diff --git a/lerobot/common/policies/sac/modeling_sac.py b/lerobot/common/policies/sac/modeling_sac.py index e5173e04..23513916 100644 --- a/lerobot/common/policies/sac/modeling_sac.py +++ b/lerobot/common/policies/sac/modeling_sac.py @@ -63,6 +63,13 @@ class SACPolicy( "max": torch.tensor([1.0, 1.0, 1.0, 1.0]), } } + # HACK: we need to pass the dataset_stats to the normalization functions + dataset_stats = dataset_stats or { + "action": { + "min": torch.tensor([-1.0, -1.0, -1.0, -1.0]), + "max": torch.tensor([1.0, 1.0, 1.0, 1.0]), + } + } self.normalize_targets = Normalize( config.output_shapes, config.output_normalization_modes, dataset_stats ) @@ -98,6 +105,7 @@ class SACPolicy( self.critic_ensemble = create_critic_ensemble(critic_nets, config.num_critics) self.critic_target = create_critic_ensemble(target_critic_nets, config.num_critics) self.critic_target.load_state_dict(self.critic_ensemble.state_dict()) + self.critic_target.load_state_dict(self.critic_ensemble.state_dict()) self.actor = Policy( encoder=encoder_actor, @@ -159,12 +167,15 @@ class SACPolicy( # We have to actualize the value of the temperature because in the previous self.temperature = self.log_alpha.exp().item() temperature = self.temperature + temperature = self.temperature batch = self.normalize_inputs(batch) # batch shape is (b, 2, ...) where index 1 returns the current observation and # the next observation for calculating the right td index. # actions = batch["action"][:, 0] actions = batch["action"] + # actions = batch["action"][:, 0] + actions = batch["action"] rewards = batch["next.reward"][:, 0] observations = {} next_observations = {} @@ -191,6 +202,7 @@ class SACPolicy( if self.config.use_backup_entropy: min_q -= self.temperature * next_log_probs td_target = rewards + self.config.discount * min_q * ~done + td_target = rewards + self.config.discount * min_q * ~done # 3- compute predicted qs q_preds = self.critic_forward(observations, actions, use_target=False) @@ -207,9 +219,11 @@ class SACPolicy( ).mean(1) ).sum() + actions_pi, log_probs, _ = self.actor(observations) actions_pi, log_probs, _ = self.actor(observations) with torch.inference_mode(): q_preds = self.critic_forward(observations, actions_pi, use_target=False) + q_preds = self.critic_forward(observations, actions_pi, use_target=False) min_q_preds = q_preds.min(dim=0)[0] actor_loss = ((temperature * log_probs) - min_q_preds).mean()