Refactor SACPolicy initialization by breaking down the constructor into smaller methods for normalization, encoders, critics, actor, and temperature setup. This enhances readability and maintainability.

This commit is contained in:
AdilZouitine 2025-04-17 16:37:43 +00:00 committed by Michel Aractingi
parent fb075a709d
commit 1ce368503d
1 changed files with 101 additions and 108 deletions

View File

@ -52,115 +52,13 @@ class SACPolicy(
config.validate_features()
self.config = config
# Determine action dimension and initialize all components
continuous_action_dim = config.output_features["action"].shape[0]
# Default to identity normalizations
self.normalize_inputs = nn.Identity()
self.normalize_targets = nn.Identity()
self.unnormalize_outputs = nn.Identity()
# Apply normalization if dataset stats provided
if config.dataset_stats:
params = _convert_normalization_params_to_tensor(config.dataset_stats)
self.normalize_inputs = Normalize(
config.input_features,
config.normalization_mapping,
params,
)
stats = dataset_stats or params
self.normalize_targets = Normalize(
config.output_features,
config.normalization_mapping,
stats,
)
self.unnormalize_outputs = Unnormalize(
config.output_features,
config.normalization_mapping,
stats,
)
# NOTE: For images the encoder should be shared between the actor and critic
self.shared_encoder = config.shared_encoder
encoder_critic = SACObservationEncoder(config, self.normalize_inputs)
encoder_actor = (
encoder_critic if self.shared_encoder else SACObservationEncoder(config, self.normalize_inputs)
)
# Create a list of critic heads
critic_heads = [
CriticHead(
input_dim=encoder_critic.output_dim + continuous_action_dim,
**asdict(config.critic_network_kwargs),
)
for _ in range(config.num_critics)
]
self.critic_ensemble = CriticEnsemble(
encoder=encoder_critic,
ensemble=critic_heads,
output_normalization=self.normalize_targets,
)
# Create target critic heads as deepcopies of the original critic heads
target_critic_heads = [
CriticHead(
input_dim=encoder_critic.output_dim + continuous_action_dim,
**asdict(config.critic_network_kwargs),
)
for _ in range(config.num_critics)
]
self.critic_target = CriticEnsemble(
encoder=encoder_critic,
ensemble=target_critic_heads,
output_normalization=self.normalize_targets,
)
self.critic_target.load_state_dict(self.critic_ensemble.state_dict())
self.critic_ensemble = torch.compile(self.critic_ensemble)
self.critic_target = torch.compile(self.critic_target)
self.grasp_critic = None
self.grasp_critic_target = None
if config.num_discrete_actions is not None:
# Create grasp critic
self.grasp_critic = GraspCritic(
encoder=encoder_critic,
input_dim=encoder_critic.output_dim,
output_dim=config.num_discrete_actions,
**asdict(config.grasp_critic_network_kwargs),
)
# Create target grasp critic
self.grasp_critic_target = GraspCritic(
encoder=encoder_critic,
input_dim=encoder_critic.output_dim,
output_dim=config.num_discrete_actions,
**asdict(config.grasp_critic_network_kwargs),
)
self.grasp_critic_target.load_state_dict(self.grasp_critic.state_dict())
self.grasp_critic = torch.compile(self.grasp_critic)
self.grasp_critic_target = torch.compile(self.grasp_critic_target)
self.actor = Policy(
encoder=encoder_actor,
network=MLP(input_dim=encoder_actor.output_dim, **asdict(config.actor_network_kwargs)),
action_dim=continuous_action_dim,
encoder_is_shared=config.shared_encoder,
**asdict(config.policy_kwargs),
)
if config.target_entropy is None:
discrete_actions_dim: Literal[1] | Literal[0] = (
1 if config.num_discrete_actions is not None else 0
)
config.target_entropy = -np.prod(continuous_action_dim + discrete_actions_dim) / 2 # (-dim(A)/2)
temperature_init = config.temperature_init
self.log_alpha = nn.Parameter(torch.tensor([math.log(temperature_init)]))
self.temperature = self.log_alpha.exp().item()
self._init_normalization(dataset_stats)
self._init_encoders()
self._init_critics(continuous_action_dim)
self._init_actor(continuous_action_dim)
self._init_temperature()
def get_optim_params(self) -> dict:
optim_params = {
@ -492,6 +390,101 @@ class SACPolicy(
actor_loss = ((self.temperature * log_probs) - min_q_preds).mean()
return actor_loss
def _init_normalization(self, dataset_stats):
"""Initialize input/output normalization modules."""
self.normalize_inputs = nn.Identity()
self.normalize_targets = nn.Identity()
self.unnormalize_outputs = nn.Identity()
if self.config.dataset_stats:
params = _convert_normalization_params_to_tensor(self.config.dataset_stats)
self.normalize_inputs = Normalize(
self.config.input_features, self.config.normalization_mapping, params
)
stats = dataset_stats or params
self.normalize_targets = Normalize(
self.config.output_features, self.config.normalization_mapping, stats
)
self.unnormalize_outputs = Unnormalize(
self.config.output_features, self.config.normalization_mapping, stats
)
def _init_encoders(self):
"""Initialize shared or separate encoders for actor and critic."""
self.shared_encoder = self.config.shared_encoder
self.encoder_critic = SACObservationEncoder(self.config, self.normalize_inputs)
self.encoder_actor = (
self.encoder_critic
if self.shared_encoder
else SACObservationEncoder(self.config, self.normalize_inputs)
)
def _init_critics(self, continuous_action_dim):
"""Build critic ensemble, targets, and optional grasp critic."""
heads = [
CriticHead(
input_dim=self.encoder_critic.output_dim + continuous_action_dim,
**asdict(self.config.critic_network_kwargs),
)
for _ in range(self.config.num_critics)
]
self.critic_ensemble = CriticEnsemble(
encoder=self.encoder_critic, ensemble=heads, output_normalization=self.normalize_targets
)
target_heads = [
CriticHead(
input_dim=self.encoder_critic.output_dim + continuous_action_dim,
**asdict(self.config.critic_network_kwargs),
)
for _ in range(self.config.num_critics)
]
self.critic_target = CriticEnsemble(
encoder=self.encoder_critic, ensemble=target_heads, output_normalization=self.normalize_targets
)
self.critic_target.load_state_dict(self.critic_ensemble.state_dict())
self.critic_ensemble = torch.compile(self.critic_ensemble)
self.critic_target = torch.compile(self.critic_target)
if self.config.num_discrete_actions is not None:
self._init_grasp_critics()
def _init_grasp_critics(self):
"""Build discrete grasp critic ensemble and target networks."""
self.grasp_critic = GraspCritic(
encoder=self.encoder_critic,
input_dim=self.encoder_critic.output_dim,
output_dim=self.config.num_discrete_actions,
**asdict(self.config.grasp_critic_network_kwargs),
)
self.grasp_critic_target = GraspCritic(
encoder=self.encoder_critic,
input_dim=self.encoder_critic.output_dim,
output_dim=self.config.num_discrete_actions,
**asdict(self.config.grasp_critic_network_kwargs),
)
# TODO: (maractingi, azouitine) Compile the grasp critic
self.grasp_critic_target.load_state_dict(self.grasp_critic.state_dict())
def _init_actor(self, continuous_action_dim):
"""Initialize policy actor network and default target entropy."""
self.actor = Policy(
encoder=self.encoder_actor,
network=MLP(input_dim=self.encoder_actor.output_dim, **asdict(self.config.actor_network_kwargs)),
action_dim=continuous_action_dim,
encoder_is_shared=self.shared_encoder,
**asdict(self.config.policy_kwargs),
)
if self.config.target_entropy is None:
dim = continuous_action_dim + (1 if self.config.num_discrete_actions is not None else 0)
self.config.target_entropy = -np.prod(dim) / 2
def _init_temperature(self):
"""Set up temperature parameter and initial log_alpha."""
temp_init = self.config.temperature_init
self.log_alpha = nn.Parameter(torch.tensor([math.log(temp_init)]))
self.temperature = self.log_alpha.exp().item()
class SACObservationEncoder(nn.Module):
"""Encode image and/or state vector observations."""