Add make_optimizer
This commit is contained in:
parent
b2cda12f87
commit
a541a7f3cf
|
@ -25,6 +25,51 @@ from lerobot.common.utils.utils import (
|
||||||
from lerobot.scripts.eval import eval_policy
|
from lerobot.scripts.eval import eval_policy
|
||||||
|
|
||||||
|
|
||||||
|
def make_optimizer(cfg, policy):
|
||||||
|
if cfg.policy.name == "act":
|
||||||
|
optimizer_params_dicts = [
|
||||||
|
{
|
||||||
|
"params": [
|
||||||
|
p
|
||||||
|
for n, p in policy.named_parameters()
|
||||||
|
if not n.startswith("backbone") and p.requires_grad
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"params": [
|
||||||
|
p for n, p in policy.named_parameters() if n.startswith("backbone") and p.requires_grad
|
||||||
|
],
|
||||||
|
"lr": cfg.training.lr_backbone,
|
||||||
|
},
|
||||||
|
]
|
||||||
|
optimizer = torch.optim.AdamW(
|
||||||
|
optimizer_params_dicts, lr=cfg.training.lr, weight_decay=cfg.training.weight_decay
|
||||||
|
)
|
||||||
|
lr_scheduler = None
|
||||||
|
elif cfg.policy.name == "diffusion":
|
||||||
|
optimizer = torch.optim.Adam(
|
||||||
|
policy.diffusion.parameters(),
|
||||||
|
cfg.training.lr,
|
||||||
|
cfg.training.adam_betas,
|
||||||
|
cfg.training.adam_eps,
|
||||||
|
cfg.training.adam_weight_decay,
|
||||||
|
)
|
||||||
|
assert cfg.training.online_steps == 0, "Diffusion Policy does not handle online training."
|
||||||
|
lr_scheduler = get_scheduler(
|
||||||
|
cfg.training.lr_scheduler,
|
||||||
|
optimizer=optimizer,
|
||||||
|
num_warmup_steps=cfg.training.lr_warmup_steps,
|
||||||
|
num_training_steps=cfg.training.offline_steps,
|
||||||
|
)
|
||||||
|
elif policy.name == "tdmpc":
|
||||||
|
optimizer = torch.optim.Adam(policy.parameters(), cfg.training.lr)
|
||||||
|
lr_scheduler = None
|
||||||
|
else:
|
||||||
|
raise NotImplementedError()
|
||||||
|
|
||||||
|
return optimizer, lr_scheduler
|
||||||
|
|
||||||
|
|
||||||
def update_policy(policy, batch, optimizer, grad_clip_norm, lr_scheduler=None):
|
def update_policy(policy, batch, optimizer, grad_clip_norm, lr_scheduler=None):
|
||||||
start_time = time.time()
|
start_time = time.time()
|
||||||
policy.train()
|
policy.train()
|
||||||
|
@ -276,46 +321,7 @@ def train(cfg: dict, out_dir=None, job_name=None):
|
||||||
|
|
||||||
# Create optimizer and scheduler
|
# Create optimizer and scheduler
|
||||||
# Temporary hack to move optimizer out of policy
|
# Temporary hack to move optimizer out of policy
|
||||||
if cfg.policy.name == "act":
|
optimizer, lr_scheduler = make_optimizer(cfg, policy)
|
||||||
optimizer_params_dicts = [
|
|
||||||
{
|
|
||||||
"params": [
|
|
||||||
p
|
|
||||||
for n, p in policy.named_parameters()
|
|
||||||
if not n.startswith("backbone") and p.requires_grad
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"params": [
|
|
||||||
p for n, p in policy.named_parameters() if n.startswith("backbone") and p.requires_grad
|
|
||||||
],
|
|
||||||
"lr": cfg.training.lr_backbone,
|
|
||||||
},
|
|
||||||
]
|
|
||||||
optimizer = torch.optim.AdamW(
|
|
||||||
optimizer_params_dicts, lr=cfg.training.lr, weight_decay=cfg.training.weight_decay
|
|
||||||
)
|
|
||||||
lr_scheduler = None
|
|
||||||
elif cfg.policy.name == "diffusion":
|
|
||||||
optimizer = torch.optim.Adam(
|
|
||||||
policy.diffusion.parameters(),
|
|
||||||
cfg.training.lr,
|
|
||||||
cfg.training.adam_betas,
|
|
||||||
cfg.training.adam_eps,
|
|
||||||
cfg.training.adam_weight_decay,
|
|
||||||
)
|
|
||||||
assert cfg.training.online_steps == 0, "Diffusion Policy does not handle online training."
|
|
||||||
lr_scheduler = get_scheduler(
|
|
||||||
cfg.training.lr_scheduler,
|
|
||||||
optimizer=optimizer,
|
|
||||||
num_warmup_steps=cfg.training.lr_warmup_steps,
|
|
||||||
num_training_steps=cfg.training.offline_steps,
|
|
||||||
)
|
|
||||||
elif policy.name == "tdmpc":
|
|
||||||
optimizer = torch.optim.Adam(policy.parameters(), cfg.training.lr)
|
|
||||||
lr_scheduler = None
|
|
||||||
else:
|
|
||||||
raise NotImplementedError()
|
|
||||||
|
|
||||||
num_learnable_params = sum(p.numel() for p in policy.parameters() if p.requires_grad)
|
num_learnable_params = sum(p.numel() for p in policy.parameters() if p.requires_grad)
|
||||||
num_total_params = sum(p.numel() for p in policy.parameters())
|
num_total_params = sum(p.numel() for p in policy.parameters())
|
||||||
|
|
Loading…
Reference in New Issue