Add json/jsonl io functions

This commit is contained in:
Simon Alibert 2024-10-24 11:49:53 +02:00
parent 8bcf81fa24
commit 18ffa4248b
3 changed files with 28 additions and 32 deletions

View File

@ -36,7 +36,7 @@ from lerobot.common.datasets.utils import (
STATS_PATH,
TASKS_PATH,
_get_info_from_robot,
append_jsonl,
append_jsonlines,
check_delta_timestamps,
check_timestamps_sync,
check_version_compatibility,
@ -648,7 +648,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
"task_index": task_index,
"task": task,
}
append_jsonl(task_dict, self.root / TASKS_PATH)
append_jsonlines(task_dict, self.root / TASKS_PATH)
chunk = self.get_episode_chunk(episode_index)
if chunk >= self.total_chunks:
@ -664,7 +664,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
"length": episode_length,
}
self.episode_dicts.append(episode_dict)
append_jsonl(episode_dict, self.root / EPISODES_PATH)
append_jsonlines(episode_dict, self.root / EPISODES_PATH)
def clear_episode_buffer(self) -> None:
episode_index = self.episode_buffer["episode_index"]

View File

@ -18,7 +18,7 @@ import warnings
from itertools import accumulate
from pathlib import Path
from pprint import pformat
from typing import Dict
from typing import Any, Dict
import datasets
import jsonlines
@ -80,13 +80,29 @@ def unflatten_dict(d: dict, sep: str = "/") -> dict:
return outdict
def load_json(fpath: Path) -> Any:
with open(fpath) as f:
return json.load(f)
def write_json(data: dict, fpath: Path) -> None:
fpath.parent.mkdir(exist_ok=True, parents=True)
with open(fpath, "w") as f:
json.dump(data, f, indent=4, ensure_ascii=False)
def append_jsonl(data: dict, fpath: Path) -> None:
def load_jsonlines(fpath: Path) -> list[Any]:
with jsonlines.open(fpath, "r") as reader:
return list(reader)
def write_jsonlines(data: dict, fpath: Path) -> None:
fpath.parent.mkdir(exist_ok=True, parents=True)
with jsonlines.open(fpath, "w") as writer:
writer.write_all(data)
def append_jsonlines(data: dict, fpath: Path) -> None:
fpath.parent.mkdir(exist_ok=True, parents=True)
with jsonlines.open(fpath, "a") as writer:
writer.write(data)
@ -170,27 +186,22 @@ def get_hub_safe_version(repo_id: str, version: str, enforce_v2: bool = True) ->
def load_info(local_dir: Path) -> dict:
with open(local_dir / INFO_PATH) as f:
return json.load(f)
return load_json(local_dir / INFO_PATH)
def load_stats(local_dir: Path) -> dict:
with open(local_dir / STATS_PATH) as f:
stats = json.load(f)
stats = load_json(local_dir / STATS_PATH)
stats = {key: torch.tensor(value) for key, value in flatten_dict(stats).items()}
return unflatten_dict(stats)
def load_tasks(local_dir: Path) -> dict:
with jsonlines.open(local_dir / TASKS_PATH, "r") as reader:
tasks = list(reader)
tasks = load_jsonlines(local_dir / TASKS_PATH)
return {item["task_index"]: item["task"] for item in sorted(tasks, key=lambda x: x["task_index"])}
def load_episode_dicts(local_dir: Path) -> dict:
with jsonlines.open(local_dir / EPISODES_PATH, "r") as reader:
return list(reader)
return load_jsonlines(local_dir / EPISODES_PATH)
def _get_info_from_robot(robot: Robot, use_videos: bool) -> tuple[list | dict]:

View File

@ -110,7 +110,6 @@ import warnings
from pathlib import Path
import datasets
import jsonlines
import pyarrow.compute as pc
import pyarrow.parquet as pq
import torch
@ -132,7 +131,10 @@ from lerobot.common.datasets.utils import (
create_lerobot_dataset_card,
flatten_dict,
get_hub_safe_version,
load_json,
unflatten_dict,
write_json,
write_jsonlines,
)
from lerobot.common.datasets.video_utils import VideoFrame # noqa: F401
from lerobot.common.utils.utils import init_hydra_config
@ -175,23 +177,6 @@ def parse_robot_config(config_path: Path, config_overrides: list[str] | None = N
}
def load_json(fpath: Path) -> dict:
with open(fpath) as f:
return json.load(f)
def write_json(data: dict, fpath: Path) -> None:
fpath.parent.mkdir(exist_ok=True, parents=True)
with open(fpath, "w") as f:
json.dump(data, f, indent=4, ensure_ascii=False)
def write_jsonlines(data: dict, fpath: Path) -> None:
fpath.parent.mkdir(exist_ok=True, parents=True)
with jsonlines.open(fpath, "w") as writer:
writer.write_all(data)
def convert_stats_to_json(v1_dir: Path, v2_dir: Path) -> None:
safetensor_path = v1_dir / V1_STATS_PATH
stats = load_file(safetensor_path)