rsl_rl/tests/test_trajectory_conversion.py

34 lines
1.2 KiB
Python

import torch
import unittest
from rsl_rl.utils.recurrency import trajectories_to_transitions, transitions_to_trajectories
class TrajectoryConversionTest(unittest.TestCase):
def test_basic_conversion(self):
input = torch.rand(128, 24)
dones = (torch.rand(128, 24) > 0.8).float()
trajectories, data = transitions_to_trajectories(input, dones)
transitions = trajectories_to_transitions(trajectories, data)
self.assertTrue(torch.allclose(input, transitions))
def test_2d_observations(self):
input = torch.rand(128, 24, 32)
dones = (torch.rand(128, 24) > 0.8).float()
trajectories, data = transitions_to_trajectories(input, dones)
transitions = trajectories_to_transitions(trajectories, data)
self.assertTrue(torch.allclose(input, transitions))
def test_batch_first(self):
input = torch.rand(128, 24, 32)
dones = (torch.rand(128, 24) > 0.8).float()
trajectories, data = transitions_to_trajectories(input, dones, batch_first=True)
transitions = trajectories_to_transitions(trajectories, data)
self.assertTrue(torch.allclose(input, transitions))