remove important sampling
This commit is contained in:
parent
4bdbf2f6e0
commit
6e97876e81
|
@ -184,12 +184,3 @@ class PrioritizedSampler(Sampler[int]):
|
||||||
|
|
||||||
def __len__(self) -> int:
|
def __len__(self) -> int:
|
||||||
return self.num_samples_per_epoch
|
return self.num_samples_per_epoch
|
||||||
|
|
||||||
def compute_is_weights(self, indices: List[int]) -> torch.Tensor:
|
|
||||||
w = []
|
|
||||||
total_p = self.sumtree.total_priority()
|
|
||||||
for idx in indices:
|
|
||||||
p = self.priorities[idx] / total_p
|
|
||||||
w.append((p * self.data_len) ** (-self.beta))
|
|
||||||
w = torch.tensor(w, dtype=torch.float32)
|
|
||||||
return w / w.max()
|
|
||||||
|
|
|
@ -71,16 +71,6 @@ def update_policy(
|
||||||
loss, output_dict = policy.forward(batch)
|
loss, output_dict = policy.forward(batch)
|
||||||
# TODO(rcadene): policy.unnormalize_outputs(out_dict)
|
# TODO(rcadene): policy.unnormalize_outputs(out_dict)
|
||||||
|
|
||||||
# Apply importance-sampling if available
|
|
||||||
if "is_weights" in batch and "per_sample_l1" in output_dict:
|
|
||||||
per_sample_l1 = output_dict["per_sample_l1"]
|
|
||||||
l1_per_item = per_sample_l1.mean(dim=-1)
|
|
||||||
w = batch["is_weights"].to(device)
|
|
||||||
weighted_loss = (l1_per_item * w).mean()
|
|
||||||
if policy.config.use_vae and "kld_loss" in output_dict:
|
|
||||||
weighted_loss += output_dict["kld_loss"] * policy.config.kl_weight
|
|
||||||
loss = weighted_loss
|
|
||||||
|
|
||||||
grad_scaler.scale(loss).backward()
|
grad_scaler.scale(loss).backward()
|
||||||
|
|
||||||
# Unscale the gradient of the optimizer's assigned params in-place **prior to gradient clipping**.
|
# Unscale the gradient of the optimizer's assigned params in-place **prior to gradient clipping**.
|
||||||
|
@ -231,10 +221,6 @@ def train(cfg: TrainPipelineConfig):
|
||||||
if isinstance(batch[key], torch.Tensor):
|
if isinstance(batch[key], torch.Tensor):
|
||||||
batch[key] = batch[key].to(device, non_blocking=True)
|
batch[key] = batch[key].to(device, non_blocking=True)
|
||||||
|
|
||||||
if "indices" in batch:
|
|
||||||
is_weights = sampler.compute_is_weights(batch["indices"].cpu().tolist())
|
|
||||||
batch["is_weights"] = is_weights
|
|
||||||
|
|
||||||
train_tracker, output_dict = update_policy(
|
train_tracker, output_dict = update_policy(
|
||||||
train_tracker,
|
train_tracker,
|
||||||
policy,
|
policy,
|
||||||
|
|
Loading…
Reference in New Issue