From 4e9b4dd38039e9807b0bdfb5e0cc1ff83e32e62c Mon Sep 17 00:00:00 2001 From: Pepijn Date: Fri, 14 Mar 2025 13:22:22 +0100 Subject: [PATCH] remove beta annealing --- lerobot/common/datasets/sampler.py | 15 +++------------ lerobot/scripts/train.py | 5 +---- 2 files changed, 4 insertions(+), 16 deletions(-) diff --git a/lerobot/common/datasets/sampler.py b/lerobot/common/datasets/sampler.py index 31597e77..f26fba53 100644 --- a/lerobot/common/datasets/sampler.py +++ b/lerobot/common/datasets/sampler.py @@ -132,11 +132,9 @@ class PrioritizedSampler(Sampler[int]): self, data_len: int, alpha: float = 0.6, + beta: float = 0.4, # For important sampling eps: float = 1e-6, num_samples_per_epoch: Optional[int] = None, - beta_start: float = 0.4, - beta_end: float = 1.0, - total_steps: int = 1, ): """ Args: @@ -148,12 +146,9 @@ class PrioritizedSampler(Sampler[int]): """ self.data_len = data_len self.alpha = alpha + self.beta = beta self.eps = eps self.num_samples_per_epoch = num_samples_per_epoch or data_len - self.beta_start = beta_start - self.beta_end = beta_end - self.total_steps = total_steps - self._beta = self.beta_start # Initialize difficulties and sum-tree self.difficulties = [1.0] * data_len @@ -165,10 +160,6 @@ class PrioritizedSampler(Sampler[int]): for i, p in enumerate(initial_priorities): self.priorities[i] = p - def update_beta(self, current_step: int): - frac = min(1.0, current_step / self.total_steps) - self._beta = self.beta_start + (self.beta_end - self.beta_start) * frac - def update_priorities(self, indices: List[int], difficulties: List[float]): """ Updates the priorities in the sum-tree. @@ -199,6 +190,6 @@ class PrioritizedSampler(Sampler[int]): total_p = self.sumtree.total_priority() for idx in indices: p = self.priorities[idx] / total_p - w.append((p * self.data_len) ** (-self._beta)) + w.append((p * self.data_len) ** (-self.beta)) w = torch.tensor(w, dtype=torch.float32) return w / w.max() diff --git a/lerobot/scripts/train.py b/lerobot/scripts/train.py index 6292a3f6..e6890073 100644 --- a/lerobot/scripts/train.py +++ b/lerobot/scripts/train.py @@ -191,11 +191,9 @@ def train(cfg: TrainPipelineConfig): sampler = PrioritizedSampler( data_len=data_len, alpha=0.6, + beta=0.4, # For important sampling eps=1e-6, num_samples_per_epoch=data_len, - beta_start=0.4, - beta_end=1.0, - total_steps=cfg.steps, ) dataloader = torch.utils.data.DataLoader( @@ -234,7 +232,6 @@ def train(cfg: TrainPipelineConfig): batch[key] = batch[key].to(device, non_blocking=True) if "indices" in batch: - sampler.update_beta(step) is_weights = sampler.compute_is_weights(batch["indices"].cpu().tolist()) batch["is_weights"] = is_weights