From 228c0456740b29ea0248de24f3da427b245df85b Mon Sep 17 00:00:00 2001 From: Cadene Date: Sat, 10 Feb 2024 15:46:24 +0000 Subject: [PATCH] Eval reproduced! Train running (but not reproduced) --- .gitignore | 1 + README.md | 1 + lerobot/common/datasets/__init__.py | 0 lerobot/common/datasets/simxarm.py | 190 +++++++++++++++++++++ lerobot/common/envs/factory.py | 1 + lerobot/common/envs/simxarm.py | 9 +- lerobot/common/logger.py | 243 +++++++++++++++++++++++++++ lerobot/common/tdmpc.py | 68 +++++++- lerobot/common/tdmpc_helper.py | 12 +- lerobot/configs/default.yaml | 41 ++++- lerobot/scripts/eval.py | 14 +- lerobot/scripts/train.py | 186 ++++++++++++++++++-- lerobot/scripts/visualize.py | 80 --------- lerobot/scripts/visualize_dataset.py | 59 +++++++ 14 files changed, 787 insertions(+), 118 deletions(-) create mode 100644 lerobot/common/datasets/__init__.py create mode 100644 lerobot/common/datasets/simxarm.py create mode 100644 lerobot/common/logger.py delete mode 100644 lerobot/scripts/visualize.py create mode 100644 lerobot/scripts/visualize_dataset.py diff --git a/.gitignore b/.gitignore index 8a917d71..01308c2e 100644 --- a/.gitignore +++ b/.gitignore @@ -5,6 +5,7 @@ wandb data outputs .vscode +rl # HPC nautilus/*.yaml diff --git a/README.md b/README.md index 74a9b6ca..4bba66df 100644 --- a/README.md +++ b/README.md @@ -15,6 +15,7 @@ conda activate lerobot python setup.py develop ``` + ## Contribute **style** diff --git a/lerobot/common/datasets/__init__.py b/lerobot/common/datasets/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/lerobot/common/datasets/simxarm.py b/lerobot/common/datasets/simxarm.py new file mode 100644 index 00000000..84e6ca7c --- /dev/null +++ b/lerobot/common/datasets/simxarm.py @@ -0,0 +1,190 @@ +import os +import pickle +from pathlib import Path +from typing import Any, Callable, Dict, Tuple + +import torch +import torchrl +import tqdm +from tensordict import TensorDict +from torchrl.data.datasets.utils import _get_root_dir +from torchrl.data.replay_buffers.replay_buffers import ( + TensorDictPrioritizedReplayBuffer, + TensorDictReplayBuffer, +) +from torchrl.data.replay_buffers.samplers import ( + Sampler, + SliceSampler, + SliceSamplerWithoutReplacement, +) +from torchrl.data.replay_buffers.storages import TensorStorage, _collate_id +from torchrl.data.replay_buffers.writers import ImmutableDatasetWriter, Writer + + +class SimxarmExperienceReplay(TensorDictReplayBuffer): + + available_datasets = [ + "xarm_lift_medium", + ] + + def __init__( + self, + dataset_id, + batch_size: int = None, + *, + shuffle: bool = True, + num_slices: int = None, + slice_len: int = None, + pad: float = None, + replacement: bool = None, + streaming: bool = False, + root: Path = None, + download: bool = False, + sampler: Sampler = None, + writer: Writer = None, + collate_fn: Callable = None, + pin_memory: bool = False, + prefetch: int = None, + transform: "torchrl.envs.Transform" = None, # noqa-F821 + split_trajs: bool = False, + strict_length: bool = True, + ): + self.download = download + if streaming: + raise NotImplementedError + self.streaming = streaming + self.dataset_id = dataset_id + self.split_trajs = split_trajs + self.shuffle = shuffle + self.num_slices = num_slices + self.slice_len = slice_len + self.pad = pad + + self.strict_length = strict_length + if (self.num_slices is not None) and (self.slice_len is not None): + raise ValueError("num_slices or slice_len can be not None, but not both.") + if split_trajs: + raise NotImplementedError + + if root is None: + root = _get_root_dir("simxarm") + os.makedirs(root, exist_ok=True) + self.root = Path(root) + if self.download == "force" or (self.download and not self._is_downloaded()): + storage = self._download_and_preproc() + else: + storage = TensorStorage(TensorDict.load_memmap(self.root / dataset_id)) + + if num_slices is not None or slice_len is not None: + if sampler is not None: + raise ValueError( + "`num_slices` and `slice_len` are exclusive with the `sampler` argument." + ) + + if replacement: + if not self.shuffle: + raise RuntimeError( + "shuffle=False can only be used when replacement=False." + ) + sampler = SliceSampler( + num_slices=num_slices, + slice_len=slice_len, + strict_length=strict_length, + ) + else: + sampler = SliceSamplerWithoutReplacement( + num_slices=num_slices, + slice_len=slice_len, + strict_length=strict_length, + shuffle=self.shuffle, + ) + + if writer is None: + writer = ImmutableDatasetWriter() + if collate_fn is None: + collate_fn = _collate_id + + super().__init__( + storage=storage, + sampler=sampler, + writer=writer, + collate_fn=collate_fn, + pin_memory=pin_memory, + prefetch=prefetch, + batch_size=batch_size, + transform=transform, + ) + + @property + def data_path_root(self): + if self.streaming: + return None + return self.root / self.dataset_id + + def _is_downloaded(self): + return os.path.exists(self.data_path_root) + + def _download_and_preproc(self): + # download + # TODO(rcadene) + + # load + dataset_dir = Path("data") / self.dataset_id + dataset_path = dataset_dir / f"buffer.pkl" + print(f"Using offline dataset '{dataset_path}'") + with open(dataset_path, "rb") as f: + dataset_dict = pickle.load(f) + + total_frames = dataset_dict["actions"].shape[0] + + idx0 = 0 + idx1 = 0 + episode_id = 0 + for i in tqdm.tqdm(range(total_frames)): + idx1 += 1 + + if not dataset_dict["dones"][i]: + continue + + num_frames = idx1 - idx0 + + image = torch.tensor(dataset_dict["observations"]["rgb"][idx0:idx1]) + state = torch.tensor(dataset_dict["observations"]["state"][idx0:idx1]) + next_image = torch.tensor( + dataset_dict["next_observations"]["rgb"][idx0:idx1] + ) + next_state = torch.tensor( + dataset_dict["next_observations"]["state"][idx0:idx1] + ) + next_reward = torch.tensor(dataset_dict["rewards"][idx0:idx1]) + next_done = torch.tensor(dataset_dict["dones"][idx0:idx1]) + + episode = TensorDict( + { + ("observation", "image"): image, + ("observation", "state"): state, + "action": torch.tensor(dataset_dict["actions"][idx0:idx1]), + "episode": torch.tensor([episode_id] * num_frames, dtype=torch.int), + "frame_id": torch.arange(0, num_frames, 1), + ("next", "observation", "image"): next_image, + ("next", "observation", "state"): next_state, + ("next", "observation", "reward"): next_reward, + ("next", "observation", "done"): next_done, + }, + batch_size=num_frames, + ) + + if episode_id == 0: + # hack to initialize tensordict data structure to store episodes + td_data = ( + episode[0] + .expand(total_frames) + .memmap_like(self.root / self.dataset_id) + ) + + td_data[idx0:idx1] = episode + + episode_id += 1 + idx0 = idx1 + + return TensorStorage(td_data.lock_()) diff --git a/lerobot/common/envs/factory.py b/lerobot/common/envs/factory.py index de42bc26..9f491f71 100644 --- a/lerobot/common/envs/factory.py +++ b/lerobot/common/envs/factory.py @@ -7,6 +7,7 @@ def make_env(cfg): assert cfg.env == "simxarm" env = SimxarmEnv( task=cfg.task, + frame_skip=cfg.action_repeat, from_pixels=cfg.from_pixels, pixels_only=cfg.pixels_only, image_size=cfg.image_size, diff --git a/lerobot/common/envs/simxarm.py b/lerobot/common/envs/simxarm.py index 3ca3ae0f..1470fceb 100644 --- a/lerobot/common/envs/simxarm.py +++ b/lerobot/common/envs/simxarm.py @@ -24,6 +24,7 @@ class SimxarmEnv(EnvBase): def __init__( self, task, + frame_skip: int = 1, from_pixels: bool = False, pixels_only: bool = False, image_size=None, @@ -32,6 +33,7 @@ class SimxarmEnv(EnvBase): ): super().__init__(device=device, batch_size=[]) self.task = task + self.frame_skip = frame_skip self.from_pixels = from_pixels self.pixels_only = pixels_only self.image_size = image_size @@ -115,12 +117,15 @@ class SimxarmEnv(EnvBase): # step expects shape=(4,) so we pad if necessary action = np.concatenate([action, self._action_padding]) # TODO(rcadene): add info["is_success"] and info["success"] ? - raw_obs, reward, done, info = self._env.step(action) + sum_reward = 0 + for t in range(self.frame_skip): + raw_obs, reward, done, info = self._env.step(action) + sum_reward += reward td = TensorDict( { "observation": self._format_raw_obs(raw_obs), - "reward": torch.tensor([reward], dtype=torch.float32), + "reward": torch.tensor([sum_reward], dtype=torch.float32), "done": torch.tensor([done], dtype=torch.bool), "success": torch.tensor([info["success"]], dtype=torch.bool), }, diff --git a/lerobot/common/logger.py b/lerobot/common/logger.py new file mode 100644 index 00000000..cd6139bc --- /dev/null +++ b/lerobot/common/logger.py @@ -0,0 +1,243 @@ +import datetime +import os +import re +from pathlib import Path + +import numpy as np +import pandas as pd +from omegaconf import OmegaConf +from termcolor import colored + +CONSOLE_FORMAT = [ + ("episode", "E", "int"), + ("env_step", "S", "int"), + ("avg_reward", "R", "float"), + ("pc_success", "R", "float"), + ("total_time", "T", "time"), +] +AGENT_METRICS = [ + "consistency_loss", + "reward_loss", + "value_loss", + "total_loss", + "weighted_loss", + "pi_loss", + "grad_norm", +] + + +def make_dir(dir_path): + """Create directory if it does not already exist.""" + try: + dir_path.mkdir(parents=True, exist_ok=True) + except OSError: + pass + return dir_path + + +def print_run(cfg, reward=None): + """Pretty-printing of run information. Call at start of training.""" + prefix, color, attrs = " ", "green", ["bold"] + + def limstr(s, maxlen=32): + return str(s[:maxlen]) + "..." if len(str(s)) > maxlen else s + + def pprint(k, v): + print( + prefix + colored(f'{k.capitalize() + ":":<16}', color, attrs=attrs), + limstr(v), + ) + + kvs = [ + ("task", cfg.task), + ("train steps", f"{int(cfg.train_steps * cfg.action_repeat):,}"), + # ('observations', 'x'.join([str(s) for s in cfg.obs_shape])), + # ('actions', cfg.action_dim), + # ('experiment', cfg.exp_name), + ] + if reward is not None: + kvs.append( + ("episode reward", colored(str(int(reward)), "white", attrs=["bold"])) + ) + w = np.max([len(limstr(str(kv[1]))) for kv in kvs]) + 21 + div = "-" * w + print(div) + for k, v in kvs: + pprint(k, v) + print(div) + + +def cfg_to_group(cfg, return_list=False): + """Return a wandb-safe group name for logging. Optionally returns group name as list.""" + lst = [cfg.task, cfg.modality, re.sub("[^0-9a-zA-Z]+", "-", cfg.exp_name)] + return lst if return_list else "-".join(lst) + + +class VideoRecorder: + """Utility class for logging evaluation videos.""" + + def __init__(self, root_dir, wandb, render_size=384, fps=15): + self.save_dir = (root_dir / "eval_video") if root_dir else None + self._wandb = wandb + self.render_size = render_size + self.fps = fps + self.frames = [] + self.enabled = False + self.camera_id = 0 + + def init(self, env, enabled=True): + self.frames = [] + self.enabled = self.save_dir and self._wandb and enabled + try: + env_name = env.unwrapped.spec.id + except: + env_name = "" + if "maze2d" in env_name: + self.camera_id = -1 + elif "quadruped" in env_name: + self.camera_id = 2 + self.record(env) + + def record(self, env): + if self.enabled: + frame = env.render( + mode="rgb_array", + height=self.render_size, + width=self.render_size, + camera_id=self.camera_id, + ) + self.frames.append(frame) + + def save(self, step): + if self.enabled: + frames = np.stack(self.frames).transpose(0, 3, 1, 2) + self._wandb.log( + {"eval_video": self._wandb.Video(frames, fps=self.fps, format="mp4")}, + step=step, + ) + + +class Logger(object): + """Primary logger object. Logs either locally or using wandb.""" + + def __init__(self, log_dir, cfg): + self._log_dir = make_dir(Path(log_dir)) + self._model_dir = make_dir(self._log_dir / "models") + self._buffer_dir = make_dir(self._log_dir / "buffers") + self._save_model = cfg.save_model + self._save_buffer = cfg.save_buffer + self._group = cfg_to_group(cfg) + self._seed = cfg.seed + self._cfg = cfg + self._eval = [] + print_run(cfg) + project, entity = cfg.get("wandb_project", "none"), cfg.get( + "wandb_entity", "none" + ) + run_offline = ( + not cfg.get("use_wandb", False) or project == "none" or entity == "none" + ) + if run_offline: + print(colored("Logs will be saved locally.", "yellow", attrs=["bold"])) + self._wandb = None + else: + try: + os.environ["WANDB_SILENT"] = "true" + import wandb + + wandb.init( + project=project, + entity=entity, + name=str(cfg.seed), + notes=cfg.notes, + group=self._group, + tags=cfg_to_group(cfg, return_list=True) + [f"seed:{cfg.seed}"], + dir=self._log_dir, + config=OmegaConf.to_container(cfg, resolve=True), + ) + print( + colored("Logs will be synced with wandb.", "blue", attrs=["bold"]) + ) + self._wandb = wandb + except: + print( + colored( + "Warning: failed to init wandb. Make sure `wandb_entity` is set to your username in `config.yaml`. Logs will be saved locally.", + "yellow", + attrs=["bold"], + ) + ) + self._wandb = None + self._video = ( + VideoRecorder(log_dir, self._wandb) + if self._wandb and cfg.save_video + else None + ) + + @property + def video(self): + return self._video + + def save_model(self, agent, identifier): + if self._save_model: + fp = self._model_dir / f"{str(identifier)}.pt" + agent.save(fp) + if self._wandb: + artifact = self._wandb.Artifact( + self._group + "-" + str(self._seed) + "-" + str(identifier), + type="model", + ) + artifact.add_file(fp) + self._wandb.log_artifact(artifact) + + def save_buffer(self, buffer, identifier): + fp = self._buffer_dir / f"{str(identifier)}.pkl" + buffer.save(fp) + if self._wandb: + artifact = self._wandb.Artifact( + self._group + "-" + str(self._seed) + "-" + str(identifier), + type="buffer", + ) + artifact.add_file(fp) + self._wandb.log_artifact(artifact) + + def finish(self, agent, buffer): + if self._save_model: + self.save_model(agent, identifier="final") + if self._save_buffer: + self.save_buffer(buffer, identifier="buffer") + if self._wandb: + self._wandb.finish() + print_run(self._cfg, self._eval[-1][-1]) + + def _format(self, key, value, ty): + if ty == "int": + return f'{colored(key + ":", "grey")} {int(value):,}' + elif ty == "float": + return f'{colored(key + ":", "grey")} {value:.01f}' + elif ty == "time": + value = str(datetime.timedelta(seconds=int(value))) + return f'{colored(key + ":", "grey")} {value}' + else: + raise f"invalid log format type: {ty}" + + def _print(self, d, category): + category = colored(category, "blue" if category == "train" else "green") + pieces = [f" {category:<14}"] + for k, disp_k, ty in CONSOLE_FORMAT: + pieces.append(f"{self._format(disp_k, d.get(k, 0), ty):<26}") + print(" ".join(pieces)) + + def log(self, d, category="train"): + assert category in {"train", "eval"} + if self._wandb is not None: + for k, v in d.items(): + self._wandb.log({category + "/" + k: v}, step=d["env_step"]) + if category == "eval": + # keys = ['env_step', 'avg_reward'] + keys = ["env_step", "avg_reward", "pc_success"] + self._eval.append(np.array([d[key] for key in keys])) + pd.DataFrame(np.array(self._eval)).to_csv( + self._log_dir / "eval.log", header=keys, index=None + ) + self._print(d, category) diff --git a/lerobot/common/tdmpc.py b/lerobot/common/tdmpc.py index df4e647b..da8638dd 100644 --- a/lerobot/common/tdmpc.py +++ b/lerobot/common/tdmpc.py @@ -1,5 +1,6 @@ from copy import deepcopy +import einops import numpy as np import torch import torch.nn as nn @@ -90,7 +91,7 @@ class TDMPC(nn.Module): self.model_target = deepcopy(self.model) self.optim = torch.optim.Adam(self.model.parameters(), lr=self.cfg.lr) self.pi_optim = torch.optim.Adam(self.model._pi.parameters(), lr=self.cfg.lr) - self.bc_optim = torch.optim.Adam(self.model.parameters(), lr=self.cfg.lr) + # self.bc_optim = torch.optim.Adam(self.model.parameters(), lr=self.cfg.lr) self.model.eval() self.model_target.eval() self.batch_size = cfg.batch_size @@ -308,9 +309,41 @@ class TDMPC(nn.Module): self.demo_batch_size = 0 # Sample from interaction dataset - obs, next_obses, action, reward, mask, done, idxs, weights = ( - replay_buffer.sample() + + # to not have to mask + # batch_size = (self.cfg.batch_size // self.cfg.horizon) * self.cfg.horizon + batch_size = self.cfg.horizon * self.cfg.batch_size + batch = replay_buffer.sample(batch_size) + + # trajectory t = 256, horizon h = 5 + # (t h) ... -> h t ... + batch = ( + batch.reshape(self.cfg.batch_size, self.cfg.horizon) + .transpose(1, 0) + .contiguous() ) + batch = batch.to("cuda") + + FIRST_FRAME = 0 + obs = { + "rgb": batch["observation", "image"][FIRST_FRAME].float(), + "state": batch["observation", "state"][FIRST_FRAME], + } + action = batch["action"] + next_obses = { + "rgb": batch["next", "observation", "image"].float(), + "state": batch["next", "observation", "state"], + } + reward = batch["next", "reward"] + reward = einops.rearrange(reward, "h t -> h t 1") + # We dont use `batch["next", "done"]` since it only indicates the end of an + # episode, but not the end of the trajectory of an episode. + # Neither does `batch["next", "terminated"]` + done = torch.zeros_like(reward, dtype=torch.bool, device=reward.device) + mask = torch.ones_like(reward, dtype=torch.bool, device=reward.device) + + idxs = batch["frame_id"][FIRST_FRAME] + weights = batch["_weight"][FIRST_FRAME, :, None] # Sample from demonstration dataset if self.demo_batch_size > 0: @@ -341,6 +374,21 @@ class TDMPC(nn.Module): idxs = torch.cat([idxs, demo_idxs]) weights = torch.cat([weights, demo_weights]) + # Apply augmentations + aug_tf = h.aug(self.cfg) + obs = aug_tf(obs) + + for k in next_obses: + next_obses[k] = einops.rearrange(next_obses[k], "h t ... -> (h t) ...") + next_obses = aug_tf(next_obses) + for k in next_obses: + next_obses[k] = einops.rearrange( + next_obses[k], + "(h t) ... -> h t ...", + h=self.cfg.horizon, + t=self.cfg.batch_size, + ) + horizon = self.cfg.horizon loss_mask = torch.ones_like(mask, device=self.device) for t in range(1, horizon): @@ -407,6 +455,7 @@ class TDMPC(nn.Module): weighted_loss = (total_loss.squeeze(1) * weights).mean() weighted_loss.register_hook(lambda grad: grad * (1 / self.cfg.horizon)) weighted_loss.backward() + grad_norm = torch.nn.utils.clip_grad_norm_( self.model.parameters(), self.cfg.grad_clip_norm, error_if_nonfinite=False ) @@ -415,13 +464,16 @@ class TDMPC(nn.Module): if self.cfg.per: # Update priorities priorities = priority_loss.clamp(max=1e4).detach() - replay_buffer.update_priorities( - idxs[: replay_buffer.cfg.batch_size], - priorities[: replay_buffer.cfg.batch_size], + # normalize between [0,1] to fit torchrl specification + priorities /= 1e4 + priorities = priorities.clamp(max=1.0) + replay_buffer.update_priority( + idxs[: self.cfg.batch_size], + priorities[: self.cfg.batch_size], ) if self.demo_batch_size > 0: - demo_buffer.update_priorities( - demo_idxs, priorities[replay_buffer.cfg.batch_size :] + demo_buffer.update_priority( + demo_idxs, priorities[self.cfg.batch_size :] ) # Update policy + target network diff --git a/lerobot/common/tdmpc_helper.py b/lerobot/common/tdmpc_helper.py index 11e5c098..8f629988 100644 --- a/lerobot/common/tdmpc_helper.py +++ b/lerobot/common/tdmpc_helper.py @@ -306,13 +306,21 @@ class RandomShiftsAug(nn.Module): x = F.pad(x, padding, "replicate") eps = 1.0 / (h + 2 * self.pad) arange = torch.linspace( - -1.0 + eps, 1.0 - eps, h + 2 * self.pad, device=x.device, dtype=x.dtype + -1.0 + eps, + 1.0 - eps, + h + 2 * self.pad, + device=x.device, + dtype=torch.float32, )[:h] arange = arange.unsqueeze(0).repeat(h, 1).unsqueeze(2) base_grid = torch.cat([arange, arange.transpose(1, 0)], dim=2) base_grid = base_grid.unsqueeze(0).repeat(n, 1, 1, 1) shift = torch.randint( - 0, 2 * self.pad + 1, size=(n, 1, 1, 2), device=x.device, dtype=x.dtype + 0, + 2 * self.pad + 1, + size=(n, 1, 1, 2), + device=x.device, + dtype=torch.float32, ) shift *= 2.0 / (h + 2 * self.pad) grid = base_grid + shift diff --git a/lerobot/configs/default.yaml b/lerobot/configs/default.yaml index f6202afa..ce43b293 100644 --- a/lerobot/configs/default.yaml +++ b/lerobot/configs/default.yaml @@ -1,5 +1,14 @@ seed: 1337 log_dir: logs/2024_01_26_train +exp_name: default +device: cuda +buffer_device: cuda +eval_freq: 1000 +save_freq: 10000 +eval_episodes: 20 +save_video: false +save_model: false +save_buffer: false # env env: simxarm @@ -8,6 +17,14 @@ from_pixels: True pixels_only: False image_size: 84 +reward_scale: 1.0 + +# xarm_lift +episode_length: 25 +modality: 'all' +action_repeat: 2 # TODO(rcadene): verify we use this +discount: 0.9 +train_steps: 50000 # pixels frame_stack: 1 @@ -54,6 +71,19 @@ update_freq: 2 tau: 0.01 utd: 1 +# offline rl +# dataset_dir: ??? +data_first_percent: 1.0 +is_data_clip: true +data_clip_eps: 1e-5 +expectile: 0.9 +A_scaling: 3.0 + +# offline->online +offline_steps: ${train_steps}/2 +pretrained_model_path: "" +balanced_sampling: true +demo_schedule: 0.5 # architecture enc_dim: 256 @@ -61,11 +91,8 @@ num_q: 5 mlp_dim: 512 latent_dim: 50 +# wandb +use_wandb: false +wandb_project: FOWM +wandb_entity: rcadene # insert your own -# xarm_lift -A_scaling: 3.0 -expectile: 0.9 -episode_length: 25 -modality: 'all' -action_repeat: 2 -discount: 0.9 \ No newline at end of file diff --git a/lerobot/scripts/eval.py b/lerobot/scripts/eval.py index 4137e5d0..5268ebff 100644 --- a/lerobot/scripts/eval.py +++ b/lerobot/scripts/eval.py @@ -32,23 +32,25 @@ def eval_policy( ep_frames.append(frame) tensordict = env.reset() - # render first frame before rollout - rendering_callback(env) + if save_video: + # render first frame before rollout + rendering_callback(env) rollout = env.rollout( max_steps=max_steps, policy=policy, - callback=rendering_callback, + callback=rendering_callback if save_video else None, auto_reset=False, tensordict=tensordict, ) + # print(", ".join([f"{x:.3f}" for x in rollout["next", "reward"][:,0].tolist()])) ep_reward = rollout["next", "reward"].sum() ep_success = rollout["next", "success"].any() rewards.append(ep_reward.item()) successes.append(ep_success.item()) if save_video: - video_dir.parent.mkdir(parents=True, exist_ok=True) + video_dir.mkdir(parents=True, exist_ok=True) # TODO(rcadene): make fps configurable video_path = video_dir / f"eval_episode_{i}.mp4" imageio.mimsave(video_path, np.stack(ep_frames), fps=15) @@ -82,8 +84,8 @@ def eval(cfg: dict): metrics = eval_policy( env, policy=policy, - num_episodes=10, - save_video=True, + num_episodes=20, + save_video=False, video_dir=Path("tmp/2023_01_29_xarm_lift_final"), ) print(metrics) diff --git a/lerobot/scripts/train.py b/lerobot/scripts/train.py index b4a9edad..55c9c0f8 100644 --- a/lerobot/scripts/train.py +++ b/lerobot/scripts/train.py @@ -1,11 +1,24 @@ +import pickle +import time +from pathlib import Path + import hydra +import imageio +import numpy as np import torch +from tensordict.nn import TensorDictModule from termcolor import colored +from torchrl.data import LazyMemmapStorage, TensorDictReplayBuffer +from torchrl.data.datasets.d4rl import D4RLExperienceReplay +from torchrl.data.datasets.openx import OpenXExperienceReplay +from torchrl.data.replay_buffers import PrioritizedSliceSampler +from lerobot.common.datasets.simxarm import SimxarmExperienceReplay from lerobot.common.envs.factory import make_env +from lerobot.common.logger import Logger from lerobot.common.tdmpc import TDMPC - -from ..common.utils import set_seed +from lerobot.common.utils import set_seed +from lerobot.scripts.eval import eval_policy @hydra.main(version_base=None, config_name="default", config_path="../configs") @@ -15,22 +28,169 @@ def train(cfg: dict): print(colored("Work dir:", "yellow", attrs=["bold"]), cfg.log_dir) env = make_env(cfg) - agent = TDMPC(cfg) + policy = TDMPC(cfg) # ckpt_path = "/home/rcadene/code/fowm/logs/xarm_lift/all/default/2/models/offline.pt" ckpt_path = "/home/rcadene/code/fowm/logs/xarm_lift/all/default/2/models/final.pt" - agent.load(ckpt_path) + policy.load(ckpt_path) - # online training - - eval_metrics = train_agent( - env, - agent, - num_episodes=10, - save_video=True, - video_dir=Path("tmp/2023_01_29_xarm_lift_final"), + td_policy = TensorDictModule( + policy, + in_keys=["observation", "step_count"], + out_keys=["action"], ) - print(eval_metrics) + # initialize offline dataset + + dataset_id = f"xarm_{cfg.task}_medium" + + num_traj_per_batch = cfg.batch_size # // cfg.horizon + # TODO(rcadene): Sampler outputs a batch_size <= cfg.batch_size. + # We would need to add a transform to pad the tensordict to ensure batch_size == cfg.batch_size. + sampler = PrioritizedSliceSampler( + max_capacity=100_000, + alpha=0.7, + beta=0.9, + num_slices=num_traj_per_batch, + strict_length=False, + ) + + # TODO(rcadene): use PrioritizedReplayBuffer + offline_buffer = SimxarmExperienceReplay( + dataset_id, + # download="force", + download=True, + streaming=False, + root="data", + sampler=sampler, + ) + + num_steps = len(offline_buffer) + index = torch.arange(0, num_steps, 1) + sampler.extend(index) + + # offline_buffer._storage.device = torch.device("cuda") + # offline_buffer._storage._storage.to(torch.device("cuda")) + # TODO(rcadene): add online_buffer + + # Observation encoder + # Dynamics predictor + # Reward predictor + # Policy + # Qs state-action value predictor + # V state value predictor + + L = Logger(cfg.log_dir, cfg) + + episode_idx = 0 + start_time = time.time() + step = 0 + last_log_step = 0 + last_save_step = 0 + + while step < cfg.train_steps: + is_offline = True + num_updates = cfg.episode_length + _step = step + num_updates + rollout_metrics = {} + + # if step >= cfg.offline_steps: + # is_offline = False + + # # Collect trajectory + # obs = env.reset() + # episode = Episode(cfg, obs) + # success = False + # while not episode.done: + # action = policy.act(obs, step=step, t0=episode.first) + # obs, reward, done, info = env.step(action.cpu().numpy()) + # reward = reward_normalizer(reward) + # mask = 1.0 if (not done or "TimeLimit.truncated" in info) else 0.0 + # success = info.get('success', False) + # episode += (obs, action, reward, done, mask, success) + # assert len(episode) <= cfg.episode_length + # buffer += episode + # episode_idx += 1 + # rollout_metrics = { + # 'episode_reward': episode.cumulative_reward, + # 'episode_success': float(success), + # 'episode_length': len(episode) + # } + # num_updates = len(episode) * cfg.utd + # _step = min(step + len(episode), cfg.train_steps) + + # Update model + train_metrics = {} + if is_offline: + for i in range(num_updates): + train_metrics.update(policy.update(offline_buffer, step + i)) + # else: + # for i in range(num_updates): + # train_metrics.update( + # policy.update(buffer, step + i // cfg.utd, + # demo_buffer=offline_buffer if cfg.balanced_sampling else None) + # ) + + # Log training metrics + env_step = int(_step * cfg.action_repeat) + common_metrics = { + "episode": episode_idx, + "step": _step, + "env_step": env_step, + "total_time": time.time() - start_time, + "is_offline": float(is_offline), + } + train_metrics.update(common_metrics) + train_metrics.update(rollout_metrics) + L.log(train_metrics, category="train") + + # Evaluate policy periodically + if step == 0 or env_step - last_log_step >= cfg.eval_freq: + + eval_metrics = eval_policy( + env, + td_policy, + num_episodes=cfg.eval_episodes, + # TODO(rcadene): add step, env_step, L.video + ) + + # TODO(rcadene): + # if hasattr(env, "get_normalized_score"): + # eval_metrics['normalized_score'] = env.get_normalized_score(eval_metrics["episode_reward"]) * 100.0 + + common_metrics.update(eval_metrics) + + L.log(common_metrics, category="eval") + last_log_step = env_step - env_step % cfg.eval_freq + + # Save model periodically + # if cfg.save_model and env_step - last_save_step >= cfg.save_freq: + # L.save_model(policy, identifier=env_step) + # print(f"Model has been checkpointed at step {env_step}") + # last_save_step = env_step - env_step % cfg.save_freq + + # if cfg.save_model and is_offline and _step >= cfg.offline_steps: + # # save the model after offline training + # L.save_model(policy, identifier="offline") + + step = _step + + # dataset_d4rl = D4RLExperienceReplay( + # dataset_id="maze2d-umaze-v1", + # split_trajs=False, + # batch_size=1, + # sampler=SamplerWithoutReplacement(drop_last=False), + # prefetch=4, + # direct_download=True, + # ) + + # dataset_openx = OpenXExperienceReplay( + # "cmu_stretch", + # batch_size=1, + # num_slices=1, + # #download="force", + # streaming=False, + # root="data", + # ) if __name__ == "__main__": diff --git a/lerobot/scripts/visualize.py b/lerobot/scripts/visualize.py deleted file mode 100644 index 64d24504..00000000 --- a/lerobot/scripts/visualize.py +++ /dev/null @@ -1,80 +0,0 @@ -import pickle -from pathlib import Path - -import imageio -import simxarm - -if __name__ == "__main__": - - task = "lift" - dataset_dir = Path(f"data/xarm_{task}_medium") - dataset_path = dataset_dir / f"buffer.pkl" - print(f"Using offline dataset '{dataset_path}'") - with open(dataset_path, "rb") as f: - dataset_dict = pickle.load(f) - - required_keys = [ - "observations", - "next_observations", - "actions", - "rewards", - "dones", - "masks", - ] - for k in required_keys: - if k not in dataset_dict and k[:-1] in dataset_dict: - dataset_dict[k] = dataset_dict.pop(k[:-1]) - - out_dir = Path("tmp/2023_01_26_xarm_lift_medium") - out_dir.mkdir(parents=True, exist_ok=True) - - frames = dataset_dict["observations"]["rgb"][:100] - frames = frames.transpose(0, 2, 3, 1) - imageio.mimsave(out_dir / "test.mp4", frames, fps=30) - - frames = [] - cfg = {} - - env = simxarm.make( - task=task, - obs_mode="all", - image_size=84, - action_repeat=cfg.get("action_repeat", 1), - frame_stack=cfg.get("frame_stack", 1), - seed=1, - ) - - obs = env.reset() - frame = env.render(mode="rgb_array", width=384, height=384) - frames.append(frame) - - # def is_first_obs(obs): - # nonlocal first_obs - # print(((dataset_dict["observations"]["state"][i]-obs["state"])**2).sum()) - # print(((dataset_dict["observations"]["rgb"][i]-obs["rgb"])**2).sum()) - - for i in range(25): - action = dataset_dict["actions"][i] - - print(f"#{i}") - # print(obs["state"]) - # print(dataset_dict["observations"]["state"][i]) - print(((dataset_dict["observations"]["state"][i] - obs["state"]) ** 2).sum()) - print(((dataset_dict["observations"]["rgb"][i] - obs["rgb"]) ** 2).sum()) - - obs, reward, done, info = env.step(action) - frame = env.render(mode="rgb_array", width=384, height=384) - frames.append(frame) - - print(reward) - print(dataset_dict["rewards"][i]) - - print(done) - print(dataset_dict["dones"][i]) - - if dataset_dict["dones"][i]: - obs = env.reset() - frame = env.render(mode="rgb_array", width=384, height=384) - frames.append(frame) - - # imageio.mimsave(out_dir / 'test_rollout.mp4', frames, fps=60) diff --git a/lerobot/scripts/visualize_dataset.py b/lerobot/scripts/visualize_dataset.py new file mode 100644 index 00000000..9d7d980d --- /dev/null +++ b/lerobot/scripts/visualize_dataset.py @@ -0,0 +1,59 @@ +import pickle +from pathlib import Path + +import imageio +import simxarm +import torch +from torchrl.data.replay_buffers import ( + SamplerWithoutReplacement, + SliceSampler, + SliceSamplerWithoutReplacement, +) + +from lerobot.common.datasets.simxarm import SimxarmExperienceReplay + + +def visualize_simxarm_dataset(dataset_id="xarm_lift_medium"): + sampler = SliceSamplerWithoutReplacement( + num_slices=1, + strict_length=False, + shuffle=False, + ) + + dataset = SimxarmExperienceReplay( + dataset_id, + # download="force", + download=True, + streaming=False, + root="data", + sampler=sampler, + ) + + NUM_EPISODES_TO_RENDER = 10 + MAX_NUM_STEPS = 50 + FIRST_FRAME = 0 + for _ in range(NUM_EPISODES_TO_RENDER): + episode = dataset.sample(MAX_NUM_STEPS) + + ep_idx = episode["episode"][FIRST_FRAME].item() + ep_frames = torch.cat( + [ + episode["observation"]["image"][FIRST_FRAME][None, ...], + episode["next", "observation"]["image"], + ], + dim=0, + ) + + video_dir = Path("tmp/2024_02_03_xarm_lift_medium") + video_dir.mkdir(parents=True, exist_ok=True) + # TODO(rcadene): make fps configurable + video_path = video_dir / f"eval_episode_{ep_idx}.mp4" + imageio.mimsave(video_path, ep_frames.numpy().transpose(0, 2, 3, 1), fps=15) + + # ran out of episodes + if dataset._sampler._sample_list.numel() == 0: + break + + +if __name__ == "__main__": + visualize_simxarm_dataset()