From 7e0f20fbf285418a78b8619107371f2f0a6c7fd1 Mon Sep 17 00:00:00 2001 From: KeWang1017 Date: Tue, 17 Dec 2024 15:58:04 +0000 Subject: [PATCH] Enhance SAC configuration and policy with new parameters and subsampling logic - Added `num_subsample_critics`, `critic_target_update_weight`, and `utd_ratio` to SACConfig. - Implemented target entropy calculation in SACPolicy if not provided. - Introduced subsampling of critics to prevent overfitting during updates. - Updated temperature loss calculation to use the new target entropy. - Added comments for future UTD update implementation. These changes improve the flexibility and performance of the SAC implementation. --- .../common/policies/sac/configuration_sac.py | 3 +++ lerobot/common/policies/sac/modeling_sac.py | 21 +++++++++++++++---- 2 files changed, 20 insertions(+), 4 deletions(-) diff --git a/lerobot/common/policies/sac/configuration_sac.py b/lerobot/common/policies/sac/configuration_sac.py index 441b3566..d324462e 100644 --- a/lerobot/common/policies/sac/configuration_sac.py +++ b/lerobot/common/policies/sac/configuration_sac.py @@ -23,8 +23,11 @@ class SACConfig: discount = 0.99 temperature_init = 1.0 num_critics = 2 + num_subsample_critics = None critic_lr = 3e-4 actor_lr = 3e-4 + critic_target_update_weight = 0.005 + utd_ratio = 2 critic_network_kwargs = { "hidden_dims": [256, 256], "activate_final": True, diff --git a/lerobot/common/policies/sac/modeling_sac.py b/lerobot/common/policies/sac/modeling_sac.py index 9ea9449d..7d451b4e 100644 --- a/lerobot/common/policies/sac/modeling_sac.py +++ b/lerobot/common/policies/sac/modeling_sac.py @@ -85,7 +85,8 @@ class SACPolicy( action_dim=config.output_shapes["action"][0], **config.policy_kwargs ) - + if config.target_entropy is None: + config.target_entropy = -np.prod(config.output_shapes["action"][0]) # (-dim(A)) self.temperature = LagrangeMultiplier(init_value=config.temperature_init) def reset(self): @@ -127,7 +128,6 @@ class SACPolicy( # perform image augmentation # reward bias - # from HIL-SERL code base # add_or_replace={"rewards": batch["rewards"] + self.config["reward_bias"]} in reward_batch @@ -136,11 +136,16 @@ class SACPolicy( action_preds, log_probs = self.actor_network(observations) # 2- compute q targets q_targets = self.target_qs(next_observations, action_preds) + # subsample critics to prevent overfitting if use high UTD (update to date) + if self.config.num_subsample_critics is not None: + indices = torch.randperm(self.config.num_critics) + indices = indices[:self.config.num_subsample_critics] + q_targets = q_targets[indices] # critics subsample size min_q = q_targets.min(dim=0) - # backup entropy + # compute td target td_target = rewards + self.discount * min_q # 3- compute predicted qs @@ -182,7 +187,10 @@ class SACPolicy( # calculate temperature loss # 1- calculate entropy entropy = -log_probs.mean() - temperature_loss = temperature * (entropy - self.target_entropy).mean() + temperature_loss = self.temp( + lhs=entropy, + rhs=self.config.target_entropy + ) loss = critics_loss + actor_loss + temperature_loss @@ -198,6 +206,11 @@ class SACPolicy( def update(self): self.critic_target.lerp_(self.critic_ensemble, self.config.critic_target_update_weight) + # TODO: implement UTD update + #for critic_step in range(self.config.utd_ratio - 1): + # only update critic and critic target + # Then update critic, critic target, actor and temperature + #for target_param, param in zip(self.critic_target.parameters(), self.critic_ensemble.parameters()): # target_param.data.copy_(target_param.data * (1.0 - self.config.critic_target_update_weight) + param.data * self.critic_target_update_weight)