Nest ACT model in ACT Policy (#122)
This commit is contained in:
parent
9d60dce6f3
commit
986583dc5c
|
@ -26,6 +26,108 @@ class ActionChunkingTransformerPolicy(nn.Module):
|
||||||
"""
|
"""
|
||||||
Action Chunking Transformer Policy as per Learning Fine-Grained Bimanual Manipulation with Low-Cost
|
Action Chunking Transformer Policy as per Learning Fine-Grained Bimanual Manipulation with Low-Cost
|
||||||
Hardware (paper: https://arxiv.org/abs/2304.13705, code: https://github.com/tonyzhaozh/act)
|
Hardware (paper: https://arxiv.org/abs/2304.13705, code: https://github.com/tonyzhaozh/act)
|
||||||
|
"""
|
||||||
|
|
||||||
|
name = "act"
|
||||||
|
|
||||||
|
def __init__(self, cfg: ActionChunkingTransformerConfig | None = None, dataset_stats=None):
|
||||||
|
"""
|
||||||
|
Args:
|
||||||
|
cfg: Policy configuration class instance or None, in which case the default instantiation of the
|
||||||
|
configuration class is used.
|
||||||
|
"""
|
||||||
|
super().__init__()
|
||||||
|
if cfg is None:
|
||||||
|
cfg = ActionChunkingTransformerConfig()
|
||||||
|
self.cfg = cfg
|
||||||
|
self.normalize_inputs = Normalize(cfg.input_shapes, cfg.input_normalization_modes, dataset_stats)
|
||||||
|
self.normalize_targets = Normalize(cfg.output_shapes, cfg.output_normalization_modes, dataset_stats)
|
||||||
|
self.unnormalize_outputs = Unnormalize(
|
||||||
|
cfg.output_shapes, cfg.output_normalization_modes, dataset_stats
|
||||||
|
)
|
||||||
|
self.model = _ActionChunkingTransformer(cfg)
|
||||||
|
|
||||||
|
def reset(self):
|
||||||
|
"""This should be called whenever the environment is reset."""
|
||||||
|
if self.cfg.n_action_steps is not None:
|
||||||
|
self._action_queue = deque([], maxlen=self.cfg.n_action_steps)
|
||||||
|
|
||||||
|
@torch.no_grad
|
||||||
|
def select_action(self, batch: dict[str, Tensor], **_) -> Tensor:
|
||||||
|
"""Select a single action given environment observations.
|
||||||
|
|
||||||
|
This method wraps `select_actions` in order to return one action at a time for execution in the
|
||||||
|
environment. It works by managing the actions in a queue and only calling `select_actions` when the
|
||||||
|
queue is empty.
|
||||||
|
"""
|
||||||
|
self.eval()
|
||||||
|
|
||||||
|
batch = self.normalize_inputs(batch)
|
||||||
|
self._stack_images(batch)
|
||||||
|
|
||||||
|
if len(self._action_queue) == 0:
|
||||||
|
# `self.model.forward` returns a (batch_size, n_action_steps, action_dim) tensor, but the queue
|
||||||
|
# effectively has shape (n_action_steps, batch_size, *), hence the transpose.
|
||||||
|
actions = self.model(batch)[0][: self.cfg.n_action_steps]
|
||||||
|
|
||||||
|
# TODO(rcadene): make _forward return output dictionary?
|
||||||
|
actions = self.unnormalize_outputs({"action": actions})["action"]
|
||||||
|
|
||||||
|
self._action_queue.extend(actions.transpose(0, 1))
|
||||||
|
return self._action_queue.popleft()
|
||||||
|
|
||||||
|
def forward(self, batch, **_) -> dict[str, Tensor]:
|
||||||
|
"""Run the batch through the model and compute the loss for training or validation."""
|
||||||
|
batch = self.normalize_inputs(batch)
|
||||||
|
batch = self.normalize_targets(batch)
|
||||||
|
self._stack_images(batch)
|
||||||
|
actions_hat, (mu_hat, log_sigma_x2_hat) = self.model(batch)
|
||||||
|
|
||||||
|
l1_loss = (
|
||||||
|
F.l1_loss(batch["action"], actions_hat, reduction="none") * ~batch["action_is_pad"].unsqueeze(-1)
|
||||||
|
).mean()
|
||||||
|
|
||||||
|
loss_dict = {"l1_loss": l1_loss}
|
||||||
|
if self.cfg.use_vae:
|
||||||
|
# Calculate Dₖₗ(latent_pdf || standard_normal). Note: After computing the KL-divergence for
|
||||||
|
# each dimension independently, we sum over the latent dimension to get the total
|
||||||
|
# KL-divergence per batch element, then take the mean over the batch.
|
||||||
|
# (See App. B of https://arxiv.org/abs/1312.6114 for more details).
|
||||||
|
mean_kld = (
|
||||||
|
(-0.5 * (1 + log_sigma_x2_hat - mu_hat.pow(2) - (log_sigma_x2_hat).exp())).sum(-1).mean()
|
||||||
|
)
|
||||||
|
loss_dict["kld_loss"] = mean_kld
|
||||||
|
loss_dict["loss"] = l1_loss + mean_kld * self.cfg.kl_weight
|
||||||
|
else:
|
||||||
|
loss_dict["loss"] = l1_loss
|
||||||
|
|
||||||
|
return loss_dict
|
||||||
|
|
||||||
|
def _stack_images(self, batch: dict[str, Tensor]) -> dict[str, Tensor]:
|
||||||
|
"""Stacks all the images in a batch and puts them in a new key: "observation.images".
|
||||||
|
|
||||||
|
This function expects `batch` to have (at least):
|
||||||
|
{
|
||||||
|
"observation.state": (B, state_dim) batch of robot states.
|
||||||
|
"observation.images.{name}": (B, C, H, W) tensor of images.
|
||||||
|
}
|
||||||
|
"""
|
||||||
|
# Stack images in the order dictated by input_shapes.
|
||||||
|
batch["observation.images"] = torch.stack(
|
||||||
|
[batch[k] for k in self.cfg.input_shapes if k.startswith("observation.images.")],
|
||||||
|
dim=-4,
|
||||||
|
)
|
||||||
|
|
||||||
|
def save(self, fp):
|
||||||
|
torch.save(self.state_dict(), fp)
|
||||||
|
|
||||||
|
def load(self, fp):
|
||||||
|
d = torch.load(fp)
|
||||||
|
self.load_state_dict(d)
|
||||||
|
|
||||||
|
|
||||||
|
class _ActionChunkingTransformer(nn.Module):
|
||||||
|
"""Action Chunking Transformer: The underlying neural network for ActionChunkingTransformerPolicy.
|
||||||
|
|
||||||
Note: In this code we use the terms `vae_encoder`, 'encoder', `decoder`. The meanings are as follows.
|
Note: In this code we use the terms `vae_encoder`, 'encoder', `decoder`. The meanings are as follows.
|
||||||
- The `vae_encoder` is, as per the literature around variational auto-encoders (VAE), the part of the
|
- The `vae_encoder` is, as per the literature around variational auto-encoders (VAE), the part of the
|
||||||
|
@ -59,24 +161,9 @@ class ActionChunkingTransformerPolicy(nn.Module):
|
||||||
└───────────────────────┘
|
└───────────────────────┘
|
||||||
"""
|
"""
|
||||||
|
|
||||||
name = "act"
|
def __init__(self, cfg: ActionChunkingTransformerConfig):
|
||||||
|
|
||||||
def __init__(self, cfg: ActionChunkingTransformerConfig | None = None, dataset_stats=None):
|
|
||||||
"""
|
|
||||||
Args:
|
|
||||||
cfg: Policy configuration class instance or None, in which case the default instantiation of the
|
|
||||||
configuration class is used.
|
|
||||||
"""
|
|
||||||
super().__init__()
|
super().__init__()
|
||||||
if cfg is None:
|
|
||||||
cfg = ActionChunkingTransformerConfig()
|
|
||||||
self.cfg = cfg
|
self.cfg = cfg
|
||||||
self.normalize_inputs = Normalize(cfg.input_shapes, cfg.input_normalization_modes, dataset_stats)
|
|
||||||
self.normalize_targets = Normalize(cfg.output_shapes, cfg.output_normalization_modes, dataset_stats)
|
|
||||||
self.unnormalize_outputs = Unnormalize(
|
|
||||||
cfg.output_shapes, cfg.output_normalization_modes, dataset_stats
|
|
||||||
)
|
|
||||||
|
|
||||||
# BERT style VAE encoder with input [cls, *joint_space_configuration, *action_sequence].
|
# BERT style VAE encoder with input [cls, *joint_space_configuration, *action_sequence].
|
||||||
# The cls token forms parameters of the latent's distribution (like this [*means, *log_variances]).
|
# The cls token forms parameters of the latent's distribution (like this [*means, *log_variances]).
|
||||||
if self.cfg.use_vae:
|
if self.cfg.use_vae:
|
||||||
|
@ -141,76 +228,7 @@ class ActionChunkingTransformerPolicy(nn.Module):
|
||||||
if p.dim() > 1:
|
if p.dim() > 1:
|
||||||
nn.init.xavier_uniform_(p)
|
nn.init.xavier_uniform_(p)
|
||||||
|
|
||||||
def reset(self):
|
def forward(self, batch: dict[str, Tensor]) -> tuple[Tensor, tuple[Tensor, Tensor] | tuple[None, None]]:
|
||||||
"""This should be called whenever the environment is reset."""
|
|
||||||
if self.cfg.n_action_steps is not None:
|
|
||||||
self._action_queue = deque([], maxlen=self.cfg.n_action_steps)
|
|
||||||
|
|
||||||
@torch.no_grad
|
|
||||||
def select_action(self, batch: dict[str, Tensor], **_) -> Tensor:
|
|
||||||
"""Select a single action given environment observations.
|
|
||||||
|
|
||||||
This method wraps `select_actions` in order to return one action at a time for execution in the
|
|
||||||
environment. It works by managing the actions in a queue and only calling `select_actions` when the
|
|
||||||
queue is empty.
|
|
||||||
"""
|
|
||||||
self.eval()
|
|
||||||
|
|
||||||
batch = self.normalize_inputs(batch)
|
|
||||||
|
|
||||||
if len(self._action_queue) == 0:
|
|
||||||
# `_forward` returns a (batch_size, n_action_steps, action_dim) tensor, but the queue effectively
|
|
||||||
# has shape (n_action_steps, batch_size, *), hence the transpose.
|
|
||||||
actions = self._forward(batch)[0][: self.cfg.n_action_steps]
|
|
||||||
|
|
||||||
# TODO(rcadene): make _forward return output dictionary?
|
|
||||||
actions = self.unnormalize_outputs({"action": actions})["action"]
|
|
||||||
|
|
||||||
self._action_queue.extend(actions.transpose(0, 1))
|
|
||||||
return self._action_queue.popleft()
|
|
||||||
|
|
||||||
def forward(self, batch, **_) -> dict[str, Tensor]:
|
|
||||||
"""Run the batch through the model and compute the loss for training or validation."""
|
|
||||||
batch = self.normalize_inputs(batch)
|
|
||||||
batch = self.normalize_targets(batch)
|
|
||||||
actions_hat, (mu_hat, log_sigma_x2_hat) = self._forward(batch)
|
|
||||||
|
|
||||||
l1_loss = (
|
|
||||||
F.l1_loss(batch["action"], actions_hat, reduction="none") * ~batch["action_is_pad"].unsqueeze(-1)
|
|
||||||
).mean()
|
|
||||||
|
|
||||||
loss_dict = {"l1_loss": l1_loss}
|
|
||||||
if self.cfg.use_vae:
|
|
||||||
# Calculate Dₖₗ(latent_pdf || standard_normal). Note: After computing the KL-divergence for
|
|
||||||
# each dimension independently, we sum over the latent dimension to get the total
|
|
||||||
# KL-divergence per batch element, then take the mean over the batch.
|
|
||||||
# (See App. B of https://arxiv.org/abs/1312.6114 for more details).
|
|
||||||
mean_kld = (
|
|
||||||
(-0.5 * (1 + log_sigma_x2_hat - mu_hat.pow(2) - (log_sigma_x2_hat).exp())).sum(-1).mean()
|
|
||||||
)
|
|
||||||
loss_dict["kld_loss"] = mean_kld
|
|
||||||
loss_dict["loss"] = l1_loss + mean_kld * self.cfg.kl_weight
|
|
||||||
else:
|
|
||||||
loss_dict["loss"] = l1_loss
|
|
||||||
|
|
||||||
return loss_dict
|
|
||||||
|
|
||||||
def _stack_images(self, batch: dict[str, Tensor]) -> dict[str, Tensor]:
|
|
||||||
"""Stacks all the images in a batch and puts them in a new key: "observation.images".
|
|
||||||
|
|
||||||
This function expects `batch` to have (at least):
|
|
||||||
{
|
|
||||||
"observation.state": (B, state_dim) batch of robot states.
|
|
||||||
"observation.images.{name}": (B, C, H, W) tensor of images.
|
|
||||||
}
|
|
||||||
"""
|
|
||||||
# Stack images in the order dictated by input_shapes.
|
|
||||||
batch["observation.images"] = torch.stack(
|
|
||||||
[batch[k] for k in self.cfg.input_shapes if k.startswith("observation.images.")],
|
|
||||||
dim=-4,
|
|
||||||
)
|
|
||||||
|
|
||||||
def _forward(self, batch: dict[str, Tensor]) -> tuple[Tensor, tuple[Tensor, Tensor] | tuple[None, None]]:
|
|
||||||
"""A forward pass through the Action Chunking Transformer (with optional VAE encoder).
|
"""A forward pass through the Action Chunking Transformer (with optional VAE encoder).
|
||||||
|
|
||||||
`batch` should have the following structure:
|
`batch` should have the following structure:
|
||||||
|
@ -231,8 +249,6 @@ class ActionChunkingTransformerPolicy(nn.Module):
|
||||||
"action" in batch
|
"action" in batch
|
||||||
), "actions must be provided when using the variational objective in training mode."
|
), "actions must be provided when using the variational objective in training mode."
|
||||||
|
|
||||||
self._stack_images(batch)
|
|
||||||
|
|
||||||
batch_size = batch["observation.state"].shape[0]
|
batch_size = batch["observation.state"].shape[0]
|
||||||
|
|
||||||
# Prepare the latent for input to the transformer encoder.
|
# Prepare the latent for input to the transformer encoder.
|
||||||
|
@ -324,13 +340,6 @@ class ActionChunkingTransformerPolicy(nn.Module):
|
||||||
|
|
||||||
return actions, (mu, log_sigma_x2)
|
return actions, (mu, log_sigma_x2)
|
||||||
|
|
||||||
def save(self, fp):
|
|
||||||
torch.save(self.state_dict(), fp)
|
|
||||||
|
|
||||||
def load(self, fp):
|
|
||||||
d = torch.load(fp)
|
|
||||||
self.load_state_dict(d)
|
|
||||||
|
|
||||||
|
|
||||||
class _TransformerEncoder(nn.Module):
|
class _TransformerEncoder(nn.Module):
|
||||||
"""Convenience module for running multiple encoder layers, maybe followed by normalization."""
|
"""Convenience module for running multiple encoder layers, maybe followed by normalization."""
|
||||||
|
|
Loading…
Reference in New Issue