Add load_metadata

This commit is contained in:
Simon Alibert 2024-10-18 14:59:09 +02:00
parent 1a51505ec6
commit bce3dc3bfa
1 changed files with 21 additions and 42 deletions

View File

@ -15,16 +15,16 @@
# limitations under the License.
import json
import warnings
from functools import cache
from itertools import accumulate
from pathlib import Path
from pprint import pformat
from typing import Dict
import datasets
import jsonlines
import torch
from datasets import load_dataset
from huggingface_hub import DatasetCard, HfApi, hf_hub_download
from huggingface_hub import DatasetCard, HfApi
from PIL import Image as PILImage
from torchvision import transforms
@ -96,7 +96,6 @@ def hf_transform_to_torch(items_dict: dict[torch.Tensor | None]):
return items_dict
@cache
def get_hub_safe_version(repo_id: str, version: str, enforce_v2: bool = True) -> str:
num_version = float(version.strip("v"))
if num_version < 2 and enforce_v2:
@ -144,50 +143,30 @@ def load_hf_dataset(
return hf_dataset
def load_stats(repo_id: str, version: str, local_dir: Path) -> dict[str, dict[str, torch.Tensor]]:
"""stats contains the statistics per modality computed over the full dataset, such as max, min, mean, std
def load_metadata(local_dir: Path) -> tuple[dict | list]:
"""Loads metadata files from a dataset."""
info_path = local_dir / "meta/info.json"
episodes_path = local_dir / "meta/episodes.jsonl"
stats_path = local_dir / "meta/stats.json"
tasks_path = local_dir / "meta/tasks.json"
Example:
```python
normalized_action = (action - stats["action"]["mean"]) / stats["action"]["std"]
```
"""
fpath = hf_hub_download(
repo_id, filename="meta/stats.json", local_dir=local_dir, repo_type="dataset", revision=version
)
with open(fpath) as f:
with open(info_path) as f:
info = json.load(f)
with jsonlines.open(episodes_path, "r") as reader:
episode_dicts = list(reader)
with open(stats_path) as f:
stats = json.load(f)
stats = flatten_dict(stats)
stats = {key: torch.tensor(value) for key, value in stats.items()}
return unflatten_dict(stats)
def load_info(repo_id: str, version: str, local_dir: Path) -> dict:
"""info contains structural information about the dataset. It should be the reference and
act as the 'source of thruth' for what's inside the dataset.
Example:
```python
print("frame per second used to collect the video", info["fps"])
```
"""
fpath = hf_hub_download(
repo_id, filename="meta/info.json", local_dir=local_dir, repo_type="dataset", revision=version
)
with open(fpath) as f:
return json.load(f)
def load_tasks(repo_id: str, version: str, local_dir: Path) -> dict:
"""tasks contains all the tasks of the dataset, indexed by their task_index."""
fpath = hf_hub_download(
repo_id, filename="meta/tasks.json", local_dir=local_dir, repo_type="dataset", revision=version
)
with open(fpath) as f:
with open(tasks_path) as f:
tasks = json.load(f)
return {item["task_index"]: item["task"] for item in sorted(tasks, key=lambda x: x["task_index"])}
stats = {key: torch.tensor(value) for key, value in flatten_dict(stats).items()}
stats = unflatten_dict(stats)
tasks = {item["task_index"]: item["task"] for item in sorted(tasks, key=lambda x: x["task_index"])}
return info, episode_dicts, stats, tasks
def get_episode_data_index(episodes: list, episode_dicts: list[dict]) -> dict[str, torch.Tensor]: