Update info.json format
This commit is contained in:
parent
21ba4b5263
commit
2d75b93ba0
|
@ -18,7 +18,8 @@ Single-task dataset:
|
||||||
python convert_dataset_16_to_20.py \
|
python convert_dataset_16_to_20.py \
|
||||||
--repo-id lerobot/aloha_sim_insertion_human_image \
|
--repo-id lerobot/aloha_sim_insertion_human_image \
|
||||||
--task "Insert the peg into the socket." \
|
--task "Insert the peg into the socket." \
|
||||||
--robot-config lerobot/configs/robot/aloha.yaml
|
--robot-config lerobot/configs/robot/aloha.yaml \
|
||||||
|
--local-dir data
|
||||||
```
|
```
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
|
@ -50,7 +51,7 @@ from huggingface_hub.errors import EntryNotFoundError
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
from safetensors.torch import load_file
|
from safetensors.torch import load_file
|
||||||
|
|
||||||
from lerobot.common.datasets.utils import create_branch
|
from lerobot.common.datasets.utils import create_branch, flatten_dict, unflatten_dict
|
||||||
from lerobot.common.utils.utils import init_hydra_config
|
from lerobot.common.utils.utils import init_hydra_config
|
||||||
from lerobot.scripts.push_dataset_to_hub import push_dataset_card_to_hub
|
from lerobot.scripts.push_dataset_to_hub import push_dataset_card_to_hub
|
||||||
|
|
||||||
|
@ -58,7 +59,7 @@ V1_6 = "v1.6"
|
||||||
V2_0 = "v2.0"
|
V2_0 = "v2.0"
|
||||||
|
|
||||||
PARQUET_PATH = "data/train-{episode_index:05d}-of-{total_episodes:05d}.parquet"
|
PARQUET_PATH = "data/train-{episode_index:05d}-of-{total_episodes:05d}.parquet"
|
||||||
VIDEO_PATH = "videos/{image_key}_episode_{episode_index:06d}.mp4"
|
VIDEO_PATH = "videos/{video_key}_episode_{episode_index:06d}.mp4"
|
||||||
|
|
||||||
|
|
||||||
def parse_robot_config(config_path: Path, config_overrides: list[str] | None = None) -> tuple[str, dict]:
|
def parse_robot_config(config_path: Path, config_overrides: list[str] | None = None) -> tuple[str, dict]:
|
||||||
|
@ -104,17 +105,19 @@ def write_json(data: dict, fpath: Path) -> None:
|
||||||
def convert_stats_to_json(input_dir: Path, output_dir: Path) -> None:
|
def convert_stats_to_json(input_dir: Path, output_dir: Path) -> None:
|
||||||
safetensor_path = input_dir / "stats.safetensors"
|
safetensor_path = input_dir / "stats.safetensors"
|
||||||
stats = load_file(safetensor_path)
|
stats = load_file(safetensor_path)
|
||||||
serializable_stats = {key: value.tolist() for key, value in stats.items()}
|
serialized_stats = {key: value.tolist() for key, value in stats.items()}
|
||||||
|
serialized_stats = unflatten_dict(serialized_stats)
|
||||||
|
|
||||||
json_path = output_dir / "stats.json"
|
json_path = output_dir / "stats.json"
|
||||||
json_path.parent.mkdir(exist_ok=True, parents=True)
|
json_path.parent.mkdir(exist_ok=True, parents=True)
|
||||||
with open(json_path, "w") as f:
|
with open(json_path, "w") as f:
|
||||||
json.dump(serializable_stats, f, indent=4)
|
json.dump(serialized_stats, f, indent=4)
|
||||||
|
|
||||||
# Sanity check
|
# Sanity check
|
||||||
with open(json_path) as f:
|
with open(json_path) as f:
|
||||||
stats_json = json.load(f)
|
stats_json = json.load(f)
|
||||||
|
|
||||||
|
stats_json = flatten_dict(stats_json)
|
||||||
stats_json = {key: torch.tensor(value) for key, value in stats_json.items()}
|
stats_json = {key: torch.tensor(value) for key, value in stats_json.items()}
|
||||||
for key in stats:
|
for key in stats:
|
||||||
torch.testing.assert_close(stats_json[key], stats[key])
|
torch.testing.assert_close(stats_json[key], stats[key])
|
||||||
|
@ -270,6 +273,7 @@ def _get_video_info(video_path: Path | str) -> dict:
|
||||||
"video.channels": pixel_channels,
|
"video.channels": pixel_channels,
|
||||||
"video.codec": video_stream_info["codec_name"],
|
"video.codec": video_stream_info["codec_name"],
|
||||||
"video.pix_fmt": video_stream_info["pix_fmt"],
|
"video.pix_fmt": video_stream_info["pix_fmt"],
|
||||||
|
"video.is_depth_map": False,
|
||||||
**_get_audio_info(video_path),
|
**_get_audio_info(video_path),
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -278,20 +282,13 @@ def _get_video_info(video_path: Path | str) -> dict:
|
||||||
|
|
||||||
def get_videos_info(repo_id: str, local_dir: Path, video_keys: list[str]) -> dict:
|
def get_videos_info(repo_id: str, local_dir: Path, video_keys: list[str]) -> dict:
|
||||||
hub_api = HfApi()
|
hub_api = HfApi()
|
||||||
videos_info_dict = {
|
videos_info_dict = {"videos_path": VIDEO_PATH}
|
||||||
"videos_path": VIDEO_PATH,
|
|
||||||
"has_audio": False,
|
|
||||||
"has_depth": False,
|
|
||||||
}
|
|
||||||
for vid_key in video_keys:
|
for vid_key in video_keys:
|
||||||
video_path = VIDEO_PATH.format(image_key=vid_key, episode_index=0)
|
video_path = VIDEO_PATH.format(video_key=vid_key, episode_index=0)
|
||||||
video_path = hub_api.hf_hub_download(
|
video_path = hub_api.hf_hub_download(
|
||||||
repo_id=repo_id, repo_type="dataset", local_dir=local_dir, filename=video_path
|
repo_id=repo_id, repo_type="dataset", local_dir=local_dir, filename=video_path
|
||||||
)
|
)
|
||||||
videos_info_dict[vid_key] = _get_video_info(video_path)
|
videos_info_dict[vid_key] = _get_video_info(video_path)
|
||||||
videos_info_dict["has_audio"] = (
|
|
||||||
videos_info_dict["has_audio"] or videos_info_dict[vid_key]["has_audio"]
|
|
||||||
)
|
|
||||||
|
|
||||||
return videos_info_dict
|
return videos_info_dict
|
||||||
|
|
||||||
|
@ -359,8 +356,8 @@ def convert_dataset(
|
||||||
tasks_by_episodes: dict | None = None,
|
tasks_by_episodes: dict | None = None,
|
||||||
robot_config: dict | None = None,
|
robot_config: dict | None = None,
|
||||||
):
|
):
|
||||||
v1_6_dir = local_dir / repo_id / V1_6
|
v1_6_dir = local_dir / V1_6 / repo_id
|
||||||
v2_0_dir = local_dir / repo_id / V2_0
|
v2_0_dir = local_dir / V2_0 / repo_id
|
||||||
v1_6_dir.mkdir(parents=True, exist_ok=True)
|
v1_6_dir.mkdir(parents=True, exist_ok=True)
|
||||||
v2_0_dir.mkdir(parents=True, exist_ok=True)
|
v2_0_dir.mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
|
@ -434,9 +431,10 @@ def convert_dataset(
|
||||||
"total_tasks": len(tasks),
|
"total_tasks": len(tasks),
|
||||||
"fps": metadata_v1_6["fps"],
|
"fps": metadata_v1_6["fps"],
|
||||||
"splits": {"train": f"0:{total_episodes}"},
|
"splits": {"train": f"0:{total_episodes}"},
|
||||||
"image_keys": keys["video"] + keys["image"],
|
|
||||||
"keys": keys["sequence"],
|
"keys": keys["sequence"],
|
||||||
"shapes": {**image_shapes, **video_shapes, **sequence_shapes},
|
"video_keys": keys["video"],
|
||||||
|
"image_keys": keys["image"],
|
||||||
|
"shapes": {**sequence_shapes, **video_shapes, **image_shapes},
|
||||||
"names": names,
|
"names": names,
|
||||||
"videos": videos_info,
|
"videos": videos_info,
|
||||||
"episodes": episodes,
|
"episodes": episodes,
|
||||||
|
|
Loading…
Reference in New Issue