[WIP] correct sac implementation

This commit is contained in:
Adil Zouitine 2025-01-13 17:54:11 +01:00 committed by Michel Aractingi
parent be965019bd
commit 956c547254
1 changed files with 14 additions and 0 deletions

View File

@ -63,6 +63,13 @@ class SACPolicy(
"max": torch.tensor([1.0, 1.0, 1.0, 1.0]), "max": torch.tensor([1.0, 1.0, 1.0, 1.0]),
} }
} }
# HACK: we need to pass the dataset_stats to the normalization functions
dataset_stats = dataset_stats or {
"action": {
"min": torch.tensor([-1.0, -1.0, -1.0, -1.0]),
"max": torch.tensor([1.0, 1.0, 1.0, 1.0]),
}
}
self.normalize_targets = Normalize( self.normalize_targets = Normalize(
config.output_shapes, config.output_normalization_modes, dataset_stats config.output_shapes, config.output_normalization_modes, dataset_stats
) )
@ -98,6 +105,7 @@ class SACPolicy(
self.critic_ensemble = create_critic_ensemble(critic_nets, config.num_critics) self.critic_ensemble = create_critic_ensemble(critic_nets, config.num_critics)
self.critic_target = create_critic_ensemble(target_critic_nets, config.num_critics) self.critic_target = create_critic_ensemble(target_critic_nets, config.num_critics)
self.critic_target.load_state_dict(self.critic_ensemble.state_dict()) self.critic_target.load_state_dict(self.critic_ensemble.state_dict())
self.critic_target.load_state_dict(self.critic_ensemble.state_dict())
self.actor = Policy( self.actor = Policy(
encoder=encoder_actor, encoder=encoder_actor,
@ -159,12 +167,15 @@ class SACPolicy(
# We have to actualize the value of the temperature because in the previous # We have to actualize the value of the temperature because in the previous
self.temperature = self.log_alpha.exp().item() self.temperature = self.log_alpha.exp().item()
temperature = self.temperature temperature = self.temperature
temperature = self.temperature
batch = self.normalize_inputs(batch) batch = self.normalize_inputs(batch)
# batch shape is (b, 2, ...) where index 1 returns the current observation and # batch shape is (b, 2, ...) where index 1 returns the current observation and
# the next observation for calculating the right td index. # the next observation for calculating the right td index.
# actions = batch["action"][:, 0] # actions = batch["action"][:, 0]
actions = batch["action"] actions = batch["action"]
# actions = batch["action"][:, 0]
actions = batch["action"]
rewards = batch["next.reward"][:, 0] rewards = batch["next.reward"][:, 0]
observations = {} observations = {}
next_observations = {} next_observations = {}
@ -191,6 +202,7 @@ class SACPolicy(
if self.config.use_backup_entropy: if self.config.use_backup_entropy:
min_q -= self.temperature * next_log_probs min_q -= self.temperature * next_log_probs
td_target = rewards + self.config.discount * min_q * ~done td_target = rewards + self.config.discount * min_q * ~done
td_target = rewards + self.config.discount * min_q * ~done
# 3- compute predicted qs # 3- compute predicted qs
q_preds = self.critic_forward(observations, actions, use_target=False) q_preds = self.critic_forward(observations, actions, use_target=False)
@ -207,9 +219,11 @@ class SACPolicy(
).mean(1) ).mean(1)
).sum() ).sum()
actions_pi, log_probs, _ = self.actor(observations)
actions_pi, log_probs, _ = self.actor(observations) actions_pi, log_probs, _ = self.actor(observations)
with torch.inference_mode(): with torch.inference_mode():
q_preds = self.critic_forward(observations, actions_pi, use_target=False) q_preds = self.critic_forward(observations, actions_pi, use_target=False)
q_preds = self.critic_forward(observations, actions_pi, use_target=False)
min_q_preds = q_preds.min(dim=0)[0] min_q_preds = q_preds.min(dim=0)[0]
actor_loss = ((temperature * log_probs) - min_q_preds).mean() actor_loss = ((temperature * log_probs) - min_q_preds).mean()