diff --git a/lerobot/common/policies/act/configuration_act.py b/lerobot/common/policies/act/configuration_act.py index a4b0b7d2..92a52eac 100644 --- a/lerobot/common/policies/act/configuration_act.py +++ b/lerobot/common/policies/act/configuration_act.py @@ -26,7 +26,10 @@ class ACTConfig: Those are: `input_shapes` and 'output_shapes`. Notes on the inputs and outputs: - - At least one key starting with "observation.image is required as an input. + - Either: + - At least one key starting with "observation.image is required as an input. + AND/OR + - The key "observation.environment_state" is required as input. - If there are multiple keys beginning with "observation.images." they are treated as multiple camera views. Right now we only support all images having the same shape. - May optionally work without an "observation.state" key for the proprioceptive robot state. @@ -162,3 +165,8 @@ class ACTConfig: raise ValueError( f"Multiple observation steps not handled yet. Got `nobs_steps={self.n_obs_steps}`" ) + if ( + not any(k.startswith("observation.image") for k in self.input_shapes) + and "observation.environment_state" not in self.input_shapes + ): + raise ValueError("You must provide at least one image or the environment state among the inputs.") diff --git a/lerobot/common/policies/act/modeling_act.py b/lerobot/common/policies/act/modeling_act.py index 5f302bc7..0a236100 100644 --- a/lerobot/common/policies/act/modeling_act.py +++ b/lerobot/common/policies/act/modeling_act.py @@ -97,7 +97,8 @@ class ACTPolicy(nn.Module, PyTorchModelHubMixin): self.eval() batch = self.normalize_inputs(batch) - batch["observation.images"] = torch.stack([batch[k] for k in self.expected_image_keys], dim=-4) + if len(self.expected_image_keys) > 0: + batch["observation.images"] = torch.stack([batch[k] for k in self.expected_image_keys], dim=-4) # If we are doing temporal ensembling, keep track of the exponential moving average (EMA), and return # the first action. @@ -135,7 +136,8 @@ class ACTPolicy(nn.Module, PyTorchModelHubMixin): def forward(self, batch: dict[str, Tensor]) -> dict[str, Tensor]: """Run the batch through the model and compute the loss for training or validation.""" batch = self.normalize_inputs(batch) - batch["observation.images"] = torch.stack([batch[k] for k in self.expected_image_keys], dim=-4) + if len(self.expected_image_keys) > 0: + batch["observation.images"] = torch.stack([batch[k] for k in self.expected_image_keys], dim=-4) batch = self.normalize_targets(batch) actions_hat, (mu_hat, log_sigma_x2_hat) = self.model(batch) @@ -200,12 +202,14 @@ class ACT(nn.Module): self.config = config # BERT style VAE encoder with input tokens [cls, robot_state, *action_sequence]. # The cls token forms parameters of the latent's distribution (like this [*means, *log_variances]). - self.use_input_state = "observation.state" in config.input_shapes + self.use_robot_state = "observation.state" in config.input_shapes + self.use_images = any(k.startswith("observation.image") for k in config.input_shapes) + self.use_env_state = "observation.environment_state" in config.input_shapes if self.config.use_vae: self.vae_encoder = ACTEncoder(config) self.vae_encoder_cls_embed = nn.Embedding(1, config.dim_model) # Projection layer for joint-space configuration to hidden dimension. - if self.use_input_state: + if self.use_robot_state: self.vae_encoder_robot_state_input_proj = nn.Linear( config.input_shapes["observation.state"][0], config.dim_model ) @@ -218,7 +222,7 @@ class ACT(nn.Module): # Fixed sinusoidal positional embedding for the input to the VAE encoder. Unsqueeze for batch # dimension. num_input_token_encoder = 1 + config.chunk_size - if self.use_input_state: + if self.use_robot_state: num_input_token_encoder += 1 self.register_buffer( "vae_encoder_pos_enc", @@ -226,34 +230,45 @@ class ACT(nn.Module): ) # Backbone for image feature extraction. - backbone_model = getattr(torchvision.models, config.vision_backbone)( - replace_stride_with_dilation=[False, False, config.replace_final_stride_with_dilation], - weights=config.pretrained_backbone_weights, - norm_layer=FrozenBatchNorm2d, - ) - # Note: The assumption here is that we are using a ResNet model (and hence layer4 is the final feature - # map). - # Note: The forward method of this returns a dict: {"feature_map": output}. - self.backbone = IntermediateLayerGetter(backbone_model, return_layers={"layer4": "feature_map"}) + if self.use_images: + backbone_model = getattr(torchvision.models, config.vision_backbone)( + replace_stride_with_dilation=[False, False, config.replace_final_stride_with_dilation], + weights=config.pretrained_backbone_weights, + norm_layer=FrozenBatchNorm2d, + ) + # Note: The assumption here is that we are using a ResNet model (and hence layer4 is the final + # feature map). + # Note: The forward method of this returns a dict: {"feature_map": output}. + self.backbone = IntermediateLayerGetter(backbone_model, return_layers={"layer4": "feature_map"}) # Transformer (acts as VAE decoder when training with the variational objective). self.encoder = ACTEncoder(config) self.decoder = ACTDecoder(config) # Transformer encoder input projections. The tokens will be structured like - # [latent, robot_state, image_feature_map_pixels]. - if self.use_input_state: + # [latent, (robot_state), (env_state), (image_feature_map_pixels)]. + if self.use_robot_state: self.encoder_robot_state_input_proj = nn.Linear( config.input_shapes["observation.state"][0], config.dim_model ) + if self.use_env_state: + self.encoder_env_state_input_proj = nn.Linear( + config.input_shapes["observation.environment_state"][0], config.dim_model + ) self.encoder_latent_input_proj = nn.Linear(config.latent_dim, config.dim_model) - self.encoder_img_feat_input_proj = nn.Conv2d( - backbone_model.fc.in_features, config.dim_model, kernel_size=1 - ) + if self.use_images: + self.encoder_img_feat_input_proj = nn.Conv2d( + backbone_model.fc.in_features, config.dim_model, kernel_size=1 + ) # Transformer encoder positional embeddings. - num_input_token_decoder = 2 if self.use_input_state else 1 - self.encoder_robot_and_latent_pos_embed = nn.Embedding(num_input_token_decoder, config.dim_model) - self.encoder_cam_feat_pos_embed = ACTSinusoidalPositionEmbedding2d(config.dim_model // 2) + n_1d_tokens = 1 # for the latent + if self.use_robot_state: + n_1d_tokens += 1 + if self.use_env_state: + n_1d_tokens += 1 + self.encoder_1d_feature_pos_embed = nn.Embedding(n_1d_tokens, config.dim_model) + if self.use_images: + self.encoder_cam_feat_pos_embed = ACTSinusoidalPositionEmbedding2d(config.dim_model // 2) # Transformer decoder. # Learnable positional embedding for the transformer's decoder (in the style of DETR object queries). @@ -274,10 +289,13 @@ class ACT(nn.Module): """A forward pass through the Action Chunking Transformer (with optional VAE encoder). `batch` should have the following structure: - { - "observation.state": (B, state_dim) batch of robot states. + "observation.state" (optional): (B, state_dim) batch of robot states. + "observation.images": (B, n_cameras, C, H, W) batch of images. + AND/OR + "observation.environment_state": (B, env_dim) batch of environment states. + "action" (optional, only if training with VAE): (B, chunk_size, action dim) batch of actions. } @@ -291,7 +309,11 @@ class ACT(nn.Module): "action" in batch ), "actions must be provided when using the variational objective in training mode." - batch_size = batch["observation.images"].shape[0] + batch_size = ( + batch["observation.images"] + if "observation.images" in batch + else batch["observation.environment_state"] + ).shape[0] # Prepare the latent for input to the transformer encoder. if self.config.use_vae and "action" in batch: @@ -299,12 +321,12 @@ class ACT(nn.Module): cls_embed = einops.repeat( self.vae_encoder_cls_embed.weight, "1 d -> b 1 d", b=batch_size ) # (B, 1, D) - if self.use_input_state: + if self.use_robot_state: robot_state_embed = self.vae_encoder_robot_state_input_proj(batch["observation.state"]) robot_state_embed = robot_state_embed.unsqueeze(1) # (B, 1, D) action_embed = self.vae_encoder_action_input_proj(batch["action"]) # (B, S, D) - if self.use_input_state: + if self.use_robot_state: vae_encoder_input = [cls_embed, robot_state_embed, action_embed] # (B, S+2, D) else: vae_encoder_input = [cls_embed, action_embed] @@ -318,7 +340,7 @@ class ACT(nn.Module): # sequence depending whether we use the input states or not (cls and robot state) # False means not a padding token. cls_joint_is_pad = torch.full( - (batch_size, 2 if self.use_input_state else 1), + (batch_size, 2 if self.use_robot_state else 1), False, device=batch["observation.state"].device, ) @@ -347,56 +369,55 @@ class ACT(nn.Module): batch["observation.state"].device ) - # Prepare all other transformer encoder inputs. + # Prepare transformer encoder inputs. + encoder_in_tokens = [self.encoder_latent_input_proj(latent_sample)] + encoder_in_pos_embed = list(self.encoder_1d_feature_pos_embed.weight.unsqueeze(1)) + # Robot state token. + if self.use_robot_state: + encoder_in_tokens.append(self.encoder_robot_state_input_proj(batch["observation.state"])) + # Environment state token. + if self.use_env_state: + encoder_in_tokens.append( + self.encoder_env_state_input_proj(batch["observation.environment_state"]) + ) + # Camera observation features and positional embeddings. - all_cam_features = [] - all_cam_pos_embeds = [] - images = batch["observation.images"] + if self.use_images: + all_cam_features = [] + all_cam_pos_embeds = [] + images = batch["observation.images"] - for cam_index in range(images.shape[-4]): - cam_features = self.backbone(images[:, cam_index])["feature_map"] - # TODO(rcadene, alexander-soare): remove call to `.to` to speedup forward ; precompute and use buffer - cam_pos_embed = self.encoder_cam_feat_pos_embed(cam_features).to(dtype=cam_features.dtype) - cam_features = self.encoder_img_feat_input_proj(cam_features) # (B, C, h, w) - all_cam_features.append(cam_features) - all_cam_pos_embeds.append(cam_pos_embed) - # Concatenate camera observation feature maps and positional embeddings along the width dimension. - encoder_in = torch.cat(all_cam_features, axis=-1) - cam_pos_embed = torch.cat(all_cam_pos_embeds, axis=-1) + for cam_index in range(images.shape[-4]): + cam_features = self.backbone(images[:, cam_index])["feature_map"] + # TODO(rcadene, alexander-soare): remove call to `.to` to speedup forward ; precompute and use + # buffer + cam_pos_embed = self.encoder_cam_feat_pos_embed(cam_features).to(dtype=cam_features.dtype) + cam_features = self.encoder_img_feat_input_proj(cam_features) # (B, C, h, w) + all_cam_features.append(cam_features) + all_cam_pos_embeds.append(cam_pos_embed) + # Concatenate camera observation feature maps and positional embeddings along the width dimension, + # and move to (sequence, batch, dim). + all_cam_features = torch.cat(all_cam_features, axis=-1) + encoder_in_tokens.extend(einops.rearrange(all_cam_features, "b c h w -> (h w) b c")) + all_cam_pos_embeds = torch.cat(all_cam_pos_embeds, axis=-1) + encoder_in_pos_embed.extend(einops.rearrange(all_cam_pos_embeds, "b c h w -> (h w) b c")) - # Get positional embeddings for robot state and latent. - if self.use_input_state: - robot_state_embed = self.encoder_robot_state_input_proj(batch["observation.state"]) # (B, C) - latent_embed = self.encoder_latent_input_proj(latent_sample) # (B, C) - - # Stack encoder input and positional embeddings moving to (S, B, C). - encoder_in_feats = [latent_embed, robot_state_embed] if self.use_input_state else [latent_embed] - encoder_in = torch.cat( - [ - torch.stack(encoder_in_feats, axis=0), - einops.rearrange(encoder_in, "b c h w -> (h w) b c"), - ] - ) - pos_embed = torch.cat( - [ - self.encoder_robot_and_latent_pos_embed.weight.unsqueeze(1), - cam_pos_embed.flatten(2).permute(2, 0, 1), - ], - axis=0, - ) + # Stack all tokens along the sequence dimension. + encoder_in_tokens = torch.stack(encoder_in_tokens, axis=0) + encoder_in_pos_embed = torch.stack(encoder_in_pos_embed, axis=0) # Forward pass through the transformer modules. - encoder_out = self.encoder(encoder_in, pos_embed=pos_embed) + encoder_out = self.encoder(encoder_in_tokens, pos_embed=encoder_in_pos_embed) # TODO(rcadene, alexander-soare): remove call to `device` ; precompute and use buffer decoder_in = torch.zeros( (self.config.chunk_size, batch_size, self.config.dim_model), - dtype=pos_embed.dtype, - device=pos_embed.device, + dtype=encoder_in_pos_embed.dtype, + device=encoder_in_pos_embed.device, ) decoder_out = self.decoder( decoder_in, encoder_out, - encoder_pos_embed=pos_embed, + encoder_pos_embed=encoder_in_pos_embed, decoder_pos_embed=self.decoder_pos_embed.weight.unsqueeze(1), ) diff --git a/tests/data/save_policy_to_safetensors/aloha_act/actions.safetensors b/tests/data/save_policy_to_safetensors/aloha_act/actions.safetensors index c5176423..583ab588 100644 --- a/tests/data/save_policy_to_safetensors/aloha_act/actions.safetensors +++ b/tests/data/save_policy_to_safetensors/aloha_act/actions.safetensors @@ -1,3 +1,3 @@ version https://git-lfs.github.com/spec/v1 -oid sha256:523f220f3acbab0cd4aef8a13c77916634488b1af08a06e4e65d1aecafdc2cae +oid sha256:28444747a9cb3876f86ae86fed72e587dbcacfccd87c5c24b8ecac30c3ce3077 size 5104 diff --git a/tests/data/save_policy_to_safetensors/aloha_act/grad_stats.safetensors b/tests/data/save_policy_to_safetensors/aloha_act/grad_stats.safetensors index bdecb18b..1e5a8475 100644 --- a/tests/data/save_policy_to_safetensors/aloha_act/grad_stats.safetensors +++ b/tests/data/save_policy_to_safetensors/aloha_act/grad_stats.safetensors @@ -1,3 +1,3 @@ version https://git-lfs.github.com/spec/v1 -oid sha256:95dd049b4386030ced4505586b874f16906f8d89f29b570201782eebcbe4f402 -size 31688 +oid sha256:a43a9ddaf8527e3344b22bd21276e1f561e83599d720933b28725b00d94823c0 +size 31672 diff --git a/tests/data/save_policy_to_safetensors/aloha_act/param_stats.safetensors b/tests/data/save_policy_to_safetensors/aloha_act/param_stats.safetensors index 26d91924..d7e14d50 100644 --- a/tests/data/save_policy_to_safetensors/aloha_act/param_stats.safetensors +++ b/tests/data/save_policy_to_safetensors/aloha_act/param_stats.safetensors @@ -1,3 +1,3 @@ version https://git-lfs.github.com/spec/v1 -oid sha256:806851d60b6c492620b7269876eef9ce17756ec03da93f36b351f8aa75be0954 -size 33408 +oid sha256:093bff1fbc3bde2547bccbbefc277d02368a8d4a9100b3e4bd47c755798cad68 +size 33400 diff --git a/tests/data/save_policy_to_safetensors/aloha_act_1000_steps/actions.safetensors b/tests/data/save_policy_to_safetensors/aloha_act_1000_steps/actions.safetensors index 1529153d..eae674a2 100644 --- a/tests/data/save_policy_to_safetensors/aloha_act_1000_steps/actions.safetensors +++ b/tests/data/save_policy_to_safetensors/aloha_act_1000_steps/actions.safetensors @@ -1,3 +1,3 @@ version https://git-lfs.github.com/spec/v1 -oid sha256:3f4e0e525aeb22ea94b79e26b39a87e6f2da9fbee33e493906aaf2aad9a7c1ef +oid sha256:85bed637e90f15c64e4af01d2dbc5d9c3a370215f2c8c379494fa3acb413bc2e size 515400 diff --git a/tests/data/save_policy_to_safetensors/aloha_act_1000_steps/grad_stats.safetensors b/tests/data/save_policy_to_safetensors/aloha_act_1000_steps/grad_stats.safetensors index 6a359f4e..fedfc7bc 100644 --- a/tests/data/save_policy_to_safetensors/aloha_act_1000_steps/grad_stats.safetensors +++ b/tests/data/save_policy_to_safetensors/aloha_act_1000_steps/grad_stats.safetensors @@ -1,3 +1,3 @@ version https://git-lfs.github.com/spec/v1 -oid sha256:6dc658a1c1616c7d1c211eb8f87cec3d44f7b67d6b3cea7a6ce12b32d74674da -size 31688 +oid sha256:00cf8e548d7ea23aa70de79e05c39990a32a790def824f729e6c98bea31c69bc +size 31672 diff --git a/tests/data/save_policy_to_safetensors/aloha_act_1000_steps/param_stats.safetensors b/tests/data/save_policy_to_safetensors/aloha_act_1000_steps/param_stats.safetensors index 157c382c..87deccc9 100644 --- a/tests/data/save_policy_to_safetensors/aloha_act_1000_steps/param_stats.safetensors +++ b/tests/data/save_policy_to_safetensors/aloha_act_1000_steps/param_stats.safetensors @@ -1,3 +1,3 @@ version https://git-lfs.github.com/spec/v1 -oid sha256:01d993c67a9267032fe9fbeff20b4359c209464976ea503040a0a76ae213450a -size 33408 +oid sha256:b3a4c2581f48229312a582d91f0adea8078c0c5b744c34d76723edf4731f9003 +size 33400 diff --git a/tests/data/save_policy_to_safetensors/dora_aloha_real_act_real/actions.safetensors b/tests/data/save_policy_to_safetensors/dora_aloha_real_act_real/actions.safetensors index 2373f1ee..2dd4a9b8 100644 --- a/tests/data/save_policy_to_safetensors/dora_aloha_real_act_real/actions.safetensors +++ b/tests/data/save_policy_to_safetensors/dora_aloha_real_act_real/actions.safetensors @@ -1,3 +1,3 @@ version https://git-lfs.github.com/spec/v1 -oid sha256:2fff6294b94cf42d4dd1249dcc5c3b0269d6d9c697f894e61b867d7ab81a94e4 +oid sha256:aab00b0349901450adbb8e0d7d4af1f743dd88e7e19f1bcfef821de8bdcc957d size 5104 diff --git a/tests/data/save_policy_to_safetensors/dora_aloha_real_act_real/grad_stats.safetensors b/tests/data/save_policy_to_safetensors/dora_aloha_real_act_real/grad_stats.safetensors index de40a20e..9b4dbdcc 100644 --- a/tests/data/save_policy_to_safetensors/dora_aloha_real_act_real/grad_stats.safetensors +++ b/tests/data/save_policy_to_safetensors/dora_aloha_real_act_real/grad_stats.safetensors @@ -1,3 +1,3 @@ version https://git-lfs.github.com/spec/v1 -oid sha256:4aa23e51607604a18b70fa42edbbe1af34f119d985628fc27cc1bbb0efbc8901 -size 31688 +oid sha256:de70c3055aa052f5b811ec7c2994ec6861efe645c6caee41e04a3460598500d5 +size 31672 diff --git a/tests/data/save_policy_to_safetensors/dora_aloha_real_act_real/output_dict.safetensors b/tests/data/save_policy_to_safetensors/dora_aloha_real_act_real/output_dict.safetensors index 8602cc56..f0b5cccc 100644 --- a/tests/data/save_policy_to_safetensors/dora_aloha_real_act_real/output_dict.safetensors +++ b/tests/data/save_policy_to_safetensors/dora_aloha_real_act_real/output_dict.safetensors @@ -1,3 +1,3 @@ version https://git-lfs.github.com/spec/v1 -oid sha256:6fd368406c93cb562a69ff11cf7adf34a4b223507dcb2b9e9b8f44ee1036988a +oid sha256:d4070bd1f1cd8c72bc2daf628088e42b8ef113f6df0bfd9e91be052bc90038c3 size 68 diff --git a/tests/data/save_policy_to_safetensors/dora_aloha_real_act_real/param_stats.safetensors b/tests/data/save_policy_to_safetensors/dora_aloha_real_act_real/param_stats.safetensors index a6612b7f..cf09e1dc 100644 --- a/tests/data/save_policy_to_safetensors/dora_aloha_real_act_real/param_stats.safetensors +++ b/tests/data/save_policy_to_safetensors/dora_aloha_real_act_real/param_stats.safetensors @@ -1,3 +1,3 @@ version https://git-lfs.github.com/spec/v1 -oid sha256:5663ee79a13bb70a1604b887dd21bf89d18482287442419c6cc6c5bf0e753e99 -size 34928 +oid sha256:19fdc1edf327e04132c1917024289b3d16e25a1ec2130f3df797fe07434dfbbd +size 34920 diff --git a/tests/data/save_policy_to_safetensors/dora_aloha_real_act_real_no_state/actions.safetensors b/tests/data/save_policy_to_safetensors/dora_aloha_real_act_real_no_state/actions.safetensors index 9f0ba883..11fa4eb8 100644 --- a/tests/data/save_policy_to_safetensors/dora_aloha_real_act_real_no_state/actions.safetensors +++ b/tests/data/save_policy_to_safetensors/dora_aloha_real_act_real_no_state/actions.safetensors @@ -1,3 +1,3 @@ version https://git-lfs.github.com/spec/v1 -oid sha256:fb1a45463efd860af2ca22c16c77d55a18bd96fef080ae77978845a2f22ef716 +oid sha256:dcd8ebaefd3ff267eb24654135d1efb179d713e6cfe6917f793a3e2483efd501 size 5104 diff --git a/tests/data/save_policy_to_safetensors/dora_aloha_real_act_real_no_state/grad_stats.safetensors b/tests/data/save_policy_to_safetensors/dora_aloha_real_act_real_no_state/grad_stats.safetensors index 2b01b94c..d0b98443 100644 --- a/tests/data/save_policy_to_safetensors/dora_aloha_real_act_real_no_state/grad_stats.safetensors +++ b/tests/data/save_policy_to_safetensors/dora_aloha_real_act_real_no_state/grad_stats.safetensors @@ -1,3 +1,3 @@ version https://git-lfs.github.com/spec/v1 -oid sha256:aa5a43e22f01d8e2f8d19f31753608794f1edbd74aaf71660091ab80ea58dc9b -size 30808 +oid sha256:107e98647ed1081745476b250df8848c0c430b2aff51d614f6b2db95684467aa +size 30800 diff --git a/tests/data/save_policy_to_safetensors/dora_aloha_real_act_real_no_state/param_stats.safetensors b/tests/data/save_policy_to_safetensors/dora_aloha_real_act_real_no_state/param_stats.safetensors index 335d2a55..e00dec82 100644 --- a/tests/data/save_policy_to_safetensors/dora_aloha_real_act_real_no_state/param_stats.safetensors +++ b/tests/data/save_policy_to_safetensors/dora_aloha_real_act_real_no_state/param_stats.safetensors @@ -1,3 +1,3 @@ version https://git-lfs.github.com/spec/v1 -oid sha256:54d1f75cf67a7b1d7a7c6865ecb9b1cc86a2f032d1890245f8996789ab6e0df6 -size 33608 +oid sha256:adbae737c987f912509d3fba06f332bda700bfc2c6d83a09c969e9d7a3ca75f7 +size 33600