From 633115d861f05ed0dbc71ac70790fd2b6610a527 Mon Sep 17 00:00:00 2001 From: Alexander Soare Date: Fri, 31 May 2024 09:03:28 +0100 Subject: [PATCH 1/3] Fix chaining in MultiLerobotDataset (#233) --- lerobot/common/datasets/lerobot_dataset.py | 1 + tests/test_datasets.py | 9 ++++++++- 2 files changed, 9 insertions(+), 1 deletion(-) diff --git a/lerobot/common/datasets/lerobot_dataset.py b/lerobot/common/datasets/lerobot_dataset.py index a87c3ee8..58ae51b1 100644 --- a/lerobot/common/datasets/lerobot_dataset.py +++ b/lerobot/common/datasets/lerobot_dataset.py @@ -371,6 +371,7 @@ class MultiLeRobotDataset(torch.utils.data.Dataset): if idx >= start_idx + dataset.num_samples: start_idx += dataset.num_samples dataset_idx += 1 + continue break else: raise AssertionError("We expect the loop to break out as long as the index is within bounds.") diff --git a/tests/test_datasets.py b/tests/test_datasets.py index dac18c14..da0ae755 100644 --- a/tests/test_datasets.py +++ b/tests/test_datasets.py @@ -114,10 +114,17 @@ def test_factory(env_name, repo_id, policy_name): assert key in item, f"{key}" +# TODO(alexander-soare): If you're hunting for savings on testing time, this takes about 5 seconds. def test_multilerobotdataset_frames(): """Check that all dataset frames are incorporated.""" # Note: use the image variants of the dataset to make the test approx 3x faster. - repo_ids = ["lerobot/aloha_sim_insertion_human_image", "lerobot/aloha_sim_transfer_cube_human_image"] + # Note: We really do need three repo_ids here as at some point this caught an issue with the chaining + # logic that wouldn't be caught with two repo IDs. + repo_ids = [ + "lerobot/aloha_sim_insertion_human_image", + "lerobot/aloha_sim_transfer_cube_human_image", + "lerobot/aloha_sim_insertion_scripted_image", + ] sub_datasets = [LeRobotDataset(repo_id) for repo_id in repo_ids] dataset = MultiLeRobotDataset(repo_ids) assert len(dataset) == sum(len(d) for d in sub_datasets) From 83f4f7f7e83d8f0115463c7dd1b8e0b0da863dc2 Mon Sep 17 00:00:00 2001 From: Radek Osmulski Date: Fri, 31 May 2024 18:19:01 +1000 Subject: [PATCH 2/3] Add precision param to format_big_number (#232) --- lerobot/common/utils/utils.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/lerobot/common/utils/utils.py b/lerobot/common/utils/utils.py index 696999ad..c429efbd 100644 --- a/lerobot/common/utils/utils.py +++ b/lerobot/common/utils/utils.py @@ -120,13 +120,13 @@ def init_logging(): logging.getLogger().addHandler(console_handler) -def format_big_number(num): +def format_big_number(num, precision=0): suffixes = ["", "K", "M", "B", "T", "Q"] divisor = 1000.0 for suffix in suffixes: if abs(num) < divisor: - return f"{num:.0f}{suffix}" + return f"{num:.{precision}f}{suffix}" num /= divisor return num From 504d2aaf485b3819b2d39c8faa2953b1a49f2aed Mon Sep 17 00:00:00 2001 From: Radek Osmulski Date: Fri, 31 May 2024 22:43:47 +1000 Subject: [PATCH 3/3] add EpisodeAwareSampler (#217) Co-authored-by: Alexander Soare --- lerobot/common/datasets/sampler.py | 61 ++++++++++++++++++ lerobot/configs/policy/diffusion.yaml | 4 ++ lerobot/scripts/train.py | 14 ++++- tests/test_sampler.py | 90 +++++++++++++++++++++++++++ 4 files changed, 168 insertions(+), 1 deletion(-) create mode 100644 lerobot/common/datasets/sampler.py create mode 100644 tests/test_sampler.py diff --git a/lerobot/common/datasets/sampler.py b/lerobot/common/datasets/sampler.py new file mode 100644 index 00000000..2f6c15c1 --- /dev/null +++ b/lerobot/common/datasets/sampler.py @@ -0,0 +1,61 @@ +#!/usr/bin/env python + +# Copyright 2024 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import Iterator, Union + +import torch + + +class EpisodeAwareSampler: + def __init__( + self, + episode_data_index: dict, + episode_indices_to_use: Union[list, None] = None, + drop_n_first_frames: int = 0, + drop_n_last_frames: int = 0, + shuffle: bool = False, + ): + """Sampler that optionally incorporates episode boundary information. + + Args: + episode_data_index: Dictionary with keys 'from' and 'to' containing the start and end indices of each episode. + episode_indices_to_use: List of episode indices to use. If None, all episodes are used. + Assumes that episodes are indexed from 0 to N-1. + drop_n_first_frames: Number of frames to drop from the start of each episode. + drop_n_last_frames: Number of frames to drop from the end of each episode. + shuffle: Whether to shuffle the indices. + """ + indices = [] + for episode_idx, (start_index, end_index) in enumerate( + zip(episode_data_index["from"], episode_data_index["to"], strict=True) + ): + if episode_indices_to_use is None or episode_idx in episode_indices_to_use: + indices.extend( + range(start_index.item() + drop_n_first_frames, end_index.item() - drop_n_last_frames) + ) + + self.indices = indices + self.shuffle = shuffle + + def __iter__(self) -> Iterator[int]: + if self.shuffle: + for i in torch.randperm(len(self.indices)): + yield self.indices[i] + else: + for i in self.indices: + yield i + + def __len__(self) -> int: + return len(self.indices) diff --git a/lerobot/configs/policy/diffusion.yaml b/lerobot/configs/policy/diffusion.yaml index 36bd22cc..b04ecf1b 100644 --- a/lerobot/configs/policy/diffusion.yaml +++ b/lerobot/configs/policy/diffusion.yaml @@ -44,6 +44,10 @@ training: observation.state: "[i / ${fps} for i in range(1 - ${policy.n_obs_steps}, 1)]" action: "[i / ${fps} for i in range(1 - ${policy.n_obs_steps}, 1 - ${policy.n_obs_steps} + ${policy.horizon})]" + # The original implementation doesn't sample frames for the last 7 steps, + # which avoids excessive padding and leads to improved training results. + drop_n_last_frames: 7 # ${policy.horizon} - ${policy.n_action_steps} - ${policy.n_obs_steps} + 1 + eval: n_episodes: 50 batch_size: 50 diff --git a/lerobot/scripts/train.py b/lerobot/scripts/train.py index 08ad6e66..860412bd 100644 --- a/lerobot/scripts/train.py +++ b/lerobot/scripts/train.py @@ -28,6 +28,7 @@ from torch.cuda.amp import GradScaler from lerobot.common.datasets.factory import make_dataset, resolve_delta_timestamps from lerobot.common.datasets.lerobot_dataset import MultiLeRobotDataset +from lerobot.common.datasets.sampler import EpisodeAwareSampler from lerobot.common.datasets.utils import cycle from lerobot.common.envs.factory import make_env from lerobot.common.logger import Logger, log_output_dir @@ -356,11 +357,22 @@ def train(cfg: DictConfig, out_dir: str | None = None, job_name: str | None = No logging.info("Resume training") # create dataloader for offline training + if cfg.training.get("drop_n_last_frames"): + shuffle = False + sampler = EpisodeAwareSampler( + offline_dataset.episode_data_index, + drop_n_last_frames=cfg.training.drop_n_last_frames, + shuffle=True, + ) + else: + shuffle = True + sampler = None dataloader = torch.utils.data.DataLoader( offline_dataset, num_workers=cfg.training.num_workers, batch_size=cfg.training.batch_size, - shuffle=True, + shuffle=shuffle, + sampler=sampler, pin_memory=device.type != "cpu", drop_last=False, ) diff --git a/tests/test_sampler.py b/tests/test_sampler.py new file mode 100644 index 00000000..635e7f11 --- /dev/null +++ b/tests/test_sampler.py @@ -0,0 +1,90 @@ +#!/usr/bin/env python + +# Copyright 2024 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from datasets import Dataset + +from lerobot.common.datasets.sampler import EpisodeAwareSampler +from lerobot.common.datasets.utils import ( + calculate_episode_data_index, + hf_transform_to_torch, +) + + +def test_drop_n_first_frames(): + dataset = Dataset.from_dict( + { + "timestamp": [0.1, 0.2, 0.3, 0.4, 0.5, 0.6], + "index": [0, 1, 2, 3, 4, 5], + "episode_index": [0, 0, 1, 2, 2, 2], + }, + ) + dataset.set_transform(hf_transform_to_torch) + episode_data_index = calculate_episode_data_index(dataset) + sampler = EpisodeAwareSampler(episode_data_index, drop_n_first_frames=1) + assert sampler.indices == [1, 4, 5] + assert len(sampler) == 3 + assert list(sampler) == [1, 4, 5] + + +def test_drop_n_last_frames(): + dataset = Dataset.from_dict( + { + "timestamp": [0.1, 0.2, 0.3, 0.4, 0.5, 0.6], + "index": [0, 1, 2, 3, 4, 5], + "episode_index": [0, 0, 1, 2, 2, 2], + }, + ) + dataset.set_transform(hf_transform_to_torch) + episode_data_index = calculate_episode_data_index(dataset) + sampler = EpisodeAwareSampler(episode_data_index, drop_n_last_frames=1) + assert sampler.indices == [0, 3, 4] + assert len(sampler) == 3 + assert list(sampler) == [0, 3, 4] + + +def test_episode_indices_to_use(): + dataset = Dataset.from_dict( + { + "timestamp": [0.1, 0.2, 0.3, 0.4, 0.5, 0.6], + "index": [0, 1, 2, 3, 4, 5], + "episode_index": [0, 0, 1, 2, 2, 2], + }, + ) + dataset.set_transform(hf_transform_to_torch) + episode_data_index = calculate_episode_data_index(dataset) + sampler = EpisodeAwareSampler(episode_data_index, episode_indices_to_use=[0, 2]) + assert sampler.indices == [0, 1, 3, 4, 5] + assert len(sampler) == 5 + assert list(sampler) == [0, 1, 3, 4, 5] + + +def test_shuffle(): + dataset = Dataset.from_dict( + { + "timestamp": [0.1, 0.2, 0.3, 0.4, 0.5, 0.6], + "index": [0, 1, 2, 3, 4, 5], + "episode_index": [0, 0, 1, 2, 2, 2], + }, + ) + dataset.set_transform(hf_transform_to_torch) + episode_data_index = calculate_episode_data_index(dataset) + sampler = EpisodeAwareSampler(episode_data_index, shuffle=False) + assert sampler.indices == [0, 1, 2, 3, 4, 5] + assert len(sampler) == 6 + assert list(sampler) == [0, 1, 2, 3, 4, 5] + sampler = EpisodeAwareSampler(episode_data_index, shuffle=True) + assert sampler.indices == [0, 1, 2, 3, 4, 5] + assert len(sampler) == 6 + assert set(sampler) == {0, 1, 2, 3, 4, 5}