User/pepijn/2025 03 17 act different image shapes (#870)

This commit is contained in:
Pepijn 2025-03-18 11:09:05 +01:00 committed by GitHub
parent 1c15bab70f
commit e8159997c7
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
1 changed files with 19 additions and 22 deletions

View File

@ -119,9 +119,7 @@ class ACTPolicy(PreTrainedPolicy):
batch = self.normalize_inputs(batch) batch = self.normalize_inputs(batch)
if self.config.image_features: if self.config.image_features:
batch = dict(batch) # shallow copy so that adding a key doesn't modify the original batch = dict(batch) # shallow copy so that adding a key doesn't modify the original
batch["observation.images"] = torch.stack( batch["observation.images"] = [batch[key] for key in self.config.image_features]
[batch[key] for key in self.config.image_features], dim=-4
)
# If we are doing temporal ensembling, do online updates where we keep track of the number of actions # If we are doing temporal ensembling, do online updates where we keep track of the number of actions
# we are ensembling over. # we are ensembling over.
@ -149,9 +147,8 @@ class ACTPolicy(PreTrainedPolicy):
batch = self.normalize_inputs(batch) batch = self.normalize_inputs(batch)
if self.config.image_features: if self.config.image_features:
batch = dict(batch) # shallow copy so that adding a key doesn't modify the original batch = dict(batch) # shallow copy so that adding a key doesn't modify the original
batch["observation.images"] = torch.stack( batch["observation.images"] = [batch[key] for key in self.config.image_features]
[batch[key] for key in self.config.image_features], dim=-4
)
batch = self.normalize_targets(batch) batch = self.normalize_targets(batch)
actions_hat, (mu_hat, log_sigma_x2_hat) = self.model(batch) actions_hat, (mu_hat, log_sigma_x2_hat) = self.model(batch)
@ -413,11 +410,10 @@ class ACT(nn.Module):
"actions must be provided when using the variational objective in training mode." "actions must be provided when using the variational objective in training mode."
) )
batch_size = ( if "observation.images" in batch:
batch["observation.images"] batch_size = batch["observation.images"][0].shape[0]
if "observation.images" in batch else:
else batch["observation.environment_state"] batch_size = batch["observation.environment_state"].shape[0]
).shape[0]
# Prepare the latent for input to the transformer encoder. # Prepare the latent for input to the transformer encoder.
if self.config.use_vae and "action" in batch: if self.config.use_vae and "action" in batch:
@ -490,20 +486,21 @@ class ACT(nn.Module):
all_cam_features = [] all_cam_features = []
all_cam_pos_embeds = [] all_cam_pos_embeds = []
for cam_index in range(batch["observation.images"].shape[-4]): # For a list of images, the H and W may vary but H*W is constant.
cam_features = self.backbone(batch["observation.images"][:, cam_index])["feature_map"] for img in batch["observation.images"]:
# TODO(rcadene, alexander-soare): remove call to `.to` to speedup forward ; precompute and use cam_features = self.backbone(img)["feature_map"]
# buffer
cam_pos_embed = self.encoder_cam_feat_pos_embed(cam_features).to(dtype=cam_features.dtype) cam_pos_embed = self.encoder_cam_feat_pos_embed(cam_features).to(dtype=cam_features.dtype)
cam_features = self.encoder_img_feat_input_proj(cam_features) # (B, C, h, w) cam_features = self.encoder_img_feat_input_proj(cam_features)
# Rearrange features to (sequence, batch, dim).
cam_features = einops.rearrange(cam_features, "b c h w -> (h w) b c")
cam_pos_embed = einops.rearrange(cam_pos_embed, "b c h w -> (h w) b c")
all_cam_features.append(cam_features) all_cam_features.append(cam_features)
all_cam_pos_embeds.append(cam_pos_embed) all_cam_pos_embeds.append(cam_pos_embed)
# Concatenate camera observation feature maps and positional embeddings along the width dimension,
# and move to (sequence, batch, dim). encoder_in_tokens.extend(torch.cat(all_cam_features, axis=0))
all_cam_features = torch.cat(all_cam_features, axis=-1) encoder_in_pos_embed.extend(torch.cat(all_cam_pos_embeds, axis=0))
encoder_in_tokens.extend(einops.rearrange(all_cam_features, "b c h w -> (h w) b c"))
all_cam_pos_embeds = torch.cat(all_cam_pos_embeds, axis=-1)
encoder_in_pos_embed.extend(einops.rearrange(all_cam_pos_embeds, "b c h w -> (h w) b c"))
# Stack all tokens along the sequence dimension. # Stack all tokens along the sequence dimension.
encoder_in_tokens = torch.stack(encoder_in_tokens, axis=0) encoder_in_tokens = torch.stack(encoder_in_tokens, axis=0)