Update info.json format
This commit is contained in:
parent
21ba4b5263
commit
2d75b93ba0
|
@ -18,7 +18,8 @@ Single-task dataset:
|
|||
python convert_dataset_16_to_20.py \
|
||||
--repo-id lerobot/aloha_sim_insertion_human_image \
|
||||
--task "Insert the peg into the socket." \
|
||||
--robot-config lerobot/configs/robot/aloha.yaml
|
||||
--robot-config lerobot/configs/robot/aloha.yaml \
|
||||
--local-dir data
|
||||
```
|
||||
|
||||
```bash
|
||||
|
@ -50,7 +51,7 @@ from huggingface_hub.errors import EntryNotFoundError
|
|||
from PIL import Image
|
||||
from safetensors.torch import load_file
|
||||
|
||||
from lerobot.common.datasets.utils import create_branch
|
||||
from lerobot.common.datasets.utils import create_branch, flatten_dict, unflatten_dict
|
||||
from lerobot.common.utils.utils import init_hydra_config
|
||||
from lerobot.scripts.push_dataset_to_hub import push_dataset_card_to_hub
|
||||
|
||||
|
@ -58,7 +59,7 @@ V1_6 = "v1.6"
|
|||
V2_0 = "v2.0"
|
||||
|
||||
PARQUET_PATH = "data/train-{episode_index:05d}-of-{total_episodes:05d}.parquet"
|
||||
VIDEO_PATH = "videos/{image_key}_episode_{episode_index:06d}.mp4"
|
||||
VIDEO_PATH = "videos/{video_key}_episode_{episode_index:06d}.mp4"
|
||||
|
||||
|
||||
def parse_robot_config(config_path: Path, config_overrides: list[str] | None = None) -> tuple[str, dict]:
|
||||
|
@ -104,17 +105,19 @@ def write_json(data: dict, fpath: Path) -> None:
|
|||
def convert_stats_to_json(input_dir: Path, output_dir: Path) -> None:
|
||||
safetensor_path = input_dir / "stats.safetensors"
|
||||
stats = load_file(safetensor_path)
|
||||
serializable_stats = {key: value.tolist() for key, value in stats.items()}
|
||||
serialized_stats = {key: value.tolist() for key, value in stats.items()}
|
||||
serialized_stats = unflatten_dict(serialized_stats)
|
||||
|
||||
json_path = output_dir / "stats.json"
|
||||
json_path.parent.mkdir(exist_ok=True, parents=True)
|
||||
with open(json_path, "w") as f:
|
||||
json.dump(serializable_stats, f, indent=4)
|
||||
json.dump(serialized_stats, f, indent=4)
|
||||
|
||||
# Sanity check
|
||||
with open(json_path) as f:
|
||||
stats_json = json.load(f)
|
||||
|
||||
stats_json = flatten_dict(stats_json)
|
||||
stats_json = {key: torch.tensor(value) for key, value in stats_json.items()}
|
||||
for key in stats:
|
||||
torch.testing.assert_close(stats_json[key], stats[key])
|
||||
|
@ -270,6 +273,7 @@ def _get_video_info(video_path: Path | str) -> dict:
|
|||
"video.channels": pixel_channels,
|
||||
"video.codec": video_stream_info["codec_name"],
|
||||
"video.pix_fmt": video_stream_info["pix_fmt"],
|
||||
"video.is_depth_map": False,
|
||||
**_get_audio_info(video_path),
|
||||
}
|
||||
|
||||
|
@ -278,20 +282,13 @@ def _get_video_info(video_path: Path | str) -> dict:
|
|||
|
||||
def get_videos_info(repo_id: str, local_dir: Path, video_keys: list[str]) -> dict:
|
||||
hub_api = HfApi()
|
||||
videos_info_dict = {
|
||||
"videos_path": VIDEO_PATH,
|
||||
"has_audio": False,
|
||||
"has_depth": False,
|
||||
}
|
||||
videos_info_dict = {"videos_path": VIDEO_PATH}
|
||||
for vid_key in video_keys:
|
||||
video_path = VIDEO_PATH.format(image_key=vid_key, episode_index=0)
|
||||
video_path = VIDEO_PATH.format(video_key=vid_key, episode_index=0)
|
||||
video_path = hub_api.hf_hub_download(
|
||||
repo_id=repo_id, repo_type="dataset", local_dir=local_dir, filename=video_path
|
||||
)
|
||||
videos_info_dict[vid_key] = _get_video_info(video_path)
|
||||
videos_info_dict["has_audio"] = (
|
||||
videos_info_dict["has_audio"] or videos_info_dict[vid_key]["has_audio"]
|
||||
)
|
||||
|
||||
return videos_info_dict
|
||||
|
||||
|
@ -359,8 +356,8 @@ def convert_dataset(
|
|||
tasks_by_episodes: dict | None = None,
|
||||
robot_config: dict | None = None,
|
||||
):
|
||||
v1_6_dir = local_dir / repo_id / V1_6
|
||||
v2_0_dir = local_dir / repo_id / V2_0
|
||||
v1_6_dir = local_dir / V1_6 / repo_id
|
||||
v2_0_dir = local_dir / V2_0 / repo_id
|
||||
v1_6_dir.mkdir(parents=True, exist_ok=True)
|
||||
v2_0_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
|
@ -434,9 +431,10 @@ def convert_dataset(
|
|||
"total_tasks": len(tasks),
|
||||
"fps": metadata_v1_6["fps"],
|
||||
"splits": {"train": f"0:{total_episodes}"},
|
||||
"image_keys": keys["video"] + keys["image"],
|
||||
"keys": keys["sequence"],
|
||||
"shapes": {**image_shapes, **video_shapes, **sequence_shapes},
|
||||
"video_keys": keys["video"],
|
||||
"image_keys": keys["image"],
|
||||
"shapes": {**sequence_shapes, **video_shapes, **image_shapes},
|
||||
"names": names,
|
||||
"videos": videos_info,
|
||||
"episodes": episodes,
|
||||
|
|
Loading…
Reference in New Issue