Fix unit tests

This commit is contained in:
Remi Cadene 2024-10-13 18:31:34 +02:00
parent d02e204e10
commit eed7b55fe3
6 changed files with 94 additions and 38 deletions

View File

@ -296,7 +296,6 @@ def add_frame(dataset, observation, action):
ep_dict[key].append(frame_info) ep_dict[key].append(frame_info)
dataset["image_keys"] = img_keys # used for video generation
dataset["current_frame_index"] += 1 dataset["current_frame_index"] += 1
@ -389,9 +388,6 @@ def from_dataset_to_lerobot_dataset(dataset, play_sounds):
image_keys = [key for key in data_dict if "image" in key] image_keys = [key for key in data_dict if "image" in key]
encode_videos(dataset, image_keys, play_sounds) encode_videos(dataset, image_keys, play_sounds)
total_frames = data_dict["frame_index"].shape[0]
data_dict["index"] = torch.arange(0, total_frames, 1)
hf_dataset = to_hf_dataset(data_dict, video) hf_dataset = to_hf_dataset(data_dict, video)
episode_data_index = calculate_episode_data_index(hf_dataset) episode_data_index = calculate_episode_data_index(hf_dataset)

View File

@ -200,28 +200,6 @@ def to_hf_dataset(data_dict, video) -> Dataset:
features["next.done"] = Value(dtype="bool", id=None) features["next.done"] = Value(dtype="bool", id=None)
features["index"] = Value(dtype="int64", id=None) features["index"] = Value(dtype="int64", id=None)
for key in data_dict:
if isinstance(data_dict[key], list):
print(key, len(data_dict[key]))
elif isinstance(data_dict[key], torch.Tensor):
print(key, data_dict[key].shape)
else:
print(key, data_dict[key])
data_dict["episode_index"] = data_dict["episode_index"].tolist()
data_dict["frame_index"] = data_dict["frame_index"].tolist()
data_dict["timestamp"] = data_dict["timestamp"].tolist()
data_dict["next.done"] = data_dict["next.done"].tolist()
data_dict["index"] = data_dict["index"].tolist()
for key in data_dict:
if isinstance(data_dict[key], list):
print(key, len(data_dict[key]))
elif isinstance(data_dict[key], torch.Tensor):
print(key, data_dict[key].shape)
else:
print(key, data_dict[key])
hf_dataset = Dataset.from_dict(data_dict, features=Features(features)) hf_dataset = Dataset.from_dict(data_dict, features=Features(features))
hf_dataset.set_transform(hf_transform_to_torch) hf_dataset.set_transform(hf_transform_to_torch)
return hf_dataset return hf_dataset

View File

@ -7,6 +7,7 @@ import logging
import time import time
import traceback import traceback
from contextlib import nullcontext from contextlib import nullcontext
from copy import copy
from functools import cache from functools import cache
import cv2 import cv2
@ -90,6 +91,7 @@ def has_method(_object: object, method_name: str):
def predict_action(observation, policy, device, use_amp): def predict_action(observation, policy, device, use_amp):
observation = copy(observation)
with ( with (
torch.inference_mode(), torch.inference_mode(),
torch.autocast(device_type=device.type) if device.type == "cuda" and use_amp else nullcontext(), torch.autocast(device_type=device.type) if device.type == "cuda" and use_amp else nullcontext(),
@ -297,7 +299,9 @@ def stop_recording(robot, listener, display_cameras):
def sanity_check_dataset_name(repo_id, policy): def sanity_check_dataset_name(repo_id, policy):
_, dataset_name = repo_id.split("/") _, dataset_name = repo_id.split("/")
if dataset_name.startswith("eval_") and policy is None: # either repo_id doesnt start with "eval_" and there is no policy
# or repo_id starts with "eval_" and there is a policy
if dataset_name.startswith("eval_") == (policy is None):
raise ValueError( raise ValueError(
f"Your dataset name begins by 'eval_' ({dataset_name}) but no policy is provided ({policy})." f"Your dataset name begins by 'eval_' ({dataset_name}) but no policy is provided ({policy})."
) )

View File

@ -201,11 +201,11 @@ def teleoperate(robot: Robot, fps: int | None = None, teleop_time_s: float | Non
@safe_disconnect @safe_disconnect
def record( def record(
robot: Robot, robot: Robot,
root: str,
repo_id: str,
pretrained_policy_name_or_path: str | None = None, pretrained_policy_name_or_path: str | None = None,
policy_overrides: List[str] | None = None, policy_overrides: List[str] | None = None,
fps: int | None = None, fps: int | None = None,
root="data",
repo_id="lerobot/debug",
warmup_time_s=2, warmup_time_s=2,
episode_time_s=10, episode_time_s=10,
reset_time_s=5, reset_time_s=5,

View File

@ -29,7 +29,7 @@ from unittest.mock import patch
import pytest import pytest
from lerobot.common.datasets.populate_dataset import add_frame from lerobot.common.datasets.populate_dataset import add_frame, init_dataset
from lerobot.common.logger import Logger from lerobot.common.logger import Logger
from lerobot.common.policies.factory import make_policy from lerobot.common.policies.factory import make_policy
from lerobot.common.utils.utils import init_hydra_config from lerobot.common.utils.utils import init_hydra_config
@ -131,13 +131,14 @@ def test_record_and_replay_and_policy(tmpdir, request, robot_type, mock):
root = tmpdir / "data" root = tmpdir / "data"
repo_id = "lerobot/debug" repo_id = "lerobot/debug"
eval_repo_id = "lerobot/eval_debug"
robot = make_robot(robot_type, overrides=overrides, mock=mock) robot = make_robot(robot_type, overrides=overrides, mock=mock)
dataset = record( dataset = record(
robot, robot,
fps=30, root,
root=root, repo_id,
repo_id=repo_id, fps=1,
warmup_time_s=1, warmup_time_s=1,
episode_time_s=1, episode_time_s=1,
reset_time_s=1, reset_time_s=1,
@ -149,8 +150,10 @@ def test_record_and_replay_and_policy(tmpdir, request, robot_type, mock):
display_cameras=False, display_cameras=False,
play_sounds=False, play_sounds=False,
) )
assert dataset.num_episodes == 2
assert len(dataset) == 2
replay(robot, episode=0, fps=30, root=root, repo_id=repo_id, play_sounds=False) replay(robot, episode=0, fps=1, root=root, repo_id=repo_id, play_sounds=False)
# TODO(rcadene, aliberts): rethink this design # TODO(rcadene, aliberts): rethink this design
if robot_type == "aloha": if robot_type == "aloha":
@ -171,6 +174,9 @@ def test_record_and_replay_and_policy(tmpdir, request, robot_type, mock):
if robot_type == "koch_bimanual": if robot_type == "koch_bimanual":
overrides += ["env.state_dim=12", "env.action_dim=12"] overrides += ["env.state_dim=12", "env.action_dim=12"]
overrides += ["wandb.enable=false"]
overrides += ["env.fps=1"]
cfg = init_hydra_config( cfg = init_hydra_config(
DEFAULT_CONFIG_PATH, DEFAULT_CONFIG_PATH,
overrides=overrides, overrides=overrides,
@ -212,6 +218,8 @@ def test_record_and_replay_and_policy(tmpdir, request, robot_type, mock):
record( record(
robot, robot,
root,
eval_repo_id,
pretrained_policy_name_or_path, pretrained_policy_name_or_path,
warmup_time_s=1, warmup_time_s=1,
episode_time_s=1, episode_time_s=1,
@ -225,9 +233,75 @@ def test_record_and_replay_and_policy(tmpdir, request, robot_type, mock):
num_image_writer_processes=num_image_writer_processes, num_image_writer_processes=num_image_writer_processes,
) )
assert dataset.num_episodes == 2
assert len(dataset) == 2
del robot del robot
@pytest.mark.parametrize("robot_type, mock", [("koch", True)])
@require_robot
def test_resume_record(tmpdir, request, robot_type, mock):
if mock and robot_type != "aloha":
request.getfixturevalue("patch_builtins_input")
# Create an empty calibration directory to trigger manual calibration
# and avoid writing calibration files in user .cache/calibration folder
calibration_dir = tmpdir / robot_type
overrides = [f"calibration_dir={calibration_dir}"]
else:
# Use the default .cache/calibration folder when mock=False or for aloha
overrides = []
robot = make_robot(robot_type, overrides=overrides, mock=mock)
root = Path(tmpdir) / "data"
repo_id = "lerobot/debug"
dataset = record(
robot,
root,
repo_id,
fps=1,
warmup_time_s=0,
episode_time_s=1,
num_episodes=1,
push_to_hub=False,
video=False,
display_cameras=False,
play_sounds=False,
run_compute_stats=False,
)
assert len(dataset) == 1, "`dataset` should contain only 1 frame"
init_dataset_return_value = {}
def wrapped_init_dataset(*args, **kwargs):
nonlocal init_dataset_return_value
init_dataset_return_value = init_dataset(*args, **kwargs)
return init_dataset_return_value
with patch("lerobot.scripts.control_robot.init_dataset", wraps=wrapped_init_dataset):
dataset = record(
robot,
root,
repo_id,
fps=1,
warmup_time_s=0,
episode_time_s=1,
num_episodes=2,
push_to_hub=False,
video=False,
display_cameras=False,
play_sounds=False,
run_compute_stats=False,
)
assert len(dataset) == 2, "`dataset` should contain only 1 frame"
assert (
init_dataset_return_value["num_episodes"] == 2
), "`init_dataset` should load the previous episode"
@pytest.mark.parametrize("robot_type, mock", [("koch", True)]) @pytest.mark.parametrize("robot_type, mock", [("koch", True)])
@require_robot @require_robot
def test_record_with_event_rerecord_episode(tmpdir, request, robot_type, mock): def test_record_with_event_rerecord_episode(tmpdir, request, robot_type, mock):
@ -258,9 +332,9 @@ def test_record_with_event_rerecord_episode(tmpdir, request, robot_type, mock):
dataset = record( dataset = record(
robot, robot,
root,
repo_id,
fps=1, fps=1,
root=root,
repo_id=repo_id,
warmup_time_s=0, warmup_time_s=0,
episode_time_s=1, episode_time_s=1,
num_episodes=1, num_episodes=1,
@ -268,6 +342,7 @@ def test_record_with_event_rerecord_episode(tmpdir, request, robot_type, mock):
video=False, video=False,
display_cameras=False, display_cameras=False,
play_sounds=False, play_sounds=False,
run_compute_stats=False,
) )
assert not mock_events["rerecord_episode"], "`rerecord_episode` wasn't properly reset to False" assert not mock_events["rerecord_episode"], "`rerecord_episode` wasn't properly reset to False"
@ -316,6 +391,7 @@ def test_record_with_event_exit_early(tmpdir, request, robot_type, mock):
video=False, video=False,
display_cameras=False, display_cameras=False,
play_sounds=False, play_sounds=False,
run_compute_stats=False,
) )
assert not mock_events["exit_early"], "`exit_early` wasn't properly reset to False" assert not mock_events["exit_early"], "`exit_early` wasn't properly reset to False"
@ -355,9 +431,9 @@ def test_record_with_event_stop_recording(tmpdir, request, robot_type, mock, num
dataset = record( dataset = record(
robot, robot,
root,
repo_id,
fps=1, fps=1,
root=root,
repo_id=repo_id,
warmup_time_s=0, warmup_time_s=0,
episode_time_s=1, episode_time_s=1,
num_episodes=2, num_episodes=2,
@ -365,6 +441,7 @@ def test_record_with_event_stop_recording(tmpdir, request, robot_type, mock, num
video=False, video=False,
display_cameras=False, display_cameras=False,
play_sounds=False, play_sounds=False,
run_compute_stats=False,
num_image_writer_processes=num_image_writer_processes, num_image_writer_processes=num_image_writer_processes,
) )

View File

@ -127,6 +127,7 @@ def test_robot(tmpdir, request, robot_type, mock):
# TODO(rcadene): skipping image for now as it's challenging to assess equality between two consecutive frames # TODO(rcadene): skipping image for now as it's challenging to assess equality between two consecutive frames
continue continue
assert torch.allclose(captured_observation[name], observation[name], atol=1) assert torch.allclose(captured_observation[name], observation[name], atol=1)
assert captured_observation[name].shape == observation[name].shape
# Test send_action can run # Test send_action can run
robot.send_action(action["action"]) robot.send_action(action["action"])