From 3b925c3dce5a3b2f741b4335ff075f4d52152697 Mon Sep 17 00:00:00 2001 From: Simon Alibert Date: Mon, 21 Oct 2024 00:15:09 +0200 Subject: [PATCH] Add ImageWriter --- lerobot/common/datasets/image_writer.py | 130 ++++++++ lerobot/common/datasets/populate_dataset.py | 352 ++++++-------------- 2 files changed, 231 insertions(+), 251 deletions(-) create mode 100644 lerobot/common/datasets/image_writer.py diff --git a/lerobot/common/datasets/image_writer.py b/lerobot/common/datasets/image_writer.py new file mode 100644 index 00000000..c87e342b --- /dev/null +++ b/lerobot/common/datasets/image_writer.py @@ -0,0 +1,130 @@ +#!/usr/bin/env python + +# Copyright 2024 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import multiprocessing +from concurrent.futures import ThreadPoolExecutor, wait +from pathlib import Path + +import torch +import tqdm +from PIL import Image + +DEFAULT_IMAGE_PATH = "images/{image_key}/episode_{episode_index:06d}/frame_{frame_index:06d}.png" + + +def safe_stop_image_writer(func): + def wrapper(*args, **kwargs): + try: + return func(*args, **kwargs) + except Exception as e: + dataset = kwargs.get("dataset", None) + image_writer = getattr(dataset, "image_writer", None) if dataset else None + if image_writer is not None: + print("Waiting for image writer to terminate...") + image_writer.stop() + raise e + + return wrapper + + +class ImageWriter: + """This class abstract away the initialisation of processes or/and threads to + save images on disk asynchrounously, which is critical to control a robot and record data + at a high frame rate. + + When `num_processes=0`, it creates a threads pool of size `num_threads`. + When `num_processes>0`, it creates processes pool of size `num_processes`, where each subprocess starts + their own threads pool of size `num_threads`. + + The optimal number of processes and threads depends on your computer capabilities. + We advise to use 4 threads per camera with 0 processes. If the fps is not stable, try to increase or lower + the number of threads. If it is still not stable, try to use 1 subprocess, or more. + """ + + def __init__(self, write_dir: Path, num_processes: int = 0, num_threads: int = 1): + self.dir = write_dir + self.image_path = DEFAULT_IMAGE_PATH + self.num_processes = num_processes + self.num_threads = self.num_threads_per_process = num_threads + + if self.num_processes <= 0: + self.type = "threads" + self.threads = ThreadPoolExecutor(max_workers=self.num_threads) + self.futures = [] + else: + self.type = "processes" + self.num_threads_per_process = self.num_threads + self.image_queue = multiprocessing.Queue() + self.processes: list[multiprocessing.Process] = [] + for _ in range(num_processes): + process = multiprocessing.Process(target=self._loop_to_save_images_in_threads) + process.start() + self.processes.append(process) + + def _loop_to_save_images_in_threads(self) -> None: + with ThreadPoolExecutor(max_workers=self.num_threads) as executor: + futures = [] + while True: + frame_data = self.image_queue.get() + if frame_data is None: + break + + image, file_path = frame_data + futures.append(executor.submit(self._save_image, image, file_path)) + + with tqdm.tqdm(total=len(futures), desc="Writing images") as progress_bar: + wait(futures) + progress_bar.update(len(futures)) + + def async_save_image(self, image: torch.Tensor, file_path: Path) -> None: + """Save an image asynchronously using threads or processes.""" + if self.type == "threads": + self.futures.append(self.threads.submit(self._save_image, image, file_path)) + else: + self.image_queue.put((image, file_path)) + + def _save_image(self, image: torch.Tensor, file_path: Path) -> None: + img = Image.fromarray(image.numpy()) + img.save(str(file_path), quality=100) + + def get_image_file_path( + self, episode_index: int, image_key: str, frame_index: int, return_str: bool = True + ) -> str | Path: + fpath = self.image_path.format( + image_key=image_key, episode_index=episode_index, frame_index=frame_index + ) + return str(self.dir / fpath) if return_str else self.dir / fpath + + def stop(self, timeout=20) -> None: + """Stop the image writer, waiting for all processes or threads to finish.""" + if self.type == "threads": + with tqdm.tqdm(total=len(self.futures), desc="Writing images") as progress_bar: + wait(self.futures, timeout=timeout) + progress_bar.update(len(self.futures)) + else: + self._stop_processes(self.processes, self.image_queue, timeout) + + def _stop_processes(self, timeout) -> None: + for _ in self.processes: + self.image_queue.put(None) + + for process in self.processes: + process.join(timeout=timeout) + + if process.is_alive(): + process.terminate() + + self.image_queue.close() + self.image_queue.join_thread() diff --git a/lerobot/common/datasets/populate_dataset.py b/lerobot/common/datasets/populate_dataset.py index df5d20e5..854b639e 100644 --- a/lerobot/common/datasets/populate_dataset.py +++ b/lerobot/common/datasets/populate_dataset.py @@ -1,16 +1,12 @@ """Functions to create an empty dataset, and populate it with frames.""" # TODO(rcadene, aliberts): to adapt as class methods of next version of LeRobotDataset -import concurrent import json import logging -import multiprocessing import shutil -from pathlib import Path import torch import tqdm -from PIL import Image from lerobot.common.datasets.compute_stats import compute_stats from lerobot.common.datasets.lerobot_dataset import CODEBASE_VERSION, LeRobotDataset @@ -26,277 +22,131 @@ from lerobot.scripts.push_dataset_to_hub import ( save_meta_data, ) -######################################################################################## -# Asynchrounous saving of images on disk -######################################################################################## - - -def safe_stop_image_writer(func): - # TODO(aliberts): Allow to pass custom exceptions - # (e.g. ThreadServiceExit, KeyboardInterrupt, SystemExit, UnpluggedError, DynamixelCommError) - def wrapper(*args, **kwargs): - try: - return func(*args, **kwargs) - except Exception as e: - image_writer = kwargs.get("dataset", {}).get("image_writer") - if image_writer is not None: - print("Waiting for image writer to terminate...") - stop_image_writer(image_writer, timeout=20) - raise e - - return wrapper - - -def save_image(img_tensor, key, frame_index, episode_index, videos_dir: str): - img = Image.fromarray(img_tensor.numpy()) - path = Path(videos_dir) / f"{key}_episode_{episode_index:06d}" / f"frame_{frame_index:06d}.png" - path.parent.mkdir(parents=True, exist_ok=True) - img.save(str(path), quality=100) - - -def loop_to_save_images_in_threads(image_queue, num_threads): - if num_threads < 1: - raise NotImplementedError(f"Only `num_threads>=1` is supported for now, but {num_threads=} given.") - - with concurrent.futures.ThreadPoolExecutor(max_workers=num_threads) as executor: - futures = [] - while True: - # Blocks until a frame is available - frame_data = image_queue.get() - - # As usually done, exit loop when receiving None to stop the worker - if frame_data is None: - break - - image, key, frame_index, episode_index, videos_dir = frame_data - futures.append(executor.submit(save_image, image, key, frame_index, episode_index, videos_dir)) - - # Before exiting function, wait for all threads to complete - with tqdm.tqdm(total=len(futures), desc="Writing images") as progress_bar: - concurrent.futures.wait(futures) - progress_bar.update(len(futures)) - - -def start_image_writer_processes(image_queue, num_processes, num_threads_per_process): - if num_processes < 1: - raise ValueError(f"Only `num_processes>=1` is supported, but {num_processes=} given.") - - if num_threads_per_process < 1: - raise NotImplementedError( - "Only `num_threads_per_process>=1` is supported for now, but {num_threads_per_process=} given." - ) - - processes = [] - for _ in range(num_processes): - process = multiprocessing.Process( - target=loop_to_save_images_in_threads, - args=(image_queue, num_threads_per_process), - ) - process.start() - processes.append(process) - return processes - - -def stop_processes(processes, queue, timeout): - # Send None to each process to signal them to stop - for _ in processes: - queue.put(None) - - # Wait maximum 20 seconds for all processes to terminate - for process in processes: - process.join(timeout=timeout) - - # If not terminated after 20 seconds, force termination - if process.is_alive(): - process.terminate() - - # Close the queue, no more items can be put in the queue - queue.close() - - # Ensure all background queue threads have finished - queue.join_thread() - - -def start_image_writer(num_processes, num_threads): - """This function abstract away the initialisation of processes or/and threads to - save images on disk asynchrounously, which is critical to control a robot and record data - at a high frame rate. - - When `num_processes=0`, it returns a dictionary containing a threads pool of size `num_threads`. - When `num_processes>0`, it returns a dictionary containing a processes pool of size `num_processes`, - where each subprocess starts their own threads pool of size `num_threads`. - - The optimal number of processes and threads depends on your computer capabilities. - We advise to use 4 threads per camera with 0 processes. If the fps is not stable, try to increase or lower - the number of threads. If it is still not stable, try to use 1 subprocess, or more. - """ - image_writer = {} - - if num_processes == 0: - futures = [] - threads_pool = concurrent.futures.ThreadPoolExecutor(max_workers=num_threads) - image_writer["threads_pool"], image_writer["futures"] = threads_pool, futures - else: - # TODO(rcadene): When using num_processes>1, `multiprocessing.Manager().Queue()` - # might be better than `multiprocessing.Queue()`. Source: https://www.geeksforgeeks.org/python-multiprocessing-queue-vs-multiprocessing-manager-queue - image_queue = multiprocessing.Queue() - processes_pool = start_image_writer_processes( - image_queue, num_processes=num_processes, num_threads_per_process=num_threads - ) - image_writer["processes_pool"], image_writer["image_queue"] = processes_pool, image_queue - - return image_writer - - -def async_save_image(image_writer, image, key, frame_index, episode_index, videos_dir): - """This function abstract away the saving of an image on disk asynchrounously. It uses a dictionary - called image writer which contains either a pool of processes or a pool of threads. - """ - if "threads_pool" in image_writer: - threads_pool, futures = image_writer["threads_pool"], image_writer["futures"] - futures.append(threads_pool.submit(save_image, image, key, frame_index, episode_index, videos_dir)) - else: - image_queue = image_writer["image_queue"] - image_queue.put((image, key, frame_index, episode_index, videos_dir)) - - -def stop_image_writer(image_writer, timeout): - if "threads_pool" in image_writer: - futures = image_writer["futures"] - # Before exiting function, wait for all threads to complete - with tqdm.tqdm(total=len(futures), desc="Writing images") as progress_bar: - concurrent.futures.wait(futures, timeout=timeout) - progress_bar.update(len(futures)) - else: - processes_pool, image_queue = image_writer["processes_pool"], image_writer["image_queue"] - stop_processes(processes_pool, image_queue, timeout=timeout) - - ######################################################################################## # Functions to initialize, resume and populate a dataset ######################################################################################## -def init_dataset( - repo_id, - root, - force_override, - fps, - video, - write_images, - num_image_writer_processes, - num_image_writer_threads, -): - local_dir = Path(root) / repo_id - if local_dir.exists() and force_override: - shutil.rmtree(local_dir) +# def init_dataset( +# repo_id, +# root, +# force_override, +# fps, +# video, +# write_images, +# num_image_writer_processes, +# num_image_writer_threads, +# ): +# local_dir = Path(root) / repo_id +# if local_dir.exists() and force_override: +# shutil.rmtree(local_dir) - episodes_dir = local_dir / "episodes" - episodes_dir.mkdir(parents=True, exist_ok=True) +# episodes_dir = local_dir / "episodes" +# episodes_dir.mkdir(parents=True, exist_ok=True) - videos_dir = local_dir / "videos" - videos_dir.mkdir(parents=True, exist_ok=True) +# videos_dir = local_dir / "videos" +# videos_dir.mkdir(parents=True, exist_ok=True) - # Logic to resume data recording - rec_info_path = episodes_dir / "data_recording_info.json" - if rec_info_path.exists(): - with open(rec_info_path) as f: - rec_info = json.load(f) - num_episodes = rec_info["last_episode_index"] + 1 - else: - num_episodes = 0 +# # Logic to resume data recording +# rec_info_path = episodes_dir / "data_recording_info.json" +# if rec_info_path.exists(): +# with open(rec_info_path) as f: +# rec_info = json.load(f) +# num_episodes = rec_info["last_episode_index"] + 1 +# else: +# num_episodes = 0 - dataset = { - "repo_id": repo_id, - "local_dir": local_dir, - "videos_dir": videos_dir, - "episodes_dir": episodes_dir, - "fps": fps, - "video": video, - "rec_info_path": rec_info_path, - "num_episodes": num_episodes, - } +# dataset = { +# "repo_id": repo_id, +# "local_dir": local_dir, +# "videos_dir": videos_dir, +# "episodes_dir": episodes_dir, +# "fps": fps, +# "video": video, +# "rec_info_path": rec_info_path, +# "num_episodes": num_episodes, +# } - if write_images: - # Initialize processes or/and threads dedicated to save images on disk asynchronously, - # which is critical to control a robot and record data at a high frame rate. - image_writer = start_image_writer( - num_processes=num_image_writer_processes, - num_threads=num_image_writer_threads, - ) - dataset["image_writer"] = image_writer +# if write_images: +# # Initialize processes or/and threads dedicated to save images on disk asynchronously, +# # which is critical to control a robot and record data at a high frame rate. +# image_writer = start_image_writer( +# num_processes=num_image_writer_processes, +# num_threads=num_image_writer_threads, +# ) +# dataset["image_writer"] = image_writer - return dataset +# return dataset -def add_frame(dataset, observation, action): - if "current_episode" not in dataset: - # initialize episode dictionary - ep_dict = {} - for key in observation: - if key not in ep_dict: - ep_dict[key] = [] - for key in action: - if key not in ep_dict: - ep_dict[key] = [] +# def add_frame(dataset, observation, action): +# if "current_episode" not in dataset: +# # initialize episode dictionary +# ep_dict = {} +# for key in observation: +# if key not in ep_dict: +# ep_dict[key] = [] +# for key in action: +# if key not in ep_dict: +# ep_dict[key] = [] - ep_dict["episode_index"] = [] - ep_dict["frame_index"] = [] - ep_dict["timestamp"] = [] - ep_dict["next.done"] = [] +# ep_dict["episode_index"] = [] +# ep_dict["frame_index"] = [] +# ep_dict["timestamp"] = [] +# ep_dict["next.done"] = [] - dataset["current_episode"] = ep_dict - dataset["current_frame_index"] = 0 +# dataset["current_episode"] = ep_dict +# dataset["current_frame_index"] = 0 - ep_dict = dataset["current_episode"] - episode_index = dataset["num_episodes"] - frame_index = dataset["current_frame_index"] - videos_dir = dataset["videos_dir"] - video = dataset["video"] - fps = dataset["fps"] +# ep_dict = dataset["current_episode"] +# episode_index = dataset["num_episodes"] +# frame_index = dataset["current_frame_index"] +# videos_dir = dataset["videos_dir"] +# video = dataset["video"] +# fps = dataset["fps"] - ep_dict["episode_index"].append(episode_index) - ep_dict["frame_index"].append(frame_index) - ep_dict["timestamp"].append(frame_index / fps) - ep_dict["next.done"].append(False) +# ep_dict["episode_index"].append(episode_index) +# ep_dict["frame_index"].append(frame_index) +# ep_dict["timestamp"].append(frame_index / fps) +# ep_dict["next.done"].append(False) - img_keys = [key for key in observation if "image" in key] - non_img_keys = [key for key in observation if "image" not in key] +# img_keys = [key for key in observation if "image" in key] +# non_img_keys = [key for key in observation if "image" not in key] - # Save all observed modalities except images - for key in non_img_keys: - ep_dict[key].append(observation[key]) +# # Save all observed modalities except images +# for key in non_img_keys: +# ep_dict[key].append(observation[key]) - # Save actions - for key in action: - ep_dict[key].append(action[key]) +# # Save actions +# for key in action: +# ep_dict[key].append(action[key]) - if "image_writer" not in dataset: - dataset["current_frame_index"] += 1 - return +# if "image_writer" not in dataset: +# dataset["current_frame_index"] += 1 +# return - # Save images - image_writer = dataset["image_writer"] - for key in img_keys: - imgs_dir = videos_dir / f"{key}_episode_{episode_index:06d}" - async_save_image( - image_writer, - image=observation[key], - key=key, - frame_index=frame_index, - episode_index=episode_index, - videos_dir=str(videos_dir), - ) +# # Save images +# image_writer = dataset["image_writer"] +# for key in img_keys: +# imgs_dir = videos_dir / f"{key}_episode_{episode_index:06d}" +# async_save_image( +# image_writer, +# image=observation[key], +# key=key, +# frame_index=frame_index, +# episode_index=episode_index, +# videos_dir=str(videos_dir), +# ) - if video: - fname = f"{key}_episode_{episode_index:06d}.mp4" - frame_info = {"path": f"videos/{fname}", "timestamp": frame_index / fps} - else: - frame_info = str(imgs_dir / f"frame_{frame_index:06d}.png") +# if video: +# fname = f"{key}_episode_{episode_index:06d}.mp4" +# frame_info = {"path": f"videos/{fname}", "timestamp": frame_index / fps} +# else: +# frame_info = str(imgs_dir / f"frame_{frame_index:06d}.png") - ep_dict[key].append(frame_info) +# ep_dict[key].append(frame_info) - dataset["current_frame_index"] += 1 +# dataset["current_frame_index"] += 1 def delete_current_episode(dataset): @@ -449,7 +299,7 @@ def create_lerobot_dataset(dataset, run_compute_stats, push_to_hub, tags, play_s if "image_writer" in dataset: logging.info("Waiting for image writer to terminate...") image_writer = dataset["image_writer"] - stop_image_writer(image_writer, timeout=20) + image_writer.stop() lerobot_dataset = from_dataset_to_lerobot_dataset(dataset, play_sounds)