diff --git a/lerobot/common/policies/pi0/modeling_pi0.py b/lerobot/common/policies/pi0/modeling_pi0.py index 7599fa63..bf92998d 100644 --- a/lerobot/common/policies/pi0/modeling_pi0.py +++ b/lerobot/common/policies/pi0/modeling_pi0.py @@ -317,16 +317,16 @@ class PI0Policy(PreTrainedPolicy): loss_dict = {} losses = self.model.forward(images, img_masks, lang_tokens, lang_masks, state, actions, noise, time) - loss_dict["losses_after_forward"] = losses.clone() + loss_dict["losses_after_forward"] = losses.mean().item() if actions_is_pad is not None: in_episode_bound = ~actions_is_pad losses = losses * in_episode_bound.unsqueeze(-1) - loss_dict["losses_after_in_ep_bound"] = losses.clone() + loss_dict["losses_after_in_ep_bound"] = losses.mean().item() # Remove padding losses = losses[:, :, : self.config.max_action_dim] - loss_dict["losses_after_rm_padding"] = losses.clone() + loss_dict["losses_after_rm_padding"] = losses.mean().item() # For backward pass loss = losses.mean() diff --git a/lerobot/common/robot_devices/control_utils.py b/lerobot/common/robot_devices/control_utils.py index 4e42a989..d5da3fb5 100644 --- a/lerobot/common/robot_devices/control_utils.py +++ b/lerobot/common/robot_devices/control_utils.py @@ -107,8 +107,11 @@ def predict_action(observation, policy, device, use_amp): torch.inference_mode(), torch.autocast(device_type=device.type) if device.type == "cuda" and use_amp else nullcontext(), ): - # Convert to pytorch format: channel first and float32 in [0,1] with batch dimension for name in observation: + # Skip pre-processing the task text, the VLA will do the tokenization + if "task" in name: + continue + # Convert to pytorch format: channel first and float32 in [0,1] with batch dimension if "image" in name: observation[name] = observation[name].type(torch.float32) / 255 observation[name] = observation[name].permute(2, 0, 1).contiguous() @@ -252,6 +255,9 @@ def control_loop( observation = robot.capture_observation() if policy is not None: + # Pass the task to the policy if provided for VLA model + if observation.get("task") is None and single_task is not None: + observation["task"] = [single_task] pred_action = predict_action( observation, policy, get_safe_torch_device(policy.config.device), policy.config.use_amp )