From 63bd5013bf285c13ec5a1b336b48b539ed5a9bea Mon Sep 17 00:00:00 2001 From: Remi Cadene Date: Sun, 13 Oct 2024 15:44:14 +0200 Subject: [PATCH] overall improve, fix some issues with events, add some tests for events --- lerobot/common/datasets/populate_dataset.py | 38 ++--- lerobot/common/robot_devices/control_robot.py | 32 ++-- lerobot/scripts/control_robot.py | 77 +++++---- tests/test_control_robot.py | 147 ++++++++++++++++++ 4 files changed, 228 insertions(+), 66 deletions(-) diff --git a/lerobot/common/datasets/populate_dataset.py b/lerobot/common/datasets/populate_dataset.py index 68a1ad6e..0afc5fd0 100644 --- a/lerobot/common/datasets/populate_dataset.py +++ b/lerobot/common/datasets/populate_dataset.py @@ -202,9 +202,9 @@ def init_dataset( if rec_info_path.exists(): with open(rec_info_path) as f: rec_info = json.load(f) - episode_index = rec_info["last_episode_index"] + 1 + num_episodes = rec_info["last_episode_index"] + 1 else: - episode_index = 0 + num_episodes = 0 dataset = { "repo_id": repo_id, @@ -214,7 +214,7 @@ def init_dataset( "fps": fps, "video": video, "rec_info_path": rec_info_path, - "current_episode_index": episode_index, + "num_episodes": num_episodes, } if write_images: @@ -249,7 +249,7 @@ def add_frame(dataset, observation, action): dataset["current_frame_index"] = 0 ep_dict = dataset["current_episode"] - episode_index = dataset["current_episode_index"] + episode_index = dataset["num_episodes"] frame_index = dataset["current_frame_index"] videos_dir = dataset["videos_dir"] video = dataset["video"] @@ -300,15 +300,19 @@ def add_frame(dataset, observation, action): dataset["current_frame_index"] += 1 -def delete_episode(dataset): +def delete_current_episode(dataset): del dataset["current_episode"] del dataset["current_frame_index"] - # TODO(rcadene): remove images and videos, etc. + # delete temporary images + episode_index = dataset["num_episodes"] + videos_dir = dataset["videos_dir"] + for tmp_imgs_dir in videos_dir.glob(f"*_episode_{episode_index:06d}"): + shutil.rmtree(tmp_imgs_dir) -def save_episode(dataset): - episode_index = dataset["current_episode_index"] +def save_current_episode(dataset): + episode_index = dataset["num_episodes"] ep_dict = dataset["current_episode"] episodes_dir = dataset["episodes_dir"] rec_info_path = dataset["rec_info_path"] @@ -337,14 +341,13 @@ def save_episode(dataset): # force re-initialization of episode dictionnary during add_frame del dataset["current_episode"] - dataset["current_episode_index"] += 1 + dataset["num_episodes"] += 1 -def encode_videos(dataset, play_sounds): +def encode_videos(dataset, image_keys, play_sounds): log_say("Encoding videos", play_sounds) - num_episodes = dataset["current_episode_index"] - image_keys = dataset["image_keys"] + num_episodes = dataset["num_episodes"] videos_dir = dataset["videos_dir"] local_dir = dataset["local_dir"] fps = dataset["fps"] @@ -368,7 +371,7 @@ def encode_videos(dataset, play_sounds): def from_dataset_to_lerobot_dataset(dataset, play_sounds): log_say("Consolidate episodes", play_sounds) - num_episodes = dataset["current_episode_index"] + num_episodes = dataset["num_episodes"] episodes_dir = dataset["episodes_dir"] videos_dir = dataset["videos_dir"] video = dataset["video"] @@ -382,6 +385,10 @@ def from_dataset_to_lerobot_dataset(dataset, play_sounds): ep_dicts.append(ep_dict) data_dict = concatenate_episodes(ep_dicts) + if video: + image_keys = [key for key in data_dict if "image" in key] + encode_videos(dataset, image_keys, play_sounds) + total_frames = data_dict["frame_index"].shape[0] data_dict["index"] = torch.arange(0, total_frames, 1) @@ -443,16 +450,11 @@ def push_lerobot_dataset_to_hub(lerobot_dataset, tags): def create_lerobot_dataset(dataset, run_compute_stats, push_to_hub, tags, play_sounds): - video = dataset["video"] - if "image_writer" in dataset: logging.info("Waiting for image writer to terminate...") image_writer = dataset["image_writer"] stop_image_writer(image_writer, timeout=20) - if video: - encode_videos(dataset, play_sounds) - lerobot_dataset = from_dataset_to_lerobot_dataset(dataset, play_sounds) if run_compute_stats: diff --git a/lerobot/common/robot_devices/control_robot.py b/lerobot/common/robot_devices/control_robot.py index 1bb6b293..f23a90b2 100644 --- a/lerobot/common/robot_devices/control_robot.py +++ b/lerobot/common/robot_devices/control_robot.py @@ -116,9 +116,6 @@ def predict_action(observation, policy, device, use_amp): def init_keyboard_listener(): - # Only import pynput if not in a headless environment - from pynput import keyboard - # Allow to exit early while recording an episode or resetting the environment, # by tapping the right arrow key '->'. This might require a sudo permission # to allow your terminal to monitor keyboard events. @@ -127,6 +124,16 @@ def init_keyboard_listener(): events["rerecord_episode"] = False events["stop_recording"] = False + if is_headless(): + logging.warning( + "Headless environment detected. On-screen cameras display and keyboard inputs will not be available." + ) + listener = None + return listener, events + + # Only import pynput if not in a headless environment + from pynput import keyboard + def on_press(key): try: if key == keyboard.Key.right: @@ -175,7 +182,8 @@ def init_policy(pretrained_policy_name_or_path, policy_overrides, fps): return policy, fps, device, use_amp -def warmup_record(robot, enable_teloperation, warmup_time_s, display_cameras, play_sounds, fps): +def warmup_record(robot, events, enable_teloperation, warmup_time_s, display_cameras, play_sounds, fps): + # TODO(rcadene): refactor warmup_record and reset_environment timestamp = 0 start_warmup_t = time.perf_counter() @@ -203,24 +211,23 @@ def warmup_record(robot, enable_teloperation, warmup_time_s, display_cameras, pl log_control_info(robot, dt_s, fps=fps) timestamp = time.perf_counter() - start_warmup_t + if events is not None and events["exit_early"]: + events["exit_early"] = False + break @safe_stop_image_writer def record_episode( dataset, robot, - episode_index, events, episode_time_s, display_cameras, - play_sounds, policy, device, use_amp, fps, ): - log_say(f"Recording episode {episode_index}", play_sounds) - timestamp = 0 start_episode_t = time.perf_counter() while timestamp < episode_time_s: @@ -258,9 +265,8 @@ def record_episode( break -def reset_environment(robot, events, reset_time_s, play_sounds): - log_say("Reset the environment", play_sounds) - +def reset_environment(robot, events, reset_time_s): + # TODO(rcadene): refactor warmup_record and reset_environment # TODO(alibets): allow for teleop during reset if has_method(robot, "teleop_safety_stop"): robot.teleop_safety_stop() @@ -279,9 +285,7 @@ def reset_environment(robot, events, reset_time_s, play_sounds): break -def done_recording(robot, listener, display_cameras, play_sounds): - log_say("Done recording", play_sounds, blocking=True) - +def stop_recording(robot, listener, display_cameras): robot.disconnect() if not is_headless(): diff --git a/lerobot/scripts/control_robot.py b/lerobot/scripts/control_robot.py index 1f258ccb..2f9f8a73 100644 --- a/lerobot/scripts/control_robot.py +++ b/lerobot/scripts/control_robot.py @@ -99,7 +99,6 @@ python lerobot/scripts/control_robot.py record \ """ import argparse -import logging import time from pathlib import Path from typing import List @@ -108,21 +107,20 @@ from typing import List from lerobot.common.datasets.lerobot_dataset import LeRobotDataset from lerobot.common.datasets.populate_dataset import ( create_lerobot_dataset, - delete_episode, + delete_current_episode, init_dataset, - save_episode, + save_current_episode, ) from lerobot.common.robot_devices.control_robot import ( - done_recording, has_method, init_keyboard_listener, init_policy, - is_headless, log_control_info, log_say, record_episode, reset_environment, sanity_check_dataset_name, + stop_recording, warmup_record, ) from lerobot.common.robot_devices.robots.factory import make_robot @@ -229,9 +227,6 @@ def record( device = None use_amp = None - if not robot.is_connected: - robot.connect() - # Load pretrained policy if pretrained_policy_name_or_path is not None: policy, fps, device, use_amp = init_policy(pretrained_policy_name_or_path, policy_overrides, fps) @@ -249,58 +244,72 @@ def record( num_image_writer_threads=num_image_writer_threads_per_camera * robot.num_cameras, ) - if is_headless(): - logging.warning( - "Headless environment detected. On-screen cameras display and keyboard inputs will not be available." - ) - else: - listener, events = init_keyboard_listener() + if not robot.is_connected: + robot.connect() - # Execute a few seconds without recording data to: + listener, events = init_keyboard_listener() + + # Execute a few seconds without recording to: # 1. teleoperate the robot to move it in starting position if no policy provided, # 2. give times to the robot devices to connect and start synchronizing, # 3. place the cameras windows on screen enable_teleoperation = policy is None - warmup_record(robot, enable_teleoperation, warmup_time_s, display_cameras, play_sounds, fps) + warmup_record(robot, events, enable_teleoperation, warmup_time_s, display_cameras, play_sounds, fps) if has_method(robot, "teleop_safety_stop"): robot.teleop_safety_stop() while True: - episode_index = dataset["current_episode_index"] - if episode_index >= num_episodes: + if dataset["num_episodes"] >= num_episodes: break + episode_index = dataset["num_episodes"] + log_say(f"Recording episode {episode_index}", play_sounds) record_episode( dataset=dataset, robot=robot, - episode_index=episode_index, events=events, episode_time_s=episode_time_s, display_cameras=display_cameras, - play_sounds=play_sounds, policy=policy, device=device, use_amp=use_amp, fps=fps, ) - # Do not reset for the last episode to be recorded - if episode_index < num_episodes - 1: - reset_environment(robot, events, reset_time_s, play_sounds) - - if events is not None and events["rerecord_episode"]: - events["rerecord_episode"] = False - delete_episode(dataset) - continue - - save_episode(dataset) - - if events is not None and events["exit_early"]: - num_episodes = episode_index + # In case stop recording is requested during `record_episode` + if events is not None and events["stop_recording"]: + save_current_episode(dataset) break - done_recording(robot, listener, display_cameras, play_sounds) + # Execute a few seconds without recording to give time to manually reset the environment + # Current code logic doesn't allow to teleoperate during this time. + # TODO(rcadene): add an option to enable teleoperation during reset + # Skip reset for the last episode to be recorded + if episode_index < num_episodes - 1: + log_say("Reset the environment", play_sounds) + reset_environment(robot, events, reset_time_s) + + # In case stop recording is requested during `reset_environment` + if events is not None and events["stop_recording"]: + save_current_episode(dataset) + break + + if events is not None and events["rerecord_episode"]: + log_say("Re-record episode", play_sounds) + events["rerecord_episode"] = False + events["exit_early"] = False + delete_current_episode(dataset) + # Force reset + log_say("Reset the environment", play_sounds) + reset_environment(robot, events, reset_time_s) + continue + + # Increment by one dataset["current_episode_index"] + save_current_episode(dataset) + + log_say("Stop recording", play_sounds, blocking=True) + stop_recording(robot, listener, display_cameras) lerobot_dataset = create_lerobot_dataset(dataset, run_compute_stats, push_to_hub, tags, play_sounds) diff --git a/tests/test_control_robot.py b/tests/test_control_robot.py index 2cd961c3..00683cb9 100644 --- a/tests/test_control_robot.py +++ b/tests/test_control_robot.py @@ -25,9 +25,11 @@ pytest -sx 'tests/test_control_robot.py::test_teleoperate[aloha-True]' import multiprocessing from pathlib import Path +from unittest.mock import patch import pytest +from lerobot.common.datasets.populate_dataset import add_frame from lerobot.common.logger import Logger from lerobot.common.policies.factory import make_policy from lerobot.common.utils.utils import init_hydra_config @@ -222,3 +224,148 @@ def test_record_and_replay_and_policy(tmpdir, request, robot_type, mock): ) del robot + + +@pytest.mark.parametrize("robot_type, mock", [("koch", True)]) +@require_robot +def test_record_with_event_rerecord_episode(tmpdir, request, robot_type, mock): + if mock and robot_type != "aloha": + request.getfixturevalue("patch_builtins_input") + + # Create an empty calibration directory to trigger manual calibration + # and avoid writing calibration files in user .cache/calibration folder + calibration_dir = tmpdir / robot_type + overrides = [f"calibration_dir={calibration_dir}"] + else: + # Use the default .cache/calibration folder when mock=False or for aloha + overrides = [] + + robot = make_robot(robot_type, overrides=overrides, mock=mock) + with ( + patch("lerobot.scripts.control_robot.init_keyboard_listener") as mock_listener, + patch("lerobot.common.robot_devices.control_robot.add_frame", wraps=add_frame) as mock_add_frame, + ): + mock_events = {} + mock_events["exit_early"] = True + mock_events["rerecord_episode"] = True + mock_events["stop_recording"] = False + mock_listener.return_value = (None, mock_events) + + root = Path(tmpdir) / "data" + repo_id = "lerobot/debug" + + dataset = record( + robot, + fps=1, + root=root, + repo_id=repo_id, + warmup_time_s=0, + episode_time_s=1, + num_episodes=1, + push_to_hub=False, + video=False, + display_cameras=False, + play_sounds=False, + ) + + assert not mock_events["rerecord_episode"], "`rerecord_episode` wasn't properly reset to False" + assert not mock_events["exit_early"], "`exit_early` wasn't properly reset to False" + assert mock_add_frame.call_count == 2, "`add_frame` should have been called 2 times" + assert len(dataset) == 1, "`dataset` should contain only 1 frame" + + +@pytest.mark.parametrize("robot_type, mock", [("koch", True)]) +@require_robot +def test_record_with_event_exit_early(tmpdir, request, robot_type, mock): + if mock: + request.getfixturevalue("patch_builtins_input") + + # Create an empty calibration directory to trigger manual calibration + # and avoid writing calibration files in user .cache/calibration folder + calibration_dir = tmpdir / robot_type + overrides = [f"calibration_dir={calibration_dir}"] + else: + # Use the default .cache/calibration folder when mock=False or for aloha + overrides = [] + + robot = make_robot(robot_type, overrides=overrides, mock=mock) + with ( + patch("lerobot.scripts.control_robot.init_keyboard_listener") as mock_listener, + patch("lerobot.common.robot_devices.control_robot.add_frame", wraps=add_frame) as mock_add_frame, + ): + mock_events = {} + mock_events["exit_early"] = True + mock_events["rerecord_episode"] = False + mock_events["stop_recording"] = False + mock_listener.return_value = (None, mock_events) + + root = Path(tmpdir) / "data" + repo_id = "lerobot/debug" + + dataset = record( + robot, + fps=2, + root=root, + repo_id=repo_id, + warmup_time_s=0, + episode_time_s=1, + num_episodes=1, + push_to_hub=False, + video=False, + display_cameras=False, + play_sounds=False, + ) + + assert not mock_events["exit_early"], "`exit_early` wasn't properly reset to False" + assert mock_add_frame.call_count == 1, "`add_frame` should have been called 1 time" + assert len(dataset) == 1, "`dataset` should contain only 1 frame" + + +@pytest.mark.parametrize( + "robot_type, mock, num_image_writer_processes", [("koch", True, 0), ("koch", True, 1)] +) +@require_robot +def test_record_with_event_stop_recording(tmpdir, request, robot_type, mock, num_image_writer_processes): + if mock: + request.getfixturevalue("patch_builtins_input") + + # Create an empty calibration directory to trigger manual calibration + # and avoid writing calibration files in user .cache/calibration folder + calibration_dir = tmpdir / robot_type + overrides = [f"calibration_dir={calibration_dir}"] + else: + # Use the default .cache/calibration folder when mock=False or for aloha + overrides = [] + + robot = make_robot(robot_type, overrides=overrides, mock=mock) + with ( + patch("lerobot.scripts.control_robot.init_keyboard_listener") as mock_listener, + patch("lerobot.common.robot_devices.control_robot.add_frame", wraps=add_frame) as mock_add_frame, + ): + mock_events = {} + mock_events["exit_early"] = True + mock_events["rerecord_episode"] = False + mock_events["stop_recording"] = True + mock_listener.return_value = (None, mock_events) + + root = Path(tmpdir) / "data" + repo_id = "lerobot/debug" + + dataset = record( + robot, + fps=1, + root=root, + repo_id=repo_id, + warmup_time_s=0, + episode_time_s=1, + num_episodes=2, + push_to_hub=False, + video=False, + display_cameras=False, + play_sounds=False, + num_image_writer_processes=num_image_writer_processes, + ) + + assert not mock_events["exit_early"], "`exit_early` wasn't properly reset to False" + assert mock_add_frame.call_count == 1, "`add_frame` should have been called 2 times" + assert len(dataset) == 1, "`dataset` should contain only 1 frame"