diff --git a/lerobot/common/datasets/sampler.py b/lerobot/common/datasets/sampler.py index 08c3ac13..31597e77 100644 --- a/lerobot/common/datasets/sampler.py +++ b/lerobot/common/datasets/sampler.py @@ -80,7 +80,7 @@ class SumTree: def initialize_tree(self, priorities: List[float]): """ - Efficiently initializes the sum tree in O(n). + Initializes the sum tree """ # Set leaf values for i, priority in enumerate(priorities): @@ -132,32 +132,42 @@ class PrioritizedSampler(Sampler[int]): self, data_len: int, alpha: float = 0.6, - beta: float = 0.1, eps: float = 1e-6, - replacement: bool = True, num_samples_per_epoch: Optional[int] = None, + beta_start: float = 0.4, + beta_end: float = 1.0, + total_steps: int = 1, ): """ Args: data_len: Total number of samples in the dataset. alpha: Exponent for priority scaling. Default is 0.6. - beta: Smoothing offset to avoid excluding low-priority samples. eps: Small constant to avoid zero priorities. replacement: Whether to sample with replacement. num_samples_per_epoch: Number of samples per epoch (default is data_len). """ self.data_len = data_len self.alpha = alpha - self.beta = beta self.eps = eps - self.replacement = replacement self.num_samples_per_epoch = num_samples_per_epoch or data_len + self.beta_start = beta_start + self.beta_end = beta_end + self.total_steps = total_steps + self._beta = self.beta_start # Initialize difficulties and sum-tree - self.difficulties = [1.0] * data_len # Default difficulty = 1.0 - initial_priorities = [(1.0 + eps) ** alpha + beta] * data_len # Compute initial priorities + self.difficulties = [1.0] * data_len + self.priorities = [0.0] * data_len + initial_priorities = [(1.0 + eps) ** alpha] * data_len + self.sumtree = SumTree(data_len) - self.sumtree.initialize_tree(initial_priorities) # Bulk load in O(n) + self.sumtree.initialize_tree(initial_priorities) + for i, p in enumerate(initial_priorities): + self.priorities[i] = p + + def update_beta(self, current_step: int): + frac = min(1.0, current_step / self.total_steps) + self._beta = self.beta_start + (self.beta_end - self.beta_start) * frac def update_priorities(self, indices: List[int], difficulties: List[float]): """ @@ -165,7 +175,8 @@ class PrioritizedSampler(Sampler[int]): """ for idx, diff in zip(indices, difficulties, strict=False): self.difficulties[idx] = diff - new_priority = (diff + self.eps) ** self.alpha + self.beta + new_priority = (diff + self.eps) ** self.alpha + self.priorities[idx] = new_priority self.sumtree.update(idx, new_priority) def __iter__(self) -> Iterator[int]: @@ -173,19 +184,21 @@ class PrioritizedSampler(Sampler[int]): Samples indices based on their priority weights. """ total_p = self.sumtree.total_priority() - sampled_indices = set() if not self.replacement else None for _ in range(self.num_samples_per_epoch): r = random.random() * total_p idx = self.sumtree.sample(r) - if not self.replacement: - while idx in sampled_indices: - r = random.random() * total_p - idx = self.sumtree.sample(r) - sampled_indices.add(idx) - yield idx def __len__(self) -> int: return self.num_samples_per_epoch + + def compute_is_weights(self, indices: List[int]) -> torch.Tensor: + w = [] + total_p = self.sumtree.total_priority() + for idx in indices: + p = self.priorities[idx] / total_p + w.append((p * self.data_len) ** (-self._beta)) + w = torch.tensor(w, dtype=torch.float32) + return w / w.max() diff --git a/lerobot/common/policies/act/modeling_act.py b/lerobot/common/policies/act/modeling_act.py index 99270c29..42f0a74b 100644 --- a/lerobot/common/policies/act/modeling_act.py +++ b/lerobot/common/policies/act/modeling_act.py @@ -161,7 +161,6 @@ class ACTPolicy(PreTrainedPolicy): l1_loss = elementwise_l1.mean() - # mean over time+action_dim => per-sample array of shape (B,) l1_per_sample = elementwise_l1.mean(dim=(1, 2)) if self.config.use_vae: @@ -175,13 +174,13 @@ class ACTPolicy(PreTrainedPolicy): loss_dict = { "l1_loss": l1_loss.item(), "kld_loss": mean_kld.item(), - "per_sample_l1": l1_per_sample, # shape (B,) + "per_sample_l1": l1_per_sample, } loss = l1_loss + mean_kld * self.config.kl_weight else: loss_dict = { "l1_loss": l1_loss.item(), - "per_sample_l1": l1_per_sample, # shape (B,) + "per_sample_l1": l1_per_sample, } loss = l1_loss diff --git a/lerobot/scripts/train.py b/lerobot/scripts/train.py index 1ad16a15..6292a3f6 100644 --- a/lerobot/scripts/train.py +++ b/lerobot/scripts/train.py @@ -70,6 +70,17 @@ def update_policy( with torch.autocast(device_type=device.type) if use_amp else nullcontext(): loss, output_dict = policy.forward(batch) # TODO(rcadene): policy.unnormalize_outputs(out_dict) + + # Apply importance-sampling if available + if "is_weights" in batch and "per_sample_l1" in output_dict: + per_sample_l1 = output_dict["per_sample_l1"] + l1_per_item = per_sample_l1.mean(dim=-1) + w = batch["is_weights"].to(device) + weighted_loss = (l1_per_item * w).mean() + if policy.config.use_vae and "kld_loss" in output_dict: + weighted_loss += output_dict["kld_loss"] * policy.config.kl_weight + loss = weighted_loss + grad_scaler.scale(loss).backward() # Unscale the gradient of the optimizer's assigned params in-place **prior to gradient clipping**. @@ -180,10 +191,11 @@ def train(cfg: TrainPipelineConfig): sampler = PrioritizedSampler( data_len=data_len, alpha=0.6, - beta=0.1, eps=1e-6, - replacement=True, num_samples_per_epoch=data_len, + beta_start=0.4, + beta_end=1.0, + total_steps=cfg.steps, ) dataloader = torch.utils.data.DataLoader( @@ -221,6 +233,11 @@ def train(cfg: TrainPipelineConfig): if isinstance(batch[key], torch.Tensor): batch[key] = batch[key].to(device, non_blocking=True) + if "indices" in batch: + sampler.update_beta(step) + is_weights = sampler.compute_is_weights(batch["indices"].cpu().tolist()) + batch["is_weights"] = is_weights + train_tracker, output_dict = update_policy( train_tracker, policy, @@ -232,11 +249,11 @@ def train(cfg: TrainPipelineConfig): use_amp=cfg.policy.use_amp, ) - # If we have 'indices' and 'per_sample_l1' then update sampler + # Update sampler if "indices" in batch and "per_sample_l1" in output_dict: - indices = batch["indices"].detach().cpu().tolist() # shape (B,) - difficulties = output_dict["per_sample_l1"].detach().cpu().tolist() # shape (B,) - sampler.update_priorities(indices, difficulties) + idxs = batch["indices"].cpu().tolist() + diffs = output_dict["per_sample_l1"].detach().cpu().tolist() + sampler.update_priorities(idxs, diffs) # Note: eval and checkpoint happens *after* the `step`th training update has completed, so we # increment `step` here.