diff --git a/lerobot/common/datasets/factory.py b/lerobot/common/datasets/factory.py index 3bf684c9..ecb28e26 100644 --- a/lerobot/common/datasets/factory.py +++ b/lerobot/common/datasets/factory.py @@ -53,7 +53,11 @@ def make_dataset( stats["action"]["max"] = torch.tensor([511.0, 511.0], dtype=torch.float32) elif stats_path is None: # load stats if the file exists already or compute stats and save it - precomputed_stats_path = Path("data") / cfg.dataset_id / "stats.pth" + if DATA_DIR is None: + # TODO(rcadene): clean stats + precomputed_stats_path = Path("data") / cfg.dataset_id / "stats.pth" + else: + precomputed_stats_path = DATA_DIR / cfg.dataset_id / "stats.pth" if precomputed_stats_path.exists(): stats = torch.load(precomputed_stats_path) else: diff --git a/lerobot/scripts/download_and_upload_dataset.py b/lerobot/scripts/download_and_upload_dataset.py index 267b619d..0ff86697 100644 --- a/lerobot/scripts/download_and_upload_dataset.py +++ b/lerobot/scripts/download_and_upload_dataset.py @@ -5,6 +5,7 @@ useless dependencies when using datasets. import io import pickle +import shutil from pathlib import Path import einops @@ -44,7 +45,7 @@ def download_and_extract_zip(url: str, destination_folder: Path) -> bool: return False -def download_and_upload_pusht(root, dataset_id="pusht", fps=10): +def download_and_upload_pusht(root, root_tests, dataset_id="pusht", fps=10): try: import pymunk from gym_pusht.envs.pusht import PushTEnv, pymunk_to_shapely @@ -197,12 +198,12 @@ def download_and_upload_pusht(root, dataset_id="pusht", fps=10): dataset = dataset.with_format("torch") num_items_first_ep = ep_dicts[0]["frame_id"].shape[0] - dataset.select(range(num_items_first_ep)).save_to_disk(f"tests/data/{dataset_id}/train") + dataset.select(range(num_items_first_ep)).save_to_disk(f"{root_tests}/{dataset_id}/train") dataset.push_to_hub(f"lerobot/{dataset_id}", token=True) dataset.push_to_hub(f"lerobot/{dataset_id}", token=True, revision="v1.0") -def download_and_upload_xarm(root, dataset_id, fps=15): +def download_and_upload_xarm(root, root_tests, dataset_id, fps=15): root = Path(root) raw_dir = root / f"{dataset_id}_raw" if not raw_dir.exists(): @@ -308,12 +309,12 @@ def download_and_upload_xarm(root, dataset_id, fps=15): dataset = dataset.with_format("torch") num_items_first_ep = ep_dicts[0]["frame_id"].shape[0] - dataset.select(range(num_items_first_ep)).save_to_disk(f"tests/data/{dataset_id}/train") + dataset.select(range(num_items_first_ep)).save_to_disk(f"{root_tests}/{dataset_id}/train") dataset.push_to_hub(f"lerobot/{dataset_id}", token=True) dataset.push_to_hub(f"lerobot/{dataset_id}", token=True, revision="v1.0") -def download_and_upload_aloha(root, dataset_id, fps=50): +def download_and_upload_aloha(root, root_tests, dataset_id, fps=50): folder_urls = { "aloha_sim_insertion_human": "https://drive.google.com/drive/folders/1RgyD0JgTX30H4IM5XZn8I3zSV_mr8pyF", "aloha_sim_insertion_scripted": "https://drive.google.com/drive/folders/1TsojQQSXtHEoGnqgJ3gmpPQR2DPLtS2N", @@ -453,16 +454,30 @@ def download_and_upload_aloha(root, dataset_id, fps=50): dataset = dataset.with_format("torch") num_items_first_ep = ep_dicts[0]["frame_id"].shape[0] - dataset.select(range(num_items_first_ep)).save_to_disk(f"tests/data/{dataset_id}/train") + dataset.select(range(num_items_first_ep)).save_to_disk(f"{root_tests}/{dataset_id}/train") dataset.push_to_hub(f"lerobot/{dataset_id}", token=True) dataset.push_to_hub(f"lerobot/{dataset_id}", token=True, revision="v1.0") if __name__ == "__main__": root = "data" - download_and_upload_pusht(root, dataset_id="pusht") - download_and_upload_xarm(root, dataset_id="xarm_lift_medium") - download_and_upload_aloha(root, dataset_id="aloha_sim_insertion_human") - download_and_upload_aloha(root, dataset_id="aloha_sim_insertion_scripted") - download_and_upload_aloha(root, dataset_id="aloha_sim_transfer_cube_human") - download_and_upload_aloha(root, dataset_id="aloha_sim_transfer_cube_scripted") + root_tests = "{root_tests}" + + download_and_upload_pusht(root, root_tests, dataset_id="pusht") + download_and_upload_xarm(root, root_tests, dataset_id="xarm_lift_medium") + download_and_upload_aloha(root, root_tests, dataset_id="aloha_sim_insertion_human") + download_and_upload_aloha(root, root_tests, dataset_id="aloha_sim_insertion_scripted") + download_and_upload_aloha(root, root_tests, dataset_id="aloha_sim_transfer_cube_human") + download_and_upload_aloha(root, root_tests, dataset_id="aloha_sim_transfer_cube_scripted") + + dataset_ids = [ + "pusht", + "xarm_lift_medium", + "aloha_sim_insertion_human", + "aloha_sim_insertion_scripted", + "aloha_sim_transfer_cube_human", + "aloha_sim_transfer_cube_scripted", + ] + for dataset_id in dataset_ids: + # assume stats have been precomputed + shutil.copy(f"{root}/{dataset_id}/stats.pth", f"{root_tests}/{dataset_id}/stats.pth") diff --git a/lerobot/scripts/visualize_dataset.py b/lerobot/scripts/visualize_dataset.py index 4b7b7d6c..10ed98d5 100644 --- a/lerobot/scripts/visualize_dataset.py +++ b/lerobot/scripts/visualize_dataset.py @@ -62,12 +62,12 @@ def render_dataset(dataset, out_dir, max_num_episodes): ) dl_iter = iter(dataloader) - num_episodes = len(dataset.data_ids_per_episode) - for ep_id in range(min(max_num_episodes, num_episodes)): + for ep_id in range(min(max_num_episodes, dataset.num_episodes)): logging.info(f"Rendering episode {ep_id}") frames = {} - for _ in dataset.data_ids_per_episode[ep_id]: + end_of_episode = False + while not end_of_episode: item = next(dl_iter) for im_key in dataset.image_keys: @@ -77,6 +77,8 @@ def render_dataset(dataset, out_dir, max_num_episodes): # add current frame to list of frames to render frames[im_key].append(item[im_key]) + end_of_episode = item["index"].item() == item["episode_data_id_to"].item() + out_dir.mkdir(parents=True, exist_ok=True) for im_key in dataset.image_keys: if len(dataset.image_keys) > 1: diff --git a/tests/data/aloha_sim_insertion_human/stats.pth b/tests/data/aloha_sim_insertion_human/stats.pth new file mode 100644 index 00000000..a7b9248f Binary files /dev/null and b/tests/data/aloha_sim_insertion_human/stats.pth differ diff --git a/tests/data/aloha_sim_insertion_scripted/stats.pth b/tests/data/aloha_sim_insertion_scripted/stats.pth new file mode 100644 index 00000000..990d4647 Binary files /dev/null and b/tests/data/aloha_sim_insertion_scripted/stats.pth differ diff --git a/tests/data/aloha_sim_transfer_cube_human/stats.pth b/tests/data/aloha_sim_transfer_cube_human/stats.pth new file mode 100644 index 00000000..1ae356e3 Binary files /dev/null and b/tests/data/aloha_sim_transfer_cube_human/stats.pth differ diff --git a/tests/data/aloha_sim_transfer_cube_scripted/stats.pth b/tests/data/aloha_sim_transfer_cube_scripted/stats.pth new file mode 100644 index 00000000..71547f09 Binary files /dev/null and b/tests/data/aloha_sim_transfer_cube_scripted/stats.pth differ diff --git a/tests/data/pusht/stats.pth b/tests/data/pusht/stats.pth new file mode 100644 index 00000000..636985fd Binary files /dev/null and b/tests/data/pusht/stats.pth differ diff --git a/tests/data/xarm_lift_medium/stats.pth b/tests/data/xarm_lift_medium/stats.pth new file mode 100644 index 00000000..3ab4e05b Binary files /dev/null and b/tests/data/xarm_lift_medium/stats.pth differ