diff --git a/lerobot/common/datasets/image_writer.py b/lerobot/common/datasets/image_writer.py index 8f368ef2..705fe73b 100644 --- a/lerobot/common/datasets/image_writer.py +++ b/lerobot/common/datasets/image_writer.py @@ -14,11 +14,12 @@ # See the License for the specific language governing permissions and # limitations under the License. import multiprocessing -from concurrent.futures import ThreadPoolExecutor, wait +import queue +import threading from pathlib import Path +import numpy as np import torch -import tqdm from PIL import Image DEFAULT_IMAGE_PATH = "{image_key}/episode_{episode_index:06d}/frame_{frame_index:06d}.png" @@ -39,8 +40,39 @@ def safe_stop_image_writer(func): return wrapper +def write_image(image_array: np.ndarray, fpath: Path): + try: + image = Image.fromarray(image_array) + image.save(fpath) + except Exception as e: + print(f"Error writing image {fpath}: {e}") + + +def worker_thread_process(queue: queue.Queue): + while True: + item = queue.get() + if item is None: + queue.task_done() + break + image_array, fpath = item + write_image(image_array, fpath) + queue.task_done() + + +def worker_process(queue: queue.Queue, num_threads: int): + threads = [] + for _ in range(num_threads): + t = threading.Thread(target=worker_thread_process, args=(queue,)) + t.daemon = True + t.start() + threads.append(t) + for t in threads: + t.join() + + class ImageWriter: - """This class abstract away the initialisation of processes or/and threads to + """ + This class abstract away the initialisation of processes or/and threads to save images on disk asynchrounously, which is critical to control a robot and record data at a high frame rate. @@ -53,113 +85,66 @@ class ImageWriter: the number of threads. If it is still not stable, try to use 1 subprocess, or more. """ - def __init__(self, write_dir: Path, num_processes: int = 0, num_threads: int = 1, timeout: int = 10): - self.dir = write_dir - self.dir.mkdir(parents=True, exist_ok=True) + def __init__(self, write_dir: Path, num_processes: int = 0, num_threads: int = 1): + self.write_dir = write_dir + self.write_dir.mkdir(parents=True, exist_ok=True) self.image_path = DEFAULT_IMAGE_PATH + self.num_processes = num_processes self.num_threads = num_threads - self.timeout = timeout + self.queue = None + self.threads = [] + self.processes = [] - if self.num_processes == 0 and self.num_threads == 0: - self.type = "synchronous" - elif self.num_processes == 0 and self.num_threads > 0: - self.type = "threads" - self.threads = ThreadPoolExecutor(max_workers=self.num_threads) - self.futures = [] + if self.num_processes == 0: + # Use threading + self.queue = queue.Queue() + for _ in range(self.num_threads): + t = threading.Thread(target=worker_thread_process, args=(self.queue,)) + t.daemon = True + t.start() + self.threads.append(t) else: - self.type = "processes" - self.main_event = multiprocessing.Event() - self.image_queue = multiprocessing.Queue() - self.processes: list[multiprocessing.Process] = [] - self.events: list[multiprocessing.Event] = [] + # Use multiprocessing + self.queue = multiprocessing.JoinableQueue() for _ in range(self.num_processes): - event = multiprocessing.Event() - process = multiprocessing.Process(target=self._loop_to_save_images_in_threads, args=(event,)) - process.start() - self.processes.append(process) - self.events.append(event) - - def _loop_to_save_images_in_threads(self, event: multiprocessing.Event) -> None: - with ThreadPoolExecutor(max_workers=self.num_threads) as executor: - futures = [] - while True: - frame_data = self.image_queue.get() - if frame_data is None: - self._wait_threads(self.futures, 10) - return - - image, file_path = frame_data - futures.append(executor.submit(self._save_image, image, file_path)) - - if self.main_event.is_set(): - self._wait_threads(self.futures, 10) - event.set() - - def async_save_image(self, image: torch.Tensor, file_path: Path) -> None: - """Save an image asynchronously using threads or processes.""" - if self.type == "synchronous": - self._save_image(image, file_path) - elif self.type == "threads": - self.futures.append(self.threads.submit(self._save_image, image, file_path)) - else: - self.image_queue.put((image, file_path)) - - def _save_image(self, image: torch.Tensor, file_path: Path) -> None: - img = Image.fromarray(image.numpy()) - img.save(str(file_path), quality=100) + p = multiprocessing.Process(target=worker_process, args=(self.queue, self.num_threads)) + p.daemon = True + p.start() + self.processes.append(p) def get_image_file_path(self, episode_index: int, image_key: str, frame_index: int) -> Path: fpath = self.image_path.format( image_key=image_key, episode_index=episode_index, frame_index=frame_index ) - return self.dir / fpath + return self.write_dir / fpath def get_episode_dir(self, episode_index: int, image_key: str) -> Path: return self.get_image_file_path( episode_index=episode_index, image_key=image_key, frame_index=0 ).parent - def wait(self) -> None: - """Wait for the thread/processes to finish writing.""" - if self.type == "synchronous": - return - elif self.type == "threads": - self._wait_threads(self.futures) + def save_image(self, image_array: torch.Tensor | np.ndarray, fpath: Path): + if isinstance(image_array, torch.Tensor): + image_array = image_array.numpy() + self.queue.put((image_array, fpath)) + + def wait_until_done(self): + self.queue.join() + + def stop(self): + if self.num_processes == 0: + # For threading + for _ in self.threads: + self.queue.put(None) + for t in self.threads: + t.join() else: - self._wait_processes() - - def _wait_threads(self, futures) -> None: - with tqdm.tqdm(total=len(futures), desc="Writing images") as progress_bar: - wait(futures, timeout=self.timeout) - progress_bar.update(len(futures)) - - def _wait_processes(self) -> None: - self.main_event.set() - for event in self.events: - event.wait() - - self.main_event.clear() - - def shutdown(self, timeout=20) -> None: - """Stop the image writer, waiting for all processes or threads to finish.""" - if self.type == "synchronous": - return - elif self.type == "threads": - self.threads.shutdown(wait=True) - else: - self._stop_processes(timeout) - - def _stop_processes(self, timeout) -> None: - for _ in self.processes: - self.image_queue.put(None) - - for process in self.processes: - process.join(timeout=timeout) - - for process in self.processes: - if process.is_alive(): - process.terminate() - - self.image_queue.close() - self.image_queue.join_thread() + # For multiprocessing + num_nones = self.num_processes * self.num_threads + for _ in range(num_nones): + self.queue.put(None) + self.queue.close() + self.queue.join_thread() + for p in self.processes: + p.join() diff --git a/lerobot/common/datasets/lerobot_dataset.py b/lerobot/common/datasets/lerobot_dataset.py index f451be28..4b1e58e9 100644 --- a/lerobot/common/datasets/lerobot_dataset.py +++ b/lerobot/common/datasets/lerobot_dataset.py @@ -25,7 +25,6 @@ import datasets import torch import torch.utils from datasets import load_dataset -from datasets.table import embed_table_storage from huggingface_hub import snapshot_download, upload_folder from lerobot.common.datasets.compute_stats import aggregate_stats, compute_stats @@ -51,6 +50,7 @@ from lerobot.common.datasets.utils import ( load_stats, load_tasks, write_json, + write_parquet, write_stats, ) from lerobot.common.datasets.video_utils import ( @@ -354,7 +354,7 @@ class LeRobotDataset(torch.utils.data.Dataset): @property def num_samples(self) -> int: """Number of samples/frames in selected episodes.""" - return len(self.hf_dataset) + return len(self.hf_dataset) if self.hf_dataset is not None else self.total_frames @property def num_episodes(self) -> int: @@ -584,9 +584,9 @@ class LeRobotDataset(torch.utils.data.Dataset): if frame_index == 0: img_path.parent.mkdir(parents=True, exist_ok=True) - self.image_writer.async_save_image( - image=frame[cam_key], - file_path=img_path, + self.image_writer.save_image( + image_array=frame[cam_key], + fpath=img_path, ) if cam_key in self.image_keys: @@ -640,14 +640,7 @@ class LeRobotDataset(torch.utils.data.Dataset): ep_dataset = datasets.Dataset.from_dict(self.episode_buffer, features=self._features, split="train") ep_data_path = self.root / self.get_data_file_path(ep_index=episode_index) ep_data_path.parent.mkdir(parents=True, exist_ok=True) - - # Embed image bytes into the table before saving to parquet - format = ep_dataset.format - ep_dataset = ep_dataset.with_format("arrow") - ep_dataset = ep_dataset.map(embed_table_storage, batched=False) - ep_dataset = ep_dataset.with_format(**format) - - ep_dataset.to_parquet(ep_data_path) + write_parquet(ep_dataset, ep_data_path) def _save_episode_to_metadata( self, episode_index: int, episode_length: int, task: str, task_index: int @@ -709,13 +702,13 @@ class LeRobotDataset(torch.utils.data.Dataset): remove the image_write in order for the LeRobotDataset object to be pickleable and parallelized. """ if self.image_writer is not None: - self.image_writer.shutdown() + self.image_writer.stop() self.image_writer = None def _wait_image_writer(self) -> None: """Wait for asynchronous image writer to finish.""" if self.image_writer is not None: - self.image_writer.wait() + self.image_writer.wait_until_done() def encode_videos(self) -> None: # Use ffmpeg to convert frames stored as png into mp4 videos @@ -754,7 +747,7 @@ class LeRobotDataset(torch.utils.data.Dataset): self._write_video_info() if not keep_image_files and self.image_writer is not None: - shutil.rmtree(self.image_writer.dir) + shutil.rmtree(self.image_writer.write_dir) if run_compute_stats: self.stop_image_writer() diff --git a/lerobot/common/datasets/utils.py b/lerobot/common/datasets/utils.py index 008d7843..6d941ecf 100644 --- a/lerobot/common/datasets/utils.py +++ b/lerobot/common/datasets/utils.py @@ -23,6 +23,7 @@ from typing import Any, Dict import datasets import jsonlines import torch +from datasets.table import embed_table_storage from huggingface_hub import DatasetCard, HfApi from PIL import Image as PILImage from torchvision import transforms @@ -80,6 +81,15 @@ def unflatten_dict(d: dict, sep: str = "/") -> dict: return outdict +def write_parquet(dataset: datasets.Dataset, fpath: Path) -> None: + # Embed image bytes into the table before saving to parquet + format = dataset.format + dataset = dataset.with_format("arrow") + dataset = dataset.map(embed_table_storage, batched=False) + dataset = dataset.with_format(**format) + dataset.to_parquet(fpath) + + def load_json(fpath: Path) -> Any: with open(fpath) as f: return json.load(f) @@ -114,6 +124,25 @@ def write_stats(stats: dict[str, torch.Tensor | dict], fpath: Path) -> None: write_json(serialized_stats, fpath) +def load_info(local_dir: Path) -> dict: + return load_json(local_dir / INFO_PATH) + + +def load_stats(local_dir: Path) -> dict: + stats = load_json(local_dir / STATS_PATH) + stats = {key: torch.tensor(value) for key, value in flatten_dict(stats).items()} + return unflatten_dict(stats) + + +def load_tasks(local_dir: Path) -> dict: + tasks = load_jsonlines(local_dir / TASKS_PATH) + return {item["task_index"]: item["task"] for item in sorted(tasks, key=lambda x: x["task_index"])} + + +def load_episode_dicts(local_dir: Path) -> dict: + return load_jsonlines(local_dir / EPISODES_PATH) + + def hf_transform_to_torch(items_dict: dict[torch.Tensor | None]): """Get a transform function that convert items from Hugging Face dataset (pyarrow) to torch tensors. Importantly, images are converted from PIL, which corresponds to @@ -185,25 +214,6 @@ def get_hub_safe_version(repo_id: str, version: str, enforce_v2: bool = True) -> return version -def load_info(local_dir: Path) -> dict: - return load_json(local_dir / INFO_PATH) - - -def load_stats(local_dir: Path) -> dict: - stats = load_json(local_dir / STATS_PATH) - stats = {key: torch.tensor(value) for key, value in flatten_dict(stats).items()} - return unflatten_dict(stats) - - -def load_tasks(local_dir: Path) -> dict: - tasks = load_jsonlines(local_dir / TASKS_PATH) - return {item["task_index"]: item["task"] for item in sorted(tasks, key=lambda x: x["task_index"])} - - -def load_episode_dicts(local_dir: Path) -> dict: - return load_jsonlines(local_dir / EPISODES_PATH) - - def _get_info_from_robot(robot: Robot, use_videos: bool) -> tuple[list | dict]: shapes = {key: len(names) for key, names in robot.names.items()} camera_shapes = {}