Resolved merge conflict with pi0
This commit is contained in:
commit
ef8579eacd
52
Makefile
52
Makefile
|
@ -26,7 +26,6 @@ test-end-to-end:
|
|||
${MAKE} DEVICE=$(DEVICE) test-diffusion-ete-eval
|
||||
${MAKE} DEVICE=$(DEVICE) test-tdmpc-ete-train
|
||||
${MAKE} DEVICE=$(DEVICE) test-tdmpc-ete-eval
|
||||
${MAKE} DEVICE=$(DEVICE) test-tdmpc-ete-train-with-online
|
||||
|
||||
test-act-ete-train:
|
||||
python lerobot/scripts/train.py \
|
||||
|
@ -128,28 +127,29 @@ test-tdmpc-ete-eval:
|
|||
--eval.batch_size=1 \
|
||||
--device=$(DEVICE)
|
||||
|
||||
test-tdmpc-ete-train-with-online:
|
||||
python lerobot/scripts/train.py \
|
||||
--policy.type=tdmpc \
|
||||
--env.type=pusht \
|
||||
--env.obs_type=environment_state_agent_pos \
|
||||
--env.episode_length=5 \
|
||||
--dataset.repo_id=lerobot/pusht_keypoints \
|
||||
--dataset.image_transforms.enable=true \
|
||||
--dataset.episodes="[0]" \
|
||||
--batch_size=2 \
|
||||
--offline.steps=2 \
|
||||
--online.steps=20 \
|
||||
--online.rollout_n_episodes=2 \
|
||||
--online.rollout_batch_size=2 \
|
||||
--online.steps_between_rollouts=10 \
|
||||
--online.buffer_capacity=1000 \
|
||||
--online.env_seed=10000 \
|
||||
--save_checkpoint=false \
|
||||
--save_freq=10 \
|
||||
--log_freq=1 \
|
||||
--eval.use_async_envs=true \
|
||||
--eval.n_episodes=1 \
|
||||
--eval.batch_size=1 \
|
||||
--device=$(DEVICE) \
|
||||
--output_dir=tests/outputs/tdmpc_online/
|
||||
# TODO(rcadene): fix online buffer to storing "task"
|
||||
# test-tdmpc-ete-train-with-online:
|
||||
# python lerobot/scripts/train.py \
|
||||
# --policy.type=tdmpc \
|
||||
# --env.type=pusht \
|
||||
# --env.obs_type=environment_state_agent_pos \
|
||||
# --env.episode_length=5 \
|
||||
# --dataset.repo_id=lerobot/pusht_keypoints \
|
||||
# --dataset.image_transforms.enable=true \
|
||||
# --dataset.episodes="[0]" \
|
||||
# --batch_size=2 \
|
||||
# --offline.steps=2 \
|
||||
# --online.steps=20 \
|
||||
# --online.rollout_n_episodes=2 \
|
||||
# --online.rollout_batch_size=2 \
|
||||
# --online.steps_between_rollouts=10 \
|
||||
# --online.buffer_capacity=1000 \
|
||||
# --online.env_seed=10000 \
|
||||
# --save_checkpoint=false \
|
||||
# --save_freq=10 \
|
||||
# --log_freq=1 \
|
||||
# --eval.use_async_envs=true \
|
||||
# --eval.n_episodes=1 \
|
||||
# --eval.batch_size=1 \
|
||||
# --device=$(DEVICE) \
|
||||
# --output_dir=tests/outputs/tdmpc_online/
|
||||
|
|
|
@ -672,6 +672,10 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
|||
for cam in image_keys:
|
||||
item[cam] = self.image_transforms(item[cam])
|
||||
|
||||
# Add task as a string
|
||||
task_idx = item["task_index"].item()
|
||||
item["task"] = self.meta.tasks[task_idx]
|
||||
|
||||
return item
|
||||
|
||||
def __repr__(self):
|
||||
|
|
|
@ -8,8 +8,6 @@ import torch
|
|||
@dataclass
|
||||
class OptimizerConfig(draccus.ChoiceRegistry, abc.ABC):
|
||||
lr: float
|
||||
betas: tuple[float, float]
|
||||
eps: float
|
||||
weight_decay: float
|
||||
grad_clip_norm: float
|
||||
|
||||
|
@ -54,3 +52,19 @@ class AdamWConfig(OptimizerConfig):
|
|||
kwargs = asdict(self)
|
||||
kwargs.pop("grad_clip_norm")
|
||||
return torch.optim.AdamW(params, **kwargs)
|
||||
|
||||
|
||||
@OptimizerConfig.register_subclass("sgd")
|
||||
@dataclass
|
||||
class SGDConfig(OptimizerConfig):
|
||||
lr: float = 1e-3
|
||||
momentum: float = 0.0
|
||||
dampening: float = 0.0
|
||||
nesterov: bool = False
|
||||
weight_decay: float = 0.0
|
||||
grad_clip_norm: float = 10.0
|
||||
|
||||
def build(self, params: dict) -> torch.optim.Optimizer:
|
||||
kwargs = asdict(self)
|
||||
kwargs.pop("grad_clip_norm")
|
||||
return torch.optim.SGD(params, **kwargs)
|
||||
|
|
|
@ -54,3 +54,38 @@ class VQBeTSchedulerConfig(LRSchedulerConfig):
|
|||
return max(0.0, 0.5 * (1.0 + math.cos(math.pi * float(self.num_cycles) * 2.0 * progress)))
|
||||
|
||||
return LambdaLR(optimizer, lr_lambda, -1)
|
||||
|
||||
|
||||
@LRSchedulerConfig.register_subclass("cosine_decay_with_warmup")
|
||||
@dataclass
|
||||
class CosineDecayWithWarmupSchedulerConfig(LRSchedulerConfig):
|
||||
"""Used by Physical Intelligence to train Pi0"""
|
||||
|
||||
num_warmup_steps: int
|
||||
num_decay_steps: int
|
||||
peak_lr: float
|
||||
decay_lr: float
|
||||
|
||||
def build(self, optimizer: Optimizer, num_training_steps: int) -> LambdaLR:
|
||||
del num_training_steps
|
||||
|
||||
def lr_lambda(current_step):
|
||||
def linear_warmup_schedule(current_step):
|
||||
if current_step <= 0:
|
||||
return 1 / (self.num_warmup_steps + 1)
|
||||
frac = 1 - current_step / self.num_warmup_steps
|
||||
return (1 / (self.num_warmup_steps + 1) - 1) * frac + 1
|
||||
|
||||
def cosine_decay_schedule(current_step):
|
||||
step = min(current_step, self.num_decay_steps)
|
||||
cosine_decay = 0.5 * (1 + math.cos(math.pi * step / self.num_decay_steps))
|
||||
alpha = self.decay_lr / self.peak_lr
|
||||
decayed = (1 - alpha) * cosine_decay + alpha
|
||||
return decayed
|
||||
|
||||
if current_step < self.num_warmup_steps:
|
||||
return linear_warmup_schedule(current_step)
|
||||
|
||||
return cosine_decay_schedule(current_step)
|
||||
|
||||
return LambdaLR(optimizer, lr_lambda, -1)
|
||||
|
|
|
@ -1,4 +1,5 @@
|
|||
from .act.configuration_act import ACTConfig as ACTConfig
|
||||
from .diffusion.configuration_diffusion import DiffusionConfig as DiffusionConfig
|
||||
from .pi0.configuration_pi0 import PI0Config as PI0Config
|
||||
from .tdmpc.configuration_tdmpc import TDMPCConfig as TDMPCConfig
|
||||
from .vqbet.configuration_vqbet import VQBeTConfig as VQBeTConfig
|
||||
|
|
|
@ -26,6 +26,7 @@ from lerobot.common.envs.utils import env_to_policy_features
|
|||
from lerobot.common.policies.act.configuration_act import ACTConfig
|
||||
from lerobot.common.policies.diffusion.configuration_diffusion import DiffusionConfig
|
||||
from lerobot.common.policies.dit_flow.configuration_dit_flow import DiTFlowConfig
|
||||
from lerobot.common.policies.pi0.configuration_pi0 import PI0Config
|
||||
from lerobot.common.policies.pretrained import PreTrainedPolicy
|
||||
from lerobot.common.policies.tdmpc.configuration_tdmpc import TDMPCConfig
|
||||
from lerobot.common.policies.vqbet.configuration_vqbet import VQBeTConfig
|
||||
|
@ -55,6 +56,10 @@ def get_policy_class(name: str) -> PreTrainedPolicy:
|
|||
from lerobot.common.policies.vqbet.modeling_vqbet import VQBeTPolicy
|
||||
|
||||
return VQBeTPolicy
|
||||
elif name == "pi0":
|
||||
from lerobot.common.policies.pi0.modeling_pi0 import PI0Policy
|
||||
|
||||
return PI0Policy
|
||||
else:
|
||||
raise NotImplementedError(f"Policy with name {name} is not implemented.")
|
||||
|
||||
|
@ -70,6 +75,8 @@ def make_policy_config(policy_type: str, **kwargs) -> PreTrainedConfig:
|
|||
return ACTConfig(**kwargs)
|
||||
elif policy_type == "vqbet":
|
||||
return VQBeTConfig(**kwargs)
|
||||
elif policy_type == "pi0":
|
||||
return PI0Config(**kwargs)
|
||||
else:
|
||||
raise ValueError(f"Policy type '{policy_type}' is not available.")
|
||||
|
||||
|
@ -148,4 +155,6 @@ def make_policy(
|
|||
policy.to(device)
|
||||
assert isinstance(policy, nn.Module)
|
||||
|
||||
# policy = torch.compile(policy, mode="reduce-overhead")
|
||||
|
||||
return policy
|
||||
|
|
|
@ -140,6 +140,9 @@ class Normalize(nn.Module):
|
|||
def forward(self, batch: dict[str, Tensor]) -> dict[str, Tensor]:
|
||||
batch = dict(batch) # shallow copy avoids mutating the input batch
|
||||
for key, ft in self.features.items():
|
||||
if key not in batch:
|
||||
continue
|
||||
|
||||
norm_mode = self.norm_map.get(ft.type, NormalizationMode.IDENTITY)
|
||||
if norm_mode is NormalizationMode.IDENTITY:
|
||||
continue
|
||||
|
@ -210,6 +213,9 @@ class Unnormalize(nn.Module):
|
|||
def forward(self, batch: dict[str, Tensor]) -> dict[str, Tensor]:
|
||||
batch = dict(batch) # shallow copy avoids mutating the input batch
|
||||
for key, ft in self.features.items():
|
||||
if key not in batch:
|
||||
continue
|
||||
|
||||
norm_mode = self.norm_map.get(ft.type, NormalizationMode.IDENTITY)
|
||||
if norm_mode is NormalizationMode.IDENTITY:
|
||||
continue
|
||||
|
|
|
@ -0,0 +1,134 @@
|
|||
from dataclasses import dataclass, field
|
||||
|
||||
from lerobot.common.optim.optimizers import AdamWConfig
|
||||
from lerobot.common.optim.schedulers import (
|
||||
CosineDecayWithWarmupSchedulerConfig,
|
||||
)
|
||||
from lerobot.configs.policies import PreTrainedConfig
|
||||
from lerobot.configs.types import FeatureType, NormalizationMode, PolicyFeature
|
||||
|
||||
|
||||
@PreTrainedConfig.register_subclass("pi0")
|
||||
@dataclass
|
||||
class PI0Config(PreTrainedConfig):
|
||||
# Input / output structure.
|
||||
n_obs_steps: int = 1
|
||||
chunk_size: int = 50
|
||||
n_action_steps: int = 50
|
||||
|
||||
normalization_mapping: dict[str, NormalizationMode] = field(
|
||||
default_factory=lambda: {
|
||||
"VISUAL": NormalizationMode.IDENTITY,
|
||||
"STATE": NormalizationMode.MEAN_STD,
|
||||
"ACTION": NormalizationMode.MEAN_STD,
|
||||
}
|
||||
)
|
||||
|
||||
# Shorter state and action vectors will be padded
|
||||
max_state_dim: int = 32
|
||||
max_action_dim: int = 32
|
||||
|
||||
# Image preprocessing
|
||||
resize_imgs_with_padding: tuple[int, int] = (224, 224)
|
||||
|
||||
# Add empty images. Used by pi0_aloha_sim which adds the empty
|
||||
# left and right wrist cameras in addition to the top camera.
|
||||
empty_cameras: int = 0
|
||||
|
||||
# Converts the joint and gripper values from the standard Aloha space to
|
||||
# the space used by the pi internal runtime which was used to train the base model.
|
||||
adapt_to_pi_aloha: bool = False
|
||||
|
||||
# Converts joint dimensions to deltas with respect to the current state before passing to the model.
|
||||
# Gripper dimensions will remain in absolute values.
|
||||
use_delta_joint_actions_aloha: bool = False
|
||||
|
||||
# Tokenizer
|
||||
tokenizer_max_length: int = 48
|
||||
|
||||
# Projector
|
||||
proj_width: int = 1024
|
||||
|
||||
# Decoding
|
||||
num_steps: int = 10
|
||||
|
||||
# Attention utils
|
||||
use_cache: bool = True
|
||||
attention_implementation: str = "eager" # or fa2, flex
|
||||
|
||||
# Finetuning settings
|
||||
freeze_vision_encoder: bool = True
|
||||
train_expert_only: bool = False
|
||||
train_state_proj: bool = True
|
||||
|
||||
# Training presets
|
||||
optimizer_lr: float = 2.5e-5
|
||||
optimizer_betas: tuple[float, float] = (0.9, 0.95)
|
||||
optimizer_eps: float = 1e-8
|
||||
optimizer_weight_decay: float = 1e-10
|
||||
|
||||
scheduler_warmup_steps: int = 1_000
|
||||
scheduler_decay_steps: int = 30_000
|
||||
scheduler_decay_lr: float = 2.5e-6
|
||||
|
||||
# TODO: Add EMA
|
||||
|
||||
def __post_init__(self):
|
||||
super().__post_init__()
|
||||
|
||||
"""Input validation (not exhaustive)."""
|
||||
if self.n_action_steps > self.chunk_size:
|
||||
raise ValueError(
|
||||
f"The chunk size is the upper bound for the number of action steps per model invocation. Got "
|
||||
f"{self.n_action_steps} for `n_action_steps` and {self.chunk_size} for `chunk_size`."
|
||||
)
|
||||
if self.n_obs_steps != 1:
|
||||
raise ValueError(
|
||||
f"Multiple observation steps not handled yet. Got `nobs_steps={self.n_obs_steps}`"
|
||||
)
|
||||
|
||||
if self.use_delta_joint_actions_aloha:
|
||||
raise NotImplementedError(
|
||||
"`use_delta_joint_actions_aloha` is used by pi0 for aloha real models. It is not ported yet in LeRobot."
|
||||
)
|
||||
|
||||
def validate_features(self) -> None:
|
||||
# TODO: implement value error
|
||||
# if not self.image_features and not self.env_state_feature:
|
||||
# raise ValueError("You must provide at least one image or the environment state among the inputs.")
|
||||
|
||||
for i in range(self.empty_cameras):
|
||||
key = f"observation.images.empty_camera_{i}"
|
||||
empty_camera = PolicyFeature(
|
||||
type=FeatureType.VISUAL,
|
||||
shape=(3, 480, 640),
|
||||
)
|
||||
self.input_features[key] = empty_camera
|
||||
|
||||
def get_optimizer_preset(self) -> AdamWConfig:
|
||||
return AdamWConfig(
|
||||
lr=self.optimizer_lr,
|
||||
betas=self.optimizer_betas,
|
||||
eps=self.optimizer_eps,
|
||||
weight_decay=self.optimizer_weight_decay,
|
||||
)
|
||||
|
||||
def get_scheduler_preset(self):
|
||||
return CosineDecayWithWarmupSchedulerConfig(
|
||||
peak_lr=self.optimizer_lr,
|
||||
decay_lr=self.scheduler_decay_lr,
|
||||
num_warmup_steps=self.scheduler_warmup_steps,
|
||||
num_decay_steps=self.scheduler_decay_steps,
|
||||
)
|
||||
|
||||
@property
|
||||
def observation_delta_indices(self) -> None:
|
||||
return None
|
||||
|
||||
@property
|
||||
def action_delta_indices(self) -> list:
|
||||
return list(range(self.chunk_size))
|
||||
|
||||
@property
|
||||
def reward_delta_indices(self) -> None:
|
||||
return None
|
|
@ -0,0 +1,68 @@
|
|||
import torch
|
||||
|
||||
from lerobot.common.datasets.lerobot_dataset import LeRobotDataset
|
||||
from lerobot.common.policies.factory import make_policy
|
||||
from lerobot.configs.policies import PreTrainedConfig
|
||||
|
||||
torch.backends.cudnn.benchmark = True
|
||||
|
||||
|
||||
def main():
|
||||
device = "cuda"
|
||||
dataset_repo_id = "danaaubakirova/koch_test"
|
||||
# model_name = "pi0_base"
|
||||
# ckpt_torch_dir = Path.home() / f".cache/openpi/openpi-assets/checkpoints/{model_name}_pytorch"
|
||||
ckpt_torch_dir = "lerobot/pi0"
|
||||
|
||||
dataset = LeRobotDataset(dataset_repo_id, episodes=[0])
|
||||
|
||||
dataloader = torch.utils.data.DataLoader(
|
||||
dataset,
|
||||
num_workers=0,
|
||||
batch_size=1,
|
||||
)
|
||||
|
||||
batch = next(iter(dataloader))
|
||||
|
||||
# To device
|
||||
for k in batch:
|
||||
if isinstance(batch[k], torch.Tensor):
|
||||
batch[k] = batch[k].to(device=device, dtype=torch.float32)
|
||||
|
||||
cfg = PreTrainedConfig.from_pretrained(ckpt_torch_dir)
|
||||
cfg.pretrained_path = ckpt_torch_dir
|
||||
policy = make_policy(cfg, device, ds_meta=dataset.meta)
|
||||
|
||||
# policy = torch.compile(policy, mode="reduce-overhead")
|
||||
|
||||
warmup_iters = 10
|
||||
benchmark_iters = 30
|
||||
|
||||
# Warmup
|
||||
for _ in range(warmup_iters):
|
||||
torch.cuda.synchronize()
|
||||
policy.select_action(batch)
|
||||
policy.reset()
|
||||
torch.cuda.synchronize()
|
||||
|
||||
# Benchmark
|
||||
start_event = torch.cuda.Event(enable_timing=True)
|
||||
end_event = torch.cuda.Event(enable_timing=True)
|
||||
|
||||
start_event.record()
|
||||
for _ in range(benchmark_iters):
|
||||
policy.select_action(batch)
|
||||
policy.reset()
|
||||
end_event.record()
|
||||
|
||||
# Synchronize and measure time
|
||||
torch.cuda.synchronize()
|
||||
elapsed_time_ms = start_event.elapsed_time(end_event)
|
||||
|
||||
avg_time_per_iter = elapsed_time_ms / benchmark_iters
|
||||
print(f"Average execution time per iteration: {avg_time_per_iter:.3f} ms")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
with torch.inference_mode():
|
||||
main()
|
|
@ -0,0 +1,117 @@
|
|||
import json
|
||||
import pickle
|
||||
from pathlib import Path
|
||||
|
||||
import torch
|
||||
|
||||
from lerobot.common.datasets.lerobot_dataset import LeRobotDatasetMetadata
|
||||
from lerobot.common.policies.factory import make_policy
|
||||
from lerobot.configs.policies import PreTrainedConfig
|
||||
|
||||
|
||||
def display(tensor: torch.Tensor):
|
||||
if tensor.dtype == torch.bool:
|
||||
tensor = tensor.float()
|
||||
print(f"Shape: {tensor.shape}")
|
||||
print(f"Mean: {tensor.mean().item()}")
|
||||
print(f"Std: {tensor.std().item()}")
|
||||
print(f"Min: {tensor.min().item()}")
|
||||
print(f"Max: {tensor.max().item()}")
|
||||
|
||||
|
||||
def main():
|
||||
num_motors = 14
|
||||
device = "cuda"
|
||||
# model_name = "pi0_aloha_towel"
|
||||
model_name = "pi0_aloha_sim"
|
||||
|
||||
if model_name == "pi0_aloha_towel":
|
||||
dataset_repo_id = "lerobot/aloha_static_towel"
|
||||
else:
|
||||
dataset_repo_id = "lerobot/aloha_sim_transfer_cube_human"
|
||||
|
||||
ckpt_torch_dir = Path.home() / f".cache/openpi/openpi-assets/checkpoints/{model_name}_pytorch"
|
||||
ckpt_jax_dir = Path.home() / f".cache/openpi/openpi-assets/checkpoints/{model_name}"
|
||||
save_dir = Path(f"../openpi/data/{model_name}/save")
|
||||
|
||||
with open(save_dir / "example.pkl", "rb") as f:
|
||||
example = pickle.load(f)
|
||||
with open(save_dir / "outputs.pkl", "rb") as f:
|
||||
outputs = pickle.load(f)
|
||||
with open(save_dir / "noise.pkl", "rb") as f:
|
||||
noise = pickle.load(f)
|
||||
|
||||
with open(ckpt_jax_dir / "assets/norm_stats.json") as f:
|
||||
norm_stats = json.load(f)
|
||||
|
||||
# Override stats
|
||||
dataset_meta = LeRobotDatasetMetadata(dataset_repo_id)
|
||||
dataset_meta.stats["observation.state"]["mean"] = torch.tensor(
|
||||
norm_stats["norm_stats"]["state"]["mean"][:num_motors], dtype=torch.float32
|
||||
)
|
||||
dataset_meta.stats["observation.state"]["std"] = torch.tensor(
|
||||
norm_stats["norm_stats"]["state"]["std"][:num_motors], dtype=torch.float32
|
||||
)
|
||||
|
||||
# Create LeRobot batch from Jax
|
||||
batch = {}
|
||||
for cam_key, uint_chw_array in example["images"].items():
|
||||
batch[f"observation.images.{cam_key}"] = torch.from_numpy(uint_chw_array) / 255.0
|
||||
batch["observation.state"] = torch.from_numpy(example["state"])
|
||||
batch["action"] = torch.from_numpy(outputs["actions"])
|
||||
batch["task"] = example["prompt"]
|
||||
|
||||
if model_name == "pi0_aloha_towel":
|
||||
del batch["observation.images.cam_low"]
|
||||
elif model_name == "pi0_aloha_sim":
|
||||
batch["observation.images.top"] = batch["observation.images.cam_high"]
|
||||
del batch["observation.images.cam_high"]
|
||||
|
||||
# Batchify
|
||||
for key in batch:
|
||||
if isinstance(batch[key], torch.Tensor):
|
||||
batch[key] = batch[key].unsqueeze(0)
|
||||
elif isinstance(batch[key], str):
|
||||
batch[key] = [batch[key]]
|
||||
else:
|
||||
raise ValueError(f"{key}, {batch[key]}")
|
||||
|
||||
# To device
|
||||
for k in batch:
|
||||
if isinstance(batch[k], torch.Tensor):
|
||||
batch[k] = batch[k].to(device=device, dtype=torch.float32)
|
||||
|
||||
noise = torch.from_numpy(noise).to(device=device, dtype=torch.float32)
|
||||
|
||||
from lerobot.common import policies # noqa
|
||||
|
||||
cfg = PreTrainedConfig.from_pretrained(ckpt_torch_dir)
|
||||
cfg.pretrained_path = ckpt_torch_dir
|
||||
policy = make_policy(cfg, device, dataset_meta)
|
||||
|
||||
# loss_dict = policy.forward(batch, noise=noise, time=time_beta)
|
||||
# loss_dict["loss"].backward()
|
||||
# print("losses")
|
||||
# display(loss_dict["losses_after_forward"])
|
||||
# print("pi_losses")
|
||||
# display(pi_losses)
|
||||
|
||||
actions = []
|
||||
for _ in range(50):
|
||||
action = policy.select_action(batch, noise=noise)
|
||||
actions.append(action)
|
||||
|
||||
actions = torch.stack(actions, dim=1)
|
||||
pi_actions = batch["action"]
|
||||
print("actions")
|
||||
display(actions)
|
||||
print()
|
||||
print("pi_actions")
|
||||
display(pi_actions)
|
||||
print("atol=3e-2", torch.allclose(actions, pi_actions, atol=3e-2))
|
||||
print("atol=2e-2", torch.allclose(actions, pi_actions, atol=2e-2))
|
||||
print("atol=1e-2", torch.allclose(actions, pi_actions, atol=1e-2))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
|
@ -0,0 +1,70 @@
|
|||
from transformers import GemmaConfig, PaliGemmaConfig
|
||||
|
||||
|
||||
def get_paligemma_config(precision: str):
|
||||
config = {
|
||||
"image_token_index": None,
|
||||
"pad_token_id": 0,
|
||||
"bos_token_id": 2,
|
||||
"eos_token_id": 1,
|
||||
}
|
||||
|
||||
# image_sizes = {"2b-test": 224, "3b-224px": 224, "3b-448px": 448, "3b-896px": 896}
|
||||
|
||||
image_size = 224 # image_sizes[variant]
|
||||
patch_size = 14
|
||||
num_image_tokens = (image_size**2) // (patch_size**2)
|
||||
|
||||
config["image_token_index"] = 257152
|
||||
text_config = {
|
||||
"vocab_size": 257152,
|
||||
"num_hidden_layers": 18,
|
||||
"num_key_value_heads": 1,
|
||||
"head_dim": 256,
|
||||
"torch_dtype": precision,
|
||||
"hidden_size": 2048,
|
||||
"hidden_activation": "gelu_pytorch_tanh",
|
||||
"num_attention_heads": 8,
|
||||
"intermediate_size": 16384,
|
||||
"is_encoder_decoder": False,
|
||||
}
|
||||
vision_config = {
|
||||
"torch_dtype": precision,
|
||||
"image_size": image_size,
|
||||
"patch_size": patch_size,
|
||||
"num_image_tokens": num_image_tokens,
|
||||
"hidden_size": 1152,
|
||||
"intermediate_size": 4304,
|
||||
"num_hidden_layers": 27,
|
||||
"num_attention_heads": 16,
|
||||
"projector_hidden_act": "gelu_fast",
|
||||
"vision_use_head": False,
|
||||
}
|
||||
final_config = PaliGemmaConfig(text_config=text_config, vision_config=vision_config, **config)
|
||||
return final_config
|
||||
|
||||
|
||||
def get_gemma_config(precision: str):
|
||||
config = {
|
||||
"image_token_index": None,
|
||||
"pad_token_id": 0,
|
||||
"bos_token_id": 2,
|
||||
"eos_token_id": 1,
|
||||
}
|
||||
|
||||
config["image_token_index"] = 257152
|
||||
text_config = {
|
||||
"vocab_size": 257152,
|
||||
"num_hidden_layers": 18,
|
||||
"num_key_value_heads": 1,
|
||||
"head_dim": 256,
|
||||
"torch_dtype": precision,
|
||||
"hidden_size": 1024,
|
||||
"hidden_activation": "gelu_pytorch_tanh",
|
||||
"num_attention_heads": 8,
|
||||
"intermediate_size": 4096,
|
||||
"is_encoder_decoder": False,
|
||||
}
|
||||
final_config = GemmaConfig()
|
||||
final_config.update(text_config)
|
||||
return final_config
|
|
@ -0,0 +1,423 @@
|
|||
"""
|
||||
Convert pi0 parameters from Jax to Pytorch
|
||||
|
||||
Follow [README of openpi](https://github.com/Physical-Intelligence/openpi) to create a new environment
|
||||
and install the required librairies.
|
||||
|
||||
```bash
|
||||
cd ~/code/openpi
|
||||
source .venv/bin/activate
|
||||
```
|
||||
|
||||
Example downloading parameters:
|
||||
```bash
|
||||
python
|
||||
>>> import openpi.shared.download as download
|
||||
>>> path='s3://openpi-assets/checkpoints/pi0_base/params'
|
||||
>>> download.maybe_download(path)
|
||||
```
|
||||
|
||||
Converting pi0_base:
|
||||
```python
|
||||
python lerobot/common/policies/pi0/conversion_scripts/convert_pi0_to_hf_lerobot.py \
|
||||
--checkpoint_dir /home/remi_cadene/.cache/openpi/openpi-assets/checkpoints/pi0_base/params \
|
||||
--output_path /home/remi_cadene/.cache/openpi/openpi-assets/checkpoints/pi0_base_pytorch
|
||||
```
|
||||
|
||||
```python
|
||||
python lerobot/common/policies/pi0/conversion_scripts/convert_pi0_to_hf_lerobot.py \
|
||||
--checkpoint_dir /home/remi_cadene/.cache/openpi/openpi-assets/checkpoints/pi0_aloha_sim/params \
|
||||
--output_path /home/remi_cadene/.cache/openpi/openpi-assets/checkpoints/pi0_aloha_sim_pytorch
|
||||
```
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import pathlib
|
||||
|
||||
import jax
|
||||
import numpy as np
|
||||
import orbax.checkpoint as ocp
|
||||
import torch
|
||||
from jax.sharding import SingleDeviceSharding
|
||||
|
||||
from lerobot.common.policies.pi0.configuration_pi0 import PI0Config
|
||||
from lerobot.common.policies.pi0.conversion_scripts.conversion_utils import (
|
||||
get_gemma_config,
|
||||
get_paligemma_config,
|
||||
)
|
||||
from lerobot.common.policies.pi0.modeling_pi0 import PI0Policy
|
||||
|
||||
PRECISIONS = {"bfloat16": torch.bfloat16, "float32": torch.float32, "float16": torch.float16}
|
||||
|
||||
|
||||
def slice_paligemma_state_dict(state_dict, config):
|
||||
suffix = "/value" if "img/embedding/kernel/value" in state_dict else ""
|
||||
|
||||
# fmt: off
|
||||
# patch embeddings
|
||||
state_dict["paligemma.vision_tower.vision_model.embeddings.patch_embedding.weight"] = state_dict.pop(f"img/embedding/kernel{suffix}").transpose(
|
||||
3, 2, 0, 1
|
||||
)
|
||||
state_dict["paligemma.vision_tower.vision_model.embeddings.patch_embedding.bias"] = state_dict.pop(f"img/embedding/bias{suffix}")
|
||||
# positional embeddings
|
||||
state_dict["paligemma.vision_tower.vision_model.embeddings.position_embedding.weight"] = state_dict.pop(f"img/pos_embedding{suffix}").reshape(
|
||||
-1, config.vision_config.hidden_size
|
||||
)
|
||||
|
||||
# extract vision layers to be sliced at index 0. There are 27 layers in the base model.
|
||||
encoderblock_layernorm0_scale = state_dict.pop(f"img/Transformer/encoderblock/LayerNorm_0/scale{suffix}")
|
||||
encoderblock_layernorm0_bias = state_dict.pop(f"img/Transformer/encoderblock/LayerNorm_0/bias{suffix}")
|
||||
encoderblock_layernorm1_scale = state_dict.pop(f"img/Transformer/encoderblock/LayerNorm_1/scale{suffix}")
|
||||
encoderblock_layernorm1_bias = state_dict.pop(f"img/Transformer/encoderblock/LayerNorm_1/bias{suffix}")
|
||||
|
||||
encoderblock_mlp_dense0_kernel= state_dict.pop(f"img/Transformer/encoderblock/MlpBlock_0/Dense_0/kernel{suffix}")
|
||||
encoderblock_mlp_dense0_bias= state_dict.pop(f"img/Transformer/encoderblock/MlpBlock_0/Dense_0/bias{suffix}")
|
||||
encoderblock_mlp_dense1_kernel= state_dict.pop(f"img/Transformer/encoderblock/MlpBlock_0/Dense_1/kernel{suffix}")
|
||||
encoderblock_mlp_dense1_bias= state_dict.pop(f"img/Transformer/encoderblock/MlpBlock_0/Dense_1/bias{suffix}")
|
||||
|
||||
encoderblock_attention_0_key_kernel = state_dict.pop(f"img/Transformer/encoderblock/MultiHeadDotProductAttention_0/key/kernel{suffix}")
|
||||
encoderblock_attention_0_key_bias = state_dict.pop(f"img/Transformer/encoderblock/MultiHeadDotProductAttention_0/key/bias{suffix}")
|
||||
encoderblock_attention_0_value_kernel = state_dict.pop(f"img/Transformer/encoderblock/MultiHeadDotProductAttention_0/value/kernel{suffix}")
|
||||
encoderblock_attention_0_value_bias = state_dict.pop(f"img/Transformer/encoderblock/MultiHeadDotProductAttention_0/value/bias{suffix}")
|
||||
encoderblock_attention_0_query_kernel = state_dict.pop(f"img/Transformer/encoderblock/MultiHeadDotProductAttention_0/query/kernel{suffix}")
|
||||
encoderblock_attention_0_query_bias = state_dict.pop(f"img/Transformer/encoderblock/MultiHeadDotProductAttention_0/query/bias{suffix}")
|
||||
encoderblock_attention_0_out_kernel = state_dict.pop(f"img/Transformer/encoderblock/MultiHeadDotProductAttention_0/out/kernel{suffix}")
|
||||
encoderblock_attention_0_out_bias = state_dict.pop(f"img/Transformer/encoderblock/MultiHeadDotProductAttention_0/out/bias{suffix}")
|
||||
|
||||
for i in range(config.vision_config.num_hidden_layers):
|
||||
state_dict[f"paligemma.vision_tower.vision_model.encoder.layers.{i}.layer_norm1.weight"] = encoderblock_layernorm0_scale[i].transpose()
|
||||
state_dict[f"paligemma.vision_tower.vision_model.encoder.layers.{i}.layer_norm1.bias"] = encoderblock_layernorm0_bias[i]
|
||||
state_dict[f"paligemma.vision_tower.vision_model.encoder.layers.{i}.layer_norm2.weight"] = encoderblock_layernorm1_scale[i].transpose()
|
||||
state_dict[f"paligemma.vision_tower.vision_model.encoder.layers.{i}.layer_norm2.bias"] = encoderblock_layernorm1_bias[i]
|
||||
|
||||
state_dict[f"paligemma.vision_tower.vision_model.encoder.layers.{i}.mlp.fc1.weight"] = encoderblock_mlp_dense0_kernel[i].transpose()
|
||||
state_dict[f"paligemma.vision_tower.vision_model.encoder.layers.{i}.mlp.fc1.bias"] = encoderblock_mlp_dense0_bias[i]
|
||||
state_dict[f"paligemma.vision_tower.vision_model.encoder.layers.{i}.mlp.fc2.weight"] = encoderblock_mlp_dense1_kernel[i].transpose()
|
||||
state_dict[f"paligemma.vision_tower.vision_model.encoder.layers.{i}.mlp.fc2.bias"] = encoderblock_mlp_dense1_bias[i]
|
||||
state_dict[f"paligemma.vision_tower.vision_model.encoder.layers.{i}.self_attn.k_proj.weight"] = encoderblock_attention_0_key_kernel[i].reshape(-1, config.vision_config.hidden_size).transpose()
|
||||
state_dict[f"paligemma.vision_tower.vision_model.encoder.layers.{i}.self_attn.k_proj.bias"] = encoderblock_attention_0_key_bias[i].reshape(-1, config.vision_config.hidden_size).reshape(-1)
|
||||
state_dict[f"paligemma.vision_tower.vision_model.encoder.layers.{i}.self_attn.v_proj.weight"] = encoderblock_attention_0_value_kernel[i].reshape(-1, config.vision_config.hidden_size).transpose()
|
||||
state_dict[f"paligemma.vision_tower.vision_model.encoder.layers.{i}.self_attn.v_proj.bias"] = encoderblock_attention_0_value_bias[i].reshape(-1, config.vision_config.hidden_size).reshape(-1)
|
||||
state_dict[f"paligemma.vision_tower.vision_model.encoder.layers.{i}.self_attn.q_proj.weight"] = encoderblock_attention_0_query_kernel[i].reshape(-1, config.vision_config.hidden_size).transpose()
|
||||
state_dict[f"paligemma.vision_tower.vision_model.encoder.layers.{i}.self_attn.q_proj.bias"] = encoderblock_attention_0_query_bias[i].reshape(-1, config.vision_config.hidden_size).reshape(-1)
|
||||
state_dict[f"paligemma.vision_tower.vision_model.encoder.layers.{i}.self_attn.out_proj.weight"] = encoderblock_attention_0_out_kernel[i].reshape(-1, config.vision_config.hidden_size).transpose()
|
||||
state_dict[f"paligemma.vision_tower.vision_model.encoder.layers.{i}.self_attn.out_proj.bias"] = encoderblock_attention_0_out_bias[i].reshape(-1, config.vision_config.hidden_size).reshape(-1)
|
||||
|
||||
state_dict["paligemma.vision_tower.vision_model.post_layernorm.weight"] = state_dict.pop(f"img/Transformer/encoder_norm/scale{suffix}").transpose()
|
||||
state_dict["paligemma.vision_tower.vision_model.post_layernorm.bias"] = state_dict.pop(f"img/Transformer/encoder_norm/bias{suffix}")
|
||||
|
||||
# multimodal projector
|
||||
|
||||
state_dict['paligemma.multi_modal_projector.linear.weight'] = state_dict.pop(f"img/head/kernel{suffix}").transpose()
|
||||
state_dict['paligemma.multi_modal_projector.linear.bias'] = state_dict.pop(f"img/head/bias{suffix}")
|
||||
|
||||
# text decoder (gemma)
|
||||
embedding_vector = state_dict.pop(f"llm/embedder/input_embedding{suffix}")
|
||||
state_dict["paligemma.language_model.model.embed_tokens.weight"] = embedding_vector
|
||||
|
||||
# pop the einsum attention + mlp representations. There are 18 layers in gemma-2b.
|
||||
|
||||
llm_attention_attn_vec_einsum = state_dict.pop(f"llm/layers/attn/attn_vec_einsum/w{suffix}")
|
||||
llm_attention_kv_einsum = state_dict.pop(f"llm/layers/attn/kv_einsum/w{suffix}")
|
||||
llm_attention_q_einsum = state_dict.pop(f"llm/layers/attn/q_einsum/w{suffix}")
|
||||
|
||||
llm_mlp_gating_einsum = state_dict.pop(f"llm/layers/mlp/gating_einsum{suffix}")
|
||||
llm_mlp_linear = state_dict.pop(f"llm/layers/mlp/linear{suffix}")
|
||||
# TODO verify correctness of layer norm loading
|
||||
|
||||
llm_input_layernorm = state_dict.pop(f"llm/layers/pre_attention_norm/scale{suffix}")
|
||||
llm_post_attention_layernorm = state_dict.pop(f"llm/layers/pre_ffw_norm/scale{suffix}")
|
||||
|
||||
for i in range(config.text_config.num_hidden_layers):
|
||||
# llm_attention_q_einsum[i].shape = (8, 2048, 256)
|
||||
q_proj_weight_reshaped = llm_attention_q_einsum[i].transpose(0, 2, 1).reshape(config.text_config.num_attention_heads * config.text_config.head_dim, config.text_config.hidden_size)
|
||||
|
||||
state_dict[f"paligemma.language_model.model.layers.{i}.self_attn.q_proj.weight"] = q_proj_weight_reshaped
|
||||
|
||||
# llm_attention_kv_einsum[i, 0, 0].shape = (2048, 256)
|
||||
k_proj_weight_reshaped = llm_attention_kv_einsum[i, 0, 0].transpose()
|
||||
state_dict[f"paligemma.language_model.model.layers.{i}.self_attn.k_proj.weight"] = k_proj_weight_reshaped
|
||||
# llm_attention_kv_einsum[i, 1, 0].shape = (2048, 256)
|
||||
v_proj_weight_reshaped = llm_attention_kv_einsum[i, 1, 0].transpose()
|
||||
state_dict[f"paligemma.language_model.model.layers.{i}.self_attn.v_proj.weight"] = v_proj_weight_reshaped
|
||||
|
||||
# output projection.
|
||||
|
||||
# llm_attention_attn_vec_einsum[i].shape = (8, 256, 2048)
|
||||
o_proj_weight_reshaped = llm_attention_attn_vec_einsum[i].transpose(2, 0, 1).reshape(config.text_config.num_attention_heads * config.text_config.head_dim, config.text_config.hidden_size)
|
||||
|
||||
state_dict[f"paligemma.language_model.model.layers.{i}.self_attn.o_proj.weight"] = o_proj_weight_reshaped
|
||||
# mlp layers
|
||||
gate_proj_weight = llm_mlp_gating_einsum[i, 0]
|
||||
state_dict[f"paligemma.language_model.model.layers.{i}.mlp.gate_proj.weight"] = gate_proj_weight.transpose()
|
||||
up_proj_weight = llm_mlp_gating_einsum[i, 1]
|
||||
state_dict[f"paligemma.language_model.model.layers.{i}.mlp.up_proj.weight"] = up_proj_weight.transpose()
|
||||
state_dict[f"paligemma.language_model.model.layers.{i}.mlp.down_proj.weight"] = llm_mlp_linear[i].transpose()
|
||||
state_dict[f"paligemma.language_model.model.layers.{i}.input_layernorm.weight"] = llm_input_layernorm[i]
|
||||
state_dict[f"paligemma.language_model.model.layers.{i}.post_attention_layernorm.weight"] = llm_post_attention_layernorm[i]
|
||||
|
||||
state_dict["paligemma.language_model.model.norm.weight"] = state_dict.pop(f"llm/final_norm/scale{suffix}")
|
||||
state_dict["paligemma.language_model.lm_head.weight"] = embedding_vector # weights are tied.
|
||||
|
||||
# fmt: on
|
||||
expert_dict = {}
|
||||
final_state_dict = {}
|
||||
for key, value in state_dict.items():
|
||||
if key not in [
|
||||
f"llm/final_norm_1/scale{suffix}",
|
||||
f"llm/layers/attn/attn_vec_einsum_1/w{suffix}",
|
||||
f"llm/layers/attn/kv_einsum_1/w{suffix}",
|
||||
f"llm/layers/attn/q_einsum_1/w{suffix}",
|
||||
f"llm/layers/mlp_1/gating_einsum{suffix}",
|
||||
f"llm/layers/mlp_1/linear{suffix}",
|
||||
f"llm/layers/pre_attention_norm_1/scale{suffix}",
|
||||
f"llm/layers/pre_ffw_norm_1/scale{suffix}",
|
||||
]:
|
||||
final_state_dict[key] = torch.from_numpy(value)
|
||||
else:
|
||||
expert_dict[key] = value
|
||||
|
||||
return final_state_dict, expert_dict
|
||||
|
||||
|
||||
def slice_gemma_state_dict(state_dict, config, num_expert=1):
|
||||
# fmt: off
|
||||
# text decoder (gemma)
|
||||
# no embedding vector, the expert just has the decoder layers
|
||||
|
||||
embedding_vector = torch.zeros([config.vocab_size, config.hidden_size])
|
||||
state_dict["gemma_expert.model.embed_tokens.weight"] = embedding_vector
|
||||
|
||||
# pop the einsum attention + mlp representations. There are 18 layers in gemma-2b.
|
||||
|
||||
suffix = "/value" if f"llm/layers/attn/attn_vec_einsum_{num_expert}/w/value" in state_dict else ""
|
||||
|
||||
llm_attention_attn_vec_einsum = state_dict.pop(f"llm/layers/attn/attn_vec_einsum_{num_expert}/w{suffix}")
|
||||
llm_attention_kv_einsum = state_dict.pop(f"llm/layers/attn/kv_einsum_{num_expert}/w{suffix}")
|
||||
llm_attention_q_einsum = state_dict.pop(f"llm/layers/attn/q_einsum_{num_expert}/w{suffix}")
|
||||
|
||||
llm_mlp_gating_einsum = state_dict.pop(f"llm/layers/mlp_{num_expert}/gating_einsum{suffix}")
|
||||
llm_mlp_linear = state_dict.pop(f"llm/layers/mlp_{num_expert}/linear{suffix}")
|
||||
# TODO verify correctness of layer norm loading
|
||||
|
||||
llm_input_layernorm = state_dict.pop(f"llm/layers/pre_attention_norm_{num_expert}/scale{suffix}")
|
||||
llm_post_attention_layernorm = state_dict.pop(f"llm/layers/pre_ffw_norm_{num_expert}/scale{suffix}")
|
||||
|
||||
for i in range(config.num_hidden_layers):
|
||||
q_proj_weight_reshaped = llm_attention_q_einsum[i].transpose(0, 2, 1).reshape(config.num_attention_heads * config.head_dim, config.hidden_size)
|
||||
|
||||
state_dict[f"gemma_expert.model.layers.{i}.self_attn.q_proj.weight"] = q_proj_weight_reshaped
|
||||
|
||||
k_proj_weight_reshaped = llm_attention_kv_einsum[i, 0, 0].transpose()
|
||||
state_dict[f"gemma_expert.model.layers.{i}.self_attn.k_proj.weight"] = k_proj_weight_reshaped
|
||||
v_proj_weight_reshaped = llm_attention_kv_einsum[i, 1, 0].transpose()
|
||||
state_dict[f"gemma_expert.model.layers.{i}.self_attn.v_proj.weight"] = v_proj_weight_reshaped
|
||||
|
||||
# output projection.
|
||||
|
||||
# llm_attention_attn_vec_einsum[i].shape = (8, 256, 1024)
|
||||
o_proj_weight_reshaped = llm_attention_attn_vec_einsum[i].reshape(config.num_attention_heads * config.head_dim, config.hidden_size).transpose(1,0)# .transpose(2, 0, 1).reshape(config.num_attention_heads * config.head_dim, config.hidden_size).transpose(1, 0)
|
||||
|
||||
state_dict[f"gemma_expert.model.layers.{i}.self_attn.o_proj.weight"] = o_proj_weight_reshaped
|
||||
# mlp layers
|
||||
gate_proj_weight = llm_mlp_gating_einsum[i, 0]
|
||||
state_dict[f"gemma_expert.model.layers.{i}.mlp.gate_proj.weight"] = gate_proj_weight.transpose()
|
||||
up_proj_weight = llm_mlp_gating_einsum[i, 1]
|
||||
state_dict[f"gemma_expert.model.layers.{i}.mlp.up_proj.weight"] = up_proj_weight.transpose()
|
||||
state_dict[f"gemma_expert.model.layers.{i}.mlp.down_proj.weight"] = llm_mlp_linear[i].transpose()
|
||||
state_dict[f"gemma_expert.model.layers.{i}.input_layernorm.weight"] = llm_input_layernorm[i]
|
||||
state_dict[f"gemma_expert.model.layers.{i}.post_attention_layernorm.weight"] = llm_post_attention_layernorm[i]
|
||||
|
||||
state_dict["gemma_expert.model.norm.weight"] = state_dict.pop(f"llm/final_norm_{num_expert}/scale{suffix}")
|
||||
state_dict["gemma_expert.lm_head.weight"] = embedding_vector # weights are tied. (and zeros here)
|
||||
|
||||
# fmt: on
|
||||
final_state_dict = {}
|
||||
for key, value in state_dict.items():
|
||||
if not isinstance(value, torch.Tensor):
|
||||
final_state_dict[key] = torch.from_numpy(value)
|
||||
else:
|
||||
final_state_dict[key] = value
|
||||
return final_state_dict
|
||||
|
||||
|
||||
def flatten_for_memory(tree, parent_key=""):
|
||||
out = {}
|
||||
for k, v in tree.items():
|
||||
new_key = f"{parent_key}/{k}" if parent_key else k
|
||||
if isinstance(v, dict):
|
||||
out.update(flatten_for_memory(v, new_key))
|
||||
else:
|
||||
out[new_key] = np.array(v) # Ensure conversion to np.array for consistency
|
||||
return out
|
||||
|
||||
|
||||
def flatten_for_npz(tree, parent_key=""):
|
||||
out = {}
|
||||
for k, v in tree.items():
|
||||
new_key = f"{parent_key}/{k}" if parent_key else k
|
||||
if isinstance(v, dict):
|
||||
out.update(flatten_for_npz(v, new_key))
|
||||
else:
|
||||
# bf16/f32 here?
|
||||
out[new_key] = np.array(v)
|
||||
return out
|
||||
|
||||
|
||||
def slice_initial_orbax_checkpoint(checkpoint_dir: str):
|
||||
params_path = pathlib.Path(checkpoint_dir).resolve()
|
||||
checkpointer = ocp.PyTreeCheckpointer()
|
||||
|
||||
metadata = checkpointer.metadata(params_path)
|
||||
print("Metadata keys:", list(metadata.keys()))
|
||||
|
||||
params_name = "params"
|
||||
|
||||
item = {params_name: metadata[params_name]}
|
||||
device = jax.local_devices()[0] # Use the first local device
|
||||
sharding = SingleDeviceSharding(device)
|
||||
restored = checkpointer.restore(
|
||||
params_path,
|
||||
ocp.args.PyTreeRestore(
|
||||
item=item,
|
||||
restore_args=jax.tree_util.tree_map(
|
||||
lambda _: ocp.ArrayRestoreArgs(
|
||||
restore_type=jax.Array, # or np.ndarray, but bf16 is annoying about it
|
||||
sharding=sharding,
|
||||
),
|
||||
item,
|
||||
),
|
||||
transforms={},
|
||||
),
|
||||
)
|
||||
params = restored[params_name]
|
||||
|
||||
# get params for PaliGemma
|
||||
pali_params = params["PaliGemma"]
|
||||
del params["PaliGemma"]
|
||||
pali_params_flat = flatten_for_npz(pali_params)
|
||||
return {"paligemma_params": pali_params_flat, "projection_params": params}
|
||||
|
||||
|
||||
def update_keys_with_prefix(d: dict, prefix: str) -> dict:
|
||||
"""Update dictionary keys by adding a prefix."""
|
||||
return {f"{prefix}{key}": value for key, value in d.items()}
|
||||
|
||||
|
||||
def convert_pi0_checkpoint(checkpoint_dir: str, precision: str, tokenizer_id: str, output_path: str):
|
||||
# Break down orbax ckpts - they are in OCDBT
|
||||
initial_params = slice_initial_orbax_checkpoint(checkpoint_dir=checkpoint_dir)
|
||||
# process projection params
|
||||
keys = [
|
||||
"state_proj",
|
||||
"action_in_proj",
|
||||
"action_out_proj",
|
||||
"action_time_mlp_in",
|
||||
"action_time_mlp_out",
|
||||
]
|
||||
|
||||
projection_params = {}
|
||||
for key in keys:
|
||||
kernel_params = initial_params["projection_params"][key]["kernel"]
|
||||
bias_params = initial_params["projection_params"][key]["bias"]
|
||||
if isinstance(kernel_params, dict):
|
||||
weight = kernel_params["value"]
|
||||
bias = bias_params["value"]
|
||||
else:
|
||||
weight = kernel_params
|
||||
bias = bias_params
|
||||
projection_params[f"{key}.weight"] = torch.from_numpy(np.array(weight)).T
|
||||
projection_params[f"{key}.bias"] = torch.from_numpy(np.array(bias))
|
||||
|
||||
# Process PaliGemma weights
|
||||
paligemma_config = get_paligemma_config(precision)
|
||||
paligemma_params, gemma_raw_dictionary = slice_paligemma_state_dict(
|
||||
initial_params["paligemma_params"], paligemma_config
|
||||
)
|
||||
|
||||
# Process Gemma weights (at this stage they are unused)
|
||||
gemma_config = get_gemma_config(precision)
|
||||
gemma_params = slice_gemma_state_dict(gemma_raw_dictionary, config=gemma_config)
|
||||
|
||||
# Instantiate model from configs
|
||||
|
||||
if "pi0_aloha_sim" in checkpoint_dir:
|
||||
pi0_config = PI0Config(
|
||||
empty_cameras=2,
|
||||
adapt_to_pi_aloha=True,
|
||||
use_delta_joint_actions_aloha=False,
|
||||
)
|
||||
elif "pi0_aloha_towel" in checkpoint_dir:
|
||||
pi0_config = PI0Config(
|
||||
adapt_to_pi_aloha=True,
|
||||
use_delta_joint_actions_aloha=True,
|
||||
)
|
||||
elif "pi0_base" in checkpoint_dir:
|
||||
pi0_config = PI0Config(
|
||||
empty_cameras=0,
|
||||
adapt_to_pi_aloha=False,
|
||||
use_delta_joint_actions_aloha=False,
|
||||
)
|
||||
else:
|
||||
raise ValueError()
|
||||
|
||||
# gemma_config=gemma_config, paligemma_config=paligemma_config)
|
||||
pi0_model = PI0Policy(pi0_config)
|
||||
|
||||
paligemma_params = update_keys_with_prefix(paligemma_params, "model.paligemma_with_expert.")
|
||||
gemma_params = update_keys_with_prefix(gemma_params, "model.paligemma_with_expert.")
|
||||
projection_params = update_keys_with_prefix(projection_params, "model.")
|
||||
|
||||
# load state dict
|
||||
torch_dtype = PRECISIONS[precision]
|
||||
pi0_model.load_state_dict({**paligemma_params, **gemma_params, **projection_params})
|
||||
pi0_model = pi0_model.to(torch_dtype)
|
||||
# pi0_tokenizer = AutoTokenizer.from_pretrained(tokenizer_id)
|
||||
|
||||
pi0_model.save_pretrained(output_path, safe_serialization=True)
|
||||
# pi0_tokenizer.save_pretrained(output_path, dtype=torch_dtype)
|
||||
|
||||
# assert that model loads properly
|
||||
del pi0_model
|
||||
PI0Policy.from_pretrained(output_path)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument(
|
||||
"--checkpoint_dir",
|
||||
default="/raid/pablo/.cache/openpi/openpi-assets/checkpoints/pi0_aloha_sim/params",
|
||||
type=str,
|
||||
help="Path to the ocdbt checkpoint",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--precision",
|
||||
choices=["float32", "bfloat16", "float16"],
|
||||
default="float32",
|
||||
type=str,
|
||||
help="Precision identifier for model conversion - should match the base checkpoint precision.",
|
||||
)
|
||||
# tokenizer is identical to paligemma, it appears
|
||||
|
||||
parser.add_argument(
|
||||
"--tokenizer_hub_id",
|
||||
default="google/paligemma-3b-pt-224",
|
||||
type=str,
|
||||
help="Hub path to the tokenizer to save",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--output_path",
|
||||
required=True,
|
||||
type=str,
|
||||
help="Path to save converted weights to",
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
convert_pi0_checkpoint(
|
||||
checkpoint_dir=args.checkpoint_dir,
|
||||
precision=args.precision,
|
||||
tokenizer_id=args.tokenizer_hub_id,
|
||||
output_path=args.output_path,
|
||||
)
|
|
@ -0,0 +1,127 @@
|
|||
import torch
|
||||
import torch.nn.functional as F # noqa: N812
|
||||
from packaging.version import Version
|
||||
|
||||
if Version(torch.__version__) > Version("2.5.0"):
|
||||
# Ffex attention is only available from torch 2.5 onwards
|
||||
from torch.nn.attention.flex_attention import (
|
||||
_mask_mod_signature,
|
||||
_round_up_to_multiple,
|
||||
create_block_mask,
|
||||
create_mask,
|
||||
flex_attention,
|
||||
)
|
||||
|
||||
|
||||
# @torch.compile(dynamic=False)
|
||||
def flex_attention_forward(
|
||||
attention_mask: torch.Tensor,
|
||||
batch_size: int,
|
||||
head_dim: int,
|
||||
query_states: torch.Tensor,
|
||||
key_states: torch.Tensor,
|
||||
value_states: torch.Tensor,
|
||||
scaling=None,
|
||||
):
|
||||
"""
|
||||
This is defined out of classes to make compile happy.
|
||||
"""
|
||||
|
||||
original_dtype = query_states.dtype
|
||||
num_att_heads = 8
|
||||
num_key_value_heads = 1
|
||||
num_key_value_groups = num_att_heads // num_key_value_heads
|
||||
|
||||
key_states = key_states[:, :, :, None, :]
|
||||
key_states = key_states.expand(
|
||||
batch_size, key_states.shape[1], num_key_value_heads, num_key_value_groups, head_dim
|
||||
)
|
||||
key_states = key_states.reshape(
|
||||
batch_size, key_states.shape[1], num_key_value_heads * num_key_value_groups, head_dim
|
||||
)
|
||||
|
||||
value_states = value_states[:, :, :, None, :]
|
||||
value_states = value_states.expand(
|
||||
batch_size, value_states.shape[1], num_key_value_heads, num_key_value_groups, head_dim
|
||||
)
|
||||
value_states = value_states.reshape(
|
||||
batch_size, value_states.shape[1], num_key_value_heads * num_key_value_groups, head_dim
|
||||
)
|
||||
|
||||
query_states = query_states.transpose(1, 2)
|
||||
key_states = key_states.transpose(1, 2)
|
||||
value_states = value_states.transpose(1, 2)
|
||||
|
||||
query_states = query_states.to(torch.float32)
|
||||
key_states = key_states.to(torch.float32)
|
||||
value_states = value_states.to(torch.float32)
|
||||
|
||||
causal_mask = attention_mask
|
||||
if causal_mask is not None:
|
||||
causal_mask = causal_mask[:, None, :, : key_states.shape[2]]
|
||||
|
||||
if causal_mask.shape[1] == 1 and query_states.shape[1] > 1:
|
||||
causal_mask = causal_mask.expand(-1, query_states.shape[1], -1, -1)
|
||||
|
||||
def precomputed_mask_factory(precomputed_mask: torch.Tensor) -> _mask_mod_signature:
|
||||
def mask_mod(b, h, q_idx, kv_idx):
|
||||
# Danger zone: if b,h,q_idx,kv_idx exceed the shape, device-side assert occurs.
|
||||
return precomputed_mask[b][h][q_idx][kv_idx]
|
||||
|
||||
return mask_mod
|
||||
|
||||
b_mask, h_mask, q_len, kv_len = causal_mask.shape # The shape of your mask
|
||||
|
||||
block_size = 128
|
||||
q_len_rounded = _round_up_to_multiple(q_len, block_size)
|
||||
kv_len_rounded = _round_up_to_multiple(kv_len, block_size)
|
||||
|
||||
# *CRITICAL* we do need to expand here, else we get a CUDA index error
|
||||
|
||||
pad_q = q_len_rounded - q_len
|
||||
pad_k = kv_len_rounded - kv_len
|
||||
|
||||
padded_causal_mask = F.pad(causal_mask, (0, pad_k, 0, pad_q), value=0.0)
|
||||
mask_mod_fn_orig = precomputed_mask_factory(padded_causal_mask)
|
||||
|
||||
mask_4d = create_mask(
|
||||
mod_fn=mask_mod_fn_orig,
|
||||
B=b_mask,
|
||||
H=h_mask,
|
||||
Q_LEN=q_len_rounded,
|
||||
KV_LEN=kv_len_rounded,
|
||||
device=causal_mask.device,
|
||||
_compile=False,
|
||||
)
|
||||
|
||||
mask_mod_fn_padded = precomputed_mask_factory(mask_4d)
|
||||
block_mask = create_block_mask(
|
||||
mask_mod=mask_mod_fn_padded,
|
||||
B=b_mask,
|
||||
H=h_mask,
|
||||
Q_LEN=q_len_rounded,
|
||||
KV_LEN=kv_len_rounded,
|
||||
BLOCK_SIZE=block_size,
|
||||
device=causal_mask.device,
|
||||
_compile=False,
|
||||
)
|
||||
|
||||
# mask is applied inside the kernel, ideally more efficiently than score_mod.
|
||||
attn_output, attention_weights = flex_attention(
|
||||
query_states,
|
||||
key_states,
|
||||
value_states,
|
||||
block_mask=block_mask,
|
||||
enable_gqa=True, # because we shaped query/key states for GQA
|
||||
scale=head_dim**-0.5 if scaling is None else scaling,
|
||||
return_lse=True,
|
||||
)
|
||||
|
||||
attn_output = attn_output.to(dtype=original_dtype)
|
||||
attn_output = attn_output.transpose(1, 2).contiguous() # [B, Q_LEN, H, head_dim]
|
||||
attn_output = attn_output.reshape(
|
||||
batch_size,
|
||||
-1,
|
||||
attn_output.shape[2] * attn_output.shape[3], # merges [H, head_dim]
|
||||
)
|
||||
return attn_output
|
|
@ -0,0 +1,732 @@
|
|||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2025 Physical Intelligence and The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""
|
||||
π0: A Vision-Language-Action Flow Model for General Robot Control
|
||||
|
||||
[Paper](https://www.physicalintelligence.company/download/pi0.pdf)
|
||||
[Jax code](https://github.com/Physical-Intelligence/openpi)
|
||||
|
||||
Designed by Physical Intelligence. Ported from Jax by Hugging Face.
|
||||
|
||||
Install pi0 extra dependencies:
|
||||
```bash
|
||||
pip install -e ".[pi0]"
|
||||
```
|
||||
|
||||
Example of finetuning the pi0 pretrained model (`pi0_base` in `openpi`):
|
||||
```bash
|
||||
python lerobot/scripts/train.py \
|
||||
--policy.path=lerobot/pi0 \
|
||||
--dataset.repo_id=danaaubakirova/koch_test
|
||||
```
|
||||
|
||||
Example of finetuning the pi0 neural network with PaliGemma and expert Gemma
|
||||
pretrained with VLM default parameters before pi0 finetuning:
|
||||
```bash
|
||||
python lerobot/scripts/train.py \
|
||||
--policy.type=pi0 \
|
||||
--dataset.repo_id=danaaubakirova/koch_test
|
||||
```
|
||||
|
||||
Example of using the pi0 pretrained model outside LeRobot training framework:
|
||||
```python
|
||||
policy = Pi0Policy.from_pretrained("lerobot/pi0")
|
||||
```
|
||||
|
||||
"""
|
||||
|
||||
import math
|
||||
from collections import deque
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F # noqa: N812
|
||||
from torch import Tensor, nn
|
||||
from transformers import AutoTokenizer
|
||||
|
||||
from lerobot.common.constants import ACTION, OBS_ROBOT
|
||||
from lerobot.common.policies.normalize import Normalize, Unnormalize
|
||||
from lerobot.common.policies.pi0.configuration_pi0 import PI0Config
|
||||
from lerobot.common.policies.pi0.paligemma_with_expert import (
|
||||
PaliGemmaWithExpertConfig,
|
||||
PaliGemmaWithExpertModel,
|
||||
)
|
||||
from lerobot.common.policies.pretrained import PreTrainedPolicy
|
||||
from lerobot.common.utils.utils import get_safe_dtype
|
||||
|
||||
|
||||
def create_sinusoidal_pos_embedding(
|
||||
time: torch.tensor, dimension: int, min_period: float, max_period: float, device="cpu"
|
||||
) -> Tensor:
|
||||
"""Computes sine-cosine positional embedding vectors for scalar positions."""
|
||||
if dimension % 2 != 0:
|
||||
raise ValueError(f"dimension ({dimension}) must be divisible by 2")
|
||||
|
||||
if time.ndim != 1:
|
||||
raise ValueError("The time tensor is expected to be of shape `(batch_size, )`.")
|
||||
|
||||
dtype = get_safe_dtype(torch.float64, device.type)
|
||||
fraction = torch.linspace(0.0, 1.0, dimension // 2, dtype=dtype, device=device)
|
||||
period = min_period * (max_period / min_period) ** fraction
|
||||
|
||||
# Compute the outer product
|
||||
scaling_factor = 1.0 / period * 2 * math.pi
|
||||
sin_input = scaling_factor[None, :] * time[:, None]
|
||||
pos_emb = torch.cat([torch.sin(sin_input), torch.cos(sin_input)], dim=1)
|
||||
return pos_emb
|
||||
|
||||
|
||||
def sample_beta(alpha, beta, bsize, device):
|
||||
gamma1 = torch.empty((bsize,), device=device).uniform_(0, 1).pow(1 / alpha)
|
||||
gamma2 = torch.empty((bsize,), device=device).uniform_(0, 1).pow(1 / beta)
|
||||
return gamma1 / (gamma1 + gamma2)
|
||||
|
||||
|
||||
def make_att_2d_masks(pad_masks, att_masks):
|
||||
"""Copied from big_vision.
|
||||
|
||||
Tokens can attend to valid inputs tokens which have a cumulative mask_ar
|
||||
smaller or equal to theirs. This way `mask_ar` int[B, N] can be used to
|
||||
setup several types of attention, for example:
|
||||
|
||||
[[1 1 1 1 1 1]]: pure causal attention.
|
||||
|
||||
[[0 0 0 1 1 1]]: prefix-lm attention. The first 3 tokens can attend between
|
||||
themselves and the last 3 tokens have a causal attention. The first
|
||||
entry could also be a 1 without changing behaviour.
|
||||
|
||||
[[1 0 1 0 1 0 0 1 0 0]]: causal attention between 4 blocks. Tokens of a
|
||||
block can attend all previous blocks and all tokens on the same block.
|
||||
|
||||
Args:
|
||||
input_mask: bool[B, N] true if its part of the input, false if padding.
|
||||
mask_ar: int32[B, N] mask that's 1 where previous tokens cannot depend on
|
||||
it and 0 where it shares the same attention mask as the previous token.
|
||||
"""
|
||||
if att_masks.ndim != 2:
|
||||
raise ValueError(att_masks.ndim)
|
||||
if pad_masks.ndim != 2:
|
||||
raise ValueError(pad_masks.ndim)
|
||||
|
||||
cumsum = torch.cumsum(att_masks, dim=1)
|
||||
att_2d_masks = cumsum[:, None, :] <= cumsum[:, :, None]
|
||||
pad_2d_masks = pad_masks[:, None, :] * pad_masks[:, :, None]
|
||||
att_2d_masks = att_2d_masks & pad_2d_masks
|
||||
return att_2d_masks
|
||||
|
||||
|
||||
def resize_with_pad(img, width, height, pad_value=-1):
|
||||
# assume no-op when width height fits already
|
||||
if img.ndim != 4:
|
||||
raise ValueError(f"(b,c,h,w) expected, but {img.shape}")
|
||||
|
||||
cur_height, cur_width = img.shape[2:]
|
||||
|
||||
ratio = max(cur_width / width, cur_height / height)
|
||||
resized_height = int(cur_height / ratio)
|
||||
resized_width = int(cur_width / ratio)
|
||||
resized_img = F.interpolate(
|
||||
img, size=(resized_height, resized_width), mode="bilinear", align_corners=False
|
||||
)
|
||||
|
||||
pad_height = max(0, int(height - resized_height))
|
||||
pad_width = max(0, int(width - resized_width))
|
||||
|
||||
# pad on left and top of image
|
||||
padded_img = F.pad(resized_img, (pad_width, 0, pad_height, 0), value=pad_value)
|
||||
return padded_img
|
||||
|
||||
|
||||
def pad_vector(vector, new_dim):
|
||||
"""Can be (batch_size x sequence_length x features_dimension)
|
||||
or (batch_size x features_dimension)
|
||||
"""
|
||||
if vector.shape[-1] == new_dim:
|
||||
return vector
|
||||
shape = list(vector.shape)
|
||||
current_dim = shape[-1]
|
||||
shape[-1] = new_dim
|
||||
new_vector = torch.zeros(*shape, dtype=vector.dtype, device=vector.device)
|
||||
new_vector[..., :current_dim] = vector
|
||||
return new_vector
|
||||
|
||||
|
||||
def normalize(x, min_val, max_val):
|
||||
return (x - min_val) / (max_val - min_val)
|
||||
|
||||
|
||||
def unnormalize(x, min_val, max_val):
|
||||
return x * (max_val - min_val) + min_val
|
||||
|
||||
|
||||
def safe_arcsin(value):
|
||||
# This ensures that the input stays within
|
||||
# [−1,1] to avoid invalid values for arcsin
|
||||
return torch.arcsin(torch.clamp(value, -1.0, 1.0))
|
||||
|
||||
|
||||
def aloha_gripper_to_angular(value):
|
||||
# Aloha transforms the gripper positions into a linear space. The following code
|
||||
# reverses this transformation to be consistent with pi0 which is pretrained in
|
||||
# angular space.
|
||||
#
|
||||
# These values are coming from the Aloha code:
|
||||
# PUPPET_GRIPPER_POSITION_OPEN, PUPPET_GRIPPER_POSITION_CLOSED
|
||||
value = unnormalize(value, min_val=0.01844, max_val=0.05800)
|
||||
|
||||
# This is the inverse of the angular to linear transformation inside the Interbotix code.
|
||||
def linear_to_radian(linear_position, arm_length, horn_radius):
|
||||
value = (horn_radius**2 + linear_position**2 - arm_length**2) / (2 * horn_radius * linear_position)
|
||||
return safe_arcsin(value)
|
||||
|
||||
# The constants are taken from the Interbotix code.
|
||||
value = linear_to_radian(value, arm_length=0.036, horn_radius=0.022)
|
||||
|
||||
# Normalize to [0, 1].
|
||||
# The values 0.4 and 1.5 were measured on an actual Trossen robot.
|
||||
return normalize(value, min_val=0.4, max_val=1.5)
|
||||
|
||||
|
||||
def aloha_gripper_from_angular(value):
|
||||
# Convert from the gripper position used by pi0 to the gripper position that is used by Aloha.
|
||||
# Note that the units are still angular but the range is different.
|
||||
|
||||
# The values 0.4 and 1.5 were measured on an actual Trossen robot.
|
||||
value = unnormalize(value, min_val=0.4, max_val=1.5)
|
||||
|
||||
# These values are coming from the Aloha code:
|
||||
# PUPPET_GRIPPER_JOINT_OPEN, PUPPET_GRIPPER_JOINT_CLOSE
|
||||
return normalize(value, min_val=-0.6213, max_val=1.4910)
|
||||
|
||||
|
||||
def aloha_gripper_from_angular_inv(value):
|
||||
# Directly inverts the gripper_from_angular function.
|
||||
value = unnormalize(value, min_val=-0.6213, max_val=1.4910)
|
||||
return normalize(value, min_val=0.4, max_val=1.5)
|
||||
|
||||
|
||||
class PI0Policy(PreTrainedPolicy):
|
||||
"""Wrapper class around PI0FlowMatching model to train and run inference within LeRobot."""
|
||||
|
||||
config_class = PI0Config
|
||||
name = "pi0"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config: PI0Config,
|
||||
dataset_stats: dict[str, dict[str, Tensor]] | None = None,
|
||||
):
|
||||
"""
|
||||
Args:
|
||||
config: Policy configuration class instance or None, in which case the default instantiation of
|
||||
the configuration class is used.
|
||||
dataset_stats: Dataset statistics to be used for normalization. If not passed here, it is expected
|
||||
that they will be passed with a call to `load_state_dict` before the policy is used.
|
||||
"""
|
||||
|
||||
super().__init__(config)
|
||||
config.validate_features()
|
||||
self.config = config
|
||||
self.normalize_inputs = Normalize(config.input_features, config.normalization_mapping, dataset_stats)
|
||||
self.normalize_targets = Normalize(
|
||||
config.output_features, config.normalization_mapping, dataset_stats
|
||||
)
|
||||
self.unnormalize_outputs = Unnormalize(
|
||||
config.output_features, config.normalization_mapping, dataset_stats
|
||||
)
|
||||
|
||||
self.language_tokenizer = AutoTokenizer.from_pretrained("google/paligemma-3b-pt-224")
|
||||
self.model = PI0FlowMatching(config)
|
||||
|
||||
self.reset()
|
||||
|
||||
def reset(self):
|
||||
"""This should be called whenever the environment is reset."""
|
||||
self._action_queue = deque([], maxlen=self.config.n_action_steps)
|
||||
|
||||
def get_optim_params(self) -> dict:
|
||||
return self.parameters()
|
||||
|
||||
@torch.no_grad
|
||||
def select_action(self, batch: dict[str, Tensor], noise: Tensor | None = None) -> Tensor:
|
||||
"""Select a single action given environment observations.
|
||||
|
||||
This method wraps `select_actions` in order to return one action at a time for execution in the
|
||||
environment. It works by managing the actions in a queue and only calling `select_actions` when the
|
||||
queue is empty.
|
||||
"""
|
||||
self.eval()
|
||||
|
||||
if self.config.adapt_to_pi_aloha:
|
||||
batch[OBS_ROBOT] = self._pi_aloha_decode_state(batch[OBS_ROBOT])
|
||||
|
||||
batch = self.normalize_inputs(batch)
|
||||
|
||||
# Action queue logic for n_action_steps > 1. When the action_queue is depleted, populate it by
|
||||
# querying the policy.
|
||||
if len(self._action_queue) == 0:
|
||||
images, img_masks = self.prepare_images(batch)
|
||||
state = self.prepare_state(batch)
|
||||
lang_tokens, lang_masks = self.prepare_language(batch)
|
||||
|
||||
actions = self.model.sample_actions(
|
||||
images, img_masks, lang_tokens, lang_masks, state, noise=noise
|
||||
)
|
||||
|
||||
# Unpad actions
|
||||
original_action_dim = self.config.action_feature.shape[0]
|
||||
actions = actions[:, :, :original_action_dim]
|
||||
|
||||
actions = self.unnormalize_outputs({"action": actions})["action"]
|
||||
|
||||
if self.config.adapt_to_pi_aloha:
|
||||
actions = self._pi_aloha_encode_actions(actions)
|
||||
|
||||
# `self.model.forward` returns a (batch_size, n_action_steps, action_dim) tensor, but the queue
|
||||
# effectively has shape (n_action_steps, batch_size, *), hence the transpose.
|
||||
self._action_queue.extend(actions.transpose(0, 1))
|
||||
return self._action_queue.popleft()
|
||||
|
||||
def forward(self, batch: dict[str, Tensor], noise=None, time=None) -> dict[str, Tensor]:
|
||||
"""Do a full training forward pass to compute the loss"""
|
||||
if self.config.adapt_to_pi_aloha:
|
||||
batch[OBS_ROBOT] = self._pi_aloha_decode_state(batch[OBS_ROBOT])
|
||||
batch[ACTION] = self._pi_aloha_encode_actions_inv(batch[ACTION])
|
||||
|
||||
batch = self.normalize_inputs(batch)
|
||||
batch = self.normalize_targets(batch)
|
||||
|
||||
images, img_masks = self.prepare_images(batch)
|
||||
state = self.prepare_state(batch)
|
||||
lang_tokens, lang_masks = self.prepare_language(batch)
|
||||
actions = self.prepare_action(batch)
|
||||
actions_is_pad = batch.get("actions_id_pad")
|
||||
|
||||
loss_dict = {}
|
||||
losses = self.model.forward(images, img_masks, lang_tokens, lang_masks, state, actions, noise, time)
|
||||
loss_dict["losses_after_forward"] = losses.clone()
|
||||
|
||||
if actions_is_pad is not None:
|
||||
in_episode_bound = ~actions_is_pad
|
||||
losses = losses * in_episode_bound.unsqueeze(-1)
|
||||
loss_dict["losses_after_in_ep_bound"] = losses.clone()
|
||||
|
||||
# Remove padding
|
||||
losses = losses[:, :, : self.config.max_action_dim]
|
||||
loss_dict["losses_after_rm_padding"] = losses.clone()
|
||||
|
||||
loss = losses.mean()
|
||||
# For backward pass
|
||||
loss_dict["loss"] = loss
|
||||
# For logging
|
||||
loss_dict["l2_loss"] = loss.item()
|
||||
return loss_dict
|
||||
|
||||
def prepare_images(self, batch):
|
||||
"""Apply Pi0 preprocessing to the images, like resizing to 224x224 and padding to keep aspect ratio, and
|
||||
convert pixel range from [0.0, 1.0] to [-1.0, 1.0] as requested by SigLIP.
|
||||
"""
|
||||
images = []
|
||||
img_masks = []
|
||||
|
||||
present_img_keys = [key for key in self.config.image_features if key in batch]
|
||||
missing_img_keys = [key for key in self.config.image_features if key not in batch]
|
||||
|
||||
if len(present_img_keys) == 0:
|
||||
raise ValueError(
|
||||
f"All image features are missing from the batch. At least one expected. (batch: {batch.keys()}) (image_features:{self.config.image_features})"
|
||||
)
|
||||
|
||||
# Preprocess image features present in the batch
|
||||
for key in present_img_keys:
|
||||
img = batch[key]
|
||||
|
||||
if self.config.resize_imgs_with_padding is not None:
|
||||
img = resize_with_pad(img, *self.config.resize_imgs_with_padding, pad_value=0)
|
||||
|
||||
# Normalize from range [0,1] to [-1,1] as expacted by siglip
|
||||
img = img * 2.0 - 1.0
|
||||
|
||||
bsize = img.shape[0]
|
||||
device = img.device
|
||||
mask = torch.ones(bsize, dtype=torch.bool, device=device)
|
||||
images.append(img)
|
||||
img_masks.append(mask)
|
||||
|
||||
# Create image features not present in the batch
|
||||
# as fully 0 padded images.
|
||||
for num_empty_cameras in range(len(missing_img_keys)):
|
||||
if num_empty_cameras >= self.config.empty_cameras:
|
||||
break
|
||||
img = torch.ones_like(img) * -1
|
||||
mask = torch.zeros_like(mask)
|
||||
images.append(img)
|
||||
img_masks.append(mask)
|
||||
|
||||
return images, img_masks
|
||||
|
||||
def prepare_language(self, batch) -> tuple[Tensor, Tensor]:
|
||||
"""Tokenize the text input"""
|
||||
device = batch[OBS_ROBOT].device
|
||||
tasks = batch["task"]
|
||||
|
||||
# PaliGemma prompt has to end with a new line
|
||||
tasks = [task if task.endswith("\n") else f"{task}\n" for task in tasks]
|
||||
|
||||
tokenized_prompt = self.language_tokenizer.__call__(
|
||||
tasks,
|
||||
padding="max_length",
|
||||
padding_side="right",
|
||||
max_length=self.config.tokenizer_max_length,
|
||||
return_tensors="pt",
|
||||
)
|
||||
lang_tokens = tokenized_prompt["input_ids"].to(device=device)
|
||||
lang_masks = tokenized_prompt["attention_mask"].to(device=device, dtype=torch.bool)
|
||||
|
||||
return lang_tokens, lang_masks
|
||||
|
||||
def _pi_aloha_decode_state(self, state):
|
||||
# Flip the joints.
|
||||
for motor_idx in [1, 2, 8, 9]:
|
||||
state[:, motor_idx] *= -1
|
||||
# Reverse the gripper transformation that is being applied by the Aloha runtime.
|
||||
for motor_idx in [6, 13]:
|
||||
state[:, motor_idx] = aloha_gripper_to_angular(state[:, motor_idx])
|
||||
return state
|
||||
|
||||
def _pi_aloha_encode_actions(self, actions):
|
||||
# Flip the joints.
|
||||
for motor_idx in [1, 2, 8, 9]:
|
||||
actions[:, :, motor_idx] *= -1
|
||||
# Reverse the gripper transformation that is being applied by the Aloha runtime.
|
||||
for motor_idx in [6, 13]:
|
||||
actions[:, :, motor_idx] = aloha_gripper_from_angular(actions[:, :, motor_idx])
|
||||
return actions
|
||||
|
||||
def _pi_aloha_encode_actions_inv(self, actions):
|
||||
# Flip the joints again.
|
||||
for motor_idx in [1, 2, 8, 9]:
|
||||
actions[:, :, motor_idx] *= -1
|
||||
# Reverse the gripper transformation that is being applied by the Aloha runtime.
|
||||
for motor_idx in [6, 13]:
|
||||
actions[:, :, motor_idx] = aloha_gripper_from_angular_inv(actions[:, :, motor_idx])
|
||||
return actions
|
||||
|
||||
def prepare_state(self, batch):
|
||||
"""Pad state"""
|
||||
state = pad_vector(batch[OBS_ROBOT], self.config.max_state_dim)
|
||||
return state
|
||||
|
||||
def prepare_action(self, batch):
|
||||
"""Pad action"""
|
||||
actions = pad_vector(batch[ACTION], self.config.max_action_dim)
|
||||
return actions
|
||||
|
||||
|
||||
class PI0FlowMatching(nn.Module):
|
||||
"""
|
||||
π0: A Vision-Language-Action Flow Model for General Robot Control
|
||||
|
||||
[Paper](https://www.physicalintelligence.company/download/pi0.pdf)
|
||||
[Jax code](https://github.com/Physical-Intelligence/openpi)
|
||||
|
||||
Designed by Physical Intelligence. Ported from Jax by Hugging Face.
|
||||
┌──────────────────────────────┐
|
||||
│ actions │
|
||||
│ ▲ │
|
||||
│ ┌┴─────┐ │
|
||||
│ kv cache │Gemma │ │
|
||||
│ ┌──────────►│Expert│ │
|
||||
│ │ │ │ │
|
||||
│ ┌┴────────┐ │x 10 │ │
|
||||
│ │ │ └▲──▲──┘ │
|
||||
│ │PaliGemma│ │ │ │
|
||||
│ │ │ │ robot state │
|
||||
│ │ │ noise │
|
||||
│ └▲──▲─────┘ │
|
||||
│ │ │ │
|
||||
│ │ image(s) │
|
||||
│ language tokens │
|
||||
└──────────────────────────────┘
|
||||
"""
|
||||
|
||||
def __init__(self, config):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
|
||||
paligemma_with_export_config = PaliGemmaWithExpertConfig(
|
||||
freeze_vision_encoder=self.config.freeze_vision_encoder,
|
||||
train_expert_only=self.config.train_expert_only,
|
||||
attention_implementation=self.config.attention_implementation,
|
||||
)
|
||||
self.paligemma_with_expert = PaliGemmaWithExpertModel(paligemma_with_export_config)
|
||||
|
||||
# Projections are float32
|
||||
self.state_proj = nn.Linear(self.config.max_state_dim, self.config.proj_width)
|
||||
self.action_in_proj = nn.Linear(self.config.max_action_dim, self.config.proj_width)
|
||||
self.action_out_proj = nn.Linear(self.config.proj_width, self.config.max_action_dim)
|
||||
|
||||
self.action_time_mlp_in = nn.Linear(self.config.proj_width * 2, self.config.proj_width)
|
||||
self.action_time_mlp_out = nn.Linear(self.config.proj_width, self.config.proj_width)
|
||||
|
||||
self.set_requires_grad()
|
||||
|
||||
def set_requires_grad(self):
|
||||
for params in self.state_proj.parameters():
|
||||
params.requires_grad = self.config.train_state_proj
|
||||
|
||||
def sample_noise(self, shape, device):
|
||||
noise = torch.normal(
|
||||
mean=0.0,
|
||||
std=1.0,
|
||||
size=shape,
|
||||
dtype=torch.float32,
|
||||
device=device,
|
||||
)
|
||||
return noise
|
||||
|
||||
def sample_time(self, bsize, device):
|
||||
time_beta = sample_beta(1.5, 1.0, bsize, device)
|
||||
time = time_beta * 0.999 + 0.001
|
||||
return time.to(dtype=torch.float32, device=device)
|
||||
|
||||
def embed_prefix(
|
||||
self, images, img_masks, lang_tokens, lang_masks
|
||||
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
"""Embed images with SigLIP and language tokens with embedding layer to prepare
|
||||
for PaliGemma transformer processing.
|
||||
"""
|
||||
# TODO: avoid list in python and torch.cat ; prefer pre-allocation with torch.empty
|
||||
embs = []
|
||||
pad_masks = []
|
||||
att_masks = []
|
||||
|
||||
# TODO: remove for loop
|
||||
for (
|
||||
img,
|
||||
img_mask,
|
||||
) in zip(images, img_masks, strict=False):
|
||||
img_emb = self.paligemma_with_expert.embed_image(img)
|
||||
img_emb = img_emb.to(dtype=torch.bfloat16)
|
||||
|
||||
# Normalize image embeddings
|
||||
img_emb_dim = img_emb.shape[-1]
|
||||
img_emb = img_emb * torch.tensor(img_emb_dim**0.5, dtype=img_emb.dtype, device=img_emb.device)
|
||||
|
||||
bsize, num_img_embs = img_emb.shape[:2]
|
||||
img_mask = img_mask[:, None].expand(bsize, num_img_embs)
|
||||
|
||||
embs.append(img_emb)
|
||||
pad_masks.append(img_mask)
|
||||
|
||||
# Create attention masks so that image tokens attend to each other
|
||||
att_masks += [0] * num_img_embs
|
||||
|
||||
lang_emb = self.paligemma_with_expert.embed_language_tokens(lang_tokens)
|
||||
|
||||
# Normalize language embeddings
|
||||
lang_emb_dim = lang_emb.shape[-1]
|
||||
lang_emb = lang_emb * math.sqrt(lang_emb_dim)
|
||||
|
||||
embs.append(lang_emb)
|
||||
pad_masks.append(lang_masks)
|
||||
|
||||
# full attention between image and language inputs
|
||||
num_lang_embs = lang_emb.shape[1]
|
||||
att_masks += [0] * num_lang_embs
|
||||
|
||||
embs = torch.cat(embs, dim=1)
|
||||
pad_masks = torch.cat(pad_masks, dim=1)
|
||||
att_masks = torch.tensor(att_masks, dtype=torch.bool, device=pad_masks.device)
|
||||
att_masks = att_masks[None, :].expand(bsize, len(att_masks))
|
||||
|
||||
return embs, pad_masks, att_masks
|
||||
|
||||
def embed_suffix(self, state, noisy_actions, timestep):
|
||||
"""Embed state, noisy_actions, timestep to prepare for Expert Gemma processing."""
|
||||
embs = []
|
||||
pad_masks = []
|
||||
att_masks = []
|
||||
|
||||
# Embed state
|
||||
state_emb = self.state_proj(state)
|
||||
state_emb = state_emb.to(dtype=torch.bfloat16)
|
||||
embs.append(state_emb[:, None, :])
|
||||
bsize = state_emb.shape[0]
|
||||
dtype = state_emb.dtype
|
||||
device = state_emb.device
|
||||
|
||||
state_mask = torch.ones(bsize, 1, dtype=torch.bool, device=device)
|
||||
pad_masks.append(state_mask)
|
||||
|
||||
# Set attention masks so that image and language inputs do not attend to state or actions
|
||||
att_masks += [1]
|
||||
|
||||
# Embed timestep using sine-cosine positional encoding with sensitivity in the range [0, 1]
|
||||
time_emb = create_sinusoidal_pos_embedding(
|
||||
timestep, self.config.proj_width, min_period=4e-3, max_period=4.0, device=device
|
||||
)
|
||||
time_emb = time_emb.type(dtype=dtype)
|
||||
|
||||
# Fuse timestep + action information using an MLP
|
||||
action_emb = self.action_in_proj(noisy_actions)
|
||||
|
||||
time_emb = time_emb[:, None, :].expand_as(action_emb)
|
||||
action_time_emb = torch.cat([action_emb, time_emb], dim=2)
|
||||
|
||||
action_time_emb = self.action_time_mlp_in(action_time_emb)
|
||||
action_time_emb = F.silu(action_time_emb) # swish == silu
|
||||
action_time_emb = self.action_time_mlp_out(action_time_emb)
|
||||
|
||||
# Add to input tokens
|
||||
embs.append(action_time_emb)
|
||||
|
||||
bsize, action_time_dim = action_time_emb.shape[:2]
|
||||
action_time_mask = torch.ones(bsize, action_time_dim, dtype=torch.bool, device=device)
|
||||
pad_masks.append(action_time_mask)
|
||||
|
||||
# Set attention masks so that image, language and state inputs do not attend to action tokens
|
||||
att_masks += [1] + ([0] * (self.config.n_action_steps - 1))
|
||||
|
||||
embs = torch.cat(embs, dim=1)
|
||||
pad_masks = torch.cat(pad_masks, dim=1)
|
||||
att_masks = torch.tensor(att_masks, dtype=embs.dtype, device=embs.device)
|
||||
att_masks = att_masks[None, :].expand(bsize, len(att_masks))
|
||||
|
||||
return embs, pad_masks, att_masks
|
||||
|
||||
def forward(
|
||||
self, images, img_masks, lang_tokens, lang_masks, state, actions, noise=None, time=None
|
||||
) -> Tensor:
|
||||
"""Do a full training forward pass and compute the loss (batch_size x num_steps x num_motors)"""
|
||||
if noise is None:
|
||||
noise = self.sample_noise(actions.shape, actions.device)
|
||||
|
||||
if time is None:
|
||||
time = self.sample_time(actions.shape[0], actions.device)
|
||||
|
||||
time_expanded = time[:, None, None]
|
||||
x_t = time_expanded * noise + (1 - time_expanded) * actions
|
||||
u_t = noise - actions
|
||||
|
||||
prefix_embs, prefix_pad_masks, prefix_att_masks = self.embed_prefix(
|
||||
images, img_masks, lang_tokens, lang_masks
|
||||
)
|
||||
suffix_embs, suffix_pad_masks, suffix_att_masks = self.embed_suffix(state, x_t, time)
|
||||
|
||||
pad_masks = torch.cat([prefix_pad_masks, suffix_pad_masks], dim=1)
|
||||
att_masks = torch.cat([prefix_att_masks, suffix_att_masks], dim=1)
|
||||
|
||||
att_2d_masks = make_att_2d_masks(pad_masks, att_masks)
|
||||
position_ids = torch.cumsum(pad_masks, dim=1) - 1
|
||||
|
||||
(_, suffix_out), _ = self.paligemma_with_expert.forward(
|
||||
attention_mask=att_2d_masks,
|
||||
position_ids=position_ids,
|
||||
past_key_values=None,
|
||||
inputs_embeds=[prefix_embs, suffix_embs],
|
||||
use_cache=False,
|
||||
fill_kv_cache=False,
|
||||
)
|
||||
suffix_out = suffix_out[:, -self.config.n_action_steps :]
|
||||
# Original openpi code, upcast attention output
|
||||
suffix_out = suffix_out.to(dtype=torch.float32)
|
||||
v_t = self.action_out_proj(suffix_out)
|
||||
|
||||
losses = F.mse_loss(u_t, v_t, reduction="none")
|
||||
return losses
|
||||
|
||||
def sample_actions(self, images, img_masks, lang_tokens, lang_masks, state, noise=None) -> Tensor:
|
||||
"""Do a full inference forward and compute the action (batch_size x num_steps x num_motors)"""
|
||||
bsize = state.shape[0]
|
||||
device = state.device
|
||||
|
||||
if noise is None:
|
||||
actions_shape = (bsize, self.config.n_action_steps, self.config.max_action_dim)
|
||||
noise = self.sample_noise(actions_shape, device)
|
||||
|
||||
prefix_embs, prefix_pad_masks, prefix_att_masks = self.embed_prefix(
|
||||
images, img_masks, lang_tokens, lang_masks
|
||||
)
|
||||
prefix_att_2d_masks = make_att_2d_masks(prefix_pad_masks, prefix_att_masks)
|
||||
prefix_position_ids = torch.cumsum(prefix_pad_masks, dim=1) - 1
|
||||
|
||||
# Compute image and language key value cache
|
||||
_, past_key_values = self.paligemma_with_expert.forward(
|
||||
attention_mask=prefix_att_2d_masks,
|
||||
position_ids=prefix_position_ids,
|
||||
past_key_values=None,
|
||||
inputs_embeds=[prefix_embs, None],
|
||||
use_cache=self.config.use_cache,
|
||||
fill_kv_cache=True,
|
||||
)
|
||||
|
||||
dt = -1.0 / self.config.num_steps
|
||||
dt = torch.tensor(dt, dtype=torch.float32, device=device)
|
||||
|
||||
x_t = noise
|
||||
time = torch.tensor(1.0, dtype=torch.float32, device=device)
|
||||
while time >= -dt / 2:
|
||||
expanded_time = time.expand(bsize)
|
||||
v_t = self.denoise_step(
|
||||
state,
|
||||
prefix_pad_masks,
|
||||
past_key_values,
|
||||
x_t,
|
||||
expanded_time,
|
||||
)
|
||||
|
||||
# Euler step
|
||||
x_t += dt * v_t
|
||||
time += dt
|
||||
return x_t
|
||||
|
||||
def denoise_step(
|
||||
self,
|
||||
state,
|
||||
prefix_pad_masks,
|
||||
past_key_values,
|
||||
x_t,
|
||||
timestep,
|
||||
):
|
||||
"""Apply one denoising step of the noise `x_t` at a given timestep."""
|
||||
suffix_embs, suffix_pad_masks, suffix_att_masks = self.embed_suffix(state, x_t, timestep)
|
||||
|
||||
suffix_len = suffix_pad_masks.shape[1]
|
||||
batch_size = prefix_pad_masks.shape[0]
|
||||
prefix_len = prefix_pad_masks.shape[1]
|
||||
prefix_pad_2d_masks = prefix_pad_masks[:, None, :].expand(batch_size, suffix_len, prefix_len)
|
||||
|
||||
suffix_att_2d_masks = make_att_2d_masks(suffix_pad_masks, suffix_att_masks)
|
||||
|
||||
full_att_2d_masks = torch.cat([prefix_pad_2d_masks, suffix_att_2d_masks], dim=2)
|
||||
|
||||
prefix_offsets = torch.sum(prefix_pad_masks, dim=-1)[:, None]
|
||||
position_ids = prefix_offsets + torch.cumsum(suffix_pad_masks, dim=1) - 1
|
||||
|
||||
outputs_embeds, _ = self.paligemma_with_expert.forward(
|
||||
attention_mask=full_att_2d_masks,
|
||||
position_ids=position_ids,
|
||||
past_key_values=past_key_values,
|
||||
inputs_embeds=[None, suffix_embs],
|
||||
use_cache=self.config.use_cache,
|
||||
fill_kv_cache=False,
|
||||
)
|
||||
suffix_out = outputs_embeds[1]
|
||||
suffix_out = suffix_out[:, -self.config.n_action_steps :]
|
||||
suffix_out = suffix_out.to(dtype=torch.float32)
|
||||
v_t = self.action_out_proj(suffix_out)
|
||||
return v_t
|
|
@ -0,0 +1,403 @@
|
|||
from typing import List, Optional, Union
|
||||
|
||||
import torch
|
||||
import torch.version
|
||||
from pytest import Cache
|
||||
from torch import nn
|
||||
from transformers import (
|
||||
AutoConfig,
|
||||
GemmaForCausalLM,
|
||||
PaliGemmaForConditionalGeneration,
|
||||
PretrainedConfig,
|
||||
PreTrainedModel,
|
||||
)
|
||||
from transformers.models.auto import CONFIG_MAPPING
|
||||
|
||||
from lerobot.common.policies.pi0.flex_attention import flex_attention_forward
|
||||
|
||||
|
||||
def apply_rope(x, positions, max_wavelength=10_000):
|
||||
"""
|
||||
Applies RoPE positions [B, L] to x [B, L, H, D].
|
||||
"""
|
||||
d_half = x.shape[-1] // 2
|
||||
device = x.device
|
||||
dtype = x.dtype
|
||||
x = x.to(torch.float32)
|
||||
|
||||
freq_exponents = (2.0 / x.shape[-1]) * torch.arange(d_half, dtype=torch.float32, device=device)
|
||||
timescale = max_wavelength**freq_exponents
|
||||
radians = positions[..., None].to(torch.float32) / timescale[None, None, :].to(torch.float32)
|
||||
|
||||
radians = radians[..., None, :]
|
||||
|
||||
sin = torch.sin(radians) # .to(dtype=dtype)
|
||||
cos = torch.cos(radians) # .to(dtype=dtype)
|
||||
|
||||
x1, x2 = x.split(d_half, dim=-1)
|
||||
res = torch.empty_like(x)
|
||||
res[..., :d_half] = x1 * cos - x2 * sin
|
||||
res[..., d_half:] = x2 * cos + x1 * sin
|
||||
|
||||
return res.to(dtype)
|
||||
|
||||
|
||||
class PaliGemmaWithExpertConfig(PretrainedConfig):
|
||||
model_type = "PaliGemmaWithExpertModel"
|
||||
sub_configs = {"paligemma_config": AutoConfig, "gemma_expert_config": AutoConfig}
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
paligemma_config: dict | None = None,
|
||||
gemma_expert_config: dict | None = None,
|
||||
freeze_vision_encoder: bool = True,
|
||||
train_expert_only: bool = True,
|
||||
attention_implementation: str = "eager",
|
||||
**kwargs,
|
||||
):
|
||||
self.freeze_vision_encoder = freeze_vision_encoder
|
||||
self.train_expert_only = train_expert_only
|
||||
self.attention_implementation = attention_implementation
|
||||
|
||||
if paligemma_config is None:
|
||||
# Default config from Pi0
|
||||
self.paligemma_config = CONFIG_MAPPING["paligemma"](
|
||||
transformers_version="4.48.1",
|
||||
_vocab_size=257152,
|
||||
bos_token_id=2,
|
||||
eos_token_id=1,
|
||||
hidden_size=2048,
|
||||
image_token_index=257152,
|
||||
model_type="paligemma",
|
||||
pad_token_id=0,
|
||||
projection_dim=2048,
|
||||
text_config={
|
||||
"hidden_activation": "gelu_pytorch_tanh",
|
||||
"hidden_size": 2048,
|
||||
"intermediate_size": 16384,
|
||||
"model_type": "gemma",
|
||||
"num_attention_heads": 8,
|
||||
"num_hidden_layers": 18,
|
||||
"num_image_tokens": 256,
|
||||
"num_key_value_heads": 1,
|
||||
"torch_dtype": "float32",
|
||||
"vocab_size": 257152,
|
||||
},
|
||||
vision_config={
|
||||
"hidden_size": 1152,
|
||||
"intermediate_size": 4304,
|
||||
"model_type": "siglip_vision_model",
|
||||
"num_attention_heads": 16,
|
||||
"num_hidden_layers": 27,
|
||||
"num_image_tokens": 256,
|
||||
"patch_size": 14,
|
||||
"projection_dim": 2048,
|
||||
"projector_hidden_act": "gelu_fast",
|
||||
"torch_dtype": "float32",
|
||||
"vision_use_head": False,
|
||||
},
|
||||
)
|
||||
elif isinstance(self.paligemma_config, dict):
|
||||
# Override Pi0 default config for PaliGemma
|
||||
if "model_type" not in gemma_expert_config:
|
||||
paligemma_config["model_type"] = "paligemma"
|
||||
|
||||
cfg_cls = CONFIG_MAPPING[paligemma_config["model_type"]]
|
||||
self.paligemma_config = cfg_cls(**paligemma_config)
|
||||
|
||||
if gemma_expert_config is None:
|
||||
# Default config from Pi0
|
||||
self.gemma_expert_config = CONFIG_MAPPING["gemma"](
|
||||
attention_bias=False,
|
||||
attention_dropout=0.0,
|
||||
bos_token_id=2,
|
||||
eos_token_id=1,
|
||||
head_dim=256,
|
||||
hidden_act="gelu_pytorch_tanh",
|
||||
hidden_activation="gelu_pytorch_tanh",
|
||||
hidden_size=1024,
|
||||
initializer_range=0.02,
|
||||
intermediate_size=4096,
|
||||
max_position_embeddings=8192,
|
||||
model_type="gemma",
|
||||
num_attention_heads=8,
|
||||
num_hidden_layers=18,
|
||||
num_key_value_heads=1,
|
||||
pad_token_id=0,
|
||||
rms_norm_eps=1e-06,
|
||||
rope_theta=10000.0,
|
||||
torch_dtype="float32",
|
||||
transformers_version="4.48.1",
|
||||
use_cache=True,
|
||||
vocab_size=257152,
|
||||
)
|
||||
elif isinstance(self.gemma_expert_config, dict):
|
||||
# Override Pi0 default config for Gemma Expert
|
||||
if "model_type" not in gemma_expert_config:
|
||||
gemma_expert_config["model_type"] = "gemma"
|
||||
|
||||
cfg_cls = CONFIG_MAPPING[paligemma_config["model_type"]]
|
||||
self.gemma_expert_config = cfg_cls(**gemma_expert_config)
|
||||
|
||||
super().__init__(**kwargs)
|
||||
|
||||
def __post_init__(self):
|
||||
super().__post_init__()
|
||||
if self.train_expert_only and not self.freeze_vision_encoder:
|
||||
raise ValueError(
|
||||
"You set `freeze_vision_encoder=False` and `train_expert_only=True` which are not compatible."
|
||||
)
|
||||
|
||||
if self.attention_implementation not in ["eager", "fa2", "flex"]:
|
||||
raise ValueError(
|
||||
f"Wrong value provided for `attention_implementation` ({self.attention_implementation}). Expected 'eager', 'fa2' or 'flex'."
|
||||
)
|
||||
|
||||
|
||||
class PaliGemmaWithExpertModel(PreTrainedModel):
|
||||
config_class = PaliGemmaWithExpertConfig
|
||||
|
||||
def __init__(self, config: PaliGemmaWithExpertConfig):
|
||||
super().__init__(config=config)
|
||||
self.config = config
|
||||
self.paligemma = PaliGemmaForConditionalGeneration(config=config.paligemma_config)
|
||||
self.gemma_expert = GemmaForCausalLM(config=config.gemma_expert_config)
|
||||
# Remove unused embed_tokens
|
||||
self.gemma_expert.model.embed_tokens = None
|
||||
|
||||
self.to_bfloat16_like_physical_intelligence()
|
||||
self.set_requires_grad()
|
||||
|
||||
def set_requires_grad(self):
|
||||
if self.config.freeze_vision_encoder:
|
||||
self.paligemma.vision_tower.eval()
|
||||
for params in self.paligemma.vision_tower.parameters():
|
||||
params.requires_grad = False
|
||||
|
||||
if self.config.train_expert_only:
|
||||
self.paligemma.eval()
|
||||
for params in self.paligemma.parameters():
|
||||
params.requires_grad = False
|
||||
|
||||
def train(self, mode: bool = True):
|
||||
super().train(mode)
|
||||
|
||||
if self.config.freeze_vision_encoder:
|
||||
self.paligemma.vision_tower.eval()
|
||||
|
||||
if self.config.train_expert_only:
|
||||
self.paligemma.eval()
|
||||
|
||||
def to_bfloat16_like_physical_intelligence(self):
|
||||
self.paligemma = self.paligemma.to(dtype=torch.bfloat16)
|
||||
|
||||
params_to_change_dtype = [
|
||||
"language_model.model.layers",
|
||||
"gemma_expert.model.layers",
|
||||
"vision_tower",
|
||||
"multi_modal",
|
||||
]
|
||||
for name, param in self.named_parameters():
|
||||
if any(selector in name for selector in params_to_change_dtype):
|
||||
param.data = param.data.to(dtype=torch.bfloat16)
|
||||
|
||||
def embed_image(self, image: torch.Tensor):
|
||||
return self.paligemma.get_image_features(image)
|
||||
|
||||
def embed_language_tokens(self, tokens: torch.Tensor):
|
||||
return self.paligemma.language_model.model.embed_tokens(tokens)
|
||||
|
||||
# TODO: break down this huge forward into modules or functions
|
||||
def forward(
|
||||
self,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
position_ids: Optional[torch.LongTensor] = None,
|
||||
past_key_values: Optional[Union[List[torch.FloatTensor], Cache]] = None,
|
||||
inputs_embeds: List[torch.FloatTensor] = None,
|
||||
use_cache: Optional[bool] = None,
|
||||
fill_kv_cache: Optional[bool] = None,
|
||||
):
|
||||
models = [self.paligemma.language_model.model, self.gemma_expert.model]
|
||||
|
||||
for hidden_states in inputs_embeds:
|
||||
# TODO this is very inefficient
|
||||
# dtype is always the same, batch size too (if > 1 len)
|
||||
# device could be trickier in multi gpu edge cases but that's it
|
||||
if hidden_states is None:
|
||||
continue
|
||||
batch_size = hidden_states.shape[0]
|
||||
|
||||
# RMSNorm
|
||||
num_layers = self.paligemma.config.text_config.num_hidden_layers
|
||||
head_dim = self.paligemma.config.text_config.head_dim
|
||||
for layer_idx in range(num_layers):
|
||||
query_states = []
|
||||
key_states = []
|
||||
value_states = []
|
||||
for i, hidden_states in enumerate(inputs_embeds):
|
||||
if hidden_states is None:
|
||||
continue
|
||||
layer = models[i].layers[layer_idx]
|
||||
# normalizer = torch.tensor(models[i].config.hidden_size**0.5, dtype=hidden_states.dtype)
|
||||
# hidden_states = hidden_states * normalizer
|
||||
hidden_states = layer.input_layernorm(hidden_states)
|
||||
|
||||
input_shape = hidden_states.shape[:-1]
|
||||
hidden_shape = (*input_shape, -1, layer.self_attn.head_dim)
|
||||
|
||||
hidden_states = hidden_states.to(dtype=torch.bfloat16)
|
||||
query_state = layer.self_attn.q_proj(hidden_states).view(hidden_shape)
|
||||
key_state = layer.self_attn.k_proj(hidden_states).view(hidden_shape)
|
||||
value_state = layer.self_attn.v_proj(hidden_states).view(hidden_shape)
|
||||
|
||||
query_states.append(query_state)
|
||||
key_states.append(key_state)
|
||||
value_states.append(value_state)
|
||||
|
||||
# B,L,H,D with L sequence length, H number of heads, D head dim
|
||||
# concatenate on the number of embeddings/tokens
|
||||
query_states = torch.cat(query_states, dim=1)
|
||||
key_states = torch.cat(key_states, dim=1)
|
||||
value_states = torch.cat(value_states, dim=1)
|
||||
|
||||
query_states = apply_rope(query_states, position_ids)
|
||||
key_states = apply_rope(key_states, position_ids)
|
||||
|
||||
if use_cache and past_key_values is None:
|
||||
past_key_values = {}
|
||||
|
||||
if use_cache:
|
||||
if fill_kv_cache:
|
||||
past_key_values[layer_idx] = {
|
||||
"key_states": key_states,
|
||||
"value_states": value_states,
|
||||
}
|
||||
else:
|
||||
# TODO here, some optimization can be done - similar to a `StaticCache` we can declare the `max_len` before.
|
||||
# so we create an empty cache, with just one cuda malloc, and if (in autoregressive case) we reach
|
||||
# the max len, then we (for instance) double the cache size. This implementation already exists
|
||||
# in `transformers`. (molbap)
|
||||
key_states = torch.cat([past_key_values[layer_idx]["key_states"], key_states], dim=1)
|
||||
value_states = torch.cat(
|
||||
[past_key_values[layer_idx]["value_states"], value_states], dim=1
|
||||
)
|
||||
|
||||
attention_interface = self.get_attention_interface()
|
||||
att_output = attention_interface(
|
||||
attention_mask, batch_size, head_dim, query_states, key_states, value_states
|
||||
)
|
||||
att_output = att_output.to(dtype=torch.bfloat16)
|
||||
|
||||
# first part of att_output is prefix (up to sequence length, [:, 0:prefix_seq_len])
|
||||
outputs_embeds = []
|
||||
start = 0
|
||||
for i, hidden_states in enumerate(inputs_embeds):
|
||||
layer = models[i].layers[layer_idx]
|
||||
|
||||
if hidden_states is not None:
|
||||
end = start + hidden_states.shape[1]
|
||||
|
||||
if att_output.dtype != layer.self_attn.o_proj.weight.dtype:
|
||||
att_output = att_output.to(layer.self_attn.o_proj.weight.dtype)
|
||||
out_emb = layer.self_attn.o_proj(att_output[:, start:end])
|
||||
|
||||
# TODO: first dropout (by default 0.0)
|
||||
|
||||
# first residual
|
||||
out_emb += hidden_states
|
||||
after_first_residual = out_emb.clone()
|
||||
|
||||
out_emb = layer.post_attention_layernorm(out_emb)
|
||||
out_emb = layer.mlp(out_emb)
|
||||
|
||||
# TODO: second dropout (by default 0.0)
|
||||
|
||||
# second residual
|
||||
out_emb += after_first_residual
|
||||
|
||||
outputs_embeds.append(out_emb)
|
||||
|
||||
start = end
|
||||
else:
|
||||
outputs_embeds.append(None)
|
||||
|
||||
inputs_embeds = outputs_embeds
|
||||
|
||||
# final norm
|
||||
outputs_embeds = []
|
||||
for i, hidden_states in enumerate(inputs_embeds):
|
||||
if hidden_states is not None:
|
||||
out_emb = models[i].norm(hidden_states)
|
||||
outputs_embeds.append(out_emb)
|
||||
else:
|
||||
outputs_embeds.append(None)
|
||||
|
||||
return outputs_embeds, past_key_values
|
||||
|
||||
def get_attention_interface(self):
|
||||
if self.config.attention_implementation == "fa2":
|
||||
attention_interface = self.flash_attention_forward
|
||||
elif self.config.attention_implementation == "flex":
|
||||
attention_interface = flex_attention_forward
|
||||
else:
|
||||
attention_interface = self.eager_attention_forward
|
||||
return attention_interface
|
||||
|
||||
def flash_attention_forward(
|
||||
self, attention_mask, batch_size, head_dim, query_states, key_states, value_states
|
||||
):
|
||||
raise NotImplementedError("FA2 is not implemented (yet)")
|
||||
|
||||
def eager_attention_forward(
|
||||
self, attention_mask, batch_size, head_dim, query_states, key_states, value_states
|
||||
):
|
||||
num_att_heads = self.config.paligemma_config.text_config.num_attention_heads
|
||||
num_key_value_heads = self.config.paligemma_config.text_config.num_key_value_heads
|
||||
num_key_value_groups = num_att_heads // num_key_value_heads
|
||||
|
||||
# query_states: batch_size, sequence_length, num_att_head, head_dim
|
||||
# key_states: batch_size, sequence_length, num_key_value_head, head_dim
|
||||
# value_states: batch_size, sequence_length, num_key_value_head, head_dim
|
||||
sequence_length = key_states.shape[1]
|
||||
|
||||
key_states = key_states[:, :, :, None, :].expand(
|
||||
batch_size, sequence_length, num_key_value_heads, num_key_value_groups, head_dim
|
||||
)
|
||||
key_states = key_states.reshape(
|
||||
batch_size, sequence_length, num_key_value_heads * num_key_value_groups, head_dim
|
||||
)
|
||||
|
||||
value_states = value_states[:, :, :, None, :].expand(
|
||||
batch_size, sequence_length, num_key_value_heads, num_key_value_groups, head_dim
|
||||
)
|
||||
value_states = value_states.reshape(
|
||||
batch_size, sequence_length, num_key_value_heads * num_key_value_groups, head_dim
|
||||
)
|
||||
|
||||
# Attention here is upcasted to float32 to match the original eager implementation.
|
||||
|
||||
query_states = query_states.to(dtype=torch.float32)
|
||||
key_states = key_states.to(dtype=torch.float32)
|
||||
|
||||
query_states = query_states.transpose(1, 2)
|
||||
key_states = key_states.transpose(1, 2)
|
||||
|
||||
att_weights = torch.matmul(query_states, key_states.transpose(2, 3))
|
||||
att_weights *= head_dim**-0.5
|
||||
big_neg = -2.3819763e38 # See gemma/modules.py
|
||||
|
||||
masked_att_weights = torch.where(attention_mask[:, None, :, :], att_weights, big_neg)
|
||||
|
||||
probs = nn.functional.softmax(masked_att_weights, dim=-1)
|
||||
probs = probs.to(dtype=value_states.dtype)
|
||||
|
||||
# probs: batch_size, num_key_value_head, num_att_head, sequence_length, sequence_length
|
||||
# value_states: batch_size, sequence_length, num_att_heads, head_dim
|
||||
|
||||
att_output = torch.matmul(probs, value_states.permute(0, 2, 1, 3))
|
||||
|
||||
att_output = att_output.permute(0, 2, 1, 3)
|
||||
# we use -1 because sequence length can change
|
||||
att_output = att_output.reshape(batch_size, -1, num_key_value_heads * num_key_value_groups * head_dim)
|
||||
|
||||
return att_output
|
|
@ -319,7 +319,7 @@ class TDMPCPolicy(PreTrainedPolicy):
|
|||
|
||||
# (b, t) -> (t, b)
|
||||
for key in batch:
|
||||
if batch[key].ndim > 1:
|
||||
if isinstance(batch[key], torch.Tensor) and batch[key].ndim > 1:
|
||||
batch[key] = batch[key].transpose(1, 0)
|
||||
|
||||
action = batch["action"] # (t, b, action_dim)
|
||||
|
@ -502,7 +502,7 @@ class TDMPCPolicy(PreTrainedPolicy):
|
|||
|
||||
# Undo (b, t) -> (t, b).
|
||||
for key in batch:
|
||||
if batch[key].ndim > 1:
|
||||
if isinstance(batch[key], torch.Tensor) and batch[key].ndim > 1:
|
||||
batch[key] = batch[key].transpose(1, 0)
|
||||
|
||||
return info
|
||||
|
|
|
@ -74,6 +74,18 @@ def get_safe_torch_device(try_device: str, log: bool = False) -> torch.device:
|
|||
return device
|
||||
|
||||
|
||||
def get_safe_dtype(dtype: torch.dtype, device: str | torch.device):
|
||||
"""
|
||||
mps is currently not compatible with float64
|
||||
"""
|
||||
if isinstance(device, torch.device):
|
||||
device = device.type
|
||||
if device == "mps" and dtype == torch.float64:
|
||||
return torch.float32
|
||||
else:
|
||||
return dtype
|
||||
|
||||
|
||||
def is_torch_device_available(try_device: str) -> bool:
|
||||
if try_device == "cuda":
|
||||
return torch.cuda.is_available()
|
||||
|
|
|
@ -125,6 +125,9 @@ def rollout(
|
|||
# Reset the policy and environments.
|
||||
policy.reset()
|
||||
|
||||
if hasattr(policy, "use_ema_modules"):
|
||||
policy.use_ema_modules()
|
||||
|
||||
observation, info = env.reset(seed=seeds)
|
||||
if render_callback is not None:
|
||||
render_callback(env)
|
||||
|
@ -205,6 +208,9 @@ def rollout(
|
|||
stacked_observations[key] = torch.stack([obs[key] for obs in all_observations], dim=1)
|
||||
ret["observation"] = stacked_observations
|
||||
|
||||
if hasattr(policy, "use_original_modules"):
|
||||
policy.use_original_modules()
|
||||
|
||||
return ret
|
||||
|
||||
|
||||
|
@ -235,7 +241,9 @@ def eval_policy(
|
|||
raise ValueError("If max_episodes_rendered > 0, videos_dir must be provided.")
|
||||
|
||||
if not isinstance(policy, PreTrainedPolicy):
|
||||
raise ValueError(policy)
|
||||
raise ValueError(
|
||||
f"Policy of type 'PreTrainedPolicy' is expected, but type '{type(policy)}' was provided."
|
||||
)
|
||||
|
||||
start = time.time()
|
||||
policy.eval()
|
||||
|
|
|
@ -38,6 +38,7 @@ from lerobot.common.policies.factory import make_policy
|
|||
from lerobot.common.policies.utils import get_device_from_parameters
|
||||
from lerobot.common.utils.utils import (
|
||||
format_big_number,
|
||||
get_safe_dtype,
|
||||
get_safe_torch_device,
|
||||
has_method,
|
||||
init_logging,
|
||||
|
@ -86,6 +87,10 @@ def update_policy(
|
|||
|
||||
optimizer.zero_grad()
|
||||
|
||||
if hasattr(policy, "update_ema_modules"):
|
||||
policy.update_ema_modules()
|
||||
|
||||
# Step through pytorch scheduler at every batch instead of epoch
|
||||
if lr_scheduler is not None:
|
||||
lr_scheduler.step()
|
||||
|
||||
|
@ -215,6 +220,7 @@ def train(cfg: TrainPipelineConfig):
|
|||
device=device,
|
||||
ds_meta=offline_dataset.meta,
|
||||
)
|
||||
|
||||
logging.info("Creating optimizer and scheduler")
|
||||
optimizer, lr_scheduler = make_optimizer_and_scheduler(cfg, policy)
|
||||
grad_scaler = GradScaler(device, enabled=cfg.use_amp)
|
||||
|
@ -296,6 +302,10 @@ def train(cfg: TrainPipelineConfig):
|
|||
dl_iter = cycle(dataloader)
|
||||
|
||||
policy.train()
|
||||
|
||||
if hasattr(policy, "init_ema_modules"):
|
||||
policy.init_ema_modules()
|
||||
|
||||
offline_step = 0
|
||||
for _ in range(step, cfg.offline.steps):
|
||||
if offline_step == 0:
|
||||
|
@ -306,7 +316,8 @@ def train(cfg: TrainPipelineConfig):
|
|||
dataloading_s = time.perf_counter() - start_time
|
||||
|
||||
for key in batch:
|
||||
batch[key] = batch[key].to(device, non_blocking=True)
|
||||
if isinstance(batch[key], torch.Tensor):
|
||||
batch[key] = batch[key].to(device, non_blocking=True)
|
||||
|
||||
train_info = update_policy(
|
||||
policy,
|
||||
|
@ -365,6 +376,8 @@ def train(cfg: TrainPipelineConfig):
|
|||
"next.reward": {"shape": (), "dtype": np.dtype("float32")},
|
||||
"next.done": {"shape": (), "dtype": np.dtype("?")},
|
||||
"task_index": {"shape": (), "dtype": np.dtype("int64")},
|
||||
# FIXME: 'task' is a string
|
||||
# "task": {"shape": (), "dtype": np.dtype("?")},
|
||||
# FIXME: 'next.success' is expected by pusht env but not xarm
|
||||
"next.success": {"shape": (), "dtype": np.dtype("?")},
|
||||
},
|
||||
|
@ -451,9 +464,10 @@ def train(cfg: TrainPipelineConfig):
|
|||
if len(offline_dataset.meta.tasks) > 1:
|
||||
raise NotImplementedError("Add support for multi task.")
|
||||
|
||||
# Hack to add a task to the online_dataset (0 is the first task of the offline_dataset)
|
||||
# TODO(rcadene, aliberts): Hack to add a task to the online_dataset (0 is the first task of the offline_dataset)
|
||||
total_num_frames = eval_info["episodes"]["index"].shape[0]
|
||||
eval_info["episodes"]["task_index"] = torch.tensor([0] * total_num_frames, dtype=torch.int64)
|
||||
eval_info["episodes"]["task"] = ["do the thing"] * total_num_frames
|
||||
|
||||
with lock if lock is not None else nullcontext():
|
||||
start_update_buffer_time = time.perf_counter()
|
||||
|
@ -499,7 +513,9 @@ def train(cfg: TrainPipelineConfig):
|
|||
dataloading_s = time.perf_counter() - start_time
|
||||
|
||||
for key in batch:
|
||||
batch[key] = batch[key].to(device, non_blocking=True)
|
||||
if isinstance(batch[key], torch.Tensor):
|
||||
dtype = get_safe_dtype(batch[key].dtype, device)
|
||||
batch[key] = batch[key].to(device=device, dtype=dtype, non_blocking=True)
|
||||
|
||||
train_info = update_policy(
|
||||
policy,
|
||||
|
|
|
@ -7,7 +7,7 @@
|
|||
<script src="https://cdn.tailwindcss.com"></script>
|
||||
<script defer src="https://cdn.jsdelivr.net/npm/alpinejs@3.x.x/dist/cdn.min.js"></script>
|
||||
</head>
|
||||
<body class="h-screen overflow-hidden font-mono text-white" x-data="{
|
||||
<body class="h-screen overflow-hidden font-mono text-white" x-data="{
|
||||
inputValue: '',
|
||||
navigateToDataset() {
|
||||
const trimmedValue = this.inputValue.trim();
|
||||
|
@ -40,14 +40,14 @@
|
|||
</div>
|
||||
</div>
|
||||
<div class="flex w-full max-w-lg px-4 mb-4">
|
||||
<input
|
||||
type="text"
|
||||
<input
|
||||
type="text"
|
||||
x-model="inputValue"
|
||||
@keyup.enter="navigateToDataset"
|
||||
placeholder="enter dataset id (ex: lerobot/droid_100)"
|
||||
class="flex-grow px-4 py-2 rounded-l bg-white bg-opacity-20 text-white placeholder-gray-300 focus:outline-none focus:ring-2 focus:ring-blue-300"
|
||||
>
|
||||
<button
|
||||
<button
|
||||
@click="navigateToDataset"
|
||||
class="px-4 py-2 bg-blue-500 text-white rounded-r hover:bg-blue-600 focus:outline-none focus:ring-2 focus:ring-blue-300"
|
||||
>
|
||||
|
@ -65,4 +65,4 @@
|
|||
</details>
|
||||
</div>
|
||||
</body>
|
||||
</html>
|
||||
</html>
|
||||
|
|
|
@ -107,8 +107,8 @@
|
|||
<span class="truncate">filter videos</span>
|
||||
<div class="transition-transform" :class="{ 'rotate-180': isVideosDropdownOpen }">🔽</div>
|
||||
</div>
|
||||
|
||||
<div x-show="isVideosDropdownOpen"
|
||||
|
||||
<div x-show="isVideosDropdownOpen"
|
||||
class="absolute mt-1 border border-slate-500 rounded shadow-lg z-10">
|
||||
<div>
|
||||
<template x-for="option in videosKeys" :key="option">
|
||||
|
|
|
@ -908,36 +908,29 @@ tests = ["pytest", "pytest-cov", "pytest-xdist"]
|
|||
|
||||
[[package]]
|
||||
name = "dash"
|
||||
version = "2.18.2"
|
||||
version = "2.9.3"
|
||||
description = "A Python framework for building reactive web-apps. Developed by Plotly."
|
||||
optional = true
|
||||
python-versions = ">=3.8"
|
||||
python-versions = ">=3.6"
|
||||
files = [
|
||||
{file = "dash-2.18.2-py3-none-any.whl", hash = "sha256:0ce0479d1bc958e934630e2de7023b8a4558f23ce1f9f5a4b34b65eb3903a869"},
|
||||
{file = "dash-2.18.2.tar.gz", hash = "sha256:20e8404f73d0fe88ce2eae33c25bbc513cbe52f30d23a401fa5f24dbb44296c8"},
|
||||
{file = "dash-2.9.3-py3-none-any.whl", hash = "sha256:a749ae1ea9de3fe7b785353a818ec9b629d39c6b7e02462954203bd1e296fd0e"},
|
||||
{file = "dash-2.9.3.tar.gz", hash = "sha256:47392f8d6455dc989a697407eb5941f3bad80604df985ab1ac9d4244568ffb34"},
|
||||
]
|
||||
|
||||
[package.dependencies]
|
||||
dash-core-components = "2.0.0"
|
||||
dash-html-components = "2.0.0"
|
||||
dash-table = "5.0.0"
|
||||
Flask = ">=1.0.4,<3.1"
|
||||
importlib-metadata = "*"
|
||||
nest-asyncio = "*"
|
||||
Flask = ">=1.0.4"
|
||||
plotly = ">=5.0.0"
|
||||
requests = "*"
|
||||
retrying = "*"
|
||||
setuptools = "*"
|
||||
typing-extensions = ">=4.1.1"
|
||||
Werkzeug = "<3.1"
|
||||
|
||||
[package.extras]
|
||||
celery = ["celery[redis] (>=5.1.2)", "redis (>=3.5.3)"]
|
||||
ci = ["black (==22.3.0)", "dash-dangerously-set-inner-html", "dash-flow-example (==0.0.5)", "flake8 (==7.0.0)", "flaky (==3.8.1)", "flask-talisman (==1.0.0)", "jupyterlab (<4.0.0)", "mimesis (<=11.1.0)", "mock (==4.0.3)", "numpy (<=1.26.3)", "openpyxl", "orjson (==3.10.3)", "pandas (>=1.4.0)", "pyarrow", "pylint (==3.0.3)", "pytest-mock", "pytest-rerunfailures", "pytest-sugar (==0.9.6)", "pyzmq (==25.1.2)", "xlrd (>=2.0.1)"]
|
||||
celery = ["celery[redis] (>=5.1.2)", "importlib-metadata (<5)", "redis (>=3.5.3)"]
|
||||
ci = ["black (==21.6b0)", "black (==22.3.0)", "dash-dangerously-set-inner-html", "dash-flow-example (==0.0.5)", "flake8 (==3.9.2)", "flaky (==3.7.0)", "flask-talisman (==1.0.0)", "isort (==4.3.21)", "mimesis", "mock (==4.0.3)", "numpy", "openpyxl", "orjson (==3.5.4)", "orjson (==3.6.7)", "pandas (==1.1.5)", "pandas (>=1.4.0)", "preconditions", "pyarrow", "pyarrow (<3)", "pylint (==2.13.5)", "pytest-mock", "pytest-rerunfailures", "pytest-sugar (==0.9.6)", "xlrd (<2)", "xlrd (>=2.0.1)"]
|
||||
compress = ["flask-compress"]
|
||||
dev = ["PyYAML (>=5.4.1)", "coloredlogs (>=15.0.1)", "fire (>=0.4.0)"]
|
||||
diskcache = ["diskcache (>=5.2.1)", "multiprocess (>=0.70.12)", "psutil (>=5.8.0)"]
|
||||
testing = ["beautifulsoup4 (>=4.8.2)", "cryptography", "dash-testing-stub (>=0.0.2)", "lxml (>=4.6.2)", "multiprocess (>=0.70.12)", "percy (>=2.0.2)", "psutil (>=5.8.0)", "pytest (>=6.0.2)", "requests[security] (>=2.21.0)", "selenium (>=3.141.0,<=4.2.0)", "waitress (>=1.4.4)"]
|
||||
testing = ["beautifulsoup4 (>=4.8.2)", "cryptography (<3.4)", "dash-testing-stub (>=0.0.2)", "lxml (>=4.6.2)", "multiprocess (>=0.70.12)", "percy (>=2.0.2)", "psutil (>=5.8.0)", "pytest (>=6.0.2)", "requests[security] (>=2.21.0)", "selenium (>=3.141.0,<=4.2.0)", "waitress (>=1.4.4)"]
|
||||
|
||||
[[package]]
|
||||
name = "dash-core-components"
|
||||
|
@ -1220,9 +1213,9 @@ files = [
|
|||
absl-py = ">=0.6.1"
|
||||
attrs = ">=18.2.0"
|
||||
numpy = [
|
||||
{version = ">=1.26.0", markers = "python_version >= \"3.12\""},
|
||||
{version = ">=1.23.3", markers = "python_version >= \"3.11\" and python_version < \"3.12\""},
|
||||
{version = ">=1.21.2", markers = "python_version >= \"3.10\" and python_version < \"3.11\""},
|
||||
{version = ">=1.23.3", markers = "python_version >= \"3.11\" and python_version < \"3.12\""},
|
||||
{version = ">=1.26.0", markers = "python_version >= \"3.12\""},
|
||||
]
|
||||
wrapt = ">=1.11.2"
|
||||
|
||||
|
@ -1430,21 +1423,21 @@ typing = ["typing-extensions (>=4.12.2)"]
|
|||
|
||||
[[package]]
|
||||
name = "flask"
|
||||
version = "3.0.3"
|
||||
version = "3.1.0"
|
||||
description = "A simple framework for building complex web applications."
|
||||
optional = false
|
||||
python-versions = ">=3.8"
|
||||
python-versions = ">=3.9"
|
||||
files = [
|
||||
{file = "flask-3.0.3-py3-none-any.whl", hash = "sha256:34e815dfaa43340d1d15a5c3a02b8476004037eb4840b34910c6e21679d288f3"},
|
||||
{file = "flask-3.0.3.tar.gz", hash = "sha256:ceb27b0af3823ea2737928a4d99d125a06175b8512c445cbd9a9ce200ef76842"},
|
||||
{file = "flask-3.1.0-py3-none-any.whl", hash = "sha256:d667207822eb83f1c4b50949b1623c8fc8d51f2341d65f72e1a1815397551136"},
|
||||
{file = "flask-3.1.0.tar.gz", hash = "sha256:5f873c5184c897c8d9d1b05df1e3d01b14910ce69607a117bd3277098a5836ac"},
|
||||
]
|
||||
|
||||
[package.dependencies]
|
||||
blinker = ">=1.6.2"
|
||||
blinker = ">=1.9"
|
||||
click = ">=8.1.3"
|
||||
itsdangerous = ">=2.1.2"
|
||||
itsdangerous = ">=2.2"
|
||||
Jinja2 = ">=3.1.2"
|
||||
Werkzeug = ">=3.0.0"
|
||||
Werkzeug = ">=3.1"
|
||||
|
||||
[package.extras]
|
||||
async = ["asgiref (>=3.2)"]
|
||||
|
@ -4248,10 +4241,10 @@ files = [
|
|||
|
||||
[package.dependencies]
|
||||
numpy = [
|
||||
{version = ">=1.26.0", markers = "python_version >= \"3.12\""},
|
||||
{version = ">=1.23.5", markers = "python_version >= \"3.11\" and python_version < \"3.12\""},
|
||||
{version = ">=1.21.4", markers = "python_version >= \"3.10\" and platform_system == \"Darwin\" and python_version < \"3.11\""},
|
||||
{version = ">=1.21.2", markers = "platform_system != \"Darwin\" and python_version >= \"3.10\" and python_version < \"3.11\""},
|
||||
{version = ">=1.23.5", markers = "python_version >= \"3.11\" and python_version < \"3.12\""},
|
||||
{version = ">=1.26.0", markers = "python_version >= \"3.12\""},
|
||||
]
|
||||
|
||||
[[package]]
|
||||
|
@ -4272,10 +4265,10 @@ files = [
|
|||
|
||||
[package.dependencies]
|
||||
numpy = [
|
||||
{version = ">=1.26.0", markers = "python_version >= \"3.12\""},
|
||||
{version = ">=1.23.5", markers = "python_version >= \"3.11\" and python_version < \"3.12\""},
|
||||
{version = ">=1.21.4", markers = "python_version >= \"3.10\" and platform_system == \"Darwin\" and python_version < \"3.11\""},
|
||||
{version = ">=1.21.2", markers = "platform_system != \"Darwin\" and python_version >= \"3.10\" and python_version < \"3.11\""},
|
||||
{version = ">=1.23.5", markers = "python_version >= \"3.11\" and python_version < \"3.12\""},
|
||||
{version = ">=1.26.0", markers = "python_version >= \"3.12\""},
|
||||
]
|
||||
|
||||
[[package]]
|
||||
|
@ -4364,9 +4357,9 @@ files = [
|
|||
|
||||
[package.dependencies]
|
||||
numpy = [
|
||||
{version = ">=1.26.0", markers = "python_version >= \"3.12\""},
|
||||
{version = ">=1.23.2", markers = "python_version == \"3.11\""},
|
||||
{version = ">=1.22.4", markers = "python_version < \"3.11\""},
|
||||
{version = ">=1.23.2", markers = "python_version == \"3.11\""},
|
||||
{version = ">=1.26.0", markers = "python_version >= \"3.12\""},
|
||||
]
|
||||
python-dateutil = ">=2.8.2"
|
||||
pytz = ">=2020.1"
|
||||
|
@ -4918,6 +4911,8 @@ files = [
|
|||
{file = "PyAudio-0.2.14-cp311-cp311-win_amd64.whl", hash = "sha256:bbeb01d36a2f472ae5ee5e1451cacc42112986abe622f735bb870a5db77cf903"},
|
||||
{file = "PyAudio-0.2.14-cp312-cp312-win32.whl", hash = "sha256:5fce4bcdd2e0e8c063d835dbe2860dac46437506af509353c7f8114d4bacbd5b"},
|
||||
{file = "PyAudio-0.2.14-cp312-cp312-win_amd64.whl", hash = "sha256:12f2f1ba04e06ff95d80700a78967897a489c05e093e3bffa05a84ed9c0a7fa3"},
|
||||
{file = "PyAudio-0.2.14-cp313-cp313-win32.whl", hash = "sha256:95328285b4dab57ea8c52a4a996cb52be6d629353315be5bfda403d15932a497"},
|
||||
{file = "PyAudio-0.2.14-cp313-cp313-win_amd64.whl", hash = "sha256:692d8c1446f52ed2662120bcd9ddcb5aa2b71f38bda31e58b19fb4672fffba69"},
|
||||
{file = "PyAudio-0.2.14-cp38-cp38-win32.whl", hash = "sha256:858caf35b05c26d8fc62f1efa2e8f53d5fa1a01164842bd622f70ddc41f55000"},
|
||||
{file = "PyAudio-0.2.14-cp38-cp38-win_amd64.whl", hash = "sha256:2dac0d6d675fe7e181ba88f2de88d321059b69abd52e3f4934a8878e03a7a074"},
|
||||
{file = "PyAudio-0.2.14-cp39-cp39-win32.whl", hash = "sha256:f745109634a7c19fa4d6b8b7d6967c3123d988c9ade0cd35d4295ee1acdb53e9"},
|
||||
|
@ -6083,20 +6078,6 @@ typing-extensions = ">=4.5"
|
|||
notebook = ["rerun-notebook (==0.21.0)"]
|
||||
tests = ["pytest (==7.1.2)"]
|
||||
|
||||
[[package]]
|
||||
name = "retrying"
|
||||
version = "1.3.4"
|
||||
description = "Retrying"
|
||||
optional = true
|
||||
python-versions = "*"
|
||||
files = [
|
||||
{file = "retrying-1.3.4-py3-none-any.whl", hash = "sha256:8cc4d43cb8e1125e0ff3344e9de678fefd85db3b750b81b2240dc0183af37b35"},
|
||||
{file = "retrying-1.3.4.tar.gz", hash = "sha256:345da8c5765bd982b1d1915deb9102fd3d1f7ad16bd84a9700b85f64d24e8f3e"},
|
||||
]
|
||||
|
||||
[package.dependencies]
|
||||
six = ">=1.7.0"
|
||||
|
||||
[[package]]
|
||||
name = "rfc3339-validator"
|
||||
version = "0.1.4"
|
||||
|
@ -6949,6 +6930,38 @@ webencodings = ">=0.4"
|
|||
doc = ["sphinx", "sphinx_rtd_theme"]
|
||||
test = ["pytest", "ruff"]
|
||||
|
||||
[[package]]
|
||||
name = "tokenizers"
|
||||
version = "0.21.0"
|
||||
description = ""
|
||||
optional = false
|
||||
python-versions = ">=3.7"
|
||||
files = [
|
||||
{file = "tokenizers-0.21.0-cp39-abi3-macosx_10_12_x86_64.whl", hash = "sha256:3c4c93eae637e7d2aaae3d376f06085164e1660f89304c0ab2b1d08a406636b2"},
|
||||
{file = "tokenizers-0.21.0-cp39-abi3-macosx_11_0_arm64.whl", hash = "sha256:f53ea537c925422a2e0e92a24cce96f6bc5046bbef24a1652a5edc8ba975f62e"},
|
||||
{file = "tokenizers-0.21.0-cp39-abi3-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:6b177fb54c4702ef611de0c069d9169f0004233890e0c4c5bd5508ae05abf193"},
|
||||
{file = "tokenizers-0.21.0-cp39-abi3-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:6b43779a269f4629bebb114e19c3fca0223296ae9fea8bb9a7a6c6fb0657ff8e"},
|
||||
{file = "tokenizers-0.21.0-cp39-abi3-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:9aeb255802be90acfd363626753fda0064a8df06031012fe7d52fd9a905eb00e"},
|
||||
{file = "tokenizers-0.21.0-cp39-abi3-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:d8b09dbeb7a8d73ee204a70f94fc06ea0f17dcf0844f16102b9f414f0b7463ba"},
|
||||
{file = "tokenizers-0.21.0-cp39-abi3-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:400832c0904f77ce87c40f1a8a27493071282f785724ae62144324f171377273"},
|
||||
{file = "tokenizers-0.21.0-cp39-abi3-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:e84ca973b3a96894d1707e189c14a774b701596d579ffc7e69debfc036a61a04"},
|
||||
{file = "tokenizers-0.21.0-cp39-abi3-musllinux_1_2_aarch64.whl", hash = "sha256:eb7202d231b273c34ec67767378cd04c767e967fda12d4a9e36208a34e2f137e"},
|
||||
{file = "tokenizers-0.21.0-cp39-abi3-musllinux_1_2_armv7l.whl", hash = "sha256:089d56db6782a73a27fd8abf3ba21779f5b85d4a9f35e3b493c7bbcbbf0d539b"},
|
||||
{file = "tokenizers-0.21.0-cp39-abi3-musllinux_1_2_i686.whl", hash = "sha256:c87ca3dc48b9b1222d984b6b7490355a6fdb411a2d810f6f05977258400ddb74"},
|
||||
{file = "tokenizers-0.21.0-cp39-abi3-musllinux_1_2_x86_64.whl", hash = "sha256:4145505a973116f91bc3ac45988a92e618a6f83eb458f49ea0790df94ee243ff"},
|
||||
{file = "tokenizers-0.21.0-cp39-abi3-win32.whl", hash = "sha256:eb1702c2f27d25d9dd5b389cc1f2f51813e99f8ca30d9e25348db6585a97e24a"},
|
||||
{file = "tokenizers-0.21.0-cp39-abi3-win_amd64.whl", hash = "sha256:87841da5a25a3a5f70c102de371db120f41873b854ba65e52bccd57df5a3780c"},
|
||||
{file = "tokenizers-0.21.0.tar.gz", hash = "sha256:ee0894bf311b75b0c03079f33859ae4b2334d675d4e93f5a4132e1eae2834fe4"},
|
||||
]
|
||||
|
||||
[package.dependencies]
|
||||
huggingface-hub = ">=0.16.4,<1.0"
|
||||
|
||||
[package.extras]
|
||||
dev = ["tokenizers[testing]"]
|
||||
docs = ["setuptools-rust", "sphinx", "sphinx-rtd-theme"]
|
||||
testing = ["black (==22.3)", "datasets", "numpy", "pytest", "requests", "ruff"]
|
||||
|
||||
[[package]]
|
||||
name = "toml"
|
||||
version = "0.10.2"
|
||||
|
@ -7151,6 +7164,75 @@ files = [
|
|||
docs = ["myst-parser", "pydata-sphinx-theme", "sphinx"]
|
||||
test = ["argcomplete (>=3.0.3)", "mypy (>=1.7.0)", "pre-commit", "pytest (>=7.0,<8.2)", "pytest-mock", "pytest-mypy-testing"]
|
||||
|
||||
[[package]]
|
||||
name = "transformers"
|
||||
version = "4.48.0"
|
||||
description = "State-of-the-art Machine Learning for JAX, PyTorch and TensorFlow"
|
||||
optional = false
|
||||
python-versions = ">=3.9.0"
|
||||
files = [
|
||||
{file = "transformers-4.48.0-py3-none-any.whl", hash = "sha256:6d3de6d71cb5f2a10f9775ccc17abce9620195caaf32ec96542bd2a6937f25b0"},
|
||||
{file = "transformers-4.48.0.tar.gz", hash = "sha256:03fdfcbfb8b0367fb6c9fbe9d1c9aa54dfd847618be9b52400b2811d22799cb1"},
|
||||
]
|
||||
|
||||
[package.dependencies]
|
||||
filelock = "*"
|
||||
huggingface-hub = ">=0.24.0,<1.0"
|
||||
numpy = ">=1.17"
|
||||
packaging = ">=20.0"
|
||||
pyyaml = ">=5.1"
|
||||
regex = "!=2019.12.17"
|
||||
requests = "*"
|
||||
safetensors = ">=0.4.1"
|
||||
tokenizers = ">=0.21,<0.22"
|
||||
tqdm = ">=4.27"
|
||||
|
||||
[package.extras]
|
||||
accelerate = ["accelerate (>=0.26.0)"]
|
||||
agents = ["Pillow (>=10.0.1,<=15.0)", "accelerate (>=0.26.0)", "datasets (!=2.5.0)", "diffusers", "opencv-python", "sentencepiece (>=0.1.91,!=0.1.92)", "torch (>=2.0)"]
|
||||
all = ["Pillow (>=10.0.1,<=15.0)", "accelerate (>=0.26.0)", "av (==9.2.0)", "codecarbon (>=2.8.1)", "flax (>=0.4.1,<=0.7.0)", "jax (>=0.4.1,<=0.4.13)", "jaxlib (>=0.4.1,<=0.4.13)", "kenlm", "keras-nlp (>=0.3.1,<0.14.0)", "librosa", "onnxconverter-common", "optax (>=0.0.8,<=0.1.4)", "optuna", "phonemizer", "protobuf", "pyctcdecode (>=0.4.0)", "ray[tune] (>=2.7.0)", "scipy (<1.13.0)", "sentencepiece (>=0.1.91,!=0.1.92)", "sigopt", "tensorflow (>2.9,<2.16)", "tensorflow-text (<2.16)", "tf2onnx", "timm (<=1.0.11)", "tokenizers (>=0.21,<0.22)", "torch (>=2.0)", "torchaudio", "torchvision"]
|
||||
audio = ["kenlm", "librosa", "phonemizer", "pyctcdecode (>=0.4.0)"]
|
||||
benchmark = ["optimum-benchmark (>=0.3.0)"]
|
||||
codecarbon = ["codecarbon (>=2.8.1)"]
|
||||
deepspeed = ["accelerate (>=0.26.0)", "deepspeed (>=0.9.3)"]
|
||||
deepspeed-testing = ["GitPython (<3.1.19)", "accelerate (>=0.26.0)", "beautifulsoup4", "cookiecutter (==1.7.3)", "datasets (!=2.5.0)", "deepspeed (>=0.9.3)", "dill (<0.3.5)", "evaluate (>=0.2.0)", "faiss-cpu", "nltk (<=3.8.1)", "optuna", "parameterized", "protobuf", "psutil", "pydantic", "pytest (>=7.2.0,<8.0.0)", "pytest-asyncio", "pytest-rich", "pytest-timeout", "pytest-xdist", "rjieba", "rouge-score (!=0.0.7,!=0.0.8,!=0.1,!=0.1.1)", "ruff (==0.5.1)", "sacrebleu (>=1.4.12,<2.0.0)", "sacremoses", "sentencepiece (>=0.1.91,!=0.1.92)", "tensorboard", "timeout-decorator"]
|
||||
dev = ["GitPython (<3.1.19)", "Pillow (>=10.0.1,<=15.0)", "accelerate (>=0.26.0)", "av (==9.2.0)", "beautifulsoup4", "codecarbon (>=2.8.1)", "cookiecutter (==1.7.3)", "datasets (!=2.5.0)", "dill (<0.3.5)", "evaluate (>=0.2.0)", "faiss-cpu", "flax (>=0.4.1,<=0.7.0)", "fugashi (>=1.0)", "ipadic (>=1.0.0,<2.0)", "isort (>=5.5.4)", "jax (>=0.4.1,<=0.4.13)", "jaxlib (>=0.4.1,<=0.4.13)", "kenlm", "keras-nlp (>=0.3.1,<0.14.0)", "libcst", "librosa", "nltk (<=3.8.1)", "onnxconverter-common", "optax (>=0.0.8,<=0.1.4)", "optuna", "parameterized", "phonemizer", "protobuf", "psutil", "pyctcdecode (>=0.4.0)", "pydantic", "pytest (>=7.2.0,<8.0.0)", "pytest-asyncio", "pytest-rich", "pytest-timeout", "pytest-xdist", "ray[tune] (>=2.7.0)", "rhoknp (>=1.1.0,<1.3.1)", "rich", "rjieba", "rouge-score (!=0.0.7,!=0.0.8,!=0.1,!=0.1.1)", "ruff (==0.5.1)", "sacrebleu (>=1.4.12,<2.0.0)", "sacremoses", "scikit-learn", "scipy (<1.13.0)", "sentencepiece (>=0.1.91,!=0.1.92)", "sigopt", "sudachidict-core (>=20220729)", "sudachipy (>=0.6.6)", "tensorboard", "tensorflow (>2.9,<2.16)", "tensorflow-text (<2.16)", "tf2onnx", "timeout-decorator", "timm (<=1.0.11)", "tokenizers (>=0.21,<0.22)", "torch (>=2.0)", "torchaudio", "torchvision", "unidic (>=1.0.2)", "unidic-lite (>=1.0.7)", "urllib3 (<2.0.0)"]
|
||||
dev-tensorflow = ["GitPython (<3.1.19)", "Pillow (>=10.0.1,<=15.0)", "beautifulsoup4", "cookiecutter (==1.7.3)", "datasets (!=2.5.0)", "dill (<0.3.5)", "evaluate (>=0.2.0)", "faiss-cpu", "isort (>=5.5.4)", "kenlm", "keras-nlp (>=0.3.1,<0.14.0)", "libcst", "librosa", "nltk (<=3.8.1)", "onnxconverter-common", "onnxruntime (>=1.4.0)", "onnxruntime-tools (>=1.4.2)", "parameterized", "phonemizer", "protobuf", "psutil", "pyctcdecode (>=0.4.0)", "pydantic", "pytest (>=7.2.0,<8.0.0)", "pytest-asyncio", "pytest-rich", "pytest-timeout", "pytest-xdist", "rich", "rjieba", "rouge-score (!=0.0.7,!=0.0.8,!=0.1,!=0.1.1)", "ruff (==0.5.1)", "sacrebleu (>=1.4.12,<2.0.0)", "sacremoses", "scikit-learn", "sentencepiece (>=0.1.91,!=0.1.92)", "tensorboard", "tensorflow (>2.9,<2.16)", "tensorflow-text (<2.16)", "tf2onnx", "timeout-decorator", "tokenizers (>=0.21,<0.22)", "urllib3 (<2.0.0)"]
|
||||
dev-torch = ["GitPython (<3.1.19)", "Pillow (>=10.0.1,<=15.0)", "accelerate (>=0.26.0)", "beautifulsoup4", "codecarbon (>=2.8.1)", "cookiecutter (==1.7.3)", "datasets (!=2.5.0)", "dill (<0.3.5)", "evaluate (>=0.2.0)", "faiss-cpu", "fugashi (>=1.0)", "ipadic (>=1.0.0,<2.0)", "isort (>=5.5.4)", "kenlm", "libcst", "librosa", "nltk (<=3.8.1)", "onnxruntime (>=1.4.0)", "onnxruntime-tools (>=1.4.2)", "optuna", "parameterized", "phonemizer", "protobuf", "psutil", "pyctcdecode (>=0.4.0)", "pydantic", "pytest (>=7.2.0,<8.0.0)", "pytest-asyncio", "pytest-rich", "pytest-timeout", "pytest-xdist", "ray[tune] (>=2.7.0)", "rhoknp (>=1.1.0,<1.3.1)", "rich", "rjieba", "rouge-score (!=0.0.7,!=0.0.8,!=0.1,!=0.1.1)", "ruff (==0.5.1)", "sacrebleu (>=1.4.12,<2.0.0)", "sacremoses", "scikit-learn", "sentencepiece (>=0.1.91,!=0.1.92)", "sigopt", "sudachidict-core (>=20220729)", "sudachipy (>=0.6.6)", "tensorboard", "timeout-decorator", "timm (<=1.0.11)", "tokenizers (>=0.21,<0.22)", "torch (>=2.0)", "torchaudio", "torchvision", "unidic (>=1.0.2)", "unidic-lite (>=1.0.7)", "urllib3 (<2.0.0)"]
|
||||
flax = ["flax (>=0.4.1,<=0.7.0)", "jax (>=0.4.1,<=0.4.13)", "jaxlib (>=0.4.1,<=0.4.13)", "optax (>=0.0.8,<=0.1.4)", "scipy (<1.13.0)"]
|
||||
flax-speech = ["kenlm", "librosa", "phonemizer", "pyctcdecode (>=0.4.0)"]
|
||||
ftfy = ["ftfy"]
|
||||
integrations = ["optuna", "ray[tune] (>=2.7.0)", "sigopt"]
|
||||
ja = ["fugashi (>=1.0)", "ipadic (>=1.0.0,<2.0)", "rhoknp (>=1.1.0,<1.3.1)", "sudachidict-core (>=20220729)", "sudachipy (>=0.6.6)", "unidic (>=1.0.2)", "unidic-lite (>=1.0.7)"]
|
||||
modelcreation = ["cookiecutter (==1.7.3)"]
|
||||
natten = ["natten (>=0.14.6,<0.15.0)"]
|
||||
onnx = ["onnxconverter-common", "onnxruntime (>=1.4.0)", "onnxruntime-tools (>=1.4.2)", "tf2onnx"]
|
||||
onnxruntime = ["onnxruntime (>=1.4.0)", "onnxruntime-tools (>=1.4.2)"]
|
||||
optuna = ["optuna"]
|
||||
quality = ["GitPython (<3.1.19)", "datasets (!=2.5.0)", "isort (>=5.5.4)", "libcst", "rich", "ruff (==0.5.1)", "urllib3 (<2.0.0)"]
|
||||
ray = ["ray[tune] (>=2.7.0)"]
|
||||
retrieval = ["datasets (!=2.5.0)", "faiss-cpu"]
|
||||
ruff = ["ruff (==0.5.1)"]
|
||||
sagemaker = ["sagemaker (>=2.31.0)"]
|
||||
sentencepiece = ["protobuf", "sentencepiece (>=0.1.91,!=0.1.92)"]
|
||||
serving = ["fastapi", "pydantic", "starlette", "uvicorn"]
|
||||
sigopt = ["sigopt"]
|
||||
sklearn = ["scikit-learn"]
|
||||
speech = ["kenlm", "librosa", "phonemizer", "pyctcdecode (>=0.4.0)", "torchaudio"]
|
||||
testing = ["GitPython (<3.1.19)", "beautifulsoup4", "cookiecutter (==1.7.3)", "datasets (!=2.5.0)", "dill (<0.3.5)", "evaluate (>=0.2.0)", "faiss-cpu", "nltk (<=3.8.1)", "parameterized", "psutil", "pydantic", "pytest (>=7.2.0,<8.0.0)", "pytest-asyncio", "pytest-rich", "pytest-timeout", "pytest-xdist", "rjieba", "rouge-score (!=0.0.7,!=0.0.8,!=0.1,!=0.1.1)", "ruff (==0.5.1)", "sacrebleu (>=1.4.12,<2.0.0)", "sacremoses", "sentencepiece (>=0.1.91,!=0.1.92)", "tensorboard", "timeout-decorator"]
|
||||
tf = ["keras-nlp (>=0.3.1,<0.14.0)", "onnxconverter-common", "tensorflow (>2.9,<2.16)", "tensorflow-text (<2.16)", "tf2onnx"]
|
||||
tf-cpu = ["keras (>2.9,<2.16)", "keras-nlp (>=0.3.1,<0.14.0)", "onnxconverter-common", "tensorflow-cpu (>2.9,<2.16)", "tensorflow-probability (<0.24)", "tensorflow-text (<2.16)", "tf2onnx"]
|
||||
tf-speech = ["kenlm", "librosa", "phonemizer", "pyctcdecode (>=0.4.0)"]
|
||||
tiktoken = ["blobfile", "tiktoken"]
|
||||
timm = ["timm (<=1.0.11)"]
|
||||
tokenizers = ["tokenizers (>=0.21,<0.22)"]
|
||||
torch = ["accelerate (>=0.26.0)", "torch (>=2.0)"]
|
||||
torch-speech = ["kenlm", "librosa", "phonemizer", "pyctcdecode (>=0.4.0)", "torchaudio"]
|
||||
torch-vision = ["Pillow (>=10.0.1,<=15.0)", "torchvision"]
|
||||
torchhub = ["filelock", "huggingface-hub (>=0.24.0,<1.0)", "importlib-metadata", "numpy (>=1.17)", "packaging (>=20.0)", "protobuf", "regex (!=2019.12.17)", "requests", "sentencepiece (>=0.1.91,!=0.1.92)", "tokenizers (>=0.21,<0.22)", "torch (>=2.0)", "tqdm (>=4.27)"]
|
||||
video = ["av (==9.2.0)"]
|
||||
vision = ["Pillow (>=10.0.1,<=15.0)"]
|
||||
|
||||
[[package]]
|
||||
name = "transforms3d"
|
||||
version = "0.4.2"
|
||||
|
@ -7443,13 +7525,13 @@ test = ["websockets"]
|
|||
|
||||
[[package]]
|
||||
name = "werkzeug"
|
||||
version = "3.0.6"
|
||||
version = "3.1.3"
|
||||
description = "The comprehensive WSGI web application library."
|
||||
optional = false
|
||||
python-versions = ">=3.8"
|
||||
python-versions = ">=3.9"
|
||||
files = [
|
||||
{file = "werkzeug-3.0.6-py3-none-any.whl", hash = "sha256:1bc0c2310d2fbb07b1dd1105eba2f7af72f322e1e455f2f93c993bee8c8a5f17"},
|
||||
{file = "werkzeug-3.0.6.tar.gz", hash = "sha256:a8dd59d4de28ca70471a34cba79bed5f7ef2e036a76b3ab0835474246eb41f8d"},
|
||||
{file = "werkzeug-3.1.3-py3-none-any.whl", hash = "sha256:54b78bf3716d19a65be4fceccc0d1d7b89e608834989dfae50ea87564639213e"},
|
||||
{file = "werkzeug-3.1.3.tar.gz", hash = "sha256:60723ce945c19328679790e3282cc758aa4a6040e4bb330f53d30fa546d44746"},
|
||||
]
|
||||
|
||||
[package.dependencies]
|
||||
|
@ -7843,6 +7925,7 @@ dora = ["gym-dora"]
|
|||
dynamixel = ["dynamixel-sdk", "pynput"]
|
||||
feetech = ["feetech-servo-sdk", "pynput"]
|
||||
intelrealsense = ["pyrealsense2"]
|
||||
pi0 = ["transformers"]
|
||||
pusht = ["gym-pusht"]
|
||||
stretch = ["hello-robot-stretch-body", "pynput", "pyrealsense2", "pyrender"]
|
||||
test = ["pyserial", "pytest", "pytest-cov"]
|
||||
|
@ -7853,4 +7936,4 @@ xarm = ["gym-xarm"]
|
|||
[metadata]
|
||||
lock-version = "2.0"
|
||||
python-versions = ">=3.10,<3.13"
|
||||
content-hash = "d6c8b50f7bb31e4600500730ae7a7973fa5d3ec7b2ca1705dd0d2a1886975e68"
|
||||
content-hash = "089df92a0455bb58d2f600ad645b99362c9ff2b800a0c6108edc09f51509c716"
|
||||
|
|
|
@ -69,6 +69,7 @@ pyrender = {git = "https://github.com/mmatl/pyrender.git", markers = "sys_platfo
|
|||
hello-robot-stretch-body = {version = ">=0.7.27", markers = "sys_platform == 'linux'", optional = true}
|
||||
pyserial = {version = ">=3.5", optional = true}
|
||||
jsonlines = ">=4.0.0"
|
||||
transformers = ">=4.48.0"
|
||||
draccus = {git = "https://github.com/dlwh/draccus.git"} # replace with draccus = ">=0.9.4" when https://github.com/dlwh/draccus/pull/29 is in release
|
||||
|
||||
|
||||
|
@ -85,6 +86,7 @@ dynamixel = ["dynamixel-sdk", "pynput"]
|
|||
feetech = ["feetech-servo-sdk", "pynput"]
|
||||
intelrealsense = ["pyrealsense2"]
|
||||
stretch = ["hello-robot-stretch-body", "pyrender", "pyrealsense2", "pynput"]
|
||||
pi0 = ["transformers"]
|
||||
|
||||
[tool.ruff]
|
||||
line-length = 110
|
||||
|
|
|
@ -82,8 +82,13 @@ def get_policy_stats(ds_repo_id, env_name, policy_name, policy_kwargs, train_kwa
|
|||
batch = next(iter(dataloader))
|
||||
obs = {}
|
||||
for k in batch:
|
||||
# TODO: regenerate the safetensors
|
||||
# for backward compatibility
|
||||
if k.endswith("_is_pad"):
|
||||
continue
|
||||
# for backward compatibility
|
||||
if k == "task":
|
||||
continue
|
||||
if k.startswith("observation"):
|
||||
obs[k] = batch[k]
|
||||
|
||||
|
|
|
@ -323,6 +323,8 @@ def test_backward_compatibility(repo_id):
|
|||
# TODO (michel-aractingi): transform language obs to langauge embeddings via tokenizer
|
||||
new_frame.pop("language_instruction", None)
|
||||
old_frame.pop("language_instruction", None)
|
||||
new_frame.pop("task", None)
|
||||
old_frame.pop("task", None)
|
||||
|
||||
# Remove task_index to allow for backward compatibility
|
||||
# TODO(rcadene): remove when new features have been generated
|
||||
|
|
|
@ -167,14 +167,16 @@ def test_policy(ds_repo_id, env_name, env_kwargs, policy_name, policy_kwargs):
|
|||
batch = next(dl_iter)
|
||||
|
||||
for key in batch:
|
||||
batch[key] = batch[key].to(DEVICE, non_blocking=True)
|
||||
if isinstance(batch[key], torch.Tensor):
|
||||
batch[key] = batch[key].to(DEVICE, non_blocking=True)
|
||||
|
||||
# Test updating the policy (and test that it does not mutate the batch)
|
||||
batch_ = deepcopy(batch)
|
||||
policy.forward(batch)
|
||||
assert set(batch) == set(batch_), "Batch keys are not the same after a forward pass."
|
||||
assert all(
|
||||
torch.equal(batch[k], batch_[k]) for k in batch
|
||||
torch.equal(batch[k], batch_[k]) if isinstance(batch[k], torch.Tensor) else batch[k] == batch_[k]
|
||||
for k in batch
|
||||
), "Batch values are not the same after a forward pass."
|
||||
|
||||
# reset the policy and environment
|
||||
|
|
Loading…
Reference in New Issue