match target entropy hil serl

Co-authored-by: s1lent4gnt <kmeftah.khalil@gmail.com>
This commit is contained in:
AdilZouitine 2025-04-15 08:00:38 +00:00 committed by Adil Zouitine
parent 35ecaae8a9
commit a850d436db
3 changed files with 11 additions and 7 deletions

View File

@ -155,7 +155,8 @@ class SACPolicy(
**asdict(config.policy_kwargs), **asdict(config.policy_kwargs),
) )
if config.target_entropy is None: if config.target_entropy is None:
config.target_entropy = -np.prod(continuous_action_dim) / 2 # (-dim(A)/2) discrete_actions_dim: Literal[1] | Literal[0] = 1 if config.num_discrete_actions is None else 0
config.target_entropy = -np.prod(continuous_action_dim + discrete_actions_dim) / 2 # (-dim(A)/2)
# TODO (azouitine): Handle the case where the temparameter is a fixed # TODO (azouitine): Handle the case where the temparameter is a fixed
# TODO (michel-aractingi): Put the log_alpha in cuda by default because otherwise # TODO (michel-aractingi): Put the log_alpha in cuda by default because otherwise
@ -176,7 +177,7 @@ class SACPolicy(
"temperature": self.log_alpha, "temperature": self.log_alpha,
} }
if self.config.num_discrete_actions is not None: if self.config.num_discrete_actions is not None:
optim_params["grasp_critic"] = self.grasp_critic.parameters_to_optimize optim_params["grasp_critic"] = self.grasp_critic.parameters()
return optim_params return optim_params
def reset(self): def reset(self):

View File

@ -231,6 +231,7 @@ def act_with_policy(
cfg=cfg.policy, cfg=cfg.policy,
env_cfg=cfg.env, env_cfg=cfg.env,
) )
policy = policy.eval()
assert isinstance(policy, nn.Module) assert isinstance(policy, nn.Module)
obs, info = online_env.reset() obs, info = online_env.reset()

View File

@ -429,7 +429,7 @@ def add_actor_information_and_train(
optimizers["grasp_critic"].zero_grad() optimizers["grasp_critic"].zero_grad()
loss_grasp_critic.backward() loss_grasp_critic.backward()
grasp_critic_grad_norm = torch.nn.utils.clip_grad_norm_( grasp_critic_grad_norm = torch.nn.utils.clip_grad_norm_(
parameters=policy.grasp_critic.parameters_to_optimize, max_norm=clip_grad_norm_value parameters=policy.grasp_critic.parameters(), max_norm=clip_grad_norm_value
) )
optimizers["grasp_critic"].step() optimizers["grasp_critic"].step()
@ -493,7 +493,7 @@ def add_actor_information_and_train(
optimizers["grasp_critic"].zero_grad() optimizers["grasp_critic"].zero_grad()
loss_grasp_critic.backward() loss_grasp_critic.backward()
grasp_critic_grad_norm = torch.nn.utils.clip_grad_norm_( grasp_critic_grad_norm = torch.nn.utils.clip_grad_norm_(
parameters=policy.grasp_critic.parameters_to_optimize, max_norm=clip_grad_norm_value parameters=policy.grasp_critic.parameters(), max_norm=clip_grad_norm_value
).item() ).item()
optimizers["grasp_critic"].step() optimizers["grasp_critic"].step()
@ -784,7 +784,7 @@ def make_optimizers_and_scheduler(cfg: TrainPipelineConfig, policy: nn.Module):
if cfg.policy.num_discrete_actions is not None: if cfg.policy.num_discrete_actions is not None:
optimizer_grasp_critic = torch.optim.Adam( optimizer_grasp_critic = torch.optim.Adam(
params=policy.grasp_critic.parameters_to_optimize, lr=cfg.policy.critic_lr params=policy.grasp_critic.parameters(), lr=cfg.policy.critic_lr
) )
optimizer_temperature = torch.optim.Adam(params=[policy.log_alpha], lr=cfg.policy.critic_lr) optimizer_temperature = torch.optim.Adam(params=[policy.log_alpha], lr=cfg.policy.critic_lr)
lr_scheduler = None lr_scheduler = None
@ -1028,8 +1028,10 @@ def get_observation_features(
return None, None return None, None
with torch.no_grad(): with torch.no_grad():
observation_features = policy.actor.encoder.get_image_features(observations, normalize=True) observation_features = policy.actor.encoder.get_base_image_features(observations, normalize=True)
next_observation_features = policy.actor.encoder.get_image_features(next_observations, normalize=True) next_observation_features = policy.actor.encoder.get_base_image_features(
next_observations, normalize=True
)
return observation_features, next_observation_features return observation_features, next_observation_features