overall improve, fix some issues with events, add some tests for events
This commit is contained in:
parent
9f5c586c1a
commit
63bd5013bf
|
@ -202,9 +202,9 @@ def init_dataset(
|
|||
if rec_info_path.exists():
|
||||
with open(rec_info_path) as f:
|
||||
rec_info = json.load(f)
|
||||
episode_index = rec_info["last_episode_index"] + 1
|
||||
num_episodes = rec_info["last_episode_index"] + 1
|
||||
else:
|
||||
episode_index = 0
|
||||
num_episodes = 0
|
||||
|
||||
dataset = {
|
||||
"repo_id": repo_id,
|
||||
|
@ -214,7 +214,7 @@ def init_dataset(
|
|||
"fps": fps,
|
||||
"video": video,
|
||||
"rec_info_path": rec_info_path,
|
||||
"current_episode_index": episode_index,
|
||||
"num_episodes": num_episodes,
|
||||
}
|
||||
|
||||
if write_images:
|
||||
|
@ -249,7 +249,7 @@ def add_frame(dataset, observation, action):
|
|||
dataset["current_frame_index"] = 0
|
||||
|
||||
ep_dict = dataset["current_episode"]
|
||||
episode_index = dataset["current_episode_index"]
|
||||
episode_index = dataset["num_episodes"]
|
||||
frame_index = dataset["current_frame_index"]
|
||||
videos_dir = dataset["videos_dir"]
|
||||
video = dataset["video"]
|
||||
|
@ -300,15 +300,19 @@ def add_frame(dataset, observation, action):
|
|||
dataset["current_frame_index"] += 1
|
||||
|
||||
|
||||
def delete_episode(dataset):
|
||||
def delete_current_episode(dataset):
|
||||
del dataset["current_episode"]
|
||||
del dataset["current_frame_index"]
|
||||
|
||||
# TODO(rcadene): remove images and videos, etc.
|
||||
# delete temporary images
|
||||
episode_index = dataset["num_episodes"]
|
||||
videos_dir = dataset["videos_dir"]
|
||||
for tmp_imgs_dir in videos_dir.glob(f"*_episode_{episode_index:06d}"):
|
||||
shutil.rmtree(tmp_imgs_dir)
|
||||
|
||||
|
||||
def save_episode(dataset):
|
||||
episode_index = dataset["current_episode_index"]
|
||||
def save_current_episode(dataset):
|
||||
episode_index = dataset["num_episodes"]
|
||||
ep_dict = dataset["current_episode"]
|
||||
episodes_dir = dataset["episodes_dir"]
|
||||
rec_info_path = dataset["rec_info_path"]
|
||||
|
@ -337,14 +341,13 @@ def save_episode(dataset):
|
|||
# force re-initialization of episode dictionnary during add_frame
|
||||
del dataset["current_episode"]
|
||||
|
||||
dataset["current_episode_index"] += 1
|
||||
dataset["num_episodes"] += 1
|
||||
|
||||
|
||||
def encode_videos(dataset, play_sounds):
|
||||
def encode_videos(dataset, image_keys, play_sounds):
|
||||
log_say("Encoding videos", play_sounds)
|
||||
|
||||
num_episodes = dataset["current_episode_index"]
|
||||
image_keys = dataset["image_keys"]
|
||||
num_episodes = dataset["num_episodes"]
|
||||
videos_dir = dataset["videos_dir"]
|
||||
local_dir = dataset["local_dir"]
|
||||
fps = dataset["fps"]
|
||||
|
@ -368,7 +371,7 @@ def encode_videos(dataset, play_sounds):
|
|||
def from_dataset_to_lerobot_dataset(dataset, play_sounds):
|
||||
log_say("Consolidate episodes", play_sounds)
|
||||
|
||||
num_episodes = dataset["current_episode_index"]
|
||||
num_episodes = dataset["num_episodes"]
|
||||
episodes_dir = dataset["episodes_dir"]
|
||||
videos_dir = dataset["videos_dir"]
|
||||
video = dataset["video"]
|
||||
|
@ -382,6 +385,10 @@ def from_dataset_to_lerobot_dataset(dataset, play_sounds):
|
|||
ep_dicts.append(ep_dict)
|
||||
data_dict = concatenate_episodes(ep_dicts)
|
||||
|
||||
if video:
|
||||
image_keys = [key for key in data_dict if "image" in key]
|
||||
encode_videos(dataset, image_keys, play_sounds)
|
||||
|
||||
total_frames = data_dict["frame_index"].shape[0]
|
||||
data_dict["index"] = torch.arange(0, total_frames, 1)
|
||||
|
||||
|
@ -443,16 +450,11 @@ def push_lerobot_dataset_to_hub(lerobot_dataset, tags):
|
|||
|
||||
|
||||
def create_lerobot_dataset(dataset, run_compute_stats, push_to_hub, tags, play_sounds):
|
||||
video = dataset["video"]
|
||||
|
||||
if "image_writer" in dataset:
|
||||
logging.info("Waiting for image writer to terminate...")
|
||||
image_writer = dataset["image_writer"]
|
||||
stop_image_writer(image_writer, timeout=20)
|
||||
|
||||
if video:
|
||||
encode_videos(dataset, play_sounds)
|
||||
|
||||
lerobot_dataset = from_dataset_to_lerobot_dataset(dataset, play_sounds)
|
||||
|
||||
if run_compute_stats:
|
||||
|
|
|
@ -116,9 +116,6 @@ def predict_action(observation, policy, device, use_amp):
|
|||
|
||||
|
||||
def init_keyboard_listener():
|
||||
# Only import pynput if not in a headless environment
|
||||
from pynput import keyboard
|
||||
|
||||
# Allow to exit early while recording an episode or resetting the environment,
|
||||
# by tapping the right arrow key '->'. This might require a sudo permission
|
||||
# to allow your terminal to monitor keyboard events.
|
||||
|
@ -127,6 +124,16 @@ def init_keyboard_listener():
|
|||
events["rerecord_episode"] = False
|
||||
events["stop_recording"] = False
|
||||
|
||||
if is_headless():
|
||||
logging.warning(
|
||||
"Headless environment detected. On-screen cameras display and keyboard inputs will not be available."
|
||||
)
|
||||
listener = None
|
||||
return listener, events
|
||||
|
||||
# Only import pynput if not in a headless environment
|
||||
from pynput import keyboard
|
||||
|
||||
def on_press(key):
|
||||
try:
|
||||
if key == keyboard.Key.right:
|
||||
|
@ -175,7 +182,8 @@ def init_policy(pretrained_policy_name_or_path, policy_overrides, fps):
|
|||
return policy, fps, device, use_amp
|
||||
|
||||
|
||||
def warmup_record(robot, enable_teloperation, warmup_time_s, display_cameras, play_sounds, fps):
|
||||
def warmup_record(robot, events, enable_teloperation, warmup_time_s, display_cameras, play_sounds, fps):
|
||||
# TODO(rcadene): refactor warmup_record and reset_environment
|
||||
timestamp = 0
|
||||
start_warmup_t = time.perf_counter()
|
||||
|
||||
|
@ -203,24 +211,23 @@ def warmup_record(robot, enable_teloperation, warmup_time_s, display_cameras, pl
|
|||
log_control_info(robot, dt_s, fps=fps)
|
||||
|
||||
timestamp = time.perf_counter() - start_warmup_t
|
||||
if events is not None and events["exit_early"]:
|
||||
events["exit_early"] = False
|
||||
break
|
||||
|
||||
|
||||
@safe_stop_image_writer
|
||||
def record_episode(
|
||||
dataset,
|
||||
robot,
|
||||
episode_index,
|
||||
events,
|
||||
episode_time_s,
|
||||
display_cameras,
|
||||
play_sounds,
|
||||
policy,
|
||||
device,
|
||||
use_amp,
|
||||
fps,
|
||||
):
|
||||
log_say(f"Recording episode {episode_index}", play_sounds)
|
||||
|
||||
timestamp = 0
|
||||
start_episode_t = time.perf_counter()
|
||||
while timestamp < episode_time_s:
|
||||
|
@ -258,9 +265,8 @@ def record_episode(
|
|||
break
|
||||
|
||||
|
||||
def reset_environment(robot, events, reset_time_s, play_sounds):
|
||||
log_say("Reset the environment", play_sounds)
|
||||
|
||||
def reset_environment(robot, events, reset_time_s):
|
||||
# TODO(rcadene): refactor warmup_record and reset_environment
|
||||
# TODO(alibets): allow for teleop during reset
|
||||
if has_method(robot, "teleop_safety_stop"):
|
||||
robot.teleop_safety_stop()
|
||||
|
@ -279,9 +285,7 @@ def reset_environment(robot, events, reset_time_s, play_sounds):
|
|||
break
|
||||
|
||||
|
||||
def done_recording(robot, listener, display_cameras, play_sounds):
|
||||
log_say("Done recording", play_sounds, blocking=True)
|
||||
|
||||
def stop_recording(robot, listener, display_cameras):
|
||||
robot.disconnect()
|
||||
|
||||
if not is_headless():
|
||||
|
|
|
@ -99,7 +99,6 @@ python lerobot/scripts/control_robot.py record \
|
|||
"""
|
||||
|
||||
import argparse
|
||||
import logging
|
||||
import time
|
||||
from pathlib import Path
|
||||
from typing import List
|
||||
|
@ -108,21 +107,20 @@ from typing import List
|
|||
from lerobot.common.datasets.lerobot_dataset import LeRobotDataset
|
||||
from lerobot.common.datasets.populate_dataset import (
|
||||
create_lerobot_dataset,
|
||||
delete_episode,
|
||||
delete_current_episode,
|
||||
init_dataset,
|
||||
save_episode,
|
||||
save_current_episode,
|
||||
)
|
||||
from lerobot.common.robot_devices.control_robot import (
|
||||
done_recording,
|
||||
has_method,
|
||||
init_keyboard_listener,
|
||||
init_policy,
|
||||
is_headless,
|
||||
log_control_info,
|
||||
log_say,
|
||||
record_episode,
|
||||
reset_environment,
|
||||
sanity_check_dataset_name,
|
||||
stop_recording,
|
||||
warmup_record,
|
||||
)
|
||||
from lerobot.common.robot_devices.robots.factory import make_robot
|
||||
|
@ -229,9 +227,6 @@ def record(
|
|||
device = None
|
||||
use_amp = None
|
||||
|
||||
if not robot.is_connected:
|
||||
robot.connect()
|
||||
|
||||
# Load pretrained policy
|
||||
if pretrained_policy_name_or_path is not None:
|
||||
policy, fps, device, use_amp = init_policy(pretrained_policy_name_or_path, policy_overrides, fps)
|
||||
|
@ -249,58 +244,72 @@ def record(
|
|||
num_image_writer_threads=num_image_writer_threads_per_camera * robot.num_cameras,
|
||||
)
|
||||
|
||||
if is_headless():
|
||||
logging.warning(
|
||||
"Headless environment detected. On-screen cameras display and keyboard inputs will not be available."
|
||||
)
|
||||
else:
|
||||
listener, events = init_keyboard_listener()
|
||||
if not robot.is_connected:
|
||||
robot.connect()
|
||||
|
||||
# Execute a few seconds without recording data to:
|
||||
listener, events = init_keyboard_listener()
|
||||
|
||||
# Execute a few seconds without recording to:
|
||||
# 1. teleoperate the robot to move it in starting position if no policy provided,
|
||||
# 2. give times to the robot devices to connect and start synchronizing,
|
||||
# 3. place the cameras windows on screen
|
||||
enable_teleoperation = policy is None
|
||||
warmup_record(robot, enable_teleoperation, warmup_time_s, display_cameras, play_sounds, fps)
|
||||
warmup_record(robot, events, enable_teleoperation, warmup_time_s, display_cameras, play_sounds, fps)
|
||||
|
||||
if has_method(robot, "teleop_safety_stop"):
|
||||
robot.teleop_safety_stop()
|
||||
|
||||
while True:
|
||||
episode_index = dataset["current_episode_index"]
|
||||
if episode_index >= num_episodes:
|
||||
if dataset["num_episodes"] >= num_episodes:
|
||||
break
|
||||
|
||||
episode_index = dataset["num_episodes"]
|
||||
log_say(f"Recording episode {episode_index}", play_sounds)
|
||||
record_episode(
|
||||
dataset=dataset,
|
||||
robot=robot,
|
||||
episode_index=episode_index,
|
||||
events=events,
|
||||
episode_time_s=episode_time_s,
|
||||
display_cameras=display_cameras,
|
||||
play_sounds=play_sounds,
|
||||
policy=policy,
|
||||
device=device,
|
||||
use_amp=use_amp,
|
||||
fps=fps,
|
||||
)
|
||||
|
||||
# Do not reset for the last episode to be recorded
|
||||
if episode_index < num_episodes - 1:
|
||||
reset_environment(robot, events, reset_time_s, play_sounds)
|
||||
|
||||
if events is not None and events["rerecord_episode"]:
|
||||
events["rerecord_episode"] = False
|
||||
delete_episode(dataset)
|
||||
continue
|
||||
|
||||
save_episode(dataset)
|
||||
|
||||
if events is not None and events["exit_early"]:
|
||||
num_episodes = episode_index
|
||||
# In case stop recording is requested during `record_episode`
|
||||
if events is not None and events["stop_recording"]:
|
||||
save_current_episode(dataset)
|
||||
break
|
||||
|
||||
done_recording(robot, listener, display_cameras, play_sounds)
|
||||
# Execute a few seconds without recording to give time to manually reset the environment
|
||||
# Current code logic doesn't allow to teleoperate during this time.
|
||||
# TODO(rcadene): add an option to enable teleoperation during reset
|
||||
# Skip reset for the last episode to be recorded
|
||||
if episode_index < num_episodes - 1:
|
||||
log_say("Reset the environment", play_sounds)
|
||||
reset_environment(robot, events, reset_time_s)
|
||||
|
||||
# In case stop recording is requested during `reset_environment`
|
||||
if events is not None and events["stop_recording"]:
|
||||
save_current_episode(dataset)
|
||||
break
|
||||
|
||||
if events is not None and events["rerecord_episode"]:
|
||||
log_say("Re-record episode", play_sounds)
|
||||
events["rerecord_episode"] = False
|
||||
events["exit_early"] = False
|
||||
delete_current_episode(dataset)
|
||||
# Force reset
|
||||
log_say("Reset the environment", play_sounds)
|
||||
reset_environment(robot, events, reset_time_s)
|
||||
continue
|
||||
|
||||
# Increment by one dataset["current_episode_index"]
|
||||
save_current_episode(dataset)
|
||||
|
||||
log_say("Stop recording", play_sounds, blocking=True)
|
||||
stop_recording(robot, listener, display_cameras)
|
||||
|
||||
lerobot_dataset = create_lerobot_dataset(dataset, run_compute_stats, push_to_hub, tags, play_sounds)
|
||||
|
||||
|
|
|
@ -25,9 +25,11 @@ pytest -sx 'tests/test_control_robot.py::test_teleoperate[aloha-True]'
|
|||
|
||||
import multiprocessing
|
||||
from pathlib import Path
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
|
||||
from lerobot.common.datasets.populate_dataset import add_frame
|
||||
from lerobot.common.logger import Logger
|
||||
from lerobot.common.policies.factory import make_policy
|
||||
from lerobot.common.utils.utils import init_hydra_config
|
||||
|
@ -222,3 +224,148 @@ def test_record_and_replay_and_policy(tmpdir, request, robot_type, mock):
|
|||
)
|
||||
|
||||
del robot
|
||||
|
||||
|
||||
@pytest.mark.parametrize("robot_type, mock", [("koch", True)])
|
||||
@require_robot
|
||||
def test_record_with_event_rerecord_episode(tmpdir, request, robot_type, mock):
|
||||
if mock and robot_type != "aloha":
|
||||
request.getfixturevalue("patch_builtins_input")
|
||||
|
||||
# Create an empty calibration directory to trigger manual calibration
|
||||
# and avoid writing calibration files in user .cache/calibration folder
|
||||
calibration_dir = tmpdir / robot_type
|
||||
overrides = [f"calibration_dir={calibration_dir}"]
|
||||
else:
|
||||
# Use the default .cache/calibration folder when mock=False or for aloha
|
||||
overrides = []
|
||||
|
||||
robot = make_robot(robot_type, overrides=overrides, mock=mock)
|
||||
with (
|
||||
patch("lerobot.scripts.control_robot.init_keyboard_listener") as mock_listener,
|
||||
patch("lerobot.common.robot_devices.control_robot.add_frame", wraps=add_frame) as mock_add_frame,
|
||||
):
|
||||
mock_events = {}
|
||||
mock_events["exit_early"] = True
|
||||
mock_events["rerecord_episode"] = True
|
||||
mock_events["stop_recording"] = False
|
||||
mock_listener.return_value = (None, mock_events)
|
||||
|
||||
root = Path(tmpdir) / "data"
|
||||
repo_id = "lerobot/debug"
|
||||
|
||||
dataset = record(
|
||||
robot,
|
||||
fps=1,
|
||||
root=root,
|
||||
repo_id=repo_id,
|
||||
warmup_time_s=0,
|
||||
episode_time_s=1,
|
||||
num_episodes=1,
|
||||
push_to_hub=False,
|
||||
video=False,
|
||||
display_cameras=False,
|
||||
play_sounds=False,
|
||||
)
|
||||
|
||||
assert not mock_events["rerecord_episode"], "`rerecord_episode` wasn't properly reset to False"
|
||||
assert not mock_events["exit_early"], "`exit_early` wasn't properly reset to False"
|
||||
assert mock_add_frame.call_count == 2, "`add_frame` should have been called 2 times"
|
||||
assert len(dataset) == 1, "`dataset` should contain only 1 frame"
|
||||
|
||||
|
||||
@pytest.mark.parametrize("robot_type, mock", [("koch", True)])
|
||||
@require_robot
|
||||
def test_record_with_event_exit_early(tmpdir, request, robot_type, mock):
|
||||
if mock:
|
||||
request.getfixturevalue("patch_builtins_input")
|
||||
|
||||
# Create an empty calibration directory to trigger manual calibration
|
||||
# and avoid writing calibration files in user .cache/calibration folder
|
||||
calibration_dir = tmpdir / robot_type
|
||||
overrides = [f"calibration_dir={calibration_dir}"]
|
||||
else:
|
||||
# Use the default .cache/calibration folder when mock=False or for aloha
|
||||
overrides = []
|
||||
|
||||
robot = make_robot(robot_type, overrides=overrides, mock=mock)
|
||||
with (
|
||||
patch("lerobot.scripts.control_robot.init_keyboard_listener") as mock_listener,
|
||||
patch("lerobot.common.robot_devices.control_robot.add_frame", wraps=add_frame) as mock_add_frame,
|
||||
):
|
||||
mock_events = {}
|
||||
mock_events["exit_early"] = True
|
||||
mock_events["rerecord_episode"] = False
|
||||
mock_events["stop_recording"] = False
|
||||
mock_listener.return_value = (None, mock_events)
|
||||
|
||||
root = Path(tmpdir) / "data"
|
||||
repo_id = "lerobot/debug"
|
||||
|
||||
dataset = record(
|
||||
robot,
|
||||
fps=2,
|
||||
root=root,
|
||||
repo_id=repo_id,
|
||||
warmup_time_s=0,
|
||||
episode_time_s=1,
|
||||
num_episodes=1,
|
||||
push_to_hub=False,
|
||||
video=False,
|
||||
display_cameras=False,
|
||||
play_sounds=False,
|
||||
)
|
||||
|
||||
assert not mock_events["exit_early"], "`exit_early` wasn't properly reset to False"
|
||||
assert mock_add_frame.call_count == 1, "`add_frame` should have been called 1 time"
|
||||
assert len(dataset) == 1, "`dataset` should contain only 1 frame"
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"robot_type, mock, num_image_writer_processes", [("koch", True, 0), ("koch", True, 1)]
|
||||
)
|
||||
@require_robot
|
||||
def test_record_with_event_stop_recording(tmpdir, request, robot_type, mock, num_image_writer_processes):
|
||||
if mock:
|
||||
request.getfixturevalue("patch_builtins_input")
|
||||
|
||||
# Create an empty calibration directory to trigger manual calibration
|
||||
# and avoid writing calibration files in user .cache/calibration folder
|
||||
calibration_dir = tmpdir / robot_type
|
||||
overrides = [f"calibration_dir={calibration_dir}"]
|
||||
else:
|
||||
# Use the default .cache/calibration folder when mock=False or for aloha
|
||||
overrides = []
|
||||
|
||||
robot = make_robot(robot_type, overrides=overrides, mock=mock)
|
||||
with (
|
||||
patch("lerobot.scripts.control_robot.init_keyboard_listener") as mock_listener,
|
||||
patch("lerobot.common.robot_devices.control_robot.add_frame", wraps=add_frame) as mock_add_frame,
|
||||
):
|
||||
mock_events = {}
|
||||
mock_events["exit_early"] = True
|
||||
mock_events["rerecord_episode"] = False
|
||||
mock_events["stop_recording"] = True
|
||||
mock_listener.return_value = (None, mock_events)
|
||||
|
||||
root = Path(tmpdir) / "data"
|
||||
repo_id = "lerobot/debug"
|
||||
|
||||
dataset = record(
|
||||
robot,
|
||||
fps=1,
|
||||
root=root,
|
||||
repo_id=repo_id,
|
||||
warmup_time_s=0,
|
||||
episode_time_s=1,
|
||||
num_episodes=2,
|
||||
push_to_hub=False,
|
||||
video=False,
|
||||
display_cameras=False,
|
||||
play_sounds=False,
|
||||
num_image_writer_processes=num_image_writer_processes,
|
||||
)
|
||||
|
||||
assert not mock_events["exit_early"], "`exit_early` wasn't properly reset to False"
|
||||
assert mock_add_frame.call_count == 1, "`add_frame` should have been called 2 times"
|
||||
assert len(dataset) == 1, "`dataset` should contain only 1 frame"
|
||||
|
|
Loading…
Reference in New Issue