Refactor modeling_sac and parameter handling for clarity and reusability.
Co-authored-by: s1lent4gnt <kmeftah.khalil@gmail.com>
This commit is contained in:
parent
854bfb4ff8
commit
320a1a92a3
|
@ -167,8 +167,12 @@ class SACPolicy(
|
||||||
|
|
||||||
def get_optim_params(self) -> dict:
|
def get_optim_params(self) -> dict:
|
||||||
optim_params = {
|
optim_params = {
|
||||||
"actor": self.actor.parameters_to_optimize,
|
"actor": [
|
||||||
"critic": self.critic_ensemble.parameters_to_optimize,
|
p
|
||||||
|
for n, p in self.actor.named_parameters()
|
||||||
|
if not n.startswith("encoder") or not self.shared_encoder
|
||||||
|
],
|
||||||
|
"critic": self.critic_ensemble.parameters(),
|
||||||
"temperature": self.log_alpha,
|
"temperature": self.log_alpha,
|
||||||
}
|
}
|
||||||
if self.config.num_discrete_actions is not None:
|
if self.config.num_discrete_actions is not None:
|
||||||
|
@ -510,7 +514,6 @@ class SACObservationEncoder(nn.Module):
|
||||||
self.config = config
|
self.config = config
|
||||||
self.input_normalization = input_normalizer
|
self.input_normalization = input_normalizer
|
||||||
self.has_pretrained_vision_encoder = False
|
self.has_pretrained_vision_encoder = False
|
||||||
self.parameters_to_optimize = []
|
|
||||||
|
|
||||||
self.aggregation_size: int = 0
|
self.aggregation_size: int = 0
|
||||||
if any("observation.image" in key for key in config.input_features):
|
if any("observation.image" in key for key in config.input_features):
|
||||||
|
@ -527,8 +530,6 @@ class SACObservationEncoder(nn.Module):
|
||||||
if config.freeze_vision_encoder:
|
if config.freeze_vision_encoder:
|
||||||
freeze_image_encoder(self.image_enc_layers.image_enc_layers)
|
freeze_image_encoder(self.image_enc_layers.image_enc_layers)
|
||||||
|
|
||||||
self.parameters_to_optimize += self.image_enc_layers.parameters_to_optimize
|
|
||||||
|
|
||||||
self.all_image_keys = [k for k in config.input_features if k.startswith("observation.image")]
|
self.all_image_keys = [k for k in config.input_features if k.startswith("observation.image")]
|
||||||
|
|
||||||
if "observation.state" in config.input_features:
|
if "observation.state" in config.input_features:
|
||||||
|
@ -542,8 +543,6 @@ class SACObservationEncoder(nn.Module):
|
||||||
)
|
)
|
||||||
self.aggregation_size += config.latent_dim
|
self.aggregation_size += config.latent_dim
|
||||||
|
|
||||||
self.parameters_to_optimize += list(self.state_enc_layers.parameters())
|
|
||||||
|
|
||||||
if "observation.environment_state" in config.input_features:
|
if "observation.environment_state" in config.input_features:
|
||||||
self.env_state_enc_layers = nn.Sequential(
|
self.env_state_enc_layers = nn.Sequential(
|
||||||
nn.Linear(
|
nn.Linear(
|
||||||
|
@ -554,10 +553,8 @@ class SACObservationEncoder(nn.Module):
|
||||||
nn.Tanh(),
|
nn.Tanh(),
|
||||||
)
|
)
|
||||||
self.aggregation_size += config.latent_dim
|
self.aggregation_size += config.latent_dim
|
||||||
self.parameters_to_optimize += list(self.env_state_enc_layers.parameters())
|
|
||||||
|
|
||||||
self.aggregation_layer = nn.Linear(in_features=self.aggregation_size, out_features=config.latent_dim)
|
self.aggregation_layer = nn.Linear(in_features=self.aggregation_size, out_features=config.latent_dim)
|
||||||
self.parameters_to_optimize += list(self.aggregation_layer.parameters())
|
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self, obs_dict: dict[str, Tensor], vision_encoder_cache: torch.Tensor | None = None
|
self, obs_dict: dict[str, Tensor], vision_encoder_cache: torch.Tensor | None = None
|
||||||
|
@ -737,12 +734,6 @@ class CriticEnsemble(nn.Module):
|
||||||
self.output_normalization = output_normalization
|
self.output_normalization = output_normalization
|
||||||
self.critics = nn.ModuleList(ensemble)
|
self.critics = nn.ModuleList(ensemble)
|
||||||
|
|
||||||
self.parameters_to_optimize = []
|
|
||||||
# Handle the case where a part of the encoder if frozen
|
|
||||||
if self.encoder is not None:
|
|
||||||
self.parameters_to_optimize += list(self.encoder.parameters_to_optimize)
|
|
||||||
self.parameters_to_optimize += list(self.critics.parameters())
|
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
observations: dict[str, torch.Tensor],
|
observations: dict[str, torch.Tensor],
|
||||||
|
@ -805,10 +796,6 @@ class GraspCritic(nn.Module):
|
||||||
else:
|
else:
|
||||||
orthogonal_init()(self.output_layer.weight)
|
orthogonal_init()(self.output_layer.weight)
|
||||||
|
|
||||||
self.parameters_to_optimize = []
|
|
||||||
self.parameters_to_optimize += list(self.net.parameters())
|
|
||||||
self.parameters_to_optimize += list(self.output_layer.parameters())
|
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self, observations: torch.Tensor, observation_features: torch.Tensor | None = None
|
self, observations: torch.Tensor, observation_features: torch.Tensor | None = None
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
|
@ -840,12 +827,8 @@ class Policy(nn.Module):
|
||||||
self.log_std_max = log_std_max
|
self.log_std_max = log_std_max
|
||||||
self.fixed_std = fixed_std
|
self.fixed_std = fixed_std
|
||||||
self.use_tanh_squash = use_tanh_squash
|
self.use_tanh_squash = use_tanh_squash
|
||||||
self.parameters_to_optimize = []
|
self.encoder_is_shared = encoder_is_shared
|
||||||
|
|
||||||
self.parameters_to_optimize += list(self.network.parameters())
|
|
||||||
|
|
||||||
if self.encoder is not None and not encoder_is_shared:
|
|
||||||
self.parameters_to_optimize += list(self.encoder.parameters())
|
|
||||||
# Find the last Linear layer's output dimension
|
# Find the last Linear layer's output dimension
|
||||||
for layer in reversed(network.net):
|
for layer in reversed(network.net):
|
||||||
if isinstance(layer, nn.Linear):
|
if isinstance(layer, nn.Linear):
|
||||||
|
@ -859,7 +842,6 @@ class Policy(nn.Module):
|
||||||
else:
|
else:
|
||||||
orthogonal_init()(self.mean_layer.weight)
|
orthogonal_init()(self.mean_layer.weight)
|
||||||
|
|
||||||
self.parameters_to_optimize += list(self.mean_layer.parameters())
|
|
||||||
# Standard deviation layer or parameter
|
# Standard deviation layer or parameter
|
||||||
if fixed_std is None:
|
if fixed_std is None:
|
||||||
self.std_layer = nn.Linear(out_features, action_dim)
|
self.std_layer = nn.Linear(out_features, action_dim)
|
||||||
|
@ -868,7 +850,6 @@ class Policy(nn.Module):
|
||||||
nn.init.uniform_(self.std_layer.bias, -init_final, init_final)
|
nn.init.uniform_(self.std_layer.bias, -init_final, init_final)
|
||||||
else:
|
else:
|
||||||
orthogonal_init()(self.std_layer.weight)
|
orthogonal_init()(self.std_layer.weight)
|
||||||
self.parameters_to_optimize += list(self.std_layer.parameters())
|
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
|
@ -877,6 +858,8 @@ class Policy(nn.Module):
|
||||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||||
# Encode observations if encoder exists
|
# Encode observations if encoder exists
|
||||||
obs_enc = self.encoder(observations, vision_encoder_cache=observation_features)
|
obs_enc = self.encoder(observations, vision_encoder_cache=observation_features)
|
||||||
|
if self.encoder_is_shared:
|
||||||
|
obs_enc = obs_enc.detach()
|
||||||
|
|
||||||
# Get network outputs
|
# Get network outputs
|
||||||
outputs = self.network(obs_enc)
|
outputs = self.network(obs_enc)
|
||||||
|
@ -966,13 +949,13 @@ class DefaultImageEncoder(nn.Module):
|
||||||
nn.Tanh(),
|
nn.Tanh(),
|
||||||
)
|
)
|
||||||
|
|
||||||
self.parameters_to_optimize = []
|
self.freeze_image_encoder = config.freeze_vision_encoder
|
||||||
if not config.freeze_vision_encoder:
|
|
||||||
self.parameters_to_optimize += list(self.image_enc_layers.parameters())
|
|
||||||
self.parameters_to_optimize += list(self.image_enc_proj.parameters())
|
|
||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
return self.image_enc_proj(self.image_enc_layers(x))
|
x = self.image_enc_layers(x)
|
||||||
|
if self.freeze_image_encoder:
|
||||||
|
x = x.detach()
|
||||||
|
return self.image_enc_proj(x)
|
||||||
|
|
||||||
|
|
||||||
class PretrainedImageEncoder(nn.Module):
|
class PretrainedImageEncoder(nn.Module):
|
||||||
|
@ -985,10 +968,7 @@ class PretrainedImageEncoder(nn.Module):
|
||||||
nn.Tanh(),
|
nn.Tanh(),
|
||||||
)
|
)
|
||||||
|
|
||||||
self.parameters_to_optimize = []
|
self.freeze_image_encoder = config.freeze_vision_encoder
|
||||||
if not config.freeze_vision_encoder:
|
|
||||||
self.parameters_to_optimize += list(self.image_enc_layers.parameters())
|
|
||||||
self.parameters_to_optimize += list(self.image_enc_proj.parameters())
|
|
||||||
|
|
||||||
def _load_pretrained_vision_encoder(self, config: SACConfig):
|
def _load_pretrained_vision_encoder(self, config: SACConfig):
|
||||||
"""Set up CNN encoder"""
|
"""Set up CNN encoder"""
|
||||||
|
@ -1009,6 +989,8 @@ class PretrainedImageEncoder(nn.Module):
|
||||||
# TODO: (maractingi, azouitine) check the forward pass of the pretrained model
|
# TODO: (maractingi, azouitine) check the forward pass of the pretrained model
|
||||||
# doesn't reach the classifier layer because we don't need it
|
# doesn't reach the classifier layer because we don't need it
|
||||||
enc_feat = self.image_enc_layers(x).pooler_output
|
enc_feat = self.image_enc_layers(x).pooler_output
|
||||||
|
if self.freeze_image_encoder:
|
||||||
|
enc_feat = enc_feat.detach()
|
||||||
enc_feat = self.image_enc_proj(enc_feat.view(enc_feat.shape[0], -1))
|
enc_feat = self.image_enc_proj(enc_feat.view(enc_feat.shape[0], -1))
|
||||||
return enc_feat
|
return enc_feat
|
||||||
|
|
||||||
|
|
|
@ -510,7 +510,7 @@ def add_actor_information_and_train(
|
||||||
optimizers["actor"].zero_grad()
|
optimizers["actor"].zero_grad()
|
||||||
loss_actor.backward()
|
loss_actor.backward()
|
||||||
actor_grad_norm = torch.nn.utils.clip_grad_norm_(
|
actor_grad_norm = torch.nn.utils.clip_grad_norm_(
|
||||||
parameters=policy.actor.parameters_to_optimize, max_norm=clip_grad_norm_value
|
parameters=policy.actor.parameters(), max_norm=clip_grad_norm_value
|
||||||
).item()
|
).item()
|
||||||
optimizers["actor"].step()
|
optimizers["actor"].step()
|
||||||
|
|
||||||
|
@ -773,12 +773,14 @@ def make_optimizers_and_scheduler(cfg: TrainPipelineConfig, policy: nn.Module):
|
||||||
"""
|
"""
|
||||||
optimizer_actor = torch.optim.Adam(
|
optimizer_actor = torch.optim.Adam(
|
||||||
# NOTE: Handle the case of shared encoder where the encoder weights are not optimized with the gradient of the actor
|
# NOTE: Handle the case of shared encoder where the encoder weights are not optimized with the gradient of the actor
|
||||||
params=policy.actor.parameters_to_optimize,
|
params=[
|
||||||
|
p
|
||||||
|
for n, p in policy.actor.named_parameters()
|
||||||
|
if not n.startswith("encoder") or not policy.config.shared_encoder
|
||||||
|
],
|
||||||
lr=cfg.policy.actor_lr,
|
lr=cfg.policy.actor_lr,
|
||||||
)
|
)
|
||||||
optimizer_critic = torch.optim.Adam(
|
optimizer_critic = torch.optim.Adam(params=policy.critic_ensemble.parameters(), lr=cfg.policy.critic_lr)
|
||||||
params=policy.critic_ensemble.parameters_to_optimize, lr=cfg.policy.critic_lr
|
|
||||||
)
|
|
||||||
|
|
||||||
if cfg.policy.num_discrete_actions is not None:
|
if cfg.policy.num_discrete_actions is not None:
|
||||||
optimizer_grasp_critic = torch.optim.Adam(
|
optimizer_grasp_critic = torch.optim.Adam(
|
||||||
|
@ -1089,6 +1091,44 @@ def push_actor_policy_to_queue(parameters_queue: Queue, policy: nn.Module):
|
||||||
parameters_queue.put(state_bytes)
|
parameters_queue.put(state_bytes)
|
||||||
|
|
||||||
|
|
||||||
|
def check_weight_gradients(module: nn.Module) -> dict[str, bool]:
|
||||||
|
"""
|
||||||
|
Checks whether each parameter in the module has a gradient.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
module (nn.Module): A PyTorch module whose parameters will be inspected.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
dict[str, bool]: A dictionary where each key is the parameter name and the value is
|
||||||
|
True if the parameter has an associated gradient (i.e. .grad is not None),
|
||||||
|
otherwise False.
|
||||||
|
"""
|
||||||
|
grad_status = {}
|
||||||
|
for name, param in module.named_parameters():
|
||||||
|
grad_status[name] = param.grad is not None
|
||||||
|
return grad_status
|
||||||
|
|
||||||
|
|
||||||
|
def get_overlapping_parameters(model: nn.Module, grad_status: dict[str, bool]) -> dict[str, bool]:
|
||||||
|
"""
|
||||||
|
Returns a dictionary of parameters (from actor) that also exist in the grad_status dictionary.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
actor (nn.Module): The actor model.
|
||||||
|
grad_status (dict[str, bool]): A dictionary where keys are parameter names and values indicate
|
||||||
|
whether each parameter has a gradient.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
dict[str, bool]: A dictionary containing only the overlapping parameter names and their gradient status.
|
||||||
|
"""
|
||||||
|
# Get actor parameter names as a set.
|
||||||
|
model_param_names = {name for name, _ in model.named_parameters()}
|
||||||
|
|
||||||
|
# Intersect parameter names between actor and grad_status.
|
||||||
|
overlapping = {name: grad_status[name] for name in grad_status if name in model_param_names}
|
||||||
|
return overlapping
|
||||||
|
|
||||||
|
|
||||||
def process_interaction_message(
|
def process_interaction_message(
|
||||||
message, interaction_step_shift: int, wandb_logger: WandBLogger | None = None
|
message, interaction_step_shift: int, wandb_logger: WandBLogger | None = None
|
||||||
):
|
):
|
||||||
|
|
Loading…
Reference in New Issue