diff --git a/lerobot/common/datasets/lerobot_dataset.py b/lerobot/common/datasets/lerobot_dataset.py index 057e4770..1a4a982e 100644 --- a/lerobot/common/datasets/lerobot_dataset.py +++ b/lerobot/common/datasets/lerobot_dataset.py @@ -198,3 +198,49 @@ class LeRobotDataset(torch.utils.data.Dataset): obj.info = info obj.videos_dir = videos_dir return obj + + +def concatenate_datasets(datasets: list[LeRobotDataset]) -> LeRobotDataset: + """ "Take a list of datasets and concatenate them to form one dataset. + + The resulting dataset is reindexed, starting from zero, and preserving the ordering of the provided datasets. + + All datasets are expected to have at least "action" and "observation.state" keys. "observation.image" keys can vary + between datasets. Where keys match across datasets, it is expected that the data has a common shape. Note: in future + iterations of LeRobot, some of these limitations may be relaxed. + """ + + +if __name__ == "__main__": + sim_datasets = [ + "lerobot/aloha_sim_insertion_human", + "lerobot/aloha_sim_insertion_scripted", + "lerobot/aloha_sim_transfer_cube_human", + "lerobot/aloha_sim_transfer_cube_scripted", + ] + + real_datasets = [ + "lerobot/aloha_static_battery", + "lerobot/aloha_static_candy", + "lerobot/aloha_static_coffee", + "lerobot/aloha_static_coffee_new", + "lerobot/aloha_static_cups_open", + "lerobot/aloha_static_fork_pick_up", + "lerobot/aloha_static_pingpong_test", + "lerobot/aloha_static_pro_pencil", + "lerobot/aloha_static_screw_driver", + "lerobot/aloha_static_tape", + "lerobot/aloha_static_thread_velcro", + "lerobot/aloha_static_towel", + "lerobot/aloha_static_vinh_cup", + "lerobot/aloha_static_vinh_cup_left", + "lerobot/aloha_static_ziploc_slide", + ] + + concatenate_datasets([LeRobotDataset(dataset_name for dataset_name in real_datasets)]) + + # for dataset_name in real_datasets: + # print(f"{dataset_name=}") + # dataset = LeRobotDataset(dataset_name) + # item = next(iter(dataset)) + # print(item.keys())