Refactor SAC policy and training loop to enhance discrete action support
- Updated SACPolicy to conditionally compute losses for grasp critic based on num_discrete_actions. - Simplified forward method to return loss outputs as a dictionary for better clarity. - Adjusted learner_server to handle both main and grasp critic losses during training. - Ensured optimizers are created conditionally for grasp critic based on configuration settings.
This commit is contained in:
parent
c3f2487026
commit
e35ee47b07
|
@ -87,6 +87,7 @@ class SACConfig(PreTrainedConfig):
|
|||
freeze_vision_encoder: Whether to freeze the vision encoder during training.
|
||||
image_encoder_hidden_dim: Hidden dimension size for the image encoder.
|
||||
shared_encoder: Whether to use a shared encoder for actor and critic.
|
||||
num_discrete_actions: Number of discrete actions, eg for gripper actions.
|
||||
concurrency: Configuration for concurrency settings.
|
||||
actor_learner: Configuration for actor-learner architecture.
|
||||
online_steps: Number of steps for online training.
|
||||
|
@ -162,7 +163,6 @@ class SACConfig(PreTrainedConfig):
|
|||
num_critics: int = 2
|
||||
num_subsample_critics: int | None = None
|
||||
critic_lr: float = 3e-4
|
||||
grasp_critic_lr: float = 3e-4
|
||||
actor_lr: float = 3e-4
|
||||
temperature_lr: float = 3e-4
|
||||
critic_target_update_weight: float = 0.005
|
||||
|
|
|
@ -228,7 +228,7 @@ class SACPolicy(
|
|||
def forward(
|
||||
self,
|
||||
batch: dict[str, Tensor | dict[str, Tensor]],
|
||||
model: Literal["actor", "critic", "grasp_critic", "temperature"] = "critic",
|
||||
model: Literal["actor", "critic", "temperature"] = "critic",
|
||||
) -> dict[str, Tensor]:
|
||||
"""Compute the loss for the given model
|
||||
|
||||
|
@ -246,7 +246,6 @@ class SACPolicy(
|
|||
Returns:
|
||||
The computed loss tensor
|
||||
"""
|
||||
# TODO: (maractingi, azouitine) Respect the function signature we output tensors
|
||||
# Extract common components from batch
|
||||
actions: Tensor = batch["action"]
|
||||
observations: dict[str, Tensor] = batch["state"]
|
||||
|
@ -259,7 +258,7 @@ class SACPolicy(
|
|||
done: Tensor = batch["done"]
|
||||
next_observation_features: Tensor = batch.get("next_observation_feature")
|
||||
|
||||
return self.compute_loss_critic(
|
||||
loss_critic = self.compute_loss_critic(
|
||||
observations=observations,
|
||||
actions=actions,
|
||||
rewards=rewards,
|
||||
|
@ -268,29 +267,28 @@ class SACPolicy(
|
|||
observation_features=observation_features,
|
||||
next_observation_features=next_observation_features,
|
||||
)
|
||||
if self.config.num_discrete_actions is not None:
|
||||
loss_grasp_critic = self.compute_loss_grasp_critic(
|
||||
observations=observations,
|
||||
actions=actions,
|
||||
rewards=rewards,
|
||||
next_observations=next_observations,
|
||||
done=done,
|
||||
)
|
||||
return {"loss_critic": loss_critic, "loss_grasp_critic": loss_grasp_critic}
|
||||
|
||||
if model == "grasp_critic":
|
||||
return self.compute_loss_grasp_critic(
|
||||
observations=observations,
|
||||
actions=actions,
|
||||
rewards=rewards,
|
||||
next_observations=next_observations,
|
||||
done=done,
|
||||
observation_features=observation_features,
|
||||
next_observation_features=next_observation_features,
|
||||
)
|
||||
|
||||
if model == "actor":
|
||||
return self.compute_loss_actor(
|
||||
return {"loss_actor": self.compute_loss_actor(
|
||||
observations=observations,
|
||||
observation_features=observation_features,
|
||||
)
|
||||
)}
|
||||
|
||||
if model == "temperature":
|
||||
return self.compute_loss_temperature(
|
||||
return {"loss_temperature": self.compute_loss_temperature(
|
||||
observations=observations,
|
||||
observation_features=observation_features,
|
||||
)
|
||||
)}
|
||||
|
||||
raise ValueError(f"Unknown model type: {model}")
|
||||
|
||||
|
@ -305,18 +303,16 @@ class SACPolicy(
|
|||
param.data * self.config.critic_target_update_weight
|
||||
+ target_param.data * (1.0 - self.config.critic_target_update_weight)
|
||||
)
|
||||
|
||||
def update_grasp_target_networks(self):
|
||||
"""Update grasp target networks with exponential moving average"""
|
||||
for target_param, param in zip(
|
||||
self.grasp_critic_target.parameters(),
|
||||
self.grasp_critic.parameters(),
|
||||
strict=False,
|
||||
):
|
||||
target_param.data.copy_(
|
||||
param.data * self.config.critic_target_update_weight
|
||||
+ target_param.data * (1.0 - self.config.critic_target_update_weight)
|
||||
)
|
||||
if self.config.num_discrete_actions is not None:
|
||||
for target_param, param in zip(
|
||||
self.grasp_critic_target.parameters(),
|
||||
self.grasp_critic.parameters(),
|
||||
strict=False,
|
||||
):
|
||||
target_param.data.copy_(
|
||||
param.data * self.config.critic_target_update_weight
|
||||
+ target_param.data * (1.0 - self.config.critic_target_update_weight)
|
||||
)
|
||||
|
||||
def update_temperature(self):
|
||||
self.temperature = self.log_alpha.exp().item()
|
||||
|
|
|
@ -392,32 +392,30 @@ def add_actor_information_and_train(
|
|||
"next_observation_feature": next_observation_features,
|
||||
}
|
||||
|
||||
# Use the forward method for critic loss
|
||||
loss_critic = policy.forward(forward_batch, model="critic")
|
||||
# Use the forward method for critic loss (includes both main critic and grasp critic)
|
||||
critic_output = policy.forward(forward_batch, model="critic")
|
||||
|
||||
# Main critic optimization
|
||||
loss_critic = critic_output["loss_critic"]
|
||||
optimizers["critic"].zero_grad()
|
||||
loss_critic.backward()
|
||||
|
||||
# clip gradients
|
||||
critic_grad_norm = torch.nn.utils.clip_grad_norm_(
|
||||
parameters=policy.critic_ensemble.parameters(), max_norm=clip_grad_norm_value
|
||||
)
|
||||
|
||||
optimizers["critic"].step()
|
||||
|
||||
# Add gripper critic optimization
|
||||
loss_grasp_critic = policy.forward(forward_batch, model="grasp_critic")
|
||||
optimizers["grasp_critic"].zero_grad()
|
||||
loss_grasp_critic.backward()
|
||||
|
||||
# clip gradients
|
||||
grasp_critic_grad_norm = torch.nn.utils.clip_grad_norm_(
|
||||
parameters=policy.grasp_critic.parameters(), max_norm=clip_grad_norm_value
|
||||
)
|
||||
|
||||
optimizers["grasp_critic"].step()
|
||||
# Grasp critic optimization (if available)
|
||||
if "loss_grasp_critic" in critic_output and hasattr(policy, "grasp_critic"):
|
||||
loss_grasp_critic = critic_output["loss_grasp_critic"]
|
||||
optimizers["grasp_critic"].zero_grad()
|
||||
loss_grasp_critic.backward()
|
||||
grasp_critic_grad_norm = torch.nn.utils.clip_grad_norm_(
|
||||
parameters=policy.grasp_critic.parameters(), max_norm=clip_grad_norm_value
|
||||
)
|
||||
optimizers["grasp_critic"].step()
|
||||
|
||||
# Update target networks
|
||||
policy.update_target_networks()
|
||||
policy.update_grasp_target_networks()
|
||||
|
||||
batch = replay_buffer.sample(batch_size=batch_size)
|
||||
|
||||
|
@ -450,81 +448,80 @@ def add_actor_information_and_train(
|
|||
"next_observation_feature": next_observation_features,
|
||||
}
|
||||
|
||||
# Use the forward method for critic loss
|
||||
loss_critic = policy.forward(forward_batch, model="critic")
|
||||
# Use the forward method for critic loss (includes both main critic and grasp critic)
|
||||
critic_output = policy.forward(forward_batch, model="critic")
|
||||
|
||||
# Main critic optimization
|
||||
loss_critic = critic_output["loss_critic"]
|
||||
optimizers["critic"].zero_grad()
|
||||
loss_critic.backward()
|
||||
|
||||
# clip gradients
|
||||
critic_grad_norm = torch.nn.utils.clip_grad_norm_(
|
||||
parameters=policy.critic_ensemble.parameters(), max_norm=clip_grad_norm_value
|
||||
).item()
|
||||
|
||||
optimizers["critic"].step()
|
||||
|
||||
training_infos = {}
|
||||
training_infos["loss_critic"] = loss_critic.item()
|
||||
training_infos["critic_grad_norm"] = critic_grad_norm
|
||||
# Initialize training info dictionary
|
||||
training_infos = {
|
||||
"loss_critic": loss_critic.item(),
|
||||
"critic_grad_norm": critic_grad_norm,
|
||||
}
|
||||
|
||||
# Add gripper critic optimization
|
||||
loss_grasp_critic = policy.forward(forward_batch, model="grasp_critic")
|
||||
optimizers["grasp_critic"].zero_grad()
|
||||
loss_grasp_critic.backward()
|
||||
|
||||
# clip gradients
|
||||
grasp_critic_grad_norm = torch.nn.utils.clip_grad_norm_(
|
||||
parameters=policy.grasp_critic.parameters(), max_norm=clip_grad_norm_value
|
||||
)
|
||||
|
||||
optimizers["grasp_critic"].step()
|
||||
|
||||
# Add training info for the grasp critic
|
||||
training_infos["loss_grasp_critic"] = loss_grasp_critic.item()
|
||||
training_infos["grasp_critic_grad_norm"] = grasp_critic_grad_norm
|
||||
# Grasp critic optimization (if available)
|
||||
if "loss_grasp_critic" in critic_output:
|
||||
loss_grasp_critic = critic_output["loss_grasp_critic"]
|
||||
optimizers["grasp_critic"].zero_grad()
|
||||
loss_grasp_critic.backward()
|
||||
grasp_critic_grad_norm = torch.nn.utils.clip_grad_norm_(
|
||||
parameters=policy.grasp_critic.parameters(), max_norm=clip_grad_norm_value
|
||||
).item()
|
||||
optimizers["grasp_critic"].step()
|
||||
|
||||
# Add grasp critic info to training info
|
||||
training_infos["loss_grasp_critic"] = loss_grasp_critic.item()
|
||||
training_infos["grasp_critic_grad_norm"] = grasp_critic_grad_norm
|
||||
|
||||
# Actor and temperature optimization (at specified frequency)
|
||||
if optimization_step % policy_update_freq == 0:
|
||||
for _ in range(policy_update_freq):
|
||||
# Use the forward method for actor loss
|
||||
loss_actor = policy.forward(forward_batch, model="actor")
|
||||
|
||||
# Actor optimization
|
||||
actor_output = policy.forward(forward_batch, model="actor")
|
||||
loss_actor = actor_output["loss_actor"]
|
||||
optimizers["actor"].zero_grad()
|
||||
loss_actor.backward()
|
||||
|
||||
# clip gradients
|
||||
actor_grad_norm = torch.nn.utils.clip_grad_norm_(
|
||||
parameters=policy.actor.parameters_to_optimize, max_norm=clip_grad_norm_value
|
||||
).item()
|
||||
|
||||
optimizers["actor"].step()
|
||||
|
||||
|
||||
# Add actor info to training info
|
||||
training_infos["loss_actor"] = loss_actor.item()
|
||||
training_infos["actor_grad_norm"] = actor_grad_norm
|
||||
|
||||
# Temperature optimization using forward method
|
||||
loss_temperature = policy.forward(forward_batch, model="temperature")
|
||||
# Temperature optimization
|
||||
temperature_output = policy.forward(forward_batch, model="temperature")
|
||||
loss_temperature = temperature_output["loss_temperature"]
|
||||
optimizers["temperature"].zero_grad()
|
||||
loss_temperature.backward()
|
||||
|
||||
# clip gradients
|
||||
temp_grad_norm = torch.nn.utils.clip_grad_norm_(
|
||||
parameters=[policy.log_alpha], max_norm=clip_grad_norm_value
|
||||
).item()
|
||||
|
||||
optimizers["temperature"].step()
|
||||
|
||||
|
||||
# Add temperature info to training info
|
||||
training_infos["loss_temperature"] = loss_temperature.item()
|
||||
training_infos["temperature_grad_norm"] = temp_grad_norm
|
||||
training_infos["temperature"] = policy.temperature
|
||||
|
||||
# Update temperature
|
||||
policy.update_temperature()
|
||||
|
||||
# Check if it's time to push updated policy to actors
|
||||
# Push policy to actors if needed
|
||||
if time.time() - last_time_policy_pushed > policy_parameters_push_frequency:
|
||||
push_actor_policy_to_queue(parameters_queue=parameters_queue, policy=policy)
|
||||
last_time_policy_pushed = time.time()
|
||||
|
||||
# Update target networks
|
||||
policy.update_target_networks()
|
||||
policy.update_grasp_target_networks()
|
||||
|
||||
# Log training metrics at specified intervals
|
||||
if optimization_step % log_freq == 0:
|
||||
|
@ -727,7 +724,7 @@ def save_training_checkpoint(
|
|||
logging.info("Resume training")
|
||||
|
||||
|
||||
def make_optimizers_and_scheduler(cfg, policy: nn.Module):
|
||||
def make_optimizers_and_scheduler(cfg: TrainPipelineConfig, policy: nn.Module):
|
||||
"""
|
||||
Creates and returns optimizers for the actor, critic, and temperature components of a reinforcement learning policy.
|
||||
|
||||
|
@ -759,17 +756,20 @@ def make_optimizers_and_scheduler(cfg, policy: nn.Module):
|
|||
lr=cfg.policy.actor_lr,
|
||||
)
|
||||
optimizer_critic = torch.optim.Adam(params=policy.critic_ensemble.parameters(), lr=cfg.policy.critic_lr)
|
||||
optimizer_grasp_critic = torch.optim.Adam(
|
||||
params=policy.grasp_critic.parameters(), lr=policy.config.grasp_critic_lr
|
||||
)
|
||||
|
||||
if cfg.policy.num_discrete_actions is not None:
|
||||
optimizer_grasp_critic = torch.optim.Adam(
|
||||
params=policy.grasp_critic.parameters(), lr=policy.critic_lr
|
||||
)
|
||||
optimizer_temperature = torch.optim.Adam(params=[policy.log_alpha], lr=cfg.policy.critic_lr)
|
||||
lr_scheduler = None
|
||||
optimizers = {
|
||||
"actor": optimizer_actor,
|
||||
"critic": optimizer_critic,
|
||||
"grasp_critic": optimizer_grasp_critic,
|
||||
"temperature": optimizer_temperature,
|
||||
}
|
||||
if cfg.policy.num_discrete_actions is not None:
|
||||
optimizers["grasp_critic"] = optimizer_grasp_critic
|
||||
return optimizers, lr_scheduler
|
||||
|
||||
|
||||
|
|
Loading…
Reference in New Issue