Refactor SACPolicy for improved type annotations and readability
- Enhanced type annotations for variables in the `SACPolicy` class to improve code clarity. - Updated method calls to use keyword arguments for better readability. - Streamlined the extraction of batch components, ensuring consistent typing across the class methods.
This commit is contained in:
parent
b3ad63cf6e
commit
0150139668
|
@ -200,16 +200,16 @@ class SACPolicy(
|
||||||
"""
|
"""
|
||||||
# TODO: (maractingi, azouitine) Respect the function signature we output tensors
|
# TODO: (maractingi, azouitine) Respect the function signature we output tensors
|
||||||
# Extract common components from batch
|
# Extract common components from batch
|
||||||
actions = batch["action"]
|
actions: Tensor = batch["action"]
|
||||||
observations = batch["state"]
|
observations: dict[str, Tensor] = batch["state"]
|
||||||
observation_features = batch.get("observation_feature")
|
observation_features: Tensor = batch.get("observation_feature")
|
||||||
|
|
||||||
if model == "critic":
|
if model == "critic":
|
||||||
# Extract critic-specific components
|
# Extract critic-specific components
|
||||||
rewards = batch["reward"]
|
rewards: Tensor = batch["reward"]
|
||||||
next_observations = batch["next_state"]
|
next_observations: dict[str, Tensor] = batch["next_state"]
|
||||||
done = batch["done"]
|
done: Tensor = batch["done"]
|
||||||
next_observation_features = batch.get("next_observation_feature")
|
next_observation_features: Tensor = batch.get("next_observation_feature")
|
||||||
|
|
||||||
return self.compute_loss_critic(
|
return self.compute_loss_critic(
|
||||||
observations=observations,
|
observations=observations,
|
||||||
|
@ -287,8 +287,8 @@ class SACPolicy(
|
||||||
|
|
||||||
# 3- compute predicted qs
|
# 3- compute predicted qs
|
||||||
q_preds = self.critic_forward(
|
q_preds = self.critic_forward(
|
||||||
observations,
|
observations=observations,
|
||||||
actions,
|
actions=actions,
|
||||||
use_target=False,
|
use_target=False,
|
||||||
observation_features=observation_features,
|
observation_features=observation_features,
|
||||||
)
|
)
|
||||||
|
@ -302,7 +302,7 @@ class SACPolicy(
|
||||||
input=q_preds,
|
input=q_preds,
|
||||||
target=td_target_duplicate,
|
target=td_target_duplicate,
|
||||||
reduction="none",
|
reduction="none",
|
||||||
).mean(1)
|
).mean(dim=1)
|
||||||
).sum()
|
).sum()
|
||||||
return critics_loss
|
return critics_loss
|
||||||
|
|
||||||
|
@ -324,11 +324,11 @@ class SACPolicy(
|
||||||
actions_pi, log_probs, _ = self.actor(observations, observation_features)
|
actions_pi, log_probs, _ = self.actor(observations, observation_features)
|
||||||
|
|
||||||
# TODO: (maractingi, azouitine) This is to slow, we should find a way to do this in a more efficient way
|
# TODO: (maractingi, azouitine) This is to slow, we should find a way to do this in a more efficient way
|
||||||
actions_pi = self.unnormalize_outputs({"action": actions_pi})["action"]
|
actions_pi: Tensor = self.unnormalize_outputs({"action": actions_pi})["action"]
|
||||||
|
|
||||||
q_preds = self.critic_forward(
|
q_preds = self.critic_forward(
|
||||||
observations,
|
observations=observations,
|
||||||
actions_pi,
|
actions=actions_pi,
|
||||||
use_target=False,
|
use_target=False,
|
||||||
observation_features=observation_features,
|
observation_features=observation_features,
|
||||||
)
|
)
|
||||||
|
|
Loading…
Reference in New Issue