SAC works

This commit is contained in:
Adil Zouitine 2025-01-14 11:34:52 +01:00
parent 50e12376de
commit d8e67a2609
2 changed files with 10 additions and 49 deletions

View File

@ -254,7 +254,9 @@ class SACPolicy(
next_action_preds, next_log_probs, _ = self.actor(next_observations)
# 2- compute q targets
q_targets = self.critic_forward(next_observations, next_action_preds, use_target=True)
q_targets = self.critic_forward(
observations=next_observations, actions=next_action_preds, use_target=True
)
# subsample critics to prevent overfitting if use high UTD (update to date)
if self.config.num_subsample_critics is not None:
@ -264,9 +266,9 @@ class SACPolicy(
# critics subsample size
min_q, _ = q_targets.min(dim=0) # Get values from min operation
if self.config.use_backup_entropy:
min_q -= temperature * next_log_probs
td_target = rewards + self.config.discount * min_q * ~done
min_q = min_q - (temperature * next_log_probs)
td_target = rewards + (1 - done) * self.config.discount * min_q
# 3- compute predicted qs
q_preds = self.critic_forward(observations, actions, use_target=False)

View File

@ -163,7 +163,9 @@ class ReplayBuffer:
)
# -- Build batched dones --
batch_dones = torch.tensor([t["done"] for t in list_of_transitions], dtype=torch.bool).to(self.device)
batch_dones = torch.tensor([t["done"] for t in list_of_transitions], dtype=torch.float32).to(
self.device
)
# Return a BatchTransition typed dict
return BatchTransition(
@ -174,48 +176,6 @@ class ReplayBuffer:
done=batch_dones,
)
# def sample(self, batch_size: int):
# # 1) Randomly sample transitions
# transitions = random.sample(self.memory, batch_size)
# # 2) For each key in state_keys, gather states [b, state_dim], next_states [b, state_dim]
# batch_state = {}
# batch_next_state = {}
# for key in self.state_keys:
# batch_state[key] = torch.cat([t["state"][key] for t in transitions], dim=0).to(
# self.device
# ) # shape [b, state_dim, ...] depending on your data
# batch_next_state[key] = torch.cat([t["next_state"][key] for t in transitions], dim=0).to(
# self.device
# ) # shape [b, state_dim, ...]
# # 3) Build the other tensors
# batch_action = torch.cat([t["action"] for t in transitions], dim=0).to(
# self.device
# ) # shape [b, ...] or [b, action_dim, ...]
# batch_reward = torch.tensor(
# [t["reward"] for t in transitions], dtype=torch.float32, device=self.device
# ).unsqueeze(dim=-1) # shape [b, 1]
# batch_done = torch.tensor(
# [t["done"] for t in transitions], dtype=torch.bool, device=self.device
# ) # shape [b]
# # 4) Create the observation and next_observation dicts
# #
# # Each key is stacked along dim=1 so final shape is [b, 2, state_dim, ...]
# # - observation[key][..., 0, :] is the current state
# # - observation[key][..., 1, :] is the next state
# # - next_observation[key] duplicates the next state to shape [b, 2, ...]
# observation = {}
# for key in self.state_keys:
# observation[key] = torch.stack([batch_state[key], batch_next_state[key]], dim=1)
# # 5) Return your structure
# ret = observation | {"action": batch_action, "next.reward": batch_reward, "next.done": batch_done}
# return ret
def train(cfg: DictConfig, out_dir: str | None = None, job_name: str | None = None):
if out_dir is None:
@ -297,8 +257,7 @@ def train(cfg: DictConfig, out_dir: str | None = None, job_name: str | None = No
# NOTE: At some point we should use a wrapper to handle the observation
if interaction_step >= cfg.training.online_step_before_learning:
with torch.inference_mode():
action = policy.select_action(batch=obs)
action = policy.select_action(batch=obs)
next_obs, reward, done, truncated, info = online_env.step(action.cpu().numpy())
else:
action = online_env.action_space.sample()