Fix convergence of sac, multiple torch compile on the same model caused divergence

This commit is contained in:
AdilZouitine 2025-03-31 13:54:21 +00:00
parent 8494634d48
commit 026ad463a9
3 changed files with 1 additions and 5 deletions

View File

@ -112,7 +112,6 @@ class SACPolicy(
self.critic_ensemble = torch.compile(self.critic_ensemble)
self.critic_target = torch.compile(self.critic_target)
self.actor = Policy(
encoder=encoder_actor,
network=MLP(input_dim=encoder_actor.output_dim, **asdict(config.actor_network_kwargs)),

View File

@ -231,7 +231,6 @@ def act_with_policy(
cfg=cfg.policy,
env_cfg=cfg.env,
)
policy = torch.compile(policy)
assert isinstance(policy, nn.Module)
obs, info = online_env.reset()

View File

@ -286,8 +286,6 @@ def add_actor_information_and_train(
env_cfg=cfg.env,
)
# compile policy
policy = torch.compile(policy)
assert isinstance(policy, nn.Module)
policy.train()