Add more options to img factories
This commit is contained in:
parent
6b2ec1ed77
commit
7a342db9c4
|
@ -41,16 +41,24 @@ def get_task_index(tasks_dicts: dict, task: str) -> int:
|
|||
|
||||
@pytest.fixture(scope="session")
|
||||
def img_tensor_factory():
|
||||
def _create_img_tensor(width=100, height=100) -> torch.Tensor:
|
||||
return torch.rand((3, height, width), dtype=torch.float32)
|
||||
def _create_img_tensor(width=100, height=100, channels=3, dtype=torch.float32) -> torch.Tensor:
|
||||
return torch.rand((channels, height, width), dtype=dtype)
|
||||
|
||||
return _create_img_tensor
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def img_array_factory():
|
||||
def _create_img_array(width=100, height=100) -> np.ndarray:
|
||||
return np.random.randint(0, 256, size=(height, width, 3), dtype=np.uint8)
|
||||
def _create_img_array(width=100, height=100, channels=3, dtype=np.uint8) -> np.ndarray:
|
||||
if np.issubdtype(dtype, np.unsignedinteger):
|
||||
# Int array in [0, 255] range
|
||||
img_array = np.random.randint(0, 256, size=(height, width, channels), dtype=dtype)
|
||||
elif np.issubdtype(dtype, np.floating):
|
||||
# Float array in [0, 1] range
|
||||
img_array = np.random.rand(height, width, channels).astype(dtype)
|
||||
else:
|
||||
raise ValueError(dtype)
|
||||
return img_array
|
||||
|
||||
return _create_img_array
|
||||
|
||||
|
|
Loading…
Reference in New Issue