add softmax q network
This commit is contained in:
parent
a8135629b4
commit
9f6f508edb
|
@ -22,8 +22,10 @@ from dataclasses import asdict
|
|||
from typing import Callable, List, Literal, Optional, Tuple
|
||||
|
||||
import einops
|
||||
from importlib_metadata import distribution
|
||||
import numpy as np
|
||||
import torch
|
||||
from torch.distributions import Categorical
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F # noqa: N812
|
||||
from torch import Tensor
|
||||
|
@ -127,6 +129,7 @@ class SACPolicy(
|
|||
encoder=encoder_critic,
|
||||
input_dim=encoder_critic.output_dim,
|
||||
output_dim=config.num_discrete_actions,
|
||||
softmax_temperature=1.0,
|
||||
**asdict(config.grasp_critic_network_kwargs),
|
||||
)
|
||||
|
||||
|
@ -194,8 +197,8 @@ class SACPolicy(
|
|||
actions = self.unnormalize_outputs({"action": actions})["action"]
|
||||
|
||||
if self.config.num_discrete_actions is not None:
|
||||
discrete_action_value = self.grasp_critic(batch, observations_features)
|
||||
discrete_action = torch.argmax(discrete_action_value, dim=-1, keepdim=True)
|
||||
_, discrete_action_distribution = self.grasp_critic(batch, observations_features)
|
||||
discrete_action = discrete_action_distribution.sample()
|
||||
actions = torch.cat([actions, discrete_action], dim=-1)
|
||||
|
||||
return actions
|
||||
|
@ -429,13 +432,13 @@ class SACPolicy(
|
|||
|
||||
with torch.no_grad():
|
||||
# For DQN, select actions using online network, evaluate with target network
|
||||
next_grasp_qs = self.grasp_critic_forward(
|
||||
next_grasp_qs, next_grasp_distribution = self.grasp_critic_forward(
|
||||
next_observations, use_target=False, observation_features=next_observation_features
|
||||
)
|
||||
best_next_grasp_action = torch.argmax(next_grasp_qs, dim=-1, keepdim=True)
|
||||
best_next_grasp_action = next_grasp_distribution.sample()
|
||||
|
||||
# Get target Q-values from target network
|
||||
target_next_grasp_qs = self.grasp_critic_forward(
|
||||
target_next_grasp_qs, _ = self.grasp_critic_forward(
|
||||
observations=next_observations,
|
||||
use_target=True,
|
||||
observation_features=next_observation_features,
|
||||
|
@ -453,7 +456,7 @@ class SACPolicy(
|
|||
target_grasp_q = rewards_gripper + (1 - done) * self.config.discount * target_next_grasp_q
|
||||
|
||||
# Get predicted Q-values for current observations
|
||||
predicted_grasp_qs = self.grasp_critic_forward(
|
||||
predicted_grasp_qs, _ = self.grasp_critic_forward(
|
||||
observations=observations, use_target=False, observation_features=observation_features
|
||||
)
|
||||
|
||||
|
@ -777,6 +780,7 @@ class GraspCritic(nn.Module):
|
|||
dropout_rate: Optional[float] = None,
|
||||
init_final: Optional[float] = None,
|
||||
final_activation: Callable[[torch.Tensor], torch.Tensor] | str | None = None,
|
||||
softmax_temperature: float = 1.0,
|
||||
):
|
||||
super().__init__()
|
||||
self.encoder = encoder
|
||||
|
@ -809,7 +813,9 @@ class GraspCritic(nn.Module):
|
|||
# Move each tensor in observations to device by cloning first to avoid inplace operations
|
||||
observations = {k: v.to(device) for k, v in observations.items()}
|
||||
obs_enc = self.encoder(observations, vision_encoder_cache=observation_features)
|
||||
return self.output_layer(self.net(obs_enc))
|
||||
q_values = self.output_layer(self.net(obs_enc))
|
||||
distribution = Categorical(logits=q_values / self.softmax_temperature)
|
||||
return q_values, distribution
|
||||
|
||||
|
||||
class Policy(nn.Module):
|
||||
|
|
Loading…
Reference in New Issue