From 41b377211cc82917799eea43f16e86111c8b59d1 Mon Sep 17 00:00:00 2001 From: Michel Aractingi Date: Sun, 29 Dec 2024 23:59:39 +0000 Subject: [PATCH] split encoder for critic and actor --- lerobot/common/policies/sac/modeling_sac.py | 40 ++++++++++++++------- 1 file changed, 28 insertions(+), 12 deletions(-) diff --git a/lerobot/common/policies/sac/modeling_sac.py b/lerobot/common/policies/sac/modeling_sac.py index 8fb46199..aada2714 100644 --- a/lerobot/common/policies/sac/modeling_sac.py +++ b/lerobot/common/policies/sac/modeling_sac.py @@ -137,6 +137,22 @@ class SACPolicy( actions, _, _ = self.actor(batch) actions = self.unnormalize_outputs({"action": actions})["action"] return actions + + def critic_forward(self, observations: dict[str, Tensor], actions: Tensor, use_target: bool = False) -> Tensor: + """Forward pass through a critic network ensemble + + Args: + observations: Dictionary of observations + actions: Action tensor + use_target: If True, use target critics, otherwise use ensemble critics + + Returns: + Tensor of Q-values from all critics + """ + critics = self.critic_target if use_target else self.critic_ensemble + q_values = torch.stack([critic(observations, actions) for critic in critics]) + return q_values + def critic_forward( self, observations: dict[str, Tensor], actions: Tensor, use_target: bool = False @@ -262,8 +278,8 @@ class MLP(nn.Module): def forward(self, x: torch.Tensor) -> torch.Tensor: return self.net(x) - - + + class Critic(nn.Module): def __init__( self, @@ -277,13 +293,13 @@ class Critic(nn.Module): self.encoder = encoder self.network = network self.init_final = init_final - + # Find the last Linear layer's output dimension for layer in reversed(network.net): if isinstance(layer, nn.Linear): out_features = layer.out_features break - + # Output layer if init_final is not None: self.output_layer = nn.Linear(out_features, 1) @@ -292,7 +308,7 @@ class Critic(nn.Module): else: self.output_layer = nn.Linear(out_features, 1) orthogonal_init()(self.output_layer.weight) - + self.to(self.device) def forward( @@ -303,9 +319,9 @@ class Critic(nn.Module): # Move each tensor in observations to device observations = {k: v.to(self.device) for k, v in observations.items()} actions = actions.to(self.device) - + obs_enc = observations if self.encoder is None else self.encoder(observations) - + inputs = torch.cat([obs_enc, actions], dim=-1) x = self.network(inputs) value = self.output_layer(x) @@ -368,16 +384,17 @@ class Policy(nn.Module): self.to(self.device) def forward( - self, + self, observations: torch.Tensor, ) -> Tuple[torch.Tensor, torch.Tensor]: + # Encode observations if encoder exists obs_enc = observations if self.encoder is None else self.encoder(observations) # Get network outputs outputs = self.network(obs_enc) means = self.mean_layer(outputs) - + # Compute standard deviations if self.fixed_std is None: log_std = self.std_layer(outputs) @@ -535,7 +552,6 @@ def create_critic_ensemble(critics: list[nn.Module], num_critics: int, device: s assert len(critics) == num_critics, f"Expected {num_critics} critics, got {len(critics)}" return nn.ModuleList(critics).to(device) - # borrowed from tdmpc def flatten_forward_unflatten(fn: Callable[[Tensor], Tensor], image_tensor: Tensor) -> Tensor: """Helper to temporarily flatten extra dims at the start of the image tensor. @@ -543,7 +559,7 @@ def flatten_forward_unflatten(fn: Callable[[Tensor], Tensor], image_tensor: Tens Args: fn: Callable that the image tensor will be passed to. It should accept (B, C, H, W) and return (B, *), where * is any number of dimensions. - image_tensor: An image tensor of shape (**, C, H, W), where ** is any number of dimensions and + image_tensor: An image tensor of shape (**, C, H, W), where ** is any number of dimensions and can be more than 1 dimensions, generally different from *. Returns: A return value from the callable reshaped to (**, *). @@ -553,4 +569,4 @@ def flatten_forward_unflatten(fn: Callable[[Tensor], Tensor], image_tensor: Tens start_dims = image_tensor.shape[:-3] inp = torch.flatten(image_tensor, end_dim=-4) flat_out = fn(inp) - return torch.reshape(flat_out, (*start_dims, *flat_out.shape[1:])) + return torch.reshape(flat_out, (*start_dims, *flat_out.shape[1:])) \ No newline at end of file