Enhance SAC configuration and policy with new parameters and subsampling logic
- Added `num_subsample_critics`, `critic_target_update_weight`, and `utd_ratio` to SACConfig. - Implemented target entropy calculation in SACPolicy if not provided. - Introduced subsampling of critics to prevent overfitting during updates. - Updated temperature loss calculation to use the new target entropy. - Added comments for future UTD update implementation. These changes improve the flexibility and performance of the SAC implementation.
This commit is contained in:
parent
dbadaae28b
commit
a5228a0dfe
|
@ -23,8 +23,11 @@ class SACConfig:
|
|||
discount = 0.99
|
||||
temperature_init = 1.0
|
||||
num_critics = 2
|
||||
num_subsample_critics = None
|
||||
critic_lr = 3e-4
|
||||
actor_lr = 3e-4
|
||||
critic_target_update_weight = 0.005
|
||||
utd_ratio = 2
|
||||
critic_network_kwargs = {
|
||||
"hidden_dims": [256, 256],
|
||||
"activate_final": True,
|
||||
|
|
|
@ -85,7 +85,8 @@ class SACPolicy(
|
|||
action_dim=config.output_shapes["action"][0],
|
||||
**config.policy_kwargs
|
||||
)
|
||||
|
||||
if config.target_entropy is None:
|
||||
config.target_entropy = -np.prod(config.output_shapes["action"][0]) # (-dim(A))
|
||||
self.temperature = LagrangeMultiplier(init_value=config.temperature_init)
|
||||
|
||||
def reset(self):
|
||||
|
@ -127,7 +128,6 @@ class SACPolicy(
|
|||
# perform image augmentation
|
||||
|
||||
# reward bias
|
||||
# from HIL-SERL code base
|
||||
# add_or_replace={"rewards": batch["rewards"] + self.config["reward_bias"]} in reward_batch
|
||||
|
||||
|
||||
|
@ -136,11 +136,16 @@ class SACPolicy(
|
|||
action_preds, log_probs = self.actor_network(observations)
|
||||
# 2- compute q targets
|
||||
q_targets = self.target_qs(next_observations, action_preds)
|
||||
# subsample critics to prevent overfitting if use high UTD (update to date)
|
||||
if self.config.num_subsample_critics is not None:
|
||||
indices = torch.randperm(self.config.num_critics)
|
||||
indices = indices[:self.config.num_subsample_critics]
|
||||
q_targets = q_targets[indices]
|
||||
|
||||
# critics subsample size
|
||||
min_q = q_targets.min(dim=0)
|
||||
|
||||
# backup entropy
|
||||
# compute td target
|
||||
td_target = rewards + self.discount * min_q
|
||||
|
||||
# 3- compute predicted qs
|
||||
|
@ -182,7 +187,10 @@ class SACPolicy(
|
|||
# calculate temperature loss
|
||||
# 1- calculate entropy
|
||||
entropy = -log_probs.mean()
|
||||
temperature_loss = temperature * (entropy - self.target_entropy).mean()
|
||||
temperature_loss = self.temp(
|
||||
lhs=entropy,
|
||||
rhs=self.config.target_entropy
|
||||
)
|
||||
|
||||
loss = critics_loss + actor_loss + temperature_loss
|
||||
|
||||
|
@ -198,6 +206,11 @@ class SACPolicy(
|
|||
|
||||
def update(self):
|
||||
self.critic_target.lerp_(self.critic_ensemble, self.config.critic_target_update_weight)
|
||||
# TODO: implement UTD update
|
||||
#for critic_step in range(self.config.utd_ratio - 1):
|
||||
# only update critic and critic target
|
||||
# Then update critic, critic target, actor and temperature
|
||||
|
||||
#for target_param, param in zip(self.critic_target.parameters(), self.critic_ensemble.parameters()):
|
||||
# target_param.data.copy_(target_param.data * (1.0 - self.config.critic_target_update_weight) + param.data * self.critic_target_update_weight)
|
||||
|
||||
|
|
Loading…
Reference in New Issue