Make ACT compatible with "observation.environment_state" (#314)

This commit is contained in:
Alexander Soare 2024-07-11 13:12:22 +01:00 committed by GitHub
parent 64425d5e00
commit 471eab3d7e
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
15 changed files with 117 additions and 88 deletions

View File

@ -26,7 +26,10 @@ class ACTConfig:
Those are: `input_shapes` and 'output_shapes`.
Notes on the inputs and outputs:
- Either:
- At least one key starting with "observation.image is required as an input.
AND/OR
- The key "observation.environment_state" is required as input.
- If there are multiple keys beginning with "observation.images." they are treated as multiple camera
views. Right now we only support all images having the same shape.
- May optionally work without an "observation.state" key for the proprioceptive robot state.
@ -162,3 +165,8 @@ class ACTConfig:
raise ValueError(
f"Multiple observation steps not handled yet. Got `nobs_steps={self.n_obs_steps}`"
)
if (
not any(k.startswith("observation.image") for k in self.input_shapes)
and "observation.environment_state" not in self.input_shapes
):
raise ValueError("You must provide at least one image or the environment state among the inputs.")

View File

@ -97,6 +97,7 @@ class ACTPolicy(nn.Module, PyTorchModelHubMixin):
self.eval()
batch = self.normalize_inputs(batch)
if len(self.expected_image_keys) > 0:
batch["observation.images"] = torch.stack([batch[k] for k in self.expected_image_keys], dim=-4)
# If we are doing temporal ensembling, keep track of the exponential moving average (EMA), and return
@ -135,6 +136,7 @@ class ACTPolicy(nn.Module, PyTorchModelHubMixin):
def forward(self, batch: dict[str, Tensor]) -> dict[str, Tensor]:
"""Run the batch through the model and compute the loss for training or validation."""
batch = self.normalize_inputs(batch)
if len(self.expected_image_keys) > 0:
batch["observation.images"] = torch.stack([batch[k] for k in self.expected_image_keys], dim=-4)
batch = self.normalize_targets(batch)
actions_hat, (mu_hat, log_sigma_x2_hat) = self.model(batch)
@ -200,12 +202,14 @@ class ACT(nn.Module):
self.config = config
# BERT style VAE encoder with input tokens [cls, robot_state, *action_sequence].
# The cls token forms parameters of the latent's distribution (like this [*means, *log_variances]).
self.use_input_state = "observation.state" in config.input_shapes
self.use_robot_state = "observation.state" in config.input_shapes
self.use_images = any(k.startswith("observation.image") for k in config.input_shapes)
self.use_env_state = "observation.environment_state" in config.input_shapes
if self.config.use_vae:
self.vae_encoder = ACTEncoder(config)
self.vae_encoder_cls_embed = nn.Embedding(1, config.dim_model)
# Projection layer for joint-space configuration to hidden dimension.
if self.use_input_state:
if self.use_robot_state:
self.vae_encoder_robot_state_input_proj = nn.Linear(
config.input_shapes["observation.state"][0], config.dim_model
)
@ -218,7 +222,7 @@ class ACT(nn.Module):
# Fixed sinusoidal positional embedding for the input to the VAE encoder. Unsqueeze for batch
# dimension.
num_input_token_encoder = 1 + config.chunk_size
if self.use_input_state:
if self.use_robot_state:
num_input_token_encoder += 1
self.register_buffer(
"vae_encoder_pos_enc",
@ -226,13 +230,14 @@ class ACT(nn.Module):
)
# Backbone for image feature extraction.
if self.use_images:
backbone_model = getattr(torchvision.models, config.vision_backbone)(
replace_stride_with_dilation=[False, False, config.replace_final_stride_with_dilation],
weights=config.pretrained_backbone_weights,
norm_layer=FrozenBatchNorm2d,
)
# Note: The assumption here is that we are using a ResNet model (and hence layer4 is the final feature
# map).
# Note: The assumption here is that we are using a ResNet model (and hence layer4 is the final
# feature map).
# Note: The forward method of this returns a dict: {"feature_map": output}.
self.backbone = IntermediateLayerGetter(backbone_model, return_layers={"layer4": "feature_map"})
@ -241,18 +246,28 @@ class ACT(nn.Module):
self.decoder = ACTDecoder(config)
# Transformer encoder input projections. The tokens will be structured like
# [latent, robot_state, image_feature_map_pixels].
if self.use_input_state:
# [latent, (robot_state), (env_state), (image_feature_map_pixels)].
if self.use_robot_state:
self.encoder_robot_state_input_proj = nn.Linear(
config.input_shapes["observation.state"][0], config.dim_model
)
if self.use_env_state:
self.encoder_env_state_input_proj = nn.Linear(
config.input_shapes["observation.environment_state"][0], config.dim_model
)
self.encoder_latent_input_proj = nn.Linear(config.latent_dim, config.dim_model)
if self.use_images:
self.encoder_img_feat_input_proj = nn.Conv2d(
backbone_model.fc.in_features, config.dim_model, kernel_size=1
)
# Transformer encoder positional embeddings.
num_input_token_decoder = 2 if self.use_input_state else 1
self.encoder_robot_and_latent_pos_embed = nn.Embedding(num_input_token_decoder, config.dim_model)
n_1d_tokens = 1 # for the latent
if self.use_robot_state:
n_1d_tokens += 1
if self.use_env_state:
n_1d_tokens += 1
self.encoder_1d_feature_pos_embed = nn.Embedding(n_1d_tokens, config.dim_model)
if self.use_images:
self.encoder_cam_feat_pos_embed = ACTSinusoidalPositionEmbedding2d(config.dim_model // 2)
# Transformer decoder.
@ -274,10 +289,13 @@ class ACT(nn.Module):
"""A forward pass through the Action Chunking Transformer (with optional VAE encoder).
`batch` should have the following structure:
{
"observation.state": (B, state_dim) batch of robot states.
"observation.state" (optional): (B, state_dim) batch of robot states.
"observation.images": (B, n_cameras, C, H, W) batch of images.
AND/OR
"observation.environment_state": (B, env_dim) batch of environment states.
"action" (optional, only if training with VAE): (B, chunk_size, action dim) batch of actions.
}
@ -291,7 +309,11 @@ class ACT(nn.Module):
"action" in batch
), "actions must be provided when using the variational objective in training mode."
batch_size = batch["observation.images"].shape[0]
batch_size = (
batch["observation.images"]
if "observation.images" in batch
else batch["observation.environment_state"]
).shape[0]
# Prepare the latent for input to the transformer encoder.
if self.config.use_vae and "action" in batch:
@ -299,12 +321,12 @@ class ACT(nn.Module):
cls_embed = einops.repeat(
self.vae_encoder_cls_embed.weight, "1 d -> b 1 d", b=batch_size
) # (B, 1, D)
if self.use_input_state:
if self.use_robot_state:
robot_state_embed = self.vae_encoder_robot_state_input_proj(batch["observation.state"])
robot_state_embed = robot_state_embed.unsqueeze(1) # (B, 1, D)
action_embed = self.vae_encoder_action_input_proj(batch["action"]) # (B, S, D)
if self.use_input_state:
if self.use_robot_state:
vae_encoder_input = [cls_embed, robot_state_embed, action_embed] # (B, S+2, D)
else:
vae_encoder_input = [cls_embed, action_embed]
@ -318,7 +340,7 @@ class ACT(nn.Module):
# sequence depending whether we use the input states or not (cls and robot state)
# False means not a padding token.
cls_joint_is_pad = torch.full(
(batch_size, 2 if self.use_input_state else 1),
(batch_size, 2 if self.use_robot_state else 1),
False,
device=batch["observation.state"].device,
)
@ -347,56 +369,55 @@ class ACT(nn.Module):
batch["observation.state"].device
)
# Prepare all other transformer encoder inputs.
# Prepare transformer encoder inputs.
encoder_in_tokens = [self.encoder_latent_input_proj(latent_sample)]
encoder_in_pos_embed = list(self.encoder_1d_feature_pos_embed.weight.unsqueeze(1))
# Robot state token.
if self.use_robot_state:
encoder_in_tokens.append(self.encoder_robot_state_input_proj(batch["observation.state"]))
# Environment state token.
if self.use_env_state:
encoder_in_tokens.append(
self.encoder_env_state_input_proj(batch["observation.environment_state"])
)
# Camera observation features and positional embeddings.
if self.use_images:
all_cam_features = []
all_cam_pos_embeds = []
images = batch["observation.images"]
for cam_index in range(images.shape[-4]):
cam_features = self.backbone(images[:, cam_index])["feature_map"]
# TODO(rcadene, alexander-soare): remove call to `.to` to speedup forward ; precompute and use buffer
# TODO(rcadene, alexander-soare): remove call to `.to` to speedup forward ; precompute and use
# buffer
cam_pos_embed = self.encoder_cam_feat_pos_embed(cam_features).to(dtype=cam_features.dtype)
cam_features = self.encoder_img_feat_input_proj(cam_features) # (B, C, h, w)
all_cam_features.append(cam_features)
all_cam_pos_embeds.append(cam_pos_embed)
# Concatenate camera observation feature maps and positional embeddings along the width dimension.
encoder_in = torch.cat(all_cam_features, axis=-1)
cam_pos_embed = torch.cat(all_cam_pos_embeds, axis=-1)
# Concatenate camera observation feature maps and positional embeddings along the width dimension,
# and move to (sequence, batch, dim).
all_cam_features = torch.cat(all_cam_features, axis=-1)
encoder_in_tokens.extend(einops.rearrange(all_cam_features, "b c h w -> (h w) b c"))
all_cam_pos_embeds = torch.cat(all_cam_pos_embeds, axis=-1)
encoder_in_pos_embed.extend(einops.rearrange(all_cam_pos_embeds, "b c h w -> (h w) b c"))
# Get positional embeddings for robot state and latent.
if self.use_input_state:
robot_state_embed = self.encoder_robot_state_input_proj(batch["observation.state"]) # (B, C)
latent_embed = self.encoder_latent_input_proj(latent_sample) # (B, C)
# Stack encoder input and positional embeddings moving to (S, B, C).
encoder_in_feats = [latent_embed, robot_state_embed] if self.use_input_state else [latent_embed]
encoder_in = torch.cat(
[
torch.stack(encoder_in_feats, axis=0),
einops.rearrange(encoder_in, "b c h w -> (h w) b c"),
]
)
pos_embed = torch.cat(
[
self.encoder_robot_and_latent_pos_embed.weight.unsqueeze(1),
cam_pos_embed.flatten(2).permute(2, 0, 1),
],
axis=0,
)
# Stack all tokens along the sequence dimension.
encoder_in_tokens = torch.stack(encoder_in_tokens, axis=0)
encoder_in_pos_embed = torch.stack(encoder_in_pos_embed, axis=0)
# Forward pass through the transformer modules.
encoder_out = self.encoder(encoder_in, pos_embed=pos_embed)
encoder_out = self.encoder(encoder_in_tokens, pos_embed=encoder_in_pos_embed)
# TODO(rcadene, alexander-soare): remove call to `device` ; precompute and use buffer
decoder_in = torch.zeros(
(self.config.chunk_size, batch_size, self.config.dim_model),
dtype=pos_embed.dtype,
device=pos_embed.device,
dtype=encoder_in_pos_embed.dtype,
device=encoder_in_pos_embed.device,
)
decoder_out = self.decoder(
decoder_in,
encoder_out,
encoder_pos_embed=pos_embed,
encoder_pos_embed=encoder_in_pos_embed,
decoder_pos_embed=self.decoder_pos_embed.weight.unsqueeze(1),
)

View File

@ -1,3 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:523f220f3acbab0cd4aef8a13c77916634488b1af08a06e4e65d1aecafdc2cae
oid sha256:28444747a9cb3876f86ae86fed72e587dbcacfccd87c5c24b8ecac30c3ce3077
size 5104

View File

@ -1,3 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:95dd049b4386030ced4505586b874f16906f8d89f29b570201782eebcbe4f402
size 31688
oid sha256:a43a9ddaf8527e3344b22bd21276e1f561e83599d720933b28725b00d94823c0
size 31672

View File

@ -1,3 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:806851d60b6c492620b7269876eef9ce17756ec03da93f36b351f8aa75be0954
size 33408
oid sha256:093bff1fbc3bde2547bccbbefc277d02368a8d4a9100b3e4bd47c755798cad68
size 33400

View File

@ -1,3 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:3f4e0e525aeb22ea94b79e26b39a87e6f2da9fbee33e493906aaf2aad9a7c1ef
oid sha256:85bed637e90f15c64e4af01d2dbc5d9c3a370215f2c8c379494fa3acb413bc2e
size 515400

View File

@ -1,3 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:6dc658a1c1616c7d1c211eb8f87cec3d44f7b67d6b3cea7a6ce12b32d74674da
size 31688
oid sha256:00cf8e548d7ea23aa70de79e05c39990a32a790def824f729e6c98bea31c69bc
size 31672

View File

@ -1,3 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:01d993c67a9267032fe9fbeff20b4359c209464976ea503040a0a76ae213450a
size 33408
oid sha256:b3a4c2581f48229312a582d91f0adea8078c0c5b744c34d76723edf4731f9003
size 33400

View File

@ -1,3 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:2fff6294b94cf42d4dd1249dcc5c3b0269d6d9c697f894e61b867d7ab81a94e4
oid sha256:aab00b0349901450adbb8e0d7d4af1f743dd88e7e19f1bcfef821de8bdcc957d
size 5104

View File

@ -1,3 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:4aa23e51607604a18b70fa42edbbe1af34f119d985628fc27cc1bbb0efbc8901
size 31688
oid sha256:de70c3055aa052f5b811ec7c2994ec6861efe645c6caee41e04a3460598500d5
size 31672

View File

@ -1,3 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:6fd368406c93cb562a69ff11cf7adf34a4b223507dcb2b9e9b8f44ee1036988a
oid sha256:d4070bd1f1cd8c72bc2daf628088e42b8ef113f6df0bfd9e91be052bc90038c3
size 68

View File

@ -1,3 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:5663ee79a13bb70a1604b887dd21bf89d18482287442419c6cc6c5bf0e753e99
size 34928
oid sha256:19fdc1edf327e04132c1917024289b3d16e25a1ec2130f3df797fe07434dfbbd
size 34920

View File

@ -1,3 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:fb1a45463efd860af2ca22c16c77d55a18bd96fef080ae77978845a2f22ef716
oid sha256:dcd8ebaefd3ff267eb24654135d1efb179d713e6cfe6917f793a3e2483efd501
size 5104

View File

@ -1,3 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:aa5a43e22f01d8e2f8d19f31753608794f1edbd74aaf71660091ab80ea58dc9b
size 30808
oid sha256:107e98647ed1081745476b250df8848c0c430b2aff51d614f6b2db95684467aa
size 30800

View File

@ -1,3 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:54d1f75cf67a7b1d7a7c6865ecb9b1cc86a2f032d1890245f8996789ab6e0df6
size 33608
oid sha256:adbae737c987f912509d3fba06f332bda700bfc2c6d83a09c969e9d7a3ca75f7
size 33600