From ad115b6c27b095b3d1e7c291c3adf0b34d188e19 Mon Sep 17 00:00:00 2001 From: Simon Alibert Date: Thu, 3 Oct 2024 20:00:44 +0200 Subject: [PATCH] WIP --- convert_dataset_16_to_20.py | 484 ++++++++++++++++++++++++++++++++++++ 1 file changed, 484 insertions(+) create mode 100644 convert_dataset_16_to_20.py diff --git a/convert_dataset_16_to_20.py b/convert_dataset_16_to_20.py new file mode 100644 index 00000000..fdb5f233 --- /dev/null +++ b/convert_dataset_16_to_20.py @@ -0,0 +1,484 @@ +""" +This script will help you convert any LeRobot dataset already pushed to the hub from codebase version 1.6 to +2.0. You will be required to provide the 'tasks', which is a short but accurate description in plain English +for each of the task performed in the dataset. This will allow to easily train models with task-conditionning. + +If your dataset contains a single task, you can provide it directly via the CLI with the '--task' option (see +examples below). + +If your dataset is a multi-task dataset, TODO + +In any case, keep in mind that there should only be one task per episode. Multi-task episodes are not +supported for now. + +Usage examples + +Single-task dataset: +```bash +python convert_dataset_16_to_20.py \ + --repo-id lerobot/aloha_sim_insertion_human_image \ + --task "Insert the peg into the socket." \ + --robot-config lerobot/configs/robot/aloha.yaml +``` + +```bash +python convert_dataset_16_to_20.py \ + --repo-id aliberts/koch_tutorial \ + --task "Pick the Lego block and drop it in the box on the right." \ + --robot-config lerobot/configs/robot/koch.yaml \ + --local-dir data +``` + +Multi-task dataset: +TODO +""" + +import argparse +import json +import math +import subprocess +from io import BytesIO +from pathlib import Path + +import pyarrow as pa +import pyarrow.compute as pc +import pyarrow.parquet as pq +import torch +from huggingface_hub import HfApi +from PIL import Image +from safetensors.torch import load_file + +from lerobot.common.utils.utils import init_hydra_config + +V1_6 = "v1.6" +V2_0 = "v2.0" + +PARQUET_PATH = "data/train-{episode_index:05d}-of-{total_episodes:05d}.parquet" +VIDEO_PATH = "videos/{image_key}_episode_{episode_index:06d}.mp4" + + +def parse_robot_config(config_path: Path, config_overrides: list[str] | None = None) -> tuple[str, dict]: + robot_cfg = init_hydra_config(config_path, config_overrides) + if robot_cfg["robot_type"] in ["aloha", "koch"]: + state_names = [ + f"{arm}_{motor}" if len(robot_cfg["follower_arms"]) > 1 else motor + for arm in robot_cfg["follower_arms"] + for motor in robot_cfg["follower_arms"][arm]["motors"] + ] + action_names = [ + # f"{arm}_{motor}" for arm in ["left", "right"] for motor in robot_cfg["leader_arms"][arm]["motors"] + f"{arm}_{motor}" if len(robot_cfg["leader_arms"]) > 1 else motor + for arm in robot_cfg["leader_arms"] + for motor in robot_cfg["leader_arms"][arm]["motors"] + ] + # elif robot_cfg["robot_type"] == "stretch3": TODO + else: + raise NotImplementedError( + "Please provide robot_config={'robot_type': ..., 'names': ...} directly to convert_dataset()." + ) + + return { + "robot_type": robot_cfg["robot_type"], + "names": { + "observation.state": state_names, + "action": action_names, + }, + } + + +def load_json(fpath: Path) -> dict: + with open(fpath) as f: + return json.load(f) + + +def write_json(data: dict, fpath: Path) -> None: + fpath.parent.mkdir(exist_ok=True, parents=True) + with open(fpath, "w") as f: + json.dump(data, f, indent=4) + + +def convert_stats_to_json(input_dir: Path, output_dir: Path) -> None: + safetensor_path = input_dir / "stats.safetensors" + stats = load_file(safetensor_path) + serializable_stats = {key: value.tolist() for key, value in stats.items()} + + json_path = output_dir / "stats.json" + json_path.parent.mkdir(exist_ok=True, parents=True) + with open(json_path, "w") as f: + json.dump(serializable_stats, f, indent=4) + + # Sanity check + with open(json_path) as f: + stats_json = json.load(f) + + stats_json = {key: torch.tensor(value) for key, value in stats_json.items()} + for key in stats: + torch.testing.assert_close(stats_json[key], stats[key]) + + +def get_keys(table: pa.Table) -> dict[str, list]: + table_metadata = json.loads(table.schema.metadata[b"huggingface"].decode("utf-8")) + sequence_keys, image_keys, video_keys = [], [], [] + for key, val in table_metadata["info"]["features"].items(): + if val["_type"] == "Sequence": + sequence_keys.append(key) + elif val["_type"] == "Image": + image_keys.append(key) + elif val["_type"] == "VideoFrame": + video_keys.append(key) + + return { + "sequence": sequence_keys, + "image": image_keys, + "video": video_keys, + } + + +def remove_hf_metadata_features(table: pa.Table, features: list[str]) -> pa.Table: + # HACK + schema = table.schema + # decode bytes dict + table_metadata = json.loads(schema.metadata[b"huggingface"].decode("utf-8")) + for key in features: + table_metadata["info"]["features"].pop(key) + + # re-encode bytes dict + table_metadata = {b"huggingface": json.dumps(table_metadata).encode("utf-8")} + new_schema = schema.with_metadata(table_metadata) + return table.replace_schema_metadata(new_schema.metadata) + + +def add_hf_metadata_features(table: pa.Table, features: dict[str, dict]) -> pa.Table: + # HACK + schema = table.schema + # decode bytes dict + table_metadata = json.loads(schema.metadata[b"huggingface"].decode("utf-8")) + for key, val in features.items(): + table_metadata["info"]["features"][key] = val + + # re-encode bytes dict + table_metadata = {b"huggingface": json.dumps(table_metadata).encode("utf-8")} + new_schema = schema.with_metadata(table_metadata) + return table.replace_schema_metadata(new_schema.metadata) + + +def remove_videoframe_from_table(table: pa.Table, image_columns: list) -> pa.Table: + table = table.drop(image_columns) + table = remove_hf_metadata_features(table, image_columns) + return table + + +def add_tasks(table: pa.Table, tasks_by_episodes: dict) -> pa.Table: + tasks_index = pa.array([tasks_by_episodes.get(key.as_py(), None) for key in table["episode_index"]]) + table = table.append_column("task_index", tasks_index) + hf_feature = {"task_index": {"dtype": "int64", "_type": "Value"}} + table = add_hf_metadata_features(table, hf_feature) + return table + + +def split_parquet_by_episodes( + table: pa.Table, keys: dict[str, list], total_episodes: int, episode_indices: list, output_dir: Path +) -> list: + (output_dir / "data").mkdir(exist_ok=True, parents=True) + if len(keys["video"]) > 0: + table = remove_videoframe_from_table(table, keys["video"]) + + episode_lengths = [] + for episode_index in sorted(episode_indices): + # Write each episode_index to a new parquet file + filtered_table = table.filter(pc.equal(table["episode_index"], episode_index)) + episode_lengths.insert(episode_index, len(filtered_table)) + output_file = output_dir / PARQUET_PATH.format( + episode_index=episode_index, total_episodes=total_episodes + ) + pq.write_table(filtered_table, output_file) + + return episode_lengths + + +def _get_audio_info(video_path: Path | str) -> dict: + ffprobe_audio_cmd = [ + "ffprobe", + "-v", + "error", + "-select_streams", + "a:0", + "-show_entries", + "stream=channels,codec_name,bit_rate,sample_rate,bit_depth,channel_layout,duration", + "-of", + "json", + str(video_path), + ] + result = subprocess.run(ffprobe_audio_cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True) + if result.returncode != 0: + raise RuntimeError(f"Error running ffprobe: {result.stderr}") + + info = json.loads(result.stdout) + audio_stream_info = info["streams"][0] if info.get("streams") else None + if audio_stream_info is None: + return {"has_audio": False} + + # Return the information, defaulting to None if no audio stream is present + return { + "has_audio": True, + "audio.channels": audio_stream_info.get("channels", None), + "audio.codec": audio_stream_info.get("codec_name", None), + "audio.bit_rate": int(audio_stream_info["bit_rate"]) if audio_stream_info.get("bit_rate") else None, + "audio.sample_rate": int(audio_stream_info["sample_rate"]) + if audio_stream_info.get("sample_rate") + else None, + "audio.bit_depth": audio_stream_info.get("bit_depth", None), + "audio.channel_layout": audio_stream_info.get("channel_layout", None), + } + + +def _get_video_info(video_path: Path | str) -> dict: + ffprobe_video_cmd = [ + "ffprobe", + "-v", + "error", + "-select_streams", + "v:0", + "-show_entries", + "stream=r_frame_rate,width,height,codec_name,nb_frames,duration,pix_fmt", + "-of", + "json", + str(video_path), + ] + result = subprocess.run(ffprobe_video_cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True) + if result.returncode != 0: + raise RuntimeError(f"Error running ffprobe: {result.stderr}") + + info = json.loads(result.stdout) + video_stream_info = info["streams"][0] + + # Calculate fps from r_frame_rate + r_frame_rate = video_stream_info["r_frame_rate"] + num, denom = map(int, r_frame_rate.split("/")) + fps = num / denom + + video_info = { + "video.fps": fps, + "video.width": video_stream_info["width"], + "video.height": video_stream_info["height"], + "video.codec": video_stream_info["codec_name"], + "video.pix_fmt": video_stream_info["pix_fmt"], + **_get_audio_info(video_path), + } + + return video_info + + +def get_videos_info(repo_id: str, local_dir: Path, video_keys: list[str]) -> dict: + hub_api = HfApi() + videos_info_dict = { + "videos_path": VIDEO_PATH, + "has_audio": False, + "has_depth": False, + } + for vid_key in video_keys: + video_path = VIDEO_PATH.format(image_key=vid_key, episode_index=0) + video_path = hub_api.hf_hub_download( + repo_id=repo_id, repo_type="dataset", local_dir=local_dir, filename=video_path + ) + videos_info_dict[vid_key] = _get_video_info(video_path) + videos_info_dict["has_audio"] = ( + videos_info_dict["has_audio"] or videos_info_dict[vid_key]["has_audio"] + ) + + return videos_info_dict + + +def get_video_shapes(videos_info: dict, video_keys: list) -> dict: + video_shapes = {} + for img_key in video_keys: + video_shapes[img_key] = { + "width": videos_info[img_key]["video.width"], + "height": videos_info[img_key]["video.height"], + } + + return video_shapes + + +def get_image_shapes(table: pa.Table, image_keys: list) -> dict: + image_shapes = {} + for img_key in image_keys: + image_bytes = table[img_key][0].as_py() # Assuming first row + image = Image.open(BytesIO(image_bytes["bytes"])) + image_shapes[img_key] = { + "width": image.width, + "height": image.height, + } + + return image_shapes + + +def get_generic_motor_names(sequence_shapes: dict) -> dict: + return {key: [f"motor_{i}" for i in range(length)] for key, length in sequence_shapes.items()} + + +def convert_dataset( + repo_id: str, + local_dir: Path, + tasks: dict, + tasks_by_episodes: dict | None = None, + robot_config: dict | None = None, +): + v1_6_dir = local_dir / repo_id / V1_6 + v2_0_dir = local_dir / repo_id / V2_0 + v1_6_dir.mkdir(parents=True, exist_ok=True) + v2_0_dir.mkdir(parents=True, exist_ok=True) + + hub_api = HfApi() + hub_api.snapshot_download( + repo_id=repo_id, repo_type="dataset", revision=V1_6, local_dir=v1_6_dir, ignore_patterns="videos/" + ) + + metadata_v1_6 = load_json(v1_6_dir / "meta_data" / "info.json") + + table = pq.read_table(v1_6_dir / "data") + keys = get_keys(table) + + # Episodes + episode_indices = sorted(table["episode_index"].unique().to_pylist()) + total_episodes = len(episode_indices) + assert episode_indices == list(range(total_episodes)) + + # Tasks + if tasks_by_episodes is None: # Single task dataset + tasks_by_episodes = {ep_idx: 0 for ep_idx in episode_indices} + + assert set(tasks) == set(tasks_by_episodes.values()) + table = add_tasks(table, tasks_by_episodes) + write_json(tasks, v2_0_dir / "meta" / "tasks.json") + + # Split data into 1 parquet file by episode + episode_lengths = split_parquet_by_episodes(table, keys, total_episodes, episode_indices, v2_0_dir) + + # Shapes + sequence_shapes = {key: len(table[key][0]) for key in keys["sequence"]} + image_shapes = get_image_shapes(table, keys["image"]) if len(keys["image"]) > 0 else {} + if len(keys["video"]) > 0: + assert metadata_v1_6.get("video", False) + videos_info = get_videos_info(repo_id, v1_6_dir, video_keys=keys["video"]) + video_shapes = get_video_shapes(videos_info, keys["video"]) + for img_key in keys["video"]: + assert videos_info[img_key]["video.pix_fmt"] == metadata_v1_6["encoding"]["pix_fmt"] + assert math.isclose(videos_info[img_key]["video.fps"], metadata_v1_6["fps"], rel_tol=1e-3) + else: + assert len(keys["video"]) == 0 + videos_info = None + video_shapes = {} + + # Names + if robot_config is not None: + robot_type = robot_config["robot_type"] + names = robot_config["names"] + else: + robot_type = "unknown" + names = get_generic_motor_names(sequence_shapes) + + assert set(names) == set(keys["sequence"]) + for key in sequence_shapes: + assert len(names[key]) == sequence_shapes[key] + + # Episodes info + episodes = [ + {"index": ep_idx, "task": tasks_by_episodes[ep_idx], "length": episode_lengths[ep_idx]} + for ep_idx in episode_indices + ] + + # Assemble metadata v2.0 + metadata_v2_0 = { + "codebase_version": V2_0, + "data_path": PARQUET_PATH, + "robot_type": robot_type, + "total_episodes": total_episodes, + "total_tasks": len(tasks), + "fps": metadata_v1_6["fps"], + "image_keys": keys["video"] + keys["image"], + "keys": keys["sequence"], + "shapes": {**image_shapes, **video_shapes, **sequence_shapes}, + "names": names, + "videos": videos_info, + "episodes": episodes, + } + write_json(metadata_v2_0, v2_0_dir / "meta" / "info.json") + + convert_stats_to_json(v1_6_dir / "meta_data", v2_0_dir / "meta") + + # test_repo_id = f"aliberts/{repo_id.split('/')[1]}" + # if hub_api.repo_exists(test_repo_id, repo_type="dataset"): + # hub_api.delete_repo(test_repo_id, repo_type="dataset") + + # hub_api.create_repo(test_repo_id, repo_type="dataset", exist_ok=True) + # hub_api.upload_folder(repo_id=test_repo_id, folder_path=v2_0_dir, repo_type="dataset") + + # TODO: + # - [X] Add shapes + # - [X] Add keys + # - [X] Add paths + # - [X] convert stats.json + # - [X] Add task.json + # - [X] Add names + # - [X] Add robot_type + # - [/] Add sanity checks (encoding, shapes) + # - [ ] Handle multitask datasets + # - [ ] Push properly to branch v2.0 and delete v1.6 stuff from that branch + + +def main(): + parser = argparse.ArgumentParser() + + parser.add_argument( + "--repo-id", + type=str, + required=True, + help="Repository identifier on Hugging Face: a community or a user name `/` the name of the dataset (e.g. `lerobot/pusht`, `cadene/aloha_sim_insertion_human`).", + ) + parser.add_argument( + "--task", + type=str, + required=True, + help="A short but accurate description of the task performed in the dataset.", + ) + parser.add_argument( + "--robot-config", + type=Path, + default=None, + help="Path to the robot's config yaml the dataset during conversion.", + ) + parser.add_argument( + "--robot-overrides", + type=str, + nargs="*", + help="Any key=value arguments to override the robot config values (use dots for.nested=overrides)", + ) + parser.add_argument( + "--local-dir", + type=Path, + default=None, + help="Local directory to store the dataset during conversion. Defaults to /tmp/{repo_id}", + ) + + args = parser.parse_args() + if args.local_dir is None: + args.local_dir = Path(f"/tmp/{args.repo_id}") + + tasks = {0: args.task} + del args.task + + if args.robot_config is not None: + robot_config = parse_robot_config(args.robot_config, args.robot_overrides) + else: + robot_config = None + del args.robot_config, args.robot_overrides + + convert_dataset(**vars(args), tasks=tasks, robot_config=robot_config) + + +if __name__ == "__main__": + from time import sleep + + sleep(1) + main()