This commit is contained in:
Alexander Soare 2024-04-17 10:50:54 +01:00
parent 18dd8f32cd
commit c50a13ab31
3 changed files with 79 additions and 103 deletions
lerobot
common/policies/act
configs/policy

View File

@ -103,12 +103,21 @@ class ActionChunkingTransformerConfig:
def __post_init__(self):
"""Input validation (not exhaustive)."""
if not self.vision_backbone.startswith("resnet"):
raise ValueError("`vision_backbone` must be one of the ResNet variants.")
raise ValueError(
f"`vision_backbone` must be one of the ResNet variants. Got {self.vision_backbone}."
)
if self.use_temporal_aggregation:
raise NotImplementedError("Temporal aggregation is not yet implemented.")
if self.n_action_steps > self.chunk_size:
raise ValueError(
"The chunk size is the upper bound for the number of action steps per model invocation."
f"The chunk size is the upper bound for the number of action steps per model invocation. Got "
f"{self.n_action_steps} for `n_action_steps` and {self.chunk_size} for `chunk_size`."
)
if self.n_obs_steps != 1:
raise ValueError(
f"Multiple observation steps not handled yet. Got `nobs_steps={self.n_obs_steps}`"
)
if self.camera_names != ["top"]:
raise ValueError("For now, `camera_names` can only be ['top']")
raise ValueError(f"For now, `camera_names` can only be ['top']. Got {self.camera_names}.")
if len(set(self.camera_names)) != len(self.camera_names):
raise ValueError(f"`camera_names` should not have any repeated entries. Got {self.camera_names}.")

View File

@ -20,7 +20,9 @@ from torch import Tensor, nn
from torchvision.models._utils import IntermediateLayerGetter
from torchvision.ops.misc import FrozenBatchNorm2d
from lerobot.common.policies.act.configuration_act import ActionChunkingTransformerConfig
from lerobot.common.policies.act.configuration_act import (
ActionChunkingTransformerConfig,
)
class ActionChunkingTransformerPolicy(nn.Module):
@ -61,9 +63,6 @@ class ActionChunkingTransformerPolicy(nn.Module):
"""
name = "act"
_multiple_obs_steps_not_handled_msg = (
"ActionChunkingTransformerPolicy does not handle multiple observation steps."
)
def __init__(self, cfg: ActionChunkingTransformerConfig | None = None):
"""
@ -74,8 +73,6 @@ class ActionChunkingTransformerPolicy(nn.Module):
super().__init__()
if cfg is None:
cfg = ActionChunkingTransformerConfig()
if cfg.n_obs_steps != 1:
raise ValueError(self._multiple_obs_steps_not_handled_msg)
self.cfg = cfg
# BERT style VAE encoder with input [cls, *joint_space_configuration, *action_sequence].
@ -102,7 +99,11 @@ class ActionChunkingTransformerPolicy(nn.Module):
mean=cfg.image_normalization_mean, std=cfg.image_normalization_std
)
backbone_model = getattr(torchvision.models, cfg.vision_backbone)(
replace_stride_with_dilation=[False, False, cfg.replace_final_stride_with_dilation],
replace_stride_with_dilation=[
False,
False,
cfg.replace_final_stride_with_dilation,
],
pretrained=cfg.use_pretrained_backbone,
norm_layer=FrozenBatchNorm2d,
)
@ -176,82 +177,16 @@ class ActionChunkingTransformerPolicy(nn.Module):
environment. It works by managing the actions in a queue and only calling `select_actions` when the
queue is empty.
"""
self.eval()
if len(self._action_queue) == 0:
# `select_actions` returns a (batch_size, n_action_steps, *) tensor, but the queue effectively has
# shape (n_action_steps, batch_size, *), hence the transpose.
self._action_queue.extend(self._select_actions(batch).transpose(0, 1))
# `_forward` returns a (batch_size, n_action_steps, action_dim) tensor, but the queue effectively
# has shape (n_action_steps, batch_size, *), hence the transpose.
self._action_queue.extend(self._forward(batch)[0][: self.cfg.n_action_steps].transpose(0, 1))
return self._action_queue.popleft()
@torch.no_grad
def _select_actions(self, batch: dict[str, Tensor]) -> Tensor:
"""Use the action chunking transformer to generate a sequence of actions."""
self.eval()
batch = self._reshape_batch(batch, add_obs_steps_dim=True)
actions, _ = self._forward(
batch["observation.state"], self.image_normalizer(batch["observation.images.top"])
)
return actions[: self.cfg.n_action_steps]
def _reshape_batch(self, batch: dict[str, Tensor], add_obs_steps_dim: bool = False) -> dict[str, Tensor]:
"""Reshapes the batch items to account for various requirements of this policy.
This function expects `batch` to have (at least):
{
"observation.state": (B, 1, J) OR (B, J) tensor of robot states (joint configuration).
"observation.images.top": (B, 1, C, H, W) OR (B, C, H, W) tensor of images.
}
TODO(alexander-soare): Right now this method does and undoes reshaping operations. This is just to
separate out the core logic from the temporary logic. See comments below.
"""
# Create a shallow copy.
batch = dict(batch)
# Add a dimension for observation steps.
if add_obs_steps_dim:
# Add a dimension for the observations steps. Since n_obs_steps > 1 is not supported right now,
# this just amounts to an unsqueeze.
for k in batch:
if k.startswith("observation."):
batch[k] = batch[k].unsqueeze(1)
# Temporary logic to remove the observation step dimension as the policy does not yet handle it.
# TODO(alexander-soare): generalize this to multiple observations steps.
# Check that there is only 1 observation step (policy does not yet handle more).
if not all(batch[k].shape[1] == 1 for k in batch if k.startswith("observation.")):
raise ValueError(self._multiple_obs_steps_not_handled_msg)
# Remove observation steps dimension.
for k in batch:
if k.startswith("observation."):
batch[k] = batch[k].squeeze(1)
# Temporary logic to add the multiple image dimension back in.
# TODO(alexander-soare): generalize this to multiple images. Once resolved, this logic will stack all
# images.
assert (
sum(k.startswith("observation.images.") and not k.endswith("is_pad") for k in batch) == 1
), f"{self.__class__.__name__} only handles one image for now."
# Since we only handle one image, just unsqueeze instead of stacking.
batch["observation.images.top"] = batch["observation.images.top"].unsqueeze(1)
return batch
def compute_loss(self, batch, **_) -> float:
batch = self._reshape_batch(batch)
self.train()
num_slices = self.cfg.batch_size
batch_size = self.cfg.chunk_size * num_slices
assert batch_size % self.cfg.chunk_size == 0
assert batch_size % num_slices == 0
actions_hat, (mu_hat, log_sigma_x2_hat) = self._forward(
batch["observation.state"],
self.image_normalizer(batch["observation.images.top"]),
batch["action"],
)
"""Runs the batch through the model and computes the loss for training or validation."""
actions_hat, (mu_hat, log_sigma_x2_hat) = self._forward(batch)
l1_loss = (
F.l1_loss(batch["action"], actions_hat, reduction="none") * ~batch["action_is_pad"].unsqueeze(-1)
@ -274,6 +209,7 @@ class ActionChunkingTransformerPolicy(nn.Module):
def update(self, batch, **_) -> dict:
"""Run the model in train mode, compute the loss, and do an optimization step."""
start_time = time.time()
self.train()
loss = self.compute_loss(batch)
loss.backward()
@ -295,35 +231,64 @@ class ActionChunkingTransformerPolicy(nn.Module):
return info
def _forward(
self, robot_state: Tensor, image: Tensor, actions: Tensor | None = None
) -> tuple[Tensor, tuple[Tensor | None, Tensor | None]]:
def _stack_images(self, batch: dict[str, Tensor]) -> dict[str, Tensor]:
"""Stacks all the images in a batch and puts them in a new key: "observation.images".
This function expects `batch` to have (at least):
{
"observation.state": (B, state_dim) batch of robot states.
"observation.images.{name}": (B, C, H, W) tensor of images.
}
"""
Args:
robot_state: (B, J) batch of robot joint configurations.
image: (B, N, C, H, W) batch of N camera frames.
actions: (B, S, A) batch of actions from the target dataset which must be provided if the
VAE is enabled and the model is in training mode.
# Check that there is only one image.
# TODO(alexander-soare): generalize this to multiple images.
provided_cameras = {k.rsplit(".", 1)[-1] for k in batch if k.startswith("observation.images.")}
if len(missing := set(self.cfg.camera_names).difference(provided_cameras)) > 0:
raise ValueError(
f"The following camera images are missing from the provided batch: {missing}. Check the "
"configuration parameter: `camera_names`."
)
# Stack images in the order dictated by the camera names.
batch["observation.images"] = torch.stack(
[batch[f"observation.images.{name}"] for name in self.cfg.camera_names],
dim=-4,
)
def _forward(self, batch: dict[str, Tensor]) -> tuple[Tensor, tuple[Tensor, Tensor] | tuple[None, None]]:
"""A forward pass through the Action Chunking Transformer (with optional VAE encoder).
`batch` should have the following structure:
{
"observation.state": (B, state_dim) batch of robot states.
"observation.images": (B, n_cameras, C, H, W) batch of images.
"action" (optional, only if training with VAE): (B, chunk_size, action dim) batch of actions.
}
Returns:
(B, S, A) batch of action sequences
(B, chunk_size, action_dim) batch of action sequences
Tuple containing the latent PDF's parameters (mean, log(σ²)) both as (B, L) tensors where L is the
latent dimension.
"""
if self.cfg.use_vae and self.training:
assert (
actions is not None
"action" in batch
), "actions must be provided when using the variational objective in training mode."
batch_size = robot_state.shape[0]
self._stack_images(batch)
batch_size = batch["observation.state"].shape[0]
# Prepare the latent for input to the transformer encoder.
if self.cfg.use_vae and actions is not None:
if self.cfg.use_vae and "action" in batch:
# Prepare the input to the VAE encoder: [cls, *joint_space_configuration, *action_sequence].
cls_embed = einops.repeat(
self.vae_encoder_cls_embed.weight, "1 d -> b 1 d", b=batch_size
) # (B, 1, D)
robot_state_embed = self.vae_encoder_robot_state_input_proj(robot_state).unsqueeze(1) # (B, 1, D)
action_embed = self.vae_encoder_action_input_proj(actions) # (B, S, D)
robot_state_embed = self.vae_encoder_robot_state_input_proj(batch["observation.state"]).unsqueeze(
1
) # (B, 1, D)
action_embed = self.vae_encoder_action_input_proj(batch["action"]) # (B, S, D)
vae_encoder_input = torch.cat([cls_embed, robot_state_embed, action_embed], axis=1) # (B, S+2, D)
# Prepare fixed positional embedding.
@ -345,15 +310,16 @@ class ActionChunkingTransformerPolicy(nn.Module):
# When not using the VAE encoder, we set the latent to be all zeros.
mu = log_sigma_x2 = None
latent_sample = torch.zeros([batch_size, self.latent_dim], dtype=torch.float32).to(
robot_state.device
batch["observation.state"].device
)
# Prepare all other transformer encoder inputs.
# Camera observation features and positional embeddings.
all_cam_features = []
all_cam_pos_embeds = []
for cam_id, _ in enumerate(self.cfg.camera_names):
cam_features = self.backbone(image[:, cam_id])["feature_map"]
images = self.image_normalizer(batch["observation.images"])
for cam_index in range(len(self.cfg.camera_names)):
cam_features = self.backbone(images[:, cam_index])["feature_map"]
cam_pos_embed = self.encoder_cam_feat_pos_embed(cam_features).to(dtype=cam_features.dtype)
cam_features = self.encoder_img_feat_input_proj(cam_features) # (B, C, h, w)
all_cam_features.append(cam_features)
@ -363,7 +329,7 @@ class ActionChunkingTransformerPolicy(nn.Module):
cam_pos_embed = torch.cat(all_cam_pos_embeds, axis=3)
# Get positional embeddings for robot state and latent.
robot_state_embed = self.encoder_robot_state_input_proj(robot_state)
robot_state_embed = self.encoder_robot_state_input_proj(batch["observation.state"])
latent_embed = self.encoder_latent_input_proj(latent_sample)
# Stack encoder input and positional embeddings moving to (S, B, C).
@ -479,7 +445,10 @@ class _TransformerDecoder(nn.Module):
) -> Tensor:
for layer in self.layers:
x = layer(
x, encoder_out, decoder_pos_embed=decoder_pos_embed, encoder_pos_embed=encoder_pos_embed
x,
encoder_out,
decoder_pos_embed=decoder_pos_embed,
encoder_pos_embed=encoder_pos_embed,
)
if self.norm is not None:
x = self.norm(x)

View File

@ -67,6 +67,4 @@ policy:
utd: 1
delta_timestamps:
observation.images.top: "[i / ${fps} for i in range(1 - ${n_obs_steps}, 1)]"
observation.state: "[i / ${fps} for i in range(1 - ${n_obs_steps}, 1)]"
action: "[i / ${fps} for i in range(${policy.chunk_size})]"