rsl_rl/tests/test_ppo_recurrency.py

145 lines
5.5 KiB
Python

import torch
import unittest
from rsl_rl.algorithms import PPO
from rsl_rl.env.vec_env import VecEnv
from rsl_rl.runners.runner import Runner
class FakeNetwork(torch.nn.Module):
def __init__(self, values):
super().__init__()
self.hidden_state = None
self.recurrent = True
self.values = values
self._hidden_size = 2
def forward(self, x, hidden_state=None):
if not hidden_state:
self.hidden_state = (self.hidden_state[0] + 1, self.hidden_state[1] - 1)
values = self.values.repeat((*x.shape[:-1], 1)).squeeze(-1)
values.requires_grad_(True)
return values
def reset_full_hidden_state(self, batch_size=None):
assert batch_size is None or batch_size == 4, f"batch_size={batch_size}"
self.hidden_state = (torch.zeros((1, 4, self._hidden_size)), torch.zeros((1, 4, self._hidden_size)))
def reset_hidden_state(self, indices):
self.hidden_state[0][:, indices] = torch.zeros((len(indices), self._hidden_size))
self.hidden_state[1][:, indices] = torch.zeros((len(indices), self._hidden_size))
class FakeActorNetwork(FakeNetwork):
def forward(self, x, compute_std=False, hidden_state=None):
values = super().forward(x, hidden_state=hidden_state)
if compute_std:
return values, torch.ones_like(values)
return values
class FakeEnv(VecEnv):
def __init__(self, dones=None, **kwargs):
super().__init__(3, 3, **kwargs)
self.num_actions = 3
self._extra = {"observations": {}, "time_outs": torch.zeros((self.num_envs, 1))}
self._step = 0
self._dones = dones
self.reset()
def get_observations(self):
return self._state_buf, self._extra
def get_privileged_observations(self):
return self._state_buf, self._extra
def reset(self):
self._state_buf = torch.zeros((self.num_envs, self.num_obs))
return self._state_buf, self._extra
def step(self, actions):
assert actions.shape[0] == self.num_envs
assert actions.shape[1] == self.num_actions
self._state_buf += actions
rewards = torch.zeros((self.num_envs))
dones = torch.zeros((self.num_envs)) if self._dones is None else self._dones[self._step % self._dones.shape[0]]
self._step += 1
return self._state_buf, rewards, dones, self._extra
class PPORecurrencyTest(unittest.TestCase):
def test_draw_action_produces_hidden_state(self):
"""Test that the hidden state is correctly added to the data dictionary when drawing actions."""
env = FakeEnv(environment_count=4)
ppo = PPO(env, device="cpu", recurrent=True)
ppo.actor = FakeActorNetwork(torch.ones(env.num_actions))
ppo.critic = FakeNetwork(torch.zeros(1))
# Done during PPO.__init__, however we need to reset the hidden state here again since we are using a fake
# network that is added after initialization.
ppo.actor.reset_full_hidden_state(batch_size=env.num_envs)
ppo.critic.reset_full_hidden_state(batch_size=env.num_envs)
ones = torch.ones((1, env.num_envs, ppo.actor._hidden_size))
state, extra = env.reset()
for ctr in range(10):
_, data = ppo.draw_actions(state, extra)
# Actor state is changed every time an action is drawn.
self.assertTrue(torch.allclose(data["actor_state_h"], ones * ctr))
self.assertTrue(torch.allclose(data["actor_state_c"], -ones * ctr))
# Critic state is only changed and saved when processing the transition (evaluating the action) so we can't
# check it here.
def test_update_produces_hidden_state(self):
"""Test that the hidden state is correctly added to the data dictionary when updating."""
dones = torch.cat((torch.tensor([[0, 0, 0, 1]]), torch.zeros((4, 4)), torch.tensor([[1, 0, 0, 0]])), dim=0)
env = FakeEnv(dones=dones, environment_count=4)
ppo = PPO(env, device="cpu", recurrent=True)
runner = Runner(env, ppo, num_steps_per_env=6)
ppo.actor = FakeActorNetwork(torch.ones(env.num_actions))
ppo.critic = FakeNetwork(torch.zeros(1))
ppo.actor.reset_full_hidden_state(batch_size=env.num_envs)
ppo.critic.reset_full_hidden_state(batch_size=env.num_envs)
runner.learn(1)
state_h_0 = torch.tensor([[0, 0], [0, 0], [0, 0], [0, 0]])
state_h_1 = torch.tensor([[1, 1], [1, 1], [1, 1], [0, 0]])
state_h_2 = state_h_1 + 1
state_h_3 = state_h_2 + 1
state_h_4 = state_h_3 + 1
state_h_5 = state_h_4 + 1
state_h_6 = torch.tensor([[0, 0], [6, 6], [6, 6], [5, 5]])
state_h = (
torch.cat((state_h_0, state_h_1, state_h_2, state_h_3, state_h_4, state_h_5), dim=0).float().unsqueeze(1)
)
next_state_h = (
torch.cat((state_h_1, state_h_2, state_h_3, state_h_4, state_h_5, state_h_6), dim=0).float().unsqueeze(1)
)
self.assertTrue(torch.allclose(ppo.storage._data["critic_state_h"], state_h))
self.assertTrue(torch.allclose(ppo.storage._data["critic_state_c"], -state_h))
self.assertTrue(torch.allclose(ppo.storage._data["critic_next_state_h"], next_state_h))
self.assertTrue(torch.allclose(ppo.storage._data["critic_next_state_c"], -next_state_h))
self.assertTrue(torch.allclose(ppo.storage._data["actor_state_h"], state_h))
self.assertTrue(torch.allclose(ppo.storage._data["actor_state_c"], -state_h))