remove unused parameter explanations, fix some names of parameters
This commit is contained in:
parent
d209c0f73c
commit
e301caf182
|
@ -3,7 +3,7 @@ from dataclasses import dataclass, field
|
|||
|
||||
@dataclass
|
||||
class VQBeTConfig:
|
||||
"""Configuration class for DiffusionPolicy.
|
||||
"""Configuration class for VQ-BeT.
|
||||
|
||||
Defaults are configured for training with PushT providing proprioceptive and single camera observations.
|
||||
|
||||
|
@ -13,9 +13,8 @@ class VQBeTConfig:
|
|||
Args:
|
||||
n_obs_steps: Number of environment steps worth of observations to pass to the policy (takes the
|
||||
current step and additional steps going back).
|
||||
horizon: Diffusion model action prediction size as detailed in `DiffusionPolicy.select_action`.
|
||||
n_action_steps: The number of action steps to run in the environment for one invocation of the policy.
|
||||
See `DiffusionPolicy.select_action` for more details.
|
||||
n_action_pred_token: TODO(jayLEE0301)
|
||||
n_action_pred_chunk: TODO(jayLEE0301)
|
||||
input_shapes: A dictionary defining the shapes of the input data for the policy.
|
||||
The key represents the input data name, and the value is a list indicating the dimensions
|
||||
of the corresponding data. For example, "observation.image" refers to an input from
|
||||
|
@ -41,32 +40,21 @@ class VQBeTConfig:
|
|||
use_group_norm: Whether to replace batch normalization with group normalization in the backbone.
|
||||
The group sizes are set to be about 16 (to be precise, feature_dim // 16).
|
||||
spatial_softmax_num_keypoints: Number of keypoints for SpatialSoftmax.
|
||||
down_dims: Feature dimension for each stage of temporal downsampling in the diffusion modeling Unet.
|
||||
You may provide a variable number of dimensions, therefore also controlling the degree of
|
||||
downsampling.
|
||||
kernel_size: The convolutional kernel size of the diffusion modeling Unet.
|
||||
n_groups: Number of groups used in the group norm of the Unet's convolutional blocks.
|
||||
diffusion_step_embed_dim: The Unet is conditioned on the diffusion timestep via a small non-linear
|
||||
network. This is the output dimension of that network, i.e., the embedding dimension.
|
||||
use_film_scale_modulation: FiLM (https://arxiv.org/abs/1709.07871) is used for the Unet conditioning.
|
||||
Bias modulation is used be default, while this parameter indicates whether to also use scale
|
||||
modulation.
|
||||
num_train_timesteps: Number of diffusion steps for the forward diffusion schedule.
|
||||
beta_schedule: Name of the diffusion beta schedule as per DDPMScheduler from Hugging Face diffusers.
|
||||
beta_start: Beta value for the first forward-diffusion step.
|
||||
beta_end: Beta value for the last forward-diffusion step.
|
||||
prediction_type: The type of prediction that the diffusion modeling Unet makes. Choose from "epsilon"
|
||||
or "sample". These have equivalent outcomes from a latent variable modeling perspective, but
|
||||
"epsilon" has been shown to work better in many deep neural network settings.
|
||||
clip_sample: Whether to clip the sample to [-`clip_sample_range`, +`clip_sample_range`] for each
|
||||
denoising step at inference time. WARNING: you will need to make sure your action-space is
|
||||
normalized to fit within this range.
|
||||
clip_sample_range: The magnitude of the clipping range as described above.
|
||||
num_inference_steps: Number of reverse diffusion steps to use at inference time (steps are evenly
|
||||
spaced). If not provided, this defaults to be the same as `num_train_timesteps`.
|
||||
do_mask_loss_for_padding: Whether to mask the loss when there are copy-padded actions. See
|
||||
`LeRobotDataset` and `load_previous_and_future_frames` for mor information. Note, this defaults
|
||||
to False as the original Diffusion Policy implementation does the same.
|
||||
discretize_step: TODO(jayLEE0301)
|
||||
vqvae_groups: TODO(jayLEE0301)
|
||||
vqvae_n_embed: TODO(jayLEE0301)
|
||||
vqvae_embedding_dim: TODO(jayLEE0301)
|
||||
vqvae_enc_hidden_dim: TODO(jayLEE0301)
|
||||
gpt_block_size: TODO(jayLEE0301)
|
||||
gpt_input_dim: TODO(jayLEE0301)
|
||||
gpt_output_dim: TODO(jayLEE0301)
|
||||
gpt_n_layer: TODO(jayLEE0301)
|
||||
gpt_n_head: TODO(jayLEE0301)
|
||||
gpt_hidden_dim: TODO(jayLEE0301)
|
||||
dropout: TODO(jayLEE0301)
|
||||
mlp_hidden_dim: TODO(jayLEE0301)
|
||||
offset_loss_weight: TODO(jayLEE0301)
|
||||
secondary_code_loss_weight: TODO(jayLEE0301)
|
||||
"""
|
||||
|
||||
# Inputs / output structure.
|
||||
|
@ -115,11 +103,11 @@ class VQBeTConfig:
|
|||
gpt_output_dim: int = 512
|
||||
gpt_n_layer: int = 8
|
||||
gpt_n_head: int = 8
|
||||
gpt_n_embed: int = 512
|
||||
gpt_hidden_dim: int = 512
|
||||
dropout: float = 0.1
|
||||
mlp_hidden_dim: int = 1024
|
||||
offset_loss_weight: float = 10000.
|
||||
secondary_code_multiplier: float = 0.5
|
||||
secondary_code_loss_weight: float = 0.5
|
||||
|
||||
def __post_init__(self):
|
||||
"""Input validation (not exhaustive)."""
|
||||
|
|
|
@ -173,7 +173,7 @@ class VQBeTModel(nn.Module):
|
|||
], dim=-2).view(batch_size, -1, self.config.gpt_input_dim)
|
||||
if img_features.shape[1] != n_obs_steps:
|
||||
raise NotImplementedError
|
||||
# eos_token = self._eos_token.repeat(batch_size, 1, 1) # TODO remove EOS token
|
||||
# eos_token = self._eos_token.repeat(batch_size, 1, 1) # TODO(jayLEE0301) remove EOS token
|
||||
len_additional_action_token = self.config.n_action_pred_token-1
|
||||
action_token = self._action_token.repeat(batch_size, len_additional_action_token, 1)
|
||||
|
||||
|
@ -183,7 +183,7 @@ class VQBeTModel(nn.Module):
|
|||
|
||||
# get action features
|
||||
features = self._policy(global_cond)
|
||||
historical_act_pred_index = np.arange(0, n_obs_steps) * 3 + 2 # TODO make it compatible with other values
|
||||
historical_act_pred_index = np.arange(0, n_obs_steps) * 3 + 2 # TODO(jayLEE0301) make it compatible with other values
|
||||
features = torch.cat([
|
||||
features[:, historical_act_pred_index],
|
||||
features[:, -len_additional_action_token:]
|
||||
|
@ -225,7 +225,7 @@ class VQBeTHead(nn.Module):
|
|||
self.output_size = config.output_shapes["action"][0]
|
||||
self.hidden_size = config.mlp_hidden_dim
|
||||
self.offset_loss_weight = config.offset_loss_weight
|
||||
self.secondary_code_multiplier = config.secondary_code_multiplier
|
||||
self.secondary_code_loss_weight = config.secondary_code_loss_weight
|
||||
|
||||
self.vqvae_groups = config.vqvae_groups
|
||||
self.vqvae_n_embed = config.vqvae_n_embed # C(number of code integers)
|
||||
|
@ -358,7 +358,7 @@ class VQBeTHead(nn.Module):
|
|||
cbet_logits[:, 1, :],
|
||||
action_bins[:, 1],
|
||||
)
|
||||
cbet_loss = cbet_loss1 * 5 + cbet_loss2 * self.secondary_code_multiplier
|
||||
cbet_loss = cbet_loss1 * 5 + cbet_loss2 * self.secondary_code_loss_weight
|
||||
|
||||
equal_primary_code_rate = torch.sum(
|
||||
(action_bins[:, 0] == sampled_centers[:, 0]).int()
|
||||
|
@ -2058,11 +2058,11 @@ class MLP(torch.nn.Sequential):
|
|||
class CausalSelfAttention(nn.Module):
|
||||
def __init__(self, config):
|
||||
super().__init__()
|
||||
assert config.gpt_n_embed % config.gpt_n_head == 0
|
||||
assert config.gpt_hidden_dim % config.gpt_n_head == 0
|
||||
# key, query, value projections for all heads, but in a batch
|
||||
self.c_attn = nn.Linear(config.gpt_n_embed, 3 * config.gpt_n_embed)
|
||||
self.c_attn = nn.Linear(config.gpt_hidden_dim, 3 * config.gpt_hidden_dim)
|
||||
# output projection
|
||||
self.c_proj = nn.Linear(config.gpt_n_embed, config.gpt_n_embed)
|
||||
self.c_proj = nn.Linear(config.gpt_hidden_dim, config.gpt_hidden_dim)
|
||||
# regularization
|
||||
self.attn_dropout = nn.Dropout(config.dropout)
|
||||
self.resid_dropout = nn.Dropout(config.dropout)
|
||||
|
@ -2074,17 +2074,17 @@ class CausalSelfAttention(nn.Module):
|
|||
),
|
||||
)
|
||||
self.gpt_n_head = config.gpt_n_head
|
||||
self.gpt_n_embed = config.gpt_n_embed
|
||||
self.gpt_hidden_dim = config.gpt_hidden_dim
|
||||
|
||||
def forward(self, x):
|
||||
(
|
||||
B,
|
||||
T,
|
||||
C,
|
||||
) = x.size() # batch size, sequence length, embedding dimensionality (gpt_n_embed)
|
||||
) = x.size() # batch size, sequence length, embedding dimensionality (gpt_hidden_dim)
|
||||
|
||||
# calculate query, key, values for all heads in batch and move head forward to be the batch dim
|
||||
q, k, v = self.c_attn(x).split(self.gpt_n_embed, dim=2)
|
||||
q, k, v = self.c_attn(x).split(self.gpt_hidden_dim, dim=2)
|
||||
k = k.view(B, T, self.gpt_n_head, C // self.gpt_n_head).transpose(
|
||||
1, 2
|
||||
) # (B, nh, T, hs)
|
||||
|
@ -2114,13 +2114,13 @@ class CausalSelfAttention(nn.Module):
|
|||
class Block(nn.Module):
|
||||
def __init__(self, config):
|
||||
super().__init__()
|
||||
self.ln_1 = nn.LayerNorm(config.gpt_n_embed)
|
||||
self.ln_1 = nn.LayerNorm(config.gpt_hidden_dim)
|
||||
self.attn = CausalSelfAttention(config)
|
||||
self.ln_2 = nn.LayerNorm(config.gpt_n_embed)
|
||||
self.ln_2 = nn.LayerNorm(config.gpt_hidden_dim)
|
||||
self.mlp = nn.Sequential(
|
||||
nn.Linear(config.gpt_n_embed, 4 * config.gpt_n_embed),
|
||||
nn.Linear(config.gpt_hidden_dim, 4 * config.gpt_hidden_dim),
|
||||
nn.GELU(),
|
||||
nn.Linear(4 * config.gpt_n_embed, config.gpt_n_embed),
|
||||
nn.Linear(4 * config.gpt_hidden_dim, config.gpt_hidden_dim),
|
||||
nn.Dropout(config.dropout)
|
||||
)
|
||||
|
||||
|
@ -2178,14 +2178,14 @@ class GPT(nn.Module):
|
|||
|
||||
self.transformer = nn.ModuleDict(
|
||||
dict(
|
||||
wte=nn.Linear(config.gpt_input_dim, config.gpt_n_embed),
|
||||
wpe=nn.Embedding(config.gpt_block_size, config.gpt_n_embed),
|
||||
wte=nn.Linear(config.gpt_input_dim, config.gpt_hidden_dim),
|
||||
wpe=nn.Embedding(config.gpt_block_size, config.gpt_hidden_dim),
|
||||
drop=nn.Dropout(config.dropout),
|
||||
h=nn.ModuleList([Block(config) for _ in range(config.gpt_n_layer)]),
|
||||
ln_f=nn.LayerNorm(config.gpt_n_embed),
|
||||
ln_f=nn.LayerNorm(config.gpt_hidden_dim),
|
||||
)
|
||||
)
|
||||
self.lm_head = nn.Linear(config.gpt_n_embed, config.gpt_output_dim, bias=False)
|
||||
self.lm_head = nn.Linear(config.gpt_hidden_dim, config.gpt_output_dim, bias=False)
|
||||
# init all weights, and apply a special scaled init to the residual projections, per GPT-2 paper
|
||||
self.apply(self._init_weights)
|
||||
for pn, p in self.named_parameters():
|
||||
|
@ -2211,10 +2211,10 @@ class GPT(nn.Module):
|
|||
# forward the GPT model itself
|
||||
tok_emb = self.transformer.wte(
|
||||
input
|
||||
) # token embeddings of shape (b, t, gpt_n_embed)
|
||||
) # token embeddings of shape (b, t, gpt_hidden_dim)
|
||||
pos_emb = self.transformer.wpe(
|
||||
pos
|
||||
) # position embeddings of shape (1, t, gpt_n_embed)
|
||||
) # position embeddings of shape (1, t, gpt_hidden_dim)
|
||||
x = self.transformer.drop(tok_emb + pos_emb)
|
||||
for block in self.transformer.h:
|
||||
x = block(x)
|
||||
|
|
|
@ -1,8 +1,6 @@
|
|||
# @package _global_
|
||||
|
||||
# Defaults for training for the PushT dataset as per https://github.com/real-stanford/diffusion_policy.
|
||||
# Note: We do not track EMA model weights as we discovered it does not improve the results. See
|
||||
# https://github.com/huggingface/lerobot/pull/134 for more details.
|
||||
|
||||
seed: 100000
|
||||
dataset_repo_id: lerobot/pusht
|
||||
|
@ -97,8 +95,8 @@ policy:
|
|||
gpt_output_dim: 512
|
||||
gpt_n_layer: 8
|
||||
gpt_n_head: 8
|
||||
gpt_n_embed: 512
|
||||
gpt_hidden_dim: 512
|
||||
dropout: 0.1
|
||||
mlp_hidden_dim: 1024
|
||||
offset_loss_weight: 10000.
|
||||
secondary_code_multiplier: 0.5
|
||||
secondary_code_loss_weight: 0.5
|
Loading…
Reference in New Issue