add explanations of each part of vq-bet

This commit is contained in:
jayLEE0301 2024-05-24 15:58:46 -04:00
parent d71db341bc
commit 71ec76fe2a
2 changed files with 73 additions and 19 deletions

View File

@ -13,8 +13,8 @@ class VQBeTConfig:
Args: Args:
n_obs_steps: Number of environment steps worth of observations to pass to the policy (takes the n_obs_steps: Number of environment steps worth of observations to pass to the policy (takes the
current step and additional steps going back). current step and additional steps going back).
n_action_pred_token: TODO(jayLEE0301) n_action_pred_token: Number of future tokens that VQ-BeT predicts.
n_action_pred_chunk: TODO(jayLEE0301) n_action_pred_chunk: Action chunk size of each aciton prediction token.
input_shapes: A dictionary defining the shapes of the input data for the policy. input_shapes: A dictionary defining the shapes of the input data for the policy.
The key represents the input data name, and the value is a list indicating the dimensions The key represents the input data name, and the value is a list indicating the dimensions
of the corresponding data. For example, "observation.image" refers to an input from of the corresponding data. For example, "observation.image" refers to an input from
@ -40,21 +40,23 @@ class VQBeTConfig:
use_group_norm: Whether to replace batch normalization with group normalization in the backbone. use_group_norm: Whether to replace batch normalization with group normalization in the backbone.
The group sizes are set to be about 16 (to be precise, feature_dim // 16). The group sizes are set to be about 16 (to be precise, feature_dim // 16).
spatial_softmax_num_keypoints: Number of keypoints for SpatialSoftmax. spatial_softmax_num_keypoints: Number of keypoints for SpatialSoftmax.
discretize_step: TODO(jayLEE0301) discretize_step: Number of optimization steps for training Residual VQ.
vqvae_groups: TODO(jayLEE0301) vqvae_groups: Number of layers in Residual VQ.
vqvae_n_embed: TODO(jayLEE0301) vqvae_n_embed: Number of embedding vectors in the RVQ dictionary (each layer).
vqvae_embedding_dim: TODO(jayLEE0301) vqvae_embedding_dim: Dimension of each embedding vector in the RVQ dictionary.
vqvae_enc_hidden_dim: TODO(jayLEE0301) vqvae_enc_hidden_dim: Size of hidden dimensions of Encoder / Decoder part of Residaul VQ-VAE
gpt_block_size: TODO(jayLEE0301) gpt_block_size: Max block size of minGPT (should be larger than the number of input tokens)
gpt_input_dim: TODO(jayLEE0301) gpt_input_dim: Size of output input of GPT. This is also used as the dimension of observation features.
gpt_output_dim: TODO(jayLEE0301) gpt_output_dim: Size of output dimension of GPT. This is also used as a input dimension of offset / bin prediction headers.
gpt_n_layer: TODO(jayLEE0301) gpt_n_layer: Number of layers of GPT
gpt_n_head: TODO(jayLEE0301) gpt_n_head: Number of headers of GPT
gpt_hidden_dim: TODO(jayLEE0301) gpt_hidden_dim: Size of hidden dimensions of GPT
dropout: TODO(jayLEE0301) gpt_num_obs_mode: Number of different observation modes. (e.g., PushT env: {state, image observation}, thus 2.)
mlp_hidden_dim: TODO(jayLEE0301) dropout: Dropout rate for GPT
offset_loss_weight: TODO(jayLEE0301) mlp_hidden_dim: Size of hidden dimensions of offset header / bin prediction headers parts of VQ-BeT
secondary_code_loss_weight: TODO(jayLEE0301) offset_loss_weight: A constant that is multiplied to the offset loss
secondary_code_loss_weight: A constant that is multiplied to the secondary loss
bet_softmax_temperature: Sampling temperature of code for rollout with VQ-BeT
""" """
# Inputs / output structure. # Inputs / output structure.

View File

@ -130,8 +130,60 @@ class VQBeTPolicy(nn.Module, PyTorchModelHubMixin):
class VQBeTModel(nn.Module): class VQBeTModel(nn.Module):
""" """VQ-BeT: The underlying neural network for VQ-BeT
TODO(jayLEE0301)
Note: In this code we use the terms `rgb_encoder`, 'policy', `action_head`. The meanings are as follows.
- The `rgb_encoder` process rgb-style image observations to one-dimensional embedding vectors
- A `policy` is a minGPT architecture, that takes observation sequences and action query tokens to generate `features`.
- These `features` pass through the action head, which passes through the code prediction, offset prediction head,
and finally generates a prediction for the action chunks.
-------------------------------** legend **-------------------------------
n = n_obs_steps, p = n_action_pred_token, c = n_action_pred_chunk)
o_{t} : visual observation at timestep {t}
s_{t} : state observation at timestep {t}
a_{t} : action at timestep {t}
A_Q : action_query_token
--------------------------------------------------------------------------
Phase 1. Discretize action using Residual VQ (for config.discretize_step steps)
RVQ encoder Residual RVQ Decoder
(a_{t}~a_{t+p}) Code Quantizer
Phase 2.
o_{t-n+1} o_{t-n+2} ... o_{t}
s_{t-n+1} s_{t-n+2} ... s_{t} p
A_Q A_Q ... A_Q ... A_Q
GPT => policy
code offset code offset code offset code offset
=> action_head
RVQ Decoder RVQ Decoder RVQ Decoder RVQ Decoder
+ + + +
action chunk action chunk action chunk action chunk
a_{t_n+1} ~ a_{t_n+2} ~ a_{t} ~ ... a_{t+p-1} ~
a_{t_n+c} a_{t_n+c+1} a_{t+c-1} a_{t+p+c-1}
ONLY this chunk is used in rollout!
""" """
def __init__(self, config: VQBeTConfig): def __init__(self, config: VQBeTConfig):
super().__init__() super().__init__()