Nest ACT model in ACT Policy (#122)
This commit is contained in:
parent
9d60dce6f3
commit
986583dc5c
|
@ -26,6 +26,108 @@ class ActionChunkingTransformerPolicy(nn.Module):
|
|||
"""
|
||||
Action Chunking Transformer Policy as per Learning Fine-Grained Bimanual Manipulation with Low-Cost
|
||||
Hardware (paper: https://arxiv.org/abs/2304.13705, code: https://github.com/tonyzhaozh/act)
|
||||
"""
|
||||
|
||||
name = "act"
|
||||
|
||||
def __init__(self, cfg: ActionChunkingTransformerConfig | None = None, dataset_stats=None):
|
||||
"""
|
||||
Args:
|
||||
cfg: Policy configuration class instance or None, in which case the default instantiation of the
|
||||
configuration class is used.
|
||||
"""
|
||||
super().__init__()
|
||||
if cfg is None:
|
||||
cfg = ActionChunkingTransformerConfig()
|
||||
self.cfg = cfg
|
||||
self.normalize_inputs = Normalize(cfg.input_shapes, cfg.input_normalization_modes, dataset_stats)
|
||||
self.normalize_targets = Normalize(cfg.output_shapes, cfg.output_normalization_modes, dataset_stats)
|
||||
self.unnormalize_outputs = Unnormalize(
|
||||
cfg.output_shapes, cfg.output_normalization_modes, dataset_stats
|
||||
)
|
||||
self.model = _ActionChunkingTransformer(cfg)
|
||||
|
||||
def reset(self):
|
||||
"""This should be called whenever the environment is reset."""
|
||||
if self.cfg.n_action_steps is not None:
|
||||
self._action_queue = deque([], maxlen=self.cfg.n_action_steps)
|
||||
|
||||
@torch.no_grad
|
||||
def select_action(self, batch: dict[str, Tensor], **_) -> Tensor:
|
||||
"""Select a single action given environment observations.
|
||||
|
||||
This method wraps `select_actions` in order to return one action at a time for execution in the
|
||||
environment. It works by managing the actions in a queue and only calling `select_actions` when the
|
||||
queue is empty.
|
||||
"""
|
||||
self.eval()
|
||||
|
||||
batch = self.normalize_inputs(batch)
|
||||
self._stack_images(batch)
|
||||
|
||||
if len(self._action_queue) == 0:
|
||||
# `self.model.forward` returns a (batch_size, n_action_steps, action_dim) tensor, but the queue
|
||||
# effectively has shape (n_action_steps, batch_size, *), hence the transpose.
|
||||
actions = self.model(batch)[0][: self.cfg.n_action_steps]
|
||||
|
||||
# TODO(rcadene): make _forward return output dictionary?
|
||||
actions = self.unnormalize_outputs({"action": actions})["action"]
|
||||
|
||||
self._action_queue.extend(actions.transpose(0, 1))
|
||||
return self._action_queue.popleft()
|
||||
|
||||
def forward(self, batch, **_) -> dict[str, Tensor]:
|
||||
"""Run the batch through the model and compute the loss for training or validation."""
|
||||
batch = self.normalize_inputs(batch)
|
||||
batch = self.normalize_targets(batch)
|
||||
self._stack_images(batch)
|
||||
actions_hat, (mu_hat, log_sigma_x2_hat) = self.model(batch)
|
||||
|
||||
l1_loss = (
|
||||
F.l1_loss(batch["action"], actions_hat, reduction="none") * ~batch["action_is_pad"].unsqueeze(-1)
|
||||
).mean()
|
||||
|
||||
loss_dict = {"l1_loss": l1_loss}
|
||||
if self.cfg.use_vae:
|
||||
# Calculate Dₖₗ(latent_pdf || standard_normal). Note: After computing the KL-divergence for
|
||||
# each dimension independently, we sum over the latent dimension to get the total
|
||||
# KL-divergence per batch element, then take the mean over the batch.
|
||||
# (See App. B of https://arxiv.org/abs/1312.6114 for more details).
|
||||
mean_kld = (
|
||||
(-0.5 * (1 + log_sigma_x2_hat - mu_hat.pow(2) - (log_sigma_x2_hat).exp())).sum(-1).mean()
|
||||
)
|
||||
loss_dict["kld_loss"] = mean_kld
|
||||
loss_dict["loss"] = l1_loss + mean_kld * self.cfg.kl_weight
|
||||
else:
|
||||
loss_dict["loss"] = l1_loss
|
||||
|
||||
return loss_dict
|
||||
|
||||
def _stack_images(self, batch: dict[str, Tensor]) -> dict[str, Tensor]:
|
||||
"""Stacks all the images in a batch and puts them in a new key: "observation.images".
|
||||
|
||||
This function expects `batch` to have (at least):
|
||||
{
|
||||
"observation.state": (B, state_dim) batch of robot states.
|
||||
"observation.images.{name}": (B, C, H, W) tensor of images.
|
||||
}
|
||||
"""
|
||||
# Stack images in the order dictated by input_shapes.
|
||||
batch["observation.images"] = torch.stack(
|
||||
[batch[k] for k in self.cfg.input_shapes if k.startswith("observation.images.")],
|
||||
dim=-4,
|
||||
)
|
||||
|
||||
def save(self, fp):
|
||||
torch.save(self.state_dict(), fp)
|
||||
|
||||
def load(self, fp):
|
||||
d = torch.load(fp)
|
||||
self.load_state_dict(d)
|
||||
|
||||
|
||||
class _ActionChunkingTransformer(nn.Module):
|
||||
"""Action Chunking Transformer: The underlying neural network for ActionChunkingTransformerPolicy.
|
||||
|
||||
Note: In this code we use the terms `vae_encoder`, 'encoder', `decoder`. The meanings are as follows.
|
||||
- The `vae_encoder` is, as per the literature around variational auto-encoders (VAE), the part of the
|
||||
|
@ -59,24 +161,9 @@ class ActionChunkingTransformerPolicy(nn.Module):
|
|||
└───────────────────────┘
|
||||
"""
|
||||
|
||||
name = "act"
|
||||
|
||||
def __init__(self, cfg: ActionChunkingTransformerConfig | None = None, dataset_stats=None):
|
||||
"""
|
||||
Args:
|
||||
cfg: Policy configuration class instance or None, in which case the default instantiation of the
|
||||
configuration class is used.
|
||||
"""
|
||||
def __init__(self, cfg: ActionChunkingTransformerConfig):
|
||||
super().__init__()
|
||||
if cfg is None:
|
||||
cfg = ActionChunkingTransformerConfig()
|
||||
self.cfg = cfg
|
||||
self.normalize_inputs = Normalize(cfg.input_shapes, cfg.input_normalization_modes, dataset_stats)
|
||||
self.normalize_targets = Normalize(cfg.output_shapes, cfg.output_normalization_modes, dataset_stats)
|
||||
self.unnormalize_outputs = Unnormalize(
|
||||
cfg.output_shapes, cfg.output_normalization_modes, dataset_stats
|
||||
)
|
||||
|
||||
# BERT style VAE encoder with input [cls, *joint_space_configuration, *action_sequence].
|
||||
# The cls token forms parameters of the latent's distribution (like this [*means, *log_variances]).
|
||||
if self.cfg.use_vae:
|
||||
|
@ -141,76 +228,7 @@ class ActionChunkingTransformerPolicy(nn.Module):
|
|||
if p.dim() > 1:
|
||||
nn.init.xavier_uniform_(p)
|
||||
|
||||
def reset(self):
|
||||
"""This should be called whenever the environment is reset."""
|
||||
if self.cfg.n_action_steps is not None:
|
||||
self._action_queue = deque([], maxlen=self.cfg.n_action_steps)
|
||||
|
||||
@torch.no_grad
|
||||
def select_action(self, batch: dict[str, Tensor], **_) -> Tensor:
|
||||
"""Select a single action given environment observations.
|
||||
|
||||
This method wraps `select_actions` in order to return one action at a time for execution in the
|
||||
environment. It works by managing the actions in a queue and only calling `select_actions` when the
|
||||
queue is empty.
|
||||
"""
|
||||
self.eval()
|
||||
|
||||
batch = self.normalize_inputs(batch)
|
||||
|
||||
if len(self._action_queue) == 0:
|
||||
# `_forward` returns a (batch_size, n_action_steps, action_dim) tensor, but the queue effectively
|
||||
# has shape (n_action_steps, batch_size, *), hence the transpose.
|
||||
actions = self._forward(batch)[0][: self.cfg.n_action_steps]
|
||||
|
||||
# TODO(rcadene): make _forward return output dictionary?
|
||||
actions = self.unnormalize_outputs({"action": actions})["action"]
|
||||
|
||||
self._action_queue.extend(actions.transpose(0, 1))
|
||||
return self._action_queue.popleft()
|
||||
|
||||
def forward(self, batch, **_) -> dict[str, Tensor]:
|
||||
"""Run the batch through the model and compute the loss for training or validation."""
|
||||
batch = self.normalize_inputs(batch)
|
||||
batch = self.normalize_targets(batch)
|
||||
actions_hat, (mu_hat, log_sigma_x2_hat) = self._forward(batch)
|
||||
|
||||
l1_loss = (
|
||||
F.l1_loss(batch["action"], actions_hat, reduction="none") * ~batch["action_is_pad"].unsqueeze(-1)
|
||||
).mean()
|
||||
|
||||
loss_dict = {"l1_loss": l1_loss}
|
||||
if self.cfg.use_vae:
|
||||
# Calculate Dₖₗ(latent_pdf || standard_normal). Note: After computing the KL-divergence for
|
||||
# each dimension independently, we sum over the latent dimension to get the total
|
||||
# KL-divergence per batch element, then take the mean over the batch.
|
||||
# (See App. B of https://arxiv.org/abs/1312.6114 for more details).
|
||||
mean_kld = (
|
||||
(-0.5 * (1 + log_sigma_x2_hat - mu_hat.pow(2) - (log_sigma_x2_hat).exp())).sum(-1).mean()
|
||||
)
|
||||
loss_dict["kld_loss"] = mean_kld
|
||||
loss_dict["loss"] = l1_loss + mean_kld * self.cfg.kl_weight
|
||||
else:
|
||||
loss_dict["loss"] = l1_loss
|
||||
|
||||
return loss_dict
|
||||
|
||||
def _stack_images(self, batch: dict[str, Tensor]) -> dict[str, Tensor]:
|
||||
"""Stacks all the images in a batch and puts them in a new key: "observation.images".
|
||||
|
||||
This function expects `batch` to have (at least):
|
||||
{
|
||||
"observation.state": (B, state_dim) batch of robot states.
|
||||
"observation.images.{name}": (B, C, H, W) tensor of images.
|
||||
}
|
||||
"""
|
||||
# Stack images in the order dictated by input_shapes.
|
||||
batch["observation.images"] = torch.stack(
|
||||
[batch[k] for k in self.cfg.input_shapes if k.startswith("observation.images.")],
|
||||
dim=-4,
|
||||
)
|
||||
|
||||
def _forward(self, batch: dict[str, Tensor]) -> tuple[Tensor, tuple[Tensor, Tensor] | tuple[None, None]]:
|
||||
def forward(self, batch: dict[str, Tensor]) -> tuple[Tensor, tuple[Tensor, Tensor] | tuple[None, None]]:
|
||||
"""A forward pass through the Action Chunking Transformer (with optional VAE encoder).
|
||||
|
||||
`batch` should have the following structure:
|
||||
|
@ -231,8 +249,6 @@ class ActionChunkingTransformerPolicy(nn.Module):
|
|||
"action" in batch
|
||||
), "actions must be provided when using the variational objective in training mode."
|
||||
|
||||
self._stack_images(batch)
|
||||
|
||||
batch_size = batch["observation.state"].shape[0]
|
||||
|
||||
# Prepare the latent for input to the transformer encoder.
|
||||
|
@ -324,13 +340,6 @@ class ActionChunkingTransformerPolicy(nn.Module):
|
|||
|
||||
return actions, (mu, log_sigma_x2)
|
||||
|
||||
def save(self, fp):
|
||||
torch.save(self.state_dict(), fp)
|
||||
|
||||
def load(self, fp):
|
||||
d = torch.load(fp)
|
||||
self.load_state_dict(d)
|
||||
|
||||
|
||||
class _TransformerEncoder(nn.Module):
|
||||
"""Convenience module for running multiple encoder layers, maybe followed by normalization."""
|
||||
|
|
Loading…
Reference in New Issue