Add grasp critic
- Implemented grasp critic to evaluate gripper actions - Added corresponding config parameters for tuning
This commit is contained in:
parent
0f706ce543
commit
4a1c26d9ee
|
@ -42,6 +42,14 @@ class CriticNetworkConfig:
|
|||
final_activation: str | None = None
|
||||
|
||||
|
||||
@dataclass
|
||||
class GraspCriticNetworkConfig:
|
||||
hidden_dims: list[int] = field(default_factory=lambda: [256, 256])
|
||||
activate_final: bool = True
|
||||
final_activation: str | None = None
|
||||
output_dim: int = 3
|
||||
|
||||
|
||||
@dataclass
|
||||
class ActorNetworkConfig:
|
||||
hidden_dims: list[int] = field(default_factory=lambda: [256, 256])
|
||||
|
|
|
@ -112,6 +112,26 @@ class SACPolicy(
|
|||
|
||||
self.critic_ensemble = torch.compile(self.critic_ensemble)
|
||||
self.critic_target = torch.compile(self.critic_target)
|
||||
|
||||
# Create grasp critic
|
||||
self.grasp_critic = GraspCritic(
|
||||
encoder=encoder_critic,
|
||||
input_dim=encoder_critic.output_dim,
|
||||
**config.grasp_critic_network_kwargs,
|
||||
)
|
||||
|
||||
# Create target grasp critic
|
||||
self.grasp_critic_target = GraspCritic(
|
||||
encoder=encoder_critic,
|
||||
input_dim=encoder_critic.output_dim,
|
||||
**config.grasp_critic_network_kwargs,
|
||||
)
|
||||
|
||||
self.grasp_critic_target.load_state_dict(self.grasp_critic.state_dict())
|
||||
|
||||
self.grasp_critic = torch.compile(self.grasp_critic)
|
||||
self.grasp_critic_target = torch.compile(self.grasp_critic_target)
|
||||
|
||||
self.actor = Policy(
|
||||
encoder=encoder_actor,
|
||||
network=MLP(input_dim=encoder_actor.output_dim, **asdict(config.actor_network_kwargs)),
|
||||
|
@ -176,6 +196,21 @@ class SACPolicy(
|
|||
q_values = critics(observations, actions, observation_features)
|
||||
return q_values
|
||||
|
||||
def grasp_critic_forward(self, observations, use_target=False, observation_features=None):
|
||||
"""Forward pass through a grasp critic network
|
||||
|
||||
Args:
|
||||
observations: Dictionary of observations
|
||||
use_target: If True, use target critics, otherwise use ensemble critics
|
||||
observation_features: Optional pre-computed observation features to avoid recomputing encoder output
|
||||
|
||||
Returns:
|
||||
Tensor of Q-values from the grasp critic network
|
||||
"""
|
||||
grasp_critic = self.grasp_critic_target if use_target else self.grasp_critic
|
||||
q_values = grasp_critic(observations, observation_features)
|
||||
return q_values
|
||||
|
||||
def forward(
|
||||
self,
|
||||
batch: dict[str, Tensor | dict[str, Tensor]],
|
||||
|
@ -246,6 +281,18 @@ class SACPolicy(
|
|||
+ target_param.data * (1.0 - self.config.critic_target_update_weight)
|
||||
)
|
||||
|
||||
def update_grasp_target_networks(self):
|
||||
"""Update grasp target networks with exponential moving average"""
|
||||
for target_param, param in zip(
|
||||
self.grasp_critic_target.parameters(),
|
||||
self.grasp_critic.parameters(),
|
||||
strict=False,
|
||||
):
|
||||
target_param.data.copy_(
|
||||
param.data * self.config.critic_target_update_weight
|
||||
+ target_param.data * (1.0 - self.config.critic_target_update_weight)
|
||||
)
|
||||
|
||||
def update_temperature(self):
|
||||
self.temperature = self.log_alpha.exp().item()
|
||||
|
||||
|
@ -307,6 +354,32 @@ class SACPolicy(
|
|||
).sum()
|
||||
return critics_loss
|
||||
|
||||
def compute_loss_grasp_critic(self, observations, actions, rewards, next_observations, done, observation_features=None, next_observation_features=None, complementary_info=None):
|
||||
|
||||
batch_size = rewards.shape[0]
|
||||
grasp_actions = torch.clip(actions[:, -1].long() + 1, 0, 2) # Map [-1, 0, 1] -> [0, 1, 2]
|
||||
|
||||
with torch.no_grad():
|
||||
next_grasp_qs = self.grasp_critic_forward(next_observations, use_target=False)
|
||||
best_next_grasp_action = torch.argmax(next_grasp_qs, dim=-1)
|
||||
|
||||
target_next_grasp_qs = self.grasp_critic_forward(next_observations, use_target=True)
|
||||
target_next_grasp_q = target_next_grasp_qs[torch.arange(batch_size), best_next_grasp_action]
|
||||
|
||||
# Get the grasp penalty from complementary_info
|
||||
grasp_penalty = torch.zeros_like(rewards)
|
||||
if complementary_info is not None and "grasp_penalty" in complementary_info:
|
||||
grasp_penalty = complementary_info["grasp_penalty"]
|
||||
|
||||
grasp_rewards = rewards + grasp_penalty
|
||||
target_grasp_q = grasp_rewards + (1 - done) * self.config.discount * target_next_grasp_q
|
||||
|
||||
predicted_grasp_qs = self.grasp_critic_forward(observations, use_target=False)
|
||||
predicted_grasp_q = predicted_grasp_qs[torch.arange(batch_size), grasp_actions]
|
||||
|
||||
grasp_critic_loss = F.mse_loss(input=predicted_grasp_q, target=target_grasp_q, reduction="mean")
|
||||
return grasp_critic_loss
|
||||
|
||||
def compute_loss_temperature(self, observations, observation_features: Tensor | None = None) -> Tensor:
|
||||
"""Compute the temperature loss"""
|
||||
# calculate temperature loss
|
||||
|
@ -509,6 +582,57 @@ class CriticEnsemble(nn.Module):
|
|||
return q_values
|
||||
|
||||
|
||||
class GraspCritic(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
encoder: Optional[nn.Module],
|
||||
network: nn.Module,
|
||||
output_dim: int = 3,
|
||||
init_final: Optional[float] = None,
|
||||
encoder_is_shared: bool = False,
|
||||
):
|
||||
super().__init__()
|
||||
self.encoder = encoder
|
||||
self.network = network
|
||||
self.output_dim = output_dim
|
||||
|
||||
# Find the last Linear layer's output dimension
|
||||
for layer in reversed(network.net):
|
||||
if isinstance(layer, nn.Linear):
|
||||
out_features = layer.out_features
|
||||
break
|
||||
|
||||
self.parameters_to_optimize += list(self.network.parameters())
|
||||
|
||||
if self.encoder is not None and not encoder_is_shared:
|
||||
self.parameters_to_optimize += list(self.encoder.parameters())
|
||||
|
||||
self.output_layer = nn.Linear(in_features=out_features, out_features=self.output_dim)
|
||||
if init_final is not None:
|
||||
nn.init.uniform_(self.output_layer.weight, -init_final, init_final)
|
||||
nn.init.uniform_(self.output_layer.bias, -init_final, init_final)
|
||||
else:
|
||||
orthogonal_init()(self.output_layer.weight)
|
||||
|
||||
self.parameters_to_optimize += list(self.output_layer.parameters())
|
||||
|
||||
def forward(
|
||||
self,
|
||||
observations: torch.Tensor,
|
||||
observation_features: torch.Tensor | None = None
|
||||
) -> torch.Tensor:
|
||||
device = get_device_from_parameters(self)
|
||||
# Move each tensor in observations to device
|
||||
observations = {k: v.to(device) for k, v in observations.items()}
|
||||
# Encode observations if encoder exists
|
||||
obs_enc = (
|
||||
observation_features
|
||||
if observation_features is not None
|
||||
else (observations if self.encoder is None else self.encoder(observations))
|
||||
)
|
||||
return self.output_layer(self.network(obs_enc))
|
||||
|
||||
|
||||
class Policy(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
|
|
Loading…
Reference in New Issue