Add test_image_writer, accept PIL images, improve ImageWriter perf in main process
This commit is contained in:
parent
375abd3020
commit
6b2ec1ed77
|
@ -19,8 +19,8 @@ import threading
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
import PIL.Image
|
||||||
import torch
|
import torch
|
||||||
from PIL import Image
|
|
||||||
|
|
||||||
DEFAULT_IMAGE_PATH = "{image_key}/episode_{episode_index:06d}/frame_{frame_index:06d}.png"
|
DEFAULT_IMAGE_PATH = "{image_key}/episode_{episode_index:06d}/frame_{frame_index:06d}.png"
|
||||||
|
|
||||||
|
@ -40,10 +40,27 @@ def safe_stop_image_writer(func):
|
||||||
return wrapper
|
return wrapper
|
||||||
|
|
||||||
|
|
||||||
def write_image(image_array: np.ndarray, fpath: Path):
|
def image_array_to_image(image_array: np.ndarray) -> PIL.Image.Image:
|
||||||
|
# TODO(aliberts): handle 1 channel and 4 for depth images
|
||||||
|
if image_array.ndim == 3 and image_array.shape[0] in [1, 3]:
|
||||||
|
# Transpose from pytorch convention (C, H, W) to (H, W, C)
|
||||||
|
image_array = image_array.transpose(1, 2, 0)
|
||||||
|
if image_array.dtype != np.uint8:
|
||||||
|
# Assume the image is in [0, 1] range for floating-point data
|
||||||
|
image_array = np.clip(image_array, 0, 1)
|
||||||
|
image_array = (image_array * 255).astype(np.uint8)
|
||||||
|
return PIL.Image.fromarray(image_array)
|
||||||
|
|
||||||
|
|
||||||
|
def write_image(image: np.ndarray | PIL.Image.Image, fpath: Path):
|
||||||
try:
|
try:
|
||||||
image = Image.fromarray(image_array)
|
if isinstance(image, np.ndarray):
|
||||||
image.save(fpath)
|
img = image_array_to_image(image)
|
||||||
|
elif isinstance(image, PIL.Image.Image):
|
||||||
|
img = image
|
||||||
|
else:
|
||||||
|
raise TypeError(f"Unsupported image type: {type(image)}")
|
||||||
|
img.save(fpath)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(f"Error writing image {fpath}: {e}")
|
print(f"Error writing image {fpath}: {e}")
|
||||||
|
|
||||||
|
@ -63,7 +80,6 @@ def worker_process(queue: queue.Queue, num_threads: int):
|
||||||
threads = []
|
threads = []
|
||||||
for _ in range(num_threads):
|
for _ in range(num_threads):
|
||||||
t = threading.Thread(target=worker_thread_process, args=(queue,))
|
t = threading.Thread(target=worker_thread_process, args=(queue,))
|
||||||
t.daemon = True
|
|
||||||
t.start()
|
t.start()
|
||||||
threads.append(t)
|
threads.append(t)
|
||||||
for t in threads:
|
for t in threads:
|
||||||
|
@ -95,6 +111,10 @@ class ImageWriter:
|
||||||
self.queue = None
|
self.queue = None
|
||||||
self.threads = []
|
self.threads = []
|
||||||
self.processes = []
|
self.processes = []
|
||||||
|
self._stopped = False
|
||||||
|
|
||||||
|
if num_threads <= 0 and num_processes <= 0:
|
||||||
|
raise ValueError("Number of threads and processes must be greater than zero.")
|
||||||
|
|
||||||
if self.num_processes == 0:
|
if self.num_processes == 0:
|
||||||
# Use threading
|
# Use threading
|
||||||
|
@ -109,7 +129,6 @@ class ImageWriter:
|
||||||
self.queue = multiprocessing.JoinableQueue()
|
self.queue = multiprocessing.JoinableQueue()
|
||||||
for _ in range(self.num_processes):
|
for _ in range(self.num_processes):
|
||||||
p = multiprocessing.Process(target=worker_process, args=(self.queue, self.num_threads))
|
p = multiprocessing.Process(target=worker_process, args=(self.queue, self.num_threads))
|
||||||
p.daemon = True
|
|
||||||
p.start()
|
p.start()
|
||||||
self.processes.append(p)
|
self.processes.append(p)
|
||||||
|
|
||||||
|
@ -124,27 +143,33 @@ class ImageWriter:
|
||||||
episode_index=episode_index, image_key=image_key, frame_index=0
|
episode_index=episode_index, image_key=image_key, frame_index=0
|
||||||
).parent
|
).parent
|
||||||
|
|
||||||
def save_image(self, image_array: torch.Tensor | np.ndarray, fpath: Path):
|
def save_image(self, image: torch.Tensor | np.ndarray | PIL.Image.Image, fpath: Path):
|
||||||
if isinstance(image_array, torch.Tensor):
|
if isinstance(image, torch.Tensor):
|
||||||
image_array = image_array.numpy()
|
# Convert tensor to numpy array to minimize main process time
|
||||||
self.queue.put((image_array, fpath))
|
image = image.cpu().numpy()
|
||||||
|
self.queue.put((image, fpath))
|
||||||
|
|
||||||
def wait_until_done(self):
|
def wait_until_done(self):
|
||||||
self.queue.join()
|
self.queue.join()
|
||||||
|
|
||||||
def stop(self):
|
def stop(self):
|
||||||
|
if self._stopped:
|
||||||
|
return
|
||||||
|
|
||||||
if self.num_processes == 0:
|
if self.num_processes == 0:
|
||||||
# For threading
|
|
||||||
for _ in self.threads:
|
for _ in self.threads:
|
||||||
self.queue.put(None)
|
self.queue.put(None)
|
||||||
for t in self.threads:
|
for t in self.threads:
|
||||||
t.join()
|
t.join()
|
||||||
else:
|
else:
|
||||||
# For multiprocessing
|
|
||||||
num_nones = self.num_processes * self.num_threads
|
num_nones = self.num_processes * self.num_threads
|
||||||
for _ in range(num_nones):
|
for _ in range(num_nones):
|
||||||
self.queue.put(None)
|
self.queue.put(None)
|
||||||
self.queue.close()
|
|
||||||
self.queue.join_thread()
|
|
||||||
for p in self.processes:
|
for p in self.processes:
|
||||||
p.join()
|
p.join()
|
||||||
|
if p.is_alive():
|
||||||
|
p.terminate()
|
||||||
|
self.queue.close()
|
||||||
|
self.queue.join_thread()
|
||||||
|
|
||||||
|
self._stopped = True
|
||||||
|
|
|
@ -590,7 +590,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||||
img_path.parent.mkdir(parents=True, exist_ok=True)
|
img_path.parent.mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
self.image_writer.save_image(
|
self.image_writer.save_image(
|
||||||
image_array=frame[cam_key],
|
image=frame[cam_key],
|
||||||
fpath=img_path,
|
fpath=img_path,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
|
@ -59,7 +59,7 @@ def img_array_factory():
|
||||||
def img_factory(img_array_factory):
|
def img_factory(img_array_factory):
|
||||||
def _create_img(width=100, height=100) -> PIL.Image.Image:
|
def _create_img(width=100, height=100) -> PIL.Image.Image:
|
||||||
img_array = img_array_factory(width=width, height=height)
|
img_array = img_array_factory(width=width, height=height)
|
||||||
return PIL.Image.Image.fromarray(img_array)
|
return PIL.Image.fromarray(img_array)
|
||||||
|
|
||||||
return _create_img
|
return _create_img
|
||||||
|
|
||||||
|
|
|
@ -0,0 +1,385 @@
|
||||||
|
import queue
|
||||||
|
import time
|
||||||
|
from multiprocessing import queues
|
||||||
|
from unittest.mock import MagicMock, patch
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import pytest
|
||||||
|
from PIL import Image
|
||||||
|
|
||||||
|
from lerobot.common.datasets.image_writer import (
|
||||||
|
ImageWriter,
|
||||||
|
image_array_to_image,
|
||||||
|
safe_stop_image_writer,
|
||||||
|
write_image,
|
||||||
|
)
|
||||||
|
|
||||||
|
DUMMY_IMAGE = "test_image.png"
|
||||||
|
|
||||||
|
|
||||||
|
def test_init_threading(tmp_path):
|
||||||
|
writer = ImageWriter(write_dir=tmp_path, num_processes=0, num_threads=2)
|
||||||
|
try:
|
||||||
|
assert writer.num_processes == 0
|
||||||
|
assert writer.num_threads == 2
|
||||||
|
assert isinstance(writer.queue, queue.Queue)
|
||||||
|
assert len(writer.threads) == 2
|
||||||
|
assert len(writer.processes) == 0
|
||||||
|
assert all(t.is_alive() for t in writer.threads)
|
||||||
|
finally:
|
||||||
|
writer.stop()
|
||||||
|
|
||||||
|
|
||||||
|
def test_init_multiprocessing(tmp_path):
|
||||||
|
writer = ImageWriter(write_dir=tmp_path, num_processes=2, num_threads=2)
|
||||||
|
try:
|
||||||
|
assert writer.num_processes == 2
|
||||||
|
assert writer.num_threads == 2
|
||||||
|
assert isinstance(writer.queue, queues.JoinableQueue)
|
||||||
|
assert len(writer.threads) == 0
|
||||||
|
assert len(writer.processes) == 2
|
||||||
|
assert all(p.is_alive() for p in writer.processes)
|
||||||
|
finally:
|
||||||
|
writer.stop()
|
||||||
|
|
||||||
|
|
||||||
|
def test_write_dir_created(tmp_path):
|
||||||
|
write_dir = tmp_path / "non_existent_dir"
|
||||||
|
assert not write_dir.exists()
|
||||||
|
writer = ImageWriter(write_dir=write_dir)
|
||||||
|
try:
|
||||||
|
assert write_dir.exists()
|
||||||
|
finally:
|
||||||
|
writer.stop()
|
||||||
|
|
||||||
|
|
||||||
|
def test_get_image_file_path_and_episode_dir(tmp_path):
|
||||||
|
writer = ImageWriter(write_dir=tmp_path)
|
||||||
|
try:
|
||||||
|
episode_index = 1
|
||||||
|
image_key = "test_key"
|
||||||
|
frame_index = 10
|
||||||
|
expected_episode_dir = tmp_path / f"{image_key}/episode_{episode_index:06d}"
|
||||||
|
expected_path = expected_episode_dir / f"frame_{frame_index:06d}.png"
|
||||||
|
image_file_path = writer.get_image_file_path(episode_index, image_key, frame_index)
|
||||||
|
assert image_file_path == expected_path
|
||||||
|
episode_dir = writer.get_episode_dir(episode_index, image_key)
|
||||||
|
assert episode_dir == expected_episode_dir
|
||||||
|
finally:
|
||||||
|
writer.stop()
|
||||||
|
|
||||||
|
|
||||||
|
def test_zero_threads(tmp_path):
|
||||||
|
with pytest.raises(ValueError):
|
||||||
|
ImageWriter(write_dir=tmp_path, num_processes=0, num_threads=0)
|
||||||
|
|
||||||
|
|
||||||
|
def test_image_array_to_image_rgb(img_array_factory):
|
||||||
|
img_array = img_array_factory(100, 100)
|
||||||
|
result_image = image_array_to_image(img_array)
|
||||||
|
assert isinstance(result_image, Image.Image)
|
||||||
|
assert result_image.size == (100, 100)
|
||||||
|
assert result_image.mode == "RGB"
|
||||||
|
|
||||||
|
|
||||||
|
def test_image_array_to_image_pytorch_format(img_array_factory):
|
||||||
|
img_array = img_array_factory(100, 100).transpose(2, 0, 1)
|
||||||
|
result_image = image_array_to_image(img_array)
|
||||||
|
assert isinstance(result_image, Image.Image)
|
||||||
|
assert result_image.size == (100, 100)
|
||||||
|
assert result_image.mode == "RGB"
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.skip("TODO: implement")
|
||||||
|
def test_image_array_to_image_single_channel(img_array_factory):
|
||||||
|
img_array = img_array_factory(channels=1)
|
||||||
|
result_image = image_array_to_image(img_array)
|
||||||
|
assert isinstance(result_image, Image.Image)
|
||||||
|
assert result_image.size == (100, 100)
|
||||||
|
assert result_image.mode == "L"
|
||||||
|
|
||||||
|
|
||||||
|
def test_image_array_to_image_float_array(img_array_factory):
|
||||||
|
img_array = img_array_factory(dtype=np.float32)
|
||||||
|
result_image = image_array_to_image(img_array)
|
||||||
|
assert isinstance(result_image, Image.Image)
|
||||||
|
assert result_image.size == (100, 100)
|
||||||
|
assert result_image.mode == "RGB"
|
||||||
|
assert np.array(result_image).dtype == np.uint8
|
||||||
|
|
||||||
|
|
||||||
|
def test_image_array_to_image_out_of_bounds_float():
|
||||||
|
# Float array with values out of [0, 1]
|
||||||
|
img_array = np.random.uniform(-1, 2, size=(100, 100, 3)).astype(np.float32)
|
||||||
|
result_image = image_array_to_image(img_array)
|
||||||
|
assert isinstance(result_image, Image.Image)
|
||||||
|
assert result_image.size == (100, 100)
|
||||||
|
assert result_image.mode == "RGB"
|
||||||
|
assert np.array(result_image).dtype == np.uint8
|
||||||
|
assert np.array(result_image).min() >= 0 and np.array(result_image).max() <= 255
|
||||||
|
|
||||||
|
|
||||||
|
def test_write_image_numpy(tmp_path, img_array_factory):
|
||||||
|
image_array = img_array_factory()
|
||||||
|
fpath = tmp_path / DUMMY_IMAGE
|
||||||
|
write_image(image_array, fpath)
|
||||||
|
assert fpath.exists()
|
||||||
|
saved_image = np.array(Image.open(fpath))
|
||||||
|
assert np.array_equal(image_array, saved_image)
|
||||||
|
|
||||||
|
|
||||||
|
def test_write_image_image(tmp_path, img_factory):
|
||||||
|
image_pil = img_factory()
|
||||||
|
fpath = tmp_path / DUMMY_IMAGE
|
||||||
|
write_image(image_pil, fpath)
|
||||||
|
assert fpath.exists()
|
||||||
|
saved_image = Image.open(fpath)
|
||||||
|
assert list(saved_image.getdata()) == list(image_pil.getdata())
|
||||||
|
assert np.array_equal(image_pil, saved_image)
|
||||||
|
|
||||||
|
|
||||||
|
def test_write_image_exception(tmp_path):
|
||||||
|
image_array = "invalid data"
|
||||||
|
fpath = tmp_path / DUMMY_IMAGE
|
||||||
|
with patch("builtins.print") as mock_print:
|
||||||
|
write_image(image_array, fpath)
|
||||||
|
mock_print.assert_called()
|
||||||
|
assert not fpath.exists()
|
||||||
|
|
||||||
|
|
||||||
|
def test_save_image_numpy(tmp_path, img_array_factory):
|
||||||
|
writer = ImageWriter(write_dir=tmp_path)
|
||||||
|
try:
|
||||||
|
image_array = img_array_factory()
|
||||||
|
fpath = tmp_path / DUMMY_IMAGE
|
||||||
|
fpath.parent.mkdir(parents=True, exist_ok=True)
|
||||||
|
writer.save_image(image_array, fpath)
|
||||||
|
writer.wait_until_done()
|
||||||
|
assert fpath.exists()
|
||||||
|
saved_image = np.array(Image.open(fpath))
|
||||||
|
assert np.array_equal(image_array, saved_image)
|
||||||
|
finally:
|
||||||
|
writer.stop()
|
||||||
|
|
||||||
|
|
||||||
|
def test_save_image_numpy_multiprocessing(tmp_path, img_array_factory):
|
||||||
|
writer = ImageWriter(write_dir=tmp_path, num_processes=2, num_threads=2)
|
||||||
|
try:
|
||||||
|
image_array = img_array_factory()
|
||||||
|
fpath = tmp_path / DUMMY_IMAGE
|
||||||
|
writer.save_image(image_array, fpath)
|
||||||
|
writer.wait_until_done()
|
||||||
|
assert fpath.exists()
|
||||||
|
saved_image = np.array(Image.open(fpath))
|
||||||
|
assert np.array_equal(image_array, saved_image)
|
||||||
|
finally:
|
||||||
|
writer.stop()
|
||||||
|
|
||||||
|
|
||||||
|
def test_save_image_torch(tmp_path, img_tensor_factory):
|
||||||
|
writer = ImageWriter(write_dir=tmp_path)
|
||||||
|
try:
|
||||||
|
image_tensor = img_tensor_factory()
|
||||||
|
fpath = tmp_path / DUMMY_IMAGE
|
||||||
|
fpath.parent.mkdir(parents=True, exist_ok=True)
|
||||||
|
writer.save_image(image_tensor, fpath)
|
||||||
|
writer.wait_until_done()
|
||||||
|
assert fpath.exists()
|
||||||
|
saved_image = np.array(Image.open(fpath))
|
||||||
|
expected_image = (image_tensor.permute(1, 2, 0).cpu().numpy() * 255).astype(np.uint8)
|
||||||
|
assert np.array_equal(expected_image, saved_image)
|
||||||
|
finally:
|
||||||
|
writer.stop()
|
||||||
|
|
||||||
|
|
||||||
|
def test_save_image_torch_multiprocessing(tmp_path, img_tensor_factory):
|
||||||
|
writer = ImageWriter(write_dir=tmp_path, num_processes=2, num_threads=2)
|
||||||
|
try:
|
||||||
|
image_tensor = img_tensor_factory()
|
||||||
|
fpath = tmp_path / DUMMY_IMAGE
|
||||||
|
writer.save_image(image_tensor, fpath)
|
||||||
|
writer.wait_until_done()
|
||||||
|
assert fpath.exists()
|
||||||
|
saved_image = np.array(Image.open(fpath))
|
||||||
|
expected_image = (image_tensor.permute(1, 2, 0).cpu().numpy() * 255).astype(np.uint8)
|
||||||
|
assert np.array_equal(expected_image, saved_image)
|
||||||
|
finally:
|
||||||
|
writer.stop()
|
||||||
|
|
||||||
|
|
||||||
|
def test_save_image_pil(tmp_path, img_factory):
|
||||||
|
writer = ImageWriter(write_dir=tmp_path)
|
||||||
|
try:
|
||||||
|
image_pil = img_factory()
|
||||||
|
fpath = tmp_path / DUMMY_IMAGE
|
||||||
|
fpath.parent.mkdir(parents=True, exist_ok=True)
|
||||||
|
writer.save_image(image_pil, fpath)
|
||||||
|
writer.wait_until_done()
|
||||||
|
assert fpath.exists()
|
||||||
|
saved_image = Image.open(fpath)
|
||||||
|
assert list(saved_image.getdata()) == list(image_pil.getdata())
|
||||||
|
finally:
|
||||||
|
writer.stop()
|
||||||
|
|
||||||
|
|
||||||
|
def test_save_image_pil_multiprocessing(tmp_path, img_factory):
|
||||||
|
writer = ImageWriter(write_dir=tmp_path, num_processes=2, num_threads=2)
|
||||||
|
try:
|
||||||
|
image_pil = img_factory()
|
||||||
|
fpath = tmp_path / DUMMY_IMAGE
|
||||||
|
writer.save_image(image_pil, fpath)
|
||||||
|
writer.wait_until_done()
|
||||||
|
assert fpath.exists()
|
||||||
|
saved_image = Image.open(fpath)
|
||||||
|
assert list(saved_image.getdata()) == list(image_pil.getdata())
|
||||||
|
finally:
|
||||||
|
writer.stop()
|
||||||
|
|
||||||
|
|
||||||
|
def test_save_image_invalid_data(tmp_path):
|
||||||
|
writer = ImageWriter(write_dir=tmp_path)
|
||||||
|
try:
|
||||||
|
image_array = "invalid data"
|
||||||
|
fpath = writer.get_image_file_path(0, "test_key", 0)
|
||||||
|
fpath.parent.mkdir(parents=True, exist_ok=True)
|
||||||
|
with patch("builtins.print") as mock_print:
|
||||||
|
writer.save_image(image_array, fpath)
|
||||||
|
writer.wait_until_done()
|
||||||
|
mock_print.assert_called()
|
||||||
|
assert not fpath.exists()
|
||||||
|
finally:
|
||||||
|
writer.stop()
|
||||||
|
|
||||||
|
|
||||||
|
def test_save_image_after_stop(tmp_path, img_array_factory):
|
||||||
|
writer = ImageWriter(write_dir=tmp_path)
|
||||||
|
writer.stop()
|
||||||
|
image_array = img_array_factory()
|
||||||
|
fpath = writer.get_image_file_path(0, "test_key", 0)
|
||||||
|
writer.save_image(image_array, fpath)
|
||||||
|
time.sleep(1)
|
||||||
|
assert not fpath.exists()
|
||||||
|
|
||||||
|
|
||||||
|
def test_stop(tmp_path):
|
||||||
|
writer = ImageWriter(write_dir=tmp_path, num_processes=0, num_threads=2)
|
||||||
|
writer.stop()
|
||||||
|
assert not any(t.is_alive() for t in writer.threads)
|
||||||
|
|
||||||
|
|
||||||
|
def test_stop_multiprocessing(tmp_path):
|
||||||
|
writer = ImageWriter(write_dir=tmp_path, num_processes=2, num_threads=2)
|
||||||
|
writer.stop()
|
||||||
|
assert not any(p.is_alive() for p in writer.processes)
|
||||||
|
|
||||||
|
|
||||||
|
def test_multiple_stops(tmp_path):
|
||||||
|
writer = ImageWriter(write_dir=tmp_path)
|
||||||
|
writer.stop()
|
||||||
|
writer.stop() # Should not raise an exception
|
||||||
|
assert not any(t.is_alive() for t in writer.threads)
|
||||||
|
|
||||||
|
|
||||||
|
def test_multiple_stops_multiprocessing(tmp_path):
|
||||||
|
writer = ImageWriter(write_dir=tmp_path, num_processes=2, num_threads=2)
|
||||||
|
writer.stop()
|
||||||
|
writer.stop() # Should not raise an exception
|
||||||
|
assert not any(t.is_alive() for t in writer.threads)
|
||||||
|
|
||||||
|
|
||||||
|
def test_wait_until_done(tmp_path, img_array_factory):
|
||||||
|
writer = ImageWriter(write_dir=tmp_path, num_processes=0, num_threads=4)
|
||||||
|
try:
|
||||||
|
num_images = 100
|
||||||
|
image_arrays = [img_array_factory(width=500, height=500) for _ in range(num_images)]
|
||||||
|
fpaths = [writer.get_image_file_path(0, "test_key", i) for i in range(num_images)]
|
||||||
|
for image_array, fpath in zip(image_arrays, fpaths, strict=True):
|
||||||
|
fpath.parent.mkdir(parents=True, exist_ok=True)
|
||||||
|
writer.save_image(image_array, fpath)
|
||||||
|
writer.wait_until_done()
|
||||||
|
for i, fpath in enumerate(fpaths):
|
||||||
|
assert fpath.exists()
|
||||||
|
saved_image = np.array(Image.open(fpath))
|
||||||
|
assert np.array_equal(saved_image, image_arrays[i])
|
||||||
|
finally:
|
||||||
|
writer.stop()
|
||||||
|
|
||||||
|
|
||||||
|
def test_wait_until_done_multiprocessing(tmp_path, img_array_factory):
|
||||||
|
writer = ImageWriter(write_dir=tmp_path, num_processes=2, num_threads=2)
|
||||||
|
try:
|
||||||
|
num_images = 100
|
||||||
|
image_arrays = [img_array_factory() for _ in range(num_images)]
|
||||||
|
fpaths = [writer.get_image_file_path(0, "test_key", i) for i in range(num_images)]
|
||||||
|
for image_array, fpath in zip(image_arrays, fpaths, strict=True):
|
||||||
|
fpath.parent.mkdir(parents=True, exist_ok=True)
|
||||||
|
writer.save_image(image_array, fpath)
|
||||||
|
writer.wait_until_done()
|
||||||
|
for i, fpath in enumerate(fpaths):
|
||||||
|
assert fpath.exists()
|
||||||
|
saved_image = np.array(Image.open(fpath))
|
||||||
|
assert np.array_equal(saved_image, image_arrays[i])
|
||||||
|
finally:
|
||||||
|
writer.stop()
|
||||||
|
|
||||||
|
|
||||||
|
def test_exception_handling(tmp_path, img_array_factory):
|
||||||
|
writer = ImageWriter(write_dir=tmp_path)
|
||||||
|
try:
|
||||||
|
image_array = img_array_factory()
|
||||||
|
with (
|
||||||
|
patch.object(writer.queue, "put", side_effect=queue.Full("Queue is full")),
|
||||||
|
pytest.raises(queue.Full) as exc_info,
|
||||||
|
):
|
||||||
|
writer.save_image(image_array, tmp_path / "test.png")
|
||||||
|
assert str(exc_info.value) == "Queue is full"
|
||||||
|
finally:
|
||||||
|
writer.stop()
|
||||||
|
|
||||||
|
|
||||||
|
def test_with_different_image_formats(tmp_path, img_array_factory):
|
||||||
|
writer = ImageWriter(write_dir=tmp_path)
|
||||||
|
try:
|
||||||
|
image_array = img_array_factory()
|
||||||
|
formats = ["png", "jpeg", "bmp"]
|
||||||
|
for fmt in formats:
|
||||||
|
fpath = tmp_path / f"test_image.{fmt}"
|
||||||
|
write_image(image_array, fpath)
|
||||||
|
assert fpath.exists()
|
||||||
|
finally:
|
||||||
|
writer.stop()
|
||||||
|
|
||||||
|
|
||||||
|
def test_safe_stop_image_writer_decorator():
|
||||||
|
class MockDataset:
|
||||||
|
def __init__(self):
|
||||||
|
self.image_writer = MagicMock(spec=ImageWriter)
|
||||||
|
|
||||||
|
@safe_stop_image_writer
|
||||||
|
def function_that_raises_exception(dataset=None):
|
||||||
|
raise Exception("Test exception")
|
||||||
|
|
||||||
|
dataset = MockDataset()
|
||||||
|
|
||||||
|
with pytest.raises(Exception) as exc_info:
|
||||||
|
function_that_raises_exception(dataset=dataset)
|
||||||
|
|
||||||
|
assert str(exc_info.value) == "Test exception"
|
||||||
|
dataset.image_writer.stop.assert_called_once()
|
||||||
|
|
||||||
|
|
||||||
|
def test_main_process_time(tmp_path, img_tensor_factory):
|
||||||
|
writer = ImageWriter(write_dir=tmp_path)
|
||||||
|
try:
|
||||||
|
image_tensor = img_tensor_factory()
|
||||||
|
fpath = tmp_path / "test_main_process_time.png"
|
||||||
|
start_time = time.perf_counter()
|
||||||
|
writer.save_image(image_tensor, fpath)
|
||||||
|
end_time = time.perf_counter()
|
||||||
|
time_spent = end_time - start_time
|
||||||
|
# Might need to adjust this threshold depending on hardware
|
||||||
|
assert time_spent < 0.01, f"Main process time exceeded threshold: {time_spent}s"
|
||||||
|
writer.wait_until_done()
|
||||||
|
assert fpath.exists()
|
||||||
|
finally:
|
||||||
|
writer.stop()
|
Loading…
Reference in New Issue