diff --git a/lerobot/common/datasets/pusht.py b/lerobot/common/datasets/pusht.py index 744b4f07..b5d4fab3 100644 --- a/lerobot/common/datasets/pusht.py +++ b/lerobot/common/datasets/pusht.py @@ -137,13 +137,16 @@ class PushtExperienceReplay(TensorDictReplayBuffer): mean_std = self._compute_or_load_mean_std(storage) mean_std["next", "observation", "image"] = mean_std["observation", "image"] mean_std["next", "observation", "state"] = mean_std["observation", "state"] - transform = NormalizeTransform(mean_std, in_keys=[ - ("observation", "image"), - ("observation", "state"), - ("next", "observation", "image"), - ("next", "observation", "state"), - ("action"), - ]) + transform = NormalizeTransform( + mean_std, + in_keys=[ + ("observation", "image"), + ("observation", "state"), + ("next", "observation", "image"), + ("next", "observation", "state"), + ("action"), + ], + ) if writer is None: writer = ImmutableDatasetWriter() @@ -185,7 +188,9 @@ class PushtExperienceReplay(TensorDictReplayBuffer): download_and_extract_zip(PUSHT_URL, raw_dir) # load - dataset_dict = DiffusionPolicyReplayBuffer.copy_from_path(zarr_path) # , keys=['img', 'state', 'action']) + dataset_dict = DiffusionPolicyReplayBuffer.copy_from_path( + zarr_path + ) # , keys=['img', 'state', 'action']) episode_ids = torch.from_numpy(dataset_dict.get_episode_idxs()) num_episodes = dataset_dict.meta["episode_ends"].shape[0] @@ -291,8 +296,8 @@ class PushtExperienceReplay(TensorDictReplayBuffer): action_mean = torch.zeros(batch["action"].shape[1]) action_std = torch.zeros(batch["action"].shape[1]) - for i in tqdm.tqdm(range(num_batch)): - image_mean += einops.reduce(batch["observation", "image"], 'b c h w -> c', reduction='mean') + for _ in tqdm.tqdm(range(num_batch)): + image_mean += einops.reduce(batch["observation", "image"], "b c h w -> c", reduction="mean") state_mean += batch["observation", "state"].mean(dim=0) action_mean += batch["action"].mean(dim=0) batch = rb.sample() @@ -302,25 +307,25 @@ class PushtExperienceReplay(TensorDictReplayBuffer): action_mean /= num_batch for i in tqdm.tqdm(range(num_batch)): - image_mean_batch = einops.reduce(batch["observation", "image"], 'b c h w -> c', reduction='mean') + image_mean_batch = einops.reduce(batch["observation", "image"], "b c h w -> c", reduction="mean") image_std += (image_mean_batch - image_mean) ** 2 state_std += (batch["observation", "state"].mean(dim=0) - state_mean) ** 2 action_std += (batch["action"].mean(dim=0) - action_mean) ** 2 if i < num_batch - 1: batch = rb.sample() - + image_std = torch.sqrt(image_std / num_batch) state_std = torch.sqrt(state_std / num_batch) action_std = torch.sqrt(action_std / num_batch) mean_std = TensorDict( { - ("observation", "image", "mean"): image_mean[None,:,None,None], - ("observation", "image", "std"): image_std[None,:,None,None], - ("observation", "state", "mean"): state_mean[None,:], - ("observation", "state", "std"): state_std[None,:], - ("action", "mean"): action_mean[None,:], - ("action", "std"): action_std[None,:], + ("observation", "image", "mean"): image_mean[None, :, None, None], + ("observation", "image", "std"): image_std[None, :, None, None], + ("observation", "state", "mean"): state_mean[None, :], + ("observation", "state", "std"): state_std[None, :], + ("action", "mean"): action_mean[None, :], + ("action", "std"): action_std[None, :], }, batch_size=[], ) diff --git a/lerobot/common/datasets/utils.py b/lerobot/common/datasets/utils.py index 8572f6e9..0ad43a65 100644 --- a/lerobot/common/datasets/utils.py +++ b/lerobot/common/datasets/utils.py @@ -3,11 +3,9 @@ import zipfile from pathlib import Path import requests -from tensordict import TensorDictBase import tqdm - def download_and_extract_zip(url: str, destination_folder: Path) -> bool: print(f"downloading from {url}") response = requests.get(url, stream=True) @@ -30,4 +28,3 @@ def download_and_extract_zip(url: str, destination_folder: Path) -> bool: return True else: return False - diff --git a/lerobot/common/envs/transforms.py b/lerobot/common/envs/transforms.py index 2579dc18..67601eac 100644 --- a/lerobot/common/envs/transforms.py +++ b/lerobot/common/envs/transforms.py @@ -1,10 +1,9 @@ from typing import Sequence -from tensordict import TensorDictBase -from tensordict.utils import NestedKey -from torchrl.envs.transforms import ObservationTransform -from torchrl.envs.transforms import Transform +from tensordict import TensorDictBase from tensordict.nn import dispatch +from tensordict.utils import NestedKey +from torchrl.envs.transforms import ObservationTransform, Transform class Prod(ObservationTransform): @@ -27,28 +26,31 @@ class Prod(ObservationTransform): class NormalizeTransform(Transform): invertible = True - def __init__(self, - mean_std: TensorDictBase, - in_keys: Sequence[NestedKey] = None, - out_keys: Sequence[NestedKey] | None = None, - in_keys_inv: Sequence[NestedKey] | None = None, - out_keys_inv: Sequence[NestedKey] | None = None, - ): + def __init__( + self, + mean_std: TensorDictBase, + in_keys: Sequence[NestedKey] = None, + out_keys: Sequence[NestedKey] | None = None, + in_keys_inv: Sequence[NestedKey] | None = None, + out_keys_inv: Sequence[NestedKey] | None = None, + ): if out_keys is None: out_keys = in_keys if in_keys_inv is None: in_keys_inv = out_keys if out_keys_inv is None: out_keys_inv = in_keys - super().__init__(in_keys=in_keys, out_keys=out_keys, in_keys_inv=in_keys_inv, out_keys_inv=out_keys_inv) + super().__init__( + in_keys=in_keys, out_keys=out_keys, in_keys_inv=in_keys_inv, out_keys_inv=out_keys_inv + ) self.mean_std = mean_std - + @dispatch(source="in_keys", dest="out_keys") def forward(self, tensordict: TensorDictBase) -> TensorDictBase: return self._call(tensordict) def _call(self, td: TensorDictBase) -> TensorDictBase: - for inkey, outkey in zip(self.in_keys, self.out_keys): + for inkey, outkey in zip(self.in_keys, self.out_keys, strict=False): # TODO(rcadene): don't know how to do `inkey not in td` if td.get(inkey, None) is None: continue @@ -58,7 +60,7 @@ class NormalizeTransform(Transform): return td def _inv_call(self, td: TensorDictBase) -> TensorDictBase: - for inkey, outkey in zip(self.in_keys_inv, self.out_keys_inv): + for inkey, outkey in zip(self.in_keys_inv, self.out_keys_inv, strict=False): # TODO(rcadene): don't know how to do `inkey not in td` if td.get(inkey, None) is None: continue diff --git a/lerobot/common/policies/diffusion.py b/lerobot/common/policies/diffusion.py index 6615fcc6..50e72b23 100644 --- a/lerobot/common/policies/diffusion.py +++ b/lerobot/common/policies/diffusion.py @@ -5,6 +5,7 @@ import hydra import torch import torch.nn as nn from diffusers.schedulers.scheduling_ddpm import DDPMScheduler + from diffusion_policy.model.common.lr_scheduler import get_scheduler from diffusion_policy.model.vision.model_getter import get_resnet from diffusion_policy.model.vision.multi_image_obs_encoder import MultiImageObsEncoder diff --git a/lerobot/common/policies/tdmpc.py b/lerobot/common/policies/tdmpc.py index 124c4438..b56e45df 100644 --- a/lerobot/common/policies/tdmpc.py +++ b/lerobot/common/policies/tdmpc.py @@ -5,7 +5,6 @@ from copy import deepcopy import einops import numpy as np -from tensordict import TensorDict import torch import torch.nn as nn @@ -127,7 +126,7 @@ class TDMPC(nn.Module): @torch.no_grad() def forward(self, observation, step_count): t0 = step_count.item() == 0 - + # TODO(rcadene): remove unsqueeze hack... if observation["image"].ndim == 3: observation["image"] = observation["image"].unsqueeze(0) @@ -147,10 +146,7 @@ class TDMPC(nn.Module): @torch.no_grad() def act(self, obs, t0=False, step=None): """Take an action. Uses either MPC or the learned policy, depending on the self.cfg.mpc flag.""" - if isinstance(obs, dict): - obs = {k: o.detach() for k, o in obs.items()} - else: - obs = obs.detach() + obs = {k: o.detach() for k, o in obs.items()} if isinstance(obs, dict) else obs.detach() z = self.model.encode(obs) if self.cfg.mpc: a = self.plan(z, t0=t0, step=step) diff --git a/lerobot/scripts/eval.py b/lerobot/scripts/eval.py index 587e3bb7..6391903e 100644 --- a/lerobot/scripts/eval.py +++ b/lerobot/scripts/eval.py @@ -11,8 +11,8 @@ import tqdm from tensordict.nn import TensorDictModule from termcolor import colored from torchrl.envs import EnvBase -from lerobot.common.datasets.factory import make_offline_buffer +from lerobot.common.datasets.factory import make_offline_buffer from lerobot.common.envs.factory import make_env from lerobot.common.policies.factory import make_policy from lerobot.common.utils import set_seed