import torch import unittest from rsl_rl.algorithms import PPO from rsl_rl.env.vec_env import VecEnv from rsl_rl.runners.runner import Runner class FakeNetwork(torch.nn.Module): def __init__(self, values): super().__init__() self.hidden_state = None self.recurrent = True self.values = values self._hidden_size = 2 def forward(self, x, hidden_state=None): if not hidden_state: self.hidden_state = (self.hidden_state[0] + 1, self.hidden_state[1] - 1) values = self.values.repeat((*x.shape[:-1], 1)).squeeze(-1) values.requires_grad_(True) return values def reset_full_hidden_state(self, batch_size=None): assert batch_size is None or batch_size == 4, f"batch_size={batch_size}" self.hidden_state = (torch.zeros((1, 4, self._hidden_size)), torch.zeros((1, 4, self._hidden_size))) def reset_hidden_state(self, indices): self.hidden_state[0][:, indices] = torch.zeros((len(indices), self._hidden_size)) self.hidden_state[1][:, indices] = torch.zeros((len(indices), self._hidden_size)) class FakeActorNetwork(FakeNetwork): def forward(self, x, compute_std=False, hidden_state=None): values = super().forward(x, hidden_state=hidden_state) if compute_std: return values, torch.ones_like(values) return values class FakeEnv(VecEnv): def __init__(self, dones=None, **kwargs): super().__init__(3, 3, **kwargs) self.num_actions = 3 self._extra = {"observations": {}, "time_outs": torch.zeros((self.num_envs, 1))} self._step = 0 self._dones = dones self.reset() def get_observations(self): return self._state_buf, self._extra def get_privileged_observations(self): return self._state_buf, self._extra def reset(self): self._state_buf = torch.zeros((self.num_envs, self.num_obs)) return self._state_buf, self._extra def step(self, actions): assert actions.shape[0] == self.num_envs assert actions.shape[1] == self.num_actions self._state_buf += actions rewards = torch.zeros((self.num_envs)) dones = torch.zeros((self.num_envs)) if self._dones is None else self._dones[self._step % self._dones.shape[0]] self._step += 1 return self._state_buf, rewards, dones, self._extra class PPORecurrencyTest(unittest.TestCase): def test_draw_action_produces_hidden_state(self): """Test that the hidden state is correctly added to the data dictionary when drawing actions.""" env = FakeEnv(environment_count=4) ppo = PPO(env, device="cpu", recurrent=True) ppo.actor = FakeActorNetwork(torch.ones(env.num_actions)) ppo.critic = FakeNetwork(torch.zeros(1)) # Done during PPO.__init__, however we need to reset the hidden state here again since we are using a fake # network that is added after initialization. ppo.actor.reset_full_hidden_state(batch_size=env.num_envs) ppo.critic.reset_full_hidden_state(batch_size=env.num_envs) ones = torch.ones((1, env.num_envs, ppo.actor._hidden_size)) state, extra = env.reset() for ctr in range(10): _, data = ppo.draw_actions(state, extra) # Actor state is changed every time an action is drawn. self.assertTrue(torch.allclose(data["actor_state_h"], ones * ctr)) self.assertTrue(torch.allclose(data["actor_state_c"], -ones * ctr)) # Critic state is only changed and saved when processing the transition (evaluating the action) so we can't # check it here. def test_update_produces_hidden_state(self): """Test that the hidden state is correctly added to the data dictionary when updating.""" dones = torch.cat((torch.tensor([[0, 0, 0, 1]]), torch.zeros((4, 4)), torch.tensor([[1, 0, 0, 0]])), dim=0) env = FakeEnv(dones=dones, environment_count=4) ppo = PPO(env, device="cpu", recurrent=True) runner = Runner(env, ppo, num_steps_per_env=6) ppo.actor = FakeActorNetwork(torch.ones(env.num_actions)) ppo.critic = FakeNetwork(torch.zeros(1)) ppo.actor.reset_full_hidden_state(batch_size=env.num_envs) ppo.critic.reset_full_hidden_state(batch_size=env.num_envs) runner.learn(1) state_h_0 = torch.tensor([[0, 0], [0, 0], [0, 0], [0, 0]]) state_h_1 = torch.tensor([[1, 1], [1, 1], [1, 1], [0, 0]]) state_h_2 = state_h_1 + 1 state_h_3 = state_h_2 + 1 state_h_4 = state_h_3 + 1 state_h_5 = state_h_4 + 1 state_h_6 = torch.tensor([[0, 0], [6, 6], [6, 6], [5, 5]]) state_h = ( torch.cat((state_h_0, state_h_1, state_h_2, state_h_3, state_h_4, state_h_5), dim=0).float().unsqueeze(1) ) next_state_h = ( torch.cat((state_h_1, state_h_2, state_h_3, state_h_4, state_h_5, state_h_6), dim=0).float().unsqueeze(1) ) self.assertTrue(torch.allclose(ppo.storage._data["critic_state_h"], state_h)) self.assertTrue(torch.allclose(ppo.storage._data["critic_state_c"], -state_h)) self.assertTrue(torch.allclose(ppo.storage._data["critic_next_state_h"], next_state_h)) self.assertTrue(torch.allclose(ppo.storage._data["critic_next_state_c"], -next_state_h)) self.assertTrue(torch.allclose(ppo.storage._data["actor_state_h"], state_h)) self.assertTrue(torch.allclose(ppo.storage._data["actor_state_c"], -state_h))