draft
This commit is contained in:
parent
18dd8f32cd
commit
c50a13ab31
|
@ -103,12 +103,21 @@ class ActionChunkingTransformerConfig:
|
||||||
def __post_init__(self):
|
def __post_init__(self):
|
||||||
"""Input validation (not exhaustive)."""
|
"""Input validation (not exhaustive)."""
|
||||||
if not self.vision_backbone.startswith("resnet"):
|
if not self.vision_backbone.startswith("resnet"):
|
||||||
raise ValueError("`vision_backbone` must be one of the ResNet variants.")
|
raise ValueError(
|
||||||
|
f"`vision_backbone` must be one of the ResNet variants. Got {self.vision_backbone}."
|
||||||
|
)
|
||||||
if self.use_temporal_aggregation:
|
if self.use_temporal_aggregation:
|
||||||
raise NotImplementedError("Temporal aggregation is not yet implemented.")
|
raise NotImplementedError("Temporal aggregation is not yet implemented.")
|
||||||
if self.n_action_steps > self.chunk_size:
|
if self.n_action_steps > self.chunk_size:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
"The chunk size is the upper bound for the number of action steps per model invocation."
|
f"The chunk size is the upper bound for the number of action steps per model invocation. Got "
|
||||||
|
f"{self.n_action_steps} for `n_action_steps` and {self.chunk_size} for `chunk_size`."
|
||||||
|
)
|
||||||
|
if self.n_obs_steps != 1:
|
||||||
|
raise ValueError(
|
||||||
|
f"Multiple observation steps not handled yet. Got `nobs_steps={self.n_obs_steps}`"
|
||||||
)
|
)
|
||||||
if self.camera_names != ["top"]:
|
if self.camera_names != ["top"]:
|
||||||
raise ValueError("For now, `camera_names` can only be ['top']")
|
raise ValueError(f"For now, `camera_names` can only be ['top']. Got {self.camera_names}.")
|
||||||
|
if len(set(self.camera_names)) != len(self.camera_names):
|
||||||
|
raise ValueError(f"`camera_names` should not have any repeated entries. Got {self.camera_names}.")
|
||||||
|
|
|
@ -20,7 +20,9 @@ from torch import Tensor, nn
|
||||||
from torchvision.models._utils import IntermediateLayerGetter
|
from torchvision.models._utils import IntermediateLayerGetter
|
||||||
from torchvision.ops.misc import FrozenBatchNorm2d
|
from torchvision.ops.misc import FrozenBatchNorm2d
|
||||||
|
|
||||||
from lerobot.common.policies.act.configuration_act import ActionChunkingTransformerConfig
|
from lerobot.common.policies.act.configuration_act import (
|
||||||
|
ActionChunkingTransformerConfig,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class ActionChunkingTransformerPolicy(nn.Module):
|
class ActionChunkingTransformerPolicy(nn.Module):
|
||||||
|
@ -61,9 +63,6 @@ class ActionChunkingTransformerPolicy(nn.Module):
|
||||||
"""
|
"""
|
||||||
|
|
||||||
name = "act"
|
name = "act"
|
||||||
_multiple_obs_steps_not_handled_msg = (
|
|
||||||
"ActionChunkingTransformerPolicy does not handle multiple observation steps."
|
|
||||||
)
|
|
||||||
|
|
||||||
def __init__(self, cfg: ActionChunkingTransformerConfig | None = None):
|
def __init__(self, cfg: ActionChunkingTransformerConfig | None = None):
|
||||||
"""
|
"""
|
||||||
|
@ -74,8 +73,6 @@ class ActionChunkingTransformerPolicy(nn.Module):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
if cfg is None:
|
if cfg is None:
|
||||||
cfg = ActionChunkingTransformerConfig()
|
cfg = ActionChunkingTransformerConfig()
|
||||||
if cfg.n_obs_steps != 1:
|
|
||||||
raise ValueError(self._multiple_obs_steps_not_handled_msg)
|
|
||||||
self.cfg = cfg
|
self.cfg = cfg
|
||||||
|
|
||||||
# BERT style VAE encoder with input [cls, *joint_space_configuration, *action_sequence].
|
# BERT style VAE encoder with input [cls, *joint_space_configuration, *action_sequence].
|
||||||
|
@ -102,7 +99,11 @@ class ActionChunkingTransformerPolicy(nn.Module):
|
||||||
mean=cfg.image_normalization_mean, std=cfg.image_normalization_std
|
mean=cfg.image_normalization_mean, std=cfg.image_normalization_std
|
||||||
)
|
)
|
||||||
backbone_model = getattr(torchvision.models, cfg.vision_backbone)(
|
backbone_model = getattr(torchvision.models, cfg.vision_backbone)(
|
||||||
replace_stride_with_dilation=[False, False, cfg.replace_final_stride_with_dilation],
|
replace_stride_with_dilation=[
|
||||||
|
False,
|
||||||
|
False,
|
||||||
|
cfg.replace_final_stride_with_dilation,
|
||||||
|
],
|
||||||
pretrained=cfg.use_pretrained_backbone,
|
pretrained=cfg.use_pretrained_backbone,
|
||||||
norm_layer=FrozenBatchNorm2d,
|
norm_layer=FrozenBatchNorm2d,
|
||||||
)
|
)
|
||||||
|
@ -176,82 +177,16 @@ class ActionChunkingTransformerPolicy(nn.Module):
|
||||||
environment. It works by managing the actions in a queue and only calling `select_actions` when the
|
environment. It works by managing the actions in a queue and only calling `select_actions` when the
|
||||||
queue is empty.
|
queue is empty.
|
||||||
"""
|
"""
|
||||||
|
self.eval()
|
||||||
if len(self._action_queue) == 0:
|
if len(self._action_queue) == 0:
|
||||||
# `select_actions` returns a (batch_size, n_action_steps, *) tensor, but the queue effectively has
|
# `_forward` returns a (batch_size, n_action_steps, action_dim) tensor, but the queue effectively
|
||||||
# shape (n_action_steps, batch_size, *), hence the transpose.
|
# has shape (n_action_steps, batch_size, *), hence the transpose.
|
||||||
self._action_queue.extend(self._select_actions(batch).transpose(0, 1))
|
self._action_queue.extend(self._forward(batch)[0][: self.cfg.n_action_steps].transpose(0, 1))
|
||||||
return self._action_queue.popleft()
|
return self._action_queue.popleft()
|
||||||
|
|
||||||
@torch.no_grad
|
|
||||||
def _select_actions(self, batch: dict[str, Tensor]) -> Tensor:
|
|
||||||
"""Use the action chunking transformer to generate a sequence of actions."""
|
|
||||||
self.eval()
|
|
||||||
batch = self._reshape_batch(batch, add_obs_steps_dim=True)
|
|
||||||
actions, _ = self._forward(
|
|
||||||
batch["observation.state"], self.image_normalizer(batch["observation.images.top"])
|
|
||||||
)
|
|
||||||
return actions[: self.cfg.n_action_steps]
|
|
||||||
|
|
||||||
def _reshape_batch(self, batch: dict[str, Tensor], add_obs_steps_dim: bool = False) -> dict[str, Tensor]:
|
|
||||||
"""Reshapes the batch items to account for various requirements of this policy.
|
|
||||||
|
|
||||||
This function expects `batch` to have (at least):
|
|
||||||
{
|
|
||||||
"observation.state": (B, 1, J) OR (B, J) tensor of robot states (joint configuration).
|
|
||||||
"observation.images.top": (B, 1, C, H, W) OR (B, C, H, W) tensor of images.
|
|
||||||
}
|
|
||||||
|
|
||||||
TODO(alexander-soare): Right now this method does and undoes reshaping operations. This is just to
|
|
||||||
separate out the core logic from the temporary logic. See comments below.
|
|
||||||
"""
|
|
||||||
# Create a shallow copy.
|
|
||||||
batch = dict(batch)
|
|
||||||
|
|
||||||
# Add a dimension for observation steps.
|
|
||||||
if add_obs_steps_dim:
|
|
||||||
# Add a dimension for the observations steps. Since n_obs_steps > 1 is not supported right now,
|
|
||||||
# this just amounts to an unsqueeze.
|
|
||||||
for k in batch:
|
|
||||||
if k.startswith("observation."):
|
|
||||||
batch[k] = batch[k].unsqueeze(1)
|
|
||||||
|
|
||||||
# Temporary logic to remove the observation step dimension as the policy does not yet handle it.
|
|
||||||
# TODO(alexander-soare): generalize this to multiple observations steps.
|
|
||||||
# Check that there is only 1 observation step (policy does not yet handle more).
|
|
||||||
if not all(batch[k].shape[1] == 1 for k in batch if k.startswith("observation.")):
|
|
||||||
raise ValueError(self._multiple_obs_steps_not_handled_msg)
|
|
||||||
# Remove observation steps dimension.
|
|
||||||
for k in batch:
|
|
||||||
if k.startswith("observation."):
|
|
||||||
batch[k] = batch[k].squeeze(1)
|
|
||||||
|
|
||||||
# Temporary logic to add the multiple image dimension back in.
|
|
||||||
# TODO(alexander-soare): generalize this to multiple images. Once resolved, this logic will stack all
|
|
||||||
# images.
|
|
||||||
assert (
|
|
||||||
sum(k.startswith("observation.images.") and not k.endswith("is_pad") for k in batch) == 1
|
|
||||||
), f"{self.__class__.__name__} only handles one image for now."
|
|
||||||
# Since we only handle one image, just unsqueeze instead of stacking.
|
|
||||||
batch["observation.images.top"] = batch["observation.images.top"].unsqueeze(1)
|
|
||||||
|
|
||||||
return batch
|
|
||||||
|
|
||||||
def compute_loss(self, batch, **_) -> float:
|
def compute_loss(self, batch, **_) -> float:
|
||||||
batch = self._reshape_batch(batch)
|
"""Runs the batch through the model and computes the loss for training or validation."""
|
||||||
|
actions_hat, (mu_hat, log_sigma_x2_hat) = self._forward(batch)
|
||||||
self.train()
|
|
||||||
|
|
||||||
num_slices = self.cfg.batch_size
|
|
||||||
batch_size = self.cfg.chunk_size * num_slices
|
|
||||||
|
|
||||||
assert batch_size % self.cfg.chunk_size == 0
|
|
||||||
assert batch_size % num_slices == 0
|
|
||||||
|
|
||||||
actions_hat, (mu_hat, log_sigma_x2_hat) = self._forward(
|
|
||||||
batch["observation.state"],
|
|
||||||
self.image_normalizer(batch["observation.images.top"]),
|
|
||||||
batch["action"],
|
|
||||||
)
|
|
||||||
|
|
||||||
l1_loss = (
|
l1_loss = (
|
||||||
F.l1_loss(batch["action"], actions_hat, reduction="none") * ~batch["action_is_pad"].unsqueeze(-1)
|
F.l1_loss(batch["action"], actions_hat, reduction="none") * ~batch["action_is_pad"].unsqueeze(-1)
|
||||||
|
@ -274,6 +209,7 @@ class ActionChunkingTransformerPolicy(nn.Module):
|
||||||
def update(self, batch, **_) -> dict:
|
def update(self, batch, **_) -> dict:
|
||||||
"""Run the model in train mode, compute the loss, and do an optimization step."""
|
"""Run the model in train mode, compute the loss, and do an optimization step."""
|
||||||
start_time = time.time()
|
start_time = time.time()
|
||||||
|
self.train()
|
||||||
loss = self.compute_loss(batch)
|
loss = self.compute_loss(batch)
|
||||||
loss.backward()
|
loss.backward()
|
||||||
|
|
||||||
|
@ -295,35 +231,64 @@ class ActionChunkingTransformerPolicy(nn.Module):
|
||||||
|
|
||||||
return info
|
return info
|
||||||
|
|
||||||
def _forward(
|
def _stack_images(self, batch: dict[str, Tensor]) -> dict[str, Tensor]:
|
||||||
self, robot_state: Tensor, image: Tensor, actions: Tensor | None = None
|
"""Stacks all the images in a batch and puts them in a new key: "observation.images".
|
||||||
) -> tuple[Tensor, tuple[Tensor | None, Tensor | None]]:
|
|
||||||
|
This function expects `batch` to have (at least):
|
||||||
|
{
|
||||||
|
"observation.state": (B, state_dim) batch of robot states.
|
||||||
|
"observation.images.{name}": (B, C, H, W) tensor of images.
|
||||||
|
}
|
||||||
"""
|
"""
|
||||||
Args:
|
# Check that there is only one image.
|
||||||
robot_state: (B, J) batch of robot joint configurations.
|
# TODO(alexander-soare): generalize this to multiple images.
|
||||||
image: (B, N, C, H, W) batch of N camera frames.
|
provided_cameras = {k.rsplit(".", 1)[-1] for k in batch if k.startswith("observation.images.")}
|
||||||
actions: (B, S, A) batch of actions from the target dataset which must be provided if the
|
if len(missing := set(self.cfg.camera_names).difference(provided_cameras)) > 0:
|
||||||
VAE is enabled and the model is in training mode.
|
raise ValueError(
|
||||||
|
f"The following camera images are missing from the provided batch: {missing}. Check the "
|
||||||
|
"configuration parameter: `camera_names`."
|
||||||
|
)
|
||||||
|
# Stack images in the order dictated by the camera names.
|
||||||
|
batch["observation.images"] = torch.stack(
|
||||||
|
[batch[f"observation.images.{name}"] for name in self.cfg.camera_names],
|
||||||
|
dim=-4,
|
||||||
|
)
|
||||||
|
|
||||||
|
def _forward(self, batch: dict[str, Tensor]) -> tuple[Tensor, tuple[Tensor, Tensor] | tuple[None, None]]:
|
||||||
|
"""A forward pass through the Action Chunking Transformer (with optional VAE encoder).
|
||||||
|
|
||||||
|
`batch` should have the following structure:
|
||||||
|
|
||||||
|
{
|
||||||
|
"observation.state": (B, state_dim) batch of robot states.
|
||||||
|
"observation.images": (B, n_cameras, C, H, W) batch of images.
|
||||||
|
"action" (optional, only if training with VAE): (B, chunk_size, action dim) batch of actions.
|
||||||
|
}
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
(B, S, A) batch of action sequences
|
(B, chunk_size, action_dim) batch of action sequences
|
||||||
Tuple containing the latent PDF's parameters (mean, log(σ²)) both as (B, L) tensors where L is the
|
Tuple containing the latent PDF's parameters (mean, log(σ²)) both as (B, L) tensors where L is the
|
||||||
latent dimension.
|
latent dimension.
|
||||||
"""
|
"""
|
||||||
if self.cfg.use_vae and self.training:
|
if self.cfg.use_vae and self.training:
|
||||||
assert (
|
assert (
|
||||||
actions is not None
|
"action" in batch
|
||||||
), "actions must be provided when using the variational objective in training mode."
|
), "actions must be provided when using the variational objective in training mode."
|
||||||
|
|
||||||
batch_size = robot_state.shape[0]
|
self._stack_images(batch)
|
||||||
|
|
||||||
|
batch_size = batch["observation.state"].shape[0]
|
||||||
|
|
||||||
# Prepare the latent for input to the transformer encoder.
|
# Prepare the latent for input to the transformer encoder.
|
||||||
if self.cfg.use_vae and actions is not None:
|
if self.cfg.use_vae and "action" in batch:
|
||||||
# Prepare the input to the VAE encoder: [cls, *joint_space_configuration, *action_sequence].
|
# Prepare the input to the VAE encoder: [cls, *joint_space_configuration, *action_sequence].
|
||||||
cls_embed = einops.repeat(
|
cls_embed = einops.repeat(
|
||||||
self.vae_encoder_cls_embed.weight, "1 d -> b 1 d", b=batch_size
|
self.vae_encoder_cls_embed.weight, "1 d -> b 1 d", b=batch_size
|
||||||
) # (B, 1, D)
|
) # (B, 1, D)
|
||||||
robot_state_embed = self.vae_encoder_robot_state_input_proj(robot_state).unsqueeze(1) # (B, 1, D)
|
robot_state_embed = self.vae_encoder_robot_state_input_proj(batch["observation.state"]).unsqueeze(
|
||||||
action_embed = self.vae_encoder_action_input_proj(actions) # (B, S, D)
|
1
|
||||||
|
) # (B, 1, D)
|
||||||
|
action_embed = self.vae_encoder_action_input_proj(batch["action"]) # (B, S, D)
|
||||||
vae_encoder_input = torch.cat([cls_embed, robot_state_embed, action_embed], axis=1) # (B, S+2, D)
|
vae_encoder_input = torch.cat([cls_embed, robot_state_embed, action_embed], axis=1) # (B, S+2, D)
|
||||||
|
|
||||||
# Prepare fixed positional embedding.
|
# Prepare fixed positional embedding.
|
||||||
|
@ -345,15 +310,16 @@ class ActionChunkingTransformerPolicy(nn.Module):
|
||||||
# When not using the VAE encoder, we set the latent to be all zeros.
|
# When not using the VAE encoder, we set the latent to be all zeros.
|
||||||
mu = log_sigma_x2 = None
|
mu = log_sigma_x2 = None
|
||||||
latent_sample = torch.zeros([batch_size, self.latent_dim], dtype=torch.float32).to(
|
latent_sample = torch.zeros([batch_size, self.latent_dim], dtype=torch.float32).to(
|
||||||
robot_state.device
|
batch["observation.state"].device
|
||||||
)
|
)
|
||||||
|
|
||||||
# Prepare all other transformer encoder inputs.
|
# Prepare all other transformer encoder inputs.
|
||||||
# Camera observation features and positional embeddings.
|
# Camera observation features and positional embeddings.
|
||||||
all_cam_features = []
|
all_cam_features = []
|
||||||
all_cam_pos_embeds = []
|
all_cam_pos_embeds = []
|
||||||
for cam_id, _ in enumerate(self.cfg.camera_names):
|
images = self.image_normalizer(batch["observation.images"])
|
||||||
cam_features = self.backbone(image[:, cam_id])["feature_map"]
|
for cam_index in range(len(self.cfg.camera_names)):
|
||||||
|
cam_features = self.backbone(images[:, cam_index])["feature_map"]
|
||||||
cam_pos_embed = self.encoder_cam_feat_pos_embed(cam_features).to(dtype=cam_features.dtype)
|
cam_pos_embed = self.encoder_cam_feat_pos_embed(cam_features).to(dtype=cam_features.dtype)
|
||||||
cam_features = self.encoder_img_feat_input_proj(cam_features) # (B, C, h, w)
|
cam_features = self.encoder_img_feat_input_proj(cam_features) # (B, C, h, w)
|
||||||
all_cam_features.append(cam_features)
|
all_cam_features.append(cam_features)
|
||||||
|
@ -363,7 +329,7 @@ class ActionChunkingTransformerPolicy(nn.Module):
|
||||||
cam_pos_embed = torch.cat(all_cam_pos_embeds, axis=3)
|
cam_pos_embed = torch.cat(all_cam_pos_embeds, axis=3)
|
||||||
|
|
||||||
# Get positional embeddings for robot state and latent.
|
# Get positional embeddings for robot state and latent.
|
||||||
robot_state_embed = self.encoder_robot_state_input_proj(robot_state)
|
robot_state_embed = self.encoder_robot_state_input_proj(batch["observation.state"])
|
||||||
latent_embed = self.encoder_latent_input_proj(latent_sample)
|
latent_embed = self.encoder_latent_input_proj(latent_sample)
|
||||||
|
|
||||||
# Stack encoder input and positional embeddings moving to (S, B, C).
|
# Stack encoder input and positional embeddings moving to (S, B, C).
|
||||||
|
@ -479,7 +445,10 @@ class _TransformerDecoder(nn.Module):
|
||||||
) -> Tensor:
|
) -> Tensor:
|
||||||
for layer in self.layers:
|
for layer in self.layers:
|
||||||
x = layer(
|
x = layer(
|
||||||
x, encoder_out, decoder_pos_embed=decoder_pos_embed, encoder_pos_embed=encoder_pos_embed
|
x,
|
||||||
|
encoder_out,
|
||||||
|
decoder_pos_embed=decoder_pos_embed,
|
||||||
|
encoder_pos_embed=encoder_pos_embed,
|
||||||
)
|
)
|
||||||
if self.norm is not None:
|
if self.norm is not None:
|
||||||
x = self.norm(x)
|
x = self.norm(x)
|
||||||
|
|
|
@ -67,6 +67,4 @@ policy:
|
||||||
utd: 1
|
utd: 1
|
||||||
|
|
||||||
delta_timestamps:
|
delta_timestamps:
|
||||||
observation.images.top: "[i / ${fps} for i in range(1 - ${n_obs_steps}, 1)]"
|
|
||||||
observation.state: "[i / ${fps} for i in range(1 - ${n_obs_steps}, 1)]"
|
|
||||||
action: "[i / ${fps} for i in range(${policy.chunk_size})]"
|
action: "[i / ${fps} for i in range(${policy.chunk_size})]"
|
||||||
|
|
Loading…
Reference in New Issue