wip
This commit is contained in:
parent
63e5ec6483
commit
2298ddf226
|
@ -160,9 +160,6 @@ class ActionChunkingTransformerPolicy(nn.Module):
|
|||
if self.cfg.n_action_steps is not None:
|
||||
self._action_queue = deque([], maxlen=self.cfg.n_action_steps)
|
||||
|
||||
def forward(self, batch: dict[str, Tensor], **_) -> Tensor:
|
||||
return self.select_action(self, batch)
|
||||
|
||||
@torch.no_grad
|
||||
def select_action(self, batch: dict[str, Tensor], **_) -> Tensor:
|
||||
"""Select a single action given environment observations.
|
||||
|
@ -178,14 +175,15 @@ class ActionChunkingTransformerPolicy(nn.Module):
|
|||
self._action_queue.extend(self._forward(batch)[0][: self.cfg.n_action_steps].transpose(0, 1))
|
||||
return self._action_queue.popleft()
|
||||
|
||||
def compute_loss(self, batch, **_) -> float:
|
||||
"""Runs the batch through the model and computes the loss for training or validation."""
|
||||
def forward(self, batch, **_) -> dict[str, Tensor]:
|
||||
"""Run the batch through the model and compute the loss for training or validation."""
|
||||
actions_hat, (mu_hat, log_sigma_x2_hat) = self._forward(batch)
|
||||
|
||||
l1_loss = (
|
||||
F.l1_loss(batch["action"], actions_hat, reduction="none") * ~batch["action_is_pad"].unsqueeze(-1)
|
||||
).mean()
|
||||
|
||||
loss_dict = {"l1_loss": l1_loss}
|
||||
if self.cfg.use_vae:
|
||||
# Calculate Dₖₗ(latent_pdf || standard_normal). Note: After computing the KL-divergence for
|
||||
# each dimension independently, we sum over the latent dimension to get the total
|
||||
|
@ -194,23 +192,23 @@ class ActionChunkingTransformerPolicy(nn.Module):
|
|||
mean_kld = (
|
||||
(-0.5 * (1 + log_sigma_x2_hat - mu_hat.pow(2) - (log_sigma_x2_hat).exp())).sum(-1).mean()
|
||||
)
|
||||
loss = l1_loss + mean_kld * self.cfg.kl_weight
|
||||
loss_dict["kld_loss"] = mean_kld
|
||||
loss_dict["loss"] = l1_loss + mean_kld * self.cfg.kl_weight
|
||||
else:
|
||||
loss = l1_loss
|
||||
loss_dict["loss"] = l1_loss
|
||||
|
||||
return loss
|
||||
return loss_dict
|
||||
|
||||
def update(self, batch, **_) -> dict:
|
||||
"""Run the model in train mode, compute the loss, and do an optimization step."""
|
||||
start_time = time.time()
|
||||
self.train()
|
||||
loss = self.compute_loss(batch)
|
||||
loss_dict = self.forward(batch)
|
||||
loss = loss_dict["loss"]
|
||||
loss.backward()
|
||||
|
||||
grad_norm = torch.nn.utils.clip_grad_norm_(
|
||||
self.parameters(),
|
||||
self.cfg.grad_clip_norm,
|
||||
error_if_nonfinite=False,
|
||||
self.parameters(), self.cfg.grad_clip_norm, error_if_nonfinite=False
|
||||
)
|
||||
|
||||
self.optimizer.step()
|
||||
|
|
|
@ -43,12 +43,12 @@ class DiffusionPolicy(nn.Module):
|
|||
name = "diffusion"
|
||||
|
||||
def __init__(self, cfg: DiffusionConfig | None, lr_scheduler_num_training_steps: int = 0):
|
||||
super().__init__()
|
||||
"""
|
||||
Args:
|
||||
cfg: Policy configuration class instance or None, in which case the default instantiation of the
|
||||
configuration class is used.
|
||||
"""
|
||||
super().__init__()
|
||||
# TODO(alexander-soare): LR scheduler will be removed.
|
||||
assert lr_scheduler_num_training_steps > 0
|
||||
if cfg is None:
|
||||
|
@ -140,12 +140,18 @@ class DiffusionPolicy(nn.Module):
|
|||
action = self._queues["action"].popleft()
|
||||
return action
|
||||
|
||||
def forward(self, batch, **_):
|
||||
def forward(self, batch: dict[str, Tensor], **_) -> dict[str, Tensor]:
|
||||
"""Run the batch through the model and compute the loss for training or validation."""
|
||||
loss = self.diffusion.compute_loss(batch)
|
||||
return {"loss": loss}
|
||||
|
||||
def update(self, batch: dict[str, Tensor], **_) -> dict:
|
||||
"""Run the model in train mode, compute the loss, and do an optimization step."""
|
||||
start_time = time.time()
|
||||
|
||||
self.diffusion.train()
|
||||
|
||||
loss = self.diffusion.compute_loss(batch)
|
||||
loss = self.forward(batch)["loss"]
|
||||
loss.backward()
|
||||
|
||||
grad_norm = torch.nn.utils.clip_grad_norm_(
|
||||
|
|
|
@ -24,20 +24,20 @@ class Policy(Protocol):
|
|||
Does things like clearing caches.
|
||||
"""
|
||||
|
||||
def forward(self, batch: dict[str, Tensor], **kwargs):
|
||||
"""Wired to `select_action`."""
|
||||
def forward(self, batch: dict[str, Tensor]) -> dict:
|
||||
"""Run the batch through the model and compute the loss for training or validation.
|
||||
|
||||
def select_action(self, batch: dict[str, Tensor], **kwargs):
|
||||
Returns a dictionary with "loss" and maybe other information.
|
||||
"""
|
||||
|
||||
def select_action(self, batch: dict[str, Tensor]):
|
||||
"""Return one action to run in the environment (potentially in batch mode).
|
||||
|
||||
When the model uses a history of observations, or outputs a sequence of actions, this method deals
|
||||
with caching.
|
||||
"""
|
||||
|
||||
def compute_loss(self, batch: dict[str, Tensor], **kwargs):
|
||||
"""Runs the batch through the model and computes the loss for training or validation."""
|
||||
|
||||
def update(self, batch, **kwargs):
|
||||
def update(self, batch):
|
||||
"""Does compute_loss then an optimization step.
|
||||
|
||||
TODO(alexander-soare): We will move the optimization step back into the training loop, so this will
|
||||
|
|
Loading…
Reference in New Issue