Refactor SAC configuration and policy to support discrete actions
- Removed GraspCriticNetworkConfig class and integrated its parameters into SACConfig. - Added num_discrete_actions parameter to SACConfig for better action handling. - Updated SACPolicy to conditionally create grasp critic networks based on num_discrete_actions. - Enhanced grasp critic forward pass to handle discrete actions and compute losses accordingly.
This commit is contained in:
parent
fe2ff516a8
commit
6a215f47dd
|
@ -42,12 +42,6 @@ class CriticNetworkConfig:
|
|||
final_activation: str | None = None
|
||||
|
||||
|
||||
@dataclass
|
||||
class GraspCriticNetworkConfig:
|
||||
hidden_dims: list[int] = field(default_factory=lambda: [256, 256])
|
||||
activate_final: bool = True
|
||||
final_activation: str | None = None
|
||||
output_dim: int = 3
|
||||
|
||||
|
||||
@dataclass
|
||||
|
@ -152,6 +146,7 @@ class SACConfig(PreTrainedConfig):
|
|||
freeze_vision_encoder: bool = True
|
||||
image_encoder_hidden_dim: int = 32
|
||||
shared_encoder: bool = True
|
||||
num_discrete_actions: int | None = None
|
||||
|
||||
# Training parameter
|
||||
online_steps: int = 1000000
|
||||
|
@ -182,7 +177,7 @@ class SACConfig(PreTrainedConfig):
|
|||
critic_network_kwargs: CriticNetworkConfig = field(default_factory=CriticNetworkConfig)
|
||||
actor_network_kwargs: ActorNetworkConfig = field(default_factory=ActorNetworkConfig)
|
||||
policy_kwargs: PolicyConfig = field(default_factory=PolicyConfig)
|
||||
|
||||
grasp_critic_network_kwargs: CriticNetworkConfig = field(default_factory=CriticNetworkConfig)
|
||||
actor_learner_config: ActorLearnerConfig = field(default_factory=ActorLearnerConfig)
|
||||
concurrency: ConcurrencyConfig = field(default_factory=ConcurrencyConfig)
|
||||
|
||||
|
|
|
@ -33,6 +33,7 @@ from lerobot.common.policies.pretrained import PreTrainedPolicy
|
|||
from lerobot.common.policies.sac.configuration_sac import SACConfig
|
||||
from lerobot.common.policies.utils import get_device_from_parameters
|
||||
|
||||
DISCRETE_DIMENSION_INDEX = -1
|
||||
|
||||
class SACPolicy(
|
||||
PreTrainedPolicy,
|
||||
|
@ -113,24 +114,30 @@ class SACPolicy(
|
|||
self.critic_ensemble = torch.compile(self.critic_ensemble)
|
||||
self.critic_target = torch.compile(self.critic_target)
|
||||
|
||||
# Create grasp critic
|
||||
self.grasp_critic = GraspCritic(
|
||||
encoder=encoder_critic,
|
||||
input_dim=encoder_critic.output_dim,
|
||||
**config.grasp_critic_network_kwargs,
|
||||
)
|
||||
self.grasp_critic = None
|
||||
self.grasp_critic_target = None
|
||||
|
||||
# Create target grasp critic
|
||||
self.grasp_critic_target = GraspCritic(
|
||||
encoder=encoder_critic,
|
||||
input_dim=encoder_critic.output_dim,
|
||||
**config.grasp_critic_network_kwargs,
|
||||
)
|
||||
if config.num_discrete_actions is not None:
|
||||
# Create grasp critic
|
||||
self.grasp_critic = GraspCritic(
|
||||
encoder=encoder_critic,
|
||||
input_dim=encoder_critic.output_dim,
|
||||
output_dim=config.num_discrete_actions,
|
||||
**asdict(config.grasp_critic_network_kwargs),
|
||||
)
|
||||
|
||||
self.grasp_critic_target.load_state_dict(self.grasp_critic.state_dict())
|
||||
# Create target grasp critic
|
||||
self.grasp_critic_target = GraspCritic(
|
||||
encoder=encoder_critic,
|
||||
input_dim=encoder_critic.output_dim,
|
||||
output_dim=config.num_discrete_actions,
|
||||
**asdict(config.grasp_critic_network_kwargs),
|
||||
)
|
||||
|
||||
self.grasp_critic = torch.compile(self.grasp_critic)
|
||||
self.grasp_critic_target = torch.compile(self.grasp_critic_target)
|
||||
self.grasp_critic_target.load_state_dict(self.grasp_critic.state_dict())
|
||||
|
||||
self.grasp_critic = torch.compile(self.grasp_critic)
|
||||
self.grasp_critic_target = torch.compile(self.grasp_critic_target)
|
||||
|
||||
self.actor = Policy(
|
||||
encoder=encoder_actor,
|
||||
|
@ -173,6 +180,12 @@ class SACPolicy(
|
|||
"""Select action for inference/evaluation"""
|
||||
actions, _, _ = self.actor(batch)
|
||||
actions = self.unnormalize_outputs({"action": actions})["action"]
|
||||
|
||||
if self.config.num_discrete_actions is not None:
|
||||
discrete_action_value = self.grasp_critic(batch)
|
||||
discrete_action = torch.argmax(discrete_action_value, dim=-1)
|
||||
actions = torch.cat([actions, discrete_action], dim=-1)
|
||||
|
||||
return actions
|
||||
|
||||
def critic_forward(
|
||||
|
@ -192,11 +205,12 @@ class SACPolicy(
|
|||
Returns:
|
||||
Tensor of Q-values from all critics
|
||||
"""
|
||||
|
||||
critics = self.critic_target if use_target else self.critic_ensemble
|
||||
q_values = critics(observations, actions, observation_features)
|
||||
return q_values
|
||||
|
||||
def grasp_critic_forward(self, observations, use_target=False, observation_features=None):
|
||||
def grasp_critic_forward(self, observations, use_target=False, observation_features=None) -> torch.Tensor:
|
||||
"""Forward pass through a grasp critic network
|
||||
|
||||
Args:
|
||||
|
@ -256,9 +270,6 @@ class SACPolicy(
|
|||
)
|
||||
|
||||
if model == "grasp_critic":
|
||||
# Extract grasp_critic-specific components
|
||||
complementary_info: dict[str, Tensor] = batch["complementary_info"]
|
||||
|
||||
return self.compute_loss_grasp_critic(
|
||||
observations=observations,
|
||||
actions=actions,
|
||||
|
@ -267,7 +278,6 @@ class SACPolicy(
|
|||
done=done,
|
||||
observation_features=observation_features,
|
||||
next_observation_features=next_observation_features,
|
||||
complementary_info=complementary_info,
|
||||
)
|
||||
|
||||
if model == "actor":
|
||||
|
@ -349,6 +359,12 @@ class SACPolicy(
|
|||
td_target = rewards + (1 - done) * self.config.discount * min_q
|
||||
|
||||
# 3- compute predicted qs
|
||||
if self.config.num_discrete_actions is not None:
|
||||
# NOTE: We only want to keep the continuous action part
|
||||
# In the buffer we have the full action space (continuous + discrete)
|
||||
# We need to split them before concatenating them in the critic forward
|
||||
actions: Tensor = actions[:, :DISCRETE_DIMENSION_INDEX]
|
||||
|
||||
q_preds = self.critic_forward(
|
||||
observations=observations,
|
||||
actions=actions,
|
||||
|
@ -378,30 +394,43 @@ class SACPolicy(
|
|||
done,
|
||||
observation_features=None,
|
||||
next_observation_features=None,
|
||||
complementary_info=None,
|
||||
):
|
||||
batch_size = rewards.shape[0]
|
||||
grasp_actions = torch.clip(actions[:, -1].long() + 1, 0, 2) # Map [-1, 0, 1] -> [0, 1, 2]
|
||||
# NOTE: We only want to keep the discrete action part
|
||||
# In the buffer we have the full action space (continuous + discrete)
|
||||
# We need to split them before concatenating them in the critic forward
|
||||
actions: Tensor = actions[:, DISCRETE_DIMENSION_INDEX:]
|
||||
actions = actions.long()
|
||||
|
||||
with torch.no_grad():
|
||||
# For DQN, select actions using online network, evaluate with target network
|
||||
next_grasp_qs = self.grasp_critic_forward(next_observations, use_target=False)
|
||||
best_next_grasp_action = torch.argmax(next_grasp_qs, dim=-1)
|
||||
|
||||
# Get target Q-values from target network
|
||||
target_next_grasp_qs = self.grasp_critic_forward(observations=next_observations, use_target=True)
|
||||
|
||||
# Use gather to select Q-values for best actions
|
||||
target_next_grasp_q = torch.gather(
|
||||
target_next_grasp_qs,
|
||||
dim=1,
|
||||
index=best_next_grasp_action.unsqueeze(-1)
|
||||
).squeeze(-1)
|
||||
|
||||
target_next_grasp_qs = self.grasp_critic_forward(next_observations, use_target=True)
|
||||
target_next_grasp_q = target_next_grasp_qs[torch.arange(batch_size), best_next_grasp_action]
|
||||
# Compute target Q-value with Bellman equation
|
||||
target_grasp_q = rewards + (1 - done) * self.config.discount * target_next_grasp_q
|
||||
|
||||
# Get the grasp penalty from complementary_info
|
||||
grasp_penalty = torch.zeros_like(rewards)
|
||||
if complementary_info is not None and "grasp_penalty" in complementary_info:
|
||||
grasp_penalty = complementary_info["grasp_penalty"]
|
||||
# Get predicted Q-values for current observations
|
||||
predicted_grasp_qs = self.grasp_critic_forward(observations=observations, use_target=False)
|
||||
|
||||
# Use gather to select Q-values for taken actions
|
||||
predicted_grasp_q = torch.gather(
|
||||
predicted_grasp_qs,
|
||||
dim=1,
|
||||
index=actions.unsqueeze(-1)
|
||||
).squeeze(-1)
|
||||
|
||||
grasp_rewards = rewards + grasp_penalty
|
||||
target_grasp_q = grasp_rewards + (1 - done) * self.config.discount * target_next_grasp_q
|
||||
|
||||
predicted_grasp_qs = self.grasp_critic_forward(observations, use_target=False)
|
||||
predicted_grasp_q = predicted_grasp_qs[torch.arange(batch_size), grasp_actions]
|
||||
|
||||
grasp_critic_loss = F.mse_loss(input=predicted_grasp_q, target=target_grasp_q, reduction="mean")
|
||||
# Compute MSE loss between predicted and target Q-values
|
||||
grasp_critic_loss = F.mse_loss(input=predicted_grasp_q, target=target_grasp_q)
|
||||
return grasp_critic_loss
|
||||
|
||||
def compute_loss_temperature(self, observations, observation_features: Tensor | None = None) -> Tensor:
|
||||
|
@ -611,7 +640,7 @@ class GraspCritic(nn.Module):
|
|||
self,
|
||||
encoder: Optional[nn.Module],
|
||||
network: nn.Module,
|
||||
output_dim: int = 3,
|
||||
output_dim: int = 3, # TODO (azouitine): rename it number of discret acitons smth like that
|
||||
init_final: Optional[float] = None,
|
||||
encoder_is_shared: bool = False,
|
||||
):
|
||||
|
|
Loading…
Reference in New Issue