add explanations of each part of vq-bet
This commit is contained in:
parent
d71db341bc
commit
71ec76fe2a
|
@ -13,8 +13,8 @@ class VQBeTConfig:
|
||||||
Args:
|
Args:
|
||||||
n_obs_steps: Number of environment steps worth of observations to pass to the policy (takes the
|
n_obs_steps: Number of environment steps worth of observations to pass to the policy (takes the
|
||||||
current step and additional steps going back).
|
current step and additional steps going back).
|
||||||
n_action_pred_token: TODO(jayLEE0301)
|
n_action_pred_token: Number of future tokens that VQ-BeT predicts.
|
||||||
n_action_pred_chunk: TODO(jayLEE0301)
|
n_action_pred_chunk: Action chunk size of each aciton prediction token.
|
||||||
input_shapes: A dictionary defining the shapes of the input data for the policy.
|
input_shapes: A dictionary defining the shapes of the input data for the policy.
|
||||||
The key represents the input data name, and the value is a list indicating the dimensions
|
The key represents the input data name, and the value is a list indicating the dimensions
|
||||||
of the corresponding data. For example, "observation.image" refers to an input from
|
of the corresponding data. For example, "observation.image" refers to an input from
|
||||||
|
@ -40,21 +40,23 @@ class VQBeTConfig:
|
||||||
use_group_norm: Whether to replace batch normalization with group normalization in the backbone.
|
use_group_norm: Whether to replace batch normalization with group normalization in the backbone.
|
||||||
The group sizes are set to be about 16 (to be precise, feature_dim // 16).
|
The group sizes are set to be about 16 (to be precise, feature_dim // 16).
|
||||||
spatial_softmax_num_keypoints: Number of keypoints for SpatialSoftmax.
|
spatial_softmax_num_keypoints: Number of keypoints for SpatialSoftmax.
|
||||||
discretize_step: TODO(jayLEE0301)
|
discretize_step: Number of optimization steps for training Residual VQ.
|
||||||
vqvae_groups: TODO(jayLEE0301)
|
vqvae_groups: Number of layers in Residual VQ.
|
||||||
vqvae_n_embed: TODO(jayLEE0301)
|
vqvae_n_embed: Number of embedding vectors in the RVQ dictionary (each layer).
|
||||||
vqvae_embedding_dim: TODO(jayLEE0301)
|
vqvae_embedding_dim: Dimension of each embedding vector in the RVQ dictionary.
|
||||||
vqvae_enc_hidden_dim: TODO(jayLEE0301)
|
vqvae_enc_hidden_dim: Size of hidden dimensions of Encoder / Decoder part of Residaul VQ-VAE
|
||||||
gpt_block_size: TODO(jayLEE0301)
|
gpt_block_size: Max block size of minGPT (should be larger than the number of input tokens)
|
||||||
gpt_input_dim: TODO(jayLEE0301)
|
gpt_input_dim: Size of output input of GPT. This is also used as the dimension of observation features.
|
||||||
gpt_output_dim: TODO(jayLEE0301)
|
gpt_output_dim: Size of output dimension of GPT. This is also used as a input dimension of offset / bin prediction headers.
|
||||||
gpt_n_layer: TODO(jayLEE0301)
|
gpt_n_layer: Number of layers of GPT
|
||||||
gpt_n_head: TODO(jayLEE0301)
|
gpt_n_head: Number of headers of GPT
|
||||||
gpt_hidden_dim: TODO(jayLEE0301)
|
gpt_hidden_dim: Size of hidden dimensions of GPT
|
||||||
dropout: TODO(jayLEE0301)
|
gpt_num_obs_mode: Number of different observation modes. (e.g., PushT env: {state, image observation}, thus 2.)
|
||||||
mlp_hidden_dim: TODO(jayLEE0301)
|
dropout: Dropout rate for GPT
|
||||||
offset_loss_weight: TODO(jayLEE0301)
|
mlp_hidden_dim: Size of hidden dimensions of offset header / bin prediction headers parts of VQ-BeT
|
||||||
secondary_code_loss_weight: TODO(jayLEE0301)
|
offset_loss_weight: A constant that is multiplied to the offset loss
|
||||||
|
secondary_code_loss_weight: A constant that is multiplied to the secondary loss
|
||||||
|
bet_softmax_temperature: Sampling temperature of code for rollout with VQ-BeT
|
||||||
"""
|
"""
|
||||||
|
|
||||||
# Inputs / output structure.
|
# Inputs / output structure.
|
||||||
|
|
|
@ -130,8 +130,60 @@ class VQBeTPolicy(nn.Module, PyTorchModelHubMixin):
|
||||||
|
|
||||||
|
|
||||||
class VQBeTModel(nn.Module):
|
class VQBeTModel(nn.Module):
|
||||||
"""
|
"""VQ-BeT: The underlying neural network for VQ-BeT
|
||||||
TODO(jayLEE0301)
|
|
||||||
|
Note: In this code we use the terms `rgb_encoder`, 'policy', `action_head`. The meanings are as follows.
|
||||||
|
- The `rgb_encoder` process rgb-style image observations to one-dimensional embedding vectors
|
||||||
|
- A `policy` is a minGPT architecture, that takes observation sequences and action query tokens to generate `features`.
|
||||||
|
- These `features` pass through the action head, which passes through the code prediction, offset prediction head,
|
||||||
|
and finally generates a prediction for the action chunks.
|
||||||
|
|
||||||
|
-------------------------------** legend **-------------------------------
|
||||||
|
│ n = n_obs_steps, p = n_action_pred_token, c = n_action_pred_chunk) │
|
||||||
|
│ o_{t} : visual observation at timestep {t} │
|
||||||
|
│ s_{t} : state observation at timestep {t} │
|
||||||
|
│ a_{t} : action at timestep {t} │
|
||||||
|
│ A_Q : action_query_token │
|
||||||
|
--------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
Phase 1. Discretize action using Residual VQ (for config.discretize_step steps)
|
||||||
|
|
||||||
|
|
||||||
|
┌─────────────────┐ ┌─────────────────┐ ┌─────────────────┐
|
||||||
|
│ │ │ │ │ │
|
||||||
|
│ RVQ encoder │ ─► │ Residual │ ─► │ RVQ Decoder │
|
||||||
|
│ (a_{t}~a_{t+p}) │ │ Code Quantizer │ │ │
|
||||||
|
│ │ │ │ │ │
|
||||||
|
└─────────────────┘ └─────────────────┘ └─────────────────┘
|
||||||
|
|
||||||
|
Phase 2.
|
||||||
|
|
||||||
|
|
||||||
|
o_{t-n+1} o_{t-n+2} ... o_{t}
|
||||||
|
│ │ │
|
||||||
|
│ s_{t-n+1} │ s_{t-n+2} ... │ s_{t} p
|
||||||
|
│ │ │ │ │ │ ┌───────┴───────┐
|
||||||
|
│ │ A_Q │ │ A_Q ... │ │ A_Q ... A_Q
|
||||||
|
│ │ │ │ │ │ │ │ │ │
|
||||||
|
┌───▼─────▼─────▼─────▼─────▼─────▼─────────────────▼─────▼─────▼───────────────▼───┐
|
||||||
|
│ │
|
||||||
|
│ GPT │ => policy
|
||||||
|
│ │
|
||||||
|
└───────────────▼─────────────────▼─────────────────────────────▼───────────────▼───┘
|
||||||
|
│ │ │ │
|
||||||
|
┌───┴───┐ ┌───┴───┐ ┌───┴───┐ ┌───┴───┐
|
||||||
|
code offset code offset code offset code offset
|
||||||
|
▼ │ ▼ │ ▼ │ ▼ │ => action_head
|
||||||
|
RVQ Decoder │ RVQ Decoder │ RVQ Decoder │ RVQ Decoder │
|
||||||
|
└── + ──┘ └── + ──┘ └── + ──┘ └── + ──┘
|
||||||
|
▼ ▼ ▼ ▼
|
||||||
|
action chunk action chunk action chunk action chunk
|
||||||
|
a_{t_n+1} ~ a_{t_n+2} ~ a_{t} ~ ... a_{t+p-1} ~
|
||||||
|
a_{t_n+c} a_{t_n+c+1} a_{t+c-1} a_{t+p+c-1}
|
||||||
|
|
||||||
|
▼
|
||||||
|
ONLY this chunk is used in rollout!
|
||||||
"""
|
"""
|
||||||
def __init__(self, config: VQBeTConfig):
|
def __init__(self, config: VQBeTConfig):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
Loading…
Reference in New Issue