Refactor `record` with `add_frame` (#468)

Co-authored-by: Simon Alibert <75076266+aliberts@users.noreply.github.com>
This commit is contained in:
Remi 2024-10-16 20:51:35 +02:00 committed by GitHub
parent 97b1feb0b3
commit 77478d50e5
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
11 changed files with 1232 additions and 705 deletions

View File

@ -47,6 +47,7 @@ jobs:
pipx install poetry && poetry config virtualenvs.in-project true pipx install poetry && poetry config virtualenvs.in-project true
echo "${{ github.workspace }}/.venv/bin" >> $GITHUB_PATH echo "${{ github.workspace }}/.venv/bin" >> $GITHUB_PATH
# TODO(rcadene, aliberts): python 3.12 seems to be used in the tests, not python 3.10
- name: Set up Python 3.10 - name: Set up Python 3.10
uses: actions/setup-python@v5 uses: actions/setup-python@v5
with: with:
@ -84,6 +85,7 @@ jobs:
pipx install poetry && poetry config virtualenvs.in-project true pipx install poetry && poetry config virtualenvs.in-project true
echo "${{ github.workspace }}/.venv/bin" >> $GITHUB_PATH echo "${{ github.workspace }}/.venv/bin" >> $GITHUB_PATH
# TODO(rcadene, aliberts): python 3.12 seems to be used in the tests, not python 3.10
- name: Set up Python 3.10 - name: Set up Python 3.10
uses: actions/setup-python@v5 uses: actions/setup-python@v5
with: with:

View File

@ -0,0 +1,468 @@
"""Functions to create an empty dataset, and populate it with frames."""
# TODO(rcadene, aliberts): to adapt as class methods of next version of LeRobotDataset
import concurrent
import json
import logging
import multiprocessing
import shutil
from pathlib import Path
import torch
import tqdm
from PIL import Image
from lerobot.common.datasets.compute_stats import compute_stats
from lerobot.common.datasets.lerobot_dataset import CODEBASE_VERSION, LeRobotDataset
from lerobot.common.datasets.push_dataset_to_hub.aloha_hdf5_format import to_hf_dataset
from lerobot.common.datasets.push_dataset_to_hub.utils import concatenate_episodes, get_default_encoding
from lerobot.common.datasets.utils import calculate_episode_data_index, create_branch
from lerobot.common.datasets.video_utils import encode_video_frames
from lerobot.common.utils.utils import log_say
from lerobot.scripts.push_dataset_to_hub import (
push_dataset_card_to_hub,
push_meta_data_to_hub,
push_videos_to_hub,
save_meta_data,
)
########################################################################################
# Asynchrounous saving of images on disk
########################################################################################
def safe_stop_image_writer(func):
# TODO(aliberts): Allow to pass custom exceptions
# (e.g. ThreadServiceExit, KeyboardInterrupt, SystemExit, UnpluggedError, DynamixelCommError)
def wrapper(*args, **kwargs):
try:
return func(*args, **kwargs)
except Exception as e:
image_writer = kwargs.get("dataset", {}).get("image_writer")
if image_writer is not None:
print("Waiting for image writer to terminate...")
stop_image_writer(image_writer, timeout=20)
raise e
return wrapper
def save_image(img_tensor, key, frame_index, episode_index, videos_dir: str):
img = Image.fromarray(img_tensor.numpy())
path = Path(videos_dir) / f"{key}_episode_{episode_index:06d}" / f"frame_{frame_index:06d}.png"
path.parent.mkdir(parents=True, exist_ok=True)
img.save(str(path), quality=100)
def loop_to_save_images_in_threads(image_queue, num_threads):
if num_threads < 1:
raise NotImplementedError(f"Only `num_threads>=1` is supported for now, but {num_threads=} given.")
with concurrent.futures.ThreadPoolExecutor(max_workers=num_threads) as executor:
futures = []
while True:
# Blocks until a frame is available
frame_data = image_queue.get()
# As usually done, exit loop when receiving None to stop the worker
if frame_data is None:
break
image, key, frame_index, episode_index, videos_dir = frame_data
futures.append(executor.submit(save_image, image, key, frame_index, episode_index, videos_dir))
# Before exiting function, wait for all threads to complete
with tqdm.tqdm(total=len(futures), desc="Writing images") as progress_bar:
concurrent.futures.wait(futures)
progress_bar.update(len(futures))
def start_image_writer_processes(image_queue, num_processes, num_threads_per_process):
if num_processes < 1:
raise ValueError(f"Only `num_processes>=1` is supported, but {num_processes=} given.")
if num_threads_per_process < 1:
raise NotImplementedError(
"Only `num_threads_per_process>=1` is supported for now, but {num_threads_per_process=} given."
)
processes = []
for _ in range(num_processes):
process = multiprocessing.Process(
target=loop_to_save_images_in_threads,
args=(image_queue, num_threads_per_process),
)
process.start()
processes.append(process)
return processes
def stop_processes(processes, queue, timeout):
# Send None to each process to signal them to stop
for _ in processes:
queue.put(None)
# Wait maximum 20 seconds for all processes to terminate
for process in processes:
process.join(timeout=timeout)
# If not terminated after 20 seconds, force termination
if process.is_alive():
process.terminate()
# Close the queue, no more items can be put in the queue
queue.close()
# Ensure all background queue threads have finished
queue.join_thread()
def start_image_writer(num_processes, num_threads):
"""This function abstract away the initialisation of processes or/and threads to
save images on disk asynchrounously, which is critical to control a robot and record data
at a high frame rate.
When `num_processes=0`, it returns a dictionary containing a threads pool of size `num_threads`.
When `num_processes>0`, it returns a dictionary containing a processes pool of size `num_processes`,
where each subprocess starts their own threads pool of size `num_threads`.
The optimal number of processes and threads depends on your computer capabilities.
We advise to use 4 threads per camera with 0 processes. If the fps is not stable, try to increase or lower
the number of threads. If it is still not stable, try to use 1 subprocess, or more.
"""
image_writer = {}
if num_processes == 0:
futures = []
threads_pool = concurrent.futures.ThreadPoolExecutor(max_workers=num_threads)
image_writer["threads_pool"], image_writer["futures"] = threads_pool, futures
else:
# TODO(rcadene): When using num_processes>1, `multiprocessing.Manager().Queue()`
# might be better than `multiprocessing.Queue()`. Source: https://www.geeksforgeeks.org/python-multiprocessing-queue-vs-multiprocessing-manager-queue
image_queue = multiprocessing.Queue()
processes_pool = start_image_writer_processes(
image_queue, num_processes=num_processes, num_threads_per_process=num_threads
)
image_writer["processes_pool"], image_writer["image_queue"] = processes_pool, image_queue
return image_writer
def async_save_image(image_writer, image, key, frame_index, episode_index, videos_dir):
"""This function abstract away the saving of an image on disk asynchrounously. It uses a dictionary
called image writer which contains either a pool of processes or a pool of threads.
"""
if "threads_pool" in image_writer:
threads_pool, futures = image_writer["threads_pool"], image_writer["futures"]
futures.append(threads_pool.submit(save_image, image, key, frame_index, episode_index, videos_dir))
else:
image_queue = image_writer["image_queue"]
image_queue.put((image, key, frame_index, episode_index, videos_dir))
def stop_image_writer(image_writer, timeout):
if "threads_pool" in image_writer:
futures = image_writer["futures"]
# Before exiting function, wait for all threads to complete
with tqdm.tqdm(total=len(futures), desc="Writing images") as progress_bar:
concurrent.futures.wait(futures, timeout=timeout)
progress_bar.update(len(futures))
else:
processes_pool, image_queue = image_writer["processes_pool"], image_writer["image_queue"]
stop_processes(processes_pool, image_queue, timeout=timeout)
########################################################################################
# Functions to initialize, resume and populate a dataset
########################################################################################
def init_dataset(
repo_id,
root,
force_override,
fps,
video,
write_images,
num_image_writer_processes,
num_image_writer_threads,
):
local_dir = Path(root) / repo_id
if local_dir.exists() and force_override:
shutil.rmtree(local_dir)
episodes_dir = local_dir / "episodes"
episodes_dir.mkdir(parents=True, exist_ok=True)
videos_dir = local_dir / "videos"
videos_dir.mkdir(parents=True, exist_ok=True)
# Logic to resume data recording
rec_info_path = episodes_dir / "data_recording_info.json"
if rec_info_path.exists():
with open(rec_info_path) as f:
rec_info = json.load(f)
num_episodes = rec_info["last_episode_index"] + 1
else:
num_episodes = 0
dataset = {
"repo_id": repo_id,
"local_dir": local_dir,
"videos_dir": videos_dir,
"episodes_dir": episodes_dir,
"fps": fps,
"video": video,
"rec_info_path": rec_info_path,
"num_episodes": num_episodes,
}
if write_images:
# Initialize processes or/and threads dedicated to save images on disk asynchronously,
# which is critical to control a robot and record data at a high frame rate.
image_writer = start_image_writer(
num_processes=num_image_writer_processes,
num_threads=num_image_writer_threads,
)
dataset["image_writer"] = image_writer
return dataset
def add_frame(dataset, observation, action):
if "current_episode" not in dataset:
# initialize episode dictionary
ep_dict = {}
for key in observation:
if key not in ep_dict:
ep_dict[key] = []
for key in action:
if key not in ep_dict:
ep_dict[key] = []
ep_dict["episode_index"] = []
ep_dict["frame_index"] = []
ep_dict["timestamp"] = []
ep_dict["next.done"] = []
dataset["current_episode"] = ep_dict
dataset["current_frame_index"] = 0
ep_dict = dataset["current_episode"]
episode_index = dataset["num_episodes"]
frame_index = dataset["current_frame_index"]
videos_dir = dataset["videos_dir"]
video = dataset["video"]
fps = dataset["fps"]
ep_dict["episode_index"].append(episode_index)
ep_dict["frame_index"].append(frame_index)
ep_dict["timestamp"].append(frame_index / fps)
ep_dict["next.done"].append(False)
img_keys = [key for key in observation if "image" in key]
non_img_keys = [key for key in observation if "image" not in key]
# Save all observed modalities except images
for key in non_img_keys:
ep_dict[key].append(observation[key])
# Save actions
for key in action:
ep_dict[key].append(action[key])
if "image_writer" not in dataset:
dataset["current_frame_index"] += 1
return
# Save images
image_writer = dataset["image_writer"]
for key in img_keys:
imgs_dir = videos_dir / f"{key}_episode_{episode_index:06d}"
async_save_image(
image_writer,
image=observation[key],
key=key,
frame_index=frame_index,
episode_index=episode_index,
videos_dir=str(videos_dir),
)
if video:
fname = f"{key}_episode_{episode_index:06d}.mp4"
frame_info = {"path": f"videos/{fname}", "timestamp": frame_index / fps}
else:
frame_info = str(imgs_dir / f"frame_{frame_index:06d}.png")
ep_dict[key].append(frame_info)
dataset["current_frame_index"] += 1
def delete_current_episode(dataset):
del dataset["current_episode"]
del dataset["current_frame_index"]
# delete temporary images
episode_index = dataset["num_episodes"]
videos_dir = dataset["videos_dir"]
for tmp_imgs_dir in videos_dir.glob(f"*_episode_{episode_index:06d}"):
shutil.rmtree(tmp_imgs_dir)
def save_current_episode(dataset):
episode_index = dataset["num_episodes"]
ep_dict = dataset["current_episode"]
episodes_dir = dataset["episodes_dir"]
rec_info_path = dataset["rec_info_path"]
ep_dict["next.done"][-1] = True
for key in ep_dict:
if "observation" in key and "image" not in key:
ep_dict[key] = torch.stack(ep_dict[key])
ep_dict["action"] = torch.stack(ep_dict["action"])
ep_dict["episode_index"] = torch.tensor(ep_dict["episode_index"])
ep_dict["frame_index"] = torch.tensor(ep_dict["frame_index"])
ep_dict["timestamp"] = torch.tensor(ep_dict["timestamp"])
ep_dict["next.done"] = torch.tensor(ep_dict["next.done"])
ep_path = episodes_dir / f"episode_{episode_index}.pth"
torch.save(ep_dict, ep_path)
rec_info = {
"last_episode_index": episode_index,
}
with open(rec_info_path, "w") as f:
json.dump(rec_info, f)
# force re-initialization of episode dictionnary during add_frame
del dataset["current_episode"]
dataset["num_episodes"] += 1
def encode_videos(dataset, image_keys, play_sounds):
log_say("Encoding videos", play_sounds)
num_episodes = dataset["num_episodes"]
videos_dir = dataset["videos_dir"]
local_dir = dataset["local_dir"]
fps = dataset["fps"]
# Use ffmpeg to convert frames stored as png into mp4 videos
for episode_index in tqdm.tqdm(range(num_episodes)):
for key in image_keys:
# key = f"observation.images.{name}"
tmp_imgs_dir = videos_dir / f"{key}_episode_{episode_index:06d}"
fname = f"{key}_episode_{episode_index:06d}.mp4"
video_path = local_dir / "videos" / fname
if video_path.exists():
# Skip if video is already encoded. Could be the case when resuming data recording.
continue
# note: `encode_video_frames` is a blocking call. Making it asynchronous shouldn't speedup encoding,
# since video encoding with ffmpeg is already using multithreading.
encode_video_frames(tmp_imgs_dir, video_path, fps, overwrite=True)
shutil.rmtree(tmp_imgs_dir)
def from_dataset_to_lerobot_dataset(dataset, play_sounds):
log_say("Consolidate episodes", play_sounds)
num_episodes = dataset["num_episodes"]
episodes_dir = dataset["episodes_dir"]
videos_dir = dataset["videos_dir"]
video = dataset["video"]
fps = dataset["fps"]
repo_id = dataset["repo_id"]
ep_dicts = []
for episode_index in tqdm.tqdm(range(num_episodes)):
ep_path = episodes_dir / f"episode_{episode_index}.pth"
ep_dict = torch.load(ep_path)
ep_dicts.append(ep_dict)
data_dict = concatenate_episodes(ep_dicts)
if video:
image_keys = [key for key in data_dict if "image" in key]
encode_videos(dataset, image_keys, play_sounds)
hf_dataset = to_hf_dataset(data_dict, video)
episode_data_index = calculate_episode_data_index(hf_dataset)
info = {
"codebase_version": CODEBASE_VERSION,
"fps": fps,
"video": video,
}
if video:
info["encoding"] = get_default_encoding()
lerobot_dataset = LeRobotDataset.from_preloaded(
repo_id=repo_id,
hf_dataset=hf_dataset,
episode_data_index=episode_data_index,
info=info,
videos_dir=videos_dir,
)
return lerobot_dataset
def save_lerobot_dataset_on_disk(lerobot_dataset):
hf_dataset = lerobot_dataset.hf_dataset
info = lerobot_dataset.info
stats = lerobot_dataset.stats
episode_data_index = lerobot_dataset.episode_data_index
local_dir = lerobot_dataset.videos_dir.parent
meta_data_dir = local_dir / "meta_data"
hf_dataset = hf_dataset.with_format(None) # to remove transforms that cant be saved
hf_dataset.save_to_disk(str(local_dir / "train"))
save_meta_data(info, stats, episode_data_index, meta_data_dir)
def push_lerobot_dataset_to_hub(lerobot_dataset, tags):
hf_dataset = lerobot_dataset.hf_dataset
local_dir = lerobot_dataset.videos_dir.parent
videos_dir = lerobot_dataset.videos_dir
repo_id = lerobot_dataset.repo_id
video = lerobot_dataset.video
meta_data_dir = local_dir / "meta_data"
if not (local_dir / "train").exists():
raise ValueError(
"You need to run `save_lerobot_dataset_on_disk(lerobot_dataset)` before pushing to the hub."
)
hf_dataset.push_to_hub(repo_id, revision="main")
push_meta_data_to_hub(repo_id, meta_data_dir, revision="main")
push_dataset_card_to_hub(repo_id, revision="main", tags=tags)
if video:
push_videos_to_hub(repo_id, videos_dir, revision="main")
create_branch(repo_id, repo_type="dataset", branch=CODEBASE_VERSION)
def create_lerobot_dataset(dataset, run_compute_stats, push_to_hub, tags, play_sounds):
if "image_writer" in dataset:
logging.info("Waiting for image writer to terminate...")
image_writer = dataset["image_writer"]
stop_image_writer(image_writer, timeout=20)
lerobot_dataset = from_dataset_to_lerobot_dataset(dataset, play_sounds)
if run_compute_stats:
log_say("Computing dataset statistics", play_sounds)
lerobot_dataset.stats = compute_stats(lerobot_dataset)
else:
logging.info("Skipping computation of the dataset statistics")
lerobot_dataset.stats = {}
save_lerobot_dataset_on_disk(lerobot_dataset)
if push_to_hub:
push_lerobot_dataset_to_hub(lerobot_dataset, tags)
return lerobot_dataset

View File

@ -189,7 +189,7 @@ class Logger:
training_state["scheduler"] = scheduler.state_dict() training_state["scheduler"] = scheduler.state_dict()
torch.save(training_state, save_dir / self.training_state_file_name) torch.save(training_state, save_dir / self.training_state_file_name)
def save_checkpont( def save_checkpoint(
self, self,
train_step: int, train_step: int,
policy: Policy, policy: Policy,

View File

@ -0,0 +1,330 @@
########################################################################################
# Utilities
########################################################################################
import logging
import time
import traceback
from contextlib import nullcontext
from copy import copy
from functools import cache
import cv2
import torch
import tqdm
from termcolor import colored
from lerobot.common.datasets.populate_dataset import add_frame, safe_stop_image_writer
from lerobot.common.policies.factory import make_policy
from lerobot.common.robot_devices.robots.utils import Robot
from lerobot.common.robot_devices.utils import busy_wait
from lerobot.common.utils.utils import get_safe_torch_device, init_hydra_config, set_global_seed
from lerobot.scripts.eval import get_pretrained_policy_path
def log_control_info(robot: Robot, dt_s, episode_index=None, frame_index=None, fps=None):
log_items = []
if episode_index is not None:
log_items.append(f"ep:{episode_index}")
if frame_index is not None:
log_items.append(f"frame:{frame_index}")
def log_dt(shortname, dt_val_s):
nonlocal log_items, fps
info_str = f"{shortname}:{dt_val_s * 1000:5.2f} ({1/ dt_val_s:3.1f}hz)"
if fps is not None:
actual_fps = 1 / dt_val_s
if actual_fps < fps - 1:
info_str = colored(info_str, "yellow")
log_items.append(info_str)
# total step time displayed in milliseconds and its frequency
log_dt("dt", dt_s)
# TODO(aliberts): move robot-specific logs logic in robot.print_logs()
if not robot.robot_type.startswith("stretch"):
for name in robot.leader_arms:
key = f"read_leader_{name}_pos_dt_s"
if key in robot.logs:
log_dt("dtRlead", robot.logs[key])
for name in robot.follower_arms:
key = f"write_follower_{name}_goal_pos_dt_s"
if key in robot.logs:
log_dt("dtWfoll", robot.logs[key])
key = f"read_follower_{name}_pos_dt_s"
if key in robot.logs:
log_dt("dtRfoll", robot.logs[key])
for name in robot.cameras:
key = f"read_camera_{name}_dt_s"
if key in robot.logs:
log_dt(f"dtR{name}", robot.logs[key])
info_str = " ".join(log_items)
logging.info(info_str)
@cache
def is_headless():
"""Detects if python is running without a monitor."""
try:
import pynput # noqa
return False
except Exception:
print(
"Error trying to import pynput. Switching to headless mode. "
"As a result, the video stream from the cameras won't be shown, "
"and you won't be able to change the control flow with keyboards. "
"For more info, see traceback below.\n"
)
traceback.print_exc()
print()
return True
def has_method(_object: object, method_name: str):
return hasattr(_object, method_name) and callable(getattr(_object, method_name))
def predict_action(observation, policy, device, use_amp):
observation = copy(observation)
with (
torch.inference_mode(),
torch.autocast(device_type=device.type) if device.type == "cuda" and use_amp else nullcontext(),
):
# Convert to pytorch format: channel first and float32 in [0,1] with batch dimension
for name in observation:
if "image" in name:
observation[name] = observation[name].type(torch.float32) / 255
observation[name] = observation[name].permute(2, 0, 1).contiguous()
observation[name] = observation[name].unsqueeze(0)
observation[name] = observation[name].to(device)
# Compute the next action with the policy
# based on the current observation
action = policy.select_action(observation)
# Remove batch dimension
action = action.squeeze(0)
# Move to cpu, if not already the case
action = action.to("cpu")
return action
def init_keyboard_listener():
# Allow to exit early while recording an episode or resetting the environment,
# by tapping the right arrow key '->'. This might require a sudo permission
# to allow your terminal to monitor keyboard events.
events = {}
events["exit_early"] = False
events["rerecord_episode"] = False
events["stop_recording"] = False
if is_headless():
logging.warning(
"Headless environment detected. On-screen cameras display and keyboard inputs will not be available."
)
listener = None
return listener, events
# Only import pynput if not in a headless environment
from pynput import keyboard
def on_press(key):
try:
if key == keyboard.Key.right:
print("Right arrow key pressed. Exiting loop...")
events["exit_early"] = True
elif key == keyboard.Key.left:
print("Left arrow key pressed. Exiting loop and rerecord the last episode...")
events["rerecord_episode"] = True
events["exit_early"] = True
elif key == keyboard.Key.esc:
print("Escape key pressed. Stopping data recording...")
events["stop_recording"] = True
events["exit_early"] = True
except Exception as e:
print(f"Error handling key press: {e}")
listener = keyboard.Listener(on_press=on_press)
listener.start()
return listener, events
def init_policy(pretrained_policy_name_or_path, policy_overrides):
"""Instantiate the policy and load fps, device and use_amp from config yaml"""
pretrained_policy_path = get_pretrained_policy_path(pretrained_policy_name_or_path)
hydra_cfg = init_hydra_config(pretrained_policy_path / "config.yaml", policy_overrides)
policy = make_policy(hydra_cfg=hydra_cfg, pretrained_policy_name_or_path=pretrained_policy_path)
# Check device is available
device = get_safe_torch_device(hydra_cfg.device, log=True)
use_amp = hydra_cfg.use_amp
policy_fps = hydra_cfg.env.fps
policy.eval()
policy.to(device)
torch.backends.cudnn.benchmark = True
torch.backends.cuda.matmul.allow_tf32 = True
set_global_seed(hydra_cfg.seed)
return policy, policy_fps, device, use_amp
def warmup_record(
robot,
events,
enable_teloperation,
warmup_time_s,
display_cameras,
fps,
):
control_loop(
robot=robot,
control_time_s=warmup_time_s,
display_cameras=display_cameras,
events=events,
fps=fps,
teleoperate=enable_teloperation,
)
def record_episode(
robot,
dataset,
events,
episode_time_s,
display_cameras,
policy,
device,
use_amp,
fps,
):
control_loop(
robot=robot,
control_time_s=episode_time_s,
display_cameras=display_cameras,
dataset=dataset,
events=events,
policy=policy,
device=device,
use_amp=use_amp,
fps=fps,
teleoperate=policy is None,
)
@safe_stop_image_writer
def control_loop(
robot,
control_time_s=None,
teleoperate=False,
display_cameras=False,
dataset=None,
events=None,
policy=None,
device=None,
use_amp=None,
fps=None,
):
# TODO(rcadene): Add option to record logs
if not robot.is_connected:
robot.connect()
if events is None:
events = {"exit_early": False}
if control_time_s is None:
control_time_s = float("inf")
if teleoperate and policy is not None:
raise ValueError("When `teleoperate` is True, `policy` should be None.")
if dataset is not None and fps is not None and dataset["fps"] != fps:
raise ValueError(f"The dataset fps should be equal to requested fps ({dataset['fps']} != {fps}).")
timestamp = 0
start_episode_t = time.perf_counter()
while timestamp < control_time_s:
start_loop_t = time.perf_counter()
if teleoperate:
observation, action = robot.teleop_step(record_data=True)
else:
observation = robot.capture_observation()
if policy is not None:
pred_action = predict_action(observation, policy, device, use_amp)
# Action can eventually be clipped using `max_relative_target`,
# so action actually sent is saved in the dataset.
action = robot.send_action(pred_action)
action = {"action": action}
if dataset is not None:
add_frame(dataset, observation, action)
if display_cameras and not is_headless():
image_keys = [key for key in observation if "image" in key]
for key in image_keys:
cv2.imshow(key, cv2.cvtColor(observation[key].numpy(), cv2.COLOR_RGB2BGR))
cv2.waitKey(1)
if fps is not None:
dt_s = time.perf_counter() - start_loop_t
busy_wait(1 / fps - dt_s)
dt_s = time.perf_counter() - start_loop_t
log_control_info(robot, dt_s, fps=fps)
timestamp = time.perf_counter() - start_episode_t
if events["exit_early"]:
events["exit_early"] = False
break
def reset_environment(robot, events, reset_time_s):
# TODO(rcadene): refactor warmup_record and reset_environment
# TODO(alibets): allow for teleop during reset
if has_method(robot, "teleop_safety_stop"):
robot.teleop_safety_stop()
timestamp = 0
start_vencod_t = time.perf_counter()
# Wait if necessary
with tqdm.tqdm(total=reset_time_s, desc="Waiting") as pbar:
while timestamp < reset_time_s:
time.sleep(1)
timestamp = time.perf_counter() - start_vencod_t
pbar.update(1)
if events["exit_early"]:
events["exit_early"] = False
break
def stop_recording(robot, listener, display_cameras):
robot.disconnect()
if not is_headless():
if listener is not None:
listener.stop()
if display_cameras:
cv2.destroyAllWindows()
def sanity_check_dataset_name(repo_id, policy):
_, dataset_name = repo_id.split("/")
# either repo_id doesnt start with "eval_" and there is no policy
# or repo_id starts with "eval_" and there is a policy
if dataset_name.startswith("eval_") == (policy is None):
raise ValueError(
f"Your dataset name begins by 'eval_' ({dataset_name}) but no policy is provided ({policy})."
)

View File

@ -349,6 +349,25 @@ class ManipulatorRobot:
self.is_connected = False self.is_connected = False
self.logs = {} self.logs = {}
@property
def has_camera(self):
return len(self.cameras) > 0
@property
def num_cameras(self):
return len(self.cameras)
@property
def available_arms(self):
available_arms = []
for name in self.follower_arms:
arm_id = get_arm_id(name, "follower")
available_arms.append(arm_id)
for name in self.leader_arms:
arm_id = get_arm_id(name, "leader")
available_arms.append(arm_id)
return available_arms
def connect(self): def connect(self):
if self.is_connected: if self.is_connected:
raise RobotDeviceAlreadyConnectedError( raise RobotDeviceAlreadyConnectedError(

View File

@ -16,6 +16,7 @@
import logging import logging
import os import os
import os.path as osp import os.path as osp
import platform
import random import random
from contextlib import contextmanager from contextlib import contextmanager
from datetime import datetime, timezone from datetime import datetime, timezone
@ -28,6 +29,12 @@ import torch
from omegaconf import DictConfig from omegaconf import DictConfig
def none_or_int(value):
if value == "None":
return None
return int(value)
def inside_slurm(): def inside_slurm():
"""Check whether the python process was launched through slurm""" """Check whether the python process was launched through slurm"""
# TODO(rcadene): return False for interactive mode `--pty bash` # TODO(rcadene): return False for interactive mode `--pty bash`
@ -183,3 +190,30 @@ def print_cuda_memory_usage():
def capture_timestamp_utc(): def capture_timestamp_utc():
return datetime.now(timezone.utc) return datetime.now(timezone.utc)
def say(text, blocking=False):
# Check if mac, linux, or windows.
if platform.system() == "Darwin":
cmd = f'say "{text}"'
elif platform.system() == "Linux":
cmd = f'spd-say "{text}"'
elif platform.system() == "Windows":
cmd = (
'PowerShell -Command "Add-Type -AssemblyName System.Speech; '
f"(New-Object System.Speech.Synthesis.SpeechSynthesizer).Speak('{text}')\""
)
if not blocking and platform.system() in ["Darwin", "Linux"]:
# TODO(rcadene): Make it work for Windows
# Use the ampersand to run command in the background
cmd += " &"
os.system(cmd)
def log_say(text, play_sounds, blocking=False):
logging.info(text)
if play_sounds:
say(text, blocking)

View File

@ -99,285 +99,35 @@ python lerobot/scripts/control_robot.py record \
""" """
import argparse import argparse
import concurrent.futures
import json
import logging import logging
import multiprocessing
import os
import platform
import shutil
import time import time
import traceback
from contextlib import nullcontext
from functools import cache
from pathlib import Path from pathlib import Path
from typing import List
import cv2
import torch
import tqdm
from omegaconf import DictConfig
from PIL import Image
from termcolor import colored
# from safetensors.torch import load_file, save_file # from safetensors.torch import load_file, save_file
from lerobot.common.datasets.compute_stats import compute_stats from lerobot.common.datasets.lerobot_dataset import LeRobotDataset
from lerobot.common.datasets.lerobot_dataset import CODEBASE_VERSION, LeRobotDataset from lerobot.common.datasets.populate_dataset import (
from lerobot.common.datasets.push_dataset_to_hub.aloha_hdf5_format import to_hf_dataset create_lerobot_dataset,
from lerobot.common.datasets.push_dataset_to_hub.utils import concatenate_episodes, get_default_encoding delete_current_episode,
from lerobot.common.datasets.utils import calculate_episode_data_index, create_branch init_dataset,
from lerobot.common.datasets.video_utils import encode_video_frames save_current_episode,
from lerobot.common.policies.factory import make_policy
from lerobot.common.robot_devices.robots.factory import make_robot
from lerobot.common.robot_devices.robots.utils import Robot, get_arm_id
from lerobot.common.robot_devices.utils import busy_wait, safe_disconnect
from lerobot.common.utils.utils import get_safe_torch_device, init_hydra_config, init_logging, set_global_seed
from lerobot.scripts.eval import get_pretrained_policy_path
from lerobot.scripts.push_dataset_to_hub import (
push_dataset_card_to_hub,
push_meta_data_to_hub,
push_videos_to_hub,
save_meta_data,
) )
from lerobot.common.robot_devices.control_utils import (
######################################################################################## control_loop,
# Utilities has_method,
######################################################################################## init_keyboard_listener,
init_policy,
log_control_info,
def say(text, blocking=False): record_episode,
# Check if mac, linux, or windows. reset_environment,
if platform.system() == "Darwin": sanity_check_dataset_name,
cmd = f'say "{text}"' stop_recording,
elif platform.system() == "Linux": warmup_record,
cmd = f'spd-say "{text}"' )
elif platform.system() == "Windows": from lerobot.common.robot_devices.robots.factory import make_robot
cmd = ( from lerobot.common.robot_devices.robots.utils import Robot
'PowerShell -Command "Add-Type -AssemblyName System.Speech; ' from lerobot.common.robot_devices.utils import busy_wait, safe_disconnect
f"(New-Object System.Speech.Synthesis.SpeechSynthesizer).Speak('{text}')\"" from lerobot.common.utils.utils import init_hydra_config, init_logging, log_say, none_or_int
)
if not blocking and platform.system() in ["Darwin", "Linux"]:
# TODO(rcadene): Make it work for Windows
# Use the ampersand to run command in the background
cmd += " &"
os.system(cmd)
def save_image(img_tensor, key, frame_index, episode_index, videos_dir: str):
img = Image.fromarray(img_tensor.numpy())
path = Path(videos_dir) / f"{key}_episode_{episode_index:06d}" / f"frame_{frame_index:06d}.png"
path.parent.mkdir(parents=True, exist_ok=True)
img.save(str(path), quality=100)
def none_or_int(value):
if value == "None":
return None
return int(value)
def log_control_info(robot: Robot, dt_s, episode_index=None, frame_index=None, fps=None):
log_items = []
if episode_index is not None:
log_items.append(f"ep:{episode_index}")
if frame_index is not None:
log_items.append(f"frame:{frame_index}")
def log_dt(shortname, dt_val_s):
nonlocal log_items, fps
info_str = f"{shortname}:{dt_val_s * 1000:5.2f} ({1/ dt_val_s:3.1f}hz)"
if fps is not None:
actual_fps = 1 / dt_val_s
if actual_fps < fps - 1:
info_str = colored(info_str, "yellow")
log_items.append(info_str)
# total step time displayed in milliseconds and its frequency
log_dt("dt", dt_s)
# TODO(aliberts): move robot-specific logs logic in robot.print_logs()
if not robot.robot_type.startswith("stretch"):
for name in robot.leader_arms:
key = f"read_leader_{name}_pos_dt_s"
if key in robot.logs:
log_dt("dtRlead", robot.logs[key])
for name in robot.follower_arms:
key = f"write_follower_{name}_goal_pos_dt_s"
if key in robot.logs:
log_dt("dtWfoll", robot.logs[key])
key = f"read_follower_{name}_pos_dt_s"
if key in robot.logs:
log_dt("dtRfoll", robot.logs[key])
for name in robot.cameras:
key = f"read_camera_{name}_dt_s"
if key in robot.logs:
log_dt(f"dtR{name}", robot.logs[key])
info_str = " ".join(log_items)
logging.info(info_str)
@cache
def is_headless():
"""Detects if python is running without a monitor."""
try:
import pynput # noqa
return False
except Exception:
print(
"Error trying to import pynput. Switching to headless mode. "
"As a result, the video stream from the cameras won't be shown, "
"and you won't be able to change the control flow with keyboards. "
"For more info, see traceback below.\n"
)
traceback.print_exc()
print()
return True
def has_method(_object: object, method_name: str):
return hasattr(_object, method_name) and callable(getattr(_object, method_name))
def get_available_arms(robot):
# TODO(rcadene): moves this function in manipulator class?
available_arms = []
for name in robot.follower_arms:
arm_id = get_arm_id(name, "follower")
available_arms.append(arm_id)
for name in robot.leader_arms:
arm_id = get_arm_id(name, "leader")
available_arms.append(arm_id)
return available_arms
########################################################################################
# Asynchrounous saving of images on disk
########################################################################################
def loop_to_save_images_in_threads(image_queue, num_threads):
if num_threads < 1:
raise NotImplementedError(f"Only `num_threads>=1` is supported for now, but {num_threads=} given.")
with concurrent.futures.ThreadPoolExecutor(max_workers=num_threads) as executor:
futures = []
while True:
# Blocks until a frame is available
frame_data = image_queue.get()
# As usually done, exit loop when receiving None to stop the worker
if frame_data is None:
break
image, key, frame_index, episode_index, videos_dir = frame_data
futures.append(executor.submit(save_image, image, key, frame_index, episode_index, videos_dir))
# Before exiting function, wait for all threads to complete
with tqdm.tqdm(total=len(futures), desc="Writing images") as progress_bar:
concurrent.futures.wait(futures)
progress_bar.update(len(futures))
def start_image_writer_processes(image_queue, num_processes, num_threads_per_process):
if num_processes < 1:
raise ValueError(f"Only `num_processes>=1` is supported, but {num_processes=} given.")
if num_threads_per_process < 1:
raise NotImplementedError(
"Only `num_threads_per_process>=1` is supported for now, but {num_threads_per_process=} given."
)
processes = []
for _ in range(num_processes):
process = multiprocessing.Process(
target=loop_to_save_images_in_threads,
args=(image_queue, num_threads_per_process),
)
process.start()
processes.append(process)
return processes
def stop_processes(processes, queue, timeout):
# Send None to each process to signal them to stop
for _ in processes:
queue.put(None)
# Close the queue, no more items can be put in the queue
queue.close()
# Wait maximum 20 seconds for all processes to terminate
for process in processes:
process.join(timeout=timeout)
# If not terminated after 20 seconds, force termination
if process.is_alive():
process.terminate()
# Ensure all background queue threads have finished
queue.join_thread()
def start_image_writer(num_processes, num_threads):
"""This function abstract away the initialisation of processes or/and threads to
save images on disk asynchrounously, which is critical to control a robot and record data
at a high frame rate.
When `num_processes=0`, it returns a dictionary containing a threads pool of size `num_threads`.
When `num_processes>0`, it returns a dictionary containing a processes pool of size `num_processes`,
where each subprocess starts their own threads pool of size `num_threads`.
The optimal number of processes and threads depends on your computer capabilities.
We advise to use 4 threads per camera with 0 processes. If the fps is not stable, try to increase or lower
the number of threads. If it is still not stable, try to use 1 subprocess, or more.
"""
image_writer = {}
if num_processes == 0:
futures = []
threads_pool = concurrent.futures.ThreadPoolExecutor(max_workers=num_threads)
image_writer["threads_pool"], image_writer["futures"] = threads_pool, futures
else:
# TODO(rcadene): When using num_processes>1, `multiprocessing.Manager().Queue()`
# might be better than `multiprocessing.Queue()`. Source: https://www.geeksforgeeks.org/python-multiprocessing-queue-vs-multiprocessing-manager-queue
image_queue = multiprocessing.Queue()
processes_pool = start_image_writer_processes(
image_queue, num_processes=num_processes, num_threads_per_process=num_threads
)
image_writer["processes_pool"], image_writer["image_queue"] = processes_pool, image_queue
return image_writer
def async_save_image(image_writer, image, key, frame_index, episode_index, videos_dir):
"""This function abstract away the saving of an image on disk asynchrounously. It uses a dictionary
called image writer which contains either a pool of processes or a pool of threads.
"""
if "threads_pool" in image_writer:
threads_pool, futures = image_writer["threads_pool"], image_writer["futures"]
futures.append(threads_pool.submit(save_image, image, key, frame_index, episode_index, videos_dir))
else:
image_queue = image_writer["image_queue"]
image_queue.put((image, key, frame_index, episode_index, videos_dir))
def stop_image_writer(image_writer, timeout):
if "threads_pool" in image_writer:
futures = image_writer["futures"]
# Before exiting function, wait for all threads to complete
with tqdm.tqdm(total=len(futures), desc="Writing images") as progress_bar:
concurrent.futures.wait(futures, timeout=timeout)
progress_bar.update(len(futures))
else:
processes_pool, image_queue = image_writer["processes_pool"], image_writer["image_queue"]
stop_processes(processes_pool, image_queue, timeout=timeout)
######################################################################################## ########################################################################################
# Control modes # Control modes
@ -394,9 +144,8 @@ def calibrate(robot: Robot, arms: list[str] | None):
robot.home() robot.home()
return return
available_arms = get_available_arms(robot) unknown_arms = [arm_id for arm_id in arms if arm_id not in robot.available_arms]
unknown_arms = [arm_id for arm_id in arms if arm_id not in available_arms] available_arms_str = " ".join(robot.available_arms)
available_arms_str = " ".join(available_arms)
unknown_arms_str = " ".join(unknown_arms) unknown_arms_str = " ".join(unknown_arms)
if arms is None or len(arms) == 0: if arms is None or len(arms) == 0:
@ -429,35 +178,26 @@ def calibrate(robot: Robot, arms: list[str] | None):
@safe_disconnect @safe_disconnect
def teleoperate(robot: Robot, fps: int | None = None, teleop_time_s: float | None = None): def teleoperate(
# TODO(rcadene): Add option to record logs robot: Robot, fps: int | None = None, teleop_time_s: float | None = None, display_cameras: bool = False
if not robot.is_connected: ):
robot.connect() control_loop(
robot,
start_teleop_t = time.perf_counter() control_time_s=teleop_time_s,
while True: fps=fps,
start_loop_t = time.perf_counter() teleoperate=True,
robot.teleop_step() display_cameras=display_cameras,
)
if fps is not None:
dt_s = time.perf_counter() - start_loop_t
busy_wait(1 / fps - dt_s)
dt_s = time.perf_counter() - start_loop_t
log_control_info(robot, dt_s, fps=fps)
if teleop_time_s is not None and time.perf_counter() - start_teleop_t > teleop_time_s:
break
@safe_disconnect @safe_disconnect
def record( def record(
robot: Robot, robot: Robot,
policy: torch.nn.Module | None = None, root: str,
hydra_cfg: DictConfig | None = None, repo_id: str,
pretrained_policy_name_or_path: str | None = None,
policy_overrides: List[str] | None = None,
fps: int | None = None, fps: int | None = None,
root="data",
repo_id="lerobot/debug",
warmup_time_s=2, warmup_time_s=2,
episode_time_s=10, episode_time_s=10,
reset_time_s=5, reset_time_s=5,
@ -473,407 +213,108 @@ def record(
play_sounds=True, play_sounds=True,
): ):
# TODO(rcadene): Add option to record logs # TODO(rcadene): Add option to record logs
# TODO(rcadene): Clean this function via decomposition in higher level functions listener = None
events = None
policy = None
device = None
use_amp = None
_, dataset_name = repo_id.split("/") # Load pretrained policy
if dataset_name.startswith("eval_") and policy is None: if pretrained_policy_name_or_path is not None:
raise ValueError( policy, policy_fps, device, use_amp = init_policy(pretrained_policy_name_or_path, policy_overrides)
f"Your dataset name begins by 'eval_' ({dataset_name}) but no policy is provided ({policy})."
if fps is None:
fps = policy_fps
logging.warning(f"No fps provided, so using the fps from policy config ({policy_fps}).")
elif fps != policy_fps:
logging.warning(
f"There is a mismatch between the provided fps ({fps}) and the one from policy config ({policy_fps})."
)
# Create empty dataset or load existing saved episodes
sanity_check_dataset_name(repo_id, policy)
dataset = init_dataset(
repo_id,
root,
force_override,
fps,
video,
write_images=robot.has_camera,
num_image_writer_processes=num_image_writer_processes,
num_image_writer_threads=num_image_writer_threads_per_camera * robot.num_cameras,
) )
if not robot.is_connected: if not robot.is_connected:
robot.connect() robot.connect()
local_dir = Path(root) / repo_id listener, events = init_keyboard_listener()
if local_dir.exists() and force_override:
shutil.rmtree(local_dir)
episodes_dir = local_dir / "episodes" # Execute a few seconds without recording to:
episodes_dir.mkdir(parents=True, exist_ok=True) # 1. teleoperate the robot to move it in starting position if no policy provided,
# 2. give times to the robot devices to connect and start synchronizing,
videos_dir = local_dir / "videos" # 3. place the cameras windows on screen
videos_dir.mkdir(parents=True, exist_ok=True) enable_teleoperation = policy is None
log_say("Warmup record", play_sounds)
# Logic to resume data recording warmup_record(robot, events, enable_teleoperation, warmup_time_s, display_cameras, fps)
rec_info_path = episodes_dir / "data_recording_info.json"
if rec_info_path.exists():
with open(rec_info_path) as f:
rec_info = json.load(f)
episode_index = rec_info["last_episode_index"] + 1
else:
episode_index = 0
if is_headless():
logging.warning(
"Headless environment detected. On-screen cameras display and keyboard inputs will not be available."
)
# Allow to exit early while recording an episode or resetting the environment,
# by tapping the right arrow key '->'. This might require a sudo permission
# to allow your terminal to monitor keyboard events.
exit_early = False
rerecord_episode = False
stop_recording = False
# Only import pynput if not in a headless environment
if not is_headless():
from pynput import keyboard
def on_press(key):
nonlocal exit_early, rerecord_episode, stop_recording
try:
if key == keyboard.Key.right:
print("Right arrow key pressed. Exiting loop...")
exit_early = True
elif key == keyboard.Key.left:
print("Left arrow key pressed. Exiting loop and rerecord the last episode...")
rerecord_episode = True
exit_early = True
elif key == keyboard.Key.esc:
print("Escape key pressed. Stopping data recording...")
stop_recording = True
exit_early = True
except Exception as e:
print(f"Error handling key press: {e}")
listener = keyboard.Listener(on_press=on_press)
listener.start()
# Load policy if any
if policy is not None:
# Check device is available
device = get_safe_torch_device(hydra_cfg.device, log=True)
policy.eval()
policy.to(device)
torch.backends.cudnn.benchmark = True
torch.backends.cuda.matmul.allow_tf32 = True
set_global_seed(hydra_cfg.seed)
# override fps using policy fps
fps = hydra_cfg.env.fps
# Execute a few seconds without recording data, to give times
# to the robot devices to connect and start synchronizing.
timestamp = 0
start_warmup_t = time.perf_counter()
is_warmup_print = False
while timestamp < warmup_time_s:
if not is_warmup_print:
logging.info("Warming up (no data recording)")
if play_sounds:
say("Warming up")
is_warmup_print = True
start_loop_t = time.perf_counter()
if policy is None:
observation, action = robot.teleop_step(record_data=True)
else:
observation = robot.capture_observation()
if display_cameras and not is_headless():
image_keys = [key for key in observation if "image" in key]
for key in image_keys:
cv2.imshow(key, cv2.cvtColor(observation[key].numpy(), cv2.COLOR_RGB2BGR))
cv2.waitKey(1)
dt_s = time.perf_counter() - start_loop_t
busy_wait(1 / fps - dt_s)
dt_s = time.perf_counter() - start_loop_t
log_control_info(robot, dt_s, fps=fps)
timestamp = time.perf_counter() - start_warmup_t
if has_method(robot, "teleop_safety_stop"): if has_method(robot, "teleop_safety_stop"):
robot.teleop_safety_stop() robot.teleop_safety_stop()
has_camera = len(robot.cameras) > 0 while True:
if has_camera: if dataset["num_episodes"] >= num_episodes:
# Initialize processes or/and threads dedicated to save images on disk asynchronously, break
# which is critical to control a robot and record data at a high frame rate.
image_writer = start_image_writer( episode_index = dataset["num_episodes"]
num_processes=num_image_writer_processes, log_say(f"Recording episode {episode_index}", play_sounds)
num_threads=num_image_writer_threads_per_camera * len(robot.cameras), record_episode(
dataset=dataset,
robot=robot,
events=events,
episode_time_s=episode_time_s,
display_cameras=display_cameras,
policy=policy,
device=device,
use_amp=use_amp,
fps=fps,
) )
# Using `try` to exist smoothly if an exception is raised # Execute a few seconds without recording to give time to manually reset the environment
try: # Current code logic doesn't allow to teleoperate during this time.
# Start recording all episodes # TODO(rcadene): add an option to enable teleoperation during reset
while episode_index < num_episodes: # Skip reset for the last episode to be recorded
logging.info(f"Recording episode {episode_index}") if not events["stop_recording"] and (
if play_sounds: (episode_index < num_episodes - 1) or events["rerecord_episode"]
say(f"Recording episode {episode_index}")
ep_dict = {}
frame_index = 0
timestamp = 0
start_episode_t = time.perf_counter()
while timestamp < episode_time_s:
start_loop_t = time.perf_counter()
if policy is None:
observation, action = robot.teleop_step(record_data=True)
else:
observation = robot.capture_observation()
image_keys = [key for key in observation if "image" in key]
not_image_keys = [key for key in observation if "image" not in key]
if has_camera > 0:
for key in image_keys:
async_save_image(
image_writer,
image=observation[key],
key=key,
frame_index=frame_index,
episode_index=episode_index,
videos_dir=str(videos_dir),
)
if display_cameras and not is_headless():
image_keys = [key for key in observation if "image" in key]
for key in image_keys:
cv2.imshow(key, cv2.cvtColor(observation[key].numpy(), cv2.COLOR_RGB2BGR))
cv2.waitKey(1)
for key in not_image_keys:
if key not in ep_dict:
ep_dict[key] = []
ep_dict[key].append(observation[key])
if policy is not None:
with (
torch.inference_mode(),
torch.autocast(device_type=device.type)
if device.type == "cuda" and hydra_cfg.use_amp
else nullcontext(),
): ):
# Convert to pytorch format: channel first and float32 in [0,1] with batch dimension log_say("Reset the environment", play_sounds)
for name in observation: reset_environment(robot, events, reset_time_s)
if "image" in name:
observation[name] = observation[name].type(torch.float32) / 255
observation[name] = observation[name].permute(2, 0, 1).contiguous()
observation[name] = observation[name].unsqueeze(0)
observation[name] = observation[name].to(device)
# Compute the next action with the policy if events["rerecord_episode"]:
# based on the current observation log_say("Re-record episode", play_sounds)
action = policy.select_action(observation) events["rerecord_episode"] = False
events["exit_early"] = False
# Remove batch dimension delete_current_episode(dataset)
action = action.squeeze(0)
# Move to cpu, if not already the case
action = action.to("cpu")
# Order the robot to move
action_sent = robot.send_action(action)
# Action can eventually be clipped using `max_relative_target`,
# so action actually sent is saved in the dataset.
action = {"action": action_sent}
for key in action:
if key not in ep_dict:
ep_dict[key] = []
ep_dict[key].append(action[key])
frame_index += 1
dt_s = time.perf_counter() - start_loop_t
busy_wait(1 / fps - dt_s)
dt_s = time.perf_counter() - start_loop_t
log_control_info(robot, dt_s, fps=fps)
timestamp = time.perf_counter() - start_episode_t
if exit_early:
exit_early = False
break
# TODO(alibets): allow for teleop during reset
if has_method(robot, "teleop_safety_stop"):
robot.teleop_safety_stop()
if not stop_recording:
# Start resetting env while the executor are finishing
logging.info("Reset the environment")
if play_sounds:
say("Reset the environment")
timestamp = 0
start_vencod_t = time.perf_counter()
# During env reset we save the data and encode the videos
num_frames = frame_index
for key in image_keys:
if video:
tmp_imgs_dir = videos_dir / f"{key}_episode_{episode_index:06d}"
fname = f"{key}_episode_{episode_index:06d}.mp4"
video_path = local_dir / "videos" / fname
if video_path.exists():
video_path.unlink()
# Store the reference to the video frame, even tho the videos are not yet encoded
ep_dict[key] = []
for i in range(num_frames):
ep_dict[key].append({"path": f"videos/{fname}", "timestamp": i / fps})
else:
imgs_dir = videos_dir / f"{key}_episode_{episode_index:06d}"
ep_dict[key] = []
for i in range(num_frames):
img_path = imgs_dir / f"frame_{i:06d}.png"
ep_dict[key].append({"path": str(img_path)})
for key in not_image_keys:
ep_dict[key] = torch.stack(ep_dict[key])
for key in action:
ep_dict[key] = torch.stack(ep_dict[key])
ep_dict["episode_index"] = torch.tensor([episode_index] * num_frames)
ep_dict["frame_index"] = torch.arange(0, num_frames, 1)
ep_dict["timestamp"] = torch.arange(0, num_frames, 1) / fps
done = torch.zeros(num_frames, dtype=torch.bool)
done[-1] = True
ep_dict["next.done"] = done
ep_path = episodes_dir / f"episode_{episode_index}.pth"
print("Saving episode dictionary...")
torch.save(ep_dict, ep_path)
rec_info = {
"last_episode_index": episode_index,
}
with open(rec_info_path, "w") as f:
json.dump(rec_info, f)
is_last_episode = stop_recording or (episode_index == (num_episodes - 1))
# Wait if necessary
with tqdm.tqdm(total=reset_time_s, desc="Waiting") as pbar:
while timestamp < reset_time_s and not is_last_episode:
time.sleep(1)
timestamp = time.perf_counter() - start_vencod_t
pbar.update(1)
if exit_early:
exit_early = False
break
# Skip updating episode index which forces re-recording episode
if rerecord_episode:
rerecord_episode = False
continue continue
episode_index += 1 # Increment by one dataset["current_episode_index"]
save_current_episode(dataset)
if is_last_episode: if events["stop_recording"]:
logging.info("Done recording") break
if play_sounds:
say("Done recording", blocking=True)
if not is_headless():
listener.stop()
if has_camera > 0: log_say("Stop recording", play_sounds, blocking=True)
logging.info("Waiting for image writer to terminate...") stop_recording(robot, listener, display_cameras)
stop_image_writer(image_writer, timeout=20)
except Exception as e: lerobot_dataset = create_lerobot_dataset(dataset, run_compute_stats, push_to_hub, tags, play_sounds)
if has_camera > 0:
logging.info("Waiting for image writer to terminate...")
stop_image_writer(image_writer, timeout=20)
raise e
robot.disconnect() log_say("Exiting", play_sounds)
if display_cameras and not is_headless():
cv2.destroyAllWindows()
num_episodes = episode_index
if video:
logging.info("Encoding videos")
if play_sounds:
say("Encoding videos")
# Use ffmpeg to convert frames stored as png into mp4 videos
for episode_index in tqdm.tqdm(range(num_episodes)):
for key in image_keys:
tmp_imgs_dir = videos_dir / f"{key}_episode_{episode_index:06d}"
fname = f"{key}_episode_{episode_index:06d}.mp4"
video_path = local_dir / "videos" / fname
if video_path.exists():
# Skip if video is already encoded. Could be the case when resuming data recording.
continue
# note: `encode_video_frames` is a blocking call. Making it asynchronous shouldn't speedup encoding,
# since video encoding with ffmpeg is already using multithreading.
encode_video_frames(tmp_imgs_dir, video_path, fps, overwrite=True)
shutil.rmtree(tmp_imgs_dir)
logging.info("Concatenating episodes")
ep_dicts = []
for episode_index in tqdm.tqdm(range(num_episodes)):
ep_path = episodes_dir / f"episode_{episode_index}.pth"
ep_dict = torch.load(ep_path)
ep_dicts.append(ep_dict)
data_dict = concatenate_episodes(ep_dicts)
total_frames = data_dict["frame_index"].shape[0]
data_dict["index"] = torch.arange(0, total_frames, 1)
hf_dataset = to_hf_dataset(data_dict, video)
episode_data_index = calculate_episode_data_index(hf_dataset)
info = {
"codebase_version": CODEBASE_VERSION,
"fps": fps,
"video": video,
}
if video:
info["encoding"] = get_default_encoding()
lerobot_dataset = LeRobotDataset.from_preloaded(
repo_id=repo_id,
hf_dataset=hf_dataset,
episode_data_index=episode_data_index,
info=info,
videos_dir=videos_dir,
)
if run_compute_stats:
logging.info("Computing dataset statistics")
if play_sounds:
say("Computing dataset statistics")
stats = compute_stats(lerobot_dataset)
lerobot_dataset.stats = stats
else:
stats = {}
logging.info("Skipping computation of the dataset statistics")
hf_dataset = hf_dataset.with_format(None) # to remove transforms that cant be saved
hf_dataset.save_to_disk(str(local_dir / "train"))
meta_data_dir = local_dir / "meta_data"
save_meta_data(info, stats, episode_data_index, meta_data_dir)
if push_to_hub:
hf_dataset.push_to_hub(repo_id, revision="main")
push_meta_data_to_hub(repo_id, meta_data_dir, revision="main")
push_dataset_card_to_hub(repo_id, revision="main", tags=tags)
if video:
push_videos_to_hub(repo_id, videos_dir, revision="main")
create_branch(repo_id, repo_type="dataset", branch=CODEBASE_VERSION)
logging.info("Exiting")
if play_sounds:
say("Exiting")
return lerobot_dataset return lerobot_dataset
@safe_disconnect
def replay( def replay(
robot: Robot, episode: int, fps: int | None = None, root="data", repo_id="lerobot/debug", play_sounds=True robot: Robot, episode: int, fps: int | None = None, root="data", repo_id="lerobot/debug", play_sounds=True
): ):
# TODO(rcadene, aliberts): refactor with control_loop, once `dataset` is an instance of LeRobotDataset
# TODO(rcadene): Add option to record logs # TODO(rcadene): Add option to record logs
local_dir = Path(root) / repo_id local_dir = Path(root) / repo_id
if not local_dir.exists(): if not local_dir.exists():
@ -887,9 +328,7 @@ def replay(
if not robot.is_connected: if not robot.is_connected:
robot.connect() robot.connect()
logging.info("Replaying episode") log_say("Replaying episode", play_sounds, blocking=True)
if play_sounds:
say("Replaying episode", blocking=True)
for idx in range(from_idx, to_idx): for idx in range(from_idx, to_idx):
start_episode_t = time.perf_counter() start_episode_t = time.perf_counter()
@ -934,6 +373,12 @@ if __name__ == "__main__":
parser_teleop.add_argument( parser_teleop.add_argument(
"--fps", type=none_or_int, default=None, help="Frames per second (set to None to disable)" "--fps", type=none_or_int, default=None, help="Frames per second (set to None to disable)"
) )
parser_teleop.add_argument(
"--display-cameras",
type=int,
default=1,
help="Display all cameras on screen (set to 1 to display or 0).",
)
parser_record = subparsers.add_parser("record", parents=[base_parser]) parser_record = subparsers.add_parser("record", parents=[base_parser])
parser_record.add_argument( parser_record.add_argument(
@ -1071,18 +516,6 @@ if __name__ == "__main__":
teleoperate(robot, **kwargs) teleoperate(robot, **kwargs)
elif control_mode == "record": elif control_mode == "record":
pretrained_policy_name_or_path = args.pretrained_policy_name_or_path
policy_overrides = args.policy_overrides
del kwargs["pretrained_policy_name_or_path"]
del kwargs["policy_overrides"]
policy_cfg = None
if pretrained_policy_name_or_path is not None:
pretrained_policy_path = get_pretrained_policy_path(pretrained_policy_name_or_path)
policy_cfg = init_hydra_config(pretrained_policy_path / "config.yaml", policy_overrides)
policy = make_policy(hydra_cfg=policy_cfg, pretrained_policy_name_or_path=pretrained_policy_path)
record(robot, policy, policy_cfg, **kwargs)
else:
record(robot, **kwargs) record(robot, **kwargs)
elif control_mode == "replay": elif control_mode == "replay":

View File

@ -383,7 +383,7 @@ def train(cfg: DictConfig, out_dir: str | None = None, job_name: str | None = No
logging.info(f"Checkpoint policy after step {step}") logging.info(f"Checkpoint policy after step {step}")
# Note: Save with step as the identifier, and format it to have at least 6 digits but more if # Note: Save with step as the identifier, and format it to have at least 6 digits but more if
# needed (choose 6 as a minimum for consistency without being overkill). # needed (choose 6 as a minimum for consistency without being overkill).
logger.save_checkpont( logger.save_checkpoint(
step, step,
policy, policy,
optimizer, optimizer,

View File

@ -25,12 +25,16 @@ pytest -sx 'tests/test_control_robot.py::test_teleoperate[aloha-True]'
import multiprocessing import multiprocessing
from pathlib import Path from pathlib import Path
from unittest.mock import patch
import pytest import pytest
from lerobot.common.datasets.populate_dataset import add_frame, init_dataset
from lerobot.common.logger import Logger
from lerobot.common.policies.factory import make_policy from lerobot.common.policies.factory import make_policy
from lerobot.common.utils.utils import init_hydra_config from lerobot.common.utils.utils import init_hydra_config
from lerobot.scripts.control_robot import calibrate, get_available_arms, record, replay, teleoperate from lerobot.scripts.control_robot import calibrate, record, replay, teleoperate
from lerobot.scripts.train import make_optimizer_and_scheduler
from tests.test_robots import make_robot from tests.test_robots import make_robot
from tests.utils import DEFAULT_CONFIG_PATH, DEVICE, TEST_ROBOT_TYPES, require_robot from tests.utils import DEFAULT_CONFIG_PATH, DEVICE, TEST_ROBOT_TYPES, require_robot
@ -69,7 +73,7 @@ def test_calibrate(tmpdir, request, robot_type, mock):
overrides_calibration_dir = [f"calibration_dir={calibration_dir}"] overrides_calibration_dir = [f"calibration_dir={calibration_dir}"]
robot = make_robot(robot_type, overrides=overrides_calibration_dir, mock=mock) robot = make_robot(robot_type, overrides=overrides_calibration_dir, mock=mock)
calibrate(robot, arms=get_available_arms(robot)) calibrate(robot, arms=robot.available_arms)
del robot del robot
@ -109,12 +113,14 @@ def test_record_without_cameras(tmpdir, request, robot_type, mock):
@pytest.mark.parametrize("robot_type, mock", TEST_ROBOT_TYPES) @pytest.mark.parametrize("robot_type, mock", TEST_ROBOT_TYPES)
@require_robot @require_robot
def test_record_and_replay_and_policy(tmpdir, request, robot_type, mock): def test_record_and_replay_and_policy(tmpdir, request, robot_type, mock):
tmpdir = Path(tmpdir)
if mock and robot_type != "aloha": if mock and robot_type != "aloha":
request.getfixturevalue("patch_builtins_input") request.getfixturevalue("patch_builtins_input")
# Create an empty calibration directory to trigger manual calibration # Create an empty calibration directory to trigger manual calibration
# and avoid writing calibration files in user .cache/calibration folder # and avoid writing calibration files in user .cache/calibration folder
calibration_dir = Path(tmpdir) / robot_type calibration_dir = tmpdir / robot_type
overrides = [f"calibration_dir={calibration_dir}"] overrides = [f"calibration_dir={calibration_dir}"]
else: else:
# Use the default .cache/calibration folder when mock=False or for aloha # Use the default .cache/calibration folder when mock=False or for aloha
@ -123,17 +129,19 @@ def test_record_and_replay_and_policy(tmpdir, request, robot_type, mock):
env_name = "koch_real" env_name = "koch_real"
policy_name = "act_koch_real" policy_name = "act_koch_real"
root = Path(tmpdir) / "data" root = tmpdir / "data"
repo_id = "lerobot/debug" repo_id = "lerobot/debug"
eval_repo_id = "lerobot/eval_debug"
robot = make_robot(robot_type, overrides=overrides, mock=mock) robot = make_robot(robot_type, overrides=overrides, mock=mock)
dataset = record( dataset = record(
robot, robot,
fps=30, root,
root=root, repo_id,
repo_id=repo_id, fps=1,
warmup_time_s=1, warmup_time_s=1,
episode_time_s=1, episode_time_s=1,
reset_time_s=1,
num_episodes=2, num_episodes=2,
push_to_hub=False, push_to_hub=False,
# TODO(rcadene, aliberts): test video=True # TODO(rcadene, aliberts): test video=True
@ -142,8 +150,10 @@ def test_record_and_replay_and_policy(tmpdir, request, robot_type, mock):
display_cameras=False, display_cameras=False,
play_sounds=False, play_sounds=False,
) )
assert dataset.num_episodes == 2
assert len(dataset) == 2
replay(robot, episode=0, fps=30, root=root, repo_id=repo_id, play_sounds=False) replay(robot, episode=0, fps=1, root=root, repo_id=repo_id, play_sounds=False)
# TODO(rcadene, aliberts): rethink this design # TODO(rcadene, aliberts): rethink this design
if robot_type == "aloha": if robot_type == "aloha":
@ -164,12 +174,26 @@ def test_record_and_replay_and_policy(tmpdir, request, robot_type, mock):
if robot_type == "koch_bimanual": if robot_type == "koch_bimanual":
overrides += ["env.state_dim=12", "env.action_dim=12"] overrides += ["env.state_dim=12", "env.action_dim=12"]
overrides += ["wandb.enable=false"]
overrides += ["env.fps=1"]
cfg = init_hydra_config( cfg = init_hydra_config(
DEFAULT_CONFIG_PATH, DEFAULT_CONFIG_PATH,
overrides=overrides, overrides=overrides,
) )
policy = make_policy(hydra_cfg=cfg, dataset_stats=dataset.stats) policy = make_policy(hydra_cfg=cfg, dataset_stats=dataset.stats)
optimizer, lr_scheduler = make_optimizer_and_scheduler(cfg, policy)
out_dir = tmpdir / "logger"
logger = Logger(cfg, out_dir, wandb_job_name="debug")
logger.save_checkpoint(
0,
policy,
optimizer,
lr_scheduler,
identifier=0,
)
pretrained_policy_name_or_path = out_dir / "checkpoints/last/pretrained_model"
# In `examples/9_use_aloha.md`, we advise using `num_image_writer_processes=1` # In `examples/9_use_aloha.md`, we advise using `num_image_writer_processes=1`
# during inference, to reach constent fps, so we test this here. # during inference, to reach constent fps, so we test this here.
@ -194,10 +218,12 @@ def test_record_and_replay_and_policy(tmpdir, request, robot_type, mock):
record( record(
robot, robot,
policy, root,
cfg, eval_repo_id,
pretrained_policy_name_or_path,
warmup_time_s=1, warmup_time_s=1,
episode_time_s=1, episode_time_s=1,
reset_time_s=1,
num_episodes=2, num_episodes=2,
run_compute_stats=False, run_compute_stats=False,
push_to_hub=False, push_to_hub=False,
@ -207,4 +233,218 @@ def test_record_and_replay_and_policy(tmpdir, request, robot_type, mock):
num_image_writer_processes=num_image_writer_processes, num_image_writer_processes=num_image_writer_processes,
) )
assert dataset.num_episodes == 2
assert len(dataset) == 2
del robot del robot
@pytest.mark.parametrize("robot_type, mock", [("koch", True)])
@require_robot
def test_resume_record(tmpdir, request, robot_type, mock):
if mock and robot_type != "aloha":
request.getfixturevalue("patch_builtins_input")
# Create an empty calibration directory to trigger manual calibration
# and avoid writing calibration files in user .cache/calibration folder
calibration_dir = tmpdir / robot_type
overrides = [f"calibration_dir={calibration_dir}"]
else:
# Use the default .cache/calibration folder when mock=False or for aloha
overrides = []
robot = make_robot(robot_type, overrides=overrides, mock=mock)
root = Path(tmpdir) / "data"
repo_id = "lerobot/debug"
dataset = record(
robot,
root,
repo_id,
fps=1,
warmup_time_s=0,
episode_time_s=1,
num_episodes=1,
push_to_hub=False,
video=False,
display_cameras=False,
play_sounds=False,
run_compute_stats=False,
)
assert len(dataset) == 1, "`dataset` should contain only 1 frame"
init_dataset_return_value = {}
def wrapped_init_dataset(*args, **kwargs):
nonlocal init_dataset_return_value
init_dataset_return_value = init_dataset(*args, **kwargs)
return init_dataset_return_value
with patch("lerobot.scripts.control_robot.init_dataset", wraps=wrapped_init_dataset):
dataset = record(
robot,
root,
repo_id,
fps=1,
warmup_time_s=0,
episode_time_s=1,
num_episodes=2,
push_to_hub=False,
video=False,
display_cameras=False,
play_sounds=False,
run_compute_stats=False,
)
assert len(dataset) == 2, "`dataset` should contain only 1 frame"
assert (
init_dataset_return_value["num_episodes"] == 2
), "`init_dataset` should load the previous episode"
@pytest.mark.parametrize("robot_type, mock", [("koch", True)])
@require_robot
def test_record_with_event_rerecord_episode(tmpdir, request, robot_type, mock):
if mock and robot_type != "aloha":
request.getfixturevalue("patch_builtins_input")
# Create an empty calibration directory to trigger manual calibration
# and avoid writing calibration files in user .cache/calibration folder
calibration_dir = tmpdir / robot_type
overrides = [f"calibration_dir={calibration_dir}"]
else:
# Use the default .cache/calibration folder when mock=False or for aloha
overrides = []
robot = make_robot(robot_type, overrides=overrides, mock=mock)
with (
patch("lerobot.scripts.control_robot.init_keyboard_listener") as mock_listener,
patch("lerobot.common.robot_devices.control_utils.add_frame", wraps=add_frame) as mock_add_frame,
):
mock_events = {}
mock_events["exit_early"] = True
mock_events["rerecord_episode"] = True
mock_events["stop_recording"] = False
mock_listener.return_value = (None, mock_events)
root = Path(tmpdir) / "data"
repo_id = "lerobot/debug"
dataset = record(
robot,
root,
repo_id,
fps=1,
warmup_time_s=0,
episode_time_s=1,
num_episodes=1,
push_to_hub=False,
video=False,
display_cameras=False,
play_sounds=False,
run_compute_stats=False,
)
assert not mock_events["rerecord_episode"], "`rerecord_episode` wasn't properly reset to False"
assert not mock_events["exit_early"], "`exit_early` wasn't properly reset to False"
assert mock_add_frame.call_count == 2, "`add_frame` should have been called 2 times"
assert len(dataset) == 1, "`dataset` should contain only 1 frame"
@pytest.mark.parametrize("robot_type, mock", [("koch", True)])
@require_robot
def test_record_with_event_exit_early(tmpdir, request, robot_type, mock):
if mock:
request.getfixturevalue("patch_builtins_input")
# Create an empty calibration directory to trigger manual calibration
# and avoid writing calibration files in user .cache/calibration folder
calibration_dir = tmpdir / robot_type
overrides = [f"calibration_dir={calibration_dir}"]
else:
# Use the default .cache/calibration folder when mock=False or for aloha
overrides = []
robot = make_robot(robot_type, overrides=overrides, mock=mock)
with (
patch("lerobot.scripts.control_robot.init_keyboard_listener") as mock_listener,
patch("lerobot.common.robot_devices.control_utils.add_frame", wraps=add_frame) as mock_add_frame,
):
mock_events = {}
mock_events["exit_early"] = True
mock_events["rerecord_episode"] = False
mock_events["stop_recording"] = False
mock_listener.return_value = (None, mock_events)
root = Path(tmpdir) / "data"
repo_id = "lerobot/debug"
dataset = record(
robot,
fps=2,
root=root,
repo_id=repo_id,
warmup_time_s=0,
episode_time_s=1,
num_episodes=1,
push_to_hub=False,
video=False,
display_cameras=False,
play_sounds=False,
run_compute_stats=False,
)
assert not mock_events["exit_early"], "`exit_early` wasn't properly reset to False"
assert mock_add_frame.call_count == 1, "`add_frame` should have been called 1 time"
assert len(dataset) == 1, "`dataset` should contain only 1 frame"
@pytest.mark.parametrize(
"robot_type, mock, num_image_writer_processes", [("koch", True, 0), ("koch", True, 1)]
)
@require_robot
def test_record_with_event_stop_recording(tmpdir, request, robot_type, mock, num_image_writer_processes):
if mock:
request.getfixturevalue("patch_builtins_input")
# Create an empty calibration directory to trigger manual calibration
# and avoid writing calibration files in user .cache/calibration folder
calibration_dir = tmpdir / robot_type
overrides = [f"calibration_dir={calibration_dir}"]
else:
# Use the default .cache/calibration folder when mock=False or for aloha
overrides = []
robot = make_robot(robot_type, overrides=overrides, mock=mock)
with (
patch("lerobot.scripts.control_robot.init_keyboard_listener") as mock_listener,
patch("lerobot.common.robot_devices.control_utils.add_frame", wraps=add_frame) as mock_add_frame,
):
mock_events = {}
mock_events["exit_early"] = True
mock_events["rerecord_episode"] = False
mock_events["stop_recording"] = True
mock_listener.return_value = (None, mock_events)
root = Path(tmpdir) / "data"
repo_id = "lerobot/debug"
dataset = record(
robot,
root,
repo_id,
fps=1,
warmup_time_s=0,
episode_time_s=1,
num_episodes=2,
push_to_hub=False,
video=False,
display_cameras=False,
play_sounds=False,
run_compute_stats=False,
num_image_writer_processes=num_image_writer_processes,
)
assert not mock_events["exit_early"], "`exit_early` wasn't properly reset to False"
assert mock_add_frame.call_count == 1, "`add_frame` should have been called 1 time"
assert len(dataset) == 1, "`dataset` should contain only 1 frame"

View File

@ -127,6 +127,7 @@ def test_robot(tmpdir, request, robot_type, mock):
# TODO(rcadene): skipping image for now as it's challenging to assess equality between two consecutive frames # TODO(rcadene): skipping image for now as it's challenging to assess equality between two consecutive frames
continue continue
assert torch.allclose(captured_observation[name], observation[name], atol=1) assert torch.allclose(captured_observation[name], observation[name], atol=1)
assert captured_observation[name].shape == observation[name].shape
# Test send_action can run # Test send_action can run
robot.send_action(action["action"]) robot.send_action(action["action"])