Write episodes as jsonlines

This commit is contained in:
Simon Alibert 2024-10-17 10:17:27 +02:00
parent c146ba936f
commit 50a75ad3fe
3 changed files with 33 additions and 8 deletions

View File

@ -93,6 +93,7 @@ import warnings
from pathlib import Path from pathlib import Path
import datasets import datasets
import jsonlines
import pyarrow.compute as pc import pyarrow.compute as pc
import pyarrow.parquet as pq import pyarrow.parquet as pq
import torch import torch
@ -160,6 +161,11 @@ def write_json(data: dict, fpath: Path) -> None:
json.dump(data, f, indent=4) json.dump(data, f, indent=4)
def write_jsonlines(data: dict, fpath: Path) -> None:
with jsonlines.open(fpath, "w") as writer:
writer.write_all(data)
def convert_stats_to_json(input_dir: Path, output_dir: Path) -> None: def convert_stats_to_json(input_dir: Path, output_dir: Path) -> None:
safetensor_path = input_dir / "stats.safetensors" safetensor_path = input_dir / "stats.safetensors"
stats = load_file(safetensor_path) stats = load_file(safetensor_path)
@ -617,7 +623,7 @@ def convert_dataset(
{"episode_index": ep_idx, "tasks": [tasks_by_episodes[ep_idx]], "length": episode_lengths[ep_idx]} {"episode_index": ep_idx, "tasks": [tasks_by_episodes[ep_idx]], "length": episode_lengths[ep_idx]}
for ep_idx in episode_indices for ep_idx in episode_indices
] ]
write_json(episodes, v20_dir / "meta" / "episodes.json") write_jsonlines(episodes, v20_dir / "meta" / "episodes.jsonl")
# Assemble metadata v2.0 # Assemble metadata v2.0
metadata_v2_0 = { metadata_v2_0 = {
@ -648,6 +654,9 @@ def convert_dataset(
with contextlib.suppress(EntryNotFoundError): with contextlib.suppress(EntryNotFoundError):
hub_api.delete_folder(repo_id=repo_id, path_in_repo="meta_data", repo_type="dataset", revision=branch) hub_api.delete_folder(repo_id=repo_id, path_in_repo="meta_data", repo_type="dataset", revision=branch)
with contextlib.suppress(EntryNotFoundError):
hub_api.delete_folder(repo_id=repo_id, path_in_repo="meta", repo_type="dataset", revision=branch)
hub_api.upload_folder( hub_api.upload_folder(
repo_id=repo_id, repo_id=repo_id,
path_in_repo="data", path_in_repo="data",
@ -681,6 +690,7 @@ def convert_dataset(
# - [X] Handle multitask datasets # - [X] Handle multitask datasets
# - [X] Handle hf hub repo limits (add chunks logic) # - [X] Handle hf hub repo limits (add chunks logic)
# - [X] Add test-branch # - [X] Add test-branch
# - [X] Use jsonlines for episodes
# - [X] Add sanity checks (encoding, shapes) # - [X] Add sanity checks (encoding, shapes)

28
poetry.lock generated
View File

@ -2620,6 +2620,20 @@ files = [
{file = "json5-0.9.25.tar.gz", hash = "sha256:548e41b9be043f9426776f05df8635a00fe06104ea51ed24b67f908856e151ae"}, {file = "json5-0.9.25.tar.gz", hash = "sha256:548e41b9be043f9426776f05df8635a00fe06104ea51ed24b67f908856e151ae"},
] ]
[[package]]
name = "jsonlines"
version = "4.0.0"
description = "Library with helpers for the jsonlines file format"
optional = false
python-versions = ">=3.8"
files = [
{file = "jsonlines-4.0.0-py3-none-any.whl", hash = "sha256:185b334ff2ca5a91362993f42e83588a360cf95ce4b71a73548502bda52a7c55"},
{file = "jsonlines-4.0.0.tar.gz", hash = "sha256:0c6d2c09117550c089995247f605ae4cf77dd1533041d366351f6f298822ea74"},
]
[package.dependencies]
attrs = ">=19.2.0"
[[package]] [[package]]
name = "jsonpointer" name = "jsonpointer"
version = "3.0.0" version = "3.0.0"
@ -4216,10 +4230,10 @@ files = [
[package.dependencies] [package.dependencies]
numpy = [ numpy = [
{version = ">=1.26.0", markers = "python_version >= \"3.12\""},
{version = ">=1.23.5", markers = "python_version >= \"3.11\" and python_version < \"3.12\""},
{version = ">=1.21.4", markers = "python_version >= \"3.10\" and platform_system == \"Darwin\" and python_version < \"3.11\""}, {version = ">=1.21.4", markers = "python_version >= \"3.10\" and platform_system == \"Darwin\" and python_version < \"3.11\""},
{version = ">=1.21.2", markers = "platform_system != \"Darwin\" and python_version >= \"3.10\" and python_version < \"3.11\""}, {version = ">=1.21.2", markers = "platform_system != \"Darwin\" and python_version >= \"3.10\" and python_version < \"3.11\""},
{version = ">=1.23.5", markers = "python_version >= \"3.11\" and python_version < \"3.12\""},
{version = ">=1.26.0", markers = "python_version >= \"3.12\""},
] ]
[[package]] [[package]]
@ -4240,10 +4254,10 @@ files = [
[package.dependencies] [package.dependencies]
numpy = [ numpy = [
{version = ">=1.26.0", markers = "python_version >= \"3.12\""},
{version = ">=1.23.5", markers = "python_version >= \"3.11\" and python_version < \"3.12\""},
{version = ">=1.21.4", markers = "python_version >= \"3.10\" and platform_system == \"Darwin\" and python_version < \"3.11\""}, {version = ">=1.21.4", markers = "python_version >= \"3.10\" and platform_system == \"Darwin\" and python_version < \"3.11\""},
{version = ">=1.21.2", markers = "platform_system != \"Darwin\" and python_version >= \"3.10\" and python_version < \"3.11\""}, {version = ">=1.21.2", markers = "platform_system != \"Darwin\" and python_version >= \"3.10\" and python_version < \"3.11\""},
{version = ">=1.23.5", markers = "python_version >= \"3.11\" and python_version < \"3.12\""},
{version = ">=1.26.0", markers = "python_version >= \"3.12\""},
] ]
[[package]] [[package]]
@ -4332,9 +4346,9 @@ files = [
[package.dependencies] [package.dependencies]
numpy = [ numpy = [
{version = ">=1.26.0", markers = "python_version >= \"3.12\""},
{version = ">=1.23.2", markers = "python_version == \"3.11\""},
{version = ">=1.22.4", markers = "python_version < \"3.11\""}, {version = ">=1.22.4", markers = "python_version < \"3.11\""},
{version = ">=1.23.2", markers = "python_version == \"3.11\""},
{version = ">=1.26.0", markers = "python_version >= \"3.12\""},
] ]
python-dateutil = ">=2.8.2" python-dateutil = ">=2.8.2"
pytz = ">=2020.1" pytz = ">=2020.1"
@ -7562,4 +7576,4 @@ xarm = ["gym-xarm"]
[metadata] [metadata]
lock-version = "2.0" lock-version = "2.0"
python-versions = ">=3.10,<3.13" python-versions = ">=3.10,<3.13"
content-hash = "f64e01ce021ae77baa2c9bb82cbd2dd6035ab01a1500207da7acdb7f9d0772e1" content-hash = "b79d32bec01c53a3ca48548b85e6f991c9d8fc091f3f528e0b54c6e9fac63ff9"

View File

@ -69,6 +69,7 @@ pyrealsense2 = {version = ">=2.55.1.6486", markers = "sys_platform != 'darwin'",
pyrender = {git = "https://github.com/mmatl/pyrender.git", markers = "sys_platform == 'linux'", optional = true} pyrender = {git = "https://github.com/mmatl/pyrender.git", markers = "sys_platform == 'linux'", optional = true}
hello-robot-stretch-body = {version = ">=0.7.27", markers = "sys_platform == 'linux'", optional = true} hello-robot-stretch-body = {version = ">=0.7.27", markers = "sys_platform == 'linux'", optional = true}
pyserial = {version = ">=3.5", optional = true} pyserial = {version = ">=3.5", optional = true}
jsonlines = "^4.0.0"
[tool.poetry.extras] [tool.poetry.extras]