diff --git a/lerobot/common/datasets/lerobot_dataset.py b/lerobot/common/datasets/lerobot_dataset.py index e2b65a19..4e100d1f 100644 --- a/lerobot/common/datasets/lerobot_dataset.py +++ b/lerobot/common/datasets/lerobot_dataset.py @@ -31,9 +31,7 @@ from lerobot.common.datasets.utils import ( get_episode_data_index, get_hub_safe_version, load_hf_dataset, - load_info, - load_stats, - load_tasks, + load_metadata, ) from lerobot.common.datasets.video_utils import VideoFrame, decode_video_frames_torchvision @@ -41,6 +39,12 @@ from lerobot.common.datasets.video_utils import VideoFrame, decode_video_frames_ CODEBASE_VERSION = "v2.0" LEROBOT_HOME = Path(os.getenv("LEROBOT_HOME", "~/.cache/huggingface/lerobot")).expanduser() +DEFAULT_CHUNK_SIZE = 1000 +DEFAULT_VIDEO_PATH = "videos/chunk-{episode_chunk:03d}/{video_key}/episode_{episode_index:06d}.mp4" +DEFAULT_PARQUET_PATH = ( + "data/chunk-{episode_chunk:03d}/train-{episode_index:05d}-of-{total_episodes:05d}.parquet" +) + class LeRobotDataset(torch.utils.data.Dataset): def __init__( @@ -70,7 +74,7 @@ class LeRobotDataset(torch.utils.data.Dataset): Instantiating this class with this 'repo_id' will download the dataset from that address and load it, pending your dataset is compliant with codebase_version v2.0. If your dataset has been created before this new format, you will be prompted to convert it using our conversion script from v1.6 - to v2.0, which you can find at [TODO(aliberts): move conversion script & add location here]. + to v2.0, which you can find at lerobot/common/datasets/v2/convert_dataset_v1_to_v2.py. 2. Your dataset already exists on your local disk in the 'root' folder: This is typically the case when you recorded your dataset locally and you may or may not have @@ -139,7 +143,9 @@ class LeRobotDataset(torch.utils.data.Dataset): timestamps is separated to the next by 1/fps +/- tolerance_s. This also applies to frames decoded from video files. It is also used to check that `delta_timestamps` (when provided) are multiples of 1/fps. Defaults to 1e-4. - download_videos (bool, optional): Flag to download the videos. Defaults to True. + download_videos (bool, optional): Flag to download the videos. Note that when set to True but the + video files are already present on local disk, they won't be downloaded again. Defaults to + True. video_backend (str | None, optional): Video backend to use for decoding videos. There is currently a single option which is the pyav decoder used by Torchvision. Defaults to pyav. """ @@ -157,9 +163,8 @@ class LeRobotDataset(torch.utils.data.Dataset): # Load metadata self.root.mkdir(exist_ok=True, parents=True) self._version = get_hub_safe_version(repo_id, CODEBASE_VERSION) - self.info = load_info(repo_id, self._version, self.root) - self.stats = load_stats(repo_id, self._version, self.root) - self.tasks = load_tasks(repo_id, self._version, self.root) + self.download_metadata() + self.info, self.episode_dicts, self.stats, self.tasks = load_metadata(self.root) # Load actual data self.download_episodes() @@ -185,6 +190,15 @@ class LeRobotDataset(torch.utils.data.Dataset): # - [ ] Update episode_index (arg update=True) # - [ ] Update info.json (arg update=True) + def download_metadata(self) -> None: + snapshot_download( + self.repo_id, + repo_type="dataset", + revision=self._version, + local_dir=self.root, + allow_patterns="meta/", + ) + def download_episodes(self) -> None: """Downloads the dataset from the given 'repo_id' at the provided version. If 'episodes' is given, this will only download those episodes (selected by their episode_index). If 'episodes' is None, the whole @@ -227,11 +241,6 @@ class LeRobotDataset(torch.utils.data.Dataset): """Formattable string for the video files.""" return self.info["videos"]["videos_path"] if len(self.video_keys) > 0 else None - @property - def episode_dicts(self) -> list[dict]: - """List of dictionary containing information for each episode, indexed by episode_index.""" - return self.info["episodes"] - @property def fps(self) -> int: """Frames per second used during data collection.""" @@ -254,7 +263,7 @@ class LeRobotDataset(torch.utils.data.Dataset): @property def camera_keys(self) -> list[str]: - """Keys to access image and video streams from cameras (regardless of their storage method).""" + """Keys to access visual modalities (regardless of their storage method).""" return self.image_keys + self.video_keys @property @@ -277,6 +286,16 @@ class LeRobotDataset(torch.utils.data.Dataset): """Total number of episodes available.""" return self.info["total_episodes"] + @property + def total_chunks(self) -> int: + """Total number of chunks (groups of episodes).""" + return self.info["total_chunks"] + + @property + def chunks_size(self) -> int: + """Max number of episodes per chunk.""" + return self.info["chunks_size"] + @property def shapes(self) -> dict: """Shapes for the different features.""" @@ -397,42 +416,28 @@ class LeRobotDataset(torch.utils.data.Dataset): ) @classmethod - def from_preloaded( + def create( cls, - repo_id: str = "from_preloaded", + repo_id: str, root: Path | None = None, - split: str = "train", - transform: callable = None, + image_transforms: Callable | None = None, delta_timestamps: dict[list[float]] | None = None, - # additional preloaded attributes - hf_dataset=None, - episode_data_index=None, - stats=None, - info=None, - videos_dir=None, - video_backend=None, + tolerance_s: float = 1e-4, + video_backend: str | None = None, ) -> "LeRobotDataset": - """Create a LeRobot Dataset from existing data and attributes instead of loading from the filesystem. - - It is especially useful when converting raw data into LeRobotDataset before saving the dataset - on the filesystem or uploading to the hub. - - Note: Meta-data attributes like `repo_id`, `version`, `root`, etc are optional and potentially - meaningless depending on the downstream usage of the return dataset. - """ + """Create a LeRobot Dataset from scratch in order to record data.""" # create an empty object of type LeRobotDataset obj = cls.__new__(cls) obj.repo_id = repo_id - obj.root = root - obj.split = split - obj.image_transforms = transform - obj.delta_timestamps = delta_timestamps - obj.hf_dataset = hf_dataset - obj.episode_data_index = episode_data_index - obj.stats = stats - obj.info = info if info is not None else {} - obj.videos_dir = videos_dir - obj.video_backend = video_backend if video_backend is not None else "pyav" + obj.root = root if root is not None else LEROBOT_HOME / repo_id + # obj.episodes = None + # obj.image_transforms = None + # obj.delta_timestamps = None + # obj.episode_data_index = episode_data_index + # obj.stats = stats + # obj.info = info if info is not None else {} + # obj.videos_dir = videos_dir + # obj.video_backend = video_backend if video_backend is not None else "pyav" return obj diff --git a/lerobot/common/datasets/v2/convert_dataset_v1_to_v2.py b/lerobot/common/datasets/v2/convert_dataset_v1_to_v2.py index 36113fb1..a498f9c1 100644 --- a/lerobot/common/datasets/v2/convert_dataset_v1_to_v2.py +++ b/lerobot/common/datasets/v2/convert_dataset_v1_to_v2.py @@ -120,6 +120,11 @@ from huggingface_hub.errors import EntryNotFoundError from PIL import Image from safetensors.torch import load_file +from lerobot.common.datasets.lerobot_dataset import ( + DEFAULT_CHUNK_SIZE, + DEFAULT_PARQUET_PATH, + DEFAULT_VIDEO_PATH, +) from lerobot.common.datasets.utils import create_branch, flatten_dict, get_hub_safe_version, unflatten_dict from lerobot.common.utils.utils import init_hydra_config from lerobot.scripts.push_dataset_to_hub import push_dataset_card_to_hub @@ -127,15 +132,8 @@ from lerobot.scripts.push_dataset_to_hub import push_dataset_card_to_hub V16 = "v1.6" V20 = "v2.0" -EPISODE_CHUNK_SIZE = 1000 - GITATTRIBUTES_REF = "aliberts/gitattributes_reference" - VIDEO_FILE = "{video_key}_episode_{episode_index:06d}.mp4" -PARQUET_CHUNK_PATH = ( - "data/chunk-{episode_chunk:03d}/train-{episode_index:05d}-of-{total_episodes:05d}.parquet" -) -VIDEO_CHUNK_PATH = "videos/chunk-{episode_chunk:03d}/{video_key}/episode_{episode_index:06d}.mp4" def parse_robot_config(config_path: Path, config_overrides: list[str] | None = None) -> tuple[str, dict]: @@ -269,15 +267,15 @@ def split_parquet_by_episodes( table = dataset.remove_columns(keys["video"])._data.table episode_lengths = [] for ep_chunk in range(total_chunks): - ep_chunk_start = EPISODE_CHUNK_SIZE * ep_chunk - ep_chunk_end = min(EPISODE_CHUNK_SIZE * (ep_chunk + 1), total_episodes) + ep_chunk_start = DEFAULT_CHUNK_SIZE * ep_chunk + ep_chunk_end = min(DEFAULT_CHUNK_SIZE * (ep_chunk + 1), total_episodes) - chunk_dir = "/".join(PARQUET_CHUNK_PATH.split("/")[:-1]).format(episode_chunk=ep_chunk) + chunk_dir = "/".join(DEFAULT_PARQUET_PATH.split("/")[:-1]).format(episode_chunk=ep_chunk) (output_dir / chunk_dir).mkdir(parents=True, exist_ok=True) for ep_idx in range(ep_chunk_start, ep_chunk_end): ep_table = table.filter(pc.equal(table["episode_index"], ep_idx)) episode_lengths.insert(ep_idx, len(ep_table)) - output_file = output_dir / PARQUET_CHUNK_PATH.format( + output_file = output_dir / DEFAULT_PARQUET_PATH.format( episode_chunk=ep_chunk, episode_index=ep_idx, total_episodes=total_episodes ) pq.write_table(ep_table, output_file) @@ -323,16 +321,16 @@ def move_videos( video_dirs = sorted(work_dir.glob("videos*/")) for ep_chunk in range(total_chunks): - ep_chunk_start = EPISODE_CHUNK_SIZE * ep_chunk - ep_chunk_end = min(EPISODE_CHUNK_SIZE * (ep_chunk + 1), total_episodes) + ep_chunk_start = DEFAULT_CHUNK_SIZE * ep_chunk + ep_chunk_end = min(DEFAULT_CHUNK_SIZE * (ep_chunk + 1), total_episodes) for vid_key in video_keys: - chunk_dir = "/".join(VIDEO_CHUNK_PATH.split("/")[:-1]).format( + chunk_dir = "/".join(DEFAULT_VIDEO_PATH.split("/")[:-1]).format( episode_chunk=ep_chunk, video_key=vid_key ) (work_dir / chunk_dir).mkdir(parents=True, exist_ok=True) for ep_idx in range(ep_chunk_start, ep_chunk_end): - target_path = VIDEO_CHUNK_PATH.format( + target_path = DEFAULT_VIDEO_PATH.format( episode_chunk=ep_chunk, video_key=vid_key, episode_index=ep_idx ) video_file = VIDEO_FILE.format(video_key=vid_key, episode_index=ep_idx) @@ -476,11 +474,12 @@ def _get_video_info(video_path: Path | str) -> dict: def get_videos_info(repo_id: str, local_dir: Path, video_keys: list[str], branch: str) -> dict: hub_api = HfApi() - videos_info_dict = {"videos_path": VIDEO_CHUNK_PATH} + videos_info_dict = {"videos_path": DEFAULT_VIDEO_PATH} # Assumes first episode video_files = [ - VIDEO_CHUNK_PATH.format(episode_chunk=0, video_key=vid_key, episode_index=0) for vid_key in video_keys + DEFAULT_VIDEO_PATH.format(episode_chunk=0, video_key=vid_key, episode_index=0) + for vid_key in video_keys ] hub_api.snapshot_download( repo_id=repo_id, repo_type="dataset", local_dir=local_dir, revision=branch, allow_patterns=video_files @@ -587,8 +586,8 @@ def convert_dataset( total_episodes = len(episode_indices) assert episode_indices == list(range(total_episodes)) total_videos = total_episodes * len(keys["video"]) - total_chunks = total_episodes // EPISODE_CHUNK_SIZE - if total_episodes % EPISODE_CHUNK_SIZE != 0: + total_chunks = total_episodes // DEFAULT_CHUNK_SIZE + if total_episodes % DEFAULT_CHUNK_SIZE != 0: total_chunks += 1 # Tasks @@ -670,14 +669,14 @@ def convert_dataset( # Assemble metadata v2.0 metadata_v2_0 = { "codebase_version": V20, - "data_path": PARQUET_CHUNK_PATH, + "data_path": DEFAULT_PARQUET_PATH, "robot_type": robot_type, "total_episodes": total_episodes, "total_frames": len(dataset), "total_tasks": len(tasks), "total_videos": total_videos, "total_chunks": total_chunks, - "chunks_size": EPISODE_CHUNK_SIZE, + "chunks_size": DEFAULT_CHUNK_SIZE, "fps": metadata_v1["fps"], "splits": {"train": f"0:{total_episodes}"}, "keys": keys["sequence"],