rsl_rl/tests/test_dppo.py

279 lines
10 KiB
Python
Raw Normal View History

import torch
import unittest
from rsl_rl.algorithms import DPPO
from rsl_rl.env.pole_balancing import PoleBalancing
class FakeCritic(torch.nn.Module):
def __init__(self, values, quantile_count=1):
self.quantile_count = quantile_count
self.recurrent = False
self.values = values
self.last_quantiles = values
def forward(self, _, distribution=False, measure_args=None):
if distribution:
return self.values
return self.values.mean(-1)
def quantiles_to_values(self, quantiles):
return quantiles.mean(-1)
class DPPOTest(unittest.TestCase):
def test_gae_computation(self):
# GT taken from old PPO implementation.
env = PoleBalancing(environment_count=4)
dppo = DPPO(env, device="cpu", gae_lambda=0.97, gamma=0.99, qrdqn_quantile_count=1)
rewards = torch.tensor(
[
[-1.0000e02, -1.4055e-01, -3.0476e-02, -2.7149e-01, -1.1157e-01, -2.3366e-01, -3.3658e-01, -1.6447e-01],
[
-1.7633e-01,
-2.6533e-01,
-3.0786e-01,
-2.6038e-01,
-2.7176e-01,
-2.1655e-01,
-1.5441e-01,
-2.9580e-01,
],
[-1.5952e-01, -1.5177e-01, -1.4296e-01, -1.6131e-01, -3.1395e-02, 2.8808e-03, -3.1242e-02, 4.8696e-03],
[1.1407e-02, -1.0000e02, -6.2290e-02, -3.7030e-01, -2.7648e-01, -3.6655e-01, -2.8456e-01, -2.3165e-01],
]
)
dones = torch.tensor(
[
[1, 0, 0, 0, 0, 0, 0, 0],
[0, 0, 0, 0, 0, 0, 0, 0],
[0, 0, 0, 0, 0, 0, 0, 0],
[0, 1, 0, 0, 0, 0, 0, 0],
]
)
observations = torch.zeros((dones.shape[0], dones.shape[1], 24))
timeouts = torch.zeros((dones.shape[0], dones.shape[1]))
values = torch.tensor(
[
[-4.6342, -7.6510, -7.0166, -7.6137, -7.4130, -7.7071, -7.7413, -7.8301],
[-7.0442, -7.0032, -6.9321, -6.7765, -6.5433, -6.3503, -6.2529, -5.9337],
[-7.5753, -7.8146, -7.6142, -7.8443, -7.8791, -7.7973, -7.7853, -7.7724],
[-6.4326, -6.1673, -7.6511, -7.7505, -8.0004, -7.8584, -7.5949, -7.9023],
]
)
value_quants = values.unsqueeze(-1)
last_values = torch.tensor([-7.9343, -5.8734, -7.8527, -8.1257])
dppo.critic = FakeCritic(last_values.unsqueeze(-1))
dataset = [
{
"dones": dones[:, i],
"full_next_critic_observations": observations[:, i].clone(),
"next_critic_observations": observations[:, i],
"rewards": rewards[:, i],
"timeouts": timeouts[:, i],
"values": values[:, i],
"value_quants": value_quants[:, i],
}
for i in range(dones.shape[1])
]
processed_dataset = dppo._process_dataset(dataset)
processed_returns = torch.stack(
[processed_dataset[i]["advantages"] + processed_dataset[i]["values"] for i in range(dones.shape[1])],
dim=-1,
)
processed_advantages = torch.stack(
[processed_dataset[i]["normalized_advantages"] for i in range(dones.shape[1])], dim=-1
)
expected_returns = torch.tensor(
[
[-100.0000, -8.4983, -8.4863, -8.5699, -8.4122, -8.4054, -8.2702, -8.0194],
[-7.2900, -7.1912, -6.9978, -6.7569, -6.5627, -6.3547, -6.1985, -6.1104],
[-7.9179, -7.8374, -7.7679, -7.6976, -7.6041, -7.6446, -7.7229, -7.7693],
[-96.2018, -100.0000, -9.0710, -9.1415, -8.8863, -8.7228, -8.4668, -8.2761],
]
)
expected_advantages = torch.tensor(
[
[-3.1452, 0.3006, 0.2779, 0.2966, 0.2951, 0.3060, 0.3122, 0.3246],
[0.3225, 0.3246, 0.3291, 0.3322, 0.3308, 0.3313, 0.3335, 0.3250],
[0.3190, 0.3307, 0.3259, 0.3368, 0.3415, 0.3371, 0.3338, 0.3316],
[-2.9412, -3.0893, 0.2797, 0.2808, 0.2992, 0.3000, 0.2997, 0.3179],
]
)
self.assertTrue(torch.isclose(processed_returns, expected_returns, atol=1e-4).all())
self.assertTrue(torch.isclose(processed_advantages, expected_advantages, atol=1e-4).all())
def test_target_computation_nstep(self):
# GT taken from old PPO implementation.
env = PoleBalancing(environment_count=2)
dppo = DPPO(env, device="cpu", gae_lambda=0.97, gamma=0.99, qrdqn_quantile_count=3, value_lambda=1.0)
rewards = torch.tensor(
[
[0.6, 1.0, 0.0, 0.6, 0.1, -0.2],
[0.1, -0.2, -0.4, 0.1, 1.0, 1.0],
]
)
dones = torch.tensor(
[
[1.0, 0.0, 0.0, 0.0, 1.0, 0.0],
[0.0, 0.0, 1.0, 0.0, 1.0, 0.0],
]
)
observations = torch.zeros((dones.shape[0], dones.shape[1], 24))
timeouts = torch.zeros((dones.shape[0], dones.shape[1]))
value_quants = torch.tensor(
[
[
[-0.4, 0.0, 0.1],
[0.9, 0.8, 0.7],
[0.7, 1.3, 0.0],
[1.2, 0.4, 1.2],
[1.3, 1.3, 1.1],
[0.0, 0.7, 0.5],
],
[
[1.3, 1.3, 0.9],
[0.4, -0.4, -0.1],
[0.4, 0.6, 0.1],
[0.7, 0.1, 0.3],
[0.2, 0.1, 0.3],
[1.4, 1.4, -0.3],
],
]
)
values = value_quants.mean(dim=-1)
last_values = torch.rand((dones.shape[0], 3))
dppo.critic = FakeCritic(last_values, quantile_count=3)
dataset = [
{
"dones": dones[:, i],
"full_next_critic_observations": observations[:, i].clone(),
"next_critic_observations": observations[:, i],
"rewards": rewards[:, i],
"timeouts": timeouts[:, i],
"values": values[:, i],
"value_quants": value_quants[:, i],
}
for i in range(dones.shape[1])
]
processed_dataset = dppo._process_dataset(dataset)
processed_value_target_quants = torch.stack(
[processed_dataset[i]["value_target_quants"] for i in range(dones.shape[1])],
dim=-2,
)
# N-step returns
# These exclude the reward received on the final step since it should not be added to the value target.
expected_returns = torch.tensor(
[
[0.6, 1.58806, 0.594, 0.6, 0.1, -0.2],
[-0.098, -0.2, -0.4, 0.1, 1.0, 1.0],
]
)
reset = lambda x: [0.0 for _ in x]
dscnt = lambda x, s=0: [x[v] * dppo.gamma**s for v in range(len(x))]
expected_value_target_quants = expected_returns.unsqueeze(-1) + torch.tensor(
[
[
reset([-0.4, 0.0, 0.1]),
dscnt([1.3, 1.3, 1.1], 3),
dscnt([1.3, 1.3, 1.1], 2),
dscnt([1.3, 1.3, 1.1], 1),
reset([1.3, 1.3, 1.1]),
dscnt(last_values[0], 1),
],
[
dscnt([0.4, 0.6, 0.1], 2),
dscnt([0.4, 0.6, 0.1], 1),
reset([0.4, 0.6, 0.1]),
dscnt([0.2, 0.1, 0.3], 1),
reset([0.2, 0.1, 0.3]),
dscnt(last_values[1], 1),
],
]
)
self.assertTrue(torch.isclose(processed_value_target_quants, expected_value_target_quants, atol=1e-4).all())
def test_target_computation_1step(self):
# GT taken from old PPO implementation.
env = PoleBalancing(environment_count=2)
dppo = DPPO(env, device="cpu", gae_lambda=0.97, gamma=0.99, qrdqn_quantile_count=3, value_lambda=0.0)
rewards = torch.tensor(
[
[0.6, 1.0, 0.0, 0.6, 0.1, -0.2],
[0.1, -0.2, -0.4, 0.1, 1.0, 1.0],
]
)
dones = torch.tensor(
[
[1.0, 0.0, 0.0, 0.0, 1.0, 0.0],
[0.0, 0.0, 1.0, 0.0, 1.0, 0.0],
]
)
observations = torch.zeros((dones.shape[0], dones.shape[1], 24))
timeouts = torch.zeros((dones.shape[0], dones.shape[1]))
value_quants = torch.tensor(
[
[
[-0.4, 0.0, 0.1],
[0.9, 0.8, 0.7],
[0.7, 1.3, 0.0],
[1.2, 0.4, 1.2],
[1.3, 1.3, 1.1],
[0.0, 0.7, 0.5],
],
[
[1.3, 1.3, 0.9],
[0.4, -0.4, -0.1],
[0.4, 0.6, 0.1],
[0.7, 0.1, 0.3],
[0.2, 0.1, 0.3],
[1.4, 1.4, -0.3],
],
]
)
values = value_quants.mean(dim=-1)
last_values = torch.rand((dones.shape[0], 3))
dppo.critic = FakeCritic(last_values, quantile_count=3)
dataset = [
{
"dones": dones[:, i],
"full_next_critic_observations": observations[:, i].clone(),
"next_critic_observations": observations[:, i],
"rewards": rewards[:, i],
"timeouts": timeouts[:, i],
"values": values[:, i],
"value_quants": value_quants[:, i],
}
for i in range(dones.shape[1])
]
processed_dataset = dppo._process_dataset(dataset)
processed_value_target_quants = torch.stack(
[processed_dataset[i]["value_target_quants"] for i in range(dones.shape[1])],
dim=-2,
)
# 1-step returns
expected_value_target_quants = rewards.unsqueeze(-1) + (
(1.0 - dones).float().unsqueeze(-1)
* dppo.gamma
* torch.cat((value_quants[:, 1:], last_values.unsqueeze(1)), dim=1)
)
self.assertTrue(torch.isclose(processed_value_target_quants, expected_value_target_quants, atol=1e-4).all())