diff --git a/examples/1_visualize_dataset.py b/examples/1_visualize_dataset.py index f52ab76a..a9406d1e 100644 --- a/examples/1_visualize_dataset.py +++ b/examples/1_visualize_dataset.py @@ -1,6 +1,5 @@ import os - -from torchrl.data.replay_buffers import SamplerWithoutReplacement +from pathlib import Path import lerobot from lerobot.common.datasets.aloha import AlohaDataset @@ -9,16 +8,13 @@ from lerobot.scripts.visualize_dataset import render_dataset print(lerobot.available_datasets) # >>> ['aloha_sim_insertion_human', 'aloha_sim_insertion_scripted', 'aloha_sim_transfer_cube_human', 'aloha_sim_transfer_cube_scripted', 'pusht', 'xarm_lift_medium'] -# we use this sampler to sample 1 frame after the other -sampler = SamplerWithoutReplacement(shuffle=False) - -dataset = AlohaDataset("aloha_sim_transfer_cube_human", sampler=sampler, root=os.environ.get("DATA_DIR")) +# TODO(rcadene): remove DATA_DIR +dataset = AlohaDataset("aloha_sim_transfer_cube_human", root=Path(os.environ.get("DATA_DIR"))) video_paths = render_dataset( dataset, out_dir="outputs/visualize_dataset/example", - max_num_samples=300, - fps=50, + max_num_episodes=1, ) print(video_paths) # ['outputs/visualize_dataset/example/episode_0.mp4'] diff --git a/examples/3_train_policy.py b/examples/3_train_policy.py index 6e01a5d5..238f953d 100644 --- a/examples/3_train_policy.py +++ b/examples/3_train_policy.py @@ -9,9 +9,8 @@ from pathlib import Path import torch from omegaconf import OmegaConf -from tqdm import trange -from lerobot.common.datasets.factory import make_offline_buffer +from lerobot.common.datasets.factory import make_dataset from lerobot.common.policies.diffusion.policy import DiffusionPolicy from lerobot.common.utils import init_hydra_config @@ -37,19 +36,33 @@ policy = DiffusionPolicy( cfg_obs_encoder=cfg.obs_encoder, cfg_optimizer=cfg.optimizer, cfg_ema=cfg.ema, - n_action_steps=cfg.n_action_steps, **cfg.policy, ) policy.train() -offline_buffer = make_offline_buffer(cfg) +dataset = make_dataset(cfg) + +# create dataloader for offline training +dataloader = torch.utils.data.DataLoader( + dataset, + num_workers=4, + batch_size=cfg.policy.batch_size, + shuffle=True, + pin_memory=cfg.device != "cpu", + drop_last=True, +) + +for step, batch in enumerate(dataloader): + info = policy(batch, step) + + if step % cfg.log_freq == 0: + num_samples = (step + 1) * cfg.policy.batch_size + loss = info["loss"] + update_s = info["update_s"] + print(f"step:{step} samples:{num_samples} loss:{loss:.3f} update_time:{update_s:.3f}(seconds)") -for offline_step in trange(cfg.offline_steps): - train_info = policy.update(offline_buffer, offline_step) - if offline_step % cfg.log_freq == 0: - print(train_info) # Save the policy, configuration, and normalization stats for later use. policy.save(output_directory / "model.pt") OmegaConf.save(cfg, output_directory / "config.yaml") -torch.save(offline_buffer.transform[-1].stats, output_directory / "stats.pth") +torch.save(dataset.transform.transforms[-1].stats, output_directory / "stats.pth") diff --git a/lerobot/common/datasets/factory.py b/lerobot/common/datasets/factory.py index 0dab5d4b..63507cce 100644 --- a/lerobot/common/datasets/factory.py +++ b/lerobot/common/datasets/factory.py @@ -40,7 +40,8 @@ def make_dataset( if normalize: # TODO(rcadene): make normalization strategy configurable between mean_std, min_max, manual_min_max, # min_max_from_spec - # stats = dataset.compute_or_load_stats() if stats_path is None else torch.load(stats_path) + # TODO(rcadene): remove this and put it in config. Ideally we want to reproduce SOTA results just with mean_std + normalization_mode = "mean_std" if cfg.env.name == "aloha" else "min_max" if cfg.policy.name == "diffusion" and cfg.env.name == "pusht": stats = {} @@ -51,7 +52,7 @@ def make_dataset( stats["action"] = {} stats["action"]["min"] = torch.tensor([12.0, 25.0], dtype=torch.float32) stats["action"]["max"] = torch.tensor([511.0, 511.0], dtype=torch.float32) - else: + elif stats_path is None: # instantiate a one frame dataset with light transform stats_dataset = clsfunc( dataset_id=cfg.dataset_id, @@ -59,9 +60,8 @@ def make_dataset( transform=Prod(in_keys=clsfunc.image_keys, prod=1 / 255.0), ) stats = compute_or_load_stats(stats_dataset) - - # TODO(rcadene): remove this and put it in config. Ideally we want to reproduce SOTA results just with mean_std - normalization_mode = "mean_std" if cfg.env.name == "aloha" else "min_max" + else: + stats = torch.load(stats_path) transforms = v2.Compose( [ diff --git a/lerobot/scripts/visualize_dataset.py b/lerobot/scripts/visualize_dataset.py index 93315e90..ed95b39a 100644 --- a/lerobot/scripts/visualize_dataset.py +++ b/lerobot/scripts/visualize_dataset.py @@ -6,9 +6,6 @@ import einops import hydra import imageio import torch -from torchrl.data.replay_buffers import ( - SamplerWithoutReplacement, -) from lerobot.common.datasets.factory import make_dataset from lerobot.common.logger import log_output_dir @@ -39,19 +36,11 @@ def visualize_dataset(cfg: dict, out_dir=None): init_logging() log_output_dir(out_dir) - # we expect frames of each episode to be stored next to each others sequentially - sampler = SamplerWithoutReplacement( - shuffle=False, - ) - logging.info("make_dataset") dataset = make_dataset( cfg, - overwrite_sampler=sampler, # remove all transformations such as rescale images from [0,255] to [0,1] or normalization normalize=False, - overwrite_batch_size=1, - overwrite_prefetch=12, ) logging.info("Start rendering episodes from offline buffer") @@ -60,64 +49,49 @@ def visualize_dataset(cfg: dict, out_dir=None): logging.info(video_path) -def render_dataset(dataset, out_dir, max_num_samples, fps): +def render_dataset(dataset, out_dir, max_num_episodes): out_dir = Path(out_dir) video_paths = [] threads = [] - frames = {} - current_ep_idx = 0 - logging.info(f"Visualizing episode {current_ep_idx}") - for i in range(max_num_samples): - # TODO(rcadene): make it work with bsize > 1 - ep_td = dataset.sample(1) - ep_idx = ep_td["episode"][FIRST_FRAME].item() - # TODO(rcadene): modify dataset._sampler._sample_list or sampler to randomly sample an episode, but sequentially sample frames - num_frames_left = dataset._sampler._sample_list.numel() - episode_is_done = ep_idx != current_ep_idx + dataloader = torch.utils.data.DataLoader( + dataset, + num_workers=4, + batch_size=1, + shuffle=False, + ) + dl_iter = iter(dataloader) - if episode_is_done: - logging.info(f"Rendering episode {current_ep_idx}") + num_episodes = len(dataset.data_ids_per_episode) + for ep_id in range(min(max_num_episodes, num_episodes)): + logging.info(f"Rendering episode {ep_id}") - for im_key in dataset.image_keys: - if not episode_is_done and num_frames_left > 0 and i < (max_num_samples - 1): + frames = {} + for _ in dataset.data_ids_per_episode[ep_id]: + item = next(dl_iter) + + for im_key in dataset.image_keys: # when first frame of episode, initialize frames dict if im_key not in frames: frames[im_key] = [] # add current frame to list of frames to render - frames[im_key].append(ep_td[im_key]) + frames[im_key].append(item[im_key]) + + out_dir.mkdir(parents=True, exist_ok=True) + for im_key in dataset.image_keys: + if len(dataset.image_keys) > 0: + im_name = im_key.replace("observation.images.", "") + video_path = out_dir / f"episode_{ep_id}_{im_name}.mp4" else: - # When episode has no more frame in its list of observation, - # one frame still remains. It is the result of the last action taken. - # It is stored in `"next"`, so we add it to the list of frames to render. - frames[im_key].append(ep_td["next"][im_key]) + video_path = out_dir / f"episode_{ep_id}.mp4" + video_paths.append(video_path) - out_dir.mkdir(parents=True, exist_ok=True) - if len(dataset.image_keys) > 1: - camera = im_key[-1] - video_path = out_dir / f"episode_{current_ep_idx}_{camera}.mp4" - else: - video_path = out_dir / f"episode_{current_ep_idx}.mp4" - video_paths.append(str(video_path)) - - thread = threading.Thread( - target=cat_and_write_video, - args=(str(video_path), frames[im_key], fps), - ) - thread.start() - threads.append(thread) - - current_ep_idx = ep_idx - - # reset list of frames - del frames[im_key] - - if num_frames_left == 0: - logging.info("Ran out of frames") - break - - if current_ep_idx == NUM_EPISODES_TO_RENDER: - break + thread = threading.Thread( + target=cat_and_write_video, + args=(str(video_path), frames[im_key], dataset.fps), + ) + thread.start() + threads.append(thread) for thread in threads: thread.join()