diff --git a/tests/fixtures/dataset_factories.py b/tests/fixtures/dataset_factories.py
index b1136ffc..b489792a 100644
--- a/tests/fixtures/dataset_factories.py
+++ b/tests/fixtures/dataset_factories.py
@@ -41,16 +41,24 @@ def get_task_index(tasks_dicts: dict, task: str) -> int:
 
 @pytest.fixture(scope="session")
 def img_tensor_factory():
-    def _create_img_tensor(width=100, height=100) -> torch.Tensor:
-        return torch.rand((3, height, width), dtype=torch.float32)
+    def _create_img_tensor(width=100, height=100, channels=3, dtype=torch.float32) -> torch.Tensor:
+        return torch.rand((channels, height, width), dtype=dtype)
 
     return _create_img_tensor
 
 
 @pytest.fixture(scope="session")
 def img_array_factory():
-    def _create_img_array(width=100, height=100) -> np.ndarray:
-        return np.random.randint(0, 256, size=(height, width, 3), dtype=np.uint8)
+    def _create_img_array(width=100, height=100, channels=3, dtype=np.uint8) -> np.ndarray:
+        if np.issubdtype(dtype, np.unsignedinteger):
+            # Int array in [0, 255] range
+            img_array = np.random.randint(0, 256, size=(height, width, channels), dtype=dtype)
+        elif np.issubdtype(dtype, np.floating):
+            # Float array in [0, 1] range
+            img_array = np.random.rand(height, width, channels).astype(dtype)
+        else:
+            raise ValueError(dtype)
+        return img_array
 
     return _create_img_array