diff --git a/lerobot/common/policies/pi0fast/modeling_pi0fast.py b/lerobot/common/policies/pi0fast/modeling_pi0fast.py index 440645a3..98219bbc 100644 --- a/lerobot/common/policies/pi0fast/modeling_pi0fast.py +++ b/lerobot/common/policies/pi0fast/modeling_pi0fast.py @@ -846,7 +846,10 @@ class PI0FAST(nn.Module): decoded_actions = [ torch.tensor( self.decode_actions_with_fast( - tok.tolist(), time_horizon=action_horizon, action_dim=action_dim, relaxed_decoding=self.config.relaxed_decoding + tok.tolist(), + time_horizon=action_horizon, + action_dim=action_dim, + relaxed_decoding=self.config.relaxed_decoding, ), device=tokens.device, ).squeeze(0)