From e2e6f6e666dc3d85bde12b642643259b08cc19d1 Mon Sep 17 00:00:00 2001 From: Remi Cadene Date: Sun, 23 Feb 2025 18:15:39 +0000 Subject: [PATCH] Add auto_downsample_height_width --- lerobot/common/datasets/compute_stats.py | 24 ++++++++++++++++++++---- 1 file changed, 20 insertions(+), 4 deletions(-) diff --git a/lerobot/common/datasets/compute_stats.py b/lerobot/common/datasets/compute_stats.py index 7519c743..a029f892 100644 --- a/lerobot/common/datasets/compute_stats.py +++ b/lerobot/common/datasets/compute_stats.py @@ -43,16 +43,32 @@ def sample_indices(data_len: int) -> list[int]: return np.round(np.linspace(0, data_len - 1, num_samples)).astype(int).tolist() +def auto_downsample_height_width(img: np.ndarray, target_size: int = 150, max_size_threshold: int = 300): + _, height, width = img.shape + + if max(width, height) < max_size_threshold: + # no downsampling needed + return img + + downsample_factor = int(width / target_size) if width > height else int(height / target_size) + return img[:, ::downsample_factor, ::downsample_factor] + + def sample_images(image_paths: list[str]) -> np.ndarray: sampled_indices = sample_indices(len(image_paths)) - images = [] - for idx in sampled_indices: + + images = None + for i, idx in enumerate(sampled_indices): path = image_paths[idx] # we load as uint8 to reduce memory usage img = load_image_as_numpy(path, dtype=np.uint8, channel_first=True) - images.append(img) + img = auto_downsample_height_width(img) + + if images is None: + images = np.empty((len(sampled_indices), *img.shape), dtype=np.uint8) + + images[i] = img - images = np.stack(images) return images