Bug fix: fix setting different learning rates between backbone and main model in ACT policy (#280)

This commit is contained in:
Thomas Wolf 2024-06-18 14:31:35 +02:00 committed by GitHub
parent b72d574891
commit 11f1cb5dc9
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 32 additions and 2 deletions

View File

@ -53,12 +53,14 @@ def make_optimizer_and_scheduler(cfg, policy):
"params": [
p
for n, p in policy.named_parameters()
if not n.startswith("backbone") and p.requires_grad
if not n.startswith("model.backbone") and p.requires_grad
]
},
{
"params": [
p for n, p in policy.named_parameters() if n.startswith("backbone") and p.requires_grad
p
for n, p in policy.named_parameters()
if n.startswith("model.backbone") and p.requires_grad
],
"lr": cfg.training.lr_backbone,
},

View File

@ -30,6 +30,7 @@ from lerobot.common.policies.factory import get_policy_and_config_classes, make_
from lerobot.common.policies.normalize import Normalize, Unnormalize
from lerobot.common.policies.policy_protocol import Policy
from lerobot.common.utils.utils import init_hydra_config
from lerobot.scripts.train import make_optimizer_and_scheduler
from tests.scripts.save_policy_to_safetensors import get_policy_stats
from tests.utils import DEFAULT_CONFIG_PATH, DEVICE, require_cpu, require_env, require_x86_64_kernel
@ -174,6 +175,33 @@ def test_policy(env_name, policy_name, extra_overrides):
env.step(action)
def test_act_backbone_lr():
"""
Test that the ACT policy can be instantiated with a different learning rate for the backbone.
"""
cfg = init_hydra_config(
DEFAULT_CONFIG_PATH,
overrides=[
"env=aloha",
"policy=act",
f"device={DEVICE}",
"training.lr_backbone=0.001",
"training.lr=0.01",
],
)
assert cfg.training.lr == 0.01
assert cfg.training.lr_backbone == 0.001
dataset = make_dataset(cfg)
policy = make_policy(hydra_cfg=cfg, dataset_stats=dataset.stats)
optimizer, _ = make_optimizer_and_scheduler(cfg, policy)
assert len(optimizer.param_groups) == 2
assert optimizer.param_groups[0]["lr"] == cfg.training.lr
assert optimizer.param_groups[1]["lr"] == cfg.training.lr_backbone
assert len(optimizer.param_groups[0]["params"]) == 133
assert len(optimizer.param_groups[1]["params"]) == 20
@pytest.mark.parametrize("policy_name", available_policies)
def test_policy_defaults(policy_name: str):
"""Check that the policy can be instantiated with defaults."""