split encoder for critic and actor
This commit is contained in:
parent
9ceb68ee90
commit
41b377211c
|
@ -137,6 +137,22 @@ class SACPolicy(
|
|||
actions, _, _ = self.actor(batch)
|
||||
actions = self.unnormalize_outputs({"action": actions})["action"]
|
||||
return actions
|
||||
|
||||
def critic_forward(self, observations: dict[str, Tensor], actions: Tensor, use_target: bool = False) -> Tensor:
|
||||
"""Forward pass through a critic network ensemble
|
||||
|
||||
Args:
|
||||
observations: Dictionary of observations
|
||||
actions: Action tensor
|
||||
use_target: If True, use target critics, otherwise use ensemble critics
|
||||
|
||||
Returns:
|
||||
Tensor of Q-values from all critics
|
||||
"""
|
||||
critics = self.critic_target if use_target else self.critic_ensemble
|
||||
q_values = torch.stack([critic(observations, actions) for critic in critics])
|
||||
return q_values
|
||||
|
||||
|
||||
def critic_forward(
|
||||
self, observations: dict[str, Tensor], actions: Tensor, use_target: bool = False
|
||||
|
@ -262,8 +278,8 @@ class MLP(nn.Module):
|
|||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
return self.net(x)
|
||||
|
||||
|
||||
|
||||
|
||||
class Critic(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
|
@ -277,13 +293,13 @@ class Critic(nn.Module):
|
|||
self.encoder = encoder
|
||||
self.network = network
|
||||
self.init_final = init_final
|
||||
|
||||
|
||||
# Find the last Linear layer's output dimension
|
||||
for layer in reversed(network.net):
|
||||
if isinstance(layer, nn.Linear):
|
||||
out_features = layer.out_features
|
||||
break
|
||||
|
||||
|
||||
# Output layer
|
||||
if init_final is not None:
|
||||
self.output_layer = nn.Linear(out_features, 1)
|
||||
|
@ -292,7 +308,7 @@ class Critic(nn.Module):
|
|||
else:
|
||||
self.output_layer = nn.Linear(out_features, 1)
|
||||
orthogonal_init()(self.output_layer.weight)
|
||||
|
||||
|
||||
self.to(self.device)
|
||||
|
||||
def forward(
|
||||
|
@ -303,9 +319,9 @@ class Critic(nn.Module):
|
|||
# Move each tensor in observations to device
|
||||
observations = {k: v.to(self.device) for k, v in observations.items()}
|
||||
actions = actions.to(self.device)
|
||||
|
||||
|
||||
obs_enc = observations if self.encoder is None else self.encoder(observations)
|
||||
|
||||
|
||||
inputs = torch.cat([obs_enc, actions], dim=-1)
|
||||
x = self.network(inputs)
|
||||
value = self.output_layer(x)
|
||||
|
@ -368,16 +384,17 @@ class Policy(nn.Module):
|
|||
self.to(self.device)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
self,
|
||||
observations: torch.Tensor,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
|
||||
# Encode observations if encoder exists
|
||||
obs_enc = observations if self.encoder is None else self.encoder(observations)
|
||||
|
||||
# Get network outputs
|
||||
outputs = self.network(obs_enc)
|
||||
means = self.mean_layer(outputs)
|
||||
|
||||
|
||||
# Compute standard deviations
|
||||
if self.fixed_std is None:
|
||||
log_std = self.std_layer(outputs)
|
||||
|
@ -535,7 +552,6 @@ def create_critic_ensemble(critics: list[nn.Module], num_critics: int, device: s
|
|||
assert len(critics) == num_critics, f"Expected {num_critics} critics, got {len(critics)}"
|
||||
return nn.ModuleList(critics).to(device)
|
||||
|
||||
|
||||
# borrowed from tdmpc
|
||||
def flatten_forward_unflatten(fn: Callable[[Tensor], Tensor], image_tensor: Tensor) -> Tensor:
|
||||
"""Helper to temporarily flatten extra dims at the start of the image tensor.
|
||||
|
@ -543,7 +559,7 @@ def flatten_forward_unflatten(fn: Callable[[Tensor], Tensor], image_tensor: Tens
|
|||
Args:
|
||||
fn: Callable that the image tensor will be passed to. It should accept (B, C, H, W) and return
|
||||
(B, *), where * is any number of dimensions.
|
||||
image_tensor: An image tensor of shape (**, C, H, W), where ** is any number of dimensions and
|
||||
image_tensor: An image tensor of shape (**, C, H, W), where ** is any number of dimensions and
|
||||
can be more than 1 dimensions, generally different from *.
|
||||
Returns:
|
||||
A return value from the callable reshaped to (**, *).
|
||||
|
@ -553,4 +569,4 @@ def flatten_forward_unflatten(fn: Callable[[Tensor], Tensor], image_tensor: Tens
|
|||
start_dims = image_tensor.shape[:-3]
|
||||
inp = torch.flatten(image_tensor, end_dim=-4)
|
||||
flat_out = fn(inp)
|
||||
return torch.reshape(flat_out, (*start_dims, *flat_out.shape[1:]))
|
||||
return torch.reshape(flat_out, (*start_dims, *flat_out.shape[1:]))
|
Loading…
Reference in New Issue