Add load_metadata
This commit is contained in:
parent
1a51505ec6
commit
bce3dc3bfa
|
@ -15,16 +15,16 @@
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
import json
|
import json
|
||||||
import warnings
|
import warnings
|
||||||
from functools import cache
|
|
||||||
from itertools import accumulate
|
from itertools import accumulate
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from pprint import pformat
|
from pprint import pformat
|
||||||
from typing import Dict
|
from typing import Dict
|
||||||
|
|
||||||
import datasets
|
import datasets
|
||||||
|
import jsonlines
|
||||||
import torch
|
import torch
|
||||||
from datasets import load_dataset
|
from datasets import load_dataset
|
||||||
from huggingface_hub import DatasetCard, HfApi, hf_hub_download
|
from huggingface_hub import DatasetCard, HfApi
|
||||||
from PIL import Image as PILImage
|
from PIL import Image as PILImage
|
||||||
from torchvision import transforms
|
from torchvision import transforms
|
||||||
|
|
||||||
|
@ -96,7 +96,6 @@ def hf_transform_to_torch(items_dict: dict[torch.Tensor | None]):
|
||||||
return items_dict
|
return items_dict
|
||||||
|
|
||||||
|
|
||||||
@cache
|
|
||||||
def get_hub_safe_version(repo_id: str, version: str, enforce_v2: bool = True) -> str:
|
def get_hub_safe_version(repo_id: str, version: str, enforce_v2: bool = True) -> str:
|
||||||
num_version = float(version.strip("v"))
|
num_version = float(version.strip("v"))
|
||||||
if num_version < 2 and enforce_v2:
|
if num_version < 2 and enforce_v2:
|
||||||
|
@ -144,50 +143,30 @@ def load_hf_dataset(
|
||||||
return hf_dataset
|
return hf_dataset
|
||||||
|
|
||||||
|
|
||||||
def load_stats(repo_id: str, version: str, local_dir: Path) -> dict[str, dict[str, torch.Tensor]]:
|
def load_metadata(local_dir: Path) -> tuple[dict | list]:
|
||||||
"""stats contains the statistics per modality computed over the full dataset, such as max, min, mean, std
|
"""Loads metadata files from a dataset."""
|
||||||
|
info_path = local_dir / "meta/info.json"
|
||||||
|
episodes_path = local_dir / "meta/episodes.jsonl"
|
||||||
|
stats_path = local_dir / "meta/stats.json"
|
||||||
|
tasks_path = local_dir / "meta/tasks.json"
|
||||||
|
|
||||||
Example:
|
with open(info_path) as f:
|
||||||
```python
|
info = json.load(f)
|
||||||
normalized_action = (action - stats["action"]["mean"]) / stats["action"]["std"]
|
|
||||||
```
|
with jsonlines.open(episodes_path, "r") as reader:
|
||||||
"""
|
episode_dicts = list(reader)
|
||||||
fpath = hf_hub_download(
|
|
||||||
repo_id, filename="meta/stats.json", local_dir=local_dir, repo_type="dataset", revision=version
|
with open(stats_path) as f:
|
||||||
)
|
|
||||||
with open(fpath) as f:
|
|
||||||
stats = json.load(f)
|
stats = json.load(f)
|
||||||
|
|
||||||
stats = flatten_dict(stats)
|
with open(tasks_path) as f:
|
||||||
stats = {key: torch.tensor(value) for key, value in stats.items()}
|
|
||||||
return unflatten_dict(stats)
|
|
||||||
|
|
||||||
|
|
||||||
def load_info(repo_id: str, version: str, local_dir: Path) -> dict:
|
|
||||||
"""info contains structural information about the dataset. It should be the reference and
|
|
||||||
act as the 'source of thruth' for what's inside the dataset.
|
|
||||||
|
|
||||||
Example:
|
|
||||||
```python
|
|
||||||
print("frame per second used to collect the video", info["fps"])
|
|
||||||
```
|
|
||||||
"""
|
|
||||||
fpath = hf_hub_download(
|
|
||||||
repo_id, filename="meta/info.json", local_dir=local_dir, repo_type="dataset", revision=version
|
|
||||||
)
|
|
||||||
with open(fpath) as f:
|
|
||||||
return json.load(f)
|
|
||||||
|
|
||||||
|
|
||||||
def load_tasks(repo_id: str, version: str, local_dir: Path) -> dict:
|
|
||||||
"""tasks contains all the tasks of the dataset, indexed by their task_index."""
|
|
||||||
fpath = hf_hub_download(
|
|
||||||
repo_id, filename="meta/tasks.json", local_dir=local_dir, repo_type="dataset", revision=version
|
|
||||||
)
|
|
||||||
with open(fpath) as f:
|
|
||||||
tasks = json.load(f)
|
tasks = json.load(f)
|
||||||
|
|
||||||
return {item["task_index"]: item["task"] for item in sorted(tasks, key=lambda x: x["task_index"])}
|
stats = {key: torch.tensor(value) for key, value in flatten_dict(stats).items()}
|
||||||
|
stats = unflatten_dict(stats)
|
||||||
|
tasks = {item["task_index"]: item["task"] for item in sorted(tasks, key=lambda x: x["task_index"])}
|
||||||
|
|
||||||
|
return info, episode_dicts, stats, tasks
|
||||||
|
|
||||||
|
|
||||||
def get_episode_data_index(episodes: list, episode_dicts: list[dict]) -> dict[str, torch.Tensor]:
|
def get_episode_data_index(episodes: list, episode_dicts: list[dict]) -> dict[str, torch.Tensor]:
|
||||||
|
|
Loading…
Reference in New Issue