fix
This commit is contained in:
parent
15e91b5905
commit
fa7e40a4a5
|
@ -34,8 +34,15 @@ def remove_episodes(dataset, episodes):
|
|||
repo_id = dataset.repo_id
|
||||
info = dataset.info
|
||||
hf_dataset = dataset.hf_dataset
|
||||
# TODO(rcadene): implement tags
|
||||
# if None, should use the same tags
|
||||
tags = None
|
||||
|
||||
local_dir = dataset.videos_dir.parent
|
||||
train_dir = local_dir / "train"
|
||||
new_train_dir = local_dir / "new_train"
|
||||
meta_data_dir = local_dir / "meta_data"
|
||||
|
||||
local_dir = dataset.videos_dir.parent / repo_id
|
||||
new_hf_dataset = hf_dataset.filter(lambda row: row["episode_index"] not in episodes)
|
||||
|
||||
unique_episode_idxs = torch.stack(new_hf_dataset["episode_index"]).unique().tolist()
|
||||
|
@ -70,28 +77,21 @@ def remove_episodes(dataset, episodes):
|
|||
info=info,
|
||||
videos_dir=dataset.videos_dir,
|
||||
)
|
||||
|
||||
stats = compute_stats(new_dataset)
|
||||
new_dataset.stats = stats
|
||||
|
||||
new_hf_dataset = new_hf_dataset.with_format(None) # to remove transforms that cant be saved
|
||||
|
||||
train_dir = local_dir / "train"
|
||||
new_train_dir = local_dir / "new_train"
|
||||
|
||||
new_hf_dataset.save_to_disk(str(new_train_dir))
|
||||
shutil.rmtree(train_dir)
|
||||
new_train_dir.rename(train_dir)
|
||||
|
||||
meta_data_dir = local_dir / "meta_data"
|
||||
save_meta_data(info, stats, episode_data_index, meta_data_dir)
|
||||
|
||||
new_hf_dataset.push_to_hub(repo_id, revision="main")
|
||||
push_meta_data_to_hub(repo_id, meta_data_dir, revision="main")
|
||||
tags = None
|
||||
push_dataset_card_to_hub(repo_id, revision="main", tags=tags)
|
||||
if new_dataset.video:
|
||||
push_videos_to_hub(repo_id, new_dataset.videos_dir, revision="main")
|
||||
if dataset.video:
|
||||
push_videos_to_hub(repo_id, dataset.videos_dir, revision="main")
|
||||
create_branch(repo_id, repo_type="dataset", branch=CODEBASE_VERSION)
|
||||
|
||||
|
||||
|
|
Loading…
Reference in New Issue