From 6e687e2910cd8ac10c6f0df1e85decdb354f1f25 Mon Sep 17 00:00:00 2001 From: AdilZouitine Date: Fri, 28 Mar 2025 16:40:45 +0000 Subject: [PATCH] Refactor SACPolicy and learner_server for improved clarity and functionality - Updated the `forward` method in `SACPolicy` to handle loss computation for actor, critic, and temperature models. - Replaced direct calls to `compute_loss_*` methods with a unified `forward` method in `learner_server`. - Enhanced batch processing by consolidating input parameters into a single dictionary for better readability and maintainability. - Removed redundant code and improved documentation for clarity. --- .../common/policies/sac/configuration_sac.py | 12 ---- lerobot/common/policies/sac/modeling_sac.py | 67 ++++++++++++++++++- lerobot/scripts/server/learner_server.py | 59 ++++++++-------- 3 files changed, 96 insertions(+), 42 deletions(-) diff --git a/lerobot/common/policies/sac/configuration_sac.py b/lerobot/common/policies/sac/configuration_sac.py index 629c6576..0d2c3765 100644 --- a/lerobot/common/policies/sac/configuration_sac.py +++ b/lerobot/common/policies/sac/configuration_sac.py @@ -135,18 +135,6 @@ class SACConfig(PreTrainedConfig): } ) - input_features: dict[str, PolicyFeature] = field( - default_factory=lambda: { - "observation.image": PolicyFeature(type=FeatureType.VISUAL, shape=(3, 64, 64)), - "observation.state": PolicyFeature(type=FeatureType.STATE, shape=(2,)), - } - ) - output_features: dict[str, PolicyFeature] = field( - default_factory=lambda: { - "action": PolicyFeature(type=FeatureType.ACTION, shape=(3,)), - } - ) - # Architecture specifics camera_number: int = 1 device: str = "cuda" diff --git a/lerobot/common/policies/sac/modeling_sac.py b/lerobot/common/policies/sac/modeling_sac.py index e46cafcf..c15b8f02 100644 --- a/lerobot/common/policies/sac/modeling_sac.py +++ b/lerobot/common/policies/sac/modeling_sac.py @@ -19,7 +19,7 @@ import math from dataclasses import asdict -from typing import Callable, List, Optional, Tuple +from typing import Callable, List, Literal, Optional, Tuple import einops import numpy as np @@ -177,7 +177,64 @@ class SACPolicy( q_values = critics(observations, actions, observation_features) return q_values - def forward(self, batch: dict[str, Tensor]) -> dict[str, Tensor | float]: ... + def forward( + self, + batch: dict[str, Tensor | dict[str, Tensor]], + model: Literal["actor", "critic", "temperature"] = "critic", + ) -> dict[str, Tensor]: + """Compute the loss for the given model + + Args: + batch: Dictionary containing: + - action: Action tensor + - reward: Reward tensor + - state: Observations tensor dict + - next_state: Next observations tensor dict + - done: Done mask tensor + - observation_feature: Optional pre-computed observation features + - next_observation_feature: Optional pre-computed next observation features + model: Which model to compute the loss for ("actor", "critic", or "temperature") + + Returns: + The computed loss tensor + """ + # TODO: (maractingi, azouitine) Respect the function signature we output tensors + # Extract common components from batch + actions = batch["action"] + observations = batch["state"] + observation_features = batch.get("observation_feature") + + if model == "critic": + # Extract critic-specific components + rewards = batch["reward"] + next_observations = batch["next_state"] + done = batch["done"] + next_observation_features = batch.get("next_observation_feature") + + return self.compute_loss_critic( + observations=observations, + actions=actions, + rewards=rewards, + next_observations=next_observations, + done=done, + observation_features=observation_features, + next_observation_features=next_observation_features, + ) + + if model == "actor": + return self.compute_loss_actor( + observations=observations, + observation_features=observation_features, + ) + + if model == "temperature": + return self.compute_loss_temperature( + observations=observations, + observation_features=observation_features, + ) + + raise ValueError(f"Unknown model type: {model}") + def update_target_networks(self): """Update target networks with exponential moving average""" for target_param, param in zip( @@ -257,7 +314,11 @@ class SACPolicy( temperature_loss = (-self.log_alpha.exp() * (log_probs + self.config.target_entropy)).mean() return temperature_loss - def compute_loss_actor(self, observations, observation_features: Tensor | None = None) -> Tensor: + def compute_loss_actor( + self, + observations, + observation_features: Tensor | None = None, + ) -> Tensor: self.temperature = self.log_alpha.exp().item() actions_pi, log_probs, _ = self.actor(observations, observation_features) diff --git a/lerobot/scripts/server/learner_server.py b/lerobot/scripts/server/learner_server.py index d9bdcf6c..244a6a47 100644 --- a/lerobot/scripts/server/learner_server.py +++ b/lerobot/scripts/server/learner_server.py @@ -382,15 +382,20 @@ def add_actor_information_and_train( observation_features, next_observation_features = get_observation_features( policy=policy, observations=observations, next_observations=next_observations ) - loss_critic = policy.compute_loss_critic( - observations=observations, - actions=actions, - rewards=rewards, - next_observations=next_observations, - done=done, - observation_features=observation_features, - next_observation_features=next_observation_features, - ) + + # Create a batch dictionary with all required elements for the forward method + forward_batch = { + "action": actions, + "reward": rewards, + "state": observations, + "next_state": next_observations, + "done": done, + "observation_feature": observation_features, + "next_observation_feature": next_observation_features, + } + + # Use the forward method for critic loss + loss_critic = policy.forward(forward_batch, model="critic") optimizers["critic"].zero_grad() loss_critic.backward() @@ -422,15 +427,20 @@ def add_actor_information_and_train( observation_features, next_observation_features = get_observation_features( policy=policy, observations=observations, next_observations=next_observations ) - loss_critic = policy.compute_loss_critic( - observations=observations, - actions=actions, - rewards=rewards, - next_observations=next_observations, - done=done, - observation_features=observation_features, - next_observation_features=next_observation_features, - ) + + # Create a batch dictionary with all required elements for the forward method + forward_batch = { + "action": actions, + "reward": rewards, + "state": observations, + "next_state": next_observations, + "done": done, + "observation_feature": observation_features, + "next_observation_feature": next_observation_features, + } + + # Use the forward method for critic loss + loss_critic = policy.forward(forward_batch, model="critic") optimizers["critic"].zero_grad() loss_critic.backward() @@ -447,10 +457,8 @@ def add_actor_information_and_train( if optimization_step % policy_update_freq == 0: for _ in range(policy_update_freq): - loss_actor = policy.compute_loss_actor( - observations=observations, - observation_features=observation_features, - ) + # Use the forward method for actor loss + loss_actor = policy.forward(forward_batch, model="actor") optimizers["actor"].zero_grad() loss_actor.backward() @@ -465,11 +473,8 @@ def add_actor_information_and_train( training_infos["loss_actor"] = loss_actor.item() training_infos["actor_grad_norm"] = actor_grad_norm - # Temperature optimization - loss_temperature = policy.compute_loss_temperature( - observations=observations, - observation_features=observation_features, - ) + # Temperature optimization using forward method + loss_temperature = policy.forward(forward_batch, model="temperature") optimizers["temperature"].zero_grad() loss_temperature.backward()