Add ImageWriter

This commit is contained in:
Simon Alibert 2024-10-21 00:15:09 +02:00
parent e46bdb9d30
commit 3b925c3dce
2 changed files with 231 additions and 251 deletions

View File

@ -0,0 +1,130 @@
#!/usr/bin/env python
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import multiprocessing
from concurrent.futures import ThreadPoolExecutor, wait
from pathlib import Path
import torch
import tqdm
from PIL import Image
DEFAULT_IMAGE_PATH = "images/{image_key}/episode_{episode_index:06d}/frame_{frame_index:06d}.png"
def safe_stop_image_writer(func):
def wrapper(*args, **kwargs):
try:
return func(*args, **kwargs)
except Exception as e:
dataset = kwargs.get("dataset", None)
image_writer = getattr(dataset, "image_writer", None) if dataset else None
if image_writer is not None:
print("Waiting for image writer to terminate...")
image_writer.stop()
raise e
return wrapper
class ImageWriter:
"""This class abstract away the initialisation of processes or/and threads to
save images on disk asynchrounously, which is critical to control a robot and record data
at a high frame rate.
When `num_processes=0`, it creates a threads pool of size `num_threads`.
When `num_processes>0`, it creates processes pool of size `num_processes`, where each subprocess starts
their own threads pool of size `num_threads`.
The optimal number of processes and threads depends on your computer capabilities.
We advise to use 4 threads per camera with 0 processes. If the fps is not stable, try to increase or lower
the number of threads. If it is still not stable, try to use 1 subprocess, or more.
"""
def __init__(self, write_dir: Path, num_processes: int = 0, num_threads: int = 1):
self.dir = write_dir
self.image_path = DEFAULT_IMAGE_PATH
self.num_processes = num_processes
self.num_threads = self.num_threads_per_process = num_threads
if self.num_processes <= 0:
self.type = "threads"
self.threads = ThreadPoolExecutor(max_workers=self.num_threads)
self.futures = []
else:
self.type = "processes"
self.num_threads_per_process = self.num_threads
self.image_queue = multiprocessing.Queue()
self.processes: list[multiprocessing.Process] = []
for _ in range(num_processes):
process = multiprocessing.Process(target=self._loop_to_save_images_in_threads)
process.start()
self.processes.append(process)
def _loop_to_save_images_in_threads(self) -> None:
with ThreadPoolExecutor(max_workers=self.num_threads) as executor:
futures = []
while True:
frame_data = self.image_queue.get()
if frame_data is None:
break
image, file_path = frame_data
futures.append(executor.submit(self._save_image, image, file_path))
with tqdm.tqdm(total=len(futures), desc="Writing images") as progress_bar:
wait(futures)
progress_bar.update(len(futures))
def async_save_image(self, image: torch.Tensor, file_path: Path) -> None:
"""Save an image asynchronously using threads or processes."""
if self.type == "threads":
self.futures.append(self.threads.submit(self._save_image, image, file_path))
else:
self.image_queue.put((image, file_path))
def _save_image(self, image: torch.Tensor, file_path: Path) -> None:
img = Image.fromarray(image.numpy())
img.save(str(file_path), quality=100)
def get_image_file_path(
self, episode_index: int, image_key: str, frame_index: int, return_str: bool = True
) -> str | Path:
fpath = self.image_path.format(
image_key=image_key, episode_index=episode_index, frame_index=frame_index
)
return str(self.dir / fpath) if return_str else self.dir / fpath
def stop(self, timeout=20) -> None:
"""Stop the image writer, waiting for all processes or threads to finish."""
if self.type == "threads":
with tqdm.tqdm(total=len(self.futures), desc="Writing images") as progress_bar:
wait(self.futures, timeout=timeout)
progress_bar.update(len(self.futures))
else:
self._stop_processes(self.processes, self.image_queue, timeout)
def _stop_processes(self, timeout) -> None:
for _ in self.processes:
self.image_queue.put(None)
for process in self.processes:
process.join(timeout=timeout)
if process.is_alive():
process.terminate()
self.image_queue.close()
self.image_queue.join_thread()

View File

@ -1,16 +1,12 @@
"""Functions to create an empty dataset, and populate it with frames."""
# TODO(rcadene, aliberts): to adapt as class methods of next version of LeRobotDataset
import concurrent
import json
import logging
import multiprocessing
import shutil
from pathlib import Path
import torch
import tqdm
from PIL import Image
from lerobot.common.datasets.compute_stats import compute_stats
from lerobot.common.datasets.lerobot_dataset import CODEBASE_VERSION, LeRobotDataset
@ -26,277 +22,131 @@ from lerobot.scripts.push_dataset_to_hub import (
save_meta_data,
)
########################################################################################
# Asynchrounous saving of images on disk
########################################################################################
def safe_stop_image_writer(func):
# TODO(aliberts): Allow to pass custom exceptions
# (e.g. ThreadServiceExit, KeyboardInterrupt, SystemExit, UnpluggedError, DynamixelCommError)
def wrapper(*args, **kwargs):
try:
return func(*args, **kwargs)
except Exception as e:
image_writer = kwargs.get("dataset", {}).get("image_writer")
if image_writer is not None:
print("Waiting for image writer to terminate...")
stop_image_writer(image_writer, timeout=20)
raise e
return wrapper
def save_image(img_tensor, key, frame_index, episode_index, videos_dir: str):
img = Image.fromarray(img_tensor.numpy())
path = Path(videos_dir) / f"{key}_episode_{episode_index:06d}" / f"frame_{frame_index:06d}.png"
path.parent.mkdir(parents=True, exist_ok=True)
img.save(str(path), quality=100)
def loop_to_save_images_in_threads(image_queue, num_threads):
if num_threads < 1:
raise NotImplementedError(f"Only `num_threads>=1` is supported for now, but {num_threads=} given.")
with concurrent.futures.ThreadPoolExecutor(max_workers=num_threads) as executor:
futures = []
while True:
# Blocks until a frame is available
frame_data = image_queue.get()
# As usually done, exit loop when receiving None to stop the worker
if frame_data is None:
break
image, key, frame_index, episode_index, videos_dir = frame_data
futures.append(executor.submit(save_image, image, key, frame_index, episode_index, videos_dir))
# Before exiting function, wait for all threads to complete
with tqdm.tqdm(total=len(futures), desc="Writing images") as progress_bar:
concurrent.futures.wait(futures)
progress_bar.update(len(futures))
def start_image_writer_processes(image_queue, num_processes, num_threads_per_process):
if num_processes < 1:
raise ValueError(f"Only `num_processes>=1` is supported, but {num_processes=} given.")
if num_threads_per_process < 1:
raise NotImplementedError(
"Only `num_threads_per_process>=1` is supported for now, but {num_threads_per_process=} given."
)
processes = []
for _ in range(num_processes):
process = multiprocessing.Process(
target=loop_to_save_images_in_threads,
args=(image_queue, num_threads_per_process),
)
process.start()
processes.append(process)
return processes
def stop_processes(processes, queue, timeout):
# Send None to each process to signal them to stop
for _ in processes:
queue.put(None)
# Wait maximum 20 seconds for all processes to terminate
for process in processes:
process.join(timeout=timeout)
# If not terminated after 20 seconds, force termination
if process.is_alive():
process.terminate()
# Close the queue, no more items can be put in the queue
queue.close()
# Ensure all background queue threads have finished
queue.join_thread()
def start_image_writer(num_processes, num_threads):
"""This function abstract away the initialisation of processes or/and threads to
save images on disk asynchrounously, which is critical to control a robot and record data
at a high frame rate.
When `num_processes=0`, it returns a dictionary containing a threads pool of size `num_threads`.
When `num_processes>0`, it returns a dictionary containing a processes pool of size `num_processes`,
where each subprocess starts their own threads pool of size `num_threads`.
The optimal number of processes and threads depends on your computer capabilities.
We advise to use 4 threads per camera with 0 processes. If the fps is not stable, try to increase or lower
the number of threads. If it is still not stable, try to use 1 subprocess, or more.
"""
image_writer = {}
if num_processes == 0:
futures = []
threads_pool = concurrent.futures.ThreadPoolExecutor(max_workers=num_threads)
image_writer["threads_pool"], image_writer["futures"] = threads_pool, futures
else:
# TODO(rcadene): When using num_processes>1, `multiprocessing.Manager().Queue()`
# might be better than `multiprocessing.Queue()`. Source: https://www.geeksforgeeks.org/python-multiprocessing-queue-vs-multiprocessing-manager-queue
image_queue = multiprocessing.Queue()
processes_pool = start_image_writer_processes(
image_queue, num_processes=num_processes, num_threads_per_process=num_threads
)
image_writer["processes_pool"], image_writer["image_queue"] = processes_pool, image_queue
return image_writer
def async_save_image(image_writer, image, key, frame_index, episode_index, videos_dir):
"""This function abstract away the saving of an image on disk asynchrounously. It uses a dictionary
called image writer which contains either a pool of processes or a pool of threads.
"""
if "threads_pool" in image_writer:
threads_pool, futures = image_writer["threads_pool"], image_writer["futures"]
futures.append(threads_pool.submit(save_image, image, key, frame_index, episode_index, videos_dir))
else:
image_queue = image_writer["image_queue"]
image_queue.put((image, key, frame_index, episode_index, videos_dir))
def stop_image_writer(image_writer, timeout):
if "threads_pool" in image_writer:
futures = image_writer["futures"]
# Before exiting function, wait for all threads to complete
with tqdm.tqdm(total=len(futures), desc="Writing images") as progress_bar:
concurrent.futures.wait(futures, timeout=timeout)
progress_bar.update(len(futures))
else:
processes_pool, image_queue = image_writer["processes_pool"], image_writer["image_queue"]
stop_processes(processes_pool, image_queue, timeout=timeout)
########################################################################################
# Functions to initialize, resume and populate a dataset
########################################################################################
def init_dataset(
repo_id,
root,
force_override,
fps,
video,
write_images,
num_image_writer_processes,
num_image_writer_threads,
):
local_dir = Path(root) / repo_id
if local_dir.exists() and force_override:
shutil.rmtree(local_dir)
# def init_dataset(
# repo_id,
# root,
# force_override,
# fps,
# video,
# write_images,
# num_image_writer_processes,
# num_image_writer_threads,
# ):
# local_dir = Path(root) / repo_id
# if local_dir.exists() and force_override:
# shutil.rmtree(local_dir)
episodes_dir = local_dir / "episodes"
episodes_dir.mkdir(parents=True, exist_ok=True)
# episodes_dir = local_dir / "episodes"
# episodes_dir.mkdir(parents=True, exist_ok=True)
videos_dir = local_dir / "videos"
videos_dir.mkdir(parents=True, exist_ok=True)
# videos_dir = local_dir / "videos"
# videos_dir.mkdir(parents=True, exist_ok=True)
# Logic to resume data recording
rec_info_path = episodes_dir / "data_recording_info.json"
if rec_info_path.exists():
with open(rec_info_path) as f:
rec_info = json.load(f)
num_episodes = rec_info["last_episode_index"] + 1
else:
num_episodes = 0
# # Logic to resume data recording
# rec_info_path = episodes_dir / "data_recording_info.json"
# if rec_info_path.exists():
# with open(rec_info_path) as f:
# rec_info = json.load(f)
# num_episodes = rec_info["last_episode_index"] + 1
# else:
# num_episodes = 0
dataset = {
"repo_id": repo_id,
"local_dir": local_dir,
"videos_dir": videos_dir,
"episodes_dir": episodes_dir,
"fps": fps,
"video": video,
"rec_info_path": rec_info_path,
"num_episodes": num_episodes,
}
# dataset = {
# "repo_id": repo_id,
# "local_dir": local_dir,
# "videos_dir": videos_dir,
# "episodes_dir": episodes_dir,
# "fps": fps,
# "video": video,
# "rec_info_path": rec_info_path,
# "num_episodes": num_episodes,
# }
if write_images:
# Initialize processes or/and threads dedicated to save images on disk asynchronously,
# which is critical to control a robot and record data at a high frame rate.
image_writer = start_image_writer(
num_processes=num_image_writer_processes,
num_threads=num_image_writer_threads,
)
dataset["image_writer"] = image_writer
# if write_images:
# # Initialize processes or/and threads dedicated to save images on disk asynchronously,
# # which is critical to control a robot and record data at a high frame rate.
# image_writer = start_image_writer(
# num_processes=num_image_writer_processes,
# num_threads=num_image_writer_threads,
# )
# dataset["image_writer"] = image_writer
return dataset
# return dataset
def add_frame(dataset, observation, action):
if "current_episode" not in dataset:
# initialize episode dictionary
ep_dict = {}
for key in observation:
if key not in ep_dict:
ep_dict[key] = []
for key in action:
if key not in ep_dict:
ep_dict[key] = []
# def add_frame(dataset, observation, action):
# if "current_episode" not in dataset:
# # initialize episode dictionary
# ep_dict = {}
# for key in observation:
# if key not in ep_dict:
# ep_dict[key] = []
# for key in action:
# if key not in ep_dict:
# ep_dict[key] = []
ep_dict["episode_index"] = []
ep_dict["frame_index"] = []
ep_dict["timestamp"] = []
ep_dict["next.done"] = []
# ep_dict["episode_index"] = []
# ep_dict["frame_index"] = []
# ep_dict["timestamp"] = []
# ep_dict["next.done"] = []
dataset["current_episode"] = ep_dict
dataset["current_frame_index"] = 0
# dataset["current_episode"] = ep_dict
# dataset["current_frame_index"] = 0
ep_dict = dataset["current_episode"]
episode_index = dataset["num_episodes"]
frame_index = dataset["current_frame_index"]
videos_dir = dataset["videos_dir"]
video = dataset["video"]
fps = dataset["fps"]
# ep_dict = dataset["current_episode"]
# episode_index = dataset["num_episodes"]
# frame_index = dataset["current_frame_index"]
# videos_dir = dataset["videos_dir"]
# video = dataset["video"]
# fps = dataset["fps"]
ep_dict["episode_index"].append(episode_index)
ep_dict["frame_index"].append(frame_index)
ep_dict["timestamp"].append(frame_index / fps)
ep_dict["next.done"].append(False)
# ep_dict["episode_index"].append(episode_index)
# ep_dict["frame_index"].append(frame_index)
# ep_dict["timestamp"].append(frame_index / fps)
# ep_dict["next.done"].append(False)
img_keys = [key for key in observation if "image" in key]
non_img_keys = [key for key in observation if "image" not in key]
# img_keys = [key for key in observation if "image" in key]
# non_img_keys = [key for key in observation if "image" not in key]
# Save all observed modalities except images
for key in non_img_keys:
ep_dict[key].append(observation[key])
# # Save all observed modalities except images
# for key in non_img_keys:
# ep_dict[key].append(observation[key])
# Save actions
for key in action:
ep_dict[key].append(action[key])
# # Save actions
# for key in action:
# ep_dict[key].append(action[key])
if "image_writer" not in dataset:
dataset["current_frame_index"] += 1
return
# if "image_writer" not in dataset:
# dataset["current_frame_index"] += 1
# return
# Save images
image_writer = dataset["image_writer"]
for key in img_keys:
imgs_dir = videos_dir / f"{key}_episode_{episode_index:06d}"
async_save_image(
image_writer,
image=observation[key],
key=key,
frame_index=frame_index,
episode_index=episode_index,
videos_dir=str(videos_dir),
)
# # Save images
# image_writer = dataset["image_writer"]
# for key in img_keys:
# imgs_dir = videos_dir / f"{key}_episode_{episode_index:06d}"
# async_save_image(
# image_writer,
# image=observation[key],
# key=key,
# frame_index=frame_index,
# episode_index=episode_index,
# videos_dir=str(videos_dir),
# )
if video:
fname = f"{key}_episode_{episode_index:06d}.mp4"
frame_info = {"path": f"videos/{fname}", "timestamp": frame_index / fps}
else:
frame_info = str(imgs_dir / f"frame_{frame_index:06d}.png")
# if video:
# fname = f"{key}_episode_{episode_index:06d}.mp4"
# frame_info = {"path": f"videos/{fname}", "timestamp": frame_index / fps}
# else:
# frame_info = str(imgs_dir / f"frame_{frame_index:06d}.png")
ep_dict[key].append(frame_info)
# ep_dict[key].append(frame_info)
dataset["current_frame_index"] += 1
# dataset["current_frame_index"] += 1
def delete_current_episode(dataset):
@ -449,7 +299,7 @@ def create_lerobot_dataset(dataset, run_compute_stats, push_to_hub, tags, play_s
if "image_writer" in dataset:
logging.info("Waiting for image writer to terminate...")
image_writer = dataset["image_writer"]
stop_image_writer(image_writer, timeout=20)
image_writer.stop()
lerobot_dataset = from_dataset_to_lerobot_dataset(dataset, play_sounds)