From 13310681b1c1b3e7052d6164c76eb6ad046563d9 Mon Sep 17 00:00:00 2001 From: Simon Alibert <75076266+aliberts@users.noreply.github.com> Date: Wed, 29 May 2024 23:02:23 +0200 Subject: [PATCH 1/6] Enable cuda for end-to-end tests (#222) --- .github/workflows/nightly-tests.yml | 2 ++ Makefile | 41 +++++++++++++++-------------- 2 files changed, 23 insertions(+), 20 deletions(-) diff --git a/.github/workflows/nightly-tests.yml b/.github/workflows/nightly-tests.yml index b30a0bca..b3a2157b 100644 --- a/.github/workflows/nightly-tests.yml +++ b/.github/workflows/nightly-tests.yml @@ -70,6 +70,8 @@ jobs: # files: ./coverage.xml # verbose: true - name: Tests end-to-end + env: + DEVICE: cuda run: make test-end-to-end # - name: Generate Report diff --git a/Makefile b/Makefile index dd98228f..33f3edf2 100644 --- a/Makefile +++ b/Makefile @@ -10,6 +10,7 @@ endif export PATH := $(dir $(PYTHON_PATH)):$(PATH) +DEVICE ?= cpu build-cpu: docker build -t lerobot:latest -f docker/lerobot-cpu/Dockerfile . @@ -18,16 +19,16 @@ build-gpu: docker build -t lerobot:latest -f docker/lerobot-gpu/Dockerfile . test-end-to-end: - ${MAKE} test-act-ete-train - ${MAKE} test-act-ete-eval - ${MAKE} test-act-ete-train-amp - ${MAKE} test-act-ete-eval-amp - ${MAKE} test-diffusion-ete-train - ${MAKE} test-diffusion-ete-eval - ${MAKE} test-tdmpc-ete-train - ${MAKE} test-tdmpc-ete-eval - ${MAKE} test-default-ete-eval - ${MAKE} test-act-pusht-tutorial + ${MAKE} DEVICE=$(DEVICE) test-act-ete-train + ${MAKE} DEVICE=$(DEVICE) test-act-ete-eval + ${MAKE} DEVICE=$(DEVICE) test-act-ete-train-amp + ${MAKE} DEVICE=$(DEVICE) test-act-ete-eval-amp + ${MAKE} DEVICE=$(DEVICE) test-diffusion-ete-train + ${MAKE} DEVICE=$(DEVICE) test-diffusion-ete-eval + ${MAKE} DEVICE=$(DEVICE) test-tdmpc-ete-train + ${MAKE} DEVICE=$(DEVICE) test-tdmpc-ete-eval + ${MAKE} DEVICE=$(DEVICE) test-default-ete-eval + ${MAKE} DEVICE=$(DEVICE) test-act-pusht-tutorial test-act-ete-train: python lerobot/scripts/train.py \ @@ -39,7 +40,7 @@ test-act-ete-train: training.online_steps=0 \ eval.n_episodes=1 \ eval.batch_size=1 \ - device=cpu \ + device=$(DEVICE) \ training.save_checkpoint=true \ training.save_freq=2 \ policy.n_action_steps=20 \ @@ -53,7 +54,7 @@ test-act-ete-eval: eval.n_episodes=1 \ eval.batch_size=1 \ env.episode_length=8 \ - device=cpu \ + device=$(DEVICE) \ test-act-ete-train-amp: python lerobot/scripts/train.py \ @@ -65,7 +66,7 @@ test-act-ete-train-amp: training.online_steps=0 \ eval.n_episodes=1 \ eval.batch_size=1 \ - device=cpu \ + device=$(DEVICE) \ training.save_checkpoint=true \ training.save_freq=2 \ policy.n_action_steps=20 \ @@ -80,7 +81,7 @@ test-act-ete-eval-amp: eval.n_episodes=1 \ eval.batch_size=1 \ env.episode_length=8 \ - device=cpu \ + device=$(DEVICE) \ use_amp=true test-diffusion-ete-train: @@ -95,7 +96,7 @@ test-diffusion-ete-train: training.online_steps=0 \ eval.n_episodes=1 \ eval.batch_size=1 \ - device=cpu \ + device=$(DEVICE) \ training.save_checkpoint=true \ training.save_freq=2 \ training.batch_size=2 \ @@ -107,7 +108,7 @@ test-diffusion-ete-eval: eval.n_episodes=1 \ eval.batch_size=1 \ env.episode_length=8 \ - device=cpu \ + device=$(DEVICE) \ # TODO(alexander-soare): Restore online_steps to 2 when it is reinstated. test-tdmpc-ete-train: @@ -122,7 +123,7 @@ test-tdmpc-ete-train: eval.n_episodes=1 \ eval.batch_size=1 \ env.episode_length=2 \ - device=cpu \ + device=$(DEVICE) \ training.save_checkpoint=true \ training.save_freq=2 \ training.batch_size=2 \ @@ -134,7 +135,7 @@ test-tdmpc-ete-eval: eval.n_episodes=1 \ eval.batch_size=1 \ env.episode_length=8 \ - device=cpu \ + device=$(DEVICE) \ test-default-ete-eval: python lerobot/scripts/eval.py \ @@ -142,7 +143,7 @@ test-default-ete-eval: eval.n_episodes=1 \ eval.batch_size=1 \ env.episode_length=8 \ - device=cpu \ + device=$(DEVICE) \ test-act-pusht-tutorial: cp examples/advanced/1_train_act_pusht/act_pusht.yaml lerobot/configs/policy/created_by_Makefile.yaml @@ -154,7 +155,7 @@ test-act-pusht-tutorial: eval.n_episodes=1 \ eval.batch_size=1 \ env.episode_length=2 \ - device=cpu \ + device=$(DEVICE) \ training.save_model=true \ training.save_freq=2 \ training.batch_size=2 \ From 2c2e4e14edd5586bd6cced4a3a7af656548ff282 Mon Sep 17 00:00:00 2001 From: Remi Date: Thu, 30 May 2024 11:26:39 +0200 Subject: [PATCH 2/6] Add `aloha_dora_format.py` (#201) Co-authored-by: Thomas Wolf --- .../push_dataset_to_hub/aloha_dora_format.py | 228 ++++++++++++++++++ lerobot/scripts/push_dataset_to_hub.py | 6 +- 2 files changed, 233 insertions(+), 1 deletion(-) create mode 100644 lerobot/common/datasets/push_dataset_to_hub/aloha_dora_format.py diff --git a/lerobot/common/datasets/push_dataset_to_hub/aloha_dora_format.py b/lerobot/common/datasets/push_dataset_to_hub/aloha_dora_format.py new file mode 100644 index 00000000..d1e5a52c --- /dev/null +++ b/lerobot/common/datasets/push_dataset_to_hub/aloha_dora_format.py @@ -0,0 +1,228 @@ +#!/usr/bin/env python + +# Copyright 2024 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Contains utilities to process raw data format from dora-record +""" + +import logging +import re +from pathlib import Path + +import pandas as pd +import torch +from datasets import Dataset, Features, Image, Sequence, Value + +from lerobot.common.datasets.utils import ( + hf_transform_to_torch, +) +from lerobot.common.datasets.video_utils import VideoFrame +from lerobot.common.utils.utils import init_logging + + +def check_format(raw_dir) -> bool: + assert raw_dir.exists() + + leader_file = list(raw_dir.glob("*.parquet")) + if len(leader_file) == 0: + raise ValueError(f"Missing parquet files in '{raw_dir}'") + return True + + +def load_from_raw(raw_dir: Path, out_dir: Path, fps: int): + # Load data stream that will be used as reference for the timestamps synchronization + reference_files = list(raw_dir.glob("observation.images.cam_*.parquet")) + if len(reference_files) == 0: + raise ValueError(f"Missing reference files for camera, starting with in '{raw_dir}'") + # select first camera in alphanumeric order + reference_key = sorted(reference_files)[0].stem + reference_df = pd.read_parquet(raw_dir / f"{reference_key}.parquet") + reference_df = reference_df[["timestamp_utc", reference_key]] + + # Merge all data stream using nearest backward strategy + df = reference_df + for path in raw_dir.glob("*.parquet"): + key = path.stem # action or observation.state or ... + if key == reference_key: + continue + if "failed_episode_index" in key: + # TODO(rcadene): add support for removing episodes that are tagged as "failed" + continue + modality_df = pd.read_parquet(path) + modality_df = modality_df[["timestamp_utc", key]] + df = pd.merge_asof( + df, + modality_df, + on="timestamp_utc", + # "nearest" is the best option over "backward", since the latter can desynchronizes camera timestamps by + # matching timestamps that are too far appart, in order to fit the backward constraints. It's not the case for "nearest". + # However, note that "nearest" might synchronize the reference camera with other cameras on slightly future timestamps. + # This is not a problem when the tolerance is set to be low enough to avoid matching timestamps that + # are too far appart. + direction="nearest", + tolerance=pd.Timedelta(f"{1/fps} seconds"), + ) + + # Remove rows with episode_index -1 which indicates data that correspond to in-between episodes + df = df[df["episode_index"] != -1] + + image_keys = [key for key in df if "observation.images." in key] + + def get_episode_index(row): + episode_index_per_cam = {} + for key in image_keys: + path = row[key][0]["path"] + match = re.search(r"_(\d{6}).mp4", path) + if not match: + raise ValueError(path) + episode_index = int(match.group(1)) + episode_index_per_cam[key] = episode_index + assert ( + len(set(episode_index_per_cam.values())) == 1 + ), f"All cameras are expected to belong to the same episode, but getting {episode_index_per_cam}" + return episode_index + + df["episode_index"] = df.apply(get_episode_index, axis=1) + + # dora only use arrays, so single values are encapsulated into a list + df["frame_index"] = df.groupby("episode_index").cumcount() + df = df.reset_index() + df["index"] = df.index + + # set 'next.done' to True for the last frame of each episode + df["next.done"] = False + df.loc[df.groupby("episode_index").tail(1).index, "next.done"] = True + + df["timestamp"] = df["timestamp_utc"].map(lambda x: x.timestamp()) + # each episode starts with timestamp 0 to match the ones from the video + df["timestamp"] = df.groupby("episode_index")["timestamp"].transform(lambda x: x - x.iloc[0]) + + del df["timestamp_utc"] + + # sanity check + has_nan = df.isna().any().any() + if has_nan: + raise ValueError("Dataset contains Nan values.") + + # sanity check episode indices go from 0 to n-1 + ep_ids = [ep_idx for ep_idx, _ in df.groupby("episode_index")] + expected_ep_ids = list(range(df["episode_index"].max() + 1)) + assert ep_ids == expected_ep_ids, f"Episodes indices go from {ep_ids} instead of {expected_ep_ids}" + + # Create symlink to raw videos directory (that needs to be absolute not relative) + out_dir.mkdir(parents=True, exist_ok=True) + videos_dir = out_dir / "videos" + videos_dir.symlink_to((raw_dir / "videos").absolute()) + + # sanity check the video paths are well formated + for key in df: + if "observation.images." not in key: + continue + for ep_idx in ep_ids: + video_path = videos_dir / f"{key}_episode_{ep_idx:06d}.mp4" + assert video_path.exists(), f"Video file not found in {video_path}" + + data_dict = {} + for key in df: + # is video frame + if "observation.images." in key: + # we need `[0] because dora only use arrays, so single values are encapsulated into a list. + # it is the case for video_frame dictionary = [{"path": ..., "timestamp": ...}] + data_dict[key] = [video_frame[0] for video_frame in df[key].values] + + # sanity check the video path is well formated + video_path = videos_dir.parent / data_dict[key][0]["path"] + assert video_path.exists(), f"Video file not found in {video_path}" + # is number + elif df[key].iloc[0].ndim == 0 or df[key].iloc[0].shape[0] == 1: + data_dict[key] = torch.from_numpy(df[key].values) + # is vector + elif df[key].iloc[0].shape[0] > 1: + data_dict[key] = torch.stack([torch.from_numpy(x.copy()) for x in df[key].values]) + else: + raise ValueError(key) + + # Get the episode index containing for each unique episode index + first_ep_index_df = df.groupby("episode_index").agg(start_index=("index", "first")).reset_index() + from_ = first_ep_index_df["start_index"].tolist() + to_ = from_[1:] + [len(df)] + episode_data_index = { + "from": from_, + "to": to_, + } + + return data_dict, episode_data_index + + +def to_hf_dataset(data_dict, video) -> Dataset: + features = {} + + keys = [key for key in data_dict if "observation.images." in key] + for key in keys: + if video: + features[key] = VideoFrame() + else: + features[key] = Image() + + features["observation.state"] = Sequence( + length=data_dict["observation.state"].shape[1], feature=Value(dtype="float32", id=None) + ) + if "observation.velocity" in data_dict: + features["observation.velocity"] = Sequence( + length=data_dict["observation.velocity"].shape[1], feature=Value(dtype="float32", id=None) + ) + if "observation.effort" in data_dict: + features["observation.effort"] = Sequence( + length=data_dict["observation.effort"].shape[1], feature=Value(dtype="float32", id=None) + ) + features["action"] = Sequence( + length=data_dict["action"].shape[1], feature=Value(dtype="float32", id=None) + ) + features["episode_index"] = Value(dtype="int64", id=None) + features["frame_index"] = Value(dtype="int64", id=None) + features["timestamp"] = Value(dtype="float32", id=None) + features["next.done"] = Value(dtype="bool", id=None) + features["index"] = Value(dtype="int64", id=None) + + hf_dataset = Dataset.from_dict(data_dict, features=Features(features)) + hf_dataset.set_transform(hf_transform_to_torch) + return hf_dataset + + +def from_raw_to_lerobot_format(raw_dir: Path, out_dir: Path, fps=None, video=True, debug=False): + init_logging() + + if debug: + logging.warning("debug=True not implemented. Falling back to debug=False.") + + # sanity check + check_format(raw_dir) + + if fps is None: + fps = 30 + else: + raise NotImplementedError() + + if not video: + raise NotImplementedError() + + data_df, episode_data_index = load_from_raw(raw_dir, out_dir, fps) + hf_dataset = to_hf_dataset(data_df, video) + + info = { + "fps": fps, + "video": video, + } + return hf_dataset, episode_data_index, info diff --git a/lerobot/scripts/push_dataset_to_hub.py b/lerobot/scripts/push_dataset_to_hub.py index 19af1cf8..c6eac5e9 100644 --- a/lerobot/scripts/push_dataset_to_hub.py +++ b/lerobot/scripts/push_dataset_to_hub.py @@ -84,10 +84,14 @@ def get_from_raw_to_lerobot_format_fn(raw_format): from lerobot.common.datasets.push_dataset_to_hub.umi_zarr_format import from_raw_to_lerobot_format elif raw_format == "aloha_hdf5": from lerobot.common.datasets.push_dataset_to_hub.aloha_hdf5_format import from_raw_to_lerobot_format + elif raw_format == "aloha_dora": + from lerobot.common.datasets.push_dataset_to_hub.aloha_dora_format import from_raw_to_lerobot_format elif raw_format == "xarm_pkl": from lerobot.common.datasets.push_dataset_to_hub.xarm_pkl_format import from_raw_to_lerobot_format else: - raise ValueError(raw_format) + raise ValueError( + f"The selected {raw_format} can't be found. Did you add it to `lerobot/scripts/push_dataset_to_hub.py::get_from_raw_to_lerobot_format_fn`?" + ) return from_raw_to_lerobot_format From 265b0ec44d42ea8cb912534df5b1e818d6caec0d Mon Sep 17 00:00:00 2001 From: Remi Date: Thu, 30 May 2024 13:45:22 +0200 Subject: [PATCH 3/6] Refactor env to add key word arguments from config yaml (#223) Co-authored-by: Simon Alibert <75076266+aliberts@users.noreply.github.com> --- .../push_dataset_to_hub/aloha_dora_format.py | 18 +++++++----- lerobot/common/envs/factory.py | 14 ++++----- lerobot/configs/default.yaml | 2 ++ lerobot/configs/env/aloha.yaml | 10 +++---- lerobot/configs/env/pusht.yaml | 11 ++++--- lerobot/configs/env/xarm.yaml | 11 ++++--- lerobot/scripts/train.py | 29 +++++++++++++++---- 7 files changed, 59 insertions(+), 36 deletions(-) diff --git a/lerobot/common/datasets/push_dataset_to_hub/aloha_dora_format.py b/lerobot/common/datasets/push_dataset_to_hub/aloha_dora_format.py index d1e5a52c..4a21bc2d 100644 --- a/lerobot/common/datasets/push_dataset_to_hub/aloha_dora_format.py +++ b/lerobot/common/datasets/push_dataset_to_hub/aloha_dora_format.py @@ -69,12 +69,10 @@ def load_from_raw(raw_dir: Path, out_dir: Path, fps: int): # "nearest" is the best option over "backward", since the latter can desynchronizes camera timestamps by # matching timestamps that are too far appart, in order to fit the backward constraints. It's not the case for "nearest". # However, note that "nearest" might synchronize the reference camera with other cameras on slightly future timestamps. - # This is not a problem when the tolerance is set to be low enough to avoid matching timestamps that # are too far appart. direction="nearest", tolerance=pd.Timedelta(f"{1/fps} seconds"), ) - # Remove rows with episode_index -1 which indicates data that correspond to in-between episodes df = df[df["episode_index"] != -1] @@ -89,9 +87,10 @@ def load_from_raw(raw_dir: Path, out_dir: Path, fps: int): raise ValueError(path) episode_index = int(match.group(1)) episode_index_per_cam[key] = episode_index - assert ( - len(set(episode_index_per_cam.values())) == 1 - ), f"All cameras are expected to belong to the same episode, but getting {episode_index_per_cam}" + if len(set(episode_index_per_cam.values())) != 1: + raise ValueError( + f"All cameras are expected to belong to the same episode, but getting {episode_index_per_cam}" + ) return episode_index df["episode_index"] = df.apply(get_episode_index, axis=1) @@ -119,7 +118,8 @@ def load_from_raw(raw_dir: Path, out_dir: Path, fps: int): # sanity check episode indices go from 0 to n-1 ep_ids = [ep_idx for ep_idx, _ in df.groupby("episode_index")] expected_ep_ids = list(range(df["episode_index"].max() + 1)) - assert ep_ids == expected_ep_ids, f"Episodes indices go from {ep_ids} instead of {expected_ep_ids}" + if ep_ids != expected_ep_ids: + raise ValueError(f"Episodes indices go from {ep_ids} instead of {expected_ep_ids}") # Create symlink to raw videos directory (that needs to be absolute not relative) out_dir.mkdir(parents=True, exist_ok=True) @@ -132,7 +132,8 @@ def load_from_raw(raw_dir: Path, out_dir: Path, fps: int): continue for ep_idx in ep_ids: video_path = videos_dir / f"{key}_episode_{ep_idx:06d}.mp4" - assert video_path.exists(), f"Video file not found in {video_path}" + if not video_path.exists(): + raise ValueError(f"Video file not found in {video_path}") data_dict = {} for key in df: @@ -144,7 +145,8 @@ def load_from_raw(raw_dir: Path, out_dir: Path, fps: int): # sanity check the video path is well formated video_path = videos_dir.parent / data_dict[key][0]["path"] - assert video_path.exists(), f"Video file not found in {video_path}" + if not video_path.exists(): + raise ValueError(f"Video file not found in {video_path}") # is number elif df[key].iloc[0].ndim == 0 or df[key].iloc[0].shape[0] == 1: data_dict[key] = torch.from_numpy(df[key].values) diff --git a/lerobot/common/envs/factory.py b/lerobot/common/envs/factory.py index 83f94cfe..d73939b1 100644 --- a/lerobot/common/envs/factory.py +++ b/lerobot/common/envs/factory.py @@ -27,14 +27,6 @@ def make_env(cfg: DictConfig, n_envs: int | None = None) -> gym.vector.VectorEnv if n_envs is not None and n_envs < 1: raise ValueError("`n_envs must be at least 1") - kwargs = { - "obs_type": "pixels_agent_pos", - "render_mode": "rgb_array", - "max_episode_steps": cfg.env.episode_length, - "visualization_width": 384, - "visualization_height": 384, - } - package_name = f"gym_{cfg.env.name}" try: @@ -46,12 +38,16 @@ def make_env(cfg: DictConfig, n_envs: int | None = None) -> gym.vector.VectorEnv raise e gym_handle = f"{package_name}/{cfg.env.task}" + gym_kwgs = dict(cfg.env.get("gym", {})) + + if cfg.env.get("episode_length"): + gym_kwgs["max_episode_steps"] = cfg.env.episode_length # batched version of the env that returns an observation of shape (b, c) env_cls = gym.vector.AsyncVectorEnv if cfg.eval.use_async_envs else gym.vector.SyncVectorEnv env = env_cls( [ - lambda: gym.make(gym_handle, disable_env_checker=True, **kwargs) + lambda: gym.make(gym_handle, disable_env_checker=True, **gym_kwgs) for _ in range(n_envs if n_envs is not None else cfg.eval.batch_size) ] ) diff --git a/lerobot/configs/default.yaml b/lerobot/configs/default.yaml index 9ae30784..f2238769 100644 --- a/lerobot/configs/default.yaml +++ b/lerobot/configs/default.yaml @@ -37,6 +37,8 @@ training: save_freq: ??? log_freq: 250 save_checkpoint: true + num_workers: 4 + batch_size: ??? eval: n_episodes: 1 diff --git a/lerobot/configs/env/aloha.yaml b/lerobot/configs/env/aloha.yaml index 95e4503d..296a4481 100644 --- a/lerobot/configs/env/aloha.yaml +++ b/lerobot/configs/env/aloha.yaml @@ -5,10 +5,10 @@ fps: 50 env: name: aloha task: AlohaInsertion-v0 - from_pixels: True - pixels_only: False - image_size: [3, 480, 640] - episode_length: 400 - fps: ${fps} state_dim: 14 action_dim: 14 + fps: ${fps} + episode_length: 400 + gym: + obs_type: pixels_agent_pos + render_mode: rgb_array diff --git a/lerobot/configs/env/pusht.yaml b/lerobot/configs/env/pusht.yaml index 43e9d187..771fbbf4 100644 --- a/lerobot/configs/env/pusht.yaml +++ b/lerobot/configs/env/pusht.yaml @@ -5,10 +5,13 @@ fps: 10 env: name: pusht task: PushT-v0 - from_pixels: True - pixels_only: False image_size: 96 - episode_length: 300 - fps: ${fps} state_dim: 2 action_dim: 2 + fps: ${fps} + episode_length: 300 + gym: + obs_type: pixels_agent_pos + render_mode: rgb_array + visualization_width: 384 + visualization_height: 384 diff --git a/lerobot/configs/env/xarm.yaml b/lerobot/configs/env/xarm.yaml index 098b0396..9dbb96f5 100644 --- a/lerobot/configs/env/xarm.yaml +++ b/lerobot/configs/env/xarm.yaml @@ -5,10 +5,13 @@ fps: 15 env: name: xarm task: XarmLift-v0 - from_pixels: True - pixels_only: False image_size: 84 - episode_length: 25 - fps: ${fps} state_dim: 4 action_dim: 4 + fps: ${fps} + episode_length: 25 + gym: + obs_type: pixels_agent_pos + render_mode: rgb_array + visualization_width: 384 + visualization_height: 384 diff --git a/lerobot/scripts/train.py b/lerobot/scripts/train.py index 5fb86f36..eb33b268 100644 --- a/lerobot/scripts/train.py +++ b/lerobot/scripts/train.py @@ -281,8 +281,12 @@ def train(cfg: DictConfig, out_dir: str | None = None, job_name: str | None = No logging.info("make_dataset") offline_dataset = make_dataset(cfg) - logging.info("make_env") - eval_env = make_env(cfg) + # Create environment used for evaluating checkpoints during training on simulation data. + # On real-world data, no need to create an environment as evaluations are done outside train.py, + # using the eval.py instead, with gym_dora environment and dora-rs. + if cfg.training.eval_freq > 0: + logging.info("make_env") + eval_env = make_env(cfg) logging.info("make_policy") policy = make_policy( @@ -315,7 +319,7 @@ def train(cfg: DictConfig, out_dir: str | None = None, job_name: str | None = No # Note: this helper will be used in offline and online training loops. def evaluate_and_checkpoint_if_needed(step): - if step % cfg.training.eval_freq == 0: + if cfg.training.eval_freq > 0 and step % cfg.training.eval_freq == 0: logging.info(f"Eval policy at step {step}") with torch.no_grad(), torch.autocast(device_type=device.type) if cfg.use_amp else nullcontext(): eval_info = eval_policy( @@ -349,7 +353,7 @@ def train(cfg: DictConfig, out_dir: str | None = None, job_name: str | None = No # create dataloader for offline training dataloader = torch.utils.data.DataLoader( offline_dataset, - num_workers=4, + num_workers=cfg.training.num_workers, batch_size=cfg.training.batch_size, shuffle=True, pin_memory=device.type != "cpu", @@ -386,6 +390,16 @@ def train(cfg: DictConfig, out_dir: str | None = None, job_name: str | None = No step += 1 + logging.info("End of offline training") + + if cfg.training.online_steps == 0: + if cfg.training.eval_freq > 0: + eval_env.close() + return + + # create an env dedicated to online episodes collection from policy rollout + online_training_env = make_env(cfg, n_envs=1) + # create an empty online dataset similar to offline dataset online_dataset = deepcopy(offline_dataset) online_dataset.hf_dataset = {} @@ -406,8 +420,11 @@ def train(cfg: DictConfig, out_dir: str | None = None, job_name: str | None = No drop_last=False, ) - eval_env.close() - logging.info("End of training") + logging.info("End of online training") + + if cfg.training.eval_freq > 0: + eval_env.close() + online_training_env.close() @hydra.main(version_base="1.2", config_name="default", config_path="../configs") From 111cd58f8add8ff5ededabf6ee270471930ed04b Mon Sep 17 00:00:00 2001 From: Alexander Soare Date: Thu, 30 May 2024 16:12:21 +0100 Subject: [PATCH 4/6] Add `MultiLerobotDataset` for training with multiple `LeRobotDataset`s (#229) --- .../compute_stats.py | 60 ++++- lerobot/common/datasets/factory.py | 45 +++- lerobot/common/datasets/lerobot_dataset.py | 206 +++++++++++++++++- lerobot/common/datasets/utils.py | 8 +- lerobot/configs/default.yaml | 4 + lerobot/scripts/push_dataset_to_hub.py | 2 +- lerobot/scripts/train.py | 49 +---- tests/test_datasets.py | 50 ++++- 8 files changed, 352 insertions(+), 72 deletions(-) rename lerobot/common/datasets/{push_dataset_to_hub => }/compute_stats.py (70%) diff --git a/lerobot/common/datasets/push_dataset_to_hub/compute_stats.py b/lerobot/common/datasets/compute_stats.py similarity index 70% rename from lerobot/common/datasets/push_dataset_to_hub/compute_stats.py rename to lerobot/common/datasets/compute_stats.py index ec296658..a69bc573 100644 --- a/lerobot/common/datasets/push_dataset_to_hub/compute_stats.py +++ b/lerobot/common/datasets/compute_stats.py @@ -16,17 +16,15 @@ from copy import deepcopy from math import ceil -import datasets import einops import torch import tqdm from datasets import Image -from lerobot.common.datasets.lerobot_dataset import LeRobotDataset from lerobot.common.datasets.video_utils import VideoFrame -def get_stats_einops_patterns(dataset: LeRobotDataset | datasets.Dataset, num_workers=0): +def get_stats_einops_patterns(dataset, num_workers=0): """These einops patterns will be used to aggregate batches and compute statistics. Note: We assume the images are in channel first format @@ -66,9 +64,8 @@ def get_stats_einops_patterns(dataset: LeRobotDataset | datasets.Dataset, num_wo return stats_patterns -def compute_stats( - dataset: LeRobotDataset | datasets.Dataset, batch_size=32, num_workers=16, max_num_samples=None -): +def compute_stats(dataset, batch_size=32, num_workers=16, max_num_samples=None): + """Compute mean/std and min/max statistics of all data keys in a LeRobotDataset.""" if max_num_samples is None: max_num_samples = len(dataset) @@ -159,3 +156,54 @@ def compute_stats( "min": min[key], } return stats + + +def aggregate_stats(ls_datasets) -> dict[str, torch.Tensor]: + """Aggregate stats of multiple LeRobot datasets into one set of stats without recomputing from scratch. + + The final stats will have the union of all data keys from each of the datasets. + + The final stats will have the union of all data keys from each of the datasets. For instance: + - new_max = max(max_dataset_0, max_dataset_1, ...) + - new_min = min(min_dataset_0, min_dataset_1, ...) + - new_mean = (mean of all data) + - new_std = (std of all data) + """ + data_keys = set() + for dataset in ls_datasets: + data_keys.update(dataset.stats.keys()) + stats = {k: {} for k in data_keys} + for data_key in data_keys: + for stat_key in ["min", "max"]: + # compute `max(dataset_0["max"], dataset_1["max"], ...)` + stats[data_key][stat_key] = einops.reduce( + torch.stack([d.stats[data_key][stat_key] for d in ls_datasets if data_key in d.stats], dim=0), + "n ... -> ...", + stat_key, + ) + total_samples = sum(d.num_samples for d in ls_datasets if data_key in d.stats) + # Compute the "sum" statistic by multiplying each mean by the number of samples in the respective + # dataset, then divide by total_samples to get the overall "mean". + # NOTE: the brackets around (d.num_samples / total_samples) are needed tor minimize the risk of + # numerical overflow! + stats[data_key]["mean"] = sum( + d.stats[data_key]["mean"] * (d.num_samples / total_samples) + for d in ls_datasets + if data_key in d.stats + ) + # The derivation for standard deviation is a little more involved but is much in the same spirit as + # the computation of the mean. + # Given two sets of data where the statistics are known: + # σ_combined = sqrt[ (n1 * (σ1^2 + d1^2) + n2 * (σ2^2 + d2^2)) / (n1 + n2) ] + # where d1 = μ1 - μ_combined, d2 = μ2 - μ_combined + # NOTE: the brackets around (d.num_samples / total_samples) are needed tor minimize the risk of + # numerical overflow! + stats[data_key]["std"] = torch.sqrt( + sum( + (d.stats[data_key]["std"] ** 2 + (d.stats[data_key]["mean"] - stats[data_key]["mean"]) ** 2) + * (d.num_samples / total_samples) + for d in ls_datasets + if data_key in d.stats + ) + ) + return stats diff --git a/lerobot/common/datasets/factory.py b/lerobot/common/datasets/factory.py index 7bdc2ca9..b48a9211 100644 --- a/lerobot/common/datasets/factory.py +++ b/lerobot/common/datasets/factory.py @@ -16,9 +16,9 @@ import logging import torch -from omegaconf import OmegaConf +from omegaconf import ListConfig, OmegaConf -from lerobot.common.datasets.lerobot_dataset import LeRobotDataset +from lerobot.common.datasets.lerobot_dataset import LeRobotDataset, MultiLeRobotDataset def resolve_delta_timestamps(cfg): @@ -35,11 +35,27 @@ def resolve_delta_timestamps(cfg): cfg.training.delta_timestamps[key] = eval(delta_timestamps[key]) -def make_dataset( - cfg, - split="train", -): - if cfg.env.name not in cfg.dataset_repo_id: +def make_dataset(cfg, split: str = "train") -> LeRobotDataset | MultiLeRobotDataset: + """ + Args: + cfg: A Hydra config as per the LeRobot config scheme. + split: Select the data subset used to create an instance of LeRobotDataset. + All datasets hosted on [lerobot](https://huggingface.co/lerobot) contain only one subset: "train". + Thus, by default, `split="train"` selects all the available data. `split` aims to work like the + slicer in the hugging face datasets: + https://huggingface.co/docs/datasets/v2.19.0/loading#slice-splits + As of now, it only supports `split="train[:n]"` to load the first n frames of the dataset or + `split="train[n:]"` to load the last n frames. For instance `split="train[:1000]"`. + Returns: + The LeRobotDataset. + """ + if not isinstance(cfg.dataset_repo_id, (str, ListConfig)): + raise ValueError( + "Expected cfg.dataset_repo_id to be either a single string to load one dataset or a list of " + "strings to load multiple datasets." + ) + + if isinstance(cfg.dataset_repo_id, str) and cfg.env.name not in cfg.dataset_repo_id: logging.warning( f"There might be a mismatch between your training dataset ({cfg.dataset_repo_id=}) and your " f"environment ({cfg.env.name=})." @@ -49,11 +65,16 @@ def make_dataset( # TODO(rcadene): add data augmentations - dataset = LeRobotDataset( - cfg.dataset_repo_id, - split=split, - delta_timestamps=cfg.training.get("delta_timestamps"), - ) + if isinstance(cfg.dataset_repo_id, str): + dataset = LeRobotDataset( + cfg.dataset_repo_id, + split=split, + delta_timestamps=cfg.training.get("delta_timestamps"), + ) + else: + dataset = MultiLeRobotDataset( + cfg.dataset_repo_id, split=split, delta_timestamps=cfg.training.get("delta_timestamps") + ) if cfg.get("override_dataset_stats"): for key, stats_dict in cfg.override_dataset_stats.items(): diff --git a/lerobot/common/datasets/lerobot_dataset.py b/lerobot/common/datasets/lerobot_dataset.py index 057e4770..a87c3ee8 100644 --- a/lerobot/common/datasets/lerobot_dataset.py +++ b/lerobot/common/datasets/lerobot_dataset.py @@ -13,12 +13,16 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import logging import os from pathlib import Path +from typing import Callable import datasets import torch +import torch.utils +from lerobot.common.datasets.compute_stats import aggregate_stats from lerobot.common.datasets.utils import ( calculate_episode_data_index, load_episode_data_index, @@ -42,7 +46,7 @@ class LeRobotDataset(torch.utils.data.Dataset): version: str | None = CODEBASE_VERSION, root: Path | None = DATA_DIR, split: str = "train", - transform: callable = None, + transform: Callable | None = None, delta_timestamps: dict[list[float]] | None = None, ): super().__init__() @@ -171,7 +175,7 @@ class LeRobotDataset(torch.utils.data.Dataset): @classmethod def from_preloaded( cls, - repo_id: str, + repo_id: str = "from_preloaded", version: str | None = CODEBASE_VERSION, root: Path | None = None, split: str = "train", @@ -183,7 +187,15 @@ class LeRobotDataset(torch.utils.data.Dataset): stats=None, info=None, videos_dir=None, - ): + ) -> "LeRobotDataset": + """Create a LeRobot Dataset from existing data and attributes instead of loading from the filesystem. + + It is especially useful when converting raw data into LeRobotDataset before saving the dataset + on the filesystem or uploading to the hub. + + Note: Meta-data attributes like `repo_id`, `version`, `root`, etc are optional and potentially + meaningless depending on the downstream usage of the return dataset. + """ # create an empty object of type LeRobotDataset obj = cls.__new__(cls) obj.repo_id = repo_id @@ -195,6 +207,192 @@ class LeRobotDataset(torch.utils.data.Dataset): obj.hf_dataset = hf_dataset obj.episode_data_index = episode_data_index obj.stats = stats - obj.info = info + obj.info = info if info is not None else {} obj.videos_dir = videos_dir return obj + + +class MultiLeRobotDataset(torch.utils.data.Dataset): + """A dataset consisting of multiple underlying `LeRobotDataset`s. + + The underlying `LeRobotDataset`s are effectively concatenated, and this class adopts much of the API + structure of `LeRobotDataset`. + """ + + def __init__( + self, + repo_ids: list[str], + version: str | None = CODEBASE_VERSION, + root: Path | None = DATA_DIR, + split: str = "train", + transform: Callable | None = None, + delta_timestamps: dict[list[float]] | None = None, + ): + super().__init__() + self.repo_ids = repo_ids + # Construct the underlying datasets passing everything but `transform` and `delta_timestamps` which + # are handled by this class. + self._datasets = [ + LeRobotDataset( + repo_id, + version=version, + root=root, + split=split, + delta_timestamps=delta_timestamps, + transform=transform, + ) + for repo_id in repo_ids + ] + # Check that some properties are consistent across datasets. Note: We may relax some of these + # consistency requirements in future iterations of this class. + for repo_id, dataset in zip(self.repo_ids, self._datasets, strict=True): + if dataset.info != self._datasets[0].info: + raise ValueError( + f"Detected a mismatch in dataset info between {self.repo_ids[0]} and {repo_id}. This is " + "not yet supported." + ) + # Disable any data keys that are not common across all of the datasets. Note: we may relax this + # restriction in future iterations of this class. For now, this is necessary at least for being able + # to use PyTorch's default DataLoader collate function. + self.disabled_data_keys = set() + intersection_data_keys = set(self._datasets[0].hf_dataset.features) + for dataset in self._datasets: + intersection_data_keys.intersection_update(dataset.hf_dataset.features) + if len(intersection_data_keys) == 0: + raise RuntimeError( + "Multiple datasets were provided but they had no keys common to all of them. The " + "multi-dataset functionality currently only keeps common keys." + ) + for repo_id, dataset in zip(self.repo_ids, self._datasets, strict=True): + extra_keys = set(dataset.hf_dataset.features).difference(intersection_data_keys) + logging.warning( + f"keys {extra_keys} of {repo_id} were disabled as they are not contained in all the " + "other datasets." + ) + self.disabled_data_keys.update(extra_keys) + + self.version = version + self.root = root + self.split = split + self.transform = transform + self.delta_timestamps = delta_timestamps + self.stats = aggregate_stats(self._datasets) + + @property + def repo_id_to_index(self): + """Return a mapping from dataset repo_id to a dataset index automatically created by this class. + + This index is incorporated as a data key in the dictionary returned by `__getitem__`. + """ + return {repo_id: i for i, repo_id in enumerate(self.repo_ids)} + + @property + def repo_index_to_id(self): + """Return the inverse mapping if repo_id_to_index.""" + return {v: k for k, v in self.repo_id_to_index} + + @property + def fps(self) -> int: + """Frames per second used during data collection. + + NOTE: Fow now, this relies on a check in __init__ to make sure all sub-datasets have the same info. + """ + return self._datasets[0].info["fps"] + + @property + def video(self) -> bool: + """Returns True if this dataset loads video frames from mp4 files. + + Returns False if it only loads images from png files. + + NOTE: Fow now, this relies on a check in __init__ to make sure all sub-datasets have the same info. + """ + return self._datasets[0].info.get("video", False) + + @property + def features(self) -> datasets.Features: + features = {} + for dataset in self._datasets: + features.update({k: v for k, v in dataset.features.items() if k not in self.disabled_data_keys}) + return features + + @property + def camera_keys(self) -> list[str]: + """Keys to access image and video stream from cameras.""" + keys = [] + for key, feats in self.features.items(): + if isinstance(feats, (datasets.Image, VideoFrame)): + keys.append(key) + return keys + + @property + def video_frame_keys(self) -> list[str]: + """Keys to access video frames that requires to be decoded into images. + + Note: It is empty if the dataset contains images only, + or equal to `self.cameras` if the dataset contains videos only, + or can even be a subset of `self.cameras` in a case of a mixed image/video dataset. + """ + video_frame_keys = [] + for key, feats in self.features.items(): + if isinstance(feats, VideoFrame): + video_frame_keys.append(key) + return video_frame_keys + + @property + def num_samples(self) -> int: + """Number of samples/frames.""" + return sum(d.num_samples for d in self._datasets) + + @property + def num_episodes(self) -> int: + """Number of episodes.""" + return sum(d.num_episodes for d in self._datasets) + + @property + def tolerance_s(self) -> float: + """Tolerance in seconds used to discard loaded frames when their timestamps + are not close enough from the requested frames. It is only used when `delta_timestamps` + is provided or when loading video frames from mp4 files. + """ + # 1e-4 to account for possible numerical error + return 1 / self.fps - 1e-4 + + def __len__(self): + return self.num_samples + + def __getitem__(self, idx: int) -> dict[str, torch.Tensor]: + if idx >= len(self): + raise IndexError(f"Index {idx} out of bounds.") + # Determine which dataset to get an item from based on the index. + start_idx = 0 + dataset_idx = 0 + for dataset in self._datasets: + if idx >= start_idx + dataset.num_samples: + start_idx += dataset.num_samples + dataset_idx += 1 + break + else: + raise AssertionError("We expect the loop to break out as long as the index is within bounds.") + item = self._datasets[dataset_idx][idx - start_idx] + item["dataset_index"] = torch.tensor(dataset_idx) + for data_key in self.disabled_data_keys: + if data_key in item: + del item[data_key] + return item + + def __repr__(self): + return ( + f"{self.__class__.__name__}(\n" + f" Repository IDs: '{self.repo_ids}',\n" + f" Version: '{self.version}',\n" + f" Split: '{self.split}',\n" + f" Number of Samples: {self.num_samples},\n" + f" Number of Episodes: {self.num_episodes},\n" + f" Type: {'video (.mp4)' if self.video else 'image (.png)'},\n" + f" Recorded Frames per Second: {self.fps},\n" + f" Camera Keys: {self.camera_keys},\n" + f" Video Frame Keys: {self.video_frame_keys if self.video else 'N/A'},\n" + f" Transformations: {self.transform},\n" + f")" + ) diff --git a/lerobot/common/datasets/utils.py b/lerobot/common/datasets/utils.py index 86fef8d4..cb2fee95 100644 --- a/lerobot/common/datasets/utils.py +++ b/lerobot/common/datasets/utils.py @@ -59,7 +59,7 @@ def unflatten_dict(d, sep="/"): return outdict -def hf_transform_to_torch(items_dict): +def hf_transform_to_torch(items_dict: dict[torch.Tensor | None]): """Get a transform function that convert items from Hugging Face dataset (pyarrow) to torch tensors. Importantly, images are converted from PIL, which corresponds to a channel last representation (h w c) of uint8 type, to a torch image representation @@ -73,6 +73,8 @@ def hf_transform_to_torch(items_dict): elif isinstance(first_item, dict) and "path" in first_item and "timestamp" in first_item: # video frame will be processed downstream pass + elif first_item is None: + pass else: items_dict[key] = [torch.tensor(x) for x in items_dict[key]] return items_dict @@ -318,8 +320,7 @@ def calculate_episode_data_index(hf_dataset: datasets.Dataset) -> Dict[str, torc def reset_episode_index(hf_dataset: datasets.Dataset) -> datasets.Dataset: - """ - Reset the `episode_index` of the provided HuggingFace Dataset. + """Reset the `episode_index` of the provided HuggingFace Dataset. `episode_data_index` (and related functionality such as `load_previous_and_future_frames`) requires the `episode_index` to be sorted, continuous (1,1,1 and not 1,2,1) and start at 0. @@ -338,6 +339,7 @@ def reset_episode_index(hf_dataset: datasets.Dataset) -> datasets.Dataset: return example hf_dataset = hf_dataset.map(modify_ep_idx_func) + return hf_dataset diff --git a/lerobot/configs/default.yaml b/lerobot/configs/default.yaml index f2238769..85b9ceea 100644 --- a/lerobot/configs/default.yaml +++ b/lerobot/configs/default.yaml @@ -23,6 +23,10 @@ use_amp: false # `seed` is used for training (eg: model initialization, dataset shuffling) # AND for the evaluation environments. seed: ??? +# You may provide a list of datasets here. `train.py` creates them all and concatenates them. Note: only data +# keys common between the datasets are kept. Each dataset gets and additional transform that inserts the +# "dataset_index" into the returned item. The index mapping is made according to the order in which the +# datsets are provided. dataset_repo_id: lerobot/pusht training: diff --git a/lerobot/scripts/push_dataset_to_hub.py b/lerobot/scripts/push_dataset_to_hub.py index c6eac5e9..52252b57 100644 --- a/lerobot/scripts/push_dataset_to_hub.py +++ b/lerobot/scripts/push_dataset_to_hub.py @@ -71,9 +71,9 @@ import torch from huggingface_hub import HfApi from safetensors.torch import save_file +from lerobot.common.datasets.compute_stats import compute_stats from lerobot.common.datasets.lerobot_dataset import CODEBASE_VERSION, LeRobotDataset from lerobot.common.datasets.push_dataset_to_hub._download_raw import download_raw -from lerobot.common.datasets.push_dataset_to_hub.compute_stats import compute_stats from lerobot.common.datasets.utils import flatten_dict diff --git a/lerobot/scripts/train.py b/lerobot/scripts/train.py index eb33b268..08ad6e66 100644 --- a/lerobot/scripts/train.py +++ b/lerobot/scripts/train.py @@ -16,7 +16,6 @@ import logging import time from contextlib import nullcontext -from copy import deepcopy from pathlib import Path from pprint import pformat @@ -28,6 +27,7 @@ from termcolor import colored from torch.cuda.amp import GradScaler from lerobot.common.datasets.factory import make_dataset, resolve_delta_timestamps +from lerobot.common.datasets.lerobot_dataset import MultiLeRobotDataset from lerobot.common.datasets.utils import cycle from lerobot.common.envs.factory import make_env from lerobot.common.logger import Logger, log_output_dir @@ -280,6 +280,11 @@ def train(cfg: DictConfig, out_dir: str | None = None, job_name: str | None = No logging.info("make_dataset") offline_dataset = make_dataset(cfg) + if isinstance(offline_dataset, MultiLeRobotDataset): + logging.info( + "Multiple datasets were provided. Applied the following index mapping to the provided datasets: " + f"{pformat(offline_dataset.repo_id_to_index , indent=2)}" + ) # Create environment used for evaluating checkpoints during training on simulation data. # On real-world data, no need to create an environment as evaluations are done outside train.py, @@ -330,7 +335,7 @@ def train(cfg: DictConfig, out_dir: str | None = None, job_name: str | None = No max_episodes_rendered=4, start_seed=cfg.seed, ) - log_eval_info(logger, eval_info["aggregated"], step, cfg, offline_dataset, is_offline) + log_eval_info(logger, eval_info["aggregated"], step, cfg, offline_dataset, is_offline=True) if cfg.wandb.enable: logger.log_video(eval_info["video_paths"][0], step, mode="eval") logging.info("Resume training") @@ -362,7 +367,6 @@ def train(cfg: DictConfig, out_dir: str | None = None, job_name: str | None = No dl_iter = cycle(dataloader) policy.train() - is_offline = True for _ in range(step, cfg.training.offline_steps): if step == 0: logging.info("Start offline training on a fixed dataset") @@ -382,7 +386,7 @@ def train(cfg: DictConfig, out_dir: str | None = None, job_name: str | None = No ) if step % cfg.training.log_freq == 0: - log_train_info(logger, train_info, step, cfg, offline_dataset, is_offline) + log_train_info(logger, train_info, step, cfg, offline_dataset, is_offline=True) # Note: evaluate_and_checkpoint_if_needed happens **after** the `step`th training update has completed, # so we pass in step + 1. @@ -390,41 +394,8 @@ def train(cfg: DictConfig, out_dir: str | None = None, job_name: str | None = No step += 1 - logging.info("End of offline training") - - if cfg.training.online_steps == 0: - if cfg.training.eval_freq > 0: - eval_env.close() - return - - # create an env dedicated to online episodes collection from policy rollout - online_training_env = make_env(cfg, n_envs=1) - - # create an empty online dataset similar to offline dataset - online_dataset = deepcopy(offline_dataset) - online_dataset.hf_dataset = {} - online_dataset.episode_data_index = {} - - # create dataloader for online training - concat_dataset = torch.utils.data.ConcatDataset([offline_dataset, online_dataset]) - weights = [1.0] * len(concat_dataset) - sampler = torch.utils.data.WeightedRandomSampler( - weights, num_samples=len(concat_dataset), replacement=True - ) - dataloader = torch.utils.data.DataLoader( - concat_dataset, - num_workers=4, - batch_size=cfg.training.batch_size, - sampler=sampler, - pin_memory=device.type != "cpu", - drop_last=False, - ) - - logging.info("End of online training") - - if cfg.training.eval_freq > 0: - eval_env.close() - online_training_env.close() + eval_env.close() + logging.info("End of training") @hydra.main(version_base="1.2", config_name="default", config_path="../configs") diff --git a/tests/test_datasets.py b/tests/test_datasets.py index afea16a5..e01fc52c 100644 --- a/tests/test_datasets.py +++ b/tests/test_datasets.py @@ -25,26 +25,34 @@ from datasets import Dataset from safetensors.torch import load_file import lerobot -from lerobot.common.datasets.factory import make_dataset -from lerobot.common.datasets.lerobot_dataset import ( - LeRobotDataset, -) -from lerobot.common.datasets.push_dataset_to_hub.compute_stats import ( +from lerobot.common.datasets.compute_stats import ( + aggregate_stats, compute_stats, get_stats_einops_patterns, ) +from lerobot.common.datasets.factory import make_dataset +from lerobot.common.datasets.lerobot_dataset import LeRobotDataset from lerobot.common.datasets.utils import ( flatten_dict, hf_transform_to_torch, load_previous_and_future_frames, unflatten_dict, ) -from lerobot.common.utils.utils import init_hydra_config +from lerobot.common.utils.utils import init_hydra_config, seeded_context from tests.utils import DEFAULT_CONFIG_PATH, DEVICE -@pytest.mark.parametrize("env_name, repo_id, policy_name", lerobot.env_dataset_policy_triplets) +@pytest.mark.parametrize( + "env_name, repo_id, policy_name", + lerobot.env_dataset_policy_triplets + + [("aloha", ["lerobot/aloha_sim_insertion_human", "lerobot/aloha_sim_transfer_cube_human"], "act")], +) def test_factory(env_name, repo_id, policy_name): + """ + Tests that: + - we can create a dataset with the factory. + - for a commonly used set of data keys, the data dimensions are correct. + """ cfg = init_hydra_config( DEFAULT_CONFIG_PATH, overrides=[ @@ -315,3 +323,31 @@ def test_backward_compatibility(repo_id): # i = dataset.episode_data_index["to"][-1].item() # load_and_compare(i - 2) # load_and_compare(i - 1) + + +def test_aggregate_stats(): + """Makes 3 basic datasets and checks that aggregate stats are computed correctly.""" + with seeded_context(0): + data_a = torch.rand(30, dtype=torch.float32) + data_b = torch.rand(20, dtype=torch.float32) + data_c = torch.rand(20, dtype=torch.float32) + + hf_dataset_1 = Dataset.from_dict( + {"a": data_a[:10], "b": data_b[:10], "c": data_c[:10], "index": torch.arange(10)} + ) + hf_dataset_1.set_transform(hf_transform_to_torch) + hf_dataset_2 = Dataset.from_dict({"a": data_a[10:20], "b": data_b[10:], "index": torch.arange(10)}) + hf_dataset_2.set_transform(hf_transform_to_torch) + hf_dataset_3 = Dataset.from_dict({"a": data_a[20:], "c": data_c[10:], "index": torch.arange(10)}) + hf_dataset_3.set_transform(hf_transform_to_torch) + dataset_1 = LeRobotDataset.from_preloaded("d1", hf_dataset=hf_dataset_1) + dataset_1.stats = compute_stats(dataset_1, batch_size=len(hf_dataset_1), num_workers=0) + dataset_2 = LeRobotDataset.from_preloaded("d2", hf_dataset=hf_dataset_2) + dataset_2.stats = compute_stats(dataset_2, batch_size=len(hf_dataset_2), num_workers=0) + dataset_3 = LeRobotDataset.from_preloaded("d3", hf_dataset=hf_dataset_3) + dataset_3.stats = compute_stats(dataset_3, batch_size=len(hf_dataset_3), num_workers=0) + stats = aggregate_stats([dataset_1, dataset_2, dataset_3]) + for data_key, data in zip(["a", "b", "c"], [data_a, data_b, data_c], strict=True): + for agg_fn in ["mean", "min", "max"]: + assert torch.allclose(stats[data_key][agg_fn], einops.reduce(data, "n -> 1", agg_fn)) + assert torch.allclose(stats[data_key]["std"], torch.std(data, correction=0)) From 0b51a335bc8256f9830019ed3b74b2762daa75b7 Mon Sep 17 00:00:00 2001 From: Alexander Soare Date: Thu, 30 May 2024 17:46:25 +0100 Subject: [PATCH 5/6] Add a test for MultiLeRobotDataset making sure it produces all frames. (#230) Co-authored-by: Remi --- tests/test_datasets.py | 29 ++++++++++++++++++++++++++++- 1 file changed, 28 insertions(+), 1 deletion(-) diff --git a/tests/test_datasets.py b/tests/test_datasets.py index e01fc52c..dac18c14 100644 --- a/tests/test_datasets.py +++ b/tests/test_datasets.py @@ -16,6 +16,7 @@ import json import logging from copy import deepcopy +from itertools import chain from pathlib import Path import einops @@ -31,7 +32,7 @@ from lerobot.common.datasets.compute_stats import ( get_stats_einops_patterns, ) from lerobot.common.datasets.factory import make_dataset -from lerobot.common.datasets.lerobot_dataset import LeRobotDataset +from lerobot.common.datasets.lerobot_dataset import LeRobotDataset, MultiLeRobotDataset from lerobot.common.datasets.utils import ( flatten_dict, hf_transform_to_torch, @@ -113,6 +114,32 @@ def test_factory(env_name, repo_id, policy_name): assert key in item, f"{key}" +def test_multilerobotdataset_frames(): + """Check that all dataset frames are incorporated.""" + # Note: use the image variants of the dataset to make the test approx 3x faster. + repo_ids = ["lerobot/aloha_sim_insertion_human_image", "lerobot/aloha_sim_transfer_cube_human_image"] + sub_datasets = [LeRobotDataset(repo_id) for repo_id in repo_ids] + dataset = MultiLeRobotDataset(repo_ids) + assert len(dataset) == sum(len(d) for d in sub_datasets) + assert dataset.num_samples == sum(d.num_samples for d in sub_datasets) + assert dataset.num_episodes == sum(d.num_episodes for d in sub_datasets) + + # Run through all items of the LeRobotDatasets in parallel with the items of the MultiLerobotDataset and + # check they match. + expected_dataset_indices = [] + for i, sub_dataset in enumerate(sub_datasets): + expected_dataset_indices.extend([i] * len(sub_dataset)) + + for expected_dataset_index, sub_dataset_item, dataset_item in zip( + expected_dataset_indices, chain(*sub_datasets), dataset, strict=True + ): + dataset_index = dataset_item.pop("dataset_index") + assert dataset_index == expected_dataset_index + assert sub_dataset_item.keys() == dataset_item.keys() + for k in sub_dataset_item: + assert torch.equal(sub_dataset_item[k], dataset_item[k]) + + def test_compute_stats_on_xarm(): """Check that the statistics are computed correctly according to the stats_patterns property. From 57fb5fe8a68d231c8a8c323ac7f845daa11e2079 Mon Sep 17 00:00:00 2001 From: Alexander Soare Date: Thu, 30 May 2024 18:16:44 +0100 Subject: [PATCH 6/6] Improve documentation on VAE encoder inputs (#215) --- lerobot/common/policies/act/modeling_act.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/lerobot/common/policies/act/modeling_act.py b/lerobot/common/policies/act/modeling_act.py index 72ebdd7a..eafe677b 100644 --- a/lerobot/common/policies/act/modeling_act.py +++ b/lerobot/common/policies/act/modeling_act.py @@ -198,7 +198,7 @@ class ACT(nn.Module): def __init__(self, config: ACTConfig): super().__init__() self.config = config - # BERT style VAE encoder with input [cls, *joint_space_configuration, *action_sequence]. + # BERT style VAE encoder with input tokens [cls, robot_state, *action_sequence]. # The cls token forms parameters of the latent's distribution (like this [*means, *log_variances]). if self.config.use_vae: self.vae_encoder = ACTEncoder(config) @@ -214,7 +214,7 @@ class ACT(nn.Module): self.latent_dim = config.latent_dim # Projection layer from the VAE encoder's output to the latent distribution's parameter space. self.vae_encoder_latent_output_proj = nn.Linear(config.dim_model, self.latent_dim * 2) - # Fixed sinusoidal positional embedding the whole input to the VAE encoder. Unsqueeze for batch + # Fixed sinusoidal positional embedding for the input to the VAE encoder. Unsqueeze for batch # dimension. self.register_buffer( "vae_encoder_pos_enc",