Remove reset_episode_index

This commit is contained in:
Simon Alibert 2024-11-03 19:07:43 +01:00
parent a6762ec316
commit 74270c8c91
2 changed files with 0 additions and 40 deletions

View File

@ -423,31 +423,6 @@ def calculate_episode_data_index(hf_dataset: datasets.Dataset) -> Dict[str, torc
return episode_data_index
# TODO(aliberts): remove
def reset_episode_index(hf_dataset: datasets.Dataset) -> datasets.Dataset:
"""Reset the `episode_index` of the provided HuggingFace Dataset.
`episode_data_index` (and related functionality such as `load_previous_and_future_frames`) requires the
`episode_index` to be sorted, continuous (1,1,1 and not 1,2,1) and start at 0.
This brings the `episode_index` to the required format.
"""
if len(hf_dataset) == 0:
return hf_dataset
unique_episode_idxs = torch.stack(hf_dataset["episode_index"]).unique().tolist()
episode_idx_to_reset_idx_mapping = {
ep_id: reset_ep_id for reset_ep_id, ep_id in enumerate(unique_episode_idxs)
}
def modify_ep_idx_func(example):
example["episode_index"] = episode_idx_to_reset_idx_mapping[example["episode_index"].item()]
return example
hf_dataset = hf_dataset.map(modify_ep_idx_func)
return hf_dataset
def cycle(iterable):
"""The equivalent of itertools.cycle, but safe for Pytorch dataloaders.

View File

@ -10,7 +10,6 @@ from datasets import Dataset
from lerobot.common.datasets.utils import (
calculate_episode_data_index,
hf_transform_to_torch,
reset_episode_index,
)
from lerobot.common.utils.utils import (
get_global_random_state,
@ -73,20 +72,6 @@ def test_calculate_episode_data_index():
assert torch.equal(episode_data_index["to"], torch.tensor([2, 3, 6]))
def test_reset_episode_index():
dataset = Dataset.from_dict(
{
"timestamp": [0.1, 0.2, 0.3, 0.4, 0.5, 0.6],
"index": [0, 1, 2, 3, 4, 5],
"episode_index": [10, 10, 11, 12, 12, 12],
},
)
dataset.set_transform(hf_transform_to_torch)
correct_episode_index = [0, 0, 1, 2, 2, 2]
dataset = reset_episode_index(dataset)
assert dataset["episode_index"] == correct_episode_index
def test_init_hydra_config_empty():
test_file = f"/tmp/test_init_hydra_config_empty_{uuid4().hex}.yaml"
with open(test_file, "w") as f: