pre-commit run --all-files
This commit is contained in:
parent
5a6ea09248
commit
4acf99f622
|
@ -26,7 +26,12 @@ from datatrove.pipeline.base import PipelineStep
|
||||||
from examples.port_datasets.droid_rlds.port_droid import DROID_SHARDS
|
from examples.port_datasets.droid_rlds.port_droid import DROID_SHARDS
|
||||||
from lerobot.common.datasets.aggregate import validate_all_metadata
|
from lerobot.common.datasets.aggregate import validate_all_metadata
|
||||||
from lerobot.common.datasets.lerobot_dataset import LeRobotDatasetMetadata
|
from lerobot.common.datasets.lerobot_dataset import LeRobotDatasetMetadata
|
||||||
from lerobot.common.datasets.utils import write_episode, legacy_write_episode_stats, write_info, legacy_write_task
|
from lerobot.common.datasets.utils import (
|
||||||
|
legacy_write_episode_stats,
|
||||||
|
legacy_write_task,
|
||||||
|
write_episode,
|
||||||
|
write_info,
|
||||||
|
)
|
||||||
from lerobot.common.utils.utils import init_logging
|
from lerobot.common.utils.utils import init_logging
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -1,13 +1,21 @@
|
||||||
import logging
|
import logging
|
||||||
from pathlib import Path
|
|
||||||
import shutil
|
import shutil
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
import pandas as pd
|
import pandas as pd
|
||||||
import tqdm
|
import tqdm
|
||||||
|
|
||||||
from lerobot.common.datasets.compute_stats import aggregate_stats
|
from lerobot.common.datasets.compute_stats import aggregate_stats
|
||||||
from lerobot.common.datasets.lerobot_dataset import LeRobotDataset, LeRobotDatasetMetadata
|
from lerobot.common.datasets.lerobot_dataset import LeRobotDataset, LeRobotDatasetMetadata
|
||||||
from lerobot.common.datasets.utils import DEFAULT_CHUNK_SIZE, DEFAULT_DATA_PATH, DEFAULT_EPISODES_PATH, DEFAULT_VIDEO_PATH, write_episode, legacy_write_episode_stats, write_info, legacy_write_task, write_stats, write_tasks
|
from lerobot.common.datasets.utils import (
|
||||||
|
DEFAULT_CHUNK_SIZE,
|
||||||
|
DEFAULT_DATA_PATH,
|
||||||
|
DEFAULT_EPISODES_PATH,
|
||||||
|
DEFAULT_VIDEO_PATH,
|
||||||
|
write_info,
|
||||||
|
write_stats,
|
||||||
|
write_tasks,
|
||||||
|
)
|
||||||
from lerobot.common.utils.utils import init_logging
|
from lerobot.common.utils.utils import init_logging
|
||||||
|
|
||||||
|
|
||||||
|
@ -32,14 +40,17 @@ def validate_all_metadata(all_metadata: list[LeRobotDatasetMetadata]):
|
||||||
|
|
||||||
return fps, robot_type, features
|
return fps, robot_type, features
|
||||||
|
|
||||||
|
|
||||||
def get_update_episode_and_task_func(episode_index_to_add, old_tasks, new_tasks):
|
def get_update_episode_and_task_func(episode_index_to_add, old_tasks, new_tasks):
|
||||||
def _update(row):
|
def _update(row):
|
||||||
row["episode_index"] = row["episode_index"] + episode_index_to_add
|
row["episode_index"] = row["episode_index"] + episode_index_to_add
|
||||||
task = old_tasks.iloc[row["task_index"]].name
|
task = old_tasks.iloc[row["task_index"]].name
|
||||||
row["task_index"] = new_tasks.loc[task].task_index.item()
|
row["task_index"] = new_tasks.loc[task].task_index.item()
|
||||||
return row
|
return row
|
||||||
|
|
||||||
return _update
|
return _update
|
||||||
|
|
||||||
|
|
||||||
def get_update_meta_func(
|
def get_update_meta_func(
|
||||||
meta_chunk_index_to_add,
|
meta_chunk_index_to_add,
|
||||||
meta_file_index_to_add,
|
meta_file_index_to_add,
|
||||||
|
@ -55,20 +66,26 @@ def get_update_meta_func(
|
||||||
row["data/chunk_index"] = row["data/chunk_index"] + data_chunk_index_to_add
|
row["data/chunk_index"] = row["data/chunk_index"] + data_chunk_index_to_add
|
||||||
row["data/file_index"] = row["data/file_index"] + data_file_index_to_add
|
row["data/file_index"] = row["data/file_index"] + data_file_index_to_add
|
||||||
for key in videos_chunk_index_to_add:
|
for key in videos_chunk_index_to_add:
|
||||||
row[f"videos/{key}/chunk_index"] = row[f"videos/{key}/chunk_index"] + videos_chunk_index_to_add[key]
|
row[f"videos/{key}/chunk_index"] = (
|
||||||
|
row[f"videos/{key}/chunk_index"] + videos_chunk_index_to_add[key]
|
||||||
|
)
|
||||||
row[f"videos/{key}/file_index"] = row[f"videos/{key}/file_index"] + videos_file_index_to_add[key]
|
row[f"videos/{key}/file_index"] = row[f"videos/{key}/file_index"] + videos_file_index_to_add[key]
|
||||||
row["dataset_from_index"] = row["dataset_from_index"] + frame_index_to_add
|
row["dataset_from_index"] = row["dataset_from_index"] + frame_index_to_add
|
||||||
row["dataset_to_index"] = row["dataset_to_index"] + frame_index_to_add
|
row["dataset_to_index"] = row["dataset_to_index"] + frame_index_to_add
|
||||||
return row
|
return row
|
||||||
|
|
||||||
return _update
|
return _update
|
||||||
|
|
||||||
def aggregate_datasets(repo_ids: list[str], aggr_repo_id: str, roots: list[Path]=None, aggr_root=None):
|
|
||||||
|
def aggregate_datasets(repo_ids: list[str], aggr_repo_id: str, roots: list[Path] = None, aggr_root=None):
|
||||||
logging.info("Start aggregate_datasets")
|
logging.info("Start aggregate_datasets")
|
||||||
|
|
||||||
if roots is None:
|
if roots is None:
|
||||||
all_metadata = [LeRobotDatasetMetadata(repo_id) for repo_id in repo_ids]
|
all_metadata = [LeRobotDatasetMetadata(repo_id) for repo_id in repo_ids]
|
||||||
else:
|
else:
|
||||||
all_metadata = [LeRobotDatasetMetadata(repo_id, root=root) for repo_id, root in zip(repo_ids, roots)]
|
all_metadata = [
|
||||||
|
LeRobotDatasetMetadata(repo_id, root=root) for repo_id, root in zip(repo_ids, roots, strict=False)
|
||||||
|
]
|
||||||
|
|
||||||
fps, robot_type, features = validate_all_metadata(all_metadata)
|
fps, robot_type, features = validate_all_metadata(all_metadata)
|
||||||
video_keys = [key for key in features if features[key]["dtype"] == "video"]
|
video_keys = [key for key in features if features[key]["dtype"] == "video"]
|
||||||
|
@ -96,12 +113,18 @@ def aggregate_datasets(repo_ids: list[str], aggr_repo_id: str, roots: list[Path]
|
||||||
aggr_data_chunk_idx = 0
|
aggr_data_chunk_idx = 0
|
||||||
aggr_data_file_idx = 0
|
aggr_data_file_idx = 0
|
||||||
|
|
||||||
aggr_videos_chunk_idx = {key: 0 for key in video_keys}
|
aggr_videos_chunk_idx = dict.fromkeys(video_keys, 0)
|
||||||
aggr_videos_file_idx = {key: 0 for key in video_keys}
|
aggr_videos_file_idx = dict.fromkeys(video_keys, 0)
|
||||||
|
|
||||||
for meta in tqdm.tqdm(all_metadata, desc="Copy data and videos"):
|
for meta in tqdm.tqdm(all_metadata, desc="Copy data and videos"):
|
||||||
|
meta_chunk_file_ids = {
|
||||||
meta_chunk_file_ids = set([(c,f) for c, f in zip(meta.episodes["meta/episodes/chunk_index"], meta.episodes["meta/episodes/file_index"])])
|
(c, f)
|
||||||
|
for c, f in zip(
|
||||||
|
meta.episodes["meta/episodes/chunk_index"],
|
||||||
|
meta.episodes["meta/episodes/file_index"],
|
||||||
|
strict=False,
|
||||||
|
)
|
||||||
|
}
|
||||||
for chunk_idx, file_idx in meta_chunk_file_ids:
|
for chunk_idx, file_idx in meta_chunk_file_ids:
|
||||||
path = meta.root / DEFAULT_EPISODES_PATH.format(chunk_index=chunk_idx, file_index=file_idx)
|
path = meta.root / DEFAULT_EPISODES_PATH.format(chunk_index=chunk_idx, file_index=file_idx)
|
||||||
df = pd.read_parquet(path)
|
df = pd.read_parquet(path)
|
||||||
|
@ -115,11 +138,13 @@ def aggregate_datasets(repo_ids: list[str], aggr_repo_id: str, roots: list[Path]
|
||||||
num_frames,
|
num_frames,
|
||||||
)
|
)
|
||||||
df = df.apply(update_meta_func, axis=1)
|
df = df.apply(update_meta_func, axis=1)
|
||||||
|
|
||||||
aggr_path = aggr_root / DEFAULT_EPISODES_PATH.format(chunk_index=aggr_meta_chunk_idx, file_index=aggr_meta_file_idx)
|
aggr_path = aggr_root / DEFAULT_EPISODES_PATH.format(
|
||||||
|
chunk_index=aggr_meta_chunk_idx, file_index=aggr_meta_file_idx
|
||||||
|
)
|
||||||
aggr_path.parent.mkdir(parents=True, exist_ok=True)
|
aggr_path.parent.mkdir(parents=True, exist_ok=True)
|
||||||
df.to_parquet(aggr_path)
|
df.to_parquet(aggr_path)
|
||||||
|
|
||||||
aggr_meta_file_idx += 1
|
aggr_meta_file_idx += 1
|
||||||
if aggr_meta_file_idx >= DEFAULT_CHUNK_SIZE:
|
if aggr_meta_file_idx >= DEFAULT_CHUNK_SIZE:
|
||||||
aggr_meta_file_idx = 0
|
aggr_meta_file_idx = 0
|
||||||
|
@ -127,10 +152,23 @@ def aggregate_datasets(repo_ids: list[str], aggr_repo_id: str, roots: list[Path]
|
||||||
|
|
||||||
# cp videos
|
# cp videos
|
||||||
for key in video_keys:
|
for key in video_keys:
|
||||||
video_chunk_file_ids = set([(c,f) for c, f in zip(meta.episodes[f"videos/{key}/chunk_index"], meta.episodes[f"videos/{key}/file_index"])])
|
video_chunk_file_ids = {
|
||||||
|
(c, f)
|
||||||
|
for c, f in zip(
|
||||||
|
meta.episodes[f"videos/{key}/chunk_index"],
|
||||||
|
meta.episodes[f"videos/{key}/file_index"],
|
||||||
|
strict=False,
|
||||||
|
)
|
||||||
|
}
|
||||||
for chunk_idx, file_idx in video_chunk_file_ids:
|
for chunk_idx, file_idx in video_chunk_file_ids:
|
||||||
path = meta.root / DEFAULT_VIDEO_PATH.format(video_key=key, chunk_index=chunk_idx, file_index=file_idx)
|
path = meta.root / DEFAULT_VIDEO_PATH.format(
|
||||||
aggr_path = aggr_root / DEFAULT_VIDEO_PATH.format(video_key=key, chunk_index=aggr_videos_chunk_idx[key], file_index=aggr_videos_file_idx[key])
|
video_key=key, chunk_index=chunk_idx, file_index=file_idx
|
||||||
|
)
|
||||||
|
aggr_path = aggr_root / DEFAULT_VIDEO_PATH.format(
|
||||||
|
video_key=key,
|
||||||
|
chunk_index=aggr_videos_chunk_idx[key],
|
||||||
|
file_index=aggr_videos_file_idx[key],
|
||||||
|
)
|
||||||
aggr_path.parent.mkdir(parents=True, exist_ok=True)
|
aggr_path.parent.mkdir(parents=True, exist_ok=True)
|
||||||
shutil.copy(str(path), str(aggr_path))
|
shutil.copy(str(path), str(aggr_path))
|
||||||
|
|
||||||
|
@ -142,14 +180,19 @@ def aggregate_datasets(repo_ids: list[str], aggr_repo_id: str, roots: list[Path]
|
||||||
aggr_videos_file_idx[key] = 0
|
aggr_videos_file_idx[key] = 0
|
||||||
aggr_videos_chunk_idx[key] += 1
|
aggr_videos_chunk_idx[key] += 1
|
||||||
|
|
||||||
data_chunk_file_ids = set([(c,f) for c, f in zip(meta.episodes["data/chunk_index"], meta.episodes["data/file_index"])])
|
data_chunk_file_ids = {
|
||||||
|
(c, f)
|
||||||
|
for c, f in zip(meta.episodes["data/chunk_index"], meta.episodes["data/file_index"], strict=False)
|
||||||
|
}
|
||||||
for chunk_idx, file_idx in data_chunk_file_ids:
|
for chunk_idx, file_idx in data_chunk_file_ids:
|
||||||
path = meta.root / DEFAULT_DATA_PATH.format(chunk_index=chunk_idx, file_index=file_idx)
|
path = meta.root / DEFAULT_DATA_PATH.format(chunk_index=chunk_idx, file_index=file_idx)
|
||||||
df = pd.read_parquet(path)
|
df = pd.read_parquet(path)
|
||||||
update_data_func = get_update_episode_and_task_func(num_episodes, meta.tasks, aggr_meta.tasks)
|
update_data_func = get_update_episode_and_task_func(num_episodes, meta.tasks, aggr_meta.tasks)
|
||||||
df = df.apply(update_data_func, axis=1)
|
df = df.apply(update_data_func, axis=1)
|
||||||
|
|
||||||
aggr_path = aggr_root / DEFAULT_DATA_PATH.format(chunk_index=aggr_data_chunk_idx, file_index=aggr_data_file_idx)
|
aggr_path = aggr_root / DEFAULT_DATA_PATH.format(
|
||||||
|
chunk_index=aggr_data_chunk_idx, file_index=aggr_data_file_idx
|
||||||
|
)
|
||||||
aggr_path.parent.mkdir(parents=True, exist_ok=True)
|
aggr_path.parent.mkdir(parents=True, exist_ok=True)
|
||||||
df.to_parquet(aggr_path)
|
df.to_parquet(aggr_path)
|
||||||
|
|
||||||
|
@ -157,7 +200,7 @@ def aggregate_datasets(repo_ids: list[str], aggr_repo_id: str, roots: list[Path]
|
||||||
if aggr_data_file_idx >= DEFAULT_CHUNK_SIZE:
|
if aggr_data_file_idx >= DEFAULT_CHUNK_SIZE:
|
||||||
aggr_data_file_idx = 0
|
aggr_data_file_idx = 0
|
||||||
aggr_data_chunk_idx += 1
|
aggr_data_chunk_idx += 1
|
||||||
|
|
||||||
num_episodes += meta.total_episodes
|
num_episodes += meta.total_episodes
|
||||||
num_frames += meta.total_frames
|
num_frames += meta.total_frames
|
||||||
|
|
||||||
|
|
|
@ -75,7 +75,9 @@ class BackwardCompatibilityError(CompatibilityError):
|
||||||
elif version.major == 2:
|
elif version.major == 2:
|
||||||
message = V2_MESSAGE.format(repo_id=repo_id, version=version)
|
message = V2_MESSAGE.format(repo_id=repo_id, version=version)
|
||||||
else:
|
else:
|
||||||
raise NotImplementedError("Contact the maintainer on [Discord](https://discord.com/invite/s3KuuzsPFb).")
|
raise NotImplementedError(
|
||||||
|
"Contact the maintainer on [Discord](https://discord.com/invite/s3KuuzsPFb)."
|
||||||
|
)
|
||||||
super().__init__(message)
|
super().__init__(message)
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -24,7 +24,7 @@ from collections.abc import Iterator
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from pprint import pformat
|
from pprint import pformat
|
||||||
from types import SimpleNamespace
|
from types import SimpleNamespace
|
||||||
from typing import Any, Tuple
|
from typing import Any
|
||||||
|
|
||||||
import datasets
|
import datasets
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
@ -47,7 +47,7 @@ from lerobot.common.datasets.backward_compatibility import (
|
||||||
)
|
)
|
||||||
from lerobot.common.robot_devices.robots.utils import Robot
|
from lerobot.common.robot_devices.robots.utils import Robot
|
||||||
from lerobot.common.utils.utils import is_valid_numpy_dtype_string
|
from lerobot.common.utils.utils import is_valid_numpy_dtype_string
|
||||||
from lerobot.configs.types import DictLike, FeatureType, PolicyFeature
|
from lerobot.configs.types import FeatureType, PolicyFeature
|
||||||
|
|
||||||
DEFAULT_CHUNK_SIZE = 1000 # Max number of files per chunk
|
DEFAULT_CHUNK_SIZE = 1000 # Max number of files per chunk
|
||||||
DEFAULT_FILE_SIZE_IN_MB = 500.0 # Max size per file
|
DEFAULT_FILE_SIZE_IN_MB = 500.0 # Max size per file
|
||||||
|
@ -249,34 +249,41 @@ def load_json(fpath: Path) -> Any:
|
||||||
with open(fpath) as f:
|
with open(fpath) as f:
|
||||||
return json.load(f)
|
return json.load(f)
|
||||||
|
|
||||||
|
|
||||||
def write_json(data: dict, fpath: Path) -> None:
|
def write_json(data: dict, fpath: Path) -> None:
|
||||||
fpath.parent.mkdir(exist_ok=True, parents=True)
|
fpath.parent.mkdir(exist_ok=True, parents=True)
|
||||||
with open(fpath, "w") as f:
|
with open(fpath, "w") as f:
|
||||||
json.dump(data, f, indent=4, ensure_ascii=False)
|
json.dump(data, f, indent=4, ensure_ascii=False)
|
||||||
|
|
||||||
|
|
||||||
def write_info(info: dict, local_dir: Path):
|
def write_info(info: dict, local_dir: Path):
|
||||||
write_json(info, local_dir / INFO_PATH)
|
write_json(info, local_dir / INFO_PATH)
|
||||||
|
|
||||||
|
|
||||||
def load_info(local_dir: Path) -> dict:
|
def load_info(local_dir: Path) -> dict:
|
||||||
info = load_json(local_dir / INFO_PATH)
|
info = load_json(local_dir / INFO_PATH)
|
||||||
for ft in info["features"].values():
|
for ft in info["features"].values():
|
||||||
ft["shape"] = tuple(ft["shape"])
|
ft["shape"] = tuple(ft["shape"])
|
||||||
return info
|
return info
|
||||||
|
|
||||||
|
|
||||||
def write_stats(stats: dict, local_dir: Path):
|
def write_stats(stats: dict, local_dir: Path):
|
||||||
serialized_stats = serialize_dict(stats)
|
serialized_stats = serialize_dict(stats)
|
||||||
write_json(serialized_stats, local_dir / STATS_PATH)
|
write_json(serialized_stats, local_dir / STATS_PATH)
|
||||||
|
|
||||||
|
|
||||||
def cast_stats_to_numpy(stats) -> dict[str, dict[str, np.ndarray]]:
|
def cast_stats_to_numpy(stats) -> dict[str, dict[str, np.ndarray]]:
|
||||||
stats = {key: np.array(value) for key, value in flatten_dict(stats).items()}
|
stats = {key: np.array(value) for key, value in flatten_dict(stats).items()}
|
||||||
return unflatten_dict(stats)
|
return unflatten_dict(stats)
|
||||||
|
|
||||||
|
|
||||||
def load_stats(local_dir: Path) -> dict[str, dict[str, np.ndarray]]:
|
def load_stats(local_dir: Path) -> dict[str, dict[str, np.ndarray]]:
|
||||||
if not (local_dir / STATS_PATH).exists():
|
if not (local_dir / STATS_PATH).exists():
|
||||||
return None
|
return None
|
||||||
stats = load_json(local_dir / STATS_PATH)
|
stats = load_json(local_dir / STATS_PATH)
|
||||||
return cast_stats_to_numpy(stats)
|
return cast_stats_to_numpy(stats)
|
||||||
|
|
||||||
|
|
||||||
def write_hf_dataset(hf_dataset: Dataset, local_dir: Path):
|
def write_hf_dataset(hf_dataset: Dataset, local_dir: Path):
|
||||||
if get_hf_dataset_size_in_mb(hf_dataset) > DEFAULT_FILE_SIZE_IN_MB:
|
if get_hf_dataset_size_in_mb(hf_dataset) > DEFAULT_FILE_SIZE_IN_MB:
|
||||||
raise NotImplementedError("Contact a maintainer.")
|
raise NotImplementedError("Contact a maintainer.")
|
||||||
|
@ -292,7 +299,6 @@ def write_tasks(tasks: pandas.DataFrame, local_dir: Path):
|
||||||
tasks.to_parquet(path)
|
tasks.to_parquet(path)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
def load_tasks(local_dir: Path):
|
def load_tasks(local_dir: Path):
|
||||||
tasks = pd.read_parquet(local_dir / DEFAULT_TASKS_PATH)
|
tasks = pd.read_parquet(local_dir / DEFAULT_TASKS_PATH)
|
||||||
return tasks
|
return tasks
|
||||||
|
|
|
@ -30,7 +30,7 @@ from huggingface_hub import HfApi, snapshot_download
|
||||||
|
|
||||||
from lerobot.common.constants import HF_LEROBOT_HOME
|
from lerobot.common.constants import HF_LEROBOT_HOME
|
||||||
from lerobot.common.datasets.compute_stats import aggregate_stats
|
from lerobot.common.datasets.compute_stats import aggregate_stats
|
||||||
from lerobot.common.datasets.lerobot_dataset import CODEBASE_VERSION, LeRobotDataset
|
from lerobot.common.datasets.lerobot_dataset import CODEBASE_VERSION
|
||||||
from lerobot.common.datasets.utils import (
|
from lerobot.common.datasets.utils import (
|
||||||
DEFAULT_CHUNK_SIZE,
|
DEFAULT_CHUNK_SIZE,
|
||||||
DEFAULT_DATA_PATH,
|
DEFAULT_DATA_PATH,
|
||||||
|
@ -43,11 +43,7 @@ from lerobot.common.datasets.utils import (
|
||||||
get_parquet_num_frames,
|
get_parquet_num_frames,
|
||||||
get_video_duration_in_s,
|
get_video_duration_in_s,
|
||||||
get_video_size_in_mb,
|
get_video_size_in_mb,
|
||||||
legacy_load_episodes,
|
|
||||||
legacy_load_episodes_stats,
|
|
||||||
legacy_load_tasks,
|
|
||||||
load_info,
|
load_info,
|
||||||
serialize_dict,
|
|
||||||
update_chunk_file_indices,
|
update_chunk_file_indices,
|
||||||
write_episodes,
|
write_episodes,
|
||||||
write_info,
|
write_info,
|
||||||
|
@ -111,10 +107,12 @@ def load_jsonlines(fpath: Path) -> list[Any]:
|
||||||
with jsonlines.open(fpath, "r") as reader:
|
with jsonlines.open(fpath, "r") as reader:
|
||||||
return list(reader)
|
return list(reader)
|
||||||
|
|
||||||
|
|
||||||
def legacy_load_episodes(local_dir: Path) -> dict:
|
def legacy_load_episodes(local_dir: Path) -> dict:
|
||||||
episodes = load_jsonlines(local_dir / LEGACY_EPISODES_PATH)
|
episodes = load_jsonlines(local_dir / LEGACY_EPISODES_PATH)
|
||||||
return {item["episode_index"]: item for item in sorted(episodes, key=lambda x: x["episode_index"])}
|
return {item["episode_index"]: item for item in sorted(episodes, key=lambda x: x["episode_index"])}
|
||||||
|
|
||||||
|
|
||||||
def legacy_load_episodes_stats(local_dir: Path) -> dict:
|
def legacy_load_episodes_stats(local_dir: Path) -> dict:
|
||||||
episodes_stats = load_jsonlines(local_dir / LEGACY_EPISODES_STATS_PATH)
|
episodes_stats = load_jsonlines(local_dir / LEGACY_EPISODES_STATS_PATH)
|
||||||
return {
|
return {
|
||||||
|
@ -122,6 +120,7 @@ def legacy_load_episodes_stats(local_dir: Path) -> dict:
|
||||||
for item in sorted(episodes_stats, key=lambda x: x["episode_index"])
|
for item in sorted(episodes_stats, key=lambda x: x["episode_index"])
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
def legacy_load_tasks(local_dir: Path) -> tuple[dict, dict]:
|
def legacy_load_tasks(local_dir: Path) -> tuple[dict, dict]:
|
||||||
tasks = load_jsonlines(local_dir / LEGACY_TASKS_PATH)
|
tasks = load_jsonlines(local_dir / LEGACY_TASKS_PATH)
|
||||||
tasks = {item["task_index"]: item["task"] for item in sorted(tasks, key=lambda x: x["task_index"])}
|
tasks = {item["task_index"]: item["task"] for item in sorted(tasks, key=lambda x: x["task_index"])}
|
||||||
|
@ -355,8 +354,6 @@ def convert_dataset(
|
||||||
branch: str | None = None,
|
branch: str | None = None,
|
||||||
num_workers: int = 4,
|
num_workers: int = 4,
|
||||||
):
|
):
|
||||||
|
|
||||||
|
|
||||||
root = HF_LEROBOT_HOME / repo_id
|
root = HF_LEROBOT_HOME / repo_id
|
||||||
old_root = HF_LEROBOT_HOME / f"{repo_id}_old"
|
old_root = HF_LEROBOT_HOME / f"{repo_id}_old"
|
||||||
new_root = HF_LEROBOT_HOME / f"{repo_id}_v30"
|
new_root = HF_LEROBOT_HOME / f"{repo_id}_v30"
|
||||||
|
|
|
@ -21,9 +21,9 @@ def test_aggregate_datasets(tmp_path, lerobot_dataset_factory):
|
||||||
repo_ids=[ds_0.repo_id, ds_1.repo_id],
|
repo_ids=[ds_0.repo_id, ds_1.repo_id],
|
||||||
roots=[ds_0.root, ds_1.root],
|
roots=[ds_0.root, ds_1.root],
|
||||||
aggr_repo_id=f"{DUMMY_REPO_ID}_aggr",
|
aggr_repo_id=f"{DUMMY_REPO_ID}_aggr",
|
||||||
aggr_root=tmp_path / "test_aggr"
|
aggr_root=tmp_path / "test_aggr",
|
||||||
)
|
)
|
||||||
|
|
||||||
aggr_ds = LeRobotDataset(f"{DUMMY_REPO_ID}_aggr", root=tmp_path / "test_aggr")
|
aggr_ds = LeRobotDataset(f"{DUMMY_REPO_ID}_aggr", root=tmp_path / "test_aggr")
|
||||||
for item in aggr_ds:
|
for _ in aggr_ds:
|
||||||
pass
|
pass
|
||||||
|
|
|
@ -11,7 +11,6 @@
|
||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
import json
|
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
import datasets
|
import datasets
|
||||||
|
|
Loading…
Reference in New Issue