From 2d932b710c0592937f0777aa7d6c90f5d9a237df Mon Sep 17 00:00:00 2001 From: AdilZouitine Date: Thu, 3 Apr 2025 07:44:46 +0000 Subject: [PATCH] Enhance SACPolicy to support shared encoder and optimize action selection - Cached encoder output in select_action method to reduce redundant computations. - Updated action selection and grasp critic calls to utilize cached encoder features when available. --- lerobot/common/policies/sac/modeling_sac.py | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) diff --git a/lerobot/common/policies/sac/modeling_sac.py b/lerobot/common/policies/sac/modeling_sac.py index af624592..2246bf8c 100644 --- a/lerobot/common/policies/sac/modeling_sac.py +++ b/lerobot/common/policies/sac/modeling_sac.py @@ -81,6 +81,7 @@ class SACPolicy( else: encoder_critic = SACObservationEncoder(config, self.normalize_inputs) encoder_actor = SACObservationEncoder(config, self.normalize_inputs) + self.shared_encoder = config.shared_encoder # Create a list of critic heads critic_heads = [ @@ -184,11 +185,16 @@ class SACPolicy( @torch.no_grad() def select_action(self, batch: dict[str, Tensor]) -> Tensor: """Select action for inference/evaluation""" - actions, _, _ = self.actor(batch) + # We cached the encoder output to avoid recomputing it + observations_features = None + if self.shared_encoder and self.actor.encoder is not None: + observations_features = self.actor.encoder(batch) + + actions, _, _ = self.actor(batch, observations_features) actions = self.unnormalize_outputs({"action": actions})["action"] if self.config.num_discrete_actions is not None: - discrete_action_value = self.grasp_critic(batch) + discrete_action_value = self.grasp_critic(batch, observations_features) discrete_action = torch.argmax(discrete_action_value, dim=-1, keepdim=True) actions = torch.cat([actions, discrete_action], dim=-1)