change discretize_step -> n_vqvae_training_steps
This commit is contained in:
parent
fd8fc11342
commit
bc10e34700
|
@ -40,7 +40,7 @@ class VQBeTConfig:
|
|||
use_group_norm: Whether to replace batch normalization with group normalization in the backbone.
|
||||
The group sizes are set to be about 16 (to be precise, feature_dim // 16).
|
||||
spatial_softmax_num_keypoints: Number of keypoints for SpatialSoftmax.
|
||||
discretize_step: Number of optimization steps for training Residual VQ.
|
||||
n_vqvae_training_steps: Number of optimization steps for training Residual VQ.
|
||||
vqvae_groups: Number of layers in Residual VQ.
|
||||
vqvae_n_embed: Number of embedding vectors in the RVQ dictionary (each layer).
|
||||
vqvae_embedding_dim: Dimension of each embedding vector in the RVQ dictionary.
|
||||
|
@ -97,7 +97,7 @@ class VQBeTConfig:
|
|||
use_group_norm: bool = True
|
||||
spatial_softmax_num_keypoints: int = 32
|
||||
# VQ-VAE
|
||||
discretize_step: int = 3000
|
||||
n_vqvae_training_steps: int = 3000
|
||||
vqvae_groups: int = 2
|
||||
vqvae_n_embed: int = 16
|
||||
vqvae_embedding_dim: int = 256
|
||||
|
|
|
@ -113,7 +113,7 @@ class VQBeTPolicy(nn.Module, PyTorchModelHubMixin):
|
|||
# loss: total loss of training RVQ
|
||||
# n_different_codes: how many of total possible codes are being used (max: vqvae_n_embed).
|
||||
# n_different_combinations: how many different code combinations you are using out of all possible code combinations (max: vqvae_n_embed ^ vqvae_groups).
|
||||
loss, n_different_codes, n_different_combinations = self.vqbet.discretize(self.config.discretize_step, batch['action'])
|
||||
loss, n_different_codes, n_different_combinations = self.vqbet.discretize(self.config.n_vqvae_training_steps, batch['action'])
|
||||
return {"loss": loss, "n_different_codes": n_different_codes, "n_different_combinations": n_different_combinations}
|
||||
# if Residual VQ is already trained, VQ-BeT trains its GPT and bin prediction head / offset prediction head parts.
|
||||
_, loss_dict = self.vqbet(batch, rollout=False)
|
||||
|
@ -208,7 +208,7 @@ class VQBeTModel(nn.Module):
|
|||
--------------------------------------------------------------------------
|
||||
|
||||
|
||||
Training Phase 1. Discretize action using Residual VQ (for config.discretize_step steps)
|
||||
Training Phase 1. Discretize action using Residual VQ (for config.n_vqvae_training_steps steps)
|
||||
|
||||
|
||||
┌─────────────────┐ ┌─────────────────┐ ┌─────────────────┐
|
||||
|
@ -280,8 +280,8 @@ class VQBeTModel(nn.Module):
|
|||
),
|
||||
)
|
||||
|
||||
def discretize(self, discretize_step, actions):
|
||||
return self.action_head.discretize(discretize_step, actions)
|
||||
def discretize(self, n_vqvae_training_steps, actions):
|
||||
return self.action_head.discretize(n_vqvae_training_steps, actions)
|
||||
|
||||
def forward(self, batch: dict[str, Tensor], rollout: bool) -> Tensor:
|
||||
# Input validation.
|
||||
|
@ -378,8 +378,8 @@ class VQBeTHead(nn.Module):
|
|||
# loss
|
||||
self._focal_loss_fn = FocalLoss(gamma=2.0)
|
||||
|
||||
def discretize(self, discretize_step, actions):
|
||||
loss, n_different_codes, n_different_combinations = pretrain_vqvae(self.vqvae_model, discretize_step, actions)
|
||||
def discretize(self, n_vqvae_training_steps, actions):
|
||||
loss, n_different_codes, n_different_combinations = pretrain_vqvae(self.vqvae_model, n_vqvae_training_steps, actions)
|
||||
return loss, n_different_codes, n_different_combinations
|
||||
|
||||
def forward(self, x, **kwargs):
|
||||
|
@ -554,7 +554,7 @@ class VQBeTHead(nn.Module):
|
|||
|
||||
class VQBeTOptimizer:
|
||||
def __init__(self, policy, cfg):
|
||||
self.discretize_step = cfg.training.discretize_step
|
||||
self.n_vqvae_training_steps = cfg.training.n_vqvae_training_steps
|
||||
self.offline_steps = cfg.training.offline_steps
|
||||
self.optimizing_step = 0
|
||||
|
||||
|
@ -616,7 +616,7 @@ class VQBeTOptimizer:
|
|||
def step(self):
|
||||
self.optimizing_step +=1
|
||||
# pretraining VQ-VAE (Training Phase 1)
|
||||
if self.optimizing_step < self.discretize_step:
|
||||
if self.optimizing_step < self.n_vqvae_training_steps:
|
||||
self.vqvae_optimizer.step()
|
||||
# training BeT (Training Phase 2)
|
||||
else:
|
||||
|
@ -626,7 +626,7 @@ class VQBeTOptimizer:
|
|||
|
||||
def zero_grad(self):
|
||||
# pretraining VQ-VAE (Training Phase 1)
|
||||
if self.optimizing_step < self.discretize_step:
|
||||
if self.optimizing_step < self.n_vqvae_training_steps:
|
||||
self.vqvae_optimizer.zero_grad()
|
||||
# training BeT (Training Phase 2)
|
||||
else:
|
||||
|
@ -638,7 +638,7 @@ class VQBeTScheduler:
|
|||
def __init__(self, optimizer, cfg):
|
||||
# VQ-BeT use scheduler only for rgb encoder. Since we took rgb encoder part from diffusion policy, we also follow the same scheduler from it.
|
||||
from diffusers.optimization import get_scheduler
|
||||
self.discretize_step = cfg.training.discretize_step
|
||||
self.n_vqvae_training_steps = cfg.training.n_vqvae_training_steps
|
||||
self.optimizing_step = 0
|
||||
|
||||
self.lr_scheduler = get_scheduler(
|
||||
|
@ -651,7 +651,7 @@ class VQBeTScheduler:
|
|||
|
||||
def step(self):
|
||||
self.optimizing_step +=1
|
||||
if self.optimizing_step >= self.discretize_step:
|
||||
if self.optimizing_step >= self.n_vqvae_training_steps:
|
||||
self.lr_scheduler.step()
|
||||
|
||||
class VQBeTRgbEncoder(nn.Module):
|
||||
|
@ -907,7 +907,7 @@ class VqVae(nn.Module):
|
|||
|
||||
|
||||
|
||||
def pretrain_vqvae(vqvae_model, discretize_step, actions):
|
||||
def pretrain_vqvae(vqvae_model, n_vqvae_training_steps, actions):
|
||||
if vqvae_model.config.action_chunk_size == 1:
|
||||
# not using action chunk
|
||||
actions = actions.reshape(-1, 1, actions.shape[-1])
|
||||
|
@ -926,8 +926,8 @@ def pretrain_vqvae(vqvae_model, discretize_step, actions):
|
|||
n_different_codes = len(torch.unique(metric[2]))
|
||||
n_different_combinations = len(torch.unique(metric[2], dim=0))
|
||||
vqvae_model.optimized_steps += 1
|
||||
# if we updated RVQ more than `discretize_step` steps,
|
||||
if vqvae_model.optimized_steps >= discretize_step:
|
||||
# if we updated RVQ more than `n_vqvae_training_steps` steps,
|
||||
if vqvae_model.optimized_steps >= n_vqvae_training_steps:
|
||||
vqvae_model.toggle_discretized(True)
|
||||
print("Finished discretizing action data!")
|
||||
vqvae_model.eval()
|
||||
|
|
|
@ -39,7 +39,7 @@ training:
|
|||
|
||||
# VQ-BeT specific
|
||||
vqvae_lr: 1.0e-3
|
||||
discretize_step: 20000
|
||||
n_vqvae_training_steps: 20000
|
||||
bet_weight_decay: 2e-4
|
||||
bet_learning_rate: 5.5e-5
|
||||
bet_betas: [0.9, 0.999]
|
||||
|
@ -84,7 +84,7 @@ policy:
|
|||
use_group_norm: True
|
||||
spatial_softmax_num_keypoints: 32
|
||||
# VQ-VAE
|
||||
discretize_step: ${training.discretize_step}
|
||||
n_vqvae_training_steps: ${training.n_vqvae_training_steps}
|
||||
vqvae_groups: 2
|
||||
vqvae_n_embed: 16
|
||||
vqvae_embedding_dim: 256
|
||||
|
|
Loading…
Reference in New Issue