Add load_metadata
This commit is contained in:
parent
1a51505ec6
commit
bce3dc3bfa
|
@ -15,16 +15,16 @@
|
|||
# limitations under the License.
|
||||
import json
|
||||
import warnings
|
||||
from functools import cache
|
||||
from itertools import accumulate
|
||||
from pathlib import Path
|
||||
from pprint import pformat
|
||||
from typing import Dict
|
||||
|
||||
import datasets
|
||||
import jsonlines
|
||||
import torch
|
||||
from datasets import load_dataset
|
||||
from huggingface_hub import DatasetCard, HfApi, hf_hub_download
|
||||
from huggingface_hub import DatasetCard, HfApi
|
||||
from PIL import Image as PILImage
|
||||
from torchvision import transforms
|
||||
|
||||
|
@ -96,7 +96,6 @@ def hf_transform_to_torch(items_dict: dict[torch.Tensor | None]):
|
|||
return items_dict
|
||||
|
||||
|
||||
@cache
|
||||
def get_hub_safe_version(repo_id: str, version: str, enforce_v2: bool = True) -> str:
|
||||
num_version = float(version.strip("v"))
|
||||
if num_version < 2 and enforce_v2:
|
||||
|
@ -144,50 +143,30 @@ def load_hf_dataset(
|
|||
return hf_dataset
|
||||
|
||||
|
||||
def load_stats(repo_id: str, version: str, local_dir: Path) -> dict[str, dict[str, torch.Tensor]]:
|
||||
"""stats contains the statistics per modality computed over the full dataset, such as max, min, mean, std
|
||||
def load_metadata(local_dir: Path) -> tuple[dict | list]:
|
||||
"""Loads metadata files from a dataset."""
|
||||
info_path = local_dir / "meta/info.json"
|
||||
episodes_path = local_dir / "meta/episodes.jsonl"
|
||||
stats_path = local_dir / "meta/stats.json"
|
||||
tasks_path = local_dir / "meta/tasks.json"
|
||||
|
||||
Example:
|
||||
```python
|
||||
normalized_action = (action - stats["action"]["mean"]) / stats["action"]["std"]
|
||||
```
|
||||
"""
|
||||
fpath = hf_hub_download(
|
||||
repo_id, filename="meta/stats.json", local_dir=local_dir, repo_type="dataset", revision=version
|
||||
)
|
||||
with open(fpath) as f:
|
||||
with open(info_path) as f:
|
||||
info = json.load(f)
|
||||
|
||||
with jsonlines.open(episodes_path, "r") as reader:
|
||||
episode_dicts = list(reader)
|
||||
|
||||
with open(stats_path) as f:
|
||||
stats = json.load(f)
|
||||
|
||||
stats = flatten_dict(stats)
|
||||
stats = {key: torch.tensor(value) for key, value in stats.items()}
|
||||
return unflatten_dict(stats)
|
||||
|
||||
|
||||
def load_info(repo_id: str, version: str, local_dir: Path) -> dict:
|
||||
"""info contains structural information about the dataset. It should be the reference and
|
||||
act as the 'source of thruth' for what's inside the dataset.
|
||||
|
||||
Example:
|
||||
```python
|
||||
print("frame per second used to collect the video", info["fps"])
|
||||
```
|
||||
"""
|
||||
fpath = hf_hub_download(
|
||||
repo_id, filename="meta/info.json", local_dir=local_dir, repo_type="dataset", revision=version
|
||||
)
|
||||
with open(fpath) as f:
|
||||
return json.load(f)
|
||||
|
||||
|
||||
def load_tasks(repo_id: str, version: str, local_dir: Path) -> dict:
|
||||
"""tasks contains all the tasks of the dataset, indexed by their task_index."""
|
||||
fpath = hf_hub_download(
|
||||
repo_id, filename="meta/tasks.json", local_dir=local_dir, repo_type="dataset", revision=version
|
||||
)
|
||||
with open(fpath) as f:
|
||||
with open(tasks_path) as f:
|
||||
tasks = json.load(f)
|
||||
|
||||
return {item["task_index"]: item["task"] for item in sorted(tasks, key=lambda x: x["task_index"])}
|
||||
stats = {key: torch.tensor(value) for key, value in flatten_dict(stats).items()}
|
||||
stats = unflatten_dict(stats)
|
||||
tasks = {item["task_index"]: item["task"] for item in sorted(tasks, key=lambda x: x["task_index"])}
|
||||
|
||||
return info, episode_dicts, stats, tasks
|
||||
|
||||
|
||||
def get_episode_data_index(episodes: list, episode_dicts: list[dict]) -> dict[str, torch.Tensor]:
|
||||
|
|
Loading…
Reference in New Issue