Fix convergence of sac, multiple torch compile on the same model caused divergence
This commit is contained in:
parent
8494634d48
commit
026ad463a9
|
@ -112,7 +112,6 @@ class SACPolicy(
|
|||
|
||||
self.critic_ensemble = torch.compile(self.critic_ensemble)
|
||||
self.critic_target = torch.compile(self.critic_target)
|
||||
|
||||
self.actor = Policy(
|
||||
encoder=encoder_actor,
|
||||
network=MLP(input_dim=encoder_actor.output_dim, **asdict(config.actor_network_kwargs)),
|
||||
|
|
|
@ -231,7 +231,6 @@ def act_with_policy(
|
|||
cfg=cfg.policy,
|
||||
env_cfg=cfg.env,
|
||||
)
|
||||
policy = torch.compile(policy)
|
||||
assert isinstance(policy, nn.Module)
|
||||
|
||||
obs, info = online_env.reset()
|
||||
|
|
|
@ -286,8 +286,6 @@ def add_actor_information_and_train(
|
|||
env_cfg=cfg.env,
|
||||
)
|
||||
|
||||
# compile policy
|
||||
policy = torch.compile(policy)
|
||||
assert isinstance(policy, nn.Module)
|
||||
policy.train()
|
||||
|
||||
|
|
Loading…
Reference in New Issue