From 71ec76fe2ac36c6b135d4548985eaf371a8a3b9c Mon Sep 17 00:00:00 2001 From: jayLEE0301 Date: Fri, 24 May 2024 15:58:46 -0400 Subject: [PATCH] add explanations of each part of vq-bet --- .../policies/vqbet/configuration_vqbet.py | 36 ++++++------ .../common/policies/vqbet/modeling_vqbet.py | 56 ++++++++++++++++++- 2 files changed, 73 insertions(+), 19 deletions(-) diff --git a/lerobot/common/policies/vqbet/configuration_vqbet.py b/lerobot/common/policies/vqbet/configuration_vqbet.py index 7533a819..477bb789 100644 --- a/lerobot/common/policies/vqbet/configuration_vqbet.py +++ b/lerobot/common/policies/vqbet/configuration_vqbet.py @@ -13,8 +13,8 @@ class VQBeTConfig: Args: n_obs_steps: Number of environment steps worth of observations to pass to the policy (takes the current step and additional steps going back). - n_action_pred_token: TODO(jayLEE0301) - n_action_pred_chunk: TODO(jayLEE0301) + n_action_pred_token: Number of future tokens that VQ-BeT predicts. + n_action_pred_chunk: Action chunk size of each aciton prediction token. input_shapes: A dictionary defining the shapes of the input data for the policy. The key represents the input data name, and the value is a list indicating the dimensions of the corresponding data. For example, "observation.image" refers to an input from @@ -40,21 +40,23 @@ class VQBeTConfig: use_group_norm: Whether to replace batch normalization with group normalization in the backbone. The group sizes are set to be about 16 (to be precise, feature_dim // 16). spatial_softmax_num_keypoints: Number of keypoints for SpatialSoftmax. - discretize_step: TODO(jayLEE0301) - vqvae_groups: TODO(jayLEE0301) - vqvae_n_embed: TODO(jayLEE0301) - vqvae_embedding_dim: TODO(jayLEE0301) - vqvae_enc_hidden_dim: TODO(jayLEE0301) - gpt_block_size: TODO(jayLEE0301) - gpt_input_dim: TODO(jayLEE0301) - gpt_output_dim: TODO(jayLEE0301) - gpt_n_layer: TODO(jayLEE0301) - gpt_n_head: TODO(jayLEE0301) - gpt_hidden_dim: TODO(jayLEE0301) - dropout: TODO(jayLEE0301) - mlp_hidden_dim: TODO(jayLEE0301) - offset_loss_weight: TODO(jayLEE0301) - secondary_code_loss_weight: TODO(jayLEE0301) + discretize_step: Number of optimization steps for training Residual VQ. + vqvae_groups: Number of layers in Residual VQ. + vqvae_n_embed: Number of embedding vectors in the RVQ dictionary (each layer). + vqvae_embedding_dim: Dimension of each embedding vector in the RVQ dictionary. + vqvae_enc_hidden_dim: Size of hidden dimensions of Encoder / Decoder part of Residaul VQ-VAE + gpt_block_size: Max block size of minGPT (should be larger than the number of input tokens) + gpt_input_dim: Size of output input of GPT. This is also used as the dimension of observation features. + gpt_output_dim: Size of output dimension of GPT. This is also used as a input dimension of offset / bin prediction headers. + gpt_n_layer: Number of layers of GPT + gpt_n_head: Number of headers of GPT + gpt_hidden_dim: Size of hidden dimensions of GPT + gpt_num_obs_mode: Number of different observation modes. (e.g., PushT env: {state, image observation}, thus 2.) + dropout: Dropout rate for GPT + mlp_hidden_dim: Size of hidden dimensions of offset header / bin prediction headers parts of VQ-BeT + offset_loss_weight: A constant that is multiplied to the offset loss + secondary_code_loss_weight: A constant that is multiplied to the secondary loss + bet_softmax_temperature: Sampling temperature of code for rollout with VQ-BeT """ # Inputs / output structure. diff --git a/lerobot/common/policies/vqbet/modeling_vqbet.py b/lerobot/common/policies/vqbet/modeling_vqbet.py index 817926ec..8ea6353c 100644 --- a/lerobot/common/policies/vqbet/modeling_vqbet.py +++ b/lerobot/common/policies/vqbet/modeling_vqbet.py @@ -130,8 +130,60 @@ class VQBeTPolicy(nn.Module, PyTorchModelHubMixin): class VQBeTModel(nn.Module): - """ - TODO(jayLEE0301) + """VQ-BeT: The underlying neural network for VQ-BeT + + Note: In this code we use the terms `rgb_encoder`, 'policy', `action_head`. The meanings are as follows. + - The `rgb_encoder` process rgb-style image observations to one-dimensional embedding vectors + - A `policy` is a minGPT architecture, that takes observation sequences and action query tokens to generate `features`. + - These `features` pass through the action head, which passes through the code prediction, offset prediction head, + and finally generates a prediction for the action chunks. + + -------------------------------** legend **------------------------------- + │ n = n_obs_steps, p = n_action_pred_token, c = n_action_pred_chunk) │ + │ o_{t} : visual observation at timestep {t} │ + │ s_{t} : state observation at timestep {t} │ + │ a_{t} : action at timestep {t} │ + │ A_Q : action_query_token │ + -------------------------------------------------------------------------- + + + Phase 1. Discretize action using Residual VQ (for config.discretize_step steps) + + + ┌─────────────────┐ ┌─────────────────┐ ┌─────────────────┐ + │ │ │ │ │ │ + │ RVQ encoder │ ─► │ Residual │ ─► │ RVQ Decoder │ + │ (a_{t}~a_{t+p}) │ │ Code Quantizer │ │ │ + │ │ │ │ │ │ + └─────────────────┘ └─────────────────┘ └─────────────────┘ + + Phase 2. + + + o_{t-n+1} o_{t-n+2} ... o_{t} + │ │ │ + │ s_{t-n+1} │ s_{t-n+2} ... │ s_{t} p + │ │ │ │ │ │ ┌───────┴───────┐ + │ │ A_Q │ │ A_Q ... │ │ A_Q ... A_Q + │ │ │ │ │ │ │ │ │ │ + ┌───▼─────▼─────▼─────▼─────▼─────▼─────────────────▼─────▼─────▼───────────────▼───┐ + │ │ + │ GPT │ => policy + │ │ + └───────────────▼─────────────────▼─────────────────────────────▼───────────────▼───┘ + │ │ │ │ + ┌───┴───┐ ┌───┴───┐ ┌───┴───┐ ┌───┴───┐ + code offset code offset code offset code offset + ▼ │ ▼ │ ▼ │ ▼ │ => action_head + RVQ Decoder │ RVQ Decoder │ RVQ Decoder │ RVQ Decoder │ + └── + ──┘ └── + ──┘ └── + ──┘ └── + ──┘ + ▼ ▼ ▼ ▼ + action chunk action chunk action chunk action chunk + a_{t_n+1} ~ a_{t_n+2} ~ a_{t} ~ ... a_{t+p-1} ~ + a_{t_n+c} a_{t_n+c+1} a_{t+c-1} a_{t+p+c-1} + + ▼ + ONLY this chunk is used in rollout! """ def __init__(self, config: VQBeTConfig): super().__init__()