wip
This commit is contained in:
parent
63e5ec6483
commit
2298ddf226
|
@ -160,9 +160,6 @@ class ActionChunkingTransformerPolicy(nn.Module):
|
||||||
if self.cfg.n_action_steps is not None:
|
if self.cfg.n_action_steps is not None:
|
||||||
self._action_queue = deque([], maxlen=self.cfg.n_action_steps)
|
self._action_queue = deque([], maxlen=self.cfg.n_action_steps)
|
||||||
|
|
||||||
def forward(self, batch: dict[str, Tensor], **_) -> Tensor:
|
|
||||||
return self.select_action(self, batch)
|
|
||||||
|
|
||||||
@torch.no_grad
|
@torch.no_grad
|
||||||
def select_action(self, batch: dict[str, Tensor], **_) -> Tensor:
|
def select_action(self, batch: dict[str, Tensor], **_) -> Tensor:
|
||||||
"""Select a single action given environment observations.
|
"""Select a single action given environment observations.
|
||||||
|
@ -178,14 +175,15 @@ class ActionChunkingTransformerPolicy(nn.Module):
|
||||||
self._action_queue.extend(self._forward(batch)[0][: self.cfg.n_action_steps].transpose(0, 1))
|
self._action_queue.extend(self._forward(batch)[0][: self.cfg.n_action_steps].transpose(0, 1))
|
||||||
return self._action_queue.popleft()
|
return self._action_queue.popleft()
|
||||||
|
|
||||||
def compute_loss(self, batch, **_) -> float:
|
def forward(self, batch, **_) -> dict[str, Tensor]:
|
||||||
"""Runs the batch through the model and computes the loss for training or validation."""
|
"""Run the batch through the model and compute the loss for training or validation."""
|
||||||
actions_hat, (mu_hat, log_sigma_x2_hat) = self._forward(batch)
|
actions_hat, (mu_hat, log_sigma_x2_hat) = self._forward(batch)
|
||||||
|
|
||||||
l1_loss = (
|
l1_loss = (
|
||||||
F.l1_loss(batch["action"], actions_hat, reduction="none") * ~batch["action_is_pad"].unsqueeze(-1)
|
F.l1_loss(batch["action"], actions_hat, reduction="none") * ~batch["action_is_pad"].unsqueeze(-1)
|
||||||
).mean()
|
).mean()
|
||||||
|
|
||||||
|
loss_dict = {"l1_loss": l1_loss}
|
||||||
if self.cfg.use_vae:
|
if self.cfg.use_vae:
|
||||||
# Calculate Dₖₗ(latent_pdf || standard_normal). Note: After computing the KL-divergence for
|
# Calculate Dₖₗ(latent_pdf || standard_normal). Note: After computing the KL-divergence for
|
||||||
# each dimension independently, we sum over the latent dimension to get the total
|
# each dimension independently, we sum over the latent dimension to get the total
|
||||||
|
@ -194,23 +192,23 @@ class ActionChunkingTransformerPolicy(nn.Module):
|
||||||
mean_kld = (
|
mean_kld = (
|
||||||
(-0.5 * (1 + log_sigma_x2_hat - mu_hat.pow(2) - (log_sigma_x2_hat).exp())).sum(-1).mean()
|
(-0.5 * (1 + log_sigma_x2_hat - mu_hat.pow(2) - (log_sigma_x2_hat).exp())).sum(-1).mean()
|
||||||
)
|
)
|
||||||
loss = l1_loss + mean_kld * self.cfg.kl_weight
|
loss_dict["kld_loss"] = mean_kld
|
||||||
|
loss_dict["loss"] = l1_loss + mean_kld * self.cfg.kl_weight
|
||||||
else:
|
else:
|
||||||
loss = l1_loss
|
loss_dict["loss"] = l1_loss
|
||||||
|
|
||||||
return loss
|
return loss_dict
|
||||||
|
|
||||||
def update(self, batch, **_) -> dict:
|
def update(self, batch, **_) -> dict:
|
||||||
"""Run the model in train mode, compute the loss, and do an optimization step."""
|
"""Run the model in train mode, compute the loss, and do an optimization step."""
|
||||||
start_time = time.time()
|
start_time = time.time()
|
||||||
self.train()
|
self.train()
|
||||||
loss = self.compute_loss(batch)
|
loss_dict = self.forward(batch)
|
||||||
|
loss = loss_dict["loss"]
|
||||||
loss.backward()
|
loss.backward()
|
||||||
|
|
||||||
grad_norm = torch.nn.utils.clip_grad_norm_(
|
grad_norm = torch.nn.utils.clip_grad_norm_(
|
||||||
self.parameters(),
|
self.parameters(), self.cfg.grad_clip_norm, error_if_nonfinite=False
|
||||||
self.cfg.grad_clip_norm,
|
|
||||||
error_if_nonfinite=False,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
self.optimizer.step()
|
self.optimizer.step()
|
||||||
|
|
|
@ -43,12 +43,12 @@ class DiffusionPolicy(nn.Module):
|
||||||
name = "diffusion"
|
name = "diffusion"
|
||||||
|
|
||||||
def __init__(self, cfg: DiffusionConfig | None, lr_scheduler_num_training_steps: int = 0):
|
def __init__(self, cfg: DiffusionConfig | None, lr_scheduler_num_training_steps: int = 0):
|
||||||
super().__init__()
|
|
||||||
"""
|
"""
|
||||||
Args:
|
Args:
|
||||||
cfg: Policy configuration class instance or None, in which case the default instantiation of the
|
cfg: Policy configuration class instance or None, in which case the default instantiation of the
|
||||||
configuration class is used.
|
configuration class is used.
|
||||||
"""
|
"""
|
||||||
|
super().__init__()
|
||||||
# TODO(alexander-soare): LR scheduler will be removed.
|
# TODO(alexander-soare): LR scheduler will be removed.
|
||||||
assert lr_scheduler_num_training_steps > 0
|
assert lr_scheduler_num_training_steps > 0
|
||||||
if cfg is None:
|
if cfg is None:
|
||||||
|
@ -140,12 +140,18 @@ class DiffusionPolicy(nn.Module):
|
||||||
action = self._queues["action"].popleft()
|
action = self._queues["action"].popleft()
|
||||||
return action
|
return action
|
||||||
|
|
||||||
def forward(self, batch, **_):
|
def forward(self, batch: dict[str, Tensor], **_) -> dict[str, Tensor]:
|
||||||
|
"""Run the batch through the model and compute the loss for training or validation."""
|
||||||
|
loss = self.diffusion.compute_loss(batch)
|
||||||
|
return {"loss": loss}
|
||||||
|
|
||||||
|
def update(self, batch: dict[str, Tensor], **_) -> dict:
|
||||||
|
"""Run the model in train mode, compute the loss, and do an optimization step."""
|
||||||
start_time = time.time()
|
start_time = time.time()
|
||||||
|
|
||||||
self.diffusion.train()
|
self.diffusion.train()
|
||||||
|
|
||||||
loss = self.diffusion.compute_loss(batch)
|
loss = self.forward(batch)["loss"]
|
||||||
loss.backward()
|
loss.backward()
|
||||||
|
|
||||||
grad_norm = torch.nn.utils.clip_grad_norm_(
|
grad_norm = torch.nn.utils.clip_grad_norm_(
|
||||||
|
|
|
@ -24,20 +24,20 @@ class Policy(Protocol):
|
||||||
Does things like clearing caches.
|
Does things like clearing caches.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def forward(self, batch: dict[str, Tensor], **kwargs):
|
def forward(self, batch: dict[str, Tensor]) -> dict:
|
||||||
"""Wired to `select_action`."""
|
"""Run the batch through the model and compute the loss for training or validation.
|
||||||
|
|
||||||
def select_action(self, batch: dict[str, Tensor], **kwargs):
|
Returns a dictionary with "loss" and maybe other information.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def select_action(self, batch: dict[str, Tensor]):
|
||||||
"""Return one action to run in the environment (potentially in batch mode).
|
"""Return one action to run in the environment (potentially in batch mode).
|
||||||
|
|
||||||
When the model uses a history of observations, or outputs a sequence of actions, this method deals
|
When the model uses a history of observations, or outputs a sequence of actions, this method deals
|
||||||
with caching.
|
with caching.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def compute_loss(self, batch: dict[str, Tensor], **kwargs):
|
def update(self, batch):
|
||||||
"""Runs the batch through the model and computes the loss for training or validation."""
|
|
||||||
|
|
||||||
def update(self, batch, **kwargs):
|
|
||||||
"""Does compute_loss then an optimization step.
|
"""Does compute_loss then an optimization step.
|
||||||
|
|
||||||
TODO(alexander-soare): We will move the optimization step back into the training loop, so this will
|
TODO(alexander-soare): We will move the optimization step back into the training loop, so this will
|
||||||
|
|
Loading…
Reference in New Issue