From 38e36d2d4c1535d89247cf043f08a6bf5198465f Mon Sep 17 00:00:00 2001 From: Remi Cadene Date: Fri, 7 Feb 2025 13:38:47 +0100 Subject: [PATCH] Improve pi0 (WIP) --- lerobot/common/policies/pi0/configuration_pi0.py | 4 ++++ .../pi0/conversion_scripts/compare_with_jax.py | 6 ++++++ .../conversion_scripts/convert_pi0_to_hf_lerobot.py | 3 +++ lerobot/common/policies/pi0/modeling_pi0.py | 3 ++- .../common/policies/pi0/paligemma_with_expert.py | 13 ++++++++++++- 5 files changed, 27 insertions(+), 2 deletions(-) diff --git a/lerobot/common/policies/pi0/configuration_pi0.py b/lerobot/common/policies/pi0/configuration_pi0.py index 8d2eedf6..921aa196 100644 --- a/lerobot/common/policies/pi0/configuration_pi0.py +++ b/lerobot/common/policies/pi0/configuration_pi0.py @@ -61,6 +61,10 @@ class PI0Config(PreTrainedConfig): train_expert_only: bool = False train_state_proj: bool = True + # Useful if you want to reproduce the training of pi0_base from + # the original initialization of PaliGemma (before training on robotics data). + paligemma_pretrained_path: str | None = "google/paligemma-3b-pt-224" + # Training presets optimizer_lr: float = 2.5e-5 optimizer_betas: tuple[float, float] = (0.9, 0.95) diff --git a/lerobot/common/policies/pi0/conversion_scripts/compare_with_jax.py b/lerobot/common/policies/pi0/conversion_scripts/compare_with_jax.py index 8b2e1c66..2f946a38 100644 --- a/lerobot/common/policies/pi0/conversion_scripts/compare_with_jax.py +++ b/lerobot/common/policies/pi0/conversion_scripts/compare_with_jax.py @@ -52,6 +52,12 @@ def main(): dataset_meta.stats["observation.state"]["std"] = torch.tensor( norm_stats["norm_stats"]["state"]["std"][:num_motors], dtype=torch.float32 ) + dataset_meta.stats["action"]["mean"] = torch.tensor( + norm_stats["norm_stats"]["actions"]["mean"][:num_motors], dtype=torch.float32 + ) + dataset_meta.stats["action"]["std"] = torch.tensor( + norm_stats["norm_stats"]["actions"]["std"][:num_motors], dtype=torch.float32 + ) # Create LeRobot batch from Jax batch = {} diff --git a/lerobot/common/policies/pi0/conversion_scripts/convert_pi0_to_hf_lerobot.py b/lerobot/common/policies/pi0/conversion_scripts/convert_pi0_to_hf_lerobot.py index f85437a5..3770b833 100644 --- a/lerobot/common/policies/pi0/conversion_scripts/convert_pi0_to_hf_lerobot.py +++ b/lerobot/common/policies/pi0/conversion_scripts/convert_pi0_to_hf_lerobot.py @@ -346,17 +346,20 @@ def convert_pi0_checkpoint(checkpoint_dir: str, precision: str, tokenizer_id: st empty_cameras=2, adapt_to_pi_aloha=True, use_delta_joint_actions_aloha=False, + paligemma_pretrained_path=None, ) elif "pi0_aloha_towel" in checkpoint_dir: pi0_config = PI0Config( adapt_to_pi_aloha=True, use_delta_joint_actions_aloha=True, + paligemma_pretrained_path=None, ) elif "pi0_base" in checkpoint_dir: pi0_config = PI0Config( empty_cameras=0, adapt_to_pi_aloha=False, use_delta_joint_actions_aloha=False, + paligemma_pretrained_path=None, ) else: raise ValueError() diff --git a/lerobot/common/policies/pi0/modeling_pi0.py b/lerobot/common/policies/pi0/modeling_pi0.py index 90d1a14c..1ecb327c 100644 --- a/lerobot/common/policies/pi0/modeling_pi0.py +++ b/lerobot/common/policies/pi0/modeling_pi0.py @@ -44,7 +44,7 @@ python lerobot/scripts/train.py \ Example of using the pi0 pretrained model outside LeRobot training framework: ```python -policy = Pi0Policy.from_pretrained("lerobot/pi0") +policy = PI0Policy.from_pretrained("lerobot/pi0") ``` """ @@ -470,6 +470,7 @@ class PI0FlowMatching(nn.Module): paligemma_with_export_config = PaliGemmaWithExpertConfig( freeze_vision_encoder=self.config.freeze_vision_encoder, train_expert_only=self.config.train_expert_only, + paligemma_pretrained_path=self.config.paligemma_pretrained_path, attention_implementation=self.config.attention_implementation, ) self.paligemma_with_expert = PaliGemmaWithExpertModel(paligemma_with_export_config) diff --git a/lerobot/common/policies/pi0/paligemma_with_expert.py b/lerobot/common/policies/pi0/paligemma_with_expert.py index 08c36c11..21ddd28b 100644 --- a/lerobot/common/policies/pi0/paligemma_with_expert.py +++ b/lerobot/common/policies/pi0/paligemma_with_expert.py @@ -6,6 +6,7 @@ from pytest import Cache from torch import nn from transformers import ( AutoConfig, + AutoModel, GemmaForCausalLM, PaliGemmaForConditionalGeneration, PretrainedConfig, @@ -52,11 +53,13 @@ class PaliGemmaWithExpertConfig(PretrainedConfig): gemma_expert_config: dict | None = None, freeze_vision_encoder: bool = True, train_expert_only: bool = True, + paligemma_pretrained_path: str | None = "google/paligemma-3b-pt-224", attention_implementation: str = "eager", **kwargs, ): self.freeze_vision_encoder = freeze_vision_encoder self.train_expert_only = train_expert_only + self.paligemma_pretrained_path = paligemma_pretrained_path self.attention_implementation = attention_implementation if paligemma_config is None: @@ -153,6 +156,11 @@ class PaliGemmaWithExpertConfig(PretrainedConfig): f"Wrong value provided for `attention_implementation` ({self.attention_implementation}). Expected 'eager', 'fa2' or 'flex'." ) + if self.paligemma_pretrained_path is not None and self.paligemma_config is not None: + raise ValueError( + "When 'paligemma_pretrained_path' is provided, 'paligemma_config' needs to be None." + ) + class PaliGemmaWithExpertModel(PreTrainedModel): config_class = PaliGemmaWithExpertConfig @@ -160,7 +168,10 @@ class PaliGemmaWithExpertModel(PreTrainedModel): def __init__(self, config: PaliGemmaWithExpertConfig): super().__init__(config=config) self.config = config - self.paligemma = PaliGemmaForConditionalGeneration(config=config.paligemma_config) + if config.paligemma_pretrained_path is not None: + self.paligemma = AutoModel(config.paligemma_pretrained_path) + else: + self.paligemma = PaliGemmaForConditionalGeneration(config=config.paligemma_config) self.gemma_expert = GemmaForCausalLM(config=config.gemma_expert_config) # Remove unused embed_tokens self.gemma_expert.model.embed_tokens = None