Bug fix: fix setting different learning rates between backbone and main model in ACT policy (#280)
This commit is contained in:
parent
b72d574891
commit
11f1cb5dc9
|
@ -53,12 +53,14 @@ def make_optimizer_and_scheduler(cfg, policy):
|
|||
"params": [
|
||||
p
|
||||
for n, p in policy.named_parameters()
|
||||
if not n.startswith("backbone") and p.requires_grad
|
||||
if not n.startswith("model.backbone") and p.requires_grad
|
||||
]
|
||||
},
|
||||
{
|
||||
"params": [
|
||||
p for n, p in policy.named_parameters() if n.startswith("backbone") and p.requires_grad
|
||||
p
|
||||
for n, p in policy.named_parameters()
|
||||
if n.startswith("model.backbone") and p.requires_grad
|
||||
],
|
||||
"lr": cfg.training.lr_backbone,
|
||||
},
|
||||
|
|
|
@ -30,6 +30,7 @@ from lerobot.common.policies.factory import get_policy_and_config_classes, make_
|
|||
from lerobot.common.policies.normalize import Normalize, Unnormalize
|
||||
from lerobot.common.policies.policy_protocol import Policy
|
||||
from lerobot.common.utils.utils import init_hydra_config
|
||||
from lerobot.scripts.train import make_optimizer_and_scheduler
|
||||
from tests.scripts.save_policy_to_safetensors import get_policy_stats
|
||||
from tests.utils import DEFAULT_CONFIG_PATH, DEVICE, require_cpu, require_env, require_x86_64_kernel
|
||||
|
||||
|
@ -174,6 +175,33 @@ def test_policy(env_name, policy_name, extra_overrides):
|
|||
env.step(action)
|
||||
|
||||
|
||||
def test_act_backbone_lr():
|
||||
"""
|
||||
Test that the ACT policy can be instantiated with a different learning rate for the backbone.
|
||||
"""
|
||||
cfg = init_hydra_config(
|
||||
DEFAULT_CONFIG_PATH,
|
||||
overrides=[
|
||||
"env=aloha",
|
||||
"policy=act",
|
||||
f"device={DEVICE}",
|
||||
"training.lr_backbone=0.001",
|
||||
"training.lr=0.01",
|
||||
],
|
||||
)
|
||||
assert cfg.training.lr == 0.01
|
||||
assert cfg.training.lr_backbone == 0.001
|
||||
|
||||
dataset = make_dataset(cfg)
|
||||
policy = make_policy(hydra_cfg=cfg, dataset_stats=dataset.stats)
|
||||
optimizer, _ = make_optimizer_and_scheduler(cfg, policy)
|
||||
assert len(optimizer.param_groups) == 2
|
||||
assert optimizer.param_groups[0]["lr"] == cfg.training.lr
|
||||
assert optimizer.param_groups[1]["lr"] == cfg.training.lr_backbone
|
||||
assert len(optimizer.param_groups[0]["params"]) == 133
|
||||
assert len(optimizer.param_groups[1]["params"]) == 20
|
||||
|
||||
|
||||
@pytest.mark.parametrize("policy_name", available_policies)
|
||||
def test_policy_defaults(policy_name: str):
|
||||
"""Check that the policy can be instantiated with defaults."""
|
||||
|
|
Loading…
Reference in New Issue