From 7e024fdce646c1713202bd1ddc4042b58fb6e0c0 Mon Sep 17 00:00:00 2001 From: Simon Alibert Date: Thu, 29 Feb 2024 13:37:48 +0100 Subject: [PATCH] Ran pre-commit run --all-files --- lerobot/common/datasets/factory.py | 1 + lerobot/common/datasets/pusht.py | 36 ++++------- lerobot/common/datasets/simxarm.py | 28 ++------ lerobot/common/envs/pusht.py | 14 +--- lerobot/common/envs/simxarm.py | 18 ++---- lerobot/common/envs/transforms.py | 1 - lerobot/common/logger.py | 16 ++--- lerobot/common/policies/diffusion.py | 11 ++-- lerobot/common/policies/tdmpc.py | 85 ++++++++----------------- lerobot/common/policies/tdmpc_helper.py | 81 +++++++++-------------- lerobot/configs/env/pusht.yaml | 2 +- lerobot/configs/env/simxarm.yaml | 2 +- lerobot/scripts/eval.py | 3 +- lerobot/scripts/train.py | 42 +++++------- lerobot/scripts/visualize_dataset.py | 19 ++---- sbatch.sh | 2 +- 16 files changed, 124 insertions(+), 237 deletions(-) diff --git a/lerobot/common/datasets/factory.py b/lerobot/common/datasets/factory.py index 3f436b74..9a129ba1 100644 --- a/lerobot/common/datasets/factory.py +++ b/lerobot/common/datasets/factory.py @@ -70,6 +70,7 @@ def make_offline_buffer(cfg, sampler=None): offline_buffer = PushtExperienceReplay( "pusht", # download="force", + # TODO(aliberts): automate download download=False, streaming=False, root="data", diff --git a/lerobot/common/datasets/pusht.py b/lerobot/common/datasets/pusht.py index 640afac3..3602856e 100644 --- a/lerobot/common/datasets/pusht.py +++ b/lerobot/common/datasets/pusht.py @@ -1,7 +1,6 @@ import os -import pickle from pathlib import Path -from typing import Any, Callable, Dict, Tuple +from typing import Callable import einops import numpy as np @@ -10,25 +9,25 @@ import pymunk import torch import torchrl import tqdm -from diffusion_policy.common.replay_buffer import ReplayBuffer -from diffusion_policy.env.pusht.pusht_env import pymunk_to_shapely from tensordict import TensorDict from torchrl.data.datasets.utils import _get_root_dir from torchrl.data.replay_buffers.replay_buffers import ( - TensorDictPrioritizedReplayBuffer, TensorDictReplayBuffer, ) from torchrl.data.replay_buffers.samplers import ( Sampler, - SliceSampler, - SliceSamplerWithoutReplacement, ) from torchrl.data.replay_buffers.storages import TensorStorage, _collate_id from torchrl.data.replay_buffers.writers import ImmutableDatasetWriter, Writer +from diffusion_policy.common.replay_buffer import ReplayBuffer +from diffusion_policy.env.pusht.pusht_env import pymunk_to_shapely + # as define in env SUCCESS_THRESHOLD = 0.95 # 95% coverage, +DEFAULT_TEE_MASK = pymunk.ShapeFilter.ALL_MASKS() + def get_goal_pose_body(pose): mass = 1 @@ -53,7 +52,7 @@ def add_tee( angle, scale=30, color="LightSlateGray", - mask=pymunk.ShapeFilter.ALL_MASKS(), + mask=DEFAULT_TEE_MASK, ): mass = 1 length = 4 @@ -87,7 +86,6 @@ def add_tee( class PushtExperienceReplay(TensorDictReplayBuffer): - def __init__( self, dataset_id, @@ -127,7 +125,7 @@ class PushtExperienceReplay(TensorDictReplayBuffer): if split_trajs: raise NotImplementedError - if self.download == True: + if self.download: raise NotImplementedError() if root is None: @@ -193,18 +191,18 @@ class PushtExperienceReplay(TensorDictReplayBuffer): # TODO(rcadene) # load + # TODO(aliberts): Dynamic paths zarr_path = ( "/home/rcadene/code/diffusion_policy/data/pusht/pusht_cchi_v7_replay.zarr" + # "/home/simon/build/diffusion_policy/data/pusht/pusht_cchi_v7_replay.zarr" ) - dataset_dict = ReplayBuffer.copy_from_path( - zarr_path - ) # , keys=['img', 'state', 'action']) + dataset_dict = ReplayBuffer.copy_from_path(zarr_path) # , keys=['img', 'state', 'action']) episode_ids = dataset_dict.get_episode_idxs() num_episodes = dataset_dict.meta["episode_ends"].shape[0] total_frames = dataset_dict["action"].shape[0] assert len( - set([dataset_dict[key].shape[0] for key in dataset_dict.keys()]) + {dataset_dict[key].shape[0] for key in dataset_dict} ), "Some data type dont have the same number of total frames." # TODO: verify that goal pose is expected to be fixed @@ -245,9 +243,7 @@ class PushtExperienceReplay(TensorDictReplayBuffer): ] space.add(*walls) - block_body = add_tee( - space, block_pos[i].tolist(), block_angle[i].item() - ) + block_body = add_tee(space, block_pos[i].tolist(), block_angle[i].item()) goal_geom = pymunk_to_shapely(goal_body, block_body.shapes) block_geom = pymunk_to_shapely(block_body, block_body.shapes) intersection_area = goal_geom.intersection(block_geom).area @@ -278,11 +274,7 @@ class PushtExperienceReplay(TensorDictReplayBuffer): if episode_id == 0: # hack to initialize tensordict data structure to store episodes - td_data = ( - episode[0] - .expand(total_frames) - .memmap_like(self.root / self.dataset_id) - ) + td_data = episode[0].expand(total_frames).memmap_like(self.root / self.dataset_id) td_data[idxtd : idxtd + len(episode)] = episode diff --git a/lerobot/common/datasets/simxarm.py b/lerobot/common/datasets/simxarm.py index 84e6ca7c..b0e17d52 100644 --- a/lerobot/common/datasets/simxarm.py +++ b/lerobot/common/datasets/simxarm.py @@ -1,7 +1,7 @@ import os import pickle from pathlib import Path -from typing import Any, Callable, Dict, Tuple +from typing import Callable import torch import torchrl @@ -9,7 +9,6 @@ import tqdm from tensordict import TensorDict from torchrl.data.datasets.utils import _get_root_dir from torchrl.data.replay_buffers.replay_buffers import ( - TensorDictPrioritizedReplayBuffer, TensorDictReplayBuffer, ) from torchrl.data.replay_buffers.samplers import ( @@ -22,7 +21,6 @@ from torchrl.data.replay_buffers.writers import ImmutableDatasetWriter, Writer class SimxarmExperienceReplay(TensorDictReplayBuffer): - available_datasets = [ "xarm_lift_medium", ] @@ -77,15 +75,11 @@ class SimxarmExperienceReplay(TensorDictReplayBuffer): if num_slices is not None or slice_len is not None: if sampler is not None: - raise ValueError( - "`num_slices` and `slice_len` are exclusive with the `sampler` argument." - ) + raise ValueError("`num_slices` and `slice_len` are exclusive with the `sampler` argument.") if replacement: if not self.shuffle: - raise RuntimeError( - "shuffle=False can only be used when replacement=False." - ) + raise RuntimeError("shuffle=False can only be used when replacement=False.") sampler = SliceSampler( num_slices=num_slices, slice_len=slice_len, @@ -130,7 +124,7 @@ class SimxarmExperienceReplay(TensorDictReplayBuffer): # load dataset_dir = Path("data") / self.dataset_id - dataset_path = dataset_dir / f"buffer.pkl" + dataset_path = dataset_dir / "buffer.pkl" print(f"Using offline dataset '{dataset_path}'") with open(dataset_path, "rb") as f: dataset_dict = pickle.load(f) @@ -150,12 +144,8 @@ class SimxarmExperienceReplay(TensorDictReplayBuffer): image = torch.tensor(dataset_dict["observations"]["rgb"][idx0:idx1]) state = torch.tensor(dataset_dict["observations"]["state"][idx0:idx1]) - next_image = torch.tensor( - dataset_dict["next_observations"]["rgb"][idx0:idx1] - ) - next_state = torch.tensor( - dataset_dict["next_observations"]["state"][idx0:idx1] - ) + next_image = torch.tensor(dataset_dict["next_observations"]["rgb"][idx0:idx1]) + next_state = torch.tensor(dataset_dict["next_observations"]["state"][idx0:idx1]) next_reward = torch.tensor(dataset_dict["rewards"][idx0:idx1]) next_done = torch.tensor(dataset_dict["dones"][idx0:idx1]) @@ -176,11 +166,7 @@ class SimxarmExperienceReplay(TensorDictReplayBuffer): if episode_id == 0: # hack to initialize tensordict data structure to store episodes - td_data = ( - episode[0] - .expand(total_frames) - .memmap_like(self.root / self.dataset_id) - ) + td_data = episode[0].expand(total_frames).memmap_like(self.root / self.dataset_id) td_data[idx0:idx1] = episode diff --git a/lerobot/common/envs/pusht.py b/lerobot/common/envs/pusht.py index adc8c015..cf890046 100644 --- a/lerobot/common/envs/pusht.py +++ b/lerobot/common/envs/pusht.py @@ -1,7 +1,6 @@ import importlib from typing import Optional -import numpy as np import torch from tensordict import TensorDict from torchrl.data.tensor_specs import ( @@ -20,7 +19,6 @@ _has_diffpolicy = importlib.util.find_spec("diffusion_policy") is not None and _ class PushtEnv(EnvBase): - def __init__( self, frame_skip: int = 1, @@ -46,8 +44,6 @@ class PushtEnv(EnvBase): if not _has_gym: raise ImportError("Cannot import gym.") - from diffusion_policy.env.pusht.pusht_env import PushTEnv - if not from_pixels: raise NotImplementedError("Use PushTEnv, instead of PushTImageEnv") from diffusion_policy.env.pusht.pusht_image_env import PushTImageEnv @@ -71,14 +67,10 @@ class PushtEnv(EnvBase): obs = {"image": torch.from_numpy(raw_obs["image"])} if not self.pixels_only: - obs["state"] = torch.from_numpy(raw_obs["agent_pos"]).type( - torch.float32 - ) + obs["state"] = torch.from_numpy(raw_obs["agent_pos"]).type(torch.float32) else: # TODO: - obs = { - "state": torch.from_numpy(raw_obs["observation"]).type(torch.float32) - } + obs = {"state": torch.from_numpy(raw_obs["observation"]).type(torch.float32)} obs = TensorDict(obs, batch_size=[]) return obs @@ -109,7 +101,7 @@ class PushtEnv(EnvBase): # step expects shape=(4,) so we pad if necessary # TODO(rcadene): add info["is_success"] and info["success"] ? sum_reward = 0 - for t in range(self.frame_skip): + for _ in range(self.frame_skip): raw_obs, reward, done, info = self._env.step(action) sum_reward += reward diff --git a/lerobot/common/envs/simxarm.py b/lerobot/common/envs/simxarm.py index e25841fe..24fd9ba4 100644 --- a/lerobot/common/envs/simxarm.py +++ b/lerobot/common/envs/simxarm.py @@ -15,12 +15,13 @@ from torchrl.envs.libs.gym import _gym_to_torchrl_spec_transform from lerobot.common.utils import set_seed +MAX_NUM_ACTIONS = 4 + _has_gym = importlib.util.find_spec("gym") is not None _has_simxarm = importlib.util.find_spec("simxarm") is not None and _has_gym class SimxarmEnv(EnvBase): - def __init__( self, task, @@ -52,18 +53,13 @@ class SimxarmEnv(EnvBase): from simxarm import TASKS if self.task not in TASKS: - raise ValueError( - f"Unknown task {self.task}. Must be one of {list(TASKS.keys())}" - ) + raise ValueError(f"Unknown task {self.task}. Must be one of {list(TASKS.keys())}") self._env = TASKS[self.task]["env"]() - MAX_NUM_ACTIONS = 4 num_actions = len(TASKS[self.task]["action_space"]) self._action_space = gym.spaces.Box(low=-1.0, high=1.0, shape=(num_actions,)) - self._action_padding = np.zeros( - (MAX_NUM_ACTIONS - num_actions), dtype=np.float32 - ) + self._action_padding = np.zeros((MAX_NUM_ACTIONS - num_actions), dtype=np.float32) if "w" not in TASKS[self.task]["action_space"]: self._action_padding[-1] = 1.0 @@ -75,9 +71,7 @@ class SimxarmEnv(EnvBase): def _format_raw_obs(self, raw_obs): if self.from_pixels: - image = self.render( - mode="rgb_array", width=self.image_size, height=self.image_size - ) + image = self.render(mode="rgb_array", width=self.image_size, height=self.image_size) image = image.transpose(2, 0, 1) # (H, W, C) -> (C, H, W) image = torch.tensor(image.copy(), dtype=torch.uint8) @@ -114,7 +108,7 @@ class SimxarmEnv(EnvBase): action = np.concatenate([action, self._action_padding]) # TODO(rcadene): add info["is_success"] and info["success"] ? sum_reward = 0 - for t in range(self.frame_skip): + for _ in range(self.frame_skip): raw_obs, reward, done, info = self._env.step(action) sum_reward += reward diff --git a/lerobot/common/envs/transforms.py b/lerobot/common/envs/transforms.py index f1e6657b..1a3c1ce1 100644 --- a/lerobot/common/envs/transforms.py +++ b/lerobot/common/envs/transforms.py @@ -5,7 +5,6 @@ from torchrl.envs.transforms import ObservationTransform class Prod(ObservationTransform): - def __init__(self, in_keys: Sequence[NestedKey], prod: float): super().__init__() self.in_keys = in_keys diff --git a/lerobot/common/logger.py b/lerobot/common/logger.py index a013c9ec..ddf2ef04 100644 --- a/lerobot/common/logger.py +++ b/lerobot/common/logger.py @@ -1,6 +1,6 @@ +import contextlib import datetime import os -import re from pathlib import Path import numpy as np @@ -29,10 +29,8 @@ AGENT_METRICS = [ def make_dir(dir_path): """Create directory if it does not already exist.""" - try: + with contextlib.suppress(OSError): dir_path.mkdir(parents=True, exist_ok=True) - except OSError: - pass return dir_path @@ -59,9 +57,7 @@ def print_run(cfg, reward=None): # ('experiment', cfg.exp_name), ] if reward is not None: - kvs.append( - ("episode reward", colored(str(int(reward)), "white", attrs=["bold"])) - ) + kvs.append(("episode reward", colored(str(int(reward)), "white", attrs=["bold"]))) w = np.max([len(limstr(str(kv[1]))) for kv in kvs]) + 21 div = "-" * w print(div) @@ -80,7 +76,7 @@ def cfg_to_group(cfg, return_list=False): return lst if return_list else "-".join(lst) -class Logger(object): +class Logger: """Primary logger object. Logs either locally or using wandb.""" def __init__(self, log_dir, job_name, cfg): @@ -183,7 +179,5 @@ class Logger(object): if category == "eval": keys = ["step", "avg_sum_reward", "avg_max_reward", "pc_success"] self._eval.append(np.array([d[key] for key in keys])) - pd.DataFrame(np.array(self._eval)).to_csv( - self._log_dir / "eval.log", header=keys, index=None - ) + pd.DataFrame(np.array(self._eval)).to_csv(self._log_dir / "eval.log", header=keys, index=None) self._print(d, category) diff --git a/lerobot/common/policies/diffusion.py b/lerobot/common/policies/diffusion.py index 3bd9f515..92187290 100644 --- a/lerobot/common/policies/diffusion.py +++ b/lerobot/common/policies/diffusion.py @@ -3,16 +3,17 @@ import copy import hydra import torch import torch.nn as nn -import torch.nn.functional as F from diffusers.schedulers.scheduling_ddpm import DDPMScheduler + from diffusion_policy.model.common.lr_scheduler import get_scheduler from diffusion_policy.model.vision.model_getter import get_resnet from diffusion_policy.model.vision.multi_image_obs_encoder import MultiImageObsEncoder from diffusion_policy.policy.diffusion_unet_image_policy import DiffusionUnetImagePolicy +FIRST_ACTION = 0 + class DiffusionPolicy(nn.Module): - def __init__( self, cfg, @@ -105,7 +106,6 @@ class DiffusionPolicy(nn.Module): out = self.diffusion.predict_action(obs_dict) # TODO(rcadene): add possibility to return >1 timestemps - FIRST_ACTION = 0 action = out["action"].squeeze(0)[FIRST_ACTION] return action @@ -132,10 +132,7 @@ class DiffusionPolicy(nn.Module): } return out - if self.cfg.balanced_sampling: - batch = replay_buffer.sample(batch_size) - else: - batch = replay_buffer.sample() + batch = replay_buffer.sample(batch_size) if self.cfg.balanced_sampling else replay_buffer.sample() batch = process_batch(batch, self.cfg.horizon, num_slices) loss = self.diffusion.compute_loss(batch) diff --git a/lerobot/common/policies/tdmpc.py b/lerobot/common/policies/tdmpc.py index 55c022df..76a7b9aa 100644 --- a/lerobot/common/policies/tdmpc.py +++ b/lerobot/common/policies/tdmpc.py @@ -7,6 +7,8 @@ import torch.nn as nn import lerobot.common.policies.tdmpc_helper as h +FIRST_FRAME = 0 + class TOLD(nn.Module): """Task-Oriented Latent Dynamics (TOLD) model used in TD-MPC.""" @@ -17,9 +19,7 @@ class TOLD(nn.Module): self.cfg = cfg self._encoder = h.enc(cfg) - self._dynamics = h.dynamics( - cfg.latent_dim + action_dim, cfg.mlp_dim, cfg.latent_dim - ) + self._dynamics = h.dynamics(cfg.latent_dim + action_dim, cfg.mlp_dim, cfg.latent_dim) self._reward = h.mlp(cfg.latent_dim + action_dim, cfg.mlp_dim, 1) self._pi = h.mlp(cfg.latent_dim, cfg.mlp_dim, action_dim) self._Qs = nn.ModuleList([h.q(cfg) for _ in range(cfg.num_q)]) @@ -65,20 +65,20 @@ class TOLD(nn.Module): return h.TruncatedNormal(mu, std).sample(clip=0.3) return mu - def V(self, z): + def V(self, z): # noqa: N802 """Predict state value (V).""" return self._V(z) - def Q(self, z, a, return_type): + def Q(self, z, a, return_type): # noqa: N802 """Predict state-action value (Q).""" assert return_type in {"min", "avg", "all"} x = torch.cat([z, a], dim=-1) if return_type == "all": - return torch.stack(list(q(x) for q in self._Qs), dim=0) + return torch.stack([q(x) for q in self._Qs], dim=0) idxs = np.random.choice(self.cfg.num_q, 2, replace=False) - Q1, Q2 = self._Qs[idxs[0]](x), self._Qs[idxs[1]](x) + Q1, Q2 = self._Qs[idxs[0]](x), self._Qs[idxs[1]](x) # noqa: N806 return torch.min(Q1, Q2) if return_type == "min" else (Q1 + Q2) / 2 @@ -146,25 +146,21 @@ class TDMPC(nn.Module): @torch.no_grad() def estimate_value(self, z, actions, horizon): """Estimate value of a trajectory starting at latent state z and executing given actions.""" - G, discount = 0, 1 + G, discount = 0, 1 # noqa: N806 for t in range(horizon): if self.cfg.uncertainty_cost > 0: - G -= ( + G -= ( # noqa: N806 discount * self.cfg.uncertainty_cost * self.model.Q(z, actions[t], return_type="all").std(dim=0) ) z, reward = self.model.next(z, actions[t]) - G += discount * reward + G += discount * reward # noqa: N806 discount *= self.cfg.discount pi = self.model.pi(z, self.cfg.min_std) - G += discount * self.model.Q(z, pi, return_type="min") + G += discount * self.model.Q(z, pi, return_type="min") # noqa: N806 if self.cfg.uncertainty_cost > 0: - G -= ( - discount - * self.cfg.uncertainty_cost - * self.model.Q(z, pi, return_type="all").std(dim=0) - ) + G -= discount * self.cfg.uncertainty_cost * self.model.Q(z, pi, return_type="all").std(dim=0) # noqa: N806 return G @torch.no_grad() @@ -180,19 +176,13 @@ class TDMPC(nn.Module): assert step is not None # Seed steps if step < self.cfg.seed_steps and self.model.training: - return torch.empty( - self.action_dim, dtype=torch.float32, device=self.device - ).uniform_(-1, 1) + return torch.empty(self.action_dim, dtype=torch.float32, device=self.device).uniform_(-1, 1) # Sample policy trajectories - horizon = int( - min(self.cfg.horizon, h.linear_schedule(self.cfg.horizon_schedule, step)) - ) + horizon = int(min(self.cfg.horizon, h.linear_schedule(self.cfg.horizon_schedule, step))) num_pi_trajs = int(self.cfg.mixture_coef * self.cfg.num_samples) if num_pi_trajs > 0: - pi_actions = torch.empty( - horizon, num_pi_trajs, self.action_dim, device=self.device - ) + pi_actions = torch.empty(horizon, num_pi_trajs, self.action_dim, device=self.device) _z = z.repeat(num_pi_trajs, 1) for t in range(horizon): pi_actions[t] = self.model.pi(_z, self.cfg.min_std) @@ -201,20 +191,16 @@ class TDMPC(nn.Module): # Initialize state and parameters z = z.repeat(self.cfg.num_samples + num_pi_trajs, 1) mean = torch.zeros(horizon, self.action_dim, device=self.device) - std = self.cfg.max_std * torch.ones( - horizon, self.action_dim, device=self.device - ) + std = self.cfg.max_std * torch.ones(horizon, self.action_dim, device=self.device) if not t0 and hasattr(self, "_prev_mean"): mean[:-1] = self._prev_mean[1:] # Iterate CEM - for i in range(self.cfg.iterations): + for _ in range(self.cfg.iterations): actions = torch.clamp( mean.unsqueeze(1) + std.unsqueeze(1) - * torch.randn( - horizon, self.cfg.num_samples, self.action_dim, device=std.device - ), + * torch.randn(horizon, self.cfg.num_samples, self.action_dim, device=std.device), -1, 1, ) @@ -223,18 +209,14 @@ class TDMPC(nn.Module): # Compute elite actions value = self.estimate_value(z, actions, horizon).nan_to_num_(0) - elite_idxs = torch.topk( - value.squeeze(1), self.cfg.num_elites, dim=0 - ).indices + elite_idxs = torch.topk(value.squeeze(1), self.cfg.num_elites, dim=0).indices elite_value, elite_actions = value[elite_idxs], actions[:, elite_idxs] # Update parameters max_value = elite_value.max(0)[0] score = torch.exp(self.cfg.temperature * (elite_value - max_value)) score /= score.sum(0) - _mean = torch.sum(score.unsqueeze(0) * elite_actions, dim=1) / ( - score.sum(0) + 1e-9 - ) + _mean = torch.sum(score.unsqueeze(0) * elite_actions, dim=1) / (score.sum(0) + 1e-9) _std = torch.sqrt( torch.sum( score.unsqueeze(0) * (elite_actions - _mean.unsqueeze(1)) ** 2, @@ -331,7 +313,6 @@ class TDMPC(nn.Module): batch = batch.reshape(num_slices, horizon).transpose(1, 0).contiguous() batch = batch.to(self.device) - FIRST_FRAME = 0 obs = { "rgb": batch["observation", "image"][FIRST_FRAME].float(), "state": batch["observation", "state"][FIRST_FRAME], @@ -359,10 +340,7 @@ class TDMPC(nn.Module): weights = batch["_weight"][FIRST_FRAME, :, None] return obs, action, next_obses, reward, mask, done, idxs, weights - if self.cfg.balanced_sampling: - batch = replay_buffer.sample(batch_size) - else: - batch = replay_buffer.sample() + batch = replay_buffer.sample(batch_size) if self.cfg.balanced_sampling else replay_buffer.sample() obs, action, next_obses, reward, mask, done, idxs, weights = process_batch( batch, self.cfg.horizon, num_slices @@ -384,10 +362,7 @@ class TDMPC(nn.Module): if isinstance(obs, dict): obs = {k: torch.cat([obs[k], demo_obs[k]]) for k in obs} - next_obses = { - k: torch.cat([next_obses[k], demo_next_obses[k]], dim=1) - for k in next_obses - } + next_obses = {k: torch.cat([next_obses[k], demo_next_obses[k]], dim=1) for k in next_obses} else: obs = torch.cat([obs, demo_obs]) next_obses = torch.cat([next_obses, demo_next_obses], dim=1) @@ -429,9 +404,7 @@ class TDMPC(nn.Module): td_targets = self._td_target(next_z, reward, mask) # Latent rollout - zs = torch.empty( - horizon + 1, self.batch_size, self.cfg.latent_dim, device=self.device - ) + zs = torch.empty(horizon + 1, self.batch_size, self.cfg.latent_dim, device=self.device) reward_preds = torch.empty_like(reward, device=self.device) assert reward.shape[0] == horizon z = self.model.encode(obs) @@ -452,12 +425,10 @@ class TDMPC(nn.Module): value_info["V"] = v.mean().item() # Losses - rho = torch.pow(self.cfg.rho, torch.arange(horizon, device=self.device)).view( - -1, 1, 1 + rho = torch.pow(self.cfg.rho, torch.arange(horizon, device=self.device)).view(-1, 1, 1) + consistency_loss = (rho * torch.mean(h.mse(zs[1:], z_targets), dim=2, keepdim=True) * loss_mask).sum( + dim=0 ) - consistency_loss = ( - rho * torch.mean(h.mse(zs[1:], z_targets), dim=2, keepdim=True) * loss_mask - ).sum(dim=0) reward_loss = (rho * h.mse(reward_preds, reward) * loss_mask).sum(dim=0) q_value_loss, priority_loss = 0, 0 for q in range(self.cfg.num_q): @@ -465,9 +436,7 @@ class TDMPC(nn.Module): priority_loss += (rho * h.l1(qs[q], td_targets) * loss_mask).sum(dim=0) expectile = h.linear_schedule(self.cfg.expectile, step) - v_value_loss = ( - rho * h.l2_expectile(v_target - v, expectile=expectile) * loss_mask - ).sum(dim=0) + v_value_loss = (rho * h.l2_expectile(v_target - v, expectile=expectile) * loss_mask).sum(dim=0) total_loss = ( self.cfg.consistency_coef * consistency_loss diff --git a/lerobot/common/policies/tdmpc_helper.py b/lerobot/common/policies/tdmpc_helper.py index dd7abbec..264cd829 100644 --- a/lerobot/common/policies/tdmpc_helper.py +++ b/lerobot/common/policies/tdmpc_helper.py @@ -5,11 +5,15 @@ import re import numpy as np import torch import torch.nn as nn -import torch.nn.functional as F +import torch.nn.functional as F # noqa: N812 from torch import distributions as pyd from torch.distributions.utils import _standard_normal -__REDUCE__ = lambda b: "mean" if b else "none" +DEFAULT_ACT_FN = nn.Mish() + + +def __REDUCE__(b): # noqa: N802, N807 + return "mean" if b else "none" def l1(pred, target, reduce=False): @@ -36,11 +40,7 @@ def l2_expectile(diff, expectile=0.7, reduce=False): def _get_out_shape(in_shape, layers): """Utility function. Returns the output shape of a network for a given input shape.""" x = torch.randn(*in_shape).unsqueeze(0) - return ( - (nn.Sequential(*layers) if isinstance(layers, list) else layers)(x) - .squeeze(0) - .shape - ) + return (nn.Sequential(*layers) if isinstance(layers, list) else layers)(x).squeeze(0).shape def gaussian_logprob(eps, log_std): @@ -73,7 +73,7 @@ def orthogonal_init(m): def ema(m, m_target, tau): """Update slow-moving average of online network (target network) at rate tau.""" with torch.no_grad(): - for p, p_target in zip(m.parameters(), m_target.parameters()): + for p, p_target in zip(m.parameters(), m_target.parameters(), strict=False): p_target.data.lerp_(p.data, tau) @@ -86,6 +86,8 @@ def set_requires_grad(net, value): class TruncatedNormal(pyd.Normal): """Utility class implementing the truncated normal distribution.""" + default_sample_shape = torch.Size() + def __init__(self, loc, scale, low=-1.0, high=1.0, eps=1e-6): super().__init__(loc, scale, validate_args=False) self.low = low @@ -97,7 +99,7 @@ class TruncatedNormal(pyd.Normal): x = x - x.detach() + clamped_x.detach() return x - def sample(self, clip=None, sample_shape=torch.Size()): + def sample(self, clip=None, sample_shape=default_sample_shape): shape = self._extended_shape(sample_shape) eps = _standard_normal(shape, dtype=self.loc.dtype, device=self.loc.device) eps *= self.scale @@ -136,7 +138,7 @@ def enc(cfg): """Returns a TOLD encoder.""" pixels_enc_layers, state_enc_layers = None, None if cfg.modality in {"pixels", "all"}: - C = int(3 * cfg.frame_stack) + C = int(3 * cfg.frame_stack) # noqa: N806 pixels_enc_layers = [ NormalizeImg(), nn.Conv2d(C, cfg.num_channels, 7, stride=2), @@ -184,7 +186,7 @@ def enc(cfg): return Multiplexer(nn.ModuleDict(encoders)) -def mlp(in_dim, mlp_dim, out_dim, act_fn=nn.Mish()): +def mlp(in_dim, mlp_dim, out_dim, act_fn=DEFAULT_ACT_FN): """Returns an MLP.""" if isinstance(mlp_dim, int): mlp_dim = [mlp_dim, mlp_dim] @@ -199,7 +201,7 @@ def mlp(in_dim, mlp_dim, out_dim, act_fn=nn.Mish()): ) -def dynamics(in_dim, mlp_dim, out_dim, act_fn=nn.Mish()): +def dynamics(in_dim, mlp_dim, out_dim, act_fn=DEFAULT_ACT_FN): """Returns a dynamics network.""" return nn.Sequential( mlp(in_dim, mlp_dim, out_dim, act_fn), @@ -327,7 +329,7 @@ class RandomShiftsAug(nn.Module): return F.grid_sample(x, grid, padding_mode="zeros", align_corners=False) -class Episode(object): +class Episode: """Storage object for a single episode.""" def __init__(self, cfg, init_obs): @@ -354,18 +356,10 @@ class Episode(object): self.obses[k][0] = torch.tensor(v, dtype=dtype, device=self.device) else: raise ValueError - self.actions = torch.empty( - (cfg.episode_length, action_dim), dtype=torch.float32, device=self.device - ) - self.rewards = torch.empty( - (cfg.episode_length,), dtype=torch.float32, device=self.device - ) - self.dones = torch.empty( - (cfg.episode_length,), dtype=torch.bool, device=self.device - ) - self.masks = torch.empty( - (cfg.episode_length,), dtype=torch.float32, device=self.device - ) + self.actions = torch.empty((cfg.episode_length, action_dim), dtype=torch.float32, device=self.device) + self.rewards = torch.empty((cfg.episode_length,), dtype=torch.float32, device=self.device) + self.dones = torch.empty((cfg.episode_length,), dtype=torch.bool, device=self.device) + self.masks = torch.empty((cfg.episode_length,), dtype=torch.float32, device=self.device) self.cumulative_reward = 0 self.done = False self.success = False @@ -380,23 +374,17 @@ class Episode(object): if cfg.modality in {"pixels", "state"}: episode = cls(cfg, obses[0]) - episode.obses[1:] = torch.tensor( - obses[1:], dtype=episode.obses.dtype, device=episode.device - ) + episode.obses[1:] = torch.tensor(obses[1:], dtype=episode.obses.dtype, device=episode.device) elif cfg.modality == "all": episode = cls(cfg, {k: v[0] for k, v in obses.items()}) - for k, v in obses.items(): + for k in obses: episode.obses[k][1:] = torch.tensor( obses[k][1:], dtype=episode.obses[k].dtype, device=episode.device ) else: raise NotImplementedError - episode.actions = torch.tensor( - actions, dtype=episode.actions.dtype, device=episode.device - ) - episode.rewards = torch.tensor( - rewards, dtype=episode.rewards.dtype, device=episode.device - ) + episode.actions = torch.tensor(actions, dtype=episode.actions.dtype, device=episode.device) + episode.rewards = torch.tensor(rewards, dtype=episode.rewards.dtype, device=episode.device) episode.dones = ( torch.tensor(dones, dtype=episode.dones.dtype, device=episode.device) if dones is not None @@ -428,9 +416,7 @@ class Episode(object): v, dtype=self.obses[k].dtype, device=self.obses[k].device ) else: - self.obses[self._idx + 1] = torch.tensor( - obs, dtype=self.obses.dtype, device=self.obses.device - ) + self.obses[self._idx + 1] = torch.tensor(obs, dtype=self.obses.dtype, device=self.obses.device) self.actions[self._idx] = action self.rewards[self._idx] = reward self.dones[self._idx] = done @@ -453,7 +439,7 @@ def get_dataset_dict(cfg, env, return_reward_normalizer=False): ] if cfg.task.startswith("xarm"): - dataset_path = os.path.join(cfg.dataset_dir, f"buffer.pkl") + dataset_path = os.path.join(cfg.dataset_dir, "buffer.pkl") print(f"Using offline dataset '{dataset_path}'") with open(dataset_path, "rb") as f: dataset_dict = pickle.load(f) @@ -461,7 +447,7 @@ def get_dataset_dict(cfg, env, return_reward_normalizer=False): if k not in dataset_dict and k[:-1] in dataset_dict: dataset_dict[k] = dataset_dict.pop(k[:-1]) elif cfg.task.startswith("legged"): - dataset_path = os.path.join(cfg.dataset_dir, f"buffer.pkl") + dataset_path = os.path.join(cfg.dataset_dir, "buffer.pkl") print(f"Using offline dataset '{dataset_path}'") with open(dataset_path, "rb") as f: dataset_dict = pickle.load(f) @@ -475,10 +461,7 @@ def get_dataset_dict(cfg, env, return_reward_normalizer=False): for i in range(len(dones) - 1): if ( - np.linalg.norm( - dataset_dict["observations"][i + 1] - - dataset_dict["next_observations"][i] - ) + np.linalg.norm(dataset_dict["observations"][i + 1] - dataset_dict["next_observations"][i]) > 1e-6 or dataset_dict["terminals"][i] == 1.0 ): @@ -501,7 +484,7 @@ def get_dataset_dict(cfg, env, return_reward_normalizer=False): dataset_dict["rewards"] = reward_normalizer(dataset_dict["rewards"]) for key in required_keys: - assert key in dataset_dict.keys(), f"Missing `{key}` in dataset." + assert key in dataset_dict, f"Missing `{key}` in dataset." if return_reward_normalizer: return dataset_dict, reward_normalizer @@ -553,9 +536,7 @@ def get_reward_normalizer(cfg, dataset): return lambda x: x - 1.0 elif cfg.task.split("-")[0] in ["hopper", "halfcheetah", "walker2d"]: (_, _, episode_returns) = get_trajectory_boundaries_and_returns(dataset) - return ( - lambda x: x / (np.max(episode_returns) - np.min(episode_returns)) * 1000.0 - ) + return lambda x: x / (np.max(episode_returns) - np.min(episode_returns)) * 1000.0 elif hasattr(cfg, "reward_scale"): return lambda x: x * cfg.reward_scale return lambda x: x @@ -571,12 +552,12 @@ def linear_schedule(schdl, step): except ValueError: match = re.match(r"linear\((.+),(.+),(.+),(.+)\)", schdl) if match: - init, final, start, end = [float(g) for g in match.groups()] + init, final, start, end = (float(g) for g in match.groups()) mix = np.clip((step - start) / (end - start), 0.0, 1.0) return (1.0 - mix) * init + mix * final match = re.match(r"linear\((.+),(.+),(.+)\)", schdl) if match: - init, final, duration = [float(g) for g in match.groups()] + init, final, duration = (float(g) for g in match.groups()) mix = np.clip(step / duration, 0.0, 1.0) return (1.0 - mix) * init + mix * final raise NotImplementedError(schdl) diff --git a/lerobot/configs/env/pusht.yaml b/lerobot/configs/env/pusht.yaml index 1719321b..6866f053 100644 --- a/lerobot/configs/env/pusht.yaml +++ b/lerobot/configs/env/pusht.yaml @@ -22,4 +22,4 @@ env: policy: state_dim: 2 - action_dim: 2 \ No newline at end of file + action_dim: 2 diff --git a/lerobot/configs/env/simxarm.yaml b/lerobot/configs/env/simxarm.yaml index 0658636a..b235dff7 100644 --- a/lerobot/configs/env/simxarm.yaml +++ b/lerobot/configs/env/simxarm.yaml @@ -21,4 +21,4 @@ env: policy: state_dim: 4 - action_dim: 4 \ No newline at end of file + action_dim: 4 diff --git a/lerobot/scripts/eval.py b/lerobot/scripts/eval.py index 2e8f6965..8240f654 100644 --- a/lerobot/scripts/eval.py +++ b/lerobot/scripts/eval.py @@ -37,10 +37,11 @@ def eval_policy( tensordict = env.reset() ep_frames = [] + if save_video or (return_first_video and i == 0): def rendering_callback(env, td=None): - ep_frames.append(env.render()) + ep_frames.append(env.render()) # noqa: B023 # render first frame before rollout rendering_callback(env) diff --git a/lerobot/scripts/train.py b/lerobot/scripts/train.py index 26f13f37..c5268aa8 100644 --- a/lerobot/scripts/train.py +++ b/lerobot/scripts/train.py @@ -6,8 +6,6 @@ import torch from tensordict.nn import TensorDictModule from termcolor import colored from torchrl.data import LazyMemmapStorage, TensorDictReplayBuffer -from torchrl.data.datasets.d4rl import D4RLExperienceReplay -from torchrl.data.datasets.openx import OpenXExperienceReplay from torchrl.data.replay_buffers import PrioritizedSliceSampler from lerobot.common.datasets.factory import make_offline_buffer @@ -27,9 +25,7 @@ def train_cli(cfg: dict): ) -def train_notebook( - out_dir=None, job_name=None, config_name="default", config_path="../configs" -): +def train_notebook(out_dir=None, job_name=None, config_name="default", config_path="../configs"): from hydra import compose, initialize hydra.core.global_hydra.GlobalHydra.instance().clear() @@ -38,7 +34,7 @@ def train_notebook( train(cfg, out_dir=out_dir, job_name=job_name) -def log_training_metrics(L, metrics, step, online_episode_idx, start_time, is_offline): +def log_training_metrics(logger, metrics, step, online_episode_idx, start_time, is_offline): common_metrics = { "episode": online_episode_idx, "step": step, @@ -46,12 +42,10 @@ def log_training_metrics(L, metrics, step, online_episode_idx, start_time, is_of "is_offline": float(is_offline), } metrics.update(common_metrics) - L.log(metrics, category="train") + logger.log(metrics, category="train") -def eval_policy_and_log( - env, td_policy, step, online_episode_idx, start_time, cfg, L, is_offline -): +def eval_policy_and_log(env, td_policy, step, online_episode_idx, start_time, cfg, logger, is_offline): common_metrics = { "episode": online_episode_idx, "step": step, @@ -65,11 +59,11 @@ def eval_policy_and_log( return_first_video=True, ) metrics.update(common_metrics) - L.log(metrics, category="eval") + logger.log(metrics, category="eval") if cfg.wandb.enable: - eval_video = L._wandb.Video(first_video, fps=cfg.fps, format="mp4") - L._wandb.log({"eval_video": eval_video}, step=step) + eval_video = logger._wandb.Video(first_video, fps=cfg.fps, format="mp4") + logger._wandb.log({"eval_video": eval_video}, step=step) def train(cfg: dict, out_dir=None, job_name=None): @@ -116,7 +110,7 @@ def train(cfg: dict, out_dir=None, job_name=None): sampler=online_sampler, ) - L = Logger(out_dir, job_name, cfg) + logger = Logger(out_dir, job_name, cfg) online_episode_idx = 0 start_time = time.time() @@ -129,9 +123,7 @@ def train(cfg: dict, out_dir=None, job_name=None): metrics = policy.update(offline_buffer, step) if step % cfg.log_freq == 0: - log_training_metrics( - L, metrics, step, online_episode_idx, start_time, is_offline=False - ) + log_training_metrics(logger, metrics, step, online_episode_idx, start_time, is_offline=False) if step > 0 and step % cfg.eval_freq == 0: eval_policy_and_log( @@ -141,13 +133,13 @@ def train(cfg: dict, out_dir=None, job_name=None): online_episode_idx, start_time, cfg, - L, + logger, is_offline=True, ) if step > 0 and cfg.save_model and step % cfg.save_freq == 0: print(f"Checkpoint model at step {step}") - L.save_model(policy, identifier=step) + logger.save_model(policy, identifier=step) step += 1 @@ -164,9 +156,7 @@ def train(cfg: dict, out_dir=None, job_name=None): auto_cast_to_device=True, ) assert len(rollout) <= cfg.env.episode_length - rollout["episode"] = torch.tensor( - [online_episode_idx] * len(rollout), dtype=torch.int - ) + rollout["episode"] = torch.tensor([online_episode_idx] * len(rollout), dtype=torch.int) online_buffer.extend(rollout) ep_sum_reward = rollout["next", "reward"].sum() @@ -188,9 +178,7 @@ def train(cfg: dict, out_dir=None, job_name=None): ) metrics.update(train_metrics) if step % cfg.log_freq == 0: - log_training_metrics( - L, metrics, step, online_episode_idx, start_time, is_offline=False - ) + log_training_metrics(logger, metrics, step, online_episode_idx, start_time, is_offline=False) if step > 0 and step % cfg.eval_freq == 0: eval_policy_and_log( @@ -200,13 +188,13 @@ def train(cfg: dict, out_dir=None, job_name=None): online_episode_idx, start_time, cfg, - L, + logger, is_offline=False, ) if step > 0 and cfg.save_model and step % cfg.save_freq == 0: print(f"Checkpoint model at step {step}") - L.save_model(policy, identifier=step) + logger.save_model(policy, identifier=step) step += 1 diff --git a/lerobot/scripts/visualize_dataset.py b/lerobot/scripts/visualize_dataset.py index 107b6a71..fcb4c20e 100644 --- a/lerobot/scripts/visualize_dataset.py +++ b/lerobot/scripts/visualize_dataset.py @@ -1,24 +1,22 @@ -import pickle from pathlib import Path import hydra import imageio -import simxarm import torch from torchrl.data.replay_buffers import ( - SamplerWithoutReplacement, - SliceSampler, SliceSamplerWithoutReplacement, ) from lerobot.common.datasets.factory import make_offline_buffer +NUM_EPISODES_TO_RENDER = 10 +MAX_NUM_STEPS = 1000 +FIRST_FRAME = 0 + @hydra.main(version_base=None, config_name="default", config_path="../configs") def visualize_dataset_cli(cfg: dict): - visualize_dataset( - cfg, out_dir=hydra.core.hydra_config.HydraConfig.get().runtime.output_dir - ) + visualize_dataset(cfg, out_dir=hydra.core.hydra_config.HydraConfig.get().runtime.output_dir) def visualize_dataset(cfg: dict, out_dir=None): @@ -33,9 +31,6 @@ def visualize_dataset(cfg: dict, out_dir=None): offline_buffer = make_offline_buffer(cfg, sampler) - NUM_EPISODES_TO_RENDER = 10 - MAX_NUM_STEPS = 1000 - FIRST_FRAME = 0 for _ in range(NUM_EPISODES_TO_RENDER): episode = offline_buffer.sample(MAX_NUM_STEPS) @@ -57,9 +52,7 @@ def visualize_dataset(cfg: dict, out_dir=None): assert ep_frames.max().item() > 1, "Not mendatory, but sanity check" assert ep_frames.max().item() <= 255 ep_frames = ep_frames.type(torch.uint8) - imageio.mimsave( - video_path, ep_frames.numpy().transpose(0, 2, 3, 1), fps=cfg.fps - ) + imageio.mimsave(video_path, ep_frames.numpy().transpose(0, 2, 3, 1), fps=cfg.fps) # ran out of episodes if offline_buffer._sampler._sample_list.numel() == 0: diff --git a/sbatch.sh b/sbatch.sh index 52a4df4b..da52c472 100644 --- a/sbatch.sh +++ b/sbatch.sh @@ -18,5 +18,5 @@ apptainer exec --nv \ source ~/.bashrc conda activate fowm - + srun $CMD