fix zip strict=False

This commit is contained in:
Cadene 2024-03-01 00:45:23 +00:00
parent ae050d2e94
commit ca948c1e5b
1 changed files with 14 additions and 2 deletions

View File

@ -73,8 +73,20 @@ def orthogonal_init(m):
def ema(m, m_target, tau): def ema(m, m_target, tau):
"""Update slow-moving average of online network (target network) at rate tau.""" """Update slow-moving average of online network (target network) at rate tau."""
with torch.no_grad(): with torch.no_grad():
for p, p_target in zip(m.parameters(), m_target.parameters(), strict=False): # TODO(rcadene, aliberts): issue with strict=False
p_target.data.lerp_(p.data, tau) # for p, p_target in zip(m.parameters(), m_target.parameters(), strict=False):
# p_target.data.lerp_(p.data, tau)
m_params_iter = iter(m.parameters())
m_target_params_iter = iter(m_target.parameters())
while True:
try:
p = next(m_params_iter)
p_target = next(m_target_params_iter)
p_target.data.lerp_(p.data, tau)
except StopIteration:
# If any iterator is exhausted, exit the loop
break
def set_requires_grad(net, value): def set_requires_grad(net, value):