remove important sampling
This commit is contained in:
parent
4bdbf2f6e0
commit
6e97876e81
|
@ -184,12 +184,3 @@ class PrioritizedSampler(Sampler[int]):
|
|||
|
||||
def __len__(self) -> int:
|
||||
return self.num_samples_per_epoch
|
||||
|
||||
def compute_is_weights(self, indices: List[int]) -> torch.Tensor:
|
||||
w = []
|
||||
total_p = self.sumtree.total_priority()
|
||||
for idx in indices:
|
||||
p = self.priorities[idx] / total_p
|
||||
w.append((p * self.data_len) ** (-self.beta))
|
||||
w = torch.tensor(w, dtype=torch.float32)
|
||||
return w / w.max()
|
||||
|
|
|
@ -71,16 +71,6 @@ def update_policy(
|
|||
loss, output_dict = policy.forward(batch)
|
||||
# TODO(rcadene): policy.unnormalize_outputs(out_dict)
|
||||
|
||||
# Apply importance-sampling if available
|
||||
if "is_weights" in batch and "per_sample_l1" in output_dict:
|
||||
per_sample_l1 = output_dict["per_sample_l1"]
|
||||
l1_per_item = per_sample_l1.mean(dim=-1)
|
||||
w = batch["is_weights"].to(device)
|
||||
weighted_loss = (l1_per_item * w).mean()
|
||||
if policy.config.use_vae and "kld_loss" in output_dict:
|
||||
weighted_loss += output_dict["kld_loss"] * policy.config.kl_weight
|
||||
loss = weighted_loss
|
||||
|
||||
grad_scaler.scale(loss).backward()
|
||||
|
||||
# Unscale the gradient of the optimizer's assigned params in-place **prior to gradient clipping**.
|
||||
|
@ -231,10 +221,6 @@ def train(cfg: TrainPipelineConfig):
|
|||
if isinstance(batch[key], torch.Tensor):
|
||||
batch[key] = batch[key].to(device, non_blocking=True)
|
||||
|
||||
if "indices" in batch:
|
||||
is_weights = sampler.compute_is_weights(batch["indices"].cpu().tolist())
|
||||
batch["is_weights"] = is_weights
|
||||
|
||||
train_tracker, output_dict = update_policy(
|
||||
train_tracker,
|
||||
policy,
|
||||
|
|
Loading…
Reference in New Issue