add EpisodeAwareSampler

EpisodeAwareSampler allows the dropping of first n or last n frames
from each episode.

Additionally, you can select which episodes to run on by passig in
their ids (complements functionality of the hf dataset slice syntax
to give extra flexibility to the user).

* modify train.py and diffusion.yaml to leverage new functionality
  and train on pusht with last 7 steps per episode dropped
* add unit tests
This commit is contained in:
Radek Osmulski 2024-05-28 18:12:57 +10:00
parent 13310681b1
commit e7abcc2ffd
4 changed files with 120 additions and 1 deletions

View File

@ -0,0 +1,47 @@
#!/usr/bin/env python
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Union
from torch.utils.data import SubsetRandomSampler
class EpisodeAwareSampler(SubsetRandomSampler):
def __init__(
self,
episode_data_index: dict,
episode_indices_to_use: Union[list, None] = None,
drop_n_first_frames: int = 0,
drop_n_last_frames: int = 0,
):
"""Sampler that optionally incorporates episode boundary information.
Args:
episode_data_index (dict): Dictionary with keys 'from' and 'to' containing the start and end indices of each episode.
episode_indices_to_use (list, optional): List of episode indices to use. If None, all episodes are used. Defaults to None.
Assumes that episodes are indexed from 0 to N-1.
drop_n_first_frames (int, optional): Number of frames to drop from the start of each episode. Defaults to 0.
drop_n_last_frames (int, optional): Number of frames to drop from the end of each episode. Defaults to 0.
"""
indices = []
for episode_idx, (start_index, end_index) in enumerate(
zip(episode_data_index["from"], episode_data_index["to"], strict=False)
):
if episode_indices_to_use is None or episode_idx in episode_indices_to_use:
indices.extend(
range(start_index.item() + drop_n_first_frames, end_index.item() - drop_n_last_frames)
)
super().__init__(indices)

View File

@ -44,6 +44,10 @@ training:
observation.state: "[i / ${fps} for i in range(1 - ${policy.n_obs_steps}, 1)]"
action: "[i / ${fps} for i in range(1 - ${policy.n_obs_steps}, 1 - ${policy.n_obs_steps} + ${policy.horizon})]"
# The original implementation doesn't sample frames for the last 7 steps,
# which avoids excessive padding and leads to improved training results.
drop_n_last_frames: ${policy.horizon} - ${policy.n_action_steps} - ${policy.n_obs_steps} + 1
eval:
n_episodes: 50
batch_size: 50

View File

@ -28,6 +28,7 @@ from termcolor import colored
from torch.cuda.amp import GradScaler
from lerobot.common.datasets.factory import make_dataset, resolve_delta_timestamps
from lerobot.common.datasets.sampler import EpisodeAwareSampler
from lerobot.common.datasets.utils import cycle
from lerobot.common.envs.factory import make_env
from lerobot.common.logger import Logger, log_output_dir
@ -351,7 +352,10 @@ def train(cfg: DictConfig, out_dir: str | None = None, job_name: str | None = No
offline_dataset,
num_workers=4,
batch_size=cfg.training.batch_size,
shuffle=True,
sampler=EpisodeAwareSampler(
offline_dataset.episode_data_index,
drop_n_last_frames=cfg.drop_n_last_frames if hasattr(cfg, "drop_n_last_frames") else 0,
),
pin_memory=device.type != "cpu",
drop_last=False,
)

64
tests/test_sampler.py Normal file
View File

@ -0,0 +1,64 @@
#!/usr/bin/env python
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from datasets import Dataset
from lerobot.common.datasets.sampler import EpisodeAwareSampler
from lerobot.common.datasets.utils import (
calculate_episode_data_index,
hf_transform_to_torch,
)
def test_drop_n_first_frames():
dataset = Dataset.from_dict(
{
"timestamp": [0.1, 0.2, 0.3, 0.4, 0.5, 0.6],
"index": [0, 1, 2, 3, 4, 5],
"episode_index": [0, 0, 1, 2, 2, 2],
},
)
dataset.set_transform(hf_transform_to_torch)
episode_data_index = calculate_episode_data_index(dataset)
sampler = EpisodeAwareSampler(episode_data_index, drop_n_first_frames=1)
assert sampler.indices == [1, 4, 5]
def test_drop_n_last_frames():
dataset = Dataset.from_dict(
{
"timestamp": [0.1, 0.2, 0.3, 0.4, 0.5, 0.6],
"index": [0, 1, 2, 3, 4, 5],
"episode_index": [0, 0, 1, 2, 2, 2],
},
)
dataset.set_transform(hf_transform_to_torch)
episode_data_index = calculate_episode_data_index(dataset)
sampler = EpisodeAwareSampler(episode_data_index, drop_n_last_frames=1)
assert sampler.indices == [0, 3, 4]
def test_episode_indices_to_use():
dataset = Dataset.from_dict(
{
"timestamp": [0.1, 0.2, 0.3, 0.4, 0.5, 0.6],
"index": [0, 1, 2, 3, 4, 5],
"episode_index": [0, 0, 1, 2, 2, 2],
},
)
dataset.set_transform(hf_transform_to_torch)
episode_data_index = calculate_episode_data_index(dataset)
sampler = EpisodeAwareSampler(episode_data_index, episode_indices_to_use=[0, 2])
assert sampler.indices == [0, 1, 3, 4, 5]