address comments

This commit is contained in:
Cadene 2024-04-24 20:57:09 +00:00
parent bc96284ca0
commit 0ec28bf71a
9 changed files with 74 additions and 57 deletions

View File

@ -30,15 +30,11 @@ class ActionChunkingTransformerConfig:
The key represents the output data name, and the value is a list indicating the dimensions
of the corresponding data. For example, "action" refers to an output shape of [14], indicating
14-dimensional actions. Importantly, shapes doesnt include batch dimension or temporal dimension.
normalize_input_modes: A dictionary specifying the normalization mode to be applied to various inputs.
The key represents the input data name, and the value specifies the type of normalization to apply.
Common normalization methods include "mean_std" (mean and standard deviation) or "min_max" (to normalize
between -1 and 1).
unnormalize_output_modes: A dictionary specifying the method to unnormalize outputs.
This parameter maps output data types to their unnormalization modes, allowing the results to be
transformed back from a normalized state to a standard state. It is typically used when output
data needs to be interpreted in its original scale or units. For example, for "action", the
unnormalization mode might be "mean_std" or "min_max".
normalize_input_modes: A dictionary with key represents the modality (e.g. "observation.state"),
and the value specifies the normalization mode to apply. The two availables
modes are "mean_std" which substracts the mean and divide by the standard
deviation and "min_max" which rescale in a [-1, 1] range.
unnormalize_output_modes: Similar dictionary as `normalize_input_modes`, but to unormalize in original scale.
vision_backbone: Name of the torchvision resnet backbone to use for encoding images.
use_pretrained_backbone: Whether the backbone should be initialized with pretrained weights from
torchvision.
@ -65,7 +61,7 @@ class ActionChunkingTransformerConfig:
"""
# Environment.
# TODO(rcadene, alexander-soar): remove these as they are defined in input_shapes, output_shapes
# TODO(rcadene, alexander-soare): remove these as they are defined in input_shapes, output_shapes
state_dim: int = 14
action_dim: int = 14
@ -75,13 +71,13 @@ class ActionChunkingTransformerConfig:
chunk_size: int = 100
n_action_steps: int = 100
input_shapes: dict[str, str] = field(
input_shapes: dict[str, list[str]] = field(
default_factory=lambda: {
"observation.images.top": [3, 480, 640],
"observation.state": [14],
}
)
output_shapes: dict[str, str] = field(
output_shapes: dict[str, list[str]] = field(
default_factory=lambda: {
"action": [14],
}

View File

@ -72,8 +72,6 @@ class ActionChunkingTransformerPolicy(nn.Module):
if cfg is None:
cfg = ActionChunkingTransformerConfig()
self.cfg = cfg
self.normalize_input_modes = cfg.normalize_input_modes
self.unnormalize_output_modes = cfg.unnormalize_output_modes
self.normalize_inputs = Normalize(cfg.input_shapes, cfg.normalize_input_modes, dataset_stats)
self.unnormalize_outputs = Unnormalize(cfg.output_shapes, cfg.unnormalize_output_modes, dataset_stats)

View File

@ -28,15 +28,11 @@ class DiffusionConfig:
The key represents the output data name, and the value is a list indicating the dimensions
of the corresponding data. For example, "action" refers to an output shape of [14], indicating
14-dimensional actions. Importantly, shapes doesnt include batch dimension or temporal dimension.
normalize_input_modes: A dictionary specifying the normalization mode to be applied to various inputs.
The key represents the input data name, and the value specifies the type of normalization to apply.
Common normalization methods include "mean_std" (mean and standard deviation) or "min_max" (to normalize
between -1 and 1).
unnormalize_output_modes: A dictionary specifying the method to unnormalize outputs.
This parameter maps output data types to their unnormalization modes, allowing the results to be
transformed back from a normalized state to a standard state. It is typically used when output
data needs to be interpreted in its original scale or units. For example, for "action", the
unnormalization mode might be "mean_std" or "min_max".
normalize_input_modes: A dictionary with key represents the modality (e.g. "observation.state"),
and the value specifies the normalization mode to apply. The two availables
modes are "mean_std" which substracts the mean and divide by the standard
deviation and "min_max" which rescale in a [-1, 1] range.
unnormalize_output_modes: Similar dictionary as `normalize_input_modes`, but to unormalize in original scale.
vision_backbone: Name of the torchvision resnet backbone to use for encoding images.
crop_shape: (H, W) shape to crop images to as a preprocessing step for the vision backbone. Must fit
within the image size. If None, no cropping is done.
@ -74,7 +70,7 @@ class DiffusionConfig:
# Environment.
# Inherit these from the environment config.
# TODO(rcadene, alexander-soar): remove these as they are defined in input_shapes, output_shapes
# TODO(rcadene, alexander-soare): remove these as they are defined in input_shapes, output_shapes
state_dim: int = 2
action_dim: int = 2
image_size: tuple[int, int] = (96, 96)
@ -84,13 +80,13 @@ class DiffusionConfig:
horizon: int = 16
n_action_steps: int = 8
input_shapes: dict[str, str] = field(
input_shapes: dict[str, list[str]] = field(
default_factory=lambda: {
"observation.image": [3, 96, 96],
"observation.state": [2],
}
)
output_shapes: dict[str, str] = field(
output_shapes: dict[str, list[str]] = field(
default_factory=lambda: {
"action": [2],
}

View File

@ -56,8 +56,6 @@ class DiffusionPolicy(nn.Module):
if cfg is None:
cfg = DiffusionConfig()
self.cfg = cfg
self.normalize_input_modes = cfg.normalize_input_modes
self.unnormalize_output_modes = cfg.unnormalize_output_modes
self.normalize_inputs = Normalize(cfg.input_shapes, cfg.normalize_input_modes, dataset_stats)
self.unnormalize_outputs = Unnormalize(cfg.output_shapes, cfg.unnormalize_output_modes, dataset_stats)

View File

@ -31,18 +31,24 @@ def create_stats_buffers(shapes, modes, stats=None):
for key, mode in modes.items():
assert mode in ["mean_std", "min_max"]
shape = shapes[key]
shape = tuple(shapes[key])
# override shape to be invariant to height and width
if "image" in key:
# assume shape is channel first (b, c, h, w) or (b, t, c, h, w)
shape[-1] = 1
shape[-2] = 1
# sanity checks
assert len(shape) == 3, f"number of dimensions of {key} != 3 ({shape=}"
c, h, w = shape
assert c < h and c < w, f"{key} is not channel first ({shape=})"
# override image shape to be invariant to height and width
shape = (c, 1, 1)
# Note: we initialize mean, std, min, max to infinity. They should be overwritten
# downstream by `stats` or `policy.load_state_dict`, as expected. During forward,
# we assert they are not infinity anymore.
buffer = {}
if mode == "mean_std":
mean = torch.zeros(shape, dtype=torch.float32)
std = torch.ones(shape, dtype=torch.float32)
mean = torch.ones(shape, dtype=torch.float32) * torch.inf
std = torch.ones(shape, dtype=torch.float32) * torch.inf
buffer = nn.ParameterDict(
{
"mean": nn.Parameter(mean, requires_grad=False),
@ -50,9 +56,8 @@ def create_stats_buffers(shapes, modes, stats=None):
}
)
elif mode == "min_max":
# TODO(rcadene): should we assume input is in [-1, 1] range?
min = torch.ones(shape, dtype=torch.float32) * -1
max = torch.ones(shape, dtype=torch.float32)
min = torch.ones(shape, dtype=torch.float32) * torch.inf
max = torch.ones(shape, dtype=torch.float32) * torch.inf
buffer = nn.ParameterDict(
{
"min": nn.Parameter(min, requires_grad=False),
@ -109,12 +114,24 @@ class Normalize(nn.Module):
buffer = getattr(self, "buffer_" + key.replace(".", "_"))
if mode == "mean_std":
mean = buffer["mean"].unsqueeze(0)
std = buffer["std"].unsqueeze(0)
mean = buffer["mean"]
std = buffer["std"]
assert not torch.isinf(
mean
).any(), "`mean` is infinity. You forgot to initialize with `stats` as argument, or called `policy.load_state_dict`."
assert not torch.isinf(
std
).any(), "`std` is infinity. You forgot to initialize with `stats` as argument, or called `policy.load_state_dict`."
batch[key] = (batch[key] - mean) / (std + 1e-8)
elif mode == "min_max":
min = buffer["min"].unsqueeze(0)
max = buffer["max"].unsqueeze(0)
min = buffer["min"]
max = buffer["max"]
assert not torch.isinf(
min
).any(), "`min` is infinity. You forgot to initialize with `stats` as argument, or called `policy.load_state_dict`."
assert not torch.isinf(
max
).any(), "`max` is infinity. You forgot to initialize with `stats` as argument, or called `policy.load_state_dict`."
# normalize to [0,1]
batch[key] = (batch[key] - min) / (max - min)
# normalize to [-1, 1]
@ -131,8 +148,8 @@ class Unnormalize(nn.Module):
The class is initialized with a set of shapes, modes, and optional pre-defined statistics. It creates buffers for unnormalization based
on these inputs, which are then used to adjust data during the forward pass. The unnormalization process operates on a batch of data,
with different keys in the batch being normalized according to the specified modes. The following unnormalization modes are supported:
- "mean_std": Unnormalizes data using the mean and standard deviation.
- "min_max": Unnormalizes data to a [0, 1] range and then to a [-1, 1] range.
- "mean_std": Subtracts the mean and divides by the standard deviation.
- "min_max": Scales and offsets the data such that the minimum is -1 and the maximum is +1.
Parameters:
shapes (dict): A dictionary where keys represent tensor identifiers and values represent the shapes of those tensors.
@ -161,12 +178,24 @@ class Unnormalize(nn.Module):
buffer = getattr(self, "buffer_" + key.replace(".", "_"))
if mode == "mean_std":
mean = buffer["mean"].unsqueeze(0)
std = buffer["std"].unsqueeze(0)
mean = buffer["mean"]
std = buffer["std"]
assert not torch.isinf(
mean
).any(), "`mean` is infinity. You forgot to initialize with `stats` as argument, or called `policy.load_state_dict`."
assert not torch.isinf(
std
).any(), "`std` is infinity. You forgot to initialize with `stats` as argument, or called `policy.load_state_dict`."
batch[key] = batch[key] * std + mean
elif mode == "min_max":
min = buffer["min"].unsqueeze(0)
max = buffer["max"].unsqueeze(0)
min = buffer["min"]
max = buffer["max"]
assert not torch.isinf(
min
).any(), "`min` is infinity. You forgot to initialize with `stats` as argument, or called `policy.load_state_dict`."
assert not torch.isinf(
max
).any(), "`max` is infinity. You forgot to initialize with `stats` as argument, or called `policy.load_state_dict`."
batch[key] = (batch[key] + 1) / 2
batch[key] = batch[key] * (max - min) + min
else:

View File

@ -35,7 +35,7 @@ policy:
n_action_steps: 100
input_shapes:
# TODO(rcadene, alexander-soar): add variables for height and width from the dataset/env?
# TODO(rcadene, alexander-soare): add variables for height and width from the dataset/env?
observation.images.top: [3, 480, 640]
observation.state: ["${policy.state_dim}"]
output_shapes:

View File

@ -51,7 +51,7 @@ policy:
n_action_steps: ${n_action_steps}
input_shapes:
# TODO(rcadene, alexander-soar): add variables for height and width from the dataset/env?
# TODO(rcadene, alexander-soare): add variables for height and width from the dataset/env?
observation.image: [3, 96, 96]
observation.state: ["${policy.state_dim}"]
output_shapes:

View File

@ -339,7 +339,6 @@ def train(cfg: dict, out_dir=None, job_name=None):
eval_info = eval_policy(
rollout_env,
policy,
transform=offline_dataset.transform,
return_episode_data=True,
seed=cfg.seed,
)

View File

@ -96,9 +96,8 @@ def test_policy(env_name, policy_name, extra_overrides):
# Test load state_dict
if policy_name != "tdmpc":
# TODO(rcadene, alexander-soar): make it work for tdmpc
# TODO(rcadene, alexander-soar): how to remove need for dataset_stats?
new_policy = make_policy(cfg, dataset_stats=dataset.stats)
# TODO(rcadene, alexander-soare): make it work for tdmpc
new_policy = make_policy(cfg)
new_policy.load_state_dict(policy.state_dict())
@ -110,7 +109,7 @@ def test_policy(env_name, policy_name, extra_overrides):
],
)
def test_normalize(insert_temporal_dim):
# TODO(rcadene, alexander-soar): test with real data and assert results of normalization/unnormalization
# TODO(rcadene, alexander-soare): test with real data and assert results of normalization/unnormalization
input_shapes = {
"observation.image": [3, 96, 96],
@ -170,6 +169,7 @@ def test_normalize(insert_temporal_dim):
# test without stats
normalize = Normalize(input_shapes, normalize_input_modes, stats=None)
with pytest.raises(AssertionError):
normalize(input_batch)
# test with stats
@ -183,6 +183,7 @@ def test_normalize(insert_temporal_dim):
# test wihtout stats
unnormalize = Unnormalize(output_shapes, unnormalize_output_modes, stats=None)
with pytest.raises(AssertionError):
unnormalize(output_batch)
# test with stats