Commit before episodes episodes_stats merging

This commit is contained in:
Remi Cadene 2025-04-09 15:20:15 +02:00
parent 53ecec5fb2
commit c1b28f0b58
12 changed files with 905 additions and 396 deletions

View File

@ -26,7 +26,7 @@ from datatrove.pipeline.base import PipelineStep
from examples.port_datasets.droid_rlds.port_droid import DROID_SHARDS from examples.port_datasets.droid_rlds.port_droid import DROID_SHARDS
from lerobot.common.datasets.aggregate import validate_all_metadata from lerobot.common.datasets.aggregate import validate_all_metadata
from lerobot.common.datasets.lerobot_dataset import LeRobotDatasetMetadata from lerobot.common.datasets.lerobot_dataset import LeRobotDatasetMetadata
from lerobot.common.datasets.utils import write_episode, write_episode_stats, write_info, write_task from lerobot.common.datasets.utils import write_episode, legacy_write_episode_stats, write_info, legacy_write_task
from lerobot.common.utils.utils import init_logging from lerobot.common.utils.utils import init_logging
@ -124,11 +124,11 @@ class AggregateDatasets(PipelineStep):
for episode_index, episode_stats in tqdm.tqdm( for episode_index, episode_stats in tqdm.tqdm(
aggr_meta.episodes_stats.items(), desc="Write episodes stats" aggr_meta.episodes_stats.items(), desc="Write episodes stats"
): ):
write_episode_stats(episode_index, episode_stats, aggr_meta.root) legacy_write_episode_stats(episode_index, episode_stats, aggr_meta.root)
# create a new task jsonl with updated episode_index using write_task # create a new task jsonl with updated episode_index using write_task
for task_index, task in tqdm.tqdm(aggr_meta.tasks.items(), desc="Write tasks"): for task_index, task in tqdm.tqdm(aggr_meta.tasks.items(), desc="Write tasks"):
write_task(task_index, task, aggr_meta.root) legacy_write_task(task_index, task, aggr_meta.root)
write_info(aggr_meta.info, aggr_meta.root) write_info(aggr_meta.info, aggr_meta.root)

View File

@ -5,7 +5,7 @@ import pandas as pd
import tqdm import tqdm
from lerobot.common.datasets.lerobot_dataset import LeRobotDataset, LeRobotDatasetMetadata from lerobot.common.datasets.lerobot_dataset import LeRobotDataset, LeRobotDatasetMetadata
from lerobot.common.datasets.utils import write_episode, write_episode_stats, write_info, write_task from lerobot.common.datasets.utils import write_episode, legacy_write_episode_stats, write_info, legacy_write_task
from lerobot.common.utils.utils import init_logging from lerobot.common.utils.utils import init_logging
@ -136,11 +136,11 @@ def aggregate_datasets(repo_ids: list[str], aggr_repo_id: str, aggr_root=None):
# create a new episode_stats jsonl with updated episode_index using write_episode_stats # create a new episode_stats jsonl with updated episode_index using write_episode_stats
for episode_index, episode_stats in aggr_meta.episodes_stats.items(): for episode_index, episode_stats in aggr_meta.episodes_stats.items():
write_episode_stats(episode_index, episode_stats, aggr_meta.root) legacy_write_episode_stats(episode_index, episode_stats, aggr_meta.root)
# create a new task jsonl with updated episode_index using write_task # create a new task jsonl with updated episode_index using write_task
for task_index, task in aggr_meta.tasks.items(): for task_index, task in aggr_meta.tasks.items():
write_task(task_index, task, aggr_meta.root) legacy_write_task(task_index, task, aggr_meta.root)
write_info(aggr_meta.info, aggr_meta.root) write_info(aggr_meta.info, aggr_meta.root)

View File

@ -33,6 +33,18 @@ If you encounter a problem, contact LeRobot maintainers on [Discord](https://dis
or open an [issue on GitHub](https://github.com/huggingface/lerobot/issues/new/choose). or open an [issue on GitHub](https://github.com/huggingface/lerobot/issues/new/choose).
""" """
V30_MESSAGE = """
The dataset you requested ({repo_id}) is in {version} format.
While current version of LeRobot is backward-compatible with it, the version of your dataset still uses global
stats instead of per-episode stats. Update your dataset stats to the new format using this command:
```
python lerobot/common/datasets/v30/convert_dataset_v21_to_v30.py --repo-id={repo_id}
```
If you encounter a problem, contact LeRobot maintainers on [Discord](https://discord.com/invite/s3KuuzsPFb)
or open an [issue on GitHub](https://github.com/huggingface/lerobot/issues/new/choose).
"""
FUTURE_MESSAGE = """ FUTURE_MESSAGE = """
The dataset you requested ({repo_id}) is only available in {version} format. The dataset you requested ({repo_id}) is only available in {version} format.
As we cannot ensure forward compatibility with it, please update your current version of lerobot. As we cannot ensure forward compatibility with it, please update your current version of lerobot.
@ -44,7 +56,12 @@ class CompatibilityError(Exception): ...
class BackwardCompatibilityError(CompatibilityError): class BackwardCompatibilityError(CompatibilityError):
def __init__(self, repo_id: str, version: packaging.version.Version): def __init__(self, repo_id: str, version: packaging.version.Version):
if version.major == 3:
message = V30_MESSAGE.format(repo_id=repo_id, version=version)
elif version.major == 2:
message = V2_MESSAGE.format(repo_id=repo_id, version=version) message = V2_MESSAGE.format(repo_id=repo_id, version=version)
else:
raise NotImplementedError("Contact the maintainer on [Discord](https://discord.com/invite/s3KuuzsPFb).")
super().__init__(message) super().__init__(message)

View File

@ -17,15 +17,16 @@ import contextlib
import logging import logging
import shutil import shutil
from pathlib import Path from pathlib import Path
import tempfile
from typing import Callable from typing import Callable
import datasets import datasets
import numpy as np import numpy as np
import packaging.version import packaging.version
import PIL.Image import PIL.Image
import pandas as pd
import torch import torch
import torch.utils import torch.utils
from datasets import concatenate_datasets, load_dataset from datasets import concatenate_datasets, load_dataset, Dataset
from huggingface_hub import HfApi, snapshot_download from huggingface_hub import HfApi, snapshot_download
from huggingface_hub.constants import REPOCARD_NAME from huggingface_hub.constants import REPOCARD_NAME
from huggingface_hub.errors import RevisionNotFoundError from huggingface_hub.errors import RevisionNotFoundError
@ -34,37 +35,57 @@ from lerobot.common.constants import HF_LEROBOT_HOME
from lerobot.common.datasets.compute_stats import aggregate_stats, compute_episode_stats from lerobot.common.datasets.compute_stats import aggregate_stats, compute_episode_stats
from lerobot.common.datasets.image_writer import AsyncImageWriter, write_image from lerobot.common.datasets.image_writer import AsyncImageWriter, write_image
from lerobot.common.datasets.utils import ( from lerobot.common.datasets.utils import (
DEFAULT_EPISODES_PATH,
DEFAULT_EPISODES_STATS_PATH,
DEFAULT_FEATURES, DEFAULT_FEATURES,
DEFAULT_IMAGE_PATH, DEFAULT_IMAGE_PATH,
EPISODES_DIR,
EPISODES_STATS_DIR,
INFO_PATH, INFO_PATH,
TASKS_PATH, LEGACY_TASKS_PATH,
append_jsonlines, append_jsonlines,
backward_compatible_episodes_stats, backward_compatible_episodes_stats,
check_delta_timestamps, check_delta_timestamps,
check_timestamps_sync, check_timestamps_sync,
check_version_compatibility, check_version_compatibility,
concat_video_files,
create_empty_dataset_info, create_empty_dataset_info,
create_lerobot_dataset_card, create_lerobot_dataset_card,
embed_images, embed_images,
get_chunk_file_indices,
get_delta_indices, get_delta_indices,
get_episode_data_index, get_episode_data_index,
get_features_from_robot, get_features_from_robot,
get_hf_dataset_size_in_mb,
get_hf_features_from_features, get_hf_features_from_features,
get_latest_parquet_path,
get_latest_video_path,
get_parquet_num_frames,
get_pd_dataframe_size_in_mb,
get_safe_version, get_safe_version,
get_video_duration_in_s,
get_video_size_in_mb,
hf_transform_to_torch, hf_transform_to_torch,
is_valid_version, is_valid_version,
legacy_load_episodes,
legacy_load_episodes_stats,
load_episodes, load_episodes,
load_episodes_stats, load_episodes_stats,
load_info, load_info,
load_nested_dataset,
load_stats, load_stats,
legacy_load_tasks,
load_tasks, load_tasks,
update_chunk_file_indices,
validate_episode_buffer, validate_episode_buffer,
validate_frame, validate_frame,
write_episode, write_episode,
write_episode_stats, legacy_write_episode_stats,
write_info, write_info,
write_json, write_json,
write_tasks,
) )
from lerobot.common.datasets.v30.convert_dataset_v21_to_v30 import get_parquet_file_size_in_mb
from lerobot.common.datasets.video_utils import ( from lerobot.common.datasets.video_utils import (
VideoFrame, VideoFrame,
decode_video_frames_torchvision, decode_video_frames_torchvision,
@ -105,12 +126,9 @@ class LeRobotDatasetMetadata:
check_version_compatibility(self.repo_id, self._version, CODEBASE_VERSION) check_version_compatibility(self.repo_id, self._version, CODEBASE_VERSION)
self.tasks, self.task_to_task_index = load_tasks(self.root) self.tasks, self.task_to_task_index = load_tasks(self.root)
self.episodes = load_episodes(self.root) self.episodes = load_episodes(self.root)
if self._version < packaging.version.parse("v2.1"):
self.stats = load_stats(self.root)
self.episodes_stats = backward_compatible_episodes_stats(self.stats, self.episodes)
else:
self.episodes_stats = load_episodes_stats(self.root) self.episodes_stats = load_episodes_stats(self.root)
self.stats = aggregate_stats(list(self.episodes_stats.values())) # TODO(rcadene): https://huggingface.slack.com/archives/C02V51Q3800/p1743517952388249?thread_ts=1742896075.499119&cid=C02V51Q3800
# self.stats = aggregate_stats(list(self.episodes_stats.values()))
def pull_from_repo( def pull_from_repo(
self, self,
@ -132,17 +150,19 @@ class LeRobotDatasetMetadata:
return packaging.version.parse(self.info["codebase_version"]) return packaging.version.parse(self.info["codebase_version"])
def get_data_file_path(self, ep_index: int) -> Path: def get_data_file_path(self, ep_index: int) -> Path:
ep_chunk = self.get_episode_chunk(ep_index) chunk_idx = self.episodes[f"data/chunk_index"][ep_index]
fpath = self.data_path.format(episode_chunk=ep_chunk, episode_index=ep_index) file_idx = self.episodes[f"data/file_index"][ep_index]
fpath = self.data_path.format(chunk_index=chunk_idx, file_index=file_idx)
return Path(fpath) return Path(fpath)
def get_video_file_path(self, ep_index: int, vid_key: str) -> Path: def get_video_file_path(self, ep_index: int, vid_key: str) -> Path:
ep_chunk = self.get_episode_chunk(ep_index) chunk_idx = self.episodes[f"{vid_key}/chunk_index"][ep_index]
fpath = self.video_path.format(episode_chunk=ep_chunk, video_key=vid_key, episode_index=ep_index) file_idx = self.episodes[f"{vid_key}/file_index"][ep_index]
fpath = self.video_path.format(video_key=vid_key, chunk_index=chunk_idx, file_index=file_idx)
return Path(fpath) return Path(fpath)
def get_episode_chunk(self, ep_index: int) -> int: # def get_episode_chunk(self, ep_index: int) -> int:
return ep_index // self.chunks_size # return ep_index // self.chunks_size
@property @property
def data_path(self) -> str: def data_path(self) -> str:
@ -210,39 +230,84 @@ class LeRobotDatasetMetadata:
return self.info["total_tasks"] return self.info["total_tasks"]
@property @property
def total_chunks(self) -> int: def chunks_size(self) -> int:
"""Total number of chunks (groups of episodes).""" """Max number of files per chunk."""
return self.info["total_chunks"] return self.info["chunks_size"]
@property @property
def chunks_size(self) -> int: def files_size_in_mb(self) -> int:
"""Max number of episodes per chunk.""" """Max size of file in mega bytes."""
return self.info["chunks_size"] return self.info["files_size_in_mb"]
def get_task_index(self, task: str) -> int | None: def get_task_index(self, task: str) -> int | None:
""" """
Given a task in natural language, returns its task_index if the task already exists in the dataset, Given a task in natural language, returns its task_index if the task already exists in the dataset,
otherwise return None. otherwise return None.
""" """
return self.task_to_task_index.get(task, None) return self.tasks.index[task] if task in self.tasks.index else None
def add_task(self, task: str): def has_task(self, task: str) -> bool:
""" return task in self.task_to_task_index
Given a task in natural language, add it to the dictionary of tasks.
"""
if task in self.task_to_task_index:
raise ValueError(f"The task '{task}' already exists and can't be added twice.")
task_index = self.info["total_tasks"] def save_episode_tasks(self, tasks: list[str]):
self.task_to_task_index[task] = task_index new_tasks = [task for task in tasks if not self.has_task(task)]
self.tasks[task_index] = task
self.info["total_tasks"] += 1
task_dict = { for task in new_tasks:
"task_index": task_index, task_index = len(self.tasks)
"task": task, self.tasks.loc[task] = task_index
}
append_jsonlines(task_dict, self.root / TASKS_PATH) if len(new_tasks) > 0:
# Update on disk
write_tasks(self.tasks, self.root)
def _save_episode(self, episode_dict: dict):
ep_dataset = Dataset.from_dict(episode_dict)
ep_size_in_mb = get_hf_dataset_size_in_mb(ep_dataset)
# Access latest parquet file information
latest_path = get_latest_parquet_path(self.root / EPISODES_DIR)
latest_size_in_mb = get_parquet_file_size_in_mb(latest_path)
chunk_idx, file_idx = get_chunk_file_indices(latest_path)
if latest_size_in_mb + ep_size_in_mb >= self.meta.files_size_in_mb:
# Create new parquet file
chunk_idx, file_idx = update_chunk_file_indices(chunk_idx, file_idx, self.meta.chunks_size)
new_path = self.root / DEFAULT_EPISODES_PATH.format(chunk_index=chunk_idx, file_index=file_idx)
new_path.parent.mkdir(parents=True, exist_ok=True)
ep_df.to_parquet(new_path, index=False)
else:
# Update latest parquet file with new row
ep_df = pd.DataFrame(ep_dataset)
latest_df = pd.concat([latest_df, ep_df], ignore_index=True) # RAM
latest_df.to_parquet(latest_path, index=False)
# Update the Hugging Face dataset by reloading it.
# This process should be fast because only the latest Parquet file has been modified.
# Therefore, only this file needs to be converted to PyArrow; the rest is loaded from the PyArrow memory-mapped cache.
self.episodes = load_episodes(self.root)
def _save_episode_stats(self, episodes_stats: dict):
ep_dataset = Dataset.from_dict(episodes_stats)
ep_size_in_mb = get_hf_dataset_size_in_mb(ep_dataset)
# Access latest parquet file information
latest_path = get_latest_parquet_path(self.root / EPISODES_STATS_DIR)
latest_size_in_mb = get_parquet_file_size_in_mb(latest_path)
chunk_idx, file_idx = get_chunk_file_indices(latest_path)
if latest_size_in_mb + ep_size_in_mb >= self.meta.files_size_in_mb:
# Create new parquet file
chunk_idx, file_idx = update_chunk_file_indices(chunk_idx, file_idx, self.meta.chunks_size)
new_path = self.root / DEFAULT_EPISODES_STATS_PATH.format(chunk_index=chunk_idx, file_index=file_idx)
new_path.parent.mkdir(parents=True, exist_ok=True)
ep_df.to_parquet(new_path, index=False)
else:
# Update latest parquet file with new row
ep_df = pd.DataFrame(ep_dataset)
latest_df = pd.concat([latest_df, ep_df], ignore_index=True) # RAM
latest_df.to_parquet(latest_path, index=False)
self.episodes_stats = load_episodes_stats(self.root)
def save_episode( def save_episode(
self, self,
@ -250,19 +315,14 @@ class LeRobotDatasetMetadata:
episode_length: int, episode_length: int,
episode_tasks: list[str], episode_tasks: list[str],
episode_stats: dict[str, dict], episode_stats: dict[str, dict],
episode_metadata: dict,
) -> None: ) -> None:
# Update info
self.info["total_episodes"] += 1 self.info["total_episodes"] += 1
self.info["total_frames"] += episode_length self.info["total_frames"] += episode_length
chunk = self.get_episode_chunk(episode_index)
if chunk >= self.total_chunks:
self.info["total_chunks"] += 1
self.info["splits"] = {"train": f"0:{self.info['total_episodes']}"} self.info["splits"] = {"train": f"0:{self.info['total_episodes']}"}
self.info["total_videos"] += len(self.video_keys)
if len(self.video_keys) > 0: if len(self.video_keys) > 0:
self.update_video_info() self.update_video_info()
write_info(self.info, self.root) write_info(self.info, self.root)
episode_dict = { episode_dict = {
@ -270,12 +330,12 @@ class LeRobotDatasetMetadata:
"tasks": episode_tasks, "tasks": episode_tasks,
"length": episode_length, "length": episode_length,
} }
self.episodes[episode_index] = episode_dict episode_dict.update(episode_metadata)
write_episode(episode_dict, self.root) self._save_episode(episode_dict)
self._save_episode_stats(episode_stats)
self.episodes_stats[episode_index] = episode_stats
self.stats = aggregate_stats([self.stats, episode_stats]) if self.stats else episode_stats self.stats = aggregate_stats([self.stats, episode_stats]) if self.stats else episode_stats
write_episode_stats(episode_index, episode_stats, self.root) # TODO: write stats
def update_video_info(self) -> None: def update_video_info(self) -> None:
""" """
@ -340,8 +400,11 @@ class LeRobotDatasetMetadata:
features = {**features, **DEFAULT_FEATURES} features = {**features, **DEFAULT_FEATURES}
obj.tasks, obj.task_to_task_index = {}, {} obj.tasks = None
obj.episodes_stats, obj.stats, obj.episodes = {}, {}, {} obj.episodes_stats = None
obj.episodes = None
# TODO(rcadene) stats
obj.stats = {}
obj.info = create_empty_dataset_info(CODEBASE_VERSION, fps, robot_type, features, use_videos) obj.info = create_empty_dataset_info(CODEBASE_VERSION, fps, robot_type, features, use_videos)
if len(obj.video_keys) > 0 and not use_videos: if len(obj.video_keys) > 0 and not use_videos:
raise ValueError() raise ValueError()
@ -486,29 +549,17 @@ class LeRobotDataset(torch.utils.data.Dataset):
self.meta = LeRobotDatasetMetadata( self.meta = LeRobotDatasetMetadata(
self.repo_id, self.root, self.revision, force_cache_sync=force_cache_sync self.repo_id, self.root, self.revision, force_cache_sync=force_cache_sync
) )
if self.episodes is not None and self.meta._version >= packaging.version.parse("v2.1"):
episodes_stats = [self.meta.episodes_stats[ep_idx] for ep_idx in self.episodes]
self.stats = aggregate_stats(episodes_stats)
# Load actual data # Load actual data
try: try:
if force_cache_sync: if force_cache_sync:
raise FileNotFoundError raise FileNotFoundError
assert all((self.root / fpath).is_file() for fpath in self.get_episodes_file_paths())
self.hf_dataset = self.load_hf_dataset() self.hf_dataset = self.load_hf_dataset()
except (AssertionError, FileNotFoundError, NotADirectoryError): except (AssertionError, FileNotFoundError, NotADirectoryError):
self.revision = get_safe_version(self.repo_id, self.revision) self.revision = get_safe_version(self.repo_id, self.revision)
self.download_episodes(download_videos) self.download_episodes(download_videos)
self.hf_dataset = self.load_hf_dataset() self.hf_dataset = self.load_hf_dataset()
self.episode_data_index = get_episode_data_index(self.meta.episodes, self.episodes)
# Check timestamps
timestamps = torch.stack(self.hf_dataset["timestamp"]).numpy()
episode_indices = torch.stack(self.hf_dataset["episode_index"]).numpy()
ep_data_index_np = {k: t.numpy() for k, t in self.episode_data_index.items()}
check_timestamps_sync(timestamps, episode_indices, ep_data_index_np, self.fps, self.tolerance_s)
# Setup delta_indices # Setup delta_indices
if self.delta_timestamps is not None: if self.delta_timestamps is not None:
check_delta_timestamps(self.delta_timestamps, self.fps, self.tolerance_s) check_delta_timestamps(self.delta_timestamps, self.fps, self.tolerance_s)
@ -592,11 +643,10 @@ class LeRobotDataset(torch.utils.data.Dataset):
""" """
# TODO(rcadene, aliberts): implement faster transfer # TODO(rcadene, aliberts): implement faster transfer
# https://huggingface.co/docs/huggingface_hub/en/guides/download#faster-downloads # https://huggingface.co/docs/huggingface_hub/en/guides/download#faster-downloads
files = None
ignore_patterns = None if download_videos else "videos/" ignore_patterns = None if download_videos else "videos/"
files = None
if self.episodes is not None: if self.episodes is not None:
files = self.get_episodes_file_paths() files = self.get_episodes_file_paths()
self.pull_from_repo(allow_patterns=files, ignore_patterns=ignore_patterns) self.pull_from_repo(allow_patterns=files, ignore_patterns=ignore_patterns)
def get_episodes_file_paths(self) -> list[Path]: def get_episodes_file_paths(self) -> list[Path]:
@ -609,31 +659,21 @@ class LeRobotDataset(torch.utils.data.Dataset):
for ep_idx in episodes for ep_idx in episodes
] ]
fpaths += video_files fpaths += video_files
# episodes are stored in the same files, so we return unique paths only
fpaths = list(set(fpaths))
return fpaths return fpaths
def load_hf_dataset(self) -> datasets.Dataset: def load_hf_dataset(self) -> datasets.Dataset:
"""hf_dataset contains all the observations, states, actions, rewards, etc.""" """hf_dataset contains all the observations, states, actions, rewards, etc."""
if self.episodes is None: hf_dataset = load_nested_dataset(self.root / "data")
path = str(self.root / "data") hf_dataset.set_format("torch")
# TODO(rcadene): load_dataset convert parquet to arrow.
# set num_proc to accelerate this conversion
hf_dataset = load_dataset("parquet", data_dir=path, split="train")
else:
files = [str(self.root / self.meta.get_data_file_path(ep_idx)) for ep_idx in self.episodes]
hf_dataset = load_dataset("parquet", data_files=files, split="train")
# TODO(aliberts): hf_dataset.set_format("torch")
hf_dataset.set_transform(hf_transform_to_torch)
return hf_dataset return hf_dataset
def create_hf_dataset(self) -> datasets.Dataset: def create_hf_dataset(self) -> datasets.Dataset:
features = get_hf_features_from_features(self.features) features = get_hf_features_from_features(self.features)
ft_dict = {col: [] for col in features} ft_dict = {col: [] for col in features}
hf_dataset = datasets.Dataset.from_dict(ft_dict, features=features, split="train") hf_dataset = datasets.Dataset.from_dict(ft_dict, features=features, split="train")
hf_dataset.set_format("torch")
# TODO(aliberts): hf_dataset.set_format("torch")
hf_dataset.set_transform(hf_transform_to_torch)
return hf_dataset return hf_dataset
@property @property
@ -664,15 +704,15 @@ class LeRobotDataset(torch.utils.data.Dataset):
return get_hf_features_from_features(self.features) return get_hf_features_from_features(self.features)
def _get_query_indices(self, idx: int, ep_idx: int) -> tuple[dict[str, list[int | bool]]]: def _get_query_indices(self, idx: int, ep_idx: int) -> tuple[dict[str, list[int | bool]]]:
ep_start = self.episode_data_index["from"][ep_idx] ep_start = self.meta.episodes["data/from_index"][ep_idx]
ep_end = self.episode_data_index["to"][ep_idx] ep_end = self.meta.episodes["data/to_index"][ep_idx]
query_indices = { query_indices = {
key: [max(ep_start.item(), min(ep_end.item() - 1, idx + delta)) for delta in delta_idx] key: [max(ep_start, min(ep_end - 1, idx + delta)) for delta in delta_idx]
for key, delta_idx in self.delta_indices.items() for key, delta_idx in self.delta_indices.items()
} }
padding = { # Pad values outside of current episode range padding = { # Pad values outside of current episode range
f"{key}_is_pad": torch.BoolTensor( f"{key}_is_pad": torch.BoolTensor(
[(idx + delta < ep_start.item()) | (idx + delta >= ep_end.item()) for delta in delta_idx] [(idx + delta < ep_start) | (idx + delta >= ep_end) for delta in delta_idx]
) )
for key, delta_idx in self.delta_indices.items() for key, delta_idx in self.delta_indices.items()
} }
@ -687,7 +727,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
for key in self.meta.video_keys: for key in self.meta.video_keys:
if query_indices is not None and key in query_indices: if query_indices is not None and key in query_indices:
timestamps = self.hf_dataset.select(query_indices[key])["timestamp"] timestamps = self.hf_dataset.select(query_indices[key])["timestamp"]
query_timestamps[key] = torch.stack(timestamps).tolist() query_timestamps[key] = timestamps.tolist()
else: else:
query_timestamps[key] = [current_ts] query_timestamps[key] = [current_ts]
@ -695,7 +735,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
def _query_hf_dataset(self, query_indices: dict[str, list[int]]) -> dict: def _query_hf_dataset(self, query_indices: dict[str, list[int]]) -> dict:
return { return {
key: torch.stack(self.hf_dataset.select(q_idx)[key]) key: self.hf_dataset.select(q_idx)[key]
for key, q_idx in query_indices.items() for key, q_idx in query_indices.items()
if key not in self.meta.video_keys if key not in self.meta.video_keys
} }
@ -708,9 +748,15 @@ class LeRobotDataset(torch.utils.data.Dataset):
""" """
item = {} item = {}
for vid_key, query_ts in query_timestamps.items(): for vid_key, query_ts in query_timestamps.items():
# Episodes are stored sequentially on a single mp4 to reduce the number of files.
# Thus we load the start timestamp of the episode on this mp4 and,
# shift the query timestamp accordingly.
from_timestamp = self.meta.episodes[f"{vid_key}/from_timestamp"][ep_idx]
shifted_query_ts = [from_timestamp + ts for ts in query_ts]
video_path = self.root / self.meta.get_video_file_path(ep_idx, vid_key) video_path = self.root / self.meta.get_video_file_path(ep_idx, vid_key)
frames = decode_video_frames_torchvision( frames = decode_video_frames_torchvision(
video_path, query_ts, self.tolerance_s, self.video_backend video_path, shifted_query_ts, self.tolerance_s, self.video_backend
) )
item[vid_key] = frames.squeeze(0) item[vid_key] = frames.squeeze(0)
@ -749,7 +795,9 @@ class LeRobotDataset(torch.utils.data.Dataset):
# Add task as a string # Add task as a string
task_idx = item["task_index"].item() task_idx = item["task_index"].item()
item["task"] = self.meta.tasks[task_idx] if self.meta.tasks["task_index"][task_idx] != task_idx:
raise ValueError("Sanity check on task index failed.")
item["task"] = self.meta.tasks["task"][task_idx]
return item return item
@ -780,6 +828,9 @@ class LeRobotDataset(torch.utils.data.Dataset):
) )
return self.root / fpath return self.root / fpath
def _get_image_file_dir(self, episode_index: int, image_key: str) -> Path:
return self._get_image_file_path(episode_index, image_key, frame_index=0).parent
def _save_image(self, image: torch.Tensor | np.ndarray | PIL.Image.Image, fpath: Path) -> None: def _save_image(self, image: torch.Tensor | np.ndarray | PIL.Image.Image, fpath: Path) -> None:
if self.image_writer is None: if self.image_writer is None:
if isinstance(image, torch.Tensor): if isinstance(image, torch.Tensor):
@ -858,11 +909,8 @@ class LeRobotDataset(torch.utils.data.Dataset):
episode_buffer["index"] = np.arange(self.meta.total_frames, self.meta.total_frames + episode_length) episode_buffer["index"] = np.arange(self.meta.total_frames, self.meta.total_frames + episode_length)
episode_buffer["episode_index"] = np.full((episode_length,), episode_index) episode_buffer["episode_index"] = np.full((episode_length,), episode_index)
# Add new tasks to the tasks dictionary # Update tasks and task indices with new tasks if any
for task in episode_tasks: self.meta.save_episode_tasks(episode_tasks)
task_index = self.meta.get_task_index(task)
if task_index is None:
self.meta.add_task(task)
# Given tasks in natural language, find their corresponding task indices # Given tasks in natural language, find their corresponding task indices
episode_buffer["task_index"] = np.array([self.meta.get_task_index(task) for task in tasks]) episode_buffer["task_index"] = np.array([self.meta.get_task_index(task) for task in tasks])
@ -874,51 +922,107 @@ class LeRobotDataset(torch.utils.data.Dataset):
continue continue
episode_buffer[key] = np.stack(episode_buffer[key]) episode_buffer[key] = np.stack(episode_buffer[key])
# Wait for image writer to end, so that episode stats over images can be computed
self._wait_image_writer() self._wait_image_writer()
self._save_episode_table(episode_buffer, episode_index)
ep_stats = compute_episode_stats(episode_buffer, self.features) ep_stats = compute_episode_stats(episode_buffer, self.features)
if len(self.meta.video_keys) > 0: ep_metadata = self._save_episode_data(episode_buffer, episode_index)
video_paths = self.encode_episode_videos(episode_index) for video_key in self.meta.video_keys:
for key in self.meta.video_keys: ep_metadata.update(self._save_episode_video(video_key, episode_index))
episode_buffer[key] = video_paths[key]
# `meta.save_episode` be executed after encoding the videos # `meta.save_episode` neeed to be executed after encoding the videos
self.meta.save_episode(episode_index, episode_length, episode_tasks, ep_stats) self.meta.save_episode(episode_index, episode_length, episode_tasks, ep_stats, ep_metadata)
ep_data_index = get_episode_data_index(self.meta.episodes, [episode_index]) # TODO(rcadene): remove? there is only one episode in the episode buffer, no need for ep_data_index
ep_data_index_np = {k: t.numpy() for k, t in ep_data_index.items()} # ep_data_index = get_episode_data_index(self.meta.episodes, [episode_index])
check_timestamps_sync( # ep_data_index_np = {k: t.numpy() for k, t in ep_data_index.items()}
episode_buffer["timestamp"], # check_timestamps_sync(
episode_buffer["episode_index"], # episode_buffer["timestamp"],
ep_data_index_np, # episode_buffer["episode_index"],
self.fps, # ep_data_index_np,
self.tolerance_s, # self.fps,
) # self.tolerance_s,
# )
video_files = list(self.root.rglob("*.mp4"))
assert len(video_files) == self.num_episodes * len(self.meta.video_keys)
parquet_files = list(self.root.rglob("*.parquet"))
assert len(parquet_files) == self.num_episodes
# TODO(rcadene): images are also deleted in clear_episode_buffer
# delete images # delete images
img_dir = self.root / "images" img_dir = self.root / "images"
if img_dir.is_dir(): if img_dir.is_dir():
shutil.rmtree(self.root / "images") shutil.rmtree(self.root / "images")
if not episode_data: # Reset the buffer if not episode_data:
# Reset episode buffer
self.episode_buffer = self.create_episode_buffer() self.episode_buffer = self.create_episode_buffer()
def _save_episode_table(self, episode_buffer: dict, episode_index: int) -> None: def _save_episode_data(self, episode_buffer: dict) -> None:
episode_dict = {key: episode_buffer[key] for key in self.hf_features} # Convert buffer into HF Dataset
ep_dataset = datasets.Dataset.from_dict(episode_dict, features=self.hf_features, split="train") ep_dict = {key: episode_buffer[key] for key in self.hf_features}
ep_dataset = datasets.Dataset.from_dict(ep_dict, features=self.hf_features, split="train")
ep_dataset = embed_images(ep_dataset) ep_dataset = embed_images(ep_dataset)
self.hf_dataset = concatenate_datasets([self.hf_dataset, ep_dataset]) ep_size_in_mb = get_hf_dataset_size_in_mb(ep_dataset)
self.hf_dataset.set_transform(hf_transform_to_torch) ep_num_frames = len(ep_dataset)
ep_data_path = self.root / self.meta.get_data_file_path(ep_index=episode_index)
ep_data_path.parent.mkdir(parents=True, exist_ok=True) # Access latest parquet file information
ep_dataset.to_parquet(ep_data_path) latest_path = get_latest_parquet_path(self.root / "data")
latest_size_in_mb = get_parquet_file_size_in_mb(latest_path)
latest_num_frames = get_parquet_num_frames(latest_path)
chunk_idx, file_idx = get_chunk_file_indices(latest_path)
if latest_size_in_mb + ep_size_in_mb >= self.meta.files_size_in_mb:
# Create new parquet file
chunk_idx, file_idx = update_chunk_file_indices(chunk_idx, file_idx, self.meta.chunks_size)
new_path = self.meta.data_path.format(chunk_index=chunk_idx, file_index=file_idx)
new_path.parent.mkdir(parents=True, exist_ok=True)
ep_df.to_parquet(new_path, index=False)
else:
# Update latest parquet file with new rows
ep_df = pd.DataFrame(ep_dataset)
latest_df = pd.concat([latest_df, ep_df], ignore_index=True) # RAM
latest_df.to_parquet(latest_path, index=False)
# Update the Hugging Face dataset by reloading it.
# This process should be fast because only the latest Parquet file has been modified.
# Therefore, only this file needs to be converted to PyArrow; the rest is loaded from the PyArrow memory-mapped cache.
self.hf_dataset = self.load_hf_dataset()
metadata = {
"data/chunk_index": chunk_idx,
"data/file_index": file_idx,
"data/from_index": latest_num_frames,
"data/to_index": latest_num_frames + ep_num_frames,
}
return metadata
def _save_episode_video(self, video_key: str, episode_index: int):
# Encode episode frames into a temporary video
ep_path = self._encode_temporary_episode_video(video_key, episode_index)
ep_size_in_mb = get_video_size_in_mb(ep_path)
ep_duration_in_s = get_video_duration_in_s(ep_path)
# Access latest video file information
latest_path = get_latest_video_path(self.root / "videos", video_key)
latest_size_in_mb = get_video_size_in_mb(latest_path)
latest_duration_in_s = get_video_duration_in_s(latest_path)
chunk_idx, file_idx = get_chunk_file_indices(latest_path)
if latest_size_in_mb + ep_size_in_mb >= self.meta.files_size_in_mb:
# Move temporary episode video to a new video file in the dataset
chunk_idx, file_idx = update_chunk_file_indices(chunk_idx, file_idx, self.meta.chunks_size)
new_path = self.meta.video_path.format(video_key=video_key, chunk_index=chunk_idx, file_index=file_idx)
new_path.parent.mkdir(parents=True, exist_ok=True)
ep_path.replace(new_path)
else:
# Update latest video file
concat_video_files([latest_path, ep_path], self.root, video_key, chunk_idx, file_idx)
metadata = {
"episode_index": episode_index,
f"{video_key}/chunk_index": chunk_idx,
f"{video_key}/file_index": file_idx,
f"{video_key}/from_timestamp": latest_duration_in_s,
f"{video_key}/to_timestamp": latest_duration_in_s + ep_duration_in_s,
}
return metadata
def clear_episode_buffer(self) -> None: def clear_episode_buffer(self) -> None:
episode_index = self.episode_buffer["episode_index"] episode_index = self.episode_buffer["episode_index"]
@ -958,34 +1062,26 @@ class LeRobotDataset(torch.utils.data.Dataset):
if self.image_writer is not None: if self.image_writer is not None:
self.image_writer.wait_until_done() self.image_writer.wait_until_done()
def encode_videos(self) -> None: # TODO(rcadene): this method is currently not used
# def encode_videos(self) -> None:
# """
# Use ffmpeg to convert frames stored as png into mp4 videos.
# Note: `encode_video_frames` is a blocking call. Making it asynchronous shouldn't speedup encoding,
# since video encoding with ffmpeg is already using multithreading.
# """
# for ep_idx in range(self.meta.total_episodes):
# self.encode_episode_videos(ep_idx)
def _encode_temporary_episode_video(self, video_key: str, episode_index: int) -> dict:
""" """
Use ffmpeg to convert frames stored as png into mp4 videos. Use ffmpeg to convert frames stored as png into mp4 videos.
Note: `encode_video_frames` is a blocking call. Making it asynchronous shouldn't speedup encoding, Note: `encode_video_frames` is a blocking call. Making it asynchronous shouldn't speedup encoding,
since video encoding with ffmpeg is already using multithreading. since video encoding with ffmpeg is already using multithreading.
""" """
for ep_idx in range(self.meta.total_episodes): temp_path = Path(tempfile.mkdtemp()) / f"{video_key}_{episode_index:3d}.mp4"
self.encode_episode_videos(ep_idx) img_dir = self._get_image_file_dir(episode_index, video_key)
encode_video_frames(img_dir, temp_path, self.fps, overwrite=True)
def encode_episode_videos(self, episode_index: int) -> dict: return temp_path
"""
Use ffmpeg to convert frames stored as png into mp4 videos.
Note: `encode_video_frames` is a blocking call. Making it asynchronous shouldn't speedup encoding,
since video encoding with ffmpeg is already using multithreading.
"""
video_paths = {}
for key in self.meta.video_keys:
video_path = self.root / self.meta.get_video_file_path(episode_index, key)
video_paths[key] = str(video_path)
if video_path.is_file():
# Skip if video is already encoded. Could be the case when resuming data recording.
continue
img_dir = self._get_image_file_path(
episode_index=episode_index, image_key=key, frame_index=0
).parent
encode_video_frames(img_dir, video_path, self.fps, overwrite=True)
return video_paths
@classmethod @classmethod
def create( def create(
@ -1030,7 +1126,6 @@ class LeRobotDataset(torch.utils.data.Dataset):
obj.image_transforms = None obj.image_transforms = None
obj.delta_timestamps = None obj.delta_timestamps = None
obj.delta_indices = None obj.delta_indices = None
obj.episode_data_index = None
obj.video_backend = video_backend if video_backend is not None else "pyav" obj.video_backend = video_backend if video_backend is not None else "pyav"
return obj return obj

View File

@ -21,40 +21,62 @@ from collections.abc import Iterator
from itertools import accumulate from itertools import accumulate
from pathlib import Path from pathlib import Path
from pprint import pformat from pprint import pformat
import subprocess
import tempfile
from types import SimpleNamespace from types import SimpleNamespace
from typing import Any from typing import Any, Tuple
import datasets import datasets
import jsonlines import jsonlines
import numpy as np import numpy as np
import packaging.version import packaging.version
import pandas
import torch import torch
from datasets.table import embed_table_storage from datasets.table import embed_table_storage
from huggingface_hub import DatasetCard, DatasetCardData, HfApi from huggingface_hub import DatasetCard, DatasetCardData, HfApi
from huggingface_hub.errors import RevisionNotFoundError from huggingface_hub.errors import RevisionNotFoundError
from PIL import Image as PILImage from PIL import Image as PILImage
from torchvision import transforms from torchvision import transforms
from datasets import Dataset, concatenate_datasets
from lerobot.common.datasets.backward_compatibility import ( from lerobot.common.datasets.backward_compatibility import (
V21_MESSAGE, V21_MESSAGE,
V30_MESSAGE,
BackwardCompatibilityError, BackwardCompatibilityError,
ForwardCompatibilityError, ForwardCompatibilityError,
) )
from lerobot.common.robot_devices.robots.utils import Robot from lerobot.common.robot_devices.robots.utils import Robot
from lerobot.common.utils.utils import is_valid_numpy_dtype_string from lerobot.common.utils.utils import is_valid_numpy_dtype_string
from lerobot.configs.types import DictLike, FeatureType, PolicyFeature from lerobot.configs.types import DictLike, FeatureType, PolicyFeature
import pyarrow.parquet as pq
DEFAULT_CHUNK_SIZE = 1000 # Max number of episodes per chunk DEFAULT_CHUNK_SIZE = 1000 # Max number of files per chunk
DEFAULT_FILE_SIZE_IN_MB = 500.0 # Max size per file
# Keep legacy for `convert_dataset_v21_to_v30.py`
LEGACY_EPISODES_PATH = "meta/episodes.jsonl"
LEGACY_STATS_PATH = "meta/stats.json"
LEGACY_EPISODES_STATS_PATH = "meta/episodes_stats.jsonl"
LEGACY_TASKS_PATH = "meta/tasks.jsonl"
LEGACY_DEFAULT_VIDEO_PATH = "videos/chunk-{episode_chunk:03d}/{video_key}/episode_{episode_index:06d}.mp4"
LEGACY_DEFAULT_PARQUET_PATH = "data/chunk-{episode_chunk:03d}/episode_{episode_index:06d}.parquet"
# TODO
DEFAULT_IMAGE_PATH = "images/{image_key}/episode_{episode_index:06d}/frame_{frame_index:06d}.png"
EPISODES_DIR = "meta/episodes"
EPISODES_STATS_DIR = "meta/episodes_stats"
TASKS_DIR = "meta/tasks"
DATA_DIR = "data"
VIDEO_DIR = "videos"
INFO_PATH = "meta/info.json" INFO_PATH = "meta/info.json"
EPISODES_PATH = "meta/episodes.jsonl" CHUNK_FILE_PATTERN = "chunk-{chunk_index:03d}/file-{file_index:03d}"
STATS_PATH = "meta/stats.json" DEFAULT_EPISODES_PATH = EPISODES_DIR + "/" + CHUNK_FILE_PATTERN + ".parquet"
EPISODES_STATS_PATH = "meta/episodes_stats.jsonl" DEFAULT_EPISODES_STATS_PATH = EPISODES_STATS_DIR + "/" + CHUNK_FILE_PATTERN + ".parquet"
TASKS_PATH = "meta/tasks.jsonl" DEFAULT_TASKS_PATH = TASKS_DIR + "/" + CHUNK_FILE_PATTERN + ".parquet"
DEFAULT_DATA_PATH = DATA_DIR + "/" + CHUNK_FILE_PATTERN + ".parquet"
DEFAULT_VIDEO_PATH = "videos/chunk-{episode_chunk:03d}/{video_key}/episode_{episode_index:06d}.mp4" DEFAULT_VIDEO_PATH = VIDEO_DIR + "/{video_key}/" + CHUNK_FILE_PATTERN + ".mp4"
DEFAULT_PARQUET_PATH = "data/chunk-{episode_chunk:03d}/episode_{episode_index:06d}.parquet"
DEFAULT_IMAGE_PATH = "images/{image_key}/episode_{episode_index:06d}/frame_{frame_index:06d}.png"
DATASET_CARD_TEMPLATE = """ DATASET_CARD_TEMPLATE = """
--- ---
@ -75,6 +97,88 @@ DEFAULT_FEATURES = {
} }
def get_hf_dataset_size_in_mb(hf_ds: Dataset) -> int:
return hf_ds.data.nbytes / (1024 ** 2)
def get_pd_dataframe_size_in_mb(df: pandas.DataFrame) -> int:
memory_usage_bytes = df.memory_usage(deep=True).sum()
return memory_usage_bytes / (1024 ** 2)
def get_chunk_file_indices(path: Path) -> Tuple[int, int]:
if not path.stem.startswith("file-") or not path.parent.name.startswith("chunk-"):
raise ValueError(f"Path does not follow {CHUNK_FILE_PATTERN}: '{path}'")
chunk_index = int(path.parent.replace("chunk-", ""))
file_index = int(path.stem.replace("file-", ""))
return chunk_index, file_index
def update_chunk_file_indices(chunk_idx: int, file_idx: int, chunks_size: int):
if file_idx == chunks_size - 1:
file_idx = 0
chunk_idx += 1
else:
file_idx += 1
return chunk_idx, file_idx
def load_nested_dataset(pq_dir: Path) -> Dataset:
""" Find parquet files in provided directory {pq_dir}/chunk-xxx/file-xxx.parquet
Convert parquet files to pyarrow memory mapped in a cache folder for efficient RAM usage
Concatenate all pyarrow references to return HF Dataset format
"""
# TODO(rcadene): set num_proc to accelerate conversion to pyarrow
return concatenate_datasets([Dataset.from_parquet(str(path)) for path in sorted(pq_dir.glob("*/*.parquet"))])
def get_latest_parquet_path(pq_dir: Path) -> Path:
return sorted(pq_dir.glob("*/*.parquet"))[-1]
def get_latest_video_path(pq_dir: Path, video_key: str) -> Path:
return sorted(pq_dir.glob(f"{video_key}/*/*.mp4"))[-1]
def get_parquet_num_frames(parquet_path):
metadata = pq.read_metadata(parquet_path)
return metadata.num_rows
def get_video_size_in_mb(mp4_path: Path):
file_size_bytes = mp4_path.stat().st_size
file_size_mb = file_size_bytes / (1024 ** 2)
return file_size_mb
def concat_video_files(paths_to_cat, new_root, video_key, chunk_idx, file_idx):
# Create a text file with the list of files to concatenate
with tempfile.NamedTemporaryFile(mode='w', delete=False, suffix='.txt') as f:
temp_file_path = f.name
for ep_path in paths_to_cat:
f.write(f"file '{str(ep_path)}'\n")
output_path = new_root / DEFAULT_VIDEO_PATH.format(video_key=video_key, chunk_index=chunk_idx, file_index=file_idx)
output_path.parent.mkdir(parents=True, exist_ok=True)
command = [
'ffmpeg',
'-y',
'-f', 'concat',
'-safe', '0',
'-i', str(temp_file_path),
'-c', 'copy',
str(output_path)
]
subprocess.run(command, check=True)
Path(temp_file_path).unlink()
def get_video_duration_in_s(mp4_file: Path):
result = subprocess.run(
[
'ffprobe',
'-v', 'error',
'-show_entries', 'format=duration',
'-of', 'default=noprint_wrappers=1:nokey=1',
mp4_file
],
stdout=subprocess.PIPE,
stderr=subprocess.STDOUT
)
return float(result.stdout)
def flatten_dict(d: dict, parent_key: str = "", sep: str = "/") -> dict: def flatten_dict(d: dict, parent_key: str = "", sep: str = "/") -> dict:
"""Flatten a nested dictionary structure by collapsing nested keys into one key with a separator. """Flatten a nested dictionary structure by collapsing nested keys into one key with a separator.
@ -124,6 +228,8 @@ def serialize_dict(stats: dict[str, torch.Tensor | np.ndarray | dict]) -> dict:
for key, value in flatten_dict(stats).items(): for key, value in flatten_dict(stats).items():
if isinstance(value, (torch.Tensor, np.ndarray)): if isinstance(value, (torch.Tensor, np.ndarray)):
serialized_dict[key] = value.tolist() serialized_dict[key] = value.tolist()
elif isinstance(value, list) and isinstance(value[0], (int, float, list)):
serialized_dict[key] = value
elif isinstance(value, np.generic): elif isinstance(value, np.generic):
serialized_dict[key] = value.item() serialized_dict[key] = value.item()
elif isinstance(value, (int, float)): elif isinstance(value, (int, float)):
@ -183,7 +289,7 @@ def load_info(local_dir: Path) -> dict:
def write_stats(stats: dict, local_dir: Path): def write_stats(stats: dict, local_dir: Path):
serialized_stats = serialize_dict(stats) serialized_stats = serialize_dict(stats)
write_json(serialized_stats, local_dir / STATS_PATH) write_json(serialized_stats, local_dir / LEGACY_STATS_PATH)
def cast_stats_to_numpy(stats) -> dict[str, dict[str, np.ndarray]]: def cast_stats_to_numpy(stats) -> dict[str, dict[str, np.ndarray]]:
@ -192,50 +298,91 @@ def cast_stats_to_numpy(stats) -> dict[str, dict[str, np.ndarray]]:
def load_stats(local_dir: Path) -> dict[str, dict[str, np.ndarray]]: def load_stats(local_dir: Path) -> dict[str, dict[str, np.ndarray]]:
if not (local_dir / STATS_PATH).exists(): if not (local_dir / LEGACY_STATS_PATH).exists():
return None return None
stats = load_json(local_dir / STATS_PATH) stats = load_json(local_dir / LEGACY_STATS_PATH)
return cast_stats_to_numpy(stats) return cast_stats_to_numpy(stats)
def write_task(task_index: int, task: dict, local_dir: Path): def write_hf_dataset(hf_dataset: Dataset, local_dir: Path):
if get_hf_dataset_size_in_mb(hf_dataset) > DEFAULT_FILE_SIZE_IN_MB:
raise NotImplementedError("Contact a maintainer.")
path = local_dir / DEFAULT_DATA_PATH.format(chunk_index=0, file_index=0)
path.parent.mkdir(parents=True, exist_ok=True)
hf_dataset.to_parquet(path)
def write_tasks(tasks: pandas.DataFrame, local_dir: Path):
path = local_dir / DEFAULT_TASKS_PATH.format(chunk_index=0, file_index=0)
path.parent.mkdir(parents=True, exist_ok=True)
tasks.to_parquet(path)
def legacy_write_task(task_index: int, task: dict, local_dir: Path):
task_dict = { task_dict = {
"task_index": task_index, "task_index": task_index,
"task": task, "task": task,
} }
append_jsonlines(task_dict, local_dir / TASKS_PATH) append_jsonlines(task_dict, local_dir / LEGACY_TASKS_PATH)
def load_tasks(local_dir: Path) -> tuple[dict, dict]: def legacy_load_tasks(local_dir: Path) -> tuple[dict, dict]:
tasks = load_jsonlines(local_dir / TASKS_PATH) tasks = load_jsonlines(local_dir / LEGACY_TASKS_PATH)
tasks = {item["task_index"]: item["task"] for item in sorted(tasks, key=lambda x: x["task_index"])} tasks = {item["task_index"]: item["task"] for item in sorted(tasks, key=lambda x: x["task_index"])}
task_to_task_index = {task: task_index for task_index, task in tasks.items()} task_to_task_index = {task: task_index for task_index, task in tasks.items()}
return tasks, task_to_task_index return tasks, task_to_task_index
def load_tasks(local_dir: Path):
tasks = load_nested_dataset(local_dir / TASKS_DIR)
# TODO(rcadene): optimize this
task_to_task_index = {d["task"]: d["task_index"] for d in tasks}
return tasks, task_to_task_index
def write_episode(episode: dict, local_dir: Path): def write_episode(episode: dict, local_dir: Path):
append_jsonlines(episode, local_dir / EPISODES_PATH) append_jsonlines(episode, local_dir / LEGACY_EPISODES_PATH)
def write_episodes(episodes: Dataset, local_dir: Path):
if get_hf_dataset_size_in_mb(episodes) > DEFAULT_FILE_SIZE_IN_MB:
raise NotImplementedError("Contact a maintainer.")
def load_episodes(local_dir: Path) -> dict: fpath = local_dir / DEFAULT_EPISODES_PATH.format(chunk_index=0, file_index=0)
episodes = load_jsonlines(local_dir / EPISODES_PATH) fpath.parent.mkdir(parents=True, exist_ok=True)
episodes.to_parquet(fpath)
def legacy_load_episodes(local_dir: Path) -> dict:
episodes = load_jsonlines(local_dir / LEGACY_EPISODES_PATH)
return {item["episode_index"]: item for item in sorted(episodes, key=lambda x: x["episode_index"])} return {item["episode_index"]: item for item in sorted(episodes, key=lambda x: x["episode_index"])}
def load_episodes(local_dir: Path):
hf_dataset = load_nested_dataset(local_dir / EPISODES_DIR)
return hf_dataset
def write_episode_stats(episode_index: int, episode_stats: dict, local_dir: Path): def legacy_write_episode_stats(episode_index: int, episode_stats: dict, local_dir: Path):
# We wrap episode_stats in a dictionary since `episode_stats["episode_index"]` # We wrap episode_stats in a dictionary since `episode_stats["episode_index"]`
# is a dictionary of stats and not an integer. # is a dictionary of stats and not an integer.
episode_stats = {"episode_index": episode_index, "stats": serialize_dict(episode_stats)} episode_stats = {"episode_index": episode_index, "stats": serialize_dict(episode_stats)}
append_jsonlines(episode_stats, local_dir / EPISODES_STATS_PATH) append_jsonlines(episode_stats, local_dir / LEGACY_EPISODES_STATS_PATH)
def load_episodes_stats(local_dir: Path) -> dict: def write_episodes_stats(episodes_stats: Dataset, local_dir: Path):
episodes_stats = load_jsonlines(local_dir / EPISODES_STATS_PATH) if get_hf_dataset_size_in_mb(episodes_stats) > DEFAULT_FILE_SIZE_IN_MB:
raise NotImplementedError("Contact a maintainer.")
fpath = local_dir / DEFAULT_EPISODES_STATS_PATH.format(chunk_index=0, file_index=0)
fpath.parent.mkdir(parents=True, exist_ok=True)
episodes_stats.to_parquet(fpath)
def legacy_load_episodes_stats(local_dir: Path) -> dict:
episodes_stats = load_jsonlines(local_dir / LEGACY_EPISODES_STATS_PATH)
return { return {
item["episode_index"]: cast_stats_to_numpy(item["stats"]) item["episode_index"]: cast_stats_to_numpy(item["stats"])
for item in sorted(episodes_stats, key=lambda x: x["episode_index"]) for item in sorted(episodes_stats, key=lambda x: x["episode_index"])
} }
def load_episodes_stats(local_dir: Path):
hf_dataset = load_nested_dataset(local_dir / EPISODES_STATS_DIR)
return hf_dataset
def backward_compatible_episodes_stats( def backward_compatible_episodes_stats(
stats: dict[str, dict[str, np.ndarray]], episodes: list[int] stats: dict[str, dict[str, np.ndarray]], episodes: list[int]
@ -388,6 +535,7 @@ def get_hf_features_from_features(features: dict) -> datasets.Features:
def get_features_from_robot(robot: Robot, use_videos: bool = True) -> dict: def get_features_from_robot(robot: Robot, use_videos: bool = True) -> dict:
# TODO(rcadene): add fps for each feature
camera_ft = {} camera_ft = {}
if robot.cameras: if robot.cameras:
camera_ft = { camera_ft = {
@ -442,11 +590,11 @@ def create_empty_dataset_info(
"total_frames": 0, "total_frames": 0,
"total_tasks": 0, "total_tasks": 0,
"total_videos": 0, "total_videos": 0,
"total_chunks": 0,
"chunks_size": DEFAULT_CHUNK_SIZE, "chunks_size": DEFAULT_CHUNK_SIZE,
"files_size_in_mb": DEFAULT_FILE_SIZE_IN_MB,
"fps": fps, "fps": fps,
"splits": {}, "splits": {},
"data_path": DEFAULT_PARQUET_PATH, "data_path": DEFAULT_DATA_PATH,
"video_path": DEFAULT_VIDEO_PATH if use_videos else None, "video_path": DEFAULT_VIDEO_PATH if use_videos else None,
"features": features, "features": features,
} }

View File

@ -121,12 +121,12 @@ from safetensors.torch import load_file
from lerobot.common.datasets.utils import ( from lerobot.common.datasets.utils import (
DEFAULT_CHUNK_SIZE, DEFAULT_CHUNK_SIZE,
DEFAULT_PARQUET_PATH, DEFAULT_DATA_PATH,
DEFAULT_VIDEO_PATH, DEFAULT_VIDEO_PATH,
EPISODES_PATH, LEGACY_EPISODES_PATH,
INFO_PATH, INFO_PATH,
STATS_PATH, LEGACY_STATS_PATH,
TASKS_PATH, LEGACY_TASKS_PATH,
create_branch, create_branch,
create_lerobot_dataset_card, create_lerobot_dataset_card,
flatten_dict, flatten_dict,
@ -188,7 +188,7 @@ def convert_stats_to_json(v1_dir: Path, v2_dir: Path) -> None:
serialized_stats = {key: value.tolist() for key, value in stats.items()} serialized_stats = {key: value.tolist() for key, value in stats.items()}
serialized_stats = unflatten_dict(serialized_stats) serialized_stats = unflatten_dict(serialized_stats)
json_path = v2_dir / STATS_PATH json_path = v2_dir / LEGACY_STATS_PATH
json_path.parent.mkdir(exist_ok=True, parents=True) json_path.parent.mkdir(exist_ok=True, parents=True)
with open(json_path, "w") as f: with open(json_path, "w") as f:
json.dump(serialized_stats, f, indent=4) json.dump(serialized_stats, f, indent=4)
@ -291,12 +291,12 @@ def split_parquet_by_episodes(
for ep_chunk in range(total_chunks): for ep_chunk in range(total_chunks):
ep_chunk_start = DEFAULT_CHUNK_SIZE * ep_chunk ep_chunk_start = DEFAULT_CHUNK_SIZE * ep_chunk
ep_chunk_end = min(DEFAULT_CHUNK_SIZE * (ep_chunk + 1), total_episodes) ep_chunk_end = min(DEFAULT_CHUNK_SIZE * (ep_chunk + 1), total_episodes)
chunk_dir = "/".join(DEFAULT_PARQUET_PATH.split("/")[:-1]).format(episode_chunk=ep_chunk) chunk_dir = "/".join(DEFAULT_DATA_PATH.split("/")[:-1]).format(episode_chunk=ep_chunk)
(output_dir / chunk_dir).mkdir(parents=True, exist_ok=True) (output_dir / chunk_dir).mkdir(parents=True, exist_ok=True)
for ep_idx in range(ep_chunk_start, ep_chunk_end): for ep_idx in range(ep_chunk_start, ep_chunk_end):
ep_table = table.filter(pc.equal(table["episode_index"], ep_idx)) ep_table = table.filter(pc.equal(table["episode_index"], ep_idx))
episode_lengths.insert(ep_idx, len(ep_table)) episode_lengths.insert(ep_idx, len(ep_table))
output_file = output_dir / DEFAULT_PARQUET_PATH.format( output_file = output_dir / DEFAULT_DATA_PATH.format(
episode_chunk=ep_chunk, episode_index=ep_idx episode_chunk=ep_chunk, episode_index=ep_idx
) )
pq.write_table(ep_table, output_file) pq.write_table(ep_table, output_file)
@ -496,7 +496,7 @@ def convert_dataset(
assert set(tasks) == {task for ep_tasks in tasks_by_episodes.values() for task in ep_tasks} assert set(tasks) == {task for ep_tasks in tasks_by_episodes.values() for task in ep_tasks}
tasks = [{"task_index": task_idx, "task": task} for task_idx, task in enumerate(tasks)] tasks = [{"task_index": task_idx, "task": task} for task_idx, task in enumerate(tasks)]
write_jsonlines(tasks, v20_dir / TASKS_PATH) write_jsonlines(tasks, v20_dir / LEGACY_TASKS_PATH)
features["task_index"] = { features["task_index"] = {
"dtype": "int64", "dtype": "int64",
"shape": (1,), "shape": (1,),
@ -546,7 +546,7 @@ def convert_dataset(
{"episode_index": ep_idx, "tasks": tasks_by_episodes[ep_idx], "length": episode_lengths[ep_idx]} {"episode_index": ep_idx, "tasks": tasks_by_episodes[ep_idx], "length": episode_lengths[ep_idx]}
for ep_idx in episode_indices for ep_idx in episode_indices
] ]
write_jsonlines(episodes, v20_dir / EPISODES_PATH) write_jsonlines(episodes, v20_dir / LEGACY_EPISODES_PATH)
# Assemble metadata v2.0 # Assemble metadata v2.0
metadata_v2_0 = { metadata_v2_0 = {
@ -560,7 +560,7 @@ def convert_dataset(
"chunks_size": DEFAULT_CHUNK_SIZE, "chunks_size": DEFAULT_CHUNK_SIZE,
"fps": metadata_v1["fps"], "fps": metadata_v1["fps"],
"splits": {"train": f"0:{total_episodes}"}, "splits": {"train": f"0:{total_episodes}"},
"data_path": DEFAULT_PARQUET_PATH, "data_path": DEFAULT_DATA_PATH,
"video_path": DEFAULT_VIDEO_PATH if video_keys else None, "video_path": DEFAULT_VIDEO_PATH if video_keys else None,
"features": features, "features": features,
} }

View File

@ -23,7 +23,7 @@ import logging
from huggingface_hub import HfApi from huggingface_hub import HfApi
from lerobot.common.datasets.lerobot_dataset import CODEBASE_VERSION, LeRobotDataset from lerobot.common.datasets.lerobot_dataset import CODEBASE_VERSION, LeRobotDataset
from lerobot.common.datasets.utils import EPISODES_STATS_PATH, STATS_PATH, load_stats, write_info from lerobot.common.datasets.utils import LEGACY_EPISODES_STATS_PATH, LEGACY_STATS_PATH, load_stats, write_info
from lerobot.common.datasets.v21.convert_stats import check_aggregate_stats, convert_stats from lerobot.common.datasets.v21.convert_stats import check_aggregate_stats, convert_stats
V20 = "v2.0" V20 = "v2.0"
@ -47,8 +47,8 @@ def convert_dataset(
with SuppressWarnings(): with SuppressWarnings():
dataset = LeRobotDataset(repo_id, revision=V20, force_cache_sync=True) dataset = LeRobotDataset(repo_id, revision=V20, force_cache_sync=True)
if (dataset.root / EPISODES_STATS_PATH).is_file(): if (dataset.root / LEGACY_EPISODES_STATS_PATH).is_file():
(dataset.root / EPISODES_STATS_PATH).unlink() (dataset.root / LEGACY_EPISODES_STATS_PATH).unlink()
convert_stats(dataset, num_workers=num_workers) convert_stats(dataset, num_workers=num_workers)
ref_stats = load_stats(dataset.root) ref_stats = load_stats(dataset.root)
@ -60,15 +60,15 @@ def convert_dataset(
dataset.push_to_hub(branch=branch, tag_version=False, allow_patterns="meta/") dataset.push_to_hub(branch=branch, tag_version=False, allow_patterns="meta/")
# delete old stats.json file # delete old stats.json file
if (dataset.root / STATS_PATH).is_file: if (dataset.root / LEGACY_STATS_PATH).is_file:
(dataset.root / STATS_PATH).unlink() (dataset.root / LEGACY_STATS_PATH).unlink()
hub_api = HfApi() hub_api = HfApi()
if hub_api.file_exists( if hub_api.file_exists(
repo_id=dataset.repo_id, filename=STATS_PATH, revision=branch, repo_type="dataset" repo_id=dataset.repo_id, filename=LEGACY_STATS_PATH, revision=branch, repo_type="dataset"
): ):
hub_api.delete_file( hub_api.delete_file(
path_in_repo=STATS_PATH, repo_id=dataset.repo_id, revision=branch, repo_type="dataset" path_in_repo=LEGACY_STATS_PATH, repo_id=dataset.repo_id, revision=branch, repo_type="dataset"
) )
hub_api.create_tag(repo_id, tag=CODEBASE_VERSION, revision=branch, repo_type="dataset") hub_api.create_tag(repo_id, tag=CODEBASE_VERSION, revision=branch, repo_type="dataset")

View File

@ -5,7 +5,7 @@ from tqdm import tqdm
from lerobot.common.datasets.compute_stats import aggregate_stats, get_feature_stats, sample_indices from lerobot.common.datasets.compute_stats import aggregate_stats, get_feature_stats, sample_indices
from lerobot.common.datasets.lerobot_dataset import LeRobotDataset from lerobot.common.datasets.lerobot_dataset import LeRobotDataset
from lerobot.common.datasets.utils import write_episode_stats from lerobot.common.datasets.utils import legacy_write_episode_stats
def sample_episode_video_frames(dataset: LeRobotDataset, episode_index: int, ft_key: str) -> np.ndarray: def sample_episode_video_frames(dataset: LeRobotDataset, episode_index: int, ft_key: str) -> np.ndarray:
@ -58,7 +58,7 @@ def convert_stats(dataset: LeRobotDataset, num_workers: int = 0):
convert_episode_stats(dataset, ep_idx) convert_episode_stats(dataset, ep_idx)
for ep_idx in tqdm(range(total_episodes)): for ep_idx in tqdm(range(total_episodes)):
write_episode_stats(ep_idx, dataset.meta.episodes_stats[ep_idx], dataset.root) legacy_write_episode_stats(ep_idx, dataset.meta.episodes_stats[ep_idx], dataset.root)
def check_aggregate_stats( def check_aggregate_stats(

View File

@ -19,26 +19,301 @@ python lerobot/common/datasets/v30/convert_dataset_v21_to_v30.py \
import argparse import argparse
import logging import logging
from pathlib import Path
import sys
from datasets import Dataset from datasets import Dataset
from huggingface_hub import snapshot_download from huggingface_hub import snapshot_download
import tqdm
from lerobot.common.constants import HF_LEROBOT_HOME from lerobot.common.constants import HF_LEROBOT_HOME
from lerobot.common.datasets.utils import ( from lerobot.common.datasets.utils import (
load_episodes_stats, DEFAULT_CHUNK_SIZE,
DEFAULT_DATA_PATH,
DEFAULT_FILE_SIZE_IN_MB,
DEFAULT_VIDEO_PATH,
concat_video_files,
flatten_dict,
get_parquet_num_frames,
get_video_duration_in_s,
get_video_size_in_mb,
legacy_load_episodes,
legacy_load_episodes_stats,
load_info,
legacy_load_tasks,
update_chunk_file_indices,
write_episodes,
write_episodes_stats,
write_info,
write_tasks,
) )
import subprocess
import tempfile
import pandas as pd
import pyarrow.parquet as pq
V21 = "v2.1" V21 = "v2.1"
class SuppressWarnings: """
def __enter__(self): -------------------------
self.previous_level = logging.getLogger().getEffectiveLevel() OLD
logging.getLogger().setLevel(logging.ERROR) data/chunk-000/episode_000000.parquet
def __exit__(self, exc_type, exc_val, exc_tb): NEW
logging.getLogger().setLevel(self.previous_level) data/chunk-000/file_000.parquet
-------------------------
OLD
videos/chunk-000/CAMERA/episode_000000.mp4
NEW
videos/chunk-000/file_000.mp4
-------------------------
OLD
episodes.jsonl
{"episode_index": 1, "tasks": ["Put the blue block in the green bowl"], "length": 266}
NEW
meta/episodes/chunk-000/episodes_000.parquet
episode_index | video_chunk_index | video_file_index | data_chunk_index | data_file_index | tasks | length
-------------------------
OLD
tasks.jsonl
{"task_index": 1, "task": "Put the blue block in the green bowl"}
NEW
meta/tasks/chunk-000/file_000.parquet
task_index | task
-------------------------
OLD
episodes_stats.jsonl
NEW
meta/episodes_stats/chunk-000/file_000.parquet
episode_index | mean | std | min | max
-------------------------
UPDATE
meta/info.json
-------------------------
"""
def get_parquet_file_size_in_mb(parquet_path):
metadata = pq.read_metadata(parquet_path)
uncompressed_size = metadata.num_rows * metadata.row_group(0).total_byte_size
return uncompressed_size / (1024 ** 2)
def generate_flat_ep_stats(episodes_stats):
for ep_idx, ep_stats in episodes_stats.items():
flat_ep_stats = flatten_dict(ep_stats)
flat_ep_stats["episode_index"] = ep_idx
yield flat_ep_stats
def convert_episodes_stats(root, new_root):
episodes_stats = legacy_load_episodes_stats(root)
ds_episodes_stats = Dataset.from_generator(lambda: generate_flat_ep_stats(episodes_stats))
write_episodes_stats(ds_episodes_stats, new_root)
def generate_task_dict(tasks):
for task_index, task in tasks.items():
yield {"task_index": task_index, "task": task}
def convert_tasks(root, new_root):
tasks, _ = legacy_load_tasks(root)
ds_tasks = Dataset.from_generator(lambda: generate_task_dict(tasks))
write_tasks(ds_tasks, new_root)
def concat_data_files(paths_to_cat, new_root, chunk_idx, file_idx):
# TODO(rcadene): to save RAM use Dataset.from_parquet(file) and concatenate_datasets
dataframes = [pd.read_parquet(file) for file in paths_to_cat]
# Concatenate all DataFrames along rows
concatenated_df = pd.concat(dataframes, ignore_index=True)
path = new_root / DEFAULT_DATA_PATH.format(chunk_index=chunk_idx, file_index=file_idx)
path.parent.mkdir(parents=True, exist_ok=True)
concatenated_df.to_parquet(path, index=False)
def convert_data(root, new_root):
data_dir = root / "data"
ep_paths = [path for path in data_dir.glob("*/*.parquet")]
ep_paths = sorted(ep_paths)
episodes_metadata = []
ep_idx = 0
chunk_idx = 0
file_idx = 0
size_in_mb = 0
num_frames = 0
paths_to_cat = []
for ep_path in ep_paths:
ep_size_in_mb = get_parquet_file_size_in_mb(ep_path)
ep_num_frames = get_parquet_num_frames(ep_path)
ep_metadata = {
"episode_index": ep_idx,
"data/chunk_index": chunk_idx,
"data/file_index": file_idx,
"data/from_index": num_frames,
"data/to_index": num_frames + ep_num_frames,
}
size_in_mb += ep_size_in_mb
num_frames += ep_num_frames
episodes_metadata.append(ep_metadata)
ep_idx += 1
if size_in_mb < DEFAULT_FILE_SIZE_IN_MB:
paths_to_cat.append(ep_path)
continue
concat_data_files(paths_to_cat, new_root, chunk_idx, file_idx)
# Reset for the next file
size_in_mb = ep_size_in_mb
num_frames = ep_num_frames
paths_to_cat = [ep_path]
chunk_idx, file_idx = update_chunk_file_indices(chunk_idx, file_idx, DEFAULT_CHUNK_SIZE)
# Write remaining data if any
if paths_to_cat:
concat_data_files(paths_to_cat, new_root, chunk_idx, file_idx)
return episodes_metadata
def get_video_keys(root):
info = load_info(root)
features = info["features"]
image_keys = [key for key, ft in features.items() if ft["dtype"] == "image"]
if len(image_keys) != 0:
raise NotImplementedError()
video_keys = [key for key, ft in features.items() if ft["dtype"] == "video"]
return video_keys
def convert_videos(root: Path, new_root: Path):
video_keys = get_video_keys(root)
video_keys = sorted(video_keys)
eps_metadata_per_cam = []
for camera in video_keys:
eps_metadata = convert_videos_of_camera(root, new_root, camera)
eps_metadata_per_cam.append(eps_metadata)
num_eps_per_cam = [len(eps_cam_map) for eps_cam_map in eps_metadata_per_cam]
if len(set(num_eps_per_cam)) != 1:
raise ValueError(f"All cams dont have same number of episodes ({num_eps_per_cam}).")
episods_metadata = []
num_cameras = len(video_keys)
num_episodes = num_eps_per_cam[0]
for ep_idx in range(num_episodes):
# Sanity check
ep_ids = [eps_metadata_per_cam[cam_idx][ep_idx]["episode_index"] for cam_idx in range(num_cameras)]
ep_ids += [ep_idx]
if len(set(ep_ids)) != 1:
raise ValueError(f"All episode indices need to match ({ep_ids}).")
ep_dict = {}
for cam_idx in range(num_cameras):
ep_dict.update(eps_metadata_per_cam[cam_idx][ep_idx])
episods_metadata.append(ep_dict)
return episods_metadata
def convert_videos_of_camera(root: Path, new_root: Path, video_key):
# Access old paths to mp4
videos_dir = root / "videos"
ep_paths = [path for path in videos_dir.glob(f"*/{video_key}/*.mp4")]
ep_paths = sorted(ep_paths)
episodes_metadata = []
ep_idx = 0
chunk_idx = 0
file_idx = 0
size_in_mb = 0
duration_in_s = 0.0
paths_to_cat = []
for ep_path in tqdm.tqdm(ep_paths, desc=f"convert videos of {video_key}"):
ep_size_in_mb = get_video_size_in_mb(ep_path)
ep_duration_in_s = get_video_duration_in_s(ep_path)
ep_metadata = {
"episode_index": ep_idx,
f"{video_key}/chunk_index": chunk_idx,
f"{video_key}/file_index": file_idx,
f"{video_key}/from_timestamp": duration_in_s,
f"{video_key}/to_timestamp": duration_in_s + ep_duration_in_s,
}
size_in_mb += ep_size_in_mb
duration_in_s += ep_duration_in_s
episodes_metadata.append(ep_metadata)
ep_idx += 1
if size_in_mb < DEFAULT_FILE_SIZE_IN_MB:
paths_to_cat.append(ep_path)
continue
concat_video_files(paths_to_cat, new_root, video_key, chunk_idx, file_idx)
# Reset for the next file
size_in_mb = ep_size_in_mb
duration_in_s = ep_duration_in_s
paths_to_cat = [ep_path]
chunk_idx, file_idx = update_chunk_file_indices(chunk_idx, file_idx, DEFAULT_CHUNK_SIZE)
# Write remaining videos if any
if paths_to_cat:
concat_video_files(paths_to_cat, new_root, video_key, chunk_idx, file_idx)
return episodes_metadata
def generate_episode_dict(episodes, episodes_data, episodes_videos):
for ep, ep_data, ep_video in zip(episodes.values(), episodes_data, episodes_videos):
ep_idx = ep["episode_index"]
ep_idx_data = ep_data["episode_index"]
ep_idx_video = ep_video["episode_index"]
if len(set([ep_idx, ep_idx_data, ep_idx_video])) != 1:
raise ValueError(f"Number of episodes is not the same ({ep_idx=},{ep_idx_data=},{ep_idx_video=}).")
ep_dict = {**ep_data, **ep_video, **ep}
yield ep_dict
def convert_episodes(root, new_root, episodes_data, episodes_videos):
episodes = legacy_load_episodes(root)
num_eps = len(episodes)
num_eps_data = len(episodes_data)
num_eps_video = len(episodes_videos)
if len(set([num_eps, num_eps_data, num_eps_video])) != 1:
raise ValueError(f"Number of episodes is not the same ({num_eps=},{num_eps_data=},{num_eps_video=}).")
ds_episodes = Dataset.from_generator(lambda: generate_episode_dict(episodes, episodes_data, episodes_videos))
write_episodes(ds_episodes, new_root)
def convert_info(root, new_root):
info = load_info(root)
info["codebase_version"] = "v3.0"
del info["total_chunks"]
del info["total_videos"]
info["files_size_in_mb"] = DEFAULT_FILE_SIZE_IN_MB
# TODO(rcadene): chunk- or chunk_ or file- or file_
info["data_path"] = DEFAULT_DATA_PATH
info["video_path"] = DEFAULT_VIDEO_PATH
info["fps"] = float(info["fps"])
for key in info["features"]:
if info["features"][key]["dtype"] == "video":
# already has fps in video_info
continue
info["features"][key]["fps"] = info["fps"]
write_info(info, new_root)
def convert_dataset( def convert_dataset(
repo_id: str, repo_id: str,
@ -46,6 +321,8 @@ def convert_dataset(
num_workers: int = 4, num_workers: int = 4,
): ):
root = HF_LEROBOT_HOME / repo_id root = HF_LEROBOT_HOME / repo_id
new_root = HF_LEROBOT_HOME / f"{repo_id}_v30"
snapshot_download( snapshot_download(
repo_id, repo_id,
repo_type="dataset", repo_type="dataset",
@ -53,63 +330,12 @@ def convert_dataset(
local_dir=root, local_dir=root,
) )
# Concatenate videos convert_info(root, new_root)
convert_episodes_stats(root, new_root)
# Create convert_tasks(root, new_root)
episodes_data_mapping = convert_data(root, new_root)
""" episodes_videos_mapping = convert_videos(root, new_root)
------------------------- convert_episodes(root, new_root, episodes_data_mapping, episodes_videos_mapping)
OLD
data/chunk-000/episode_000000.parquet
NEW
data/chunk-000/file_000.parquet
-------------------------
OLD
videos/chunk-000/CAMERA/episode_000000.mp4
NEW
videos/chunk-000/file_000.mp4
-------------------------
OLD
episodes.jsonl
{"episode_index": 1, "tasks": ["Put the blue block in the green bowl"], "length": 266}
NEW
meta/episodes/chunk-000/episodes_000.parquet
episode_index | video_chunk_index | video_file_index | data_chunk_index | data_file_index | tasks | length
-------------------------
OLD
tasks.jsonl
{"task_index": 1, "task": "Put the blue block in the green bowl"}
NEW
meta/tasks/chunk-000/file_000.parquet
task_index | task
-------------------------
OLD
episodes_stats.jsonl
NEW
meta/episodes_stats/chunk-000/file_000.parquet
episode_index | mean | std | min | max
-------------------------
UPDATE
meta/info.json
-------------------------
"""
new_root = HF_LEROBOT_HOME / f"{repo_id}_v30"
new_root.mkdir(parents=True, exist_ok=True)
episodes_stats = load_episodes_stats(root)
hf_dataset = Dataset.from_dict(episodes_stats) # noqa: F841
meta_ep_st_ch = new_root / "meta/episodes_stats/chunk-000"
meta_ep_st_ch.mkdir(parents=True, exist_ok=True)
# hf_dataset.to_parquet(meta_ep_st_ch / 'file_000.parquet')
if __name__ == "__main__": if __name__ == "__main__":
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()

View File

@ -9,13 +9,16 @@ import numpy as np
import PIL.Image import PIL.Image
import pytest import pytest
import torch import torch
from datasets import Dataset
from lerobot.common.datasets.lerobot_dataset import CODEBASE_VERSION, LeRobotDataset, LeRobotDatasetMetadata from lerobot.common.datasets.lerobot_dataset import CODEBASE_VERSION, LeRobotDataset, LeRobotDatasetMetadata
from lerobot.common.datasets.utils import ( from lerobot.common.datasets.utils import (
DEFAULT_CHUNK_SIZE, DEFAULT_CHUNK_SIZE,
DEFAULT_FEATURES, DEFAULT_FEATURES,
DEFAULT_PARQUET_PATH, DEFAULT_DATA_PATH,
DEFAULT_FILE_SIZE_IN_MB,
DEFAULT_VIDEO_PATH, DEFAULT_VIDEO_PATH,
flatten_dict,
get_hf_features_from_features, get_hf_features_from_features,
hf_transform_to_torch, hf_transform_to_torch,
) )
@ -33,10 +36,9 @@ class LeRobotDatasetFactory(Protocol):
def __call__(self, *args, **kwargs) -> LeRobotDataset: ... def __call__(self, *args, **kwargs) -> LeRobotDataset: ...
def get_task_index(task_dicts: dict, task: str) -> int: def get_task_index(tasks: Dataset, task: str) -> int:
tasks = {d["task_index"]: d["task"] for d in task_dicts.values()} task_idx = tasks["task"].index(task)
task_to_task_index = {task: task_idx for task_idx, task in tasks.items()} return task_idx
return task_to_task_index[task]
@pytest.fixture(scope="session") @pytest.fixture(scope="session")
@ -104,9 +106,9 @@ def info_factory(features_factory):
total_frames: int = 0, total_frames: int = 0,
total_tasks: int = 0, total_tasks: int = 0,
total_videos: int = 0, total_videos: int = 0,
total_chunks: int = 0,
chunks_size: int = DEFAULT_CHUNK_SIZE, chunks_size: int = DEFAULT_CHUNK_SIZE,
data_path: str = DEFAULT_PARQUET_PATH, files_size_in_mb: float = DEFAULT_FILE_SIZE_IN_MB,
data_path: str = DEFAULT_DATA_PATH,
video_path: str = DEFAULT_VIDEO_PATH, video_path: str = DEFAULT_VIDEO_PATH,
motor_features: dict = DUMMY_MOTOR_FEATURES, motor_features: dict = DUMMY_MOTOR_FEATURES,
camera_features: dict = DUMMY_CAMERA_FEATURES, camera_features: dict = DUMMY_CAMERA_FEATURES,
@ -120,8 +122,8 @@ def info_factory(features_factory):
"total_frames": total_frames, "total_frames": total_frames,
"total_tasks": total_tasks, "total_tasks": total_tasks,
"total_videos": total_videos, "total_videos": total_videos,
"total_chunks": total_chunks,
"chunks_size": chunks_size, "chunks_size": chunks_size,
"files_size_in_mb": files_size_in_mb,
"fps": fps, "fps": fps,
"splits": {}, "splits": {},
"data_path": data_path, "data_path": data_path,
@ -168,25 +170,25 @@ def episodes_stats_factory(stats_factory):
features: dict[str], features: dict[str],
total_episodes: int = 3, total_episodes: int = 3,
) -> dict: ) -> dict:
episodes_stats = {}
for episode_index in range(total_episodes): def _generator(total_episodes):
episodes_stats[episode_index] = { for ep_idx in range(total_episodes):
"episode_index": episode_index, flat_ep_stats = flatten_dict(stats_factory(features))
"stats": stats_factory(features), flat_ep_stats["episode_index"] = ep_idx
} yield flat_ep_stats
return episodes_stats
# Simpler to rely on generator instead of from_dict
return Dataset.from_generator(lambda: _generator(total_episodes))
return _create_episodes_stats return _create_episodes_stats
@pytest.fixture(scope="session") @pytest.fixture(scope="session")
def tasks_factory(): def tasks_factory():
def _create_tasks(total_tasks: int = 3) -> int: def _create_tasks(total_tasks: int = 3) -> Dataset:
tasks = {} ids = list(range(total_tasks))
for task_index in range(total_tasks): tasks = [f"Perform action {i}." for i in ids]
task_dict = {"task_index": task_index, "task": f"Perform action {task_index}."} return Dataset.from_dict({"task_index": ids, "task": tasks})
tasks[task_index] = task_dict
return tasks
return _create_tasks return _create_tasks
@ -196,6 +198,7 @@ def episodes_factory(tasks_factory):
def _create_episodes( def _create_episodes(
total_episodes: int = 3, total_episodes: int = 3,
total_frames: int = 400, total_frames: int = 400,
video_keys: list[str] | None = None,
tasks: dict | None = None, tasks: dict | None = None,
multi_task: bool = False, multi_task: bool = False,
): ):
@ -215,26 +218,41 @@ def episodes_factory(tasks_factory):
# Generate random lengths that sum up to total_length # Generate random lengths that sum up to total_length
lengths = np.random.multinomial(total_frames, [1 / total_episodes] * total_episodes).tolist() lengths = np.random.multinomial(total_frames, [1 / total_episodes] * total_episodes).tolist()
tasks_list = [task_dict["task"] for task_dict in tasks.values()] num_tasks_available = len(tasks["task"])
num_tasks_available = len(tasks_list)
episodes = {} d = {
remaining_tasks = tasks_list.copy() "episode_index": [],
"data/chunk_index": [],
"data/file_index": [],
"tasks": [],
"length": [],
}
if video_keys is not None:
for video_key in video_keys:
d[f"{video_key}/chunk_index"] = []
d[f"{video_key}/file_index"] = []
remaining_tasks = tasks["task"].copy()
for ep_idx in range(total_episodes): for ep_idx in range(total_episodes):
num_tasks_in_episode = random.randint(1, min(3, num_tasks_available)) if multi_task else 1 num_tasks_in_episode = random.randint(1, min(3, num_tasks_available)) if multi_task else 1
tasks_to_sample = remaining_tasks if remaining_tasks else tasks_list tasks_to_sample = remaining_tasks if remaining_tasks else tasks["task"]
episode_tasks = random.sample(tasks_to_sample, min(num_tasks_in_episode, len(tasks_to_sample))) episode_tasks = random.sample(tasks_to_sample, min(num_tasks_in_episode, len(tasks_to_sample)))
if remaining_tasks: if remaining_tasks:
for task in episode_tasks: for task in episode_tasks:
remaining_tasks.remove(task) remaining_tasks.remove(task)
episodes[ep_idx] = { d["episode_index"].append(ep_idx)
"episode_index": ep_idx, # TODO(rcadene): remove heuristic of only one file
"tasks": episode_tasks, d["data/chunk_index"].append(0)
"length": lengths[ep_idx], d["data/file_index"].append(0)
} d["tasks"].append(episode_tasks)
d["length"].append(lengths[ep_idx])
if video_keys is not None:
for video_key in video_keys:
d[f"{video_key}/chunk_index"].append(0)
d[f"{video_key}/file_index"].append(0)
return episodes return Dataset.from_dict(d)
return _create_episodes return _create_episodes
@ -258,7 +276,7 @@ def hf_dataset_factory(features_factory, tasks_factory, episodes_factory, img_ar
frame_index_col = np.array([], dtype=np.int64) frame_index_col = np.array([], dtype=np.int64)
episode_index_col = np.array([], dtype=np.int64) episode_index_col = np.array([], dtype=np.int64)
task_index = np.array([], dtype=np.int64) task_index = np.array([], dtype=np.int64)
for ep_dict in episodes.values(): for ep_dict in episodes:
timestamp_col = np.concatenate((timestamp_col, np.arange(ep_dict["length"]) / fps)) timestamp_col = np.concatenate((timestamp_col, np.arange(ep_dict["length"]) / fps))
frame_index_col = np.concatenate((frame_index_col, np.arange(ep_dict["length"], dtype=int))) frame_index_col = np.concatenate((frame_index_col, np.arange(ep_dict["length"], dtype=int)))
episode_index_col = np.concatenate( episode_index_col = np.concatenate(
@ -291,7 +309,7 @@ def hf_dataset_factory(features_factory, tasks_factory, episodes_factory, img_ar
}, },
features=hf_features, features=hf_features,
) )
dataset.set_transform(hf_transform_to_torch) dataset.set_format("torch")
return dataset return dataset
return _create_hf_dataset return _create_hf_dataset
@ -326,8 +344,9 @@ def lerobot_dataset_metadata_factory(
if not tasks: if not tasks:
tasks = tasks_factory(total_tasks=info["total_tasks"]) tasks = tasks_factory(total_tasks=info["total_tasks"])
if not episodes: if not episodes:
video_keys = [key for key, ft in info["features"].items() if ft["dtype"] == "video"]
episodes = episodes_factory( episodes = episodes_factory(
total_episodes=info["total_episodes"], total_frames=info["total_frames"], tasks=tasks total_episodes=info["total_episodes"], total_frames=info["total_frames"], video_keys=video_keys, tasks=tasks
) )
mock_snapshot_download = mock_snapshot_download_factory( mock_snapshot_download = mock_snapshot_download_factory(
@ -371,9 +390,9 @@ def lerobot_dataset_factory(
multi_task: bool = False, multi_task: bool = False,
info: dict | None = None, info: dict | None = None,
stats: dict | None = None, stats: dict | None = None,
episodes_stats: list[dict] | None = None, episodes_stats: datasets.Dataset | None = None,
tasks: list[dict] | None = None, tasks: datasets.Dataset | None = None,
episode_dicts: list[dict] | None = None, episode_dicts: datasets.Dataset | None = None,
hf_dataset: datasets.Dataset | None = None, hf_dataset: datasets.Dataset | None = None,
**kwargs, **kwargs,
) -> LeRobotDataset: ) -> LeRobotDataset:
@ -388,9 +407,11 @@ def lerobot_dataset_factory(
if not tasks: if not tasks:
tasks = tasks_factory(total_tasks=info["total_tasks"]) tasks = tasks_factory(total_tasks=info["total_tasks"])
if not episode_dicts: if not episode_dicts:
video_keys = [key for key, ft in info["features"].items() if ft["dtype"] == "video"]
episode_dicts = episodes_factory( episode_dicts = episodes_factory(
total_episodes=info["total_episodes"], total_episodes=info["total_episodes"],
total_frames=info["total_frames"], total_frames=info["total_frames"],
video_keys=video_keys,
tasks=tasks, tasks=tasks,
multi_task=multi_task, multi_task=multi_task,
) )

View File

@ -7,83 +7,75 @@ import pyarrow.compute as pc
import pyarrow.parquet as pq import pyarrow.parquet as pq
import pytest import pytest
from datasets import Dataset
from lerobot.common.datasets.utils import ( from lerobot.common.datasets.utils import (
EPISODES_PATH, write_episodes,
EPISODES_STATS_PATH, write_episodes_stats,
INFO_PATH, write_hf_dataset,
STATS_PATH, write_info,
TASKS_PATH, write_stats,
write_tasks,
) )
@pytest.fixture(scope="session") @pytest.fixture(scope="session")
def info_path(info_factory): def create_info(info_factory):
def _create_info_json_file(dir: Path, info: dict | None = None) -> Path: def _create_info(dir: Path, info: dict | None = None):
if not info: if not info:
info = info_factory() info = info_factory()
fpath = dir / INFO_PATH write_info(info, dir)
fpath.parent.mkdir(parents=True, exist_ok=True)
with open(fpath, "w") as f:
json.dump(info, f, indent=4, ensure_ascii=False)
return fpath
return _create_info_json_file return _create_info
@pytest.fixture(scope="session") @pytest.fixture(scope="session")
def stats_path(stats_factory): def create_stats(stats_factory):
def _create_stats_json_file(dir: Path, stats: dict | None = None) -> Path: def _create_stats(dir: Path, stats: dict | None = None):
if not stats: if not stats:
stats = stats_factory() stats = stats_factory()
fpath = dir / STATS_PATH write_stats(stats, dir)
fpath.parent.mkdir(parents=True, exist_ok=True)
with open(fpath, "w") as f:
json.dump(stats, f, indent=4, ensure_ascii=False)
return fpath
return _create_stats_json_file return _create_stats
@pytest.fixture(scope="session") @pytest.fixture(scope="session")
def episodes_stats_path(episodes_stats_factory): def create_episodes_stats(episodes_stats_factory):
def _create_episodes_stats_jsonl_file(dir: Path, episodes_stats: list[dict] | None = None) -> Path: def _create_episodes_stats(dir: Path, episodes_stats: Dataset | None = None):
if not episodes_stats: if not episodes_stats:
episodes_stats = episodes_stats_factory() episodes_stats = episodes_stats_factory()
fpath = dir / EPISODES_STATS_PATH write_episodes_stats(episodes_stats, dir)
fpath.parent.mkdir(parents=True, exist_ok=True)
with jsonlines.open(fpath, "w") as writer:
writer.write_all(episodes_stats.values())
return fpath
return _create_episodes_stats_jsonl_file return _create_episodes_stats
@pytest.fixture(scope="session") @pytest.fixture(scope="session")
def tasks_path(tasks_factory): def create_tasks(tasks_factory):
def _create_tasks_jsonl_file(dir: Path, tasks: list | None = None) -> Path: def _create_tasks(dir: Path, tasks: Dataset | None = None):
if not tasks: if not tasks:
tasks = tasks_factory() tasks = tasks_factory()
fpath = dir / TASKS_PATH write_tasks(tasks, dir)
fpath.parent.mkdir(parents=True, exist_ok=True)
with jsonlines.open(fpath, "w") as writer:
writer.write_all(tasks.values())
return fpath
return _create_tasks_jsonl_file return _create_tasks
@pytest.fixture(scope="session") @pytest.fixture(scope="session")
def episode_path(episodes_factory): def create_episodes(episodes_factory):
def _create_episodes_jsonl_file(dir: Path, episodes: list | None = None) -> Path: def _create_episodes(dir: Path, episodes: Dataset | None = None):
if not episodes: if not episodes:
episodes = episodes_factory() episodes = episodes_factory()
fpath = dir / EPISODES_PATH write_episodes(episodes, dir)
fpath.parent.mkdir(parents=True, exist_ok=True)
with jsonlines.open(fpath, "w") as writer:
writer.write_all(episodes.values())
return fpath
return _create_episodes_jsonl_file return _create_episodes
@pytest.fixture(scope="session")
def create_hf_dataset(hf_dataset_factory):
def _create_hf_dataset(dir: Path, hf_dataset: Dataset | None = None):
if not hf_dataset:
hf_dataset = hf_dataset_factory()
write_hf_dataset(hf_dataset, dir)
return _create_hf_dataset
@pytest.fixture(scope="session") @pytest.fixture(scope="session")
@ -91,6 +83,7 @@ def single_episode_parquet_path(hf_dataset_factory, info_factory):
def _create_single_episode_parquet( def _create_single_episode_parquet(
dir: Path, ep_idx: int = 0, hf_dataset: datasets.Dataset | None = None, info: dict | None = None dir: Path, ep_idx: int = 0, hf_dataset: datasets.Dataset | None = None, info: dict | None = None
) -> Path: ) -> Path:
raise NotImplementedError()
if not info: if not info:
info = info_factory() info = info_factory()
if hf_dataset is None: if hf_dataset is None:
@ -114,6 +107,7 @@ def multi_episode_parquet_path(hf_dataset_factory, info_factory):
def _create_multi_episode_parquet( def _create_multi_episode_parquet(
dir: Path, hf_dataset: datasets.Dataset | None = None, info: dict | None = None dir: Path, hf_dataset: datasets.Dataset | None = None, info: dict | None = None
) -> Path: ) -> Path:
raise NotImplementedError()
if not info: if not info:
info = info_factory() info = info_factory()
if hf_dataset is None: if hf_dataset is None:

104
tests/fixtures/hub.py vendored
View File

@ -5,11 +5,12 @@ import pytest
from huggingface_hub.utils import filter_repo_objects from huggingface_hub.utils import filter_repo_objects
from lerobot.common.datasets.utils import ( from lerobot.common.datasets.utils import (
EPISODES_PATH, DEFAULT_DATA_PATH,
EPISODES_STATS_PATH, DEFAULT_EPISODES_PATH,
DEFAULT_EPISODES_STATS_PATH,
DEFAULT_TASKS_PATH,
INFO_PATH, INFO_PATH,
STATS_PATH, LEGACY_STATS_PATH,
TASKS_PATH,
) )
from tests.fixtures.constants import LEROBOT_TEST_DIR from tests.fixtures.constants import LEROBOT_TEST_DIR
@ -17,17 +18,17 @@ from tests.fixtures.constants import LEROBOT_TEST_DIR
@pytest.fixture(scope="session") @pytest.fixture(scope="session")
def mock_snapshot_download_factory( def mock_snapshot_download_factory(
info_factory, info_factory,
info_path, create_info,
stats_factory, stats_factory,
stats_path, create_stats,
episodes_stats_factory, episodes_stats_factory,
episodes_stats_path, create_episodes_stats,
tasks_factory, tasks_factory,
tasks_path, create_tasks,
episodes_factory, episodes_factory,
episode_path, create_episodes,
single_episode_parquet_path,
hf_dataset_factory, hf_dataset_factory,
create_hf_dataset,
): ):
""" """
This factory allows to patch snapshot_download such that when called, it will create expected files rather This factory allows to patch snapshot_download such that when called, it will create expected files rather
@ -37,9 +38,9 @@ def mock_snapshot_download_factory(
def _mock_snapshot_download_func( def _mock_snapshot_download_func(
info: dict | None = None, info: dict | None = None,
stats: dict | None = None, stats: dict | None = None,
episodes_stats: list[dict] | None = None, episodes_stats: datasets.Dataset | None = None,
tasks: list[dict] | None = None, tasks: datasets.Dataset | None = None,
episodes: list[dict] | None = None, episodes: datasets.Dataset | None = None,
hf_dataset: datasets.Dataset | None = None, hf_dataset: datasets.Dataset | None = None,
): ):
if not info: if not info:
@ -59,14 +60,6 @@ def mock_snapshot_download_factory(
if not hf_dataset: if not hf_dataset:
hf_dataset = hf_dataset_factory(tasks=tasks, episodes=episodes, fps=info["fps"]) hf_dataset = hf_dataset_factory(tasks=tasks, episodes=episodes, fps=info["fps"])
def _extract_episode_index_from_path(fpath: str) -> int:
path = Path(fpath)
if path.suffix == ".parquet" and path.stem.startswith("episode_"):
episode_index = int(path.stem[len("episode_") :]) # 'episode_000000' -> 0
return episode_index
else:
return None
def _mock_snapshot_download( def _mock_snapshot_download(
repo_id: str, repo_id: str,
local_dir: str | Path | None = None, local_dir: str | Path | None = None,
@ -79,40 +72,55 @@ def mock_snapshot_download_factory(
local_dir = LEROBOT_TEST_DIR local_dir = LEROBOT_TEST_DIR
# List all possible files # List all possible files
all_files = [] all_files = [
meta_files = [INFO_PATH, STATS_PATH, EPISODES_STATS_PATH, TASKS_PATH, EPISODES_PATH] INFO_PATH,
all_files.extend(meta_files) LEGACY_STATS_PATH,
# TODO(rcadene)
data_files = [] DEFAULT_TASKS_PATH.format(chunk_index=0, file_index=0),
for episode_dict in episodes.values(): DEFAULT_EPISODES_STATS_PATH.format(chunk_index=0, file_index=0),
ep_idx = episode_dict["episode_index"] DEFAULT_EPISODES_PATH.format(chunk_index=0, file_index=0),
ep_chunk = ep_idx // info["chunks_size"] DEFAULT_DATA_PATH.format(chunk_index=0, file_index=0),
data_path = info["data_path"].format(episode_chunk=ep_chunk, episode_index=ep_idx) ]
data_files.append(data_path)
all_files.extend(data_files)
allowed_files = filter_repo_objects( allowed_files = filter_repo_objects(
all_files, allow_patterns=allow_patterns, ignore_patterns=ignore_patterns all_files, allow_patterns=allow_patterns, ignore_patterns=ignore_patterns
) )
# Create allowed files has_info = False
has_tasks = False
has_episodes = False
has_episodes_stats = False
has_stats = False
has_data = False
for rel_path in allowed_files: for rel_path in allowed_files:
if rel_path.startswith("data/"): if rel_path.startswith("meta/info.json"):
episode_index = _extract_episode_index_from_path(rel_path) has_info = True
if episode_index is not None: elif rel_path.startswith("meta/stats"):
_ = single_episode_parquet_path(local_dir, episode_index, hf_dataset, info) has_stats = True
if rel_path == INFO_PATH: elif rel_path.startswith("meta/tasks"):
_ = info_path(local_dir, info) has_tasks = True
elif rel_path == STATS_PATH: elif rel_path.startswith("meta/episodes_stats"):
_ = stats_path(local_dir, stats) has_episodes_stats = True
elif rel_path == EPISODES_STATS_PATH: elif rel_path.startswith("meta/episodes"):
_ = episodes_stats_path(local_dir, episodes_stats) has_episodes = True
elif rel_path == TASKS_PATH: elif rel_path.startswith("data/"):
_ = tasks_path(local_dir, tasks) has_data = True
elif rel_path == EPISODES_PATH:
_ = episode_path(local_dir, episodes)
else: else:
pass raise ValueError(f"{rel_path} not supported.")
if has_info:
create_info(local_dir, info)
if has_stats:
create_stats(local_dir, stats)
if has_tasks:
create_tasks(local_dir, tasks)
if has_episodes:
create_episodes(local_dir, episodes)
if has_episodes_stats:
create_episodes_stats(local_dir, episodes_stats)
if has_data:
create_hf_dataset(local_dir, hf_dataset)
return str(local_dir) return str(local_dir)
return _mock_snapshot_download return _mock_snapshot_download