rsl_rl/rsl_rl/utils/recurrency.py

70 lines
3.3 KiB
Python

import torch
from typing import Tuple
def trajectories_to_transitions(trajectories: torch.Tensor, data: Tuple[torch.Tensor, int, bool]) -> torch.Tensor:
"""Unpacks a tensor of trajectories into a tensor of transitions.
Args:
trajectories (torch.Tensor): A tensor of trajectories.
data (Tuple[torch.Tensor, int, bool]): A tuple containing the mask and data for the conversion.
batch_first (bool, optional): Whether the first dimension of the trajectories tensor is the batch dimension.
Defaults to False.
Returns:
A tensor of transitions of shape (batch_size, time, *).
"""
mask, batch_size, batch_first = data
if not batch_first:
trajectories, mask = trajectories.transpose(0, 1), mask.transpose(0, 1)
transitions = trajectories[mask == 1.0].reshape(batch_size, -1, *trajectories.shape[2:])
return transitions
def transitions_to_trajectories(
transitions: torch.Tensor, dones: torch.Tensor, batch_first: bool = False
) -> Tuple[torch.Tensor, Tuple[torch.Tensor, int, bool]]:
"""Packs a tensor of transitions into a tensor of trajectories.
Example:
>>> transitions = torch.tensor([[[1, 2], [3, 4], [5, 6]], [[7, 8], [9, 10], [11, 12]]])
>>> dones = torch.tensor([[0, 0, 1], [0, 1, 0]])
>>> transitions_to_trajectories(None, transitions, dones, batch_first=True)
(tensor([[[1, 2], [3, 4], [5, 6]], [[7, 8], [9, 10], [0, 0]], [[11, 12], [0, 0], [0, 0]]]), tensor([[1, 1, 1], [1, 1, 0], [1, 0, 0]]))
Args:
transitions (torch.Tensor): Tensor of transitions of shape (batch_size, time, *).
dones (torch.Tensor): Tensor of transition terminations of shape (batch_size, time).
batch_first (bool): Whether the first dimension of the output tensor should be the batch dimension. Defaults to
False.
Returns:
A torch.Tensor of trajectories of shape (time, trajectory_count, *) that is padded with zeros and data for
reverting the operation. If batch_first is True, the shape of the trajectories is (trajectory_count, time, *).
"""
batch_size = transitions.shape[0]
# Count the trajectory lengths by (1) padding dones with a 1 at the end to indicate the end of the trajectory,
# (2) stacking up the padded dones in a single column, and (3) counting the number of steps between each done by
# using the row index.
padded_dones = dones.clone()
padded_dones[:, -1] = 1
stacked_dones = torch.cat((padded_dones.new([-1]), padded_dones.reshape(-1, 1).nonzero()[:, 0]))
trajectory_lengths = stacked_dones[1:] - stacked_dones[:-1]
# Compute trajectories by splitting transitions according to previously computed trajectory lengths.
trajectory_list = torch.split(transitions.flatten(0, 1), trajectory_lengths.int().tolist())
trajectories = torch.nn.utils.rnn.pad_sequence(trajectory_list, batch_first=batch_first)
# The mask is generated by computing a 2d matrix of increasing counts in the 2nd dimension and comparing it to the
# trajectory lengths.
range = torch.arange(0, trajectory_lengths.max()).repeat(len(trajectory_lengths), 1)
range = range.cuda(dones.device) if dones.is_cuda else range
mask = (trajectory_lengths.unsqueeze(1) > range).float()
if not batch_first:
mask = mask.T
return trajectories, (mask, batch_size, batch_first)