Update & fix conversion script
This commit is contained in:
parent
c72dc23c43
commit
c3c0141738
|
@ -124,19 +124,26 @@ from lerobot.common.datasets.utils import (
|
|||
DEFAULT_CHUNK_SIZE,
|
||||
DEFAULT_PARQUET_PATH,
|
||||
DEFAULT_VIDEO_PATH,
|
||||
EPISODES_PATH,
|
||||
INFO_PATH,
|
||||
STATS_PATH,
|
||||
TASKS_PATH,
|
||||
create_branch,
|
||||
create_lerobot_dataset_card,
|
||||
flatten_dict,
|
||||
get_hub_safe_version,
|
||||
unflatten_dict,
|
||||
)
|
||||
from lerobot.common.datasets.video_utils import VideoFrame # noqa: F401
|
||||
from lerobot.common.utils.utils import init_hydra_config
|
||||
|
||||
V16 = "v1.6"
|
||||
V20 = "v2.0"
|
||||
|
||||
GITATTRIBUTES_REF = "aliberts/gitattributes_reference"
|
||||
VIDEO_FILE = "{video_key}_episode_{episode_index:06d}.mp4"
|
||||
V1_VIDEO_FILE = "{video_key}_episode_{episode_index:06d}.mp4"
|
||||
V1_INFO_PATH = "meta_data/info.json"
|
||||
V1_STATS_PATH = "meta_data/stats.safetensors"
|
||||
|
||||
|
||||
def parse_robot_config(config_path: Path, config_overrides: list[str] | None = None) -> tuple[str, dict]:
|
||||
|
@ -180,17 +187,18 @@ def write_json(data: dict, fpath: Path) -> None:
|
|||
|
||||
|
||||
def write_jsonlines(data: dict, fpath: Path) -> None:
|
||||
fpath.parent.mkdir(exist_ok=True, parents=True)
|
||||
with jsonlines.open(fpath, "w") as writer:
|
||||
writer.write_all(data)
|
||||
|
||||
|
||||
def convert_stats_to_json(input_dir: Path, output_dir: Path) -> None:
|
||||
safetensor_path = input_dir / "stats.safetensors"
|
||||
def convert_stats_to_json(v1_dir: Path, v2_dir: Path) -> None:
|
||||
safetensor_path = v1_dir / V1_STATS_PATH
|
||||
stats = load_file(safetensor_path)
|
||||
serialized_stats = {key: value.tolist() for key, value in stats.items()}
|
||||
serialized_stats = unflatten_dict(serialized_stats)
|
||||
|
||||
json_path = output_dir / "stats.json"
|
||||
json_path = v2_dir / STATS_PATH
|
||||
json_path.parent.mkdir(exist_ok=True, parents=True)
|
||||
with open(json_path, "w") as f:
|
||||
json.dump(serialized_stats, f, indent=4)
|
||||
|
@ -279,7 +287,7 @@ def split_parquet_by_episodes(
|
|||
ep_table = table.filter(pc.equal(table["episode_index"], ep_idx))
|
||||
episode_lengths.insert(ep_idx, len(ep_table))
|
||||
output_file = output_dir / DEFAULT_PARQUET_PATH.format(
|
||||
episode_chunk=ep_chunk, episode_index=ep_idx, total_episodes=total_episodes
|
||||
episode_chunk=ep_chunk, episode_index=ep_idx
|
||||
)
|
||||
pq.write_table(ep_table, output_file)
|
||||
|
||||
|
@ -336,7 +344,7 @@ def move_videos(
|
|||
target_path = DEFAULT_VIDEO_PATH.format(
|
||||
episode_chunk=ep_chunk, video_key=vid_key, episode_index=ep_idx
|
||||
)
|
||||
video_file = VIDEO_FILE.format(video_key=vid_key, episode_index=ep_idx)
|
||||
video_file = V1_VIDEO_FILE.format(video_key=vid_key, episode_index=ep_idx)
|
||||
if len(video_dirs) == 1:
|
||||
video_path = video_dirs[0] / video_file
|
||||
else:
|
||||
|
@ -572,7 +580,7 @@ def convert_dataset(
|
|||
branch = test_branch
|
||||
create_branch(repo_id=repo_id, branch=test_branch, repo_type="dataset")
|
||||
|
||||
metadata_v1 = load_json(v1x_dir / "meta_data" / "info.json")
|
||||
metadata_v1 = load_json(v1x_dir / V1_INFO_PATH)
|
||||
dataset = datasets.load_dataset("parquet", data_dir=v1x_dir / "data", split="train")
|
||||
keys = get_keys(dataset)
|
||||
|
||||
|
@ -611,7 +619,7 @@ def convert_dataset(
|
|||
|
||||
assert set(tasks) == {task for ep_tasks in tasks_by_episodes.values() for task in ep_tasks}
|
||||
tasks = [{"task_index": task_idx, "task": task} for task_idx, task in enumerate(tasks)]
|
||||
write_jsonlines(tasks, v20_dir / "meta" / "tasks.json")
|
||||
write_jsonlines(tasks, v20_dir / TASKS_PATH)
|
||||
|
||||
# Shapes
|
||||
sequence_shapes = {key: dataset.features[key].length for key in keys["sequence"]}
|
||||
|
@ -667,7 +675,7 @@ def convert_dataset(
|
|||
{"episode_index": ep_idx, "tasks": tasks_by_episodes[ep_idx], "length": episode_lengths[ep_idx]}
|
||||
for ep_idx in episode_indices
|
||||
]
|
||||
write_jsonlines(episodes, v20_dir / "meta" / "episodes.jsonl")
|
||||
write_jsonlines(episodes, v20_dir / EPISODES_PATH)
|
||||
|
||||
# Assemble metadata v2.0
|
||||
metadata_v2_0 = {
|
||||
|
@ -689,8 +697,8 @@ def convert_dataset(
|
|||
"names": names,
|
||||
"videos": videos_info,
|
||||
}
|
||||
write_json(metadata_v2_0, v20_dir / "meta" / "info.json")
|
||||
convert_stats_to_json(v1x_dir / "meta_data", v20_dir / "meta")
|
||||
write_json(metadata_v2_0, v20_dir / INFO_PATH)
|
||||
convert_stats_to_json(v1x_dir, v20_dir)
|
||||
|
||||
with contextlib.suppress(EntryNotFoundError):
|
||||
hub_api.delete_folder(repo_id=repo_id, path_in_repo="data", repo_type="dataset", revision=branch)
|
||||
|
|
Loading…
Reference in New Issue