Fix tests
This commit is contained in:
parent
9ee8711504
commit
f43e5d07f5
|
@ -136,7 +136,10 @@ def append_jsonlines(data: dict, fpath: Path) -> None:
|
|||
|
||||
|
||||
def load_info(local_dir: Path) -> dict:
|
||||
return load_json(local_dir / INFO_PATH)
|
||||
info = load_json(local_dir / INFO_PATH)
|
||||
for ft in info["features"].values():
|
||||
ft["shape"] = tuple(ft["shape"])
|
||||
return info
|
||||
|
||||
|
||||
def load_stats(local_dir: Path) -> dict:
|
||||
|
|
|
@ -13,6 +13,7 @@ from functools import cache
|
|||
import cv2
|
||||
import torch
|
||||
import tqdm
|
||||
from deepdiff import DeepDiff
|
||||
from termcolor import colored
|
||||
|
||||
from lerobot.common.datasets.image_writer import safe_stop_image_writer
|
||||
|
@ -333,16 +334,19 @@ def sanity_check_dataset_name(repo_id, policy):
|
|||
)
|
||||
|
||||
|
||||
def sanity_check_dataset_robot_compatibility(dataset, robot, fps, use_videos):
|
||||
def sanity_check_dataset_robot_compatibility(
|
||||
dataset: LeRobotDataset, robot: Robot, fps: int, use_videos: bool
|
||||
) -> None:
|
||||
fields = [
|
||||
("robot_type", dataset.meta.info["robot_type"], robot.robot_type),
|
||||
("fps", dataset.meta.info["fps"], fps),
|
||||
("robot_type", dataset.meta.robot_type, robot.robot_type),
|
||||
("fps", dataset.fps, fps),
|
||||
("features", dataset.features, get_features_from_robot(robot, use_videos)),
|
||||
]
|
||||
|
||||
mismatches = []
|
||||
for field, dataset_value, present_value in fields:
|
||||
if dataset_value != present_value:
|
||||
diff = DeepDiff(dataset_value, present_value)
|
||||
if diff:
|
||||
mismatches.append(f"{field}: expected {present_value}, got {dataset_value}")
|
||||
|
||||
if mismatches:
|
||||
|
|
Loading…
Reference in New Issue