added optimizer and sac to factory.py

This commit is contained in:
Michel Aractingi 2024-12-23 14:12:03 +01:00
parent b53d6e0ff2
commit 08ec971086
3 changed files with 16 additions and 0 deletions

View File

@ -66,6 +66,12 @@ def get_policy_and_config_classes(name: str) -> tuple[Policy, object]:
from lerobot.common.policies.vqbet.modeling_vqbet import VQBeTPolicy from lerobot.common.policies.vqbet.modeling_vqbet import VQBeTPolicy
return VQBeTPolicy, VQBeTConfig return VQBeTPolicy, VQBeTConfig
elif name == "sac":
from lerobot.common.policies.sac.configuration_sac import SACConfig
from lerobot.common.policies.sac.modeling_sac import SACPolicy
return SACPolicy, SACConfig
else: else:
raise NotImplementedError(f"Policy with name {name} is not implemented.") raise NotImplementedError(f"Policy with name {name} is not implemented.")

View File

@ -26,6 +26,7 @@ class SACConfig:
num_subsample_critics = None num_subsample_critics = None
critic_lr = 3e-4 critic_lr = 3e-4
actor_lr = 3e-4 actor_lr = 3e-4
temperature_lr = 3e-4
critic_target_update_weight = 0.005 critic_target_update_weight = 0.005
utd_ratio = 2 utd_ratio = 2
critic_network_kwargs = { critic_network_kwargs = {

View File

@ -93,6 +93,15 @@ def make_optimizer_and_scheduler(cfg, policy):
elif policy.name == "tdmpc": elif policy.name == "tdmpc":
optimizer = torch.optim.Adam(policy.parameters(), cfg.training.lr) optimizer = torch.optim.Adam(policy.parameters(), cfg.training.lr)
lr_scheduler = None lr_scheduler = None
elif policy.name == "sac":
optimizer = torch.optim.Adam([
{'params': policy.actor.parameters(), 'lr': policy.config.actor_lr},
{'params': policy.critic_ensemble.parameters(), 'lr': policy.config.critic_lr},
{'params': policy.temperature.parameters(), 'lr': policy.config.temperature_lr},
])
lr_scheduler = None
elif cfg.policy.name == "vqbet": elif cfg.policy.name == "vqbet":
from lerobot.common.policies.vqbet.modeling_vqbet import VQBeTOptimizer, VQBeTScheduler from lerobot.common.policies.vqbet.modeling_vqbet import VQBeTOptimizer, VQBeTScheduler