Consolidate the optimizer into one, and make the scheduler updates lr of all parameters in phase 2 together
This commit is contained in:
parent
930c0cf86a
commit
913489ead0
|
@ -449,14 +449,15 @@ class VQBeTHead(nn.Module):
|
|||
sampled_offsets = cbet_offsets[indices]
|
||||
# Then, sum the offsets over the RVQ layers to get a net offset for the bin prediction
|
||||
sampled_offsets = sampled_offsets.sum(dim=1)
|
||||
# Get the centroids of each layer to pass it through RVQ decoder
|
||||
return_decoder_input = self.vqvae_model.draw_code_forward(sampled_centers).clone().detach()
|
||||
# pass the centroids through decoder to get actions.
|
||||
decoded_action = (
|
||||
self.vqvae_model.get_action_from_latent(return_decoder_input)
|
||||
.clone()
|
||||
.detach()
|
||||
)
|
||||
with torch.no_grad():
|
||||
# Get the centroids of each layer to pass it through RVQ decoder
|
||||
return_decoder_input = self.vqvae_model.draw_code_forward(sampled_centers).clone().detach()
|
||||
# pass the centroids through decoder to get actions.
|
||||
decoded_action = (
|
||||
self.vqvae_model.get_action_from_latent(return_decoder_input)
|
||||
.clone()
|
||||
.detach()
|
||||
)
|
||||
# reshaped extracted offset to match with decoded centroids
|
||||
sampled_offsets = einops.rearrange(
|
||||
sampled_offsets, "NT (W A) -> NT W A", W=self.config.action_chunk_size
|
||||
|
@ -504,9 +505,10 @@ class VQBeTHead(nn.Module):
|
|||
action_seq = einops.rearrange(action_seq, "N T W A -> (N T) W A")
|
||||
# Figure out the loss for the actions.
|
||||
# First, we need to find the closest cluster center for each ground truth action.
|
||||
state_vq, action_bins = self.vqvae_model.get_code(
|
||||
action_seq
|
||||
) # action_bins: NT, G
|
||||
with torch.no_grad():
|
||||
state_vq, action_bins = self.vqvae_model.get_code(
|
||||
action_seq
|
||||
) # action_bins: NT, G
|
||||
|
||||
# Now we can compute the loss.
|
||||
|
||||
|
@ -554,89 +556,59 @@ class VQBeTHead(nn.Module):
|
|||
}
|
||||
return loss_dict
|
||||
|
||||
class VQBeTOptimizer(nn.Module):
|
||||
class VQBeTOptimizer(torch.optim.Adam):
|
||||
def __init__(self, policy, cfg):
|
||||
super().__init__()
|
||||
self.n_vqvae_training_steps = cfg.training.n_vqvae_training_steps
|
||||
self.offline_steps = cfg.training.offline_steps
|
||||
self.optimizing_step = 0
|
||||
|
||||
|
||||
vqvae_params = (
|
||||
list(policy.vqbet.action_head.vqvae_model.encoder.parameters())
|
||||
+ list(policy.vqbet.action_head.vqvae_model.decoder.parameters())
|
||||
+ list(policy.vqbet.action_head.vqvae_model.vq_layer.parameters())
|
||||
)
|
||||
self.vqvae_optimizer = torch.optim.Adam(
|
||||
vqvae_params, lr=cfg.training.vqvae_lr, weight_decay=0.0001
|
||||
decay_params, no_decay_params = policy.vqbet.policy.configure_parameters()
|
||||
decay_params = (
|
||||
decay_params
|
||||
+ list(policy.vqbet.rgb_encoder.parameters())
|
||||
+ list(policy.vqbet.state_projector.parameters())
|
||||
+ list(policy.vqbet.rgb_feature_projector.parameters())
|
||||
+ [policy.vqbet._action_token]
|
||||
+ list(policy.vqbet.action_head.map_to_cbet_preds_offset.parameters())
|
||||
)
|
||||
|
||||
self.encoder_optimizer = torch.optim.Adam(
|
||||
policy.vqbet.rgb_encoder.parameters(),
|
||||
if cfg.policy.sequentially_select:
|
||||
decay_params = (
|
||||
decay_params
|
||||
+ list(policy.vqbet.action_head.map_to_cbet_preds_primary_bin.parameters())
|
||||
+ list(policy.vqbet.action_head.map_to_cbet_preds_secondary_bin.parameters())
|
||||
)
|
||||
else:
|
||||
decay_params = (
|
||||
decay_params
|
||||
+ list(policy.vqbet.action_head.map_to_cbet_preds_bin.parameters())
|
||||
)
|
||||
|
||||
optim_groups = [
|
||||
{
|
||||
"params": decay_params,
|
||||
"weight_decay": cfg.training.adam_weight_decay,
|
||||
"lr": cfg.training.lr,
|
||||
},
|
||||
{
|
||||
"params": vqvae_params,
|
||||
"weight_decay": 0.0001,
|
||||
"lr": cfg.training.vqvae_lr,
|
||||
},
|
||||
{
|
||||
"params": no_decay_params,
|
||||
"weight_decay": 0.0,
|
||||
"lr": cfg.training.lr,
|
||||
},
|
||||
]
|
||||
super(VQBeTOptimizer, self).__init__(
|
||||
optim_groups,
|
||||
cfg.training.lr,
|
||||
cfg.training.adam_betas,
|
||||
cfg.training.adam_eps,
|
||||
cfg.training.adam_weight_decay,
|
||||
)
|
||||
|
||||
self.bet_optimizer1 = policy.vqbet.policy.configure_optimizers(
|
||||
weight_decay=cfg.training.bet_weight_decay,
|
||||
learning_rate=cfg.training.bet_learning_rate,
|
||||
betas=cfg.training.bet_betas,
|
||||
)
|
||||
|
||||
self.bet_optimizer1.add_param_group(
|
||||
{"params": policy.vqbet._action_token}
|
||||
)
|
||||
self.bet_optimizer1.add_param_group(
|
||||
{"params": policy.vqbet.state_projector.parameters()}
|
||||
)
|
||||
self.bet_optimizer1.add_param_group(
|
||||
{"params": policy.vqbet.rgb_feature_projector.parameters()}
|
||||
)
|
||||
|
||||
if cfg.policy.sequentially_select:
|
||||
vqbet_head_params = (
|
||||
list(policy.vqbet.action_head.map_to_cbet_preds_primary_bin.parameters())
|
||||
+ list(policy.vqbet.action_head.map_to_cbet_preds_secondary_bin.parameters())
|
||||
+ list(policy.vqbet.action_head.map_to_cbet_preds_offset.parameters())
|
||||
)
|
||||
|
||||
else:
|
||||
vqbet_head_params = (
|
||||
list(policy.vqbet.action_head.map_to_cbet_preds_bin.parameters())
|
||||
+ list(policy.vqbet.action_head.map_to_cbet_preds_offset.parameters())
|
||||
)
|
||||
self.bet_optimizer2 = torch.optim.AdamW(
|
||||
vqbet_head_params,
|
||||
lr=cfg.training.bet_learning_rate,
|
||||
weight_decay=cfg.training.bet_weight_decay,
|
||||
betas=cfg.training.bet_betas,
|
||||
)
|
||||
|
||||
self.param_groups = self.encoder_optimizer.param_groups
|
||||
|
||||
def step(self):
|
||||
self.optimizing_step +=1
|
||||
# pretraining VQ-VAE (Training Phase 1)
|
||||
if self.optimizing_step < self.n_vqvae_training_steps:
|
||||
self.vqvae_optimizer.step()
|
||||
# training BeT (Training Phase 2)
|
||||
else:
|
||||
self.encoder_optimizer.step()
|
||||
self.bet_optimizer1.step()
|
||||
self.bet_optimizer2.step()
|
||||
|
||||
def zero_grad(self):
|
||||
# pretraining VQ-VAE (Training Phase 1)
|
||||
if self.optimizing_step < self.n_vqvae_training_steps:
|
||||
self.vqvae_optimizer.zero_grad()
|
||||
# training BeT (Training Phase 2)
|
||||
else:
|
||||
self.encoder_optimizer.zero_grad()
|
||||
self.bet_optimizer1.zero_grad()
|
||||
self.bet_optimizer2.zero_grad()
|
||||
|
||||
class VQBeTScheduler(nn.Module):
|
||||
def __init__(self, optimizer, cfg):
|
||||
super().__init__()
|
||||
|
@ -647,7 +619,7 @@ class VQBeTScheduler(nn.Module):
|
|||
|
||||
self.lr_scheduler = get_scheduler(
|
||||
cfg.training.lr_scheduler,
|
||||
optimizer=optimizer.encoder_optimizer,
|
||||
optimizer=optimizer,
|
||||
num_warmup_steps=cfg.training.lr_warmup_steps,
|
||||
num_training_steps=cfg.training.offline_steps,
|
||||
)
|
||||
|
@ -2403,7 +2375,7 @@ class GPT(nn.Module):
|
|||
for block in self.transformer.h:
|
||||
block.attn.bias = block.attn.bias[:, :, :gpt_block_size, :gpt_block_size]
|
||||
|
||||
def configure_optimizers(self, weight_decay, learning_rate, betas, optimizer="Adamw", eps=None):
|
||||
def configure_parameters(self):
|
||||
"""
|
||||
This long function is unfortunately doing something very simple and is being very defensive:
|
||||
We are separating out all parameters of the model into two buckets: those that will experience
|
||||
|
@ -2443,20 +2415,23 @@ class GPT(nn.Module):
|
|||
)
|
||||
|
||||
# create the pytorch optimizer object
|
||||
optim_groups = [
|
||||
{
|
||||
"params": [param_dict[pn] for pn in sorted(list(decay))],
|
||||
"weight_decay": weight_decay,
|
||||
},
|
||||
{
|
||||
"params": [param_dict[pn] for pn in sorted(list(no_decay))],
|
||||
"weight_decay": 0.0,
|
||||
},
|
||||
]
|
||||
if optimizer=="Adamw":
|
||||
optimizer = torch.optim.AdamW(optim_groups, lr=learning_rate, betas=betas)
|
||||
elif optimizer=="Adam":
|
||||
optimizer = torch.optim.Adam(optim_groups, lr=learning_rate, betas=betas, eps=eps)
|
||||
else:
|
||||
raise NotImplementedError
|
||||
return optimizer
|
||||
# optim_groups = [
|
||||
# {
|
||||
# "params": [param_dict[pn] for pn in sorted(list(decay))],
|
||||
# "weight_decay": weight_decay,
|
||||
# },
|
||||
# {
|
||||
# "params": [param_dict[pn] for pn in sorted(list(no_decay))],
|
||||
# "weight_decay": 0.0,
|
||||
# },
|
||||
# ]
|
||||
decay = [param_dict[pn] for pn in sorted(list(decay))]
|
||||
no_decay = [param_dict[pn] for pn in sorted(list(no_decay))]
|
||||
return decay, no_decay
|
||||
# if optimizer=="Adamw":
|
||||
# optimizer = torch.optim.AdamW(optim_groups, lr=learning_rate, betas=betas)
|
||||
# elif optimizer=="Adam":
|
||||
# optimizer = torch.optim.Adam(optim_groups, lr=learning_rate, betas=betas, eps=eps)
|
||||
# else:
|
||||
# raise NotImplementedError
|
||||
# return optimizer
|
||||
|
|
Loading…
Reference in New Issue