added optimizer and sac to factory.py
This commit is contained in:
parent
b53d6e0ff2
commit
08ec971086
|
@ -66,6 +66,12 @@ def get_policy_and_config_classes(name: str) -> tuple[Policy, object]:
|
||||||
from lerobot.common.policies.vqbet.modeling_vqbet import VQBeTPolicy
|
from lerobot.common.policies.vqbet.modeling_vqbet import VQBeTPolicy
|
||||||
|
|
||||||
return VQBeTPolicy, VQBeTConfig
|
return VQBeTPolicy, VQBeTConfig
|
||||||
|
elif name == "sac":
|
||||||
|
from lerobot.common.policies.sac.configuration_sac import SACConfig
|
||||||
|
from lerobot.common.policies.sac.modeling_sac import SACPolicy
|
||||||
|
|
||||||
|
return SACPolicy, SACConfig
|
||||||
|
|
||||||
else:
|
else:
|
||||||
raise NotImplementedError(f"Policy with name {name} is not implemented.")
|
raise NotImplementedError(f"Policy with name {name} is not implemented.")
|
||||||
|
|
||||||
|
|
|
@ -26,6 +26,7 @@ class SACConfig:
|
||||||
num_subsample_critics = None
|
num_subsample_critics = None
|
||||||
critic_lr = 3e-4
|
critic_lr = 3e-4
|
||||||
actor_lr = 3e-4
|
actor_lr = 3e-4
|
||||||
|
temperature_lr = 3e-4
|
||||||
critic_target_update_weight = 0.005
|
critic_target_update_weight = 0.005
|
||||||
utd_ratio = 2
|
utd_ratio = 2
|
||||||
critic_network_kwargs = {
|
critic_network_kwargs = {
|
||||||
|
|
|
@ -93,6 +93,15 @@ def make_optimizer_and_scheduler(cfg, policy):
|
||||||
elif policy.name == "tdmpc":
|
elif policy.name == "tdmpc":
|
||||||
optimizer = torch.optim.Adam(policy.parameters(), cfg.training.lr)
|
optimizer = torch.optim.Adam(policy.parameters(), cfg.training.lr)
|
||||||
lr_scheduler = None
|
lr_scheduler = None
|
||||||
|
|
||||||
|
elif policy.name == "sac":
|
||||||
|
optimizer = torch.optim.Adam([
|
||||||
|
{'params': policy.actor.parameters(), 'lr': policy.config.actor_lr},
|
||||||
|
{'params': policy.critic_ensemble.parameters(), 'lr': policy.config.critic_lr},
|
||||||
|
{'params': policy.temperature.parameters(), 'lr': policy.config.temperature_lr},
|
||||||
|
])
|
||||||
|
lr_scheduler = None
|
||||||
|
|
||||||
elif cfg.policy.name == "vqbet":
|
elif cfg.policy.name == "vqbet":
|
||||||
from lerobot.common.policies.vqbet.modeling_vqbet import VQBeTOptimizer, VQBeTScheduler
|
from lerobot.common.policies.vqbet.modeling_vqbet import VQBeTOptimizer, VQBeTScheduler
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue