Source code for pantheonrl.common.multiagentenv

"""
This module defines the standard Env classes for PantheonRL.

It defines the following Environments:

- The abstract base MultiAgentEnv class
- The abstract SimultaneousEnv
- The abstract TurnBasedEnv

It defines a convenience DummyEnv for interacting with SARL
algorithms.

It also defines the PlayerException and KillEnvException.
"""
import warnings

from abc import ABC, abstractmethod
from typing import List, Tuple, Dict, Optional, Callable, Any, Union

import threading
from threading import Condition

import gymnasium as gym
import numpy as np

from .agents import Agent, DummyAgent
from .observation import Observation, extract_obs


[docs] class PlayerException(Exception): """Raise when players in the environment are incorrectly set"""
[docs] class KillEnvException(Exception): """Raise when the DummyEnv is killed"""
[docs] class DummyEnv(gym.Env): """ Environment representing an interface for single-agent RL algorithms that assume access to a gym environment. In its basic use, it just defines the observation and action spaces. However, it may also be used directly to run a single-agent RL algorithm. .. warning:: Use caution when trying to directly train a policy on this environment. You must create a separate thread and manage potential deadlocks. If you are using the SB3 algorithms, we strongly advise using our OnPolicyAgent and OffPolicyAgent classes instead to avoid deadlocks. :param base_env: The base MultiAgentEnv :param agent_ind: The player number in the larger environment :param extractor: Function to call to process the Observation into a usable value. By default, transforms the Observation into a numpy array of the partial observation. """ def __init__( self, base_env: gym.Env, agent_ind: int, extractor: Callable[[Observation], Any] = extract_obs, ): super().__init__() self.base_env = base_env self.agent_ind = agent_ind self.observation_space = self.base_env.observation_spaces[agent_ind] self.action_space = self.base_env.action_spaces[agent_ind] self._obs = None self._rew = None self._done = True self.obs_cv = Condition() self.extractor = extractor self.associated_agent = None self.steps = 0 self.dead = False
[docs] def step( self, action: np.ndarray ) -> tuple[Union[Observation, Any], float, bool, bool, dict[str, Any]]: """ Run one timestep from the perspective of the agent. Accepts the agent's action and returns a tuple of (observation, reward, done, info) from the perspective of the ego agent. Note that when the environment is done, the final observation is the latest observation provided by the environment, which may be the same as the previous observation given to the agent, especially in turn-based settings. :param action: An action provided by the ego-agent. :returns: observation: Ego-agent's next observation reward: Amount of reward returned after previous action terminated: Whether the episode has ended (call reset() if True) truncated: Whether the episode was truncated (call reset() if True) info: Extra information about the environment """ assert threading.current_thread() is not threading.main_thread() # print("Dummy Env: got new action in step function", self.steps) with self.associated_agent.action_cv: self.associated_agent._action = action # print("Dummy Env: sending action notification") self.associated_agent.action_cv.notify() self._obs = None with self.obs_cv: # print("Dummy Env: waiting for observation") while self._obs is None: self.obs_cv.wait() if self.dead: raise KillEnvException("Killing dummy environment") to_return = ( self.extractor(self._obs), self._rew, self._done, False, {}, ) if not self._done: self._obs = None # else: # print("DUMMY ENV THINKS DONE") # print("Dummy Env: got observation") self.steps += 1 return to_return
[docs] def reset( self, *, seed: Optional[int] = None, options: Optional[dict[str, Any]] = None, ) -> tuple[Observation, dict[str, Any]]: assert self._done assert threading.current_thread() is not threading.main_thread() # print("Dummy Env: reset called") with self.obs_cv: # print("Dummy Env: waiting for observation (reset)") while self._obs is None: self.obs_cv.wait() to_return = self.extractor(self._obs), {} self._done = False # print("Dummy Env: got observation (reset)") # print(to_return) return to_return
[docs] def close(self): self.associated_agent._action = 0 self.associated_agent.dummy_env = None with self.associated_agent.action_cv: self.associated_agent.action_cv.notify() warnings.warn( "Partner agent's dummy environment is dead. Remember to set the \ learning time for the partner to be much larger than the program \ lifetime" )
[docs] def render(self): pass
[docs] class MultiAgentEnv(gym.Env, ABC): """ Base class for all Multi-agent environments. :param observation_spaces: The observation space for each player :param action_spaces: The action space for each player :param ego_ind: The player number that the ego represents :param n_players: The number of players in the game :param resample_policy: The resampling policy (see set_resample_policy) :param partners: Lists of agents to choose from for the partner players :param ego_extractor: Function to extract Observation into the type the ego agent expects """ def __init__( self, observation_spaces: List[gym.spaces.Space], action_spaces: List[gym.spaces.Space], ego_ind: int = 0, n_players: int = 2, resample_policy: str = "default", partners: Optional[List[List[Agent]]] = None, ego_extractor: Callable[[Observation], Any] = extract_obs, ): super().__init__() self.observation_spaces = observation_spaces self.action_spaces = action_spaces self.ego_ind = ego_ind self.n_players = n_players if partners is not None: if len(partners) != n_players - 1: raise PlayerException( "The number of partners needs to equal the number \ of non-ego players" ) for plist in partners: if not isinstance(plist, list) or not plist: raise PlayerException( "Sublist for each partner must be nonempty list" ) self.partners = partners or [[]] * (n_players - 1) self.partnerids = [0] * (n_players - 1) self._players: Tuple[int, ...] = tuple() self._obs: Tuple[Optional[np.ndarray], ...] = tuple() self._old_ego_obs: Optional[np.ndarray] = None self.should_update = [False] * (self.n_players - 1) self.total_rews = [0] * (self.n_players) self.ego_moved = False self.set_resample_policy(resample_policy) self.ego_extractor = ego_extractor
[docs] def get_ego_ind(self): """Returns the current player number for the ego agent""" return self.ego_ind
[docs] def set_ego_ind(self, new_ind: int, silence_partner_warning: bool = False): """ Sets the current player number for the ego agent ..warning:: Modifying the ego_ind after partners have been added will change the player number of those partners as well :param new_ind: the new index of the ego player :param silence_partner_warning: Whether to suppress the partner warning """ if not silence_partner_warning: for plist in self.partners: if len(plist) > 0: warnings.warn( "Modifying the ego_ind after partners have been added \ will change the player number of those partners as \ well" ) break self.ego_ind = new_ind
[docs] def get_dummy_env(self, player_num: int): """ Returns a dummy environment with just an observation and action space that a partner agent can use to construct their policy network. :param player_num: the partner number to query :returns: Dummy environment for this player number """ return DummyEnv(self, player_num)
[docs] def construct_single_agent_interface(self, player_num: int): """ Construct a gym interface to be used by a single-agent RL algorithm. Note that when training a policy using this interface, it must be spawned in a separate Thread. Please refer to the custom_sarl.py file in examples to see how to appropriately use this function. :param player_num: the player number to build the interface around :returns: environment to use for the new player """ dummy_env = self.get_dummy_env(player_num) dummy_agent = DummyAgent(dummy_env) partner_num = self._get_partner_num(player_num) if len(self.partners[partner_num]) != 0: raise PlayerException( "Cannot construct multiple single agent \ interfaces for the same player_num" ) self.add_partner_agent(dummy_agent, player_num) return dummy_env
@property def observation_space(self) -> gym.spaces.Space: """The observation space of the ego agent""" return self.observation_spaces[self.ego_ind] @property def action_space(self) -> gym.spaces.Space: """The action space of the ego agent""" return self.action_spaces[self.ego_ind]
[docs] def set_ego_extractor(self, ego_extractor: Callable[[Observation], Any]): """ Sets the function to extract Observation for the ego agent. :param ego_extractor: Function to extract Observation into the type the ego agent expects """ self.ego_extractor = ego_extractor
def _get_partner_num(self, player_num: int) -> int: if player_num == self.ego_ind: raise PlayerException("Ego agent is not set by the environment") if player_num > self.ego_ind: return player_num - 1 return player_num
[docs] def add_partner_agent(self, agent: Agent, player_num: int = 1) -> None: """ Add agent to the list of potential partner agents. If there are multiple agents that can be a specific player number, the environment randomly samples from them at the start of every episode. :param agent: Agent to add :param player_num: the player number that this new agent can be """ self.partners[self._get_partner_num(player_num)].append(agent)
[docs] def set_partnerid(self, agent_id: int, player_num: int = 1) -> None: """ Set the current partner agent to use :param agent_id: agent_id to use as current partner :param player_num: The player number """ partner_num = self._get_partner_num(player_num) assert 0 <= agent_id < len(self.partners[partner_num]) self.partnerids[partner_num] = agent_id
[docs] def resample_random(self) -> None: """Randomly resamples each partner policy""" self.partnerids = [ self.np_random.integers(0, len(plist)) for plist in self.partners ]
[docs] def resample_null(self) -> None: """Do not resample each partner policy"""
[docs] def resample_round_robin(self) -> None: """ Sets the partner policy to the next option on the list for round-robin sampling. Note: This function is only valid for 2-player environments """ self.partnerids = [(self.partnerids[0] + 1) % len(self.partners[0])]
[docs] def set_resample_policy(self, resample_policy: str) -> None: """ Set the resample_partner method to round "robin" or "random" :param resample_policy: The new resampling policy to use. Valid values are: "default", "robin", "random", or "null" """ if resample_policy == "default": resample_policy = "robin" if self.n_players == 2 else "random" if resample_policy == "robin" and self.n_players != 2: raise PlayerException( "Cannot do round robin resampling for >2 players" ) if resample_policy == "robin": self.resample_partner = self.resample_round_robin elif resample_policy == "random": self.resample_partner = self.resample_random elif resample_policy == "null": self.resample_partner = self.resample_null else: raise PlayerException( f"Invalid resampling policy: {resample_policy}" )
def _get_actions(self, players, obs, ego_act=None): actions = [] for player, ob in zip(players, obs): if player == self.ego_ind: actions.append(ego_act) else: p = self._get_partner_num(player) agent = self.partners[p][self.partnerids[p]] actions.append(agent.get_action(ob)) if not self.should_update[p]: agent.update(self.total_rews[player], False) self.should_update[p] = True return np.array(actions) def _update_players(self, rews, done): for i in range(self.n_players - 1): nextrew = rews[i + (0 if i < self.ego_ind else 1)] if self.should_update[i]: self.partners[i][self.partnerids[i]].update(nextrew, done) for i in range(self.n_players): self.total_rews[i] += rews[i]
[docs] def step( self, action: np.ndarray ) -> tuple[Union[Observation, Any], float, bool, bool, dict[str, Any]]: """ Run one timestep from the perspective of the ego-agent. This involves calling the ego_step function and the alt_step function to get to the next observation of the ego agent. Accepts the ego-agent's action and returns a tuple of (observation, reward, done, info) from the perspective of the ego agent. Note that when the environment is done, the final observation is the latest observation provided by the environment, which may be the same as the previous observation given to the agent, especially in turn-based settings. :param action: An action provided by the ego-agent. :returns: observation: Ego-agent's next observation reward: Amount of reward returned after previous action terminated: Whether the episode has ended (call reset() if True) truncated: Whether the episode was truncated (call reset() if True) info: Extra information about the environment """ ego_rew = 0.0 while True: acts = self._get_actions(self._players, self._obs, action) self._players, self._obs, rews, done, info = self.n_step(acts) info["_partnerid"] = self.partnerids self._update_players(rews, done) ego_rew += ( rews[self.ego_ind] if self.ego_moved else self.total_rews[self.ego_ind] ) self.ego_moved = True if self.ego_ind in self._players: break if done: ego_obs = self._old_ego_obs return self.ego_extractor(ego_obs), ego_rew, done, False, info ego_obs = self._obs[self._players.index(self.ego_ind)] self._old_ego_obs = ego_obs return self.ego_extractor(ego_obs), ego_rew, done, False, info
[docs] def reset( self, *, seed: Optional[int] = None, options: Optional[dict[str, Any]] = None, ) -> tuple[Observation, dict[str, Any]]: """ Reset environment to an initial state and return the first observation for the ego agent. :returns: Ego-agent's first observation """ super().reset(seed=seed) self.resample_partner() self._players, self._obs = self.n_reset() self.should_update = [False] * (self.n_players - 1) self.total_rews = [0] * self.n_players self.ego_moved = False while self.ego_ind not in self._players: acts = self._get_actions(self._players, self._obs) self._players, self._obs, rews, done, _ = self.n_step(acts) self._update_players(rews, done) if done: self.resample_partner() self._players, self._obs = self.n_reset() self.should_update = [False] * (self.n_players - 1) self.total_rews = [0] * self.n_players self.ego_moved = False ego_obs = self._obs[self._players.index(self.ego_ind)] assert ego_obs is not None self._old_ego_obs = ego_obs return self.ego_extractor(ego_obs), {}
[docs] @abstractmethod def n_step( self, actions: List[np.ndarray], ) -> Tuple[ Tuple[int, ...], Tuple[Optional[Observation], ...], Tuple[float, ...], bool, Dict, ]: """ Perform the actions specified by the agents that will move. This function returns a tuple of (next agents, observations, both rewards, done, info). This function is called by the `step` function. :param actions: List of action provided agents that are acting on this step. :returns: agents: Tuple representing the agents to call for the next actions observations: Tuple representing the next observations (ego, alt) rewards: Tuple representing the rewards of all agents done: Whether the episode has ended info: Extra information about the environment """
[docs] @abstractmethod def n_reset( self, ) -> Tuple[Tuple[int, ...], Tuple[Optional[Observation], ...]]: """ Reset the environment and return which agents will move first along with their initial observations. This function is called by the `reset` function. :returns: agents: Tuple representing the agents that will move first observations: Tuple representing the observations of both agents """
[docs] class TurnBasedEnv(MultiAgentEnv, ABC): """ Base class for all 2-player turn-based games. In turn-based games, players take turns receiving observations and making actions. :param observation_spaces: The observation space for each player :param action_spaces: The action space for each player :param probegostart: Probability that the ego agent gets the first turn :param partners: List of policies to choose from for the partner agent """ def __init__( self, observation_spaces: List[gym.spaces.Space], action_spaces: List[gym.spaces.Space], probegostart: float = 0.5, partners: Optional[List[Agent]] = None, ): partners = [partners] if partners else None super().__init__( observation_spaces, action_spaces, ego_ind=0, n_players=2, partners=partners, ) self.probegostart = probegostart self.ego_next = True
[docs] def n_step( self, actions: List[np.ndarray], ) -> Tuple[ Tuple[int, ...], Tuple[Optional[Observation], ...], Tuple[float, ...], bool, Dict, ]: agents = (1 if self.ego_next else 0,) obs, rews, done, info = ( self.ego_step(actions[0]) if self.ego_next else self.alt_step(actions[0]) ) self.ego_next = not self.ego_next return agents, (Observation(obs),), rews, done, info
[docs] def n_reset( self, ) -> Tuple[Tuple[int, ...], Tuple[Optional[np.ndarray], ...]]: self.ego_next = self.np_random.random() < self.probegostart obs = self.multi_reset(self.ego_next) return (0 if self.ego_next else 1,), (Observation(obs),)
[docs] @abstractmethod def ego_step( self, action: np.ndarray ) -> Tuple[Optional[np.ndarray], Tuple[float, float], bool, Dict]: """ Perform the ego-agent's action and return a tuple of (partner's observation, both rewards, done, info). This function is called by the `step` function along with alt-step. :param action: An action provided by the ego-agent. :returns: partner observation: Partner's next observation rewards: Tuple representing the rewards of both agents (ego, alt) done: Whether the episode has ended info: Extra information about the environment """
[docs] @abstractmethod def alt_step( self, action: np.ndarray ) -> Tuple[Optional[np.ndarray], Tuple[float, float], bool, Dict]: """ Perform the partner's action and return a tuple of (ego's observation, both rewards, done, info). This function is called by the `step` function along with ego-step. :param action: An action provided by the partner. :returns: ego observation: Ego-agent's next observation rewards: Tuple representing the rewards of both agents (ego, alt) done: Whether the episode has ended info: Extra information about the environment """
[docs] @abstractmethod def multi_reset(self, egofirst: bool) -> np.ndarray: """ Reset the environment and give the observation of the starting agent (based on the value of `egofirst`). This function is called by the `reset` function. :param egofirst: True if the ego has the first turn, False otherwise :returns: The observation for the starting agent (ego if `egofirst` is True, and the partner's observation otherwise) """
[docs] class SimultaneousEnv(MultiAgentEnv, ABC): """ Base class for all 2-player simultaneous games. :param observation_spaces: The observation space for each player :param action_spaces: The action space for each player :param partners: List of policies to choose from for the partner agent """ def __init__( self, observation_spaces: List[gym.spaces.Space], action_spaces: List[gym.spaces.Space], partners: Optional[List[Agent]] = None, ): partners = [partners] if partners else None super().__init__( observation_spaces, action_spaces, ego_ind=0, n_players=2, partners=partners, )
[docs] def n_step( self, actions: List[np.ndarray], ) -> Tuple[ Tuple[int, ...], Tuple[Optional[Observation], ...], Tuple[float, ...], bool, Dict, ]: (obs0, obs1), r, d, i = self.multi_step(actions[0], actions[1]) return ((0, 1), (Observation(obs0), Observation(obs1)), r, d, i)
[docs] def n_reset( self, ) -> Tuple[Tuple[int, ...], Tuple[Optional[Observation], ...]]: (obs0, obs1) = self.multi_reset() return (0, 1), (Observation(obs0), Observation(obs1))
[docs] @abstractmethod def multi_step( self, ego_action: np.ndarray, alt_action: np.ndarray ) -> Tuple[ Tuple[Optional[np.ndarray], Optional[np.ndarray]], Tuple[float, float], bool, Dict, ]: """ Perform the ego-agent's and partner's actions. This function returns a tuple of (observations, both rewards, done, info). This function is called by the `step` function. :param ego_action: An action provided by the ego-agent. :param alt_action: An action provided by the partner. :returns: observations: Tuple representing the next observations (ego, alt) rewards: Tuple representing the rewards of both agents (ego, alt) done: Whether the episode has ended info: Extra information about the environment """
[docs] @abstractmethod def multi_reset(self) -> Tuple[np.ndarray, np.ndarray]: """ Reset the environment and give the observation of both agents. This function is called by the `reset` function. :returns: The observations of both agents """