Add auto_downsample_height_width
This commit is contained in:
parent
ff0029f84b
commit
e2e6f6e666
|
@ -43,16 +43,32 @@ def sample_indices(data_len: int) -> list[int]:
|
||||||
return np.round(np.linspace(0, data_len - 1, num_samples)).astype(int).tolist()
|
return np.round(np.linspace(0, data_len - 1, num_samples)).astype(int).tolist()
|
||||||
|
|
||||||
|
|
||||||
|
def auto_downsample_height_width(img: np.ndarray, target_size: int = 150, max_size_threshold: int = 300):
|
||||||
|
_, height, width = img.shape
|
||||||
|
|
||||||
|
if max(width, height) < max_size_threshold:
|
||||||
|
# no downsampling needed
|
||||||
|
return img
|
||||||
|
|
||||||
|
downsample_factor = int(width / target_size) if width > height else int(height / target_size)
|
||||||
|
return img[:, ::downsample_factor, ::downsample_factor]
|
||||||
|
|
||||||
|
|
||||||
def sample_images(image_paths: list[str]) -> np.ndarray:
|
def sample_images(image_paths: list[str]) -> np.ndarray:
|
||||||
sampled_indices = sample_indices(len(image_paths))
|
sampled_indices = sample_indices(len(image_paths))
|
||||||
images = []
|
|
||||||
for idx in sampled_indices:
|
images = None
|
||||||
|
for i, idx in enumerate(sampled_indices):
|
||||||
path = image_paths[idx]
|
path = image_paths[idx]
|
||||||
# we load as uint8 to reduce memory usage
|
# we load as uint8 to reduce memory usage
|
||||||
img = load_image_as_numpy(path, dtype=np.uint8, channel_first=True)
|
img = load_image_as_numpy(path, dtype=np.uint8, channel_first=True)
|
||||||
images.append(img)
|
img = auto_downsample_height_width(img)
|
||||||
|
|
||||||
|
if images is None:
|
||||||
|
images = np.empty((len(sampled_indices), *img.shape), dtype=np.uint8)
|
||||||
|
|
||||||
|
images[i] = img
|
||||||
|
|
||||||
images = np.stack(images)
|
|
||||||
return images
|
return images
|
||||||
|
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue