Merge 38e36d2d4c
into 1c873df5c0
This commit is contained in:
commit
9896e0f830
|
@ -75,6 +75,10 @@ class PI0Config(PreTrainedConfig):
|
||||||
train_expert_only: bool = False
|
train_expert_only: bool = False
|
||||||
train_state_proj: bool = True
|
train_state_proj: bool = True
|
||||||
|
|
||||||
|
# Useful if you want to reproduce the training of pi0_base from
|
||||||
|
# the original initialization of PaliGemma (before training on robotics data).
|
||||||
|
paligemma_pretrained_path: str | None = "google/paligemma-3b-pt-224"
|
||||||
|
|
||||||
# Training presets
|
# Training presets
|
||||||
optimizer_lr: float = 2.5e-5
|
optimizer_lr: float = 2.5e-5
|
||||||
optimizer_betas: tuple[float, float] = (0.9, 0.95)
|
optimizer_betas: tuple[float, float] = (0.9, 0.95)
|
||||||
|
|
|
@ -66,6 +66,12 @@ def main():
|
||||||
dataset_meta.stats["observation.state"]["std"] = torch.tensor(
|
dataset_meta.stats["observation.state"]["std"] = torch.tensor(
|
||||||
norm_stats["norm_stats"]["state"]["std"][:num_motors], dtype=torch.float32
|
norm_stats["norm_stats"]["state"]["std"][:num_motors], dtype=torch.float32
|
||||||
)
|
)
|
||||||
|
dataset_meta.stats["action"]["mean"] = torch.tensor(
|
||||||
|
norm_stats["norm_stats"]["actions"]["mean"][:num_motors], dtype=torch.float32
|
||||||
|
)
|
||||||
|
dataset_meta.stats["action"]["std"] = torch.tensor(
|
||||||
|
norm_stats["norm_stats"]["actions"]["std"][:num_motors], dtype=torch.float32
|
||||||
|
)
|
||||||
|
|
||||||
# Create LeRobot batch from Jax
|
# Create LeRobot batch from Jax
|
||||||
batch = {}
|
batch = {}
|
||||||
|
|
|
@ -360,17 +360,20 @@ def convert_pi0_checkpoint(checkpoint_dir: str, precision: str, tokenizer_id: st
|
||||||
empty_cameras=2,
|
empty_cameras=2,
|
||||||
adapt_to_pi_aloha=True,
|
adapt_to_pi_aloha=True,
|
||||||
use_delta_joint_actions_aloha=False,
|
use_delta_joint_actions_aloha=False,
|
||||||
|
paligemma_pretrained_path=None,
|
||||||
)
|
)
|
||||||
elif "pi0_aloha_towel" in checkpoint_dir:
|
elif "pi0_aloha_towel" in checkpoint_dir:
|
||||||
pi0_config = PI0Config(
|
pi0_config = PI0Config(
|
||||||
adapt_to_pi_aloha=True,
|
adapt_to_pi_aloha=True,
|
||||||
use_delta_joint_actions_aloha=True,
|
use_delta_joint_actions_aloha=True,
|
||||||
|
paligemma_pretrained_path=None,
|
||||||
)
|
)
|
||||||
elif "pi0_base" in checkpoint_dir:
|
elif "pi0_base" in checkpoint_dir:
|
||||||
pi0_config = PI0Config(
|
pi0_config = PI0Config(
|
||||||
empty_cameras=0,
|
empty_cameras=0,
|
||||||
adapt_to_pi_aloha=False,
|
adapt_to_pi_aloha=False,
|
||||||
use_delta_joint_actions_aloha=False,
|
use_delta_joint_actions_aloha=False,
|
||||||
|
paligemma_pretrained_path=None,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
raise ValueError()
|
raise ValueError()
|
||||||
|
|
|
@ -44,7 +44,7 @@ python lerobot/scripts/train.py \
|
||||||
|
|
||||||
Example of using the pi0 pretrained model outside LeRobot training framework:
|
Example of using the pi0 pretrained model outside LeRobot training framework:
|
||||||
```python
|
```python
|
||||||
policy = Pi0Policy.from_pretrained("lerobot/pi0")
|
policy = PI0Policy.from_pretrained("lerobot/pi0")
|
||||||
```
|
```
|
||||||
|
|
||||||
"""
|
"""
|
||||||
|
@ -470,6 +470,7 @@ class PI0FlowMatching(nn.Module):
|
||||||
paligemma_with_export_config = PaliGemmaWithExpertConfig(
|
paligemma_with_export_config = PaliGemmaWithExpertConfig(
|
||||||
freeze_vision_encoder=self.config.freeze_vision_encoder,
|
freeze_vision_encoder=self.config.freeze_vision_encoder,
|
||||||
train_expert_only=self.config.train_expert_only,
|
train_expert_only=self.config.train_expert_only,
|
||||||
|
paligemma_pretrained_path=self.config.paligemma_pretrained_path,
|
||||||
attention_implementation=self.config.attention_implementation,
|
attention_implementation=self.config.attention_implementation,
|
||||||
)
|
)
|
||||||
self.paligemma_with_expert = PaliGemmaWithExpertModel(paligemma_with_export_config)
|
self.paligemma_with_expert = PaliGemmaWithExpertModel(paligemma_with_export_config)
|
||||||
|
|
|
@ -20,6 +20,7 @@ from pytest import Cache
|
||||||
from torch import nn
|
from torch import nn
|
||||||
from transformers import (
|
from transformers import (
|
||||||
AutoConfig,
|
AutoConfig,
|
||||||
|
AutoModel,
|
||||||
GemmaForCausalLM,
|
GemmaForCausalLM,
|
||||||
PaliGemmaForConditionalGeneration,
|
PaliGemmaForConditionalGeneration,
|
||||||
PretrainedConfig,
|
PretrainedConfig,
|
||||||
|
@ -66,11 +67,13 @@ class PaliGemmaWithExpertConfig(PretrainedConfig):
|
||||||
gemma_expert_config: dict | None = None,
|
gemma_expert_config: dict | None = None,
|
||||||
freeze_vision_encoder: bool = True,
|
freeze_vision_encoder: bool = True,
|
||||||
train_expert_only: bool = True,
|
train_expert_only: bool = True,
|
||||||
|
paligemma_pretrained_path: str | None = "google/paligemma-3b-pt-224",
|
||||||
attention_implementation: str = "eager",
|
attention_implementation: str = "eager",
|
||||||
**kwargs,
|
**kwargs,
|
||||||
):
|
):
|
||||||
self.freeze_vision_encoder = freeze_vision_encoder
|
self.freeze_vision_encoder = freeze_vision_encoder
|
||||||
self.train_expert_only = train_expert_only
|
self.train_expert_only = train_expert_only
|
||||||
|
self.paligemma_pretrained_path = paligemma_pretrained_path
|
||||||
self.attention_implementation = attention_implementation
|
self.attention_implementation = attention_implementation
|
||||||
|
|
||||||
if paligemma_config is None:
|
if paligemma_config is None:
|
||||||
|
@ -167,6 +170,11 @@ class PaliGemmaWithExpertConfig(PretrainedConfig):
|
||||||
f"Wrong value provided for `attention_implementation` ({self.attention_implementation}). Expected 'eager', 'fa2' or 'flex'."
|
f"Wrong value provided for `attention_implementation` ({self.attention_implementation}). Expected 'eager', 'fa2' or 'flex'."
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if self.paligemma_pretrained_path is not None and self.paligemma_config is not None:
|
||||||
|
raise ValueError(
|
||||||
|
"When 'paligemma_pretrained_path' is provided, 'paligemma_config' needs to be None."
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class PaliGemmaWithExpertModel(PreTrainedModel):
|
class PaliGemmaWithExpertModel(PreTrainedModel):
|
||||||
config_class = PaliGemmaWithExpertConfig
|
config_class = PaliGemmaWithExpertConfig
|
||||||
|
@ -174,6 +182,9 @@ class PaliGemmaWithExpertModel(PreTrainedModel):
|
||||||
def __init__(self, config: PaliGemmaWithExpertConfig):
|
def __init__(self, config: PaliGemmaWithExpertConfig):
|
||||||
super().__init__(config=config)
|
super().__init__(config=config)
|
||||||
self.config = config
|
self.config = config
|
||||||
|
if config.paligemma_pretrained_path is not None:
|
||||||
|
self.paligemma = AutoModel(config.paligemma_pretrained_path)
|
||||||
|
else:
|
||||||
self.paligemma = PaliGemmaForConditionalGeneration(config=config.paligemma_config)
|
self.paligemma = PaliGemmaForConditionalGeneration(config=config.paligemma_config)
|
||||||
self.gemma_expert = GemmaForCausalLM(config=config.gemma_expert_config)
|
self.gemma_expert = GemmaForCausalLM(config=config.gemma_expert_config)
|
||||||
# Remove unused embed_tokens
|
# Remove unused embed_tokens
|
||||||
|
|
Loading…
Reference in New Issue