split encoder for critic and actor
This commit is contained in:
parent
9ceb68ee90
commit
41b377211c
|
@ -137,6 +137,22 @@ class SACPolicy(
|
||||||
actions, _, _ = self.actor(batch)
|
actions, _, _ = self.actor(batch)
|
||||||
actions = self.unnormalize_outputs({"action": actions})["action"]
|
actions = self.unnormalize_outputs({"action": actions})["action"]
|
||||||
return actions
|
return actions
|
||||||
|
|
||||||
|
def critic_forward(self, observations: dict[str, Tensor], actions: Tensor, use_target: bool = False) -> Tensor:
|
||||||
|
"""Forward pass through a critic network ensemble
|
||||||
|
|
||||||
|
Args:
|
||||||
|
observations: Dictionary of observations
|
||||||
|
actions: Action tensor
|
||||||
|
use_target: If True, use target critics, otherwise use ensemble critics
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Tensor of Q-values from all critics
|
||||||
|
"""
|
||||||
|
critics = self.critic_target if use_target else self.critic_ensemble
|
||||||
|
q_values = torch.stack([critic(observations, actions) for critic in critics])
|
||||||
|
return q_values
|
||||||
|
|
||||||
|
|
||||||
def critic_forward(
|
def critic_forward(
|
||||||
self, observations: dict[str, Tensor], actions: Tensor, use_target: bool = False
|
self, observations: dict[str, Tensor], actions: Tensor, use_target: bool = False
|
||||||
|
@ -262,8 +278,8 @@ class MLP(nn.Module):
|
||||||
|
|
||||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||||
return self.net(x)
|
return self.net(x)
|
||||||
|
|
||||||
|
|
||||||
class Critic(nn.Module):
|
class Critic(nn.Module):
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
|
@ -277,13 +293,13 @@ class Critic(nn.Module):
|
||||||
self.encoder = encoder
|
self.encoder = encoder
|
||||||
self.network = network
|
self.network = network
|
||||||
self.init_final = init_final
|
self.init_final = init_final
|
||||||
|
|
||||||
# Find the last Linear layer's output dimension
|
# Find the last Linear layer's output dimension
|
||||||
for layer in reversed(network.net):
|
for layer in reversed(network.net):
|
||||||
if isinstance(layer, nn.Linear):
|
if isinstance(layer, nn.Linear):
|
||||||
out_features = layer.out_features
|
out_features = layer.out_features
|
||||||
break
|
break
|
||||||
|
|
||||||
# Output layer
|
# Output layer
|
||||||
if init_final is not None:
|
if init_final is not None:
|
||||||
self.output_layer = nn.Linear(out_features, 1)
|
self.output_layer = nn.Linear(out_features, 1)
|
||||||
|
@ -292,7 +308,7 @@ class Critic(nn.Module):
|
||||||
else:
|
else:
|
||||||
self.output_layer = nn.Linear(out_features, 1)
|
self.output_layer = nn.Linear(out_features, 1)
|
||||||
orthogonal_init()(self.output_layer.weight)
|
orthogonal_init()(self.output_layer.weight)
|
||||||
|
|
||||||
self.to(self.device)
|
self.to(self.device)
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
|
@ -303,9 +319,9 @@ class Critic(nn.Module):
|
||||||
# Move each tensor in observations to device
|
# Move each tensor in observations to device
|
||||||
observations = {k: v.to(self.device) for k, v in observations.items()}
|
observations = {k: v.to(self.device) for k, v in observations.items()}
|
||||||
actions = actions.to(self.device)
|
actions = actions.to(self.device)
|
||||||
|
|
||||||
obs_enc = observations if self.encoder is None else self.encoder(observations)
|
obs_enc = observations if self.encoder is None else self.encoder(observations)
|
||||||
|
|
||||||
inputs = torch.cat([obs_enc, actions], dim=-1)
|
inputs = torch.cat([obs_enc, actions], dim=-1)
|
||||||
x = self.network(inputs)
|
x = self.network(inputs)
|
||||||
value = self.output_layer(x)
|
value = self.output_layer(x)
|
||||||
|
@ -368,16 +384,17 @@ class Policy(nn.Module):
|
||||||
self.to(self.device)
|
self.to(self.device)
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
observations: torch.Tensor,
|
observations: torch.Tensor,
|
||||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||||
|
|
||||||
# Encode observations if encoder exists
|
# Encode observations if encoder exists
|
||||||
obs_enc = observations if self.encoder is None else self.encoder(observations)
|
obs_enc = observations if self.encoder is None else self.encoder(observations)
|
||||||
|
|
||||||
# Get network outputs
|
# Get network outputs
|
||||||
outputs = self.network(obs_enc)
|
outputs = self.network(obs_enc)
|
||||||
means = self.mean_layer(outputs)
|
means = self.mean_layer(outputs)
|
||||||
|
|
||||||
# Compute standard deviations
|
# Compute standard deviations
|
||||||
if self.fixed_std is None:
|
if self.fixed_std is None:
|
||||||
log_std = self.std_layer(outputs)
|
log_std = self.std_layer(outputs)
|
||||||
|
@ -535,7 +552,6 @@ def create_critic_ensemble(critics: list[nn.Module], num_critics: int, device: s
|
||||||
assert len(critics) == num_critics, f"Expected {num_critics} critics, got {len(critics)}"
|
assert len(critics) == num_critics, f"Expected {num_critics} critics, got {len(critics)}"
|
||||||
return nn.ModuleList(critics).to(device)
|
return nn.ModuleList(critics).to(device)
|
||||||
|
|
||||||
|
|
||||||
# borrowed from tdmpc
|
# borrowed from tdmpc
|
||||||
def flatten_forward_unflatten(fn: Callable[[Tensor], Tensor], image_tensor: Tensor) -> Tensor:
|
def flatten_forward_unflatten(fn: Callable[[Tensor], Tensor], image_tensor: Tensor) -> Tensor:
|
||||||
"""Helper to temporarily flatten extra dims at the start of the image tensor.
|
"""Helper to temporarily flatten extra dims at the start of the image tensor.
|
||||||
|
@ -543,7 +559,7 @@ def flatten_forward_unflatten(fn: Callable[[Tensor], Tensor], image_tensor: Tens
|
||||||
Args:
|
Args:
|
||||||
fn: Callable that the image tensor will be passed to. It should accept (B, C, H, W) and return
|
fn: Callable that the image tensor will be passed to. It should accept (B, C, H, W) and return
|
||||||
(B, *), where * is any number of dimensions.
|
(B, *), where * is any number of dimensions.
|
||||||
image_tensor: An image tensor of shape (**, C, H, W), where ** is any number of dimensions and
|
image_tensor: An image tensor of shape (**, C, H, W), where ** is any number of dimensions and
|
||||||
can be more than 1 dimensions, generally different from *.
|
can be more than 1 dimensions, generally different from *.
|
||||||
Returns:
|
Returns:
|
||||||
A return value from the callable reshaped to (**, *).
|
A return value from the callable reshaped to (**, *).
|
||||||
|
@ -553,4 +569,4 @@ def flatten_forward_unflatten(fn: Callable[[Tensor], Tensor], image_tensor: Tens
|
||||||
start_dims = image_tensor.shape[:-3]
|
start_dims = image_tensor.shape[:-3]
|
||||||
inp = torch.flatten(image_tensor, end_dim=-4)
|
inp = torch.flatten(image_tensor, end_dim=-4)
|
||||||
flat_out = fn(inp)
|
flat_out = fn(inp)
|
||||||
return torch.reshape(flat_out, (*start_dims, *flat_out.shape[1:]))
|
return torch.reshape(flat_out, (*start_dims, *flat_out.shape[1:]))
|
Loading…
Reference in New Issue