moved make optimizer and scheduler function to inside policy
This commit is contained in:
parent
c0da806232
commit
bbce0eaeaf
|
@ -160,6 +160,30 @@ class ACTPolicy(
|
|||
|
||||
return loss_dict
|
||||
|
||||
def make_optimizer_and_scheduler(self, **kwargs):
|
||||
"""Create the optimizer and learning rate scheduler for ACT"""
|
||||
lr, lr_backbone, weight_decay = kwargs["lr"], kwargs["lr_backbone"], kwargs["weight_decay"]
|
||||
optimizer_params_dicts = [
|
||||
{
|
||||
"params": [
|
||||
p
|
||||
for n, p in self.named_parameters()
|
||||
if not n.startswith("model.backbone") and p.requires_grad
|
||||
]
|
||||
},
|
||||
{
|
||||
"params": [
|
||||
p
|
||||
for n, p in self.named_parameters()
|
||||
if n.startswith("model.backbone") and p.requires_grad
|
||||
],
|
||||
"lr": lr_backbone,
|
||||
},
|
||||
]
|
||||
optimizer = torch.optim.AdamW(optimizer_params_dicts, lr=lr, weight_decay=weight_decay)
|
||||
lr_scheduler = None
|
||||
return optimizer, lr_scheduler
|
||||
|
||||
|
||||
class ACTTemporalEnsembler:
|
||||
def __init__(self, temporal_ensemble_coeff: float, chunk_size: int) -> None:
|
||||
|
|
|
@ -156,6 +156,36 @@ class DiffusionPolicy(
|
|||
loss = self.diffusion.compute_loss(batch)
|
||||
return {"loss": loss}
|
||||
|
||||
def make_optimizer_and_scheduler(self, **kwargs):
|
||||
"""Create the optimizer and learning rate scheduler for Diffusion policy"""
|
||||
lr, adam_betas, adam_eps, adam_weight_decay = (
|
||||
kwargs["lr"],
|
||||
kwargs["adam_betas"],
|
||||
kwargs["adam_eps"],
|
||||
kwargs["adam_weight_decay"],
|
||||
)
|
||||
lr_scheduler_name, lr_warmup_steps, offline_steps = (
|
||||
kwargs["lr_scheduler"],
|
||||
kwargs["lr_warmup_steps"],
|
||||
kwargs["offline_steps"],
|
||||
)
|
||||
optimizer = torch.optim.Adam(
|
||||
self.diffusion.parameters(),
|
||||
lr,
|
||||
adam_betas,
|
||||
adam_eps,
|
||||
adam_weight_decay,
|
||||
)
|
||||
from diffusers.optimization import get_scheduler
|
||||
|
||||
lr_scheduler = get_scheduler(
|
||||
lr_scheduler_name,
|
||||
optimizer=optimizer,
|
||||
num_warmup_steps=lr_warmup_steps,
|
||||
num_training_steps=offline_steps,
|
||||
)
|
||||
return optimizer, lr_scheduler
|
||||
|
||||
|
||||
def _make_noise_scheduler(name: str, **kwargs: dict) -> DDPMScheduler | DDIMScheduler:
|
||||
"""
|
||||
|
|
|
@ -534,6 +534,13 @@ class TDMPCPolicy(
|
|||
# we update every step and adjust the decay parameter `alpha` accordingly (0.99 -> 0.995)
|
||||
update_ema_parameters(self.model_target, self.model, self.config.target_model_momentum)
|
||||
|
||||
def make_optimizer_and_scheduler(self, **kwargs):
|
||||
"""Create the optimizer and learning rate scheduler for TD-MPC"""
|
||||
lr = kwargs["lr"]
|
||||
optimizer = torch.optim.Adam(self.parameters(), lr)
|
||||
lr_scheduler = None
|
||||
return optimizer, lr_scheduler
|
||||
|
||||
|
||||
class TDMPCTOLD(nn.Module):
|
||||
"""Task-Oriented Latent Dynamics (TOLD) model used in TD-MPC."""
|
||||
|
|
|
@ -51,59 +51,6 @@ from lerobot.common.utils.utils import (
|
|||
from lerobot.scripts.eval import eval_policy
|
||||
|
||||
|
||||
def make_optimizer_and_scheduler(cfg, policy):
|
||||
if cfg.policy.name == "act":
|
||||
optimizer_params_dicts = [
|
||||
{
|
||||
"params": [
|
||||
p
|
||||
for n, p in policy.named_parameters()
|
||||
if not n.startswith("model.backbone") and p.requires_grad
|
||||
]
|
||||
},
|
||||
{
|
||||
"params": [
|
||||
p
|
||||
for n, p in policy.named_parameters()
|
||||
if n.startswith("model.backbone") and p.requires_grad
|
||||
],
|
||||
"lr": cfg.training.lr_backbone,
|
||||
},
|
||||
]
|
||||
optimizer = torch.optim.AdamW(
|
||||
optimizer_params_dicts, lr=cfg.training.lr, weight_decay=cfg.training.weight_decay
|
||||
)
|
||||
lr_scheduler = None
|
||||
elif cfg.policy.name == "diffusion":
|
||||
optimizer = torch.optim.Adam(
|
||||
policy.diffusion.parameters(),
|
||||
cfg.training.lr,
|
||||
cfg.training.adam_betas,
|
||||
cfg.training.adam_eps,
|
||||
cfg.training.adam_weight_decay,
|
||||
)
|
||||
from diffusers.optimization import get_scheduler
|
||||
|
||||
lr_scheduler = get_scheduler(
|
||||
cfg.training.lr_scheduler,
|
||||
optimizer=optimizer,
|
||||
num_warmup_steps=cfg.training.lr_warmup_steps,
|
||||
num_training_steps=cfg.training.offline_steps,
|
||||
)
|
||||
elif policy.name == "tdmpc":
|
||||
optimizer = torch.optim.Adam(policy.parameters(), cfg.training.lr)
|
||||
lr_scheduler = None
|
||||
elif cfg.policy.name == "vqbet":
|
||||
from lerobot.common.policies.vqbet.modeling_vqbet import VQBeTOptimizer, VQBeTScheduler
|
||||
|
||||
optimizer = VQBeTOptimizer(policy, cfg)
|
||||
lr_scheduler = VQBeTScheduler(optimizer, cfg)
|
||||
else:
|
||||
raise NotImplementedError()
|
||||
|
||||
return optimizer, lr_scheduler
|
||||
|
||||
|
||||
def update_policy(
|
||||
policy,
|
||||
batch,
|
||||
|
@ -334,7 +281,7 @@ def train(cfg: DictConfig, out_dir: str | None = None, job_name: str | None = No
|
|||
assert isinstance(policy, nn.Module)
|
||||
# Create optimizer and scheduler
|
||||
# Temporary hack to move optimizer out of policy
|
||||
optimizer, lr_scheduler = make_optimizer_and_scheduler(cfg, policy)
|
||||
optimizer, lr_scheduler = policy.make_optimizer_and_scheduler(**cfg.training)
|
||||
grad_scaler = GradScaler(enabled=cfg.use_amp)
|
||||
|
||||
step = 0 # number of policy updates (forward + backward + optim)
|
||||
|
|
Loading…
Reference in New Issue