From 566a8aa98e9c731a87af3aadf7f7a8d71df4329a Mon Sep 17 00:00:00 2001 From: Radek Osmulski Date: Tue, 28 May 2024 22:10:15 +1000 Subject: [PATCH] Implement review feedback --- lerobot/common/datasets/sampler.py | 24 +++++++++++++++++++----- tests/test_sampler.py | 25 +++++++++++++++++++++++++ 2 files changed, 44 insertions(+), 5 deletions(-) diff --git a/lerobot/common/datasets/sampler.py b/lerobot/common/datasets/sampler.py index b13691d2..3a2a2173 100644 --- a/lerobot/common/datasets/sampler.py +++ b/lerobot/common/datasets/sampler.py @@ -13,18 +13,19 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Union +from typing import Iterator, Union -from torch.utils.data import SubsetRandomSampler +import torch -class EpisodeAwareSampler(SubsetRandomSampler): +class EpisodeAwareSampler: def __init__( self, episode_data_index: dict, episode_indices_to_use: Union[list, None] = None, drop_n_first_frames: int = 0, drop_n_last_frames: int = 0, + shuffle: bool = True, ): """Sampler that optionally incorporates episode boundary information. @@ -34,14 +35,27 @@ class EpisodeAwareSampler(SubsetRandomSampler): Assumes that episodes are indexed from 0 to N-1. drop_n_first_frames (int, optional): Number of frames to drop from the start of each episode. Defaults to 0. drop_n_last_frames (int, optional): Number of frames to drop from the end of each episode. Defaults to 0. + shuffle (bool, optional): Whether to shuffle the indices. Defaults to True. """ indices = [] for episode_idx, (start_index, end_index) in enumerate( - zip(episode_data_index["from"], episode_data_index["to"], strict=False) + zip(episode_data_index["from"], episode_data_index["to"], strict=True) ): if episode_indices_to_use is None or episode_idx in episode_indices_to_use: indices.extend( range(start_index.item() + drop_n_first_frames, end_index.item() - drop_n_last_frames) ) - super().__init__(indices) + self.indices = indices + self.shuffle = shuffle + + def __iter__(self) -> Iterator[int]: + if self.shuffle: + for i in torch.randperm(len(self.indices)): + yield self.indices[i] + else: + for i in self.indices: + yield i + + def __len__(self) -> int: + return len(self.indices) diff --git a/tests/test_sampler.py b/tests/test_sampler.py index 6effda37..90ab5fd3 100644 --- a/tests/test_sampler.py +++ b/tests/test_sampler.py @@ -34,6 +34,8 @@ def test_drop_n_first_frames(): episode_data_index = calculate_episode_data_index(dataset) sampler = EpisodeAwareSampler(episode_data_index, drop_n_first_frames=1) assert sampler.indices == [1, 4, 5] + assert len(sampler) == 3 + assert set(sampler) == {1, 4, 5} def test_drop_n_last_frames(): @@ -48,6 +50,8 @@ def test_drop_n_last_frames(): episode_data_index = calculate_episode_data_index(dataset) sampler = EpisodeAwareSampler(episode_data_index, drop_n_last_frames=1) assert sampler.indices == [0, 3, 4] + assert len(sampler) == 3 + assert set(sampler) == {0, 3, 4} def test_episode_indices_to_use(): @@ -62,3 +66,24 @@ def test_episode_indices_to_use(): episode_data_index = calculate_episode_data_index(dataset) sampler = EpisodeAwareSampler(episode_data_index, episode_indices_to_use=[0, 2]) assert sampler.indices == [0, 1, 3, 4, 5] + assert len(sampler) == 5 + assert set(sampler) == {0, 1, 3, 4, 5} + + +def test_shuffle(): + dataset = Dataset.from_dict( + { + "timestamp": [0.1, 0.2, 0.3, 0.4, 0.5, 0.6], + "index": [0, 1, 2, 3, 4, 5], + "episode_index": [0, 0, 1, 2, 2, 2], + }, + ) + dataset.set_transform(hf_transform_to_torch) + episode_data_index = calculate_episode_data_index(dataset) + sampler = EpisodeAwareSampler(episode_data_index, shuffle=False) + assert sampler.indices == [0, 1, 2, 3, 4, 5] + assert len(sampler) == 6 + assert list(sampler) == [0, 1, 2, 3, 4, 5] + sampler = EpisodeAwareSampler(episode_data_index, shuffle=True) + assert len(sampler) == 6 + assert set(sampler) == {0, 1, 2, 3, 4, 5}