This commit is contained in:
Remi 2025-04-04 16:50:58 +03:00 committed by GitHub
commit 9896e0f830
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 27 additions and 2 deletions

View File

@ -75,6 +75,10 @@ class PI0Config(PreTrainedConfig):
train_expert_only: bool = False
train_state_proj: bool = True
# Useful if you want to reproduce the training of pi0_base from
# the original initialization of PaliGemma (before training on robotics data).
paligemma_pretrained_path: str | None = "google/paligemma-3b-pt-224"
# Training presets
optimizer_lr: float = 2.5e-5
optimizer_betas: tuple[float, float] = (0.9, 0.95)

View File

@ -66,6 +66,12 @@ def main():
dataset_meta.stats["observation.state"]["std"] = torch.tensor(
norm_stats["norm_stats"]["state"]["std"][:num_motors], dtype=torch.float32
)
dataset_meta.stats["action"]["mean"] = torch.tensor(
norm_stats["norm_stats"]["actions"]["mean"][:num_motors], dtype=torch.float32
)
dataset_meta.stats["action"]["std"] = torch.tensor(
norm_stats["norm_stats"]["actions"]["std"][:num_motors], dtype=torch.float32
)
# Create LeRobot batch from Jax
batch = {}

View File

@ -360,17 +360,20 @@ def convert_pi0_checkpoint(checkpoint_dir: str, precision: str, tokenizer_id: st
empty_cameras=2,
adapt_to_pi_aloha=True,
use_delta_joint_actions_aloha=False,
paligemma_pretrained_path=None,
)
elif "pi0_aloha_towel" in checkpoint_dir:
pi0_config = PI0Config(
adapt_to_pi_aloha=True,
use_delta_joint_actions_aloha=True,
paligemma_pretrained_path=None,
)
elif "pi0_base" in checkpoint_dir:
pi0_config = PI0Config(
empty_cameras=0,
adapt_to_pi_aloha=False,
use_delta_joint_actions_aloha=False,
paligemma_pretrained_path=None,
)
else:
raise ValueError()

View File

@ -44,7 +44,7 @@ python lerobot/scripts/train.py \
Example of using the pi0 pretrained model outside LeRobot training framework:
```python
policy = Pi0Policy.from_pretrained("lerobot/pi0")
policy = PI0Policy.from_pretrained("lerobot/pi0")
```
"""
@ -470,6 +470,7 @@ class PI0FlowMatching(nn.Module):
paligemma_with_export_config = PaliGemmaWithExpertConfig(
freeze_vision_encoder=self.config.freeze_vision_encoder,
train_expert_only=self.config.train_expert_only,
paligemma_pretrained_path=self.config.paligemma_pretrained_path,
attention_implementation=self.config.attention_implementation,
)
self.paligemma_with_expert = PaliGemmaWithExpertModel(paligemma_with_export_config)

View File

@ -20,6 +20,7 @@ from pytest import Cache
from torch import nn
from transformers import (
AutoConfig,
AutoModel,
GemmaForCausalLM,
PaliGemmaForConditionalGeneration,
PretrainedConfig,
@ -66,11 +67,13 @@ class PaliGemmaWithExpertConfig(PretrainedConfig):
gemma_expert_config: dict | None = None,
freeze_vision_encoder: bool = True,
train_expert_only: bool = True,
paligemma_pretrained_path: str | None = "google/paligemma-3b-pt-224",
attention_implementation: str = "eager",
**kwargs,
):
self.freeze_vision_encoder = freeze_vision_encoder
self.train_expert_only = train_expert_only
self.paligemma_pretrained_path = paligemma_pretrained_path
self.attention_implementation = attention_implementation
if paligemma_config is None:
@ -167,6 +170,11 @@ class PaliGemmaWithExpertConfig(PretrainedConfig):
f"Wrong value provided for `attention_implementation` ({self.attention_implementation}). Expected 'eager', 'fa2' or 'flex'."
)
if self.paligemma_pretrained_path is not None and self.paligemma_config is not None:
raise ValueError(
"When 'paligemma_pretrained_path' is provided, 'paligemma_config' needs to be None."
)
class PaliGemmaWithExpertModel(PreTrainedModel):
config_class = PaliGemmaWithExpertConfig
@ -174,7 +182,10 @@ class PaliGemmaWithExpertModel(PreTrainedModel):
def __init__(self, config: PaliGemmaWithExpertConfig):
super().__init__(config=config)
self.config = config
self.paligemma = PaliGemmaForConditionalGeneration(config=config.paligemma_config)
if config.paligemma_pretrained_path is not None:
self.paligemma = AutoModel(config.paligemma_pretrained_path)
else:
self.paligemma = PaliGemmaForConditionalGeneration(config=config.paligemma_config)
self.gemma_expert = GemmaForCausalLM(config=config.gemma_expert_config)
# Remove unused embed_tokens
self.gemma_expert.model.embed_tokens = None