Fixed the problem of the reset function of Memory corresponding to actor_critic_recurrent

This commit is contained in:
Bernard Tan 2024-07-17 21:29:21 +08:00 committed by GitHub
parent a1d25d1fef
commit f5114ab5c2
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
1 changed files with 3 additions and 1 deletions

View File

@ -93,5 +93,7 @@ class Memory(torch.nn.Module):
def reset(self, dones=None):
# When the RNN is an LSTM, self.hidden_states_a is a list with hidden_state and cell_state
# dones: (num_envs,), hidden_states: (num_layers, num_envs, hidden_size)
dones_envs_id = torch.where(dones)[0] if dones else None
for hidden_state in self.hidden_states:
hidden_state[..., dones, :] = 0.0
hidden_state[..., dones_envs_id, :] = 0.0