145 lines
5.5 KiB
Python
145 lines
5.5 KiB
Python
import torch
|
|
import unittest
|
|
from rsl_rl.algorithms import PPO
|
|
from rsl_rl.env.vec_env import VecEnv
|
|
from rsl_rl.runners.runner import Runner
|
|
|
|
|
|
class FakeNetwork(torch.nn.Module):
|
|
def __init__(self, values):
|
|
super().__init__()
|
|
|
|
self.hidden_state = None
|
|
self.recurrent = True
|
|
self.values = values
|
|
|
|
self._hidden_size = 2
|
|
|
|
def forward(self, x, hidden_state=None):
|
|
if not hidden_state:
|
|
self.hidden_state = (self.hidden_state[0] + 1, self.hidden_state[1] - 1)
|
|
|
|
values = self.values.repeat((*x.shape[:-1], 1)).squeeze(-1)
|
|
values.requires_grad_(True)
|
|
|
|
return values
|
|
|
|
def reset_full_hidden_state(self, batch_size=None):
|
|
assert batch_size is None or batch_size == 4, f"batch_size={batch_size}"
|
|
|
|
self.hidden_state = (torch.zeros((1, 4, self._hidden_size)), torch.zeros((1, 4, self._hidden_size)))
|
|
|
|
def reset_hidden_state(self, indices):
|
|
self.hidden_state[0][:, indices] = torch.zeros((len(indices), self._hidden_size))
|
|
self.hidden_state[1][:, indices] = torch.zeros((len(indices), self._hidden_size))
|
|
|
|
|
|
class FakeActorNetwork(FakeNetwork):
|
|
def forward(self, x, compute_std=False, hidden_state=None):
|
|
values = super().forward(x, hidden_state=hidden_state)
|
|
|
|
if compute_std:
|
|
return values, torch.ones_like(values)
|
|
|
|
return values
|
|
|
|
|
|
class FakeEnv(VecEnv):
|
|
def __init__(self, dones=None, **kwargs):
|
|
super().__init__(3, 3, **kwargs)
|
|
|
|
self.num_actions = 3
|
|
self._extra = {"observations": {}, "time_outs": torch.zeros((self.num_envs, 1))}
|
|
|
|
self._step = 0
|
|
self._dones = dones
|
|
|
|
self.reset()
|
|
|
|
def get_observations(self):
|
|
return self._state_buf, self._extra
|
|
|
|
def get_privileged_observations(self):
|
|
return self._state_buf, self._extra
|
|
|
|
def reset(self):
|
|
self._state_buf = torch.zeros((self.num_envs, self.num_obs))
|
|
|
|
return self._state_buf, self._extra
|
|
|
|
def step(self, actions):
|
|
assert actions.shape[0] == self.num_envs
|
|
assert actions.shape[1] == self.num_actions
|
|
|
|
self._state_buf += actions
|
|
|
|
rewards = torch.zeros((self.num_envs))
|
|
dones = torch.zeros((self.num_envs)) if self._dones is None else self._dones[self._step % self._dones.shape[0]]
|
|
|
|
self._step += 1
|
|
|
|
return self._state_buf, rewards, dones, self._extra
|
|
|
|
|
|
class PPORecurrencyTest(unittest.TestCase):
|
|
def test_draw_action_produces_hidden_state(self):
|
|
"""Test that the hidden state is correctly added to the data dictionary when drawing actions."""
|
|
env = FakeEnv(environment_count=4)
|
|
ppo = PPO(env, device="cpu", recurrent=True)
|
|
|
|
ppo.actor = FakeActorNetwork(torch.ones(env.num_actions))
|
|
ppo.critic = FakeNetwork(torch.zeros(1))
|
|
|
|
# Done during PPO.__init__, however we need to reset the hidden state here again since we are using a fake
|
|
# network that is added after initialization.
|
|
ppo.actor.reset_full_hidden_state(batch_size=env.num_envs)
|
|
ppo.critic.reset_full_hidden_state(batch_size=env.num_envs)
|
|
|
|
ones = torch.ones((1, env.num_envs, ppo.actor._hidden_size))
|
|
state, extra = env.reset()
|
|
for ctr in range(10):
|
|
_, data = ppo.draw_actions(state, extra)
|
|
|
|
# Actor state is changed every time an action is drawn.
|
|
self.assertTrue(torch.allclose(data["actor_state_h"], ones * ctr))
|
|
self.assertTrue(torch.allclose(data["actor_state_c"], -ones * ctr))
|
|
# Critic state is only changed and saved when processing the transition (evaluating the action) so we can't
|
|
# check it here.
|
|
|
|
def test_update_produces_hidden_state(self):
|
|
"""Test that the hidden state is correctly added to the data dictionary when updating."""
|
|
dones = torch.cat((torch.tensor([[0, 0, 0, 1]]), torch.zeros((4, 4)), torch.tensor([[1, 0, 0, 0]])), dim=0)
|
|
|
|
env = FakeEnv(dones=dones, environment_count=4)
|
|
ppo = PPO(env, device="cpu", recurrent=True)
|
|
runner = Runner(env, ppo, num_steps_per_env=6)
|
|
|
|
ppo.actor = FakeActorNetwork(torch.ones(env.num_actions))
|
|
ppo.critic = FakeNetwork(torch.zeros(1))
|
|
|
|
ppo.actor.reset_full_hidden_state(batch_size=env.num_envs)
|
|
ppo.critic.reset_full_hidden_state(batch_size=env.num_envs)
|
|
|
|
runner.learn(1)
|
|
|
|
state_h_0 = torch.tensor([[0, 0], [0, 0], [0, 0], [0, 0]])
|
|
state_h_1 = torch.tensor([[1, 1], [1, 1], [1, 1], [0, 0]])
|
|
state_h_2 = state_h_1 + 1
|
|
state_h_3 = state_h_2 + 1
|
|
state_h_4 = state_h_3 + 1
|
|
state_h_5 = state_h_4 + 1
|
|
state_h_6 = torch.tensor([[0, 0], [6, 6], [6, 6], [5, 5]])
|
|
state_h = (
|
|
torch.cat((state_h_0, state_h_1, state_h_2, state_h_3, state_h_4, state_h_5), dim=0).float().unsqueeze(1)
|
|
)
|
|
next_state_h = (
|
|
torch.cat((state_h_1, state_h_2, state_h_3, state_h_4, state_h_5, state_h_6), dim=0).float().unsqueeze(1)
|
|
)
|
|
|
|
self.assertTrue(torch.allclose(ppo.storage._data["critic_state_h"], state_h))
|
|
self.assertTrue(torch.allclose(ppo.storage._data["critic_state_c"], -state_h))
|
|
self.assertTrue(torch.allclose(ppo.storage._data["critic_next_state_h"], next_state_h))
|
|
self.assertTrue(torch.allclose(ppo.storage._data["critic_next_state_c"], -next_state_h))
|
|
self.assertTrue(torch.allclose(ppo.storage._data["actor_state_h"], state_h))
|
|
self.assertTrue(torch.allclose(ppo.storage._data["actor_state_c"], -state_h))
|