add EpisodeAwareSampler (#217)
Co-authored-by: Alexander Soare <alexander.soare159@gmail.com>
This commit is contained in:
parent
83f4f7f7e8
commit
504d2aaf48
|
@ -0,0 +1,61 @@
|
||||||
|
#!/usr/bin/env python
|
||||||
|
|
||||||
|
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
from typing import Iterator, Union
|
||||||
|
|
||||||
|
import torch
|
||||||
|
|
||||||
|
|
||||||
|
class EpisodeAwareSampler:
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
episode_data_index: dict,
|
||||||
|
episode_indices_to_use: Union[list, None] = None,
|
||||||
|
drop_n_first_frames: int = 0,
|
||||||
|
drop_n_last_frames: int = 0,
|
||||||
|
shuffle: bool = False,
|
||||||
|
):
|
||||||
|
"""Sampler that optionally incorporates episode boundary information.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
episode_data_index: Dictionary with keys 'from' and 'to' containing the start and end indices of each episode.
|
||||||
|
episode_indices_to_use: List of episode indices to use. If None, all episodes are used.
|
||||||
|
Assumes that episodes are indexed from 0 to N-1.
|
||||||
|
drop_n_first_frames: Number of frames to drop from the start of each episode.
|
||||||
|
drop_n_last_frames: Number of frames to drop from the end of each episode.
|
||||||
|
shuffle: Whether to shuffle the indices.
|
||||||
|
"""
|
||||||
|
indices = []
|
||||||
|
for episode_idx, (start_index, end_index) in enumerate(
|
||||||
|
zip(episode_data_index["from"], episode_data_index["to"], strict=True)
|
||||||
|
):
|
||||||
|
if episode_indices_to_use is None or episode_idx in episode_indices_to_use:
|
||||||
|
indices.extend(
|
||||||
|
range(start_index.item() + drop_n_first_frames, end_index.item() - drop_n_last_frames)
|
||||||
|
)
|
||||||
|
|
||||||
|
self.indices = indices
|
||||||
|
self.shuffle = shuffle
|
||||||
|
|
||||||
|
def __iter__(self) -> Iterator[int]:
|
||||||
|
if self.shuffle:
|
||||||
|
for i in torch.randperm(len(self.indices)):
|
||||||
|
yield self.indices[i]
|
||||||
|
else:
|
||||||
|
for i in self.indices:
|
||||||
|
yield i
|
||||||
|
|
||||||
|
def __len__(self) -> int:
|
||||||
|
return len(self.indices)
|
|
@ -44,6 +44,10 @@ training:
|
||||||
observation.state: "[i / ${fps} for i in range(1 - ${policy.n_obs_steps}, 1)]"
|
observation.state: "[i / ${fps} for i in range(1 - ${policy.n_obs_steps}, 1)]"
|
||||||
action: "[i / ${fps} for i in range(1 - ${policy.n_obs_steps}, 1 - ${policy.n_obs_steps} + ${policy.horizon})]"
|
action: "[i / ${fps} for i in range(1 - ${policy.n_obs_steps}, 1 - ${policy.n_obs_steps} + ${policy.horizon})]"
|
||||||
|
|
||||||
|
# The original implementation doesn't sample frames for the last 7 steps,
|
||||||
|
# which avoids excessive padding and leads to improved training results.
|
||||||
|
drop_n_last_frames: 7 # ${policy.horizon} - ${policy.n_action_steps} - ${policy.n_obs_steps} + 1
|
||||||
|
|
||||||
eval:
|
eval:
|
||||||
n_episodes: 50
|
n_episodes: 50
|
||||||
batch_size: 50
|
batch_size: 50
|
||||||
|
|
|
@ -28,6 +28,7 @@ from torch.cuda.amp import GradScaler
|
||||||
|
|
||||||
from lerobot.common.datasets.factory import make_dataset, resolve_delta_timestamps
|
from lerobot.common.datasets.factory import make_dataset, resolve_delta_timestamps
|
||||||
from lerobot.common.datasets.lerobot_dataset import MultiLeRobotDataset
|
from lerobot.common.datasets.lerobot_dataset import MultiLeRobotDataset
|
||||||
|
from lerobot.common.datasets.sampler import EpisodeAwareSampler
|
||||||
from lerobot.common.datasets.utils import cycle
|
from lerobot.common.datasets.utils import cycle
|
||||||
from lerobot.common.envs.factory import make_env
|
from lerobot.common.envs.factory import make_env
|
||||||
from lerobot.common.logger import Logger, log_output_dir
|
from lerobot.common.logger import Logger, log_output_dir
|
||||||
|
@ -356,11 +357,22 @@ def train(cfg: DictConfig, out_dir: str | None = None, job_name: str | None = No
|
||||||
logging.info("Resume training")
|
logging.info("Resume training")
|
||||||
|
|
||||||
# create dataloader for offline training
|
# create dataloader for offline training
|
||||||
|
if cfg.training.get("drop_n_last_frames"):
|
||||||
|
shuffle = False
|
||||||
|
sampler = EpisodeAwareSampler(
|
||||||
|
offline_dataset.episode_data_index,
|
||||||
|
drop_n_last_frames=cfg.training.drop_n_last_frames,
|
||||||
|
shuffle=True,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
shuffle = True
|
||||||
|
sampler = None
|
||||||
dataloader = torch.utils.data.DataLoader(
|
dataloader = torch.utils.data.DataLoader(
|
||||||
offline_dataset,
|
offline_dataset,
|
||||||
num_workers=cfg.training.num_workers,
|
num_workers=cfg.training.num_workers,
|
||||||
batch_size=cfg.training.batch_size,
|
batch_size=cfg.training.batch_size,
|
||||||
shuffle=True,
|
shuffle=shuffle,
|
||||||
|
sampler=sampler,
|
||||||
pin_memory=device.type != "cpu",
|
pin_memory=device.type != "cpu",
|
||||||
drop_last=False,
|
drop_last=False,
|
||||||
)
|
)
|
||||||
|
|
|
@ -0,0 +1,90 @@
|
||||||
|
#!/usr/bin/env python
|
||||||
|
|
||||||
|
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
from datasets import Dataset
|
||||||
|
|
||||||
|
from lerobot.common.datasets.sampler import EpisodeAwareSampler
|
||||||
|
from lerobot.common.datasets.utils import (
|
||||||
|
calculate_episode_data_index,
|
||||||
|
hf_transform_to_torch,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def test_drop_n_first_frames():
|
||||||
|
dataset = Dataset.from_dict(
|
||||||
|
{
|
||||||
|
"timestamp": [0.1, 0.2, 0.3, 0.4, 0.5, 0.6],
|
||||||
|
"index": [0, 1, 2, 3, 4, 5],
|
||||||
|
"episode_index": [0, 0, 1, 2, 2, 2],
|
||||||
|
},
|
||||||
|
)
|
||||||
|
dataset.set_transform(hf_transform_to_torch)
|
||||||
|
episode_data_index = calculate_episode_data_index(dataset)
|
||||||
|
sampler = EpisodeAwareSampler(episode_data_index, drop_n_first_frames=1)
|
||||||
|
assert sampler.indices == [1, 4, 5]
|
||||||
|
assert len(sampler) == 3
|
||||||
|
assert list(sampler) == [1, 4, 5]
|
||||||
|
|
||||||
|
|
||||||
|
def test_drop_n_last_frames():
|
||||||
|
dataset = Dataset.from_dict(
|
||||||
|
{
|
||||||
|
"timestamp": [0.1, 0.2, 0.3, 0.4, 0.5, 0.6],
|
||||||
|
"index": [0, 1, 2, 3, 4, 5],
|
||||||
|
"episode_index": [0, 0, 1, 2, 2, 2],
|
||||||
|
},
|
||||||
|
)
|
||||||
|
dataset.set_transform(hf_transform_to_torch)
|
||||||
|
episode_data_index = calculate_episode_data_index(dataset)
|
||||||
|
sampler = EpisodeAwareSampler(episode_data_index, drop_n_last_frames=1)
|
||||||
|
assert sampler.indices == [0, 3, 4]
|
||||||
|
assert len(sampler) == 3
|
||||||
|
assert list(sampler) == [0, 3, 4]
|
||||||
|
|
||||||
|
|
||||||
|
def test_episode_indices_to_use():
|
||||||
|
dataset = Dataset.from_dict(
|
||||||
|
{
|
||||||
|
"timestamp": [0.1, 0.2, 0.3, 0.4, 0.5, 0.6],
|
||||||
|
"index": [0, 1, 2, 3, 4, 5],
|
||||||
|
"episode_index": [0, 0, 1, 2, 2, 2],
|
||||||
|
},
|
||||||
|
)
|
||||||
|
dataset.set_transform(hf_transform_to_torch)
|
||||||
|
episode_data_index = calculate_episode_data_index(dataset)
|
||||||
|
sampler = EpisodeAwareSampler(episode_data_index, episode_indices_to_use=[0, 2])
|
||||||
|
assert sampler.indices == [0, 1, 3, 4, 5]
|
||||||
|
assert len(sampler) == 5
|
||||||
|
assert list(sampler) == [0, 1, 3, 4, 5]
|
||||||
|
|
||||||
|
|
||||||
|
def test_shuffle():
|
||||||
|
dataset = Dataset.from_dict(
|
||||||
|
{
|
||||||
|
"timestamp": [0.1, 0.2, 0.3, 0.4, 0.5, 0.6],
|
||||||
|
"index": [0, 1, 2, 3, 4, 5],
|
||||||
|
"episode_index": [0, 0, 1, 2, 2, 2],
|
||||||
|
},
|
||||||
|
)
|
||||||
|
dataset.set_transform(hf_transform_to_torch)
|
||||||
|
episode_data_index = calculate_episode_data_index(dataset)
|
||||||
|
sampler = EpisodeAwareSampler(episode_data_index, shuffle=False)
|
||||||
|
assert sampler.indices == [0, 1, 2, 3, 4, 5]
|
||||||
|
assert len(sampler) == 6
|
||||||
|
assert list(sampler) == [0, 1, 2, 3, 4, 5]
|
||||||
|
sampler = EpisodeAwareSampler(episode_data_index, shuffle=True)
|
||||||
|
assert sampler.indices == [0, 1, 2, 3, 4, 5]
|
||||||
|
assert len(sampler) == 6
|
||||||
|
assert set(sampler) == {0, 1, 2, 3, 4, 5}
|
Loading…
Reference in New Issue