pantheonrl.algos.adap.util

Collection of helper functions for ADAP

Functions

get_categorical

Samples from categorical distribution

get_context_kl_loss

Gets the KL loss for ADAP

get_l2_sphere

Samples from l2 sphere

get_natural_number

Returns context vector of shape (num,1) with numbers in range [0, ctx_size]

get_positive_square

Samples from the square with axes between 0 and 1

get_unit_square

Samples from unit square centered at 0

kl_divergence

Wrapper for the PyTorch implementation of the full form KL Divergence :param dist_true: the p distribution :param dist_pred: the q distribution :return: KL(dist_true||dist_pred)