"""
Module defining the Policy for ADAP
"""
# pylint: disable=locally-disabled, not-callable
from typing import Any, Dict, Optional, Type, Union, List, Tuple
import torch
import gymnasium as gym
from torch import nn
from stable_baselines3.common.utils import get_device
from stable_baselines3.common.policies import ActorCriticPolicy, BasePolicy
from stable_baselines3.common.type_aliases import Schedule
from stable_baselines3.common.torch_layers import (
BaseFeaturesExtractor,
FlattenExtractor,
MlpExtractor,
)
[docs]
class AdapPolicy(ActorCriticPolicy):
"""
Base Policy for the ADAP Actor-critic policy
"""
def __init__(
self,
observation_space: gym.spaces.Space,
action_space: gym.spaces.Space,
lr_schedule: Schedule,
net_arch: Optional[Union[List[int], Dict[str, List[int]]]] = None,
activation_fn: Type[nn.Module] = nn.Tanh,
ortho_init: bool = True,
use_sde: bool = False,
log_std_init: float = 0.0,
full_std: bool = True,
use_expln: bool = False,
squash_output: bool = False,
features_extractor_class: Type[
BaseFeaturesExtractor
] = FlattenExtractor,
features_extractor_kwargs: Optional[Dict[str, Any]] = None,
share_features_extractor: bool = True,
normalize_images: bool = True,
optimizer_class: Type[torch.optim.Optimizer] = torch.optim.Adam,
optimizer_kwargs: Optional[Dict[str, Any]] = None,
context_size: int = 3,
):
self.context_size = context_size
self.context = None
self.mlp_extractor = None
super().__init__(
observation_space=observation_space,
action_space=action_space,
lr_schedule=lr_schedule,
net_arch=net_arch,
activation_fn=activation_fn,
ortho_init=ortho_init,
use_sde=use_sde,
log_std_init=log_std_init,
full_std=full_std,
use_expln=use_expln,
squash_output=squash_output,
features_extractor_class=features_extractor_class,
features_extractor_kwargs=features_extractor_kwargs,
share_features_extractor=share_features_extractor,
normalize_images=normalize_images,
optimizer_class=optimizer_class,
optimizer_kwargs=optimizer_kwargs,
)
[docs]
def set_context(self, ctxt):
"""Set the context"""
self.context = ctxt
[docs]
def get_context(self):
"""Get the current context"""
return self.context
def _build_mlp_extractor(self) -> None:
"""
Create the policy and value networks.
Part of the layers can be shared.
"""
# Note: If net_arch is None and some features extractor is used,
# net_arch here is an empty list and mlp_extractor does not
# really contain any layers (acts like an identity module).
self.mlp_extractor = MlpExtractor(
self.features_dim + self.context_size,
net_arch=self.net_arch,
activation_fn=self.activation_fn,
device=self.device,
)
def _get_latent(
self, obs: torch.Tensor
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""
Get the latent code (activations of the last layer of each network)
for the different networks.
:param obs: Observation
:return: Latent codes
for the actor, the value function and for gSDE function
"""
# Preprocess the observation if needed
features = self.extract_features(obs)
features = torch.cat(
(features, self.context.repeat(features.size()[0], 1)), dim=1
)
latent_pi, latent_vf = self.mlp_extractor(features)
return latent_pi, latent_vf
[docs]
def forward(
self, obs: torch.Tensor, deterministic: bool = False
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""
Forward pass in all the networks (actor and critic)
:param obs: Observation
:param deterministic: Whether to sample or use deterministic actions
:return: action, value and log probability of the action
"""
# Preprocess the observation if needed
latents = obs[..., -self.context_size :].reshape(
-1, self.context_size
)[0]
obs = obs[..., : -self.context_size].reshape(
-1, obs.size(dim=-1) - self.context_size
)
features = self.extract_features(obs)
latents = latents.to(features.device, features.dtype)
features = torch.cat(
(features, latents.repeat(features.size()[0], 1)), dim=1
)
if self.share_features_extractor:
latent_pi, latent_vf = self.mlp_extractor(features)
else:
pi_features, vf_features = features
latent_pi = self.mlp_extractor.forward_actor(pi_features)
latent_vf = self.mlp_extractor.forward_critic(vf_features)
# Evaluate the values for the given observations
values = self.value_net(latent_vf)
distribution = self._get_action_dist_from_latent(latent_pi)
actions = distribution.get_actions(deterministic=deterministic)
log_prob = distribution.log_prob(actions)
actions = actions.reshape((-1, *self.action_space.shape))
return actions, values, log_prob
[docs]
def predict_values(self, obs: torch.Tensor) -> torch.Tensor:
"""
Get the estimated values according to the current policy given the observations.
:param obs: Observation
:return: the estimated values.
"""
latents = obs[..., -self.context_size :].reshape(
-1, self.context_size
)[0]
obs = obs[..., : -self.context_size].reshape(
-1, obs.size(dim=-1) - self.context_size
)
features = super(BasePolicy, self).extract_features(
obs, self.vf_features_extractor
)
latents = latents.to(features.device, features.dtype)
features = torch.cat(
(features, latents.repeat(features.size()[0], 1)), dim=1
)
latent_vf = self.mlp_extractor.forward_critic(features)
return self.value_net(latent_vf)
[docs]
def evaluate_actions(
self, obs: torch.Tensor, actions: torch.Tensor
) -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]:
"""
Evaluate actions according to the current policy,
given the observations.
:param obs: Observation
:param actions: Actions
:return: estimated value, log likelihood of taking those actions
and entropy of the action distribution.
"""
# Preprocess the observation if needed
latents = obs[..., -self.context_size :].reshape(
-1, self.context_size
)[0]
obs = obs[..., : -self.context_size].reshape(
-1, obs.size(dim=-1) - self.context_size
)
print("NEW OBS", obs)
features = self.extract_features(obs)
latents = latents.to(features.device, features.dtype)
print(features.shape, latents.shape)
features = torch.cat(
(features, latents.repeat(features.size()[0], 1)), dim=1
)
print(features.shape)
if self.share_features_extractor:
latent_pi, latent_vf = self.mlp_extractor(features)
else:
pi_features, vf_features = features
latent_pi = self.mlp_extractor.forward_actor(pi_features)
latent_vf = self.mlp_extractor.forward_critic(vf_features)
distribution = self._get_action_dist_from_latent(latent_pi)
log_prob = distribution.log_prob(actions)
values = self.value_net(latent_vf)
entropy = distribution.entropy()
return values, log_prob, entropy
[docs]
class MultModel(nn.Module):
"""Neural Network representing multiplicative layers"""
def __init__(
self, feature_dim, net_arch, activation_fn, device, context_size
):
super().__init__()
self.context_size = context_size
device = get_device(device)
policy_net: List[nn.Module] = []
value_net: List[nn.Module] = []
last_layer_dim_pi = feature_dim
last_layer_dim_vf = feature_dim
# save dimensions of layers in policy and value nets
if isinstance(net_arch, dict):
# Note: if key is not specificed, assume linear network
pi_layers_dims = net_arch.get(
"pi", []
) # Layer sizes of the policy network
vf_layers_dims = net_arch.get(
"vf", []
) # Layer sizes of the value network
else:
pi_layers_dims = vf_layers_dims = net_arch
# Iterate through the policy layers and build the policy net
for curr_layer_dim in pi_layers_dims:
policy_net.append(nn.Linear(last_layer_dim_pi, curr_layer_dim))
policy_net.append(activation_fn())
last_layer_dim_pi = curr_layer_dim
# Iterate through the value layers and build the value net
for curr_layer_dim in vf_layers_dims:
value_net.append(nn.Linear(last_layer_dim_vf, curr_layer_dim))
value_net.append(activation_fn())
last_layer_dim_vf = curr_layer_dim
# Save dim, used to create the distributions
self.latent_dim_pi = last_layer_dim_pi
self.latent_dim_vf = last_layer_dim_vf
self.hidden_dim1 = policy_net[0].out_features
self.agent_branch_1 = nn.Sequential(*policy_net[0:2]).to(device)
self.agent_scaling = nn.Sequential(
nn.Linear(self.hidden_dim1, self.hidden_dim1 * self.context_size),
activation_fn(),
).to(device)
self.agent_branch_2 = nn.Sequential(*policy_net[2:]).to(device)
self.hidden_dim2 = value_net[0].out_features
self.value_branch_1 = nn.Sequential(*value_net[0:2]).to(device)
self.value_scaling = nn.Sequential(
nn.Linear(self.hidden_dim2, self.hidden_dim2 * self.context_size),
activation_fn(),
).to(device)
self.value_branch_2 = nn.Sequential(*value_net[2:]).to(device)
[docs]
def policies(
self, observations: torch.Tensor, contexts: torch.Tensor
) -> torch.Tensor:
"""Returns the logits from the policy function"""
batch_size = observations.shape[0]
x = self.agent_branch_1(observations)
x_a = self.agent_scaling(x)
# reshape to do context multiplication
x_a = x_a.view((batch_size, self.hidden_dim1, self.context_size))
x_a_out = torch.matmul(x_a, contexts.unsqueeze(-1)).squeeze(-1)
logits = self.agent_branch_2(x + x_a_out)
return logits
[docs]
def values(
self, observations: torch.Tensor, contexts: torch.Tensor
) -> torch.Tensor:
"""Returns the response from the value function"""
batch_size = observations.shape[0]
x = self.value_branch_1(observations)
x_a = self.value_scaling(x)
# reshape to do context multiplication
x_a = x_a.view((batch_size, self.hidden_dim2, self.context_size))
x_a_out = torch.matmul(x_a, contexts.unsqueeze(-1)).squeeze(-1)
values = self.value_branch_2(x + x_a_out)
# values = self.value_branch_2(x_a_out)
return values
[docs]
def forward(
self, features: torch.Tensor
) -> Tuple[torch.Tensor, torch.Tensor]:
"""Returns the action logits and values"""
# features = self.shared_net(features)
observations = features[:, : -self.context_size]
contexts = features[:, -self.context_size :]
return self.policies(observations, contexts), self.values(
observations, contexts
)
[docs]
def forward_actor(self, features: torch.Tensor) -> torch.Tensor:
"""Returns the action logits and values"""
# features = self.shared_net(features)
observations = features[:, : -self.context_size]
contexts = features[:, -self.context_size :]
return self.policies(observations, contexts)
[docs]
def forward_critic(self, features: torch.Tensor) -> torch.Tensor:
"""Returns the action logits and values"""
# features = self.shared_net(features)
observations = features[:, : -self.context_size]
contexts = features[:, -self.context_size :]
return self.values(observations, contexts)
[docs]
class AdapPolicyMult(AdapPolicy):
"""
Multiplicative Policy for the ADAP Actor-critic policy
"""
def _build_mlp_extractor(self) -> None:
"""
Create the policy and value networks.
Part of the layers can be shared.
"""
# Note: If net_arch is None and some features extractor is used,
# net_arch here is an empty list and mlp_extractor does not
# really contain any layers (acts like an identity module).
self.mlp_extractor = MultModel(
self.features_dim,
net_arch=self.net_arch,
activation_fn=self.activation_fn,
device=self.device,
context_size=self.context_size,
)