Refactor SACPolicy for improved type annotations and readability

- Enhanced type annotations for variables in the `SACPolicy` class to improve code clarity.
- Updated method calls to use keyword arguments for better readability.
- Streamlined the extraction of batch components, ensuring consistent typing across the class methods.
This commit is contained in:
AdilZouitine 2025-03-28 16:46:21 +00:00
parent b3ad63cf6e
commit 0150139668
1 changed files with 13 additions and 13 deletions

View File

@ -200,16 +200,16 @@ class SACPolicy(
""" """
# TODO: (maractingi, azouitine) Respect the function signature we output tensors # TODO: (maractingi, azouitine) Respect the function signature we output tensors
# Extract common components from batch # Extract common components from batch
actions = batch["action"] actions: Tensor = batch["action"]
observations = batch["state"] observations: dict[str, Tensor] = batch["state"]
observation_features = batch.get("observation_feature") observation_features: Tensor = batch.get("observation_feature")
if model == "critic": if model == "critic":
# Extract critic-specific components # Extract critic-specific components
rewards = batch["reward"] rewards: Tensor = batch["reward"]
next_observations = batch["next_state"] next_observations: dict[str, Tensor] = batch["next_state"]
done = batch["done"] done: Tensor = batch["done"]
next_observation_features = batch.get("next_observation_feature") next_observation_features: Tensor = batch.get("next_observation_feature")
return self.compute_loss_critic( return self.compute_loss_critic(
observations=observations, observations=observations,
@ -287,8 +287,8 @@ class SACPolicy(
# 3- compute predicted qs # 3- compute predicted qs
q_preds = self.critic_forward( q_preds = self.critic_forward(
observations, observations=observations,
actions, actions=actions,
use_target=False, use_target=False,
observation_features=observation_features, observation_features=observation_features,
) )
@ -302,7 +302,7 @@ class SACPolicy(
input=q_preds, input=q_preds,
target=td_target_duplicate, target=td_target_duplicate,
reduction="none", reduction="none",
).mean(1) ).mean(dim=1)
).sum() ).sum()
return critics_loss return critics_loss
@ -324,11 +324,11 @@ class SACPolicy(
actions_pi, log_probs, _ = self.actor(observations, observation_features) actions_pi, log_probs, _ = self.actor(observations, observation_features)
# TODO: (maractingi, azouitine) This is to slow, we should find a way to do this in a more efficient way # TODO: (maractingi, azouitine) This is to slow, we should find a way to do this in a more efficient way
actions_pi = self.unnormalize_outputs({"action": actions_pi})["action"] actions_pi: Tensor = self.unnormalize_outputs({"action": actions_pi})["action"]
q_preds = self.critic_forward( q_preds = self.critic_forward(
observations, observations=observations,
actions_pi, actions=actions_pi,
use_target=False, use_target=False,
observation_features=observation_features, observation_features=observation_features,
) )