Add policy/act_aloha_real.yaml + env/act_real.yaml (#429)
Co-authored-by: Simon Alibert <75076266+aliberts@users.noreply.github.com>
This commit is contained in:
parent
c29e70e5a1
commit
97b1feb0b3
|
@ -0,0 +1,179 @@
|
|||
This tutorial explains how to use [Aloha and Aloha 2 stationary](https://www.trossenrobotics.com/aloha-stationary) with LeRobot.
|
||||
|
||||
## Setup
|
||||
|
||||
Follow the [documentation from Trossen Robotics](https://docs.trossenrobotics.com/aloha_docs/getting_started/stationary/hardware_setup.html) for setting up the hardware and plugging the 4 arms and 4 cameras to your computer.
|
||||
|
||||
|
||||
## Install LeRobot
|
||||
|
||||
On your computer:
|
||||
|
||||
1. [Install Miniconda](https://docs.anaconda.com/miniconda/#quick-command-line-install):
|
||||
```bash
|
||||
mkdir -p ~/miniconda3
|
||||
wget https://repo.anaconda.com/miniconda/Miniconda3-latest-Linux-x86_64.sh -O ~/miniconda3/miniconda.sh
|
||||
bash ~/miniconda3/miniconda.sh -b -u -p ~/miniconda3
|
||||
rm ~/miniconda3/miniconda.sh
|
||||
~/miniconda3/bin/conda init bash
|
||||
```
|
||||
|
||||
2. Restart shell or `source ~/.bashrc`
|
||||
|
||||
3. Create and activate a fresh conda environment for lerobot
|
||||
```bash
|
||||
conda create -y -n lerobot python=3.10 && conda activate lerobot
|
||||
```
|
||||
|
||||
4. Clone LeRobot:
|
||||
```bash
|
||||
git clone https://github.com/huggingface/lerobot.git ~/lerobot
|
||||
```
|
||||
|
||||
5. Install LeRobot with dependencies for the Aloha motors (dynamixel) and cameras (intelrealsense):
|
||||
```bash
|
||||
cd ~/lerobot && pip install -e ".[dynamixel intelrealsense]"
|
||||
```
|
||||
|
||||
And install extra dependencies for recording datasets on Linux:
|
||||
```bash
|
||||
conda install -y -c conda-forge ffmpeg
|
||||
pip uninstall -y opencv-python
|
||||
conda install -y -c conda-forge "opencv>=4.10.0"
|
||||
```
|
||||
|
||||
## Teleoperate
|
||||
|
||||
**/!\ FOR SAFETY, READ THIS /!\**
|
||||
Teleoperation consists in manually operating the leader arms to move the follower arms. Importantly:
|
||||
1. Make sure your leader arms are in the same position as the follower arms, so that the follower arms don't move too fast to match the leader arms,
|
||||
2. Our code assumes that your robot has been assembled following Trossen Robotics instructions. This allows us to skip calibration, as we use the pre-defined calibration files in `.cache/calibration/aloha_default`. If you replace a motor, make sure you follow the exact instructions from Trossen Robotics.
|
||||
|
||||
By running the following code, you can start your first **SAFE** teleoperation:
|
||||
```bash
|
||||
python lerobot/scripts/control_robot.py teleoperate \
|
||||
--robot-path lerobot/configs/robot/aloha.yaml \
|
||||
--robot-overrides max_relative_target=5
|
||||
```
|
||||
|
||||
By adding `--robot-overrides max_relative_target=5`, we override the default value for `max_relative_target` defined in `lerobot/configs/robot/aloha.yaml`. It is expected to be `5` to limit the magnitude of the movement for more safety, but the teloperation won't be smooth. When you feel confident, you can disable this limit by adding `--robot-overrides max_relative_target=null` to the command line:
|
||||
```bash
|
||||
python lerobot/scripts/control_robot.py teleoperate \
|
||||
--robot-path lerobot/configs/robot/aloha.yaml \
|
||||
--robot-overrides max_relative_target=null
|
||||
```
|
||||
|
||||
## Record a dataset
|
||||
|
||||
Once you're familiar with teleoperation, you can record your first dataset with Aloha.
|
||||
|
||||
If you want to use the Hugging Face hub features for uploading your dataset and you haven't previously done it, make sure you've logged in using a write-access token, which can be generated from the [Hugging Face settings](https://huggingface.co/settings/tokens):
|
||||
```bash
|
||||
huggingface-cli login --token ${HUGGINGFACE_TOKEN} --add-to-git-credential
|
||||
```
|
||||
|
||||
Store your Hugging Face repository name in a variable to run these commands:
|
||||
```bash
|
||||
HF_USER=$(huggingface-cli whoami | head -n 1)
|
||||
echo $HF_USER
|
||||
```
|
||||
|
||||
Record 2 episodes and upload your dataset to the hub:
|
||||
```bash
|
||||
python lerobot/scripts/control_robot.py record \
|
||||
--robot-path lerobot/configs/robot/aloha.yaml \
|
||||
--robot-overrides max_relative_target=null \
|
||||
--fps 30 \
|
||||
--root data \
|
||||
--repo-id ${HF_USER}/aloha_test \
|
||||
--tags aloha tutorial \
|
||||
--warmup-time-s 5 \
|
||||
--episode-time-s 40 \
|
||||
--reset-time-s 10 \
|
||||
--num-episodes 2 \
|
||||
--push-to-hub 1
|
||||
```
|
||||
|
||||
## Visualize a dataset
|
||||
|
||||
If you uploaded your dataset to the hub with `--push-to-hub 1`, you can [visualize your dataset online](https://huggingface.co/spaces/lerobot/visualize_dataset) by copy pasting your repo id given by:
|
||||
```bash
|
||||
echo ${HF_USER}/aloha_test
|
||||
```
|
||||
|
||||
If you didn't upload with `--push-to-hub 0`, you can also visualize it locally with:
|
||||
```bash
|
||||
python lerobot/scripts/visualize_dataset_html.py \
|
||||
--root data \
|
||||
--repo-id ${HF_USER}/aloha_test
|
||||
```
|
||||
|
||||
## Replay an episode
|
||||
|
||||
**/!\ FOR SAFETY, READ THIS /!\**
|
||||
Replay consists in automatically replaying the sequence of actions (i.e. goal positions for your motors) recorded in a given dataset episode. Make sure the current initial position of your robot is similar to the one in your episode, so that your follower arms don't move too fast to go to the first goal positions. For safety, you might want to add `--robot-overrides max_relative_target=5` to your command line as explained above.
|
||||
|
||||
Now try to replay the first episode on your robot:
|
||||
```bash
|
||||
python lerobot/scripts/control_robot.py replay \
|
||||
--robot-path lerobot/configs/robot/aloha.yaml \
|
||||
--robot-overrides max_relative_target=null \
|
||||
--fps 30 \
|
||||
--root data \
|
||||
--repo-id ${HF_USER}/aloha_test \
|
||||
--episode 0
|
||||
```
|
||||
|
||||
## Train a policy
|
||||
|
||||
To train a policy to control your robot, use the [`python lerobot/scripts/train.py`](../lerobot/scripts/train.py) script. A few arguments are required. Here is an example command:
|
||||
```bash
|
||||
DATA_DIR=data python lerobot/scripts/train.py \
|
||||
dataset_repo_id=${HF_USER}/aloha_test \
|
||||
policy=act_aloha_real \
|
||||
env=aloha_real \
|
||||
hydra.run.dir=outputs/train/act_aloha_test \
|
||||
hydra.job.name=act_aloha_test \
|
||||
device=cuda \
|
||||
wandb.enable=true
|
||||
```
|
||||
|
||||
Let's explain it:
|
||||
1. We provided the dataset as argument with `dataset_repo_id=${HF_USER}/aloha_test`.
|
||||
2. We provided the policy with `policy=act_aloha_real`. This loads configurations from [`lerobot/configs/policy/act_aloha_real.yaml`](../lerobot/configs/policy/act_aloha_real.yaml). Importantly, this policy uses 4 cameras as input `cam_right_wrist`, `cam_left_wrist`, `cam_high`, and `cam_low`.
|
||||
3. We provided an environment as argument with `env=aloha_real`. This loads configurations from [`lerobot/configs/env/aloha_real.yaml`](../lerobot/configs/env/aloha_real.yaml). Note: this yaml defines 18 dimensions for the `state_dim` and `action_dim`, corresponding to 18 motors, not 14 motors as used in previous Aloha work. This is because, we include the `shoulder_shadow` and `elbow_shadow` motors for simplicity.
|
||||
4. We provided `device=cuda` since we are training on a Nvidia GPU.
|
||||
5. We provided `wandb.enable=true` to use [Weights and Biases](https://docs.wandb.ai/quickstart) for visualizing training plots. This is optional but if you use it, make sure you are logged in by running `wandb login`.
|
||||
6. We added `DATA_DIR=data` to access your dataset stored in your local `data` directory. If you dont provide `DATA_DIR`, your dataset will be downloaded from Hugging Face hub to your cache folder `$HOME/.cache/hugginface`. In future versions of `lerobot`, both directories will be in sync.
|
||||
|
||||
Training should take several hours. You will find checkpoints in `outputs/train/act_aloha_test/checkpoints`.
|
||||
|
||||
## Evaluate your policy
|
||||
|
||||
You can use the `record` function from [`lerobot/scripts/control_robot.py`](../lerobot/scripts/control_robot.py) but with a policy checkpoint as input. For instance, run this command to record 10 evaluation episodes:
|
||||
```bash
|
||||
python lerobot/scripts/control_robot.py record \
|
||||
--robot-path lerobot/configs/robot/aloha.yaml \
|
||||
--robot-overrides max_relative_target=null \
|
||||
--fps 30 \
|
||||
--root data \
|
||||
--repo-id ${HF_USER}/eval_act_aloha_test \
|
||||
--tags aloha tutorial eval \
|
||||
--warmup-time-s 5 \
|
||||
--episode-time-s 40 \
|
||||
--reset-time-s 10 \
|
||||
--num-episodes 10 \
|
||||
--num-image-writer-processes 1 \
|
||||
-p outputs/train/act_aloha_test/checkpoints/last/pretrained_model
|
||||
```
|
||||
|
||||
As you can see, it's almost the same command as previously used to record your training dataset. Two things changed:
|
||||
1. There is an additional `-p` argument which indicates the path to your policy checkpoint with (e.g. `-p outputs/train/eval_aloha_test/checkpoints/last/pretrained_model`). You can also use the model repository if you uploaded a model checkpoint to the hub (e.g. `-p ${HF_USER}/act_aloha_test`).
|
||||
2. The name of dataset begins by `eval` to reflect that you are running inference (e.g. `--repo-id ${HF_USER}/eval_act_aloha_test`).
|
||||
3. We use `--num-image-writer-processes 1` instead of the default value (`0`). On our computer, using a dedicated process to write images from the 4 cameras on disk allows to reach constent 30 fps during inference. Feel free to explore different values for `--num-image-writer-processes`.
|
||||
|
||||
## More
|
||||
|
||||
Follow this [previous tutorial](https://github.com/huggingface/lerobot/blob/main/examples/7_get_started_with_real_robot.md#4-train-a-policy-on-your-data) for a more in-depth explaination.
|
||||
|
||||
If you have any question or need help, please reach out on Discord in the channel `#aloha-arm`.
|
|
@ -216,7 +216,9 @@ available_policies_per_env = {
|
|||
"aloha": ["act"],
|
||||
"pusht": ["diffusion", "vqbet"],
|
||||
"xarm": ["tdmpc"],
|
||||
"dora_aloha_real": ["act_real"],
|
||||
"koch_real": ["act_koch_real"],
|
||||
"aloha_real": ["act_aloha_real"],
|
||||
"dora_aloha_real": ["act_aloha_real"],
|
||||
}
|
||||
|
||||
env_task_pairs = [(env, task) for env, tasks in available_tasks_per_env.items() for task in tasks]
|
||||
|
|
|
@ -364,6 +364,7 @@ class ManipulatorRobot:
|
|||
for name in self.follower_arms:
|
||||
print(f"Connecting {name} follower arm.")
|
||||
self.follower_arms[name].connect()
|
||||
for name in self.leader_arms:
|
||||
print(f"Connecting {name} leader arm.")
|
||||
self.leader_arms[name].connect()
|
||||
|
||||
|
|
|
@ -0,0 +1,10 @@
|
|||
# @package _global_
|
||||
|
||||
fps: 30
|
||||
|
||||
env:
|
||||
name: real_world
|
||||
task: null
|
||||
state_dim: 18
|
||||
action_dim: 18
|
||||
fps: ${fps}
|
|
@ -1,16 +1,22 @@
|
|||
# @package _global_
|
||||
|
||||
# Use `act_real.yaml` to train on real-world Aloha/Aloha2 datasets.
|
||||
# Compared to `act.yaml`, it contains 4 cameras (i.e. cam_right_wrist, cam_left_wrist, images,
|
||||
# cam_low) instead of 1 camera (i.e. top). Also, `training.eval_freq` is set to -1. This config is used
|
||||
# to evaluate checkpoints at a certain frequency of training steps. When it is set to -1, it deactivates evaluation.
|
||||
# This is because real-world evaluation is done through [dora-lerobot](https://github.com/dora-rs/dora-lerobot).
|
||||
# Look at its README for more information on how to evaluate a checkpoint in the real-world.
|
||||
# Use `act_aloha_real.yaml` to train on real-world datasets collected on Aloha or Aloha-2 robots.
|
||||
# Compared to `act.yaml`, it contains 4 cameras (i.e. cam_right_wrist, cam_left_wrist, cam_high, cam_low) instead of 1 camera (i.e. top).
|
||||
# Also, `training.eval_freq` is set to -1. This config is used to evaluate checkpoints at a certain frequency of training steps.
|
||||
# When it is set to -1, it deactivates evaluation. This is because real-world evaluation is done through our `control_robot.py` script.
|
||||
# Look at the documentation in header of `control_robot.py` for more information on how to collect data , train and evaluate a policy.
|
||||
#
|
||||
# Example of usage for training:
|
||||
# Example of usage for training and inference with `control_robot.py`:
|
||||
# ```bash
|
||||
# python lerobot/scripts/train.py \
|
||||
# policy=act_real \
|
||||
# policy=act_aloha_real \
|
||||
# env=aloha_real
|
||||
# ```
|
||||
#
|
||||
# Example of usage for training and inference with [Dora-rs](https://github.com/dora-rs/dora-lerobot):
|
||||
# ```bash
|
||||
# python lerobot/scripts/train.py \
|
||||
# policy=act_aloha_real \
|
||||
# env=dora_aloha_real
|
||||
# ```
|
||||
|
||||
|
@ -36,10 +42,11 @@ override_dataset_stats:
|
|||
std: [[[0.229]], [[0.224]], [[0.225]]] # (c,1,1)
|
||||
|
||||
training:
|
||||
offline_steps: 100000
|
||||
offline_steps: 80000
|
||||
online_steps: 0
|
||||
eval_freq: -1
|
||||
save_freq: 20000
|
||||
save_freq: 10000
|
||||
log_freq: 100
|
||||
save_checkpoint: true
|
||||
|
||||
batch_size: 8
|
||||
|
@ -62,7 +69,7 @@ policy:
|
|||
|
||||
# Input / output structure.
|
||||
n_obs_steps: 1
|
||||
chunk_size: 100 # chunk_size
|
||||
chunk_size: 100
|
||||
n_action_steps: 100
|
||||
|
||||
input_shapes:
|
||||
|
@ -107,7 +114,7 @@ policy:
|
|||
n_vae_encoder_layers: 4
|
||||
|
||||
# Inference.
|
||||
temporal_ensemble_coeff: null
|
||||
temporal_ensemble_momentum: null
|
||||
|
||||
# Training and loss computation.
|
||||
dropout: 0.1
|
|
@ -1,110 +0,0 @@
|
|||
# @package _global_
|
||||
|
||||
# Use `act_real_no_state.yaml` to train on real-world Aloha/Aloha2 datasets when cameras are moving (e.g. wrist cameras)
|
||||
# Compared to `act_real.yaml`, it is camera only and does not use the state as input which is vector of robot joint positions.
|
||||
# We validated experimentaly that not using state reaches better success rate. Our hypothesis is that `act_real.yaml` might
|
||||
# overfits to the state, because the images are more complex to learn from since they are moving.
|
||||
#
|
||||
# Example of usage for training:
|
||||
# ```bash
|
||||
# python lerobot/scripts/train.py \
|
||||
# policy=act_real_no_state \
|
||||
# env=dora_aloha_real
|
||||
# ```
|
||||
|
||||
seed: 1000
|
||||
dataset_repo_id: lerobot/aloha_static_vinh_cup
|
||||
|
||||
override_dataset_stats:
|
||||
observation.images.cam_right_wrist:
|
||||
# stats from imagenet, since we use a pretrained vision model
|
||||
mean: [[[0.485]], [[0.456]], [[0.406]]] # (c,1,1)
|
||||
std: [[[0.229]], [[0.224]], [[0.225]]] # (c,1,1)
|
||||
observation.images.cam_left_wrist:
|
||||
# stats from imagenet, since we use a pretrained vision model
|
||||
mean: [[[0.485]], [[0.456]], [[0.406]]] # (c,1,1)
|
||||
std: [[[0.229]], [[0.224]], [[0.225]]] # (c,1,1)
|
||||
observation.images.cam_high:
|
||||
# stats from imagenet, since we use a pretrained vision model
|
||||
mean: [[[0.485]], [[0.456]], [[0.406]]] # (c,1,1)
|
||||
std: [[[0.229]], [[0.224]], [[0.225]]] # (c,1,1)
|
||||
observation.images.cam_low:
|
||||
# stats from imagenet, since we use a pretrained vision model
|
||||
mean: [[[0.485]], [[0.456]], [[0.406]]] # (c,1,1)
|
||||
std: [[[0.229]], [[0.224]], [[0.225]]] # (c,1,1)
|
||||
|
||||
training:
|
||||
offline_steps: 100000
|
||||
online_steps: 0
|
||||
eval_freq: -1
|
||||
save_freq: 20000
|
||||
save_checkpoint: true
|
||||
|
||||
batch_size: 8
|
||||
lr: 1e-5
|
||||
lr_backbone: 1e-5
|
||||
weight_decay: 1e-4
|
||||
grad_clip_norm: 10
|
||||
online_steps_between_rollouts: 1
|
||||
|
||||
delta_timestamps:
|
||||
action: "[i / ${fps} for i in range(${policy.chunk_size})]"
|
||||
|
||||
eval:
|
||||
n_episodes: 50
|
||||
batch_size: 50
|
||||
|
||||
# See `configuration_act.py` for more details.
|
||||
policy:
|
||||
name: act
|
||||
|
||||
# Input / output structure.
|
||||
n_obs_steps: 1
|
||||
chunk_size: 100 # chunk_size
|
||||
n_action_steps: 100
|
||||
|
||||
input_shapes:
|
||||
# TODO(rcadene, alexander-soare): add variables for height and width from the dataset/env?
|
||||
observation.images.cam_right_wrist: [3, 480, 640]
|
||||
observation.images.cam_left_wrist: [3, 480, 640]
|
||||
observation.images.cam_high: [3, 480, 640]
|
||||
observation.images.cam_low: [3, 480, 640]
|
||||
output_shapes:
|
||||
action: ["${env.action_dim}"]
|
||||
|
||||
# Normalization / Unnormalization
|
||||
input_normalization_modes:
|
||||
observation.images.cam_right_wrist: mean_std
|
||||
observation.images.cam_left_wrist: mean_std
|
||||
observation.images.cam_high: mean_std
|
||||
observation.images.cam_low: mean_std
|
||||
output_normalization_modes:
|
||||
action: mean_std
|
||||
|
||||
# Architecture.
|
||||
# Vision backbone.
|
||||
vision_backbone: resnet18
|
||||
pretrained_backbone_weights: ResNet18_Weights.IMAGENET1K_V1
|
||||
replace_final_stride_with_dilation: false
|
||||
# Transformer layers.
|
||||
pre_norm: false
|
||||
dim_model: 512
|
||||
n_heads: 8
|
||||
dim_feedforward: 3200
|
||||
feedforward_activation: relu
|
||||
n_encoder_layers: 4
|
||||
# Note: Although the original ACT implementation has 7 for `n_decoder_layers`, there is a bug in the code
|
||||
# that means only the first layer is used. Here we match the original implementation by setting this to 1.
|
||||
# See this issue https://github.com/tonyzhaozh/act/issues/25#issue-2258740521.
|
||||
n_decoder_layers: 1
|
||||
# VAE.
|
||||
use_vae: true
|
||||
latent_dim: 32
|
||||
n_vae_encoder_layers: 4
|
||||
|
||||
# Inference.
|
||||
temporal_ensemble_coeff: null
|
||||
|
||||
# Training and loss computation.
|
||||
dropout: 0.1
|
||||
kl_weight: 10.0
|
|
@ -102,6 +102,7 @@ import argparse
|
|||
import concurrent.futures
|
||||
import json
|
||||
import logging
|
||||
import multiprocessing
|
||||
import os
|
||||
import platform
|
||||
import shutil
|
||||
|
@ -163,9 +164,9 @@ def say(text, blocking=False):
|
|||
os.system(cmd)
|
||||
|
||||
|
||||
def save_image(img_tensor, key, frame_index, episode_index, videos_dir):
|
||||
def save_image(img_tensor, key, frame_index, episode_index, videos_dir: str):
|
||||
img = Image.fromarray(img_tensor.numpy())
|
||||
path = videos_dir / f"{key}_episode_{episode_index:06d}" / f"frame_{frame_index:06d}.png"
|
||||
path = Path(videos_dir) / f"{key}_episode_{episode_index:06d}" / f"frame_{frame_index:06d}.png"
|
||||
path.parent.mkdir(parents=True, exist_ok=True)
|
||||
img.save(str(path), quality=100)
|
||||
|
||||
|
@ -255,6 +256,129 @@ def get_available_arms(robot):
|
|||
return available_arms
|
||||
|
||||
|
||||
########################################################################################
|
||||
# Asynchrounous saving of images on disk
|
||||
########################################################################################
|
||||
|
||||
|
||||
def loop_to_save_images_in_threads(image_queue, num_threads):
|
||||
if num_threads < 1:
|
||||
raise NotImplementedError(f"Only `num_threads>=1` is supported for now, but {num_threads=} given.")
|
||||
|
||||
with concurrent.futures.ThreadPoolExecutor(max_workers=num_threads) as executor:
|
||||
futures = []
|
||||
while True:
|
||||
# Blocks until a frame is available
|
||||
frame_data = image_queue.get()
|
||||
|
||||
# As usually done, exit loop when receiving None to stop the worker
|
||||
if frame_data is None:
|
||||
break
|
||||
|
||||
image, key, frame_index, episode_index, videos_dir = frame_data
|
||||
futures.append(executor.submit(save_image, image, key, frame_index, episode_index, videos_dir))
|
||||
|
||||
# Before exiting function, wait for all threads to complete
|
||||
with tqdm.tqdm(total=len(futures), desc="Writing images") as progress_bar:
|
||||
concurrent.futures.wait(futures)
|
||||
progress_bar.update(len(futures))
|
||||
|
||||
|
||||
def start_image_writer_processes(image_queue, num_processes, num_threads_per_process):
|
||||
if num_processes < 1:
|
||||
raise ValueError(f"Only `num_processes>=1` is supported, but {num_processes=} given.")
|
||||
|
||||
if num_threads_per_process < 1:
|
||||
raise NotImplementedError(
|
||||
"Only `num_threads_per_process>=1` is supported for now, but {num_threads_per_process=} given."
|
||||
)
|
||||
|
||||
processes = []
|
||||
for _ in range(num_processes):
|
||||
process = multiprocessing.Process(
|
||||
target=loop_to_save_images_in_threads,
|
||||
args=(image_queue, num_threads_per_process),
|
||||
)
|
||||
process.start()
|
||||
processes.append(process)
|
||||
return processes
|
||||
|
||||
|
||||
def stop_processes(processes, queue, timeout):
|
||||
# Send None to each process to signal them to stop
|
||||
for _ in processes:
|
||||
queue.put(None)
|
||||
|
||||
# Close the queue, no more items can be put in the queue
|
||||
queue.close()
|
||||
|
||||
# Wait maximum 20 seconds for all processes to terminate
|
||||
for process in processes:
|
||||
process.join(timeout=timeout)
|
||||
|
||||
# If not terminated after 20 seconds, force termination
|
||||
if process.is_alive():
|
||||
process.terminate()
|
||||
|
||||
# Ensure all background queue threads have finished
|
||||
queue.join_thread()
|
||||
|
||||
|
||||
def start_image_writer(num_processes, num_threads):
|
||||
"""This function abstract away the initialisation of processes or/and threads to
|
||||
save images on disk asynchrounously, which is critical to control a robot and record data
|
||||
at a high frame rate.
|
||||
|
||||
When `num_processes=0`, it returns a dictionary containing a threads pool of size `num_threads`.
|
||||
When `num_processes>0`, it returns a dictionary containing a processes pool of size `num_processes`,
|
||||
where each subprocess starts their own threads pool of size `num_threads`.
|
||||
|
||||
The optimal number of processes and threads depends on your computer capabilities.
|
||||
We advise to use 4 threads per camera with 0 processes. If the fps is not stable, try to increase or lower
|
||||
the number of threads. If it is still not stable, try to use 1 subprocess, or more.
|
||||
"""
|
||||
image_writer = {}
|
||||
|
||||
if num_processes == 0:
|
||||
futures = []
|
||||
threads_pool = concurrent.futures.ThreadPoolExecutor(max_workers=num_threads)
|
||||
image_writer["threads_pool"], image_writer["futures"] = threads_pool, futures
|
||||
else:
|
||||
# TODO(rcadene): When using num_processes>1, `multiprocessing.Manager().Queue()`
|
||||
# might be better than `multiprocessing.Queue()`. Source: https://www.geeksforgeeks.org/python-multiprocessing-queue-vs-multiprocessing-manager-queue
|
||||
image_queue = multiprocessing.Queue()
|
||||
processes_pool = start_image_writer_processes(
|
||||
image_queue, num_processes=num_processes, num_threads_per_process=num_threads
|
||||
)
|
||||
image_writer["processes_pool"], image_writer["image_queue"] = processes_pool, image_queue
|
||||
|
||||
return image_writer
|
||||
|
||||
|
||||
def async_save_image(image_writer, image, key, frame_index, episode_index, videos_dir):
|
||||
"""This function abstract away the saving of an image on disk asynchrounously. It uses a dictionary
|
||||
called image writer which contains either a pool of processes or a pool of threads.
|
||||
"""
|
||||
if "threads_pool" in image_writer:
|
||||
threads_pool, futures = image_writer["threads_pool"], image_writer["futures"]
|
||||
futures.append(threads_pool.submit(save_image, image, key, frame_index, episode_index, videos_dir))
|
||||
else:
|
||||
image_queue = image_writer["image_queue"]
|
||||
image_queue.put((image, key, frame_index, episode_index, videos_dir))
|
||||
|
||||
|
||||
def stop_image_writer(image_writer, timeout):
|
||||
if "threads_pool" in image_writer:
|
||||
futures = image_writer["futures"]
|
||||
# Before exiting function, wait for all threads to complete
|
||||
with tqdm.tqdm(total=len(futures), desc="Writing images") as progress_bar:
|
||||
concurrent.futures.wait(futures, timeout=timeout)
|
||||
progress_bar.update(len(futures))
|
||||
else:
|
||||
processes_pool, image_queue = image_writer["processes_pool"], image_writer["image_queue"]
|
||||
stop_processes(processes_pool, image_queue, timeout=timeout)
|
||||
|
||||
|
||||
########################################################################################
|
||||
# Control modes
|
||||
########################################################################################
|
||||
|
@ -342,9 +466,11 @@ def record(
|
|||
run_compute_stats=True,
|
||||
push_to_hub=True,
|
||||
tags=None,
|
||||
num_image_writers_per_camera=4,
|
||||
num_image_writer_processes=0,
|
||||
num_image_writer_threads_per_camera=4,
|
||||
force_override=False,
|
||||
display_cameras=True,
|
||||
play_sounds=True,
|
||||
):
|
||||
# TODO(rcadene): Add option to record logs
|
||||
# TODO(rcadene): Clean this function via decomposition in higher level functions
|
||||
|
@ -436,7 +562,8 @@ def record(
|
|||
while timestamp < warmup_time_s:
|
||||
if not is_warmup_print:
|
||||
logging.info("Warming up (no data recording)")
|
||||
say("Warming up")
|
||||
if play_sounds:
|
||||
say("Warming up")
|
||||
is_warmup_print = True
|
||||
|
||||
start_loop_t = time.perf_counter()
|
||||
|
@ -463,16 +590,22 @@ def record(
|
|||
if has_method(robot, "teleop_safety_stop"):
|
||||
robot.teleop_safety_stop()
|
||||
|
||||
# Save images using threads to reach high fps (30 and more)
|
||||
# Using `with` to exist smoothly if an execption is raised.
|
||||
futures = []
|
||||
num_image_writers = num_image_writers_per_camera * len(robot.cameras)
|
||||
num_image_writers = max(num_image_writers, 1)
|
||||
with concurrent.futures.ThreadPoolExecutor(max_workers=num_image_writers) as executor:
|
||||
has_camera = len(robot.cameras) > 0
|
||||
if has_camera:
|
||||
# Initialize processes or/and threads dedicated to save images on disk asynchronously,
|
||||
# which is critical to control a robot and record data at a high frame rate.
|
||||
image_writer = start_image_writer(
|
||||
num_processes=num_image_writer_processes,
|
||||
num_threads=num_image_writer_threads_per_camera * len(robot.cameras),
|
||||
)
|
||||
|
||||
# Using `try` to exist smoothly if an exception is raised
|
||||
try:
|
||||
# Start recording all episodes
|
||||
while episode_index < num_episodes:
|
||||
logging.info(f"Recording episode {episode_index}")
|
||||
say(f"Recording episode {episode_index}")
|
||||
if play_sounds:
|
||||
say(f"Recording episode {episode_index}")
|
||||
ep_dict = {}
|
||||
frame_index = 0
|
||||
timestamp = 0
|
||||
|
@ -488,12 +621,16 @@ def record(
|
|||
image_keys = [key for key in observation if "image" in key]
|
||||
not_image_keys = [key for key in observation if "image" not in key]
|
||||
|
||||
for key in image_keys:
|
||||
futures += [
|
||||
executor.submit(
|
||||
save_image, observation[key], key, frame_index, episode_index, videos_dir
|
||||
if has_camera > 0:
|
||||
for key in image_keys:
|
||||
async_save_image(
|
||||
image_writer,
|
||||
image=observation[key],
|
||||
key=key,
|
||||
frame_index=frame_index,
|
||||
episode_index=episode_index,
|
||||
videos_dir=str(videos_dir),
|
||||
)
|
||||
]
|
||||
|
||||
if display_cameras and not is_headless():
|
||||
image_keys = [key for key in observation if "image" in key]
|
||||
|
@ -563,7 +700,8 @@ def record(
|
|||
if not stop_recording:
|
||||
# Start resetting env while the executor are finishing
|
||||
logging.info("Reset the environment")
|
||||
say("Reset the environment")
|
||||
if play_sounds:
|
||||
say("Reset the environment")
|
||||
|
||||
timestamp = 0
|
||||
start_vencod_t = time.perf_counter()
|
||||
|
@ -635,18 +773,23 @@ def record(
|
|||
|
||||
if is_last_episode:
|
||||
logging.info("Done recording")
|
||||
say("Done recording", blocking=True)
|
||||
if play_sounds:
|
||||
say("Done recording", blocking=True)
|
||||
if not is_headless():
|
||||
listener.stop()
|
||||
|
||||
logging.info("Waiting for threads writing the images on disk to terminate...")
|
||||
for _ in tqdm.tqdm(
|
||||
concurrent.futures.as_completed(futures), total=len(futures), desc="Writting images"
|
||||
):
|
||||
pass
|
||||
break
|
||||
if has_camera > 0:
|
||||
logging.info("Waiting for image writer to terminate...")
|
||||
stop_image_writer(image_writer, timeout=20)
|
||||
|
||||
except Exception as e:
|
||||
if has_camera > 0:
|
||||
logging.info("Waiting for image writer to terminate...")
|
||||
stop_image_writer(image_writer, timeout=20)
|
||||
raise e
|
||||
|
||||
robot.disconnect()
|
||||
|
||||
if display_cameras and not is_headless():
|
||||
cv2.destroyAllWindows()
|
||||
|
||||
|
@ -654,7 +797,8 @@ def record(
|
|||
|
||||
if video:
|
||||
logging.info("Encoding videos")
|
||||
say("Encoding videos")
|
||||
if play_sounds:
|
||||
say("Encoding videos")
|
||||
# Use ffmpeg to convert frames stored as png into mp4 videos
|
||||
for episode_index in tqdm.tqdm(range(num_episodes)):
|
||||
for key in image_keys:
|
||||
|
@ -699,7 +843,8 @@ def record(
|
|||
)
|
||||
if run_compute_stats:
|
||||
logging.info("Computing dataset statistics")
|
||||
say("Computing dataset statistics")
|
||||
if play_sounds:
|
||||
say("Computing dataset statistics")
|
||||
stats = compute_stats(lerobot_dataset)
|
||||
lerobot_dataset.stats = stats
|
||||
else:
|
||||
|
@ -721,11 +866,14 @@ def record(
|
|||
create_branch(repo_id, repo_type="dataset", branch=CODEBASE_VERSION)
|
||||
|
||||
logging.info("Exiting")
|
||||
say("Exiting")
|
||||
if play_sounds:
|
||||
say("Exiting")
|
||||
return lerobot_dataset
|
||||
|
||||
|
||||
def replay(robot: Robot, episode: int, fps: int | None = None, root="data", repo_id="lerobot/debug"):
|
||||
def replay(
|
||||
robot: Robot, episode: int, fps: int | None = None, root="data", repo_id="lerobot/debug", play_sounds=True
|
||||
):
|
||||
# TODO(rcadene): Add option to record logs
|
||||
local_dir = Path(root) / repo_id
|
||||
if not local_dir.exists():
|
||||
|
@ -740,7 +888,8 @@ def replay(robot: Robot, episode: int, fps: int | None = None, root="data", repo
|
|||
robot.connect()
|
||||
|
||||
logging.info("Replaying episode")
|
||||
say("Replaying episode", blocking=True)
|
||||
if play_sounds:
|
||||
say("Replaying episode", blocking=True)
|
||||
for idx in range(from_idx, to_idx):
|
||||
start_episode_t = time.perf_counter()
|
||||
|
||||
|
@ -840,12 +989,23 @@ if __name__ == "__main__":
|
|||
help="Add tags to your dataset on the hub.",
|
||||
)
|
||||
parser_record.add_argument(
|
||||
"--num-image-writers-per-camera",
|
||||
"--num-image-writer-processes",
|
||||
type=int,
|
||||
default=0,
|
||||
help=(
|
||||
"Number of subprocesses handling the saving of frames as PNGs. Set to 0 to use threads only; "
|
||||
"set to ≥1 to use subprocesses, each using threads to write images. The best number of processes "
|
||||
"and threads depends on your system. We recommend 4 threads per camera with 0 processes. "
|
||||
"If fps is unstable, adjust the thread count. If still unstable, try using 1 or more subprocesses."
|
||||
),
|
||||
)
|
||||
parser_record.add_argument(
|
||||
"--num-image-writer-threads-per-camera",
|
||||
type=int,
|
||||
default=4,
|
||||
help=(
|
||||
"Number of threads writing the frames as png images on disk, per camera. "
|
||||
"Too much threads might cause unstable teleoperation fps due to main thread being blocked. "
|
||||
"Too many threads might cause unstable teleoperation fps due to main thread being blocked. "
|
||||
"Not enough threads might cause low camera fps."
|
||||
),
|
||||
)
|
||||
|
|
|
@ -5245,7 +5245,7 @@ docs = ["sphinx", "sphinx-automodapi", "sphinx-rtd-theme"]
|
|||
name = "pyserial"
|
||||
version = "3.5"
|
||||
description = "Python Serial Port Extension"
|
||||
optional = true
|
||||
optional = false
|
||||
python-versions = "*"
|
||||
files = [
|
||||
{file = "pyserial-3.5-py2.py3-none-any.whl", hash = "sha256:c4451db6ba391ca6ca299fb3ec7bae67a5c55dde170964c7a14ceefec02f2cf0"},
|
||||
|
|
|
@ -52,8 +52,9 @@ def is_robot_available(robot_type):
|
|||
print(f"\nInstall module '{e.name}'")
|
||||
elif isinstance(e, SerialException):
|
||||
print("\nNo physical motors bus detected.")
|
||||
else:
|
||||
traceback.print_exc()
|
||||
|
||||
traceback.print_exc()
|
||||
return False
|
||||
|
||||
|
||||
|
@ -77,8 +78,9 @@ def is_camera_available(camera_type):
|
|||
print(f"\nInstall module '{e.name}'")
|
||||
elif isinstance(e, ValueError) and "camera_index" in e.args[0]:
|
||||
print("\nNo physical camera detected.")
|
||||
else:
|
||||
traceback.print_exc()
|
||||
|
||||
traceback.print_exc()
|
||||
return False
|
||||
|
||||
|
||||
|
@ -102,8 +104,9 @@ def is_motor_available(motor_type):
|
|||
print(f"\nInstall module '{e.name}'")
|
||||
elif isinstance(e, SerialException):
|
||||
print("\nNo physical motors bus detected.")
|
||||
else:
|
||||
traceback.print_exc()
|
||||
|
||||
traceback.print_exc()
|
||||
return False
|
||||
|
||||
|
||||
|
|
|
@ -1,3 +0,0 @@
|
|||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:b5a9f73a2356aff9c717cdfd0d37a6da08b0cf2cc09c98edbc9492501b7f64a5
|
||||
size 5104
|
|
@ -1,3 +0,0 @@
|
|||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:28738b3cfad17af0ac5181effdd796acdf7953cd5bcca3f421a11ddfd6b0076f
|
||||
size 30800
|
|
@ -1,3 +0,0 @@
|
|||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:4bb8a197a40456fdbc16029126268e6bcef3eca1837d88235165dc7e14618bea
|
||||
size 68
|
|
@ -1,3 +0,0 @@
|
|||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:bea60cce42d324f539dd3bca1e66b5ba6391838fdcadb00efc25f3240edb529a
|
||||
size 33600
|
|
@ -23,6 +23,7 @@ pytest -sx 'tests/test_control_robot.py::test_teleoperate[aloha-True]'
|
|||
```
|
||||
"""
|
||||
|
||||
import multiprocessing
|
||||
from pathlib import Path
|
||||
|
||||
import pytest
|
||||
|
@ -37,7 +38,7 @@ from tests.utils import DEFAULT_CONFIG_PATH, DEVICE, TEST_ROBOT_TYPES, require_r
|
|||
@pytest.mark.parametrize("robot_type, mock", TEST_ROBOT_TYPES)
|
||||
@require_robot
|
||||
def test_teleoperate(tmpdir, request, robot_type, mock):
|
||||
if mock:
|
||||
if mock and robot_type != "aloha":
|
||||
request.getfixturevalue("patch_builtins_input")
|
||||
|
||||
# Create an empty calibration directory to trigger manual calibration
|
||||
|
@ -78,7 +79,7 @@ def test_record_without_cameras(tmpdir, request, robot_type, mock):
|
|||
# Avoid using cameras
|
||||
overrides = ["~cameras"]
|
||||
|
||||
if mock:
|
||||
if mock and robot_type != "aloha":
|
||||
request.getfixturevalue("patch_builtins_input")
|
||||
|
||||
# Create an empty calibration directory to trigger manual calibration
|
||||
|
@ -101,13 +102,14 @@ def test_record_without_cameras(tmpdir, request, robot_type, mock):
|
|||
run_compute_stats=False,
|
||||
push_to_hub=False,
|
||||
video=False,
|
||||
play_sounds=False,
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("robot_type, mock", TEST_ROBOT_TYPES)
|
||||
@require_robot
|
||||
def test_record_and_replay_and_policy(tmpdir, request, robot_type, mock):
|
||||
if mock:
|
||||
if mock and robot_type != "aloha":
|
||||
request.getfixturevalue("patch_builtins_input")
|
||||
|
||||
# Create an empty calibration directory to trigger manual calibration
|
||||
|
@ -115,12 +117,9 @@ def test_record_and_replay_and_policy(tmpdir, request, robot_type, mock):
|
|||
calibration_dir = Path(tmpdir) / robot_type
|
||||
overrides = [f"calibration_dir={calibration_dir}"]
|
||||
else:
|
||||
# Use the default .cache/calibration folder when mock=False
|
||||
# Use the default .cache/calibration folder when mock=False or for aloha
|
||||
overrides = None
|
||||
|
||||
if robot_type == "aloha":
|
||||
pytest.skip("TODO(rcadene): enable test once aloha_real and act_aloha_real are merged")
|
||||
|
||||
env_name = "koch_real"
|
||||
policy_name = "act_koch_real"
|
||||
|
||||
|
@ -141,21 +140,58 @@ def test_record_and_replay_and_policy(tmpdir, request, robot_type, mock):
|
|||
video=False,
|
||||
# TODO(rcadene): display cameras through cv2 sometimes crashes on mac
|
||||
display_cameras=False,
|
||||
play_sounds=False,
|
||||
)
|
||||
|
||||
replay(robot, episode=0, fps=30, root=root, repo_id=repo_id)
|
||||
replay(robot, episode=0, fps=30, root=root, repo_id=repo_id, play_sounds=False)
|
||||
|
||||
# TODO(rcadene, aliberts): rethink this design
|
||||
if robot_type == "aloha":
|
||||
env_name = "aloha_real"
|
||||
policy_name = "act_aloha_real"
|
||||
elif robot_type in ["koch", "koch_bimanual"]:
|
||||
env_name = "koch_real"
|
||||
policy_name = "act_koch_real"
|
||||
else:
|
||||
raise NotImplementedError(robot_type)
|
||||
|
||||
overrides = [
|
||||
f"env={env_name}",
|
||||
f"policy={policy_name}",
|
||||
f"device={DEVICE}",
|
||||
]
|
||||
|
||||
if robot_type == "koch_bimanual":
|
||||
overrides += ["env.state_dim=12", "env.action_dim=12"]
|
||||
|
||||
cfg = init_hydra_config(
|
||||
DEFAULT_CONFIG_PATH,
|
||||
overrides=[
|
||||
f"env={env_name}",
|
||||
f"policy={policy_name}",
|
||||
f"device={DEVICE}",
|
||||
],
|
||||
overrides=overrides,
|
||||
)
|
||||
|
||||
policy = make_policy(hydra_cfg=cfg, dataset_stats=dataset.stats)
|
||||
|
||||
# In `examples/9_use_aloha.md`, we advise using `num_image_writer_processes=1`
|
||||
# during inference, to reach constent fps, so we test this here.
|
||||
if robot_type == "aloha":
|
||||
num_image_writer_processes = 1
|
||||
|
||||
# `multiprocessing.set_start_method("spawn", force=True)` avoids a hanging issue
|
||||
# before exiting pytest. However, it outputs the following error in the log:
|
||||
# Traceback (most recent call last):
|
||||
# File "<string>", line 1, in <module>
|
||||
# File "/Users/rcadene/miniconda3/envs/lerobot/lib/python3.10/multiprocessing/spawn.py", line 116, in spawn_main
|
||||
# exitcode = _main(fd, parent_sentinel)
|
||||
# File "/Users/rcadene/miniconda3/envs/lerobot/lib/python3.10/multiprocessing/spawn.py", line 126, in _main
|
||||
# self = reduction.pickle.load(from_parent)
|
||||
# File "/Users/rcadene/miniconda3/envs/lerobot/lib/python3.10/multiprocessing/synchronize.py", line 110, in __setstate__
|
||||
# self._semlock = _multiprocessing.SemLock._rebuild(*state)
|
||||
# FileNotFoundError: [Errno 2] No such file or directory
|
||||
# TODO(rcadene, aliberts): fix FileNotFoundError in multiprocessing
|
||||
multiprocessing.set_start_method("spawn", force=True)
|
||||
else:
|
||||
num_image_writer_processes = 0
|
||||
|
||||
record(
|
||||
robot,
|
||||
policy,
|
||||
|
@ -167,6 +203,8 @@ def test_record_and_replay_and_policy(tmpdir, request, robot_type, mock):
|
|||
push_to_hub=False,
|
||||
video=False,
|
||||
display_cameras=False,
|
||||
play_sounds=False,
|
||||
num_image_writer_processes=num_image_writer_processes,
|
||||
)
|
||||
|
||||
del robot
|
||||
|
|
|
@ -308,12 +308,11 @@ def test_flatten_unflatten_dict():
|
|||
# "lerobot/cmu_stretch",
|
||||
],
|
||||
)
|
||||
# TODO(rcadene, aliberts): all these tests fail locally on Mac M1, but not on Linux
|
||||
def test_backward_compatibility(repo_id):
|
||||
"""The artifacts for this test have been generated by `tests/scripts/save_dataset_to_safetensors.py`."""
|
||||
|
||||
dataset = LeRobotDataset(
|
||||
repo_id,
|
||||
)
|
||||
dataset = LeRobotDataset(repo_id)
|
||||
|
||||
test_dir = Path("tests/data/save_dataset_to_safetensors") / repo_id
|
||||
|
||||
|
|
|
@ -367,8 +367,7 @@ def test_normalize(insert_temporal_dim):
|
|||
),
|
||||
("aloha", "act", ["policy.n_action_steps=10"], ""),
|
||||
("aloha", "act", ["policy.n_action_steps=1000", "policy.chunk_size=1000"], "_1000_steps"),
|
||||
("dora_aloha_real", "act_real", ["policy.n_action_steps=10"], ""),
|
||||
("dora_aloha_real", "act_real_no_state", ["policy.n_action_steps=10"], ""),
|
||||
("dora_aloha_real", "act_aloha_real", ["policy.n_action_steps=10"], ""),
|
||||
],
|
||||
)
|
||||
# As artifacts have been generated on an x86_64 kernel, this test won't
|
||||
|
|
Loading…
Reference in New Issue