From 50a75ad3fe47437f09da71f0946c781e459d5547 Mon Sep 17 00:00:00 2001 From: Simon Alibert Date: Thu, 17 Oct 2024 10:17:27 +0200 Subject: [PATCH] Write episodes as jsonlines --- .../datasets/v2/convert_dataset_v1_to_v2.py | 12 +++++++- poetry.lock | 28 ++++++++++++++----- pyproject.toml | 1 + 3 files changed, 33 insertions(+), 8 deletions(-) diff --git a/lerobot/common/datasets/v2/convert_dataset_v1_to_v2.py b/lerobot/common/datasets/v2/convert_dataset_v1_to_v2.py index 6ddfd2a5..81131f3b 100644 --- a/lerobot/common/datasets/v2/convert_dataset_v1_to_v2.py +++ b/lerobot/common/datasets/v2/convert_dataset_v1_to_v2.py @@ -93,6 +93,7 @@ import warnings from pathlib import Path import datasets +import jsonlines import pyarrow.compute as pc import pyarrow.parquet as pq import torch @@ -160,6 +161,11 @@ def write_json(data: dict, fpath: Path) -> None: json.dump(data, f, indent=4) +def write_jsonlines(data: dict, fpath: Path) -> None: + with jsonlines.open(fpath, "w") as writer: + writer.write_all(data) + + def convert_stats_to_json(input_dir: Path, output_dir: Path) -> None: safetensor_path = input_dir / "stats.safetensors" stats = load_file(safetensor_path) @@ -617,7 +623,7 @@ def convert_dataset( {"episode_index": ep_idx, "tasks": [tasks_by_episodes[ep_idx]], "length": episode_lengths[ep_idx]} for ep_idx in episode_indices ] - write_json(episodes, v20_dir / "meta" / "episodes.json") + write_jsonlines(episodes, v20_dir / "meta" / "episodes.jsonl") # Assemble metadata v2.0 metadata_v2_0 = { @@ -648,6 +654,9 @@ def convert_dataset( with contextlib.suppress(EntryNotFoundError): hub_api.delete_folder(repo_id=repo_id, path_in_repo="meta_data", repo_type="dataset", revision=branch) + with contextlib.suppress(EntryNotFoundError): + hub_api.delete_folder(repo_id=repo_id, path_in_repo="meta", repo_type="dataset", revision=branch) + hub_api.upload_folder( repo_id=repo_id, path_in_repo="data", @@ -681,6 +690,7 @@ def convert_dataset( # - [X] Handle multitask datasets # - [X] Handle hf hub repo limits (add chunks logic) # - [X] Add test-branch + # - [X] Use jsonlines for episodes # - [X] Add sanity checks (encoding, shapes) diff --git a/poetry.lock b/poetry.lock index b4d491ae..011e76ef 100644 --- a/poetry.lock +++ b/poetry.lock @@ -2620,6 +2620,20 @@ files = [ {file = "json5-0.9.25.tar.gz", hash = "sha256:548e41b9be043f9426776f05df8635a00fe06104ea51ed24b67f908856e151ae"}, ] +[[package]] +name = "jsonlines" +version = "4.0.0" +description = "Library with helpers for the jsonlines file format" +optional = false +python-versions = ">=3.8" +files = [ + {file = "jsonlines-4.0.0-py3-none-any.whl", hash = "sha256:185b334ff2ca5a91362993f42e83588a360cf95ce4b71a73548502bda52a7c55"}, + {file = "jsonlines-4.0.0.tar.gz", hash = "sha256:0c6d2c09117550c089995247f605ae4cf77dd1533041d366351f6f298822ea74"}, +] + +[package.dependencies] +attrs = ">=19.2.0" + [[package]] name = "jsonpointer" version = "3.0.0" @@ -4216,10 +4230,10 @@ files = [ [package.dependencies] numpy = [ - {version = ">=1.26.0", markers = "python_version >= \"3.12\""}, - {version = ">=1.23.5", markers = "python_version >= \"3.11\" and python_version < \"3.12\""}, {version = ">=1.21.4", markers = "python_version >= \"3.10\" and platform_system == \"Darwin\" and python_version < \"3.11\""}, {version = ">=1.21.2", markers = "platform_system != \"Darwin\" and python_version >= \"3.10\" and python_version < \"3.11\""}, + {version = ">=1.23.5", markers = "python_version >= \"3.11\" and python_version < \"3.12\""}, + {version = ">=1.26.0", markers = "python_version >= \"3.12\""}, ] [[package]] @@ -4240,10 +4254,10 @@ files = [ [package.dependencies] numpy = [ - {version = ">=1.26.0", markers = "python_version >= \"3.12\""}, - {version = ">=1.23.5", markers = "python_version >= \"3.11\" and python_version < \"3.12\""}, {version = ">=1.21.4", markers = "python_version >= \"3.10\" and platform_system == \"Darwin\" and python_version < \"3.11\""}, {version = ">=1.21.2", markers = "platform_system != \"Darwin\" and python_version >= \"3.10\" and python_version < \"3.11\""}, + {version = ">=1.23.5", markers = "python_version >= \"3.11\" and python_version < \"3.12\""}, + {version = ">=1.26.0", markers = "python_version >= \"3.12\""}, ] [[package]] @@ -4332,9 +4346,9 @@ files = [ [package.dependencies] numpy = [ - {version = ">=1.26.0", markers = "python_version >= \"3.12\""}, - {version = ">=1.23.2", markers = "python_version == \"3.11\""}, {version = ">=1.22.4", markers = "python_version < \"3.11\""}, + {version = ">=1.23.2", markers = "python_version == \"3.11\""}, + {version = ">=1.26.0", markers = "python_version >= \"3.12\""}, ] python-dateutil = ">=2.8.2" pytz = ">=2020.1" @@ -7562,4 +7576,4 @@ xarm = ["gym-xarm"] [metadata] lock-version = "2.0" python-versions = ">=3.10,<3.13" -content-hash = "f64e01ce021ae77baa2c9bb82cbd2dd6035ab01a1500207da7acdb7f9d0772e1" +content-hash = "b79d32bec01c53a3ca48548b85e6f991c9d8fc091f3f528e0b54c6e9fac63ff9" diff --git a/pyproject.toml b/pyproject.toml index 89ed7ff0..85390c19 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -69,6 +69,7 @@ pyrealsense2 = {version = ">=2.55.1.6486", markers = "sys_platform != 'darwin'", pyrender = {git = "https://github.com/mmatl/pyrender.git", markers = "sys_platform == 'linux'", optional = true} hello-robot-stretch-body = {version = ">=0.7.27", markers = "sys_platform == 'linux'", optional = true} pyserial = {version = ">=3.5", optional = true} +jsonlines = "^4.0.0" [tool.poetry.extras]