Add json/jsonl io functions

This commit is contained in:
Simon Alibert 2024-10-24 11:49:53 +02:00
parent 8bcf81fa24
commit 18ffa4248b
3 changed files with 28 additions and 32 deletions

View File

@ -36,7 +36,7 @@ from lerobot.common.datasets.utils import (
STATS_PATH, STATS_PATH,
TASKS_PATH, TASKS_PATH,
_get_info_from_robot, _get_info_from_robot,
append_jsonl, append_jsonlines,
check_delta_timestamps, check_delta_timestamps,
check_timestamps_sync, check_timestamps_sync,
check_version_compatibility, check_version_compatibility,
@ -648,7 +648,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
"task_index": task_index, "task_index": task_index,
"task": task, "task": task,
} }
append_jsonl(task_dict, self.root / TASKS_PATH) append_jsonlines(task_dict, self.root / TASKS_PATH)
chunk = self.get_episode_chunk(episode_index) chunk = self.get_episode_chunk(episode_index)
if chunk >= self.total_chunks: if chunk >= self.total_chunks:
@ -664,7 +664,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
"length": episode_length, "length": episode_length,
} }
self.episode_dicts.append(episode_dict) self.episode_dicts.append(episode_dict)
append_jsonl(episode_dict, self.root / EPISODES_PATH) append_jsonlines(episode_dict, self.root / EPISODES_PATH)
def clear_episode_buffer(self) -> None: def clear_episode_buffer(self) -> None:
episode_index = self.episode_buffer["episode_index"] episode_index = self.episode_buffer["episode_index"]

View File

@ -18,7 +18,7 @@ import warnings
from itertools import accumulate from itertools import accumulate
from pathlib import Path from pathlib import Path
from pprint import pformat from pprint import pformat
from typing import Dict from typing import Any, Dict
import datasets import datasets
import jsonlines import jsonlines
@ -80,13 +80,29 @@ def unflatten_dict(d: dict, sep: str = "/") -> dict:
return outdict return outdict
def load_json(fpath: Path) -> Any:
with open(fpath) as f:
return json.load(f)
def write_json(data: dict, fpath: Path) -> None: def write_json(data: dict, fpath: Path) -> None:
fpath.parent.mkdir(exist_ok=True, parents=True) fpath.parent.mkdir(exist_ok=True, parents=True)
with open(fpath, "w") as f: with open(fpath, "w") as f:
json.dump(data, f, indent=4, ensure_ascii=False) json.dump(data, f, indent=4, ensure_ascii=False)
def append_jsonl(data: dict, fpath: Path) -> None: def load_jsonlines(fpath: Path) -> list[Any]:
with jsonlines.open(fpath, "r") as reader:
return list(reader)
def write_jsonlines(data: dict, fpath: Path) -> None:
fpath.parent.mkdir(exist_ok=True, parents=True)
with jsonlines.open(fpath, "w") as writer:
writer.write_all(data)
def append_jsonlines(data: dict, fpath: Path) -> None:
fpath.parent.mkdir(exist_ok=True, parents=True) fpath.parent.mkdir(exist_ok=True, parents=True)
with jsonlines.open(fpath, "a") as writer: with jsonlines.open(fpath, "a") as writer:
writer.write(data) writer.write(data)
@ -170,27 +186,22 @@ def get_hub_safe_version(repo_id: str, version: str, enforce_v2: bool = True) ->
def load_info(local_dir: Path) -> dict: def load_info(local_dir: Path) -> dict:
with open(local_dir / INFO_PATH) as f: return load_json(local_dir / INFO_PATH)
return json.load(f)
def load_stats(local_dir: Path) -> dict: def load_stats(local_dir: Path) -> dict:
with open(local_dir / STATS_PATH) as f: stats = load_json(local_dir / STATS_PATH)
stats = json.load(f)
stats = {key: torch.tensor(value) for key, value in flatten_dict(stats).items()} stats = {key: torch.tensor(value) for key, value in flatten_dict(stats).items()}
return unflatten_dict(stats) return unflatten_dict(stats)
def load_tasks(local_dir: Path) -> dict: def load_tasks(local_dir: Path) -> dict:
with jsonlines.open(local_dir / TASKS_PATH, "r") as reader: tasks = load_jsonlines(local_dir / TASKS_PATH)
tasks = list(reader)
return {item["task_index"]: item["task"] for item in sorted(tasks, key=lambda x: x["task_index"])} return {item["task_index"]: item["task"] for item in sorted(tasks, key=lambda x: x["task_index"])}
def load_episode_dicts(local_dir: Path) -> dict: def load_episode_dicts(local_dir: Path) -> dict:
with jsonlines.open(local_dir / EPISODES_PATH, "r") as reader: return load_jsonlines(local_dir / EPISODES_PATH)
return list(reader)
def _get_info_from_robot(robot: Robot, use_videos: bool) -> tuple[list | dict]: def _get_info_from_robot(robot: Robot, use_videos: bool) -> tuple[list | dict]:

View File

@ -110,7 +110,6 @@ import warnings
from pathlib import Path from pathlib import Path
import datasets import datasets
import jsonlines
import pyarrow.compute as pc import pyarrow.compute as pc
import pyarrow.parquet as pq import pyarrow.parquet as pq
import torch import torch
@ -132,7 +131,10 @@ from lerobot.common.datasets.utils import (
create_lerobot_dataset_card, create_lerobot_dataset_card,
flatten_dict, flatten_dict,
get_hub_safe_version, get_hub_safe_version,
load_json,
unflatten_dict, unflatten_dict,
write_json,
write_jsonlines,
) )
from lerobot.common.datasets.video_utils import VideoFrame # noqa: F401 from lerobot.common.datasets.video_utils import VideoFrame # noqa: F401
from lerobot.common.utils.utils import init_hydra_config from lerobot.common.utils.utils import init_hydra_config
@ -175,23 +177,6 @@ def parse_robot_config(config_path: Path, config_overrides: list[str] | None = N
} }
def load_json(fpath: Path) -> dict:
with open(fpath) as f:
return json.load(f)
def write_json(data: dict, fpath: Path) -> None:
fpath.parent.mkdir(exist_ok=True, parents=True)
with open(fpath, "w") as f:
json.dump(data, f, indent=4, ensure_ascii=False)
def write_jsonlines(data: dict, fpath: Path) -> None:
fpath.parent.mkdir(exist_ok=True, parents=True)
with jsonlines.open(fpath, "w") as writer:
writer.write_all(data)
def convert_stats_to_json(v1_dir: Path, v2_dir: Path) -> None: def convert_stats_to_json(v1_dir: Path, v2_dir: Path) -> None:
safetensor_path = v1_dir / V1_STATS_PATH safetensor_path = v1_dir / V1_STATS_PATH
stats = load_file(safetensor_path) stats = load_file(safetensor_path)