Refactor SAC policy and training loop to enhance discrete action support

- Updated SACPolicy to conditionally compute losses for grasp critic based on num_discrete_actions.
- Simplified forward method to return loss outputs as a dictionary for better clarity.
- Adjusted learner_server to handle both main and grasp critic losses during training.
- Ensured optimizers are created conditionally for grasp critic based on configuration settings.
This commit is contained in:
AdilZouitine 2025-04-01 11:42:28 +00:00
parent 6a215f47dd
commit 306c735172
3 changed files with 86 additions and 90 deletions

View File

@ -87,6 +87,7 @@ class SACConfig(PreTrainedConfig):
freeze_vision_encoder: Whether to freeze the vision encoder during training. freeze_vision_encoder: Whether to freeze the vision encoder during training.
image_encoder_hidden_dim: Hidden dimension size for the image encoder. image_encoder_hidden_dim: Hidden dimension size for the image encoder.
shared_encoder: Whether to use a shared encoder for actor and critic. shared_encoder: Whether to use a shared encoder for actor and critic.
num_discrete_actions: Number of discrete actions, eg for gripper actions.
concurrency: Configuration for concurrency settings. concurrency: Configuration for concurrency settings.
actor_learner: Configuration for actor-learner architecture. actor_learner: Configuration for actor-learner architecture.
online_steps: Number of steps for online training. online_steps: Number of steps for online training.
@ -162,7 +163,6 @@ class SACConfig(PreTrainedConfig):
num_critics: int = 2 num_critics: int = 2
num_subsample_critics: int | None = None num_subsample_critics: int | None = None
critic_lr: float = 3e-4 critic_lr: float = 3e-4
grasp_critic_lr: float = 3e-4
actor_lr: float = 3e-4 actor_lr: float = 3e-4
temperature_lr: float = 3e-4 temperature_lr: float = 3e-4
critic_target_update_weight: float = 0.005 critic_target_update_weight: float = 0.005

View File

@ -228,7 +228,7 @@ class SACPolicy(
def forward( def forward(
self, self,
batch: dict[str, Tensor | dict[str, Tensor]], batch: dict[str, Tensor | dict[str, Tensor]],
model: Literal["actor", "critic", "grasp_critic", "temperature"] = "critic", model: Literal["actor", "critic", "temperature"] = "critic",
) -> dict[str, Tensor]: ) -> dict[str, Tensor]:
"""Compute the loss for the given model """Compute the loss for the given model
@ -246,7 +246,6 @@ class SACPolicy(
Returns: Returns:
The computed loss tensor The computed loss tensor
""" """
# TODO: (maractingi, azouitine) Respect the function signature we output tensors
# Extract common components from batch # Extract common components from batch
actions: Tensor = batch["action"] actions: Tensor = batch["action"]
observations: dict[str, Tensor] = batch["state"] observations: dict[str, Tensor] = batch["state"]
@ -259,7 +258,7 @@ class SACPolicy(
done: Tensor = batch["done"] done: Tensor = batch["done"]
next_observation_features: Tensor = batch.get("next_observation_feature") next_observation_features: Tensor = batch.get("next_observation_feature")
return self.compute_loss_critic( loss_critic = self.compute_loss_critic(
observations=observations, observations=observations,
actions=actions, actions=actions,
rewards=rewards, rewards=rewards,
@ -268,29 +267,28 @@ class SACPolicy(
observation_features=observation_features, observation_features=observation_features,
next_observation_features=next_observation_features, next_observation_features=next_observation_features,
) )
if self.config.num_discrete_actions is not None:
loss_grasp_critic = self.compute_loss_grasp_critic(
observations=observations,
actions=actions,
rewards=rewards,
next_observations=next_observations,
done=done,
)
return {"loss_critic": loss_critic, "loss_grasp_critic": loss_grasp_critic}
if model == "grasp_critic":
return self.compute_loss_grasp_critic(
observations=observations,
actions=actions,
rewards=rewards,
next_observations=next_observations,
done=done,
observation_features=observation_features,
next_observation_features=next_observation_features,
)
if model == "actor": if model == "actor":
return self.compute_loss_actor( return {"loss_actor": self.compute_loss_actor(
observations=observations, observations=observations,
observation_features=observation_features, observation_features=observation_features,
) )}
if model == "temperature": if model == "temperature":
return self.compute_loss_temperature( return {"loss_temperature": self.compute_loss_temperature(
observations=observations, observations=observations,
observation_features=observation_features, observation_features=observation_features,
) )}
raise ValueError(f"Unknown model type: {model}") raise ValueError(f"Unknown model type: {model}")
@ -305,9 +303,7 @@ class SACPolicy(
param.data * self.config.critic_target_update_weight param.data * self.config.critic_target_update_weight
+ target_param.data * (1.0 - self.config.critic_target_update_weight) + target_param.data * (1.0 - self.config.critic_target_update_weight)
) )
if self.config.num_discrete_actions is not None:
def update_grasp_target_networks(self):
"""Update grasp target networks with exponential moving average"""
for target_param, param in zip( for target_param, param in zip(
self.grasp_critic_target.parameters(), self.grasp_critic_target.parameters(),
self.grasp_critic.parameters(), self.grasp_critic.parameters(),

View File

@ -392,32 +392,30 @@ def add_actor_information_and_train(
"next_observation_feature": next_observation_features, "next_observation_feature": next_observation_features,
} }
# Use the forward method for critic loss # Use the forward method for critic loss (includes both main critic and grasp critic)
loss_critic = policy.forward(forward_batch, model="critic") critic_output = policy.forward(forward_batch, model="critic")
# Main critic optimization
loss_critic = critic_output["loss_critic"]
optimizers["critic"].zero_grad() optimizers["critic"].zero_grad()
loss_critic.backward() loss_critic.backward()
# clip gradients
critic_grad_norm = torch.nn.utils.clip_grad_norm_( critic_grad_norm = torch.nn.utils.clip_grad_norm_(
parameters=policy.critic_ensemble.parameters(), max_norm=clip_grad_norm_value parameters=policy.critic_ensemble.parameters(), max_norm=clip_grad_norm_value
) )
optimizers["critic"].step() optimizers["critic"].step()
# Add gripper critic optimization # Grasp critic optimization (if available)
loss_grasp_critic = policy.forward(forward_batch, model="grasp_critic") if "loss_grasp_critic" in critic_output and hasattr(policy, "grasp_critic"):
loss_grasp_critic = critic_output["loss_grasp_critic"]
optimizers["grasp_critic"].zero_grad() optimizers["grasp_critic"].zero_grad()
loss_grasp_critic.backward() loss_grasp_critic.backward()
# clip gradients
grasp_critic_grad_norm = torch.nn.utils.clip_grad_norm_( grasp_critic_grad_norm = torch.nn.utils.clip_grad_norm_(
parameters=policy.grasp_critic.parameters(), max_norm=clip_grad_norm_value parameters=policy.grasp_critic.parameters(), max_norm=clip_grad_norm_value
) )
optimizers["grasp_critic"].step() optimizers["grasp_critic"].step()
# Update target networks
policy.update_target_networks() policy.update_target_networks()
policy.update_grasp_target_networks()
batch = replay_buffer.sample(batch_size=batch_size) batch = replay_buffer.sample(batch_size=batch_size)
@ -450,81 +448,80 @@ def add_actor_information_and_train(
"next_observation_feature": next_observation_features, "next_observation_feature": next_observation_features,
} }
# Use the forward method for critic loss # Use the forward method for critic loss (includes both main critic and grasp critic)
loss_critic = policy.forward(forward_batch, model="critic") critic_output = policy.forward(forward_batch, model="critic")
# Main critic optimization
loss_critic = critic_output["loss_critic"]
optimizers["critic"].zero_grad() optimizers["critic"].zero_grad()
loss_critic.backward() loss_critic.backward()
# clip gradients
critic_grad_norm = torch.nn.utils.clip_grad_norm_( critic_grad_norm = torch.nn.utils.clip_grad_norm_(
parameters=policy.critic_ensemble.parameters(), max_norm=clip_grad_norm_value parameters=policy.critic_ensemble.parameters(), max_norm=clip_grad_norm_value
).item() ).item()
optimizers["critic"].step() optimizers["critic"].step()
training_infos = {} # Initialize training info dictionary
training_infos["loss_critic"] = loss_critic.item() training_infos = {
training_infos["critic_grad_norm"] = critic_grad_norm "loss_critic": loss_critic.item(),
"critic_grad_norm": critic_grad_norm,
}
# Add gripper critic optimization # Grasp critic optimization (if available)
loss_grasp_critic = policy.forward(forward_batch, model="grasp_critic") if "loss_grasp_critic" in critic_output:
loss_grasp_critic = critic_output["loss_grasp_critic"]
optimizers["grasp_critic"].zero_grad() optimizers["grasp_critic"].zero_grad()
loss_grasp_critic.backward() loss_grasp_critic.backward()
# clip gradients
grasp_critic_grad_norm = torch.nn.utils.clip_grad_norm_( grasp_critic_grad_norm = torch.nn.utils.clip_grad_norm_(
parameters=policy.grasp_critic.parameters(), max_norm=clip_grad_norm_value parameters=policy.grasp_critic.parameters(), max_norm=clip_grad_norm_value
) ).item()
optimizers["grasp_critic"].step() optimizers["grasp_critic"].step()
# Add training info for the grasp critic # Add grasp critic info to training info
training_infos["loss_grasp_critic"] = loss_grasp_critic.item() training_infos["loss_grasp_critic"] = loss_grasp_critic.item()
training_infos["grasp_critic_grad_norm"] = grasp_critic_grad_norm training_infos["grasp_critic_grad_norm"] = grasp_critic_grad_norm
# Actor and temperature optimization (at specified frequency)
if optimization_step % policy_update_freq == 0: if optimization_step % policy_update_freq == 0:
for _ in range(policy_update_freq): for _ in range(policy_update_freq):
# Use the forward method for actor loss # Actor optimization
loss_actor = policy.forward(forward_batch, model="actor") actor_output = policy.forward(forward_batch, model="actor")
loss_actor = actor_output["loss_actor"]
optimizers["actor"].zero_grad() optimizers["actor"].zero_grad()
loss_actor.backward() loss_actor.backward()
# clip gradients
actor_grad_norm = torch.nn.utils.clip_grad_norm_( actor_grad_norm = torch.nn.utils.clip_grad_norm_(
parameters=policy.actor.parameters_to_optimize, max_norm=clip_grad_norm_value parameters=policy.actor.parameters_to_optimize, max_norm=clip_grad_norm_value
).item() ).item()
optimizers["actor"].step() optimizers["actor"].step()
# Add actor info to training info
training_infos["loss_actor"] = loss_actor.item() training_infos["loss_actor"] = loss_actor.item()
training_infos["actor_grad_norm"] = actor_grad_norm training_infos["actor_grad_norm"] = actor_grad_norm
# Temperature optimization using forward method # Temperature optimization
loss_temperature = policy.forward(forward_batch, model="temperature") temperature_output = policy.forward(forward_batch, model="temperature")
loss_temperature = temperature_output["loss_temperature"]
optimizers["temperature"].zero_grad() optimizers["temperature"].zero_grad()
loss_temperature.backward() loss_temperature.backward()
# clip gradients
temp_grad_norm = torch.nn.utils.clip_grad_norm_( temp_grad_norm = torch.nn.utils.clip_grad_norm_(
parameters=[policy.log_alpha], max_norm=clip_grad_norm_value parameters=[policy.log_alpha], max_norm=clip_grad_norm_value
).item() ).item()
optimizers["temperature"].step() optimizers["temperature"].step()
# Add temperature info to training info
training_infos["loss_temperature"] = loss_temperature.item() training_infos["loss_temperature"] = loss_temperature.item()
training_infos["temperature_grad_norm"] = temp_grad_norm training_infos["temperature_grad_norm"] = temp_grad_norm
training_infos["temperature"] = policy.temperature training_infos["temperature"] = policy.temperature
# Update temperature
policy.update_temperature() policy.update_temperature()
# Check if it's time to push updated policy to actors # Push policy to actors if needed
if time.time() - last_time_policy_pushed > policy_parameters_push_frequency: if time.time() - last_time_policy_pushed > policy_parameters_push_frequency:
push_actor_policy_to_queue(parameters_queue=parameters_queue, policy=policy) push_actor_policy_to_queue(parameters_queue=parameters_queue, policy=policy)
last_time_policy_pushed = time.time() last_time_policy_pushed = time.time()
# Update target networks
policy.update_target_networks() policy.update_target_networks()
policy.update_grasp_target_networks()
# Log training metrics at specified intervals # Log training metrics at specified intervals
if optimization_step % log_freq == 0: if optimization_step % log_freq == 0:
@ -727,7 +724,7 @@ def save_training_checkpoint(
logging.info("Resume training") logging.info("Resume training")
def make_optimizers_and_scheduler(cfg, policy: nn.Module): def make_optimizers_and_scheduler(cfg: TrainPipelineConfig, policy: nn.Module):
""" """
Creates and returns optimizers for the actor, critic, and temperature components of a reinforcement learning policy. Creates and returns optimizers for the actor, critic, and temperature components of a reinforcement learning policy.
@ -759,17 +756,20 @@ def make_optimizers_and_scheduler(cfg, policy: nn.Module):
lr=cfg.policy.actor_lr, lr=cfg.policy.actor_lr,
) )
optimizer_critic = torch.optim.Adam(params=policy.critic_ensemble.parameters(), lr=cfg.policy.critic_lr) optimizer_critic = torch.optim.Adam(params=policy.critic_ensemble.parameters(), lr=cfg.policy.critic_lr)
if cfg.policy.num_discrete_actions is not None:
optimizer_grasp_critic = torch.optim.Adam( optimizer_grasp_critic = torch.optim.Adam(
params=policy.grasp_critic.parameters(), lr=policy.config.grasp_critic_lr params=policy.grasp_critic.parameters(), lr=policy.critic_lr
) )
optimizer_temperature = torch.optim.Adam(params=[policy.log_alpha], lr=cfg.policy.critic_lr) optimizer_temperature = torch.optim.Adam(params=[policy.log_alpha], lr=cfg.policy.critic_lr)
lr_scheduler = None lr_scheduler = None
optimizers = { optimizers = {
"actor": optimizer_actor, "actor": optimizer_actor,
"critic": optimizer_critic, "critic": optimizer_critic,
"grasp_critic": optimizer_grasp_critic,
"temperature": optimizer_temperature, "temperature": optimizer_temperature,
} }
if cfg.policy.num_discrete_actions is not None:
optimizers["grasp_critic"] = optimizer_grasp_critic
return optimizers, lr_scheduler return optimizers, lr_scheduler