Add grasp critic to the training loop
- Integrated the grasp critic gradient update to the training loop in learner_server - Added Adam optimizer and configured grasp critic learning rate in configuration_sac - Added target critics networks update after the critics gradient step
This commit is contained in:
parent
2c1e5fa28b
commit
c774bbe522
|
@ -167,6 +167,7 @@ class SACConfig(PreTrainedConfig):
|
|||
num_critics: int = 2
|
||||
num_subsample_critics: int | None = None
|
||||
critic_lr: float = 3e-4
|
||||
grasp_critic_lr: float = 3e-4
|
||||
actor_lr: float = 3e-4
|
||||
temperature_lr: float = 3e-4
|
||||
critic_target_update_weight: float = 0.005
|
||||
|
|
|
@ -214,7 +214,7 @@ class SACPolicy(
|
|||
def forward(
|
||||
self,
|
||||
batch: dict[str, Tensor | dict[str, Tensor]],
|
||||
model: Literal["actor", "critic", "temperature"] = "critic",
|
||||
model: Literal["actor", "critic", "grasp_critic", "temperature"] = "critic",
|
||||
) -> dict[str, Tensor]:
|
||||
"""Compute the loss for the given model
|
||||
|
||||
|
@ -227,7 +227,7 @@ class SACPolicy(
|
|||
- done: Done mask tensor
|
||||
- observation_feature: Optional pre-computed observation features
|
||||
- next_observation_feature: Optional pre-computed next observation features
|
||||
model: Which model to compute the loss for ("actor", "critic", or "temperature")
|
||||
model: Which model to compute the loss for ("actor", "critic", "grasp_critic", or "temperature")
|
||||
|
||||
Returns:
|
||||
The computed loss tensor
|
||||
|
@ -254,6 +254,21 @@ class SACPolicy(
|
|||
observation_features=observation_features,
|
||||
next_observation_features=next_observation_features,
|
||||
)
|
||||
|
||||
if model == "grasp_critic":
|
||||
# Extract grasp_critic-specific components
|
||||
complementary_info: dict[str, Tensor] = batch["complementary_info"]
|
||||
|
||||
return self.compute_loss_grasp_critic(
|
||||
observations=observations,
|
||||
actions=actions,
|
||||
rewards=rewards,
|
||||
next_observations=next_observations,
|
||||
done=done,
|
||||
observation_features=observation_features,
|
||||
next_observation_features=next_observation_features,
|
||||
complementary_info=complementary_info,
|
||||
)
|
||||
|
||||
if model == "actor":
|
||||
return self.compute_loss_actor(
|
||||
|
|
|
@ -375,6 +375,7 @@ def add_actor_information_and_train(
|
|||
observations = batch["state"]
|
||||
next_observations = batch["next_state"]
|
||||
done = batch["done"]
|
||||
complementary_info = batch["complementary_info"]
|
||||
check_nan_in_transition(observations=observations, actions=actions, next_state=next_observations)
|
||||
|
||||
observation_features, next_observation_features = get_observation_features(
|
||||
|
@ -390,6 +391,7 @@ def add_actor_information_and_train(
|
|||
"done": done,
|
||||
"observation_feature": observation_features,
|
||||
"next_observation_feature": next_observation_features,
|
||||
"complementary_info": complementary_info,
|
||||
}
|
||||
|
||||
# Use the forward method for critic loss
|
||||
|
@ -404,7 +406,20 @@ def add_actor_information_and_train(
|
|||
|
||||
optimizers["critic"].step()
|
||||
|
||||
# Add gripper critic optimization
|
||||
loss_grasp_critic = policy.forward(forward_batch, model="grasp_critic")
|
||||
optimizers["grasp_critic"].zero_grad()
|
||||
loss_grasp_critic.backward()
|
||||
|
||||
# clip gradients
|
||||
grasp_critic_grad_norm = torch.nn.utils.clip_grad_norm_(
|
||||
parameters=policy.grasp_critic.parameters(), max_norm=clip_grad_norm_value
|
||||
)
|
||||
|
||||
optimizers["grasp_critic"].step()
|
||||
|
||||
policy.update_target_networks()
|
||||
policy.update_grasp_target_networks()
|
||||
|
||||
batch = replay_buffer.sample(batch_size=batch_size)
|
||||
|
||||
|
@ -435,6 +450,7 @@ def add_actor_information_and_train(
|
|||
"done": done,
|
||||
"observation_feature": observation_features,
|
||||
"next_observation_feature": next_observation_features,
|
||||
"complementary_info": complementary_info,
|
||||
}
|
||||
|
||||
# Use the forward method for critic loss
|
||||
|
@ -453,6 +469,22 @@ def add_actor_information_and_train(
|
|||
training_infos["loss_critic"] = loss_critic.item()
|
||||
training_infos["critic_grad_norm"] = critic_grad_norm
|
||||
|
||||
# Add gripper critic optimization
|
||||
loss_grasp_critic = policy.forward(forward_batch, model="grasp_critic")
|
||||
optimizers["grasp_critic"].zero_grad()
|
||||
loss_grasp_critic.backward()
|
||||
|
||||
# clip gradients
|
||||
grasp_critic_grad_norm = torch.nn.utils.clip_grad_norm_(
|
||||
parameters=policy.grasp_critic.parameters(), max_norm=clip_grad_norm_value
|
||||
)
|
||||
|
||||
optimizers["grasp_critic"].step()
|
||||
|
||||
# Add training info for the grasp critic
|
||||
training_infos["loss_grasp_critic"] = loss_grasp_critic.item()
|
||||
training_infos["grasp_critic_grad_norm"] = grasp_critic_grad_norm
|
||||
|
||||
if optimization_step % policy_update_freq == 0:
|
||||
for _ in range(policy_update_freq):
|
||||
# Use the forward method for actor loss
|
||||
|
@ -495,6 +527,7 @@ def add_actor_information_and_train(
|
|||
last_time_policy_pushed = time.time()
|
||||
|
||||
policy.update_target_networks()
|
||||
policy.update_grasp_target_networks()
|
||||
|
||||
# Log training metrics at specified intervals
|
||||
if optimization_step % log_freq == 0:
|
||||
|
@ -729,11 +762,13 @@ def make_optimizers_and_scheduler(cfg, policy: nn.Module):
|
|||
lr=cfg.policy.actor_lr,
|
||||
)
|
||||
optimizer_critic = torch.optim.Adam(params=policy.critic_ensemble.parameters(), lr=cfg.policy.critic_lr)
|
||||
optimizer_grasp_critic = torch.optim.Adam(params=policy.grasp_critic.parameters(), lr=policy.config.grasp_critic_lr)
|
||||
optimizer_temperature = torch.optim.Adam(params=[policy.log_alpha], lr=cfg.policy.critic_lr)
|
||||
lr_scheduler = None
|
||||
optimizers = {
|
||||
"actor": optimizer_actor,
|
||||
"critic": optimizer_critic,
|
||||
"grasp_critic": optimizer_grasp_critic,
|
||||
"temperature": optimizer_temperature,
|
||||
}
|
||||
return optimizers, lr_scheduler
|
||||
|
|
Loading…
Reference in New Issue