SAC works
This commit is contained in:
parent
50e12376de
commit
d8e67a2609
|
@ -254,7 +254,9 @@ class SACPolicy(
|
|||
next_action_preds, next_log_probs, _ = self.actor(next_observations)
|
||||
|
||||
# 2- compute q targets
|
||||
q_targets = self.critic_forward(next_observations, next_action_preds, use_target=True)
|
||||
q_targets = self.critic_forward(
|
||||
observations=next_observations, actions=next_action_preds, use_target=True
|
||||
)
|
||||
|
||||
# subsample critics to prevent overfitting if use high UTD (update to date)
|
||||
if self.config.num_subsample_critics is not None:
|
||||
|
@ -264,9 +266,9 @@ class SACPolicy(
|
|||
|
||||
# critics subsample size
|
||||
min_q, _ = q_targets.min(dim=0) # Get values from min operation
|
||||
if self.config.use_backup_entropy:
|
||||
min_q -= temperature * next_log_probs
|
||||
td_target = rewards + self.config.discount * min_q * ~done
|
||||
min_q = min_q - (temperature * next_log_probs)
|
||||
|
||||
td_target = rewards + (1 - done) * self.config.discount * min_q
|
||||
|
||||
# 3- compute predicted qs
|
||||
q_preds = self.critic_forward(observations, actions, use_target=False)
|
||||
|
|
|
@ -163,7 +163,9 @@ class ReplayBuffer:
|
|||
)
|
||||
|
||||
# -- Build batched dones --
|
||||
batch_dones = torch.tensor([t["done"] for t in list_of_transitions], dtype=torch.bool).to(self.device)
|
||||
batch_dones = torch.tensor([t["done"] for t in list_of_transitions], dtype=torch.float32).to(
|
||||
self.device
|
||||
)
|
||||
|
||||
# Return a BatchTransition typed dict
|
||||
return BatchTransition(
|
||||
|
@ -174,48 +176,6 @@ class ReplayBuffer:
|
|||
done=batch_dones,
|
||||
)
|
||||
|
||||
# def sample(self, batch_size: int):
|
||||
# # 1) Randomly sample transitions
|
||||
# transitions = random.sample(self.memory, batch_size)
|
||||
|
||||
# # 2) For each key in state_keys, gather states [b, state_dim], next_states [b, state_dim]
|
||||
# batch_state = {}
|
||||
# batch_next_state = {}
|
||||
# for key in self.state_keys:
|
||||
# batch_state[key] = torch.cat([t["state"][key] for t in transitions], dim=0).to(
|
||||
# self.device
|
||||
# ) # shape [b, state_dim, ...] depending on your data
|
||||
# batch_next_state[key] = torch.cat([t["next_state"][key] for t in transitions], dim=0).to(
|
||||
# self.device
|
||||
# ) # shape [b, state_dim, ...]
|
||||
|
||||
# # 3) Build the other tensors
|
||||
# batch_action = torch.cat([t["action"] for t in transitions], dim=0).to(
|
||||
# self.device
|
||||
# ) # shape [b, ...] or [b, action_dim, ...]
|
||||
|
||||
# batch_reward = torch.tensor(
|
||||
# [t["reward"] for t in transitions], dtype=torch.float32, device=self.device
|
||||
# ).unsqueeze(dim=-1) # shape [b, 1]
|
||||
|
||||
# batch_done = torch.tensor(
|
||||
# [t["done"] for t in transitions], dtype=torch.bool, device=self.device
|
||||
# ) # shape [b]
|
||||
|
||||
# # 4) Create the observation and next_observation dicts
|
||||
# #
|
||||
# # Each key is stacked along dim=1 so final shape is [b, 2, state_dim, ...]
|
||||
# # - observation[key][..., 0, :] is the current state
|
||||
# # - observation[key][..., 1, :] is the next state
|
||||
# # - next_observation[key] duplicates the next state to shape [b, 2, ...]
|
||||
# observation = {}
|
||||
# for key in self.state_keys:
|
||||
# observation[key] = torch.stack([batch_state[key], batch_next_state[key]], dim=1)
|
||||
|
||||
# # 5) Return your structure
|
||||
# ret = observation | {"action": batch_action, "next.reward": batch_reward, "next.done": batch_done}
|
||||
# return ret
|
||||
|
||||
|
||||
def train(cfg: DictConfig, out_dir: str | None = None, job_name: str | None = None):
|
||||
if out_dir is None:
|
||||
|
@ -297,8 +257,7 @@ def train(cfg: DictConfig, out_dir: str | None = None, job_name: str | None = No
|
|||
# NOTE: At some point we should use a wrapper to handle the observation
|
||||
|
||||
if interaction_step >= cfg.training.online_step_before_learning:
|
||||
with torch.inference_mode():
|
||||
action = policy.select_action(batch=obs)
|
||||
action = policy.select_action(batch=obs)
|
||||
next_obs, reward, done, truncated, info = online_env.step(action.cpu().numpy())
|
||||
else:
|
||||
action = online_env.action_space.sample()
|
||||
|
|
Loading…
Reference in New Issue