Refactor SACPolicy and learner_server for improved clarity and functionality

- Updated the `forward` method in `SACPolicy` to handle loss computation for actor, critic, and temperature models.
- Replaced direct calls to `compute_loss_*` methods with a unified `forward` method in `learner_server`.
- Enhanced batch processing by consolidating input parameters into a single dictionary for better readability and maintainability.
- Removed redundant code and improved documentation for clarity.
This commit is contained in:
AdilZouitine 2025-03-28 16:40:45 +00:00
parent 8b02e81bb5
commit b3ad63cf6e
3 changed files with 96 additions and 42 deletions

View File

@ -135,18 +135,6 @@ class SACConfig(PreTrainedConfig):
}
)
input_features: dict[str, PolicyFeature] = field(
default_factory=lambda: {
"observation.image": PolicyFeature(type=FeatureType.VISUAL, shape=(3, 64, 64)),
"observation.state": PolicyFeature(type=FeatureType.STATE, shape=(2,)),
}
)
output_features: dict[str, PolicyFeature] = field(
default_factory=lambda: {
"action": PolicyFeature(type=FeatureType.ACTION, shape=(3,)),
}
)
# Architecture specifics
camera_number: int = 1
device: str = "cuda"

View File

@ -19,7 +19,7 @@
import math
from dataclasses import asdict
from typing import Callable, List, Optional, Tuple
from typing import Callable, List, Literal, Optional, Tuple
import einops
import numpy as np
@ -177,7 +177,64 @@ class SACPolicy(
q_values = critics(observations, actions, observation_features)
return q_values
def forward(self, batch: dict[str, Tensor]) -> dict[str, Tensor | float]: ...
def forward(
self,
batch: dict[str, Tensor | dict[str, Tensor]],
model: Literal["actor", "critic", "temperature"] = "critic",
) -> dict[str, Tensor]:
"""Compute the loss for the given model
Args:
batch: Dictionary containing:
- action: Action tensor
- reward: Reward tensor
- state: Observations tensor dict
- next_state: Next observations tensor dict
- done: Done mask tensor
- observation_feature: Optional pre-computed observation features
- next_observation_feature: Optional pre-computed next observation features
model: Which model to compute the loss for ("actor", "critic", or "temperature")
Returns:
The computed loss tensor
"""
# TODO: (maractingi, azouitine) Respect the function signature we output tensors
# Extract common components from batch
actions = batch["action"]
observations = batch["state"]
observation_features = batch.get("observation_feature")
if model == "critic":
# Extract critic-specific components
rewards = batch["reward"]
next_observations = batch["next_state"]
done = batch["done"]
next_observation_features = batch.get("next_observation_feature")
return self.compute_loss_critic(
observations=observations,
actions=actions,
rewards=rewards,
next_observations=next_observations,
done=done,
observation_features=observation_features,
next_observation_features=next_observation_features,
)
if model == "actor":
return self.compute_loss_actor(
observations=observations,
observation_features=observation_features,
)
if model == "temperature":
return self.compute_loss_temperature(
observations=observations,
observation_features=observation_features,
)
raise ValueError(f"Unknown model type: {model}")
def update_target_networks(self):
"""Update target networks with exponential moving average"""
for target_param, param in zip(
@ -257,7 +314,11 @@ class SACPolicy(
temperature_loss = (-self.log_alpha.exp() * (log_probs + self.config.target_entropy)).mean()
return temperature_loss
def compute_loss_actor(self, observations, observation_features: Tensor | None = None) -> Tensor:
def compute_loss_actor(
self,
observations,
observation_features: Tensor | None = None,
) -> Tensor:
self.temperature = self.log_alpha.exp().item()
actions_pi, log_probs, _ = self.actor(observations, observation_features)

View File

@ -382,15 +382,20 @@ def add_actor_information_and_train(
observation_features, next_observation_features = get_observation_features(
policy=policy, observations=observations, next_observations=next_observations
)
loss_critic = policy.compute_loss_critic(
observations=observations,
actions=actions,
rewards=rewards,
next_observations=next_observations,
done=done,
observation_features=observation_features,
next_observation_features=next_observation_features,
)
# Create a batch dictionary with all required elements for the forward method
forward_batch = {
"action": actions,
"reward": rewards,
"state": observations,
"next_state": next_observations,
"done": done,
"observation_feature": observation_features,
"next_observation_feature": next_observation_features,
}
# Use the forward method for critic loss
loss_critic = policy.forward(forward_batch, model="critic")
optimizers["critic"].zero_grad()
loss_critic.backward()
@ -422,15 +427,20 @@ def add_actor_information_and_train(
observation_features, next_observation_features = get_observation_features(
policy=policy, observations=observations, next_observations=next_observations
)
loss_critic = policy.compute_loss_critic(
observations=observations,
actions=actions,
rewards=rewards,
next_observations=next_observations,
done=done,
observation_features=observation_features,
next_observation_features=next_observation_features,
)
# Create a batch dictionary with all required elements for the forward method
forward_batch = {
"action": actions,
"reward": rewards,
"state": observations,
"next_state": next_observations,
"done": done,
"observation_feature": observation_features,
"next_observation_feature": next_observation_features,
}
# Use the forward method for critic loss
loss_critic = policy.forward(forward_batch, model="critic")
optimizers["critic"].zero_grad()
loss_critic.backward()
@ -447,10 +457,8 @@ def add_actor_information_and_train(
if optimization_step % policy_update_freq == 0:
for _ in range(policy_update_freq):
loss_actor = policy.compute_loss_actor(
observations=observations,
observation_features=observation_features,
)
# Use the forward method for actor loss
loss_actor = policy.forward(forward_batch, model="actor")
optimizers["actor"].zero_grad()
loss_actor.backward()
@ -465,11 +473,8 @@ def add_actor_information_and_train(
training_infos["loss_actor"] = loss_actor.item()
training_infos["actor_grad_norm"] = actor_grad_norm
# Temperature optimization
loss_temperature = policy.compute_loss_temperature(
observations=observations,
observation_features=observation_features,
)
# Temperature optimization using forward method
loss_temperature = policy.forward(forward_batch, model="temperature")
optimizers["temperature"].zero_grad()
loss_temperature.backward()