Refactor SACPolicy and learner_server for improved clarity and functionality
- Updated the `forward` method in `SACPolicy` to handle loss computation for actor, critic, and temperature models. - Replaced direct calls to `compute_loss_*` methods with a unified `forward` method in `learner_server`. - Enhanced batch processing by consolidating input parameters into a single dictionary for better readability and maintainability. - Removed redundant code and improved documentation for clarity.
This commit is contained in:
parent
8b02e81bb5
commit
b3ad63cf6e
|
@ -135,18 +135,6 @@ class SACConfig(PreTrainedConfig):
|
|||
}
|
||||
)
|
||||
|
||||
input_features: dict[str, PolicyFeature] = field(
|
||||
default_factory=lambda: {
|
||||
"observation.image": PolicyFeature(type=FeatureType.VISUAL, shape=(3, 64, 64)),
|
||||
"observation.state": PolicyFeature(type=FeatureType.STATE, shape=(2,)),
|
||||
}
|
||||
)
|
||||
output_features: dict[str, PolicyFeature] = field(
|
||||
default_factory=lambda: {
|
||||
"action": PolicyFeature(type=FeatureType.ACTION, shape=(3,)),
|
||||
}
|
||||
)
|
||||
|
||||
# Architecture specifics
|
||||
camera_number: int = 1
|
||||
device: str = "cuda"
|
||||
|
|
|
@ -19,7 +19,7 @@
|
|||
|
||||
import math
|
||||
from dataclasses import asdict
|
||||
from typing import Callable, List, Optional, Tuple
|
||||
from typing import Callable, List, Literal, Optional, Tuple
|
||||
|
||||
import einops
|
||||
import numpy as np
|
||||
|
@ -177,7 +177,64 @@ class SACPolicy(
|
|||
q_values = critics(observations, actions, observation_features)
|
||||
return q_values
|
||||
|
||||
def forward(self, batch: dict[str, Tensor]) -> dict[str, Tensor | float]: ...
|
||||
def forward(
|
||||
self,
|
||||
batch: dict[str, Tensor | dict[str, Tensor]],
|
||||
model: Literal["actor", "critic", "temperature"] = "critic",
|
||||
) -> dict[str, Tensor]:
|
||||
"""Compute the loss for the given model
|
||||
|
||||
Args:
|
||||
batch: Dictionary containing:
|
||||
- action: Action tensor
|
||||
- reward: Reward tensor
|
||||
- state: Observations tensor dict
|
||||
- next_state: Next observations tensor dict
|
||||
- done: Done mask tensor
|
||||
- observation_feature: Optional pre-computed observation features
|
||||
- next_observation_feature: Optional pre-computed next observation features
|
||||
model: Which model to compute the loss for ("actor", "critic", or "temperature")
|
||||
|
||||
Returns:
|
||||
The computed loss tensor
|
||||
"""
|
||||
# TODO: (maractingi, azouitine) Respect the function signature we output tensors
|
||||
# Extract common components from batch
|
||||
actions = batch["action"]
|
||||
observations = batch["state"]
|
||||
observation_features = batch.get("observation_feature")
|
||||
|
||||
if model == "critic":
|
||||
# Extract critic-specific components
|
||||
rewards = batch["reward"]
|
||||
next_observations = batch["next_state"]
|
||||
done = batch["done"]
|
||||
next_observation_features = batch.get("next_observation_feature")
|
||||
|
||||
return self.compute_loss_critic(
|
||||
observations=observations,
|
||||
actions=actions,
|
||||
rewards=rewards,
|
||||
next_observations=next_observations,
|
||||
done=done,
|
||||
observation_features=observation_features,
|
||||
next_observation_features=next_observation_features,
|
||||
)
|
||||
|
||||
if model == "actor":
|
||||
return self.compute_loss_actor(
|
||||
observations=observations,
|
||||
observation_features=observation_features,
|
||||
)
|
||||
|
||||
if model == "temperature":
|
||||
return self.compute_loss_temperature(
|
||||
observations=observations,
|
||||
observation_features=observation_features,
|
||||
)
|
||||
|
||||
raise ValueError(f"Unknown model type: {model}")
|
||||
|
||||
def update_target_networks(self):
|
||||
"""Update target networks with exponential moving average"""
|
||||
for target_param, param in zip(
|
||||
|
@ -257,7 +314,11 @@ class SACPolicy(
|
|||
temperature_loss = (-self.log_alpha.exp() * (log_probs + self.config.target_entropy)).mean()
|
||||
return temperature_loss
|
||||
|
||||
def compute_loss_actor(self, observations, observation_features: Tensor | None = None) -> Tensor:
|
||||
def compute_loss_actor(
|
||||
self,
|
||||
observations,
|
||||
observation_features: Tensor | None = None,
|
||||
) -> Tensor:
|
||||
self.temperature = self.log_alpha.exp().item()
|
||||
|
||||
actions_pi, log_probs, _ = self.actor(observations, observation_features)
|
||||
|
|
|
@ -382,15 +382,20 @@ def add_actor_information_and_train(
|
|||
observation_features, next_observation_features = get_observation_features(
|
||||
policy=policy, observations=observations, next_observations=next_observations
|
||||
)
|
||||
loss_critic = policy.compute_loss_critic(
|
||||
observations=observations,
|
||||
actions=actions,
|
||||
rewards=rewards,
|
||||
next_observations=next_observations,
|
||||
done=done,
|
||||
observation_features=observation_features,
|
||||
next_observation_features=next_observation_features,
|
||||
)
|
||||
|
||||
# Create a batch dictionary with all required elements for the forward method
|
||||
forward_batch = {
|
||||
"action": actions,
|
||||
"reward": rewards,
|
||||
"state": observations,
|
||||
"next_state": next_observations,
|
||||
"done": done,
|
||||
"observation_feature": observation_features,
|
||||
"next_observation_feature": next_observation_features,
|
||||
}
|
||||
|
||||
# Use the forward method for critic loss
|
||||
loss_critic = policy.forward(forward_batch, model="critic")
|
||||
optimizers["critic"].zero_grad()
|
||||
loss_critic.backward()
|
||||
|
||||
|
@ -422,15 +427,20 @@ def add_actor_information_and_train(
|
|||
observation_features, next_observation_features = get_observation_features(
|
||||
policy=policy, observations=observations, next_observations=next_observations
|
||||
)
|
||||
loss_critic = policy.compute_loss_critic(
|
||||
observations=observations,
|
||||
actions=actions,
|
||||
rewards=rewards,
|
||||
next_observations=next_observations,
|
||||
done=done,
|
||||
observation_features=observation_features,
|
||||
next_observation_features=next_observation_features,
|
||||
)
|
||||
|
||||
# Create a batch dictionary with all required elements for the forward method
|
||||
forward_batch = {
|
||||
"action": actions,
|
||||
"reward": rewards,
|
||||
"state": observations,
|
||||
"next_state": next_observations,
|
||||
"done": done,
|
||||
"observation_feature": observation_features,
|
||||
"next_observation_feature": next_observation_features,
|
||||
}
|
||||
|
||||
# Use the forward method for critic loss
|
||||
loss_critic = policy.forward(forward_batch, model="critic")
|
||||
optimizers["critic"].zero_grad()
|
||||
loss_critic.backward()
|
||||
|
||||
|
@ -447,10 +457,8 @@ def add_actor_information_and_train(
|
|||
|
||||
if optimization_step % policy_update_freq == 0:
|
||||
for _ in range(policy_update_freq):
|
||||
loss_actor = policy.compute_loss_actor(
|
||||
observations=observations,
|
||||
observation_features=observation_features,
|
||||
)
|
||||
# Use the forward method for actor loss
|
||||
loss_actor = policy.forward(forward_batch, model="actor")
|
||||
|
||||
optimizers["actor"].zero_grad()
|
||||
loss_actor.backward()
|
||||
|
@ -465,11 +473,8 @@ def add_actor_information_and_train(
|
|||
training_infos["loss_actor"] = loss_actor.item()
|
||||
training_infos["actor_grad_norm"] = actor_grad_norm
|
||||
|
||||
# Temperature optimization
|
||||
loss_temperature = policy.compute_loss_temperature(
|
||||
observations=observations,
|
||||
observation_features=observation_features,
|
||||
)
|
||||
# Temperature optimization using forward method
|
||||
loss_temperature = policy.forward(forward_batch, model="temperature")
|
||||
optimizers["temperature"].zero_grad()
|
||||
loss_temperature.backward()
|
||||
|
||||
|
|
Loading…
Reference in New Issue