From 15e91b5905c84afff95137e6b0dc309994670d9f Mon Sep 17 00:00:00 2001 From: Remi Cadene Date: Sat, 24 Aug 2024 18:16:37 +0200 Subject: [PATCH 1/2] Add edit_dataset.py remove --episodes --- lerobot/scripts/edit_dataset.py | 132 ++++++++++++++++++++++++++++++++ 1 file changed, 132 insertions(+) create mode 100644 lerobot/scripts/edit_dataset.py diff --git a/lerobot/scripts/edit_dataset.py b/lerobot/scripts/edit_dataset.py new file mode 100644 index 00000000..4c7a6547 --- /dev/null +++ b/lerobot/scripts/edit_dataset.py @@ -0,0 +1,132 @@ +""" +Edit your dataset in-place. + +Example of usage: +```bash +python lerobot/scripts/edit_dataset.py remove \ + --root data \ + --repo-id cadene/koch_bimanual_folding_2 \ + --episodes 0 4 7 10 34 54 69 +``` +""" + +import argparse +import shutil +from pathlib import Path + +import torch + +from lerobot.common.datasets.compute_stats import compute_stats +from lerobot.common.datasets.lerobot_dataset import CODEBASE_VERSION, LeRobotDataset +from lerobot.common.datasets.utils import calculate_episode_data_index, create_branch +from lerobot.scripts.push_dataset_to_hub import ( + push_dataset_card_to_hub, + push_meta_data_to_hub, + push_videos_to_hub, + save_meta_data, +) + + +def remove_episodes(dataset, episodes): + if not dataset.video: + raise NotImplementedError() + + repo_id = dataset.repo_id + info = dataset.info + hf_dataset = dataset.hf_dataset + + local_dir = dataset.videos_dir.parent / repo_id + new_hf_dataset = hf_dataset.filter(lambda row: row["episode_index"] not in episodes) + + unique_episode_idxs = torch.stack(new_hf_dataset["episode_index"]).unique().tolist() + + episode_idx_to_reset_idx_mapping = {} + for new_ep_idx, ep_idx in enumerate(sorted(unique_episode_idxs)): + episode_idx_to_reset_idx_mapping[ep_idx] = new_ep_idx + + for key in dataset.video_frame_keys: + path = dataset.videos_dir / f"{key}_episode_{ep_idx:06d}.mp4" + new_path = dataset.videos_dir / f"{key}_episode_{new_ep_idx:06d}.mp4" + path.rename(new_path) + + def modify_ep_idx(row): + new_ep_idx = episode_idx_to_reset_idx_mapping[row["episode_index"].item()] + + for key in dataset.video_frame_keys: + fname = f"{key}_episode_{new_ep_idx:06d}.mp4" + row[key]["path"] = f"videos/{fname}" + + row["episode_index"] = new_ep_idx + return row + + new_hf_dataset = new_hf_dataset.map(modify_ep_idx) + + episode_data_index = calculate_episode_data_index(new_hf_dataset) + + new_dataset = LeRobotDataset.from_preloaded( + repo_id=dataset.repo_id, + hf_dataset=new_hf_dataset, + episode_data_index=episode_data_index, + info=info, + videos_dir=dataset.videos_dir, + ) + + stats = compute_stats(new_dataset) + new_dataset.stats = stats + + new_hf_dataset = new_hf_dataset.with_format(None) # to remove transforms that cant be saved + + train_dir = local_dir / "train" + new_train_dir = local_dir / "new_train" + + new_hf_dataset.save_to_disk(str(new_train_dir)) + shutil.rmtree(train_dir) + new_train_dir.rename(train_dir) + + meta_data_dir = local_dir / "meta_data" + save_meta_data(info, stats, episode_data_index, meta_data_dir) + + new_hf_dataset.push_to_hub(repo_id, revision="main") + push_meta_data_to_hub(repo_id, meta_data_dir, revision="main") + tags = None + push_dataset_card_to_hub(repo_id, revision="main", tags=tags) + if new_dataset.video: + push_videos_to_hub(repo_id, new_dataset.videos_dir, revision="main") + create_branch(repo_id, repo_type="dataset", branch=CODEBASE_VERSION) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + subparsers = parser.add_subparsers(dest="mode", required=True) + + # Set common options for all the subparsers + base_parser = argparse.ArgumentParser(add_help=False) + base_parser.add_argument( + "--root", + type=Path, + default="data", + help="Root directory where the dataset will be stored locally at '{root}/{repo_id}' (e.g. 'data/hf_username/dataset_name').", + ) + base_parser.add_argument( + "--repo-id", + type=str, + default="lerobot/test", + help="Dataset identifier. By convention it should match '{hf_username}/{dataset_name}' (e.g. `lerobot/test`).", + ) + + remove_calib = subparsers.add_parser("remove", parents=[base_parser]) + remove_calib.add_argument( + "--episodes", + type=int, + nargs="+", + help="Episode indices to remove (e.g. `0 1 5 6`).", + ) + + args = parser.parse_args() + + input("It is recommended to make a copy of your dataset before modifying it. Press enter to continue.") + + dataset = LeRobotDataset(args.repo_id, root=args.root) + + if args.mode == "remove": + remove_episodes(dataset, args.episodes) From fa7e40a4a578298db95f001a3471e6c479acfd0c Mon Sep 17 00:00:00 2001 From: Remi Cadene Date: Sun, 25 Aug 2024 15:55:29 +0200 Subject: [PATCH 2/2] fix --- lerobot/scripts/edit_dataset.py | 20 ++++++++++---------- 1 file changed, 10 insertions(+), 10 deletions(-) diff --git a/lerobot/scripts/edit_dataset.py b/lerobot/scripts/edit_dataset.py index 4c7a6547..8e2afcd8 100644 --- a/lerobot/scripts/edit_dataset.py +++ b/lerobot/scripts/edit_dataset.py @@ -34,8 +34,15 @@ def remove_episodes(dataset, episodes): repo_id = dataset.repo_id info = dataset.info hf_dataset = dataset.hf_dataset + # TODO(rcadene): implement tags + # if None, should use the same tags + tags = None + + local_dir = dataset.videos_dir.parent + train_dir = local_dir / "train" + new_train_dir = local_dir / "new_train" + meta_data_dir = local_dir / "meta_data" - local_dir = dataset.videos_dir.parent / repo_id new_hf_dataset = hf_dataset.filter(lambda row: row["episode_index"] not in episodes) unique_episode_idxs = torch.stack(new_hf_dataset["episode_index"]).unique().tolist() @@ -70,28 +77,21 @@ def remove_episodes(dataset, episodes): info=info, videos_dir=dataset.videos_dir, ) - stats = compute_stats(new_dataset) - new_dataset.stats = stats new_hf_dataset = new_hf_dataset.with_format(None) # to remove transforms that cant be saved - train_dir = local_dir / "train" - new_train_dir = local_dir / "new_train" - new_hf_dataset.save_to_disk(str(new_train_dir)) shutil.rmtree(train_dir) new_train_dir.rename(train_dir) - meta_data_dir = local_dir / "meta_data" save_meta_data(info, stats, episode_data_index, meta_data_dir) new_hf_dataset.push_to_hub(repo_id, revision="main") push_meta_data_to_hub(repo_id, meta_data_dir, revision="main") - tags = None push_dataset_card_to_hub(repo_id, revision="main", tags=tags) - if new_dataset.video: - push_videos_to_hub(repo_id, new_dataset.videos_dir, revision="main") + if dataset.video: + push_videos_to_hub(repo_id, dataset.videos_dir, revision="main") create_branch(repo_id, repo_type="dataset", branch=CODEBASE_VERSION)