Add Important sampling, only use replacement, remove beta smoothing
This commit is contained in:
parent
6a8be97bb5
commit
17d12db7c4
|
@ -80,7 +80,7 @@ class SumTree:
|
||||||
|
|
||||||
def initialize_tree(self, priorities: List[float]):
|
def initialize_tree(self, priorities: List[float]):
|
||||||
"""
|
"""
|
||||||
Efficiently initializes the sum tree in O(n).
|
Initializes the sum tree
|
||||||
"""
|
"""
|
||||||
# Set leaf values
|
# Set leaf values
|
||||||
for i, priority in enumerate(priorities):
|
for i, priority in enumerate(priorities):
|
||||||
|
@ -132,32 +132,42 @@ class PrioritizedSampler(Sampler[int]):
|
||||||
self,
|
self,
|
||||||
data_len: int,
|
data_len: int,
|
||||||
alpha: float = 0.6,
|
alpha: float = 0.6,
|
||||||
beta: float = 0.1,
|
|
||||||
eps: float = 1e-6,
|
eps: float = 1e-6,
|
||||||
replacement: bool = True,
|
|
||||||
num_samples_per_epoch: Optional[int] = None,
|
num_samples_per_epoch: Optional[int] = None,
|
||||||
|
beta_start: float = 0.4,
|
||||||
|
beta_end: float = 1.0,
|
||||||
|
total_steps: int = 1,
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Args:
|
Args:
|
||||||
data_len: Total number of samples in the dataset.
|
data_len: Total number of samples in the dataset.
|
||||||
alpha: Exponent for priority scaling. Default is 0.6.
|
alpha: Exponent for priority scaling. Default is 0.6.
|
||||||
beta: Smoothing offset to avoid excluding low-priority samples.
|
|
||||||
eps: Small constant to avoid zero priorities.
|
eps: Small constant to avoid zero priorities.
|
||||||
replacement: Whether to sample with replacement.
|
replacement: Whether to sample with replacement.
|
||||||
num_samples_per_epoch: Number of samples per epoch (default is data_len).
|
num_samples_per_epoch: Number of samples per epoch (default is data_len).
|
||||||
"""
|
"""
|
||||||
self.data_len = data_len
|
self.data_len = data_len
|
||||||
self.alpha = alpha
|
self.alpha = alpha
|
||||||
self.beta = beta
|
|
||||||
self.eps = eps
|
self.eps = eps
|
||||||
self.replacement = replacement
|
|
||||||
self.num_samples_per_epoch = num_samples_per_epoch or data_len
|
self.num_samples_per_epoch = num_samples_per_epoch or data_len
|
||||||
|
self.beta_start = beta_start
|
||||||
|
self.beta_end = beta_end
|
||||||
|
self.total_steps = total_steps
|
||||||
|
self._beta = self.beta_start
|
||||||
|
|
||||||
# Initialize difficulties and sum-tree
|
# Initialize difficulties and sum-tree
|
||||||
self.difficulties = [1.0] * data_len # Default difficulty = 1.0
|
self.difficulties = [1.0] * data_len
|
||||||
initial_priorities = [(1.0 + eps) ** alpha + beta] * data_len # Compute initial priorities
|
self.priorities = [0.0] * data_len
|
||||||
|
initial_priorities = [(1.0 + eps) ** alpha] * data_len
|
||||||
|
|
||||||
self.sumtree = SumTree(data_len)
|
self.sumtree = SumTree(data_len)
|
||||||
self.sumtree.initialize_tree(initial_priorities) # Bulk load in O(n)
|
self.sumtree.initialize_tree(initial_priorities)
|
||||||
|
for i, p in enumerate(initial_priorities):
|
||||||
|
self.priorities[i] = p
|
||||||
|
|
||||||
|
def update_beta(self, current_step: int):
|
||||||
|
frac = min(1.0, current_step / self.total_steps)
|
||||||
|
self._beta = self.beta_start + (self.beta_end - self.beta_start) * frac
|
||||||
|
|
||||||
def update_priorities(self, indices: List[int], difficulties: List[float]):
|
def update_priorities(self, indices: List[int], difficulties: List[float]):
|
||||||
"""
|
"""
|
||||||
|
@ -165,7 +175,8 @@ class PrioritizedSampler(Sampler[int]):
|
||||||
"""
|
"""
|
||||||
for idx, diff in zip(indices, difficulties, strict=False):
|
for idx, diff in zip(indices, difficulties, strict=False):
|
||||||
self.difficulties[idx] = diff
|
self.difficulties[idx] = diff
|
||||||
new_priority = (diff + self.eps) ** self.alpha + self.beta
|
new_priority = (diff + self.eps) ** self.alpha
|
||||||
|
self.priorities[idx] = new_priority
|
||||||
self.sumtree.update(idx, new_priority)
|
self.sumtree.update(idx, new_priority)
|
||||||
|
|
||||||
def __iter__(self) -> Iterator[int]:
|
def __iter__(self) -> Iterator[int]:
|
||||||
|
@ -173,19 +184,21 @@ class PrioritizedSampler(Sampler[int]):
|
||||||
Samples indices based on their priority weights.
|
Samples indices based on their priority weights.
|
||||||
"""
|
"""
|
||||||
total_p = self.sumtree.total_priority()
|
total_p = self.sumtree.total_priority()
|
||||||
sampled_indices = set() if not self.replacement else None
|
|
||||||
|
|
||||||
for _ in range(self.num_samples_per_epoch):
|
for _ in range(self.num_samples_per_epoch):
|
||||||
r = random.random() * total_p
|
r = random.random() * total_p
|
||||||
idx = self.sumtree.sample(r)
|
idx = self.sumtree.sample(r)
|
||||||
|
|
||||||
if not self.replacement:
|
|
||||||
while idx in sampled_indices:
|
|
||||||
r = random.random() * total_p
|
|
||||||
idx = self.sumtree.sample(r)
|
|
||||||
sampled_indices.add(idx)
|
|
||||||
|
|
||||||
yield idx
|
yield idx
|
||||||
|
|
||||||
def __len__(self) -> int:
|
def __len__(self) -> int:
|
||||||
return self.num_samples_per_epoch
|
return self.num_samples_per_epoch
|
||||||
|
|
||||||
|
def compute_is_weights(self, indices: List[int]) -> torch.Tensor:
|
||||||
|
w = []
|
||||||
|
total_p = self.sumtree.total_priority()
|
||||||
|
for idx in indices:
|
||||||
|
p = self.priorities[idx] / total_p
|
||||||
|
w.append((p * self.data_len) ** (-self._beta))
|
||||||
|
w = torch.tensor(w, dtype=torch.float32)
|
||||||
|
return w / w.max()
|
||||||
|
|
|
@ -161,7 +161,6 @@ class ACTPolicy(PreTrainedPolicy):
|
||||||
|
|
||||||
l1_loss = elementwise_l1.mean()
|
l1_loss = elementwise_l1.mean()
|
||||||
|
|
||||||
# mean over time+action_dim => per-sample array of shape (B,)
|
|
||||||
l1_per_sample = elementwise_l1.mean(dim=(1, 2))
|
l1_per_sample = elementwise_l1.mean(dim=(1, 2))
|
||||||
|
|
||||||
if self.config.use_vae:
|
if self.config.use_vae:
|
||||||
|
@ -175,13 +174,13 @@ class ACTPolicy(PreTrainedPolicy):
|
||||||
loss_dict = {
|
loss_dict = {
|
||||||
"l1_loss": l1_loss.item(),
|
"l1_loss": l1_loss.item(),
|
||||||
"kld_loss": mean_kld.item(),
|
"kld_loss": mean_kld.item(),
|
||||||
"per_sample_l1": l1_per_sample, # shape (B,)
|
"per_sample_l1": l1_per_sample,
|
||||||
}
|
}
|
||||||
loss = l1_loss + mean_kld * self.config.kl_weight
|
loss = l1_loss + mean_kld * self.config.kl_weight
|
||||||
else:
|
else:
|
||||||
loss_dict = {
|
loss_dict = {
|
||||||
"l1_loss": l1_loss.item(),
|
"l1_loss": l1_loss.item(),
|
||||||
"per_sample_l1": l1_per_sample, # shape (B,)
|
"per_sample_l1": l1_per_sample,
|
||||||
}
|
}
|
||||||
loss = l1_loss
|
loss = l1_loss
|
||||||
|
|
||||||
|
|
|
@ -70,6 +70,17 @@ def update_policy(
|
||||||
with torch.autocast(device_type=device.type) if use_amp else nullcontext():
|
with torch.autocast(device_type=device.type) if use_amp else nullcontext():
|
||||||
loss, output_dict = policy.forward(batch)
|
loss, output_dict = policy.forward(batch)
|
||||||
# TODO(rcadene): policy.unnormalize_outputs(out_dict)
|
# TODO(rcadene): policy.unnormalize_outputs(out_dict)
|
||||||
|
|
||||||
|
# Apply importance-sampling if available
|
||||||
|
if "is_weights" in batch and "per_sample_l1" in output_dict:
|
||||||
|
per_sample_l1 = output_dict["per_sample_l1"]
|
||||||
|
l1_per_item = per_sample_l1.mean(dim=-1)
|
||||||
|
w = batch["is_weights"].to(device)
|
||||||
|
weighted_loss = (l1_per_item * w).mean()
|
||||||
|
if policy.config.use_vae and "kld_loss" in output_dict:
|
||||||
|
weighted_loss += output_dict["kld_loss"] * policy.config.kl_weight
|
||||||
|
loss = weighted_loss
|
||||||
|
|
||||||
grad_scaler.scale(loss).backward()
|
grad_scaler.scale(loss).backward()
|
||||||
|
|
||||||
# Unscale the gradient of the optimizer's assigned params in-place **prior to gradient clipping**.
|
# Unscale the gradient of the optimizer's assigned params in-place **prior to gradient clipping**.
|
||||||
|
@ -180,10 +191,11 @@ def train(cfg: TrainPipelineConfig):
|
||||||
sampler = PrioritizedSampler(
|
sampler = PrioritizedSampler(
|
||||||
data_len=data_len,
|
data_len=data_len,
|
||||||
alpha=0.6,
|
alpha=0.6,
|
||||||
beta=0.1,
|
|
||||||
eps=1e-6,
|
eps=1e-6,
|
||||||
replacement=True,
|
|
||||||
num_samples_per_epoch=data_len,
|
num_samples_per_epoch=data_len,
|
||||||
|
beta_start=0.4,
|
||||||
|
beta_end=1.0,
|
||||||
|
total_steps=cfg.steps,
|
||||||
)
|
)
|
||||||
|
|
||||||
dataloader = torch.utils.data.DataLoader(
|
dataloader = torch.utils.data.DataLoader(
|
||||||
|
@ -221,6 +233,11 @@ def train(cfg: TrainPipelineConfig):
|
||||||
if isinstance(batch[key], torch.Tensor):
|
if isinstance(batch[key], torch.Tensor):
|
||||||
batch[key] = batch[key].to(device, non_blocking=True)
|
batch[key] = batch[key].to(device, non_blocking=True)
|
||||||
|
|
||||||
|
if "indices" in batch:
|
||||||
|
sampler.update_beta(step)
|
||||||
|
is_weights = sampler.compute_is_weights(batch["indices"].cpu().tolist())
|
||||||
|
batch["is_weights"] = is_weights
|
||||||
|
|
||||||
train_tracker, output_dict = update_policy(
|
train_tracker, output_dict = update_policy(
|
||||||
train_tracker,
|
train_tracker,
|
||||||
policy,
|
policy,
|
||||||
|
@ -232,11 +249,11 @@ def train(cfg: TrainPipelineConfig):
|
||||||
use_amp=cfg.policy.use_amp,
|
use_amp=cfg.policy.use_amp,
|
||||||
)
|
)
|
||||||
|
|
||||||
# If we have 'indices' and 'per_sample_l1' then update sampler
|
# Update sampler
|
||||||
if "indices" in batch and "per_sample_l1" in output_dict:
|
if "indices" in batch and "per_sample_l1" in output_dict:
|
||||||
indices = batch["indices"].detach().cpu().tolist() # shape (B,)
|
idxs = batch["indices"].cpu().tolist()
|
||||||
difficulties = output_dict["per_sample_l1"].detach().cpu().tolist() # shape (B,)
|
diffs = output_dict["per_sample_l1"].detach().cpu().tolist()
|
||||||
sampler.update_priorities(indices, difficulties)
|
sampler.update_priorities(idxs, diffs)
|
||||||
|
|
||||||
# Note: eval and checkpoint happens *after* the `step`th training update has completed, so we
|
# Note: eval and checkpoint happens *after* the `step`th training update has completed, so we
|
||||||
# increment `step` here.
|
# increment `step` here.
|
||||||
|
|
Loading…
Reference in New Issue