Fix tests
This commit is contained in:
parent
9ee8711504
commit
f43e5d07f5
|
@ -136,7 +136,10 @@ def append_jsonlines(data: dict, fpath: Path) -> None:
|
||||||
|
|
||||||
|
|
||||||
def load_info(local_dir: Path) -> dict:
|
def load_info(local_dir: Path) -> dict:
|
||||||
return load_json(local_dir / INFO_PATH)
|
info = load_json(local_dir / INFO_PATH)
|
||||||
|
for ft in info["features"].values():
|
||||||
|
ft["shape"] = tuple(ft["shape"])
|
||||||
|
return info
|
||||||
|
|
||||||
|
|
||||||
def load_stats(local_dir: Path) -> dict:
|
def load_stats(local_dir: Path) -> dict:
|
||||||
|
|
|
@ -13,6 +13,7 @@ from functools import cache
|
||||||
import cv2
|
import cv2
|
||||||
import torch
|
import torch
|
||||||
import tqdm
|
import tqdm
|
||||||
|
from deepdiff import DeepDiff
|
||||||
from termcolor import colored
|
from termcolor import colored
|
||||||
|
|
||||||
from lerobot.common.datasets.image_writer import safe_stop_image_writer
|
from lerobot.common.datasets.image_writer import safe_stop_image_writer
|
||||||
|
@ -333,16 +334,19 @@ def sanity_check_dataset_name(repo_id, policy):
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def sanity_check_dataset_robot_compatibility(dataset, robot, fps, use_videos):
|
def sanity_check_dataset_robot_compatibility(
|
||||||
|
dataset: LeRobotDataset, robot: Robot, fps: int, use_videos: bool
|
||||||
|
) -> None:
|
||||||
fields = [
|
fields = [
|
||||||
("robot_type", dataset.meta.info["robot_type"], robot.robot_type),
|
("robot_type", dataset.meta.robot_type, robot.robot_type),
|
||||||
("fps", dataset.meta.info["fps"], fps),
|
("fps", dataset.fps, fps),
|
||||||
("features", dataset.features, get_features_from_robot(robot, use_videos)),
|
("features", dataset.features, get_features_from_robot(robot, use_videos)),
|
||||||
]
|
]
|
||||||
|
|
||||||
mismatches = []
|
mismatches = []
|
||||||
for field, dataset_value, present_value in fields:
|
for field, dataset_value, present_value in fields:
|
||||||
if dataset_value != present_value:
|
diff = DeepDiff(dataset_value, present_value)
|
||||||
|
if diff:
|
||||||
mismatches.append(f"{field}: expected {present_value}, got {dataset_value}")
|
mismatches.append(f"{field}: expected {present_value}, got {dataset_value}")
|
||||||
|
|
||||||
if mismatches:
|
if mismatches:
|
||||||
|
|
Loading…
Reference in New Issue