pantheonrl.common.trajsaver

This module defines classes and methods related to saving trajectories.

Most of these functions come directly from the HumanCompatibleAI imitation repo: https://github.com/HumanCompatibleAI/imitation/blob/master/src/ imitation/data/types.py

Functions

dataclass_quick_asdict

Extract dataclass to items using dataclasses.fields + dict comprehension.

transitions_collate_fn

Custom torch.utils.data.DataLoader collate_fn for TransitionsMinimal.

Classes

MultiTransitions

Base class for all classes that store multiple transitions

SimultaneousTransitions

Class that stores transitions from SimultaneousEnv

TransitionsMinimal

This class is modified from HumanCompatibleAI's imitation repo: https://github.com/HumanCompatibleAI/imitation/blob/master/src/imitation/ data/types.py

TurnBasedTransitions

Class that stores transitions from TurnBasedEnv