From 6e97876e81e70f282b3b184b8fa0ae17182da3dd Mon Sep 17 00:00:00 2001 From: Pepijn Date: Mon, 17 Mar 2025 08:27:17 +0100 Subject: [PATCH] remove important sampling --- lerobot/common/datasets/sampler.py | 9 --------- lerobot/scripts/train.py | 14 -------------- 2 files changed, 23 deletions(-) diff --git a/lerobot/common/datasets/sampler.py b/lerobot/common/datasets/sampler.py index ddd1176e..ff34aefa 100644 --- a/lerobot/common/datasets/sampler.py +++ b/lerobot/common/datasets/sampler.py @@ -184,12 +184,3 @@ class PrioritizedSampler(Sampler[int]): def __len__(self) -> int: return self.num_samples_per_epoch - - def compute_is_weights(self, indices: List[int]) -> torch.Tensor: - w = [] - total_p = self.sumtree.total_priority() - for idx in indices: - p = self.priorities[idx] / total_p - w.append((p * self.data_len) ** (-self.beta)) - w = torch.tensor(w, dtype=torch.float32) - return w / w.max() diff --git a/lerobot/scripts/train.py b/lerobot/scripts/train.py index e6890073..2ffcf8a9 100644 --- a/lerobot/scripts/train.py +++ b/lerobot/scripts/train.py @@ -71,16 +71,6 @@ def update_policy( loss, output_dict = policy.forward(batch) # TODO(rcadene): policy.unnormalize_outputs(out_dict) - # Apply importance-sampling if available - if "is_weights" in batch and "per_sample_l1" in output_dict: - per_sample_l1 = output_dict["per_sample_l1"] - l1_per_item = per_sample_l1.mean(dim=-1) - w = batch["is_weights"].to(device) - weighted_loss = (l1_per_item * w).mean() - if policy.config.use_vae and "kld_loss" in output_dict: - weighted_loss += output_dict["kld_loss"] * policy.config.kl_weight - loss = weighted_loss - grad_scaler.scale(loss).backward() # Unscale the gradient of the optimizer's assigned params in-place **prior to gradient clipping**. @@ -231,10 +221,6 @@ def train(cfg: TrainPipelineConfig): if isinstance(batch[key], torch.Tensor): batch[key] = batch[key].to(device, non_blocking=True) - if "indices" in batch: - is_weights = sampler.compute_is_weights(batch["indices"].cpu().tolist()) - batch["is_weights"] = is_weights - train_tracker, output_dict = update_policy( train_tracker, policy,