Add policy/act_aloha_real.yaml + env/act_real.yaml (#429)

Co-authored-by: Simon Alibert <75076266+aliberts@users.noreply.github.com>
This commit is contained in:
Remi 2024-10-10 17:12:45 +02:00 committed by GitHub
parent c29e70e5a1
commit 97b1feb0b3
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
20 changed files with 464 additions and 188 deletions

179
examples/9_use_aloha.md Normal file
View File

@ -0,0 +1,179 @@
This tutorial explains how to use [Aloha and Aloha 2 stationary](https://www.trossenrobotics.com/aloha-stationary) with LeRobot.
## Setup
Follow the [documentation from Trossen Robotics](https://docs.trossenrobotics.com/aloha_docs/getting_started/stationary/hardware_setup.html) for setting up the hardware and plugging the 4 arms and 4 cameras to your computer.
## Install LeRobot
On your computer:
1. [Install Miniconda](https://docs.anaconda.com/miniconda/#quick-command-line-install):
```bash
mkdir -p ~/miniconda3
wget https://repo.anaconda.com/miniconda/Miniconda3-latest-Linux-x86_64.sh -O ~/miniconda3/miniconda.sh
bash ~/miniconda3/miniconda.sh -b -u -p ~/miniconda3
rm ~/miniconda3/miniconda.sh
~/miniconda3/bin/conda init bash
```
2. Restart shell or `source ~/.bashrc`
3. Create and activate a fresh conda environment for lerobot
```bash
conda create -y -n lerobot python=3.10 && conda activate lerobot
```
4. Clone LeRobot:
```bash
git clone https://github.com/huggingface/lerobot.git ~/lerobot
```
5. Install LeRobot with dependencies for the Aloha motors (dynamixel) and cameras (intelrealsense):
```bash
cd ~/lerobot && pip install -e ".[dynamixel intelrealsense]"
```
And install extra dependencies for recording datasets on Linux:
```bash
conda install -y -c conda-forge ffmpeg
pip uninstall -y opencv-python
conda install -y -c conda-forge "opencv>=4.10.0"
```
## Teleoperate
**/!\ FOR SAFETY, READ THIS /!\**
Teleoperation consists in manually operating the leader arms to move the follower arms. Importantly:
1. Make sure your leader arms are in the same position as the follower arms, so that the follower arms don't move too fast to match the leader arms,
2. Our code assumes that your robot has been assembled following Trossen Robotics instructions. This allows us to skip calibration, as we use the pre-defined calibration files in `.cache/calibration/aloha_default`. If you replace a motor, make sure you follow the exact instructions from Trossen Robotics.
By running the following code, you can start your first **SAFE** teleoperation:
```bash
python lerobot/scripts/control_robot.py teleoperate \
--robot-path lerobot/configs/robot/aloha.yaml \
--robot-overrides max_relative_target=5
```
By adding `--robot-overrides max_relative_target=5`, we override the default value for `max_relative_target` defined in `lerobot/configs/robot/aloha.yaml`. It is expected to be `5` to limit the magnitude of the movement for more safety, but the teloperation won't be smooth. When you feel confident, you can disable this limit by adding `--robot-overrides max_relative_target=null` to the command line:
```bash
python lerobot/scripts/control_robot.py teleoperate \
--robot-path lerobot/configs/robot/aloha.yaml \
--robot-overrides max_relative_target=null
```
## Record a dataset
Once you're familiar with teleoperation, you can record your first dataset with Aloha.
If you want to use the Hugging Face hub features for uploading your dataset and you haven't previously done it, make sure you've logged in using a write-access token, which can be generated from the [Hugging Face settings](https://huggingface.co/settings/tokens):
```bash
huggingface-cli login --token ${HUGGINGFACE_TOKEN} --add-to-git-credential
```
Store your Hugging Face repository name in a variable to run these commands:
```bash
HF_USER=$(huggingface-cli whoami | head -n 1)
echo $HF_USER
```
Record 2 episodes and upload your dataset to the hub:
```bash
python lerobot/scripts/control_robot.py record \
--robot-path lerobot/configs/robot/aloha.yaml \
--robot-overrides max_relative_target=null \
--fps 30 \
--root data \
--repo-id ${HF_USER}/aloha_test \
--tags aloha tutorial \
--warmup-time-s 5 \
--episode-time-s 40 \
--reset-time-s 10 \
--num-episodes 2 \
--push-to-hub 1
```
## Visualize a dataset
If you uploaded your dataset to the hub with `--push-to-hub 1`, you can [visualize your dataset online](https://huggingface.co/spaces/lerobot/visualize_dataset) by copy pasting your repo id given by:
```bash
echo ${HF_USER}/aloha_test
```
If you didn't upload with `--push-to-hub 0`, you can also visualize it locally with:
```bash
python lerobot/scripts/visualize_dataset_html.py \
--root data \
--repo-id ${HF_USER}/aloha_test
```
## Replay an episode
**/!\ FOR SAFETY, READ THIS /!\**
Replay consists in automatically replaying the sequence of actions (i.e. goal positions for your motors) recorded in a given dataset episode. Make sure the current initial position of your robot is similar to the one in your episode, so that your follower arms don't move too fast to go to the first goal positions. For safety, you might want to add `--robot-overrides max_relative_target=5` to your command line as explained above.
Now try to replay the first episode on your robot:
```bash
python lerobot/scripts/control_robot.py replay \
--robot-path lerobot/configs/robot/aloha.yaml \
--robot-overrides max_relative_target=null \
--fps 30 \
--root data \
--repo-id ${HF_USER}/aloha_test \
--episode 0
```
## Train a policy
To train a policy to control your robot, use the [`python lerobot/scripts/train.py`](../lerobot/scripts/train.py) script. A few arguments are required. Here is an example command:
```bash
DATA_DIR=data python lerobot/scripts/train.py \
dataset_repo_id=${HF_USER}/aloha_test \
policy=act_aloha_real \
env=aloha_real \
hydra.run.dir=outputs/train/act_aloha_test \
hydra.job.name=act_aloha_test \
device=cuda \
wandb.enable=true
```
Let's explain it:
1. We provided the dataset as argument with `dataset_repo_id=${HF_USER}/aloha_test`.
2. We provided the policy with `policy=act_aloha_real`. This loads configurations from [`lerobot/configs/policy/act_aloha_real.yaml`](../lerobot/configs/policy/act_aloha_real.yaml). Importantly, this policy uses 4 cameras as input `cam_right_wrist`, `cam_left_wrist`, `cam_high`, and `cam_low`.
3. We provided an environment as argument with `env=aloha_real`. This loads configurations from [`lerobot/configs/env/aloha_real.yaml`](../lerobot/configs/env/aloha_real.yaml). Note: this yaml defines 18 dimensions for the `state_dim` and `action_dim`, corresponding to 18 motors, not 14 motors as used in previous Aloha work. This is because, we include the `shoulder_shadow` and `elbow_shadow` motors for simplicity.
4. We provided `device=cuda` since we are training on a Nvidia GPU.
5. We provided `wandb.enable=true` to use [Weights and Biases](https://docs.wandb.ai/quickstart) for visualizing training plots. This is optional but if you use it, make sure you are logged in by running `wandb login`.
6. We added `DATA_DIR=data` to access your dataset stored in your local `data` directory. If you dont provide `DATA_DIR`, your dataset will be downloaded from Hugging Face hub to your cache folder `$HOME/.cache/hugginface`. In future versions of `lerobot`, both directories will be in sync.
Training should take several hours. You will find checkpoints in `outputs/train/act_aloha_test/checkpoints`.
## Evaluate your policy
You can use the `record` function from [`lerobot/scripts/control_robot.py`](../lerobot/scripts/control_robot.py) but with a policy checkpoint as input. For instance, run this command to record 10 evaluation episodes:
```bash
python lerobot/scripts/control_robot.py record \
--robot-path lerobot/configs/robot/aloha.yaml \
--robot-overrides max_relative_target=null \
--fps 30 \
--root data \
--repo-id ${HF_USER}/eval_act_aloha_test \
--tags aloha tutorial eval \
--warmup-time-s 5 \
--episode-time-s 40 \
--reset-time-s 10 \
--num-episodes 10 \
--num-image-writer-processes 1 \
-p outputs/train/act_aloha_test/checkpoints/last/pretrained_model
```
As you can see, it's almost the same command as previously used to record your training dataset. Two things changed:
1. There is an additional `-p` argument which indicates the path to your policy checkpoint with (e.g. `-p outputs/train/eval_aloha_test/checkpoints/last/pretrained_model`). You can also use the model repository if you uploaded a model checkpoint to the hub (e.g. `-p ${HF_USER}/act_aloha_test`).
2. The name of dataset begins by `eval` to reflect that you are running inference (e.g. `--repo-id ${HF_USER}/eval_act_aloha_test`).
3. We use `--num-image-writer-processes 1` instead of the default value (`0`). On our computer, using a dedicated process to write images from the 4 cameras on disk allows to reach constent 30 fps during inference. Feel free to explore different values for `--num-image-writer-processes`.
## More
Follow this [previous tutorial](https://github.com/huggingface/lerobot/blob/main/examples/7_get_started_with_real_robot.md#4-train-a-policy-on-your-data) for a more in-depth explaination.
If you have any question or need help, please reach out on Discord in the channel `#aloha-arm`.

View File

@ -216,7 +216,9 @@ available_policies_per_env = {
"aloha": ["act"],
"pusht": ["diffusion", "vqbet"],
"xarm": ["tdmpc"],
"dora_aloha_real": ["act_real"],
"koch_real": ["act_koch_real"],
"aloha_real": ["act_aloha_real"],
"dora_aloha_real": ["act_aloha_real"],
}
env_task_pairs = [(env, task) for env, tasks in available_tasks_per_env.items() for task in tasks]

View File

@ -364,6 +364,7 @@ class ManipulatorRobot:
for name in self.follower_arms:
print(f"Connecting {name} follower arm.")
self.follower_arms[name].connect()
for name in self.leader_arms:
print(f"Connecting {name} leader arm.")
self.leader_arms[name].connect()

10
lerobot/configs/env/aloha_real.yaml vendored Normal file
View File

@ -0,0 +1,10 @@
# @package _global_
fps: 30
env:
name: real_world
task: null
state_dim: 18
action_dim: 18
fps: ${fps}

View File

@ -1,16 +1,22 @@
# @package _global_
# Use `act_real.yaml` to train on real-world Aloha/Aloha2 datasets.
# Compared to `act.yaml`, it contains 4 cameras (i.e. cam_right_wrist, cam_left_wrist, images,
# cam_low) instead of 1 camera (i.e. top). Also, `training.eval_freq` is set to -1. This config is used
# to evaluate checkpoints at a certain frequency of training steps. When it is set to -1, it deactivates evaluation.
# This is because real-world evaluation is done through [dora-lerobot](https://github.com/dora-rs/dora-lerobot).
# Look at its README for more information on how to evaluate a checkpoint in the real-world.
# Use `act_aloha_real.yaml` to train on real-world datasets collected on Aloha or Aloha-2 robots.
# Compared to `act.yaml`, it contains 4 cameras (i.e. cam_right_wrist, cam_left_wrist, cam_high, cam_low) instead of 1 camera (i.e. top).
# Also, `training.eval_freq` is set to -1. This config is used to evaluate checkpoints at a certain frequency of training steps.
# When it is set to -1, it deactivates evaluation. This is because real-world evaluation is done through our `control_robot.py` script.
# Look at the documentation in header of `control_robot.py` for more information on how to collect data , train and evaluate a policy.
#
# Example of usage for training:
# Example of usage for training and inference with `control_robot.py`:
# ```bash
# python lerobot/scripts/train.py \
# policy=act_real \
# policy=act_aloha_real \
# env=aloha_real
# ```
#
# Example of usage for training and inference with [Dora-rs](https://github.com/dora-rs/dora-lerobot):
# ```bash
# python lerobot/scripts/train.py \
# policy=act_aloha_real \
# env=dora_aloha_real
# ```
@ -36,10 +42,11 @@ override_dataset_stats:
std: [[[0.229]], [[0.224]], [[0.225]]] # (c,1,1)
training:
offline_steps: 100000
offline_steps: 80000
online_steps: 0
eval_freq: -1
save_freq: 20000
save_freq: 10000
log_freq: 100
save_checkpoint: true
batch_size: 8
@ -62,7 +69,7 @@ policy:
# Input / output structure.
n_obs_steps: 1
chunk_size: 100 # chunk_size
chunk_size: 100
n_action_steps: 100
input_shapes:
@ -107,7 +114,7 @@ policy:
n_vae_encoder_layers: 4
# Inference.
temporal_ensemble_coeff: null
temporal_ensemble_momentum: null
# Training and loss computation.
dropout: 0.1

View File

@ -1,110 +0,0 @@
# @package _global_
# Use `act_real_no_state.yaml` to train on real-world Aloha/Aloha2 datasets when cameras are moving (e.g. wrist cameras)
# Compared to `act_real.yaml`, it is camera only and does not use the state as input which is vector of robot joint positions.
# We validated experimentaly that not using state reaches better success rate. Our hypothesis is that `act_real.yaml` might
# overfits to the state, because the images are more complex to learn from since they are moving.
#
# Example of usage for training:
# ```bash
# python lerobot/scripts/train.py \
# policy=act_real_no_state \
# env=dora_aloha_real
# ```
seed: 1000
dataset_repo_id: lerobot/aloha_static_vinh_cup
override_dataset_stats:
observation.images.cam_right_wrist:
# stats from imagenet, since we use a pretrained vision model
mean: [[[0.485]], [[0.456]], [[0.406]]] # (c,1,1)
std: [[[0.229]], [[0.224]], [[0.225]]] # (c,1,1)
observation.images.cam_left_wrist:
# stats from imagenet, since we use a pretrained vision model
mean: [[[0.485]], [[0.456]], [[0.406]]] # (c,1,1)
std: [[[0.229]], [[0.224]], [[0.225]]] # (c,1,1)
observation.images.cam_high:
# stats from imagenet, since we use a pretrained vision model
mean: [[[0.485]], [[0.456]], [[0.406]]] # (c,1,1)
std: [[[0.229]], [[0.224]], [[0.225]]] # (c,1,1)
observation.images.cam_low:
# stats from imagenet, since we use a pretrained vision model
mean: [[[0.485]], [[0.456]], [[0.406]]] # (c,1,1)
std: [[[0.229]], [[0.224]], [[0.225]]] # (c,1,1)
training:
offline_steps: 100000
online_steps: 0
eval_freq: -1
save_freq: 20000
save_checkpoint: true
batch_size: 8
lr: 1e-5
lr_backbone: 1e-5
weight_decay: 1e-4
grad_clip_norm: 10
online_steps_between_rollouts: 1
delta_timestamps:
action: "[i / ${fps} for i in range(${policy.chunk_size})]"
eval:
n_episodes: 50
batch_size: 50
# See `configuration_act.py` for more details.
policy:
name: act
# Input / output structure.
n_obs_steps: 1
chunk_size: 100 # chunk_size
n_action_steps: 100
input_shapes:
# TODO(rcadene, alexander-soare): add variables for height and width from the dataset/env?
observation.images.cam_right_wrist: [3, 480, 640]
observation.images.cam_left_wrist: [3, 480, 640]
observation.images.cam_high: [3, 480, 640]
observation.images.cam_low: [3, 480, 640]
output_shapes:
action: ["${env.action_dim}"]
# Normalization / Unnormalization
input_normalization_modes:
observation.images.cam_right_wrist: mean_std
observation.images.cam_left_wrist: mean_std
observation.images.cam_high: mean_std
observation.images.cam_low: mean_std
output_normalization_modes:
action: mean_std
# Architecture.
# Vision backbone.
vision_backbone: resnet18
pretrained_backbone_weights: ResNet18_Weights.IMAGENET1K_V1
replace_final_stride_with_dilation: false
# Transformer layers.
pre_norm: false
dim_model: 512
n_heads: 8
dim_feedforward: 3200
feedforward_activation: relu
n_encoder_layers: 4
# Note: Although the original ACT implementation has 7 for `n_decoder_layers`, there is a bug in the code
# that means only the first layer is used. Here we match the original implementation by setting this to 1.
# See this issue https://github.com/tonyzhaozh/act/issues/25#issue-2258740521.
n_decoder_layers: 1
# VAE.
use_vae: true
latent_dim: 32
n_vae_encoder_layers: 4
# Inference.
temporal_ensemble_coeff: null
# Training and loss computation.
dropout: 0.1
kl_weight: 10.0

View File

@ -102,6 +102,7 @@ import argparse
import concurrent.futures
import json
import logging
import multiprocessing
import os
import platform
import shutil
@ -163,9 +164,9 @@ def say(text, blocking=False):
os.system(cmd)
def save_image(img_tensor, key, frame_index, episode_index, videos_dir):
def save_image(img_tensor, key, frame_index, episode_index, videos_dir: str):
img = Image.fromarray(img_tensor.numpy())
path = videos_dir / f"{key}_episode_{episode_index:06d}" / f"frame_{frame_index:06d}.png"
path = Path(videos_dir) / f"{key}_episode_{episode_index:06d}" / f"frame_{frame_index:06d}.png"
path.parent.mkdir(parents=True, exist_ok=True)
img.save(str(path), quality=100)
@ -255,6 +256,129 @@ def get_available_arms(robot):
return available_arms
########################################################################################
# Asynchrounous saving of images on disk
########################################################################################
def loop_to_save_images_in_threads(image_queue, num_threads):
if num_threads < 1:
raise NotImplementedError(f"Only `num_threads>=1` is supported for now, but {num_threads=} given.")
with concurrent.futures.ThreadPoolExecutor(max_workers=num_threads) as executor:
futures = []
while True:
# Blocks until a frame is available
frame_data = image_queue.get()
# As usually done, exit loop when receiving None to stop the worker
if frame_data is None:
break
image, key, frame_index, episode_index, videos_dir = frame_data
futures.append(executor.submit(save_image, image, key, frame_index, episode_index, videos_dir))
# Before exiting function, wait for all threads to complete
with tqdm.tqdm(total=len(futures), desc="Writing images") as progress_bar:
concurrent.futures.wait(futures)
progress_bar.update(len(futures))
def start_image_writer_processes(image_queue, num_processes, num_threads_per_process):
if num_processes < 1:
raise ValueError(f"Only `num_processes>=1` is supported, but {num_processes=} given.")
if num_threads_per_process < 1:
raise NotImplementedError(
"Only `num_threads_per_process>=1` is supported for now, but {num_threads_per_process=} given."
)
processes = []
for _ in range(num_processes):
process = multiprocessing.Process(
target=loop_to_save_images_in_threads,
args=(image_queue, num_threads_per_process),
)
process.start()
processes.append(process)
return processes
def stop_processes(processes, queue, timeout):
# Send None to each process to signal them to stop
for _ in processes:
queue.put(None)
# Close the queue, no more items can be put in the queue
queue.close()
# Wait maximum 20 seconds for all processes to terminate
for process in processes:
process.join(timeout=timeout)
# If not terminated after 20 seconds, force termination
if process.is_alive():
process.terminate()
# Ensure all background queue threads have finished
queue.join_thread()
def start_image_writer(num_processes, num_threads):
"""This function abstract away the initialisation of processes or/and threads to
save images on disk asynchrounously, which is critical to control a robot and record data
at a high frame rate.
When `num_processes=0`, it returns a dictionary containing a threads pool of size `num_threads`.
When `num_processes>0`, it returns a dictionary containing a processes pool of size `num_processes`,
where each subprocess starts their own threads pool of size `num_threads`.
The optimal number of processes and threads depends on your computer capabilities.
We advise to use 4 threads per camera with 0 processes. If the fps is not stable, try to increase or lower
the number of threads. If it is still not stable, try to use 1 subprocess, or more.
"""
image_writer = {}
if num_processes == 0:
futures = []
threads_pool = concurrent.futures.ThreadPoolExecutor(max_workers=num_threads)
image_writer["threads_pool"], image_writer["futures"] = threads_pool, futures
else:
# TODO(rcadene): When using num_processes>1, `multiprocessing.Manager().Queue()`
# might be better than `multiprocessing.Queue()`. Source: https://www.geeksforgeeks.org/python-multiprocessing-queue-vs-multiprocessing-manager-queue
image_queue = multiprocessing.Queue()
processes_pool = start_image_writer_processes(
image_queue, num_processes=num_processes, num_threads_per_process=num_threads
)
image_writer["processes_pool"], image_writer["image_queue"] = processes_pool, image_queue
return image_writer
def async_save_image(image_writer, image, key, frame_index, episode_index, videos_dir):
"""This function abstract away the saving of an image on disk asynchrounously. It uses a dictionary
called image writer which contains either a pool of processes or a pool of threads.
"""
if "threads_pool" in image_writer:
threads_pool, futures = image_writer["threads_pool"], image_writer["futures"]
futures.append(threads_pool.submit(save_image, image, key, frame_index, episode_index, videos_dir))
else:
image_queue = image_writer["image_queue"]
image_queue.put((image, key, frame_index, episode_index, videos_dir))
def stop_image_writer(image_writer, timeout):
if "threads_pool" in image_writer:
futures = image_writer["futures"]
# Before exiting function, wait for all threads to complete
with tqdm.tqdm(total=len(futures), desc="Writing images") as progress_bar:
concurrent.futures.wait(futures, timeout=timeout)
progress_bar.update(len(futures))
else:
processes_pool, image_queue = image_writer["processes_pool"], image_writer["image_queue"]
stop_processes(processes_pool, image_queue, timeout=timeout)
########################################################################################
# Control modes
########################################################################################
@ -342,9 +466,11 @@ def record(
run_compute_stats=True,
push_to_hub=True,
tags=None,
num_image_writers_per_camera=4,
num_image_writer_processes=0,
num_image_writer_threads_per_camera=4,
force_override=False,
display_cameras=True,
play_sounds=True,
):
# TODO(rcadene): Add option to record logs
# TODO(rcadene): Clean this function via decomposition in higher level functions
@ -436,7 +562,8 @@ def record(
while timestamp < warmup_time_s:
if not is_warmup_print:
logging.info("Warming up (no data recording)")
say("Warming up")
if play_sounds:
say("Warming up")
is_warmup_print = True
start_loop_t = time.perf_counter()
@ -463,16 +590,22 @@ def record(
if has_method(robot, "teleop_safety_stop"):
robot.teleop_safety_stop()
# Save images using threads to reach high fps (30 and more)
# Using `with` to exist smoothly if an execption is raised.
futures = []
num_image_writers = num_image_writers_per_camera * len(robot.cameras)
num_image_writers = max(num_image_writers, 1)
with concurrent.futures.ThreadPoolExecutor(max_workers=num_image_writers) as executor:
has_camera = len(robot.cameras) > 0
if has_camera:
# Initialize processes or/and threads dedicated to save images on disk asynchronously,
# which is critical to control a robot and record data at a high frame rate.
image_writer = start_image_writer(
num_processes=num_image_writer_processes,
num_threads=num_image_writer_threads_per_camera * len(robot.cameras),
)
# Using `try` to exist smoothly if an exception is raised
try:
# Start recording all episodes
while episode_index < num_episodes:
logging.info(f"Recording episode {episode_index}")
say(f"Recording episode {episode_index}")
if play_sounds:
say(f"Recording episode {episode_index}")
ep_dict = {}
frame_index = 0
timestamp = 0
@ -488,12 +621,16 @@ def record(
image_keys = [key for key in observation if "image" in key]
not_image_keys = [key for key in observation if "image" not in key]
for key in image_keys:
futures += [
executor.submit(
save_image, observation[key], key, frame_index, episode_index, videos_dir
if has_camera > 0:
for key in image_keys:
async_save_image(
image_writer,
image=observation[key],
key=key,
frame_index=frame_index,
episode_index=episode_index,
videos_dir=str(videos_dir),
)
]
if display_cameras and not is_headless():
image_keys = [key for key in observation if "image" in key]
@ -563,7 +700,8 @@ def record(
if not stop_recording:
# Start resetting env while the executor are finishing
logging.info("Reset the environment")
say("Reset the environment")
if play_sounds:
say("Reset the environment")
timestamp = 0
start_vencod_t = time.perf_counter()
@ -635,18 +773,23 @@ def record(
if is_last_episode:
logging.info("Done recording")
say("Done recording", blocking=True)
if play_sounds:
say("Done recording", blocking=True)
if not is_headless():
listener.stop()
logging.info("Waiting for threads writing the images on disk to terminate...")
for _ in tqdm.tqdm(
concurrent.futures.as_completed(futures), total=len(futures), desc="Writting images"
):
pass
break
if has_camera > 0:
logging.info("Waiting for image writer to terminate...")
stop_image_writer(image_writer, timeout=20)
except Exception as e:
if has_camera > 0:
logging.info("Waiting for image writer to terminate...")
stop_image_writer(image_writer, timeout=20)
raise e
robot.disconnect()
if display_cameras and not is_headless():
cv2.destroyAllWindows()
@ -654,7 +797,8 @@ def record(
if video:
logging.info("Encoding videos")
say("Encoding videos")
if play_sounds:
say("Encoding videos")
# Use ffmpeg to convert frames stored as png into mp4 videos
for episode_index in tqdm.tqdm(range(num_episodes)):
for key in image_keys:
@ -699,7 +843,8 @@ def record(
)
if run_compute_stats:
logging.info("Computing dataset statistics")
say("Computing dataset statistics")
if play_sounds:
say("Computing dataset statistics")
stats = compute_stats(lerobot_dataset)
lerobot_dataset.stats = stats
else:
@ -721,11 +866,14 @@ def record(
create_branch(repo_id, repo_type="dataset", branch=CODEBASE_VERSION)
logging.info("Exiting")
say("Exiting")
if play_sounds:
say("Exiting")
return lerobot_dataset
def replay(robot: Robot, episode: int, fps: int | None = None, root="data", repo_id="lerobot/debug"):
def replay(
robot: Robot, episode: int, fps: int | None = None, root="data", repo_id="lerobot/debug", play_sounds=True
):
# TODO(rcadene): Add option to record logs
local_dir = Path(root) / repo_id
if not local_dir.exists():
@ -740,7 +888,8 @@ def replay(robot: Robot, episode: int, fps: int | None = None, root="data", repo
robot.connect()
logging.info("Replaying episode")
say("Replaying episode", blocking=True)
if play_sounds:
say("Replaying episode", blocking=True)
for idx in range(from_idx, to_idx):
start_episode_t = time.perf_counter()
@ -840,12 +989,23 @@ if __name__ == "__main__":
help="Add tags to your dataset on the hub.",
)
parser_record.add_argument(
"--num-image-writers-per-camera",
"--num-image-writer-processes",
type=int,
default=0,
help=(
"Number of subprocesses handling the saving of frames as PNGs. Set to 0 to use threads only; "
"set to ≥1 to use subprocesses, each using threads to write images. The best number of processes "
"and threads depends on your system. We recommend 4 threads per camera with 0 processes. "
"If fps is unstable, adjust the thread count. If still unstable, try using 1 or more subprocesses."
),
)
parser_record.add_argument(
"--num-image-writer-threads-per-camera",
type=int,
default=4,
help=(
"Number of threads writing the frames as png images on disk, per camera. "
"Too much threads might cause unstable teleoperation fps due to main thread being blocked. "
"Too many threads might cause unstable teleoperation fps due to main thread being blocked. "
"Not enough threads might cause low camera fps."
),
)

2
poetry.lock generated
View File

@ -5245,7 +5245,7 @@ docs = ["sphinx", "sphinx-automodapi", "sphinx-rtd-theme"]
name = "pyserial"
version = "3.5"
description = "Python Serial Port Extension"
optional = true
optional = false
python-versions = "*"
files = [
{file = "pyserial-3.5-py2.py3-none-any.whl", hash = "sha256:c4451db6ba391ca6ca299fb3ec7bae67a5c55dde170964c7a14ceefec02f2cf0"},

View File

@ -52,8 +52,9 @@ def is_robot_available(robot_type):
print(f"\nInstall module '{e.name}'")
elif isinstance(e, SerialException):
print("\nNo physical motors bus detected.")
else:
traceback.print_exc()
traceback.print_exc()
return False
@ -77,8 +78,9 @@ def is_camera_available(camera_type):
print(f"\nInstall module '{e.name}'")
elif isinstance(e, ValueError) and "camera_index" in e.args[0]:
print("\nNo physical camera detected.")
else:
traceback.print_exc()
traceback.print_exc()
return False
@ -102,8 +104,9 @@ def is_motor_available(motor_type):
print(f"\nInstall module '{e.name}'")
elif isinstance(e, SerialException):
print("\nNo physical motors bus detected.")
else:
traceback.print_exc()
traceback.print_exc()
return False

View File

@ -1,3 +0,0 @@
version https://git-lfs.github.com/spec/v1
oid sha256:b5a9f73a2356aff9c717cdfd0d37a6da08b0cf2cc09c98edbc9492501b7f64a5
size 5104

View File

@ -1,3 +0,0 @@
version https://git-lfs.github.com/spec/v1
oid sha256:28738b3cfad17af0ac5181effdd796acdf7953cd5bcca3f421a11ddfd6b0076f
size 30800

View File

@ -1,3 +0,0 @@
version https://git-lfs.github.com/spec/v1
oid sha256:4bb8a197a40456fdbc16029126268e6bcef3eca1837d88235165dc7e14618bea
size 68

View File

@ -1,3 +0,0 @@
version https://git-lfs.github.com/spec/v1
oid sha256:bea60cce42d324f539dd3bca1e66b5ba6391838fdcadb00efc25f3240edb529a
size 33600

View File

@ -23,6 +23,7 @@ pytest -sx 'tests/test_control_robot.py::test_teleoperate[aloha-True]'
```
"""
import multiprocessing
from pathlib import Path
import pytest
@ -37,7 +38,7 @@ from tests.utils import DEFAULT_CONFIG_PATH, DEVICE, TEST_ROBOT_TYPES, require_r
@pytest.mark.parametrize("robot_type, mock", TEST_ROBOT_TYPES)
@require_robot
def test_teleoperate(tmpdir, request, robot_type, mock):
if mock:
if mock and robot_type != "aloha":
request.getfixturevalue("patch_builtins_input")
# Create an empty calibration directory to trigger manual calibration
@ -78,7 +79,7 @@ def test_record_without_cameras(tmpdir, request, robot_type, mock):
# Avoid using cameras
overrides = ["~cameras"]
if mock:
if mock and robot_type != "aloha":
request.getfixturevalue("patch_builtins_input")
# Create an empty calibration directory to trigger manual calibration
@ -101,13 +102,14 @@ def test_record_without_cameras(tmpdir, request, robot_type, mock):
run_compute_stats=False,
push_to_hub=False,
video=False,
play_sounds=False,
)
@pytest.mark.parametrize("robot_type, mock", TEST_ROBOT_TYPES)
@require_robot
def test_record_and_replay_and_policy(tmpdir, request, robot_type, mock):
if mock:
if mock and robot_type != "aloha":
request.getfixturevalue("patch_builtins_input")
# Create an empty calibration directory to trigger manual calibration
@ -115,12 +117,9 @@ def test_record_and_replay_and_policy(tmpdir, request, robot_type, mock):
calibration_dir = Path(tmpdir) / robot_type
overrides = [f"calibration_dir={calibration_dir}"]
else:
# Use the default .cache/calibration folder when mock=False
# Use the default .cache/calibration folder when mock=False or for aloha
overrides = None
if robot_type == "aloha":
pytest.skip("TODO(rcadene): enable test once aloha_real and act_aloha_real are merged")
env_name = "koch_real"
policy_name = "act_koch_real"
@ -141,21 +140,58 @@ def test_record_and_replay_and_policy(tmpdir, request, robot_type, mock):
video=False,
# TODO(rcadene): display cameras through cv2 sometimes crashes on mac
display_cameras=False,
play_sounds=False,
)
replay(robot, episode=0, fps=30, root=root, repo_id=repo_id)
replay(robot, episode=0, fps=30, root=root, repo_id=repo_id, play_sounds=False)
# TODO(rcadene, aliberts): rethink this design
if robot_type == "aloha":
env_name = "aloha_real"
policy_name = "act_aloha_real"
elif robot_type in ["koch", "koch_bimanual"]:
env_name = "koch_real"
policy_name = "act_koch_real"
else:
raise NotImplementedError(robot_type)
overrides = [
f"env={env_name}",
f"policy={policy_name}",
f"device={DEVICE}",
]
if robot_type == "koch_bimanual":
overrides += ["env.state_dim=12", "env.action_dim=12"]
cfg = init_hydra_config(
DEFAULT_CONFIG_PATH,
overrides=[
f"env={env_name}",
f"policy={policy_name}",
f"device={DEVICE}",
],
overrides=overrides,
)
policy = make_policy(hydra_cfg=cfg, dataset_stats=dataset.stats)
# In `examples/9_use_aloha.md`, we advise using `num_image_writer_processes=1`
# during inference, to reach constent fps, so we test this here.
if robot_type == "aloha":
num_image_writer_processes = 1
# `multiprocessing.set_start_method("spawn", force=True)` avoids a hanging issue
# before exiting pytest. However, it outputs the following error in the log:
# Traceback (most recent call last):
# File "<string>", line 1, in <module>
# File "/Users/rcadene/miniconda3/envs/lerobot/lib/python3.10/multiprocessing/spawn.py", line 116, in spawn_main
# exitcode = _main(fd, parent_sentinel)
# File "/Users/rcadene/miniconda3/envs/lerobot/lib/python3.10/multiprocessing/spawn.py", line 126, in _main
# self = reduction.pickle.load(from_parent)
# File "/Users/rcadene/miniconda3/envs/lerobot/lib/python3.10/multiprocessing/synchronize.py", line 110, in __setstate__
# self._semlock = _multiprocessing.SemLock._rebuild(*state)
# FileNotFoundError: [Errno 2] No such file or directory
# TODO(rcadene, aliberts): fix FileNotFoundError in multiprocessing
multiprocessing.set_start_method("spawn", force=True)
else:
num_image_writer_processes = 0
record(
robot,
policy,
@ -167,6 +203,8 @@ def test_record_and_replay_and_policy(tmpdir, request, robot_type, mock):
push_to_hub=False,
video=False,
display_cameras=False,
play_sounds=False,
num_image_writer_processes=num_image_writer_processes,
)
del robot

View File

@ -308,12 +308,11 @@ def test_flatten_unflatten_dict():
# "lerobot/cmu_stretch",
],
)
# TODO(rcadene, aliberts): all these tests fail locally on Mac M1, but not on Linux
def test_backward_compatibility(repo_id):
"""The artifacts for this test have been generated by `tests/scripts/save_dataset_to_safetensors.py`."""
dataset = LeRobotDataset(
repo_id,
)
dataset = LeRobotDataset(repo_id)
test_dir = Path("tests/data/save_dataset_to_safetensors") / repo_id

View File

@ -367,8 +367,7 @@ def test_normalize(insert_temporal_dim):
),
("aloha", "act", ["policy.n_action_steps=10"], ""),
("aloha", "act", ["policy.n_action_steps=1000", "policy.chunk_size=1000"], "_1000_steps"),
("dora_aloha_real", "act_real", ["policy.n_action_steps=10"], ""),
("dora_aloha_real", "act_real_no_state", ["policy.n_action_steps=10"], ""),
("dora_aloha_real", "act_aloha_real", ["policy.n_action_steps=10"], ""),
],
)
# As artifacts have been generated on an x86_64 kernel, this test won't