Improve pi0 (WIP)

This commit is contained in:
Remi Cadene 2025-02-07 13:38:47 +01:00
parent 638d411cd3
commit 38e36d2d4c
5 changed files with 27 additions and 2 deletions

View File

@ -61,6 +61,10 @@ class PI0Config(PreTrainedConfig):
train_expert_only: bool = False train_expert_only: bool = False
train_state_proj: bool = True train_state_proj: bool = True
# Useful if you want to reproduce the training of pi0_base from
# the original initialization of PaliGemma (before training on robotics data).
paligemma_pretrained_path: str | None = "google/paligemma-3b-pt-224"
# Training presets # Training presets
optimizer_lr: float = 2.5e-5 optimizer_lr: float = 2.5e-5
optimizer_betas: tuple[float, float] = (0.9, 0.95) optimizer_betas: tuple[float, float] = (0.9, 0.95)

View File

@ -52,6 +52,12 @@ def main():
dataset_meta.stats["observation.state"]["std"] = torch.tensor( dataset_meta.stats["observation.state"]["std"] = torch.tensor(
norm_stats["norm_stats"]["state"]["std"][:num_motors], dtype=torch.float32 norm_stats["norm_stats"]["state"]["std"][:num_motors], dtype=torch.float32
) )
dataset_meta.stats["action"]["mean"] = torch.tensor(
norm_stats["norm_stats"]["actions"]["mean"][:num_motors], dtype=torch.float32
)
dataset_meta.stats["action"]["std"] = torch.tensor(
norm_stats["norm_stats"]["actions"]["std"][:num_motors], dtype=torch.float32
)
# Create LeRobot batch from Jax # Create LeRobot batch from Jax
batch = {} batch = {}

View File

@ -346,17 +346,20 @@ def convert_pi0_checkpoint(checkpoint_dir: str, precision: str, tokenizer_id: st
empty_cameras=2, empty_cameras=2,
adapt_to_pi_aloha=True, adapt_to_pi_aloha=True,
use_delta_joint_actions_aloha=False, use_delta_joint_actions_aloha=False,
paligemma_pretrained_path=None,
) )
elif "pi0_aloha_towel" in checkpoint_dir: elif "pi0_aloha_towel" in checkpoint_dir:
pi0_config = PI0Config( pi0_config = PI0Config(
adapt_to_pi_aloha=True, adapt_to_pi_aloha=True,
use_delta_joint_actions_aloha=True, use_delta_joint_actions_aloha=True,
paligemma_pretrained_path=None,
) )
elif "pi0_base" in checkpoint_dir: elif "pi0_base" in checkpoint_dir:
pi0_config = PI0Config( pi0_config = PI0Config(
empty_cameras=0, empty_cameras=0,
adapt_to_pi_aloha=False, adapt_to_pi_aloha=False,
use_delta_joint_actions_aloha=False, use_delta_joint_actions_aloha=False,
paligemma_pretrained_path=None,
) )
else: else:
raise ValueError() raise ValueError()

View File

@ -44,7 +44,7 @@ python lerobot/scripts/train.py \
Example of using the pi0 pretrained model outside LeRobot training framework: Example of using the pi0 pretrained model outside LeRobot training framework:
```python ```python
policy = Pi0Policy.from_pretrained("lerobot/pi0") policy = PI0Policy.from_pretrained("lerobot/pi0")
``` ```
""" """
@ -470,6 +470,7 @@ class PI0FlowMatching(nn.Module):
paligemma_with_export_config = PaliGemmaWithExpertConfig( paligemma_with_export_config = PaliGemmaWithExpertConfig(
freeze_vision_encoder=self.config.freeze_vision_encoder, freeze_vision_encoder=self.config.freeze_vision_encoder,
train_expert_only=self.config.train_expert_only, train_expert_only=self.config.train_expert_only,
paligemma_pretrained_path=self.config.paligemma_pretrained_path,
attention_implementation=self.config.attention_implementation, attention_implementation=self.config.attention_implementation,
) )
self.paligemma_with_expert = PaliGemmaWithExpertModel(paligemma_with_export_config) self.paligemma_with_expert = PaliGemmaWithExpertModel(paligemma_with_export_config)

View File

@ -6,6 +6,7 @@ from pytest import Cache
from torch import nn from torch import nn
from transformers import ( from transformers import (
AutoConfig, AutoConfig,
AutoModel,
GemmaForCausalLM, GemmaForCausalLM,
PaliGemmaForConditionalGeneration, PaliGemmaForConditionalGeneration,
PretrainedConfig, PretrainedConfig,
@ -52,11 +53,13 @@ class PaliGemmaWithExpertConfig(PretrainedConfig):
gemma_expert_config: dict | None = None, gemma_expert_config: dict | None = None,
freeze_vision_encoder: bool = True, freeze_vision_encoder: bool = True,
train_expert_only: bool = True, train_expert_only: bool = True,
paligemma_pretrained_path: str | None = "google/paligemma-3b-pt-224",
attention_implementation: str = "eager", attention_implementation: str = "eager",
**kwargs, **kwargs,
): ):
self.freeze_vision_encoder = freeze_vision_encoder self.freeze_vision_encoder = freeze_vision_encoder
self.train_expert_only = train_expert_only self.train_expert_only = train_expert_only
self.paligemma_pretrained_path = paligemma_pretrained_path
self.attention_implementation = attention_implementation self.attention_implementation = attention_implementation
if paligemma_config is None: if paligemma_config is None:
@ -153,6 +156,11 @@ class PaliGemmaWithExpertConfig(PretrainedConfig):
f"Wrong value provided for `attention_implementation` ({self.attention_implementation}). Expected 'eager', 'fa2' or 'flex'." f"Wrong value provided for `attention_implementation` ({self.attention_implementation}). Expected 'eager', 'fa2' or 'flex'."
) )
if self.paligemma_pretrained_path is not None and self.paligemma_config is not None:
raise ValueError(
"When 'paligemma_pretrained_path' is provided, 'paligemma_config' needs to be None."
)
class PaliGemmaWithExpertModel(PreTrainedModel): class PaliGemmaWithExpertModel(PreTrainedModel):
config_class = PaliGemmaWithExpertConfig config_class = PaliGemmaWithExpertConfig
@ -160,7 +168,10 @@ class PaliGemmaWithExpertModel(PreTrainedModel):
def __init__(self, config: PaliGemmaWithExpertConfig): def __init__(self, config: PaliGemmaWithExpertConfig):
super().__init__(config=config) super().__init__(config=config)
self.config = config self.config = config
self.paligemma = PaliGemmaForConditionalGeneration(config=config.paligemma_config) if config.paligemma_pretrained_path is not None:
self.paligemma = AutoModel(config.paligemma_pretrained_path)
else:
self.paligemma = PaliGemmaForConditionalGeneration(config=config.paligemma_config)
self.gemma_expert = GemmaForCausalLM(config=config.gemma_expert_config) self.gemma_expert = GemmaForCausalLM(config=config.gemma_expert_config)
# Remove unused embed_tokens # Remove unused embed_tokens
self.gemma_expert.model.embed_tokens = None self.gemma_expert.model.embed_tokens = None