address comments

This commit is contained in:
Cadene 2024-04-24 20:57:09 +00:00
parent bc96284ca0
commit 0ec28bf71a
9 changed files with 74 additions and 57 deletions

View File

@ -30,15 +30,11 @@ class ActionChunkingTransformerConfig:
The key represents the output data name, and the value is a list indicating the dimensions The key represents the output data name, and the value is a list indicating the dimensions
of the corresponding data. For example, "action" refers to an output shape of [14], indicating of the corresponding data. For example, "action" refers to an output shape of [14], indicating
14-dimensional actions. Importantly, shapes doesnt include batch dimension or temporal dimension. 14-dimensional actions. Importantly, shapes doesnt include batch dimension or temporal dimension.
normalize_input_modes: A dictionary specifying the normalization mode to be applied to various inputs. normalize_input_modes: A dictionary with key represents the modality (e.g. "observation.state"),
The key represents the input data name, and the value specifies the type of normalization to apply. and the value specifies the normalization mode to apply. The two availables
Common normalization methods include "mean_std" (mean and standard deviation) or "min_max" (to normalize modes are "mean_std" which substracts the mean and divide by the standard
between -1 and 1). deviation and "min_max" which rescale in a [-1, 1] range.
unnormalize_output_modes: A dictionary specifying the method to unnormalize outputs. unnormalize_output_modes: Similar dictionary as `normalize_input_modes`, but to unormalize in original scale.
This parameter maps output data types to their unnormalization modes, allowing the results to be
transformed back from a normalized state to a standard state. It is typically used when output
data needs to be interpreted in its original scale or units. For example, for "action", the
unnormalization mode might be "mean_std" or "min_max".
vision_backbone: Name of the torchvision resnet backbone to use for encoding images. vision_backbone: Name of the torchvision resnet backbone to use for encoding images.
use_pretrained_backbone: Whether the backbone should be initialized with pretrained weights from use_pretrained_backbone: Whether the backbone should be initialized with pretrained weights from
torchvision. torchvision.
@ -65,7 +61,7 @@ class ActionChunkingTransformerConfig:
""" """
# Environment. # Environment.
# TODO(rcadene, alexander-soar): remove these as they are defined in input_shapes, output_shapes # TODO(rcadene, alexander-soare): remove these as they are defined in input_shapes, output_shapes
state_dim: int = 14 state_dim: int = 14
action_dim: int = 14 action_dim: int = 14
@ -75,13 +71,13 @@ class ActionChunkingTransformerConfig:
chunk_size: int = 100 chunk_size: int = 100
n_action_steps: int = 100 n_action_steps: int = 100
input_shapes: dict[str, str] = field( input_shapes: dict[str, list[str]] = field(
default_factory=lambda: { default_factory=lambda: {
"observation.images.top": [3, 480, 640], "observation.images.top": [3, 480, 640],
"observation.state": [14], "observation.state": [14],
} }
) )
output_shapes: dict[str, str] = field( output_shapes: dict[str, list[str]] = field(
default_factory=lambda: { default_factory=lambda: {
"action": [14], "action": [14],
} }

View File

@ -72,8 +72,6 @@ class ActionChunkingTransformerPolicy(nn.Module):
if cfg is None: if cfg is None:
cfg = ActionChunkingTransformerConfig() cfg = ActionChunkingTransformerConfig()
self.cfg = cfg self.cfg = cfg
self.normalize_input_modes = cfg.normalize_input_modes
self.unnormalize_output_modes = cfg.unnormalize_output_modes
self.normalize_inputs = Normalize(cfg.input_shapes, cfg.normalize_input_modes, dataset_stats) self.normalize_inputs = Normalize(cfg.input_shapes, cfg.normalize_input_modes, dataset_stats)
self.unnormalize_outputs = Unnormalize(cfg.output_shapes, cfg.unnormalize_output_modes, dataset_stats) self.unnormalize_outputs = Unnormalize(cfg.output_shapes, cfg.unnormalize_output_modes, dataset_stats)

View File

@ -28,15 +28,11 @@ class DiffusionConfig:
The key represents the output data name, and the value is a list indicating the dimensions The key represents the output data name, and the value is a list indicating the dimensions
of the corresponding data. For example, "action" refers to an output shape of [14], indicating of the corresponding data. For example, "action" refers to an output shape of [14], indicating
14-dimensional actions. Importantly, shapes doesnt include batch dimension or temporal dimension. 14-dimensional actions. Importantly, shapes doesnt include batch dimension or temporal dimension.
normalize_input_modes: A dictionary specifying the normalization mode to be applied to various inputs. normalize_input_modes: A dictionary with key represents the modality (e.g. "observation.state"),
The key represents the input data name, and the value specifies the type of normalization to apply. and the value specifies the normalization mode to apply. The two availables
Common normalization methods include "mean_std" (mean and standard deviation) or "min_max" (to normalize modes are "mean_std" which substracts the mean and divide by the standard
between -1 and 1). deviation and "min_max" which rescale in a [-1, 1] range.
unnormalize_output_modes: A dictionary specifying the method to unnormalize outputs. unnormalize_output_modes: Similar dictionary as `normalize_input_modes`, but to unormalize in original scale.
This parameter maps output data types to their unnormalization modes, allowing the results to be
transformed back from a normalized state to a standard state. It is typically used when output
data needs to be interpreted in its original scale or units. For example, for "action", the
unnormalization mode might be "mean_std" or "min_max".
vision_backbone: Name of the torchvision resnet backbone to use for encoding images. vision_backbone: Name of the torchvision resnet backbone to use for encoding images.
crop_shape: (H, W) shape to crop images to as a preprocessing step for the vision backbone. Must fit crop_shape: (H, W) shape to crop images to as a preprocessing step for the vision backbone. Must fit
within the image size. If None, no cropping is done. within the image size. If None, no cropping is done.
@ -74,7 +70,7 @@ class DiffusionConfig:
# Environment. # Environment.
# Inherit these from the environment config. # Inherit these from the environment config.
# TODO(rcadene, alexander-soar): remove these as they are defined in input_shapes, output_shapes # TODO(rcadene, alexander-soare): remove these as they are defined in input_shapes, output_shapes
state_dim: int = 2 state_dim: int = 2
action_dim: int = 2 action_dim: int = 2
image_size: tuple[int, int] = (96, 96) image_size: tuple[int, int] = (96, 96)
@ -84,13 +80,13 @@ class DiffusionConfig:
horizon: int = 16 horizon: int = 16
n_action_steps: int = 8 n_action_steps: int = 8
input_shapes: dict[str, str] = field( input_shapes: dict[str, list[str]] = field(
default_factory=lambda: { default_factory=lambda: {
"observation.image": [3, 96, 96], "observation.image": [3, 96, 96],
"observation.state": [2], "observation.state": [2],
} }
) )
output_shapes: dict[str, str] = field( output_shapes: dict[str, list[str]] = field(
default_factory=lambda: { default_factory=lambda: {
"action": [2], "action": [2],
} }

View File

@ -56,8 +56,6 @@ class DiffusionPolicy(nn.Module):
if cfg is None: if cfg is None:
cfg = DiffusionConfig() cfg = DiffusionConfig()
self.cfg = cfg self.cfg = cfg
self.normalize_input_modes = cfg.normalize_input_modes
self.unnormalize_output_modes = cfg.unnormalize_output_modes
self.normalize_inputs = Normalize(cfg.input_shapes, cfg.normalize_input_modes, dataset_stats) self.normalize_inputs = Normalize(cfg.input_shapes, cfg.normalize_input_modes, dataset_stats)
self.unnormalize_outputs = Unnormalize(cfg.output_shapes, cfg.unnormalize_output_modes, dataset_stats) self.unnormalize_outputs = Unnormalize(cfg.output_shapes, cfg.unnormalize_output_modes, dataset_stats)

View File

@ -31,18 +31,24 @@ def create_stats_buffers(shapes, modes, stats=None):
for key, mode in modes.items(): for key, mode in modes.items():
assert mode in ["mean_std", "min_max"] assert mode in ["mean_std", "min_max"]
shape = shapes[key] shape = tuple(shapes[key])
# override shape to be invariant to height and width
if "image" in key: if "image" in key:
# assume shape is channel first (b, c, h, w) or (b, t, c, h, w) # sanity checks
shape[-1] = 1 assert len(shape) == 3, f"number of dimensions of {key} != 3 ({shape=}"
shape[-2] = 1 c, h, w = shape
assert c < h and c < w, f"{key} is not channel first ({shape=})"
# override image shape to be invariant to height and width
shape = (c, 1, 1)
# Note: we initialize mean, std, min, max to infinity. They should be overwritten
# downstream by `stats` or `policy.load_state_dict`, as expected. During forward,
# we assert they are not infinity anymore.
buffer = {} buffer = {}
if mode == "mean_std": if mode == "mean_std":
mean = torch.zeros(shape, dtype=torch.float32) mean = torch.ones(shape, dtype=torch.float32) * torch.inf
std = torch.ones(shape, dtype=torch.float32) std = torch.ones(shape, dtype=torch.float32) * torch.inf
buffer = nn.ParameterDict( buffer = nn.ParameterDict(
{ {
"mean": nn.Parameter(mean, requires_grad=False), "mean": nn.Parameter(mean, requires_grad=False),
@ -50,9 +56,8 @@ def create_stats_buffers(shapes, modes, stats=None):
} }
) )
elif mode == "min_max": elif mode == "min_max":
# TODO(rcadene): should we assume input is in [-1, 1] range? min = torch.ones(shape, dtype=torch.float32) * torch.inf
min = torch.ones(shape, dtype=torch.float32) * -1 max = torch.ones(shape, dtype=torch.float32) * torch.inf
max = torch.ones(shape, dtype=torch.float32)
buffer = nn.ParameterDict( buffer = nn.ParameterDict(
{ {
"min": nn.Parameter(min, requires_grad=False), "min": nn.Parameter(min, requires_grad=False),
@ -109,12 +114,24 @@ class Normalize(nn.Module):
buffer = getattr(self, "buffer_" + key.replace(".", "_")) buffer = getattr(self, "buffer_" + key.replace(".", "_"))
if mode == "mean_std": if mode == "mean_std":
mean = buffer["mean"].unsqueeze(0) mean = buffer["mean"]
std = buffer["std"].unsqueeze(0) std = buffer["std"]
assert not torch.isinf(
mean
).any(), "`mean` is infinity. You forgot to initialize with `stats` as argument, or called `policy.load_state_dict`."
assert not torch.isinf(
std
).any(), "`std` is infinity. You forgot to initialize with `stats` as argument, or called `policy.load_state_dict`."
batch[key] = (batch[key] - mean) / (std + 1e-8) batch[key] = (batch[key] - mean) / (std + 1e-8)
elif mode == "min_max": elif mode == "min_max":
min = buffer["min"].unsqueeze(0) min = buffer["min"]
max = buffer["max"].unsqueeze(0) max = buffer["max"]
assert not torch.isinf(
min
).any(), "`min` is infinity. You forgot to initialize with `stats` as argument, or called `policy.load_state_dict`."
assert not torch.isinf(
max
).any(), "`max` is infinity. You forgot to initialize with `stats` as argument, or called `policy.load_state_dict`."
# normalize to [0,1] # normalize to [0,1]
batch[key] = (batch[key] - min) / (max - min) batch[key] = (batch[key] - min) / (max - min)
# normalize to [-1, 1] # normalize to [-1, 1]
@ -131,8 +148,8 @@ class Unnormalize(nn.Module):
The class is initialized with a set of shapes, modes, and optional pre-defined statistics. It creates buffers for unnormalization based The class is initialized with a set of shapes, modes, and optional pre-defined statistics. It creates buffers for unnormalization based
on these inputs, which are then used to adjust data during the forward pass. The unnormalization process operates on a batch of data, on these inputs, which are then used to adjust data during the forward pass. The unnormalization process operates on a batch of data,
with different keys in the batch being normalized according to the specified modes. The following unnormalization modes are supported: with different keys in the batch being normalized according to the specified modes. The following unnormalization modes are supported:
- "mean_std": Unnormalizes data using the mean and standard deviation. - "mean_std": Subtracts the mean and divides by the standard deviation.
- "min_max": Unnormalizes data to a [0, 1] range and then to a [-1, 1] range. - "min_max": Scales and offsets the data such that the minimum is -1 and the maximum is +1.
Parameters: Parameters:
shapes (dict): A dictionary where keys represent tensor identifiers and values represent the shapes of those tensors. shapes (dict): A dictionary where keys represent tensor identifiers and values represent the shapes of those tensors.
@ -161,12 +178,24 @@ class Unnormalize(nn.Module):
buffer = getattr(self, "buffer_" + key.replace(".", "_")) buffer = getattr(self, "buffer_" + key.replace(".", "_"))
if mode == "mean_std": if mode == "mean_std":
mean = buffer["mean"].unsqueeze(0) mean = buffer["mean"]
std = buffer["std"].unsqueeze(0) std = buffer["std"]
assert not torch.isinf(
mean
).any(), "`mean` is infinity. You forgot to initialize with `stats` as argument, or called `policy.load_state_dict`."
assert not torch.isinf(
std
).any(), "`std` is infinity. You forgot to initialize with `stats` as argument, or called `policy.load_state_dict`."
batch[key] = batch[key] * std + mean batch[key] = batch[key] * std + mean
elif mode == "min_max": elif mode == "min_max":
min = buffer["min"].unsqueeze(0) min = buffer["min"]
max = buffer["max"].unsqueeze(0) max = buffer["max"]
assert not torch.isinf(
min
).any(), "`min` is infinity. You forgot to initialize with `stats` as argument, or called `policy.load_state_dict`."
assert not torch.isinf(
max
).any(), "`max` is infinity. You forgot to initialize with `stats` as argument, or called `policy.load_state_dict`."
batch[key] = (batch[key] + 1) / 2 batch[key] = (batch[key] + 1) / 2
batch[key] = batch[key] * (max - min) + min batch[key] = batch[key] * (max - min) + min
else: else:

View File

@ -35,7 +35,7 @@ policy:
n_action_steps: 100 n_action_steps: 100
input_shapes: input_shapes:
# TODO(rcadene, alexander-soar): add variables for height and width from the dataset/env? # TODO(rcadene, alexander-soare): add variables for height and width from the dataset/env?
observation.images.top: [3, 480, 640] observation.images.top: [3, 480, 640]
observation.state: ["${policy.state_dim}"] observation.state: ["${policy.state_dim}"]
output_shapes: output_shapes:

View File

@ -51,7 +51,7 @@ policy:
n_action_steps: ${n_action_steps} n_action_steps: ${n_action_steps}
input_shapes: input_shapes:
# TODO(rcadene, alexander-soar): add variables for height and width from the dataset/env? # TODO(rcadene, alexander-soare): add variables for height and width from the dataset/env?
observation.image: [3, 96, 96] observation.image: [3, 96, 96]
observation.state: ["${policy.state_dim}"] observation.state: ["${policy.state_dim}"]
output_shapes: output_shapes:

View File

@ -339,7 +339,6 @@ def train(cfg: dict, out_dir=None, job_name=None):
eval_info = eval_policy( eval_info = eval_policy(
rollout_env, rollout_env,
policy, policy,
transform=offline_dataset.transform,
return_episode_data=True, return_episode_data=True,
seed=cfg.seed, seed=cfg.seed,
) )

View File

@ -96,9 +96,8 @@ def test_policy(env_name, policy_name, extra_overrides):
# Test load state_dict # Test load state_dict
if policy_name != "tdmpc": if policy_name != "tdmpc":
# TODO(rcadene, alexander-soar): make it work for tdmpc # TODO(rcadene, alexander-soare): make it work for tdmpc
# TODO(rcadene, alexander-soar): how to remove need for dataset_stats? new_policy = make_policy(cfg)
new_policy = make_policy(cfg, dataset_stats=dataset.stats)
new_policy.load_state_dict(policy.state_dict()) new_policy.load_state_dict(policy.state_dict())
@ -110,7 +109,7 @@ def test_policy(env_name, policy_name, extra_overrides):
], ],
) )
def test_normalize(insert_temporal_dim): def test_normalize(insert_temporal_dim):
# TODO(rcadene, alexander-soar): test with real data and assert results of normalization/unnormalization # TODO(rcadene, alexander-soare): test with real data and assert results of normalization/unnormalization
input_shapes = { input_shapes = {
"observation.image": [3, 96, 96], "observation.image": [3, 96, 96],
@ -170,7 +169,8 @@ def test_normalize(insert_temporal_dim):
# test without stats # test without stats
normalize = Normalize(input_shapes, normalize_input_modes, stats=None) normalize = Normalize(input_shapes, normalize_input_modes, stats=None)
normalize(input_batch) with pytest.raises(AssertionError):
normalize(input_batch)
# test with stats # test with stats
normalize = Normalize(input_shapes, normalize_input_modes, stats=dataset_stats) normalize = Normalize(input_shapes, normalize_input_modes, stats=dataset_stats)
@ -183,7 +183,8 @@ def test_normalize(insert_temporal_dim):
# test wihtout stats # test wihtout stats
unnormalize = Unnormalize(output_shapes, unnormalize_output_modes, stats=None) unnormalize = Unnormalize(output_shapes, unnormalize_output_modes, stats=None)
unnormalize(output_batch) with pytest.raises(AssertionError):
unnormalize(output_batch)
# test with stats # test with stats
unnormalize = Unnormalize(output_shapes, unnormalize_output_modes, stats=dataset_stats) unnormalize = Unnormalize(output_shapes, unnormalize_output_modes, stats=dataset_stats)