From eed7b55fe3451c48a99c74177e93190e1ffb578d Mon Sep 17 00:00:00 2001 From: Remi Cadene Date: Sun, 13 Oct 2024 18:31:34 +0200 Subject: [PATCH] Fix unit tests --- lerobot/common/datasets/populate_dataset.py | 4 - .../push_dataset_to_hub/aloha_hdf5_format.py | 22 ----- lerobot/common/robot_devices/control_robot.py | 6 +- lerobot/scripts/control_robot.py | 4 +- tests/test_control_robot.py | 95 +++++++++++++++++-- tests/test_robots.py | 1 + 6 files changed, 94 insertions(+), 38 deletions(-) diff --git a/lerobot/common/datasets/populate_dataset.py b/lerobot/common/datasets/populate_dataset.py index 0afc5fd0..e4fd6c1c 100644 --- a/lerobot/common/datasets/populate_dataset.py +++ b/lerobot/common/datasets/populate_dataset.py @@ -296,7 +296,6 @@ def add_frame(dataset, observation, action): ep_dict[key].append(frame_info) - dataset["image_keys"] = img_keys # used for video generation dataset["current_frame_index"] += 1 @@ -389,9 +388,6 @@ def from_dataset_to_lerobot_dataset(dataset, play_sounds): image_keys = [key for key in data_dict if "image" in key] encode_videos(dataset, image_keys, play_sounds) - total_frames = data_dict["frame_index"].shape[0] - data_dict["index"] = torch.arange(0, total_frames, 1) - hf_dataset = to_hf_dataset(data_dict, video) episode_data_index = calculate_episode_data_index(hf_dataset) diff --git a/lerobot/common/datasets/push_dataset_to_hub/aloha_hdf5_format.py b/lerobot/common/datasets/push_dataset_to_hub/aloha_hdf5_format.py index b5efc953..52c4bba3 100644 --- a/lerobot/common/datasets/push_dataset_to_hub/aloha_hdf5_format.py +++ b/lerobot/common/datasets/push_dataset_to_hub/aloha_hdf5_format.py @@ -200,28 +200,6 @@ def to_hf_dataset(data_dict, video) -> Dataset: features["next.done"] = Value(dtype="bool", id=None) features["index"] = Value(dtype="int64", id=None) - for key in data_dict: - if isinstance(data_dict[key], list): - print(key, len(data_dict[key])) - elif isinstance(data_dict[key], torch.Tensor): - print(key, data_dict[key].shape) - else: - print(key, data_dict[key]) - - data_dict["episode_index"] = data_dict["episode_index"].tolist() - data_dict["frame_index"] = data_dict["frame_index"].tolist() - data_dict["timestamp"] = data_dict["timestamp"].tolist() - data_dict["next.done"] = data_dict["next.done"].tolist() - data_dict["index"] = data_dict["index"].tolist() - - for key in data_dict: - if isinstance(data_dict[key], list): - print(key, len(data_dict[key])) - elif isinstance(data_dict[key], torch.Tensor): - print(key, data_dict[key].shape) - else: - print(key, data_dict[key]) - hf_dataset = Dataset.from_dict(data_dict, features=Features(features)) hf_dataset.set_transform(hf_transform_to_torch) return hf_dataset diff --git a/lerobot/common/robot_devices/control_robot.py b/lerobot/common/robot_devices/control_robot.py index f23a90b2..f7c91c75 100644 --- a/lerobot/common/robot_devices/control_robot.py +++ b/lerobot/common/robot_devices/control_robot.py @@ -7,6 +7,7 @@ import logging import time import traceback from contextlib import nullcontext +from copy import copy from functools import cache import cv2 @@ -90,6 +91,7 @@ def has_method(_object: object, method_name: str): def predict_action(observation, policy, device, use_amp): + observation = copy(observation) with ( torch.inference_mode(), torch.autocast(device_type=device.type) if device.type == "cuda" and use_amp else nullcontext(), @@ -297,7 +299,9 @@ def stop_recording(robot, listener, display_cameras): def sanity_check_dataset_name(repo_id, policy): _, dataset_name = repo_id.split("/") - if dataset_name.startswith("eval_") and policy is None: + # either repo_id doesnt start with "eval_" and there is no policy + # or repo_id starts with "eval_" and there is a policy + if dataset_name.startswith("eval_") == (policy is None): raise ValueError( f"Your dataset name begins by 'eval_' ({dataset_name}) but no policy is provided ({policy})." ) diff --git a/lerobot/scripts/control_robot.py b/lerobot/scripts/control_robot.py index 2f9f8a73..47e1f2b6 100644 --- a/lerobot/scripts/control_robot.py +++ b/lerobot/scripts/control_robot.py @@ -201,11 +201,11 @@ def teleoperate(robot: Robot, fps: int | None = None, teleop_time_s: float | Non @safe_disconnect def record( robot: Robot, + root: str, + repo_id: str, pretrained_policy_name_or_path: str | None = None, policy_overrides: List[str] | None = None, fps: int | None = None, - root="data", - repo_id="lerobot/debug", warmup_time_s=2, episode_time_s=10, reset_time_s=5, diff --git a/tests/test_control_robot.py b/tests/test_control_robot.py index 97379766..19a83780 100644 --- a/tests/test_control_robot.py +++ b/tests/test_control_robot.py @@ -29,7 +29,7 @@ from unittest.mock import patch import pytest -from lerobot.common.datasets.populate_dataset import add_frame +from lerobot.common.datasets.populate_dataset import add_frame, init_dataset from lerobot.common.logger import Logger from lerobot.common.policies.factory import make_policy from lerobot.common.utils.utils import init_hydra_config @@ -131,13 +131,14 @@ def test_record_and_replay_and_policy(tmpdir, request, robot_type, mock): root = tmpdir / "data" repo_id = "lerobot/debug" + eval_repo_id = "lerobot/eval_debug" robot = make_robot(robot_type, overrides=overrides, mock=mock) dataset = record( robot, - fps=30, - root=root, - repo_id=repo_id, + root, + repo_id, + fps=1, warmup_time_s=1, episode_time_s=1, reset_time_s=1, @@ -149,8 +150,10 @@ def test_record_and_replay_and_policy(tmpdir, request, robot_type, mock): display_cameras=False, play_sounds=False, ) + assert dataset.num_episodes == 2 + assert len(dataset) == 2 - replay(robot, episode=0, fps=30, root=root, repo_id=repo_id, play_sounds=False) + replay(robot, episode=0, fps=1, root=root, repo_id=repo_id, play_sounds=False) # TODO(rcadene, aliberts): rethink this design if robot_type == "aloha": @@ -171,6 +174,9 @@ def test_record_and_replay_and_policy(tmpdir, request, robot_type, mock): if robot_type == "koch_bimanual": overrides += ["env.state_dim=12", "env.action_dim=12"] + overrides += ["wandb.enable=false"] + overrides += ["env.fps=1"] + cfg = init_hydra_config( DEFAULT_CONFIG_PATH, overrides=overrides, @@ -212,6 +218,8 @@ def test_record_and_replay_and_policy(tmpdir, request, robot_type, mock): record( robot, + root, + eval_repo_id, pretrained_policy_name_or_path, warmup_time_s=1, episode_time_s=1, @@ -225,9 +233,75 @@ def test_record_and_replay_and_policy(tmpdir, request, robot_type, mock): num_image_writer_processes=num_image_writer_processes, ) + assert dataset.num_episodes == 2 + assert len(dataset) == 2 + del robot +@pytest.mark.parametrize("robot_type, mock", [("koch", True)]) +@require_robot +def test_resume_record(tmpdir, request, robot_type, mock): + if mock and robot_type != "aloha": + request.getfixturevalue("patch_builtins_input") + + # Create an empty calibration directory to trigger manual calibration + # and avoid writing calibration files in user .cache/calibration folder + calibration_dir = tmpdir / robot_type + overrides = [f"calibration_dir={calibration_dir}"] + else: + # Use the default .cache/calibration folder when mock=False or for aloha + overrides = [] + + robot = make_robot(robot_type, overrides=overrides, mock=mock) + + root = Path(tmpdir) / "data" + repo_id = "lerobot/debug" + + dataset = record( + robot, + root, + repo_id, + fps=1, + warmup_time_s=0, + episode_time_s=1, + num_episodes=1, + push_to_hub=False, + video=False, + display_cameras=False, + play_sounds=False, + run_compute_stats=False, + ) + assert len(dataset) == 1, "`dataset` should contain only 1 frame" + + init_dataset_return_value = {} + + def wrapped_init_dataset(*args, **kwargs): + nonlocal init_dataset_return_value + init_dataset_return_value = init_dataset(*args, **kwargs) + return init_dataset_return_value + + with patch("lerobot.scripts.control_robot.init_dataset", wraps=wrapped_init_dataset): + dataset = record( + robot, + root, + repo_id, + fps=1, + warmup_time_s=0, + episode_time_s=1, + num_episodes=2, + push_to_hub=False, + video=False, + display_cameras=False, + play_sounds=False, + run_compute_stats=False, + ) + assert len(dataset) == 2, "`dataset` should contain only 1 frame" + assert ( + init_dataset_return_value["num_episodes"] == 2 + ), "`init_dataset` should load the previous episode" + + @pytest.mark.parametrize("robot_type, mock", [("koch", True)]) @require_robot def test_record_with_event_rerecord_episode(tmpdir, request, robot_type, mock): @@ -258,9 +332,9 @@ def test_record_with_event_rerecord_episode(tmpdir, request, robot_type, mock): dataset = record( robot, + root, + repo_id, fps=1, - root=root, - repo_id=repo_id, warmup_time_s=0, episode_time_s=1, num_episodes=1, @@ -268,6 +342,7 @@ def test_record_with_event_rerecord_episode(tmpdir, request, robot_type, mock): video=False, display_cameras=False, play_sounds=False, + run_compute_stats=False, ) assert not mock_events["rerecord_episode"], "`rerecord_episode` wasn't properly reset to False" @@ -316,6 +391,7 @@ def test_record_with_event_exit_early(tmpdir, request, robot_type, mock): video=False, display_cameras=False, play_sounds=False, + run_compute_stats=False, ) assert not mock_events["exit_early"], "`exit_early` wasn't properly reset to False" @@ -355,9 +431,9 @@ def test_record_with_event_stop_recording(tmpdir, request, robot_type, mock, num dataset = record( robot, + root, + repo_id, fps=1, - root=root, - repo_id=repo_id, warmup_time_s=0, episode_time_s=1, num_episodes=2, @@ -365,6 +441,7 @@ def test_record_with_event_stop_recording(tmpdir, request, robot_type, mock, num video=False, display_cameras=False, play_sounds=False, + run_compute_stats=False, num_image_writer_processes=num_image_writer_processes, ) diff --git a/tests/test_robots.py b/tests/test_robots.py index 72f0c944..13ad8c45 100644 --- a/tests/test_robots.py +++ b/tests/test_robots.py @@ -127,6 +127,7 @@ def test_robot(tmpdir, request, robot_type, mock): # TODO(rcadene): skipping image for now as it's challenging to assess equality between two consecutive frames continue assert torch.allclose(captured_observation[name], observation[name], atol=1) + assert captured_observation[name].shape == observation[name].shape # Test send_action can run robot.send_action(action["action"])