Fix image writer

This commit is contained in:
Simon Alibert 2024-10-28 12:01:32 +01:00
parent df3d2ec5df
commit 51e87f6f97
3 changed files with 118 additions and 130 deletions

View File

@ -14,11 +14,12 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import multiprocessing
from concurrent.futures import ThreadPoolExecutor, wait
import queue
import threading
from pathlib import Path
import numpy as np
import torch
import tqdm
from PIL import Image
DEFAULT_IMAGE_PATH = "{image_key}/episode_{episode_index:06d}/frame_{frame_index:06d}.png"
@ -39,8 +40,39 @@ def safe_stop_image_writer(func):
return wrapper
def write_image(image_array: np.ndarray, fpath: Path):
try:
image = Image.fromarray(image_array)
image.save(fpath)
except Exception as e:
print(f"Error writing image {fpath}: {e}")
def worker_thread_process(queue: queue.Queue):
while True:
item = queue.get()
if item is None:
queue.task_done()
break
image_array, fpath = item
write_image(image_array, fpath)
queue.task_done()
def worker_process(queue: queue.Queue, num_threads: int):
threads = []
for _ in range(num_threads):
t = threading.Thread(target=worker_thread_process, args=(queue,))
t.daemon = True
t.start()
threads.append(t)
for t in threads:
t.join()
class ImageWriter:
"""This class abstract away the initialisation of processes or/and threads to
"""
This class abstract away the initialisation of processes or/and threads to
save images on disk asynchrounously, which is critical to control a robot and record data
at a high frame rate.
@ -53,113 +85,66 @@ class ImageWriter:
the number of threads. If it is still not stable, try to use 1 subprocess, or more.
"""
def __init__(self, write_dir: Path, num_processes: int = 0, num_threads: int = 1, timeout: int = 10):
self.dir = write_dir
self.dir.mkdir(parents=True, exist_ok=True)
def __init__(self, write_dir: Path, num_processes: int = 0, num_threads: int = 1):
self.write_dir = write_dir
self.write_dir.mkdir(parents=True, exist_ok=True)
self.image_path = DEFAULT_IMAGE_PATH
self.num_processes = num_processes
self.num_threads = num_threads
self.timeout = timeout
self.queue = None
self.threads = []
self.processes = []
if self.num_processes == 0 and self.num_threads == 0:
self.type = "synchronous"
elif self.num_processes == 0 and self.num_threads > 0:
self.type = "threads"
self.threads = ThreadPoolExecutor(max_workers=self.num_threads)
self.futures = []
if self.num_processes == 0:
# Use threading
self.queue = queue.Queue()
for _ in range(self.num_threads):
t = threading.Thread(target=worker_thread_process, args=(self.queue,))
t.daemon = True
t.start()
self.threads.append(t)
else:
self.type = "processes"
self.main_event = multiprocessing.Event()
self.image_queue = multiprocessing.Queue()
self.processes: list[multiprocessing.Process] = []
self.events: list[multiprocessing.Event] = []
# Use multiprocessing
self.queue = multiprocessing.JoinableQueue()
for _ in range(self.num_processes):
event = multiprocessing.Event()
process = multiprocessing.Process(target=self._loop_to_save_images_in_threads, args=(event,))
process.start()
self.processes.append(process)
self.events.append(event)
def _loop_to_save_images_in_threads(self, event: multiprocessing.Event) -> None:
with ThreadPoolExecutor(max_workers=self.num_threads) as executor:
futures = []
while True:
frame_data = self.image_queue.get()
if frame_data is None:
self._wait_threads(self.futures, 10)
return
image, file_path = frame_data
futures.append(executor.submit(self._save_image, image, file_path))
if self.main_event.is_set():
self._wait_threads(self.futures, 10)
event.set()
def async_save_image(self, image: torch.Tensor, file_path: Path) -> None:
"""Save an image asynchronously using threads or processes."""
if self.type == "synchronous":
self._save_image(image, file_path)
elif self.type == "threads":
self.futures.append(self.threads.submit(self._save_image, image, file_path))
else:
self.image_queue.put((image, file_path))
def _save_image(self, image: torch.Tensor, file_path: Path) -> None:
img = Image.fromarray(image.numpy())
img.save(str(file_path), quality=100)
p = multiprocessing.Process(target=worker_process, args=(self.queue, self.num_threads))
p.daemon = True
p.start()
self.processes.append(p)
def get_image_file_path(self, episode_index: int, image_key: str, frame_index: int) -> Path:
fpath = self.image_path.format(
image_key=image_key, episode_index=episode_index, frame_index=frame_index
)
return self.dir / fpath
return self.write_dir / fpath
def get_episode_dir(self, episode_index: int, image_key: str) -> Path:
return self.get_image_file_path(
episode_index=episode_index, image_key=image_key, frame_index=0
).parent
def wait(self) -> None:
"""Wait for the thread/processes to finish writing."""
if self.type == "synchronous":
return
elif self.type == "threads":
self._wait_threads(self.futures)
def save_image(self, image_array: torch.Tensor | np.ndarray, fpath: Path):
if isinstance(image_array, torch.Tensor):
image_array = image_array.numpy()
self.queue.put((image_array, fpath))
def wait_until_done(self):
self.queue.join()
def stop(self):
if self.num_processes == 0:
# For threading
for _ in self.threads:
self.queue.put(None)
for t in self.threads:
t.join()
else:
self._wait_processes()
def _wait_threads(self, futures) -> None:
with tqdm.tqdm(total=len(futures), desc="Writing images") as progress_bar:
wait(futures, timeout=self.timeout)
progress_bar.update(len(futures))
def _wait_processes(self) -> None:
self.main_event.set()
for event in self.events:
event.wait()
self.main_event.clear()
def shutdown(self, timeout=20) -> None:
"""Stop the image writer, waiting for all processes or threads to finish."""
if self.type == "synchronous":
return
elif self.type == "threads":
self.threads.shutdown(wait=True)
else:
self._stop_processes(timeout)
def _stop_processes(self, timeout) -> None:
for _ in self.processes:
self.image_queue.put(None)
for process in self.processes:
process.join(timeout=timeout)
for process in self.processes:
if process.is_alive():
process.terminate()
self.image_queue.close()
self.image_queue.join_thread()
# For multiprocessing
num_nones = self.num_processes * self.num_threads
for _ in range(num_nones):
self.queue.put(None)
self.queue.close()
self.queue.join_thread()
for p in self.processes:
p.join()

View File

@ -25,7 +25,6 @@ import datasets
import torch
import torch.utils
from datasets import load_dataset
from datasets.table import embed_table_storage
from huggingface_hub import snapshot_download, upload_folder
from lerobot.common.datasets.compute_stats import aggregate_stats, compute_stats
@ -51,6 +50,7 @@ from lerobot.common.datasets.utils import (
load_stats,
load_tasks,
write_json,
write_parquet,
write_stats,
)
from lerobot.common.datasets.video_utils import (
@ -354,7 +354,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
@property
def num_samples(self) -> int:
"""Number of samples/frames in selected episodes."""
return len(self.hf_dataset)
return len(self.hf_dataset) if self.hf_dataset is not None else self.total_frames
@property
def num_episodes(self) -> int:
@ -584,9 +584,9 @@ class LeRobotDataset(torch.utils.data.Dataset):
if frame_index == 0:
img_path.parent.mkdir(parents=True, exist_ok=True)
self.image_writer.async_save_image(
image=frame[cam_key],
file_path=img_path,
self.image_writer.save_image(
image_array=frame[cam_key],
fpath=img_path,
)
if cam_key in self.image_keys:
@ -640,14 +640,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
ep_dataset = datasets.Dataset.from_dict(self.episode_buffer, features=self._features, split="train")
ep_data_path = self.root / self.get_data_file_path(ep_index=episode_index)
ep_data_path.parent.mkdir(parents=True, exist_ok=True)
# Embed image bytes into the table before saving to parquet
format = ep_dataset.format
ep_dataset = ep_dataset.with_format("arrow")
ep_dataset = ep_dataset.map(embed_table_storage, batched=False)
ep_dataset = ep_dataset.with_format(**format)
ep_dataset.to_parquet(ep_data_path)
write_parquet(ep_dataset, ep_data_path)
def _save_episode_to_metadata(
self, episode_index: int, episode_length: int, task: str, task_index: int
@ -709,13 +702,13 @@ class LeRobotDataset(torch.utils.data.Dataset):
remove the image_write in order for the LeRobotDataset object to be pickleable and parallelized.
"""
if self.image_writer is not None:
self.image_writer.shutdown()
self.image_writer.stop()
self.image_writer = None
def _wait_image_writer(self) -> None:
"""Wait for asynchronous image writer to finish."""
if self.image_writer is not None:
self.image_writer.wait()
self.image_writer.wait_until_done()
def encode_videos(self) -> None:
# Use ffmpeg to convert frames stored as png into mp4 videos
@ -754,7 +747,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
self._write_video_info()
if not keep_image_files and self.image_writer is not None:
shutil.rmtree(self.image_writer.dir)
shutil.rmtree(self.image_writer.write_dir)
if run_compute_stats:
self.stop_image_writer()

View File

@ -23,6 +23,7 @@ from typing import Any, Dict
import datasets
import jsonlines
import torch
from datasets.table import embed_table_storage
from huggingface_hub import DatasetCard, HfApi
from PIL import Image as PILImage
from torchvision import transforms
@ -80,6 +81,15 @@ def unflatten_dict(d: dict, sep: str = "/") -> dict:
return outdict
def write_parquet(dataset: datasets.Dataset, fpath: Path) -> None:
# Embed image bytes into the table before saving to parquet
format = dataset.format
dataset = dataset.with_format("arrow")
dataset = dataset.map(embed_table_storage, batched=False)
dataset = dataset.with_format(**format)
dataset.to_parquet(fpath)
def load_json(fpath: Path) -> Any:
with open(fpath) as f:
return json.load(f)
@ -114,6 +124,25 @@ def write_stats(stats: dict[str, torch.Tensor | dict], fpath: Path) -> None:
write_json(serialized_stats, fpath)
def load_info(local_dir: Path) -> dict:
return load_json(local_dir / INFO_PATH)
def load_stats(local_dir: Path) -> dict:
stats = load_json(local_dir / STATS_PATH)
stats = {key: torch.tensor(value) for key, value in flatten_dict(stats).items()}
return unflatten_dict(stats)
def load_tasks(local_dir: Path) -> dict:
tasks = load_jsonlines(local_dir / TASKS_PATH)
return {item["task_index"]: item["task"] for item in sorted(tasks, key=lambda x: x["task_index"])}
def load_episode_dicts(local_dir: Path) -> dict:
return load_jsonlines(local_dir / EPISODES_PATH)
def hf_transform_to_torch(items_dict: dict[torch.Tensor | None]):
"""Get a transform function that convert items from Hugging Face dataset (pyarrow)
to torch tensors. Importantly, images are converted from PIL, which corresponds to
@ -185,25 +214,6 @@ def get_hub_safe_version(repo_id: str, version: str, enforce_v2: bool = True) ->
return version
def load_info(local_dir: Path) -> dict:
return load_json(local_dir / INFO_PATH)
def load_stats(local_dir: Path) -> dict:
stats = load_json(local_dir / STATS_PATH)
stats = {key: torch.tensor(value) for key, value in flatten_dict(stats).items()}
return unflatten_dict(stats)
def load_tasks(local_dir: Path) -> dict:
tasks = load_jsonlines(local_dir / TASKS_PATH)
return {item["task_index"]: item["task"] for item in sorted(tasks, key=lambda x: x["task_index"])}
def load_episode_dicts(local_dir: Path) -> dict:
return load_jsonlines(local_dir / EPISODES_PATH)
def _get_info_from_robot(robot: Robot, use_videos: bool) -> tuple[list | dict]:
shapes = {key: len(names) for key, names in robot.names.items()}
camera_shapes = {}