Enhance SAC configuration and policy with new parameters and subsampling logic

- Added `num_subsample_critics`, `critic_target_update_weight`, and `utd_ratio` to SACConfig.
- Implemented target entropy calculation in SACPolicy if not provided.
- Introduced subsampling of critics to prevent overfitting during updates.
- Updated temperature loss calculation to use the new target entropy.
- Added comments for future UTD update implementation.

These changes improve the flexibility and performance of the SAC implementation.
This commit is contained in:
KeWang1017 2024-12-17 15:58:04 +00:00 committed by AdilZouitine
parent dbadaae28b
commit a5228a0dfe
2 changed files with 20 additions and 4 deletions

View File

@ -23,8 +23,11 @@ class SACConfig:
discount = 0.99
temperature_init = 1.0
num_critics = 2
num_subsample_critics = None
critic_lr = 3e-4
actor_lr = 3e-4
critic_target_update_weight = 0.005
utd_ratio = 2
critic_network_kwargs = {
"hidden_dims": [256, 256],
"activate_final": True,

View File

@ -85,7 +85,8 @@ class SACPolicy(
action_dim=config.output_shapes["action"][0],
**config.policy_kwargs
)
if config.target_entropy is None:
config.target_entropy = -np.prod(config.output_shapes["action"][0]) # (-dim(A))
self.temperature = LagrangeMultiplier(init_value=config.temperature_init)
def reset(self):
@ -127,7 +128,6 @@ class SACPolicy(
# perform image augmentation
# reward bias
# from HIL-SERL code base
# add_or_replace={"rewards": batch["rewards"] + self.config["reward_bias"]} in reward_batch
@ -136,11 +136,16 @@ class SACPolicy(
action_preds, log_probs = self.actor_network(observations)
# 2- compute q targets
q_targets = self.target_qs(next_observations, action_preds)
# subsample critics to prevent overfitting if use high UTD (update to date)
if self.config.num_subsample_critics is not None:
indices = torch.randperm(self.config.num_critics)
indices = indices[:self.config.num_subsample_critics]
q_targets = q_targets[indices]
# critics subsample size
min_q = q_targets.min(dim=0)
# backup entropy
# compute td target
td_target = rewards + self.discount * min_q
# 3- compute predicted qs
@ -182,7 +187,10 @@ class SACPolicy(
# calculate temperature loss
# 1- calculate entropy
entropy = -log_probs.mean()
temperature_loss = temperature * (entropy - self.target_entropy).mean()
temperature_loss = self.temp(
lhs=entropy,
rhs=self.config.target_entropy
)
loss = critics_loss + actor_loss + temperature_loss
@ -198,6 +206,11 @@ class SACPolicy(
def update(self):
self.critic_target.lerp_(self.critic_ensemble, self.config.critic_target_update_weight)
# TODO: implement UTD update
#for critic_step in range(self.config.utd_ratio - 1):
# only update critic and critic target
# Then update critic, critic target, actor and temperature
#for target_param, param in zip(self.critic_target.parameters(), self.critic_ensemble.parameters()):
# target_param.data.copy_(target_param.data * (1.0 - self.config.critic_target_update_weight) + param.data * self.critic_target_update_weight)