diff --git a/README.md b/README.md index a0045bf2..8b78ca3e 100644 --- a/README.md +++ b/README.md @@ -263,15 +263,13 @@ Secondly, assuming you have trained a policy, you need: - `config.yaml` which you can get from the `.hydra` directory of your training output folder. - `model.pt` which should be one of the saved models in the `models` directory of your training output folder (they won't be named `model.pt` but you will need to choose one). -- `stats.pth` which should point to the same file in the dataset directory (found in `data/{dataset_name}`). To upload these to the hub, prepare a folder with the following structure (you can use symlinks rather than copying): ``` to_upload ├── config.yaml - ├── model.pt - └── stats.pth + └── model.pt ``` With the folder prepared, run the following with a desired revision ID. diff --git a/examples/3_evaluate_pretrained_policy.py b/examples/3_evaluate_pretrained_policy.py index a892fa23..392ad1c6 100644 --- a/examples/3_evaluate_pretrained_policy.py +++ b/examples/3_evaluate_pretrained_policy.py @@ -19,7 +19,6 @@ folder = Path(snapshot_download(hub_id)) config_path = folder / "config.yaml" weights_path = folder / "model.pt" -stats_path = folder / "stats.pth" # normalization stats # Override some config parameters to do with evaluation. overrides = [ @@ -36,5 +35,4 @@ cfg = init_hydra_config(config_path, overrides) eval( cfg, out_dir=f"outputs/eval/example_{cfg.env.name}_{cfg.policy.name}", - stats_path=stats_path, ) diff --git a/examples/4_train_policy.py b/examples/4_train_policy.py index 1ccb40d6..6068f3b8 100644 --- a/examples/4_train_policy.py +++ b/examples/4_train_policy.py @@ -62,7 +62,6 @@ while not done: done = True break -# Save the policy, configuration, and normalization stats for later use. +# Save the policy and configuration for later use. policy.save(output_directory / "model.pt") OmegaConf.save(hydra_cfg, output_directory / "config.yaml") -torch.save(dataset.transform.transforms[-1].stats, output_directory / "stats.pth") diff --git a/lerobot/common/datasets/factory.py b/lerobot/common/datasets/factory.py index 0fbfff65..59c65f1f 100644 --- a/lerobot/common/datasets/factory.py +++ b/lerobot/common/datasets/factory.py @@ -2,18 +2,12 @@ import os from pathlib import Path import torch -from torchvision.transforms import v2 - -from lerobot.common.transforms import NormalizeTransform DATA_DIR = Path(os.environ["DATA_DIR"]) if "DATA_DIR" in os.environ else None def make_dataset( cfg, - # set normalize=False to remove all transformations and keep images unnormalized in [0,255] - normalize=True, - stats_path=None, split="train", ): if cfg.env.name == "xarm": @@ -33,58 +27,23 @@ def make_dataset( else: raise ValueError(cfg.env.name) - transforms = None - if normalize: - # TODO(rcadene): make normalization strategy configurable between mean_std, min_max, manual_min_max, - # min_max_from_spec - # TODO(rcadene): remove this and put it in config. Ideally we want to reproduce SOTA results just with mean_std - normalization_mode = "mean_std" if cfg.env.name == "aloha" else "min_max" - - if cfg.policy.name == "diffusion" and cfg.env.name == "pusht": - stats = {} - # TODO(rcadene): we overwrite stats to have the same as pretrained model, but we should remove this - stats["observation.state"] = {} - stats["observation.state"]["min"] = torch.tensor([13.456424, 32.938293], dtype=torch.float32) - stats["observation.state"]["max"] = torch.tensor([496.14618, 510.9579], dtype=torch.float32) - stats["action"] = {} - stats["action"]["min"] = torch.tensor([12.0, 25.0], dtype=torch.float32) - stats["action"]["max"] = torch.tensor([511.0, 511.0], dtype=torch.float32) - elif stats_path is None: - # load a first dataset to access precomputed stats - stats_dataset = clsfunc( - dataset_id=cfg.dataset_id, - split="train", - root=DATA_DIR, - ) - stats = stats_dataset.stats - else: - stats = torch.load(stats_path) - - transforms = v2.Compose( - [ - NormalizeTransform( - stats, - in_keys=[ - "observation.state", - "action", - ], - mode=normalization_mode, - ), - ] - ) - delta_timestamps = cfg.policy.get("delta_timestamps") if delta_timestamps is not None: for key in delta_timestamps: if isinstance(delta_timestamps[key], str): delta_timestamps[key] = eval(delta_timestamps[key]) + # TODO(rcadene): add data augmentations + dataset = clsfunc( dataset_id=cfg.dataset_id, split=split, root=DATA_DIR, delta_timestamps=delta_timestamps, - transform=transforms, ) + if cfg.get("override_dataset_stats"): + for key, val in cfg.override_dataset_stats.items(): + dataset.stats[key] = torch.tensor(val) + return dataset diff --git a/lerobot/common/envs/utils.py b/lerobot/common/envs/utils.py index dcce1bcc..4a72ea52 100644 --- a/lerobot/common/envs/utils.py +++ b/lerobot/common/envs/utils.py @@ -1,10 +1,8 @@ import einops import torch -from lerobot.common.transforms import apply_inverse_transform - -def preprocess_observation(observation, transform=None): +def preprocess_observation(observation): # map to expected inputs for the policy obs = {} @@ -33,19 +31,11 @@ def preprocess_observation(observation, transform=None): # TODO(rcadene): enable pixels only baseline with `obs_type="pixels"` in environment by removing requirement for "agent_pos" obs["observation.state"] = torch.from_numpy(observation["agent_pos"]).float() - # apply same transforms as in training - if transform is not None: - for key in obs: - obs[key] = torch.stack([transform({key: item})[key] for item in obs[key]]) - return obs -def postprocess_action(action, transform=None): - action = action.to("cpu") - # action is a batch (num_env,action_dim) instead of an item (action_dim), - # we assume applying inverse transform on a batch works the same - action = apply_inverse_transform({"action": action}, transform)["action"].numpy() +def postprocess_action(action): + action = action.to("cpu").numpy() assert ( action.ndim == 2 ), "we assume dimensions are respectively the number of parallel envs, action dimensions" diff --git a/lerobot/common/policies/act/configuration_act.py b/lerobot/common/policies/act/configuration_act.py index 211a8ed0..2a29d994 100644 --- a/lerobot/common/policies/act/configuration_act.py +++ b/lerobot/common/policies/act/configuration_act.py @@ -1,4 +1,4 @@ -from dataclasses import dataclass, field +from dataclasses import dataclass @dataclass @@ -60,12 +60,14 @@ class ActionChunkingTransformerConfig: chunk_size: int = 100 n_action_steps: int = 100 - # Vision preprocessing. - image_normalization_mean: tuple[float, float, float] = field( - default_factory=lambda: [0.485, 0.456, 0.406] - ) - image_normalization_std: tuple[float, float, float] = field(default_factory=lambda: [0.229, 0.224, 0.225]) - + # Normalization / Unnormalization + normalize_input_modes: dict[str, str] = { + "observation.image": "mean_std", + "observation.state": "mean_std", + } + unnormalize_output_modes: dict[str, str] = { + "action": "mean_std", + } # Architecture. # Vision backbone. vision_backbone: str = "resnet18" diff --git a/lerobot/common/policies/act/modeling_act.py b/lerobot/common/policies/act/modeling_act.py index c1af4ef4..78010255 100644 --- a/lerobot/common/policies/act/modeling_act.py +++ b/lerobot/common/policies/act/modeling_act.py @@ -15,12 +15,15 @@ import numpy as np import torch import torch.nn.functional as F # noqa: N812 import torchvision -import torchvision.transforms as transforms from torch import Tensor, nn from torchvision.models._utils import IntermediateLayerGetter from torchvision.ops.misc import FrozenBatchNorm2d from lerobot.common.policies.act.configuration_act import ActionChunkingTransformerConfig +from lerobot.common.policies.utils import ( + normalize_inputs, + unnormalize_outputs, +) class ActionChunkingTransformerPolicy(nn.Module): @@ -62,7 +65,7 @@ class ActionChunkingTransformerPolicy(nn.Module): name = "act" - def __init__(self, cfg: ActionChunkingTransformerConfig | None = None): + def __init__(self, cfg: ActionChunkingTransformerConfig | None = None, dataset_stats=None): """ Args: cfg: Policy configuration class instance or None, in which case the default instantiation of the @@ -72,6 +75,9 @@ class ActionChunkingTransformerPolicy(nn.Module): if cfg is None: cfg = ActionChunkingTransformerConfig() self.cfg = cfg + self.register_buffer("dataset_stats", dataset_stats) + self.normalize_input_modes = cfg.normalize_input_modes + self.unnormalize_output_modes = cfg.unnormalize_output_modes # BERT style VAE encoder with input [cls, *joint_space_configuration, *action_sequence]. # The cls token forms parameters of the latent's distribution (like this [*means, *log_variances]). @@ -93,9 +99,6 @@ class ActionChunkingTransformerPolicy(nn.Module): ) # Backbone for image feature extraction. - self.image_normalizer = transforms.Normalize( - mean=cfg.image_normalization_mean, std=cfg.image_normalization_std - ) backbone_model = getattr(torchvision.models, cfg.vision_backbone)( replace_stride_with_dilation=[False, False, cfg.replace_final_stride_with_dilation], pretrained=cfg.use_pretrained_backbone, @@ -169,10 +172,15 @@ class ActionChunkingTransformerPolicy(nn.Module): queue is empty. """ self.eval() + + batch = normalize_inputs(batch, self.dataset_stats, self.normalize_input_modes) + if len(self._action_queue) == 0: # `_forward` returns a (batch_size, n_action_steps, action_dim) tensor, but the queue effectively # has shape (n_action_steps, batch_size, *), hence the transpose. - self._action_queue.extend(self._forward(batch)[0][: self.cfg.n_action_steps].transpose(0, 1)) + actions = self._forward(batch)[0][: self.cfg.n_action_steps] + actions = unnormalize_outputs(actions, self.dataset_stats, self.unnormalize_output_modes) + self._action_queue.extend(actions.transpose(0, 1)) return self._action_queue.popleft() def forward(self, batch, **_) -> dict[str, Tensor]: @@ -203,7 +211,10 @@ class ActionChunkingTransformerPolicy(nn.Module): """Run the model in train mode, compute the loss, and do an optimization step.""" start_time = time.time() self.train() + + batch = normalize_inputs(batch, self.dataset_stats, self.normalize_input_modes) loss_dict = self.forward(batch) + # TODO(rcadene): unnormalize_outputs(actions, self.dataset_stats, self.unnormalize_output_modes) loss = loss_dict["loss"] loss.backward() @@ -309,7 +320,7 @@ class ActionChunkingTransformerPolicy(nn.Module): # Camera observation features and positional embeddings. all_cam_features = [] all_cam_pos_embeds = [] - images = self.image_normalizer(batch["observation.images"]) + images = batch["observation.images"] for cam_index in range(len(self.cfg.camera_names)): cam_features = self.backbone(images[:, cam_index])["feature_map"] cam_pos_embed = self.encoder_cam_feat_pos_embed(cam_features).to(dtype=cam_features.dtype) diff --git a/lerobot/common/policies/diffusion/configuration_diffusion.py b/lerobot/common/policies/diffusion/configuration_diffusion.py index d8820a0b..b868adad 100644 --- a/lerobot/common/policies/diffusion/configuration_diffusion.py +++ b/lerobot/common/policies/diffusion/configuration_diffusion.py @@ -69,9 +69,14 @@ class DiffusionConfig: horizon: int = 16 n_action_steps: int = 8 - # Vision preprocessing. - image_normalization_mean: tuple[float, float, float] = (0.5, 0.5, 0.5) - image_normalization_std: tuple[float, float, float] = (0.5, 0.5, 0.5) + # Normalization / Unnormalization + normalize_input_modes: dict[str, str] = { + "observation.image": "mean_std", + "observation.state": "min_max", + } + unnormalize_output_modes: dict[str, str] = { + "action": "min_max", + } # Architecture / modeling. # Vision backbone. diff --git a/lerobot/common/policies/diffusion/modeling_diffusion.py b/lerobot/common/policies/diffusion/modeling_diffusion.py index e7cc62f4..80ab4eba 100644 --- a/lerobot/common/policies/diffusion/modeling_diffusion.py +++ b/lerobot/common/policies/diffusion/modeling_diffusion.py @@ -13,7 +13,6 @@ import logging import math import time from collections import deque -from itertools import chain from typing import Callable import einops @@ -30,7 +29,9 @@ from lerobot.common.policies.diffusion.configuration_diffusion import DiffusionC from lerobot.common.policies.utils import ( get_device_from_parameters, get_dtype_from_parameters, + normalize_inputs, populate_queues, + unnormalize_outputs, ) @@ -42,7 +43,9 @@ class DiffusionPolicy(nn.Module): name = "diffusion" - def __init__(self, cfg: DiffusionConfig | None, lr_scheduler_num_training_steps: int = 0): + def __init__( + self, cfg: DiffusionConfig | None, lr_scheduler_num_training_steps: int = 0, dataset_stats=None + ): """ Args: cfg: Policy configuration class instance or None, in which case the default instantiation of the @@ -54,6 +57,9 @@ class DiffusionPolicy(nn.Module): if cfg is None: cfg = DiffusionConfig() self.cfg = cfg + self.register_buffer("dataset_stats", dataset_stats) + self.normalize_input_modes = cfg.normalize_input_modes + self.unnormalize_output_modes = cfg.unnormalize_output_modes # queues are populated during rollout of the policy, they contain the n latest observations and actions self._queues = None @@ -126,6 +132,8 @@ class DiffusionPolicy(nn.Module): assert "observation.state" in batch assert len(batch) == 2 + batch = normalize_inputs(batch, self.dataset_stats, self.normalize_input_modes) + self._queues = populate_queues(self._queues, batch) if len(self._queues["action"]) == 0: @@ -135,6 +143,8 @@ class DiffusionPolicy(nn.Module): actions = self.ema_diffusion.generate_actions(batch) else: actions = self.diffusion.generate_actions(batch) + + actions = unnormalize_outputs(actions, self.dataset_stats, self.unnormalize_output_modes) self._queues["action"].extend(actions.transpose(0, 1)) action = self._queues["action"].popleft() @@ -151,9 +161,13 @@ class DiffusionPolicy(nn.Module): self.diffusion.train() + batch = normalize_inputs(batch, self.dataset_stats, self.normalize_input_modes) + loss = self.forward(batch)["loss"] loss.backward() + # TODO(rcadene): unnormalize_outputs(actions, self.dataset_stats, self.unnormalize_output_modes) + grad_norm = torch.nn.utils.clip_grad_norm_( self.diffusion.parameters(), self.cfg.grad_clip_norm, @@ -346,12 +360,6 @@ class _RgbEncoder(nn.Module): def __init__(self, cfg: DiffusionConfig): super().__init__() # Set up optional preprocessing. - if all(v == 1.0 for v in chain(cfg.image_normalization_mean, cfg.image_normalization_std)): - self.normalizer = nn.Identity() - else: - self.normalizer = torchvision.transforms.Normalize( - mean=cfg.image_normalization_mean, std=cfg.image_normalization_std - ) if cfg.crop_shape is not None: self.do_crop = True # Always use center crop for eval @@ -397,8 +405,7 @@ class _RgbEncoder(nn.Module): Returns: (B, D) image feature. """ - # Preprocess: normalize and maybe crop (if it was set up in the __init__). - x = self.normalizer(x) + # Preprocess: maybe crop (if it was set up in the __init__). if self.do_crop: if self.training: # noqa: SIM108 x = self.maybe_random_crop(x) diff --git a/lerobot/common/policies/factory.py b/lerobot/common/policies/factory.py index 9698175d..a8235388 100644 --- a/lerobot/common/policies/factory.py +++ b/lerobot/common/policies/factory.py @@ -20,7 +20,7 @@ def _policy_cfg_from_hydra_cfg(policy_cfg_class, hydra_cfg): return policy_cfg -def make_policy(hydra_cfg: DictConfig): +def make_policy(hydra_cfg: DictConfig, dataset_stats=None): if hydra_cfg.policy.name == "tdmpc": from lerobot.common.policies.tdmpc.policy import TDMPCPolicy @@ -35,14 +35,14 @@ def make_policy(hydra_cfg: DictConfig): from lerobot.common.policies.diffusion.modeling_diffusion import DiffusionPolicy policy_cfg = _policy_cfg_from_hydra_cfg(DiffusionConfig, hydra_cfg) - policy = DiffusionPolicy(policy_cfg, hydra_cfg.offline_steps) + policy = DiffusionPolicy(policy_cfg, hydra_cfg.offline_steps, dataset_stats) policy.to(get_safe_torch_device(hydra_cfg.device)) elif hydra_cfg.policy.name == "act": from lerobot.common.policies.act.configuration_act import ActionChunkingTransformerConfig from lerobot.common.policies.act.modeling_act import ActionChunkingTransformerPolicy policy_cfg = _policy_cfg_from_hydra_cfg(ActionChunkingTransformerConfig, hydra_cfg) - policy = ActionChunkingTransformerPolicy(policy_cfg) + policy = ActionChunkingTransformerPolicy(policy_cfg, dataset_stats) policy.to(get_safe_torch_device(hydra_cfg.device)) else: raise ValueError(hydra_cfg.policy.name) diff --git a/lerobot/common/policies/utils.py b/lerobot/common/policies/utils.py index b23c1336..658f8085 100644 --- a/lerobot/common/policies/utils.py +++ b/lerobot/common/policies/utils.py @@ -28,3 +28,41 @@ def get_dtype_from_parameters(module: nn.Module) -> torch.dtype: Note: assumes that all parameters have the same dtype. """ return next(iter(module.parameters())).dtype + + +def normalize_inputs(batch, stats, normalize_input_modes): + if normalize_input_modes is None: + return batch + for key, mode in normalize_input_modes.items(): + if mode == "mean_std": + mean = stats[key]["mean"].unsqueeze(0) + std = stats[key]["std"].unsqueeze(0) + batch[key] = (batch[key] - mean) / (std + 1e-8) + elif mode == "min_max": + min = stats[key]["min"].unsqueeze(0) + max = stats[key]["max"].unsqueeze(0) + # normalize to [0,1] + batch[key] = (batch[key] - min) / (max - min) + # normalize to [-1, 1] + batch[key] = batch[key] * 2 - 1 + else: + raise ValueError(mode) + return batch + + +def unnormalize_outputs(batch, stats, unnormalize_output_modes): + if unnormalize_output_modes is None: + return batch + for key, mode in unnormalize_output_modes.items(): + if mode == "mean_std": + mean = stats[key]["mean"].unsqueeze(0) + std = stats[key]["std"].unsqueeze(0) + batch[key] = batch[key] * std + mean + elif mode == "min_max": + min = stats[key]["min"].unsqueeze(0) + max = stats[key]["max"].unsqueeze(0) + batch[key] = (batch[key] + 1) / 2 + batch[key] = batch[key] * (max - min) + min + else: + raise ValueError(mode) + return batch diff --git a/lerobot/common/transforms.py b/lerobot/common/transforms.py deleted file mode 100644 index fffa835a..00000000 --- a/lerobot/common/transforms.py +++ /dev/null @@ -1,65 +0,0 @@ -from torchvision.transforms.v2 import Compose, Transform - - -def apply_inverse_transform(item, transform): - transforms = transform.transforms if isinstance(transform, Compose) else [transform] - for tf in transforms[::-1]: - if tf.invertible: - item = tf.inverse_transform(item) - else: - raise ValueError(f"Inverse transform called on a non invertible transform ({tf}).") - return item - - -class NormalizeTransform(Transform): - invertible = True - - def __init__( - self, - stats: dict, - in_keys: list[str] = None, - out_keys: list[str] | None = None, - in_keys_inv: list[str] | None = None, - out_keys_inv: list[str] | None = None, - mode="mean_std", - ): - super().__init__() - self.in_keys = in_keys - self.out_keys = in_keys if out_keys is None else out_keys - self.in_keys_inv = self.out_keys if in_keys_inv is None else in_keys_inv - self.out_keys_inv = self.in_keys if out_keys_inv is None else out_keys_inv - self.stats = stats - assert mode in ["mean_std", "min_max"] - self.mode = mode - - def forward(self, item): - for inkey, outkey in zip(self.in_keys, self.out_keys, strict=False): - if inkey not in item: - continue - if self.mode == "mean_std": - mean = self.stats[inkey]["mean"] - std = self.stats[inkey]["std"] - item[outkey] = (item[inkey] - mean) / (std + 1e-8) - else: - min = self.stats[inkey]["min"] - max = self.stats[inkey]["max"] - # normalize to [0,1] - item[outkey] = (item[inkey] - min) / (max - min) - # normalize to [-1, 1] - item[outkey] = item[outkey] * 2 - 1 - return item - - def inverse_transform(self, item): - for inkey, outkey in zip(self.in_keys_inv, self.out_keys_inv, strict=False): - if inkey not in item: - continue - if self.mode == "mean_std": - mean = self.stats[inkey]["mean"] - std = self.stats[inkey]["std"] - item[outkey] = item[inkey] * std + mean - else: - min = self.stats[inkey]["min"] - max = self.stats[inkey]["max"] - item[outkey] = (item[inkey] + 1) / 2 - item[outkey] = item[outkey] * (max - min) + min - return item diff --git a/lerobot/configs/policy/act.yaml b/lerobot/configs/policy/act.yaml index eb4e512b..bc859067 100644 --- a/lerobot/configs/policy/act.yaml +++ b/lerobot/configs/policy/act.yaml @@ -11,6 +11,11 @@ log_freq: 250 n_obs_steps: 1 # when temporal_agg=False, n_action_steps=horizon +override_dataset_stats: + observation.image: + mean: [[[0.485]], [[0.456]], [[0.406]]] # (c,1,1) + std: [[[0.229]], [[0.224]], [[0.225]]] # (c,1,1) + # See `configuration_act.py` for more details. policy: name: act @@ -28,9 +33,12 @@ policy: chunk_size: 100 # chunk_size n_action_steps: 100 - # Vision preprocessing. - image_normalization_mean: [0.485, 0.456, 0.406] - image_normalization_std: [0.229, 0.224, 0.225] + # Normalization / Unnormalization + normalize_input_modes: + observation.image: mean_std + observation.state: mean_std + unnormalize_output_modes: + action: mean_std # Architecture. # Vision backbone. diff --git a/lerobot/configs/policy/diffusion.yaml b/lerobot/configs/policy/diffusion.yaml index 44746dfc..94eacb49 100644 --- a/lerobot/configs/policy/diffusion.yaml +++ b/lerobot/configs/policy/diffusion.yaml @@ -18,6 +18,17 @@ online_steps: 0 offline_prioritized_sampler: true +override_dataset_stats: + observation.image: + mean: [[[0.5]], [[0.5]], [[0.5]]] # (c,1,1) + std: [[[0.5]], [[0.5]], [[0.5]]] # (c,1,1) + observation.state: + min: [13.456424, 32.938293] + max: [496.14618, 510.9579] + action: + min: [12.0, 25.0] + max: [511.0, 511.0] + policy: name: diffusion @@ -36,9 +47,12 @@ policy: horizon: ${horizon} n_action_steps: ${n_action_steps} - # Vision preprocessing. - image_normalization_mean: [0.5, 0.5, 0.5] - image_normalization_std: [0.5, 0.5, 0.5] + # Normalization / Unnormalization + normalize_input_modes: + observation.image: mean_std + observation.state: min_max + unnormalize_output_modes: + action: min_max # Architecture / modeling. # Vision backbone. diff --git a/lerobot/scripts/eval.py b/lerobot/scripts/eval.py index 32b7e26b..c66e7ee9 100644 --- a/lerobot/scripts/eval.py +++ b/lerobot/scripts/eval.py @@ -46,7 +46,6 @@ from huggingface_hub import snapshot_download from PIL import Image as PILImage from tqdm import trange -from lerobot.common.datasets.factory import make_dataset from lerobot.common.datasets.utils import hf_transform_to_torch from lerobot.common.envs.factory import make_env from lerobot.common.envs.utils import postprocess_action, preprocess_observation @@ -64,8 +63,6 @@ def eval_policy( policy: torch.nn.Module, max_episodes_rendered: int = 0, video_dir: Path = None, - # TODO(rcadene): make it possible to overwrite fps? we should use env.fps - transform: callable = None, return_episode_data: bool = False, seed=None, ): @@ -132,10 +129,6 @@ def eval_policy( if return_episode_data: observations.append(deepcopy(observation)) - # apply transform to normalize the observations - for key in observation: - observation[key] = torch.stack([transform({key: item})[key] for item in observation[key]]) - # send observation to device/gpu observation = {key: observation[key].to(device, non_blocking=True) for key in observation} @@ -143,8 +136,8 @@ def eval_policy( with torch.inference_mode(): action = policy.select_action(observation, step=step) - # apply inverse transform to unnormalize the action - action = postprocess_action(action, transform) + # convert to cpu numpy + action = postprocess_action(action) # apply the next action observation, reward, terminated, truncated, info = env.step(action) @@ -360,7 +353,7 @@ def eval_policy( return info -def eval(cfg: dict, out_dir=None, stats_path=None): +def eval(cfg: dict, out_dir=None): if out_dir is None: raise NotImplementedError() @@ -375,10 +368,6 @@ def eval(cfg: dict, out_dir=None, stats_path=None): log_output_dir(out_dir) - logging.info("Making transforms.") - # TODO(alexander-soare): Completely decouple datasets from evaluation. - transform = make_dataset(cfg, stats_path=stats_path).transform - logging.info("Making environment.") env = make_env(cfg, num_parallel_envs=cfg.eval_episodes) @@ -390,7 +379,6 @@ def eval(cfg: dict, out_dir=None, stats_path=None): policy, max_episodes_rendered=10, video_dir=Path(out_dir) / "eval", - transform=transform, return_episode_data=False, seed=cfg.seed, ) @@ -423,17 +411,13 @@ if __name__ == "__main__": if args.config is not None: # Note: For the config_path, Hydra wants a path relative to this script file. cfg = init_hydra_config(args.config, args.overrides) - # TODO(alexander-soare): Save and load stats in trained model directory. - stats_path = None elif args.hub_id is not None: folder = Path(snapshot_download(args.hub_id, revision=args.revision)) cfg = init_hydra_config( folder / "config.yaml", [f"policy.pretrained_model_path={folder / 'model.pt'}", *args.overrides] ) - stats_path = folder / "stats.pth" eval( cfg, out_dir=f"outputs/eval/{dt.now().strftime('%Y-%m-%d/%H-%M-%S')}_{cfg.env.name}_{cfg.policy.name}", - stats_path=stats_path, ) diff --git a/lerobot/scripts/train.py b/lerobot/scripts/train.py index 1f4ee16a..1033ae8f 100644 --- a/lerobot/scripts/train.py +++ b/lerobot/scripts/train.py @@ -232,7 +232,7 @@ def train(cfg: dict, out_dir=None, job_name=None): env = make_env(cfg, num_parallel_envs=cfg.eval_episodes) logging.info("make_policy") - policy = make_policy(cfg) + policy = make_policy(cfg, dataset_stats=offline_dataset.stats) num_learnable_params = sum(p.numel() for p in policy.parameters() if p.requires_grad) num_total_params = sum(p.numel() for p in policy.parameters()) diff --git a/tests/test_envs.py b/tests/test_envs.py index 33928a62..01a90c87 100644 --- a/tests/test_envs.py +++ b/tests/test_envs.py @@ -42,7 +42,7 @@ def test_factory(env_name): env = make_env(cfg, num_parallel_envs=1) obs, _ = env.reset() - obs = preprocess_observation(obs, transform=dataset.transform) + obs = preprocess_observation(obs) for key in dataset.image_keys: img = obs[key] assert img.dtype == torch.float32 diff --git a/tests/test_examples.py b/tests/test_examples.py index 3ac040b1..876735b4 100644 --- a/tests/test_examples.py +++ b/tests/test_examples.py @@ -51,7 +51,7 @@ def test_examples_4_and_3(): # Pass empty globals to allow dictionary comprehension https://stackoverflow.com/a/32897127/4391249. exec(file_contents, {}) - for file_name in ["model.pt", "stats.pth", "config.yaml"]: + for file_name in ["model.pt", "config.yaml"]: assert Path(f"outputs/train/example_pusht_diffusion/{file_name}").exists() path = "examples/3_evaluate_pretrained_policy.py" diff --git a/tests/test_policies.py b/tests/test_policies.py index ab679fcb..37401598 100644 --- a/tests/test_policies.py +++ b/tests/test_policies.py @@ -44,14 +44,16 @@ def test_policy(env_name, policy_name, extra_overrides): ] + extra_overrides, ) + # Check that we can make the policy object. - policy = make_policy(cfg) + dataset = make_dataset(cfg) + policy = make_policy(cfg, dataset_stats=dataset.stats) # Check that the policy follows the required protocol. assert isinstance( policy, Policy ), f"The policy does not follow the required protocol. Please see {Policy.__module__}.{Policy.__name__}." + # Check that we run select_actions and get the appropriate output. - dataset = make_dataset(cfg) env = make_env(cfg, num_parallel_envs=2) dataloader = torch.utils.data.DataLoader( @@ -77,7 +79,7 @@ def test_policy(env_name, policy_name, extra_overrides): observation, _ = env.reset(seed=cfg.seed) # apply transform to normalize the observations - observation = preprocess_observation(observation, dataset.transform) + observation = preprocess_observation(observation) # send observation to device/gpu observation = {key: observation[key].to(DEVICE, non_blocking=True) for key in observation} @@ -86,8 +88,8 @@ def test_policy(env_name, policy_name, extra_overrides): with torch.inference_mode(): action = policy.select_action(observation, step=0) - # apply inverse transform to unnormalize the action - action = postprocess_action(action, dataset.transform) + # convert action to cpu numpy array + action = postprocess_action(action) # Test step through policy env.step(action)