Fix image writer

This commit is contained in:
Simon Alibert 2024-10-28 12:01:32 +01:00
parent df3d2ec5df
commit 51e87f6f97
3 changed files with 118 additions and 130 deletions

View File

@ -14,11 +14,12 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import multiprocessing import multiprocessing
from concurrent.futures import ThreadPoolExecutor, wait import queue
import threading
from pathlib import Path from pathlib import Path
import numpy as np
import torch import torch
import tqdm
from PIL import Image from PIL import Image
DEFAULT_IMAGE_PATH = "{image_key}/episode_{episode_index:06d}/frame_{frame_index:06d}.png" DEFAULT_IMAGE_PATH = "{image_key}/episode_{episode_index:06d}/frame_{frame_index:06d}.png"
@ -39,8 +40,39 @@ def safe_stop_image_writer(func):
return wrapper return wrapper
def write_image(image_array: np.ndarray, fpath: Path):
try:
image = Image.fromarray(image_array)
image.save(fpath)
except Exception as e:
print(f"Error writing image {fpath}: {e}")
def worker_thread_process(queue: queue.Queue):
while True:
item = queue.get()
if item is None:
queue.task_done()
break
image_array, fpath = item
write_image(image_array, fpath)
queue.task_done()
def worker_process(queue: queue.Queue, num_threads: int):
threads = []
for _ in range(num_threads):
t = threading.Thread(target=worker_thread_process, args=(queue,))
t.daemon = True
t.start()
threads.append(t)
for t in threads:
t.join()
class ImageWriter: class ImageWriter:
"""This class abstract away the initialisation of processes or/and threads to """
This class abstract away the initialisation of processes or/and threads to
save images on disk asynchrounously, which is critical to control a robot and record data save images on disk asynchrounously, which is critical to control a robot and record data
at a high frame rate. at a high frame rate.
@ -53,113 +85,66 @@ class ImageWriter:
the number of threads. If it is still not stable, try to use 1 subprocess, or more. the number of threads. If it is still not stable, try to use 1 subprocess, or more.
""" """
def __init__(self, write_dir: Path, num_processes: int = 0, num_threads: int = 1, timeout: int = 10): def __init__(self, write_dir: Path, num_processes: int = 0, num_threads: int = 1):
self.dir = write_dir self.write_dir = write_dir
self.dir.mkdir(parents=True, exist_ok=True) self.write_dir.mkdir(parents=True, exist_ok=True)
self.image_path = DEFAULT_IMAGE_PATH self.image_path = DEFAULT_IMAGE_PATH
self.num_processes = num_processes self.num_processes = num_processes
self.num_threads = num_threads self.num_threads = num_threads
self.timeout = timeout self.queue = None
self.threads = []
self.processes = []
if self.num_processes == 0 and self.num_threads == 0: if self.num_processes == 0:
self.type = "synchronous" # Use threading
elif self.num_processes == 0 and self.num_threads > 0: self.queue = queue.Queue()
self.type = "threads" for _ in range(self.num_threads):
self.threads = ThreadPoolExecutor(max_workers=self.num_threads) t = threading.Thread(target=worker_thread_process, args=(self.queue,))
self.futures = [] t.daemon = True
t.start()
self.threads.append(t)
else: else:
self.type = "processes" # Use multiprocessing
self.main_event = multiprocessing.Event() self.queue = multiprocessing.JoinableQueue()
self.image_queue = multiprocessing.Queue()
self.processes: list[multiprocessing.Process] = []
self.events: list[multiprocessing.Event] = []
for _ in range(self.num_processes): for _ in range(self.num_processes):
event = multiprocessing.Event() p = multiprocessing.Process(target=worker_process, args=(self.queue, self.num_threads))
process = multiprocessing.Process(target=self._loop_to_save_images_in_threads, args=(event,)) p.daemon = True
process.start() p.start()
self.processes.append(process) self.processes.append(p)
self.events.append(event)
def _loop_to_save_images_in_threads(self, event: multiprocessing.Event) -> None:
with ThreadPoolExecutor(max_workers=self.num_threads) as executor:
futures = []
while True:
frame_data = self.image_queue.get()
if frame_data is None:
self._wait_threads(self.futures, 10)
return
image, file_path = frame_data
futures.append(executor.submit(self._save_image, image, file_path))
if self.main_event.is_set():
self._wait_threads(self.futures, 10)
event.set()
def async_save_image(self, image: torch.Tensor, file_path: Path) -> None:
"""Save an image asynchronously using threads or processes."""
if self.type == "synchronous":
self._save_image(image, file_path)
elif self.type == "threads":
self.futures.append(self.threads.submit(self._save_image, image, file_path))
else:
self.image_queue.put((image, file_path))
def _save_image(self, image: torch.Tensor, file_path: Path) -> None:
img = Image.fromarray(image.numpy())
img.save(str(file_path), quality=100)
def get_image_file_path(self, episode_index: int, image_key: str, frame_index: int) -> Path: def get_image_file_path(self, episode_index: int, image_key: str, frame_index: int) -> Path:
fpath = self.image_path.format( fpath = self.image_path.format(
image_key=image_key, episode_index=episode_index, frame_index=frame_index image_key=image_key, episode_index=episode_index, frame_index=frame_index
) )
return self.dir / fpath return self.write_dir / fpath
def get_episode_dir(self, episode_index: int, image_key: str) -> Path: def get_episode_dir(self, episode_index: int, image_key: str) -> Path:
return self.get_image_file_path( return self.get_image_file_path(
episode_index=episode_index, image_key=image_key, frame_index=0 episode_index=episode_index, image_key=image_key, frame_index=0
).parent ).parent
def wait(self) -> None: def save_image(self, image_array: torch.Tensor | np.ndarray, fpath: Path):
"""Wait for the thread/processes to finish writing.""" if isinstance(image_array, torch.Tensor):
if self.type == "synchronous": image_array = image_array.numpy()
return self.queue.put((image_array, fpath))
elif self.type == "threads":
self._wait_threads(self.futures) def wait_until_done(self):
self.queue.join()
def stop(self):
if self.num_processes == 0:
# For threading
for _ in self.threads:
self.queue.put(None)
for t in self.threads:
t.join()
else: else:
self._wait_processes() # For multiprocessing
num_nones = self.num_processes * self.num_threads
def _wait_threads(self, futures) -> None: for _ in range(num_nones):
with tqdm.tqdm(total=len(futures), desc="Writing images") as progress_bar: self.queue.put(None)
wait(futures, timeout=self.timeout) self.queue.close()
progress_bar.update(len(futures)) self.queue.join_thread()
for p in self.processes:
def _wait_processes(self) -> None: p.join()
self.main_event.set()
for event in self.events:
event.wait()
self.main_event.clear()
def shutdown(self, timeout=20) -> None:
"""Stop the image writer, waiting for all processes or threads to finish."""
if self.type == "synchronous":
return
elif self.type == "threads":
self.threads.shutdown(wait=True)
else:
self._stop_processes(timeout)
def _stop_processes(self, timeout) -> None:
for _ in self.processes:
self.image_queue.put(None)
for process in self.processes:
process.join(timeout=timeout)
for process in self.processes:
if process.is_alive():
process.terminate()
self.image_queue.close()
self.image_queue.join_thread()

View File

@ -25,7 +25,6 @@ import datasets
import torch import torch
import torch.utils import torch.utils
from datasets import load_dataset from datasets import load_dataset
from datasets.table import embed_table_storage
from huggingface_hub import snapshot_download, upload_folder from huggingface_hub import snapshot_download, upload_folder
from lerobot.common.datasets.compute_stats import aggregate_stats, compute_stats from lerobot.common.datasets.compute_stats import aggregate_stats, compute_stats
@ -51,6 +50,7 @@ from lerobot.common.datasets.utils import (
load_stats, load_stats,
load_tasks, load_tasks,
write_json, write_json,
write_parquet,
write_stats, write_stats,
) )
from lerobot.common.datasets.video_utils import ( from lerobot.common.datasets.video_utils import (
@ -354,7 +354,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
@property @property
def num_samples(self) -> int: def num_samples(self) -> int:
"""Number of samples/frames in selected episodes.""" """Number of samples/frames in selected episodes."""
return len(self.hf_dataset) return len(self.hf_dataset) if self.hf_dataset is not None else self.total_frames
@property @property
def num_episodes(self) -> int: def num_episodes(self) -> int:
@ -584,9 +584,9 @@ class LeRobotDataset(torch.utils.data.Dataset):
if frame_index == 0: if frame_index == 0:
img_path.parent.mkdir(parents=True, exist_ok=True) img_path.parent.mkdir(parents=True, exist_ok=True)
self.image_writer.async_save_image( self.image_writer.save_image(
image=frame[cam_key], image_array=frame[cam_key],
file_path=img_path, fpath=img_path,
) )
if cam_key in self.image_keys: if cam_key in self.image_keys:
@ -640,14 +640,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
ep_dataset = datasets.Dataset.from_dict(self.episode_buffer, features=self._features, split="train") ep_dataset = datasets.Dataset.from_dict(self.episode_buffer, features=self._features, split="train")
ep_data_path = self.root / self.get_data_file_path(ep_index=episode_index) ep_data_path = self.root / self.get_data_file_path(ep_index=episode_index)
ep_data_path.parent.mkdir(parents=True, exist_ok=True) ep_data_path.parent.mkdir(parents=True, exist_ok=True)
write_parquet(ep_dataset, ep_data_path)
# Embed image bytes into the table before saving to parquet
format = ep_dataset.format
ep_dataset = ep_dataset.with_format("arrow")
ep_dataset = ep_dataset.map(embed_table_storage, batched=False)
ep_dataset = ep_dataset.with_format(**format)
ep_dataset.to_parquet(ep_data_path)
def _save_episode_to_metadata( def _save_episode_to_metadata(
self, episode_index: int, episode_length: int, task: str, task_index: int self, episode_index: int, episode_length: int, task: str, task_index: int
@ -709,13 +702,13 @@ class LeRobotDataset(torch.utils.data.Dataset):
remove the image_write in order for the LeRobotDataset object to be pickleable and parallelized. remove the image_write in order for the LeRobotDataset object to be pickleable and parallelized.
""" """
if self.image_writer is not None: if self.image_writer is not None:
self.image_writer.shutdown() self.image_writer.stop()
self.image_writer = None self.image_writer = None
def _wait_image_writer(self) -> None: def _wait_image_writer(self) -> None:
"""Wait for asynchronous image writer to finish.""" """Wait for asynchronous image writer to finish."""
if self.image_writer is not None: if self.image_writer is not None:
self.image_writer.wait() self.image_writer.wait_until_done()
def encode_videos(self) -> None: def encode_videos(self) -> None:
# Use ffmpeg to convert frames stored as png into mp4 videos # Use ffmpeg to convert frames stored as png into mp4 videos
@ -754,7 +747,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
self._write_video_info() self._write_video_info()
if not keep_image_files and self.image_writer is not None: if not keep_image_files and self.image_writer is not None:
shutil.rmtree(self.image_writer.dir) shutil.rmtree(self.image_writer.write_dir)
if run_compute_stats: if run_compute_stats:
self.stop_image_writer() self.stop_image_writer()

View File

@ -23,6 +23,7 @@ from typing import Any, Dict
import datasets import datasets
import jsonlines import jsonlines
import torch import torch
from datasets.table import embed_table_storage
from huggingface_hub import DatasetCard, HfApi from huggingface_hub import DatasetCard, HfApi
from PIL import Image as PILImage from PIL import Image as PILImage
from torchvision import transforms from torchvision import transforms
@ -80,6 +81,15 @@ def unflatten_dict(d: dict, sep: str = "/") -> dict:
return outdict return outdict
def write_parquet(dataset: datasets.Dataset, fpath: Path) -> None:
# Embed image bytes into the table before saving to parquet
format = dataset.format
dataset = dataset.with_format("arrow")
dataset = dataset.map(embed_table_storage, batched=False)
dataset = dataset.with_format(**format)
dataset.to_parquet(fpath)
def load_json(fpath: Path) -> Any: def load_json(fpath: Path) -> Any:
with open(fpath) as f: with open(fpath) as f:
return json.load(f) return json.load(f)
@ -114,6 +124,25 @@ def write_stats(stats: dict[str, torch.Tensor | dict], fpath: Path) -> None:
write_json(serialized_stats, fpath) write_json(serialized_stats, fpath)
def load_info(local_dir: Path) -> dict:
return load_json(local_dir / INFO_PATH)
def load_stats(local_dir: Path) -> dict:
stats = load_json(local_dir / STATS_PATH)
stats = {key: torch.tensor(value) for key, value in flatten_dict(stats).items()}
return unflatten_dict(stats)
def load_tasks(local_dir: Path) -> dict:
tasks = load_jsonlines(local_dir / TASKS_PATH)
return {item["task_index"]: item["task"] for item in sorted(tasks, key=lambda x: x["task_index"])}
def load_episode_dicts(local_dir: Path) -> dict:
return load_jsonlines(local_dir / EPISODES_PATH)
def hf_transform_to_torch(items_dict: dict[torch.Tensor | None]): def hf_transform_to_torch(items_dict: dict[torch.Tensor | None]):
"""Get a transform function that convert items from Hugging Face dataset (pyarrow) """Get a transform function that convert items from Hugging Face dataset (pyarrow)
to torch tensors. Importantly, images are converted from PIL, which corresponds to to torch tensors. Importantly, images are converted from PIL, which corresponds to
@ -185,25 +214,6 @@ def get_hub_safe_version(repo_id: str, version: str, enforce_v2: bool = True) ->
return version return version
def load_info(local_dir: Path) -> dict:
return load_json(local_dir / INFO_PATH)
def load_stats(local_dir: Path) -> dict:
stats = load_json(local_dir / STATS_PATH)
stats = {key: torch.tensor(value) for key, value in flatten_dict(stats).items()}
return unflatten_dict(stats)
def load_tasks(local_dir: Path) -> dict:
tasks = load_jsonlines(local_dir / TASKS_PATH)
return {item["task_index"]: item["task"] for item in sorted(tasks, key=lambda x: x["task_index"])}
def load_episode_dicts(local_dir: Path) -> dict:
return load_jsonlines(local_dir / EPISODES_PATH)
def _get_info_from_robot(robot: Robot, use_videos: bool) -> tuple[list | dict]: def _get_info_from_robot(robot: Robot, use_videos: bool) -> tuple[list | dict]:
shapes = {key: len(names) for key, names in robot.names.items()} shapes = {key: len(names) for key, names in robot.names.items()}
camera_shapes = {} camera_shapes = {}