Bug fix: missing attention mask in VAE encoder in ACT policy (#279)

Co-authored-by: Alexander Soare <alexander.soare159@gmail.com>
This commit is contained in:
Thomas Wolf 2024-06-19 13:07:21 +02:00 committed by GitHub
parent 56199fb76f
commit 48951662f2
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
7 changed files with 60 additions and 25 deletions

View File

@ -314,9 +314,23 @@ class ACT(nn.Module):
# Note: detach() shouldn't be necessary but leaving it the same as the original code just in case.
pos_embed = self.vae_encoder_pos_enc.clone().detach() # (1, S+2, D)
# Prepare key padding mask for the transformer encoder. We have 1 or 2 extra tokens at the start of the
# sequence depending whether we use the input states or not (cls and robot state)
# False means not a padding token.
cls_joint_is_pad = torch.full(
(batch_size, 2 if self.use_input_state else 1),
False,
device=batch["observation.state"].device,
)
key_padding_mask = torch.cat(
[cls_joint_is_pad, batch["action_is_pad"]], axis=1
) # (bs, seq+1 or 2)
# Forward pass through VAE encoder to get the latent PDF parameters.
cls_token_out = self.vae_encoder(
vae_encoder_input.permute(1, 0, 2), pos_embed=pos_embed.permute(1, 0, 2)
vae_encoder_input.permute(1, 0, 2),
pos_embed=pos_embed.permute(1, 0, 2),
key_padding_mask=key_padding_mask,
)[0] # select the class token, with shape (B, D)
latent_pdf_params = self.vae_encoder_latent_output_proj(cls_token_out)
mu = latent_pdf_params[:, : self.config.latent_dim]
@ -402,9 +416,11 @@ class ACTEncoder(nn.Module):
self.layers = nn.ModuleList([ACTEncoderLayer(config) for _ in range(config.n_encoder_layers)])
self.norm = nn.LayerNorm(config.dim_model) if config.pre_norm else nn.Identity()
def forward(self, x: Tensor, pos_embed: Tensor | None = None) -> Tensor:
def forward(
self, x: Tensor, pos_embed: Tensor | None = None, key_padding_mask: Tensor | None = None
) -> Tensor:
for layer in self.layers:
x = layer(x, pos_embed=pos_embed)
x = layer(x, pos_embed=pos_embed, key_padding_mask=key_padding_mask)
x = self.norm(x)
return x
@ -427,12 +443,13 @@ class ACTEncoderLayer(nn.Module):
self.activation = get_activation_fn(config.feedforward_activation)
self.pre_norm = config.pre_norm
def forward(self, x, pos_embed: Tensor | None = None) -> Tensor:
def forward(self, x, pos_embed: Tensor | None = None, key_padding_mask: Tensor | None = None) -> Tensor:
skip = x
if self.pre_norm:
x = self.norm1(x)
q = k = x if pos_embed is None else x + pos_embed
x = self.self_attn(q, k, value=x)[0] # select just the output, not the attention weights
x = self.self_attn(q, k, value=x, key_padding_mask=key_padding_mask)
x = x[0] # note: [0] to select just the output, not the attention weights
x = skip + self.dropout1(x)
if self.pre_norm:
skip = x

View File

@ -0,0 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:3f4e0e525aeb22ea94b79e26b39a87e6f2da9fbee33e493906aaf2aad9a7c1ef
size 515400

View File

@ -0,0 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:6dc658a1c1616c7d1c211eb8f87cec3d44f7b67d6b3cea7a6ce12b32d74674da
size 31688

View File

@ -0,0 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:03971f92b7907b6b7e6ac207f508666104cd84c26c5276f510c431db604e188b
size 68

View File

@ -0,0 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:01d993c67a9267032fe9fbeff20b4359c209464976ea503040a0a76ae213450a
size 33408

View File

@ -89,8 +89,8 @@ def get_policy_stats(env_name, policy_name, extra_overrides):
return output_dict, grad_stats, param_stats, actions
def save_policy_to_safetensors(output_dir, env_name, policy_name, extra_overrides):
env_policy_dir = Path(output_dir) / f"{env_name}_{policy_name}"
def save_policy_to_safetensors(output_dir, env_name, policy_name, extra_overrides, file_name_extra):
env_policy_dir = Path(output_dir) / f"{env_name}_{policy_name}{file_name_extra}"
if env_policy_dir.exists():
print(f"Overwrite existing safetensors in '{env_policy_dir}':")
@ -108,15 +108,17 @@ def save_policy_to_safetensors(output_dir, env_name, policy_name, extra_override
if __name__ == "__main__":
env_policies = [
("xarm", "tdmpc", []),
(
"pusht",
"diffusion",
["policy.n_action_steps=8", "policy.num_inference_steps=10", "policy.down_dims=[128, 256, 512]"],
),
("aloha", "act", ["policy.n_action_steps=10"]),
("dora_aloha_real", "act_real", ["policy.n_action_steps=10"]),
("dora_aloha_real", "act_real_no_state", ["policy.n_action_steps=10"]),
# ("xarm", "tdmpc", []),
# (
# "pusht",
# "diffusion",
# ["policy.n_action_steps=8", "policy.num_inference_steps=10", "policy.down_dims=[128, 256, 512]"],
# ),
("aloha", "act", ["policy.n_action_steps=1000", "policy.chunk_size=1000"], "_1000_steps"),
# ("dora_aloha_real", "act_real", ["policy.n_action_steps=10"]),
# ("dora_aloha_real", "act_real_no_state", ["policy.n_action_steps=10"]),
]
for env, policy, extra_overrides in env_policies:
save_policy_to_safetensors("tests/data/save_policy_to_safetensors", env, policy, extra_overrides)
for env, policy, extra_overrides, file_name_extra in env_policies:
save_policy_to_safetensors(
"tests/data/save_policy_to_safetensors", env, policy, extra_overrides, file_name_extra
)

View File

@ -315,24 +315,26 @@ def test_normalize(insert_temporal_dim):
@pytest.mark.parametrize(
"env_name, policy_name, extra_overrides",
"env_name, policy_name, extra_overrides, file_name_extra",
[
("xarm", "tdmpc", []),
("xarm", "tdmpc", [], ""),
(
"pusht",
"diffusion",
["policy.n_action_steps=8", "policy.num_inference_steps=10", "policy.down_dims=[128, 256, 512]"],
"",
),
("aloha", "act", ["policy.n_action_steps=10"]),
("dora_aloha_real", "act_real", ["policy.n_action_steps=10"]),
("dora_aloha_real", "act_real_no_state", ["policy.n_action_steps=10"]),
("aloha", "act", ["policy.n_action_steps=10"], ""),
("aloha", "act", ["policy.n_action_steps=1000", "policy.chunk_size=1000"], "_1000_steps"),
("dora_aloha_real", "act_real", ["policy.n_action_steps=10"], ""),
("dora_aloha_real", "act_real_no_state", ["policy.n_action_steps=10"], ""),
],
)
# As artifacts have been generated on an x86_64 kernel, this test won't
# pass if it's run on another platform due to floating point errors
@require_x86_64_kernel
@require_cpu
def test_backward_compatibility(env_name, policy_name, extra_overrides):
def test_backward_compatibility(env_name, policy_name, extra_overrides, file_name_extra):
"""
NOTE: If this test does not pass, and you have intentionally changed something in the policy:
1. Inspect the differences in policy outputs and make sure you can account for them. Your PR should
@ -344,7 +346,9 @@ def test_backward_compatibility(env_name, policy_name, extra_overrides):
5. Remember to restore `tests/scripts/save_policy_to_safetensors.py` to its original state.
6. Remember to stage and commit the resulting changes to `tests/data`.
"""
env_policy_dir = Path("tests/data/save_policy_to_safetensors") / f"{env_name}_{policy_name}"
env_policy_dir = (
Path("tests/data/save_policy_to_safetensors") / f"{env_name}_{policy_name}{file_name_extra}"
)
saved_output_dict = load_file(env_policy_dir / "output_dict.safetensors")
saved_grad_stats = load_file(env_policy_dir / "grad_stats.safetensors")
saved_param_stats = load_file(env_policy_dir / "param_stats.safetensors")