diff --git a/lerobot/common/policies/vqbet/modeling_vqbet.py b/lerobot/common/policies/vqbet/modeling_vqbet.py index c5ab1797..7271150c 100644 --- a/lerobot/common/policies/vqbet/modeling_vqbet.py +++ b/lerobot/common/policies/vqbet/modeling_vqbet.py @@ -801,7 +801,7 @@ class VqVae(nn.Module): dec_out = self.decoder(state_vq) encoder_loss = (state - dec_out).abs().mean() - rep_loss = encoder_loss * vq_loss_state * 5 + rep_loss = encoder_loss + vq_loss_state * 5 metric = ( encoder_loss.clone().detach(),