Enhance SACPolicy and learner server for improved grasp critic integration
- Updated SACPolicy to conditionally compute grasp critic losses based on the presence of discrete actions. - Refactored the forward method to handle grasp critic model selection and loss computation more clearly. - Adjusted learner server to utilize optimized parameters for grasp critic during training. - Improved action handling in the ManiskillMockGripperWrapper to accommodate both tuple and single action inputs.
This commit is contained in:
parent
077d18b439
commit
a54baceabb
|
@ -52,8 +52,6 @@ class SACPolicy(
|
||||||
self.config = config
|
self.config = config
|
||||||
|
|
||||||
continuous_action_dim = config.output_features["action"].shape[0]
|
continuous_action_dim = config.output_features["action"].shape[0]
|
||||||
if config.num_discrete_actions is not None:
|
|
||||||
continuous_action_dim -= 1
|
|
||||||
|
|
||||||
if config.dataset_stats is not None:
|
if config.dataset_stats is not None:
|
||||||
input_normalization_params = _convert_normalization_params_to_tensor(config.dataset_stats)
|
input_normalization_params = _convert_normalization_params_to_tensor(config.dataset_stats)
|
||||||
|
@ -191,7 +189,7 @@ class SACPolicy(
|
||||||
|
|
||||||
if self.config.num_discrete_actions is not None:
|
if self.config.num_discrete_actions is not None:
|
||||||
discrete_action_value = self.grasp_critic(batch)
|
discrete_action_value = self.grasp_critic(batch)
|
||||||
discrete_action = torch.argmax(discrete_action_value, dim=-1)
|
discrete_action = torch.argmax(discrete_action_value, dim=-1, keepdim=True)
|
||||||
actions = torch.cat([actions, discrete_action], dim=-1)
|
actions = torch.cat([actions, discrete_action], dim=-1)
|
||||||
|
|
||||||
return actions
|
return actions
|
||||||
|
@ -236,7 +234,7 @@ class SACPolicy(
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
batch: dict[str, Tensor | dict[str, Tensor]],
|
batch: dict[str, Tensor | dict[str, Tensor]],
|
||||||
model: Literal["actor", "critic", "temperature"] = "critic",
|
model: Literal["actor", "critic", "temperature", "grasp_critic"] = "critic",
|
||||||
) -> dict[str, Tensor]:
|
) -> dict[str, Tensor]:
|
||||||
"""Compute the loss for the given model
|
"""Compute the loss for the given model
|
||||||
|
|
||||||
|
@ -275,18 +273,25 @@ class SACPolicy(
|
||||||
observation_features=observation_features,
|
observation_features=observation_features,
|
||||||
next_observation_features=next_observation_features,
|
next_observation_features=next_observation_features,
|
||||||
)
|
)
|
||||||
if self.config.num_discrete_actions is not None:
|
|
||||||
|
return {"loss_critic": loss_critic}
|
||||||
|
|
||||||
|
if model == "grasp_critic" and self.config.num_discrete_actions is not None:
|
||||||
|
# Extract critic-specific components
|
||||||
|
rewards: Tensor = batch["reward"]
|
||||||
|
next_observations: dict[str, Tensor] = batch["next_state"]
|
||||||
|
done: Tensor = batch["done"]
|
||||||
|
next_observation_features: Tensor = batch.get("next_observation_feature")
|
||||||
loss_grasp_critic = self.compute_loss_grasp_critic(
|
loss_grasp_critic = self.compute_loss_grasp_critic(
|
||||||
observations=observations,
|
observations=observations,
|
||||||
actions=actions,
|
actions=actions,
|
||||||
rewards=rewards,
|
rewards=rewards,
|
||||||
next_observations=next_observations,
|
next_observations=next_observations,
|
||||||
done=done,
|
done=done,
|
||||||
|
observation_features=observation_features,
|
||||||
|
next_observation_features=next_observation_features,
|
||||||
)
|
)
|
||||||
return {"loss_critic": loss_critic, "loss_grasp_critic": loss_grasp_critic}
|
return {"loss_grasp_critic": loss_grasp_critic}
|
||||||
|
|
||||||
return {"loss_critic": loss_critic}
|
|
||||||
|
|
||||||
if model == "actor":
|
if model == "actor":
|
||||||
return {
|
return {
|
||||||
"loss_actor": self.compute_loss_actor(
|
"loss_actor": self.compute_loss_actor(
|
||||||
|
@ -373,7 +378,6 @@ class SACPolicy(
|
||||||
# In the buffer we have the full action space (continuous + discrete)
|
# In the buffer we have the full action space (continuous + discrete)
|
||||||
# We need to split them before concatenating them in the critic forward
|
# We need to split them before concatenating them in the critic forward
|
||||||
actions: Tensor = actions[:, :DISCRETE_DIMENSION_INDEX]
|
actions: Tensor = actions[:, :DISCRETE_DIMENSION_INDEX]
|
||||||
|
|
||||||
q_preds = self.critic_forward(
|
q_preds = self.critic_forward(
|
||||||
observations=observations,
|
observations=observations,
|
||||||
actions=actions,
|
actions=actions,
|
||||||
|
@ -407,30 +411,38 @@ class SACPolicy(
|
||||||
# NOTE: We only want to keep the discrete action part
|
# NOTE: We only want to keep the discrete action part
|
||||||
# In the buffer we have the full action space (continuous + discrete)
|
# In the buffer we have the full action space (continuous + discrete)
|
||||||
# We need to split them before concatenating them in the critic forward
|
# We need to split them before concatenating them in the critic forward
|
||||||
actions: Tensor = actions[:, DISCRETE_DIMENSION_INDEX:]
|
actions_discrete: Tensor = actions[:, DISCRETE_DIMENSION_INDEX:].clone()
|
||||||
actions = actions.long()
|
actions_discrete = actions_discrete.long()
|
||||||
|
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
# For DQN, select actions using online network, evaluate with target network
|
# For DQN, select actions using online network, evaluate with target network
|
||||||
next_grasp_qs = self.grasp_critic_forward(next_observations, use_target=False)
|
next_grasp_qs = self.grasp_critic_forward(
|
||||||
best_next_grasp_action = torch.argmax(next_grasp_qs, dim=-1)
|
next_observations, use_target=False, observation_features=next_observation_features
|
||||||
|
)
|
||||||
|
best_next_grasp_action = torch.argmax(next_grasp_qs, dim=-1, keepdim=True)
|
||||||
|
|
||||||
# Get target Q-values from target network
|
# Get target Q-values from target network
|
||||||
target_next_grasp_qs = self.grasp_critic_forward(observations=next_observations, use_target=True)
|
target_next_grasp_qs = self.grasp_critic_forward(
|
||||||
|
observations=next_observations,
|
||||||
|
use_target=True,
|
||||||
|
observation_features=next_observation_features,
|
||||||
|
)
|
||||||
|
|
||||||
# Use gather to select Q-values for best actions
|
# Use gather to select Q-values for best actions
|
||||||
target_next_grasp_q = torch.gather(
|
target_next_grasp_q = torch.gather(
|
||||||
target_next_grasp_qs, dim=1, index=best_next_grasp_action.unsqueeze(-1)
|
target_next_grasp_qs, dim=1, index=best_next_grasp_action
|
||||||
).squeeze(-1)
|
).squeeze(-1)
|
||||||
|
|
||||||
# Compute target Q-value with Bellman equation
|
# Compute target Q-value with Bellman equation
|
||||||
target_grasp_q = rewards + (1 - done) * self.config.discount * target_next_grasp_q
|
target_grasp_q = rewards + (1 - done) * self.config.discount * target_next_grasp_q
|
||||||
|
|
||||||
# Get predicted Q-values for current observations
|
# Get predicted Q-values for current observations
|
||||||
predicted_grasp_qs = self.grasp_critic_forward(observations=observations, use_target=False)
|
predicted_grasp_qs = self.grasp_critic_forward(
|
||||||
|
observations=observations, use_target=False, observation_features=observation_features
|
||||||
|
)
|
||||||
|
|
||||||
# Use gather to select Q-values for taken actions
|
# Use gather to select Q-values for taken actions
|
||||||
predicted_grasp_q = torch.gather(predicted_grasp_qs, dim=1, index=actions.unsqueeze(-1)).squeeze(-1)
|
predicted_grasp_q = torch.gather(predicted_grasp_qs, dim=1, index=actions_discrete).squeeze(-1)
|
||||||
|
|
||||||
# Compute MSE loss between predicted and target Q-values
|
# Compute MSE loss between predicted and target Q-values
|
||||||
grasp_critic_loss = F.mse_loss(input=predicted_grasp_q, target=target_grasp_q)
|
grasp_critic_loss = F.mse_loss(input=predicted_grasp_q, target=target_grasp_q)
|
||||||
|
@ -642,49 +654,52 @@ class GraspCritic(nn.Module):
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
encoder: Optional[nn.Module],
|
encoder: Optional[nn.Module],
|
||||||
network: nn.Module,
|
input_dim: int,
|
||||||
output_dim: int = 3, # TODO (azouitine): rename it number of discret acitons smth like that
|
hidden_dims: list[int],
|
||||||
|
output_dim: int = 3,
|
||||||
|
activations: Callable[[torch.Tensor], torch.Tensor] | str = nn.SiLU(),
|
||||||
|
activate_final: bool = False,
|
||||||
|
dropout_rate: Optional[float] = None,
|
||||||
init_final: Optional[float] = None,
|
init_final: Optional[float] = None,
|
||||||
encoder_is_shared: bool = False,
|
final_activation: Callable[[torch.Tensor], torch.Tensor] | str | None = None,
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.encoder = encoder
|
self.encoder = encoder
|
||||||
self.network = network
|
|
||||||
self.output_dim = output_dim
|
self.output_dim = output_dim
|
||||||
|
|
||||||
# Find the last Linear layer's output dimension
|
self.net = MLP(
|
||||||
for layer in reversed(network.net):
|
input_dim=input_dim,
|
||||||
if isinstance(layer, nn.Linear):
|
hidden_dims=hidden_dims,
|
||||||
out_features = layer.out_features
|
activations=activations,
|
||||||
break
|
activate_final=activate_final,
|
||||||
|
dropout_rate=dropout_rate,
|
||||||
|
final_activation=final_activation,
|
||||||
|
)
|
||||||
|
|
||||||
self.parameters_to_optimize += list(self.network.parameters())
|
self.output_layer = nn.Linear(in_features=hidden_dims[-1], out_features=self.output_dim)
|
||||||
|
|
||||||
if self.encoder is not None and not encoder_is_shared:
|
|
||||||
self.parameters_to_optimize += list(self.encoder.parameters())
|
|
||||||
|
|
||||||
self.output_layer = nn.Linear(in_features=out_features, out_features=self.output_dim)
|
|
||||||
if init_final is not None:
|
if init_final is not None:
|
||||||
nn.init.uniform_(self.output_layer.weight, -init_final, init_final)
|
nn.init.uniform_(self.output_layer.weight, -init_final, init_final)
|
||||||
nn.init.uniform_(self.output_layer.bias, -init_final, init_final)
|
nn.init.uniform_(self.output_layer.bias, -init_final, init_final)
|
||||||
else:
|
else:
|
||||||
orthogonal_init()(self.output_layer.weight)
|
orthogonal_init()(self.output_layer.weight)
|
||||||
|
|
||||||
|
self.parameters_to_optimize = []
|
||||||
|
self.parameters_to_optimize += list(self.net.parameters())
|
||||||
self.parameters_to_optimize += list(self.output_layer.parameters())
|
self.parameters_to_optimize += list(self.output_layer.parameters())
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self, observations: torch.Tensor, observation_features: torch.Tensor | None = None
|
self, observations: torch.Tensor, observation_features: torch.Tensor | None = None
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
device = get_device_from_parameters(self)
|
device = get_device_from_parameters(self)
|
||||||
# Move each tensor in observations to device
|
# Move each tensor in observations to device by cloning first to avoid inplace operations
|
||||||
observations = {k: v.to(device) for k, v in observations.items()}
|
observations = {k: v.to(device) for k, v in observations.items()}
|
||||||
# Encode observations if encoder exists
|
# Encode observations if encoder exists
|
||||||
obs_enc = (
|
obs_enc = (
|
||||||
observation_features
|
observation_features.to(device)
|
||||||
if observation_features is not None
|
if observation_features is not None
|
||||||
else (observations if self.encoder is None else self.encoder(observations))
|
else (observations if self.encoder is None else self.encoder(observations))
|
||||||
)
|
)
|
||||||
return self.output_layer(self.network(obs_enc))
|
return self.output_layer(self.net(obs_enc))
|
||||||
|
|
||||||
|
|
||||||
class Policy(nn.Module):
|
class Policy(nn.Module):
|
||||||
|
|
|
@ -405,12 +405,13 @@ def add_actor_information_and_train(
|
||||||
optimizers["critic"].step()
|
optimizers["critic"].step()
|
||||||
|
|
||||||
# Grasp critic optimization (if available)
|
# Grasp critic optimization (if available)
|
||||||
if "loss_grasp_critic" in critic_output:
|
if policy.config.num_discrete_actions is not None:
|
||||||
loss_grasp_critic = critic_output["loss_grasp_critic"]
|
discrete_critic_output = policy.forward(forward_batch, model="grasp_critic")
|
||||||
|
loss_grasp_critic = discrete_critic_output["loss_grasp_critic"]
|
||||||
optimizers["grasp_critic"].zero_grad()
|
optimizers["grasp_critic"].zero_grad()
|
||||||
loss_grasp_critic.backward()
|
loss_grasp_critic.backward()
|
||||||
grasp_critic_grad_norm = torch.nn.utils.clip_grad_norm_(
|
grasp_critic_grad_norm = torch.nn.utils.clip_grad_norm_(
|
||||||
parameters=policy.grasp_critic.parameters(), max_norm=clip_grad_norm_value
|
parameters=policy.grasp_critic.parameters_to_optimize, max_norm=clip_grad_norm_value
|
||||||
)
|
)
|
||||||
optimizers["grasp_critic"].step()
|
optimizers["grasp_critic"].step()
|
||||||
|
|
||||||
|
@ -467,12 +468,13 @@ def add_actor_information_and_train(
|
||||||
}
|
}
|
||||||
|
|
||||||
# Grasp critic optimization (if available)
|
# Grasp critic optimization (if available)
|
||||||
if "loss_grasp_critic" in critic_output:
|
if policy.config.num_discrete_actions is not None:
|
||||||
loss_grasp_critic = critic_output["loss_grasp_critic"]
|
discrete_critic_output = policy.forward(forward_batch, model="grasp_critic")
|
||||||
|
loss_grasp_critic = discrete_critic_output["loss_grasp_critic"]
|
||||||
optimizers["grasp_critic"].zero_grad()
|
optimizers["grasp_critic"].zero_grad()
|
||||||
loss_grasp_critic.backward()
|
loss_grasp_critic.backward()
|
||||||
grasp_critic_grad_norm = torch.nn.utils.clip_grad_norm_(
|
grasp_critic_grad_norm = torch.nn.utils.clip_grad_norm_(
|
||||||
parameters=policy.grasp_critic.parameters(), max_norm=clip_grad_norm_value
|
parameters=policy.grasp_critic.parameters_to_optimize, max_norm=clip_grad_norm_value
|
||||||
).item()
|
).item()
|
||||||
optimizers["grasp_critic"].step()
|
optimizers["grasp_critic"].step()
|
||||||
|
|
||||||
|
@ -759,7 +761,7 @@ def make_optimizers_and_scheduler(cfg: TrainPipelineConfig, policy: nn.Module):
|
||||||
|
|
||||||
if cfg.policy.num_discrete_actions is not None:
|
if cfg.policy.num_discrete_actions is not None:
|
||||||
optimizer_grasp_critic = torch.optim.Adam(
|
optimizer_grasp_critic = torch.optim.Adam(
|
||||||
params=policy.grasp_critic.parameters(), lr=policy.critic_lr
|
params=policy.grasp_critic.parameters_to_optimize, lr=cfg.policy.critic_lr
|
||||||
)
|
)
|
||||||
optimizer_temperature = torch.optim.Adam(params=[policy.log_alpha], lr=cfg.policy.critic_lr)
|
optimizer_temperature = torch.optim.Adam(params=[policy.log_alpha], lr=cfg.policy.critic_lr)
|
||||||
lr_scheduler = None
|
lr_scheduler = None
|
||||||
|
|
|
@ -16,7 +16,6 @@ from lerobot.common.policies.sac.configuration_sac import SACConfig
|
||||||
from lerobot.common.policies.sac.modeling_sac import SACPolicy
|
from lerobot.common.policies.sac.modeling_sac import SACPolicy
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
def preprocess_maniskill_observation(
|
def preprocess_maniskill_observation(
|
||||||
observations: dict[str, np.ndarray],
|
observations: dict[str, np.ndarray],
|
||||||
) -> dict[str, torch.Tensor]:
|
) -> dict[str, torch.Tensor]:
|
||||||
|
@ -156,6 +155,7 @@ class TimeLimitWrapper(gym.Wrapper):
|
||||||
self.current_step = 0
|
self.current_step = 0
|
||||||
return super().reset(seed=seed, options=options)
|
return super().reset(seed=seed, options=options)
|
||||||
|
|
||||||
|
|
||||||
class ManiskillMockGripperWrapper(gym.Wrapper):
|
class ManiskillMockGripperWrapper(gym.Wrapper):
|
||||||
def __init__(self, env, nb_discrete_actions: int = 3):
|
def __init__(self, env, nb_discrete_actions: int = 3):
|
||||||
super().__init__(env)
|
super().__init__(env)
|
||||||
|
@ -166,12 +166,17 @@ class ManiskillMockGripperWrapper(gym.Wrapper):
|
||||||
self.action_space = gym.spaces.Tuple((action_space_agent, env.action_space[1]))
|
self.action_space = gym.spaces.Tuple((action_space_agent, env.action_space[1]))
|
||||||
|
|
||||||
def step(self, action):
|
def step(self, action):
|
||||||
|
if isinstance(action, tuple):
|
||||||
action_agent, telop_action = action
|
action_agent, telop_action = action
|
||||||
|
else:
|
||||||
|
telop_action = 0
|
||||||
|
action_agent = action
|
||||||
real_action = action_agent[:-1]
|
real_action = action_agent[:-1]
|
||||||
final_action = (real_action, telop_action)
|
final_action = (real_action, telop_action)
|
||||||
obs, reward, terminated, truncated, info = self.env.step(final_action)
|
obs, reward, terminated, truncated, info = self.env.step(final_action)
|
||||||
return obs, reward, terminated, truncated, info
|
return obs, reward, terminated, truncated, info
|
||||||
|
|
||||||
|
|
||||||
def make_maniskill(
|
def make_maniskill(
|
||||||
cfg: ManiskillEnvConfig,
|
cfg: ManiskillEnvConfig,
|
||||||
n_envs: int | None = None,
|
n_envs: int | None = None,
|
||||||
|
|
Loading…
Reference in New Issue