# License: see [LICENSE, LICENSES/rsl_rl/LICENSE] import torch def split_and_pad_trajectories(tensor, dones): """ Splits trajectories at done indices. Then concatenates them and padds with zeros up to the length og the longest trajectory. Returns masks corresponding to valid parts of the trajectories Example: Input: [ [a1, a2, a3, a4 | a5, a6], [b1, b2 | b3, b4, b5 | b6] f] Output:[ [a1, a2, a3, a4], | [ [True, True, True, True], [a5, a6, 0, 0], | [True, True, False, False], [b1, b2, 0, 0], | [True, True, False, False], [b3, b4, b5, 0], | [True, True, True, False], [b6, 0, 0, 0] | [True, False, False, False], ] | ] Assumes that the inputy has the following dimension order: [time, number of envs, aditional dimensions] """ dones = dones.clone() dones[-1] = 1 # Permute the buffers to have order (num_envs, num_transitions_per_env, ...), for correct reshaping flat_dones = dones.transpose(1, 0).reshape(-1, 1) # Get length of trajectory by counting the number of successive not done elements done_indices = torch.cat((flat_dones.new_tensor([-1], dtype=torch.int64), flat_dones.nonzero()[:, 0])) trajectory_lengths = done_indices[1:] - done_indices[:-1] trajectory_lengths_list = trajectory_lengths.tolist() # Extract the individual trajectories trajectories = torch.split(tensor.transpose(1, 0).flatten(0, 1),trajectory_lengths_list) padded_trajectories = torch.nn.utils.rnn.pad_sequence(trajectories) trajectory_masks = trajectory_lengths > torch.arange(0, tensor.shape[0], device=tensor.device).unsqueeze(1) return padded_trajectories, trajectory_masks def unpad_trajectories(trajectories, masks): """ Does the inverse operation of split_and_pad_trajectories() """ # Need to transpose before and after the masking to have proper reshaping return trajectories.transpose(1, 0)[masks.transpose(1, 0)].view(-1, trajectories.shape[0], trajectories.shape[-1]).transpose(1, 0)