From 621f69d98fce1a8d91d232080e1c8b017f544a1d Mon Sep 17 00:00:00 2001 From: Radek Osmulski Date: Tue, 28 May 2024 23:14:08 +1000 Subject: [PATCH] Implement 2nd round of review changes --- lerobot/common/datasets/sampler.py | 4 ++-- lerobot/scripts/train.py | 1 + tests/test_sampler.py | 6 +++--- 3 files changed, 6 insertions(+), 5 deletions(-) diff --git a/lerobot/common/datasets/sampler.py b/lerobot/common/datasets/sampler.py index 3a2a2173..3905bcc3 100644 --- a/lerobot/common/datasets/sampler.py +++ b/lerobot/common/datasets/sampler.py @@ -25,7 +25,7 @@ class EpisodeAwareSampler: episode_indices_to_use: Union[list, None] = None, drop_n_first_frames: int = 0, drop_n_last_frames: int = 0, - shuffle: bool = True, + shuffle: bool = False, ): """Sampler that optionally incorporates episode boundary information. @@ -35,7 +35,7 @@ class EpisodeAwareSampler: Assumes that episodes are indexed from 0 to N-1. drop_n_first_frames (int, optional): Number of frames to drop from the start of each episode. Defaults to 0. drop_n_last_frames (int, optional): Number of frames to drop from the end of each episode. Defaults to 0. - shuffle (bool, optional): Whether to shuffle the indices. Defaults to True. + shuffle (bool, optional): Whether to shuffle the indices. Defaults to False. """ indices = [] for episode_idx, (start_index, end_index) in enumerate( diff --git a/lerobot/scripts/train.py b/lerobot/scripts/train.py index 1c575888..06c821d7 100644 --- a/lerobot/scripts/train.py +++ b/lerobot/scripts/train.py @@ -355,6 +355,7 @@ def train(cfg: DictConfig, out_dir: str | None = None, job_name: str | None = No sampler=EpisodeAwareSampler( offline_dataset.episode_data_index, drop_n_last_frames=cfg.drop_n_last_frames if hasattr(cfg, "drop_n_last_frames") else 0, + shuffle=True, ), pin_memory=device.type != "cpu", drop_last=False, diff --git a/tests/test_sampler.py b/tests/test_sampler.py index 90ab5fd3..2326d12c 100644 --- a/tests/test_sampler.py +++ b/tests/test_sampler.py @@ -35,7 +35,7 @@ def test_drop_n_first_frames(): sampler = EpisodeAwareSampler(episode_data_index, drop_n_first_frames=1) assert sampler.indices == [1, 4, 5] assert len(sampler) == 3 - assert set(sampler) == {1, 4, 5} + assert list(sampler) == [1, 4, 5] def test_drop_n_last_frames(): @@ -51,7 +51,7 @@ def test_drop_n_last_frames(): sampler = EpisodeAwareSampler(episode_data_index, drop_n_last_frames=1) assert sampler.indices == [0, 3, 4] assert len(sampler) == 3 - assert set(sampler) == {0, 3, 4} + assert list(sampler) == [0, 3, 4] def test_episode_indices_to_use(): @@ -67,7 +67,7 @@ def test_episode_indices_to_use(): sampler = EpisodeAwareSampler(episode_data_index, episode_indices_to_use=[0, 2]) assert sampler.indices == [0, 1, 3, 4, 5] assert len(sampler) == 5 - assert set(sampler) == {0, 1, 3, 4, 5} + assert list(sampler) == [0, 1, 3, 4, 5] def test_shuffle():