From f43e5d07f5602c5b7b0200cdbe0f068c374dc5a1 Mon Sep 17 00:00:00 2001 From: Simon Alibert Date: Wed, 20 Nov 2024 00:26:31 +0100 Subject: [PATCH] Fix tests --- lerobot/common/datasets/utils.py | 5 ++++- lerobot/common/robot_devices/control_utils.py | 12 ++++++++---- 2 files changed, 12 insertions(+), 5 deletions(-) diff --git a/lerobot/common/datasets/utils.py b/lerobot/common/datasets/utils.py index dc43d112..0ad3dfae 100644 --- a/lerobot/common/datasets/utils.py +++ b/lerobot/common/datasets/utils.py @@ -136,7 +136,10 @@ def append_jsonlines(data: dict, fpath: Path) -> None: def load_info(local_dir: Path) -> dict: - return load_json(local_dir / INFO_PATH) + info = load_json(local_dir / INFO_PATH) + for ft in info["features"].values(): + ft["shape"] = tuple(ft["shape"]) + return info def load_stats(local_dir: Path) -> dict: diff --git a/lerobot/common/robot_devices/control_utils.py b/lerobot/common/robot_devices/control_utils.py index 3ede0c38..d55116aa 100644 --- a/lerobot/common/robot_devices/control_utils.py +++ b/lerobot/common/robot_devices/control_utils.py @@ -13,6 +13,7 @@ from functools import cache import cv2 import torch import tqdm +from deepdiff import DeepDiff from termcolor import colored from lerobot.common.datasets.image_writer import safe_stop_image_writer @@ -333,16 +334,19 @@ def sanity_check_dataset_name(repo_id, policy): ) -def sanity_check_dataset_robot_compatibility(dataset, robot, fps, use_videos): +def sanity_check_dataset_robot_compatibility( + dataset: LeRobotDataset, robot: Robot, fps: int, use_videos: bool +) -> None: fields = [ - ("robot_type", dataset.meta.info["robot_type"], robot.robot_type), - ("fps", dataset.meta.info["fps"], fps), + ("robot_type", dataset.meta.robot_type, robot.robot_type), + ("fps", dataset.fps, fps), ("features", dataset.features, get_features_from_robot(robot, use_videos)), ] mismatches = [] for field, dataset_value, present_value in fields: - if dataset_value != present_value: + diff = DeepDiff(dataset_value, present_value) + if diff: mismatches.append(f"{field}: expected {present_value}, got {dataset_value}") if mismatches: