diff --git a/lerobot/common/policies/tdmpc_helper.py b/lerobot/common/policies/tdmpc_helper.py index 264cd829..2c2ab4f2 100644 --- a/lerobot/common/policies/tdmpc_helper.py +++ b/lerobot/common/policies/tdmpc_helper.py @@ -73,8 +73,20 @@ def orthogonal_init(m): def ema(m, m_target, tau): """Update slow-moving average of online network (target network) at rate tau.""" with torch.no_grad(): - for p, p_target in zip(m.parameters(), m_target.parameters(), strict=False): - p_target.data.lerp_(p.data, tau) + # TODO(rcadene, aliberts): issue with strict=False + # for p, p_target in zip(m.parameters(), m_target.parameters(), strict=False): + # p_target.data.lerp_(p.data, tau) + m_params_iter = iter(m.parameters()) + m_target_params_iter = iter(m_target.parameters()) + + while True: + try: + p = next(m_params_iter) + p_target = next(m_target_params_iter) + p_target.data.lerp_(p.data, tau) + except StopIteration: + # If any iterator is exhausted, exit the loop + break def set_requires_grad(net, value):