diff --git a/lerobot/common/policies/sac/modeling_sac.py b/lerobot/common/policies/sac/modeling_sac.py index e3d83d36..281ffe2e 100644 --- a/lerobot/common/policies/sac/modeling_sac.py +++ b/lerobot/common/policies/sac/modeling_sac.py @@ -22,8 +22,10 @@ from dataclasses import asdict from typing import Callable, List, Literal, Optional, Tuple import einops +from importlib_metadata import distribution import numpy as np import torch +from torch.distributions import Categorical import torch.nn as nn import torch.nn.functional as F # noqa: N812 from torch import Tensor @@ -127,6 +129,7 @@ class SACPolicy( encoder=encoder_critic, input_dim=encoder_critic.output_dim, output_dim=config.num_discrete_actions, + softmax_temperature=1.0, **asdict(config.grasp_critic_network_kwargs), ) @@ -194,8 +197,8 @@ class SACPolicy( actions = self.unnormalize_outputs({"action": actions})["action"] if self.config.num_discrete_actions is not None: - discrete_action_value = self.grasp_critic(batch, observations_features) - discrete_action = torch.argmax(discrete_action_value, dim=-1, keepdim=True) + _, discrete_action_distribution = self.grasp_critic(batch, observations_features) + discrete_action = discrete_action_distribution.sample() actions = torch.cat([actions, discrete_action], dim=-1) return actions @@ -429,13 +432,13 @@ class SACPolicy( with torch.no_grad(): # For DQN, select actions using online network, evaluate with target network - next_grasp_qs = self.grasp_critic_forward( + next_grasp_qs, next_grasp_distribution = self.grasp_critic_forward( next_observations, use_target=False, observation_features=next_observation_features ) - best_next_grasp_action = torch.argmax(next_grasp_qs, dim=-1, keepdim=True) + best_next_grasp_action = next_grasp_distribution.sample() # Get target Q-values from target network - target_next_grasp_qs = self.grasp_critic_forward( + target_next_grasp_qs, _ = self.grasp_critic_forward( observations=next_observations, use_target=True, observation_features=next_observation_features, @@ -453,7 +456,7 @@ class SACPolicy( target_grasp_q = rewards_gripper + (1 - done) * self.config.discount * target_next_grasp_q # Get predicted Q-values for current observations - predicted_grasp_qs = self.grasp_critic_forward( + predicted_grasp_qs, _ = self.grasp_critic_forward( observations=observations, use_target=False, observation_features=observation_features ) @@ -777,6 +780,7 @@ class GraspCritic(nn.Module): dropout_rate: Optional[float] = None, init_final: Optional[float] = None, final_activation: Callable[[torch.Tensor], torch.Tensor] | str | None = None, + softmax_temperature: float = 1.0, ): super().__init__() self.encoder = encoder @@ -809,7 +813,9 @@ class GraspCritic(nn.Module): # Move each tensor in observations to device by cloning first to avoid inplace operations observations = {k: v.to(device) for k, v in observations.items()} obs_enc = self.encoder(observations, vision_encoder_cache=observation_features) - return self.output_layer(self.net(obs_enc)) + q_values = self.output_layer(self.net(obs_enc)) + distribution = Categorical(logits=q_values / self.softmax_temperature) + return q_values, distribution class Policy(nn.Module):