Refactor SACPolicy for improved type annotations and readability
- Enhanced type annotations for variables in the `SACPolicy` class to improve code clarity. - Updated method calls to use keyword arguments for better readability. - Streamlined the extraction of batch components, ensuring consistent typing across the class methods.
This commit is contained in:
parent
b3ad63cf6e
commit
0150139668
|
@ -200,16 +200,16 @@ class SACPolicy(
|
|||
"""
|
||||
# TODO: (maractingi, azouitine) Respect the function signature we output tensors
|
||||
# Extract common components from batch
|
||||
actions = batch["action"]
|
||||
observations = batch["state"]
|
||||
observation_features = batch.get("observation_feature")
|
||||
actions: Tensor = batch["action"]
|
||||
observations: dict[str, Tensor] = batch["state"]
|
||||
observation_features: Tensor = batch.get("observation_feature")
|
||||
|
||||
if model == "critic":
|
||||
# Extract critic-specific components
|
||||
rewards = batch["reward"]
|
||||
next_observations = batch["next_state"]
|
||||
done = batch["done"]
|
||||
next_observation_features = batch.get("next_observation_feature")
|
||||
rewards: Tensor = batch["reward"]
|
||||
next_observations: dict[str, Tensor] = batch["next_state"]
|
||||
done: Tensor = batch["done"]
|
||||
next_observation_features: Tensor = batch.get("next_observation_feature")
|
||||
|
||||
return self.compute_loss_critic(
|
||||
observations=observations,
|
||||
|
@ -287,8 +287,8 @@ class SACPolicy(
|
|||
|
||||
# 3- compute predicted qs
|
||||
q_preds = self.critic_forward(
|
||||
observations,
|
||||
actions,
|
||||
observations=observations,
|
||||
actions=actions,
|
||||
use_target=False,
|
||||
observation_features=observation_features,
|
||||
)
|
||||
|
@ -302,7 +302,7 @@ class SACPolicy(
|
|||
input=q_preds,
|
||||
target=td_target_duplicate,
|
||||
reduction="none",
|
||||
).mean(1)
|
||||
).mean(dim=1)
|
||||
).sum()
|
||||
return critics_loss
|
||||
|
||||
|
@ -324,11 +324,11 @@ class SACPolicy(
|
|||
actions_pi, log_probs, _ = self.actor(observations, observation_features)
|
||||
|
||||
# TODO: (maractingi, azouitine) This is to slow, we should find a way to do this in a more efficient way
|
||||
actions_pi = self.unnormalize_outputs({"action": actions_pi})["action"]
|
||||
actions_pi: Tensor = self.unnormalize_outputs({"action": actions_pi})["action"]
|
||||
|
||||
q_preds = self.critic_forward(
|
||||
observations,
|
||||
actions_pi,
|
||||
observations=observations,
|
||||
actions=actions_pi,
|
||||
use_target=False,
|
||||
observation_features=observation_features,
|
||||
)
|
||||
|
|
Loading…
Reference in New Issue