pre-commit run --all-files
This commit is contained in:
parent
5a6ea09248
commit
4acf99f622
|
@ -26,7 +26,12 @@ from datatrove.pipeline.base import PipelineStep
|
|||
from examples.port_datasets.droid_rlds.port_droid import DROID_SHARDS
|
||||
from lerobot.common.datasets.aggregate import validate_all_metadata
|
||||
from lerobot.common.datasets.lerobot_dataset import LeRobotDatasetMetadata
|
||||
from lerobot.common.datasets.utils import write_episode, legacy_write_episode_stats, write_info, legacy_write_task
|
||||
from lerobot.common.datasets.utils import (
|
||||
legacy_write_episode_stats,
|
||||
legacy_write_task,
|
||||
write_episode,
|
||||
write_info,
|
||||
)
|
||||
from lerobot.common.utils.utils import init_logging
|
||||
|
||||
|
||||
|
|
|
@ -1,13 +1,21 @@
|
|||
import logging
|
||||
from pathlib import Path
|
||||
import shutil
|
||||
from pathlib import Path
|
||||
|
||||
import pandas as pd
|
||||
import tqdm
|
||||
|
||||
from lerobot.common.datasets.compute_stats import aggregate_stats
|
||||
from lerobot.common.datasets.lerobot_dataset import LeRobotDataset, LeRobotDatasetMetadata
|
||||
from lerobot.common.datasets.utils import DEFAULT_CHUNK_SIZE, DEFAULT_DATA_PATH, DEFAULT_EPISODES_PATH, DEFAULT_VIDEO_PATH, write_episode, legacy_write_episode_stats, write_info, legacy_write_task, write_stats, write_tasks
|
||||
from lerobot.common.datasets.utils import (
|
||||
DEFAULT_CHUNK_SIZE,
|
||||
DEFAULT_DATA_PATH,
|
||||
DEFAULT_EPISODES_PATH,
|
||||
DEFAULT_VIDEO_PATH,
|
||||
write_info,
|
||||
write_stats,
|
||||
write_tasks,
|
||||
)
|
||||
from lerobot.common.utils.utils import init_logging
|
||||
|
||||
|
||||
|
@ -32,14 +40,17 @@ def validate_all_metadata(all_metadata: list[LeRobotDatasetMetadata]):
|
|||
|
||||
return fps, robot_type, features
|
||||
|
||||
|
||||
def get_update_episode_and_task_func(episode_index_to_add, old_tasks, new_tasks):
|
||||
def _update(row):
|
||||
row["episode_index"] = row["episode_index"] + episode_index_to_add
|
||||
task = old_tasks.iloc[row["task_index"]].name
|
||||
row["task_index"] = new_tasks.loc[task].task_index.item()
|
||||
return row
|
||||
|
||||
return _update
|
||||
|
||||
|
||||
def get_update_meta_func(
|
||||
meta_chunk_index_to_add,
|
||||
meta_file_index_to_add,
|
||||
|
@ -55,20 +66,26 @@ def get_update_meta_func(
|
|||
row["data/chunk_index"] = row["data/chunk_index"] + data_chunk_index_to_add
|
||||
row["data/file_index"] = row["data/file_index"] + data_file_index_to_add
|
||||
for key in videos_chunk_index_to_add:
|
||||
row[f"videos/{key}/chunk_index"] = row[f"videos/{key}/chunk_index"] + videos_chunk_index_to_add[key]
|
||||
row[f"videos/{key}/chunk_index"] = (
|
||||
row[f"videos/{key}/chunk_index"] + videos_chunk_index_to_add[key]
|
||||
)
|
||||
row[f"videos/{key}/file_index"] = row[f"videos/{key}/file_index"] + videos_file_index_to_add[key]
|
||||
row["dataset_from_index"] = row["dataset_from_index"] + frame_index_to_add
|
||||
row["dataset_to_index"] = row["dataset_to_index"] + frame_index_to_add
|
||||
return row
|
||||
|
||||
return _update
|
||||
|
||||
def aggregate_datasets(repo_ids: list[str], aggr_repo_id: str, roots: list[Path]=None, aggr_root=None):
|
||||
|
||||
def aggregate_datasets(repo_ids: list[str], aggr_repo_id: str, roots: list[Path] = None, aggr_root=None):
|
||||
logging.info("Start aggregate_datasets")
|
||||
|
||||
if roots is None:
|
||||
all_metadata = [LeRobotDatasetMetadata(repo_id) for repo_id in repo_ids]
|
||||
else:
|
||||
all_metadata = [LeRobotDatasetMetadata(repo_id, root=root) for repo_id, root in zip(repo_ids, roots)]
|
||||
all_metadata = [
|
||||
LeRobotDatasetMetadata(repo_id, root=root) for repo_id, root in zip(repo_ids, roots, strict=False)
|
||||
]
|
||||
|
||||
fps, robot_type, features = validate_all_metadata(all_metadata)
|
||||
video_keys = [key for key in features if features[key]["dtype"] == "video"]
|
||||
|
@ -96,12 +113,18 @@ def aggregate_datasets(repo_ids: list[str], aggr_repo_id: str, roots: list[Path]
|
|||
aggr_data_chunk_idx = 0
|
||||
aggr_data_file_idx = 0
|
||||
|
||||
aggr_videos_chunk_idx = {key: 0 for key in video_keys}
|
||||
aggr_videos_file_idx = {key: 0 for key in video_keys}
|
||||
aggr_videos_chunk_idx = dict.fromkeys(video_keys, 0)
|
||||
aggr_videos_file_idx = dict.fromkeys(video_keys, 0)
|
||||
|
||||
for meta in tqdm.tqdm(all_metadata, desc="Copy data and videos"):
|
||||
|
||||
meta_chunk_file_ids = set([(c,f) for c, f in zip(meta.episodes["meta/episodes/chunk_index"], meta.episodes["meta/episodes/file_index"])])
|
||||
meta_chunk_file_ids = {
|
||||
(c, f)
|
||||
for c, f in zip(
|
||||
meta.episodes["meta/episodes/chunk_index"],
|
||||
meta.episodes["meta/episodes/file_index"],
|
||||
strict=False,
|
||||
)
|
||||
}
|
||||
for chunk_idx, file_idx in meta_chunk_file_ids:
|
||||
path = meta.root / DEFAULT_EPISODES_PATH.format(chunk_index=chunk_idx, file_index=file_idx)
|
||||
df = pd.read_parquet(path)
|
||||
|
@ -115,11 +138,13 @@ def aggregate_datasets(repo_ids: list[str], aggr_repo_id: str, roots: list[Path]
|
|||
num_frames,
|
||||
)
|
||||
df = df.apply(update_meta_func, axis=1)
|
||||
|
||||
aggr_path = aggr_root / DEFAULT_EPISODES_PATH.format(chunk_index=aggr_meta_chunk_idx, file_index=aggr_meta_file_idx)
|
||||
|
||||
aggr_path = aggr_root / DEFAULT_EPISODES_PATH.format(
|
||||
chunk_index=aggr_meta_chunk_idx, file_index=aggr_meta_file_idx
|
||||
)
|
||||
aggr_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
df.to_parquet(aggr_path)
|
||||
|
||||
|
||||
aggr_meta_file_idx += 1
|
||||
if aggr_meta_file_idx >= DEFAULT_CHUNK_SIZE:
|
||||
aggr_meta_file_idx = 0
|
||||
|
@ -127,10 +152,23 @@ def aggregate_datasets(repo_ids: list[str], aggr_repo_id: str, roots: list[Path]
|
|||
|
||||
# cp videos
|
||||
for key in video_keys:
|
||||
video_chunk_file_ids = set([(c,f) for c, f in zip(meta.episodes[f"videos/{key}/chunk_index"], meta.episodes[f"videos/{key}/file_index"])])
|
||||
video_chunk_file_ids = {
|
||||
(c, f)
|
||||
for c, f in zip(
|
||||
meta.episodes[f"videos/{key}/chunk_index"],
|
||||
meta.episodes[f"videos/{key}/file_index"],
|
||||
strict=False,
|
||||
)
|
||||
}
|
||||
for chunk_idx, file_idx in video_chunk_file_ids:
|
||||
path = meta.root / DEFAULT_VIDEO_PATH.format(video_key=key, chunk_index=chunk_idx, file_index=file_idx)
|
||||
aggr_path = aggr_root / DEFAULT_VIDEO_PATH.format(video_key=key, chunk_index=aggr_videos_chunk_idx[key], file_index=aggr_videos_file_idx[key])
|
||||
path = meta.root / DEFAULT_VIDEO_PATH.format(
|
||||
video_key=key, chunk_index=chunk_idx, file_index=file_idx
|
||||
)
|
||||
aggr_path = aggr_root / DEFAULT_VIDEO_PATH.format(
|
||||
video_key=key,
|
||||
chunk_index=aggr_videos_chunk_idx[key],
|
||||
file_index=aggr_videos_file_idx[key],
|
||||
)
|
||||
aggr_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
shutil.copy(str(path), str(aggr_path))
|
||||
|
||||
|
@ -142,14 +180,19 @@ def aggregate_datasets(repo_ids: list[str], aggr_repo_id: str, roots: list[Path]
|
|||
aggr_videos_file_idx[key] = 0
|
||||
aggr_videos_chunk_idx[key] += 1
|
||||
|
||||
data_chunk_file_ids = set([(c,f) for c, f in zip(meta.episodes["data/chunk_index"], meta.episodes["data/file_index"])])
|
||||
data_chunk_file_ids = {
|
||||
(c, f)
|
||||
for c, f in zip(meta.episodes["data/chunk_index"], meta.episodes["data/file_index"], strict=False)
|
||||
}
|
||||
for chunk_idx, file_idx in data_chunk_file_ids:
|
||||
path = meta.root / DEFAULT_DATA_PATH.format(chunk_index=chunk_idx, file_index=file_idx)
|
||||
df = pd.read_parquet(path)
|
||||
update_data_func = get_update_episode_and_task_func(num_episodes, meta.tasks, aggr_meta.tasks)
|
||||
df = df.apply(update_data_func, axis=1)
|
||||
|
||||
aggr_path = aggr_root / DEFAULT_DATA_PATH.format(chunk_index=aggr_data_chunk_idx, file_index=aggr_data_file_idx)
|
||||
aggr_path = aggr_root / DEFAULT_DATA_PATH.format(
|
||||
chunk_index=aggr_data_chunk_idx, file_index=aggr_data_file_idx
|
||||
)
|
||||
aggr_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
df.to_parquet(aggr_path)
|
||||
|
||||
|
@ -157,7 +200,7 @@ def aggregate_datasets(repo_ids: list[str], aggr_repo_id: str, roots: list[Path]
|
|||
if aggr_data_file_idx >= DEFAULT_CHUNK_SIZE:
|
||||
aggr_data_file_idx = 0
|
||||
aggr_data_chunk_idx += 1
|
||||
|
||||
|
||||
num_episodes += meta.total_episodes
|
||||
num_frames += meta.total_frames
|
||||
|
||||
|
|
|
@ -75,7 +75,9 @@ class BackwardCompatibilityError(CompatibilityError):
|
|||
elif version.major == 2:
|
||||
message = V2_MESSAGE.format(repo_id=repo_id, version=version)
|
||||
else:
|
||||
raise NotImplementedError("Contact the maintainer on [Discord](https://discord.com/invite/s3KuuzsPFb).")
|
||||
raise NotImplementedError(
|
||||
"Contact the maintainer on [Discord](https://discord.com/invite/s3KuuzsPFb)."
|
||||
)
|
||||
super().__init__(message)
|
||||
|
||||
|
||||
|
|
|
@ -24,7 +24,7 @@ from collections.abc import Iterator
|
|||
from pathlib import Path
|
||||
from pprint import pformat
|
||||
from types import SimpleNamespace
|
||||
from typing import Any, Tuple
|
||||
from typing import Any
|
||||
|
||||
import datasets
|
||||
import numpy as np
|
||||
|
@ -47,7 +47,7 @@ from lerobot.common.datasets.backward_compatibility import (
|
|||
)
|
||||
from lerobot.common.robot_devices.robots.utils import Robot
|
||||
from lerobot.common.utils.utils import is_valid_numpy_dtype_string
|
||||
from lerobot.configs.types import DictLike, FeatureType, PolicyFeature
|
||||
from lerobot.configs.types import FeatureType, PolicyFeature
|
||||
|
||||
DEFAULT_CHUNK_SIZE = 1000 # Max number of files per chunk
|
||||
DEFAULT_FILE_SIZE_IN_MB = 500.0 # Max size per file
|
||||
|
@ -249,34 +249,41 @@ def load_json(fpath: Path) -> Any:
|
|||
with open(fpath) as f:
|
||||
return json.load(f)
|
||||
|
||||
|
||||
def write_json(data: dict, fpath: Path) -> None:
|
||||
fpath.parent.mkdir(exist_ok=True, parents=True)
|
||||
with open(fpath, "w") as f:
|
||||
json.dump(data, f, indent=4, ensure_ascii=False)
|
||||
|
||||
|
||||
def write_info(info: dict, local_dir: Path):
|
||||
write_json(info, local_dir / INFO_PATH)
|
||||
|
||||
|
||||
def load_info(local_dir: Path) -> dict:
|
||||
info = load_json(local_dir / INFO_PATH)
|
||||
for ft in info["features"].values():
|
||||
ft["shape"] = tuple(ft["shape"])
|
||||
return info
|
||||
|
||||
|
||||
def write_stats(stats: dict, local_dir: Path):
|
||||
serialized_stats = serialize_dict(stats)
|
||||
write_json(serialized_stats, local_dir / STATS_PATH)
|
||||
|
||||
|
||||
def cast_stats_to_numpy(stats) -> dict[str, dict[str, np.ndarray]]:
|
||||
stats = {key: np.array(value) for key, value in flatten_dict(stats).items()}
|
||||
return unflatten_dict(stats)
|
||||
|
||||
|
||||
def load_stats(local_dir: Path) -> dict[str, dict[str, np.ndarray]]:
|
||||
if not (local_dir / STATS_PATH).exists():
|
||||
return None
|
||||
stats = load_json(local_dir / STATS_PATH)
|
||||
return cast_stats_to_numpy(stats)
|
||||
|
||||
|
||||
def write_hf_dataset(hf_dataset: Dataset, local_dir: Path):
|
||||
if get_hf_dataset_size_in_mb(hf_dataset) > DEFAULT_FILE_SIZE_IN_MB:
|
||||
raise NotImplementedError("Contact a maintainer.")
|
||||
|
@ -292,7 +299,6 @@ def write_tasks(tasks: pandas.DataFrame, local_dir: Path):
|
|||
tasks.to_parquet(path)
|
||||
|
||||
|
||||
|
||||
def load_tasks(local_dir: Path):
|
||||
tasks = pd.read_parquet(local_dir / DEFAULT_TASKS_PATH)
|
||||
return tasks
|
||||
|
|
|
@ -30,7 +30,7 @@ from huggingface_hub import HfApi, snapshot_download
|
|||
|
||||
from lerobot.common.constants import HF_LEROBOT_HOME
|
||||
from lerobot.common.datasets.compute_stats import aggregate_stats
|
||||
from lerobot.common.datasets.lerobot_dataset import CODEBASE_VERSION, LeRobotDataset
|
||||
from lerobot.common.datasets.lerobot_dataset import CODEBASE_VERSION
|
||||
from lerobot.common.datasets.utils import (
|
||||
DEFAULT_CHUNK_SIZE,
|
||||
DEFAULT_DATA_PATH,
|
||||
|
@ -43,11 +43,7 @@ from lerobot.common.datasets.utils import (
|
|||
get_parquet_num_frames,
|
||||
get_video_duration_in_s,
|
||||
get_video_size_in_mb,
|
||||
legacy_load_episodes,
|
||||
legacy_load_episodes_stats,
|
||||
legacy_load_tasks,
|
||||
load_info,
|
||||
serialize_dict,
|
||||
update_chunk_file_indices,
|
||||
write_episodes,
|
||||
write_info,
|
||||
|
@ -111,10 +107,12 @@ def load_jsonlines(fpath: Path) -> list[Any]:
|
|||
with jsonlines.open(fpath, "r") as reader:
|
||||
return list(reader)
|
||||
|
||||
|
||||
def legacy_load_episodes(local_dir: Path) -> dict:
|
||||
episodes = load_jsonlines(local_dir / LEGACY_EPISODES_PATH)
|
||||
return {item["episode_index"]: item for item in sorted(episodes, key=lambda x: x["episode_index"])}
|
||||
|
||||
|
||||
def legacy_load_episodes_stats(local_dir: Path) -> dict:
|
||||
episodes_stats = load_jsonlines(local_dir / LEGACY_EPISODES_STATS_PATH)
|
||||
return {
|
||||
|
@ -122,6 +120,7 @@ def legacy_load_episodes_stats(local_dir: Path) -> dict:
|
|||
for item in sorted(episodes_stats, key=lambda x: x["episode_index"])
|
||||
}
|
||||
|
||||
|
||||
def legacy_load_tasks(local_dir: Path) -> tuple[dict, dict]:
|
||||
tasks = load_jsonlines(local_dir / LEGACY_TASKS_PATH)
|
||||
tasks = {item["task_index"]: item["task"] for item in sorted(tasks, key=lambda x: x["task_index"])}
|
||||
|
@ -355,8 +354,6 @@ def convert_dataset(
|
|||
branch: str | None = None,
|
||||
num_workers: int = 4,
|
||||
):
|
||||
|
||||
|
||||
root = HF_LEROBOT_HOME / repo_id
|
||||
old_root = HF_LEROBOT_HOME / f"{repo_id}_old"
|
||||
new_root = HF_LEROBOT_HOME / f"{repo_id}_v30"
|
||||
|
|
|
@ -21,9 +21,9 @@ def test_aggregate_datasets(tmp_path, lerobot_dataset_factory):
|
|||
repo_ids=[ds_0.repo_id, ds_1.repo_id],
|
||||
roots=[ds_0.root, ds_1.root],
|
||||
aggr_repo_id=f"{DUMMY_REPO_ID}_aggr",
|
||||
aggr_root=tmp_path / "test_aggr"
|
||||
aggr_root=tmp_path / "test_aggr",
|
||||
)
|
||||
|
||||
aggr_ds = LeRobotDataset(f"{DUMMY_REPO_ID}_aggr", root=tmp_path / "test_aggr")
|
||||
for item in aggr_ds:
|
||||
for _ in aggr_ds:
|
||||
pass
|
||||
|
|
|
@ -11,7 +11,6 @@
|
|||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import json
|
||||
from pathlib import Path
|
||||
|
||||
import datasets
|
||||
|
|
Loading…
Reference in New Issue