Fix tests

This commit is contained in:
Simon Alibert 2024-10-22 20:14:06 +02:00
parent ee52b8b782
commit b46db7ea73
2 changed files with 58 additions and 55 deletions

View File

@ -120,7 +120,7 @@ class ImageWriter:
wait(self.futures, timeout=timeout) wait(self.futures, timeout=timeout)
progress_bar.update(len(self.futures)) progress_bar.update(len(self.futures))
else: else:
self._stop_processes(self.processes, self.image_queue, timeout) self._stop_processes(timeout)
def _stop_processes(self, timeout) -> None: def _stop_processes(self, timeout) -> None:
for _ in self.processes: for _ in self.processes:

View File

@ -29,7 +29,6 @@ from unittest.mock import patch
import pytest import pytest
from lerobot.common.datasets.populate_dataset import add_frame, init_dataset
from lerobot.common.logger import Logger from lerobot.common.logger import Logger
from lerobot.common.policies.factory import make_policy from lerobot.common.policies.factory import make_policy
from lerobot.common.utils.utils import init_hydra_config from lerobot.common.utils.utils import init_hydra_config
@ -91,8 +90,9 @@ def test_record_without_cameras(tmpdir, request, robot_type, mock):
calibration_dir = Path(tmpdir) / robot_type calibration_dir = Path(tmpdir) / robot_type
overrides.append(f"calibration_dir={calibration_dir}") overrides.append(f"calibration_dir={calibration_dir}")
root = Path(tmpdir) / "data"
repo_id = "lerobot/debug" repo_id = "lerobot/debug"
root = Path(tmpdir) / "data" / repo_id
single_task = "Do something."
robot = make_robot(robot_type, overrides=overrides, mock=mock) robot = make_robot(robot_type, overrides=overrides, mock=mock)
record( record(
@ -100,6 +100,7 @@ def test_record_without_cameras(tmpdir, request, robot_type, mock):
fps=30, fps=30,
root=root, root=root,
repo_id=repo_id, repo_id=repo_id,
single_task=single_task,
warmup_time_s=1, warmup_time_s=1,
episode_time_s=1, episode_time_s=1,
num_episodes=2, num_episodes=2,
@ -129,17 +130,18 @@ def test_record_and_replay_and_policy(tmpdir, request, robot_type, mock):
env_name = "koch_real" env_name = "koch_real"
policy_name = "act_koch_real" policy_name = "act_koch_real"
root = tmpdir / "data"
repo_id = "lerobot/debug" repo_id = "lerobot/debug"
eval_repo_id = "lerobot/eval_debug" root = tmpdir / "data" / repo_id
single_task = "Do something."
robot = make_robot(robot_type, overrides=overrides, mock=mock) robot = make_robot(robot_type, overrides=overrides, mock=mock)
dataset = record( dataset = record(
robot, robot,
root, root,
repo_id, repo_id,
fps=1, single_task,
warmup_time_s=1, fps=5,
warmup_time_s=0.5,
episode_time_s=1, episode_time_s=1,
reset_time_s=1, reset_time_s=1,
num_episodes=2, num_episodes=2,
@ -150,10 +152,10 @@ def test_record_and_replay_and_policy(tmpdir, request, robot_type, mock):
display_cameras=False, display_cameras=False,
play_sounds=False, play_sounds=False,
) )
assert dataset.num_episodes == 2 assert dataset.total_episodes == 2
assert len(dataset) == 2 assert len(dataset) == 10
replay(robot, episode=0, fps=1, root=root, repo_id=repo_id, play_sounds=False) replay(robot, episode=0, fps=5, root=root, repo_id=repo_id, play_sounds=False)
# TODO(rcadene, aliberts): rethink this design # TODO(rcadene, aliberts): rethink this design
if robot_type == "aloha": if robot_type == "aloha":
@ -216,10 +218,14 @@ def test_record_and_replay_and_policy(tmpdir, request, robot_type, mock):
else: else:
num_image_writer_processes = 0 num_image_writer_processes = 0
record( eval_repo_id = "lerobot/eval_debug"
eval_root = tmpdir / "data" / eval_repo_id
dataset = record(
robot, robot,
root, eval_root,
eval_repo_id, eval_repo_id,
single_task,
pretrained_policy_name_or_path, pretrained_policy_name_or_path,
warmup_time_s=1, warmup_time_s=1,
episode_time_s=1, episode_time_s=1,
@ -255,13 +261,15 @@ def test_resume_record(tmpdir, request, robot_type, mock):
robot = make_robot(robot_type, overrides=overrides, mock=mock) robot = make_robot(robot_type, overrides=overrides, mock=mock)
root = Path(tmpdir) / "data"
repo_id = "lerobot/debug" repo_id = "lerobot/debug"
root = Path(tmpdir) / "data" / repo_id
single_task = "Do something."
dataset = record( dataset = record(
robot, robot,
root, root,
repo_id, repo_id,
single_task,
fps=1, fps=1,
warmup_time_s=0, warmup_time_s=0,
episode_time_s=1, episode_time_s=1,
@ -274,32 +282,33 @@ def test_resume_record(tmpdir, request, robot_type, mock):
) )
assert len(dataset) == 1, "`dataset` should contain only 1 frame" assert len(dataset) == 1, "`dataset` should contain only 1 frame"
init_dataset_return_value = {} # init_dataset_return_value = {}
def wrapped_init_dataset(*args, **kwargs): # def wrapped_init_dataset(*args, **kwargs):
nonlocal init_dataset_return_value # nonlocal init_dataset_return_value
init_dataset_return_value = init_dataset(*args, **kwargs) # init_dataset_return_value = init_dataset(*args, **kwargs)
return init_dataset_return_value # return init_dataset_return_value
with patch("lerobot.scripts.control_robot.init_dataset", wraps=wrapped_init_dataset): # with patch("lerobot.scripts.control_robot.init_dataset", wraps=wrapped_init_dataset):
dataset = record( dataset = record(
robot, robot,
root, root,
repo_id, repo_id,
fps=1, single_task,
warmup_time_s=0, fps=1,
episode_time_s=1, warmup_time_s=0,
num_episodes=2, episode_time_s=1,
push_to_hub=False, num_episodes=2,
video=False, push_to_hub=False,
display_cameras=False, video=False,
play_sounds=False, display_cameras=False,
run_compute_stats=False, play_sounds=False,
) run_compute_stats=False,
assert len(dataset) == 2, "`dataset` should contain only 1 frame" )
assert ( assert len(dataset) == 2, "`dataset` should contain only 1 frame"
init_dataset_return_value["num_episodes"] == 2 # assert (
), "`init_dataset` should load the previous episode" # init_dataset_return_value["num_episodes"] == 2
# ), "`init_dataset` should load the previous episode"
@pytest.mark.parametrize("robot_type, mock", [("koch", True)]) @pytest.mark.parametrize("robot_type, mock", [("koch", True)])
@ -317,23 +326,22 @@ def test_record_with_event_rerecord_episode(tmpdir, request, robot_type, mock):
overrides = [] overrides = []
robot = make_robot(robot_type, overrides=overrides, mock=mock) robot = make_robot(robot_type, overrides=overrides, mock=mock)
with ( with patch("lerobot.scripts.control_robot.init_keyboard_listener") as mock_listener:
patch("lerobot.scripts.control_robot.init_keyboard_listener") as mock_listener,
patch("lerobot.common.robot_devices.control_utils.add_frame", wraps=add_frame) as mock_add_frame,
):
mock_events = {} mock_events = {}
mock_events["exit_early"] = True mock_events["exit_early"] = True
mock_events["rerecord_episode"] = True mock_events["rerecord_episode"] = True
mock_events["stop_recording"] = False mock_events["stop_recording"] = False
mock_listener.return_value = (None, mock_events) mock_listener.return_value = (None, mock_events)
root = Path(tmpdir) / "data"
repo_id = "lerobot/debug" repo_id = "lerobot/debug"
root = Path(tmpdir) / "data" / repo_id
single_task = "Do something."
dataset = record( dataset = record(
robot, robot,
root, root,
repo_id, repo_id,
single_task,
fps=1, fps=1,
warmup_time_s=0, warmup_time_s=0,
episode_time_s=1, episode_time_s=1,
@ -347,7 +355,6 @@ def test_record_with_event_rerecord_episode(tmpdir, request, robot_type, mock):
assert not mock_events["rerecord_episode"], "`rerecord_episode` wasn't properly reset to False" assert not mock_events["rerecord_episode"], "`rerecord_episode` wasn't properly reset to False"
assert not mock_events["exit_early"], "`exit_early` wasn't properly reset to False" assert not mock_events["exit_early"], "`exit_early` wasn't properly reset to False"
assert mock_add_frame.call_count == 2, "`add_frame` should have been called 2 times"
assert len(dataset) == 1, "`dataset` should contain only 1 frame" assert len(dataset) == 1, "`dataset` should contain only 1 frame"
@ -366,23 +373,22 @@ def test_record_with_event_exit_early(tmpdir, request, robot_type, mock):
overrides = [] overrides = []
robot = make_robot(robot_type, overrides=overrides, mock=mock) robot = make_robot(robot_type, overrides=overrides, mock=mock)
with ( with patch("lerobot.scripts.control_robot.init_keyboard_listener") as mock_listener:
patch("lerobot.scripts.control_robot.init_keyboard_listener") as mock_listener,
patch("lerobot.common.robot_devices.control_utils.add_frame", wraps=add_frame) as mock_add_frame,
):
mock_events = {} mock_events = {}
mock_events["exit_early"] = True mock_events["exit_early"] = True
mock_events["rerecord_episode"] = False mock_events["rerecord_episode"] = False
mock_events["stop_recording"] = False mock_events["stop_recording"] = False
mock_listener.return_value = (None, mock_events) mock_listener.return_value = (None, mock_events)
root = Path(tmpdir) / "data"
repo_id = "lerobot/debug" repo_id = "lerobot/debug"
root = Path(tmpdir) / "data" / repo_id
single_task = "Do something."
dataset = record( dataset = record(
robot, robot,
fps=2, fps=2,
root=root, root=root,
single_task=single_task,
repo_id=repo_id, repo_id=repo_id,
warmup_time_s=0, warmup_time_s=0,
episode_time_s=1, episode_time_s=1,
@ -395,7 +401,6 @@ def test_record_with_event_exit_early(tmpdir, request, robot_type, mock):
) )
assert not mock_events["exit_early"], "`exit_early` wasn't properly reset to False" assert not mock_events["exit_early"], "`exit_early` wasn't properly reset to False"
assert mock_add_frame.call_count == 1, "`add_frame` should have been called 1 time"
assert len(dataset) == 1, "`dataset` should contain only 1 frame" assert len(dataset) == 1, "`dataset` should contain only 1 frame"
@ -416,23 +421,22 @@ def test_record_with_event_stop_recording(tmpdir, request, robot_type, mock, num
overrides = [] overrides = []
robot = make_robot(robot_type, overrides=overrides, mock=mock) robot = make_robot(robot_type, overrides=overrides, mock=mock)
with ( with patch("lerobot.scripts.control_robot.init_keyboard_listener") as mock_listener:
patch("lerobot.scripts.control_robot.init_keyboard_listener") as mock_listener,
patch("lerobot.common.robot_devices.control_utils.add_frame", wraps=add_frame) as mock_add_frame,
):
mock_events = {} mock_events = {}
mock_events["exit_early"] = True mock_events["exit_early"] = True
mock_events["rerecord_episode"] = False mock_events["rerecord_episode"] = False
mock_events["stop_recording"] = True mock_events["stop_recording"] = True
mock_listener.return_value = (None, mock_events) mock_listener.return_value = (None, mock_events)
root = Path(tmpdir) / "data"
repo_id = "lerobot/debug" repo_id = "lerobot/debug"
root = Path(tmpdir) / "data" / repo_id
single_task = "Do something."
dataset = record( dataset = record(
robot, robot,
root, root,
repo_id, repo_id,
single_task=single_task,
fps=1, fps=1,
warmup_time_s=0, warmup_time_s=0,
episode_time_s=1, episode_time_s=1,
@ -446,5 +450,4 @@ def test_record_with_event_stop_recording(tmpdir, request, robot_type, mock, num
) )
assert not mock_events["exit_early"], "`exit_early` wasn't properly reset to False" assert not mock_events["exit_early"], "`exit_early` wasn't properly reset to False"
assert mock_add_frame.call_count == 1, "`add_frame` should have been called 1 time"
assert len(dataset) == 1, "`dataset` should contain only 1 frame" assert len(dataset) == 1, "`dataset` should contain only 1 frame"