From 77478d50e5d5642112d1bad6f1657f08eb1c3378 Mon Sep 17 00:00:00 2001 From: Remi Date: Wed, 16 Oct 2024 20:51:35 +0200 Subject: [PATCH] Refactor `record` with `add_frame` (#468) Co-authored-by: Simon Alibert <75076266+aliberts@users.noreply.github.com> --- .github/workflows/test.yml | 2 + lerobot/common/datasets/populate_dataset.py | 468 ++++++++++ lerobot/common/logger.py | 2 +- lerobot/common/robot_devices/control_utils.py | 330 +++++++ .../robot_devices/robots/manipulator.py | 19 + lerobot/common/utils/utils.py | 34 + lerobot/scripts/control_robot.py | 817 +++--------------- lerobot/scripts/train.py | 2 +- .../templates/visualize_dataset_template.html | 2 +- tests/test_control_robot.py | 260 +++++- tests/test_robots.py | 1 + 11 files changed, 1232 insertions(+), 705 deletions(-) create mode 100644 lerobot/common/datasets/populate_dataset.py create mode 100644 lerobot/common/robot_devices/control_utils.py diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index 10c90f84..9d51def4 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -47,6 +47,7 @@ jobs: pipx install poetry && poetry config virtualenvs.in-project true echo "${{ github.workspace }}/.venv/bin" >> $GITHUB_PATH + # TODO(rcadene, aliberts): python 3.12 seems to be used in the tests, not python 3.10 - name: Set up Python 3.10 uses: actions/setup-python@v5 with: @@ -84,6 +85,7 @@ jobs: pipx install poetry && poetry config virtualenvs.in-project true echo "${{ github.workspace }}/.venv/bin" >> $GITHUB_PATH + # TODO(rcadene, aliberts): python 3.12 seems to be used in the tests, not python 3.10 - name: Set up Python 3.10 uses: actions/setup-python@v5 with: diff --git a/lerobot/common/datasets/populate_dataset.py b/lerobot/common/datasets/populate_dataset.py new file mode 100644 index 00000000..df5d20e5 --- /dev/null +++ b/lerobot/common/datasets/populate_dataset.py @@ -0,0 +1,468 @@ +"""Functions to create an empty dataset, and populate it with frames.""" +# TODO(rcadene, aliberts): to adapt as class methods of next version of LeRobotDataset + +import concurrent +import json +import logging +import multiprocessing +import shutil +from pathlib import Path + +import torch +import tqdm +from PIL import Image + +from lerobot.common.datasets.compute_stats import compute_stats +from lerobot.common.datasets.lerobot_dataset import CODEBASE_VERSION, LeRobotDataset +from lerobot.common.datasets.push_dataset_to_hub.aloha_hdf5_format import to_hf_dataset +from lerobot.common.datasets.push_dataset_to_hub.utils import concatenate_episodes, get_default_encoding +from lerobot.common.datasets.utils import calculate_episode_data_index, create_branch +from lerobot.common.datasets.video_utils import encode_video_frames +from lerobot.common.utils.utils import log_say +from lerobot.scripts.push_dataset_to_hub import ( + push_dataset_card_to_hub, + push_meta_data_to_hub, + push_videos_to_hub, + save_meta_data, +) + +######################################################################################## +# Asynchrounous saving of images on disk +######################################################################################## + + +def safe_stop_image_writer(func): + # TODO(aliberts): Allow to pass custom exceptions + # (e.g. ThreadServiceExit, KeyboardInterrupt, SystemExit, UnpluggedError, DynamixelCommError) + def wrapper(*args, **kwargs): + try: + return func(*args, **kwargs) + except Exception as e: + image_writer = kwargs.get("dataset", {}).get("image_writer") + if image_writer is not None: + print("Waiting for image writer to terminate...") + stop_image_writer(image_writer, timeout=20) + raise e + + return wrapper + + +def save_image(img_tensor, key, frame_index, episode_index, videos_dir: str): + img = Image.fromarray(img_tensor.numpy()) + path = Path(videos_dir) / f"{key}_episode_{episode_index:06d}" / f"frame_{frame_index:06d}.png" + path.parent.mkdir(parents=True, exist_ok=True) + img.save(str(path), quality=100) + + +def loop_to_save_images_in_threads(image_queue, num_threads): + if num_threads < 1: + raise NotImplementedError(f"Only `num_threads>=1` is supported for now, but {num_threads=} given.") + + with concurrent.futures.ThreadPoolExecutor(max_workers=num_threads) as executor: + futures = [] + while True: + # Blocks until a frame is available + frame_data = image_queue.get() + + # As usually done, exit loop when receiving None to stop the worker + if frame_data is None: + break + + image, key, frame_index, episode_index, videos_dir = frame_data + futures.append(executor.submit(save_image, image, key, frame_index, episode_index, videos_dir)) + + # Before exiting function, wait for all threads to complete + with tqdm.tqdm(total=len(futures), desc="Writing images") as progress_bar: + concurrent.futures.wait(futures) + progress_bar.update(len(futures)) + + +def start_image_writer_processes(image_queue, num_processes, num_threads_per_process): + if num_processes < 1: + raise ValueError(f"Only `num_processes>=1` is supported, but {num_processes=} given.") + + if num_threads_per_process < 1: + raise NotImplementedError( + "Only `num_threads_per_process>=1` is supported for now, but {num_threads_per_process=} given." + ) + + processes = [] + for _ in range(num_processes): + process = multiprocessing.Process( + target=loop_to_save_images_in_threads, + args=(image_queue, num_threads_per_process), + ) + process.start() + processes.append(process) + return processes + + +def stop_processes(processes, queue, timeout): + # Send None to each process to signal them to stop + for _ in processes: + queue.put(None) + + # Wait maximum 20 seconds for all processes to terminate + for process in processes: + process.join(timeout=timeout) + + # If not terminated after 20 seconds, force termination + if process.is_alive(): + process.terminate() + + # Close the queue, no more items can be put in the queue + queue.close() + + # Ensure all background queue threads have finished + queue.join_thread() + + +def start_image_writer(num_processes, num_threads): + """This function abstract away the initialisation of processes or/and threads to + save images on disk asynchrounously, which is critical to control a robot and record data + at a high frame rate. + + When `num_processes=0`, it returns a dictionary containing a threads pool of size `num_threads`. + When `num_processes>0`, it returns a dictionary containing a processes pool of size `num_processes`, + where each subprocess starts their own threads pool of size `num_threads`. + + The optimal number of processes and threads depends on your computer capabilities. + We advise to use 4 threads per camera with 0 processes. If the fps is not stable, try to increase or lower + the number of threads. If it is still not stable, try to use 1 subprocess, or more. + """ + image_writer = {} + + if num_processes == 0: + futures = [] + threads_pool = concurrent.futures.ThreadPoolExecutor(max_workers=num_threads) + image_writer["threads_pool"], image_writer["futures"] = threads_pool, futures + else: + # TODO(rcadene): When using num_processes>1, `multiprocessing.Manager().Queue()` + # might be better than `multiprocessing.Queue()`. Source: https://www.geeksforgeeks.org/python-multiprocessing-queue-vs-multiprocessing-manager-queue + image_queue = multiprocessing.Queue() + processes_pool = start_image_writer_processes( + image_queue, num_processes=num_processes, num_threads_per_process=num_threads + ) + image_writer["processes_pool"], image_writer["image_queue"] = processes_pool, image_queue + + return image_writer + + +def async_save_image(image_writer, image, key, frame_index, episode_index, videos_dir): + """This function abstract away the saving of an image on disk asynchrounously. It uses a dictionary + called image writer which contains either a pool of processes or a pool of threads. + """ + if "threads_pool" in image_writer: + threads_pool, futures = image_writer["threads_pool"], image_writer["futures"] + futures.append(threads_pool.submit(save_image, image, key, frame_index, episode_index, videos_dir)) + else: + image_queue = image_writer["image_queue"] + image_queue.put((image, key, frame_index, episode_index, videos_dir)) + + +def stop_image_writer(image_writer, timeout): + if "threads_pool" in image_writer: + futures = image_writer["futures"] + # Before exiting function, wait for all threads to complete + with tqdm.tqdm(total=len(futures), desc="Writing images") as progress_bar: + concurrent.futures.wait(futures, timeout=timeout) + progress_bar.update(len(futures)) + else: + processes_pool, image_queue = image_writer["processes_pool"], image_writer["image_queue"] + stop_processes(processes_pool, image_queue, timeout=timeout) + + +######################################################################################## +# Functions to initialize, resume and populate a dataset +######################################################################################## + + +def init_dataset( + repo_id, + root, + force_override, + fps, + video, + write_images, + num_image_writer_processes, + num_image_writer_threads, +): + local_dir = Path(root) / repo_id + if local_dir.exists() and force_override: + shutil.rmtree(local_dir) + + episodes_dir = local_dir / "episodes" + episodes_dir.mkdir(parents=True, exist_ok=True) + + videos_dir = local_dir / "videos" + videos_dir.mkdir(parents=True, exist_ok=True) + + # Logic to resume data recording + rec_info_path = episodes_dir / "data_recording_info.json" + if rec_info_path.exists(): + with open(rec_info_path) as f: + rec_info = json.load(f) + num_episodes = rec_info["last_episode_index"] + 1 + else: + num_episodes = 0 + + dataset = { + "repo_id": repo_id, + "local_dir": local_dir, + "videos_dir": videos_dir, + "episodes_dir": episodes_dir, + "fps": fps, + "video": video, + "rec_info_path": rec_info_path, + "num_episodes": num_episodes, + } + + if write_images: + # Initialize processes or/and threads dedicated to save images on disk asynchronously, + # which is critical to control a robot and record data at a high frame rate. + image_writer = start_image_writer( + num_processes=num_image_writer_processes, + num_threads=num_image_writer_threads, + ) + dataset["image_writer"] = image_writer + + return dataset + + +def add_frame(dataset, observation, action): + if "current_episode" not in dataset: + # initialize episode dictionary + ep_dict = {} + for key in observation: + if key not in ep_dict: + ep_dict[key] = [] + for key in action: + if key not in ep_dict: + ep_dict[key] = [] + + ep_dict["episode_index"] = [] + ep_dict["frame_index"] = [] + ep_dict["timestamp"] = [] + ep_dict["next.done"] = [] + + dataset["current_episode"] = ep_dict + dataset["current_frame_index"] = 0 + + ep_dict = dataset["current_episode"] + episode_index = dataset["num_episodes"] + frame_index = dataset["current_frame_index"] + videos_dir = dataset["videos_dir"] + video = dataset["video"] + fps = dataset["fps"] + + ep_dict["episode_index"].append(episode_index) + ep_dict["frame_index"].append(frame_index) + ep_dict["timestamp"].append(frame_index / fps) + ep_dict["next.done"].append(False) + + img_keys = [key for key in observation if "image" in key] + non_img_keys = [key for key in observation if "image" not in key] + + # Save all observed modalities except images + for key in non_img_keys: + ep_dict[key].append(observation[key]) + + # Save actions + for key in action: + ep_dict[key].append(action[key]) + + if "image_writer" not in dataset: + dataset["current_frame_index"] += 1 + return + + # Save images + image_writer = dataset["image_writer"] + for key in img_keys: + imgs_dir = videos_dir / f"{key}_episode_{episode_index:06d}" + async_save_image( + image_writer, + image=observation[key], + key=key, + frame_index=frame_index, + episode_index=episode_index, + videos_dir=str(videos_dir), + ) + + if video: + fname = f"{key}_episode_{episode_index:06d}.mp4" + frame_info = {"path": f"videos/{fname}", "timestamp": frame_index / fps} + else: + frame_info = str(imgs_dir / f"frame_{frame_index:06d}.png") + + ep_dict[key].append(frame_info) + + dataset["current_frame_index"] += 1 + + +def delete_current_episode(dataset): + del dataset["current_episode"] + del dataset["current_frame_index"] + + # delete temporary images + episode_index = dataset["num_episodes"] + videos_dir = dataset["videos_dir"] + for tmp_imgs_dir in videos_dir.glob(f"*_episode_{episode_index:06d}"): + shutil.rmtree(tmp_imgs_dir) + + +def save_current_episode(dataset): + episode_index = dataset["num_episodes"] + ep_dict = dataset["current_episode"] + episodes_dir = dataset["episodes_dir"] + rec_info_path = dataset["rec_info_path"] + + ep_dict["next.done"][-1] = True + + for key in ep_dict: + if "observation" in key and "image" not in key: + ep_dict[key] = torch.stack(ep_dict[key]) + + ep_dict["action"] = torch.stack(ep_dict["action"]) + ep_dict["episode_index"] = torch.tensor(ep_dict["episode_index"]) + ep_dict["frame_index"] = torch.tensor(ep_dict["frame_index"]) + ep_dict["timestamp"] = torch.tensor(ep_dict["timestamp"]) + ep_dict["next.done"] = torch.tensor(ep_dict["next.done"]) + + ep_path = episodes_dir / f"episode_{episode_index}.pth" + torch.save(ep_dict, ep_path) + + rec_info = { + "last_episode_index": episode_index, + } + with open(rec_info_path, "w") as f: + json.dump(rec_info, f) + + # force re-initialization of episode dictionnary during add_frame + del dataset["current_episode"] + + dataset["num_episodes"] += 1 + + +def encode_videos(dataset, image_keys, play_sounds): + log_say("Encoding videos", play_sounds) + + num_episodes = dataset["num_episodes"] + videos_dir = dataset["videos_dir"] + local_dir = dataset["local_dir"] + fps = dataset["fps"] + + # Use ffmpeg to convert frames stored as png into mp4 videos + for episode_index in tqdm.tqdm(range(num_episodes)): + for key in image_keys: + # key = f"observation.images.{name}" + tmp_imgs_dir = videos_dir / f"{key}_episode_{episode_index:06d}" + fname = f"{key}_episode_{episode_index:06d}.mp4" + video_path = local_dir / "videos" / fname + if video_path.exists(): + # Skip if video is already encoded. Could be the case when resuming data recording. + continue + # note: `encode_video_frames` is a blocking call. Making it asynchronous shouldn't speedup encoding, + # since video encoding with ffmpeg is already using multithreading. + encode_video_frames(tmp_imgs_dir, video_path, fps, overwrite=True) + shutil.rmtree(tmp_imgs_dir) + + +def from_dataset_to_lerobot_dataset(dataset, play_sounds): + log_say("Consolidate episodes", play_sounds) + + num_episodes = dataset["num_episodes"] + episodes_dir = dataset["episodes_dir"] + videos_dir = dataset["videos_dir"] + video = dataset["video"] + fps = dataset["fps"] + repo_id = dataset["repo_id"] + + ep_dicts = [] + for episode_index in tqdm.tqdm(range(num_episodes)): + ep_path = episodes_dir / f"episode_{episode_index}.pth" + ep_dict = torch.load(ep_path) + ep_dicts.append(ep_dict) + data_dict = concatenate_episodes(ep_dicts) + + if video: + image_keys = [key for key in data_dict if "image" in key] + encode_videos(dataset, image_keys, play_sounds) + + hf_dataset = to_hf_dataset(data_dict, video) + episode_data_index = calculate_episode_data_index(hf_dataset) + + info = { + "codebase_version": CODEBASE_VERSION, + "fps": fps, + "video": video, + } + if video: + info["encoding"] = get_default_encoding() + + lerobot_dataset = LeRobotDataset.from_preloaded( + repo_id=repo_id, + hf_dataset=hf_dataset, + episode_data_index=episode_data_index, + info=info, + videos_dir=videos_dir, + ) + + return lerobot_dataset + + +def save_lerobot_dataset_on_disk(lerobot_dataset): + hf_dataset = lerobot_dataset.hf_dataset + info = lerobot_dataset.info + stats = lerobot_dataset.stats + episode_data_index = lerobot_dataset.episode_data_index + local_dir = lerobot_dataset.videos_dir.parent + meta_data_dir = local_dir / "meta_data" + + hf_dataset = hf_dataset.with_format(None) # to remove transforms that cant be saved + hf_dataset.save_to_disk(str(local_dir / "train")) + + save_meta_data(info, stats, episode_data_index, meta_data_dir) + + +def push_lerobot_dataset_to_hub(lerobot_dataset, tags): + hf_dataset = lerobot_dataset.hf_dataset + local_dir = lerobot_dataset.videos_dir.parent + videos_dir = lerobot_dataset.videos_dir + repo_id = lerobot_dataset.repo_id + video = lerobot_dataset.video + meta_data_dir = local_dir / "meta_data" + + if not (local_dir / "train").exists(): + raise ValueError( + "You need to run `save_lerobot_dataset_on_disk(lerobot_dataset)` before pushing to the hub." + ) + + hf_dataset.push_to_hub(repo_id, revision="main") + push_meta_data_to_hub(repo_id, meta_data_dir, revision="main") + push_dataset_card_to_hub(repo_id, revision="main", tags=tags) + if video: + push_videos_to_hub(repo_id, videos_dir, revision="main") + create_branch(repo_id, repo_type="dataset", branch=CODEBASE_VERSION) + + +def create_lerobot_dataset(dataset, run_compute_stats, push_to_hub, tags, play_sounds): + if "image_writer" in dataset: + logging.info("Waiting for image writer to terminate...") + image_writer = dataset["image_writer"] + stop_image_writer(image_writer, timeout=20) + + lerobot_dataset = from_dataset_to_lerobot_dataset(dataset, play_sounds) + + if run_compute_stats: + log_say("Computing dataset statistics", play_sounds) + lerobot_dataset.stats = compute_stats(lerobot_dataset) + else: + logging.info("Skipping computation of the dataset statistics") + lerobot_dataset.stats = {} + + save_lerobot_dataset_on_disk(lerobot_dataset) + + if push_to_hub: + push_lerobot_dataset_to_hub(lerobot_dataset, tags) + + return lerobot_dataset diff --git a/lerobot/common/logger.py b/lerobot/common/logger.py index bf578fcc..3bd2df89 100644 --- a/lerobot/common/logger.py +++ b/lerobot/common/logger.py @@ -189,7 +189,7 @@ class Logger: training_state["scheduler"] = scheduler.state_dict() torch.save(training_state, save_dir / self.training_state_file_name) - def save_checkpont( + def save_checkpoint( self, train_step: int, policy: Policy, diff --git a/lerobot/common/robot_devices/control_utils.py b/lerobot/common/robot_devices/control_utils.py new file mode 100644 index 00000000..08bcec2e --- /dev/null +++ b/lerobot/common/robot_devices/control_utils.py @@ -0,0 +1,330 @@ +######################################################################################## +# Utilities +######################################################################################## + + +import logging +import time +import traceback +from contextlib import nullcontext +from copy import copy +from functools import cache + +import cv2 +import torch +import tqdm +from termcolor import colored + +from lerobot.common.datasets.populate_dataset import add_frame, safe_stop_image_writer +from lerobot.common.policies.factory import make_policy +from lerobot.common.robot_devices.robots.utils import Robot +from lerobot.common.robot_devices.utils import busy_wait +from lerobot.common.utils.utils import get_safe_torch_device, init_hydra_config, set_global_seed +from lerobot.scripts.eval import get_pretrained_policy_path + + +def log_control_info(robot: Robot, dt_s, episode_index=None, frame_index=None, fps=None): + log_items = [] + if episode_index is not None: + log_items.append(f"ep:{episode_index}") + if frame_index is not None: + log_items.append(f"frame:{frame_index}") + + def log_dt(shortname, dt_val_s): + nonlocal log_items, fps + info_str = f"{shortname}:{dt_val_s * 1000:5.2f} ({1/ dt_val_s:3.1f}hz)" + if fps is not None: + actual_fps = 1 / dt_val_s + if actual_fps < fps - 1: + info_str = colored(info_str, "yellow") + log_items.append(info_str) + + # total step time displayed in milliseconds and its frequency + log_dt("dt", dt_s) + + # TODO(aliberts): move robot-specific logs logic in robot.print_logs() + if not robot.robot_type.startswith("stretch"): + for name in robot.leader_arms: + key = f"read_leader_{name}_pos_dt_s" + if key in robot.logs: + log_dt("dtRlead", robot.logs[key]) + + for name in robot.follower_arms: + key = f"write_follower_{name}_goal_pos_dt_s" + if key in robot.logs: + log_dt("dtWfoll", robot.logs[key]) + + key = f"read_follower_{name}_pos_dt_s" + if key in robot.logs: + log_dt("dtRfoll", robot.logs[key]) + + for name in robot.cameras: + key = f"read_camera_{name}_dt_s" + if key in robot.logs: + log_dt(f"dtR{name}", robot.logs[key]) + + info_str = " ".join(log_items) + logging.info(info_str) + + +@cache +def is_headless(): + """Detects if python is running without a monitor.""" + try: + import pynput # noqa + + return False + except Exception: + print( + "Error trying to import pynput. Switching to headless mode. " + "As a result, the video stream from the cameras won't be shown, " + "and you won't be able to change the control flow with keyboards. " + "For more info, see traceback below.\n" + ) + traceback.print_exc() + print() + return True + + +def has_method(_object: object, method_name: str): + return hasattr(_object, method_name) and callable(getattr(_object, method_name)) + + +def predict_action(observation, policy, device, use_amp): + observation = copy(observation) + with ( + torch.inference_mode(), + torch.autocast(device_type=device.type) if device.type == "cuda" and use_amp else nullcontext(), + ): + # Convert to pytorch format: channel first and float32 in [0,1] with batch dimension + for name in observation: + if "image" in name: + observation[name] = observation[name].type(torch.float32) / 255 + observation[name] = observation[name].permute(2, 0, 1).contiguous() + observation[name] = observation[name].unsqueeze(0) + observation[name] = observation[name].to(device) + + # Compute the next action with the policy + # based on the current observation + action = policy.select_action(observation) + + # Remove batch dimension + action = action.squeeze(0) + + # Move to cpu, if not already the case + action = action.to("cpu") + + return action + + +def init_keyboard_listener(): + # Allow to exit early while recording an episode or resetting the environment, + # by tapping the right arrow key '->'. This might require a sudo permission + # to allow your terminal to monitor keyboard events. + events = {} + events["exit_early"] = False + events["rerecord_episode"] = False + events["stop_recording"] = False + + if is_headless(): + logging.warning( + "Headless environment detected. On-screen cameras display and keyboard inputs will not be available." + ) + listener = None + return listener, events + + # Only import pynput if not in a headless environment + from pynput import keyboard + + def on_press(key): + try: + if key == keyboard.Key.right: + print("Right arrow key pressed. Exiting loop...") + events["exit_early"] = True + elif key == keyboard.Key.left: + print("Left arrow key pressed. Exiting loop and rerecord the last episode...") + events["rerecord_episode"] = True + events["exit_early"] = True + elif key == keyboard.Key.esc: + print("Escape key pressed. Stopping data recording...") + events["stop_recording"] = True + events["exit_early"] = True + except Exception as e: + print(f"Error handling key press: {e}") + + listener = keyboard.Listener(on_press=on_press) + listener.start() + + return listener, events + + +def init_policy(pretrained_policy_name_or_path, policy_overrides): + """Instantiate the policy and load fps, device and use_amp from config yaml""" + pretrained_policy_path = get_pretrained_policy_path(pretrained_policy_name_or_path) + hydra_cfg = init_hydra_config(pretrained_policy_path / "config.yaml", policy_overrides) + policy = make_policy(hydra_cfg=hydra_cfg, pretrained_policy_name_or_path=pretrained_policy_path) + + # Check device is available + device = get_safe_torch_device(hydra_cfg.device, log=True) + use_amp = hydra_cfg.use_amp + policy_fps = hydra_cfg.env.fps + + policy.eval() + policy.to(device) + + torch.backends.cudnn.benchmark = True + torch.backends.cuda.matmul.allow_tf32 = True + set_global_seed(hydra_cfg.seed) + return policy, policy_fps, device, use_amp + + +def warmup_record( + robot, + events, + enable_teloperation, + warmup_time_s, + display_cameras, + fps, +): + control_loop( + robot=robot, + control_time_s=warmup_time_s, + display_cameras=display_cameras, + events=events, + fps=fps, + teleoperate=enable_teloperation, + ) + + +def record_episode( + robot, + dataset, + events, + episode_time_s, + display_cameras, + policy, + device, + use_amp, + fps, +): + control_loop( + robot=robot, + control_time_s=episode_time_s, + display_cameras=display_cameras, + dataset=dataset, + events=events, + policy=policy, + device=device, + use_amp=use_amp, + fps=fps, + teleoperate=policy is None, + ) + + +@safe_stop_image_writer +def control_loop( + robot, + control_time_s=None, + teleoperate=False, + display_cameras=False, + dataset=None, + events=None, + policy=None, + device=None, + use_amp=None, + fps=None, +): + # TODO(rcadene): Add option to record logs + if not robot.is_connected: + robot.connect() + + if events is None: + events = {"exit_early": False} + + if control_time_s is None: + control_time_s = float("inf") + + if teleoperate and policy is not None: + raise ValueError("When `teleoperate` is True, `policy` should be None.") + + if dataset is not None and fps is not None and dataset["fps"] != fps: + raise ValueError(f"The dataset fps should be equal to requested fps ({dataset['fps']} != {fps}).") + + timestamp = 0 + start_episode_t = time.perf_counter() + while timestamp < control_time_s: + start_loop_t = time.perf_counter() + + if teleoperate: + observation, action = robot.teleop_step(record_data=True) + else: + observation = robot.capture_observation() + + if policy is not None: + pred_action = predict_action(observation, policy, device, use_amp) + # Action can eventually be clipped using `max_relative_target`, + # so action actually sent is saved in the dataset. + action = robot.send_action(pred_action) + action = {"action": action} + + if dataset is not None: + add_frame(dataset, observation, action) + + if display_cameras and not is_headless(): + image_keys = [key for key in observation if "image" in key] + for key in image_keys: + cv2.imshow(key, cv2.cvtColor(observation[key].numpy(), cv2.COLOR_RGB2BGR)) + cv2.waitKey(1) + + if fps is not None: + dt_s = time.perf_counter() - start_loop_t + busy_wait(1 / fps - dt_s) + + dt_s = time.perf_counter() - start_loop_t + log_control_info(robot, dt_s, fps=fps) + + timestamp = time.perf_counter() - start_episode_t + if events["exit_early"]: + events["exit_early"] = False + break + + +def reset_environment(robot, events, reset_time_s): + # TODO(rcadene): refactor warmup_record and reset_environment + # TODO(alibets): allow for teleop during reset + if has_method(robot, "teleop_safety_stop"): + robot.teleop_safety_stop() + + timestamp = 0 + start_vencod_t = time.perf_counter() + + # Wait if necessary + with tqdm.tqdm(total=reset_time_s, desc="Waiting") as pbar: + while timestamp < reset_time_s: + time.sleep(1) + timestamp = time.perf_counter() - start_vencod_t + pbar.update(1) + if events["exit_early"]: + events["exit_early"] = False + break + + +def stop_recording(robot, listener, display_cameras): + robot.disconnect() + + if not is_headless(): + if listener is not None: + listener.stop() + + if display_cameras: + cv2.destroyAllWindows() + + +def sanity_check_dataset_name(repo_id, policy): + _, dataset_name = repo_id.split("/") + # either repo_id doesnt start with "eval_" and there is no policy + # or repo_id starts with "eval_" and there is a policy + if dataset_name.startswith("eval_") == (policy is None): + raise ValueError( + f"Your dataset name begins by 'eval_' ({dataset_name}) but no policy is provided ({policy})." + ) diff --git a/lerobot/common/robot_devices/robots/manipulator.py b/lerobot/common/robot_devices/robots/manipulator.py index 6ab900fb..20969c30 100644 --- a/lerobot/common/robot_devices/robots/manipulator.py +++ b/lerobot/common/robot_devices/robots/manipulator.py @@ -349,6 +349,25 @@ class ManipulatorRobot: self.is_connected = False self.logs = {} + @property + def has_camera(self): + return len(self.cameras) > 0 + + @property + def num_cameras(self): + return len(self.cameras) + + @property + def available_arms(self): + available_arms = [] + for name in self.follower_arms: + arm_id = get_arm_id(name, "follower") + available_arms.append(arm_id) + for name in self.leader_arms: + arm_id = get_arm_id(name, "leader") + available_arms.append(arm_id) + return available_arms + def connect(self): if self.is_connected: raise RobotDeviceAlreadyConnectedError( diff --git a/lerobot/common/utils/utils.py b/lerobot/common/utils/utils.py index 1aa0bc2d..554e054e 100644 --- a/lerobot/common/utils/utils.py +++ b/lerobot/common/utils/utils.py @@ -16,6 +16,7 @@ import logging import os import os.path as osp +import platform import random from contextlib import contextmanager from datetime import datetime, timezone @@ -28,6 +29,12 @@ import torch from omegaconf import DictConfig +def none_or_int(value): + if value == "None": + return None + return int(value) + + def inside_slurm(): """Check whether the python process was launched through slurm""" # TODO(rcadene): return False for interactive mode `--pty bash` @@ -183,3 +190,30 @@ def print_cuda_memory_usage(): def capture_timestamp_utc(): return datetime.now(timezone.utc) + + +def say(text, blocking=False): + # Check if mac, linux, or windows. + if platform.system() == "Darwin": + cmd = f'say "{text}"' + elif platform.system() == "Linux": + cmd = f'spd-say "{text}"' + elif platform.system() == "Windows": + cmd = ( + 'PowerShell -Command "Add-Type -AssemblyName System.Speech; ' + f"(New-Object System.Speech.Synthesis.SpeechSynthesizer).Speak('{text}')\"" + ) + + if not blocking and platform.system() in ["Darwin", "Linux"]: + # TODO(rcadene): Make it work for Windows + # Use the ampersand to run command in the background + cmd += " &" + + os.system(cmd) + + +def log_say(text, play_sounds, blocking=False): + logging.info(text) + + if play_sounds: + say(text, blocking) diff --git a/lerobot/scripts/control_robot.py b/lerobot/scripts/control_robot.py index 3b6345b4..425247e6 100644 --- a/lerobot/scripts/control_robot.py +++ b/lerobot/scripts/control_robot.py @@ -99,285 +99,35 @@ python lerobot/scripts/control_robot.py record \ """ import argparse -import concurrent.futures -import json import logging -import multiprocessing -import os -import platform -import shutil import time -import traceback -from contextlib import nullcontext -from functools import cache from pathlib import Path - -import cv2 -import torch -import tqdm -from omegaconf import DictConfig -from PIL import Image -from termcolor import colored +from typing import List # from safetensors.torch import load_file, save_file -from lerobot.common.datasets.compute_stats import compute_stats -from lerobot.common.datasets.lerobot_dataset import CODEBASE_VERSION, LeRobotDataset -from lerobot.common.datasets.push_dataset_to_hub.aloha_hdf5_format import to_hf_dataset -from lerobot.common.datasets.push_dataset_to_hub.utils import concatenate_episodes, get_default_encoding -from lerobot.common.datasets.utils import calculate_episode_data_index, create_branch -from lerobot.common.datasets.video_utils import encode_video_frames -from lerobot.common.policies.factory import make_policy -from lerobot.common.robot_devices.robots.factory import make_robot -from lerobot.common.robot_devices.robots.utils import Robot, get_arm_id -from lerobot.common.robot_devices.utils import busy_wait, safe_disconnect -from lerobot.common.utils.utils import get_safe_torch_device, init_hydra_config, init_logging, set_global_seed -from lerobot.scripts.eval import get_pretrained_policy_path -from lerobot.scripts.push_dataset_to_hub import ( - push_dataset_card_to_hub, - push_meta_data_to_hub, - push_videos_to_hub, - save_meta_data, +from lerobot.common.datasets.lerobot_dataset import LeRobotDataset +from lerobot.common.datasets.populate_dataset import ( + create_lerobot_dataset, + delete_current_episode, + init_dataset, + save_current_episode, ) - -######################################################################################## -# Utilities -######################################################################################## - - -def say(text, blocking=False): - # Check if mac, linux, or windows. - if platform.system() == "Darwin": - cmd = f'say "{text}"' - elif platform.system() == "Linux": - cmd = f'spd-say "{text}"' - elif platform.system() == "Windows": - cmd = ( - 'PowerShell -Command "Add-Type -AssemblyName System.Speech; ' - f"(New-Object System.Speech.Synthesis.SpeechSynthesizer).Speak('{text}')\"" - ) - - if not blocking and platform.system() in ["Darwin", "Linux"]: - # TODO(rcadene): Make it work for Windows - # Use the ampersand to run command in the background - cmd += " &" - - os.system(cmd) - - -def save_image(img_tensor, key, frame_index, episode_index, videos_dir: str): - img = Image.fromarray(img_tensor.numpy()) - path = Path(videos_dir) / f"{key}_episode_{episode_index:06d}" / f"frame_{frame_index:06d}.png" - path.parent.mkdir(parents=True, exist_ok=True) - img.save(str(path), quality=100) - - -def none_or_int(value): - if value == "None": - return None - return int(value) - - -def log_control_info(robot: Robot, dt_s, episode_index=None, frame_index=None, fps=None): - log_items = [] - if episode_index is not None: - log_items.append(f"ep:{episode_index}") - if frame_index is not None: - log_items.append(f"frame:{frame_index}") - - def log_dt(shortname, dt_val_s): - nonlocal log_items, fps - info_str = f"{shortname}:{dt_val_s * 1000:5.2f} ({1/ dt_val_s:3.1f}hz)" - if fps is not None: - actual_fps = 1 / dt_val_s - if actual_fps < fps - 1: - info_str = colored(info_str, "yellow") - log_items.append(info_str) - - # total step time displayed in milliseconds and its frequency - log_dt("dt", dt_s) - - # TODO(aliberts): move robot-specific logs logic in robot.print_logs() - if not robot.robot_type.startswith("stretch"): - for name in robot.leader_arms: - key = f"read_leader_{name}_pos_dt_s" - if key in robot.logs: - log_dt("dtRlead", robot.logs[key]) - - for name in robot.follower_arms: - key = f"write_follower_{name}_goal_pos_dt_s" - if key in robot.logs: - log_dt("dtWfoll", robot.logs[key]) - - key = f"read_follower_{name}_pos_dt_s" - if key in robot.logs: - log_dt("dtRfoll", robot.logs[key]) - - for name in robot.cameras: - key = f"read_camera_{name}_dt_s" - if key in robot.logs: - log_dt(f"dtR{name}", robot.logs[key]) - - info_str = " ".join(log_items) - logging.info(info_str) - - -@cache -def is_headless(): - """Detects if python is running without a monitor.""" - try: - import pynput # noqa - - return False - except Exception: - print( - "Error trying to import pynput. Switching to headless mode. " - "As a result, the video stream from the cameras won't be shown, " - "and you won't be able to change the control flow with keyboards. " - "For more info, see traceback below.\n" - ) - traceback.print_exc() - print() - return True - - -def has_method(_object: object, method_name: str): - return hasattr(_object, method_name) and callable(getattr(_object, method_name)) - - -def get_available_arms(robot): - # TODO(rcadene): moves this function in manipulator class? - available_arms = [] - for name in robot.follower_arms: - arm_id = get_arm_id(name, "follower") - available_arms.append(arm_id) - for name in robot.leader_arms: - arm_id = get_arm_id(name, "leader") - available_arms.append(arm_id) - return available_arms - - -######################################################################################## -# Asynchrounous saving of images on disk -######################################################################################## - - -def loop_to_save_images_in_threads(image_queue, num_threads): - if num_threads < 1: - raise NotImplementedError(f"Only `num_threads>=1` is supported for now, but {num_threads=} given.") - - with concurrent.futures.ThreadPoolExecutor(max_workers=num_threads) as executor: - futures = [] - while True: - # Blocks until a frame is available - frame_data = image_queue.get() - - # As usually done, exit loop when receiving None to stop the worker - if frame_data is None: - break - - image, key, frame_index, episode_index, videos_dir = frame_data - futures.append(executor.submit(save_image, image, key, frame_index, episode_index, videos_dir)) - - # Before exiting function, wait for all threads to complete - with tqdm.tqdm(total=len(futures), desc="Writing images") as progress_bar: - concurrent.futures.wait(futures) - progress_bar.update(len(futures)) - - -def start_image_writer_processes(image_queue, num_processes, num_threads_per_process): - if num_processes < 1: - raise ValueError(f"Only `num_processes>=1` is supported, but {num_processes=} given.") - - if num_threads_per_process < 1: - raise NotImplementedError( - "Only `num_threads_per_process>=1` is supported for now, but {num_threads_per_process=} given." - ) - - processes = [] - for _ in range(num_processes): - process = multiprocessing.Process( - target=loop_to_save_images_in_threads, - args=(image_queue, num_threads_per_process), - ) - process.start() - processes.append(process) - return processes - - -def stop_processes(processes, queue, timeout): - # Send None to each process to signal them to stop - for _ in processes: - queue.put(None) - - # Close the queue, no more items can be put in the queue - queue.close() - - # Wait maximum 20 seconds for all processes to terminate - for process in processes: - process.join(timeout=timeout) - - # If not terminated after 20 seconds, force termination - if process.is_alive(): - process.terminate() - - # Ensure all background queue threads have finished - queue.join_thread() - - -def start_image_writer(num_processes, num_threads): - """This function abstract away the initialisation of processes or/and threads to - save images on disk asynchrounously, which is critical to control a robot and record data - at a high frame rate. - - When `num_processes=0`, it returns a dictionary containing a threads pool of size `num_threads`. - When `num_processes>0`, it returns a dictionary containing a processes pool of size `num_processes`, - where each subprocess starts their own threads pool of size `num_threads`. - - The optimal number of processes and threads depends on your computer capabilities. - We advise to use 4 threads per camera with 0 processes. If the fps is not stable, try to increase or lower - the number of threads. If it is still not stable, try to use 1 subprocess, or more. - """ - image_writer = {} - - if num_processes == 0: - futures = [] - threads_pool = concurrent.futures.ThreadPoolExecutor(max_workers=num_threads) - image_writer["threads_pool"], image_writer["futures"] = threads_pool, futures - else: - # TODO(rcadene): When using num_processes>1, `multiprocessing.Manager().Queue()` - # might be better than `multiprocessing.Queue()`. Source: https://www.geeksforgeeks.org/python-multiprocessing-queue-vs-multiprocessing-manager-queue - image_queue = multiprocessing.Queue() - processes_pool = start_image_writer_processes( - image_queue, num_processes=num_processes, num_threads_per_process=num_threads - ) - image_writer["processes_pool"], image_writer["image_queue"] = processes_pool, image_queue - - return image_writer - - -def async_save_image(image_writer, image, key, frame_index, episode_index, videos_dir): - """This function abstract away the saving of an image on disk asynchrounously. It uses a dictionary - called image writer which contains either a pool of processes or a pool of threads. - """ - if "threads_pool" in image_writer: - threads_pool, futures = image_writer["threads_pool"], image_writer["futures"] - futures.append(threads_pool.submit(save_image, image, key, frame_index, episode_index, videos_dir)) - else: - image_queue = image_writer["image_queue"] - image_queue.put((image, key, frame_index, episode_index, videos_dir)) - - -def stop_image_writer(image_writer, timeout): - if "threads_pool" in image_writer: - futures = image_writer["futures"] - # Before exiting function, wait for all threads to complete - with tqdm.tqdm(total=len(futures), desc="Writing images") as progress_bar: - concurrent.futures.wait(futures, timeout=timeout) - progress_bar.update(len(futures)) - else: - processes_pool, image_queue = image_writer["processes_pool"], image_writer["image_queue"] - stop_processes(processes_pool, image_queue, timeout=timeout) - +from lerobot.common.robot_devices.control_utils import ( + control_loop, + has_method, + init_keyboard_listener, + init_policy, + log_control_info, + record_episode, + reset_environment, + sanity_check_dataset_name, + stop_recording, + warmup_record, +) +from lerobot.common.robot_devices.robots.factory import make_robot +from lerobot.common.robot_devices.robots.utils import Robot +from lerobot.common.robot_devices.utils import busy_wait, safe_disconnect +from lerobot.common.utils.utils import init_hydra_config, init_logging, log_say, none_or_int ######################################################################################## # Control modes @@ -394,9 +144,8 @@ def calibrate(robot: Robot, arms: list[str] | None): robot.home() return - available_arms = get_available_arms(robot) - unknown_arms = [arm_id for arm_id in arms if arm_id not in available_arms] - available_arms_str = " ".join(available_arms) + unknown_arms = [arm_id for arm_id in arms if arm_id not in robot.available_arms] + available_arms_str = " ".join(robot.available_arms) unknown_arms_str = " ".join(unknown_arms) if arms is None or len(arms) == 0: @@ -429,35 +178,26 @@ def calibrate(robot: Robot, arms: list[str] | None): @safe_disconnect -def teleoperate(robot: Robot, fps: int | None = None, teleop_time_s: float | None = None): - # TODO(rcadene): Add option to record logs - if not robot.is_connected: - robot.connect() - - start_teleop_t = time.perf_counter() - while True: - start_loop_t = time.perf_counter() - robot.teleop_step() - - if fps is not None: - dt_s = time.perf_counter() - start_loop_t - busy_wait(1 / fps - dt_s) - - dt_s = time.perf_counter() - start_loop_t - log_control_info(robot, dt_s, fps=fps) - - if teleop_time_s is not None and time.perf_counter() - start_teleop_t > teleop_time_s: - break +def teleoperate( + robot: Robot, fps: int | None = None, teleop_time_s: float | None = None, display_cameras: bool = False +): + control_loop( + robot, + control_time_s=teleop_time_s, + fps=fps, + teleoperate=True, + display_cameras=display_cameras, + ) @safe_disconnect def record( robot: Robot, - policy: torch.nn.Module | None = None, - hydra_cfg: DictConfig | None = None, + root: str, + repo_id: str, + pretrained_policy_name_or_path: str | None = None, + policy_overrides: List[str] | None = None, fps: int | None = None, - root="data", - repo_id="lerobot/debug", warmup_time_s=2, episode_time_s=10, reset_time_s=5, @@ -473,407 +213,108 @@ def record( play_sounds=True, ): # TODO(rcadene): Add option to record logs - # TODO(rcadene): Clean this function via decomposition in higher level functions + listener = None + events = None + policy = None + device = None + use_amp = None - _, dataset_name = repo_id.split("/") - if dataset_name.startswith("eval_") and policy is None: - raise ValueError( - f"Your dataset name begins by 'eval_' ({dataset_name}) but no policy is provided ({policy})." - ) + # Load pretrained policy + if pretrained_policy_name_or_path is not None: + policy, policy_fps, device, use_amp = init_policy(pretrained_policy_name_or_path, policy_overrides) + + if fps is None: + fps = policy_fps + logging.warning(f"No fps provided, so using the fps from policy config ({policy_fps}).") + elif fps != policy_fps: + logging.warning( + f"There is a mismatch between the provided fps ({fps}) and the one from policy config ({policy_fps})." + ) + + # Create empty dataset or load existing saved episodes + sanity_check_dataset_name(repo_id, policy) + dataset = init_dataset( + repo_id, + root, + force_override, + fps, + video, + write_images=robot.has_camera, + num_image_writer_processes=num_image_writer_processes, + num_image_writer_threads=num_image_writer_threads_per_camera * robot.num_cameras, + ) if not robot.is_connected: robot.connect() - local_dir = Path(root) / repo_id - if local_dir.exists() and force_override: - shutil.rmtree(local_dir) + listener, events = init_keyboard_listener() - episodes_dir = local_dir / "episodes" - episodes_dir.mkdir(parents=True, exist_ok=True) - - videos_dir = local_dir / "videos" - videos_dir.mkdir(parents=True, exist_ok=True) - - # Logic to resume data recording - rec_info_path = episodes_dir / "data_recording_info.json" - if rec_info_path.exists(): - with open(rec_info_path) as f: - rec_info = json.load(f) - episode_index = rec_info["last_episode_index"] + 1 - else: - episode_index = 0 - - if is_headless(): - logging.warning( - "Headless environment detected. On-screen cameras display and keyboard inputs will not be available." - ) - - # Allow to exit early while recording an episode or resetting the environment, - # by tapping the right arrow key '->'. This might require a sudo permission - # to allow your terminal to monitor keyboard events. - exit_early = False - rerecord_episode = False - stop_recording = False - - # Only import pynput if not in a headless environment - if not is_headless(): - from pynput import keyboard - - def on_press(key): - nonlocal exit_early, rerecord_episode, stop_recording - try: - if key == keyboard.Key.right: - print("Right arrow key pressed. Exiting loop...") - exit_early = True - elif key == keyboard.Key.left: - print("Left arrow key pressed. Exiting loop and rerecord the last episode...") - rerecord_episode = True - exit_early = True - elif key == keyboard.Key.esc: - print("Escape key pressed. Stopping data recording...") - stop_recording = True - exit_early = True - except Exception as e: - print(f"Error handling key press: {e}") - - listener = keyboard.Listener(on_press=on_press) - listener.start() - - # Load policy if any - if policy is not None: - # Check device is available - device = get_safe_torch_device(hydra_cfg.device, log=True) - - policy.eval() - policy.to(device) - - torch.backends.cudnn.benchmark = True - torch.backends.cuda.matmul.allow_tf32 = True - set_global_seed(hydra_cfg.seed) - - # override fps using policy fps - fps = hydra_cfg.env.fps - - # Execute a few seconds without recording data, to give times - # to the robot devices to connect and start synchronizing. - timestamp = 0 - start_warmup_t = time.perf_counter() - is_warmup_print = False - while timestamp < warmup_time_s: - if not is_warmup_print: - logging.info("Warming up (no data recording)") - if play_sounds: - say("Warming up") - is_warmup_print = True - - start_loop_t = time.perf_counter() - - if policy is None: - observation, action = robot.teleop_step(record_data=True) - else: - observation = robot.capture_observation() - - if display_cameras and not is_headless(): - image_keys = [key for key in observation if "image" in key] - for key in image_keys: - cv2.imshow(key, cv2.cvtColor(observation[key].numpy(), cv2.COLOR_RGB2BGR)) - cv2.waitKey(1) - - dt_s = time.perf_counter() - start_loop_t - busy_wait(1 / fps - dt_s) - - dt_s = time.perf_counter() - start_loop_t - log_control_info(robot, dt_s, fps=fps) - - timestamp = time.perf_counter() - start_warmup_t + # Execute a few seconds without recording to: + # 1. teleoperate the robot to move it in starting position if no policy provided, + # 2. give times to the robot devices to connect and start synchronizing, + # 3. place the cameras windows on screen + enable_teleoperation = policy is None + log_say("Warmup record", play_sounds) + warmup_record(robot, events, enable_teleoperation, warmup_time_s, display_cameras, fps) if has_method(robot, "teleop_safety_stop"): robot.teleop_safety_stop() - has_camera = len(robot.cameras) > 0 - if has_camera: - # Initialize processes or/and threads dedicated to save images on disk asynchronously, - # which is critical to control a robot and record data at a high frame rate. - image_writer = start_image_writer( - num_processes=num_image_writer_processes, - num_threads=num_image_writer_threads_per_camera * len(robot.cameras), + while True: + if dataset["num_episodes"] >= num_episodes: + break + + episode_index = dataset["num_episodes"] + log_say(f"Recording episode {episode_index}", play_sounds) + record_episode( + dataset=dataset, + robot=robot, + events=events, + episode_time_s=episode_time_s, + display_cameras=display_cameras, + policy=policy, + device=device, + use_amp=use_amp, + fps=fps, ) - # Using `try` to exist smoothly if an exception is raised - try: - # Start recording all episodes - while episode_index < num_episodes: - logging.info(f"Recording episode {episode_index}") - if play_sounds: - say(f"Recording episode {episode_index}") - ep_dict = {} - frame_index = 0 - timestamp = 0 - start_episode_t = time.perf_counter() - while timestamp < episode_time_s: - start_loop_t = time.perf_counter() + # Execute a few seconds without recording to give time to manually reset the environment + # Current code logic doesn't allow to teleoperate during this time. + # TODO(rcadene): add an option to enable teleoperation during reset + # Skip reset for the last episode to be recorded + if not events["stop_recording"] and ( + (episode_index < num_episodes - 1) or events["rerecord_episode"] + ): + log_say("Reset the environment", play_sounds) + reset_environment(robot, events, reset_time_s) - if policy is None: - observation, action = robot.teleop_step(record_data=True) - else: - observation = robot.capture_observation() + if events["rerecord_episode"]: + log_say("Re-record episode", play_sounds) + events["rerecord_episode"] = False + events["exit_early"] = False + delete_current_episode(dataset) + continue - image_keys = [key for key in observation if "image" in key] - not_image_keys = [key for key in observation if "image" not in key] + # Increment by one dataset["current_episode_index"] + save_current_episode(dataset) - if has_camera > 0: - for key in image_keys: - async_save_image( - image_writer, - image=observation[key], - key=key, - frame_index=frame_index, - episode_index=episode_index, - videos_dir=str(videos_dir), - ) + if events["stop_recording"]: + break - if display_cameras and not is_headless(): - image_keys = [key for key in observation if "image" in key] - for key in image_keys: - cv2.imshow(key, cv2.cvtColor(observation[key].numpy(), cv2.COLOR_RGB2BGR)) - cv2.waitKey(1) + log_say("Stop recording", play_sounds, blocking=True) + stop_recording(robot, listener, display_cameras) - for key in not_image_keys: - if key not in ep_dict: - ep_dict[key] = [] - ep_dict[key].append(observation[key]) + lerobot_dataset = create_lerobot_dataset(dataset, run_compute_stats, push_to_hub, tags, play_sounds) - if policy is not None: - with ( - torch.inference_mode(), - torch.autocast(device_type=device.type) - if device.type == "cuda" and hydra_cfg.use_amp - else nullcontext(), - ): - # Convert to pytorch format: channel first and float32 in [0,1] with batch dimension - for name in observation: - if "image" in name: - observation[name] = observation[name].type(torch.float32) / 255 - observation[name] = observation[name].permute(2, 0, 1).contiguous() - observation[name] = observation[name].unsqueeze(0) - observation[name] = observation[name].to(device) - - # Compute the next action with the policy - # based on the current observation - action = policy.select_action(observation) - - # Remove batch dimension - action = action.squeeze(0) - - # Move to cpu, if not already the case - action = action.to("cpu") - - # Order the robot to move - action_sent = robot.send_action(action) - - # Action can eventually be clipped using `max_relative_target`, - # so action actually sent is saved in the dataset. - action = {"action": action_sent} - - for key in action: - if key not in ep_dict: - ep_dict[key] = [] - ep_dict[key].append(action[key]) - - frame_index += 1 - - dt_s = time.perf_counter() - start_loop_t - busy_wait(1 / fps - dt_s) - - dt_s = time.perf_counter() - start_loop_t - log_control_info(robot, dt_s, fps=fps) - - timestamp = time.perf_counter() - start_episode_t - if exit_early: - exit_early = False - break - - # TODO(alibets): allow for teleop during reset - if has_method(robot, "teleop_safety_stop"): - robot.teleop_safety_stop() - - if not stop_recording: - # Start resetting env while the executor are finishing - logging.info("Reset the environment") - if play_sounds: - say("Reset the environment") - - timestamp = 0 - start_vencod_t = time.perf_counter() - - # During env reset we save the data and encode the videos - num_frames = frame_index - - for key in image_keys: - if video: - tmp_imgs_dir = videos_dir / f"{key}_episode_{episode_index:06d}" - fname = f"{key}_episode_{episode_index:06d}.mp4" - video_path = local_dir / "videos" / fname - if video_path.exists(): - video_path.unlink() - # Store the reference to the video frame, even tho the videos are not yet encoded - ep_dict[key] = [] - for i in range(num_frames): - ep_dict[key].append({"path": f"videos/{fname}", "timestamp": i / fps}) - - else: - imgs_dir = videos_dir / f"{key}_episode_{episode_index:06d}" - ep_dict[key] = [] - for i in range(num_frames): - img_path = imgs_dir / f"frame_{i:06d}.png" - ep_dict[key].append({"path": str(img_path)}) - - for key in not_image_keys: - ep_dict[key] = torch.stack(ep_dict[key]) - - for key in action: - ep_dict[key] = torch.stack(ep_dict[key]) - - ep_dict["episode_index"] = torch.tensor([episode_index] * num_frames) - ep_dict["frame_index"] = torch.arange(0, num_frames, 1) - ep_dict["timestamp"] = torch.arange(0, num_frames, 1) / fps - - done = torch.zeros(num_frames, dtype=torch.bool) - done[-1] = True - ep_dict["next.done"] = done - - ep_path = episodes_dir / f"episode_{episode_index}.pth" - print("Saving episode dictionary...") - torch.save(ep_dict, ep_path) - - rec_info = { - "last_episode_index": episode_index, - } - with open(rec_info_path, "w") as f: - json.dump(rec_info, f) - - is_last_episode = stop_recording or (episode_index == (num_episodes - 1)) - - # Wait if necessary - with tqdm.tqdm(total=reset_time_s, desc="Waiting") as pbar: - while timestamp < reset_time_s and not is_last_episode: - time.sleep(1) - timestamp = time.perf_counter() - start_vencod_t - pbar.update(1) - if exit_early: - exit_early = False - break - - # Skip updating episode index which forces re-recording episode - if rerecord_episode: - rerecord_episode = False - continue - - episode_index += 1 - - if is_last_episode: - logging.info("Done recording") - if play_sounds: - say("Done recording", blocking=True) - if not is_headless(): - listener.stop() - - if has_camera > 0: - logging.info("Waiting for image writer to terminate...") - stop_image_writer(image_writer, timeout=20) - - except Exception as e: - if has_camera > 0: - logging.info("Waiting for image writer to terminate...") - stop_image_writer(image_writer, timeout=20) - raise e - - robot.disconnect() - - if display_cameras and not is_headless(): - cv2.destroyAllWindows() - - num_episodes = episode_index - - if video: - logging.info("Encoding videos") - if play_sounds: - say("Encoding videos") - # Use ffmpeg to convert frames stored as png into mp4 videos - for episode_index in tqdm.tqdm(range(num_episodes)): - for key in image_keys: - tmp_imgs_dir = videos_dir / f"{key}_episode_{episode_index:06d}" - fname = f"{key}_episode_{episode_index:06d}.mp4" - video_path = local_dir / "videos" / fname - if video_path.exists(): - # Skip if video is already encoded. Could be the case when resuming data recording. - continue - # note: `encode_video_frames` is a blocking call. Making it asynchronous shouldn't speedup encoding, - # since video encoding with ffmpeg is already using multithreading. - encode_video_frames(tmp_imgs_dir, video_path, fps, overwrite=True) - shutil.rmtree(tmp_imgs_dir) - - logging.info("Concatenating episodes") - ep_dicts = [] - for episode_index in tqdm.tqdm(range(num_episodes)): - ep_path = episodes_dir / f"episode_{episode_index}.pth" - ep_dict = torch.load(ep_path) - ep_dicts.append(ep_dict) - data_dict = concatenate_episodes(ep_dicts) - - total_frames = data_dict["frame_index"].shape[0] - data_dict["index"] = torch.arange(0, total_frames, 1) - - hf_dataset = to_hf_dataset(data_dict, video) - episode_data_index = calculate_episode_data_index(hf_dataset) - info = { - "codebase_version": CODEBASE_VERSION, - "fps": fps, - "video": video, - } - if video: - info["encoding"] = get_default_encoding() - - lerobot_dataset = LeRobotDataset.from_preloaded( - repo_id=repo_id, - hf_dataset=hf_dataset, - episode_data_index=episode_data_index, - info=info, - videos_dir=videos_dir, - ) - if run_compute_stats: - logging.info("Computing dataset statistics") - if play_sounds: - say("Computing dataset statistics") - stats = compute_stats(lerobot_dataset) - lerobot_dataset.stats = stats - else: - stats = {} - logging.info("Skipping computation of the dataset statistics") - - hf_dataset = hf_dataset.with_format(None) # to remove transforms that cant be saved - hf_dataset.save_to_disk(str(local_dir / "train")) - - meta_data_dir = local_dir / "meta_data" - save_meta_data(info, stats, episode_data_index, meta_data_dir) - - if push_to_hub: - hf_dataset.push_to_hub(repo_id, revision="main") - push_meta_data_to_hub(repo_id, meta_data_dir, revision="main") - push_dataset_card_to_hub(repo_id, revision="main", tags=tags) - if video: - push_videos_to_hub(repo_id, videos_dir, revision="main") - create_branch(repo_id, repo_type="dataset", branch=CODEBASE_VERSION) - - logging.info("Exiting") - if play_sounds: - say("Exiting") + log_say("Exiting", play_sounds) return lerobot_dataset +@safe_disconnect def replay( robot: Robot, episode: int, fps: int | None = None, root="data", repo_id="lerobot/debug", play_sounds=True ): + # TODO(rcadene, aliberts): refactor with control_loop, once `dataset` is an instance of LeRobotDataset # TODO(rcadene): Add option to record logs local_dir = Path(root) / repo_id if not local_dir.exists(): @@ -887,9 +328,7 @@ def replay( if not robot.is_connected: robot.connect() - logging.info("Replaying episode") - if play_sounds: - say("Replaying episode", blocking=True) + log_say("Replaying episode", play_sounds, blocking=True) for idx in range(from_idx, to_idx): start_episode_t = time.perf_counter() @@ -934,6 +373,12 @@ if __name__ == "__main__": parser_teleop.add_argument( "--fps", type=none_or_int, default=None, help="Frames per second (set to None to disable)" ) + parser_teleop.add_argument( + "--display-cameras", + type=int, + default=1, + help="Display all cameras on screen (set to 1 to display or 0).", + ) parser_record = subparsers.add_parser("record", parents=[base_parser]) parser_record.add_argument( @@ -1071,19 +516,7 @@ if __name__ == "__main__": teleoperate(robot, **kwargs) elif control_mode == "record": - pretrained_policy_name_or_path = args.pretrained_policy_name_or_path - policy_overrides = args.policy_overrides - del kwargs["pretrained_policy_name_or_path"] - del kwargs["policy_overrides"] - - policy_cfg = None - if pretrained_policy_name_or_path is not None: - pretrained_policy_path = get_pretrained_policy_path(pretrained_policy_name_or_path) - policy_cfg = init_hydra_config(pretrained_policy_path / "config.yaml", policy_overrides) - policy = make_policy(hydra_cfg=policy_cfg, pretrained_policy_name_or_path=pretrained_policy_path) - record(robot, policy, policy_cfg, **kwargs) - else: - record(robot, **kwargs) + record(robot, **kwargs) elif control_mode == "replay": replay(robot, **kwargs) diff --git a/lerobot/scripts/train.py b/lerobot/scripts/train.py index 45807503..f60f904e 100644 --- a/lerobot/scripts/train.py +++ b/lerobot/scripts/train.py @@ -383,7 +383,7 @@ def train(cfg: DictConfig, out_dir: str | None = None, job_name: str | None = No logging.info(f"Checkpoint policy after step {step}") # Note: Save with step as the identifier, and format it to have at least 6 digits but more if # needed (choose 6 as a minimum for consistency without being overkill). - logger.save_checkpont( + logger.save_checkpoint( step, policy, optimizer, diff --git a/lerobot/templates/visualize_dataset_template.html b/lerobot/templates/visualize_dataset_template.html index 4f0bd343..658d6ba6 100644 --- a/lerobot/templates/visualize_dataset_template.html +++ b/lerobot/templates/visualize_dataset_template.html @@ -250,7 +250,7 @@ if(!canPlayVideos){ this.videoCodecError = true; } - + // process CSV data this.videos = document.querySelectorAll('video'); this.video = this.videos[0]; diff --git a/tests/test_control_robot.py b/tests/test_control_robot.py index 770b3489..2c0bca9b 100644 --- a/tests/test_control_robot.py +++ b/tests/test_control_robot.py @@ -25,12 +25,16 @@ pytest -sx 'tests/test_control_robot.py::test_teleoperate[aloha-True]' import multiprocessing from pathlib import Path +from unittest.mock import patch import pytest +from lerobot.common.datasets.populate_dataset import add_frame, init_dataset +from lerobot.common.logger import Logger from lerobot.common.policies.factory import make_policy from lerobot.common.utils.utils import init_hydra_config -from lerobot.scripts.control_robot import calibrate, get_available_arms, record, replay, teleoperate +from lerobot.scripts.control_robot import calibrate, record, replay, teleoperate +from lerobot.scripts.train import make_optimizer_and_scheduler from tests.test_robots import make_robot from tests.utils import DEFAULT_CONFIG_PATH, DEVICE, TEST_ROBOT_TYPES, require_robot @@ -69,7 +73,7 @@ def test_calibrate(tmpdir, request, robot_type, mock): overrides_calibration_dir = [f"calibration_dir={calibration_dir}"] robot = make_robot(robot_type, overrides=overrides_calibration_dir, mock=mock) - calibrate(robot, arms=get_available_arms(robot)) + calibrate(robot, arms=robot.available_arms) del robot @@ -109,12 +113,14 @@ def test_record_without_cameras(tmpdir, request, robot_type, mock): @pytest.mark.parametrize("robot_type, mock", TEST_ROBOT_TYPES) @require_robot def test_record_and_replay_and_policy(tmpdir, request, robot_type, mock): + tmpdir = Path(tmpdir) + if mock and robot_type != "aloha": request.getfixturevalue("patch_builtins_input") # Create an empty calibration directory to trigger manual calibration # and avoid writing calibration files in user .cache/calibration folder - calibration_dir = Path(tmpdir) / robot_type + calibration_dir = tmpdir / robot_type overrides = [f"calibration_dir={calibration_dir}"] else: # Use the default .cache/calibration folder when mock=False or for aloha @@ -123,17 +129,19 @@ def test_record_and_replay_and_policy(tmpdir, request, robot_type, mock): env_name = "koch_real" policy_name = "act_koch_real" - root = Path(tmpdir) / "data" + root = tmpdir / "data" repo_id = "lerobot/debug" + eval_repo_id = "lerobot/eval_debug" robot = make_robot(robot_type, overrides=overrides, mock=mock) dataset = record( robot, - fps=30, - root=root, - repo_id=repo_id, + root, + repo_id, + fps=1, warmup_time_s=1, episode_time_s=1, + reset_time_s=1, num_episodes=2, push_to_hub=False, # TODO(rcadene, aliberts): test video=True @@ -142,8 +150,10 @@ def test_record_and_replay_and_policy(tmpdir, request, robot_type, mock): display_cameras=False, play_sounds=False, ) + assert dataset.num_episodes == 2 + assert len(dataset) == 2 - replay(robot, episode=0, fps=30, root=root, repo_id=repo_id, play_sounds=False) + replay(robot, episode=0, fps=1, root=root, repo_id=repo_id, play_sounds=False) # TODO(rcadene, aliberts): rethink this design if robot_type == "aloha": @@ -164,12 +174,26 @@ def test_record_and_replay_and_policy(tmpdir, request, robot_type, mock): if robot_type == "koch_bimanual": overrides += ["env.state_dim=12", "env.action_dim=12"] + overrides += ["wandb.enable=false"] + overrides += ["env.fps=1"] + cfg = init_hydra_config( DEFAULT_CONFIG_PATH, overrides=overrides, ) policy = make_policy(hydra_cfg=cfg, dataset_stats=dataset.stats) + optimizer, lr_scheduler = make_optimizer_and_scheduler(cfg, policy) + out_dir = tmpdir / "logger" + logger = Logger(cfg, out_dir, wandb_job_name="debug") + logger.save_checkpoint( + 0, + policy, + optimizer, + lr_scheduler, + identifier=0, + ) + pretrained_policy_name_or_path = out_dir / "checkpoints/last/pretrained_model" # In `examples/9_use_aloha.md`, we advise using `num_image_writer_processes=1` # during inference, to reach constent fps, so we test this here. @@ -194,10 +218,12 @@ def test_record_and_replay_and_policy(tmpdir, request, robot_type, mock): record( robot, - policy, - cfg, + root, + eval_repo_id, + pretrained_policy_name_or_path, warmup_time_s=1, episode_time_s=1, + reset_time_s=1, num_episodes=2, run_compute_stats=False, push_to_hub=False, @@ -207,4 +233,218 @@ def test_record_and_replay_and_policy(tmpdir, request, robot_type, mock): num_image_writer_processes=num_image_writer_processes, ) + assert dataset.num_episodes == 2 + assert len(dataset) == 2 + del robot + + +@pytest.mark.parametrize("robot_type, mock", [("koch", True)]) +@require_robot +def test_resume_record(tmpdir, request, robot_type, mock): + if mock and robot_type != "aloha": + request.getfixturevalue("patch_builtins_input") + + # Create an empty calibration directory to trigger manual calibration + # and avoid writing calibration files in user .cache/calibration folder + calibration_dir = tmpdir / robot_type + overrides = [f"calibration_dir={calibration_dir}"] + else: + # Use the default .cache/calibration folder when mock=False or for aloha + overrides = [] + + robot = make_robot(robot_type, overrides=overrides, mock=mock) + + root = Path(tmpdir) / "data" + repo_id = "lerobot/debug" + + dataset = record( + robot, + root, + repo_id, + fps=1, + warmup_time_s=0, + episode_time_s=1, + num_episodes=1, + push_to_hub=False, + video=False, + display_cameras=False, + play_sounds=False, + run_compute_stats=False, + ) + assert len(dataset) == 1, "`dataset` should contain only 1 frame" + + init_dataset_return_value = {} + + def wrapped_init_dataset(*args, **kwargs): + nonlocal init_dataset_return_value + init_dataset_return_value = init_dataset(*args, **kwargs) + return init_dataset_return_value + + with patch("lerobot.scripts.control_robot.init_dataset", wraps=wrapped_init_dataset): + dataset = record( + robot, + root, + repo_id, + fps=1, + warmup_time_s=0, + episode_time_s=1, + num_episodes=2, + push_to_hub=False, + video=False, + display_cameras=False, + play_sounds=False, + run_compute_stats=False, + ) + assert len(dataset) == 2, "`dataset` should contain only 1 frame" + assert ( + init_dataset_return_value["num_episodes"] == 2 + ), "`init_dataset` should load the previous episode" + + +@pytest.mark.parametrize("robot_type, mock", [("koch", True)]) +@require_robot +def test_record_with_event_rerecord_episode(tmpdir, request, robot_type, mock): + if mock and robot_type != "aloha": + request.getfixturevalue("patch_builtins_input") + + # Create an empty calibration directory to trigger manual calibration + # and avoid writing calibration files in user .cache/calibration folder + calibration_dir = tmpdir / robot_type + overrides = [f"calibration_dir={calibration_dir}"] + else: + # Use the default .cache/calibration folder when mock=False or for aloha + overrides = [] + + robot = make_robot(robot_type, overrides=overrides, mock=mock) + with ( + patch("lerobot.scripts.control_robot.init_keyboard_listener") as mock_listener, + patch("lerobot.common.robot_devices.control_utils.add_frame", wraps=add_frame) as mock_add_frame, + ): + mock_events = {} + mock_events["exit_early"] = True + mock_events["rerecord_episode"] = True + mock_events["stop_recording"] = False + mock_listener.return_value = (None, mock_events) + + root = Path(tmpdir) / "data" + repo_id = "lerobot/debug" + + dataset = record( + robot, + root, + repo_id, + fps=1, + warmup_time_s=0, + episode_time_s=1, + num_episodes=1, + push_to_hub=False, + video=False, + display_cameras=False, + play_sounds=False, + run_compute_stats=False, + ) + + assert not mock_events["rerecord_episode"], "`rerecord_episode` wasn't properly reset to False" + assert not mock_events["exit_early"], "`exit_early` wasn't properly reset to False" + assert mock_add_frame.call_count == 2, "`add_frame` should have been called 2 times" + assert len(dataset) == 1, "`dataset` should contain only 1 frame" + + +@pytest.mark.parametrize("robot_type, mock", [("koch", True)]) +@require_robot +def test_record_with_event_exit_early(tmpdir, request, robot_type, mock): + if mock: + request.getfixturevalue("patch_builtins_input") + + # Create an empty calibration directory to trigger manual calibration + # and avoid writing calibration files in user .cache/calibration folder + calibration_dir = tmpdir / robot_type + overrides = [f"calibration_dir={calibration_dir}"] + else: + # Use the default .cache/calibration folder when mock=False or for aloha + overrides = [] + + robot = make_robot(robot_type, overrides=overrides, mock=mock) + with ( + patch("lerobot.scripts.control_robot.init_keyboard_listener") as mock_listener, + patch("lerobot.common.robot_devices.control_utils.add_frame", wraps=add_frame) as mock_add_frame, + ): + mock_events = {} + mock_events["exit_early"] = True + mock_events["rerecord_episode"] = False + mock_events["stop_recording"] = False + mock_listener.return_value = (None, mock_events) + + root = Path(tmpdir) / "data" + repo_id = "lerobot/debug" + + dataset = record( + robot, + fps=2, + root=root, + repo_id=repo_id, + warmup_time_s=0, + episode_time_s=1, + num_episodes=1, + push_to_hub=False, + video=False, + display_cameras=False, + play_sounds=False, + run_compute_stats=False, + ) + + assert not mock_events["exit_early"], "`exit_early` wasn't properly reset to False" + assert mock_add_frame.call_count == 1, "`add_frame` should have been called 1 time" + assert len(dataset) == 1, "`dataset` should contain only 1 frame" + + +@pytest.mark.parametrize( + "robot_type, mock, num_image_writer_processes", [("koch", True, 0), ("koch", True, 1)] +) +@require_robot +def test_record_with_event_stop_recording(tmpdir, request, robot_type, mock, num_image_writer_processes): + if mock: + request.getfixturevalue("patch_builtins_input") + + # Create an empty calibration directory to trigger manual calibration + # and avoid writing calibration files in user .cache/calibration folder + calibration_dir = tmpdir / robot_type + overrides = [f"calibration_dir={calibration_dir}"] + else: + # Use the default .cache/calibration folder when mock=False or for aloha + overrides = [] + + robot = make_robot(robot_type, overrides=overrides, mock=mock) + with ( + patch("lerobot.scripts.control_robot.init_keyboard_listener") as mock_listener, + patch("lerobot.common.robot_devices.control_utils.add_frame", wraps=add_frame) as mock_add_frame, + ): + mock_events = {} + mock_events["exit_early"] = True + mock_events["rerecord_episode"] = False + mock_events["stop_recording"] = True + mock_listener.return_value = (None, mock_events) + + root = Path(tmpdir) / "data" + repo_id = "lerobot/debug" + + dataset = record( + robot, + root, + repo_id, + fps=1, + warmup_time_s=0, + episode_time_s=1, + num_episodes=2, + push_to_hub=False, + video=False, + display_cameras=False, + play_sounds=False, + run_compute_stats=False, + num_image_writer_processes=num_image_writer_processes, + ) + + assert not mock_events["exit_early"], "`exit_early` wasn't properly reset to False" + assert mock_add_frame.call_count == 1, "`add_frame` should have been called 1 time" + assert len(dataset) == 1, "`dataset` should contain only 1 frame" diff --git a/tests/test_robots.py b/tests/test_robots.py index 72f0c944..13ad8c45 100644 --- a/tests/test_robots.py +++ b/tests/test_robots.py @@ -127,6 +127,7 @@ def test_robot(tmpdir, request, robot_type, mock): # TODO(rcadene): skipping image for now as it's challenging to assess equality between two consecutive frames continue assert torch.allclose(captured_observation[name], observation[name], atol=1) + assert captured_observation[name].shape == observation[name].shape # Test send_action can run robot.send_action(action["action"])