70 lines
3.3 KiB
Python
70 lines
3.3 KiB
Python
import torch
|
|
from typing import Tuple
|
|
|
|
|
|
def trajectories_to_transitions(trajectories: torch.Tensor, data: Tuple[torch.Tensor, int, bool]) -> torch.Tensor:
|
|
"""Unpacks a tensor of trajectories into a tensor of transitions.
|
|
|
|
Args:
|
|
trajectories (torch.Tensor): A tensor of trajectories.
|
|
data (Tuple[torch.Tensor, int, bool]): A tuple containing the mask and data for the conversion.
|
|
batch_first (bool, optional): Whether the first dimension of the trajectories tensor is the batch dimension.
|
|
Defaults to False.
|
|
Returns:
|
|
A tensor of transitions of shape (batch_size, time, *).
|
|
"""
|
|
mask, batch_size, batch_first = data
|
|
|
|
if not batch_first:
|
|
trajectories, mask = trajectories.transpose(0, 1), mask.transpose(0, 1)
|
|
|
|
transitions = trajectories[mask == 1.0].reshape(batch_size, -1, *trajectories.shape[2:])
|
|
|
|
return transitions
|
|
|
|
|
|
def transitions_to_trajectories(
|
|
transitions: torch.Tensor, dones: torch.Tensor, batch_first: bool = False
|
|
) -> Tuple[torch.Tensor, Tuple[torch.Tensor, int, bool]]:
|
|
"""Packs a tensor of transitions into a tensor of trajectories.
|
|
|
|
Example:
|
|
>>> transitions = torch.tensor([[[1, 2], [3, 4], [5, 6]], [[7, 8], [9, 10], [11, 12]]])
|
|
>>> dones = torch.tensor([[0, 0, 1], [0, 1, 0]])
|
|
>>> transitions_to_trajectories(None, transitions, dones, batch_first=True)
|
|
(tensor([[[1, 2], [3, 4], [5, 6]], [[7, 8], [9, 10], [0, 0]], [[11, 12], [0, 0], [0, 0]]]), tensor([[1, 1, 1], [1, 1, 0], [1, 0, 0]]))
|
|
|
|
Args:
|
|
transitions (torch.Tensor): Tensor of transitions of shape (batch_size, time, *).
|
|
dones (torch.Tensor): Tensor of transition terminations of shape (batch_size, time).
|
|
batch_first (bool): Whether the first dimension of the output tensor should be the batch dimension. Defaults to
|
|
False.
|
|
Returns:
|
|
A torch.Tensor of trajectories of shape (time, trajectory_count, *) that is padded with zeros and data for
|
|
reverting the operation. If batch_first is True, the shape of the trajectories is (trajectory_count, time, *).
|
|
"""
|
|
batch_size = transitions.shape[0]
|
|
|
|
# Count the trajectory lengths by (1) padding dones with a 1 at the end to indicate the end of the trajectory,
|
|
# (2) stacking up the padded dones in a single column, and (3) counting the number of steps between each done by
|
|
# using the row index.
|
|
padded_dones = dones.clone()
|
|
padded_dones[:, -1] = 1
|
|
stacked_dones = torch.cat((padded_dones.new([-1]), padded_dones.reshape(-1, 1).nonzero()[:, 0]))
|
|
trajectory_lengths = stacked_dones[1:] - stacked_dones[:-1]
|
|
|
|
# Compute trajectories by splitting transitions according to previously computed trajectory lengths.
|
|
trajectory_list = torch.split(transitions.flatten(0, 1), trajectory_lengths.int().tolist())
|
|
trajectories = torch.nn.utils.rnn.pad_sequence(trajectory_list, batch_first=batch_first)
|
|
|
|
# The mask is generated by computing a 2d matrix of increasing counts in the 2nd dimension and comparing it to the
|
|
# trajectory lengths.
|
|
range = torch.arange(0, trajectory_lengths.max()).repeat(len(trajectory_lengths), 1)
|
|
range = range.cuda(dones.device) if dones.is_cuda else range
|
|
mask = (trajectory_lengths.unsqueeze(1) > range).float()
|
|
|
|
if not batch_first:
|
|
mask = mask.T
|
|
|
|
return trajectories, (mask, batch_size, batch_first)
|