Refactor SACPolicy and learner_server for improved clarity and functionality

- Updated the `forward` method in `SACPolicy` to handle loss computation for actor, critic, and temperature models.
- Replaced direct calls to `compute_loss_*` methods with a unified `forward` method in `learner_server`.
- Enhanced batch processing by consolidating input parameters into a single dictionary for better readability and maintainability.
- Removed redundant code and improved documentation for clarity.
This commit is contained in:
AdilZouitine 2025-03-28 16:40:45 +00:00
parent 8b02e81bb5
commit b3ad63cf6e
3 changed files with 96 additions and 42 deletions

View File

@ -135,18 +135,6 @@ class SACConfig(PreTrainedConfig):
} }
) )
input_features: dict[str, PolicyFeature] = field(
default_factory=lambda: {
"observation.image": PolicyFeature(type=FeatureType.VISUAL, shape=(3, 64, 64)),
"observation.state": PolicyFeature(type=FeatureType.STATE, shape=(2,)),
}
)
output_features: dict[str, PolicyFeature] = field(
default_factory=lambda: {
"action": PolicyFeature(type=FeatureType.ACTION, shape=(3,)),
}
)
# Architecture specifics # Architecture specifics
camera_number: int = 1 camera_number: int = 1
device: str = "cuda" device: str = "cuda"

View File

@ -19,7 +19,7 @@
import math import math
from dataclasses import asdict from dataclasses import asdict
from typing import Callable, List, Optional, Tuple from typing import Callable, List, Literal, Optional, Tuple
import einops import einops
import numpy as np import numpy as np
@ -177,7 +177,64 @@ class SACPolicy(
q_values = critics(observations, actions, observation_features) q_values = critics(observations, actions, observation_features)
return q_values return q_values
def forward(self, batch: dict[str, Tensor]) -> dict[str, Tensor | float]: ... def forward(
self,
batch: dict[str, Tensor | dict[str, Tensor]],
model: Literal["actor", "critic", "temperature"] = "critic",
) -> dict[str, Tensor]:
"""Compute the loss for the given model
Args:
batch: Dictionary containing:
- action: Action tensor
- reward: Reward tensor
- state: Observations tensor dict
- next_state: Next observations tensor dict
- done: Done mask tensor
- observation_feature: Optional pre-computed observation features
- next_observation_feature: Optional pre-computed next observation features
model: Which model to compute the loss for ("actor", "critic", or "temperature")
Returns:
The computed loss tensor
"""
# TODO: (maractingi, azouitine) Respect the function signature we output tensors
# Extract common components from batch
actions = batch["action"]
observations = batch["state"]
observation_features = batch.get("observation_feature")
if model == "critic":
# Extract critic-specific components
rewards = batch["reward"]
next_observations = batch["next_state"]
done = batch["done"]
next_observation_features = batch.get("next_observation_feature")
return self.compute_loss_critic(
observations=observations,
actions=actions,
rewards=rewards,
next_observations=next_observations,
done=done,
observation_features=observation_features,
next_observation_features=next_observation_features,
)
if model == "actor":
return self.compute_loss_actor(
observations=observations,
observation_features=observation_features,
)
if model == "temperature":
return self.compute_loss_temperature(
observations=observations,
observation_features=observation_features,
)
raise ValueError(f"Unknown model type: {model}")
def update_target_networks(self): def update_target_networks(self):
"""Update target networks with exponential moving average""" """Update target networks with exponential moving average"""
for target_param, param in zip( for target_param, param in zip(
@ -257,7 +314,11 @@ class SACPolicy(
temperature_loss = (-self.log_alpha.exp() * (log_probs + self.config.target_entropy)).mean() temperature_loss = (-self.log_alpha.exp() * (log_probs + self.config.target_entropy)).mean()
return temperature_loss return temperature_loss
def compute_loss_actor(self, observations, observation_features: Tensor | None = None) -> Tensor: def compute_loss_actor(
self,
observations,
observation_features: Tensor | None = None,
) -> Tensor:
self.temperature = self.log_alpha.exp().item() self.temperature = self.log_alpha.exp().item()
actions_pi, log_probs, _ = self.actor(observations, observation_features) actions_pi, log_probs, _ = self.actor(observations, observation_features)

View File

@ -382,15 +382,20 @@ def add_actor_information_and_train(
observation_features, next_observation_features = get_observation_features( observation_features, next_observation_features = get_observation_features(
policy=policy, observations=observations, next_observations=next_observations policy=policy, observations=observations, next_observations=next_observations
) )
loss_critic = policy.compute_loss_critic(
observations=observations, # Create a batch dictionary with all required elements for the forward method
actions=actions, forward_batch = {
rewards=rewards, "action": actions,
next_observations=next_observations, "reward": rewards,
done=done, "state": observations,
observation_features=observation_features, "next_state": next_observations,
next_observation_features=next_observation_features, "done": done,
) "observation_feature": observation_features,
"next_observation_feature": next_observation_features,
}
# Use the forward method for critic loss
loss_critic = policy.forward(forward_batch, model="critic")
optimizers["critic"].zero_grad() optimizers["critic"].zero_grad()
loss_critic.backward() loss_critic.backward()
@ -422,15 +427,20 @@ def add_actor_information_and_train(
observation_features, next_observation_features = get_observation_features( observation_features, next_observation_features = get_observation_features(
policy=policy, observations=observations, next_observations=next_observations policy=policy, observations=observations, next_observations=next_observations
) )
loss_critic = policy.compute_loss_critic(
observations=observations, # Create a batch dictionary with all required elements for the forward method
actions=actions, forward_batch = {
rewards=rewards, "action": actions,
next_observations=next_observations, "reward": rewards,
done=done, "state": observations,
observation_features=observation_features, "next_state": next_observations,
next_observation_features=next_observation_features, "done": done,
) "observation_feature": observation_features,
"next_observation_feature": next_observation_features,
}
# Use the forward method for critic loss
loss_critic = policy.forward(forward_batch, model="critic")
optimizers["critic"].zero_grad() optimizers["critic"].zero_grad()
loss_critic.backward() loss_critic.backward()
@ -447,10 +457,8 @@ def add_actor_information_and_train(
if optimization_step % policy_update_freq == 0: if optimization_step % policy_update_freq == 0:
for _ in range(policy_update_freq): for _ in range(policy_update_freq):
loss_actor = policy.compute_loss_actor( # Use the forward method for actor loss
observations=observations, loss_actor = policy.forward(forward_batch, model="actor")
observation_features=observation_features,
)
optimizers["actor"].zero_grad() optimizers["actor"].zero_grad()
loss_actor.backward() loss_actor.backward()
@ -465,11 +473,8 @@ def add_actor_information_and_train(
training_infos["loss_actor"] = loss_actor.item() training_infos["loss_actor"] = loss_actor.item()
training_infos["actor_grad_norm"] = actor_grad_norm training_infos["actor_grad_norm"] = actor_grad_norm
# Temperature optimization # Temperature optimization using forward method
loss_temperature = policy.compute_loss_temperature( loss_temperature = policy.forward(forward_batch, model="temperature")
observations=observations,
observation_features=observation_features,
)
optimizers["temperature"].zero_grad() optimizers["temperature"].zero_grad()
loss_temperature.backward() loss_temperature.backward()