Source code for pantheonrl.algos.adap.agent

"""
Module defining the ADAP partner agent.
"""
from typing import Optional

import numpy as np

from pantheonrl.common.agents import OnPolicyAgent
from pantheonrl.common.observation import Observation

from .adap_learn import ADAP
from .util import SAMPLERS
from .policies import AdapPolicy


[docs] class AdapAgent(OnPolicyAgent): """ Agent representing an ADAP learning algorithm. The `get_action` and `update` functions are based on the `learn` function from ``OnPolicyAlgorithm``. :param model: Model representing the agent's learning algorithm :param log_interval: Optional log interval for policy logging :param working_timesteps: Estimate for number of timesteps to train for. :param callback: Optional callback fed into the OnPolicyAlgorithm :param tb_log_name: Name for tensorboard log """ def __init__( self, model: ADAP, log_interval=None, working_timesteps=1000, callback=None, tb_log_name="AdapAgent", latent_syncer: Optional[AdapPolicy] = None, ): super().__init__( model, log_interval, working_timesteps, callback, tb_log_name ) self.latent_syncer = latent_syncer buf = self.model.rollout_buffer self.model.full_obs_shape = ( buf.obs_shape[0] + self.model.context_size, ) buf.obs_shape = self.model.full_obs_shape buf.reset()
[docs] def get_action(self, obs: Observation) -> np.ndarray: """ Return an action given an observation. The agent saves the last transition into its buffer. It also updates the model if the buffer is full. :param obs: The observation to use :returns: The action to take """ if self.latent_syncer is not None: self.model.policy.set_context( self.latent_syncer.policy.get_context() ) if not isinstance(obs.obs, np.ndarray): obs.obs = np.array([obs.obs]) obs.obs = np.concatenate( (np.reshape(obs.obs, (1, -1)), self.model.policy.get_context()), axis=1, ) return super().get_action(obs)
[docs] def update(self, reward: float, done: bool) -> None: super().update(reward, done) if done and self.latent_syncer is None: sampled_context = SAMPLERS[self.model.context_sampler]( ctx_size=self.model.context_size, num=1, use_torch=True ) self.model.policy.set_context(sampled_context)