diff --git a/lerobot/common/policies/pi0fast/config.json b/lerobot/common/policies/pi0fast/config.json new file mode 100644 index 00000000..6c59f1b9 --- /dev/null +++ b/lerobot/common/policies/pi0fast/config.json @@ -0,0 +1,82 @@ +{ + "type": "pi0fast", + "n_obs_steps": 1, + "normalization_mapping": { + "VISUAL": "IDENTITY", + "STATE": "MEAN_STD", + "ACTION": "MEAN_STD" + }, + "input_features": { + "observation.image": { + "type": "VISUAL", + "shape": [ + 3, + 256, + 256 + ] + }, + "observation.image2": { + "type": "VISUAL", + "shape": [ + 3, + 256, + 256 + ] + }, + "observation.image3": { + "type": "VISUAL", + "shape": [ + 3, + 256, + 256 + ] + }, + "observation.state": { + "type": "STATE", + "shape": [ + 8 + ] + } + }, + "output_features": { + "action": { + "type": "ACTION", + "shape": [ + 7 + ] + } + }, + "use_env_state": true, + "exclude_image_keys": "", + "normalize_per_robot_type": false, + "chunk_size": 10, + "n_action_steps": 5, + "max_state_dim": 32, + "max_action_dim": 32, + "resize_imgs_with_padding": [ + 224, + 224 + ], + "interpolate_like_pi": false, + "empty_cameras": 0, + "adapt_to_pi_aloha": false, + "use_delta_joint_actions_aloha": false, + "tokenizer_max_length": 48, + "proj_width": 1024, + "max_decoding_steps": 256, + "fast_skip_tokens": 128, + "max_input_seq_len": 256, + "use_cache": true, + "freeze_vision_encoder": true, + "freeze_lm_head": true, + "optimizer_lr": 0.0001, + "optimizer_betas": [ + 0.9, + 0.95 + ], + "optimizer_eps": 1e-08, + "optimizer_weight_decay": 1e-10, + "scheduler_warmup_steps": 1000, + "scheduler_decay_steps": 30000, + "scheduler_decay_lr": 2.5e-06 +} \ No newline at end of file diff --git a/lerobot/common/policies/pi0fast/configuration_pi0fast.py b/lerobot/common/policies/pi0fast/configuration_pi0fast.py index 34bc3ebc..5cacc1a3 100644 --- a/lerobot/common/policies/pi0fast/configuration_pi0fast.py +++ b/lerobot/common/policies/pi0fast/configuration_pi0fast.py @@ -21,8 +21,8 @@ class PEFTConfig: class PI0FASTConfig(PreTrainedConfig): # Input / output structure. n_obs_steps: int = 1 - chunk_size: int = 51 - n_action_steps: int = 50 + chunk_size: int = 10 + n_action_steps: int = 5 normalization_mapping: dict[str, NormalizationMode] = field( default_factory=lambda: { diff --git a/lerobot/common/policies/pi0fast/modeling_pi0fast.py b/lerobot/common/policies/pi0fast/modeling_pi0fast.py index 5c63eb47..293fe6e1 100644 --- a/lerobot/common/policies/pi0fast/modeling_pi0fast.py +++ b/lerobot/common/policies/pi0fast/modeling_pi0fast.py @@ -87,6 +87,61 @@ def display(tensor: torch.Tensor): print(f"Max: {tensor.max().item()}") +def normalize(x, min_val, max_val): + return (x - min_val) / (max_val - min_val) + + +def unnormalize(x, min_val, max_val): + return x * (max_val - min_val) + min_val + + +def safe_arcsin(value): + # This ensures that the input stays within + # [−1,1] to avoid invalid values for arcsin + return torch.arcsin(torch.clamp(value, -1.0, 1.0)) + + +def aloha_gripper_to_angular(value): + # Aloha transforms the gripper positions into a linear space. The following code + # reverses this transformation to be consistent with pi0 which is pretrained in + # angular space. + # + # These values are coming from the Aloha code: + # PUPPET_GRIPPER_POSITION_OPEN, PUPPET_GRIPPER_POSITION_CLOSED + value = unnormalize(value, min_val=0.01844, max_val=0.05800) + + # This is the inverse of the angular to linear transformation inside the Interbotix code. + def linear_to_radian(linear_position, arm_length, horn_radius): + value = (horn_radius**2 + linear_position**2 - arm_length**2) / (2 * horn_radius * linear_position) + return safe_arcsin(value) + + # The constants are taken from the Interbotix code. + value = linear_to_radian(value, arm_length=0.036, horn_radius=0.022) + + # Normalize to [0, 1]. + # The values 0.4 and 1.5 were measured on an actual Trossen robot. + return normalize(value, min_val=0.4, max_val=1.5) + + +def aloha_gripper_from_angular(value): + # Convert from the gripper position used by pi0 to the gripper position that is used by Aloha. + # Note that the units are still angular but the range is different. + + # The values 0.4 and 1.5 were measured on an actual Trossen robot. + value = unnormalize(value, min_val=0.4, max_val=1.5) + + # These values are coming from the Aloha code: + # PUPPET_GRIPPER_JOINT_OPEN, PUPPET_GRIPPER_JOINT_CLOSE + return normalize(value, min_val=-0.6213, max_val=1.4910) + + +def aloha_gripper_from_angular_inv(value): + # Directly inverts the gripper_from_angular function. + value = unnormalize(value, min_val=-0.6213, max_val=1.4910) + return normalize(value, min_val=0.4, max_val=1.5) + + + class PI0FASTPolicy(PreTrainedPolicy): """Wrapper class around PI0FAST tokenizer and model to train and run inference within LeRobot.""" @@ -130,6 +185,33 @@ class PI0FASTPolicy(PreTrainedPolicy): def get_optim_params(self) -> dict: return self.parameters() + def _pi_aloha_decode_state(self, state): + # Flip the joints. + for motor_idx in [1, 2, 8, 9]: + state[:, motor_idx] *= -1 + # Reverse the gripper transformation that is being applied by the Aloha runtime. + for motor_idx in [6, 13]: + state[:, motor_idx] = aloha_gripper_to_angular(state[:, motor_idx]) + return state + + def _pi_aloha_encode_actions(self, actions): + # Flip the joints. + for motor_idx in [1, 2, 8, 9]: + actions[:, :, motor_idx] *= -1 + # Reverse the gripper transformation that is being applied by the Aloha runtime. + for motor_idx in [6, 13]: + actions[:, :, motor_idx] = aloha_gripper_from_angular(actions[:, :, motor_idx]) + return actions + + def _pi_aloha_encode_actions_inv(self, actions): + # Flip the joints again. + for motor_idx in [1, 2, 8, 9]: + actions[:, :, motor_idx] *= -1 + # Reverse the gripper transformation that is being applied by the Aloha runtime. + for motor_idx in [6, 13]: + actions[:, :, motor_idx] = aloha_gripper_from_angular_inv(actions[:, :, motor_idx]) + return actions + @torch.no_grad def select_action(self, batch: dict[str, Tensor]) -> Tensor: """Select a single action given environment observations. diff --git a/lerobot/common/policies/pi0fast/test_pi0fast.sh b/lerobot/common/policies/pi0fast/test_pi0fast.sh index 4852178d..9e889e9a 100644 --- a/lerobot/common/policies/pi0fast/test_pi0fast.sh +++ b/lerobot/common/policies/pi0fast/test_pi0fast.sh @@ -9,8 +9,9 @@ TASK=AlohaTransferCube-v0 REPO_ID=lerobot/aloha_sim_transfer_cube_human OUT_DIR=~/logs/lerobot/tmp/act_aloha_transfer -EVAL_FREQ=5 +EVAL_FREQ=50 +POLICY_PATH=~/.cache/openpi/openpi-assets/checkpoints/pi0_fast_base_pytorch/ POLICY=pi0fast python lerobot/scripts/train.py \