From ca948c1e5bec0001d4b76162a806e9a2dbc45d9b Mon Sep 17 00:00:00 2001 From: Cadene Date: Fri, 1 Mar 2024 00:45:23 +0000 Subject: [PATCH] fix zip strict=False --- lerobot/common/policies/tdmpc_helper.py | 16 ++++++++++++++-- 1 file changed, 14 insertions(+), 2 deletions(-) diff --git a/lerobot/common/policies/tdmpc_helper.py b/lerobot/common/policies/tdmpc_helper.py index 264cd829..2c2ab4f2 100644 --- a/lerobot/common/policies/tdmpc_helper.py +++ b/lerobot/common/policies/tdmpc_helper.py @@ -73,8 +73,20 @@ def orthogonal_init(m): def ema(m, m_target, tau): """Update slow-moving average of online network (target network) at rate tau.""" with torch.no_grad(): - for p, p_target in zip(m.parameters(), m_target.parameters(), strict=False): - p_target.data.lerp_(p.data, tau) + # TODO(rcadene, aliberts): issue with strict=False + # for p, p_target in zip(m.parameters(), m_target.parameters(), strict=False): + # p_target.data.lerp_(p.data, tau) + m_params_iter = iter(m.parameters()) + m_target_params_iter = iter(m_target.parameters()) + + while True: + try: + p = next(m_params_iter) + p_target = next(m_target_params_iter) + p_target.data.lerp_(p.data, tau) + except StopIteration: + # If any iterator is exhausted, exit the loop + break def set_requires_grad(net, value):