Enhance SACPolicy and learner server for improved grasp critic integration

- Updated SACPolicy to conditionally compute grasp critic losses based on the presence of discrete actions.
- Refactored the forward method to handle grasp critic model selection and loss computation more clearly.
- Adjusted learner server to utilize optimized parameters for grasp critic during training.
- Improved action handling in the ManiskillMockGripperWrapper to accommodate both tuple and single action inputs.
This commit is contained in:
AdilZouitine 2025-04-02 15:50:39 +00:00 committed by Adil Zouitine
parent 077d18b439
commit a54baceabb
3 changed files with 72 additions and 50 deletions

View File

@ -52,8 +52,6 @@ class SACPolicy(
self.config = config self.config = config
continuous_action_dim = config.output_features["action"].shape[0] continuous_action_dim = config.output_features["action"].shape[0]
if config.num_discrete_actions is not None:
continuous_action_dim -= 1
if config.dataset_stats is not None: if config.dataset_stats is not None:
input_normalization_params = _convert_normalization_params_to_tensor(config.dataset_stats) input_normalization_params = _convert_normalization_params_to_tensor(config.dataset_stats)
@ -191,7 +189,7 @@ class SACPolicy(
if self.config.num_discrete_actions is not None: if self.config.num_discrete_actions is not None:
discrete_action_value = self.grasp_critic(batch) discrete_action_value = self.grasp_critic(batch)
discrete_action = torch.argmax(discrete_action_value, dim=-1) discrete_action = torch.argmax(discrete_action_value, dim=-1, keepdim=True)
actions = torch.cat([actions, discrete_action], dim=-1) actions = torch.cat([actions, discrete_action], dim=-1)
return actions return actions
@ -236,7 +234,7 @@ class SACPolicy(
def forward( def forward(
self, self,
batch: dict[str, Tensor | dict[str, Tensor]], batch: dict[str, Tensor | dict[str, Tensor]],
model: Literal["actor", "critic", "temperature"] = "critic", model: Literal["actor", "critic", "temperature", "grasp_critic"] = "critic",
) -> dict[str, Tensor]: ) -> dict[str, Tensor]:
"""Compute the loss for the given model """Compute the loss for the given model
@ -275,18 +273,25 @@ class SACPolicy(
observation_features=observation_features, observation_features=observation_features,
next_observation_features=next_observation_features, next_observation_features=next_observation_features,
) )
if self.config.num_discrete_actions is not None:
return {"loss_critic": loss_critic}
if model == "grasp_critic" and self.config.num_discrete_actions is not None:
# Extract critic-specific components
rewards: Tensor = batch["reward"]
next_observations: dict[str, Tensor] = batch["next_state"]
done: Tensor = batch["done"]
next_observation_features: Tensor = batch.get("next_observation_feature")
loss_grasp_critic = self.compute_loss_grasp_critic( loss_grasp_critic = self.compute_loss_grasp_critic(
observations=observations, observations=observations,
actions=actions, actions=actions,
rewards=rewards, rewards=rewards,
next_observations=next_observations, next_observations=next_observations,
done=done, done=done,
observation_features=observation_features,
next_observation_features=next_observation_features,
) )
return {"loss_critic": loss_critic, "loss_grasp_critic": loss_grasp_critic} return {"loss_grasp_critic": loss_grasp_critic}
return {"loss_critic": loss_critic}
if model == "actor": if model == "actor":
return { return {
"loss_actor": self.compute_loss_actor( "loss_actor": self.compute_loss_actor(
@ -373,7 +378,6 @@ class SACPolicy(
# In the buffer we have the full action space (continuous + discrete) # In the buffer we have the full action space (continuous + discrete)
# We need to split them before concatenating them in the critic forward # We need to split them before concatenating them in the critic forward
actions: Tensor = actions[:, :DISCRETE_DIMENSION_INDEX] actions: Tensor = actions[:, :DISCRETE_DIMENSION_INDEX]
q_preds = self.critic_forward( q_preds = self.critic_forward(
observations=observations, observations=observations,
actions=actions, actions=actions,
@ -407,30 +411,38 @@ class SACPolicy(
# NOTE: We only want to keep the discrete action part # NOTE: We only want to keep the discrete action part
# In the buffer we have the full action space (continuous + discrete) # In the buffer we have the full action space (continuous + discrete)
# We need to split them before concatenating them in the critic forward # We need to split them before concatenating them in the critic forward
actions: Tensor = actions[:, DISCRETE_DIMENSION_INDEX:] actions_discrete: Tensor = actions[:, DISCRETE_DIMENSION_INDEX:].clone()
actions = actions.long() actions_discrete = actions_discrete.long()
with torch.no_grad(): with torch.no_grad():
# For DQN, select actions using online network, evaluate with target network # For DQN, select actions using online network, evaluate with target network
next_grasp_qs = self.grasp_critic_forward(next_observations, use_target=False) next_grasp_qs = self.grasp_critic_forward(
best_next_grasp_action = torch.argmax(next_grasp_qs, dim=-1) next_observations, use_target=False, observation_features=next_observation_features
)
best_next_grasp_action = torch.argmax(next_grasp_qs, dim=-1, keepdim=True)
# Get target Q-values from target network # Get target Q-values from target network
target_next_grasp_qs = self.grasp_critic_forward(observations=next_observations, use_target=True) target_next_grasp_qs = self.grasp_critic_forward(
observations=next_observations,
use_target=True,
observation_features=next_observation_features,
)
# Use gather to select Q-values for best actions # Use gather to select Q-values for best actions
target_next_grasp_q = torch.gather( target_next_grasp_q = torch.gather(
target_next_grasp_qs, dim=1, index=best_next_grasp_action.unsqueeze(-1) target_next_grasp_qs, dim=1, index=best_next_grasp_action
).squeeze(-1) ).squeeze(-1)
# Compute target Q-value with Bellman equation # Compute target Q-value with Bellman equation
target_grasp_q = rewards + (1 - done) * self.config.discount * target_next_grasp_q target_grasp_q = rewards + (1 - done) * self.config.discount * target_next_grasp_q
# Get predicted Q-values for current observations # Get predicted Q-values for current observations
predicted_grasp_qs = self.grasp_critic_forward(observations=observations, use_target=False) predicted_grasp_qs = self.grasp_critic_forward(
observations=observations, use_target=False, observation_features=observation_features
)
# Use gather to select Q-values for taken actions # Use gather to select Q-values for taken actions
predicted_grasp_q = torch.gather(predicted_grasp_qs, dim=1, index=actions.unsqueeze(-1)).squeeze(-1) predicted_grasp_q = torch.gather(predicted_grasp_qs, dim=1, index=actions_discrete).squeeze(-1)
# Compute MSE loss between predicted and target Q-values # Compute MSE loss between predicted and target Q-values
grasp_critic_loss = F.mse_loss(input=predicted_grasp_q, target=target_grasp_q) grasp_critic_loss = F.mse_loss(input=predicted_grasp_q, target=target_grasp_q)
@ -642,49 +654,52 @@ class GraspCritic(nn.Module):
def __init__( def __init__(
self, self,
encoder: Optional[nn.Module], encoder: Optional[nn.Module],
network: nn.Module, input_dim: int,
output_dim: int = 3, # TODO (azouitine): rename it number of discret acitons smth like that hidden_dims: list[int],
output_dim: int = 3,
activations: Callable[[torch.Tensor], torch.Tensor] | str = nn.SiLU(),
activate_final: bool = False,
dropout_rate: Optional[float] = None,
init_final: Optional[float] = None, init_final: Optional[float] = None,
encoder_is_shared: bool = False, final_activation: Callable[[torch.Tensor], torch.Tensor] | str | None = None,
): ):
super().__init__() super().__init__()
self.encoder = encoder self.encoder = encoder
self.network = network
self.output_dim = output_dim self.output_dim = output_dim
# Find the last Linear layer's output dimension self.net = MLP(
for layer in reversed(network.net): input_dim=input_dim,
if isinstance(layer, nn.Linear): hidden_dims=hidden_dims,
out_features = layer.out_features activations=activations,
break activate_final=activate_final,
dropout_rate=dropout_rate,
final_activation=final_activation,
)
self.parameters_to_optimize += list(self.network.parameters()) self.output_layer = nn.Linear(in_features=hidden_dims[-1], out_features=self.output_dim)
if self.encoder is not None and not encoder_is_shared:
self.parameters_to_optimize += list(self.encoder.parameters())
self.output_layer = nn.Linear(in_features=out_features, out_features=self.output_dim)
if init_final is not None: if init_final is not None:
nn.init.uniform_(self.output_layer.weight, -init_final, init_final) nn.init.uniform_(self.output_layer.weight, -init_final, init_final)
nn.init.uniform_(self.output_layer.bias, -init_final, init_final) nn.init.uniform_(self.output_layer.bias, -init_final, init_final)
else: else:
orthogonal_init()(self.output_layer.weight) orthogonal_init()(self.output_layer.weight)
self.parameters_to_optimize = []
self.parameters_to_optimize += list(self.net.parameters())
self.parameters_to_optimize += list(self.output_layer.parameters()) self.parameters_to_optimize += list(self.output_layer.parameters())
def forward( def forward(
self, observations: torch.Tensor, observation_features: torch.Tensor | None = None self, observations: torch.Tensor, observation_features: torch.Tensor | None = None
) -> torch.Tensor: ) -> torch.Tensor:
device = get_device_from_parameters(self) device = get_device_from_parameters(self)
# Move each tensor in observations to device # Move each tensor in observations to device by cloning first to avoid inplace operations
observations = {k: v.to(device) for k, v in observations.items()} observations = {k: v.to(device) for k, v in observations.items()}
# Encode observations if encoder exists # Encode observations if encoder exists
obs_enc = ( obs_enc = (
observation_features observation_features.to(device)
if observation_features is not None if observation_features is not None
else (observations if self.encoder is None else self.encoder(observations)) else (observations if self.encoder is None else self.encoder(observations))
) )
return self.output_layer(self.network(obs_enc)) return self.output_layer(self.net(obs_enc))
class Policy(nn.Module): class Policy(nn.Module):

View File

@ -405,12 +405,13 @@ def add_actor_information_and_train(
optimizers["critic"].step() optimizers["critic"].step()
# Grasp critic optimization (if available) # Grasp critic optimization (if available)
if "loss_grasp_critic" in critic_output: if policy.config.num_discrete_actions is not None:
loss_grasp_critic = critic_output["loss_grasp_critic"] discrete_critic_output = policy.forward(forward_batch, model="grasp_critic")
loss_grasp_critic = discrete_critic_output["loss_grasp_critic"]
optimizers["grasp_critic"].zero_grad() optimizers["grasp_critic"].zero_grad()
loss_grasp_critic.backward() loss_grasp_critic.backward()
grasp_critic_grad_norm = torch.nn.utils.clip_grad_norm_( grasp_critic_grad_norm = torch.nn.utils.clip_grad_norm_(
parameters=policy.grasp_critic.parameters(), max_norm=clip_grad_norm_value parameters=policy.grasp_critic.parameters_to_optimize, max_norm=clip_grad_norm_value
) )
optimizers["grasp_critic"].step() optimizers["grasp_critic"].step()
@ -467,12 +468,13 @@ def add_actor_information_and_train(
} }
# Grasp critic optimization (if available) # Grasp critic optimization (if available)
if "loss_grasp_critic" in critic_output: if policy.config.num_discrete_actions is not None:
loss_grasp_critic = critic_output["loss_grasp_critic"] discrete_critic_output = policy.forward(forward_batch, model="grasp_critic")
loss_grasp_critic = discrete_critic_output["loss_grasp_critic"]
optimizers["grasp_critic"].zero_grad() optimizers["grasp_critic"].zero_grad()
loss_grasp_critic.backward() loss_grasp_critic.backward()
grasp_critic_grad_norm = torch.nn.utils.clip_grad_norm_( grasp_critic_grad_norm = torch.nn.utils.clip_grad_norm_(
parameters=policy.grasp_critic.parameters(), max_norm=clip_grad_norm_value parameters=policy.grasp_critic.parameters_to_optimize, max_norm=clip_grad_norm_value
).item() ).item()
optimizers["grasp_critic"].step() optimizers["grasp_critic"].step()
@ -759,7 +761,7 @@ def make_optimizers_and_scheduler(cfg: TrainPipelineConfig, policy: nn.Module):
if cfg.policy.num_discrete_actions is not None: if cfg.policy.num_discrete_actions is not None:
optimizer_grasp_critic = torch.optim.Adam( optimizer_grasp_critic = torch.optim.Adam(
params=policy.grasp_critic.parameters(), lr=policy.critic_lr params=policy.grasp_critic.parameters_to_optimize, lr=cfg.policy.critic_lr
) )
optimizer_temperature = torch.optim.Adam(params=[policy.log_alpha], lr=cfg.policy.critic_lr) optimizer_temperature = torch.optim.Adam(params=[policy.log_alpha], lr=cfg.policy.critic_lr)
lr_scheduler = None lr_scheduler = None

View File

@ -16,7 +16,6 @@ from lerobot.common.policies.sac.configuration_sac import SACConfig
from lerobot.common.policies.sac.modeling_sac import SACPolicy from lerobot.common.policies.sac.modeling_sac import SACPolicy
def preprocess_maniskill_observation( def preprocess_maniskill_observation(
observations: dict[str, np.ndarray], observations: dict[str, np.ndarray],
) -> dict[str, torch.Tensor]: ) -> dict[str, torch.Tensor]:
@ -156,6 +155,7 @@ class TimeLimitWrapper(gym.Wrapper):
self.current_step = 0 self.current_step = 0
return super().reset(seed=seed, options=options) return super().reset(seed=seed, options=options)
class ManiskillMockGripperWrapper(gym.Wrapper): class ManiskillMockGripperWrapper(gym.Wrapper):
def __init__(self, env, nb_discrete_actions: int = 3): def __init__(self, env, nb_discrete_actions: int = 3):
super().__init__(env) super().__init__(env)
@ -166,12 +166,17 @@ class ManiskillMockGripperWrapper(gym.Wrapper):
self.action_space = gym.spaces.Tuple((action_space_agent, env.action_space[1])) self.action_space = gym.spaces.Tuple((action_space_agent, env.action_space[1]))
def step(self, action): def step(self, action):
if isinstance(action, tuple):
action_agent, telop_action = action action_agent, telop_action = action
else:
telop_action = 0
action_agent = action
real_action = action_agent[:-1] real_action = action_agent[:-1]
final_action = (real_action, telop_action) final_action = (real_action, telop_action)
obs, reward, terminated, truncated, info = self.env.step(final_action) obs, reward, terminated, truncated, info = self.env.step(final_action)
return obs, reward, terminated, truncated, info return obs, reward, terminated, truncated, info
def make_maniskill( def make_maniskill(
cfg: ManiskillEnvConfig, cfg: ManiskillEnvConfig,
n_envs: int | None = None, n_envs: int | None = None,