This commit is contained in:
Alexander Soare 2024-04-17 16:21:37 +01:00
parent 63e5ec6483
commit 2298ddf226
3 changed files with 26 additions and 22 deletions

View File

@ -160,9 +160,6 @@ class ActionChunkingTransformerPolicy(nn.Module):
if self.cfg.n_action_steps is not None: if self.cfg.n_action_steps is not None:
self._action_queue = deque([], maxlen=self.cfg.n_action_steps) self._action_queue = deque([], maxlen=self.cfg.n_action_steps)
def forward(self, batch: dict[str, Tensor], **_) -> Tensor:
return self.select_action(self, batch)
@torch.no_grad @torch.no_grad
def select_action(self, batch: dict[str, Tensor], **_) -> Tensor: def select_action(self, batch: dict[str, Tensor], **_) -> Tensor:
"""Select a single action given environment observations. """Select a single action given environment observations.
@ -178,14 +175,15 @@ class ActionChunkingTransformerPolicy(nn.Module):
self._action_queue.extend(self._forward(batch)[0][: self.cfg.n_action_steps].transpose(0, 1)) self._action_queue.extend(self._forward(batch)[0][: self.cfg.n_action_steps].transpose(0, 1))
return self._action_queue.popleft() return self._action_queue.popleft()
def compute_loss(self, batch, **_) -> float: def forward(self, batch, **_) -> dict[str, Tensor]:
"""Runs the batch through the model and computes the loss for training or validation.""" """Run the batch through the model and compute the loss for training or validation."""
actions_hat, (mu_hat, log_sigma_x2_hat) = self._forward(batch) actions_hat, (mu_hat, log_sigma_x2_hat) = self._forward(batch)
l1_loss = ( l1_loss = (
F.l1_loss(batch["action"], actions_hat, reduction="none") * ~batch["action_is_pad"].unsqueeze(-1) F.l1_loss(batch["action"], actions_hat, reduction="none") * ~batch["action_is_pad"].unsqueeze(-1)
).mean() ).mean()
loss_dict = {"l1_loss": l1_loss}
if self.cfg.use_vae: if self.cfg.use_vae:
# Calculate Dₖₗ(latent_pdf || standard_normal). Note: After computing the KL-divergence for # Calculate Dₖₗ(latent_pdf || standard_normal). Note: After computing the KL-divergence for
# each dimension independently, we sum over the latent dimension to get the total # each dimension independently, we sum over the latent dimension to get the total
@ -194,23 +192,23 @@ class ActionChunkingTransformerPolicy(nn.Module):
mean_kld = ( mean_kld = (
(-0.5 * (1 + log_sigma_x2_hat - mu_hat.pow(2) - (log_sigma_x2_hat).exp())).sum(-1).mean() (-0.5 * (1 + log_sigma_x2_hat - mu_hat.pow(2) - (log_sigma_x2_hat).exp())).sum(-1).mean()
) )
loss = l1_loss + mean_kld * self.cfg.kl_weight loss_dict["kld_loss"] = mean_kld
loss_dict["loss"] = l1_loss + mean_kld * self.cfg.kl_weight
else: else:
loss = l1_loss loss_dict["loss"] = l1_loss
return loss return loss_dict
def update(self, batch, **_) -> dict: def update(self, batch, **_) -> dict:
"""Run the model in train mode, compute the loss, and do an optimization step.""" """Run the model in train mode, compute the loss, and do an optimization step."""
start_time = time.time() start_time = time.time()
self.train() self.train()
loss = self.compute_loss(batch) loss_dict = self.forward(batch)
loss = loss_dict["loss"]
loss.backward() loss.backward()
grad_norm = torch.nn.utils.clip_grad_norm_( grad_norm = torch.nn.utils.clip_grad_norm_(
self.parameters(), self.parameters(), self.cfg.grad_clip_norm, error_if_nonfinite=False
self.cfg.grad_clip_norm,
error_if_nonfinite=False,
) )
self.optimizer.step() self.optimizer.step()

View File

@ -43,12 +43,12 @@ class DiffusionPolicy(nn.Module):
name = "diffusion" name = "diffusion"
def __init__(self, cfg: DiffusionConfig | None, lr_scheduler_num_training_steps: int = 0): def __init__(self, cfg: DiffusionConfig | None, lr_scheduler_num_training_steps: int = 0):
super().__init__()
""" """
Args: Args:
cfg: Policy configuration class instance or None, in which case the default instantiation of the cfg: Policy configuration class instance or None, in which case the default instantiation of the
configuration class is used. configuration class is used.
""" """
super().__init__()
# TODO(alexander-soare): LR scheduler will be removed. # TODO(alexander-soare): LR scheduler will be removed.
assert lr_scheduler_num_training_steps > 0 assert lr_scheduler_num_training_steps > 0
if cfg is None: if cfg is None:
@ -140,12 +140,18 @@ class DiffusionPolicy(nn.Module):
action = self._queues["action"].popleft() action = self._queues["action"].popleft()
return action return action
def forward(self, batch, **_): def forward(self, batch: dict[str, Tensor], **_) -> dict[str, Tensor]:
"""Run the batch through the model and compute the loss for training or validation."""
loss = self.diffusion.compute_loss(batch)
return {"loss": loss}
def update(self, batch: dict[str, Tensor], **_) -> dict:
"""Run the model in train mode, compute the loss, and do an optimization step."""
start_time = time.time() start_time = time.time()
self.diffusion.train() self.diffusion.train()
loss = self.diffusion.compute_loss(batch) loss = self.forward(batch)["loss"]
loss.backward() loss.backward()
grad_norm = torch.nn.utils.clip_grad_norm_( grad_norm = torch.nn.utils.clip_grad_norm_(

View File

@ -24,20 +24,20 @@ class Policy(Protocol):
Does things like clearing caches. Does things like clearing caches.
""" """
def forward(self, batch: dict[str, Tensor], **kwargs): def forward(self, batch: dict[str, Tensor]) -> dict:
"""Wired to `select_action`.""" """Run the batch through the model and compute the loss for training or validation.
def select_action(self, batch: dict[str, Tensor], **kwargs): Returns a dictionary with "loss" and maybe other information.
"""
def select_action(self, batch: dict[str, Tensor]):
"""Return one action to run in the environment (potentially in batch mode). """Return one action to run in the environment (potentially in batch mode).
When the model uses a history of observations, or outputs a sequence of actions, this method deals When the model uses a history of observations, or outputs a sequence of actions, this method deals
with caching. with caching.
""" """
def compute_loss(self, batch: dict[str, Tensor], **kwargs): def update(self, batch):
"""Runs the batch through the model and computes the loss for training or validation."""
def update(self, batch, **kwargs):
"""Does compute_loss then an optimization step. """Does compute_loss then an optimization step.
TODO(alexander-soare): We will move the optimization step back into the training loop, so this will TODO(alexander-soare): We will move the optimization step back into the training loop, so this will