Add clone, delete, WIP on remove_episode, drop_frame
This commit is contained in:
parent
ebe0bfad77
commit
49ae3e19e1
|
@ -964,6 +964,35 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||||
obj.video_backend = video_backend if video_backend is not None else "pyav"
|
obj.video_backend = video_backend if video_backend is not None else "pyav"
|
||||||
return obj
|
return obj
|
||||||
|
|
||||||
|
def clone(self, new_repo_id: str, new_root: str | Path | None = None) -> "LeRobotDataset":
|
||||||
|
return LeRobotDataset.create(
|
||||||
|
repo_id=new_repo_id,
|
||||||
|
fps=self.fps,
|
||||||
|
root=new_root,
|
||||||
|
robot=self.robot,
|
||||||
|
robot_type=self.robot_type,
|
||||||
|
features=self.features,
|
||||||
|
use_videos=self.use_videos,
|
||||||
|
tolerance_s=self.tolerance_s,
|
||||||
|
image_writer_processes=self.image_writer_processes,
|
||||||
|
image_writer_threads=self.image_writer_threads,
|
||||||
|
video_backend=self.video_backend,
|
||||||
|
)
|
||||||
|
|
||||||
|
def delete(self):
|
||||||
|
"""Delete the dataset locally. If it was push to hub, you can still access it by downloading it again."""
|
||||||
|
shutil.rmtree(self.root)
|
||||||
|
|
||||||
|
def remove_episode(self, episode: int | list[int]):
|
||||||
|
if isinstance(episode, int):
|
||||||
|
episode = [episode]
|
||||||
|
|
||||||
|
for ep in episode:
|
||||||
|
self.meta.info
|
||||||
|
|
||||||
|
def drop_frame(self, episode_range: dict[int, tuple[int]]):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
class MultiLeRobotDataset(torch.utils.data.Dataset):
|
class MultiLeRobotDataset(torch.utils.data.Dataset):
|
||||||
"""A dataset consisting of multiple underlying `LeRobotDataset`s.
|
"""A dataset consisting of multiple underlying `LeRobotDataset`s.
|
||||||
|
|
|
@ -28,6 +28,28 @@ from pathlib import Path
|
||||||
|
|
||||||
from lerobot.common.datasets.lerobot_dataset import LeRobotDataset
|
from lerobot.common.datasets.lerobot_dataset import LeRobotDataset
|
||||||
|
|
||||||
|
|
||||||
|
def parse_episode_range_string(ep_range_str):
|
||||||
|
parts = ep_range_str.split("-")
|
||||||
|
if len(parts) != 3:
|
||||||
|
raise ValueError(
|
||||||
|
f"Invalid episode range string '{ep_range_str}'. Expected format: 'EP-FROM-TO', e.g., '1-5-10'."
|
||||||
|
)
|
||||||
|
ep, start, end = parts
|
||||||
|
return int(ep), int(start), int(end)
|
||||||
|
|
||||||
|
|
||||||
|
def parse_episode_range_strings(ep_range_strings):
|
||||||
|
ep_ranges = {}
|
||||||
|
for ep_range_str in ep_range_strings:
|
||||||
|
ep, start, end = parse_episode_range_string(ep_range_str)
|
||||||
|
if ep not in ep_ranges:
|
||||||
|
ep_ranges[ep] = []
|
||||||
|
ep_ranges[ep].append((start, end))
|
||||||
|
|
||||||
|
return ep_ranges
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
parser = argparse.ArgumentParser()
|
parser = argparse.ArgumentParser()
|
||||||
subparsers = parser.add_subparsers(dest="mode", required=True)
|
subparsers = parser.add_subparsers(dest="mode", required=True)
|
||||||
|
@ -38,7 +60,7 @@ if __name__ == "__main__":
|
||||||
"--root",
|
"--root",
|
||||||
type=Path,
|
type=Path,
|
||||||
default=None,
|
default=None,
|
||||||
help="Root directory where the dataset will be stored (e.g. 'dataset/path').",
|
help="Root directory where the dataset is stored (e.g. 'dataset/path').",
|
||||||
)
|
)
|
||||||
base_parser.add_argument(
|
base_parser.add_argument(
|
||||||
"--repo-id",
|
"--repo-id",
|
||||||
|
@ -53,7 +75,6 @@ if __name__ == "__main__":
|
||||||
help="Use local files only. By default, this script will try to fetch the dataset from the hub if it exists.",
|
help="Use local files only. By default, this script will try to fetch the dataset from the hub if it exists.",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
############################################################################
|
############################################################################
|
||||||
# consolidate
|
# consolidate
|
||||||
parser_conso = subparsers.add_parser("consolidate", parents=[base_parser])
|
parser_conso = subparsers.add_parser("consolidate", parents=[base_parser])
|
||||||
|
@ -93,6 +114,45 @@ if __name__ == "__main__":
|
||||||
help="Create a private dataset repository on the Hugging Face Hub. Push publicly by default.",
|
help="Create a private dataset repository on the Hugging Face Hub. Push publicly by default.",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
############################################################################
|
||||||
|
# clone
|
||||||
|
parser_clone = subparsers.add_parser("clone", parents=[base_parser])
|
||||||
|
parser_clone.add_argument(
|
||||||
|
"--root",
|
||||||
|
type=Path,
|
||||||
|
default=None,
|
||||||
|
help="New root directory where the dataset is stored (e.g. 'dataset/path').",
|
||||||
|
)
|
||||||
|
parser_clone.add_argument(
|
||||||
|
"--new-repo-id",
|
||||||
|
type=str,
|
||||||
|
help="New dataset identifier. By convention it should match '{hf_username}/{dataset_name}' (e.g. `lerobot/test`).",
|
||||||
|
)
|
||||||
|
|
||||||
|
############################################################################
|
||||||
|
# delete
|
||||||
|
parser_del = subparsers.add_parser("delete", parents=[base_parser])
|
||||||
|
|
||||||
|
############################################################################
|
||||||
|
# remove_episode
|
||||||
|
parser_rm_ep = subparsers.add_parser("remove_episode", parents=[base_parser])
|
||||||
|
parser_rm_ep.add_argument(
|
||||||
|
"--episode",
|
||||||
|
type=int,
|
||||||
|
nargs="*",
|
||||||
|
help="List of one or several episodes to be removed from the dataset locally.",
|
||||||
|
)
|
||||||
|
|
||||||
|
############################################################################
|
||||||
|
# drop_frame
|
||||||
|
parser_drop_frame = subparsers.add_parser("drop_frame", parents=[base_parser])
|
||||||
|
parser_rm_ep.add_argument(
|
||||||
|
"--episode-range",
|
||||||
|
type=str,
|
||||||
|
nargs="*",
|
||||||
|
help="List of one or several frame ranges per episode to be removed from the dataset locally. For instance, using `--episode-frame-range 0-0-10 3-5-20` will remove from episode 0, the frames from indices 0 to 10 excluded, and from episode 3 the frames from indices 5 to 20.",
|
||||||
|
)
|
||||||
|
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
kwargs = vars(args)
|
kwargs = vars(args)
|
||||||
|
|
||||||
|
@ -114,15 +174,15 @@ if __name__ == "__main__":
|
||||||
private = kwargs.pop("private") == 1
|
private = kwargs.pop("private") == 1
|
||||||
dataset.push_to_hub(private=private, **kwargs)
|
dataset.push_to_hub(private=private, **kwargs)
|
||||||
|
|
||||||
|
elif mode == "clone":
|
||||||
|
dataset.clone(**kwargs)
|
||||||
|
|
||||||
|
elif mode == "delete":
|
||||||
|
dataset.delete(**kwargs)
|
||||||
|
|
||||||
elif mode == "remove_episode":
|
elif mode == "remove_episode":
|
||||||
remove_episode(**kwargs)
|
dataset.remove_episode(**kwargs)
|
||||||
|
|
||||||
elif mode == "delete_dataset":
|
elif mode == "drop_frame":
|
||||||
delete_dataset()
|
ep_range = parse_episode_range_strings(kwargs.pop("episode_range"))
|
||||||
|
dataset.drop_frame(episode_range=ep_range, **kwargs)
|
||||||
elif mode == "_episode":
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue