pre-commit run --all-files

This commit is contained in:
Remi Cadene 2025-04-21 09:34:19 +02:00
parent 5a6ea09248
commit 4acf99f622
7 changed files with 85 additions and 33 deletions

View File

@ -26,7 +26,12 @@ from datatrove.pipeline.base import PipelineStep
from examples.port_datasets.droid_rlds.port_droid import DROID_SHARDS from examples.port_datasets.droid_rlds.port_droid import DROID_SHARDS
from lerobot.common.datasets.aggregate import validate_all_metadata from lerobot.common.datasets.aggregate import validate_all_metadata
from lerobot.common.datasets.lerobot_dataset import LeRobotDatasetMetadata from lerobot.common.datasets.lerobot_dataset import LeRobotDatasetMetadata
from lerobot.common.datasets.utils import write_episode, legacy_write_episode_stats, write_info, legacy_write_task from lerobot.common.datasets.utils import (
legacy_write_episode_stats,
legacy_write_task,
write_episode,
write_info,
)
from lerobot.common.utils.utils import init_logging from lerobot.common.utils.utils import init_logging

View File

@ -1,13 +1,21 @@
import logging import logging
from pathlib import Path
import shutil import shutil
from pathlib import Path
import pandas as pd import pandas as pd
import tqdm import tqdm
from lerobot.common.datasets.compute_stats import aggregate_stats from lerobot.common.datasets.compute_stats import aggregate_stats
from lerobot.common.datasets.lerobot_dataset import LeRobotDataset, LeRobotDatasetMetadata from lerobot.common.datasets.lerobot_dataset import LeRobotDataset, LeRobotDatasetMetadata
from lerobot.common.datasets.utils import DEFAULT_CHUNK_SIZE, DEFAULT_DATA_PATH, DEFAULT_EPISODES_PATH, DEFAULT_VIDEO_PATH, write_episode, legacy_write_episode_stats, write_info, legacy_write_task, write_stats, write_tasks from lerobot.common.datasets.utils import (
DEFAULT_CHUNK_SIZE,
DEFAULT_DATA_PATH,
DEFAULT_EPISODES_PATH,
DEFAULT_VIDEO_PATH,
write_info,
write_stats,
write_tasks,
)
from lerobot.common.utils.utils import init_logging from lerobot.common.utils.utils import init_logging
@ -32,14 +40,17 @@ def validate_all_metadata(all_metadata: list[LeRobotDatasetMetadata]):
return fps, robot_type, features return fps, robot_type, features
def get_update_episode_and_task_func(episode_index_to_add, old_tasks, new_tasks): def get_update_episode_and_task_func(episode_index_to_add, old_tasks, new_tasks):
def _update(row): def _update(row):
row["episode_index"] = row["episode_index"] + episode_index_to_add row["episode_index"] = row["episode_index"] + episode_index_to_add
task = old_tasks.iloc[row["task_index"]].name task = old_tasks.iloc[row["task_index"]].name
row["task_index"] = new_tasks.loc[task].task_index.item() row["task_index"] = new_tasks.loc[task].task_index.item()
return row return row
return _update return _update
def get_update_meta_func( def get_update_meta_func(
meta_chunk_index_to_add, meta_chunk_index_to_add,
meta_file_index_to_add, meta_file_index_to_add,
@ -55,20 +66,26 @@ def get_update_meta_func(
row["data/chunk_index"] = row["data/chunk_index"] + data_chunk_index_to_add row["data/chunk_index"] = row["data/chunk_index"] + data_chunk_index_to_add
row["data/file_index"] = row["data/file_index"] + data_file_index_to_add row["data/file_index"] = row["data/file_index"] + data_file_index_to_add
for key in videos_chunk_index_to_add: for key in videos_chunk_index_to_add:
row[f"videos/{key}/chunk_index"] = row[f"videos/{key}/chunk_index"] + videos_chunk_index_to_add[key] row[f"videos/{key}/chunk_index"] = (
row[f"videos/{key}/chunk_index"] + videos_chunk_index_to_add[key]
)
row[f"videos/{key}/file_index"] = row[f"videos/{key}/file_index"] + videos_file_index_to_add[key] row[f"videos/{key}/file_index"] = row[f"videos/{key}/file_index"] + videos_file_index_to_add[key]
row["dataset_from_index"] = row["dataset_from_index"] + frame_index_to_add row["dataset_from_index"] = row["dataset_from_index"] + frame_index_to_add
row["dataset_to_index"] = row["dataset_to_index"] + frame_index_to_add row["dataset_to_index"] = row["dataset_to_index"] + frame_index_to_add
return row return row
return _update return _update
def aggregate_datasets(repo_ids: list[str], aggr_repo_id: str, roots: list[Path]=None, aggr_root=None):
def aggregate_datasets(repo_ids: list[str], aggr_repo_id: str, roots: list[Path] = None, aggr_root=None):
logging.info("Start aggregate_datasets") logging.info("Start aggregate_datasets")
if roots is None: if roots is None:
all_metadata = [LeRobotDatasetMetadata(repo_id) for repo_id in repo_ids] all_metadata = [LeRobotDatasetMetadata(repo_id) for repo_id in repo_ids]
else: else:
all_metadata = [LeRobotDatasetMetadata(repo_id, root=root) for repo_id, root in zip(repo_ids, roots)] all_metadata = [
LeRobotDatasetMetadata(repo_id, root=root) for repo_id, root in zip(repo_ids, roots, strict=False)
]
fps, robot_type, features = validate_all_metadata(all_metadata) fps, robot_type, features = validate_all_metadata(all_metadata)
video_keys = [key for key in features if features[key]["dtype"] == "video"] video_keys = [key for key in features if features[key]["dtype"] == "video"]
@ -96,12 +113,18 @@ def aggregate_datasets(repo_ids: list[str], aggr_repo_id: str, roots: list[Path]
aggr_data_chunk_idx = 0 aggr_data_chunk_idx = 0
aggr_data_file_idx = 0 aggr_data_file_idx = 0
aggr_videos_chunk_idx = {key: 0 for key in video_keys} aggr_videos_chunk_idx = dict.fromkeys(video_keys, 0)
aggr_videos_file_idx = {key: 0 for key in video_keys} aggr_videos_file_idx = dict.fromkeys(video_keys, 0)
for meta in tqdm.tqdm(all_metadata, desc="Copy data and videos"): for meta in tqdm.tqdm(all_metadata, desc="Copy data and videos"):
meta_chunk_file_ids = {
meta_chunk_file_ids = set([(c,f) for c, f in zip(meta.episodes["meta/episodes/chunk_index"], meta.episodes["meta/episodes/file_index"])]) (c, f)
for c, f in zip(
meta.episodes["meta/episodes/chunk_index"],
meta.episodes["meta/episodes/file_index"],
strict=False,
)
}
for chunk_idx, file_idx in meta_chunk_file_ids: for chunk_idx, file_idx in meta_chunk_file_ids:
path = meta.root / DEFAULT_EPISODES_PATH.format(chunk_index=chunk_idx, file_index=file_idx) path = meta.root / DEFAULT_EPISODES_PATH.format(chunk_index=chunk_idx, file_index=file_idx)
df = pd.read_parquet(path) df = pd.read_parquet(path)
@ -115,11 +138,13 @@ def aggregate_datasets(repo_ids: list[str], aggr_repo_id: str, roots: list[Path]
num_frames, num_frames,
) )
df = df.apply(update_meta_func, axis=1) df = df.apply(update_meta_func, axis=1)
aggr_path = aggr_root / DEFAULT_EPISODES_PATH.format(chunk_index=aggr_meta_chunk_idx, file_index=aggr_meta_file_idx) aggr_path = aggr_root / DEFAULT_EPISODES_PATH.format(
chunk_index=aggr_meta_chunk_idx, file_index=aggr_meta_file_idx
)
aggr_path.parent.mkdir(parents=True, exist_ok=True) aggr_path.parent.mkdir(parents=True, exist_ok=True)
df.to_parquet(aggr_path) df.to_parquet(aggr_path)
aggr_meta_file_idx += 1 aggr_meta_file_idx += 1
if aggr_meta_file_idx >= DEFAULT_CHUNK_SIZE: if aggr_meta_file_idx >= DEFAULT_CHUNK_SIZE:
aggr_meta_file_idx = 0 aggr_meta_file_idx = 0
@ -127,10 +152,23 @@ def aggregate_datasets(repo_ids: list[str], aggr_repo_id: str, roots: list[Path]
# cp videos # cp videos
for key in video_keys: for key in video_keys:
video_chunk_file_ids = set([(c,f) for c, f in zip(meta.episodes[f"videos/{key}/chunk_index"], meta.episodes[f"videos/{key}/file_index"])]) video_chunk_file_ids = {
(c, f)
for c, f in zip(
meta.episodes[f"videos/{key}/chunk_index"],
meta.episodes[f"videos/{key}/file_index"],
strict=False,
)
}
for chunk_idx, file_idx in video_chunk_file_ids: for chunk_idx, file_idx in video_chunk_file_ids:
path = meta.root / DEFAULT_VIDEO_PATH.format(video_key=key, chunk_index=chunk_idx, file_index=file_idx) path = meta.root / DEFAULT_VIDEO_PATH.format(
aggr_path = aggr_root / DEFAULT_VIDEO_PATH.format(video_key=key, chunk_index=aggr_videos_chunk_idx[key], file_index=aggr_videos_file_idx[key]) video_key=key, chunk_index=chunk_idx, file_index=file_idx
)
aggr_path = aggr_root / DEFAULT_VIDEO_PATH.format(
video_key=key,
chunk_index=aggr_videos_chunk_idx[key],
file_index=aggr_videos_file_idx[key],
)
aggr_path.parent.mkdir(parents=True, exist_ok=True) aggr_path.parent.mkdir(parents=True, exist_ok=True)
shutil.copy(str(path), str(aggr_path)) shutil.copy(str(path), str(aggr_path))
@ -142,14 +180,19 @@ def aggregate_datasets(repo_ids: list[str], aggr_repo_id: str, roots: list[Path]
aggr_videos_file_idx[key] = 0 aggr_videos_file_idx[key] = 0
aggr_videos_chunk_idx[key] += 1 aggr_videos_chunk_idx[key] += 1
data_chunk_file_ids = set([(c,f) for c, f in zip(meta.episodes["data/chunk_index"], meta.episodes["data/file_index"])]) data_chunk_file_ids = {
(c, f)
for c, f in zip(meta.episodes["data/chunk_index"], meta.episodes["data/file_index"], strict=False)
}
for chunk_idx, file_idx in data_chunk_file_ids: for chunk_idx, file_idx in data_chunk_file_ids:
path = meta.root / DEFAULT_DATA_PATH.format(chunk_index=chunk_idx, file_index=file_idx) path = meta.root / DEFAULT_DATA_PATH.format(chunk_index=chunk_idx, file_index=file_idx)
df = pd.read_parquet(path) df = pd.read_parquet(path)
update_data_func = get_update_episode_and_task_func(num_episodes, meta.tasks, aggr_meta.tasks) update_data_func = get_update_episode_and_task_func(num_episodes, meta.tasks, aggr_meta.tasks)
df = df.apply(update_data_func, axis=1) df = df.apply(update_data_func, axis=1)
aggr_path = aggr_root / DEFAULT_DATA_PATH.format(chunk_index=aggr_data_chunk_idx, file_index=aggr_data_file_idx) aggr_path = aggr_root / DEFAULT_DATA_PATH.format(
chunk_index=aggr_data_chunk_idx, file_index=aggr_data_file_idx
)
aggr_path.parent.mkdir(parents=True, exist_ok=True) aggr_path.parent.mkdir(parents=True, exist_ok=True)
df.to_parquet(aggr_path) df.to_parquet(aggr_path)
@ -157,7 +200,7 @@ def aggregate_datasets(repo_ids: list[str], aggr_repo_id: str, roots: list[Path]
if aggr_data_file_idx >= DEFAULT_CHUNK_SIZE: if aggr_data_file_idx >= DEFAULT_CHUNK_SIZE:
aggr_data_file_idx = 0 aggr_data_file_idx = 0
aggr_data_chunk_idx += 1 aggr_data_chunk_idx += 1
num_episodes += meta.total_episodes num_episodes += meta.total_episodes
num_frames += meta.total_frames num_frames += meta.total_frames

View File

@ -75,7 +75,9 @@ class BackwardCompatibilityError(CompatibilityError):
elif version.major == 2: elif version.major == 2:
message = V2_MESSAGE.format(repo_id=repo_id, version=version) message = V2_MESSAGE.format(repo_id=repo_id, version=version)
else: else:
raise NotImplementedError("Contact the maintainer on [Discord](https://discord.com/invite/s3KuuzsPFb).") raise NotImplementedError(
"Contact the maintainer on [Discord](https://discord.com/invite/s3KuuzsPFb)."
)
super().__init__(message) super().__init__(message)

View File

@ -24,7 +24,7 @@ from collections.abc import Iterator
from pathlib import Path from pathlib import Path
from pprint import pformat from pprint import pformat
from types import SimpleNamespace from types import SimpleNamespace
from typing import Any, Tuple from typing import Any
import datasets import datasets
import numpy as np import numpy as np
@ -47,7 +47,7 @@ from lerobot.common.datasets.backward_compatibility import (
) )
from lerobot.common.robot_devices.robots.utils import Robot from lerobot.common.robot_devices.robots.utils import Robot
from lerobot.common.utils.utils import is_valid_numpy_dtype_string from lerobot.common.utils.utils import is_valid_numpy_dtype_string
from lerobot.configs.types import DictLike, FeatureType, PolicyFeature from lerobot.configs.types import FeatureType, PolicyFeature
DEFAULT_CHUNK_SIZE = 1000 # Max number of files per chunk DEFAULT_CHUNK_SIZE = 1000 # Max number of files per chunk
DEFAULT_FILE_SIZE_IN_MB = 500.0 # Max size per file DEFAULT_FILE_SIZE_IN_MB = 500.0 # Max size per file
@ -249,34 +249,41 @@ def load_json(fpath: Path) -> Any:
with open(fpath) as f: with open(fpath) as f:
return json.load(f) return json.load(f)
def write_json(data: dict, fpath: Path) -> None: def write_json(data: dict, fpath: Path) -> None:
fpath.parent.mkdir(exist_ok=True, parents=True) fpath.parent.mkdir(exist_ok=True, parents=True)
with open(fpath, "w") as f: with open(fpath, "w") as f:
json.dump(data, f, indent=4, ensure_ascii=False) json.dump(data, f, indent=4, ensure_ascii=False)
def write_info(info: dict, local_dir: Path): def write_info(info: dict, local_dir: Path):
write_json(info, local_dir / INFO_PATH) write_json(info, local_dir / INFO_PATH)
def load_info(local_dir: Path) -> dict: def load_info(local_dir: Path) -> dict:
info = load_json(local_dir / INFO_PATH) info = load_json(local_dir / INFO_PATH)
for ft in info["features"].values(): for ft in info["features"].values():
ft["shape"] = tuple(ft["shape"]) ft["shape"] = tuple(ft["shape"])
return info return info
def write_stats(stats: dict, local_dir: Path): def write_stats(stats: dict, local_dir: Path):
serialized_stats = serialize_dict(stats) serialized_stats = serialize_dict(stats)
write_json(serialized_stats, local_dir / STATS_PATH) write_json(serialized_stats, local_dir / STATS_PATH)
def cast_stats_to_numpy(stats) -> dict[str, dict[str, np.ndarray]]: def cast_stats_to_numpy(stats) -> dict[str, dict[str, np.ndarray]]:
stats = {key: np.array(value) for key, value in flatten_dict(stats).items()} stats = {key: np.array(value) for key, value in flatten_dict(stats).items()}
return unflatten_dict(stats) return unflatten_dict(stats)
def load_stats(local_dir: Path) -> dict[str, dict[str, np.ndarray]]: def load_stats(local_dir: Path) -> dict[str, dict[str, np.ndarray]]:
if not (local_dir / STATS_PATH).exists(): if not (local_dir / STATS_PATH).exists():
return None return None
stats = load_json(local_dir / STATS_PATH) stats = load_json(local_dir / STATS_PATH)
return cast_stats_to_numpy(stats) return cast_stats_to_numpy(stats)
def write_hf_dataset(hf_dataset: Dataset, local_dir: Path): def write_hf_dataset(hf_dataset: Dataset, local_dir: Path):
if get_hf_dataset_size_in_mb(hf_dataset) > DEFAULT_FILE_SIZE_IN_MB: if get_hf_dataset_size_in_mb(hf_dataset) > DEFAULT_FILE_SIZE_IN_MB:
raise NotImplementedError("Contact a maintainer.") raise NotImplementedError("Contact a maintainer.")
@ -292,7 +299,6 @@ def write_tasks(tasks: pandas.DataFrame, local_dir: Path):
tasks.to_parquet(path) tasks.to_parquet(path)
def load_tasks(local_dir: Path): def load_tasks(local_dir: Path):
tasks = pd.read_parquet(local_dir / DEFAULT_TASKS_PATH) tasks = pd.read_parquet(local_dir / DEFAULT_TASKS_PATH)
return tasks return tasks

View File

@ -30,7 +30,7 @@ from huggingface_hub import HfApi, snapshot_download
from lerobot.common.constants import HF_LEROBOT_HOME from lerobot.common.constants import HF_LEROBOT_HOME
from lerobot.common.datasets.compute_stats import aggregate_stats from lerobot.common.datasets.compute_stats import aggregate_stats
from lerobot.common.datasets.lerobot_dataset import CODEBASE_VERSION, LeRobotDataset from lerobot.common.datasets.lerobot_dataset import CODEBASE_VERSION
from lerobot.common.datasets.utils import ( from lerobot.common.datasets.utils import (
DEFAULT_CHUNK_SIZE, DEFAULT_CHUNK_SIZE,
DEFAULT_DATA_PATH, DEFAULT_DATA_PATH,
@ -43,11 +43,7 @@ from lerobot.common.datasets.utils import (
get_parquet_num_frames, get_parquet_num_frames,
get_video_duration_in_s, get_video_duration_in_s,
get_video_size_in_mb, get_video_size_in_mb,
legacy_load_episodes,
legacy_load_episodes_stats,
legacy_load_tasks,
load_info, load_info,
serialize_dict,
update_chunk_file_indices, update_chunk_file_indices,
write_episodes, write_episodes,
write_info, write_info,
@ -111,10 +107,12 @@ def load_jsonlines(fpath: Path) -> list[Any]:
with jsonlines.open(fpath, "r") as reader: with jsonlines.open(fpath, "r") as reader:
return list(reader) return list(reader)
def legacy_load_episodes(local_dir: Path) -> dict: def legacy_load_episodes(local_dir: Path) -> dict:
episodes = load_jsonlines(local_dir / LEGACY_EPISODES_PATH) episodes = load_jsonlines(local_dir / LEGACY_EPISODES_PATH)
return {item["episode_index"]: item for item in sorted(episodes, key=lambda x: x["episode_index"])} return {item["episode_index"]: item for item in sorted(episodes, key=lambda x: x["episode_index"])}
def legacy_load_episodes_stats(local_dir: Path) -> dict: def legacy_load_episodes_stats(local_dir: Path) -> dict:
episodes_stats = load_jsonlines(local_dir / LEGACY_EPISODES_STATS_PATH) episodes_stats = load_jsonlines(local_dir / LEGACY_EPISODES_STATS_PATH)
return { return {
@ -122,6 +120,7 @@ def legacy_load_episodes_stats(local_dir: Path) -> dict:
for item in sorted(episodes_stats, key=lambda x: x["episode_index"]) for item in sorted(episodes_stats, key=lambda x: x["episode_index"])
} }
def legacy_load_tasks(local_dir: Path) -> tuple[dict, dict]: def legacy_load_tasks(local_dir: Path) -> tuple[dict, dict]:
tasks = load_jsonlines(local_dir / LEGACY_TASKS_PATH) tasks = load_jsonlines(local_dir / LEGACY_TASKS_PATH)
tasks = {item["task_index"]: item["task"] for item in sorted(tasks, key=lambda x: x["task_index"])} tasks = {item["task_index"]: item["task"] for item in sorted(tasks, key=lambda x: x["task_index"])}
@ -355,8 +354,6 @@ def convert_dataset(
branch: str | None = None, branch: str | None = None,
num_workers: int = 4, num_workers: int = 4,
): ):
root = HF_LEROBOT_HOME / repo_id root = HF_LEROBOT_HOME / repo_id
old_root = HF_LEROBOT_HOME / f"{repo_id}_old" old_root = HF_LEROBOT_HOME / f"{repo_id}_old"
new_root = HF_LEROBOT_HOME / f"{repo_id}_v30" new_root = HF_LEROBOT_HOME / f"{repo_id}_v30"

View File

@ -21,9 +21,9 @@ def test_aggregate_datasets(tmp_path, lerobot_dataset_factory):
repo_ids=[ds_0.repo_id, ds_1.repo_id], repo_ids=[ds_0.repo_id, ds_1.repo_id],
roots=[ds_0.root, ds_1.root], roots=[ds_0.root, ds_1.root],
aggr_repo_id=f"{DUMMY_REPO_ID}_aggr", aggr_repo_id=f"{DUMMY_REPO_ID}_aggr",
aggr_root=tmp_path / "test_aggr" aggr_root=tmp_path / "test_aggr",
) )
aggr_ds = LeRobotDataset(f"{DUMMY_REPO_ID}_aggr", root=tmp_path / "test_aggr") aggr_ds = LeRobotDataset(f"{DUMMY_REPO_ID}_aggr", root=tmp_path / "test_aggr")
for item in aggr_ds: for _ in aggr_ds:
pass pass

View File

@ -11,7 +11,6 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import json
from pathlib import Path from pathlib import Path
import datasets import datasets