Write episodes as jsonlines

This commit is contained in:
Simon Alibert 2024-10-17 10:17:27 +02:00
parent c146ba936f
commit 50a75ad3fe
3 changed files with 33 additions and 8 deletions

View File

@ -93,6 +93,7 @@ import warnings
from pathlib import Path
import datasets
import jsonlines
import pyarrow.compute as pc
import pyarrow.parquet as pq
import torch
@ -160,6 +161,11 @@ def write_json(data: dict, fpath: Path) -> None:
json.dump(data, f, indent=4)
def write_jsonlines(data: dict, fpath: Path) -> None:
with jsonlines.open(fpath, "w") as writer:
writer.write_all(data)
def convert_stats_to_json(input_dir: Path, output_dir: Path) -> None:
safetensor_path = input_dir / "stats.safetensors"
stats = load_file(safetensor_path)
@ -617,7 +623,7 @@ def convert_dataset(
{"episode_index": ep_idx, "tasks": [tasks_by_episodes[ep_idx]], "length": episode_lengths[ep_idx]}
for ep_idx in episode_indices
]
write_json(episodes, v20_dir / "meta" / "episodes.json")
write_jsonlines(episodes, v20_dir / "meta" / "episodes.jsonl")
# Assemble metadata v2.0
metadata_v2_0 = {
@ -648,6 +654,9 @@ def convert_dataset(
with contextlib.suppress(EntryNotFoundError):
hub_api.delete_folder(repo_id=repo_id, path_in_repo="meta_data", repo_type="dataset", revision=branch)
with contextlib.suppress(EntryNotFoundError):
hub_api.delete_folder(repo_id=repo_id, path_in_repo="meta", repo_type="dataset", revision=branch)
hub_api.upload_folder(
repo_id=repo_id,
path_in_repo="data",
@ -681,6 +690,7 @@ def convert_dataset(
# - [X] Handle multitask datasets
# - [X] Handle hf hub repo limits (add chunks logic)
# - [X] Add test-branch
# - [X] Use jsonlines for episodes
# - [X] Add sanity checks (encoding, shapes)

28
poetry.lock generated
View File

@ -2620,6 +2620,20 @@ files = [
{file = "json5-0.9.25.tar.gz", hash = "sha256:548e41b9be043f9426776f05df8635a00fe06104ea51ed24b67f908856e151ae"},
]
[[package]]
name = "jsonlines"
version = "4.0.0"
description = "Library with helpers for the jsonlines file format"
optional = false
python-versions = ">=3.8"
files = [
{file = "jsonlines-4.0.0-py3-none-any.whl", hash = "sha256:185b334ff2ca5a91362993f42e83588a360cf95ce4b71a73548502bda52a7c55"},
{file = "jsonlines-4.0.0.tar.gz", hash = "sha256:0c6d2c09117550c089995247f605ae4cf77dd1533041d366351f6f298822ea74"},
]
[package.dependencies]
attrs = ">=19.2.0"
[[package]]
name = "jsonpointer"
version = "3.0.0"
@ -4216,10 +4230,10 @@ files = [
[package.dependencies]
numpy = [
{version = ">=1.26.0", markers = "python_version >= \"3.12\""},
{version = ">=1.23.5", markers = "python_version >= \"3.11\" and python_version < \"3.12\""},
{version = ">=1.21.4", markers = "python_version >= \"3.10\" and platform_system == \"Darwin\" and python_version < \"3.11\""},
{version = ">=1.21.2", markers = "platform_system != \"Darwin\" and python_version >= \"3.10\" and python_version < \"3.11\""},
{version = ">=1.23.5", markers = "python_version >= \"3.11\" and python_version < \"3.12\""},
{version = ">=1.26.0", markers = "python_version >= \"3.12\""},
]
[[package]]
@ -4240,10 +4254,10 @@ files = [
[package.dependencies]
numpy = [
{version = ">=1.26.0", markers = "python_version >= \"3.12\""},
{version = ">=1.23.5", markers = "python_version >= \"3.11\" and python_version < \"3.12\""},
{version = ">=1.21.4", markers = "python_version >= \"3.10\" and platform_system == \"Darwin\" and python_version < \"3.11\""},
{version = ">=1.21.2", markers = "platform_system != \"Darwin\" and python_version >= \"3.10\" and python_version < \"3.11\""},
{version = ">=1.23.5", markers = "python_version >= \"3.11\" and python_version < \"3.12\""},
{version = ">=1.26.0", markers = "python_version >= \"3.12\""},
]
[[package]]
@ -4332,9 +4346,9 @@ files = [
[package.dependencies]
numpy = [
{version = ">=1.26.0", markers = "python_version >= \"3.12\""},
{version = ">=1.23.2", markers = "python_version == \"3.11\""},
{version = ">=1.22.4", markers = "python_version < \"3.11\""},
{version = ">=1.23.2", markers = "python_version == \"3.11\""},
{version = ">=1.26.0", markers = "python_version >= \"3.12\""},
]
python-dateutil = ">=2.8.2"
pytz = ">=2020.1"
@ -7562,4 +7576,4 @@ xarm = ["gym-xarm"]
[metadata]
lock-version = "2.0"
python-versions = ">=3.10,<3.13"
content-hash = "f64e01ce021ae77baa2c9bb82cbd2dd6035ab01a1500207da7acdb7f9d0772e1"
content-hash = "b79d32bec01c53a3ca48548b85e6f991c9d8fc091f3f528e0b54c6e9fac63ff9"

View File

@ -69,6 +69,7 @@ pyrealsense2 = {version = ">=2.55.1.6486", markers = "sys_platform != 'darwin'",
pyrender = {git = "https://github.com/mmatl/pyrender.git", markers = "sys_platform == 'linux'", optional = true}
hello-robot-stretch-body = {version = ">=0.7.27", markers = "sys_platform == 'linux'", optional = true}
pyserial = {version = ">=3.5", optional = true}
jsonlines = "^4.0.0"
[tool.poetry.extras]