From a541a7f3cffefc49cf776a5f864400cd7c6d1cdf Mon Sep 17 00:00:00 2001 From: Simon Alibert Date: Tue, 30 Apr 2024 11:55:37 +0200 Subject: [PATCH] Add make_optimizer --- lerobot/scripts/train.py | 86 +++++++++++++++++++++------------------- 1 file changed, 46 insertions(+), 40 deletions(-) diff --git a/lerobot/scripts/train.py b/lerobot/scripts/train.py index bd27b28a..a7c5d168 100644 --- a/lerobot/scripts/train.py +++ b/lerobot/scripts/train.py @@ -25,6 +25,51 @@ from lerobot.common.utils.utils import ( from lerobot.scripts.eval import eval_policy +def make_optimizer(cfg, policy): + if cfg.policy.name == "act": + optimizer_params_dicts = [ + { + "params": [ + p + for n, p in policy.named_parameters() + if not n.startswith("backbone") and p.requires_grad + ] + }, + { + "params": [ + p for n, p in policy.named_parameters() if n.startswith("backbone") and p.requires_grad + ], + "lr": cfg.training.lr_backbone, + }, + ] + optimizer = torch.optim.AdamW( + optimizer_params_dicts, lr=cfg.training.lr, weight_decay=cfg.training.weight_decay + ) + lr_scheduler = None + elif cfg.policy.name == "diffusion": + optimizer = torch.optim.Adam( + policy.diffusion.parameters(), + cfg.training.lr, + cfg.training.adam_betas, + cfg.training.adam_eps, + cfg.training.adam_weight_decay, + ) + assert cfg.training.online_steps == 0, "Diffusion Policy does not handle online training." + lr_scheduler = get_scheduler( + cfg.training.lr_scheduler, + optimizer=optimizer, + num_warmup_steps=cfg.training.lr_warmup_steps, + num_training_steps=cfg.training.offline_steps, + ) + elif policy.name == "tdmpc": + optimizer = torch.optim.Adam(policy.parameters(), cfg.training.lr) + lr_scheduler = None + else: + raise NotImplementedError() + + return optimizer, lr_scheduler + + def update_policy(policy, batch, optimizer, grad_clip_norm, lr_scheduler=None): start_time = time.time() policy.train() @@ -276,46 +321,7 @@ def train(cfg: dict, out_dir=None, job_name=None): # Create optimizer and scheduler # Temporary hack to move optimizer out of policy - if cfg.policy.name == "act": - optimizer_params_dicts = [ - { - "params": [ - p - for n, p in policy.named_parameters() - if not n.startswith("backbone") and p.requires_grad - ] - }, - { - "params": [ - p for n, p in policy.named_parameters() if n.startswith("backbone") and p.requires_grad - ], - "lr": cfg.training.lr_backbone, - }, - ] - optimizer = torch.optim.AdamW( - optimizer_params_dicts, lr=cfg.training.lr, weight_decay=cfg.training.weight_decay - ) - lr_scheduler = None - elif cfg.policy.name == "diffusion": - optimizer = torch.optim.Adam( - policy.diffusion.parameters(), - cfg.training.lr, - cfg.training.adam_betas, - cfg.training.adam_eps, - cfg.training.adam_weight_decay, - ) - assert cfg.training.online_steps == 0, "Diffusion Policy does not handle online training." - lr_scheduler = get_scheduler( - cfg.training.lr_scheduler, - optimizer=optimizer, - num_warmup_steps=cfg.training.lr_warmup_steps, - num_training_steps=cfg.training.offline_steps, - ) - elif policy.name == "tdmpc": - optimizer = torch.optim.Adam(policy.parameters(), cfg.training.lr) - lr_scheduler = None - else: - raise NotImplementedError() + optimizer, lr_scheduler = make_optimizer(cfg, policy) num_learnable_params = sum(p.numel() for p in policy.parameters() if p.requires_grad) num_total_params = sum(p.numel() for p in policy.parameters())