Address Alexander comments

This commit is contained in:
Remi Cadene 2024-05-31 10:08:11 +00:00
parent b60f4efa7c
commit eeb215540b
5 changed files with 24 additions and 25 deletions

View File

@ -138,8 +138,8 @@ available_policies = [
# keys and values refer to yaml files # keys and values refer to yaml files
available_policies_per_env = { available_policies_per_env = {
"aloha": ["act"], "aloha": ["act", "diffusion"],
"pusht": ["diffusion"], "pusht": ["act", "diffusion"],
"xarm": ["tdmpc"], "xarm": ["tdmpc"],
"dora_aloha_real": ["act_real"], "dora_aloha_real": ["act_real"],
} }

View File

@ -26,11 +26,9 @@ class ACTConfig:
Those are: `input_shapes` and 'output_shapes`. Those are: `input_shapes` and 'output_shapes`.
Notes on the inputs and outputs: Notes on the inputs and outputs:
- "observation.state" is required as an input key.
- At least one key starting with "observation.image is required as an input. - At least one key starting with "observation.image is required as an input.
- If there are multiple keys beginning with "observation.image" they are treated as multiple camera - If there are multiple keys beginning with "observation.images." they are treated as multiple camera
views. views. Right now we only support all images having the same shape.
Right now we only support all images having the same shape.
- "action" is required as an output key. - "action" is required as an output key.
Args: Args:

View File

@ -200,13 +200,12 @@ class ACT(nn.Module):
self.config = config self.config = config
# BERT style VAE encoder with input tokens [cls, robot_state, *action_sequence]. # BERT style VAE encoder with input tokens [cls, robot_state, *action_sequence].
# The cls token forms parameters of the latent's distribution (like this [*means, *log_variances]). # The cls token forms parameters of the latent's distribution (like this [*means, *log_variances]).
self.has_state = "observation.state" in config.input_shapes self.use_input_state = "observation.state" in config.input_shapes
self.latent_dim = config.latent_dim
if self.config.use_vae: if self.config.use_vae:
self.vae_encoder = ACTEncoder(config) self.vae_encoder = ACTEncoder(config)
self.vae_encoder_cls_embed = nn.Embedding(1, config.dim_model) self.vae_encoder_cls_embed = nn.Embedding(1, config.dim_model)
# Projection layer for joint-space configuration to hidden dimension. # Projection layer for joint-space configuration to hidden dimension.
if self.has_state: if self.use_input_state:
self.vae_encoder_robot_state_input_proj = nn.Linear( self.vae_encoder_robot_state_input_proj = nn.Linear(
config.input_shapes["observation.state"][0], config.dim_model config.input_shapes["observation.state"][0], config.dim_model
) )
@ -215,10 +214,12 @@ class ACT(nn.Module):
config.output_shapes["action"][0], config.dim_model config.output_shapes["action"][0], config.dim_model
) )
# Projection layer from the VAE encoder's output to the latent distribution's parameter space. # Projection layer from the VAE encoder's output to the latent distribution's parameter space.
self.vae_encoder_latent_output_proj = nn.Linear(config.dim_model, self.latent_dim * 2) self.vae_encoder_latent_output_proj = nn.Linear(config.dim_model, config.latent_dim * 2)
# Fixed sinusoidal positional embedding for the input to the VAE encoder. Unsqueeze for batch # Fixed sinusoidal positional embedding for the input to the VAE encoder. Unsqueeze for batch
# dimension. # dimension.
num_input_token_encoder = 1 + 1 + config.chunk_size if self.has_state else 1 + config.chunk_size num_input_token_encoder = 1 + config.chunk_size
if self.use_input_state:
num_input_token_encoder += 1
self.register_buffer( self.register_buffer(
"vae_encoder_pos_enc", "vae_encoder_pos_enc",
create_sinusoidal_pos_embedding(num_input_token_encoder, config.dim_model).unsqueeze(0), create_sinusoidal_pos_embedding(num_input_token_encoder, config.dim_model).unsqueeze(0),
@ -241,16 +242,16 @@ class ACT(nn.Module):
# Transformer encoder input projections. The tokens will be structured like # Transformer encoder input projections. The tokens will be structured like
# [latent, robot_state, image_feature_map_pixels]. # [latent, robot_state, image_feature_map_pixels].
if self.has_state: if self.use_input_state:
self.encoder_robot_state_input_proj = nn.Linear( self.encoder_robot_state_input_proj = nn.Linear(
config.input_shapes["observation.state"][0], config.dim_model config.input_shapes["observation.state"][0], config.dim_model
) )
self.encoder_latent_input_proj = nn.Linear(self.latent_dim, config.dim_model) self.encoder_latent_input_proj = nn.Linear(config.latent_dim, config.dim_model)
self.encoder_img_feat_input_proj = nn.Conv2d( self.encoder_img_feat_input_proj = nn.Conv2d(
backbone_model.fc.in_features, config.dim_model, kernel_size=1 backbone_model.fc.in_features, config.dim_model, kernel_size=1
) )
# Transformer encoder positional embeddings. # Transformer encoder positional embeddings.
num_input_token_decoder = 2 if self.has_state else 1 num_input_token_decoder = 2 if self.use_input_state else 1
self.encoder_robot_and_latent_pos_embed = nn.Embedding(num_input_token_decoder, config.dim_model) self.encoder_robot_and_latent_pos_embed = nn.Embedding(num_input_token_decoder, config.dim_model)
self.encoder_cam_feat_pos_embed = ACTSinusoidalPositionEmbedding2d(config.dim_model // 2) self.encoder_cam_feat_pos_embed = ACTSinusoidalPositionEmbedding2d(config.dim_model // 2)
@ -298,12 +299,12 @@ class ACT(nn.Module):
cls_embed = einops.repeat( cls_embed = einops.repeat(
self.vae_encoder_cls_embed.weight, "1 d -> b 1 d", b=batch_size self.vae_encoder_cls_embed.weight, "1 d -> b 1 d", b=batch_size
) # (B, 1, D) ) # (B, 1, D)
if self.has_state: if self.use_input_state:
robot_state_embed = self.vae_encoder_robot_state_input_proj(batch["observation.state"]) robot_state_embed = self.vae_encoder_robot_state_input_proj(batch["observation.state"])
robot_state_embed = robot_state_embed.unsqueeze(1) # (B, 1, D) robot_state_embed = robot_state_embed.unsqueeze(1) # (B, 1, D)
action_embed = self.vae_encoder_action_input_proj(batch["action"]) # (B, S, D) action_embed = self.vae_encoder_action_input_proj(batch["action"]) # (B, S, D)
if self.has_state: if self.use_input_state:
vae_encoder_input = [cls_embed, robot_state_embed, action_embed] # (B, S+2, D) vae_encoder_input = [cls_embed, robot_state_embed, action_embed] # (B, S+2, D)
else: else:
vae_encoder_input = [cls_embed, action_embed] vae_encoder_input = [cls_embed, action_embed]
@ -318,9 +319,9 @@ class ACT(nn.Module):
vae_encoder_input.permute(1, 0, 2), pos_embed=pos_embed.permute(1, 0, 2) vae_encoder_input.permute(1, 0, 2), pos_embed=pos_embed.permute(1, 0, 2)
)[0] # select the class token, with shape (B, D) )[0] # select the class token, with shape (B, D)
latent_pdf_params = self.vae_encoder_latent_output_proj(cls_token_out) latent_pdf_params = self.vae_encoder_latent_output_proj(cls_token_out)
mu = latent_pdf_params[:, : self.latent_dim] mu = latent_pdf_params[:, : self.config.latent_dim]
# This is 2log(sigma). Done this way to match the original implementation. # This is 2log(sigma). Done this way to match the original implementation.
log_sigma_x2 = latent_pdf_params[:, self.latent_dim :] log_sigma_x2 = latent_pdf_params[:, self.config.latent_dim :]
# Sample the latent with the reparameterization trick. # Sample the latent with the reparameterization trick.
latent_sample = mu + log_sigma_x2.div(2).exp() * torch.randn_like(mu) latent_sample = mu + log_sigma_x2.div(2).exp() * torch.randn_like(mu)
@ -328,7 +329,7 @@ class ACT(nn.Module):
# When not using the VAE encoder, we set the latent to be all zeros. # When not using the VAE encoder, we set the latent to be all zeros.
mu = log_sigma_x2 = None mu = log_sigma_x2 = None
# TODO(rcadene, alexander-soare): remove call to `.to` to speedup forward ; precompute and use buffer # TODO(rcadene, alexander-soare): remove call to `.to` to speedup forward ; precompute and use buffer
latent_sample = torch.zeros([batch_size, self.latent_dim], dtype=torch.float32).to( latent_sample = torch.zeros([batch_size, self.config.latent_dim], dtype=torch.float32).to(
batch["observation.state"].device batch["observation.state"].device
) )
@ -350,12 +351,12 @@ class ACT(nn.Module):
cam_pos_embed = torch.cat(all_cam_pos_embeds, axis=-1) cam_pos_embed = torch.cat(all_cam_pos_embeds, axis=-1)
# Get positional embeddings for robot state and latent. # Get positional embeddings for robot state and latent.
if self.has_state: if self.use_input_state:
robot_state_embed = self.encoder_robot_state_input_proj(batch["observation.state"]) # (B, C) robot_state_embed = self.encoder_robot_state_input_proj(batch["observation.state"]) # (B, C)
latent_embed = self.encoder_latent_input_proj(latent_sample) # (B, C) latent_embed = self.encoder_latent_input_proj(latent_sample) # (B, C)
# Stack encoder input and positional embeddings moving to (S, B, C). # Stack encoder input and positional embeddings moving to (S, B, C).
encoder_in_feats = [latent_embed, robot_state_embed] if self.has_state else [latent_embed] encoder_in_feats = [latent_embed, robot_state_embed] if self.use_input_state else [latent_embed]
encoder_in = torch.cat( encoder_in = torch.cat(
[ [
torch.stack(encoder_in_feats, axis=0), torch.stack(encoder_in_feats, axis=0),

View File

@ -28,10 +28,7 @@ class DiffusionConfig:
Notes on the inputs and outputs: Notes on the inputs and outputs:
- "observation.state" is required as an input key. - "observation.state" is required as an input key.
- At least one key starting with "observation.image is required as an input. - A key starting with "observation.image is required as an input.
- If there are multiple keys beginning with "observation.image" they are treated as multiple camera
views.
Right now we only support all images having the same shape.
- "action" is required as an output key. - "action" is required as an output key.
Args: Args:

View File

@ -86,6 +86,9 @@ def test_policy(env_name, policy_name, extra_overrides):
- Updating the policy. - Updating the policy.
- Using the policy to select actions at inference time. - Using the policy to select actions at inference time.
- Test the action can be applied to the policy - Test the action can be applied to the policy
Note: We test various combinations of policy and dataset. The combinations are by no means exhaustive,
and for now we add tests as we see fit.
""" """
cfg = init_hydra_config( cfg = init_hydra_config(
DEFAULT_CONFIG_PATH, DEFAULT_CONFIG_PATH,