overall improve, fix some issues with events, add some tests for events
This commit is contained in:
parent
9f5c586c1a
commit
63bd5013bf
|
@ -202,9 +202,9 @@ def init_dataset(
|
||||||
if rec_info_path.exists():
|
if rec_info_path.exists():
|
||||||
with open(rec_info_path) as f:
|
with open(rec_info_path) as f:
|
||||||
rec_info = json.load(f)
|
rec_info = json.load(f)
|
||||||
episode_index = rec_info["last_episode_index"] + 1
|
num_episodes = rec_info["last_episode_index"] + 1
|
||||||
else:
|
else:
|
||||||
episode_index = 0
|
num_episodes = 0
|
||||||
|
|
||||||
dataset = {
|
dataset = {
|
||||||
"repo_id": repo_id,
|
"repo_id": repo_id,
|
||||||
|
@ -214,7 +214,7 @@ def init_dataset(
|
||||||
"fps": fps,
|
"fps": fps,
|
||||||
"video": video,
|
"video": video,
|
||||||
"rec_info_path": rec_info_path,
|
"rec_info_path": rec_info_path,
|
||||||
"current_episode_index": episode_index,
|
"num_episodes": num_episodes,
|
||||||
}
|
}
|
||||||
|
|
||||||
if write_images:
|
if write_images:
|
||||||
|
@ -249,7 +249,7 @@ def add_frame(dataset, observation, action):
|
||||||
dataset["current_frame_index"] = 0
|
dataset["current_frame_index"] = 0
|
||||||
|
|
||||||
ep_dict = dataset["current_episode"]
|
ep_dict = dataset["current_episode"]
|
||||||
episode_index = dataset["current_episode_index"]
|
episode_index = dataset["num_episodes"]
|
||||||
frame_index = dataset["current_frame_index"]
|
frame_index = dataset["current_frame_index"]
|
||||||
videos_dir = dataset["videos_dir"]
|
videos_dir = dataset["videos_dir"]
|
||||||
video = dataset["video"]
|
video = dataset["video"]
|
||||||
|
@ -300,15 +300,19 @@ def add_frame(dataset, observation, action):
|
||||||
dataset["current_frame_index"] += 1
|
dataset["current_frame_index"] += 1
|
||||||
|
|
||||||
|
|
||||||
def delete_episode(dataset):
|
def delete_current_episode(dataset):
|
||||||
del dataset["current_episode"]
|
del dataset["current_episode"]
|
||||||
del dataset["current_frame_index"]
|
del dataset["current_frame_index"]
|
||||||
|
|
||||||
# TODO(rcadene): remove images and videos, etc.
|
# delete temporary images
|
||||||
|
episode_index = dataset["num_episodes"]
|
||||||
|
videos_dir = dataset["videos_dir"]
|
||||||
|
for tmp_imgs_dir in videos_dir.glob(f"*_episode_{episode_index:06d}"):
|
||||||
|
shutil.rmtree(tmp_imgs_dir)
|
||||||
|
|
||||||
|
|
||||||
def save_episode(dataset):
|
def save_current_episode(dataset):
|
||||||
episode_index = dataset["current_episode_index"]
|
episode_index = dataset["num_episodes"]
|
||||||
ep_dict = dataset["current_episode"]
|
ep_dict = dataset["current_episode"]
|
||||||
episodes_dir = dataset["episodes_dir"]
|
episodes_dir = dataset["episodes_dir"]
|
||||||
rec_info_path = dataset["rec_info_path"]
|
rec_info_path = dataset["rec_info_path"]
|
||||||
|
@ -337,14 +341,13 @@ def save_episode(dataset):
|
||||||
# force re-initialization of episode dictionnary during add_frame
|
# force re-initialization of episode dictionnary during add_frame
|
||||||
del dataset["current_episode"]
|
del dataset["current_episode"]
|
||||||
|
|
||||||
dataset["current_episode_index"] += 1
|
dataset["num_episodes"] += 1
|
||||||
|
|
||||||
|
|
||||||
def encode_videos(dataset, play_sounds):
|
def encode_videos(dataset, image_keys, play_sounds):
|
||||||
log_say("Encoding videos", play_sounds)
|
log_say("Encoding videos", play_sounds)
|
||||||
|
|
||||||
num_episodes = dataset["current_episode_index"]
|
num_episodes = dataset["num_episodes"]
|
||||||
image_keys = dataset["image_keys"]
|
|
||||||
videos_dir = dataset["videos_dir"]
|
videos_dir = dataset["videos_dir"]
|
||||||
local_dir = dataset["local_dir"]
|
local_dir = dataset["local_dir"]
|
||||||
fps = dataset["fps"]
|
fps = dataset["fps"]
|
||||||
|
@ -368,7 +371,7 @@ def encode_videos(dataset, play_sounds):
|
||||||
def from_dataset_to_lerobot_dataset(dataset, play_sounds):
|
def from_dataset_to_lerobot_dataset(dataset, play_sounds):
|
||||||
log_say("Consolidate episodes", play_sounds)
|
log_say("Consolidate episodes", play_sounds)
|
||||||
|
|
||||||
num_episodes = dataset["current_episode_index"]
|
num_episodes = dataset["num_episodes"]
|
||||||
episodes_dir = dataset["episodes_dir"]
|
episodes_dir = dataset["episodes_dir"]
|
||||||
videos_dir = dataset["videos_dir"]
|
videos_dir = dataset["videos_dir"]
|
||||||
video = dataset["video"]
|
video = dataset["video"]
|
||||||
|
@ -382,6 +385,10 @@ def from_dataset_to_lerobot_dataset(dataset, play_sounds):
|
||||||
ep_dicts.append(ep_dict)
|
ep_dicts.append(ep_dict)
|
||||||
data_dict = concatenate_episodes(ep_dicts)
|
data_dict = concatenate_episodes(ep_dicts)
|
||||||
|
|
||||||
|
if video:
|
||||||
|
image_keys = [key for key in data_dict if "image" in key]
|
||||||
|
encode_videos(dataset, image_keys, play_sounds)
|
||||||
|
|
||||||
total_frames = data_dict["frame_index"].shape[0]
|
total_frames = data_dict["frame_index"].shape[0]
|
||||||
data_dict["index"] = torch.arange(0, total_frames, 1)
|
data_dict["index"] = torch.arange(0, total_frames, 1)
|
||||||
|
|
||||||
|
@ -443,16 +450,11 @@ def push_lerobot_dataset_to_hub(lerobot_dataset, tags):
|
||||||
|
|
||||||
|
|
||||||
def create_lerobot_dataset(dataset, run_compute_stats, push_to_hub, tags, play_sounds):
|
def create_lerobot_dataset(dataset, run_compute_stats, push_to_hub, tags, play_sounds):
|
||||||
video = dataset["video"]
|
|
||||||
|
|
||||||
if "image_writer" in dataset:
|
if "image_writer" in dataset:
|
||||||
logging.info("Waiting for image writer to terminate...")
|
logging.info("Waiting for image writer to terminate...")
|
||||||
image_writer = dataset["image_writer"]
|
image_writer = dataset["image_writer"]
|
||||||
stop_image_writer(image_writer, timeout=20)
|
stop_image_writer(image_writer, timeout=20)
|
||||||
|
|
||||||
if video:
|
|
||||||
encode_videos(dataset, play_sounds)
|
|
||||||
|
|
||||||
lerobot_dataset = from_dataset_to_lerobot_dataset(dataset, play_sounds)
|
lerobot_dataset = from_dataset_to_lerobot_dataset(dataset, play_sounds)
|
||||||
|
|
||||||
if run_compute_stats:
|
if run_compute_stats:
|
||||||
|
|
|
@ -116,9 +116,6 @@ def predict_action(observation, policy, device, use_amp):
|
||||||
|
|
||||||
|
|
||||||
def init_keyboard_listener():
|
def init_keyboard_listener():
|
||||||
# Only import pynput if not in a headless environment
|
|
||||||
from pynput import keyboard
|
|
||||||
|
|
||||||
# Allow to exit early while recording an episode or resetting the environment,
|
# Allow to exit early while recording an episode or resetting the environment,
|
||||||
# by tapping the right arrow key '->'. This might require a sudo permission
|
# by tapping the right arrow key '->'. This might require a sudo permission
|
||||||
# to allow your terminal to monitor keyboard events.
|
# to allow your terminal to monitor keyboard events.
|
||||||
|
@ -127,6 +124,16 @@ def init_keyboard_listener():
|
||||||
events["rerecord_episode"] = False
|
events["rerecord_episode"] = False
|
||||||
events["stop_recording"] = False
|
events["stop_recording"] = False
|
||||||
|
|
||||||
|
if is_headless():
|
||||||
|
logging.warning(
|
||||||
|
"Headless environment detected. On-screen cameras display and keyboard inputs will not be available."
|
||||||
|
)
|
||||||
|
listener = None
|
||||||
|
return listener, events
|
||||||
|
|
||||||
|
# Only import pynput if not in a headless environment
|
||||||
|
from pynput import keyboard
|
||||||
|
|
||||||
def on_press(key):
|
def on_press(key):
|
||||||
try:
|
try:
|
||||||
if key == keyboard.Key.right:
|
if key == keyboard.Key.right:
|
||||||
|
@ -175,7 +182,8 @@ def init_policy(pretrained_policy_name_or_path, policy_overrides, fps):
|
||||||
return policy, fps, device, use_amp
|
return policy, fps, device, use_amp
|
||||||
|
|
||||||
|
|
||||||
def warmup_record(robot, enable_teloperation, warmup_time_s, display_cameras, play_sounds, fps):
|
def warmup_record(robot, events, enable_teloperation, warmup_time_s, display_cameras, play_sounds, fps):
|
||||||
|
# TODO(rcadene): refactor warmup_record and reset_environment
|
||||||
timestamp = 0
|
timestamp = 0
|
||||||
start_warmup_t = time.perf_counter()
|
start_warmup_t = time.perf_counter()
|
||||||
|
|
||||||
|
@ -203,24 +211,23 @@ def warmup_record(robot, enable_teloperation, warmup_time_s, display_cameras, pl
|
||||||
log_control_info(robot, dt_s, fps=fps)
|
log_control_info(robot, dt_s, fps=fps)
|
||||||
|
|
||||||
timestamp = time.perf_counter() - start_warmup_t
|
timestamp = time.perf_counter() - start_warmup_t
|
||||||
|
if events is not None and events["exit_early"]:
|
||||||
|
events["exit_early"] = False
|
||||||
|
break
|
||||||
|
|
||||||
|
|
||||||
@safe_stop_image_writer
|
@safe_stop_image_writer
|
||||||
def record_episode(
|
def record_episode(
|
||||||
dataset,
|
dataset,
|
||||||
robot,
|
robot,
|
||||||
episode_index,
|
|
||||||
events,
|
events,
|
||||||
episode_time_s,
|
episode_time_s,
|
||||||
display_cameras,
|
display_cameras,
|
||||||
play_sounds,
|
|
||||||
policy,
|
policy,
|
||||||
device,
|
device,
|
||||||
use_amp,
|
use_amp,
|
||||||
fps,
|
fps,
|
||||||
):
|
):
|
||||||
log_say(f"Recording episode {episode_index}", play_sounds)
|
|
||||||
|
|
||||||
timestamp = 0
|
timestamp = 0
|
||||||
start_episode_t = time.perf_counter()
|
start_episode_t = time.perf_counter()
|
||||||
while timestamp < episode_time_s:
|
while timestamp < episode_time_s:
|
||||||
|
@ -258,9 +265,8 @@ def record_episode(
|
||||||
break
|
break
|
||||||
|
|
||||||
|
|
||||||
def reset_environment(robot, events, reset_time_s, play_sounds):
|
def reset_environment(robot, events, reset_time_s):
|
||||||
log_say("Reset the environment", play_sounds)
|
# TODO(rcadene): refactor warmup_record and reset_environment
|
||||||
|
|
||||||
# TODO(alibets): allow for teleop during reset
|
# TODO(alibets): allow for teleop during reset
|
||||||
if has_method(robot, "teleop_safety_stop"):
|
if has_method(robot, "teleop_safety_stop"):
|
||||||
robot.teleop_safety_stop()
|
robot.teleop_safety_stop()
|
||||||
|
@ -279,9 +285,7 @@ def reset_environment(robot, events, reset_time_s, play_sounds):
|
||||||
break
|
break
|
||||||
|
|
||||||
|
|
||||||
def done_recording(robot, listener, display_cameras, play_sounds):
|
def stop_recording(robot, listener, display_cameras):
|
||||||
log_say("Done recording", play_sounds, blocking=True)
|
|
||||||
|
|
||||||
robot.disconnect()
|
robot.disconnect()
|
||||||
|
|
||||||
if not is_headless():
|
if not is_headless():
|
||||||
|
|
|
@ -99,7 +99,6 @@ python lerobot/scripts/control_robot.py record \
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import argparse
|
import argparse
|
||||||
import logging
|
|
||||||
import time
|
import time
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import List
|
from typing import List
|
||||||
|
@ -108,21 +107,20 @@ from typing import List
|
||||||
from lerobot.common.datasets.lerobot_dataset import LeRobotDataset
|
from lerobot.common.datasets.lerobot_dataset import LeRobotDataset
|
||||||
from lerobot.common.datasets.populate_dataset import (
|
from lerobot.common.datasets.populate_dataset import (
|
||||||
create_lerobot_dataset,
|
create_lerobot_dataset,
|
||||||
delete_episode,
|
delete_current_episode,
|
||||||
init_dataset,
|
init_dataset,
|
||||||
save_episode,
|
save_current_episode,
|
||||||
)
|
)
|
||||||
from lerobot.common.robot_devices.control_robot import (
|
from lerobot.common.robot_devices.control_robot import (
|
||||||
done_recording,
|
|
||||||
has_method,
|
has_method,
|
||||||
init_keyboard_listener,
|
init_keyboard_listener,
|
||||||
init_policy,
|
init_policy,
|
||||||
is_headless,
|
|
||||||
log_control_info,
|
log_control_info,
|
||||||
log_say,
|
log_say,
|
||||||
record_episode,
|
record_episode,
|
||||||
reset_environment,
|
reset_environment,
|
||||||
sanity_check_dataset_name,
|
sanity_check_dataset_name,
|
||||||
|
stop_recording,
|
||||||
warmup_record,
|
warmup_record,
|
||||||
)
|
)
|
||||||
from lerobot.common.robot_devices.robots.factory import make_robot
|
from lerobot.common.robot_devices.robots.factory import make_robot
|
||||||
|
@ -229,9 +227,6 @@ def record(
|
||||||
device = None
|
device = None
|
||||||
use_amp = None
|
use_amp = None
|
||||||
|
|
||||||
if not robot.is_connected:
|
|
||||||
robot.connect()
|
|
||||||
|
|
||||||
# Load pretrained policy
|
# Load pretrained policy
|
||||||
if pretrained_policy_name_or_path is not None:
|
if pretrained_policy_name_or_path is not None:
|
||||||
policy, fps, device, use_amp = init_policy(pretrained_policy_name_or_path, policy_overrides, fps)
|
policy, fps, device, use_amp = init_policy(pretrained_policy_name_or_path, policy_overrides, fps)
|
||||||
|
@ -249,58 +244,72 @@ def record(
|
||||||
num_image_writer_threads=num_image_writer_threads_per_camera * robot.num_cameras,
|
num_image_writer_threads=num_image_writer_threads_per_camera * robot.num_cameras,
|
||||||
)
|
)
|
||||||
|
|
||||||
if is_headless():
|
if not robot.is_connected:
|
||||||
logging.warning(
|
robot.connect()
|
||||||
"Headless environment detected. On-screen cameras display and keyboard inputs will not be available."
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
listener, events = init_keyboard_listener()
|
listener, events = init_keyboard_listener()
|
||||||
|
|
||||||
# Execute a few seconds without recording data to:
|
# Execute a few seconds without recording to:
|
||||||
# 1. teleoperate the robot to move it in starting position if no policy provided,
|
# 1. teleoperate the robot to move it in starting position if no policy provided,
|
||||||
# 2. give times to the robot devices to connect and start synchronizing,
|
# 2. give times to the robot devices to connect and start synchronizing,
|
||||||
# 3. place the cameras windows on screen
|
# 3. place the cameras windows on screen
|
||||||
enable_teleoperation = policy is None
|
enable_teleoperation = policy is None
|
||||||
warmup_record(robot, enable_teleoperation, warmup_time_s, display_cameras, play_sounds, fps)
|
warmup_record(robot, events, enable_teleoperation, warmup_time_s, display_cameras, play_sounds, fps)
|
||||||
|
|
||||||
if has_method(robot, "teleop_safety_stop"):
|
if has_method(robot, "teleop_safety_stop"):
|
||||||
robot.teleop_safety_stop()
|
robot.teleop_safety_stop()
|
||||||
|
|
||||||
while True:
|
while True:
|
||||||
episode_index = dataset["current_episode_index"]
|
if dataset["num_episodes"] >= num_episodes:
|
||||||
if episode_index >= num_episodes:
|
|
||||||
break
|
break
|
||||||
|
|
||||||
|
episode_index = dataset["num_episodes"]
|
||||||
|
log_say(f"Recording episode {episode_index}", play_sounds)
|
||||||
record_episode(
|
record_episode(
|
||||||
dataset=dataset,
|
dataset=dataset,
|
||||||
robot=robot,
|
robot=robot,
|
||||||
episode_index=episode_index,
|
|
||||||
events=events,
|
events=events,
|
||||||
episode_time_s=episode_time_s,
|
episode_time_s=episode_time_s,
|
||||||
display_cameras=display_cameras,
|
display_cameras=display_cameras,
|
||||||
play_sounds=play_sounds,
|
|
||||||
policy=policy,
|
policy=policy,
|
||||||
device=device,
|
device=device,
|
||||||
use_amp=use_amp,
|
use_amp=use_amp,
|
||||||
fps=fps,
|
fps=fps,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Do not reset for the last episode to be recorded
|
# In case stop recording is requested during `record_episode`
|
||||||
if episode_index < num_episodes - 1:
|
if events is not None and events["stop_recording"]:
|
||||||
reset_environment(robot, events, reset_time_s, play_sounds)
|
save_current_episode(dataset)
|
||||||
|
|
||||||
if events is not None and events["rerecord_episode"]:
|
|
||||||
events["rerecord_episode"] = False
|
|
||||||
delete_episode(dataset)
|
|
||||||
continue
|
|
||||||
|
|
||||||
save_episode(dataset)
|
|
||||||
|
|
||||||
if events is not None and events["exit_early"]:
|
|
||||||
num_episodes = episode_index
|
|
||||||
break
|
break
|
||||||
|
|
||||||
done_recording(robot, listener, display_cameras, play_sounds)
|
# Execute a few seconds without recording to give time to manually reset the environment
|
||||||
|
# Current code logic doesn't allow to teleoperate during this time.
|
||||||
|
# TODO(rcadene): add an option to enable teleoperation during reset
|
||||||
|
# Skip reset for the last episode to be recorded
|
||||||
|
if episode_index < num_episodes - 1:
|
||||||
|
log_say("Reset the environment", play_sounds)
|
||||||
|
reset_environment(robot, events, reset_time_s)
|
||||||
|
|
||||||
|
# In case stop recording is requested during `reset_environment`
|
||||||
|
if events is not None and events["stop_recording"]:
|
||||||
|
save_current_episode(dataset)
|
||||||
|
break
|
||||||
|
|
||||||
|
if events is not None and events["rerecord_episode"]:
|
||||||
|
log_say("Re-record episode", play_sounds)
|
||||||
|
events["rerecord_episode"] = False
|
||||||
|
events["exit_early"] = False
|
||||||
|
delete_current_episode(dataset)
|
||||||
|
# Force reset
|
||||||
|
log_say("Reset the environment", play_sounds)
|
||||||
|
reset_environment(robot, events, reset_time_s)
|
||||||
|
continue
|
||||||
|
|
||||||
|
# Increment by one dataset["current_episode_index"]
|
||||||
|
save_current_episode(dataset)
|
||||||
|
|
||||||
|
log_say("Stop recording", play_sounds, blocking=True)
|
||||||
|
stop_recording(robot, listener, display_cameras)
|
||||||
|
|
||||||
lerobot_dataset = create_lerobot_dataset(dataset, run_compute_stats, push_to_hub, tags, play_sounds)
|
lerobot_dataset = create_lerobot_dataset(dataset, run_compute_stats, push_to_hub, tags, play_sounds)
|
||||||
|
|
||||||
|
|
|
@ -25,9 +25,11 @@ pytest -sx 'tests/test_control_robot.py::test_teleoperate[aloha-True]'
|
||||||
|
|
||||||
import multiprocessing
|
import multiprocessing
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
from unittest.mock import patch
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
|
from lerobot.common.datasets.populate_dataset import add_frame
|
||||||
from lerobot.common.logger import Logger
|
from lerobot.common.logger import Logger
|
||||||
from lerobot.common.policies.factory import make_policy
|
from lerobot.common.policies.factory import make_policy
|
||||||
from lerobot.common.utils.utils import init_hydra_config
|
from lerobot.common.utils.utils import init_hydra_config
|
||||||
|
@ -222,3 +224,148 @@ def test_record_and_replay_and_policy(tmpdir, request, robot_type, mock):
|
||||||
)
|
)
|
||||||
|
|
||||||
del robot
|
del robot
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("robot_type, mock", [("koch", True)])
|
||||||
|
@require_robot
|
||||||
|
def test_record_with_event_rerecord_episode(tmpdir, request, robot_type, mock):
|
||||||
|
if mock and robot_type != "aloha":
|
||||||
|
request.getfixturevalue("patch_builtins_input")
|
||||||
|
|
||||||
|
# Create an empty calibration directory to trigger manual calibration
|
||||||
|
# and avoid writing calibration files in user .cache/calibration folder
|
||||||
|
calibration_dir = tmpdir / robot_type
|
||||||
|
overrides = [f"calibration_dir={calibration_dir}"]
|
||||||
|
else:
|
||||||
|
# Use the default .cache/calibration folder when mock=False or for aloha
|
||||||
|
overrides = []
|
||||||
|
|
||||||
|
robot = make_robot(robot_type, overrides=overrides, mock=mock)
|
||||||
|
with (
|
||||||
|
patch("lerobot.scripts.control_robot.init_keyboard_listener") as mock_listener,
|
||||||
|
patch("lerobot.common.robot_devices.control_robot.add_frame", wraps=add_frame) as mock_add_frame,
|
||||||
|
):
|
||||||
|
mock_events = {}
|
||||||
|
mock_events["exit_early"] = True
|
||||||
|
mock_events["rerecord_episode"] = True
|
||||||
|
mock_events["stop_recording"] = False
|
||||||
|
mock_listener.return_value = (None, mock_events)
|
||||||
|
|
||||||
|
root = Path(tmpdir) / "data"
|
||||||
|
repo_id = "lerobot/debug"
|
||||||
|
|
||||||
|
dataset = record(
|
||||||
|
robot,
|
||||||
|
fps=1,
|
||||||
|
root=root,
|
||||||
|
repo_id=repo_id,
|
||||||
|
warmup_time_s=0,
|
||||||
|
episode_time_s=1,
|
||||||
|
num_episodes=1,
|
||||||
|
push_to_hub=False,
|
||||||
|
video=False,
|
||||||
|
display_cameras=False,
|
||||||
|
play_sounds=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert not mock_events["rerecord_episode"], "`rerecord_episode` wasn't properly reset to False"
|
||||||
|
assert not mock_events["exit_early"], "`exit_early` wasn't properly reset to False"
|
||||||
|
assert mock_add_frame.call_count == 2, "`add_frame` should have been called 2 times"
|
||||||
|
assert len(dataset) == 1, "`dataset` should contain only 1 frame"
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("robot_type, mock", [("koch", True)])
|
||||||
|
@require_robot
|
||||||
|
def test_record_with_event_exit_early(tmpdir, request, robot_type, mock):
|
||||||
|
if mock:
|
||||||
|
request.getfixturevalue("patch_builtins_input")
|
||||||
|
|
||||||
|
# Create an empty calibration directory to trigger manual calibration
|
||||||
|
# and avoid writing calibration files in user .cache/calibration folder
|
||||||
|
calibration_dir = tmpdir / robot_type
|
||||||
|
overrides = [f"calibration_dir={calibration_dir}"]
|
||||||
|
else:
|
||||||
|
# Use the default .cache/calibration folder when mock=False or for aloha
|
||||||
|
overrides = []
|
||||||
|
|
||||||
|
robot = make_robot(robot_type, overrides=overrides, mock=mock)
|
||||||
|
with (
|
||||||
|
patch("lerobot.scripts.control_robot.init_keyboard_listener") as mock_listener,
|
||||||
|
patch("lerobot.common.robot_devices.control_robot.add_frame", wraps=add_frame) as mock_add_frame,
|
||||||
|
):
|
||||||
|
mock_events = {}
|
||||||
|
mock_events["exit_early"] = True
|
||||||
|
mock_events["rerecord_episode"] = False
|
||||||
|
mock_events["stop_recording"] = False
|
||||||
|
mock_listener.return_value = (None, mock_events)
|
||||||
|
|
||||||
|
root = Path(tmpdir) / "data"
|
||||||
|
repo_id = "lerobot/debug"
|
||||||
|
|
||||||
|
dataset = record(
|
||||||
|
robot,
|
||||||
|
fps=2,
|
||||||
|
root=root,
|
||||||
|
repo_id=repo_id,
|
||||||
|
warmup_time_s=0,
|
||||||
|
episode_time_s=1,
|
||||||
|
num_episodes=1,
|
||||||
|
push_to_hub=False,
|
||||||
|
video=False,
|
||||||
|
display_cameras=False,
|
||||||
|
play_sounds=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert not mock_events["exit_early"], "`exit_early` wasn't properly reset to False"
|
||||||
|
assert mock_add_frame.call_count == 1, "`add_frame` should have been called 1 time"
|
||||||
|
assert len(dataset) == 1, "`dataset` should contain only 1 frame"
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"robot_type, mock, num_image_writer_processes", [("koch", True, 0), ("koch", True, 1)]
|
||||||
|
)
|
||||||
|
@require_robot
|
||||||
|
def test_record_with_event_stop_recording(tmpdir, request, robot_type, mock, num_image_writer_processes):
|
||||||
|
if mock:
|
||||||
|
request.getfixturevalue("patch_builtins_input")
|
||||||
|
|
||||||
|
# Create an empty calibration directory to trigger manual calibration
|
||||||
|
# and avoid writing calibration files in user .cache/calibration folder
|
||||||
|
calibration_dir = tmpdir / robot_type
|
||||||
|
overrides = [f"calibration_dir={calibration_dir}"]
|
||||||
|
else:
|
||||||
|
# Use the default .cache/calibration folder when mock=False or for aloha
|
||||||
|
overrides = []
|
||||||
|
|
||||||
|
robot = make_robot(robot_type, overrides=overrides, mock=mock)
|
||||||
|
with (
|
||||||
|
patch("lerobot.scripts.control_robot.init_keyboard_listener") as mock_listener,
|
||||||
|
patch("lerobot.common.robot_devices.control_robot.add_frame", wraps=add_frame) as mock_add_frame,
|
||||||
|
):
|
||||||
|
mock_events = {}
|
||||||
|
mock_events["exit_early"] = True
|
||||||
|
mock_events["rerecord_episode"] = False
|
||||||
|
mock_events["stop_recording"] = True
|
||||||
|
mock_listener.return_value = (None, mock_events)
|
||||||
|
|
||||||
|
root = Path(tmpdir) / "data"
|
||||||
|
repo_id = "lerobot/debug"
|
||||||
|
|
||||||
|
dataset = record(
|
||||||
|
robot,
|
||||||
|
fps=1,
|
||||||
|
root=root,
|
||||||
|
repo_id=repo_id,
|
||||||
|
warmup_time_s=0,
|
||||||
|
episode_time_s=1,
|
||||||
|
num_episodes=2,
|
||||||
|
push_to_hub=False,
|
||||||
|
video=False,
|
||||||
|
display_cameras=False,
|
||||||
|
play_sounds=False,
|
||||||
|
num_image_writer_processes=num_image_writer_processes,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert not mock_events["exit_early"], "`exit_early` wasn't properly reset to False"
|
||||||
|
assert mock_add_frame.call_count == 1, "`add_frame` should have been called 2 times"
|
||||||
|
assert len(dataset) == 1, "`dataset` should contain only 1 frame"
|
||||||
|
|
Loading…
Reference in New Issue