Make ACT compatible with "observation.environment_state" (#314)

This commit is contained in:
Alexander Soare 2024-07-11 13:12:22 +01:00 committed by GitHub
parent 64425d5e00
commit 471eab3d7e
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
15 changed files with 117 additions and 88 deletions

View File

@ -26,7 +26,10 @@ class ACTConfig:
Those are: `input_shapes` and 'output_shapes`. Those are: `input_shapes` and 'output_shapes`.
Notes on the inputs and outputs: Notes on the inputs and outputs:
- Either:
- At least one key starting with "observation.image is required as an input. - At least one key starting with "observation.image is required as an input.
AND/OR
- The key "observation.environment_state" is required as input.
- If there are multiple keys beginning with "observation.images." they are treated as multiple camera - If there are multiple keys beginning with "observation.images." they are treated as multiple camera
views. Right now we only support all images having the same shape. views. Right now we only support all images having the same shape.
- May optionally work without an "observation.state" key for the proprioceptive robot state. - May optionally work without an "observation.state" key for the proprioceptive robot state.
@ -162,3 +165,8 @@ class ACTConfig:
raise ValueError( raise ValueError(
f"Multiple observation steps not handled yet. Got `nobs_steps={self.n_obs_steps}`" f"Multiple observation steps not handled yet. Got `nobs_steps={self.n_obs_steps}`"
) )
if (
not any(k.startswith("observation.image") for k in self.input_shapes)
and "observation.environment_state" not in self.input_shapes
):
raise ValueError("You must provide at least one image or the environment state among the inputs.")

View File

@ -97,6 +97,7 @@ class ACTPolicy(nn.Module, PyTorchModelHubMixin):
self.eval() self.eval()
batch = self.normalize_inputs(batch) batch = self.normalize_inputs(batch)
if len(self.expected_image_keys) > 0:
batch["observation.images"] = torch.stack([batch[k] for k in self.expected_image_keys], dim=-4) batch["observation.images"] = torch.stack([batch[k] for k in self.expected_image_keys], dim=-4)
# If we are doing temporal ensembling, keep track of the exponential moving average (EMA), and return # If we are doing temporal ensembling, keep track of the exponential moving average (EMA), and return
@ -135,6 +136,7 @@ class ACTPolicy(nn.Module, PyTorchModelHubMixin):
def forward(self, batch: dict[str, Tensor]) -> dict[str, Tensor]: def forward(self, batch: dict[str, Tensor]) -> dict[str, Tensor]:
"""Run the batch through the model and compute the loss for training or validation.""" """Run the batch through the model and compute the loss for training or validation."""
batch = self.normalize_inputs(batch) batch = self.normalize_inputs(batch)
if len(self.expected_image_keys) > 0:
batch["observation.images"] = torch.stack([batch[k] for k in self.expected_image_keys], dim=-4) batch["observation.images"] = torch.stack([batch[k] for k in self.expected_image_keys], dim=-4)
batch = self.normalize_targets(batch) batch = self.normalize_targets(batch)
actions_hat, (mu_hat, log_sigma_x2_hat) = self.model(batch) actions_hat, (mu_hat, log_sigma_x2_hat) = self.model(batch)
@ -200,12 +202,14 @@ class ACT(nn.Module):
self.config = config self.config = config
# BERT style VAE encoder with input tokens [cls, robot_state, *action_sequence]. # BERT style VAE encoder with input tokens [cls, robot_state, *action_sequence].
# The cls token forms parameters of the latent's distribution (like this [*means, *log_variances]). # The cls token forms parameters of the latent's distribution (like this [*means, *log_variances]).
self.use_input_state = "observation.state" in config.input_shapes self.use_robot_state = "observation.state" in config.input_shapes
self.use_images = any(k.startswith("observation.image") for k in config.input_shapes)
self.use_env_state = "observation.environment_state" in config.input_shapes
if self.config.use_vae: if self.config.use_vae:
self.vae_encoder = ACTEncoder(config) self.vae_encoder = ACTEncoder(config)
self.vae_encoder_cls_embed = nn.Embedding(1, config.dim_model) self.vae_encoder_cls_embed = nn.Embedding(1, config.dim_model)
# Projection layer for joint-space configuration to hidden dimension. # Projection layer for joint-space configuration to hidden dimension.
if self.use_input_state: if self.use_robot_state:
self.vae_encoder_robot_state_input_proj = nn.Linear( self.vae_encoder_robot_state_input_proj = nn.Linear(
config.input_shapes["observation.state"][0], config.dim_model config.input_shapes["observation.state"][0], config.dim_model
) )
@ -218,7 +222,7 @@ class ACT(nn.Module):
# Fixed sinusoidal positional embedding for the input to the VAE encoder. Unsqueeze for batch # Fixed sinusoidal positional embedding for the input to the VAE encoder. Unsqueeze for batch
# dimension. # dimension.
num_input_token_encoder = 1 + config.chunk_size num_input_token_encoder = 1 + config.chunk_size
if self.use_input_state: if self.use_robot_state:
num_input_token_encoder += 1 num_input_token_encoder += 1
self.register_buffer( self.register_buffer(
"vae_encoder_pos_enc", "vae_encoder_pos_enc",
@ -226,13 +230,14 @@ class ACT(nn.Module):
) )
# Backbone for image feature extraction. # Backbone for image feature extraction.
if self.use_images:
backbone_model = getattr(torchvision.models, config.vision_backbone)( backbone_model = getattr(torchvision.models, config.vision_backbone)(
replace_stride_with_dilation=[False, False, config.replace_final_stride_with_dilation], replace_stride_with_dilation=[False, False, config.replace_final_stride_with_dilation],
weights=config.pretrained_backbone_weights, weights=config.pretrained_backbone_weights,
norm_layer=FrozenBatchNorm2d, norm_layer=FrozenBatchNorm2d,
) )
# Note: The assumption here is that we are using a ResNet model (and hence layer4 is the final feature # Note: The assumption here is that we are using a ResNet model (and hence layer4 is the final
# map). # feature map).
# Note: The forward method of this returns a dict: {"feature_map": output}. # Note: The forward method of this returns a dict: {"feature_map": output}.
self.backbone = IntermediateLayerGetter(backbone_model, return_layers={"layer4": "feature_map"}) self.backbone = IntermediateLayerGetter(backbone_model, return_layers={"layer4": "feature_map"})
@ -241,18 +246,28 @@ class ACT(nn.Module):
self.decoder = ACTDecoder(config) self.decoder = ACTDecoder(config)
# Transformer encoder input projections. The tokens will be structured like # Transformer encoder input projections. The tokens will be structured like
# [latent, robot_state, image_feature_map_pixels]. # [latent, (robot_state), (env_state), (image_feature_map_pixels)].
if self.use_input_state: if self.use_robot_state:
self.encoder_robot_state_input_proj = nn.Linear( self.encoder_robot_state_input_proj = nn.Linear(
config.input_shapes["observation.state"][0], config.dim_model config.input_shapes["observation.state"][0], config.dim_model
) )
if self.use_env_state:
self.encoder_env_state_input_proj = nn.Linear(
config.input_shapes["observation.environment_state"][0], config.dim_model
)
self.encoder_latent_input_proj = nn.Linear(config.latent_dim, config.dim_model) self.encoder_latent_input_proj = nn.Linear(config.latent_dim, config.dim_model)
if self.use_images:
self.encoder_img_feat_input_proj = nn.Conv2d( self.encoder_img_feat_input_proj = nn.Conv2d(
backbone_model.fc.in_features, config.dim_model, kernel_size=1 backbone_model.fc.in_features, config.dim_model, kernel_size=1
) )
# Transformer encoder positional embeddings. # Transformer encoder positional embeddings.
num_input_token_decoder = 2 if self.use_input_state else 1 n_1d_tokens = 1 # for the latent
self.encoder_robot_and_latent_pos_embed = nn.Embedding(num_input_token_decoder, config.dim_model) if self.use_robot_state:
n_1d_tokens += 1
if self.use_env_state:
n_1d_tokens += 1
self.encoder_1d_feature_pos_embed = nn.Embedding(n_1d_tokens, config.dim_model)
if self.use_images:
self.encoder_cam_feat_pos_embed = ACTSinusoidalPositionEmbedding2d(config.dim_model // 2) self.encoder_cam_feat_pos_embed = ACTSinusoidalPositionEmbedding2d(config.dim_model // 2)
# Transformer decoder. # Transformer decoder.
@ -274,10 +289,13 @@ class ACT(nn.Module):
"""A forward pass through the Action Chunking Transformer (with optional VAE encoder). """A forward pass through the Action Chunking Transformer (with optional VAE encoder).
`batch` should have the following structure: `batch` should have the following structure:
{ {
"observation.state": (B, state_dim) batch of robot states. "observation.state" (optional): (B, state_dim) batch of robot states.
"observation.images": (B, n_cameras, C, H, W) batch of images. "observation.images": (B, n_cameras, C, H, W) batch of images.
AND/OR
"observation.environment_state": (B, env_dim) batch of environment states.
"action" (optional, only if training with VAE): (B, chunk_size, action dim) batch of actions. "action" (optional, only if training with VAE): (B, chunk_size, action dim) batch of actions.
} }
@ -291,7 +309,11 @@ class ACT(nn.Module):
"action" in batch "action" in batch
), "actions must be provided when using the variational objective in training mode." ), "actions must be provided when using the variational objective in training mode."
batch_size = batch["observation.images"].shape[0] batch_size = (
batch["observation.images"]
if "observation.images" in batch
else batch["observation.environment_state"]
).shape[0]
# Prepare the latent for input to the transformer encoder. # Prepare the latent for input to the transformer encoder.
if self.config.use_vae and "action" in batch: if self.config.use_vae and "action" in batch:
@ -299,12 +321,12 @@ class ACT(nn.Module):
cls_embed = einops.repeat( cls_embed = einops.repeat(
self.vae_encoder_cls_embed.weight, "1 d -> b 1 d", b=batch_size self.vae_encoder_cls_embed.weight, "1 d -> b 1 d", b=batch_size
) # (B, 1, D) ) # (B, 1, D)
if self.use_input_state: if self.use_robot_state:
robot_state_embed = self.vae_encoder_robot_state_input_proj(batch["observation.state"]) robot_state_embed = self.vae_encoder_robot_state_input_proj(batch["observation.state"])
robot_state_embed = robot_state_embed.unsqueeze(1) # (B, 1, D) robot_state_embed = robot_state_embed.unsqueeze(1) # (B, 1, D)
action_embed = self.vae_encoder_action_input_proj(batch["action"]) # (B, S, D) action_embed = self.vae_encoder_action_input_proj(batch["action"]) # (B, S, D)
if self.use_input_state: if self.use_robot_state:
vae_encoder_input = [cls_embed, robot_state_embed, action_embed] # (B, S+2, D) vae_encoder_input = [cls_embed, robot_state_embed, action_embed] # (B, S+2, D)
else: else:
vae_encoder_input = [cls_embed, action_embed] vae_encoder_input = [cls_embed, action_embed]
@ -318,7 +340,7 @@ class ACT(nn.Module):
# sequence depending whether we use the input states or not (cls and robot state) # sequence depending whether we use the input states or not (cls and robot state)
# False means not a padding token. # False means not a padding token.
cls_joint_is_pad = torch.full( cls_joint_is_pad = torch.full(
(batch_size, 2 if self.use_input_state else 1), (batch_size, 2 if self.use_robot_state else 1),
False, False,
device=batch["observation.state"].device, device=batch["observation.state"].device,
) )
@ -347,56 +369,55 @@ class ACT(nn.Module):
batch["observation.state"].device batch["observation.state"].device
) )
# Prepare all other transformer encoder inputs. # Prepare transformer encoder inputs.
encoder_in_tokens = [self.encoder_latent_input_proj(latent_sample)]
encoder_in_pos_embed = list(self.encoder_1d_feature_pos_embed.weight.unsqueeze(1))
# Robot state token.
if self.use_robot_state:
encoder_in_tokens.append(self.encoder_robot_state_input_proj(batch["observation.state"]))
# Environment state token.
if self.use_env_state:
encoder_in_tokens.append(
self.encoder_env_state_input_proj(batch["observation.environment_state"])
)
# Camera observation features and positional embeddings. # Camera observation features and positional embeddings.
if self.use_images:
all_cam_features = [] all_cam_features = []
all_cam_pos_embeds = [] all_cam_pos_embeds = []
images = batch["observation.images"] images = batch["observation.images"]
for cam_index in range(images.shape[-4]): for cam_index in range(images.shape[-4]):
cam_features = self.backbone(images[:, cam_index])["feature_map"] cam_features = self.backbone(images[:, cam_index])["feature_map"]
# TODO(rcadene, alexander-soare): remove call to `.to` to speedup forward ; precompute and use buffer # TODO(rcadene, alexander-soare): remove call to `.to` to speedup forward ; precompute and use
# buffer
cam_pos_embed = self.encoder_cam_feat_pos_embed(cam_features).to(dtype=cam_features.dtype) cam_pos_embed = self.encoder_cam_feat_pos_embed(cam_features).to(dtype=cam_features.dtype)
cam_features = self.encoder_img_feat_input_proj(cam_features) # (B, C, h, w) cam_features = self.encoder_img_feat_input_proj(cam_features) # (B, C, h, w)
all_cam_features.append(cam_features) all_cam_features.append(cam_features)
all_cam_pos_embeds.append(cam_pos_embed) all_cam_pos_embeds.append(cam_pos_embed)
# Concatenate camera observation feature maps and positional embeddings along the width dimension. # Concatenate camera observation feature maps and positional embeddings along the width dimension,
encoder_in = torch.cat(all_cam_features, axis=-1) # and move to (sequence, batch, dim).
cam_pos_embed = torch.cat(all_cam_pos_embeds, axis=-1) all_cam_features = torch.cat(all_cam_features, axis=-1)
encoder_in_tokens.extend(einops.rearrange(all_cam_features, "b c h w -> (h w) b c"))
all_cam_pos_embeds = torch.cat(all_cam_pos_embeds, axis=-1)
encoder_in_pos_embed.extend(einops.rearrange(all_cam_pos_embeds, "b c h w -> (h w) b c"))
# Get positional embeddings for robot state and latent. # Stack all tokens along the sequence dimension.
if self.use_input_state: encoder_in_tokens = torch.stack(encoder_in_tokens, axis=0)
robot_state_embed = self.encoder_robot_state_input_proj(batch["observation.state"]) # (B, C) encoder_in_pos_embed = torch.stack(encoder_in_pos_embed, axis=0)
latent_embed = self.encoder_latent_input_proj(latent_sample) # (B, C)
# Stack encoder input and positional embeddings moving to (S, B, C).
encoder_in_feats = [latent_embed, robot_state_embed] if self.use_input_state else [latent_embed]
encoder_in = torch.cat(
[
torch.stack(encoder_in_feats, axis=0),
einops.rearrange(encoder_in, "b c h w -> (h w) b c"),
]
)
pos_embed = torch.cat(
[
self.encoder_robot_and_latent_pos_embed.weight.unsqueeze(1),
cam_pos_embed.flatten(2).permute(2, 0, 1),
],
axis=0,
)
# Forward pass through the transformer modules. # Forward pass through the transformer modules.
encoder_out = self.encoder(encoder_in, pos_embed=pos_embed) encoder_out = self.encoder(encoder_in_tokens, pos_embed=encoder_in_pos_embed)
# TODO(rcadene, alexander-soare): remove call to `device` ; precompute and use buffer # TODO(rcadene, alexander-soare): remove call to `device` ; precompute and use buffer
decoder_in = torch.zeros( decoder_in = torch.zeros(
(self.config.chunk_size, batch_size, self.config.dim_model), (self.config.chunk_size, batch_size, self.config.dim_model),
dtype=pos_embed.dtype, dtype=encoder_in_pos_embed.dtype,
device=pos_embed.device, device=encoder_in_pos_embed.device,
) )
decoder_out = self.decoder( decoder_out = self.decoder(
decoder_in, decoder_in,
encoder_out, encoder_out,
encoder_pos_embed=pos_embed, encoder_pos_embed=encoder_in_pos_embed,
decoder_pos_embed=self.decoder_pos_embed.weight.unsqueeze(1), decoder_pos_embed=self.decoder_pos_embed.weight.unsqueeze(1),
) )

View File

@ -1,3 +1,3 @@
version https://git-lfs.github.com/spec/v1 version https://git-lfs.github.com/spec/v1
oid sha256:523f220f3acbab0cd4aef8a13c77916634488b1af08a06e4e65d1aecafdc2cae oid sha256:28444747a9cb3876f86ae86fed72e587dbcacfccd87c5c24b8ecac30c3ce3077
size 5104 size 5104

View File

@ -1,3 +1,3 @@
version https://git-lfs.github.com/spec/v1 version https://git-lfs.github.com/spec/v1
oid sha256:95dd049b4386030ced4505586b874f16906f8d89f29b570201782eebcbe4f402 oid sha256:a43a9ddaf8527e3344b22bd21276e1f561e83599d720933b28725b00d94823c0
size 31688 size 31672

View File

@ -1,3 +1,3 @@
version https://git-lfs.github.com/spec/v1 version https://git-lfs.github.com/spec/v1
oid sha256:806851d60b6c492620b7269876eef9ce17756ec03da93f36b351f8aa75be0954 oid sha256:093bff1fbc3bde2547bccbbefc277d02368a8d4a9100b3e4bd47c755798cad68
size 33408 size 33400

View File

@ -1,3 +1,3 @@
version https://git-lfs.github.com/spec/v1 version https://git-lfs.github.com/spec/v1
oid sha256:3f4e0e525aeb22ea94b79e26b39a87e6f2da9fbee33e493906aaf2aad9a7c1ef oid sha256:85bed637e90f15c64e4af01d2dbc5d9c3a370215f2c8c379494fa3acb413bc2e
size 515400 size 515400

View File

@ -1,3 +1,3 @@
version https://git-lfs.github.com/spec/v1 version https://git-lfs.github.com/spec/v1
oid sha256:6dc658a1c1616c7d1c211eb8f87cec3d44f7b67d6b3cea7a6ce12b32d74674da oid sha256:00cf8e548d7ea23aa70de79e05c39990a32a790def824f729e6c98bea31c69bc
size 31688 size 31672

View File

@ -1,3 +1,3 @@
version https://git-lfs.github.com/spec/v1 version https://git-lfs.github.com/spec/v1
oid sha256:01d993c67a9267032fe9fbeff20b4359c209464976ea503040a0a76ae213450a oid sha256:b3a4c2581f48229312a582d91f0adea8078c0c5b744c34d76723edf4731f9003
size 33408 size 33400

View File

@ -1,3 +1,3 @@
version https://git-lfs.github.com/spec/v1 version https://git-lfs.github.com/spec/v1
oid sha256:2fff6294b94cf42d4dd1249dcc5c3b0269d6d9c697f894e61b867d7ab81a94e4 oid sha256:aab00b0349901450adbb8e0d7d4af1f743dd88e7e19f1bcfef821de8bdcc957d
size 5104 size 5104

View File

@ -1,3 +1,3 @@
version https://git-lfs.github.com/spec/v1 version https://git-lfs.github.com/spec/v1
oid sha256:4aa23e51607604a18b70fa42edbbe1af34f119d985628fc27cc1bbb0efbc8901 oid sha256:de70c3055aa052f5b811ec7c2994ec6861efe645c6caee41e04a3460598500d5
size 31688 size 31672

View File

@ -1,3 +1,3 @@
version https://git-lfs.github.com/spec/v1 version https://git-lfs.github.com/spec/v1
oid sha256:6fd368406c93cb562a69ff11cf7adf34a4b223507dcb2b9e9b8f44ee1036988a oid sha256:d4070bd1f1cd8c72bc2daf628088e42b8ef113f6df0bfd9e91be052bc90038c3
size 68 size 68

View File

@ -1,3 +1,3 @@
version https://git-lfs.github.com/spec/v1 version https://git-lfs.github.com/spec/v1
oid sha256:5663ee79a13bb70a1604b887dd21bf89d18482287442419c6cc6c5bf0e753e99 oid sha256:19fdc1edf327e04132c1917024289b3d16e25a1ec2130f3df797fe07434dfbbd
size 34928 size 34920

View File

@ -1,3 +1,3 @@
version https://git-lfs.github.com/spec/v1 version https://git-lfs.github.com/spec/v1
oid sha256:fb1a45463efd860af2ca22c16c77d55a18bd96fef080ae77978845a2f22ef716 oid sha256:dcd8ebaefd3ff267eb24654135d1efb179d713e6cfe6917f793a3e2483efd501
size 5104 size 5104

View File

@ -1,3 +1,3 @@
version https://git-lfs.github.com/spec/v1 version https://git-lfs.github.com/spec/v1
oid sha256:aa5a43e22f01d8e2f8d19f31753608794f1edbd74aaf71660091ab80ea58dc9b oid sha256:107e98647ed1081745476b250df8848c0c430b2aff51d614f6b2db95684467aa
size 30808 size 30800

View File

@ -1,3 +1,3 @@
version https://git-lfs.github.com/spec/v1 version https://git-lfs.github.com/spec/v1
oid sha256:54d1f75cf67a7b1d7a7c6865ecb9b1cc86a2f032d1890245f8996789ab6e0df6 oid sha256:adbae737c987f912509d3fba06f332bda700bfc2c6d83a09c969e9d7a3ca75f7
size 33608 size 33600