From 08ec971086488277fc8745bc5c11a445e46c51ea Mon Sep 17 00:00:00 2001 From: Michel Aractingi Date: Mon, 23 Dec 2024 14:12:03 +0100 Subject: [PATCH] added optimizer and sac to factory.py --- lerobot/common/policies/factory.py | 6 ++++++ lerobot/common/policies/sac/configuration_sac.py | 1 + lerobot/scripts/train.py | 9 +++++++++ 3 files changed, 16 insertions(+) diff --git a/lerobot/common/policies/factory.py b/lerobot/common/policies/factory.py index 5cb2fd52..7f550d90 100644 --- a/lerobot/common/policies/factory.py +++ b/lerobot/common/policies/factory.py @@ -66,6 +66,12 @@ def get_policy_and_config_classes(name: str) -> tuple[Policy, object]: from lerobot.common.policies.vqbet.modeling_vqbet import VQBeTPolicy return VQBeTPolicy, VQBeTConfig + elif name == "sac": + from lerobot.common.policies.sac.configuration_sac import SACConfig + from lerobot.common.policies.sac.modeling_sac import SACPolicy + + return SACPolicy, SACConfig + else: raise NotImplementedError(f"Policy with name {name} is not implemented.") diff --git a/lerobot/common/policies/sac/configuration_sac.py b/lerobot/common/policies/sac/configuration_sac.py index d324462e..6db198e8 100644 --- a/lerobot/common/policies/sac/configuration_sac.py +++ b/lerobot/common/policies/sac/configuration_sac.py @@ -26,6 +26,7 @@ class SACConfig: num_subsample_critics = None critic_lr = 3e-4 actor_lr = 3e-4 + temperature_lr = 3e-4 critic_target_update_weight = 0.005 utd_ratio = 2 critic_network_kwargs = { diff --git a/lerobot/scripts/train.py b/lerobot/scripts/train.py index 9a0b7e4c..346c3acd 100644 --- a/lerobot/scripts/train.py +++ b/lerobot/scripts/train.py @@ -93,6 +93,15 @@ def make_optimizer_and_scheduler(cfg, policy): elif policy.name == "tdmpc": optimizer = torch.optim.Adam(policy.parameters(), cfg.training.lr) lr_scheduler = None + + elif policy.name == "sac": + optimizer = torch.optim.Adam([ + {'params': policy.actor.parameters(), 'lr': policy.config.actor_lr}, + {'params': policy.critic_ensemble.parameters(), 'lr': policy.config.critic_lr}, + {'params': policy.temperature.parameters(), 'lr': policy.config.temperature_lr}, + ]) + lr_scheduler = None + elif cfg.policy.name == "vqbet": from lerobot.common.policies.vqbet.modeling_vqbet import VQBeTOptimizer, VQBeTScheduler