[WIP] correct sac implementation
This commit is contained in:
parent
be965019bd
commit
956c547254
|
@ -63,6 +63,13 @@ class SACPolicy(
|
|||
"max": torch.tensor([1.0, 1.0, 1.0, 1.0]),
|
||||
}
|
||||
}
|
||||
# HACK: we need to pass the dataset_stats to the normalization functions
|
||||
dataset_stats = dataset_stats or {
|
||||
"action": {
|
||||
"min": torch.tensor([-1.0, -1.0, -1.0, -1.0]),
|
||||
"max": torch.tensor([1.0, 1.0, 1.0, 1.0]),
|
||||
}
|
||||
}
|
||||
self.normalize_targets = Normalize(
|
||||
config.output_shapes, config.output_normalization_modes, dataset_stats
|
||||
)
|
||||
|
@ -98,6 +105,7 @@ class SACPolicy(
|
|||
self.critic_ensemble = create_critic_ensemble(critic_nets, config.num_critics)
|
||||
self.critic_target = create_critic_ensemble(target_critic_nets, config.num_critics)
|
||||
self.critic_target.load_state_dict(self.critic_ensemble.state_dict())
|
||||
self.critic_target.load_state_dict(self.critic_ensemble.state_dict())
|
||||
|
||||
self.actor = Policy(
|
||||
encoder=encoder_actor,
|
||||
|
@ -159,12 +167,15 @@ class SACPolicy(
|
|||
# We have to actualize the value of the temperature because in the previous
|
||||
self.temperature = self.log_alpha.exp().item()
|
||||
temperature = self.temperature
|
||||
temperature = self.temperature
|
||||
|
||||
batch = self.normalize_inputs(batch)
|
||||
# batch shape is (b, 2, ...) where index 1 returns the current observation and
|
||||
# the next observation for calculating the right td index.
|
||||
# actions = batch["action"][:, 0]
|
||||
actions = batch["action"]
|
||||
# actions = batch["action"][:, 0]
|
||||
actions = batch["action"]
|
||||
rewards = batch["next.reward"][:, 0]
|
||||
observations = {}
|
||||
next_observations = {}
|
||||
|
@ -191,6 +202,7 @@ class SACPolicy(
|
|||
if self.config.use_backup_entropy:
|
||||
min_q -= self.temperature * next_log_probs
|
||||
td_target = rewards + self.config.discount * min_q * ~done
|
||||
td_target = rewards + self.config.discount * min_q * ~done
|
||||
|
||||
# 3- compute predicted qs
|
||||
q_preds = self.critic_forward(observations, actions, use_target=False)
|
||||
|
@ -207,9 +219,11 @@ class SACPolicy(
|
|||
).mean(1)
|
||||
).sum()
|
||||
|
||||
actions_pi, log_probs, _ = self.actor(observations)
|
||||
actions_pi, log_probs, _ = self.actor(observations)
|
||||
with torch.inference_mode():
|
||||
q_preds = self.critic_forward(observations, actions_pi, use_target=False)
|
||||
q_preds = self.critic_forward(observations, actions_pi, use_target=False)
|
||||
min_q_preds = q_preds.min(dim=0)[0]
|
||||
|
||||
actor_loss = ((temperature * log_probs) - min_q_preds).mean()
|
||||
|
|
Loading…
Reference in New Issue