match target entropy hil serl
Co-authored-by: s1lent4gnt <kmeftah.khalil@gmail.com>
This commit is contained in:
parent
5d7820527d
commit
bda2053106
|
@ -155,7 +155,8 @@ class SACPolicy(
|
|||
**asdict(config.policy_kwargs),
|
||||
)
|
||||
if config.target_entropy is None:
|
||||
config.target_entropy = -np.prod(continuous_action_dim) / 2 # (-dim(A)/2)
|
||||
discrete_actions_dim: Literal[1] | Literal[0] = 1 if config.num_discrete_actions is None else 0
|
||||
config.target_entropy = -np.prod(continuous_action_dim + discrete_actions_dim) / 2 # (-dim(A)/2)
|
||||
|
||||
# TODO (azouitine): Handle the case where the temparameter is a fixed
|
||||
# TODO (michel-aractingi): Put the log_alpha in cuda by default because otherwise
|
||||
|
@ -176,7 +177,7 @@ class SACPolicy(
|
|||
"temperature": self.log_alpha,
|
||||
}
|
||||
if self.config.num_discrete_actions is not None:
|
||||
optim_params["grasp_critic"] = self.grasp_critic.parameters_to_optimize
|
||||
optim_params["grasp_critic"] = self.grasp_critic.parameters()
|
||||
return optim_params
|
||||
|
||||
def reset(self):
|
||||
|
|
|
@ -231,6 +231,7 @@ def act_with_policy(
|
|||
cfg=cfg.policy,
|
||||
env_cfg=cfg.env,
|
||||
)
|
||||
policy = policy.eval()
|
||||
assert isinstance(policy, nn.Module)
|
||||
|
||||
obs, info = online_env.reset()
|
||||
|
|
|
@ -429,7 +429,7 @@ def add_actor_information_and_train(
|
|||
optimizers["grasp_critic"].zero_grad()
|
||||
loss_grasp_critic.backward()
|
||||
grasp_critic_grad_norm = torch.nn.utils.clip_grad_norm_(
|
||||
parameters=policy.grasp_critic.parameters_to_optimize, max_norm=clip_grad_norm_value
|
||||
parameters=policy.grasp_critic.parameters(), max_norm=clip_grad_norm_value
|
||||
)
|
||||
optimizers["grasp_critic"].step()
|
||||
|
||||
|
@ -493,7 +493,7 @@ def add_actor_information_and_train(
|
|||
optimizers["grasp_critic"].zero_grad()
|
||||
loss_grasp_critic.backward()
|
||||
grasp_critic_grad_norm = torch.nn.utils.clip_grad_norm_(
|
||||
parameters=policy.grasp_critic.parameters_to_optimize, max_norm=clip_grad_norm_value
|
||||
parameters=policy.grasp_critic.parameters(), max_norm=clip_grad_norm_value
|
||||
).item()
|
||||
optimizers["grasp_critic"].step()
|
||||
|
||||
|
@ -784,7 +784,7 @@ def make_optimizers_and_scheduler(cfg: TrainPipelineConfig, policy: nn.Module):
|
|||
|
||||
if cfg.policy.num_discrete_actions is not None:
|
||||
optimizer_grasp_critic = torch.optim.Adam(
|
||||
params=policy.grasp_critic.parameters_to_optimize, lr=cfg.policy.critic_lr
|
||||
params=policy.grasp_critic.parameters(), lr=cfg.policy.critic_lr
|
||||
)
|
||||
optimizer_temperature = torch.optim.Adam(params=[policy.log_alpha], lr=cfg.policy.critic_lr)
|
||||
lr_scheduler = None
|
||||
|
@ -1028,8 +1028,10 @@ def get_observation_features(
|
|||
return None, None
|
||||
|
||||
with torch.no_grad():
|
||||
observation_features = policy.actor.encoder.get_image_features(observations, normalize=True)
|
||||
next_observation_features = policy.actor.encoder.get_image_features(next_observations, normalize=True)
|
||||
observation_features = policy.actor.encoder.get_base_image_features(observations, normalize=True)
|
||||
next_observation_features = policy.actor.encoder.get_base_image_features(
|
||||
next_observations, normalize=True
|
||||
)
|
||||
|
||||
return observation_features, next_observation_features
|
||||
|
||||
|
|
Loading…
Reference in New Issue