improve docstring
This commit is contained in:
parent
0ec28bf71a
commit
6d56bcb5de
|
@ -4,27 +4,24 @@ from torch import nn
|
||||||
|
|
||||||
def create_stats_buffers(shapes, modes, stats=None):
|
def create_stats_buffers(shapes, modes, stats=None):
|
||||||
"""
|
"""
|
||||||
This function generates buffers to store the mean and standard deviation, or minimum and maximum values,
|
Create buffers per modality (e.g. "observation.image", "action") containing their mean, std, min, max statistics.
|
||||||
used for normalizing tensors. The mode of normalization is determined by the `modes` dictionary, which can
|
|
||||||
be either "mean_std" (for mean and standard deviation) or "min_max" (for minimum and maximum). These buffers
|
|
||||||
are created as PyTorch nn.ParameterDict objects with nn.Parameters set to not require gradients, suitable
|
|
||||||
for normalization purposes.
|
|
||||||
|
|
||||||
If the provided `shapes` contain keys related to images, the shape is adjusted to be invariant to height
|
|
||||||
and width, assuming a channel-first (c, h, w) format.
|
|
||||||
|
|
||||||
Parameters:
|
Parameters:
|
||||||
shapes (dict): A dictionary where keys represent tensor identifiers and values represent the shapes of those tensors.
|
shapes (dict): A dictionary where keys are input modalities (e.g. "observation.image") and values are their shapes (e.g. `[3,96,96]`]).
|
||||||
modes (dict): A dictionary specifying the normalization mode for each key in `shapes`. Valid modes are "mean_std" or "min_max".
|
These shapes are used to create the tensor buffer containing mean, std, min, max statistics. If the provided `shapes` contain keys related to images, the shape is adjusted to be invariant to height
|
||||||
stats (dict, optional): A dictionary containing pre-defined statistics for normalization. It can contain 'mean' and 'std' for
|
and width, assuming a channel-first (c, h, w) format.
|
||||||
"mean_std" mode, or 'min' and 'max' for "min_max" mode. If provided, these statistics will overwrite the default buffers.
|
modes (dict): A dictionary where keys are output modalities (e.g. "observation.image") and values are their normalization modes among:
|
||||||
It's expected for training the model for the first time. If not provided, the default buffers are supposed to be overriden
|
- "mean_std": substract the mean and divide by standard deviation.
|
||||||
by a call to `policy.load_state_dict(state_dict)`. It's useful for loading a pretrained model for finetuning or evaluation,
|
- "min_max": map to [-1, 1] range.
|
||||||
without requiring to initialize the dataset used to train the model just to acess the `stats`.
|
stats (dict, optional): A dictionary where keys are output modalities (e.g. "observation.image") and values are dictionaries of statistic types and their values
|
||||||
|
(e.g. `{"mean": torch.randn(3,1,1)}, "std": torch.randn(3,1,1)}`). If provided, as expected for training the model for the first time,
|
||||||
|
these statistics will overwrite the default buffers. If not provided, as expected for finetuning or evaluation, the default buffers should to be
|
||||||
|
be overwritten by a call to `policy.load_state_dict(state_dict)`. That way, initializing the dataset is not needed to get the stats, since
|
||||||
|
they are already in the policy state_dict.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
dict: A dictionary where keys match the `modes` and `shapes` keys, and values are nn.ParameterDict objects containing
|
dict: A dictionary where keys are modalities and values are `nn.ParameterDict` containing `nn.Parameters` set to
|
||||||
the appropriate buffers for normalization.
|
`requires_grad=False`, suitable to not be updated during backpropagation.
|
||||||
"""
|
"""
|
||||||
stats_buffers = {}
|
stats_buffers = {}
|
||||||
|
|
||||||
|
@ -79,22 +76,20 @@ def create_stats_buffers(shapes, modes, stats=None):
|
||||||
|
|
||||||
class Normalize(nn.Module):
|
class Normalize(nn.Module):
|
||||||
"""
|
"""
|
||||||
A PyTorch module for normalizing data based on predefined statistics.
|
Normalizes the input data (e.g. "observation.image") for more stable and faster convergence during training.
|
||||||
|
|
||||||
The class is initialized with a set of shapes, modes, and optional pre-defined statistics. It creates buffers for normalization based
|
|
||||||
on these inputs, which are then used to adjust data during the forward pass. The normalization process operates on a batch of data,
|
|
||||||
with different keys in the batch being normalized according to the specified modes. The following normalization modes are supported:
|
|
||||||
- "mean_std": Normalizes data using the mean and standard deviation.
|
|
||||||
- "min_max": Normalizes data to a [0, 1] range and then to a [-1, 1] range.
|
|
||||||
|
|
||||||
Parameters:
|
Parameters:
|
||||||
shapes (dict): A dictionary where keys represent tensor identifiers and values represent the shapes of those tensors.
|
shapes (dict): A dictionary where keys are input modalities (e.g. "observation.image") and values are their shapes (e.g. `[3,96,96]`]).
|
||||||
modes (dict): A dictionary indicating the normalization mode for each tensor key. Valid modes are "mean_std" or "min_max".
|
These shapes are used to create the tensor buffer containing mean, std, min, max statistics. If the provided `shapes` contain keys related to images, the shape is adjusted to be invariant to height
|
||||||
stats (dict, optional): A dictionary containing pre-defined statistics for normalization. It can contain 'mean' and 'std' for
|
and width, assuming a channel-first (c, h, w) format.
|
||||||
"mean_std" mode, or 'min' and 'max' for "min_max" mode. If provided, these statistics will overwrite the default buffers.
|
modes (dict): A dictionary where keys are output modalities (e.g. "observation.image") and values are their normalization modes among:
|
||||||
It's expected for training the model for the first time. If not provided, the default buffers are supposed to be overriden
|
- "mean_std": substract the mean and divide by standard deviation.
|
||||||
by a call to `policy.load_state_dict(state_dict)`. It's useful for loading a pretrained model for finetuning or evaluation,
|
- "min_max": map to [-1, 1] range.
|
||||||
without requiring to initialize the dataset used to train the model just to acess the `stats`.
|
stats (dict, optional): A dictionary where keys are output modalities (e.g. "observation.image") and values are dictionaries of statistic types and their values
|
||||||
|
(e.g. `{"mean": torch.randn(3,1,1)}, "std": torch.randn(3,1,1)}`). If provided, as expected for training the model for the first time,
|
||||||
|
these statistics will overwrite the default buffers. If not provided, as expected for finetuning or evaluation, the default buffers should to be
|
||||||
|
be overwritten by a call to `policy.load_state_dict(state_dict)`. That way, initializing the dataset is not needed to get the stats, since
|
||||||
|
they are already in the policy state_dict.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, shapes, modes, stats=None):
|
def __init__(self, shapes, modes, stats=None):
|
||||||
|
@ -143,22 +138,20 @@ class Normalize(nn.Module):
|
||||||
|
|
||||||
class Unnormalize(nn.Module):
|
class Unnormalize(nn.Module):
|
||||||
"""
|
"""
|
||||||
A PyTorch module for unnormalizing data based on predefined statistics.
|
Similar to `Normalize` but unnormalizes output data (e.g. `{"action": torch.randn(b,c)}`) in their original range used by the environment.
|
||||||
|
|
||||||
The class is initialized with a set of shapes, modes, and optional pre-defined statistics. It creates buffers for unnormalization based
|
|
||||||
on these inputs, which are then used to adjust data during the forward pass. The unnormalization process operates on a batch of data,
|
|
||||||
with different keys in the batch being normalized according to the specified modes. The following unnormalization modes are supported:
|
|
||||||
- "mean_std": Subtracts the mean and divides by the standard deviation.
|
|
||||||
- "min_max": Scales and offsets the data such that the minimum is -1 and the maximum is +1.
|
|
||||||
|
|
||||||
Parameters:
|
Parameters:
|
||||||
shapes (dict): A dictionary where keys represent tensor identifiers and values represent the shapes of those tensors.
|
shapes (dict): A dictionary where keys are output modalities (e.g. "action") and values are their shapes (e.g. [10]).
|
||||||
modes (dict): A dictionary indicating the unnormalization mode for each tensor key. Valid modes are "mean_std" or "min_max".
|
These shapes are used to create the tensor buffer containing mean, std, min, max statistics. If the provided `shapes` contain keys related to images, the shape is adjusted to be invariant to height
|
||||||
stats (dict, optional): A dictionary containing pre-defined statistics for unnormalization. It can contain 'mean' and 'std' for
|
and width, assuming a channel-first (c, h, w) format.
|
||||||
"mean_std" mode, or 'min' and 'max' for "min_max" mode. If provided, these statistics will overwrite the default buffers.
|
modes (dict): A dictionary where keys are output modalities (e.g. "action") and values are their unnormalization modes among:
|
||||||
It's expected for training the model for the first time. If not provided, the default buffers are supposed to be overriden
|
- "mean_std": multiply by standard deviation and add mean
|
||||||
by a call to `policy.load_state_dict(state_dict)`. It's useful for loading a pretrained model for finetuning or evaluation,
|
- "min_max": go from [-1, 1] range to original range.
|
||||||
without requiring to initialize the dataset used to train the model just to acess the `stats`.
|
stats (dict, optional): A dictionary where keys are output modalities (e.g. "action") and values are dictionaries of statistic types and their values
|
||||||
|
(e.g. `{"max": torch.tensor(1)}, "min": torch.tensor(0)}`). If provided, as expected for training the model for the first time,
|
||||||
|
these statistics will overwrite the default buffers. If not provided, as expected for finetuning or evaluation, the default buffers should to be
|
||||||
|
be overwritten by a call to `policy.load_state_dict(state_dict)`. That way, initializing the dataset is not needed to get the stats, since
|
||||||
|
they are already in the policy state_dict.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, shapes, modes, stats=None):
|
def __init__(self, shapes, modes, stats=None):
|
||||||
|
|
Loading…
Reference in New Issue