Write episodes as jsonlines
This commit is contained in:
parent
c146ba936f
commit
50a75ad3fe
|
@ -93,6 +93,7 @@ import warnings
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
import datasets
|
import datasets
|
||||||
|
import jsonlines
|
||||||
import pyarrow.compute as pc
|
import pyarrow.compute as pc
|
||||||
import pyarrow.parquet as pq
|
import pyarrow.parquet as pq
|
||||||
import torch
|
import torch
|
||||||
|
@ -160,6 +161,11 @@ def write_json(data: dict, fpath: Path) -> None:
|
||||||
json.dump(data, f, indent=4)
|
json.dump(data, f, indent=4)
|
||||||
|
|
||||||
|
|
||||||
|
def write_jsonlines(data: dict, fpath: Path) -> None:
|
||||||
|
with jsonlines.open(fpath, "w") as writer:
|
||||||
|
writer.write_all(data)
|
||||||
|
|
||||||
|
|
||||||
def convert_stats_to_json(input_dir: Path, output_dir: Path) -> None:
|
def convert_stats_to_json(input_dir: Path, output_dir: Path) -> None:
|
||||||
safetensor_path = input_dir / "stats.safetensors"
|
safetensor_path = input_dir / "stats.safetensors"
|
||||||
stats = load_file(safetensor_path)
|
stats = load_file(safetensor_path)
|
||||||
|
@ -617,7 +623,7 @@ def convert_dataset(
|
||||||
{"episode_index": ep_idx, "tasks": [tasks_by_episodes[ep_idx]], "length": episode_lengths[ep_idx]}
|
{"episode_index": ep_idx, "tasks": [tasks_by_episodes[ep_idx]], "length": episode_lengths[ep_idx]}
|
||||||
for ep_idx in episode_indices
|
for ep_idx in episode_indices
|
||||||
]
|
]
|
||||||
write_json(episodes, v20_dir / "meta" / "episodes.json")
|
write_jsonlines(episodes, v20_dir / "meta" / "episodes.jsonl")
|
||||||
|
|
||||||
# Assemble metadata v2.0
|
# Assemble metadata v2.0
|
||||||
metadata_v2_0 = {
|
metadata_v2_0 = {
|
||||||
|
@ -648,6 +654,9 @@ def convert_dataset(
|
||||||
with contextlib.suppress(EntryNotFoundError):
|
with contextlib.suppress(EntryNotFoundError):
|
||||||
hub_api.delete_folder(repo_id=repo_id, path_in_repo="meta_data", repo_type="dataset", revision=branch)
|
hub_api.delete_folder(repo_id=repo_id, path_in_repo="meta_data", repo_type="dataset", revision=branch)
|
||||||
|
|
||||||
|
with contextlib.suppress(EntryNotFoundError):
|
||||||
|
hub_api.delete_folder(repo_id=repo_id, path_in_repo="meta", repo_type="dataset", revision=branch)
|
||||||
|
|
||||||
hub_api.upload_folder(
|
hub_api.upload_folder(
|
||||||
repo_id=repo_id,
|
repo_id=repo_id,
|
||||||
path_in_repo="data",
|
path_in_repo="data",
|
||||||
|
@ -681,6 +690,7 @@ def convert_dataset(
|
||||||
# - [X] Handle multitask datasets
|
# - [X] Handle multitask datasets
|
||||||
# - [X] Handle hf hub repo limits (add chunks logic)
|
# - [X] Handle hf hub repo limits (add chunks logic)
|
||||||
# - [X] Add test-branch
|
# - [X] Add test-branch
|
||||||
|
# - [X] Use jsonlines for episodes
|
||||||
# - [X] Add sanity checks (encoding, shapes)
|
# - [X] Add sanity checks (encoding, shapes)
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -2620,6 +2620,20 @@ files = [
|
||||||
{file = "json5-0.9.25.tar.gz", hash = "sha256:548e41b9be043f9426776f05df8635a00fe06104ea51ed24b67f908856e151ae"},
|
{file = "json5-0.9.25.tar.gz", hash = "sha256:548e41b9be043f9426776f05df8635a00fe06104ea51ed24b67f908856e151ae"},
|
||||||
]
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "jsonlines"
|
||||||
|
version = "4.0.0"
|
||||||
|
description = "Library with helpers for the jsonlines file format"
|
||||||
|
optional = false
|
||||||
|
python-versions = ">=3.8"
|
||||||
|
files = [
|
||||||
|
{file = "jsonlines-4.0.0-py3-none-any.whl", hash = "sha256:185b334ff2ca5a91362993f42e83588a360cf95ce4b71a73548502bda52a7c55"},
|
||||||
|
{file = "jsonlines-4.0.0.tar.gz", hash = "sha256:0c6d2c09117550c089995247f605ae4cf77dd1533041d366351f6f298822ea74"},
|
||||||
|
]
|
||||||
|
|
||||||
|
[package.dependencies]
|
||||||
|
attrs = ">=19.2.0"
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "jsonpointer"
|
name = "jsonpointer"
|
||||||
version = "3.0.0"
|
version = "3.0.0"
|
||||||
|
@ -4216,10 +4230,10 @@ files = [
|
||||||
|
|
||||||
[package.dependencies]
|
[package.dependencies]
|
||||||
numpy = [
|
numpy = [
|
||||||
{version = ">=1.26.0", markers = "python_version >= \"3.12\""},
|
|
||||||
{version = ">=1.23.5", markers = "python_version >= \"3.11\" and python_version < \"3.12\""},
|
|
||||||
{version = ">=1.21.4", markers = "python_version >= \"3.10\" and platform_system == \"Darwin\" and python_version < \"3.11\""},
|
{version = ">=1.21.4", markers = "python_version >= \"3.10\" and platform_system == \"Darwin\" and python_version < \"3.11\""},
|
||||||
{version = ">=1.21.2", markers = "platform_system != \"Darwin\" and python_version >= \"3.10\" and python_version < \"3.11\""},
|
{version = ">=1.21.2", markers = "platform_system != \"Darwin\" and python_version >= \"3.10\" and python_version < \"3.11\""},
|
||||||
|
{version = ">=1.23.5", markers = "python_version >= \"3.11\" and python_version < \"3.12\""},
|
||||||
|
{version = ">=1.26.0", markers = "python_version >= \"3.12\""},
|
||||||
]
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
|
@ -4240,10 +4254,10 @@ files = [
|
||||||
|
|
||||||
[package.dependencies]
|
[package.dependencies]
|
||||||
numpy = [
|
numpy = [
|
||||||
{version = ">=1.26.0", markers = "python_version >= \"3.12\""},
|
|
||||||
{version = ">=1.23.5", markers = "python_version >= \"3.11\" and python_version < \"3.12\""},
|
|
||||||
{version = ">=1.21.4", markers = "python_version >= \"3.10\" and platform_system == \"Darwin\" and python_version < \"3.11\""},
|
{version = ">=1.21.4", markers = "python_version >= \"3.10\" and platform_system == \"Darwin\" and python_version < \"3.11\""},
|
||||||
{version = ">=1.21.2", markers = "platform_system != \"Darwin\" and python_version >= \"3.10\" and python_version < \"3.11\""},
|
{version = ">=1.21.2", markers = "platform_system != \"Darwin\" and python_version >= \"3.10\" and python_version < \"3.11\""},
|
||||||
|
{version = ">=1.23.5", markers = "python_version >= \"3.11\" and python_version < \"3.12\""},
|
||||||
|
{version = ">=1.26.0", markers = "python_version >= \"3.12\""},
|
||||||
]
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
|
@ -4332,9 +4346,9 @@ files = [
|
||||||
|
|
||||||
[package.dependencies]
|
[package.dependencies]
|
||||||
numpy = [
|
numpy = [
|
||||||
{version = ">=1.26.0", markers = "python_version >= \"3.12\""},
|
|
||||||
{version = ">=1.23.2", markers = "python_version == \"3.11\""},
|
|
||||||
{version = ">=1.22.4", markers = "python_version < \"3.11\""},
|
{version = ">=1.22.4", markers = "python_version < \"3.11\""},
|
||||||
|
{version = ">=1.23.2", markers = "python_version == \"3.11\""},
|
||||||
|
{version = ">=1.26.0", markers = "python_version >= \"3.12\""},
|
||||||
]
|
]
|
||||||
python-dateutil = ">=2.8.2"
|
python-dateutil = ">=2.8.2"
|
||||||
pytz = ">=2020.1"
|
pytz = ">=2020.1"
|
||||||
|
@ -7562,4 +7576,4 @@ xarm = ["gym-xarm"]
|
||||||
[metadata]
|
[metadata]
|
||||||
lock-version = "2.0"
|
lock-version = "2.0"
|
||||||
python-versions = ">=3.10,<3.13"
|
python-versions = ">=3.10,<3.13"
|
||||||
content-hash = "f64e01ce021ae77baa2c9bb82cbd2dd6035ab01a1500207da7acdb7f9d0772e1"
|
content-hash = "b79d32bec01c53a3ca48548b85e6f991c9d8fc091f3f528e0b54c6e9fac63ff9"
|
||||||
|
|
|
@ -69,6 +69,7 @@ pyrealsense2 = {version = ">=2.55.1.6486", markers = "sys_platform != 'darwin'",
|
||||||
pyrender = {git = "https://github.com/mmatl/pyrender.git", markers = "sys_platform == 'linux'", optional = true}
|
pyrender = {git = "https://github.com/mmatl/pyrender.git", markers = "sys_platform == 'linux'", optional = true}
|
||||||
hello-robot-stretch-body = {version = ">=0.7.27", markers = "sys_platform == 'linux'", optional = true}
|
hello-robot-stretch-body = {version = ">=0.7.27", markers = "sys_platform == 'linux'", optional = true}
|
||||||
pyserial = {version = ">=3.5", optional = true}
|
pyserial = {version = ">=3.5", optional = true}
|
||||||
|
jsonlines = "^4.0.0"
|
||||||
|
|
||||||
|
|
||||||
[tool.poetry.extras]
|
[tool.poetry.extras]
|
||||||
|
|
Loading…
Reference in New Issue