Update info.json format

This commit is contained in:
Simon Alibert 2024-10-08 15:31:37 +02:00
parent 21ba4b5263
commit 2d75b93ba0
1 changed files with 16 additions and 18 deletions

View File

@ -18,7 +18,8 @@ Single-task dataset:
python convert_dataset_16_to_20.py \ python convert_dataset_16_to_20.py \
--repo-id lerobot/aloha_sim_insertion_human_image \ --repo-id lerobot/aloha_sim_insertion_human_image \
--task "Insert the peg into the socket." \ --task "Insert the peg into the socket." \
--robot-config lerobot/configs/robot/aloha.yaml --robot-config lerobot/configs/robot/aloha.yaml \
--local-dir data
``` ```
```bash ```bash
@ -50,7 +51,7 @@ from huggingface_hub.errors import EntryNotFoundError
from PIL import Image from PIL import Image
from safetensors.torch import load_file from safetensors.torch import load_file
from lerobot.common.datasets.utils import create_branch from lerobot.common.datasets.utils import create_branch, flatten_dict, unflatten_dict
from lerobot.common.utils.utils import init_hydra_config from lerobot.common.utils.utils import init_hydra_config
from lerobot.scripts.push_dataset_to_hub import push_dataset_card_to_hub from lerobot.scripts.push_dataset_to_hub import push_dataset_card_to_hub
@ -58,7 +59,7 @@ V1_6 = "v1.6"
V2_0 = "v2.0" V2_0 = "v2.0"
PARQUET_PATH = "data/train-{episode_index:05d}-of-{total_episodes:05d}.parquet" PARQUET_PATH = "data/train-{episode_index:05d}-of-{total_episodes:05d}.parquet"
VIDEO_PATH = "videos/{image_key}_episode_{episode_index:06d}.mp4" VIDEO_PATH = "videos/{video_key}_episode_{episode_index:06d}.mp4"
def parse_robot_config(config_path: Path, config_overrides: list[str] | None = None) -> tuple[str, dict]: def parse_robot_config(config_path: Path, config_overrides: list[str] | None = None) -> tuple[str, dict]:
@ -104,17 +105,19 @@ def write_json(data: dict, fpath: Path) -> None:
def convert_stats_to_json(input_dir: Path, output_dir: Path) -> None: def convert_stats_to_json(input_dir: Path, output_dir: Path) -> None:
safetensor_path = input_dir / "stats.safetensors" safetensor_path = input_dir / "stats.safetensors"
stats = load_file(safetensor_path) stats = load_file(safetensor_path)
serializable_stats = {key: value.tolist() for key, value in stats.items()} serialized_stats = {key: value.tolist() for key, value in stats.items()}
serialized_stats = unflatten_dict(serialized_stats)
json_path = output_dir / "stats.json" json_path = output_dir / "stats.json"
json_path.parent.mkdir(exist_ok=True, parents=True) json_path.parent.mkdir(exist_ok=True, parents=True)
with open(json_path, "w") as f: with open(json_path, "w") as f:
json.dump(serializable_stats, f, indent=4) json.dump(serialized_stats, f, indent=4)
# Sanity check # Sanity check
with open(json_path) as f: with open(json_path) as f:
stats_json = json.load(f) stats_json = json.load(f)
stats_json = flatten_dict(stats_json)
stats_json = {key: torch.tensor(value) for key, value in stats_json.items()} stats_json = {key: torch.tensor(value) for key, value in stats_json.items()}
for key in stats: for key in stats:
torch.testing.assert_close(stats_json[key], stats[key]) torch.testing.assert_close(stats_json[key], stats[key])
@ -270,6 +273,7 @@ def _get_video_info(video_path: Path | str) -> dict:
"video.channels": pixel_channels, "video.channels": pixel_channels,
"video.codec": video_stream_info["codec_name"], "video.codec": video_stream_info["codec_name"],
"video.pix_fmt": video_stream_info["pix_fmt"], "video.pix_fmt": video_stream_info["pix_fmt"],
"video.is_depth_map": False,
**_get_audio_info(video_path), **_get_audio_info(video_path),
} }
@ -278,20 +282,13 @@ def _get_video_info(video_path: Path | str) -> dict:
def get_videos_info(repo_id: str, local_dir: Path, video_keys: list[str]) -> dict: def get_videos_info(repo_id: str, local_dir: Path, video_keys: list[str]) -> dict:
hub_api = HfApi() hub_api = HfApi()
videos_info_dict = { videos_info_dict = {"videos_path": VIDEO_PATH}
"videos_path": VIDEO_PATH,
"has_audio": False,
"has_depth": False,
}
for vid_key in video_keys: for vid_key in video_keys:
video_path = VIDEO_PATH.format(image_key=vid_key, episode_index=0) video_path = VIDEO_PATH.format(video_key=vid_key, episode_index=0)
video_path = hub_api.hf_hub_download( video_path = hub_api.hf_hub_download(
repo_id=repo_id, repo_type="dataset", local_dir=local_dir, filename=video_path repo_id=repo_id, repo_type="dataset", local_dir=local_dir, filename=video_path
) )
videos_info_dict[vid_key] = _get_video_info(video_path) videos_info_dict[vid_key] = _get_video_info(video_path)
videos_info_dict["has_audio"] = (
videos_info_dict["has_audio"] or videos_info_dict[vid_key]["has_audio"]
)
return videos_info_dict return videos_info_dict
@ -359,8 +356,8 @@ def convert_dataset(
tasks_by_episodes: dict | None = None, tasks_by_episodes: dict | None = None,
robot_config: dict | None = None, robot_config: dict | None = None,
): ):
v1_6_dir = local_dir / repo_id / V1_6 v1_6_dir = local_dir / V1_6 / repo_id
v2_0_dir = local_dir / repo_id / V2_0 v2_0_dir = local_dir / V2_0 / repo_id
v1_6_dir.mkdir(parents=True, exist_ok=True) v1_6_dir.mkdir(parents=True, exist_ok=True)
v2_0_dir.mkdir(parents=True, exist_ok=True) v2_0_dir.mkdir(parents=True, exist_ok=True)
@ -434,9 +431,10 @@ def convert_dataset(
"total_tasks": len(tasks), "total_tasks": len(tasks),
"fps": metadata_v1_6["fps"], "fps": metadata_v1_6["fps"],
"splits": {"train": f"0:{total_episodes}"}, "splits": {"train": f"0:{total_episodes}"},
"image_keys": keys["video"] + keys["image"],
"keys": keys["sequence"], "keys": keys["sequence"],
"shapes": {**image_shapes, **video_shapes, **sequence_shapes}, "video_keys": keys["video"],
"image_keys": keys["image"],
"shapes": {**sequence_shapes, **video_shapes, **image_shapes},
"names": names, "names": names,
"videos": videos_info, "videos": videos_info,
"episodes": episodes, "episodes": episodes,