diff --git a/lerobot/common/datasets/sampler.py b/lerobot/common/datasets/sampler.py new file mode 100644 index 00000000..b13691d2 --- /dev/null +++ b/lerobot/common/datasets/sampler.py @@ -0,0 +1,47 @@ +#!/usr/bin/env python + +# Copyright 2024 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import Union + +from torch.utils.data import SubsetRandomSampler + + +class EpisodeAwareSampler(SubsetRandomSampler): + def __init__( + self, + episode_data_index: dict, + episode_indices_to_use: Union[list, None] = None, + drop_n_first_frames: int = 0, + drop_n_last_frames: int = 0, + ): + """Sampler that optionally incorporates episode boundary information. + + Args: + episode_data_index (dict): Dictionary with keys 'from' and 'to' containing the start and end indices of each episode. + episode_indices_to_use (list, optional): List of episode indices to use. If None, all episodes are used. Defaults to None. + Assumes that episodes are indexed from 0 to N-1. + drop_n_first_frames (int, optional): Number of frames to drop from the start of each episode. Defaults to 0. + drop_n_last_frames (int, optional): Number of frames to drop from the end of each episode. Defaults to 0. + """ + indices = [] + for episode_idx, (start_index, end_index) in enumerate( + zip(episode_data_index["from"], episode_data_index["to"], strict=False) + ): + if episode_indices_to_use is None or episode_idx in episode_indices_to_use: + indices.extend( + range(start_index.item() + drop_n_first_frames, end_index.item() - drop_n_last_frames) + ) + + super().__init__(indices) diff --git a/lerobot/configs/policy/diffusion.yaml b/lerobot/configs/policy/diffusion.yaml index 36bd22cc..2eb7a162 100644 --- a/lerobot/configs/policy/diffusion.yaml +++ b/lerobot/configs/policy/diffusion.yaml @@ -44,6 +44,10 @@ training: observation.state: "[i / ${fps} for i in range(1 - ${policy.n_obs_steps}, 1)]" action: "[i / ${fps} for i in range(1 - ${policy.n_obs_steps}, 1 - ${policy.n_obs_steps} + ${policy.horizon})]" + # The original implementation doesn't sample frames for the last 7 steps, + # which avoids excessive padding and leads to improved training results. + drop_n_last_frames: ${policy.horizon} - ${policy.n_action_steps} - ${policy.n_obs_steps} + 1 + eval: n_episodes: 50 batch_size: 50 diff --git a/lerobot/scripts/train.py b/lerobot/scripts/train.py index 5fb86f36..1c575888 100644 --- a/lerobot/scripts/train.py +++ b/lerobot/scripts/train.py @@ -28,6 +28,7 @@ from termcolor import colored from torch.cuda.amp import GradScaler from lerobot.common.datasets.factory import make_dataset, resolve_delta_timestamps +from lerobot.common.datasets.sampler import EpisodeAwareSampler from lerobot.common.datasets.utils import cycle from lerobot.common.envs.factory import make_env from lerobot.common.logger import Logger, log_output_dir @@ -351,7 +352,10 @@ def train(cfg: DictConfig, out_dir: str | None = None, job_name: str | None = No offline_dataset, num_workers=4, batch_size=cfg.training.batch_size, - shuffle=True, + sampler=EpisodeAwareSampler( + offline_dataset.episode_data_index, + drop_n_last_frames=cfg.drop_n_last_frames if hasattr(cfg, "drop_n_last_frames") else 0, + ), pin_memory=device.type != "cpu", drop_last=False, ) diff --git a/tests/test_sampler.py b/tests/test_sampler.py new file mode 100644 index 00000000..6effda37 --- /dev/null +++ b/tests/test_sampler.py @@ -0,0 +1,64 @@ +#!/usr/bin/env python + +# Copyright 2024 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from datasets import Dataset + +from lerobot.common.datasets.sampler import EpisodeAwareSampler +from lerobot.common.datasets.utils import ( + calculate_episode_data_index, + hf_transform_to_torch, +) + + +def test_drop_n_first_frames(): + dataset = Dataset.from_dict( + { + "timestamp": [0.1, 0.2, 0.3, 0.4, 0.5, 0.6], + "index": [0, 1, 2, 3, 4, 5], + "episode_index": [0, 0, 1, 2, 2, 2], + }, + ) + dataset.set_transform(hf_transform_to_torch) + episode_data_index = calculate_episode_data_index(dataset) + sampler = EpisodeAwareSampler(episode_data_index, drop_n_first_frames=1) + assert sampler.indices == [1, 4, 5] + + +def test_drop_n_last_frames(): + dataset = Dataset.from_dict( + { + "timestamp": [0.1, 0.2, 0.3, 0.4, 0.5, 0.6], + "index": [0, 1, 2, 3, 4, 5], + "episode_index": [0, 0, 1, 2, 2, 2], + }, + ) + dataset.set_transform(hf_transform_to_torch) + episode_data_index = calculate_episode_data_index(dataset) + sampler = EpisodeAwareSampler(episode_data_index, drop_n_last_frames=1) + assert sampler.indices == [0, 3, 4] + + +def test_episode_indices_to_use(): + dataset = Dataset.from_dict( + { + "timestamp": [0.1, 0.2, 0.3, 0.4, 0.5, 0.6], + "index": [0, 1, 2, 3, 4, 5], + "episode_index": [0, 0, 1, 2, 2, 2], + }, + ) + dataset.set_transform(hf_transform_to_torch) + episode_data_index = calculate_episode_data_index(dataset) + sampler = EpisodeAwareSampler(episode_data_index, episode_indices_to_use=[0, 2]) + assert sampler.indices == [0, 1, 3, 4, 5]