"""
This module defines classes and methods related to saving trajectories.
Most of these functions come directly from the HumanCompatibleAI imitation
repo: https://github.com/HumanCompatibleAI/imitation/blob/master/src/
imitation/data/types.py
"""
import dataclasses
from abc import ABC, abstractmethod
from typing import Dict, Mapping, Sequence, Union, TypeVar, overload
import numpy as np
import torch as th
from torch.utils import data as th_data
from torch.utils.data._utils.collate import default_collate
from .util import get_space_size
T = TypeVar("T")
[docs]
def transitions_collate_fn(
batch: Sequence[Mapping[str, np.ndarray]],
) -> Dict[str, Union[np.ndarray, th.Tensor]]:
"""
Custom `torch.utils.data.DataLoader` collate_fn for `TransitionsMinimal`.
Use this as the `collate_fn` argument to `DataLoader` if using an instance
of `TransitionsMinimal` as the `dataset` argument.
"""
batch_no_infos = list(batch)
result = default_collate(batch_no_infos)
assert isinstance(result, dict)
return result
[docs]
def dataclass_quick_asdict(dataclass_instance) -> dict:
"""
Extract dataclass to items using `dataclasses.fields` + dict comprehension.
This is a quick alternative to `dataclasses.asdict`, which expensively and
undocumentedly deep-copies every numpy array value.
See https://stackoverflow.com/a/52229565/1091722.
"""
obj = dataclass_instance
d = {f.name: getattr(obj, f.name) for f in dataclasses.fields(obj)}
return d
[docs]
@dataclasses.dataclass(frozen=True)
class TransitionsMinimal(th_data.Dataset):
"""
This class is modified from HumanCompatibleAI's imitation repo:
https://github.com/HumanCompatibleAI/imitation/blob/master/src/imitation/
data/types.py
A Torch-compatible `Dataset` of obs-act transitions.
This class and its subclasses are usually instantiated via
`imitation.data.rollout.flatten_trajectories`.
Indexing an instance `trans` of TransitionsMinimal with an integer `i`
returns the `i`th `Dict[str, np.ndarray]` sample, whose keys are the field
names of each dataclass field and whose values are the ith elements of each
field value.
Slicing returns a possibly empty instance of `TransitionsMinimal` where
each field has been sliced.
"""
obs: np.ndarray
"""
Previous observations. Shape: (batch_size, ) + observation_shape.
The i'th observation `obs[i]` in this array is the observation seen
by the agent when choosing action `acts[i]`. `obs[i]` is not required to
be from the timestep preceding `obs[i+1]`.
"""
acts: np.ndarray
"""Actions. Shape: (batch_size,) + action_shape."""
def __len__(self):
"""Returns number of transitions. Always positive."""
return len(self.obs)
def __post_init__(self):
"""Performs input validation: check shapes & dtypes match docstring.
Also make array values read-only.
"""
for val in vars(self).values():
if isinstance(val, np.ndarray):
val.setflags(write=False)
if len(self.obs) != len(self.acts):
raise ValueError(
"obs and acts must have same number of timesteps: "
f"{len(self.obs)} != {len(self.acts)}"
)
@overload
def __getitem__(self: T, key: slice) -> T:
pass # pragma: no cover
@overload
def __getitem__(self, key: int) -> Dict[str, np.ndarray]:
pass # pragma: no cover
def __getitem__(self, key):
"""See TransitionsMinimal docstring for indexing and slicing semantics."""
d = dataclass_quick_asdict(self)
d_item = {k: v[key] for k, v in d.items()}
if isinstance(key, slice):
# Return type is the same as this dataclass. Replace field value
# with slices.
return dataclasses.replace(self, **d_item)
assert isinstance(key, int)
# Return type is a dictionary. Array values have no batch dimension
#
# Dictionary of np.ndarray values is a convenient
# torch.util.data.Dataset return type, as a
# torch.util.data.DataLoader taking in this `Dataset` as its first
# argument knows how to automatically concatenate several
# dictionaries together to make a single dictionary batch with
# `torch.Tensor` values.
return d_item
[docs]
def write_transition(self, file):
"""Write transition to a given file."""
length = len(self.obs)
full_list = np.concatenate(
(self.obs.reshape((length, -1)), self.acts.reshape((length, -1))),
axis=1,
)
np.save(file, full_list)
[docs]
@classmethod
def read_transition(cls, file, obs_space, act_space):
"""Construct TransitionsMinimal from file"""
full_list = np.load(file)
obs_size = get_space_size(obs_space)
act_size = get_space_size(act_space)
obs = full_list[:, :obs_size]
acts = full_list[:, obs_size : obs_size + act_size]
if obs_size == 1:
obs = obs.flatten()
if act_size == 1:
acts = acts.flatten()
return TransitionsMinimal(obs, acts)
[docs]
class MultiTransitions(ABC):
"""Base class for all classes that store multiple transitions"""
[docs]
@abstractmethod
def get_ego_transitions(self) -> TransitionsMinimal:
"""Returns the ego's transitions"""
[docs]
@abstractmethod
def get_alt_transitions(self) -> TransitionsMinimal:
"""Returns the partner's transitions"""
[docs]
@abstractmethod
def write_transition(self, file):
"""Write transition to a given file."""
[docs]
@dataclasses.dataclass(frozen=True)
class TurnBasedTransitions(MultiTransitions):
"""Class that stores transitions from TurnBasedEnv"""
obs: np.ndarray
acts: np.ndarray
flags: np.ndarray
[docs]
def get_ego_transitions(self) -> TransitionsMinimal:
"""Returns the ego's transitions"""
mask = self.flags % 2 == 0
return TransitionsMinimal(self.obs[mask], self.acts[mask])
[docs]
def get_alt_transitions(self) -> TransitionsMinimal:
"""Returns the partner's transitions"""
mask = self.flags % 2 == 1
return TransitionsMinimal(self.obs[mask], self.acts[mask])
[docs]
def write_transition(self, file):
flags = np.reshape(self.flags, (-1, 1))
length = flags.shape[0]
obs = np.reshape(self.obs, (length, -1))
acts = np.reshape(self.acts, (length, -1))
full_list = np.concatenate((obs, acts, flags), axis=1)
np.save(file, full_list)
[docs]
@classmethod
def read_transition(cls, file, obs_space, act_space):
"""Construct TurnBasedTransitions from file"""
full_list = np.load(file)
obs_size = get_space_size(obs_space)
act_size = get_space_size(act_space)
obs = full_list[:, :obs_size]
acts = full_list[:, obs_size : obs_size + act_size]
flags = full_list[:, -1]
if obs_size == 1:
obs = obs.flatten()
if act_size == 1:
acts = acts.flatten()
return TurnBasedTransitions(obs, acts, flags)
[docs]
@dataclasses.dataclass(frozen=True)
class SimultaneousTransitions(MultiTransitions):
"""Class that stores transitions from SimultaneousEnv"""
egoobs: np.ndarray
egoacts: np.ndarray
altobs: np.ndarray
altacts: np.ndarray
flags: np.ndarray
[docs]
def get_ego_transitions(self) -> TransitionsMinimal:
"""Returns the ego's transitions"""
return TransitionsMinimal(self.egoobs, self.egoacts)
[docs]
def get_alt_transitions(self) -> TransitionsMinimal:
"""Returns the partner's transitions"""
return TransitionsMinimal(self.altobs, self.altacts)
[docs]
def write_transition(self, file):
flags = np.reshape(self.flags, (-1, 1))
length = flags.shape[0]
egoobs = np.reshape(self.egoobs, (length, -1))
egoacts = np.reshape(self.egoacts, (length, -1))
altobs = np.reshape(self.altobs, (length, -1))
altacts = np.reshape(self.altacts, (length, -1))
full_list = np.concatenate(
(egoobs, egoacts, altobs, altacts, flags), axis=1
)
np.save(file, full_list)
[docs]
@classmethod
def read_transition(cls, file, obs_space, act_space):
"""Construct SimultaneousTransitions from file"""
full_list = np.load(file)
obs_size = get_space_size(obs_space)
act_size = get_space_size(act_space)
egoobs = full_list[:, :obs_size]
egoacts = full_list[:, obs_size : (obs_size + act_size)]
altobs = full_list[
:, (obs_size + act_size) : (2 * obs_size + act_size)
]
altacts = full_list[:, (2 * obs_size + act_size) : -1]
flags = full_list[:, -1]
if obs_size == 1:
egoobs = egoobs.flatten()
altobs = altobs.flatten()
if act_size == 1:
egoacts = egoacts.flatten()
altacts = altacts.flatten()
return SimultaneousTransitions(egoobs, egoacts, altobs, altacts, flags)