Add make_optimizer

This commit is contained in:
Simon Alibert 2024-04-30 11:55:37 +02:00
parent b2cda12f87
commit a541a7f3cf
1 changed files with 46 additions and 40 deletions

View File

@ -25,6 +25,51 @@ from lerobot.common.utils.utils import (
from lerobot.scripts.eval import eval_policy
def make_optimizer(cfg, policy):
if cfg.policy.name == "act":
optimizer_params_dicts = [
{
"params": [
p
for n, p in policy.named_parameters()
if not n.startswith("backbone") and p.requires_grad
]
},
{
"params": [
p for n, p in policy.named_parameters() if n.startswith("backbone") and p.requires_grad
],
"lr": cfg.training.lr_backbone,
},
]
optimizer = torch.optim.AdamW(
optimizer_params_dicts, lr=cfg.training.lr, weight_decay=cfg.training.weight_decay
)
lr_scheduler = None
elif cfg.policy.name == "diffusion":
optimizer = torch.optim.Adam(
policy.diffusion.parameters(),
cfg.training.lr,
cfg.training.adam_betas,
cfg.training.adam_eps,
cfg.training.adam_weight_decay,
)
assert cfg.training.online_steps == 0, "Diffusion Policy does not handle online training."
lr_scheduler = get_scheduler(
cfg.training.lr_scheduler,
optimizer=optimizer,
num_warmup_steps=cfg.training.lr_warmup_steps,
num_training_steps=cfg.training.offline_steps,
)
elif policy.name == "tdmpc":
optimizer = torch.optim.Adam(policy.parameters(), cfg.training.lr)
lr_scheduler = None
else:
raise NotImplementedError()
return optimizer, lr_scheduler
def update_policy(policy, batch, optimizer, grad_clip_norm, lr_scheduler=None):
start_time = time.time()
policy.train()
@ -276,46 +321,7 @@ def train(cfg: dict, out_dir=None, job_name=None):
# Create optimizer and scheduler
# Temporary hack to move optimizer out of policy
if cfg.policy.name == "act":
optimizer_params_dicts = [
{
"params": [
p
for n, p in policy.named_parameters()
if not n.startswith("backbone") and p.requires_grad
]
},
{
"params": [
p for n, p in policy.named_parameters() if n.startswith("backbone") and p.requires_grad
],
"lr": cfg.training.lr_backbone,
},
]
optimizer = torch.optim.AdamW(
optimizer_params_dicts, lr=cfg.training.lr, weight_decay=cfg.training.weight_decay
)
lr_scheduler = None
elif cfg.policy.name == "diffusion":
optimizer = torch.optim.Adam(
policy.diffusion.parameters(),
cfg.training.lr,
cfg.training.adam_betas,
cfg.training.adam_eps,
cfg.training.adam_weight_decay,
)
assert cfg.training.online_steps == 0, "Diffusion Policy does not handle online training."
lr_scheduler = get_scheduler(
cfg.training.lr_scheduler,
optimizer=optimizer,
num_warmup_steps=cfg.training.lr_warmup_steps,
num_training_steps=cfg.training.offline_steps,
)
elif policy.name == "tdmpc":
optimizer = torch.optim.Adam(policy.parameters(), cfg.training.lr)
lr_scheduler = None
else:
raise NotImplementedError()
optimizer, lr_scheduler = make_optimizer(cfg, policy)
num_learnable_params = sum(p.numel() for p in policy.parameters() if p.requires_grad)
num_total_params = sum(p.numel() for p in policy.parameters())