improve docstring

This commit is contained in:
Cadene 2024-04-24 21:40:42 +00:00
parent 0ec28bf71a
commit 6d56bcb5de
1 changed files with 38 additions and 45 deletions

View File

@ -4,27 +4,24 @@ from torch import nn
def create_stats_buffers(shapes, modes, stats=None): def create_stats_buffers(shapes, modes, stats=None):
""" """
This function generates buffers to store the mean and standard deviation, or minimum and maximum values, Create buffers per modality (e.g. "observation.image", "action") containing their mean, std, min, max statistics.
used for normalizing tensors. The mode of normalization is determined by the `modes` dictionary, which can
be either "mean_std" (for mean and standard deviation) or "min_max" (for minimum and maximum). These buffers
are created as PyTorch nn.ParameterDict objects with nn.Parameters set to not require gradients, suitable
for normalization purposes.
If the provided `shapes` contain keys related to images, the shape is adjusted to be invariant to height
and width, assuming a channel-first (c, h, w) format.
Parameters: Parameters:
shapes (dict): A dictionary where keys represent tensor identifiers and values represent the shapes of those tensors. shapes (dict): A dictionary where keys are input modalities (e.g. "observation.image") and values are their shapes (e.g. `[3,96,96]`]).
modes (dict): A dictionary specifying the normalization mode for each key in `shapes`. Valid modes are "mean_std" or "min_max". These shapes are used to create the tensor buffer containing mean, std, min, max statistics. If the provided `shapes` contain keys related to images, the shape is adjusted to be invariant to height
stats (dict, optional): A dictionary containing pre-defined statistics for normalization. It can contain 'mean' and 'std' for and width, assuming a channel-first (c, h, w) format.
"mean_std" mode, or 'min' and 'max' for "min_max" mode. If provided, these statistics will overwrite the default buffers. modes (dict): A dictionary where keys are output modalities (e.g. "observation.image") and values are their normalization modes among:
It's expected for training the model for the first time. If not provided, the default buffers are supposed to be overriden - "mean_std": substract the mean and divide by standard deviation.
by a call to `policy.load_state_dict(state_dict)`. It's useful for loading a pretrained model for finetuning or evaluation, - "min_max": map to [-1, 1] range.
without requiring to initialize the dataset used to train the model just to acess the `stats`. stats (dict, optional): A dictionary where keys are output modalities (e.g. "observation.image") and values are dictionaries of statistic types and their values
(e.g. `{"mean": torch.randn(3,1,1)}, "std": torch.randn(3,1,1)}`). If provided, as expected for training the model for the first time,
these statistics will overwrite the default buffers. If not provided, as expected for finetuning or evaluation, the default buffers should to be
be overwritten by a call to `policy.load_state_dict(state_dict)`. That way, initializing the dataset is not needed to get the stats, since
they are already in the policy state_dict.
Returns: Returns:
dict: A dictionary where keys match the `modes` and `shapes` keys, and values are nn.ParameterDict objects containing dict: A dictionary where keys are modalities and values are `nn.ParameterDict` containing `nn.Parameters` set to
the appropriate buffers for normalization. `requires_grad=False`, suitable to not be updated during backpropagation.
""" """
stats_buffers = {} stats_buffers = {}
@ -79,22 +76,20 @@ def create_stats_buffers(shapes, modes, stats=None):
class Normalize(nn.Module): class Normalize(nn.Module):
""" """
A PyTorch module for normalizing data based on predefined statistics. Normalizes the input data (e.g. "observation.image") for more stable and faster convergence during training.
The class is initialized with a set of shapes, modes, and optional pre-defined statistics. It creates buffers for normalization based
on these inputs, which are then used to adjust data during the forward pass. The normalization process operates on a batch of data,
with different keys in the batch being normalized according to the specified modes. The following normalization modes are supported:
- "mean_std": Normalizes data using the mean and standard deviation.
- "min_max": Normalizes data to a [0, 1] range and then to a [-1, 1] range.
Parameters: Parameters:
shapes (dict): A dictionary where keys represent tensor identifiers and values represent the shapes of those tensors. shapes (dict): A dictionary where keys are input modalities (e.g. "observation.image") and values are their shapes (e.g. `[3,96,96]`]).
modes (dict): A dictionary indicating the normalization mode for each tensor key. Valid modes are "mean_std" or "min_max". These shapes are used to create the tensor buffer containing mean, std, min, max statistics. If the provided `shapes` contain keys related to images, the shape is adjusted to be invariant to height
stats (dict, optional): A dictionary containing pre-defined statistics for normalization. It can contain 'mean' and 'std' for and width, assuming a channel-first (c, h, w) format.
"mean_std" mode, or 'min' and 'max' for "min_max" mode. If provided, these statistics will overwrite the default buffers. modes (dict): A dictionary where keys are output modalities (e.g. "observation.image") and values are their normalization modes among:
It's expected for training the model for the first time. If not provided, the default buffers are supposed to be overriden - "mean_std": substract the mean and divide by standard deviation.
by a call to `policy.load_state_dict(state_dict)`. It's useful for loading a pretrained model for finetuning or evaluation, - "min_max": map to [-1, 1] range.
without requiring to initialize the dataset used to train the model just to acess the `stats`. stats (dict, optional): A dictionary where keys are output modalities (e.g. "observation.image") and values are dictionaries of statistic types and their values
(e.g. `{"mean": torch.randn(3,1,1)}, "std": torch.randn(3,1,1)}`). If provided, as expected for training the model for the first time,
these statistics will overwrite the default buffers. If not provided, as expected for finetuning or evaluation, the default buffers should to be
be overwritten by a call to `policy.load_state_dict(state_dict)`. That way, initializing the dataset is not needed to get the stats, since
they are already in the policy state_dict.
""" """
def __init__(self, shapes, modes, stats=None): def __init__(self, shapes, modes, stats=None):
@ -143,22 +138,20 @@ class Normalize(nn.Module):
class Unnormalize(nn.Module): class Unnormalize(nn.Module):
""" """
A PyTorch module for unnormalizing data based on predefined statistics. Similar to `Normalize` but unnormalizes output data (e.g. `{"action": torch.randn(b,c)}`) in their original range used by the environment.
The class is initialized with a set of shapes, modes, and optional pre-defined statistics. It creates buffers for unnormalization based
on these inputs, which are then used to adjust data during the forward pass. The unnormalization process operates on a batch of data,
with different keys in the batch being normalized according to the specified modes. The following unnormalization modes are supported:
- "mean_std": Subtracts the mean and divides by the standard deviation.
- "min_max": Scales and offsets the data such that the minimum is -1 and the maximum is +1.
Parameters: Parameters:
shapes (dict): A dictionary where keys represent tensor identifiers and values represent the shapes of those tensors. shapes (dict): A dictionary where keys are output modalities (e.g. "action") and values are their shapes (e.g. [10]).
modes (dict): A dictionary indicating the unnormalization mode for each tensor key. Valid modes are "mean_std" or "min_max". These shapes are used to create the tensor buffer containing mean, std, min, max statistics. If the provided `shapes` contain keys related to images, the shape is adjusted to be invariant to height
stats (dict, optional): A dictionary containing pre-defined statistics for unnormalization. It can contain 'mean' and 'std' for and width, assuming a channel-first (c, h, w) format.
"mean_std" mode, or 'min' and 'max' for "min_max" mode. If provided, these statistics will overwrite the default buffers. modes (dict): A dictionary where keys are output modalities (e.g. "action") and values are their unnormalization modes among:
It's expected for training the model for the first time. If not provided, the default buffers are supposed to be overriden - "mean_std": multiply by standard deviation and add mean
by a call to `policy.load_state_dict(state_dict)`. It's useful for loading a pretrained model for finetuning or evaluation, - "min_max": go from [-1, 1] range to original range.
without requiring to initialize the dataset used to train the model just to acess the `stats`. stats (dict, optional): A dictionary where keys are output modalities (e.g. "action") and values are dictionaries of statistic types and their values
(e.g. `{"max": torch.tensor(1)}, "min": torch.tensor(0)}`). If provided, as expected for training the model for the first time,
these statistics will overwrite the default buffers. If not provided, as expected for finetuning or evaluation, the default buffers should to be
be overwritten by a call to `policy.load_state_dict(state_dict)`. That way, initializing the dataset is not needed to get the stats, since
they are already in the policy state_dict.
""" """
def __init__(self, shapes, modes, stats=None): def __init__(self, shapes, modes, stats=None):