Move normalization to policy for act and diffusion (#90)

Co-authored-by: Alexander Soare <alexander.soare159@gmail.com>
This commit is contained in:
Remi 2024-04-25 11:47:38 +02:00 committed by GitHub
parent c1bcf857c5
commit e760e4cd63
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
25 changed files with 543 additions and 288 deletions

View File

@ -263,15 +263,13 @@ Secondly, assuming you have trained a policy, you need:
- `config.yaml` which you can get from the `.hydra` directory of your training output folder. - `config.yaml` which you can get from the `.hydra` directory of your training output folder.
- `model.pt` which should be one of the saved models in the `models` directory of your training output folder (they won't be named `model.pt` but you will need to choose one). - `model.pt` which should be one of the saved models in the `models` directory of your training output folder (they won't be named `model.pt` but you will need to choose one).
- `stats.pth` which should point to the same file in the dataset directory (found in `data/{dataset_name}`).
To upload these to the hub, prepare a folder with the following structure (you can use symlinks rather than copying): To upload these to the hub, prepare a folder with the following structure (you can use symlinks rather than copying):
``` ```
to_upload to_upload
├── config.yaml ├── config.yaml
├── model.pt └── model.pt
└── stats.pth
``` ```
With the folder prepared, run the following with a desired revision ID. With the folder prepared, run the following with a desired revision ID.

View File

@ -44,7 +44,7 @@ from datasets import load_dataset
# TODO(rcadene): list available datasets on lerobot page using `datasets` # TODO(rcadene): list available datasets on lerobot page using `datasets`
# download/load hugging face dataset in pyarrow format # download/load hugging face dataset in pyarrow format
hf_dataset, fps = load_dataset("lerobot/pusht", split="train"), 10 hf_dataset, fps = load_dataset("lerobot/pusht", split="train", revision="v1.1"), 10
# display name of dataset and its features # display name of dataset and its features
# TODO(rcadene): update to make the print pretty # TODO(rcadene): update to make the print pretty

View File

@ -19,7 +19,6 @@ folder = Path(snapshot_download(hub_id))
config_path = folder / "config.yaml" config_path = folder / "config.yaml"
weights_path = folder / "model.pt" weights_path = folder / "model.pt"
stats_path = folder / "stats.pth" # normalization stats
# Override some config parameters to do with evaluation. # Override some config parameters to do with evaluation.
overrides = [ overrides = [
@ -36,5 +35,4 @@ cfg = init_hydra_config(config_path, overrides)
eval( eval(
cfg, cfg,
out_dir=f"outputs/eval/example_{cfg.env.name}_{cfg.policy.name}", out_dir=f"outputs/eval/example_{cfg.env.name}_{cfg.policy.name}",
stats_path=stats_path,
) )

View File

@ -34,7 +34,7 @@ dataset = make_dataset(hydra_cfg)
# If you're doing something different, you will likely need to change at least some of the defaults. # If you're doing something different, you will likely need to change at least some of the defaults.
cfg = DiffusionConfig() cfg = DiffusionConfig()
# TODO(alexander-soare): Remove LR scheduler from the policy. # TODO(alexander-soare): Remove LR scheduler from the policy.
policy = DiffusionPolicy(cfg, lr_scheduler_num_training_steps=training_steps) policy = DiffusionPolicy(cfg, lr_scheduler_num_training_steps=training_steps, dataset_stats=dataset.stats)
policy.train() policy.train()
policy.to(device) policy.to(device)
@ -62,7 +62,6 @@ while not done:
done = True done = True
break break
# Save the policy, configuration, and normalization stats for later use. # Save the policy and configuration for later use.
policy.save(output_directory / "model.pt") policy.save(output_directory / "model.pt")
OmegaConf.save(hydra_cfg, output_directory / "config.yaml") OmegaConf.save(hydra_cfg, output_directory / "config.yaml")
torch.save(dataset.transform.transforms[-1].stats, output_directory / "stats.pth")

View File

@ -2,18 +2,13 @@ import os
from pathlib import Path from pathlib import Path
import torch import torch
from torchvision.transforms import v2 from omegaconf import OmegaConf
from lerobot.common.transforms import NormalizeTransform
DATA_DIR = Path(os.environ["DATA_DIR"]) if "DATA_DIR" in os.environ else None DATA_DIR = Path(os.environ["DATA_DIR"]) if "DATA_DIR" in os.environ else None
def make_dataset( def make_dataset(
cfg, cfg,
# set normalize=False to remove all transformations and keep images unnormalized in [0,255]
normalize=True,
stats_path=None,
split="train", split="train",
): ):
if cfg.env.name == "xarm": if cfg.env.name == "xarm":
@ -33,58 +28,26 @@ def make_dataset(
else: else:
raise ValueError(cfg.env.name) raise ValueError(cfg.env.name)
transforms = None
if normalize:
# TODO(rcadene): make normalization strategy configurable between mean_std, min_max, manual_min_max,
# min_max_from_spec
# TODO(rcadene): remove this and put it in config. Ideally we want to reproduce SOTA results just with mean_std
normalization_mode = "mean_std" if cfg.env.name == "aloha" else "min_max"
if cfg.policy.name == "diffusion" and cfg.env.name == "pusht":
stats = {}
# TODO(rcadene): we overwrite stats to have the same as pretrained model, but we should remove this
stats["observation.state"] = {}
stats["observation.state"]["min"] = torch.tensor([13.456424, 32.938293], dtype=torch.float32)
stats["observation.state"]["max"] = torch.tensor([496.14618, 510.9579], dtype=torch.float32)
stats["action"] = {}
stats["action"]["min"] = torch.tensor([12.0, 25.0], dtype=torch.float32)
stats["action"]["max"] = torch.tensor([511.0, 511.0], dtype=torch.float32)
elif stats_path is None:
# load a first dataset to access precomputed stats
stats_dataset = clsfunc(
dataset_id=cfg.dataset_id,
split="train",
root=DATA_DIR,
)
stats = stats_dataset.stats
else:
stats = torch.load(stats_path)
transforms = v2.Compose(
[
NormalizeTransform(
stats,
in_keys=[
"observation.state",
"action",
],
mode=normalization_mode,
),
]
)
delta_timestamps = cfg.policy.get("delta_timestamps") delta_timestamps = cfg.policy.get("delta_timestamps")
if delta_timestamps is not None: if delta_timestamps is not None:
for key in delta_timestamps: for key in delta_timestamps:
if isinstance(delta_timestamps[key], str): if isinstance(delta_timestamps[key], str):
delta_timestamps[key] = eval(delta_timestamps[key]) delta_timestamps[key] = eval(delta_timestamps[key])
# TODO(rcadene): add data augmentations
dataset = clsfunc( dataset = clsfunc(
dataset_id=cfg.dataset_id, dataset_id=cfg.dataset_id,
split=split, split=split,
root=DATA_DIR, root=DATA_DIR,
delta_timestamps=delta_timestamps, delta_timestamps=delta_timestamps,
transform=transforms,
) )
if cfg.get("override_dataset_stats"):
for key, stats_dict in cfg.override_dataset_stats.items():
for stats_type, listconfig in stats_dict.items():
# example of stats_type: min, max, mean, std
stats = OmegaConf.to_container(listconfig, resolve=True)
dataset.stats[key][stats_type] = torch.tensor(stats, dtype=torch.float32)
return dataset return dataset

View File

@ -1,10 +1,8 @@
import einops import einops
import torch import torch
from lerobot.common.transforms import apply_inverse_transform
def preprocess_observation(observation):
def preprocess_observation(observation, transform=None):
# map to expected inputs for the policy # map to expected inputs for the policy
obs = {} obs = {}
@ -24,7 +22,7 @@ def preprocess_observation(observation, transform=None):
assert img.dtype == torch.uint8, f"expect torch.uint8, but instead {img.dtype=}" assert img.dtype == torch.uint8, f"expect torch.uint8, but instead {img.dtype=}"
# convert to channel first of type float32 in range [0,1] # convert to channel first of type float32 in range [0,1]
img = einops.rearrange(img, "b h w c -> b c h w") img = einops.rearrange(img, "b h w c -> b c h w").contiguous()
img = img.type(torch.float32) img = img.type(torch.float32)
img /= 255 img /= 255
@ -33,19 +31,11 @@ def preprocess_observation(observation, transform=None):
# TODO(rcadene): enable pixels only baseline with `obs_type="pixels"` in environment by removing requirement for "agent_pos" # TODO(rcadene): enable pixels only baseline with `obs_type="pixels"` in environment by removing requirement for "agent_pos"
obs["observation.state"] = torch.from_numpy(observation["agent_pos"]).float() obs["observation.state"] = torch.from_numpy(observation["agent_pos"]).float()
# apply same transforms as in training
if transform is not None:
for key in obs:
obs[key] = torch.stack([transform({key: item})[key] for item in obs[key]])
return obs return obs
def postprocess_action(action, transform=None): def postprocess_action(action):
action = action.to("cpu") action = action.to("cpu").numpy()
# action is a batch (num_env,action_dim) instead of an item (action_dim),
# we assume applying inverse transform on a batch works the same
action = apply_inverse_transform({"action": action}, transform)["action"].numpy()
assert ( assert (
action.ndim == 2 action.ndim == 2
), "we assume dimensions are respectively the number of parallel envs, action dimensions" ), "we assume dimensions are respectively the number of parallel envs, action dimensions"

View File

@ -8,23 +8,30 @@ class ActionChunkingTransformerConfig:
Defaults are configured for training on bimanual Aloha tasks like "insertion" or "transfer". Defaults are configured for training on bimanual Aloha tasks like "insertion" or "transfer".
The parameters you will most likely need to change are the ones which depend on the environment / sensors. The parameters you will most likely need to change are the ones which depend on the environment / sensors.
Those are: `state_dim`, `action_dim` and `camera_names`. Those are: `input_shapes` and 'output_shapes`.
Args: Args:
state_dim: Dimensionality of the observation state space (excluding images).
action_dim: Dimensionality of the action space.
n_obs_steps: Number of environment steps worth of observations to pass to the policy (takes the n_obs_steps: Number of environment steps worth of observations to pass to the policy (takes the
current step and additional steps going back). current step and additional steps going back).
camera_names: The (unique) set of names for the cameras.
chunk_size: The size of the action prediction "chunks" in units of environment steps. chunk_size: The size of the action prediction "chunks" in units of environment steps.
n_action_steps: The number of action steps to run in the environment for one invocation of the policy. n_action_steps: The number of action steps to run in the environment for one invocation of the policy.
This should be no greater than the chunk size. For example, if the chunk size size 100, you may This should be no greater than the chunk size. For example, if the chunk size size 100, you may
set this to 50. This would mean that the model predicts 100 steps worth of actions, runs 50 in the set this to 50. This would mean that the model predicts 100 steps worth of actions, runs 50 in the
environment, and throws the other 50 out. environment, and throws the other 50 out.
image_normalization_mean: Value to subtract from the input image pixels (inputs are assumed to be in input_shapes: A dictionary defining the shapes of the input data for the policy.
[0, 1]) for normalization. The key represents the input data name, and the value is a list indicating the dimensions
image_normalization_std: Value by which to divide the input image pixels (after the mean has been of the corresponding data. For example, "observation.images.top" refers to an input from the
subtracted). "top" camera with dimensions [3, 96, 96], indicating it has three color channels and 96x96 resolution.
Importantly, shapes doesnt include batch dimension or temporal dimension.
output_shapes: A dictionary defining the shapes of the output data for the policy.
The key represents the output data name, and the value is a list indicating the dimensions
of the corresponding data. For example, "action" refers to an output shape of [14], indicating
14-dimensional actions. Importantly, shapes doesnt include batch dimension or temporal dimension.
normalize_input_modes: A dictionary with key represents the modality (e.g. "observation.state"),
and the value specifies the normalization mode to apply. The two availables
modes are "mean_std" which substracts the mean and divide by the standard
deviation and "min_max" which rescale in a [-1, 1] range.
unnormalize_output_modes: Similar dictionary as `normalize_input_modes`, but to unormalize in original scale.
vision_backbone: Name of the torchvision resnet backbone to use for encoding images. vision_backbone: Name of the torchvision resnet backbone to use for encoding images.
use_pretrained_backbone: Whether the backbone should be initialized with pretrained weights from use_pretrained_backbone: Whether the backbone should be initialized with pretrained weights from
torchvision. torchvision.
@ -50,21 +57,35 @@ class ActionChunkingTransformerConfig:
is enabled. Loss is then calculated as: `reconstruction_loss + kl_weight * kld_loss`. is enabled. Loss is then calculated as: `reconstruction_loss + kl_weight * kld_loss`.
""" """
# Environment. # Input / output structure.
state_dim: int = 14
action_dim: int = 14
# Inputs / output structure.
n_obs_steps: int = 1 n_obs_steps: int = 1
camera_names: tuple[str] = ("top",)
chunk_size: int = 100 chunk_size: int = 100
n_action_steps: int = 100 n_action_steps: int = 100
# Vision preprocessing. input_shapes: dict[str, list[str]] = field(
image_normalization_mean: tuple[float, float, float] = field( default_factory=lambda: {
default_factory=lambda: [0.485, 0.456, 0.406] "observation.images.top": [3, 480, 640],
"observation.state": [14],
}
)
output_shapes: dict[str, list[str]] = field(
default_factory=lambda: {
"action": [14],
}
)
# Normalization / Unnormalization
normalize_input_modes: dict[str, str] = field(
default_factory=lambda: {
"observation.image": "mean_std",
"observation.state": "mean_std",
}
)
unnormalize_output_modes: dict[str, str] = field(
default_factory=lambda: {
"action": "mean_std",
}
) )
image_normalization_std: tuple[float, float, float] = field(default_factory=lambda: [0.229, 0.224, 0.225])
# Architecture. # Architecture.
# Vision backbone. # Vision backbone.
@ -117,7 +138,10 @@ class ActionChunkingTransformerConfig:
raise ValueError( raise ValueError(
f"Multiple observation steps not handled yet. Got `nobs_steps={self.n_obs_steps}`" f"Multiple observation steps not handled yet. Got `nobs_steps={self.n_obs_steps}`"
) )
if self.camera_names != ["top"]: # Check that there is only one image.
raise ValueError(f"For now, `camera_names` can only be ['top']. Got {self.camera_names}.") # TODO(alexander-soare): generalize this to multiple images.
if len(set(self.camera_names)) != len(self.camera_names): if (
raise ValueError(f"`camera_names` should not have any repeated entries. Got {self.camera_names}.") sum(k.startswith("observation.images.") for k in self.input_shapes) != 1
or "observation.images.top" not in self.input_shapes
):
raise ValueError('For now, only "observation.images.top" is accepted for an image input.')

View File

@ -15,12 +15,12 @@ import numpy as np
import torch import torch
import torch.nn.functional as F # noqa: N812 import torch.nn.functional as F # noqa: N812
import torchvision import torchvision
import torchvision.transforms as transforms
from torch import Tensor, nn from torch import Tensor, nn
from torchvision.models._utils import IntermediateLayerGetter from torchvision.models._utils import IntermediateLayerGetter
from torchvision.ops.misc import FrozenBatchNorm2d from torchvision.ops.misc import FrozenBatchNorm2d
from lerobot.common.policies.act.configuration_act import ActionChunkingTransformerConfig from lerobot.common.policies.act.configuration_act import ActionChunkingTransformerConfig
from lerobot.common.policies.normalize import Normalize, Unnormalize
class ActionChunkingTransformerPolicy(nn.Module): class ActionChunkingTransformerPolicy(nn.Module):
@ -62,7 +62,7 @@ class ActionChunkingTransformerPolicy(nn.Module):
name = "act" name = "act"
def __init__(self, cfg: ActionChunkingTransformerConfig | None = None): def __init__(self, cfg: ActionChunkingTransformerConfig | None = None, dataset_stats=None):
""" """
Args: Args:
cfg: Policy configuration class instance or None, in which case the default instantiation of the cfg: Policy configuration class instance or None, in which case the default instantiation of the
@ -72,6 +72,8 @@ class ActionChunkingTransformerPolicy(nn.Module):
if cfg is None: if cfg is None:
cfg = ActionChunkingTransformerConfig() cfg = ActionChunkingTransformerConfig()
self.cfg = cfg self.cfg = cfg
self.normalize_inputs = Normalize(cfg.input_shapes, cfg.normalize_input_modes, dataset_stats)
self.unnormalize_outputs = Unnormalize(cfg.output_shapes, cfg.unnormalize_output_modes, dataset_stats)
# BERT style VAE encoder with input [cls, *joint_space_configuration, *action_sequence]. # BERT style VAE encoder with input [cls, *joint_space_configuration, *action_sequence].
# The cls token forms parameters of the latent's distribution (like this [*means, *log_variances]). # The cls token forms parameters of the latent's distribution (like this [*means, *log_variances]).
@ -79,9 +81,13 @@ class ActionChunkingTransformerPolicy(nn.Module):
self.vae_encoder = _TransformerEncoder(cfg) self.vae_encoder = _TransformerEncoder(cfg)
self.vae_encoder_cls_embed = nn.Embedding(1, cfg.d_model) self.vae_encoder_cls_embed = nn.Embedding(1, cfg.d_model)
# Projection layer for joint-space configuration to hidden dimension. # Projection layer for joint-space configuration to hidden dimension.
self.vae_encoder_robot_state_input_proj = nn.Linear(cfg.state_dim, cfg.d_model) self.vae_encoder_robot_state_input_proj = nn.Linear(
cfg.input_shapes["observation.state"][0], cfg.d_model
)
# Projection layer for action (joint-space target) to hidden dimension. # Projection layer for action (joint-space target) to hidden dimension.
self.vae_encoder_action_input_proj = nn.Linear(cfg.state_dim, cfg.d_model) self.vae_encoder_action_input_proj = nn.Linear(
cfg.input_shapes["observation.state"][0], cfg.d_model
)
self.latent_dim = cfg.latent_dim self.latent_dim = cfg.latent_dim
# Projection layer from the VAE encoder's output to the latent distribution's parameter space. # Projection layer from the VAE encoder's output to the latent distribution's parameter space.
self.vae_encoder_latent_output_proj = nn.Linear(cfg.d_model, self.latent_dim * 2) self.vae_encoder_latent_output_proj = nn.Linear(cfg.d_model, self.latent_dim * 2)
@ -93,9 +99,6 @@ class ActionChunkingTransformerPolicy(nn.Module):
) )
# Backbone for image feature extraction. # Backbone for image feature extraction.
self.image_normalizer = transforms.Normalize(
mean=cfg.image_normalization_mean, std=cfg.image_normalization_std
)
backbone_model = getattr(torchvision.models, cfg.vision_backbone)( backbone_model = getattr(torchvision.models, cfg.vision_backbone)(
replace_stride_with_dilation=[False, False, cfg.replace_final_stride_with_dilation], replace_stride_with_dilation=[False, False, cfg.replace_final_stride_with_dilation],
pretrained=cfg.use_pretrained_backbone, pretrained=cfg.use_pretrained_backbone,
@ -112,7 +115,7 @@ class ActionChunkingTransformerPolicy(nn.Module):
# Transformer encoder input projections. The tokens will be structured like # Transformer encoder input projections. The tokens will be structured like
# [latent, robot_state, image_feature_map_pixels]. # [latent, robot_state, image_feature_map_pixels].
self.encoder_robot_state_input_proj = nn.Linear(cfg.state_dim, cfg.d_model) self.encoder_robot_state_input_proj = nn.Linear(cfg.input_shapes["observation.state"][0], cfg.d_model)
self.encoder_latent_input_proj = nn.Linear(self.latent_dim, cfg.d_model) self.encoder_latent_input_proj = nn.Linear(self.latent_dim, cfg.d_model)
self.encoder_img_feat_input_proj = nn.Conv2d( self.encoder_img_feat_input_proj = nn.Conv2d(
backbone_model.fc.in_features, cfg.d_model, kernel_size=1 backbone_model.fc.in_features, cfg.d_model, kernel_size=1
@ -126,7 +129,7 @@ class ActionChunkingTransformerPolicy(nn.Module):
self.decoder_pos_embed = nn.Embedding(cfg.chunk_size, cfg.d_model) self.decoder_pos_embed = nn.Embedding(cfg.chunk_size, cfg.d_model)
# Final action regression head on the output of the transformer's decoder. # Final action regression head on the output of the transformer's decoder.
self.action_head = nn.Linear(cfg.d_model, cfg.action_dim) self.action_head = nn.Linear(cfg.d_model, cfg.output_shapes["action"][0])
self._reset_parameters() self._reset_parameters()
self._create_optimizer() self._create_optimizer()
@ -169,10 +172,18 @@ class ActionChunkingTransformerPolicy(nn.Module):
queue is empty. queue is empty.
""" """
self.eval() self.eval()
batch = self.normalize_inputs(batch)
if len(self._action_queue) == 0: if len(self._action_queue) == 0:
# `_forward` returns a (batch_size, n_action_steps, action_dim) tensor, but the queue effectively # `_forward` returns a (batch_size, n_action_steps, action_dim) tensor, but the queue effectively
# has shape (n_action_steps, batch_size, *), hence the transpose. # has shape (n_action_steps, batch_size, *), hence the transpose.
self._action_queue.extend(self._forward(batch)[0][: self.cfg.n_action_steps].transpose(0, 1)) actions = self._forward(batch)[0][: self.cfg.n_action_steps]
# TODO(rcadene): make _forward return output dictionary?
actions = self.unnormalize_outputs({"action": actions})["action"]
self._action_queue.extend(actions.transpose(0, 1))
return self._action_queue.popleft() return self._action_queue.popleft()
def forward(self, batch, **_) -> dict[str, Tensor]: def forward(self, batch, **_) -> dict[str, Tensor]:
@ -203,7 +214,11 @@ class ActionChunkingTransformerPolicy(nn.Module):
"""Run the model in train mode, compute the loss, and do an optimization step.""" """Run the model in train mode, compute the loss, and do an optimization step."""
start_time = time.time() start_time = time.time()
self.train() self.train()
batch = self.normalize_inputs(batch)
loss_dict = self.forward(batch) loss_dict = self.forward(batch)
# TODO(rcadene): self.unnormalize_outputs(out_dict)
loss = loss_dict["loss"] loss = loss_dict["loss"]
loss.backward() loss.backward()
@ -232,17 +247,9 @@ class ActionChunkingTransformerPolicy(nn.Module):
"observation.images.{name}": (B, C, H, W) tensor of images. "observation.images.{name}": (B, C, H, W) tensor of images.
} }
""" """
# Check that there is only one image. # Stack images in the order dictated by input_shapes.
# TODO(alexander-soare): generalize this to multiple images.
provided_cameras = {k.rsplit(".", 1)[-1] for k in batch if k.startswith("observation.images.")}
if len(missing := set(self.cfg.camera_names).difference(provided_cameras)) > 0:
raise ValueError(
f"The following camera images are missing from the provided batch: {missing}. Check the "
"configuration parameter: `camera_names`."
)
# Stack images in the order dictated by the camera names.
batch["observation.images"] = torch.stack( batch["observation.images"] = torch.stack(
[batch[f"observation.images.{name}"] for name in self.cfg.camera_names], [batch[k] for k in self.cfg.input_shapes if k.startswith("observation.images.")],
dim=-4, dim=-4,
) )
@ -309,8 +316,8 @@ class ActionChunkingTransformerPolicy(nn.Module):
# Camera observation features and positional embeddings. # Camera observation features and positional embeddings.
all_cam_features = [] all_cam_features = []
all_cam_pos_embeds = [] all_cam_pos_embeds = []
images = self.image_normalizer(batch["observation.images"]) images = batch["observation.images"]
for cam_index in range(len(self.cfg.camera_names)): for cam_index in range(images.shape[-4]):
cam_features = self.backbone(images[:, cam_index])["feature_map"] cam_features = self.backbone(images[:, cam_index])["feature_map"]
cam_pos_embed = self.encoder_cam_feat_pos_embed(cam_features).to(dtype=cam_features.dtype) cam_pos_embed = self.encoder_cam_feat_pos_embed(cam_features).to(dtype=cam_features.dtype)
cam_features = self.encoder_img_feat_input_proj(cam_features) # (B, C, h, w) cam_features = self.encoder_img_feat_input_proj(cam_features) # (B, C, h, w)

View File

@ -1,4 +1,4 @@
from dataclasses import dataclass from dataclasses import dataclass, field
@dataclass @dataclass
@ -8,21 +8,28 @@ class DiffusionConfig:
Defaults are configured for training with PushT providing proprioceptive and single camera observations. Defaults are configured for training with PushT providing proprioceptive and single camera observations.
The parameters you will most likely need to change are the ones which depend on the environment / sensors. The parameters you will most likely need to change are the ones which depend on the environment / sensors.
Those are: `state_dim`, `action_dim` and `image_size`. Those are: `input_shapes` and `output_shapes`.
Args: Args:
state_dim: Dimensionality of the observation state space (excluding images).
action_dim: Dimensionality of the action space.
image_size: (H, W) size of the input images.
n_obs_steps: Number of environment steps worth of observations to pass to the policy (takes the n_obs_steps: Number of environment steps worth of observations to pass to the policy (takes the
current step and additional steps going back). current step and additional steps going back).
horizon: Diffusion model action prediction size as detailed in `DiffusionPolicy.select_action`. horizon: Diffusion model action prediction size as detailed in `DiffusionPolicy.select_action`.
n_action_steps: The number of action steps to run in the environment for one invocation of the policy. n_action_steps: The number of action steps to run in the environment for one invocation of the policy.
See `DiffusionPolicy.select_action` for more details. See `DiffusionPolicy.select_action` for more details.
image_normalization_mean: Value to subtract from the input image pixels (inputs are assumed to be in input_shapes: A dictionary defining the shapes of the input data for the policy.
[0, 1]) for normalization. The key represents the input data name, and the value is a list indicating the dimensions
image_normalization_std: Value by which to divide the input image pixels (after the mean has been of the corresponding data. For example, "observation.image" refers to an input from
subtracted). a camera with dimensions [3, 96, 96], indicating it has three color channels and 96x96 resolution.
Importantly, shapes doesnt include batch dimension or temporal dimension.
output_shapes: A dictionary defining the shapes of the output data for the policy.
The key represents the output data name, and the value is a list indicating the dimensions
of the corresponding data. For example, "action" refers to an output shape of [14], indicating
14-dimensional actions. Importantly, shapes doesnt include batch dimension or temporal dimension.
normalize_input_modes: A dictionary with key represents the modality (e.g. "observation.state"),
and the value specifies the normalization mode to apply. The two availables
modes are "mean_std" which substracts the mean and divide by the standard
deviation and "min_max" which rescale in a [-1, 1] range.
unnormalize_output_modes: Similar dictionary as `normalize_input_modes`, but to unormalize in original scale.
vision_backbone: Name of the torchvision resnet backbone to use for encoding images. vision_backbone: Name of the torchvision resnet backbone to use for encoding images.
crop_shape: (H, W) shape to crop images to as a preprocessing step for the vision backbone. Must fit crop_shape: (H, W) shape to crop images to as a preprocessing step for the vision backbone. Must fit
within the image size. If None, no cropping is done. within the image size. If None, no cropping is done.
@ -58,20 +65,35 @@ class DiffusionConfig:
spaced). If not provided, this defaults to be the same as `num_train_timesteps`. spaced). If not provided, this defaults to be the same as `num_train_timesteps`.
""" """
# Environment.
# Inherit these from the environment config.
state_dim: int = 2
action_dim: int = 2
image_size: tuple[int, int] = (96, 96)
# Inputs / output structure. # Inputs / output structure.
n_obs_steps: int = 2 n_obs_steps: int = 2
horizon: int = 16 horizon: int = 16
n_action_steps: int = 8 n_action_steps: int = 8
# Vision preprocessing. input_shapes: dict[str, list[str]] = field(
image_normalization_mean: tuple[float, float, float] = (0.5, 0.5, 0.5) default_factory=lambda: {
image_normalization_std: tuple[float, float, float] = (0.5, 0.5, 0.5) "observation.image": [3, 96, 96],
"observation.state": [2],
}
)
output_shapes: dict[str, list[str]] = field(
default_factory=lambda: {
"action": [2],
}
)
# Normalization / Unnormalization
normalize_input_modes: dict[str, str] = field(
default_factory=lambda: {
"observation.image": "mean_std",
"observation.state": "min_max",
}
)
unnormalize_output_modes: dict[str, str] = field(
default_factory=lambda: {
"action": "min_max",
}
)
# Architecture / modeling. # Architecture / modeling.
# Vision backbone. # Vision backbone.
@ -123,10 +145,14 @@ class DiffusionConfig:
raise ValueError( raise ValueError(
f"`vision_backbone` must be one of the ResNet variants. Got {self.vision_backbone}." f"`vision_backbone` must be one of the ResNet variants. Got {self.vision_backbone}."
) )
if self.crop_shape[0] > self.image_size[0] or self.crop_shape[1] > self.image_size[1]: if (
self.crop_shape[0] > self.input_shapes["observation.image"][1]
or self.crop_shape[1] > self.input_shapes["observation.image"][2]
):
raise ValueError( raise ValueError(
f"`crop_shape` should fit within `image_size`. Got {self.crop_shape} for `crop_shape` and " f'`crop_shape` should fit within `input_shapes["observation.image"]`. Got {self.crop_shape} '
f"{self.image_size} for `image_size`." f'for `crop_shape` and {self.input_shapes["observation.image"]} for '
'`input_shapes["observation.image"]`.'
) )
supported_prediction_types = ["epsilon", "sample"] supported_prediction_types = ["epsilon", "sample"]
if self.prediction_type not in supported_prediction_types: if self.prediction_type not in supported_prediction_types:

View File

@ -13,7 +13,6 @@ import logging
import math import math
import time import time
from collections import deque from collections import deque
from itertools import chain
from typing import Callable from typing import Callable
import einops import einops
@ -27,6 +26,7 @@ from torch import Tensor, nn
from torch.nn.modules.batchnorm import _BatchNorm from torch.nn.modules.batchnorm import _BatchNorm
from lerobot.common.policies.diffusion.configuration_diffusion import DiffusionConfig from lerobot.common.policies.diffusion.configuration_diffusion import DiffusionConfig
from lerobot.common.policies.normalize import Normalize, Unnormalize
from lerobot.common.policies.utils import ( from lerobot.common.policies.utils import (
get_device_from_parameters, get_device_from_parameters,
get_dtype_from_parameters, get_dtype_from_parameters,
@ -42,7 +42,9 @@ class DiffusionPolicy(nn.Module):
name = "diffusion" name = "diffusion"
def __init__(self, cfg: DiffusionConfig | None, lr_scheduler_num_training_steps: int = 0): def __init__(
self, cfg: DiffusionConfig | None, lr_scheduler_num_training_steps: int = 0, dataset_stats=None
):
""" """
Args: Args:
cfg: Policy configuration class instance or None, in which case the default instantiation of the cfg: Policy configuration class instance or None, in which case the default instantiation of the
@ -54,6 +56,8 @@ class DiffusionPolicy(nn.Module):
if cfg is None: if cfg is None:
cfg = DiffusionConfig() cfg = DiffusionConfig()
self.cfg = cfg self.cfg = cfg
self.normalize_inputs = Normalize(cfg.input_shapes, cfg.normalize_input_modes, dataset_stats)
self.unnormalize_outputs = Unnormalize(cfg.output_shapes, cfg.unnormalize_output_modes, dataset_stats)
# queues are populated during rollout of the policy, they contain the n latest observations and actions # queues are populated during rollout of the policy, they contain the n latest observations and actions
self._queues = None self._queues = None
@ -126,6 +130,8 @@ class DiffusionPolicy(nn.Module):
assert "observation.state" in batch assert "observation.state" in batch
assert len(batch) == 2 assert len(batch) == 2
batch = self.normalize_inputs(batch)
self._queues = populate_queues(self._queues, batch) self._queues = populate_queues(self._queues, batch)
if len(self._queues["action"]) == 0: if len(self._queues["action"]) == 0:
@ -135,6 +141,10 @@ class DiffusionPolicy(nn.Module):
actions = self.ema_diffusion.generate_actions(batch) actions = self.ema_diffusion.generate_actions(batch)
else: else:
actions = self.diffusion.generate_actions(batch) actions = self.diffusion.generate_actions(batch)
# TODO(rcadene): make above methods return output dictionary?
actions = self.unnormalize_outputs({"action": actions})["action"]
self._queues["action"].extend(actions.transpose(0, 1)) self._queues["action"].extend(actions.transpose(0, 1))
action = self._queues["action"].popleft() action = self._queues["action"].popleft()
@ -151,9 +161,13 @@ class DiffusionPolicy(nn.Module):
self.diffusion.train() self.diffusion.train()
batch = self.normalize_inputs(batch)
loss = self.forward(batch)["loss"] loss = self.forward(batch)["loss"]
loss.backward() loss.backward()
# TODO(rcadene): self.unnormalize_outputs(out_dict)
grad_norm = torch.nn.utils.clip_grad_norm_( grad_norm = torch.nn.utils.clip_grad_norm_(
self.diffusion.parameters(), self.diffusion.parameters(),
self.cfg.grad_clip_norm, self.cfg.grad_clip_norm,
@ -197,7 +211,8 @@ class _DiffusionUnetImagePolicy(nn.Module):
self.rgb_encoder = _RgbEncoder(cfg) self.rgb_encoder = _RgbEncoder(cfg)
self.unet = _ConditionalUnet1D( self.unet = _ConditionalUnet1D(
cfg, global_cond_dim=(cfg.action_dim + self.rgb_encoder.feature_dim) * cfg.n_obs_steps cfg,
global_cond_dim=(cfg.output_shapes["action"][0] + self.rgb_encoder.feature_dim) * cfg.n_obs_steps,
) )
self.noise_scheduler = DDPMScheduler( self.noise_scheduler = DDPMScheduler(
@ -225,7 +240,7 @@ class _DiffusionUnetImagePolicy(nn.Module):
# Sample prior. # Sample prior.
sample = torch.randn( sample = torch.randn(
size=(batch_size, self.cfg.horizon, self.cfg.action_dim), size=(batch_size, self.cfg.horizon, self.cfg.output_shapes["action"][0]),
dtype=dtype, dtype=dtype,
device=device, device=device,
generator=generator, generator=generator,
@ -268,7 +283,7 @@ class _DiffusionUnetImagePolicy(nn.Module):
sample = self.conditional_sample(batch_size, global_cond=global_cond) sample = self.conditional_sample(batch_size, global_cond=global_cond)
# `horizon` steps worth of actions (from the first observation). # `horizon` steps worth of actions (from the first observation).
actions = sample[..., : self.cfg.action_dim] actions = sample[..., : self.cfg.output_shapes["action"][0]]
# Extract `n_action_steps` steps worth of actions (from the current observation). # Extract `n_action_steps` steps worth of actions (from the current observation).
start = n_obs_steps - 1 start = n_obs_steps - 1
end = start + self.cfg.n_action_steps end = start + self.cfg.n_action_steps
@ -346,12 +361,6 @@ class _RgbEncoder(nn.Module):
def __init__(self, cfg: DiffusionConfig): def __init__(self, cfg: DiffusionConfig):
super().__init__() super().__init__()
# Set up optional preprocessing. # Set up optional preprocessing.
if all(v == 1.0 for v in chain(cfg.image_normalization_mean, cfg.image_normalization_std)):
self.normalizer = nn.Identity()
else:
self.normalizer = torchvision.transforms.Normalize(
mean=cfg.image_normalization_mean, std=cfg.image_normalization_std
)
if cfg.crop_shape is not None: if cfg.crop_shape is not None:
self.do_crop = True self.do_crop = True
# Always use center crop for eval # Always use center crop for eval
@ -384,7 +393,9 @@ class _RgbEncoder(nn.Module):
# Set up pooling and final layers. # Set up pooling and final layers.
# Use a dry run to get the feature map shape. # Use a dry run to get the feature map shape.
with torch.inference_mode(): with torch.inference_mode():
feat_map_shape = tuple(self.backbone(torch.zeros(size=(1, 3, *cfg.image_size))).shape[1:]) feat_map_shape = tuple(
self.backbone(torch.zeros(size=(1, *cfg.input_shapes["observation.image"]))).shape[1:]
)
self.pool = SpatialSoftmax(feat_map_shape, num_kp=cfg.spatial_softmax_num_keypoints) self.pool = SpatialSoftmax(feat_map_shape, num_kp=cfg.spatial_softmax_num_keypoints)
self.feature_dim = cfg.spatial_softmax_num_keypoints * 2 self.feature_dim = cfg.spatial_softmax_num_keypoints * 2
self.out = nn.Linear(cfg.spatial_softmax_num_keypoints * 2, self.feature_dim) self.out = nn.Linear(cfg.spatial_softmax_num_keypoints * 2, self.feature_dim)
@ -397,8 +408,7 @@ class _RgbEncoder(nn.Module):
Returns: Returns:
(B, D) image feature. (B, D) image feature.
""" """
# Preprocess: normalize and maybe crop (if it was set up in the __init__). # Preprocess: maybe crop (if it was set up in the __init__).
x = self.normalizer(x)
if self.do_crop: if self.do_crop:
if self.training: # noqa: SIM108 if self.training: # noqa: SIM108
x = self.maybe_random_crop(x) x = self.maybe_random_crop(x)
@ -502,7 +512,7 @@ class _ConditionalUnet1D(nn.Module):
# In channels / out channels for each downsampling block in the Unet's encoder. For the decoder, we # In channels / out channels for each downsampling block in the Unet's encoder. For the decoder, we
# just reverse these. # just reverse these.
in_out = [(cfg.action_dim, cfg.down_dims[0])] + list( in_out = [(cfg.output_shapes["action"][0], cfg.down_dims[0])] + list(
zip(cfg.down_dims[:-1], cfg.down_dims[1:], strict=True) zip(cfg.down_dims[:-1], cfg.down_dims[1:], strict=True)
) )
@ -553,7 +563,7 @@ class _ConditionalUnet1D(nn.Module):
self.final_conv = nn.Sequential( self.final_conv = nn.Sequential(
_Conv1dBlock(cfg.down_dims[0], cfg.down_dims[0], kernel_size=cfg.kernel_size), _Conv1dBlock(cfg.down_dims[0], cfg.down_dims[0], kernel_size=cfg.kernel_size),
nn.Conv1d(cfg.down_dims[0], cfg.action_dim, 1), nn.Conv1d(cfg.down_dims[0], cfg.output_shapes["action"][0], 1),
) )
def forward(self, x: Tensor, timestep: Tensor | int, global_cond=None) -> Tensor: def forward(self, x: Tensor, timestep: Tensor | int, global_cond=None) -> Tensor:

View File

@ -20,7 +20,7 @@ def _policy_cfg_from_hydra_cfg(policy_cfg_class, hydra_cfg):
return policy_cfg return policy_cfg
def make_policy(hydra_cfg: DictConfig): def make_policy(hydra_cfg: DictConfig, dataset_stats=None):
if hydra_cfg.policy.name == "tdmpc": if hydra_cfg.policy.name == "tdmpc":
from lerobot.common.policies.tdmpc.policy import TDMPCPolicy from lerobot.common.policies.tdmpc.policy import TDMPCPolicy
@ -35,14 +35,14 @@ def make_policy(hydra_cfg: DictConfig):
from lerobot.common.policies.diffusion.modeling_diffusion import DiffusionPolicy from lerobot.common.policies.diffusion.modeling_diffusion import DiffusionPolicy
policy_cfg = _policy_cfg_from_hydra_cfg(DiffusionConfig, hydra_cfg) policy_cfg = _policy_cfg_from_hydra_cfg(DiffusionConfig, hydra_cfg)
policy = DiffusionPolicy(policy_cfg, hydra_cfg.offline_steps) policy = DiffusionPolicy(policy_cfg, hydra_cfg.offline_steps, dataset_stats)
policy.to(get_safe_torch_device(hydra_cfg.device)) policy.to(get_safe_torch_device(hydra_cfg.device))
elif hydra_cfg.policy.name == "act": elif hydra_cfg.policy.name == "act":
from lerobot.common.policies.act.configuration_act import ActionChunkingTransformerConfig from lerobot.common.policies.act.configuration_act import ActionChunkingTransformerConfig
from lerobot.common.policies.act.modeling_act import ActionChunkingTransformerPolicy from lerobot.common.policies.act.modeling_act import ActionChunkingTransformerPolicy
policy_cfg = _policy_cfg_from_hydra_cfg(ActionChunkingTransformerConfig, hydra_cfg) policy_cfg = _policy_cfg_from_hydra_cfg(ActionChunkingTransformerConfig, hydra_cfg)
policy = ActionChunkingTransformerPolicy(policy_cfg) policy = ActionChunkingTransformerPolicy(policy_cfg, dataset_stats)
policy.to(get_safe_torch_device(hydra_cfg.device)) policy.to(get_safe_torch_device(hydra_cfg.device))
else: else:
raise ValueError(hydra_cfg.policy.name) raise ValueError(hydra_cfg.policy.name)

View File

@ -0,0 +1,196 @@
import torch
from torch import nn
def create_stats_buffers(shapes, modes, stats=None):
"""
Create buffers per modality (e.g. "observation.image", "action") containing their mean, std, min, max statistics.
Parameters:
shapes (dict): A dictionary where keys are input modalities (e.g. "observation.image") and values are their shapes (e.g. `[3,96,96]`]).
These shapes are used to create the tensor buffer containing mean, std, min, max statistics. If the provided `shapes` contain keys related to images, the shape is adjusted to be invariant to height
and width, assuming a channel-first (c, h, w) format.
modes (dict): A dictionary where keys are output modalities (e.g. "observation.image") and values are their normalization modes among:
- "mean_std": substract the mean and divide by standard deviation.
- "min_max": map to [-1, 1] range.
stats (dict, optional): A dictionary where keys are output modalities (e.g. "observation.image") and values are dictionaries of statistic types and their values
(e.g. `{"mean": torch.randn(3,1,1)}, "std": torch.randn(3,1,1)}`). If provided, as expected for training the model for the first time,
these statistics will overwrite the default buffers. If not provided, as expected for finetuning or evaluation, the default buffers should to be
be overwritten by a call to `policy.load_state_dict(state_dict)`. That way, initializing the dataset is not needed to get the stats, since
they are already in the policy state_dict.
Returns:
dict: A dictionary where keys are modalities and values are `nn.ParameterDict` containing `nn.Parameters` set to
`requires_grad=False`, suitable to not be updated during backpropagation.
"""
stats_buffers = {}
for key, mode in modes.items():
assert mode in ["mean_std", "min_max"]
shape = tuple(shapes[key])
if "image" in key:
# sanity checks
assert len(shape) == 3, f"number of dimensions of {key} != 3 ({shape=}"
c, h, w = shape
assert c < h and c < w, f"{key} is not channel first ({shape=})"
# override image shape to be invariant to height and width
shape = (c, 1, 1)
# Note: we initialize mean, std, min, max to infinity. They should be overwritten
# downstream by `stats` or `policy.load_state_dict`, as expected. During forward,
# we assert they are not infinity anymore.
buffer = {}
if mode == "mean_std":
mean = torch.ones(shape, dtype=torch.float32) * torch.inf
std = torch.ones(shape, dtype=torch.float32) * torch.inf
buffer = nn.ParameterDict(
{
"mean": nn.Parameter(mean, requires_grad=False),
"std": nn.Parameter(std, requires_grad=False),
}
)
elif mode == "min_max":
min = torch.ones(shape, dtype=torch.float32) * torch.inf
max = torch.ones(shape, dtype=torch.float32) * torch.inf
buffer = nn.ParameterDict(
{
"min": nn.Parameter(min, requires_grad=False),
"max": nn.Parameter(max, requires_grad=False),
}
)
if stats is not None:
if mode == "mean_std":
buffer["mean"].data = stats[key]["mean"]
buffer["std"].data = stats[key]["std"]
elif mode == "min_max":
buffer["min"].data = stats[key]["min"]
buffer["max"].data = stats[key]["max"]
stats_buffers[key] = buffer
return stats_buffers
class Normalize(nn.Module):
"""
Normalizes the input data (e.g. "observation.image") for more stable and faster convergence during training.
Parameters:
shapes (dict): A dictionary where keys are input modalities (e.g. "observation.image") and values are their shapes (e.g. `[3,96,96]`]).
These shapes are used to create the tensor buffer containing mean, std, min, max statistics. If the provided `shapes` contain keys related to images, the shape is adjusted to be invariant to height
and width, assuming a channel-first (c, h, w) format.
modes (dict): A dictionary where keys are output modalities (e.g. "observation.image") and values are their normalization modes among:
- "mean_std": substract the mean and divide by standard deviation.
- "min_max": map to [-1, 1] range.
stats (dict, optional): A dictionary where keys are output modalities (e.g. "observation.image") and values are dictionaries of statistic types and their values
(e.g. `{"mean": torch.randn(3,1,1)}, "std": torch.randn(3,1,1)}`). If provided, as expected for training the model for the first time,
these statistics will overwrite the default buffers. If not provided, as expected for finetuning or evaluation, the default buffers should to be
be overwritten by a call to `policy.load_state_dict(state_dict)`. That way, initializing the dataset is not needed to get the stats, since
they are already in the policy state_dict.
"""
def __init__(self, shapes, modes, stats=None):
super().__init__()
self.shapes = shapes
self.modes = modes
self.stats = stats
# `self.buffer_observation_state["mean"]` contains `torch.tensor(state_dim)`
stats_buffers = create_stats_buffers(shapes, modes, stats)
for key, buffer in stats_buffers.items():
setattr(self, "buffer_" + key.replace(".", "_"), buffer)
# TODO(rcadene): should we remove torch.no_grad?
@torch.no_grad
def forward(self, batch):
for key, mode in self.modes.items():
buffer = getattr(self, "buffer_" + key.replace(".", "_"))
if mode == "mean_std":
mean = buffer["mean"]
std = buffer["std"]
assert not torch.isinf(
mean
).any(), "`mean` is infinity. You forgot to initialize with `stats` as argument, or called `policy.load_state_dict`."
assert not torch.isinf(
std
).any(), "`std` is infinity. You forgot to initialize with `stats` as argument, or called `policy.load_state_dict`."
batch[key] = (batch[key] - mean) / (std + 1e-8)
elif mode == "min_max":
min = buffer["min"]
max = buffer["max"]
assert not torch.isinf(
min
).any(), "`min` is infinity. You forgot to initialize with `stats` as argument, or called `policy.load_state_dict`."
assert not torch.isinf(
max
).any(), "`max` is infinity. You forgot to initialize with `stats` as argument, or called `policy.load_state_dict`."
# normalize to [0,1]
batch[key] = (batch[key] - min) / (max - min)
# normalize to [-1, 1]
batch[key] = batch[key] * 2 - 1
else:
raise ValueError(mode)
return batch
class Unnormalize(nn.Module):
"""
Similar to `Normalize` but unnormalizes output data (e.g. `{"action": torch.randn(b,c)}`) in their original range used by the environment.
Parameters:
shapes (dict): A dictionary where keys are output modalities (e.g. "action") and values are their shapes (e.g. [10]).
These shapes are used to create the tensor buffer containing mean, std, min, max statistics. If the provided `shapes` contain keys related to images, the shape is adjusted to be invariant to height
and width, assuming a channel-first (c, h, w) format.
modes (dict): A dictionary where keys are output modalities (e.g. "action") and values are their unnormalization modes among:
- "mean_std": multiply by standard deviation and add mean
- "min_max": go from [-1, 1] range to original range.
stats (dict, optional): A dictionary where keys are output modalities (e.g. "action") and values are dictionaries of statistic types and their values
(e.g. `{"max": torch.tensor(1)}, "min": torch.tensor(0)}`). If provided, as expected for training the model for the first time,
these statistics will overwrite the default buffers. If not provided, as expected for finetuning or evaluation, the default buffers should to be
be overwritten by a call to `policy.load_state_dict(state_dict)`. That way, initializing the dataset is not needed to get the stats, since
they are already in the policy state_dict.
"""
def __init__(self, shapes, modes, stats=None):
super().__init__()
self.shapes = shapes
self.modes = modes
self.stats = stats
# `self.buffer_observation_state["mean"]` contains `torch.tensor(state_dim)`
stats_buffers = create_stats_buffers(shapes, modes, stats)
for key, buffer in stats_buffers.items():
setattr(self, "buffer_" + key.replace(".", "_"), buffer)
# TODO(rcadene): should we remove torch.no_grad?
@torch.no_grad
def forward(self, batch):
for key, mode in self.modes.items():
buffer = getattr(self, "buffer_" + key.replace(".", "_"))
if mode == "mean_std":
mean = buffer["mean"]
std = buffer["std"]
assert not torch.isinf(
mean
).any(), "`mean` is infinity. You forgot to initialize with `stats` as argument, or called `policy.load_state_dict`."
assert not torch.isinf(
std
).any(), "`std` is infinity. You forgot to initialize with `stats` as argument, or called `policy.load_state_dict`."
batch[key] = batch[key] * std + mean
elif mode == "min_max":
min = buffer["min"]
max = buffer["max"]
assert not torch.isinf(
min
).any(), "`min` is infinity. You forgot to initialize with `stats` as argument, or called `policy.load_state_dict`."
assert not torch.isinf(
max
).any(), "`max` is infinity. You forgot to initialize with `stats` as argument, or called `policy.load_state_dict`."
batch[key] = (batch[key] + 1) / 2
batch[key] = batch[key] * (max - min) + min
else:
raise ValueError(mode)
return batch

View File

@ -1,65 +0,0 @@
from torchvision.transforms.v2 import Compose, Transform
def apply_inverse_transform(item, transform):
transforms = transform.transforms if isinstance(transform, Compose) else [transform]
for tf in transforms[::-1]:
if tf.invertible:
item = tf.inverse_transform(item)
else:
raise ValueError(f"Inverse transform called on a non invertible transform ({tf}).")
return item
class NormalizeTransform(Transform):
invertible = True
def __init__(
self,
stats: dict,
in_keys: list[str] = None,
out_keys: list[str] | None = None,
in_keys_inv: list[str] | None = None,
out_keys_inv: list[str] | None = None,
mode="mean_std",
):
super().__init__()
self.in_keys = in_keys
self.out_keys = in_keys if out_keys is None else out_keys
self.in_keys_inv = self.out_keys if in_keys_inv is None else in_keys_inv
self.out_keys_inv = self.in_keys if out_keys_inv is None else out_keys_inv
self.stats = stats
assert mode in ["mean_std", "min_max"]
self.mode = mode
def forward(self, item):
for inkey, outkey in zip(self.in_keys, self.out_keys, strict=False):
if inkey not in item:
continue
if self.mode == "mean_std":
mean = self.stats[inkey]["mean"]
std = self.stats[inkey]["std"]
item[outkey] = (item[inkey] - mean) / (std + 1e-8)
else:
min = self.stats[inkey]["min"]
max = self.stats[inkey]["max"]
# normalize to [0,1]
item[outkey] = (item[inkey] - min) / (max - min)
# normalize to [-1, 1]
item[outkey] = item[outkey] * 2 - 1
return item
def inverse_transform(self, item):
for inkey, outkey in zip(self.in_keys_inv, self.out_keys_inv, strict=False):
if inkey not in item:
continue
if self.mode == "mean_std":
mean = self.stats[inkey]["mean"]
std = self.stats[inkey]["std"]
item[outkey] = item[inkey] * std + mean
else:
min = self.stats[inkey]["min"]
max = self.stats[inkey]["max"]
item[outkey] = (item[inkey] + 1) / 2
item[outkey] = item[outkey] * (max - min) + min
return item

View File

@ -20,7 +20,5 @@ env:
image_size: [3, 480, 640] image_size: [3, 480, 640]
episode_length: 400 episode_length: 400
fps: ${fps} fps: ${fps}
policy:
state_dim: 14 state_dim: 14
action_dim: 14 action_dim: 14

View File

@ -20,7 +20,5 @@ env:
image_size: 96 image_size: 96
episode_length: 300 episode_length: 300
fps: ${fps} fps: ${fps}
policy:
state_dim: 2 state_dim: 2
action_dim: 2 action_dim: 2

View File

@ -19,7 +19,5 @@ env:
image_size: 84 image_size: 84
episode_length: 25 episode_length: 25
fps: ${fps} fps: ${fps}
policy:
state_dim: 4 state_dim: 4
action_dim: 4 action_dim: 4

View File

@ -11,26 +11,36 @@ log_freq: 250
n_obs_steps: 1 n_obs_steps: 1
# when temporal_agg=False, n_action_steps=horizon # when temporal_agg=False, n_action_steps=horizon
override_dataset_stats:
observation.images.top:
# stats from imagenet, since we use a pretrained vision model
mean: [[[0.485]], [[0.456]], [[0.406]]] # (c,1,1)
std: [[[0.229]], [[0.224]], [[0.225]]] # (c,1,1)
# See `configuration_act.py` for more details. # See `configuration_act.py` for more details.
policy: policy:
name: act name: act
pretrained_model_path: pretrained_model_path:
# Environment. # Input / output structure.
# Inherit these from the environment config.
state_dim: ???
action_dim: ???
# Inputs / output structure.
n_obs_steps: ${n_obs_steps} n_obs_steps: ${n_obs_steps}
camera_names: [top] # [top, front_close, left_pillar, right_pillar]
chunk_size: 100 # chunk_size chunk_size: 100 # chunk_size
n_action_steps: 100 n_action_steps: 100
# Vision preprocessing. input_shapes:
image_normalization_mean: [0.485, 0.456, 0.406] # TODO(rcadene, alexander-soare): add variables for height and width from the dataset/env?
image_normalization_std: [0.229, 0.224, 0.225] observation.images.top: [3, 480, 640]
observation.state: ["${env.state_dim}"]
output_shapes:
action: ["${env.action_dim}"]
# Normalization / Unnormalization
normalize_input_modes:
observation.images.top: mean_std
observation.state: mean_std
unnormalize_output_modes:
action: mean_std
# Architecture. # Architecture.
# Vision backbone. # Vision backbone.

View File

@ -18,27 +18,43 @@ online_steps: 0
offline_prioritized_sampler: true offline_prioritized_sampler: true
override_dataset_stats:
# TODO(rcadene, alexander-soare): should we remove image stats as well? do we use a pretrained vision model?
observation.image:
mean: [[[0.5]], [[0.5]], [[0.5]]] # (c,1,1)
std: [[[0.5]], [[0.5]], [[0.5]]] # (c,1,1)
# TODO(rcadene, alexander-soare): we override state and action stats to use the same as the pretrained model
# from the original codebase, but we should remove these and train our own pretrained model
observation.state:
min: [13.456424, 32.938293]
max: [496.14618, 510.9579]
action:
min: [12.0, 25.0]
max: [511.0, 511.0]
policy: policy:
name: diffusion name: diffusion
pretrained_model_path: pretrained_model_path:
# Environment. # Input / output structure.
# Inherit these from the environment config.
state_dim: ???
action_dim: ???
image_size:
- ${env.image_size} # height
- ${env.image_size} # width
# Inputs / output structure.
n_obs_steps: ${n_obs_steps} n_obs_steps: ${n_obs_steps}
horizon: ${horizon} horizon: ${horizon}
n_action_steps: ${n_action_steps} n_action_steps: ${n_action_steps}
# Vision preprocessing. input_shapes:
image_normalization_mean: [0.5, 0.5, 0.5] # TODO(rcadene, alexander-soare): add variables for height and width from the dataset/env?
image_normalization_std: [0.5, 0.5, 0.5] observation.image: [3, 96, 96]
observation.state: ["${env.state_dim}"]
output_shapes:
action: ["${env.action_dim}"]
# Normalization / Unnormalization
normalize_input_modes:
observation.image: mean_std
observation.state: min_max
unnormalize_output_modes:
action: min_max
# Architecture / modeling. # Architecture / modeling.
# Vision backbone. # Vision backbone.

View File

@ -16,8 +16,8 @@ policy:
frame_stack: 1 frame_stack: 1
num_channels: 32 num_channels: 32
img_size: ${env.image_size} img_size: ${env.image_size}
state_dim: ??? state_dim: ${env.action_dim}
action_dim: ??? action_dim: ${env.action_dim}
# planning # planning
mpc: true mpc: true

View File

@ -46,7 +46,6 @@ from huggingface_hub import snapshot_download
from PIL import Image as PILImage from PIL import Image as PILImage
from tqdm import trange from tqdm import trange
from lerobot.common.datasets.factory import make_dataset
from lerobot.common.datasets.utils import hf_transform_to_torch from lerobot.common.datasets.utils import hf_transform_to_torch
from lerobot.common.envs.factory import make_env from lerobot.common.envs.factory import make_env
from lerobot.common.envs.utils import postprocess_action, preprocess_observation from lerobot.common.envs.utils import postprocess_action, preprocess_observation
@ -64,8 +63,6 @@ def eval_policy(
policy: torch.nn.Module, policy: torch.nn.Module,
max_episodes_rendered: int = 0, max_episodes_rendered: int = 0,
video_dir: Path = None, video_dir: Path = None,
# TODO(rcadene): make it possible to overwrite fps? we should use env.fps
transform: callable = None,
return_episode_data: bool = False, return_episode_data: bool = False,
seed=None, seed=None,
): ):
@ -132,10 +129,6 @@ def eval_policy(
if return_episode_data: if return_episode_data:
observations.append(deepcopy(observation)) observations.append(deepcopy(observation))
# apply transform to normalize the observations
for key in observation:
observation[key] = torch.stack([transform({key: item})[key] for item in observation[key]])
# send observation to device/gpu # send observation to device/gpu
observation = {key: observation[key].to(device, non_blocking=True) for key in observation} observation = {key: observation[key].to(device, non_blocking=True) for key in observation}
@ -143,8 +136,8 @@ def eval_policy(
with torch.inference_mode(): with torch.inference_mode():
action = policy.select_action(observation, step=step) action = policy.select_action(observation, step=step)
# apply inverse transform to unnormalize the action # convert to cpu numpy
action = postprocess_action(action, transform) action = postprocess_action(action)
# apply the next action # apply the next action
observation, reward, terminated, truncated, info = env.step(action) observation, reward, terminated, truncated, info = env.step(action)
@ -360,7 +353,7 @@ def eval_policy(
return info return info
def eval(cfg: dict, out_dir=None, stats_path=None): def eval(cfg: dict, out_dir=None):
if out_dir is None: if out_dir is None:
raise NotImplementedError() raise NotImplementedError()
@ -375,10 +368,6 @@ def eval(cfg: dict, out_dir=None, stats_path=None):
log_output_dir(out_dir) log_output_dir(out_dir)
logging.info("Making transforms.")
# TODO(alexander-soare): Completely decouple datasets from evaluation.
transform = make_dataset(cfg, stats_path=stats_path).transform
logging.info("Making environment.") logging.info("Making environment.")
env = make_env(cfg, num_parallel_envs=cfg.eval_episodes) env = make_env(cfg, num_parallel_envs=cfg.eval_episodes)
@ -390,7 +379,6 @@ def eval(cfg: dict, out_dir=None, stats_path=None):
policy, policy,
max_episodes_rendered=10, max_episodes_rendered=10,
video_dir=Path(out_dir) / "eval", video_dir=Path(out_dir) / "eval",
transform=transform,
return_episode_data=False, return_episode_data=False,
seed=cfg.seed, seed=cfg.seed,
) )
@ -423,17 +411,13 @@ if __name__ == "__main__":
if args.config is not None: if args.config is not None:
# Note: For the config_path, Hydra wants a path relative to this script file. # Note: For the config_path, Hydra wants a path relative to this script file.
cfg = init_hydra_config(args.config, args.overrides) cfg = init_hydra_config(args.config, args.overrides)
# TODO(alexander-soare): Save and load stats in trained model directory.
stats_path = None
elif args.hub_id is not None: elif args.hub_id is not None:
folder = Path(snapshot_download(args.hub_id, revision=args.revision)) folder = Path(snapshot_download(args.hub_id, revision=args.revision))
cfg = init_hydra_config( cfg = init_hydra_config(
folder / "config.yaml", [f"policy.pretrained_model_path={folder / 'model.pt'}", *args.overrides] folder / "config.yaml", [f"policy.pretrained_model_path={folder / 'model.pt'}", *args.overrides]
) )
stats_path = folder / "stats.pth"
eval( eval(
cfg, cfg,
out_dir=f"outputs/eval/{dt.now().strftime('%Y-%m-%d/%H-%M-%S')}_{cfg.env.name}_{cfg.policy.name}", out_dir=f"outputs/eval/{dt.now().strftime('%Y-%m-%d/%H-%M-%S')}_{cfg.env.name}_{cfg.policy.name}",
stats_path=stats_path,
) )

View File

@ -232,7 +232,7 @@ def train(cfg: dict, out_dir=None, job_name=None):
env = make_env(cfg, num_parallel_envs=cfg.eval_episodes) env = make_env(cfg, num_parallel_envs=cfg.eval_episodes)
logging.info("make_policy") logging.info("make_policy")
policy = make_policy(cfg) policy = make_policy(cfg, dataset_stats=offline_dataset.stats)
num_learnable_params = sum(p.numel() for p in policy.parameters() if p.requires_grad) num_learnable_params = sum(p.numel() for p in policy.parameters() if p.requires_grad)
num_total_params = sum(p.numel() for p in policy.parameters()) num_total_params = sum(p.numel() for p in policy.parameters())
@ -339,7 +339,6 @@ def train(cfg: dict, out_dir=None, job_name=None):
eval_info = eval_policy( eval_info = eval_policy(
rollout_env, rollout_env,
policy, policy,
transform=offline_dataset.transform,
return_episode_data=True, return_episode_data=True,
seed=cfg.seed, seed=cfg.seed,
) )

View File

@ -50,11 +50,7 @@ def visualize_dataset(cfg: dict, out_dir=None):
log_output_dir(out_dir) log_output_dir(out_dir)
logging.info("make_dataset") logging.info("make_dataset")
dataset = make_dataset( dataset = make_dataset(cfg)
cfg,
# remove all transformations such as rescale images from [0,255] to [0,1] or normalization
normalize=False,
)
logging.info("Start rendering episodes from offline buffer") logging.info("Start rendering episodes from offline buffer")
video_paths = render_dataset(dataset, out_dir, MAX_NUM_STEPS * NUM_EPISODES_TO_RENDER) video_paths = render_dataset(dataset, out_dir, MAX_NUM_STEPS * NUM_EPISODES_TO_RENDER)

View File

@ -6,7 +6,6 @@ import torch
from gymnasium.utils.env_checker import check_env from gymnasium.utils.env_checker import check_env
import lerobot import lerobot
from lerobot.common.datasets.factory import make_dataset
from lerobot.common.envs.factory import make_env from lerobot.common.envs.factory import make_env
from lerobot.common.envs.utils import preprocess_observation from lerobot.common.envs.utils import preprocess_observation
from lerobot.common.utils.utils import init_hydra_config from lerobot.common.utils.utils import init_hydra_config
@ -38,12 +37,14 @@ def test_factory(env_name):
overrides=[f"env={env_name}", f"device={DEVICE}"], overrides=[f"env={env_name}", f"device={DEVICE}"],
) )
dataset = make_dataset(cfg)
env = make_env(cfg, num_parallel_envs=1) env = make_env(cfg, num_parallel_envs=1)
obs, _ = env.reset() obs, _ = env.reset()
obs = preprocess_observation(obs, transform=dataset.transform) obs = preprocess_observation(obs)
for key in dataset.image_keys:
# test image keys are float32 in range [0,1]
for key in obs:
if "image" not in key:
continue
img = obs[key] img = obs[key]
assert img.dtype == torch.float32 assert img.dtype == torch.float32
# TODO(rcadene): we assume for now that image normalization takes place in the model # TODO(rcadene): we assume for now that image normalization takes place in the model

View File

@ -51,7 +51,7 @@ def test_examples_4_and_3():
# Pass empty globals to allow dictionary comprehension https://stackoverflow.com/a/32897127/4391249. # Pass empty globals to allow dictionary comprehension https://stackoverflow.com/a/32897127/4391249.
exec(file_contents, {}) exec(file_contents, {})
for file_name in ["model.pt", "stats.pth", "config.yaml"]: for file_name in ["model.pt", "config.yaml"]:
assert Path(f"outputs/train/example_pusht_diffusion/{file_name}").exists() assert Path(f"outputs/train/example_pusht_diffusion/{file_name}").exists()
path = "examples/3_evaluate_pretrained_policy.py" path = "examples/3_evaluate_pretrained_policy.py"

View File

@ -6,10 +6,10 @@ from lerobot.common.datasets.utils import cycle
from lerobot.common.envs.factory import make_env from lerobot.common.envs.factory import make_env
from lerobot.common.envs.utils import postprocess_action, preprocess_observation from lerobot.common.envs.utils import postprocess_action, preprocess_observation
from lerobot.common.policies.factory import make_policy from lerobot.common.policies.factory import make_policy
from lerobot.common.policies.normalize import Normalize, Unnormalize
from lerobot.common.policies.policy_protocol import Policy from lerobot.common.policies.policy_protocol import Policy
from lerobot.common.utils.utils import init_hydra_config from lerobot.common.utils.utils import init_hydra_config
from tests.utils import DEFAULT_CONFIG_PATH, DEVICE, require_env
from .utils import DEFAULT_CONFIG_PATH, DEVICE, require_env
# TODO(aliberts): refactor using lerobot/__init__.py variables # TODO(aliberts): refactor using lerobot/__init__.py variables
@ -44,14 +44,16 @@ def test_policy(env_name, policy_name, extra_overrides):
] ]
+ extra_overrides, + extra_overrides,
) )
# Check that we can make the policy object. # Check that we can make the policy object.
policy = make_policy(cfg) dataset = make_dataset(cfg)
policy = make_policy(cfg, dataset_stats=dataset.stats)
# Check that the policy follows the required protocol. # Check that the policy follows the required protocol.
assert isinstance( assert isinstance(
policy, Policy policy, Policy
), f"The policy does not follow the required protocol. Please see {Policy.__module__}.{Policy.__name__}." ), f"The policy does not follow the required protocol. Please see {Policy.__module__}.{Policy.__name__}."
# Check that we run select_actions and get the appropriate output. # Check that we run select_actions and get the appropriate output.
dataset = make_dataset(cfg)
env = make_env(cfg, num_parallel_envs=2) env = make_env(cfg, num_parallel_envs=2)
dataloader = torch.utils.data.DataLoader( dataloader = torch.utils.data.DataLoader(
@ -77,7 +79,7 @@ def test_policy(env_name, policy_name, extra_overrides):
observation, _ = env.reset(seed=cfg.seed) observation, _ = env.reset(seed=cfg.seed)
# apply transform to normalize the observations # apply transform to normalize the observations
observation = preprocess_observation(observation, dataset.transform) observation = preprocess_observation(observation)
# send observation to device/gpu # send observation to device/gpu
observation = {key: observation[key].to(DEVICE, non_blocking=True) for key in observation} observation = {key: observation[key].to(DEVICE, non_blocking=True) for key in observation}
@ -86,8 +88,115 @@ def test_policy(env_name, policy_name, extra_overrides):
with torch.inference_mode(): with torch.inference_mode():
action = policy.select_action(observation, step=0) action = policy.select_action(observation, step=0)
# apply inverse transform to unnormalize the action # convert action to cpu numpy array
action = postprocess_action(action, dataset.transform) action = postprocess_action(action)
# Test step through policy # Test step through policy
env.step(action) env.step(action)
# Test load state_dict
if policy_name != "tdmpc":
# TODO(rcadene, alexander-soare): make it work for tdmpc
new_policy = make_policy(cfg)
new_policy.load_state_dict(policy.state_dict())
@pytest.mark.parametrize(
"insert_temporal_dim",
[
False,
True,
],
)
def test_normalize(insert_temporal_dim):
"""
Test that normalize/unnormalize can run without exceptions when properly set up, and that they raise
an exception when the forward pass is called without the stats having been provided.
TODO(rcadene, alexander-soare): This should also test that the normalization / unnormalization works as
expected.
"""
input_shapes = {
"observation.image": [3, 96, 96],
"observation.state": [10],
}
output_shapes = {
"action": [5],
}
normalize_input_modes = {
"observation.image": "mean_std",
"observation.state": "min_max",
}
unnormalize_output_modes = {
"action": "min_max",
}
dataset_stats = {
"observation.image": {
"mean": torch.randn(3, 1, 1),
"std": torch.randn(3, 1, 1),
"min": torch.randn(3, 1, 1),
"max": torch.randn(3, 1, 1),
},
"observation.state": {
"mean": torch.randn(10),
"std": torch.randn(10),
"min": torch.randn(10),
"max": torch.randn(10),
},
"action": {
"mean": torch.randn(5),
"std": torch.randn(5),
"min": torch.randn(5),
"max": torch.randn(5),
},
}
bsize = 2
input_batch = {
"observation.image": torch.randn(bsize, 3, 96, 96),
"observation.state": torch.randn(bsize, 10),
}
output_batch = {
"action": torch.randn(bsize, 5),
}
if insert_temporal_dim:
tdim = 4
for key in input_batch:
# [2,3,96,96] -> [2,tdim,3,96,96]
input_batch[key] = torch.stack([input_batch[key]] * tdim, dim=1)
for key in output_batch:
output_batch[key] = torch.stack([output_batch[key]] * tdim, dim=1)
# test without stats
normalize = Normalize(input_shapes, normalize_input_modes, stats=None)
with pytest.raises(AssertionError):
normalize(input_batch)
# test with stats
normalize = Normalize(input_shapes, normalize_input_modes, stats=dataset_stats)
normalize(input_batch)
# test loading pretrained models
new_normalize = Normalize(input_shapes, normalize_input_modes, stats=None)
new_normalize.load_state_dict(normalize.state_dict())
new_normalize(input_batch)
# test without stats
unnormalize = Unnormalize(output_shapes, unnormalize_output_modes, stats=None)
with pytest.raises(AssertionError):
unnormalize(output_batch)
# test with stats
unnormalize = Unnormalize(output_shapes, unnormalize_output_modes, stats=dataset_stats)
unnormalize(output_batch)
# test loading pretrained models
new_unnormalize = Unnormalize(output_shapes, unnormalize_output_modes, stats=None)
new_unnormalize.load_state_dict(unnormalize.state_dict())
unnormalize(output_batch)