Consolidate the optimizer into one, and make the scheduler updates lr of all parameters in phase 2 together

This commit is contained in:
jayLEE0301 2024-06-08 15:34:29 -04:00
parent 930c0cf86a
commit 913489ead0
1 changed files with 75 additions and 100 deletions

View File

@ -449,14 +449,15 @@ class VQBeTHead(nn.Module):
sampled_offsets = cbet_offsets[indices]
# Then, sum the offsets over the RVQ layers to get a net offset for the bin prediction
sampled_offsets = sampled_offsets.sum(dim=1)
# Get the centroids of each layer to pass it through RVQ decoder
return_decoder_input = self.vqvae_model.draw_code_forward(sampled_centers).clone().detach()
# pass the centroids through decoder to get actions.
decoded_action = (
self.vqvae_model.get_action_from_latent(return_decoder_input)
.clone()
.detach()
)
with torch.no_grad():
# Get the centroids of each layer to pass it through RVQ decoder
return_decoder_input = self.vqvae_model.draw_code_forward(sampled_centers).clone().detach()
# pass the centroids through decoder to get actions.
decoded_action = (
self.vqvae_model.get_action_from_latent(return_decoder_input)
.clone()
.detach()
)
# reshaped extracted offset to match with decoded centroids
sampled_offsets = einops.rearrange(
sampled_offsets, "NT (W A) -> NT W A", W=self.config.action_chunk_size
@ -504,9 +505,10 @@ class VQBeTHead(nn.Module):
action_seq = einops.rearrange(action_seq, "N T W A -> (N T) W A")
# Figure out the loss for the actions.
# First, we need to find the closest cluster center for each ground truth action.
state_vq, action_bins = self.vqvae_model.get_code(
action_seq
) # action_bins: NT, G
with torch.no_grad():
state_vq, action_bins = self.vqvae_model.get_code(
action_seq
) # action_bins: NT, G
# Now we can compute the loss.
@ -554,89 +556,59 @@ class VQBeTHead(nn.Module):
}
return loss_dict
class VQBeTOptimizer(nn.Module):
class VQBeTOptimizer(torch.optim.Adam):
def __init__(self, policy, cfg):
super().__init__()
self.n_vqvae_training_steps = cfg.training.n_vqvae_training_steps
self.offline_steps = cfg.training.offline_steps
self.optimizing_step = 0
vqvae_params = (
list(policy.vqbet.action_head.vqvae_model.encoder.parameters())
+ list(policy.vqbet.action_head.vqvae_model.decoder.parameters())
+ list(policy.vqbet.action_head.vqvae_model.vq_layer.parameters())
)
self.vqvae_optimizer = torch.optim.Adam(
vqvae_params, lr=cfg.training.vqvae_lr, weight_decay=0.0001
decay_params, no_decay_params = policy.vqbet.policy.configure_parameters()
decay_params = (
decay_params
+ list(policy.vqbet.rgb_encoder.parameters())
+ list(policy.vqbet.state_projector.parameters())
+ list(policy.vqbet.rgb_feature_projector.parameters())
+ [policy.vqbet._action_token]
+ list(policy.vqbet.action_head.map_to_cbet_preds_offset.parameters())
)
self.encoder_optimizer = torch.optim.Adam(
policy.vqbet.rgb_encoder.parameters(),
if cfg.policy.sequentially_select:
decay_params = (
decay_params
+ list(policy.vqbet.action_head.map_to_cbet_preds_primary_bin.parameters())
+ list(policy.vqbet.action_head.map_to_cbet_preds_secondary_bin.parameters())
)
else:
decay_params = (
decay_params
+ list(policy.vqbet.action_head.map_to_cbet_preds_bin.parameters())
)
optim_groups = [
{
"params": decay_params,
"weight_decay": cfg.training.adam_weight_decay,
"lr": cfg.training.lr,
},
{
"params": vqvae_params,
"weight_decay": 0.0001,
"lr": cfg.training.vqvae_lr,
},
{
"params": no_decay_params,
"weight_decay": 0.0,
"lr": cfg.training.lr,
},
]
super(VQBeTOptimizer, self).__init__(
optim_groups,
cfg.training.lr,
cfg.training.adam_betas,
cfg.training.adam_eps,
cfg.training.adam_weight_decay,
)
self.bet_optimizer1 = policy.vqbet.policy.configure_optimizers(
weight_decay=cfg.training.bet_weight_decay,
learning_rate=cfg.training.bet_learning_rate,
betas=cfg.training.bet_betas,
)
self.bet_optimizer1.add_param_group(
{"params": policy.vqbet._action_token}
)
self.bet_optimizer1.add_param_group(
{"params": policy.vqbet.state_projector.parameters()}
)
self.bet_optimizer1.add_param_group(
{"params": policy.vqbet.rgb_feature_projector.parameters()}
)
if cfg.policy.sequentially_select:
vqbet_head_params = (
list(policy.vqbet.action_head.map_to_cbet_preds_primary_bin.parameters())
+ list(policy.vqbet.action_head.map_to_cbet_preds_secondary_bin.parameters())
+ list(policy.vqbet.action_head.map_to_cbet_preds_offset.parameters())
)
else:
vqbet_head_params = (
list(policy.vqbet.action_head.map_to_cbet_preds_bin.parameters())
+ list(policy.vqbet.action_head.map_to_cbet_preds_offset.parameters())
)
self.bet_optimizer2 = torch.optim.AdamW(
vqbet_head_params,
lr=cfg.training.bet_learning_rate,
weight_decay=cfg.training.bet_weight_decay,
betas=cfg.training.bet_betas,
)
self.param_groups = self.encoder_optimizer.param_groups
def step(self):
self.optimizing_step +=1
# pretraining VQ-VAE (Training Phase 1)
if self.optimizing_step < self.n_vqvae_training_steps:
self.vqvae_optimizer.step()
# training BeT (Training Phase 2)
else:
self.encoder_optimizer.step()
self.bet_optimizer1.step()
self.bet_optimizer2.step()
def zero_grad(self):
# pretraining VQ-VAE (Training Phase 1)
if self.optimizing_step < self.n_vqvae_training_steps:
self.vqvae_optimizer.zero_grad()
# training BeT (Training Phase 2)
else:
self.encoder_optimizer.zero_grad()
self.bet_optimizer1.zero_grad()
self.bet_optimizer2.zero_grad()
class VQBeTScheduler(nn.Module):
def __init__(self, optimizer, cfg):
super().__init__()
@ -647,7 +619,7 @@ class VQBeTScheduler(nn.Module):
self.lr_scheduler = get_scheduler(
cfg.training.lr_scheduler,
optimizer=optimizer.encoder_optimizer,
optimizer=optimizer,
num_warmup_steps=cfg.training.lr_warmup_steps,
num_training_steps=cfg.training.offline_steps,
)
@ -2403,7 +2375,7 @@ class GPT(nn.Module):
for block in self.transformer.h:
block.attn.bias = block.attn.bias[:, :, :gpt_block_size, :gpt_block_size]
def configure_optimizers(self, weight_decay, learning_rate, betas, optimizer="Adamw", eps=None):
def configure_parameters(self):
"""
This long function is unfortunately doing something very simple and is being very defensive:
We are separating out all parameters of the model into two buckets: those that will experience
@ -2443,20 +2415,23 @@ class GPT(nn.Module):
)
# create the pytorch optimizer object
optim_groups = [
{
"params": [param_dict[pn] for pn in sorted(list(decay))],
"weight_decay": weight_decay,
},
{
"params": [param_dict[pn] for pn in sorted(list(no_decay))],
"weight_decay": 0.0,
},
]
if optimizer=="Adamw":
optimizer = torch.optim.AdamW(optim_groups, lr=learning_rate, betas=betas)
elif optimizer=="Adam":
optimizer = torch.optim.Adam(optim_groups, lr=learning_rate, betas=betas, eps=eps)
else:
raise NotImplementedError
return optimizer
# optim_groups = [
# {
# "params": [param_dict[pn] for pn in sorted(list(decay))],
# "weight_decay": weight_decay,
# },
# {
# "params": [param_dict[pn] for pn in sorted(list(no_decay))],
# "weight_decay": 0.0,
# },
# ]
decay = [param_dict[pn] for pn in sorted(list(decay))]
no_decay = [param_dict[pn] for pn in sorted(list(no_decay))]
return decay, no_decay
# if optimizer=="Adamw":
# optimizer = torch.optim.AdamW(optim_groups, lr=learning_rate, betas=betas)
# elif optimizer=="Adam":
# optimizer = torch.optim.Adam(optim_groups, lr=learning_rate, betas=betas, eps=eps)
# else:
# raise NotImplementedError
# return optimizer