172 lines
5.4 KiB
Python
172 lines
5.4 KiB
Python
|
import torch
|
||
|
import unittest
|
||
|
from rsl_rl.algorithms import DPPO
|
||
|
from rsl_rl.env.vec_env import VecEnv
|
||
|
|
||
|
ACTION_SIZE = 3
|
||
|
ENV_COUNT = 3
|
||
|
OBS_SIZE = 24
|
||
|
|
||
|
|
||
|
class FakeEnv(VecEnv):
|
||
|
def __init__(self, rewards, dones, environment_count=1):
|
||
|
super().__init__(OBS_SIZE, OBS_SIZE, environment_count=environment_count)
|
||
|
|
||
|
self.num_actions = ACTION_SIZE
|
||
|
self.rewards = rewards
|
||
|
self.dones = dones
|
||
|
|
||
|
self._step = 0
|
||
|
|
||
|
def get_observations(self):
|
||
|
return torch.zeros((self.num_envs, self.num_obs)), {"observations": {}}
|
||
|
|
||
|
def get_privileged_observations(self):
|
||
|
return torch.zeros((self.num_envs, self.num_privileged_obs)), {"observations": {}}
|
||
|
|
||
|
def step(self, actions):
|
||
|
obs, _ = self.get_observations()
|
||
|
rewards = self.rewards[self._step]
|
||
|
dones = self.dones[self._step]
|
||
|
|
||
|
self._step += 1
|
||
|
|
||
|
return obs, rewards, dones, {"observations": {}}
|
||
|
|
||
|
def reset(self):
|
||
|
pass
|
||
|
|
||
|
|
||
|
class FakeCritic(torch.nn.Module):
|
||
|
def __init__(self, action_samples, value_samples, action_values, value_values, action_taus, value_taus):
|
||
|
self.recurrent = False
|
||
|
self.action_samples = action_samples
|
||
|
self.value_samples = value_samples
|
||
|
self.action_values = action_values
|
||
|
self.value_values = value_values
|
||
|
self.action_taus = action_taus
|
||
|
self.value_taus = value_taus
|
||
|
|
||
|
self.last_quantiles = None
|
||
|
self.last_taus = None
|
||
|
|
||
|
def forward(self, _, distribution=False, measure_args=None, sample_count=8, taus=None, use_measure=True):
|
||
|
if taus is not None:
|
||
|
sample_count = taus.shape[-1]
|
||
|
|
||
|
if sample_count == self.action_samples:
|
||
|
self.last_taus = self.action_taus
|
||
|
self.last_quantiles = self.action_values
|
||
|
elif sample_count == self.value_samples:
|
||
|
self.last_taus = self.value_taus
|
||
|
self.last_quantiles = self.value_values
|
||
|
else:
|
||
|
raise ValueError(f"Invalid sample count: {sample_count}")
|
||
|
|
||
|
if distribution:
|
||
|
return self.last_quantiles
|
||
|
|
||
|
return self.last_quantiles.mean(-1)
|
||
|
|
||
|
|
||
|
def fake_process_quants(self, x):
|
||
|
idx = torch.arange(0, x.shape[-1]).expand(*x.shape[:-1])
|
||
|
|
||
|
return x, idx
|
||
|
|
||
|
|
||
|
class DPPOTest(unittest.TestCase):
|
||
|
def test_value_target_computation(self):
|
||
|
rewards = torch.tensor(
|
||
|
[
|
||
|
[-1.0000e02, -1.4055e-01, -3.0476e-02],
|
||
|
[-1.7633e-01, -2.6533e-01, -3.0786e-01],
|
||
|
[-1.5952e-01, -1.5177e-01, -1.4296e-01],
|
||
|
[1.1407e-02, -1.0000e02, -6.2290e-02],
|
||
|
]
|
||
|
)
|
||
|
dones = torch.tensor(
|
||
|
[
|
||
|
[1, 0, 0],
|
||
|
[0, 0, 0],
|
||
|
[0, 0, 0],
|
||
|
[0, 1, 0],
|
||
|
]
|
||
|
)
|
||
|
|
||
|
env = FakeEnv(rewards, dones, environment_count=ENV_COUNT)
|
||
|
dppo = DPPO(
|
||
|
env,
|
||
|
critic_network=DPPO.network_iqn,
|
||
|
device="cpu",
|
||
|
gae_lambda=0.97,
|
||
|
gamma=0.99,
|
||
|
iqn_action_samples=4,
|
||
|
iqn_value_samples=2,
|
||
|
value_lambda=1.0,
|
||
|
value_loss=DPPO.value_loss_energy,
|
||
|
)
|
||
|
|
||
|
# Generate fake dataset
|
||
|
|
||
|
action_taus = torch.tensor(
|
||
|
[
|
||
|
[[0.3, 0.5, 1.0, 0.2], [0.8, 0.9, 0.0, 0.9], [0.6, 0.1, 0.6, 0.5]],
|
||
|
[[0.7, 0.9, 0.3, 0.0], [1.0, 0.7, 0.7, 0.7], [0.3, 0.8, 0.8, 0.1]],
|
||
|
[[0.3, 0.8, 0.3, 0.2], [0.2, 0.9, 0.6, 0.4], [0.8, 0.4, 0.8, 1.0]],
|
||
|
[[0.6, 0.6, 0.8, 0.8], [0.8, 0.0, 0.9, 0.1], [0.2, 0.3, 0.6, 0.2]],
|
||
|
]
|
||
|
)
|
||
|
action_value_quants = torch.tensor(
|
||
|
[
|
||
|
[[0.2, 0.2, 0.6, 0.5], [0.5, 0.8, 0.1, 0.0], [1.0, 0.1, 0.8, 0.8]],
|
||
|
[[0.0, 0.6, 0.1, 0.9], [0.2, 1.0, 0.9, 1.0], [0.4, 0.1, 0.1, 0.8]],
|
||
|
[[0.7, 0.0, 0.6, 0.8], [0.7, 0.7, 0.7, 0.8], [0.0, 0.1, 0.5, 0.8]],
|
||
|
[[0.5, 0.8, 0.1, 0.1], [0.9, 0.4, 0.7, 0.6], [0.6, 0.3, 0.1, 0.4]],
|
||
|
]
|
||
|
)
|
||
|
value_taus = torch.tensor(
|
||
|
[
|
||
|
[[0.3, 0.5], [0.8, 0.9], [0.6, 0.1]],
|
||
|
[[0.7, 0.9], [1.0, 0.7], [0.3, 0.8]],
|
||
|
[[0.3, 0.8], [0.2, 0.9], [0.8, 0.4]],
|
||
|
[[0.6, 0.6], [0.8, 0.0], [0.2, 0.3]],
|
||
|
]
|
||
|
)
|
||
|
value_value_quants = torch.tensor(
|
||
|
[
|
||
|
[[0.9, 0.8], [0.1, 0.3], [0.3, 0.5]],
|
||
|
[[0.2, 0.1], [0.9, 0.3], [0.4, 0.2]],
|
||
|
[[0.7, 1.0], [0.6, 0.2], [0.2, 0.6]],
|
||
|
[[0.4, 1.0], [0.3, 0.6], [0.3, 0.1]],
|
||
|
]
|
||
|
)
|
||
|
|
||
|
actions = torch.zeros(ENV_COUNT, ACTION_SIZE)
|
||
|
env_info = {"observations": {}}
|
||
|
obs = torch.zeros(ENV_COUNT, OBS_SIZE)
|
||
|
dataset = []
|
||
|
for i in range(4):
|
||
|
dppo.critic = FakeCritic(4, 2, action_value_quants[i], value_value_quants[i], action_taus[i], value_taus[i])
|
||
|
dppo.critic._process_quants = fake_process_quants
|
||
|
|
||
|
_, data = dppo.draw_actions(obs, {})
|
||
|
_, rewards, dones, _ = env.step(actions)
|
||
|
|
||
|
dataset.append(
|
||
|
dppo.process_transition(
|
||
|
obs,
|
||
|
env_info,
|
||
|
actions,
|
||
|
rewards,
|
||
|
obs,
|
||
|
env_info,
|
||
|
dones,
|
||
|
data,
|
||
|
)
|
||
|
)
|
||
|
|
||
|
processed_dataset = dppo._process_dataset(dataset)
|
||
|
|
||
|
# TODO: Test that the value targets are correct.
|