diff --git a/download_and_upload_dataset.py b/download_and_upload_dataset.py index 0ff86697..2e5c806c 100644 --- a/download_and_upload_dataset.py +++ b/download_and_upload_dataset.py @@ -17,6 +17,17 @@ from datasets import Dataset, Features, Image, Sequence, Value from PIL import Image as PILImage +def download_and_upload(root, root_tests, dataset_id): + if "pusht" in dataset_id: + download_and_upload_pusht(root, root_tests, dataset_id) + elif "xarm" in dataset_id: + download_and_upload_xarm(root, root_tests, dataset_id) + elif "aloha" in dataset_id: + download_and_upload_aloha(root, root_tests, dataset_id) + else: + raise ValueError(dataset_id) + + def download_and_extract_zip(url: str, destination_folder: Path) -> bool: import zipfile @@ -87,7 +98,6 @@ def download_and_upload_pusht(root, root_tests, dataset_id="pusht", fps=10): states = torch.from_numpy(dataset_dict["state"]) actions = torch.from_numpy(dataset_dict["action"]) - data_ids_per_episode = {} ep_dicts = [] id_from = 0 @@ -150,15 +160,11 @@ def download_and_upload_pusht(root, root_tests, dataset_id="pusht", fps=10): "next.reward": torch.cat([reward[1:], reward[[-1]]]), "next.done": torch.cat([done[1:], done[[-1]]]), "next.success": torch.cat([success[1:], success[[-1]]]), - "episode_data_id_from": torch.tensor([id_from] * num_frames), - "episode_data_id_to": torch.tensor([id_from + num_frames - 1] * num_frames), + "episode_data_index_from": torch.tensor([id_from] * num_frames), + "episode_data_index_to": torch.tensor([id_from + num_frames] * num_frames), } ep_dicts.append(ep_dict) - assert isinstance(episode_id, int) - data_ids_per_episode[episode_id] = torch.arange(id_from, id_to, 1) - assert len(data_ids_per_episode[episode_id]) == num_frames - id_from += num_frames data_dict = {} @@ -190,8 +196,8 @@ def download_and_upload_pusht(root, root_tests, dataset_id="pusht", fps=10): "next.done": Value(dtype="bool", id=None), "next.success": Value(dtype="bool", id=None), "index": Value(dtype="int64", id=None), - "episode_data_id_from": Value(dtype="int64", id=None), - "episode_data_id_to": Value(dtype="int64", id=None), + "episode_data_index_from": Value(dtype="int64", id=None), + "episode_data_index_to": Value(dtype="int64", id=None), } features = Features(features) dataset = Dataset.from_dict(data_dict, features=features) @@ -265,8 +271,8 @@ def download_and_upload_xarm(root, root_tests, dataset_id, fps=15): # "next.observation.state": next_state, "next.reward": next_reward, "next.done": next_done, - "episode_data_id_from": torch.tensor([id_from] * num_frames), - "episode_data_id_to": torch.tensor([id_from + num_frames - 1] * num_frames), + "episode_data_index_from": torch.tensor([id_from] * num_frames), + "episode_data_index_to": torch.tensor([id_from + num_frames] * num_frames), } ep_dicts.append(ep_dict) @@ -301,8 +307,8 @@ def download_and_upload_xarm(root, root_tests, dataset_id, fps=15): "next.done": Value(dtype="bool", id=None), #'next.success': Value(dtype='bool', id=None), "index": Value(dtype="int64", id=None), - "episode_data_id_from": Value(dtype="int64", id=None), - "episode_data_id_to": Value(dtype="int64", id=None), + "episode_data_index_from": Value(dtype="int64", id=None), + "episode_data_index_to": Value(dtype="int64", id=None), } features = Features(features) dataset = Dataset.from_dict(data_dict, features=features) @@ -390,20 +396,7 @@ def download_and_upload_aloha(root, root_tests, dataset_id, fps=50): state = torch.from_numpy(ep["/observations/qpos"][:]) action = torch.from_numpy(ep["/action"][:]) - ep_dict = { - "observation.state": state, - "action": action, - "episode_id": torch.tensor([ep_id] * num_frames), - "frame_id": torch.arange(0, num_frames, 1), - "timestamp": torch.arange(0, num_frames, 1) / fps, - # "next.observation.state": state, - # TODO(rcadene): compute reward and success - # "next.reward": reward, - "next.done": done, - # "next.success": success, - "episode_data_id_from": torch.tensor([id_from] * num_frames), - "episode_data_id_to": torch.tensor([id_from + num_frames - 1] * num_frames), - } + ep_dict = {} for cam in cameras[dataset_id]: image = torch.from_numpy(ep[f"/observations/images/{cam}"][:]) # b h w c @@ -411,6 +404,23 @@ def download_and_upload_aloha(root, root_tests, dataset_id, fps=50): ep_dict[f"observation.images.{cam}"] = [PILImage.fromarray(x.numpy()) for x in image] # ep_dict[f"next.observation.images.{cam}"] = image + ep_dict.update( + { + "observation.state": state, + "action": action, + "episode_id": torch.tensor([ep_id] * num_frames), + "frame_id": torch.arange(0, num_frames, 1), + "timestamp": torch.arange(0, num_frames, 1) / fps, + # "next.observation.state": state, + # TODO(rcadene): compute reward and success + # "next.reward": reward, + "next.done": done, + # "next.success": success, + "episode_data_index_from": torch.tensor([id_from] * num_frames), + "episode_data_index_to": torch.tensor([id_from + num_frames] * num_frames), + } + ) + assert isinstance(ep_id, int) ep_dicts.append(ep_dict) @@ -446,8 +456,8 @@ def download_and_upload_aloha(root, root_tests, dataset_id, fps=50): "next.done": Value(dtype="bool", id=None), #'next.success': Value(dtype='bool', id=None), "index": Value(dtype="int64", id=None), - "episode_data_id_from": Value(dtype="int64", id=None), - "episode_data_id_to": Value(dtype="int64", id=None), + "episode_data_index_from": Value(dtype="int64", id=None), + "episode_data_index_to": Value(dtype="int64", id=None), } features = Features(features) dataset = Dataset.from_dict(data_dict, features=features) @@ -461,23 +471,17 @@ def download_and_upload_aloha(root, root_tests, dataset_id, fps=50): if __name__ == "__main__": root = "data" - root_tests = "{root_tests}" - - download_and_upload_pusht(root, root_tests, dataset_id="pusht") - download_and_upload_xarm(root, root_tests, dataset_id="xarm_lift_medium") - download_and_upload_aloha(root, root_tests, dataset_id="aloha_sim_insertion_human") - download_and_upload_aloha(root, root_tests, dataset_id="aloha_sim_insertion_scripted") - download_and_upload_aloha(root, root_tests, dataset_id="aloha_sim_transfer_cube_human") - download_and_upload_aloha(root, root_tests, dataset_id="aloha_sim_transfer_cube_scripted") + root_tests = "tests/data" dataset_ids = [ - "pusht", - "xarm_lift_medium", - "aloha_sim_insertion_human", - "aloha_sim_insertion_scripted", - "aloha_sim_transfer_cube_human", + # "pusht", + # "xarm_lift_medium", + # "aloha_sim_insertion_human", + # "aloha_sim_insertion_scripted", + # "aloha_sim_transfer_cube_human", "aloha_sim_transfer_cube_scripted", ] for dataset_id in dataset_ids: + download_and_upload(root, root_tests, dataset_id) # assume stats have been precomputed shutil.copy(f"{root}/{dataset_id}/stats.pth", f"{root_tests}/{dataset_id}/stats.pth") diff --git a/lerobot/common/datasets/utils.py b/lerobot/common/datasets/utils.py index 154dcb68..1b353e69 100644 --- a/lerobot/common/datasets/utils.py +++ b/lerobot/common/datasets/utils.py @@ -7,9 +7,9 @@ import tqdm def load_previous_and_future_frames( - item: dict[torch.Tensor], - data_dict: dict[torch.Tensor], - delta_timestamps: dict[list[float]], + item: dict[str, torch.Tensor], + data_dict: dict[str, torch.Tensor], + delta_timestamps: dict[str, list[float]], tol: float = 0.04, ) -> dict[torch.Tensor]: """ @@ -35,12 +35,12 @@ def load_previous_and_future_frames( - AssertionError: If any of the frames unexpectedly violate the tolerance level. This could indicate synchronization issues with timestamps during data collection. """ # get indices of the frames associated to the episode, and their timestamps - ep_data_id_from = item["episode_data_id_from"].item() - ep_data_id_to = item["episode_data_id_to"].item() - ep_data_ids = torch.arange(ep_data_id_from, ep_data_id_to + 1, 1) + ep_data_id_from = item["episode_data_index_from"].item() + ep_data_id_to = item["episode_data_index_to"].item() + ep_data_ids = torch.arange(ep_data_id_from, ep_data_id_to, 1) # load timestamps - ep_timestamps = data_dict.select_columns("timestamp")[ep_data_id_from : ep_data_id_to + 1]["timestamp"] + ep_timestamps = data_dict.select_columns("timestamp")[ep_data_id_from:ep_data_id_to]["timestamp"] # we make the assumption that the timestamps are sorted ep_first_ts = ep_timestamps[0] diff --git a/lerobot/scripts/eval.py b/lerobot/scripts/eval.py index d8c697c2..6c5b757d 100644 --- a/lerobot/scripts/eval.py +++ b/lerobot/scripts/eval.py @@ -215,8 +215,8 @@ def eval_policy( "timestamp": torch.arange(0, num_frames, 1) / fps, "next.done": dones[ep_id, :num_frames], "next.reward": rewards[ep_id, :num_frames].type(torch.float32), - "episode_data_id_from": torch.tensor([idx_from] * num_frames), - "episode_data_id_to": torch.tensor([idx_from + num_frames - 1] * num_frames), + "episode_data_index_from": torch.tensor([idx_from] * num_frames), + "episode_data_index_to": torch.tensor([idx_from + num_frames] * num_frames), } for key in observations: ep_dict[key] = observations[key][ep_id][:num_frames] diff --git a/lerobot/scripts/train.py b/lerobot/scripts/train.py index 71bc1e72..19218355 100644 --- a/lerobot/scripts/train.py +++ b/lerobot/scripts/train.py @@ -141,15 +141,15 @@ def add_episodes_inplace(data_dict, online_dataset, concat_dataset, sampler, pc_ online_dataset.data_dict = data_dict else: # find episode index and data frame indices according to previous episode in online_dataset - start_episode = online_dataset.data_dict["episode_id"][-1].item() + 1 - start_index = online_dataset.data_dict["index"][-1].item() + 1 + start_episode = online_dataset.select_columns("episode_id")[-1]["episode_id"].item() + 1 + start_index = online_dataset.select_columns("index")[-1]["index"].item() + 1 def shift_indices(example): # note: we dont shift "frame_id" since it represents the index of the frame in the episode it belongs to example["episode_id"] += start_episode example["index"] += start_index - example["episode_data_id_from"] += start_index - example["episode_data_id_to"] += start_index + example["episode_data_index_from"] += start_index + example["episode_data_index_to"] += start_index return example disable_progress_bar() # map has a tqdm progress bar diff --git a/lerobot/scripts/visualize_dataset.py b/lerobot/scripts/visualize_dataset.py index 10ed98d5..e7bd0693 100644 --- a/lerobot/scripts/visualize_dataset.py +++ b/lerobot/scripts/visualize_dataset.py @@ -77,7 +77,7 @@ def render_dataset(dataset, out_dir, max_num_episodes): # add current frame to list of frames to render frames[im_key].append(item[im_key]) - end_of_episode = item["index"].item() == item["episode_data_id_to"].item() + end_of_episode = item["index"].item() == item["episode_data_index_to"].item() out_dir.mkdir(parents=True, exist_ok=True) for im_key in dataset.image_keys: diff --git a/tests/data/aloha_sim_insertion_human/train/data-00000-of-00001.arrow b/tests/data/aloha_sim_insertion_human/train/data-00000-of-00001.arrow index 4d357e34..165298cf 100644 Binary files a/tests/data/aloha_sim_insertion_human/train/data-00000-of-00001.arrow and b/tests/data/aloha_sim_insertion_human/train/data-00000-of-00001.arrow differ diff --git a/tests/data/aloha_sim_insertion_human/train/dataset_info.json b/tests/data/aloha_sim_insertion_human/train/dataset_info.json index 473812f3..542c7bf1 100644 --- a/tests/data/aloha_sim_insertion_human/train/dataset_info.json +++ b/tests/data/aloha_sim_insertion_human/train/dataset_info.json @@ -2,6 +2,9 @@ "citation": "", "description": "", "features": { + "observation.images.top": { + "_type": "Image" + }, "observation.state": { "feature": { "dtype": "float32", @@ -34,17 +37,14 @@ "dtype": "bool", "_type": "Value" }, - "episode_data_id_from": { + "episode_data_index_from": { "dtype": "int64", "_type": "Value" }, - "episode_data_id_to": { + "episode_data_index_to": { "dtype": "int64", "_type": "Value" }, - "observation.images.top": { - "_type": "Image" - }, "index": { "dtype": "int64", "_type": "Value" diff --git a/tests/data/aloha_sim_insertion_human/train/state.json b/tests/data/aloha_sim_insertion_human/train/state.json index 5b56e98c..39101fd5 100644 --- a/tests/data/aloha_sim_insertion_human/train/state.json +++ b/tests/data/aloha_sim_insertion_human/train/state.json @@ -4,7 +4,7 @@ "filename": "data-00000-of-00001.arrow" } ], - "_fingerprint": "05980bca35112ebd", + "_fingerprint": "d79cf82ffc86f110", "_format_columns": null, "_format_kwargs": {}, "_format_type": "torch", diff --git a/tests/data/aloha_sim_insertion_scripted/train/data-00000-of-00001.arrow b/tests/data/aloha_sim_insertion_scripted/train/data-00000-of-00001.arrow index 421474a2..034f759f 100644 Binary files a/tests/data/aloha_sim_insertion_scripted/train/data-00000-of-00001.arrow and b/tests/data/aloha_sim_insertion_scripted/train/data-00000-of-00001.arrow differ diff --git a/tests/data/aloha_sim_insertion_scripted/train/dataset_info.json b/tests/data/aloha_sim_insertion_scripted/train/dataset_info.json index 473812f3..542c7bf1 100644 --- a/tests/data/aloha_sim_insertion_scripted/train/dataset_info.json +++ b/tests/data/aloha_sim_insertion_scripted/train/dataset_info.json @@ -2,6 +2,9 @@ "citation": "", "description": "", "features": { + "observation.images.top": { + "_type": "Image" + }, "observation.state": { "feature": { "dtype": "float32", @@ -34,17 +37,14 @@ "dtype": "bool", "_type": "Value" }, - "episode_data_id_from": { + "episode_data_index_from": { "dtype": "int64", "_type": "Value" }, - "episode_data_id_to": { + "episode_data_index_to": { "dtype": "int64", "_type": "Value" }, - "observation.images.top": { - "_type": "Image" - }, "index": { "dtype": "int64", "_type": "Value" diff --git a/tests/data/aloha_sim_insertion_scripted/train/state.json b/tests/data/aloha_sim_insertion_scripted/train/state.json index 8f202c3a..ecaa8fd8 100644 --- a/tests/data/aloha_sim_insertion_scripted/train/state.json +++ b/tests/data/aloha_sim_insertion_scripted/train/state.json @@ -4,7 +4,7 @@ "filename": "data-00000-of-00001.arrow" } ], - "_fingerprint": "f3330a7e1d8bc55b", + "_fingerprint": "d8e4a817b5449498", "_format_columns": null, "_format_kwargs": {}, "_format_type": "torch", diff --git a/tests/data/aloha_sim_transfer_cube_human/train/data-00000-of-00001.arrow b/tests/data/aloha_sim_transfer_cube_human/train/data-00000-of-00001.arrow index 9e371c8f..9682f005 100644 Binary files a/tests/data/aloha_sim_transfer_cube_human/train/data-00000-of-00001.arrow and b/tests/data/aloha_sim_transfer_cube_human/train/data-00000-of-00001.arrow differ diff --git a/tests/data/aloha_sim_transfer_cube_human/train/dataset_info.json b/tests/data/aloha_sim_transfer_cube_human/train/dataset_info.json index 473812f3..542c7bf1 100644 --- a/tests/data/aloha_sim_transfer_cube_human/train/dataset_info.json +++ b/tests/data/aloha_sim_transfer_cube_human/train/dataset_info.json @@ -2,6 +2,9 @@ "citation": "", "description": "", "features": { + "observation.images.top": { + "_type": "Image" + }, "observation.state": { "feature": { "dtype": "float32", @@ -34,17 +37,14 @@ "dtype": "bool", "_type": "Value" }, - "episode_data_id_from": { + "episode_data_index_from": { "dtype": "int64", "_type": "Value" }, - "episode_data_id_to": { + "episode_data_index_to": { "dtype": "int64", "_type": "Value" }, - "observation.images.top": { - "_type": "Image" - }, "index": { "dtype": "int64", "_type": "Value" diff --git a/tests/data/aloha_sim_transfer_cube_human/train/state.json b/tests/data/aloha_sim_transfer_cube_human/train/state.json index ec1fdf06..0167986b 100644 --- a/tests/data/aloha_sim_transfer_cube_human/train/state.json +++ b/tests/data/aloha_sim_transfer_cube_human/train/state.json @@ -4,7 +4,7 @@ "filename": "data-00000-of-00001.arrow" } ], - "_fingerprint": "42aa77ffb6863924", + "_fingerprint": "f03482befa767127", "_format_columns": null, "_format_kwargs": {}, "_format_type": "torch", diff --git a/tests/data/aloha_sim_transfer_cube_scripted/train/data-00000-of-00001.arrow b/tests/data/aloha_sim_transfer_cube_scripted/train/data-00000-of-00001.arrow index 99d3363b..567191d5 100644 Binary files a/tests/data/aloha_sim_transfer_cube_scripted/train/data-00000-of-00001.arrow and b/tests/data/aloha_sim_transfer_cube_scripted/train/data-00000-of-00001.arrow differ diff --git a/tests/data/aloha_sim_transfer_cube_scripted/train/dataset_info.json b/tests/data/aloha_sim_transfer_cube_scripted/train/dataset_info.json index 473812f3..542c7bf1 100644 --- a/tests/data/aloha_sim_transfer_cube_scripted/train/dataset_info.json +++ b/tests/data/aloha_sim_transfer_cube_scripted/train/dataset_info.json @@ -2,6 +2,9 @@ "citation": "", "description": "", "features": { + "observation.images.top": { + "_type": "Image" + }, "observation.state": { "feature": { "dtype": "float32", @@ -34,17 +37,14 @@ "dtype": "bool", "_type": "Value" }, - "episode_data_id_from": { + "episode_data_index_from": { "dtype": "int64", "_type": "Value" }, - "episode_data_id_to": { + "episode_data_index_to": { "dtype": "int64", "_type": "Value" }, - "observation.images.top": { - "_type": "Image" - }, "index": { "dtype": "int64", "_type": "Value" diff --git a/tests/data/aloha_sim_transfer_cube_scripted/train/state.json b/tests/data/aloha_sim_transfer_cube_scripted/train/state.json index ee3cc1fe..56005bc9 100644 --- a/tests/data/aloha_sim_transfer_cube_scripted/train/state.json +++ b/tests/data/aloha_sim_transfer_cube_scripted/train/state.json @@ -4,7 +4,7 @@ "filename": "data-00000-of-00001.arrow" } ], - "_fingerprint": "43f176a3740fe622", + "_fingerprint": "93e03c6320c7d56e", "_format_columns": null, "_format_kwargs": {}, "_format_type": "torch", diff --git a/tests/data/pusht/train/data-00000-of-00001.arrow b/tests/data/pusht/train/data-00000-of-00001.arrow index 71f657ab..9a36a8db 100644 Binary files a/tests/data/pusht/train/data-00000-of-00001.arrow and b/tests/data/pusht/train/data-00000-of-00001.arrow differ diff --git a/tests/data/pusht/train/dataset_info.json b/tests/data/pusht/train/dataset_info.json index b21231fe..667e06f7 100644 --- a/tests/data/pusht/train/dataset_info.json +++ b/tests/data/pusht/train/dataset_info.json @@ -45,11 +45,11 @@ "dtype": "bool", "_type": "Value" }, - "episode_data_id_from": { + "episode_data_index_from": { "dtype": "int64", "_type": "Value" }, - "episode_data_id_to": { + "episode_data_index_to": { "dtype": "int64", "_type": "Value" }, diff --git a/tests/data/pusht/train/state.json b/tests/data/pusht/train/state.json index 090326e1..7e0ff574 100644 --- a/tests/data/pusht/train/state.json +++ b/tests/data/pusht/train/state.json @@ -4,7 +4,7 @@ "filename": "data-00000-of-00001.arrow" } ], - "_fingerprint": "f7ed966ae18000ae", + "_fingerprint": "21bb9a76ed78a475", "_format_columns": null, "_format_kwargs": {}, "_format_type": "torch", diff --git a/tests/data/xarm_lift_medium/train/data-00000-of-00001.arrow b/tests/data/xarm_lift_medium/train/data-00000-of-00001.arrow index f6ee0e50..45d527e0 100644 Binary files a/tests/data/xarm_lift_medium/train/data-00000-of-00001.arrow and b/tests/data/xarm_lift_medium/train/data-00000-of-00001.arrow differ diff --git a/tests/data/xarm_lift_medium/train/dataset_info.json b/tests/data/xarm_lift_medium/train/dataset_info.json index 81ba7c8c..bb647c41 100644 --- a/tests/data/xarm_lift_medium/train/dataset_info.json +++ b/tests/data/xarm_lift_medium/train/dataset_info.json @@ -41,11 +41,11 @@ "dtype": "bool", "_type": "Value" }, - "episode_data_id_from": { + "episode_data_index_from": { "dtype": "int64", "_type": "Value" }, - "episode_data_id_to": { + "episode_data_index_to": { "dtype": "int64", "_type": "Value" }, diff --git a/tests/data/xarm_lift_medium/train/state.json b/tests/data/xarm_lift_medium/train/state.json index 500bdb85..c930c52c 100644 --- a/tests/data/xarm_lift_medium/train/state.json +++ b/tests/data/xarm_lift_medium/train/state.json @@ -4,7 +4,7 @@ "filename": "data-00000-of-00001.arrow" } ], - "_fingerprint": "7dcd82fc3815bba6", + "_fingerprint": "a95cbec45e3bb9d6", "_format_columns": null, "_format_kwargs": {}, "_format_type": "torch", diff --git a/tests/test_datasets.py b/tests/test_datasets.py index c40d478a..85ddb00f 100644 --- a/tests/test_datasets.py +++ b/tests/test_datasets.py @@ -95,12 +95,14 @@ def test_compute_stats(): """ from lerobot.common.datasets.xarm import XarmDataset + DATA_DIR = Path(os.environ["DATA_DIR"]) if "DATA_DIR" in os.environ else None # get transform to convert images from uint8 [0,255] to float32 [0,1] transform = Prod(in_keys=XarmDataset.image_keys, prod=1 / 255.0) dataset = XarmDataset( dataset_id="xarm_lift_medium", + root=DATA_DIR, transform=transform, ) @@ -115,11 +117,11 @@ def test_compute_stats(): # get all frames from the dataset in the same dtype and range as during compute_stats dataloader = torch.utils.data.DataLoader( dataset, - num_workers=16, + num_workers=8, batch_size=len(dataset), shuffle=False, ) - data_dict = next(iter(dataloader)) # takes 23 seconds + data_dict = next(iter(dataloader)) # compute stats based on all frames from the dataset without any batching expected_stats = {} @@ -154,8 +156,8 @@ def test_load_previous_and_future_frames_within_tolerance(): data_dict = Dataset.from_dict({ "timestamp": [0.1, 0.2, 0.3, 0.4, 0.5], "index": [0, 1, 2, 3, 4], - "episode_data_id_from": [0, 0, 0, 0, 0], - "episode_data_id_to": [4, 4, 4, 4, 4], + "episode_data_index_from": [0, 0, 0, 0, 0], + "episode_data_index_to": [4, 4, 4, 4, 4], }) data_dict = data_dict.with_format("torch") item = data_dict[2] @@ -170,8 +172,8 @@ def test_load_previous_and_future_frames_outside_tolerance_inside_episode_range( data_dict = Dataset.from_dict({ "timestamp": [0.1, 0.2, 0.3, 0.4, 0.5], "index": [0, 1, 2, 3, 4], - "episode_data_id_from": [0, 0, 0, 0, 0], - "episode_data_id_to": [4, 4, 4, 4, 4], + "episode_data_index_from": [0, 0, 0, 0, 0], + "episode_data_index_to": [4, 4, 4, 4, 4], }) data_dict = data_dict.with_format("torch") item = data_dict[2] @@ -184,8 +186,8 @@ def test_load_previous_and_future_frames_outside_tolerance_outside_episode_range data_dict = Dataset.from_dict({ "timestamp": [0.1, 0.2, 0.3, 0.4, 0.5], "index": [0, 1, 2, 3, 4], - "episode_data_id_from": [0, 0, 0, 0, 0], - "episode_data_id_to": [4, 4, 4, 4, 4], + "episode_data_index_from": [0, 0, 0, 0, 0], + "episode_data_index_to": [4, 4, 4, 4, 4], }) data_dict = data_dict.with_format("torch") item = data_dict[2]