Make ACT compatible with "observation.environment_state" (#314)
This commit is contained in:
parent
64425d5e00
commit
471eab3d7e
|
@ -26,7 +26,10 @@ class ACTConfig:
|
|||
Those are: `input_shapes` and 'output_shapes`.
|
||||
|
||||
Notes on the inputs and outputs:
|
||||
- At least one key starting with "observation.image is required as an input.
|
||||
- Either:
|
||||
- At least one key starting with "observation.image is required as an input.
|
||||
AND/OR
|
||||
- The key "observation.environment_state" is required as input.
|
||||
- If there are multiple keys beginning with "observation.images." they are treated as multiple camera
|
||||
views. Right now we only support all images having the same shape.
|
||||
- May optionally work without an "observation.state" key for the proprioceptive robot state.
|
||||
|
@ -162,3 +165,8 @@ class ACTConfig:
|
|||
raise ValueError(
|
||||
f"Multiple observation steps not handled yet. Got `nobs_steps={self.n_obs_steps}`"
|
||||
)
|
||||
if (
|
||||
not any(k.startswith("observation.image") for k in self.input_shapes)
|
||||
and "observation.environment_state" not in self.input_shapes
|
||||
):
|
||||
raise ValueError("You must provide at least one image or the environment state among the inputs.")
|
||||
|
|
|
@ -97,7 +97,8 @@ class ACTPolicy(nn.Module, PyTorchModelHubMixin):
|
|||
self.eval()
|
||||
|
||||
batch = self.normalize_inputs(batch)
|
||||
batch["observation.images"] = torch.stack([batch[k] for k in self.expected_image_keys], dim=-4)
|
||||
if len(self.expected_image_keys) > 0:
|
||||
batch["observation.images"] = torch.stack([batch[k] for k in self.expected_image_keys], dim=-4)
|
||||
|
||||
# If we are doing temporal ensembling, keep track of the exponential moving average (EMA), and return
|
||||
# the first action.
|
||||
|
@ -135,7 +136,8 @@ class ACTPolicy(nn.Module, PyTorchModelHubMixin):
|
|||
def forward(self, batch: dict[str, Tensor]) -> dict[str, Tensor]:
|
||||
"""Run the batch through the model and compute the loss for training or validation."""
|
||||
batch = self.normalize_inputs(batch)
|
||||
batch["observation.images"] = torch.stack([batch[k] for k in self.expected_image_keys], dim=-4)
|
||||
if len(self.expected_image_keys) > 0:
|
||||
batch["observation.images"] = torch.stack([batch[k] for k in self.expected_image_keys], dim=-4)
|
||||
batch = self.normalize_targets(batch)
|
||||
actions_hat, (mu_hat, log_sigma_x2_hat) = self.model(batch)
|
||||
|
||||
|
@ -200,12 +202,14 @@ class ACT(nn.Module):
|
|||
self.config = config
|
||||
# BERT style VAE encoder with input tokens [cls, robot_state, *action_sequence].
|
||||
# The cls token forms parameters of the latent's distribution (like this [*means, *log_variances]).
|
||||
self.use_input_state = "observation.state" in config.input_shapes
|
||||
self.use_robot_state = "observation.state" in config.input_shapes
|
||||
self.use_images = any(k.startswith("observation.image") for k in config.input_shapes)
|
||||
self.use_env_state = "observation.environment_state" in config.input_shapes
|
||||
if self.config.use_vae:
|
||||
self.vae_encoder = ACTEncoder(config)
|
||||
self.vae_encoder_cls_embed = nn.Embedding(1, config.dim_model)
|
||||
# Projection layer for joint-space configuration to hidden dimension.
|
||||
if self.use_input_state:
|
||||
if self.use_robot_state:
|
||||
self.vae_encoder_robot_state_input_proj = nn.Linear(
|
||||
config.input_shapes["observation.state"][0], config.dim_model
|
||||
)
|
||||
|
@ -218,7 +222,7 @@ class ACT(nn.Module):
|
|||
# Fixed sinusoidal positional embedding for the input to the VAE encoder. Unsqueeze for batch
|
||||
# dimension.
|
||||
num_input_token_encoder = 1 + config.chunk_size
|
||||
if self.use_input_state:
|
||||
if self.use_robot_state:
|
||||
num_input_token_encoder += 1
|
||||
self.register_buffer(
|
||||
"vae_encoder_pos_enc",
|
||||
|
@ -226,34 +230,45 @@ class ACT(nn.Module):
|
|||
)
|
||||
|
||||
# Backbone for image feature extraction.
|
||||
backbone_model = getattr(torchvision.models, config.vision_backbone)(
|
||||
replace_stride_with_dilation=[False, False, config.replace_final_stride_with_dilation],
|
||||
weights=config.pretrained_backbone_weights,
|
||||
norm_layer=FrozenBatchNorm2d,
|
||||
)
|
||||
# Note: The assumption here is that we are using a ResNet model (and hence layer4 is the final feature
|
||||
# map).
|
||||
# Note: The forward method of this returns a dict: {"feature_map": output}.
|
||||
self.backbone = IntermediateLayerGetter(backbone_model, return_layers={"layer4": "feature_map"})
|
||||
if self.use_images:
|
||||
backbone_model = getattr(torchvision.models, config.vision_backbone)(
|
||||
replace_stride_with_dilation=[False, False, config.replace_final_stride_with_dilation],
|
||||
weights=config.pretrained_backbone_weights,
|
||||
norm_layer=FrozenBatchNorm2d,
|
||||
)
|
||||
# Note: The assumption here is that we are using a ResNet model (and hence layer4 is the final
|
||||
# feature map).
|
||||
# Note: The forward method of this returns a dict: {"feature_map": output}.
|
||||
self.backbone = IntermediateLayerGetter(backbone_model, return_layers={"layer4": "feature_map"})
|
||||
|
||||
# Transformer (acts as VAE decoder when training with the variational objective).
|
||||
self.encoder = ACTEncoder(config)
|
||||
self.decoder = ACTDecoder(config)
|
||||
|
||||
# Transformer encoder input projections. The tokens will be structured like
|
||||
# [latent, robot_state, image_feature_map_pixels].
|
||||
if self.use_input_state:
|
||||
# [latent, (robot_state), (env_state), (image_feature_map_pixels)].
|
||||
if self.use_robot_state:
|
||||
self.encoder_robot_state_input_proj = nn.Linear(
|
||||
config.input_shapes["observation.state"][0], config.dim_model
|
||||
)
|
||||
if self.use_env_state:
|
||||
self.encoder_env_state_input_proj = nn.Linear(
|
||||
config.input_shapes["observation.environment_state"][0], config.dim_model
|
||||
)
|
||||
self.encoder_latent_input_proj = nn.Linear(config.latent_dim, config.dim_model)
|
||||
self.encoder_img_feat_input_proj = nn.Conv2d(
|
||||
backbone_model.fc.in_features, config.dim_model, kernel_size=1
|
||||
)
|
||||
if self.use_images:
|
||||
self.encoder_img_feat_input_proj = nn.Conv2d(
|
||||
backbone_model.fc.in_features, config.dim_model, kernel_size=1
|
||||
)
|
||||
# Transformer encoder positional embeddings.
|
||||
num_input_token_decoder = 2 if self.use_input_state else 1
|
||||
self.encoder_robot_and_latent_pos_embed = nn.Embedding(num_input_token_decoder, config.dim_model)
|
||||
self.encoder_cam_feat_pos_embed = ACTSinusoidalPositionEmbedding2d(config.dim_model // 2)
|
||||
n_1d_tokens = 1 # for the latent
|
||||
if self.use_robot_state:
|
||||
n_1d_tokens += 1
|
||||
if self.use_env_state:
|
||||
n_1d_tokens += 1
|
||||
self.encoder_1d_feature_pos_embed = nn.Embedding(n_1d_tokens, config.dim_model)
|
||||
if self.use_images:
|
||||
self.encoder_cam_feat_pos_embed = ACTSinusoidalPositionEmbedding2d(config.dim_model // 2)
|
||||
|
||||
# Transformer decoder.
|
||||
# Learnable positional embedding for the transformer's decoder (in the style of DETR object queries).
|
||||
|
@ -274,10 +289,13 @@ class ACT(nn.Module):
|
|||
"""A forward pass through the Action Chunking Transformer (with optional VAE encoder).
|
||||
|
||||
`batch` should have the following structure:
|
||||
|
||||
{
|
||||
"observation.state": (B, state_dim) batch of robot states.
|
||||
"observation.state" (optional): (B, state_dim) batch of robot states.
|
||||
|
||||
"observation.images": (B, n_cameras, C, H, W) batch of images.
|
||||
AND/OR
|
||||
"observation.environment_state": (B, env_dim) batch of environment states.
|
||||
|
||||
"action" (optional, only if training with VAE): (B, chunk_size, action dim) batch of actions.
|
||||
}
|
||||
|
||||
|
@ -291,7 +309,11 @@ class ACT(nn.Module):
|
|||
"action" in batch
|
||||
), "actions must be provided when using the variational objective in training mode."
|
||||
|
||||
batch_size = batch["observation.images"].shape[0]
|
||||
batch_size = (
|
||||
batch["observation.images"]
|
||||
if "observation.images" in batch
|
||||
else batch["observation.environment_state"]
|
||||
).shape[0]
|
||||
|
||||
# Prepare the latent for input to the transformer encoder.
|
||||
if self.config.use_vae and "action" in batch:
|
||||
|
@ -299,12 +321,12 @@ class ACT(nn.Module):
|
|||
cls_embed = einops.repeat(
|
||||
self.vae_encoder_cls_embed.weight, "1 d -> b 1 d", b=batch_size
|
||||
) # (B, 1, D)
|
||||
if self.use_input_state:
|
||||
if self.use_robot_state:
|
||||
robot_state_embed = self.vae_encoder_robot_state_input_proj(batch["observation.state"])
|
||||
robot_state_embed = robot_state_embed.unsqueeze(1) # (B, 1, D)
|
||||
action_embed = self.vae_encoder_action_input_proj(batch["action"]) # (B, S, D)
|
||||
|
||||
if self.use_input_state:
|
||||
if self.use_robot_state:
|
||||
vae_encoder_input = [cls_embed, robot_state_embed, action_embed] # (B, S+2, D)
|
||||
else:
|
||||
vae_encoder_input = [cls_embed, action_embed]
|
||||
|
@ -318,7 +340,7 @@ class ACT(nn.Module):
|
|||
# sequence depending whether we use the input states or not (cls and robot state)
|
||||
# False means not a padding token.
|
||||
cls_joint_is_pad = torch.full(
|
||||
(batch_size, 2 if self.use_input_state else 1),
|
||||
(batch_size, 2 if self.use_robot_state else 1),
|
||||
False,
|
||||
device=batch["observation.state"].device,
|
||||
)
|
||||
|
@ -347,56 +369,55 @@ class ACT(nn.Module):
|
|||
batch["observation.state"].device
|
||||
)
|
||||
|
||||
# Prepare all other transformer encoder inputs.
|
||||
# Prepare transformer encoder inputs.
|
||||
encoder_in_tokens = [self.encoder_latent_input_proj(latent_sample)]
|
||||
encoder_in_pos_embed = list(self.encoder_1d_feature_pos_embed.weight.unsqueeze(1))
|
||||
# Robot state token.
|
||||
if self.use_robot_state:
|
||||
encoder_in_tokens.append(self.encoder_robot_state_input_proj(batch["observation.state"]))
|
||||
# Environment state token.
|
||||
if self.use_env_state:
|
||||
encoder_in_tokens.append(
|
||||
self.encoder_env_state_input_proj(batch["observation.environment_state"])
|
||||
)
|
||||
|
||||
# Camera observation features and positional embeddings.
|
||||
all_cam_features = []
|
||||
all_cam_pos_embeds = []
|
||||
images = batch["observation.images"]
|
||||
if self.use_images:
|
||||
all_cam_features = []
|
||||
all_cam_pos_embeds = []
|
||||
images = batch["observation.images"]
|
||||
|
||||
for cam_index in range(images.shape[-4]):
|
||||
cam_features = self.backbone(images[:, cam_index])["feature_map"]
|
||||
# TODO(rcadene, alexander-soare): remove call to `.to` to speedup forward ; precompute and use buffer
|
||||
cam_pos_embed = self.encoder_cam_feat_pos_embed(cam_features).to(dtype=cam_features.dtype)
|
||||
cam_features = self.encoder_img_feat_input_proj(cam_features) # (B, C, h, w)
|
||||
all_cam_features.append(cam_features)
|
||||
all_cam_pos_embeds.append(cam_pos_embed)
|
||||
# Concatenate camera observation feature maps and positional embeddings along the width dimension.
|
||||
encoder_in = torch.cat(all_cam_features, axis=-1)
|
||||
cam_pos_embed = torch.cat(all_cam_pos_embeds, axis=-1)
|
||||
for cam_index in range(images.shape[-4]):
|
||||
cam_features = self.backbone(images[:, cam_index])["feature_map"]
|
||||
# TODO(rcadene, alexander-soare): remove call to `.to` to speedup forward ; precompute and use
|
||||
# buffer
|
||||
cam_pos_embed = self.encoder_cam_feat_pos_embed(cam_features).to(dtype=cam_features.dtype)
|
||||
cam_features = self.encoder_img_feat_input_proj(cam_features) # (B, C, h, w)
|
||||
all_cam_features.append(cam_features)
|
||||
all_cam_pos_embeds.append(cam_pos_embed)
|
||||
# Concatenate camera observation feature maps and positional embeddings along the width dimension,
|
||||
# and move to (sequence, batch, dim).
|
||||
all_cam_features = torch.cat(all_cam_features, axis=-1)
|
||||
encoder_in_tokens.extend(einops.rearrange(all_cam_features, "b c h w -> (h w) b c"))
|
||||
all_cam_pos_embeds = torch.cat(all_cam_pos_embeds, axis=-1)
|
||||
encoder_in_pos_embed.extend(einops.rearrange(all_cam_pos_embeds, "b c h w -> (h w) b c"))
|
||||
|
||||
# Get positional embeddings for robot state and latent.
|
||||
if self.use_input_state:
|
||||
robot_state_embed = self.encoder_robot_state_input_proj(batch["observation.state"]) # (B, C)
|
||||
latent_embed = self.encoder_latent_input_proj(latent_sample) # (B, C)
|
||||
|
||||
# Stack encoder input and positional embeddings moving to (S, B, C).
|
||||
encoder_in_feats = [latent_embed, robot_state_embed] if self.use_input_state else [latent_embed]
|
||||
encoder_in = torch.cat(
|
||||
[
|
||||
torch.stack(encoder_in_feats, axis=0),
|
||||
einops.rearrange(encoder_in, "b c h w -> (h w) b c"),
|
||||
]
|
||||
)
|
||||
pos_embed = torch.cat(
|
||||
[
|
||||
self.encoder_robot_and_latent_pos_embed.weight.unsqueeze(1),
|
||||
cam_pos_embed.flatten(2).permute(2, 0, 1),
|
||||
],
|
||||
axis=0,
|
||||
)
|
||||
# Stack all tokens along the sequence dimension.
|
||||
encoder_in_tokens = torch.stack(encoder_in_tokens, axis=0)
|
||||
encoder_in_pos_embed = torch.stack(encoder_in_pos_embed, axis=0)
|
||||
|
||||
# Forward pass through the transformer modules.
|
||||
encoder_out = self.encoder(encoder_in, pos_embed=pos_embed)
|
||||
encoder_out = self.encoder(encoder_in_tokens, pos_embed=encoder_in_pos_embed)
|
||||
# TODO(rcadene, alexander-soare): remove call to `device` ; precompute and use buffer
|
||||
decoder_in = torch.zeros(
|
||||
(self.config.chunk_size, batch_size, self.config.dim_model),
|
||||
dtype=pos_embed.dtype,
|
||||
device=pos_embed.device,
|
||||
dtype=encoder_in_pos_embed.dtype,
|
||||
device=encoder_in_pos_embed.device,
|
||||
)
|
||||
decoder_out = self.decoder(
|
||||
decoder_in,
|
||||
encoder_out,
|
||||
encoder_pos_embed=pos_embed,
|
||||
encoder_pos_embed=encoder_in_pos_embed,
|
||||
decoder_pos_embed=self.decoder_pos_embed.weight.unsqueeze(1),
|
||||
)
|
||||
|
||||
|
|
|
@ -1,3 +1,3 @@
|
|||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:523f220f3acbab0cd4aef8a13c77916634488b1af08a06e4e65d1aecafdc2cae
|
||||
oid sha256:28444747a9cb3876f86ae86fed72e587dbcacfccd87c5c24b8ecac30c3ce3077
|
||||
size 5104
|
||||
|
|
|
@ -1,3 +1,3 @@
|
|||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:95dd049b4386030ced4505586b874f16906f8d89f29b570201782eebcbe4f402
|
||||
size 31688
|
||||
oid sha256:a43a9ddaf8527e3344b22bd21276e1f561e83599d720933b28725b00d94823c0
|
||||
size 31672
|
||||
|
|
|
@ -1,3 +1,3 @@
|
|||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:806851d60b6c492620b7269876eef9ce17756ec03da93f36b351f8aa75be0954
|
||||
size 33408
|
||||
oid sha256:093bff1fbc3bde2547bccbbefc277d02368a8d4a9100b3e4bd47c755798cad68
|
||||
size 33400
|
||||
|
|
|
@ -1,3 +1,3 @@
|
|||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:3f4e0e525aeb22ea94b79e26b39a87e6f2da9fbee33e493906aaf2aad9a7c1ef
|
||||
oid sha256:85bed637e90f15c64e4af01d2dbc5d9c3a370215f2c8c379494fa3acb413bc2e
|
||||
size 515400
|
||||
|
|
|
@ -1,3 +1,3 @@
|
|||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:6dc658a1c1616c7d1c211eb8f87cec3d44f7b67d6b3cea7a6ce12b32d74674da
|
||||
size 31688
|
||||
oid sha256:00cf8e548d7ea23aa70de79e05c39990a32a790def824f729e6c98bea31c69bc
|
||||
size 31672
|
||||
|
|
|
@ -1,3 +1,3 @@
|
|||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:01d993c67a9267032fe9fbeff20b4359c209464976ea503040a0a76ae213450a
|
||||
size 33408
|
||||
oid sha256:b3a4c2581f48229312a582d91f0adea8078c0c5b744c34d76723edf4731f9003
|
||||
size 33400
|
||||
|
|
|
@ -1,3 +1,3 @@
|
|||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:2fff6294b94cf42d4dd1249dcc5c3b0269d6d9c697f894e61b867d7ab81a94e4
|
||||
oid sha256:aab00b0349901450adbb8e0d7d4af1f743dd88e7e19f1bcfef821de8bdcc957d
|
||||
size 5104
|
||||
|
|
|
@ -1,3 +1,3 @@
|
|||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:4aa23e51607604a18b70fa42edbbe1af34f119d985628fc27cc1bbb0efbc8901
|
||||
size 31688
|
||||
oid sha256:de70c3055aa052f5b811ec7c2994ec6861efe645c6caee41e04a3460598500d5
|
||||
size 31672
|
||||
|
|
|
@ -1,3 +1,3 @@
|
|||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:6fd368406c93cb562a69ff11cf7adf34a4b223507dcb2b9e9b8f44ee1036988a
|
||||
oid sha256:d4070bd1f1cd8c72bc2daf628088e42b8ef113f6df0bfd9e91be052bc90038c3
|
||||
size 68
|
||||
|
|
|
@ -1,3 +1,3 @@
|
|||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:5663ee79a13bb70a1604b887dd21bf89d18482287442419c6cc6c5bf0e753e99
|
||||
size 34928
|
||||
oid sha256:19fdc1edf327e04132c1917024289b3d16e25a1ec2130f3df797fe07434dfbbd
|
||||
size 34920
|
||||
|
|
|
@ -1,3 +1,3 @@
|
|||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:fb1a45463efd860af2ca22c16c77d55a18bd96fef080ae77978845a2f22ef716
|
||||
oid sha256:dcd8ebaefd3ff267eb24654135d1efb179d713e6cfe6917f793a3e2483efd501
|
||||
size 5104
|
||||
|
|
|
@ -1,3 +1,3 @@
|
|||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:aa5a43e22f01d8e2f8d19f31753608794f1edbd74aaf71660091ab80ea58dc9b
|
||||
size 30808
|
||||
oid sha256:107e98647ed1081745476b250df8848c0c430b2aff51d614f6b2db95684467aa
|
||||
size 30800
|
||||
|
|
|
@ -1,3 +1,3 @@
|
|||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:54d1f75cf67a7b1d7a7c6865ecb9b1cc86a2f032d1890245f8996789ab6e0df6
|
||||
size 33608
|
||||
oid sha256:adbae737c987f912509d3fba06f332bda700bfc2c6d83a09c969e9d7a3ca75f7
|
||||
size 33600
|
||||
|
|
Loading…
Reference in New Issue