Improve pi0 (WIP)
This commit is contained in:
parent
638d411cd3
commit
38e36d2d4c
|
@ -61,6 +61,10 @@ class PI0Config(PreTrainedConfig):
|
|||
train_expert_only: bool = False
|
||||
train_state_proj: bool = True
|
||||
|
||||
# Useful if you want to reproduce the training of pi0_base from
|
||||
# the original initialization of PaliGemma (before training on robotics data).
|
||||
paligemma_pretrained_path: str | None = "google/paligemma-3b-pt-224"
|
||||
|
||||
# Training presets
|
||||
optimizer_lr: float = 2.5e-5
|
||||
optimizer_betas: tuple[float, float] = (0.9, 0.95)
|
||||
|
|
|
@ -52,6 +52,12 @@ def main():
|
|||
dataset_meta.stats["observation.state"]["std"] = torch.tensor(
|
||||
norm_stats["norm_stats"]["state"]["std"][:num_motors], dtype=torch.float32
|
||||
)
|
||||
dataset_meta.stats["action"]["mean"] = torch.tensor(
|
||||
norm_stats["norm_stats"]["actions"]["mean"][:num_motors], dtype=torch.float32
|
||||
)
|
||||
dataset_meta.stats["action"]["std"] = torch.tensor(
|
||||
norm_stats["norm_stats"]["actions"]["std"][:num_motors], dtype=torch.float32
|
||||
)
|
||||
|
||||
# Create LeRobot batch from Jax
|
||||
batch = {}
|
||||
|
|
|
@ -346,17 +346,20 @@ def convert_pi0_checkpoint(checkpoint_dir: str, precision: str, tokenizer_id: st
|
|||
empty_cameras=2,
|
||||
adapt_to_pi_aloha=True,
|
||||
use_delta_joint_actions_aloha=False,
|
||||
paligemma_pretrained_path=None,
|
||||
)
|
||||
elif "pi0_aloha_towel" in checkpoint_dir:
|
||||
pi0_config = PI0Config(
|
||||
adapt_to_pi_aloha=True,
|
||||
use_delta_joint_actions_aloha=True,
|
||||
paligemma_pretrained_path=None,
|
||||
)
|
||||
elif "pi0_base" in checkpoint_dir:
|
||||
pi0_config = PI0Config(
|
||||
empty_cameras=0,
|
||||
adapt_to_pi_aloha=False,
|
||||
use_delta_joint_actions_aloha=False,
|
||||
paligemma_pretrained_path=None,
|
||||
)
|
||||
else:
|
||||
raise ValueError()
|
||||
|
|
|
@ -44,7 +44,7 @@ python lerobot/scripts/train.py \
|
|||
|
||||
Example of using the pi0 pretrained model outside LeRobot training framework:
|
||||
```python
|
||||
policy = Pi0Policy.from_pretrained("lerobot/pi0")
|
||||
policy = PI0Policy.from_pretrained("lerobot/pi0")
|
||||
```
|
||||
|
||||
"""
|
||||
|
@ -470,6 +470,7 @@ class PI0FlowMatching(nn.Module):
|
|||
paligemma_with_export_config = PaliGemmaWithExpertConfig(
|
||||
freeze_vision_encoder=self.config.freeze_vision_encoder,
|
||||
train_expert_only=self.config.train_expert_only,
|
||||
paligemma_pretrained_path=self.config.paligemma_pretrained_path,
|
||||
attention_implementation=self.config.attention_implementation,
|
||||
)
|
||||
self.paligemma_with_expert = PaliGemmaWithExpertModel(paligemma_with_export_config)
|
||||
|
|
|
@ -6,6 +6,7 @@ from pytest import Cache
|
|||
from torch import nn
|
||||
from transformers import (
|
||||
AutoConfig,
|
||||
AutoModel,
|
||||
GemmaForCausalLM,
|
||||
PaliGemmaForConditionalGeneration,
|
||||
PretrainedConfig,
|
||||
|
@ -52,11 +53,13 @@ class PaliGemmaWithExpertConfig(PretrainedConfig):
|
|||
gemma_expert_config: dict | None = None,
|
||||
freeze_vision_encoder: bool = True,
|
||||
train_expert_only: bool = True,
|
||||
paligemma_pretrained_path: str | None = "google/paligemma-3b-pt-224",
|
||||
attention_implementation: str = "eager",
|
||||
**kwargs,
|
||||
):
|
||||
self.freeze_vision_encoder = freeze_vision_encoder
|
||||
self.train_expert_only = train_expert_only
|
||||
self.paligemma_pretrained_path = paligemma_pretrained_path
|
||||
self.attention_implementation = attention_implementation
|
||||
|
||||
if paligemma_config is None:
|
||||
|
@ -153,6 +156,11 @@ class PaliGemmaWithExpertConfig(PretrainedConfig):
|
|||
f"Wrong value provided for `attention_implementation` ({self.attention_implementation}). Expected 'eager', 'fa2' or 'flex'."
|
||||
)
|
||||
|
||||
if self.paligemma_pretrained_path is not None and self.paligemma_config is not None:
|
||||
raise ValueError(
|
||||
"When 'paligemma_pretrained_path' is provided, 'paligemma_config' needs to be None."
|
||||
)
|
||||
|
||||
|
||||
class PaliGemmaWithExpertModel(PreTrainedModel):
|
||||
config_class = PaliGemmaWithExpertConfig
|
||||
|
@ -160,6 +168,9 @@ class PaliGemmaWithExpertModel(PreTrainedModel):
|
|||
def __init__(self, config: PaliGemmaWithExpertConfig):
|
||||
super().__init__(config=config)
|
||||
self.config = config
|
||||
if config.paligemma_pretrained_path is not None:
|
||||
self.paligemma = AutoModel(config.paligemma_pretrained_path)
|
||||
else:
|
||||
self.paligemma = PaliGemmaForConditionalGeneration(config=config.paligemma_config)
|
||||
self.gemma_expert = GemmaForCausalLM(config=config.gemma_expert_config)
|
||||
# Remove unused embed_tokens
|
||||
|
|
Loading…
Reference in New Issue