import torch from go2_gym_learn.utils import split_and_pad_trajectories class RolloutStorage: class Transition: def __init__(self): self.observations = None self.privileged_observations = None self.observation_histories = None self.critic_observations = None self.actions = None self.rewards = None self.dones = None self.values = None self.actions_log_prob = None self.action_mean = None self.action_sigma = None self.env_bins = None def clear(self): self.__init__() def __init__(self, num_envs, num_transitions_per_env, obs_shape, privileged_obs_shape, obs_history_shape, actions_shape, device='cpu'): self.device = device self.obs_shape = obs_shape self.privileged_obs_shape = privileged_obs_shape self.obs_history_shape = obs_history_shape self.actions_shape = actions_shape # Core self.observations = torch.zeros(num_transitions_per_env, num_envs, *obs_shape, device=self.device) self.privileged_observations = torch.zeros(num_transitions_per_env, num_envs, *privileged_obs_shape, device=self.device) self.observation_histories = torch.zeros(num_transitions_per_env, num_envs, *obs_history_shape, device=self.device) self.rewards = torch.zeros(num_transitions_per_env, num_envs, 1, device=self.device) self.actions = torch.zeros(num_transitions_per_env, num_envs, *actions_shape, device=self.device) self.dones = torch.zeros(num_transitions_per_env, num_envs, 1, device=self.device).byte() # For PPO self.actions_log_prob = torch.zeros(num_transitions_per_env, num_envs, 1, device=self.device) self.values = torch.zeros(num_transitions_per_env, num_envs, 1, device=self.device) self.returns = torch.zeros(num_transitions_per_env, num_envs, 1, device=self.device) self.advantages = torch.zeros(num_transitions_per_env, num_envs, 1, device=self.device) self.mu = torch.zeros(num_transitions_per_env, num_envs, *actions_shape, device=self.device) self.sigma = torch.zeros(num_transitions_per_env, num_envs, *actions_shape, device=self.device) self.env_bins = torch.zeros(num_transitions_per_env, num_envs, 1, device=self.device) self.num_transitions_per_env = num_transitions_per_env self.num_envs = num_envs self.step = 0 def add_transitions(self, transition: Transition): if self.step >= self.num_transitions_per_env: raise AssertionError("Rollout buffer overflow") self.observations[self.step].copy_(transition.observations) self.privileged_observations[self.step].copy_(transition.privileged_observations) self.observation_histories[self.step].copy_(transition.observation_histories) self.actions[self.step].copy_(transition.actions) self.rewards[self.step].copy_(transition.rewards.view(-1, 1)) self.dones[self.step].copy_(transition.dones.view(-1, 1)) self.values[self.step].copy_(transition.values) self.actions_log_prob[self.step].copy_(transition.actions_log_prob.view(-1, 1)) self.mu[self.step].copy_(transition.action_mean) self.sigma[self.step].copy_(transition.action_sigma) self.env_bins[self.step].copy_(transition.env_bins.view(-1, 1)) self.step += 1 def clear(self): self.step = 0 def compute_returns(self, last_values, gamma, lam): advantage = 0 for step in reversed(range(self.num_transitions_per_env)): if step == self.num_transitions_per_env - 1: next_values = last_values else: next_values = self.values[step + 1] next_is_not_terminal = 1.0 - self.dones[step].float() delta = self.rewards[step] + next_is_not_terminal * gamma * next_values - self.values[step] advantage = delta + next_is_not_terminal * gamma * lam * advantage self.returns[step] = advantage + self.values[step] # Compute and normalize the advantages self.advantages = self.returns - self.values self.advantages = (self.advantages - self.advantages.mean()) / (self.advantages.std() + 1e-8) def get_statistics(self): done = self.dones done[-1] = 1 flat_dones = done.permute(1, 0, 2).reshape(-1, 1) done_indices = torch.cat((flat_dones.new_tensor([-1], dtype=torch.int64), flat_dones.nonzero(as_tuple=False)[:, 0])) trajectory_lengths = (done_indices[1:] - done_indices[:-1]) return trajectory_lengths.float().mean(), self.rewards.mean() def mini_batch_generator(self, num_mini_batches, num_epochs=8): batch_size = self.num_envs * self.num_transitions_per_env mini_batch_size = batch_size // num_mini_batches indices = torch.randperm(num_mini_batches*mini_batch_size, requires_grad=False, device=self.device) observations = self.observations.flatten(0, 1) privileged_obs = self.privileged_observations.flatten(0, 1) obs_history = self.observation_histories.flatten(0, 1) critic_observations = observations actions = self.actions.flatten(0, 1) values = self.values.flatten(0, 1) returns = self.returns.flatten(0, 1) old_actions_log_prob = self.actions_log_prob.flatten(0, 1) advantages = self.advantages.flatten(0, 1) old_mu = self.mu.flatten(0, 1) old_sigma = self.sigma.flatten(0, 1) old_env_bins = self.env_bins.flatten(0, 1) for epoch in range(num_epochs): for i in range(num_mini_batches): start = i*mini_batch_size end = (i+1)*mini_batch_size batch_idx = indices[start:end] obs_batch = observations[batch_idx] critic_observations_batch = critic_observations[batch_idx] privileged_obs_batch = privileged_obs[batch_idx] obs_history_batch = obs_history[batch_idx] actions_batch = actions[batch_idx] target_values_batch = values[batch_idx] returns_batch = returns[batch_idx] old_actions_log_prob_batch = old_actions_log_prob[batch_idx] advantages_batch = advantages[batch_idx] old_mu_batch = old_mu[batch_idx] old_sigma_batch = old_sigma[batch_idx] env_bins_batch = old_env_bins[batch_idx] yield obs_batch, critic_observations_batch, privileged_obs_batch, obs_history_batch, actions_batch, target_values_batch, advantages_batch, returns_batch, \ old_actions_log_prob_batch, old_mu_batch, old_sigma_batch, None, env_bins_batch # for RNNs only def reccurent_mini_batch_generator(self, num_mini_batches, num_epochs=8): padded_obs_trajectories, trajectory_masks = split_and_pad_trajectories(self.observations, self.dones) padded_privileged_obs_trajectories, trajectory_masks = split_and_pad_trajectories(self.privileged_observations, self.dones) padded_obs_history_trajectories, trajectory_masks = split_and_pad_trajectories(self.observation_histories, self.dones) padded_critic_obs_trajectories = padded_obs_trajectories mini_batch_size = self.num_envs // num_mini_batches for ep in range(num_epochs): first_traj = 0 for i in range(num_mini_batches): start = i*mini_batch_size stop = (i+1)*mini_batch_size dones = self.dones.squeeze(-1) last_was_done = torch.zeros_like(dones, dtype=torch.bool) last_was_done[1:] = dones[:-1] last_was_done[0] = True trajectories_batch_size = torch.sum(last_was_done[:, start:stop]) last_traj = first_traj + trajectories_batch_size masks_batch = trajectory_masks[:, first_traj:last_traj] obs_batch = padded_obs_trajectories[:, first_traj:last_traj] critic_obs_batch = padded_critic_obs_trajectories[:, first_traj:last_traj] privileged_obs_batch = padded_privileged_obs_trajectories[:, first_traj:last_traj] obs_history_batch = padded_obs_history_trajectories[:, first_traj:last_traj] actions_batch = self.actions[:, start:stop] old_mu_batch = self.mu[:, start:stop] old_sigma_batch = self.sigma[:, start:stop] returns_batch = self.returns[:, start:stop] advantages_batch = self.advantages[:, start:stop] values_batch = self.values[:, start:stop] old_actions_log_prob_batch = self.actions_log_prob[:, start:stop] yield obs_batch, critic_obs_batch, privileged_obs_batch, obs_history_batch, actions_batch, values_batch, advantages_batch, returns_batch, \ old_actions_log_prob_batch, old_mu_batch, old_sigma_batch, masks_batch first_traj = last_traj