Move normalization to policy for act and diffusion (#90)
Co-authored-by: Alexander Soare <alexander.soare159@gmail.com>
This commit is contained in:
parent
c1bcf857c5
commit
e760e4cd63
|
@ -263,15 +263,13 @@ Secondly, assuming you have trained a policy, you need:
|
|||
|
||||
- `config.yaml` which you can get from the `.hydra` directory of your training output folder.
|
||||
- `model.pt` which should be one of the saved models in the `models` directory of your training output folder (they won't be named `model.pt` but you will need to choose one).
|
||||
- `stats.pth` which should point to the same file in the dataset directory (found in `data/{dataset_name}`).
|
||||
|
||||
To upload these to the hub, prepare a folder with the following structure (you can use symlinks rather than copying):
|
||||
|
||||
```
|
||||
to_upload
|
||||
├── config.yaml
|
||||
├── model.pt
|
||||
└── stats.pth
|
||||
└── model.pt
|
||||
```
|
||||
|
||||
With the folder prepared, run the following with a desired revision ID.
|
||||
|
|
|
@ -44,7 +44,7 @@ from datasets import load_dataset
|
|||
# TODO(rcadene): list available datasets on lerobot page using `datasets`
|
||||
|
||||
# download/load hugging face dataset in pyarrow format
|
||||
hf_dataset, fps = load_dataset("lerobot/pusht", split="train"), 10
|
||||
hf_dataset, fps = load_dataset("lerobot/pusht", split="train", revision="v1.1"), 10
|
||||
|
||||
# display name of dataset and its features
|
||||
# TODO(rcadene): update to make the print pretty
|
||||
|
|
|
@ -19,7 +19,6 @@ folder = Path(snapshot_download(hub_id))
|
|||
|
||||
config_path = folder / "config.yaml"
|
||||
weights_path = folder / "model.pt"
|
||||
stats_path = folder / "stats.pth" # normalization stats
|
||||
|
||||
# Override some config parameters to do with evaluation.
|
||||
overrides = [
|
||||
|
@ -36,5 +35,4 @@ cfg = init_hydra_config(config_path, overrides)
|
|||
eval(
|
||||
cfg,
|
||||
out_dir=f"outputs/eval/example_{cfg.env.name}_{cfg.policy.name}",
|
||||
stats_path=stats_path,
|
||||
)
|
||||
|
|
|
@ -34,7 +34,7 @@ dataset = make_dataset(hydra_cfg)
|
|||
# If you're doing something different, you will likely need to change at least some of the defaults.
|
||||
cfg = DiffusionConfig()
|
||||
# TODO(alexander-soare): Remove LR scheduler from the policy.
|
||||
policy = DiffusionPolicy(cfg, lr_scheduler_num_training_steps=training_steps)
|
||||
policy = DiffusionPolicy(cfg, lr_scheduler_num_training_steps=training_steps, dataset_stats=dataset.stats)
|
||||
policy.train()
|
||||
policy.to(device)
|
||||
|
||||
|
@ -62,7 +62,6 @@ while not done:
|
|||
done = True
|
||||
break
|
||||
|
||||
# Save the policy, configuration, and normalization stats for later use.
|
||||
# Save the policy and configuration for later use.
|
||||
policy.save(output_directory / "model.pt")
|
||||
OmegaConf.save(hydra_cfg, output_directory / "config.yaml")
|
||||
torch.save(dataset.transform.transforms[-1].stats, output_directory / "stats.pth")
|
||||
|
|
|
@ -2,18 +2,13 @@ import os
|
|||
from pathlib import Path
|
||||
|
||||
import torch
|
||||
from torchvision.transforms import v2
|
||||
|
||||
from lerobot.common.transforms import NormalizeTransform
|
||||
from omegaconf import OmegaConf
|
||||
|
||||
DATA_DIR = Path(os.environ["DATA_DIR"]) if "DATA_DIR" in os.environ else None
|
||||
|
||||
|
||||
def make_dataset(
|
||||
cfg,
|
||||
# set normalize=False to remove all transformations and keep images unnormalized in [0,255]
|
||||
normalize=True,
|
||||
stats_path=None,
|
||||
split="train",
|
||||
):
|
||||
if cfg.env.name == "xarm":
|
||||
|
@ -33,58 +28,26 @@ def make_dataset(
|
|||
else:
|
||||
raise ValueError(cfg.env.name)
|
||||
|
||||
transforms = None
|
||||
if normalize:
|
||||
# TODO(rcadene): make normalization strategy configurable between mean_std, min_max, manual_min_max,
|
||||
# min_max_from_spec
|
||||
# TODO(rcadene): remove this and put it in config. Ideally we want to reproduce SOTA results just with mean_std
|
||||
normalization_mode = "mean_std" if cfg.env.name == "aloha" else "min_max"
|
||||
|
||||
if cfg.policy.name == "diffusion" and cfg.env.name == "pusht":
|
||||
stats = {}
|
||||
# TODO(rcadene): we overwrite stats to have the same as pretrained model, but we should remove this
|
||||
stats["observation.state"] = {}
|
||||
stats["observation.state"]["min"] = torch.tensor([13.456424, 32.938293], dtype=torch.float32)
|
||||
stats["observation.state"]["max"] = torch.tensor([496.14618, 510.9579], dtype=torch.float32)
|
||||
stats["action"] = {}
|
||||
stats["action"]["min"] = torch.tensor([12.0, 25.0], dtype=torch.float32)
|
||||
stats["action"]["max"] = torch.tensor([511.0, 511.0], dtype=torch.float32)
|
||||
elif stats_path is None:
|
||||
# load a first dataset to access precomputed stats
|
||||
stats_dataset = clsfunc(
|
||||
dataset_id=cfg.dataset_id,
|
||||
split="train",
|
||||
root=DATA_DIR,
|
||||
)
|
||||
stats = stats_dataset.stats
|
||||
else:
|
||||
stats = torch.load(stats_path)
|
||||
|
||||
transforms = v2.Compose(
|
||||
[
|
||||
NormalizeTransform(
|
||||
stats,
|
||||
in_keys=[
|
||||
"observation.state",
|
||||
"action",
|
||||
],
|
||||
mode=normalization_mode,
|
||||
),
|
||||
]
|
||||
)
|
||||
|
||||
delta_timestamps = cfg.policy.get("delta_timestamps")
|
||||
if delta_timestamps is not None:
|
||||
for key in delta_timestamps:
|
||||
if isinstance(delta_timestamps[key], str):
|
||||
delta_timestamps[key] = eval(delta_timestamps[key])
|
||||
|
||||
# TODO(rcadene): add data augmentations
|
||||
|
||||
dataset = clsfunc(
|
||||
dataset_id=cfg.dataset_id,
|
||||
split=split,
|
||||
root=DATA_DIR,
|
||||
delta_timestamps=delta_timestamps,
|
||||
transform=transforms,
|
||||
)
|
||||
|
||||
if cfg.get("override_dataset_stats"):
|
||||
for key, stats_dict in cfg.override_dataset_stats.items():
|
||||
for stats_type, listconfig in stats_dict.items():
|
||||
# example of stats_type: min, max, mean, std
|
||||
stats = OmegaConf.to_container(listconfig, resolve=True)
|
||||
dataset.stats[key][stats_type] = torch.tensor(stats, dtype=torch.float32)
|
||||
|
||||
return dataset
|
||||
|
|
|
@ -1,10 +1,8 @@
|
|||
import einops
|
||||
import torch
|
||||
|
||||
from lerobot.common.transforms import apply_inverse_transform
|
||||
|
||||
|
||||
def preprocess_observation(observation, transform=None):
|
||||
def preprocess_observation(observation):
|
||||
# map to expected inputs for the policy
|
||||
obs = {}
|
||||
|
||||
|
@ -24,7 +22,7 @@ def preprocess_observation(observation, transform=None):
|
|||
assert img.dtype == torch.uint8, f"expect torch.uint8, but instead {img.dtype=}"
|
||||
|
||||
# convert to channel first of type float32 in range [0,1]
|
||||
img = einops.rearrange(img, "b h w c -> b c h w")
|
||||
img = einops.rearrange(img, "b h w c -> b c h w").contiguous()
|
||||
img = img.type(torch.float32)
|
||||
img /= 255
|
||||
|
||||
|
@ -33,19 +31,11 @@ def preprocess_observation(observation, transform=None):
|
|||
# TODO(rcadene): enable pixels only baseline with `obs_type="pixels"` in environment by removing requirement for "agent_pos"
|
||||
obs["observation.state"] = torch.from_numpy(observation["agent_pos"]).float()
|
||||
|
||||
# apply same transforms as in training
|
||||
if transform is not None:
|
||||
for key in obs:
|
||||
obs[key] = torch.stack([transform({key: item})[key] for item in obs[key]])
|
||||
|
||||
return obs
|
||||
|
||||
|
||||
def postprocess_action(action, transform=None):
|
||||
action = action.to("cpu")
|
||||
# action is a batch (num_env,action_dim) instead of an item (action_dim),
|
||||
# we assume applying inverse transform on a batch works the same
|
||||
action = apply_inverse_transform({"action": action}, transform)["action"].numpy()
|
||||
def postprocess_action(action):
|
||||
action = action.to("cpu").numpy()
|
||||
assert (
|
||||
action.ndim == 2
|
||||
), "we assume dimensions are respectively the number of parallel envs, action dimensions"
|
||||
|
|
|
@ -8,23 +8,30 @@ class ActionChunkingTransformerConfig:
|
|||
Defaults are configured for training on bimanual Aloha tasks like "insertion" or "transfer".
|
||||
|
||||
The parameters you will most likely need to change are the ones which depend on the environment / sensors.
|
||||
Those are: `state_dim`, `action_dim` and `camera_names`.
|
||||
Those are: `input_shapes` and 'output_shapes`.
|
||||
|
||||
Args:
|
||||
state_dim: Dimensionality of the observation state space (excluding images).
|
||||
action_dim: Dimensionality of the action space.
|
||||
n_obs_steps: Number of environment steps worth of observations to pass to the policy (takes the
|
||||
current step and additional steps going back).
|
||||
camera_names: The (unique) set of names for the cameras.
|
||||
chunk_size: The size of the action prediction "chunks" in units of environment steps.
|
||||
n_action_steps: The number of action steps to run in the environment for one invocation of the policy.
|
||||
This should be no greater than the chunk size. For example, if the chunk size size 100, you may
|
||||
set this to 50. This would mean that the model predicts 100 steps worth of actions, runs 50 in the
|
||||
environment, and throws the other 50 out.
|
||||
image_normalization_mean: Value to subtract from the input image pixels (inputs are assumed to be in
|
||||
[0, 1]) for normalization.
|
||||
image_normalization_std: Value by which to divide the input image pixels (after the mean has been
|
||||
subtracted).
|
||||
input_shapes: A dictionary defining the shapes of the input data for the policy.
|
||||
The key represents the input data name, and the value is a list indicating the dimensions
|
||||
of the corresponding data. For example, "observation.images.top" refers to an input from the
|
||||
"top" camera with dimensions [3, 96, 96], indicating it has three color channels and 96x96 resolution.
|
||||
Importantly, shapes doesnt include batch dimension or temporal dimension.
|
||||
output_shapes: A dictionary defining the shapes of the output data for the policy.
|
||||
The key represents the output data name, and the value is a list indicating the dimensions
|
||||
of the corresponding data. For example, "action" refers to an output shape of [14], indicating
|
||||
14-dimensional actions. Importantly, shapes doesnt include batch dimension or temporal dimension.
|
||||
normalize_input_modes: A dictionary with key represents the modality (e.g. "observation.state"),
|
||||
and the value specifies the normalization mode to apply. The two availables
|
||||
modes are "mean_std" which substracts the mean and divide by the standard
|
||||
deviation and "min_max" which rescale in a [-1, 1] range.
|
||||
unnormalize_output_modes: Similar dictionary as `normalize_input_modes`, but to unormalize in original scale.
|
||||
vision_backbone: Name of the torchvision resnet backbone to use for encoding images.
|
||||
use_pretrained_backbone: Whether the backbone should be initialized with pretrained weights from
|
||||
torchvision.
|
||||
|
@ -50,21 +57,35 @@ class ActionChunkingTransformerConfig:
|
|||
is enabled. Loss is then calculated as: `reconstruction_loss + kl_weight * kld_loss`.
|
||||
"""
|
||||
|
||||
# Environment.
|
||||
state_dim: int = 14
|
||||
action_dim: int = 14
|
||||
|
||||
# Inputs / output structure.
|
||||
# Input / output structure.
|
||||
n_obs_steps: int = 1
|
||||
camera_names: tuple[str] = ("top",)
|
||||
chunk_size: int = 100
|
||||
n_action_steps: int = 100
|
||||
|
||||
# Vision preprocessing.
|
||||
image_normalization_mean: tuple[float, float, float] = field(
|
||||
default_factory=lambda: [0.485, 0.456, 0.406]
|
||||
input_shapes: dict[str, list[str]] = field(
|
||||
default_factory=lambda: {
|
||||
"observation.images.top": [3, 480, 640],
|
||||
"observation.state": [14],
|
||||
}
|
||||
)
|
||||
output_shapes: dict[str, list[str]] = field(
|
||||
default_factory=lambda: {
|
||||
"action": [14],
|
||||
}
|
||||
)
|
||||
|
||||
# Normalization / Unnormalization
|
||||
normalize_input_modes: dict[str, str] = field(
|
||||
default_factory=lambda: {
|
||||
"observation.image": "mean_std",
|
||||
"observation.state": "mean_std",
|
||||
}
|
||||
)
|
||||
unnormalize_output_modes: dict[str, str] = field(
|
||||
default_factory=lambda: {
|
||||
"action": "mean_std",
|
||||
}
|
||||
)
|
||||
image_normalization_std: tuple[float, float, float] = field(default_factory=lambda: [0.229, 0.224, 0.225])
|
||||
|
||||
# Architecture.
|
||||
# Vision backbone.
|
||||
|
@ -117,7 +138,10 @@ class ActionChunkingTransformerConfig:
|
|||
raise ValueError(
|
||||
f"Multiple observation steps not handled yet. Got `nobs_steps={self.n_obs_steps}`"
|
||||
)
|
||||
if self.camera_names != ["top"]:
|
||||
raise ValueError(f"For now, `camera_names` can only be ['top']. Got {self.camera_names}.")
|
||||
if len(set(self.camera_names)) != len(self.camera_names):
|
||||
raise ValueError(f"`camera_names` should not have any repeated entries. Got {self.camera_names}.")
|
||||
# Check that there is only one image.
|
||||
# TODO(alexander-soare): generalize this to multiple images.
|
||||
if (
|
||||
sum(k.startswith("observation.images.") for k in self.input_shapes) != 1
|
||||
or "observation.images.top" not in self.input_shapes
|
||||
):
|
||||
raise ValueError('For now, only "observation.images.top" is accepted for an image input.')
|
||||
|
|
|
@ -15,12 +15,12 @@ import numpy as np
|
|||
import torch
|
||||
import torch.nn.functional as F # noqa: N812
|
||||
import torchvision
|
||||
import torchvision.transforms as transforms
|
||||
from torch import Tensor, nn
|
||||
from torchvision.models._utils import IntermediateLayerGetter
|
||||
from torchvision.ops.misc import FrozenBatchNorm2d
|
||||
|
||||
from lerobot.common.policies.act.configuration_act import ActionChunkingTransformerConfig
|
||||
from lerobot.common.policies.normalize import Normalize, Unnormalize
|
||||
|
||||
|
||||
class ActionChunkingTransformerPolicy(nn.Module):
|
||||
|
@ -62,7 +62,7 @@ class ActionChunkingTransformerPolicy(nn.Module):
|
|||
|
||||
name = "act"
|
||||
|
||||
def __init__(self, cfg: ActionChunkingTransformerConfig | None = None):
|
||||
def __init__(self, cfg: ActionChunkingTransformerConfig | None = None, dataset_stats=None):
|
||||
"""
|
||||
Args:
|
||||
cfg: Policy configuration class instance or None, in which case the default instantiation of the
|
||||
|
@ -72,6 +72,8 @@ class ActionChunkingTransformerPolicy(nn.Module):
|
|||
if cfg is None:
|
||||
cfg = ActionChunkingTransformerConfig()
|
||||
self.cfg = cfg
|
||||
self.normalize_inputs = Normalize(cfg.input_shapes, cfg.normalize_input_modes, dataset_stats)
|
||||
self.unnormalize_outputs = Unnormalize(cfg.output_shapes, cfg.unnormalize_output_modes, dataset_stats)
|
||||
|
||||
# BERT style VAE encoder with input [cls, *joint_space_configuration, *action_sequence].
|
||||
# The cls token forms parameters of the latent's distribution (like this [*means, *log_variances]).
|
||||
|
@ -79,9 +81,13 @@ class ActionChunkingTransformerPolicy(nn.Module):
|
|||
self.vae_encoder = _TransformerEncoder(cfg)
|
||||
self.vae_encoder_cls_embed = nn.Embedding(1, cfg.d_model)
|
||||
# Projection layer for joint-space configuration to hidden dimension.
|
||||
self.vae_encoder_robot_state_input_proj = nn.Linear(cfg.state_dim, cfg.d_model)
|
||||
self.vae_encoder_robot_state_input_proj = nn.Linear(
|
||||
cfg.input_shapes["observation.state"][0], cfg.d_model
|
||||
)
|
||||
# Projection layer for action (joint-space target) to hidden dimension.
|
||||
self.vae_encoder_action_input_proj = nn.Linear(cfg.state_dim, cfg.d_model)
|
||||
self.vae_encoder_action_input_proj = nn.Linear(
|
||||
cfg.input_shapes["observation.state"][0], cfg.d_model
|
||||
)
|
||||
self.latent_dim = cfg.latent_dim
|
||||
# Projection layer from the VAE encoder's output to the latent distribution's parameter space.
|
||||
self.vae_encoder_latent_output_proj = nn.Linear(cfg.d_model, self.latent_dim * 2)
|
||||
|
@ -93,9 +99,6 @@ class ActionChunkingTransformerPolicy(nn.Module):
|
|||
)
|
||||
|
||||
# Backbone for image feature extraction.
|
||||
self.image_normalizer = transforms.Normalize(
|
||||
mean=cfg.image_normalization_mean, std=cfg.image_normalization_std
|
||||
)
|
||||
backbone_model = getattr(torchvision.models, cfg.vision_backbone)(
|
||||
replace_stride_with_dilation=[False, False, cfg.replace_final_stride_with_dilation],
|
||||
pretrained=cfg.use_pretrained_backbone,
|
||||
|
@ -112,7 +115,7 @@ class ActionChunkingTransformerPolicy(nn.Module):
|
|||
|
||||
# Transformer encoder input projections. The tokens will be structured like
|
||||
# [latent, robot_state, image_feature_map_pixels].
|
||||
self.encoder_robot_state_input_proj = nn.Linear(cfg.state_dim, cfg.d_model)
|
||||
self.encoder_robot_state_input_proj = nn.Linear(cfg.input_shapes["observation.state"][0], cfg.d_model)
|
||||
self.encoder_latent_input_proj = nn.Linear(self.latent_dim, cfg.d_model)
|
||||
self.encoder_img_feat_input_proj = nn.Conv2d(
|
||||
backbone_model.fc.in_features, cfg.d_model, kernel_size=1
|
||||
|
@ -126,7 +129,7 @@ class ActionChunkingTransformerPolicy(nn.Module):
|
|||
self.decoder_pos_embed = nn.Embedding(cfg.chunk_size, cfg.d_model)
|
||||
|
||||
# Final action regression head on the output of the transformer's decoder.
|
||||
self.action_head = nn.Linear(cfg.d_model, cfg.action_dim)
|
||||
self.action_head = nn.Linear(cfg.d_model, cfg.output_shapes["action"][0])
|
||||
|
||||
self._reset_parameters()
|
||||
self._create_optimizer()
|
||||
|
@ -169,10 +172,18 @@ class ActionChunkingTransformerPolicy(nn.Module):
|
|||
queue is empty.
|
||||
"""
|
||||
self.eval()
|
||||
|
||||
batch = self.normalize_inputs(batch)
|
||||
|
||||
if len(self._action_queue) == 0:
|
||||
# `_forward` returns a (batch_size, n_action_steps, action_dim) tensor, but the queue effectively
|
||||
# has shape (n_action_steps, batch_size, *), hence the transpose.
|
||||
self._action_queue.extend(self._forward(batch)[0][: self.cfg.n_action_steps].transpose(0, 1))
|
||||
actions = self._forward(batch)[0][: self.cfg.n_action_steps]
|
||||
|
||||
# TODO(rcadene): make _forward return output dictionary?
|
||||
actions = self.unnormalize_outputs({"action": actions})["action"]
|
||||
|
||||
self._action_queue.extend(actions.transpose(0, 1))
|
||||
return self._action_queue.popleft()
|
||||
|
||||
def forward(self, batch, **_) -> dict[str, Tensor]:
|
||||
|
@ -203,7 +214,11 @@ class ActionChunkingTransformerPolicy(nn.Module):
|
|||
"""Run the model in train mode, compute the loss, and do an optimization step."""
|
||||
start_time = time.time()
|
||||
self.train()
|
||||
|
||||
batch = self.normalize_inputs(batch)
|
||||
|
||||
loss_dict = self.forward(batch)
|
||||
# TODO(rcadene): self.unnormalize_outputs(out_dict)
|
||||
loss = loss_dict["loss"]
|
||||
loss.backward()
|
||||
|
||||
|
@ -232,17 +247,9 @@ class ActionChunkingTransformerPolicy(nn.Module):
|
|||
"observation.images.{name}": (B, C, H, W) tensor of images.
|
||||
}
|
||||
"""
|
||||
# Check that there is only one image.
|
||||
# TODO(alexander-soare): generalize this to multiple images.
|
||||
provided_cameras = {k.rsplit(".", 1)[-1] for k in batch if k.startswith("observation.images.")}
|
||||
if len(missing := set(self.cfg.camera_names).difference(provided_cameras)) > 0:
|
||||
raise ValueError(
|
||||
f"The following camera images are missing from the provided batch: {missing}. Check the "
|
||||
"configuration parameter: `camera_names`."
|
||||
)
|
||||
# Stack images in the order dictated by the camera names.
|
||||
# Stack images in the order dictated by input_shapes.
|
||||
batch["observation.images"] = torch.stack(
|
||||
[batch[f"observation.images.{name}"] for name in self.cfg.camera_names],
|
||||
[batch[k] for k in self.cfg.input_shapes if k.startswith("observation.images.")],
|
||||
dim=-4,
|
||||
)
|
||||
|
||||
|
@ -309,8 +316,8 @@ class ActionChunkingTransformerPolicy(nn.Module):
|
|||
# Camera observation features and positional embeddings.
|
||||
all_cam_features = []
|
||||
all_cam_pos_embeds = []
|
||||
images = self.image_normalizer(batch["observation.images"])
|
||||
for cam_index in range(len(self.cfg.camera_names)):
|
||||
images = batch["observation.images"]
|
||||
for cam_index in range(images.shape[-4]):
|
||||
cam_features = self.backbone(images[:, cam_index])["feature_map"]
|
||||
cam_pos_embed = self.encoder_cam_feat_pos_embed(cam_features).to(dtype=cam_features.dtype)
|
||||
cam_features = self.encoder_img_feat_input_proj(cam_features) # (B, C, h, w)
|
||||
|
|
|
@ -1,4 +1,4 @@
|
|||
from dataclasses import dataclass
|
||||
from dataclasses import dataclass, field
|
||||
|
||||
|
||||
@dataclass
|
||||
|
@ -8,21 +8,28 @@ class DiffusionConfig:
|
|||
Defaults are configured for training with PushT providing proprioceptive and single camera observations.
|
||||
|
||||
The parameters you will most likely need to change are the ones which depend on the environment / sensors.
|
||||
Those are: `state_dim`, `action_dim` and `image_size`.
|
||||
Those are: `input_shapes` and `output_shapes`.
|
||||
|
||||
Args:
|
||||
state_dim: Dimensionality of the observation state space (excluding images).
|
||||
action_dim: Dimensionality of the action space.
|
||||
image_size: (H, W) size of the input images.
|
||||
n_obs_steps: Number of environment steps worth of observations to pass to the policy (takes the
|
||||
current step and additional steps going back).
|
||||
horizon: Diffusion model action prediction size as detailed in `DiffusionPolicy.select_action`.
|
||||
n_action_steps: The number of action steps to run in the environment for one invocation of the policy.
|
||||
See `DiffusionPolicy.select_action` for more details.
|
||||
image_normalization_mean: Value to subtract from the input image pixels (inputs are assumed to be in
|
||||
[0, 1]) for normalization.
|
||||
image_normalization_std: Value by which to divide the input image pixels (after the mean has been
|
||||
subtracted).
|
||||
input_shapes: A dictionary defining the shapes of the input data for the policy.
|
||||
The key represents the input data name, and the value is a list indicating the dimensions
|
||||
of the corresponding data. For example, "observation.image" refers to an input from
|
||||
a camera with dimensions [3, 96, 96], indicating it has three color channels and 96x96 resolution.
|
||||
Importantly, shapes doesnt include batch dimension or temporal dimension.
|
||||
output_shapes: A dictionary defining the shapes of the output data for the policy.
|
||||
The key represents the output data name, and the value is a list indicating the dimensions
|
||||
of the corresponding data. For example, "action" refers to an output shape of [14], indicating
|
||||
14-dimensional actions. Importantly, shapes doesnt include batch dimension or temporal dimension.
|
||||
normalize_input_modes: A dictionary with key represents the modality (e.g. "observation.state"),
|
||||
and the value specifies the normalization mode to apply. The two availables
|
||||
modes are "mean_std" which substracts the mean and divide by the standard
|
||||
deviation and "min_max" which rescale in a [-1, 1] range.
|
||||
unnormalize_output_modes: Similar dictionary as `normalize_input_modes`, but to unormalize in original scale.
|
||||
vision_backbone: Name of the torchvision resnet backbone to use for encoding images.
|
||||
crop_shape: (H, W) shape to crop images to as a preprocessing step for the vision backbone. Must fit
|
||||
within the image size. If None, no cropping is done.
|
||||
|
@ -58,20 +65,35 @@ class DiffusionConfig:
|
|||
spaced). If not provided, this defaults to be the same as `num_train_timesteps`.
|
||||
"""
|
||||
|
||||
# Environment.
|
||||
# Inherit these from the environment config.
|
||||
state_dim: int = 2
|
||||
action_dim: int = 2
|
||||
image_size: tuple[int, int] = (96, 96)
|
||||
|
||||
# Inputs / output structure.
|
||||
n_obs_steps: int = 2
|
||||
horizon: int = 16
|
||||
n_action_steps: int = 8
|
||||
|
||||
# Vision preprocessing.
|
||||
image_normalization_mean: tuple[float, float, float] = (0.5, 0.5, 0.5)
|
||||
image_normalization_std: tuple[float, float, float] = (0.5, 0.5, 0.5)
|
||||
input_shapes: dict[str, list[str]] = field(
|
||||
default_factory=lambda: {
|
||||
"observation.image": [3, 96, 96],
|
||||
"observation.state": [2],
|
||||
}
|
||||
)
|
||||
output_shapes: dict[str, list[str]] = field(
|
||||
default_factory=lambda: {
|
||||
"action": [2],
|
||||
}
|
||||
)
|
||||
|
||||
# Normalization / Unnormalization
|
||||
normalize_input_modes: dict[str, str] = field(
|
||||
default_factory=lambda: {
|
||||
"observation.image": "mean_std",
|
||||
"observation.state": "min_max",
|
||||
}
|
||||
)
|
||||
unnormalize_output_modes: dict[str, str] = field(
|
||||
default_factory=lambda: {
|
||||
"action": "min_max",
|
||||
}
|
||||
)
|
||||
|
||||
# Architecture / modeling.
|
||||
# Vision backbone.
|
||||
|
@ -123,10 +145,14 @@ class DiffusionConfig:
|
|||
raise ValueError(
|
||||
f"`vision_backbone` must be one of the ResNet variants. Got {self.vision_backbone}."
|
||||
)
|
||||
if self.crop_shape[0] > self.image_size[0] or self.crop_shape[1] > self.image_size[1]:
|
||||
if (
|
||||
self.crop_shape[0] > self.input_shapes["observation.image"][1]
|
||||
or self.crop_shape[1] > self.input_shapes["observation.image"][2]
|
||||
):
|
||||
raise ValueError(
|
||||
f"`crop_shape` should fit within `image_size`. Got {self.crop_shape} for `crop_shape` and "
|
||||
f"{self.image_size} for `image_size`."
|
||||
f'`crop_shape` should fit within `input_shapes["observation.image"]`. Got {self.crop_shape} '
|
||||
f'for `crop_shape` and {self.input_shapes["observation.image"]} for '
|
||||
'`input_shapes["observation.image"]`.'
|
||||
)
|
||||
supported_prediction_types = ["epsilon", "sample"]
|
||||
if self.prediction_type not in supported_prediction_types:
|
||||
|
|
|
@ -13,7 +13,6 @@ import logging
|
|||
import math
|
||||
import time
|
||||
from collections import deque
|
||||
from itertools import chain
|
||||
from typing import Callable
|
||||
|
||||
import einops
|
||||
|
@ -27,6 +26,7 @@ from torch import Tensor, nn
|
|||
from torch.nn.modules.batchnorm import _BatchNorm
|
||||
|
||||
from lerobot.common.policies.diffusion.configuration_diffusion import DiffusionConfig
|
||||
from lerobot.common.policies.normalize import Normalize, Unnormalize
|
||||
from lerobot.common.policies.utils import (
|
||||
get_device_from_parameters,
|
||||
get_dtype_from_parameters,
|
||||
|
@ -42,7 +42,9 @@ class DiffusionPolicy(nn.Module):
|
|||
|
||||
name = "diffusion"
|
||||
|
||||
def __init__(self, cfg: DiffusionConfig | None, lr_scheduler_num_training_steps: int = 0):
|
||||
def __init__(
|
||||
self, cfg: DiffusionConfig | None, lr_scheduler_num_training_steps: int = 0, dataset_stats=None
|
||||
):
|
||||
"""
|
||||
Args:
|
||||
cfg: Policy configuration class instance or None, in which case the default instantiation of the
|
||||
|
@ -54,6 +56,8 @@ class DiffusionPolicy(nn.Module):
|
|||
if cfg is None:
|
||||
cfg = DiffusionConfig()
|
||||
self.cfg = cfg
|
||||
self.normalize_inputs = Normalize(cfg.input_shapes, cfg.normalize_input_modes, dataset_stats)
|
||||
self.unnormalize_outputs = Unnormalize(cfg.output_shapes, cfg.unnormalize_output_modes, dataset_stats)
|
||||
|
||||
# queues are populated during rollout of the policy, they contain the n latest observations and actions
|
||||
self._queues = None
|
||||
|
@ -126,6 +130,8 @@ class DiffusionPolicy(nn.Module):
|
|||
assert "observation.state" in batch
|
||||
assert len(batch) == 2
|
||||
|
||||
batch = self.normalize_inputs(batch)
|
||||
|
||||
self._queues = populate_queues(self._queues, batch)
|
||||
|
||||
if len(self._queues["action"]) == 0:
|
||||
|
@ -135,6 +141,10 @@ class DiffusionPolicy(nn.Module):
|
|||
actions = self.ema_diffusion.generate_actions(batch)
|
||||
else:
|
||||
actions = self.diffusion.generate_actions(batch)
|
||||
|
||||
# TODO(rcadene): make above methods return output dictionary?
|
||||
actions = self.unnormalize_outputs({"action": actions})["action"]
|
||||
|
||||
self._queues["action"].extend(actions.transpose(0, 1))
|
||||
|
||||
action = self._queues["action"].popleft()
|
||||
|
@ -151,9 +161,13 @@ class DiffusionPolicy(nn.Module):
|
|||
|
||||
self.diffusion.train()
|
||||
|
||||
batch = self.normalize_inputs(batch)
|
||||
|
||||
loss = self.forward(batch)["loss"]
|
||||
loss.backward()
|
||||
|
||||
# TODO(rcadene): self.unnormalize_outputs(out_dict)
|
||||
|
||||
grad_norm = torch.nn.utils.clip_grad_norm_(
|
||||
self.diffusion.parameters(),
|
||||
self.cfg.grad_clip_norm,
|
||||
|
@ -197,7 +211,8 @@ class _DiffusionUnetImagePolicy(nn.Module):
|
|||
|
||||
self.rgb_encoder = _RgbEncoder(cfg)
|
||||
self.unet = _ConditionalUnet1D(
|
||||
cfg, global_cond_dim=(cfg.action_dim + self.rgb_encoder.feature_dim) * cfg.n_obs_steps
|
||||
cfg,
|
||||
global_cond_dim=(cfg.output_shapes["action"][0] + self.rgb_encoder.feature_dim) * cfg.n_obs_steps,
|
||||
)
|
||||
|
||||
self.noise_scheduler = DDPMScheduler(
|
||||
|
@ -225,7 +240,7 @@ class _DiffusionUnetImagePolicy(nn.Module):
|
|||
|
||||
# Sample prior.
|
||||
sample = torch.randn(
|
||||
size=(batch_size, self.cfg.horizon, self.cfg.action_dim),
|
||||
size=(batch_size, self.cfg.horizon, self.cfg.output_shapes["action"][0]),
|
||||
dtype=dtype,
|
||||
device=device,
|
||||
generator=generator,
|
||||
|
@ -268,7 +283,7 @@ class _DiffusionUnetImagePolicy(nn.Module):
|
|||
sample = self.conditional_sample(batch_size, global_cond=global_cond)
|
||||
|
||||
# `horizon` steps worth of actions (from the first observation).
|
||||
actions = sample[..., : self.cfg.action_dim]
|
||||
actions = sample[..., : self.cfg.output_shapes["action"][0]]
|
||||
# Extract `n_action_steps` steps worth of actions (from the current observation).
|
||||
start = n_obs_steps - 1
|
||||
end = start + self.cfg.n_action_steps
|
||||
|
@ -346,12 +361,6 @@ class _RgbEncoder(nn.Module):
|
|||
def __init__(self, cfg: DiffusionConfig):
|
||||
super().__init__()
|
||||
# Set up optional preprocessing.
|
||||
if all(v == 1.0 for v in chain(cfg.image_normalization_mean, cfg.image_normalization_std)):
|
||||
self.normalizer = nn.Identity()
|
||||
else:
|
||||
self.normalizer = torchvision.transforms.Normalize(
|
||||
mean=cfg.image_normalization_mean, std=cfg.image_normalization_std
|
||||
)
|
||||
if cfg.crop_shape is not None:
|
||||
self.do_crop = True
|
||||
# Always use center crop for eval
|
||||
|
@ -384,7 +393,9 @@ class _RgbEncoder(nn.Module):
|
|||
# Set up pooling and final layers.
|
||||
# Use a dry run to get the feature map shape.
|
||||
with torch.inference_mode():
|
||||
feat_map_shape = tuple(self.backbone(torch.zeros(size=(1, 3, *cfg.image_size))).shape[1:])
|
||||
feat_map_shape = tuple(
|
||||
self.backbone(torch.zeros(size=(1, *cfg.input_shapes["observation.image"]))).shape[1:]
|
||||
)
|
||||
self.pool = SpatialSoftmax(feat_map_shape, num_kp=cfg.spatial_softmax_num_keypoints)
|
||||
self.feature_dim = cfg.spatial_softmax_num_keypoints * 2
|
||||
self.out = nn.Linear(cfg.spatial_softmax_num_keypoints * 2, self.feature_dim)
|
||||
|
@ -397,8 +408,7 @@ class _RgbEncoder(nn.Module):
|
|||
Returns:
|
||||
(B, D) image feature.
|
||||
"""
|
||||
# Preprocess: normalize and maybe crop (if it was set up in the __init__).
|
||||
x = self.normalizer(x)
|
||||
# Preprocess: maybe crop (if it was set up in the __init__).
|
||||
if self.do_crop:
|
||||
if self.training: # noqa: SIM108
|
||||
x = self.maybe_random_crop(x)
|
||||
|
@ -502,7 +512,7 @@ class _ConditionalUnet1D(nn.Module):
|
|||
|
||||
# In channels / out channels for each downsampling block in the Unet's encoder. For the decoder, we
|
||||
# just reverse these.
|
||||
in_out = [(cfg.action_dim, cfg.down_dims[0])] + list(
|
||||
in_out = [(cfg.output_shapes["action"][0], cfg.down_dims[0])] + list(
|
||||
zip(cfg.down_dims[:-1], cfg.down_dims[1:], strict=True)
|
||||
)
|
||||
|
||||
|
@ -553,7 +563,7 @@ class _ConditionalUnet1D(nn.Module):
|
|||
|
||||
self.final_conv = nn.Sequential(
|
||||
_Conv1dBlock(cfg.down_dims[0], cfg.down_dims[0], kernel_size=cfg.kernel_size),
|
||||
nn.Conv1d(cfg.down_dims[0], cfg.action_dim, 1),
|
||||
nn.Conv1d(cfg.down_dims[0], cfg.output_shapes["action"][0], 1),
|
||||
)
|
||||
|
||||
def forward(self, x: Tensor, timestep: Tensor | int, global_cond=None) -> Tensor:
|
||||
|
|
|
@ -20,7 +20,7 @@ def _policy_cfg_from_hydra_cfg(policy_cfg_class, hydra_cfg):
|
|||
return policy_cfg
|
||||
|
||||
|
||||
def make_policy(hydra_cfg: DictConfig):
|
||||
def make_policy(hydra_cfg: DictConfig, dataset_stats=None):
|
||||
if hydra_cfg.policy.name == "tdmpc":
|
||||
from lerobot.common.policies.tdmpc.policy import TDMPCPolicy
|
||||
|
||||
|
@ -35,14 +35,14 @@ def make_policy(hydra_cfg: DictConfig):
|
|||
from lerobot.common.policies.diffusion.modeling_diffusion import DiffusionPolicy
|
||||
|
||||
policy_cfg = _policy_cfg_from_hydra_cfg(DiffusionConfig, hydra_cfg)
|
||||
policy = DiffusionPolicy(policy_cfg, hydra_cfg.offline_steps)
|
||||
policy = DiffusionPolicy(policy_cfg, hydra_cfg.offline_steps, dataset_stats)
|
||||
policy.to(get_safe_torch_device(hydra_cfg.device))
|
||||
elif hydra_cfg.policy.name == "act":
|
||||
from lerobot.common.policies.act.configuration_act import ActionChunkingTransformerConfig
|
||||
from lerobot.common.policies.act.modeling_act import ActionChunkingTransformerPolicy
|
||||
|
||||
policy_cfg = _policy_cfg_from_hydra_cfg(ActionChunkingTransformerConfig, hydra_cfg)
|
||||
policy = ActionChunkingTransformerPolicy(policy_cfg)
|
||||
policy = ActionChunkingTransformerPolicy(policy_cfg, dataset_stats)
|
||||
policy.to(get_safe_torch_device(hydra_cfg.device))
|
||||
else:
|
||||
raise ValueError(hydra_cfg.policy.name)
|
||||
|
|
|
@ -0,0 +1,196 @@
|
|||
import torch
|
||||
from torch import nn
|
||||
|
||||
|
||||
def create_stats_buffers(shapes, modes, stats=None):
|
||||
"""
|
||||
Create buffers per modality (e.g. "observation.image", "action") containing their mean, std, min, max statistics.
|
||||
|
||||
Parameters:
|
||||
shapes (dict): A dictionary where keys are input modalities (e.g. "observation.image") and values are their shapes (e.g. `[3,96,96]`]).
|
||||
These shapes are used to create the tensor buffer containing mean, std, min, max statistics. If the provided `shapes` contain keys related to images, the shape is adjusted to be invariant to height
|
||||
and width, assuming a channel-first (c, h, w) format.
|
||||
modes (dict): A dictionary where keys are output modalities (e.g. "observation.image") and values are their normalization modes among:
|
||||
- "mean_std": substract the mean and divide by standard deviation.
|
||||
- "min_max": map to [-1, 1] range.
|
||||
stats (dict, optional): A dictionary where keys are output modalities (e.g. "observation.image") and values are dictionaries of statistic types and their values
|
||||
(e.g. `{"mean": torch.randn(3,1,1)}, "std": torch.randn(3,1,1)}`). If provided, as expected for training the model for the first time,
|
||||
these statistics will overwrite the default buffers. If not provided, as expected for finetuning or evaluation, the default buffers should to be
|
||||
be overwritten by a call to `policy.load_state_dict(state_dict)`. That way, initializing the dataset is not needed to get the stats, since
|
||||
they are already in the policy state_dict.
|
||||
|
||||
Returns:
|
||||
dict: A dictionary where keys are modalities and values are `nn.ParameterDict` containing `nn.Parameters` set to
|
||||
`requires_grad=False`, suitable to not be updated during backpropagation.
|
||||
"""
|
||||
stats_buffers = {}
|
||||
|
||||
for key, mode in modes.items():
|
||||
assert mode in ["mean_std", "min_max"]
|
||||
|
||||
shape = tuple(shapes[key])
|
||||
|
||||
if "image" in key:
|
||||
# sanity checks
|
||||
assert len(shape) == 3, f"number of dimensions of {key} != 3 ({shape=}"
|
||||
c, h, w = shape
|
||||
assert c < h and c < w, f"{key} is not channel first ({shape=})"
|
||||
# override image shape to be invariant to height and width
|
||||
shape = (c, 1, 1)
|
||||
|
||||
# Note: we initialize mean, std, min, max to infinity. They should be overwritten
|
||||
# downstream by `stats` or `policy.load_state_dict`, as expected. During forward,
|
||||
# we assert they are not infinity anymore.
|
||||
|
||||
buffer = {}
|
||||
if mode == "mean_std":
|
||||
mean = torch.ones(shape, dtype=torch.float32) * torch.inf
|
||||
std = torch.ones(shape, dtype=torch.float32) * torch.inf
|
||||
buffer = nn.ParameterDict(
|
||||
{
|
||||
"mean": nn.Parameter(mean, requires_grad=False),
|
||||
"std": nn.Parameter(std, requires_grad=False),
|
||||
}
|
||||
)
|
||||
elif mode == "min_max":
|
||||
min = torch.ones(shape, dtype=torch.float32) * torch.inf
|
||||
max = torch.ones(shape, dtype=torch.float32) * torch.inf
|
||||
buffer = nn.ParameterDict(
|
||||
{
|
||||
"min": nn.Parameter(min, requires_grad=False),
|
||||
"max": nn.Parameter(max, requires_grad=False),
|
||||
}
|
||||
)
|
||||
|
||||
if stats is not None:
|
||||
if mode == "mean_std":
|
||||
buffer["mean"].data = stats[key]["mean"]
|
||||
buffer["std"].data = stats[key]["std"]
|
||||
elif mode == "min_max":
|
||||
buffer["min"].data = stats[key]["min"]
|
||||
buffer["max"].data = stats[key]["max"]
|
||||
|
||||
stats_buffers[key] = buffer
|
||||
return stats_buffers
|
||||
|
||||
|
||||
class Normalize(nn.Module):
|
||||
"""
|
||||
Normalizes the input data (e.g. "observation.image") for more stable and faster convergence during training.
|
||||
|
||||
Parameters:
|
||||
shapes (dict): A dictionary where keys are input modalities (e.g. "observation.image") and values are their shapes (e.g. `[3,96,96]`]).
|
||||
These shapes are used to create the tensor buffer containing mean, std, min, max statistics. If the provided `shapes` contain keys related to images, the shape is adjusted to be invariant to height
|
||||
and width, assuming a channel-first (c, h, w) format.
|
||||
modes (dict): A dictionary where keys are output modalities (e.g. "observation.image") and values are their normalization modes among:
|
||||
- "mean_std": substract the mean and divide by standard deviation.
|
||||
- "min_max": map to [-1, 1] range.
|
||||
stats (dict, optional): A dictionary where keys are output modalities (e.g. "observation.image") and values are dictionaries of statistic types and their values
|
||||
(e.g. `{"mean": torch.randn(3,1,1)}, "std": torch.randn(3,1,1)}`). If provided, as expected for training the model for the first time,
|
||||
these statistics will overwrite the default buffers. If not provided, as expected for finetuning or evaluation, the default buffers should to be
|
||||
be overwritten by a call to `policy.load_state_dict(state_dict)`. That way, initializing the dataset is not needed to get the stats, since
|
||||
they are already in the policy state_dict.
|
||||
"""
|
||||
|
||||
def __init__(self, shapes, modes, stats=None):
|
||||
super().__init__()
|
||||
self.shapes = shapes
|
||||
self.modes = modes
|
||||
self.stats = stats
|
||||
# `self.buffer_observation_state["mean"]` contains `torch.tensor(state_dim)`
|
||||
stats_buffers = create_stats_buffers(shapes, modes, stats)
|
||||
for key, buffer in stats_buffers.items():
|
||||
setattr(self, "buffer_" + key.replace(".", "_"), buffer)
|
||||
|
||||
# TODO(rcadene): should we remove torch.no_grad?
|
||||
@torch.no_grad
|
||||
def forward(self, batch):
|
||||
for key, mode in self.modes.items():
|
||||
buffer = getattr(self, "buffer_" + key.replace(".", "_"))
|
||||
|
||||
if mode == "mean_std":
|
||||
mean = buffer["mean"]
|
||||
std = buffer["std"]
|
||||
assert not torch.isinf(
|
||||
mean
|
||||
).any(), "`mean` is infinity. You forgot to initialize with `stats` as argument, or called `policy.load_state_dict`."
|
||||
assert not torch.isinf(
|
||||
std
|
||||
).any(), "`std` is infinity. You forgot to initialize with `stats` as argument, or called `policy.load_state_dict`."
|
||||
batch[key] = (batch[key] - mean) / (std + 1e-8)
|
||||
elif mode == "min_max":
|
||||
min = buffer["min"]
|
||||
max = buffer["max"]
|
||||
assert not torch.isinf(
|
||||
min
|
||||
).any(), "`min` is infinity. You forgot to initialize with `stats` as argument, or called `policy.load_state_dict`."
|
||||
assert not torch.isinf(
|
||||
max
|
||||
).any(), "`max` is infinity. You forgot to initialize with `stats` as argument, or called `policy.load_state_dict`."
|
||||
# normalize to [0,1]
|
||||
batch[key] = (batch[key] - min) / (max - min)
|
||||
# normalize to [-1, 1]
|
||||
batch[key] = batch[key] * 2 - 1
|
||||
else:
|
||||
raise ValueError(mode)
|
||||
return batch
|
||||
|
||||
|
||||
class Unnormalize(nn.Module):
|
||||
"""
|
||||
Similar to `Normalize` but unnormalizes output data (e.g. `{"action": torch.randn(b,c)}`) in their original range used by the environment.
|
||||
|
||||
Parameters:
|
||||
shapes (dict): A dictionary where keys are output modalities (e.g. "action") and values are their shapes (e.g. [10]).
|
||||
These shapes are used to create the tensor buffer containing mean, std, min, max statistics. If the provided `shapes` contain keys related to images, the shape is adjusted to be invariant to height
|
||||
and width, assuming a channel-first (c, h, w) format.
|
||||
modes (dict): A dictionary where keys are output modalities (e.g. "action") and values are their unnormalization modes among:
|
||||
- "mean_std": multiply by standard deviation and add mean
|
||||
- "min_max": go from [-1, 1] range to original range.
|
||||
stats (dict, optional): A dictionary where keys are output modalities (e.g. "action") and values are dictionaries of statistic types and their values
|
||||
(e.g. `{"max": torch.tensor(1)}, "min": torch.tensor(0)}`). If provided, as expected for training the model for the first time,
|
||||
these statistics will overwrite the default buffers. If not provided, as expected for finetuning or evaluation, the default buffers should to be
|
||||
be overwritten by a call to `policy.load_state_dict(state_dict)`. That way, initializing the dataset is not needed to get the stats, since
|
||||
they are already in the policy state_dict.
|
||||
"""
|
||||
|
||||
def __init__(self, shapes, modes, stats=None):
|
||||
super().__init__()
|
||||
self.shapes = shapes
|
||||
self.modes = modes
|
||||
self.stats = stats
|
||||
# `self.buffer_observation_state["mean"]` contains `torch.tensor(state_dim)`
|
||||
stats_buffers = create_stats_buffers(shapes, modes, stats)
|
||||
for key, buffer in stats_buffers.items():
|
||||
setattr(self, "buffer_" + key.replace(".", "_"), buffer)
|
||||
|
||||
# TODO(rcadene): should we remove torch.no_grad?
|
||||
@torch.no_grad
|
||||
def forward(self, batch):
|
||||
for key, mode in self.modes.items():
|
||||
buffer = getattr(self, "buffer_" + key.replace(".", "_"))
|
||||
|
||||
if mode == "mean_std":
|
||||
mean = buffer["mean"]
|
||||
std = buffer["std"]
|
||||
assert not torch.isinf(
|
||||
mean
|
||||
).any(), "`mean` is infinity. You forgot to initialize with `stats` as argument, or called `policy.load_state_dict`."
|
||||
assert not torch.isinf(
|
||||
std
|
||||
).any(), "`std` is infinity. You forgot to initialize with `stats` as argument, or called `policy.load_state_dict`."
|
||||
batch[key] = batch[key] * std + mean
|
||||
elif mode == "min_max":
|
||||
min = buffer["min"]
|
||||
max = buffer["max"]
|
||||
assert not torch.isinf(
|
||||
min
|
||||
).any(), "`min` is infinity. You forgot to initialize with `stats` as argument, or called `policy.load_state_dict`."
|
||||
assert not torch.isinf(
|
||||
max
|
||||
).any(), "`max` is infinity. You forgot to initialize with `stats` as argument, or called `policy.load_state_dict`."
|
||||
batch[key] = (batch[key] + 1) / 2
|
||||
batch[key] = batch[key] * (max - min) + min
|
||||
else:
|
||||
raise ValueError(mode)
|
||||
return batch
|
|
@ -1,65 +0,0 @@
|
|||
from torchvision.transforms.v2 import Compose, Transform
|
||||
|
||||
|
||||
def apply_inverse_transform(item, transform):
|
||||
transforms = transform.transforms if isinstance(transform, Compose) else [transform]
|
||||
for tf in transforms[::-1]:
|
||||
if tf.invertible:
|
||||
item = tf.inverse_transform(item)
|
||||
else:
|
||||
raise ValueError(f"Inverse transform called on a non invertible transform ({tf}).")
|
||||
return item
|
||||
|
||||
|
||||
class NormalizeTransform(Transform):
|
||||
invertible = True
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
stats: dict,
|
||||
in_keys: list[str] = None,
|
||||
out_keys: list[str] | None = None,
|
||||
in_keys_inv: list[str] | None = None,
|
||||
out_keys_inv: list[str] | None = None,
|
||||
mode="mean_std",
|
||||
):
|
||||
super().__init__()
|
||||
self.in_keys = in_keys
|
||||
self.out_keys = in_keys if out_keys is None else out_keys
|
||||
self.in_keys_inv = self.out_keys if in_keys_inv is None else in_keys_inv
|
||||
self.out_keys_inv = self.in_keys if out_keys_inv is None else out_keys_inv
|
||||
self.stats = stats
|
||||
assert mode in ["mean_std", "min_max"]
|
||||
self.mode = mode
|
||||
|
||||
def forward(self, item):
|
||||
for inkey, outkey in zip(self.in_keys, self.out_keys, strict=False):
|
||||
if inkey not in item:
|
||||
continue
|
||||
if self.mode == "mean_std":
|
||||
mean = self.stats[inkey]["mean"]
|
||||
std = self.stats[inkey]["std"]
|
||||
item[outkey] = (item[inkey] - mean) / (std + 1e-8)
|
||||
else:
|
||||
min = self.stats[inkey]["min"]
|
||||
max = self.stats[inkey]["max"]
|
||||
# normalize to [0,1]
|
||||
item[outkey] = (item[inkey] - min) / (max - min)
|
||||
# normalize to [-1, 1]
|
||||
item[outkey] = item[outkey] * 2 - 1
|
||||
return item
|
||||
|
||||
def inverse_transform(self, item):
|
||||
for inkey, outkey in zip(self.in_keys_inv, self.out_keys_inv, strict=False):
|
||||
if inkey not in item:
|
||||
continue
|
||||
if self.mode == "mean_std":
|
||||
mean = self.stats[inkey]["mean"]
|
||||
std = self.stats[inkey]["std"]
|
||||
item[outkey] = item[inkey] * std + mean
|
||||
else:
|
||||
min = self.stats[inkey]["min"]
|
||||
max = self.stats[inkey]["max"]
|
||||
item[outkey] = (item[inkey] + 1) / 2
|
||||
item[outkey] = item[outkey] * (max - min) + min
|
||||
return item
|
|
@ -20,7 +20,5 @@ env:
|
|||
image_size: [3, 480, 640]
|
||||
episode_length: 400
|
||||
fps: ${fps}
|
||||
|
||||
policy:
|
||||
state_dim: 14
|
||||
action_dim: 14
|
||||
|
|
|
@ -20,7 +20,5 @@ env:
|
|||
image_size: 96
|
||||
episode_length: 300
|
||||
fps: ${fps}
|
||||
|
||||
policy:
|
||||
state_dim: 2
|
||||
action_dim: 2
|
||||
|
|
|
@ -19,7 +19,5 @@ env:
|
|||
image_size: 84
|
||||
episode_length: 25
|
||||
fps: ${fps}
|
||||
|
||||
policy:
|
||||
state_dim: 4
|
||||
action_dim: 4
|
||||
|
|
|
@ -11,26 +11,36 @@ log_freq: 250
|
|||
n_obs_steps: 1
|
||||
# when temporal_agg=False, n_action_steps=horizon
|
||||
|
||||
override_dataset_stats:
|
||||
observation.images.top:
|
||||
# stats from imagenet, since we use a pretrained vision model
|
||||
mean: [[[0.485]], [[0.456]], [[0.406]]] # (c,1,1)
|
||||
std: [[[0.229]], [[0.224]], [[0.225]]] # (c,1,1)
|
||||
|
||||
# See `configuration_act.py` for more details.
|
||||
policy:
|
||||
name: act
|
||||
|
||||
pretrained_model_path:
|
||||
|
||||
# Environment.
|
||||
# Inherit these from the environment config.
|
||||
state_dim: ???
|
||||
action_dim: ???
|
||||
|
||||
# Inputs / output structure.
|
||||
# Input / output structure.
|
||||
n_obs_steps: ${n_obs_steps}
|
||||
camera_names: [top] # [top, front_close, left_pillar, right_pillar]
|
||||
chunk_size: 100 # chunk_size
|
||||
n_action_steps: 100
|
||||
|
||||
# Vision preprocessing.
|
||||
image_normalization_mean: [0.485, 0.456, 0.406]
|
||||
image_normalization_std: [0.229, 0.224, 0.225]
|
||||
input_shapes:
|
||||
# TODO(rcadene, alexander-soare): add variables for height and width from the dataset/env?
|
||||
observation.images.top: [3, 480, 640]
|
||||
observation.state: ["${env.state_dim}"]
|
||||
output_shapes:
|
||||
action: ["${env.action_dim}"]
|
||||
|
||||
# Normalization / Unnormalization
|
||||
normalize_input_modes:
|
||||
observation.images.top: mean_std
|
||||
observation.state: mean_std
|
||||
unnormalize_output_modes:
|
||||
action: mean_std
|
||||
|
||||
# Architecture.
|
||||
# Vision backbone.
|
||||
|
|
|
@ -18,27 +18,43 @@ online_steps: 0
|
|||
|
||||
offline_prioritized_sampler: true
|
||||
|
||||
override_dataset_stats:
|
||||
# TODO(rcadene, alexander-soare): should we remove image stats as well? do we use a pretrained vision model?
|
||||
observation.image:
|
||||
mean: [[[0.5]], [[0.5]], [[0.5]]] # (c,1,1)
|
||||
std: [[[0.5]], [[0.5]], [[0.5]]] # (c,1,1)
|
||||
# TODO(rcadene, alexander-soare): we override state and action stats to use the same as the pretrained model
|
||||
# from the original codebase, but we should remove these and train our own pretrained model
|
||||
observation.state:
|
||||
min: [13.456424, 32.938293]
|
||||
max: [496.14618, 510.9579]
|
||||
action:
|
||||
min: [12.0, 25.0]
|
||||
max: [511.0, 511.0]
|
||||
|
||||
policy:
|
||||
name: diffusion
|
||||
|
||||
pretrained_model_path:
|
||||
|
||||
# Environment.
|
||||
# Inherit these from the environment config.
|
||||
state_dim: ???
|
||||
action_dim: ???
|
||||
image_size:
|
||||
- ${env.image_size} # height
|
||||
- ${env.image_size} # width
|
||||
|
||||
# Inputs / output structure.
|
||||
# Input / output structure.
|
||||
n_obs_steps: ${n_obs_steps}
|
||||
horizon: ${horizon}
|
||||
n_action_steps: ${n_action_steps}
|
||||
|
||||
# Vision preprocessing.
|
||||
image_normalization_mean: [0.5, 0.5, 0.5]
|
||||
image_normalization_std: [0.5, 0.5, 0.5]
|
||||
input_shapes:
|
||||
# TODO(rcadene, alexander-soare): add variables for height and width from the dataset/env?
|
||||
observation.image: [3, 96, 96]
|
||||
observation.state: ["${env.state_dim}"]
|
||||
output_shapes:
|
||||
action: ["${env.action_dim}"]
|
||||
|
||||
# Normalization / Unnormalization
|
||||
normalize_input_modes:
|
||||
observation.image: mean_std
|
||||
observation.state: min_max
|
||||
unnormalize_output_modes:
|
||||
action: min_max
|
||||
|
||||
# Architecture / modeling.
|
||||
# Vision backbone.
|
||||
|
|
|
@ -16,8 +16,8 @@ policy:
|
|||
frame_stack: 1
|
||||
num_channels: 32
|
||||
img_size: ${env.image_size}
|
||||
state_dim: ???
|
||||
action_dim: ???
|
||||
state_dim: ${env.action_dim}
|
||||
action_dim: ${env.action_dim}
|
||||
|
||||
# planning
|
||||
mpc: true
|
||||
|
|
|
@ -46,7 +46,6 @@ from huggingface_hub import snapshot_download
|
|||
from PIL import Image as PILImage
|
||||
from tqdm import trange
|
||||
|
||||
from lerobot.common.datasets.factory import make_dataset
|
||||
from lerobot.common.datasets.utils import hf_transform_to_torch
|
||||
from lerobot.common.envs.factory import make_env
|
||||
from lerobot.common.envs.utils import postprocess_action, preprocess_observation
|
||||
|
@ -64,8 +63,6 @@ def eval_policy(
|
|||
policy: torch.nn.Module,
|
||||
max_episodes_rendered: int = 0,
|
||||
video_dir: Path = None,
|
||||
# TODO(rcadene): make it possible to overwrite fps? we should use env.fps
|
||||
transform: callable = None,
|
||||
return_episode_data: bool = False,
|
||||
seed=None,
|
||||
):
|
||||
|
@ -132,10 +129,6 @@ def eval_policy(
|
|||
if return_episode_data:
|
||||
observations.append(deepcopy(observation))
|
||||
|
||||
# apply transform to normalize the observations
|
||||
for key in observation:
|
||||
observation[key] = torch.stack([transform({key: item})[key] for item in observation[key]])
|
||||
|
||||
# send observation to device/gpu
|
||||
observation = {key: observation[key].to(device, non_blocking=True) for key in observation}
|
||||
|
||||
|
@ -143,8 +136,8 @@ def eval_policy(
|
|||
with torch.inference_mode():
|
||||
action = policy.select_action(observation, step=step)
|
||||
|
||||
# apply inverse transform to unnormalize the action
|
||||
action = postprocess_action(action, transform)
|
||||
# convert to cpu numpy
|
||||
action = postprocess_action(action)
|
||||
|
||||
# apply the next action
|
||||
observation, reward, terminated, truncated, info = env.step(action)
|
||||
|
@ -360,7 +353,7 @@ def eval_policy(
|
|||
return info
|
||||
|
||||
|
||||
def eval(cfg: dict, out_dir=None, stats_path=None):
|
||||
def eval(cfg: dict, out_dir=None):
|
||||
if out_dir is None:
|
||||
raise NotImplementedError()
|
||||
|
||||
|
@ -375,10 +368,6 @@ def eval(cfg: dict, out_dir=None, stats_path=None):
|
|||
|
||||
log_output_dir(out_dir)
|
||||
|
||||
logging.info("Making transforms.")
|
||||
# TODO(alexander-soare): Completely decouple datasets from evaluation.
|
||||
transform = make_dataset(cfg, stats_path=stats_path).transform
|
||||
|
||||
logging.info("Making environment.")
|
||||
env = make_env(cfg, num_parallel_envs=cfg.eval_episodes)
|
||||
|
||||
|
@ -390,7 +379,6 @@ def eval(cfg: dict, out_dir=None, stats_path=None):
|
|||
policy,
|
||||
max_episodes_rendered=10,
|
||||
video_dir=Path(out_dir) / "eval",
|
||||
transform=transform,
|
||||
return_episode_data=False,
|
||||
seed=cfg.seed,
|
||||
)
|
||||
|
@ -423,17 +411,13 @@ if __name__ == "__main__":
|
|||
if args.config is not None:
|
||||
# Note: For the config_path, Hydra wants a path relative to this script file.
|
||||
cfg = init_hydra_config(args.config, args.overrides)
|
||||
# TODO(alexander-soare): Save and load stats in trained model directory.
|
||||
stats_path = None
|
||||
elif args.hub_id is not None:
|
||||
folder = Path(snapshot_download(args.hub_id, revision=args.revision))
|
||||
cfg = init_hydra_config(
|
||||
folder / "config.yaml", [f"policy.pretrained_model_path={folder / 'model.pt'}", *args.overrides]
|
||||
)
|
||||
stats_path = folder / "stats.pth"
|
||||
|
||||
eval(
|
||||
cfg,
|
||||
out_dir=f"outputs/eval/{dt.now().strftime('%Y-%m-%d/%H-%M-%S')}_{cfg.env.name}_{cfg.policy.name}",
|
||||
stats_path=stats_path,
|
||||
)
|
||||
|
|
|
@ -232,7 +232,7 @@ def train(cfg: dict, out_dir=None, job_name=None):
|
|||
env = make_env(cfg, num_parallel_envs=cfg.eval_episodes)
|
||||
|
||||
logging.info("make_policy")
|
||||
policy = make_policy(cfg)
|
||||
policy = make_policy(cfg, dataset_stats=offline_dataset.stats)
|
||||
|
||||
num_learnable_params = sum(p.numel() for p in policy.parameters() if p.requires_grad)
|
||||
num_total_params = sum(p.numel() for p in policy.parameters())
|
||||
|
@ -339,7 +339,6 @@ def train(cfg: dict, out_dir=None, job_name=None):
|
|||
eval_info = eval_policy(
|
||||
rollout_env,
|
||||
policy,
|
||||
transform=offline_dataset.transform,
|
||||
return_episode_data=True,
|
||||
seed=cfg.seed,
|
||||
)
|
||||
|
|
|
@ -50,11 +50,7 @@ def visualize_dataset(cfg: dict, out_dir=None):
|
|||
log_output_dir(out_dir)
|
||||
|
||||
logging.info("make_dataset")
|
||||
dataset = make_dataset(
|
||||
cfg,
|
||||
# remove all transformations such as rescale images from [0,255] to [0,1] or normalization
|
||||
normalize=False,
|
||||
)
|
||||
dataset = make_dataset(cfg)
|
||||
|
||||
logging.info("Start rendering episodes from offline buffer")
|
||||
video_paths = render_dataset(dataset, out_dir, MAX_NUM_STEPS * NUM_EPISODES_TO_RENDER)
|
||||
|
|
|
@ -6,7 +6,6 @@ import torch
|
|||
from gymnasium.utils.env_checker import check_env
|
||||
|
||||
import lerobot
|
||||
from lerobot.common.datasets.factory import make_dataset
|
||||
from lerobot.common.envs.factory import make_env
|
||||
from lerobot.common.envs.utils import preprocess_observation
|
||||
from lerobot.common.utils.utils import init_hydra_config
|
||||
|
@ -38,12 +37,14 @@ def test_factory(env_name):
|
|||
overrides=[f"env={env_name}", f"device={DEVICE}"],
|
||||
)
|
||||
|
||||
dataset = make_dataset(cfg)
|
||||
|
||||
env = make_env(cfg, num_parallel_envs=1)
|
||||
obs, _ = env.reset()
|
||||
obs = preprocess_observation(obs, transform=dataset.transform)
|
||||
for key in dataset.image_keys:
|
||||
obs = preprocess_observation(obs)
|
||||
|
||||
# test image keys are float32 in range [0,1]
|
||||
for key in obs:
|
||||
if "image" not in key:
|
||||
continue
|
||||
img = obs[key]
|
||||
assert img.dtype == torch.float32
|
||||
# TODO(rcadene): we assume for now that image normalization takes place in the model
|
||||
|
|
|
@ -51,7 +51,7 @@ def test_examples_4_and_3():
|
|||
# Pass empty globals to allow dictionary comprehension https://stackoverflow.com/a/32897127/4391249.
|
||||
exec(file_contents, {})
|
||||
|
||||
for file_name in ["model.pt", "stats.pth", "config.yaml"]:
|
||||
for file_name in ["model.pt", "config.yaml"]:
|
||||
assert Path(f"outputs/train/example_pusht_diffusion/{file_name}").exists()
|
||||
|
||||
path = "examples/3_evaluate_pretrained_policy.py"
|
||||
|
|
|
@ -6,10 +6,10 @@ from lerobot.common.datasets.utils import cycle
|
|||
from lerobot.common.envs.factory import make_env
|
||||
from lerobot.common.envs.utils import postprocess_action, preprocess_observation
|
||||
from lerobot.common.policies.factory import make_policy
|
||||
from lerobot.common.policies.normalize import Normalize, Unnormalize
|
||||
from lerobot.common.policies.policy_protocol import Policy
|
||||
from lerobot.common.utils.utils import init_hydra_config
|
||||
|
||||
from .utils import DEFAULT_CONFIG_PATH, DEVICE, require_env
|
||||
from tests.utils import DEFAULT_CONFIG_PATH, DEVICE, require_env
|
||||
|
||||
|
||||
# TODO(aliberts): refactor using lerobot/__init__.py variables
|
||||
|
@ -44,14 +44,16 @@ def test_policy(env_name, policy_name, extra_overrides):
|
|||
]
|
||||
+ extra_overrides,
|
||||
)
|
||||
|
||||
# Check that we can make the policy object.
|
||||
policy = make_policy(cfg)
|
||||
dataset = make_dataset(cfg)
|
||||
policy = make_policy(cfg, dataset_stats=dataset.stats)
|
||||
# Check that the policy follows the required protocol.
|
||||
assert isinstance(
|
||||
policy, Policy
|
||||
), f"The policy does not follow the required protocol. Please see {Policy.__module__}.{Policy.__name__}."
|
||||
|
||||
# Check that we run select_actions and get the appropriate output.
|
||||
dataset = make_dataset(cfg)
|
||||
env = make_env(cfg, num_parallel_envs=2)
|
||||
|
||||
dataloader = torch.utils.data.DataLoader(
|
||||
|
@ -77,7 +79,7 @@ def test_policy(env_name, policy_name, extra_overrides):
|
|||
observation, _ = env.reset(seed=cfg.seed)
|
||||
|
||||
# apply transform to normalize the observations
|
||||
observation = preprocess_observation(observation, dataset.transform)
|
||||
observation = preprocess_observation(observation)
|
||||
|
||||
# send observation to device/gpu
|
||||
observation = {key: observation[key].to(DEVICE, non_blocking=True) for key in observation}
|
||||
|
@ -86,8 +88,115 @@ def test_policy(env_name, policy_name, extra_overrides):
|
|||
with torch.inference_mode():
|
||||
action = policy.select_action(observation, step=0)
|
||||
|
||||
# apply inverse transform to unnormalize the action
|
||||
action = postprocess_action(action, dataset.transform)
|
||||
# convert action to cpu numpy array
|
||||
action = postprocess_action(action)
|
||||
|
||||
# Test step through policy
|
||||
env.step(action)
|
||||
|
||||
# Test load state_dict
|
||||
if policy_name != "tdmpc":
|
||||
# TODO(rcadene, alexander-soare): make it work for tdmpc
|
||||
new_policy = make_policy(cfg)
|
||||
new_policy.load_state_dict(policy.state_dict())
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"insert_temporal_dim",
|
||||
[
|
||||
False,
|
||||
True,
|
||||
],
|
||||
)
|
||||
def test_normalize(insert_temporal_dim):
|
||||
"""
|
||||
Test that normalize/unnormalize can run without exceptions when properly set up, and that they raise
|
||||
an exception when the forward pass is called without the stats having been provided.
|
||||
|
||||
TODO(rcadene, alexander-soare): This should also test that the normalization / unnormalization works as
|
||||
expected.
|
||||
"""
|
||||
|
||||
input_shapes = {
|
||||
"observation.image": [3, 96, 96],
|
||||
"observation.state": [10],
|
||||
}
|
||||
output_shapes = {
|
||||
"action": [5],
|
||||
}
|
||||
|
||||
normalize_input_modes = {
|
||||
"observation.image": "mean_std",
|
||||
"observation.state": "min_max",
|
||||
}
|
||||
unnormalize_output_modes = {
|
||||
"action": "min_max",
|
||||
}
|
||||
|
||||
dataset_stats = {
|
||||
"observation.image": {
|
||||
"mean": torch.randn(3, 1, 1),
|
||||
"std": torch.randn(3, 1, 1),
|
||||
"min": torch.randn(3, 1, 1),
|
||||
"max": torch.randn(3, 1, 1),
|
||||
},
|
||||
"observation.state": {
|
||||
"mean": torch.randn(10),
|
||||
"std": torch.randn(10),
|
||||
"min": torch.randn(10),
|
||||
"max": torch.randn(10),
|
||||
},
|
||||
"action": {
|
||||
"mean": torch.randn(5),
|
||||
"std": torch.randn(5),
|
||||
"min": torch.randn(5),
|
||||
"max": torch.randn(5),
|
||||
},
|
||||
}
|
||||
|
||||
bsize = 2
|
||||
input_batch = {
|
||||
"observation.image": torch.randn(bsize, 3, 96, 96),
|
||||
"observation.state": torch.randn(bsize, 10),
|
||||
}
|
||||
output_batch = {
|
||||
"action": torch.randn(bsize, 5),
|
||||
}
|
||||
|
||||
if insert_temporal_dim:
|
||||
tdim = 4
|
||||
|
||||
for key in input_batch:
|
||||
# [2,3,96,96] -> [2,tdim,3,96,96]
|
||||
input_batch[key] = torch.stack([input_batch[key]] * tdim, dim=1)
|
||||
|
||||
for key in output_batch:
|
||||
output_batch[key] = torch.stack([output_batch[key]] * tdim, dim=1)
|
||||
|
||||
# test without stats
|
||||
normalize = Normalize(input_shapes, normalize_input_modes, stats=None)
|
||||
with pytest.raises(AssertionError):
|
||||
normalize(input_batch)
|
||||
|
||||
# test with stats
|
||||
normalize = Normalize(input_shapes, normalize_input_modes, stats=dataset_stats)
|
||||
normalize(input_batch)
|
||||
|
||||
# test loading pretrained models
|
||||
new_normalize = Normalize(input_shapes, normalize_input_modes, stats=None)
|
||||
new_normalize.load_state_dict(normalize.state_dict())
|
||||
new_normalize(input_batch)
|
||||
|
||||
# test without stats
|
||||
unnormalize = Unnormalize(output_shapes, unnormalize_output_modes, stats=None)
|
||||
with pytest.raises(AssertionError):
|
||||
unnormalize(output_batch)
|
||||
|
||||
# test with stats
|
||||
unnormalize = Unnormalize(output_shapes, unnormalize_output_modes, stats=dataset_stats)
|
||||
unnormalize(output_batch)
|
||||
|
||||
# test loading pretrained models
|
||||
new_unnormalize = Unnormalize(output_shapes, unnormalize_output_modes, stats=None)
|
||||
new_unnormalize.load_state_dict(unnormalize.state_dict())
|
||||
unnormalize(output_batch)
|
||||
|
|
Loading…
Reference in New Issue