Enhance SACPolicy to support shared encoder and optimize action selection

- Cached encoder output in select_action method to reduce redundant computations.
- Updated action selection and grasp critic calls to utilize cached encoder features when available.
This commit is contained in:
AdilZouitine 2025-04-03 07:44:46 +00:00
parent 0ed7ff142c
commit 51f1625c20
1 changed files with 8 additions and 2 deletions

View File

@ -81,6 +81,7 @@ class SACPolicy(
else: else:
encoder_critic = SACObservationEncoder(config, self.normalize_inputs) encoder_critic = SACObservationEncoder(config, self.normalize_inputs)
encoder_actor = SACObservationEncoder(config, self.normalize_inputs) encoder_actor = SACObservationEncoder(config, self.normalize_inputs)
self.shared_encoder = config.shared_encoder
# Create a list of critic heads # Create a list of critic heads
critic_heads = [ critic_heads = [
@ -184,11 +185,16 @@ class SACPolicy(
@torch.no_grad() @torch.no_grad()
def select_action(self, batch: dict[str, Tensor]) -> Tensor: def select_action(self, batch: dict[str, Tensor]) -> Tensor:
"""Select action for inference/evaluation""" """Select action for inference/evaluation"""
actions, _, _ = self.actor(batch) # We cached the encoder output to avoid recomputing it
observations_features = None
if self.shared_encoder and self.actor.encoder is not None:
observations_features = self.actor.encoder(batch)
actions, _, _ = self.actor(batch, observations_features)
actions = self.unnormalize_outputs({"action": actions})["action"] actions = self.unnormalize_outputs({"action": actions})["action"]
if self.config.num_discrete_actions is not None: if self.config.num_discrete_actions is not None:
discrete_action_value = self.grasp_critic(batch) discrete_action_value = self.grasp_critic(batch, observations_features)
discrete_action = torch.argmax(discrete_action_value, dim=-1, keepdim=True) discrete_action = torch.argmax(discrete_action_value, dim=-1, keepdim=True)
actions = torch.cat([actions, discrete_action], dim=-1) actions = torch.cat([actions, discrete_action], dim=-1)