split encoder for critic and actor
This commit is contained in:
parent
9ceb68ee90
commit
41b377211c
|
@ -138,6 +138,22 @@ class SACPolicy(
|
||||||
actions = self.unnormalize_outputs({"action": actions})["action"]
|
actions = self.unnormalize_outputs({"action": actions})["action"]
|
||||||
return actions
|
return actions
|
||||||
|
|
||||||
|
def critic_forward(self, observations: dict[str, Tensor], actions: Tensor, use_target: bool = False) -> Tensor:
|
||||||
|
"""Forward pass through a critic network ensemble
|
||||||
|
|
||||||
|
Args:
|
||||||
|
observations: Dictionary of observations
|
||||||
|
actions: Action tensor
|
||||||
|
use_target: If True, use target critics, otherwise use ensemble critics
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Tensor of Q-values from all critics
|
||||||
|
"""
|
||||||
|
critics = self.critic_target if use_target else self.critic_ensemble
|
||||||
|
q_values = torch.stack([critic(observations, actions) for critic in critics])
|
||||||
|
return q_values
|
||||||
|
|
||||||
|
|
||||||
def critic_forward(
|
def critic_forward(
|
||||||
self, observations: dict[str, Tensor], actions: Tensor, use_target: bool = False
|
self, observations: dict[str, Tensor], actions: Tensor, use_target: bool = False
|
||||||
) -> Tensor:
|
) -> Tensor:
|
||||||
|
@ -371,6 +387,7 @@ class Policy(nn.Module):
|
||||||
self,
|
self,
|
||||||
observations: torch.Tensor,
|
observations: torch.Tensor,
|
||||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||||
|
|
||||||
# Encode observations if encoder exists
|
# Encode observations if encoder exists
|
||||||
obs_enc = observations if self.encoder is None else self.encoder(observations)
|
obs_enc = observations if self.encoder is None else self.encoder(observations)
|
||||||
|
|
||||||
|
@ -535,7 +552,6 @@ def create_critic_ensemble(critics: list[nn.Module], num_critics: int, device: s
|
||||||
assert len(critics) == num_critics, f"Expected {num_critics} critics, got {len(critics)}"
|
assert len(critics) == num_critics, f"Expected {num_critics} critics, got {len(critics)}"
|
||||||
return nn.ModuleList(critics).to(device)
|
return nn.ModuleList(critics).to(device)
|
||||||
|
|
||||||
|
|
||||||
# borrowed from tdmpc
|
# borrowed from tdmpc
|
||||||
def flatten_forward_unflatten(fn: Callable[[Tensor], Tensor], image_tensor: Tensor) -> Tensor:
|
def flatten_forward_unflatten(fn: Callable[[Tensor], Tensor], image_tensor: Tensor) -> Tensor:
|
||||||
"""Helper to temporarily flatten extra dims at the start of the image tensor.
|
"""Helper to temporarily flatten extra dims at the start of the image tensor.
|
||||||
|
|
Loading…
Reference in New Issue