This commit is contained in:
Xiaoxuan Liu 2025-04-04 12:59:29 +03:00 committed by GitHub
commit 098acc3caf
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 10 additions and 4 deletions

View File

@ -317,16 +317,16 @@ class PI0Policy(PreTrainedPolicy):
loss_dict = {} loss_dict = {}
losses = self.model.forward(images, img_masks, lang_tokens, lang_masks, state, actions, noise, time) losses = self.model.forward(images, img_masks, lang_tokens, lang_masks, state, actions, noise, time)
loss_dict["losses_after_forward"] = losses.clone() loss_dict["losses_after_forward"] = losses.mean().item()
if actions_is_pad is not None: if actions_is_pad is not None:
in_episode_bound = ~actions_is_pad in_episode_bound = ~actions_is_pad
losses = losses * in_episode_bound.unsqueeze(-1) losses = losses * in_episode_bound.unsqueeze(-1)
loss_dict["losses_after_in_ep_bound"] = losses.clone() loss_dict["losses_after_in_ep_bound"] = losses.mean().item()
# Remove padding # Remove padding
losses = losses[:, :, : self.config.max_action_dim] losses = losses[:, :, : self.config.max_action_dim]
loss_dict["losses_after_rm_padding"] = losses.clone() loss_dict["losses_after_rm_padding"] = losses.mean().item()
# For backward pass # For backward pass
loss = losses.mean() loss = losses.mean()

View File

@ -107,8 +107,11 @@ def predict_action(observation, policy, device, use_amp):
torch.inference_mode(), torch.inference_mode(),
torch.autocast(device_type=device.type) if device.type == "cuda" and use_amp else nullcontext(), torch.autocast(device_type=device.type) if device.type == "cuda" and use_amp else nullcontext(),
): ):
# Convert to pytorch format: channel first and float32 in [0,1] with batch dimension
for name in observation: for name in observation:
# Skip pre-processing the task text, the VLA will do the tokenization
if "task" in name:
continue
# Convert to pytorch format: channel first and float32 in [0,1] with batch dimension
if "image" in name: if "image" in name:
observation[name] = observation[name].type(torch.float32) / 255 observation[name] = observation[name].type(torch.float32) / 255
observation[name] = observation[name].permute(2, 0, 1).contiguous() observation[name] = observation[name].permute(2, 0, 1).contiguous()
@ -252,6 +255,9 @@ def control_loop(
observation = robot.capture_observation() observation = robot.capture_observation()
if policy is not None: if policy is not None:
# Pass the task to the policy if provided for VLA model
if observation.get("task") is None and single_task is not None:
observation["task"] = [single_task]
pred_action = predict_action( pred_action = predict_action(
observation, policy, get_safe_torch_device(policy.config.device), policy.config.use_amp observation, policy, get_safe_torch_device(policy.config.device), policy.config.use_amp
) )