Merge remote-tracking branch 'origin/main' into user/aliberts/2024_09_25_reshape_dataset
This commit is contained in:
commit
354f37aed3
|
@ -47,6 +47,7 @@ jobs:
|
|||
pipx install poetry && poetry config virtualenvs.in-project true
|
||||
echo "${{ github.workspace }}/.venv/bin" >> $GITHUB_PATH
|
||||
|
||||
# TODO(rcadene, aliberts): python 3.12 seems to be used in the tests, not python 3.10
|
||||
- name: Set up Python 3.10
|
||||
uses: actions/setup-python@v5
|
||||
with:
|
||||
|
@ -84,6 +85,7 @@ jobs:
|
|||
pipx install poetry && poetry config virtualenvs.in-project true
|
||||
echo "${{ github.workspace }}/.venv/bin" >> $GITHUB_PATH
|
||||
|
||||
# TODO(rcadene, aliberts): python 3.12 seems to be used in the tests, not python 3.10
|
||||
- name: Set up Python 3.10
|
||||
uses: actions/setup-python@v5
|
||||
with:
|
||||
|
|
|
@ -0,0 +1,179 @@
|
|||
This tutorial explains how to use [Aloha and Aloha 2 stationary](https://www.trossenrobotics.com/aloha-stationary) with LeRobot.
|
||||
|
||||
## Setup
|
||||
|
||||
Follow the [documentation from Trossen Robotics](https://docs.trossenrobotics.com/aloha_docs/getting_started/stationary/hardware_setup.html) for setting up the hardware and plugging the 4 arms and 4 cameras to your computer.
|
||||
|
||||
|
||||
## Install LeRobot
|
||||
|
||||
On your computer:
|
||||
|
||||
1. [Install Miniconda](https://docs.anaconda.com/miniconda/#quick-command-line-install):
|
||||
```bash
|
||||
mkdir -p ~/miniconda3
|
||||
wget https://repo.anaconda.com/miniconda/Miniconda3-latest-Linux-x86_64.sh -O ~/miniconda3/miniconda.sh
|
||||
bash ~/miniconda3/miniconda.sh -b -u -p ~/miniconda3
|
||||
rm ~/miniconda3/miniconda.sh
|
||||
~/miniconda3/bin/conda init bash
|
||||
```
|
||||
|
||||
2. Restart shell or `source ~/.bashrc`
|
||||
|
||||
3. Create and activate a fresh conda environment for lerobot
|
||||
```bash
|
||||
conda create -y -n lerobot python=3.10 && conda activate lerobot
|
||||
```
|
||||
|
||||
4. Clone LeRobot:
|
||||
```bash
|
||||
git clone https://github.com/huggingface/lerobot.git ~/lerobot
|
||||
```
|
||||
|
||||
5. Install LeRobot with dependencies for the Aloha motors (dynamixel) and cameras (intelrealsense):
|
||||
```bash
|
||||
cd ~/lerobot && pip install -e ".[dynamixel intelrealsense]"
|
||||
```
|
||||
|
||||
And install extra dependencies for recording datasets on Linux:
|
||||
```bash
|
||||
conda install -y -c conda-forge ffmpeg
|
||||
pip uninstall -y opencv-python
|
||||
conda install -y -c conda-forge "opencv>=4.10.0"
|
||||
```
|
||||
|
||||
## Teleoperate
|
||||
|
||||
**/!\ FOR SAFETY, READ THIS /!\**
|
||||
Teleoperation consists in manually operating the leader arms to move the follower arms. Importantly:
|
||||
1. Make sure your leader arms are in the same position as the follower arms, so that the follower arms don't move too fast to match the leader arms,
|
||||
2. Our code assumes that your robot has been assembled following Trossen Robotics instructions. This allows us to skip calibration, as we use the pre-defined calibration files in `.cache/calibration/aloha_default`. If you replace a motor, make sure you follow the exact instructions from Trossen Robotics.
|
||||
|
||||
By running the following code, you can start your first **SAFE** teleoperation:
|
||||
```bash
|
||||
python lerobot/scripts/control_robot.py teleoperate \
|
||||
--robot-path lerobot/configs/robot/aloha.yaml \
|
||||
--robot-overrides max_relative_target=5
|
||||
```
|
||||
|
||||
By adding `--robot-overrides max_relative_target=5`, we override the default value for `max_relative_target` defined in `lerobot/configs/robot/aloha.yaml`. It is expected to be `5` to limit the magnitude of the movement for more safety, but the teloperation won't be smooth. When you feel confident, you can disable this limit by adding `--robot-overrides max_relative_target=null` to the command line:
|
||||
```bash
|
||||
python lerobot/scripts/control_robot.py teleoperate \
|
||||
--robot-path lerobot/configs/robot/aloha.yaml \
|
||||
--robot-overrides max_relative_target=null
|
||||
```
|
||||
|
||||
## Record a dataset
|
||||
|
||||
Once you're familiar with teleoperation, you can record your first dataset with Aloha.
|
||||
|
||||
If you want to use the Hugging Face hub features for uploading your dataset and you haven't previously done it, make sure you've logged in using a write-access token, which can be generated from the [Hugging Face settings](https://huggingface.co/settings/tokens):
|
||||
```bash
|
||||
huggingface-cli login --token ${HUGGINGFACE_TOKEN} --add-to-git-credential
|
||||
```
|
||||
|
||||
Store your Hugging Face repository name in a variable to run these commands:
|
||||
```bash
|
||||
HF_USER=$(huggingface-cli whoami | head -n 1)
|
||||
echo $HF_USER
|
||||
```
|
||||
|
||||
Record 2 episodes and upload your dataset to the hub:
|
||||
```bash
|
||||
python lerobot/scripts/control_robot.py record \
|
||||
--robot-path lerobot/configs/robot/aloha.yaml \
|
||||
--robot-overrides max_relative_target=null \
|
||||
--fps 30 \
|
||||
--root data \
|
||||
--repo-id ${HF_USER}/aloha_test \
|
||||
--tags aloha tutorial \
|
||||
--warmup-time-s 5 \
|
||||
--episode-time-s 40 \
|
||||
--reset-time-s 10 \
|
||||
--num-episodes 2 \
|
||||
--push-to-hub 1
|
||||
```
|
||||
|
||||
## Visualize a dataset
|
||||
|
||||
If you uploaded your dataset to the hub with `--push-to-hub 1`, you can [visualize your dataset online](https://huggingface.co/spaces/lerobot/visualize_dataset) by copy pasting your repo id given by:
|
||||
```bash
|
||||
echo ${HF_USER}/aloha_test
|
||||
```
|
||||
|
||||
If you didn't upload with `--push-to-hub 0`, you can also visualize it locally with:
|
||||
```bash
|
||||
python lerobot/scripts/visualize_dataset_html.py \
|
||||
--root data \
|
||||
--repo-id ${HF_USER}/aloha_test
|
||||
```
|
||||
|
||||
## Replay an episode
|
||||
|
||||
**/!\ FOR SAFETY, READ THIS /!\**
|
||||
Replay consists in automatically replaying the sequence of actions (i.e. goal positions for your motors) recorded in a given dataset episode. Make sure the current initial position of your robot is similar to the one in your episode, so that your follower arms don't move too fast to go to the first goal positions. For safety, you might want to add `--robot-overrides max_relative_target=5` to your command line as explained above.
|
||||
|
||||
Now try to replay the first episode on your robot:
|
||||
```bash
|
||||
python lerobot/scripts/control_robot.py replay \
|
||||
--robot-path lerobot/configs/robot/aloha.yaml \
|
||||
--robot-overrides max_relative_target=null \
|
||||
--fps 30 \
|
||||
--root data \
|
||||
--repo-id ${HF_USER}/aloha_test \
|
||||
--episode 0
|
||||
```
|
||||
|
||||
## Train a policy
|
||||
|
||||
To train a policy to control your robot, use the [`python lerobot/scripts/train.py`](../lerobot/scripts/train.py) script. A few arguments are required. Here is an example command:
|
||||
```bash
|
||||
DATA_DIR=data python lerobot/scripts/train.py \
|
||||
dataset_repo_id=${HF_USER}/aloha_test \
|
||||
policy=act_aloha_real \
|
||||
env=aloha_real \
|
||||
hydra.run.dir=outputs/train/act_aloha_test \
|
||||
hydra.job.name=act_aloha_test \
|
||||
device=cuda \
|
||||
wandb.enable=true
|
||||
```
|
||||
|
||||
Let's explain it:
|
||||
1. We provided the dataset as argument with `dataset_repo_id=${HF_USER}/aloha_test`.
|
||||
2. We provided the policy with `policy=act_aloha_real`. This loads configurations from [`lerobot/configs/policy/act_aloha_real.yaml`](../lerobot/configs/policy/act_aloha_real.yaml). Importantly, this policy uses 4 cameras as input `cam_right_wrist`, `cam_left_wrist`, `cam_high`, and `cam_low`.
|
||||
3. We provided an environment as argument with `env=aloha_real`. This loads configurations from [`lerobot/configs/env/aloha_real.yaml`](../lerobot/configs/env/aloha_real.yaml). Note: this yaml defines 18 dimensions for the `state_dim` and `action_dim`, corresponding to 18 motors, not 14 motors as used in previous Aloha work. This is because, we include the `shoulder_shadow` and `elbow_shadow` motors for simplicity.
|
||||
4. We provided `device=cuda` since we are training on a Nvidia GPU.
|
||||
5. We provided `wandb.enable=true` to use [Weights and Biases](https://docs.wandb.ai/quickstart) for visualizing training plots. This is optional but if you use it, make sure you are logged in by running `wandb login`.
|
||||
6. We added `DATA_DIR=data` to access your dataset stored in your local `data` directory. If you dont provide `DATA_DIR`, your dataset will be downloaded from Hugging Face hub to your cache folder `$HOME/.cache/hugginface`. In future versions of `lerobot`, both directories will be in sync.
|
||||
|
||||
Training should take several hours. You will find checkpoints in `outputs/train/act_aloha_test/checkpoints`.
|
||||
|
||||
## Evaluate your policy
|
||||
|
||||
You can use the `record` function from [`lerobot/scripts/control_robot.py`](../lerobot/scripts/control_robot.py) but with a policy checkpoint as input. For instance, run this command to record 10 evaluation episodes:
|
||||
```bash
|
||||
python lerobot/scripts/control_robot.py record \
|
||||
--robot-path lerobot/configs/robot/aloha.yaml \
|
||||
--robot-overrides max_relative_target=null \
|
||||
--fps 30 \
|
||||
--root data \
|
||||
--repo-id ${HF_USER}/eval_act_aloha_test \
|
||||
--tags aloha tutorial eval \
|
||||
--warmup-time-s 5 \
|
||||
--episode-time-s 40 \
|
||||
--reset-time-s 10 \
|
||||
--num-episodes 10 \
|
||||
--num-image-writer-processes 1 \
|
||||
-p outputs/train/act_aloha_test/checkpoints/last/pretrained_model
|
||||
```
|
||||
|
||||
As you can see, it's almost the same command as previously used to record your training dataset. Two things changed:
|
||||
1. There is an additional `-p` argument which indicates the path to your policy checkpoint with (e.g. `-p outputs/train/eval_aloha_test/checkpoints/last/pretrained_model`). You can also use the model repository if you uploaded a model checkpoint to the hub (e.g. `-p ${HF_USER}/act_aloha_test`).
|
||||
2. The name of dataset begins by `eval` to reflect that you are running inference (e.g. `--repo-id ${HF_USER}/eval_act_aloha_test`).
|
||||
3. We use `--num-image-writer-processes 1` instead of the default value (`0`). On our computer, using a dedicated process to write images from the 4 cameras on disk allows to reach constent 30 fps during inference. Feel free to explore different values for `--num-image-writer-processes`.
|
||||
|
||||
## More
|
||||
|
||||
Follow this [previous tutorial](https://github.com/huggingface/lerobot/blob/main/examples/7_get_started_with_real_robot.md#4-train-a-policy-on-your-data) for a more in-depth explaination.
|
||||
|
||||
If you have any question or need help, please reach out on Discord in the channel `#aloha-arm`.
|
|
@ -216,7 +216,9 @@ available_policies_per_env = {
|
|||
"aloha": ["act"],
|
||||
"pusht": ["diffusion", "vqbet"],
|
||||
"xarm": ["tdmpc"],
|
||||
"dora_aloha_real": ["act_real"],
|
||||
"koch_real": ["act_koch_real"],
|
||||
"aloha_real": ["act_aloha_real"],
|
||||
"dora_aloha_real": ["act_aloha_real"],
|
||||
}
|
||||
|
||||
env_task_pairs = [(env, task) for env, tasks in available_tasks_per_env.items() for task in tasks]
|
||||
|
|
|
@ -0,0 +1,468 @@
|
|||
"""Functions to create an empty dataset, and populate it with frames."""
|
||||
# TODO(rcadene, aliberts): to adapt as class methods of next version of LeRobotDataset
|
||||
|
||||
import concurrent
|
||||
import json
|
||||
import logging
|
||||
import multiprocessing
|
||||
import shutil
|
||||
from pathlib import Path
|
||||
|
||||
import torch
|
||||
import tqdm
|
||||
from PIL import Image
|
||||
|
||||
from lerobot.common.datasets.compute_stats import compute_stats
|
||||
from lerobot.common.datasets.lerobot_dataset import CODEBASE_VERSION, LeRobotDataset
|
||||
from lerobot.common.datasets.push_dataset_to_hub.aloha_hdf5_format import to_hf_dataset
|
||||
from lerobot.common.datasets.push_dataset_to_hub.utils import concatenate_episodes, get_default_encoding
|
||||
from lerobot.common.datasets.utils import calculate_episode_data_index, create_branch
|
||||
from lerobot.common.datasets.video_utils import encode_video_frames
|
||||
from lerobot.common.utils.utils import log_say
|
||||
from lerobot.scripts.push_dataset_to_hub import (
|
||||
push_dataset_card_to_hub,
|
||||
push_meta_data_to_hub,
|
||||
push_videos_to_hub,
|
||||
save_meta_data,
|
||||
)
|
||||
|
||||
########################################################################################
|
||||
# Asynchrounous saving of images on disk
|
||||
########################################################################################
|
||||
|
||||
|
||||
def safe_stop_image_writer(func):
|
||||
# TODO(aliberts): Allow to pass custom exceptions
|
||||
# (e.g. ThreadServiceExit, KeyboardInterrupt, SystemExit, UnpluggedError, DynamixelCommError)
|
||||
def wrapper(*args, **kwargs):
|
||||
try:
|
||||
return func(*args, **kwargs)
|
||||
except Exception as e:
|
||||
image_writer = kwargs.get("dataset", {}).get("image_writer")
|
||||
if image_writer is not None:
|
||||
print("Waiting for image writer to terminate...")
|
||||
stop_image_writer(image_writer, timeout=20)
|
||||
raise e
|
||||
|
||||
return wrapper
|
||||
|
||||
|
||||
def save_image(img_tensor, key, frame_index, episode_index, videos_dir: str):
|
||||
img = Image.fromarray(img_tensor.numpy())
|
||||
path = Path(videos_dir) / f"{key}_episode_{episode_index:06d}" / f"frame_{frame_index:06d}.png"
|
||||
path.parent.mkdir(parents=True, exist_ok=True)
|
||||
img.save(str(path), quality=100)
|
||||
|
||||
|
||||
def loop_to_save_images_in_threads(image_queue, num_threads):
|
||||
if num_threads < 1:
|
||||
raise NotImplementedError(f"Only `num_threads>=1` is supported for now, but {num_threads=} given.")
|
||||
|
||||
with concurrent.futures.ThreadPoolExecutor(max_workers=num_threads) as executor:
|
||||
futures = []
|
||||
while True:
|
||||
# Blocks until a frame is available
|
||||
frame_data = image_queue.get()
|
||||
|
||||
# As usually done, exit loop when receiving None to stop the worker
|
||||
if frame_data is None:
|
||||
break
|
||||
|
||||
image, key, frame_index, episode_index, videos_dir = frame_data
|
||||
futures.append(executor.submit(save_image, image, key, frame_index, episode_index, videos_dir))
|
||||
|
||||
# Before exiting function, wait for all threads to complete
|
||||
with tqdm.tqdm(total=len(futures), desc="Writing images") as progress_bar:
|
||||
concurrent.futures.wait(futures)
|
||||
progress_bar.update(len(futures))
|
||||
|
||||
|
||||
def start_image_writer_processes(image_queue, num_processes, num_threads_per_process):
|
||||
if num_processes < 1:
|
||||
raise ValueError(f"Only `num_processes>=1` is supported, but {num_processes=} given.")
|
||||
|
||||
if num_threads_per_process < 1:
|
||||
raise NotImplementedError(
|
||||
"Only `num_threads_per_process>=1` is supported for now, but {num_threads_per_process=} given."
|
||||
)
|
||||
|
||||
processes = []
|
||||
for _ in range(num_processes):
|
||||
process = multiprocessing.Process(
|
||||
target=loop_to_save_images_in_threads,
|
||||
args=(image_queue, num_threads_per_process),
|
||||
)
|
||||
process.start()
|
||||
processes.append(process)
|
||||
return processes
|
||||
|
||||
|
||||
def stop_processes(processes, queue, timeout):
|
||||
# Send None to each process to signal them to stop
|
||||
for _ in processes:
|
||||
queue.put(None)
|
||||
|
||||
# Wait maximum 20 seconds for all processes to terminate
|
||||
for process in processes:
|
||||
process.join(timeout=timeout)
|
||||
|
||||
# If not terminated after 20 seconds, force termination
|
||||
if process.is_alive():
|
||||
process.terminate()
|
||||
|
||||
# Close the queue, no more items can be put in the queue
|
||||
queue.close()
|
||||
|
||||
# Ensure all background queue threads have finished
|
||||
queue.join_thread()
|
||||
|
||||
|
||||
def start_image_writer(num_processes, num_threads):
|
||||
"""This function abstract away the initialisation of processes or/and threads to
|
||||
save images on disk asynchrounously, which is critical to control a robot and record data
|
||||
at a high frame rate.
|
||||
|
||||
When `num_processes=0`, it returns a dictionary containing a threads pool of size `num_threads`.
|
||||
When `num_processes>0`, it returns a dictionary containing a processes pool of size `num_processes`,
|
||||
where each subprocess starts their own threads pool of size `num_threads`.
|
||||
|
||||
The optimal number of processes and threads depends on your computer capabilities.
|
||||
We advise to use 4 threads per camera with 0 processes. If the fps is not stable, try to increase or lower
|
||||
the number of threads. If it is still not stable, try to use 1 subprocess, or more.
|
||||
"""
|
||||
image_writer = {}
|
||||
|
||||
if num_processes == 0:
|
||||
futures = []
|
||||
threads_pool = concurrent.futures.ThreadPoolExecutor(max_workers=num_threads)
|
||||
image_writer["threads_pool"], image_writer["futures"] = threads_pool, futures
|
||||
else:
|
||||
# TODO(rcadene): When using num_processes>1, `multiprocessing.Manager().Queue()`
|
||||
# might be better than `multiprocessing.Queue()`. Source: https://www.geeksforgeeks.org/python-multiprocessing-queue-vs-multiprocessing-manager-queue
|
||||
image_queue = multiprocessing.Queue()
|
||||
processes_pool = start_image_writer_processes(
|
||||
image_queue, num_processes=num_processes, num_threads_per_process=num_threads
|
||||
)
|
||||
image_writer["processes_pool"], image_writer["image_queue"] = processes_pool, image_queue
|
||||
|
||||
return image_writer
|
||||
|
||||
|
||||
def async_save_image(image_writer, image, key, frame_index, episode_index, videos_dir):
|
||||
"""This function abstract away the saving of an image on disk asynchrounously. It uses a dictionary
|
||||
called image writer which contains either a pool of processes or a pool of threads.
|
||||
"""
|
||||
if "threads_pool" in image_writer:
|
||||
threads_pool, futures = image_writer["threads_pool"], image_writer["futures"]
|
||||
futures.append(threads_pool.submit(save_image, image, key, frame_index, episode_index, videos_dir))
|
||||
else:
|
||||
image_queue = image_writer["image_queue"]
|
||||
image_queue.put((image, key, frame_index, episode_index, videos_dir))
|
||||
|
||||
|
||||
def stop_image_writer(image_writer, timeout):
|
||||
if "threads_pool" in image_writer:
|
||||
futures = image_writer["futures"]
|
||||
# Before exiting function, wait for all threads to complete
|
||||
with tqdm.tqdm(total=len(futures), desc="Writing images") as progress_bar:
|
||||
concurrent.futures.wait(futures, timeout=timeout)
|
||||
progress_bar.update(len(futures))
|
||||
else:
|
||||
processes_pool, image_queue = image_writer["processes_pool"], image_writer["image_queue"]
|
||||
stop_processes(processes_pool, image_queue, timeout=timeout)
|
||||
|
||||
|
||||
########################################################################################
|
||||
# Functions to initialize, resume and populate a dataset
|
||||
########################################################################################
|
||||
|
||||
|
||||
def init_dataset(
|
||||
repo_id,
|
||||
root,
|
||||
force_override,
|
||||
fps,
|
||||
video,
|
||||
write_images,
|
||||
num_image_writer_processes,
|
||||
num_image_writer_threads,
|
||||
):
|
||||
local_dir = Path(root) / repo_id
|
||||
if local_dir.exists() and force_override:
|
||||
shutil.rmtree(local_dir)
|
||||
|
||||
episodes_dir = local_dir / "episodes"
|
||||
episodes_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
videos_dir = local_dir / "videos"
|
||||
videos_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# Logic to resume data recording
|
||||
rec_info_path = episodes_dir / "data_recording_info.json"
|
||||
if rec_info_path.exists():
|
||||
with open(rec_info_path) as f:
|
||||
rec_info = json.load(f)
|
||||
num_episodes = rec_info["last_episode_index"] + 1
|
||||
else:
|
||||
num_episodes = 0
|
||||
|
||||
dataset = {
|
||||
"repo_id": repo_id,
|
||||
"local_dir": local_dir,
|
||||
"videos_dir": videos_dir,
|
||||
"episodes_dir": episodes_dir,
|
||||
"fps": fps,
|
||||
"video": video,
|
||||
"rec_info_path": rec_info_path,
|
||||
"num_episodes": num_episodes,
|
||||
}
|
||||
|
||||
if write_images:
|
||||
# Initialize processes or/and threads dedicated to save images on disk asynchronously,
|
||||
# which is critical to control a robot and record data at a high frame rate.
|
||||
image_writer = start_image_writer(
|
||||
num_processes=num_image_writer_processes,
|
||||
num_threads=num_image_writer_threads,
|
||||
)
|
||||
dataset["image_writer"] = image_writer
|
||||
|
||||
return dataset
|
||||
|
||||
|
||||
def add_frame(dataset, observation, action):
|
||||
if "current_episode" not in dataset:
|
||||
# initialize episode dictionary
|
||||
ep_dict = {}
|
||||
for key in observation:
|
||||
if key not in ep_dict:
|
||||
ep_dict[key] = []
|
||||
for key in action:
|
||||
if key not in ep_dict:
|
||||
ep_dict[key] = []
|
||||
|
||||
ep_dict["episode_index"] = []
|
||||
ep_dict["frame_index"] = []
|
||||
ep_dict["timestamp"] = []
|
||||
ep_dict["next.done"] = []
|
||||
|
||||
dataset["current_episode"] = ep_dict
|
||||
dataset["current_frame_index"] = 0
|
||||
|
||||
ep_dict = dataset["current_episode"]
|
||||
episode_index = dataset["num_episodes"]
|
||||
frame_index = dataset["current_frame_index"]
|
||||
videos_dir = dataset["videos_dir"]
|
||||
video = dataset["video"]
|
||||
fps = dataset["fps"]
|
||||
|
||||
ep_dict["episode_index"].append(episode_index)
|
||||
ep_dict["frame_index"].append(frame_index)
|
||||
ep_dict["timestamp"].append(frame_index / fps)
|
||||
ep_dict["next.done"].append(False)
|
||||
|
||||
img_keys = [key for key in observation if "image" in key]
|
||||
non_img_keys = [key for key in observation if "image" not in key]
|
||||
|
||||
# Save all observed modalities except images
|
||||
for key in non_img_keys:
|
||||
ep_dict[key].append(observation[key])
|
||||
|
||||
# Save actions
|
||||
for key in action:
|
||||
ep_dict[key].append(action[key])
|
||||
|
||||
if "image_writer" not in dataset:
|
||||
dataset["current_frame_index"] += 1
|
||||
return
|
||||
|
||||
# Save images
|
||||
image_writer = dataset["image_writer"]
|
||||
for key in img_keys:
|
||||
imgs_dir = videos_dir / f"{key}_episode_{episode_index:06d}"
|
||||
async_save_image(
|
||||
image_writer,
|
||||
image=observation[key],
|
||||
key=key,
|
||||
frame_index=frame_index,
|
||||
episode_index=episode_index,
|
||||
videos_dir=str(videos_dir),
|
||||
)
|
||||
|
||||
if video:
|
||||
fname = f"{key}_episode_{episode_index:06d}.mp4"
|
||||
frame_info = {"path": f"videos/{fname}", "timestamp": frame_index / fps}
|
||||
else:
|
||||
frame_info = str(imgs_dir / f"frame_{frame_index:06d}.png")
|
||||
|
||||
ep_dict[key].append(frame_info)
|
||||
|
||||
dataset["current_frame_index"] += 1
|
||||
|
||||
|
||||
def delete_current_episode(dataset):
|
||||
del dataset["current_episode"]
|
||||
del dataset["current_frame_index"]
|
||||
|
||||
# delete temporary images
|
||||
episode_index = dataset["num_episodes"]
|
||||
videos_dir = dataset["videos_dir"]
|
||||
for tmp_imgs_dir in videos_dir.glob(f"*_episode_{episode_index:06d}"):
|
||||
shutil.rmtree(tmp_imgs_dir)
|
||||
|
||||
|
||||
def save_current_episode(dataset):
|
||||
episode_index = dataset["num_episodes"]
|
||||
ep_dict = dataset["current_episode"]
|
||||
episodes_dir = dataset["episodes_dir"]
|
||||
rec_info_path = dataset["rec_info_path"]
|
||||
|
||||
ep_dict["next.done"][-1] = True
|
||||
|
||||
for key in ep_dict:
|
||||
if "observation" in key and "image" not in key:
|
||||
ep_dict[key] = torch.stack(ep_dict[key])
|
||||
|
||||
ep_dict["action"] = torch.stack(ep_dict["action"])
|
||||
ep_dict["episode_index"] = torch.tensor(ep_dict["episode_index"])
|
||||
ep_dict["frame_index"] = torch.tensor(ep_dict["frame_index"])
|
||||
ep_dict["timestamp"] = torch.tensor(ep_dict["timestamp"])
|
||||
ep_dict["next.done"] = torch.tensor(ep_dict["next.done"])
|
||||
|
||||
ep_path = episodes_dir / f"episode_{episode_index}.pth"
|
||||
torch.save(ep_dict, ep_path)
|
||||
|
||||
rec_info = {
|
||||
"last_episode_index": episode_index,
|
||||
}
|
||||
with open(rec_info_path, "w") as f:
|
||||
json.dump(rec_info, f)
|
||||
|
||||
# force re-initialization of episode dictionnary during add_frame
|
||||
del dataset["current_episode"]
|
||||
|
||||
dataset["num_episodes"] += 1
|
||||
|
||||
|
||||
def encode_videos(dataset, image_keys, play_sounds):
|
||||
log_say("Encoding videos", play_sounds)
|
||||
|
||||
num_episodes = dataset["num_episodes"]
|
||||
videos_dir = dataset["videos_dir"]
|
||||
local_dir = dataset["local_dir"]
|
||||
fps = dataset["fps"]
|
||||
|
||||
# Use ffmpeg to convert frames stored as png into mp4 videos
|
||||
for episode_index in tqdm.tqdm(range(num_episodes)):
|
||||
for key in image_keys:
|
||||
# key = f"observation.images.{name}"
|
||||
tmp_imgs_dir = videos_dir / f"{key}_episode_{episode_index:06d}"
|
||||
fname = f"{key}_episode_{episode_index:06d}.mp4"
|
||||
video_path = local_dir / "videos" / fname
|
||||
if video_path.exists():
|
||||
# Skip if video is already encoded. Could be the case when resuming data recording.
|
||||
continue
|
||||
# note: `encode_video_frames` is a blocking call. Making it asynchronous shouldn't speedup encoding,
|
||||
# since video encoding with ffmpeg is already using multithreading.
|
||||
encode_video_frames(tmp_imgs_dir, video_path, fps, overwrite=True)
|
||||
shutil.rmtree(tmp_imgs_dir)
|
||||
|
||||
|
||||
def from_dataset_to_lerobot_dataset(dataset, play_sounds):
|
||||
log_say("Consolidate episodes", play_sounds)
|
||||
|
||||
num_episodes = dataset["num_episodes"]
|
||||
episodes_dir = dataset["episodes_dir"]
|
||||
videos_dir = dataset["videos_dir"]
|
||||
video = dataset["video"]
|
||||
fps = dataset["fps"]
|
||||
repo_id = dataset["repo_id"]
|
||||
|
||||
ep_dicts = []
|
||||
for episode_index in tqdm.tqdm(range(num_episodes)):
|
||||
ep_path = episodes_dir / f"episode_{episode_index}.pth"
|
||||
ep_dict = torch.load(ep_path)
|
||||
ep_dicts.append(ep_dict)
|
||||
data_dict = concatenate_episodes(ep_dicts)
|
||||
|
||||
if video:
|
||||
image_keys = [key for key in data_dict if "image" in key]
|
||||
encode_videos(dataset, image_keys, play_sounds)
|
||||
|
||||
hf_dataset = to_hf_dataset(data_dict, video)
|
||||
episode_data_index = calculate_episode_data_index(hf_dataset)
|
||||
|
||||
info = {
|
||||
"codebase_version": CODEBASE_VERSION,
|
||||
"fps": fps,
|
||||
"video": video,
|
||||
}
|
||||
if video:
|
||||
info["encoding"] = get_default_encoding()
|
||||
|
||||
lerobot_dataset = LeRobotDataset.from_preloaded(
|
||||
repo_id=repo_id,
|
||||
hf_dataset=hf_dataset,
|
||||
episode_data_index=episode_data_index,
|
||||
info=info,
|
||||
videos_dir=videos_dir,
|
||||
)
|
||||
|
||||
return lerobot_dataset
|
||||
|
||||
|
||||
def save_lerobot_dataset_on_disk(lerobot_dataset):
|
||||
hf_dataset = lerobot_dataset.hf_dataset
|
||||
info = lerobot_dataset.info
|
||||
stats = lerobot_dataset.stats
|
||||
episode_data_index = lerobot_dataset.episode_data_index
|
||||
local_dir = lerobot_dataset.videos_dir.parent
|
||||
meta_data_dir = local_dir / "meta_data"
|
||||
|
||||
hf_dataset = hf_dataset.with_format(None) # to remove transforms that cant be saved
|
||||
hf_dataset.save_to_disk(str(local_dir / "train"))
|
||||
|
||||
save_meta_data(info, stats, episode_data_index, meta_data_dir)
|
||||
|
||||
|
||||
def push_lerobot_dataset_to_hub(lerobot_dataset, tags):
|
||||
hf_dataset = lerobot_dataset.hf_dataset
|
||||
local_dir = lerobot_dataset.videos_dir.parent
|
||||
videos_dir = lerobot_dataset.videos_dir
|
||||
repo_id = lerobot_dataset.repo_id
|
||||
video = lerobot_dataset.video
|
||||
meta_data_dir = local_dir / "meta_data"
|
||||
|
||||
if not (local_dir / "train").exists():
|
||||
raise ValueError(
|
||||
"You need to run `save_lerobot_dataset_on_disk(lerobot_dataset)` before pushing to the hub."
|
||||
)
|
||||
|
||||
hf_dataset.push_to_hub(repo_id, revision="main")
|
||||
push_meta_data_to_hub(repo_id, meta_data_dir, revision="main")
|
||||
push_dataset_card_to_hub(repo_id, revision="main", tags=tags)
|
||||
if video:
|
||||
push_videos_to_hub(repo_id, videos_dir, revision="main")
|
||||
create_branch(repo_id, repo_type="dataset", branch=CODEBASE_VERSION)
|
||||
|
||||
|
||||
def create_lerobot_dataset(dataset, run_compute_stats, push_to_hub, tags, play_sounds):
|
||||
if "image_writer" in dataset:
|
||||
logging.info("Waiting for image writer to terminate...")
|
||||
image_writer = dataset["image_writer"]
|
||||
stop_image_writer(image_writer, timeout=20)
|
||||
|
||||
lerobot_dataset = from_dataset_to_lerobot_dataset(dataset, play_sounds)
|
||||
|
||||
if run_compute_stats:
|
||||
log_say("Computing dataset statistics", play_sounds)
|
||||
lerobot_dataset.stats = compute_stats(lerobot_dataset)
|
||||
else:
|
||||
logging.info("Skipping computation of the dataset statistics")
|
||||
lerobot_dataset.stats = {}
|
||||
|
||||
save_lerobot_dataset_on_disk(lerobot_dataset)
|
||||
|
||||
if push_to_hub:
|
||||
push_lerobot_dataset_to_hub(lerobot_dataset, tags)
|
||||
|
||||
return lerobot_dataset
|
|
@ -189,7 +189,7 @@ class Logger:
|
|||
training_state["scheduler"] = scheduler.state_dict()
|
||||
torch.save(training_state, save_dir / self.training_state_file_name)
|
||||
|
||||
def save_checkpont(
|
||||
def save_checkpoint(
|
||||
self,
|
||||
train_step: int,
|
||||
policy: Policy,
|
||||
|
|
|
@ -0,0 +1,330 @@
|
|||
########################################################################################
|
||||
# Utilities
|
||||
########################################################################################
|
||||
|
||||
|
||||
import logging
|
||||
import time
|
||||
import traceback
|
||||
from contextlib import nullcontext
|
||||
from copy import copy
|
||||
from functools import cache
|
||||
|
||||
import cv2
|
||||
import torch
|
||||
import tqdm
|
||||
from termcolor import colored
|
||||
|
||||
from lerobot.common.datasets.populate_dataset import add_frame, safe_stop_image_writer
|
||||
from lerobot.common.policies.factory import make_policy
|
||||
from lerobot.common.robot_devices.robots.utils import Robot
|
||||
from lerobot.common.robot_devices.utils import busy_wait
|
||||
from lerobot.common.utils.utils import get_safe_torch_device, init_hydra_config, set_global_seed
|
||||
from lerobot.scripts.eval import get_pretrained_policy_path
|
||||
|
||||
|
||||
def log_control_info(robot: Robot, dt_s, episode_index=None, frame_index=None, fps=None):
|
||||
log_items = []
|
||||
if episode_index is not None:
|
||||
log_items.append(f"ep:{episode_index}")
|
||||
if frame_index is not None:
|
||||
log_items.append(f"frame:{frame_index}")
|
||||
|
||||
def log_dt(shortname, dt_val_s):
|
||||
nonlocal log_items, fps
|
||||
info_str = f"{shortname}:{dt_val_s * 1000:5.2f} ({1/ dt_val_s:3.1f}hz)"
|
||||
if fps is not None:
|
||||
actual_fps = 1 / dt_val_s
|
||||
if actual_fps < fps - 1:
|
||||
info_str = colored(info_str, "yellow")
|
||||
log_items.append(info_str)
|
||||
|
||||
# total step time displayed in milliseconds and its frequency
|
||||
log_dt("dt", dt_s)
|
||||
|
||||
# TODO(aliberts): move robot-specific logs logic in robot.print_logs()
|
||||
if not robot.robot_type.startswith("stretch"):
|
||||
for name in robot.leader_arms:
|
||||
key = f"read_leader_{name}_pos_dt_s"
|
||||
if key in robot.logs:
|
||||
log_dt("dtRlead", robot.logs[key])
|
||||
|
||||
for name in robot.follower_arms:
|
||||
key = f"write_follower_{name}_goal_pos_dt_s"
|
||||
if key in robot.logs:
|
||||
log_dt("dtWfoll", robot.logs[key])
|
||||
|
||||
key = f"read_follower_{name}_pos_dt_s"
|
||||
if key in robot.logs:
|
||||
log_dt("dtRfoll", robot.logs[key])
|
||||
|
||||
for name in robot.cameras:
|
||||
key = f"read_camera_{name}_dt_s"
|
||||
if key in robot.logs:
|
||||
log_dt(f"dtR{name}", robot.logs[key])
|
||||
|
||||
info_str = " ".join(log_items)
|
||||
logging.info(info_str)
|
||||
|
||||
|
||||
@cache
|
||||
def is_headless():
|
||||
"""Detects if python is running without a monitor."""
|
||||
try:
|
||||
import pynput # noqa
|
||||
|
||||
return False
|
||||
except Exception:
|
||||
print(
|
||||
"Error trying to import pynput. Switching to headless mode. "
|
||||
"As a result, the video stream from the cameras won't be shown, "
|
||||
"and you won't be able to change the control flow with keyboards. "
|
||||
"For more info, see traceback below.\n"
|
||||
)
|
||||
traceback.print_exc()
|
||||
print()
|
||||
return True
|
||||
|
||||
|
||||
def has_method(_object: object, method_name: str):
|
||||
return hasattr(_object, method_name) and callable(getattr(_object, method_name))
|
||||
|
||||
|
||||
def predict_action(observation, policy, device, use_amp):
|
||||
observation = copy(observation)
|
||||
with (
|
||||
torch.inference_mode(),
|
||||
torch.autocast(device_type=device.type) if device.type == "cuda" and use_amp else nullcontext(),
|
||||
):
|
||||
# Convert to pytorch format: channel first and float32 in [0,1] with batch dimension
|
||||
for name in observation:
|
||||
if "image" in name:
|
||||
observation[name] = observation[name].type(torch.float32) / 255
|
||||
observation[name] = observation[name].permute(2, 0, 1).contiguous()
|
||||
observation[name] = observation[name].unsqueeze(0)
|
||||
observation[name] = observation[name].to(device)
|
||||
|
||||
# Compute the next action with the policy
|
||||
# based on the current observation
|
||||
action = policy.select_action(observation)
|
||||
|
||||
# Remove batch dimension
|
||||
action = action.squeeze(0)
|
||||
|
||||
# Move to cpu, if not already the case
|
||||
action = action.to("cpu")
|
||||
|
||||
return action
|
||||
|
||||
|
||||
def init_keyboard_listener():
|
||||
# Allow to exit early while recording an episode or resetting the environment,
|
||||
# by tapping the right arrow key '->'. This might require a sudo permission
|
||||
# to allow your terminal to monitor keyboard events.
|
||||
events = {}
|
||||
events["exit_early"] = False
|
||||
events["rerecord_episode"] = False
|
||||
events["stop_recording"] = False
|
||||
|
||||
if is_headless():
|
||||
logging.warning(
|
||||
"Headless environment detected. On-screen cameras display and keyboard inputs will not be available."
|
||||
)
|
||||
listener = None
|
||||
return listener, events
|
||||
|
||||
# Only import pynput if not in a headless environment
|
||||
from pynput import keyboard
|
||||
|
||||
def on_press(key):
|
||||
try:
|
||||
if key == keyboard.Key.right:
|
||||
print("Right arrow key pressed. Exiting loop...")
|
||||
events["exit_early"] = True
|
||||
elif key == keyboard.Key.left:
|
||||
print("Left arrow key pressed. Exiting loop and rerecord the last episode...")
|
||||
events["rerecord_episode"] = True
|
||||
events["exit_early"] = True
|
||||
elif key == keyboard.Key.esc:
|
||||
print("Escape key pressed. Stopping data recording...")
|
||||
events["stop_recording"] = True
|
||||
events["exit_early"] = True
|
||||
except Exception as e:
|
||||
print(f"Error handling key press: {e}")
|
||||
|
||||
listener = keyboard.Listener(on_press=on_press)
|
||||
listener.start()
|
||||
|
||||
return listener, events
|
||||
|
||||
|
||||
def init_policy(pretrained_policy_name_or_path, policy_overrides):
|
||||
"""Instantiate the policy and load fps, device and use_amp from config yaml"""
|
||||
pretrained_policy_path = get_pretrained_policy_path(pretrained_policy_name_or_path)
|
||||
hydra_cfg = init_hydra_config(pretrained_policy_path / "config.yaml", policy_overrides)
|
||||
policy = make_policy(hydra_cfg=hydra_cfg, pretrained_policy_name_or_path=pretrained_policy_path)
|
||||
|
||||
# Check device is available
|
||||
device = get_safe_torch_device(hydra_cfg.device, log=True)
|
||||
use_amp = hydra_cfg.use_amp
|
||||
policy_fps = hydra_cfg.env.fps
|
||||
|
||||
policy.eval()
|
||||
policy.to(device)
|
||||
|
||||
torch.backends.cudnn.benchmark = True
|
||||
torch.backends.cuda.matmul.allow_tf32 = True
|
||||
set_global_seed(hydra_cfg.seed)
|
||||
return policy, policy_fps, device, use_amp
|
||||
|
||||
|
||||
def warmup_record(
|
||||
robot,
|
||||
events,
|
||||
enable_teloperation,
|
||||
warmup_time_s,
|
||||
display_cameras,
|
||||
fps,
|
||||
):
|
||||
control_loop(
|
||||
robot=robot,
|
||||
control_time_s=warmup_time_s,
|
||||
display_cameras=display_cameras,
|
||||
events=events,
|
||||
fps=fps,
|
||||
teleoperate=enable_teloperation,
|
||||
)
|
||||
|
||||
|
||||
def record_episode(
|
||||
robot,
|
||||
dataset,
|
||||
events,
|
||||
episode_time_s,
|
||||
display_cameras,
|
||||
policy,
|
||||
device,
|
||||
use_amp,
|
||||
fps,
|
||||
):
|
||||
control_loop(
|
||||
robot=robot,
|
||||
control_time_s=episode_time_s,
|
||||
display_cameras=display_cameras,
|
||||
dataset=dataset,
|
||||
events=events,
|
||||
policy=policy,
|
||||
device=device,
|
||||
use_amp=use_amp,
|
||||
fps=fps,
|
||||
teleoperate=policy is None,
|
||||
)
|
||||
|
||||
|
||||
@safe_stop_image_writer
|
||||
def control_loop(
|
||||
robot,
|
||||
control_time_s=None,
|
||||
teleoperate=False,
|
||||
display_cameras=False,
|
||||
dataset=None,
|
||||
events=None,
|
||||
policy=None,
|
||||
device=None,
|
||||
use_amp=None,
|
||||
fps=None,
|
||||
):
|
||||
# TODO(rcadene): Add option to record logs
|
||||
if not robot.is_connected:
|
||||
robot.connect()
|
||||
|
||||
if events is None:
|
||||
events = {"exit_early": False}
|
||||
|
||||
if control_time_s is None:
|
||||
control_time_s = float("inf")
|
||||
|
||||
if teleoperate and policy is not None:
|
||||
raise ValueError("When `teleoperate` is True, `policy` should be None.")
|
||||
|
||||
if dataset is not None and fps is not None and dataset["fps"] != fps:
|
||||
raise ValueError(f"The dataset fps should be equal to requested fps ({dataset['fps']} != {fps}).")
|
||||
|
||||
timestamp = 0
|
||||
start_episode_t = time.perf_counter()
|
||||
while timestamp < control_time_s:
|
||||
start_loop_t = time.perf_counter()
|
||||
|
||||
if teleoperate:
|
||||
observation, action = robot.teleop_step(record_data=True)
|
||||
else:
|
||||
observation = robot.capture_observation()
|
||||
|
||||
if policy is not None:
|
||||
pred_action = predict_action(observation, policy, device, use_amp)
|
||||
# Action can eventually be clipped using `max_relative_target`,
|
||||
# so action actually sent is saved in the dataset.
|
||||
action = robot.send_action(pred_action)
|
||||
action = {"action": action}
|
||||
|
||||
if dataset is not None:
|
||||
add_frame(dataset, observation, action)
|
||||
|
||||
if display_cameras and not is_headless():
|
||||
image_keys = [key for key in observation if "image" in key]
|
||||
for key in image_keys:
|
||||
cv2.imshow(key, cv2.cvtColor(observation[key].numpy(), cv2.COLOR_RGB2BGR))
|
||||
cv2.waitKey(1)
|
||||
|
||||
if fps is not None:
|
||||
dt_s = time.perf_counter() - start_loop_t
|
||||
busy_wait(1 / fps - dt_s)
|
||||
|
||||
dt_s = time.perf_counter() - start_loop_t
|
||||
log_control_info(robot, dt_s, fps=fps)
|
||||
|
||||
timestamp = time.perf_counter() - start_episode_t
|
||||
if events["exit_early"]:
|
||||
events["exit_early"] = False
|
||||
break
|
||||
|
||||
|
||||
def reset_environment(robot, events, reset_time_s):
|
||||
# TODO(rcadene): refactor warmup_record and reset_environment
|
||||
# TODO(alibets): allow for teleop during reset
|
||||
if has_method(robot, "teleop_safety_stop"):
|
||||
robot.teleop_safety_stop()
|
||||
|
||||
timestamp = 0
|
||||
start_vencod_t = time.perf_counter()
|
||||
|
||||
# Wait if necessary
|
||||
with tqdm.tqdm(total=reset_time_s, desc="Waiting") as pbar:
|
||||
while timestamp < reset_time_s:
|
||||
time.sleep(1)
|
||||
timestamp = time.perf_counter() - start_vencod_t
|
||||
pbar.update(1)
|
||||
if events["exit_early"]:
|
||||
events["exit_early"] = False
|
||||
break
|
||||
|
||||
|
||||
def stop_recording(robot, listener, display_cameras):
|
||||
robot.disconnect()
|
||||
|
||||
if not is_headless():
|
||||
if listener is not None:
|
||||
listener.stop()
|
||||
|
||||
if display_cameras:
|
||||
cv2.destroyAllWindows()
|
||||
|
||||
|
||||
def sanity_check_dataset_name(repo_id, policy):
|
||||
_, dataset_name = repo_id.split("/")
|
||||
# either repo_id doesnt start with "eval_" and there is no policy
|
||||
# or repo_id starts with "eval_" and there is a policy
|
||||
if dataset_name.startswith("eval_") == (policy is None):
|
||||
raise ValueError(
|
||||
f"Your dataset name begins by 'eval_' ({dataset_name}) but no policy is provided ({policy})."
|
||||
)
|
|
@ -349,6 +349,25 @@ class ManipulatorRobot:
|
|||
self.is_connected = False
|
||||
self.logs = {}
|
||||
|
||||
@property
|
||||
def has_camera(self):
|
||||
return len(self.cameras) > 0
|
||||
|
||||
@property
|
||||
def num_cameras(self):
|
||||
return len(self.cameras)
|
||||
|
||||
@property
|
||||
def available_arms(self):
|
||||
available_arms = []
|
||||
for name in self.follower_arms:
|
||||
arm_id = get_arm_id(name, "follower")
|
||||
available_arms.append(arm_id)
|
||||
for name in self.leader_arms:
|
||||
arm_id = get_arm_id(name, "leader")
|
||||
available_arms.append(arm_id)
|
||||
return available_arms
|
||||
|
||||
def connect(self):
|
||||
if self.is_connected:
|
||||
raise RobotDeviceAlreadyConnectedError(
|
||||
|
@ -364,6 +383,7 @@ class ManipulatorRobot:
|
|||
for name in self.follower_arms:
|
||||
print(f"Connecting {name} follower arm.")
|
||||
self.follower_arms[name].connect()
|
||||
for name in self.leader_arms:
|
||||
print(f"Connecting {name} leader arm.")
|
||||
self.leader_arms[name].connect()
|
||||
|
||||
|
|
|
@ -16,6 +16,7 @@
|
|||
import logging
|
||||
import os
|
||||
import os.path as osp
|
||||
import platform
|
||||
import random
|
||||
from contextlib import contextmanager
|
||||
from datetime import datetime, timezone
|
||||
|
@ -28,6 +29,12 @@ import torch
|
|||
from omegaconf import DictConfig
|
||||
|
||||
|
||||
def none_or_int(value):
|
||||
if value == "None":
|
||||
return None
|
||||
return int(value)
|
||||
|
||||
|
||||
def inside_slurm():
|
||||
"""Check whether the python process was launched through slurm"""
|
||||
# TODO(rcadene): return False for interactive mode `--pty bash`
|
||||
|
@ -183,3 +190,30 @@ def print_cuda_memory_usage():
|
|||
|
||||
def capture_timestamp_utc():
|
||||
return datetime.now(timezone.utc)
|
||||
|
||||
|
||||
def say(text, blocking=False):
|
||||
# Check if mac, linux, or windows.
|
||||
if platform.system() == "Darwin":
|
||||
cmd = f'say "{text}"'
|
||||
if not blocking:
|
||||
cmd += " &"
|
||||
elif platform.system() == "Linux":
|
||||
cmd = f'spd-say "{text}"'
|
||||
if blocking:
|
||||
cmd += " --wait"
|
||||
elif platform.system() == "Windows":
|
||||
# TODO(rcadene): Make blocking option work for Windows
|
||||
cmd = (
|
||||
'PowerShell -Command "Add-Type -AssemblyName System.Speech; '
|
||||
f"(New-Object System.Speech.Synthesis.SpeechSynthesizer).Speak('{text}')\""
|
||||
)
|
||||
|
||||
os.system(cmd)
|
||||
|
||||
|
||||
def log_say(text, play_sounds, blocking=False):
|
||||
logging.info(text)
|
||||
|
||||
if play_sounds:
|
||||
say(text, blocking)
|
||||
|
|
|
@ -0,0 +1,10 @@
|
|||
# @package _global_
|
||||
|
||||
fps: 30
|
||||
|
||||
env:
|
||||
name: real_world
|
||||
task: null
|
||||
state_dim: 18
|
||||
action_dim: 18
|
||||
fps: ${fps}
|
|
@ -1,16 +1,22 @@
|
|||
# @package _global_
|
||||
|
||||
# Use `act_real.yaml` to train on real-world Aloha/Aloha2 datasets.
|
||||
# Compared to `act.yaml`, it contains 4 cameras (i.e. cam_right_wrist, cam_left_wrist, images,
|
||||
# cam_low) instead of 1 camera (i.e. top). Also, `training.eval_freq` is set to -1. This config is used
|
||||
# to evaluate checkpoints at a certain frequency of training steps. When it is set to -1, it deactivates evaluation.
|
||||
# This is because real-world evaluation is done through [dora-lerobot](https://github.com/dora-rs/dora-lerobot).
|
||||
# Look at its README for more information on how to evaluate a checkpoint in the real-world.
|
||||
# Use `act_aloha_real.yaml` to train on real-world datasets collected on Aloha or Aloha-2 robots.
|
||||
# Compared to `act.yaml`, it contains 4 cameras (i.e. cam_right_wrist, cam_left_wrist, cam_high, cam_low) instead of 1 camera (i.e. top).
|
||||
# Also, `training.eval_freq` is set to -1. This config is used to evaluate checkpoints at a certain frequency of training steps.
|
||||
# When it is set to -1, it deactivates evaluation. This is because real-world evaluation is done through our `control_robot.py` script.
|
||||
# Look at the documentation in header of `control_robot.py` for more information on how to collect data , train and evaluate a policy.
|
||||
#
|
||||
# Example of usage for training:
|
||||
# Example of usage for training and inference with `control_robot.py`:
|
||||
# ```bash
|
||||
# python lerobot/scripts/train.py \
|
||||
# policy=act_real \
|
||||
# policy=act_aloha_real \
|
||||
# env=aloha_real
|
||||
# ```
|
||||
#
|
||||
# Example of usage for training and inference with [Dora-rs](https://github.com/dora-rs/dora-lerobot):
|
||||
# ```bash
|
||||
# python lerobot/scripts/train.py \
|
||||
# policy=act_aloha_real \
|
||||
# env=dora_aloha_real
|
||||
# ```
|
||||
|
||||
|
@ -36,10 +42,11 @@ override_dataset_stats:
|
|||
std: [[[0.229]], [[0.224]], [[0.225]]] # (c,1,1)
|
||||
|
||||
training:
|
||||
offline_steps: 100000
|
||||
offline_steps: 80000
|
||||
online_steps: 0
|
||||
eval_freq: -1
|
||||
save_freq: 20000
|
||||
save_freq: 10000
|
||||
log_freq: 100
|
||||
save_checkpoint: true
|
||||
|
||||
batch_size: 8
|
||||
|
@ -62,7 +69,7 @@ policy:
|
|||
|
||||
# Input / output structure.
|
||||
n_obs_steps: 1
|
||||
chunk_size: 100 # chunk_size
|
||||
chunk_size: 100
|
||||
n_action_steps: 100
|
||||
|
||||
input_shapes:
|
||||
|
@ -107,7 +114,7 @@ policy:
|
|||
n_vae_encoder_layers: 4
|
||||
|
||||
# Inference.
|
||||
temporal_ensemble_coeff: null
|
||||
temporal_ensemble_momentum: null
|
||||
|
||||
# Training and loss computation.
|
||||
dropout: 0.1
|
|
@ -1,110 +0,0 @@
|
|||
# @package _global_
|
||||
|
||||
# Use `act_real_no_state.yaml` to train on real-world Aloha/Aloha2 datasets when cameras are moving (e.g. wrist cameras)
|
||||
# Compared to `act_real.yaml`, it is camera only and does not use the state as input which is vector of robot joint positions.
|
||||
# We validated experimentaly that not using state reaches better success rate. Our hypothesis is that `act_real.yaml` might
|
||||
# overfits to the state, because the images are more complex to learn from since they are moving.
|
||||
#
|
||||
# Example of usage for training:
|
||||
# ```bash
|
||||
# python lerobot/scripts/train.py \
|
||||
# policy=act_real_no_state \
|
||||
# env=dora_aloha_real
|
||||
# ```
|
||||
|
||||
seed: 1000
|
||||
dataset_repo_id: lerobot/aloha_static_vinh_cup
|
||||
|
||||
override_dataset_stats:
|
||||
observation.images.cam_right_wrist:
|
||||
# stats from imagenet, since we use a pretrained vision model
|
||||
mean: [[[0.485]], [[0.456]], [[0.406]]] # (c,1,1)
|
||||
std: [[[0.229]], [[0.224]], [[0.225]]] # (c,1,1)
|
||||
observation.images.cam_left_wrist:
|
||||
# stats from imagenet, since we use a pretrained vision model
|
||||
mean: [[[0.485]], [[0.456]], [[0.406]]] # (c,1,1)
|
||||
std: [[[0.229]], [[0.224]], [[0.225]]] # (c,1,1)
|
||||
observation.images.cam_high:
|
||||
# stats from imagenet, since we use a pretrained vision model
|
||||
mean: [[[0.485]], [[0.456]], [[0.406]]] # (c,1,1)
|
||||
std: [[[0.229]], [[0.224]], [[0.225]]] # (c,1,1)
|
||||
observation.images.cam_low:
|
||||
# stats from imagenet, since we use a pretrained vision model
|
||||
mean: [[[0.485]], [[0.456]], [[0.406]]] # (c,1,1)
|
||||
std: [[[0.229]], [[0.224]], [[0.225]]] # (c,1,1)
|
||||
|
||||
training:
|
||||
offline_steps: 100000
|
||||
online_steps: 0
|
||||
eval_freq: -1
|
||||
save_freq: 20000
|
||||
save_checkpoint: true
|
||||
|
||||
batch_size: 8
|
||||
lr: 1e-5
|
||||
lr_backbone: 1e-5
|
||||
weight_decay: 1e-4
|
||||
grad_clip_norm: 10
|
||||
online_steps_between_rollouts: 1
|
||||
|
||||
delta_timestamps:
|
||||
action: "[i / ${fps} for i in range(${policy.chunk_size})]"
|
||||
|
||||
eval:
|
||||
n_episodes: 50
|
||||
batch_size: 50
|
||||
|
||||
# See `configuration_act.py` for more details.
|
||||
policy:
|
||||
name: act
|
||||
|
||||
# Input / output structure.
|
||||
n_obs_steps: 1
|
||||
chunk_size: 100 # chunk_size
|
||||
n_action_steps: 100
|
||||
|
||||
input_shapes:
|
||||
# TODO(rcadene, alexander-soare): add variables for height and width from the dataset/env?
|
||||
observation.images.cam_right_wrist: [3, 480, 640]
|
||||
observation.images.cam_left_wrist: [3, 480, 640]
|
||||
observation.images.cam_high: [3, 480, 640]
|
||||
observation.images.cam_low: [3, 480, 640]
|
||||
output_shapes:
|
||||
action: ["${env.action_dim}"]
|
||||
|
||||
# Normalization / Unnormalization
|
||||
input_normalization_modes:
|
||||
observation.images.cam_right_wrist: mean_std
|
||||
observation.images.cam_left_wrist: mean_std
|
||||
observation.images.cam_high: mean_std
|
||||
observation.images.cam_low: mean_std
|
||||
output_normalization_modes:
|
||||
action: mean_std
|
||||
|
||||
# Architecture.
|
||||
# Vision backbone.
|
||||
vision_backbone: resnet18
|
||||
pretrained_backbone_weights: ResNet18_Weights.IMAGENET1K_V1
|
||||
replace_final_stride_with_dilation: false
|
||||
# Transformer layers.
|
||||
pre_norm: false
|
||||
dim_model: 512
|
||||
n_heads: 8
|
||||
dim_feedforward: 3200
|
||||
feedforward_activation: relu
|
||||
n_encoder_layers: 4
|
||||
# Note: Although the original ACT implementation has 7 for `n_decoder_layers`, there is a bug in the code
|
||||
# that means only the first layer is used. Here we match the original implementation by setting this to 1.
|
||||
# See this issue https://github.com/tonyzhaozh/act/issues/25#issue-2258740521.
|
||||
n_decoder_layers: 1
|
||||
# VAE.
|
||||
use_vae: true
|
||||
latent_dim: 32
|
||||
n_vae_encoder_layers: 4
|
||||
|
||||
# Inference.
|
||||
temporal_ensemble_coeff: null
|
||||
|
||||
# Training and loss computation.
|
||||
dropout: 0.1
|
||||
kl_weight: 10.0
|
|
@ -99,161 +99,35 @@ python lerobot/scripts/control_robot.py record \
|
|||
"""
|
||||
|
||||
import argparse
|
||||
import concurrent.futures
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import platform
|
||||
import shutil
|
||||
import time
|
||||
import traceback
|
||||
from contextlib import nullcontext
|
||||
from functools import cache
|
||||
from pathlib import Path
|
||||
|
||||
import cv2
|
||||
import torch
|
||||
import tqdm
|
||||
from omegaconf import DictConfig
|
||||
from PIL import Image
|
||||
from termcolor import colored
|
||||
from typing import List
|
||||
|
||||
# from safetensors.torch import load_file, save_file
|
||||
from lerobot.common.datasets.compute_stats import compute_stats
|
||||
from lerobot.common.datasets.lerobot_dataset import CODEBASE_VERSION, LeRobotDataset
|
||||
from lerobot.common.datasets.push_dataset_to_hub.aloha_hdf5_format import to_hf_dataset
|
||||
from lerobot.common.datasets.push_dataset_to_hub.utils import concatenate_episodes, get_default_encoding
|
||||
from lerobot.common.datasets.utils import calculate_episode_data_index, create_branch
|
||||
from lerobot.common.datasets.video_utils import encode_video_frames
|
||||
from lerobot.common.policies.factory import make_policy
|
||||
from lerobot.common.robot_devices.robots.factory import make_robot
|
||||
from lerobot.common.robot_devices.robots.utils import Robot, get_arm_id
|
||||
from lerobot.common.robot_devices.utils import busy_wait, safe_disconnect
|
||||
from lerobot.common.utils.utils import get_safe_torch_device, init_hydra_config, init_logging, set_global_seed
|
||||
from lerobot.scripts.eval import get_pretrained_policy_path
|
||||
from lerobot.scripts.push_dataset_to_hub import (
|
||||
push_dataset_card_to_hub,
|
||||
push_meta_data_to_hub,
|
||||
push_videos_to_hub,
|
||||
save_meta_data,
|
||||
from lerobot.common.datasets.lerobot_dataset import LeRobotDataset
|
||||
from lerobot.common.datasets.populate_dataset import (
|
||||
create_lerobot_dataset,
|
||||
delete_current_episode,
|
||||
init_dataset,
|
||||
save_current_episode,
|
||||
)
|
||||
|
||||
########################################################################################
|
||||
# Utilities
|
||||
########################################################################################
|
||||
|
||||
|
||||
def say(text, blocking=False):
|
||||
# Check if mac, linux, or windows.
|
||||
if platform.system() == "Darwin":
|
||||
cmd = f'say "{text}"'
|
||||
elif platform.system() == "Linux":
|
||||
cmd = f'spd-say "{text}"'
|
||||
elif platform.system() == "Windows":
|
||||
cmd = (
|
||||
'PowerShell -Command "Add-Type -AssemblyName System.Speech; '
|
||||
f"(New-Object System.Speech.Synthesis.SpeechSynthesizer).Speak('{text}')\""
|
||||
)
|
||||
|
||||
if not blocking and platform.system() in ["Darwin", "Linux"]:
|
||||
# TODO(rcadene): Make it work for Windows
|
||||
# Use the ampersand to run command in the background
|
||||
cmd += " &"
|
||||
|
||||
os.system(cmd)
|
||||
|
||||
|
||||
def save_image(img_tensor, key, frame_index, episode_index, videos_dir):
|
||||
img = Image.fromarray(img_tensor.numpy())
|
||||
path = videos_dir / f"{key}_episode_{episode_index:06d}" / f"frame_{frame_index:06d}.png"
|
||||
path.parent.mkdir(parents=True, exist_ok=True)
|
||||
img.save(str(path), quality=100)
|
||||
|
||||
|
||||
def none_or_int(value):
|
||||
if value == "None":
|
||||
return None
|
||||
return int(value)
|
||||
|
||||
|
||||
def log_control_info(robot: Robot, dt_s, episode_index=None, frame_index=None, fps=None):
|
||||
log_items = []
|
||||
if episode_index is not None:
|
||||
log_items.append(f"ep:{episode_index}")
|
||||
if frame_index is not None:
|
||||
log_items.append(f"frame:{frame_index}")
|
||||
|
||||
def log_dt(shortname, dt_val_s):
|
||||
nonlocal log_items, fps
|
||||
info_str = f"{shortname}:{dt_val_s * 1000:5.2f} ({1/ dt_val_s:3.1f}hz)"
|
||||
if fps is not None:
|
||||
actual_fps = 1 / dt_val_s
|
||||
if actual_fps < fps - 1:
|
||||
info_str = colored(info_str, "yellow")
|
||||
log_items.append(info_str)
|
||||
|
||||
# total step time displayed in milliseconds and its frequency
|
||||
log_dt("dt", dt_s)
|
||||
|
||||
# TODO(aliberts): move robot-specific logs logic in robot.print_logs()
|
||||
if not robot.robot_type.startswith("stretch"):
|
||||
for name in robot.leader_arms:
|
||||
key = f"read_leader_{name}_pos_dt_s"
|
||||
if key in robot.logs:
|
||||
log_dt("dtRlead", robot.logs[key])
|
||||
|
||||
for name in robot.follower_arms:
|
||||
key = f"write_follower_{name}_goal_pos_dt_s"
|
||||
if key in robot.logs:
|
||||
log_dt("dtWfoll", robot.logs[key])
|
||||
|
||||
key = f"read_follower_{name}_pos_dt_s"
|
||||
if key in robot.logs:
|
||||
log_dt("dtRfoll", robot.logs[key])
|
||||
|
||||
for name in robot.cameras:
|
||||
key = f"read_camera_{name}_dt_s"
|
||||
if key in robot.logs:
|
||||
log_dt(f"dtR{name}", robot.logs[key])
|
||||
|
||||
info_str = " ".join(log_items)
|
||||
logging.info(info_str)
|
||||
|
||||
|
||||
@cache
|
||||
def is_headless():
|
||||
"""Detects if python is running without a monitor."""
|
||||
try:
|
||||
import pynput # noqa
|
||||
|
||||
return False
|
||||
except Exception:
|
||||
print(
|
||||
"Error trying to import pynput. Switching to headless mode. "
|
||||
"As a result, the video stream from the cameras won't be shown, "
|
||||
"and you won't be able to change the control flow with keyboards. "
|
||||
"For more info, see traceback below.\n"
|
||||
)
|
||||
traceback.print_exc()
|
||||
print()
|
||||
return True
|
||||
|
||||
|
||||
def has_method(_object: object, method_name: str):
|
||||
return hasattr(_object, method_name) and callable(getattr(_object, method_name))
|
||||
|
||||
|
||||
def get_available_arms(robot):
|
||||
# TODO(rcadene): moves this function in manipulator class?
|
||||
available_arms = []
|
||||
for name in robot.follower_arms:
|
||||
arm_id = get_arm_id(name, "follower")
|
||||
available_arms.append(arm_id)
|
||||
for name in robot.leader_arms:
|
||||
arm_id = get_arm_id(name, "leader")
|
||||
available_arms.append(arm_id)
|
||||
return available_arms
|
||||
|
||||
from lerobot.common.robot_devices.control_utils import (
|
||||
control_loop,
|
||||
has_method,
|
||||
init_keyboard_listener,
|
||||
init_policy,
|
||||
log_control_info,
|
||||
record_episode,
|
||||
reset_environment,
|
||||
sanity_check_dataset_name,
|
||||
stop_recording,
|
||||
warmup_record,
|
||||
)
|
||||
from lerobot.common.robot_devices.robots.factory import make_robot
|
||||
from lerobot.common.robot_devices.robots.utils import Robot
|
||||
from lerobot.common.robot_devices.utils import busy_wait, safe_disconnect
|
||||
from lerobot.common.utils.utils import init_hydra_config, init_logging, log_say, none_or_int
|
||||
|
||||
########################################################################################
|
||||
# Control modes
|
||||
|
@ -270,9 +144,8 @@ def calibrate(robot: Robot, arms: list[str] | None):
|
|||
robot.home()
|
||||
return
|
||||
|
||||
available_arms = get_available_arms(robot)
|
||||
unknown_arms = [arm_id for arm_id in arms if arm_id not in available_arms]
|
||||
available_arms_str = " ".join(available_arms)
|
||||
unknown_arms = [arm_id for arm_id in arms if arm_id not in robot.available_arms]
|
||||
available_arms_str = " ".join(robot.available_arms)
|
||||
unknown_arms_str = " ".join(unknown_arms)
|
||||
|
||||
if arms is None or len(arms) == 0:
|
||||
|
@ -305,35 +178,26 @@ def calibrate(robot: Robot, arms: list[str] | None):
|
|||
|
||||
|
||||
@safe_disconnect
|
||||
def teleoperate(robot: Robot, fps: int | None = None, teleop_time_s: float | None = None):
|
||||
# TODO(rcadene): Add option to record logs
|
||||
if not robot.is_connected:
|
||||
robot.connect()
|
||||
|
||||
start_teleop_t = time.perf_counter()
|
||||
while True:
|
||||
start_loop_t = time.perf_counter()
|
||||
robot.teleop_step()
|
||||
|
||||
if fps is not None:
|
||||
dt_s = time.perf_counter() - start_loop_t
|
||||
busy_wait(1 / fps - dt_s)
|
||||
|
||||
dt_s = time.perf_counter() - start_loop_t
|
||||
log_control_info(robot, dt_s, fps=fps)
|
||||
|
||||
if teleop_time_s is not None and time.perf_counter() - start_teleop_t > teleop_time_s:
|
||||
break
|
||||
def teleoperate(
|
||||
robot: Robot, fps: int | None = None, teleop_time_s: float | None = None, display_cameras: bool = False
|
||||
):
|
||||
control_loop(
|
||||
robot,
|
||||
control_time_s=teleop_time_s,
|
||||
fps=fps,
|
||||
teleoperate=True,
|
||||
display_cameras=display_cameras,
|
||||
)
|
||||
|
||||
|
||||
@safe_disconnect
|
||||
def record(
|
||||
robot: Robot,
|
||||
policy: torch.nn.Module | None = None,
|
||||
hydra_cfg: DictConfig | None = None,
|
||||
root: str,
|
||||
repo_id: str,
|
||||
pretrained_policy_name_or_path: str | None = None,
|
||||
policy_overrides: List[str] | None = None,
|
||||
fps: int | None = None,
|
||||
root="data",
|
||||
repo_id="lerobot/debug",
|
||||
warmup_time_s=2,
|
||||
episode_time_s=10,
|
||||
reset_time_s=5,
|
||||
|
@ -342,390 +206,115 @@ def record(
|
|||
run_compute_stats=True,
|
||||
push_to_hub=True,
|
||||
tags=None,
|
||||
num_image_writers_per_camera=4,
|
||||
num_image_writer_processes=0,
|
||||
num_image_writer_threads_per_camera=4,
|
||||
force_override=False,
|
||||
display_cameras=True,
|
||||
play_sounds=True,
|
||||
):
|
||||
# TODO(rcadene): Add option to record logs
|
||||
# TODO(rcadene): Clean this function via decomposition in higher level functions
|
||||
listener = None
|
||||
events = None
|
||||
policy = None
|
||||
device = None
|
||||
use_amp = None
|
||||
|
||||
_, dataset_name = repo_id.split("/")
|
||||
if dataset_name.startswith("eval_") and policy is None:
|
||||
raise ValueError(
|
||||
f"Your dataset name begins by 'eval_' ({dataset_name}) but no policy is provided ({policy})."
|
||||
)
|
||||
# Load pretrained policy
|
||||
if pretrained_policy_name_or_path is not None:
|
||||
policy, policy_fps, device, use_amp = init_policy(pretrained_policy_name_or_path, policy_overrides)
|
||||
|
||||
if fps is None:
|
||||
fps = policy_fps
|
||||
logging.warning(f"No fps provided, so using the fps from policy config ({policy_fps}).")
|
||||
elif fps != policy_fps:
|
||||
logging.warning(
|
||||
f"There is a mismatch between the provided fps ({fps}) and the one from policy config ({policy_fps})."
|
||||
)
|
||||
|
||||
# Create empty dataset or load existing saved episodes
|
||||
sanity_check_dataset_name(repo_id, policy)
|
||||
dataset = init_dataset(
|
||||
repo_id,
|
||||
root,
|
||||
force_override,
|
||||
fps,
|
||||
video,
|
||||
write_images=robot.has_camera,
|
||||
num_image_writer_processes=num_image_writer_processes,
|
||||
num_image_writer_threads=num_image_writer_threads_per_camera * robot.num_cameras,
|
||||
)
|
||||
|
||||
if not robot.is_connected:
|
||||
robot.connect()
|
||||
|
||||
local_dir = Path(root) / repo_id
|
||||
if local_dir.exists() and force_override:
|
||||
shutil.rmtree(local_dir)
|
||||
listener, events = init_keyboard_listener()
|
||||
|
||||
episodes_dir = local_dir / "episodes"
|
||||
episodes_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
videos_dir = local_dir / "videos"
|
||||
videos_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# Logic to resume data recording
|
||||
rec_info_path = episodes_dir / "data_recording_info.json"
|
||||
if rec_info_path.exists():
|
||||
with open(rec_info_path) as f:
|
||||
rec_info = json.load(f)
|
||||
episode_index = rec_info["last_episode_index"] + 1
|
||||
else:
|
||||
episode_index = 0
|
||||
|
||||
if is_headless():
|
||||
logging.warning(
|
||||
"Headless environment detected. On-screen cameras display and keyboard inputs will not be available."
|
||||
)
|
||||
|
||||
# Allow to exit early while recording an episode or resetting the environment,
|
||||
# by tapping the right arrow key '->'. This might require a sudo permission
|
||||
# to allow your terminal to monitor keyboard events.
|
||||
exit_early = False
|
||||
rerecord_episode = False
|
||||
stop_recording = False
|
||||
|
||||
# Only import pynput if not in a headless environment
|
||||
if not is_headless():
|
||||
from pynput import keyboard
|
||||
|
||||
def on_press(key):
|
||||
nonlocal exit_early, rerecord_episode, stop_recording
|
||||
try:
|
||||
if key == keyboard.Key.right:
|
||||
print("Right arrow key pressed. Exiting loop...")
|
||||
exit_early = True
|
||||
elif key == keyboard.Key.left:
|
||||
print("Left arrow key pressed. Exiting loop and rerecord the last episode...")
|
||||
rerecord_episode = True
|
||||
exit_early = True
|
||||
elif key == keyboard.Key.esc:
|
||||
print("Escape key pressed. Stopping data recording...")
|
||||
stop_recording = True
|
||||
exit_early = True
|
||||
except Exception as e:
|
||||
print(f"Error handling key press: {e}")
|
||||
|
||||
listener = keyboard.Listener(on_press=on_press)
|
||||
listener.start()
|
||||
|
||||
# Load policy if any
|
||||
if policy is not None:
|
||||
# Check device is available
|
||||
device = get_safe_torch_device(hydra_cfg.device, log=True)
|
||||
|
||||
policy.eval()
|
||||
policy.to(device)
|
||||
|
||||
torch.backends.cudnn.benchmark = True
|
||||
torch.backends.cuda.matmul.allow_tf32 = True
|
||||
set_global_seed(hydra_cfg.seed)
|
||||
|
||||
# override fps using policy fps
|
||||
fps = hydra_cfg.env.fps
|
||||
|
||||
# Execute a few seconds without recording data, to give times
|
||||
# to the robot devices to connect and start synchronizing.
|
||||
timestamp = 0
|
||||
start_warmup_t = time.perf_counter()
|
||||
is_warmup_print = False
|
||||
while timestamp < warmup_time_s:
|
||||
if not is_warmup_print:
|
||||
logging.info("Warming up (no data recording)")
|
||||
say("Warming up")
|
||||
is_warmup_print = True
|
||||
|
||||
start_loop_t = time.perf_counter()
|
||||
|
||||
if policy is None:
|
||||
observation, action = robot.teleop_step(record_data=True)
|
||||
else:
|
||||
observation = robot.capture_observation()
|
||||
|
||||
if display_cameras and not is_headless():
|
||||
image_keys = [key for key in observation if "image" in key]
|
||||
for key in image_keys:
|
||||
cv2.imshow(key, cv2.cvtColor(observation[key].numpy(), cv2.COLOR_RGB2BGR))
|
||||
cv2.waitKey(1)
|
||||
|
||||
dt_s = time.perf_counter() - start_loop_t
|
||||
busy_wait(1 / fps - dt_s)
|
||||
|
||||
dt_s = time.perf_counter() - start_loop_t
|
||||
log_control_info(robot, dt_s, fps=fps)
|
||||
|
||||
timestamp = time.perf_counter() - start_warmup_t
|
||||
# Execute a few seconds without recording to:
|
||||
# 1. teleoperate the robot to move it in starting position if no policy provided,
|
||||
# 2. give times to the robot devices to connect and start synchronizing,
|
||||
# 3. place the cameras windows on screen
|
||||
enable_teleoperation = policy is None
|
||||
log_say("Warmup record", play_sounds)
|
||||
warmup_record(robot, events, enable_teleoperation, warmup_time_s, display_cameras, fps)
|
||||
|
||||
if has_method(robot, "teleop_safety_stop"):
|
||||
robot.teleop_safety_stop()
|
||||
|
||||
# Save images using threads to reach high fps (30 and more)
|
||||
# Using `with` to exist smoothly if an execption is raised.
|
||||
futures = []
|
||||
num_image_writers = num_image_writers_per_camera * len(robot.cameras)
|
||||
num_image_writers = max(num_image_writers, 1)
|
||||
with concurrent.futures.ThreadPoolExecutor(max_workers=num_image_writers) as executor:
|
||||
# Start recording all episodes
|
||||
while episode_index < num_episodes:
|
||||
logging.info(f"Recording episode {episode_index}")
|
||||
say(f"Recording episode {episode_index}")
|
||||
ep_dict = {}
|
||||
frame_index = 0
|
||||
timestamp = 0
|
||||
start_episode_t = time.perf_counter()
|
||||
while timestamp < episode_time_s:
|
||||
start_loop_t = time.perf_counter()
|
||||
while True:
|
||||
if dataset["num_episodes"] >= num_episodes:
|
||||
break
|
||||
|
||||
if policy is None:
|
||||
observation, action = robot.teleop_step(record_data=True)
|
||||
else:
|
||||
observation = robot.capture_observation()
|
||||
episode_index = dataset["num_episodes"]
|
||||
log_say(f"Recording episode {episode_index}", play_sounds)
|
||||
record_episode(
|
||||
dataset=dataset,
|
||||
robot=robot,
|
||||
events=events,
|
||||
episode_time_s=episode_time_s,
|
||||
display_cameras=display_cameras,
|
||||
policy=policy,
|
||||
device=device,
|
||||
use_amp=use_amp,
|
||||
fps=fps,
|
||||
)
|
||||
|
||||
image_keys = [key for key in observation if "image" in key]
|
||||
not_image_keys = [key for key in observation if "image" not in key]
|
||||
# Execute a few seconds without recording to give time to manually reset the environment
|
||||
# Current code logic doesn't allow to teleoperate during this time.
|
||||
# TODO(rcadene): add an option to enable teleoperation during reset
|
||||
# Skip reset for the last episode to be recorded
|
||||
if not events["stop_recording"] and (
|
||||
(episode_index < num_episodes - 1) or events["rerecord_episode"]
|
||||
):
|
||||
log_say("Reset the environment", play_sounds)
|
||||
reset_environment(robot, events, reset_time_s)
|
||||
|
||||
for key in image_keys:
|
||||
futures += [
|
||||
executor.submit(
|
||||
save_image, observation[key], key, frame_index, episode_index, videos_dir
|
||||
)
|
||||
]
|
||||
if events["rerecord_episode"]:
|
||||
log_say("Re-record episode", play_sounds)
|
||||
events["rerecord_episode"] = False
|
||||
events["exit_early"] = False
|
||||
delete_current_episode(dataset)
|
||||
continue
|
||||
|
||||
if display_cameras and not is_headless():
|
||||
image_keys = [key for key in observation if "image" in key]
|
||||
for key in image_keys:
|
||||
cv2.imshow(key, cv2.cvtColor(observation[key].numpy(), cv2.COLOR_RGB2BGR))
|
||||
cv2.waitKey(1)
|
||||
# Increment by one dataset["current_episode_index"]
|
||||
save_current_episode(dataset)
|
||||
|
||||
for key in not_image_keys:
|
||||
if key not in ep_dict:
|
||||
ep_dict[key] = []
|
||||
ep_dict[key].append(observation[key])
|
||||
if events["stop_recording"]:
|
||||
break
|
||||
|
||||
if policy is not None:
|
||||
with (
|
||||
torch.inference_mode(),
|
||||
torch.autocast(device_type=device.type)
|
||||
if device.type == "cuda" and hydra_cfg.use_amp
|
||||
else nullcontext(),
|
||||
):
|
||||
# Convert to pytorch format: channel first and float32 in [0,1] with batch dimension
|
||||
for name in observation:
|
||||
if "image" in name:
|
||||
observation[name] = observation[name].type(torch.float32) / 255
|
||||
observation[name] = observation[name].permute(2, 0, 1).contiguous()
|
||||
observation[name] = observation[name].unsqueeze(0)
|
||||
observation[name] = observation[name].to(device)
|
||||
log_say("Stop recording", play_sounds, blocking=True)
|
||||
stop_recording(robot, listener, display_cameras)
|
||||
|
||||
# Compute the next action with the policy
|
||||
# based on the current observation
|
||||
action = policy.select_action(observation)
|
||||
lerobot_dataset = create_lerobot_dataset(dataset, run_compute_stats, push_to_hub, tags, play_sounds)
|
||||
|
||||
# Remove batch dimension
|
||||
action = action.squeeze(0)
|
||||
|
||||
# Move to cpu, if not already the case
|
||||
action = action.to("cpu")
|
||||
|
||||
# Order the robot to move
|
||||
action_sent = robot.send_action(action)
|
||||
|
||||
# Action can eventually be clipped using `max_relative_target`,
|
||||
# so action actually sent is saved in the dataset.
|
||||
action = {"action": action_sent}
|
||||
|
||||
for key in action:
|
||||
if key not in ep_dict:
|
||||
ep_dict[key] = []
|
||||
ep_dict[key].append(action[key])
|
||||
|
||||
frame_index += 1
|
||||
|
||||
dt_s = time.perf_counter() - start_loop_t
|
||||
busy_wait(1 / fps - dt_s)
|
||||
|
||||
dt_s = time.perf_counter() - start_loop_t
|
||||
log_control_info(robot, dt_s, fps=fps)
|
||||
|
||||
timestamp = time.perf_counter() - start_episode_t
|
||||
if exit_early:
|
||||
exit_early = False
|
||||
break
|
||||
|
||||
# TODO(alibets): allow for teleop during reset
|
||||
if has_method(robot, "teleop_safety_stop"):
|
||||
robot.teleop_safety_stop()
|
||||
|
||||
if not stop_recording:
|
||||
# Start resetting env while the executor are finishing
|
||||
logging.info("Reset the environment")
|
||||
say("Reset the environment")
|
||||
|
||||
timestamp = 0
|
||||
start_vencod_t = time.perf_counter()
|
||||
|
||||
# During env reset we save the data and encode the videos
|
||||
num_frames = frame_index
|
||||
|
||||
for key in image_keys:
|
||||
if video:
|
||||
tmp_imgs_dir = videos_dir / f"{key}_episode_{episode_index:06d}"
|
||||
fname = f"{key}_episode_{episode_index:06d}.mp4"
|
||||
video_path = local_dir / "videos" / fname
|
||||
if video_path.exists():
|
||||
video_path.unlink()
|
||||
# Store the reference to the video frame, even tho the videos are not yet encoded
|
||||
ep_dict[key] = []
|
||||
for i in range(num_frames):
|
||||
ep_dict[key].append({"path": f"videos/{fname}", "timestamp": i / fps})
|
||||
|
||||
else:
|
||||
imgs_dir = videos_dir / f"{key}_episode_{episode_index:06d}"
|
||||
ep_dict[key] = []
|
||||
for i in range(num_frames):
|
||||
img_path = imgs_dir / f"frame_{i:06d}.png"
|
||||
ep_dict[key].append({"path": str(img_path)})
|
||||
|
||||
for key in not_image_keys:
|
||||
ep_dict[key] = torch.stack(ep_dict[key])
|
||||
|
||||
for key in action:
|
||||
ep_dict[key] = torch.stack(ep_dict[key])
|
||||
|
||||
ep_dict["episode_index"] = torch.tensor([episode_index] * num_frames)
|
||||
ep_dict["frame_index"] = torch.arange(0, num_frames, 1)
|
||||
ep_dict["timestamp"] = torch.arange(0, num_frames, 1) / fps
|
||||
|
||||
done = torch.zeros(num_frames, dtype=torch.bool)
|
||||
done[-1] = True
|
||||
ep_dict["next.done"] = done
|
||||
|
||||
ep_path = episodes_dir / f"episode_{episode_index}.pth"
|
||||
print("Saving episode dictionary...")
|
||||
torch.save(ep_dict, ep_path)
|
||||
|
||||
rec_info = {
|
||||
"last_episode_index": episode_index,
|
||||
}
|
||||
with open(rec_info_path, "w") as f:
|
||||
json.dump(rec_info, f)
|
||||
|
||||
is_last_episode = stop_recording or (episode_index == (num_episodes - 1))
|
||||
|
||||
# Wait if necessary
|
||||
with tqdm.tqdm(total=reset_time_s, desc="Waiting") as pbar:
|
||||
while timestamp < reset_time_s and not is_last_episode:
|
||||
time.sleep(1)
|
||||
timestamp = time.perf_counter() - start_vencod_t
|
||||
pbar.update(1)
|
||||
if exit_early:
|
||||
exit_early = False
|
||||
break
|
||||
|
||||
# Skip updating episode index which forces re-recording episode
|
||||
if rerecord_episode:
|
||||
rerecord_episode = False
|
||||
continue
|
||||
|
||||
episode_index += 1
|
||||
|
||||
if is_last_episode:
|
||||
logging.info("Done recording")
|
||||
say("Done recording", blocking=True)
|
||||
if not is_headless():
|
||||
listener.stop()
|
||||
|
||||
logging.info("Waiting for threads writing the images on disk to terminate...")
|
||||
for _ in tqdm.tqdm(
|
||||
concurrent.futures.as_completed(futures), total=len(futures), desc="Writting images"
|
||||
):
|
||||
pass
|
||||
break
|
||||
|
||||
robot.disconnect()
|
||||
if display_cameras and not is_headless():
|
||||
cv2.destroyAllWindows()
|
||||
|
||||
num_episodes = episode_index
|
||||
|
||||
if video:
|
||||
logging.info("Encoding videos")
|
||||
say("Encoding videos")
|
||||
# Use ffmpeg to convert frames stored as png into mp4 videos
|
||||
for episode_index in tqdm.tqdm(range(num_episodes)):
|
||||
for key in image_keys:
|
||||
tmp_imgs_dir = videos_dir / f"{key}_episode_{episode_index:06d}"
|
||||
fname = f"{key}_episode_{episode_index:06d}.mp4"
|
||||
video_path = local_dir / "videos" / fname
|
||||
if video_path.exists():
|
||||
# Skip if video is already encoded. Could be the case when resuming data recording.
|
||||
continue
|
||||
# note: `encode_video_frames` is a blocking call. Making it asynchronous shouldn't speedup encoding,
|
||||
# since video encoding with ffmpeg is already using multithreading.
|
||||
encode_video_frames(tmp_imgs_dir, video_path, fps, overwrite=True)
|
||||
shutil.rmtree(tmp_imgs_dir)
|
||||
|
||||
logging.info("Concatenating episodes")
|
||||
ep_dicts = []
|
||||
for episode_index in tqdm.tqdm(range(num_episodes)):
|
||||
ep_path = episodes_dir / f"episode_{episode_index}.pth"
|
||||
ep_dict = torch.load(ep_path)
|
||||
ep_dicts.append(ep_dict)
|
||||
data_dict = concatenate_episodes(ep_dicts)
|
||||
|
||||
total_frames = data_dict["frame_index"].shape[0]
|
||||
data_dict["index"] = torch.arange(0, total_frames, 1)
|
||||
|
||||
hf_dataset = to_hf_dataset(data_dict, video)
|
||||
episode_data_index = calculate_episode_data_index(hf_dataset)
|
||||
info = {
|
||||
"codebase_version": CODEBASE_VERSION,
|
||||
"fps": fps,
|
||||
"video": video,
|
||||
}
|
||||
if video:
|
||||
info["encoding"] = get_default_encoding()
|
||||
|
||||
lerobot_dataset = LeRobotDataset.from_preloaded(
|
||||
repo_id=repo_id,
|
||||
hf_dataset=hf_dataset,
|
||||
episode_data_index=episode_data_index,
|
||||
info=info,
|
||||
videos_dir=videos_dir,
|
||||
)
|
||||
if run_compute_stats:
|
||||
logging.info("Computing dataset statistics")
|
||||
say("Computing dataset statistics")
|
||||
stats = compute_stats(lerobot_dataset)
|
||||
lerobot_dataset.stats = stats
|
||||
else:
|
||||
stats = {}
|
||||
logging.info("Skipping computation of the dataset statistics")
|
||||
|
||||
hf_dataset = hf_dataset.with_format(None) # to remove transforms that cant be saved
|
||||
hf_dataset.save_to_disk(str(local_dir / "train"))
|
||||
|
||||
meta_data_dir = local_dir / "meta_data"
|
||||
save_meta_data(info, stats, episode_data_index, meta_data_dir)
|
||||
|
||||
if push_to_hub:
|
||||
hf_dataset.push_to_hub(repo_id, revision="main")
|
||||
push_meta_data_to_hub(repo_id, meta_data_dir, revision="main")
|
||||
push_dataset_card_to_hub(repo_id, revision="main", tags=tags)
|
||||
if video:
|
||||
push_videos_to_hub(repo_id, videos_dir, revision="main")
|
||||
create_branch(repo_id, repo_type="dataset", branch=CODEBASE_VERSION)
|
||||
|
||||
logging.info("Exiting")
|
||||
say("Exiting")
|
||||
log_say("Exiting", play_sounds)
|
||||
return lerobot_dataset
|
||||
|
||||
|
||||
def replay(robot: Robot, episode: int, fps: int | None = None, root="data", repo_id="lerobot/debug"):
|
||||
@safe_disconnect
|
||||
def replay(
|
||||
robot: Robot, episode: int, fps: int | None = None, root="data", repo_id="lerobot/debug", play_sounds=True
|
||||
):
|
||||
# TODO(rcadene, aliberts): refactor with control_loop, once `dataset` is an instance of LeRobotDataset
|
||||
# TODO(rcadene): Add option to record logs
|
||||
local_dir = Path(root) / repo_id
|
||||
if not local_dir.exists():
|
||||
|
@ -739,8 +328,7 @@ def replay(robot: Robot, episode: int, fps: int | None = None, root="data", repo
|
|||
if not robot.is_connected:
|
||||
robot.connect()
|
||||
|
||||
logging.info("Replaying episode")
|
||||
say("Replaying episode", blocking=True)
|
||||
log_say("Replaying episode", play_sounds, blocking=True)
|
||||
for idx in range(from_idx, to_idx):
|
||||
start_episode_t = time.perf_counter()
|
||||
|
||||
|
@ -785,6 +373,12 @@ if __name__ == "__main__":
|
|||
parser_teleop.add_argument(
|
||||
"--fps", type=none_or_int, default=None, help="Frames per second (set to None to disable)"
|
||||
)
|
||||
parser_teleop.add_argument(
|
||||
"--display-cameras",
|
||||
type=int,
|
||||
default=1,
|
||||
help="Display all cameras on screen (set to 1 to display or 0).",
|
||||
)
|
||||
|
||||
parser_record = subparsers.add_parser("record", parents=[base_parser])
|
||||
parser_record.add_argument(
|
||||
|
@ -840,12 +434,23 @@ if __name__ == "__main__":
|
|||
help="Add tags to your dataset on the hub.",
|
||||
)
|
||||
parser_record.add_argument(
|
||||
"--num-image-writers-per-camera",
|
||||
"--num-image-writer-processes",
|
||||
type=int,
|
||||
default=0,
|
||||
help=(
|
||||
"Number of subprocesses handling the saving of frames as PNGs. Set to 0 to use threads only; "
|
||||
"set to ≥1 to use subprocesses, each using threads to write images. The best number of processes "
|
||||
"and threads depends on your system. We recommend 4 threads per camera with 0 processes. "
|
||||
"If fps is unstable, adjust the thread count. If still unstable, try using 1 or more subprocesses."
|
||||
),
|
||||
)
|
||||
parser_record.add_argument(
|
||||
"--num-image-writer-threads-per-camera",
|
||||
type=int,
|
||||
default=4,
|
||||
help=(
|
||||
"Number of threads writing the frames as png images on disk, per camera. "
|
||||
"Too much threads might cause unstable teleoperation fps due to main thread being blocked. "
|
||||
"Too many threads might cause unstable teleoperation fps due to main thread being blocked. "
|
||||
"Not enough threads might cause low camera fps."
|
||||
),
|
||||
)
|
||||
|
@ -911,19 +516,7 @@ if __name__ == "__main__":
|
|||
teleoperate(robot, **kwargs)
|
||||
|
||||
elif control_mode == "record":
|
||||
pretrained_policy_name_or_path = args.pretrained_policy_name_or_path
|
||||
policy_overrides = args.policy_overrides
|
||||
del kwargs["pretrained_policy_name_or_path"]
|
||||
del kwargs["policy_overrides"]
|
||||
|
||||
policy_cfg = None
|
||||
if pretrained_policy_name_or_path is not None:
|
||||
pretrained_policy_path = get_pretrained_policy_path(pretrained_policy_name_or_path)
|
||||
policy_cfg = init_hydra_config(pretrained_policy_path / "config.yaml", policy_overrides)
|
||||
policy = make_policy(hydra_cfg=policy_cfg, pretrained_policy_name_or_path=pretrained_policy_path)
|
||||
record(robot, policy, policy_cfg, **kwargs)
|
||||
else:
|
||||
record(robot, **kwargs)
|
||||
record(robot, **kwargs)
|
||||
|
||||
elif control_mode == "replay":
|
||||
replay(robot, **kwargs)
|
||||
|
|
|
@ -383,7 +383,7 @@ def train(cfg: DictConfig, out_dir: str | None = None, job_name: str | None = No
|
|||
logging.info(f"Checkpoint policy after step {step}")
|
||||
# Note: Save with step as the identifier, and format it to have at least 6 digits but more if
|
||||
# needed (choose 6 as a minimum for consistency without being overkill).
|
||||
logger.save_checkpont(
|
||||
logger.save_checkpoint(
|
||||
step,
|
||||
policy,
|
||||
optimizer,
|
||||
|
|
|
@ -250,7 +250,7 @@
|
|||
if(!canPlayVideos){
|
||||
this.videoCodecError = true;
|
||||
}
|
||||
|
||||
|
||||
// process CSV data
|
||||
this.videos = document.querySelectorAll('video');
|
||||
this.video = this.videos[0];
|
||||
|
|
|
@ -5389,7 +5389,7 @@ docs = ["sphinx", "sphinx-automodapi", "sphinx-rtd-theme"]
|
|||
name = "pyserial"
|
||||
version = "3.5"
|
||||
description = "Python Serial Port Extension"
|
||||
optional = true
|
||||
optional = false
|
||||
python-versions = "*"
|
||||
files = [
|
||||
{file = "pyserial-3.5-py2.py3-none-any.whl", hash = "sha256:c4451db6ba391ca6ca299fb3ec7bae67a5c55dde170964c7a14ceefec02f2cf0"},
|
||||
|
|
|
@ -52,8 +52,9 @@ def is_robot_available(robot_type):
|
|||
print(f"\nInstall module '{e.name}'")
|
||||
elif isinstance(e, SerialException):
|
||||
print("\nNo physical motors bus detected.")
|
||||
else:
|
||||
traceback.print_exc()
|
||||
|
||||
traceback.print_exc()
|
||||
return False
|
||||
|
||||
|
||||
|
@ -77,8 +78,9 @@ def is_camera_available(camera_type):
|
|||
print(f"\nInstall module '{e.name}'")
|
||||
elif isinstance(e, ValueError) and "camera_index" in e.args[0]:
|
||||
print("\nNo physical camera detected.")
|
||||
else:
|
||||
traceback.print_exc()
|
||||
|
||||
traceback.print_exc()
|
||||
return False
|
||||
|
||||
|
||||
|
@ -102,8 +104,9 @@ def is_motor_available(motor_type):
|
|||
print(f"\nInstall module '{e.name}'")
|
||||
elif isinstance(e, SerialException):
|
||||
print("\nNo physical motors bus detected.")
|
||||
else:
|
||||
traceback.print_exc()
|
||||
|
||||
traceback.print_exc()
|
||||
return False
|
||||
|
||||
|
||||
|
|
|
@ -1,3 +0,0 @@
|
|||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:b5a9f73a2356aff9c717cdfd0d37a6da08b0cf2cc09c98edbc9492501b7f64a5
|
||||
size 5104
|
|
@ -1,3 +0,0 @@
|
|||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:28738b3cfad17af0ac5181effdd796acdf7953cd5bcca3f421a11ddfd6b0076f
|
||||
size 30800
|
|
@ -1,3 +0,0 @@
|
|||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:4bb8a197a40456fdbc16029126268e6bcef3eca1837d88235165dc7e14618bea
|
||||
size 68
|
|
@ -1,3 +0,0 @@
|
|||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:bea60cce42d324f539dd3bca1e66b5ba6391838fdcadb00efc25f3240edb529a
|
||||
size 33600
|
|
@ -23,13 +23,18 @@ pytest -sx 'tests/test_control_robot.py::test_teleoperate[aloha-True]'
|
|||
```
|
||||
"""
|
||||
|
||||
import multiprocessing
|
||||
from pathlib import Path
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
|
||||
from lerobot.common.datasets.populate_dataset import add_frame, init_dataset
|
||||
from lerobot.common.logger import Logger
|
||||
from lerobot.common.policies.factory import make_policy
|
||||
from lerobot.common.utils.utils import init_hydra_config
|
||||
from lerobot.scripts.control_robot import calibrate, get_available_arms, record, replay, teleoperate
|
||||
from lerobot.scripts.control_robot import calibrate, record, replay, teleoperate
|
||||
from lerobot.scripts.train import make_optimizer_and_scheduler
|
||||
from tests.test_robots import make_robot
|
||||
from tests.utils import DEFAULT_CONFIG_PATH, DEVICE, TEST_ROBOT_TYPES, require_robot
|
||||
|
||||
|
@ -37,7 +42,7 @@ from tests.utils import DEFAULT_CONFIG_PATH, DEVICE, TEST_ROBOT_TYPES, require_r
|
|||
@pytest.mark.parametrize("robot_type, mock", TEST_ROBOT_TYPES)
|
||||
@require_robot
|
||||
def test_teleoperate(tmpdir, request, robot_type, mock):
|
||||
if mock:
|
||||
if mock and robot_type != "aloha":
|
||||
request.getfixturevalue("patch_builtins_input")
|
||||
|
||||
# Create an empty calibration directory to trigger manual calibration
|
||||
|
@ -68,7 +73,7 @@ def test_calibrate(tmpdir, request, robot_type, mock):
|
|||
overrides_calibration_dir = [f"calibration_dir={calibration_dir}"]
|
||||
|
||||
robot = make_robot(robot_type, overrides=overrides_calibration_dir, mock=mock)
|
||||
calibrate(robot, arms=get_available_arms(robot))
|
||||
calibrate(robot, arms=robot.available_arms)
|
||||
del robot
|
||||
|
||||
|
||||
|
@ -78,7 +83,7 @@ def test_record_without_cameras(tmpdir, request, robot_type, mock):
|
|||
# Avoid using cameras
|
||||
overrides = ["~cameras"]
|
||||
|
||||
if mock:
|
||||
if mock and robot_type != "aloha":
|
||||
request.getfixturevalue("patch_builtins_input")
|
||||
|
||||
# Create an empty calibration directory to trigger manual calibration
|
||||
|
@ -101,72 +106,345 @@ def test_record_without_cameras(tmpdir, request, robot_type, mock):
|
|||
run_compute_stats=False,
|
||||
push_to_hub=False,
|
||||
video=False,
|
||||
play_sounds=False,
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("robot_type, mock", TEST_ROBOT_TYPES)
|
||||
@require_robot
|
||||
def test_record_and_replay_and_policy(tmpdir, request, robot_type, mock):
|
||||
if mock:
|
||||
tmpdir = Path(tmpdir)
|
||||
|
||||
if mock and robot_type != "aloha":
|
||||
request.getfixturevalue("patch_builtins_input")
|
||||
|
||||
# Create an empty calibration directory to trigger manual calibration
|
||||
# and avoid writing calibration files in user .cache/calibration folder
|
||||
calibration_dir = Path(tmpdir) / robot_type
|
||||
calibration_dir = tmpdir / robot_type
|
||||
overrides = [f"calibration_dir={calibration_dir}"]
|
||||
else:
|
||||
# Use the default .cache/calibration folder when mock=False
|
||||
# Use the default .cache/calibration folder when mock=False or for aloha
|
||||
overrides = None
|
||||
|
||||
if robot_type == "aloha":
|
||||
pytest.skip("TODO(rcadene): enable test once aloha_real and act_aloha_real are merged")
|
||||
|
||||
env_name = "koch_real"
|
||||
policy_name = "act_koch_real"
|
||||
|
||||
root = Path(tmpdir) / "data"
|
||||
root = tmpdir / "data"
|
||||
repo_id = "lerobot/debug"
|
||||
eval_repo_id = "lerobot/eval_debug"
|
||||
|
||||
robot = make_robot(robot_type, overrides=overrides, mock=mock)
|
||||
dataset = record(
|
||||
robot,
|
||||
fps=30,
|
||||
root=root,
|
||||
repo_id=repo_id,
|
||||
root,
|
||||
repo_id,
|
||||
fps=1,
|
||||
warmup_time_s=1,
|
||||
episode_time_s=1,
|
||||
reset_time_s=1,
|
||||
num_episodes=2,
|
||||
push_to_hub=False,
|
||||
# TODO(rcadene, aliberts): test video=True
|
||||
video=False,
|
||||
# TODO(rcadene): display cameras through cv2 sometimes crashes on mac
|
||||
display_cameras=False,
|
||||
play_sounds=False,
|
||||
)
|
||||
assert dataset.num_episodes == 2
|
||||
assert len(dataset) == 2
|
||||
|
||||
replay(robot, episode=0, fps=30, root=root, repo_id=repo_id)
|
||||
replay(robot, episode=0, fps=1, root=root, repo_id=repo_id, play_sounds=False)
|
||||
|
||||
# TODO(rcadene, aliberts): rethink this design
|
||||
if robot_type == "aloha":
|
||||
env_name = "aloha_real"
|
||||
policy_name = "act_aloha_real"
|
||||
elif robot_type in ["koch", "koch_bimanual"]:
|
||||
env_name = "koch_real"
|
||||
policy_name = "act_koch_real"
|
||||
else:
|
||||
raise NotImplementedError(robot_type)
|
||||
|
||||
overrides = [
|
||||
f"env={env_name}",
|
||||
f"policy={policy_name}",
|
||||
f"device={DEVICE}",
|
||||
]
|
||||
|
||||
if robot_type == "koch_bimanual":
|
||||
overrides += ["env.state_dim=12", "env.action_dim=12"]
|
||||
|
||||
overrides += ["wandb.enable=false"]
|
||||
overrides += ["env.fps=1"]
|
||||
|
||||
cfg = init_hydra_config(
|
||||
DEFAULT_CONFIG_PATH,
|
||||
overrides=[
|
||||
f"env={env_name}",
|
||||
f"policy={policy_name}",
|
||||
f"device={DEVICE}",
|
||||
],
|
||||
overrides=overrides,
|
||||
)
|
||||
|
||||
policy = make_policy(hydra_cfg=cfg, dataset_stats=dataset.stats)
|
||||
optimizer, lr_scheduler = make_optimizer_and_scheduler(cfg, policy)
|
||||
out_dir = tmpdir / "logger"
|
||||
logger = Logger(cfg, out_dir, wandb_job_name="debug")
|
||||
logger.save_checkpoint(
|
||||
0,
|
||||
policy,
|
||||
optimizer,
|
||||
lr_scheduler,
|
||||
identifier=0,
|
||||
)
|
||||
pretrained_policy_name_or_path = out_dir / "checkpoints/last/pretrained_model"
|
||||
|
||||
# In `examples/9_use_aloha.md`, we advise using `num_image_writer_processes=1`
|
||||
# during inference, to reach constent fps, so we test this here.
|
||||
if robot_type == "aloha":
|
||||
num_image_writer_processes = 1
|
||||
|
||||
# `multiprocessing.set_start_method("spawn", force=True)` avoids a hanging issue
|
||||
# before exiting pytest. However, it outputs the following error in the log:
|
||||
# Traceback (most recent call last):
|
||||
# File "<string>", line 1, in <module>
|
||||
# File "/Users/rcadene/miniconda3/envs/lerobot/lib/python3.10/multiprocessing/spawn.py", line 116, in spawn_main
|
||||
# exitcode = _main(fd, parent_sentinel)
|
||||
# File "/Users/rcadene/miniconda3/envs/lerobot/lib/python3.10/multiprocessing/spawn.py", line 126, in _main
|
||||
# self = reduction.pickle.load(from_parent)
|
||||
# File "/Users/rcadene/miniconda3/envs/lerobot/lib/python3.10/multiprocessing/synchronize.py", line 110, in __setstate__
|
||||
# self._semlock = _multiprocessing.SemLock._rebuild(*state)
|
||||
# FileNotFoundError: [Errno 2] No such file or directory
|
||||
# TODO(rcadene, aliberts): fix FileNotFoundError in multiprocessing
|
||||
multiprocessing.set_start_method("spawn", force=True)
|
||||
else:
|
||||
num_image_writer_processes = 0
|
||||
|
||||
record(
|
||||
robot,
|
||||
policy,
|
||||
cfg,
|
||||
root,
|
||||
eval_repo_id,
|
||||
pretrained_policy_name_or_path,
|
||||
warmup_time_s=1,
|
||||
episode_time_s=1,
|
||||
reset_time_s=1,
|
||||
num_episodes=2,
|
||||
run_compute_stats=False,
|
||||
push_to_hub=False,
|
||||
video=False,
|
||||
display_cameras=False,
|
||||
play_sounds=False,
|
||||
num_image_writer_processes=num_image_writer_processes,
|
||||
)
|
||||
|
||||
assert dataset.num_episodes == 2
|
||||
assert len(dataset) == 2
|
||||
|
||||
del robot
|
||||
|
||||
|
||||
@pytest.mark.parametrize("robot_type, mock", [("koch", True)])
|
||||
@require_robot
|
||||
def test_resume_record(tmpdir, request, robot_type, mock):
|
||||
if mock and robot_type != "aloha":
|
||||
request.getfixturevalue("patch_builtins_input")
|
||||
|
||||
# Create an empty calibration directory to trigger manual calibration
|
||||
# and avoid writing calibration files in user .cache/calibration folder
|
||||
calibration_dir = tmpdir / robot_type
|
||||
overrides = [f"calibration_dir={calibration_dir}"]
|
||||
else:
|
||||
# Use the default .cache/calibration folder when mock=False or for aloha
|
||||
overrides = []
|
||||
|
||||
robot = make_robot(robot_type, overrides=overrides, mock=mock)
|
||||
|
||||
root = Path(tmpdir) / "data"
|
||||
repo_id = "lerobot/debug"
|
||||
|
||||
dataset = record(
|
||||
robot,
|
||||
root,
|
||||
repo_id,
|
||||
fps=1,
|
||||
warmup_time_s=0,
|
||||
episode_time_s=1,
|
||||
num_episodes=1,
|
||||
push_to_hub=False,
|
||||
video=False,
|
||||
display_cameras=False,
|
||||
play_sounds=False,
|
||||
run_compute_stats=False,
|
||||
)
|
||||
assert len(dataset) == 1, "`dataset` should contain only 1 frame"
|
||||
|
||||
init_dataset_return_value = {}
|
||||
|
||||
def wrapped_init_dataset(*args, **kwargs):
|
||||
nonlocal init_dataset_return_value
|
||||
init_dataset_return_value = init_dataset(*args, **kwargs)
|
||||
return init_dataset_return_value
|
||||
|
||||
with patch("lerobot.scripts.control_robot.init_dataset", wraps=wrapped_init_dataset):
|
||||
dataset = record(
|
||||
robot,
|
||||
root,
|
||||
repo_id,
|
||||
fps=1,
|
||||
warmup_time_s=0,
|
||||
episode_time_s=1,
|
||||
num_episodes=2,
|
||||
push_to_hub=False,
|
||||
video=False,
|
||||
display_cameras=False,
|
||||
play_sounds=False,
|
||||
run_compute_stats=False,
|
||||
)
|
||||
assert len(dataset) == 2, "`dataset` should contain only 1 frame"
|
||||
assert (
|
||||
init_dataset_return_value["num_episodes"] == 2
|
||||
), "`init_dataset` should load the previous episode"
|
||||
|
||||
|
||||
@pytest.mark.parametrize("robot_type, mock", [("koch", True)])
|
||||
@require_robot
|
||||
def test_record_with_event_rerecord_episode(tmpdir, request, robot_type, mock):
|
||||
if mock and robot_type != "aloha":
|
||||
request.getfixturevalue("patch_builtins_input")
|
||||
|
||||
# Create an empty calibration directory to trigger manual calibration
|
||||
# and avoid writing calibration files in user .cache/calibration folder
|
||||
calibration_dir = tmpdir / robot_type
|
||||
overrides = [f"calibration_dir={calibration_dir}"]
|
||||
else:
|
||||
# Use the default .cache/calibration folder when mock=False or for aloha
|
||||
overrides = []
|
||||
|
||||
robot = make_robot(robot_type, overrides=overrides, mock=mock)
|
||||
with (
|
||||
patch("lerobot.scripts.control_robot.init_keyboard_listener") as mock_listener,
|
||||
patch("lerobot.common.robot_devices.control_utils.add_frame", wraps=add_frame) as mock_add_frame,
|
||||
):
|
||||
mock_events = {}
|
||||
mock_events["exit_early"] = True
|
||||
mock_events["rerecord_episode"] = True
|
||||
mock_events["stop_recording"] = False
|
||||
mock_listener.return_value = (None, mock_events)
|
||||
|
||||
root = Path(tmpdir) / "data"
|
||||
repo_id = "lerobot/debug"
|
||||
|
||||
dataset = record(
|
||||
robot,
|
||||
root,
|
||||
repo_id,
|
||||
fps=1,
|
||||
warmup_time_s=0,
|
||||
episode_time_s=1,
|
||||
num_episodes=1,
|
||||
push_to_hub=False,
|
||||
video=False,
|
||||
display_cameras=False,
|
||||
play_sounds=False,
|
||||
run_compute_stats=False,
|
||||
)
|
||||
|
||||
assert not mock_events["rerecord_episode"], "`rerecord_episode` wasn't properly reset to False"
|
||||
assert not mock_events["exit_early"], "`exit_early` wasn't properly reset to False"
|
||||
assert mock_add_frame.call_count == 2, "`add_frame` should have been called 2 times"
|
||||
assert len(dataset) == 1, "`dataset` should contain only 1 frame"
|
||||
|
||||
|
||||
@pytest.mark.parametrize("robot_type, mock", [("koch", True)])
|
||||
@require_robot
|
||||
def test_record_with_event_exit_early(tmpdir, request, robot_type, mock):
|
||||
if mock:
|
||||
request.getfixturevalue("patch_builtins_input")
|
||||
|
||||
# Create an empty calibration directory to trigger manual calibration
|
||||
# and avoid writing calibration files in user .cache/calibration folder
|
||||
calibration_dir = tmpdir / robot_type
|
||||
overrides = [f"calibration_dir={calibration_dir}"]
|
||||
else:
|
||||
# Use the default .cache/calibration folder when mock=False or for aloha
|
||||
overrides = []
|
||||
|
||||
robot = make_robot(robot_type, overrides=overrides, mock=mock)
|
||||
with (
|
||||
patch("lerobot.scripts.control_robot.init_keyboard_listener") as mock_listener,
|
||||
patch("lerobot.common.robot_devices.control_utils.add_frame", wraps=add_frame) as mock_add_frame,
|
||||
):
|
||||
mock_events = {}
|
||||
mock_events["exit_early"] = True
|
||||
mock_events["rerecord_episode"] = False
|
||||
mock_events["stop_recording"] = False
|
||||
mock_listener.return_value = (None, mock_events)
|
||||
|
||||
root = Path(tmpdir) / "data"
|
||||
repo_id = "lerobot/debug"
|
||||
|
||||
dataset = record(
|
||||
robot,
|
||||
fps=2,
|
||||
root=root,
|
||||
repo_id=repo_id,
|
||||
warmup_time_s=0,
|
||||
episode_time_s=1,
|
||||
num_episodes=1,
|
||||
push_to_hub=False,
|
||||
video=False,
|
||||
display_cameras=False,
|
||||
play_sounds=False,
|
||||
run_compute_stats=False,
|
||||
)
|
||||
|
||||
assert not mock_events["exit_early"], "`exit_early` wasn't properly reset to False"
|
||||
assert mock_add_frame.call_count == 1, "`add_frame` should have been called 1 time"
|
||||
assert len(dataset) == 1, "`dataset` should contain only 1 frame"
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"robot_type, mock, num_image_writer_processes", [("koch", True, 0), ("koch", True, 1)]
|
||||
)
|
||||
@require_robot
|
||||
def test_record_with_event_stop_recording(tmpdir, request, robot_type, mock, num_image_writer_processes):
|
||||
if mock:
|
||||
request.getfixturevalue("patch_builtins_input")
|
||||
|
||||
# Create an empty calibration directory to trigger manual calibration
|
||||
# and avoid writing calibration files in user .cache/calibration folder
|
||||
calibration_dir = tmpdir / robot_type
|
||||
overrides = [f"calibration_dir={calibration_dir}"]
|
||||
else:
|
||||
# Use the default .cache/calibration folder when mock=False or for aloha
|
||||
overrides = []
|
||||
|
||||
robot = make_robot(robot_type, overrides=overrides, mock=mock)
|
||||
with (
|
||||
patch("lerobot.scripts.control_robot.init_keyboard_listener") as mock_listener,
|
||||
patch("lerobot.common.robot_devices.control_utils.add_frame", wraps=add_frame) as mock_add_frame,
|
||||
):
|
||||
mock_events = {}
|
||||
mock_events["exit_early"] = True
|
||||
mock_events["rerecord_episode"] = False
|
||||
mock_events["stop_recording"] = True
|
||||
mock_listener.return_value = (None, mock_events)
|
||||
|
||||
root = Path(tmpdir) / "data"
|
||||
repo_id = "lerobot/debug"
|
||||
|
||||
dataset = record(
|
||||
robot,
|
||||
root,
|
||||
repo_id,
|
||||
fps=1,
|
||||
warmup_time_s=0,
|
||||
episode_time_s=1,
|
||||
num_episodes=2,
|
||||
push_to_hub=False,
|
||||
video=False,
|
||||
display_cameras=False,
|
||||
play_sounds=False,
|
||||
run_compute_stats=False,
|
||||
num_image_writer_processes=num_image_writer_processes,
|
||||
)
|
||||
|
||||
assert not mock_events["exit_early"], "`exit_early` wasn't properly reset to False"
|
||||
assert mock_add_frame.call_count == 1, "`add_frame` should have been called 1 time"
|
||||
assert len(dataset) == 1, "`dataset` should contain only 1 frame"
|
||||
|
|
|
@ -308,12 +308,11 @@ def test_flatten_unflatten_dict():
|
|||
# "lerobot/cmu_stretch",
|
||||
],
|
||||
)
|
||||
# TODO(rcadene, aliberts): all these tests fail locally on Mac M1, but not on Linux
|
||||
def test_backward_compatibility(repo_id):
|
||||
"""The artifacts for this test have been generated by `tests/scripts/save_dataset_to_safetensors.py`."""
|
||||
|
||||
dataset = LeRobotDataset(
|
||||
repo_id,
|
||||
)
|
||||
dataset = LeRobotDataset(repo_id)
|
||||
|
||||
test_dir = Path("tests/data/save_dataset_to_safetensors") / repo_id
|
||||
|
||||
|
|
|
@ -367,8 +367,7 @@ def test_normalize(insert_temporal_dim):
|
|||
),
|
||||
("aloha", "act", ["policy.n_action_steps=10"], ""),
|
||||
("aloha", "act", ["policy.n_action_steps=1000", "policy.chunk_size=1000"], "_1000_steps"),
|
||||
("dora_aloha_real", "act_real", ["policy.n_action_steps=10"], ""),
|
||||
("dora_aloha_real", "act_real_no_state", ["policy.n_action_steps=10"], ""),
|
||||
("dora_aloha_real", "act_aloha_real", ["policy.n_action_steps=10"], ""),
|
||||
],
|
||||
)
|
||||
# As artifacts have been generated on an x86_64 kernel, this test won't
|
||||
|
|
|
@ -127,6 +127,7 @@ def test_robot(tmpdir, request, robot_type, mock):
|
|||
# TODO(rcadene): skipping image for now as it's challenging to assess equality between two consecutive frames
|
||||
continue
|
||||
assert torch.allclose(captured_observation[name], observation[name], atol=1)
|
||||
assert captured_observation[name].shape == observation[name].shape
|
||||
|
||||
# Test send_action can run
|
||||
robot.send_action(action["action"])
|
||||
|
|
Loading…
Reference in New Issue