Add grasp critic
- Implemented grasp critic to evaluate gripper actions - Added corresponding config parameters for tuning
This commit is contained in:
parent
334cf8143e
commit
66693965c0
|
@ -42,6 +42,14 @@ class CriticNetworkConfig:
|
||||||
final_activation: str | None = None
|
final_activation: str | None = None
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class GraspCriticNetworkConfig:
|
||||||
|
hidden_dims: list[int] = field(default_factory=lambda: [256, 256])
|
||||||
|
activate_final: bool = True
|
||||||
|
final_activation: str | None = None
|
||||||
|
output_dim: int = 3
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class ActorNetworkConfig:
|
class ActorNetworkConfig:
|
||||||
hidden_dims: list[int] = field(default_factory=lambda: [256, 256])
|
hidden_dims: list[int] = field(default_factory=lambda: [256, 256])
|
||||||
|
|
|
@ -112,6 +112,26 @@ class SACPolicy(
|
||||||
|
|
||||||
self.critic_ensemble = torch.compile(self.critic_ensemble)
|
self.critic_ensemble = torch.compile(self.critic_ensemble)
|
||||||
self.critic_target = torch.compile(self.critic_target)
|
self.critic_target = torch.compile(self.critic_target)
|
||||||
|
|
||||||
|
# Create grasp critic
|
||||||
|
self.grasp_critic = GraspCritic(
|
||||||
|
encoder=encoder_critic,
|
||||||
|
input_dim=encoder_critic.output_dim,
|
||||||
|
**config.grasp_critic_network_kwargs,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Create target grasp critic
|
||||||
|
self.grasp_critic_target = GraspCritic(
|
||||||
|
encoder=encoder_critic,
|
||||||
|
input_dim=encoder_critic.output_dim,
|
||||||
|
**config.grasp_critic_network_kwargs,
|
||||||
|
)
|
||||||
|
|
||||||
|
self.grasp_critic_target.load_state_dict(self.grasp_critic.state_dict())
|
||||||
|
|
||||||
|
self.grasp_critic = torch.compile(self.grasp_critic)
|
||||||
|
self.grasp_critic_target = torch.compile(self.grasp_critic_target)
|
||||||
|
|
||||||
self.actor = Policy(
|
self.actor = Policy(
|
||||||
encoder=encoder_actor,
|
encoder=encoder_actor,
|
||||||
network=MLP(input_dim=encoder_actor.output_dim, **asdict(config.actor_network_kwargs)),
|
network=MLP(input_dim=encoder_actor.output_dim, **asdict(config.actor_network_kwargs)),
|
||||||
|
@ -176,6 +196,21 @@ class SACPolicy(
|
||||||
q_values = critics(observations, actions, observation_features)
|
q_values = critics(observations, actions, observation_features)
|
||||||
return q_values
|
return q_values
|
||||||
|
|
||||||
|
def grasp_critic_forward(self, observations, use_target=False, observation_features=None):
|
||||||
|
"""Forward pass through a grasp critic network
|
||||||
|
|
||||||
|
Args:
|
||||||
|
observations: Dictionary of observations
|
||||||
|
use_target: If True, use target critics, otherwise use ensemble critics
|
||||||
|
observation_features: Optional pre-computed observation features to avoid recomputing encoder output
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Tensor of Q-values from the grasp critic network
|
||||||
|
"""
|
||||||
|
grasp_critic = self.grasp_critic_target if use_target else self.grasp_critic
|
||||||
|
q_values = grasp_critic(observations, observation_features)
|
||||||
|
return q_values
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
batch: dict[str, Tensor | dict[str, Tensor]],
|
batch: dict[str, Tensor | dict[str, Tensor]],
|
||||||
|
@ -246,6 +281,18 @@ class SACPolicy(
|
||||||
+ target_param.data * (1.0 - self.config.critic_target_update_weight)
|
+ target_param.data * (1.0 - self.config.critic_target_update_weight)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def update_grasp_target_networks(self):
|
||||||
|
"""Update grasp target networks with exponential moving average"""
|
||||||
|
for target_param, param in zip(
|
||||||
|
self.grasp_critic_target.parameters(),
|
||||||
|
self.grasp_critic.parameters(),
|
||||||
|
strict=False,
|
||||||
|
):
|
||||||
|
target_param.data.copy_(
|
||||||
|
param.data * self.config.critic_target_update_weight
|
||||||
|
+ target_param.data * (1.0 - self.config.critic_target_update_weight)
|
||||||
|
)
|
||||||
|
|
||||||
def update_temperature(self):
|
def update_temperature(self):
|
||||||
self.temperature = self.log_alpha.exp().item()
|
self.temperature = self.log_alpha.exp().item()
|
||||||
|
|
||||||
|
@ -307,6 +354,32 @@ class SACPolicy(
|
||||||
).sum()
|
).sum()
|
||||||
return critics_loss
|
return critics_loss
|
||||||
|
|
||||||
|
def compute_loss_grasp_critic(self, observations, actions, rewards, next_observations, done, observation_features=None, next_observation_features=None, complementary_info=None):
|
||||||
|
|
||||||
|
batch_size = rewards.shape[0]
|
||||||
|
grasp_actions = torch.clip(actions[:, -1].long() + 1, 0, 2) # Map [-1, 0, 1] -> [0, 1, 2]
|
||||||
|
|
||||||
|
with torch.no_grad():
|
||||||
|
next_grasp_qs = self.grasp_critic_forward(next_observations, use_target=False)
|
||||||
|
best_next_grasp_action = torch.argmax(next_grasp_qs, dim=-1)
|
||||||
|
|
||||||
|
target_next_grasp_qs = self.grasp_critic_forward(next_observations, use_target=True)
|
||||||
|
target_next_grasp_q = target_next_grasp_qs[torch.arange(batch_size), best_next_grasp_action]
|
||||||
|
|
||||||
|
# Get the grasp penalty from complementary_info
|
||||||
|
grasp_penalty = torch.zeros_like(rewards)
|
||||||
|
if complementary_info is not None and "grasp_penalty" in complementary_info:
|
||||||
|
grasp_penalty = complementary_info["grasp_penalty"]
|
||||||
|
|
||||||
|
grasp_rewards = rewards + grasp_penalty
|
||||||
|
target_grasp_q = grasp_rewards + (1 - done) * self.config.discount * target_next_grasp_q
|
||||||
|
|
||||||
|
predicted_grasp_qs = self.grasp_critic_forward(observations, use_target=False)
|
||||||
|
predicted_grasp_q = predicted_grasp_qs[torch.arange(batch_size), grasp_actions]
|
||||||
|
|
||||||
|
grasp_critic_loss = F.mse_loss(input=predicted_grasp_q, target=target_grasp_q, reduction="mean")
|
||||||
|
return grasp_critic_loss
|
||||||
|
|
||||||
def compute_loss_temperature(self, observations, observation_features: Tensor | None = None) -> Tensor:
|
def compute_loss_temperature(self, observations, observation_features: Tensor | None = None) -> Tensor:
|
||||||
"""Compute the temperature loss"""
|
"""Compute the temperature loss"""
|
||||||
# calculate temperature loss
|
# calculate temperature loss
|
||||||
|
@ -509,6 +582,57 @@ class CriticEnsemble(nn.Module):
|
||||||
return q_values
|
return q_values
|
||||||
|
|
||||||
|
|
||||||
|
class GraspCritic(nn.Module):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
encoder: Optional[nn.Module],
|
||||||
|
network: nn.Module,
|
||||||
|
output_dim: int = 3,
|
||||||
|
init_final: Optional[float] = None,
|
||||||
|
encoder_is_shared: bool = False,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
self.encoder = encoder
|
||||||
|
self.network = network
|
||||||
|
self.output_dim = output_dim
|
||||||
|
|
||||||
|
# Find the last Linear layer's output dimension
|
||||||
|
for layer in reversed(network.net):
|
||||||
|
if isinstance(layer, nn.Linear):
|
||||||
|
out_features = layer.out_features
|
||||||
|
break
|
||||||
|
|
||||||
|
self.parameters_to_optimize += list(self.network.parameters())
|
||||||
|
|
||||||
|
if self.encoder is not None and not encoder_is_shared:
|
||||||
|
self.parameters_to_optimize += list(self.encoder.parameters())
|
||||||
|
|
||||||
|
self.output_layer = nn.Linear(in_features=out_features, out_features=self.output_dim)
|
||||||
|
if init_final is not None:
|
||||||
|
nn.init.uniform_(self.output_layer.weight, -init_final, init_final)
|
||||||
|
nn.init.uniform_(self.output_layer.bias, -init_final, init_final)
|
||||||
|
else:
|
||||||
|
orthogonal_init()(self.output_layer.weight)
|
||||||
|
|
||||||
|
self.parameters_to_optimize += list(self.output_layer.parameters())
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
observations: torch.Tensor,
|
||||||
|
observation_features: torch.Tensor | None = None
|
||||||
|
) -> torch.Tensor:
|
||||||
|
device = get_device_from_parameters(self)
|
||||||
|
# Move each tensor in observations to device
|
||||||
|
observations = {k: v.to(device) for k, v in observations.items()}
|
||||||
|
# Encode observations if encoder exists
|
||||||
|
obs_enc = (
|
||||||
|
observation_features
|
||||||
|
if observation_features is not None
|
||||||
|
else (observations if self.encoder is None else self.encoder(observations))
|
||||||
|
)
|
||||||
|
return self.output_layer(self.network(obs_enc))
|
||||||
|
|
||||||
|
|
||||||
class Policy(nn.Module):
|
class Policy(nn.Module):
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
|
|
Loading…
Reference in New Issue