diff --git a/lerobot/common/datasets/utils.py b/lerobot/common/datasets/utils.py index fbf4dd5f..bebc3c6f 100644 --- a/lerobot/common/datasets/utils.py +++ b/lerobot/common/datasets/utils.py @@ -15,16 +15,16 @@ # limitations under the License. import json import warnings -from functools import cache from itertools import accumulate from pathlib import Path from pprint import pformat from typing import Dict import datasets +import jsonlines import torch from datasets import load_dataset -from huggingface_hub import DatasetCard, HfApi, hf_hub_download +from huggingface_hub import DatasetCard, HfApi from PIL import Image as PILImage from torchvision import transforms @@ -96,7 +96,6 @@ def hf_transform_to_torch(items_dict: dict[torch.Tensor | None]): return items_dict -@cache def get_hub_safe_version(repo_id: str, version: str, enforce_v2: bool = True) -> str: num_version = float(version.strip("v")) if num_version < 2 and enforce_v2: @@ -144,50 +143,30 @@ def load_hf_dataset( return hf_dataset -def load_stats(repo_id: str, version: str, local_dir: Path) -> dict[str, dict[str, torch.Tensor]]: - """stats contains the statistics per modality computed over the full dataset, such as max, min, mean, std +def load_metadata(local_dir: Path) -> tuple[dict | list]: + """Loads metadata files from a dataset.""" + info_path = local_dir / "meta/info.json" + episodes_path = local_dir / "meta/episodes.jsonl" + stats_path = local_dir / "meta/stats.json" + tasks_path = local_dir / "meta/tasks.json" - Example: - ```python - normalized_action = (action - stats["action"]["mean"]) / stats["action"]["std"] - ``` - """ - fpath = hf_hub_download( - repo_id, filename="meta/stats.json", local_dir=local_dir, repo_type="dataset", revision=version - ) - with open(fpath) as f: + with open(info_path) as f: + info = json.load(f) + + with jsonlines.open(episodes_path, "r") as reader: + episode_dicts = list(reader) + + with open(stats_path) as f: stats = json.load(f) - stats = flatten_dict(stats) - stats = {key: torch.tensor(value) for key, value in stats.items()} - return unflatten_dict(stats) - - -def load_info(repo_id: str, version: str, local_dir: Path) -> dict: - """info contains structural information about the dataset. It should be the reference and - act as the 'source of thruth' for what's inside the dataset. - - Example: - ```python - print("frame per second used to collect the video", info["fps"]) - ``` - """ - fpath = hf_hub_download( - repo_id, filename="meta/info.json", local_dir=local_dir, repo_type="dataset", revision=version - ) - with open(fpath) as f: - return json.load(f) - - -def load_tasks(repo_id: str, version: str, local_dir: Path) -> dict: - """tasks contains all the tasks of the dataset, indexed by their task_index.""" - fpath = hf_hub_download( - repo_id, filename="meta/tasks.json", local_dir=local_dir, repo_type="dataset", revision=version - ) - with open(fpath) as f: + with open(tasks_path) as f: tasks = json.load(f) - return {item["task_index"]: item["task"] for item in sorted(tasks, key=lambda x: x["task_index"])} + stats = {key: torch.tensor(value) for key, value in flatten_dict(stats).items()} + stats = unflatten_dict(stats) + tasks = {item["task_index"]: item["task"] for item in sorted(tasks, key=lambda x: x["task_index"])} + + return info, episode_dicts, stats, tasks def get_episode_data_index(episodes: list, episode_dicts: list[dict]) -> dict[str, torch.Tensor]: