testing adapting aloha

This commit is contained in:
mshukor 2025-03-24 18:02:15 +01:00
parent 85eea19264
commit 0a4bba1da7
4 changed files with 168 additions and 3 deletions

View File

@ -0,0 +1,82 @@
{
"type": "pi0fast",
"n_obs_steps": 1,
"normalization_mapping": {
"VISUAL": "IDENTITY",
"STATE": "MEAN_STD",
"ACTION": "MEAN_STD"
},
"input_features": {
"observation.image": {
"type": "VISUAL",
"shape": [
3,
256,
256
]
},
"observation.image2": {
"type": "VISUAL",
"shape": [
3,
256,
256
]
},
"observation.image3": {
"type": "VISUAL",
"shape": [
3,
256,
256
]
},
"observation.state": {
"type": "STATE",
"shape": [
8
]
}
},
"output_features": {
"action": {
"type": "ACTION",
"shape": [
7
]
}
},
"use_env_state": true,
"exclude_image_keys": "",
"normalize_per_robot_type": false,
"chunk_size": 10,
"n_action_steps": 5,
"max_state_dim": 32,
"max_action_dim": 32,
"resize_imgs_with_padding": [
224,
224
],
"interpolate_like_pi": false,
"empty_cameras": 0,
"adapt_to_pi_aloha": false,
"use_delta_joint_actions_aloha": false,
"tokenizer_max_length": 48,
"proj_width": 1024,
"max_decoding_steps": 256,
"fast_skip_tokens": 128,
"max_input_seq_len": 256,
"use_cache": true,
"freeze_vision_encoder": true,
"freeze_lm_head": true,
"optimizer_lr": 0.0001,
"optimizer_betas": [
0.9,
0.95
],
"optimizer_eps": 1e-08,
"optimizer_weight_decay": 1e-10,
"scheduler_warmup_steps": 1000,
"scheduler_decay_steps": 30000,
"scheduler_decay_lr": 2.5e-06
}

View File

@ -21,8 +21,8 @@ class PEFTConfig:
class PI0FASTConfig(PreTrainedConfig):
# Input / output structure.
n_obs_steps: int = 1
chunk_size: int = 51
n_action_steps: int = 50
chunk_size: int = 10
n_action_steps: int = 5
normalization_mapping: dict[str, NormalizationMode] = field(
default_factory=lambda: {

View File

@ -87,6 +87,61 @@ def display(tensor: torch.Tensor):
print(f"Max: {tensor.max().item()}")
def normalize(x, min_val, max_val):
return (x - min_val) / (max_val - min_val)
def unnormalize(x, min_val, max_val):
return x * (max_val - min_val) + min_val
def safe_arcsin(value):
# This ensures that the input stays within
# [1,1] to avoid invalid values for arcsin
return torch.arcsin(torch.clamp(value, -1.0, 1.0))
def aloha_gripper_to_angular(value):
# Aloha transforms the gripper positions into a linear space. The following code
# reverses this transformation to be consistent with pi0 which is pretrained in
# angular space.
#
# These values are coming from the Aloha code:
# PUPPET_GRIPPER_POSITION_OPEN, PUPPET_GRIPPER_POSITION_CLOSED
value = unnormalize(value, min_val=0.01844, max_val=0.05800)
# This is the inverse of the angular to linear transformation inside the Interbotix code.
def linear_to_radian(linear_position, arm_length, horn_radius):
value = (horn_radius**2 + linear_position**2 - arm_length**2) / (2 * horn_radius * linear_position)
return safe_arcsin(value)
# The constants are taken from the Interbotix code.
value = linear_to_radian(value, arm_length=0.036, horn_radius=0.022)
# Normalize to [0, 1].
# The values 0.4 and 1.5 were measured on an actual Trossen robot.
return normalize(value, min_val=0.4, max_val=1.5)
def aloha_gripper_from_angular(value):
# Convert from the gripper position used by pi0 to the gripper position that is used by Aloha.
# Note that the units are still angular but the range is different.
# The values 0.4 and 1.5 were measured on an actual Trossen robot.
value = unnormalize(value, min_val=0.4, max_val=1.5)
# These values are coming from the Aloha code:
# PUPPET_GRIPPER_JOINT_OPEN, PUPPET_GRIPPER_JOINT_CLOSE
return normalize(value, min_val=-0.6213, max_val=1.4910)
def aloha_gripper_from_angular_inv(value):
# Directly inverts the gripper_from_angular function.
value = unnormalize(value, min_val=-0.6213, max_val=1.4910)
return normalize(value, min_val=0.4, max_val=1.5)
class PI0FASTPolicy(PreTrainedPolicy):
"""Wrapper class around PI0FAST tokenizer and model to train and run inference within LeRobot."""
@ -130,6 +185,33 @@ class PI0FASTPolicy(PreTrainedPolicy):
def get_optim_params(self) -> dict:
return self.parameters()
def _pi_aloha_decode_state(self, state):
# Flip the joints.
for motor_idx in [1, 2, 8, 9]:
state[:, motor_idx] *= -1
# Reverse the gripper transformation that is being applied by the Aloha runtime.
for motor_idx in [6, 13]:
state[:, motor_idx] = aloha_gripper_to_angular(state[:, motor_idx])
return state
def _pi_aloha_encode_actions(self, actions):
# Flip the joints.
for motor_idx in [1, 2, 8, 9]:
actions[:, :, motor_idx] *= -1
# Reverse the gripper transformation that is being applied by the Aloha runtime.
for motor_idx in [6, 13]:
actions[:, :, motor_idx] = aloha_gripper_from_angular(actions[:, :, motor_idx])
return actions
def _pi_aloha_encode_actions_inv(self, actions):
# Flip the joints again.
for motor_idx in [1, 2, 8, 9]:
actions[:, :, motor_idx] *= -1
# Reverse the gripper transformation that is being applied by the Aloha runtime.
for motor_idx in [6, 13]:
actions[:, :, motor_idx] = aloha_gripper_from_angular_inv(actions[:, :, motor_idx])
return actions
@torch.no_grad
def select_action(self, batch: dict[str, Tensor]) -> Tensor:
"""Select a single action given environment observations.

View File

@ -9,8 +9,9 @@ TASK=AlohaTransferCube-v0
REPO_ID=lerobot/aloha_sim_transfer_cube_human
OUT_DIR=~/logs/lerobot/tmp/act_aloha_transfer
EVAL_FREQ=5
EVAL_FREQ=50
POLICY_PATH=~/.cache/openpi/openpi-assets/checkpoints/pi0_fast_base_pytorch/
POLICY=pi0fast
python lerobot/scripts/train.py \