Move normalize/unnormalize transforms to policy for act and diffusion

This commit is contained in:
Cadene 2024-04-20 21:08:14 +00:00
parent c1bcf857c5
commit 42ed7bb670
19 changed files with 145 additions and 195 deletions

View File

@ -263,15 +263,13 @@ Secondly, assuming you have trained a policy, you need:
- `config.yaml` which you can get from the `.hydra` directory of your training output folder. - `config.yaml` which you can get from the `.hydra` directory of your training output folder.
- `model.pt` which should be one of the saved models in the `models` directory of your training output folder (they won't be named `model.pt` but you will need to choose one). - `model.pt` which should be one of the saved models in the `models` directory of your training output folder (they won't be named `model.pt` but you will need to choose one).
- `stats.pth` which should point to the same file in the dataset directory (found in `data/{dataset_name}`).
To upload these to the hub, prepare a folder with the following structure (you can use symlinks rather than copying): To upload these to the hub, prepare a folder with the following structure (you can use symlinks rather than copying):
``` ```
to_upload to_upload
├── config.yaml ├── config.yaml
├── model.pt └── model.pt
└── stats.pth
``` ```
With the folder prepared, run the following with a desired revision ID. With the folder prepared, run the following with a desired revision ID.

View File

@ -19,7 +19,6 @@ folder = Path(snapshot_download(hub_id))
config_path = folder / "config.yaml" config_path = folder / "config.yaml"
weights_path = folder / "model.pt" weights_path = folder / "model.pt"
stats_path = folder / "stats.pth" # normalization stats
# Override some config parameters to do with evaluation. # Override some config parameters to do with evaluation.
overrides = [ overrides = [
@ -36,5 +35,4 @@ cfg = init_hydra_config(config_path, overrides)
eval( eval(
cfg, cfg,
out_dir=f"outputs/eval/example_{cfg.env.name}_{cfg.policy.name}", out_dir=f"outputs/eval/example_{cfg.env.name}_{cfg.policy.name}",
stats_path=stats_path,
) )

View File

@ -62,7 +62,6 @@ while not done:
done = True done = True
break break
# Save the policy, configuration, and normalization stats for later use. # Save the policy and configuration for later use.
policy.save(output_directory / "model.pt") policy.save(output_directory / "model.pt")
OmegaConf.save(hydra_cfg, output_directory / "config.yaml") OmegaConf.save(hydra_cfg, output_directory / "config.yaml")
torch.save(dataset.transform.transforms[-1].stats, output_directory / "stats.pth")

View File

@ -2,18 +2,12 @@ import os
from pathlib import Path from pathlib import Path
import torch import torch
from torchvision.transforms import v2
from lerobot.common.transforms import NormalizeTransform
DATA_DIR = Path(os.environ["DATA_DIR"]) if "DATA_DIR" in os.environ else None DATA_DIR = Path(os.environ["DATA_DIR"]) if "DATA_DIR" in os.environ else None
def make_dataset( def make_dataset(
cfg, cfg,
# set normalize=False to remove all transformations and keep images unnormalized in [0,255]
normalize=True,
stats_path=None,
split="train", split="train",
): ):
if cfg.env.name == "xarm": if cfg.env.name == "xarm":
@ -33,58 +27,23 @@ def make_dataset(
else: else:
raise ValueError(cfg.env.name) raise ValueError(cfg.env.name)
transforms = None
if normalize:
# TODO(rcadene): make normalization strategy configurable between mean_std, min_max, manual_min_max,
# min_max_from_spec
# TODO(rcadene): remove this and put it in config. Ideally we want to reproduce SOTA results just with mean_std
normalization_mode = "mean_std" if cfg.env.name == "aloha" else "min_max"
if cfg.policy.name == "diffusion" and cfg.env.name == "pusht":
stats = {}
# TODO(rcadene): we overwrite stats to have the same as pretrained model, but we should remove this
stats["observation.state"] = {}
stats["observation.state"]["min"] = torch.tensor([13.456424, 32.938293], dtype=torch.float32)
stats["observation.state"]["max"] = torch.tensor([496.14618, 510.9579], dtype=torch.float32)
stats["action"] = {}
stats["action"]["min"] = torch.tensor([12.0, 25.0], dtype=torch.float32)
stats["action"]["max"] = torch.tensor([511.0, 511.0], dtype=torch.float32)
elif stats_path is None:
# load a first dataset to access precomputed stats
stats_dataset = clsfunc(
dataset_id=cfg.dataset_id,
split="train",
root=DATA_DIR,
)
stats = stats_dataset.stats
else:
stats = torch.load(stats_path)
transforms = v2.Compose(
[
NormalizeTransform(
stats,
in_keys=[
"observation.state",
"action",
],
mode=normalization_mode,
),
]
)
delta_timestamps = cfg.policy.get("delta_timestamps") delta_timestamps = cfg.policy.get("delta_timestamps")
if delta_timestamps is not None: if delta_timestamps is not None:
for key in delta_timestamps: for key in delta_timestamps:
if isinstance(delta_timestamps[key], str): if isinstance(delta_timestamps[key], str):
delta_timestamps[key] = eval(delta_timestamps[key]) delta_timestamps[key] = eval(delta_timestamps[key])
# TODO(rcadene): add data augmentations
dataset = clsfunc( dataset = clsfunc(
dataset_id=cfg.dataset_id, dataset_id=cfg.dataset_id,
split=split, split=split,
root=DATA_DIR, root=DATA_DIR,
delta_timestamps=delta_timestamps, delta_timestamps=delta_timestamps,
transform=transforms,
) )
if cfg.get("override_dataset_stats"):
for key, val in cfg.override_dataset_stats.items():
dataset.stats[key] = torch.tensor(val)
return dataset return dataset

View File

@ -1,10 +1,8 @@
import einops import einops
import torch import torch
from lerobot.common.transforms import apply_inverse_transform
def preprocess_observation(observation):
def preprocess_observation(observation, transform=None):
# map to expected inputs for the policy # map to expected inputs for the policy
obs = {} obs = {}
@ -33,19 +31,11 @@ def preprocess_observation(observation, transform=None):
# TODO(rcadene): enable pixels only baseline with `obs_type="pixels"` in environment by removing requirement for "agent_pos" # TODO(rcadene): enable pixels only baseline with `obs_type="pixels"` in environment by removing requirement for "agent_pos"
obs["observation.state"] = torch.from_numpy(observation["agent_pos"]).float() obs["observation.state"] = torch.from_numpy(observation["agent_pos"]).float()
# apply same transforms as in training
if transform is not None:
for key in obs:
obs[key] = torch.stack([transform({key: item})[key] for item in obs[key]])
return obs return obs
def postprocess_action(action, transform=None): def postprocess_action(action):
action = action.to("cpu") action = action.to("cpu").numpy()
# action is a batch (num_env,action_dim) instead of an item (action_dim),
# we assume applying inverse transform on a batch works the same
action = apply_inverse_transform({"action": action}, transform)["action"].numpy()
assert ( assert (
action.ndim == 2 action.ndim == 2
), "we assume dimensions are respectively the number of parallel envs, action dimensions" ), "we assume dimensions are respectively the number of parallel envs, action dimensions"

View File

@ -1,4 +1,4 @@
from dataclasses import dataclass, field from dataclasses import dataclass
@dataclass @dataclass
@ -60,12 +60,14 @@ class ActionChunkingTransformerConfig:
chunk_size: int = 100 chunk_size: int = 100
n_action_steps: int = 100 n_action_steps: int = 100
# Vision preprocessing. # Normalization / Unnormalization
image_normalization_mean: tuple[float, float, float] = field( normalize_input_modes: dict[str, str] = {
default_factory=lambda: [0.485, 0.456, 0.406] "observation.image": "mean_std",
) "observation.state": "mean_std",
image_normalization_std: tuple[float, float, float] = field(default_factory=lambda: [0.229, 0.224, 0.225]) }
unnormalize_output_modes: dict[str, str] = {
"action": "mean_std",
}
# Architecture. # Architecture.
# Vision backbone. # Vision backbone.
vision_backbone: str = "resnet18" vision_backbone: str = "resnet18"

View File

@ -15,12 +15,15 @@ import numpy as np
import torch import torch
import torch.nn.functional as F # noqa: N812 import torch.nn.functional as F # noqa: N812
import torchvision import torchvision
import torchvision.transforms as transforms
from torch import Tensor, nn from torch import Tensor, nn
from torchvision.models._utils import IntermediateLayerGetter from torchvision.models._utils import IntermediateLayerGetter
from torchvision.ops.misc import FrozenBatchNorm2d from torchvision.ops.misc import FrozenBatchNorm2d
from lerobot.common.policies.act.configuration_act import ActionChunkingTransformerConfig from lerobot.common.policies.act.configuration_act import ActionChunkingTransformerConfig
from lerobot.common.policies.utils import (
normalize_inputs,
unnormalize_outputs,
)
class ActionChunkingTransformerPolicy(nn.Module): class ActionChunkingTransformerPolicy(nn.Module):
@ -62,7 +65,7 @@ class ActionChunkingTransformerPolicy(nn.Module):
name = "act" name = "act"
def __init__(self, cfg: ActionChunkingTransformerConfig | None = None): def __init__(self, cfg: ActionChunkingTransformerConfig | None = None, dataset_stats=None):
""" """
Args: Args:
cfg: Policy configuration class instance or None, in which case the default instantiation of the cfg: Policy configuration class instance or None, in which case the default instantiation of the
@ -72,6 +75,9 @@ class ActionChunkingTransformerPolicy(nn.Module):
if cfg is None: if cfg is None:
cfg = ActionChunkingTransformerConfig() cfg = ActionChunkingTransformerConfig()
self.cfg = cfg self.cfg = cfg
self.register_buffer("dataset_stats", dataset_stats)
self.normalize_input_modes = cfg.normalize_input_modes
self.unnormalize_output_modes = cfg.unnormalize_output_modes
# BERT style VAE encoder with input [cls, *joint_space_configuration, *action_sequence]. # BERT style VAE encoder with input [cls, *joint_space_configuration, *action_sequence].
# The cls token forms parameters of the latent's distribution (like this [*means, *log_variances]). # The cls token forms parameters of the latent's distribution (like this [*means, *log_variances]).
@ -93,9 +99,6 @@ class ActionChunkingTransformerPolicy(nn.Module):
) )
# Backbone for image feature extraction. # Backbone for image feature extraction.
self.image_normalizer = transforms.Normalize(
mean=cfg.image_normalization_mean, std=cfg.image_normalization_std
)
backbone_model = getattr(torchvision.models, cfg.vision_backbone)( backbone_model = getattr(torchvision.models, cfg.vision_backbone)(
replace_stride_with_dilation=[False, False, cfg.replace_final_stride_with_dilation], replace_stride_with_dilation=[False, False, cfg.replace_final_stride_with_dilation],
pretrained=cfg.use_pretrained_backbone, pretrained=cfg.use_pretrained_backbone,
@ -169,10 +172,15 @@ class ActionChunkingTransformerPolicy(nn.Module):
queue is empty. queue is empty.
""" """
self.eval() self.eval()
batch = normalize_inputs(batch, self.dataset_stats, self.normalize_input_modes)
if len(self._action_queue) == 0: if len(self._action_queue) == 0:
# `_forward` returns a (batch_size, n_action_steps, action_dim) tensor, but the queue effectively # `_forward` returns a (batch_size, n_action_steps, action_dim) tensor, but the queue effectively
# has shape (n_action_steps, batch_size, *), hence the transpose. # has shape (n_action_steps, batch_size, *), hence the transpose.
self._action_queue.extend(self._forward(batch)[0][: self.cfg.n_action_steps].transpose(0, 1)) actions = self._forward(batch)[0][: self.cfg.n_action_steps]
actions = unnormalize_outputs(actions, self.dataset_stats, self.unnormalize_output_modes)
self._action_queue.extend(actions.transpose(0, 1))
return self._action_queue.popleft() return self._action_queue.popleft()
def forward(self, batch, **_) -> dict[str, Tensor]: def forward(self, batch, **_) -> dict[str, Tensor]:
@ -203,7 +211,10 @@ class ActionChunkingTransformerPolicy(nn.Module):
"""Run the model in train mode, compute the loss, and do an optimization step.""" """Run the model in train mode, compute the loss, and do an optimization step."""
start_time = time.time() start_time = time.time()
self.train() self.train()
batch = normalize_inputs(batch, self.dataset_stats, self.normalize_input_modes)
loss_dict = self.forward(batch) loss_dict = self.forward(batch)
# TODO(rcadene): unnormalize_outputs(actions, self.dataset_stats, self.unnormalize_output_modes)
loss = loss_dict["loss"] loss = loss_dict["loss"]
loss.backward() loss.backward()
@ -309,7 +320,7 @@ class ActionChunkingTransformerPolicy(nn.Module):
# Camera observation features and positional embeddings. # Camera observation features and positional embeddings.
all_cam_features = [] all_cam_features = []
all_cam_pos_embeds = [] all_cam_pos_embeds = []
images = self.image_normalizer(batch["observation.images"]) images = batch["observation.images"]
for cam_index in range(len(self.cfg.camera_names)): for cam_index in range(len(self.cfg.camera_names)):
cam_features = self.backbone(images[:, cam_index])["feature_map"] cam_features = self.backbone(images[:, cam_index])["feature_map"]
cam_pos_embed = self.encoder_cam_feat_pos_embed(cam_features).to(dtype=cam_features.dtype) cam_pos_embed = self.encoder_cam_feat_pos_embed(cam_features).to(dtype=cam_features.dtype)

View File

@ -69,9 +69,14 @@ class DiffusionConfig:
horizon: int = 16 horizon: int = 16
n_action_steps: int = 8 n_action_steps: int = 8
# Vision preprocessing. # Normalization / Unnormalization
image_normalization_mean: tuple[float, float, float] = (0.5, 0.5, 0.5) normalize_input_modes: dict[str, str] = {
image_normalization_std: tuple[float, float, float] = (0.5, 0.5, 0.5) "observation.image": "mean_std",
"observation.state": "min_max",
}
unnormalize_output_modes: dict[str, str] = {
"action": "min_max",
}
# Architecture / modeling. # Architecture / modeling.
# Vision backbone. # Vision backbone.

View File

@ -13,7 +13,6 @@ import logging
import math import math
import time import time
from collections import deque from collections import deque
from itertools import chain
from typing import Callable from typing import Callable
import einops import einops
@ -30,7 +29,9 @@ from lerobot.common.policies.diffusion.configuration_diffusion import DiffusionC
from lerobot.common.policies.utils import ( from lerobot.common.policies.utils import (
get_device_from_parameters, get_device_from_parameters,
get_dtype_from_parameters, get_dtype_from_parameters,
normalize_inputs,
populate_queues, populate_queues,
unnormalize_outputs,
) )
@ -42,7 +43,9 @@ class DiffusionPolicy(nn.Module):
name = "diffusion" name = "diffusion"
def __init__(self, cfg: DiffusionConfig | None, lr_scheduler_num_training_steps: int = 0): def __init__(
self, cfg: DiffusionConfig | None, lr_scheduler_num_training_steps: int = 0, dataset_stats=None
):
""" """
Args: Args:
cfg: Policy configuration class instance or None, in which case the default instantiation of the cfg: Policy configuration class instance or None, in which case the default instantiation of the
@ -54,6 +57,9 @@ class DiffusionPolicy(nn.Module):
if cfg is None: if cfg is None:
cfg = DiffusionConfig() cfg = DiffusionConfig()
self.cfg = cfg self.cfg = cfg
self.register_buffer("dataset_stats", dataset_stats)
self.normalize_input_modes = cfg.normalize_input_modes
self.unnormalize_output_modes = cfg.unnormalize_output_modes
# queues are populated during rollout of the policy, they contain the n latest observations and actions # queues are populated during rollout of the policy, they contain the n latest observations and actions
self._queues = None self._queues = None
@ -126,6 +132,8 @@ class DiffusionPolicy(nn.Module):
assert "observation.state" in batch assert "observation.state" in batch
assert len(batch) == 2 assert len(batch) == 2
batch = normalize_inputs(batch, self.dataset_stats, self.normalize_input_modes)
self._queues = populate_queues(self._queues, batch) self._queues = populate_queues(self._queues, batch)
if len(self._queues["action"]) == 0: if len(self._queues["action"]) == 0:
@ -135,6 +143,8 @@ class DiffusionPolicy(nn.Module):
actions = self.ema_diffusion.generate_actions(batch) actions = self.ema_diffusion.generate_actions(batch)
else: else:
actions = self.diffusion.generate_actions(batch) actions = self.diffusion.generate_actions(batch)
actions = unnormalize_outputs(actions, self.dataset_stats, self.unnormalize_output_modes)
self._queues["action"].extend(actions.transpose(0, 1)) self._queues["action"].extend(actions.transpose(0, 1))
action = self._queues["action"].popleft() action = self._queues["action"].popleft()
@ -151,9 +161,13 @@ class DiffusionPolicy(nn.Module):
self.diffusion.train() self.diffusion.train()
batch = normalize_inputs(batch, self.dataset_stats, self.normalize_input_modes)
loss = self.forward(batch)["loss"] loss = self.forward(batch)["loss"]
loss.backward() loss.backward()
# TODO(rcadene): unnormalize_outputs(actions, self.dataset_stats, self.unnormalize_output_modes)
grad_norm = torch.nn.utils.clip_grad_norm_( grad_norm = torch.nn.utils.clip_grad_norm_(
self.diffusion.parameters(), self.diffusion.parameters(),
self.cfg.grad_clip_norm, self.cfg.grad_clip_norm,
@ -346,12 +360,6 @@ class _RgbEncoder(nn.Module):
def __init__(self, cfg: DiffusionConfig): def __init__(self, cfg: DiffusionConfig):
super().__init__() super().__init__()
# Set up optional preprocessing. # Set up optional preprocessing.
if all(v == 1.0 for v in chain(cfg.image_normalization_mean, cfg.image_normalization_std)):
self.normalizer = nn.Identity()
else:
self.normalizer = torchvision.transforms.Normalize(
mean=cfg.image_normalization_mean, std=cfg.image_normalization_std
)
if cfg.crop_shape is not None: if cfg.crop_shape is not None:
self.do_crop = True self.do_crop = True
# Always use center crop for eval # Always use center crop for eval
@ -397,8 +405,7 @@ class _RgbEncoder(nn.Module):
Returns: Returns:
(B, D) image feature. (B, D) image feature.
""" """
# Preprocess: normalize and maybe crop (if it was set up in the __init__). # Preprocess: maybe crop (if it was set up in the __init__).
x = self.normalizer(x)
if self.do_crop: if self.do_crop:
if self.training: # noqa: SIM108 if self.training: # noqa: SIM108
x = self.maybe_random_crop(x) x = self.maybe_random_crop(x)

View File

@ -20,7 +20,7 @@ def _policy_cfg_from_hydra_cfg(policy_cfg_class, hydra_cfg):
return policy_cfg return policy_cfg
def make_policy(hydra_cfg: DictConfig): def make_policy(hydra_cfg: DictConfig, dataset_stats=None):
if hydra_cfg.policy.name == "tdmpc": if hydra_cfg.policy.name == "tdmpc":
from lerobot.common.policies.tdmpc.policy import TDMPCPolicy from lerobot.common.policies.tdmpc.policy import TDMPCPolicy
@ -35,14 +35,14 @@ def make_policy(hydra_cfg: DictConfig):
from lerobot.common.policies.diffusion.modeling_diffusion import DiffusionPolicy from lerobot.common.policies.diffusion.modeling_diffusion import DiffusionPolicy
policy_cfg = _policy_cfg_from_hydra_cfg(DiffusionConfig, hydra_cfg) policy_cfg = _policy_cfg_from_hydra_cfg(DiffusionConfig, hydra_cfg)
policy = DiffusionPolicy(policy_cfg, hydra_cfg.offline_steps) policy = DiffusionPolicy(policy_cfg, hydra_cfg.offline_steps, dataset_stats)
policy.to(get_safe_torch_device(hydra_cfg.device)) policy.to(get_safe_torch_device(hydra_cfg.device))
elif hydra_cfg.policy.name == "act": elif hydra_cfg.policy.name == "act":
from lerobot.common.policies.act.configuration_act import ActionChunkingTransformerConfig from lerobot.common.policies.act.configuration_act import ActionChunkingTransformerConfig
from lerobot.common.policies.act.modeling_act import ActionChunkingTransformerPolicy from lerobot.common.policies.act.modeling_act import ActionChunkingTransformerPolicy
policy_cfg = _policy_cfg_from_hydra_cfg(ActionChunkingTransformerConfig, hydra_cfg) policy_cfg = _policy_cfg_from_hydra_cfg(ActionChunkingTransformerConfig, hydra_cfg)
policy = ActionChunkingTransformerPolicy(policy_cfg) policy = ActionChunkingTransformerPolicy(policy_cfg, dataset_stats)
policy.to(get_safe_torch_device(hydra_cfg.device)) policy.to(get_safe_torch_device(hydra_cfg.device))
else: else:
raise ValueError(hydra_cfg.policy.name) raise ValueError(hydra_cfg.policy.name)

View File

@ -28,3 +28,41 @@ def get_dtype_from_parameters(module: nn.Module) -> torch.dtype:
Note: assumes that all parameters have the same dtype. Note: assumes that all parameters have the same dtype.
""" """
return next(iter(module.parameters())).dtype return next(iter(module.parameters())).dtype
def normalize_inputs(batch, stats, normalize_input_modes):
if normalize_input_modes is None:
return batch
for key, mode in normalize_input_modes.items():
if mode == "mean_std":
mean = stats[key]["mean"].unsqueeze(0)
std = stats[key]["std"].unsqueeze(0)
batch[key] = (batch[key] - mean) / (std + 1e-8)
elif mode == "min_max":
min = stats[key]["min"].unsqueeze(0)
max = stats[key]["max"].unsqueeze(0)
# normalize to [0,1]
batch[key] = (batch[key] - min) / (max - min)
# normalize to [-1, 1]
batch[key] = batch[key] * 2 - 1
else:
raise ValueError(mode)
return batch
def unnormalize_outputs(batch, stats, unnormalize_output_modes):
if unnormalize_output_modes is None:
return batch
for key, mode in unnormalize_output_modes.items():
if mode == "mean_std":
mean = stats[key]["mean"].unsqueeze(0)
std = stats[key]["std"].unsqueeze(0)
batch[key] = batch[key] * std + mean
elif mode == "min_max":
min = stats[key]["min"].unsqueeze(0)
max = stats[key]["max"].unsqueeze(0)
batch[key] = (batch[key] + 1) / 2
batch[key] = batch[key] * (max - min) + min
else:
raise ValueError(mode)
return batch

View File

@ -1,65 +0,0 @@
from torchvision.transforms.v2 import Compose, Transform
def apply_inverse_transform(item, transform):
transforms = transform.transforms if isinstance(transform, Compose) else [transform]
for tf in transforms[::-1]:
if tf.invertible:
item = tf.inverse_transform(item)
else:
raise ValueError(f"Inverse transform called on a non invertible transform ({tf}).")
return item
class NormalizeTransform(Transform):
invertible = True
def __init__(
self,
stats: dict,
in_keys: list[str] = None,
out_keys: list[str] | None = None,
in_keys_inv: list[str] | None = None,
out_keys_inv: list[str] | None = None,
mode="mean_std",
):
super().__init__()
self.in_keys = in_keys
self.out_keys = in_keys if out_keys is None else out_keys
self.in_keys_inv = self.out_keys if in_keys_inv is None else in_keys_inv
self.out_keys_inv = self.in_keys if out_keys_inv is None else out_keys_inv
self.stats = stats
assert mode in ["mean_std", "min_max"]
self.mode = mode
def forward(self, item):
for inkey, outkey in zip(self.in_keys, self.out_keys, strict=False):
if inkey not in item:
continue
if self.mode == "mean_std":
mean = self.stats[inkey]["mean"]
std = self.stats[inkey]["std"]
item[outkey] = (item[inkey] - mean) / (std + 1e-8)
else:
min = self.stats[inkey]["min"]
max = self.stats[inkey]["max"]
# normalize to [0,1]
item[outkey] = (item[inkey] - min) / (max - min)
# normalize to [-1, 1]
item[outkey] = item[outkey] * 2 - 1
return item
def inverse_transform(self, item):
for inkey, outkey in zip(self.in_keys_inv, self.out_keys_inv, strict=False):
if inkey not in item:
continue
if self.mode == "mean_std":
mean = self.stats[inkey]["mean"]
std = self.stats[inkey]["std"]
item[outkey] = item[inkey] * std + mean
else:
min = self.stats[inkey]["min"]
max = self.stats[inkey]["max"]
item[outkey] = (item[inkey] + 1) / 2
item[outkey] = item[outkey] * (max - min) + min
return item

View File

@ -11,6 +11,11 @@ log_freq: 250
n_obs_steps: 1 n_obs_steps: 1
# when temporal_agg=False, n_action_steps=horizon # when temporal_agg=False, n_action_steps=horizon
override_dataset_stats:
observation.image:
mean: [[[0.485]], [[0.456]], [[0.406]]] # (c,1,1)
std: [[[0.229]], [[0.224]], [[0.225]]] # (c,1,1)
# See `configuration_act.py` for more details. # See `configuration_act.py` for more details.
policy: policy:
name: act name: act
@ -28,9 +33,12 @@ policy:
chunk_size: 100 # chunk_size chunk_size: 100 # chunk_size
n_action_steps: 100 n_action_steps: 100
# Vision preprocessing. # Normalization / Unnormalization
image_normalization_mean: [0.485, 0.456, 0.406] normalize_input_modes:
image_normalization_std: [0.229, 0.224, 0.225] observation.image: mean_std
observation.state: mean_std
unnormalize_output_modes:
action: mean_std
# Architecture. # Architecture.
# Vision backbone. # Vision backbone.

View File

@ -18,6 +18,17 @@ online_steps: 0
offline_prioritized_sampler: true offline_prioritized_sampler: true
override_dataset_stats:
observation.image:
mean: [[[0.5]], [[0.5]], [[0.5]]] # (c,1,1)
std: [[[0.5]], [[0.5]], [[0.5]]] # (c,1,1)
observation.state:
min: [13.456424, 32.938293]
max: [496.14618, 510.9579]
action:
min: [12.0, 25.0]
max: [511.0, 511.0]
policy: policy:
name: diffusion name: diffusion
@ -36,9 +47,12 @@ policy:
horizon: ${horizon} horizon: ${horizon}
n_action_steps: ${n_action_steps} n_action_steps: ${n_action_steps}
# Vision preprocessing. # Normalization / Unnormalization
image_normalization_mean: [0.5, 0.5, 0.5] normalize_input_modes:
image_normalization_std: [0.5, 0.5, 0.5] observation.image: mean_std
observation.state: min_max
unnormalize_output_modes:
action: min_max
# Architecture / modeling. # Architecture / modeling.
# Vision backbone. # Vision backbone.

View File

@ -46,7 +46,6 @@ from huggingface_hub import snapshot_download
from PIL import Image as PILImage from PIL import Image as PILImage
from tqdm import trange from tqdm import trange
from lerobot.common.datasets.factory import make_dataset
from lerobot.common.datasets.utils import hf_transform_to_torch from lerobot.common.datasets.utils import hf_transform_to_torch
from lerobot.common.envs.factory import make_env from lerobot.common.envs.factory import make_env
from lerobot.common.envs.utils import postprocess_action, preprocess_observation from lerobot.common.envs.utils import postprocess_action, preprocess_observation
@ -64,8 +63,6 @@ def eval_policy(
policy: torch.nn.Module, policy: torch.nn.Module,
max_episodes_rendered: int = 0, max_episodes_rendered: int = 0,
video_dir: Path = None, video_dir: Path = None,
# TODO(rcadene): make it possible to overwrite fps? we should use env.fps
transform: callable = None,
return_episode_data: bool = False, return_episode_data: bool = False,
seed=None, seed=None,
): ):
@ -132,10 +129,6 @@ def eval_policy(
if return_episode_data: if return_episode_data:
observations.append(deepcopy(observation)) observations.append(deepcopy(observation))
# apply transform to normalize the observations
for key in observation:
observation[key] = torch.stack([transform({key: item})[key] for item in observation[key]])
# send observation to device/gpu # send observation to device/gpu
observation = {key: observation[key].to(device, non_blocking=True) for key in observation} observation = {key: observation[key].to(device, non_blocking=True) for key in observation}
@ -143,8 +136,8 @@ def eval_policy(
with torch.inference_mode(): with torch.inference_mode():
action = policy.select_action(observation, step=step) action = policy.select_action(observation, step=step)
# apply inverse transform to unnormalize the action # convert to cpu numpy
action = postprocess_action(action, transform) action = postprocess_action(action)
# apply the next action # apply the next action
observation, reward, terminated, truncated, info = env.step(action) observation, reward, terminated, truncated, info = env.step(action)
@ -360,7 +353,7 @@ def eval_policy(
return info return info
def eval(cfg: dict, out_dir=None, stats_path=None): def eval(cfg: dict, out_dir=None):
if out_dir is None: if out_dir is None:
raise NotImplementedError() raise NotImplementedError()
@ -375,10 +368,6 @@ def eval(cfg: dict, out_dir=None, stats_path=None):
log_output_dir(out_dir) log_output_dir(out_dir)
logging.info("Making transforms.")
# TODO(alexander-soare): Completely decouple datasets from evaluation.
transform = make_dataset(cfg, stats_path=stats_path).transform
logging.info("Making environment.") logging.info("Making environment.")
env = make_env(cfg, num_parallel_envs=cfg.eval_episodes) env = make_env(cfg, num_parallel_envs=cfg.eval_episodes)
@ -390,7 +379,6 @@ def eval(cfg: dict, out_dir=None, stats_path=None):
policy, policy,
max_episodes_rendered=10, max_episodes_rendered=10,
video_dir=Path(out_dir) / "eval", video_dir=Path(out_dir) / "eval",
transform=transform,
return_episode_data=False, return_episode_data=False,
seed=cfg.seed, seed=cfg.seed,
) )
@ -423,17 +411,13 @@ if __name__ == "__main__":
if args.config is not None: if args.config is not None:
# Note: For the config_path, Hydra wants a path relative to this script file. # Note: For the config_path, Hydra wants a path relative to this script file.
cfg = init_hydra_config(args.config, args.overrides) cfg = init_hydra_config(args.config, args.overrides)
# TODO(alexander-soare): Save and load stats in trained model directory.
stats_path = None
elif args.hub_id is not None: elif args.hub_id is not None:
folder = Path(snapshot_download(args.hub_id, revision=args.revision)) folder = Path(snapshot_download(args.hub_id, revision=args.revision))
cfg = init_hydra_config( cfg = init_hydra_config(
folder / "config.yaml", [f"policy.pretrained_model_path={folder / 'model.pt'}", *args.overrides] folder / "config.yaml", [f"policy.pretrained_model_path={folder / 'model.pt'}", *args.overrides]
) )
stats_path = folder / "stats.pth"
eval( eval(
cfg, cfg,
out_dir=f"outputs/eval/{dt.now().strftime('%Y-%m-%d/%H-%M-%S')}_{cfg.env.name}_{cfg.policy.name}", out_dir=f"outputs/eval/{dt.now().strftime('%Y-%m-%d/%H-%M-%S')}_{cfg.env.name}_{cfg.policy.name}",
stats_path=stats_path,
) )

View File

@ -232,7 +232,7 @@ def train(cfg: dict, out_dir=None, job_name=None):
env = make_env(cfg, num_parallel_envs=cfg.eval_episodes) env = make_env(cfg, num_parallel_envs=cfg.eval_episodes)
logging.info("make_policy") logging.info("make_policy")
policy = make_policy(cfg) policy = make_policy(cfg, dataset_stats=offline_dataset.stats)
num_learnable_params = sum(p.numel() for p in policy.parameters() if p.requires_grad) num_learnable_params = sum(p.numel() for p in policy.parameters() if p.requires_grad)
num_total_params = sum(p.numel() for p in policy.parameters()) num_total_params = sum(p.numel() for p in policy.parameters())

View File

@ -42,7 +42,7 @@ def test_factory(env_name):
env = make_env(cfg, num_parallel_envs=1) env = make_env(cfg, num_parallel_envs=1)
obs, _ = env.reset() obs, _ = env.reset()
obs = preprocess_observation(obs, transform=dataset.transform) obs = preprocess_observation(obs)
for key in dataset.image_keys: for key in dataset.image_keys:
img = obs[key] img = obs[key]
assert img.dtype == torch.float32 assert img.dtype == torch.float32

View File

@ -51,7 +51,7 @@ def test_examples_4_and_3():
# Pass empty globals to allow dictionary comprehension https://stackoverflow.com/a/32897127/4391249. # Pass empty globals to allow dictionary comprehension https://stackoverflow.com/a/32897127/4391249.
exec(file_contents, {}) exec(file_contents, {})
for file_name in ["model.pt", "stats.pth", "config.yaml"]: for file_name in ["model.pt", "config.yaml"]:
assert Path(f"outputs/train/example_pusht_diffusion/{file_name}").exists() assert Path(f"outputs/train/example_pusht_diffusion/{file_name}").exists()
path = "examples/3_evaluate_pretrained_policy.py" path = "examples/3_evaluate_pretrained_policy.py"

View File

@ -44,14 +44,16 @@ def test_policy(env_name, policy_name, extra_overrides):
] ]
+ extra_overrides, + extra_overrides,
) )
# Check that we can make the policy object. # Check that we can make the policy object.
policy = make_policy(cfg) dataset = make_dataset(cfg)
policy = make_policy(cfg, dataset_stats=dataset.stats)
# Check that the policy follows the required protocol. # Check that the policy follows the required protocol.
assert isinstance( assert isinstance(
policy, Policy policy, Policy
), f"The policy does not follow the required protocol. Please see {Policy.__module__}.{Policy.__name__}." ), f"The policy does not follow the required protocol. Please see {Policy.__module__}.{Policy.__name__}."
# Check that we run select_actions and get the appropriate output. # Check that we run select_actions and get the appropriate output.
dataset = make_dataset(cfg)
env = make_env(cfg, num_parallel_envs=2) env = make_env(cfg, num_parallel_envs=2)
dataloader = torch.utils.data.DataLoader( dataloader = torch.utils.data.DataLoader(
@ -77,7 +79,7 @@ def test_policy(env_name, policy_name, extra_overrides):
observation, _ = env.reset(seed=cfg.seed) observation, _ = env.reset(seed=cfg.seed)
# apply transform to normalize the observations # apply transform to normalize the observations
observation = preprocess_observation(observation, dataset.transform) observation = preprocess_observation(observation)
# send observation to device/gpu # send observation to device/gpu
observation = {key: observation[key].to(DEVICE, non_blocking=True) for key in observation} observation = {key: observation[key].to(DEVICE, non_blocking=True) for key in observation}
@ -86,8 +88,8 @@ def test_policy(env_name, policy_name, extra_overrides):
with torch.inference_mode(): with torch.inference_mode():
action = policy.select_action(observation, step=0) action = policy.select_action(observation, step=0)
# apply inverse transform to unnormalize the action # convert action to cpu numpy array
action = postprocess_action(action, dataset.transform) action = postprocess_action(action)
# Test step through policy # Test step through policy
env.step(action) env.step(action)