diff --git a/lerobot/scripts/edit_dataset.py b/lerobot/scripts/edit_dataset.py index 4c7a6547..8e2afcd8 100644 --- a/lerobot/scripts/edit_dataset.py +++ b/lerobot/scripts/edit_dataset.py @@ -34,8 +34,15 @@ def remove_episodes(dataset, episodes): repo_id = dataset.repo_id info = dataset.info hf_dataset = dataset.hf_dataset + # TODO(rcadene): implement tags + # if None, should use the same tags + tags = None + + local_dir = dataset.videos_dir.parent + train_dir = local_dir / "train" + new_train_dir = local_dir / "new_train" + meta_data_dir = local_dir / "meta_data" - local_dir = dataset.videos_dir.parent / repo_id new_hf_dataset = hf_dataset.filter(lambda row: row["episode_index"] not in episodes) unique_episode_idxs = torch.stack(new_hf_dataset["episode_index"]).unique().tolist() @@ -70,28 +77,21 @@ def remove_episodes(dataset, episodes): info=info, videos_dir=dataset.videos_dir, ) - stats = compute_stats(new_dataset) - new_dataset.stats = stats new_hf_dataset = new_hf_dataset.with_format(None) # to remove transforms that cant be saved - train_dir = local_dir / "train" - new_train_dir = local_dir / "new_train" - new_hf_dataset.save_to_disk(str(new_train_dir)) shutil.rmtree(train_dir) new_train_dir.rename(train_dir) - meta_data_dir = local_dir / "meta_data" save_meta_data(info, stats, episode_data_index, meta_data_dir) new_hf_dataset.push_to_hub(repo_id, revision="main") push_meta_data_to_hub(repo_id, meta_data_dir, revision="main") - tags = None push_dataset_card_to_hub(repo_id, revision="main", tags=tags) - if new_dataset.video: - push_videos_to_hub(repo_id, new_dataset.videos_dir, revision="main") + if dataset.video: + push_videos_to_hub(repo_id, dataset.videos_dir, revision="main") create_branch(repo_id, repo_type="dataset", branch=CODEBASE_VERSION)