[WIP] correct sac implementation
This commit is contained in:
parent
921ed960fb
commit
1e9bafc852
|
@ -63,6 +63,13 @@ class SACPolicy(
|
||||||
"max": torch.tensor([1.0, 1.0, 1.0, 1.0]),
|
"max": torch.tensor([1.0, 1.0, 1.0, 1.0]),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
# HACK: we need to pass the dataset_stats to the normalization functions
|
||||||
|
dataset_stats = dataset_stats or {
|
||||||
|
"action": {
|
||||||
|
"min": torch.tensor([-1.0, -1.0, -1.0, -1.0]),
|
||||||
|
"max": torch.tensor([1.0, 1.0, 1.0, 1.0]),
|
||||||
|
}
|
||||||
|
}
|
||||||
self.normalize_targets = Normalize(
|
self.normalize_targets = Normalize(
|
||||||
config.output_shapes, config.output_normalization_modes, dataset_stats
|
config.output_shapes, config.output_normalization_modes, dataset_stats
|
||||||
)
|
)
|
||||||
|
@ -98,6 +105,7 @@ class SACPolicy(
|
||||||
self.critic_ensemble = create_critic_ensemble(critic_nets, config.num_critics)
|
self.critic_ensemble = create_critic_ensemble(critic_nets, config.num_critics)
|
||||||
self.critic_target = create_critic_ensemble(target_critic_nets, config.num_critics)
|
self.critic_target = create_critic_ensemble(target_critic_nets, config.num_critics)
|
||||||
self.critic_target.load_state_dict(self.critic_ensemble.state_dict())
|
self.critic_target.load_state_dict(self.critic_ensemble.state_dict())
|
||||||
|
self.critic_target.load_state_dict(self.critic_ensemble.state_dict())
|
||||||
|
|
||||||
self.actor = Policy(
|
self.actor = Policy(
|
||||||
encoder=encoder_actor,
|
encoder=encoder_actor,
|
||||||
|
@ -159,12 +167,15 @@ class SACPolicy(
|
||||||
# We have to actualize the value of the temperature because in the previous
|
# We have to actualize the value of the temperature because in the previous
|
||||||
self.temperature = self.log_alpha.exp().item()
|
self.temperature = self.log_alpha.exp().item()
|
||||||
temperature = self.temperature
|
temperature = self.temperature
|
||||||
|
temperature = self.temperature
|
||||||
|
|
||||||
batch = self.normalize_inputs(batch)
|
batch = self.normalize_inputs(batch)
|
||||||
# batch shape is (b, 2, ...) where index 1 returns the current observation and
|
# batch shape is (b, 2, ...) where index 1 returns the current observation and
|
||||||
# the next observation for calculating the right td index.
|
# the next observation for calculating the right td index.
|
||||||
# actions = batch["action"][:, 0]
|
# actions = batch["action"][:, 0]
|
||||||
actions = batch["action"]
|
actions = batch["action"]
|
||||||
|
# actions = batch["action"][:, 0]
|
||||||
|
actions = batch["action"]
|
||||||
rewards = batch["next.reward"][:, 0]
|
rewards = batch["next.reward"][:, 0]
|
||||||
observations = {}
|
observations = {}
|
||||||
next_observations = {}
|
next_observations = {}
|
||||||
|
@ -191,6 +202,7 @@ class SACPolicy(
|
||||||
if self.config.use_backup_entropy:
|
if self.config.use_backup_entropy:
|
||||||
min_q -= self.temperature * next_log_probs
|
min_q -= self.temperature * next_log_probs
|
||||||
td_target = rewards + self.config.discount * min_q * ~done
|
td_target = rewards + self.config.discount * min_q * ~done
|
||||||
|
td_target = rewards + self.config.discount * min_q * ~done
|
||||||
|
|
||||||
# 3- compute predicted qs
|
# 3- compute predicted qs
|
||||||
q_preds = self.critic_forward(observations, actions, use_target=False)
|
q_preds = self.critic_forward(observations, actions, use_target=False)
|
||||||
|
@ -207,9 +219,11 @@ class SACPolicy(
|
||||||
).mean(1)
|
).mean(1)
|
||||||
).sum()
|
).sum()
|
||||||
|
|
||||||
|
actions_pi, log_probs, _ = self.actor(observations)
|
||||||
actions_pi, log_probs, _ = self.actor(observations)
|
actions_pi, log_probs, _ = self.actor(observations)
|
||||||
with torch.inference_mode():
|
with torch.inference_mode():
|
||||||
q_preds = self.critic_forward(observations, actions_pi, use_target=False)
|
q_preds = self.critic_forward(observations, actions_pi, use_target=False)
|
||||||
|
q_preds = self.critic_forward(observations, actions_pi, use_target=False)
|
||||||
min_q_preds = q_preds.min(dim=0)[0]
|
min_q_preds = q_preds.min(dim=0)[0]
|
||||||
|
|
||||||
actor_loss = ((temperature * log_probs) - min_q_preds).mean()
|
actor_loss = ((temperature * log_probs) - min_q_preds).mean()
|
||||||
|
|
Loading…
Reference in New Issue