Add grasp critic to the training loop

- Integrated the grasp critic gradient update to the training loop in learner_server
- Added Adam optimizer and configured grasp critic learning rate in configuration_sac
- Added target critics networks update after the critics gradient step
This commit is contained in:
s1lent4gnt 2025-03-31 18:06:21 +02:00
parent 2c1e5fa28b
commit c774bbe522
3 changed files with 53 additions and 2 deletions

View File

@ -167,6 +167,7 @@ class SACConfig(PreTrainedConfig):
num_critics: int = 2
num_subsample_critics: int | None = None
critic_lr: float = 3e-4
grasp_critic_lr: float = 3e-4
actor_lr: float = 3e-4
temperature_lr: float = 3e-4
critic_target_update_weight: float = 0.005

View File

@ -214,7 +214,7 @@ class SACPolicy(
def forward(
self,
batch: dict[str, Tensor | dict[str, Tensor]],
model: Literal["actor", "critic", "temperature"] = "critic",
model: Literal["actor", "critic", "grasp_critic", "temperature"] = "critic",
) -> dict[str, Tensor]:
"""Compute the loss for the given model
@ -227,7 +227,7 @@ class SACPolicy(
- done: Done mask tensor
- observation_feature: Optional pre-computed observation features
- next_observation_feature: Optional pre-computed next observation features
model: Which model to compute the loss for ("actor", "critic", or "temperature")
model: Which model to compute the loss for ("actor", "critic", "grasp_critic", or "temperature")
Returns:
The computed loss tensor
@ -254,6 +254,21 @@ class SACPolicy(
observation_features=observation_features,
next_observation_features=next_observation_features,
)
if model == "grasp_critic":
# Extract grasp_critic-specific components
complementary_info: dict[str, Tensor] = batch["complementary_info"]
return self.compute_loss_grasp_critic(
observations=observations,
actions=actions,
rewards=rewards,
next_observations=next_observations,
done=done,
observation_features=observation_features,
next_observation_features=next_observation_features,
complementary_info=complementary_info,
)
if model == "actor":
return self.compute_loss_actor(

View File

@ -375,6 +375,7 @@ def add_actor_information_and_train(
observations = batch["state"]
next_observations = batch["next_state"]
done = batch["done"]
complementary_info = batch["complementary_info"]
check_nan_in_transition(observations=observations, actions=actions, next_state=next_observations)
observation_features, next_observation_features = get_observation_features(
@ -390,6 +391,7 @@ def add_actor_information_and_train(
"done": done,
"observation_feature": observation_features,
"next_observation_feature": next_observation_features,
"complementary_info": complementary_info,
}
# Use the forward method for critic loss
@ -404,7 +406,20 @@ def add_actor_information_and_train(
optimizers["critic"].step()
# Add gripper critic optimization
loss_grasp_critic = policy.forward(forward_batch, model="grasp_critic")
optimizers["grasp_critic"].zero_grad()
loss_grasp_critic.backward()
# clip gradients
grasp_critic_grad_norm = torch.nn.utils.clip_grad_norm_(
parameters=policy.grasp_critic.parameters(), max_norm=clip_grad_norm_value
)
optimizers["grasp_critic"].step()
policy.update_target_networks()
policy.update_grasp_target_networks()
batch = replay_buffer.sample(batch_size=batch_size)
@ -435,6 +450,7 @@ def add_actor_information_and_train(
"done": done,
"observation_feature": observation_features,
"next_observation_feature": next_observation_features,
"complementary_info": complementary_info,
}
# Use the forward method for critic loss
@ -453,6 +469,22 @@ def add_actor_information_and_train(
training_infos["loss_critic"] = loss_critic.item()
training_infos["critic_grad_norm"] = critic_grad_norm
# Add gripper critic optimization
loss_grasp_critic = policy.forward(forward_batch, model="grasp_critic")
optimizers["grasp_critic"].zero_grad()
loss_grasp_critic.backward()
# clip gradients
grasp_critic_grad_norm = torch.nn.utils.clip_grad_norm_(
parameters=policy.grasp_critic.parameters(), max_norm=clip_grad_norm_value
)
optimizers["grasp_critic"].step()
# Add training info for the grasp critic
training_infos["loss_grasp_critic"] = loss_grasp_critic.item()
training_infos["grasp_critic_grad_norm"] = grasp_critic_grad_norm
if optimization_step % policy_update_freq == 0:
for _ in range(policy_update_freq):
# Use the forward method for actor loss
@ -495,6 +527,7 @@ def add_actor_information_and_train(
last_time_policy_pushed = time.time()
policy.update_target_networks()
policy.update_grasp_target_networks()
# Log training metrics at specified intervals
if optimization_step % log_freq == 0:
@ -729,11 +762,13 @@ def make_optimizers_and_scheduler(cfg, policy: nn.Module):
lr=cfg.policy.actor_lr,
)
optimizer_critic = torch.optim.Adam(params=policy.critic_ensemble.parameters(), lr=cfg.policy.critic_lr)
optimizer_grasp_critic = torch.optim.Adam(params=policy.grasp_critic.parameters(), lr=policy.config.grasp_critic_lr)
optimizer_temperature = torch.optim.Adam(params=[policy.log_alpha], lr=cfg.policy.critic_lr)
lr_scheduler = None
optimizers = {
"actor": optimizer_actor,
"critic": optimizer_critic,
"grasp_critic": optimizer_grasp_critic,
"temperature": optimizer_temperature,
}
return optimizers, lr_scheduler