From 5af00d0c1ee0aa3d9a90e6afe646474073ff5065 Mon Sep 17 00:00:00 2001 From: Cadene Date: Fri, 5 Apr 2024 09:31:39 +0000 Subject: [PATCH] fix train.py, stats, eval.py (training is running) --- lerobot/common/datasets/aloha.py | 16 +++++++------ lerobot/common/datasets/pusht.py | 16 +++++++------ lerobot/common/datasets/simxarm.py | 16 +++++++------ lerobot/common/datasets/utils.py | 15 ++++++++---- .../diffusion/diffusion_unet_image_policy.py | 7 +++--- lerobot/common/policies/diffusion/policy.py | 3 ++- lerobot/common/transforms.py | 16 ++++++------- lerobot/scripts/eval.py | 20 +++++++--------- lerobot/scripts/train.py | 9 +++---- tests/scripts/mock_dataset.py | 24 +++++++++---------- tests/test_datasets.py | 6 +---- 11 files changed, 76 insertions(+), 72 deletions(-) diff --git a/lerobot/common/datasets/aloha.py b/lerobot/common/datasets/aloha.py index 2744f595..102de08e 100644 --- a/lerobot/common/datasets/aloha.py +++ b/lerobot/common/datasets/aloha.py @@ -91,15 +91,17 @@ class AlohaDataset(torch.utils.data.Dataset): self.transform = transform self.delta_timestamps = delta_timestamps - data_dir = self.root / f"{self.dataset_id}" - if (data_dir / "data_dict.pth").exists() and (data_dir / "data_ids_per_episode.pth").exists(): - self.data_dict = torch.load(data_dir / "data_dict.pth") - self.data_ids_per_episode = torch.load(data_dir / "data_ids_per_episode.pth") + self.data_dir = self.root / f"{self.dataset_id}" + if (self.data_dir / "data_dict.pth").exists() and ( + self.data_dir / "data_ids_per_episode.pth" + ).exists(): + self.data_dict = torch.load(self.data_dir / "data_dict.pth") + self.data_ids_per_episode = torch.load(self.data_dir / "data_ids_per_episode.pth") else: self._download_and_preproc_obsolete() - data_dir.mkdir(parents=True, exist_ok=True) - torch.save(self.data_dict, data_dir / "data_dict.pth") - torch.save(self.data_ids_per_episode, data_dir / "data_ids_per_episode.pth") + self.data_dir.mkdir(parents=True, exist_ok=True) + torch.save(self.data_dict, self.data_dir / "data_dict.pth") + torch.save(self.data_ids_per_episode, self.data_dir / "data_ids_per_episode.pth") @property def num_samples(self) -> int: diff --git a/lerobot/common/datasets/pusht.py b/lerobot/common/datasets/pusht.py index 3de70b1f..9b73b101 100644 --- a/lerobot/common/datasets/pusht.py +++ b/lerobot/common/datasets/pusht.py @@ -105,15 +105,17 @@ class PushtDataset(torch.utils.data.Dataset): self.transform = transform self.delta_timestamps = delta_timestamps - data_dir = self.root / f"{self.dataset_id}" - if (data_dir / "data_dict.pth").exists() and (data_dir / "data_ids_per_episode.pth").exists(): - self.data_dict = torch.load(data_dir / "data_dict.pth") - self.data_ids_per_episode = torch.load(data_dir / "data_ids_per_episode.pth") + self.data_dir = self.root / f"{self.dataset_id}" + if (self.data_dir / "data_dict.pth").exists() and ( + self.data_dir / "data_ids_per_episode.pth" + ).exists(): + self.data_dict = torch.load(self.data_dir / "data_dict.pth") + self.data_ids_per_episode = torch.load(self.data_dir / "data_ids_per_episode.pth") else: self._download_and_preproc_obsolete() - data_dir.mkdir(parents=True, exist_ok=True) - torch.save(self.data_dict, data_dir / "data_dict.pth") - torch.save(self.data_ids_per_episode, data_dir / "data_ids_per_episode.pth") + self.data_dir.mkdir(parents=True, exist_ok=True) + torch.save(self.data_dict, self.data_dir / "data_dict.pth") + torch.save(self.data_ids_per_episode, self.data_dir / "data_ids_per_episode.pth") @property def num_samples(self) -> int: diff --git a/lerobot/common/datasets/simxarm.py b/lerobot/common/datasets/simxarm.py index 4b2c68ad..7bddf608 100644 --- a/lerobot/common/datasets/simxarm.py +++ b/lerobot/common/datasets/simxarm.py @@ -46,15 +46,17 @@ class SimxarmDataset(torch.utils.data.Dataset): self.transform = transform self.delta_timestamps = delta_timestamps - data_dir = self.root / f"{self.dataset_id}" - if (data_dir / "data_dict.pth").exists() and (data_dir / "data_ids_per_episode.pth").exists(): - self.data_dict = torch.load(data_dir / "data_dict.pth") - self.data_ids_per_episode = torch.load(data_dir / "data_ids_per_episode.pth") + self.data_dir = self.root / f"{self.dataset_id}" + if (self.data_dir / "data_dict.pth").exists() and ( + self.data_dir / "data_ids_per_episode.pth" + ).exists(): + self.data_dict = torch.load(self.data_dir / "data_dict.pth") + self.data_ids_per_episode = torch.load(self.data_dir / "data_ids_per_episode.pth") else: self._download_and_preproc_obsolete() - data_dir.mkdir(parents=True, exist_ok=True) - torch.save(self.data_dict, data_dir / "data_dict.pth") - torch.save(self.data_ids_per_episode, data_dir / "data_ids_per_episode.pth") + self.data_dir.mkdir(parents=True, exist_ok=True) + torch.save(self.data_dict, self.data_dir / "data_dict.pth") + torch.save(self.data_ids_per_episode, self.data_dir / "data_ids_per_episode.pth") @property def num_samples(self) -> int: diff --git a/lerobot/common/datasets/utils.py b/lerobot/common/datasets/utils.py index 522227d7..6b207b4d 100644 --- a/lerobot/common/datasets/utils.py +++ b/lerobot/common/datasets/utils.py @@ -112,16 +112,19 @@ def compute_or_load_stats(dataset, batch_size=32, max_num_samples=None): if max_num_samples is None: max_num_samples = len(dataset) + else: + raise NotImplementedError("We need to set shuffle=True, but this violate an assert for now.") dataloader = torch.utils.data.DataLoader( dataset, num_workers=4, batch_size=batch_size, - shuffle=True, + shuffle=False, # pin_memory=cfg.device != "cpu", drop_last=False, ) + # these einops patterns will be used to aggregate batches and compute statistics stats_patterns = { "action": "b c -> c", "observation.state": "b c -> c", @@ -142,9 +145,9 @@ def compute_or_load_stats(dataset, batch_size=32, max_num_samples=None): first_batch = None running_item_count = 0 # for online mean computation for i, batch in enumerate( - tqdm(dataloader, total=ceil(max_num_samples / batch_size), desc="Compute mean, min, max") + tqdm.tqdm(dataloader, total=ceil(max_num_samples / batch_size), desc="Compute mean, min, max") ): - this_batch_size = batch.batch_size[0] + this_batch_size = len(batch["index"]) running_item_count += this_batch_size if first_batch is None: first_batch = deepcopy(batch) @@ -166,8 +169,10 @@ def compute_or_load_stats(dataset, batch_size=32, max_num_samples=None): first_batch_ = None running_item_count = 0 # for online std computation - for i, batch in enumerate(tqdm(dataloader, total=ceil(max_num_samples / batch_size), desc="Compute std")): - this_batch_size = batch.batch_size[0] + for i, batch in enumerate( + tqdm.tqdm(dataloader, total=ceil(max_num_samples / batch_size), desc="Compute std") + ): + this_batch_size = len(batch["index"]) running_item_count += this_batch_size # Sanity check to make sure the batches are still in the same order as before. if first_batch_ is None: diff --git a/lerobot/common/policies/diffusion/diffusion_unet_image_policy.py b/lerobot/common/policies/diffusion/diffusion_unet_image_policy.py index 7719fdde..373e4b6c 100644 --- a/lerobot/common/policies/diffusion/diffusion_unet_image_policy.py +++ b/lerobot/common/policies/diffusion/diffusion_unet_image_policy.py @@ -243,10 +243,9 @@ class DiffusionUnetImagePolicy(BaseImagePolicy): result = {"action": action, "action_pred": action_pred} return result - def compute_loss(self, batch): - assert "valid_mask" not in batch - nobs = batch["obs"] - nactions = batch["action"] + def compute_loss(self, obs_dict, action): + nobs = obs_dict + nactions = action batch_size = nactions.shape[0] horizon = nactions.shape[1] diff --git a/lerobot/common/policies/diffusion/policy.py b/lerobot/common/policies/diffusion/policy.py index a0fe0eba..de8796ab 100644 --- a/lerobot/common/policies/diffusion/policy.py +++ b/lerobot/common/policies/diffusion/policy.py @@ -157,7 +157,8 @@ class DiffusionPolicy(nn.Module): "image": batch["observation.image"], "agent_pos": batch["observation.state"], } - loss = self.diffusion.compute_loss(obs_dict) + action = batch["action"] + loss = self.diffusion.compute_loss(obs_dict, action) loss.backward() grad_norm = torch.nn.utils.clip_grad_norm_( diff --git a/lerobot/common/transforms.py b/lerobot/common/transforms.py index 4974c086..ec967614 100644 --- a/lerobot/common/transforms.py +++ b/lerobot/common/transforms.py @@ -72,12 +72,12 @@ class NormalizeTransform(Transform): if inkey not in item: continue if self.mode == "mean_std": - mean = self.stats[f"{inkey}.mean"] - std = self.stats[f"{inkey}.std"] + mean = self.stats[inkey]["mean"] + std = self.stats[inkey]["std"] item[outkey] = (item[inkey] - mean) / (std + 1e-8) else: - min = self.stats[f"{inkey}.min"] - max = self.stats[f"{inkey}.max"] + min = self.stats[inkey]["min"] + max = self.stats[inkey]["max"] # normalize to [0,1] item[outkey] = (item[inkey] - min) / (max - min) # normalize to [-1, 1] @@ -89,12 +89,12 @@ class NormalizeTransform(Transform): if inkey not in item: continue if self.mode == "mean_std": - mean = self.stats[f"{inkey}.mean"] - std = self.stats[f"{inkey}.std"] + mean = self.stats[inkey]["mean"] + std = self.stats[inkey]["std"] item[outkey] = item[inkey] * std + mean else: - min = self.stats[f"{inkey}.min"] - max = self.stats[f"{inkey}.max"] + min = self.stats[inkey]["min"] + max = self.stats[inkey]["max"] item[outkey] = (item[inkey] + 1) / 2 item[outkey] = item[outkey] * (max - min) + min return item diff --git a/lerobot/scripts/eval.py b/lerobot/scripts/eval.py index fe0f7bb2..09399878 100644 --- a/lerobot/scripts/eval.py +++ b/lerobot/scripts/eval.py @@ -37,7 +37,6 @@ from pathlib import Path import einops import gymnasium as gym -import hydra import imageio import numpy as np import torch @@ -47,8 +46,8 @@ from lerobot.common.datasets.factory import make_dataset from lerobot.common.envs.factory import make_env from lerobot.common.logger import log_output_dir from lerobot.common.policies.factory import make_policy -from lerobot.common.utils import get_safe_torch_device, init_hydra_config, init_logging, set_global_seed from lerobot.common.transforms import apply_inverse_transform +from lerobot.common.utils import get_safe_torch_device, init_hydra_config, init_logging, set_global_seed def write_video(video_path, stacked_frames, fps): @@ -92,9 +91,12 @@ def eval_policy( fps: int = 15, return_first_video: bool = False, transform: callable = None, + seed=None, ): if policy is not None: policy.eval() + device = "cpu" if policy is None else next(policy.parameters()).device + start = time.time() sum_rewards = [] max_rewards = [] @@ -125,11 +127,11 @@ def eval_policy( policy.reset() else: logging.warning( - f"Policy {policy} doesnt have a `reset` method. This find if the policy doesnt rely on an internal state during rollout." + f"Policy {policy} doesnt have a `reset` method. It is required if the policy relies on an internal state during rollout." ) # reset the environment - observation, info = env.reset(seed=cfg.seed) + observation, info = env.reset(seed=seed) maybe_render_frame(env) rewards = [] @@ -138,13 +140,12 @@ def eval_policy( done = torch.tensor([False for _ in env.envs]) step = 0 - do_rollout = True - while do_rollout: + while not done.all(): # apply transform to normalize the observations observation = preprocess_observation(observation, transform) # send observation to device/gpu - observation = {key: observation[key].to(cfg.device, non_blocking=True) for key in observation} + observation = {key: observation[key].to(device, non_blocking=True) for key in observation} # get the next action for the environment with torch.inference_mode(): @@ -180,10 +181,6 @@ def eval_policy( step += 1 - if done.all(): - do_rollout = False - break - rewards = torch.stack(rewards, dim=1) successes = torch.stack(successes, dim=1) dones = torch.stack(dones, dim=1) @@ -295,6 +292,7 @@ def eval(cfg: dict, out_dir=None, stats_path=None): fps=cfg.env.fps, # TODO(rcadene): what should we do with the transform? transform=dataset.transform, + seed=cfg.seed, ) print(info["aggregated"]) diff --git a/lerobot/scripts/train.py b/lerobot/scripts/train.py index 5e9cd361..602fa5ab 100644 --- a/lerobot/scripts/train.py +++ b/lerobot/scripts/train.py @@ -145,7 +145,7 @@ def train(cfg: dict, out_dir=None, job_name=None): # ) logging.info("make_env") - env = make_env(cfg) + env = make_env(cfg, num_parallel_envs=cfg.eval_episodes) logging.info("make_policy") policy = make_policy(cfg) @@ -173,12 +173,11 @@ def train(cfg: dict, out_dir=None, job_name=None): eval_info, first_video = eval_policy( env, policy, - num_episodes=cfg.eval_episodes, - max_steps=cfg.env.episode_length, return_first_video=True, video_dir=Path(out_dir) / "eval", save_video=True, transform=dataset.transform, + seed=cfg.seed, ) log_eval_info(logger, eval_info["aggregated"], step, cfg, dataset, is_offline) if cfg.wandb.enable: @@ -211,7 +210,7 @@ def train(cfg: dict, out_dir=None, job_name=None): for key in batch: batch[key] = batch[key].to(cfg.device, non_blocking=True) - train_info = policy.update(batch, step) + train_info = policy(batch, step) # TODO(rcadene): is it ok if step_t=0 = 0 and not 1 as previously done? if step % cfg.log_freq == 0: @@ -223,6 +222,8 @@ def train(cfg: dict, out_dir=None, job_name=None): step += 1 + raise NotImplementedError() + demo_buffer = dataset if cfg.policy.balanced_sampling else None online_step = 0 is_offline = False diff --git a/tests/scripts/mock_dataset.py b/tests/scripts/mock_dataset.py index d9c86464..044417aa 100644 --- a/tests/scripts/mock_dataset.py +++ b/tests/scripts/mock_dataset.py @@ -18,28 +18,26 @@ Example: import argparse import shutil -from tensordict import TensorDict from pathlib import Path +import torch + def mock_dataset(in_data_dir, out_data_dir, num_frames): in_data_dir = Path(in_data_dir) out_data_dir = Path(out_data_dir) - # load full dataset as a tensor dict - in_td_data = TensorDict.load_memmap(in_data_dir / "replay_buffer") + # copy the first `n` frames for each data key so that we have real data + in_data_dict = torch.load(in_data_dir / "data_dict.pth") + out_data_dict = {key: in_data_dict[key][:num_frames].clone() for key in in_data_dict} + torch.save(out_data_dict, out_data_dir / "data_dict.pth") - # use 1 frame to know the specification of the dataset - # and copy it over `n` frames in the test artifact directory - out_td_data = in_td_data[0].expand(num_frames).memmap_like(out_data_dir / "replay_buffer") + # copy the full mapping between data_id and episode since it's small + in_ids_per_ep_path = in_data_dir / "data_ids_per_episode.pth" + out_ids_per_ep_path = out_data_dir / "data_ids_per_episode.pth" + shutil.copy(in_ids_per_ep_path, out_ids_per_ep_path) - # copy the first `n` frames so that we have real data - out_td_data[:num_frames] = in_td_data[:num_frames].clone() - - # make sure everything has been properly written - out_td_data.lock_() - - # copy the full statistics of dataset since it's pretty small + # copy the full statistics of dataset since it's small in_stats_path = in_data_dir / "stats.pth" out_stats_path = out_data_dir / "stats.pth" shutil.copy(in_stats_path, out_stats_path) diff --git a/tests/test_datasets.py b/tests/test_datasets.py index 00008259..e5ca0099 100644 --- a/tests/test_datasets.py +++ b/tests/test_datasets.py @@ -59,11 +59,7 @@ def test_factory(env_name, dataset_id): # ) # dataset = make_dataset(cfg) # # Get all of the data. -# all_data = TensorDictReplayBuffer( -# storage=buffer._storage, -# batch_size=len(buffer), -# sampler=SamplerWithoutReplacement(), -# ).sample().float() +# all_data = dataset.data_dict # # Note: we set the batch size to be smaller than the whole dataset to make sure we are testing batched # # computation of the statistics. While doing this, we also make sure it works when we don't divide the # # dataset into even batches.