This commit is contained in:
Remi Cadene 2024-08-25 15:55:29 +02:00
parent 15e91b5905
commit fa7e40a4a5
1 changed files with 10 additions and 10 deletions

View File

@ -34,8 +34,15 @@ def remove_episodes(dataset, episodes):
repo_id = dataset.repo_id repo_id = dataset.repo_id
info = dataset.info info = dataset.info
hf_dataset = dataset.hf_dataset hf_dataset = dataset.hf_dataset
# TODO(rcadene): implement tags
# if None, should use the same tags
tags = None
local_dir = dataset.videos_dir.parent
train_dir = local_dir / "train"
new_train_dir = local_dir / "new_train"
meta_data_dir = local_dir / "meta_data"
local_dir = dataset.videos_dir.parent / repo_id
new_hf_dataset = hf_dataset.filter(lambda row: row["episode_index"] not in episodes) new_hf_dataset = hf_dataset.filter(lambda row: row["episode_index"] not in episodes)
unique_episode_idxs = torch.stack(new_hf_dataset["episode_index"]).unique().tolist() unique_episode_idxs = torch.stack(new_hf_dataset["episode_index"]).unique().tolist()
@ -70,28 +77,21 @@ def remove_episodes(dataset, episodes):
info=info, info=info,
videos_dir=dataset.videos_dir, videos_dir=dataset.videos_dir,
) )
stats = compute_stats(new_dataset) stats = compute_stats(new_dataset)
new_dataset.stats = stats
new_hf_dataset = new_hf_dataset.with_format(None) # to remove transforms that cant be saved new_hf_dataset = new_hf_dataset.with_format(None) # to remove transforms that cant be saved
train_dir = local_dir / "train"
new_train_dir = local_dir / "new_train"
new_hf_dataset.save_to_disk(str(new_train_dir)) new_hf_dataset.save_to_disk(str(new_train_dir))
shutil.rmtree(train_dir) shutil.rmtree(train_dir)
new_train_dir.rename(train_dir) new_train_dir.rename(train_dir)
meta_data_dir = local_dir / "meta_data"
save_meta_data(info, stats, episode_data_index, meta_data_dir) save_meta_data(info, stats, episode_data_index, meta_data_dir)
new_hf_dataset.push_to_hub(repo_id, revision="main") new_hf_dataset.push_to_hub(repo_id, revision="main")
push_meta_data_to_hub(repo_id, meta_data_dir, revision="main") push_meta_data_to_hub(repo_id, meta_data_dir, revision="main")
tags = None
push_dataset_card_to_hub(repo_id, revision="main", tags=tags) push_dataset_card_to_hub(repo_id, revision="main", tags=tags)
if new_dataset.video: if dataset.video:
push_videos_to_hub(repo_id, new_dataset.videos_dir, revision="main") push_videos_to_hub(repo_id, dataset.videos_dir, revision="main")
create_branch(repo_id, repo_type="dataset", branch=CODEBASE_VERSION) create_branch(repo_id, repo_type="dataset", branch=CODEBASE_VERSION)