Update & fix conversion script
This commit is contained in:
parent
c72dc23c43
commit
c3c0141738
|
@ -124,19 +124,26 @@ from lerobot.common.datasets.utils import (
|
||||||
DEFAULT_CHUNK_SIZE,
|
DEFAULT_CHUNK_SIZE,
|
||||||
DEFAULT_PARQUET_PATH,
|
DEFAULT_PARQUET_PATH,
|
||||||
DEFAULT_VIDEO_PATH,
|
DEFAULT_VIDEO_PATH,
|
||||||
|
EPISODES_PATH,
|
||||||
|
INFO_PATH,
|
||||||
|
STATS_PATH,
|
||||||
|
TASKS_PATH,
|
||||||
create_branch,
|
create_branch,
|
||||||
create_lerobot_dataset_card,
|
create_lerobot_dataset_card,
|
||||||
flatten_dict,
|
flatten_dict,
|
||||||
get_hub_safe_version,
|
get_hub_safe_version,
|
||||||
unflatten_dict,
|
unflatten_dict,
|
||||||
)
|
)
|
||||||
|
from lerobot.common.datasets.video_utils import VideoFrame # noqa: F401
|
||||||
from lerobot.common.utils.utils import init_hydra_config
|
from lerobot.common.utils.utils import init_hydra_config
|
||||||
|
|
||||||
V16 = "v1.6"
|
V16 = "v1.6"
|
||||||
V20 = "v2.0"
|
V20 = "v2.0"
|
||||||
|
|
||||||
GITATTRIBUTES_REF = "aliberts/gitattributes_reference"
|
GITATTRIBUTES_REF = "aliberts/gitattributes_reference"
|
||||||
VIDEO_FILE = "{video_key}_episode_{episode_index:06d}.mp4"
|
V1_VIDEO_FILE = "{video_key}_episode_{episode_index:06d}.mp4"
|
||||||
|
V1_INFO_PATH = "meta_data/info.json"
|
||||||
|
V1_STATS_PATH = "meta_data/stats.safetensors"
|
||||||
|
|
||||||
|
|
||||||
def parse_robot_config(config_path: Path, config_overrides: list[str] | None = None) -> tuple[str, dict]:
|
def parse_robot_config(config_path: Path, config_overrides: list[str] | None = None) -> tuple[str, dict]:
|
||||||
|
@ -180,17 +187,18 @@ def write_json(data: dict, fpath: Path) -> None:
|
||||||
|
|
||||||
|
|
||||||
def write_jsonlines(data: dict, fpath: Path) -> None:
|
def write_jsonlines(data: dict, fpath: Path) -> None:
|
||||||
|
fpath.parent.mkdir(exist_ok=True, parents=True)
|
||||||
with jsonlines.open(fpath, "w") as writer:
|
with jsonlines.open(fpath, "w") as writer:
|
||||||
writer.write_all(data)
|
writer.write_all(data)
|
||||||
|
|
||||||
|
|
||||||
def convert_stats_to_json(input_dir: Path, output_dir: Path) -> None:
|
def convert_stats_to_json(v1_dir: Path, v2_dir: Path) -> None:
|
||||||
safetensor_path = input_dir / "stats.safetensors"
|
safetensor_path = v1_dir / V1_STATS_PATH
|
||||||
stats = load_file(safetensor_path)
|
stats = load_file(safetensor_path)
|
||||||
serialized_stats = {key: value.tolist() for key, value in stats.items()}
|
serialized_stats = {key: value.tolist() for key, value in stats.items()}
|
||||||
serialized_stats = unflatten_dict(serialized_stats)
|
serialized_stats = unflatten_dict(serialized_stats)
|
||||||
|
|
||||||
json_path = output_dir / "stats.json"
|
json_path = v2_dir / STATS_PATH
|
||||||
json_path.parent.mkdir(exist_ok=True, parents=True)
|
json_path.parent.mkdir(exist_ok=True, parents=True)
|
||||||
with open(json_path, "w") as f:
|
with open(json_path, "w") as f:
|
||||||
json.dump(serialized_stats, f, indent=4)
|
json.dump(serialized_stats, f, indent=4)
|
||||||
|
@ -279,7 +287,7 @@ def split_parquet_by_episodes(
|
||||||
ep_table = table.filter(pc.equal(table["episode_index"], ep_idx))
|
ep_table = table.filter(pc.equal(table["episode_index"], ep_idx))
|
||||||
episode_lengths.insert(ep_idx, len(ep_table))
|
episode_lengths.insert(ep_idx, len(ep_table))
|
||||||
output_file = output_dir / DEFAULT_PARQUET_PATH.format(
|
output_file = output_dir / DEFAULT_PARQUET_PATH.format(
|
||||||
episode_chunk=ep_chunk, episode_index=ep_idx, total_episodes=total_episodes
|
episode_chunk=ep_chunk, episode_index=ep_idx
|
||||||
)
|
)
|
||||||
pq.write_table(ep_table, output_file)
|
pq.write_table(ep_table, output_file)
|
||||||
|
|
||||||
|
@ -336,7 +344,7 @@ def move_videos(
|
||||||
target_path = DEFAULT_VIDEO_PATH.format(
|
target_path = DEFAULT_VIDEO_PATH.format(
|
||||||
episode_chunk=ep_chunk, video_key=vid_key, episode_index=ep_idx
|
episode_chunk=ep_chunk, video_key=vid_key, episode_index=ep_idx
|
||||||
)
|
)
|
||||||
video_file = VIDEO_FILE.format(video_key=vid_key, episode_index=ep_idx)
|
video_file = V1_VIDEO_FILE.format(video_key=vid_key, episode_index=ep_idx)
|
||||||
if len(video_dirs) == 1:
|
if len(video_dirs) == 1:
|
||||||
video_path = video_dirs[0] / video_file
|
video_path = video_dirs[0] / video_file
|
||||||
else:
|
else:
|
||||||
|
@ -572,7 +580,7 @@ def convert_dataset(
|
||||||
branch = test_branch
|
branch = test_branch
|
||||||
create_branch(repo_id=repo_id, branch=test_branch, repo_type="dataset")
|
create_branch(repo_id=repo_id, branch=test_branch, repo_type="dataset")
|
||||||
|
|
||||||
metadata_v1 = load_json(v1x_dir / "meta_data" / "info.json")
|
metadata_v1 = load_json(v1x_dir / V1_INFO_PATH)
|
||||||
dataset = datasets.load_dataset("parquet", data_dir=v1x_dir / "data", split="train")
|
dataset = datasets.load_dataset("parquet", data_dir=v1x_dir / "data", split="train")
|
||||||
keys = get_keys(dataset)
|
keys = get_keys(dataset)
|
||||||
|
|
||||||
|
@ -611,7 +619,7 @@ def convert_dataset(
|
||||||
|
|
||||||
assert set(tasks) == {task for ep_tasks in tasks_by_episodes.values() for task in ep_tasks}
|
assert set(tasks) == {task for ep_tasks in tasks_by_episodes.values() for task in ep_tasks}
|
||||||
tasks = [{"task_index": task_idx, "task": task} for task_idx, task in enumerate(tasks)]
|
tasks = [{"task_index": task_idx, "task": task} for task_idx, task in enumerate(tasks)]
|
||||||
write_jsonlines(tasks, v20_dir / "meta" / "tasks.json")
|
write_jsonlines(tasks, v20_dir / TASKS_PATH)
|
||||||
|
|
||||||
# Shapes
|
# Shapes
|
||||||
sequence_shapes = {key: dataset.features[key].length for key in keys["sequence"]}
|
sequence_shapes = {key: dataset.features[key].length for key in keys["sequence"]}
|
||||||
|
@ -667,7 +675,7 @@ def convert_dataset(
|
||||||
{"episode_index": ep_idx, "tasks": tasks_by_episodes[ep_idx], "length": episode_lengths[ep_idx]}
|
{"episode_index": ep_idx, "tasks": tasks_by_episodes[ep_idx], "length": episode_lengths[ep_idx]}
|
||||||
for ep_idx in episode_indices
|
for ep_idx in episode_indices
|
||||||
]
|
]
|
||||||
write_jsonlines(episodes, v20_dir / "meta" / "episodes.jsonl")
|
write_jsonlines(episodes, v20_dir / EPISODES_PATH)
|
||||||
|
|
||||||
# Assemble metadata v2.0
|
# Assemble metadata v2.0
|
||||||
metadata_v2_0 = {
|
metadata_v2_0 = {
|
||||||
|
@ -689,8 +697,8 @@ def convert_dataset(
|
||||||
"names": names,
|
"names": names,
|
||||||
"videos": videos_info,
|
"videos": videos_info,
|
||||||
}
|
}
|
||||||
write_json(metadata_v2_0, v20_dir / "meta" / "info.json")
|
write_json(metadata_v2_0, v20_dir / INFO_PATH)
|
||||||
convert_stats_to_json(v1x_dir / "meta_data", v20_dir / "meta")
|
convert_stats_to_json(v1x_dir, v20_dir)
|
||||||
|
|
||||||
with contextlib.suppress(EntryNotFoundError):
|
with contextlib.suppress(EntryNotFoundError):
|
||||||
hub_api.delete_folder(repo_id=repo_id, path_in_repo="data", repo_type="dataset", revision=branch)
|
hub_api.delete_folder(repo_id=repo_id, path_in_repo="data", repo_type="dataset", revision=branch)
|
||||||
|
|
Loading…
Reference in New Issue