From 82e32f1fcdb2fca761adbe57e4181327ea00b022 Mon Sep 17 00:00:00 2001 From: Simon Alibert Date: Wed, 5 Jun 2024 14:18:18 +0000 Subject: [PATCH] Fix RandomSubsetApply weighted sampling --- lerobot/common/datasets/transforms.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/lerobot/common/datasets/transforms.py b/lerobot/common/datasets/transforms.py index 463c3f95..cf1baced 100644 --- a/lerobot/common/datasets/transforms.py +++ b/lerobot/common/datasets/transforms.py @@ -45,14 +45,15 @@ class RandomSubsetApply(Transform): raise ValueError(f"n_subset should be in the interval [0, {len(transforms)}]") self.transforms = transforms + total = sum(p) + self.p = [prob / total for prob in p] self.n_subset = n_subset self.random_order = random_order def forward(self, *inputs: Any) -> Any: needs_unpacking = len(inputs) > 1 - indices = torch.arange(len(self.transforms)) - selected_indices = torch.randperm(len(indices))[: self.n_subset] + selected_indices = torch.multinomial(torch.tensor(self.p), self.n_subset) if not self.random_order: selected_indices = selected_indices.sort().values