From 5fb010ab29743f2a75f2b9e113df85fbedcf4009 Mon Sep 17 00:00:00 2001 From: Alexander Soare Date: Wed, 22 May 2024 16:14:36 +0100 Subject: [PATCH] backup wip --- lerobot/common/datasets/lerobot_dataset.py | 46 ++++++++++++++++++++++ 1 file changed, 46 insertions(+) diff --git a/lerobot/common/datasets/lerobot_dataset.py b/lerobot/common/datasets/lerobot_dataset.py index 057e4770..1a4a982e 100644 --- a/lerobot/common/datasets/lerobot_dataset.py +++ b/lerobot/common/datasets/lerobot_dataset.py @@ -198,3 +198,49 @@ class LeRobotDataset(torch.utils.data.Dataset): obj.info = info obj.videos_dir = videos_dir return obj + + +def concatenate_datasets(datasets: list[LeRobotDataset]) -> LeRobotDataset: + """ "Take a list of datasets and concatenate them to form one dataset. + + The resulting dataset is reindexed, starting from zero, and preserving the ordering of the provided datasets. + + All datasets are expected to have at least "action" and "observation.state" keys. "observation.image" keys can vary + between datasets. Where keys match across datasets, it is expected that the data has a common shape. Note: in future + iterations of LeRobot, some of these limitations may be relaxed. + """ + + +if __name__ == "__main__": + sim_datasets = [ + "lerobot/aloha_sim_insertion_human", + "lerobot/aloha_sim_insertion_scripted", + "lerobot/aloha_sim_transfer_cube_human", + "lerobot/aloha_sim_transfer_cube_scripted", + ] + + real_datasets = [ + "lerobot/aloha_static_battery", + "lerobot/aloha_static_candy", + "lerobot/aloha_static_coffee", + "lerobot/aloha_static_coffee_new", + "lerobot/aloha_static_cups_open", + "lerobot/aloha_static_fork_pick_up", + "lerobot/aloha_static_pingpong_test", + "lerobot/aloha_static_pro_pencil", + "lerobot/aloha_static_screw_driver", + "lerobot/aloha_static_tape", + "lerobot/aloha_static_thread_velcro", + "lerobot/aloha_static_towel", + "lerobot/aloha_static_vinh_cup", + "lerobot/aloha_static_vinh_cup_left", + "lerobot/aloha_static_ziploc_slide", + ] + + concatenate_datasets([LeRobotDataset(dataset_name for dataset_name in real_datasets)]) + + # for dataset_name in real_datasets: + # print(f"{dataset_name=}") + # dataset = LeRobotDataset(dataset_name) + # item = next(iter(dataset)) + # print(item.keys())