Refactor SACPolicy for improved readability and action dimension handling
- Cleaned up code formatting for better readability, including consistent spacing and removal of unnecessary blank lines. - Consolidated continuous action dimension calculation to enhance clarity and maintainability. - Simplified loss return statements in the forward method to improve code structure. - Ensured grasp critic parameters are included conditionally based on configuration settings.
This commit is contained in:
parent
c6cd1475a7
commit
077d18b439
|
@ -33,7 +33,8 @@ from lerobot.common.policies.pretrained import PreTrainedPolicy
|
||||||
from lerobot.common.policies.sac.configuration_sac import SACConfig
|
from lerobot.common.policies.sac.configuration_sac import SACConfig
|
||||||
from lerobot.common.policies.utils import get_device_from_parameters
|
from lerobot.common.policies.utils import get_device_from_parameters
|
||||||
|
|
||||||
DISCRETE_DIMENSION_INDEX = -1 # Gripper is always the last dimension
|
DISCRETE_DIMENSION_INDEX = -1 # Gripper is always the last dimension
|
||||||
|
|
||||||
|
|
||||||
class SACPolicy(
|
class SACPolicy(
|
||||||
PreTrainedPolicy,
|
PreTrainedPolicy,
|
||||||
|
@ -50,6 +51,10 @@ class SACPolicy(
|
||||||
config.validate_features()
|
config.validate_features()
|
||||||
self.config = config
|
self.config = config
|
||||||
|
|
||||||
|
continuous_action_dim = config.output_features["action"].shape[0]
|
||||||
|
if config.num_discrete_actions is not None:
|
||||||
|
continuous_action_dim -= 1
|
||||||
|
|
||||||
if config.dataset_stats is not None:
|
if config.dataset_stats is not None:
|
||||||
input_normalization_params = _convert_normalization_params_to_tensor(config.dataset_stats)
|
input_normalization_params = _convert_normalization_params_to_tensor(config.dataset_stats)
|
||||||
self.normalize_inputs = Normalize(
|
self.normalize_inputs = Normalize(
|
||||||
|
@ -117,10 +122,7 @@ class SACPolicy(
|
||||||
self.grasp_critic = None
|
self.grasp_critic = None
|
||||||
self.grasp_critic_target = None
|
self.grasp_critic_target = None
|
||||||
|
|
||||||
continuous_action_dim = config.output_features["action"].shape[0]
|
|
||||||
if config.num_discrete_actions is not None:
|
if config.num_discrete_actions is not None:
|
||||||
|
|
||||||
continuous_action_dim -= 1
|
|
||||||
# Create grasp critic
|
# Create grasp critic
|
||||||
self.grasp_critic = GraspCritic(
|
self.grasp_critic = GraspCritic(
|
||||||
encoder=encoder_critic,
|
encoder=encoder_critic,
|
||||||
|
@ -142,7 +144,6 @@ class SACPolicy(
|
||||||
self.grasp_critic = torch.compile(self.grasp_critic)
|
self.grasp_critic = torch.compile(self.grasp_critic)
|
||||||
self.grasp_critic_target = torch.compile(self.grasp_critic_target)
|
self.grasp_critic_target = torch.compile(self.grasp_critic_target)
|
||||||
|
|
||||||
|
|
||||||
self.actor = Policy(
|
self.actor = Policy(
|
||||||
encoder=encoder_actor,
|
encoder=encoder_actor,
|
||||||
network=MLP(input_dim=encoder_actor.output_dim, **asdict(config.actor_network_kwargs)),
|
network=MLP(input_dim=encoder_actor.output_dim, **asdict(config.actor_network_kwargs)),
|
||||||
|
@ -162,11 +163,14 @@ class SACPolicy(
|
||||||
self.temperature = self.log_alpha.exp().item()
|
self.temperature = self.log_alpha.exp().item()
|
||||||
|
|
||||||
def get_optim_params(self) -> dict:
|
def get_optim_params(self) -> dict:
|
||||||
return {
|
optim_params = {
|
||||||
"actor": self.actor.parameters_to_optimize,
|
"actor": self.actor.parameters_to_optimize,
|
||||||
"critic": self.critic_ensemble.parameters_to_optimize,
|
"critic": self.critic_ensemble.parameters_to_optimize,
|
||||||
"temperature": self.log_alpha,
|
"temperature": self.log_alpha,
|
||||||
}
|
}
|
||||||
|
if self.config.num_discrete_actions is not None:
|
||||||
|
optim_params["grasp_critic"] = self.grasp_critic.parameters_to_optimize
|
||||||
|
return optim_params
|
||||||
|
|
||||||
def reset(self):
|
def reset(self):
|
||||||
"""Reset the policy"""
|
"""Reset the policy"""
|
||||||
|
@ -262,7 +266,7 @@ class SACPolicy(
|
||||||
done: Tensor = batch["done"]
|
done: Tensor = batch["done"]
|
||||||
next_observation_features: Tensor = batch.get("next_observation_feature")
|
next_observation_features: Tensor = batch.get("next_observation_feature")
|
||||||
|
|
||||||
loss_critic = self.compute_loss_critic(
|
loss_critic = self.compute_loss_critic(
|
||||||
observations=observations,
|
observations=observations,
|
||||||
actions=actions,
|
actions=actions,
|
||||||
rewards=rewards,
|
rewards=rewards,
|
||||||
|
@ -283,18 +287,21 @@ class SACPolicy(
|
||||||
|
|
||||||
return {"loss_critic": loss_critic}
|
return {"loss_critic": loss_critic}
|
||||||
|
|
||||||
|
|
||||||
if model == "actor":
|
if model == "actor":
|
||||||
return {"loss_actor": self.compute_loss_actor(
|
return {
|
||||||
observations=observations,
|
"loss_actor": self.compute_loss_actor(
|
||||||
observation_features=observation_features,
|
observations=observations,
|
||||||
)}
|
observation_features=observation_features,
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
if model == "temperature":
|
if model == "temperature":
|
||||||
return {"loss_temperature": self.compute_loss_temperature(
|
return {
|
||||||
observations=observations,
|
"loss_temperature": self.compute_loss_temperature(
|
||||||
observation_features=observation_features,
|
observations=observations,
|
||||||
)}
|
observation_features=observation_features,
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
raise ValueError(f"Unknown model type: {model}")
|
raise ValueError(f"Unknown model type: {model}")
|
||||||
|
|
||||||
|
@ -366,7 +373,7 @@ class SACPolicy(
|
||||||
# In the buffer we have the full action space (continuous + discrete)
|
# In the buffer we have the full action space (continuous + discrete)
|
||||||
# We need to split them before concatenating them in the critic forward
|
# We need to split them before concatenating them in the critic forward
|
||||||
actions: Tensor = actions[:, :DISCRETE_DIMENSION_INDEX]
|
actions: Tensor = actions[:, :DISCRETE_DIMENSION_INDEX]
|
||||||
|
|
||||||
q_preds = self.critic_forward(
|
q_preds = self.critic_forward(
|
||||||
observations=observations,
|
observations=observations,
|
||||||
actions=actions,
|
actions=actions,
|
||||||
|
@ -407,15 +414,13 @@ class SACPolicy(
|
||||||
# For DQN, select actions using online network, evaluate with target network
|
# For DQN, select actions using online network, evaluate with target network
|
||||||
next_grasp_qs = self.grasp_critic_forward(next_observations, use_target=False)
|
next_grasp_qs = self.grasp_critic_forward(next_observations, use_target=False)
|
||||||
best_next_grasp_action = torch.argmax(next_grasp_qs, dim=-1)
|
best_next_grasp_action = torch.argmax(next_grasp_qs, dim=-1)
|
||||||
|
|
||||||
# Get target Q-values from target network
|
# Get target Q-values from target network
|
||||||
target_next_grasp_qs = self.grasp_critic_forward(observations=next_observations, use_target=True)
|
target_next_grasp_qs = self.grasp_critic_forward(observations=next_observations, use_target=True)
|
||||||
|
|
||||||
# Use gather to select Q-values for best actions
|
# Use gather to select Q-values for best actions
|
||||||
target_next_grasp_q = torch.gather(
|
target_next_grasp_q = torch.gather(
|
||||||
target_next_grasp_qs,
|
target_next_grasp_qs, dim=1, index=best_next_grasp_action.unsqueeze(-1)
|
||||||
dim=1,
|
|
||||||
index=best_next_grasp_action.unsqueeze(-1)
|
|
||||||
).squeeze(-1)
|
).squeeze(-1)
|
||||||
|
|
||||||
# Compute target Q-value with Bellman equation
|
# Compute target Q-value with Bellman equation
|
||||||
|
@ -423,13 +428,9 @@ class SACPolicy(
|
||||||
|
|
||||||
# Get predicted Q-values for current observations
|
# Get predicted Q-values for current observations
|
||||||
predicted_grasp_qs = self.grasp_critic_forward(observations=observations, use_target=False)
|
predicted_grasp_qs = self.grasp_critic_forward(observations=observations, use_target=False)
|
||||||
|
|
||||||
# Use gather to select Q-values for taken actions
|
# Use gather to select Q-values for taken actions
|
||||||
predicted_grasp_q = torch.gather(
|
predicted_grasp_q = torch.gather(predicted_grasp_qs, dim=1, index=actions.unsqueeze(-1)).squeeze(-1)
|
||||||
predicted_grasp_qs,
|
|
||||||
dim=1,
|
|
||||||
index=actions.unsqueeze(-1)
|
|
||||||
).squeeze(-1)
|
|
||||||
|
|
||||||
# Compute MSE loss between predicted and target Q-values
|
# Compute MSE loss between predicted and target Q-values
|
||||||
grasp_critic_loss = F.mse_loss(input=predicted_grasp_q, target=target_grasp_q)
|
grasp_critic_loss = F.mse_loss(input=predicted_grasp_q, target=target_grasp_q)
|
||||||
|
@ -642,7 +643,7 @@ class GraspCritic(nn.Module):
|
||||||
self,
|
self,
|
||||||
encoder: Optional[nn.Module],
|
encoder: Optional[nn.Module],
|
||||||
network: nn.Module,
|
network: nn.Module,
|
||||||
output_dim: int = 3, # TODO (azouitine): rename it number of discret acitons smth like that
|
output_dim: int = 3, # TODO (azouitine): rename it number of discret acitons smth like that
|
||||||
init_final: Optional[float] = None,
|
init_final: Optional[float] = None,
|
||||||
encoder_is_shared: bool = False,
|
encoder_is_shared: bool = False,
|
||||||
):
|
):
|
||||||
|
|
|
@ -394,7 +394,7 @@ def add_actor_information_and_train(
|
||||||
|
|
||||||
# Use the forward method for critic loss (includes both main critic and grasp critic)
|
# Use the forward method for critic loss (includes both main critic and grasp critic)
|
||||||
critic_output = policy.forward(forward_batch, model="critic")
|
critic_output = policy.forward(forward_batch, model="critic")
|
||||||
|
|
||||||
# Main critic optimization
|
# Main critic optimization
|
||||||
loss_critic = critic_output["loss_critic"]
|
loss_critic = critic_output["loss_critic"]
|
||||||
optimizers["critic"].zero_grad()
|
optimizers["critic"].zero_grad()
|
||||||
|
@ -405,7 +405,7 @@ def add_actor_information_and_train(
|
||||||
optimizers["critic"].step()
|
optimizers["critic"].step()
|
||||||
|
|
||||||
# Grasp critic optimization (if available)
|
# Grasp critic optimization (if available)
|
||||||
if "loss_grasp_critic" in critic_output and hasattr(policy, "grasp_critic"):
|
if "loss_grasp_critic" in critic_output:
|
||||||
loss_grasp_critic = critic_output["loss_grasp_critic"]
|
loss_grasp_critic = critic_output["loss_grasp_critic"]
|
||||||
optimizers["grasp_critic"].zero_grad()
|
optimizers["grasp_critic"].zero_grad()
|
||||||
loss_grasp_critic.backward()
|
loss_grasp_critic.backward()
|
||||||
|
@ -450,7 +450,7 @@ def add_actor_information_and_train(
|
||||||
|
|
||||||
# Use the forward method for critic loss (includes both main critic and grasp critic)
|
# Use the forward method for critic loss (includes both main critic and grasp critic)
|
||||||
critic_output = policy.forward(forward_batch, model="critic")
|
critic_output = policy.forward(forward_batch, model="critic")
|
||||||
|
|
||||||
# Main critic optimization
|
# Main critic optimization
|
||||||
loss_critic = critic_output["loss_critic"]
|
loss_critic = critic_output["loss_critic"]
|
||||||
optimizers["critic"].zero_grad()
|
optimizers["critic"].zero_grad()
|
||||||
|
@ -475,7 +475,7 @@ def add_actor_information_and_train(
|
||||||
parameters=policy.grasp_critic.parameters(), max_norm=clip_grad_norm_value
|
parameters=policy.grasp_critic.parameters(), max_norm=clip_grad_norm_value
|
||||||
).item()
|
).item()
|
||||||
optimizers["grasp_critic"].step()
|
optimizers["grasp_critic"].step()
|
||||||
|
|
||||||
# Add grasp critic info to training info
|
# Add grasp critic info to training info
|
||||||
training_infos["loss_grasp_critic"] = loss_grasp_critic.item()
|
training_infos["loss_grasp_critic"] = loss_grasp_critic.item()
|
||||||
training_infos["grasp_critic_grad_norm"] = grasp_critic_grad_norm
|
training_infos["grasp_critic_grad_norm"] = grasp_critic_grad_norm
|
||||||
|
@ -492,7 +492,7 @@ def add_actor_information_and_train(
|
||||||
parameters=policy.actor.parameters_to_optimize, max_norm=clip_grad_norm_value
|
parameters=policy.actor.parameters_to_optimize, max_norm=clip_grad_norm_value
|
||||||
).item()
|
).item()
|
||||||
optimizers["actor"].step()
|
optimizers["actor"].step()
|
||||||
|
|
||||||
# Add actor info to training info
|
# Add actor info to training info
|
||||||
training_infos["loss_actor"] = loss_actor.item()
|
training_infos["loss_actor"] = loss_actor.item()
|
||||||
training_infos["actor_grad_norm"] = actor_grad_norm
|
training_infos["actor_grad_norm"] = actor_grad_norm
|
||||||
|
@ -506,7 +506,7 @@ def add_actor_information_and_train(
|
||||||
parameters=[policy.log_alpha], max_norm=clip_grad_norm_value
|
parameters=[policy.log_alpha], max_norm=clip_grad_norm_value
|
||||||
).item()
|
).item()
|
||||||
optimizers["temperature"].step()
|
optimizers["temperature"].step()
|
||||||
|
|
||||||
# Add temperature info to training info
|
# Add temperature info to training info
|
||||||
training_infos["loss_temperature"] = loss_temperature.item()
|
training_infos["loss_temperature"] = loss_temperature.item()
|
||||||
training_infos["temperature_grad_norm"] = temp_grad_norm
|
training_infos["temperature_grad_norm"] = temp_grad_norm
|
||||||
|
@ -756,7 +756,7 @@ def make_optimizers_and_scheduler(cfg: TrainPipelineConfig, policy: nn.Module):
|
||||||
lr=cfg.policy.actor_lr,
|
lr=cfg.policy.actor_lr,
|
||||||
)
|
)
|
||||||
optimizer_critic = torch.optim.Adam(params=policy.critic_ensemble.parameters(), lr=cfg.policy.critic_lr)
|
optimizer_critic = torch.optim.Adam(params=policy.critic_ensemble.parameters(), lr=cfg.policy.critic_lr)
|
||||||
|
|
||||||
if cfg.policy.num_discrete_actions is not None:
|
if cfg.policy.num_discrete_actions is not None:
|
||||||
optimizer_grasp_critic = torch.optim.Adam(
|
optimizer_grasp_critic = torch.optim.Adam(
|
||||||
params=policy.grasp_critic.parameters(), lr=policy.critic_lr
|
params=policy.grasp_critic.parameters(), lr=policy.critic_lr
|
||||||
|
|
Loading…
Reference in New Issue