split encoder for critic and actor

This commit is contained in:
Michel Aractingi 2024-12-29 23:59:39 +00:00
parent 9ceb68ee90
commit 41b377211c
1 changed files with 28 additions and 12 deletions

View File

@ -138,6 +138,22 @@ class SACPolicy(
actions = self.unnormalize_outputs({"action": actions})["action"] actions = self.unnormalize_outputs({"action": actions})["action"]
return actions return actions
def critic_forward(self, observations: dict[str, Tensor], actions: Tensor, use_target: bool = False) -> Tensor:
"""Forward pass through a critic network ensemble
Args:
observations: Dictionary of observations
actions: Action tensor
use_target: If True, use target critics, otherwise use ensemble critics
Returns:
Tensor of Q-values from all critics
"""
critics = self.critic_target if use_target else self.critic_ensemble
q_values = torch.stack([critic(observations, actions) for critic in critics])
return q_values
def critic_forward( def critic_forward(
self, observations: dict[str, Tensor], actions: Tensor, use_target: bool = False self, observations: dict[str, Tensor], actions: Tensor, use_target: bool = False
) -> Tensor: ) -> Tensor:
@ -371,6 +387,7 @@ class Policy(nn.Module):
self, self,
observations: torch.Tensor, observations: torch.Tensor,
) -> Tuple[torch.Tensor, torch.Tensor]: ) -> Tuple[torch.Tensor, torch.Tensor]:
# Encode observations if encoder exists # Encode observations if encoder exists
obs_enc = observations if self.encoder is None else self.encoder(observations) obs_enc = observations if self.encoder is None else self.encoder(observations)
@ -535,7 +552,6 @@ def create_critic_ensemble(critics: list[nn.Module], num_critics: int, device: s
assert len(critics) == num_critics, f"Expected {num_critics} critics, got {len(critics)}" assert len(critics) == num_critics, f"Expected {num_critics} critics, got {len(critics)}"
return nn.ModuleList(critics).to(device) return nn.ModuleList(critics).to(device)
# borrowed from tdmpc # borrowed from tdmpc
def flatten_forward_unflatten(fn: Callable[[Tensor], Tensor], image_tensor: Tensor) -> Tensor: def flatten_forward_unflatten(fn: Callable[[Tensor], Tensor], image_tensor: Tensor) -> Tensor:
"""Helper to temporarily flatten extra dims at the start of the image tensor. """Helper to temporarily flatten extra dims at the start of the image tensor.