[WIP] correct sac implementation

This commit is contained in:
Adil Zouitine 2025-01-13 17:54:11 +01:00 committed by Michel Aractingi
parent be965019bd
commit 956c547254
1 changed files with 14 additions and 0 deletions

View File

@ -63,6 +63,13 @@ class SACPolicy(
"max": torch.tensor([1.0, 1.0, 1.0, 1.0]),
}
}
# HACK: we need to pass the dataset_stats to the normalization functions
dataset_stats = dataset_stats or {
"action": {
"min": torch.tensor([-1.0, -1.0, -1.0, -1.0]),
"max": torch.tensor([1.0, 1.0, 1.0, 1.0]),
}
}
self.normalize_targets = Normalize(
config.output_shapes, config.output_normalization_modes, dataset_stats
)
@ -98,6 +105,7 @@ class SACPolicy(
self.critic_ensemble = create_critic_ensemble(critic_nets, config.num_critics)
self.critic_target = create_critic_ensemble(target_critic_nets, config.num_critics)
self.critic_target.load_state_dict(self.critic_ensemble.state_dict())
self.critic_target.load_state_dict(self.critic_ensemble.state_dict())
self.actor = Policy(
encoder=encoder_actor,
@ -159,12 +167,15 @@ class SACPolicy(
# We have to actualize the value of the temperature because in the previous
self.temperature = self.log_alpha.exp().item()
temperature = self.temperature
temperature = self.temperature
batch = self.normalize_inputs(batch)
# batch shape is (b, 2, ...) where index 1 returns the current observation and
# the next observation for calculating the right td index.
# actions = batch["action"][:, 0]
actions = batch["action"]
# actions = batch["action"][:, 0]
actions = batch["action"]
rewards = batch["next.reward"][:, 0]
observations = {}
next_observations = {}
@ -191,6 +202,7 @@ class SACPolicy(
if self.config.use_backup_entropy:
min_q -= self.temperature * next_log_probs
td_target = rewards + self.config.discount * min_q * ~done
td_target = rewards + self.config.discount * min_q * ~done
# 3- compute predicted qs
q_preds = self.critic_forward(observations, actions, use_target=False)
@ -207,9 +219,11 @@ class SACPolicy(
).mean(1)
).sum()
actions_pi, log_probs, _ = self.actor(observations)
actions_pi, log_probs, _ = self.actor(observations)
with torch.inference_mode():
q_preds = self.critic_forward(observations, actions_pi, use_target=False)
q_preds = self.critic_forward(observations, actions_pi, use_target=False)
min_q_preds = q_preds.min(dim=0)[0]
actor_loss = ((temperature * log_probs) - min_q_preds).mean()