add softmax q network

This commit is contained in:
AdilZouitine 2025-04-08 09:14:49 +00:00
parent a8135629b4
commit 9f6f508edb
1 changed files with 13 additions and 7 deletions

View File

@ -22,8 +22,10 @@ from dataclasses import asdict
from typing import Callable, List, Literal, Optional, Tuple
import einops
from importlib_metadata import distribution
import numpy as np
import torch
from torch.distributions import Categorical
import torch.nn as nn
import torch.nn.functional as F # noqa: N812
from torch import Tensor
@ -127,6 +129,7 @@ class SACPolicy(
encoder=encoder_critic,
input_dim=encoder_critic.output_dim,
output_dim=config.num_discrete_actions,
softmax_temperature=1.0,
**asdict(config.grasp_critic_network_kwargs),
)
@ -194,8 +197,8 @@ class SACPolicy(
actions = self.unnormalize_outputs({"action": actions})["action"]
if self.config.num_discrete_actions is not None:
discrete_action_value = self.grasp_critic(batch, observations_features)
discrete_action = torch.argmax(discrete_action_value, dim=-1, keepdim=True)
_, discrete_action_distribution = self.grasp_critic(batch, observations_features)
discrete_action = discrete_action_distribution.sample()
actions = torch.cat([actions, discrete_action], dim=-1)
return actions
@ -429,13 +432,13 @@ class SACPolicy(
with torch.no_grad():
# For DQN, select actions using online network, evaluate with target network
next_grasp_qs = self.grasp_critic_forward(
next_grasp_qs, next_grasp_distribution = self.grasp_critic_forward(
next_observations, use_target=False, observation_features=next_observation_features
)
best_next_grasp_action = torch.argmax(next_grasp_qs, dim=-1, keepdim=True)
best_next_grasp_action = next_grasp_distribution.sample()
# Get target Q-values from target network
target_next_grasp_qs = self.grasp_critic_forward(
target_next_grasp_qs, _ = self.grasp_critic_forward(
observations=next_observations,
use_target=True,
observation_features=next_observation_features,
@ -453,7 +456,7 @@ class SACPolicy(
target_grasp_q = rewards_gripper + (1 - done) * self.config.discount * target_next_grasp_q
# Get predicted Q-values for current observations
predicted_grasp_qs = self.grasp_critic_forward(
predicted_grasp_qs, _ = self.grasp_critic_forward(
observations=observations, use_target=False, observation_features=observation_features
)
@ -777,6 +780,7 @@ class GraspCritic(nn.Module):
dropout_rate: Optional[float] = None,
init_final: Optional[float] = None,
final_activation: Callable[[torch.Tensor], torch.Tensor] | str | None = None,
softmax_temperature: float = 1.0,
):
super().__init__()
self.encoder = encoder
@ -809,7 +813,9 @@ class GraspCritic(nn.Module):
# Move each tensor in observations to device by cloning first to avoid inplace operations
observations = {k: v.to(device) for k, v in observations.items()}
obs_enc = self.encoder(observations, vision_encoder_cache=observation_features)
return self.output_layer(self.net(obs_enc))
q_values = self.output_layer(self.net(obs_enc))
distribution = Categorical(logits=q_values / self.softmax_temperature)
return q_values, distribution
class Policy(nn.Module):