Enhance SACPolicy to support shared encoder and optimize action selection
- Cached encoder output in select_action method to reduce redundant computations. - Updated action selection and grasp critic calls to utilize cached encoder features when available.
This commit is contained in:
parent
0ed7ff142c
commit
51f1625c20
|
@ -81,6 +81,7 @@ class SACPolicy(
|
||||||
else:
|
else:
|
||||||
encoder_critic = SACObservationEncoder(config, self.normalize_inputs)
|
encoder_critic = SACObservationEncoder(config, self.normalize_inputs)
|
||||||
encoder_actor = SACObservationEncoder(config, self.normalize_inputs)
|
encoder_actor = SACObservationEncoder(config, self.normalize_inputs)
|
||||||
|
self.shared_encoder = config.shared_encoder
|
||||||
|
|
||||||
# Create a list of critic heads
|
# Create a list of critic heads
|
||||||
critic_heads = [
|
critic_heads = [
|
||||||
|
@ -184,11 +185,16 @@ class SACPolicy(
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
def select_action(self, batch: dict[str, Tensor]) -> Tensor:
|
def select_action(self, batch: dict[str, Tensor]) -> Tensor:
|
||||||
"""Select action for inference/evaluation"""
|
"""Select action for inference/evaluation"""
|
||||||
actions, _, _ = self.actor(batch)
|
# We cached the encoder output to avoid recomputing it
|
||||||
|
observations_features = None
|
||||||
|
if self.shared_encoder and self.actor.encoder is not None:
|
||||||
|
observations_features = self.actor.encoder(batch)
|
||||||
|
|
||||||
|
actions, _, _ = self.actor(batch, observations_features)
|
||||||
actions = self.unnormalize_outputs({"action": actions})["action"]
|
actions = self.unnormalize_outputs({"action": actions})["action"]
|
||||||
|
|
||||||
if self.config.num_discrete_actions is not None:
|
if self.config.num_discrete_actions is not None:
|
||||||
discrete_action_value = self.grasp_critic(batch)
|
discrete_action_value = self.grasp_critic(batch, observations_features)
|
||||||
discrete_action = torch.argmax(discrete_action_value, dim=-1, keepdim=True)
|
discrete_action = torch.argmax(discrete_action_value, dim=-1, keepdim=True)
|
||||||
actions = torch.cat([actions, discrete_action], dim=-1)
|
actions = torch.cat([actions, discrete_action], dim=-1)
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue