address comments
This commit is contained in:
parent
bc96284ca0
commit
0ec28bf71a
|
@ -30,15 +30,11 @@ class ActionChunkingTransformerConfig:
|
||||||
The key represents the output data name, and the value is a list indicating the dimensions
|
The key represents the output data name, and the value is a list indicating the dimensions
|
||||||
of the corresponding data. For example, "action" refers to an output shape of [14], indicating
|
of the corresponding data. For example, "action" refers to an output shape of [14], indicating
|
||||||
14-dimensional actions. Importantly, shapes doesnt include batch dimension or temporal dimension.
|
14-dimensional actions. Importantly, shapes doesnt include batch dimension or temporal dimension.
|
||||||
normalize_input_modes: A dictionary specifying the normalization mode to be applied to various inputs.
|
normalize_input_modes: A dictionary with key represents the modality (e.g. "observation.state"),
|
||||||
The key represents the input data name, and the value specifies the type of normalization to apply.
|
and the value specifies the normalization mode to apply. The two availables
|
||||||
Common normalization methods include "mean_std" (mean and standard deviation) or "min_max" (to normalize
|
modes are "mean_std" which substracts the mean and divide by the standard
|
||||||
between -1 and 1).
|
deviation and "min_max" which rescale in a [-1, 1] range.
|
||||||
unnormalize_output_modes: A dictionary specifying the method to unnormalize outputs.
|
unnormalize_output_modes: Similar dictionary as `normalize_input_modes`, but to unormalize in original scale.
|
||||||
This parameter maps output data types to their unnormalization modes, allowing the results to be
|
|
||||||
transformed back from a normalized state to a standard state. It is typically used when output
|
|
||||||
data needs to be interpreted in its original scale or units. For example, for "action", the
|
|
||||||
unnormalization mode might be "mean_std" or "min_max".
|
|
||||||
vision_backbone: Name of the torchvision resnet backbone to use for encoding images.
|
vision_backbone: Name of the torchvision resnet backbone to use for encoding images.
|
||||||
use_pretrained_backbone: Whether the backbone should be initialized with pretrained weights from
|
use_pretrained_backbone: Whether the backbone should be initialized with pretrained weights from
|
||||||
torchvision.
|
torchvision.
|
||||||
|
@ -65,7 +61,7 @@ class ActionChunkingTransformerConfig:
|
||||||
"""
|
"""
|
||||||
|
|
||||||
# Environment.
|
# Environment.
|
||||||
# TODO(rcadene, alexander-soar): remove these as they are defined in input_shapes, output_shapes
|
# TODO(rcadene, alexander-soare): remove these as they are defined in input_shapes, output_shapes
|
||||||
state_dim: int = 14
|
state_dim: int = 14
|
||||||
action_dim: int = 14
|
action_dim: int = 14
|
||||||
|
|
||||||
|
@ -75,13 +71,13 @@ class ActionChunkingTransformerConfig:
|
||||||
chunk_size: int = 100
|
chunk_size: int = 100
|
||||||
n_action_steps: int = 100
|
n_action_steps: int = 100
|
||||||
|
|
||||||
input_shapes: dict[str, str] = field(
|
input_shapes: dict[str, list[str]] = field(
|
||||||
default_factory=lambda: {
|
default_factory=lambda: {
|
||||||
"observation.images.top": [3, 480, 640],
|
"observation.images.top": [3, 480, 640],
|
||||||
"observation.state": [14],
|
"observation.state": [14],
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
output_shapes: dict[str, str] = field(
|
output_shapes: dict[str, list[str]] = field(
|
||||||
default_factory=lambda: {
|
default_factory=lambda: {
|
||||||
"action": [14],
|
"action": [14],
|
||||||
}
|
}
|
||||||
|
|
|
@ -72,8 +72,6 @@ class ActionChunkingTransformerPolicy(nn.Module):
|
||||||
if cfg is None:
|
if cfg is None:
|
||||||
cfg = ActionChunkingTransformerConfig()
|
cfg = ActionChunkingTransformerConfig()
|
||||||
self.cfg = cfg
|
self.cfg = cfg
|
||||||
self.normalize_input_modes = cfg.normalize_input_modes
|
|
||||||
self.unnormalize_output_modes = cfg.unnormalize_output_modes
|
|
||||||
self.normalize_inputs = Normalize(cfg.input_shapes, cfg.normalize_input_modes, dataset_stats)
|
self.normalize_inputs = Normalize(cfg.input_shapes, cfg.normalize_input_modes, dataset_stats)
|
||||||
self.unnormalize_outputs = Unnormalize(cfg.output_shapes, cfg.unnormalize_output_modes, dataset_stats)
|
self.unnormalize_outputs = Unnormalize(cfg.output_shapes, cfg.unnormalize_output_modes, dataset_stats)
|
||||||
|
|
||||||
|
|
|
@ -28,15 +28,11 @@ class DiffusionConfig:
|
||||||
The key represents the output data name, and the value is a list indicating the dimensions
|
The key represents the output data name, and the value is a list indicating the dimensions
|
||||||
of the corresponding data. For example, "action" refers to an output shape of [14], indicating
|
of the corresponding data. For example, "action" refers to an output shape of [14], indicating
|
||||||
14-dimensional actions. Importantly, shapes doesnt include batch dimension or temporal dimension.
|
14-dimensional actions. Importantly, shapes doesnt include batch dimension or temporal dimension.
|
||||||
normalize_input_modes: A dictionary specifying the normalization mode to be applied to various inputs.
|
normalize_input_modes: A dictionary with key represents the modality (e.g. "observation.state"),
|
||||||
The key represents the input data name, and the value specifies the type of normalization to apply.
|
and the value specifies the normalization mode to apply. The two availables
|
||||||
Common normalization methods include "mean_std" (mean and standard deviation) or "min_max" (to normalize
|
modes are "mean_std" which substracts the mean and divide by the standard
|
||||||
between -1 and 1).
|
deviation and "min_max" which rescale in a [-1, 1] range.
|
||||||
unnormalize_output_modes: A dictionary specifying the method to unnormalize outputs.
|
unnormalize_output_modes: Similar dictionary as `normalize_input_modes`, but to unormalize in original scale.
|
||||||
This parameter maps output data types to their unnormalization modes, allowing the results to be
|
|
||||||
transformed back from a normalized state to a standard state. It is typically used when output
|
|
||||||
data needs to be interpreted in its original scale or units. For example, for "action", the
|
|
||||||
unnormalization mode might be "mean_std" or "min_max".
|
|
||||||
vision_backbone: Name of the torchvision resnet backbone to use for encoding images.
|
vision_backbone: Name of the torchvision resnet backbone to use for encoding images.
|
||||||
crop_shape: (H, W) shape to crop images to as a preprocessing step for the vision backbone. Must fit
|
crop_shape: (H, W) shape to crop images to as a preprocessing step for the vision backbone. Must fit
|
||||||
within the image size. If None, no cropping is done.
|
within the image size. If None, no cropping is done.
|
||||||
|
@ -74,7 +70,7 @@ class DiffusionConfig:
|
||||||
|
|
||||||
# Environment.
|
# Environment.
|
||||||
# Inherit these from the environment config.
|
# Inherit these from the environment config.
|
||||||
# TODO(rcadene, alexander-soar): remove these as they are defined in input_shapes, output_shapes
|
# TODO(rcadene, alexander-soare): remove these as they are defined in input_shapes, output_shapes
|
||||||
state_dim: int = 2
|
state_dim: int = 2
|
||||||
action_dim: int = 2
|
action_dim: int = 2
|
||||||
image_size: tuple[int, int] = (96, 96)
|
image_size: tuple[int, int] = (96, 96)
|
||||||
|
@ -84,13 +80,13 @@ class DiffusionConfig:
|
||||||
horizon: int = 16
|
horizon: int = 16
|
||||||
n_action_steps: int = 8
|
n_action_steps: int = 8
|
||||||
|
|
||||||
input_shapes: dict[str, str] = field(
|
input_shapes: dict[str, list[str]] = field(
|
||||||
default_factory=lambda: {
|
default_factory=lambda: {
|
||||||
"observation.image": [3, 96, 96],
|
"observation.image": [3, 96, 96],
|
||||||
"observation.state": [2],
|
"observation.state": [2],
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
output_shapes: dict[str, str] = field(
|
output_shapes: dict[str, list[str]] = field(
|
||||||
default_factory=lambda: {
|
default_factory=lambda: {
|
||||||
"action": [2],
|
"action": [2],
|
||||||
}
|
}
|
||||||
|
|
|
@ -56,8 +56,6 @@ class DiffusionPolicy(nn.Module):
|
||||||
if cfg is None:
|
if cfg is None:
|
||||||
cfg = DiffusionConfig()
|
cfg = DiffusionConfig()
|
||||||
self.cfg = cfg
|
self.cfg = cfg
|
||||||
self.normalize_input_modes = cfg.normalize_input_modes
|
|
||||||
self.unnormalize_output_modes = cfg.unnormalize_output_modes
|
|
||||||
self.normalize_inputs = Normalize(cfg.input_shapes, cfg.normalize_input_modes, dataset_stats)
|
self.normalize_inputs = Normalize(cfg.input_shapes, cfg.normalize_input_modes, dataset_stats)
|
||||||
self.unnormalize_outputs = Unnormalize(cfg.output_shapes, cfg.unnormalize_output_modes, dataset_stats)
|
self.unnormalize_outputs = Unnormalize(cfg.output_shapes, cfg.unnormalize_output_modes, dataset_stats)
|
||||||
|
|
||||||
|
|
|
@ -31,18 +31,24 @@ def create_stats_buffers(shapes, modes, stats=None):
|
||||||
for key, mode in modes.items():
|
for key, mode in modes.items():
|
||||||
assert mode in ["mean_std", "min_max"]
|
assert mode in ["mean_std", "min_max"]
|
||||||
|
|
||||||
shape = shapes[key]
|
shape = tuple(shapes[key])
|
||||||
|
|
||||||
# override shape to be invariant to height and width
|
|
||||||
if "image" in key:
|
if "image" in key:
|
||||||
# assume shape is channel first (b, c, h, w) or (b, t, c, h, w)
|
# sanity checks
|
||||||
shape[-1] = 1
|
assert len(shape) == 3, f"number of dimensions of {key} != 3 ({shape=}"
|
||||||
shape[-2] = 1
|
c, h, w = shape
|
||||||
|
assert c < h and c < w, f"{key} is not channel first ({shape=})"
|
||||||
|
# override image shape to be invariant to height and width
|
||||||
|
shape = (c, 1, 1)
|
||||||
|
|
||||||
|
# Note: we initialize mean, std, min, max to infinity. They should be overwritten
|
||||||
|
# downstream by `stats` or `policy.load_state_dict`, as expected. During forward,
|
||||||
|
# we assert they are not infinity anymore.
|
||||||
|
|
||||||
buffer = {}
|
buffer = {}
|
||||||
if mode == "mean_std":
|
if mode == "mean_std":
|
||||||
mean = torch.zeros(shape, dtype=torch.float32)
|
mean = torch.ones(shape, dtype=torch.float32) * torch.inf
|
||||||
std = torch.ones(shape, dtype=torch.float32)
|
std = torch.ones(shape, dtype=torch.float32) * torch.inf
|
||||||
buffer = nn.ParameterDict(
|
buffer = nn.ParameterDict(
|
||||||
{
|
{
|
||||||
"mean": nn.Parameter(mean, requires_grad=False),
|
"mean": nn.Parameter(mean, requires_grad=False),
|
||||||
|
@ -50,9 +56,8 @@ def create_stats_buffers(shapes, modes, stats=None):
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
elif mode == "min_max":
|
elif mode == "min_max":
|
||||||
# TODO(rcadene): should we assume input is in [-1, 1] range?
|
min = torch.ones(shape, dtype=torch.float32) * torch.inf
|
||||||
min = torch.ones(shape, dtype=torch.float32) * -1
|
max = torch.ones(shape, dtype=torch.float32) * torch.inf
|
||||||
max = torch.ones(shape, dtype=torch.float32)
|
|
||||||
buffer = nn.ParameterDict(
|
buffer = nn.ParameterDict(
|
||||||
{
|
{
|
||||||
"min": nn.Parameter(min, requires_grad=False),
|
"min": nn.Parameter(min, requires_grad=False),
|
||||||
|
@ -109,12 +114,24 @@ class Normalize(nn.Module):
|
||||||
buffer = getattr(self, "buffer_" + key.replace(".", "_"))
|
buffer = getattr(self, "buffer_" + key.replace(".", "_"))
|
||||||
|
|
||||||
if mode == "mean_std":
|
if mode == "mean_std":
|
||||||
mean = buffer["mean"].unsqueeze(0)
|
mean = buffer["mean"]
|
||||||
std = buffer["std"].unsqueeze(0)
|
std = buffer["std"]
|
||||||
|
assert not torch.isinf(
|
||||||
|
mean
|
||||||
|
).any(), "`mean` is infinity. You forgot to initialize with `stats` as argument, or called `policy.load_state_dict`."
|
||||||
|
assert not torch.isinf(
|
||||||
|
std
|
||||||
|
).any(), "`std` is infinity. You forgot to initialize with `stats` as argument, or called `policy.load_state_dict`."
|
||||||
batch[key] = (batch[key] - mean) / (std + 1e-8)
|
batch[key] = (batch[key] - mean) / (std + 1e-8)
|
||||||
elif mode == "min_max":
|
elif mode == "min_max":
|
||||||
min = buffer["min"].unsqueeze(0)
|
min = buffer["min"]
|
||||||
max = buffer["max"].unsqueeze(0)
|
max = buffer["max"]
|
||||||
|
assert not torch.isinf(
|
||||||
|
min
|
||||||
|
).any(), "`min` is infinity. You forgot to initialize with `stats` as argument, or called `policy.load_state_dict`."
|
||||||
|
assert not torch.isinf(
|
||||||
|
max
|
||||||
|
).any(), "`max` is infinity. You forgot to initialize with `stats` as argument, or called `policy.load_state_dict`."
|
||||||
# normalize to [0,1]
|
# normalize to [0,1]
|
||||||
batch[key] = (batch[key] - min) / (max - min)
|
batch[key] = (batch[key] - min) / (max - min)
|
||||||
# normalize to [-1, 1]
|
# normalize to [-1, 1]
|
||||||
|
@ -131,8 +148,8 @@ class Unnormalize(nn.Module):
|
||||||
The class is initialized with a set of shapes, modes, and optional pre-defined statistics. It creates buffers for unnormalization based
|
The class is initialized with a set of shapes, modes, and optional pre-defined statistics. It creates buffers for unnormalization based
|
||||||
on these inputs, which are then used to adjust data during the forward pass. The unnormalization process operates on a batch of data,
|
on these inputs, which are then used to adjust data during the forward pass. The unnormalization process operates on a batch of data,
|
||||||
with different keys in the batch being normalized according to the specified modes. The following unnormalization modes are supported:
|
with different keys in the batch being normalized according to the specified modes. The following unnormalization modes are supported:
|
||||||
- "mean_std": Unnormalizes data using the mean and standard deviation.
|
- "mean_std": Subtracts the mean and divides by the standard deviation.
|
||||||
- "min_max": Unnormalizes data to a [0, 1] range and then to a [-1, 1] range.
|
- "min_max": Scales and offsets the data such that the minimum is -1 and the maximum is +1.
|
||||||
|
|
||||||
Parameters:
|
Parameters:
|
||||||
shapes (dict): A dictionary where keys represent tensor identifiers and values represent the shapes of those tensors.
|
shapes (dict): A dictionary where keys represent tensor identifiers and values represent the shapes of those tensors.
|
||||||
|
@ -161,12 +178,24 @@ class Unnormalize(nn.Module):
|
||||||
buffer = getattr(self, "buffer_" + key.replace(".", "_"))
|
buffer = getattr(self, "buffer_" + key.replace(".", "_"))
|
||||||
|
|
||||||
if mode == "mean_std":
|
if mode == "mean_std":
|
||||||
mean = buffer["mean"].unsqueeze(0)
|
mean = buffer["mean"]
|
||||||
std = buffer["std"].unsqueeze(0)
|
std = buffer["std"]
|
||||||
|
assert not torch.isinf(
|
||||||
|
mean
|
||||||
|
).any(), "`mean` is infinity. You forgot to initialize with `stats` as argument, or called `policy.load_state_dict`."
|
||||||
|
assert not torch.isinf(
|
||||||
|
std
|
||||||
|
).any(), "`std` is infinity. You forgot to initialize with `stats` as argument, or called `policy.load_state_dict`."
|
||||||
batch[key] = batch[key] * std + mean
|
batch[key] = batch[key] * std + mean
|
||||||
elif mode == "min_max":
|
elif mode == "min_max":
|
||||||
min = buffer["min"].unsqueeze(0)
|
min = buffer["min"]
|
||||||
max = buffer["max"].unsqueeze(0)
|
max = buffer["max"]
|
||||||
|
assert not torch.isinf(
|
||||||
|
min
|
||||||
|
).any(), "`min` is infinity. You forgot to initialize with `stats` as argument, or called `policy.load_state_dict`."
|
||||||
|
assert not torch.isinf(
|
||||||
|
max
|
||||||
|
).any(), "`max` is infinity. You forgot to initialize with `stats` as argument, or called `policy.load_state_dict`."
|
||||||
batch[key] = (batch[key] + 1) / 2
|
batch[key] = (batch[key] + 1) / 2
|
||||||
batch[key] = batch[key] * (max - min) + min
|
batch[key] = batch[key] * (max - min) + min
|
||||||
else:
|
else:
|
||||||
|
|
|
@ -35,7 +35,7 @@ policy:
|
||||||
n_action_steps: 100
|
n_action_steps: 100
|
||||||
|
|
||||||
input_shapes:
|
input_shapes:
|
||||||
# TODO(rcadene, alexander-soar): add variables for height and width from the dataset/env?
|
# TODO(rcadene, alexander-soare): add variables for height and width from the dataset/env?
|
||||||
observation.images.top: [3, 480, 640]
|
observation.images.top: [3, 480, 640]
|
||||||
observation.state: ["${policy.state_dim}"]
|
observation.state: ["${policy.state_dim}"]
|
||||||
output_shapes:
|
output_shapes:
|
||||||
|
|
|
@ -51,7 +51,7 @@ policy:
|
||||||
n_action_steps: ${n_action_steps}
|
n_action_steps: ${n_action_steps}
|
||||||
|
|
||||||
input_shapes:
|
input_shapes:
|
||||||
# TODO(rcadene, alexander-soar): add variables for height and width from the dataset/env?
|
# TODO(rcadene, alexander-soare): add variables for height and width from the dataset/env?
|
||||||
observation.image: [3, 96, 96]
|
observation.image: [3, 96, 96]
|
||||||
observation.state: ["${policy.state_dim}"]
|
observation.state: ["${policy.state_dim}"]
|
||||||
output_shapes:
|
output_shapes:
|
||||||
|
|
|
@ -339,7 +339,6 @@ def train(cfg: dict, out_dir=None, job_name=None):
|
||||||
eval_info = eval_policy(
|
eval_info = eval_policy(
|
||||||
rollout_env,
|
rollout_env,
|
||||||
policy,
|
policy,
|
||||||
transform=offline_dataset.transform,
|
|
||||||
return_episode_data=True,
|
return_episode_data=True,
|
||||||
seed=cfg.seed,
|
seed=cfg.seed,
|
||||||
)
|
)
|
||||||
|
|
|
@ -96,9 +96,8 @@ def test_policy(env_name, policy_name, extra_overrides):
|
||||||
|
|
||||||
# Test load state_dict
|
# Test load state_dict
|
||||||
if policy_name != "tdmpc":
|
if policy_name != "tdmpc":
|
||||||
# TODO(rcadene, alexander-soar): make it work for tdmpc
|
# TODO(rcadene, alexander-soare): make it work for tdmpc
|
||||||
# TODO(rcadene, alexander-soar): how to remove need for dataset_stats?
|
new_policy = make_policy(cfg)
|
||||||
new_policy = make_policy(cfg, dataset_stats=dataset.stats)
|
|
||||||
new_policy.load_state_dict(policy.state_dict())
|
new_policy.load_state_dict(policy.state_dict())
|
||||||
|
|
||||||
|
|
||||||
|
@ -110,7 +109,7 @@ def test_policy(env_name, policy_name, extra_overrides):
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
def test_normalize(insert_temporal_dim):
|
def test_normalize(insert_temporal_dim):
|
||||||
# TODO(rcadene, alexander-soar): test with real data and assert results of normalization/unnormalization
|
# TODO(rcadene, alexander-soare): test with real data and assert results of normalization/unnormalization
|
||||||
|
|
||||||
input_shapes = {
|
input_shapes = {
|
||||||
"observation.image": [3, 96, 96],
|
"observation.image": [3, 96, 96],
|
||||||
|
@ -170,7 +169,8 @@ def test_normalize(insert_temporal_dim):
|
||||||
|
|
||||||
# test without stats
|
# test without stats
|
||||||
normalize = Normalize(input_shapes, normalize_input_modes, stats=None)
|
normalize = Normalize(input_shapes, normalize_input_modes, stats=None)
|
||||||
normalize(input_batch)
|
with pytest.raises(AssertionError):
|
||||||
|
normalize(input_batch)
|
||||||
|
|
||||||
# test with stats
|
# test with stats
|
||||||
normalize = Normalize(input_shapes, normalize_input_modes, stats=dataset_stats)
|
normalize = Normalize(input_shapes, normalize_input_modes, stats=dataset_stats)
|
||||||
|
@ -183,7 +183,8 @@ def test_normalize(insert_temporal_dim):
|
||||||
|
|
||||||
# test wihtout stats
|
# test wihtout stats
|
||||||
unnormalize = Unnormalize(output_shapes, unnormalize_output_modes, stats=None)
|
unnormalize = Unnormalize(output_shapes, unnormalize_output_modes, stats=None)
|
||||||
unnormalize(output_batch)
|
with pytest.raises(AssertionError):
|
||||||
|
unnormalize(output_batch)
|
||||||
|
|
||||||
# test with stats
|
# test with stats
|
||||||
unnormalize = Unnormalize(output_shapes, unnormalize_output_modes, stats=dataset_stats)
|
unnormalize = Unnormalize(output_shapes, unnormalize_output_modes, stats=dataset_stats)
|
||||||
|
|
Loading…
Reference in New Issue