WIP
This commit is contained in:
parent
92573486a8
commit
ad115b6c27
|
@ -0,0 +1,484 @@
|
|||
"""
|
||||
This script will help you convert any LeRobot dataset already pushed to the hub from codebase version 1.6 to
|
||||
2.0. You will be required to provide the 'tasks', which is a short but accurate description in plain English
|
||||
for each of the task performed in the dataset. This will allow to easily train models with task-conditionning.
|
||||
|
||||
If your dataset contains a single task, you can provide it directly via the CLI with the '--task' option (see
|
||||
examples below).
|
||||
|
||||
If your dataset is a multi-task dataset, TODO
|
||||
|
||||
In any case, keep in mind that there should only be one task per episode. Multi-task episodes are not
|
||||
supported for now.
|
||||
|
||||
Usage examples
|
||||
|
||||
Single-task dataset:
|
||||
```bash
|
||||
python convert_dataset_16_to_20.py \
|
||||
--repo-id lerobot/aloha_sim_insertion_human_image \
|
||||
--task "Insert the peg into the socket." \
|
||||
--robot-config lerobot/configs/robot/aloha.yaml
|
||||
```
|
||||
|
||||
```bash
|
||||
python convert_dataset_16_to_20.py \
|
||||
--repo-id aliberts/koch_tutorial \
|
||||
--task "Pick the Lego block and drop it in the box on the right." \
|
||||
--robot-config lerobot/configs/robot/koch.yaml \
|
||||
--local-dir data
|
||||
```
|
||||
|
||||
Multi-task dataset:
|
||||
TODO
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import json
|
||||
import math
|
||||
import subprocess
|
||||
from io import BytesIO
|
||||
from pathlib import Path
|
||||
|
||||
import pyarrow as pa
|
||||
import pyarrow.compute as pc
|
||||
import pyarrow.parquet as pq
|
||||
import torch
|
||||
from huggingface_hub import HfApi
|
||||
from PIL import Image
|
||||
from safetensors.torch import load_file
|
||||
|
||||
from lerobot.common.utils.utils import init_hydra_config
|
||||
|
||||
V1_6 = "v1.6"
|
||||
V2_0 = "v2.0"
|
||||
|
||||
PARQUET_PATH = "data/train-{episode_index:05d}-of-{total_episodes:05d}.parquet"
|
||||
VIDEO_PATH = "videos/{image_key}_episode_{episode_index:06d}.mp4"
|
||||
|
||||
|
||||
def parse_robot_config(config_path: Path, config_overrides: list[str] | None = None) -> tuple[str, dict]:
|
||||
robot_cfg = init_hydra_config(config_path, config_overrides)
|
||||
if robot_cfg["robot_type"] in ["aloha", "koch"]:
|
||||
state_names = [
|
||||
f"{arm}_{motor}" if len(robot_cfg["follower_arms"]) > 1 else motor
|
||||
for arm in robot_cfg["follower_arms"]
|
||||
for motor in robot_cfg["follower_arms"][arm]["motors"]
|
||||
]
|
||||
action_names = [
|
||||
# f"{arm}_{motor}" for arm in ["left", "right"] for motor in robot_cfg["leader_arms"][arm]["motors"]
|
||||
f"{arm}_{motor}" if len(robot_cfg["leader_arms"]) > 1 else motor
|
||||
for arm in robot_cfg["leader_arms"]
|
||||
for motor in robot_cfg["leader_arms"][arm]["motors"]
|
||||
]
|
||||
# elif robot_cfg["robot_type"] == "stretch3": TODO
|
||||
else:
|
||||
raise NotImplementedError(
|
||||
"Please provide robot_config={'robot_type': ..., 'names': ...} directly to convert_dataset()."
|
||||
)
|
||||
|
||||
return {
|
||||
"robot_type": robot_cfg["robot_type"],
|
||||
"names": {
|
||||
"observation.state": state_names,
|
||||
"action": action_names,
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
def load_json(fpath: Path) -> dict:
|
||||
with open(fpath) as f:
|
||||
return json.load(f)
|
||||
|
||||
|
||||
def write_json(data: dict, fpath: Path) -> None:
|
||||
fpath.parent.mkdir(exist_ok=True, parents=True)
|
||||
with open(fpath, "w") as f:
|
||||
json.dump(data, f, indent=4)
|
||||
|
||||
|
||||
def convert_stats_to_json(input_dir: Path, output_dir: Path) -> None:
|
||||
safetensor_path = input_dir / "stats.safetensors"
|
||||
stats = load_file(safetensor_path)
|
||||
serializable_stats = {key: value.tolist() for key, value in stats.items()}
|
||||
|
||||
json_path = output_dir / "stats.json"
|
||||
json_path.parent.mkdir(exist_ok=True, parents=True)
|
||||
with open(json_path, "w") as f:
|
||||
json.dump(serializable_stats, f, indent=4)
|
||||
|
||||
# Sanity check
|
||||
with open(json_path) as f:
|
||||
stats_json = json.load(f)
|
||||
|
||||
stats_json = {key: torch.tensor(value) for key, value in stats_json.items()}
|
||||
for key in stats:
|
||||
torch.testing.assert_close(stats_json[key], stats[key])
|
||||
|
||||
|
||||
def get_keys(table: pa.Table) -> dict[str, list]:
|
||||
table_metadata = json.loads(table.schema.metadata[b"huggingface"].decode("utf-8"))
|
||||
sequence_keys, image_keys, video_keys = [], [], []
|
||||
for key, val in table_metadata["info"]["features"].items():
|
||||
if val["_type"] == "Sequence":
|
||||
sequence_keys.append(key)
|
||||
elif val["_type"] == "Image":
|
||||
image_keys.append(key)
|
||||
elif val["_type"] == "VideoFrame":
|
||||
video_keys.append(key)
|
||||
|
||||
return {
|
||||
"sequence": sequence_keys,
|
||||
"image": image_keys,
|
||||
"video": video_keys,
|
||||
}
|
||||
|
||||
|
||||
def remove_hf_metadata_features(table: pa.Table, features: list[str]) -> pa.Table:
|
||||
# HACK
|
||||
schema = table.schema
|
||||
# decode bytes dict
|
||||
table_metadata = json.loads(schema.metadata[b"huggingface"].decode("utf-8"))
|
||||
for key in features:
|
||||
table_metadata["info"]["features"].pop(key)
|
||||
|
||||
# re-encode bytes dict
|
||||
table_metadata = {b"huggingface": json.dumps(table_metadata).encode("utf-8")}
|
||||
new_schema = schema.with_metadata(table_metadata)
|
||||
return table.replace_schema_metadata(new_schema.metadata)
|
||||
|
||||
|
||||
def add_hf_metadata_features(table: pa.Table, features: dict[str, dict]) -> pa.Table:
|
||||
# HACK
|
||||
schema = table.schema
|
||||
# decode bytes dict
|
||||
table_metadata = json.loads(schema.metadata[b"huggingface"].decode("utf-8"))
|
||||
for key, val in features.items():
|
||||
table_metadata["info"]["features"][key] = val
|
||||
|
||||
# re-encode bytes dict
|
||||
table_metadata = {b"huggingface": json.dumps(table_metadata).encode("utf-8")}
|
||||
new_schema = schema.with_metadata(table_metadata)
|
||||
return table.replace_schema_metadata(new_schema.metadata)
|
||||
|
||||
|
||||
def remove_videoframe_from_table(table: pa.Table, image_columns: list) -> pa.Table:
|
||||
table = table.drop(image_columns)
|
||||
table = remove_hf_metadata_features(table, image_columns)
|
||||
return table
|
||||
|
||||
|
||||
def add_tasks(table: pa.Table, tasks_by_episodes: dict) -> pa.Table:
|
||||
tasks_index = pa.array([tasks_by_episodes.get(key.as_py(), None) for key in table["episode_index"]])
|
||||
table = table.append_column("task_index", tasks_index)
|
||||
hf_feature = {"task_index": {"dtype": "int64", "_type": "Value"}}
|
||||
table = add_hf_metadata_features(table, hf_feature)
|
||||
return table
|
||||
|
||||
|
||||
def split_parquet_by_episodes(
|
||||
table: pa.Table, keys: dict[str, list], total_episodes: int, episode_indices: list, output_dir: Path
|
||||
) -> list:
|
||||
(output_dir / "data").mkdir(exist_ok=True, parents=True)
|
||||
if len(keys["video"]) > 0:
|
||||
table = remove_videoframe_from_table(table, keys["video"])
|
||||
|
||||
episode_lengths = []
|
||||
for episode_index in sorted(episode_indices):
|
||||
# Write each episode_index to a new parquet file
|
||||
filtered_table = table.filter(pc.equal(table["episode_index"], episode_index))
|
||||
episode_lengths.insert(episode_index, len(filtered_table))
|
||||
output_file = output_dir / PARQUET_PATH.format(
|
||||
episode_index=episode_index, total_episodes=total_episodes
|
||||
)
|
||||
pq.write_table(filtered_table, output_file)
|
||||
|
||||
return episode_lengths
|
||||
|
||||
|
||||
def _get_audio_info(video_path: Path | str) -> dict:
|
||||
ffprobe_audio_cmd = [
|
||||
"ffprobe",
|
||||
"-v",
|
||||
"error",
|
||||
"-select_streams",
|
||||
"a:0",
|
||||
"-show_entries",
|
||||
"stream=channels,codec_name,bit_rate,sample_rate,bit_depth,channel_layout,duration",
|
||||
"-of",
|
||||
"json",
|
||||
str(video_path),
|
||||
]
|
||||
result = subprocess.run(ffprobe_audio_cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True)
|
||||
if result.returncode != 0:
|
||||
raise RuntimeError(f"Error running ffprobe: {result.stderr}")
|
||||
|
||||
info = json.loads(result.stdout)
|
||||
audio_stream_info = info["streams"][0] if info.get("streams") else None
|
||||
if audio_stream_info is None:
|
||||
return {"has_audio": False}
|
||||
|
||||
# Return the information, defaulting to None if no audio stream is present
|
||||
return {
|
||||
"has_audio": True,
|
||||
"audio.channels": audio_stream_info.get("channels", None),
|
||||
"audio.codec": audio_stream_info.get("codec_name", None),
|
||||
"audio.bit_rate": int(audio_stream_info["bit_rate"]) if audio_stream_info.get("bit_rate") else None,
|
||||
"audio.sample_rate": int(audio_stream_info["sample_rate"])
|
||||
if audio_stream_info.get("sample_rate")
|
||||
else None,
|
||||
"audio.bit_depth": audio_stream_info.get("bit_depth", None),
|
||||
"audio.channel_layout": audio_stream_info.get("channel_layout", None),
|
||||
}
|
||||
|
||||
|
||||
def _get_video_info(video_path: Path | str) -> dict:
|
||||
ffprobe_video_cmd = [
|
||||
"ffprobe",
|
||||
"-v",
|
||||
"error",
|
||||
"-select_streams",
|
||||
"v:0",
|
||||
"-show_entries",
|
||||
"stream=r_frame_rate,width,height,codec_name,nb_frames,duration,pix_fmt",
|
||||
"-of",
|
||||
"json",
|
||||
str(video_path),
|
||||
]
|
||||
result = subprocess.run(ffprobe_video_cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True)
|
||||
if result.returncode != 0:
|
||||
raise RuntimeError(f"Error running ffprobe: {result.stderr}")
|
||||
|
||||
info = json.loads(result.stdout)
|
||||
video_stream_info = info["streams"][0]
|
||||
|
||||
# Calculate fps from r_frame_rate
|
||||
r_frame_rate = video_stream_info["r_frame_rate"]
|
||||
num, denom = map(int, r_frame_rate.split("/"))
|
||||
fps = num / denom
|
||||
|
||||
video_info = {
|
||||
"video.fps": fps,
|
||||
"video.width": video_stream_info["width"],
|
||||
"video.height": video_stream_info["height"],
|
||||
"video.codec": video_stream_info["codec_name"],
|
||||
"video.pix_fmt": video_stream_info["pix_fmt"],
|
||||
**_get_audio_info(video_path),
|
||||
}
|
||||
|
||||
return video_info
|
||||
|
||||
|
||||
def get_videos_info(repo_id: str, local_dir: Path, video_keys: list[str]) -> dict:
|
||||
hub_api = HfApi()
|
||||
videos_info_dict = {
|
||||
"videos_path": VIDEO_PATH,
|
||||
"has_audio": False,
|
||||
"has_depth": False,
|
||||
}
|
||||
for vid_key in video_keys:
|
||||
video_path = VIDEO_PATH.format(image_key=vid_key, episode_index=0)
|
||||
video_path = hub_api.hf_hub_download(
|
||||
repo_id=repo_id, repo_type="dataset", local_dir=local_dir, filename=video_path
|
||||
)
|
||||
videos_info_dict[vid_key] = _get_video_info(video_path)
|
||||
videos_info_dict["has_audio"] = (
|
||||
videos_info_dict["has_audio"] or videos_info_dict[vid_key]["has_audio"]
|
||||
)
|
||||
|
||||
return videos_info_dict
|
||||
|
||||
|
||||
def get_video_shapes(videos_info: dict, video_keys: list) -> dict:
|
||||
video_shapes = {}
|
||||
for img_key in video_keys:
|
||||
video_shapes[img_key] = {
|
||||
"width": videos_info[img_key]["video.width"],
|
||||
"height": videos_info[img_key]["video.height"],
|
||||
}
|
||||
|
||||
return video_shapes
|
||||
|
||||
|
||||
def get_image_shapes(table: pa.Table, image_keys: list) -> dict:
|
||||
image_shapes = {}
|
||||
for img_key in image_keys:
|
||||
image_bytes = table[img_key][0].as_py() # Assuming first row
|
||||
image = Image.open(BytesIO(image_bytes["bytes"]))
|
||||
image_shapes[img_key] = {
|
||||
"width": image.width,
|
||||
"height": image.height,
|
||||
}
|
||||
|
||||
return image_shapes
|
||||
|
||||
|
||||
def get_generic_motor_names(sequence_shapes: dict) -> dict:
|
||||
return {key: [f"motor_{i}" for i in range(length)] for key, length in sequence_shapes.items()}
|
||||
|
||||
|
||||
def convert_dataset(
|
||||
repo_id: str,
|
||||
local_dir: Path,
|
||||
tasks: dict,
|
||||
tasks_by_episodes: dict | None = None,
|
||||
robot_config: dict | None = None,
|
||||
):
|
||||
v1_6_dir = local_dir / repo_id / V1_6
|
||||
v2_0_dir = local_dir / repo_id / V2_0
|
||||
v1_6_dir.mkdir(parents=True, exist_ok=True)
|
||||
v2_0_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
hub_api = HfApi()
|
||||
hub_api.snapshot_download(
|
||||
repo_id=repo_id, repo_type="dataset", revision=V1_6, local_dir=v1_6_dir, ignore_patterns="videos/"
|
||||
)
|
||||
|
||||
metadata_v1_6 = load_json(v1_6_dir / "meta_data" / "info.json")
|
||||
|
||||
table = pq.read_table(v1_6_dir / "data")
|
||||
keys = get_keys(table)
|
||||
|
||||
# Episodes
|
||||
episode_indices = sorted(table["episode_index"].unique().to_pylist())
|
||||
total_episodes = len(episode_indices)
|
||||
assert episode_indices == list(range(total_episodes))
|
||||
|
||||
# Tasks
|
||||
if tasks_by_episodes is None: # Single task dataset
|
||||
tasks_by_episodes = {ep_idx: 0 for ep_idx in episode_indices}
|
||||
|
||||
assert set(tasks) == set(tasks_by_episodes.values())
|
||||
table = add_tasks(table, tasks_by_episodes)
|
||||
write_json(tasks, v2_0_dir / "meta" / "tasks.json")
|
||||
|
||||
# Split data into 1 parquet file by episode
|
||||
episode_lengths = split_parquet_by_episodes(table, keys, total_episodes, episode_indices, v2_0_dir)
|
||||
|
||||
# Shapes
|
||||
sequence_shapes = {key: len(table[key][0]) for key in keys["sequence"]}
|
||||
image_shapes = get_image_shapes(table, keys["image"]) if len(keys["image"]) > 0 else {}
|
||||
if len(keys["video"]) > 0:
|
||||
assert metadata_v1_6.get("video", False)
|
||||
videos_info = get_videos_info(repo_id, v1_6_dir, video_keys=keys["video"])
|
||||
video_shapes = get_video_shapes(videos_info, keys["video"])
|
||||
for img_key in keys["video"]:
|
||||
assert videos_info[img_key]["video.pix_fmt"] == metadata_v1_6["encoding"]["pix_fmt"]
|
||||
assert math.isclose(videos_info[img_key]["video.fps"], metadata_v1_6["fps"], rel_tol=1e-3)
|
||||
else:
|
||||
assert len(keys["video"]) == 0
|
||||
videos_info = None
|
||||
video_shapes = {}
|
||||
|
||||
# Names
|
||||
if robot_config is not None:
|
||||
robot_type = robot_config["robot_type"]
|
||||
names = robot_config["names"]
|
||||
else:
|
||||
robot_type = "unknown"
|
||||
names = get_generic_motor_names(sequence_shapes)
|
||||
|
||||
assert set(names) == set(keys["sequence"])
|
||||
for key in sequence_shapes:
|
||||
assert len(names[key]) == sequence_shapes[key]
|
||||
|
||||
# Episodes info
|
||||
episodes = [
|
||||
{"index": ep_idx, "task": tasks_by_episodes[ep_idx], "length": episode_lengths[ep_idx]}
|
||||
for ep_idx in episode_indices
|
||||
]
|
||||
|
||||
# Assemble metadata v2.0
|
||||
metadata_v2_0 = {
|
||||
"codebase_version": V2_0,
|
||||
"data_path": PARQUET_PATH,
|
||||
"robot_type": robot_type,
|
||||
"total_episodes": total_episodes,
|
||||
"total_tasks": len(tasks),
|
||||
"fps": metadata_v1_6["fps"],
|
||||
"image_keys": keys["video"] + keys["image"],
|
||||
"keys": keys["sequence"],
|
||||
"shapes": {**image_shapes, **video_shapes, **sequence_shapes},
|
||||
"names": names,
|
||||
"videos": videos_info,
|
||||
"episodes": episodes,
|
||||
}
|
||||
write_json(metadata_v2_0, v2_0_dir / "meta" / "info.json")
|
||||
|
||||
convert_stats_to_json(v1_6_dir / "meta_data", v2_0_dir / "meta")
|
||||
|
||||
# test_repo_id = f"aliberts/{repo_id.split('/')[1]}"
|
||||
# if hub_api.repo_exists(test_repo_id, repo_type="dataset"):
|
||||
# hub_api.delete_repo(test_repo_id, repo_type="dataset")
|
||||
|
||||
# hub_api.create_repo(test_repo_id, repo_type="dataset", exist_ok=True)
|
||||
# hub_api.upload_folder(repo_id=test_repo_id, folder_path=v2_0_dir, repo_type="dataset")
|
||||
|
||||
# TODO:
|
||||
# - [X] Add shapes
|
||||
# - [X] Add keys
|
||||
# - [X] Add paths
|
||||
# - [X] convert stats.json
|
||||
# - [X] Add task.json
|
||||
# - [X] Add names
|
||||
# - [X] Add robot_type
|
||||
# - [/] Add sanity checks (encoding, shapes)
|
||||
# - [ ] Handle multitask datasets
|
||||
# - [ ] Push properly to branch v2.0 and delete v1.6 stuff from that branch
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser()
|
||||
|
||||
parser.add_argument(
|
||||
"--repo-id",
|
||||
type=str,
|
||||
required=True,
|
||||
help="Repository identifier on Hugging Face: a community or a user name `/` the name of the dataset (e.g. `lerobot/pusht`, `cadene/aloha_sim_insertion_human`).",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--task",
|
||||
type=str,
|
||||
required=True,
|
||||
help="A short but accurate description of the task performed in the dataset.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--robot-config",
|
||||
type=Path,
|
||||
default=None,
|
||||
help="Path to the robot's config yaml the dataset during conversion.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--robot-overrides",
|
||||
type=str,
|
||||
nargs="*",
|
||||
help="Any key=value arguments to override the robot config values (use dots for.nested=overrides)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--local-dir",
|
||||
type=Path,
|
||||
default=None,
|
||||
help="Local directory to store the dataset during conversion. Defaults to /tmp/{repo_id}",
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
if args.local_dir is None:
|
||||
args.local_dir = Path(f"/tmp/{args.repo_id}")
|
||||
|
||||
tasks = {0: args.task}
|
||||
del args.task
|
||||
|
||||
if args.robot_config is not None:
|
||||
robot_config = parse_robot_config(args.robot_config, args.robot_overrides)
|
||||
else:
|
||||
robot_config = None
|
||||
del args.robot_config, args.robot_overrides
|
||||
|
||||
convert_dataset(**vars(args), tasks=tasks, robot_config=robot_config)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
from time import sleep
|
||||
|
||||
sleep(1)
|
||||
main()
|
Loading…
Reference in New Issue