remove beta annealing
This commit is contained in:
parent
17d12db7c4
commit
4e9b4dd380
|
@ -132,11 +132,9 @@ class PrioritizedSampler(Sampler[int]):
|
||||||
self,
|
self,
|
||||||
data_len: int,
|
data_len: int,
|
||||||
alpha: float = 0.6,
|
alpha: float = 0.6,
|
||||||
|
beta: float = 0.4, # For important sampling
|
||||||
eps: float = 1e-6,
|
eps: float = 1e-6,
|
||||||
num_samples_per_epoch: Optional[int] = None,
|
num_samples_per_epoch: Optional[int] = None,
|
||||||
beta_start: float = 0.4,
|
|
||||||
beta_end: float = 1.0,
|
|
||||||
total_steps: int = 1,
|
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Args:
|
Args:
|
||||||
|
@ -148,12 +146,9 @@ class PrioritizedSampler(Sampler[int]):
|
||||||
"""
|
"""
|
||||||
self.data_len = data_len
|
self.data_len = data_len
|
||||||
self.alpha = alpha
|
self.alpha = alpha
|
||||||
|
self.beta = beta
|
||||||
self.eps = eps
|
self.eps = eps
|
||||||
self.num_samples_per_epoch = num_samples_per_epoch or data_len
|
self.num_samples_per_epoch = num_samples_per_epoch or data_len
|
||||||
self.beta_start = beta_start
|
|
||||||
self.beta_end = beta_end
|
|
||||||
self.total_steps = total_steps
|
|
||||||
self._beta = self.beta_start
|
|
||||||
|
|
||||||
# Initialize difficulties and sum-tree
|
# Initialize difficulties and sum-tree
|
||||||
self.difficulties = [1.0] * data_len
|
self.difficulties = [1.0] * data_len
|
||||||
|
@ -165,10 +160,6 @@ class PrioritizedSampler(Sampler[int]):
|
||||||
for i, p in enumerate(initial_priorities):
|
for i, p in enumerate(initial_priorities):
|
||||||
self.priorities[i] = p
|
self.priorities[i] = p
|
||||||
|
|
||||||
def update_beta(self, current_step: int):
|
|
||||||
frac = min(1.0, current_step / self.total_steps)
|
|
||||||
self._beta = self.beta_start + (self.beta_end - self.beta_start) * frac
|
|
||||||
|
|
||||||
def update_priorities(self, indices: List[int], difficulties: List[float]):
|
def update_priorities(self, indices: List[int], difficulties: List[float]):
|
||||||
"""
|
"""
|
||||||
Updates the priorities in the sum-tree.
|
Updates the priorities in the sum-tree.
|
||||||
|
@ -199,6 +190,6 @@ class PrioritizedSampler(Sampler[int]):
|
||||||
total_p = self.sumtree.total_priority()
|
total_p = self.sumtree.total_priority()
|
||||||
for idx in indices:
|
for idx in indices:
|
||||||
p = self.priorities[idx] / total_p
|
p = self.priorities[idx] / total_p
|
||||||
w.append((p * self.data_len) ** (-self._beta))
|
w.append((p * self.data_len) ** (-self.beta))
|
||||||
w = torch.tensor(w, dtype=torch.float32)
|
w = torch.tensor(w, dtype=torch.float32)
|
||||||
return w / w.max()
|
return w / w.max()
|
||||||
|
|
|
@ -191,11 +191,9 @@ def train(cfg: TrainPipelineConfig):
|
||||||
sampler = PrioritizedSampler(
|
sampler = PrioritizedSampler(
|
||||||
data_len=data_len,
|
data_len=data_len,
|
||||||
alpha=0.6,
|
alpha=0.6,
|
||||||
|
beta=0.4, # For important sampling
|
||||||
eps=1e-6,
|
eps=1e-6,
|
||||||
num_samples_per_epoch=data_len,
|
num_samples_per_epoch=data_len,
|
||||||
beta_start=0.4,
|
|
||||||
beta_end=1.0,
|
|
||||||
total_steps=cfg.steps,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
dataloader = torch.utils.data.DataLoader(
|
dataloader = torch.utils.data.DataLoader(
|
||||||
|
@ -234,7 +232,6 @@ def train(cfg: TrainPipelineConfig):
|
||||||
batch[key] = batch[key].to(device, non_blocking=True)
|
batch[key] = batch[key].to(device, non_blocking=True)
|
||||||
|
|
||||||
if "indices" in batch:
|
if "indices" in batch:
|
||||||
sampler.update_beta(step)
|
|
||||||
is_weights = sampler.compute_is_weights(batch["indices"].cpu().tolist())
|
is_weights = sampler.compute_is_weights(batch["indices"].cpu().tolist())
|
||||||
batch["is_weights"] = is_weights
|
batch["is_weights"] = is_weights
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue