diff --git a/lerobot/common/policies/pi0/configuration_pi0.py b/lerobot/common/policies/pi0/configuration_pi0.py index 8c7cc130..98823207 100644 --- a/lerobot/common/policies/pi0/configuration_pi0.py +++ b/lerobot/common/policies/pi0/configuration_pi0.py @@ -75,6 +75,10 @@ class PI0Config(PreTrainedConfig): train_expert_only: bool = False train_state_proj: bool = True + # Useful if you want to reproduce the training of pi0_base from + # the original initialization of PaliGemma (before training on robotics data). + paligemma_pretrained_path: str | None = "google/paligemma-3b-pt-224" + # Training presets optimizer_lr: float = 2.5e-5 optimizer_betas: tuple[float, float] = (0.9, 0.95) diff --git a/lerobot/common/policies/pi0/conversion_scripts/compare_with_jax.py b/lerobot/common/policies/pi0/conversion_scripts/compare_with_jax.py index 6bd7c91f..1aa7ea4c 100644 --- a/lerobot/common/policies/pi0/conversion_scripts/compare_with_jax.py +++ b/lerobot/common/policies/pi0/conversion_scripts/compare_with_jax.py @@ -66,6 +66,12 @@ def main(): dataset_meta.stats["observation.state"]["std"] = torch.tensor( norm_stats["norm_stats"]["state"]["std"][:num_motors], dtype=torch.float32 ) + dataset_meta.stats["action"]["mean"] = torch.tensor( + norm_stats["norm_stats"]["actions"]["mean"][:num_motors], dtype=torch.float32 + ) + dataset_meta.stats["action"]["std"] = torch.tensor( + norm_stats["norm_stats"]["actions"]["std"][:num_motors], dtype=torch.float32 + ) # Create LeRobot batch from Jax batch = {} diff --git a/lerobot/common/policies/pi0/conversion_scripts/convert_pi0_to_hf_lerobot.py b/lerobot/common/policies/pi0/conversion_scripts/convert_pi0_to_hf_lerobot.py index 73ff506f..6e6b8c2e 100644 --- a/lerobot/common/policies/pi0/conversion_scripts/convert_pi0_to_hf_lerobot.py +++ b/lerobot/common/policies/pi0/conversion_scripts/convert_pi0_to_hf_lerobot.py @@ -360,17 +360,20 @@ def convert_pi0_checkpoint(checkpoint_dir: str, precision: str, tokenizer_id: st empty_cameras=2, adapt_to_pi_aloha=True, use_delta_joint_actions_aloha=False, + paligemma_pretrained_path=None, ) elif "pi0_aloha_towel" in checkpoint_dir: pi0_config = PI0Config( adapt_to_pi_aloha=True, use_delta_joint_actions_aloha=True, + paligemma_pretrained_path=None, ) elif "pi0_base" in checkpoint_dir: pi0_config = PI0Config( empty_cameras=0, adapt_to_pi_aloha=False, use_delta_joint_actions_aloha=False, + paligemma_pretrained_path=None, ) else: raise ValueError() diff --git a/lerobot/common/policies/pi0/modeling_pi0.py b/lerobot/common/policies/pi0/modeling_pi0.py index 7599fa63..624fa4eb 100644 --- a/lerobot/common/policies/pi0/modeling_pi0.py +++ b/lerobot/common/policies/pi0/modeling_pi0.py @@ -44,7 +44,7 @@ python lerobot/scripts/train.py \ Example of using the pi0 pretrained model outside LeRobot training framework: ```python -policy = Pi0Policy.from_pretrained("lerobot/pi0") +policy = PI0Policy.from_pretrained("lerobot/pi0") ``` """ @@ -470,6 +470,7 @@ class PI0FlowMatching(nn.Module): paligemma_with_export_config = PaliGemmaWithExpertConfig( freeze_vision_encoder=self.config.freeze_vision_encoder, train_expert_only=self.config.train_expert_only, + paligemma_pretrained_path=self.config.paligemma_pretrained_path, attention_implementation=self.config.attention_implementation, ) self.paligemma_with_expert = PaliGemmaWithExpertModel(paligemma_with_export_config) diff --git a/lerobot/common/policies/pi0/paligemma_with_expert.py b/lerobot/common/policies/pi0/paligemma_with_expert.py index 76e2ce60..527cad29 100644 --- a/lerobot/common/policies/pi0/paligemma_with_expert.py +++ b/lerobot/common/policies/pi0/paligemma_with_expert.py @@ -20,6 +20,7 @@ from pytest import Cache from torch import nn from transformers import ( AutoConfig, + AutoModel, GemmaForCausalLM, PaliGemmaForConditionalGeneration, PretrainedConfig, @@ -66,11 +67,13 @@ class PaliGemmaWithExpertConfig(PretrainedConfig): gemma_expert_config: dict | None = None, freeze_vision_encoder: bool = True, train_expert_only: bool = True, + paligemma_pretrained_path: str | None = "google/paligemma-3b-pt-224", attention_implementation: str = "eager", **kwargs, ): self.freeze_vision_encoder = freeze_vision_encoder self.train_expert_only = train_expert_only + self.paligemma_pretrained_path = paligemma_pretrained_path self.attention_implementation = attention_implementation if paligemma_config is None: @@ -167,6 +170,11 @@ class PaliGemmaWithExpertConfig(PretrainedConfig): f"Wrong value provided for `attention_implementation` ({self.attention_implementation}). Expected 'eager', 'fa2' or 'flex'." ) + if self.paligemma_pretrained_path is not None and self.paligemma_config is not None: + raise ValueError( + "When 'paligemma_pretrained_path' is provided, 'paligemma_config' needs to be None." + ) + class PaliGemmaWithExpertModel(PreTrainedModel): config_class = PaliGemmaWithExpertConfig @@ -174,7 +182,10 @@ class PaliGemmaWithExpertModel(PreTrainedModel): def __init__(self, config: PaliGemmaWithExpertConfig): super().__init__(config=config) self.config = config - self.paligemma = PaliGemmaForConditionalGeneration(config=config.paligemma_config) + if config.paligemma_pretrained_path is not None: + self.paligemma = AutoModel(config.paligemma_pretrained_path) + else: + self.paligemma = PaliGemmaForConditionalGeneration(config=config.paligemma_config) self.gemma_expert = GemmaForCausalLM(config=config.gemma_expert_config) # Remove unused embed_tokens self.gemma_expert.model.embed_tokens = None