From eeb215540b7c8007ae6372b914f296ccc7457c79 Mon Sep 17 00:00:00 2001 From: Remi Cadene Date: Fri, 31 May 2024 10:08:11 +0000 Subject: [PATCH] Address Alexander comments --- lerobot/__init__.py | 4 +-- .../common/policies/act/configuration_act.py | 6 ++-- lerobot/common/policies/act/modeling_act.py | 31 ++++++++++--------- .../diffusion/configuration_diffusion.py | 5 +-- tests/test_policies.py | 3 ++ 5 files changed, 24 insertions(+), 25 deletions(-) diff --git a/lerobot/__init__.py b/lerobot/__init__.py index a5a90fb4..8e554998 100644 --- a/lerobot/__init__.py +++ b/lerobot/__init__.py @@ -138,8 +138,8 @@ available_policies = [ # keys and values refer to yaml files available_policies_per_env = { - "aloha": ["act"], - "pusht": ["diffusion"], + "aloha": ["act", "diffusion"], + "pusht": ["act", "diffusion"], "xarm": ["tdmpc"], "dora_aloha_real": ["act_real"], } diff --git a/lerobot/common/policies/act/configuration_act.py b/lerobot/common/policies/act/configuration_act.py index 82bc6d8e..d33e8bd9 100644 --- a/lerobot/common/policies/act/configuration_act.py +++ b/lerobot/common/policies/act/configuration_act.py @@ -26,11 +26,9 @@ class ACTConfig: Those are: `input_shapes` and 'output_shapes`. Notes on the inputs and outputs: - - "observation.state" is required as an input key. - At least one key starting with "observation.image is required as an input. - - If there are multiple keys beginning with "observation.image" they are treated as multiple camera - views. - Right now we only support all images having the same shape. + - If there are multiple keys beginning with "observation.images." they are treated as multiple camera + views. Right now we only support all images having the same shape. - "action" is required as an output key. Args: diff --git a/lerobot/common/policies/act/modeling_act.py b/lerobot/common/policies/act/modeling_act.py index 141dc862..bef59bec 100644 --- a/lerobot/common/policies/act/modeling_act.py +++ b/lerobot/common/policies/act/modeling_act.py @@ -200,13 +200,12 @@ class ACT(nn.Module): self.config = config # BERT style VAE encoder with input tokens [cls, robot_state, *action_sequence]. # The cls token forms parameters of the latent's distribution (like this [*means, *log_variances]). - self.has_state = "observation.state" in config.input_shapes - self.latent_dim = config.latent_dim + self.use_input_state = "observation.state" in config.input_shapes if self.config.use_vae: self.vae_encoder = ACTEncoder(config) self.vae_encoder_cls_embed = nn.Embedding(1, config.dim_model) # Projection layer for joint-space configuration to hidden dimension. - if self.has_state: + if self.use_input_state: self.vae_encoder_robot_state_input_proj = nn.Linear( config.input_shapes["observation.state"][0], config.dim_model ) @@ -215,10 +214,12 @@ class ACT(nn.Module): config.output_shapes["action"][0], config.dim_model ) # Projection layer from the VAE encoder's output to the latent distribution's parameter space. - self.vae_encoder_latent_output_proj = nn.Linear(config.dim_model, self.latent_dim * 2) + self.vae_encoder_latent_output_proj = nn.Linear(config.dim_model, config.latent_dim * 2) # Fixed sinusoidal positional embedding for the input to the VAE encoder. Unsqueeze for batch # dimension. - num_input_token_encoder = 1 + 1 + config.chunk_size if self.has_state else 1 + config.chunk_size + num_input_token_encoder = 1 + config.chunk_size + if self.use_input_state: + num_input_token_encoder += 1 self.register_buffer( "vae_encoder_pos_enc", create_sinusoidal_pos_embedding(num_input_token_encoder, config.dim_model).unsqueeze(0), @@ -241,16 +242,16 @@ class ACT(nn.Module): # Transformer encoder input projections. The tokens will be structured like # [latent, robot_state, image_feature_map_pixels]. - if self.has_state: + if self.use_input_state: self.encoder_robot_state_input_proj = nn.Linear( config.input_shapes["observation.state"][0], config.dim_model ) - self.encoder_latent_input_proj = nn.Linear(self.latent_dim, config.dim_model) + self.encoder_latent_input_proj = nn.Linear(config.latent_dim, config.dim_model) self.encoder_img_feat_input_proj = nn.Conv2d( backbone_model.fc.in_features, config.dim_model, kernel_size=1 ) # Transformer encoder positional embeddings. - num_input_token_decoder = 2 if self.has_state else 1 + num_input_token_decoder = 2 if self.use_input_state else 1 self.encoder_robot_and_latent_pos_embed = nn.Embedding(num_input_token_decoder, config.dim_model) self.encoder_cam_feat_pos_embed = ACTSinusoidalPositionEmbedding2d(config.dim_model // 2) @@ -298,12 +299,12 @@ class ACT(nn.Module): cls_embed = einops.repeat( self.vae_encoder_cls_embed.weight, "1 d -> b 1 d", b=batch_size ) # (B, 1, D) - if self.has_state: + if self.use_input_state: robot_state_embed = self.vae_encoder_robot_state_input_proj(batch["observation.state"]) robot_state_embed = robot_state_embed.unsqueeze(1) # (B, 1, D) action_embed = self.vae_encoder_action_input_proj(batch["action"]) # (B, S, D) - if self.has_state: + if self.use_input_state: vae_encoder_input = [cls_embed, robot_state_embed, action_embed] # (B, S+2, D) else: vae_encoder_input = [cls_embed, action_embed] @@ -318,9 +319,9 @@ class ACT(nn.Module): vae_encoder_input.permute(1, 0, 2), pos_embed=pos_embed.permute(1, 0, 2) )[0] # select the class token, with shape (B, D) latent_pdf_params = self.vae_encoder_latent_output_proj(cls_token_out) - mu = latent_pdf_params[:, : self.latent_dim] + mu = latent_pdf_params[:, : self.config.latent_dim] # This is 2log(sigma). Done this way to match the original implementation. - log_sigma_x2 = latent_pdf_params[:, self.latent_dim :] + log_sigma_x2 = latent_pdf_params[:, self.config.latent_dim :] # Sample the latent with the reparameterization trick. latent_sample = mu + log_sigma_x2.div(2).exp() * torch.randn_like(mu) @@ -328,7 +329,7 @@ class ACT(nn.Module): # When not using the VAE encoder, we set the latent to be all zeros. mu = log_sigma_x2 = None # TODO(rcadene, alexander-soare): remove call to `.to` to speedup forward ; precompute and use buffer - latent_sample = torch.zeros([batch_size, self.latent_dim], dtype=torch.float32).to( + latent_sample = torch.zeros([batch_size, self.config.latent_dim], dtype=torch.float32).to( batch["observation.state"].device ) @@ -350,12 +351,12 @@ class ACT(nn.Module): cam_pos_embed = torch.cat(all_cam_pos_embeds, axis=-1) # Get positional embeddings for robot state and latent. - if self.has_state: + if self.use_input_state: robot_state_embed = self.encoder_robot_state_input_proj(batch["observation.state"]) # (B, C) latent_embed = self.encoder_latent_input_proj(latent_sample) # (B, C) # Stack encoder input and positional embeddings moving to (S, B, C). - encoder_in_feats = [latent_embed, robot_state_embed] if self.has_state else [latent_embed] + encoder_in_feats = [latent_embed, robot_state_embed] if self.use_input_state else [latent_embed] encoder_in = torch.cat( [ torch.stack(encoder_in_feats, axis=0), diff --git a/lerobot/common/policies/diffusion/configuration_diffusion.py b/lerobot/common/policies/diffusion/configuration_diffusion.py index 48783d89..59ed1656 100644 --- a/lerobot/common/policies/diffusion/configuration_diffusion.py +++ b/lerobot/common/policies/diffusion/configuration_diffusion.py @@ -28,10 +28,7 @@ class DiffusionConfig: Notes on the inputs and outputs: - "observation.state" is required as an input key. - - At least one key starting with "observation.image is required as an input. - - If there are multiple keys beginning with "observation.image" they are treated as multiple camera - views. - Right now we only support all images having the same shape. + - A key starting with "observation.image is required as an input. - "action" is required as an output key. Args: diff --git a/tests/test_policies.py b/tests/test_policies.py index 6378e254..22d9c294 100644 --- a/tests/test_policies.py +++ b/tests/test_policies.py @@ -86,6 +86,9 @@ def test_policy(env_name, policy_name, extra_overrides): - Updating the policy. - Using the policy to select actions at inference time. - Test the action can be applied to the policy + + Note: We test various combinations of policy and dataset. The combinations are by no means exhaustive, + and for now we add tests as we see fit. """ cfg = init_hydra_config( DEFAULT_CONFIG_PATH,