Most unit tests are passing
This commit is contained in:
parent
c1b28f0b58
commit
34c5d4ce07
|
@ -16,17 +16,18 @@
|
||||||
import contextlib
|
import contextlib
|
||||||
import logging
|
import logging
|
||||||
import shutil
|
import shutil
|
||||||
from pathlib import Path
|
|
||||||
import tempfile
|
import tempfile
|
||||||
|
from pathlib import Path
|
||||||
from typing import Callable
|
from typing import Callable
|
||||||
|
|
||||||
import datasets
|
import datasets
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import packaging.version
|
import packaging.version
|
||||||
import PIL.Image
|
|
||||||
import pandas as pd
|
import pandas as pd
|
||||||
|
import PIL.Image
|
||||||
import torch
|
import torch
|
||||||
import torch.utils
|
import torch.utils
|
||||||
from datasets import concatenate_datasets, load_dataset, Dataset
|
from datasets import Dataset
|
||||||
from huggingface_hub import HfApi, snapshot_download
|
from huggingface_hub import HfApi, snapshot_download
|
||||||
from huggingface_hub.constants import REPOCARD_NAME
|
from huggingface_hub.constants import REPOCARD_NAME
|
||||||
from huggingface_hub.errors import RevisionNotFoundError
|
from huggingface_hub.errors import RevisionNotFoundError
|
||||||
|
@ -35,52 +36,36 @@ from lerobot.common.constants import HF_LEROBOT_HOME
|
||||||
from lerobot.common.datasets.compute_stats import aggregate_stats, compute_episode_stats
|
from lerobot.common.datasets.compute_stats import aggregate_stats, compute_episode_stats
|
||||||
from lerobot.common.datasets.image_writer import AsyncImageWriter, write_image
|
from lerobot.common.datasets.image_writer import AsyncImageWriter, write_image
|
||||||
from lerobot.common.datasets.utils import (
|
from lerobot.common.datasets.utils import (
|
||||||
|
DEFAULT_DATA_PATH,
|
||||||
DEFAULT_EPISODES_PATH,
|
DEFAULT_EPISODES_PATH,
|
||||||
DEFAULT_EPISODES_STATS_PATH,
|
|
||||||
DEFAULT_FEATURES,
|
DEFAULT_FEATURES,
|
||||||
DEFAULT_IMAGE_PATH,
|
DEFAULT_IMAGE_PATH,
|
||||||
EPISODES_DIR,
|
|
||||||
EPISODES_STATS_DIR,
|
|
||||||
INFO_PATH,
|
INFO_PATH,
|
||||||
LEGACY_TASKS_PATH,
|
|
||||||
append_jsonlines,
|
|
||||||
backward_compatible_episodes_stats,
|
|
||||||
check_delta_timestamps,
|
check_delta_timestamps,
|
||||||
check_timestamps_sync,
|
|
||||||
check_version_compatibility,
|
check_version_compatibility,
|
||||||
concat_video_files,
|
concat_video_files,
|
||||||
create_empty_dataset_info,
|
create_empty_dataset_info,
|
||||||
create_lerobot_dataset_card,
|
create_lerobot_dataset_card,
|
||||||
embed_images,
|
embed_images,
|
||||||
|
flatten_dict,
|
||||||
get_chunk_file_indices,
|
get_chunk_file_indices,
|
||||||
get_delta_indices,
|
get_delta_indices,
|
||||||
get_episode_data_index,
|
|
||||||
get_features_from_robot,
|
get_features_from_robot,
|
||||||
get_hf_dataset_size_in_mb,
|
get_hf_dataset_size_in_mb,
|
||||||
get_hf_features_from_features,
|
get_hf_features_from_features,
|
||||||
get_latest_parquet_path,
|
|
||||||
get_latest_video_path,
|
get_latest_video_path,
|
||||||
get_parquet_num_frames,
|
get_parquet_num_frames,
|
||||||
get_pd_dataframe_size_in_mb,
|
|
||||||
get_safe_version,
|
get_safe_version,
|
||||||
get_video_duration_in_s,
|
get_video_duration_in_s,
|
||||||
get_video_size_in_mb,
|
get_video_size_in_mb,
|
||||||
hf_transform_to_torch,
|
|
||||||
is_valid_version,
|
is_valid_version,
|
||||||
legacy_load_episodes,
|
|
||||||
legacy_load_episodes_stats,
|
|
||||||
load_episodes,
|
load_episodes,
|
||||||
load_episodes_stats,
|
|
||||||
load_info,
|
load_info,
|
||||||
load_nested_dataset,
|
load_nested_dataset,
|
||||||
load_stats,
|
|
||||||
legacy_load_tasks,
|
|
||||||
load_tasks,
|
load_tasks,
|
||||||
update_chunk_file_indices,
|
update_chunk_file_indices,
|
||||||
validate_episode_buffer,
|
validate_episode_buffer,
|
||||||
validate_frame,
|
validate_frame,
|
||||||
write_episode,
|
|
||||||
legacy_write_episode_stats,
|
|
||||||
write_info,
|
write_info,
|
||||||
write_json,
|
write_json,
|
||||||
write_tasks,
|
write_tasks,
|
||||||
|
@ -118,15 +103,17 @@ class LeRobotDatasetMetadata:
|
||||||
self.revision = get_safe_version(self.repo_id, self.revision)
|
self.revision = get_safe_version(self.repo_id, self.revision)
|
||||||
|
|
||||||
(self.root / "meta").mkdir(exist_ok=True, parents=True)
|
(self.root / "meta").mkdir(exist_ok=True, parents=True)
|
||||||
|
# TODO(rcadene): instead of downloading all episodes metadata files,
|
||||||
|
# download only the ones associated to the requested episodes. This would
|
||||||
|
# require adding `episodes: list[int]` as argument.
|
||||||
self.pull_from_repo(allow_patterns="meta/")
|
self.pull_from_repo(allow_patterns="meta/")
|
||||||
self.load_metadata()
|
self.load_metadata()
|
||||||
|
|
||||||
def load_metadata(self):
|
def load_metadata(self):
|
||||||
self.info = load_info(self.root)
|
self.info = load_info(self.root)
|
||||||
check_version_compatibility(self.repo_id, self._version, CODEBASE_VERSION)
|
check_version_compatibility(self.repo_id, self._version, CODEBASE_VERSION)
|
||||||
self.tasks, self.task_to_task_index = load_tasks(self.root)
|
self.tasks = load_tasks(self.root)
|
||||||
self.episodes = load_episodes(self.root)
|
self.episodes = load_episodes(self.root)
|
||||||
self.episodes_stats = load_episodes_stats(self.root)
|
|
||||||
# TODO(rcadene): https://huggingface.slack.com/archives/C02V51Q3800/p1743517952388249?thread_ts=1742896075.499119&cid=C02V51Q3800
|
# TODO(rcadene): https://huggingface.slack.com/archives/C02V51Q3800/p1743517952388249?thread_ts=1742896075.499119&cid=C02V51Q3800
|
||||||
# self.stats = aggregate_stats(list(self.episodes_stats.values()))
|
# self.stats = aggregate_stats(list(self.episodes_stats.values()))
|
||||||
|
|
||||||
|
@ -150,8 +137,8 @@ class LeRobotDatasetMetadata:
|
||||||
return packaging.version.parse(self.info["codebase_version"])
|
return packaging.version.parse(self.info["codebase_version"])
|
||||||
|
|
||||||
def get_data_file_path(self, ep_index: int) -> Path:
|
def get_data_file_path(self, ep_index: int) -> Path:
|
||||||
chunk_idx = self.episodes[f"data/chunk_index"][ep_index]
|
chunk_idx = self.episodes["data/chunk_index"][ep_index]
|
||||||
file_idx = self.episodes[f"data/file_index"][ep_index]
|
file_idx = self.episodes["data/file_index"][ep_index]
|
||||||
fpath = self.data_path.format(chunk_index=chunk_idx, file_index=file_idx)
|
fpath = self.data_path.format(chunk_index=chunk_idx, file_index=file_idx)
|
||||||
return Path(fpath)
|
return Path(fpath)
|
||||||
|
|
||||||
|
@ -161,9 +148,6 @@ class LeRobotDatasetMetadata:
|
||||||
fpath = self.video_path.format(video_key=vid_key, chunk_index=chunk_idx, file_index=file_idx)
|
fpath = self.video_path.format(video_key=vid_key, chunk_index=chunk_idx, file_index=file_idx)
|
||||||
return Path(fpath)
|
return Path(fpath)
|
||||||
|
|
||||||
# def get_episode_chunk(self, ep_index: int) -> int:
|
|
||||||
# return ep_index // self.chunks_size
|
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def data_path(self) -> str:
|
def data_path(self) -> str:
|
||||||
"""Formattable string for the parquet files."""
|
"""Formattable string for the parquet files."""
|
||||||
|
@ -233,7 +217,7 @@ class LeRobotDatasetMetadata:
|
||||||
def chunks_size(self) -> int:
|
def chunks_size(self) -> int:
|
||||||
"""Max number of files per chunk."""
|
"""Max number of files per chunk."""
|
||||||
return self.info["chunks_size"]
|
return self.info["chunks_size"]
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def files_size_in_mb(self) -> int:
|
def files_size_in_mb(self) -> int:
|
||||||
"""Max size of file in mega bytes."""
|
"""Max size of file in mega bytes."""
|
||||||
|
@ -244,71 +228,84 @@ class LeRobotDatasetMetadata:
|
||||||
Given a task in natural language, returns its task_index if the task already exists in the dataset,
|
Given a task in natural language, returns its task_index if the task already exists in the dataset,
|
||||||
otherwise return None.
|
otherwise return None.
|
||||||
"""
|
"""
|
||||||
return self.tasks.index[task] if task in self.tasks.index else None
|
if task in self.tasks.index:
|
||||||
|
return int(self.tasks.loc[task].task_index)
|
||||||
def has_task(self, task: str) -> bool:
|
else:
|
||||||
return task in self.task_to_task_index
|
return None
|
||||||
|
|
||||||
def save_episode_tasks(self, tasks: list[str]):
|
def save_episode_tasks(self, tasks: list[str]):
|
||||||
new_tasks = [task for task in tasks if not self.has_task(task)]
|
if len(set(tasks)) != len(tasks):
|
||||||
|
raise ValueError(f"Tasks are not unique: {tasks}")
|
||||||
|
|
||||||
for task in new_tasks:
|
if self.tasks is None:
|
||||||
task_index = len(self.tasks)
|
new_tasks = tasks
|
||||||
self.tasks.loc[task] = task_index
|
task_indices = range(len(tasks))
|
||||||
|
self.tasks = pd.DataFrame({"task_index": task_indices}, index=tasks)
|
||||||
|
else:
|
||||||
|
new_tasks = [task for task in tasks if task not in self.tasks.index]
|
||||||
|
new_task_indices = range(len(self.tasks), len(new_tasks))
|
||||||
|
for task_idx, task in zip(new_task_indices, new_tasks, strict=False):
|
||||||
|
self.tasks.loc[task] = task_idx
|
||||||
|
|
||||||
if len(new_tasks) > 0:
|
if len(new_tasks) > 0:
|
||||||
# Update on disk
|
# Update on disk
|
||||||
write_tasks(self.tasks, self.root)
|
write_tasks(self.tasks, self.root)
|
||||||
|
|
||||||
def _save_episode(self, episode_dict: dict):
|
def _save_episode_metadata(self, episode_dict: dict) -> None:
|
||||||
|
"""Save episode metadata to a parquet file and update the Hugging Face dataset of episodes metadata.
|
||||||
|
|
||||||
|
This function processes episodes metadata from a dictionary, converts it into a Hugging Face dataset,
|
||||||
|
and saves it as a parquet file. It handles both the creation of new parquet files and the
|
||||||
|
updating of existing ones based on size constraints. After saving the metadata, it reloads
|
||||||
|
the Hugging Face dataset to ensure it is up-to-date.
|
||||||
|
|
||||||
|
Notes: We both need to update parquet files and HF dataset:
|
||||||
|
- `pandas` loads parquet file in RAM
|
||||||
|
- `datasets` relies on a memory mapping from pyarrow (no RAM). It either converts parquet files to a pyarrow cache on disk,
|
||||||
|
or loads directly from pyarrow cache.
|
||||||
|
"""
|
||||||
|
# Convert buffer into HF Dataset
|
||||||
|
episode_dict = {key: [value] for key, value in episode_dict.items()}
|
||||||
ep_dataset = Dataset.from_dict(episode_dict)
|
ep_dataset = Dataset.from_dict(episode_dict)
|
||||||
ep_size_in_mb = get_hf_dataset_size_in_mb(ep_dataset)
|
ep_size_in_mb = get_hf_dataset_size_in_mb(ep_dataset)
|
||||||
|
df = pd.DataFrame(ep_dataset)
|
||||||
|
|
||||||
# Access latest parquet file information
|
if self.episodes is None:
|
||||||
latest_path = get_latest_parquet_path(self.root / EPISODES_DIR)
|
# Initialize indices and frame count for a new dataset made of the first episode data
|
||||||
latest_size_in_mb = get_parquet_file_size_in_mb(latest_path)
|
chunk_idx, file_idx = 0, 0
|
||||||
chunk_idx, file_idx = get_chunk_file_indices(latest_path)
|
df["meta/episodes/chunk_index"] = [chunk_idx]
|
||||||
|
df["meta/episodes/file_index"] = [file_idx]
|
||||||
if latest_size_in_mb + ep_size_in_mb >= self.meta.files_size_in_mb:
|
|
||||||
# Create new parquet file
|
|
||||||
chunk_idx, file_idx = update_chunk_file_indices(chunk_idx, file_idx, self.meta.chunks_size)
|
|
||||||
new_path = self.root / DEFAULT_EPISODES_PATH.format(chunk_index=chunk_idx, file_index=file_idx)
|
|
||||||
new_path.parent.mkdir(parents=True, exist_ok=True)
|
|
||||||
ep_df.to_parquet(new_path, index=False)
|
|
||||||
else:
|
else:
|
||||||
# Update latest parquet file with new row
|
# Retrieve information from the latest parquet file
|
||||||
ep_df = pd.DataFrame(ep_dataset)
|
latest_ep = self.episodes.with_format(columns=["chunk_index", "file_index"])[-1]
|
||||||
latest_df = pd.concat([latest_df, ep_df], ignore_index=True) # RAM
|
chunk_idx, file_idx = latest_ep["chunk_index"], latest_ep["file_index"]
|
||||||
latest_df.to_parquet(latest_path, index=False)
|
|
||||||
|
latest_path = self.root / DEFAULT_EPISODES_PATH.format(chunk_index=chunk_idx, file_index=file_idx)
|
||||||
|
latest_size_in_mb = get_parquet_file_size_in_mb(latest_path)
|
||||||
|
|
||||||
|
# Determine if a new parquet file is needed
|
||||||
|
if latest_size_in_mb + ep_size_in_mb >= self.files_size_in_mb:
|
||||||
|
# Size limit is reached, prepare new parquet file
|
||||||
|
chunk_idx, file_idx = update_chunk_file_indices(chunk_idx, file_idx, self.meta.chunks_size)
|
||||||
|
df["meta/episodes/chunk_index"] = [chunk_idx]
|
||||||
|
df["meta/episodes/file_index"] = [file_idx]
|
||||||
|
else:
|
||||||
|
# Update the existing parquet file with new row
|
||||||
|
df["meta/episodes/chunk_index"] = [chunk_idx]
|
||||||
|
df["meta/episodes/file_index"] = [file_idx]
|
||||||
|
latest_df = pd.read_parquet(latest_path)
|
||||||
|
latest_df = pd.concat([latest_df, df], ignore_index=True)
|
||||||
|
|
||||||
|
# Write the resulting dataframe from RAM to disk
|
||||||
|
path = self.root / DEFAULT_EPISODES_PATH.format(chunk_index=chunk_idx, file_index=file_idx)
|
||||||
|
path.parent.mkdir(parents=True, exist_ok=True)
|
||||||
|
df.to_parquet(path, index=False)
|
||||||
|
|
||||||
# Update the Hugging Face dataset by reloading it.
|
# Update the Hugging Face dataset by reloading it.
|
||||||
# This process should be fast because only the latest Parquet file has been modified.
|
# This process should be fast because only the latest Parquet file has been modified.
|
||||||
# Therefore, only this file needs to be converted to PyArrow; the rest is loaded from the PyArrow memory-mapped cache.
|
# Therefore, only this file needs to be converted to PyArrow; the rest is loaded from the PyArrow memory-mapped cache.
|
||||||
self.episodes = load_episodes(self.root)
|
self.episodes = load_episodes(self.root)
|
||||||
|
|
||||||
def _save_episode_stats(self, episodes_stats: dict):
|
|
||||||
ep_dataset = Dataset.from_dict(episodes_stats)
|
|
||||||
ep_size_in_mb = get_hf_dataset_size_in_mb(ep_dataset)
|
|
||||||
|
|
||||||
# Access latest parquet file information
|
|
||||||
latest_path = get_latest_parquet_path(self.root / EPISODES_STATS_DIR)
|
|
||||||
latest_size_in_mb = get_parquet_file_size_in_mb(latest_path)
|
|
||||||
chunk_idx, file_idx = get_chunk_file_indices(latest_path)
|
|
||||||
|
|
||||||
if latest_size_in_mb + ep_size_in_mb >= self.meta.files_size_in_mb:
|
|
||||||
# Create new parquet file
|
|
||||||
chunk_idx, file_idx = update_chunk_file_indices(chunk_idx, file_idx, self.meta.chunks_size)
|
|
||||||
new_path = self.root / DEFAULT_EPISODES_STATS_PATH.format(chunk_index=chunk_idx, file_index=file_idx)
|
|
||||||
new_path.parent.mkdir(parents=True, exist_ok=True)
|
|
||||||
ep_df.to_parquet(new_path, index=False)
|
|
||||||
else:
|
|
||||||
# Update latest parquet file with new row
|
|
||||||
ep_df = pd.DataFrame(ep_dataset)
|
|
||||||
latest_df = pd.concat([latest_df, ep_df], ignore_index=True) # RAM
|
|
||||||
latest_df.to_parquet(latest_path, index=False)
|
|
||||||
|
|
||||||
self.episodes_stats = load_episodes_stats(self.root)
|
|
||||||
|
|
||||||
def save_episode(
|
def save_episode(
|
||||||
self,
|
self,
|
||||||
episode_index: int,
|
episode_index: int,
|
||||||
|
@ -331,8 +328,8 @@ class LeRobotDatasetMetadata:
|
||||||
"length": episode_length,
|
"length": episode_length,
|
||||||
}
|
}
|
||||||
episode_dict.update(episode_metadata)
|
episode_dict.update(episode_metadata)
|
||||||
self._save_episode(episode_dict)
|
episode_dict.update(flatten_dict({"stats": episode_stats}))
|
||||||
self._save_episode_stats(episode_stats)
|
self._save_episode_metadata(episode_dict)
|
||||||
|
|
||||||
self.stats = aggregate_stats([self.stats, episode_stats]) if self.stats else episode_stats
|
self.stats = aggregate_stats([self.stats, episode_stats]) if self.stats else episode_stats
|
||||||
# TODO: write stats
|
# TODO: write stats
|
||||||
|
@ -401,7 +398,6 @@ class LeRobotDatasetMetadata:
|
||||||
features = {**features, **DEFAULT_FEATURES}
|
features = {**features, **DEFAULT_FEATURES}
|
||||||
|
|
||||||
obj.tasks = None
|
obj.tasks = None
|
||||||
obj.episodes_stats = None
|
|
||||||
obj.episodes = None
|
obj.episodes = None
|
||||||
# TODO(rcadene) stats
|
# TODO(rcadene) stats
|
||||||
obj.stats = {}
|
obj.stats = {}
|
||||||
|
@ -557,7 +553,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||||
self.hf_dataset = self.load_hf_dataset()
|
self.hf_dataset = self.load_hf_dataset()
|
||||||
except (AssertionError, FileNotFoundError, NotADirectoryError):
|
except (AssertionError, FileNotFoundError, NotADirectoryError):
|
||||||
self.revision = get_safe_version(self.repo_id, self.revision)
|
self.revision = get_safe_version(self.repo_id, self.revision)
|
||||||
self.download_episodes(download_videos)
|
self.download(download_videos)
|
||||||
self.hf_dataset = self.load_hf_dataset()
|
self.hf_dataset = self.load_hf_dataset()
|
||||||
|
|
||||||
# Setup delta_indices
|
# Setup delta_indices
|
||||||
|
@ -635,7 +631,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||||
ignore_patterns=ignore_patterns,
|
ignore_patterns=ignore_patterns,
|
||||||
)
|
)
|
||||||
|
|
||||||
def download_episodes(self, download_videos: bool = True) -> None:
|
def download(self, download_videos: bool = True) -> None:
|
||||||
"""Downloads the dataset from the given 'repo_id' at the provided version. If 'episodes' is given, this
|
"""Downloads the dataset from the given 'repo_id' at the provided version. If 'episodes' is given, this
|
||||||
will only download those episodes (selected by their episode_index). If 'episodes' is None, the whole
|
will only download those episodes (selected by their episode_index). If 'episodes' is None, the whole
|
||||||
dataset will be downloaded. Thanks to the behavior of snapshot_download, if the files are already present
|
dataset will be downloaded. Thanks to the behavior of snapshot_download, if the files are already present
|
||||||
|
@ -795,10 +791,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||||
|
|
||||||
# Add task as a string
|
# Add task as a string
|
||||||
task_idx = item["task_index"].item()
|
task_idx = item["task_index"].item()
|
||||||
if self.meta.tasks["task_index"][task_idx] != task_idx:
|
item["task"] = self.meta.tasks.iloc[task_idx].name
|
||||||
raise ValueError("Sanity check on task index failed.")
|
|
||||||
item["task"] = self.meta.tasks["task"][task_idx]
|
|
||||||
|
|
||||||
return item
|
return item
|
||||||
|
|
||||||
def __repr__(self):
|
def __repr__(self):
|
||||||
|
@ -827,7 +820,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||||
image_key=image_key, episode_index=episode_index, frame_index=frame_index
|
image_key=image_key, episode_index=episode_index, frame_index=frame_index
|
||||||
)
|
)
|
||||||
return self.root / fpath
|
return self.root / fpath
|
||||||
|
|
||||||
def _get_image_file_dir(self, episode_index: int, image_key: str) -> Path:
|
def _get_image_file_dir(self, episode_index: int, image_key: str) -> Path:
|
||||||
return self._get_image_file_path(episode_index, image_key, frame_index=0).parent
|
return self._get_image_file_path(episode_index, image_key, frame_index=0).parent
|
||||||
|
|
||||||
|
@ -926,11 +919,11 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||||
self._wait_image_writer()
|
self._wait_image_writer()
|
||||||
ep_stats = compute_episode_stats(episode_buffer, self.features)
|
ep_stats = compute_episode_stats(episode_buffer, self.features)
|
||||||
|
|
||||||
ep_metadata = self._save_episode_data(episode_buffer, episode_index)
|
ep_metadata = self._save_episode_data(episode_buffer)
|
||||||
for video_key in self.meta.video_keys:
|
for video_key in self.meta.video_keys:
|
||||||
ep_metadata.update(self._save_episode_video(video_key, episode_index))
|
ep_metadata.update(self._save_episode_video(video_key, episode_index))
|
||||||
|
|
||||||
# `meta.save_episode` neeed to be executed after encoding the videos
|
# `meta.save_episode` need to be executed after encoding the videos
|
||||||
self.meta.save_episode(episode_index, episode_length, episode_tasks, ep_stats, ep_metadata)
|
self.meta.save_episode(episode_index, episode_length, episode_tasks, ep_stats, ep_metadata)
|
||||||
|
|
||||||
# TODO(rcadene): remove? there is only one episode in the episode buffer, no need for ep_data_index
|
# TODO(rcadene): remove? there is only one episode in the episode buffer, no need for ep_data_index
|
||||||
|
@ -954,31 +947,57 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||||
# Reset episode buffer
|
# Reset episode buffer
|
||||||
self.episode_buffer = self.create_episode_buffer()
|
self.episode_buffer = self.create_episode_buffer()
|
||||||
|
|
||||||
def _save_episode_data(self, episode_buffer: dict) -> None:
|
def _save_episode_data(self, episode_buffer: dict) -> dict:
|
||||||
|
"""Save episode data to a parquet file and update the Hugging Face dataset of frames data.
|
||||||
|
|
||||||
|
This function processes episodes data from a buffer, converts it into a Hugging Face dataset,
|
||||||
|
and saves it as a parquet file. It handles both the creation of new parquet files and the
|
||||||
|
updating of existing ones based on size constraints. After saving the data, it reloads
|
||||||
|
the Hugging Face dataset to ensure it is up-to-date.
|
||||||
|
|
||||||
|
Notes: We both need to update parquet files and HF dataset:
|
||||||
|
- `pandas` loads parquet file in RAM
|
||||||
|
- `datasets` relies on a memory mapping from pyarrow (no RAM). It either converts parquet files to a pyarrow cache on disk,
|
||||||
|
or loads directly from pyarrow cache.
|
||||||
|
"""
|
||||||
# Convert buffer into HF Dataset
|
# Convert buffer into HF Dataset
|
||||||
ep_dict = {key: episode_buffer[key] for key in self.hf_features}
|
ep_dict = {key: episode_buffer[key] for key in self.hf_features}
|
||||||
ep_dataset = datasets.Dataset.from_dict(ep_dict, features=self.hf_features, split="train")
|
ep_dataset = datasets.Dataset.from_dict(ep_dict, features=self.hf_features, split="train")
|
||||||
ep_dataset = embed_images(ep_dataset)
|
ep_dataset = embed_images(ep_dataset)
|
||||||
ep_size_in_mb = get_hf_dataset_size_in_mb(ep_dataset)
|
ep_size_in_mb = get_hf_dataset_size_in_mb(ep_dataset)
|
||||||
ep_num_frames = len(ep_dataset)
|
ep_num_frames = len(ep_dataset)
|
||||||
|
df = pd.DataFrame(ep_dataset)
|
||||||
|
|
||||||
# Access latest parquet file information
|
if self.meta.episodes is None:
|
||||||
latest_path = get_latest_parquet_path(self.root / "data")
|
# Initialize indices and frame count for a new dataset made of the first episode data
|
||||||
latest_size_in_mb = get_parquet_file_size_in_mb(latest_path)
|
chunk_idx, file_idx = 0, 0
|
||||||
latest_num_frames = get_parquet_num_frames(latest_path)
|
latest_num_frames = 0
|
||||||
chunk_idx, file_idx = get_chunk_file_indices(latest_path)
|
|
||||||
|
|
||||||
if latest_size_in_mb + ep_size_in_mb >= self.meta.files_size_in_mb:
|
|
||||||
# Create new parquet file
|
|
||||||
chunk_idx, file_idx = update_chunk_file_indices(chunk_idx, file_idx, self.meta.chunks_size)
|
|
||||||
new_path = self.meta.data_path.format(chunk_index=chunk_idx, file_index=file_idx)
|
|
||||||
new_path.parent.mkdir(parents=True, exist_ok=True)
|
|
||||||
ep_df.to_parquet(new_path, index=False)
|
|
||||||
else:
|
else:
|
||||||
# Update latest parquet file with new rows
|
# Retrieve information from the latest parquet file
|
||||||
ep_df = pd.DataFrame(ep_dataset)
|
latest_ep = self.meta.episodes.with_format(columns=["data/chunk_index", "data/file_index"])[-1]
|
||||||
latest_df = pd.concat([latest_df, ep_df], ignore_index=True) # RAM
|
chunk_idx, file_idx = latest_ep["data/chunk_index"], latest_ep["data/file_index"]
|
||||||
latest_df.to_parquet(latest_path, index=False)
|
|
||||||
|
latest_path = self.root / DEFAULT_DATA_PATH.format(chunk_index=chunk_idx, file_index=file_idx)
|
||||||
|
latest_size_in_mb = get_parquet_file_size_in_mb(latest_path)
|
||||||
|
latest_num_frames = get_parquet_num_frames(latest_path)
|
||||||
|
|
||||||
|
# Determine if a new parquet file is needed
|
||||||
|
if latest_size_in_mb + ep_size_in_mb >= self.meta.files_size_in_mb:
|
||||||
|
# Size limit is reached, prepare new parquet file
|
||||||
|
chunk_idx, file_idx = update_chunk_file_indices(chunk_idx, file_idx, self.meta.chunks_size)
|
||||||
|
latest_num_frames = 0
|
||||||
|
else:
|
||||||
|
# Update the existing parquet file with new rows
|
||||||
|
latest_df = pd.read_parquet(latest_path)
|
||||||
|
df = pd.concat([latest_df, df], ignore_index=True)
|
||||||
|
|
||||||
|
# Write the resulting dataframe from RAM to disk
|
||||||
|
path = self.root / self.meta.data_path.format(chunk_index=chunk_idx, file_index=file_idx)
|
||||||
|
path.parent.mkdir(parents=True, exist_ok=True)
|
||||||
|
if len(self.meta.image_keys) > 0:
|
||||||
|
datasets.Dataset.from_dict(df.to_dict(orient="list")).to_parquet(path)
|
||||||
|
else:
|
||||||
|
df.to_parquet(path)
|
||||||
|
|
||||||
# Update the Hugging Face dataset by reloading it.
|
# Update the Hugging Face dataset by reloading it.
|
||||||
# This process should be fast because only the latest Parquet file has been modified.
|
# This process should be fast because only the latest Parquet file has been modified.
|
||||||
|
@ -1008,7 +1027,9 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||||
if latest_size_in_mb + ep_size_in_mb >= self.meta.files_size_in_mb:
|
if latest_size_in_mb + ep_size_in_mb >= self.meta.files_size_in_mb:
|
||||||
# Move temporary episode video to a new video file in the dataset
|
# Move temporary episode video to a new video file in the dataset
|
||||||
chunk_idx, file_idx = update_chunk_file_indices(chunk_idx, file_idx, self.meta.chunks_size)
|
chunk_idx, file_idx = update_chunk_file_indices(chunk_idx, file_idx, self.meta.chunks_size)
|
||||||
new_path = self.meta.video_path.format(video_key=video_key, chunk_index=chunk_idx, file_index=file_idx)
|
new_path = self.meta.video_path.format(
|
||||||
|
video_key=video_key, chunk_index=chunk_idx, file_index=file_idx
|
||||||
|
)
|
||||||
new_path.parent.mkdir(parents=True, exist_ok=True)
|
new_path.parent.mkdir(parents=True, exist_ok=True)
|
||||||
ep_path.replace(new_path)
|
ep_path.replace(new_path)
|
||||||
else:
|
else:
|
||||||
|
|
|
@ -17,12 +17,12 @@ import contextlib
|
||||||
import importlib.resources
|
import importlib.resources
|
||||||
import json
|
import json
|
||||||
import logging
|
import logging
|
||||||
|
import subprocess
|
||||||
|
import tempfile
|
||||||
from collections.abc import Iterator
|
from collections.abc import Iterator
|
||||||
from itertools import accumulate
|
from itertools import accumulate
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from pprint import pformat
|
from pprint import pformat
|
||||||
import subprocess
|
|
||||||
import tempfile
|
|
||||||
from types import SimpleNamespace
|
from types import SimpleNamespace
|
||||||
from typing import Any, Tuple
|
from typing import Any, Tuple
|
||||||
|
|
||||||
|
@ -31,24 +31,24 @@ import jsonlines
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import packaging.version
|
import packaging.version
|
||||||
import pandas
|
import pandas
|
||||||
|
import pandas as pd
|
||||||
|
import pyarrow.parquet as pq
|
||||||
import torch
|
import torch
|
||||||
|
from datasets import Dataset, concatenate_datasets
|
||||||
from datasets.table import embed_table_storage
|
from datasets.table import embed_table_storage
|
||||||
from huggingface_hub import DatasetCard, DatasetCardData, HfApi
|
from huggingface_hub import DatasetCard, DatasetCardData, HfApi
|
||||||
from huggingface_hub.errors import RevisionNotFoundError
|
from huggingface_hub.errors import RevisionNotFoundError
|
||||||
from PIL import Image as PILImage
|
from PIL import Image as PILImage
|
||||||
from torchvision import transforms
|
from torchvision import transforms
|
||||||
from datasets import Dataset, concatenate_datasets
|
|
||||||
|
|
||||||
from lerobot.common.datasets.backward_compatibility import (
|
from lerobot.common.datasets.backward_compatibility import (
|
||||||
V21_MESSAGE,
|
V21_MESSAGE,
|
||||||
V30_MESSAGE,
|
|
||||||
BackwardCompatibilityError,
|
BackwardCompatibilityError,
|
||||||
ForwardCompatibilityError,
|
ForwardCompatibilityError,
|
||||||
)
|
)
|
||||||
from lerobot.common.robot_devices.robots.utils import Robot
|
from lerobot.common.robot_devices.robots.utils import Robot
|
||||||
from lerobot.common.utils.utils import is_valid_numpy_dtype_string
|
from lerobot.common.utils.utils import is_valid_numpy_dtype_string
|
||||||
from lerobot.configs.types import DictLike, FeatureType, PolicyFeature
|
from lerobot.configs.types import DictLike, FeatureType, PolicyFeature
|
||||||
import pyarrow.parquet as pq
|
|
||||||
|
|
||||||
DEFAULT_CHUNK_SIZE = 1000 # Max number of files per chunk
|
DEFAULT_CHUNK_SIZE = 1000 # Max number of files per chunk
|
||||||
DEFAULT_FILE_SIZE_IN_MB = 500.0 # Max size per file
|
DEFAULT_FILE_SIZE_IN_MB = 500.0 # Max size per file
|
||||||
|
@ -65,16 +65,13 @@ LEGACY_DEFAULT_PARQUET_PATH = "data/chunk-{episode_chunk:03d}/episode_{episode_i
|
||||||
DEFAULT_IMAGE_PATH = "images/{image_key}/episode_{episode_index:06d}/frame_{frame_index:06d}.png"
|
DEFAULT_IMAGE_PATH = "images/{image_key}/episode_{episode_index:06d}/frame_{frame_index:06d}.png"
|
||||||
|
|
||||||
EPISODES_DIR = "meta/episodes"
|
EPISODES_DIR = "meta/episodes"
|
||||||
EPISODES_STATS_DIR = "meta/episodes_stats"
|
|
||||||
TASKS_DIR = "meta/tasks"
|
|
||||||
DATA_DIR = "data"
|
DATA_DIR = "data"
|
||||||
VIDEO_DIR = "videos"
|
VIDEO_DIR = "videos"
|
||||||
|
|
||||||
INFO_PATH = "meta/info.json"
|
INFO_PATH = "meta/info.json"
|
||||||
CHUNK_FILE_PATTERN = "chunk-{chunk_index:03d}/file-{file_index:03d}"
|
CHUNK_FILE_PATTERN = "chunk-{chunk_index:03d}/file-{file_index:03d}"
|
||||||
DEFAULT_EPISODES_PATH = EPISODES_DIR + "/" + CHUNK_FILE_PATTERN + ".parquet"
|
DEFAULT_EPISODES_PATH = EPISODES_DIR + "/" + CHUNK_FILE_PATTERN + ".parquet"
|
||||||
DEFAULT_EPISODES_STATS_PATH = EPISODES_STATS_DIR + "/" + CHUNK_FILE_PATTERN + ".parquet"
|
DEFAULT_TASKS_PATH = "meta/tasks.parquet"
|
||||||
DEFAULT_TASKS_PATH = TASKS_DIR + "/" + CHUNK_FILE_PATTERN + ".parquet"
|
|
||||||
DEFAULT_DATA_PATH = DATA_DIR + "/" + CHUNK_FILE_PATTERN + ".parquet"
|
DEFAULT_DATA_PATH = DATA_DIR + "/" + CHUNK_FILE_PATTERN + ".parquet"
|
||||||
DEFAULT_VIDEO_PATH = VIDEO_DIR + "/{video_key}/" + CHUNK_FILE_PATTERN + ".mp4"
|
DEFAULT_VIDEO_PATH = VIDEO_DIR + "/{video_key}/" + CHUNK_FILE_PATTERN + ".mp4"
|
||||||
|
|
||||||
|
@ -98,11 +95,13 @@ DEFAULT_FEATURES = {
|
||||||
|
|
||||||
|
|
||||||
def get_hf_dataset_size_in_mb(hf_ds: Dataset) -> int:
|
def get_hf_dataset_size_in_mb(hf_ds: Dataset) -> int:
|
||||||
return hf_ds.data.nbytes / (1024 ** 2)
|
return hf_ds.data.nbytes / (1024**2)
|
||||||
|
|
||||||
|
|
||||||
def get_pd_dataframe_size_in_mb(df: pandas.DataFrame) -> int:
|
def get_pd_dataframe_size_in_mb(df: pandas.DataFrame) -> int:
|
||||||
memory_usage_bytes = df.memory_usage(deep=True).sum()
|
memory_usage_bytes = df.memory_usage(deep=True).sum()
|
||||||
return memory_usage_bytes / (1024 ** 2)
|
return memory_usage_bytes / (1024**2)
|
||||||
|
|
||||||
|
|
||||||
def get_chunk_file_indices(path: Path) -> Tuple[int, int]:
|
def get_chunk_file_indices(path: Path) -> Tuple[int, int]:
|
||||||
if not path.stem.startswith("file-") or not path.parent.name.startswith("chunk-"):
|
if not path.stem.startswith("file-") or not path.parent.name.startswith("chunk-"):
|
||||||
|
@ -112,6 +111,7 @@ def get_chunk_file_indices(path: Path) -> Tuple[int, int]:
|
||||||
file_index = int(path.stem.replace("file-", ""))
|
file_index = int(path.stem.replace("file-", ""))
|
||||||
return chunk_index, file_index
|
return chunk_index, file_index
|
||||||
|
|
||||||
|
|
||||||
def update_chunk_file_indices(chunk_idx: int, file_idx: int, chunks_size: int):
|
def update_chunk_file_indices(chunk_idx: int, file_idx: int, chunks_size: int):
|
||||||
if file_idx == chunks_size - 1:
|
if file_idx == chunks_size - 1:
|
||||||
file_idx = 0
|
file_idx = 0
|
||||||
|
@ -122,63 +122,86 @@ def update_chunk_file_indices(chunk_idx: int, file_idx: int, chunks_size: int):
|
||||||
|
|
||||||
|
|
||||||
def load_nested_dataset(pq_dir: Path) -> Dataset:
|
def load_nested_dataset(pq_dir: Path) -> Dataset:
|
||||||
""" Find parquet files in provided directory {pq_dir}/chunk-xxx/file-xxx.parquet
|
"""Find parquet files in provided directory {pq_dir}/chunk-xxx/file-xxx.parquet
|
||||||
Convert parquet files to pyarrow memory mapped in a cache folder for efficient RAM usage
|
Convert parquet files to pyarrow memory mapped in a cache folder for efficient RAM usage
|
||||||
Concatenate all pyarrow references to return HF Dataset format
|
Concatenate all pyarrow references to return HF Dataset format
|
||||||
"""
|
"""
|
||||||
|
paths = sorted(pq_dir.glob("*/*.parquet"))
|
||||||
|
if len(paths) == 0:
|
||||||
|
raise FileNotFoundError(f"Provided directory does not contain any parquet file: {pq_dir}")
|
||||||
|
|
||||||
# TODO(rcadene): set num_proc to accelerate conversion to pyarrow
|
# TODO(rcadene): set num_proc to accelerate conversion to pyarrow
|
||||||
return concatenate_datasets([Dataset.from_parquet(str(path)) for path in sorted(pq_dir.glob("*/*.parquet"))])
|
return concatenate_datasets(
|
||||||
|
[Dataset.from_parquet(str(path)) for path in sorted(pq_dir.glob("*/*.parquet"))]
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def get_latest_parquet_path(pq_dir: Path) -> Path:
|
def get_latest_parquet_path(pq_dir: Path) -> Path:
|
||||||
return sorted(pq_dir.glob("*/*.parquet"))[-1]
|
return sorted(pq_dir.glob("*/*.parquet"))[-1]
|
||||||
|
|
||||||
|
|
||||||
def get_latest_video_path(pq_dir: Path, video_key: str) -> Path:
|
def get_latest_video_path(pq_dir: Path, video_key: str) -> Path:
|
||||||
return sorted(pq_dir.glob(f"{video_key}/*/*.mp4"))[-1]
|
return sorted(pq_dir.glob(f"{video_key}/*/*.mp4"))[-1]
|
||||||
|
|
||||||
|
|
||||||
def get_parquet_num_frames(parquet_path):
|
def get_parquet_num_frames(parquet_path):
|
||||||
metadata = pq.read_metadata(parquet_path)
|
metadata = pq.read_metadata(parquet_path)
|
||||||
return metadata.num_rows
|
return metadata.num_rows
|
||||||
|
|
||||||
|
|
||||||
def get_video_size_in_mb(mp4_path: Path):
|
def get_video_size_in_mb(mp4_path: Path):
|
||||||
file_size_bytes = mp4_path.stat().st_size
|
file_size_bytes = mp4_path.stat().st_size
|
||||||
file_size_mb = file_size_bytes / (1024 ** 2)
|
file_size_mb = file_size_bytes / (1024**2)
|
||||||
return file_size_mb
|
return file_size_mb
|
||||||
|
|
||||||
|
|
||||||
def concat_video_files(paths_to_cat, new_root, video_key, chunk_idx, file_idx):
|
def concat_video_files(paths_to_cat, new_root, video_key, chunk_idx, file_idx):
|
||||||
# Create a text file with the list of files to concatenate
|
# Create a text file with the list of files to concatenate
|
||||||
with tempfile.NamedTemporaryFile(mode='w', delete=False, suffix='.txt') as f:
|
with tempfile.NamedTemporaryFile(mode="w", delete=False, suffix=".txt") as f:
|
||||||
temp_file_path = f.name
|
temp_file_path = f.name
|
||||||
for ep_path in paths_to_cat:
|
for ep_path in paths_to_cat:
|
||||||
f.write(f"file '{str(ep_path)}'\n")
|
f.write(f"file '{str(ep_path)}'\n")
|
||||||
|
|
||||||
output_path = new_root / DEFAULT_VIDEO_PATH.format(video_key=video_key, chunk_index=chunk_idx, file_index=file_idx)
|
output_path = new_root / DEFAULT_VIDEO_PATH.format(
|
||||||
|
video_key=video_key, chunk_index=chunk_idx, file_index=file_idx
|
||||||
|
)
|
||||||
output_path.parent.mkdir(parents=True, exist_ok=True)
|
output_path.parent.mkdir(parents=True, exist_ok=True)
|
||||||
command = [
|
command = [
|
||||||
'ffmpeg',
|
"ffmpeg",
|
||||||
'-y',
|
"-y",
|
||||||
'-f', 'concat',
|
"-f",
|
||||||
'-safe', '0',
|
"concat",
|
||||||
'-i', str(temp_file_path),
|
"-safe",
|
||||||
'-c', 'copy',
|
"0",
|
||||||
str(output_path)
|
"-i",
|
||||||
|
str(temp_file_path),
|
||||||
|
"-c",
|
||||||
|
"copy",
|
||||||
|
str(output_path),
|
||||||
]
|
]
|
||||||
subprocess.run(command, check=True)
|
subprocess.run(command, check=True)
|
||||||
Path(temp_file_path).unlink()
|
Path(temp_file_path).unlink()
|
||||||
|
|
||||||
|
|
||||||
def get_video_duration_in_s(mp4_file: Path):
|
def get_video_duration_in_s(mp4_file: Path):
|
||||||
|
command = [
|
||||||
|
"ffprobe",
|
||||||
|
"-v",
|
||||||
|
"error",
|
||||||
|
"-show_entries",
|
||||||
|
"format=duration",
|
||||||
|
"-of",
|
||||||
|
"default=noprint_wrappers=1:nokey=1",
|
||||||
|
mp4_file,
|
||||||
|
]
|
||||||
result = subprocess.run(
|
result = subprocess.run(
|
||||||
[
|
command,
|
||||||
'ffprobe',
|
|
||||||
'-v', 'error',
|
|
||||||
'-show_entries', 'format=duration',
|
|
||||||
'-of', 'default=noprint_wrappers=1:nokey=1',
|
|
||||||
mp4_file
|
|
||||||
],
|
|
||||||
stdout=subprocess.PIPE,
|
stdout=subprocess.PIPE,
|
||||||
stderr=subprocess.STDOUT
|
stderr=subprocess.STDOUT,
|
||||||
)
|
)
|
||||||
return float(result.stdout)
|
return float(result.stdout)
|
||||||
|
|
||||||
|
|
||||||
def flatten_dict(d: dict, parent_key: str = "", sep: str = "/") -> dict:
|
def flatten_dict(d: dict, parent_key: str = "", sep: str = "/") -> dict:
|
||||||
"""Flatten a nested dictionary structure by collapsing nested keys into one key with a separator.
|
"""Flatten a nested dictionary structure by collapsing nested keys into one key with a separator.
|
||||||
|
|
||||||
|
@ -314,10 +337,11 @@ def write_hf_dataset(hf_dataset: Dataset, local_dir: Path):
|
||||||
|
|
||||||
|
|
||||||
def write_tasks(tasks: pandas.DataFrame, local_dir: Path):
|
def write_tasks(tasks: pandas.DataFrame, local_dir: Path):
|
||||||
path = local_dir / DEFAULT_TASKS_PATH.format(chunk_index=0, file_index=0)
|
path = local_dir / DEFAULT_TASKS_PATH
|
||||||
path.parent.mkdir(parents=True, exist_ok=True)
|
path.parent.mkdir(parents=True, exist_ok=True)
|
||||||
tasks.to_parquet(path)
|
tasks.to_parquet(path)
|
||||||
|
|
||||||
|
|
||||||
def legacy_write_task(task_index: int, task: dict, local_dir: Path):
|
def legacy_write_task(task_index: int, task: dict, local_dir: Path):
|
||||||
task_dict = {
|
task_dict = {
|
||||||
"task_index": task_index,
|
"task_index": task_index,
|
||||||
|
@ -332,31 +356,42 @@ def legacy_load_tasks(local_dir: Path) -> tuple[dict, dict]:
|
||||||
task_to_task_index = {task: task_index for task_index, task in tasks.items()}
|
task_to_task_index = {task: task_index for task_index, task in tasks.items()}
|
||||||
return tasks, task_to_task_index
|
return tasks, task_to_task_index
|
||||||
|
|
||||||
|
|
||||||
def load_tasks(local_dir: Path):
|
def load_tasks(local_dir: Path):
|
||||||
tasks = load_nested_dataset(local_dir / TASKS_DIR)
|
tasks = pd.read_parquet(local_dir / DEFAULT_TASKS_PATH)
|
||||||
# TODO(rcadene): optimize this
|
return tasks
|
||||||
task_to_task_index = {d["task"]: d["task_index"] for d in tasks}
|
|
||||||
return tasks, task_to_task_index
|
|
||||||
|
|
||||||
def write_episode(episode: dict, local_dir: Path):
|
def write_episode(episode: dict, local_dir: Path):
|
||||||
append_jsonlines(episode, local_dir / LEGACY_EPISODES_PATH)
|
append_jsonlines(episode, local_dir / LEGACY_EPISODES_PATH)
|
||||||
|
|
||||||
|
|
||||||
def write_episodes(episodes: Dataset, local_dir: Path):
|
def write_episodes(episodes: Dataset, local_dir: Path):
|
||||||
if get_hf_dataset_size_in_mb(episodes) > DEFAULT_FILE_SIZE_IN_MB:
|
if get_hf_dataset_size_in_mb(episodes) > DEFAULT_FILE_SIZE_IN_MB:
|
||||||
raise NotImplementedError("Contact a maintainer.")
|
raise NotImplementedError("Contact a maintainer.")
|
||||||
|
|
||||||
|
def add_chunk_file_indices(row):
|
||||||
|
row["chunk_index"] = 0
|
||||||
|
row["file_index"] = 0
|
||||||
|
return row
|
||||||
|
|
||||||
|
episodes = episodes.map(add_chunk_file_indices)
|
||||||
|
|
||||||
fpath = local_dir / DEFAULT_EPISODES_PATH.format(chunk_index=0, file_index=0)
|
fpath = local_dir / DEFAULT_EPISODES_PATH.format(chunk_index=0, file_index=0)
|
||||||
fpath.parent.mkdir(parents=True, exist_ok=True)
|
fpath.parent.mkdir(parents=True, exist_ok=True)
|
||||||
episodes.to_parquet(fpath)
|
episodes.to_parquet(fpath)
|
||||||
|
|
||||||
|
|
||||||
def legacy_load_episodes(local_dir: Path) -> dict:
|
def legacy_load_episodes(local_dir: Path) -> dict:
|
||||||
episodes = load_jsonlines(local_dir / LEGACY_EPISODES_PATH)
|
episodes = load_jsonlines(local_dir / LEGACY_EPISODES_PATH)
|
||||||
return {item["episode_index"]: item for item in sorted(episodes, key=lambda x: x["episode_index"])}
|
return {item["episode_index"]: item for item in sorted(episodes, key=lambda x: x["episode_index"])}
|
||||||
|
|
||||||
|
|
||||||
def load_episodes(local_dir: Path):
|
def load_episodes(local_dir: Path):
|
||||||
hf_dataset = load_nested_dataset(local_dir / EPISODES_DIR)
|
hf_dataset = load_nested_dataset(local_dir / EPISODES_DIR)
|
||||||
return hf_dataset
|
return hf_dataset
|
||||||
|
|
||||||
|
|
||||||
def legacy_write_episode_stats(episode_index: int, episode_stats: dict, local_dir: Path):
|
def legacy_write_episode_stats(episode_index: int, episode_stats: dict, local_dir: Path):
|
||||||
# We wrap episode_stats in a dictionary since `episode_stats["episode_index"]`
|
# We wrap episode_stats in a dictionary since `episode_stats["episode_index"]`
|
||||||
# is a dictionary of stats and not an integer.
|
# is a dictionary of stats and not an integer.
|
||||||
|
@ -364,13 +399,13 @@ def legacy_write_episode_stats(episode_index: int, episode_stats: dict, local_di
|
||||||
append_jsonlines(episode_stats, local_dir / LEGACY_EPISODES_STATS_PATH)
|
append_jsonlines(episode_stats, local_dir / LEGACY_EPISODES_STATS_PATH)
|
||||||
|
|
||||||
|
|
||||||
def write_episodes_stats(episodes_stats: Dataset, local_dir: Path):
|
# def write_episodes_stats(episodes_stats: Dataset, local_dir: Path):
|
||||||
if get_hf_dataset_size_in_mb(episodes_stats) > DEFAULT_FILE_SIZE_IN_MB:
|
# if get_hf_dataset_size_in_mb(episodes_stats) > DEFAULT_FILE_SIZE_IN_MB:
|
||||||
raise NotImplementedError("Contact a maintainer.")
|
# raise NotImplementedError("Contact a maintainer.")
|
||||||
|
|
||||||
fpath = local_dir / DEFAULT_EPISODES_STATS_PATH.format(chunk_index=0, file_index=0)
|
# fpath = local_dir / DEFAULT_EPISODES_STATS_PATH.format(chunk_index=0, file_index=0)
|
||||||
fpath.parent.mkdir(parents=True, exist_ok=True)
|
# fpath.parent.mkdir(parents=True, exist_ok=True)
|
||||||
episodes_stats.to_parquet(fpath)
|
# episodes_stats.to_parquet(fpath)
|
||||||
|
|
||||||
|
|
||||||
def legacy_load_episodes_stats(local_dir: Path) -> dict:
|
def legacy_load_episodes_stats(local_dir: Path) -> dict:
|
||||||
|
@ -380,9 +415,11 @@ def legacy_load_episodes_stats(local_dir: Path) -> dict:
|
||||||
for item in sorted(episodes_stats, key=lambda x: x["episode_index"])
|
for item in sorted(episodes_stats, key=lambda x: x["episode_index"])
|
||||||
}
|
}
|
||||||
|
|
||||||
def load_episodes_stats(local_dir: Path):
|
|
||||||
hf_dataset = load_nested_dataset(local_dir / EPISODES_STATS_DIR)
|
# def load_episodes_stats(local_dir: Path):
|
||||||
return hf_dataset
|
# hf_dataset = load_nested_dataset(local_dir / EPISODES_STATS_DIR)
|
||||||
|
# return hf_dataset
|
||||||
|
|
||||||
|
|
||||||
def backward_compatible_episodes_stats(
|
def backward_compatible_episodes_stats(
|
||||||
stats: dict[str, dict[str, np.ndarray]], episodes: list[int]
|
stats: dict[str, dict[str, np.ndarray]], episodes: list[int]
|
||||||
|
|
|
@ -18,13 +18,13 @@ python lerobot/common/datasets/v30/convert_dataset_v21_to_v30.py \
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import argparse
|
import argparse
|
||||||
import logging
|
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
import sys
|
|
||||||
|
|
||||||
|
import pandas as pd
|
||||||
|
import pyarrow.parquet as pq
|
||||||
|
import tqdm
|
||||||
from datasets import Dataset
|
from datasets import Dataset
|
||||||
from huggingface_hub import snapshot_download
|
from huggingface_hub import snapshot_download
|
||||||
import tqdm
|
|
||||||
|
|
||||||
from lerobot.common.constants import HF_LEROBOT_HOME
|
from lerobot.common.constants import HF_LEROBOT_HOME
|
||||||
from lerobot.common.datasets.utils import (
|
from lerobot.common.datasets.utils import (
|
||||||
|
@ -39,18 +39,13 @@ from lerobot.common.datasets.utils import (
|
||||||
get_video_size_in_mb,
|
get_video_size_in_mb,
|
||||||
legacy_load_episodes,
|
legacy_load_episodes,
|
||||||
legacy_load_episodes_stats,
|
legacy_load_episodes_stats,
|
||||||
load_info,
|
|
||||||
legacy_load_tasks,
|
legacy_load_tasks,
|
||||||
|
load_info,
|
||||||
update_chunk_file_indices,
|
update_chunk_file_indices,
|
||||||
write_episodes,
|
write_episodes,
|
||||||
write_episodes_stats,
|
|
||||||
write_info,
|
write_info,
|
||||||
write_tasks,
|
write_tasks,
|
||||||
)
|
)
|
||||||
import subprocess
|
|
||||||
import tempfile
|
|
||||||
import pandas as pd
|
|
||||||
import pyarrow.parquet as pq
|
|
||||||
|
|
||||||
V21 = "v2.1"
|
V21 = "v2.1"
|
||||||
|
|
||||||
|
@ -97,32 +92,31 @@ meta/info.json
|
||||||
-------------------------
|
-------------------------
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
|
||||||
def get_parquet_file_size_in_mb(parquet_path):
|
def get_parquet_file_size_in_mb(parquet_path):
|
||||||
metadata = pq.read_metadata(parquet_path)
|
metadata = pq.read_metadata(parquet_path)
|
||||||
uncompressed_size = metadata.num_rows * metadata.row_group(0).total_byte_size
|
uncompressed_size = metadata.num_rows * metadata.row_group(0).total_byte_size
|
||||||
return uncompressed_size / (1024 ** 2)
|
return uncompressed_size / (1024**2)
|
||||||
|
|
||||||
|
|
||||||
|
# def generate_flat_ep_stats(episodes_stats):
|
||||||
|
# for ep_idx, ep_stats in episodes_stats.items():
|
||||||
|
# flat_ep_stats = flatten_dict(ep_stats)
|
||||||
|
# flat_ep_stats["episode_index"] = ep_idx
|
||||||
|
# yield flat_ep_stats
|
||||||
|
|
||||||
def generate_flat_ep_stats(episodes_stats):
|
# def convert_episodes_stats(root, new_root):
|
||||||
for ep_idx, ep_stats in episodes_stats.items():
|
# episodes_stats = legacy_load_episodes_stats(root)
|
||||||
flat_ep_stats = flatten_dict(ep_stats)
|
# ds_episodes_stats = Dataset.from_generator(lambda: generate_flat_ep_stats(episodes_stats))
|
||||||
flat_ep_stats["episode_index"] = ep_idx
|
# write_episodes_stats(ds_episodes_stats, new_root)
|
||||||
yield flat_ep_stats
|
|
||||||
|
|
||||||
def convert_episodes_stats(root, new_root):
|
|
||||||
episodes_stats = legacy_load_episodes_stats(root)
|
|
||||||
ds_episodes_stats = Dataset.from_generator(lambda: generate_flat_ep_stats(episodes_stats))
|
|
||||||
write_episodes_stats(ds_episodes_stats, new_root)
|
|
||||||
|
|
||||||
def generate_task_dict(tasks):
|
|
||||||
for task_index, task in tasks.items():
|
|
||||||
yield {"task_index": task_index, "task": task}
|
|
||||||
|
|
||||||
def convert_tasks(root, new_root):
|
def convert_tasks(root, new_root):
|
||||||
tasks, _ = legacy_load_tasks(root)
|
tasks, _ = legacy_load_tasks(root)
|
||||||
ds_tasks = Dataset.from_generator(lambda: generate_task_dict(tasks))
|
task_indices = tasks.keys()
|
||||||
write_tasks(ds_tasks, new_root)
|
task_strings = tasks.values()
|
||||||
|
df_tasks = pd.DataFrame({"task_index": task_indices}, index=task_strings)
|
||||||
|
write_tasks(df_tasks, new_root)
|
||||||
|
|
||||||
|
|
||||||
def concat_data_files(paths_to_cat, new_root, chunk_idx, file_idx):
|
def concat_data_files(paths_to_cat, new_root, chunk_idx, file_idx):
|
||||||
|
@ -138,17 +132,15 @@ def concat_data_files(paths_to_cat, new_root, chunk_idx, file_idx):
|
||||||
|
|
||||||
def convert_data(root, new_root):
|
def convert_data(root, new_root):
|
||||||
data_dir = root / "data"
|
data_dir = root / "data"
|
||||||
|
ep_paths = sorted(data_dir.glob("*/*.parquet"))
|
||||||
|
|
||||||
ep_paths = [path for path in data_dir.glob("*/*.parquet")]
|
|
||||||
ep_paths = sorted(ep_paths)
|
|
||||||
|
|
||||||
episodes_metadata = []
|
|
||||||
ep_idx = 0
|
ep_idx = 0
|
||||||
chunk_idx = 0
|
chunk_idx = 0
|
||||||
file_idx = 0
|
file_idx = 0
|
||||||
size_in_mb = 0
|
size_in_mb = 0
|
||||||
num_frames = 0
|
num_frames = 0
|
||||||
paths_to_cat = []
|
paths_to_cat = []
|
||||||
|
episodes_metadata = []
|
||||||
for ep_path in ep_paths:
|
for ep_path in ep_paths:
|
||||||
ep_size_in_mb = get_parquet_file_size_in_mb(ep_path)
|
ep_size_in_mb = get_parquet_file_size_in_mb(ep_path)
|
||||||
ep_num_frames = get_parquet_num_frames(ep_path)
|
ep_num_frames = get_parquet_num_frames(ep_path)
|
||||||
|
@ -184,7 +176,6 @@ def convert_data(root, new_root):
|
||||||
return episodes_metadata
|
return episodes_metadata
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
def get_video_keys(root):
|
def get_video_keys(root):
|
||||||
info = load_info(root)
|
info = load_info(root)
|
||||||
features = info["features"]
|
features = info["features"]
|
||||||
|
@ -204,11 +195,11 @@ def convert_videos(root: Path, new_root: Path):
|
||||||
for camera in video_keys:
|
for camera in video_keys:
|
||||||
eps_metadata = convert_videos_of_camera(root, new_root, camera)
|
eps_metadata = convert_videos_of_camera(root, new_root, camera)
|
||||||
eps_metadata_per_cam.append(eps_metadata)
|
eps_metadata_per_cam.append(eps_metadata)
|
||||||
|
|
||||||
num_eps_per_cam = [len(eps_cam_map) for eps_cam_map in eps_metadata_per_cam]
|
num_eps_per_cam = [len(eps_cam_map) for eps_cam_map in eps_metadata_per_cam]
|
||||||
if len(set(num_eps_per_cam)) != 1:
|
if len(set(num_eps_per_cam)) != 1:
|
||||||
raise ValueError(f"All cams dont have same number of episodes ({num_eps_per_cam}).")
|
raise ValueError(f"All cams dont have same number of episodes ({num_eps_per_cam}).")
|
||||||
|
|
||||||
episods_metadata = []
|
episods_metadata = []
|
||||||
num_cameras = len(video_keys)
|
num_cameras = len(video_keys)
|
||||||
num_episodes = num_eps_per_cam[0]
|
num_episodes = num_eps_per_cam[0]
|
||||||
|
@ -223,23 +214,22 @@ def convert_videos(root: Path, new_root: Path):
|
||||||
for cam_idx in range(num_cameras):
|
for cam_idx in range(num_cameras):
|
||||||
ep_dict.update(eps_metadata_per_cam[cam_idx][ep_idx])
|
ep_dict.update(eps_metadata_per_cam[cam_idx][ep_idx])
|
||||||
episods_metadata.append(ep_dict)
|
episods_metadata.append(ep_dict)
|
||||||
|
|
||||||
return episods_metadata
|
return episods_metadata
|
||||||
|
|
||||||
|
|
||||||
def convert_videos_of_camera(root: Path, new_root: Path, video_key):
|
def convert_videos_of_camera(root: Path, new_root: Path, video_key):
|
||||||
# Access old paths to mp4
|
# Access old paths to mp4
|
||||||
videos_dir = root / "videos"
|
videos_dir = root / "videos"
|
||||||
ep_paths = [path for path in videos_dir.glob(f"*/{video_key}/*.mp4")]
|
ep_paths = sorted(videos_dir.glob(f"*/{video_key}/*.mp4"))
|
||||||
ep_paths = sorted(ep_paths)
|
|
||||||
|
|
||||||
episodes_metadata = []
|
|
||||||
ep_idx = 0
|
ep_idx = 0
|
||||||
chunk_idx = 0
|
chunk_idx = 0
|
||||||
file_idx = 0
|
file_idx = 0
|
||||||
size_in_mb = 0
|
size_in_mb = 0
|
||||||
duration_in_s = 0.0
|
duration_in_s = 0.0
|
||||||
paths_to_cat = []
|
paths_to_cat = []
|
||||||
|
episodes_metadata = []
|
||||||
for ep_path in tqdm.tqdm(ep_paths, desc=f"convert videos of {video_key}"):
|
for ep_path in tqdm.tqdm(ep_paths, desc=f"convert videos of {video_key}"):
|
||||||
ep_size_in_mb = get_video_size_in_mb(ep_path)
|
ep_size_in_mb = get_video_size_in_mb(ep_path)
|
||||||
ep_duration_in_s = get_video_duration_in_s(ep_path)
|
ep_duration_in_s = get_video_duration_in_s(ep_path)
|
||||||
|
@ -274,30 +264,53 @@ def convert_videos_of_camera(root: Path, new_root: Path, video_key):
|
||||||
|
|
||||||
return episodes_metadata
|
return episodes_metadata
|
||||||
|
|
||||||
def generate_episode_dict(episodes, episodes_data, episodes_videos):
|
|
||||||
for ep, ep_data, ep_video in zip(episodes.values(), episodes_data, episodes_videos):
|
def generate_episode_metadata_dict(
|
||||||
ep_idx = ep["episode_index"]
|
episodes_legacy_metadata, episodes_metadata, episodes_videos, episodes_stats
|
||||||
ep_idx_data = ep_data["episode_index"]
|
):
|
||||||
|
for ep_legacy_metadata, ep_metadata, ep_video, ep_stats, ep_idx_stats in zip(
|
||||||
|
episodes_legacy_metadata.values(),
|
||||||
|
episodes_metadata,
|
||||||
|
episodes_videos,
|
||||||
|
episodes_stats.values(),
|
||||||
|
episodes_stats.keys(),
|
||||||
|
strict=False,
|
||||||
|
):
|
||||||
|
ep_idx = ep_legacy_metadata["episode_index"]
|
||||||
|
ep_idx_data = ep_metadata["episode_index"]
|
||||||
ep_idx_video = ep_video["episode_index"]
|
ep_idx_video = ep_video["episode_index"]
|
||||||
|
|
||||||
if len(set([ep_idx, ep_idx_data, ep_idx_video])) != 1:
|
if len({ep_idx, ep_idx_data, ep_idx_video, ep_idx_stats}) != 1:
|
||||||
raise ValueError(f"Number of episodes is not the same ({ep_idx=},{ep_idx_data=},{ep_idx_video=}).")
|
raise ValueError(
|
||||||
|
f"Number of episodes is not the same ({ep_idx=},{ep_idx_data=},{ep_idx_video=},{ep_idx_stats=})."
|
||||||
|
)
|
||||||
|
|
||||||
ep_dict = {**ep_data, **ep_video, **ep}
|
ep_dict = {**ep_metadata, **ep_video, **ep_legacy_metadata, **flatten_dict({"stats": ep_stats})}
|
||||||
|
ep_dict["meta/episodes/chunk_index"] = 0
|
||||||
|
ep_dict["meta/episodes/file_index"] = 0
|
||||||
yield ep_dict
|
yield ep_dict
|
||||||
|
|
||||||
def convert_episodes(root, new_root, episodes_data, episodes_videos):
|
|
||||||
episodes = legacy_load_episodes(root)
|
|
||||||
|
|
||||||
num_eps = len(episodes)
|
def convert_episodes_metadata(root, new_root, episodes_metadata, episodes_video_metadata):
|
||||||
num_eps_data = len(episodes_data)
|
episodes_legacy_metadata = legacy_load_episodes(root)
|
||||||
num_eps_video = len(episodes_videos)
|
episodes_stats = legacy_load_episodes_stats(root)
|
||||||
if len(set([num_eps, num_eps_data, num_eps_video])) != 1:
|
|
||||||
raise ValueError(f"Number of episodes is not the same ({num_eps=},{num_eps_data=},{num_eps_video=}).")
|
|
||||||
|
|
||||||
ds_episodes = Dataset.from_generator(lambda: generate_episode_dict(episodes, episodes_data, episodes_videos))
|
num_eps = len(episodes_legacy_metadata)
|
||||||
|
num_eps_metadata = len(episodes_metadata)
|
||||||
|
num_eps_video_metadata = len(episodes_video_metadata)
|
||||||
|
if len({num_eps, num_eps_metadata, num_eps_video_metadata}) != 1:
|
||||||
|
raise ValueError(
|
||||||
|
f"Number of episodes is not the same ({num_eps=},{num_eps_metadata=},{num_eps_video_metadata=})."
|
||||||
|
)
|
||||||
|
|
||||||
|
ds_episodes = Dataset.from_generator(
|
||||||
|
lambda: generate_episode_metadata_dict(
|
||||||
|
episodes_legacy_metadata, episodes_metadata, episodes_video_metadata, episodes_stats
|
||||||
|
)
|
||||||
|
)
|
||||||
write_episodes(ds_episodes, new_root)
|
write_episodes(ds_episodes, new_root)
|
||||||
|
|
||||||
|
|
||||||
def convert_info(root, new_root):
|
def convert_info(root, new_root):
|
||||||
info = load_info(root)
|
info = load_info(root)
|
||||||
info["codebase_version"] = "v3.0"
|
info["codebase_version"] = "v3.0"
|
||||||
|
@ -315,6 +328,7 @@ def convert_info(root, new_root):
|
||||||
info["features"][key]["fps"] = info["fps"]
|
info["features"][key]["fps"] = info["fps"]
|
||||||
write_info(info, new_root)
|
write_info(info, new_root)
|
||||||
|
|
||||||
|
|
||||||
def convert_dataset(
|
def convert_dataset(
|
||||||
repo_id: str,
|
repo_id: str,
|
||||||
branch: str | None = None,
|
branch: str | None = None,
|
||||||
|
@ -331,11 +345,11 @@ def convert_dataset(
|
||||||
)
|
)
|
||||||
|
|
||||||
convert_info(root, new_root)
|
convert_info(root, new_root)
|
||||||
convert_episodes_stats(root, new_root)
|
convert_tasks(root, new_root)
|
||||||
convert_tasks(root, new_root)
|
episodes_metadata = convert_data(root, new_root)
|
||||||
episodes_data_mapping = convert_data(root, new_root)
|
episodes_videos_metadata = convert_videos(root, new_root)
|
||||||
episodes_videos_mapping = convert_videos(root, new_root)
|
convert_episodes_metadata(root, new_root, episodes_metadata, episodes_videos_metadata)
|
||||||
convert_episodes(root, new_root, episodes_data_mapping, episodes_videos_mapping)
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
parser = argparse.ArgumentParser()
|
parser = argparse.ArgumentParser()
|
||||||
|
|
|
@ -6,6 +6,7 @@ from unittest.mock import patch
|
||||||
|
|
||||||
import datasets
|
import datasets
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
import pandas as pd
|
||||||
import PIL.Image
|
import PIL.Image
|
||||||
import pytest
|
import pytest
|
||||||
import torch
|
import torch
|
||||||
|
@ -14,13 +15,12 @@ from datasets import Dataset
|
||||||
from lerobot.common.datasets.lerobot_dataset import CODEBASE_VERSION, LeRobotDataset, LeRobotDatasetMetadata
|
from lerobot.common.datasets.lerobot_dataset import CODEBASE_VERSION, LeRobotDataset, LeRobotDatasetMetadata
|
||||||
from lerobot.common.datasets.utils import (
|
from lerobot.common.datasets.utils import (
|
||||||
DEFAULT_CHUNK_SIZE,
|
DEFAULT_CHUNK_SIZE,
|
||||||
DEFAULT_FEATURES,
|
|
||||||
DEFAULT_DATA_PATH,
|
DEFAULT_DATA_PATH,
|
||||||
|
DEFAULT_FEATURES,
|
||||||
DEFAULT_FILE_SIZE_IN_MB,
|
DEFAULT_FILE_SIZE_IN_MB,
|
||||||
DEFAULT_VIDEO_PATH,
|
DEFAULT_VIDEO_PATH,
|
||||||
flatten_dict,
|
flatten_dict,
|
||||||
get_hf_features_from_features,
|
get_hf_features_from_features,
|
||||||
hf_transform_to_torch,
|
|
||||||
)
|
)
|
||||||
from tests.fixtures.constants import (
|
from tests.fixtures.constants import (
|
||||||
DEFAULT_FPS,
|
DEFAULT_FPS,
|
||||||
|
@ -36,8 +36,9 @@ class LeRobotDatasetFactory(Protocol):
|
||||||
def __call__(self, *args, **kwargs) -> LeRobotDataset: ...
|
def __call__(self, *args, **kwargs) -> LeRobotDataset: ...
|
||||||
|
|
||||||
|
|
||||||
def get_task_index(tasks: Dataset, task: str) -> int:
|
def get_task_index(tasks: datasets.Dataset, task: str) -> int:
|
||||||
task_idx = tasks["task"].index(task)
|
# TODO(rcadene): a bit complicated no? ^^
|
||||||
|
task_idx = tasks.loc[task].task_index.item()
|
||||||
return task_idx
|
return task_idx
|
||||||
|
|
||||||
|
|
||||||
|
@ -164,42 +165,44 @@ def stats_factory():
|
||||||
return _create_stats
|
return _create_stats
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(scope="session")
|
# @pytest.fixture(scope="session")
|
||||||
def episodes_stats_factory(stats_factory):
|
# def episodes_stats_factory(stats_factory):
|
||||||
def _create_episodes_stats(
|
# def _create_episodes_stats(
|
||||||
features: dict[str],
|
# features: dict[str],
|
||||||
total_episodes: int = 3,
|
# total_episodes: int = 3,
|
||||||
) -> dict:
|
# ) -> dict:
|
||||||
|
|
||||||
def _generator(total_episodes):
|
|
||||||
for ep_idx in range(total_episodes):
|
|
||||||
flat_ep_stats = flatten_dict(stats_factory(features))
|
|
||||||
flat_ep_stats["episode_index"] = ep_idx
|
|
||||||
yield flat_ep_stats
|
|
||||||
|
|
||||||
# Simpler to rely on generator instead of from_dict
|
# def _generator(total_episodes):
|
||||||
return Dataset.from_generator(lambda: _generator(total_episodes))
|
# for ep_idx in range(total_episodes):
|
||||||
|
# flat_ep_stats = flatten_dict(stats_factory(features))
|
||||||
|
# flat_ep_stats["episode_index"] = ep_idx
|
||||||
|
# yield flat_ep_stats
|
||||||
|
|
||||||
return _create_episodes_stats
|
# # Simpler to rely on generator instead of from_dict
|
||||||
|
# return Dataset.from_generator(lambda: _generator(total_episodes))
|
||||||
|
|
||||||
|
# return _create_episodes_stats
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(scope="session")
|
@pytest.fixture(scope="session")
|
||||||
def tasks_factory():
|
def tasks_factory():
|
||||||
def _create_tasks(total_tasks: int = 3) -> Dataset:
|
def _create_tasks(total_tasks: int = 3) -> pd.DataFrame:
|
||||||
ids = list(range(total_tasks))
|
ids = list(range(total_tasks))
|
||||||
tasks = [f"Perform action {i}." for i in ids]
|
tasks = [f"Perform action {i}." for i in ids]
|
||||||
return Dataset.from_dict({"task_index": ids, "task": tasks})
|
df = pd.DataFrame({"task_index": ids}, index=tasks)
|
||||||
|
return df
|
||||||
|
|
||||||
return _create_tasks
|
return _create_tasks
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(scope="session")
|
@pytest.fixture(scope="session")
|
||||||
def episodes_factory(tasks_factory):
|
def episodes_factory(tasks_factory, stats_factory):
|
||||||
def _create_episodes(
|
def _create_episodes(
|
||||||
|
features: dict[str],
|
||||||
total_episodes: int = 3,
|
total_episodes: int = 3,
|
||||||
total_frames: int = 400,
|
total_frames: int = 400,
|
||||||
video_keys: list[str] | None = None,
|
video_keys: list[str] | None = None,
|
||||||
tasks: dict | None = None,
|
tasks: pd.DataFrame | None = None,
|
||||||
multi_task: bool = False,
|
multi_task: bool = False,
|
||||||
):
|
):
|
||||||
if total_episodes <= 0 or total_frames <= 0:
|
if total_episodes <= 0 or total_frames <= 0:
|
||||||
|
@ -207,21 +210,24 @@ def episodes_factory(tasks_factory):
|
||||||
if total_frames < total_episodes:
|
if total_frames < total_episodes:
|
||||||
raise ValueError("total_length must be greater than or equal to num_episodes.")
|
raise ValueError("total_length must be greater than or equal to num_episodes.")
|
||||||
|
|
||||||
if not tasks:
|
if tasks is None:
|
||||||
min_tasks = 2 if multi_task else 1
|
min_tasks = 2 if multi_task else 1
|
||||||
total_tasks = random.randint(min_tasks, total_episodes)
|
total_tasks = random.randint(min_tasks, total_episodes)
|
||||||
tasks = tasks_factory(total_tasks)
|
tasks = tasks_factory(total_tasks)
|
||||||
|
|
||||||
if total_episodes < len(tasks) and not multi_task:
|
num_tasks_available = len(tasks)
|
||||||
|
|
||||||
|
if total_episodes < num_tasks_available and not multi_task:
|
||||||
raise ValueError("The number of tasks should be less than the number of episodes.")
|
raise ValueError("The number of tasks should be less than the number of episodes.")
|
||||||
|
|
||||||
# Generate random lengths that sum up to total_length
|
# Generate random lengths that sum up to total_length
|
||||||
lengths = np.random.multinomial(total_frames, [1 / total_episodes] * total_episodes).tolist()
|
lengths = np.random.multinomial(total_frames, [1 / total_episodes] * total_episodes).tolist()
|
||||||
|
|
||||||
num_tasks_available = len(tasks["task"])
|
# Create empty dictionaries with all keys
|
||||||
|
|
||||||
d = {
|
d = {
|
||||||
"episode_index": [],
|
"episode_index": [],
|
||||||
|
"meta/episodes/chunk_index": [],
|
||||||
|
"meta/episodes/file_index": [],
|
||||||
"data/chunk_index": [],
|
"data/chunk_index": [],
|
||||||
"data/file_index": [],
|
"data/file_index": [],
|
||||||
"tasks": [],
|
"tasks": [],
|
||||||
|
@ -232,10 +238,13 @@ def episodes_factory(tasks_factory):
|
||||||
d[f"{video_key}/chunk_index"] = []
|
d[f"{video_key}/chunk_index"] = []
|
||||||
d[f"{video_key}/file_index"] = []
|
d[f"{video_key}/file_index"] = []
|
||||||
|
|
||||||
remaining_tasks = tasks["task"].copy()
|
for stats_key in flatten_dict({"stats": stats_factory(features)}):
|
||||||
|
d[stats_key] = []
|
||||||
|
|
||||||
|
remaining_tasks = list(tasks.index)
|
||||||
for ep_idx in range(total_episodes):
|
for ep_idx in range(total_episodes):
|
||||||
num_tasks_in_episode = random.randint(1, min(3, num_tasks_available)) if multi_task else 1
|
num_tasks_in_episode = random.randint(1, min(3, num_tasks_available)) if multi_task else 1
|
||||||
tasks_to_sample = remaining_tasks if remaining_tasks else tasks["task"]
|
tasks_to_sample = remaining_tasks if len(remaining_tasks) > 0 else list(tasks.index)
|
||||||
episode_tasks = random.sample(tasks_to_sample, min(num_tasks_in_episode, len(tasks_to_sample)))
|
episode_tasks = random.sample(tasks_to_sample, min(num_tasks_in_episode, len(tasks_to_sample)))
|
||||||
if remaining_tasks:
|
if remaining_tasks:
|
||||||
for task in episode_tasks:
|
for task in episode_tasks:
|
||||||
|
@ -243,15 +252,22 @@ def episodes_factory(tasks_factory):
|
||||||
|
|
||||||
d["episode_index"].append(ep_idx)
|
d["episode_index"].append(ep_idx)
|
||||||
# TODO(rcadene): remove heuristic of only one file
|
# TODO(rcadene): remove heuristic of only one file
|
||||||
|
d["meta/episodes/chunk_index"].append(0)
|
||||||
|
d["meta/episodes/file_index"].append(0)
|
||||||
d["data/chunk_index"].append(0)
|
d["data/chunk_index"].append(0)
|
||||||
d["data/file_index"].append(0)
|
d["data/file_index"].append(0)
|
||||||
d["tasks"].append(episode_tasks)
|
d["tasks"].append(episode_tasks)
|
||||||
d["length"].append(lengths[ep_idx])
|
d["length"].append(lengths[ep_idx])
|
||||||
|
|
||||||
if video_keys is not None:
|
if video_keys is not None:
|
||||||
for video_key in video_keys:
|
for video_key in video_keys:
|
||||||
d[f"{video_key}/chunk_index"].append(0)
|
d[f"{video_key}/chunk_index"].append(0)
|
||||||
d[f"{video_key}/file_index"].append(0)
|
d[f"{video_key}/file_index"].append(0)
|
||||||
|
|
||||||
|
# Add stats columns like "stats/action/max"
|
||||||
|
for stats_key, stats in flatten_dict({"stats": stats_factory(features)}).items():
|
||||||
|
d[stats_key].append(stats)
|
||||||
|
|
||||||
return Dataset.from_dict(d)
|
return Dataset.from_dict(d)
|
||||||
|
|
||||||
return _create_episodes
|
return _create_episodes
|
||||||
|
@ -261,15 +277,15 @@ def episodes_factory(tasks_factory):
|
||||||
def hf_dataset_factory(features_factory, tasks_factory, episodes_factory, img_array_factory):
|
def hf_dataset_factory(features_factory, tasks_factory, episodes_factory, img_array_factory):
|
||||||
def _create_hf_dataset(
|
def _create_hf_dataset(
|
||||||
features: dict | None = None,
|
features: dict | None = None,
|
||||||
tasks: list[dict] | None = None,
|
tasks: pd.DataFrame | None = None,
|
||||||
episodes: list[dict] | None = None,
|
episodes: datasets.Dataset | None = None,
|
||||||
fps: int = DEFAULT_FPS,
|
fps: int = DEFAULT_FPS,
|
||||||
) -> datasets.Dataset:
|
) -> datasets.Dataset:
|
||||||
if not tasks:
|
if tasks is None:
|
||||||
tasks = tasks_factory()
|
tasks = tasks_factory()
|
||||||
if not episodes:
|
if episodes is None:
|
||||||
episodes = episodes_factory()
|
episodes = episodes_factory()
|
||||||
if not features:
|
if features is None:
|
||||||
features = features_factory()
|
features = features_factory()
|
||||||
|
|
||||||
timestamp_col = np.array([], dtype=np.float32)
|
timestamp_col = np.array([], dtype=np.float32)
|
||||||
|
@ -282,6 +298,8 @@ def hf_dataset_factory(features_factory, tasks_factory, episodes_factory, img_ar
|
||||||
episode_index_col = np.concatenate(
|
episode_index_col = np.concatenate(
|
||||||
(episode_index_col, np.full(ep_dict["length"], ep_dict["episode_index"], dtype=int))
|
(episode_index_col, np.full(ep_dict["length"], ep_dict["episode_index"], dtype=int))
|
||||||
)
|
)
|
||||||
|
# Slightly incorrect, but for simplicity, we assign to all frames the first task defined in the episode metadata.
|
||||||
|
# TODO(rcadene): assign the tasks of the episode per chunks of frames
|
||||||
ep_task_index = get_task_index(tasks, ep_dict["tasks"][0])
|
ep_task_index = get_task_index(tasks, ep_dict["tasks"][0])
|
||||||
task_index = np.concatenate((task_index, np.full(ep_dict["length"], ep_task_index, dtype=int)))
|
task_index = np.concatenate((task_index, np.full(ep_dict["length"], ep_task_index, dtype=int)))
|
||||||
|
|
||||||
|
@ -319,7 +337,6 @@ def hf_dataset_factory(features_factory, tasks_factory, episodes_factory, img_ar
|
||||||
def lerobot_dataset_metadata_factory(
|
def lerobot_dataset_metadata_factory(
|
||||||
info_factory,
|
info_factory,
|
||||||
stats_factory,
|
stats_factory,
|
||||||
episodes_stats_factory,
|
|
||||||
tasks_factory,
|
tasks_factory,
|
||||||
episodes_factory,
|
episodes_factory,
|
||||||
mock_snapshot_download_factory,
|
mock_snapshot_download_factory,
|
||||||
|
@ -329,30 +346,28 @@ def lerobot_dataset_metadata_factory(
|
||||||
repo_id: str = DUMMY_REPO_ID,
|
repo_id: str = DUMMY_REPO_ID,
|
||||||
info: dict | None = None,
|
info: dict | None = None,
|
||||||
stats: dict | None = None,
|
stats: dict | None = None,
|
||||||
episodes_stats: list[dict] | None = None,
|
tasks: pd.DataFrame | None = None,
|
||||||
tasks: list[dict] | None = None,
|
episodes: datasets.Dataset | None = None,
|
||||||
episodes: list[dict] | None = None,
|
|
||||||
) -> LeRobotDatasetMetadata:
|
) -> LeRobotDatasetMetadata:
|
||||||
if not info:
|
if info is None:
|
||||||
info = info_factory()
|
info = info_factory()
|
||||||
if not stats:
|
if stats is None:
|
||||||
stats = stats_factory(features=info["features"])
|
stats = stats_factory(features=info["features"])
|
||||||
if not episodes_stats:
|
if tasks is None:
|
||||||
episodes_stats = episodes_stats_factory(
|
|
||||||
features=info["features"], total_episodes=info["total_episodes"]
|
|
||||||
)
|
|
||||||
if not tasks:
|
|
||||||
tasks = tasks_factory(total_tasks=info["total_tasks"])
|
tasks = tasks_factory(total_tasks=info["total_tasks"])
|
||||||
if not episodes:
|
if episodes is None:
|
||||||
video_keys = [key for key, ft in info["features"].items() if ft["dtype"] == "video"]
|
video_keys = [key for key, ft in info["features"].items() if ft["dtype"] == "video"]
|
||||||
episodes = episodes_factory(
|
episodes = episodes_factory(
|
||||||
total_episodes=info["total_episodes"], total_frames=info["total_frames"], video_keys=video_keys, tasks=tasks
|
features=info["features"],
|
||||||
|
total_episodes=info["total_episodes"],
|
||||||
|
total_frames=info["total_frames"],
|
||||||
|
video_keys=video_keys,
|
||||||
|
tasks=tasks,
|
||||||
)
|
)
|
||||||
|
|
||||||
mock_snapshot_download = mock_snapshot_download_factory(
|
mock_snapshot_download = mock_snapshot_download_factory(
|
||||||
info=info,
|
info=info,
|
||||||
stats=stats,
|
stats=stats,
|
||||||
episodes_stats=episodes_stats,
|
|
||||||
tasks=tasks,
|
tasks=tasks,
|
||||||
episodes=episodes,
|
episodes=episodes,
|
||||||
)
|
)
|
||||||
|
@ -374,7 +389,6 @@ def lerobot_dataset_metadata_factory(
|
||||||
def lerobot_dataset_factory(
|
def lerobot_dataset_factory(
|
||||||
info_factory,
|
info_factory,
|
||||||
stats_factory,
|
stats_factory,
|
||||||
episodes_stats_factory,
|
|
||||||
tasks_factory,
|
tasks_factory,
|
||||||
episodes_factory,
|
episodes_factory,
|
||||||
hf_dataset_factory,
|
hf_dataset_factory,
|
||||||
|
@ -390,25 +404,23 @@ def lerobot_dataset_factory(
|
||||||
multi_task: bool = False,
|
multi_task: bool = False,
|
||||||
info: dict | None = None,
|
info: dict | None = None,
|
||||||
stats: dict | None = None,
|
stats: dict | None = None,
|
||||||
episodes_stats: datasets.Dataset | None = None,
|
tasks: pd.DataFrame | None = None,
|
||||||
tasks: datasets.Dataset | None = None,
|
episodes_metadata: datasets.Dataset | None = None,
|
||||||
episode_dicts: datasets.Dataset | None = None,
|
|
||||||
hf_dataset: datasets.Dataset | None = None,
|
hf_dataset: datasets.Dataset | None = None,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
) -> LeRobotDataset:
|
) -> LeRobotDataset:
|
||||||
if not info:
|
if info is None:
|
||||||
info = info_factory(
|
info = info_factory(
|
||||||
total_episodes=total_episodes, total_frames=total_frames, total_tasks=total_tasks
|
total_episodes=total_episodes, total_frames=total_frames, total_tasks=total_tasks
|
||||||
)
|
)
|
||||||
if not stats:
|
if stats is None:
|
||||||
stats = stats_factory(features=info["features"])
|
stats = stats_factory(features=info["features"])
|
||||||
if not episodes_stats:
|
if tasks is None:
|
||||||
episodes_stats = episodes_stats_factory(features=info["features"], total_episodes=total_episodes)
|
|
||||||
if not tasks:
|
|
||||||
tasks = tasks_factory(total_tasks=info["total_tasks"])
|
tasks = tasks_factory(total_tasks=info["total_tasks"])
|
||||||
if not episode_dicts:
|
if episodes_metadata is None:
|
||||||
video_keys = [key for key, ft in info["features"].items() if ft["dtype"] == "video"]
|
video_keys = [key for key, ft in info["features"].items() if ft["dtype"] == "video"]
|
||||||
episode_dicts = episodes_factory(
|
episodes_metadata = episodes_factory(
|
||||||
|
features=info["features"],
|
||||||
total_episodes=info["total_episodes"],
|
total_episodes=info["total_episodes"],
|
||||||
total_frames=info["total_frames"],
|
total_frames=info["total_frames"],
|
||||||
video_keys=video_keys,
|
video_keys=video_keys,
|
||||||
|
@ -416,14 +428,13 @@ def lerobot_dataset_factory(
|
||||||
multi_task=multi_task,
|
multi_task=multi_task,
|
||||||
)
|
)
|
||||||
if not hf_dataset:
|
if not hf_dataset:
|
||||||
hf_dataset = hf_dataset_factory(tasks=tasks, episodes=episode_dicts, fps=info["fps"])
|
hf_dataset = hf_dataset_factory(tasks=tasks, episodes=episodes_metadata, fps=info["fps"])
|
||||||
|
|
||||||
mock_snapshot_download = mock_snapshot_download_factory(
|
mock_snapshot_download = mock_snapshot_download_factory(
|
||||||
info=info,
|
info=info,
|
||||||
stats=stats,
|
stats=stats,
|
||||||
episodes_stats=episodes_stats,
|
|
||||||
tasks=tasks,
|
tasks=tasks,
|
||||||
episodes=episode_dicts,
|
episodes=episodes_metadata,
|
||||||
hf_dataset=hf_dataset,
|
hf_dataset=hf_dataset,
|
||||||
)
|
)
|
||||||
mock_metadata = lerobot_dataset_metadata_factory(
|
mock_metadata = lerobot_dataset_metadata_factory(
|
||||||
|
@ -431,9 +442,8 @@ def lerobot_dataset_factory(
|
||||||
repo_id=repo_id,
|
repo_id=repo_id,
|
||||||
info=info,
|
info=info,
|
||||||
stats=stats,
|
stats=stats,
|
||||||
episodes_stats=episodes_stats,
|
|
||||||
tasks=tasks,
|
tasks=tasks,
|
||||||
episodes=episode_dicts,
|
episodes=episodes_metadata,
|
||||||
)
|
)
|
||||||
with (
|
with (
|
||||||
patch("lerobot.common.datasets.lerobot_dataset.LeRobotDatasetMetadata") as mock_metadata_patch,
|
patch("lerobot.common.datasets.lerobot_dataset.LeRobotDatasetMetadata") as mock_metadata_patch,
|
||||||
|
|
|
@ -1,17 +1,13 @@
|
||||||
import json
|
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
import datasets
|
import datasets
|
||||||
import jsonlines
|
import pandas as pd
|
||||||
import pyarrow.compute as pc
|
import pyarrow.compute as pc
|
||||||
import pyarrow.parquet as pq
|
import pyarrow.parquet as pq
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
from datasets import Dataset
|
|
||||||
|
|
||||||
from lerobot.common.datasets.utils import (
|
from lerobot.common.datasets.utils import (
|
||||||
write_episodes,
|
write_episodes,
|
||||||
write_episodes_stats,
|
|
||||||
write_hf_dataset,
|
write_hf_dataset,
|
||||||
write_info,
|
write_info,
|
||||||
write_stats,
|
write_stats,
|
||||||
|
@ -22,7 +18,7 @@ from lerobot.common.datasets.utils import (
|
||||||
@pytest.fixture(scope="session")
|
@pytest.fixture(scope="session")
|
||||||
def create_info(info_factory):
|
def create_info(info_factory):
|
||||||
def _create_info(dir: Path, info: dict | None = None):
|
def _create_info(dir: Path, info: dict | None = None):
|
||||||
if not info:
|
if info is None:
|
||||||
info = info_factory()
|
info = info_factory()
|
||||||
write_info(info, dir)
|
write_info(info, dir)
|
||||||
|
|
||||||
|
@ -32,27 +28,27 @@ def create_info(info_factory):
|
||||||
@pytest.fixture(scope="session")
|
@pytest.fixture(scope="session")
|
||||||
def create_stats(stats_factory):
|
def create_stats(stats_factory):
|
||||||
def _create_stats(dir: Path, stats: dict | None = None):
|
def _create_stats(dir: Path, stats: dict | None = None):
|
||||||
if not stats:
|
if stats is None:
|
||||||
stats = stats_factory()
|
stats = stats_factory()
|
||||||
write_stats(stats, dir)
|
write_stats(stats, dir)
|
||||||
|
|
||||||
return _create_stats
|
return _create_stats
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(scope="session")
|
# @pytest.fixture(scope="session")
|
||||||
def create_episodes_stats(episodes_stats_factory):
|
# def create_episodes_stats(episodes_stats_factory):
|
||||||
def _create_episodes_stats(dir: Path, episodes_stats: Dataset | None = None):
|
# def _create_episodes_stats(dir: Path, episodes_stats: Dataset | None = None):
|
||||||
if not episodes_stats:
|
# if episodes_stats is None:
|
||||||
episodes_stats = episodes_stats_factory()
|
# episodes_stats = episodes_stats_factory()
|
||||||
write_episodes_stats(episodes_stats, dir)
|
# write_episodes_stats(episodes_stats, dir)
|
||||||
|
|
||||||
return _create_episodes_stats
|
# return _create_episodes_stats
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(scope="session")
|
@pytest.fixture(scope="session")
|
||||||
def create_tasks(tasks_factory):
|
def create_tasks(tasks_factory):
|
||||||
def _create_tasks(dir: Path, tasks: Dataset | None = None):
|
def _create_tasks(dir: Path, tasks: pd.DataFrame | None = None):
|
||||||
if not tasks:
|
if tasks is None:
|
||||||
tasks = tasks_factory()
|
tasks = tasks_factory()
|
||||||
write_tasks(tasks, dir)
|
write_tasks(tasks, dir)
|
||||||
|
|
||||||
|
@ -61,17 +57,18 @@ def create_tasks(tasks_factory):
|
||||||
|
|
||||||
@pytest.fixture(scope="session")
|
@pytest.fixture(scope="session")
|
||||||
def create_episodes(episodes_factory):
|
def create_episodes(episodes_factory):
|
||||||
def _create_episodes(dir: Path, episodes: Dataset | None = None):
|
def _create_episodes(dir: Path, episodes: datasets.Dataset | None = None):
|
||||||
if not episodes:
|
if episodes is None:
|
||||||
episodes = episodes_factory()
|
episodes = episodes_factory()
|
||||||
write_episodes(episodes, dir)
|
write_episodes(episodes, dir)
|
||||||
|
|
||||||
return _create_episodes
|
return _create_episodes
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(scope="session")
|
@pytest.fixture(scope="session")
|
||||||
def create_hf_dataset(hf_dataset_factory):
|
def create_hf_dataset(hf_dataset_factory):
|
||||||
def _create_hf_dataset(dir: Path, hf_dataset: Dataset | None = None):
|
def _create_hf_dataset(dir: Path, hf_dataset: datasets.Dataset | None = None):
|
||||||
if not hf_dataset:
|
if hf_dataset is None:
|
||||||
hf_dataset = hf_dataset_factory()
|
hf_dataset = hf_dataset_factory()
|
||||||
write_hf_dataset(hf_dataset, dir)
|
write_hf_dataset(hf_dataset, dir)
|
||||||
|
|
||||||
|
@ -84,7 +81,7 @@ def single_episode_parquet_path(hf_dataset_factory, info_factory):
|
||||||
dir: Path, ep_idx: int = 0, hf_dataset: datasets.Dataset | None = None, info: dict | None = None
|
dir: Path, ep_idx: int = 0, hf_dataset: datasets.Dataset | None = None, info: dict | None = None
|
||||||
) -> Path:
|
) -> Path:
|
||||||
raise NotImplementedError()
|
raise NotImplementedError()
|
||||||
if not info:
|
if info is None:
|
||||||
info = info_factory()
|
info = info_factory()
|
||||||
if hf_dataset is None:
|
if hf_dataset is None:
|
||||||
hf_dataset = hf_dataset_factory()
|
hf_dataset = hf_dataset_factory()
|
||||||
|
@ -108,7 +105,7 @@ def multi_episode_parquet_path(hf_dataset_factory, info_factory):
|
||||||
dir: Path, hf_dataset: datasets.Dataset | None = None, info: dict | None = None
|
dir: Path, hf_dataset: datasets.Dataset | None = None, info: dict | None = None
|
||||||
) -> Path:
|
) -> Path:
|
||||||
raise NotImplementedError()
|
raise NotImplementedError()
|
||||||
if not info:
|
if info is None:
|
||||||
info = info_factory()
|
info = info_factory()
|
||||||
if hf_dataset is None:
|
if hf_dataset is None:
|
||||||
hf_dataset = hf_dataset_factory()
|
hf_dataset = hf_dataset_factory()
|
||||||
|
|
|
@ -1,13 +1,13 @@
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
import datasets
|
import datasets
|
||||||
|
import pandas as pd
|
||||||
import pytest
|
import pytest
|
||||||
from huggingface_hub.utils import filter_repo_objects
|
from huggingface_hub.utils import filter_repo_objects
|
||||||
|
|
||||||
from lerobot.common.datasets.utils import (
|
from lerobot.common.datasets.utils import (
|
||||||
DEFAULT_DATA_PATH,
|
DEFAULT_DATA_PATH,
|
||||||
DEFAULT_EPISODES_PATH,
|
DEFAULT_EPISODES_PATH,
|
||||||
DEFAULT_EPISODES_STATS_PATH,
|
|
||||||
DEFAULT_TASKS_PATH,
|
DEFAULT_TASKS_PATH,
|
||||||
INFO_PATH,
|
INFO_PATH,
|
||||||
LEGACY_STATS_PATH,
|
LEGACY_STATS_PATH,
|
||||||
|
@ -21,8 +21,6 @@ def mock_snapshot_download_factory(
|
||||||
create_info,
|
create_info,
|
||||||
stats_factory,
|
stats_factory,
|
||||||
create_stats,
|
create_stats,
|
||||||
episodes_stats_factory,
|
|
||||||
create_episodes_stats,
|
|
||||||
tasks_factory,
|
tasks_factory,
|
||||||
create_tasks,
|
create_tasks,
|
||||||
episodes_factory,
|
episodes_factory,
|
||||||
|
@ -38,46 +36,43 @@ def mock_snapshot_download_factory(
|
||||||
def _mock_snapshot_download_func(
|
def _mock_snapshot_download_func(
|
||||||
info: dict | None = None,
|
info: dict | None = None,
|
||||||
stats: dict | None = None,
|
stats: dict | None = None,
|
||||||
episodes_stats: datasets.Dataset | None = None,
|
tasks: pd.DataFrame | None = None,
|
||||||
tasks: datasets.Dataset | None = None,
|
|
||||||
episodes: datasets.Dataset | None = None,
|
episodes: datasets.Dataset | None = None,
|
||||||
hf_dataset: datasets.Dataset | None = None,
|
hf_dataset: datasets.Dataset | None = None,
|
||||||
):
|
):
|
||||||
if not info:
|
if info is None:
|
||||||
info = info_factory()
|
info = info_factory()
|
||||||
if not stats:
|
if stats is None:
|
||||||
stats = stats_factory(features=info["features"])
|
stats = stats_factory(features=info["features"])
|
||||||
if not episodes_stats:
|
if tasks is None:
|
||||||
episodes_stats = episodes_stats_factory(
|
|
||||||
features=info["features"], total_episodes=info["total_episodes"]
|
|
||||||
)
|
|
||||||
if not tasks:
|
|
||||||
tasks = tasks_factory(total_tasks=info["total_tasks"])
|
tasks = tasks_factory(total_tasks=info["total_tasks"])
|
||||||
if not episodes:
|
if episodes is None:
|
||||||
episodes = episodes_factory(
|
episodes = episodes_factory(
|
||||||
total_episodes=info["total_episodes"], total_frames=info["total_frames"], tasks=tasks
|
features=info["features"],
|
||||||
|
total_episodes=info["total_episodes"],
|
||||||
|
total_frames=info["total_frames"],
|
||||||
|
tasks=tasks,
|
||||||
)
|
)
|
||||||
if not hf_dataset:
|
if hf_dataset is None:
|
||||||
hf_dataset = hf_dataset_factory(tasks=tasks, episodes=episodes, fps=info["fps"])
|
hf_dataset = hf_dataset_factory(tasks=tasks, episodes=episodes, fps=info["fps"])
|
||||||
|
|
||||||
def _mock_snapshot_download(
|
def _mock_snapshot_download(
|
||||||
repo_id: str,
|
repo_id: str, # TODO(rcadene): repo_id should be used no?
|
||||||
local_dir: str | Path | None = None,
|
local_dir: str | Path | None = None,
|
||||||
allow_patterns: str | list[str] | None = None,
|
allow_patterns: str | list[str] | None = None,
|
||||||
ignore_patterns: str | list[str] | None = None,
|
ignore_patterns: str | list[str] | None = None,
|
||||||
*args,
|
*args,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
) -> str:
|
) -> str:
|
||||||
if not local_dir:
|
if local_dir is None:
|
||||||
local_dir = LEROBOT_TEST_DIR
|
local_dir = LEROBOT_TEST_DIR
|
||||||
|
|
||||||
# List all possible files
|
# List all possible files
|
||||||
all_files = [
|
all_files = [
|
||||||
INFO_PATH,
|
INFO_PATH,
|
||||||
LEGACY_STATS_PATH,
|
LEGACY_STATS_PATH,
|
||||||
# TODO(rcadene)
|
# TODO(rcadene): remove naive chunk 0 file 0 ?
|
||||||
DEFAULT_TASKS_PATH.format(chunk_index=0, file_index=0),
|
DEFAULT_TASKS_PATH.format(chunk_index=0, file_index=0),
|
||||||
DEFAULT_EPISODES_STATS_PATH.format(chunk_index=0, file_index=0),
|
|
||||||
DEFAULT_EPISODES_PATH.format(chunk_index=0, file_index=0),
|
DEFAULT_EPISODES_PATH.format(chunk_index=0, file_index=0),
|
||||||
DEFAULT_DATA_PATH.format(chunk_index=0, file_index=0),
|
DEFAULT_DATA_PATH.format(chunk_index=0, file_index=0),
|
||||||
]
|
]
|
||||||
|
@ -89,7 +84,6 @@ def mock_snapshot_download_factory(
|
||||||
has_info = False
|
has_info = False
|
||||||
has_tasks = False
|
has_tasks = False
|
||||||
has_episodes = False
|
has_episodes = False
|
||||||
has_episodes_stats = False
|
|
||||||
has_stats = False
|
has_stats = False
|
||||||
has_data = False
|
has_data = False
|
||||||
for rel_path in allowed_files:
|
for rel_path in allowed_files:
|
||||||
|
@ -99,8 +93,6 @@ def mock_snapshot_download_factory(
|
||||||
has_stats = True
|
has_stats = True
|
||||||
elif rel_path.startswith("meta/tasks"):
|
elif rel_path.startswith("meta/tasks"):
|
||||||
has_tasks = True
|
has_tasks = True
|
||||||
elif rel_path.startswith("meta/episodes_stats"):
|
|
||||||
has_episodes_stats = True
|
|
||||||
elif rel_path.startswith("meta/episodes"):
|
elif rel_path.startswith("meta/episodes"):
|
||||||
has_episodes = True
|
has_episodes = True
|
||||||
elif rel_path.startswith("data/"):
|
elif rel_path.startswith("data/"):
|
||||||
|
@ -116,8 +108,6 @@ def mock_snapshot_download_factory(
|
||||||
create_tasks(local_dir, tasks)
|
create_tasks(local_dir, tasks)
|
||||||
if has_episodes:
|
if has_episodes:
|
||||||
create_episodes(local_dir, episodes)
|
create_episodes(local_dir, episodes)
|
||||||
if has_episodes_stats:
|
|
||||||
create_episodes_stats(local_dir, episodes_stats)
|
|
||||||
if has_data:
|
if has_data:
|
||||||
create_hf_dataset(local_dir, hf_dataset)
|
create_hf_dataset(local_dir, hf_dataset)
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue