Enhance pi0 model inference

1. Pass task into observation for VLA model(pi0)
  2. Update loss_dict stats data format.
This commit is contained in:
Xiaoxuan Liu 2025-03-18 10:32:20 +08:00
parent 1c15bab70f
commit b158576896
2 changed files with 10 additions and 4 deletions

View File

@ -317,16 +317,16 @@ class PI0Policy(PreTrainedPolicy):
loss_dict = {}
losses = self.model.forward(images, img_masks, lang_tokens, lang_masks, state, actions, noise, time)
loss_dict["losses_after_forward"] = losses.clone()
loss_dict["losses_after_forward"] = losses.mean().item()
if actions_is_pad is not None:
in_episode_bound = ~actions_is_pad
losses = losses * in_episode_bound.unsqueeze(-1)
loss_dict["losses_after_in_ep_bound"] = losses.clone()
loss_dict["losses_after_in_ep_bound"] = losses.mean().item()
# Remove padding
losses = losses[:, :, : self.config.max_action_dim]
loss_dict["losses_after_rm_padding"] = losses.clone()
loss_dict["losses_after_rm_padding"] = losses.mean().item()
# For backward pass
loss = losses.mean()

View File

@ -107,8 +107,11 @@ def predict_action(observation, policy, device, use_amp):
torch.inference_mode(),
torch.autocast(device_type=device.type) if device.type == "cuda" and use_amp else nullcontext(),
):
# Convert to pytorch format: channel first and float32 in [0,1] with batch dimension
for name in observation:
# Skip pre-processing the task text, the VLA will do the tokenization
if "task" in name:
continue
# Convert to pytorch format: channel first and float32 in [0,1] with batch dimension
if "image" in name:
observation[name] = observation[name].type(torch.float32) / 255
observation[name] = observation[name].permute(2, 0, 1).contiguous()
@ -252,6 +255,9 @@ def control_loop(
observation = robot.capture_observation()
if policy is not None:
# Pass the task to the policy if provided for VLA model
if observation.get("task") is None and single_task is not None:
observation["task"] = [single_task]
pred_action = predict_action(
observation, policy, get_safe_torch_device(policy.config.device), policy.config.use_amp
)