Commit before episodes episodes_stats merging
This commit is contained in:
parent
53ecec5fb2
commit
c1b28f0b58
|
@ -26,7 +26,7 @@ from datatrove.pipeline.base import PipelineStep
|
|||
from examples.port_datasets.droid_rlds.port_droid import DROID_SHARDS
|
||||
from lerobot.common.datasets.aggregate import validate_all_metadata
|
||||
from lerobot.common.datasets.lerobot_dataset import LeRobotDatasetMetadata
|
||||
from lerobot.common.datasets.utils import write_episode, write_episode_stats, write_info, write_task
|
||||
from lerobot.common.datasets.utils import write_episode, legacy_write_episode_stats, write_info, legacy_write_task
|
||||
from lerobot.common.utils.utils import init_logging
|
||||
|
||||
|
||||
|
@ -124,11 +124,11 @@ class AggregateDatasets(PipelineStep):
|
|||
for episode_index, episode_stats in tqdm.tqdm(
|
||||
aggr_meta.episodes_stats.items(), desc="Write episodes stats"
|
||||
):
|
||||
write_episode_stats(episode_index, episode_stats, aggr_meta.root)
|
||||
legacy_write_episode_stats(episode_index, episode_stats, aggr_meta.root)
|
||||
|
||||
# create a new task jsonl with updated episode_index using write_task
|
||||
for task_index, task in tqdm.tqdm(aggr_meta.tasks.items(), desc="Write tasks"):
|
||||
write_task(task_index, task, aggr_meta.root)
|
||||
legacy_write_task(task_index, task, aggr_meta.root)
|
||||
|
||||
write_info(aggr_meta.info, aggr_meta.root)
|
||||
|
||||
|
|
|
@ -5,7 +5,7 @@ import pandas as pd
|
|||
import tqdm
|
||||
|
||||
from lerobot.common.datasets.lerobot_dataset import LeRobotDataset, LeRobotDatasetMetadata
|
||||
from lerobot.common.datasets.utils import write_episode, write_episode_stats, write_info, write_task
|
||||
from lerobot.common.datasets.utils import write_episode, legacy_write_episode_stats, write_info, legacy_write_task
|
||||
from lerobot.common.utils.utils import init_logging
|
||||
|
||||
|
||||
|
@ -136,11 +136,11 @@ def aggregate_datasets(repo_ids: list[str], aggr_repo_id: str, aggr_root=None):
|
|||
|
||||
# create a new episode_stats jsonl with updated episode_index using write_episode_stats
|
||||
for episode_index, episode_stats in aggr_meta.episodes_stats.items():
|
||||
write_episode_stats(episode_index, episode_stats, aggr_meta.root)
|
||||
legacy_write_episode_stats(episode_index, episode_stats, aggr_meta.root)
|
||||
|
||||
# create a new task jsonl with updated episode_index using write_task
|
||||
for task_index, task in aggr_meta.tasks.items():
|
||||
write_task(task_index, task, aggr_meta.root)
|
||||
legacy_write_task(task_index, task, aggr_meta.root)
|
||||
|
||||
write_info(aggr_meta.info, aggr_meta.root)
|
||||
|
||||
|
|
|
@ -33,6 +33,18 @@ If you encounter a problem, contact LeRobot maintainers on [Discord](https://dis
|
|||
or open an [issue on GitHub](https://github.com/huggingface/lerobot/issues/new/choose).
|
||||
"""
|
||||
|
||||
V30_MESSAGE = """
|
||||
The dataset you requested ({repo_id}) is in {version} format.
|
||||
While current version of LeRobot is backward-compatible with it, the version of your dataset still uses global
|
||||
stats instead of per-episode stats. Update your dataset stats to the new format using this command:
|
||||
```
|
||||
python lerobot/common/datasets/v30/convert_dataset_v21_to_v30.py --repo-id={repo_id}
|
||||
```
|
||||
|
||||
If you encounter a problem, contact LeRobot maintainers on [Discord](https://discord.com/invite/s3KuuzsPFb)
|
||||
or open an [issue on GitHub](https://github.com/huggingface/lerobot/issues/new/choose).
|
||||
"""
|
||||
|
||||
FUTURE_MESSAGE = """
|
||||
The dataset you requested ({repo_id}) is only available in {version} format.
|
||||
As we cannot ensure forward compatibility with it, please update your current version of lerobot.
|
||||
|
@ -44,7 +56,12 @@ class CompatibilityError(Exception): ...
|
|||
|
||||
class BackwardCompatibilityError(CompatibilityError):
|
||||
def __init__(self, repo_id: str, version: packaging.version.Version):
|
||||
message = V2_MESSAGE.format(repo_id=repo_id, version=version)
|
||||
if version.major == 3:
|
||||
message = V30_MESSAGE.format(repo_id=repo_id, version=version)
|
||||
elif version.major == 2:
|
||||
message = V2_MESSAGE.format(repo_id=repo_id, version=version)
|
||||
else:
|
||||
raise NotImplementedError("Contact the maintainer on [Discord](https://discord.com/invite/s3KuuzsPFb).")
|
||||
super().__init__(message)
|
||||
|
||||
|
||||
|
|
|
@ -17,15 +17,16 @@ import contextlib
|
|||
import logging
|
||||
import shutil
|
||||
from pathlib import Path
|
||||
import tempfile
|
||||
from typing import Callable
|
||||
|
||||
import datasets
|
||||
import numpy as np
|
||||
import packaging.version
|
||||
import PIL.Image
|
||||
import pandas as pd
|
||||
import torch
|
||||
import torch.utils
|
||||
from datasets import concatenate_datasets, load_dataset
|
||||
from datasets import concatenate_datasets, load_dataset, Dataset
|
||||
from huggingface_hub import HfApi, snapshot_download
|
||||
from huggingface_hub.constants import REPOCARD_NAME
|
||||
from huggingface_hub.errors import RevisionNotFoundError
|
||||
|
@ -34,37 +35,57 @@ from lerobot.common.constants import HF_LEROBOT_HOME
|
|||
from lerobot.common.datasets.compute_stats import aggregate_stats, compute_episode_stats
|
||||
from lerobot.common.datasets.image_writer import AsyncImageWriter, write_image
|
||||
from lerobot.common.datasets.utils import (
|
||||
DEFAULT_EPISODES_PATH,
|
||||
DEFAULT_EPISODES_STATS_PATH,
|
||||
DEFAULT_FEATURES,
|
||||
DEFAULT_IMAGE_PATH,
|
||||
EPISODES_DIR,
|
||||
EPISODES_STATS_DIR,
|
||||
INFO_PATH,
|
||||
TASKS_PATH,
|
||||
LEGACY_TASKS_PATH,
|
||||
append_jsonlines,
|
||||
backward_compatible_episodes_stats,
|
||||
check_delta_timestamps,
|
||||
check_timestamps_sync,
|
||||
check_version_compatibility,
|
||||
concat_video_files,
|
||||
create_empty_dataset_info,
|
||||
create_lerobot_dataset_card,
|
||||
embed_images,
|
||||
get_chunk_file_indices,
|
||||
get_delta_indices,
|
||||
get_episode_data_index,
|
||||
get_features_from_robot,
|
||||
get_hf_dataset_size_in_mb,
|
||||
get_hf_features_from_features,
|
||||
get_latest_parquet_path,
|
||||
get_latest_video_path,
|
||||
get_parquet_num_frames,
|
||||
get_pd_dataframe_size_in_mb,
|
||||
get_safe_version,
|
||||
get_video_duration_in_s,
|
||||
get_video_size_in_mb,
|
||||
hf_transform_to_torch,
|
||||
is_valid_version,
|
||||
legacy_load_episodes,
|
||||
legacy_load_episodes_stats,
|
||||
load_episodes,
|
||||
load_episodes_stats,
|
||||
load_info,
|
||||
load_nested_dataset,
|
||||
load_stats,
|
||||
legacy_load_tasks,
|
||||
load_tasks,
|
||||
update_chunk_file_indices,
|
||||
validate_episode_buffer,
|
||||
validate_frame,
|
||||
write_episode,
|
||||
write_episode_stats,
|
||||
legacy_write_episode_stats,
|
||||
write_info,
|
||||
write_json,
|
||||
write_tasks,
|
||||
)
|
||||
from lerobot.common.datasets.v30.convert_dataset_v21_to_v30 import get_parquet_file_size_in_mb
|
||||
from lerobot.common.datasets.video_utils import (
|
||||
VideoFrame,
|
||||
decode_video_frames_torchvision,
|
||||
|
@ -105,12 +126,9 @@ class LeRobotDatasetMetadata:
|
|||
check_version_compatibility(self.repo_id, self._version, CODEBASE_VERSION)
|
||||
self.tasks, self.task_to_task_index = load_tasks(self.root)
|
||||
self.episodes = load_episodes(self.root)
|
||||
if self._version < packaging.version.parse("v2.1"):
|
||||
self.stats = load_stats(self.root)
|
||||
self.episodes_stats = backward_compatible_episodes_stats(self.stats, self.episodes)
|
||||
else:
|
||||
self.episodes_stats = load_episodes_stats(self.root)
|
||||
self.stats = aggregate_stats(list(self.episodes_stats.values()))
|
||||
self.episodes_stats = load_episodes_stats(self.root)
|
||||
# TODO(rcadene): https://huggingface.slack.com/archives/C02V51Q3800/p1743517952388249?thread_ts=1742896075.499119&cid=C02V51Q3800
|
||||
# self.stats = aggregate_stats(list(self.episodes_stats.values()))
|
||||
|
||||
def pull_from_repo(
|
||||
self,
|
||||
|
@ -132,17 +150,19 @@ class LeRobotDatasetMetadata:
|
|||
return packaging.version.parse(self.info["codebase_version"])
|
||||
|
||||
def get_data_file_path(self, ep_index: int) -> Path:
|
||||
ep_chunk = self.get_episode_chunk(ep_index)
|
||||
fpath = self.data_path.format(episode_chunk=ep_chunk, episode_index=ep_index)
|
||||
chunk_idx = self.episodes[f"data/chunk_index"][ep_index]
|
||||
file_idx = self.episodes[f"data/file_index"][ep_index]
|
||||
fpath = self.data_path.format(chunk_index=chunk_idx, file_index=file_idx)
|
||||
return Path(fpath)
|
||||
|
||||
def get_video_file_path(self, ep_index: int, vid_key: str) -> Path:
|
||||
ep_chunk = self.get_episode_chunk(ep_index)
|
||||
fpath = self.video_path.format(episode_chunk=ep_chunk, video_key=vid_key, episode_index=ep_index)
|
||||
chunk_idx = self.episodes[f"{vid_key}/chunk_index"][ep_index]
|
||||
file_idx = self.episodes[f"{vid_key}/file_index"][ep_index]
|
||||
fpath = self.video_path.format(video_key=vid_key, chunk_index=chunk_idx, file_index=file_idx)
|
||||
return Path(fpath)
|
||||
|
||||
def get_episode_chunk(self, ep_index: int) -> int:
|
||||
return ep_index // self.chunks_size
|
||||
# def get_episode_chunk(self, ep_index: int) -> int:
|
||||
# return ep_index // self.chunks_size
|
||||
|
||||
@property
|
||||
def data_path(self) -> str:
|
||||
|
@ -209,40 +229,85 @@ class LeRobotDatasetMetadata:
|
|||
"""Total number of different tasks performed in this dataset."""
|
||||
return self.info["total_tasks"]
|
||||
|
||||
@property
|
||||
def total_chunks(self) -> int:
|
||||
"""Total number of chunks (groups of episodes)."""
|
||||
return self.info["total_chunks"]
|
||||
|
||||
@property
|
||||
def chunks_size(self) -> int:
|
||||
"""Max number of episodes per chunk."""
|
||||
"""Max number of files per chunk."""
|
||||
return self.info["chunks_size"]
|
||||
|
||||
@property
|
||||
def files_size_in_mb(self) -> int:
|
||||
"""Max size of file in mega bytes."""
|
||||
return self.info["files_size_in_mb"]
|
||||
|
||||
def get_task_index(self, task: str) -> int | None:
|
||||
"""
|
||||
Given a task in natural language, returns its task_index if the task already exists in the dataset,
|
||||
otherwise return None.
|
||||
"""
|
||||
return self.task_to_task_index.get(task, None)
|
||||
return self.tasks.index[task] if task in self.tasks.index else None
|
||||
|
||||
def has_task(self, task: str) -> bool:
|
||||
return task in self.task_to_task_index
|
||||
|
||||
def add_task(self, task: str):
|
||||
"""
|
||||
Given a task in natural language, add it to the dictionary of tasks.
|
||||
"""
|
||||
if task in self.task_to_task_index:
|
||||
raise ValueError(f"The task '{task}' already exists and can't be added twice.")
|
||||
def save_episode_tasks(self, tasks: list[str]):
|
||||
new_tasks = [task for task in tasks if not self.has_task(task)]
|
||||
|
||||
task_index = self.info["total_tasks"]
|
||||
self.task_to_task_index[task] = task_index
|
||||
self.tasks[task_index] = task
|
||||
self.info["total_tasks"] += 1
|
||||
for task in new_tasks:
|
||||
task_index = len(self.tasks)
|
||||
self.tasks.loc[task] = task_index
|
||||
|
||||
task_dict = {
|
||||
"task_index": task_index,
|
||||
"task": task,
|
||||
}
|
||||
append_jsonlines(task_dict, self.root / TASKS_PATH)
|
||||
if len(new_tasks) > 0:
|
||||
# Update on disk
|
||||
write_tasks(self.tasks, self.root)
|
||||
|
||||
def _save_episode(self, episode_dict: dict):
|
||||
ep_dataset = Dataset.from_dict(episode_dict)
|
||||
ep_size_in_mb = get_hf_dataset_size_in_mb(ep_dataset)
|
||||
|
||||
# Access latest parquet file information
|
||||
latest_path = get_latest_parquet_path(self.root / EPISODES_DIR)
|
||||
latest_size_in_mb = get_parquet_file_size_in_mb(latest_path)
|
||||
chunk_idx, file_idx = get_chunk_file_indices(latest_path)
|
||||
|
||||
if latest_size_in_mb + ep_size_in_mb >= self.meta.files_size_in_mb:
|
||||
# Create new parquet file
|
||||
chunk_idx, file_idx = update_chunk_file_indices(chunk_idx, file_idx, self.meta.chunks_size)
|
||||
new_path = self.root / DEFAULT_EPISODES_PATH.format(chunk_index=chunk_idx, file_index=file_idx)
|
||||
new_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
ep_df.to_parquet(new_path, index=False)
|
||||
else:
|
||||
# Update latest parquet file with new row
|
||||
ep_df = pd.DataFrame(ep_dataset)
|
||||
latest_df = pd.concat([latest_df, ep_df], ignore_index=True) # RAM
|
||||
latest_df.to_parquet(latest_path, index=False)
|
||||
|
||||
# Update the Hugging Face dataset by reloading it.
|
||||
# This process should be fast because only the latest Parquet file has been modified.
|
||||
# Therefore, only this file needs to be converted to PyArrow; the rest is loaded from the PyArrow memory-mapped cache.
|
||||
self.episodes = load_episodes(self.root)
|
||||
|
||||
def _save_episode_stats(self, episodes_stats: dict):
|
||||
ep_dataset = Dataset.from_dict(episodes_stats)
|
||||
ep_size_in_mb = get_hf_dataset_size_in_mb(ep_dataset)
|
||||
|
||||
# Access latest parquet file information
|
||||
latest_path = get_latest_parquet_path(self.root / EPISODES_STATS_DIR)
|
||||
latest_size_in_mb = get_parquet_file_size_in_mb(latest_path)
|
||||
chunk_idx, file_idx = get_chunk_file_indices(latest_path)
|
||||
|
||||
if latest_size_in_mb + ep_size_in_mb >= self.meta.files_size_in_mb:
|
||||
# Create new parquet file
|
||||
chunk_idx, file_idx = update_chunk_file_indices(chunk_idx, file_idx, self.meta.chunks_size)
|
||||
new_path = self.root / DEFAULT_EPISODES_STATS_PATH.format(chunk_index=chunk_idx, file_index=file_idx)
|
||||
new_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
ep_df.to_parquet(new_path, index=False)
|
||||
else:
|
||||
# Update latest parquet file with new row
|
||||
ep_df = pd.DataFrame(ep_dataset)
|
||||
latest_df = pd.concat([latest_df, ep_df], ignore_index=True) # RAM
|
||||
latest_df.to_parquet(latest_path, index=False)
|
||||
|
||||
self.episodes_stats = load_episodes_stats(self.root)
|
||||
|
||||
def save_episode(
|
||||
self,
|
||||
|
@ -250,19 +315,14 @@ class LeRobotDatasetMetadata:
|
|||
episode_length: int,
|
||||
episode_tasks: list[str],
|
||||
episode_stats: dict[str, dict],
|
||||
episode_metadata: dict,
|
||||
) -> None:
|
||||
# Update info
|
||||
self.info["total_episodes"] += 1
|
||||
self.info["total_frames"] += episode_length
|
||||
|
||||
chunk = self.get_episode_chunk(episode_index)
|
||||
if chunk >= self.total_chunks:
|
||||
self.info["total_chunks"] += 1
|
||||
|
||||
self.info["splits"] = {"train": f"0:{self.info['total_episodes']}"}
|
||||
self.info["total_videos"] += len(self.video_keys)
|
||||
if len(self.video_keys) > 0:
|
||||
self.update_video_info()
|
||||
|
||||
write_info(self.info, self.root)
|
||||
|
||||
episode_dict = {
|
||||
|
@ -270,12 +330,12 @@ class LeRobotDatasetMetadata:
|
|||
"tasks": episode_tasks,
|
||||
"length": episode_length,
|
||||
}
|
||||
self.episodes[episode_index] = episode_dict
|
||||
write_episode(episode_dict, self.root)
|
||||
episode_dict.update(episode_metadata)
|
||||
self._save_episode(episode_dict)
|
||||
self._save_episode_stats(episode_stats)
|
||||
|
||||
self.episodes_stats[episode_index] = episode_stats
|
||||
self.stats = aggregate_stats([self.stats, episode_stats]) if self.stats else episode_stats
|
||||
write_episode_stats(episode_index, episode_stats, self.root)
|
||||
# TODO: write stats
|
||||
|
||||
def update_video_info(self) -> None:
|
||||
"""
|
||||
|
@ -340,8 +400,11 @@ class LeRobotDatasetMetadata:
|
|||
|
||||
features = {**features, **DEFAULT_FEATURES}
|
||||
|
||||
obj.tasks, obj.task_to_task_index = {}, {}
|
||||
obj.episodes_stats, obj.stats, obj.episodes = {}, {}, {}
|
||||
obj.tasks = None
|
||||
obj.episodes_stats = None
|
||||
obj.episodes = None
|
||||
# TODO(rcadene) stats
|
||||
obj.stats = {}
|
||||
obj.info = create_empty_dataset_info(CODEBASE_VERSION, fps, robot_type, features, use_videos)
|
||||
if len(obj.video_keys) > 0 and not use_videos:
|
||||
raise ValueError()
|
||||
|
@ -486,29 +549,17 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
|||
self.meta = LeRobotDatasetMetadata(
|
||||
self.repo_id, self.root, self.revision, force_cache_sync=force_cache_sync
|
||||
)
|
||||
if self.episodes is not None and self.meta._version >= packaging.version.parse("v2.1"):
|
||||
episodes_stats = [self.meta.episodes_stats[ep_idx] for ep_idx in self.episodes]
|
||||
self.stats = aggregate_stats(episodes_stats)
|
||||
|
||||
# Load actual data
|
||||
try:
|
||||
if force_cache_sync:
|
||||
raise FileNotFoundError
|
||||
assert all((self.root / fpath).is_file() for fpath in self.get_episodes_file_paths())
|
||||
self.hf_dataset = self.load_hf_dataset()
|
||||
except (AssertionError, FileNotFoundError, NotADirectoryError):
|
||||
self.revision = get_safe_version(self.repo_id, self.revision)
|
||||
self.download_episodes(download_videos)
|
||||
self.hf_dataset = self.load_hf_dataset()
|
||||
|
||||
self.episode_data_index = get_episode_data_index(self.meta.episodes, self.episodes)
|
||||
|
||||
# Check timestamps
|
||||
timestamps = torch.stack(self.hf_dataset["timestamp"]).numpy()
|
||||
episode_indices = torch.stack(self.hf_dataset["episode_index"]).numpy()
|
||||
ep_data_index_np = {k: t.numpy() for k, t in self.episode_data_index.items()}
|
||||
check_timestamps_sync(timestamps, episode_indices, ep_data_index_np, self.fps, self.tolerance_s)
|
||||
|
||||
# Setup delta_indices
|
||||
if self.delta_timestamps is not None:
|
||||
check_delta_timestamps(self.delta_timestamps, self.fps, self.tolerance_s)
|
||||
|
@ -592,11 +643,10 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
|||
"""
|
||||
# TODO(rcadene, aliberts): implement faster transfer
|
||||
# https://huggingface.co/docs/huggingface_hub/en/guides/download#faster-downloads
|
||||
files = None
|
||||
ignore_patterns = None if download_videos else "videos/"
|
||||
files = None
|
||||
if self.episodes is not None:
|
||||
files = self.get_episodes_file_paths()
|
||||
|
||||
self.pull_from_repo(allow_patterns=files, ignore_patterns=ignore_patterns)
|
||||
|
||||
def get_episodes_file_paths(self) -> list[Path]:
|
||||
|
@ -609,31 +659,21 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
|||
for ep_idx in episodes
|
||||
]
|
||||
fpaths += video_files
|
||||
|
||||
# episodes are stored in the same files, so we return unique paths only
|
||||
fpaths = list(set(fpaths))
|
||||
return fpaths
|
||||
|
||||
def load_hf_dataset(self) -> datasets.Dataset:
|
||||
"""hf_dataset contains all the observations, states, actions, rewards, etc."""
|
||||
if self.episodes is None:
|
||||
path = str(self.root / "data")
|
||||
# TODO(rcadene): load_dataset convert parquet to arrow.
|
||||
# set num_proc to accelerate this conversion
|
||||
hf_dataset = load_dataset("parquet", data_dir=path, split="train")
|
||||
else:
|
||||
files = [str(self.root / self.meta.get_data_file_path(ep_idx)) for ep_idx in self.episodes]
|
||||
hf_dataset = load_dataset("parquet", data_files=files, split="train")
|
||||
|
||||
# TODO(aliberts): hf_dataset.set_format("torch")
|
||||
hf_dataset.set_transform(hf_transform_to_torch)
|
||||
hf_dataset = load_nested_dataset(self.root / "data")
|
||||
hf_dataset.set_format("torch")
|
||||
return hf_dataset
|
||||
|
||||
def create_hf_dataset(self) -> datasets.Dataset:
|
||||
features = get_hf_features_from_features(self.features)
|
||||
ft_dict = {col: [] for col in features}
|
||||
hf_dataset = datasets.Dataset.from_dict(ft_dict, features=features, split="train")
|
||||
|
||||
# TODO(aliberts): hf_dataset.set_format("torch")
|
||||
hf_dataset.set_transform(hf_transform_to_torch)
|
||||
hf_dataset.set_format("torch")
|
||||
return hf_dataset
|
||||
|
||||
@property
|
||||
|
@ -664,15 +704,15 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
|||
return get_hf_features_from_features(self.features)
|
||||
|
||||
def _get_query_indices(self, idx: int, ep_idx: int) -> tuple[dict[str, list[int | bool]]]:
|
||||
ep_start = self.episode_data_index["from"][ep_idx]
|
||||
ep_end = self.episode_data_index["to"][ep_idx]
|
||||
ep_start = self.meta.episodes["data/from_index"][ep_idx]
|
||||
ep_end = self.meta.episodes["data/to_index"][ep_idx]
|
||||
query_indices = {
|
||||
key: [max(ep_start.item(), min(ep_end.item() - 1, idx + delta)) for delta in delta_idx]
|
||||
key: [max(ep_start, min(ep_end - 1, idx + delta)) for delta in delta_idx]
|
||||
for key, delta_idx in self.delta_indices.items()
|
||||
}
|
||||
padding = { # Pad values outside of current episode range
|
||||
f"{key}_is_pad": torch.BoolTensor(
|
||||
[(idx + delta < ep_start.item()) | (idx + delta >= ep_end.item()) for delta in delta_idx]
|
||||
[(idx + delta < ep_start) | (idx + delta >= ep_end) for delta in delta_idx]
|
||||
)
|
||||
for key, delta_idx in self.delta_indices.items()
|
||||
}
|
||||
|
@ -687,7 +727,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
|||
for key in self.meta.video_keys:
|
||||
if query_indices is not None and key in query_indices:
|
||||
timestamps = self.hf_dataset.select(query_indices[key])["timestamp"]
|
||||
query_timestamps[key] = torch.stack(timestamps).tolist()
|
||||
query_timestamps[key] = timestamps.tolist()
|
||||
else:
|
||||
query_timestamps[key] = [current_ts]
|
||||
|
||||
|
@ -695,7 +735,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
|||
|
||||
def _query_hf_dataset(self, query_indices: dict[str, list[int]]) -> dict:
|
||||
return {
|
||||
key: torch.stack(self.hf_dataset.select(q_idx)[key])
|
||||
key: self.hf_dataset.select(q_idx)[key]
|
||||
for key, q_idx in query_indices.items()
|
||||
if key not in self.meta.video_keys
|
||||
}
|
||||
|
@ -708,9 +748,15 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
|||
"""
|
||||
item = {}
|
||||
for vid_key, query_ts in query_timestamps.items():
|
||||
# Episodes are stored sequentially on a single mp4 to reduce the number of files.
|
||||
# Thus we load the start timestamp of the episode on this mp4 and,
|
||||
# shift the query timestamp accordingly.
|
||||
from_timestamp = self.meta.episodes[f"{vid_key}/from_timestamp"][ep_idx]
|
||||
shifted_query_ts = [from_timestamp + ts for ts in query_ts]
|
||||
|
||||
video_path = self.root / self.meta.get_video_file_path(ep_idx, vid_key)
|
||||
frames = decode_video_frames_torchvision(
|
||||
video_path, query_ts, self.tolerance_s, self.video_backend
|
||||
video_path, shifted_query_ts, self.tolerance_s, self.video_backend
|
||||
)
|
||||
item[vid_key] = frames.squeeze(0)
|
||||
|
||||
|
@ -749,7 +795,9 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
|||
|
||||
# Add task as a string
|
||||
task_idx = item["task_index"].item()
|
||||
item["task"] = self.meta.tasks[task_idx]
|
||||
if self.meta.tasks["task_index"][task_idx] != task_idx:
|
||||
raise ValueError("Sanity check on task index failed.")
|
||||
item["task"] = self.meta.tasks["task"][task_idx]
|
||||
|
||||
return item
|
||||
|
||||
|
@ -779,6 +827,9 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
|||
image_key=image_key, episode_index=episode_index, frame_index=frame_index
|
||||
)
|
||||
return self.root / fpath
|
||||
|
||||
def _get_image_file_dir(self, episode_index: int, image_key: str) -> Path:
|
||||
return self._get_image_file_path(episode_index, image_key, frame_index=0).parent
|
||||
|
||||
def _save_image(self, image: torch.Tensor | np.ndarray | PIL.Image.Image, fpath: Path) -> None:
|
||||
if self.image_writer is None:
|
||||
|
@ -858,11 +909,8 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
|||
episode_buffer["index"] = np.arange(self.meta.total_frames, self.meta.total_frames + episode_length)
|
||||
episode_buffer["episode_index"] = np.full((episode_length,), episode_index)
|
||||
|
||||
# Add new tasks to the tasks dictionary
|
||||
for task in episode_tasks:
|
||||
task_index = self.meta.get_task_index(task)
|
||||
if task_index is None:
|
||||
self.meta.add_task(task)
|
||||
# Update tasks and task indices with new tasks if any
|
||||
self.meta.save_episode_tasks(episode_tasks)
|
||||
|
||||
# Given tasks in natural language, find their corresponding task indices
|
||||
episode_buffer["task_index"] = np.array([self.meta.get_task_index(task) for task in tasks])
|
||||
|
@ -874,51 +922,107 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
|||
continue
|
||||
episode_buffer[key] = np.stack(episode_buffer[key])
|
||||
|
||||
# Wait for image writer to end, so that episode stats over images can be computed
|
||||
self._wait_image_writer()
|
||||
self._save_episode_table(episode_buffer, episode_index)
|
||||
ep_stats = compute_episode_stats(episode_buffer, self.features)
|
||||
|
||||
if len(self.meta.video_keys) > 0:
|
||||
video_paths = self.encode_episode_videos(episode_index)
|
||||
for key in self.meta.video_keys:
|
||||
episode_buffer[key] = video_paths[key]
|
||||
ep_metadata = self._save_episode_data(episode_buffer, episode_index)
|
||||
for video_key in self.meta.video_keys:
|
||||
ep_metadata.update(self._save_episode_video(video_key, episode_index))
|
||||
|
||||
# `meta.save_episode` neeed to be executed after encoding the videos
|
||||
self.meta.save_episode(episode_index, episode_length, episode_tasks, ep_stats, ep_metadata)
|
||||
|
||||
# `meta.save_episode` be executed after encoding the videos
|
||||
self.meta.save_episode(episode_index, episode_length, episode_tasks, ep_stats)
|
||||
|
||||
ep_data_index = get_episode_data_index(self.meta.episodes, [episode_index])
|
||||
ep_data_index_np = {k: t.numpy() for k, t in ep_data_index.items()}
|
||||
check_timestamps_sync(
|
||||
episode_buffer["timestamp"],
|
||||
episode_buffer["episode_index"],
|
||||
ep_data_index_np,
|
||||
self.fps,
|
||||
self.tolerance_s,
|
||||
)
|
||||
|
||||
video_files = list(self.root.rglob("*.mp4"))
|
||||
assert len(video_files) == self.num_episodes * len(self.meta.video_keys)
|
||||
|
||||
parquet_files = list(self.root.rglob("*.parquet"))
|
||||
assert len(parquet_files) == self.num_episodes
|
||||
# TODO(rcadene): remove? there is only one episode in the episode buffer, no need for ep_data_index
|
||||
# ep_data_index = get_episode_data_index(self.meta.episodes, [episode_index])
|
||||
# ep_data_index_np = {k: t.numpy() for k, t in ep_data_index.items()}
|
||||
# check_timestamps_sync(
|
||||
# episode_buffer["timestamp"],
|
||||
# episode_buffer["episode_index"],
|
||||
# ep_data_index_np,
|
||||
# self.fps,
|
||||
# self.tolerance_s,
|
||||
# )
|
||||
|
||||
# TODO(rcadene): images are also deleted in clear_episode_buffer
|
||||
# delete images
|
||||
img_dir = self.root / "images"
|
||||
if img_dir.is_dir():
|
||||
shutil.rmtree(self.root / "images")
|
||||
|
||||
if not episode_data: # Reset the buffer
|
||||
if not episode_data:
|
||||
# Reset episode buffer
|
||||
self.episode_buffer = self.create_episode_buffer()
|
||||
|
||||
def _save_episode_table(self, episode_buffer: dict, episode_index: int) -> None:
|
||||
episode_dict = {key: episode_buffer[key] for key in self.hf_features}
|
||||
ep_dataset = datasets.Dataset.from_dict(episode_dict, features=self.hf_features, split="train")
|
||||
def _save_episode_data(self, episode_buffer: dict) -> None:
|
||||
# Convert buffer into HF Dataset
|
||||
ep_dict = {key: episode_buffer[key] for key in self.hf_features}
|
||||
ep_dataset = datasets.Dataset.from_dict(ep_dict, features=self.hf_features, split="train")
|
||||
ep_dataset = embed_images(ep_dataset)
|
||||
self.hf_dataset = concatenate_datasets([self.hf_dataset, ep_dataset])
|
||||
self.hf_dataset.set_transform(hf_transform_to_torch)
|
||||
ep_data_path = self.root / self.meta.get_data_file_path(ep_index=episode_index)
|
||||
ep_data_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
ep_dataset.to_parquet(ep_data_path)
|
||||
ep_size_in_mb = get_hf_dataset_size_in_mb(ep_dataset)
|
||||
ep_num_frames = len(ep_dataset)
|
||||
|
||||
# Access latest parquet file information
|
||||
latest_path = get_latest_parquet_path(self.root / "data")
|
||||
latest_size_in_mb = get_parquet_file_size_in_mb(latest_path)
|
||||
latest_num_frames = get_parquet_num_frames(latest_path)
|
||||
chunk_idx, file_idx = get_chunk_file_indices(latest_path)
|
||||
|
||||
if latest_size_in_mb + ep_size_in_mb >= self.meta.files_size_in_mb:
|
||||
# Create new parquet file
|
||||
chunk_idx, file_idx = update_chunk_file_indices(chunk_idx, file_idx, self.meta.chunks_size)
|
||||
new_path = self.meta.data_path.format(chunk_index=chunk_idx, file_index=file_idx)
|
||||
new_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
ep_df.to_parquet(new_path, index=False)
|
||||
else:
|
||||
# Update latest parquet file with new rows
|
||||
ep_df = pd.DataFrame(ep_dataset)
|
||||
latest_df = pd.concat([latest_df, ep_df], ignore_index=True) # RAM
|
||||
latest_df.to_parquet(latest_path, index=False)
|
||||
|
||||
# Update the Hugging Face dataset by reloading it.
|
||||
# This process should be fast because only the latest Parquet file has been modified.
|
||||
# Therefore, only this file needs to be converted to PyArrow; the rest is loaded from the PyArrow memory-mapped cache.
|
||||
self.hf_dataset = self.load_hf_dataset()
|
||||
|
||||
metadata = {
|
||||
"data/chunk_index": chunk_idx,
|
||||
"data/file_index": file_idx,
|
||||
"data/from_index": latest_num_frames,
|
||||
"data/to_index": latest_num_frames + ep_num_frames,
|
||||
}
|
||||
return metadata
|
||||
|
||||
def _save_episode_video(self, video_key: str, episode_index: int):
|
||||
# Encode episode frames into a temporary video
|
||||
ep_path = self._encode_temporary_episode_video(video_key, episode_index)
|
||||
ep_size_in_mb = get_video_size_in_mb(ep_path)
|
||||
ep_duration_in_s = get_video_duration_in_s(ep_path)
|
||||
|
||||
# Access latest video file information
|
||||
latest_path = get_latest_video_path(self.root / "videos", video_key)
|
||||
latest_size_in_mb = get_video_size_in_mb(latest_path)
|
||||
latest_duration_in_s = get_video_duration_in_s(latest_path)
|
||||
chunk_idx, file_idx = get_chunk_file_indices(latest_path)
|
||||
|
||||
if latest_size_in_mb + ep_size_in_mb >= self.meta.files_size_in_mb:
|
||||
# Move temporary episode video to a new video file in the dataset
|
||||
chunk_idx, file_idx = update_chunk_file_indices(chunk_idx, file_idx, self.meta.chunks_size)
|
||||
new_path = self.meta.video_path.format(video_key=video_key, chunk_index=chunk_idx, file_index=file_idx)
|
||||
new_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
ep_path.replace(new_path)
|
||||
else:
|
||||
# Update latest video file
|
||||
concat_video_files([latest_path, ep_path], self.root, video_key, chunk_idx, file_idx)
|
||||
|
||||
metadata = {
|
||||
"episode_index": episode_index,
|
||||
f"{video_key}/chunk_index": chunk_idx,
|
||||
f"{video_key}/file_index": file_idx,
|
||||
f"{video_key}/from_timestamp": latest_duration_in_s,
|
||||
f"{video_key}/to_timestamp": latest_duration_in_s + ep_duration_in_s,
|
||||
}
|
||||
return metadata
|
||||
|
||||
def clear_episode_buffer(self) -> None:
|
||||
episode_index = self.episode_buffer["episode_index"]
|
||||
|
@ -958,34 +1062,26 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
|||
if self.image_writer is not None:
|
||||
self.image_writer.wait_until_done()
|
||||
|
||||
def encode_videos(self) -> None:
|
||||
# TODO(rcadene): this method is currently not used
|
||||
# def encode_videos(self) -> None:
|
||||
# """
|
||||
# Use ffmpeg to convert frames stored as png into mp4 videos.
|
||||
# Note: `encode_video_frames` is a blocking call. Making it asynchronous shouldn't speedup encoding,
|
||||
# since video encoding with ffmpeg is already using multithreading.
|
||||
# """
|
||||
# for ep_idx in range(self.meta.total_episodes):
|
||||
# self.encode_episode_videos(ep_idx)
|
||||
|
||||
def _encode_temporary_episode_video(self, video_key: str, episode_index: int) -> dict:
|
||||
"""
|
||||
Use ffmpeg to convert frames stored as png into mp4 videos.
|
||||
Note: `encode_video_frames` is a blocking call. Making it asynchronous shouldn't speedup encoding,
|
||||
since video encoding with ffmpeg is already using multithreading.
|
||||
"""
|
||||
for ep_idx in range(self.meta.total_episodes):
|
||||
self.encode_episode_videos(ep_idx)
|
||||
|
||||
def encode_episode_videos(self, episode_index: int) -> dict:
|
||||
"""
|
||||
Use ffmpeg to convert frames stored as png into mp4 videos.
|
||||
Note: `encode_video_frames` is a blocking call. Making it asynchronous shouldn't speedup encoding,
|
||||
since video encoding with ffmpeg is already using multithreading.
|
||||
"""
|
||||
video_paths = {}
|
||||
for key in self.meta.video_keys:
|
||||
video_path = self.root / self.meta.get_video_file_path(episode_index, key)
|
||||
video_paths[key] = str(video_path)
|
||||
if video_path.is_file():
|
||||
# Skip if video is already encoded. Could be the case when resuming data recording.
|
||||
continue
|
||||
img_dir = self._get_image_file_path(
|
||||
episode_index=episode_index, image_key=key, frame_index=0
|
||||
).parent
|
||||
encode_video_frames(img_dir, video_path, self.fps, overwrite=True)
|
||||
|
||||
return video_paths
|
||||
temp_path = Path(tempfile.mkdtemp()) / f"{video_key}_{episode_index:3d}.mp4"
|
||||
img_dir = self._get_image_file_dir(episode_index, video_key)
|
||||
encode_video_frames(img_dir, temp_path, self.fps, overwrite=True)
|
||||
return temp_path
|
||||
|
||||
@classmethod
|
||||
def create(
|
||||
|
@ -1030,7 +1126,6 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
|||
obj.image_transforms = None
|
||||
obj.delta_timestamps = None
|
||||
obj.delta_indices = None
|
||||
obj.episode_data_index = None
|
||||
obj.video_backend = video_backend if video_backend is not None else "pyav"
|
||||
return obj
|
||||
|
||||
|
|
|
@ -21,40 +21,62 @@ from collections.abc import Iterator
|
|||
from itertools import accumulate
|
||||
from pathlib import Path
|
||||
from pprint import pformat
|
||||
import subprocess
|
||||
import tempfile
|
||||
from types import SimpleNamespace
|
||||
from typing import Any
|
||||
from typing import Any, Tuple
|
||||
|
||||
import datasets
|
||||
import jsonlines
|
||||
import numpy as np
|
||||
import packaging.version
|
||||
import pandas
|
||||
import torch
|
||||
from datasets.table import embed_table_storage
|
||||
from huggingface_hub import DatasetCard, DatasetCardData, HfApi
|
||||
from huggingface_hub.errors import RevisionNotFoundError
|
||||
from PIL import Image as PILImage
|
||||
from torchvision import transforms
|
||||
from datasets import Dataset, concatenate_datasets
|
||||
|
||||
from lerobot.common.datasets.backward_compatibility import (
|
||||
V21_MESSAGE,
|
||||
V30_MESSAGE,
|
||||
BackwardCompatibilityError,
|
||||
ForwardCompatibilityError,
|
||||
)
|
||||
from lerobot.common.robot_devices.robots.utils import Robot
|
||||
from lerobot.common.utils.utils import is_valid_numpy_dtype_string
|
||||
from lerobot.configs.types import DictLike, FeatureType, PolicyFeature
|
||||
import pyarrow.parquet as pq
|
||||
|
||||
DEFAULT_CHUNK_SIZE = 1000 # Max number of episodes per chunk
|
||||
DEFAULT_CHUNK_SIZE = 1000 # Max number of files per chunk
|
||||
DEFAULT_FILE_SIZE_IN_MB = 500.0 # Max size per file
|
||||
|
||||
# Keep legacy for `convert_dataset_v21_to_v30.py`
|
||||
LEGACY_EPISODES_PATH = "meta/episodes.jsonl"
|
||||
LEGACY_STATS_PATH = "meta/stats.json"
|
||||
LEGACY_EPISODES_STATS_PATH = "meta/episodes_stats.jsonl"
|
||||
LEGACY_TASKS_PATH = "meta/tasks.jsonl"
|
||||
LEGACY_DEFAULT_VIDEO_PATH = "videos/chunk-{episode_chunk:03d}/{video_key}/episode_{episode_index:06d}.mp4"
|
||||
LEGACY_DEFAULT_PARQUET_PATH = "data/chunk-{episode_chunk:03d}/episode_{episode_index:06d}.parquet"
|
||||
|
||||
# TODO
|
||||
DEFAULT_IMAGE_PATH = "images/{image_key}/episode_{episode_index:06d}/frame_{frame_index:06d}.png"
|
||||
|
||||
EPISODES_DIR = "meta/episodes"
|
||||
EPISODES_STATS_DIR = "meta/episodes_stats"
|
||||
TASKS_DIR = "meta/tasks"
|
||||
DATA_DIR = "data"
|
||||
VIDEO_DIR = "videos"
|
||||
|
||||
INFO_PATH = "meta/info.json"
|
||||
EPISODES_PATH = "meta/episodes.jsonl"
|
||||
STATS_PATH = "meta/stats.json"
|
||||
EPISODES_STATS_PATH = "meta/episodes_stats.jsonl"
|
||||
TASKS_PATH = "meta/tasks.jsonl"
|
||||
|
||||
DEFAULT_VIDEO_PATH = "videos/chunk-{episode_chunk:03d}/{video_key}/episode_{episode_index:06d}.mp4"
|
||||
DEFAULT_PARQUET_PATH = "data/chunk-{episode_chunk:03d}/episode_{episode_index:06d}.parquet"
|
||||
DEFAULT_IMAGE_PATH = "images/{image_key}/episode_{episode_index:06d}/frame_{frame_index:06d}.png"
|
||||
CHUNK_FILE_PATTERN = "chunk-{chunk_index:03d}/file-{file_index:03d}"
|
||||
DEFAULT_EPISODES_PATH = EPISODES_DIR + "/" + CHUNK_FILE_PATTERN + ".parquet"
|
||||
DEFAULT_EPISODES_STATS_PATH = EPISODES_STATS_DIR + "/" + CHUNK_FILE_PATTERN + ".parquet"
|
||||
DEFAULT_TASKS_PATH = TASKS_DIR + "/" + CHUNK_FILE_PATTERN + ".parquet"
|
||||
DEFAULT_DATA_PATH = DATA_DIR + "/" + CHUNK_FILE_PATTERN + ".parquet"
|
||||
DEFAULT_VIDEO_PATH = VIDEO_DIR + "/{video_key}/" + CHUNK_FILE_PATTERN + ".mp4"
|
||||
|
||||
DATASET_CARD_TEMPLATE = """
|
||||
---
|
||||
|
@ -75,6 +97,88 @@ DEFAULT_FEATURES = {
|
|||
}
|
||||
|
||||
|
||||
def get_hf_dataset_size_in_mb(hf_ds: Dataset) -> int:
|
||||
return hf_ds.data.nbytes / (1024 ** 2)
|
||||
|
||||
def get_pd_dataframe_size_in_mb(df: pandas.DataFrame) -> int:
|
||||
memory_usage_bytes = df.memory_usage(deep=True).sum()
|
||||
return memory_usage_bytes / (1024 ** 2)
|
||||
|
||||
def get_chunk_file_indices(path: Path) -> Tuple[int, int]:
|
||||
if not path.stem.startswith("file-") or not path.parent.name.startswith("chunk-"):
|
||||
raise ValueError(f"Path does not follow {CHUNK_FILE_PATTERN}: '{path}'")
|
||||
|
||||
chunk_index = int(path.parent.replace("chunk-", ""))
|
||||
file_index = int(path.stem.replace("file-", ""))
|
||||
return chunk_index, file_index
|
||||
|
||||
def update_chunk_file_indices(chunk_idx: int, file_idx: int, chunks_size: int):
|
||||
if file_idx == chunks_size - 1:
|
||||
file_idx = 0
|
||||
chunk_idx += 1
|
||||
else:
|
||||
file_idx += 1
|
||||
return chunk_idx, file_idx
|
||||
|
||||
|
||||
def load_nested_dataset(pq_dir: Path) -> Dataset:
|
||||
""" Find parquet files in provided directory {pq_dir}/chunk-xxx/file-xxx.parquet
|
||||
Convert parquet files to pyarrow memory mapped in a cache folder for efficient RAM usage
|
||||
Concatenate all pyarrow references to return HF Dataset format
|
||||
"""
|
||||
# TODO(rcadene): set num_proc to accelerate conversion to pyarrow
|
||||
return concatenate_datasets([Dataset.from_parquet(str(path)) for path in sorted(pq_dir.glob("*/*.parquet"))])
|
||||
|
||||
def get_latest_parquet_path(pq_dir: Path) -> Path:
|
||||
return sorted(pq_dir.glob("*/*.parquet"))[-1]
|
||||
|
||||
def get_latest_video_path(pq_dir: Path, video_key: str) -> Path:
|
||||
return sorted(pq_dir.glob(f"{video_key}/*/*.mp4"))[-1]
|
||||
|
||||
def get_parquet_num_frames(parquet_path):
|
||||
metadata = pq.read_metadata(parquet_path)
|
||||
return metadata.num_rows
|
||||
|
||||
def get_video_size_in_mb(mp4_path: Path):
|
||||
file_size_bytes = mp4_path.stat().st_size
|
||||
file_size_mb = file_size_bytes / (1024 ** 2)
|
||||
return file_size_mb
|
||||
|
||||
def concat_video_files(paths_to_cat, new_root, video_key, chunk_idx, file_idx):
|
||||
# Create a text file with the list of files to concatenate
|
||||
with tempfile.NamedTemporaryFile(mode='w', delete=False, suffix='.txt') as f:
|
||||
temp_file_path = f.name
|
||||
for ep_path in paths_to_cat:
|
||||
f.write(f"file '{str(ep_path)}'\n")
|
||||
|
||||
output_path = new_root / DEFAULT_VIDEO_PATH.format(video_key=video_key, chunk_index=chunk_idx, file_index=file_idx)
|
||||
output_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
command = [
|
||||
'ffmpeg',
|
||||
'-y',
|
||||
'-f', 'concat',
|
||||
'-safe', '0',
|
||||
'-i', str(temp_file_path),
|
||||
'-c', 'copy',
|
||||
str(output_path)
|
||||
]
|
||||
subprocess.run(command, check=True)
|
||||
Path(temp_file_path).unlink()
|
||||
|
||||
def get_video_duration_in_s(mp4_file: Path):
|
||||
result = subprocess.run(
|
||||
[
|
||||
'ffprobe',
|
||||
'-v', 'error',
|
||||
'-show_entries', 'format=duration',
|
||||
'-of', 'default=noprint_wrappers=1:nokey=1',
|
||||
mp4_file
|
||||
],
|
||||
stdout=subprocess.PIPE,
|
||||
stderr=subprocess.STDOUT
|
||||
)
|
||||
return float(result.stdout)
|
||||
|
||||
def flatten_dict(d: dict, parent_key: str = "", sep: str = "/") -> dict:
|
||||
"""Flatten a nested dictionary structure by collapsing nested keys into one key with a separator.
|
||||
|
||||
|
@ -124,6 +228,8 @@ def serialize_dict(stats: dict[str, torch.Tensor | np.ndarray | dict]) -> dict:
|
|||
for key, value in flatten_dict(stats).items():
|
||||
if isinstance(value, (torch.Tensor, np.ndarray)):
|
||||
serialized_dict[key] = value.tolist()
|
||||
elif isinstance(value, list) and isinstance(value[0], (int, float, list)):
|
||||
serialized_dict[key] = value
|
||||
elif isinstance(value, np.generic):
|
||||
serialized_dict[key] = value.item()
|
||||
elif isinstance(value, (int, float)):
|
||||
|
@ -183,7 +289,7 @@ def load_info(local_dir: Path) -> dict:
|
|||
|
||||
def write_stats(stats: dict, local_dir: Path):
|
||||
serialized_stats = serialize_dict(stats)
|
||||
write_json(serialized_stats, local_dir / STATS_PATH)
|
||||
write_json(serialized_stats, local_dir / LEGACY_STATS_PATH)
|
||||
|
||||
|
||||
def cast_stats_to_numpy(stats) -> dict[str, dict[str, np.ndarray]]:
|
||||
|
@ -192,50 +298,91 @@ def cast_stats_to_numpy(stats) -> dict[str, dict[str, np.ndarray]]:
|
|||
|
||||
|
||||
def load_stats(local_dir: Path) -> dict[str, dict[str, np.ndarray]]:
|
||||
if not (local_dir / STATS_PATH).exists():
|
||||
if not (local_dir / LEGACY_STATS_PATH).exists():
|
||||
return None
|
||||
stats = load_json(local_dir / STATS_PATH)
|
||||
stats = load_json(local_dir / LEGACY_STATS_PATH)
|
||||
return cast_stats_to_numpy(stats)
|
||||
|
||||
|
||||
def write_task(task_index: int, task: dict, local_dir: Path):
|
||||
def write_hf_dataset(hf_dataset: Dataset, local_dir: Path):
|
||||
if get_hf_dataset_size_in_mb(hf_dataset) > DEFAULT_FILE_SIZE_IN_MB:
|
||||
raise NotImplementedError("Contact a maintainer.")
|
||||
|
||||
path = local_dir / DEFAULT_DATA_PATH.format(chunk_index=0, file_index=0)
|
||||
path.parent.mkdir(parents=True, exist_ok=True)
|
||||
hf_dataset.to_parquet(path)
|
||||
|
||||
|
||||
def write_tasks(tasks: pandas.DataFrame, local_dir: Path):
|
||||
path = local_dir / DEFAULT_TASKS_PATH.format(chunk_index=0, file_index=0)
|
||||
path.parent.mkdir(parents=True, exist_ok=True)
|
||||
tasks.to_parquet(path)
|
||||
|
||||
def legacy_write_task(task_index: int, task: dict, local_dir: Path):
|
||||
task_dict = {
|
||||
"task_index": task_index,
|
||||
"task": task,
|
||||
}
|
||||
append_jsonlines(task_dict, local_dir / TASKS_PATH)
|
||||
append_jsonlines(task_dict, local_dir / LEGACY_TASKS_PATH)
|
||||
|
||||
|
||||
def load_tasks(local_dir: Path) -> tuple[dict, dict]:
|
||||
tasks = load_jsonlines(local_dir / TASKS_PATH)
|
||||
def legacy_load_tasks(local_dir: Path) -> tuple[dict, dict]:
|
||||
tasks = load_jsonlines(local_dir / LEGACY_TASKS_PATH)
|
||||
tasks = {item["task_index"]: item["task"] for item in sorted(tasks, key=lambda x: x["task_index"])}
|
||||
task_to_task_index = {task: task_index for task_index, task in tasks.items()}
|
||||
return tasks, task_to_task_index
|
||||
|
||||
def load_tasks(local_dir: Path):
|
||||
tasks = load_nested_dataset(local_dir / TASKS_DIR)
|
||||
# TODO(rcadene): optimize this
|
||||
task_to_task_index = {d["task"]: d["task_index"] for d in tasks}
|
||||
return tasks, task_to_task_index
|
||||
|
||||
def write_episode(episode: dict, local_dir: Path):
|
||||
append_jsonlines(episode, local_dir / EPISODES_PATH)
|
||||
append_jsonlines(episode, local_dir / LEGACY_EPISODES_PATH)
|
||||
|
||||
def write_episodes(episodes: Dataset, local_dir: Path):
|
||||
if get_hf_dataset_size_in_mb(episodes) > DEFAULT_FILE_SIZE_IN_MB:
|
||||
raise NotImplementedError("Contact a maintainer.")
|
||||
|
||||
fpath = local_dir / DEFAULT_EPISODES_PATH.format(chunk_index=0, file_index=0)
|
||||
fpath.parent.mkdir(parents=True, exist_ok=True)
|
||||
episodes.to_parquet(fpath)
|
||||
|
||||
def load_episodes(local_dir: Path) -> dict:
|
||||
episodes = load_jsonlines(local_dir / EPISODES_PATH)
|
||||
def legacy_load_episodes(local_dir: Path) -> dict:
|
||||
episodes = load_jsonlines(local_dir / LEGACY_EPISODES_PATH)
|
||||
return {item["episode_index"]: item for item in sorted(episodes, key=lambda x: x["episode_index"])}
|
||||
|
||||
def load_episodes(local_dir: Path):
|
||||
hf_dataset = load_nested_dataset(local_dir / EPISODES_DIR)
|
||||
return hf_dataset
|
||||
|
||||
def write_episode_stats(episode_index: int, episode_stats: dict, local_dir: Path):
|
||||
def legacy_write_episode_stats(episode_index: int, episode_stats: dict, local_dir: Path):
|
||||
# We wrap episode_stats in a dictionary since `episode_stats["episode_index"]`
|
||||
# is a dictionary of stats and not an integer.
|
||||
episode_stats = {"episode_index": episode_index, "stats": serialize_dict(episode_stats)}
|
||||
append_jsonlines(episode_stats, local_dir / EPISODES_STATS_PATH)
|
||||
append_jsonlines(episode_stats, local_dir / LEGACY_EPISODES_STATS_PATH)
|
||||
|
||||
|
||||
def load_episodes_stats(local_dir: Path) -> dict:
|
||||
episodes_stats = load_jsonlines(local_dir / EPISODES_STATS_PATH)
|
||||
def write_episodes_stats(episodes_stats: Dataset, local_dir: Path):
|
||||
if get_hf_dataset_size_in_mb(episodes_stats) > DEFAULT_FILE_SIZE_IN_MB:
|
||||
raise NotImplementedError("Contact a maintainer.")
|
||||
|
||||
fpath = local_dir / DEFAULT_EPISODES_STATS_PATH.format(chunk_index=0, file_index=0)
|
||||
fpath.parent.mkdir(parents=True, exist_ok=True)
|
||||
episodes_stats.to_parquet(fpath)
|
||||
|
||||
|
||||
def legacy_load_episodes_stats(local_dir: Path) -> dict:
|
||||
episodes_stats = load_jsonlines(local_dir / LEGACY_EPISODES_STATS_PATH)
|
||||
return {
|
||||
item["episode_index"]: cast_stats_to_numpy(item["stats"])
|
||||
for item in sorted(episodes_stats, key=lambda x: x["episode_index"])
|
||||
}
|
||||
|
||||
def load_episodes_stats(local_dir: Path):
|
||||
hf_dataset = load_nested_dataset(local_dir / EPISODES_STATS_DIR)
|
||||
return hf_dataset
|
||||
|
||||
def backward_compatible_episodes_stats(
|
||||
stats: dict[str, dict[str, np.ndarray]], episodes: list[int]
|
||||
|
@ -388,6 +535,7 @@ def get_hf_features_from_features(features: dict) -> datasets.Features:
|
|||
|
||||
|
||||
def get_features_from_robot(robot: Robot, use_videos: bool = True) -> dict:
|
||||
# TODO(rcadene): add fps for each feature
|
||||
camera_ft = {}
|
||||
if robot.cameras:
|
||||
camera_ft = {
|
||||
|
@ -442,11 +590,11 @@ def create_empty_dataset_info(
|
|||
"total_frames": 0,
|
||||
"total_tasks": 0,
|
||||
"total_videos": 0,
|
||||
"total_chunks": 0,
|
||||
"chunks_size": DEFAULT_CHUNK_SIZE,
|
||||
"files_size_in_mb": DEFAULT_FILE_SIZE_IN_MB,
|
||||
"fps": fps,
|
||||
"splits": {},
|
||||
"data_path": DEFAULT_PARQUET_PATH,
|
||||
"data_path": DEFAULT_DATA_PATH,
|
||||
"video_path": DEFAULT_VIDEO_PATH if use_videos else None,
|
||||
"features": features,
|
||||
}
|
||||
|
|
|
@ -121,12 +121,12 @@ from safetensors.torch import load_file
|
|||
|
||||
from lerobot.common.datasets.utils import (
|
||||
DEFAULT_CHUNK_SIZE,
|
||||
DEFAULT_PARQUET_PATH,
|
||||
DEFAULT_DATA_PATH,
|
||||
DEFAULT_VIDEO_PATH,
|
||||
EPISODES_PATH,
|
||||
LEGACY_EPISODES_PATH,
|
||||
INFO_PATH,
|
||||
STATS_PATH,
|
||||
TASKS_PATH,
|
||||
LEGACY_STATS_PATH,
|
||||
LEGACY_TASKS_PATH,
|
||||
create_branch,
|
||||
create_lerobot_dataset_card,
|
||||
flatten_dict,
|
||||
|
@ -188,7 +188,7 @@ def convert_stats_to_json(v1_dir: Path, v2_dir: Path) -> None:
|
|||
serialized_stats = {key: value.tolist() for key, value in stats.items()}
|
||||
serialized_stats = unflatten_dict(serialized_stats)
|
||||
|
||||
json_path = v2_dir / STATS_PATH
|
||||
json_path = v2_dir / LEGACY_STATS_PATH
|
||||
json_path.parent.mkdir(exist_ok=True, parents=True)
|
||||
with open(json_path, "w") as f:
|
||||
json.dump(serialized_stats, f, indent=4)
|
||||
|
@ -291,12 +291,12 @@ def split_parquet_by_episodes(
|
|||
for ep_chunk in range(total_chunks):
|
||||
ep_chunk_start = DEFAULT_CHUNK_SIZE * ep_chunk
|
||||
ep_chunk_end = min(DEFAULT_CHUNK_SIZE * (ep_chunk + 1), total_episodes)
|
||||
chunk_dir = "/".join(DEFAULT_PARQUET_PATH.split("/")[:-1]).format(episode_chunk=ep_chunk)
|
||||
chunk_dir = "/".join(DEFAULT_DATA_PATH.split("/")[:-1]).format(episode_chunk=ep_chunk)
|
||||
(output_dir / chunk_dir).mkdir(parents=True, exist_ok=True)
|
||||
for ep_idx in range(ep_chunk_start, ep_chunk_end):
|
||||
ep_table = table.filter(pc.equal(table["episode_index"], ep_idx))
|
||||
episode_lengths.insert(ep_idx, len(ep_table))
|
||||
output_file = output_dir / DEFAULT_PARQUET_PATH.format(
|
||||
output_file = output_dir / DEFAULT_DATA_PATH.format(
|
||||
episode_chunk=ep_chunk, episode_index=ep_idx
|
||||
)
|
||||
pq.write_table(ep_table, output_file)
|
||||
|
@ -496,7 +496,7 @@ def convert_dataset(
|
|||
|
||||
assert set(tasks) == {task for ep_tasks in tasks_by_episodes.values() for task in ep_tasks}
|
||||
tasks = [{"task_index": task_idx, "task": task} for task_idx, task in enumerate(tasks)]
|
||||
write_jsonlines(tasks, v20_dir / TASKS_PATH)
|
||||
write_jsonlines(tasks, v20_dir / LEGACY_TASKS_PATH)
|
||||
features["task_index"] = {
|
||||
"dtype": "int64",
|
||||
"shape": (1,),
|
||||
|
@ -546,7 +546,7 @@ def convert_dataset(
|
|||
{"episode_index": ep_idx, "tasks": tasks_by_episodes[ep_idx], "length": episode_lengths[ep_idx]}
|
||||
for ep_idx in episode_indices
|
||||
]
|
||||
write_jsonlines(episodes, v20_dir / EPISODES_PATH)
|
||||
write_jsonlines(episodes, v20_dir / LEGACY_EPISODES_PATH)
|
||||
|
||||
# Assemble metadata v2.0
|
||||
metadata_v2_0 = {
|
||||
|
@ -560,7 +560,7 @@ def convert_dataset(
|
|||
"chunks_size": DEFAULT_CHUNK_SIZE,
|
||||
"fps": metadata_v1["fps"],
|
||||
"splits": {"train": f"0:{total_episodes}"},
|
||||
"data_path": DEFAULT_PARQUET_PATH,
|
||||
"data_path": DEFAULT_DATA_PATH,
|
||||
"video_path": DEFAULT_VIDEO_PATH if video_keys else None,
|
||||
"features": features,
|
||||
}
|
||||
|
|
|
@ -23,7 +23,7 @@ import logging
|
|||
from huggingface_hub import HfApi
|
||||
|
||||
from lerobot.common.datasets.lerobot_dataset import CODEBASE_VERSION, LeRobotDataset
|
||||
from lerobot.common.datasets.utils import EPISODES_STATS_PATH, STATS_PATH, load_stats, write_info
|
||||
from lerobot.common.datasets.utils import LEGACY_EPISODES_STATS_PATH, LEGACY_STATS_PATH, load_stats, write_info
|
||||
from lerobot.common.datasets.v21.convert_stats import check_aggregate_stats, convert_stats
|
||||
|
||||
V20 = "v2.0"
|
||||
|
@ -47,8 +47,8 @@ def convert_dataset(
|
|||
with SuppressWarnings():
|
||||
dataset = LeRobotDataset(repo_id, revision=V20, force_cache_sync=True)
|
||||
|
||||
if (dataset.root / EPISODES_STATS_PATH).is_file():
|
||||
(dataset.root / EPISODES_STATS_PATH).unlink()
|
||||
if (dataset.root / LEGACY_EPISODES_STATS_PATH).is_file():
|
||||
(dataset.root / LEGACY_EPISODES_STATS_PATH).unlink()
|
||||
|
||||
convert_stats(dataset, num_workers=num_workers)
|
||||
ref_stats = load_stats(dataset.root)
|
||||
|
@ -60,15 +60,15 @@ def convert_dataset(
|
|||
dataset.push_to_hub(branch=branch, tag_version=False, allow_patterns="meta/")
|
||||
|
||||
# delete old stats.json file
|
||||
if (dataset.root / STATS_PATH).is_file:
|
||||
(dataset.root / STATS_PATH).unlink()
|
||||
if (dataset.root / LEGACY_STATS_PATH).is_file:
|
||||
(dataset.root / LEGACY_STATS_PATH).unlink()
|
||||
|
||||
hub_api = HfApi()
|
||||
if hub_api.file_exists(
|
||||
repo_id=dataset.repo_id, filename=STATS_PATH, revision=branch, repo_type="dataset"
|
||||
repo_id=dataset.repo_id, filename=LEGACY_STATS_PATH, revision=branch, repo_type="dataset"
|
||||
):
|
||||
hub_api.delete_file(
|
||||
path_in_repo=STATS_PATH, repo_id=dataset.repo_id, revision=branch, repo_type="dataset"
|
||||
path_in_repo=LEGACY_STATS_PATH, repo_id=dataset.repo_id, revision=branch, repo_type="dataset"
|
||||
)
|
||||
|
||||
hub_api.create_tag(repo_id, tag=CODEBASE_VERSION, revision=branch, repo_type="dataset")
|
||||
|
|
|
@ -5,7 +5,7 @@ from tqdm import tqdm
|
|||
|
||||
from lerobot.common.datasets.compute_stats import aggregate_stats, get_feature_stats, sample_indices
|
||||
from lerobot.common.datasets.lerobot_dataset import LeRobotDataset
|
||||
from lerobot.common.datasets.utils import write_episode_stats
|
||||
from lerobot.common.datasets.utils import legacy_write_episode_stats
|
||||
|
||||
|
||||
def sample_episode_video_frames(dataset: LeRobotDataset, episode_index: int, ft_key: str) -> np.ndarray:
|
||||
|
@ -58,7 +58,7 @@ def convert_stats(dataset: LeRobotDataset, num_workers: int = 0):
|
|||
convert_episode_stats(dataset, ep_idx)
|
||||
|
||||
for ep_idx in tqdm(range(total_episodes)):
|
||||
write_episode_stats(ep_idx, dataset.meta.episodes_stats[ep_idx], dataset.root)
|
||||
legacy_write_episode_stats(ep_idx, dataset.meta.episodes_stats[ep_idx], dataset.root)
|
||||
|
||||
|
||||
def check_aggregate_stats(
|
||||
|
|
|
@ -19,26 +19,301 @@ python lerobot/common/datasets/v30/convert_dataset_v21_to_v30.py \
|
|||
|
||||
import argparse
|
||||
import logging
|
||||
from pathlib import Path
|
||||
import sys
|
||||
|
||||
from datasets import Dataset
|
||||
from huggingface_hub import snapshot_download
|
||||
import tqdm
|
||||
|
||||
from lerobot.common.constants import HF_LEROBOT_HOME
|
||||
from lerobot.common.datasets.utils import (
|
||||
load_episodes_stats,
|
||||
DEFAULT_CHUNK_SIZE,
|
||||
DEFAULT_DATA_PATH,
|
||||
DEFAULT_FILE_SIZE_IN_MB,
|
||||
DEFAULT_VIDEO_PATH,
|
||||
concat_video_files,
|
||||
flatten_dict,
|
||||
get_parquet_num_frames,
|
||||
get_video_duration_in_s,
|
||||
get_video_size_in_mb,
|
||||
legacy_load_episodes,
|
||||
legacy_load_episodes_stats,
|
||||
load_info,
|
||||
legacy_load_tasks,
|
||||
update_chunk_file_indices,
|
||||
write_episodes,
|
||||
write_episodes_stats,
|
||||
write_info,
|
||||
write_tasks,
|
||||
)
|
||||
import subprocess
|
||||
import tempfile
|
||||
import pandas as pd
|
||||
import pyarrow.parquet as pq
|
||||
|
||||
V21 = "v2.1"
|
||||
|
||||
|
||||
class SuppressWarnings:
|
||||
def __enter__(self):
|
||||
self.previous_level = logging.getLogger().getEffectiveLevel()
|
||||
logging.getLogger().setLevel(logging.ERROR)
|
||||
"""
|
||||
-------------------------
|
||||
OLD
|
||||
data/chunk-000/episode_000000.parquet
|
||||
|
||||
def __exit__(self, exc_type, exc_val, exc_tb):
|
||||
logging.getLogger().setLevel(self.previous_level)
|
||||
NEW
|
||||
data/chunk-000/file_000.parquet
|
||||
-------------------------
|
||||
OLD
|
||||
videos/chunk-000/CAMERA/episode_000000.mp4
|
||||
|
||||
NEW
|
||||
videos/chunk-000/file_000.mp4
|
||||
-------------------------
|
||||
OLD
|
||||
episodes.jsonl
|
||||
{"episode_index": 1, "tasks": ["Put the blue block in the green bowl"], "length": 266}
|
||||
|
||||
NEW
|
||||
meta/episodes/chunk-000/episodes_000.parquet
|
||||
episode_index | video_chunk_index | video_file_index | data_chunk_index | data_file_index | tasks | length
|
||||
-------------------------
|
||||
OLD
|
||||
tasks.jsonl
|
||||
{"task_index": 1, "task": "Put the blue block in the green bowl"}
|
||||
|
||||
NEW
|
||||
meta/tasks/chunk-000/file_000.parquet
|
||||
task_index | task
|
||||
-------------------------
|
||||
OLD
|
||||
episodes_stats.jsonl
|
||||
|
||||
NEW
|
||||
meta/episodes_stats/chunk-000/file_000.parquet
|
||||
episode_index | mean | std | min | max
|
||||
-------------------------
|
||||
UPDATE
|
||||
meta/info.json
|
||||
-------------------------
|
||||
"""
|
||||
|
||||
def get_parquet_file_size_in_mb(parquet_path):
|
||||
metadata = pq.read_metadata(parquet_path)
|
||||
uncompressed_size = metadata.num_rows * metadata.row_group(0).total_byte_size
|
||||
return uncompressed_size / (1024 ** 2)
|
||||
|
||||
|
||||
|
||||
def generate_flat_ep_stats(episodes_stats):
|
||||
for ep_idx, ep_stats in episodes_stats.items():
|
||||
flat_ep_stats = flatten_dict(ep_stats)
|
||||
flat_ep_stats["episode_index"] = ep_idx
|
||||
yield flat_ep_stats
|
||||
|
||||
def convert_episodes_stats(root, new_root):
|
||||
episodes_stats = legacy_load_episodes_stats(root)
|
||||
ds_episodes_stats = Dataset.from_generator(lambda: generate_flat_ep_stats(episodes_stats))
|
||||
write_episodes_stats(ds_episodes_stats, new_root)
|
||||
|
||||
def generate_task_dict(tasks):
|
||||
for task_index, task in tasks.items():
|
||||
yield {"task_index": task_index, "task": task}
|
||||
|
||||
def convert_tasks(root, new_root):
|
||||
tasks, _ = legacy_load_tasks(root)
|
||||
ds_tasks = Dataset.from_generator(lambda: generate_task_dict(tasks))
|
||||
write_tasks(ds_tasks, new_root)
|
||||
|
||||
|
||||
def concat_data_files(paths_to_cat, new_root, chunk_idx, file_idx):
|
||||
# TODO(rcadene): to save RAM use Dataset.from_parquet(file) and concatenate_datasets
|
||||
dataframes = [pd.read_parquet(file) for file in paths_to_cat]
|
||||
# Concatenate all DataFrames along rows
|
||||
concatenated_df = pd.concat(dataframes, ignore_index=True)
|
||||
|
||||
path = new_root / DEFAULT_DATA_PATH.format(chunk_index=chunk_idx, file_index=file_idx)
|
||||
path.parent.mkdir(parents=True, exist_ok=True)
|
||||
concatenated_df.to_parquet(path, index=False)
|
||||
|
||||
|
||||
def convert_data(root, new_root):
|
||||
data_dir = root / "data"
|
||||
|
||||
ep_paths = [path for path in data_dir.glob("*/*.parquet")]
|
||||
ep_paths = sorted(ep_paths)
|
||||
|
||||
episodes_metadata = []
|
||||
ep_idx = 0
|
||||
chunk_idx = 0
|
||||
file_idx = 0
|
||||
size_in_mb = 0
|
||||
num_frames = 0
|
||||
paths_to_cat = []
|
||||
for ep_path in ep_paths:
|
||||
ep_size_in_mb = get_parquet_file_size_in_mb(ep_path)
|
||||
ep_num_frames = get_parquet_num_frames(ep_path)
|
||||
ep_metadata = {
|
||||
"episode_index": ep_idx,
|
||||
"data/chunk_index": chunk_idx,
|
||||
"data/file_index": file_idx,
|
||||
"data/from_index": num_frames,
|
||||
"data/to_index": num_frames + ep_num_frames,
|
||||
}
|
||||
size_in_mb += ep_size_in_mb
|
||||
num_frames += ep_num_frames
|
||||
episodes_metadata.append(ep_metadata)
|
||||
ep_idx += 1
|
||||
|
||||
if size_in_mb < DEFAULT_FILE_SIZE_IN_MB:
|
||||
paths_to_cat.append(ep_path)
|
||||
continue
|
||||
|
||||
concat_data_files(paths_to_cat, new_root, chunk_idx, file_idx)
|
||||
|
||||
# Reset for the next file
|
||||
size_in_mb = ep_size_in_mb
|
||||
num_frames = ep_num_frames
|
||||
paths_to_cat = [ep_path]
|
||||
|
||||
chunk_idx, file_idx = update_chunk_file_indices(chunk_idx, file_idx, DEFAULT_CHUNK_SIZE)
|
||||
|
||||
# Write remaining data if any
|
||||
if paths_to_cat:
|
||||
concat_data_files(paths_to_cat, new_root, chunk_idx, file_idx)
|
||||
|
||||
return episodes_metadata
|
||||
|
||||
|
||||
|
||||
def get_video_keys(root):
|
||||
info = load_info(root)
|
||||
features = info["features"]
|
||||
image_keys = [key for key, ft in features.items() if ft["dtype"] == "image"]
|
||||
if len(image_keys) != 0:
|
||||
raise NotImplementedError()
|
||||
|
||||
video_keys = [key for key, ft in features.items() if ft["dtype"] == "video"]
|
||||
return video_keys
|
||||
|
||||
|
||||
def convert_videos(root: Path, new_root: Path):
|
||||
video_keys = get_video_keys(root)
|
||||
video_keys = sorted(video_keys)
|
||||
|
||||
eps_metadata_per_cam = []
|
||||
for camera in video_keys:
|
||||
eps_metadata = convert_videos_of_camera(root, new_root, camera)
|
||||
eps_metadata_per_cam.append(eps_metadata)
|
||||
|
||||
num_eps_per_cam = [len(eps_cam_map) for eps_cam_map in eps_metadata_per_cam]
|
||||
if len(set(num_eps_per_cam)) != 1:
|
||||
raise ValueError(f"All cams dont have same number of episodes ({num_eps_per_cam}).")
|
||||
|
||||
episods_metadata = []
|
||||
num_cameras = len(video_keys)
|
||||
num_episodes = num_eps_per_cam[0]
|
||||
for ep_idx in range(num_episodes):
|
||||
# Sanity check
|
||||
ep_ids = [eps_metadata_per_cam[cam_idx][ep_idx]["episode_index"] for cam_idx in range(num_cameras)]
|
||||
ep_ids += [ep_idx]
|
||||
if len(set(ep_ids)) != 1:
|
||||
raise ValueError(f"All episode indices need to match ({ep_ids}).")
|
||||
|
||||
ep_dict = {}
|
||||
for cam_idx in range(num_cameras):
|
||||
ep_dict.update(eps_metadata_per_cam[cam_idx][ep_idx])
|
||||
episods_metadata.append(ep_dict)
|
||||
|
||||
return episods_metadata
|
||||
|
||||
|
||||
def convert_videos_of_camera(root: Path, new_root: Path, video_key):
|
||||
# Access old paths to mp4
|
||||
videos_dir = root / "videos"
|
||||
ep_paths = [path for path in videos_dir.glob(f"*/{video_key}/*.mp4")]
|
||||
ep_paths = sorted(ep_paths)
|
||||
|
||||
episodes_metadata = []
|
||||
ep_idx = 0
|
||||
chunk_idx = 0
|
||||
file_idx = 0
|
||||
size_in_mb = 0
|
||||
duration_in_s = 0.0
|
||||
paths_to_cat = []
|
||||
for ep_path in tqdm.tqdm(ep_paths, desc=f"convert videos of {video_key}"):
|
||||
ep_size_in_mb = get_video_size_in_mb(ep_path)
|
||||
ep_duration_in_s = get_video_duration_in_s(ep_path)
|
||||
ep_metadata = {
|
||||
"episode_index": ep_idx,
|
||||
f"{video_key}/chunk_index": chunk_idx,
|
||||
f"{video_key}/file_index": file_idx,
|
||||
f"{video_key}/from_timestamp": duration_in_s,
|
||||
f"{video_key}/to_timestamp": duration_in_s + ep_duration_in_s,
|
||||
}
|
||||
size_in_mb += ep_size_in_mb
|
||||
duration_in_s += ep_duration_in_s
|
||||
episodes_metadata.append(ep_metadata)
|
||||
ep_idx += 1
|
||||
|
||||
if size_in_mb < DEFAULT_FILE_SIZE_IN_MB:
|
||||
paths_to_cat.append(ep_path)
|
||||
continue
|
||||
|
||||
concat_video_files(paths_to_cat, new_root, video_key, chunk_idx, file_idx)
|
||||
|
||||
# Reset for the next file
|
||||
size_in_mb = ep_size_in_mb
|
||||
duration_in_s = ep_duration_in_s
|
||||
paths_to_cat = [ep_path]
|
||||
|
||||
chunk_idx, file_idx = update_chunk_file_indices(chunk_idx, file_idx, DEFAULT_CHUNK_SIZE)
|
||||
|
||||
# Write remaining videos if any
|
||||
if paths_to_cat:
|
||||
concat_video_files(paths_to_cat, new_root, video_key, chunk_idx, file_idx)
|
||||
|
||||
return episodes_metadata
|
||||
|
||||
def generate_episode_dict(episodes, episodes_data, episodes_videos):
|
||||
for ep, ep_data, ep_video in zip(episodes.values(), episodes_data, episodes_videos):
|
||||
ep_idx = ep["episode_index"]
|
||||
ep_idx_data = ep_data["episode_index"]
|
||||
ep_idx_video = ep_video["episode_index"]
|
||||
|
||||
if len(set([ep_idx, ep_idx_data, ep_idx_video])) != 1:
|
||||
raise ValueError(f"Number of episodes is not the same ({ep_idx=},{ep_idx_data=},{ep_idx_video=}).")
|
||||
|
||||
ep_dict = {**ep_data, **ep_video, **ep}
|
||||
yield ep_dict
|
||||
|
||||
def convert_episodes(root, new_root, episodes_data, episodes_videos):
|
||||
episodes = legacy_load_episodes(root)
|
||||
|
||||
num_eps = len(episodes)
|
||||
num_eps_data = len(episodes_data)
|
||||
num_eps_video = len(episodes_videos)
|
||||
if len(set([num_eps, num_eps_data, num_eps_video])) != 1:
|
||||
raise ValueError(f"Number of episodes is not the same ({num_eps=},{num_eps_data=},{num_eps_video=}).")
|
||||
|
||||
ds_episodes = Dataset.from_generator(lambda: generate_episode_dict(episodes, episodes_data, episodes_videos))
|
||||
write_episodes(ds_episodes, new_root)
|
||||
|
||||
def convert_info(root, new_root):
|
||||
info = load_info(root)
|
||||
info["codebase_version"] = "v3.0"
|
||||
del info["total_chunks"]
|
||||
del info["total_videos"]
|
||||
info["files_size_in_mb"] = DEFAULT_FILE_SIZE_IN_MB
|
||||
# TODO(rcadene): chunk- or chunk_ or file- or file_
|
||||
info["data_path"] = DEFAULT_DATA_PATH
|
||||
info["video_path"] = DEFAULT_VIDEO_PATH
|
||||
info["fps"] = float(info["fps"])
|
||||
for key in info["features"]:
|
||||
if info["features"][key]["dtype"] == "video":
|
||||
# already has fps in video_info
|
||||
continue
|
||||
info["features"][key]["fps"] = info["fps"]
|
||||
write_info(info, new_root)
|
||||
|
||||
def convert_dataset(
|
||||
repo_id: str,
|
||||
|
@ -46,6 +321,8 @@ def convert_dataset(
|
|||
num_workers: int = 4,
|
||||
):
|
||||
root = HF_LEROBOT_HOME / repo_id
|
||||
new_root = HF_LEROBOT_HOME / f"{repo_id}_v30"
|
||||
|
||||
snapshot_download(
|
||||
repo_id,
|
||||
repo_type="dataset",
|
||||
|
@ -53,63 +330,12 @@ def convert_dataset(
|
|||
local_dir=root,
|
||||
)
|
||||
|
||||
# Concatenate videos
|
||||
|
||||
# Create
|
||||
|
||||
"""
|
||||
-------------------------
|
||||
OLD
|
||||
data/chunk-000/episode_000000.parquet
|
||||
|
||||
NEW
|
||||
data/chunk-000/file_000.parquet
|
||||
-------------------------
|
||||
OLD
|
||||
videos/chunk-000/CAMERA/episode_000000.mp4
|
||||
|
||||
NEW
|
||||
videos/chunk-000/file_000.mp4
|
||||
-------------------------
|
||||
OLD
|
||||
episodes.jsonl
|
||||
{"episode_index": 1, "tasks": ["Put the blue block in the green bowl"], "length": 266}
|
||||
|
||||
NEW
|
||||
meta/episodes/chunk-000/episodes_000.parquet
|
||||
episode_index | video_chunk_index | video_file_index | data_chunk_index | data_file_index | tasks | length
|
||||
-------------------------
|
||||
OLD
|
||||
tasks.jsonl
|
||||
{"task_index": 1, "task": "Put the blue block in the green bowl"}
|
||||
|
||||
NEW
|
||||
meta/tasks/chunk-000/file_000.parquet
|
||||
task_index | task
|
||||
-------------------------
|
||||
OLD
|
||||
episodes_stats.jsonl
|
||||
|
||||
NEW
|
||||
meta/episodes_stats/chunk-000/file_000.parquet
|
||||
episode_index | mean | std | min | max
|
||||
-------------------------
|
||||
UPDATE
|
||||
meta/info.json
|
||||
-------------------------
|
||||
"""
|
||||
|
||||
new_root = HF_LEROBOT_HOME / f"{repo_id}_v30"
|
||||
new_root.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
episodes_stats = load_episodes_stats(root)
|
||||
hf_dataset = Dataset.from_dict(episodes_stats) # noqa: F841
|
||||
|
||||
meta_ep_st_ch = new_root / "meta/episodes_stats/chunk-000"
|
||||
meta_ep_st_ch.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# hf_dataset.to_parquet(meta_ep_st_ch / 'file_000.parquet')
|
||||
|
||||
convert_info(root, new_root)
|
||||
convert_episodes_stats(root, new_root)
|
||||
convert_tasks(root, new_root)
|
||||
episodes_data_mapping = convert_data(root, new_root)
|
||||
episodes_videos_mapping = convert_videos(root, new_root)
|
||||
convert_episodes(root, new_root, episodes_data_mapping, episodes_videos_mapping)
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
|
|
|
@ -9,13 +9,16 @@ import numpy as np
|
|||
import PIL.Image
|
||||
import pytest
|
||||
import torch
|
||||
from datasets import Dataset
|
||||
|
||||
from lerobot.common.datasets.lerobot_dataset import CODEBASE_VERSION, LeRobotDataset, LeRobotDatasetMetadata
|
||||
from lerobot.common.datasets.utils import (
|
||||
DEFAULT_CHUNK_SIZE,
|
||||
DEFAULT_FEATURES,
|
||||
DEFAULT_PARQUET_PATH,
|
||||
DEFAULT_DATA_PATH,
|
||||
DEFAULT_FILE_SIZE_IN_MB,
|
||||
DEFAULT_VIDEO_PATH,
|
||||
flatten_dict,
|
||||
get_hf_features_from_features,
|
||||
hf_transform_to_torch,
|
||||
)
|
||||
|
@ -33,10 +36,9 @@ class LeRobotDatasetFactory(Protocol):
|
|||
def __call__(self, *args, **kwargs) -> LeRobotDataset: ...
|
||||
|
||||
|
||||
def get_task_index(task_dicts: dict, task: str) -> int:
|
||||
tasks = {d["task_index"]: d["task"] for d in task_dicts.values()}
|
||||
task_to_task_index = {task: task_idx for task_idx, task in tasks.items()}
|
||||
return task_to_task_index[task]
|
||||
def get_task_index(tasks: Dataset, task: str) -> int:
|
||||
task_idx = tasks["task"].index(task)
|
||||
return task_idx
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
|
@ -104,9 +106,9 @@ def info_factory(features_factory):
|
|||
total_frames: int = 0,
|
||||
total_tasks: int = 0,
|
||||
total_videos: int = 0,
|
||||
total_chunks: int = 0,
|
||||
chunks_size: int = DEFAULT_CHUNK_SIZE,
|
||||
data_path: str = DEFAULT_PARQUET_PATH,
|
||||
files_size_in_mb: float = DEFAULT_FILE_SIZE_IN_MB,
|
||||
data_path: str = DEFAULT_DATA_PATH,
|
||||
video_path: str = DEFAULT_VIDEO_PATH,
|
||||
motor_features: dict = DUMMY_MOTOR_FEATURES,
|
||||
camera_features: dict = DUMMY_CAMERA_FEATURES,
|
||||
|
@ -120,8 +122,8 @@ def info_factory(features_factory):
|
|||
"total_frames": total_frames,
|
||||
"total_tasks": total_tasks,
|
||||
"total_videos": total_videos,
|
||||
"total_chunks": total_chunks,
|
||||
"chunks_size": chunks_size,
|
||||
"files_size_in_mb": files_size_in_mb,
|
||||
"fps": fps,
|
||||
"splits": {},
|
||||
"data_path": data_path,
|
||||
|
@ -168,25 +170,25 @@ def episodes_stats_factory(stats_factory):
|
|||
features: dict[str],
|
||||
total_episodes: int = 3,
|
||||
) -> dict:
|
||||
episodes_stats = {}
|
||||
for episode_index in range(total_episodes):
|
||||
episodes_stats[episode_index] = {
|
||||
"episode_index": episode_index,
|
||||
"stats": stats_factory(features),
|
||||
}
|
||||
return episodes_stats
|
||||
|
||||
def _generator(total_episodes):
|
||||
for ep_idx in range(total_episodes):
|
||||
flat_ep_stats = flatten_dict(stats_factory(features))
|
||||
flat_ep_stats["episode_index"] = ep_idx
|
||||
yield flat_ep_stats
|
||||
|
||||
# Simpler to rely on generator instead of from_dict
|
||||
return Dataset.from_generator(lambda: _generator(total_episodes))
|
||||
|
||||
return _create_episodes_stats
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def tasks_factory():
|
||||
def _create_tasks(total_tasks: int = 3) -> int:
|
||||
tasks = {}
|
||||
for task_index in range(total_tasks):
|
||||
task_dict = {"task_index": task_index, "task": f"Perform action {task_index}."}
|
||||
tasks[task_index] = task_dict
|
||||
return tasks
|
||||
def _create_tasks(total_tasks: int = 3) -> Dataset:
|
||||
ids = list(range(total_tasks))
|
||||
tasks = [f"Perform action {i}." for i in ids]
|
||||
return Dataset.from_dict({"task_index": ids, "task": tasks})
|
||||
|
||||
return _create_tasks
|
||||
|
||||
|
@ -196,6 +198,7 @@ def episodes_factory(tasks_factory):
|
|||
def _create_episodes(
|
||||
total_episodes: int = 3,
|
||||
total_frames: int = 400,
|
||||
video_keys: list[str] | None = None,
|
||||
tasks: dict | None = None,
|
||||
multi_task: bool = False,
|
||||
):
|
||||
|
@ -215,26 +218,41 @@ def episodes_factory(tasks_factory):
|
|||
# Generate random lengths that sum up to total_length
|
||||
lengths = np.random.multinomial(total_frames, [1 / total_episodes] * total_episodes).tolist()
|
||||
|
||||
tasks_list = [task_dict["task"] for task_dict in tasks.values()]
|
||||
num_tasks_available = len(tasks_list)
|
||||
num_tasks_available = len(tasks["task"])
|
||||
|
||||
episodes = {}
|
||||
remaining_tasks = tasks_list.copy()
|
||||
d = {
|
||||
"episode_index": [],
|
||||
"data/chunk_index": [],
|
||||
"data/file_index": [],
|
||||
"tasks": [],
|
||||
"length": [],
|
||||
}
|
||||
if video_keys is not None:
|
||||
for video_key in video_keys:
|
||||
d[f"{video_key}/chunk_index"] = []
|
||||
d[f"{video_key}/file_index"] = []
|
||||
|
||||
remaining_tasks = tasks["task"].copy()
|
||||
for ep_idx in range(total_episodes):
|
||||
num_tasks_in_episode = random.randint(1, min(3, num_tasks_available)) if multi_task else 1
|
||||
tasks_to_sample = remaining_tasks if remaining_tasks else tasks_list
|
||||
tasks_to_sample = remaining_tasks if remaining_tasks else tasks["task"]
|
||||
episode_tasks = random.sample(tasks_to_sample, min(num_tasks_in_episode, len(tasks_to_sample)))
|
||||
if remaining_tasks:
|
||||
for task in episode_tasks:
|
||||
remaining_tasks.remove(task)
|
||||
|
||||
episodes[ep_idx] = {
|
||||
"episode_index": ep_idx,
|
||||
"tasks": episode_tasks,
|
||||
"length": lengths[ep_idx],
|
||||
}
|
||||
d["episode_index"].append(ep_idx)
|
||||
# TODO(rcadene): remove heuristic of only one file
|
||||
d["data/chunk_index"].append(0)
|
||||
d["data/file_index"].append(0)
|
||||
d["tasks"].append(episode_tasks)
|
||||
d["length"].append(lengths[ep_idx])
|
||||
if video_keys is not None:
|
||||
for video_key in video_keys:
|
||||
d[f"{video_key}/chunk_index"].append(0)
|
||||
d[f"{video_key}/file_index"].append(0)
|
||||
|
||||
return episodes
|
||||
return Dataset.from_dict(d)
|
||||
|
||||
return _create_episodes
|
||||
|
||||
|
@ -258,7 +276,7 @@ def hf_dataset_factory(features_factory, tasks_factory, episodes_factory, img_ar
|
|||
frame_index_col = np.array([], dtype=np.int64)
|
||||
episode_index_col = np.array([], dtype=np.int64)
|
||||
task_index = np.array([], dtype=np.int64)
|
||||
for ep_dict in episodes.values():
|
||||
for ep_dict in episodes:
|
||||
timestamp_col = np.concatenate((timestamp_col, np.arange(ep_dict["length"]) / fps))
|
||||
frame_index_col = np.concatenate((frame_index_col, np.arange(ep_dict["length"], dtype=int)))
|
||||
episode_index_col = np.concatenate(
|
||||
|
@ -291,7 +309,7 @@ def hf_dataset_factory(features_factory, tasks_factory, episodes_factory, img_ar
|
|||
},
|
||||
features=hf_features,
|
||||
)
|
||||
dataset.set_transform(hf_transform_to_torch)
|
||||
dataset.set_format("torch")
|
||||
return dataset
|
||||
|
||||
return _create_hf_dataset
|
||||
|
@ -326,8 +344,9 @@ def lerobot_dataset_metadata_factory(
|
|||
if not tasks:
|
||||
tasks = tasks_factory(total_tasks=info["total_tasks"])
|
||||
if not episodes:
|
||||
video_keys = [key for key, ft in info["features"].items() if ft["dtype"] == "video"]
|
||||
episodes = episodes_factory(
|
||||
total_episodes=info["total_episodes"], total_frames=info["total_frames"], tasks=tasks
|
||||
total_episodes=info["total_episodes"], total_frames=info["total_frames"], video_keys=video_keys, tasks=tasks
|
||||
)
|
||||
|
||||
mock_snapshot_download = mock_snapshot_download_factory(
|
||||
|
@ -371,9 +390,9 @@ def lerobot_dataset_factory(
|
|||
multi_task: bool = False,
|
||||
info: dict | None = None,
|
||||
stats: dict | None = None,
|
||||
episodes_stats: list[dict] | None = None,
|
||||
tasks: list[dict] | None = None,
|
||||
episode_dicts: list[dict] | None = None,
|
||||
episodes_stats: datasets.Dataset | None = None,
|
||||
tasks: datasets.Dataset | None = None,
|
||||
episode_dicts: datasets.Dataset | None = None,
|
||||
hf_dataset: datasets.Dataset | None = None,
|
||||
**kwargs,
|
||||
) -> LeRobotDataset:
|
||||
|
@ -388,9 +407,11 @@ def lerobot_dataset_factory(
|
|||
if not tasks:
|
||||
tasks = tasks_factory(total_tasks=info["total_tasks"])
|
||||
if not episode_dicts:
|
||||
video_keys = [key for key, ft in info["features"].items() if ft["dtype"] == "video"]
|
||||
episode_dicts = episodes_factory(
|
||||
total_episodes=info["total_episodes"],
|
||||
total_frames=info["total_frames"],
|
||||
video_keys=video_keys,
|
||||
tasks=tasks,
|
||||
multi_task=multi_task,
|
||||
)
|
||||
|
|
|
@ -7,83 +7,75 @@ import pyarrow.compute as pc
|
|||
import pyarrow.parquet as pq
|
||||
import pytest
|
||||
|
||||
from datasets import Dataset
|
||||
|
||||
from lerobot.common.datasets.utils import (
|
||||
EPISODES_PATH,
|
||||
EPISODES_STATS_PATH,
|
||||
INFO_PATH,
|
||||
STATS_PATH,
|
||||
TASKS_PATH,
|
||||
write_episodes,
|
||||
write_episodes_stats,
|
||||
write_hf_dataset,
|
||||
write_info,
|
||||
write_stats,
|
||||
write_tasks,
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def info_path(info_factory):
|
||||
def _create_info_json_file(dir: Path, info: dict | None = None) -> Path:
|
||||
def create_info(info_factory):
|
||||
def _create_info(dir: Path, info: dict | None = None):
|
||||
if not info:
|
||||
info = info_factory()
|
||||
fpath = dir / INFO_PATH
|
||||
fpath.parent.mkdir(parents=True, exist_ok=True)
|
||||
with open(fpath, "w") as f:
|
||||
json.dump(info, f, indent=4, ensure_ascii=False)
|
||||
return fpath
|
||||
write_info(info, dir)
|
||||
|
||||
return _create_info_json_file
|
||||
return _create_info
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def stats_path(stats_factory):
|
||||
def _create_stats_json_file(dir: Path, stats: dict | None = None) -> Path:
|
||||
def create_stats(stats_factory):
|
||||
def _create_stats(dir: Path, stats: dict | None = None):
|
||||
if not stats:
|
||||
stats = stats_factory()
|
||||
fpath = dir / STATS_PATH
|
||||
fpath.parent.mkdir(parents=True, exist_ok=True)
|
||||
with open(fpath, "w") as f:
|
||||
json.dump(stats, f, indent=4, ensure_ascii=False)
|
||||
return fpath
|
||||
write_stats(stats, dir)
|
||||
|
||||
return _create_stats_json_file
|
||||
return _create_stats
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def episodes_stats_path(episodes_stats_factory):
|
||||
def _create_episodes_stats_jsonl_file(dir: Path, episodes_stats: list[dict] | None = None) -> Path:
|
||||
def create_episodes_stats(episodes_stats_factory):
|
||||
def _create_episodes_stats(dir: Path, episodes_stats: Dataset | None = None):
|
||||
if not episodes_stats:
|
||||
episodes_stats = episodes_stats_factory()
|
||||
fpath = dir / EPISODES_STATS_PATH
|
||||
fpath.parent.mkdir(parents=True, exist_ok=True)
|
||||
with jsonlines.open(fpath, "w") as writer:
|
||||
writer.write_all(episodes_stats.values())
|
||||
return fpath
|
||||
write_episodes_stats(episodes_stats, dir)
|
||||
|
||||
return _create_episodes_stats_jsonl_file
|
||||
return _create_episodes_stats
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def tasks_path(tasks_factory):
|
||||
def _create_tasks_jsonl_file(dir: Path, tasks: list | None = None) -> Path:
|
||||
def create_tasks(tasks_factory):
|
||||
def _create_tasks(dir: Path, tasks: Dataset | None = None):
|
||||
if not tasks:
|
||||
tasks = tasks_factory()
|
||||
fpath = dir / TASKS_PATH
|
||||
fpath.parent.mkdir(parents=True, exist_ok=True)
|
||||
with jsonlines.open(fpath, "w") as writer:
|
||||
writer.write_all(tasks.values())
|
||||
return fpath
|
||||
write_tasks(tasks, dir)
|
||||
|
||||
return _create_tasks_jsonl_file
|
||||
return _create_tasks
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def episode_path(episodes_factory):
|
||||
def _create_episodes_jsonl_file(dir: Path, episodes: list | None = None) -> Path:
|
||||
def create_episodes(episodes_factory):
|
||||
def _create_episodes(dir: Path, episodes: Dataset | None = None):
|
||||
if not episodes:
|
||||
episodes = episodes_factory()
|
||||
fpath = dir / EPISODES_PATH
|
||||
fpath.parent.mkdir(parents=True, exist_ok=True)
|
||||
with jsonlines.open(fpath, "w") as writer:
|
||||
writer.write_all(episodes.values())
|
||||
return fpath
|
||||
write_episodes(episodes, dir)
|
||||
|
||||
return _create_episodes_jsonl_file
|
||||
return _create_episodes
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def create_hf_dataset(hf_dataset_factory):
|
||||
def _create_hf_dataset(dir: Path, hf_dataset: Dataset | None = None):
|
||||
if not hf_dataset:
|
||||
hf_dataset = hf_dataset_factory()
|
||||
write_hf_dataset(hf_dataset, dir)
|
||||
|
||||
return _create_hf_dataset
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
|
@ -91,6 +83,7 @@ def single_episode_parquet_path(hf_dataset_factory, info_factory):
|
|||
def _create_single_episode_parquet(
|
||||
dir: Path, ep_idx: int = 0, hf_dataset: datasets.Dataset | None = None, info: dict | None = None
|
||||
) -> Path:
|
||||
raise NotImplementedError()
|
||||
if not info:
|
||||
info = info_factory()
|
||||
if hf_dataset is None:
|
||||
|
@ -114,6 +107,7 @@ def multi_episode_parquet_path(hf_dataset_factory, info_factory):
|
|||
def _create_multi_episode_parquet(
|
||||
dir: Path, hf_dataset: datasets.Dataset | None = None, info: dict | None = None
|
||||
) -> Path:
|
||||
raise NotImplementedError()
|
||||
if not info:
|
||||
info = info_factory()
|
||||
if hf_dataset is None:
|
||||
|
|
|
@ -5,11 +5,12 @@ import pytest
|
|||
from huggingface_hub.utils import filter_repo_objects
|
||||
|
||||
from lerobot.common.datasets.utils import (
|
||||
EPISODES_PATH,
|
||||
EPISODES_STATS_PATH,
|
||||
DEFAULT_DATA_PATH,
|
||||
DEFAULT_EPISODES_PATH,
|
||||
DEFAULT_EPISODES_STATS_PATH,
|
||||
DEFAULT_TASKS_PATH,
|
||||
INFO_PATH,
|
||||
STATS_PATH,
|
||||
TASKS_PATH,
|
||||
LEGACY_STATS_PATH,
|
||||
)
|
||||
from tests.fixtures.constants import LEROBOT_TEST_DIR
|
||||
|
||||
|
@ -17,17 +18,17 @@ from tests.fixtures.constants import LEROBOT_TEST_DIR
|
|||
@pytest.fixture(scope="session")
|
||||
def mock_snapshot_download_factory(
|
||||
info_factory,
|
||||
info_path,
|
||||
create_info,
|
||||
stats_factory,
|
||||
stats_path,
|
||||
create_stats,
|
||||
episodes_stats_factory,
|
||||
episodes_stats_path,
|
||||
create_episodes_stats,
|
||||
tasks_factory,
|
||||
tasks_path,
|
||||
create_tasks,
|
||||
episodes_factory,
|
||||
episode_path,
|
||||
single_episode_parquet_path,
|
||||
create_episodes,
|
||||
hf_dataset_factory,
|
||||
create_hf_dataset,
|
||||
):
|
||||
"""
|
||||
This factory allows to patch snapshot_download such that when called, it will create expected files rather
|
||||
|
@ -37,9 +38,9 @@ def mock_snapshot_download_factory(
|
|||
def _mock_snapshot_download_func(
|
||||
info: dict | None = None,
|
||||
stats: dict | None = None,
|
||||
episodes_stats: list[dict] | None = None,
|
||||
tasks: list[dict] | None = None,
|
||||
episodes: list[dict] | None = None,
|
||||
episodes_stats: datasets.Dataset | None = None,
|
||||
tasks: datasets.Dataset | None = None,
|
||||
episodes: datasets.Dataset | None = None,
|
||||
hf_dataset: datasets.Dataset | None = None,
|
||||
):
|
||||
if not info:
|
||||
|
@ -59,14 +60,6 @@ def mock_snapshot_download_factory(
|
|||
if not hf_dataset:
|
||||
hf_dataset = hf_dataset_factory(tasks=tasks, episodes=episodes, fps=info["fps"])
|
||||
|
||||
def _extract_episode_index_from_path(fpath: str) -> int:
|
||||
path = Path(fpath)
|
||||
if path.suffix == ".parquet" and path.stem.startswith("episode_"):
|
||||
episode_index = int(path.stem[len("episode_") :]) # 'episode_000000' -> 0
|
||||
return episode_index
|
||||
else:
|
||||
return None
|
||||
|
||||
def _mock_snapshot_download(
|
||||
repo_id: str,
|
||||
local_dir: str | Path | None = None,
|
||||
|
@ -79,40 +72,55 @@ def mock_snapshot_download_factory(
|
|||
local_dir = LEROBOT_TEST_DIR
|
||||
|
||||
# List all possible files
|
||||
all_files = []
|
||||
meta_files = [INFO_PATH, STATS_PATH, EPISODES_STATS_PATH, TASKS_PATH, EPISODES_PATH]
|
||||
all_files.extend(meta_files)
|
||||
|
||||
data_files = []
|
||||
for episode_dict in episodes.values():
|
||||
ep_idx = episode_dict["episode_index"]
|
||||
ep_chunk = ep_idx // info["chunks_size"]
|
||||
data_path = info["data_path"].format(episode_chunk=ep_chunk, episode_index=ep_idx)
|
||||
data_files.append(data_path)
|
||||
all_files.extend(data_files)
|
||||
all_files = [
|
||||
INFO_PATH,
|
||||
LEGACY_STATS_PATH,
|
||||
# TODO(rcadene)
|
||||
DEFAULT_TASKS_PATH.format(chunk_index=0, file_index=0),
|
||||
DEFAULT_EPISODES_STATS_PATH.format(chunk_index=0, file_index=0),
|
||||
DEFAULT_EPISODES_PATH.format(chunk_index=0, file_index=0),
|
||||
DEFAULT_DATA_PATH.format(chunk_index=0, file_index=0),
|
||||
]
|
||||
|
||||
allowed_files = filter_repo_objects(
|
||||
all_files, allow_patterns=allow_patterns, ignore_patterns=ignore_patterns
|
||||
)
|
||||
|
||||
# Create allowed files
|
||||
has_info = False
|
||||
has_tasks = False
|
||||
has_episodes = False
|
||||
has_episodes_stats = False
|
||||
has_stats = False
|
||||
has_data = False
|
||||
for rel_path in allowed_files:
|
||||
if rel_path.startswith("data/"):
|
||||
episode_index = _extract_episode_index_from_path(rel_path)
|
||||
if episode_index is not None:
|
||||
_ = single_episode_parquet_path(local_dir, episode_index, hf_dataset, info)
|
||||
if rel_path == INFO_PATH:
|
||||
_ = info_path(local_dir, info)
|
||||
elif rel_path == STATS_PATH:
|
||||
_ = stats_path(local_dir, stats)
|
||||
elif rel_path == EPISODES_STATS_PATH:
|
||||
_ = episodes_stats_path(local_dir, episodes_stats)
|
||||
elif rel_path == TASKS_PATH:
|
||||
_ = tasks_path(local_dir, tasks)
|
||||
elif rel_path == EPISODES_PATH:
|
||||
_ = episode_path(local_dir, episodes)
|
||||
if rel_path.startswith("meta/info.json"):
|
||||
has_info = True
|
||||
elif rel_path.startswith("meta/stats"):
|
||||
has_stats = True
|
||||
elif rel_path.startswith("meta/tasks"):
|
||||
has_tasks = True
|
||||
elif rel_path.startswith("meta/episodes_stats"):
|
||||
has_episodes_stats = True
|
||||
elif rel_path.startswith("meta/episodes"):
|
||||
has_episodes = True
|
||||
elif rel_path.startswith("data/"):
|
||||
has_data = True
|
||||
else:
|
||||
pass
|
||||
raise ValueError(f"{rel_path} not supported.")
|
||||
|
||||
if has_info:
|
||||
create_info(local_dir, info)
|
||||
if has_stats:
|
||||
create_stats(local_dir, stats)
|
||||
if has_tasks:
|
||||
create_tasks(local_dir, tasks)
|
||||
if has_episodes:
|
||||
create_episodes(local_dir, episodes)
|
||||
if has_episodes_stats:
|
||||
create_episodes_stats(local_dir, episodes_stats)
|
||||
if has_data:
|
||||
create_hf_dataset(local_dir, hf_dataset)
|
||||
|
||||
return str(local_dir)
|
||||
|
||||
return _mock_snapshot_download
|
||||
|
|
Loading…
Reference in New Issue