diff --git a/lerobot/common/policies/sac/modeling_sac.py b/lerobot/common/policies/sac/modeling_sac.py index c15b8f02..385efacc 100644 --- a/lerobot/common/policies/sac/modeling_sac.py +++ b/lerobot/common/policies/sac/modeling_sac.py @@ -200,16 +200,16 @@ class SACPolicy( """ # TODO: (maractingi, azouitine) Respect the function signature we output tensors # Extract common components from batch - actions = batch["action"] - observations = batch["state"] - observation_features = batch.get("observation_feature") + actions: Tensor = batch["action"] + observations: dict[str, Tensor] = batch["state"] + observation_features: Tensor = batch.get("observation_feature") if model == "critic": # Extract critic-specific components - rewards = batch["reward"] - next_observations = batch["next_state"] - done = batch["done"] - next_observation_features = batch.get("next_observation_feature") + rewards: Tensor = batch["reward"] + next_observations: dict[str, Tensor] = batch["next_state"] + done: Tensor = batch["done"] + next_observation_features: Tensor = batch.get("next_observation_feature") return self.compute_loss_critic( observations=observations, @@ -287,8 +287,8 @@ class SACPolicy( # 3- compute predicted qs q_preds = self.critic_forward( - observations, - actions, + observations=observations, + actions=actions, use_target=False, observation_features=observation_features, ) @@ -302,7 +302,7 @@ class SACPolicy( input=q_preds, target=td_target_duplicate, reduction="none", - ).mean(1) + ).mean(dim=1) ).sum() return critics_loss @@ -324,11 +324,11 @@ class SACPolicy( actions_pi, log_probs, _ = self.actor(observations, observation_features) # TODO: (maractingi, azouitine) This is to slow, we should find a way to do this in a more efficient way - actions_pi = self.unnormalize_outputs({"action": actions_pi})["action"] + actions_pi: Tensor = self.unnormalize_outputs({"action": actions_pi})["action"] q_preds = self.critic_forward( - observations, - actions_pi, + observations=observations, + actions=actions_pi, use_target=False, observation_features=observation_features, )