This commit is contained in:
mshukor 2025-03-26 13:21:19 +01:00
parent 0a4bba1da7
commit 7a45fa0fc1
4 changed files with 17 additions and 10 deletions

View File

@ -46,9 +46,6 @@
] ]
} }
}, },
"use_env_state": true,
"exclude_image_keys": "",
"normalize_per_robot_type": false,
"chunk_size": 10, "chunk_size": 10,
"n_action_steps": 5, "n_action_steps": 5,
"max_state_dim": 32, "max_state_dim": 32,

View File

@ -175,6 +175,7 @@ class PI0FASTPolicy(PreTrainedPolicy):
self.language_tokenizer = AutoProcessor.from_pretrained("google/paligemma-3b-pt-224") self.language_tokenizer = AutoProcessor.from_pretrained("google/paligemma-3b-pt-224")
self.model = PI0FAST(config) self.model = PI0FAST(config)
self.adapt_to_pi_aloha = True #self.config.adapt_to_pi_aloha # FIXME(mshukor): debug
self.reset() self.reset()
@ -222,7 +223,7 @@ class PI0FASTPolicy(PreTrainedPolicy):
""" """
self.eval() self.eval()
if self.config.adapt_to_pi_aloha: if self.adapt_to_pi_aloha:
batch[OBS_ROBOT] = self._pi_aloha_decode_state(batch[OBS_ROBOT]) batch[OBS_ROBOT] = self._pi_aloha_decode_state(batch[OBS_ROBOT])
batch = self.normalize_inputs(batch) batch = self.normalize_inputs(batch)
@ -241,7 +242,7 @@ class PI0FASTPolicy(PreTrainedPolicy):
actions = self.unnormalize_outputs({"action": actions})["action"] actions = self.unnormalize_outputs({"action": actions})["action"]
if self.config.adapt_to_pi_aloha: if self.adapt_to_pi_aloha:
actions = self._pi_aloha_encode_actions(actions) actions = self._pi_aloha_encode_actions(actions)
# `self.model.forward` returns a (batch_size, n_action_steps, action_dim) tensor, but the queue # `self.model.forward` returns a (batch_size, n_action_steps, action_dim) tensor, but the queue
@ -250,7 +251,7 @@ class PI0FASTPolicy(PreTrainedPolicy):
return self._action_queue.popleft() return self._action_queue.popleft()
def forward(self, batch: dict[str, Tensor]) -> dict[str, Tensor]: def forward(self, batch: dict[str, Tensor]) -> dict[str, Tensor]:
if self.config.adapt_to_pi_aloha: if self.adapt_to_pi_aloha:
batch[OBS_ROBOT] = self._pi_aloha_decode_state(batch[OBS_ROBOT]) batch[OBS_ROBOT] = self._pi_aloha_decode_state(batch[OBS_ROBOT])
batch[ACTION] = self._pi_aloha_encode_actions_inv(batch[ACTION]) batch[ACTION] = self._pi_aloha_encode_actions_inv(batch[ACTION])
batch = self.normalize_inputs(batch) batch = self.normalize_inputs(batch)

View File

@ -9,15 +9,19 @@ TASK=AlohaTransferCube-v0
REPO_ID=lerobot/aloha_sim_transfer_cube_human REPO_ID=lerobot/aloha_sim_transfer_cube_human
OUT_DIR=~/logs/lerobot/tmp/act_aloha_transfer OUT_DIR=~/logs/lerobot/tmp/act_aloha_transfer
EVAL_FREQ=50 EVAL_FREQ=5000
OFFLINE_STEP=30000
SAVE_FREQ=100000
POLICY_PATH=~/.cache/openpi/openpi-assets/checkpoints/pi0_fast_base_pytorch/ POLICY_PATH=~/.cache/openpi/openpi-assets/checkpoints/pi0_fast_base_pytorch/
POLICY=pi0fast POLICY=pi0fast
python lerobot/scripts/train.py \ python lerobot/scripts/train.py \
--policy.type=$POLICY \ --policy.path=$POLICY_PATH \
--dataset.repo_id=$REPO_ID \ --dataset.repo_id=$REPO_ID \
--env.type=$ENV \ --env.type=$ENV \
--env.task=$TASK \ --env.task=$TASK \
--output_dir=$OUT_DIR \ --output_dir=$OUT_DIR \
--eval_freq=$EVAL_FREQ --eval_freq=$EVAL_FREQ \
--steps=$OFFLINE_STEP \
--save_freq=$SAVE_FREQ

View File

@ -154,7 +154,12 @@ def rollout(
observation = { observation = {
key: observation[key].to(device, non_blocking=device.type == "cuda") for key in observation key: observation[key].to(device, non_blocking=device.type == "cuda") for key in observation
} }
if hasattr(env.envs[0], "task_description"):
observation["task"] = env.call("task_description")
elif hasattr(env.envs[0], "task"):
observation["task"] = env.call("task")
else:
observation["task"] = ["" for _ in range(observation[list(observation.keys())[0]].shape[0])]
with torch.inference_mode(): with torch.inference_mode():
action = policy.select_action(observation) action = policy.select_action(observation)