backup wip
This commit is contained in:
parent
0eb899de73
commit
cb3978b5f3
|
@ -165,8 +165,8 @@ class ActionChunkingTransformerPolicy(nn.Module):
|
|||
if self.cfg.n_action_steps is not None:
|
||||
self._action_queue = deque([], maxlen=self.cfg.n_action_steps)
|
||||
|
||||
# def forward(self, batch: dict[str, Tensor], **_) -> Tensor:
|
||||
# return self.select_action(self, batch)
|
||||
def forward(self, batch: dict[str, Tensor], **_) -> Tensor:
|
||||
return self.select_action(self, batch)
|
||||
|
||||
@torch.no_grad
|
||||
def select_action(self, batch: dict[str, Tensor], **_) -> Tensor:
|
||||
|
@ -186,20 +186,28 @@ class ActionChunkingTransformerPolicy(nn.Module):
|
|||
def _select_actions(self, batch: dict[str, Tensor]) -> Tensor:
|
||||
"""Use the action chunking transformer to generate a sequence of actions."""
|
||||
self.eval()
|
||||
self._preprocess_batch(batch, add_obs_steps_dim=True)
|
||||
action = self.forward(batch, return_loss=False)
|
||||
return action[: self.cfg.n_action_steps]
|
||||
batch = self._reshape_batch(batch, add_obs_steps_dim=True)
|
||||
actions, _ = self._forward(
|
||||
batch["observation.state"], self.image_normalizer(batch["observation.images.top"])
|
||||
)
|
||||
return actions[: self.cfg.n_action_steps]
|
||||
|
||||
def _reshape_batch(self, batch: dict[str, Tensor], add_obs_steps_dim: bool = False) -> dict[str, Tensor]:
|
||||
"""Reshapes the batch items to account for various requirements of this policy.
|
||||
|
||||
def _preprocess_batch(
|
||||
self, batch: dict[str, Tensor], add_obs_steps_dim: bool = False
|
||||
) -> dict[str, Tensor]:
|
||||
"""
|
||||
This function expects `batch` to have (at least):
|
||||
{
|
||||
"observation.state": (B, 1, J) OR (B, J) tensor of robot states (joint configuration).
|
||||
"observation.images.top": (B, 1, C, H, W) OR (B, C, H, W) tensor of images.
|
||||
}
|
||||
|
||||
TODO(alexander-soare): Right now this method does and undoes reshaping operations. This is just to
|
||||
separate out the core logic from the temporary logic. See comments below.
|
||||
"""
|
||||
# Create a shallow copy.
|
||||
batch = dict(batch)
|
||||
|
||||
# Add a dimension for observation steps.
|
||||
if add_obs_steps_dim:
|
||||
# Add a dimension for the observations steps. Since n_obs_steps > 1 is not supported right now,
|
||||
# this just amounts to an unsqueeze.
|
||||
|
@ -207,18 +215,29 @@ class ActionChunkingTransformerPolicy(nn.Module):
|
|||
if k.startswith("observation."):
|
||||
batch[k] = batch[k].unsqueeze(1)
|
||||
|
||||
if batch["observation.state"].shape[1] != 1:
|
||||
# Temporary logic to remove the observation step dimension as the policy does not yet handle it.
|
||||
# TODO(alexander-soare): generalize this to multiple observations steps.
|
||||
# Check that there is only 1 observation step (policy does not yet handle more).
|
||||
if not all(batch[k].shape[1] == 1 for k in batch if k.startswith("observation.")):
|
||||
raise ValueError(self._multiple_obs_steps_not_handled_msg)
|
||||
batch["observation.state"] = batch["observation.state"].squeeze(1)
|
||||
# TODO(alexander-soare): generalize this to multiple images.
|
||||
# Remove observation steps dimension.
|
||||
for k in batch:
|
||||
if k.startswith("observation."):
|
||||
batch[k] = batch[k].squeeze(1)
|
||||
|
||||
# Temporary logic to add the multiple image dimension back in.
|
||||
# TODO(alexander-soare): generalize this to multiple images. Once resolved, this logic will stack all
|
||||
# images.
|
||||
assert (
|
||||
sum(k.startswith("observation.images.") and not k.endswith("is_pad") for k in batch) == 1
|
||||
), "ACT only handles one image for now."
|
||||
# Note: no squeeze is required for "observation.images.top" because then we'd have to unsqueeze to get
|
||||
# the image index dimension.
|
||||
), f"{self.__class__.__name__} only handles one image for now."
|
||||
# Since we only handle one image, just unsqueeze instead of stacking.
|
||||
batch["observation.images.top"] = batch["observation.images.top"].unsqueeze(1)
|
||||
|
||||
return batch
|
||||
|
||||
def compute_loss(self, batch, **_) -> float:
|
||||
self._preprocess_batch(batch)
|
||||
batch = self._reshape_batch(batch)
|
||||
|
||||
self.train()
|
||||
|
||||
|
@ -228,7 +247,28 @@ class ActionChunkingTransformerPolicy(nn.Module):
|
|||
assert batch_size % self.cfg.chunk_size == 0
|
||||
assert batch_size % num_slices == 0
|
||||
|
||||
loss = self.forward(batch, return_loss=True)["loss"]
|
||||
actions_hat, (mu_hat, log_sigma_x2_hat) = self._forward(
|
||||
batch["observation.state"],
|
||||
self.image_normalizer(batch["observation.images.top"]),
|
||||
batch["action"],
|
||||
)
|
||||
|
||||
l1_loss = (
|
||||
F.l1_loss(batch["action"], actions_hat, reduction="none") * ~batch["action_is_pad"].unsqueeze(-1)
|
||||
).mean()
|
||||
|
||||
if self.cfg.use_vae:
|
||||
# Calculate Dₖₗ(latent_pdf || standard_normal). Note: After computing the KL-divergence for
|
||||
# each dimension independently, we sum over the latent dimension to get the total
|
||||
# KL-divergence per batch element, then take the mean over the batch.
|
||||
# (See App. B of https://arxiv.org/abs/1312.6114 for more details).
|
||||
mean_kld = (
|
||||
(-0.5 * (1 + log_sigma_x2_hat - mu_hat.pow(2) - (log_sigma_x2_hat).exp())).sum(-1).mean()
|
||||
)
|
||||
loss = l1_loss + mean_kld * self.cfg.kl_weight
|
||||
else:
|
||||
loss = l1_loss
|
||||
|
||||
return loss
|
||||
|
||||
def update(self, batch, **_) -> dict:
|
||||
|
@ -255,39 +295,6 @@ class ActionChunkingTransformerPolicy(nn.Module):
|
|||
|
||||
return info
|
||||
|
||||
def forward(self, batch: dict[str, Tensor], return_loss: bool = False) -> dict | Tensor:
|
||||
"""A forward pass through the DNN part of this policy with optional loss computation."""
|
||||
images = self.image_normalizer(batch["observation.images.top"])
|
||||
|
||||
if return_loss: # training time
|
||||
actions_hat, (mu_hat, log_sigma_x2_hat) = self._forward(
|
||||
batch["observation.state"], images, batch["action"]
|
||||
)
|
||||
|
||||
l1_loss = (
|
||||
F.l1_loss(batch["action"], actions_hat, reduction="none")
|
||||
* ~batch["action_is_pad"].unsqueeze(-1)
|
||||
).mean()
|
||||
|
||||
loss_dict = {}
|
||||
loss_dict["l1"] = l1_loss
|
||||
if self.cfg.use_vae:
|
||||
# Calculate Dₖₗ(latent_pdf || standard_normal). Note: After computing the KL-divergence for
|
||||
# each dimension independently, we sum over the latent dimension to get the total
|
||||
# KL-divergence per batch element, then take the mean over the batch.
|
||||
# (See App. B of https://arxiv.org/abs/1312.6114 for more details).
|
||||
mean_kld = (
|
||||
(-0.5 * (1 + log_sigma_x2_hat - mu_hat.pow(2) - (log_sigma_x2_hat).exp())).sum(-1).mean()
|
||||
)
|
||||
loss_dict["kl"] = mean_kld
|
||||
loss_dict["loss"] = loss_dict["l1"] + loss_dict["kl"] * self.cfg.kl_weight
|
||||
else:
|
||||
loss_dict["loss"] = loss_dict["l1"]
|
||||
return loss_dict
|
||||
else:
|
||||
action, _ = self._forward(batch["observation.state"], images)
|
||||
return action
|
||||
|
||||
def _forward(
|
||||
self, robot_state: Tensor, image: Tensor, actions: Tensor | None = None
|
||||
) -> tuple[Tensor, tuple[Tensor | None, Tensor | None]]:
|
||||
|
|
|
@ -0,0 +1,45 @@
|
|||
"""A protocol that all policies should follow.
|
||||
|
||||
This provides a mechanism for type-hinting and isinstance checks without requiring the policies classes
|
||||
subclass a base class.
|
||||
|
||||
The protocol structure, method signatures, and docstrings should be used by developers as a reference for
|
||||
how to implement new policies.
|
||||
"""
|
||||
|
||||
from typing import Protocol, runtime_checkable
|
||||
|
||||
from torch import Tensor
|
||||
|
||||
|
||||
@runtime_checkable
|
||||
class Policy(Protocol):
|
||||
"""The required interface for implementing a policy."""
|
||||
|
||||
name: str
|
||||
|
||||
def reset(self):
|
||||
"""To be called whenever the environment is reset.
|
||||
|
||||
Does things like clearing caches.
|
||||
"""
|
||||
|
||||
def forward(self, batch: dict[str, Tensor], **kwargs):
|
||||
"""Wired to `select_action`."""
|
||||
|
||||
def select_action(self, batch: dict[str, Tensor], **kwargs):
|
||||
"""Return one action to run in the environment (potentially in batch mode).
|
||||
|
||||
When the model uses a history of observations, or outputs a sequence of actions, this method deals
|
||||
with caching.
|
||||
"""
|
||||
|
||||
def compute_loss(self, batch: dict[str, Tensor], **kwargs):
|
||||
"""Runs the batch through the model and computes the loss for training or validation."""
|
||||
|
||||
def update(self, batch, **kwargs):
|
||||
"""Does compute_loss then an optimization step.
|
||||
|
||||
TODO(alexander-soare): We will move the optimization step back into the training loop, so this will
|
||||
disappear.
|
||||
"""
|
Loading…
Reference in New Issue