Refactor SAC policy and training loop to enhance discrete action support
- Updated SACPolicy to conditionally compute losses for grasp critic based on num_discrete_actions. - Simplified forward method to return loss outputs as a dictionary for better clarity. - Adjusted learner_server to handle both main and grasp critic losses during training. - Ensured optimizers are created conditionally for grasp critic based on configuration settings.
This commit is contained in:
parent
6a215f47dd
commit
306c735172
|
@ -87,6 +87,7 @@ class SACConfig(PreTrainedConfig):
|
||||||
freeze_vision_encoder: Whether to freeze the vision encoder during training.
|
freeze_vision_encoder: Whether to freeze the vision encoder during training.
|
||||||
image_encoder_hidden_dim: Hidden dimension size for the image encoder.
|
image_encoder_hidden_dim: Hidden dimension size for the image encoder.
|
||||||
shared_encoder: Whether to use a shared encoder for actor and critic.
|
shared_encoder: Whether to use a shared encoder for actor and critic.
|
||||||
|
num_discrete_actions: Number of discrete actions, eg for gripper actions.
|
||||||
concurrency: Configuration for concurrency settings.
|
concurrency: Configuration for concurrency settings.
|
||||||
actor_learner: Configuration for actor-learner architecture.
|
actor_learner: Configuration for actor-learner architecture.
|
||||||
online_steps: Number of steps for online training.
|
online_steps: Number of steps for online training.
|
||||||
|
@ -162,7 +163,6 @@ class SACConfig(PreTrainedConfig):
|
||||||
num_critics: int = 2
|
num_critics: int = 2
|
||||||
num_subsample_critics: int | None = None
|
num_subsample_critics: int | None = None
|
||||||
critic_lr: float = 3e-4
|
critic_lr: float = 3e-4
|
||||||
grasp_critic_lr: float = 3e-4
|
|
||||||
actor_lr: float = 3e-4
|
actor_lr: float = 3e-4
|
||||||
temperature_lr: float = 3e-4
|
temperature_lr: float = 3e-4
|
||||||
critic_target_update_weight: float = 0.005
|
critic_target_update_weight: float = 0.005
|
||||||
|
|
|
@ -228,7 +228,7 @@ class SACPolicy(
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
batch: dict[str, Tensor | dict[str, Tensor]],
|
batch: dict[str, Tensor | dict[str, Tensor]],
|
||||||
model: Literal["actor", "critic", "grasp_critic", "temperature"] = "critic",
|
model: Literal["actor", "critic", "temperature"] = "critic",
|
||||||
) -> dict[str, Tensor]:
|
) -> dict[str, Tensor]:
|
||||||
"""Compute the loss for the given model
|
"""Compute the loss for the given model
|
||||||
|
|
||||||
|
@ -246,7 +246,6 @@ class SACPolicy(
|
||||||
Returns:
|
Returns:
|
||||||
The computed loss tensor
|
The computed loss tensor
|
||||||
"""
|
"""
|
||||||
# TODO: (maractingi, azouitine) Respect the function signature we output tensors
|
|
||||||
# Extract common components from batch
|
# Extract common components from batch
|
||||||
actions: Tensor = batch["action"]
|
actions: Tensor = batch["action"]
|
||||||
observations: dict[str, Tensor] = batch["state"]
|
observations: dict[str, Tensor] = batch["state"]
|
||||||
|
@ -259,7 +258,7 @@ class SACPolicy(
|
||||||
done: Tensor = batch["done"]
|
done: Tensor = batch["done"]
|
||||||
next_observation_features: Tensor = batch.get("next_observation_feature")
|
next_observation_features: Tensor = batch.get("next_observation_feature")
|
||||||
|
|
||||||
return self.compute_loss_critic(
|
loss_critic = self.compute_loss_critic(
|
||||||
observations=observations,
|
observations=observations,
|
||||||
actions=actions,
|
actions=actions,
|
||||||
rewards=rewards,
|
rewards=rewards,
|
||||||
|
@ -268,29 +267,28 @@ class SACPolicy(
|
||||||
observation_features=observation_features,
|
observation_features=observation_features,
|
||||||
next_observation_features=next_observation_features,
|
next_observation_features=next_observation_features,
|
||||||
)
|
)
|
||||||
|
if self.config.num_discrete_actions is not None:
|
||||||
|
loss_grasp_critic = self.compute_loss_grasp_critic(
|
||||||
|
observations=observations,
|
||||||
|
actions=actions,
|
||||||
|
rewards=rewards,
|
||||||
|
next_observations=next_observations,
|
||||||
|
done=done,
|
||||||
|
)
|
||||||
|
return {"loss_critic": loss_critic, "loss_grasp_critic": loss_grasp_critic}
|
||||||
|
|
||||||
if model == "grasp_critic":
|
|
||||||
return self.compute_loss_grasp_critic(
|
|
||||||
observations=observations,
|
|
||||||
actions=actions,
|
|
||||||
rewards=rewards,
|
|
||||||
next_observations=next_observations,
|
|
||||||
done=done,
|
|
||||||
observation_features=observation_features,
|
|
||||||
next_observation_features=next_observation_features,
|
|
||||||
)
|
|
||||||
|
|
||||||
if model == "actor":
|
if model == "actor":
|
||||||
return self.compute_loss_actor(
|
return {"loss_actor": self.compute_loss_actor(
|
||||||
observations=observations,
|
observations=observations,
|
||||||
observation_features=observation_features,
|
observation_features=observation_features,
|
||||||
)
|
)}
|
||||||
|
|
||||||
if model == "temperature":
|
if model == "temperature":
|
||||||
return self.compute_loss_temperature(
|
return {"loss_temperature": self.compute_loss_temperature(
|
||||||
observations=observations,
|
observations=observations,
|
||||||
observation_features=observation_features,
|
observation_features=observation_features,
|
||||||
)
|
)}
|
||||||
|
|
||||||
raise ValueError(f"Unknown model type: {model}")
|
raise ValueError(f"Unknown model type: {model}")
|
||||||
|
|
||||||
|
@ -305,9 +303,7 @@ class SACPolicy(
|
||||||
param.data * self.config.critic_target_update_weight
|
param.data * self.config.critic_target_update_weight
|
||||||
+ target_param.data * (1.0 - self.config.critic_target_update_weight)
|
+ target_param.data * (1.0 - self.config.critic_target_update_weight)
|
||||||
)
|
)
|
||||||
|
if self.config.num_discrete_actions is not None:
|
||||||
def update_grasp_target_networks(self):
|
|
||||||
"""Update grasp target networks with exponential moving average"""
|
|
||||||
for target_param, param in zip(
|
for target_param, param in zip(
|
||||||
self.grasp_critic_target.parameters(),
|
self.grasp_critic_target.parameters(),
|
||||||
self.grasp_critic.parameters(),
|
self.grasp_critic.parameters(),
|
||||||
|
|
|
@ -392,32 +392,30 @@ def add_actor_information_and_train(
|
||||||
"next_observation_feature": next_observation_features,
|
"next_observation_feature": next_observation_features,
|
||||||
}
|
}
|
||||||
|
|
||||||
# Use the forward method for critic loss
|
# Use the forward method for critic loss (includes both main critic and grasp critic)
|
||||||
loss_critic = policy.forward(forward_batch, model="critic")
|
critic_output = policy.forward(forward_batch, model="critic")
|
||||||
|
|
||||||
|
# Main critic optimization
|
||||||
|
loss_critic = critic_output["loss_critic"]
|
||||||
optimizers["critic"].zero_grad()
|
optimizers["critic"].zero_grad()
|
||||||
loss_critic.backward()
|
loss_critic.backward()
|
||||||
|
|
||||||
# clip gradients
|
|
||||||
critic_grad_norm = torch.nn.utils.clip_grad_norm_(
|
critic_grad_norm = torch.nn.utils.clip_grad_norm_(
|
||||||
parameters=policy.critic_ensemble.parameters(), max_norm=clip_grad_norm_value
|
parameters=policy.critic_ensemble.parameters(), max_norm=clip_grad_norm_value
|
||||||
)
|
)
|
||||||
|
|
||||||
optimizers["critic"].step()
|
optimizers["critic"].step()
|
||||||
|
|
||||||
# Add gripper critic optimization
|
# Grasp critic optimization (if available)
|
||||||
loss_grasp_critic = policy.forward(forward_batch, model="grasp_critic")
|
if "loss_grasp_critic" in critic_output and hasattr(policy, "grasp_critic"):
|
||||||
|
loss_grasp_critic = critic_output["loss_grasp_critic"]
|
||||||
optimizers["grasp_critic"].zero_grad()
|
optimizers["grasp_critic"].zero_grad()
|
||||||
loss_grasp_critic.backward()
|
loss_grasp_critic.backward()
|
||||||
|
|
||||||
# clip gradients
|
|
||||||
grasp_critic_grad_norm = torch.nn.utils.clip_grad_norm_(
|
grasp_critic_grad_norm = torch.nn.utils.clip_grad_norm_(
|
||||||
parameters=policy.grasp_critic.parameters(), max_norm=clip_grad_norm_value
|
parameters=policy.grasp_critic.parameters(), max_norm=clip_grad_norm_value
|
||||||
)
|
)
|
||||||
|
|
||||||
optimizers["grasp_critic"].step()
|
optimizers["grasp_critic"].step()
|
||||||
|
|
||||||
|
# Update target networks
|
||||||
policy.update_target_networks()
|
policy.update_target_networks()
|
||||||
policy.update_grasp_target_networks()
|
|
||||||
|
|
||||||
batch = replay_buffer.sample(batch_size=batch_size)
|
batch = replay_buffer.sample(batch_size=batch_size)
|
||||||
|
|
||||||
|
@ -450,81 +448,80 @@ def add_actor_information_and_train(
|
||||||
"next_observation_feature": next_observation_features,
|
"next_observation_feature": next_observation_features,
|
||||||
}
|
}
|
||||||
|
|
||||||
# Use the forward method for critic loss
|
# Use the forward method for critic loss (includes both main critic and grasp critic)
|
||||||
loss_critic = policy.forward(forward_batch, model="critic")
|
critic_output = policy.forward(forward_batch, model="critic")
|
||||||
|
|
||||||
|
# Main critic optimization
|
||||||
|
loss_critic = critic_output["loss_critic"]
|
||||||
optimizers["critic"].zero_grad()
|
optimizers["critic"].zero_grad()
|
||||||
loss_critic.backward()
|
loss_critic.backward()
|
||||||
|
|
||||||
# clip gradients
|
|
||||||
critic_grad_norm = torch.nn.utils.clip_grad_norm_(
|
critic_grad_norm = torch.nn.utils.clip_grad_norm_(
|
||||||
parameters=policy.critic_ensemble.parameters(), max_norm=clip_grad_norm_value
|
parameters=policy.critic_ensemble.parameters(), max_norm=clip_grad_norm_value
|
||||||
).item()
|
).item()
|
||||||
|
|
||||||
optimizers["critic"].step()
|
optimizers["critic"].step()
|
||||||
|
|
||||||
training_infos = {}
|
# Initialize training info dictionary
|
||||||
training_infos["loss_critic"] = loss_critic.item()
|
training_infos = {
|
||||||
training_infos["critic_grad_norm"] = critic_grad_norm
|
"loss_critic": loss_critic.item(),
|
||||||
|
"critic_grad_norm": critic_grad_norm,
|
||||||
|
}
|
||||||
|
|
||||||
# Add gripper critic optimization
|
# Grasp critic optimization (if available)
|
||||||
loss_grasp_critic = policy.forward(forward_batch, model="grasp_critic")
|
if "loss_grasp_critic" in critic_output:
|
||||||
|
loss_grasp_critic = critic_output["loss_grasp_critic"]
|
||||||
optimizers["grasp_critic"].zero_grad()
|
optimizers["grasp_critic"].zero_grad()
|
||||||
loss_grasp_critic.backward()
|
loss_grasp_critic.backward()
|
||||||
|
|
||||||
# clip gradients
|
|
||||||
grasp_critic_grad_norm = torch.nn.utils.clip_grad_norm_(
|
grasp_critic_grad_norm = torch.nn.utils.clip_grad_norm_(
|
||||||
parameters=policy.grasp_critic.parameters(), max_norm=clip_grad_norm_value
|
parameters=policy.grasp_critic.parameters(), max_norm=clip_grad_norm_value
|
||||||
)
|
).item()
|
||||||
|
|
||||||
optimizers["grasp_critic"].step()
|
optimizers["grasp_critic"].step()
|
||||||
|
|
||||||
# Add training info for the grasp critic
|
# Add grasp critic info to training info
|
||||||
training_infos["loss_grasp_critic"] = loss_grasp_critic.item()
|
training_infos["loss_grasp_critic"] = loss_grasp_critic.item()
|
||||||
training_infos["grasp_critic_grad_norm"] = grasp_critic_grad_norm
|
training_infos["grasp_critic_grad_norm"] = grasp_critic_grad_norm
|
||||||
|
|
||||||
|
# Actor and temperature optimization (at specified frequency)
|
||||||
if optimization_step % policy_update_freq == 0:
|
if optimization_step % policy_update_freq == 0:
|
||||||
for _ in range(policy_update_freq):
|
for _ in range(policy_update_freq):
|
||||||
# Use the forward method for actor loss
|
# Actor optimization
|
||||||
loss_actor = policy.forward(forward_batch, model="actor")
|
actor_output = policy.forward(forward_batch, model="actor")
|
||||||
|
loss_actor = actor_output["loss_actor"]
|
||||||
optimizers["actor"].zero_grad()
|
optimizers["actor"].zero_grad()
|
||||||
loss_actor.backward()
|
loss_actor.backward()
|
||||||
|
|
||||||
# clip gradients
|
|
||||||
actor_grad_norm = torch.nn.utils.clip_grad_norm_(
|
actor_grad_norm = torch.nn.utils.clip_grad_norm_(
|
||||||
parameters=policy.actor.parameters_to_optimize, max_norm=clip_grad_norm_value
|
parameters=policy.actor.parameters_to_optimize, max_norm=clip_grad_norm_value
|
||||||
).item()
|
).item()
|
||||||
|
|
||||||
optimizers["actor"].step()
|
optimizers["actor"].step()
|
||||||
|
|
||||||
|
# Add actor info to training info
|
||||||
training_infos["loss_actor"] = loss_actor.item()
|
training_infos["loss_actor"] = loss_actor.item()
|
||||||
training_infos["actor_grad_norm"] = actor_grad_norm
|
training_infos["actor_grad_norm"] = actor_grad_norm
|
||||||
|
|
||||||
# Temperature optimization using forward method
|
# Temperature optimization
|
||||||
loss_temperature = policy.forward(forward_batch, model="temperature")
|
temperature_output = policy.forward(forward_batch, model="temperature")
|
||||||
|
loss_temperature = temperature_output["loss_temperature"]
|
||||||
optimizers["temperature"].zero_grad()
|
optimizers["temperature"].zero_grad()
|
||||||
loss_temperature.backward()
|
loss_temperature.backward()
|
||||||
|
|
||||||
# clip gradients
|
|
||||||
temp_grad_norm = torch.nn.utils.clip_grad_norm_(
|
temp_grad_norm = torch.nn.utils.clip_grad_norm_(
|
||||||
parameters=[policy.log_alpha], max_norm=clip_grad_norm_value
|
parameters=[policy.log_alpha], max_norm=clip_grad_norm_value
|
||||||
).item()
|
).item()
|
||||||
|
|
||||||
optimizers["temperature"].step()
|
optimizers["temperature"].step()
|
||||||
|
|
||||||
|
# Add temperature info to training info
|
||||||
training_infos["loss_temperature"] = loss_temperature.item()
|
training_infos["loss_temperature"] = loss_temperature.item()
|
||||||
training_infos["temperature_grad_norm"] = temp_grad_norm
|
training_infos["temperature_grad_norm"] = temp_grad_norm
|
||||||
training_infos["temperature"] = policy.temperature
|
training_infos["temperature"] = policy.temperature
|
||||||
|
|
||||||
|
# Update temperature
|
||||||
policy.update_temperature()
|
policy.update_temperature()
|
||||||
|
|
||||||
# Check if it's time to push updated policy to actors
|
# Push policy to actors if needed
|
||||||
if time.time() - last_time_policy_pushed > policy_parameters_push_frequency:
|
if time.time() - last_time_policy_pushed > policy_parameters_push_frequency:
|
||||||
push_actor_policy_to_queue(parameters_queue=parameters_queue, policy=policy)
|
push_actor_policy_to_queue(parameters_queue=parameters_queue, policy=policy)
|
||||||
last_time_policy_pushed = time.time()
|
last_time_policy_pushed = time.time()
|
||||||
|
|
||||||
|
# Update target networks
|
||||||
policy.update_target_networks()
|
policy.update_target_networks()
|
||||||
policy.update_grasp_target_networks()
|
|
||||||
|
|
||||||
# Log training metrics at specified intervals
|
# Log training metrics at specified intervals
|
||||||
if optimization_step % log_freq == 0:
|
if optimization_step % log_freq == 0:
|
||||||
|
@ -727,7 +724,7 @@ def save_training_checkpoint(
|
||||||
logging.info("Resume training")
|
logging.info("Resume training")
|
||||||
|
|
||||||
|
|
||||||
def make_optimizers_and_scheduler(cfg, policy: nn.Module):
|
def make_optimizers_and_scheduler(cfg: TrainPipelineConfig, policy: nn.Module):
|
||||||
"""
|
"""
|
||||||
Creates and returns optimizers for the actor, critic, and temperature components of a reinforcement learning policy.
|
Creates and returns optimizers for the actor, critic, and temperature components of a reinforcement learning policy.
|
||||||
|
|
||||||
|
@ -759,17 +756,20 @@ def make_optimizers_and_scheduler(cfg, policy: nn.Module):
|
||||||
lr=cfg.policy.actor_lr,
|
lr=cfg.policy.actor_lr,
|
||||||
)
|
)
|
||||||
optimizer_critic = torch.optim.Adam(params=policy.critic_ensemble.parameters(), lr=cfg.policy.critic_lr)
|
optimizer_critic = torch.optim.Adam(params=policy.critic_ensemble.parameters(), lr=cfg.policy.critic_lr)
|
||||||
|
|
||||||
|
if cfg.policy.num_discrete_actions is not None:
|
||||||
optimizer_grasp_critic = torch.optim.Adam(
|
optimizer_grasp_critic = torch.optim.Adam(
|
||||||
params=policy.grasp_critic.parameters(), lr=policy.config.grasp_critic_lr
|
params=policy.grasp_critic.parameters(), lr=policy.critic_lr
|
||||||
)
|
)
|
||||||
optimizer_temperature = torch.optim.Adam(params=[policy.log_alpha], lr=cfg.policy.critic_lr)
|
optimizer_temperature = torch.optim.Adam(params=[policy.log_alpha], lr=cfg.policy.critic_lr)
|
||||||
lr_scheduler = None
|
lr_scheduler = None
|
||||||
optimizers = {
|
optimizers = {
|
||||||
"actor": optimizer_actor,
|
"actor": optimizer_actor,
|
||||||
"critic": optimizer_critic,
|
"critic": optimizer_critic,
|
||||||
"grasp_critic": optimizer_grasp_critic,
|
|
||||||
"temperature": optimizer_temperature,
|
"temperature": optimizer_temperature,
|
||||||
}
|
}
|
||||||
|
if cfg.policy.num_discrete_actions is not None:
|
||||||
|
optimizers["grasp_critic"] = optimizer_grasp_critic
|
||||||
return optimizers, lr_scheduler
|
return optimizers, lr_scheduler
|
||||||
|
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue