rsl_rl/tests/test_dppo_recurrency.py

171 lines
6.3 KiB
Python
Raw Normal View History

import torch
import unittest
from rsl_rl.algorithms import DPPO
from rsl_rl.env.vec_env import VecEnv
from rsl_rl.runners.runner import Runner
from rsl_rl.utils.benchmarkable import Benchmarkable
class FakeNetwork(torch.nn.Module, Benchmarkable):
def __init__(self, values):
super().__init__()
self.hidden_state = None
self.quantile_count = 1
self.recurrent = True
self.values = values
self._hidden_size = 2
def forward(self, x, hidden_state=None):
if not hidden_state:
self.hidden_state = (self.hidden_state[0] + 1, self.hidden_state[1] - 1)
values = self.values.repeat((*x.shape[:-1], 1)).squeeze(-1)
values.requires_grad_(True)
return values
def reset_full_hidden_state(self, batch_size=None):
assert batch_size is None or batch_size == 4, f"batch_size={batch_size}"
self.hidden_state = (torch.zeros((1, 4, self._hidden_size)), torch.zeros((1, 4, self._hidden_size)))
def reset_hidden_state(self, indices):
self.hidden_state[0][:, indices] = torch.zeros((len(indices), self._hidden_size))
self.hidden_state[1][:, indices] = torch.zeros((len(indices), self._hidden_size))
class FakeActorNetwork(FakeNetwork):
def forward(self, x, compute_std=False, hidden_state=None):
values = super().forward(x, hidden_state=hidden_state)
if compute_std:
return values, torch.ones_like(values)
return values
class FakeCriticNetwork(FakeNetwork):
_quantile_count = 1
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
def forward(self, x, distribution=False, hidden_state=None, measure_args=None):
values = super().forward(x, hidden_state=hidden_state)
self.last_quantiles = values.reshape(*values.shape, 1)
if distribution:
return self.last_quantiles
return values
def quantile_l1_loss(self, *args, **kwargs):
return torch.tensor(0.0)
def quantiles_to_values(self, quantiles):
return quantiles.squeeze()
class FakeEnv(VecEnv):
def __init__(self, dones=None, **kwargs):
super().__init__(3, 3, **kwargs)
self.num_actions = 3
self._extra = {"observations": {}, "time_outs": torch.zeros((self.num_envs, 1))}
self._step = 0
self._dones = dones
self.reset()
def get_observations(self):
return self._state_buf, self._extra
def get_privileged_observations(self):
return self._state_buf, self._extra
def reset(self):
self._state_buf = torch.zeros((self.num_envs, self.num_obs))
return self._state_buf, self._extra
def step(self, actions):
assert actions.shape[0] == self.num_envs
assert actions.shape[1] == self.num_actions
self._state_buf += actions
rewards = torch.zeros((self.num_envs))
dones = torch.zeros((self.num_envs)) if self._dones is None else self._dones[self._step % self._dones.shape[0]]
self._step += 1
return self._state_buf, rewards, dones, self._extra
class DPPORecurrencyTest(unittest.TestCase):
def test_draw_action_produces_hidden_state(self):
"""Test that the hidden state is correctly added to the data dictionary when drawing actions."""
env = FakeEnv(environment_count=4)
dppo = DPPO(env, device="cpu", recurrent=True)
dppo.actor = FakeActorNetwork(torch.ones(env.num_actions))
dppo.critic = FakeCriticNetwork(torch.zeros(1))
# Done during DPPO.__init__, however we need to reset the hidden state here again since we are using a fake
# network that is added after initialization.
dppo.actor.reset_full_hidden_state(batch_size=env.num_envs)
dppo.critic.reset_full_hidden_state(batch_size=env.num_envs)
ones = torch.ones((1, env.num_envs, dppo.actor._hidden_size))
state, extra = env.reset()
for ctr in range(10):
_, data = dppo.draw_actions(state, extra)
# Actor state is changed every time an action is drawn.
self.assertTrue(torch.allclose(data["actor_state_h"], ones * ctr))
self.assertTrue(torch.allclose(data["actor_state_c"], -ones * ctr))
# Critic state is only changed and saved when processing the transition (evaluating the action) so we can't
# check it here.
def test_update_produces_hidden_state(self):
"""Test that the hidden state is correctly added to the data dictionary when updating."""
dones = torch.cat((torch.tensor([[0, 0, 0, 1]]), torch.zeros((4, 4)), torch.tensor([[1, 0, 0, 0]])), dim=0)
env = FakeEnv(dones=dones, environment_count=4)
dppo = DPPO(env, device="cpu", recurrent=True)
runner = Runner(env, dppo, num_steps_per_env=6)
dppo._value_loss = lambda *args, **kwargs: torch.tensor(0.0)
dppo.actor = FakeActorNetwork(torch.ones(env.num_actions))
dppo.critic = FakeCriticNetwork(torch.zeros(1))
dppo.actor.reset_full_hidden_state(batch_size=env.num_envs)
dppo.critic.reset_full_hidden_state(batch_size=env.num_envs)
runner.learn(1)
state_h_0 = torch.tensor([[0, 0], [0, 0], [0, 0], [0, 0]])
state_h_1 = torch.tensor([[1, 1], [1, 1], [1, 1], [0, 0]])
state_h_2 = state_h_1 + 1
state_h_3 = state_h_2 + 1
state_h_4 = state_h_3 + 1
state_h_5 = state_h_4 + 1
state_h_6 = torch.tensor([[0, 0], [6, 6], [6, 6], [5, 5]])
state_h = (
torch.cat((state_h_0, state_h_1, state_h_2, state_h_3, state_h_4, state_h_5), dim=0).float().unsqueeze(1)
)
next_state_h = (
torch.cat((state_h_1, state_h_2, state_h_3, state_h_4, state_h_5, state_h_6), dim=0).float().unsqueeze(1)
)
self.assertTrue(torch.allclose(dppo.storage._data["critic_state_h"], state_h))
self.assertTrue(torch.allclose(dppo.storage._data["critic_state_c"], -state_h))
self.assertTrue(torch.allclose(dppo.storage._data["critic_next_state_h"], next_state_h))
self.assertTrue(torch.allclose(dppo.storage._data["critic_next_state_c"], -next_state_h))
self.assertTrue(torch.allclose(dppo.storage._data["actor_state_h"], state_h))
self.assertTrue(torch.allclose(dppo.storage._data["actor_state_c"], -state_h))