Extend v1 compatibility

This commit is contained in:
Simon Alibert 2024-10-14 10:14:27 +02:00
parent cf633344be
commit cbc51e1341
1 changed files with 39 additions and 29 deletions

View File

@ -12,7 +12,7 @@ We support 3 different scenarios for these tasks:
# 1. Single task dataset # 1. Single task dataset
If your dataset contains a single task, you can simply provide it directly via the CLI with the If your dataset contains a single task, you can simply provide it directly via the CLI with the
'--single-task' option (see examples below). '--single-task' option.
Examples: Examples:
@ -67,7 +67,15 @@ If your dataset is a multi-task dataset, you have two options to provide the tas
# 3. Multi task episodes # 3. Multi task episodes
If you have multiple tasks per episodes, your dataset should contain a language instruction column in its If you have multiple tasks per episodes, your dataset should contain a language instruction column in its
parquet file, and you must provide this column's name with the '--tasks-col' arg. parquet file, and you must provide this column's name with the '--tasks-col' arg.
TODO
Example:
```bash
python convert_dataset_v1_to_v2.py \
--repo-id lerobot/stanford_kuka_multimodal_dataset \
--tasks-col "language_instruction" \
--local-dir data
```
""" """
import argparse import argparse
@ -87,12 +95,12 @@ from huggingface_hub.errors import EntryNotFoundError
from PIL import Image from PIL import Image
from safetensors.torch import load_file from safetensors.torch import load_file
from lerobot.common.datasets.utils import create_branch, flatten_dict, unflatten_dict from lerobot.common.datasets.utils import create_branch, flatten_dict, get_hub_safe_version, unflatten_dict
from lerobot.common.utils.utils import init_hydra_config from lerobot.common.utils.utils import init_hydra_config
from lerobot.scripts.push_dataset_to_hub import push_dataset_card_to_hub from lerobot.scripts.push_dataset_to_hub import push_dataset_card_to_hub
V1_6 = "v1.6" V16 = "v1.6"
V2_0 = "v2.0" V20 = "v2.0"
PARQUET_PATH = "data/train-{episode_index:05d}-of-{total_episodes:05d}.parquet" PARQUET_PATH = "data/train-{episode_index:05d}-of-{total_episodes:05d}.parquet"
VIDEO_PATH = "videos/{video_key}_episode_{episode_index:06d}.mp4" VIDEO_PATH = "videos/{video_key}_episode_{episode_index:06d}.mp4"
@ -385,18 +393,19 @@ def convert_dataset(
tasks_col: Path | None = None, tasks_col: Path | None = None,
robot_config: dict | None = None, robot_config: dict | None = None,
): ):
v1_6_dir = local_dir / V1_6 / repo_id v1 = get_hub_safe_version(repo_id, V16)
v2_0_dir = local_dir / V2_0 / repo_id v1x_dir = local_dir / v1 / repo_id
v1_6_dir.mkdir(parents=True, exist_ok=True) v20_dir = local_dir / V20 / repo_id
v2_0_dir.mkdir(parents=True, exist_ok=True) v1x_dir.mkdir(parents=True, exist_ok=True)
v20_dir.mkdir(parents=True, exist_ok=True)
hub_api = HfApi() hub_api = HfApi()
hub_api.snapshot_download( hub_api.snapshot_download(
repo_id=repo_id, repo_type="dataset", revision=V1_6, local_dir=v1_6_dir, ignore_patterns="videos/" repo_id=repo_id, repo_type="dataset", revision=v1, local_dir=v1x_dir, ignore_patterns="videos/"
) )
metadata_v1_6 = load_json(v1_6_dir / "meta_data" / "info.json") metadata_v1 = load_json(v1x_dir / "meta_data" / "info.json")
dataset = datasets.load_dataset("parquet", data_dir=v1_6_dir / "data", split="train") dataset = datasets.load_dataset("parquet", data_dir=v1x_dir / "data", split="train")
keys = get_keys(dataset) keys = get_keys(dataset)
# Episodes # Episodes
@ -422,21 +431,22 @@ def convert_dataset(
assert set(tasks) == {task for ep_tasks in tasks_by_episodes.values() for task in ep_tasks} assert set(tasks) == {task for ep_tasks in tasks_by_episodes.values() for task in ep_tasks}
task_json = [{"task_index": task_idx, "task": task} for task_idx, task in enumerate(tasks)] task_json = [{"task_index": task_idx, "task": task} for task_idx, task in enumerate(tasks)]
write_json(task_json, v2_0_dir / "meta" / "tasks.json") write_json(task_json, v20_dir / "meta" / "tasks.json")
# Split data into 1 parquet file by episode # Split data into 1 parquet file by episode
episode_lengths = split_parquet_by_episodes(dataset, keys, total_episodes, episode_indices, v2_0_dir) episode_lengths = split_parquet_by_episodes(dataset, keys, total_episodes, episode_indices, v20_dir)
# Shapes # Shapes
sequence_shapes = {key: len(dataset[key][0]) for key in keys["sequence"]} sequence_shapes = {key: len(dataset[key][0]) for key in keys["sequence"]}
image_shapes = get_image_shapes(dataset, keys["image"]) if len(keys["image"]) > 0 else {} image_shapes = get_image_shapes(dataset, keys["image"]) if len(keys["image"]) > 0 else {}
if len(keys["video"]) > 0: if len(keys["video"]) > 0:
assert metadata_v1_6.get("video", False) assert metadata_v1.get("video", False)
videos_info = get_videos_info(repo_id, v1_6_dir, video_keys=keys["video"]) videos_info = get_videos_info(repo_id, v1x_dir, video_keys=keys["video"])
video_shapes = get_video_shapes(videos_info, keys["video"]) video_shapes = get_video_shapes(videos_info, keys["video"])
for img_key in keys["video"]: for img_key in keys["video"]:
assert videos_info[img_key]["video.pix_fmt"] == metadata_v1_6["encoding"]["pix_fmt"] assert math.isclose(videos_info[img_key]["video.fps"], metadata_v1["fps"], rel_tol=1e-3)
assert math.isclose(videos_info[img_key]["video.fps"], metadata_v1_6["fps"], rel_tol=1e-3) if "encoding" in metadata_v1:
assert videos_info[img_key]["video.pix_fmt"] == metadata_v1["encoding"]["pix_fmt"]
else: else:
assert len(keys["video"]) == 0 assert len(keys["video"]) == 0
videos_info = None videos_info = None
@ -461,16 +471,16 @@ def convert_dataset(
{"episode_index": ep_idx, "tasks": [tasks_by_episodes[ep_idx]], "length": episode_lengths[ep_idx]} {"episode_index": ep_idx, "tasks": [tasks_by_episodes[ep_idx]], "length": episode_lengths[ep_idx]}
for ep_idx in episode_indices for ep_idx in episode_indices
] ]
write_json(episodes, v2_0_dir / "meta" / "episodes.json") write_json(episodes, v20_dir / "meta" / "episodes.json")
# Assemble metadata v2.0 # Assemble metadata v2.0
metadata_v2_0 = { metadata_v2_0 = {
"codebase_version": V2_0, "codebase_version": V20,
"data_path": PARQUET_PATH, "data_path": PARQUET_PATH,
"robot_type": robot_type, "robot_type": robot_type,
"total_episodes": total_episodes, "total_episodes": total_episodes,
"total_tasks": len(tasks), "total_tasks": len(tasks),
"fps": metadata_v1_6["fps"], "fps": metadata_v1["fps"],
"splits": {"train": f"0:{total_episodes}"}, "splits": {"train": f"0:{total_episodes}"},
"keys": keys["sequence"], "keys": keys["sequence"],
"video_keys": keys["video"], "video_keys": keys["video"],
@ -479,14 +489,14 @@ def convert_dataset(
"names": names, "names": names,
"videos": videos_info, "videos": videos_info,
} }
write_json(metadata_v2_0, v2_0_dir / "meta" / "info.json") write_json(metadata_v2_0, v20_dir / "meta" / "info.json")
convert_stats_to_json(v1_6_dir / "meta_data", v2_0_dir / "meta") convert_stats_to_json(v1x_dir / "meta_data", v20_dir / "meta")
#### TODO: delete #### TODO: delete
repo_id = f"aliberts/{repo_id.split('/')[1]}" # repo_id = f"aliberts/{repo_id.split('/')[1]}"
# if hub_api.repo_exists(repo_id=repo_id, repo_type="dataset"): # if hub_api.repo_exists(repo_id=repo_id, repo_type="dataset"):
# hub_api.delete_repo(repo_id=repo_id, repo_type="dataset") # hub_api.delete_repo(repo_id=repo_id, repo_type="dataset")
hub_api.create_repo(repo_id=repo_id, repo_type="dataset", exist_ok=True) # hub_api.create_repo(repo_id=repo_id, repo_type="dataset", exist_ok=True)
#### ####
with contextlib.suppress(EntryNotFoundError): with contextlib.suppress(EntryNotFoundError):
@ -498,28 +508,28 @@ def convert_dataset(
hub_api.upload_folder( hub_api.upload_folder(
repo_id=repo_id, repo_id=repo_id,
path_in_repo="data", path_in_repo="data",
folder_path=v2_0_dir / "data", folder_path=v20_dir / "data",
repo_type="dataset", repo_type="dataset",
revision="main", revision="main",
) )
hub_api.upload_folder( hub_api.upload_folder(
repo_id=repo_id, repo_id=repo_id,
path_in_repo="videos", path_in_repo="videos",
folder_path=v1_6_dir / "videos", folder_path=v1x_dir / "videos",
repo_type="dataset", repo_type="dataset",
revision="main", revision="main",
) )
hub_api.upload_folder( hub_api.upload_folder(
repo_id=repo_id, repo_id=repo_id,
path_in_repo="meta", path_in_repo="meta",
folder_path=v2_0_dir / "meta", folder_path=v20_dir / "meta",
repo_type="dataset", repo_type="dataset",
revision="main", revision="main",
) )
card_text = f"[meta/info.json](meta/info.json)\n```json\n{json.dumps(metadata_v2_0, indent=4)}\n```" card_text = f"[meta/info.json](meta/info.json)\n```json\n{json.dumps(metadata_v2_0, indent=4)}\n```"
push_dataset_card_to_hub(repo_id=repo_id, revision="main", tags=repo_tags, text=card_text) push_dataset_card_to_hub(repo_id=repo_id, revision="main", tags=repo_tags, text=card_text)
create_branch(repo_id=repo_id, branch=V2_0, repo_type="dataset") create_branch(repo_id=repo_id, branch=V20, repo_type="dataset")
# TODO: # TODO:
# - [X] Add shapes # - [X] Add shapes