Refactor SACPolicy for improved readability and action dimension handling
- Cleaned up code formatting for better readability, including consistent spacing and removal of unnecessary blank lines. - Consolidated continuous action dimension calculation to enhance clarity and maintainability. - Simplified loss return statements in the forward method to improve code structure. - Ensured grasp critic parameters are included conditionally based on configuration settings.
This commit is contained in:
parent
c6cd1475a7
commit
077d18b439
|
@ -35,6 +35,7 @@ from lerobot.common.policies.utils import get_device_from_parameters
|
|||
|
||||
DISCRETE_DIMENSION_INDEX = -1 # Gripper is always the last dimension
|
||||
|
||||
|
||||
class SACPolicy(
|
||||
PreTrainedPolicy,
|
||||
):
|
||||
|
@ -50,6 +51,10 @@ class SACPolicy(
|
|||
config.validate_features()
|
||||
self.config = config
|
||||
|
||||
continuous_action_dim = config.output_features["action"].shape[0]
|
||||
if config.num_discrete_actions is not None:
|
||||
continuous_action_dim -= 1
|
||||
|
||||
if config.dataset_stats is not None:
|
||||
input_normalization_params = _convert_normalization_params_to_tensor(config.dataset_stats)
|
||||
self.normalize_inputs = Normalize(
|
||||
|
@ -117,10 +122,7 @@ class SACPolicy(
|
|||
self.grasp_critic = None
|
||||
self.grasp_critic_target = None
|
||||
|
||||
continuous_action_dim = config.output_features["action"].shape[0]
|
||||
if config.num_discrete_actions is not None:
|
||||
|
||||
continuous_action_dim -= 1
|
||||
# Create grasp critic
|
||||
self.grasp_critic = GraspCritic(
|
||||
encoder=encoder_critic,
|
||||
|
@ -142,7 +144,6 @@ class SACPolicy(
|
|||
self.grasp_critic = torch.compile(self.grasp_critic)
|
||||
self.grasp_critic_target = torch.compile(self.grasp_critic_target)
|
||||
|
||||
|
||||
self.actor = Policy(
|
||||
encoder=encoder_actor,
|
||||
network=MLP(input_dim=encoder_actor.output_dim, **asdict(config.actor_network_kwargs)),
|
||||
|
@ -162,11 +163,14 @@ class SACPolicy(
|
|||
self.temperature = self.log_alpha.exp().item()
|
||||
|
||||
def get_optim_params(self) -> dict:
|
||||
return {
|
||||
optim_params = {
|
||||
"actor": self.actor.parameters_to_optimize,
|
||||
"critic": self.critic_ensemble.parameters_to_optimize,
|
||||
"temperature": self.log_alpha,
|
||||
}
|
||||
if self.config.num_discrete_actions is not None:
|
||||
optim_params["grasp_critic"] = self.grasp_critic.parameters_to_optimize
|
||||
return optim_params
|
||||
|
||||
def reset(self):
|
||||
"""Reset the policy"""
|
||||
|
@ -283,18 +287,21 @@ class SACPolicy(
|
|||
|
||||
return {"loss_critic": loss_critic}
|
||||
|
||||
|
||||
if model == "actor":
|
||||
return {"loss_actor": self.compute_loss_actor(
|
||||
return {
|
||||
"loss_actor": self.compute_loss_actor(
|
||||
observations=observations,
|
||||
observation_features=observation_features,
|
||||
)}
|
||||
)
|
||||
}
|
||||
|
||||
if model == "temperature":
|
||||
return {"loss_temperature": self.compute_loss_temperature(
|
||||
return {
|
||||
"loss_temperature": self.compute_loss_temperature(
|
||||
observations=observations,
|
||||
observation_features=observation_features,
|
||||
)}
|
||||
)
|
||||
}
|
||||
|
||||
raise ValueError(f"Unknown model type: {model}")
|
||||
|
||||
|
@ -413,9 +420,7 @@ class SACPolicy(
|
|||
|
||||
# Use gather to select Q-values for best actions
|
||||
target_next_grasp_q = torch.gather(
|
||||
target_next_grasp_qs,
|
||||
dim=1,
|
||||
index=best_next_grasp_action.unsqueeze(-1)
|
||||
target_next_grasp_qs, dim=1, index=best_next_grasp_action.unsqueeze(-1)
|
||||
).squeeze(-1)
|
||||
|
||||
# Compute target Q-value with Bellman equation
|
||||
|
@ -425,11 +430,7 @@ class SACPolicy(
|
|||
predicted_grasp_qs = self.grasp_critic_forward(observations=observations, use_target=False)
|
||||
|
||||
# Use gather to select Q-values for taken actions
|
||||
predicted_grasp_q = torch.gather(
|
||||
predicted_grasp_qs,
|
||||
dim=1,
|
||||
index=actions.unsqueeze(-1)
|
||||
).squeeze(-1)
|
||||
predicted_grasp_q = torch.gather(predicted_grasp_qs, dim=1, index=actions.unsqueeze(-1)).squeeze(-1)
|
||||
|
||||
# Compute MSE loss between predicted and target Q-values
|
||||
grasp_critic_loss = F.mse_loss(input=predicted_grasp_q, target=target_grasp_q)
|
||||
|
|
|
@ -405,7 +405,7 @@ def add_actor_information_and_train(
|
|||
optimizers["critic"].step()
|
||||
|
||||
# Grasp critic optimization (if available)
|
||||
if "loss_grasp_critic" in critic_output and hasattr(policy, "grasp_critic"):
|
||||
if "loss_grasp_critic" in critic_output:
|
||||
loss_grasp_critic = critic_output["loss_grasp_critic"]
|
||||
optimizers["grasp_critic"].zero_grad()
|
||||
loss_grasp_critic.backward()
|
||||
|
|
Loading…
Reference in New Issue