Merge b158576896
into 1c873df5c0
This commit is contained in:
commit
098acc3caf
|
@ -317,16 +317,16 @@ class PI0Policy(PreTrainedPolicy):
|
||||||
|
|
||||||
loss_dict = {}
|
loss_dict = {}
|
||||||
losses = self.model.forward(images, img_masks, lang_tokens, lang_masks, state, actions, noise, time)
|
losses = self.model.forward(images, img_masks, lang_tokens, lang_masks, state, actions, noise, time)
|
||||||
loss_dict["losses_after_forward"] = losses.clone()
|
loss_dict["losses_after_forward"] = losses.mean().item()
|
||||||
|
|
||||||
if actions_is_pad is not None:
|
if actions_is_pad is not None:
|
||||||
in_episode_bound = ~actions_is_pad
|
in_episode_bound = ~actions_is_pad
|
||||||
losses = losses * in_episode_bound.unsqueeze(-1)
|
losses = losses * in_episode_bound.unsqueeze(-1)
|
||||||
loss_dict["losses_after_in_ep_bound"] = losses.clone()
|
loss_dict["losses_after_in_ep_bound"] = losses.mean().item()
|
||||||
|
|
||||||
# Remove padding
|
# Remove padding
|
||||||
losses = losses[:, :, : self.config.max_action_dim]
|
losses = losses[:, :, : self.config.max_action_dim]
|
||||||
loss_dict["losses_after_rm_padding"] = losses.clone()
|
loss_dict["losses_after_rm_padding"] = losses.mean().item()
|
||||||
|
|
||||||
# For backward pass
|
# For backward pass
|
||||||
loss = losses.mean()
|
loss = losses.mean()
|
||||||
|
|
|
@ -107,8 +107,11 @@ def predict_action(observation, policy, device, use_amp):
|
||||||
torch.inference_mode(),
|
torch.inference_mode(),
|
||||||
torch.autocast(device_type=device.type) if device.type == "cuda" and use_amp else nullcontext(),
|
torch.autocast(device_type=device.type) if device.type == "cuda" and use_amp else nullcontext(),
|
||||||
):
|
):
|
||||||
# Convert to pytorch format: channel first and float32 in [0,1] with batch dimension
|
|
||||||
for name in observation:
|
for name in observation:
|
||||||
|
# Skip pre-processing the task text, the VLA will do the tokenization
|
||||||
|
if "task" in name:
|
||||||
|
continue
|
||||||
|
# Convert to pytorch format: channel first and float32 in [0,1] with batch dimension
|
||||||
if "image" in name:
|
if "image" in name:
|
||||||
observation[name] = observation[name].type(torch.float32) / 255
|
observation[name] = observation[name].type(torch.float32) / 255
|
||||||
observation[name] = observation[name].permute(2, 0, 1).contiguous()
|
observation[name] = observation[name].permute(2, 0, 1).contiguous()
|
||||||
|
@ -252,6 +255,9 @@ def control_loop(
|
||||||
observation = robot.capture_observation()
|
observation = robot.capture_observation()
|
||||||
|
|
||||||
if policy is not None:
|
if policy is not None:
|
||||||
|
# Pass the task to the policy if provided for VLA model
|
||||||
|
if observation.get("task") is None and single_task is not None:
|
||||||
|
observation["task"] = [single_task]
|
||||||
pred_action = predict_action(
|
pred_action = predict_action(
|
||||||
observation, policy, get_safe_torch_device(policy.config.device), policy.config.use_amp
|
observation, policy, get_safe_torch_device(policy.config.device), policy.config.use_amp
|
||||||
)
|
)
|
||||||
|
|
Loading…
Reference in New Issue