rsl_rl/tests/test_dppo_iqn.py

172 lines
5.4 KiB
Python

import torch
import unittest
from rsl_rl.algorithms import DPPO
from rsl_rl.env.vec_env import VecEnv
ACTION_SIZE = 3
ENV_COUNT = 3
OBS_SIZE = 24
class FakeEnv(VecEnv):
def __init__(self, rewards, dones, environment_count=1):
super().__init__(OBS_SIZE, OBS_SIZE, environment_count=environment_count)
self.num_actions = ACTION_SIZE
self.rewards = rewards
self.dones = dones
self._step = 0
def get_observations(self):
return torch.zeros((self.num_envs, self.num_obs)), {"observations": {}}
def get_privileged_observations(self):
return torch.zeros((self.num_envs, self.num_privileged_obs)), {"observations": {}}
def step(self, actions):
obs, _ = self.get_observations()
rewards = self.rewards[self._step]
dones = self.dones[self._step]
self._step += 1
return obs, rewards, dones, {"observations": {}}
def reset(self):
pass
class FakeCritic(torch.nn.Module):
def __init__(self, action_samples, value_samples, action_values, value_values, action_taus, value_taus):
self.recurrent = False
self.action_samples = action_samples
self.value_samples = value_samples
self.action_values = action_values
self.value_values = value_values
self.action_taus = action_taus
self.value_taus = value_taus
self.last_quantiles = None
self.last_taus = None
def forward(self, _, distribution=False, measure_args=None, sample_count=8, taus=None, use_measure=True):
if taus is not None:
sample_count = taus.shape[-1]
if sample_count == self.action_samples:
self.last_taus = self.action_taus
self.last_quantiles = self.action_values
elif sample_count == self.value_samples:
self.last_taus = self.value_taus
self.last_quantiles = self.value_values
else:
raise ValueError(f"Invalid sample count: {sample_count}")
if distribution:
return self.last_quantiles
return self.last_quantiles.mean(-1)
def fake_process_quants(self, x):
idx = torch.arange(0, x.shape[-1]).expand(*x.shape[:-1])
return x, idx
class DPPOTest(unittest.TestCase):
def test_value_target_computation(self):
rewards = torch.tensor(
[
[-1.0000e02, -1.4055e-01, -3.0476e-02],
[-1.7633e-01, -2.6533e-01, -3.0786e-01],
[-1.5952e-01, -1.5177e-01, -1.4296e-01],
[1.1407e-02, -1.0000e02, -6.2290e-02],
]
)
dones = torch.tensor(
[
[1, 0, 0],
[0, 0, 0],
[0, 0, 0],
[0, 1, 0],
]
)
env = FakeEnv(rewards, dones, environment_count=ENV_COUNT)
dppo = DPPO(
env,
critic_network=DPPO.network_iqn,
device="cpu",
gae_lambda=0.97,
gamma=0.99,
iqn_action_samples=4,
iqn_value_samples=2,
value_lambda=1.0,
value_loss=DPPO.value_loss_energy,
)
# Generate fake dataset
action_taus = torch.tensor(
[
[[0.3, 0.5, 1.0, 0.2], [0.8, 0.9, 0.0, 0.9], [0.6, 0.1, 0.6, 0.5]],
[[0.7, 0.9, 0.3, 0.0], [1.0, 0.7, 0.7, 0.7], [0.3, 0.8, 0.8, 0.1]],
[[0.3, 0.8, 0.3, 0.2], [0.2, 0.9, 0.6, 0.4], [0.8, 0.4, 0.8, 1.0]],
[[0.6, 0.6, 0.8, 0.8], [0.8, 0.0, 0.9, 0.1], [0.2, 0.3, 0.6, 0.2]],
]
)
action_value_quants = torch.tensor(
[
[[0.2, 0.2, 0.6, 0.5], [0.5, 0.8, 0.1, 0.0], [1.0, 0.1, 0.8, 0.8]],
[[0.0, 0.6, 0.1, 0.9], [0.2, 1.0, 0.9, 1.0], [0.4, 0.1, 0.1, 0.8]],
[[0.7, 0.0, 0.6, 0.8], [0.7, 0.7, 0.7, 0.8], [0.0, 0.1, 0.5, 0.8]],
[[0.5, 0.8, 0.1, 0.1], [0.9, 0.4, 0.7, 0.6], [0.6, 0.3, 0.1, 0.4]],
]
)
value_taus = torch.tensor(
[
[[0.3, 0.5], [0.8, 0.9], [0.6, 0.1]],
[[0.7, 0.9], [1.0, 0.7], [0.3, 0.8]],
[[0.3, 0.8], [0.2, 0.9], [0.8, 0.4]],
[[0.6, 0.6], [0.8, 0.0], [0.2, 0.3]],
]
)
value_value_quants = torch.tensor(
[
[[0.9, 0.8], [0.1, 0.3], [0.3, 0.5]],
[[0.2, 0.1], [0.9, 0.3], [0.4, 0.2]],
[[0.7, 1.0], [0.6, 0.2], [0.2, 0.6]],
[[0.4, 1.0], [0.3, 0.6], [0.3, 0.1]],
]
)
actions = torch.zeros(ENV_COUNT, ACTION_SIZE)
env_info = {"observations": {}}
obs = torch.zeros(ENV_COUNT, OBS_SIZE)
dataset = []
for i in range(4):
dppo.critic = FakeCritic(4, 2, action_value_quants[i], value_value_quants[i], action_taus[i], value_taus[i])
dppo.critic._process_quants = fake_process_quants
_, data = dppo.draw_actions(obs, {})
_, rewards, dones, _ = env.step(actions)
dataset.append(
dppo.process_transition(
obs,
env_info,
actions,
rewards,
obs,
env_info,
dones,
data,
)
)
processed_dataset = dppo._process_dataset(dataset)
# TODO: Test that the value targets are correct.