Source code for pantheonrl.envs.pettingzoo

"""
Simple wrapper for Petting Zoo environments.
"""

from typing import Tuple, Optional, List, Dict

import numpy as np
import gymnasium as gym

from pantheonrl.common.multiagentenv import MultiAgentEnv
from pantheonrl.common.observation import Observation


[docs] class PettingZooAECWrapper(MultiAgentEnv): """ Wrapper for Petting Zoo AEC environments. :param base_env: base PettingZoo env :param ego_ind: index of the ego agent """ def __init__(self, base_env, ego_ind=0): self.base_env = base_env observation_spaces = [] action_spaces = [] for player_ind in range(base_env.max_num_agents): agent = self.base_env.possible_agents[player_ind] ospace = self.base_env.observation_space(agent) if isinstance(ospace, gym.spaces.dict.Dict): ospace = ospace.spaces["observation"] aspace = self.base_env.action_space(agent) observation_spaces.append(ospace) action_spaces.append(aspace) super().__init__( observation_spaces, action_spaces, ego_ind, base_env.max_num_agents ) self._action_mask = None
[docs] def n_step( self, actions: List[np.ndarray], ) -> Tuple[ Tuple[int, ...], Tuple[Optional[Observation], ...], Tuple[float, ...], bool, Dict, ]: agent = self.base_env.agent_selection act = actions[0] if self._action_mask is not None and not self._action_mask[act]: act = self._action_mask.tolist().index(1) self.base_env.step(act) agent = self.base_env.agent_selection agent_idx = self.base_env.possible_agents.index(agent) obs = self.base_env.observe(agent) if isinstance(obs, dict): self._action_mask = obs["action_mask"] obs = obs["observation"] rewards = [0] * self.n_players for key, val in self.base_env.rewards.items(): rewards[self.base_env.possible_agents.index(key)] = val done = all( self.base_env.terminations[x] or self.base_env.truncations[x] for x in self.base_env.possible_agents ) # print(self.base_env.terminations) # done = all(self.base_env.dones.values()) info = self.base_env.infos[self.base_env.possible_agents[self.ego_ind]] obs = Observation(obs=obs, action_mask=self._action_mask) return (agent_idx,), (obs,), tuple(rewards), done, info
[docs] def n_reset( self, ) -> Tuple[Tuple[int, ...], Tuple[Optional[Observation], ...]]: self.base_env.reset() agent = self.base_env.agent_selection agent_idx = self.base_env.possible_agents.index(agent) obs = self.base_env.observe(agent) if isinstance(obs, dict): self._action_mask = obs["action_mask"] obs = obs["observation"] obs = Observation(obs=obs, action_mask=self._action_mask) return (agent_idx,), (obs,)