Fix RandomSubsetApply weighted sampling
This commit is contained in:
parent
644e77e413
commit
82e32f1fcd
|
@ -45,14 +45,15 @@ class RandomSubsetApply(Transform):
|
||||||
raise ValueError(f"n_subset should be in the interval [0, {len(transforms)}]")
|
raise ValueError(f"n_subset should be in the interval [0, {len(transforms)}]")
|
||||||
|
|
||||||
self.transforms = transforms
|
self.transforms = transforms
|
||||||
|
total = sum(p)
|
||||||
|
self.p = [prob / total for prob in p]
|
||||||
self.n_subset = n_subset
|
self.n_subset = n_subset
|
||||||
self.random_order = random_order
|
self.random_order = random_order
|
||||||
|
|
||||||
def forward(self, *inputs: Any) -> Any:
|
def forward(self, *inputs: Any) -> Any:
|
||||||
needs_unpacking = len(inputs) > 1
|
needs_unpacking = len(inputs) > 1
|
||||||
|
|
||||||
indices = torch.arange(len(self.transforms))
|
selected_indices = torch.multinomial(torch.tensor(self.p), self.n_subset)
|
||||||
selected_indices = torch.randperm(len(indices))[: self.n_subset]
|
|
||||||
if not self.random_order:
|
if not self.random_order:
|
||||||
selected_indices = selected_indices.sort().values
|
selected_indices = selected_indices.sort().values
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue