Add more options to img factories

This commit is contained in:
Simon Alibert 2024-11-02 20:01:02 +01:00
parent 6b2ec1ed77
commit 7a342db9c4
1 changed files with 12 additions and 4 deletions

View File

@ -41,16 +41,24 @@ def get_task_index(tasks_dicts: dict, task: str) -> int:
@pytest.fixture(scope="session")
def img_tensor_factory():
def _create_img_tensor(width=100, height=100) -> torch.Tensor:
return torch.rand((3, height, width), dtype=torch.float32)
def _create_img_tensor(width=100, height=100, channels=3, dtype=torch.float32) -> torch.Tensor:
return torch.rand((channels, height, width), dtype=dtype)
return _create_img_tensor
@pytest.fixture(scope="session")
def img_array_factory():
def _create_img_array(width=100, height=100) -> np.ndarray:
return np.random.randint(0, 256, size=(height, width, 3), dtype=np.uint8)
def _create_img_array(width=100, height=100, channels=3, dtype=np.uint8) -> np.ndarray:
if np.issubdtype(dtype, np.unsignedinteger):
# Int array in [0, 255] range
img_array = np.random.randint(0, 256, size=(height, width, channels), dtype=dtype)
elif np.issubdtype(dtype, np.floating):
# Float array in [0, 1] range
img_array = np.random.rand(height, width, channels).astype(dtype)
else:
raise ValueError(dtype)
return img_array
return _create_img_array