From 74270c8c916aeb9b379e770956393de7e8b1d379 Mon Sep 17 00:00:00 2001 From: Simon Alibert Date: Sun, 3 Nov 2024 19:07:43 +0100 Subject: [PATCH] Remove reset_episode_index --- lerobot/common/datasets/utils.py | 25 ------------------------- tests/test_utils.py | 15 --------------- 2 files changed, 40 deletions(-) diff --git a/lerobot/common/datasets/utils.py b/lerobot/common/datasets/utils.py index 5ade25ae..e21c0128 100644 --- a/lerobot/common/datasets/utils.py +++ b/lerobot/common/datasets/utils.py @@ -423,31 +423,6 @@ def calculate_episode_data_index(hf_dataset: datasets.Dataset) -> Dict[str, torc return episode_data_index -# TODO(aliberts): remove -def reset_episode_index(hf_dataset: datasets.Dataset) -> datasets.Dataset: - """Reset the `episode_index` of the provided HuggingFace Dataset. - - `episode_data_index` (and related functionality such as `load_previous_and_future_frames`) requires the - `episode_index` to be sorted, continuous (1,1,1 and not 1,2,1) and start at 0. - - This brings the `episode_index` to the required format. - """ - if len(hf_dataset) == 0: - return hf_dataset - unique_episode_idxs = torch.stack(hf_dataset["episode_index"]).unique().tolist() - episode_idx_to_reset_idx_mapping = { - ep_id: reset_ep_id for reset_ep_id, ep_id in enumerate(unique_episode_idxs) - } - - def modify_ep_idx_func(example): - example["episode_index"] = episode_idx_to_reset_idx_mapping[example["episode_index"].item()] - return example - - hf_dataset = hf_dataset.map(modify_ep_idx_func) - - return hf_dataset - - def cycle(iterable): """The equivalent of itertools.cycle, but safe for Pytorch dataloaders. diff --git a/tests/test_utils.py b/tests/test_utils.py index e5ba2267..42715e00 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -10,7 +10,6 @@ from datasets import Dataset from lerobot.common.datasets.utils import ( calculate_episode_data_index, hf_transform_to_torch, - reset_episode_index, ) from lerobot.common.utils.utils import ( get_global_random_state, @@ -73,20 +72,6 @@ def test_calculate_episode_data_index(): assert torch.equal(episode_data_index["to"], torch.tensor([2, 3, 6])) -def test_reset_episode_index(): - dataset = Dataset.from_dict( - { - "timestamp": [0.1, 0.2, 0.3, 0.4, 0.5, 0.6], - "index": [0, 1, 2, 3, 4, 5], - "episode_index": [10, 10, 11, 12, 12, 12], - }, - ) - dataset.set_transform(hf_transform_to_torch) - correct_episode_index = [0, 0, 1, 2, 2, 2] - dataset = reset_episode_index(dataset) - assert dataset["episode_index"] == correct_episode_index - - def test_init_hydra_config_empty(): test_file = f"/tmp/test_init_hydra_config_empty_{uuid4().hex}.yaml" with open(test_file, "w") as f: