remove beta annealing

This commit is contained in:
Pepijn 2025-03-14 13:22:22 +01:00
parent 17d12db7c4
commit 4e9b4dd380
2 changed files with 4 additions and 16 deletions

View File

@ -132,11 +132,9 @@ class PrioritizedSampler(Sampler[int]):
self, self,
data_len: int, data_len: int,
alpha: float = 0.6, alpha: float = 0.6,
beta: float = 0.4, # For important sampling
eps: float = 1e-6, eps: float = 1e-6,
num_samples_per_epoch: Optional[int] = None, num_samples_per_epoch: Optional[int] = None,
beta_start: float = 0.4,
beta_end: float = 1.0,
total_steps: int = 1,
): ):
""" """
Args: Args:
@ -148,12 +146,9 @@ class PrioritizedSampler(Sampler[int]):
""" """
self.data_len = data_len self.data_len = data_len
self.alpha = alpha self.alpha = alpha
self.beta = beta
self.eps = eps self.eps = eps
self.num_samples_per_epoch = num_samples_per_epoch or data_len self.num_samples_per_epoch = num_samples_per_epoch or data_len
self.beta_start = beta_start
self.beta_end = beta_end
self.total_steps = total_steps
self._beta = self.beta_start
# Initialize difficulties and sum-tree # Initialize difficulties and sum-tree
self.difficulties = [1.0] * data_len self.difficulties = [1.0] * data_len
@ -165,10 +160,6 @@ class PrioritizedSampler(Sampler[int]):
for i, p in enumerate(initial_priorities): for i, p in enumerate(initial_priorities):
self.priorities[i] = p self.priorities[i] = p
def update_beta(self, current_step: int):
frac = min(1.0, current_step / self.total_steps)
self._beta = self.beta_start + (self.beta_end - self.beta_start) * frac
def update_priorities(self, indices: List[int], difficulties: List[float]): def update_priorities(self, indices: List[int], difficulties: List[float]):
""" """
Updates the priorities in the sum-tree. Updates the priorities in the sum-tree.
@ -199,6 +190,6 @@ class PrioritizedSampler(Sampler[int]):
total_p = self.sumtree.total_priority() total_p = self.sumtree.total_priority()
for idx in indices: for idx in indices:
p = self.priorities[idx] / total_p p = self.priorities[idx] / total_p
w.append((p * self.data_len) ** (-self._beta)) w.append((p * self.data_len) ** (-self.beta))
w = torch.tensor(w, dtype=torch.float32) w = torch.tensor(w, dtype=torch.float32)
return w / w.max() return w / w.max()

View File

@ -191,11 +191,9 @@ def train(cfg: TrainPipelineConfig):
sampler = PrioritizedSampler( sampler = PrioritizedSampler(
data_len=data_len, data_len=data_len,
alpha=0.6, alpha=0.6,
beta=0.4, # For important sampling
eps=1e-6, eps=1e-6,
num_samples_per_epoch=data_len, num_samples_per_epoch=data_len,
beta_start=0.4,
beta_end=1.0,
total_steps=cfg.steps,
) )
dataloader = torch.utils.data.DataLoader( dataloader = torch.utils.data.DataLoader(
@ -234,7 +232,6 @@ def train(cfg: TrainPipelineConfig):
batch[key] = batch[key].to(device, non_blocking=True) batch[key] = batch[key].to(device, non_blocking=True)
if "indices" in batch: if "indices" in batch:
sampler.update_beta(step)
is_weights = sampler.compute_is_weights(batch["indices"].cpu().tolist()) is_weights = sampler.compute_is_weights(batch["indices"].cpu().tolist())
batch["is_weights"] = is_weights batch["is_weights"] = is_weights