[pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci
This commit is contained in:
pre-commit-ci[bot] 2025-03-05 00:35:06 +00:00
parent eb5ed64e62
commit f7d664dcc0
12 changed files with 591 additions and 466 deletions

View File

@ -1,6 +1,6 @@
from .act.configuration_act import ACTConfig as ACTConfig
from .dexvla.configuration_dexvla import DexVLAConfig as DexVLAConfig
from .diffusion.configuration_diffusion import DiffusionConfig as DiffusionConfig
from .pi0.configuration_pi0 import PI0Config as PI0Config
from .tdmpc.configuration_tdmpc import TDMPCConfig as TDMPCConfig
from .vqbet.configuration_vqbet import VQBeTConfig as VQBeTConfig
from .dexvla.configuration_dexvla import DexVLAConfig as DexVLAConfig

View File

@ -1,4 +1,3 @@
# coding=utf-8
# Copyright 2024 The Qwen team, Alibaba Group and the HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
@ -14,30 +13,28 @@
# limitations under the License.
"""Qwen2VL model configuration"""
from dataclasses import dataclass, field
from typing import Tuple
from dataclasses import dataclass, field
from transformers import AutoConfig
from transformers.utils import logging
from lerobot.common.optim.optimizers import AdamWConfig
from lerobot.common.optim.schedulers import (
CosineDecayWithWarmupSchedulerConfig,
)
from transformers.utils import logging
from lerobot.configs.policies import PreTrainedConfig
from lerobot.common.policies.dexvla.policy_heads.configuration_scaledp import ScaleDPPolicyConfig
from lerobot.common.policies.dexvla.policy_heads.configuration_unet_diffusion import UnetDiffusionPolicyConfig
from lerobot.common.policies.dexvla.qwe2_vla.configuration_qwen2_vla import Qwen2VLAConfig
from lerobot.configs.types import NormalizationMode
logger = logging.get_logger(__name__)
@PreTrainedConfig.register_subclass("dexvla")
@dataclass
class DexVLAConfig(PreTrainedConfig):
# For loading policy head
policy_head_type: str = 'scale_dp_policy'
policy_head_size: str = 'ScaleDP_L'
policy_head_type: str = "scale_dp_policy"
policy_head_size: str = "ScaleDP_L"
action_dim: int = 14
state_dim: int = 14
chunk_size: int = 50
@ -45,9 +42,9 @@ class DexVLAConfig(PreTrainedConfig):
n_obs_steps: int = 1
hidden_size: int = 1536
qwen2_vl_path: str = None # '/media/rl/HDD/data/weights/Qwen2-VL-2B-Instruct'
qwen2_vl_path: str = None # '/media/rl/HDD/data/weights/Qwen2-VL-2B-Instruct'
pretrained_path: str = None # pretrained dexvla
pretrained_path: str = None # pretrained dexvla
using_film: bool = True
llm_loss_weight: float = 1.0
with_llm_head: bool = True
@ -82,33 +79,37 @@ class DexVLAConfig(PreTrainedConfig):
f"Multiple observation steps not handled yet. Got `nobs_steps={self.n_obs_steps}`"
)
if self.using_reasoning:
assert self.using_film, f"using_reasoning requires `using_film=True`"
assert self.with_llm_head, f"using_reasoning requires `with_llm_head=True`"
assert self.using_film, "using_reasoning requires `using_film=True`"
assert self.with_llm_head, "using_reasoning requires `with_llm_head=True`"
print("You have set using_reasoning=True, please make sure your data has key 'reasoning'.")
else:
print(f"Warning:DexVLA recommends to use reasoning data which can better handle long-horizon and dexterous tasks. You can set 'using_reaasoning=True'.")
print(
"Warning:DexVLA recommends to use reasoning data which can better handle long-horizon and dexterous tasks. You can set 'using_reaasoning=True'."
)
if self.qwen2_vl_path is None:
raise ValueError("DexVLA is built on official qwen2_vl-2B. You have to download the official weights of qwen2_vl-2B first and set 'qwen2_vl_path'.")
raise ValueError(
"DexVLA is built on official qwen2_vl-2B. You have to download the official weights of qwen2_vl-2B first and set 'qwen2_vl_path'."
)
if self.policy_head_type == 'scale_dp_policy':
if self.policy_head_type == "scale_dp_policy":
self.policy_head_config = AutoConfig.for_model(
model_type=self.policy_head_type,
model_size=self.policy_head_size,
cond_dim=self.hidden_size,
action_dim=self.action_dim,
prediction_horizon=self.chunk_size,
state_dim=self.state_dim
state_dim=self.state_dim,
)
elif self.policy_head_type == 'unet_diffusion':
elif self.policy_head_type == "unet_diffusion":
self.policy_head_config = AutoConfig.for_model(
model_type=self.policy_head_type,
global_cond_dim=self.hidden_size,
action_dim=self.action_dim,
state_dim=self.state_dim
state_dim=self.state_dim,
)
else:
raise ValueError(f'Policy head type {self.policy_head_type} not supported')
raise ValueError(f"Policy head type {self.policy_head_type} not supported")
self.qwen2_vla_config = AutoConfig.from_pretrained(self.qwen2_vl_path)
@ -152,6 +153,3 @@ class DexVLAConfig(PreTrainedConfig):
@property
def reward_delta_indices(self) -> None:
return None

View File

@ -1,16 +1,18 @@
import torch.nn as nn
class ActionProjector(nn.Module):
def __init__(self, in_dim, out_dim=1024):
super(ActionProjector, self).__init__()
super().__init__()
self.global_1d_pool = nn.AdaptiveAvgPool1d(1)
self.mlps = nn.ModuleList([
# nn.LayerNorm(in_dim),
nn.Linear(in_dim, in_dim),
nn.GELU(),
nn.Linear(in_dim, out_dim),
nn.Dropout(0.0),
]
self.mlps = nn.ModuleList(
[
# nn.LayerNorm(in_dim),
nn.Linear(in_dim, in_dim),
nn.GELU(),
nn.Linear(in_dim, out_dim),
nn.Dropout(0.0),
]
)
def forward(self, x):
@ -22,7 +24,7 @@ class ActionProjector(nn.Module):
class FiLM(nn.Module):
def __init__(self, feature_dim, condition_dim):
super(FiLM, self).__init__()
super().__init__()
self.scale_fc = nn.Linear(condition_dim, feature_dim)
self.shift_fc = nn.Linear(condition_dim, feature_dim)

View File

@ -1,18 +1,16 @@
import torch
from torch import Tensor
from lerobot.common.policies.normalize import Normalize, Unnormalize
from lerobot.common.policies.dexvla.configuration_dexvla import DexVLAConfig
from lerobot.common.policies.dexvla.qwe2_vla.modeling_qwen2_vla import (
Qwen2VLForConditionalGenerationForVLA
)
from lerobot.common.policies.pretrained import PreTrainedPolicy
from collections import deque
from lerobot.common.policies.dexvla.policy_heads.modeling_unet_diffusion import ConditionalUnet1D
from lerobot.common.policies.dexvla.policy_heads.modeling_scaledp import ScaleDP
from lerobot.common.policies.dexvla.robot_data_processor import Qwen2VLAProcess
from transformers import AutoProcessor, AutoTokenizer
import torch
import torchvision.transforms as transforms
from torch import Tensor
from transformers import AutoProcessor, AutoTokenizer
from lerobot.common.policies.dexvla.configuration_dexvla import DexVLAConfig
from lerobot.common.policies.dexvla.qwe2_vla.modeling_qwen2_vla import Qwen2VLForConditionalGenerationForVLA
from lerobot.common.policies.dexvla.robot_data_processor import Qwen2VLAProcess
from lerobot.common.policies.normalize import Normalize, Unnormalize
from lerobot.common.policies.pretrained import PreTrainedPolicy
class DexVLAPolicy(PreTrainedPolicy):
"""Wrapper class around Qwen2VLForConditionalGenerationForVLA model to train and run inference within LeRobot."""
@ -44,17 +42,17 @@ class DexVLAPolicy(PreTrainedPolicy):
config.output_features, config.normalization_mapping, dataset_stats
)
for k in ['using_film', 'llm_loss_weight', 'with_llm_head', 'policy_head_config']:
for k in ["using_film", "llm_loss_weight", "with_llm_head", "policy_head_config"]:
setattr(config.qwen2_vla_config, k, config.__dict__[k])
self.model = Qwen2VLForConditionalGenerationForVLA(config.qwen2_vla_config).to(torch.bfloat16)
self.model.requires_grad_(False)
self.model.policy_head.requires_grad_(True)
self.qwen2_vl_processor = AutoProcessor.from_pretrained(config.qwen2_vl_path)
self.tokenizer = AutoTokenizer.from_pretrained(
config.qwen2_vl_path
)
self.vla_processor = Qwen2VLAProcess(tokenizer=self.tokenizer, multimodal_processor=self.qwen2_vl_processor) # process the input data into VLM format
self.tokenizer = AutoTokenizer.from_pretrained(config.qwen2_vl_path)
self.vla_processor = Qwen2VLAProcess(
tokenizer=self.tokenizer, multimodal_processor=self.qwen2_vl_processor
) # process the input data into VLM format
self.resize_size = self.config.resize_size
ratio = 0.95
@ -73,14 +71,14 @@ class DexVLAPolicy(PreTrainedPolicy):
batch = self.normalize_inputs(batch)
batch = self.normalize_targets(batch)
present_img_keys = [key for key in self.config.image_features if key in batch]
task_descs = batch['task']
task_descs = batch["task"]
try:
reasonings = batch['reasoning']
reasonings = batch["reasoning"]
except KeyError:
reasonings = ['no reasoning'] * len(task_descs)
reasonings = ["no reasoning"] * len(task_descs)
pass
is_pad = batch['action_is_pad']
is_pad = batch["action_is_pad"]
all_cam_images = []
for k in present_img_keys:
all_cam_images.append(batch[k])
@ -89,8 +87,8 @@ class DexVLAPolicy(PreTrainedPolicy):
image_data = torch.stack(all_cam_images) * 255
image_data = image_data.to(dtype=torch.uint8)
# construct observations
qpos_data = batch['observation.state'].float()
action_data = batch['action'].float()
qpos_data = batch["observation.state"].float()
action_data = batch["action"].float()
orig_shape = image_data.shape
image_data = image_data.view(-1, *orig_shape[2:])
@ -100,40 +98,35 @@ class DexVLAPolicy(PreTrainedPolicy):
image_data = image_data.view(*orig_shape[:3], *self.resize_size)
vl_data = {
'images': image_data,
'raw_langs': task_descs,
'reasonings': reasonings
}
vl_data = {"images": image_data, "raw_langs": task_descs, "reasonings": reasonings}
# processing vl_data into qwen2_vl format
vla_inputs = self.vla_processor.forward(vl_data, use_reasoning=self.config.using_reasoning)
vla_inputs['states'] = qpos_data
vla_inputs['is_pad'] = is_pad
vla_inputs['actions'] = action_data
vla_inputs["states"] = qpos_data
vla_inputs["is_pad"] = is_pad
vla_inputs["actions"] = action_data
return vla_inputs
def forward(self, batch: dict[str, Tensor]) -> tuple[Tensor, dict[str, Tensor]]:
processed_batch = self.process_batch(batch)
ret = self.model.forward(**processed_batch)
loss_dict = ret['loss']
loss = loss_dict['loss'].mean()
loss_dict = ret["loss"]
loss = loss_dict["loss"].mean()
return loss, loss_dict
def dexvla_predict_action(self,
input_ids: torch.LongTensor = None,
actions=None,
states=None,
is_pad=None,
tokenizer=None,
is_eval=True,
pixel_values=None,
attention_mask=None,
image_grid_thw=None,
):
input_ids = input_ids.to('cuda')
def dexvla_predict_action(
self,
input_ids: torch.LongTensor = None,
actions=None,
states=None,
is_pad=None,
tokenizer=None,
is_eval=True,
pixel_values=None,
attention_mask=None,
image_grid_thw=None,
):
input_ids = input_ids.to("cuda")
with torch.inference_mode():
outputs = self.model.generate(
input_ids,
@ -157,7 +150,7 @@ class DexVLAPolicy(PreTrainedPolicy):
input_token_len = input_ids.shape[1]
n_diff_input_output = (input_ids != output_ids[:, :input_token_len]).sum().item()
if n_diff_input_output > 0:
print(f'[Warning] {n_diff_input_output} output_ids are not the same as the input_ids')
print(f"[Warning] {n_diff_input_output} output_ids are not the same as the input_ids")
outputs_text = tokenizer.batch_decode(output_ids[:, input_token_len:], skip_special_tokens=False)[0]
outputs_text = outputs_text.strip()
@ -167,35 +160,44 @@ class DexVLAPolicy(PreTrainedPolicy):
action_hidden_states = None
if self.model.using_film:
action_hidden_states = self.model.film_forward(labels=torch.ones_like(output_ids),
input_ids=output_ids,
hidden_states=torch.cat(last_hidden_states, dim=1))
action_hidden_states = self.model.film_forward(
labels=torch.ones_like(output_ids),
input_ids=output_ids,
hidden_states=torch.cat(last_hidden_states, dim=1),
)
action = self.model.policy_head(actions, action_hidden_states, states.to(all_hidden_states.dtype), is_pad)
action = self.model.policy_head(
actions, action_hidden_states, states.to(all_hidden_states.dtype), is_pad
)
return action, outputs_text
def tinyvla_predict_action(self,
input_ids: torch.LongTensor = None,
actions=None,
states=None,
is_pad=None,
is_eval=True,
pixel_values=None,
attention_mask=None,
image_grid_thw=None,
):
input_ids = input_ids.to('cuda')
def tinyvla_predict_action(
self,
input_ids: torch.LongTensor = None,
actions=None,
states=None,
is_pad=None,
is_eval=True,
pixel_values=None,
attention_mask=None,
image_grid_thw=None,
):
input_ids = input_ids.to("cuda")
with torch.inference_mode():
all_hidden_states = self.model.forward(input_ids,
pixel_values=pixel_values,
attention_mask=attention_mask,
image_grid_thw=image_grid_thw,
is_eval=is_eval,
tinyvla=True)
all_hidden_states = self.model.forward(
input_ids,
pixel_values=pixel_values,
attention_mask=attention_mask,
image_grid_thw=image_grid_thw,
is_eval=is_eval,
tinyvla=True,
)
all_hidden_states = torch.mean(all_hidden_states, dim=1).unsqueeze(1)
action = self.model.policy_head(actions, all_hidden_states, states.to(all_hidden_states.dtype), is_pad)
action = self.model.policy_head(
actions, all_hidden_states, states.to(all_hidden_states.dtype), is_pad
)
return action, "tinyvla generates no reasoning"
def reset(self):
@ -219,7 +221,7 @@ class DexVLAPolicy(PreTrainedPolicy):
if len(self._action_queue) == 0:
present_img_keys = [key for key in self.config.image_features if key in batch]
try:
task_descs = batch['task']
task_descs = batch["task"]
except KeyError:
task_descs = " "
print("No task descriptions found for this task")
@ -232,7 +234,7 @@ class DexVLAPolicy(PreTrainedPolicy):
image_data = torch.stack(all_cam_images) * 255
image_data = image_data.to(dtype=torch.uint8)
# construct observations
qpos_data = batch['observation.state'].float()
qpos_data = batch["observation.state"].float()
image_data = image_data.squeeze(0)
@ -240,20 +242,19 @@ class DexVLAPolicy(PreTrainedPolicy):
image_data = transform(image_data)
# processing vl_data into qwen2_vl format
vla_inputs = self.vla_processor.single_forward_process(images=image_data, raw_lang=task_descs, reasoning=None, eval=True)
vla_inputs['states'] = qpos_data
vla_inputs = self.vla_processor.single_forward_process(
images=image_data, raw_lang=task_descs, reasoning=None, eval=True
)
vla_inputs["states"] = qpos_data
if self.config.using_film and self.config.with_llm_head: # dexvla
all_actions, outputs = self.dexvla_predict_action(**vla_inputs, is_eval=True, tokenizer=self.tokenizer)
else: # tinyvla
if self.config.using_film and self.config.with_llm_head: # dexvla
all_actions, outputs = self.dexvla_predict_action(
**vla_inputs, is_eval=True, tokenizer=self.tokenizer
)
else: # tinyvla
all_actions, outputs = self.tinyvla_predict_action(**vla_inputs, is_eval=True)
actions = self.unnormalize_outputs({"action": all_actions})["action"]
self._action_queue.extend(actions.transpose(0, 1))
return self._action_queue.popleft()

View File

@ -1,47 +1,58 @@
import os
from typing import Union, List
from transformers import PretrainedConfig
from typing import Union
from transformers import AutoConfig, PretrainedConfig
from transformers.utils import logging
from transformers import AutoConfig, AutoModelForCausalLM
logger = logging.get_logger(__name__)
MODEL_STRUCTURE = {
'ScaleDP_H': {'depth': 32, 'n_emb': 1280, 'num_heads': 16, },
'ScaleDP_L': {'depth': 24, 'n_emb': 1024, 'num_heads': 16, }, # 400M
"ScaleDP_H": {
"depth": 32,
"n_emb": 1280,
"num_heads": 16,
},
"ScaleDP_L": {
"depth": 24,
"n_emb": 1024,
"num_heads": 16,
}, # 400M
}
class ScaleDPPolicyConfig(PretrainedConfig):
'''
"""
Configuration for ScaleDP policy head
'''
"""
model_type = "scale_dp_policy"
def __init__(
self,
eval: bool = False,
action_dim: int = 14, # action dim
# output_dim: int = 14, # action dim
cond_dim: int = 1536, # the input dim of the condition
state_dim: int = 14, # the input dim of the state
prediction_horizon: int = 16, # horizon
n_obs_steps: int = 2, # number of observation steps
depth: int = 28, # number of DiT blocks
n_emb: int = 256, # embedding size
num_heads: int = 16,
mlp_ratio: int = 4.0,
time_as_cond: bool = True,
obs_as_cond: bool = True,
learn_sigma: bool = False,
model_size: str = "none",
num_inference_timesteps: int = 10,
noise_samples: int = 1,
num_train_timesteps: int = 100,
**kwargs
self,
eval: bool = False,
action_dim: int = 14, # action dim
# output_dim: int = 14, # action dim
cond_dim: int = 1536, # the input dim of the condition
state_dim: int = 14, # the input dim of the state
prediction_horizon: int = 16, # horizon
n_obs_steps: int = 2, # number of observation steps
depth: int = 28, # number of DiT blocks
n_emb: int = 256, # embedding size
num_heads: int = 16,
mlp_ratio: int = 4.0,
time_as_cond: bool = True,
obs_as_cond: bool = True,
learn_sigma: bool = False,
model_size: str = "none",
num_inference_timesteps: int = 10,
noise_samples: int = 1,
num_train_timesteps: int = 100,
**kwargs,
):
if model_size != "none":
depth = MODEL_STRUCTURE[model_size]['depth']
n_emb = MODEL_STRUCTURE[model_size]['n_emb']
num_heads = MODEL_STRUCTURE[model_size]['num_heads']
depth = MODEL_STRUCTURE[model_size]["depth"]
n_emb = MODEL_STRUCTURE[model_size]["n_emb"]
num_heads = MODEL_STRUCTURE[model_size]["num_heads"]
else:
# raise ValueError("model_size show not be 'none'")
pass
@ -52,7 +63,6 @@ class ScaleDPPolicyConfig(PretrainedConfig):
self.output_dim = action_dim
self.prediction_horizon = prediction_horizon
self.cond_dim = cond_dim
self.state_dim = state_dim
@ -72,7 +82,9 @@ class ScaleDPPolicyConfig(PretrainedConfig):
super().__init__(**kwargs)
@classmethod
def from_pretrained(cls, pretrained_model_name_or_path: Union[str, os.PathLike], **kwargs) -> "PretrainedConfig":
def from_pretrained(
cls, pretrained_model_name_or_path: Union[str, os.PathLike], **kwargs
) -> "PretrainedConfig":
cls._set_token_in_kwargs(kwargs)
config_dict, kwargs = cls.get_config_dict(pretrained_model_name_or_path, **kwargs)
@ -81,7 +93,11 @@ class ScaleDPPolicyConfig(PretrainedConfig):
if config_dict.get("model_type") == "llava_pythia":
config_dict = config_dict["action_head"]
if "model_type" in config_dict and hasattr(cls, "model_type") and config_dict["model_type"] != cls.model_type:
if (
"model_type" in config_dict
and hasattr(cls, "model_type")
and config_dict["model_type"] != cls.model_type
):
logger.warning(
f"You are using a model of type {config_dict['model_type']} to instantiate a model of type "
f"{cls.model_type}. This is not supported for all configurations of models and can yield errors."
@ -89,4 +105,5 @@ class ScaleDPPolicyConfig(PretrainedConfig):
return cls.from_dict(config_dict, **kwargs)
AutoConfig.register("scale_dp_policy", ScaleDPPolicyConfig)

View File

@ -1,31 +1,33 @@
import os
from typing import Union, List
from transformers import PretrainedConfig
from typing import Union
from transformers import AutoConfig, PretrainedConfig
from transformers.utils import logging
from transformers import AutoConfig, AutoModelForCausalLM
logger = logging.get_logger(__name__)
class UnetDiffusionPolicyConfig(PretrainedConfig):
'''
"""
Configuration for dit diffusion policy head
'''
"""
model_type = "unet_diffusion_policy"
def __init__(
self,
action_dim=10,
global_cond_dim=2048,
diffusion_step_embed_dim=256,
down_dims=[256, 512, 1024],
kernel_size=5,
n_groups=8,
state_dim=7,
prediction_horizon=16,
noise_samples=1,
num_inference_timesteps=10,
num_train_timesteps=100,
**kwargs
self,
action_dim=10,
global_cond_dim=2048,
diffusion_step_embed_dim=256,
down_dims=[256, 512, 1024],
kernel_size=5,
n_groups=8,
state_dim=7,
prediction_horizon=16,
noise_samples=1,
num_inference_timesteps=10,
num_train_timesteps=100,
**kwargs,
):
self.input_dim = action_dim
self.noise_samples = noise_samples
@ -42,7 +44,9 @@ class UnetDiffusionPolicyConfig(PretrainedConfig):
super().__init__(**kwargs)
@classmethod
def from_pretrained(cls, pretrained_model_name_or_path: Union[str, os.PathLike], **kwargs) -> "PretrainedConfig":
def from_pretrained(
cls, pretrained_model_name_or_path: Union[str, os.PathLike], **kwargs
) -> "PretrainedConfig":
cls._set_token_in_kwargs(kwargs)
config_dict, kwargs = cls.get_config_dict(pretrained_model_name_or_path, **kwargs)
@ -51,7 +55,11 @@ class UnetDiffusionPolicyConfig(PretrainedConfig):
if config_dict.get("model_type") == "llava_pythia":
config_dict = config_dict["action_head"]
if "model_type" in config_dict and hasattr(cls, "model_type") and config_dict["model_type"] != cls.model_type:
if (
"model_type" in config_dict
and hasattr(cls, "model_type")
and config_dict["model_type"] != cls.model_type
):
logger.warning(
f"You are using a model of type {config_dict['model_type']} to instantiate a model of type "
f"{cls.model_type}. This is not supported for all configurations of models and can yield errors."
@ -59,4 +67,5 @@ class UnetDiffusionPolicyConfig(PretrainedConfig):
return cls.from_dict(config_dict, **kwargs)
AutoConfig.register("unet_diffusion_policy", UnetDiffusionPolicyConfig)

View File

@ -1,27 +1,24 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
from typing import Tuple
import timm
import numpy as np
import logging
import math
from typing import Tuple
import numpy as np
try:
from typing import Literal
except ImportError:
from typing_extensions import Literal
pass
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.utils.checkpoint
from torch.jit import Final
from timm.models.vision_transformer import Mlp, use_fused_attn
from torch.jit import Final
from transformers import AutoModel
from transformers.modeling_utils import PreTrainedModel
from transformers import AutoModel, AutoModelForCausalLM
_logger = logging.getLogger(__name__)
@ -30,20 +27,20 @@ class Attention(nn.Module):
fused_attn: Final[bool]
def __init__(
self,
dim: int,
num_heads: int = 8,
qkv_bias: bool = False,
qk_norm: bool = False,
attn_drop: float = 0.,
proj_drop: float = 0.,
norm_layer: nn.Module = nn.LayerNorm,
self,
dim: int,
num_heads: int = 8,
qkv_bias: bool = False,
qk_norm: bool = False,
attn_drop: float = 0.0,
proj_drop: float = 0.0,
norm_layer: nn.Module = nn.LayerNorm,
) -> None:
super().__init__()
assert dim % num_heads == 0, 'dim should be divisible by num_heads'
assert dim % num_heads == 0, "dim should be divisible by num_heads"
self.num_heads = num_heads
self.head_dim = dim // num_heads
self.scale = self.head_dim ** -0.5
self.scale = self.head_dim**-0.5
self.fused_attn = use_fused_attn()
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
@ -61,8 +58,11 @@ class Attention(nn.Module):
if self.fused_attn:
x = F.scaled_dot_product_attention(
q, k, v, attn_mask=attn_mask,
dropout_p=self.attn_drop.p if self.training else 0.,
q,
k,
v,
attn_mask=attn_mask,
dropout_p=self.attn_drop.p if self.training else 0.0,
)
else:
q = q * self.scale
@ -104,6 +104,7 @@ def modulate(x, shift, scale):
# Embedding Layers for Timesteps and Class Labels #
#################################################################################
class TimestepEmbedder(nn.Module):
"""
Embeds scalar timesteps into vector representations.
@ -145,11 +146,11 @@ class TimestepEmbedder(nn.Module):
return t_emb
#################################################################################
# Core ScaleDP Model #
#################################################################################
class ScaleDPBlock(nn.Module):
"""
A ScaleDP block with adaptive layer norm zero (adaLN-Zero) conScaleDPioning.
@ -163,14 +164,15 @@ class ScaleDPBlock(nn.Module):
mlp_hidden_dim = int(hidden_size * mlp_ratio)
approx_gelu = lambda: nn.GELU(approximate="tanh")
self.mlp = Mlp(in_features=hidden_size, hidden_features=mlp_hidden_dim, act_layer=approx_gelu, drop=0)
self.adaLN_modulation = nn.Sequential(
nn.SiLU(),
nn.Linear(hidden_size, 6 * hidden_size, bias=True)
)
self.adaLN_modulation = nn.Sequential(nn.SiLU(), nn.Linear(hidden_size, 6 * hidden_size, bias=True))
def forward(self, x, c, attn_mask=None):
shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.adaLN_modulation(c).chunk(6, dim=1)
x = x + gate_msa.unsqueeze(1) * self.attn(modulate(self.norm1(x), shift_msa, scale_msa), attn_mask=attn_mask) # norm, scale&shift, attn, scale,
shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.adaLN_modulation(c).chunk(
6, dim=1
)
x = x + gate_msa.unsqueeze(1) * self.attn(
modulate(self.norm1(x), shift_msa, scale_msa), attn_mask=attn_mask
) # norm, scale&shift, attn, scale,
x = x + gate_mlp.unsqueeze(1) * self.mlp(modulate(self.norm2(x), shift_mlp, scale_mlp))
return x
@ -184,10 +186,7 @@ class FinalLayer(nn.Module):
super().__init__()
self.norm_final = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
self.linear = nn.Linear(hidden_size, output_dim, bias=True)
self.adaLN_modulation = nn.Sequential(
nn.SiLU(),
nn.Linear(hidden_size, 2 * hidden_size, bias=True)
)
self.adaLN_modulation = nn.Sequential(nn.SiLU(), nn.Linear(hidden_size, 2 * hidden_size, bias=True))
def forward(self, x, c):
shift, scale = self.adaLN_modulation(c).chunk(2, dim=1)
@ -195,15 +194,20 @@ class FinalLayer(nn.Module):
x = self.linear(x)
return x
from .configuration_scaledp import ScaleDPPolicyConfig
class ScaleDP(PreTrainedModel):
"""
Diffusion models with a Transformer backbone.
"""
config_class = ScaleDPPolicyConfig
def __init__(
self,
config: ScaleDPPolicyConfig,
self,
config: ScaleDPPolicyConfig,
):
super().__init__(config)
# compute number of tokens for main trunk and conScaleDPion encoder
@ -221,11 +225,11 @@ class ScaleDP(PreTrainedModel):
# self.combine = nn.Linear(cond_dim+state_dim, cond_dim)
self.combine = nn.Sequential(
nn.Linear(config.cond_dim+config.state_dim, 1024),
nn.Linear(config.cond_dim + config.state_dim, 1024),
nn.ReLU(),
nn.Linear(1024, 1024),
nn.ReLU(),
nn.Linear(1024, config.cond_dim)
nn.Linear(1024, config.cond_dim),
)
self.learn_sigma = config.learn_sigma
self.input_dim = config.input_dim
@ -241,9 +245,12 @@ class ScaleDP(PreTrainedModel):
# Will use fixed sin-cos embedding:
self.pos_embed = nn.Parameter(torch.zeros(1, config.prediction_horizon, config.n_emb))
self.blocks = nn.ModuleList([
ScaleDPBlock(config.n_emb, config.num_heads, mlp_ratio=config.mlp_ratio) for _ in range(config.depth)
])
self.blocks = nn.ModuleList(
[
ScaleDPBlock(config.n_emb, config.num_heads, mlp_ratio=config.mlp_ratio)
for _ in range(config.depth)
]
)
self.final_layer = FinalLayer(config.n_emb, output_dim=config.output_dim)
# self.initialize_weights()
# constants
@ -253,23 +260,22 @@ class ScaleDP(PreTrainedModel):
self.time_as_cond = config.time_as_cond
self.action_dim = config.output_dim
self.obs_as_cond = obs_as_cond
logger.info(
"number of parameters in ScaleDP: %e", sum(p.numel() for p in self.parameters())
)
logger.info("number of parameters in ScaleDP: %e", sum(p.numel() for p in self.parameters()))
from diffusers.schedulers.scheduling_ddim import DDIMScheduler
self.num_inference_timesteps = config.num_inference_timesteps
# self.proj_to_action = nn.Identity()
self.noise_scheduler = DDIMScheduler(
num_train_timesteps=config.num_train_timesteps, # 100
beta_schedule='squaredcos_cap_v2',
num_train_timesteps=config.num_train_timesteps, # 100
beta_schedule="squaredcos_cap_v2",
clip_sample=True,
set_alpha_to_one=True,
steps_offset=0,
prediction_type='epsilon'
prediction_type="epsilon",
)
self.num_queries = config.num_queries #16
self.noise_samples = config.noise_samples # 1
self.num_queries = config.num_queries # 16
self.noise_samples = config.noise_samples # 1
# self.num_inference_timesteps = config.num_inference_timesteps # 100
def initialize_weights(self):
@ -308,7 +314,6 @@ class ScaleDP(PreTrainedModel):
nn.init.constant_(self.final_layer.linear.weight, 0)
nn.init.constant_(self.final_layer.linear.bias, 0)
def get_optim_groups(self, weight_decay: float = 1e-3):
"""
This long function is unfortunately doing something very simple and is being very defensive:
@ -324,7 +329,7 @@ class ScaleDP(PreTrainedModel):
blacklist_weight_modules = (torch.nn.LayerNorm, torch.nn.Embedding)
for mn, m in self.named_modules():
for pn, p in m.named_parameters():
fpn = "%s.%s" % (mn, pn) if mn else pn # full param name
fpn = "{}.{}".format(mn, pn) if mn else pn # full param name
if pn.endswith("bias"):
# all biases will not be decayed
@ -343,13 +348,13 @@ class ScaleDP(PreTrainedModel):
param_dict = {pn: p for pn, p in self.named_parameters()}
inter_params = decay & no_decay
union_params = decay | no_decay
assert (
len(inter_params) == 0
), "parameters %s made it into both decay/no_decay sets!" % (str(inter_params),)
assert (
len(param_dict.keys() - union_params) == 0
), "parameters %s were not separated into either decay/no_decay set!" % (
str(param_dict.keys() - union_params),
assert len(inter_params) == 0, "parameters {} made it into both decay/no_decay sets!".format(
str(inter_params)
)
assert len(param_dict.keys() - union_params) == 0, (
"parameters {} were not separated into either decay/no_decay set!".format(
str(param_dict.keys() - union_params),
)
)
# create the pytorch optimizer object
@ -365,14 +370,14 @@ class ScaleDP(PreTrainedModel):
]
return optim_groups
def configure_optimizers(self,
learning_rate: float = 1e-4,
weight_decay: float = 1e-3,
betas: Tuple[float, float] = (0.9, 0.95)):
def configure_optimizers(
self,
learning_rate: float = 1e-4,
weight_decay: float = 1e-3,
betas: Tuple[float, float] = (0.9, 0.95),
):
optim_groups = self.get_optim_groups(weight_decay=weight_decay)
optimizer = torch.optim.AdamW(
optim_groups, lr=learning_rate, betas=betas
)
optimizer = torch.optim.AdamW(optim_groups, lr=learning_rate, betas=betas)
return optimizer
def forward(self, actions, hidden_states, states, is_pad):
@ -385,25 +390,26 @@ class ScaleDP(PreTrainedModel):
"""
if actions is not None: # training time
B = actions.size(0)
actions = actions[:, :self.num_queries]
is_pad = is_pad[:, :self.num_queries]
actions = actions[:, : self.num_queries]
is_pad = is_pad[:, : self.num_queries]
num_noise_samples = self.noise_samples
# sample noise to add to actions
noise = torch.randn([num_noise_samples] + list(actions.shape), device=actions.device,
dtype=actions.dtype) # num_noise, B, Ta, D(1, 2, 16, 14)
noise = torch.randn(
[num_noise_samples] + list(actions.shape), device=actions.device, dtype=actions.dtype
) # num_noise, B, Ta, D(1, 2, 16, 14)
# sample a diffusion iteration for each data point
timesteps = torch.randint(
0, self.noise_scheduler.config.num_train_timesteps,
(B,), device=actions.device
0, self.noise_scheduler.config.num_train_timesteps, (B,), device=actions.device
).long()
timesteps, noise = timesteps.to(actions.device), noise.to(actions.device)
# add noise to the clean actions according to the noise magnitude at each diffusion iteration
# (this is the forward diffusion process)
noisy_actions = torch.cat([self.noise_scheduler.add_noise(
actions, noise[i], timesteps)
for i in range(len(noise))], dim=0) # [num_noise_samples * B, Ta, action_dim]
noisy_actions = torch.cat(
[self.noise_scheduler.add_noise(actions, noise[i], timesteps) for i in range(len(noise))],
dim=0,
) # [num_noise_samples * B, Ta, action_dim]
noisy_actions = noisy_actions.to(dtype=actions.dtype)
assert hidden_states.ndim == 3
@ -411,14 +417,16 @@ class ScaleDP(PreTrainedModel):
hidden_states = hidden_states.repeat(num_noise_samples, 1, 1)
timesteps = timesteps.repeat(num_noise_samples)
is_pad = is_pad.repeat(num_noise_samples, 1)
states = states.repeat(num_noise_samples, 1)
states = states.repeat(num_noise_samples, 1)
noise_pred = self.model_forward(noisy_actions, timesteps, global_cond=hidden_states, states=states)
noise_pred = self.model_forward(
noisy_actions, timesteps, global_cond=hidden_states, states=states
)
noise = noise.view(noise.size(0) * noise.size(1), *noise.size()[2:])
loss = torch.nn.functional.mse_loss(noise_pred, noise, reduction='none')
loss = torch.nn.functional.mse_loss(noise_pred, noise, reduction="none")
loss = (loss * ~is_pad.unsqueeze(-1)).mean()
# loss_dict['loss'] = loss
return {'loss': loss}
return {"loss": loss}
# return loss
else: # inference time
B = 1
@ -438,9 +446,7 @@ class ScaleDP(PreTrainedModel):
# inverse diffusion step (remove noise)
naction = self.noise_scheduler.step(
model_output=noise_pred,
timestep=k,
sample=naction
model_output=noise_pred, timestep=k, sample=naction
).prev_sample
return naction
@ -462,7 +468,9 @@ class ScaleDP(PreTrainedModel):
t = t[None].to(x.device)
t = t.expand(t.shape[0])
x = self.x_embedder(x) + self.pos_embed.to(device=x.device, dtype=x.dtype) # (N, T, D), where T = prediction_horizon
x = self.x_embedder(x) + self.pos_embed.to(
device=x.device, dtype=x.dtype
) # (N, T, D), where T = prediction_horizon
t = self.t_embedder(t) # (N, D)
if self.obs_as_cond:
global_cond = self.cond_obs_emb(global_cond) # (N, D)
@ -474,11 +482,13 @@ class ScaleDP(PreTrainedModel):
x = self.final_layer(x, c) # (N, T, output_dim)
return x
#################################################################################
# Sine/Cosine Positional Embedding Functions #
#################################################################################
# https://github.com/facebookresearch/mae/blob/main/util/pos_embed.py
def get_2d_sincos_pos_embed(embed_dim, grid_size, cls_token=False, extra_tokens=0):
"""
grid_size: int of the grid height and width
@ -516,11 +526,11 @@ def get_1d_sincos_pos_embed_from_grid(embed_dim, pos):
"""
assert embed_dim % 2 == 0
omega = np.arange(embed_dim // 2, dtype=np.float64)
omega /= embed_dim / 2.
omega = 1. / 10000 ** omega # (D/2,)
omega /= embed_dim / 2.0
omega = 1.0 / 10000**omega # (D/2,)
pos = pos.reshape(-1) # (M,)
out = np.einsum('m,d->md', pos, omega) # (M, D/2), outer product
out = np.einsum("m,d->md", pos, omega) # (M, D/2), outer product
emb_sin = np.sin(out) # (M, D/2)
emb_cos = np.cos(out) # (M, D/2)
@ -533,12 +543,13 @@ def get_1d_sincos_pos_embed_from_grid(embed_dim, pos):
# ScaleDP Configs #
#################################################################################
def ScaleDP_H(**kwargs):
return ScaleDP(depth=32, n_emb=1280, num_heads=16, **kwargs)
def ScaleDP_L(**kwargs):
return ScaleDP(depth=24, n_emb=1024, num_heads=16, **kwargs)
AutoModel.register(ScaleDPPolicyConfig, ScaleDP)

View File

@ -1,29 +1,29 @@
"""
Implementation of Diffusion Policy https://diffusion-policy.cs.columbia.edu/ by Cheng Chi
"""
from typing import Callable, Union
import copy
import math
from collections import OrderedDict, deque
from packaging.version import parse as parse_version
import random
from typing import Union
import torch
import torch.nn as nn
import torch.nn.functional as F
# requires diffusers==0.11.1
from diffusers.schedulers.scheduling_ddpm import DDPMScheduler
from diffusers.schedulers.scheduling_ddim import DDIMScheduler
from diffusers.training_utils import EMAModel
from .configuration_unet_diffusion import UnetDiffusionPolicyConfig
from transformers import AutoModel
from transformers.modeling_utils import PreTrainedModel
from transformers import AutoModel, AutoModelForCausalLM
import copy
from .configuration_unet_diffusion import UnetDiffusionPolicyConfig
# =================== UNet for Diffusion ==============
class SinusoidalPosEmb(nn.Module):
def __init__(self, dim, dtype):
super().__init__()
self.dim = dim
self.dtype=dtype
self.dtype = dtype
def forward(self, x):
device = x.device
@ -54,9 +54,9 @@ class Upsample1d(nn.Module):
class Conv1dBlock(nn.Module):
'''
Conv1d --> GroupNorm --> Mish
'''
"""
Conv1d --> GroupNorm --> Mish
"""
def __init__(self, inp_channels, out_channels, kernel_size, n_groups=8):
super().__init__()
@ -72,46 +72,41 @@ class Conv1dBlock(nn.Module):
class ConditionalResidualBlock1D(nn.Module):
def __init__(self,
in_channels,
out_channels,
cond_dim,
kernel_size=3,
n_groups=8):
def __init__(self, in_channels, out_channels, cond_dim, kernel_size=3, n_groups=8):
super().__init__()
self.blocks = nn.ModuleList([
Conv1dBlock(in_channels, out_channels, kernel_size, n_groups=n_groups),
Conv1dBlock(out_channels, out_channels, kernel_size, n_groups=n_groups),
])
self.blocks = nn.ModuleList(
[
Conv1dBlock(in_channels, out_channels, kernel_size, n_groups=n_groups),
Conv1dBlock(out_channels, out_channels, kernel_size, n_groups=n_groups),
]
)
# FiLM modulation https://arxiv.org/abs/1709.07871
# predicts per-channel scale and bias
cond_channels = out_channels * 2
self.out_channels = out_channels
self.cond_encoder = nn.Sequential(
nn.Mish(),
nn.Linear(cond_dim, cond_channels),
nn.Unflatten(-1, (-1, 1))
nn.Mish(), nn.Linear(cond_dim, cond_channels), nn.Unflatten(-1, (-1, 1))
)
# make sure dimensions compatible
self.residual_conv = nn.Conv1d(in_channels, out_channels, 1) \
if in_channels != out_channels else nn.Identity()
self.residual_conv = (
nn.Conv1d(in_channels, out_channels, 1) if in_channels != out_channels else nn.Identity()
)
def forward(self, x, cond):
'''
x : [ batch_size x in_channels x horizon ]
cond : [ batch_size x cond_dim]
"""
x : [ batch_size x in_channels x horizon ]
cond : [ batch_size x cond_dim]
returns:
out : [ batch_size x out_channels x horizon ]
'''
returns:
out : [ batch_size x out_channels x horizon ]
"""
out = self.blocks[0](x)
embed = self.cond_encoder(cond)
embed = embed.reshape(
embed.shape[0], 2, self.out_channels, 1)
embed = embed.reshape(embed.shape[0], 2, self.out_channels, 1)
scale = embed[:, 0, ...]
bias = embed[:, 1, ...]
out = scale * out + bias
@ -125,9 +120,8 @@ class ConditionalUnet1D(PreTrainedModel):
_no_split_modules = ["mid_modules", "down_modules", "up_modules"]
config_class = UnetDiffusionPolicyConfig
def __init__(self,
config: UnetDiffusionPolicyConfig
):
def __init__(self, config: UnetDiffusionPolicyConfig):
"""
input_dim: Dim of actions.
global_cond_dim: Dim of global conditioning applied with FiLM
@ -148,7 +142,7 @@ class ConditionalUnet1D(PreTrainedModel):
# self.global_1d_pool = nn.AdaptiveAvgPool1d(1)
# self.proj2action = nn.Linear(config.hidden_dim, config.global_cond_dim)
self.norm_after_pool = nn.LayerNorm(config.global_cond_dim)
self.combine = nn.Linear(config.global_cond_dim+config.state_dim, config.global_cond_dim)
self.combine = nn.Linear(config.global_cond_dim + config.state_dim, config.global_cond_dim)
dsed = config.diffusion_step_embed_dim
diffusion_step_encoder = nn.Sequential(
SinusoidalPosEmb(dsed, torch.bfloat16),
@ -158,44 +152,76 @@ class ConditionalUnet1D(PreTrainedModel):
)
cond_dim = dsed + config.global_cond_dim
in_out = list(zip(all_dims[:-1], all_dims[1:]))
in_out = list(zip(all_dims[:-1], all_dims[1:], strict=False))
mid_dim = all_dims[-1]
self.mid_modules = nn.ModuleList([
ConditionalResidualBlock1D(
mid_dim, mid_dim, cond_dim=cond_dim,
kernel_size=config.kernel_size, n_groups=config.n_groups
),
ConditionalResidualBlock1D(
mid_dim, mid_dim, cond_dim=cond_dim,
kernel_size=config.kernel_size, n_groups=config.n_groups
),
])
self.mid_modules = nn.ModuleList(
[
ConditionalResidualBlock1D(
mid_dim,
mid_dim,
cond_dim=cond_dim,
kernel_size=config.kernel_size,
n_groups=config.n_groups,
),
ConditionalResidualBlock1D(
mid_dim,
mid_dim,
cond_dim=cond_dim,
kernel_size=config.kernel_size,
n_groups=config.n_groups,
),
]
)
down_modules = nn.ModuleList([])
for ind, (dim_in, dim_out) in enumerate(in_out):
is_last = ind >= (len(in_out) - 1)
down_modules.append(nn.ModuleList([
ConditionalResidualBlock1D(
dim_in, dim_out, cond_dim=cond_dim,
kernel_size=config.kernel_size, n_groups=config.n_groups),
ConditionalResidualBlock1D(
dim_out, dim_out, cond_dim=cond_dim,
kernel_size=config.kernel_size, n_groups=config.n_groups),
Downsample1d(dim_out) if not is_last else nn.Identity()
]))
down_modules.append(
nn.ModuleList(
[
ConditionalResidualBlock1D(
dim_in,
dim_out,
cond_dim=cond_dim,
kernel_size=config.kernel_size,
n_groups=config.n_groups,
),
ConditionalResidualBlock1D(
dim_out,
dim_out,
cond_dim=cond_dim,
kernel_size=config.kernel_size,
n_groups=config.n_groups,
),
Downsample1d(dim_out) if not is_last else nn.Identity(),
]
)
)
up_modules = nn.ModuleList([])
for ind, (dim_in, dim_out) in enumerate(reversed(in_out[1:])):
is_last = ind >= (len(in_out) - 1)
up_modules.append(nn.ModuleList([
ConditionalResidualBlock1D(
dim_out * 2, dim_in, cond_dim=cond_dim,
kernel_size=config.kernel_size, n_groups=config.n_groups),
ConditionalResidualBlock1D(
dim_in, dim_in, cond_dim=cond_dim,
kernel_size=config.kernel_size, n_groups=config.n_groups),
Upsample1d(dim_in) if not is_last else nn.Identity()
]))
up_modules.append(
nn.ModuleList(
[
ConditionalResidualBlock1D(
dim_out * 2,
dim_in,
cond_dim=cond_dim,
kernel_size=config.kernel_size,
n_groups=config.n_groups,
),
ConditionalResidualBlock1D(
dim_in,
dim_in,
cond_dim=cond_dim,
kernel_size=config.kernel_size,
n_groups=config.n_groups,
),
Upsample1d(dim_in) if not is_last else nn.Identity(),
]
)
)
final_conv = nn.Sequential(
Conv1dBlock(start_dim, start_dim, kernel_size=config.kernel_size),
@ -207,20 +233,17 @@ class ConditionalUnet1D(PreTrainedModel):
self.down_modules = down_modules
self.final_conv = final_conv
print("number of parameters: {:e}".format(
sum(p.numel() for p in self.parameters()))
)
print("number of parameters: {:e}".format(sum(p.numel() for p in self.parameters())))
from diffusers.schedulers.scheduling_ddim import DDIMScheduler
self.num_inference_timesteps = config.num_inference_timesteps
# self.proj_to_action = nn.Identity()
self.noise_scheduler = DDIMScheduler(
num_train_timesteps=config.num_train_timesteps, # 100
beta_schedule='squaredcos_cap_v2',
beta_schedule="squaredcos_cap_v2",
clip_sample=True,
set_alpha_to_one=True,
steps_offset=0,
prediction_type='epsilon'
prediction_type="epsilon",
)
# self.num_inference_timesteps = config.num_inference_timesteps # 100
@ -235,25 +258,26 @@ class ConditionalUnet1D(PreTrainedModel):
"""
if actions is not None: # training time
B = actions.size(0)
actions = copy.deepcopy(actions[:, :self.num_queries])
is_pad = copy.deepcopy(is_pad[:, :self.num_queries])
actions = copy.deepcopy(actions[:, : self.num_queries])
is_pad = copy.deepcopy(is_pad[:, : self.num_queries])
num_noise_samples = self.noise_samples
# sample noise to add to actions
noise = torch.randn([num_noise_samples] + list(actions.shape), device=actions.device,
dtype=actions.dtype) # num_noise, B, Ta, D
noise = torch.randn(
[num_noise_samples] + list(actions.shape), device=actions.device, dtype=actions.dtype
) # num_noise, B, Ta, D
# sample a diffusion iteration for each data point
timesteps = torch.randint(
0, self.noise_scheduler.config.num_train_timesteps,
(B,), device=actions.device
0, self.noise_scheduler.config.num_train_timesteps, (B,), device=actions.device
).long()
timesteps, noise = timesteps.to(actions.device), noise.to(actions.device)
# add noise to the clean actions according to the noise magnitude at each diffusion iteration
# (this is the forward diffusion process)
noisy_actions = torch.cat([self.noise_scheduler.add_noise(
actions, noise[i], timesteps)
for i in range(len(noise))], dim=0) # [num_noise_samples * B, Ta, action_dim]
noisy_actions = torch.cat(
[self.noise_scheduler.add_noise(actions, noise[i], timesteps) for i in range(len(noise))],
dim=0,
) # [num_noise_samples * B, Ta, action_dim]
noisy_actions = noisy_actions.to(dtype=actions.dtype)
assert hidden_states.ndim == 3
@ -263,12 +287,14 @@ class ConditionalUnet1D(PreTrainedModel):
is_pad = is_pad.repeat(num_noise_samples, 1)
states = states.repeat(num_noise_samples, 1)
noise_pred = self.model_forward(noisy_actions, timesteps, global_cond=hidden_states, states=states)
noise_pred = self.model_forward(
noisy_actions, timesteps, global_cond=hidden_states, states=states
)
noise = noise.view(noise.size(0) * noise.size(1), *noise.size()[2:])
loss = torch.nn.functional.mse_loss(noise_pred, noise, reduction='none')
loss = torch.nn.functional.mse_loss(noise_pred, noise, reduction="none")
loss = (loss * ~is_pad.unsqueeze(-1)).mean()
# loss_dict['loss'] = loss
return {'loss': loss}
return {"loss": loss}
# return loss
else: # inference time
B = 1
@ -288,18 +314,14 @@ class ConditionalUnet1D(PreTrainedModel):
# inverse diffusion step (remove noise)
naction = self.noise_scheduler.step(
model_output=noise_pred,
timestep=k,
sample=naction
model_output=noise_pred, timestep=k, sample=naction
).prev_sample
return naction
def model_forward(self,
sample: torch.Tensor,
timestep: Union[torch.Tensor, float, int],
global_cond=None,
states=None):
def model_forward(
self, sample: torch.Tensor, timestep: Union[torch.Tensor, float, int], global_cond=None, states=None
):
"""
x: (B,T,input_dim)
timestep: (B,) or int, diffusion step
@ -327,9 +349,7 @@ class ConditionalUnet1D(PreTrainedModel):
global_feature = self.diffusion_step_encoder(timesteps)
if global_cond is not None:
global_feature = torch.cat([
global_feature, global_cond
], axis=-1)
global_feature = torch.cat([global_feature, global_cond], axis=-1)
x = sample
h = []
@ -355,4 +375,5 @@ class ConditionalUnet1D(PreTrainedModel):
# (B,T,C)
return x
AutoModel.register(UnetDiffusionPolicyConfig, ConditionalUnet1D)

View File

@ -1,4 +1,3 @@
# coding=utf-8
# Copyright 2024 The Qwen team, Alibaba Group and the HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
@ -17,10 +16,10 @@
import os
from typing import Union
from transformers import AutoConfig
from transformers.configuration_utils import PretrainedConfig
from transformers.modeling_rope_utils import rope_config_validation
from transformers.utils import logging
from transformers import AutoModel, AutoConfig
logger = logging.get_logger(__name__)
@ -56,7 +55,9 @@ class Qwen2VLVisionConfig(PretrainedConfig):
self.temporal_patch_size = temporal_patch_size
@classmethod
def from_pretrained(cls, pretrained_model_name_or_path: Union[str, os.PathLike], **kwargs) -> "PretrainedConfig":
def from_pretrained(
cls, pretrained_model_name_or_path: Union[str, os.PathLike], **kwargs
) -> "PretrainedConfig":
cls._set_token_in_kwargs(kwargs)
config_dict, kwargs = cls.get_config_dict(pretrained_model_name_or_path, **kwargs)
@ -64,7 +65,11 @@ class Qwen2VLVisionConfig(PretrainedConfig):
if config_dict.get("model_type") == "qwen2_vl":
config_dict = config_dict["vision_config"]
if "model_type" in config_dict and hasattr(cls, "model_type") and config_dict["model_type"] != cls.model_type:
if (
"model_type" in config_dict
and hasattr(cls, "model_type")
and config_dict["model_type"] != cls.model_type
):
logger.warning(
f"You are using a model of type {config_dict['model_type']} to instantiate a model of type "
f"{cls.model_type}. This is not supported for all configurations of models and can yield errors."
@ -204,7 +209,7 @@ class Qwen2VLAConfig(PretrainedConfig):
vision_config=None,
rope_scaling=None,
# For loading policy head
policy_head_type='scale_dp_policy', # unet_diffusion_policy
policy_head_type="scale_dp_policy", # unet_diffusion_policy
**kwargs,
):
if isinstance(vision_config, dict):
@ -221,7 +226,7 @@ class Qwen2VLAConfig(PretrainedConfig):
self.use_sliding_window = use_sliding_window
self.sliding_window = sliding_window
self.max_window_layers = max_window_layers
self.policy_head_type = policy_head_type # for loading policy head
self.policy_head_type = policy_head_type # for loading policy head
# for backward compatibility
if num_key_value_heads is None:
@ -248,5 +253,5 @@ class Qwen2VLAConfig(PretrainedConfig):
super().__init__(tie_word_embeddings=tie_word_embeddings, **kwargs)
from transformers import AutoConfig
AutoConfig.register("qwen2_vla", Qwen2VLAConfig)

View File

@ -1,4 +1,3 @@
# coding=utf-8
# Copyright 2024 The Qwen team, Alibaba Group and the HuggingFace Inc. team. All rights reserved.
#
# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
@ -19,6 +18,7 @@
# limitations under the License.
"""PyTorch Qwen2-VL model."""
import gc
import math
from dataclasses import dataclass
from typing import Any, Dict, List, Optional, Tuple, Union
@ -28,7 +28,7 @@ import torch.nn as nn
import torch.nn.functional as F
import torch.utils.checkpoint
from torch.nn import CrossEntropyLoss, LayerNorm
from transformers import AutoConfig, AutoModel
from transformers.activations import ACT2FN
from transformers.cache_utils import Cache, SlidingWindowCache, StaticCache
from transformers.generation import GenerationMixin
@ -37,8 +37,6 @@ from transformers.modeling_outputs import (
BaseModelOutputWithPast,
ModelOutput,
)
from lerobot.common.policies.dexvla.fusion_modules import *
from transformers.modeling_rope_utils import ROPE_INIT_FUNCTIONS
from transformers.modeling_utils import PreTrainedModel
from transformers.utils import (
@ -49,14 +47,13 @@ from transformers.utils import (
logging,
replace_return_docstrings,
)
from .configuration_qwen2_vla import Qwen2VLAConfig, Qwen2VLVisionConfig
from transformers import AutoConfig, AutoModel
import gc
from lerobot.common.policies.dexvla.fusion_modules import *
from .configuration_qwen2_vla import Qwen2VLAConfig, Qwen2VLVisionConfig
if is_flash_attn_2_available():
from flash_attn import flash_attn_varlen_func
from transformers.modeling_flash_attention_utils import _flash_attention_forward
else:
flash_attn_varlen_func = None
@ -164,7 +161,9 @@ class Qwen2VLRotaryEmbedding(nn.Module):
self.register_buffer("inv_freq", inv_freq, persistent=False)
self.max_seq_len_cached = seq_len
if seq_len < self.original_max_seq_len and self.max_seq_len_cached > self.original_max_seq_len: # reset
if (
seq_len < self.original_max_seq_len and self.max_seq_len_cached > self.original_max_seq_len
): # reset
self.register_buffer("inv_freq", self.original_inv_freq, persistent=False)
self.max_seq_len_cached = self.original_max_seq_len
@ -335,7 +334,9 @@ class VisionAttention(nn.Module):
self, hidden_states: torch.Tensor, cu_seqlens: torch.Tensor, rotary_pos_emb: torch.Tensor = None
) -> torch.Tensor:
seq_length = hidden_states.shape[0]
q, k, v = self.qkv(hidden_states).reshape(seq_length, 3, self.num_heads, -1).permute(1, 0, 2, 3).unbind(0)
q, k, v = (
self.qkv(hidden_states).reshape(seq_length, 3, self.num_heads, -1).permute(1, 0, 2, 3).unbind(0)
)
q = apply_rotary_pos_emb_vision(q.unsqueeze(0), rotary_pos_emb).squeeze(0)
k = apply_rotary_pos_emb_vision(k.unsqueeze(0), rotary_pos_emb).squeeze(0)
@ -369,7 +370,9 @@ class VisionFlashAttention2(nn.Module):
self, hidden_states: torch.Tensor, cu_seqlens: torch.Tensor, rotary_pos_emb: torch.Tensor = None
) -> torch.Tensor:
seq_length = hidden_states.shape[0]
q, k, v = self.qkv(hidden_states).reshape(seq_length, 3, self.num_heads, -1).permute(1, 0, 2, 3).unbind(0)
q, k, v = (
self.qkv(hidden_states).reshape(seq_length, 3, self.num_heads, -1).permute(1, 0, 2, 3).unbind(0)
)
q = apply_rotary_pos_emb_vision(q.unsqueeze(0), rotary_pos_emb).squeeze(0)
k = apply_rotary_pos_emb_vision(k.unsqueeze(0), rotary_pos_emb).squeeze(0)
@ -392,7 +395,9 @@ class VisionSdpaAttention(nn.Module):
self, hidden_states: torch.Tensor, cu_seqlens: torch.Tensor, rotary_pos_emb: torch.Tensor = None
) -> torch.Tensor:
seq_length = hidden_states.shape[0]
q, k, v = self.qkv(hidden_states).reshape(seq_length, 3, self.num_heads, -1).permute(1, 0, 2, 3).unbind(0)
q, k, v = (
self.qkv(hidden_states).reshape(seq_length, 3, self.num_heads, -1).permute(1, 0, 2, 3).unbind(0)
)
q = apply_rotary_pos_emb_vision(q.unsqueeze(0), rotary_pos_emb).squeeze(0)
k = apply_rotary_pos_emb_vision(k.unsqueeze(0), rotary_pos_emb).squeeze(0)
@ -538,7 +543,9 @@ class Qwen2VLAttention(nn.Module):
output_attentions: bool = False,
use_cache: bool = False,
cache_position: Optional[torch.LongTensor] = None,
position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.46
position_embeddings: Optional[
Tuple[torch.Tensor, torch.Tensor]
] = None, # will become mandatory in v4.46
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
bsz, q_len, _ = hidden_states.size()
@ -569,8 +576,14 @@ class Qwen2VLAttention(nn.Module):
)
if past_key_value is not None:
cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} # Specific to RoPE models
key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
cache_kwargs = {
"sin": sin,
"cos": cos,
"cache_position": cache_position,
} # Specific to RoPE models
key_states, value_states = past_key_value.update(
key_states, value_states, self.layer_idx, cache_kwargs
)
# repeat k/v heads if n_kv_heads < n_heads
key_states = repeat_kv(key_states, self.num_key_value_groups)
@ -585,7 +598,9 @@ class Qwen2VLAttention(nn.Module):
# Fix precision issues in Qwen2-VL float16 inference
# Replace inf values with zeros in attention weights to prevent NaN propagation
if query_states.dtype == torch.float16:
attn_weights = torch.where(torch.isinf(attn_weights), torch.zeros_like(attn_weights), attn_weights)
attn_weights = torch.where(
torch.isinf(attn_weights), torch.zeros_like(attn_weights), attn_weights
)
# upcast attention to fp32
attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
@ -634,7 +649,9 @@ class Qwen2VLFlashAttention2(Qwen2VLAttention):
output_attentions: bool = False,
use_cache: bool = False,
cache_position: Optional[torch.LongTensor] = None,
position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.46
position_embeddings: Optional[
Tuple[torch.Tensor, torch.Tensor]
] = None, # will become mandatory in v4.46
):
bsz, q_len, _ = hidden_states.size()
@ -696,10 +713,18 @@ class Qwen2VLFlashAttention2(Qwen2VLAttention):
if attention_mask is not None:
attention_mask = attention_mask[:, slicing_tokens:]
attention_mask = torch.cat([attention_mask, torch.ones_like(attention_mask[:, -1:])], dim=-1)
attention_mask = torch.cat(
[attention_mask, torch.ones_like(attention_mask[:, -1:])], dim=-1
)
cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} # Specific to RoPE models
key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
cache_kwargs = {
"sin": sin,
"cos": cos,
"cache_position": cache_position,
} # Specific to RoPE models
key_states, value_states = past_key_value.update(
key_states, value_states, self.layer_idx, cache_kwargs
)
# repeat k/v heads if n_kv_heads < n_heads
key_states = repeat_kv(key_states, self.num_key_value_groups)
@ -781,7 +806,9 @@ class Qwen2VLSdpaAttention(Qwen2VLAttention):
output_attentions: bool = False,
use_cache: bool = False,
cache_position: Optional[torch.LongTensor] = None,
position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.46
position_embeddings: Optional[
Tuple[torch.Tensor, torch.Tensor]
] = None, # will become mandatory in v4.46
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
if output_attentions:
logger.warning_once(
@ -826,8 +853,14 @@ class Qwen2VLSdpaAttention(Qwen2VLAttention):
)
if past_key_value is not None:
cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} # Specific to RoPE models
key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
cache_kwargs = {
"sin": sin,
"cos": cos,
"cache_position": cache_position,
} # Specific to RoPE models
key_states, value_states = past_key_value.update(
key_states, value_states, self.layer_idx, cache_kwargs
)
key_states = repeat_kv(key_states, self.num_key_value_groups)
value_states = repeat_kv(value_states, self.num_key_value_groups)
@ -897,7 +930,9 @@ class Qwen2VLDecoderLayer(nn.Module):
output_attentions: Optional[bool] = False,
use_cache: Optional[bool] = False,
cache_position: Optional[torch.LongTensor] = None,
position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.46
position_embeddings: Optional[
Tuple[torch.Tensor, torch.Tensor]
] = None, # will become mandatory in v4.46
**kwargs,
) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
"""
@ -1116,7 +1151,9 @@ class Qwen2VLModel(Qwen2VLPreTrainedModel):
return_dict: Optional[bool] = None,
cache_position: Optional[torch.LongTensor] = None,
) -> Union[Tuple, BaseModelOutputWithPast]:
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_attentions = (
output_attentions if output_attentions is not None else self.config.output_attentions
)
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
@ -1208,7 +1245,9 @@ class Qwen2VLModel(Qwen2VLPreTrainedModel):
next_cache = next_decoder_cache if use_cache else None
if not return_dict:
return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None)
return tuple(
v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None
)
return BaseModelOutputWithPast(
last_hidden_state=hidden_states,
past_key_values=next_cache,
@ -1436,6 +1475,7 @@ QWEN2_VL_INPUTS_DOCSTRING = r"""
The rope index difference between sequence length and multimodal rope.
"""
class Qwen2VLForConditionalGenerationForVLA(Qwen2VLPreTrainedModel, GenerationMixin):
_tied_weights_keys = ["lm_head.weight"]
@ -1599,9 +1639,15 @@ class Qwen2VLForConditionalGenerationForVLA(Qwen2VLPreTrainedModel, GenerationMi
st_idx = llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0
llm_pos_ids_list.append(torch.arange(text_len).view(1, -1).expand(3, -1) + st_idx)
t_index = torch.arange(llm_grid_t).view(-1, 1).expand(-1, llm_grid_h * llm_grid_w).flatten()
h_index = torch.arange(llm_grid_h).view(1, -1, 1).expand(llm_grid_t, -1, llm_grid_w).flatten()
w_index = torch.arange(llm_grid_w).view(1, 1, -1).expand(llm_grid_t, llm_grid_h, -1).flatten()
t_index = (
torch.arange(llm_grid_t).view(-1, 1).expand(-1, llm_grid_h * llm_grid_w).flatten()
)
h_index = (
torch.arange(llm_grid_h).view(1, -1, 1).expand(llm_grid_t, -1, llm_grid_w).flatten()
)
w_index = (
torch.arange(llm_grid_w).view(1, 1, -1).expand(llm_grid_t, llm_grid_h, -1).flatten()
)
llm_pos_ids_list.append(torch.stack([t_index, h_index, w_index]) + text_len + st_idx)
st = ed + llm_grid_t * llm_grid_h * llm_grid_w
@ -1721,18 +1767,20 @@ class Qwen2VLForConditionalGenerationForVLA(Qwen2VLPreTrainedModel, GenerationMi
```"""
self.computed_type = torch.bfloat16
input_ids=input_ids.to("cuda")
attention_mask=attention_mask.to("cuda")
input_ids = input_ids.to("cuda")
attention_mask = attention_mask.to("cuda")
if not is_eval:
labels = labels.to("cuda")
actions = actions.to(dtype=self.computed_type, device='cuda')
states = states.to(dtype=self.computed_type, device='cuda')
actions = actions.to(dtype=self.computed_type, device="cuda")
states = states.to(dtype=self.computed_type, device="cuda")
position_ids, rope_deltas = self.get_rope_index(
input_ids, image_grid_thw, video_grid_thw, attention_mask
)
if pixel_values is not None:
pixel_values = pixel_values.to(dtype=self.computed_type, device='cuda')
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
pixel_values = pixel_values.to(dtype=self.computed_type, device="cuda")
output_attentions = (
output_attentions if output_attentions is not None else self.config.output_attentions
)
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
@ -1792,7 +1840,7 @@ class Qwen2VLForConditionalGenerationForVLA(Qwen2VLPreTrainedModel, GenerationMi
)
hidden_states = outputs[0]
if tinyvla: # dex-vla supports tinyvla-style VLA
if tinyvla: # dex-vla supports tinyvla-style VLA
return hidden_states
if self.with_llm_head:
@ -1833,21 +1881,28 @@ class Qwen2VLForConditionalGenerationForVLA(Qwen2VLPreTrainedModel, GenerationMi
)
if self.using_film:
action_hidden_states = self.film_forward(labels=labels, input_ids=input_ids,
hidden_states=hidden_states)
else: # tinyvla
action_hidden_states = self.film_forward(
labels=labels, input_ids=input_ids, hidden_states=hidden_states
)
else: # tinyvla
action_hidden_states = hidden_states
ret = self.policy_head(actions=actions, hidden_states=action_hidden_states, states=states, is_pad=is_pad)
ret = self.policy_head(
actions=actions, hidden_states=action_hidden_states, states=states, is_pad=is_pad
)
if self.with_llm_head:
loss = {'loss': ret['loss'] + self.llm_loss_weight * llm_loss,
'llm_loss': llm_loss,
'action_loss': ret['loss']}
loss = {
"loss": ret["loss"] + self.llm_loss_weight * llm_loss,
"llm_loss": llm_loss,
"action_loss": ret["loss"],
}
else:
loss = {'loss': ret['loss'],
'llm_loss': (torch.ones(1)*(-100)).to(ret['loss'].dtype).squeeze(0),
'action_loss': ret['loss']}
loss = {
"loss": ret["loss"],
"llm_loss": (torch.ones(1) * (-100)).to(ret["loss"].dtype).squeeze(0),
"action_loss": ret["loss"],
}
if not return_dict:
output = (logits,) + outputs[1:]
@ -1904,30 +1959,32 @@ class Qwen2VLForConditionalGenerationForVLA(Qwen2VLPreTrainedModel, GenerationMi
return action_hidden_states
def prepare_inputs_for_generation(
self,
input_ids,
past_key_values=None,
attention_mask=None,
inputs_embeds=None,
cache_position=None,
position_ids=None,
use_cache=True,
pixel_values=None,
pixel_values_videos=None,
image_grid_thw=None,
video_grid_thw=None,
**kwargs,
self,
input_ids,
past_key_values=None,
attention_mask=None,
inputs_embeds=None,
cache_position=None,
position_ids=None,
use_cache=True,
pixel_values=None,
pixel_values_videos=None,
image_grid_thw=None,
video_grid_thw=None,
**kwargs,
):
# If we have cache: let's slice `input_ids` through `cache_position`, to keep only the unprocessed tokens
# Exception 1: when passing input_embeds, input_ids may be missing entries
# Exception 2: some generation methods do special slicing of input_ids, so we don't need to do it here
if past_key_values is not None:
if inputs_embeds is not None: # Exception 1
input_ids = input_ids[:, -cache_position.shape[0]:]
elif input_ids.shape[1] != cache_position.shape[0]: # Default case (the "else", a no op, is Exception 2)
input_ids = input_ids[:, -cache_position.shape[0] :]
elif (
input_ids.shape[1] != cache_position.shape[0]
): # Default case (the "else", a no op, is Exception 2)
input_ids = input_ids[:, cache_position]
rope_deltas = kwargs.get("rope_deltas", None)
rope_deltas = kwargs.get("rope_deltas")
if attention_mask is not None and position_ids is None:
if cache_position is None or (cache_position is not None and cache_position[0] == 0):
position_ids, rope_deltas = self.get_rope_index(
@ -1936,7 +1993,9 @@ class Qwen2VLForConditionalGenerationForVLA(Qwen2VLPreTrainedModel, GenerationMi
else:
batch_size, seq_length = input_ids.shape
delta = (
cache_position[0] + rope_deltas if cache_position is not None and rope_deltas is not None else 0
cache_position[0] + rope_deltas
if cache_position is not None and rope_deltas is not None
else 0
)
position_ids = torch.arange(seq_length, device=input_ids.device)
position_ids = position_ids.view(1, -1).expand(batch_size, -1)
@ -1990,6 +2049,6 @@ class Qwen2VLForConditionalGenerationForVLA(Qwen2VLPreTrainedModel, GenerationMi
return model_inputs
from transformers import AutoModelForCausalLM
AutoModelForCausalLM.register(Qwen2VLAConfig, Qwen2VLForConditionalGenerationForVLA)

View File

@ -1,16 +1,15 @@
from PIL import Image
import numpy as np
from torchvision.transforms.functional import to_pil_image, to_tensor
import torchvision.transforms as transforms
import torch
from qwen_vl_utils import process_vision_info
from PIL import Image
from qwen_vl_utils import fetch_image
class Qwen2VLAProcess:
def __init__(
self,
tokenizer=None,
max_seq_len=512,
multimodal_processor=None,
self,
tokenizer=None,
max_seq_len=512,
multimodal_processor=None,
):
super().__init__()
self.tokenizer = tokenizer
@ -20,10 +19,10 @@ class Qwen2VLAProcess:
def qwen2_image_preprocess(self, each):
ele = {}
each = Image.fromarray(each.squeeze(0).permute(1, 2, 0).cpu().numpy().astype(np.uint8))
ele['image'] = each
ele["image"] = each
ele['resized_height'] = each.height
ele['resized_width'] = each.width
ele["resized_height"] = each.height
ele["resized_width"] = each.width
each = fetch_image(ele)
return torch.from_numpy(np.array(each))
@ -58,61 +57,63 @@ class Qwen2VLAProcess:
if eval:
return model_inputs
input_labels = torch.ones_like(model_inputs['input_ids']) * -100
input_labels = torch.ones_like(model_inputs["input_ids"]) * -100
if use_reasoning:
answer =reasoning + "Next action:" + '<|im_end|>'
answer = reasoning + "Next action:" + "<|im_end|>"
else:
answer = '' + '<|im_end|>'
answer = "" + "<|im_end|>"
output_text = self.tokenizer(answer, padding=True, return_tensors="pt")
output_labels = output_text['input_ids']
model_inputs['input_ids'] = torch.cat((model_inputs['input_ids'], output_text['input_ids']), dim=-1)
model_inputs['attention_mask'] = torch.cat((model_inputs['attention_mask'], output_text['attention_mask']), dim=-1)
output_labels = output_text["input_ids"]
model_inputs["input_ids"] = torch.cat((model_inputs["input_ids"], output_text["input_ids"]), dim=-1)
model_inputs["attention_mask"] = torch.cat(
(model_inputs["attention_mask"], output_text["attention_mask"]), dim=-1
)
labels = torch.cat((input_labels, output_labels), dim=-1)
data_dict['labels'] = labels
data_dict["labels"] = labels
for k, v in model_inputs.items():
data_dict[k] = v
return data_dict
def forward(self, batch, use_reasoning=True):
"""This is the main process function for processing vl data into Qwen2_vl format"""
all_images = batch['images']
all_images = torch.einsum('v b c h w -> b v c h w', all_images) # camera_views, batch_size, channel, height, width
all_images = batch["images"]
all_images = torch.einsum(
"v b c h w -> b v c h w", all_images
) # camera_views, batch_size, channel, height, width
ret_l = []
for idx, images in enumerate(all_images):
raw_lang = batch['raw_langs'][idx]
reasoning = batch['reasonings'][idx]
raw_lang = batch["raw_langs"][idx]
reasoning = batch["reasonings"][idx]
ret_dict = self.single_forward_process(images, raw_lang, reasoning, use_reasoning=use_reasoning)
ret_l.append(ret_dict)
return self.post_process(ret_l)
def post_process(self, instances):
input_ids = [torch.flip(instance['input_ids'].squeeze(0), dims=[0]) for instance in instances]
labels = [torch.flip(instance['labels'].squeeze(0), dims=[0]) for instance in instances]
input_ids = [torch.flip(instance["input_ids"].squeeze(0), dims=[0]) for instance in instances]
labels = [torch.flip(instance["labels"].squeeze(0), dims=[0]) for instance in instances]
image_grid_thw = torch.stack([instances['image_grid_thw'] for instances in instances])
pixel_values = torch.stack([instances['pixel_values'] for instances in instances])
image_grid_thw = torch.stack([instances["image_grid_thw"] for instances in instances])
pixel_values = torch.stack([instances["pixel_values"] for instances in instances])
pixel_values_videos = None
video_grid_thw = None
labels = torch.nn.utils.rnn.pad_sequence(labels,
batch_first=True,
padding_value=-100)
labels = torch.nn.utils.rnn.pad_sequence(labels, batch_first=True, padding_value=-100)
labels = torch.flip(labels, dims=[1])
input_ids = torch.nn.utils.rnn.pad_sequence(input_ids,
batch_first=True,
padding_value=self.tokenizer.pad_token_id)
input_ids = torch.nn.utils.rnn.pad_sequence(
input_ids, batch_first=True, padding_value=self.tokenizer.pad_token_id
)
input_ids = torch.flip(input_ids, dims=[1])
b = input_ids.shape[0]
image_grid_thw = image_grid_thw.reshape(b * image_grid_thw.shape[1], image_grid_thw.shape[2])
pixel_values = pixel_values.reshape(b * pixel_values.shape[1], pixel_values.shape[2])
attention_mask = input_ids.ne(self.tokenizer.pad_token_id),
attention_mask = (input_ids.ne(self.tokenizer.pad_token_id),)
batch = dict(
input_ids=input_ids,
@ -126,7 +127,6 @@ class Qwen2VLAProcess:
return batch
def construct_chat_data(self, len_image, raw_lang):
messages = [
{
"role": "user",
@ -135,11 +135,13 @@ class Qwen2VLAProcess:
]
for i in range(len_image):
messages[0]['content'].append({
"type": "image",
"image": None,
})
messages[0]['content'].append({"type": "text", "text": f""})
messages[0]['content'][-1]['text'] = raw_lang
messages[0]["content"].append(
{
"type": "image",
"image": None,
}
)
messages[0]["content"].append({"type": "text", "text": ""})
messages[0]["content"][-1]["text"] = raw_lang
return messages

View File

@ -24,9 +24,9 @@ from lerobot.common.datasets.utils import dataset_to_policy_features
from lerobot.common.envs.configs import EnvConfig
from lerobot.common.envs.utils import env_to_policy_features
from lerobot.common.policies.act.configuration_act import ACTConfig
from lerobot.common.policies.dexvla.configuration_dexvla import DexVLAConfig
from lerobot.common.policies.diffusion.configuration_diffusion import DiffusionConfig
from lerobot.common.policies.pi0.configuration_pi0 import PI0Config
from lerobot.common.policies.dexvla.configuration_dexvla import DexVLAConfig
from lerobot.common.policies.pretrained import PreTrainedPolicy
from lerobot.common.policies.tdmpc.configuration_tdmpc import TDMPCConfig
from lerobot.common.policies.vqbet.configuration_vqbet import VQBeTConfig