From 48ded3dbc7041616cbb58069a907a5a0a4c51016 Mon Sep 17 00:00:00 2001 From: Remi Cadene Date: Sat, 2 Mar 2024 18:11:50 +0000 Subject: [PATCH 1/7] fix --- lerobot/common/logger.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/lerobot/common/logger.py b/lerobot/common/logger.py index 2b877b2e..a8cb6a66 100644 --- a/lerobot/common/logger.py +++ b/lerobot/common/logger.py @@ -144,5 +144,5 @@ class Logger: def log_video(self, video, step, mode="train"): assert mode in {"train", "eval"} - wandb_video = self._wandb.Video(video, fps=self.cfg.fps, format="mp4") + wandb_video = self._wandb.Video(video, fps=self._cfg.fps, format="mp4") self._wandb.log({f"{mode}/video": wandb_video}, step=step) From cbbed590a9307065d52f7c94daad2734b69fad3f Mon Sep 17 00:00:00 2001 From: Remi Cadene Date: Sun, 3 Mar 2024 13:19:02 +0000 Subject: [PATCH 2/7] Add mode to NormalizeTransform with mean_std or min_max (Not fully tested) --- lerobot/common/datasets/pusht.py | 60 ++++++++++++++++++++++--------- lerobot/common/envs/factory.py | 8 +---- lerobot/common/envs/transforms.py | 38 +++++++++++++++----- lerobot/scripts/eval.py | 2 +- 4 files changed, 75 insertions(+), 33 deletions(-) diff --git a/lerobot/common/datasets/pusht.py b/lerobot/common/datasets/pusht.py index b5d4fab3..1c1b658e 100644 --- a/lerobot/common/datasets/pusht.py +++ b/lerobot/common/datasets/pusht.py @@ -1,4 +1,5 @@ import logging +import math import os from pathlib import Path from typing import Callable @@ -134,18 +135,19 @@ class PushtExperienceReplay(TensorDictReplayBuffer): else: storage = TensorStorage(TensorDict.load_memmap(self.root / dataset_id)) - mean_std = self._compute_or_load_mean_std(storage) - mean_std["next", "observation", "image"] = mean_std["observation", "image"] - mean_std["next", "observation", "state"] = mean_std["observation", "state"] + stats = self._compute_or_load_stats(storage) + stats["next", "observation", "image"] = stats["observation", "image"] + stats["next", "observation", "state"] = stats["observation", "state"] transform = NormalizeTransform( - mean_std, + stats, in_keys=[ - ("observation", "image"), + # ("observation", "image"), ("observation", "state"), - ("next", "observation", "image"), + # ("next", "observation", "image"), ("next", "observation", "state"), ("action"), ], + mode="min_max", ) if writer is None: @@ -282,7 +284,7 @@ class PushtExperienceReplay(TensorDictReplayBuffer): return TensorStorage(td_data.lock_()) - def _compute_mean_std(self, storage, num_batch=10, batch_size=32): + def _compute_stats(self, storage, num_batch=100, batch_size=32): rb = TensorDictReplayBuffer( storage=storage, batch_size=batch_size, @@ -291,15 +293,27 @@ class PushtExperienceReplay(TensorDictReplayBuffer): batch = rb.sample() image_mean = torch.zeros(batch["observation", "image"].shape[1]) image_std = torch.zeros(batch["observation", "image"].shape[1]) + image_max = -math.inf + image_min = math.inf state_mean = torch.zeros(batch["observation", "state"].shape[1]) state_std = torch.zeros(batch["observation", "state"].shape[1]) + state_max = -math.inf + state_min = math.inf action_mean = torch.zeros(batch["action"].shape[1]) action_std = torch.zeros(batch["action"].shape[1]) + action_max = -math.inf + action_min = math.inf for _ in tqdm.tqdm(range(num_batch)): image_mean += einops.reduce(batch["observation", "image"], "b c h w -> c", reduction="mean") state_mean += batch["observation", "state"].mean(dim=0) action_mean += batch["action"].mean(dim=0) + image_max = max(image_max, batch["observation", "image"].max().item()) + image_min = min(image_min, batch["observation", "image"].min().item()) + state_max = max(state_max, batch["observation", "state"].max().item()) + state_min = min(state_min, batch["observation", "state"].min().item()) + action_max = max(action_max, batch["action"].max().item()) + action_min = min(action_min, batch["action"].min().item()) batch = rb.sample() image_mean /= num_batch @@ -311,6 +325,12 @@ class PushtExperienceReplay(TensorDictReplayBuffer): image_std += (image_mean_batch - image_mean) ** 2 state_std += (batch["observation", "state"].mean(dim=0) - state_mean) ** 2 action_std += (batch["action"].mean(dim=0) - action_mean) ** 2 + image_max = max(image_max, batch["observation", "image"].max().item()) + image_min = min(image_min, batch["observation", "image"].min().item()) + state_max = max(state_max, batch["observation", "state"].max().item()) + state_min = min(state_min, batch["observation", "state"].min().item()) + action_max = max(action_max, batch["action"].max().item()) + action_min = min(action_min, batch["action"].min().item()) if i < num_batch - 1: batch = rb.sample() @@ -318,25 +338,31 @@ class PushtExperienceReplay(TensorDictReplayBuffer): state_std = torch.sqrt(state_std / num_batch) action_std = torch.sqrt(action_std / num_batch) - mean_std = TensorDict( + stats = TensorDict( { ("observation", "image", "mean"): image_mean[None, :, None, None], ("observation", "image", "std"): image_std[None, :, None, None], + ("observation", "image", "max"): torch.tensor(image_max), + ("observation", "image", "min"): torch.tensor(image_min), ("observation", "state", "mean"): state_mean[None, :], ("observation", "state", "std"): state_std[None, :], + ("observation", "state", "max"): torch.tensor(state_max), + ("observation", "state", "min"): torch.tensor(state_min), ("action", "mean"): action_mean[None, :], ("action", "std"): action_std[None, :], + ("action", "max"): torch.tensor(action_max), + ("action", "min"): torch.tensor(action_min), }, batch_size=[], ) - return mean_std + return stats - def _compute_or_load_mean_std(self, storage) -> TensorDict: - mean_std_path = self.root / self.dataset_id / "mean_std.pth" - if mean_std_path.exists(): - mean_std = torch.load(mean_std_path) + def _compute_or_load_stats(self, storage) -> TensorDict: + stats_path = self.root / self.dataset_id / "stats.pth" + if stats_path.exists(): + stats = torch.load(stats_path) else: - logging.info(f"compute_mean_std and save to {mean_std_path}") - mean_std = self._compute_mean_std(storage) - torch.save(mean_std, mean_std_path) - return mean_std + logging.info(f"compute_stats and save to {stats_path}") + stats = self._compute_stats(storage) + torch.save(stats, stats_path) + return stats diff --git a/lerobot/common/envs/factory.py b/lerobot/common/envs/factory.py index 2cd4f73b..d7dc8aae 100644 --- a/lerobot/common/envs/factory.py +++ b/lerobot/common/envs/factory.py @@ -1,7 +1,5 @@ from torchrl.envs.transforms import StepCounter, TransformedEnv -from lerobot.common.envs.transforms import Prod - def make_env(cfg, transform=None): kwargs = { @@ -28,12 +26,8 @@ def make_env(cfg, transform=None): # limit rollout to max_steps env = TransformedEnv(env, StepCounter(max_steps=cfg.env.episode_length)) - if cfg.env.name == "pusht": - # to ensure pusht is in [0,255] like simxarm - env.append_transform(Prod(in_keys=[("observation", "image")], prod=255.0)) - if transform is not None: - # useful to add mean and std normalization + # useful to add normalization env.append_transform(transform) return env diff --git a/lerobot/common/envs/transforms.py b/lerobot/common/envs/transforms.py index 67601eac..671c0827 100644 --- a/lerobot/common/envs/transforms.py +++ b/lerobot/common/envs/transforms.py @@ -28,11 +28,12 @@ class NormalizeTransform(Transform): def __init__( self, - mean_std: TensorDictBase, + stats: TensorDictBase, in_keys: Sequence[NestedKey] = None, out_keys: Sequence[NestedKey] | None = None, in_keys_inv: Sequence[NestedKey] | None = None, out_keys_inv: Sequence[NestedKey] | None = None, + mode="mean_std", ): if out_keys is None: out_keys = in_keys @@ -43,7 +44,14 @@ class NormalizeTransform(Transform): super().__init__( in_keys=in_keys, out_keys=out_keys, in_keys_inv=in_keys_inv, out_keys_inv=out_keys_inv ) - self.mean_std = mean_std + self.stats = stats + assert mode in ["mean_std", "min_max"] + self.mode = mode + + def _reset(self, tensordict: TensorDictBase, tensordict_reset: TensorDictBase) -> TensorDictBase: + # _reset is called once when the environment reset to normalize the first observation + tensordict_reset = self._call(tensordict_reset) + return tensordict_reset @dispatch(source="in_keys", dest="out_keys") def forward(self, tensordict: TensorDictBase) -> TensorDictBase: @@ -54,9 +62,17 @@ class NormalizeTransform(Transform): # TODO(rcadene): don't know how to do `inkey not in td` if td.get(inkey, None) is None: continue - mean = self.mean_std[inkey]["mean"] - std = self.mean_std[inkey]["std"] - td[outkey] = (td[inkey] - mean) / (std + 1e-8) + if self.mode == "mean_std": + mean = self.stats[inkey]["mean"] + std = self.stats[inkey]["std"] + td[outkey] = (td[inkey] - mean) / (std + 1e-8) + else: + min = self.stats[inkey]["min"] + max = self.stats[inkey]["max"] + # normalize to [0,1] + td[outkey] = (td[inkey] - min) / (max - min) + # normalize to [-1, 1] + td[outkey] = td[outkey] * 2 - 1 return td def _inv_call(self, td: TensorDictBase) -> TensorDictBase: @@ -64,7 +80,13 @@ class NormalizeTransform(Transform): # TODO(rcadene): don't know how to do `inkey not in td` if td.get(inkey, None) is None: continue - mean = self.mean_std[inkey]["mean"] - std = self.mean_std[inkey]["std"] - td[outkey] = td[inkey] * std + mean + if self.mode == "mean_std": + mean = self.stats[inkey]["mean"] + std = self.stats[inkey]["std"] + td[outkey] = td[inkey] * std + mean + else: + min = self.stats[inkey]["min"] + max = self.stats[inkey]["max"] + td[outkey] = (td[inkey] + 1) / 2 + td[outkey] = td[outkey] * (max - min) + min return td diff --git a/lerobot/scripts/eval.py b/lerobot/scripts/eval.py index 6391903e..214f5dba 100644 --- a/lerobot/scripts/eval.py +++ b/lerobot/scripts/eval.py @@ -118,7 +118,7 @@ def eval(cfg: dict, out_dir=None): offline_buffer = make_offline_buffer(cfg) logging.info("make_env") - env = make_env(cfg, transform=offline_buffer.transform) + env = make_env(cfg, transform=offline_buffer._transform) if cfg.policy.pretrained_model_path: policy = make_policy(cfg) From 0f2fa4d9ef732047d6637ae8086cb69c96b3c4a3 Mon Sep 17 00:00:00 2001 From: Remi Cadene Date: Sun, 3 Mar 2024 13:21:31 +0000 Subject: [PATCH 3/7] Add obs queue to pusht, Set n_obs_steps=2 for diffusion (Not fully tested) --- lerobot/common/envs/pusht.py | 84 +++++++++++++++++++-- lerobot/common/policies/diffusion/policy.py | 13 ++-- lerobot/configs/policy/diffusion.yaml | 2 +- 3 files changed, 85 insertions(+), 14 deletions(-) diff --git a/lerobot/common/envs/pusht.py b/lerobot/common/envs/pusht.py index 39bf3bba..ab979b38 100644 --- a/lerobot/common/envs/pusht.py +++ b/lerobot/common/envs/pusht.py @@ -1,4 +1,5 @@ import importlib +from collections import deque from typing import Optional import torch @@ -27,12 +28,16 @@ class PushtEnv(EnvBase): image_size=None, seed=1337, device="cpu", + num_prev_obs=1, + num_prev_action=0, ): super().__init__(device=device, batch_size=[]) self.frame_skip = frame_skip self.from_pixels = from_pixels self.pixels_only = pixels_only self.image_size = image_size + self.num_prev_obs = num_prev_obs + self.num_prev_action = num_prev_action if pixels_only: assert from_pixels @@ -56,6 +61,12 @@ class PushtEnv(EnvBase): self._make_spec() self._current_seed = self.set_seed(seed) + if self.num_prev_obs > 0: + self._prev_obs_image_queue = deque(maxlen=self.num_prev_obs) + self._prev_obs_state_queue = deque(maxlen=self.num_prev_obs) + if self.num_prev_action > 0: + self._prev_action_queue = deque(maxlen=self.num_prev_action) + def render(self, mode="rgb_array", width=384, height=384): if width != height: raise NotImplementedError() @@ -67,7 +78,8 @@ class PushtEnv(EnvBase): def _format_raw_obs(self, raw_obs): if self.from_pixels: - obs = {"image": torch.from_numpy(raw_obs["image"])} + image = torch.from_numpy(raw_obs["image"]) + obs = {"image": image} if not self.pixels_only: obs["state"] = torch.from_numpy(raw_obs["agent_pos"]).type(torch.float32) @@ -75,7 +87,6 @@ class PushtEnv(EnvBase): # TODO: obs = {"state": torch.from_numpy(raw_obs["observation"]).type(torch.float32)} - obs = TensorDict(obs, batch_size=[]) return obs def _reset(self, tensordict: Optional[TensorDict] = None): @@ -87,9 +98,21 @@ class PushtEnv(EnvBase): raw_obs = self._env.reset() assert self._current_seed == self._env._seed + obs = self._format_raw_obs(raw_obs) + + if self.num_prev_obs > 0: + # remove all previous observations + if "image" in obs: + self._prev_obs_image_queue.clear() + if "state" in obs: + self._prev_obs_state_queue.clear() + + # copy the current observation n times + obs = self._stack_prev_obs(obs) + td = TensorDict( { - "observation": self._format_raw_obs(raw_obs), + "observation": TensorDict(obs, batch_size=[]), "done": torch.tensor([False], dtype=torch.bool), }, batch_size=[], @@ -98,6 +121,40 @@ class PushtEnv(EnvBase): raise NotImplementedError() return td + def _stack_prev_obs(self, obs): + """When the queue is empty, copy the current observation n times.""" + assert self.num_prev_obs > 0 + + def stack_update_queue(prev_obs_queue, obs, num_prev_obs): + # get n most recent observations + prev_obs = list(prev_obs_queue)[-num_prev_obs:] + + # if not enough observations, copy the oldest observation until we obtain n observations + if len(prev_obs) == 0: + prev_obs = [obs] * num_prev_obs # queue is empty when env reset + elif len(prev_obs) < num_prev_obs: + prev_obs = [prev_obs[0] for _ in range(num_prev_obs - len(prev_obs))] + prev_obs + + # stack n most recent observations with the current observation + stacked_obs = torch.stack(prev_obs + [obs], dim=0) + + # add current observation to the queue + # automatically remove oldest observation when queue is full + prev_obs_queue.appendleft(obs) + + return stacked_obs + + stacked_obs = {} + if "image" in obs: + stacked_obs["image"] = stack_update_queue( + self._prev_obs_image_queue, obs["image"], self.num_prev_obs + ) + if "state" in obs: + stacked_obs["state"] = stack_update_queue( + self._prev_obs_state_queue, obs["state"], self.num_prev_obs + ) + return stacked_obs + def _step(self, tensordict: TensorDict): td = tensordict # remove batch dim @@ -109,9 +166,14 @@ class PushtEnv(EnvBase): raw_obs, reward, done, info = self._env.step(action) sum_reward += reward + obs = self._format_raw_obs(raw_obs) + + if self.num_prev_obs > 0: + obs = self._stack_prev_obs(obs) + td = TensorDict( { - "observation": self._format_raw_obs(raw_obs), + "observation": TensorDict(obs, batch_size=[]), "reward": torch.tensor([sum_reward], dtype=torch.float32), # succes and done are true when coverage > self.success_threshold in env "done": torch.tensor([done], dtype=torch.bool), @@ -124,14 +186,22 @@ class PushtEnv(EnvBase): def _make_spec(self): obs = {} if self.from_pixels: + image_shape = (3, self.image_size, self.image_size) + if self.num_prev_obs > 0: + image_shape = (self.num_prev_obs, *image_shape) + obs["image"] = BoundedTensorSpec( low=0, high=1, - shape=(3, self.image_size, self.image_size), + shape=image_shape, dtype=torch.float32, device=self.device, ) if not self.pixels_only: + state_shape = self._env.observation_space["agent_pos"].shape + if self.num_prev_obs > 0: + state_shape = (self.num_prev_obs, *state_shape) + obs["state"] = BoundedTensorSpec( low=0, high=512, @@ -141,6 +211,10 @@ class PushtEnv(EnvBase): ) else: # TODO(rcadene): add observation_space achieved_goal and desired_goal? + state_shape = self._env.observation_space["observation"].shape + if self.num_prev_obs > 0: + state_shape = (self.num_prev_obs, *state_shape) + obs["state"] = UnboundedContinuousTensorSpec( # TODO: shape=self._env.observation_space["observation"].shape, diff --git a/lerobot/common/policies/diffusion/policy.py b/lerobot/common/policies/diffusion/policy.py index a484c65a..aeec502e 100644 --- a/lerobot/common/policies/diffusion/policy.py +++ b/lerobot/common/policies/diffusion/policy.py @@ -1,7 +1,6 @@ import copy import time -import einops import hydra import torch import torch.nn as nn @@ -101,15 +100,13 @@ class DiffusionPolicy(nn.Module): # TODO(rcadene): remove unused step_count del step_count - # TODO(rcadene): remove unsqueeze hack... - if observation["image"].ndim == 3: - observation["image"] = observation["image"].unsqueeze(0) - observation["state"] = observation["state"].unsqueeze(0) + # TODO(rcadene): remove unsqueeze hack to add bsize=1 + observation["image"] = observation["image"].unsqueeze(0) + observation["state"] = observation["state"].unsqueeze(0) obs_dict = { - # TODO(rcadene): hack to add temporal dim - "image": einops.rearrange(observation["image"], "b c h w -> b 1 c h w"), - "agent_pos": einops.rearrange(observation["state"], "b c -> b 1 c"), + "image": observation["image"], + "agent_pos": observation["state"], } out = self.diffusion.predict_action(obs_dict) diff --git a/lerobot/configs/policy/diffusion.yaml b/lerobot/configs/policy/diffusion.yaml index 6f18816a..f136fa55 100644 --- a/lerobot/configs/policy/diffusion.yaml +++ b/lerobot/configs/policy/diffusion.yaml @@ -13,7 +13,7 @@ shape_meta: shape: [2] horizon: 16 -n_obs_steps: 1 # TODO(rcadene): before 2 +n_obs_steps: 2 n_action_steps: 8 n_latency_steps: 0 dataset_obs_steps: ${n_obs_steps} From 4c400b41a53238239a6c0986da0f483bf277012d Mon Sep 17 00:00:00 2001 From: Remi Cadene Date: Sun, 3 Mar 2024 13:22:09 +0000 Subject: [PATCH 4/7] Improve log msg in train.py --- lerobot/scripts/train.py | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) diff --git a/lerobot/scripts/train.py b/lerobot/scripts/train.py index a537835e..1c63fc97 100644 --- a/lerobot/scripts/train.py +++ b/lerobot/scripts/train.py @@ -170,6 +170,7 @@ def train(cfg: dict, out_dir=None, job_name=None): log_train_info(logger, train_info, step, cfg, offline_buffer, is_offline) if step > 0 and step % cfg.eval_freq == 0: + logging.info(f"Eval policy at step {step}") eval_info, first_video = eval_policy( env, td_policy, @@ -179,10 +180,12 @@ def train(cfg: dict, out_dir=None, job_name=None): log_eval_info(logger, eval_info, step, cfg, offline_buffer, is_offline) if cfg.wandb.enable: logger.log_video(first_video, step, mode="eval") + logging.info("Resume training") if step > 0 and cfg.save_model and step % cfg.save_freq == 0: - logging.info(f"Checkpoint model at step {step}") + logging.info(f"Checkpoint policy at step {step}") logger.save_model(policy, identifier=step) + logging.info("Resume training") step += 1 @@ -227,6 +230,7 @@ def train(cfg: dict, out_dir=None, job_name=None): log_train_info(logger, train_info, step, cfg, offline_buffer, is_offline) if step > 0 and step % cfg.eval_freq == 0: + logging.info(f"Eval policy at step {step}") eval_info, first_video = eval_policy( env, td_policy, @@ -236,10 +240,12 @@ def train(cfg: dict, out_dir=None, job_name=None): log_eval_info(logger, eval_info, step, cfg, offline_buffer, is_offline) if cfg.wandb.enable: logger.log_video(first_video, step, mode="eval") + logging.info("Resume training") if step > 0 and cfg.save_model and step % cfg.save_freq == 0: - logging.info(f"Checkpoint model at step {step}") + logging.info(f"Checkpoint policy at step {step}") logger.save_model(policy, identifier=step) + logging.info("Resume training") step += 1 online_step += 1 From fddd9f0311e15ce36572d90f69e3c533398b8290 Mon Sep 17 00:00:00 2001 From: Remi Cadene Date: Sun, 3 Mar 2024 14:02:24 +0000 Subject: [PATCH 5/7] Add possibility for the policy to provide a sequence of actions to the env --- lerobot/common/envs/pusht.py | 15 +++++++++++---- lerobot/common/policies/diffusion/policy.py | 5 +---- 2 files changed, 12 insertions(+), 8 deletions(-) diff --git a/lerobot/common/envs/pusht.py b/lerobot/common/envs/pusht.py index ab979b38..927a1ba7 100644 --- a/lerobot/common/envs/pusht.py +++ b/lerobot/common/envs/pusht.py @@ -157,13 +157,20 @@ class PushtEnv(EnvBase): def _step(self, tensordict: TensorDict): td = tensordict - # remove batch dim - action = td["action"].squeeze(0).numpy() + action = td["action"].numpy() # step expects shape=(4,) so we pad if necessary # TODO(rcadene): add info["is_success"] and info["success"] ? sum_reward = 0 - for _ in range(self.frame_skip): - raw_obs, reward, done, info = self._env.step(action) + + if action.ndim == 1: + action = action.repeat(self.frame_skip, 1) + else: + if self.frame_skip > 1: + raise NotImplementedError() + + num_action_steps = action.shape[0] + for i in range(num_action_steps): + raw_obs, reward, done, info = self._env.step(action[i]) sum_reward += reward obs = self._format_raw_obs(raw_obs) diff --git a/lerobot/common/policies/diffusion/policy.py b/lerobot/common/policies/diffusion/policy.py index aeec502e..df05bfd8 100644 --- a/lerobot/common/policies/diffusion/policy.py +++ b/lerobot/common/policies/diffusion/policy.py @@ -12,8 +12,6 @@ from diffusion_policy.model.vision.model_getter import get_resnet from .diffusion_unet_image_policy import DiffusionUnetImagePolicy from .multi_image_obs_encoder import MultiImageObsEncoder -FIRST_ACTION = 0 - class DiffusionPolicy(nn.Module): def __init__( @@ -110,8 +108,7 @@ class DiffusionPolicy(nn.Module): } out = self.diffusion.predict_action(obs_dict) - # TODO(rcadene): add possibility to return >1 timestemps - action = out["action"].squeeze(0)[FIRST_ACTION] + action = out["action"].squeeze(0) return action def update(self, replay_buffer, step): From cfc304e870176a83c675c6b91fc760666b0c760f Mon Sep 17 00:00:00 2001 From: Remi Cadene Date: Mon, 4 Mar 2024 10:59:43 +0000 Subject: [PATCH 6/7] Refactor env queue, Training diffusion works (Still not converging) --- lerobot/common/datasets/factory.py | 4 +- lerobot/common/datasets/pusht.py | 13 ++++- lerobot/common/envs/factory.py | 4 ++ lerobot/common/envs/pusht.py | 59 +++++++-------------- lerobot/common/logger.py | 46 +--------------- lerobot/common/policies/diffusion/policy.py | 34 +++++++++--- lerobot/configs/default.yaml | 5 +- lerobot/configs/policy/diffusion.yaml | 18 +++---- lerobot/configs/policy/tdmpc.yaml | 2 + lerobot/scripts/eval.py | 2 +- lerobot/scripts/train.py | 20 +++++-- 11 files changed, 96 insertions(+), 111 deletions(-) diff --git a/lerobot/common/datasets/factory.py b/lerobot/common/datasets/factory.py index 942a36dd..b6b93d89 100644 --- a/lerobot/common/datasets/factory.py +++ b/lerobot/common/datasets/factory.py @@ -69,7 +69,7 @@ def make_offline_buffer(cfg, sampler=None): sampler=sampler, batch_size=batch_size, pin_memory=pin_memory, - prefetch=prefetch, + prefetch=prefetch if isinstance(prefetch, int) else None, ) elif cfg.env.name == "pusht": offline_buffer = PushtExperienceReplay( @@ -79,7 +79,7 @@ def make_offline_buffer(cfg, sampler=None): sampler=sampler, batch_size=batch_size, pin_memory=pin_memory, - prefetch=prefetch, + prefetch=prefetch if isinstance(prefetch, int) else None, ) else: raise ValueError(cfg.env.name) diff --git a/lerobot/common/datasets/pusht.py b/lerobot/common/datasets/pusht.py index 1c1b658e..8ea64f86 100644 --- a/lerobot/common/datasets/pusht.py +++ b/lerobot/common/datasets/pusht.py @@ -143,13 +143,24 @@ class PushtExperienceReplay(TensorDictReplayBuffer): in_keys=[ # ("observation", "image"), ("observation", "state"), + # TODO(rcadene): for tdmpc, we might want image and state # ("next", "observation", "image"), - ("next", "observation", "state"), + # ("next", "observation", "state"), ("action"), ], mode="min_max", ) + # TODO(rcadene): make normalization strategy configurable between mean_std, min_max, min_max_spec + transform.stats["observation", "state", "min"] = torch.tensor( + [13.456424, 32.938293], dtype=torch.float32 + ) + transform.stats["observation", "state", "max"] = torch.tensor( + [496.14618, 510.9579], dtype=torch.float32 + ) + transform.stats["action", "min"] = torch.tensor([12.0, 25.0], dtype=torch.float32) + transform.stats["action", "max"] = torch.tensor([511.0, 511.0], dtype=torch.float32) + if writer is None: writer = ImmutableDatasetWriter() if collate_fn is None: diff --git a/lerobot/common/envs/factory.py b/lerobot/common/envs/factory.py index d7dc8aae..acf089f6 100644 --- a/lerobot/common/envs/factory.py +++ b/lerobot/common/envs/factory.py @@ -7,6 +7,8 @@ def make_env(cfg, transform=None): "from_pixels": cfg.env.from_pixels, "pixels_only": cfg.env.pixels_only, "image_size": cfg.env.image_size, + # TODO(rcadene): do we want a specific eval_env_seed? + "seed": cfg.seed, } if cfg.env.name == "simxarm": @@ -17,6 +19,8 @@ def make_env(cfg, transform=None): elif cfg.env.name == "pusht": from lerobot.common.envs.pusht import PushtEnv + # assert kwargs["seed"] > 200, "Seed 0-200 are used for the demonstration dataset, so we don't want to seed the eval env with this range." + clsfunc = PushtEnv else: raise ValueError(cfg.env.name) diff --git a/lerobot/common/envs/pusht.py b/lerobot/common/envs/pusht.py index 927a1ba7..62aa8d1b 100644 --- a/lerobot/common/envs/pusht.py +++ b/lerobot/common/envs/pusht.py @@ -101,14 +101,18 @@ class PushtEnv(EnvBase): obs = self._format_raw_obs(raw_obs) if self.num_prev_obs > 0: - # remove all previous observations + stacked_obs = {} if "image" in obs: - self._prev_obs_image_queue.clear() + self._prev_obs_image_queue = deque( + [obs["image"]] * (self.num_prev_obs + 1), maxlen=(self.num_prev_obs + 1) + ) + stacked_obs["image"] = torch.stack(list(self._prev_obs_image_queue)) if "state" in obs: - self._prev_obs_state_queue.clear() - - # copy the current observation n times - obs = self._stack_prev_obs(obs) + self._prev_obs_state_queue = deque( + [obs["state"]] * (self.num_prev_obs + 1), maxlen=(self.num_prev_obs + 1) + ) + stacked_obs["state"] = torch.stack(list(self._prev_obs_state_queue)) + obs = stacked_obs td = TensorDict( { @@ -121,40 +125,6 @@ class PushtEnv(EnvBase): raise NotImplementedError() return td - def _stack_prev_obs(self, obs): - """When the queue is empty, copy the current observation n times.""" - assert self.num_prev_obs > 0 - - def stack_update_queue(prev_obs_queue, obs, num_prev_obs): - # get n most recent observations - prev_obs = list(prev_obs_queue)[-num_prev_obs:] - - # if not enough observations, copy the oldest observation until we obtain n observations - if len(prev_obs) == 0: - prev_obs = [obs] * num_prev_obs # queue is empty when env reset - elif len(prev_obs) < num_prev_obs: - prev_obs = [prev_obs[0] for _ in range(num_prev_obs - len(prev_obs))] + prev_obs - - # stack n most recent observations with the current observation - stacked_obs = torch.stack(prev_obs + [obs], dim=0) - - # add current observation to the queue - # automatically remove oldest observation when queue is full - prev_obs_queue.appendleft(obs) - - return stacked_obs - - stacked_obs = {} - if "image" in obs: - stacked_obs["image"] = stack_update_queue( - self._prev_obs_image_queue, obs["image"], self.num_prev_obs - ) - if "state" in obs: - stacked_obs["state"] = stack_update_queue( - self._prev_obs_state_queue, obs["state"], self.num_prev_obs - ) - return stacked_obs - def _step(self, tensordict: TensorDict): td = tensordict action = td["action"].numpy() @@ -176,7 +146,14 @@ class PushtEnv(EnvBase): obs = self._format_raw_obs(raw_obs) if self.num_prev_obs > 0: - obs = self._stack_prev_obs(obs) + stacked_obs = {} + if "image" in obs: + self._prev_obs_image_queue.append(obs["image"]) + stacked_obs["image"] = torch.stack(list(self._prev_obs_image_queue)) + if "state" in obs: + self._prev_obs_state_queue.append(obs["state"]) + stacked_obs["state"] = torch.stack(list(self._prev_obs_state_queue)) + obs = stacked_obs td = TensorDict( { diff --git a/lerobot/common/logger.py b/lerobot/common/logger.py index a8cb6a66..54325bd4 100644 --- a/lerobot/common/logger.py +++ b/lerobot/common/logger.py @@ -1,51 +1,11 @@ -import contextlib +import logging import os from pathlib import Path -import numpy as np from omegaconf import OmegaConf from termcolor import colored -def make_dir(dir_path): - """Create directory if it does not already exist.""" - with contextlib.suppress(OSError): - dir_path.mkdir(parents=True, exist_ok=True) - return dir_path - - -def print_run(cfg, reward=None): - """Pretty-printing of run information. Call at start of training.""" - prefix, color, attrs = " ", "green", ["bold"] - - def limstr(s, maxlen=32): - return str(s[:maxlen]) + "..." if len(str(s)) > maxlen else s - - def pprint(k, v): - print( - prefix + colored(f'{k.capitalize() + ":":<16}', color, attrs=attrs), - limstr(v), - ) - - kvs = [ - ("task", cfg.env.task), - ("offline_steps", f"{cfg.offline_steps}"), - ("online_steps", f"{cfg.online_steps}"), - ("action_repeat", f"{cfg.env.action_repeat}"), - # ('observations', 'x'.join([str(s) for s in cfg.obs_shape])), - # ('actions', cfg.action_dim), - # ('experiment', cfg.exp_name), - ] - if reward is not None: - kvs.append(("episode reward", colored(str(int(reward)), "white", attrs=["bold"]))) - w = np.max([len(limstr(str(kv[1]))) for kv in kvs]) + 21 - div = "-" * w - print(div) - for k, v in kvs: - pprint(k, v) - print(div) - - def cfg_to_group(cfg, return_list=False): """Return a wandb-safe group name for logging. Optionally returns group name as list.""" # lst = [cfg.task, cfg.modality, re.sub("[^0-9a-zA-Z]+", "-", cfg.exp_name)] @@ -71,13 +31,12 @@ class Logger: self._seed = cfg.seed self._cfg = cfg self._eval = [] - print_run(cfg) project = cfg.get("wandb", {}).get("project") entity = cfg.get("wandb", {}).get("entity") enable_wandb = cfg.get("wandb", {}).get("enable", False) run_offline = not enable_wandb or not project or not entity if run_offline: - print(colored("Logs will be saved locally.", "yellow", attrs=["bold"])) + logging.info(colored("Logs will be saved locally.", "yellow", attrs=["bold"])) self._wandb = None else: os.environ["WANDB_SILENT"] = "true" @@ -134,7 +93,6 @@ class Logger: self.save_buffer(buffer, identifier="buffer") if self._wandb: self._wandb.finish() - print_run(self._cfg, self._eval[-1][-1]) def log_dict(self, d, step, mode="train"): assert mode in {"train", "eval"} diff --git a/lerobot/common/policies/diffusion/policy.py b/lerobot/common/policies/diffusion/policy.py index df05bfd8..7ae0a529 100644 --- a/lerobot/common/policies/diffusion/policy.py +++ b/lerobot/common/policies/diffusion/policy.py @@ -4,10 +4,8 @@ import time import hydra import torch import torch.nn as nn -from diffusers.schedulers.scheduling_ddpm import DDPMScheduler from diffusion_policy.model.common.lr_scheduler import get_scheduler -from diffusion_policy.model.vision.model_getter import get_resnet from .diffusion_unet_image_policy import DiffusionUnetImagePolicy from .multi_image_obs_encoder import MultiImageObsEncoder @@ -39,8 +37,8 @@ class DiffusionPolicy(nn.Module): super().__init__() self.cfg = cfg - noise_scheduler = DDPMScheduler(**cfg_noise_scheduler) - rgb_model = get_resnet(**cfg_rgb_model) + noise_scheduler = hydra.utils.instantiate(cfg_noise_scheduler) + rgb_model = hydra.utils.instantiate(cfg_rgb_model) obs_encoder = MultiImageObsEncoder( rgb_model=rgb_model, **cfg_obs_encoder, @@ -127,16 +125,36 @@ class DiffusionPolicy(nn.Module): # (t h) ... -> t h ... batch = batch.reshape(num_slices, horizon) # .transpose(1, 0).contiguous() + # |-1|0|1|2|3|4|5|6|7|8|9|10|11|12|13|14| timestamps: 16 + # |o|o| observations: 2 + # | |a|a|a|a|a|a|a|a| actions executed: 8 + # |p|p|p|p|p|p|p|p|p|p|p| p| p| p| p| p| actions predicted: 16 + # note: we predict the action needed to go from t=-1 to t=0 similarly to an inverse kinematic model + + image = batch["observation", "image"] + state = batch["observation", "state"] + action = batch["action"] + assert image.shape[1] == horizon + assert state.shape[1] == horizon + assert action.shape[1] == horizon + + if not (horizon == 16 and self.cfg.n_obs_steps == 2): + raise NotImplementedError() + + # keep first 2 observations of the slice corresponding to t=[-1,0] + image = image[:, : self.cfg.n_obs_steps] + state = state[:, : self.cfg.n_obs_steps] + out = { "obs": { - "image": batch["observation", "image"].to(self.device, non_blocking=True), - "agent_pos": batch["observation", "state"].to(self.device, non_blocking=True), + "image": image.to(self.device, non_blocking=True), + "agent_pos": state.to(self.device, non_blocking=True), }, - "action": batch["action"].to(self.device, non_blocking=True), + "action": action.to(self.device, non_blocking=True), } return out - batch = replay_buffer.sample(batch_size) if self.cfg.balanced_sampling else replay_buffer.sample() + batch = replay_buffer.sample(batch_size) batch = process_batch(batch, self.cfg.horizon, num_slices) data_s = time.time() - start_time diff --git a/lerobot/configs/default.yaml b/lerobot/configs/default.yaml index 97f560f5..66e3dff1 100644 --- a/lerobot/configs/default.yaml +++ b/lerobot/configs/default.yaml @@ -1,7 +1,7 @@ defaults: - _self_ - - env: simxarm - - policy: tdmpc + - env: pusht + - policy: diffusion hydra: run: @@ -22,6 +22,7 @@ save_buffer: false train_steps: ??? fps: ??? +n_action_steps: ??? env: ??? policy: ??? diff --git a/lerobot/configs/policy/diffusion.yaml b/lerobot/configs/policy/diffusion.yaml index f136fa55..da1b6545 100644 --- a/lerobot/configs/policy/diffusion.yaml +++ b/lerobot/configs/policy/diffusion.yaml @@ -21,7 +21,7 @@ past_action_visible: False keypoint_visible_rate: 1.0 obs_as_global_cond: True -eval_episodes: 50 +eval_episodes: 1 eval_freq: 10000 save_freq: 100000 log_freq: 250 @@ -40,8 +40,8 @@ policy: num_inference_steps: 100 obs_as_global_cond: ${obs_as_global_cond} # crop_shape: null - diffusion_step_embed_dim: 128 - down_dims: [512, 1024, 2048] + diffusion_step_embed_dim: 256 # before 128 + down_dims: [256, 512, 1024] # before [512, 1024, 2048] kernel_size: 5 n_groups: 8 cond_predict_scale: True @@ -62,7 +62,7 @@ policy: grad_clip_norm: 0 noise_scheduler: - # _target_: diffusers.schedulers.scheduling_ddpm.DDPMScheduler + _target_: diffusers.schedulers.scheduling_ddpm.DDPMScheduler num_train_timesteps: 100 beta_start: 0.0001 beta_end: 0.02 @@ -74,16 +74,16 @@ noise_scheduler: obs_encoder: # _target_: diffusion_policy.model.vision.multi_image_obs_encoder.MultiImageObsEncoder shape_meta: ${shape_meta} - resize_shape: null - crop_shape: [76, 76] + # resize_shape: null + # crop_shape: [76, 76] # constant center crop - random_crop: True + # random_crop: True use_group_norm: True share_rgb_model: False - imagenet_norm: False # TODO(rcadene): was set to True + imagenet_norm: True rgb_model: - #_target_: diffusion_policy.model.vision.model_getter.get_resnet + _target_: diffusion_policy.model.vision.model_getter.get_resnet name: resnet18 weights: null diff --git a/lerobot/configs/policy/tdmpc.yaml b/lerobot/configs/policy/tdmpc.yaml index 26dc4e51..16b7018e 100644 --- a/lerobot/configs/policy/tdmpc.yaml +++ b/lerobot/configs/policy/tdmpc.yaml @@ -1,5 +1,7 @@ # @package _global_ +n_action_steps: 1 + policy: name: tdmpc diff --git a/lerobot/scripts/eval.py b/lerobot/scripts/eval.py index 214f5dba..abe4645a 100644 --- a/lerobot/scripts/eval.py +++ b/lerobot/scripts/eval.py @@ -137,7 +137,7 @@ def eval(cfg: dict, out_dir=None): save_video=True, video_dir=Path(out_dir) / "eval", fps=cfg.env.fps, - max_steps=cfg.env.episode_length, + max_steps=cfg.env.episode_length // cfg.n_action_steps, num_episodes=cfg.eval_episodes, ) print(metrics) diff --git a/lerobot/scripts/train.py b/lerobot/scripts/train.py index 1c63fc97..1557495b 100644 --- a/lerobot/scripts/train.py +++ b/lerobot/scripts/train.py @@ -119,7 +119,6 @@ def train(cfg: dict, out_dir=None, job_name=None): torch.backends.cudnn.benchmark = True torch.backends.cuda.matmul.allow_tf32 = True set_seed(cfg.seed) - logging.info(colored("Work dir:", "yellow", attrs=["bold"]) + f" {out_dir}") logging.info("make_offline_buffer") offline_buffer = make_offline_buffer(cfg) @@ -149,6 +148,9 @@ def train(cfg: dict, out_dir=None, job_name=None): logging.info("make_policy") policy = make_policy(cfg) + num_learnable_params = sum(p.numel() for p in policy.parameters() if p.requires_grad) + num_total_params = sum(p.numel() for p in policy.parameters()) + td_policy = TensorDictModule( policy, in_keys=["observation", "step_count"], @@ -158,6 +160,16 @@ def train(cfg: dict, out_dir=None, job_name=None): # log metrics to terminal and wandb logger = Logger(out_dir, job_name, cfg) + logging.info(colored("Work dir:", "yellow", attrs=["bold"]) + f" {out_dir}") + logging.info(f"{cfg.env.task=}") + logging.info(f"{cfg.offline_steps=} ({format_big_number(cfg.offline_steps)})") + logging.info(f"{cfg.online_steps=}") + logging.info(f"{cfg.env.action_repeat=}") + logging.info(f"{offline_buffer.num_samples=} ({format_big_number(offline_buffer.num_samples)})") + logging.info(f"{offline_buffer.num_episodes=}") + logging.info(f"{num_learnable_params=} ({format_big_number(num_learnable_params)})") + logging.info(f"{num_total_params=} ({format_big_number(num_total_params)})") + step = 0 # number of policy update is_offline = True @@ -175,6 +187,7 @@ def train(cfg: dict, out_dir=None, job_name=None): env, td_policy, num_episodes=cfg.eval_episodes, + max_steps=cfg.env.episode_length // cfg.n_action_steps, return_first_video=True, ) log_eval_info(logger, eval_info, step, cfg, offline_buffer, is_offline) @@ -199,11 +212,11 @@ def train(cfg: dict, out_dir=None, job_name=None): # TODO: add configurable number of rollout? (default=1) with torch.no_grad(): rollout = env.rollout( - max_steps=cfg.env.episode_length, + max_steps=cfg.env.episode_length // cfg.n_action_steps, policy=td_policy, auto_cast_to_device=True, ) - assert len(rollout) <= cfg.env.episode_length + assert len(rollout) <= cfg.env.episode_length // cfg.n_action_steps # set same episode index for all time steps contained in this rollout rollout["episode"] = torch.tensor([env_step] * len(rollout), dtype=torch.int) online_buffer.extend(rollout) @@ -235,6 +248,7 @@ def train(cfg: dict, out_dir=None, job_name=None): env, td_policy, num_episodes=cfg.eval_episodes, + max_steps=cfg.env.episode_length // cfg.n_action_steps, return_first_video=True, ) log_eval_info(logger, eval_info, step, cfg, offline_buffer, is_offline) From e29fbb50e8230211871197214405cababcad0202 Mon Sep 17 00:00:00 2001 From: Remi Cadene Date: Mon, 4 Mar 2024 17:26:34 +0000 Subject: [PATCH 7/7] Fix grad_clip_norm 0 -> 10, Fix normalization min_max to be per channel --- lerobot/common/datasets/pusht.py | 108 ++++++++++++++++---------- lerobot/configs/policy/diffusion.yaml | 2 +- 2 files changed, 68 insertions(+), 42 deletions(-) diff --git a/lerobot/common/datasets/pusht.py b/lerobot/common/datasets/pusht.py index 8ea64f86..11569ee2 100644 --- a/lerobot/common/datasets/pusht.py +++ b/lerobot/common/datasets/pusht.py @@ -136,14 +136,14 @@ class PushtExperienceReplay(TensorDictReplayBuffer): storage = TensorStorage(TensorDict.load_memmap(self.root / dataset_id)) stats = self._compute_or_load_stats(storage) - stats["next", "observation", "image"] = stats["observation", "image"] - stats["next", "observation", "state"] = stats["observation", "state"] transform = NormalizeTransform( stats, in_keys=[ + # TODO(rcadene): imagenet normalization is applied inside diffusion policy + # We need to automate this for tdmpc and others # ("observation", "image"), ("observation", "state"), - # TODO(rcadene): for tdmpc, we might want image and state + # TODO(rcadene): for tdmpc, we might want next image and state # ("next", "observation", "image"), # ("next", "observation", "state"), ("action"), @@ -151,7 +151,7 @@ class PushtExperienceReplay(TensorDictReplayBuffer): mode="min_max", ) - # TODO(rcadene): make normalization strategy configurable between mean_std, min_max, min_max_spec + # TODO(rcadene): make normalization strategy configurable between mean_std, min_max, manual_min_max, min_max_from_spec transform.stats["observation", "state", "min"] = torch.tensor( [13.456424, 32.938293], dtype=torch.float32 ) @@ -302,29 +302,43 @@ class PushtExperienceReplay(TensorDictReplayBuffer): prefetch=True, ) batch = rb.sample() - image_mean = torch.zeros(batch["observation", "image"].shape[1]) - image_std = torch.zeros(batch["observation", "image"].shape[1]) - image_max = -math.inf - image_min = math.inf - state_mean = torch.zeros(batch["observation", "state"].shape[1]) - state_std = torch.zeros(batch["observation", "state"].shape[1]) - state_max = -math.inf - state_min = math.inf - action_mean = torch.zeros(batch["action"].shape[1]) - action_std = torch.zeros(batch["action"].shape[1]) - action_max = -math.inf - action_min = math.inf + + image_channels = batch["observation", "image"].shape[1] + image_mean = torch.zeros(image_channels) + image_std = torch.zeros(image_channels) + image_max = torch.tensor([-math.inf] * image_channels) + image_min = torch.tensor([math.inf] * image_channels) + + state_channels = batch["observation", "state"].shape[1] + state_mean = torch.zeros(state_channels) + state_std = torch.zeros(state_channels) + state_max = torch.tensor([-math.inf] * state_channels) + state_min = torch.tensor([math.inf] * state_channels) + + action_channels = batch["action"].shape[1] + action_mean = torch.zeros(action_channels) + action_std = torch.zeros(action_channels) + action_max = torch.tensor([-math.inf] * action_channels) + action_min = torch.tensor([math.inf] * action_channels) for _ in tqdm.tqdm(range(num_batch)): - image_mean += einops.reduce(batch["observation", "image"], "b c h w -> c", reduction="mean") - state_mean += batch["observation", "state"].mean(dim=0) - action_mean += batch["action"].mean(dim=0) - image_max = max(image_max, batch["observation", "image"].max().item()) - image_min = min(image_min, batch["observation", "image"].min().item()) - state_max = max(state_max, batch["observation", "state"].max().item()) - state_min = min(state_min, batch["observation", "state"].min().item()) - action_max = max(action_max, batch["action"].max().item()) - action_min = min(action_min, batch["action"].min().item()) + image_mean += einops.reduce(batch["observation", "image"], "b c h w -> c", "mean") + state_mean += einops.reduce(batch["observation", "state"], "b c -> c", "mean") + action_mean += einops.reduce(batch["action"], "b c -> c", "mean") + + b_image_max = einops.reduce(batch["observation", "image"], "b c h w -> c", "max") + b_image_min = einops.reduce(batch["observation", "image"], "b c h w -> c", "min") + b_state_max = einops.reduce(batch["observation", "state"], "b c -> c", "max") + b_state_min = einops.reduce(batch["observation", "state"], "b c -> c", "min") + b_action_max = einops.reduce(batch["action"], "b c -> c", "max") + b_action_min = einops.reduce(batch["action"], "b c -> c", "min") + image_max = torch.maximum(image_max, b_image_max) + image_min = torch.maximum(image_min, b_image_min) + state_max = torch.maximum(state_max, b_state_max) + state_min = torch.maximum(state_min, b_state_min) + action_max = torch.maximum(action_max, b_action_max) + action_min = torch.maximum(action_min, b_action_min) + batch = rb.sample() image_mean /= num_batch @@ -332,16 +346,26 @@ class PushtExperienceReplay(TensorDictReplayBuffer): action_mean /= num_batch for i in tqdm.tqdm(range(num_batch)): - image_mean_batch = einops.reduce(batch["observation", "image"], "b c h w -> c", reduction="mean") - image_std += (image_mean_batch - image_mean) ** 2 - state_std += (batch["observation", "state"].mean(dim=0) - state_mean) ** 2 - action_std += (batch["action"].mean(dim=0) - action_mean) ** 2 - image_max = max(image_max, batch["observation", "image"].max().item()) - image_min = min(image_min, batch["observation", "image"].min().item()) - state_max = max(state_max, batch["observation", "state"].max().item()) - state_min = min(state_min, batch["observation", "state"].min().item()) - action_max = max(action_max, batch["action"].max().item()) - action_min = min(action_min, batch["action"].min().item()) + b_image_mean = einops.reduce(batch["observation", "image"], "b c h w -> c", "mean") + b_state_mean = einops.reduce(batch["observation", "state"], "b c -> c", "mean") + b_action_mean = einops.reduce(batch["action"], "b c -> c", "mean") + image_std += (b_image_mean - image_mean) ** 2 + state_std += (b_state_mean - state_mean) ** 2 + action_std += (b_action_mean - action_mean) ** 2 + + b_image_max = einops.reduce(batch["observation", "image"], "b c h w -> c", "max") + b_image_min = einops.reduce(batch["observation", "image"], "b c h w -> c", "min") + b_state_max = einops.reduce(batch["observation", "state"], "b c -> c", "max") + b_state_min = einops.reduce(batch["observation", "state"], "b c -> c", "min") + b_action_max = einops.reduce(batch["action"], "b c -> c", "max") + b_action_min = einops.reduce(batch["action"], "b c -> c", "min") + image_max = torch.maximum(image_max, b_image_max) + image_min = torch.maximum(image_min, b_image_min) + state_max = torch.maximum(state_max, b_state_max) + state_min = torch.maximum(state_min, b_state_min) + action_max = torch.maximum(action_max, b_action_max) + action_min = torch.maximum(action_min, b_action_min) + if i < num_batch - 1: batch = rb.sample() @@ -353,19 +377,21 @@ class PushtExperienceReplay(TensorDictReplayBuffer): { ("observation", "image", "mean"): image_mean[None, :, None, None], ("observation", "image", "std"): image_std[None, :, None, None], - ("observation", "image", "max"): torch.tensor(image_max), - ("observation", "image", "min"): torch.tensor(image_min), + ("observation", "image", "max"): image_max[None, :, None, None], + ("observation", "image", "min"): image_min[None, :, None, None], ("observation", "state", "mean"): state_mean[None, :], ("observation", "state", "std"): state_std[None, :], - ("observation", "state", "max"): torch.tensor(state_max), - ("observation", "state", "min"): torch.tensor(state_min), + ("observation", "state", "max"): state_max[None, :], + ("observation", "state", "min"): state_min[None, :], ("action", "mean"): action_mean[None, :], ("action", "std"): action_std[None, :], - ("action", "max"): torch.tensor(action_max), - ("action", "min"): torch.tensor(action_min), + ("action", "max"): action_max[None, :], + ("action", "min"): action_min[None, :], }, batch_size=[], ) + stats["next", "observation", "image"] = stats["observation", "image"] + stats["next", "observation", "state"] = stats["observation", "state"] return stats def _compute_or_load_stats(self, storage) -> TensorDict: diff --git a/lerobot/configs/policy/diffusion.yaml b/lerobot/configs/policy/diffusion.yaml index da1b6545..0ea8f638 100644 --- a/lerobot/configs/policy/diffusion.yaml +++ b/lerobot/configs/policy/diffusion.yaml @@ -59,7 +59,7 @@ policy: use_ema: true lr_scheduler: cosine lr_warmup_steps: 500 - grad_clip_norm: 0 + grad_clip_norm: 10 noise_scheduler: _target_: diffusers.schedulers.scheduling_ddpm.DDPMScheduler