remove beta annealing

This commit is contained in:
Pepijn 2025-03-14 13:22:22 +01:00
parent 17d12db7c4
commit 4e9b4dd380
2 changed files with 4 additions and 16 deletions

View File

@ -132,11 +132,9 @@ class PrioritizedSampler(Sampler[int]):
self,
data_len: int,
alpha: float = 0.6,
beta: float = 0.4, # For important sampling
eps: float = 1e-6,
num_samples_per_epoch: Optional[int] = None,
beta_start: float = 0.4,
beta_end: float = 1.0,
total_steps: int = 1,
):
"""
Args:
@ -148,12 +146,9 @@ class PrioritizedSampler(Sampler[int]):
"""
self.data_len = data_len
self.alpha = alpha
self.beta = beta
self.eps = eps
self.num_samples_per_epoch = num_samples_per_epoch or data_len
self.beta_start = beta_start
self.beta_end = beta_end
self.total_steps = total_steps
self._beta = self.beta_start
# Initialize difficulties and sum-tree
self.difficulties = [1.0] * data_len
@ -165,10 +160,6 @@ class PrioritizedSampler(Sampler[int]):
for i, p in enumerate(initial_priorities):
self.priorities[i] = p
def update_beta(self, current_step: int):
frac = min(1.0, current_step / self.total_steps)
self._beta = self.beta_start + (self.beta_end - self.beta_start) * frac
def update_priorities(self, indices: List[int], difficulties: List[float]):
"""
Updates the priorities in the sum-tree.
@ -199,6 +190,6 @@ class PrioritizedSampler(Sampler[int]):
total_p = self.sumtree.total_priority()
for idx in indices:
p = self.priorities[idx] / total_p
w.append((p * self.data_len) ** (-self._beta))
w.append((p * self.data_len) ** (-self.beta))
w = torch.tensor(w, dtype=torch.float32)
return w / w.max()

View File

@ -191,11 +191,9 @@ def train(cfg: TrainPipelineConfig):
sampler = PrioritizedSampler(
data_len=data_len,
alpha=0.6,
beta=0.4, # For important sampling
eps=1e-6,
num_samples_per_epoch=data_len,
beta_start=0.4,
beta_end=1.0,
total_steps=cfg.steps,
)
dataloader = torch.utils.data.DataLoader(
@ -234,7 +232,6 @@ def train(cfg: TrainPipelineConfig):
batch[key] = batch[key].to(device, non_blocking=True)
if "indices" in batch:
sampler.update_beta(step)
is_weights = sampler.compute_is_weights(batch["indices"].cpu().tolist())
batch["is_weights"] = is_weights