diff --git a/lerobot/common/datasets/image_writer.py b/lerobot/common/datasets/image_writer.py index 705fe73b..7b9ed0e4 100644 --- a/lerobot/common/datasets/image_writer.py +++ b/lerobot/common/datasets/image_writer.py @@ -19,8 +19,8 @@ import threading from pathlib import Path import numpy as np +import PIL.Image import torch -from PIL import Image DEFAULT_IMAGE_PATH = "{image_key}/episode_{episode_index:06d}/frame_{frame_index:06d}.png" @@ -40,10 +40,27 @@ def safe_stop_image_writer(func): return wrapper -def write_image(image_array: np.ndarray, fpath: Path): +def image_array_to_image(image_array: np.ndarray) -> PIL.Image.Image: + # TODO(aliberts): handle 1 channel and 4 for depth images + if image_array.ndim == 3 and image_array.shape[0] in [1, 3]: + # Transpose from pytorch convention (C, H, W) to (H, W, C) + image_array = image_array.transpose(1, 2, 0) + if image_array.dtype != np.uint8: + # Assume the image is in [0, 1] range for floating-point data + image_array = np.clip(image_array, 0, 1) + image_array = (image_array * 255).astype(np.uint8) + return PIL.Image.fromarray(image_array) + + +def write_image(image: np.ndarray | PIL.Image.Image, fpath: Path): try: - image = Image.fromarray(image_array) - image.save(fpath) + if isinstance(image, np.ndarray): + img = image_array_to_image(image) + elif isinstance(image, PIL.Image.Image): + img = image + else: + raise TypeError(f"Unsupported image type: {type(image)}") + img.save(fpath) except Exception as e: print(f"Error writing image {fpath}: {e}") @@ -63,7 +80,6 @@ def worker_process(queue: queue.Queue, num_threads: int): threads = [] for _ in range(num_threads): t = threading.Thread(target=worker_thread_process, args=(queue,)) - t.daemon = True t.start() threads.append(t) for t in threads: @@ -95,6 +111,10 @@ class ImageWriter: self.queue = None self.threads = [] self.processes = [] + self._stopped = False + + if num_threads <= 0 and num_processes <= 0: + raise ValueError("Number of threads and processes must be greater than zero.") if self.num_processes == 0: # Use threading @@ -109,7 +129,6 @@ class ImageWriter: self.queue = multiprocessing.JoinableQueue() for _ in range(self.num_processes): p = multiprocessing.Process(target=worker_process, args=(self.queue, self.num_threads)) - p.daemon = True p.start() self.processes.append(p) @@ -124,27 +143,33 @@ class ImageWriter: episode_index=episode_index, image_key=image_key, frame_index=0 ).parent - def save_image(self, image_array: torch.Tensor | np.ndarray, fpath: Path): - if isinstance(image_array, torch.Tensor): - image_array = image_array.numpy() - self.queue.put((image_array, fpath)) + def save_image(self, image: torch.Tensor | np.ndarray | PIL.Image.Image, fpath: Path): + if isinstance(image, torch.Tensor): + # Convert tensor to numpy array to minimize main process time + image = image.cpu().numpy() + self.queue.put((image, fpath)) def wait_until_done(self): self.queue.join() def stop(self): + if self._stopped: + return + if self.num_processes == 0: - # For threading for _ in self.threads: self.queue.put(None) for t in self.threads: t.join() else: - # For multiprocessing num_nones = self.num_processes * self.num_threads for _ in range(num_nones): self.queue.put(None) - self.queue.close() - self.queue.join_thread() for p in self.processes: p.join() + if p.is_alive(): + p.terminate() + self.queue.close() + self.queue.join_thread() + + self._stopped = True diff --git a/lerobot/common/datasets/lerobot_dataset.py b/lerobot/common/datasets/lerobot_dataset.py index a2b889c8..bbeb25d6 100644 --- a/lerobot/common/datasets/lerobot_dataset.py +++ b/lerobot/common/datasets/lerobot_dataset.py @@ -590,7 +590,7 @@ class LeRobotDataset(torch.utils.data.Dataset): img_path.parent.mkdir(parents=True, exist_ok=True) self.image_writer.save_image( - image_array=frame[cam_key], + image=frame[cam_key], fpath=img_path, ) diff --git a/tests/fixtures/dataset_factories.py b/tests/fixtures/dataset_factories.py index 52e6411e..b1136ffc 100644 --- a/tests/fixtures/dataset_factories.py +++ b/tests/fixtures/dataset_factories.py @@ -59,7 +59,7 @@ def img_array_factory(): def img_factory(img_array_factory): def _create_img(width=100, height=100) -> PIL.Image.Image: img_array = img_array_factory(width=width, height=height) - return PIL.Image.Image.fromarray(img_array) + return PIL.Image.fromarray(img_array) return _create_img diff --git a/tests/test_image_writer.py b/tests/test_image_writer.py new file mode 100644 index 00000000..3f3045d0 --- /dev/null +++ b/tests/test_image_writer.py @@ -0,0 +1,385 @@ +import queue +import time +from multiprocessing import queues +from unittest.mock import MagicMock, patch + +import numpy as np +import pytest +from PIL import Image + +from lerobot.common.datasets.image_writer import ( + ImageWriter, + image_array_to_image, + safe_stop_image_writer, + write_image, +) + +DUMMY_IMAGE = "test_image.png" + + +def test_init_threading(tmp_path): + writer = ImageWriter(write_dir=tmp_path, num_processes=0, num_threads=2) + try: + assert writer.num_processes == 0 + assert writer.num_threads == 2 + assert isinstance(writer.queue, queue.Queue) + assert len(writer.threads) == 2 + assert len(writer.processes) == 0 + assert all(t.is_alive() for t in writer.threads) + finally: + writer.stop() + + +def test_init_multiprocessing(tmp_path): + writer = ImageWriter(write_dir=tmp_path, num_processes=2, num_threads=2) + try: + assert writer.num_processes == 2 + assert writer.num_threads == 2 + assert isinstance(writer.queue, queues.JoinableQueue) + assert len(writer.threads) == 0 + assert len(writer.processes) == 2 + assert all(p.is_alive() for p in writer.processes) + finally: + writer.stop() + + +def test_write_dir_created(tmp_path): + write_dir = tmp_path / "non_existent_dir" + assert not write_dir.exists() + writer = ImageWriter(write_dir=write_dir) + try: + assert write_dir.exists() + finally: + writer.stop() + + +def test_get_image_file_path_and_episode_dir(tmp_path): + writer = ImageWriter(write_dir=tmp_path) + try: + episode_index = 1 + image_key = "test_key" + frame_index = 10 + expected_episode_dir = tmp_path / f"{image_key}/episode_{episode_index:06d}" + expected_path = expected_episode_dir / f"frame_{frame_index:06d}.png" + image_file_path = writer.get_image_file_path(episode_index, image_key, frame_index) + assert image_file_path == expected_path + episode_dir = writer.get_episode_dir(episode_index, image_key) + assert episode_dir == expected_episode_dir + finally: + writer.stop() + + +def test_zero_threads(tmp_path): + with pytest.raises(ValueError): + ImageWriter(write_dir=tmp_path, num_processes=0, num_threads=0) + + +def test_image_array_to_image_rgb(img_array_factory): + img_array = img_array_factory(100, 100) + result_image = image_array_to_image(img_array) + assert isinstance(result_image, Image.Image) + assert result_image.size == (100, 100) + assert result_image.mode == "RGB" + + +def test_image_array_to_image_pytorch_format(img_array_factory): + img_array = img_array_factory(100, 100).transpose(2, 0, 1) + result_image = image_array_to_image(img_array) + assert isinstance(result_image, Image.Image) + assert result_image.size == (100, 100) + assert result_image.mode == "RGB" + + +@pytest.mark.skip("TODO: implement") +def test_image_array_to_image_single_channel(img_array_factory): + img_array = img_array_factory(channels=1) + result_image = image_array_to_image(img_array) + assert isinstance(result_image, Image.Image) + assert result_image.size == (100, 100) + assert result_image.mode == "L" + + +def test_image_array_to_image_float_array(img_array_factory): + img_array = img_array_factory(dtype=np.float32) + result_image = image_array_to_image(img_array) + assert isinstance(result_image, Image.Image) + assert result_image.size == (100, 100) + assert result_image.mode == "RGB" + assert np.array(result_image).dtype == np.uint8 + + +def test_image_array_to_image_out_of_bounds_float(): + # Float array with values out of [0, 1] + img_array = np.random.uniform(-1, 2, size=(100, 100, 3)).astype(np.float32) + result_image = image_array_to_image(img_array) + assert isinstance(result_image, Image.Image) + assert result_image.size == (100, 100) + assert result_image.mode == "RGB" + assert np.array(result_image).dtype == np.uint8 + assert np.array(result_image).min() >= 0 and np.array(result_image).max() <= 255 + + +def test_write_image_numpy(tmp_path, img_array_factory): + image_array = img_array_factory() + fpath = tmp_path / DUMMY_IMAGE + write_image(image_array, fpath) + assert fpath.exists() + saved_image = np.array(Image.open(fpath)) + assert np.array_equal(image_array, saved_image) + + +def test_write_image_image(tmp_path, img_factory): + image_pil = img_factory() + fpath = tmp_path / DUMMY_IMAGE + write_image(image_pil, fpath) + assert fpath.exists() + saved_image = Image.open(fpath) + assert list(saved_image.getdata()) == list(image_pil.getdata()) + assert np.array_equal(image_pil, saved_image) + + +def test_write_image_exception(tmp_path): + image_array = "invalid data" + fpath = tmp_path / DUMMY_IMAGE + with patch("builtins.print") as mock_print: + write_image(image_array, fpath) + mock_print.assert_called() + assert not fpath.exists() + + +def test_save_image_numpy(tmp_path, img_array_factory): + writer = ImageWriter(write_dir=tmp_path) + try: + image_array = img_array_factory() + fpath = tmp_path / DUMMY_IMAGE + fpath.parent.mkdir(parents=True, exist_ok=True) + writer.save_image(image_array, fpath) + writer.wait_until_done() + assert fpath.exists() + saved_image = np.array(Image.open(fpath)) + assert np.array_equal(image_array, saved_image) + finally: + writer.stop() + + +def test_save_image_numpy_multiprocessing(tmp_path, img_array_factory): + writer = ImageWriter(write_dir=tmp_path, num_processes=2, num_threads=2) + try: + image_array = img_array_factory() + fpath = tmp_path / DUMMY_IMAGE + writer.save_image(image_array, fpath) + writer.wait_until_done() + assert fpath.exists() + saved_image = np.array(Image.open(fpath)) + assert np.array_equal(image_array, saved_image) + finally: + writer.stop() + + +def test_save_image_torch(tmp_path, img_tensor_factory): + writer = ImageWriter(write_dir=tmp_path) + try: + image_tensor = img_tensor_factory() + fpath = tmp_path / DUMMY_IMAGE + fpath.parent.mkdir(parents=True, exist_ok=True) + writer.save_image(image_tensor, fpath) + writer.wait_until_done() + assert fpath.exists() + saved_image = np.array(Image.open(fpath)) + expected_image = (image_tensor.permute(1, 2, 0).cpu().numpy() * 255).astype(np.uint8) + assert np.array_equal(expected_image, saved_image) + finally: + writer.stop() + + +def test_save_image_torch_multiprocessing(tmp_path, img_tensor_factory): + writer = ImageWriter(write_dir=tmp_path, num_processes=2, num_threads=2) + try: + image_tensor = img_tensor_factory() + fpath = tmp_path / DUMMY_IMAGE + writer.save_image(image_tensor, fpath) + writer.wait_until_done() + assert fpath.exists() + saved_image = np.array(Image.open(fpath)) + expected_image = (image_tensor.permute(1, 2, 0).cpu().numpy() * 255).astype(np.uint8) + assert np.array_equal(expected_image, saved_image) + finally: + writer.stop() + + +def test_save_image_pil(tmp_path, img_factory): + writer = ImageWriter(write_dir=tmp_path) + try: + image_pil = img_factory() + fpath = tmp_path / DUMMY_IMAGE + fpath.parent.mkdir(parents=True, exist_ok=True) + writer.save_image(image_pil, fpath) + writer.wait_until_done() + assert fpath.exists() + saved_image = Image.open(fpath) + assert list(saved_image.getdata()) == list(image_pil.getdata()) + finally: + writer.stop() + + +def test_save_image_pil_multiprocessing(tmp_path, img_factory): + writer = ImageWriter(write_dir=tmp_path, num_processes=2, num_threads=2) + try: + image_pil = img_factory() + fpath = tmp_path / DUMMY_IMAGE + writer.save_image(image_pil, fpath) + writer.wait_until_done() + assert fpath.exists() + saved_image = Image.open(fpath) + assert list(saved_image.getdata()) == list(image_pil.getdata()) + finally: + writer.stop() + + +def test_save_image_invalid_data(tmp_path): + writer = ImageWriter(write_dir=tmp_path) + try: + image_array = "invalid data" + fpath = writer.get_image_file_path(0, "test_key", 0) + fpath.parent.mkdir(parents=True, exist_ok=True) + with patch("builtins.print") as mock_print: + writer.save_image(image_array, fpath) + writer.wait_until_done() + mock_print.assert_called() + assert not fpath.exists() + finally: + writer.stop() + + +def test_save_image_after_stop(tmp_path, img_array_factory): + writer = ImageWriter(write_dir=tmp_path) + writer.stop() + image_array = img_array_factory() + fpath = writer.get_image_file_path(0, "test_key", 0) + writer.save_image(image_array, fpath) + time.sleep(1) + assert not fpath.exists() + + +def test_stop(tmp_path): + writer = ImageWriter(write_dir=tmp_path, num_processes=0, num_threads=2) + writer.stop() + assert not any(t.is_alive() for t in writer.threads) + + +def test_stop_multiprocessing(tmp_path): + writer = ImageWriter(write_dir=tmp_path, num_processes=2, num_threads=2) + writer.stop() + assert not any(p.is_alive() for p in writer.processes) + + +def test_multiple_stops(tmp_path): + writer = ImageWriter(write_dir=tmp_path) + writer.stop() + writer.stop() # Should not raise an exception + assert not any(t.is_alive() for t in writer.threads) + + +def test_multiple_stops_multiprocessing(tmp_path): + writer = ImageWriter(write_dir=tmp_path, num_processes=2, num_threads=2) + writer.stop() + writer.stop() # Should not raise an exception + assert not any(t.is_alive() for t in writer.threads) + + +def test_wait_until_done(tmp_path, img_array_factory): + writer = ImageWriter(write_dir=tmp_path, num_processes=0, num_threads=4) + try: + num_images = 100 + image_arrays = [img_array_factory(width=500, height=500) for _ in range(num_images)] + fpaths = [writer.get_image_file_path(0, "test_key", i) for i in range(num_images)] + for image_array, fpath in zip(image_arrays, fpaths, strict=True): + fpath.parent.mkdir(parents=True, exist_ok=True) + writer.save_image(image_array, fpath) + writer.wait_until_done() + for i, fpath in enumerate(fpaths): + assert fpath.exists() + saved_image = np.array(Image.open(fpath)) + assert np.array_equal(saved_image, image_arrays[i]) + finally: + writer.stop() + + +def test_wait_until_done_multiprocessing(tmp_path, img_array_factory): + writer = ImageWriter(write_dir=tmp_path, num_processes=2, num_threads=2) + try: + num_images = 100 + image_arrays = [img_array_factory() for _ in range(num_images)] + fpaths = [writer.get_image_file_path(0, "test_key", i) for i in range(num_images)] + for image_array, fpath in zip(image_arrays, fpaths, strict=True): + fpath.parent.mkdir(parents=True, exist_ok=True) + writer.save_image(image_array, fpath) + writer.wait_until_done() + for i, fpath in enumerate(fpaths): + assert fpath.exists() + saved_image = np.array(Image.open(fpath)) + assert np.array_equal(saved_image, image_arrays[i]) + finally: + writer.stop() + + +def test_exception_handling(tmp_path, img_array_factory): + writer = ImageWriter(write_dir=tmp_path) + try: + image_array = img_array_factory() + with ( + patch.object(writer.queue, "put", side_effect=queue.Full("Queue is full")), + pytest.raises(queue.Full) as exc_info, + ): + writer.save_image(image_array, tmp_path / "test.png") + assert str(exc_info.value) == "Queue is full" + finally: + writer.stop() + + +def test_with_different_image_formats(tmp_path, img_array_factory): + writer = ImageWriter(write_dir=tmp_path) + try: + image_array = img_array_factory() + formats = ["png", "jpeg", "bmp"] + for fmt in formats: + fpath = tmp_path / f"test_image.{fmt}" + write_image(image_array, fpath) + assert fpath.exists() + finally: + writer.stop() + + +def test_safe_stop_image_writer_decorator(): + class MockDataset: + def __init__(self): + self.image_writer = MagicMock(spec=ImageWriter) + + @safe_stop_image_writer + def function_that_raises_exception(dataset=None): + raise Exception("Test exception") + + dataset = MockDataset() + + with pytest.raises(Exception) as exc_info: + function_that_raises_exception(dataset=dataset) + + assert str(exc_info.value) == "Test exception" + dataset.image_writer.stop.assert_called_once() + + +def test_main_process_time(tmp_path, img_tensor_factory): + writer = ImageWriter(write_dir=tmp_path) + try: + image_tensor = img_tensor_factory() + fpath = tmp_path / "test_main_process_time.png" + start_time = time.perf_counter() + writer.save_image(image_tensor, fpath) + end_time = time.perf_counter() + time_spent = end_time - start_time + # Might need to adjust this threshold depending on hardware + assert time_spent < 0.01, f"Main process time exceeded threshold: {time_spent}s" + writer.wait_until_done() + assert fpath.exists() + finally: + writer.stop()