Extend v1 compatibility
This commit is contained in:
parent
cf633344be
commit
cbc51e1341
|
@ -12,7 +12,7 @@ We support 3 different scenarios for these tasks:
|
||||||
|
|
||||||
# 1. Single task dataset
|
# 1. Single task dataset
|
||||||
If your dataset contains a single task, you can simply provide it directly via the CLI with the
|
If your dataset contains a single task, you can simply provide it directly via the CLI with the
|
||||||
'--single-task' option (see examples below).
|
'--single-task' option.
|
||||||
|
|
||||||
Examples:
|
Examples:
|
||||||
|
|
||||||
|
@ -67,7 +67,15 @@ If your dataset is a multi-task dataset, you have two options to provide the tas
|
||||||
# 3. Multi task episodes
|
# 3. Multi task episodes
|
||||||
If you have multiple tasks per episodes, your dataset should contain a language instruction column in its
|
If you have multiple tasks per episodes, your dataset should contain a language instruction column in its
|
||||||
parquet file, and you must provide this column's name with the '--tasks-col' arg.
|
parquet file, and you must provide this column's name with the '--tasks-col' arg.
|
||||||
TODO
|
|
||||||
|
Example:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
python convert_dataset_v1_to_v2.py \
|
||||||
|
--repo-id lerobot/stanford_kuka_multimodal_dataset \
|
||||||
|
--tasks-col "language_instruction" \
|
||||||
|
--local-dir data
|
||||||
|
```
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import argparse
|
import argparse
|
||||||
|
@ -87,12 +95,12 @@ from huggingface_hub.errors import EntryNotFoundError
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
from safetensors.torch import load_file
|
from safetensors.torch import load_file
|
||||||
|
|
||||||
from lerobot.common.datasets.utils import create_branch, flatten_dict, unflatten_dict
|
from lerobot.common.datasets.utils import create_branch, flatten_dict, get_hub_safe_version, unflatten_dict
|
||||||
from lerobot.common.utils.utils import init_hydra_config
|
from lerobot.common.utils.utils import init_hydra_config
|
||||||
from lerobot.scripts.push_dataset_to_hub import push_dataset_card_to_hub
|
from lerobot.scripts.push_dataset_to_hub import push_dataset_card_to_hub
|
||||||
|
|
||||||
V1_6 = "v1.6"
|
V16 = "v1.6"
|
||||||
V2_0 = "v2.0"
|
V20 = "v2.0"
|
||||||
|
|
||||||
PARQUET_PATH = "data/train-{episode_index:05d}-of-{total_episodes:05d}.parquet"
|
PARQUET_PATH = "data/train-{episode_index:05d}-of-{total_episodes:05d}.parquet"
|
||||||
VIDEO_PATH = "videos/{video_key}_episode_{episode_index:06d}.mp4"
|
VIDEO_PATH = "videos/{video_key}_episode_{episode_index:06d}.mp4"
|
||||||
|
@ -385,18 +393,19 @@ def convert_dataset(
|
||||||
tasks_col: Path | None = None,
|
tasks_col: Path | None = None,
|
||||||
robot_config: dict | None = None,
|
robot_config: dict | None = None,
|
||||||
):
|
):
|
||||||
v1_6_dir = local_dir / V1_6 / repo_id
|
v1 = get_hub_safe_version(repo_id, V16)
|
||||||
v2_0_dir = local_dir / V2_0 / repo_id
|
v1x_dir = local_dir / v1 / repo_id
|
||||||
v1_6_dir.mkdir(parents=True, exist_ok=True)
|
v20_dir = local_dir / V20 / repo_id
|
||||||
v2_0_dir.mkdir(parents=True, exist_ok=True)
|
v1x_dir.mkdir(parents=True, exist_ok=True)
|
||||||
|
v20_dir.mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
hub_api = HfApi()
|
hub_api = HfApi()
|
||||||
hub_api.snapshot_download(
|
hub_api.snapshot_download(
|
||||||
repo_id=repo_id, repo_type="dataset", revision=V1_6, local_dir=v1_6_dir, ignore_patterns="videos/"
|
repo_id=repo_id, repo_type="dataset", revision=v1, local_dir=v1x_dir, ignore_patterns="videos/"
|
||||||
)
|
)
|
||||||
|
|
||||||
metadata_v1_6 = load_json(v1_6_dir / "meta_data" / "info.json")
|
metadata_v1 = load_json(v1x_dir / "meta_data" / "info.json")
|
||||||
dataset = datasets.load_dataset("parquet", data_dir=v1_6_dir / "data", split="train")
|
dataset = datasets.load_dataset("parquet", data_dir=v1x_dir / "data", split="train")
|
||||||
keys = get_keys(dataset)
|
keys = get_keys(dataset)
|
||||||
|
|
||||||
# Episodes
|
# Episodes
|
||||||
|
@ -422,21 +431,22 @@ def convert_dataset(
|
||||||
|
|
||||||
assert set(tasks) == {task for ep_tasks in tasks_by_episodes.values() for task in ep_tasks}
|
assert set(tasks) == {task for ep_tasks in tasks_by_episodes.values() for task in ep_tasks}
|
||||||
task_json = [{"task_index": task_idx, "task": task} for task_idx, task in enumerate(tasks)]
|
task_json = [{"task_index": task_idx, "task": task} for task_idx, task in enumerate(tasks)]
|
||||||
write_json(task_json, v2_0_dir / "meta" / "tasks.json")
|
write_json(task_json, v20_dir / "meta" / "tasks.json")
|
||||||
|
|
||||||
# Split data into 1 parquet file by episode
|
# Split data into 1 parquet file by episode
|
||||||
episode_lengths = split_parquet_by_episodes(dataset, keys, total_episodes, episode_indices, v2_0_dir)
|
episode_lengths = split_parquet_by_episodes(dataset, keys, total_episodes, episode_indices, v20_dir)
|
||||||
|
|
||||||
# Shapes
|
# Shapes
|
||||||
sequence_shapes = {key: len(dataset[key][0]) for key in keys["sequence"]}
|
sequence_shapes = {key: len(dataset[key][0]) for key in keys["sequence"]}
|
||||||
image_shapes = get_image_shapes(dataset, keys["image"]) if len(keys["image"]) > 0 else {}
|
image_shapes = get_image_shapes(dataset, keys["image"]) if len(keys["image"]) > 0 else {}
|
||||||
if len(keys["video"]) > 0:
|
if len(keys["video"]) > 0:
|
||||||
assert metadata_v1_6.get("video", False)
|
assert metadata_v1.get("video", False)
|
||||||
videos_info = get_videos_info(repo_id, v1_6_dir, video_keys=keys["video"])
|
videos_info = get_videos_info(repo_id, v1x_dir, video_keys=keys["video"])
|
||||||
video_shapes = get_video_shapes(videos_info, keys["video"])
|
video_shapes = get_video_shapes(videos_info, keys["video"])
|
||||||
for img_key in keys["video"]:
|
for img_key in keys["video"]:
|
||||||
assert videos_info[img_key]["video.pix_fmt"] == metadata_v1_6["encoding"]["pix_fmt"]
|
assert math.isclose(videos_info[img_key]["video.fps"], metadata_v1["fps"], rel_tol=1e-3)
|
||||||
assert math.isclose(videos_info[img_key]["video.fps"], metadata_v1_6["fps"], rel_tol=1e-3)
|
if "encoding" in metadata_v1:
|
||||||
|
assert videos_info[img_key]["video.pix_fmt"] == metadata_v1["encoding"]["pix_fmt"]
|
||||||
else:
|
else:
|
||||||
assert len(keys["video"]) == 0
|
assert len(keys["video"]) == 0
|
||||||
videos_info = None
|
videos_info = None
|
||||||
|
@ -461,16 +471,16 @@ def convert_dataset(
|
||||||
{"episode_index": ep_idx, "tasks": [tasks_by_episodes[ep_idx]], "length": episode_lengths[ep_idx]}
|
{"episode_index": ep_idx, "tasks": [tasks_by_episodes[ep_idx]], "length": episode_lengths[ep_idx]}
|
||||||
for ep_idx in episode_indices
|
for ep_idx in episode_indices
|
||||||
]
|
]
|
||||||
write_json(episodes, v2_0_dir / "meta" / "episodes.json")
|
write_json(episodes, v20_dir / "meta" / "episodes.json")
|
||||||
|
|
||||||
# Assemble metadata v2.0
|
# Assemble metadata v2.0
|
||||||
metadata_v2_0 = {
|
metadata_v2_0 = {
|
||||||
"codebase_version": V2_0,
|
"codebase_version": V20,
|
||||||
"data_path": PARQUET_PATH,
|
"data_path": PARQUET_PATH,
|
||||||
"robot_type": robot_type,
|
"robot_type": robot_type,
|
||||||
"total_episodes": total_episodes,
|
"total_episodes": total_episodes,
|
||||||
"total_tasks": len(tasks),
|
"total_tasks": len(tasks),
|
||||||
"fps": metadata_v1_6["fps"],
|
"fps": metadata_v1["fps"],
|
||||||
"splits": {"train": f"0:{total_episodes}"},
|
"splits": {"train": f"0:{total_episodes}"},
|
||||||
"keys": keys["sequence"],
|
"keys": keys["sequence"],
|
||||||
"video_keys": keys["video"],
|
"video_keys": keys["video"],
|
||||||
|
@ -479,14 +489,14 @@ def convert_dataset(
|
||||||
"names": names,
|
"names": names,
|
||||||
"videos": videos_info,
|
"videos": videos_info,
|
||||||
}
|
}
|
||||||
write_json(metadata_v2_0, v2_0_dir / "meta" / "info.json")
|
write_json(metadata_v2_0, v20_dir / "meta" / "info.json")
|
||||||
convert_stats_to_json(v1_6_dir / "meta_data", v2_0_dir / "meta")
|
convert_stats_to_json(v1x_dir / "meta_data", v20_dir / "meta")
|
||||||
|
|
||||||
#### TODO: delete
|
#### TODO: delete
|
||||||
repo_id = f"aliberts/{repo_id.split('/')[1]}"
|
# repo_id = f"aliberts/{repo_id.split('/')[1]}"
|
||||||
# if hub_api.repo_exists(repo_id=repo_id, repo_type="dataset"):
|
# if hub_api.repo_exists(repo_id=repo_id, repo_type="dataset"):
|
||||||
# hub_api.delete_repo(repo_id=repo_id, repo_type="dataset")
|
# hub_api.delete_repo(repo_id=repo_id, repo_type="dataset")
|
||||||
hub_api.create_repo(repo_id=repo_id, repo_type="dataset", exist_ok=True)
|
# hub_api.create_repo(repo_id=repo_id, repo_type="dataset", exist_ok=True)
|
||||||
####
|
####
|
||||||
|
|
||||||
with contextlib.suppress(EntryNotFoundError):
|
with contextlib.suppress(EntryNotFoundError):
|
||||||
|
@ -498,28 +508,28 @@ def convert_dataset(
|
||||||
hub_api.upload_folder(
|
hub_api.upload_folder(
|
||||||
repo_id=repo_id,
|
repo_id=repo_id,
|
||||||
path_in_repo="data",
|
path_in_repo="data",
|
||||||
folder_path=v2_0_dir / "data",
|
folder_path=v20_dir / "data",
|
||||||
repo_type="dataset",
|
repo_type="dataset",
|
||||||
revision="main",
|
revision="main",
|
||||||
)
|
)
|
||||||
hub_api.upload_folder(
|
hub_api.upload_folder(
|
||||||
repo_id=repo_id,
|
repo_id=repo_id,
|
||||||
path_in_repo="videos",
|
path_in_repo="videos",
|
||||||
folder_path=v1_6_dir / "videos",
|
folder_path=v1x_dir / "videos",
|
||||||
repo_type="dataset",
|
repo_type="dataset",
|
||||||
revision="main",
|
revision="main",
|
||||||
)
|
)
|
||||||
hub_api.upload_folder(
|
hub_api.upload_folder(
|
||||||
repo_id=repo_id,
|
repo_id=repo_id,
|
||||||
path_in_repo="meta",
|
path_in_repo="meta",
|
||||||
folder_path=v2_0_dir / "meta",
|
folder_path=v20_dir / "meta",
|
||||||
repo_type="dataset",
|
repo_type="dataset",
|
||||||
revision="main",
|
revision="main",
|
||||||
)
|
)
|
||||||
|
|
||||||
card_text = f"[meta/info.json](meta/info.json)\n```json\n{json.dumps(metadata_v2_0, indent=4)}\n```"
|
card_text = f"[meta/info.json](meta/info.json)\n```json\n{json.dumps(metadata_v2_0, indent=4)}\n```"
|
||||||
push_dataset_card_to_hub(repo_id=repo_id, revision="main", tags=repo_tags, text=card_text)
|
push_dataset_card_to_hub(repo_id=repo_id, revision="main", tags=repo_tags, text=card_text)
|
||||||
create_branch(repo_id=repo_id, branch=V2_0, repo_type="dataset")
|
create_branch(repo_id=repo_id, branch=V20, repo_type="dataset")
|
||||||
|
|
||||||
# TODO:
|
# TODO:
|
||||||
# - [X] Add shapes
|
# - [X] Add shapes
|
||||||
|
|
Loading…
Reference in New Issue