diff --git a/lerobot/common/datasets/pusht.py b/lerobot/common/datasets/pusht.py index b5d4fab3..1c1b658e 100644 --- a/lerobot/common/datasets/pusht.py +++ b/lerobot/common/datasets/pusht.py @@ -1,4 +1,5 @@ import logging +import math import os from pathlib import Path from typing import Callable @@ -134,18 +135,19 @@ class PushtExperienceReplay(TensorDictReplayBuffer): else: storage = TensorStorage(TensorDict.load_memmap(self.root / dataset_id)) - mean_std = self._compute_or_load_mean_std(storage) - mean_std["next", "observation", "image"] = mean_std["observation", "image"] - mean_std["next", "observation", "state"] = mean_std["observation", "state"] + stats = self._compute_or_load_stats(storage) + stats["next", "observation", "image"] = stats["observation", "image"] + stats["next", "observation", "state"] = stats["observation", "state"] transform = NormalizeTransform( - mean_std, + stats, in_keys=[ - ("observation", "image"), + # ("observation", "image"), ("observation", "state"), - ("next", "observation", "image"), + # ("next", "observation", "image"), ("next", "observation", "state"), ("action"), ], + mode="min_max", ) if writer is None: @@ -282,7 +284,7 @@ class PushtExperienceReplay(TensorDictReplayBuffer): return TensorStorage(td_data.lock_()) - def _compute_mean_std(self, storage, num_batch=10, batch_size=32): + def _compute_stats(self, storage, num_batch=100, batch_size=32): rb = TensorDictReplayBuffer( storage=storage, batch_size=batch_size, @@ -291,15 +293,27 @@ class PushtExperienceReplay(TensorDictReplayBuffer): batch = rb.sample() image_mean = torch.zeros(batch["observation", "image"].shape[1]) image_std = torch.zeros(batch["observation", "image"].shape[1]) + image_max = -math.inf + image_min = math.inf state_mean = torch.zeros(batch["observation", "state"].shape[1]) state_std = torch.zeros(batch["observation", "state"].shape[1]) + state_max = -math.inf + state_min = math.inf action_mean = torch.zeros(batch["action"].shape[1]) action_std = torch.zeros(batch["action"].shape[1]) + action_max = -math.inf + action_min = math.inf for _ in tqdm.tqdm(range(num_batch)): image_mean += einops.reduce(batch["observation", "image"], "b c h w -> c", reduction="mean") state_mean += batch["observation", "state"].mean(dim=0) action_mean += batch["action"].mean(dim=0) + image_max = max(image_max, batch["observation", "image"].max().item()) + image_min = min(image_min, batch["observation", "image"].min().item()) + state_max = max(state_max, batch["observation", "state"].max().item()) + state_min = min(state_min, batch["observation", "state"].min().item()) + action_max = max(action_max, batch["action"].max().item()) + action_min = min(action_min, batch["action"].min().item()) batch = rb.sample() image_mean /= num_batch @@ -311,6 +325,12 @@ class PushtExperienceReplay(TensorDictReplayBuffer): image_std += (image_mean_batch - image_mean) ** 2 state_std += (batch["observation", "state"].mean(dim=0) - state_mean) ** 2 action_std += (batch["action"].mean(dim=0) - action_mean) ** 2 + image_max = max(image_max, batch["observation", "image"].max().item()) + image_min = min(image_min, batch["observation", "image"].min().item()) + state_max = max(state_max, batch["observation", "state"].max().item()) + state_min = min(state_min, batch["observation", "state"].min().item()) + action_max = max(action_max, batch["action"].max().item()) + action_min = min(action_min, batch["action"].min().item()) if i < num_batch - 1: batch = rb.sample() @@ -318,25 +338,31 @@ class PushtExperienceReplay(TensorDictReplayBuffer): state_std = torch.sqrt(state_std / num_batch) action_std = torch.sqrt(action_std / num_batch) - mean_std = TensorDict( + stats = TensorDict( { ("observation", "image", "mean"): image_mean[None, :, None, None], ("observation", "image", "std"): image_std[None, :, None, None], + ("observation", "image", "max"): torch.tensor(image_max), + ("observation", "image", "min"): torch.tensor(image_min), ("observation", "state", "mean"): state_mean[None, :], ("observation", "state", "std"): state_std[None, :], + ("observation", "state", "max"): torch.tensor(state_max), + ("observation", "state", "min"): torch.tensor(state_min), ("action", "mean"): action_mean[None, :], ("action", "std"): action_std[None, :], + ("action", "max"): torch.tensor(action_max), + ("action", "min"): torch.tensor(action_min), }, batch_size=[], ) - return mean_std + return stats - def _compute_or_load_mean_std(self, storage) -> TensorDict: - mean_std_path = self.root / self.dataset_id / "mean_std.pth" - if mean_std_path.exists(): - mean_std = torch.load(mean_std_path) + def _compute_or_load_stats(self, storage) -> TensorDict: + stats_path = self.root / self.dataset_id / "stats.pth" + if stats_path.exists(): + stats = torch.load(stats_path) else: - logging.info(f"compute_mean_std and save to {mean_std_path}") - mean_std = self._compute_mean_std(storage) - torch.save(mean_std, mean_std_path) - return mean_std + logging.info(f"compute_stats and save to {stats_path}") + stats = self._compute_stats(storage) + torch.save(stats, stats_path) + return stats diff --git a/lerobot/common/envs/factory.py b/lerobot/common/envs/factory.py index 2cd4f73b..d7dc8aae 100644 --- a/lerobot/common/envs/factory.py +++ b/lerobot/common/envs/factory.py @@ -1,7 +1,5 @@ from torchrl.envs.transforms import StepCounter, TransformedEnv -from lerobot.common.envs.transforms import Prod - def make_env(cfg, transform=None): kwargs = { @@ -28,12 +26,8 @@ def make_env(cfg, transform=None): # limit rollout to max_steps env = TransformedEnv(env, StepCounter(max_steps=cfg.env.episode_length)) - if cfg.env.name == "pusht": - # to ensure pusht is in [0,255] like simxarm - env.append_transform(Prod(in_keys=[("observation", "image")], prod=255.0)) - if transform is not None: - # useful to add mean and std normalization + # useful to add normalization env.append_transform(transform) return env diff --git a/lerobot/common/envs/transforms.py b/lerobot/common/envs/transforms.py index 67601eac..671c0827 100644 --- a/lerobot/common/envs/transforms.py +++ b/lerobot/common/envs/transforms.py @@ -28,11 +28,12 @@ class NormalizeTransform(Transform): def __init__( self, - mean_std: TensorDictBase, + stats: TensorDictBase, in_keys: Sequence[NestedKey] = None, out_keys: Sequence[NestedKey] | None = None, in_keys_inv: Sequence[NestedKey] | None = None, out_keys_inv: Sequence[NestedKey] | None = None, + mode="mean_std", ): if out_keys is None: out_keys = in_keys @@ -43,7 +44,14 @@ class NormalizeTransform(Transform): super().__init__( in_keys=in_keys, out_keys=out_keys, in_keys_inv=in_keys_inv, out_keys_inv=out_keys_inv ) - self.mean_std = mean_std + self.stats = stats + assert mode in ["mean_std", "min_max"] + self.mode = mode + + def _reset(self, tensordict: TensorDictBase, tensordict_reset: TensorDictBase) -> TensorDictBase: + # _reset is called once when the environment reset to normalize the first observation + tensordict_reset = self._call(tensordict_reset) + return tensordict_reset @dispatch(source="in_keys", dest="out_keys") def forward(self, tensordict: TensorDictBase) -> TensorDictBase: @@ -54,9 +62,17 @@ class NormalizeTransform(Transform): # TODO(rcadene): don't know how to do `inkey not in td` if td.get(inkey, None) is None: continue - mean = self.mean_std[inkey]["mean"] - std = self.mean_std[inkey]["std"] - td[outkey] = (td[inkey] - mean) / (std + 1e-8) + if self.mode == "mean_std": + mean = self.stats[inkey]["mean"] + std = self.stats[inkey]["std"] + td[outkey] = (td[inkey] - mean) / (std + 1e-8) + else: + min = self.stats[inkey]["min"] + max = self.stats[inkey]["max"] + # normalize to [0,1] + td[outkey] = (td[inkey] - min) / (max - min) + # normalize to [-1, 1] + td[outkey] = td[outkey] * 2 - 1 return td def _inv_call(self, td: TensorDictBase) -> TensorDictBase: @@ -64,7 +80,13 @@ class NormalizeTransform(Transform): # TODO(rcadene): don't know how to do `inkey not in td` if td.get(inkey, None) is None: continue - mean = self.mean_std[inkey]["mean"] - std = self.mean_std[inkey]["std"] - td[outkey] = td[inkey] * std + mean + if self.mode == "mean_std": + mean = self.stats[inkey]["mean"] + std = self.stats[inkey]["std"] + td[outkey] = td[inkey] * std + mean + else: + min = self.stats[inkey]["min"] + max = self.stats[inkey]["max"] + td[outkey] = (td[inkey] + 1) / 2 + td[outkey] = td[outkey] * (max - min) + min return td diff --git a/lerobot/scripts/eval.py b/lerobot/scripts/eval.py index 6391903e..214f5dba 100644 --- a/lerobot/scripts/eval.py +++ b/lerobot/scripts/eval.py @@ -118,7 +118,7 @@ def eval(cfg: dict, out_dir=None): offline_buffer = make_offline_buffer(cfg) logging.info("make_env") - env = make_env(cfg, transform=offline_buffer.transform) + env = make_env(cfg, transform=offline_buffer._transform) if cfg.policy.pretrained_model_path: policy = make_policy(cfg)