add softmax q network
This commit is contained in:
parent
a8135629b4
commit
9f6f508edb
|
@ -22,8 +22,10 @@ from dataclasses import asdict
|
||||||
from typing import Callable, List, Literal, Optional, Tuple
|
from typing import Callable, List, Literal, Optional, Tuple
|
||||||
|
|
||||||
import einops
|
import einops
|
||||||
|
from importlib_metadata import distribution
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
|
from torch.distributions import Categorical
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
import torch.nn.functional as F # noqa: N812
|
import torch.nn.functional as F # noqa: N812
|
||||||
from torch import Tensor
|
from torch import Tensor
|
||||||
|
@ -127,6 +129,7 @@ class SACPolicy(
|
||||||
encoder=encoder_critic,
|
encoder=encoder_critic,
|
||||||
input_dim=encoder_critic.output_dim,
|
input_dim=encoder_critic.output_dim,
|
||||||
output_dim=config.num_discrete_actions,
|
output_dim=config.num_discrete_actions,
|
||||||
|
softmax_temperature=1.0,
|
||||||
**asdict(config.grasp_critic_network_kwargs),
|
**asdict(config.grasp_critic_network_kwargs),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -194,8 +197,8 @@ class SACPolicy(
|
||||||
actions = self.unnormalize_outputs({"action": actions})["action"]
|
actions = self.unnormalize_outputs({"action": actions})["action"]
|
||||||
|
|
||||||
if self.config.num_discrete_actions is not None:
|
if self.config.num_discrete_actions is not None:
|
||||||
discrete_action_value = self.grasp_critic(batch, observations_features)
|
_, discrete_action_distribution = self.grasp_critic(batch, observations_features)
|
||||||
discrete_action = torch.argmax(discrete_action_value, dim=-1, keepdim=True)
|
discrete_action = discrete_action_distribution.sample()
|
||||||
actions = torch.cat([actions, discrete_action], dim=-1)
|
actions = torch.cat([actions, discrete_action], dim=-1)
|
||||||
|
|
||||||
return actions
|
return actions
|
||||||
|
@ -429,13 +432,13 @@ class SACPolicy(
|
||||||
|
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
# For DQN, select actions using online network, evaluate with target network
|
# For DQN, select actions using online network, evaluate with target network
|
||||||
next_grasp_qs = self.grasp_critic_forward(
|
next_grasp_qs, next_grasp_distribution = self.grasp_critic_forward(
|
||||||
next_observations, use_target=False, observation_features=next_observation_features
|
next_observations, use_target=False, observation_features=next_observation_features
|
||||||
)
|
)
|
||||||
best_next_grasp_action = torch.argmax(next_grasp_qs, dim=-1, keepdim=True)
|
best_next_grasp_action = next_grasp_distribution.sample()
|
||||||
|
|
||||||
# Get target Q-values from target network
|
# Get target Q-values from target network
|
||||||
target_next_grasp_qs = self.grasp_critic_forward(
|
target_next_grasp_qs, _ = self.grasp_critic_forward(
|
||||||
observations=next_observations,
|
observations=next_observations,
|
||||||
use_target=True,
|
use_target=True,
|
||||||
observation_features=next_observation_features,
|
observation_features=next_observation_features,
|
||||||
|
@ -453,7 +456,7 @@ class SACPolicy(
|
||||||
target_grasp_q = rewards_gripper + (1 - done) * self.config.discount * target_next_grasp_q
|
target_grasp_q = rewards_gripper + (1 - done) * self.config.discount * target_next_grasp_q
|
||||||
|
|
||||||
# Get predicted Q-values for current observations
|
# Get predicted Q-values for current observations
|
||||||
predicted_grasp_qs = self.grasp_critic_forward(
|
predicted_grasp_qs, _ = self.grasp_critic_forward(
|
||||||
observations=observations, use_target=False, observation_features=observation_features
|
observations=observations, use_target=False, observation_features=observation_features
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -777,6 +780,7 @@ class GraspCritic(nn.Module):
|
||||||
dropout_rate: Optional[float] = None,
|
dropout_rate: Optional[float] = None,
|
||||||
init_final: Optional[float] = None,
|
init_final: Optional[float] = None,
|
||||||
final_activation: Callable[[torch.Tensor], torch.Tensor] | str | None = None,
|
final_activation: Callable[[torch.Tensor], torch.Tensor] | str | None = None,
|
||||||
|
softmax_temperature: float = 1.0,
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.encoder = encoder
|
self.encoder = encoder
|
||||||
|
@ -809,7 +813,9 @@ class GraspCritic(nn.Module):
|
||||||
# Move each tensor in observations to device by cloning first to avoid inplace operations
|
# Move each tensor in observations to device by cloning first to avoid inplace operations
|
||||||
observations = {k: v.to(device) for k, v in observations.items()}
|
observations = {k: v.to(device) for k, v in observations.items()}
|
||||||
obs_enc = self.encoder(observations, vision_encoder_cache=observation_features)
|
obs_enc = self.encoder(observations, vision_encoder_cache=observation_features)
|
||||||
return self.output_layer(self.net(obs_enc))
|
q_values = self.output_layer(self.net(obs_enc))
|
||||||
|
distribution = Categorical(logits=q_values / self.softmax_temperature)
|
||||||
|
return q_values, distribution
|
||||||
|
|
||||||
|
|
||||||
class Policy(nn.Module):
|
class Policy(nn.Module):
|
||||||
|
|
Loading…
Reference in New Issue