From f5114ab5c2f96f5292a1ce3930c09961bf4a9da4 Mon Sep 17 00:00:00 2001 From: Bernard Tan <30761156+thkkk@users.noreply.github.com> Date: Wed, 17 Jul 2024 21:29:21 +0800 Subject: [PATCH] Fixed the problem of the reset function of Memory corresponding to actor_critic_recurrent --- rsl_rl/modules/actor_critic_recurrent.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/rsl_rl/modules/actor_critic_recurrent.py b/rsl_rl/modules/actor_critic_recurrent.py index 6321ec5..44f52d4 100644 --- a/rsl_rl/modules/actor_critic_recurrent.py +++ b/rsl_rl/modules/actor_critic_recurrent.py @@ -93,5 +93,7 @@ class Memory(torch.nn.Module): def reset(self, dones=None): # When the RNN is an LSTM, self.hidden_states_a is a list with hidden_state and cell_state + # dones: (num_envs,), hidden_states: (num_layers, num_envs, hidden_size) + dones_envs_id = torch.where(dones)[0] if dones else None for hidden_state in self.hidden_states: - hidden_state[..., dones, :] = 0.0 + hidden_state[..., dones_envs_id, :] = 0.0