From a0a50de8c9869d6a6eafc6025ced2c906a3b7ce0 Mon Sep 17 00:00:00 2001 From: Adil Zouitine Date: Tue, 14 Jan 2025 11:34:52 +0100 Subject: [PATCH] SAC works --- lerobot/common/policies/sac/modeling_sac.py | 10 +++-- lerobot/scripts/train_sac.py | 49 ++------------------- 2 files changed, 10 insertions(+), 49 deletions(-) diff --git a/lerobot/common/policies/sac/modeling_sac.py b/lerobot/common/policies/sac/modeling_sac.py index d48cec88..f2d10ae5 100644 --- a/lerobot/common/policies/sac/modeling_sac.py +++ b/lerobot/common/policies/sac/modeling_sac.py @@ -254,7 +254,9 @@ class SACPolicy( next_action_preds, next_log_probs, _ = self.actor(next_observations) # 2- compute q targets - q_targets = self.critic_forward(next_observations, next_action_preds, use_target=True) + q_targets = self.critic_forward( + observations=next_observations, actions=next_action_preds, use_target=True + ) # subsample critics to prevent overfitting if use high UTD (update to date) if self.config.num_subsample_critics is not None: @@ -264,9 +266,9 @@ class SACPolicy( # critics subsample size min_q, _ = q_targets.min(dim=0) # Get values from min operation - if self.config.use_backup_entropy: - min_q -= temperature * next_log_probs - td_target = rewards + self.config.discount * min_q * ~done + min_q = min_q - (temperature * next_log_probs) + + td_target = rewards + (1 - done) * self.config.discount * min_q # 3- compute predicted qs q_preds = self.critic_forward(observations, actions, use_target=False) diff --git a/lerobot/scripts/train_sac.py b/lerobot/scripts/train_sac.py index 9edafb76..30891db9 100644 --- a/lerobot/scripts/train_sac.py +++ b/lerobot/scripts/train_sac.py @@ -163,7 +163,9 @@ class ReplayBuffer: ) # -- Build batched dones -- - batch_dones = torch.tensor([t["done"] for t in list_of_transitions], dtype=torch.bool).to(self.device) + batch_dones = torch.tensor([t["done"] for t in list_of_transitions], dtype=torch.float32).to( + self.device + ) # Return a BatchTransition typed dict return BatchTransition( @@ -174,48 +176,6 @@ class ReplayBuffer: done=batch_dones, ) - # def sample(self, batch_size: int): - # # 1) Randomly sample transitions - # transitions = random.sample(self.memory, batch_size) - - # # 2) For each key in state_keys, gather states [b, state_dim], next_states [b, state_dim] - # batch_state = {} - # batch_next_state = {} - # for key in self.state_keys: - # batch_state[key] = torch.cat([t["state"][key] for t in transitions], dim=0).to( - # self.device - # ) # shape [b, state_dim, ...] depending on your data - # batch_next_state[key] = torch.cat([t["next_state"][key] for t in transitions], dim=0).to( - # self.device - # ) # shape [b, state_dim, ...] - - # # 3) Build the other tensors - # batch_action = torch.cat([t["action"] for t in transitions], dim=0).to( - # self.device - # ) # shape [b, ...] or [b, action_dim, ...] - - # batch_reward = torch.tensor( - # [t["reward"] for t in transitions], dtype=torch.float32, device=self.device - # ).unsqueeze(dim=-1) # shape [b, 1] - - # batch_done = torch.tensor( - # [t["done"] for t in transitions], dtype=torch.bool, device=self.device - # ) # shape [b] - - # # 4) Create the observation and next_observation dicts - # # - # # Each key is stacked along dim=1 so final shape is [b, 2, state_dim, ...] - # # - observation[key][..., 0, :] is the current state - # # - observation[key][..., 1, :] is the next state - # # - next_observation[key] duplicates the next state to shape [b, 2, ...] - # observation = {} - # for key in self.state_keys: - # observation[key] = torch.stack([batch_state[key], batch_next_state[key]], dim=1) - - # # 5) Return your structure - # ret = observation | {"action": batch_action, "next.reward": batch_reward, "next.done": batch_done} - # return ret - def train(cfg: DictConfig, out_dir: str | None = None, job_name: str | None = None): if out_dir is None: @@ -297,8 +257,7 @@ def train(cfg: DictConfig, out_dir: str | None = None, job_name: str | None = No # NOTE: At some point we should use a wrapper to handle the observation if interaction_step >= cfg.training.online_step_before_learning: - with torch.inference_mode(): - action = policy.select_action(batch=obs) + action = policy.select_action(batch=obs) next_obs, reward, done, truncated, info = online_env.step(action.cpu().numpy()) else: action = online_env.action_space.sample()