Fixed the problem of the reset function of Memory corresponding to actor_critic_recurrent
This commit is contained in:
parent
a1d25d1fef
commit
f5114ab5c2
|
@ -93,5 +93,7 @@ class Memory(torch.nn.Module):
|
||||||
|
|
||||||
def reset(self, dones=None):
|
def reset(self, dones=None):
|
||||||
# When the RNN is an LSTM, self.hidden_states_a is a list with hidden_state and cell_state
|
# When the RNN is an LSTM, self.hidden_states_a is a list with hidden_state and cell_state
|
||||||
|
# dones: (num_envs,), hidden_states: (num_layers, num_envs, hidden_size)
|
||||||
|
dones_envs_id = torch.where(dones)[0] if dones else None
|
||||||
for hidden_state in self.hidden_states:
|
for hidden_state in self.hidden_states:
|
||||||
hidden_state[..., dones, :] = 0.0
|
hidden_state[..., dones_envs_id, :] = 0.0
|
||||||
|
|
Loading…
Reference in New Issue