remove beta
This commit is contained in:
parent
6e97876e81
commit
3b6fff70e1
|
@ -132,7 +132,6 @@ class PrioritizedSampler(Sampler[int]):
|
||||||
self,
|
self,
|
||||||
data_len: int,
|
data_len: int,
|
||||||
alpha: float = 0.6,
|
alpha: float = 0.6,
|
||||||
beta: float = 0.4, # For important sampling
|
|
||||||
eps: float = 1e-6,
|
eps: float = 1e-6,
|
||||||
num_samples_per_epoch: Optional[int] = None,
|
num_samples_per_epoch: Optional[int] = None,
|
||||||
):
|
):
|
||||||
|
@ -146,7 +145,6 @@ class PrioritizedSampler(Sampler[int]):
|
||||||
"""
|
"""
|
||||||
self.data_len = data_len
|
self.data_len = data_len
|
||||||
self.alpha = alpha
|
self.alpha = alpha
|
||||||
self.beta = beta
|
|
||||||
self.eps = eps
|
self.eps = eps
|
||||||
self.num_samples_per_epoch = num_samples_per_epoch or data_len
|
self.num_samples_per_epoch = num_samples_per_epoch or data_len
|
||||||
|
|
||||||
|
|
|
@ -181,7 +181,6 @@ def train(cfg: TrainPipelineConfig):
|
||||||
sampler = PrioritizedSampler(
|
sampler = PrioritizedSampler(
|
||||||
data_len=data_len,
|
data_len=data_len,
|
||||||
alpha=0.6,
|
alpha=0.6,
|
||||||
beta=0.4, # For important sampling
|
|
||||||
eps=1e-6,
|
eps=1e-6,
|
||||||
num_samples_per_epoch=data_len,
|
num_samples_per_epoch=data_len,
|
||||||
)
|
)
|
||||||
|
|
Loading…
Reference in New Issue