Fix unit tests
This commit is contained in:
parent
d02e204e10
commit
eed7b55fe3
|
@ -296,7 +296,6 @@ def add_frame(dataset, observation, action):
|
||||||
|
|
||||||
ep_dict[key].append(frame_info)
|
ep_dict[key].append(frame_info)
|
||||||
|
|
||||||
dataset["image_keys"] = img_keys # used for video generation
|
|
||||||
dataset["current_frame_index"] += 1
|
dataset["current_frame_index"] += 1
|
||||||
|
|
||||||
|
|
||||||
|
@ -389,9 +388,6 @@ def from_dataset_to_lerobot_dataset(dataset, play_sounds):
|
||||||
image_keys = [key for key in data_dict if "image" in key]
|
image_keys = [key for key in data_dict if "image" in key]
|
||||||
encode_videos(dataset, image_keys, play_sounds)
|
encode_videos(dataset, image_keys, play_sounds)
|
||||||
|
|
||||||
total_frames = data_dict["frame_index"].shape[0]
|
|
||||||
data_dict["index"] = torch.arange(0, total_frames, 1)
|
|
||||||
|
|
||||||
hf_dataset = to_hf_dataset(data_dict, video)
|
hf_dataset = to_hf_dataset(data_dict, video)
|
||||||
episode_data_index = calculate_episode_data_index(hf_dataset)
|
episode_data_index = calculate_episode_data_index(hf_dataset)
|
||||||
|
|
||||||
|
|
|
@ -200,28 +200,6 @@ def to_hf_dataset(data_dict, video) -> Dataset:
|
||||||
features["next.done"] = Value(dtype="bool", id=None)
|
features["next.done"] = Value(dtype="bool", id=None)
|
||||||
features["index"] = Value(dtype="int64", id=None)
|
features["index"] = Value(dtype="int64", id=None)
|
||||||
|
|
||||||
for key in data_dict:
|
|
||||||
if isinstance(data_dict[key], list):
|
|
||||||
print(key, len(data_dict[key]))
|
|
||||||
elif isinstance(data_dict[key], torch.Tensor):
|
|
||||||
print(key, data_dict[key].shape)
|
|
||||||
else:
|
|
||||||
print(key, data_dict[key])
|
|
||||||
|
|
||||||
data_dict["episode_index"] = data_dict["episode_index"].tolist()
|
|
||||||
data_dict["frame_index"] = data_dict["frame_index"].tolist()
|
|
||||||
data_dict["timestamp"] = data_dict["timestamp"].tolist()
|
|
||||||
data_dict["next.done"] = data_dict["next.done"].tolist()
|
|
||||||
data_dict["index"] = data_dict["index"].tolist()
|
|
||||||
|
|
||||||
for key in data_dict:
|
|
||||||
if isinstance(data_dict[key], list):
|
|
||||||
print(key, len(data_dict[key]))
|
|
||||||
elif isinstance(data_dict[key], torch.Tensor):
|
|
||||||
print(key, data_dict[key].shape)
|
|
||||||
else:
|
|
||||||
print(key, data_dict[key])
|
|
||||||
|
|
||||||
hf_dataset = Dataset.from_dict(data_dict, features=Features(features))
|
hf_dataset = Dataset.from_dict(data_dict, features=Features(features))
|
||||||
hf_dataset.set_transform(hf_transform_to_torch)
|
hf_dataset.set_transform(hf_transform_to_torch)
|
||||||
return hf_dataset
|
return hf_dataset
|
||||||
|
|
|
@ -7,6 +7,7 @@ import logging
|
||||||
import time
|
import time
|
||||||
import traceback
|
import traceback
|
||||||
from contextlib import nullcontext
|
from contextlib import nullcontext
|
||||||
|
from copy import copy
|
||||||
from functools import cache
|
from functools import cache
|
||||||
|
|
||||||
import cv2
|
import cv2
|
||||||
|
@ -90,6 +91,7 @@ def has_method(_object: object, method_name: str):
|
||||||
|
|
||||||
|
|
||||||
def predict_action(observation, policy, device, use_amp):
|
def predict_action(observation, policy, device, use_amp):
|
||||||
|
observation = copy(observation)
|
||||||
with (
|
with (
|
||||||
torch.inference_mode(),
|
torch.inference_mode(),
|
||||||
torch.autocast(device_type=device.type) if device.type == "cuda" and use_amp else nullcontext(),
|
torch.autocast(device_type=device.type) if device.type == "cuda" and use_amp else nullcontext(),
|
||||||
|
@ -297,7 +299,9 @@ def stop_recording(robot, listener, display_cameras):
|
||||||
|
|
||||||
def sanity_check_dataset_name(repo_id, policy):
|
def sanity_check_dataset_name(repo_id, policy):
|
||||||
_, dataset_name = repo_id.split("/")
|
_, dataset_name = repo_id.split("/")
|
||||||
if dataset_name.startswith("eval_") and policy is None:
|
# either repo_id doesnt start with "eval_" and there is no policy
|
||||||
|
# or repo_id starts with "eval_" and there is a policy
|
||||||
|
if dataset_name.startswith("eval_") == (policy is None):
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"Your dataset name begins by 'eval_' ({dataset_name}) but no policy is provided ({policy})."
|
f"Your dataset name begins by 'eval_' ({dataset_name}) but no policy is provided ({policy})."
|
||||||
)
|
)
|
||||||
|
|
|
@ -201,11 +201,11 @@ def teleoperate(robot: Robot, fps: int | None = None, teleop_time_s: float | Non
|
||||||
@safe_disconnect
|
@safe_disconnect
|
||||||
def record(
|
def record(
|
||||||
robot: Robot,
|
robot: Robot,
|
||||||
|
root: str,
|
||||||
|
repo_id: str,
|
||||||
pretrained_policy_name_or_path: str | None = None,
|
pretrained_policy_name_or_path: str | None = None,
|
||||||
policy_overrides: List[str] | None = None,
|
policy_overrides: List[str] | None = None,
|
||||||
fps: int | None = None,
|
fps: int | None = None,
|
||||||
root="data",
|
|
||||||
repo_id="lerobot/debug",
|
|
||||||
warmup_time_s=2,
|
warmup_time_s=2,
|
||||||
episode_time_s=10,
|
episode_time_s=10,
|
||||||
reset_time_s=5,
|
reset_time_s=5,
|
||||||
|
|
|
@ -29,7 +29,7 @@ from unittest.mock import patch
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
from lerobot.common.datasets.populate_dataset import add_frame
|
from lerobot.common.datasets.populate_dataset import add_frame, init_dataset
|
||||||
from lerobot.common.logger import Logger
|
from lerobot.common.logger import Logger
|
||||||
from lerobot.common.policies.factory import make_policy
|
from lerobot.common.policies.factory import make_policy
|
||||||
from lerobot.common.utils.utils import init_hydra_config
|
from lerobot.common.utils.utils import init_hydra_config
|
||||||
|
@ -131,13 +131,14 @@ def test_record_and_replay_and_policy(tmpdir, request, robot_type, mock):
|
||||||
|
|
||||||
root = tmpdir / "data"
|
root = tmpdir / "data"
|
||||||
repo_id = "lerobot/debug"
|
repo_id = "lerobot/debug"
|
||||||
|
eval_repo_id = "lerobot/eval_debug"
|
||||||
|
|
||||||
robot = make_robot(robot_type, overrides=overrides, mock=mock)
|
robot = make_robot(robot_type, overrides=overrides, mock=mock)
|
||||||
dataset = record(
|
dataset = record(
|
||||||
robot,
|
robot,
|
||||||
fps=30,
|
root,
|
||||||
root=root,
|
repo_id,
|
||||||
repo_id=repo_id,
|
fps=1,
|
||||||
warmup_time_s=1,
|
warmup_time_s=1,
|
||||||
episode_time_s=1,
|
episode_time_s=1,
|
||||||
reset_time_s=1,
|
reset_time_s=1,
|
||||||
|
@ -149,8 +150,10 @@ def test_record_and_replay_and_policy(tmpdir, request, robot_type, mock):
|
||||||
display_cameras=False,
|
display_cameras=False,
|
||||||
play_sounds=False,
|
play_sounds=False,
|
||||||
)
|
)
|
||||||
|
assert dataset.num_episodes == 2
|
||||||
|
assert len(dataset) == 2
|
||||||
|
|
||||||
replay(robot, episode=0, fps=30, root=root, repo_id=repo_id, play_sounds=False)
|
replay(robot, episode=0, fps=1, root=root, repo_id=repo_id, play_sounds=False)
|
||||||
|
|
||||||
# TODO(rcadene, aliberts): rethink this design
|
# TODO(rcadene, aliberts): rethink this design
|
||||||
if robot_type == "aloha":
|
if robot_type == "aloha":
|
||||||
|
@ -171,6 +174,9 @@ def test_record_and_replay_and_policy(tmpdir, request, robot_type, mock):
|
||||||
if robot_type == "koch_bimanual":
|
if robot_type == "koch_bimanual":
|
||||||
overrides += ["env.state_dim=12", "env.action_dim=12"]
|
overrides += ["env.state_dim=12", "env.action_dim=12"]
|
||||||
|
|
||||||
|
overrides += ["wandb.enable=false"]
|
||||||
|
overrides += ["env.fps=1"]
|
||||||
|
|
||||||
cfg = init_hydra_config(
|
cfg = init_hydra_config(
|
||||||
DEFAULT_CONFIG_PATH,
|
DEFAULT_CONFIG_PATH,
|
||||||
overrides=overrides,
|
overrides=overrides,
|
||||||
|
@ -212,6 +218,8 @@ def test_record_and_replay_and_policy(tmpdir, request, robot_type, mock):
|
||||||
|
|
||||||
record(
|
record(
|
||||||
robot,
|
robot,
|
||||||
|
root,
|
||||||
|
eval_repo_id,
|
||||||
pretrained_policy_name_or_path,
|
pretrained_policy_name_or_path,
|
||||||
warmup_time_s=1,
|
warmup_time_s=1,
|
||||||
episode_time_s=1,
|
episode_time_s=1,
|
||||||
|
@ -225,9 +233,75 @@ def test_record_and_replay_and_policy(tmpdir, request, robot_type, mock):
|
||||||
num_image_writer_processes=num_image_writer_processes,
|
num_image_writer_processes=num_image_writer_processes,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
assert dataset.num_episodes == 2
|
||||||
|
assert len(dataset) == 2
|
||||||
|
|
||||||
del robot
|
del robot
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("robot_type, mock", [("koch", True)])
|
||||||
|
@require_robot
|
||||||
|
def test_resume_record(tmpdir, request, robot_type, mock):
|
||||||
|
if mock and robot_type != "aloha":
|
||||||
|
request.getfixturevalue("patch_builtins_input")
|
||||||
|
|
||||||
|
# Create an empty calibration directory to trigger manual calibration
|
||||||
|
# and avoid writing calibration files in user .cache/calibration folder
|
||||||
|
calibration_dir = tmpdir / robot_type
|
||||||
|
overrides = [f"calibration_dir={calibration_dir}"]
|
||||||
|
else:
|
||||||
|
# Use the default .cache/calibration folder when mock=False or for aloha
|
||||||
|
overrides = []
|
||||||
|
|
||||||
|
robot = make_robot(robot_type, overrides=overrides, mock=mock)
|
||||||
|
|
||||||
|
root = Path(tmpdir) / "data"
|
||||||
|
repo_id = "lerobot/debug"
|
||||||
|
|
||||||
|
dataset = record(
|
||||||
|
robot,
|
||||||
|
root,
|
||||||
|
repo_id,
|
||||||
|
fps=1,
|
||||||
|
warmup_time_s=0,
|
||||||
|
episode_time_s=1,
|
||||||
|
num_episodes=1,
|
||||||
|
push_to_hub=False,
|
||||||
|
video=False,
|
||||||
|
display_cameras=False,
|
||||||
|
play_sounds=False,
|
||||||
|
run_compute_stats=False,
|
||||||
|
)
|
||||||
|
assert len(dataset) == 1, "`dataset` should contain only 1 frame"
|
||||||
|
|
||||||
|
init_dataset_return_value = {}
|
||||||
|
|
||||||
|
def wrapped_init_dataset(*args, **kwargs):
|
||||||
|
nonlocal init_dataset_return_value
|
||||||
|
init_dataset_return_value = init_dataset(*args, **kwargs)
|
||||||
|
return init_dataset_return_value
|
||||||
|
|
||||||
|
with patch("lerobot.scripts.control_robot.init_dataset", wraps=wrapped_init_dataset):
|
||||||
|
dataset = record(
|
||||||
|
robot,
|
||||||
|
root,
|
||||||
|
repo_id,
|
||||||
|
fps=1,
|
||||||
|
warmup_time_s=0,
|
||||||
|
episode_time_s=1,
|
||||||
|
num_episodes=2,
|
||||||
|
push_to_hub=False,
|
||||||
|
video=False,
|
||||||
|
display_cameras=False,
|
||||||
|
play_sounds=False,
|
||||||
|
run_compute_stats=False,
|
||||||
|
)
|
||||||
|
assert len(dataset) == 2, "`dataset` should contain only 1 frame"
|
||||||
|
assert (
|
||||||
|
init_dataset_return_value["num_episodes"] == 2
|
||||||
|
), "`init_dataset` should load the previous episode"
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("robot_type, mock", [("koch", True)])
|
@pytest.mark.parametrize("robot_type, mock", [("koch", True)])
|
||||||
@require_robot
|
@require_robot
|
||||||
def test_record_with_event_rerecord_episode(tmpdir, request, robot_type, mock):
|
def test_record_with_event_rerecord_episode(tmpdir, request, robot_type, mock):
|
||||||
|
@ -258,9 +332,9 @@ def test_record_with_event_rerecord_episode(tmpdir, request, robot_type, mock):
|
||||||
|
|
||||||
dataset = record(
|
dataset = record(
|
||||||
robot,
|
robot,
|
||||||
|
root,
|
||||||
|
repo_id,
|
||||||
fps=1,
|
fps=1,
|
||||||
root=root,
|
|
||||||
repo_id=repo_id,
|
|
||||||
warmup_time_s=0,
|
warmup_time_s=0,
|
||||||
episode_time_s=1,
|
episode_time_s=1,
|
||||||
num_episodes=1,
|
num_episodes=1,
|
||||||
|
@ -268,6 +342,7 @@ def test_record_with_event_rerecord_episode(tmpdir, request, robot_type, mock):
|
||||||
video=False,
|
video=False,
|
||||||
display_cameras=False,
|
display_cameras=False,
|
||||||
play_sounds=False,
|
play_sounds=False,
|
||||||
|
run_compute_stats=False,
|
||||||
)
|
)
|
||||||
|
|
||||||
assert not mock_events["rerecord_episode"], "`rerecord_episode` wasn't properly reset to False"
|
assert not mock_events["rerecord_episode"], "`rerecord_episode` wasn't properly reset to False"
|
||||||
|
@ -316,6 +391,7 @@ def test_record_with_event_exit_early(tmpdir, request, robot_type, mock):
|
||||||
video=False,
|
video=False,
|
||||||
display_cameras=False,
|
display_cameras=False,
|
||||||
play_sounds=False,
|
play_sounds=False,
|
||||||
|
run_compute_stats=False,
|
||||||
)
|
)
|
||||||
|
|
||||||
assert not mock_events["exit_early"], "`exit_early` wasn't properly reset to False"
|
assert not mock_events["exit_early"], "`exit_early` wasn't properly reset to False"
|
||||||
|
@ -355,9 +431,9 @@ def test_record_with_event_stop_recording(tmpdir, request, robot_type, mock, num
|
||||||
|
|
||||||
dataset = record(
|
dataset = record(
|
||||||
robot,
|
robot,
|
||||||
|
root,
|
||||||
|
repo_id,
|
||||||
fps=1,
|
fps=1,
|
||||||
root=root,
|
|
||||||
repo_id=repo_id,
|
|
||||||
warmup_time_s=0,
|
warmup_time_s=0,
|
||||||
episode_time_s=1,
|
episode_time_s=1,
|
||||||
num_episodes=2,
|
num_episodes=2,
|
||||||
|
@ -365,6 +441,7 @@ def test_record_with_event_stop_recording(tmpdir, request, robot_type, mock, num
|
||||||
video=False,
|
video=False,
|
||||||
display_cameras=False,
|
display_cameras=False,
|
||||||
play_sounds=False,
|
play_sounds=False,
|
||||||
|
run_compute_stats=False,
|
||||||
num_image_writer_processes=num_image_writer_processes,
|
num_image_writer_processes=num_image_writer_processes,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
|
@ -127,6 +127,7 @@ def test_robot(tmpdir, request, robot_type, mock):
|
||||||
# TODO(rcadene): skipping image for now as it's challenging to assess equality between two consecutive frames
|
# TODO(rcadene): skipping image for now as it's challenging to assess equality between two consecutive frames
|
||||||
continue
|
continue
|
||||||
assert torch.allclose(captured_observation[name], observation[name], atol=1)
|
assert torch.allclose(captured_observation[name], observation[name], atol=1)
|
||||||
|
assert captured_observation[name].shape == observation[name].shape
|
||||||
|
|
||||||
# Test send_action can run
|
# Test send_action can run
|
||||||
robot.send_action(action["action"])
|
robot.send_action(action["action"])
|
||||||
|
|
Loading…
Reference in New Issue