testing
This commit is contained in:
parent
0a4bba1da7
commit
7a45fa0fc1
|
@ -46,9 +46,6 @@
|
|||
]
|
||||
}
|
||||
},
|
||||
"use_env_state": true,
|
||||
"exclude_image_keys": "",
|
||||
"normalize_per_robot_type": false,
|
||||
"chunk_size": 10,
|
||||
"n_action_steps": 5,
|
||||
"max_state_dim": 32,
|
||||
|
|
|
@ -175,6 +175,7 @@ class PI0FASTPolicy(PreTrainedPolicy):
|
|||
|
||||
self.language_tokenizer = AutoProcessor.from_pretrained("google/paligemma-3b-pt-224")
|
||||
self.model = PI0FAST(config)
|
||||
self.adapt_to_pi_aloha = True #self.config.adapt_to_pi_aloha # FIXME(mshukor): debug
|
||||
|
||||
self.reset()
|
||||
|
||||
|
@ -222,7 +223,7 @@ class PI0FASTPolicy(PreTrainedPolicy):
|
|||
"""
|
||||
self.eval()
|
||||
|
||||
if self.config.adapt_to_pi_aloha:
|
||||
if self.adapt_to_pi_aloha:
|
||||
batch[OBS_ROBOT] = self._pi_aloha_decode_state(batch[OBS_ROBOT])
|
||||
|
||||
batch = self.normalize_inputs(batch)
|
||||
|
@ -241,7 +242,7 @@ class PI0FASTPolicy(PreTrainedPolicy):
|
|||
|
||||
actions = self.unnormalize_outputs({"action": actions})["action"]
|
||||
|
||||
if self.config.adapt_to_pi_aloha:
|
||||
if self.adapt_to_pi_aloha:
|
||||
actions = self._pi_aloha_encode_actions(actions)
|
||||
|
||||
# `self.model.forward` returns a (batch_size, n_action_steps, action_dim) tensor, but the queue
|
||||
|
@ -250,7 +251,7 @@ class PI0FASTPolicy(PreTrainedPolicy):
|
|||
return self._action_queue.popleft()
|
||||
|
||||
def forward(self, batch: dict[str, Tensor]) -> dict[str, Tensor]:
|
||||
if self.config.adapt_to_pi_aloha:
|
||||
if self.adapt_to_pi_aloha:
|
||||
batch[OBS_ROBOT] = self._pi_aloha_decode_state(batch[OBS_ROBOT])
|
||||
batch[ACTION] = self._pi_aloha_encode_actions_inv(batch[ACTION])
|
||||
batch = self.normalize_inputs(batch)
|
||||
|
|
|
@ -9,15 +9,19 @@ TASK=AlohaTransferCube-v0
|
|||
REPO_ID=lerobot/aloha_sim_transfer_cube_human
|
||||
OUT_DIR=~/logs/lerobot/tmp/act_aloha_transfer
|
||||
|
||||
EVAL_FREQ=50
|
||||
EVAL_FREQ=5000
|
||||
OFFLINE_STEP=30000
|
||||
SAVE_FREQ=100000
|
||||
|
||||
POLICY_PATH=~/.cache/openpi/openpi-assets/checkpoints/pi0_fast_base_pytorch/
|
||||
POLICY=pi0fast
|
||||
|
||||
python lerobot/scripts/train.py \
|
||||
--policy.type=$POLICY \
|
||||
--policy.path=$POLICY_PATH \
|
||||
--dataset.repo_id=$REPO_ID \
|
||||
--env.type=$ENV \
|
||||
--env.task=$TASK \
|
||||
--output_dir=$OUT_DIR \
|
||||
--eval_freq=$EVAL_FREQ
|
||||
--eval_freq=$EVAL_FREQ \
|
||||
--steps=$OFFLINE_STEP \
|
||||
--save_freq=$SAVE_FREQ
|
|
@ -154,7 +154,12 @@ def rollout(
|
|||
observation = {
|
||||
key: observation[key].to(device, non_blocking=device.type == "cuda") for key in observation
|
||||
}
|
||||
|
||||
if hasattr(env.envs[0], "task_description"):
|
||||
observation["task"] = env.call("task_description")
|
||||
elif hasattr(env.envs[0], "task"):
|
||||
observation["task"] = env.call("task")
|
||||
else:
|
||||
observation["task"] = ["" for _ in range(observation[list(observation.keys())[0]].shape[0])]
|
||||
with torch.inference_mode():
|
||||
action = policy.select_action(observation)
|
||||
|
||||
|
|
Loading…
Reference in New Issue