remove beta annealing
This commit is contained in:
parent
17d12db7c4
commit
4e9b4dd380
|
@ -132,11 +132,9 @@ class PrioritizedSampler(Sampler[int]):
|
|||
self,
|
||||
data_len: int,
|
||||
alpha: float = 0.6,
|
||||
beta: float = 0.4, # For important sampling
|
||||
eps: float = 1e-6,
|
||||
num_samples_per_epoch: Optional[int] = None,
|
||||
beta_start: float = 0.4,
|
||||
beta_end: float = 1.0,
|
||||
total_steps: int = 1,
|
||||
):
|
||||
"""
|
||||
Args:
|
||||
|
@ -148,12 +146,9 @@ class PrioritizedSampler(Sampler[int]):
|
|||
"""
|
||||
self.data_len = data_len
|
||||
self.alpha = alpha
|
||||
self.beta = beta
|
||||
self.eps = eps
|
||||
self.num_samples_per_epoch = num_samples_per_epoch or data_len
|
||||
self.beta_start = beta_start
|
||||
self.beta_end = beta_end
|
||||
self.total_steps = total_steps
|
||||
self._beta = self.beta_start
|
||||
|
||||
# Initialize difficulties and sum-tree
|
||||
self.difficulties = [1.0] * data_len
|
||||
|
@ -165,10 +160,6 @@ class PrioritizedSampler(Sampler[int]):
|
|||
for i, p in enumerate(initial_priorities):
|
||||
self.priorities[i] = p
|
||||
|
||||
def update_beta(self, current_step: int):
|
||||
frac = min(1.0, current_step / self.total_steps)
|
||||
self._beta = self.beta_start + (self.beta_end - self.beta_start) * frac
|
||||
|
||||
def update_priorities(self, indices: List[int], difficulties: List[float]):
|
||||
"""
|
||||
Updates the priorities in the sum-tree.
|
||||
|
@ -199,6 +190,6 @@ class PrioritizedSampler(Sampler[int]):
|
|||
total_p = self.sumtree.total_priority()
|
||||
for idx in indices:
|
||||
p = self.priorities[idx] / total_p
|
||||
w.append((p * self.data_len) ** (-self._beta))
|
||||
w.append((p * self.data_len) ** (-self.beta))
|
||||
w = torch.tensor(w, dtype=torch.float32)
|
||||
return w / w.max()
|
||||
|
|
|
@ -191,11 +191,9 @@ def train(cfg: TrainPipelineConfig):
|
|||
sampler = PrioritizedSampler(
|
||||
data_len=data_len,
|
||||
alpha=0.6,
|
||||
beta=0.4, # For important sampling
|
||||
eps=1e-6,
|
||||
num_samples_per_epoch=data_len,
|
||||
beta_start=0.4,
|
||||
beta_end=1.0,
|
||||
total_steps=cfg.steps,
|
||||
)
|
||||
|
||||
dataloader = torch.utils.data.DataLoader(
|
||||
|
@ -234,7 +232,6 @@ def train(cfg: TrainPipelineConfig):
|
|||
batch[key] = batch[key].to(device, non_blocking=True)
|
||||
|
||||
if "indices" in batch:
|
||||
sampler.update_beta(step)
|
||||
is_weights = sampler.compute_is_weights(batch["indices"].cpu().tolist())
|
||||
batch["is_weights"] = is_weights
|
||||
|
||||
|
|
Loading…
Reference in New Issue