This commit is contained in:
Simon Alibert 2024-10-03 20:00:44 +02:00
parent 92573486a8
commit ad115b6c27
1 changed files with 484 additions and 0 deletions

484
convert_dataset_16_to_20.py Normal file
View File

@ -0,0 +1,484 @@
"""
This script will help you convert any LeRobot dataset already pushed to the hub from codebase version 1.6 to
2.0. You will be required to provide the 'tasks', which is a short but accurate description in plain English
for each of the task performed in the dataset. This will allow to easily train models with task-conditionning.
If your dataset contains a single task, you can provide it directly via the CLI with the '--task' option (see
examples below).
If your dataset is a multi-task dataset, TODO
In any case, keep in mind that there should only be one task per episode. Multi-task episodes are not
supported for now.
Usage examples
Single-task dataset:
```bash
python convert_dataset_16_to_20.py \
--repo-id lerobot/aloha_sim_insertion_human_image \
--task "Insert the peg into the socket." \
--robot-config lerobot/configs/robot/aloha.yaml
```
```bash
python convert_dataset_16_to_20.py \
--repo-id aliberts/koch_tutorial \
--task "Pick the Lego block and drop it in the box on the right." \
--robot-config lerobot/configs/robot/koch.yaml \
--local-dir data
```
Multi-task dataset:
TODO
"""
import argparse
import json
import math
import subprocess
from io import BytesIO
from pathlib import Path
import pyarrow as pa
import pyarrow.compute as pc
import pyarrow.parquet as pq
import torch
from huggingface_hub import HfApi
from PIL import Image
from safetensors.torch import load_file
from lerobot.common.utils.utils import init_hydra_config
V1_6 = "v1.6"
V2_0 = "v2.0"
PARQUET_PATH = "data/train-{episode_index:05d}-of-{total_episodes:05d}.parquet"
VIDEO_PATH = "videos/{image_key}_episode_{episode_index:06d}.mp4"
def parse_robot_config(config_path: Path, config_overrides: list[str] | None = None) -> tuple[str, dict]:
robot_cfg = init_hydra_config(config_path, config_overrides)
if robot_cfg["robot_type"] in ["aloha", "koch"]:
state_names = [
f"{arm}_{motor}" if len(robot_cfg["follower_arms"]) > 1 else motor
for arm in robot_cfg["follower_arms"]
for motor in robot_cfg["follower_arms"][arm]["motors"]
]
action_names = [
# f"{arm}_{motor}" for arm in ["left", "right"] for motor in robot_cfg["leader_arms"][arm]["motors"]
f"{arm}_{motor}" if len(robot_cfg["leader_arms"]) > 1 else motor
for arm in robot_cfg["leader_arms"]
for motor in robot_cfg["leader_arms"][arm]["motors"]
]
# elif robot_cfg["robot_type"] == "stretch3": TODO
else:
raise NotImplementedError(
"Please provide robot_config={'robot_type': ..., 'names': ...} directly to convert_dataset()."
)
return {
"robot_type": robot_cfg["robot_type"],
"names": {
"observation.state": state_names,
"action": action_names,
},
}
def load_json(fpath: Path) -> dict:
with open(fpath) as f:
return json.load(f)
def write_json(data: dict, fpath: Path) -> None:
fpath.parent.mkdir(exist_ok=True, parents=True)
with open(fpath, "w") as f:
json.dump(data, f, indent=4)
def convert_stats_to_json(input_dir: Path, output_dir: Path) -> None:
safetensor_path = input_dir / "stats.safetensors"
stats = load_file(safetensor_path)
serializable_stats = {key: value.tolist() for key, value in stats.items()}
json_path = output_dir / "stats.json"
json_path.parent.mkdir(exist_ok=True, parents=True)
with open(json_path, "w") as f:
json.dump(serializable_stats, f, indent=4)
# Sanity check
with open(json_path) as f:
stats_json = json.load(f)
stats_json = {key: torch.tensor(value) for key, value in stats_json.items()}
for key in stats:
torch.testing.assert_close(stats_json[key], stats[key])
def get_keys(table: pa.Table) -> dict[str, list]:
table_metadata = json.loads(table.schema.metadata[b"huggingface"].decode("utf-8"))
sequence_keys, image_keys, video_keys = [], [], []
for key, val in table_metadata["info"]["features"].items():
if val["_type"] == "Sequence":
sequence_keys.append(key)
elif val["_type"] == "Image":
image_keys.append(key)
elif val["_type"] == "VideoFrame":
video_keys.append(key)
return {
"sequence": sequence_keys,
"image": image_keys,
"video": video_keys,
}
def remove_hf_metadata_features(table: pa.Table, features: list[str]) -> pa.Table:
# HACK
schema = table.schema
# decode bytes dict
table_metadata = json.loads(schema.metadata[b"huggingface"].decode("utf-8"))
for key in features:
table_metadata["info"]["features"].pop(key)
# re-encode bytes dict
table_metadata = {b"huggingface": json.dumps(table_metadata).encode("utf-8")}
new_schema = schema.with_metadata(table_metadata)
return table.replace_schema_metadata(new_schema.metadata)
def add_hf_metadata_features(table: pa.Table, features: dict[str, dict]) -> pa.Table:
# HACK
schema = table.schema
# decode bytes dict
table_metadata = json.loads(schema.metadata[b"huggingface"].decode("utf-8"))
for key, val in features.items():
table_metadata["info"]["features"][key] = val
# re-encode bytes dict
table_metadata = {b"huggingface": json.dumps(table_metadata).encode("utf-8")}
new_schema = schema.with_metadata(table_metadata)
return table.replace_schema_metadata(new_schema.metadata)
def remove_videoframe_from_table(table: pa.Table, image_columns: list) -> pa.Table:
table = table.drop(image_columns)
table = remove_hf_metadata_features(table, image_columns)
return table
def add_tasks(table: pa.Table, tasks_by_episodes: dict) -> pa.Table:
tasks_index = pa.array([tasks_by_episodes.get(key.as_py(), None) for key in table["episode_index"]])
table = table.append_column("task_index", tasks_index)
hf_feature = {"task_index": {"dtype": "int64", "_type": "Value"}}
table = add_hf_metadata_features(table, hf_feature)
return table
def split_parquet_by_episodes(
table: pa.Table, keys: dict[str, list], total_episodes: int, episode_indices: list, output_dir: Path
) -> list:
(output_dir / "data").mkdir(exist_ok=True, parents=True)
if len(keys["video"]) > 0:
table = remove_videoframe_from_table(table, keys["video"])
episode_lengths = []
for episode_index in sorted(episode_indices):
# Write each episode_index to a new parquet file
filtered_table = table.filter(pc.equal(table["episode_index"], episode_index))
episode_lengths.insert(episode_index, len(filtered_table))
output_file = output_dir / PARQUET_PATH.format(
episode_index=episode_index, total_episodes=total_episodes
)
pq.write_table(filtered_table, output_file)
return episode_lengths
def _get_audio_info(video_path: Path | str) -> dict:
ffprobe_audio_cmd = [
"ffprobe",
"-v",
"error",
"-select_streams",
"a:0",
"-show_entries",
"stream=channels,codec_name,bit_rate,sample_rate,bit_depth,channel_layout,duration",
"-of",
"json",
str(video_path),
]
result = subprocess.run(ffprobe_audio_cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True)
if result.returncode != 0:
raise RuntimeError(f"Error running ffprobe: {result.stderr}")
info = json.loads(result.stdout)
audio_stream_info = info["streams"][0] if info.get("streams") else None
if audio_stream_info is None:
return {"has_audio": False}
# Return the information, defaulting to None if no audio stream is present
return {
"has_audio": True,
"audio.channels": audio_stream_info.get("channels", None),
"audio.codec": audio_stream_info.get("codec_name", None),
"audio.bit_rate": int(audio_stream_info["bit_rate"]) if audio_stream_info.get("bit_rate") else None,
"audio.sample_rate": int(audio_stream_info["sample_rate"])
if audio_stream_info.get("sample_rate")
else None,
"audio.bit_depth": audio_stream_info.get("bit_depth", None),
"audio.channel_layout": audio_stream_info.get("channel_layout", None),
}
def _get_video_info(video_path: Path | str) -> dict:
ffprobe_video_cmd = [
"ffprobe",
"-v",
"error",
"-select_streams",
"v:0",
"-show_entries",
"stream=r_frame_rate,width,height,codec_name,nb_frames,duration,pix_fmt",
"-of",
"json",
str(video_path),
]
result = subprocess.run(ffprobe_video_cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True)
if result.returncode != 0:
raise RuntimeError(f"Error running ffprobe: {result.stderr}")
info = json.loads(result.stdout)
video_stream_info = info["streams"][0]
# Calculate fps from r_frame_rate
r_frame_rate = video_stream_info["r_frame_rate"]
num, denom = map(int, r_frame_rate.split("/"))
fps = num / denom
video_info = {
"video.fps": fps,
"video.width": video_stream_info["width"],
"video.height": video_stream_info["height"],
"video.codec": video_stream_info["codec_name"],
"video.pix_fmt": video_stream_info["pix_fmt"],
**_get_audio_info(video_path),
}
return video_info
def get_videos_info(repo_id: str, local_dir: Path, video_keys: list[str]) -> dict:
hub_api = HfApi()
videos_info_dict = {
"videos_path": VIDEO_PATH,
"has_audio": False,
"has_depth": False,
}
for vid_key in video_keys:
video_path = VIDEO_PATH.format(image_key=vid_key, episode_index=0)
video_path = hub_api.hf_hub_download(
repo_id=repo_id, repo_type="dataset", local_dir=local_dir, filename=video_path
)
videos_info_dict[vid_key] = _get_video_info(video_path)
videos_info_dict["has_audio"] = (
videos_info_dict["has_audio"] or videos_info_dict[vid_key]["has_audio"]
)
return videos_info_dict
def get_video_shapes(videos_info: dict, video_keys: list) -> dict:
video_shapes = {}
for img_key in video_keys:
video_shapes[img_key] = {
"width": videos_info[img_key]["video.width"],
"height": videos_info[img_key]["video.height"],
}
return video_shapes
def get_image_shapes(table: pa.Table, image_keys: list) -> dict:
image_shapes = {}
for img_key in image_keys:
image_bytes = table[img_key][0].as_py() # Assuming first row
image = Image.open(BytesIO(image_bytes["bytes"]))
image_shapes[img_key] = {
"width": image.width,
"height": image.height,
}
return image_shapes
def get_generic_motor_names(sequence_shapes: dict) -> dict:
return {key: [f"motor_{i}" for i in range(length)] for key, length in sequence_shapes.items()}
def convert_dataset(
repo_id: str,
local_dir: Path,
tasks: dict,
tasks_by_episodes: dict | None = None,
robot_config: dict | None = None,
):
v1_6_dir = local_dir / repo_id / V1_6
v2_0_dir = local_dir / repo_id / V2_0
v1_6_dir.mkdir(parents=True, exist_ok=True)
v2_0_dir.mkdir(parents=True, exist_ok=True)
hub_api = HfApi()
hub_api.snapshot_download(
repo_id=repo_id, repo_type="dataset", revision=V1_6, local_dir=v1_6_dir, ignore_patterns="videos/"
)
metadata_v1_6 = load_json(v1_6_dir / "meta_data" / "info.json")
table = pq.read_table(v1_6_dir / "data")
keys = get_keys(table)
# Episodes
episode_indices = sorted(table["episode_index"].unique().to_pylist())
total_episodes = len(episode_indices)
assert episode_indices == list(range(total_episodes))
# Tasks
if tasks_by_episodes is None: # Single task dataset
tasks_by_episodes = {ep_idx: 0 for ep_idx in episode_indices}
assert set(tasks) == set(tasks_by_episodes.values())
table = add_tasks(table, tasks_by_episodes)
write_json(tasks, v2_0_dir / "meta" / "tasks.json")
# Split data into 1 parquet file by episode
episode_lengths = split_parquet_by_episodes(table, keys, total_episodes, episode_indices, v2_0_dir)
# Shapes
sequence_shapes = {key: len(table[key][0]) for key in keys["sequence"]}
image_shapes = get_image_shapes(table, keys["image"]) if len(keys["image"]) > 0 else {}
if len(keys["video"]) > 0:
assert metadata_v1_6.get("video", False)
videos_info = get_videos_info(repo_id, v1_6_dir, video_keys=keys["video"])
video_shapes = get_video_shapes(videos_info, keys["video"])
for img_key in keys["video"]:
assert videos_info[img_key]["video.pix_fmt"] == metadata_v1_6["encoding"]["pix_fmt"]
assert math.isclose(videos_info[img_key]["video.fps"], metadata_v1_6["fps"], rel_tol=1e-3)
else:
assert len(keys["video"]) == 0
videos_info = None
video_shapes = {}
# Names
if robot_config is not None:
robot_type = robot_config["robot_type"]
names = robot_config["names"]
else:
robot_type = "unknown"
names = get_generic_motor_names(sequence_shapes)
assert set(names) == set(keys["sequence"])
for key in sequence_shapes:
assert len(names[key]) == sequence_shapes[key]
# Episodes info
episodes = [
{"index": ep_idx, "task": tasks_by_episodes[ep_idx], "length": episode_lengths[ep_idx]}
for ep_idx in episode_indices
]
# Assemble metadata v2.0
metadata_v2_0 = {
"codebase_version": V2_0,
"data_path": PARQUET_PATH,
"robot_type": robot_type,
"total_episodes": total_episodes,
"total_tasks": len(tasks),
"fps": metadata_v1_6["fps"],
"image_keys": keys["video"] + keys["image"],
"keys": keys["sequence"],
"shapes": {**image_shapes, **video_shapes, **sequence_shapes},
"names": names,
"videos": videos_info,
"episodes": episodes,
}
write_json(metadata_v2_0, v2_0_dir / "meta" / "info.json")
convert_stats_to_json(v1_6_dir / "meta_data", v2_0_dir / "meta")
# test_repo_id = f"aliberts/{repo_id.split('/')[1]}"
# if hub_api.repo_exists(test_repo_id, repo_type="dataset"):
# hub_api.delete_repo(test_repo_id, repo_type="dataset")
# hub_api.create_repo(test_repo_id, repo_type="dataset", exist_ok=True)
# hub_api.upload_folder(repo_id=test_repo_id, folder_path=v2_0_dir, repo_type="dataset")
# TODO:
# - [X] Add shapes
# - [X] Add keys
# - [X] Add paths
# - [X] convert stats.json
# - [X] Add task.json
# - [X] Add names
# - [X] Add robot_type
# - [/] Add sanity checks (encoding, shapes)
# - [ ] Handle multitask datasets
# - [ ] Push properly to branch v2.0 and delete v1.6 stuff from that branch
def main():
parser = argparse.ArgumentParser()
parser.add_argument(
"--repo-id",
type=str,
required=True,
help="Repository identifier on Hugging Face: a community or a user name `/` the name of the dataset (e.g. `lerobot/pusht`, `cadene/aloha_sim_insertion_human`).",
)
parser.add_argument(
"--task",
type=str,
required=True,
help="A short but accurate description of the task performed in the dataset.",
)
parser.add_argument(
"--robot-config",
type=Path,
default=None,
help="Path to the robot's config yaml the dataset during conversion.",
)
parser.add_argument(
"--robot-overrides",
type=str,
nargs="*",
help="Any key=value arguments to override the robot config values (use dots for.nested=overrides)",
)
parser.add_argument(
"--local-dir",
type=Path,
default=None,
help="Local directory to store the dataset during conversion. Defaults to /tmp/{repo_id}",
)
args = parser.parse_args()
if args.local_dir is None:
args.local_dir = Path(f"/tmp/{args.repo_id}")
tasks = {0: args.task}
del args.task
if args.robot_config is not None:
robot_config = parse_robot_config(args.robot_config, args.robot_overrides)
else:
robot_config = None
del args.robot_config, args.robot_overrides
convert_dataset(**vars(args), tasks=tasks, robot_config=robot_config)
if __name__ == "__main__":
from time import sleep
sleep(1)
main()