From 7a342db9c43bad9c1ab43dd8d17836bdc95a1254 Mon Sep 17 00:00:00 2001 From: Simon Alibert Date: Sat, 2 Nov 2024 20:01:02 +0100 Subject: [PATCH] Add more options to img factories --- tests/fixtures/dataset_factories.py | 16 ++++++++++++---- 1 file changed, 12 insertions(+), 4 deletions(-) diff --git a/tests/fixtures/dataset_factories.py b/tests/fixtures/dataset_factories.py index b1136ffc..b489792a 100644 --- a/tests/fixtures/dataset_factories.py +++ b/tests/fixtures/dataset_factories.py @@ -41,16 +41,24 @@ def get_task_index(tasks_dicts: dict, task: str) -> int: @pytest.fixture(scope="session") def img_tensor_factory(): - def _create_img_tensor(width=100, height=100) -> torch.Tensor: - return torch.rand((3, height, width), dtype=torch.float32) + def _create_img_tensor(width=100, height=100, channels=3, dtype=torch.float32) -> torch.Tensor: + return torch.rand((channels, height, width), dtype=dtype) return _create_img_tensor @pytest.fixture(scope="session") def img_array_factory(): - def _create_img_array(width=100, height=100) -> np.ndarray: - return np.random.randint(0, 256, size=(height, width, 3), dtype=np.uint8) + def _create_img_array(width=100, height=100, channels=3, dtype=np.uint8) -> np.ndarray: + if np.issubdtype(dtype, np.unsignedinteger): + # Int array in [0, 255] range + img_array = np.random.randint(0, 256, size=(height, width, channels), dtype=dtype) + elif np.issubdtype(dtype, np.floating): + # Float array in [0, 1] range + img_array = np.random.rand(height, width, channels).astype(dtype) + else: + raise ValueError(dtype) + return img_array return _create_img_array