Fix convergence of sac, multiple torch compile on the same model caused divergence
This commit is contained in:
parent
8494634d48
commit
026ad463a9
|
@ -112,7 +112,6 @@ class SACPolicy(
|
||||||
|
|
||||||
self.critic_ensemble = torch.compile(self.critic_ensemble)
|
self.critic_ensemble = torch.compile(self.critic_ensemble)
|
||||||
self.critic_target = torch.compile(self.critic_target)
|
self.critic_target = torch.compile(self.critic_target)
|
||||||
|
|
||||||
self.actor = Policy(
|
self.actor = Policy(
|
||||||
encoder=encoder_actor,
|
encoder=encoder_actor,
|
||||||
network=MLP(input_dim=encoder_actor.output_dim, **asdict(config.actor_network_kwargs)),
|
network=MLP(input_dim=encoder_actor.output_dim, **asdict(config.actor_network_kwargs)),
|
||||||
|
|
|
@ -231,7 +231,6 @@ def act_with_policy(
|
||||||
cfg=cfg.policy,
|
cfg=cfg.policy,
|
||||||
env_cfg=cfg.env,
|
env_cfg=cfg.env,
|
||||||
)
|
)
|
||||||
policy = torch.compile(policy)
|
|
||||||
assert isinstance(policy, nn.Module)
|
assert isinstance(policy, nn.Module)
|
||||||
|
|
||||||
obs, info = online_env.reset()
|
obs, info = online_env.reset()
|
||||||
|
|
|
@ -285,9 +285,7 @@ def add_actor_information_and_train(
|
||||||
# ds_meta=cfg.dataset,
|
# ds_meta=cfg.dataset,
|
||||||
env_cfg=cfg.env,
|
env_cfg=cfg.env,
|
||||||
)
|
)
|
||||||
|
|
||||||
# compile policy
|
|
||||||
policy = torch.compile(policy)
|
|
||||||
assert isinstance(policy, nn.Module)
|
assert isinstance(policy, nn.Module)
|
||||||
policy.train()
|
policy.train()
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue