From f8cfe2e1a923ec8108be04e35506184bbae4eb39 Mon Sep 17 00:00:00 2001 From: Tavish Date: Mon, 7 Apr 2025 21:54:33 +0800 Subject: [PATCH] support depth image when aggregating stats --- lerobot/common/datasets/compute_stats.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/lerobot/common/datasets/compute_stats.py b/lerobot/common/datasets/compute_stats.py index 1149ec83..f3234e3b 100644 --- a/lerobot/common/datasets/compute_stats.py +++ b/lerobot/common/datasets/compute_stats.py @@ -119,8 +119,11 @@ def _assert_type_and_shape(stats_list: list[dict[str, dict]]): raise ValueError("Number of dimensions must be at least 1, and is 0 instead.") if k == "count" and v.shape != (1,): raise ValueError(f"Shape of 'count' must be (1), but is {v.shape} instead.") - if "image" in fkey and k != "count" and v.shape != (3, 1, 1): - raise ValueError(f"Shape of '{k}' must be (3,1,1), but is {v.shape} instead.") + if "image" in fkey and k != "count": + if "depth" not in fkey and v.shape != (3, 1, 1): + raise ValueError(f"Shape of '{k}' must be (3,1,1), but is {v.shape} instead.") + if "depth" in fkey and v.shape != (1, 1, 1): + raise ValueError(f"Shape of '{k}' must be (1,1,1), but is {v.shape} instead.") def aggregate_feature_stats(stats_ft_list: list[dict[str, dict]]) -> dict[str, dict[str, np.ndarray]]: