Refactor SACPolicy for improved type annotations and readability

- Enhanced type annotations for variables in the `SACPolicy` class to improve code clarity.
- Updated method calls to use keyword arguments for better readability.
- Streamlined the extraction of batch components, ensuring consistent typing across the class methods.
This commit is contained in:
AdilZouitine 2025-03-28 16:46:21 +00:00
parent b3ad63cf6e
commit 0150139668
1 changed files with 13 additions and 13 deletions

View File

@ -200,16 +200,16 @@ class SACPolicy(
"""
# TODO: (maractingi, azouitine) Respect the function signature we output tensors
# Extract common components from batch
actions = batch["action"]
observations = batch["state"]
observation_features = batch.get("observation_feature")
actions: Tensor = batch["action"]
observations: dict[str, Tensor] = batch["state"]
observation_features: Tensor = batch.get("observation_feature")
if model == "critic":
# Extract critic-specific components
rewards = batch["reward"]
next_observations = batch["next_state"]
done = batch["done"]
next_observation_features = batch.get("next_observation_feature")
rewards: Tensor = batch["reward"]
next_observations: dict[str, Tensor] = batch["next_state"]
done: Tensor = batch["done"]
next_observation_features: Tensor = batch.get("next_observation_feature")
return self.compute_loss_critic(
observations=observations,
@ -287,8 +287,8 @@ class SACPolicy(
# 3- compute predicted qs
q_preds = self.critic_forward(
observations,
actions,
observations=observations,
actions=actions,
use_target=False,
observation_features=observation_features,
)
@ -302,7 +302,7 @@ class SACPolicy(
input=q_preds,
target=td_target_duplicate,
reduction="none",
).mean(1)
).mean(dim=1)
).sum()
return critics_loss
@ -324,11 +324,11 @@ class SACPolicy(
actions_pi, log_probs, _ = self.actor(observations, observation_features)
# TODO: (maractingi, azouitine) This is to slow, we should find a way to do this in a more efficient way
actions_pi = self.unnormalize_outputs({"action": actions_pi})["action"]
actions_pi: Tensor = self.unnormalize_outputs({"action": actions_pi})["action"]
q_preds = self.critic_forward(
observations,
actions_pi,
observations=observations,
actions=actions_pi,
use_target=False,
observation_features=observation_features,
)