overall improve, fix some issues with events, add some tests for events

This commit is contained in:
Remi Cadene 2024-10-13 15:44:14 +02:00
parent 9f5c586c1a
commit 63bd5013bf
4 changed files with 228 additions and 66 deletions

View File

@ -202,9 +202,9 @@ def init_dataset(
if rec_info_path.exists(): if rec_info_path.exists():
with open(rec_info_path) as f: with open(rec_info_path) as f:
rec_info = json.load(f) rec_info = json.load(f)
episode_index = rec_info["last_episode_index"] + 1 num_episodes = rec_info["last_episode_index"] + 1
else: else:
episode_index = 0 num_episodes = 0
dataset = { dataset = {
"repo_id": repo_id, "repo_id": repo_id,
@ -214,7 +214,7 @@ def init_dataset(
"fps": fps, "fps": fps,
"video": video, "video": video,
"rec_info_path": rec_info_path, "rec_info_path": rec_info_path,
"current_episode_index": episode_index, "num_episodes": num_episodes,
} }
if write_images: if write_images:
@ -249,7 +249,7 @@ def add_frame(dataset, observation, action):
dataset["current_frame_index"] = 0 dataset["current_frame_index"] = 0
ep_dict = dataset["current_episode"] ep_dict = dataset["current_episode"]
episode_index = dataset["current_episode_index"] episode_index = dataset["num_episodes"]
frame_index = dataset["current_frame_index"] frame_index = dataset["current_frame_index"]
videos_dir = dataset["videos_dir"] videos_dir = dataset["videos_dir"]
video = dataset["video"] video = dataset["video"]
@ -300,15 +300,19 @@ def add_frame(dataset, observation, action):
dataset["current_frame_index"] += 1 dataset["current_frame_index"] += 1
def delete_episode(dataset): def delete_current_episode(dataset):
del dataset["current_episode"] del dataset["current_episode"]
del dataset["current_frame_index"] del dataset["current_frame_index"]
# TODO(rcadene): remove images and videos, etc. # delete temporary images
episode_index = dataset["num_episodes"]
videos_dir = dataset["videos_dir"]
for tmp_imgs_dir in videos_dir.glob(f"*_episode_{episode_index:06d}"):
shutil.rmtree(tmp_imgs_dir)
def save_episode(dataset): def save_current_episode(dataset):
episode_index = dataset["current_episode_index"] episode_index = dataset["num_episodes"]
ep_dict = dataset["current_episode"] ep_dict = dataset["current_episode"]
episodes_dir = dataset["episodes_dir"] episodes_dir = dataset["episodes_dir"]
rec_info_path = dataset["rec_info_path"] rec_info_path = dataset["rec_info_path"]
@ -337,14 +341,13 @@ def save_episode(dataset):
# force re-initialization of episode dictionnary during add_frame # force re-initialization of episode dictionnary during add_frame
del dataset["current_episode"] del dataset["current_episode"]
dataset["current_episode_index"] += 1 dataset["num_episodes"] += 1
def encode_videos(dataset, play_sounds): def encode_videos(dataset, image_keys, play_sounds):
log_say("Encoding videos", play_sounds) log_say("Encoding videos", play_sounds)
num_episodes = dataset["current_episode_index"] num_episodes = dataset["num_episodes"]
image_keys = dataset["image_keys"]
videos_dir = dataset["videos_dir"] videos_dir = dataset["videos_dir"]
local_dir = dataset["local_dir"] local_dir = dataset["local_dir"]
fps = dataset["fps"] fps = dataset["fps"]
@ -368,7 +371,7 @@ def encode_videos(dataset, play_sounds):
def from_dataset_to_lerobot_dataset(dataset, play_sounds): def from_dataset_to_lerobot_dataset(dataset, play_sounds):
log_say("Consolidate episodes", play_sounds) log_say("Consolidate episodes", play_sounds)
num_episodes = dataset["current_episode_index"] num_episodes = dataset["num_episodes"]
episodes_dir = dataset["episodes_dir"] episodes_dir = dataset["episodes_dir"]
videos_dir = dataset["videos_dir"] videos_dir = dataset["videos_dir"]
video = dataset["video"] video = dataset["video"]
@ -382,6 +385,10 @@ def from_dataset_to_lerobot_dataset(dataset, play_sounds):
ep_dicts.append(ep_dict) ep_dicts.append(ep_dict)
data_dict = concatenate_episodes(ep_dicts) data_dict = concatenate_episodes(ep_dicts)
if video:
image_keys = [key for key in data_dict if "image" in key]
encode_videos(dataset, image_keys, play_sounds)
total_frames = data_dict["frame_index"].shape[0] total_frames = data_dict["frame_index"].shape[0]
data_dict["index"] = torch.arange(0, total_frames, 1) data_dict["index"] = torch.arange(0, total_frames, 1)
@ -443,16 +450,11 @@ def push_lerobot_dataset_to_hub(lerobot_dataset, tags):
def create_lerobot_dataset(dataset, run_compute_stats, push_to_hub, tags, play_sounds): def create_lerobot_dataset(dataset, run_compute_stats, push_to_hub, tags, play_sounds):
video = dataset["video"]
if "image_writer" in dataset: if "image_writer" in dataset:
logging.info("Waiting for image writer to terminate...") logging.info("Waiting for image writer to terminate...")
image_writer = dataset["image_writer"] image_writer = dataset["image_writer"]
stop_image_writer(image_writer, timeout=20) stop_image_writer(image_writer, timeout=20)
if video:
encode_videos(dataset, play_sounds)
lerobot_dataset = from_dataset_to_lerobot_dataset(dataset, play_sounds) lerobot_dataset = from_dataset_to_lerobot_dataset(dataset, play_sounds)
if run_compute_stats: if run_compute_stats:

View File

@ -116,9 +116,6 @@ def predict_action(observation, policy, device, use_amp):
def init_keyboard_listener(): def init_keyboard_listener():
# Only import pynput if not in a headless environment
from pynput import keyboard
# Allow to exit early while recording an episode or resetting the environment, # Allow to exit early while recording an episode or resetting the environment,
# by tapping the right arrow key '->'. This might require a sudo permission # by tapping the right arrow key '->'. This might require a sudo permission
# to allow your terminal to monitor keyboard events. # to allow your terminal to monitor keyboard events.
@ -127,6 +124,16 @@ def init_keyboard_listener():
events["rerecord_episode"] = False events["rerecord_episode"] = False
events["stop_recording"] = False events["stop_recording"] = False
if is_headless():
logging.warning(
"Headless environment detected. On-screen cameras display and keyboard inputs will not be available."
)
listener = None
return listener, events
# Only import pynput if not in a headless environment
from pynput import keyboard
def on_press(key): def on_press(key):
try: try:
if key == keyboard.Key.right: if key == keyboard.Key.right:
@ -175,7 +182,8 @@ def init_policy(pretrained_policy_name_or_path, policy_overrides, fps):
return policy, fps, device, use_amp return policy, fps, device, use_amp
def warmup_record(robot, enable_teloperation, warmup_time_s, display_cameras, play_sounds, fps): def warmup_record(robot, events, enable_teloperation, warmup_time_s, display_cameras, play_sounds, fps):
# TODO(rcadene): refactor warmup_record and reset_environment
timestamp = 0 timestamp = 0
start_warmup_t = time.perf_counter() start_warmup_t = time.perf_counter()
@ -203,24 +211,23 @@ def warmup_record(robot, enable_teloperation, warmup_time_s, display_cameras, pl
log_control_info(robot, dt_s, fps=fps) log_control_info(robot, dt_s, fps=fps)
timestamp = time.perf_counter() - start_warmup_t timestamp = time.perf_counter() - start_warmup_t
if events is not None and events["exit_early"]:
events["exit_early"] = False
break
@safe_stop_image_writer @safe_stop_image_writer
def record_episode( def record_episode(
dataset, dataset,
robot, robot,
episode_index,
events, events,
episode_time_s, episode_time_s,
display_cameras, display_cameras,
play_sounds,
policy, policy,
device, device,
use_amp, use_amp,
fps, fps,
): ):
log_say(f"Recording episode {episode_index}", play_sounds)
timestamp = 0 timestamp = 0
start_episode_t = time.perf_counter() start_episode_t = time.perf_counter()
while timestamp < episode_time_s: while timestamp < episode_time_s:
@ -258,9 +265,8 @@ def record_episode(
break break
def reset_environment(robot, events, reset_time_s, play_sounds): def reset_environment(robot, events, reset_time_s):
log_say("Reset the environment", play_sounds) # TODO(rcadene): refactor warmup_record and reset_environment
# TODO(alibets): allow for teleop during reset # TODO(alibets): allow for teleop during reset
if has_method(robot, "teleop_safety_stop"): if has_method(robot, "teleop_safety_stop"):
robot.teleop_safety_stop() robot.teleop_safety_stop()
@ -279,9 +285,7 @@ def reset_environment(robot, events, reset_time_s, play_sounds):
break break
def done_recording(robot, listener, display_cameras, play_sounds): def stop_recording(robot, listener, display_cameras):
log_say("Done recording", play_sounds, blocking=True)
robot.disconnect() robot.disconnect()
if not is_headless(): if not is_headless():

View File

@ -99,7 +99,6 @@ python lerobot/scripts/control_robot.py record \
""" """
import argparse import argparse
import logging
import time import time
from pathlib import Path from pathlib import Path
from typing import List from typing import List
@ -108,21 +107,20 @@ from typing import List
from lerobot.common.datasets.lerobot_dataset import LeRobotDataset from lerobot.common.datasets.lerobot_dataset import LeRobotDataset
from lerobot.common.datasets.populate_dataset import ( from lerobot.common.datasets.populate_dataset import (
create_lerobot_dataset, create_lerobot_dataset,
delete_episode, delete_current_episode,
init_dataset, init_dataset,
save_episode, save_current_episode,
) )
from lerobot.common.robot_devices.control_robot import ( from lerobot.common.robot_devices.control_robot import (
done_recording,
has_method, has_method,
init_keyboard_listener, init_keyboard_listener,
init_policy, init_policy,
is_headless,
log_control_info, log_control_info,
log_say, log_say,
record_episode, record_episode,
reset_environment, reset_environment,
sanity_check_dataset_name, sanity_check_dataset_name,
stop_recording,
warmup_record, warmup_record,
) )
from lerobot.common.robot_devices.robots.factory import make_robot from lerobot.common.robot_devices.robots.factory import make_robot
@ -229,9 +227,6 @@ def record(
device = None device = None
use_amp = None use_amp = None
if not robot.is_connected:
robot.connect()
# Load pretrained policy # Load pretrained policy
if pretrained_policy_name_or_path is not None: if pretrained_policy_name_or_path is not None:
policy, fps, device, use_amp = init_policy(pretrained_policy_name_or_path, policy_overrides, fps) policy, fps, device, use_amp = init_policy(pretrained_policy_name_or_path, policy_overrides, fps)
@ -249,58 +244,72 @@ def record(
num_image_writer_threads=num_image_writer_threads_per_camera * robot.num_cameras, num_image_writer_threads=num_image_writer_threads_per_camera * robot.num_cameras,
) )
if is_headless(): if not robot.is_connected:
logging.warning( robot.connect()
"Headless environment detected. On-screen cameras display and keyboard inputs will not be available."
)
else:
listener, events = init_keyboard_listener() listener, events = init_keyboard_listener()
# Execute a few seconds without recording data to: # Execute a few seconds without recording to:
# 1. teleoperate the robot to move it in starting position if no policy provided, # 1. teleoperate the robot to move it in starting position if no policy provided,
# 2. give times to the robot devices to connect and start synchronizing, # 2. give times to the robot devices to connect and start synchronizing,
# 3. place the cameras windows on screen # 3. place the cameras windows on screen
enable_teleoperation = policy is None enable_teleoperation = policy is None
warmup_record(robot, enable_teleoperation, warmup_time_s, display_cameras, play_sounds, fps) warmup_record(robot, events, enable_teleoperation, warmup_time_s, display_cameras, play_sounds, fps)
if has_method(robot, "teleop_safety_stop"): if has_method(robot, "teleop_safety_stop"):
robot.teleop_safety_stop() robot.teleop_safety_stop()
while True: while True:
episode_index = dataset["current_episode_index"] if dataset["num_episodes"] >= num_episodes:
if episode_index >= num_episodes:
break break
episode_index = dataset["num_episodes"]
log_say(f"Recording episode {episode_index}", play_sounds)
record_episode( record_episode(
dataset=dataset, dataset=dataset,
robot=robot, robot=robot,
episode_index=episode_index,
events=events, events=events,
episode_time_s=episode_time_s, episode_time_s=episode_time_s,
display_cameras=display_cameras, display_cameras=display_cameras,
play_sounds=play_sounds,
policy=policy, policy=policy,
device=device, device=device,
use_amp=use_amp, use_amp=use_amp,
fps=fps, fps=fps,
) )
# Do not reset for the last episode to be recorded # In case stop recording is requested during `record_episode`
if episode_index < num_episodes - 1: if events is not None and events["stop_recording"]:
reset_environment(robot, events, reset_time_s, play_sounds) save_current_episode(dataset)
if events is not None and events["rerecord_episode"]:
events["rerecord_episode"] = False
delete_episode(dataset)
continue
save_episode(dataset)
if events is not None and events["exit_early"]:
num_episodes = episode_index
break break
done_recording(robot, listener, display_cameras, play_sounds) # Execute a few seconds without recording to give time to manually reset the environment
# Current code logic doesn't allow to teleoperate during this time.
# TODO(rcadene): add an option to enable teleoperation during reset
# Skip reset for the last episode to be recorded
if episode_index < num_episodes - 1:
log_say("Reset the environment", play_sounds)
reset_environment(robot, events, reset_time_s)
# In case stop recording is requested during `reset_environment`
if events is not None and events["stop_recording"]:
save_current_episode(dataset)
break
if events is not None and events["rerecord_episode"]:
log_say("Re-record episode", play_sounds)
events["rerecord_episode"] = False
events["exit_early"] = False
delete_current_episode(dataset)
# Force reset
log_say("Reset the environment", play_sounds)
reset_environment(robot, events, reset_time_s)
continue
# Increment by one dataset["current_episode_index"]
save_current_episode(dataset)
log_say("Stop recording", play_sounds, blocking=True)
stop_recording(robot, listener, display_cameras)
lerobot_dataset = create_lerobot_dataset(dataset, run_compute_stats, push_to_hub, tags, play_sounds) lerobot_dataset = create_lerobot_dataset(dataset, run_compute_stats, push_to_hub, tags, play_sounds)

View File

@ -25,9 +25,11 @@ pytest -sx 'tests/test_control_robot.py::test_teleoperate[aloha-True]'
import multiprocessing import multiprocessing
from pathlib import Path from pathlib import Path
from unittest.mock import patch
import pytest import pytest
from lerobot.common.datasets.populate_dataset import add_frame
from lerobot.common.logger import Logger from lerobot.common.logger import Logger
from lerobot.common.policies.factory import make_policy from lerobot.common.policies.factory import make_policy
from lerobot.common.utils.utils import init_hydra_config from lerobot.common.utils.utils import init_hydra_config
@ -222,3 +224,148 @@ def test_record_and_replay_and_policy(tmpdir, request, robot_type, mock):
) )
del robot del robot
@pytest.mark.parametrize("robot_type, mock", [("koch", True)])
@require_robot
def test_record_with_event_rerecord_episode(tmpdir, request, robot_type, mock):
if mock and robot_type != "aloha":
request.getfixturevalue("patch_builtins_input")
# Create an empty calibration directory to trigger manual calibration
# and avoid writing calibration files in user .cache/calibration folder
calibration_dir = tmpdir / robot_type
overrides = [f"calibration_dir={calibration_dir}"]
else:
# Use the default .cache/calibration folder when mock=False or for aloha
overrides = []
robot = make_robot(robot_type, overrides=overrides, mock=mock)
with (
patch("lerobot.scripts.control_robot.init_keyboard_listener") as mock_listener,
patch("lerobot.common.robot_devices.control_robot.add_frame", wraps=add_frame) as mock_add_frame,
):
mock_events = {}
mock_events["exit_early"] = True
mock_events["rerecord_episode"] = True
mock_events["stop_recording"] = False
mock_listener.return_value = (None, mock_events)
root = Path(tmpdir) / "data"
repo_id = "lerobot/debug"
dataset = record(
robot,
fps=1,
root=root,
repo_id=repo_id,
warmup_time_s=0,
episode_time_s=1,
num_episodes=1,
push_to_hub=False,
video=False,
display_cameras=False,
play_sounds=False,
)
assert not mock_events["rerecord_episode"], "`rerecord_episode` wasn't properly reset to False"
assert not mock_events["exit_early"], "`exit_early` wasn't properly reset to False"
assert mock_add_frame.call_count == 2, "`add_frame` should have been called 2 times"
assert len(dataset) == 1, "`dataset` should contain only 1 frame"
@pytest.mark.parametrize("robot_type, mock", [("koch", True)])
@require_robot
def test_record_with_event_exit_early(tmpdir, request, robot_type, mock):
if mock:
request.getfixturevalue("patch_builtins_input")
# Create an empty calibration directory to trigger manual calibration
# and avoid writing calibration files in user .cache/calibration folder
calibration_dir = tmpdir / robot_type
overrides = [f"calibration_dir={calibration_dir}"]
else:
# Use the default .cache/calibration folder when mock=False or for aloha
overrides = []
robot = make_robot(robot_type, overrides=overrides, mock=mock)
with (
patch("lerobot.scripts.control_robot.init_keyboard_listener") as mock_listener,
patch("lerobot.common.robot_devices.control_robot.add_frame", wraps=add_frame) as mock_add_frame,
):
mock_events = {}
mock_events["exit_early"] = True
mock_events["rerecord_episode"] = False
mock_events["stop_recording"] = False
mock_listener.return_value = (None, mock_events)
root = Path(tmpdir) / "data"
repo_id = "lerobot/debug"
dataset = record(
robot,
fps=2,
root=root,
repo_id=repo_id,
warmup_time_s=0,
episode_time_s=1,
num_episodes=1,
push_to_hub=False,
video=False,
display_cameras=False,
play_sounds=False,
)
assert not mock_events["exit_early"], "`exit_early` wasn't properly reset to False"
assert mock_add_frame.call_count == 1, "`add_frame` should have been called 1 time"
assert len(dataset) == 1, "`dataset` should contain only 1 frame"
@pytest.mark.parametrize(
"robot_type, mock, num_image_writer_processes", [("koch", True, 0), ("koch", True, 1)]
)
@require_robot
def test_record_with_event_stop_recording(tmpdir, request, robot_type, mock, num_image_writer_processes):
if mock:
request.getfixturevalue("patch_builtins_input")
# Create an empty calibration directory to trigger manual calibration
# and avoid writing calibration files in user .cache/calibration folder
calibration_dir = tmpdir / robot_type
overrides = [f"calibration_dir={calibration_dir}"]
else:
# Use the default .cache/calibration folder when mock=False or for aloha
overrides = []
robot = make_robot(robot_type, overrides=overrides, mock=mock)
with (
patch("lerobot.scripts.control_robot.init_keyboard_listener") as mock_listener,
patch("lerobot.common.robot_devices.control_robot.add_frame", wraps=add_frame) as mock_add_frame,
):
mock_events = {}
mock_events["exit_early"] = True
mock_events["rerecord_episode"] = False
mock_events["stop_recording"] = True
mock_listener.return_value = (None, mock_events)
root = Path(tmpdir) / "data"
repo_id = "lerobot/debug"
dataset = record(
robot,
fps=1,
root=root,
repo_id=repo_id,
warmup_time_s=0,
episode_time_s=1,
num_episodes=2,
push_to_hub=False,
video=False,
display_cameras=False,
play_sounds=False,
num_image_writer_processes=num_image_writer_processes,
)
assert not mock_events["exit_early"], "`exit_early` wasn't properly reset to False"
assert mock_add_frame.call_count == 1, "`add_frame` should have been called 2 times"
assert len(dataset) == 1, "`dataset` should contain only 1 frame"