2024-07-15 23:43:10 +08:00
"""
2024-08-16 00:11:33 +08:00
Utilities to control a robot .
Useful to record a dataset , replay a recorded episode , run the policy on your robot
and record an evaluation dataset , and to recalibrate your robot if needed .
2024-07-15 23:43:10 +08:00
Examples of usage :
2024-08-16 00:11:33 +08:00
- Recalibrate your robot :
` ` ` bash
python lerobot / scripts / control_robot . py calibrate
` ` `
2024-07-15 23:43:10 +08:00
- Unlimited teleoperation at highest frequency ( ~ 200 Hz is expected ) , to exit with CTRL + C :
` ` ` bash
python lerobot / scripts / control_robot . py teleoperate
2024-08-16 00:11:33 +08:00
# Remove the cameras from the robot definition. They are not used in 'teleoperate' anyway.
python lerobot / scripts / control_robot . py teleoperate - - robot - overrides ' ~cameras '
2024-07-15 23:43:10 +08:00
` ` `
- Unlimited teleoperation at a limited frequency of 30 Hz , to simulate data recording frequency :
` ` ` bash
python lerobot / scripts / control_robot . py teleoperate \
- - fps 30
` ` `
- Record one episode in order to test replay :
` ` ` bash
2024-08-16 00:11:33 +08:00
python lerobot / scripts / control_robot . py record \
2024-07-15 23:43:10 +08:00
- - fps 30 \
- - root tmp / data \
- - repo - id $ USER / koch_test \
- - num - episodes 1 \
- - run - compute - stats 0
` ` `
- Visualize dataset :
` ` ` bash
python lerobot / scripts / visualize_dataset . py \
- - root tmp / data \
- - repo - id $ USER / koch_test \
- - episode - index 0
` ` `
- Replay this test episode :
` ` ` bash
2024-08-16 00:11:33 +08:00
python lerobot / scripts / control_robot . py replay \
2024-07-15 23:43:10 +08:00
- - fps 30 \
- - root tmp / data \
- - repo - id $ USER / koch_test \
- - episode 0
` ` `
- Record a full dataset in order to train a policy , with 2 seconds of warmup ,
30 seconds of recording for each episode , and 10 seconds to reset the environment in between episodes :
` ` ` bash
2024-08-16 00:11:33 +08:00
python lerobot / scripts / control_robot . py record \
2024-07-15 23:43:10 +08:00
- - fps 30 \
- - root data \
- - repo - id $ USER / koch_pick_place_lego \
- - num - episodes 50 \
- - warmup - time - s 2 \
- - episode - time - s 30 \
- - reset - time - s 10
` ` `
* * NOTE * * : You can use your keyboard to control data recording flow .
- Tap right arrow key ' -> ' to early exit while recording an episode and go to resseting the environment .
- Tap right arrow key ' -> ' to early exit while resetting the environment and got to recording the next episode .
- Tap left arrow key ' <- ' to early exit and re - record the current episode .
- Tap escape key ' esc ' to stop the data recording .
This might require a sudo permission to allow your terminal to monitor keyboard events .
* * NOTE * * : You can resume / continue data recording by running the same data recording command twice .
To avoid resuming by deleting the dataset , use ` - - force - override 1 ` .
- Train on this dataset with the ACT policy :
` ` ` bash
DATA_DIR = data python lerobot / scripts / train . py \
policy = act_koch_real \
env = koch_real \
dataset_repo_id = $ USER / koch_pick_place_lego \
hydra . run . dir = outputs / train / act_koch_real
` ` `
- Run the pretrained policy on the robot :
` ` ` bash
2024-08-16 00:11:33 +08:00
python lerobot / scripts / control_robot . py record \
- - fps 30 \
- - root data \
- - repo - id $ USER / eval_act_koch_real \
- - num - episodes 10 \
- - warmup - time - s 2 \
- - episode - time - s 30 \
- - reset - time - s 10
2024-07-15 23:43:10 +08:00
- p outputs / train / act_koch_real / checkpoints / 080000 / pretrained_model
` ` `
"""
import argparse
import concurrent . futures
import json
import logging
import os
import platform
import shutil
import time
2024-08-16 00:11:33 +08:00
import traceback
2024-07-15 23:43:10 +08:00
from contextlib import nullcontext
2024-08-16 00:11:33 +08:00
from functools import cache
2024-07-15 23:43:10 +08:00
from pathlib import Path
2024-08-16 00:11:33 +08:00
import cv2
2024-07-15 23:43:10 +08:00
import torch
import tqdm
from omegaconf import DictConfig
from PIL import Image
from termcolor import colored
# from safetensors.torch import load_file, save_file
from lerobot . common . datasets . compute_stats import compute_stats
from lerobot . common . datasets . lerobot_dataset import CODEBASE_VERSION , LeRobotDataset
from lerobot . common . datasets . push_dataset_to_hub . aloha_hdf5_format import to_hf_dataset
2024-07-23 02:08:59 +08:00
from lerobot . common . datasets . push_dataset_to_hub . utils import concatenate_episodes , get_default_encoding
2024-08-16 00:11:33 +08:00
from lerobot . common . datasets . utils import calculate_episode_data_index , create_branch
2024-07-15 23:43:10 +08:00
from lerobot . common . datasets . video_utils import encode_video_frames
from lerobot . common . policies . factory import make_policy
from lerobot . common . robot_devices . robots . factory import make_robot
2024-09-05 01:28:05 +08:00
from lerobot . common . robot_devices . robots . utils import Robot , get_arm_id
2024-10-05 00:56:42 +08:00
from lerobot . common . robot_devices . utils import busy_wait , safe_disconnect
2024-07-15 23:43:10 +08:00
from lerobot . common . utils . utils import get_safe_torch_device , init_hydra_config , init_logging , set_global_seed
from lerobot . scripts . eval import get_pretrained_policy_path
2024-08-16 16:08:44 +08:00
from lerobot . scripts . push_dataset_to_hub import (
push_dataset_card_to_hub ,
push_meta_data_to_hub ,
push_videos_to_hub ,
save_meta_data ,
)
2024-07-15 23:43:10 +08:00
########################################################################################
# Utilities
########################################################################################
2024-08-16 00:11:33 +08:00
def say ( text , blocking = False ) :
# Check if mac, linux, or windows.
if platform . system ( ) == " Darwin " :
cmd = f ' say " { text } " '
elif platform . system ( ) == " Linux " :
cmd = f ' spd-say " { text } " '
elif platform . system ( ) == " Windows " :
cmd = (
' PowerShell -Command " Add-Type -AssemblyName System.Speech; '
f " (New-Object System.Speech.Synthesis.SpeechSynthesizer).Speak( ' { text } ' ) \" "
)
if not blocking and platform . system ( ) in [ " Darwin " , " Linux " ] :
# TODO(rcadene): Make it work for Windows
# Use the ampersand to run command in the background
cmd + = " & "
os . system ( cmd )
2024-07-15 23:43:10 +08:00
def save_image ( img_tensor , key , frame_index , episode_index , videos_dir ) :
img = Image . fromarray ( img_tensor . numpy ( ) )
path = videos_dir / f " { key } _episode_ { episode_index : 06d } " / f " frame_ { frame_index : 06d } .png "
path . parent . mkdir ( parents = True , exist_ok = True )
img . save ( str ( path ) , quality = 100 )
def none_or_int ( value ) :
if value == " None " :
return None
return int ( value )
2024-10-05 00:56:42 +08:00
def log_control_info ( robot : Robot , dt_s , episode_index = None , frame_index = None , fps = None ) :
2024-07-15 23:43:10 +08:00
log_items = [ ]
if episode_index is not None :
2024-09-12 20:20:24 +08:00
log_items . append ( f " ep: { episode_index } " )
2024-07-15 23:43:10 +08:00
if frame_index is not None :
2024-09-12 20:20:24 +08:00
log_items . append ( f " frame: { frame_index } " )
2024-07-15 23:43:10 +08:00
def log_dt ( shortname , dt_val_s ) :
2024-09-12 20:20:24 +08:00
nonlocal log_items , fps
info_str = f " { shortname } : { dt_val_s * 1000 : 5.2f } ( { 1 / dt_val_s : 3.1f } hz) "
if fps is not None :
actual_fps = 1 / dt_val_s
if actual_fps < fps - 1 :
info_str = colored ( info_str , " yellow " )
log_items . append ( info_str )
2024-07-15 23:43:10 +08:00
# total step time displayed in milliseconds and its frequency
log_dt ( " dt " , dt_s )
2024-10-05 00:56:42 +08:00
# TODO(aliberts): move robot-specific logs logic in robot.print_logs()
if not robot . robot_type . startswith ( " stretch " ) :
for name in robot . leader_arms :
key = f " read_leader_ { name } _pos_dt_s "
if key in robot . logs :
log_dt ( " dtRlead " , robot . logs [ key ] )
2024-07-15 23:43:10 +08:00
2024-10-05 00:56:42 +08:00
for name in robot . follower_arms :
key = f " write_follower_ { name } _goal_pos_dt_s "
if key in robot . logs :
log_dt ( " dtWfoll " , robot . logs [ key ] )
2024-07-15 23:43:10 +08:00
2024-10-05 00:56:42 +08:00
key = f " read_follower_ { name } _pos_dt_s "
if key in robot . logs :
log_dt ( " dtRfoll " , robot . logs [ key ] )
2024-07-15 23:43:10 +08:00
2024-10-05 00:56:42 +08:00
for name in robot . cameras :
key = f " read_camera_ { name } _dt_s "
if key in robot . logs :
log_dt ( f " dtR { name } " , robot . logs [ key ] )
2024-07-15 23:43:10 +08:00
info_str = " " . join ( log_items )
logging . info ( info_str )
2024-08-16 00:11:33 +08:00
@cache
def is_headless ( ) :
""" Detects if python is running without a monitor. """
try :
import pynput # noqa
return False
except Exception :
print (
" Error trying to import pynput. Switching to headless mode. "
" As a result, the video stream from the cameras won ' t be shown, "
" and you won ' t be able to change the control flow with keyboards. "
" For more info, see traceback below. \n "
)
traceback . print_exc ( )
print ( )
return True
2024-07-15 23:43:10 +08:00
2024-10-05 00:56:42 +08:00
def has_method ( _object : object , method_name : str ) :
return hasattr ( _object , method_name ) and callable ( getattr ( _object , method_name ) )
2024-07-15 23:43:10 +08:00
2024-10-03 23:05:23 +08:00
def get_available_arms ( robot ) :
# TODO(rcadene): moves this function in manipulator class?
2024-09-05 01:28:05 +08:00
available_arms = [ ]
for name in robot . follower_arms :
arm_id = get_arm_id ( name , " follower " )
available_arms . append ( arm_id )
for name in robot . leader_arms :
arm_id = get_arm_id ( name , " leader " )
available_arms . append ( arm_id )
2024-10-03 23:05:23 +08:00
return available_arms
2024-09-05 01:28:05 +08:00
2024-10-05 00:56:42 +08:00
########################################################################################
# Control modes
########################################################################################
@safe_disconnect
2024-10-03 23:05:23 +08:00
def calibrate ( robot : Robot , arms : list [ str ] | None ) :
2024-10-05 00:56:42 +08:00
# TODO(aliberts): move this code in robots' classes
if robot . robot_type . startswith ( " stretch " ) :
if not robot . is_connected :
robot . connect ( )
if not robot . is_homed ( ) :
robot . home ( )
return
2024-10-03 23:05:23 +08:00
available_arms = get_available_arms ( robot )
unknown_arms = [ arm_id for arm_id in arms if arm_id not in available_arms ]
2024-09-05 01:28:05 +08:00
available_arms_str = " " . join ( available_arms )
unknown_arms_str = " " . join ( unknown_arms )
if arms is None or len ( arms ) == 0 :
raise ValueError (
" No arm provided. Use `--arms` as argument with one or more available arms. \n "
f " For instance, to recalibrate all arms add: `--arms { available_arms_str } ` "
)
if len ( unknown_arms ) > 0 :
raise ValueError (
f " Unknown arms provided ( ' { unknown_arms_str } ' ). Available arms are ` { available_arms_str } `. "
)
for arm_id in arms :
arm_calib_path = robot . calibration_dir / f " { arm_id } .json "
if arm_calib_path . exists ( ) :
print ( f " Removing ' { arm_calib_path } ' " )
arm_calib_path . unlink ( )
else :
print ( f " Calibration file not found ' { arm_calib_path } ' " )
2024-08-16 00:11:33 +08:00
if robot . is_connected :
robot . disconnect ( )
# Calling `connect` automatically runs calibration
# when the calibration file is missing
robot . connect ( )
2024-09-05 01:28:05 +08:00
robot . disconnect ( )
print ( " Calibration is done! You can now teleoperate and record datasets! " )
2024-08-16 00:11:33 +08:00
2024-10-05 00:56:42 +08:00
@safe_disconnect
2024-07-15 23:43:10 +08:00
def teleoperate ( robot : Robot , fps : int | None = None , teleop_time_s : float | None = None ) :
# TODO(rcadene): Add option to record logs
if not robot . is_connected :
robot . connect ( )
2024-08-16 00:11:33 +08:00
start_teleop_t = time . perf_counter ( )
2024-07-15 23:43:10 +08:00
while True :
2024-08-16 00:11:33 +08:00
start_loop_t = time . perf_counter ( )
2024-07-15 23:43:10 +08:00
robot . teleop_step ( )
if fps is not None :
2024-08-16 00:11:33 +08:00
dt_s = time . perf_counter ( ) - start_loop_t
2024-07-15 23:43:10 +08:00
busy_wait ( 1 / fps - dt_s )
2024-08-16 00:11:33 +08:00
dt_s = time . perf_counter ( ) - start_loop_t
2024-07-15 23:43:10 +08:00
log_control_info ( robot , dt_s , fps = fps )
2024-08-16 00:11:33 +08:00
if teleop_time_s is not None and time . perf_counter ( ) - start_teleop_t > teleop_time_s :
2024-07-15 23:43:10 +08:00
break
2024-10-05 00:56:42 +08:00
@safe_disconnect
2024-08-16 00:11:33 +08:00
def record (
2024-07-15 23:43:10 +08:00
robot : Robot ,
2024-08-16 00:11:33 +08:00
policy : torch . nn . Module | None = None ,
hydra_cfg : DictConfig | None = None ,
2024-07-15 23:43:10 +08:00
fps : int | None = None ,
root = " data " ,
repo_id = " lerobot/debug " ,
warmup_time_s = 2 ,
episode_time_s = 10 ,
reset_time_s = 5 ,
num_episodes = 50 ,
video = True ,
run_compute_stats = True ,
push_to_hub = True ,
2024-08-16 16:08:44 +08:00
tags = None ,
2024-09-12 20:20:24 +08:00
num_image_writers_per_camera = 4 ,
2024-07-15 23:43:10 +08:00
force_override = False ,
2024-10-03 23:05:23 +08:00
display_cameras = True ,
2024-07-15 23:43:10 +08:00
) :
# TODO(rcadene): Add option to record logs
2024-08-16 00:11:33 +08:00
# TODO(rcadene): Clean this function via decomposition in higher level functions
_ , dataset_name = repo_id . split ( " / " )
if dataset_name . startswith ( " eval_ " ) and policy is None :
raise ValueError (
f " Your dataset name begins by ' eval_ ' ( { dataset_name } ) but no policy is provided ( { policy } ). "
)
2024-07-15 23:43:10 +08:00
if not robot . is_connected :
robot . connect ( )
local_dir = Path ( root ) / repo_id
if local_dir . exists ( ) and force_override :
shutil . rmtree ( local_dir )
episodes_dir = local_dir / " episodes "
episodes_dir . mkdir ( parents = True , exist_ok = True )
videos_dir = local_dir / " videos "
videos_dir . mkdir ( parents = True , exist_ok = True )
# Logic to resume data recording
rec_info_path = episodes_dir / " data_recording_info.json "
if rec_info_path . exists ( ) :
with open ( rec_info_path ) as f :
rec_info = json . load ( f )
episode_index = rec_info [ " last_episode_index " ] + 1
else :
episode_index = 0
2024-08-16 00:11:33 +08:00
if is_headless ( ) :
2024-10-03 23:05:23 +08:00
logging . warning (
2024-08-16 00:11:33 +08:00
" Headless environment detected. On-screen cameras display and keyboard inputs will not be available. "
)
2024-07-15 23:43:10 +08:00
# Allow to exit early while recording an episode or resetting the environment,
# by tapping the right arrow key '->'. This might require a sudo permission
# to allow your terminal to monitor keyboard events.
exit_early = False
rerecord_episode = False
stop_recording = False
# Only import pynput if not in a headless environment
2024-08-16 00:11:33 +08:00
if not is_headless ( ) :
2024-07-15 23:43:10 +08:00
from pynput import keyboard
def on_press ( key ) :
nonlocal exit_early , rerecord_episode , stop_recording
try :
if key == keyboard . Key . right :
print ( " Right arrow key pressed. Exiting loop... " )
exit_early = True
elif key == keyboard . Key . left :
print ( " Left arrow key pressed. Exiting loop and rerecord the last episode... " )
rerecord_episode = True
exit_early = True
elif key == keyboard . Key . esc :
print ( " Escape key pressed. Stopping data recording... " )
stop_recording = True
exit_early = True
except Exception as e :
print ( f " Error handling key press: { e } " )
listener = keyboard . Listener ( on_press = on_press )
listener . start ( )
2024-08-16 00:11:33 +08:00
# Load policy if any
if policy is not None :
# Check device is available
device = get_safe_torch_device ( hydra_cfg . device , log = True )
policy . eval ( )
policy . to ( device )
torch . backends . cudnn . benchmark = True
torch . backends . cuda . matmul . allow_tf32 = True
set_global_seed ( hydra_cfg . seed )
# override fps using policy fps
fps = hydra_cfg . env . fps
# Execute a few seconds without recording data, to give times
# to the robot devices to connect and start synchronizing.
timestamp = 0
start_warmup_t = time . perf_counter ( )
is_warmup_print = False
while timestamp < warmup_time_s :
if not is_warmup_print :
logging . info ( " Warming up (no data recording) " )
say ( " Warming up " )
is_warmup_print = True
start_loop_t = time . perf_counter ( )
if policy is None :
observation , action = robot . teleop_step ( record_data = True )
else :
observation = robot . capture_observation ( )
2024-10-03 23:05:23 +08:00
if display_cameras and not is_headless ( ) :
2024-08-16 00:11:33 +08:00
image_keys = [ key for key in observation if " image " in key ]
for key in image_keys :
cv2 . imshow ( key , cv2 . cvtColor ( observation [ key ] . numpy ( ) , cv2 . COLOR_RGB2BGR ) )
cv2 . waitKey ( 1 )
dt_s = time . perf_counter ( ) - start_loop_t
busy_wait ( 1 / fps - dt_s )
dt_s = time . perf_counter ( ) - start_loop_t
log_control_info ( robot , dt_s , fps = fps )
timestamp = time . perf_counter ( ) - start_warmup_t
2024-10-05 00:56:42 +08:00
if has_method ( robot , " teleop_safety_stop " ) :
robot . teleop_safety_stop ( )
2024-07-15 23:43:10 +08:00
# Save images using threads to reach high fps (30 and more)
# Using `with` to exist smoothly if an execption is raised.
futures = [ ]
2024-09-12 20:20:24 +08:00
num_image_writers = num_image_writers_per_camera * len ( robot . cameras )
2024-10-03 23:05:23 +08:00
num_image_writers = max ( num_image_writers , 1 )
2024-07-15 23:43:10 +08:00
with concurrent . futures . ThreadPoolExecutor ( max_workers = num_image_writers ) as executor :
# Start recording all episodes
while episode_index < num_episodes :
logging . info ( f " Recording episode { episode_index } " )
2024-08-16 00:11:33 +08:00
say ( f " Recording episode { episode_index } " )
2024-07-15 23:43:10 +08:00
ep_dict = { }
frame_index = 0
timestamp = 0
2024-08-16 00:11:33 +08:00
start_episode_t = time . perf_counter ( )
2024-07-15 23:43:10 +08:00
while timestamp < episode_time_s :
2024-08-16 00:11:33 +08:00
start_loop_t = time . perf_counter ( )
if policy is None :
observation , action = robot . teleop_step ( record_data = True )
else :
observation = robot . capture_observation ( )
2024-07-15 23:43:10 +08:00
image_keys = [ key for key in observation if " image " in key ]
not_image_keys = [ key for key in observation if " image " not in key ]
for key in image_keys :
futures + = [
executor . submit (
save_image , observation [ key ] , key , frame_index , episode_index , videos_dir
)
]
2024-10-03 23:05:23 +08:00
if display_cameras and not is_headless ( ) :
2024-08-16 00:11:33 +08:00
image_keys = [ key for key in observation if " image " in key ]
for key in image_keys :
cv2 . imshow ( key , cv2 . cvtColor ( observation [ key ] . numpy ( ) , cv2 . COLOR_RGB2BGR ) )
cv2 . waitKey ( 1 )
2024-07-15 23:43:10 +08:00
for key in not_image_keys :
if key not in ep_dict :
ep_dict [ key ] = [ ]
ep_dict [ key ] . append ( observation [ key ] )
2024-08-16 00:11:33 +08:00
if policy is not None :
with (
torch . inference_mode ( ) ,
torch . autocast ( device_type = device . type )
if device . type == " cuda " and hydra_cfg . use_amp
else nullcontext ( ) ,
) :
# Convert to pytorch format: channel first and float32 in [0,1] with batch dimension
for name in observation :
if " image " in name :
observation [ name ] = observation [ name ] . type ( torch . float32 ) / 255
observation [ name ] = observation [ name ] . permute ( 2 , 0 , 1 ) . contiguous ( )
observation [ name ] = observation [ name ] . unsqueeze ( 0 )
observation [ name ] = observation [ name ] . to ( device )
# Compute the next action with the policy
# based on the current observation
action = policy . select_action ( observation )
# Remove batch dimension
action = action . squeeze ( 0 )
# Move to cpu, if not already the case
action = action . to ( " cpu " )
# Order the robot to move
2024-09-05 01:28:05 +08:00
action_sent = robot . send_action ( action )
# Action can eventually be clipped using `max_relative_target`,
# so action actually sent is saved in the dataset.
action = { " action " : action_sent }
2024-08-16 00:11:33 +08:00
2024-07-15 23:43:10 +08:00
for key in action :
if key not in ep_dict :
ep_dict [ key ] = [ ]
ep_dict [ key ] . append ( action [ key ] )
frame_index + = 1
2024-08-16 00:11:33 +08:00
dt_s = time . perf_counter ( ) - start_loop_t
2024-07-15 23:43:10 +08:00
busy_wait ( 1 / fps - dt_s )
2024-08-16 00:11:33 +08:00
dt_s = time . perf_counter ( ) - start_loop_t
2024-07-15 23:43:10 +08:00
log_control_info ( robot , dt_s , fps = fps )
2024-08-16 00:11:33 +08:00
timestamp = time . perf_counter ( ) - start_episode_t
2024-07-15 23:43:10 +08:00
if exit_early :
exit_early = False
break
2024-10-05 00:56:42 +08:00
# TODO(alibets): allow for teleop during reset
if has_method ( robot , " teleop_safety_stop " ) :
robot . teleop_safety_stop ( )
2024-07-15 23:43:10 +08:00
if not stop_recording :
# Start resetting env while the executor are finishing
logging . info ( " Reset the environment " )
2024-08-16 00:11:33 +08:00
say ( " Reset the environment " )
2024-07-15 23:43:10 +08:00
timestamp = 0
2024-08-16 00:11:33 +08:00
start_vencod_t = time . perf_counter ( )
2024-07-15 23:43:10 +08:00
# During env reset we save the data and encode the videos
num_frames = frame_index
for key in image_keys :
2024-10-03 23:05:23 +08:00
if video :
tmp_imgs_dir = videos_dir / f " { key } _episode_ { episode_index : 06d } "
fname = f " { key } _episode_ { episode_index : 06d } .mp4 "
video_path = local_dir / " videos " / fname
if video_path . exists ( ) :
video_path . unlink ( )
# Store the reference to the video frame, even tho the videos are not yet encoded
ep_dict [ key ] = [ ]
for i in range ( num_frames ) :
ep_dict [ key ] . append ( { " path " : f " videos/ { fname } " , " timestamp " : i / fps } )
else :
imgs_dir = videos_dir / f " { key } _episode_ { episode_index : 06d } "
ep_dict [ key ] = [ ]
for i in range ( num_frames ) :
img_path = imgs_dir / f " frame_ { i : 06d } .png "
ep_dict [ key ] . append ( { " path " : str ( img_path ) } )
2024-07-15 23:43:10 +08:00
for key in not_image_keys :
ep_dict [ key ] = torch . stack ( ep_dict [ key ] )
for key in action :
ep_dict [ key ] = torch . stack ( ep_dict [ key ] )
ep_dict [ " episode_index " ] = torch . tensor ( [ episode_index ] * num_frames )
ep_dict [ " frame_index " ] = torch . arange ( 0 , num_frames , 1 )
ep_dict [ " timestamp " ] = torch . arange ( 0 , num_frames , 1 ) / fps
done = torch . zeros ( num_frames , dtype = torch . bool )
done [ - 1 ] = True
ep_dict [ " next.done " ] = done
ep_path = episodes_dir / f " episode_ { episode_index } .pth "
print ( " Saving episode dictionary... " )
torch . save ( ep_dict , ep_path )
rec_info = {
" last_episode_index " : episode_index ,
}
with open ( rec_info_path , " w " ) as f :
json . dump ( rec_info , f )
is_last_episode = stop_recording or ( episode_index == ( num_episodes - 1 ) )
# Wait if necessary
with tqdm . tqdm ( total = reset_time_s , desc = " Waiting " ) as pbar :
while timestamp < reset_time_s and not is_last_episode :
time . sleep ( 1 )
2024-08-16 00:11:33 +08:00
timestamp = time . perf_counter ( ) - start_vencod_t
2024-07-15 23:43:10 +08:00
pbar . update ( 1 )
if exit_early :
exit_early = False
break
# Skip updating episode index which forces re-recording episode
if rerecord_episode :
rerecord_episode = False
continue
episode_index + = 1
if is_last_episode :
logging . info ( " Done recording " )
2024-08-16 00:11:33 +08:00
say ( " Done recording " , blocking = True )
if not is_headless ( ) :
2024-07-15 23:43:10 +08:00
listener . stop ( )
logging . info ( " Waiting for threads writing the images on disk to terminate... " )
for _ in tqdm . tqdm (
concurrent . futures . as_completed ( futures ) , total = len ( futures ) , desc = " Writting images "
) :
pass
break
2024-08-16 00:11:33 +08:00
robot . disconnect ( )
2024-10-03 23:05:23 +08:00
if display_cameras and not is_headless ( ) :
2024-08-16 00:11:33 +08:00
cv2 . destroyAllWindows ( )
2024-07-15 23:43:10 +08:00
num_episodes = episode_index
2024-10-03 23:05:23 +08:00
if video :
logging . info ( " Encoding videos " )
say ( " Encoding videos " )
# Use ffmpeg to convert frames stored as png into mp4 videos
for episode_index in tqdm . tqdm ( range ( num_episodes ) ) :
for key in image_keys :
tmp_imgs_dir = videos_dir / f " { key } _episode_ { episode_index : 06d } "
fname = f " { key } _episode_ { episode_index : 06d } .mp4 "
video_path = local_dir / " videos " / fname
if video_path . exists ( ) :
# Skip if video is already encoded. Could be the case when resuming data recording.
continue
# note: `encode_video_frames` is a blocking call. Making it asynchronous shouldn't speedup encoding,
# since video encoding with ffmpeg is already using multithreading.
encode_video_frames ( tmp_imgs_dir , video_path , fps , overwrite = True )
shutil . rmtree ( tmp_imgs_dir )
2024-07-15 23:43:10 +08:00
logging . info ( " Concatenating episodes " )
ep_dicts = [ ]
for episode_index in tqdm . tqdm ( range ( num_episodes ) ) :
ep_path = episodes_dir / f " episode_ { episode_index } .pth "
ep_dict = torch . load ( ep_path )
ep_dicts . append ( ep_dict )
data_dict = concatenate_episodes ( ep_dicts )
total_frames = data_dict [ " frame_index " ] . shape [ 0 ]
data_dict [ " index " ] = torch . arange ( 0 , total_frames , 1 )
hf_dataset = to_hf_dataset ( data_dict , video )
episode_data_index = calculate_episode_data_index ( hf_dataset )
info = {
2024-07-17 05:02:31 +08:00
" codebase_version " : CODEBASE_VERSION ,
2024-07-15 23:43:10 +08:00
" fps " : fps ,
" video " : video ,
}
2024-07-23 02:08:59 +08:00
if video :
info [ " encoding " ] = get_default_encoding ( )
2024-07-15 23:43:10 +08:00
lerobot_dataset = LeRobotDataset . from_preloaded (
repo_id = repo_id ,
hf_dataset = hf_dataset ,
episode_data_index = episode_data_index ,
info = info ,
videos_dir = videos_dir ,
)
if run_compute_stats :
logging . info ( " Computing dataset statistics " )
2024-08-16 00:11:33 +08:00
say ( " Computing dataset statistics " )
2024-07-15 23:43:10 +08:00
stats = compute_stats ( lerobot_dataset )
lerobot_dataset . stats = stats
else :
2024-08-16 00:11:33 +08:00
stats = { }
logging . info ( " Skipping computation of the dataset statistics " )
2024-07-15 23:43:10 +08:00
hf_dataset = hf_dataset . with_format ( None ) # to remove transforms that cant be saved
hf_dataset . save_to_disk ( str ( local_dir / " train " ) )
meta_data_dir = local_dir / " meta_data "
save_meta_data ( info , stats , episode_data_index , meta_data_dir )
if push_to_hub :
hf_dataset . push_to_hub ( repo_id , revision = " main " )
push_meta_data_to_hub ( repo_id , meta_data_dir , revision = " main " )
2024-08-16 16:08:44 +08:00
push_dataset_card_to_hub ( repo_id , revision = " main " , tags = tags )
2024-07-15 23:43:10 +08:00
if video :
push_videos_to_hub ( repo_id , videos_dir , revision = " main " )
create_branch ( repo_id , repo_type = " dataset " , branch = CODEBASE_VERSION )
logging . info ( " Exiting " )
2024-08-16 00:11:33 +08:00
say ( " Exiting " )
2024-07-15 23:43:10 +08:00
return lerobot_dataset
2024-08-16 00:11:33 +08:00
def replay ( robot : Robot , episode : int , fps : int | None = None , root = " data " , repo_id = " lerobot/debug " ) :
2024-07-15 23:43:10 +08:00
# TODO(rcadene): Add option to record logs
local_dir = Path ( root ) / repo_id
if not local_dir . exists ( ) :
raise ValueError ( local_dir )
dataset = LeRobotDataset ( repo_id , root = root )
items = dataset . hf_dataset . select_columns ( " action " )
from_idx = dataset . episode_data_index [ " from " ] [ episode ] . item ( )
to_idx = dataset . episode_data_index [ " to " ] [ episode ] . item ( )
if not robot . is_connected :
robot . connect ( )
logging . info ( " Replaying episode " )
2024-08-16 00:11:33 +08:00
say ( " Replaying episode " , blocking = True )
2024-07-15 23:43:10 +08:00
for idx in range ( from_idx , to_idx ) :
2024-08-16 00:11:33 +08:00
start_episode_t = time . perf_counter ( )
2024-07-15 23:43:10 +08:00
action = items [ idx ] [ " action " ]
robot . send_action ( action )
2024-08-16 00:11:33 +08:00
dt_s = time . perf_counter ( ) - start_episode_t
2024-07-15 23:43:10 +08:00
busy_wait ( 1 / fps - dt_s )
2024-08-16 00:11:33 +08:00
dt_s = time . perf_counter ( ) - start_episode_t
2024-07-15 23:43:10 +08:00
log_control_info ( robot , dt_s , fps = fps )
if __name__ == " __main__ " :
parser = argparse . ArgumentParser ( )
subparsers = parser . add_subparsers ( dest = " mode " , required = True )
# Set common options for all the subparsers
base_parser = argparse . ArgumentParser ( add_help = False )
base_parser . add_argument (
2024-08-16 00:11:33 +08:00
" --robot-path " ,
type = str ,
default = " lerobot/configs/robot/koch.yaml " ,
help = " Path to robot yaml file used to instantiate the robot using `make_robot` factory function. " ,
)
base_parser . add_argument (
" --robot-overrides " ,
2024-07-15 23:43:10 +08:00
type = str ,
2024-08-16 00:11:33 +08:00
nargs = " * " ,
help = " Any key=value arguments to override config values (use dots for.nested=overrides) " ,
2024-07-15 23:43:10 +08:00
)
2024-08-16 00:11:33 +08:00
parser_calib = subparsers . add_parser ( " calibrate " , parents = [ base_parser ] )
2024-09-05 01:28:05 +08:00
parser_calib . add_argument (
" --arms " ,
2024-09-06 20:44:31 +08:00
type = str ,
2024-09-05 01:28:05 +08:00
nargs = " * " ,
help = " List of arms to calibrate (e.g. `--arms left_follower right_follower left_leader`) " ,
)
2024-08-16 00:11:33 +08:00
2024-07-15 23:43:10 +08:00
parser_teleop = subparsers . add_parser ( " teleoperate " , parents = [ base_parser ] )
parser_teleop . add_argument (
" --fps " , type = none_or_int , default = None , help = " Frames per second (set to None to disable) "
)
2024-08-16 00:11:33 +08:00
parser_record = subparsers . add_parser ( " record " , parents = [ base_parser ] )
2024-07-15 23:43:10 +08:00
parser_record . add_argument (
" --fps " , type = none_or_int , default = None , help = " Frames per second (set to None to disable) "
)
parser_record . add_argument (
" --root " ,
type = Path ,
default = " data " ,
help = " Root directory where the dataset will be stored locally at ' {root} / {repo_id} ' (e.g. ' data/hf_username/dataset_name ' ). " ,
)
parser_record . add_argument (
" --repo-id " ,
type = str ,
default = " lerobot/test " ,
help = " Dataset identifier. By convention it should match ' {hf_username} / {dataset_name} ' (e.g. `lerobot/test`). " ,
)
parser_record . add_argument (
" --warmup-time-s " ,
type = int ,
2024-08-16 00:11:33 +08:00
default = 10 ,
2024-07-15 23:43:10 +08:00
help = " Number of seconds before starting data collection. It allows the robot devices to warmup and synchronize. " ,
)
parser_record . add_argument (
" --episode-time-s " ,
type = int ,
2024-08-16 00:11:33 +08:00
default = 60 ,
2024-07-15 23:43:10 +08:00
help = " Number of seconds for data recording for each episode. " ,
)
parser_record . add_argument (
" --reset-time-s " ,
type = int ,
2024-08-16 00:11:33 +08:00
default = 60 ,
2024-07-15 23:43:10 +08:00
help = " Number of seconds for resetting the environment after each episode. " ,
)
parser_record . add_argument ( " --num-episodes " , type = int , default = 50 , help = " Number of episodes to record. " )
parser_record . add_argument (
" --run-compute-stats " ,
type = int ,
default = 1 ,
help = " By default, run the computation of the data statistics at the end of data collection. Compute intensive and not required to just replay an episode. " ,
)
parser_record . add_argument (
" --push-to-hub " ,
type = int ,
default = 1 ,
help = " Upload dataset to Hugging Face hub. " ,
)
2024-08-16 16:08:44 +08:00
parser_record . add_argument (
" --tags " ,
type = str ,
nargs = " * " ,
help = " Add tags to your dataset on the hub. " ,
)
2024-07-15 23:43:10 +08:00
parser_record . add_argument (
2024-09-12 20:20:24 +08:00
" --num-image-writers-per-camera " ,
2024-07-15 23:43:10 +08:00
type = int ,
2024-09-12 20:20:24 +08:00
default = 4 ,
help = (
" Number of threads writing the frames as png images on disk, per camera. "
" Too much threads might cause unstable teleoperation fps due to main thread being blocked. "
" Not enough threads might cause low camera fps. "
) ,
2024-07-15 23:43:10 +08:00
)
parser_record . add_argument (
" --force-override " ,
type = int ,
default = 0 ,
help = " By default, data recording is resumed. When set to 1, delete the local directory and start data recording from scratch. " ,
)
2024-08-16 00:11:33 +08:00
parser_record . add_argument (
" -p " ,
" --pretrained-policy-name-or-path " ,
type = str ,
help = (
" Either the repo ID of a model hosted on the Hub or a path to a directory containing weights "
" saved using `Policy.save_pretrained`. "
) ,
)
parser_record . add_argument (
" --policy-overrides " ,
type = str ,
nargs = " * " ,
help = " Any key=value arguments to override config values (use dots for.nested=overrides) " ,
)
2024-07-15 23:43:10 +08:00
2024-08-16 00:11:33 +08:00
parser_replay = subparsers . add_parser ( " replay " , parents = [ base_parser ] )
2024-07-15 23:43:10 +08:00
parser_replay . add_argument (
" --fps " , type = none_or_int , default = None , help = " Frames per second (set to None to disable) "
)
parser_replay . add_argument (
" --root " ,
type = Path ,
default = " data " ,
help = " Root directory where the dataset will be stored locally at ' {root} / {repo_id} ' (e.g. ' data/hf_username/dataset_name ' ). " ,
)
parser_replay . add_argument (
" --repo-id " ,
type = str ,
default = " lerobot/test " ,
help = " Dataset identifier. By convention it should match ' {hf_username} / {dataset_name} ' (e.g. `lerobot/test`). " ,
)
parser_replay . add_argument ( " --episode " , type = int , default = 0 , help = " Index of the episode to replay. " )
args = parser . parse_args ( )
init_logging ( )
control_mode = args . mode
2024-08-16 00:11:33 +08:00
robot_path = args . robot_path
robot_overrides = args . robot_overrides
2024-07-15 23:43:10 +08:00
kwargs = vars ( args )
del kwargs [ " mode " ]
2024-08-16 00:11:33 +08:00
del kwargs [ " robot_path " ]
del kwargs [ " robot_overrides " ]
robot_cfg = init_hydra_config ( robot_path , robot_overrides )
robot = make_robot ( robot_cfg )
2024-07-15 23:43:10 +08:00
2024-08-16 00:11:33 +08:00
if control_mode == " calibrate " :
calibrate ( robot , * * kwargs )
elif control_mode == " teleoperate " :
2024-07-15 23:43:10 +08:00
teleoperate ( robot , * * kwargs )
2024-08-16 00:11:33 +08:00
elif control_mode == " record " :
pretrained_policy_name_or_path = args . pretrained_policy_name_or_path
policy_overrides = args . policy_overrides
del kwargs [ " pretrained_policy_name_or_path " ]
del kwargs [ " policy_overrides " ]
policy_cfg = None
if pretrained_policy_name_or_path is not None :
pretrained_policy_path = get_pretrained_policy_path ( pretrained_policy_name_or_path )
policy_cfg = init_hydra_config ( pretrained_policy_path / " config.yaml " , policy_overrides )
policy = make_policy ( hydra_cfg = policy_cfg , pretrained_policy_name_or_path = pretrained_policy_path )
record ( robot , policy , policy_cfg , * * kwargs )
else :
record ( robot , * * kwargs )
elif control_mode == " replay " :
replay ( robot , * * kwargs )
if robot . is_connected :
# Disconnect manually to avoid a "Core dump" during process
# termination due to camera threads not properly exiting.
robot . disconnect ( )