diff --git a/lerobot/common/datasets/factory.py b/lerobot/common/datasets/factory.py index 942a36dd..b6b93d89 100644 --- a/lerobot/common/datasets/factory.py +++ b/lerobot/common/datasets/factory.py @@ -69,7 +69,7 @@ def make_offline_buffer(cfg, sampler=None): sampler=sampler, batch_size=batch_size, pin_memory=pin_memory, - prefetch=prefetch, + prefetch=prefetch if isinstance(prefetch, int) else None, ) elif cfg.env.name == "pusht": offline_buffer = PushtExperienceReplay( @@ -79,7 +79,7 @@ def make_offline_buffer(cfg, sampler=None): sampler=sampler, batch_size=batch_size, pin_memory=pin_memory, - prefetch=prefetch, + prefetch=prefetch if isinstance(prefetch, int) else None, ) else: raise ValueError(cfg.env.name) diff --git a/lerobot/common/datasets/pusht.py b/lerobot/common/datasets/pusht.py index 1c1b658e..8ea64f86 100644 --- a/lerobot/common/datasets/pusht.py +++ b/lerobot/common/datasets/pusht.py @@ -143,13 +143,24 @@ class PushtExperienceReplay(TensorDictReplayBuffer): in_keys=[ # ("observation", "image"), ("observation", "state"), + # TODO(rcadene): for tdmpc, we might want image and state # ("next", "observation", "image"), - ("next", "observation", "state"), + # ("next", "observation", "state"), ("action"), ], mode="min_max", ) + # TODO(rcadene): make normalization strategy configurable between mean_std, min_max, min_max_spec + transform.stats["observation", "state", "min"] = torch.tensor( + [13.456424, 32.938293], dtype=torch.float32 + ) + transform.stats["observation", "state", "max"] = torch.tensor( + [496.14618, 510.9579], dtype=torch.float32 + ) + transform.stats["action", "min"] = torch.tensor([12.0, 25.0], dtype=torch.float32) + transform.stats["action", "max"] = torch.tensor([511.0, 511.0], dtype=torch.float32) + if writer is None: writer = ImmutableDatasetWriter() if collate_fn is None: diff --git a/lerobot/common/envs/factory.py b/lerobot/common/envs/factory.py index d7dc8aae..acf089f6 100644 --- a/lerobot/common/envs/factory.py +++ b/lerobot/common/envs/factory.py @@ -7,6 +7,8 @@ def make_env(cfg, transform=None): "from_pixels": cfg.env.from_pixels, "pixels_only": cfg.env.pixels_only, "image_size": cfg.env.image_size, + # TODO(rcadene): do we want a specific eval_env_seed? + "seed": cfg.seed, } if cfg.env.name == "simxarm": @@ -17,6 +19,8 @@ def make_env(cfg, transform=None): elif cfg.env.name == "pusht": from lerobot.common.envs.pusht import PushtEnv + # assert kwargs["seed"] > 200, "Seed 0-200 are used for the demonstration dataset, so we don't want to seed the eval env with this range." + clsfunc = PushtEnv else: raise ValueError(cfg.env.name) diff --git a/lerobot/common/envs/pusht.py b/lerobot/common/envs/pusht.py index 927a1ba7..62aa8d1b 100644 --- a/lerobot/common/envs/pusht.py +++ b/lerobot/common/envs/pusht.py @@ -101,14 +101,18 @@ class PushtEnv(EnvBase): obs = self._format_raw_obs(raw_obs) if self.num_prev_obs > 0: - # remove all previous observations + stacked_obs = {} if "image" in obs: - self._prev_obs_image_queue.clear() + self._prev_obs_image_queue = deque( + [obs["image"]] * (self.num_prev_obs + 1), maxlen=(self.num_prev_obs + 1) + ) + stacked_obs["image"] = torch.stack(list(self._prev_obs_image_queue)) if "state" in obs: - self._prev_obs_state_queue.clear() - - # copy the current observation n times - obs = self._stack_prev_obs(obs) + self._prev_obs_state_queue = deque( + [obs["state"]] * (self.num_prev_obs + 1), maxlen=(self.num_prev_obs + 1) + ) + stacked_obs["state"] = torch.stack(list(self._prev_obs_state_queue)) + obs = stacked_obs td = TensorDict( { @@ -121,40 +125,6 @@ class PushtEnv(EnvBase): raise NotImplementedError() return td - def _stack_prev_obs(self, obs): - """When the queue is empty, copy the current observation n times.""" - assert self.num_prev_obs > 0 - - def stack_update_queue(prev_obs_queue, obs, num_prev_obs): - # get n most recent observations - prev_obs = list(prev_obs_queue)[-num_prev_obs:] - - # if not enough observations, copy the oldest observation until we obtain n observations - if len(prev_obs) == 0: - prev_obs = [obs] * num_prev_obs # queue is empty when env reset - elif len(prev_obs) < num_prev_obs: - prev_obs = [prev_obs[0] for _ in range(num_prev_obs - len(prev_obs))] + prev_obs - - # stack n most recent observations with the current observation - stacked_obs = torch.stack(prev_obs + [obs], dim=0) - - # add current observation to the queue - # automatically remove oldest observation when queue is full - prev_obs_queue.appendleft(obs) - - return stacked_obs - - stacked_obs = {} - if "image" in obs: - stacked_obs["image"] = stack_update_queue( - self._prev_obs_image_queue, obs["image"], self.num_prev_obs - ) - if "state" in obs: - stacked_obs["state"] = stack_update_queue( - self._prev_obs_state_queue, obs["state"], self.num_prev_obs - ) - return stacked_obs - def _step(self, tensordict: TensorDict): td = tensordict action = td["action"].numpy() @@ -176,7 +146,14 @@ class PushtEnv(EnvBase): obs = self._format_raw_obs(raw_obs) if self.num_prev_obs > 0: - obs = self._stack_prev_obs(obs) + stacked_obs = {} + if "image" in obs: + self._prev_obs_image_queue.append(obs["image"]) + stacked_obs["image"] = torch.stack(list(self._prev_obs_image_queue)) + if "state" in obs: + self._prev_obs_state_queue.append(obs["state"]) + stacked_obs["state"] = torch.stack(list(self._prev_obs_state_queue)) + obs = stacked_obs td = TensorDict( { diff --git a/lerobot/common/logger.py b/lerobot/common/logger.py index a8cb6a66..54325bd4 100644 --- a/lerobot/common/logger.py +++ b/lerobot/common/logger.py @@ -1,51 +1,11 @@ -import contextlib +import logging import os from pathlib import Path -import numpy as np from omegaconf import OmegaConf from termcolor import colored -def make_dir(dir_path): - """Create directory if it does not already exist.""" - with contextlib.suppress(OSError): - dir_path.mkdir(parents=True, exist_ok=True) - return dir_path - - -def print_run(cfg, reward=None): - """Pretty-printing of run information. Call at start of training.""" - prefix, color, attrs = " ", "green", ["bold"] - - def limstr(s, maxlen=32): - return str(s[:maxlen]) + "..." if len(str(s)) > maxlen else s - - def pprint(k, v): - print( - prefix + colored(f'{k.capitalize() + ":":<16}', color, attrs=attrs), - limstr(v), - ) - - kvs = [ - ("task", cfg.env.task), - ("offline_steps", f"{cfg.offline_steps}"), - ("online_steps", f"{cfg.online_steps}"), - ("action_repeat", f"{cfg.env.action_repeat}"), - # ('observations', 'x'.join([str(s) for s in cfg.obs_shape])), - # ('actions', cfg.action_dim), - # ('experiment', cfg.exp_name), - ] - if reward is not None: - kvs.append(("episode reward", colored(str(int(reward)), "white", attrs=["bold"]))) - w = np.max([len(limstr(str(kv[1]))) for kv in kvs]) + 21 - div = "-" * w - print(div) - for k, v in kvs: - pprint(k, v) - print(div) - - def cfg_to_group(cfg, return_list=False): """Return a wandb-safe group name for logging. Optionally returns group name as list.""" # lst = [cfg.task, cfg.modality, re.sub("[^0-9a-zA-Z]+", "-", cfg.exp_name)] @@ -71,13 +31,12 @@ class Logger: self._seed = cfg.seed self._cfg = cfg self._eval = [] - print_run(cfg) project = cfg.get("wandb", {}).get("project") entity = cfg.get("wandb", {}).get("entity") enable_wandb = cfg.get("wandb", {}).get("enable", False) run_offline = not enable_wandb or not project or not entity if run_offline: - print(colored("Logs will be saved locally.", "yellow", attrs=["bold"])) + logging.info(colored("Logs will be saved locally.", "yellow", attrs=["bold"])) self._wandb = None else: os.environ["WANDB_SILENT"] = "true" @@ -134,7 +93,6 @@ class Logger: self.save_buffer(buffer, identifier="buffer") if self._wandb: self._wandb.finish() - print_run(self._cfg, self._eval[-1][-1]) def log_dict(self, d, step, mode="train"): assert mode in {"train", "eval"} diff --git a/lerobot/common/policies/diffusion/policy.py b/lerobot/common/policies/diffusion/policy.py index df05bfd8..7ae0a529 100644 --- a/lerobot/common/policies/diffusion/policy.py +++ b/lerobot/common/policies/diffusion/policy.py @@ -4,10 +4,8 @@ import time import hydra import torch import torch.nn as nn -from diffusers.schedulers.scheduling_ddpm import DDPMScheduler from diffusion_policy.model.common.lr_scheduler import get_scheduler -from diffusion_policy.model.vision.model_getter import get_resnet from .diffusion_unet_image_policy import DiffusionUnetImagePolicy from .multi_image_obs_encoder import MultiImageObsEncoder @@ -39,8 +37,8 @@ class DiffusionPolicy(nn.Module): super().__init__() self.cfg = cfg - noise_scheduler = DDPMScheduler(**cfg_noise_scheduler) - rgb_model = get_resnet(**cfg_rgb_model) + noise_scheduler = hydra.utils.instantiate(cfg_noise_scheduler) + rgb_model = hydra.utils.instantiate(cfg_rgb_model) obs_encoder = MultiImageObsEncoder( rgb_model=rgb_model, **cfg_obs_encoder, @@ -127,16 +125,36 @@ class DiffusionPolicy(nn.Module): # (t h) ... -> t h ... batch = batch.reshape(num_slices, horizon) # .transpose(1, 0).contiguous() + # |-1|0|1|2|3|4|5|6|7|8|9|10|11|12|13|14| timestamps: 16 + # |o|o| observations: 2 + # | |a|a|a|a|a|a|a|a| actions executed: 8 + # |p|p|p|p|p|p|p|p|p|p|p| p| p| p| p| p| actions predicted: 16 + # note: we predict the action needed to go from t=-1 to t=0 similarly to an inverse kinematic model + + image = batch["observation", "image"] + state = batch["observation", "state"] + action = batch["action"] + assert image.shape[1] == horizon + assert state.shape[1] == horizon + assert action.shape[1] == horizon + + if not (horizon == 16 and self.cfg.n_obs_steps == 2): + raise NotImplementedError() + + # keep first 2 observations of the slice corresponding to t=[-1,0] + image = image[:, : self.cfg.n_obs_steps] + state = state[:, : self.cfg.n_obs_steps] + out = { "obs": { - "image": batch["observation", "image"].to(self.device, non_blocking=True), - "agent_pos": batch["observation", "state"].to(self.device, non_blocking=True), + "image": image.to(self.device, non_blocking=True), + "agent_pos": state.to(self.device, non_blocking=True), }, - "action": batch["action"].to(self.device, non_blocking=True), + "action": action.to(self.device, non_blocking=True), } return out - batch = replay_buffer.sample(batch_size) if self.cfg.balanced_sampling else replay_buffer.sample() + batch = replay_buffer.sample(batch_size) batch = process_batch(batch, self.cfg.horizon, num_slices) data_s = time.time() - start_time diff --git a/lerobot/configs/default.yaml b/lerobot/configs/default.yaml index 97f560f5..66e3dff1 100644 --- a/lerobot/configs/default.yaml +++ b/lerobot/configs/default.yaml @@ -1,7 +1,7 @@ defaults: - _self_ - - env: simxarm - - policy: tdmpc + - env: pusht + - policy: diffusion hydra: run: @@ -22,6 +22,7 @@ save_buffer: false train_steps: ??? fps: ??? +n_action_steps: ??? env: ??? policy: ??? diff --git a/lerobot/configs/policy/diffusion.yaml b/lerobot/configs/policy/diffusion.yaml index f136fa55..da1b6545 100644 --- a/lerobot/configs/policy/diffusion.yaml +++ b/lerobot/configs/policy/diffusion.yaml @@ -21,7 +21,7 @@ past_action_visible: False keypoint_visible_rate: 1.0 obs_as_global_cond: True -eval_episodes: 50 +eval_episodes: 1 eval_freq: 10000 save_freq: 100000 log_freq: 250 @@ -40,8 +40,8 @@ policy: num_inference_steps: 100 obs_as_global_cond: ${obs_as_global_cond} # crop_shape: null - diffusion_step_embed_dim: 128 - down_dims: [512, 1024, 2048] + diffusion_step_embed_dim: 256 # before 128 + down_dims: [256, 512, 1024] # before [512, 1024, 2048] kernel_size: 5 n_groups: 8 cond_predict_scale: True @@ -62,7 +62,7 @@ policy: grad_clip_norm: 0 noise_scheduler: - # _target_: diffusers.schedulers.scheduling_ddpm.DDPMScheduler + _target_: diffusers.schedulers.scheduling_ddpm.DDPMScheduler num_train_timesteps: 100 beta_start: 0.0001 beta_end: 0.02 @@ -74,16 +74,16 @@ noise_scheduler: obs_encoder: # _target_: diffusion_policy.model.vision.multi_image_obs_encoder.MultiImageObsEncoder shape_meta: ${shape_meta} - resize_shape: null - crop_shape: [76, 76] + # resize_shape: null + # crop_shape: [76, 76] # constant center crop - random_crop: True + # random_crop: True use_group_norm: True share_rgb_model: False - imagenet_norm: False # TODO(rcadene): was set to True + imagenet_norm: True rgb_model: - #_target_: diffusion_policy.model.vision.model_getter.get_resnet + _target_: diffusion_policy.model.vision.model_getter.get_resnet name: resnet18 weights: null diff --git a/lerobot/configs/policy/tdmpc.yaml b/lerobot/configs/policy/tdmpc.yaml index 26dc4e51..16b7018e 100644 --- a/lerobot/configs/policy/tdmpc.yaml +++ b/lerobot/configs/policy/tdmpc.yaml @@ -1,5 +1,7 @@ # @package _global_ +n_action_steps: 1 + policy: name: tdmpc diff --git a/lerobot/scripts/eval.py b/lerobot/scripts/eval.py index 214f5dba..abe4645a 100644 --- a/lerobot/scripts/eval.py +++ b/lerobot/scripts/eval.py @@ -137,7 +137,7 @@ def eval(cfg: dict, out_dir=None): save_video=True, video_dir=Path(out_dir) / "eval", fps=cfg.env.fps, - max_steps=cfg.env.episode_length, + max_steps=cfg.env.episode_length // cfg.n_action_steps, num_episodes=cfg.eval_episodes, ) print(metrics) diff --git a/lerobot/scripts/train.py b/lerobot/scripts/train.py index 1c63fc97..1557495b 100644 --- a/lerobot/scripts/train.py +++ b/lerobot/scripts/train.py @@ -119,7 +119,6 @@ def train(cfg: dict, out_dir=None, job_name=None): torch.backends.cudnn.benchmark = True torch.backends.cuda.matmul.allow_tf32 = True set_seed(cfg.seed) - logging.info(colored("Work dir:", "yellow", attrs=["bold"]) + f" {out_dir}") logging.info("make_offline_buffer") offline_buffer = make_offline_buffer(cfg) @@ -149,6 +148,9 @@ def train(cfg: dict, out_dir=None, job_name=None): logging.info("make_policy") policy = make_policy(cfg) + num_learnable_params = sum(p.numel() for p in policy.parameters() if p.requires_grad) + num_total_params = sum(p.numel() for p in policy.parameters()) + td_policy = TensorDictModule( policy, in_keys=["observation", "step_count"], @@ -158,6 +160,16 @@ def train(cfg: dict, out_dir=None, job_name=None): # log metrics to terminal and wandb logger = Logger(out_dir, job_name, cfg) + logging.info(colored("Work dir:", "yellow", attrs=["bold"]) + f" {out_dir}") + logging.info(f"{cfg.env.task=}") + logging.info(f"{cfg.offline_steps=} ({format_big_number(cfg.offline_steps)})") + logging.info(f"{cfg.online_steps=}") + logging.info(f"{cfg.env.action_repeat=}") + logging.info(f"{offline_buffer.num_samples=} ({format_big_number(offline_buffer.num_samples)})") + logging.info(f"{offline_buffer.num_episodes=}") + logging.info(f"{num_learnable_params=} ({format_big_number(num_learnable_params)})") + logging.info(f"{num_total_params=} ({format_big_number(num_total_params)})") + step = 0 # number of policy update is_offline = True @@ -175,6 +187,7 @@ def train(cfg: dict, out_dir=None, job_name=None): env, td_policy, num_episodes=cfg.eval_episodes, + max_steps=cfg.env.episode_length // cfg.n_action_steps, return_first_video=True, ) log_eval_info(logger, eval_info, step, cfg, offline_buffer, is_offline) @@ -199,11 +212,11 @@ def train(cfg: dict, out_dir=None, job_name=None): # TODO: add configurable number of rollout? (default=1) with torch.no_grad(): rollout = env.rollout( - max_steps=cfg.env.episode_length, + max_steps=cfg.env.episode_length // cfg.n_action_steps, policy=td_policy, auto_cast_to_device=True, ) - assert len(rollout) <= cfg.env.episode_length + assert len(rollout) <= cfg.env.episode_length // cfg.n_action_steps # set same episode index for all time steps contained in this rollout rollout["episode"] = torch.tensor([env_step] * len(rollout), dtype=torch.int) online_buffer.extend(rollout) @@ -235,6 +248,7 @@ def train(cfg: dict, out_dir=None, job_name=None): env, td_policy, num_episodes=cfg.eval_episodes, + max_steps=cfg.env.episode_length // cfg.n_action_steps, return_first_video=True, ) log_eval_info(logger, eval_info, step, cfg, offline_buffer, is_offline)