walk-these-ways-go2/go1_gym_learn/utils/utils.py

43 lines
2.1 KiB
Python
Raw Normal View History

2024-01-28 17:11:38 +08:00
# License: see [LICENSE, LICENSES/rsl_rl/LICENSE]
import torch
def split_and_pad_trajectories(tensor, dones):
""" Splits trajectories at done indices. Then concatenates them and padds with zeros up to the length og the longest trajectory.
Returns masks corresponding to valid parts of the trajectories
Example:
Input: [ [a1, a2, a3, a4 | a5, a6],
[b1, b2 | b3, b4, b5 | b6]
f]
Output:[ [a1, a2, a3, a4], | [ [True, True, True, True],
[a5, a6, 0, 0], | [True, True, False, False],
[b1, b2, 0, 0], | [True, True, False, False],
[b3, b4, b5, 0], | [True, True, True, False],
[b6, 0, 0, 0] | [True, False, False, False],
] | ]
Assumes that the inputy has the following dimension order: [time, number of envs, aditional dimensions]
"""
dones = dones.clone()
dones[-1] = 1
# Permute the buffers to have order (num_envs, num_transitions_per_env, ...), for correct reshaping
flat_dones = dones.transpose(1, 0).reshape(-1, 1)
# Get length of trajectory by counting the number of successive not done elements
done_indices = torch.cat((flat_dones.new_tensor([-1], dtype=torch.int64), flat_dones.nonzero()[:, 0]))
trajectory_lengths = done_indices[1:] - done_indices[:-1]
trajectory_lengths_list = trajectory_lengths.tolist()
# Extract the individual trajectories
trajectories = torch.split(tensor.transpose(1, 0).flatten(0, 1),trajectory_lengths_list)
padded_trajectories = torch.nn.utils.rnn.pad_sequence(trajectories)
trajectory_masks = trajectory_lengths > torch.arange(0, tensor.shape[0], device=tensor.device).unsqueeze(1)
return padded_trajectories, trajectory_masks
def unpad_trajectories(trajectories, masks):
""" Does the inverse operation of split_and_pad_trajectories()
"""
# Need to transpose before and after the masking to have proper reshaping
return trajectories.transpose(1, 0)[masks.transpose(1, 0)].view(-1, trajectories.shape[0], trajectories.shape[-1]).transpose(1, 0)