43 lines
2.1 KiB
Python
43 lines
2.1 KiB
Python
|
# License: see [LICENSE, LICENSES/rsl_rl/LICENSE]
|
||
|
|
||
|
import torch
|
||
|
|
||
|
def split_and_pad_trajectories(tensor, dones):
|
||
|
""" Splits trajectories at done indices. Then concatenates them and padds with zeros up to the length og the longest trajectory.
|
||
|
Returns masks corresponding to valid parts of the trajectories
|
||
|
Example:
|
||
|
Input: [ [a1, a2, a3, a4 | a5, a6],
|
||
|
[b1, b2 | b3, b4, b5 | b6]
|
||
|
f]
|
||
|
|
||
|
Output:[ [a1, a2, a3, a4], | [ [True, True, True, True],
|
||
|
[a5, a6, 0, 0], | [True, True, False, False],
|
||
|
[b1, b2, 0, 0], | [True, True, False, False],
|
||
|
[b3, b4, b5, 0], | [True, True, True, False],
|
||
|
[b6, 0, 0, 0] | [True, False, False, False],
|
||
|
] | ]
|
||
|
|
||
|
Assumes that the inputy has the following dimension order: [time, number of envs, aditional dimensions]
|
||
|
"""
|
||
|
dones = dones.clone()
|
||
|
dones[-1] = 1
|
||
|
# Permute the buffers to have order (num_envs, num_transitions_per_env, ...), for correct reshaping
|
||
|
flat_dones = dones.transpose(1, 0).reshape(-1, 1)
|
||
|
|
||
|
# Get length of trajectory by counting the number of successive not done elements
|
||
|
done_indices = torch.cat((flat_dones.new_tensor([-1], dtype=torch.int64), flat_dones.nonzero()[:, 0]))
|
||
|
trajectory_lengths = done_indices[1:] - done_indices[:-1]
|
||
|
trajectory_lengths_list = trajectory_lengths.tolist()
|
||
|
# Extract the individual trajectories
|
||
|
trajectories = torch.split(tensor.transpose(1, 0).flatten(0, 1),trajectory_lengths_list)
|
||
|
padded_trajectories = torch.nn.utils.rnn.pad_sequence(trajectories)
|
||
|
|
||
|
|
||
|
trajectory_masks = trajectory_lengths > torch.arange(0, tensor.shape[0], device=tensor.device).unsqueeze(1)
|
||
|
return padded_trajectories, trajectory_masks
|
||
|
|
||
|
def unpad_trajectories(trajectories, masks):
|
||
|
""" Does the inverse operation of split_and_pad_trajectories()
|
||
|
"""
|
||
|
# Need to transpose before and after the masking to have proper reshaping
|
||
|
return trajectories.transpose(1, 0)[masks.transpose(1, 0)].view(-1, trajectories.shape[0], trajectories.shape[-1]).transpose(1, 0)
|