This commit is contained in:
mshukor 2025-03-26 13:21:19 +01:00
parent 0a4bba1da7
commit 7a45fa0fc1
4 changed files with 17 additions and 10 deletions

View File

@ -46,9 +46,6 @@
]
}
},
"use_env_state": true,
"exclude_image_keys": "",
"normalize_per_robot_type": false,
"chunk_size": 10,
"n_action_steps": 5,
"max_state_dim": 32,

View File

@ -175,6 +175,7 @@ class PI0FASTPolicy(PreTrainedPolicy):
self.language_tokenizer = AutoProcessor.from_pretrained("google/paligemma-3b-pt-224")
self.model = PI0FAST(config)
self.adapt_to_pi_aloha = True #self.config.adapt_to_pi_aloha # FIXME(mshukor): debug
self.reset()
@ -222,7 +223,7 @@ class PI0FASTPolicy(PreTrainedPolicy):
"""
self.eval()
if self.config.adapt_to_pi_aloha:
if self.adapt_to_pi_aloha:
batch[OBS_ROBOT] = self._pi_aloha_decode_state(batch[OBS_ROBOT])
batch = self.normalize_inputs(batch)
@ -241,7 +242,7 @@ class PI0FASTPolicy(PreTrainedPolicy):
actions = self.unnormalize_outputs({"action": actions})["action"]
if self.config.adapt_to_pi_aloha:
if self.adapt_to_pi_aloha:
actions = self._pi_aloha_encode_actions(actions)
# `self.model.forward` returns a (batch_size, n_action_steps, action_dim) tensor, but the queue
@ -250,7 +251,7 @@ class PI0FASTPolicy(PreTrainedPolicy):
return self._action_queue.popleft()
def forward(self, batch: dict[str, Tensor]) -> dict[str, Tensor]:
if self.config.adapt_to_pi_aloha:
if self.adapt_to_pi_aloha:
batch[OBS_ROBOT] = self._pi_aloha_decode_state(batch[OBS_ROBOT])
batch[ACTION] = self._pi_aloha_encode_actions_inv(batch[ACTION])
batch = self.normalize_inputs(batch)

View File

@ -9,15 +9,19 @@ TASK=AlohaTransferCube-v0
REPO_ID=lerobot/aloha_sim_transfer_cube_human
OUT_DIR=~/logs/lerobot/tmp/act_aloha_transfer
EVAL_FREQ=50
EVAL_FREQ=5000
OFFLINE_STEP=30000
SAVE_FREQ=100000
POLICY_PATH=~/.cache/openpi/openpi-assets/checkpoints/pi0_fast_base_pytorch/
POLICY=pi0fast
python lerobot/scripts/train.py \
--policy.type=$POLICY \
--policy.path=$POLICY_PATH \
--dataset.repo_id=$REPO_ID \
--env.type=$ENV \
--env.task=$TASK \
--output_dir=$OUT_DIR \
--eval_freq=$EVAL_FREQ
--eval_freq=$EVAL_FREQ \
--steps=$OFFLINE_STEP \
--save_freq=$SAVE_FREQ

View File

@ -154,7 +154,12 @@ def rollout(
observation = {
key: observation[key].to(device, non_blocking=device.type == "cuda") for key in observation
}
if hasattr(env.envs[0], "task_description"):
observation["task"] = env.call("task_description")
elif hasattr(env.envs[0], "task"):
observation["task"] = env.call("task")
else:
observation["task"] = ["" for _ in range(observation[list(observation.keys())[0]].shape[0])]
with torch.inference_mode():
action = policy.select_action(observation)