backup wip

This commit is contained in:
Alexander Soare 2024-04-16 18:12:39 +01:00
parent 0eb899de73
commit cb3978b5f3
2 changed files with 102 additions and 50 deletions

View File

@ -165,8 +165,8 @@ class ActionChunkingTransformerPolicy(nn.Module):
if self.cfg.n_action_steps is not None:
self._action_queue = deque([], maxlen=self.cfg.n_action_steps)
# def forward(self, batch: dict[str, Tensor], **_) -> Tensor:
# return self.select_action(self, batch)
def forward(self, batch: dict[str, Tensor], **_) -> Tensor:
return self.select_action(self, batch)
@torch.no_grad
def select_action(self, batch: dict[str, Tensor], **_) -> Tensor:
@ -186,20 +186,28 @@ class ActionChunkingTransformerPolicy(nn.Module):
def _select_actions(self, batch: dict[str, Tensor]) -> Tensor:
"""Use the action chunking transformer to generate a sequence of actions."""
self.eval()
self._preprocess_batch(batch, add_obs_steps_dim=True)
action = self.forward(batch, return_loss=False)
return action[: self.cfg.n_action_steps]
batch = self._reshape_batch(batch, add_obs_steps_dim=True)
actions, _ = self._forward(
batch["observation.state"], self.image_normalizer(batch["observation.images.top"])
)
return actions[: self.cfg.n_action_steps]
def _reshape_batch(self, batch: dict[str, Tensor], add_obs_steps_dim: bool = False) -> dict[str, Tensor]:
"""Reshapes the batch items to account for various requirements of this policy.
def _preprocess_batch(
self, batch: dict[str, Tensor], add_obs_steps_dim: bool = False
) -> dict[str, Tensor]:
"""
This function expects `batch` to have (at least):
{
"observation.state": (B, 1, J) OR (B, J) tensor of robot states (joint configuration).
"observation.images.top": (B, 1, C, H, W) OR (B, C, H, W) tensor of images.
}
TODO(alexander-soare): Right now this method does and undoes reshaping operations. This is just to
separate out the core logic from the temporary logic. See comments below.
"""
# Create a shallow copy.
batch = dict(batch)
# Add a dimension for observation steps.
if add_obs_steps_dim:
# Add a dimension for the observations steps. Since n_obs_steps > 1 is not supported right now,
# this just amounts to an unsqueeze.
@ -207,18 +215,29 @@ class ActionChunkingTransformerPolicy(nn.Module):
if k.startswith("observation."):
batch[k] = batch[k].unsqueeze(1)
if batch["observation.state"].shape[1] != 1:
# Temporary logic to remove the observation step dimension as the policy does not yet handle it.
# TODO(alexander-soare): generalize this to multiple observations steps.
# Check that there is only 1 observation step (policy does not yet handle more).
if not all(batch[k].shape[1] == 1 for k in batch if k.startswith("observation.")):
raise ValueError(self._multiple_obs_steps_not_handled_msg)
batch["observation.state"] = batch["observation.state"].squeeze(1)
# TODO(alexander-soare): generalize this to multiple images.
# Remove observation steps dimension.
for k in batch:
if k.startswith("observation."):
batch[k] = batch[k].squeeze(1)
# Temporary logic to add the multiple image dimension back in.
# TODO(alexander-soare): generalize this to multiple images. Once resolved, this logic will stack all
# images.
assert (
sum(k.startswith("observation.images.") and not k.endswith("is_pad") for k in batch) == 1
), "ACT only handles one image for now."
# Note: no squeeze is required for "observation.images.top" because then we'd have to unsqueeze to get
# the image index dimension.
), f"{self.__class__.__name__} only handles one image for now."
# Since we only handle one image, just unsqueeze instead of stacking.
batch["observation.images.top"] = batch["observation.images.top"].unsqueeze(1)
return batch
def compute_loss(self, batch, **_) -> float:
self._preprocess_batch(batch)
batch = self._reshape_batch(batch)
self.train()
@ -228,7 +247,28 @@ class ActionChunkingTransformerPolicy(nn.Module):
assert batch_size % self.cfg.chunk_size == 0
assert batch_size % num_slices == 0
loss = self.forward(batch, return_loss=True)["loss"]
actions_hat, (mu_hat, log_sigma_x2_hat) = self._forward(
batch["observation.state"],
self.image_normalizer(batch["observation.images.top"]),
batch["action"],
)
l1_loss = (
F.l1_loss(batch["action"], actions_hat, reduction="none") * ~batch["action_is_pad"].unsqueeze(-1)
).mean()
if self.cfg.use_vae:
# Calculate Dₖₗ(latent_pdf || standard_normal). Note: After computing the KL-divergence for
# each dimension independently, we sum over the latent dimension to get the total
# KL-divergence per batch element, then take the mean over the batch.
# (See App. B of https://arxiv.org/abs/1312.6114 for more details).
mean_kld = (
(-0.5 * (1 + log_sigma_x2_hat - mu_hat.pow(2) - (log_sigma_x2_hat).exp())).sum(-1).mean()
)
loss = l1_loss + mean_kld * self.cfg.kl_weight
else:
loss = l1_loss
return loss
def update(self, batch, **_) -> dict:
@ -255,39 +295,6 @@ class ActionChunkingTransformerPolicy(nn.Module):
return info
def forward(self, batch: dict[str, Tensor], return_loss: bool = False) -> dict | Tensor:
"""A forward pass through the DNN part of this policy with optional loss computation."""
images = self.image_normalizer(batch["observation.images.top"])
if return_loss: # training time
actions_hat, (mu_hat, log_sigma_x2_hat) = self._forward(
batch["observation.state"], images, batch["action"]
)
l1_loss = (
F.l1_loss(batch["action"], actions_hat, reduction="none")
* ~batch["action_is_pad"].unsqueeze(-1)
).mean()
loss_dict = {}
loss_dict["l1"] = l1_loss
if self.cfg.use_vae:
# Calculate Dₖₗ(latent_pdf || standard_normal). Note: After computing the KL-divergence for
# each dimension independently, we sum over the latent dimension to get the total
# KL-divergence per batch element, then take the mean over the batch.
# (See App. B of https://arxiv.org/abs/1312.6114 for more details).
mean_kld = (
(-0.5 * (1 + log_sigma_x2_hat - mu_hat.pow(2) - (log_sigma_x2_hat).exp())).sum(-1).mean()
)
loss_dict["kl"] = mean_kld
loss_dict["loss"] = loss_dict["l1"] + loss_dict["kl"] * self.cfg.kl_weight
else:
loss_dict["loss"] = loss_dict["l1"]
return loss_dict
else:
action, _ = self._forward(batch["observation.state"], images)
return action
def _forward(
self, robot_state: Tensor, image: Tensor, actions: Tensor | None = None
) -> tuple[Tensor, tuple[Tensor | None, Tensor | None]]:

View File

@ -0,0 +1,45 @@
"""A protocol that all policies should follow.
This provides a mechanism for type-hinting and isinstance checks without requiring the policies classes
subclass a base class.
The protocol structure, method signatures, and docstrings should be used by developers as a reference for
how to implement new policies.
"""
from typing import Protocol, runtime_checkable
from torch import Tensor
@runtime_checkable
class Policy(Protocol):
"""The required interface for implementing a policy."""
name: str
def reset(self):
"""To be called whenever the environment is reset.
Does things like clearing caches.
"""
def forward(self, batch: dict[str, Tensor], **kwargs):
"""Wired to `select_action`."""
def select_action(self, batch: dict[str, Tensor], **kwargs):
"""Return one action to run in the environment (potentially in batch mode).
When the model uses a history of observations, or outputs a sequence of actions, this method deals
with caching.
"""
def compute_loss(self, batch: dict[str, Tensor], **kwargs):
"""Runs the batch through the model and computes the loss for training or validation."""
def update(self, batch, **kwargs):
"""Does compute_loss then an optimization step.
TODO(alexander-soare): We will move the optimization step back into the training loop, so this will
disappear.
"""