diff --git a/lerobot/common/policies/act/modeling_act.py b/lerobot/common/policies/act/modeling_act.py index f2b16a1e..72d4df03 100644 --- a/lerobot/common/policies/act/modeling_act.py +++ b/lerobot/common/policies/act/modeling_act.py @@ -119,9 +119,7 @@ class ACTPolicy(PreTrainedPolicy): batch = self.normalize_inputs(batch) if self.config.image_features: batch = dict(batch) # shallow copy so that adding a key doesn't modify the original - batch["observation.images"] = torch.stack( - [batch[key] for key in self.config.image_features], dim=-4 - ) + batch["observation.images"] = [batch[key] for key in self.config.image_features] # If we are doing temporal ensembling, do online updates where we keep track of the number of actions # we are ensembling over. @@ -149,9 +147,8 @@ class ACTPolicy(PreTrainedPolicy): batch = self.normalize_inputs(batch) if self.config.image_features: batch = dict(batch) # shallow copy so that adding a key doesn't modify the original - batch["observation.images"] = torch.stack( - [batch[key] for key in self.config.image_features], dim=-4 - ) + batch["observation.images"] = [batch[key] for key in self.config.image_features] + batch = self.normalize_targets(batch) actions_hat, (mu_hat, log_sigma_x2_hat) = self.model(batch) @@ -413,11 +410,10 @@ class ACT(nn.Module): "actions must be provided when using the variational objective in training mode." ) - batch_size = ( - batch["observation.images"] - if "observation.images" in batch - else batch["observation.environment_state"] - ).shape[0] + if "observation.images" in batch: + batch_size = batch["observation.images"][0].shape[0] + else: + batch_size = batch["observation.environment_state"].shape[0] # Prepare the latent for input to the transformer encoder. if self.config.use_vae and "action" in batch: @@ -490,20 +486,21 @@ class ACT(nn.Module): all_cam_features = [] all_cam_pos_embeds = [] - for cam_index in range(batch["observation.images"].shape[-4]): - cam_features = self.backbone(batch["observation.images"][:, cam_index])["feature_map"] - # TODO(rcadene, alexander-soare): remove call to `.to` to speedup forward ; precompute and use - # buffer + # For a list of images, the H and W may vary but H*W is constant. + for img in batch["observation.images"]: + cam_features = self.backbone(img)["feature_map"] cam_pos_embed = self.encoder_cam_feat_pos_embed(cam_features).to(dtype=cam_features.dtype) - cam_features = self.encoder_img_feat_input_proj(cam_features) # (B, C, h, w) + cam_features = self.encoder_img_feat_input_proj(cam_features) + + # Rearrange features to (sequence, batch, dim). + cam_features = einops.rearrange(cam_features, "b c h w -> (h w) b c") + cam_pos_embed = einops.rearrange(cam_pos_embed, "b c h w -> (h w) b c") + all_cam_features.append(cam_features) all_cam_pos_embeds.append(cam_pos_embed) - # Concatenate camera observation feature maps and positional embeddings along the width dimension, - # and move to (sequence, batch, dim). - all_cam_features = torch.cat(all_cam_features, axis=-1) - encoder_in_tokens.extend(einops.rearrange(all_cam_features, "b c h w -> (h w) b c")) - all_cam_pos_embeds = torch.cat(all_cam_pos_embeds, axis=-1) - encoder_in_pos_embed.extend(einops.rearrange(all_cam_pos_embeds, "b c h w -> (h w) b c")) + + encoder_in_tokens.extend(torch.cat(all_cam_features, axis=0)) + encoder_in_pos_embed.extend(torch.cat(all_cam_pos_embeds, axis=0)) # Stack all tokens along the sequence dimension. encoder_in_tokens = torch.stack(encoder_in_tokens, axis=0)