match target entropy hil serl

Co-authored-by: s1lent4gnt <kmeftah.khalil@gmail.com>
This commit is contained in:
AdilZouitine 2025-04-15 08:00:38 +00:00
parent 5d7820527d
commit bda2053106
3 changed files with 11 additions and 7 deletions

View File

@ -155,7 +155,8 @@ class SACPolicy(
**asdict(config.policy_kwargs),
)
if config.target_entropy is None:
config.target_entropy = -np.prod(continuous_action_dim) / 2 # (-dim(A)/2)
discrete_actions_dim: Literal[1] | Literal[0] = 1 if config.num_discrete_actions is None else 0
config.target_entropy = -np.prod(continuous_action_dim + discrete_actions_dim) / 2 # (-dim(A)/2)
# TODO (azouitine): Handle the case where the temparameter is a fixed
# TODO (michel-aractingi): Put the log_alpha in cuda by default because otherwise
@ -176,7 +177,7 @@ class SACPolicy(
"temperature": self.log_alpha,
}
if self.config.num_discrete_actions is not None:
optim_params["grasp_critic"] = self.grasp_critic.parameters_to_optimize
optim_params["grasp_critic"] = self.grasp_critic.parameters()
return optim_params
def reset(self):

View File

@ -231,6 +231,7 @@ def act_with_policy(
cfg=cfg.policy,
env_cfg=cfg.env,
)
policy = policy.eval()
assert isinstance(policy, nn.Module)
obs, info = online_env.reset()

View File

@ -429,7 +429,7 @@ def add_actor_information_and_train(
optimizers["grasp_critic"].zero_grad()
loss_grasp_critic.backward()
grasp_critic_grad_norm = torch.nn.utils.clip_grad_norm_(
parameters=policy.grasp_critic.parameters_to_optimize, max_norm=clip_grad_norm_value
parameters=policy.grasp_critic.parameters(), max_norm=clip_grad_norm_value
)
optimizers["grasp_critic"].step()
@ -493,7 +493,7 @@ def add_actor_information_and_train(
optimizers["grasp_critic"].zero_grad()
loss_grasp_critic.backward()
grasp_critic_grad_norm = torch.nn.utils.clip_grad_norm_(
parameters=policy.grasp_critic.parameters_to_optimize, max_norm=clip_grad_norm_value
parameters=policy.grasp_critic.parameters(), max_norm=clip_grad_norm_value
).item()
optimizers["grasp_critic"].step()
@ -784,7 +784,7 @@ def make_optimizers_and_scheduler(cfg: TrainPipelineConfig, policy: nn.Module):
if cfg.policy.num_discrete_actions is not None:
optimizer_grasp_critic = torch.optim.Adam(
params=policy.grasp_critic.parameters_to_optimize, lr=cfg.policy.critic_lr
params=policy.grasp_critic.parameters(), lr=cfg.policy.critic_lr
)
optimizer_temperature = torch.optim.Adam(params=[policy.log_alpha], lr=cfg.policy.critic_lr)
lr_scheduler = None
@ -1028,8 +1028,10 @@ def get_observation_features(
return None, None
with torch.no_grad():
observation_features = policy.actor.encoder.get_image_features(observations, normalize=True)
next_observation_features = policy.actor.encoder.get_image_features(next_observations, normalize=True)
observation_features = policy.actor.encoder.get_base_image_features(observations, normalize=True)
next_observation_features = policy.actor.encoder.get_base_image_features(
next_observations, normalize=True
)
return observation_features, next_observation_features