change discretize_step -> n_vqvae_training_steps

This commit is contained in:
jayLEE0301 2024-06-05 14:05:50 -04:00
parent fd8fc11342
commit bc10e34700
3 changed files with 18 additions and 18 deletions

View File

@ -40,7 +40,7 @@ class VQBeTConfig:
use_group_norm: Whether to replace batch normalization with group normalization in the backbone.
The group sizes are set to be about 16 (to be precise, feature_dim // 16).
spatial_softmax_num_keypoints: Number of keypoints for SpatialSoftmax.
discretize_step: Number of optimization steps for training Residual VQ.
n_vqvae_training_steps: Number of optimization steps for training Residual VQ.
vqvae_groups: Number of layers in Residual VQ.
vqvae_n_embed: Number of embedding vectors in the RVQ dictionary (each layer).
vqvae_embedding_dim: Dimension of each embedding vector in the RVQ dictionary.
@ -97,7 +97,7 @@ class VQBeTConfig:
use_group_norm: bool = True
spatial_softmax_num_keypoints: int = 32
# VQ-VAE
discretize_step: int = 3000
n_vqvae_training_steps: int = 3000
vqvae_groups: int = 2
vqvae_n_embed: int = 16
vqvae_embedding_dim: int = 256

View File

@ -113,7 +113,7 @@ class VQBeTPolicy(nn.Module, PyTorchModelHubMixin):
# loss: total loss of training RVQ
# n_different_codes: how many of total possible codes are being used (max: vqvae_n_embed).
# n_different_combinations: how many different code combinations you are using out of all possible code combinations (max: vqvae_n_embed ^ vqvae_groups).
loss, n_different_codes, n_different_combinations = self.vqbet.discretize(self.config.discretize_step, batch['action'])
loss, n_different_codes, n_different_combinations = self.vqbet.discretize(self.config.n_vqvae_training_steps, batch['action'])
return {"loss": loss, "n_different_codes": n_different_codes, "n_different_combinations": n_different_combinations}
# if Residual VQ is already trained, VQ-BeT trains its GPT and bin prediction head / offset prediction head parts.
_, loss_dict = self.vqbet(batch, rollout=False)
@ -208,7 +208,7 @@ class VQBeTModel(nn.Module):
--------------------------------------------------------------------------
Training Phase 1. Discretize action using Residual VQ (for config.discretize_step steps)
Training Phase 1. Discretize action using Residual VQ (for config.n_vqvae_training_steps steps)
@ -280,8 +280,8 @@ class VQBeTModel(nn.Module):
),
)
def discretize(self, discretize_step, actions):
return self.action_head.discretize(discretize_step, actions)
def discretize(self, n_vqvae_training_steps, actions):
return self.action_head.discretize(n_vqvae_training_steps, actions)
def forward(self, batch: dict[str, Tensor], rollout: bool) -> Tensor:
# Input validation.
@ -378,8 +378,8 @@ class VQBeTHead(nn.Module):
# loss
self._focal_loss_fn = FocalLoss(gamma=2.0)
def discretize(self, discretize_step, actions):
loss, n_different_codes, n_different_combinations = pretrain_vqvae(self.vqvae_model, discretize_step, actions)
def discretize(self, n_vqvae_training_steps, actions):
loss, n_different_codes, n_different_combinations = pretrain_vqvae(self.vqvae_model, n_vqvae_training_steps, actions)
return loss, n_different_codes, n_different_combinations
def forward(self, x, **kwargs):
@ -554,7 +554,7 @@ class VQBeTHead(nn.Module):
class VQBeTOptimizer:
def __init__(self, policy, cfg):
self.discretize_step = cfg.training.discretize_step
self.n_vqvae_training_steps = cfg.training.n_vqvae_training_steps
self.offline_steps = cfg.training.offline_steps
self.optimizing_step = 0
@ -616,7 +616,7 @@ class VQBeTOptimizer:
def step(self):
self.optimizing_step +=1
# pretraining VQ-VAE (Training Phase 1)
if self.optimizing_step < self.discretize_step:
if self.optimizing_step < self.n_vqvae_training_steps:
self.vqvae_optimizer.step()
# training BeT (Training Phase 2)
else:
@ -626,7 +626,7 @@ class VQBeTOptimizer:
def zero_grad(self):
# pretraining VQ-VAE (Training Phase 1)
if self.optimizing_step < self.discretize_step:
if self.optimizing_step < self.n_vqvae_training_steps:
self.vqvae_optimizer.zero_grad()
# training BeT (Training Phase 2)
else:
@ -638,7 +638,7 @@ class VQBeTScheduler:
def __init__(self, optimizer, cfg):
# VQ-BeT use scheduler only for rgb encoder. Since we took rgb encoder part from diffusion policy, we also follow the same scheduler from it.
from diffusers.optimization import get_scheduler
self.discretize_step = cfg.training.discretize_step
self.n_vqvae_training_steps = cfg.training.n_vqvae_training_steps
self.optimizing_step = 0
self.lr_scheduler = get_scheduler(
@ -651,7 +651,7 @@ class VQBeTScheduler:
def step(self):
self.optimizing_step +=1
if self.optimizing_step >= self.discretize_step:
if self.optimizing_step >= self.n_vqvae_training_steps:
self.lr_scheduler.step()
class VQBeTRgbEncoder(nn.Module):
@ -907,7 +907,7 @@ class VqVae(nn.Module):
def pretrain_vqvae(vqvae_model, discretize_step, actions):
def pretrain_vqvae(vqvae_model, n_vqvae_training_steps, actions):
if vqvae_model.config.action_chunk_size == 1:
# not using action chunk
actions = actions.reshape(-1, 1, actions.shape[-1])
@ -926,8 +926,8 @@ def pretrain_vqvae(vqvae_model, discretize_step, actions):
n_different_codes = len(torch.unique(metric[2]))
n_different_combinations = len(torch.unique(metric[2], dim=0))
vqvae_model.optimized_steps += 1
# if we updated RVQ more than `discretize_step` steps,
if vqvae_model.optimized_steps >= discretize_step:
# if we updated RVQ more than `n_vqvae_training_steps` steps,
if vqvae_model.optimized_steps >= n_vqvae_training_steps:
vqvae_model.toggle_discretized(True)
print("Finished discretizing action data!")
vqvae_model.eval()

View File

@ -39,7 +39,7 @@ training:
# VQ-BeT specific
vqvae_lr: 1.0e-3
discretize_step: 20000
n_vqvae_training_steps: 20000
bet_weight_decay: 2e-4
bet_learning_rate: 5.5e-5
bet_betas: [0.9, 0.999]
@ -84,7 +84,7 @@ policy:
use_group_norm: True
spatial_softmax_num_keypoints: 32
# VQ-VAE
discretize_step: ${training.discretize_step}
n_vqvae_training_steps: ${training.n_vqvae_training_steps}
vqvae_groups: 2
vqvae_n_embed: 16
vqvae_embedding_dim: 256