Source code for pantheonrl.common.wrappers

"""
Collection of environment wrappers.

This module implements the FrameStack and Recording wrappers for
TurnBasedEnv and SimultaneousEnv
"""
from abc import ABC, abstractmethod
from typing import List, Optional, Dict, Tuple

import numpy as np

import gymnasium as gym

from .multiagentenv import TurnBasedEnv, SimultaneousEnv, MultiAgentEnv
from .trajsaver import (TurnBasedTransitions, SimultaneousTransitions,
                        MultiTransitions)
from .util import (calculate_space, get_default_obs, SpaceException)

# Flags for the TurnBasedRecorder wrapper
EGO_NOT_DONE = 0
ALT_NOT_DONE = 1
EGO_DONE = 2
ALT_DONE = 3

# Flags for the SimultaneousRecorder wrapper
NOT_DONE = 0
DONE = 1


[docs] def frame_wrap(env: MultiAgentEnv, numframes: int): """ Construct FrameStack environment for the given env """ if isinstance(env, TurnBasedEnv): return TurnBasedFrameStack(env, numframes, altenv=env.get_dummy_env(1)) if isinstance(env, SimultaneousEnv): return SimultaneousFrameStack(env, numframes) raise SpaceException
[docs] def recorder_wrap(env: MultiAgentEnv): """ Construct a Recorder environment for the given env""" if isinstance(env, TurnBasedEnv): return TurnBasedRecorder(env) if isinstance(env, SimultaneousEnv): return SimultaneousRecorder(env) raise SpaceException
[docs] class HistoryQueue: """ Ring buffer representing the saved history for the FrameStack wrappers. :param defaultelem: The default element for an empty buffer :param size: The length of the queue """ def __init__(self, defaultelem: np.ndarray, size: int): self.defaultelem = defaultelem self.size = size self.pos = 0 self.history: List[np.ndarray] = [defaultelem] * size
[docs] def add(self, toadd: np.ndarray) -> np.ndarray: """ Add the given value to the queue and return the new representation :param toadd: The new value to add. This overrides the oldest value :return: The new queue representation, where the first element is the most recently added element and the last element is the oldest """ if isinstance(toadd, int): toadd = np.array([toadd]) self.history[self.pos] = toadd ans = np.array([val for ind in range(self.size) for val in self.history[self.pos - ind]]) self.pos = (self.pos + 1) % self.size return ans
[docs] def reset(self) -> None: """ Reset the queue. This fills the buffer with the defaultelement. """ self.history = [self.defaultelem] * self.size self.pos = 0
[docs] class MultiRecorder(ABC): """ Base Class for all Recorder Wrappers"""
[docs] @abstractmethod def get_transitions(self) -> MultiTransitions: """ Get the transitions that have been recorded """
[docs] def write_transition(self, file): """Write transition to a given file.""" self.get_transitions().write_transition(file)
[docs] class TurnBasedRecorder(TurnBasedEnv, MultiRecorder): """ Recorder for all turn-based environments :param env: The environment to record """ def __init__(self, env: gym.Env): super().__init__( env.observation_spaces, env.action_spaces, probegostart=env.probegostart, partners=env.partners[0]) self.env = env self.allobs: List[np.ndarray] = [] self.allacts: List[np.ndarray] = [] self.flags: List[int] = [] self.incomplete = False
[docs] def ego_step( self, action: np.ndarray ) -> Tuple[Optional[np.ndarray], Tuple[float, float], bool, Dict]: """ This function calls the embedded environment's ego_step and records the action and new observation. """ altobs, rews, done, info = self.env.ego_step(action) self.allacts.append(action) if not done: self.allobs.append(altobs) self.flags.append(EGO_NOT_DONE) else: self.flags.append(EGO_DONE) self.incomplete = False return altobs, rews, done, info
[docs] def alt_step( self, action: np.ndarray ) -> Tuple[Optional[np.ndarray], Tuple[float, float], bool, Dict]: """ This function calls the embedded environment's alt_step and records the action and new observation. """ egoobs, rews, done, info = self.env.alt_step(action) self.allacts.append(action) if not done: self.allobs.append(egoobs) self.flags.append(ALT_NOT_DONE) else: self.flags.append(ALT_DONE) self.incomplete = False return egoobs, rews, done, info
[docs] def multi_reset(self, egofirst: bool) -> np.ndarray: """ This function calls the embedded environment's multi_reset and records the new observation. """ newobs = self.env.multi_reset(egofirst) if self.incomplete: self.allobs[-1] = newobs else: self.allobs.append(newobs) self.incomplete = True return newobs
[docs] def get_transitions(self) -> TurnBasedTransitions: """ Return the recorded transitions """ obsarray = np.array(self.allobs) if self.incomplete: obsarray = obsarray[:-1] return TurnBasedTransitions( obsarray, np.array(self.allacts), np.array(self.flags) )
[docs] class SimultaneousRecorder(SimultaneousEnv, MultiRecorder): """ Recorder for all turn-based environments :param env: The environment to record """ def __init__(self, env): super().__init__( env.observation_spaces, env.action_spaces, partners=env.partners[0]) self.env = env self.allegoobs = [] self.allegoacts = [] self.allaltobs = [] self.allaltacts = [] self.allflags = [] self.incomplete = False
[docs] def multi_step( self, ego_action: np.ndarray, alt_action: np.ndarray ) -> Tuple[Tuple[Optional[np.ndarray], Optional[np.ndarray]], Tuple[float, float], bool, Dict]: """ This function calls the embedded environment's multi_step and records the new actions and observations. """ obs, rews, done, info = self.env.multi_step(ego_action, alt_action) self.allegoacts.append(ego_action) self.allaltacts.append(alt_action) if not done: self.allegoobs.append(obs[0]) self.allaltobs.append(obs[1]) self.allflags.append(NOT_DONE) else: self.allflags.append(DONE) self.incomplete = False return obs, rews, done, info
[docs] def multi_reset(self) -> Tuple[np.ndarray, np.ndarray]: """ This function calls the embedded environment's multi_reset and records the new observations. """ obs = self.env.multi_reset() self.allegoobs.append(obs[0]) self.allaltobs.append(obs[1]) self.incomplete = True return obs
[docs] def get_transitions(self) -> SimultaneousTransitions: """ Return the recorded transitions """ egoobsarr = np.array(self.allegoobs) altobsarr = np.array(self.allaltobs) if self.incomplete: egoobsarr = egoobsarr[:-1] altobsarr = altobsarr[:-1] return SimultaneousTransitions( egoobsarr, np.array(self.allegoacts), altobsarr, np.array(self.allaltacts), np.array(self.allflags) )
[docs] class TurnBasedFrameStack(TurnBasedEnv): """ Wrapper that stacks the observations of a turn-based environment. :param env: The environment to wrap :param numframes: The number of frames to stack for each observation :param defaultobs: The default observation that fills old segments of the frame stacks. :param altenv: The optional dummy environment representing the spaces of the partner agent. :param defaultaltobs: The default observation that fills old segments of the frame stacks for the partner agent. """ def __init__( self, env: gym.Env, numframes: int, defaultobs: Optional[np.ndarray] = None, altenv: Optional[gym.Env] = None, defaultaltobs: Optional[np.ndarray] = None ): super().__init__( [calculate_space(o, numframes) for o in env.observation_spaces], env.action_spaces, probegostart=env.probegostart, partners=env.partners[0]) self.env = env self.numframes = numframes if defaultobs is not None: defobs = defaultobs else: defobs = get_default_obs(env) if altenv is None: altenv = env if defaultaltobs is not None: defaltobs = defaultaltobs else: defaltobs = get_default_obs(altenv) self.egohistory = HistoryQueue(defobs, numframes) self.althistory = HistoryQueue(defaltobs, numframes)
[docs] def ego_step( self, action: np.ndarray ) -> Tuple[Optional[np.ndarray], Tuple[float, float], bool, Dict]: altobs, rews, done, info = self.env.ego_step(action) return self.althistory.add(altobs), rews, done, info
[docs] def alt_step( self, action: np.ndarray ) -> Tuple[Optional[np.ndarray], Tuple[float, float], bool, Dict]: egoobs, rews, done, info = self.env.alt_step(action) return self.egohistory.add(egoobs), rews, done, info
[docs] def multi_reset(self, egofirst: bool) -> np.ndarray: newobs = self.env.multi_reset(egofirst) self.egohistory.reset() self.althistory.reset() if egofirst: return self.egohistory.add(newobs) return self.althistory.add(newobs)
[docs] class SimultaneousFrameStack(SimultaneousEnv): """ Wrapper that stacks the observations of a simultaneous environment. :param env: The environment to wrap :param numframes: The number of frames to stack for each observation :param defaultobs: The default observation that fills old segments of the frame stacks. """ def __init__( self, env: gym.Env, numframes: int, defaultobs: Optional[np.ndarray] = None ): super().__init__( [calculate_space(o, numframes) for o in env.observation_spaces], env.action_spaces, partners=env.partners[0]) self.env = env self.numframes = numframes self.defaultobs = get_default_obs( env) if defaultobs is None else list(defaultobs) self.egohistory = HistoryQueue(self.defaultobs, self.numframes) self.althistory = HistoryQueue(self.defaultobs, self.numframes)
[docs] def multi_step( self, ego_action: np.ndarray, alt_action: np.ndarray ) -> Tuple[Tuple[Optional[np.ndarray], Optional[np.ndarray]], Tuple[float, float], bool, Dict]: obs, rews, done, info = self.env.multi_step(ego_action, alt_action) return (self.egohistory.add(obs[0]), self.althistory.add(obs[1])), rews, done, info
[docs] def multi_reset(self) -> Tuple[np.ndarray, np.ndarray]: obs = self.env.multi_reset() self.egohistory.reset() self.althistory.reset() return (self.egohistory.add(obs[0]), self.althistory.add(obs[1]))