Fix RandomSubsetApply weighted sampling

This commit is contained in:
Simon Alibert 2024-06-05 14:18:18 +00:00
parent 644e77e413
commit 82e32f1fcd
1 changed files with 3 additions and 2 deletions

View File

@ -45,14 +45,15 @@ class RandomSubsetApply(Transform):
raise ValueError(f"n_subset should be in the interval [0, {len(transforms)}]") raise ValueError(f"n_subset should be in the interval [0, {len(transforms)}]")
self.transforms = transforms self.transforms = transforms
total = sum(p)
self.p = [prob / total for prob in p]
self.n_subset = n_subset self.n_subset = n_subset
self.random_order = random_order self.random_order = random_order
def forward(self, *inputs: Any) -> Any: def forward(self, *inputs: Any) -> Any:
needs_unpacking = len(inputs) > 1 needs_unpacking = len(inputs) > 1
indices = torch.arange(len(self.transforms)) selected_indices = torch.multinomial(torch.tensor(self.p), self.n_subset)
selected_indices = torch.randperm(len(indices))[: self.n_subset]
if not self.random_order: if not self.random_order:
selected_indices = selected_indices.sort().values selected_indices = selected_indices.sort().values