From 21670dce90451532e31655bbccfa68fd906d7e69 Mon Sep 17 00:00:00 2001 From: Cadene Date: Mon, 26 Feb 2024 01:10:09 +0000 Subject: [PATCH] Refactor train, eval_policy, logger, Add diffusion.yaml (WIP) --- lerobot/common/datasets/factory.py | 20 ++ lerobot/common/logger.py | 76 +------ lerobot/common/policies/diffusion.py | 2 + lerobot/common/policies/factory.py | 22 +- lerobot/common/policies/tdmpc_helper.py | 255 ------------------------ lerobot/configs/env/pusht.yaml | 4 +- lerobot/configs/env/simxarm.yaml | 4 +- lerobot/configs/policy/diffusion.yaml | 117 +++++++++++ lerobot/configs/policy/tdmpc.yaml | 2 - lerobot/scripts/eval.py | 34 ++-- lerobot/scripts/train.py | 196 +++++++++--------- tests/test_policies.py | 17 +- 12 files changed, 306 insertions(+), 443 deletions(-) create mode 100644 lerobot/configs/policy/diffusion.yaml diff --git a/lerobot/common/datasets/factory.py b/lerobot/common/datasets/factory.py index e8b61135..54155478 100644 --- a/lerobot/common/datasets/factory.py +++ b/lerobot/common/datasets/factory.py @@ -4,6 +4,26 @@ from lerobot.common.datasets.pusht import PushtExperienceReplay from lerobot.common.datasets.simxarm import SimxarmExperienceReplay from rl.torchrl.data.replay_buffers.samplers import PrioritizedSliceSampler +# TODO(rcadene): implement + +# dataset_d4rl = D4RLExperienceReplay( +# dataset_id="maze2d-umaze-v1", +# split_trajs=False, +# batch_size=1, +# sampler=SamplerWithoutReplacement(drop_last=False), +# prefetch=4, +# direct_download=True, +# ) + +# dataset_openx = OpenXExperienceReplay( +# "cmu_stretch", +# batch_size=1, +# num_slices=1, +# #download="force", +# streaming=False, +# root="data", +# ) + def make_offline_buffer(cfg, sampler=None): diff --git a/lerobot/common/logger.py b/lerobot/common/logger.py index cb1bf0eb..a013c9ec 100644 --- a/lerobot/common/logger.py +++ b/lerobot/common/logger.py @@ -10,10 +10,10 @@ from termcolor import colored CONSOLE_FORMAT = [ ("episode", "E", "int"), - ("env_step", "S", "int"), + ("step", "S", "int"), ("avg_sum_reward", "RS", "float"), ("avg_max_reward", "RM", "float"), - ("pc_success", "S", "float"), + ("pc_success", "SR", "float"), ("total_time", "T", "time"), ] AGENT_METRICS = [ @@ -51,7 +51,9 @@ def print_run(cfg, reward=None): kvs = [ ("task", cfg.env.task), - ("train steps", f"{int(cfg.train_steps * cfg.env.action_repeat):,}"), + ("offline_steps", f"{cfg.offline_steps}"), + ("online_steps", f"{cfg.online_steps}"), + ("action_repeat", f"{cfg.env.action_repeat}"), # ('observations', 'x'.join([str(s) for s in cfg.obs_shape])), # ('actions', cfg.action_dim), # ('experiment', cfg.exp_name), @@ -78,54 +80,6 @@ def cfg_to_group(cfg, return_list=False): return lst if return_list else "-".join(lst) -class VideoRecorder: - """Utility class for logging evaluation videos.""" - - def __init__(self, root_dir, wandb, render_size=384, fps=15): - self.save_dir = (root_dir / "eval_video") if root_dir else None - self._wandb = wandb - self.render_size = render_size - self.fps = fps - self.frames = [] - self.enabled = False - self.camera_id = 0 - - def init(self, env, enabled=True): - self.frames = [] - self.enabled = self.save_dir and self._wandb and enabled - try: - env_name = env.unwrapped.spec.id - except: - env_name = "" - if "maze2d" in env_name: - self.camera_id = -1 - elif "quadruped" in env_name: - self.camera_id = 2 - self.record(env) - - def record(self, env): - if self.enabled: - frame = env.render( - mode="rgb_array", - height=self.render_size, - width=self.render_size, - camera_id=self.camera_id, - ) - self.frames.append(frame) - - def save(self, step): - if self.enabled: - frames = np.stack(self.frames).transpose(0, 3, 1, 2) - self._wandb.log( - { - "eval_video": self._wandb.Video( - frames, fps=self.env.fps, format="mp4" - ) - }, - step=step, - ) - - class Logger(object): """Primary logger object. Logs either locally or using wandb.""" @@ -170,15 +124,6 @@ class Logger(object): ) print(colored("Logs will be synced with wandb.", "blue", attrs=["bold"])) self._wandb = wandb - self._video = ( - VideoRecorder(self._log_dir, self._wandb) - if self._wandb and cfg.save_video - else None - ) - - @property - def video(self): - return self._video def save_model(self, agent, identifier): if self._save_model: @@ -214,12 +159,12 @@ class Logger(object): def _format(self, key, value, ty): if ty == "int": - return f'{colored(key + ":", "grey")} {int(value):,}' + return f'{colored(key + ":", "yellow")} {int(value):,}' elif ty == "float": - return f'{colored(key + ":", "grey")} {value:.01f}' + return f'{colored(key + ":", "yellow")} {value:.01f}' elif ty == "time": value = str(datetime.timedelta(seconds=int(value))) - return f'{colored(key + ":", "grey")} {value}' + return f'{colored(key + ":", "yellow")} {value}' else: raise f"invalid log format type: {ty}" @@ -234,10 +179,9 @@ class Logger(object): assert category in {"train", "eval"} if self._wandb is not None: for k, v in d.items(): - self._wandb.log({category + "/" + k: v}, step=d["env_step"]) + self._wandb.log({category + "/" + k: v}, step=d["step"]) if category == "eval": - # keys = ['env_step', 'avg_reward'] - keys = ["env_step", "avg_sum_reward", "avg_max_reward", "pc_success"] + keys = ["step", "avg_sum_reward", "avg_max_reward", "pc_success"] self._eval.append(np.array([d[key] for key in keys])) pd.DataFrame(np.array(self._eval)).to_csv( self._log_dir / "eval.log", header=keys, index=None diff --git a/lerobot/common/policies/diffusion.py b/lerobot/common/policies/diffusion.py index b8272453..d05e7ac2 100644 --- a/lerobot/common/policies/diffusion.py +++ b/lerobot/common/policies/diffusion.py @@ -1,6 +1,8 @@ import torch import torch.nn as nn import torch.nn.functional as F +from diffusers.schedulers.scheduling_ddpm import DDPMScheduler +from diffusion_policy.model.vision.multi_image_obs_encoder import MultiImageObsEncoder from diffusion_policy.policy.diffusion_unet_image_policy import DiffusionUnetImagePolicy diff --git a/lerobot/common/policies/factory.py b/lerobot/common/policies/factory.py index d2407e1f..7a87b5af 100644 --- a/lerobot/common/policies/factory.py +++ b/lerobot/common/policies/factory.py @@ -4,9 +4,29 @@ def make_policy(cfg): policy = TDMPC(cfg.policy) elif cfg.policy.name == "diffusion": + from diffusers.schedulers.scheduling_ddpm import DDPMScheduler + from diffusion_policy.model.vision.model_getter import get_resnet + from diffusion_policy.model.vision.multi_image_obs_encoder import ( + MultiImageObsEncoder, + ) + from lerobot.common.policies.diffusion import DiffusionPolicy - policy = DiffusionPolicy(cfg.policy) + noise_scheduler = DDPMScheduler(**cfg.noise_scheduler) + + rgb_model = get_resnet(**cfg.rgb_model) + + obs_encoder = MultiImageObsEncoder( + rgb_model=rgb_model, + **cfg.obs_encoder, + ) + + policy = DiffusionPolicy( + noise_scheduler=noise_scheduler, + obs_encoder=obs_encoder, + n_action_steps=cfg.n_action_steps + cfg.n_latency_steps, + **cfg.policy, + ) else: raise ValueError(cfg.policy.name) diff --git a/lerobot/common/policies/tdmpc_helper.py b/lerobot/common/policies/tdmpc_helper.py index d143da3d..dd7abbec 100644 --- a/lerobot/common/policies/tdmpc_helper.py +++ b/lerobot/common/policies/tdmpc_helper.py @@ -441,261 +441,6 @@ class Episode(object): self._idx += 1 -class ReplayBuffer: - """ - Storage and sampling functionality. - """ - - def __init__(self, cfg, dataset=None): - action_dim = cfg.action_dim - obs_shape = {"rgb": (3, cfg.img_size, cfg.img_size), "state": (cfg.state_dim,)} - - self.cfg = cfg - self.device = torch.device(cfg.buffer_device) - print("Replay buffer device: ", self.device) - - if dataset is not None: - self.capacity = max(dataset["rewards"].shape[0], cfg.max_buffer_size) - else: - self.capacity = min(cfg.train_steps, cfg.max_buffer_size) - - if cfg.modality in {"pixels", "state"}: - dtype = torch.float32 if cfg.modality == "state" else torch.uint8 - # Note self.obs_shape always has single frame, which is different from cfg.obs_shape - self.obs_shape = ( - obs_shape if cfg.modality == "state" else (3, *obs_shape[-2:]) - ) - self._obs = torch.zeros( - (self.capacity + cfg.horizon - 1, *self.obs_shape), - dtype=dtype, - device=self.device, - ) - self._next_obs = torch.zeros( - (self.capacity + cfg.horizon - 1, *self.obs_shape), - dtype=dtype, - device=self.device, - ) - elif cfg.modality == "all": - self.obs_shape = {} - self._obs, self._next_obs = {}, {} - for k, v in obs_shape.items(): - assert k in {"rgb", "state"} - dtype = torch.float32 if k == "state" else torch.uint8 - self.obs_shape[k] = v if k == "state" else (3, *v[-2:]) - self._obs[k] = torch.zeros( - (self.capacity + cfg.horizon - 1, *self.obs_shape[k]), - dtype=dtype, - device=self.device, - ) - self._next_obs[k] = self._obs[k].clone() - else: - raise ValueError - - self._action = torch.zeros( - (self.capacity + cfg.horizon - 1, action_dim), - dtype=torch.float32, - device=self.device, - ) - self._reward = torch.zeros( - (self.capacity + cfg.horizon - 1,), dtype=torch.float32, device=self.device - ) - self._mask = torch.zeros( - (self.capacity + cfg.horizon - 1,), dtype=torch.float32, device=self.device - ) - self._done = torch.zeros( - (self.capacity + cfg.horizon - 1,), dtype=torch.bool, device=self.device - ) - self._priorities = torch.ones( - (self.capacity + cfg.horizon - 1,), dtype=torch.float32, device=self.device - ) - self._eps = 1e-6 - self._full = False - self.idx = 0 - if dataset is not None: - self.init_from_offline_dataset(dataset) - - self._aug = aug(cfg) - - def init_from_offline_dataset(self, dataset): - """Initialize the replay buffer from an offline dataset.""" - assert self.idx == 0 and not self._full - n_transitions = int(len(dataset["rewards"]) * self.cfg.data_first_percent) - - def copy_data(dst, src, n): - assert isinstance(dst, dict) == isinstance(src, dict) - if isinstance(dst, dict): - for k in dst: - copy_data(dst[k], src[k], n) - else: - dst[:n] = torch.from_numpy(src[:n]) - - copy_data(self._obs, dataset["observations"], n_transitions) - copy_data(self._next_obs, dataset["next_observations"], n_transitions) - copy_data(self._action, dataset["actions"], n_transitions) - copy_data(self._reward, dataset["rewards"], n_transitions) - copy_data(self._mask, dataset["masks"], n_transitions) - copy_data(self._done, dataset["dones"], n_transitions) - self.idx = (self.idx + n_transitions) % self.capacity - self._full = n_transitions >= self.capacity - - def __add__(self, episode: Episode): - self.add(episode) - return self - - def add(self, episode: Episode): - """Add an episode to the replay buffer.""" - if self.idx + len(episode) > self.capacity: - print("Warning: episode got truncated") - ep_len = min(len(episode), self.capacity - self.idx) - idxs = slice(self.idx, self.idx + ep_len) - assert self.idx + ep_len <= self.capacity - if self.cfg.modality in {"pixels", "state"}: - self._obs[idxs] = ( - episode.obses[:ep_len] - if self.cfg.modality == "state" - else episode.obses[:ep_len, -3:] - ) - self._next_obs[idxs] = ( - episode.obses[1 : ep_len + 1] - if self.cfg.modality == "state" - else episode.obses[1 : ep_len + 1, -3:] - ) - elif self.cfg.modality == "all": - for k, v in episode.obses.items(): - assert k in {"rgb", "state"} - assert k in self._obs - assert k in self._next_obs - if k == "rgb": - self._obs[k][idxs] = episode.obses[k][:ep_len, -3:] - self._next_obs[k][idxs] = episode.obses[k][1 : ep_len + 1, -3:] - else: - self._obs[k][idxs] = episode.obses[k][:ep_len] - self._next_obs[k][idxs] = episode.obses[k][1 : ep_len + 1] - self._action[idxs] = episode.actions[:ep_len] - self._reward[idxs] = episode.rewards[:ep_len] - self._mask[idxs] = episode.masks[:ep_len] - self._done[idxs] = episode.dones[:ep_len] - self._done[self.idx + ep_len - 1] = True # in case truncated - if self._full: - max_priority = ( - self._priorities[: self.capacity].max().to(self.device).item() - ) - else: - max_priority = ( - 1.0 - if self.idx == 0 - else self._priorities[: self.idx].max().to(self.device).item() - ) - new_priorities = torch.full((ep_len,), max_priority, device=self.device) - self._priorities[idxs] = new_priorities - self.idx = (self.idx + ep_len) % self.capacity - self._full = self._full or self.idx == 0 - - def update_priorities(self, idxs, priorities): - """Update priorities for Prioritized Experience Replay (PER)""" - self._priorities[idxs] = priorities.squeeze(1).to(self.device) + self._eps - - def _get_obs(self, arr, idxs): - """Retrieve observations by indices""" - if isinstance(arr, dict): - return {k: self._get_obs(v, idxs) for k, v in arr.items()} - if arr.ndim <= 2: # if self.cfg.modality == 'state': - return arr[idxs].cuda() - obs = torch.empty( - (self.cfg.batch_size, 3 * self.cfg.frame_stack, *arr.shape[-2:]), - dtype=arr.dtype, - device=torch.device("cuda"), - ) - obs[:, -3:] = arr[idxs].cuda() - _idxs = idxs.clone() - mask = torch.ones_like(_idxs, dtype=torch.bool) - for i in range(1, self.cfg.frame_stack): - mask[_idxs % self.cfg.episode_length == 0] = False - _idxs[mask] -= 1 - obs[:, -(i + 1) * 3 : -i * 3] = arr[_idxs].cuda() - return obs.float() - - def sample(self): - """Sample transitions from the replay buffer.""" - probs = ( - self._priorities[: self.capacity] - if self._full - else self._priorities[: self.idx] - ) ** self.cfg.per_alpha - probs /= probs.sum() - total = len(probs) - idxs = torch.from_numpy( - np.random.choice( - total, - self.cfg.batch_size, - p=probs.cpu().numpy(), - replace=not self._full, - ) - ).to(self.device) - weights = (total * probs[idxs]) ** (-self.cfg.per_beta) - weights /= weights.max() - - idxs_in_horizon = torch.stack([idxs + t for t in range(self.cfg.horizon)]) - - obs = self._aug(self._get_obs(self._obs, idxs)) - next_obs = [ - self._aug(self._get_obs(self._next_obs, _idxs)) for _idxs in idxs_in_horizon - ] - if isinstance(next_obs[0], dict): - next_obs = {k: torch.stack([o[k] for o in next_obs]) for k in next_obs[0]} - else: - next_obs = torch.stack(next_obs) - action = self._action[idxs_in_horizon] - reward = self._reward[idxs_in_horizon] - mask = self._mask[idxs_in_horizon] - done = self._done[idxs_in_horizon] - - if not action.is_cuda: - action, reward, mask, done, idxs, weights = ( - action.cuda(), - reward.cuda(), - mask.cuda(), - done.cuda(), - idxs.cuda(), - weights.cuda(), - ) - - return ( - obs, - next_obs, - action, - reward.unsqueeze(2), - mask.unsqueeze(2), - done.unsqueeze(2), - idxs, - weights, - ) - - def save(self, path): - """Save the replay buffer to path""" - print(f"saving replay buffer to '{path}'...") - sz = self.capacity if self._full else self.idx - dataset = { - "observations": ( - {k: v[:sz].cpu().numpy() for k, v in self._obs.items()} - if isinstance(self._obs, dict) - else self._obs[:sz].cpu().numpy() - ), - "next_observations": ( - {k: v[:sz].cpu().numpy() for k, v in self._next_obs.items()} - if isinstance(self._next_obs, dict) - else self._next_obs[:sz].cpu().numpy() - ), - "actions": self._action[:sz].cpu().numpy(), - "rewards": self._reward[:sz].cpu().numpy(), - "dones": self._done[:sz].cpu().numpy(), - "masks": self._mask[:sz].cpu().numpy(), - } - with open(path, "wb") as f: - pickle.dump(dataset, f) - return dataset - - def get_dataset_dict(cfg, env, return_reward_normalizer=False): """Construct a dataset for env""" required_keys = [ diff --git a/lerobot/configs/env/pusht.yaml b/lerobot/configs/env/pusht.yaml index 1a500120..60fc594e 100644 --- a/lerobot/configs/env/pusht.yaml +++ b/lerobot/configs/env/pusht.yaml @@ -3,7 +3,9 @@ eval_episodes: 50 eval_freq: 7500 save_freq: 75000 -train_steps: 50000 # TODO: same as simxarm, need to adjust +# TODO: same as simxarm, need to adjust +offline_steps: 25000 +online_steps: 25000 fps: 10 diff --git a/lerobot/configs/env/simxarm.yaml b/lerobot/configs/env/simxarm.yaml index 3972631c..0658636a 100644 --- a/lerobot/configs/env/simxarm.yaml +++ b/lerobot/configs/env/simxarm.yaml @@ -3,7 +3,9 @@ eval_episodes: 20 eval_freq: 1000 save_freq: 10000 -train_steps: 50000 +log_freq: 50 +offline_steps: 25000 +online_steps: 25000 fps: 15 diff --git a/lerobot/configs/policy/diffusion.yaml b/lerobot/configs/policy/diffusion.yaml new file mode 100644 index 00000000..40c708db --- /dev/null +++ b/lerobot/configs/policy/diffusion.yaml @@ -0,0 +1,117 @@ +# @package _global_ + +shape_meta: + # acceptable types: rgb, low_dim + obs: + image: + shape: [3, 96, 96] + type: rgb + agent_pos: + shape: [2] + type: low_dim + action: + shape: [2] + +horizon: 16 +n_obs_steps: 2 +n_action_steps: 8 +n_latency_steps: 0 +dataset_obs_steps: ${n_obs_steps} +past_action_visible: False +keypoint_visible_rate: 1.0 +obs_as_global_cond: True + +policy: + name: diffusion + + shape_meta: ${shape_meta} + + horizon: ${horizon} + # n_action_steps: ${eval:'${n_action_steps}+${n_latency_steps}'} + n_obs_steps: ${n_obs_steps} + num_inference_steps: 100 + obs_as_global_cond: ${obs_as_global_cond} + # crop_shape: null + diffusion_step_embed_dim: 128 + down_dims: [512, 1024, 2048] + kernel_size: 5 + n_groups: 8 + cond_predict_scale: True + + pretrained_model_path: + + batch_size: 64 + + per_alpha: 0.6 + per_beta: 0.4 + + balanced_sampling: true + + utd: 1 + +noise_scheduler: + # _target_: diffusers.schedulers.scheduling_ddpm.DDPMScheduler + num_train_timesteps: 100 + beta_start: 0.0001 + beta_end: 0.02 + beta_schedule: squaredcos_cap_v2 + variance_type: fixed_small # Yilun's paper uses fixed_small_log instead, but easy to cause Nan + clip_sample: True # required when predict_epsilon=False + prediction_type: epsilon # or sample + +obs_encoder: + # _target_: diffusion_policy.model.vision.multi_image_obs_encoder.MultiImageObsEncoder + shape_meta: ${shape_meta} + resize_shape: null + crop_shape: [76, 76] + # constant center crop + random_crop: True + use_group_norm: True + share_rgb_model: False + imagenet_norm: True + +rgb_model: + #_target_: diffusion_policy.model.vision.model_getter.get_resnet + name: resnet18 + weights: null + +ema: + _target_: diffusion_policy.model.diffusion.ema_model.EMAModel + update_after_step: 0 + inv_gamma: 1.0 + power: 0.75 + min_value: 0.0 + max_value: 0.9999 + +optimizer: + _target_: torch.optim.AdamW + lr: 1.0e-4 + betas: [0.95, 0.999] + eps: 1.0e-8 + weight_decay: 1.0e-6 + +training: + device: "cuda:0" + seed: 42 + debug: False + resume: True + # optimization + lr_scheduler: cosine + lr_warmup_steps: 500 + num_epochs: 8000 + gradient_accumulate_every: 1 + # EMA destroys performance when used with BatchNorm + # replace BatchNorm with GroupNorm. + use_ema: True + freeze_encoder: False + # training loop control + # in epochs + rollout_every: 50 + checkpoint_every: 50 + val_every: 1 + sample_every: 5 + # steps per epoch + max_train_steps: null + max_val_steps: null + # misc + tqdm_interval_sec: 1.0 diff --git a/lerobot/configs/policy/tdmpc.yaml b/lerobot/configs/policy/tdmpc.yaml index 1c2140f8..f4bb46ed 100644 --- a/lerobot/configs/policy/tdmpc.yaml +++ b/lerobot/configs/policy/tdmpc.yaml @@ -5,8 +5,6 @@ policy: reward_scale: 1.0 - # xarm_lift - train_steps: ${train_steps} episode_length: ${env.episode_length} discount: 0.9 modality: 'all' diff --git a/lerobot/scripts/eval.py b/lerobot/scripts/eval.py index 077f2556..b0f67834 100644 --- a/lerobot/scripts/eval.py +++ b/lerobot/scripts/eval.py @@ -26,31 +26,31 @@ def eval_policy( save_video: bool = False, video_dir: Path = None, fps: int = 15, - env_step: int = None, - wandb=None, + return_first_video: bool = False, ): - if wandb is not None: - assert env_step is not None sum_rewards = [] max_rewards = [] successes = [] threads = [] for i in range(num_episodes): - ep_frames = [] - - def rendering_callback(env, td=None): - ep_frames.append(env.render()) - tensordict = env.reset() - if save_video or wandb: + + ep_frames = [] + if save_video or (return_first_video and i == 0): + + def rendering_callback(env, td=None): + ep_frames.append(env.render()) + # render first frame before rollout rendering_callback(env) + else: + rendering_callback = None with torch.inference_mode(): rollout = env.rollout( max_steps=max_steps, policy=policy, - callback=rendering_callback if save_video or wandb else None, + callback=rendering_callback, auto_reset=False, tensordict=tensordict, auto_cast_to_device=True, @@ -63,7 +63,7 @@ def eval_policy( max_rewards.append(ep_max_reward.item()) successes.append(ep_success.item()) - if save_video or wandb: + if save_video or (return_first_video and i == 0): stacked_frames = np.stack(ep_frames) if save_video: @@ -76,12 +76,8 @@ def eval_policy( thread.start() threads.append(thread) - first_episode = i == 0 - if wandb and first_episode: - eval_video = wandb.Video( - stacked_frames.transpose(0, 3, 1, 2), fps=fps, format="mp4" - ) - wandb.log({"eval_video": eval_video}, step=env_step) + if return_first_video and i == 0: + first_video = stacked_frames.transpose(0, 3, 1, 2) for thread in threads: thread.join() @@ -91,6 +87,8 @@ def eval_policy( "avg_max_reward": np.nanmean(max_rewards), "pc_success": np.nanmean(successes) * 100, } + if return_first_video: + return metrics, first_video return metrics diff --git a/lerobot/scripts/train.py b/lerobot/scripts/train.py index 5a0e2c16..1f516cce 100644 --- a/lerobot/scripts/train.py +++ b/lerobot/scripts/train.py @@ -38,6 +38,40 @@ def train_notebook( train(cfg, out_dir=out_dir, job_name=job_name) +def log_training_metrics(L, metrics, step, online_episode_idx, start_time, is_offline): + common_metrics = { + "episode": online_episode_idx, + "step": step, + "total_time": time.time() - start_time, + "is_offline": float(is_offline), + } + metrics.update(common_metrics) + L.log(metrics, category="train") + + +def eval_policy_and_log( + env, td_policy, step, online_episode_idx, start_time, is_offline, cfg, L +): + common_metrics = { + "episode": online_episode_idx, + "step": step, + "total_time": time.time() - start_time, + "is_offline": float(is_offline), + } + metrics, first_video = eval_policy( + env, + td_policy, + num_episodes=cfg.eval_episodes, + return_first_video=True, + ) + metrics.update(common_metrics) + L.log(metrics, category="eval") + + if cfg.wandb.enable: + eval_video = L._wandb.Video(first_video, fps=cfg.fps, format="mp4") + L._wandb.log({"eval_video": eval_video}, step=step) + + def train(cfg: dict, out_dir=None, job_name=None): if out_dir is None: raise NotImplementedError() @@ -84,115 +118,89 @@ def train(cfg: dict, out_dir=None, job_name=None): online_episode_idx = 0 start_time = time.time() step = 0 - last_log_step = 0 - last_save_step = 0 - while step < cfg.train_steps: - is_offline = True - num_updates = cfg.env.episode_length - _step = step + num_updates - rollout_metrics = {} + # First eval with a random model or pretrained + eval_policy_and_log( + env, td_policy, step, online_episode_idx, start_time, is_offline, cfg, L + ) - # TODO(rcadene): move offline_steps outside policy - if step >= cfg.policy.offline_steps: - is_offline = False + # Train offline + for _ in range(cfg.offline_steps): + # TODO(rcadene): is it ok if step_t=0 = 0 and not 1 as previously done? + metrics = policy.update(offline_buffer, step) - # TODO: use SyncDataCollector for that? - with torch.no_grad(): - rollout = env.rollout( - max_steps=cfg.env.episode_length, - policy=td_policy, - auto_cast_to_device=True, - ) - assert len(rollout) <= cfg.env.episode_length - rollout["episode"] = torch.tensor( - [online_episode_idx] * len(rollout), dtype=torch.int + if step % cfg.log_freq == 0: + log_training_metrics( + L, metrics, step, online_episode_idx, start_time, is_offline=False ) - online_buffer.extend(rollout) - ep_sum_reward = rollout["next", "reward"].sum() - ep_max_reward = rollout["next", "reward"].max() - ep_success = rollout["next", "success"].any() + if step > 0 and step % cfg.eval_freq == 0: + eval_policy_and_log( + env, td_policy, step, online_episode_idx, start_time, is_offline, cfg, L + ) - online_episode_idx += 1 - rollout_metrics = { - "avg_sum_reward": np.nanmean(ep_sum_reward), - "avg_max_reward": np.nanmean(ep_max_reward), - "pc_success": np.nanmean(ep_success) * 100, - } - num_updates = len(rollout) * cfg.policy.utd - _step = min(step + len(rollout), cfg.train_steps) + if step > 0 and cfg.save_model and step % cfg.save_freq == 0: + print(f"Checkpoint model at step {step}") + L.save_model(policy, identifier=step) - # Update model - for i in range(num_updates): - if is_offline: - train_metrics = policy.update(offline_buffer, step + i) - else: - train_metrics = policy.update( - online_buffer, - step + i // cfg.policy.utd, - demo_buffer=( - offline_buffer if cfg.policy.balanced_sampling else None - ), - ) + step += 1 - # Log training metrics - env_step = int(_step * cfg.env.action_repeat) - common_metrics = { - "episode": online_episode_idx, - "step": _step, - "env_step": env_step, - "total_time": time.time() - start_time, - "is_offline": float(is_offline), + # Train online + demo_buffer = offline_buffer if cfg.policy.balanced_sampling else None + for _ in range(cfg.online_steps): + # TODO: use SyncDataCollector for that? + with torch.no_grad(): + rollout = env.rollout( + max_steps=cfg.env.episode_length, + policy=td_policy, + auto_cast_to_device=True, + ) + assert len(rollout) <= cfg.env.episode_length + rollout["episode"] = torch.tensor( + [online_episode_idx] * len(rollout), dtype=torch.int + ) + online_buffer.extend(rollout) + + ep_sum_reward = rollout["next", "reward"].sum() + ep_max_reward = rollout["next", "reward"].max() + ep_success = rollout["next", "success"].any() + metrics = { + "avg_sum_reward": np.nanmean(ep_sum_reward), + "avg_max_reward": np.nanmean(ep_max_reward), + "pc_success": np.nanmean(ep_success) * 100, } - train_metrics.update(common_metrics) - train_metrics.update(rollout_metrics) - L.log(train_metrics, category="train") - # Evaluate policy periodically - if step == 0 or env_step - last_log_step >= cfg.eval_freq: + online_episode_idx += 1 - eval_metrics = eval_policy( - env, - td_policy, - num_episodes=cfg.eval_episodes, - env_step=env_step, - wandb=L._wandb, + for _ in range(cfg.policy.utd): + train_metrics = policy.update( + online_buffer, + step, + demo_buffer=demo_buffer, ) + metrics.update(train_metrics) + if step % cfg.log_freq == 0: + log_training_metrics( + L, metrics, step, online_episode_idx, start_time, is_offline=False + ) - common_metrics.update(eval_metrics) - L.log(common_metrics, category="eval") - last_log_step = env_step - env_step % cfg.eval_freq + if step > 0 and step & cfg.eval_freq == 0: + eval_policy_and_log( + env, + td_policy, + step, + online_episode_idx, + start_time, + is_offline, + cfg, + L, + ) - # Save model periodically - if cfg.save_model and env_step - last_save_step >= cfg.save_freq: - L.save_model(policy, identifier=env_step) - print(f"Model has been checkpointed at step {env_step}") - last_save_step = env_step - env_step % cfg.save_freq + if step > 0 and cfg.save_model and step % cfg.save_freq == 0: + print(f"Checkpoint model at step {step}") + L.save_model(policy, identifier=step) - if cfg.save_model and is_offline and _step >= cfg.offline_steps: - # save the model after offline training - L.save_model(policy, identifier="offline") - - step = _step - - # dataset_d4rl = D4RLExperienceReplay( - # dataset_id="maze2d-umaze-v1", - # split_trajs=False, - # batch_size=1, - # sampler=SamplerWithoutReplacement(drop_last=False), - # prefetch=4, - # direct_download=True, - # ) - - # dataset_openx = OpenXExperienceReplay( - # "cmu_stretch", - # batch_size=1, - # num_slices=1, - # #download="force", - # streaming=False, - # root="data", - # ) + step += 1 if __name__ == "__main__": diff --git a/tests/test_policies.py b/tests/test_policies.py index 7408c729..03f20bd0 100644 --- a/tests/test_policies.py +++ b/tests/test_policies.py @@ -6,12 +6,19 @@ from .utils import init_config @pytest.mark.parametrize( - "env_name", + "env_name,policy_name", [ - "simxarm", - "pusht", + ("simxarm", "tdmpc"), + ("pusht", "tdmpc"), + ("simxarm", "diffusion"), + ("pusht", "diffusion"), ], ) -def test_factory(env_name): - cfg = init_config(overrides=[f"env={env_name}"]) +def test_factory(env_name, policy_name): + cfg = init_config( + overrides=[ + f"env={env_name}", + f"policy={policy_name}", + ] + ) policy = make_policy(cfg)