From f95ecd66fcaaed8513b5b10f97304ff21aab275f Mon Sep 17 00:00:00 2001 From: Remi Cadene Date: Wed, 6 Mar 2024 10:14:03 +0000 Subject: [PATCH] Improve visualize_dataset, Improve AbstractReplayBuffer, Small improvements --- lerobot/common/datasets/abstract.py | 149 +++++++++++---------------- lerobot/common/datasets/factory.py | 53 ++++++---- lerobot/common/datasets/pusht.py | 9 +- lerobot/common/logger.py | 4 + lerobot/scripts/eval.py | 18 +++- lerobot/scripts/train.py | 8 +- lerobot/scripts/visualize_dataset.py | 104 ++++++++++++++----- 7 files changed, 195 insertions(+), 150 deletions(-) diff --git a/lerobot/common/datasets/abstract.py b/lerobot/common/datasets/abstract.py index 4f56d96f..af30cf8c 100644 --- a/lerobot/common/datasets/abstract.py +++ b/lerobot/common/datasets/abstract.py @@ -1,6 +1,5 @@ import abc import logging -import math from pathlib import Path from typing import Callable @@ -50,6 +49,22 @@ class AbstractExperienceReplay(TensorDictReplayBuffer): transform=transform, ) + @property + def stats_patterns(self) -> dict: + return { + ("observation", "state"): "b c -> 1 c", + ("observation", "image"): "b c h w -> 1 c 1 1", + ("action"): "b c -> 1 c", + } + + @property + def image_keys(self) -> list: + return [("observation", "image")] + + @property + def num_cameras(self) -> int: + return len(self.image_keys) + @property def num_samples(self) -> int: return len(self) @@ -67,7 +82,7 @@ class AbstractExperienceReplay(TensorDictReplayBuffer): stats = torch.load(stats_path) else: logging.info(f"compute_stats and save to {stats_path}") - stats = self._compute_stats(self._storage._storage, num_batch, batch_size) + stats = self._compute_stats(num_batch, batch_size) torch.save(stats, stats_path) return stats @@ -85,101 +100,59 @@ class AbstractExperienceReplay(TensorDictReplayBuffer): def _is_downloaded(self) -> bool: return self.data_dir.is_dir() - def _compute_stats(self, storage, num_batch=100, batch_size=32): + def _compute_stats(self, num_batch=100, batch_size=32): rb = TensorDictReplayBuffer( - storage=storage, + storage=self._storage, batch_size=batch_size, prefetch=True, ) - batch = rb.sample() - image_channels = batch["observation", "image"].shape[1] - image_mean = torch.zeros(image_channels) - image_std = torch.zeros(image_channels) - image_max = torch.tensor([-math.inf] * image_channels) - image_min = torch.tensor([math.inf] * image_channels) - - state_channels = batch["observation", "state"].shape[1] - state_mean = torch.zeros(state_channels) - state_std = torch.zeros(state_channels) - state_max = torch.tensor([-math.inf] * state_channels) - state_min = torch.tensor([math.inf] * state_channels) - - action_channels = batch["action"].shape[1] - action_mean = torch.zeros(action_channels) - action_std = torch.zeros(action_channels) - action_max = torch.tensor([-math.inf] * action_channels) - action_min = torch.tensor([math.inf] * action_channels) + mean, std, max, min = {}, {}, {}, {} + # compute mean, min, max for _ in tqdm.tqdm(range(num_batch)): - image_mean += einops.reduce(batch["observation", "image"], "b c h w -> c", "mean") - state_mean += einops.reduce(batch["observation", "state"], "b c -> c", "mean") - action_mean += einops.reduce(batch["action"], "b c -> c", "mean") - - b_image_max = einops.reduce(batch["observation", "image"], "b c h w -> c", "max") - b_image_min = einops.reduce(batch["observation", "image"], "b c h w -> c", "min") - b_state_max = einops.reduce(batch["observation", "state"], "b c -> c", "max") - b_state_min = einops.reduce(batch["observation", "state"], "b c -> c", "min") - b_action_max = einops.reduce(batch["action"], "b c -> c", "max") - b_action_min = einops.reduce(batch["action"], "b c -> c", "min") - image_max = torch.maximum(image_max, b_image_max) - image_min = torch.maximum(image_min, b_image_min) - state_max = torch.maximum(state_max, b_state_max) - state_min = torch.maximum(state_min, b_state_min) - action_max = torch.maximum(action_max, b_action_max) - action_min = torch.maximum(action_min, b_action_min) - batch = rb.sample() - - image_mean /= num_batch - state_mean /= num_batch - action_mean /= num_batch - - for i in tqdm.tqdm(range(num_batch)): - b_image_mean = einops.reduce(batch["observation", "image"], "b c h w -> c", "mean") - b_state_mean = einops.reduce(batch["observation", "state"], "b c -> c", "mean") - b_action_mean = einops.reduce(batch["action"], "b c -> c", "mean") - image_std += (b_image_mean - image_mean) ** 2 - state_std += (b_state_mean - state_mean) ** 2 - action_std += (b_action_mean - action_mean) ** 2 - - b_image_max = einops.reduce(batch["observation", "image"], "b c h w -> c", "max") - b_image_min = einops.reduce(batch["observation", "image"], "b c h w -> c", "min") - b_state_max = einops.reduce(batch["observation", "state"], "b c -> c", "max") - b_state_min = einops.reduce(batch["observation", "state"], "b c -> c", "min") - b_action_max = einops.reduce(batch["action"], "b c -> c", "max") - b_action_min = einops.reduce(batch["action"], "b c -> c", "min") - image_max = torch.maximum(image_max, b_image_max) - image_min = torch.maximum(image_min, b_image_min) - state_max = torch.maximum(state_max, b_state_max) - state_min = torch.maximum(state_min, b_state_min) - action_max = torch.maximum(action_max, b_action_max) - action_min = torch.maximum(action_min, b_action_min) - - if i < num_batch - 1: + for key, pattern in self.stats_patterns.items(): + batch[key] = batch[key].float() + if key not in mean: + # first batch initialize mean, min, max + mean[key] = einops.reduce(batch[key], pattern, "mean") + max[key] = einops.reduce(batch[key], pattern, "max") + min[key] = einops.reduce(batch[key], pattern, "min") + else: + mean[key] += einops.reduce(batch[key], pattern, "mean") + max[key] = torch.maximum(max[key], einops.reduce(batch[key], pattern, "max")) + min[key] = torch.minimum(min[key], einops.reduce(batch[key], pattern, "min")) batch = rb.sample() - image_std = torch.sqrt(image_std / num_batch) - state_std = torch.sqrt(state_std / num_batch) - action_std = torch.sqrt(action_std / num_batch) + for key in self.stats_patterns: + mean[key] /= num_batch - stats = TensorDict( - { - ("observation", "image", "mean"): image_mean[None, :, None, None], - ("observation", "image", "std"): image_std[None, :, None, None], - ("observation", "image", "max"): image_max[None, :, None, None], - ("observation", "image", "min"): image_min[None, :, None, None], - ("observation", "state", "mean"): state_mean[None, :], - ("observation", "state", "std"): state_std[None, :], - ("observation", "state", "max"): state_max[None, :], - ("observation", "state", "min"): state_min[None, :], - ("action", "mean"): action_mean[None, :], - ("action", "std"): action_std[None, :], - ("action", "max"): action_max[None, :], - ("action", "min"): action_min[None, :], - }, - batch_size=[], - ) - stats["next", "observation", "image"] = stats["observation", "image"] - stats["next", "observation", "state"] = stats["observation", "state"] + # compute std, min, max + for _ in tqdm.tqdm(range(num_batch)): + batch = rb.sample() + for key, pattern in self.stats_patterns.items(): + batch[key] = batch[key].float() + batch_mean = einops.reduce(batch[key], pattern, "mean") + if key not in std: + # first batch initialize std + std[key] = (batch_mean - mean[key]) ** 2 + else: + std[key] += (batch_mean - mean[key]) ** 2 + max[key] = torch.maximum(max[key], einops.reduce(batch[key], pattern, "max")) + min[key] = torch.minimum(min[key], einops.reduce(batch[key], pattern, "min")) + + for key in self.stats_patterns: + std[key] = torch.sqrt(std[key] / num_batch) + + stats = TensorDict({}, batch_size=[]) + for key in self.stats_patterns: + stats[(*key, "mean")] = mean[key] + stats[(*key, "std")] = std[key] + stats[(*key, "max")] = max[key] + stats[(*key, "min")] = min[key] + + if key[0] == "observation": + # use same stats for the next observations + stats[("next", *key)] = stats[key] return stats diff --git a/lerobot/common/datasets/factory.py b/lerobot/common/datasets/factory.py index feccdf21..c72da5e5 100644 --- a/lerobot/common/datasets/factory.py +++ b/lerobot/common/datasets/factory.py @@ -10,7 +10,9 @@ from lerobot.common.envs.transforms import NormalizeTransform DATA_DIR = Path(os.environ.get("DATA_DIR", "data")) -def make_offline_buffer(cfg, sampler=None): +def make_offline_buffer( + cfg, overwrite_sampler=None, normalize=True, overwrite_batch_size=None, overwrite_prefetch=None +): if cfg.policy.balanced_sampling: assert cfg.online_steps > 0 batch_size = None @@ -23,9 +25,13 @@ def make_offline_buffer(cfg, sampler=None): pin_memory = cfg.device == "cuda" prefetch = cfg.prefetch - overwrite_sampler = sampler is not None + if overwrite_batch_size is not None: + batch_size = overwrite_batch_size - if not overwrite_sampler: + if overwrite_prefetch is not None: + prefetch = overwrite_prefetch + + if overwrite_sampler is None: # TODO(rcadene): move batch_size outside num_traj_per_batch = cfg.policy.batch_size # // cfg.horizon # TODO(rcadene): Sampler outputs a batch_size <= cfg.batch_size. @@ -46,6 +52,8 @@ def make_offline_buffer(cfg, sampler=None): num_slices=num_traj_per_batch, strict_length=False, ) + else: + sampler = overwrite_sampler if cfg.env.name == "simxarm": from lerobot.common.datasets.simxarm import SimxarmExperienceReplay @@ -70,30 +78,31 @@ def make_offline_buffer(cfg, sampler=None): prefetch=prefetch if isinstance(prefetch, int) else None, ) - # TODO(rcadene): make normalization strategy configurable between mean_std, min_max, manual_min_max, min_max_from_spec - stats = offline_buffer.compute_or_load_stats() - in_keys = [("observation", "state"), ("action")] + if normalize: + # TODO(rcadene): make normalization strategy configurable between mean_std, min_max, manual_min_max, min_max_from_spec + stats = offline_buffer.compute_or_load_stats() + in_keys = [("observation", "state"), ("action")] - if cfg.policy == "tdmpc": - # TODO(rcadene): imagenet normalization is applied inside diffusion policy, but no normalization inside tdmpc - in_keys.append(("observation", "image")) - # since we use next observations in tdmpc - in_keys.append(("next", "observation", "image")) - in_keys.append(("next", "observation", "state")) + if cfg.policy == "tdmpc": + for key in offline_buffer.image_keys: + # TODO(rcadene): imagenet normalization is applied inside diffusion policy, but no normalization inside tdmpc + in_keys.append(key) + # since we use next observations in tdmpc + in_keys.append(("next", *key)) + in_keys.append(("next", "observation", "state")) - if cfg.policy == "diffusion" and cfg.env.name == "pusht": - # TODO(rcadene): we overwrite stats to have the same as pretrained model, but we should remove this - stats["observation", "state", "min"] = torch.tensor([13.456424, 32.938293], dtype=torch.float32) - stats["observation", "state", "max"] = torch.tensor([496.14618, 510.9579], dtype=torch.float32) - stats["action", "min"] = torch.tensor([12.0, 25.0], dtype=torch.float32) - stats["action", "max"] = torch.tensor([511.0, 511.0], dtype=torch.float32) + if cfg.policy == "diffusion" and cfg.env.name == "pusht": + # TODO(rcadene): we overwrite stats to have the same as pretrained model, but we should remove this + stats["observation", "state", "min"] = torch.tensor([13.456424, 32.938293], dtype=torch.float32) + stats["observation", "state", "max"] = torch.tensor([496.14618, 510.9579], dtype=torch.float32) + stats["action", "min"] = torch.tensor([12.0, 25.0], dtype=torch.float32) + stats["action", "max"] = torch.tensor([511.0, 511.0], dtype=torch.float32) - transform = NormalizeTransform(stats, in_keys, mode="min_max") - offline_buffer.set_transform(transform) + transform = NormalizeTransform(stats, in_keys, mode="min_max") + offline_buffer.set_transform(transform) if not overwrite_sampler: - num_steps = len(offline_buffer) - index = torch.arange(0, num_steps, 1) + index = torch.arange(0, offline_buffer.num_frames, 1) sampler.extend(index) return offline_buffer diff --git a/lerobot/common/datasets/pusht.py b/lerobot/common/datasets/pusht.py index 77334851..b93b519b 100644 --- a/lerobot/common/datasets/pusht.py +++ b/lerobot/common/datasets/pusht.py @@ -183,8 +183,7 @@ class PushtExperienceReplay(AbstractExperienceReplay): # last step of demonstration is considered done done[-1] = True - print("before " + """episode = TensorDict(""") - episode = TensorDict( + ep_td = TensorDict( { ("observation", "image"): image[:-1], ("observation", "state"): agent_pos[:-1], @@ -203,11 +202,11 @@ class PushtExperienceReplay(AbstractExperienceReplay): if episode_id == 0: # hack to initialize tensordict data structure to store episodes - td_data = episode[0].expand(total_frames).memmap_like(self.data_dir) + td_data = ep_td[0].expand(total_frames).memmap_like(self.data_dir) - td_data[idxtd : idxtd + len(episode)] = episode + td_data[idxtd : idxtd + len(ep_td)] = ep_td idx0 = idx1 - idxtd = idxtd + len(episode) + idxtd = idxtd + len(ep_td) return TensorStorage(td_data.lock_()) diff --git a/lerobot/common/logger.py b/lerobot/common/logger.py index 54325bd4..3d98d726 100644 --- a/lerobot/common/logger.py +++ b/lerobot/common/logger.py @@ -6,6 +6,10 @@ from omegaconf import OmegaConf from termcolor import colored +def log_output_dir(out_dir): + logging.info(colored("Output dir:", "yellow", attrs=["bold"]) + f" {out_dir}") + + def cfg_to_group(cfg, return_list=False): """Return a wandb-safe group name for logging. Optionally returns group name as list.""" # lst = [cfg.task, cfg.modality, re.sub("[^0-9a-zA-Z]+", "-", cfg.exp_name)] diff --git a/lerobot/scripts/eval.py b/lerobot/scripts/eval.py index abe4645a..c9338dca 100644 --- a/lerobot/scripts/eval.py +++ b/lerobot/scripts/eval.py @@ -9,13 +9,13 @@ import numpy as np import torch import tqdm from tensordict.nn import TensorDictModule -from termcolor import colored from torchrl.envs import EnvBase from lerobot.common.datasets.factory import make_offline_buffer from lerobot.common.envs.factory import make_env +from lerobot.common.logger import log_output_dir from lerobot.common.policies.factory import make_policy -from lerobot.common.utils import set_seed +from lerobot.common.utils import init_logging, set_seed def write_video(video_path, stacked_frames, fps): @@ -109,10 +109,18 @@ def eval(cfg: dict, out_dir=None): if out_dir is None: raise NotImplementedError() - assert torch.cuda.is_available() + init_logging() + + if cfg.device == "cuda": + assert torch.cuda.is_available() + else: + logging.warning("Using CPU, this will be slow.") + torch.backends.cudnn.benchmark = True + torch.backends.cuda.matmul.allow_tf32 = True set_seed(cfg.seed) - print(colored("Log dir:", "yellow", attrs=["bold"]), out_dir) + + log_output_dir(out_dir) logging.info("make_offline_buffer") offline_buffer = make_offline_buffer(cfg) @@ -142,6 +150,8 @@ def eval(cfg: dict, out_dir=None): ) print(metrics) + logging.info("End of eval") + if __name__ == "__main__": eval_cli() diff --git a/lerobot/scripts/train.py b/lerobot/scripts/train.py index c169b49b..be3bef8b 100644 --- a/lerobot/scripts/train.py +++ b/lerobot/scripts/train.py @@ -4,13 +4,12 @@ import hydra import numpy as np import torch from tensordict.nn import TensorDictModule -from termcolor import colored from torchrl.data import LazyMemmapStorage, TensorDictReplayBuffer from torchrl.data.replay_buffers import PrioritizedSliceSampler from lerobot.common.datasets.factory import make_offline_buffer from lerobot.common.envs.factory import make_env -from lerobot.common.logger import Logger +from lerobot.common.logger import Logger, log_output_dir from lerobot.common.policies.factory import make_policy from lerobot.common.utils import format_big_number, init_logging, set_seed from lerobot.scripts.eval import eval_policy @@ -164,7 +163,7 @@ def train(cfg: dict, out_dir=None, job_name=None): # log metrics to terminal and wandb logger = Logger(out_dir, job_name, cfg) - logging.info(colored("Work dir:", "yellow", attrs=["bold"]) + f" {out_dir}") + log_output_dir(out_dir) logging.info(f"{cfg.env.task=}") logging.info(f"{cfg.offline_steps=} ({format_big_number(cfg.offline_steps)})") logging.info(f"{cfg.online_steps=}") @@ -212,7 +211,6 @@ def train(cfg: dict, out_dir=None, job_name=None): for env_step in range(cfg.online_steps): if env_step == 0: logging.info("Start online training by interacting with environment") - # TODO: use SyncDataCollector for that? # TODO: add configurable number of rollout? (default=1) with torch.no_grad(): rollout = env.rollout( @@ -268,6 +266,8 @@ def train(cfg: dict, out_dir=None, job_name=None): step += 1 online_step += 1 + logging.info("End of training") + if __name__ == "__main__": train_cli() diff --git a/lerobot/scripts/visualize_dataset.py b/lerobot/scripts/visualize_dataset.py index 5aa0a278..1bd63f6e 100644 --- a/lerobot/scripts/visualize_dataset.py +++ b/lerobot/scripts/visualize_dataset.py @@ -1,13 +1,20 @@ +import logging +import threading from pathlib import Path +import einops import hydra import imageio import torch -from torchrl.data.replay_buffers import SliceSamplerWithoutReplacement +from torchrl.data.replay_buffers import ( + SamplerWithoutReplacement, +) from lerobot.common.datasets.factory import make_offline_buffer +from lerobot.common.logger import log_output_dir +from lerobot.common.utils import init_logging -NUM_EPISODES_TO_RENDER = 10 +NUM_EPISODES_TO_RENDER = 50 MAX_NUM_STEPS = 1000 FIRST_FRAME = 0 @@ -17,45 +24,88 @@ def visualize_dataset_cli(cfg: dict): visualize_dataset(cfg, out_dir=hydra.core.hydra_config.HydraConfig.get().runtime.output_dir) +def cat_and_write_video(video_path, frames, fps): + frames = torch.cat(frames) + assert frames.dtype == torch.uint8 + frames = einops.rearrange(frames, "b c h w -> b h w c").numpy() + imageio.mimsave(video_path, frames, fps=fps) + + def visualize_dataset(cfg: dict, out_dir=None): if out_dir is None: raise NotImplementedError() - sampler = SliceSamplerWithoutReplacement( - num_slices=1, - strict_length=False, + init_logging() + log_output_dir(out_dir) + + # we expect frames of each episode to be stored next to each others sequentially + sampler = SamplerWithoutReplacement( shuffle=False, ) - offline_buffer = make_offline_buffer(cfg, sampler) + logging.info("make_offline_buffer") + offline_buffer = make_offline_buffer( + cfg, overwrite_sampler=sampler, normalize=False, overwrite_batch_size=1, overwrite_prefetch=12 + ) - for _ in range(NUM_EPISODES_TO_RENDER): - episode = offline_buffer.sample(MAX_NUM_STEPS) + logging.info("Start rendering episodes from offline buffer") - ep_idx = episode["episode"][FIRST_FRAME].item() - ep_frames = torch.cat( - [ - episode["observation"]["image"][FIRST_FRAME][None, ...], - episode["next", "observation"]["image"], - ], - dim=0, - ) + threads = [] + frames = {} + current_ep_idx = 0 + logging.info(f"Visualizing episode {current_ep_idx}") + for _ in range(MAX_NUM_STEPS * NUM_EPISODES_TO_RENDER): + # TODO(rcadene): make it work with bsize > 1 + ep_td = offline_buffer.sample(1) + ep_idx = ep_td["episode"][FIRST_FRAME].item() - video_dir = Path(out_dir) / "visualize_dataset" - video_dir.mkdir(parents=True, exist_ok=True) - # TODO(rcadene): make fps configurable - video_path = video_dir / f"episode_{ep_idx}.mp4" + # TODO(rcaene): modify offline_buffer._sampler._sample_list or sampler to randomly sample an episode, but sequentially sample frames + no_more_frames = offline_buffer._sampler._sample_list.numel() == 0 + new_episode = ep_idx != current_ep_idx - assert ep_frames.min().item() >= 0 - assert ep_frames.max().item() > 1, "Not mendatory, but sanity check" - assert ep_frames.max().item() <= 255 - ep_frames = ep_frames.type(torch.uint8) - imageio.mimsave(video_path, ep_frames.numpy().transpose(0, 2, 3, 1), fps=cfg.fps) + if new_episode: + logging.info(f"Visualizing episode {current_ep_idx}") - # ran out of episodes - if offline_buffer._sampler._sample_list.numel() == 0: + for im_key in offline_buffer.image_keys: + if new_episode or no_more_frames: + # append last observed frames (the ones after last action taken) + frames[im_key].append(ep_td[("next", *im_key)]) + + video_dir = Path(out_dir) / "visualize_dataset" + video_dir.mkdir(parents=True, exist_ok=True) + + if len(offline_buffer.image_keys) > 1: + camera = im_key[-1] + video_path = video_dir / f"episode_{current_ep_idx}_{camera}.mp4" + else: + video_path = video_dir / f"episode_{current_ep_idx}.mp4" + + thread = threading.Thread( + target=cat_and_write_video, + args=(str(video_path), frames[im_key], cfg.fps), + ) + thread.start() + threads.append(thread) + + current_ep_idx = ep_idx + + # reset list of frames + del frames[im_key] + + # append current cameras images to list of frames + if im_key not in frames: + frames[im_key] = [] + frames[im_key].append(ep_td[im_key]) + + if no_more_frames: + logging.info("Ran out of frames") break + for thread in threads: + thread.join() + + logging.info("End of visualize_dataset") + if __name__ == "__main__": visualize_dataset_cli()