Enhance SACPolicy to support shared encoder and optimize action selection
- Cached encoder output in select_action method to reduce redundant computations. - Updated action selection and grasp critic calls to utilize cached encoder features when available.
This commit is contained in:
parent
0ed7ff142c
commit
51f1625c20
|
@ -81,6 +81,7 @@ class SACPolicy(
|
|||
else:
|
||||
encoder_critic = SACObservationEncoder(config, self.normalize_inputs)
|
||||
encoder_actor = SACObservationEncoder(config, self.normalize_inputs)
|
||||
self.shared_encoder = config.shared_encoder
|
||||
|
||||
# Create a list of critic heads
|
||||
critic_heads = [
|
||||
|
@ -184,11 +185,16 @@ class SACPolicy(
|
|||
@torch.no_grad()
|
||||
def select_action(self, batch: dict[str, Tensor]) -> Tensor:
|
||||
"""Select action for inference/evaluation"""
|
||||
actions, _, _ = self.actor(batch)
|
||||
# We cached the encoder output to avoid recomputing it
|
||||
observations_features = None
|
||||
if self.shared_encoder and self.actor.encoder is not None:
|
||||
observations_features = self.actor.encoder(batch)
|
||||
|
||||
actions, _, _ = self.actor(batch, observations_features)
|
||||
actions = self.unnormalize_outputs({"action": actions})["action"]
|
||||
|
||||
if self.config.num_discrete_actions is not None:
|
||||
discrete_action_value = self.grasp_critic(batch)
|
||||
discrete_action_value = self.grasp_critic(batch, observations_features)
|
||||
discrete_action = torch.argmax(discrete_action_value, dim=-1, keepdim=True)
|
||||
actions = torch.cat([actions, discrete_action], dim=-1)
|
||||
|
||||
|
|
Loading…
Reference in New Issue