Update info.json format

This commit is contained in:
Simon Alibert 2024-10-08 15:31:37 +02:00
parent 21ba4b5263
commit 2d75b93ba0
1 changed files with 16 additions and 18 deletions

View File

@ -18,7 +18,8 @@ Single-task dataset:
python convert_dataset_16_to_20.py \
--repo-id lerobot/aloha_sim_insertion_human_image \
--task "Insert the peg into the socket." \
--robot-config lerobot/configs/robot/aloha.yaml
--robot-config lerobot/configs/robot/aloha.yaml \
--local-dir data
```
```bash
@ -50,7 +51,7 @@ from huggingface_hub.errors import EntryNotFoundError
from PIL import Image
from safetensors.torch import load_file
from lerobot.common.datasets.utils import create_branch
from lerobot.common.datasets.utils import create_branch, flatten_dict, unflatten_dict
from lerobot.common.utils.utils import init_hydra_config
from lerobot.scripts.push_dataset_to_hub import push_dataset_card_to_hub
@ -58,7 +59,7 @@ V1_6 = "v1.6"
V2_0 = "v2.0"
PARQUET_PATH = "data/train-{episode_index:05d}-of-{total_episodes:05d}.parquet"
VIDEO_PATH = "videos/{image_key}_episode_{episode_index:06d}.mp4"
VIDEO_PATH = "videos/{video_key}_episode_{episode_index:06d}.mp4"
def parse_robot_config(config_path: Path, config_overrides: list[str] | None = None) -> tuple[str, dict]:
@ -104,17 +105,19 @@ def write_json(data: dict, fpath: Path) -> None:
def convert_stats_to_json(input_dir: Path, output_dir: Path) -> None:
safetensor_path = input_dir / "stats.safetensors"
stats = load_file(safetensor_path)
serializable_stats = {key: value.tolist() for key, value in stats.items()}
serialized_stats = {key: value.tolist() for key, value in stats.items()}
serialized_stats = unflatten_dict(serialized_stats)
json_path = output_dir / "stats.json"
json_path.parent.mkdir(exist_ok=True, parents=True)
with open(json_path, "w") as f:
json.dump(serializable_stats, f, indent=4)
json.dump(serialized_stats, f, indent=4)
# Sanity check
with open(json_path) as f:
stats_json = json.load(f)
stats_json = flatten_dict(stats_json)
stats_json = {key: torch.tensor(value) for key, value in stats_json.items()}
for key in stats:
torch.testing.assert_close(stats_json[key], stats[key])
@ -270,6 +273,7 @@ def _get_video_info(video_path: Path | str) -> dict:
"video.channels": pixel_channels,
"video.codec": video_stream_info["codec_name"],
"video.pix_fmt": video_stream_info["pix_fmt"],
"video.is_depth_map": False,
**_get_audio_info(video_path),
}
@ -278,20 +282,13 @@ def _get_video_info(video_path: Path | str) -> dict:
def get_videos_info(repo_id: str, local_dir: Path, video_keys: list[str]) -> dict:
hub_api = HfApi()
videos_info_dict = {
"videos_path": VIDEO_PATH,
"has_audio": False,
"has_depth": False,
}
videos_info_dict = {"videos_path": VIDEO_PATH}
for vid_key in video_keys:
video_path = VIDEO_PATH.format(image_key=vid_key, episode_index=0)
video_path = VIDEO_PATH.format(video_key=vid_key, episode_index=0)
video_path = hub_api.hf_hub_download(
repo_id=repo_id, repo_type="dataset", local_dir=local_dir, filename=video_path
)
videos_info_dict[vid_key] = _get_video_info(video_path)
videos_info_dict["has_audio"] = (
videos_info_dict["has_audio"] or videos_info_dict[vid_key]["has_audio"]
)
return videos_info_dict
@ -359,8 +356,8 @@ def convert_dataset(
tasks_by_episodes: dict | None = None,
robot_config: dict | None = None,
):
v1_6_dir = local_dir / repo_id / V1_6
v2_0_dir = local_dir / repo_id / V2_0
v1_6_dir = local_dir / V1_6 / repo_id
v2_0_dir = local_dir / V2_0 / repo_id
v1_6_dir.mkdir(parents=True, exist_ok=True)
v2_0_dir.mkdir(parents=True, exist_ok=True)
@ -434,9 +431,10 @@ def convert_dataset(
"total_tasks": len(tasks),
"fps": metadata_v1_6["fps"],
"splits": {"train": f"0:{total_episodes}"},
"image_keys": keys["video"] + keys["image"],
"keys": keys["sequence"],
"shapes": {**image_shapes, **video_shapes, **sequence_shapes},
"video_keys": keys["video"],
"image_keys": keys["image"],
"shapes": {**sequence_shapes, **video_shapes, **image_shapes},
"names": names,
"videos": videos_info,
"episodes": episodes,