Most unit tests are passing

This commit is contained in:
Remi Cadene 2025-04-11 14:04:22 +02:00
parent c1b28f0b58
commit 34c5d4ce07
6 changed files with 391 additions and 322 deletions

View File

@ -16,17 +16,18 @@
import contextlib
import logging
import shutil
from pathlib import Path
import tempfile
from pathlib import Path
from typing import Callable
import datasets
import numpy as np
import packaging.version
import PIL.Image
import pandas as pd
import PIL.Image
import torch
import torch.utils
from datasets import concatenate_datasets, load_dataset, Dataset
from datasets import Dataset
from huggingface_hub import HfApi, snapshot_download
from huggingface_hub.constants import REPOCARD_NAME
from huggingface_hub.errors import RevisionNotFoundError
@ -35,52 +36,36 @@ from lerobot.common.constants import HF_LEROBOT_HOME
from lerobot.common.datasets.compute_stats import aggregate_stats, compute_episode_stats
from lerobot.common.datasets.image_writer import AsyncImageWriter, write_image
from lerobot.common.datasets.utils import (
DEFAULT_DATA_PATH,
DEFAULT_EPISODES_PATH,
DEFAULT_EPISODES_STATS_PATH,
DEFAULT_FEATURES,
DEFAULT_IMAGE_PATH,
EPISODES_DIR,
EPISODES_STATS_DIR,
INFO_PATH,
LEGACY_TASKS_PATH,
append_jsonlines,
backward_compatible_episodes_stats,
check_delta_timestamps,
check_timestamps_sync,
check_version_compatibility,
concat_video_files,
create_empty_dataset_info,
create_lerobot_dataset_card,
embed_images,
flatten_dict,
get_chunk_file_indices,
get_delta_indices,
get_episode_data_index,
get_features_from_robot,
get_hf_dataset_size_in_mb,
get_hf_features_from_features,
get_latest_parquet_path,
get_latest_video_path,
get_parquet_num_frames,
get_pd_dataframe_size_in_mb,
get_safe_version,
get_video_duration_in_s,
get_video_size_in_mb,
hf_transform_to_torch,
is_valid_version,
legacy_load_episodes,
legacy_load_episodes_stats,
load_episodes,
load_episodes_stats,
load_info,
load_nested_dataset,
load_stats,
legacy_load_tasks,
load_tasks,
update_chunk_file_indices,
validate_episode_buffer,
validate_frame,
write_episode,
legacy_write_episode_stats,
write_info,
write_json,
write_tasks,
@ -118,15 +103,17 @@ class LeRobotDatasetMetadata:
self.revision = get_safe_version(self.repo_id, self.revision)
(self.root / "meta").mkdir(exist_ok=True, parents=True)
# TODO(rcadene): instead of downloading all episodes metadata files,
# download only the ones associated to the requested episodes. This would
# require adding `episodes: list[int]` as argument.
self.pull_from_repo(allow_patterns="meta/")
self.load_metadata()
def load_metadata(self):
self.info = load_info(self.root)
check_version_compatibility(self.repo_id, self._version, CODEBASE_VERSION)
self.tasks, self.task_to_task_index = load_tasks(self.root)
self.tasks = load_tasks(self.root)
self.episodes = load_episodes(self.root)
self.episodes_stats = load_episodes_stats(self.root)
# TODO(rcadene): https://huggingface.slack.com/archives/C02V51Q3800/p1743517952388249?thread_ts=1742896075.499119&cid=C02V51Q3800
# self.stats = aggregate_stats(list(self.episodes_stats.values()))
@ -150,8 +137,8 @@ class LeRobotDatasetMetadata:
return packaging.version.parse(self.info["codebase_version"])
def get_data_file_path(self, ep_index: int) -> Path:
chunk_idx = self.episodes[f"data/chunk_index"][ep_index]
file_idx = self.episodes[f"data/file_index"][ep_index]
chunk_idx = self.episodes["data/chunk_index"][ep_index]
file_idx = self.episodes["data/file_index"][ep_index]
fpath = self.data_path.format(chunk_index=chunk_idx, file_index=file_idx)
return Path(fpath)
@ -161,9 +148,6 @@ class LeRobotDatasetMetadata:
fpath = self.video_path.format(video_key=vid_key, chunk_index=chunk_idx, file_index=file_idx)
return Path(fpath)
# def get_episode_chunk(self, ep_index: int) -> int:
# return ep_index // self.chunks_size
@property
def data_path(self) -> str:
"""Formattable string for the parquet files."""
@ -244,71 +228,84 @@ class LeRobotDatasetMetadata:
Given a task in natural language, returns its task_index if the task already exists in the dataset,
otherwise return None.
"""
return self.tasks.index[task] if task in self.tasks.index else None
def has_task(self, task: str) -> bool:
return task in self.task_to_task_index
if task in self.tasks.index:
return int(self.tasks.loc[task].task_index)
else:
return None
def save_episode_tasks(self, tasks: list[str]):
new_tasks = [task for task in tasks if not self.has_task(task)]
if len(set(tasks)) != len(tasks):
raise ValueError(f"Tasks are not unique: {tasks}")
for task in new_tasks:
task_index = len(self.tasks)
self.tasks.loc[task] = task_index
if self.tasks is None:
new_tasks = tasks
task_indices = range(len(tasks))
self.tasks = pd.DataFrame({"task_index": task_indices}, index=tasks)
else:
new_tasks = [task for task in tasks if task not in self.tasks.index]
new_task_indices = range(len(self.tasks), len(new_tasks))
for task_idx, task in zip(new_task_indices, new_tasks, strict=False):
self.tasks.loc[task] = task_idx
if len(new_tasks) > 0:
# Update on disk
write_tasks(self.tasks, self.root)
def _save_episode(self, episode_dict: dict):
def _save_episode_metadata(self, episode_dict: dict) -> None:
"""Save episode metadata to a parquet file and update the Hugging Face dataset of episodes metadata.
This function processes episodes metadata from a dictionary, converts it into a Hugging Face dataset,
and saves it as a parquet file. It handles both the creation of new parquet files and the
updating of existing ones based on size constraints. After saving the metadata, it reloads
the Hugging Face dataset to ensure it is up-to-date.
Notes: We both need to update parquet files and HF dataset:
- `pandas` loads parquet file in RAM
- `datasets` relies on a memory mapping from pyarrow (no RAM). It either converts parquet files to a pyarrow cache on disk,
or loads directly from pyarrow cache.
"""
# Convert buffer into HF Dataset
episode_dict = {key: [value] for key, value in episode_dict.items()}
ep_dataset = Dataset.from_dict(episode_dict)
ep_size_in_mb = get_hf_dataset_size_in_mb(ep_dataset)
df = pd.DataFrame(ep_dataset)
# Access latest parquet file information
latest_path = get_latest_parquet_path(self.root / EPISODES_DIR)
latest_size_in_mb = get_parquet_file_size_in_mb(latest_path)
chunk_idx, file_idx = get_chunk_file_indices(latest_path)
if latest_size_in_mb + ep_size_in_mb >= self.meta.files_size_in_mb:
# Create new parquet file
chunk_idx, file_idx = update_chunk_file_indices(chunk_idx, file_idx, self.meta.chunks_size)
new_path = self.root / DEFAULT_EPISODES_PATH.format(chunk_index=chunk_idx, file_index=file_idx)
new_path.parent.mkdir(parents=True, exist_ok=True)
ep_df.to_parquet(new_path, index=False)
if self.episodes is None:
# Initialize indices and frame count for a new dataset made of the first episode data
chunk_idx, file_idx = 0, 0
df["meta/episodes/chunk_index"] = [chunk_idx]
df["meta/episodes/file_index"] = [file_idx]
else:
# Update latest parquet file with new row
ep_df = pd.DataFrame(ep_dataset)
latest_df = pd.concat([latest_df, ep_df], ignore_index=True) # RAM
latest_df.to_parquet(latest_path, index=False)
# Retrieve information from the latest parquet file
latest_ep = self.episodes.with_format(columns=["chunk_index", "file_index"])[-1]
chunk_idx, file_idx = latest_ep["chunk_index"], latest_ep["file_index"]
latest_path = self.root / DEFAULT_EPISODES_PATH.format(chunk_index=chunk_idx, file_index=file_idx)
latest_size_in_mb = get_parquet_file_size_in_mb(latest_path)
# Determine if a new parquet file is needed
if latest_size_in_mb + ep_size_in_mb >= self.files_size_in_mb:
# Size limit is reached, prepare new parquet file
chunk_idx, file_idx = update_chunk_file_indices(chunk_idx, file_idx, self.meta.chunks_size)
df["meta/episodes/chunk_index"] = [chunk_idx]
df["meta/episodes/file_index"] = [file_idx]
else:
# Update the existing parquet file with new row
df["meta/episodes/chunk_index"] = [chunk_idx]
df["meta/episodes/file_index"] = [file_idx]
latest_df = pd.read_parquet(latest_path)
latest_df = pd.concat([latest_df, df], ignore_index=True)
# Write the resulting dataframe from RAM to disk
path = self.root / DEFAULT_EPISODES_PATH.format(chunk_index=chunk_idx, file_index=file_idx)
path.parent.mkdir(parents=True, exist_ok=True)
df.to_parquet(path, index=False)
# Update the Hugging Face dataset by reloading it.
# This process should be fast because only the latest Parquet file has been modified.
# Therefore, only this file needs to be converted to PyArrow; the rest is loaded from the PyArrow memory-mapped cache.
self.episodes = load_episodes(self.root)
def _save_episode_stats(self, episodes_stats: dict):
ep_dataset = Dataset.from_dict(episodes_stats)
ep_size_in_mb = get_hf_dataset_size_in_mb(ep_dataset)
# Access latest parquet file information
latest_path = get_latest_parquet_path(self.root / EPISODES_STATS_DIR)
latest_size_in_mb = get_parquet_file_size_in_mb(latest_path)
chunk_idx, file_idx = get_chunk_file_indices(latest_path)
if latest_size_in_mb + ep_size_in_mb >= self.meta.files_size_in_mb:
# Create new parquet file
chunk_idx, file_idx = update_chunk_file_indices(chunk_idx, file_idx, self.meta.chunks_size)
new_path = self.root / DEFAULT_EPISODES_STATS_PATH.format(chunk_index=chunk_idx, file_index=file_idx)
new_path.parent.mkdir(parents=True, exist_ok=True)
ep_df.to_parquet(new_path, index=False)
else:
# Update latest parquet file with new row
ep_df = pd.DataFrame(ep_dataset)
latest_df = pd.concat([latest_df, ep_df], ignore_index=True) # RAM
latest_df.to_parquet(latest_path, index=False)
self.episodes_stats = load_episodes_stats(self.root)
def save_episode(
self,
episode_index: int,
@ -331,8 +328,8 @@ class LeRobotDatasetMetadata:
"length": episode_length,
}
episode_dict.update(episode_metadata)
self._save_episode(episode_dict)
self._save_episode_stats(episode_stats)
episode_dict.update(flatten_dict({"stats": episode_stats}))
self._save_episode_metadata(episode_dict)
self.stats = aggregate_stats([self.stats, episode_stats]) if self.stats else episode_stats
# TODO: write stats
@ -401,7 +398,6 @@ class LeRobotDatasetMetadata:
features = {**features, **DEFAULT_FEATURES}
obj.tasks = None
obj.episodes_stats = None
obj.episodes = None
# TODO(rcadene) stats
obj.stats = {}
@ -557,7 +553,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
self.hf_dataset = self.load_hf_dataset()
except (AssertionError, FileNotFoundError, NotADirectoryError):
self.revision = get_safe_version(self.repo_id, self.revision)
self.download_episodes(download_videos)
self.download(download_videos)
self.hf_dataset = self.load_hf_dataset()
# Setup delta_indices
@ -635,7 +631,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
ignore_patterns=ignore_patterns,
)
def download_episodes(self, download_videos: bool = True) -> None:
def download(self, download_videos: bool = True) -> None:
"""Downloads the dataset from the given 'repo_id' at the provided version. If 'episodes' is given, this
will only download those episodes (selected by their episode_index). If 'episodes' is None, the whole
dataset will be downloaded. Thanks to the behavior of snapshot_download, if the files are already present
@ -795,10 +791,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
# Add task as a string
task_idx = item["task_index"].item()
if self.meta.tasks["task_index"][task_idx] != task_idx:
raise ValueError("Sanity check on task index failed.")
item["task"] = self.meta.tasks["task"][task_idx]
item["task"] = self.meta.tasks.iloc[task_idx].name
return item
def __repr__(self):
@ -926,11 +919,11 @@ class LeRobotDataset(torch.utils.data.Dataset):
self._wait_image_writer()
ep_stats = compute_episode_stats(episode_buffer, self.features)
ep_metadata = self._save_episode_data(episode_buffer, episode_index)
ep_metadata = self._save_episode_data(episode_buffer)
for video_key in self.meta.video_keys:
ep_metadata.update(self._save_episode_video(video_key, episode_index))
# `meta.save_episode` neeed to be executed after encoding the videos
# `meta.save_episode` need to be executed after encoding the videos
self.meta.save_episode(episode_index, episode_length, episode_tasks, ep_stats, ep_metadata)
# TODO(rcadene): remove? there is only one episode in the episode buffer, no need for ep_data_index
@ -954,31 +947,57 @@ class LeRobotDataset(torch.utils.data.Dataset):
# Reset episode buffer
self.episode_buffer = self.create_episode_buffer()
def _save_episode_data(self, episode_buffer: dict) -> None:
def _save_episode_data(self, episode_buffer: dict) -> dict:
"""Save episode data to a parquet file and update the Hugging Face dataset of frames data.
This function processes episodes data from a buffer, converts it into a Hugging Face dataset,
and saves it as a parquet file. It handles both the creation of new parquet files and the
updating of existing ones based on size constraints. After saving the data, it reloads
the Hugging Face dataset to ensure it is up-to-date.
Notes: We both need to update parquet files and HF dataset:
- `pandas` loads parquet file in RAM
- `datasets` relies on a memory mapping from pyarrow (no RAM). It either converts parquet files to a pyarrow cache on disk,
or loads directly from pyarrow cache.
"""
# Convert buffer into HF Dataset
ep_dict = {key: episode_buffer[key] for key in self.hf_features}
ep_dataset = datasets.Dataset.from_dict(ep_dict, features=self.hf_features, split="train")
ep_dataset = embed_images(ep_dataset)
ep_size_in_mb = get_hf_dataset_size_in_mb(ep_dataset)
ep_num_frames = len(ep_dataset)
df = pd.DataFrame(ep_dataset)
# Access latest parquet file information
latest_path = get_latest_parquet_path(self.root / "data")
if self.meta.episodes is None:
# Initialize indices and frame count for a new dataset made of the first episode data
chunk_idx, file_idx = 0, 0
latest_num_frames = 0
else:
# Retrieve information from the latest parquet file
latest_ep = self.meta.episodes.with_format(columns=["data/chunk_index", "data/file_index"])[-1]
chunk_idx, file_idx = latest_ep["data/chunk_index"], latest_ep["data/file_index"]
latest_path = self.root / DEFAULT_DATA_PATH.format(chunk_index=chunk_idx, file_index=file_idx)
latest_size_in_mb = get_parquet_file_size_in_mb(latest_path)
latest_num_frames = get_parquet_num_frames(latest_path)
chunk_idx, file_idx = get_chunk_file_indices(latest_path)
# Determine if a new parquet file is needed
if latest_size_in_mb + ep_size_in_mb >= self.meta.files_size_in_mb:
# Create new parquet file
# Size limit is reached, prepare new parquet file
chunk_idx, file_idx = update_chunk_file_indices(chunk_idx, file_idx, self.meta.chunks_size)
new_path = self.meta.data_path.format(chunk_index=chunk_idx, file_index=file_idx)
new_path.parent.mkdir(parents=True, exist_ok=True)
ep_df.to_parquet(new_path, index=False)
latest_num_frames = 0
else:
# Update latest parquet file with new rows
ep_df = pd.DataFrame(ep_dataset)
latest_df = pd.concat([latest_df, ep_df], ignore_index=True) # RAM
latest_df.to_parquet(latest_path, index=False)
# Update the existing parquet file with new rows
latest_df = pd.read_parquet(latest_path)
df = pd.concat([latest_df, df], ignore_index=True)
# Write the resulting dataframe from RAM to disk
path = self.root / self.meta.data_path.format(chunk_index=chunk_idx, file_index=file_idx)
path.parent.mkdir(parents=True, exist_ok=True)
if len(self.meta.image_keys) > 0:
datasets.Dataset.from_dict(df.to_dict(orient="list")).to_parquet(path)
else:
df.to_parquet(path)
# Update the Hugging Face dataset by reloading it.
# This process should be fast because only the latest Parquet file has been modified.
@ -1008,7 +1027,9 @@ class LeRobotDataset(torch.utils.data.Dataset):
if latest_size_in_mb + ep_size_in_mb >= self.meta.files_size_in_mb:
# Move temporary episode video to a new video file in the dataset
chunk_idx, file_idx = update_chunk_file_indices(chunk_idx, file_idx, self.meta.chunks_size)
new_path = self.meta.video_path.format(video_key=video_key, chunk_index=chunk_idx, file_index=file_idx)
new_path = self.meta.video_path.format(
video_key=video_key, chunk_index=chunk_idx, file_index=file_idx
)
new_path.parent.mkdir(parents=True, exist_ok=True)
ep_path.replace(new_path)
else:

View File

@ -17,12 +17,12 @@ import contextlib
import importlib.resources
import json
import logging
import subprocess
import tempfile
from collections.abc import Iterator
from itertools import accumulate
from pathlib import Path
from pprint import pformat
import subprocess
import tempfile
from types import SimpleNamespace
from typing import Any, Tuple
@ -31,24 +31,24 @@ import jsonlines
import numpy as np
import packaging.version
import pandas
import pandas as pd
import pyarrow.parquet as pq
import torch
from datasets import Dataset, concatenate_datasets
from datasets.table import embed_table_storage
from huggingface_hub import DatasetCard, DatasetCardData, HfApi
from huggingface_hub.errors import RevisionNotFoundError
from PIL import Image as PILImage
from torchvision import transforms
from datasets import Dataset, concatenate_datasets
from lerobot.common.datasets.backward_compatibility import (
V21_MESSAGE,
V30_MESSAGE,
BackwardCompatibilityError,
ForwardCompatibilityError,
)
from lerobot.common.robot_devices.robots.utils import Robot
from lerobot.common.utils.utils import is_valid_numpy_dtype_string
from lerobot.configs.types import DictLike, FeatureType, PolicyFeature
import pyarrow.parquet as pq
DEFAULT_CHUNK_SIZE = 1000 # Max number of files per chunk
DEFAULT_FILE_SIZE_IN_MB = 500.0 # Max size per file
@ -65,16 +65,13 @@ LEGACY_DEFAULT_PARQUET_PATH = "data/chunk-{episode_chunk:03d}/episode_{episode_i
DEFAULT_IMAGE_PATH = "images/{image_key}/episode_{episode_index:06d}/frame_{frame_index:06d}.png"
EPISODES_DIR = "meta/episodes"
EPISODES_STATS_DIR = "meta/episodes_stats"
TASKS_DIR = "meta/tasks"
DATA_DIR = "data"
VIDEO_DIR = "videos"
INFO_PATH = "meta/info.json"
CHUNK_FILE_PATTERN = "chunk-{chunk_index:03d}/file-{file_index:03d}"
DEFAULT_EPISODES_PATH = EPISODES_DIR + "/" + CHUNK_FILE_PATTERN + ".parquet"
DEFAULT_EPISODES_STATS_PATH = EPISODES_STATS_DIR + "/" + CHUNK_FILE_PATTERN + ".parquet"
DEFAULT_TASKS_PATH = TASKS_DIR + "/" + CHUNK_FILE_PATTERN + ".parquet"
DEFAULT_TASKS_PATH = "meta/tasks.parquet"
DEFAULT_DATA_PATH = DATA_DIR + "/" + CHUNK_FILE_PATTERN + ".parquet"
DEFAULT_VIDEO_PATH = VIDEO_DIR + "/{video_key}/" + CHUNK_FILE_PATTERN + ".mp4"
@ -98,11 +95,13 @@ DEFAULT_FEATURES = {
def get_hf_dataset_size_in_mb(hf_ds: Dataset) -> int:
return hf_ds.data.nbytes / (1024 ** 2)
return hf_ds.data.nbytes / (1024**2)
def get_pd_dataframe_size_in_mb(df: pandas.DataFrame) -> int:
memory_usage_bytes = df.memory_usage(deep=True).sum()
return memory_usage_bytes / (1024 ** 2)
return memory_usage_bytes / (1024**2)
def get_chunk_file_indices(path: Path) -> Tuple[int, int]:
if not path.stem.startswith("file-") or not path.parent.name.startswith("chunk-"):
@ -112,6 +111,7 @@ def get_chunk_file_indices(path: Path) -> Tuple[int, int]:
file_index = int(path.stem.replace("file-", ""))
return chunk_index, file_index
def update_chunk_file_indices(chunk_idx: int, file_idx: int, chunks_size: int):
if file_idx == chunks_size - 1:
file_idx = 0
@ -122,63 +122,86 @@ def update_chunk_file_indices(chunk_idx: int, file_idx: int, chunks_size: int):
def load_nested_dataset(pq_dir: Path) -> Dataset:
""" Find parquet files in provided directory {pq_dir}/chunk-xxx/file-xxx.parquet
"""Find parquet files in provided directory {pq_dir}/chunk-xxx/file-xxx.parquet
Convert parquet files to pyarrow memory mapped in a cache folder for efficient RAM usage
Concatenate all pyarrow references to return HF Dataset format
"""
paths = sorted(pq_dir.glob("*/*.parquet"))
if len(paths) == 0:
raise FileNotFoundError(f"Provided directory does not contain any parquet file: {pq_dir}")
# TODO(rcadene): set num_proc to accelerate conversion to pyarrow
return concatenate_datasets([Dataset.from_parquet(str(path)) for path in sorted(pq_dir.glob("*/*.parquet"))])
return concatenate_datasets(
[Dataset.from_parquet(str(path)) for path in sorted(pq_dir.glob("*/*.parquet"))]
)
def get_latest_parquet_path(pq_dir: Path) -> Path:
return sorted(pq_dir.glob("*/*.parquet"))[-1]
def get_latest_video_path(pq_dir: Path, video_key: str) -> Path:
return sorted(pq_dir.glob(f"{video_key}/*/*.mp4"))[-1]
def get_parquet_num_frames(parquet_path):
metadata = pq.read_metadata(parquet_path)
return metadata.num_rows
def get_video_size_in_mb(mp4_path: Path):
file_size_bytes = mp4_path.stat().st_size
file_size_mb = file_size_bytes / (1024 ** 2)
file_size_mb = file_size_bytes / (1024**2)
return file_size_mb
def concat_video_files(paths_to_cat, new_root, video_key, chunk_idx, file_idx):
# Create a text file with the list of files to concatenate
with tempfile.NamedTemporaryFile(mode='w', delete=False, suffix='.txt') as f:
with tempfile.NamedTemporaryFile(mode="w", delete=False, suffix=".txt") as f:
temp_file_path = f.name
for ep_path in paths_to_cat:
f.write(f"file '{str(ep_path)}'\n")
output_path = new_root / DEFAULT_VIDEO_PATH.format(video_key=video_key, chunk_index=chunk_idx, file_index=file_idx)
output_path = new_root / DEFAULT_VIDEO_PATH.format(
video_key=video_key, chunk_index=chunk_idx, file_index=file_idx
)
output_path.parent.mkdir(parents=True, exist_ok=True)
command = [
'ffmpeg',
'-y',
'-f', 'concat',
'-safe', '0',
'-i', str(temp_file_path),
'-c', 'copy',
str(output_path)
"ffmpeg",
"-y",
"-f",
"concat",
"-safe",
"0",
"-i",
str(temp_file_path),
"-c",
"copy",
str(output_path),
]
subprocess.run(command, check=True)
Path(temp_file_path).unlink()
def get_video_duration_in_s(mp4_file: Path):
command = [
"ffprobe",
"-v",
"error",
"-show_entries",
"format=duration",
"-of",
"default=noprint_wrappers=1:nokey=1",
mp4_file,
]
result = subprocess.run(
[
'ffprobe',
'-v', 'error',
'-show_entries', 'format=duration',
'-of', 'default=noprint_wrappers=1:nokey=1',
mp4_file
],
command,
stdout=subprocess.PIPE,
stderr=subprocess.STDOUT
stderr=subprocess.STDOUT,
)
return float(result.stdout)
def flatten_dict(d: dict, parent_key: str = "", sep: str = "/") -> dict:
"""Flatten a nested dictionary structure by collapsing nested keys into one key with a separator.
@ -314,10 +337,11 @@ def write_hf_dataset(hf_dataset: Dataset, local_dir: Path):
def write_tasks(tasks: pandas.DataFrame, local_dir: Path):
path = local_dir / DEFAULT_TASKS_PATH.format(chunk_index=0, file_index=0)
path = local_dir / DEFAULT_TASKS_PATH
path.parent.mkdir(parents=True, exist_ok=True)
tasks.to_parquet(path)
def legacy_write_task(task_index: int, task: dict, local_dir: Path):
task_dict = {
"task_index": task_index,
@ -332,31 +356,42 @@ def legacy_load_tasks(local_dir: Path) -> tuple[dict, dict]:
task_to_task_index = {task: task_index for task_index, task in tasks.items()}
return tasks, task_to_task_index
def load_tasks(local_dir: Path):
tasks = load_nested_dataset(local_dir / TASKS_DIR)
# TODO(rcadene): optimize this
task_to_task_index = {d["task"]: d["task_index"] for d in tasks}
return tasks, task_to_task_index
tasks = pd.read_parquet(local_dir / DEFAULT_TASKS_PATH)
return tasks
def write_episode(episode: dict, local_dir: Path):
append_jsonlines(episode, local_dir / LEGACY_EPISODES_PATH)
def write_episodes(episodes: Dataset, local_dir: Path):
if get_hf_dataset_size_in_mb(episodes) > DEFAULT_FILE_SIZE_IN_MB:
raise NotImplementedError("Contact a maintainer.")
def add_chunk_file_indices(row):
row["chunk_index"] = 0
row["file_index"] = 0
return row
episodes = episodes.map(add_chunk_file_indices)
fpath = local_dir / DEFAULT_EPISODES_PATH.format(chunk_index=0, file_index=0)
fpath.parent.mkdir(parents=True, exist_ok=True)
episodes.to_parquet(fpath)
def legacy_load_episodes(local_dir: Path) -> dict:
episodes = load_jsonlines(local_dir / LEGACY_EPISODES_PATH)
return {item["episode_index"]: item for item in sorted(episodes, key=lambda x: x["episode_index"])}
def load_episodes(local_dir: Path):
hf_dataset = load_nested_dataset(local_dir / EPISODES_DIR)
return hf_dataset
def legacy_write_episode_stats(episode_index: int, episode_stats: dict, local_dir: Path):
# We wrap episode_stats in a dictionary since `episode_stats["episode_index"]`
# is a dictionary of stats and not an integer.
@ -364,13 +399,13 @@ def legacy_write_episode_stats(episode_index: int, episode_stats: dict, local_di
append_jsonlines(episode_stats, local_dir / LEGACY_EPISODES_STATS_PATH)
def write_episodes_stats(episodes_stats: Dataset, local_dir: Path):
if get_hf_dataset_size_in_mb(episodes_stats) > DEFAULT_FILE_SIZE_IN_MB:
raise NotImplementedError("Contact a maintainer.")
# def write_episodes_stats(episodes_stats: Dataset, local_dir: Path):
# if get_hf_dataset_size_in_mb(episodes_stats) > DEFAULT_FILE_SIZE_IN_MB:
# raise NotImplementedError("Contact a maintainer.")
fpath = local_dir / DEFAULT_EPISODES_STATS_PATH.format(chunk_index=0, file_index=0)
fpath.parent.mkdir(parents=True, exist_ok=True)
episodes_stats.to_parquet(fpath)
# fpath = local_dir / DEFAULT_EPISODES_STATS_PATH.format(chunk_index=0, file_index=0)
# fpath.parent.mkdir(parents=True, exist_ok=True)
# episodes_stats.to_parquet(fpath)
def legacy_load_episodes_stats(local_dir: Path) -> dict:
@ -380,9 +415,11 @@ def legacy_load_episodes_stats(local_dir: Path) -> dict:
for item in sorted(episodes_stats, key=lambda x: x["episode_index"])
}
def load_episodes_stats(local_dir: Path):
hf_dataset = load_nested_dataset(local_dir / EPISODES_STATS_DIR)
return hf_dataset
# def load_episodes_stats(local_dir: Path):
# hf_dataset = load_nested_dataset(local_dir / EPISODES_STATS_DIR)
# return hf_dataset
def backward_compatible_episodes_stats(
stats: dict[str, dict[str, np.ndarray]], episodes: list[int]

View File

@ -18,13 +18,13 @@ python lerobot/common/datasets/v30/convert_dataset_v21_to_v30.py \
"""
import argparse
import logging
from pathlib import Path
import sys
import pandas as pd
import pyarrow.parquet as pq
import tqdm
from datasets import Dataset
from huggingface_hub import snapshot_download
import tqdm
from lerobot.common.constants import HF_LEROBOT_HOME
from lerobot.common.datasets.utils import (
@ -39,18 +39,13 @@ from lerobot.common.datasets.utils import (
get_video_size_in_mb,
legacy_load_episodes,
legacy_load_episodes_stats,
load_info,
legacy_load_tasks,
load_info,
update_chunk_file_indices,
write_episodes,
write_episodes_stats,
write_info,
write_tasks,
)
import subprocess
import tempfile
import pandas as pd
import pyarrow.parquet as pq
V21 = "v2.1"
@ -97,32 +92,31 @@ meta/info.json
-------------------------
"""
def get_parquet_file_size_in_mb(parquet_path):
metadata = pq.read_metadata(parquet_path)
uncompressed_size = metadata.num_rows * metadata.row_group(0).total_byte_size
return uncompressed_size / (1024 ** 2)
return uncompressed_size / (1024**2)
# def generate_flat_ep_stats(episodes_stats):
# for ep_idx, ep_stats in episodes_stats.items():
# flat_ep_stats = flatten_dict(ep_stats)
# flat_ep_stats["episode_index"] = ep_idx
# yield flat_ep_stats
def generate_flat_ep_stats(episodes_stats):
for ep_idx, ep_stats in episodes_stats.items():
flat_ep_stats = flatten_dict(ep_stats)
flat_ep_stats["episode_index"] = ep_idx
yield flat_ep_stats
# def convert_episodes_stats(root, new_root):
# episodes_stats = legacy_load_episodes_stats(root)
# ds_episodes_stats = Dataset.from_generator(lambda: generate_flat_ep_stats(episodes_stats))
# write_episodes_stats(ds_episodes_stats, new_root)
def convert_episodes_stats(root, new_root):
episodes_stats = legacy_load_episodes_stats(root)
ds_episodes_stats = Dataset.from_generator(lambda: generate_flat_ep_stats(episodes_stats))
write_episodes_stats(ds_episodes_stats, new_root)
def generate_task_dict(tasks):
for task_index, task in tasks.items():
yield {"task_index": task_index, "task": task}
def convert_tasks(root, new_root):
tasks, _ = legacy_load_tasks(root)
ds_tasks = Dataset.from_generator(lambda: generate_task_dict(tasks))
write_tasks(ds_tasks, new_root)
task_indices = tasks.keys()
task_strings = tasks.values()
df_tasks = pd.DataFrame({"task_index": task_indices}, index=task_strings)
write_tasks(df_tasks, new_root)
def concat_data_files(paths_to_cat, new_root, chunk_idx, file_idx):
@ -138,17 +132,15 @@ def concat_data_files(paths_to_cat, new_root, chunk_idx, file_idx):
def convert_data(root, new_root):
data_dir = root / "data"
ep_paths = sorted(data_dir.glob("*/*.parquet"))
ep_paths = [path for path in data_dir.glob("*/*.parquet")]
ep_paths = sorted(ep_paths)
episodes_metadata = []
ep_idx = 0
chunk_idx = 0
file_idx = 0
size_in_mb = 0
num_frames = 0
paths_to_cat = []
episodes_metadata = []
for ep_path in ep_paths:
ep_size_in_mb = get_parquet_file_size_in_mb(ep_path)
ep_num_frames = get_parquet_num_frames(ep_path)
@ -184,7 +176,6 @@ def convert_data(root, new_root):
return episodes_metadata
def get_video_keys(root):
info = load_info(root)
features = info["features"]
@ -230,16 +221,15 @@ def convert_videos(root: Path, new_root: Path):
def convert_videos_of_camera(root: Path, new_root: Path, video_key):
# Access old paths to mp4
videos_dir = root / "videos"
ep_paths = [path for path in videos_dir.glob(f"*/{video_key}/*.mp4")]
ep_paths = sorted(ep_paths)
ep_paths = sorted(videos_dir.glob(f"*/{video_key}/*.mp4"))
episodes_metadata = []
ep_idx = 0
chunk_idx = 0
file_idx = 0
size_in_mb = 0
duration_in_s = 0.0
paths_to_cat = []
episodes_metadata = []
for ep_path in tqdm.tqdm(ep_paths, desc=f"convert videos of {video_key}"):
ep_size_in_mb = get_video_size_in_mb(ep_path)
ep_duration_in_s = get_video_duration_in_s(ep_path)
@ -274,30 +264,53 @@ def convert_videos_of_camera(root: Path, new_root: Path, video_key):
return episodes_metadata
def generate_episode_dict(episodes, episodes_data, episodes_videos):
for ep, ep_data, ep_video in zip(episodes.values(), episodes_data, episodes_videos):
ep_idx = ep["episode_index"]
ep_idx_data = ep_data["episode_index"]
def generate_episode_metadata_dict(
episodes_legacy_metadata, episodes_metadata, episodes_videos, episodes_stats
):
for ep_legacy_metadata, ep_metadata, ep_video, ep_stats, ep_idx_stats in zip(
episodes_legacy_metadata.values(),
episodes_metadata,
episodes_videos,
episodes_stats.values(),
episodes_stats.keys(),
strict=False,
):
ep_idx = ep_legacy_metadata["episode_index"]
ep_idx_data = ep_metadata["episode_index"]
ep_idx_video = ep_video["episode_index"]
if len(set([ep_idx, ep_idx_data, ep_idx_video])) != 1:
raise ValueError(f"Number of episodes is not the same ({ep_idx=},{ep_idx_data=},{ep_idx_video=}).")
if len({ep_idx, ep_idx_data, ep_idx_video, ep_idx_stats}) != 1:
raise ValueError(
f"Number of episodes is not the same ({ep_idx=},{ep_idx_data=},{ep_idx_video=},{ep_idx_stats=})."
)
ep_dict = {**ep_data, **ep_video, **ep}
ep_dict = {**ep_metadata, **ep_video, **ep_legacy_metadata, **flatten_dict({"stats": ep_stats})}
ep_dict["meta/episodes/chunk_index"] = 0
ep_dict["meta/episodes/file_index"] = 0
yield ep_dict
def convert_episodes(root, new_root, episodes_data, episodes_videos):
episodes = legacy_load_episodes(root)
num_eps = len(episodes)
num_eps_data = len(episodes_data)
num_eps_video = len(episodes_videos)
if len(set([num_eps, num_eps_data, num_eps_video])) != 1:
raise ValueError(f"Number of episodes is not the same ({num_eps=},{num_eps_data=},{num_eps_video=}).")
def convert_episodes_metadata(root, new_root, episodes_metadata, episodes_video_metadata):
episodes_legacy_metadata = legacy_load_episodes(root)
episodes_stats = legacy_load_episodes_stats(root)
ds_episodes = Dataset.from_generator(lambda: generate_episode_dict(episodes, episodes_data, episodes_videos))
num_eps = len(episodes_legacy_metadata)
num_eps_metadata = len(episodes_metadata)
num_eps_video_metadata = len(episodes_video_metadata)
if len({num_eps, num_eps_metadata, num_eps_video_metadata}) != 1:
raise ValueError(
f"Number of episodes is not the same ({num_eps=},{num_eps_metadata=},{num_eps_video_metadata=})."
)
ds_episodes = Dataset.from_generator(
lambda: generate_episode_metadata_dict(
episodes_legacy_metadata, episodes_metadata, episodes_video_metadata, episodes_stats
)
)
write_episodes(ds_episodes, new_root)
def convert_info(root, new_root):
info = load_info(root)
info["codebase_version"] = "v3.0"
@ -315,6 +328,7 @@ def convert_info(root, new_root):
info["features"][key]["fps"] = info["fps"]
write_info(info, new_root)
def convert_dataset(
repo_id: str,
branch: str | None = None,
@ -331,11 +345,11 @@ def convert_dataset(
)
convert_info(root, new_root)
convert_episodes_stats(root, new_root)
convert_tasks(root, new_root)
episodes_data_mapping = convert_data(root, new_root)
episodes_videos_mapping = convert_videos(root, new_root)
convert_episodes(root, new_root, episodes_data_mapping, episodes_videos_mapping)
episodes_metadata = convert_data(root, new_root)
episodes_videos_metadata = convert_videos(root, new_root)
convert_episodes_metadata(root, new_root, episodes_metadata, episodes_videos_metadata)
if __name__ == "__main__":
parser = argparse.ArgumentParser()

View File

@ -6,6 +6,7 @@ from unittest.mock import patch
import datasets
import numpy as np
import pandas as pd
import PIL.Image
import pytest
import torch
@ -14,13 +15,12 @@ from datasets import Dataset
from lerobot.common.datasets.lerobot_dataset import CODEBASE_VERSION, LeRobotDataset, LeRobotDatasetMetadata
from lerobot.common.datasets.utils import (
DEFAULT_CHUNK_SIZE,
DEFAULT_FEATURES,
DEFAULT_DATA_PATH,
DEFAULT_FEATURES,
DEFAULT_FILE_SIZE_IN_MB,
DEFAULT_VIDEO_PATH,
flatten_dict,
get_hf_features_from_features,
hf_transform_to_torch,
)
from tests.fixtures.constants import (
DEFAULT_FPS,
@ -36,8 +36,9 @@ class LeRobotDatasetFactory(Protocol):
def __call__(self, *args, **kwargs) -> LeRobotDataset: ...
def get_task_index(tasks: Dataset, task: str) -> int:
task_idx = tasks["task"].index(task)
def get_task_index(tasks: datasets.Dataset, task: str) -> int:
# TODO(rcadene): a bit complicated no? ^^
task_idx = tasks.loc[task].task_index.item()
return task_idx
@ -164,42 +165,44 @@ def stats_factory():
return _create_stats
@pytest.fixture(scope="session")
def episodes_stats_factory(stats_factory):
def _create_episodes_stats(
features: dict[str],
total_episodes: int = 3,
) -> dict:
# @pytest.fixture(scope="session")
# def episodes_stats_factory(stats_factory):
# def _create_episodes_stats(
# features: dict[str],
# total_episodes: int = 3,
# ) -> dict:
def _generator(total_episodes):
for ep_idx in range(total_episodes):
flat_ep_stats = flatten_dict(stats_factory(features))
flat_ep_stats["episode_index"] = ep_idx
yield flat_ep_stats
# def _generator(total_episodes):
# for ep_idx in range(total_episodes):
# flat_ep_stats = flatten_dict(stats_factory(features))
# flat_ep_stats["episode_index"] = ep_idx
# yield flat_ep_stats
# Simpler to rely on generator instead of from_dict
return Dataset.from_generator(lambda: _generator(total_episodes))
# # Simpler to rely on generator instead of from_dict
# return Dataset.from_generator(lambda: _generator(total_episodes))
return _create_episodes_stats
# return _create_episodes_stats
@pytest.fixture(scope="session")
def tasks_factory():
def _create_tasks(total_tasks: int = 3) -> Dataset:
def _create_tasks(total_tasks: int = 3) -> pd.DataFrame:
ids = list(range(total_tasks))
tasks = [f"Perform action {i}." for i in ids]
return Dataset.from_dict({"task_index": ids, "task": tasks})
df = pd.DataFrame({"task_index": ids}, index=tasks)
return df
return _create_tasks
@pytest.fixture(scope="session")
def episodes_factory(tasks_factory):
def episodes_factory(tasks_factory, stats_factory):
def _create_episodes(
features: dict[str],
total_episodes: int = 3,
total_frames: int = 400,
video_keys: list[str] | None = None,
tasks: dict | None = None,
tasks: pd.DataFrame | None = None,
multi_task: bool = False,
):
if total_episodes <= 0 or total_frames <= 0:
@ -207,21 +210,24 @@ def episodes_factory(tasks_factory):
if total_frames < total_episodes:
raise ValueError("total_length must be greater than or equal to num_episodes.")
if not tasks:
if tasks is None:
min_tasks = 2 if multi_task else 1
total_tasks = random.randint(min_tasks, total_episodes)
tasks = tasks_factory(total_tasks)
if total_episodes < len(tasks) and not multi_task:
num_tasks_available = len(tasks)
if total_episodes < num_tasks_available and not multi_task:
raise ValueError("The number of tasks should be less than the number of episodes.")
# Generate random lengths that sum up to total_length
lengths = np.random.multinomial(total_frames, [1 / total_episodes] * total_episodes).tolist()
num_tasks_available = len(tasks["task"])
# Create empty dictionaries with all keys
d = {
"episode_index": [],
"meta/episodes/chunk_index": [],
"meta/episodes/file_index": [],
"data/chunk_index": [],
"data/file_index": [],
"tasks": [],
@ -232,10 +238,13 @@ def episodes_factory(tasks_factory):
d[f"{video_key}/chunk_index"] = []
d[f"{video_key}/file_index"] = []
remaining_tasks = tasks["task"].copy()
for stats_key in flatten_dict({"stats": stats_factory(features)}):
d[stats_key] = []
remaining_tasks = list(tasks.index)
for ep_idx in range(total_episodes):
num_tasks_in_episode = random.randint(1, min(3, num_tasks_available)) if multi_task else 1
tasks_to_sample = remaining_tasks if remaining_tasks else tasks["task"]
tasks_to_sample = remaining_tasks if len(remaining_tasks) > 0 else list(tasks.index)
episode_tasks = random.sample(tasks_to_sample, min(num_tasks_in_episode, len(tasks_to_sample)))
if remaining_tasks:
for task in episode_tasks:
@ -243,15 +252,22 @@ def episodes_factory(tasks_factory):
d["episode_index"].append(ep_idx)
# TODO(rcadene): remove heuristic of only one file
d["meta/episodes/chunk_index"].append(0)
d["meta/episodes/file_index"].append(0)
d["data/chunk_index"].append(0)
d["data/file_index"].append(0)
d["tasks"].append(episode_tasks)
d["length"].append(lengths[ep_idx])
if video_keys is not None:
for video_key in video_keys:
d[f"{video_key}/chunk_index"].append(0)
d[f"{video_key}/file_index"].append(0)
# Add stats columns like "stats/action/max"
for stats_key, stats in flatten_dict({"stats": stats_factory(features)}).items():
d[stats_key].append(stats)
return Dataset.from_dict(d)
return _create_episodes
@ -261,15 +277,15 @@ def episodes_factory(tasks_factory):
def hf_dataset_factory(features_factory, tasks_factory, episodes_factory, img_array_factory):
def _create_hf_dataset(
features: dict | None = None,
tasks: list[dict] | None = None,
episodes: list[dict] | None = None,
tasks: pd.DataFrame | None = None,
episodes: datasets.Dataset | None = None,
fps: int = DEFAULT_FPS,
) -> datasets.Dataset:
if not tasks:
if tasks is None:
tasks = tasks_factory()
if not episodes:
if episodes is None:
episodes = episodes_factory()
if not features:
if features is None:
features = features_factory()
timestamp_col = np.array([], dtype=np.float32)
@ -282,6 +298,8 @@ def hf_dataset_factory(features_factory, tasks_factory, episodes_factory, img_ar
episode_index_col = np.concatenate(
(episode_index_col, np.full(ep_dict["length"], ep_dict["episode_index"], dtype=int))
)
# Slightly incorrect, but for simplicity, we assign to all frames the first task defined in the episode metadata.
# TODO(rcadene): assign the tasks of the episode per chunks of frames
ep_task_index = get_task_index(tasks, ep_dict["tasks"][0])
task_index = np.concatenate((task_index, np.full(ep_dict["length"], ep_task_index, dtype=int)))
@ -319,7 +337,6 @@ def hf_dataset_factory(features_factory, tasks_factory, episodes_factory, img_ar
def lerobot_dataset_metadata_factory(
info_factory,
stats_factory,
episodes_stats_factory,
tasks_factory,
episodes_factory,
mock_snapshot_download_factory,
@ -329,30 +346,28 @@ def lerobot_dataset_metadata_factory(
repo_id: str = DUMMY_REPO_ID,
info: dict | None = None,
stats: dict | None = None,
episodes_stats: list[dict] | None = None,
tasks: list[dict] | None = None,
episodes: list[dict] | None = None,
tasks: pd.DataFrame | None = None,
episodes: datasets.Dataset | None = None,
) -> LeRobotDatasetMetadata:
if not info:
if info is None:
info = info_factory()
if not stats:
if stats is None:
stats = stats_factory(features=info["features"])
if not episodes_stats:
episodes_stats = episodes_stats_factory(
features=info["features"], total_episodes=info["total_episodes"]
)
if not tasks:
if tasks is None:
tasks = tasks_factory(total_tasks=info["total_tasks"])
if not episodes:
if episodes is None:
video_keys = [key for key, ft in info["features"].items() if ft["dtype"] == "video"]
episodes = episodes_factory(
total_episodes=info["total_episodes"], total_frames=info["total_frames"], video_keys=video_keys, tasks=tasks
features=info["features"],
total_episodes=info["total_episodes"],
total_frames=info["total_frames"],
video_keys=video_keys,
tasks=tasks,
)
mock_snapshot_download = mock_snapshot_download_factory(
info=info,
stats=stats,
episodes_stats=episodes_stats,
tasks=tasks,
episodes=episodes,
)
@ -374,7 +389,6 @@ def lerobot_dataset_metadata_factory(
def lerobot_dataset_factory(
info_factory,
stats_factory,
episodes_stats_factory,
tasks_factory,
episodes_factory,
hf_dataset_factory,
@ -390,25 +404,23 @@ def lerobot_dataset_factory(
multi_task: bool = False,
info: dict | None = None,
stats: dict | None = None,
episodes_stats: datasets.Dataset | None = None,
tasks: datasets.Dataset | None = None,
episode_dicts: datasets.Dataset | None = None,
tasks: pd.DataFrame | None = None,
episodes_metadata: datasets.Dataset | None = None,
hf_dataset: datasets.Dataset | None = None,
**kwargs,
) -> LeRobotDataset:
if not info:
if info is None:
info = info_factory(
total_episodes=total_episodes, total_frames=total_frames, total_tasks=total_tasks
)
if not stats:
if stats is None:
stats = stats_factory(features=info["features"])
if not episodes_stats:
episodes_stats = episodes_stats_factory(features=info["features"], total_episodes=total_episodes)
if not tasks:
if tasks is None:
tasks = tasks_factory(total_tasks=info["total_tasks"])
if not episode_dicts:
if episodes_metadata is None:
video_keys = [key for key, ft in info["features"].items() if ft["dtype"] == "video"]
episode_dicts = episodes_factory(
episodes_metadata = episodes_factory(
features=info["features"],
total_episodes=info["total_episodes"],
total_frames=info["total_frames"],
video_keys=video_keys,
@ -416,14 +428,13 @@ def lerobot_dataset_factory(
multi_task=multi_task,
)
if not hf_dataset:
hf_dataset = hf_dataset_factory(tasks=tasks, episodes=episode_dicts, fps=info["fps"])
hf_dataset = hf_dataset_factory(tasks=tasks, episodes=episodes_metadata, fps=info["fps"])
mock_snapshot_download = mock_snapshot_download_factory(
info=info,
stats=stats,
episodes_stats=episodes_stats,
tasks=tasks,
episodes=episode_dicts,
episodes=episodes_metadata,
hf_dataset=hf_dataset,
)
mock_metadata = lerobot_dataset_metadata_factory(
@ -431,9 +442,8 @@ def lerobot_dataset_factory(
repo_id=repo_id,
info=info,
stats=stats,
episodes_stats=episodes_stats,
tasks=tasks,
episodes=episode_dicts,
episodes=episodes_metadata,
)
with (
patch("lerobot.common.datasets.lerobot_dataset.LeRobotDatasetMetadata") as mock_metadata_patch,

View File

@ -1,17 +1,13 @@
import json
from pathlib import Path
import datasets
import jsonlines
import pandas as pd
import pyarrow.compute as pc
import pyarrow.parquet as pq
import pytest
from datasets import Dataset
from lerobot.common.datasets.utils import (
write_episodes,
write_episodes_stats,
write_hf_dataset,
write_info,
write_stats,
@ -22,7 +18,7 @@ from lerobot.common.datasets.utils import (
@pytest.fixture(scope="session")
def create_info(info_factory):
def _create_info(dir: Path, info: dict | None = None):
if not info:
if info is None:
info = info_factory()
write_info(info, dir)
@ -32,27 +28,27 @@ def create_info(info_factory):
@pytest.fixture(scope="session")
def create_stats(stats_factory):
def _create_stats(dir: Path, stats: dict | None = None):
if not stats:
if stats is None:
stats = stats_factory()
write_stats(stats, dir)
return _create_stats
@pytest.fixture(scope="session")
def create_episodes_stats(episodes_stats_factory):
def _create_episodes_stats(dir: Path, episodes_stats: Dataset | None = None):
if not episodes_stats:
episodes_stats = episodes_stats_factory()
write_episodes_stats(episodes_stats, dir)
# @pytest.fixture(scope="session")
# def create_episodes_stats(episodes_stats_factory):
# def _create_episodes_stats(dir: Path, episodes_stats: Dataset | None = None):
# if episodes_stats is None:
# episodes_stats = episodes_stats_factory()
# write_episodes_stats(episodes_stats, dir)
return _create_episodes_stats
# return _create_episodes_stats
@pytest.fixture(scope="session")
def create_tasks(tasks_factory):
def _create_tasks(dir: Path, tasks: Dataset | None = None):
if not tasks:
def _create_tasks(dir: Path, tasks: pd.DataFrame | None = None):
if tasks is None:
tasks = tasks_factory()
write_tasks(tasks, dir)
@ -61,17 +57,18 @@ def create_tasks(tasks_factory):
@pytest.fixture(scope="session")
def create_episodes(episodes_factory):
def _create_episodes(dir: Path, episodes: Dataset | None = None):
if not episodes:
def _create_episodes(dir: Path, episodes: datasets.Dataset | None = None):
if episodes is None:
episodes = episodes_factory()
write_episodes(episodes, dir)
return _create_episodes
@pytest.fixture(scope="session")
def create_hf_dataset(hf_dataset_factory):
def _create_hf_dataset(dir: Path, hf_dataset: Dataset | None = None):
if not hf_dataset:
def _create_hf_dataset(dir: Path, hf_dataset: datasets.Dataset | None = None):
if hf_dataset is None:
hf_dataset = hf_dataset_factory()
write_hf_dataset(hf_dataset, dir)
@ -84,7 +81,7 @@ def single_episode_parquet_path(hf_dataset_factory, info_factory):
dir: Path, ep_idx: int = 0, hf_dataset: datasets.Dataset | None = None, info: dict | None = None
) -> Path:
raise NotImplementedError()
if not info:
if info is None:
info = info_factory()
if hf_dataset is None:
hf_dataset = hf_dataset_factory()
@ -108,7 +105,7 @@ def multi_episode_parquet_path(hf_dataset_factory, info_factory):
dir: Path, hf_dataset: datasets.Dataset | None = None, info: dict | None = None
) -> Path:
raise NotImplementedError()
if not info:
if info is None:
info = info_factory()
if hf_dataset is None:
hf_dataset = hf_dataset_factory()

38
tests/fixtures/hub.py vendored
View File

@ -1,13 +1,13 @@
from pathlib import Path
import datasets
import pandas as pd
import pytest
from huggingface_hub.utils import filter_repo_objects
from lerobot.common.datasets.utils import (
DEFAULT_DATA_PATH,
DEFAULT_EPISODES_PATH,
DEFAULT_EPISODES_STATS_PATH,
DEFAULT_TASKS_PATH,
INFO_PATH,
LEGACY_STATS_PATH,
@ -21,8 +21,6 @@ def mock_snapshot_download_factory(
create_info,
stats_factory,
create_stats,
episodes_stats_factory,
create_episodes_stats,
tasks_factory,
create_tasks,
episodes_factory,
@ -38,46 +36,43 @@ def mock_snapshot_download_factory(
def _mock_snapshot_download_func(
info: dict | None = None,
stats: dict | None = None,
episodes_stats: datasets.Dataset | None = None,
tasks: datasets.Dataset | None = None,
tasks: pd.DataFrame | None = None,
episodes: datasets.Dataset | None = None,
hf_dataset: datasets.Dataset | None = None,
):
if not info:
if info is None:
info = info_factory()
if not stats:
if stats is None:
stats = stats_factory(features=info["features"])
if not episodes_stats:
episodes_stats = episodes_stats_factory(
features=info["features"], total_episodes=info["total_episodes"]
)
if not tasks:
if tasks is None:
tasks = tasks_factory(total_tasks=info["total_tasks"])
if not episodes:
if episodes is None:
episodes = episodes_factory(
total_episodes=info["total_episodes"], total_frames=info["total_frames"], tasks=tasks
features=info["features"],
total_episodes=info["total_episodes"],
total_frames=info["total_frames"],
tasks=tasks,
)
if not hf_dataset:
if hf_dataset is None:
hf_dataset = hf_dataset_factory(tasks=tasks, episodes=episodes, fps=info["fps"])
def _mock_snapshot_download(
repo_id: str,
repo_id: str, # TODO(rcadene): repo_id should be used no?
local_dir: str | Path | None = None,
allow_patterns: str | list[str] | None = None,
ignore_patterns: str | list[str] | None = None,
*args,
**kwargs,
) -> str:
if not local_dir:
if local_dir is None:
local_dir = LEROBOT_TEST_DIR
# List all possible files
all_files = [
INFO_PATH,
LEGACY_STATS_PATH,
# TODO(rcadene)
# TODO(rcadene): remove naive chunk 0 file 0 ?
DEFAULT_TASKS_PATH.format(chunk_index=0, file_index=0),
DEFAULT_EPISODES_STATS_PATH.format(chunk_index=0, file_index=0),
DEFAULT_EPISODES_PATH.format(chunk_index=0, file_index=0),
DEFAULT_DATA_PATH.format(chunk_index=0, file_index=0),
]
@ -89,7 +84,6 @@ def mock_snapshot_download_factory(
has_info = False
has_tasks = False
has_episodes = False
has_episodes_stats = False
has_stats = False
has_data = False
for rel_path in allowed_files:
@ -99,8 +93,6 @@ def mock_snapshot_download_factory(
has_stats = True
elif rel_path.startswith("meta/tasks"):
has_tasks = True
elif rel_path.startswith("meta/episodes_stats"):
has_episodes_stats = True
elif rel_path.startswith("meta/episodes"):
has_episodes = True
elif rel_path.startswith("data/"):
@ -116,8 +108,6 @@ def mock_snapshot_download_factory(
create_tasks(local_dir, tasks)
if has_episodes:
create_episodes(local_dir, episodes)
if has_episodes_stats:
create_episodes_stats(local_dir, episodes_stats)
if has_data:
create_hf_dataset(local_dir, hf_dataset)