Source code for pantheonrl.algos.bc

"""
Behavioural Cloning (BC).
Trains policy by applying supervised learning to a fixed dataset of
(observation, action) pairs generated by some expert demonstrator.

https://github.com/HumanCompatibleAI/imitation/blob/master/src/imitation/algorithms/bc.py
"""

import contextlib
from dataclasses import dataclass
from typing import (
    Any,
    Callable,
    Dict,
    Iterable,
    Mapping,
    Optional,
    Tuple,
    Type,
    Union,
)

import gymnasium as gym
import numpy as np
import torch
import torch.utils.data as th_data
from torch.optim.optimizer import Optimizer
from torch.optim.adam import Adam
import tqdm.autonotebook as tqdm
from stable_baselines3.common import policies, utils

from pantheonrl.common.trajsaver import (
    TransitionsMinimal,
    transitions_collate_fn,
)
from pantheonrl.common.util import FeedForward32Policy

log = utils.configure_logger(verbose=0)  # change to 1 for debugging


[docs] @dataclass class BCShell: """ Shell class for BC policy """ policy: FeedForward32Policy
[docs] def get_policy(self): """ Get the current policy """ return self.policy
[docs] def set_policy(self, new_policy): """ Set the BC policy """ self.policy = new_policy
[docs] def reconstruct_policy( policy_path: str, device: Union[torch.device, str] = "auto", ) -> policies.BasePolicy: """Reconstruct a saved policy. Args: policy_path: path where `.save_policy()` has been run. device: device on which to load the policy. Returns: policy: policy with reloaded weights. """ policy = torch.load(policy_path, map_location=utils.get_device(device)) assert isinstance(policy, policies.BasePolicy) return policy
[docs] class ConstantLRSchedule: """A callable that returns a constant learning rate.""" def __init__(self, lr: float = 1e-3): """ Args: lr: the constant learning rate that calls to this object will return. """ self.lr = lr
[docs] def __call__(self, _): """ Returns the constant learning rate. """ return self.lr
[docs] def set_lr(self, new_lr): """ Sets a new learning rate """ self.lr = new_lr
[docs] class EpochOrBatchIteratorWithProgress: """ Wraps DataLoader so that all BC batches can be processed in a one for-loop. Also uses `tqdm` to show progress in stdout. Args: data_loader: An iterable over data dicts, as used in `BC`. n_epochs: The number of epochs to iterate through in one call to __iter__. Exactly one of `n_epochs` and `n_batches` should be provided. n_batches: The number of batches to iterate through in one call to __iter__. Exactly one of `n_epochs` and `n_batches` should be provided. on_epoch_end: A callback function without parameters to be called at the end of every epoch. on_batch_end: A callback function without parameters to be called at the end of every batch. """ def __init__( self, data_loader: Iterable[dict], n_epochs: Optional[int] = None, n_batches: Optional[int] = None, on_epoch_end: Optional[Callable[[], None]] = None, on_batch_end: Optional[Callable[[], None]] = None, ): if n_epochs is not None and n_batches is None: self.use_epochs = True elif n_epochs is None and n_batches is not None: self.use_epochs = False else: raise ValueError( "Must provide exactly one of `n_epochs` \ and `n_batches` arguments." ) self.data_loader = data_loader self.n_epochs = n_epochs self.n_batches = n_batches self.on_epoch_end = on_epoch_end self.on_batch_end = on_batch_end def __iter__(self) -> Iterable[Tuple[dict, dict]]: """Yields batches while updating tqdm display to display progress.""" samples_so_far = 0 epoch_num = 0 batch_num = 0 batch_suffix = epoch_suffix = "" if self.use_epochs: display = tqdm.tqdm(total=self.n_epochs) epoch_suffix = f"/{self.n_epochs}" else: # Use batches. display = tqdm.tqdm(total=self.n_batches) batch_suffix = f"/{self.n_batches}" def update_desc(): display.set_description( f"batch: {batch_num}{batch_suffix} \ epoch: {epoch_num}{epoch_suffix}" ) with contextlib.closing(display): while True: update_desc() got_data_on_epoch = False for batch in self.data_loader: got_data_on_epoch = True batch_num += 1 batch_size = len(batch["obs"]) assert batch_size > 0 samples_so_far += batch_size stats = { "epoch_num": epoch_num, "batch_num": batch_num, "samples_so_far": samples_so_far, } yield batch, stats if self.on_batch_end is not None: self.on_batch_end() if not self.use_epochs: update_desc() display.update(1) assert self.n_batches is not None if batch_num >= self.n_batches: return if not got_data_on_epoch: raise AssertionError( f"Data loader returned no data after " f"{batch_num} batches, during epoch " f"{epoch_num} -- did it reset correctly?" ) epoch_num += 1 if self.on_epoch_end is not None: self.on_epoch_end() if self.use_epochs: update_desc() display.update(1) assert self.n_epochs is not None if epoch_num >= self.n_epochs: return
[docs] def set_data_loader(self, new_data_loader): """ Set the data loader to new value """ self.data_loader = new_data_loader
[docs] class BC: """ Behavioral cloning (BC). Recovers a policy via supervised learning on observation-action Tensor pairs, sampled from a Torch DataLoader or any Iterator that ducktypes `torch.utils.data.DataLoader`. Args: observation_space: the observation space of the environment. action_space: the action space of the environment. policy_class: used to instantiate imitation policy. policy_kwargs: keyword arguments passed to policy's constructor. expert_data: If not None, then immediately call `self.set_expert_data_loader(expert_data)` during initialization. optimizer_cls: optimiser to use for supervised training. optimizer_kwargs: keyword arguments, excluding learning rate and weight decay, for optimiser construction. ent_weight: scaling applied to the policy's entropy regularization. l2_weight: scaling applied to the policy's L2 regularization. device: name/identity of device to place policy on. """ DEFAULT_BATCH_SIZE: int = 32 """ Default batch size for DataLoader automatically constructed from Transitions. See `set_expert_data_loader()`. """ def __init__( self, observation_space: gym.Space, action_space: gym.Space, *, policy_class: Type[policies.BasePolicy] = FeedForward32Policy, policy_kwargs: Optional[Mapping[str, Any]] = None, expert_data: Union[Iterable[Mapping], TransitionsMinimal, None] = None, optimizer_cls: Type[Optimizer] = Adam, optimizer_kwargs: Optional[Dict[str, Any]] = None, ent_weight: float = 1e-3, l2_weight: float = 0.0, device: Union[str, torch.device] = "auto", ): if optimizer_kwargs: if "weight_decay" in optimizer_kwargs: raise ValueError( "Use the parameter l2_weight instead of weight_decay." ) self.action_space = action_space self.observation_space = observation_space self.policy_class = policy_class self.device = device = utils.get_device(device) self.policy_kwargs = { "observation_space": self.observation_space, "action_space": self.action_space, "lr_schedule": ConstantLRSchedule(), } self.policy_kwargs.update(policy_kwargs or {}) self.device = utils.get_device(device) self.policy = self.policy_class(**self.policy_kwargs).to( self.device ) # pytype: disable=not-instantiable optimizer_kwargs = optimizer_kwargs or {} self.optimizer = optimizer_cls( self.policy.parameters(), **optimizer_kwargs ) self.expert_data_loader: Optional[Iterable[Mapping]] = None self.ent_weight = ent_weight self.l2_weight = l2_weight if expert_data is not None: self.set_expert_data_loader(expert_data)
[docs] def set_expert_data_loader( self, expert_data: Union[Iterable[Mapping], TransitionsMinimal], ) -> None: """Set the expert data loader, which yields batches of obs-act pairs. Changing the expert data loader on-demand is useful for DAgger and other interactive algorithms. Args: expert_data: Either a Torch `DataLoader`, any other iterator that yields dictionaries containing "obs" and "acts" Tensors or Numpy arrays, or a `TransitionsMinimal` instance. If this is a `TransitionsMinimal` instance, then it is automatically converted into a shuffled `DataLoader` with batch size `BC.DEFAULT_BATCH_SIZE`. """ if isinstance(expert_data, TransitionsMinimal): self.expert_data_loader = th_data.DataLoader( expert_data, shuffle=True, batch_size=BC.DEFAULT_BATCH_SIZE, collate_fn=transitions_collate_fn, ) else: self.expert_data_loader = expert_data
def _calculate_loss( self, obs: Union[torch.Tensor, np.ndarray], acts: Union[torch.Tensor, np.ndarray], ) -> Tuple[torch.Tensor, Dict[str, float]]: """ Calculate the supervised learning loss used to train the behavioral clone. Args: obs: The observations seen by the expert. If this is a Tensor, then gradients are detached first before loss is calculated. acts: The actions taken by the expert. If this is a Tensor, then its gradients are detached first before loss is calculated. Returns: loss: The supervised learning loss for the behavioral clone to optimize. stats_dict: Statistics about the learning process to be logged. """ obs = torch.as_tensor(obs, device=self.device).detach() acts = torch.as_tensor(acts, device=self.device).detach() _, log_prob, entropy = self.policy.evaluate_actions(obs, acts) prob_true_act = log_prob.exp().mean() log_prob = log_prob.mean() entropy = entropy.mean() l2_norms = [w.square().sum() for w in self.policy.parameters()] # divide by 2 to cancel with gradient of square l2_norm = sum(l2_norms) / 2 ent_loss = -self.ent_weight * entropy neglogp = -log_prob l2_loss = self.l2_weight * l2_norm loss = neglogp + ent_loss + l2_loss stats_dict = { "neglogp": neglogp.item(), "loss": loss.item(), "entropy": entropy.item(), "ent_loss": ent_loss.item(), "prob_true_act": prob_true_act.item(), "l2_norm": l2_norm.item(), "l2_loss": l2_loss.item(), } return loss, stats_dict
[docs] def train( self, *, n_epochs: Optional[int] = None, n_batches: Optional[int] = None, on_epoch_end: Callable[[], None] = None, on_batch_end: Callable[[], None] = None, log_interval: int = 100, ): """Train with supervised learning for some number of epochs. Here an 'epoch' is just a complete pass through the expert data loader, as set by `self.set_expert_data_loader()`. Args: n_epochs: Number of complete passes made through expert data before ending training. Provide exactly one of `n_epochs` and `n_batches`. n_batches: Number of batches loaded from dataset before ending training. Provide exactly one of `n_epochs` and `n_batches`. on_epoch_end: Optional callback with no parameters to run at the end of each epoch. on_batch_end: Optional callback with no parameters to run at the end of each batch. log_interval: Log stats after every log_interval batches. """ it = EpochOrBatchIteratorWithProgress( self.expert_data_loader, n_epochs=n_epochs, n_batches=n_batches, on_epoch_end=on_epoch_end, on_batch_end=on_batch_end, ) batch_num = 0 for batch, stats_dict_it in it: loss, stats_dict_loss = self._calculate_loss( batch["obs"], batch["acts"] ) self.optimizer.zero_grad() loss.backward() self.optimizer.step() if batch_num % log_interval == 0: for stats in [stats_dict_it, stats_dict_loss]: for k, v in stats.items(): log.record(k, v) log.dump(batch_num) batch_num += 1
[docs] def save_policy(self, policy_path: str) -> None: """Save policy to a patorch. Can be reloaded by `.reconstruct_policy()`. Args: policy_path: path to save policy to. """ torch.save(self.policy, policy_path)