Refactor `record` with `add_frame` (#468)
Co-authored-by: Simon Alibert <75076266+aliberts@users.noreply.github.com>
This commit is contained in:
parent
97b1feb0b3
commit
77478d50e5
|
@ -47,6 +47,7 @@ jobs:
|
||||||
pipx install poetry && poetry config virtualenvs.in-project true
|
pipx install poetry && poetry config virtualenvs.in-project true
|
||||||
echo "${{ github.workspace }}/.venv/bin" >> $GITHUB_PATH
|
echo "${{ github.workspace }}/.venv/bin" >> $GITHUB_PATH
|
||||||
|
|
||||||
|
# TODO(rcadene, aliberts): python 3.12 seems to be used in the tests, not python 3.10
|
||||||
- name: Set up Python 3.10
|
- name: Set up Python 3.10
|
||||||
uses: actions/setup-python@v5
|
uses: actions/setup-python@v5
|
||||||
with:
|
with:
|
||||||
|
@ -84,6 +85,7 @@ jobs:
|
||||||
pipx install poetry && poetry config virtualenvs.in-project true
|
pipx install poetry && poetry config virtualenvs.in-project true
|
||||||
echo "${{ github.workspace }}/.venv/bin" >> $GITHUB_PATH
|
echo "${{ github.workspace }}/.venv/bin" >> $GITHUB_PATH
|
||||||
|
|
||||||
|
# TODO(rcadene, aliberts): python 3.12 seems to be used in the tests, not python 3.10
|
||||||
- name: Set up Python 3.10
|
- name: Set up Python 3.10
|
||||||
uses: actions/setup-python@v5
|
uses: actions/setup-python@v5
|
||||||
with:
|
with:
|
||||||
|
|
|
@ -0,0 +1,468 @@
|
||||||
|
"""Functions to create an empty dataset, and populate it with frames."""
|
||||||
|
# TODO(rcadene, aliberts): to adapt as class methods of next version of LeRobotDataset
|
||||||
|
|
||||||
|
import concurrent
|
||||||
|
import json
|
||||||
|
import logging
|
||||||
|
import multiprocessing
|
||||||
|
import shutil
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import tqdm
|
||||||
|
from PIL import Image
|
||||||
|
|
||||||
|
from lerobot.common.datasets.compute_stats import compute_stats
|
||||||
|
from lerobot.common.datasets.lerobot_dataset import CODEBASE_VERSION, LeRobotDataset
|
||||||
|
from lerobot.common.datasets.push_dataset_to_hub.aloha_hdf5_format import to_hf_dataset
|
||||||
|
from lerobot.common.datasets.push_dataset_to_hub.utils import concatenate_episodes, get_default_encoding
|
||||||
|
from lerobot.common.datasets.utils import calculate_episode_data_index, create_branch
|
||||||
|
from lerobot.common.datasets.video_utils import encode_video_frames
|
||||||
|
from lerobot.common.utils.utils import log_say
|
||||||
|
from lerobot.scripts.push_dataset_to_hub import (
|
||||||
|
push_dataset_card_to_hub,
|
||||||
|
push_meta_data_to_hub,
|
||||||
|
push_videos_to_hub,
|
||||||
|
save_meta_data,
|
||||||
|
)
|
||||||
|
|
||||||
|
########################################################################################
|
||||||
|
# Asynchrounous saving of images on disk
|
||||||
|
########################################################################################
|
||||||
|
|
||||||
|
|
||||||
|
def safe_stop_image_writer(func):
|
||||||
|
# TODO(aliberts): Allow to pass custom exceptions
|
||||||
|
# (e.g. ThreadServiceExit, KeyboardInterrupt, SystemExit, UnpluggedError, DynamixelCommError)
|
||||||
|
def wrapper(*args, **kwargs):
|
||||||
|
try:
|
||||||
|
return func(*args, **kwargs)
|
||||||
|
except Exception as e:
|
||||||
|
image_writer = kwargs.get("dataset", {}).get("image_writer")
|
||||||
|
if image_writer is not None:
|
||||||
|
print("Waiting for image writer to terminate...")
|
||||||
|
stop_image_writer(image_writer, timeout=20)
|
||||||
|
raise e
|
||||||
|
|
||||||
|
return wrapper
|
||||||
|
|
||||||
|
|
||||||
|
def save_image(img_tensor, key, frame_index, episode_index, videos_dir: str):
|
||||||
|
img = Image.fromarray(img_tensor.numpy())
|
||||||
|
path = Path(videos_dir) / f"{key}_episode_{episode_index:06d}" / f"frame_{frame_index:06d}.png"
|
||||||
|
path.parent.mkdir(parents=True, exist_ok=True)
|
||||||
|
img.save(str(path), quality=100)
|
||||||
|
|
||||||
|
|
||||||
|
def loop_to_save_images_in_threads(image_queue, num_threads):
|
||||||
|
if num_threads < 1:
|
||||||
|
raise NotImplementedError(f"Only `num_threads>=1` is supported for now, but {num_threads=} given.")
|
||||||
|
|
||||||
|
with concurrent.futures.ThreadPoolExecutor(max_workers=num_threads) as executor:
|
||||||
|
futures = []
|
||||||
|
while True:
|
||||||
|
# Blocks until a frame is available
|
||||||
|
frame_data = image_queue.get()
|
||||||
|
|
||||||
|
# As usually done, exit loop when receiving None to stop the worker
|
||||||
|
if frame_data is None:
|
||||||
|
break
|
||||||
|
|
||||||
|
image, key, frame_index, episode_index, videos_dir = frame_data
|
||||||
|
futures.append(executor.submit(save_image, image, key, frame_index, episode_index, videos_dir))
|
||||||
|
|
||||||
|
# Before exiting function, wait for all threads to complete
|
||||||
|
with tqdm.tqdm(total=len(futures), desc="Writing images") as progress_bar:
|
||||||
|
concurrent.futures.wait(futures)
|
||||||
|
progress_bar.update(len(futures))
|
||||||
|
|
||||||
|
|
||||||
|
def start_image_writer_processes(image_queue, num_processes, num_threads_per_process):
|
||||||
|
if num_processes < 1:
|
||||||
|
raise ValueError(f"Only `num_processes>=1` is supported, but {num_processes=} given.")
|
||||||
|
|
||||||
|
if num_threads_per_process < 1:
|
||||||
|
raise NotImplementedError(
|
||||||
|
"Only `num_threads_per_process>=1` is supported for now, but {num_threads_per_process=} given."
|
||||||
|
)
|
||||||
|
|
||||||
|
processes = []
|
||||||
|
for _ in range(num_processes):
|
||||||
|
process = multiprocessing.Process(
|
||||||
|
target=loop_to_save_images_in_threads,
|
||||||
|
args=(image_queue, num_threads_per_process),
|
||||||
|
)
|
||||||
|
process.start()
|
||||||
|
processes.append(process)
|
||||||
|
return processes
|
||||||
|
|
||||||
|
|
||||||
|
def stop_processes(processes, queue, timeout):
|
||||||
|
# Send None to each process to signal them to stop
|
||||||
|
for _ in processes:
|
||||||
|
queue.put(None)
|
||||||
|
|
||||||
|
# Wait maximum 20 seconds for all processes to terminate
|
||||||
|
for process in processes:
|
||||||
|
process.join(timeout=timeout)
|
||||||
|
|
||||||
|
# If not terminated after 20 seconds, force termination
|
||||||
|
if process.is_alive():
|
||||||
|
process.terminate()
|
||||||
|
|
||||||
|
# Close the queue, no more items can be put in the queue
|
||||||
|
queue.close()
|
||||||
|
|
||||||
|
# Ensure all background queue threads have finished
|
||||||
|
queue.join_thread()
|
||||||
|
|
||||||
|
|
||||||
|
def start_image_writer(num_processes, num_threads):
|
||||||
|
"""This function abstract away the initialisation of processes or/and threads to
|
||||||
|
save images on disk asynchrounously, which is critical to control a robot and record data
|
||||||
|
at a high frame rate.
|
||||||
|
|
||||||
|
When `num_processes=0`, it returns a dictionary containing a threads pool of size `num_threads`.
|
||||||
|
When `num_processes>0`, it returns a dictionary containing a processes pool of size `num_processes`,
|
||||||
|
where each subprocess starts their own threads pool of size `num_threads`.
|
||||||
|
|
||||||
|
The optimal number of processes and threads depends on your computer capabilities.
|
||||||
|
We advise to use 4 threads per camera with 0 processes. If the fps is not stable, try to increase or lower
|
||||||
|
the number of threads. If it is still not stable, try to use 1 subprocess, or more.
|
||||||
|
"""
|
||||||
|
image_writer = {}
|
||||||
|
|
||||||
|
if num_processes == 0:
|
||||||
|
futures = []
|
||||||
|
threads_pool = concurrent.futures.ThreadPoolExecutor(max_workers=num_threads)
|
||||||
|
image_writer["threads_pool"], image_writer["futures"] = threads_pool, futures
|
||||||
|
else:
|
||||||
|
# TODO(rcadene): When using num_processes>1, `multiprocessing.Manager().Queue()`
|
||||||
|
# might be better than `multiprocessing.Queue()`. Source: https://www.geeksforgeeks.org/python-multiprocessing-queue-vs-multiprocessing-manager-queue
|
||||||
|
image_queue = multiprocessing.Queue()
|
||||||
|
processes_pool = start_image_writer_processes(
|
||||||
|
image_queue, num_processes=num_processes, num_threads_per_process=num_threads
|
||||||
|
)
|
||||||
|
image_writer["processes_pool"], image_writer["image_queue"] = processes_pool, image_queue
|
||||||
|
|
||||||
|
return image_writer
|
||||||
|
|
||||||
|
|
||||||
|
def async_save_image(image_writer, image, key, frame_index, episode_index, videos_dir):
|
||||||
|
"""This function abstract away the saving of an image on disk asynchrounously. It uses a dictionary
|
||||||
|
called image writer which contains either a pool of processes or a pool of threads.
|
||||||
|
"""
|
||||||
|
if "threads_pool" in image_writer:
|
||||||
|
threads_pool, futures = image_writer["threads_pool"], image_writer["futures"]
|
||||||
|
futures.append(threads_pool.submit(save_image, image, key, frame_index, episode_index, videos_dir))
|
||||||
|
else:
|
||||||
|
image_queue = image_writer["image_queue"]
|
||||||
|
image_queue.put((image, key, frame_index, episode_index, videos_dir))
|
||||||
|
|
||||||
|
|
||||||
|
def stop_image_writer(image_writer, timeout):
|
||||||
|
if "threads_pool" in image_writer:
|
||||||
|
futures = image_writer["futures"]
|
||||||
|
# Before exiting function, wait for all threads to complete
|
||||||
|
with tqdm.tqdm(total=len(futures), desc="Writing images") as progress_bar:
|
||||||
|
concurrent.futures.wait(futures, timeout=timeout)
|
||||||
|
progress_bar.update(len(futures))
|
||||||
|
else:
|
||||||
|
processes_pool, image_queue = image_writer["processes_pool"], image_writer["image_queue"]
|
||||||
|
stop_processes(processes_pool, image_queue, timeout=timeout)
|
||||||
|
|
||||||
|
|
||||||
|
########################################################################################
|
||||||
|
# Functions to initialize, resume and populate a dataset
|
||||||
|
########################################################################################
|
||||||
|
|
||||||
|
|
||||||
|
def init_dataset(
|
||||||
|
repo_id,
|
||||||
|
root,
|
||||||
|
force_override,
|
||||||
|
fps,
|
||||||
|
video,
|
||||||
|
write_images,
|
||||||
|
num_image_writer_processes,
|
||||||
|
num_image_writer_threads,
|
||||||
|
):
|
||||||
|
local_dir = Path(root) / repo_id
|
||||||
|
if local_dir.exists() and force_override:
|
||||||
|
shutil.rmtree(local_dir)
|
||||||
|
|
||||||
|
episodes_dir = local_dir / "episodes"
|
||||||
|
episodes_dir.mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
|
videos_dir = local_dir / "videos"
|
||||||
|
videos_dir.mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
|
# Logic to resume data recording
|
||||||
|
rec_info_path = episodes_dir / "data_recording_info.json"
|
||||||
|
if rec_info_path.exists():
|
||||||
|
with open(rec_info_path) as f:
|
||||||
|
rec_info = json.load(f)
|
||||||
|
num_episodes = rec_info["last_episode_index"] + 1
|
||||||
|
else:
|
||||||
|
num_episodes = 0
|
||||||
|
|
||||||
|
dataset = {
|
||||||
|
"repo_id": repo_id,
|
||||||
|
"local_dir": local_dir,
|
||||||
|
"videos_dir": videos_dir,
|
||||||
|
"episodes_dir": episodes_dir,
|
||||||
|
"fps": fps,
|
||||||
|
"video": video,
|
||||||
|
"rec_info_path": rec_info_path,
|
||||||
|
"num_episodes": num_episodes,
|
||||||
|
}
|
||||||
|
|
||||||
|
if write_images:
|
||||||
|
# Initialize processes or/and threads dedicated to save images on disk asynchronously,
|
||||||
|
# which is critical to control a robot and record data at a high frame rate.
|
||||||
|
image_writer = start_image_writer(
|
||||||
|
num_processes=num_image_writer_processes,
|
||||||
|
num_threads=num_image_writer_threads,
|
||||||
|
)
|
||||||
|
dataset["image_writer"] = image_writer
|
||||||
|
|
||||||
|
return dataset
|
||||||
|
|
||||||
|
|
||||||
|
def add_frame(dataset, observation, action):
|
||||||
|
if "current_episode" not in dataset:
|
||||||
|
# initialize episode dictionary
|
||||||
|
ep_dict = {}
|
||||||
|
for key in observation:
|
||||||
|
if key not in ep_dict:
|
||||||
|
ep_dict[key] = []
|
||||||
|
for key in action:
|
||||||
|
if key not in ep_dict:
|
||||||
|
ep_dict[key] = []
|
||||||
|
|
||||||
|
ep_dict["episode_index"] = []
|
||||||
|
ep_dict["frame_index"] = []
|
||||||
|
ep_dict["timestamp"] = []
|
||||||
|
ep_dict["next.done"] = []
|
||||||
|
|
||||||
|
dataset["current_episode"] = ep_dict
|
||||||
|
dataset["current_frame_index"] = 0
|
||||||
|
|
||||||
|
ep_dict = dataset["current_episode"]
|
||||||
|
episode_index = dataset["num_episodes"]
|
||||||
|
frame_index = dataset["current_frame_index"]
|
||||||
|
videos_dir = dataset["videos_dir"]
|
||||||
|
video = dataset["video"]
|
||||||
|
fps = dataset["fps"]
|
||||||
|
|
||||||
|
ep_dict["episode_index"].append(episode_index)
|
||||||
|
ep_dict["frame_index"].append(frame_index)
|
||||||
|
ep_dict["timestamp"].append(frame_index / fps)
|
||||||
|
ep_dict["next.done"].append(False)
|
||||||
|
|
||||||
|
img_keys = [key for key in observation if "image" in key]
|
||||||
|
non_img_keys = [key for key in observation if "image" not in key]
|
||||||
|
|
||||||
|
# Save all observed modalities except images
|
||||||
|
for key in non_img_keys:
|
||||||
|
ep_dict[key].append(observation[key])
|
||||||
|
|
||||||
|
# Save actions
|
||||||
|
for key in action:
|
||||||
|
ep_dict[key].append(action[key])
|
||||||
|
|
||||||
|
if "image_writer" not in dataset:
|
||||||
|
dataset["current_frame_index"] += 1
|
||||||
|
return
|
||||||
|
|
||||||
|
# Save images
|
||||||
|
image_writer = dataset["image_writer"]
|
||||||
|
for key in img_keys:
|
||||||
|
imgs_dir = videos_dir / f"{key}_episode_{episode_index:06d}"
|
||||||
|
async_save_image(
|
||||||
|
image_writer,
|
||||||
|
image=observation[key],
|
||||||
|
key=key,
|
||||||
|
frame_index=frame_index,
|
||||||
|
episode_index=episode_index,
|
||||||
|
videos_dir=str(videos_dir),
|
||||||
|
)
|
||||||
|
|
||||||
|
if video:
|
||||||
|
fname = f"{key}_episode_{episode_index:06d}.mp4"
|
||||||
|
frame_info = {"path": f"videos/{fname}", "timestamp": frame_index / fps}
|
||||||
|
else:
|
||||||
|
frame_info = str(imgs_dir / f"frame_{frame_index:06d}.png")
|
||||||
|
|
||||||
|
ep_dict[key].append(frame_info)
|
||||||
|
|
||||||
|
dataset["current_frame_index"] += 1
|
||||||
|
|
||||||
|
|
||||||
|
def delete_current_episode(dataset):
|
||||||
|
del dataset["current_episode"]
|
||||||
|
del dataset["current_frame_index"]
|
||||||
|
|
||||||
|
# delete temporary images
|
||||||
|
episode_index = dataset["num_episodes"]
|
||||||
|
videos_dir = dataset["videos_dir"]
|
||||||
|
for tmp_imgs_dir in videos_dir.glob(f"*_episode_{episode_index:06d}"):
|
||||||
|
shutil.rmtree(tmp_imgs_dir)
|
||||||
|
|
||||||
|
|
||||||
|
def save_current_episode(dataset):
|
||||||
|
episode_index = dataset["num_episodes"]
|
||||||
|
ep_dict = dataset["current_episode"]
|
||||||
|
episodes_dir = dataset["episodes_dir"]
|
||||||
|
rec_info_path = dataset["rec_info_path"]
|
||||||
|
|
||||||
|
ep_dict["next.done"][-1] = True
|
||||||
|
|
||||||
|
for key in ep_dict:
|
||||||
|
if "observation" in key and "image" not in key:
|
||||||
|
ep_dict[key] = torch.stack(ep_dict[key])
|
||||||
|
|
||||||
|
ep_dict["action"] = torch.stack(ep_dict["action"])
|
||||||
|
ep_dict["episode_index"] = torch.tensor(ep_dict["episode_index"])
|
||||||
|
ep_dict["frame_index"] = torch.tensor(ep_dict["frame_index"])
|
||||||
|
ep_dict["timestamp"] = torch.tensor(ep_dict["timestamp"])
|
||||||
|
ep_dict["next.done"] = torch.tensor(ep_dict["next.done"])
|
||||||
|
|
||||||
|
ep_path = episodes_dir / f"episode_{episode_index}.pth"
|
||||||
|
torch.save(ep_dict, ep_path)
|
||||||
|
|
||||||
|
rec_info = {
|
||||||
|
"last_episode_index": episode_index,
|
||||||
|
}
|
||||||
|
with open(rec_info_path, "w") as f:
|
||||||
|
json.dump(rec_info, f)
|
||||||
|
|
||||||
|
# force re-initialization of episode dictionnary during add_frame
|
||||||
|
del dataset["current_episode"]
|
||||||
|
|
||||||
|
dataset["num_episodes"] += 1
|
||||||
|
|
||||||
|
|
||||||
|
def encode_videos(dataset, image_keys, play_sounds):
|
||||||
|
log_say("Encoding videos", play_sounds)
|
||||||
|
|
||||||
|
num_episodes = dataset["num_episodes"]
|
||||||
|
videos_dir = dataset["videos_dir"]
|
||||||
|
local_dir = dataset["local_dir"]
|
||||||
|
fps = dataset["fps"]
|
||||||
|
|
||||||
|
# Use ffmpeg to convert frames stored as png into mp4 videos
|
||||||
|
for episode_index in tqdm.tqdm(range(num_episodes)):
|
||||||
|
for key in image_keys:
|
||||||
|
# key = f"observation.images.{name}"
|
||||||
|
tmp_imgs_dir = videos_dir / f"{key}_episode_{episode_index:06d}"
|
||||||
|
fname = f"{key}_episode_{episode_index:06d}.mp4"
|
||||||
|
video_path = local_dir / "videos" / fname
|
||||||
|
if video_path.exists():
|
||||||
|
# Skip if video is already encoded. Could be the case when resuming data recording.
|
||||||
|
continue
|
||||||
|
# note: `encode_video_frames` is a blocking call. Making it asynchronous shouldn't speedup encoding,
|
||||||
|
# since video encoding with ffmpeg is already using multithreading.
|
||||||
|
encode_video_frames(tmp_imgs_dir, video_path, fps, overwrite=True)
|
||||||
|
shutil.rmtree(tmp_imgs_dir)
|
||||||
|
|
||||||
|
|
||||||
|
def from_dataset_to_lerobot_dataset(dataset, play_sounds):
|
||||||
|
log_say("Consolidate episodes", play_sounds)
|
||||||
|
|
||||||
|
num_episodes = dataset["num_episodes"]
|
||||||
|
episodes_dir = dataset["episodes_dir"]
|
||||||
|
videos_dir = dataset["videos_dir"]
|
||||||
|
video = dataset["video"]
|
||||||
|
fps = dataset["fps"]
|
||||||
|
repo_id = dataset["repo_id"]
|
||||||
|
|
||||||
|
ep_dicts = []
|
||||||
|
for episode_index in tqdm.tqdm(range(num_episodes)):
|
||||||
|
ep_path = episodes_dir / f"episode_{episode_index}.pth"
|
||||||
|
ep_dict = torch.load(ep_path)
|
||||||
|
ep_dicts.append(ep_dict)
|
||||||
|
data_dict = concatenate_episodes(ep_dicts)
|
||||||
|
|
||||||
|
if video:
|
||||||
|
image_keys = [key for key in data_dict if "image" in key]
|
||||||
|
encode_videos(dataset, image_keys, play_sounds)
|
||||||
|
|
||||||
|
hf_dataset = to_hf_dataset(data_dict, video)
|
||||||
|
episode_data_index = calculate_episode_data_index(hf_dataset)
|
||||||
|
|
||||||
|
info = {
|
||||||
|
"codebase_version": CODEBASE_VERSION,
|
||||||
|
"fps": fps,
|
||||||
|
"video": video,
|
||||||
|
}
|
||||||
|
if video:
|
||||||
|
info["encoding"] = get_default_encoding()
|
||||||
|
|
||||||
|
lerobot_dataset = LeRobotDataset.from_preloaded(
|
||||||
|
repo_id=repo_id,
|
||||||
|
hf_dataset=hf_dataset,
|
||||||
|
episode_data_index=episode_data_index,
|
||||||
|
info=info,
|
||||||
|
videos_dir=videos_dir,
|
||||||
|
)
|
||||||
|
|
||||||
|
return lerobot_dataset
|
||||||
|
|
||||||
|
|
||||||
|
def save_lerobot_dataset_on_disk(lerobot_dataset):
|
||||||
|
hf_dataset = lerobot_dataset.hf_dataset
|
||||||
|
info = lerobot_dataset.info
|
||||||
|
stats = lerobot_dataset.stats
|
||||||
|
episode_data_index = lerobot_dataset.episode_data_index
|
||||||
|
local_dir = lerobot_dataset.videos_dir.parent
|
||||||
|
meta_data_dir = local_dir / "meta_data"
|
||||||
|
|
||||||
|
hf_dataset = hf_dataset.with_format(None) # to remove transforms that cant be saved
|
||||||
|
hf_dataset.save_to_disk(str(local_dir / "train"))
|
||||||
|
|
||||||
|
save_meta_data(info, stats, episode_data_index, meta_data_dir)
|
||||||
|
|
||||||
|
|
||||||
|
def push_lerobot_dataset_to_hub(lerobot_dataset, tags):
|
||||||
|
hf_dataset = lerobot_dataset.hf_dataset
|
||||||
|
local_dir = lerobot_dataset.videos_dir.parent
|
||||||
|
videos_dir = lerobot_dataset.videos_dir
|
||||||
|
repo_id = lerobot_dataset.repo_id
|
||||||
|
video = lerobot_dataset.video
|
||||||
|
meta_data_dir = local_dir / "meta_data"
|
||||||
|
|
||||||
|
if not (local_dir / "train").exists():
|
||||||
|
raise ValueError(
|
||||||
|
"You need to run `save_lerobot_dataset_on_disk(lerobot_dataset)` before pushing to the hub."
|
||||||
|
)
|
||||||
|
|
||||||
|
hf_dataset.push_to_hub(repo_id, revision="main")
|
||||||
|
push_meta_data_to_hub(repo_id, meta_data_dir, revision="main")
|
||||||
|
push_dataset_card_to_hub(repo_id, revision="main", tags=tags)
|
||||||
|
if video:
|
||||||
|
push_videos_to_hub(repo_id, videos_dir, revision="main")
|
||||||
|
create_branch(repo_id, repo_type="dataset", branch=CODEBASE_VERSION)
|
||||||
|
|
||||||
|
|
||||||
|
def create_lerobot_dataset(dataset, run_compute_stats, push_to_hub, tags, play_sounds):
|
||||||
|
if "image_writer" in dataset:
|
||||||
|
logging.info("Waiting for image writer to terminate...")
|
||||||
|
image_writer = dataset["image_writer"]
|
||||||
|
stop_image_writer(image_writer, timeout=20)
|
||||||
|
|
||||||
|
lerobot_dataset = from_dataset_to_lerobot_dataset(dataset, play_sounds)
|
||||||
|
|
||||||
|
if run_compute_stats:
|
||||||
|
log_say("Computing dataset statistics", play_sounds)
|
||||||
|
lerobot_dataset.stats = compute_stats(lerobot_dataset)
|
||||||
|
else:
|
||||||
|
logging.info("Skipping computation of the dataset statistics")
|
||||||
|
lerobot_dataset.stats = {}
|
||||||
|
|
||||||
|
save_lerobot_dataset_on_disk(lerobot_dataset)
|
||||||
|
|
||||||
|
if push_to_hub:
|
||||||
|
push_lerobot_dataset_to_hub(lerobot_dataset, tags)
|
||||||
|
|
||||||
|
return lerobot_dataset
|
|
@ -189,7 +189,7 @@ class Logger:
|
||||||
training_state["scheduler"] = scheduler.state_dict()
|
training_state["scheduler"] = scheduler.state_dict()
|
||||||
torch.save(training_state, save_dir / self.training_state_file_name)
|
torch.save(training_state, save_dir / self.training_state_file_name)
|
||||||
|
|
||||||
def save_checkpont(
|
def save_checkpoint(
|
||||||
self,
|
self,
|
||||||
train_step: int,
|
train_step: int,
|
||||||
policy: Policy,
|
policy: Policy,
|
||||||
|
|
|
@ -0,0 +1,330 @@
|
||||||
|
########################################################################################
|
||||||
|
# Utilities
|
||||||
|
########################################################################################
|
||||||
|
|
||||||
|
|
||||||
|
import logging
|
||||||
|
import time
|
||||||
|
import traceback
|
||||||
|
from contextlib import nullcontext
|
||||||
|
from copy import copy
|
||||||
|
from functools import cache
|
||||||
|
|
||||||
|
import cv2
|
||||||
|
import torch
|
||||||
|
import tqdm
|
||||||
|
from termcolor import colored
|
||||||
|
|
||||||
|
from lerobot.common.datasets.populate_dataset import add_frame, safe_stop_image_writer
|
||||||
|
from lerobot.common.policies.factory import make_policy
|
||||||
|
from lerobot.common.robot_devices.robots.utils import Robot
|
||||||
|
from lerobot.common.robot_devices.utils import busy_wait
|
||||||
|
from lerobot.common.utils.utils import get_safe_torch_device, init_hydra_config, set_global_seed
|
||||||
|
from lerobot.scripts.eval import get_pretrained_policy_path
|
||||||
|
|
||||||
|
|
||||||
|
def log_control_info(robot: Robot, dt_s, episode_index=None, frame_index=None, fps=None):
|
||||||
|
log_items = []
|
||||||
|
if episode_index is not None:
|
||||||
|
log_items.append(f"ep:{episode_index}")
|
||||||
|
if frame_index is not None:
|
||||||
|
log_items.append(f"frame:{frame_index}")
|
||||||
|
|
||||||
|
def log_dt(shortname, dt_val_s):
|
||||||
|
nonlocal log_items, fps
|
||||||
|
info_str = f"{shortname}:{dt_val_s * 1000:5.2f} ({1/ dt_val_s:3.1f}hz)"
|
||||||
|
if fps is not None:
|
||||||
|
actual_fps = 1 / dt_val_s
|
||||||
|
if actual_fps < fps - 1:
|
||||||
|
info_str = colored(info_str, "yellow")
|
||||||
|
log_items.append(info_str)
|
||||||
|
|
||||||
|
# total step time displayed in milliseconds and its frequency
|
||||||
|
log_dt("dt", dt_s)
|
||||||
|
|
||||||
|
# TODO(aliberts): move robot-specific logs logic in robot.print_logs()
|
||||||
|
if not robot.robot_type.startswith("stretch"):
|
||||||
|
for name in robot.leader_arms:
|
||||||
|
key = f"read_leader_{name}_pos_dt_s"
|
||||||
|
if key in robot.logs:
|
||||||
|
log_dt("dtRlead", robot.logs[key])
|
||||||
|
|
||||||
|
for name in robot.follower_arms:
|
||||||
|
key = f"write_follower_{name}_goal_pos_dt_s"
|
||||||
|
if key in robot.logs:
|
||||||
|
log_dt("dtWfoll", robot.logs[key])
|
||||||
|
|
||||||
|
key = f"read_follower_{name}_pos_dt_s"
|
||||||
|
if key in robot.logs:
|
||||||
|
log_dt("dtRfoll", robot.logs[key])
|
||||||
|
|
||||||
|
for name in robot.cameras:
|
||||||
|
key = f"read_camera_{name}_dt_s"
|
||||||
|
if key in robot.logs:
|
||||||
|
log_dt(f"dtR{name}", robot.logs[key])
|
||||||
|
|
||||||
|
info_str = " ".join(log_items)
|
||||||
|
logging.info(info_str)
|
||||||
|
|
||||||
|
|
||||||
|
@cache
|
||||||
|
def is_headless():
|
||||||
|
"""Detects if python is running without a monitor."""
|
||||||
|
try:
|
||||||
|
import pynput # noqa
|
||||||
|
|
||||||
|
return False
|
||||||
|
except Exception:
|
||||||
|
print(
|
||||||
|
"Error trying to import pynput. Switching to headless mode. "
|
||||||
|
"As a result, the video stream from the cameras won't be shown, "
|
||||||
|
"and you won't be able to change the control flow with keyboards. "
|
||||||
|
"For more info, see traceback below.\n"
|
||||||
|
)
|
||||||
|
traceback.print_exc()
|
||||||
|
print()
|
||||||
|
return True
|
||||||
|
|
||||||
|
|
||||||
|
def has_method(_object: object, method_name: str):
|
||||||
|
return hasattr(_object, method_name) and callable(getattr(_object, method_name))
|
||||||
|
|
||||||
|
|
||||||
|
def predict_action(observation, policy, device, use_amp):
|
||||||
|
observation = copy(observation)
|
||||||
|
with (
|
||||||
|
torch.inference_mode(),
|
||||||
|
torch.autocast(device_type=device.type) if device.type == "cuda" and use_amp else nullcontext(),
|
||||||
|
):
|
||||||
|
# Convert to pytorch format: channel first and float32 in [0,1] with batch dimension
|
||||||
|
for name in observation:
|
||||||
|
if "image" in name:
|
||||||
|
observation[name] = observation[name].type(torch.float32) / 255
|
||||||
|
observation[name] = observation[name].permute(2, 0, 1).contiguous()
|
||||||
|
observation[name] = observation[name].unsqueeze(0)
|
||||||
|
observation[name] = observation[name].to(device)
|
||||||
|
|
||||||
|
# Compute the next action with the policy
|
||||||
|
# based on the current observation
|
||||||
|
action = policy.select_action(observation)
|
||||||
|
|
||||||
|
# Remove batch dimension
|
||||||
|
action = action.squeeze(0)
|
||||||
|
|
||||||
|
# Move to cpu, if not already the case
|
||||||
|
action = action.to("cpu")
|
||||||
|
|
||||||
|
return action
|
||||||
|
|
||||||
|
|
||||||
|
def init_keyboard_listener():
|
||||||
|
# Allow to exit early while recording an episode or resetting the environment,
|
||||||
|
# by tapping the right arrow key '->'. This might require a sudo permission
|
||||||
|
# to allow your terminal to monitor keyboard events.
|
||||||
|
events = {}
|
||||||
|
events["exit_early"] = False
|
||||||
|
events["rerecord_episode"] = False
|
||||||
|
events["stop_recording"] = False
|
||||||
|
|
||||||
|
if is_headless():
|
||||||
|
logging.warning(
|
||||||
|
"Headless environment detected. On-screen cameras display and keyboard inputs will not be available."
|
||||||
|
)
|
||||||
|
listener = None
|
||||||
|
return listener, events
|
||||||
|
|
||||||
|
# Only import pynput if not in a headless environment
|
||||||
|
from pynput import keyboard
|
||||||
|
|
||||||
|
def on_press(key):
|
||||||
|
try:
|
||||||
|
if key == keyboard.Key.right:
|
||||||
|
print("Right arrow key pressed. Exiting loop...")
|
||||||
|
events["exit_early"] = True
|
||||||
|
elif key == keyboard.Key.left:
|
||||||
|
print("Left arrow key pressed. Exiting loop and rerecord the last episode...")
|
||||||
|
events["rerecord_episode"] = True
|
||||||
|
events["exit_early"] = True
|
||||||
|
elif key == keyboard.Key.esc:
|
||||||
|
print("Escape key pressed. Stopping data recording...")
|
||||||
|
events["stop_recording"] = True
|
||||||
|
events["exit_early"] = True
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Error handling key press: {e}")
|
||||||
|
|
||||||
|
listener = keyboard.Listener(on_press=on_press)
|
||||||
|
listener.start()
|
||||||
|
|
||||||
|
return listener, events
|
||||||
|
|
||||||
|
|
||||||
|
def init_policy(pretrained_policy_name_or_path, policy_overrides):
|
||||||
|
"""Instantiate the policy and load fps, device and use_amp from config yaml"""
|
||||||
|
pretrained_policy_path = get_pretrained_policy_path(pretrained_policy_name_or_path)
|
||||||
|
hydra_cfg = init_hydra_config(pretrained_policy_path / "config.yaml", policy_overrides)
|
||||||
|
policy = make_policy(hydra_cfg=hydra_cfg, pretrained_policy_name_or_path=pretrained_policy_path)
|
||||||
|
|
||||||
|
# Check device is available
|
||||||
|
device = get_safe_torch_device(hydra_cfg.device, log=True)
|
||||||
|
use_amp = hydra_cfg.use_amp
|
||||||
|
policy_fps = hydra_cfg.env.fps
|
||||||
|
|
||||||
|
policy.eval()
|
||||||
|
policy.to(device)
|
||||||
|
|
||||||
|
torch.backends.cudnn.benchmark = True
|
||||||
|
torch.backends.cuda.matmul.allow_tf32 = True
|
||||||
|
set_global_seed(hydra_cfg.seed)
|
||||||
|
return policy, policy_fps, device, use_amp
|
||||||
|
|
||||||
|
|
||||||
|
def warmup_record(
|
||||||
|
robot,
|
||||||
|
events,
|
||||||
|
enable_teloperation,
|
||||||
|
warmup_time_s,
|
||||||
|
display_cameras,
|
||||||
|
fps,
|
||||||
|
):
|
||||||
|
control_loop(
|
||||||
|
robot=robot,
|
||||||
|
control_time_s=warmup_time_s,
|
||||||
|
display_cameras=display_cameras,
|
||||||
|
events=events,
|
||||||
|
fps=fps,
|
||||||
|
teleoperate=enable_teloperation,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def record_episode(
|
||||||
|
robot,
|
||||||
|
dataset,
|
||||||
|
events,
|
||||||
|
episode_time_s,
|
||||||
|
display_cameras,
|
||||||
|
policy,
|
||||||
|
device,
|
||||||
|
use_amp,
|
||||||
|
fps,
|
||||||
|
):
|
||||||
|
control_loop(
|
||||||
|
robot=robot,
|
||||||
|
control_time_s=episode_time_s,
|
||||||
|
display_cameras=display_cameras,
|
||||||
|
dataset=dataset,
|
||||||
|
events=events,
|
||||||
|
policy=policy,
|
||||||
|
device=device,
|
||||||
|
use_amp=use_amp,
|
||||||
|
fps=fps,
|
||||||
|
teleoperate=policy is None,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@safe_stop_image_writer
|
||||||
|
def control_loop(
|
||||||
|
robot,
|
||||||
|
control_time_s=None,
|
||||||
|
teleoperate=False,
|
||||||
|
display_cameras=False,
|
||||||
|
dataset=None,
|
||||||
|
events=None,
|
||||||
|
policy=None,
|
||||||
|
device=None,
|
||||||
|
use_amp=None,
|
||||||
|
fps=None,
|
||||||
|
):
|
||||||
|
# TODO(rcadene): Add option to record logs
|
||||||
|
if not robot.is_connected:
|
||||||
|
robot.connect()
|
||||||
|
|
||||||
|
if events is None:
|
||||||
|
events = {"exit_early": False}
|
||||||
|
|
||||||
|
if control_time_s is None:
|
||||||
|
control_time_s = float("inf")
|
||||||
|
|
||||||
|
if teleoperate and policy is not None:
|
||||||
|
raise ValueError("When `teleoperate` is True, `policy` should be None.")
|
||||||
|
|
||||||
|
if dataset is not None and fps is not None and dataset["fps"] != fps:
|
||||||
|
raise ValueError(f"The dataset fps should be equal to requested fps ({dataset['fps']} != {fps}).")
|
||||||
|
|
||||||
|
timestamp = 0
|
||||||
|
start_episode_t = time.perf_counter()
|
||||||
|
while timestamp < control_time_s:
|
||||||
|
start_loop_t = time.perf_counter()
|
||||||
|
|
||||||
|
if teleoperate:
|
||||||
|
observation, action = robot.teleop_step(record_data=True)
|
||||||
|
else:
|
||||||
|
observation = robot.capture_observation()
|
||||||
|
|
||||||
|
if policy is not None:
|
||||||
|
pred_action = predict_action(observation, policy, device, use_amp)
|
||||||
|
# Action can eventually be clipped using `max_relative_target`,
|
||||||
|
# so action actually sent is saved in the dataset.
|
||||||
|
action = robot.send_action(pred_action)
|
||||||
|
action = {"action": action}
|
||||||
|
|
||||||
|
if dataset is not None:
|
||||||
|
add_frame(dataset, observation, action)
|
||||||
|
|
||||||
|
if display_cameras and not is_headless():
|
||||||
|
image_keys = [key for key in observation if "image" in key]
|
||||||
|
for key in image_keys:
|
||||||
|
cv2.imshow(key, cv2.cvtColor(observation[key].numpy(), cv2.COLOR_RGB2BGR))
|
||||||
|
cv2.waitKey(1)
|
||||||
|
|
||||||
|
if fps is not None:
|
||||||
|
dt_s = time.perf_counter() - start_loop_t
|
||||||
|
busy_wait(1 / fps - dt_s)
|
||||||
|
|
||||||
|
dt_s = time.perf_counter() - start_loop_t
|
||||||
|
log_control_info(robot, dt_s, fps=fps)
|
||||||
|
|
||||||
|
timestamp = time.perf_counter() - start_episode_t
|
||||||
|
if events["exit_early"]:
|
||||||
|
events["exit_early"] = False
|
||||||
|
break
|
||||||
|
|
||||||
|
|
||||||
|
def reset_environment(robot, events, reset_time_s):
|
||||||
|
# TODO(rcadene): refactor warmup_record and reset_environment
|
||||||
|
# TODO(alibets): allow for teleop during reset
|
||||||
|
if has_method(robot, "teleop_safety_stop"):
|
||||||
|
robot.teleop_safety_stop()
|
||||||
|
|
||||||
|
timestamp = 0
|
||||||
|
start_vencod_t = time.perf_counter()
|
||||||
|
|
||||||
|
# Wait if necessary
|
||||||
|
with tqdm.tqdm(total=reset_time_s, desc="Waiting") as pbar:
|
||||||
|
while timestamp < reset_time_s:
|
||||||
|
time.sleep(1)
|
||||||
|
timestamp = time.perf_counter() - start_vencod_t
|
||||||
|
pbar.update(1)
|
||||||
|
if events["exit_early"]:
|
||||||
|
events["exit_early"] = False
|
||||||
|
break
|
||||||
|
|
||||||
|
|
||||||
|
def stop_recording(robot, listener, display_cameras):
|
||||||
|
robot.disconnect()
|
||||||
|
|
||||||
|
if not is_headless():
|
||||||
|
if listener is not None:
|
||||||
|
listener.stop()
|
||||||
|
|
||||||
|
if display_cameras:
|
||||||
|
cv2.destroyAllWindows()
|
||||||
|
|
||||||
|
|
||||||
|
def sanity_check_dataset_name(repo_id, policy):
|
||||||
|
_, dataset_name = repo_id.split("/")
|
||||||
|
# either repo_id doesnt start with "eval_" and there is no policy
|
||||||
|
# or repo_id starts with "eval_" and there is a policy
|
||||||
|
if dataset_name.startswith("eval_") == (policy is None):
|
||||||
|
raise ValueError(
|
||||||
|
f"Your dataset name begins by 'eval_' ({dataset_name}) but no policy is provided ({policy})."
|
||||||
|
)
|
|
@ -349,6 +349,25 @@ class ManipulatorRobot:
|
||||||
self.is_connected = False
|
self.is_connected = False
|
||||||
self.logs = {}
|
self.logs = {}
|
||||||
|
|
||||||
|
@property
|
||||||
|
def has_camera(self):
|
||||||
|
return len(self.cameras) > 0
|
||||||
|
|
||||||
|
@property
|
||||||
|
def num_cameras(self):
|
||||||
|
return len(self.cameras)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def available_arms(self):
|
||||||
|
available_arms = []
|
||||||
|
for name in self.follower_arms:
|
||||||
|
arm_id = get_arm_id(name, "follower")
|
||||||
|
available_arms.append(arm_id)
|
||||||
|
for name in self.leader_arms:
|
||||||
|
arm_id = get_arm_id(name, "leader")
|
||||||
|
available_arms.append(arm_id)
|
||||||
|
return available_arms
|
||||||
|
|
||||||
def connect(self):
|
def connect(self):
|
||||||
if self.is_connected:
|
if self.is_connected:
|
||||||
raise RobotDeviceAlreadyConnectedError(
|
raise RobotDeviceAlreadyConnectedError(
|
||||||
|
|
|
@ -16,6 +16,7 @@
|
||||||
import logging
|
import logging
|
||||||
import os
|
import os
|
||||||
import os.path as osp
|
import os.path as osp
|
||||||
|
import platform
|
||||||
import random
|
import random
|
||||||
from contextlib import contextmanager
|
from contextlib import contextmanager
|
||||||
from datetime import datetime, timezone
|
from datetime import datetime, timezone
|
||||||
|
@ -28,6 +29,12 @@ import torch
|
||||||
from omegaconf import DictConfig
|
from omegaconf import DictConfig
|
||||||
|
|
||||||
|
|
||||||
|
def none_or_int(value):
|
||||||
|
if value == "None":
|
||||||
|
return None
|
||||||
|
return int(value)
|
||||||
|
|
||||||
|
|
||||||
def inside_slurm():
|
def inside_slurm():
|
||||||
"""Check whether the python process was launched through slurm"""
|
"""Check whether the python process was launched through slurm"""
|
||||||
# TODO(rcadene): return False for interactive mode `--pty bash`
|
# TODO(rcadene): return False for interactive mode `--pty bash`
|
||||||
|
@ -183,3 +190,30 @@ def print_cuda_memory_usage():
|
||||||
|
|
||||||
def capture_timestamp_utc():
|
def capture_timestamp_utc():
|
||||||
return datetime.now(timezone.utc)
|
return datetime.now(timezone.utc)
|
||||||
|
|
||||||
|
|
||||||
|
def say(text, blocking=False):
|
||||||
|
# Check if mac, linux, or windows.
|
||||||
|
if platform.system() == "Darwin":
|
||||||
|
cmd = f'say "{text}"'
|
||||||
|
elif platform.system() == "Linux":
|
||||||
|
cmd = f'spd-say "{text}"'
|
||||||
|
elif platform.system() == "Windows":
|
||||||
|
cmd = (
|
||||||
|
'PowerShell -Command "Add-Type -AssemblyName System.Speech; '
|
||||||
|
f"(New-Object System.Speech.Synthesis.SpeechSynthesizer).Speak('{text}')\""
|
||||||
|
)
|
||||||
|
|
||||||
|
if not blocking and platform.system() in ["Darwin", "Linux"]:
|
||||||
|
# TODO(rcadene): Make it work for Windows
|
||||||
|
# Use the ampersand to run command in the background
|
||||||
|
cmd += " &"
|
||||||
|
|
||||||
|
os.system(cmd)
|
||||||
|
|
||||||
|
|
||||||
|
def log_say(text, play_sounds, blocking=False):
|
||||||
|
logging.info(text)
|
||||||
|
|
||||||
|
if play_sounds:
|
||||||
|
say(text, blocking)
|
||||||
|
|
|
@ -99,285 +99,35 @@ python lerobot/scripts/control_robot.py record \
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import argparse
|
import argparse
|
||||||
import concurrent.futures
|
|
||||||
import json
|
|
||||||
import logging
|
import logging
|
||||||
import multiprocessing
|
|
||||||
import os
|
|
||||||
import platform
|
|
||||||
import shutil
|
|
||||||
import time
|
import time
|
||||||
import traceback
|
|
||||||
from contextlib import nullcontext
|
|
||||||
from functools import cache
|
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
from typing import List
|
||||||
import cv2
|
|
||||||
import torch
|
|
||||||
import tqdm
|
|
||||||
from omegaconf import DictConfig
|
|
||||||
from PIL import Image
|
|
||||||
from termcolor import colored
|
|
||||||
|
|
||||||
# from safetensors.torch import load_file, save_file
|
# from safetensors.torch import load_file, save_file
|
||||||
from lerobot.common.datasets.compute_stats import compute_stats
|
from lerobot.common.datasets.lerobot_dataset import LeRobotDataset
|
||||||
from lerobot.common.datasets.lerobot_dataset import CODEBASE_VERSION, LeRobotDataset
|
from lerobot.common.datasets.populate_dataset import (
|
||||||
from lerobot.common.datasets.push_dataset_to_hub.aloha_hdf5_format import to_hf_dataset
|
create_lerobot_dataset,
|
||||||
from lerobot.common.datasets.push_dataset_to_hub.utils import concatenate_episodes, get_default_encoding
|
delete_current_episode,
|
||||||
from lerobot.common.datasets.utils import calculate_episode_data_index, create_branch
|
init_dataset,
|
||||||
from lerobot.common.datasets.video_utils import encode_video_frames
|
save_current_episode,
|
||||||
from lerobot.common.policies.factory import make_policy
|
|
||||||
from lerobot.common.robot_devices.robots.factory import make_robot
|
|
||||||
from lerobot.common.robot_devices.robots.utils import Robot, get_arm_id
|
|
||||||
from lerobot.common.robot_devices.utils import busy_wait, safe_disconnect
|
|
||||||
from lerobot.common.utils.utils import get_safe_torch_device, init_hydra_config, init_logging, set_global_seed
|
|
||||||
from lerobot.scripts.eval import get_pretrained_policy_path
|
|
||||||
from lerobot.scripts.push_dataset_to_hub import (
|
|
||||||
push_dataset_card_to_hub,
|
|
||||||
push_meta_data_to_hub,
|
|
||||||
push_videos_to_hub,
|
|
||||||
save_meta_data,
|
|
||||||
)
|
)
|
||||||
|
from lerobot.common.robot_devices.control_utils import (
|
||||||
########################################################################################
|
control_loop,
|
||||||
# Utilities
|
has_method,
|
||||||
########################################################################################
|
init_keyboard_listener,
|
||||||
|
init_policy,
|
||||||
|
log_control_info,
|
||||||
def say(text, blocking=False):
|
record_episode,
|
||||||
# Check if mac, linux, or windows.
|
reset_environment,
|
||||||
if platform.system() == "Darwin":
|
sanity_check_dataset_name,
|
||||||
cmd = f'say "{text}"'
|
stop_recording,
|
||||||
elif platform.system() == "Linux":
|
warmup_record,
|
||||||
cmd = f'spd-say "{text}"'
|
)
|
||||||
elif platform.system() == "Windows":
|
from lerobot.common.robot_devices.robots.factory import make_robot
|
||||||
cmd = (
|
from lerobot.common.robot_devices.robots.utils import Robot
|
||||||
'PowerShell -Command "Add-Type -AssemblyName System.Speech; '
|
from lerobot.common.robot_devices.utils import busy_wait, safe_disconnect
|
||||||
f"(New-Object System.Speech.Synthesis.SpeechSynthesizer).Speak('{text}')\""
|
from lerobot.common.utils.utils import init_hydra_config, init_logging, log_say, none_or_int
|
||||||
)
|
|
||||||
|
|
||||||
if not blocking and platform.system() in ["Darwin", "Linux"]:
|
|
||||||
# TODO(rcadene): Make it work for Windows
|
|
||||||
# Use the ampersand to run command in the background
|
|
||||||
cmd += " &"
|
|
||||||
|
|
||||||
os.system(cmd)
|
|
||||||
|
|
||||||
|
|
||||||
def save_image(img_tensor, key, frame_index, episode_index, videos_dir: str):
|
|
||||||
img = Image.fromarray(img_tensor.numpy())
|
|
||||||
path = Path(videos_dir) / f"{key}_episode_{episode_index:06d}" / f"frame_{frame_index:06d}.png"
|
|
||||||
path.parent.mkdir(parents=True, exist_ok=True)
|
|
||||||
img.save(str(path), quality=100)
|
|
||||||
|
|
||||||
|
|
||||||
def none_or_int(value):
|
|
||||||
if value == "None":
|
|
||||||
return None
|
|
||||||
return int(value)
|
|
||||||
|
|
||||||
|
|
||||||
def log_control_info(robot: Robot, dt_s, episode_index=None, frame_index=None, fps=None):
|
|
||||||
log_items = []
|
|
||||||
if episode_index is not None:
|
|
||||||
log_items.append(f"ep:{episode_index}")
|
|
||||||
if frame_index is not None:
|
|
||||||
log_items.append(f"frame:{frame_index}")
|
|
||||||
|
|
||||||
def log_dt(shortname, dt_val_s):
|
|
||||||
nonlocal log_items, fps
|
|
||||||
info_str = f"{shortname}:{dt_val_s * 1000:5.2f} ({1/ dt_val_s:3.1f}hz)"
|
|
||||||
if fps is not None:
|
|
||||||
actual_fps = 1 / dt_val_s
|
|
||||||
if actual_fps < fps - 1:
|
|
||||||
info_str = colored(info_str, "yellow")
|
|
||||||
log_items.append(info_str)
|
|
||||||
|
|
||||||
# total step time displayed in milliseconds and its frequency
|
|
||||||
log_dt("dt", dt_s)
|
|
||||||
|
|
||||||
# TODO(aliberts): move robot-specific logs logic in robot.print_logs()
|
|
||||||
if not robot.robot_type.startswith("stretch"):
|
|
||||||
for name in robot.leader_arms:
|
|
||||||
key = f"read_leader_{name}_pos_dt_s"
|
|
||||||
if key in robot.logs:
|
|
||||||
log_dt("dtRlead", robot.logs[key])
|
|
||||||
|
|
||||||
for name in robot.follower_arms:
|
|
||||||
key = f"write_follower_{name}_goal_pos_dt_s"
|
|
||||||
if key in robot.logs:
|
|
||||||
log_dt("dtWfoll", robot.logs[key])
|
|
||||||
|
|
||||||
key = f"read_follower_{name}_pos_dt_s"
|
|
||||||
if key in robot.logs:
|
|
||||||
log_dt("dtRfoll", robot.logs[key])
|
|
||||||
|
|
||||||
for name in robot.cameras:
|
|
||||||
key = f"read_camera_{name}_dt_s"
|
|
||||||
if key in robot.logs:
|
|
||||||
log_dt(f"dtR{name}", robot.logs[key])
|
|
||||||
|
|
||||||
info_str = " ".join(log_items)
|
|
||||||
logging.info(info_str)
|
|
||||||
|
|
||||||
|
|
||||||
@cache
|
|
||||||
def is_headless():
|
|
||||||
"""Detects if python is running without a monitor."""
|
|
||||||
try:
|
|
||||||
import pynput # noqa
|
|
||||||
|
|
||||||
return False
|
|
||||||
except Exception:
|
|
||||||
print(
|
|
||||||
"Error trying to import pynput. Switching to headless mode. "
|
|
||||||
"As a result, the video stream from the cameras won't be shown, "
|
|
||||||
"and you won't be able to change the control flow with keyboards. "
|
|
||||||
"For more info, see traceback below.\n"
|
|
||||||
)
|
|
||||||
traceback.print_exc()
|
|
||||||
print()
|
|
||||||
return True
|
|
||||||
|
|
||||||
|
|
||||||
def has_method(_object: object, method_name: str):
|
|
||||||
return hasattr(_object, method_name) and callable(getattr(_object, method_name))
|
|
||||||
|
|
||||||
|
|
||||||
def get_available_arms(robot):
|
|
||||||
# TODO(rcadene): moves this function in manipulator class?
|
|
||||||
available_arms = []
|
|
||||||
for name in robot.follower_arms:
|
|
||||||
arm_id = get_arm_id(name, "follower")
|
|
||||||
available_arms.append(arm_id)
|
|
||||||
for name in robot.leader_arms:
|
|
||||||
arm_id = get_arm_id(name, "leader")
|
|
||||||
available_arms.append(arm_id)
|
|
||||||
return available_arms
|
|
||||||
|
|
||||||
|
|
||||||
########################################################################################
|
|
||||||
# Asynchrounous saving of images on disk
|
|
||||||
########################################################################################
|
|
||||||
|
|
||||||
|
|
||||||
def loop_to_save_images_in_threads(image_queue, num_threads):
|
|
||||||
if num_threads < 1:
|
|
||||||
raise NotImplementedError(f"Only `num_threads>=1` is supported for now, but {num_threads=} given.")
|
|
||||||
|
|
||||||
with concurrent.futures.ThreadPoolExecutor(max_workers=num_threads) as executor:
|
|
||||||
futures = []
|
|
||||||
while True:
|
|
||||||
# Blocks until a frame is available
|
|
||||||
frame_data = image_queue.get()
|
|
||||||
|
|
||||||
# As usually done, exit loop when receiving None to stop the worker
|
|
||||||
if frame_data is None:
|
|
||||||
break
|
|
||||||
|
|
||||||
image, key, frame_index, episode_index, videos_dir = frame_data
|
|
||||||
futures.append(executor.submit(save_image, image, key, frame_index, episode_index, videos_dir))
|
|
||||||
|
|
||||||
# Before exiting function, wait for all threads to complete
|
|
||||||
with tqdm.tqdm(total=len(futures), desc="Writing images") as progress_bar:
|
|
||||||
concurrent.futures.wait(futures)
|
|
||||||
progress_bar.update(len(futures))
|
|
||||||
|
|
||||||
|
|
||||||
def start_image_writer_processes(image_queue, num_processes, num_threads_per_process):
|
|
||||||
if num_processes < 1:
|
|
||||||
raise ValueError(f"Only `num_processes>=1` is supported, but {num_processes=} given.")
|
|
||||||
|
|
||||||
if num_threads_per_process < 1:
|
|
||||||
raise NotImplementedError(
|
|
||||||
"Only `num_threads_per_process>=1` is supported for now, but {num_threads_per_process=} given."
|
|
||||||
)
|
|
||||||
|
|
||||||
processes = []
|
|
||||||
for _ in range(num_processes):
|
|
||||||
process = multiprocessing.Process(
|
|
||||||
target=loop_to_save_images_in_threads,
|
|
||||||
args=(image_queue, num_threads_per_process),
|
|
||||||
)
|
|
||||||
process.start()
|
|
||||||
processes.append(process)
|
|
||||||
return processes
|
|
||||||
|
|
||||||
|
|
||||||
def stop_processes(processes, queue, timeout):
|
|
||||||
# Send None to each process to signal them to stop
|
|
||||||
for _ in processes:
|
|
||||||
queue.put(None)
|
|
||||||
|
|
||||||
# Close the queue, no more items can be put in the queue
|
|
||||||
queue.close()
|
|
||||||
|
|
||||||
# Wait maximum 20 seconds for all processes to terminate
|
|
||||||
for process in processes:
|
|
||||||
process.join(timeout=timeout)
|
|
||||||
|
|
||||||
# If not terminated after 20 seconds, force termination
|
|
||||||
if process.is_alive():
|
|
||||||
process.terminate()
|
|
||||||
|
|
||||||
# Ensure all background queue threads have finished
|
|
||||||
queue.join_thread()
|
|
||||||
|
|
||||||
|
|
||||||
def start_image_writer(num_processes, num_threads):
|
|
||||||
"""This function abstract away the initialisation of processes or/and threads to
|
|
||||||
save images on disk asynchrounously, which is critical to control a robot and record data
|
|
||||||
at a high frame rate.
|
|
||||||
|
|
||||||
When `num_processes=0`, it returns a dictionary containing a threads pool of size `num_threads`.
|
|
||||||
When `num_processes>0`, it returns a dictionary containing a processes pool of size `num_processes`,
|
|
||||||
where each subprocess starts their own threads pool of size `num_threads`.
|
|
||||||
|
|
||||||
The optimal number of processes and threads depends on your computer capabilities.
|
|
||||||
We advise to use 4 threads per camera with 0 processes. If the fps is not stable, try to increase or lower
|
|
||||||
the number of threads. If it is still not stable, try to use 1 subprocess, or more.
|
|
||||||
"""
|
|
||||||
image_writer = {}
|
|
||||||
|
|
||||||
if num_processes == 0:
|
|
||||||
futures = []
|
|
||||||
threads_pool = concurrent.futures.ThreadPoolExecutor(max_workers=num_threads)
|
|
||||||
image_writer["threads_pool"], image_writer["futures"] = threads_pool, futures
|
|
||||||
else:
|
|
||||||
# TODO(rcadene): When using num_processes>1, `multiprocessing.Manager().Queue()`
|
|
||||||
# might be better than `multiprocessing.Queue()`. Source: https://www.geeksforgeeks.org/python-multiprocessing-queue-vs-multiprocessing-manager-queue
|
|
||||||
image_queue = multiprocessing.Queue()
|
|
||||||
processes_pool = start_image_writer_processes(
|
|
||||||
image_queue, num_processes=num_processes, num_threads_per_process=num_threads
|
|
||||||
)
|
|
||||||
image_writer["processes_pool"], image_writer["image_queue"] = processes_pool, image_queue
|
|
||||||
|
|
||||||
return image_writer
|
|
||||||
|
|
||||||
|
|
||||||
def async_save_image(image_writer, image, key, frame_index, episode_index, videos_dir):
|
|
||||||
"""This function abstract away the saving of an image on disk asynchrounously. It uses a dictionary
|
|
||||||
called image writer which contains either a pool of processes or a pool of threads.
|
|
||||||
"""
|
|
||||||
if "threads_pool" in image_writer:
|
|
||||||
threads_pool, futures = image_writer["threads_pool"], image_writer["futures"]
|
|
||||||
futures.append(threads_pool.submit(save_image, image, key, frame_index, episode_index, videos_dir))
|
|
||||||
else:
|
|
||||||
image_queue = image_writer["image_queue"]
|
|
||||||
image_queue.put((image, key, frame_index, episode_index, videos_dir))
|
|
||||||
|
|
||||||
|
|
||||||
def stop_image_writer(image_writer, timeout):
|
|
||||||
if "threads_pool" in image_writer:
|
|
||||||
futures = image_writer["futures"]
|
|
||||||
# Before exiting function, wait for all threads to complete
|
|
||||||
with tqdm.tqdm(total=len(futures), desc="Writing images") as progress_bar:
|
|
||||||
concurrent.futures.wait(futures, timeout=timeout)
|
|
||||||
progress_bar.update(len(futures))
|
|
||||||
else:
|
|
||||||
processes_pool, image_queue = image_writer["processes_pool"], image_writer["image_queue"]
|
|
||||||
stop_processes(processes_pool, image_queue, timeout=timeout)
|
|
||||||
|
|
||||||
|
|
||||||
########################################################################################
|
########################################################################################
|
||||||
# Control modes
|
# Control modes
|
||||||
|
@ -394,9 +144,8 @@ def calibrate(robot: Robot, arms: list[str] | None):
|
||||||
robot.home()
|
robot.home()
|
||||||
return
|
return
|
||||||
|
|
||||||
available_arms = get_available_arms(robot)
|
unknown_arms = [arm_id for arm_id in arms if arm_id not in robot.available_arms]
|
||||||
unknown_arms = [arm_id for arm_id in arms if arm_id not in available_arms]
|
available_arms_str = " ".join(robot.available_arms)
|
||||||
available_arms_str = " ".join(available_arms)
|
|
||||||
unknown_arms_str = " ".join(unknown_arms)
|
unknown_arms_str = " ".join(unknown_arms)
|
||||||
|
|
||||||
if arms is None or len(arms) == 0:
|
if arms is None or len(arms) == 0:
|
||||||
|
@ -429,35 +178,26 @@ def calibrate(robot: Robot, arms: list[str] | None):
|
||||||
|
|
||||||
|
|
||||||
@safe_disconnect
|
@safe_disconnect
|
||||||
def teleoperate(robot: Robot, fps: int | None = None, teleop_time_s: float | None = None):
|
def teleoperate(
|
||||||
# TODO(rcadene): Add option to record logs
|
robot: Robot, fps: int | None = None, teleop_time_s: float | None = None, display_cameras: bool = False
|
||||||
if not robot.is_connected:
|
):
|
||||||
robot.connect()
|
control_loop(
|
||||||
|
robot,
|
||||||
start_teleop_t = time.perf_counter()
|
control_time_s=teleop_time_s,
|
||||||
while True:
|
fps=fps,
|
||||||
start_loop_t = time.perf_counter()
|
teleoperate=True,
|
||||||
robot.teleop_step()
|
display_cameras=display_cameras,
|
||||||
|
)
|
||||||
if fps is not None:
|
|
||||||
dt_s = time.perf_counter() - start_loop_t
|
|
||||||
busy_wait(1 / fps - dt_s)
|
|
||||||
|
|
||||||
dt_s = time.perf_counter() - start_loop_t
|
|
||||||
log_control_info(robot, dt_s, fps=fps)
|
|
||||||
|
|
||||||
if teleop_time_s is not None and time.perf_counter() - start_teleop_t > teleop_time_s:
|
|
||||||
break
|
|
||||||
|
|
||||||
|
|
||||||
@safe_disconnect
|
@safe_disconnect
|
||||||
def record(
|
def record(
|
||||||
robot: Robot,
|
robot: Robot,
|
||||||
policy: torch.nn.Module | None = None,
|
root: str,
|
||||||
hydra_cfg: DictConfig | None = None,
|
repo_id: str,
|
||||||
|
pretrained_policy_name_or_path: str | None = None,
|
||||||
|
policy_overrides: List[str] | None = None,
|
||||||
fps: int | None = None,
|
fps: int | None = None,
|
||||||
root="data",
|
|
||||||
repo_id="lerobot/debug",
|
|
||||||
warmup_time_s=2,
|
warmup_time_s=2,
|
||||||
episode_time_s=10,
|
episode_time_s=10,
|
||||||
reset_time_s=5,
|
reset_time_s=5,
|
||||||
|
@ -473,407 +213,108 @@ def record(
|
||||||
play_sounds=True,
|
play_sounds=True,
|
||||||
):
|
):
|
||||||
# TODO(rcadene): Add option to record logs
|
# TODO(rcadene): Add option to record logs
|
||||||
# TODO(rcadene): Clean this function via decomposition in higher level functions
|
listener = None
|
||||||
|
events = None
|
||||||
|
policy = None
|
||||||
|
device = None
|
||||||
|
use_amp = None
|
||||||
|
|
||||||
_, dataset_name = repo_id.split("/")
|
# Load pretrained policy
|
||||||
if dataset_name.startswith("eval_") and policy is None:
|
if pretrained_policy_name_or_path is not None:
|
||||||
raise ValueError(
|
policy, policy_fps, device, use_amp = init_policy(pretrained_policy_name_or_path, policy_overrides)
|
||||||
f"Your dataset name begins by 'eval_' ({dataset_name}) but no policy is provided ({policy})."
|
|
||||||
)
|
if fps is None:
|
||||||
|
fps = policy_fps
|
||||||
|
logging.warning(f"No fps provided, so using the fps from policy config ({policy_fps}).")
|
||||||
|
elif fps != policy_fps:
|
||||||
|
logging.warning(
|
||||||
|
f"There is a mismatch between the provided fps ({fps}) and the one from policy config ({policy_fps})."
|
||||||
|
)
|
||||||
|
|
||||||
|
# Create empty dataset or load existing saved episodes
|
||||||
|
sanity_check_dataset_name(repo_id, policy)
|
||||||
|
dataset = init_dataset(
|
||||||
|
repo_id,
|
||||||
|
root,
|
||||||
|
force_override,
|
||||||
|
fps,
|
||||||
|
video,
|
||||||
|
write_images=robot.has_camera,
|
||||||
|
num_image_writer_processes=num_image_writer_processes,
|
||||||
|
num_image_writer_threads=num_image_writer_threads_per_camera * robot.num_cameras,
|
||||||
|
)
|
||||||
|
|
||||||
if not robot.is_connected:
|
if not robot.is_connected:
|
||||||
robot.connect()
|
robot.connect()
|
||||||
|
|
||||||
local_dir = Path(root) / repo_id
|
listener, events = init_keyboard_listener()
|
||||||
if local_dir.exists() and force_override:
|
|
||||||
shutil.rmtree(local_dir)
|
|
||||||
|
|
||||||
episodes_dir = local_dir / "episodes"
|
# Execute a few seconds without recording to:
|
||||||
episodes_dir.mkdir(parents=True, exist_ok=True)
|
# 1. teleoperate the robot to move it in starting position if no policy provided,
|
||||||
|
# 2. give times to the robot devices to connect and start synchronizing,
|
||||||
videos_dir = local_dir / "videos"
|
# 3. place the cameras windows on screen
|
||||||
videos_dir.mkdir(parents=True, exist_ok=True)
|
enable_teleoperation = policy is None
|
||||||
|
log_say("Warmup record", play_sounds)
|
||||||
# Logic to resume data recording
|
warmup_record(robot, events, enable_teleoperation, warmup_time_s, display_cameras, fps)
|
||||||
rec_info_path = episodes_dir / "data_recording_info.json"
|
|
||||||
if rec_info_path.exists():
|
|
||||||
with open(rec_info_path) as f:
|
|
||||||
rec_info = json.load(f)
|
|
||||||
episode_index = rec_info["last_episode_index"] + 1
|
|
||||||
else:
|
|
||||||
episode_index = 0
|
|
||||||
|
|
||||||
if is_headless():
|
|
||||||
logging.warning(
|
|
||||||
"Headless environment detected. On-screen cameras display and keyboard inputs will not be available."
|
|
||||||
)
|
|
||||||
|
|
||||||
# Allow to exit early while recording an episode or resetting the environment,
|
|
||||||
# by tapping the right arrow key '->'. This might require a sudo permission
|
|
||||||
# to allow your terminal to monitor keyboard events.
|
|
||||||
exit_early = False
|
|
||||||
rerecord_episode = False
|
|
||||||
stop_recording = False
|
|
||||||
|
|
||||||
# Only import pynput if not in a headless environment
|
|
||||||
if not is_headless():
|
|
||||||
from pynput import keyboard
|
|
||||||
|
|
||||||
def on_press(key):
|
|
||||||
nonlocal exit_early, rerecord_episode, stop_recording
|
|
||||||
try:
|
|
||||||
if key == keyboard.Key.right:
|
|
||||||
print("Right arrow key pressed. Exiting loop...")
|
|
||||||
exit_early = True
|
|
||||||
elif key == keyboard.Key.left:
|
|
||||||
print("Left arrow key pressed. Exiting loop and rerecord the last episode...")
|
|
||||||
rerecord_episode = True
|
|
||||||
exit_early = True
|
|
||||||
elif key == keyboard.Key.esc:
|
|
||||||
print("Escape key pressed. Stopping data recording...")
|
|
||||||
stop_recording = True
|
|
||||||
exit_early = True
|
|
||||||
except Exception as e:
|
|
||||||
print(f"Error handling key press: {e}")
|
|
||||||
|
|
||||||
listener = keyboard.Listener(on_press=on_press)
|
|
||||||
listener.start()
|
|
||||||
|
|
||||||
# Load policy if any
|
|
||||||
if policy is not None:
|
|
||||||
# Check device is available
|
|
||||||
device = get_safe_torch_device(hydra_cfg.device, log=True)
|
|
||||||
|
|
||||||
policy.eval()
|
|
||||||
policy.to(device)
|
|
||||||
|
|
||||||
torch.backends.cudnn.benchmark = True
|
|
||||||
torch.backends.cuda.matmul.allow_tf32 = True
|
|
||||||
set_global_seed(hydra_cfg.seed)
|
|
||||||
|
|
||||||
# override fps using policy fps
|
|
||||||
fps = hydra_cfg.env.fps
|
|
||||||
|
|
||||||
# Execute a few seconds without recording data, to give times
|
|
||||||
# to the robot devices to connect and start synchronizing.
|
|
||||||
timestamp = 0
|
|
||||||
start_warmup_t = time.perf_counter()
|
|
||||||
is_warmup_print = False
|
|
||||||
while timestamp < warmup_time_s:
|
|
||||||
if not is_warmup_print:
|
|
||||||
logging.info("Warming up (no data recording)")
|
|
||||||
if play_sounds:
|
|
||||||
say("Warming up")
|
|
||||||
is_warmup_print = True
|
|
||||||
|
|
||||||
start_loop_t = time.perf_counter()
|
|
||||||
|
|
||||||
if policy is None:
|
|
||||||
observation, action = robot.teleop_step(record_data=True)
|
|
||||||
else:
|
|
||||||
observation = robot.capture_observation()
|
|
||||||
|
|
||||||
if display_cameras and not is_headless():
|
|
||||||
image_keys = [key for key in observation if "image" in key]
|
|
||||||
for key in image_keys:
|
|
||||||
cv2.imshow(key, cv2.cvtColor(observation[key].numpy(), cv2.COLOR_RGB2BGR))
|
|
||||||
cv2.waitKey(1)
|
|
||||||
|
|
||||||
dt_s = time.perf_counter() - start_loop_t
|
|
||||||
busy_wait(1 / fps - dt_s)
|
|
||||||
|
|
||||||
dt_s = time.perf_counter() - start_loop_t
|
|
||||||
log_control_info(robot, dt_s, fps=fps)
|
|
||||||
|
|
||||||
timestamp = time.perf_counter() - start_warmup_t
|
|
||||||
|
|
||||||
if has_method(robot, "teleop_safety_stop"):
|
if has_method(robot, "teleop_safety_stop"):
|
||||||
robot.teleop_safety_stop()
|
robot.teleop_safety_stop()
|
||||||
|
|
||||||
has_camera = len(robot.cameras) > 0
|
while True:
|
||||||
if has_camera:
|
if dataset["num_episodes"] >= num_episodes:
|
||||||
# Initialize processes or/and threads dedicated to save images on disk asynchronously,
|
break
|
||||||
# which is critical to control a robot and record data at a high frame rate.
|
|
||||||
image_writer = start_image_writer(
|
episode_index = dataset["num_episodes"]
|
||||||
num_processes=num_image_writer_processes,
|
log_say(f"Recording episode {episode_index}", play_sounds)
|
||||||
num_threads=num_image_writer_threads_per_camera * len(robot.cameras),
|
record_episode(
|
||||||
|
dataset=dataset,
|
||||||
|
robot=robot,
|
||||||
|
events=events,
|
||||||
|
episode_time_s=episode_time_s,
|
||||||
|
display_cameras=display_cameras,
|
||||||
|
policy=policy,
|
||||||
|
device=device,
|
||||||
|
use_amp=use_amp,
|
||||||
|
fps=fps,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Using `try` to exist smoothly if an exception is raised
|
# Execute a few seconds without recording to give time to manually reset the environment
|
||||||
try:
|
# Current code logic doesn't allow to teleoperate during this time.
|
||||||
# Start recording all episodes
|
# TODO(rcadene): add an option to enable teleoperation during reset
|
||||||
while episode_index < num_episodes:
|
# Skip reset for the last episode to be recorded
|
||||||
logging.info(f"Recording episode {episode_index}")
|
if not events["stop_recording"] and (
|
||||||
if play_sounds:
|
(episode_index < num_episodes - 1) or events["rerecord_episode"]
|
||||||
say(f"Recording episode {episode_index}")
|
):
|
||||||
ep_dict = {}
|
log_say("Reset the environment", play_sounds)
|
||||||
frame_index = 0
|
reset_environment(robot, events, reset_time_s)
|
||||||
timestamp = 0
|
|
||||||
start_episode_t = time.perf_counter()
|
|
||||||
while timestamp < episode_time_s:
|
|
||||||
start_loop_t = time.perf_counter()
|
|
||||||
|
|
||||||
if policy is None:
|
if events["rerecord_episode"]:
|
||||||
observation, action = robot.teleop_step(record_data=True)
|
log_say("Re-record episode", play_sounds)
|
||||||
else:
|
events["rerecord_episode"] = False
|
||||||
observation = robot.capture_observation()
|
events["exit_early"] = False
|
||||||
|
delete_current_episode(dataset)
|
||||||
|
continue
|
||||||
|
|
||||||
image_keys = [key for key in observation if "image" in key]
|
# Increment by one dataset["current_episode_index"]
|
||||||
not_image_keys = [key for key in observation if "image" not in key]
|
save_current_episode(dataset)
|
||||||
|
|
||||||
if has_camera > 0:
|
if events["stop_recording"]:
|
||||||
for key in image_keys:
|
break
|
||||||
async_save_image(
|
|
||||||
image_writer,
|
|
||||||
image=observation[key],
|
|
||||||
key=key,
|
|
||||||
frame_index=frame_index,
|
|
||||||
episode_index=episode_index,
|
|
||||||
videos_dir=str(videos_dir),
|
|
||||||
)
|
|
||||||
|
|
||||||
if display_cameras and not is_headless():
|
log_say("Stop recording", play_sounds, blocking=True)
|
||||||
image_keys = [key for key in observation if "image" in key]
|
stop_recording(robot, listener, display_cameras)
|
||||||
for key in image_keys:
|
|
||||||
cv2.imshow(key, cv2.cvtColor(observation[key].numpy(), cv2.COLOR_RGB2BGR))
|
|
||||||
cv2.waitKey(1)
|
|
||||||
|
|
||||||
for key in not_image_keys:
|
lerobot_dataset = create_lerobot_dataset(dataset, run_compute_stats, push_to_hub, tags, play_sounds)
|
||||||
if key not in ep_dict:
|
|
||||||
ep_dict[key] = []
|
|
||||||
ep_dict[key].append(observation[key])
|
|
||||||
|
|
||||||
if policy is not None:
|
log_say("Exiting", play_sounds)
|
||||||
with (
|
|
||||||
torch.inference_mode(),
|
|
||||||
torch.autocast(device_type=device.type)
|
|
||||||
if device.type == "cuda" and hydra_cfg.use_amp
|
|
||||||
else nullcontext(),
|
|
||||||
):
|
|
||||||
# Convert to pytorch format: channel first and float32 in [0,1] with batch dimension
|
|
||||||
for name in observation:
|
|
||||||
if "image" in name:
|
|
||||||
observation[name] = observation[name].type(torch.float32) / 255
|
|
||||||
observation[name] = observation[name].permute(2, 0, 1).contiguous()
|
|
||||||
observation[name] = observation[name].unsqueeze(0)
|
|
||||||
observation[name] = observation[name].to(device)
|
|
||||||
|
|
||||||
# Compute the next action with the policy
|
|
||||||
# based on the current observation
|
|
||||||
action = policy.select_action(observation)
|
|
||||||
|
|
||||||
# Remove batch dimension
|
|
||||||
action = action.squeeze(0)
|
|
||||||
|
|
||||||
# Move to cpu, if not already the case
|
|
||||||
action = action.to("cpu")
|
|
||||||
|
|
||||||
# Order the robot to move
|
|
||||||
action_sent = robot.send_action(action)
|
|
||||||
|
|
||||||
# Action can eventually be clipped using `max_relative_target`,
|
|
||||||
# so action actually sent is saved in the dataset.
|
|
||||||
action = {"action": action_sent}
|
|
||||||
|
|
||||||
for key in action:
|
|
||||||
if key not in ep_dict:
|
|
||||||
ep_dict[key] = []
|
|
||||||
ep_dict[key].append(action[key])
|
|
||||||
|
|
||||||
frame_index += 1
|
|
||||||
|
|
||||||
dt_s = time.perf_counter() - start_loop_t
|
|
||||||
busy_wait(1 / fps - dt_s)
|
|
||||||
|
|
||||||
dt_s = time.perf_counter() - start_loop_t
|
|
||||||
log_control_info(robot, dt_s, fps=fps)
|
|
||||||
|
|
||||||
timestamp = time.perf_counter() - start_episode_t
|
|
||||||
if exit_early:
|
|
||||||
exit_early = False
|
|
||||||
break
|
|
||||||
|
|
||||||
# TODO(alibets): allow for teleop during reset
|
|
||||||
if has_method(robot, "teleop_safety_stop"):
|
|
||||||
robot.teleop_safety_stop()
|
|
||||||
|
|
||||||
if not stop_recording:
|
|
||||||
# Start resetting env while the executor are finishing
|
|
||||||
logging.info("Reset the environment")
|
|
||||||
if play_sounds:
|
|
||||||
say("Reset the environment")
|
|
||||||
|
|
||||||
timestamp = 0
|
|
||||||
start_vencod_t = time.perf_counter()
|
|
||||||
|
|
||||||
# During env reset we save the data and encode the videos
|
|
||||||
num_frames = frame_index
|
|
||||||
|
|
||||||
for key in image_keys:
|
|
||||||
if video:
|
|
||||||
tmp_imgs_dir = videos_dir / f"{key}_episode_{episode_index:06d}"
|
|
||||||
fname = f"{key}_episode_{episode_index:06d}.mp4"
|
|
||||||
video_path = local_dir / "videos" / fname
|
|
||||||
if video_path.exists():
|
|
||||||
video_path.unlink()
|
|
||||||
# Store the reference to the video frame, even tho the videos are not yet encoded
|
|
||||||
ep_dict[key] = []
|
|
||||||
for i in range(num_frames):
|
|
||||||
ep_dict[key].append({"path": f"videos/{fname}", "timestamp": i / fps})
|
|
||||||
|
|
||||||
else:
|
|
||||||
imgs_dir = videos_dir / f"{key}_episode_{episode_index:06d}"
|
|
||||||
ep_dict[key] = []
|
|
||||||
for i in range(num_frames):
|
|
||||||
img_path = imgs_dir / f"frame_{i:06d}.png"
|
|
||||||
ep_dict[key].append({"path": str(img_path)})
|
|
||||||
|
|
||||||
for key in not_image_keys:
|
|
||||||
ep_dict[key] = torch.stack(ep_dict[key])
|
|
||||||
|
|
||||||
for key in action:
|
|
||||||
ep_dict[key] = torch.stack(ep_dict[key])
|
|
||||||
|
|
||||||
ep_dict["episode_index"] = torch.tensor([episode_index] * num_frames)
|
|
||||||
ep_dict["frame_index"] = torch.arange(0, num_frames, 1)
|
|
||||||
ep_dict["timestamp"] = torch.arange(0, num_frames, 1) / fps
|
|
||||||
|
|
||||||
done = torch.zeros(num_frames, dtype=torch.bool)
|
|
||||||
done[-1] = True
|
|
||||||
ep_dict["next.done"] = done
|
|
||||||
|
|
||||||
ep_path = episodes_dir / f"episode_{episode_index}.pth"
|
|
||||||
print("Saving episode dictionary...")
|
|
||||||
torch.save(ep_dict, ep_path)
|
|
||||||
|
|
||||||
rec_info = {
|
|
||||||
"last_episode_index": episode_index,
|
|
||||||
}
|
|
||||||
with open(rec_info_path, "w") as f:
|
|
||||||
json.dump(rec_info, f)
|
|
||||||
|
|
||||||
is_last_episode = stop_recording or (episode_index == (num_episodes - 1))
|
|
||||||
|
|
||||||
# Wait if necessary
|
|
||||||
with tqdm.tqdm(total=reset_time_s, desc="Waiting") as pbar:
|
|
||||||
while timestamp < reset_time_s and not is_last_episode:
|
|
||||||
time.sleep(1)
|
|
||||||
timestamp = time.perf_counter() - start_vencod_t
|
|
||||||
pbar.update(1)
|
|
||||||
if exit_early:
|
|
||||||
exit_early = False
|
|
||||||
break
|
|
||||||
|
|
||||||
# Skip updating episode index which forces re-recording episode
|
|
||||||
if rerecord_episode:
|
|
||||||
rerecord_episode = False
|
|
||||||
continue
|
|
||||||
|
|
||||||
episode_index += 1
|
|
||||||
|
|
||||||
if is_last_episode:
|
|
||||||
logging.info("Done recording")
|
|
||||||
if play_sounds:
|
|
||||||
say("Done recording", blocking=True)
|
|
||||||
if not is_headless():
|
|
||||||
listener.stop()
|
|
||||||
|
|
||||||
if has_camera > 0:
|
|
||||||
logging.info("Waiting for image writer to terminate...")
|
|
||||||
stop_image_writer(image_writer, timeout=20)
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
if has_camera > 0:
|
|
||||||
logging.info("Waiting for image writer to terminate...")
|
|
||||||
stop_image_writer(image_writer, timeout=20)
|
|
||||||
raise e
|
|
||||||
|
|
||||||
robot.disconnect()
|
|
||||||
|
|
||||||
if display_cameras and not is_headless():
|
|
||||||
cv2.destroyAllWindows()
|
|
||||||
|
|
||||||
num_episodes = episode_index
|
|
||||||
|
|
||||||
if video:
|
|
||||||
logging.info("Encoding videos")
|
|
||||||
if play_sounds:
|
|
||||||
say("Encoding videos")
|
|
||||||
# Use ffmpeg to convert frames stored as png into mp4 videos
|
|
||||||
for episode_index in tqdm.tqdm(range(num_episodes)):
|
|
||||||
for key in image_keys:
|
|
||||||
tmp_imgs_dir = videos_dir / f"{key}_episode_{episode_index:06d}"
|
|
||||||
fname = f"{key}_episode_{episode_index:06d}.mp4"
|
|
||||||
video_path = local_dir / "videos" / fname
|
|
||||||
if video_path.exists():
|
|
||||||
# Skip if video is already encoded. Could be the case when resuming data recording.
|
|
||||||
continue
|
|
||||||
# note: `encode_video_frames` is a blocking call. Making it asynchronous shouldn't speedup encoding,
|
|
||||||
# since video encoding with ffmpeg is already using multithreading.
|
|
||||||
encode_video_frames(tmp_imgs_dir, video_path, fps, overwrite=True)
|
|
||||||
shutil.rmtree(tmp_imgs_dir)
|
|
||||||
|
|
||||||
logging.info("Concatenating episodes")
|
|
||||||
ep_dicts = []
|
|
||||||
for episode_index in tqdm.tqdm(range(num_episodes)):
|
|
||||||
ep_path = episodes_dir / f"episode_{episode_index}.pth"
|
|
||||||
ep_dict = torch.load(ep_path)
|
|
||||||
ep_dicts.append(ep_dict)
|
|
||||||
data_dict = concatenate_episodes(ep_dicts)
|
|
||||||
|
|
||||||
total_frames = data_dict["frame_index"].shape[0]
|
|
||||||
data_dict["index"] = torch.arange(0, total_frames, 1)
|
|
||||||
|
|
||||||
hf_dataset = to_hf_dataset(data_dict, video)
|
|
||||||
episode_data_index = calculate_episode_data_index(hf_dataset)
|
|
||||||
info = {
|
|
||||||
"codebase_version": CODEBASE_VERSION,
|
|
||||||
"fps": fps,
|
|
||||||
"video": video,
|
|
||||||
}
|
|
||||||
if video:
|
|
||||||
info["encoding"] = get_default_encoding()
|
|
||||||
|
|
||||||
lerobot_dataset = LeRobotDataset.from_preloaded(
|
|
||||||
repo_id=repo_id,
|
|
||||||
hf_dataset=hf_dataset,
|
|
||||||
episode_data_index=episode_data_index,
|
|
||||||
info=info,
|
|
||||||
videos_dir=videos_dir,
|
|
||||||
)
|
|
||||||
if run_compute_stats:
|
|
||||||
logging.info("Computing dataset statistics")
|
|
||||||
if play_sounds:
|
|
||||||
say("Computing dataset statistics")
|
|
||||||
stats = compute_stats(lerobot_dataset)
|
|
||||||
lerobot_dataset.stats = stats
|
|
||||||
else:
|
|
||||||
stats = {}
|
|
||||||
logging.info("Skipping computation of the dataset statistics")
|
|
||||||
|
|
||||||
hf_dataset = hf_dataset.with_format(None) # to remove transforms that cant be saved
|
|
||||||
hf_dataset.save_to_disk(str(local_dir / "train"))
|
|
||||||
|
|
||||||
meta_data_dir = local_dir / "meta_data"
|
|
||||||
save_meta_data(info, stats, episode_data_index, meta_data_dir)
|
|
||||||
|
|
||||||
if push_to_hub:
|
|
||||||
hf_dataset.push_to_hub(repo_id, revision="main")
|
|
||||||
push_meta_data_to_hub(repo_id, meta_data_dir, revision="main")
|
|
||||||
push_dataset_card_to_hub(repo_id, revision="main", tags=tags)
|
|
||||||
if video:
|
|
||||||
push_videos_to_hub(repo_id, videos_dir, revision="main")
|
|
||||||
create_branch(repo_id, repo_type="dataset", branch=CODEBASE_VERSION)
|
|
||||||
|
|
||||||
logging.info("Exiting")
|
|
||||||
if play_sounds:
|
|
||||||
say("Exiting")
|
|
||||||
return lerobot_dataset
|
return lerobot_dataset
|
||||||
|
|
||||||
|
|
||||||
|
@safe_disconnect
|
||||||
def replay(
|
def replay(
|
||||||
robot: Robot, episode: int, fps: int | None = None, root="data", repo_id="lerobot/debug", play_sounds=True
|
robot: Robot, episode: int, fps: int | None = None, root="data", repo_id="lerobot/debug", play_sounds=True
|
||||||
):
|
):
|
||||||
|
# TODO(rcadene, aliberts): refactor with control_loop, once `dataset` is an instance of LeRobotDataset
|
||||||
# TODO(rcadene): Add option to record logs
|
# TODO(rcadene): Add option to record logs
|
||||||
local_dir = Path(root) / repo_id
|
local_dir = Path(root) / repo_id
|
||||||
if not local_dir.exists():
|
if not local_dir.exists():
|
||||||
|
@ -887,9 +328,7 @@ def replay(
|
||||||
if not robot.is_connected:
|
if not robot.is_connected:
|
||||||
robot.connect()
|
robot.connect()
|
||||||
|
|
||||||
logging.info("Replaying episode")
|
log_say("Replaying episode", play_sounds, blocking=True)
|
||||||
if play_sounds:
|
|
||||||
say("Replaying episode", blocking=True)
|
|
||||||
for idx in range(from_idx, to_idx):
|
for idx in range(from_idx, to_idx):
|
||||||
start_episode_t = time.perf_counter()
|
start_episode_t = time.perf_counter()
|
||||||
|
|
||||||
|
@ -934,6 +373,12 @@ if __name__ == "__main__":
|
||||||
parser_teleop.add_argument(
|
parser_teleop.add_argument(
|
||||||
"--fps", type=none_or_int, default=None, help="Frames per second (set to None to disable)"
|
"--fps", type=none_or_int, default=None, help="Frames per second (set to None to disable)"
|
||||||
)
|
)
|
||||||
|
parser_teleop.add_argument(
|
||||||
|
"--display-cameras",
|
||||||
|
type=int,
|
||||||
|
default=1,
|
||||||
|
help="Display all cameras on screen (set to 1 to display or 0).",
|
||||||
|
)
|
||||||
|
|
||||||
parser_record = subparsers.add_parser("record", parents=[base_parser])
|
parser_record = subparsers.add_parser("record", parents=[base_parser])
|
||||||
parser_record.add_argument(
|
parser_record.add_argument(
|
||||||
|
@ -1071,19 +516,7 @@ if __name__ == "__main__":
|
||||||
teleoperate(robot, **kwargs)
|
teleoperate(robot, **kwargs)
|
||||||
|
|
||||||
elif control_mode == "record":
|
elif control_mode == "record":
|
||||||
pretrained_policy_name_or_path = args.pretrained_policy_name_or_path
|
record(robot, **kwargs)
|
||||||
policy_overrides = args.policy_overrides
|
|
||||||
del kwargs["pretrained_policy_name_or_path"]
|
|
||||||
del kwargs["policy_overrides"]
|
|
||||||
|
|
||||||
policy_cfg = None
|
|
||||||
if pretrained_policy_name_or_path is not None:
|
|
||||||
pretrained_policy_path = get_pretrained_policy_path(pretrained_policy_name_or_path)
|
|
||||||
policy_cfg = init_hydra_config(pretrained_policy_path / "config.yaml", policy_overrides)
|
|
||||||
policy = make_policy(hydra_cfg=policy_cfg, pretrained_policy_name_or_path=pretrained_policy_path)
|
|
||||||
record(robot, policy, policy_cfg, **kwargs)
|
|
||||||
else:
|
|
||||||
record(robot, **kwargs)
|
|
||||||
|
|
||||||
elif control_mode == "replay":
|
elif control_mode == "replay":
|
||||||
replay(robot, **kwargs)
|
replay(robot, **kwargs)
|
||||||
|
|
|
@ -383,7 +383,7 @@ def train(cfg: DictConfig, out_dir: str | None = None, job_name: str | None = No
|
||||||
logging.info(f"Checkpoint policy after step {step}")
|
logging.info(f"Checkpoint policy after step {step}")
|
||||||
# Note: Save with step as the identifier, and format it to have at least 6 digits but more if
|
# Note: Save with step as the identifier, and format it to have at least 6 digits but more if
|
||||||
# needed (choose 6 as a minimum for consistency without being overkill).
|
# needed (choose 6 as a minimum for consistency without being overkill).
|
||||||
logger.save_checkpont(
|
logger.save_checkpoint(
|
||||||
step,
|
step,
|
||||||
policy,
|
policy,
|
||||||
optimizer,
|
optimizer,
|
||||||
|
|
|
@ -250,7 +250,7 @@
|
||||||
if(!canPlayVideos){
|
if(!canPlayVideos){
|
||||||
this.videoCodecError = true;
|
this.videoCodecError = true;
|
||||||
}
|
}
|
||||||
|
|
||||||
// process CSV data
|
// process CSV data
|
||||||
this.videos = document.querySelectorAll('video');
|
this.videos = document.querySelectorAll('video');
|
||||||
this.video = this.videos[0];
|
this.video = this.videos[0];
|
||||||
|
|
|
@ -25,12 +25,16 @@ pytest -sx 'tests/test_control_robot.py::test_teleoperate[aloha-True]'
|
||||||
|
|
||||||
import multiprocessing
|
import multiprocessing
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
from unittest.mock import patch
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
|
from lerobot.common.datasets.populate_dataset import add_frame, init_dataset
|
||||||
|
from lerobot.common.logger import Logger
|
||||||
from lerobot.common.policies.factory import make_policy
|
from lerobot.common.policies.factory import make_policy
|
||||||
from lerobot.common.utils.utils import init_hydra_config
|
from lerobot.common.utils.utils import init_hydra_config
|
||||||
from lerobot.scripts.control_robot import calibrate, get_available_arms, record, replay, teleoperate
|
from lerobot.scripts.control_robot import calibrate, record, replay, teleoperate
|
||||||
|
from lerobot.scripts.train import make_optimizer_and_scheduler
|
||||||
from tests.test_robots import make_robot
|
from tests.test_robots import make_robot
|
||||||
from tests.utils import DEFAULT_CONFIG_PATH, DEVICE, TEST_ROBOT_TYPES, require_robot
|
from tests.utils import DEFAULT_CONFIG_PATH, DEVICE, TEST_ROBOT_TYPES, require_robot
|
||||||
|
|
||||||
|
@ -69,7 +73,7 @@ def test_calibrate(tmpdir, request, robot_type, mock):
|
||||||
overrides_calibration_dir = [f"calibration_dir={calibration_dir}"]
|
overrides_calibration_dir = [f"calibration_dir={calibration_dir}"]
|
||||||
|
|
||||||
robot = make_robot(robot_type, overrides=overrides_calibration_dir, mock=mock)
|
robot = make_robot(robot_type, overrides=overrides_calibration_dir, mock=mock)
|
||||||
calibrate(robot, arms=get_available_arms(robot))
|
calibrate(robot, arms=robot.available_arms)
|
||||||
del robot
|
del robot
|
||||||
|
|
||||||
|
|
||||||
|
@ -109,12 +113,14 @@ def test_record_without_cameras(tmpdir, request, robot_type, mock):
|
||||||
@pytest.mark.parametrize("robot_type, mock", TEST_ROBOT_TYPES)
|
@pytest.mark.parametrize("robot_type, mock", TEST_ROBOT_TYPES)
|
||||||
@require_robot
|
@require_robot
|
||||||
def test_record_and_replay_and_policy(tmpdir, request, robot_type, mock):
|
def test_record_and_replay_and_policy(tmpdir, request, robot_type, mock):
|
||||||
|
tmpdir = Path(tmpdir)
|
||||||
|
|
||||||
if mock and robot_type != "aloha":
|
if mock and robot_type != "aloha":
|
||||||
request.getfixturevalue("patch_builtins_input")
|
request.getfixturevalue("patch_builtins_input")
|
||||||
|
|
||||||
# Create an empty calibration directory to trigger manual calibration
|
# Create an empty calibration directory to trigger manual calibration
|
||||||
# and avoid writing calibration files in user .cache/calibration folder
|
# and avoid writing calibration files in user .cache/calibration folder
|
||||||
calibration_dir = Path(tmpdir) / robot_type
|
calibration_dir = tmpdir / robot_type
|
||||||
overrides = [f"calibration_dir={calibration_dir}"]
|
overrides = [f"calibration_dir={calibration_dir}"]
|
||||||
else:
|
else:
|
||||||
# Use the default .cache/calibration folder when mock=False or for aloha
|
# Use the default .cache/calibration folder when mock=False or for aloha
|
||||||
|
@ -123,17 +129,19 @@ def test_record_and_replay_and_policy(tmpdir, request, robot_type, mock):
|
||||||
env_name = "koch_real"
|
env_name = "koch_real"
|
||||||
policy_name = "act_koch_real"
|
policy_name = "act_koch_real"
|
||||||
|
|
||||||
root = Path(tmpdir) / "data"
|
root = tmpdir / "data"
|
||||||
repo_id = "lerobot/debug"
|
repo_id = "lerobot/debug"
|
||||||
|
eval_repo_id = "lerobot/eval_debug"
|
||||||
|
|
||||||
robot = make_robot(robot_type, overrides=overrides, mock=mock)
|
robot = make_robot(robot_type, overrides=overrides, mock=mock)
|
||||||
dataset = record(
|
dataset = record(
|
||||||
robot,
|
robot,
|
||||||
fps=30,
|
root,
|
||||||
root=root,
|
repo_id,
|
||||||
repo_id=repo_id,
|
fps=1,
|
||||||
warmup_time_s=1,
|
warmup_time_s=1,
|
||||||
episode_time_s=1,
|
episode_time_s=1,
|
||||||
|
reset_time_s=1,
|
||||||
num_episodes=2,
|
num_episodes=2,
|
||||||
push_to_hub=False,
|
push_to_hub=False,
|
||||||
# TODO(rcadene, aliberts): test video=True
|
# TODO(rcadene, aliberts): test video=True
|
||||||
|
@ -142,8 +150,10 @@ def test_record_and_replay_and_policy(tmpdir, request, robot_type, mock):
|
||||||
display_cameras=False,
|
display_cameras=False,
|
||||||
play_sounds=False,
|
play_sounds=False,
|
||||||
)
|
)
|
||||||
|
assert dataset.num_episodes == 2
|
||||||
|
assert len(dataset) == 2
|
||||||
|
|
||||||
replay(robot, episode=0, fps=30, root=root, repo_id=repo_id, play_sounds=False)
|
replay(robot, episode=0, fps=1, root=root, repo_id=repo_id, play_sounds=False)
|
||||||
|
|
||||||
# TODO(rcadene, aliberts): rethink this design
|
# TODO(rcadene, aliberts): rethink this design
|
||||||
if robot_type == "aloha":
|
if robot_type == "aloha":
|
||||||
|
@ -164,12 +174,26 @@ def test_record_and_replay_and_policy(tmpdir, request, robot_type, mock):
|
||||||
if robot_type == "koch_bimanual":
|
if robot_type == "koch_bimanual":
|
||||||
overrides += ["env.state_dim=12", "env.action_dim=12"]
|
overrides += ["env.state_dim=12", "env.action_dim=12"]
|
||||||
|
|
||||||
|
overrides += ["wandb.enable=false"]
|
||||||
|
overrides += ["env.fps=1"]
|
||||||
|
|
||||||
cfg = init_hydra_config(
|
cfg = init_hydra_config(
|
||||||
DEFAULT_CONFIG_PATH,
|
DEFAULT_CONFIG_PATH,
|
||||||
overrides=overrides,
|
overrides=overrides,
|
||||||
)
|
)
|
||||||
|
|
||||||
policy = make_policy(hydra_cfg=cfg, dataset_stats=dataset.stats)
|
policy = make_policy(hydra_cfg=cfg, dataset_stats=dataset.stats)
|
||||||
|
optimizer, lr_scheduler = make_optimizer_and_scheduler(cfg, policy)
|
||||||
|
out_dir = tmpdir / "logger"
|
||||||
|
logger = Logger(cfg, out_dir, wandb_job_name="debug")
|
||||||
|
logger.save_checkpoint(
|
||||||
|
0,
|
||||||
|
policy,
|
||||||
|
optimizer,
|
||||||
|
lr_scheduler,
|
||||||
|
identifier=0,
|
||||||
|
)
|
||||||
|
pretrained_policy_name_or_path = out_dir / "checkpoints/last/pretrained_model"
|
||||||
|
|
||||||
# In `examples/9_use_aloha.md`, we advise using `num_image_writer_processes=1`
|
# In `examples/9_use_aloha.md`, we advise using `num_image_writer_processes=1`
|
||||||
# during inference, to reach constent fps, so we test this here.
|
# during inference, to reach constent fps, so we test this here.
|
||||||
|
@ -194,10 +218,12 @@ def test_record_and_replay_and_policy(tmpdir, request, robot_type, mock):
|
||||||
|
|
||||||
record(
|
record(
|
||||||
robot,
|
robot,
|
||||||
policy,
|
root,
|
||||||
cfg,
|
eval_repo_id,
|
||||||
|
pretrained_policy_name_or_path,
|
||||||
warmup_time_s=1,
|
warmup_time_s=1,
|
||||||
episode_time_s=1,
|
episode_time_s=1,
|
||||||
|
reset_time_s=1,
|
||||||
num_episodes=2,
|
num_episodes=2,
|
||||||
run_compute_stats=False,
|
run_compute_stats=False,
|
||||||
push_to_hub=False,
|
push_to_hub=False,
|
||||||
|
@ -207,4 +233,218 @@ def test_record_and_replay_and_policy(tmpdir, request, robot_type, mock):
|
||||||
num_image_writer_processes=num_image_writer_processes,
|
num_image_writer_processes=num_image_writer_processes,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
assert dataset.num_episodes == 2
|
||||||
|
assert len(dataset) == 2
|
||||||
|
|
||||||
del robot
|
del robot
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("robot_type, mock", [("koch", True)])
|
||||||
|
@require_robot
|
||||||
|
def test_resume_record(tmpdir, request, robot_type, mock):
|
||||||
|
if mock and robot_type != "aloha":
|
||||||
|
request.getfixturevalue("patch_builtins_input")
|
||||||
|
|
||||||
|
# Create an empty calibration directory to trigger manual calibration
|
||||||
|
# and avoid writing calibration files in user .cache/calibration folder
|
||||||
|
calibration_dir = tmpdir / robot_type
|
||||||
|
overrides = [f"calibration_dir={calibration_dir}"]
|
||||||
|
else:
|
||||||
|
# Use the default .cache/calibration folder when mock=False or for aloha
|
||||||
|
overrides = []
|
||||||
|
|
||||||
|
robot = make_robot(robot_type, overrides=overrides, mock=mock)
|
||||||
|
|
||||||
|
root = Path(tmpdir) / "data"
|
||||||
|
repo_id = "lerobot/debug"
|
||||||
|
|
||||||
|
dataset = record(
|
||||||
|
robot,
|
||||||
|
root,
|
||||||
|
repo_id,
|
||||||
|
fps=1,
|
||||||
|
warmup_time_s=0,
|
||||||
|
episode_time_s=1,
|
||||||
|
num_episodes=1,
|
||||||
|
push_to_hub=False,
|
||||||
|
video=False,
|
||||||
|
display_cameras=False,
|
||||||
|
play_sounds=False,
|
||||||
|
run_compute_stats=False,
|
||||||
|
)
|
||||||
|
assert len(dataset) == 1, "`dataset` should contain only 1 frame"
|
||||||
|
|
||||||
|
init_dataset_return_value = {}
|
||||||
|
|
||||||
|
def wrapped_init_dataset(*args, **kwargs):
|
||||||
|
nonlocal init_dataset_return_value
|
||||||
|
init_dataset_return_value = init_dataset(*args, **kwargs)
|
||||||
|
return init_dataset_return_value
|
||||||
|
|
||||||
|
with patch("lerobot.scripts.control_robot.init_dataset", wraps=wrapped_init_dataset):
|
||||||
|
dataset = record(
|
||||||
|
robot,
|
||||||
|
root,
|
||||||
|
repo_id,
|
||||||
|
fps=1,
|
||||||
|
warmup_time_s=0,
|
||||||
|
episode_time_s=1,
|
||||||
|
num_episodes=2,
|
||||||
|
push_to_hub=False,
|
||||||
|
video=False,
|
||||||
|
display_cameras=False,
|
||||||
|
play_sounds=False,
|
||||||
|
run_compute_stats=False,
|
||||||
|
)
|
||||||
|
assert len(dataset) == 2, "`dataset` should contain only 1 frame"
|
||||||
|
assert (
|
||||||
|
init_dataset_return_value["num_episodes"] == 2
|
||||||
|
), "`init_dataset` should load the previous episode"
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("robot_type, mock", [("koch", True)])
|
||||||
|
@require_robot
|
||||||
|
def test_record_with_event_rerecord_episode(tmpdir, request, robot_type, mock):
|
||||||
|
if mock and robot_type != "aloha":
|
||||||
|
request.getfixturevalue("patch_builtins_input")
|
||||||
|
|
||||||
|
# Create an empty calibration directory to trigger manual calibration
|
||||||
|
# and avoid writing calibration files in user .cache/calibration folder
|
||||||
|
calibration_dir = tmpdir / robot_type
|
||||||
|
overrides = [f"calibration_dir={calibration_dir}"]
|
||||||
|
else:
|
||||||
|
# Use the default .cache/calibration folder when mock=False or for aloha
|
||||||
|
overrides = []
|
||||||
|
|
||||||
|
robot = make_robot(robot_type, overrides=overrides, mock=mock)
|
||||||
|
with (
|
||||||
|
patch("lerobot.scripts.control_robot.init_keyboard_listener") as mock_listener,
|
||||||
|
patch("lerobot.common.robot_devices.control_utils.add_frame", wraps=add_frame) as mock_add_frame,
|
||||||
|
):
|
||||||
|
mock_events = {}
|
||||||
|
mock_events["exit_early"] = True
|
||||||
|
mock_events["rerecord_episode"] = True
|
||||||
|
mock_events["stop_recording"] = False
|
||||||
|
mock_listener.return_value = (None, mock_events)
|
||||||
|
|
||||||
|
root = Path(tmpdir) / "data"
|
||||||
|
repo_id = "lerobot/debug"
|
||||||
|
|
||||||
|
dataset = record(
|
||||||
|
robot,
|
||||||
|
root,
|
||||||
|
repo_id,
|
||||||
|
fps=1,
|
||||||
|
warmup_time_s=0,
|
||||||
|
episode_time_s=1,
|
||||||
|
num_episodes=1,
|
||||||
|
push_to_hub=False,
|
||||||
|
video=False,
|
||||||
|
display_cameras=False,
|
||||||
|
play_sounds=False,
|
||||||
|
run_compute_stats=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert not mock_events["rerecord_episode"], "`rerecord_episode` wasn't properly reset to False"
|
||||||
|
assert not mock_events["exit_early"], "`exit_early` wasn't properly reset to False"
|
||||||
|
assert mock_add_frame.call_count == 2, "`add_frame` should have been called 2 times"
|
||||||
|
assert len(dataset) == 1, "`dataset` should contain only 1 frame"
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("robot_type, mock", [("koch", True)])
|
||||||
|
@require_robot
|
||||||
|
def test_record_with_event_exit_early(tmpdir, request, robot_type, mock):
|
||||||
|
if mock:
|
||||||
|
request.getfixturevalue("patch_builtins_input")
|
||||||
|
|
||||||
|
# Create an empty calibration directory to trigger manual calibration
|
||||||
|
# and avoid writing calibration files in user .cache/calibration folder
|
||||||
|
calibration_dir = tmpdir / robot_type
|
||||||
|
overrides = [f"calibration_dir={calibration_dir}"]
|
||||||
|
else:
|
||||||
|
# Use the default .cache/calibration folder when mock=False or for aloha
|
||||||
|
overrides = []
|
||||||
|
|
||||||
|
robot = make_robot(robot_type, overrides=overrides, mock=mock)
|
||||||
|
with (
|
||||||
|
patch("lerobot.scripts.control_robot.init_keyboard_listener") as mock_listener,
|
||||||
|
patch("lerobot.common.robot_devices.control_utils.add_frame", wraps=add_frame) as mock_add_frame,
|
||||||
|
):
|
||||||
|
mock_events = {}
|
||||||
|
mock_events["exit_early"] = True
|
||||||
|
mock_events["rerecord_episode"] = False
|
||||||
|
mock_events["stop_recording"] = False
|
||||||
|
mock_listener.return_value = (None, mock_events)
|
||||||
|
|
||||||
|
root = Path(tmpdir) / "data"
|
||||||
|
repo_id = "lerobot/debug"
|
||||||
|
|
||||||
|
dataset = record(
|
||||||
|
robot,
|
||||||
|
fps=2,
|
||||||
|
root=root,
|
||||||
|
repo_id=repo_id,
|
||||||
|
warmup_time_s=0,
|
||||||
|
episode_time_s=1,
|
||||||
|
num_episodes=1,
|
||||||
|
push_to_hub=False,
|
||||||
|
video=False,
|
||||||
|
display_cameras=False,
|
||||||
|
play_sounds=False,
|
||||||
|
run_compute_stats=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert not mock_events["exit_early"], "`exit_early` wasn't properly reset to False"
|
||||||
|
assert mock_add_frame.call_count == 1, "`add_frame` should have been called 1 time"
|
||||||
|
assert len(dataset) == 1, "`dataset` should contain only 1 frame"
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"robot_type, mock, num_image_writer_processes", [("koch", True, 0), ("koch", True, 1)]
|
||||||
|
)
|
||||||
|
@require_robot
|
||||||
|
def test_record_with_event_stop_recording(tmpdir, request, robot_type, mock, num_image_writer_processes):
|
||||||
|
if mock:
|
||||||
|
request.getfixturevalue("patch_builtins_input")
|
||||||
|
|
||||||
|
# Create an empty calibration directory to trigger manual calibration
|
||||||
|
# and avoid writing calibration files in user .cache/calibration folder
|
||||||
|
calibration_dir = tmpdir / robot_type
|
||||||
|
overrides = [f"calibration_dir={calibration_dir}"]
|
||||||
|
else:
|
||||||
|
# Use the default .cache/calibration folder when mock=False or for aloha
|
||||||
|
overrides = []
|
||||||
|
|
||||||
|
robot = make_robot(robot_type, overrides=overrides, mock=mock)
|
||||||
|
with (
|
||||||
|
patch("lerobot.scripts.control_robot.init_keyboard_listener") as mock_listener,
|
||||||
|
patch("lerobot.common.robot_devices.control_utils.add_frame", wraps=add_frame) as mock_add_frame,
|
||||||
|
):
|
||||||
|
mock_events = {}
|
||||||
|
mock_events["exit_early"] = True
|
||||||
|
mock_events["rerecord_episode"] = False
|
||||||
|
mock_events["stop_recording"] = True
|
||||||
|
mock_listener.return_value = (None, mock_events)
|
||||||
|
|
||||||
|
root = Path(tmpdir) / "data"
|
||||||
|
repo_id = "lerobot/debug"
|
||||||
|
|
||||||
|
dataset = record(
|
||||||
|
robot,
|
||||||
|
root,
|
||||||
|
repo_id,
|
||||||
|
fps=1,
|
||||||
|
warmup_time_s=0,
|
||||||
|
episode_time_s=1,
|
||||||
|
num_episodes=2,
|
||||||
|
push_to_hub=False,
|
||||||
|
video=False,
|
||||||
|
display_cameras=False,
|
||||||
|
play_sounds=False,
|
||||||
|
run_compute_stats=False,
|
||||||
|
num_image_writer_processes=num_image_writer_processes,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert not mock_events["exit_early"], "`exit_early` wasn't properly reset to False"
|
||||||
|
assert mock_add_frame.call_count == 1, "`add_frame` should have been called 1 time"
|
||||||
|
assert len(dataset) == 1, "`dataset` should contain only 1 frame"
|
||||||
|
|
|
@ -127,6 +127,7 @@ def test_robot(tmpdir, request, robot_type, mock):
|
||||||
# TODO(rcadene): skipping image for now as it's challenging to assess equality between two consecutive frames
|
# TODO(rcadene): skipping image for now as it's challenging to assess equality between two consecutive frames
|
||||||
continue
|
continue
|
||||||
assert torch.allclose(captured_observation[name], observation[name], atol=1)
|
assert torch.allclose(captured_observation[name], observation[name], atol=1)
|
||||||
|
assert captured_observation[name].shape == observation[name].shape
|
||||||
|
|
||||||
# Test send_action can run
|
# Test send_action can run
|
||||||
robot.send_action(action["action"])
|
robot.send_action(action["action"])
|
||||||
|
|
Loading…
Reference in New Issue