remove important sampling

This commit is contained in:
Pepijn 2025-03-17 08:27:17 +01:00
parent 4bdbf2f6e0
commit 6e97876e81
2 changed files with 0 additions and 23 deletions

View File

@ -184,12 +184,3 @@ class PrioritizedSampler(Sampler[int]):
def __len__(self) -> int:
return self.num_samples_per_epoch
def compute_is_weights(self, indices: List[int]) -> torch.Tensor:
w = []
total_p = self.sumtree.total_priority()
for idx in indices:
p = self.priorities[idx] / total_p
w.append((p * self.data_len) ** (-self.beta))
w = torch.tensor(w, dtype=torch.float32)
return w / w.max()

View File

@ -71,16 +71,6 @@ def update_policy(
loss, output_dict = policy.forward(batch)
# TODO(rcadene): policy.unnormalize_outputs(out_dict)
# Apply importance-sampling if available
if "is_weights" in batch and "per_sample_l1" in output_dict:
per_sample_l1 = output_dict["per_sample_l1"]
l1_per_item = per_sample_l1.mean(dim=-1)
w = batch["is_weights"].to(device)
weighted_loss = (l1_per_item * w).mean()
if policy.config.use_vae and "kld_loss" in output_dict:
weighted_loss += output_dict["kld_loss"] * policy.config.kl_weight
loss = weighted_loss
grad_scaler.scale(loss).backward()
# Unscale the gradient of the optimizer's assigned params in-place **prior to gradient clipping**.
@ -231,10 +221,6 @@ def train(cfg: TrainPipelineConfig):
if isinstance(batch[key], torch.Tensor):
batch[key] = batch[key].to(device, non_blocking=True)
if "indices" in batch:
is_weights = sampler.compute_is_weights(batch["indices"].cpu().tolist())
batch["is_weights"] = is_weights
train_tracker, output_dict = update_policy(
train_tracker,
policy,