Fix tests

This commit is contained in:
Simon Alibert 2024-11-20 00:26:31 +01:00
parent 9ee8711504
commit f43e5d07f5
2 changed files with 12 additions and 5 deletions

View File

@ -136,7 +136,10 @@ def append_jsonlines(data: dict, fpath: Path) -> None:
def load_info(local_dir: Path) -> dict:
return load_json(local_dir / INFO_PATH)
info = load_json(local_dir / INFO_PATH)
for ft in info["features"].values():
ft["shape"] = tuple(ft["shape"])
return info
def load_stats(local_dir: Path) -> dict:

View File

@ -13,6 +13,7 @@ from functools import cache
import cv2
import torch
import tqdm
from deepdiff import DeepDiff
from termcolor import colored
from lerobot.common.datasets.image_writer import safe_stop_image_writer
@ -333,16 +334,19 @@ def sanity_check_dataset_name(repo_id, policy):
)
def sanity_check_dataset_robot_compatibility(dataset, robot, fps, use_videos):
def sanity_check_dataset_robot_compatibility(
dataset: LeRobotDataset, robot: Robot, fps: int, use_videos: bool
) -> None:
fields = [
("robot_type", dataset.meta.info["robot_type"], robot.robot_type),
("fps", dataset.meta.info["fps"], fps),
("robot_type", dataset.meta.robot_type, robot.robot_type),
("fps", dataset.fps, fps),
("features", dataset.features, get_features_from_robot(robot, use_videos)),
]
mismatches = []
for field, dataset_value, present_value in fields:
if dataset_value != present_value:
diff = DeepDiff(dataset_value, present_value)
if diff:
mismatches.append(f"{field}: expected {present_value}, got {dataset_value}")
if mismatches: