pre-commit run --all-files

This commit is contained in:
Remi Cadene 2025-04-21 09:34:19 +02:00
parent 5a6ea09248
commit 4acf99f622
7 changed files with 85 additions and 33 deletions

View File

@ -26,7 +26,12 @@ from datatrove.pipeline.base import PipelineStep
from examples.port_datasets.droid_rlds.port_droid import DROID_SHARDS
from lerobot.common.datasets.aggregate import validate_all_metadata
from lerobot.common.datasets.lerobot_dataset import LeRobotDatasetMetadata
from lerobot.common.datasets.utils import write_episode, legacy_write_episode_stats, write_info, legacy_write_task
from lerobot.common.datasets.utils import (
legacy_write_episode_stats,
legacy_write_task,
write_episode,
write_info,
)
from lerobot.common.utils.utils import init_logging

View File

@ -1,13 +1,21 @@
import logging
from pathlib import Path
import shutil
from pathlib import Path
import pandas as pd
import tqdm
from lerobot.common.datasets.compute_stats import aggregate_stats
from lerobot.common.datasets.lerobot_dataset import LeRobotDataset, LeRobotDatasetMetadata
from lerobot.common.datasets.utils import DEFAULT_CHUNK_SIZE, DEFAULT_DATA_PATH, DEFAULT_EPISODES_PATH, DEFAULT_VIDEO_PATH, write_episode, legacy_write_episode_stats, write_info, legacy_write_task, write_stats, write_tasks
from lerobot.common.datasets.utils import (
DEFAULT_CHUNK_SIZE,
DEFAULT_DATA_PATH,
DEFAULT_EPISODES_PATH,
DEFAULT_VIDEO_PATH,
write_info,
write_stats,
write_tasks,
)
from lerobot.common.utils.utils import init_logging
@ -32,14 +40,17 @@ def validate_all_metadata(all_metadata: list[LeRobotDatasetMetadata]):
return fps, robot_type, features
def get_update_episode_and_task_func(episode_index_to_add, old_tasks, new_tasks):
def _update(row):
row["episode_index"] = row["episode_index"] + episode_index_to_add
task = old_tasks.iloc[row["task_index"]].name
row["task_index"] = new_tasks.loc[task].task_index.item()
return row
return _update
def get_update_meta_func(
meta_chunk_index_to_add,
meta_file_index_to_add,
@ -55,20 +66,26 @@ def get_update_meta_func(
row["data/chunk_index"] = row["data/chunk_index"] + data_chunk_index_to_add
row["data/file_index"] = row["data/file_index"] + data_file_index_to_add
for key in videos_chunk_index_to_add:
row[f"videos/{key}/chunk_index"] = row[f"videos/{key}/chunk_index"] + videos_chunk_index_to_add[key]
row[f"videos/{key}/chunk_index"] = (
row[f"videos/{key}/chunk_index"] + videos_chunk_index_to_add[key]
)
row[f"videos/{key}/file_index"] = row[f"videos/{key}/file_index"] + videos_file_index_to_add[key]
row["dataset_from_index"] = row["dataset_from_index"] + frame_index_to_add
row["dataset_to_index"] = row["dataset_to_index"] + frame_index_to_add
return row
return _update
def aggregate_datasets(repo_ids: list[str], aggr_repo_id: str, roots: list[Path] = None, aggr_root=None):
logging.info("Start aggregate_datasets")
if roots is None:
all_metadata = [LeRobotDatasetMetadata(repo_id) for repo_id in repo_ids]
else:
all_metadata = [LeRobotDatasetMetadata(repo_id, root=root) for repo_id, root in zip(repo_ids, roots)]
all_metadata = [
LeRobotDatasetMetadata(repo_id, root=root) for repo_id, root in zip(repo_ids, roots, strict=False)
]
fps, robot_type, features = validate_all_metadata(all_metadata)
video_keys = [key for key in features if features[key]["dtype"] == "video"]
@ -96,12 +113,18 @@ def aggregate_datasets(repo_ids: list[str], aggr_repo_id: str, roots: list[Path]
aggr_data_chunk_idx = 0
aggr_data_file_idx = 0
aggr_videos_chunk_idx = {key: 0 for key in video_keys}
aggr_videos_file_idx = {key: 0 for key in video_keys}
aggr_videos_chunk_idx = dict.fromkeys(video_keys, 0)
aggr_videos_file_idx = dict.fromkeys(video_keys, 0)
for meta in tqdm.tqdm(all_metadata, desc="Copy data and videos"):
meta_chunk_file_ids = set([(c,f) for c, f in zip(meta.episodes["meta/episodes/chunk_index"], meta.episodes["meta/episodes/file_index"])])
meta_chunk_file_ids = {
(c, f)
for c, f in zip(
meta.episodes["meta/episodes/chunk_index"],
meta.episodes["meta/episodes/file_index"],
strict=False,
)
}
for chunk_idx, file_idx in meta_chunk_file_ids:
path = meta.root / DEFAULT_EPISODES_PATH.format(chunk_index=chunk_idx, file_index=file_idx)
df = pd.read_parquet(path)
@ -116,7 +139,9 @@ def aggregate_datasets(repo_ids: list[str], aggr_repo_id: str, roots: list[Path]
)
df = df.apply(update_meta_func, axis=1)
aggr_path = aggr_root / DEFAULT_EPISODES_PATH.format(chunk_index=aggr_meta_chunk_idx, file_index=aggr_meta_file_idx)
aggr_path = aggr_root / DEFAULT_EPISODES_PATH.format(
chunk_index=aggr_meta_chunk_idx, file_index=aggr_meta_file_idx
)
aggr_path.parent.mkdir(parents=True, exist_ok=True)
df.to_parquet(aggr_path)
@ -127,10 +152,23 @@ def aggregate_datasets(repo_ids: list[str], aggr_repo_id: str, roots: list[Path]
# cp videos
for key in video_keys:
video_chunk_file_ids = set([(c,f) for c, f in zip(meta.episodes[f"videos/{key}/chunk_index"], meta.episodes[f"videos/{key}/file_index"])])
video_chunk_file_ids = {
(c, f)
for c, f in zip(
meta.episodes[f"videos/{key}/chunk_index"],
meta.episodes[f"videos/{key}/file_index"],
strict=False,
)
}
for chunk_idx, file_idx in video_chunk_file_ids:
path = meta.root / DEFAULT_VIDEO_PATH.format(video_key=key, chunk_index=chunk_idx, file_index=file_idx)
aggr_path = aggr_root / DEFAULT_VIDEO_PATH.format(video_key=key, chunk_index=aggr_videos_chunk_idx[key], file_index=aggr_videos_file_idx[key])
path = meta.root / DEFAULT_VIDEO_PATH.format(
video_key=key, chunk_index=chunk_idx, file_index=file_idx
)
aggr_path = aggr_root / DEFAULT_VIDEO_PATH.format(
video_key=key,
chunk_index=aggr_videos_chunk_idx[key],
file_index=aggr_videos_file_idx[key],
)
aggr_path.parent.mkdir(parents=True, exist_ok=True)
shutil.copy(str(path), str(aggr_path))
@ -142,14 +180,19 @@ def aggregate_datasets(repo_ids: list[str], aggr_repo_id: str, roots: list[Path]
aggr_videos_file_idx[key] = 0
aggr_videos_chunk_idx[key] += 1
data_chunk_file_ids = set([(c,f) for c, f in zip(meta.episodes["data/chunk_index"], meta.episodes["data/file_index"])])
data_chunk_file_ids = {
(c, f)
for c, f in zip(meta.episodes["data/chunk_index"], meta.episodes["data/file_index"], strict=False)
}
for chunk_idx, file_idx in data_chunk_file_ids:
path = meta.root / DEFAULT_DATA_PATH.format(chunk_index=chunk_idx, file_index=file_idx)
df = pd.read_parquet(path)
update_data_func = get_update_episode_and_task_func(num_episodes, meta.tasks, aggr_meta.tasks)
df = df.apply(update_data_func, axis=1)
aggr_path = aggr_root / DEFAULT_DATA_PATH.format(chunk_index=aggr_data_chunk_idx, file_index=aggr_data_file_idx)
aggr_path = aggr_root / DEFAULT_DATA_PATH.format(
chunk_index=aggr_data_chunk_idx, file_index=aggr_data_file_idx
)
aggr_path.parent.mkdir(parents=True, exist_ok=True)
df.to_parquet(aggr_path)

View File

@ -75,7 +75,9 @@ class BackwardCompatibilityError(CompatibilityError):
elif version.major == 2:
message = V2_MESSAGE.format(repo_id=repo_id, version=version)
else:
raise NotImplementedError("Contact the maintainer on [Discord](https://discord.com/invite/s3KuuzsPFb).")
raise NotImplementedError(
"Contact the maintainer on [Discord](https://discord.com/invite/s3KuuzsPFb)."
)
super().__init__(message)

View File

@ -24,7 +24,7 @@ from collections.abc import Iterator
from pathlib import Path
from pprint import pformat
from types import SimpleNamespace
from typing import Any, Tuple
from typing import Any
import datasets
import numpy as np
@ -47,7 +47,7 @@ from lerobot.common.datasets.backward_compatibility import (
)
from lerobot.common.robot_devices.robots.utils import Robot
from lerobot.common.utils.utils import is_valid_numpy_dtype_string
from lerobot.configs.types import DictLike, FeatureType, PolicyFeature
from lerobot.configs.types import FeatureType, PolicyFeature
DEFAULT_CHUNK_SIZE = 1000 # Max number of files per chunk
DEFAULT_FILE_SIZE_IN_MB = 500.0 # Max size per file
@ -249,34 +249,41 @@ def load_json(fpath: Path) -> Any:
with open(fpath) as f:
return json.load(f)
def write_json(data: dict, fpath: Path) -> None:
fpath.parent.mkdir(exist_ok=True, parents=True)
with open(fpath, "w") as f:
json.dump(data, f, indent=4, ensure_ascii=False)
def write_info(info: dict, local_dir: Path):
write_json(info, local_dir / INFO_PATH)
def load_info(local_dir: Path) -> dict:
info = load_json(local_dir / INFO_PATH)
for ft in info["features"].values():
ft["shape"] = tuple(ft["shape"])
return info
def write_stats(stats: dict, local_dir: Path):
serialized_stats = serialize_dict(stats)
write_json(serialized_stats, local_dir / STATS_PATH)
def cast_stats_to_numpy(stats) -> dict[str, dict[str, np.ndarray]]:
stats = {key: np.array(value) for key, value in flatten_dict(stats).items()}
return unflatten_dict(stats)
def load_stats(local_dir: Path) -> dict[str, dict[str, np.ndarray]]:
if not (local_dir / STATS_PATH).exists():
return None
stats = load_json(local_dir / STATS_PATH)
return cast_stats_to_numpy(stats)
def write_hf_dataset(hf_dataset: Dataset, local_dir: Path):
if get_hf_dataset_size_in_mb(hf_dataset) > DEFAULT_FILE_SIZE_IN_MB:
raise NotImplementedError("Contact a maintainer.")
@ -292,7 +299,6 @@ def write_tasks(tasks: pandas.DataFrame, local_dir: Path):
tasks.to_parquet(path)
def load_tasks(local_dir: Path):
tasks = pd.read_parquet(local_dir / DEFAULT_TASKS_PATH)
return tasks

View File

@ -30,7 +30,7 @@ from huggingface_hub import HfApi, snapshot_download
from lerobot.common.constants import HF_LEROBOT_HOME
from lerobot.common.datasets.compute_stats import aggregate_stats
from lerobot.common.datasets.lerobot_dataset import CODEBASE_VERSION, LeRobotDataset
from lerobot.common.datasets.lerobot_dataset import CODEBASE_VERSION
from lerobot.common.datasets.utils import (
DEFAULT_CHUNK_SIZE,
DEFAULT_DATA_PATH,
@ -43,11 +43,7 @@ from lerobot.common.datasets.utils import (
get_parquet_num_frames,
get_video_duration_in_s,
get_video_size_in_mb,
legacy_load_episodes,
legacy_load_episodes_stats,
legacy_load_tasks,
load_info,
serialize_dict,
update_chunk_file_indices,
write_episodes,
write_info,
@ -111,10 +107,12 @@ def load_jsonlines(fpath: Path) -> list[Any]:
with jsonlines.open(fpath, "r") as reader:
return list(reader)
def legacy_load_episodes(local_dir: Path) -> dict:
episodes = load_jsonlines(local_dir / LEGACY_EPISODES_PATH)
return {item["episode_index"]: item for item in sorted(episodes, key=lambda x: x["episode_index"])}
def legacy_load_episodes_stats(local_dir: Path) -> dict:
episodes_stats = load_jsonlines(local_dir / LEGACY_EPISODES_STATS_PATH)
return {
@ -122,6 +120,7 @@ def legacy_load_episodes_stats(local_dir: Path) -> dict:
for item in sorted(episodes_stats, key=lambda x: x["episode_index"])
}
def legacy_load_tasks(local_dir: Path) -> tuple[dict, dict]:
tasks = load_jsonlines(local_dir / LEGACY_TASKS_PATH)
tasks = {item["task_index"]: item["task"] for item in sorted(tasks, key=lambda x: x["task_index"])}
@ -355,8 +354,6 @@ def convert_dataset(
branch: str | None = None,
num_workers: int = 4,
):
root = HF_LEROBOT_HOME / repo_id
old_root = HF_LEROBOT_HOME / f"{repo_id}_old"
new_root = HF_LEROBOT_HOME / f"{repo_id}_v30"

View File

@ -21,9 +21,9 @@ def test_aggregate_datasets(tmp_path, lerobot_dataset_factory):
repo_ids=[ds_0.repo_id, ds_1.repo_id],
roots=[ds_0.root, ds_1.root],
aggr_repo_id=f"{DUMMY_REPO_ID}_aggr",
aggr_root=tmp_path / "test_aggr"
aggr_root=tmp_path / "test_aggr",
)
aggr_ds = LeRobotDataset(f"{DUMMY_REPO_ID}_aggr", root=tmp_path / "test_aggr")
for item in aggr_ds:
for _ in aggr_ds:
pass

View File

@ -11,7 +11,6 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import json
from pathlib import Path
import datasets