address comments
This commit is contained in:
parent
bc96284ca0
commit
0ec28bf71a
|
@ -30,15 +30,11 @@ class ActionChunkingTransformerConfig:
|
|||
The key represents the output data name, and the value is a list indicating the dimensions
|
||||
of the corresponding data. For example, "action" refers to an output shape of [14], indicating
|
||||
14-dimensional actions. Importantly, shapes doesnt include batch dimension or temporal dimension.
|
||||
normalize_input_modes: A dictionary specifying the normalization mode to be applied to various inputs.
|
||||
The key represents the input data name, and the value specifies the type of normalization to apply.
|
||||
Common normalization methods include "mean_std" (mean and standard deviation) or "min_max" (to normalize
|
||||
between -1 and 1).
|
||||
unnormalize_output_modes: A dictionary specifying the method to unnormalize outputs.
|
||||
This parameter maps output data types to their unnormalization modes, allowing the results to be
|
||||
transformed back from a normalized state to a standard state. It is typically used when output
|
||||
data needs to be interpreted in its original scale or units. For example, for "action", the
|
||||
unnormalization mode might be "mean_std" or "min_max".
|
||||
normalize_input_modes: A dictionary with key represents the modality (e.g. "observation.state"),
|
||||
and the value specifies the normalization mode to apply. The two availables
|
||||
modes are "mean_std" which substracts the mean and divide by the standard
|
||||
deviation and "min_max" which rescale in a [-1, 1] range.
|
||||
unnormalize_output_modes: Similar dictionary as `normalize_input_modes`, but to unormalize in original scale.
|
||||
vision_backbone: Name of the torchvision resnet backbone to use for encoding images.
|
||||
use_pretrained_backbone: Whether the backbone should be initialized with pretrained weights from
|
||||
torchvision.
|
||||
|
@ -65,7 +61,7 @@ class ActionChunkingTransformerConfig:
|
|||
"""
|
||||
|
||||
# Environment.
|
||||
# TODO(rcadene, alexander-soar): remove these as they are defined in input_shapes, output_shapes
|
||||
# TODO(rcadene, alexander-soare): remove these as they are defined in input_shapes, output_shapes
|
||||
state_dim: int = 14
|
||||
action_dim: int = 14
|
||||
|
||||
|
@ -75,13 +71,13 @@ class ActionChunkingTransformerConfig:
|
|||
chunk_size: int = 100
|
||||
n_action_steps: int = 100
|
||||
|
||||
input_shapes: dict[str, str] = field(
|
||||
input_shapes: dict[str, list[str]] = field(
|
||||
default_factory=lambda: {
|
||||
"observation.images.top": [3, 480, 640],
|
||||
"observation.state": [14],
|
||||
}
|
||||
)
|
||||
output_shapes: dict[str, str] = field(
|
||||
output_shapes: dict[str, list[str]] = field(
|
||||
default_factory=lambda: {
|
||||
"action": [14],
|
||||
}
|
||||
|
|
|
@ -72,8 +72,6 @@ class ActionChunkingTransformerPolicy(nn.Module):
|
|||
if cfg is None:
|
||||
cfg = ActionChunkingTransformerConfig()
|
||||
self.cfg = cfg
|
||||
self.normalize_input_modes = cfg.normalize_input_modes
|
||||
self.unnormalize_output_modes = cfg.unnormalize_output_modes
|
||||
self.normalize_inputs = Normalize(cfg.input_shapes, cfg.normalize_input_modes, dataset_stats)
|
||||
self.unnormalize_outputs = Unnormalize(cfg.output_shapes, cfg.unnormalize_output_modes, dataset_stats)
|
||||
|
||||
|
|
|
@ -28,15 +28,11 @@ class DiffusionConfig:
|
|||
The key represents the output data name, and the value is a list indicating the dimensions
|
||||
of the corresponding data. For example, "action" refers to an output shape of [14], indicating
|
||||
14-dimensional actions. Importantly, shapes doesnt include batch dimension or temporal dimension.
|
||||
normalize_input_modes: A dictionary specifying the normalization mode to be applied to various inputs.
|
||||
The key represents the input data name, and the value specifies the type of normalization to apply.
|
||||
Common normalization methods include "mean_std" (mean and standard deviation) or "min_max" (to normalize
|
||||
between -1 and 1).
|
||||
unnormalize_output_modes: A dictionary specifying the method to unnormalize outputs.
|
||||
This parameter maps output data types to their unnormalization modes, allowing the results to be
|
||||
transformed back from a normalized state to a standard state. It is typically used when output
|
||||
data needs to be interpreted in its original scale or units. For example, for "action", the
|
||||
unnormalization mode might be "mean_std" or "min_max".
|
||||
normalize_input_modes: A dictionary with key represents the modality (e.g. "observation.state"),
|
||||
and the value specifies the normalization mode to apply. The two availables
|
||||
modes are "mean_std" which substracts the mean and divide by the standard
|
||||
deviation and "min_max" which rescale in a [-1, 1] range.
|
||||
unnormalize_output_modes: Similar dictionary as `normalize_input_modes`, but to unormalize in original scale.
|
||||
vision_backbone: Name of the torchvision resnet backbone to use for encoding images.
|
||||
crop_shape: (H, W) shape to crop images to as a preprocessing step for the vision backbone. Must fit
|
||||
within the image size. If None, no cropping is done.
|
||||
|
@ -74,7 +70,7 @@ class DiffusionConfig:
|
|||
|
||||
# Environment.
|
||||
# Inherit these from the environment config.
|
||||
# TODO(rcadene, alexander-soar): remove these as they are defined in input_shapes, output_shapes
|
||||
# TODO(rcadene, alexander-soare): remove these as they are defined in input_shapes, output_shapes
|
||||
state_dim: int = 2
|
||||
action_dim: int = 2
|
||||
image_size: tuple[int, int] = (96, 96)
|
||||
|
@ -84,13 +80,13 @@ class DiffusionConfig:
|
|||
horizon: int = 16
|
||||
n_action_steps: int = 8
|
||||
|
||||
input_shapes: dict[str, str] = field(
|
||||
input_shapes: dict[str, list[str]] = field(
|
||||
default_factory=lambda: {
|
||||
"observation.image": [3, 96, 96],
|
||||
"observation.state": [2],
|
||||
}
|
||||
)
|
||||
output_shapes: dict[str, str] = field(
|
||||
output_shapes: dict[str, list[str]] = field(
|
||||
default_factory=lambda: {
|
||||
"action": [2],
|
||||
}
|
||||
|
|
|
@ -56,8 +56,6 @@ class DiffusionPolicy(nn.Module):
|
|||
if cfg is None:
|
||||
cfg = DiffusionConfig()
|
||||
self.cfg = cfg
|
||||
self.normalize_input_modes = cfg.normalize_input_modes
|
||||
self.unnormalize_output_modes = cfg.unnormalize_output_modes
|
||||
self.normalize_inputs = Normalize(cfg.input_shapes, cfg.normalize_input_modes, dataset_stats)
|
||||
self.unnormalize_outputs = Unnormalize(cfg.output_shapes, cfg.unnormalize_output_modes, dataset_stats)
|
||||
|
||||
|
|
|
@ -31,18 +31,24 @@ def create_stats_buffers(shapes, modes, stats=None):
|
|||
for key, mode in modes.items():
|
||||
assert mode in ["mean_std", "min_max"]
|
||||
|
||||
shape = shapes[key]
|
||||
shape = tuple(shapes[key])
|
||||
|
||||
# override shape to be invariant to height and width
|
||||
if "image" in key:
|
||||
# assume shape is channel first (b, c, h, w) or (b, t, c, h, w)
|
||||
shape[-1] = 1
|
||||
shape[-2] = 1
|
||||
# sanity checks
|
||||
assert len(shape) == 3, f"number of dimensions of {key} != 3 ({shape=}"
|
||||
c, h, w = shape
|
||||
assert c < h and c < w, f"{key} is not channel first ({shape=})"
|
||||
# override image shape to be invariant to height and width
|
||||
shape = (c, 1, 1)
|
||||
|
||||
# Note: we initialize mean, std, min, max to infinity. They should be overwritten
|
||||
# downstream by `stats` or `policy.load_state_dict`, as expected. During forward,
|
||||
# we assert they are not infinity anymore.
|
||||
|
||||
buffer = {}
|
||||
if mode == "mean_std":
|
||||
mean = torch.zeros(shape, dtype=torch.float32)
|
||||
std = torch.ones(shape, dtype=torch.float32)
|
||||
mean = torch.ones(shape, dtype=torch.float32) * torch.inf
|
||||
std = torch.ones(shape, dtype=torch.float32) * torch.inf
|
||||
buffer = nn.ParameterDict(
|
||||
{
|
||||
"mean": nn.Parameter(mean, requires_grad=False),
|
||||
|
@ -50,9 +56,8 @@ def create_stats_buffers(shapes, modes, stats=None):
|
|||
}
|
||||
)
|
||||
elif mode == "min_max":
|
||||
# TODO(rcadene): should we assume input is in [-1, 1] range?
|
||||
min = torch.ones(shape, dtype=torch.float32) * -1
|
||||
max = torch.ones(shape, dtype=torch.float32)
|
||||
min = torch.ones(shape, dtype=torch.float32) * torch.inf
|
||||
max = torch.ones(shape, dtype=torch.float32) * torch.inf
|
||||
buffer = nn.ParameterDict(
|
||||
{
|
||||
"min": nn.Parameter(min, requires_grad=False),
|
||||
|
@ -109,12 +114,24 @@ class Normalize(nn.Module):
|
|||
buffer = getattr(self, "buffer_" + key.replace(".", "_"))
|
||||
|
||||
if mode == "mean_std":
|
||||
mean = buffer["mean"].unsqueeze(0)
|
||||
std = buffer["std"].unsqueeze(0)
|
||||
mean = buffer["mean"]
|
||||
std = buffer["std"]
|
||||
assert not torch.isinf(
|
||||
mean
|
||||
).any(), "`mean` is infinity. You forgot to initialize with `stats` as argument, or called `policy.load_state_dict`."
|
||||
assert not torch.isinf(
|
||||
std
|
||||
).any(), "`std` is infinity. You forgot to initialize with `stats` as argument, or called `policy.load_state_dict`."
|
||||
batch[key] = (batch[key] - mean) / (std + 1e-8)
|
||||
elif mode == "min_max":
|
||||
min = buffer["min"].unsqueeze(0)
|
||||
max = buffer["max"].unsqueeze(0)
|
||||
min = buffer["min"]
|
||||
max = buffer["max"]
|
||||
assert not torch.isinf(
|
||||
min
|
||||
).any(), "`min` is infinity. You forgot to initialize with `stats` as argument, or called `policy.load_state_dict`."
|
||||
assert not torch.isinf(
|
||||
max
|
||||
).any(), "`max` is infinity. You forgot to initialize with `stats` as argument, or called `policy.load_state_dict`."
|
||||
# normalize to [0,1]
|
||||
batch[key] = (batch[key] - min) / (max - min)
|
||||
# normalize to [-1, 1]
|
||||
|
@ -131,8 +148,8 @@ class Unnormalize(nn.Module):
|
|||
The class is initialized with a set of shapes, modes, and optional pre-defined statistics. It creates buffers for unnormalization based
|
||||
on these inputs, which are then used to adjust data during the forward pass. The unnormalization process operates on a batch of data,
|
||||
with different keys in the batch being normalized according to the specified modes. The following unnormalization modes are supported:
|
||||
- "mean_std": Unnormalizes data using the mean and standard deviation.
|
||||
- "min_max": Unnormalizes data to a [0, 1] range and then to a [-1, 1] range.
|
||||
- "mean_std": Subtracts the mean and divides by the standard deviation.
|
||||
- "min_max": Scales and offsets the data such that the minimum is -1 and the maximum is +1.
|
||||
|
||||
Parameters:
|
||||
shapes (dict): A dictionary where keys represent tensor identifiers and values represent the shapes of those tensors.
|
||||
|
@ -161,12 +178,24 @@ class Unnormalize(nn.Module):
|
|||
buffer = getattr(self, "buffer_" + key.replace(".", "_"))
|
||||
|
||||
if mode == "mean_std":
|
||||
mean = buffer["mean"].unsqueeze(0)
|
||||
std = buffer["std"].unsqueeze(0)
|
||||
mean = buffer["mean"]
|
||||
std = buffer["std"]
|
||||
assert not torch.isinf(
|
||||
mean
|
||||
).any(), "`mean` is infinity. You forgot to initialize with `stats` as argument, or called `policy.load_state_dict`."
|
||||
assert not torch.isinf(
|
||||
std
|
||||
).any(), "`std` is infinity. You forgot to initialize with `stats` as argument, or called `policy.load_state_dict`."
|
||||
batch[key] = batch[key] * std + mean
|
||||
elif mode == "min_max":
|
||||
min = buffer["min"].unsqueeze(0)
|
||||
max = buffer["max"].unsqueeze(0)
|
||||
min = buffer["min"]
|
||||
max = buffer["max"]
|
||||
assert not torch.isinf(
|
||||
min
|
||||
).any(), "`min` is infinity. You forgot to initialize with `stats` as argument, or called `policy.load_state_dict`."
|
||||
assert not torch.isinf(
|
||||
max
|
||||
).any(), "`max` is infinity. You forgot to initialize with `stats` as argument, or called `policy.load_state_dict`."
|
||||
batch[key] = (batch[key] + 1) / 2
|
||||
batch[key] = batch[key] * (max - min) + min
|
||||
else:
|
||||
|
|
|
@ -35,7 +35,7 @@ policy:
|
|||
n_action_steps: 100
|
||||
|
||||
input_shapes:
|
||||
# TODO(rcadene, alexander-soar): add variables for height and width from the dataset/env?
|
||||
# TODO(rcadene, alexander-soare): add variables for height and width from the dataset/env?
|
||||
observation.images.top: [3, 480, 640]
|
||||
observation.state: ["${policy.state_dim}"]
|
||||
output_shapes:
|
||||
|
|
|
@ -51,7 +51,7 @@ policy:
|
|||
n_action_steps: ${n_action_steps}
|
||||
|
||||
input_shapes:
|
||||
# TODO(rcadene, alexander-soar): add variables for height and width from the dataset/env?
|
||||
# TODO(rcadene, alexander-soare): add variables for height and width from the dataset/env?
|
||||
observation.image: [3, 96, 96]
|
||||
observation.state: ["${policy.state_dim}"]
|
||||
output_shapes:
|
||||
|
|
|
@ -339,7 +339,6 @@ def train(cfg: dict, out_dir=None, job_name=None):
|
|||
eval_info = eval_policy(
|
||||
rollout_env,
|
||||
policy,
|
||||
transform=offline_dataset.transform,
|
||||
return_episode_data=True,
|
||||
seed=cfg.seed,
|
||||
)
|
||||
|
|
|
@ -96,9 +96,8 @@ def test_policy(env_name, policy_name, extra_overrides):
|
|||
|
||||
# Test load state_dict
|
||||
if policy_name != "tdmpc":
|
||||
# TODO(rcadene, alexander-soar): make it work for tdmpc
|
||||
# TODO(rcadene, alexander-soar): how to remove need for dataset_stats?
|
||||
new_policy = make_policy(cfg, dataset_stats=dataset.stats)
|
||||
# TODO(rcadene, alexander-soare): make it work for tdmpc
|
||||
new_policy = make_policy(cfg)
|
||||
new_policy.load_state_dict(policy.state_dict())
|
||||
|
||||
|
||||
|
@ -110,7 +109,7 @@ def test_policy(env_name, policy_name, extra_overrides):
|
|||
],
|
||||
)
|
||||
def test_normalize(insert_temporal_dim):
|
||||
# TODO(rcadene, alexander-soar): test with real data and assert results of normalization/unnormalization
|
||||
# TODO(rcadene, alexander-soare): test with real data and assert results of normalization/unnormalization
|
||||
|
||||
input_shapes = {
|
||||
"observation.image": [3, 96, 96],
|
||||
|
@ -170,6 +169,7 @@ def test_normalize(insert_temporal_dim):
|
|||
|
||||
# test without stats
|
||||
normalize = Normalize(input_shapes, normalize_input_modes, stats=None)
|
||||
with pytest.raises(AssertionError):
|
||||
normalize(input_batch)
|
||||
|
||||
# test with stats
|
||||
|
@ -183,6 +183,7 @@ def test_normalize(insert_temporal_dim):
|
|||
|
||||
# test wihtout stats
|
||||
unnormalize = Unnormalize(output_shapes, unnormalize_output_modes, stats=None)
|
||||
with pytest.raises(AssertionError):
|
||||
unnormalize(output_batch)
|
||||
|
||||
# test with stats
|
||||
|
|
Loading…
Reference in New Issue