From 1ae6205269094c8b31ceb94db7b95a50d234c08a Mon Sep 17 00:00:00 2001 From: Remi Cadene Date: Sat, 2 Mar 2024 15:53:29 +0000 Subject: [PATCH] Add Normalize, non_blocking=True in tdmpc, tdmpc run (TODO: diffusion) --- lerobot/common/datasets/factory.py | 7 +- lerobot/common/datasets/pusht.py | 102 +++++++++++++++++++++-------- lerobot/common/datasets/utils.py | 3 + lerobot/common/envs/factory.py | 6 +- lerobot/common/envs/pusht.py | 3 +- lerobot/common/envs/transforms.py | 47 +++++++++++++ lerobot/common/policies/tdmpc.py | 47 ++++++------- lerobot/scripts/eval.py | 8 ++- lerobot/scripts/train.py | 27 ++++---- 9 files changed, 183 insertions(+), 67 deletions(-) diff --git a/lerobot/common/datasets/factory.py b/lerobot/common/datasets/factory.py index 9fc0d2da..942a36dd 100644 --- a/lerobot/common/datasets/factory.py +++ b/lerobot/common/datasets/factory.py @@ -1,3 +1,4 @@ +import os from pathlib import Path import torch @@ -6,7 +7,7 @@ from torchrl.data.replay_buffers import PrioritizedSliceSampler from lerobot.common.datasets.pusht import PushtExperienceReplay from lerobot.common.datasets.simxarm import SimxarmExperienceReplay -DATA_PATH = Path("data/") +DATA_DIR = Path(os.environ.get("DATA_DIR", "data")) # TODO(rcadene): implement @@ -64,7 +65,7 @@ def make_offline_buffer(cfg, sampler=None): # download="force", download=True, streaming=False, - root=str(DATA_PATH), + root=str(DATA_DIR), sampler=sampler, batch_size=batch_size, pin_memory=pin_memory, @@ -74,7 +75,7 @@ def make_offline_buffer(cfg, sampler=None): offline_buffer = PushtExperienceReplay( "pusht", streaming=False, - root=DATA_PATH, + root=DATA_DIR, sampler=sampler, batch_size=batch_size, pin_memory=pin_memory, diff --git a/lerobot/common/datasets/pusht.py b/lerobot/common/datasets/pusht.py index 3fb0b20e..744b4f07 100644 --- a/lerobot/common/datasets/pusht.py +++ b/lerobot/common/datasets/pusht.py @@ -1,3 +1,4 @@ +import logging import os from pathlib import Path from typing import Callable @@ -16,9 +17,10 @@ from torchrl.data.replay_buffers.samplers import Sampler from torchrl.data.replay_buffers.storages import TensorStorage, _collate_id from torchrl.data.replay_buffers.writers import ImmutableDatasetWriter, Writer -from diffusion_policy.common.replay_buffer import ReplayBuffer +from diffusion_policy.common.replay_buffer import ReplayBuffer as DiffusionPolicyReplayBuffer from diffusion_policy.env.pusht.pusht_env import pymunk_to_shapely -from lerobot.common.datasets import utils +from lerobot.common.datasets.utils import download_and_extract_zip +from lerobot.common.envs.transforms import NormalizeTransform # as define in env SUCCESS_THRESHOLD = 0.95 # 95% coverage, @@ -132,29 +134,16 @@ class PushtExperienceReplay(TensorDictReplayBuffer): else: storage = TensorStorage(TensorDict.load_memmap(self.root / dataset_id)) - # if num_slices is not None or slice_len is not None: - # if sampler is not None: - # raise ValueError( - # "`num_slices` and `slice_len` are exclusive with the `sampler` argument." - # ) - - # if replacement: - # if not self.shuffle: - # raise RuntimeError( - # "shuffle=False can only be used when replacement=False." - # ) - # sampler = SliceSampler( - # num_slices=num_slices, - # slice_len=slice_len, - # strict_length=strict_length, - # ) - # else: - # sampler = SliceSamplerWithoutReplacement( - # num_slices=num_slices, - # slice_len=slice_len, - # strict_length=strict_length, - # shuffle=self.shuffle, - # ) + mean_std = self._compute_or_load_mean_std(storage) + mean_std["next", "observation", "image"] = mean_std["observation", "image"] + mean_std["next", "observation", "state"] = mean_std["observation", "state"] + transform = NormalizeTransform(mean_std, in_keys=[ + ("observation", "image"), + ("observation", "state"), + ("next", "observation", "image"), + ("next", "observation", "state"), + ("action"), + ]) if writer is None: writer = ImmutableDatasetWriter() @@ -193,10 +182,10 @@ class PushtExperienceReplay(TensorDictReplayBuffer): zarr_path = (raw_dir / PUSHT_ZARR).resolve() if not zarr_path.is_dir(): raw_dir.mkdir(parents=True, exist_ok=True) - utils.download_and_extract_zip(PUSHT_URL, raw_dir) + download_and_extract_zip(PUSHT_URL, raw_dir) # load - dataset_dict = ReplayBuffer.copy_from_path(zarr_path) # , keys=['img', 'state', 'action']) + dataset_dict = DiffusionPolicyReplayBuffer.copy_from_path(zarr_path) # , keys=['img', 'state', 'action']) episode_ids = torch.from_numpy(dataset_dict.get_episode_idxs()) num_episodes = dataset_dict.meta["episode_ends"].shape[0] @@ -287,3 +276,62 @@ class PushtExperienceReplay(TensorDictReplayBuffer): idxtd = idxtd + len(episode) return TensorStorage(td_data.lock_()) + + def _compute_mean_std(self, storage, num_batch=10, batch_size=32): + rb = TensorDictReplayBuffer( + storage=storage, + batch_size=batch_size, + prefetch=True, + ) + batch = rb.sample() + image_mean = torch.zeros(batch["observation", "image"].shape[1]) + image_std = torch.zeros(batch["observation", "image"].shape[1]) + state_mean = torch.zeros(batch["observation", "state"].shape[1]) + state_std = torch.zeros(batch["observation", "state"].shape[1]) + action_mean = torch.zeros(batch["action"].shape[1]) + action_std = torch.zeros(batch["action"].shape[1]) + + for i in tqdm.tqdm(range(num_batch)): + image_mean += einops.reduce(batch["observation", "image"], 'b c h w -> c', reduction='mean') + state_mean += batch["observation", "state"].mean(dim=0) + action_mean += batch["action"].mean(dim=0) + batch = rb.sample() + + image_mean /= num_batch + state_mean /= num_batch + action_mean /= num_batch + + for i in tqdm.tqdm(range(num_batch)): + image_mean_batch = einops.reduce(batch["observation", "image"], 'b c h w -> c', reduction='mean') + image_std += (image_mean_batch - image_mean) ** 2 + state_std += (batch["observation", "state"].mean(dim=0) - state_mean) ** 2 + action_std += (batch["action"].mean(dim=0) - action_mean) ** 2 + if i < num_batch - 1: + batch = rb.sample() + + image_std = torch.sqrt(image_std / num_batch) + state_std = torch.sqrt(state_std / num_batch) + action_std = torch.sqrt(action_std / num_batch) + + mean_std = TensorDict( + { + ("observation", "image", "mean"): image_mean[None,:,None,None], + ("observation", "image", "std"): image_std[None,:,None,None], + ("observation", "state", "mean"): state_mean[None,:], + ("observation", "state", "std"): state_std[None,:], + ("action", "mean"): action_mean[None,:], + ("action", "std"): action_std[None,:], + }, + batch_size=[], + ) + return mean_std + + def _compute_or_load_mean_std(self, storage) -> TensorDict: + mean_std_path = self.root / self.dataset_id / "mean_std.pth" + if mean_std_path.exists(): + mean_std = torch.load(mean_std_path) + else: + logging.info(f"compute_mean_std and save to {mean_std_path}") + mean_std = self._compute_mean_std(storage) + torch.save(mean_std, mean_std_path) + return mean_std diff --git a/lerobot/common/datasets/utils.py b/lerobot/common/datasets/utils.py index 0ad43a65..8572f6e9 100644 --- a/lerobot/common/datasets/utils.py +++ b/lerobot/common/datasets/utils.py @@ -3,9 +3,11 @@ import zipfile from pathlib import Path import requests +from tensordict import TensorDictBase import tqdm + def download_and_extract_zip(url: str, destination_folder: Path) -> bool: print(f"downloading from {url}") response = requests.get(url, stream=True) @@ -28,3 +30,4 @@ def download_and_extract_zip(url: str, destination_folder: Path) -> bool: return True else: return False + diff --git a/lerobot/common/envs/factory.py b/lerobot/common/envs/factory.py index b93f3541..2cd4f73b 100644 --- a/lerobot/common/envs/factory.py +++ b/lerobot/common/envs/factory.py @@ -3,7 +3,7 @@ from torchrl.envs.transforms import StepCounter, TransformedEnv from lerobot.common.envs.transforms import Prod -def make_env(cfg): +def make_env(cfg, transform=None): kwargs = { "frame_skip": cfg.env.action_repeat, "from_pixels": cfg.env.from_pixels, @@ -32,6 +32,10 @@ def make_env(cfg): # to ensure pusht is in [0,255] like simxarm env.append_transform(Prod(in_keys=[("observation", "image")], prod=255.0)) + if transform is not None: + # useful to add mean and std normalization + env.append_transform(transform) + return env diff --git a/lerobot/common/envs/pusht.py b/lerobot/common/envs/pusht.py index 45cbc705..39bf3bba 100644 --- a/lerobot/common/envs/pusht.py +++ b/lerobot/common/envs/pusht.py @@ -100,7 +100,8 @@ class PushtEnv(EnvBase): def _step(self, tensordict: TensorDict): td = tensordict - action = td["action"].numpy() + # remove batch dim + action = td["action"].squeeze(0).numpy() # step expects shape=(4,) so we pad if necessary # TODO(rcadene): add info["is_success"] and info["success"] ? sum_reward = 0 diff --git a/lerobot/common/envs/transforms.py b/lerobot/common/envs/transforms.py index 1a3c1ce1..2579dc18 100644 --- a/lerobot/common/envs/transforms.py +++ b/lerobot/common/envs/transforms.py @@ -1,7 +1,10 @@ from typing import Sequence +from tensordict import TensorDictBase from tensordict.utils import NestedKey from torchrl.envs.transforms import ObservationTransform +from torchrl.envs.transforms import Transform +from tensordict.nn import dispatch class Prod(ObservationTransform): @@ -19,3 +22,47 @@ class Prod(ObservationTransform): for key in self.in_keys: obs_spec[key].space.high *= self.prod return obs_spec + + +class NormalizeTransform(Transform): + invertible = True + + def __init__(self, + mean_std: TensorDictBase, + in_keys: Sequence[NestedKey] = None, + out_keys: Sequence[NestedKey] | None = None, + in_keys_inv: Sequence[NestedKey] | None = None, + out_keys_inv: Sequence[NestedKey] | None = None, + ): + if out_keys is None: + out_keys = in_keys + if in_keys_inv is None: + in_keys_inv = out_keys + if out_keys_inv is None: + out_keys_inv = in_keys + super().__init__(in_keys=in_keys, out_keys=out_keys, in_keys_inv=in_keys_inv, out_keys_inv=out_keys_inv) + self.mean_std = mean_std + + @dispatch(source="in_keys", dest="out_keys") + def forward(self, tensordict: TensorDictBase) -> TensorDictBase: + return self._call(tensordict) + + def _call(self, td: TensorDictBase) -> TensorDictBase: + for inkey, outkey in zip(self.in_keys, self.out_keys): + # TODO(rcadene): don't know how to do `inkey not in td` + if td.get(inkey, None) is None: + continue + mean = self.mean_std[inkey]["mean"] + std = self.mean_std[inkey]["std"] + td[outkey] = (td[inkey] - mean) / (std + 1e-8) + return td + + def _inv_call(self, td: TensorDictBase) -> TensorDictBase: + for inkey, outkey in zip(self.in_keys_inv, self.out_keys_inv): + # TODO(rcadene): don't know how to do `inkey not in td` + if td.get(inkey, None) is None: + continue + mean = self.mean_std[inkey]["mean"] + std = self.mean_std[inkey]["std"] + td[outkey] = td[inkey] * std + mean + return td diff --git a/lerobot/common/policies/tdmpc.py b/lerobot/common/policies/tdmpc.py index 19f00b5f..124c4438 100644 --- a/lerobot/common/policies/tdmpc.py +++ b/lerobot/common/policies/tdmpc.py @@ -5,6 +5,7 @@ from copy import deepcopy import einops import numpy as np +from tensordict import TensorDict import torch import torch.nn as nn @@ -126,19 +127,30 @@ class TDMPC(nn.Module): @torch.no_grad() def forward(self, observation, step_count): t0 = step_count.item() == 0 + + # TODO(rcadene): remove unsqueeze hack... + if observation["image"].ndim == 3: + observation["image"] = observation["image"].unsqueeze(0) + observation["state"] = observation["state"].unsqueeze(0) + obs = { - "rgb": observation["image"], - "state": observation["state"], + # TODO(rcadene): remove contiguous hack... + "rgb": observation["image"].contiguous(), + "state": observation["state"].contiguous(), } - return self.act(obs, t0=t0, step=self.step.item()) + action = self.act(obs, t0=t0, step=self.step.item()) + + # TODO(rcadene): hack to postprocess action (e.g. unnormalize) + # action = action * self.action_std + self.action_mean + return action @torch.no_grad() def act(self, obs, t0=False, step=None): """Take an action. Uses either MPC or the learned policy, depending on the self.cfg.mpc flag.""" if isinstance(obs, dict): - obs = {k: o.detach().unsqueeze(0) for k, o in obs.items()} + obs = {k: o.detach() for k, o in obs.items()} else: - obs = obs.detach().unsqueeze(0) + obs = obs.detach() z = self.model.encode(obs) if self.cfg.mpc: a = self.plan(z, t0=t0, step=step) @@ -315,26 +327,20 @@ class TDMPC(nn.Module): # trajectory t = 256, horizon h = 5 # (t h) ... -> h t ... batch = batch.reshape(num_slices, horizon).transpose(1, 0).contiguous() - batch = batch.to(self.device) obs = { - "rgb": batch["observation", "image"][FIRST_FRAME].float(), - "state": batch["observation", "state"][FIRST_FRAME], + "rgb": batch["observation", "image"][FIRST_FRAME].to(self.device, non_blocking=True), + "state": batch["observation", "state"][FIRST_FRAME].to(self.device, non_blocking=True), } - action = batch["action"] + action = batch["action"].to(self.device, non_blocking=True) next_obses = { - "rgb": batch["next", "observation", "image"].float(), - "state": batch["next", "observation", "state"], + "rgb": batch["next", "observation", "image"].to(self.device, non_blocking=True), + "state": batch["next", "observation", "state"].to(self.device, non_blocking=True), } - reward = batch["next", "reward"] + reward = batch["next", "reward"].to(self.device, non_blocking=True) - # TODO(rcadene): add non_blocking=True - # for key in obs: - # obs[key] = obs[key].to(self.device, non_blocking=True) - # next_obses[key] = next_obses[key].to(self.device, non_blocking=True) - - # action = action.to(self.device, non_blocking=True) - # reward = reward.to(self.device, non_blocking=True) + idxs = batch["index"][FIRST_FRAME].to(self.device, non_blocking=True) + weights = batch["_weight"][FIRST_FRAME, :, None].to(self.device, non_blocking=True) # TODO(rcadene): rearrange directly in offline dataset if reward.ndim == 2: @@ -347,9 +353,6 @@ class TDMPC(nn.Module): # Neither does `batch["next", "terminated"]` done = torch.zeros_like(reward, dtype=torch.bool, device=reward.device) mask = torch.ones_like(reward, dtype=torch.bool, device=reward.device) - - idxs = batch["index"][FIRST_FRAME] - weights = batch["_weight"][FIRST_FRAME, :, None] return obs, action, next_obses, reward, mask, done, idxs, weights batch = replay_buffer.sample(batch_size) if self.cfg.balanced_sampling else replay_buffer.sample() diff --git a/lerobot/scripts/eval.py b/lerobot/scripts/eval.py index 3c5b8cf1..587e3bb7 100644 --- a/lerobot/scripts/eval.py +++ b/lerobot/scripts/eval.py @@ -1,3 +1,4 @@ +import logging import threading import time from pathlib import Path @@ -10,6 +11,7 @@ import tqdm from tensordict.nn import TensorDictModule from termcolor import colored from torchrl.envs import EnvBase +from lerobot.common.datasets.factory import make_offline_buffer from lerobot.common.envs.factory import make_env from lerobot.common.policies.factory import make_policy @@ -112,7 +114,11 @@ def eval(cfg: dict, out_dir=None): set_seed(cfg.seed) print(colored("Log dir:", "yellow", attrs=["bold"]), out_dir) - env = make_env(cfg) + logging.info("make_offline_buffer") + offline_buffer = make_offline_buffer(cfg) + + logging.info("make_env") + env = make_env(cfg, transform=offline_buffer.transform) if cfg.policy.pretrained_model_path: policy = make_policy(cfg) diff --git a/lerobot/scripts/train.py b/lerobot/scripts/train.py index 954c8f87..cd1fe15e 100644 --- a/lerobot/scripts/train.py +++ b/lerobot/scripts/train.py @@ -117,21 +117,10 @@ def train(cfg: dict, out_dir=None, job_name=None): assert torch.cuda.is_available() torch.backends.cudnn.benchmark = True + torch.backends.cuda.matmul.allow_tf32 = True set_seed(cfg.seed) logging.info(colored("Work dir:", "yellow", attrs=["bold"]) + f" {out_dir}") - logging.info("make_env") - env = make_env(cfg) - - logging.info("make_policy") - policy = make_policy(cfg) - - td_policy = TensorDictModule( - policy, - in_keys=["observation", "step_count"], - out_keys=["action"], - ) - logging.info("make_offline_buffer") offline_buffer = make_offline_buffer(cfg) @@ -151,8 +140,22 @@ def train(cfg: dict, out_dir=None, job_name=None): online_buffer = TensorDictReplayBuffer( storage=LazyMemmapStorage(100_000), sampler=online_sampler, + transform=offline_buffer._transform, ) + logging.info("make_env") + env = make_env(cfg, transform=offline_buffer._transform) + + logging.info("make_policy") + policy = make_policy(cfg, transform=offline_buffer._transform) + + td_policy = TensorDictModule( + policy, + in_keys=["observation", "step_count"], + out_keys=["action"], + ) + + # log metrics to terminal and wandb logger = Logger(out_dir, job_name, cfg) step = 0 # number of policy update