From bbce0eaeafc93238430025e373a494819d6f0128 Mon Sep 17 00:00:00 2001 From: Michel Aractingi Date: Mon, 2 Sep 2024 07:53:10 +0000 Subject: [PATCH] moved make optimizer and scheduler function to inside policy --- lerobot/common/policies/act/modeling_act.py | 24 ++++++++ .../policies/diffusion/modeling_diffusion.py | 30 ++++++++++ .../common/policies/tdmpc/modeling_tdmpc.py | 7 +++ lerobot/scripts/train.py | 55 +------------------ 4 files changed, 62 insertions(+), 54 deletions(-) diff --git a/lerobot/common/policies/act/modeling_act.py b/lerobot/common/policies/act/modeling_act.py index 3427c482..ea8fdc2d 100644 --- a/lerobot/common/policies/act/modeling_act.py +++ b/lerobot/common/policies/act/modeling_act.py @@ -160,6 +160,30 @@ class ACTPolicy( return loss_dict + def make_optimizer_and_scheduler(self, **kwargs): + """Create the optimizer and learning rate scheduler for ACT""" + lr, lr_backbone, weight_decay = kwargs["lr"], kwargs["lr_backbone"], kwargs["weight_decay"] + optimizer_params_dicts = [ + { + "params": [ + p + for n, p in self.named_parameters() + if not n.startswith("model.backbone") and p.requires_grad + ] + }, + { + "params": [ + p + for n, p in self.named_parameters() + if n.startswith("model.backbone") and p.requires_grad + ], + "lr": lr_backbone, + }, + ] + optimizer = torch.optim.AdamW(optimizer_params_dicts, lr=lr, weight_decay=weight_decay) + lr_scheduler = None + return optimizer, lr_scheduler + class ACTTemporalEnsembler: def __init__(self, temporal_ensemble_coeff: float, chunk_size: int) -> None: diff --git a/lerobot/common/policies/diffusion/modeling_diffusion.py b/lerobot/common/policies/diffusion/modeling_diffusion.py index 308a8be3..0093e451 100644 --- a/lerobot/common/policies/diffusion/modeling_diffusion.py +++ b/lerobot/common/policies/diffusion/modeling_diffusion.py @@ -156,6 +156,36 @@ class DiffusionPolicy( loss = self.diffusion.compute_loss(batch) return {"loss": loss} + def make_optimizer_and_scheduler(self, **kwargs): + """Create the optimizer and learning rate scheduler for Diffusion policy""" + lr, adam_betas, adam_eps, adam_weight_decay = ( + kwargs["lr"], + kwargs["adam_betas"], + kwargs["adam_eps"], + kwargs["adam_weight_decay"], + ) + lr_scheduler_name, lr_warmup_steps, offline_steps = ( + kwargs["lr_scheduler"], + kwargs["lr_warmup_steps"], + kwargs["offline_steps"], + ) + optimizer = torch.optim.Adam( + self.diffusion.parameters(), + lr, + adam_betas, + adam_eps, + adam_weight_decay, + ) + from diffusers.optimization import get_scheduler + + lr_scheduler = get_scheduler( + lr_scheduler_name, + optimizer=optimizer, + num_warmup_steps=lr_warmup_steps, + num_training_steps=offline_steps, + ) + return optimizer, lr_scheduler + def _make_noise_scheduler(name: str, **kwargs: dict) -> DDPMScheduler | DDIMScheduler: """ diff --git a/lerobot/common/policies/tdmpc/modeling_tdmpc.py b/lerobot/common/policies/tdmpc/modeling_tdmpc.py index d97c4824..9e988c20 100644 --- a/lerobot/common/policies/tdmpc/modeling_tdmpc.py +++ b/lerobot/common/policies/tdmpc/modeling_tdmpc.py @@ -534,6 +534,13 @@ class TDMPCPolicy( # we update every step and adjust the decay parameter `alpha` accordingly (0.99 -> 0.995) update_ema_parameters(self.model_target, self.model, self.config.target_model_momentum) + def make_optimizer_and_scheduler(self, **kwargs): + """Create the optimizer and learning rate scheduler for TD-MPC""" + lr = kwargs["lr"] + optimizer = torch.optim.Adam(self.parameters(), lr) + lr_scheduler = None + return optimizer, lr_scheduler + class TDMPCTOLD(nn.Module): """Task-Oriented Latent Dynamics (TOLD) model used in TD-MPC.""" diff --git a/lerobot/scripts/train.py b/lerobot/scripts/train.py index 45807503..e2cf55d6 100644 --- a/lerobot/scripts/train.py +++ b/lerobot/scripts/train.py @@ -51,59 +51,6 @@ from lerobot.common.utils.utils import ( from lerobot.scripts.eval import eval_policy -def make_optimizer_and_scheduler(cfg, policy): - if cfg.policy.name == "act": - optimizer_params_dicts = [ - { - "params": [ - p - for n, p in policy.named_parameters() - if not n.startswith("model.backbone") and p.requires_grad - ] - }, - { - "params": [ - p - for n, p in policy.named_parameters() - if n.startswith("model.backbone") and p.requires_grad - ], - "lr": cfg.training.lr_backbone, - }, - ] - optimizer = torch.optim.AdamW( - optimizer_params_dicts, lr=cfg.training.lr, weight_decay=cfg.training.weight_decay - ) - lr_scheduler = None - elif cfg.policy.name == "diffusion": - optimizer = torch.optim.Adam( - policy.diffusion.parameters(), - cfg.training.lr, - cfg.training.adam_betas, - cfg.training.adam_eps, - cfg.training.adam_weight_decay, - ) - from diffusers.optimization import get_scheduler - - lr_scheduler = get_scheduler( - cfg.training.lr_scheduler, - optimizer=optimizer, - num_warmup_steps=cfg.training.lr_warmup_steps, - num_training_steps=cfg.training.offline_steps, - ) - elif policy.name == "tdmpc": - optimizer = torch.optim.Adam(policy.parameters(), cfg.training.lr) - lr_scheduler = None - elif cfg.policy.name == "vqbet": - from lerobot.common.policies.vqbet.modeling_vqbet import VQBeTOptimizer, VQBeTScheduler - - optimizer = VQBeTOptimizer(policy, cfg) - lr_scheduler = VQBeTScheduler(optimizer, cfg) - else: - raise NotImplementedError() - - return optimizer, lr_scheduler - - def update_policy( policy, batch, @@ -334,7 +281,7 @@ def train(cfg: DictConfig, out_dir: str | None = None, job_name: str | None = No assert isinstance(policy, nn.Module) # Create optimizer and scheduler # Temporary hack to move optimizer out of policy - optimizer, lr_scheduler = make_optimizer_and_scheduler(cfg, policy) + optimizer, lr_scheduler = policy.make_optimizer_and_scheduler(**cfg.training) grad_scaler = GradScaler(enabled=cfg.use_amp) step = 0 # number of policy updates (forward + backward + optim)