diff --git a/lerobot/common/policies/__init__.py b/lerobot/common/policies/__init__.py index 5a7d1b8a..d212ef7e 100644 --- a/lerobot/common/policies/__init__.py +++ b/lerobot/common/policies/__init__.py @@ -1,6 +1,6 @@ from .act.configuration_act import ACTConfig as ACTConfig +from .dexvla.configuration_dexvla import DexVLAConfig as DexVLAConfig from .diffusion.configuration_diffusion import DiffusionConfig as DiffusionConfig from .pi0.configuration_pi0 import PI0Config as PI0Config from .tdmpc.configuration_tdmpc import TDMPCConfig as TDMPCConfig from .vqbet.configuration_vqbet import VQBeTConfig as VQBeTConfig -from .dexvla.configuration_dexvla import DexVLAConfig as DexVLAConfig diff --git a/lerobot/common/policies/dexvla/configuration_dexvla.py b/lerobot/common/policies/dexvla/configuration_dexvla.py index 6f3c0ef0..b39e74d1 100644 --- a/lerobot/common/policies/dexvla/configuration_dexvla.py +++ b/lerobot/common/policies/dexvla/configuration_dexvla.py @@ -1,4 +1,3 @@ -# coding=utf-8 # Copyright 2024 The Qwen team, Alibaba Group and the HuggingFace Inc. team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); @@ -14,30 +13,28 @@ # limitations under the License. """Qwen2VL model configuration""" +from dataclasses import dataclass, field from typing import Tuple -from dataclasses import dataclass, field - from transformers import AutoConfig +from transformers.utils import logging from lerobot.common.optim.optimizers import AdamWConfig from lerobot.common.optim.schedulers import ( CosineDecayWithWarmupSchedulerConfig, ) -from transformers.utils import logging from lerobot.configs.policies import PreTrainedConfig -from lerobot.common.policies.dexvla.policy_heads.configuration_scaledp import ScaleDPPolicyConfig -from lerobot.common.policies.dexvla.policy_heads.configuration_unet_diffusion import UnetDiffusionPolicyConfig -from lerobot.common.policies.dexvla.qwe2_vla.configuration_qwen2_vla import Qwen2VLAConfig from lerobot.configs.types import NormalizationMode logger = logging.get_logger(__name__) + + @PreTrainedConfig.register_subclass("dexvla") @dataclass class DexVLAConfig(PreTrainedConfig): # For loading policy head - policy_head_type: str = 'scale_dp_policy' - policy_head_size: str = 'ScaleDP_L' + policy_head_type: str = "scale_dp_policy" + policy_head_size: str = "ScaleDP_L" action_dim: int = 14 state_dim: int = 14 chunk_size: int = 50 @@ -45,9 +42,9 @@ class DexVLAConfig(PreTrainedConfig): n_obs_steps: int = 1 hidden_size: int = 1536 - qwen2_vl_path: str = None # '/media/rl/HDD/data/weights/Qwen2-VL-2B-Instruct' + qwen2_vl_path: str = None # '/media/rl/HDD/data/weights/Qwen2-VL-2B-Instruct' - pretrained_path: str = None # pretrained dexvla + pretrained_path: str = None # pretrained dexvla using_film: bool = True llm_loss_weight: float = 1.0 with_llm_head: bool = True @@ -82,33 +79,37 @@ class DexVLAConfig(PreTrainedConfig): f"Multiple observation steps not handled yet. Got `nobs_steps={self.n_obs_steps}`" ) if self.using_reasoning: - assert self.using_film, f"using_reasoning requires `using_film=True`" - assert self.with_llm_head, f"using_reasoning requires `with_llm_head=True`" + assert self.using_film, "using_reasoning requires `using_film=True`" + assert self.with_llm_head, "using_reasoning requires `with_llm_head=True`" print("You have set using_reasoning=True, please make sure your data has key 'reasoning'.") else: - print(f"Warning:DexVLA recommends to use reasoning data which can better handle long-horizon and dexterous tasks. You can set 'using_reaasoning=True'.") + print( + "Warning:DexVLA recommends to use reasoning data which can better handle long-horizon and dexterous tasks. You can set 'using_reaasoning=True'." + ) if self.qwen2_vl_path is None: - raise ValueError("DexVLA is built on official qwen2_vl-2B. You have to download the official weights of qwen2_vl-2B first and set 'qwen2_vl_path'.") + raise ValueError( + "DexVLA is built on official qwen2_vl-2B. You have to download the official weights of qwen2_vl-2B first and set 'qwen2_vl_path'." + ) - if self.policy_head_type == 'scale_dp_policy': + if self.policy_head_type == "scale_dp_policy": self.policy_head_config = AutoConfig.for_model( model_type=self.policy_head_type, model_size=self.policy_head_size, cond_dim=self.hidden_size, action_dim=self.action_dim, prediction_horizon=self.chunk_size, - state_dim=self.state_dim + state_dim=self.state_dim, ) - elif self.policy_head_type == 'unet_diffusion': + elif self.policy_head_type == "unet_diffusion": self.policy_head_config = AutoConfig.for_model( model_type=self.policy_head_type, global_cond_dim=self.hidden_size, action_dim=self.action_dim, - state_dim=self.state_dim + state_dim=self.state_dim, ) else: - raise ValueError(f'Policy head type {self.policy_head_type} not supported') + raise ValueError(f"Policy head type {self.policy_head_type} not supported") self.qwen2_vla_config = AutoConfig.from_pretrained(self.qwen2_vl_path) @@ -152,6 +153,3 @@ class DexVLAConfig(PreTrainedConfig): @property def reward_delta_indices(self) -> None: return None - - - diff --git a/lerobot/common/policies/dexvla/fusion_modules.py b/lerobot/common/policies/dexvla/fusion_modules.py index 7eb452e0..0d977edc 100644 --- a/lerobot/common/policies/dexvla/fusion_modules.py +++ b/lerobot/common/policies/dexvla/fusion_modules.py @@ -1,16 +1,18 @@ import torch.nn as nn + class ActionProjector(nn.Module): def __init__(self, in_dim, out_dim=1024): - super(ActionProjector, self).__init__() + super().__init__() self.global_1d_pool = nn.AdaptiveAvgPool1d(1) - self.mlps = nn.ModuleList([ - # nn.LayerNorm(in_dim), - nn.Linear(in_dim, in_dim), - nn.GELU(), - nn.Linear(in_dim, out_dim), - nn.Dropout(0.0), - ] + self.mlps = nn.ModuleList( + [ + # nn.LayerNorm(in_dim), + nn.Linear(in_dim, in_dim), + nn.GELU(), + nn.Linear(in_dim, out_dim), + nn.Dropout(0.0), + ] ) def forward(self, x): @@ -22,7 +24,7 @@ class ActionProjector(nn.Module): class FiLM(nn.Module): def __init__(self, feature_dim, condition_dim): - super(FiLM, self).__init__() + super().__init__() self.scale_fc = nn.Linear(condition_dim, feature_dim) self.shift_fc = nn.Linear(condition_dim, feature_dim) diff --git a/lerobot/common/policies/dexvla/modeling_dexvla.py b/lerobot/common/policies/dexvla/modeling_dexvla.py index e9330a79..8734751c 100644 --- a/lerobot/common/policies/dexvla/modeling_dexvla.py +++ b/lerobot/common/policies/dexvla/modeling_dexvla.py @@ -1,18 +1,16 @@ -import torch -from torch import Tensor - -from lerobot.common.policies.normalize import Normalize, Unnormalize -from lerobot.common.policies.dexvla.configuration_dexvla import DexVLAConfig -from lerobot.common.policies.dexvla.qwe2_vla.modeling_qwen2_vla import ( - Qwen2VLForConditionalGenerationForVLA -) -from lerobot.common.policies.pretrained import PreTrainedPolicy from collections import deque -from lerobot.common.policies.dexvla.policy_heads.modeling_unet_diffusion import ConditionalUnet1D -from lerobot.common.policies.dexvla.policy_heads.modeling_scaledp import ScaleDP -from lerobot.common.policies.dexvla.robot_data_processor import Qwen2VLAProcess -from transformers import AutoProcessor, AutoTokenizer + +import torch import torchvision.transforms as transforms +from torch import Tensor +from transformers import AutoProcessor, AutoTokenizer + +from lerobot.common.policies.dexvla.configuration_dexvla import DexVLAConfig +from lerobot.common.policies.dexvla.qwe2_vla.modeling_qwen2_vla import Qwen2VLForConditionalGenerationForVLA +from lerobot.common.policies.dexvla.robot_data_processor import Qwen2VLAProcess +from lerobot.common.policies.normalize import Normalize, Unnormalize +from lerobot.common.policies.pretrained import PreTrainedPolicy + class DexVLAPolicy(PreTrainedPolicy): """Wrapper class around Qwen2VLForConditionalGenerationForVLA model to train and run inference within LeRobot.""" @@ -44,17 +42,17 @@ class DexVLAPolicy(PreTrainedPolicy): config.output_features, config.normalization_mapping, dataset_stats ) - for k in ['using_film', 'llm_loss_weight', 'with_llm_head', 'policy_head_config']: + for k in ["using_film", "llm_loss_weight", "with_llm_head", "policy_head_config"]: setattr(config.qwen2_vla_config, k, config.__dict__[k]) self.model = Qwen2VLForConditionalGenerationForVLA(config.qwen2_vla_config).to(torch.bfloat16) self.model.requires_grad_(False) self.model.policy_head.requires_grad_(True) self.qwen2_vl_processor = AutoProcessor.from_pretrained(config.qwen2_vl_path) - self.tokenizer = AutoTokenizer.from_pretrained( - config.qwen2_vl_path - ) - self.vla_processor = Qwen2VLAProcess(tokenizer=self.tokenizer, multimodal_processor=self.qwen2_vl_processor) # process the input data into VLM format + self.tokenizer = AutoTokenizer.from_pretrained(config.qwen2_vl_path) + self.vla_processor = Qwen2VLAProcess( + tokenizer=self.tokenizer, multimodal_processor=self.qwen2_vl_processor + ) # process the input data into VLM format self.resize_size = self.config.resize_size ratio = 0.95 @@ -73,14 +71,14 @@ class DexVLAPolicy(PreTrainedPolicy): batch = self.normalize_inputs(batch) batch = self.normalize_targets(batch) present_img_keys = [key for key in self.config.image_features if key in batch] - task_descs = batch['task'] + task_descs = batch["task"] try: - reasonings = batch['reasoning'] + reasonings = batch["reasoning"] except KeyError: - reasonings = ['no reasoning'] * len(task_descs) + reasonings = ["no reasoning"] * len(task_descs) pass - is_pad = batch['action_is_pad'] + is_pad = batch["action_is_pad"] all_cam_images = [] for k in present_img_keys: all_cam_images.append(batch[k]) @@ -89,8 +87,8 @@ class DexVLAPolicy(PreTrainedPolicy): image_data = torch.stack(all_cam_images) * 255 image_data = image_data.to(dtype=torch.uint8) # construct observations - qpos_data = batch['observation.state'].float() - action_data = batch['action'].float() + qpos_data = batch["observation.state"].float() + action_data = batch["action"].float() orig_shape = image_data.shape image_data = image_data.view(-1, *orig_shape[2:]) @@ -100,40 +98,35 @@ class DexVLAPolicy(PreTrainedPolicy): image_data = image_data.view(*orig_shape[:3], *self.resize_size) - vl_data = { - 'images': image_data, - 'raw_langs': task_descs, - 'reasonings': reasonings - } + vl_data = {"images": image_data, "raw_langs": task_descs, "reasonings": reasonings} # processing vl_data into qwen2_vl format vla_inputs = self.vla_processor.forward(vl_data, use_reasoning=self.config.using_reasoning) - vla_inputs['states'] = qpos_data - vla_inputs['is_pad'] = is_pad - vla_inputs['actions'] = action_data + vla_inputs["states"] = qpos_data + vla_inputs["is_pad"] = is_pad + vla_inputs["actions"] = action_data return vla_inputs - def forward(self, batch: dict[str, Tensor]) -> tuple[Tensor, dict[str, Tensor]]: - processed_batch = self.process_batch(batch) ret = self.model.forward(**processed_batch) - loss_dict = ret['loss'] - loss = loss_dict['loss'].mean() + loss_dict = ret["loss"] + loss = loss_dict["loss"].mean() return loss, loss_dict - def dexvla_predict_action(self, - input_ids: torch.LongTensor = None, - actions=None, - states=None, - is_pad=None, - tokenizer=None, - is_eval=True, - pixel_values=None, - attention_mask=None, - image_grid_thw=None, - ): - input_ids = input_ids.to('cuda') + def dexvla_predict_action( + self, + input_ids: torch.LongTensor = None, + actions=None, + states=None, + is_pad=None, + tokenizer=None, + is_eval=True, + pixel_values=None, + attention_mask=None, + image_grid_thw=None, + ): + input_ids = input_ids.to("cuda") with torch.inference_mode(): outputs = self.model.generate( input_ids, @@ -157,7 +150,7 @@ class DexVLAPolicy(PreTrainedPolicy): input_token_len = input_ids.shape[1] n_diff_input_output = (input_ids != output_ids[:, :input_token_len]).sum().item() if n_diff_input_output > 0: - print(f'[Warning] {n_diff_input_output} output_ids are not the same as the input_ids') + print(f"[Warning] {n_diff_input_output} output_ids are not the same as the input_ids") outputs_text = tokenizer.batch_decode(output_ids[:, input_token_len:], skip_special_tokens=False)[0] outputs_text = outputs_text.strip() @@ -167,35 +160,44 @@ class DexVLAPolicy(PreTrainedPolicy): action_hidden_states = None if self.model.using_film: - action_hidden_states = self.model.film_forward(labels=torch.ones_like(output_ids), - input_ids=output_ids, - hidden_states=torch.cat(last_hidden_states, dim=1)) + action_hidden_states = self.model.film_forward( + labels=torch.ones_like(output_ids), + input_ids=output_ids, + hidden_states=torch.cat(last_hidden_states, dim=1), + ) - action = self.model.policy_head(actions, action_hidden_states, states.to(all_hidden_states.dtype), is_pad) + action = self.model.policy_head( + actions, action_hidden_states, states.to(all_hidden_states.dtype), is_pad + ) return action, outputs_text - def tinyvla_predict_action(self, - input_ids: torch.LongTensor = None, - actions=None, - states=None, - is_pad=None, - is_eval=True, - pixel_values=None, - attention_mask=None, - image_grid_thw=None, - ): - input_ids = input_ids.to('cuda') + def tinyvla_predict_action( + self, + input_ids: torch.LongTensor = None, + actions=None, + states=None, + is_pad=None, + is_eval=True, + pixel_values=None, + attention_mask=None, + image_grid_thw=None, + ): + input_ids = input_ids.to("cuda") with torch.inference_mode(): - all_hidden_states = self.model.forward(input_ids, - pixel_values=pixel_values, - attention_mask=attention_mask, - image_grid_thw=image_grid_thw, - is_eval=is_eval, - tinyvla=True) + all_hidden_states = self.model.forward( + input_ids, + pixel_values=pixel_values, + attention_mask=attention_mask, + image_grid_thw=image_grid_thw, + is_eval=is_eval, + tinyvla=True, + ) all_hidden_states = torch.mean(all_hidden_states, dim=1).unsqueeze(1) - action = self.model.policy_head(actions, all_hidden_states, states.to(all_hidden_states.dtype), is_pad) + action = self.model.policy_head( + actions, all_hidden_states, states.to(all_hidden_states.dtype), is_pad + ) return action, "tinyvla generates no reasoning" def reset(self): @@ -219,7 +221,7 @@ class DexVLAPolicy(PreTrainedPolicy): if len(self._action_queue) == 0: present_img_keys = [key for key in self.config.image_features if key in batch] try: - task_descs = batch['task'] + task_descs = batch["task"] except KeyError: task_descs = " " print("No task descriptions found for this task") @@ -232,7 +234,7 @@ class DexVLAPolicy(PreTrainedPolicy): image_data = torch.stack(all_cam_images) * 255 image_data = image_data.to(dtype=torch.uint8) # construct observations - qpos_data = batch['observation.state'].float() + qpos_data = batch["observation.state"].float() image_data = image_data.squeeze(0) @@ -240,20 +242,19 @@ class DexVLAPolicy(PreTrainedPolicy): image_data = transform(image_data) # processing vl_data into qwen2_vl format - vla_inputs = self.vla_processor.single_forward_process(images=image_data, raw_lang=task_descs, reasoning=None, eval=True) - vla_inputs['states'] = qpos_data + vla_inputs = self.vla_processor.single_forward_process( + images=image_data, raw_lang=task_descs, reasoning=None, eval=True + ) + vla_inputs["states"] = qpos_data - if self.config.using_film and self.config.with_llm_head: # dexvla - all_actions, outputs = self.dexvla_predict_action(**vla_inputs, is_eval=True, tokenizer=self.tokenizer) - else: # tinyvla + if self.config.using_film and self.config.with_llm_head: # dexvla + all_actions, outputs = self.dexvla_predict_action( + **vla_inputs, is_eval=True, tokenizer=self.tokenizer + ) + else: # tinyvla all_actions, outputs = self.tinyvla_predict_action(**vla_inputs, is_eval=True) actions = self.unnormalize_outputs({"action": all_actions})["action"] self._action_queue.extend(actions.transpose(0, 1)) return self._action_queue.popleft() - - - - - diff --git a/lerobot/common/policies/dexvla/policy_heads/configuration_scaledp.py b/lerobot/common/policies/dexvla/policy_heads/configuration_scaledp.py index 0837f499..6a8f7ea9 100644 --- a/lerobot/common/policies/dexvla/policy_heads/configuration_scaledp.py +++ b/lerobot/common/policies/dexvla/policy_heads/configuration_scaledp.py @@ -1,47 +1,58 @@ import os -from typing import Union, List -from transformers import PretrainedConfig +from typing import Union +from transformers import AutoConfig, PretrainedConfig from transformers.utils import logging -from transformers import AutoConfig, AutoModelForCausalLM + logger = logging.get_logger(__name__) MODEL_STRUCTURE = { - 'ScaleDP_H': {'depth': 32, 'n_emb': 1280, 'num_heads': 16, }, - 'ScaleDP_L': {'depth': 24, 'n_emb': 1024, 'num_heads': 16, }, # 400M + "ScaleDP_H": { + "depth": 32, + "n_emb": 1280, + "num_heads": 16, + }, + "ScaleDP_L": { + "depth": 24, + "n_emb": 1024, + "num_heads": 16, + }, # 400M } + class ScaleDPPolicyConfig(PretrainedConfig): - ''' + """ Configuration for ScaleDP policy head - ''' + """ + model_type = "scale_dp_policy" + def __init__( - self, - eval: bool = False, - action_dim: int = 14, # action dim - # output_dim: int = 14, # action dim - cond_dim: int = 1536, # the input dim of the condition - state_dim: int = 14, # the input dim of the state - prediction_horizon: int = 16, # horizon - n_obs_steps: int = 2, # number of observation steps - depth: int = 28, # number of DiT blocks - n_emb: int = 256, # embedding size - num_heads: int = 16, - mlp_ratio: int = 4.0, - time_as_cond: bool = True, - obs_as_cond: bool = True, - learn_sigma: bool = False, - model_size: str = "none", - num_inference_timesteps: int = 10, - noise_samples: int = 1, - num_train_timesteps: int = 100, - **kwargs + self, + eval: bool = False, + action_dim: int = 14, # action dim + # output_dim: int = 14, # action dim + cond_dim: int = 1536, # the input dim of the condition + state_dim: int = 14, # the input dim of the state + prediction_horizon: int = 16, # horizon + n_obs_steps: int = 2, # number of observation steps + depth: int = 28, # number of DiT blocks + n_emb: int = 256, # embedding size + num_heads: int = 16, + mlp_ratio: int = 4.0, + time_as_cond: bool = True, + obs_as_cond: bool = True, + learn_sigma: bool = False, + model_size: str = "none", + num_inference_timesteps: int = 10, + noise_samples: int = 1, + num_train_timesteps: int = 100, + **kwargs, ): if model_size != "none": - depth = MODEL_STRUCTURE[model_size]['depth'] - n_emb = MODEL_STRUCTURE[model_size]['n_emb'] - num_heads = MODEL_STRUCTURE[model_size]['num_heads'] + depth = MODEL_STRUCTURE[model_size]["depth"] + n_emb = MODEL_STRUCTURE[model_size]["n_emb"] + num_heads = MODEL_STRUCTURE[model_size]["num_heads"] else: # raise ValueError("model_size show not be 'none'") pass @@ -52,7 +63,6 @@ class ScaleDPPolicyConfig(PretrainedConfig): self.output_dim = action_dim self.prediction_horizon = prediction_horizon - self.cond_dim = cond_dim self.state_dim = state_dim @@ -72,7 +82,9 @@ class ScaleDPPolicyConfig(PretrainedConfig): super().__init__(**kwargs) @classmethod - def from_pretrained(cls, pretrained_model_name_or_path: Union[str, os.PathLike], **kwargs) -> "PretrainedConfig": + def from_pretrained( + cls, pretrained_model_name_or_path: Union[str, os.PathLike], **kwargs + ) -> "PretrainedConfig": cls._set_token_in_kwargs(kwargs) config_dict, kwargs = cls.get_config_dict(pretrained_model_name_or_path, **kwargs) @@ -81,7 +93,11 @@ class ScaleDPPolicyConfig(PretrainedConfig): if config_dict.get("model_type") == "llava_pythia": config_dict = config_dict["action_head"] - if "model_type" in config_dict and hasattr(cls, "model_type") and config_dict["model_type"] != cls.model_type: + if ( + "model_type" in config_dict + and hasattr(cls, "model_type") + and config_dict["model_type"] != cls.model_type + ): logger.warning( f"You are using a model of type {config_dict['model_type']} to instantiate a model of type " f"{cls.model_type}. This is not supported for all configurations of models and can yield errors." @@ -89,4 +105,5 @@ class ScaleDPPolicyConfig(PretrainedConfig): return cls.from_dict(config_dict, **kwargs) + AutoConfig.register("scale_dp_policy", ScaleDPPolicyConfig) diff --git a/lerobot/common/policies/dexvla/policy_heads/configuration_unet_diffusion.py b/lerobot/common/policies/dexvla/policy_heads/configuration_unet_diffusion.py index 38e403a6..aaf66447 100644 --- a/lerobot/common/policies/dexvla/policy_heads/configuration_unet_diffusion.py +++ b/lerobot/common/policies/dexvla/policy_heads/configuration_unet_diffusion.py @@ -1,31 +1,33 @@ import os -from typing import Union, List -from transformers import PretrainedConfig +from typing import Union +from transformers import AutoConfig, PretrainedConfig from transformers.utils import logging -from transformers import AutoConfig, AutoModelForCausalLM + logger = logging.get_logger(__name__) + class UnetDiffusionPolicyConfig(PretrainedConfig): - ''' + """ Configuration for dit diffusion policy head - ''' + """ + model_type = "unet_diffusion_policy" def __init__( - self, - action_dim=10, - global_cond_dim=2048, - diffusion_step_embed_dim=256, - down_dims=[256, 512, 1024], - kernel_size=5, - n_groups=8, - state_dim=7, - prediction_horizon=16, - noise_samples=1, - num_inference_timesteps=10, - num_train_timesteps=100, - **kwargs + self, + action_dim=10, + global_cond_dim=2048, + diffusion_step_embed_dim=256, + down_dims=[256, 512, 1024], + kernel_size=5, + n_groups=8, + state_dim=7, + prediction_horizon=16, + noise_samples=1, + num_inference_timesteps=10, + num_train_timesteps=100, + **kwargs, ): self.input_dim = action_dim self.noise_samples = noise_samples @@ -42,7 +44,9 @@ class UnetDiffusionPolicyConfig(PretrainedConfig): super().__init__(**kwargs) @classmethod - def from_pretrained(cls, pretrained_model_name_or_path: Union[str, os.PathLike], **kwargs) -> "PretrainedConfig": + def from_pretrained( + cls, pretrained_model_name_or_path: Union[str, os.PathLike], **kwargs + ) -> "PretrainedConfig": cls._set_token_in_kwargs(kwargs) config_dict, kwargs = cls.get_config_dict(pretrained_model_name_or_path, **kwargs) @@ -51,7 +55,11 @@ class UnetDiffusionPolicyConfig(PretrainedConfig): if config_dict.get("model_type") == "llava_pythia": config_dict = config_dict["action_head"] - if "model_type" in config_dict and hasattr(cls, "model_type") and config_dict["model_type"] != cls.model_type: + if ( + "model_type" in config_dict + and hasattr(cls, "model_type") + and config_dict["model_type"] != cls.model_type + ): logger.warning( f"You are using a model of type {config_dict['model_type']} to instantiate a model of type " f"{cls.model_type}. This is not supported for all configurations of models and can yield errors." @@ -59,4 +67,5 @@ class UnetDiffusionPolicyConfig(PretrainedConfig): return cls.from_dict(config_dict, **kwargs) + AutoConfig.register("unet_diffusion_policy", UnetDiffusionPolicyConfig) diff --git a/lerobot/common/policies/dexvla/policy_heads/modeling_scaledp.py b/lerobot/common/policies/dexvla/policy_heads/modeling_scaledp.py index 4c78b6e1..b9a9b919 100644 --- a/lerobot/common/policies/dexvla/policy_heads/modeling_scaledp.py +++ b/lerobot/common/policies/dexvla/policy_heads/modeling_scaledp.py @@ -1,27 +1,24 @@ # Copyright (c) Meta Platforms, Inc. and affiliates. # All rights reserved. -from typing import Tuple - -import timm -import numpy as np import logging - import math from typing import Tuple +import numpy as np + try: from typing import Literal except ImportError: - from typing_extensions import Literal + pass import torch import torch.nn as nn import torch.nn.functional as F import torch.utils.checkpoint -from torch.jit import Final from timm.models.vision_transformer import Mlp, use_fused_attn +from torch.jit import Final +from transformers import AutoModel from transformers.modeling_utils import PreTrainedModel -from transformers import AutoModel, AutoModelForCausalLM _logger = logging.getLogger(__name__) @@ -30,20 +27,20 @@ class Attention(nn.Module): fused_attn: Final[bool] def __init__( - self, - dim: int, - num_heads: int = 8, - qkv_bias: bool = False, - qk_norm: bool = False, - attn_drop: float = 0., - proj_drop: float = 0., - norm_layer: nn.Module = nn.LayerNorm, + self, + dim: int, + num_heads: int = 8, + qkv_bias: bool = False, + qk_norm: bool = False, + attn_drop: float = 0.0, + proj_drop: float = 0.0, + norm_layer: nn.Module = nn.LayerNorm, ) -> None: super().__init__() - assert dim % num_heads == 0, 'dim should be divisible by num_heads' + assert dim % num_heads == 0, "dim should be divisible by num_heads" self.num_heads = num_heads self.head_dim = dim // num_heads - self.scale = self.head_dim ** -0.5 + self.scale = self.head_dim**-0.5 self.fused_attn = use_fused_attn() self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) @@ -61,8 +58,11 @@ class Attention(nn.Module): if self.fused_attn: x = F.scaled_dot_product_attention( - q, k, v, attn_mask=attn_mask, - dropout_p=self.attn_drop.p if self.training else 0., + q, + k, + v, + attn_mask=attn_mask, + dropout_p=self.attn_drop.p if self.training else 0.0, ) else: q = q * self.scale @@ -104,6 +104,7 @@ def modulate(x, shift, scale): # Embedding Layers for Timesteps and Class Labels # ################################################################################# + class TimestepEmbedder(nn.Module): """ Embeds scalar timesteps into vector representations. @@ -145,11 +146,11 @@ class TimestepEmbedder(nn.Module): return t_emb - ################################################################################# # Core ScaleDP Model # ################################################################################# + class ScaleDPBlock(nn.Module): """ A ScaleDP block with adaptive layer norm zero (adaLN-Zero) conScaleDPioning. @@ -163,14 +164,15 @@ class ScaleDPBlock(nn.Module): mlp_hidden_dim = int(hidden_size * mlp_ratio) approx_gelu = lambda: nn.GELU(approximate="tanh") self.mlp = Mlp(in_features=hidden_size, hidden_features=mlp_hidden_dim, act_layer=approx_gelu, drop=0) - self.adaLN_modulation = nn.Sequential( - nn.SiLU(), - nn.Linear(hidden_size, 6 * hidden_size, bias=True) - ) + self.adaLN_modulation = nn.Sequential(nn.SiLU(), nn.Linear(hidden_size, 6 * hidden_size, bias=True)) def forward(self, x, c, attn_mask=None): - shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.adaLN_modulation(c).chunk(6, dim=1) - x = x + gate_msa.unsqueeze(1) * self.attn(modulate(self.norm1(x), shift_msa, scale_msa), attn_mask=attn_mask) # norm, scale&shift, attn, scale, + shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.adaLN_modulation(c).chunk( + 6, dim=1 + ) + x = x + gate_msa.unsqueeze(1) * self.attn( + modulate(self.norm1(x), shift_msa, scale_msa), attn_mask=attn_mask + ) # norm, scale&shift, attn, scale, x = x + gate_mlp.unsqueeze(1) * self.mlp(modulate(self.norm2(x), shift_mlp, scale_mlp)) return x @@ -184,10 +186,7 @@ class FinalLayer(nn.Module): super().__init__() self.norm_final = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) self.linear = nn.Linear(hidden_size, output_dim, bias=True) - self.adaLN_modulation = nn.Sequential( - nn.SiLU(), - nn.Linear(hidden_size, 2 * hidden_size, bias=True) - ) + self.adaLN_modulation = nn.Sequential(nn.SiLU(), nn.Linear(hidden_size, 2 * hidden_size, bias=True)) def forward(self, x, c): shift, scale = self.adaLN_modulation(c).chunk(2, dim=1) @@ -195,15 +194,20 @@ class FinalLayer(nn.Module): x = self.linear(x) return x + from .configuration_scaledp import ScaleDPPolicyConfig + + class ScaleDP(PreTrainedModel): """ Diffusion models with a Transformer backbone. """ + config_class = ScaleDPPolicyConfig + def __init__( - self, - config: ScaleDPPolicyConfig, + self, + config: ScaleDPPolicyConfig, ): super().__init__(config) # compute number of tokens for main trunk and conScaleDPion encoder @@ -221,11 +225,11 @@ class ScaleDP(PreTrainedModel): # self.combine = nn.Linear(cond_dim+state_dim, cond_dim) self.combine = nn.Sequential( - nn.Linear(config.cond_dim+config.state_dim, 1024), + nn.Linear(config.cond_dim + config.state_dim, 1024), nn.ReLU(), nn.Linear(1024, 1024), nn.ReLU(), - nn.Linear(1024, config.cond_dim) + nn.Linear(1024, config.cond_dim), ) self.learn_sigma = config.learn_sigma self.input_dim = config.input_dim @@ -241,9 +245,12 @@ class ScaleDP(PreTrainedModel): # Will use fixed sin-cos embedding: self.pos_embed = nn.Parameter(torch.zeros(1, config.prediction_horizon, config.n_emb)) - self.blocks = nn.ModuleList([ - ScaleDPBlock(config.n_emb, config.num_heads, mlp_ratio=config.mlp_ratio) for _ in range(config.depth) - ]) + self.blocks = nn.ModuleList( + [ + ScaleDPBlock(config.n_emb, config.num_heads, mlp_ratio=config.mlp_ratio) + for _ in range(config.depth) + ] + ) self.final_layer = FinalLayer(config.n_emb, output_dim=config.output_dim) # self.initialize_weights() # constants @@ -253,23 +260,22 @@ class ScaleDP(PreTrainedModel): self.time_as_cond = config.time_as_cond self.action_dim = config.output_dim self.obs_as_cond = obs_as_cond - logger.info( - "number of parameters in ScaleDP: %e", sum(p.numel() for p in self.parameters()) - ) + logger.info("number of parameters in ScaleDP: %e", sum(p.numel() for p in self.parameters())) from diffusers.schedulers.scheduling_ddim import DDIMScheduler + self.num_inference_timesteps = config.num_inference_timesteps # self.proj_to_action = nn.Identity() self.noise_scheduler = DDIMScheduler( - num_train_timesteps=config.num_train_timesteps, # 100 - beta_schedule='squaredcos_cap_v2', + num_train_timesteps=config.num_train_timesteps, # 100 + beta_schedule="squaredcos_cap_v2", clip_sample=True, set_alpha_to_one=True, steps_offset=0, - prediction_type='epsilon' + prediction_type="epsilon", ) - self.num_queries = config.num_queries #16 - self.noise_samples = config.noise_samples # 1 + self.num_queries = config.num_queries # 16 + self.noise_samples = config.noise_samples # 1 # self.num_inference_timesteps = config.num_inference_timesteps # 100 def initialize_weights(self): @@ -308,7 +314,6 @@ class ScaleDP(PreTrainedModel): nn.init.constant_(self.final_layer.linear.weight, 0) nn.init.constant_(self.final_layer.linear.bias, 0) - def get_optim_groups(self, weight_decay: float = 1e-3): """ This long function is unfortunately doing something very simple and is being very defensive: @@ -324,7 +329,7 @@ class ScaleDP(PreTrainedModel): blacklist_weight_modules = (torch.nn.LayerNorm, torch.nn.Embedding) for mn, m in self.named_modules(): for pn, p in m.named_parameters(): - fpn = "%s.%s" % (mn, pn) if mn else pn # full param name + fpn = "{}.{}".format(mn, pn) if mn else pn # full param name if pn.endswith("bias"): # all biases will not be decayed @@ -343,13 +348,13 @@ class ScaleDP(PreTrainedModel): param_dict = {pn: p for pn, p in self.named_parameters()} inter_params = decay & no_decay union_params = decay | no_decay - assert ( - len(inter_params) == 0 - ), "parameters %s made it into both decay/no_decay sets!" % (str(inter_params),) - assert ( - len(param_dict.keys() - union_params) == 0 - ), "parameters %s were not separated into either decay/no_decay set!" % ( - str(param_dict.keys() - union_params), + assert len(inter_params) == 0, "parameters {} made it into both decay/no_decay sets!".format( + str(inter_params) + ) + assert len(param_dict.keys() - union_params) == 0, ( + "parameters {} were not separated into either decay/no_decay set!".format( + str(param_dict.keys() - union_params), + ) ) # create the pytorch optimizer object @@ -365,14 +370,14 @@ class ScaleDP(PreTrainedModel): ] return optim_groups - def configure_optimizers(self, - learning_rate: float = 1e-4, - weight_decay: float = 1e-3, - betas: Tuple[float, float] = (0.9, 0.95)): + def configure_optimizers( + self, + learning_rate: float = 1e-4, + weight_decay: float = 1e-3, + betas: Tuple[float, float] = (0.9, 0.95), + ): optim_groups = self.get_optim_groups(weight_decay=weight_decay) - optimizer = torch.optim.AdamW( - optim_groups, lr=learning_rate, betas=betas - ) + optimizer = torch.optim.AdamW(optim_groups, lr=learning_rate, betas=betas) return optimizer def forward(self, actions, hidden_states, states, is_pad): @@ -385,25 +390,26 @@ class ScaleDP(PreTrainedModel): """ if actions is not None: # training time B = actions.size(0) - actions = actions[:, :self.num_queries] - is_pad = is_pad[:, :self.num_queries] + actions = actions[:, : self.num_queries] + is_pad = is_pad[:, : self.num_queries] num_noise_samples = self.noise_samples # sample noise to add to actions - noise = torch.randn([num_noise_samples] + list(actions.shape), device=actions.device, - dtype=actions.dtype) # num_noise, B, Ta, D(1, 2, 16, 14) + noise = torch.randn( + [num_noise_samples] + list(actions.shape), device=actions.device, dtype=actions.dtype + ) # num_noise, B, Ta, D(1, 2, 16, 14) # sample a diffusion iteration for each data point timesteps = torch.randint( - 0, self.noise_scheduler.config.num_train_timesteps, - (B,), device=actions.device + 0, self.noise_scheduler.config.num_train_timesteps, (B,), device=actions.device ).long() timesteps, noise = timesteps.to(actions.device), noise.to(actions.device) # add noise to the clean actions according to the noise magnitude at each diffusion iteration # (this is the forward diffusion process) - noisy_actions = torch.cat([self.noise_scheduler.add_noise( - actions, noise[i], timesteps) - for i in range(len(noise))], dim=0) # [num_noise_samples * B, Ta, action_dim] + noisy_actions = torch.cat( + [self.noise_scheduler.add_noise(actions, noise[i], timesteps) for i in range(len(noise))], + dim=0, + ) # [num_noise_samples * B, Ta, action_dim] noisy_actions = noisy_actions.to(dtype=actions.dtype) assert hidden_states.ndim == 3 @@ -411,14 +417,16 @@ class ScaleDP(PreTrainedModel): hidden_states = hidden_states.repeat(num_noise_samples, 1, 1) timesteps = timesteps.repeat(num_noise_samples) is_pad = is_pad.repeat(num_noise_samples, 1) - states = states.repeat(num_noise_samples, 1) + states = states.repeat(num_noise_samples, 1) - noise_pred = self.model_forward(noisy_actions, timesteps, global_cond=hidden_states, states=states) + noise_pred = self.model_forward( + noisy_actions, timesteps, global_cond=hidden_states, states=states + ) noise = noise.view(noise.size(0) * noise.size(1), *noise.size()[2:]) - loss = torch.nn.functional.mse_loss(noise_pred, noise, reduction='none') + loss = torch.nn.functional.mse_loss(noise_pred, noise, reduction="none") loss = (loss * ~is_pad.unsqueeze(-1)).mean() # loss_dict['loss'] = loss - return {'loss': loss} + return {"loss": loss} # return loss else: # inference time B = 1 @@ -438,9 +446,7 @@ class ScaleDP(PreTrainedModel): # inverse diffusion step (remove noise) naction = self.noise_scheduler.step( - model_output=noise_pred, - timestep=k, - sample=naction + model_output=noise_pred, timestep=k, sample=naction ).prev_sample return naction @@ -462,7 +468,9 @@ class ScaleDP(PreTrainedModel): t = t[None].to(x.device) t = t.expand(t.shape[0]) - x = self.x_embedder(x) + self.pos_embed.to(device=x.device, dtype=x.dtype) # (N, T, D), where T = prediction_horizon + x = self.x_embedder(x) + self.pos_embed.to( + device=x.device, dtype=x.dtype + ) # (N, T, D), where T = prediction_horizon t = self.t_embedder(t) # (N, D) if self.obs_as_cond: global_cond = self.cond_obs_emb(global_cond) # (N, D) @@ -474,11 +482,13 @@ class ScaleDP(PreTrainedModel): x = self.final_layer(x, c) # (N, T, output_dim) return x + ################################################################################# # Sine/Cosine Positional Embedding Functions # ################################################################################# # https://github.com/facebookresearch/mae/blob/main/util/pos_embed.py + def get_2d_sincos_pos_embed(embed_dim, grid_size, cls_token=False, extra_tokens=0): """ grid_size: int of the grid height and width @@ -516,11 +526,11 @@ def get_1d_sincos_pos_embed_from_grid(embed_dim, pos): """ assert embed_dim % 2 == 0 omega = np.arange(embed_dim // 2, dtype=np.float64) - omega /= embed_dim / 2. - omega = 1. / 10000 ** omega # (D/2,) + omega /= embed_dim / 2.0 + omega = 1.0 / 10000**omega # (D/2,) pos = pos.reshape(-1) # (M,) - out = np.einsum('m,d->md', pos, omega) # (M, D/2), outer product + out = np.einsum("m,d->md", pos, omega) # (M, D/2), outer product emb_sin = np.sin(out) # (M, D/2) emb_cos = np.cos(out) # (M, D/2) @@ -533,12 +543,13 @@ def get_1d_sincos_pos_embed_from_grid(embed_dim, pos): # ScaleDP Configs # ################################################################################# + def ScaleDP_H(**kwargs): return ScaleDP(depth=32, n_emb=1280, num_heads=16, **kwargs) + def ScaleDP_L(**kwargs): return ScaleDP(depth=24, n_emb=1024, num_heads=16, **kwargs) - AutoModel.register(ScaleDPPolicyConfig, ScaleDP) diff --git a/lerobot/common/policies/dexvla/policy_heads/modeling_unet_diffusion.py b/lerobot/common/policies/dexvla/policy_heads/modeling_unet_diffusion.py index a7b456d2..eba83e36 100644 --- a/lerobot/common/policies/dexvla/policy_heads/modeling_unet_diffusion.py +++ b/lerobot/common/policies/dexvla/policy_heads/modeling_unet_diffusion.py @@ -1,29 +1,29 @@ """ Implementation of Diffusion Policy https://diffusion-policy.cs.columbia.edu/ by Cheng Chi """ -from typing import Callable, Union + +import copy import math -from collections import OrderedDict, deque -from packaging.version import parse as parse_version -import random +from typing import Union + import torch import torch.nn as nn -import torch.nn.functional as F + # requires diffusers==0.11.1 -from diffusers.schedulers.scheduling_ddpm import DDPMScheduler from diffusers.schedulers.scheduling_ddim import DDIMScheduler -from diffusers.training_utils import EMAModel -from .configuration_unet_diffusion import UnetDiffusionPolicyConfig +from transformers import AutoModel from transformers.modeling_utils import PreTrainedModel -from transformers import AutoModel, AutoModelForCausalLM -import copy + +from .configuration_unet_diffusion import UnetDiffusionPolicyConfig + # =================== UNet for Diffusion ============== + class SinusoidalPosEmb(nn.Module): def __init__(self, dim, dtype): super().__init__() self.dim = dim - self.dtype=dtype + self.dtype = dtype def forward(self, x): device = x.device @@ -54,9 +54,9 @@ class Upsample1d(nn.Module): class Conv1dBlock(nn.Module): - ''' - Conv1d --> GroupNorm --> Mish - ''' + """ + Conv1d --> GroupNorm --> Mish + """ def __init__(self, inp_channels, out_channels, kernel_size, n_groups=8): super().__init__() @@ -72,46 +72,41 @@ class Conv1dBlock(nn.Module): class ConditionalResidualBlock1D(nn.Module): - def __init__(self, - in_channels, - out_channels, - cond_dim, - kernel_size=3, - n_groups=8): + def __init__(self, in_channels, out_channels, cond_dim, kernel_size=3, n_groups=8): super().__init__() - self.blocks = nn.ModuleList([ - Conv1dBlock(in_channels, out_channels, kernel_size, n_groups=n_groups), - Conv1dBlock(out_channels, out_channels, kernel_size, n_groups=n_groups), - ]) + self.blocks = nn.ModuleList( + [ + Conv1dBlock(in_channels, out_channels, kernel_size, n_groups=n_groups), + Conv1dBlock(out_channels, out_channels, kernel_size, n_groups=n_groups), + ] + ) # FiLM modulation https://arxiv.org/abs/1709.07871 # predicts per-channel scale and bias cond_channels = out_channels * 2 self.out_channels = out_channels self.cond_encoder = nn.Sequential( - nn.Mish(), - nn.Linear(cond_dim, cond_channels), - nn.Unflatten(-1, (-1, 1)) + nn.Mish(), nn.Linear(cond_dim, cond_channels), nn.Unflatten(-1, (-1, 1)) ) # make sure dimensions compatible - self.residual_conv = nn.Conv1d(in_channels, out_channels, 1) \ - if in_channels != out_channels else nn.Identity() + self.residual_conv = ( + nn.Conv1d(in_channels, out_channels, 1) if in_channels != out_channels else nn.Identity() + ) def forward(self, x, cond): - ''' - x : [ batch_size x in_channels x horizon ] - cond : [ batch_size x cond_dim] + """ + x : [ batch_size x in_channels x horizon ] + cond : [ batch_size x cond_dim] - returns: - out : [ batch_size x out_channels x horizon ] - ''' + returns: + out : [ batch_size x out_channels x horizon ] + """ out = self.blocks[0](x) embed = self.cond_encoder(cond) - embed = embed.reshape( - embed.shape[0], 2, self.out_channels, 1) + embed = embed.reshape(embed.shape[0], 2, self.out_channels, 1) scale = embed[:, 0, ...] bias = embed[:, 1, ...] out = scale * out + bias @@ -125,9 +120,8 @@ class ConditionalUnet1D(PreTrainedModel): _no_split_modules = ["mid_modules", "down_modules", "up_modules"] config_class = UnetDiffusionPolicyConfig - def __init__(self, - config: UnetDiffusionPolicyConfig - ): + + def __init__(self, config: UnetDiffusionPolicyConfig): """ input_dim: Dim of actions. global_cond_dim: Dim of global conditioning applied with FiLM @@ -148,7 +142,7 @@ class ConditionalUnet1D(PreTrainedModel): # self.global_1d_pool = nn.AdaptiveAvgPool1d(1) # self.proj2action = nn.Linear(config.hidden_dim, config.global_cond_dim) self.norm_after_pool = nn.LayerNorm(config.global_cond_dim) - self.combine = nn.Linear(config.global_cond_dim+config.state_dim, config.global_cond_dim) + self.combine = nn.Linear(config.global_cond_dim + config.state_dim, config.global_cond_dim) dsed = config.diffusion_step_embed_dim diffusion_step_encoder = nn.Sequential( SinusoidalPosEmb(dsed, torch.bfloat16), @@ -158,44 +152,76 @@ class ConditionalUnet1D(PreTrainedModel): ) cond_dim = dsed + config.global_cond_dim - in_out = list(zip(all_dims[:-1], all_dims[1:])) + in_out = list(zip(all_dims[:-1], all_dims[1:], strict=False)) mid_dim = all_dims[-1] - self.mid_modules = nn.ModuleList([ - ConditionalResidualBlock1D( - mid_dim, mid_dim, cond_dim=cond_dim, - kernel_size=config.kernel_size, n_groups=config.n_groups - ), - ConditionalResidualBlock1D( - mid_dim, mid_dim, cond_dim=cond_dim, - kernel_size=config.kernel_size, n_groups=config.n_groups - ), - ]) + self.mid_modules = nn.ModuleList( + [ + ConditionalResidualBlock1D( + mid_dim, + mid_dim, + cond_dim=cond_dim, + kernel_size=config.kernel_size, + n_groups=config.n_groups, + ), + ConditionalResidualBlock1D( + mid_dim, + mid_dim, + cond_dim=cond_dim, + kernel_size=config.kernel_size, + n_groups=config.n_groups, + ), + ] + ) down_modules = nn.ModuleList([]) for ind, (dim_in, dim_out) in enumerate(in_out): is_last = ind >= (len(in_out) - 1) - down_modules.append(nn.ModuleList([ - ConditionalResidualBlock1D( - dim_in, dim_out, cond_dim=cond_dim, - kernel_size=config.kernel_size, n_groups=config.n_groups), - ConditionalResidualBlock1D( - dim_out, dim_out, cond_dim=cond_dim, - kernel_size=config.kernel_size, n_groups=config.n_groups), - Downsample1d(dim_out) if not is_last else nn.Identity() - ])) + down_modules.append( + nn.ModuleList( + [ + ConditionalResidualBlock1D( + dim_in, + dim_out, + cond_dim=cond_dim, + kernel_size=config.kernel_size, + n_groups=config.n_groups, + ), + ConditionalResidualBlock1D( + dim_out, + dim_out, + cond_dim=cond_dim, + kernel_size=config.kernel_size, + n_groups=config.n_groups, + ), + Downsample1d(dim_out) if not is_last else nn.Identity(), + ] + ) + ) up_modules = nn.ModuleList([]) for ind, (dim_in, dim_out) in enumerate(reversed(in_out[1:])): is_last = ind >= (len(in_out) - 1) - up_modules.append(nn.ModuleList([ - ConditionalResidualBlock1D( - dim_out * 2, dim_in, cond_dim=cond_dim, - kernel_size=config.kernel_size, n_groups=config.n_groups), - ConditionalResidualBlock1D( - dim_in, dim_in, cond_dim=cond_dim, - kernel_size=config.kernel_size, n_groups=config.n_groups), - Upsample1d(dim_in) if not is_last else nn.Identity() - ])) + up_modules.append( + nn.ModuleList( + [ + ConditionalResidualBlock1D( + dim_out * 2, + dim_in, + cond_dim=cond_dim, + kernel_size=config.kernel_size, + n_groups=config.n_groups, + ), + ConditionalResidualBlock1D( + dim_in, + dim_in, + cond_dim=cond_dim, + kernel_size=config.kernel_size, + n_groups=config.n_groups, + ), + Upsample1d(dim_in) if not is_last else nn.Identity(), + ] + ) + ) final_conv = nn.Sequential( Conv1dBlock(start_dim, start_dim, kernel_size=config.kernel_size), @@ -207,20 +233,17 @@ class ConditionalUnet1D(PreTrainedModel): self.down_modules = down_modules self.final_conv = final_conv - print("number of parameters: {:e}".format( - sum(p.numel() for p in self.parameters())) - ) + print("number of parameters: {:e}".format(sum(p.numel() for p in self.parameters()))) - from diffusers.schedulers.scheduling_ddim import DDIMScheduler self.num_inference_timesteps = config.num_inference_timesteps # self.proj_to_action = nn.Identity() self.noise_scheduler = DDIMScheduler( num_train_timesteps=config.num_train_timesteps, # 100 - beta_schedule='squaredcos_cap_v2', + beta_schedule="squaredcos_cap_v2", clip_sample=True, set_alpha_to_one=True, steps_offset=0, - prediction_type='epsilon' + prediction_type="epsilon", ) # self.num_inference_timesteps = config.num_inference_timesteps # 100 @@ -235,25 +258,26 @@ class ConditionalUnet1D(PreTrainedModel): """ if actions is not None: # training time B = actions.size(0) - actions = copy.deepcopy(actions[:, :self.num_queries]) - is_pad = copy.deepcopy(is_pad[:, :self.num_queries]) + actions = copy.deepcopy(actions[:, : self.num_queries]) + is_pad = copy.deepcopy(is_pad[:, : self.num_queries]) num_noise_samples = self.noise_samples # sample noise to add to actions - noise = torch.randn([num_noise_samples] + list(actions.shape), device=actions.device, - dtype=actions.dtype) # num_noise, B, Ta, D + noise = torch.randn( + [num_noise_samples] + list(actions.shape), device=actions.device, dtype=actions.dtype + ) # num_noise, B, Ta, D # sample a diffusion iteration for each data point timesteps = torch.randint( - 0, self.noise_scheduler.config.num_train_timesteps, - (B,), device=actions.device + 0, self.noise_scheduler.config.num_train_timesteps, (B,), device=actions.device ).long() timesteps, noise = timesteps.to(actions.device), noise.to(actions.device) # add noise to the clean actions according to the noise magnitude at each diffusion iteration # (this is the forward diffusion process) - noisy_actions = torch.cat([self.noise_scheduler.add_noise( - actions, noise[i], timesteps) - for i in range(len(noise))], dim=0) # [num_noise_samples * B, Ta, action_dim] + noisy_actions = torch.cat( + [self.noise_scheduler.add_noise(actions, noise[i], timesteps) for i in range(len(noise))], + dim=0, + ) # [num_noise_samples * B, Ta, action_dim] noisy_actions = noisy_actions.to(dtype=actions.dtype) assert hidden_states.ndim == 3 @@ -263,12 +287,14 @@ class ConditionalUnet1D(PreTrainedModel): is_pad = is_pad.repeat(num_noise_samples, 1) states = states.repeat(num_noise_samples, 1) - noise_pred = self.model_forward(noisy_actions, timesteps, global_cond=hidden_states, states=states) + noise_pred = self.model_forward( + noisy_actions, timesteps, global_cond=hidden_states, states=states + ) noise = noise.view(noise.size(0) * noise.size(1), *noise.size()[2:]) - loss = torch.nn.functional.mse_loss(noise_pred, noise, reduction='none') + loss = torch.nn.functional.mse_loss(noise_pred, noise, reduction="none") loss = (loss * ~is_pad.unsqueeze(-1)).mean() # loss_dict['loss'] = loss - return {'loss': loss} + return {"loss": loss} # return loss else: # inference time B = 1 @@ -288,18 +314,14 @@ class ConditionalUnet1D(PreTrainedModel): # inverse diffusion step (remove noise) naction = self.noise_scheduler.step( - model_output=noise_pred, - timestep=k, - sample=naction + model_output=noise_pred, timestep=k, sample=naction ).prev_sample return naction - def model_forward(self, - sample: torch.Tensor, - timestep: Union[torch.Tensor, float, int], - global_cond=None, - states=None): + def model_forward( + self, sample: torch.Tensor, timestep: Union[torch.Tensor, float, int], global_cond=None, states=None + ): """ x: (B,T,input_dim) timestep: (B,) or int, diffusion step @@ -327,9 +349,7 @@ class ConditionalUnet1D(PreTrainedModel): global_feature = self.diffusion_step_encoder(timesteps) if global_cond is not None: - global_feature = torch.cat([ - global_feature, global_cond - ], axis=-1) + global_feature = torch.cat([global_feature, global_cond], axis=-1) x = sample h = [] @@ -355,4 +375,5 @@ class ConditionalUnet1D(PreTrainedModel): # (B,T,C) return x + AutoModel.register(UnetDiffusionPolicyConfig, ConditionalUnet1D) diff --git a/lerobot/common/policies/dexvla/qwe2_vla/configuration_qwen2_vla.py b/lerobot/common/policies/dexvla/qwe2_vla/configuration_qwen2_vla.py index a1a1d81f..f6b46350 100644 --- a/lerobot/common/policies/dexvla/qwe2_vla/configuration_qwen2_vla.py +++ b/lerobot/common/policies/dexvla/qwe2_vla/configuration_qwen2_vla.py @@ -1,4 +1,3 @@ -# coding=utf-8 # Copyright 2024 The Qwen team, Alibaba Group and the HuggingFace Inc. team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); @@ -17,10 +16,10 @@ import os from typing import Union +from transformers import AutoConfig from transformers.configuration_utils import PretrainedConfig from transformers.modeling_rope_utils import rope_config_validation from transformers.utils import logging -from transformers import AutoModel, AutoConfig logger = logging.get_logger(__name__) @@ -56,7 +55,9 @@ class Qwen2VLVisionConfig(PretrainedConfig): self.temporal_patch_size = temporal_patch_size @classmethod - def from_pretrained(cls, pretrained_model_name_or_path: Union[str, os.PathLike], **kwargs) -> "PretrainedConfig": + def from_pretrained( + cls, pretrained_model_name_or_path: Union[str, os.PathLike], **kwargs + ) -> "PretrainedConfig": cls._set_token_in_kwargs(kwargs) config_dict, kwargs = cls.get_config_dict(pretrained_model_name_or_path, **kwargs) @@ -64,7 +65,11 @@ class Qwen2VLVisionConfig(PretrainedConfig): if config_dict.get("model_type") == "qwen2_vl": config_dict = config_dict["vision_config"] - if "model_type" in config_dict and hasattr(cls, "model_type") and config_dict["model_type"] != cls.model_type: + if ( + "model_type" in config_dict + and hasattr(cls, "model_type") + and config_dict["model_type"] != cls.model_type + ): logger.warning( f"You are using a model of type {config_dict['model_type']} to instantiate a model of type " f"{cls.model_type}. This is not supported for all configurations of models and can yield errors." @@ -204,7 +209,7 @@ class Qwen2VLAConfig(PretrainedConfig): vision_config=None, rope_scaling=None, # For loading policy head - policy_head_type='scale_dp_policy', # unet_diffusion_policy + policy_head_type="scale_dp_policy", # unet_diffusion_policy **kwargs, ): if isinstance(vision_config, dict): @@ -221,7 +226,7 @@ class Qwen2VLAConfig(PretrainedConfig): self.use_sliding_window = use_sliding_window self.sliding_window = sliding_window self.max_window_layers = max_window_layers - self.policy_head_type = policy_head_type # for loading policy head + self.policy_head_type = policy_head_type # for loading policy head # for backward compatibility if num_key_value_heads is None: @@ -248,5 +253,5 @@ class Qwen2VLAConfig(PretrainedConfig): super().__init__(tie_word_embeddings=tie_word_embeddings, **kwargs) -from transformers import AutoConfig + AutoConfig.register("qwen2_vla", Qwen2VLAConfig) diff --git a/lerobot/common/policies/dexvla/qwe2_vla/modeling_qwen2_vla.py b/lerobot/common/policies/dexvla/qwe2_vla/modeling_qwen2_vla.py index e37fea19..235c66a3 100644 --- a/lerobot/common/policies/dexvla/qwe2_vla/modeling_qwen2_vla.py +++ b/lerobot/common/policies/dexvla/qwe2_vla/modeling_qwen2_vla.py @@ -1,4 +1,3 @@ -# coding=utf-8 # Copyright 2024 The Qwen team, Alibaba Group and the HuggingFace Inc. team. All rights reserved. # # This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX @@ -19,6 +18,7 @@ # limitations under the License. """PyTorch Qwen2-VL model.""" +import gc import math from dataclasses import dataclass from typing import Any, Dict, List, Optional, Tuple, Union @@ -28,7 +28,7 @@ import torch.nn as nn import torch.nn.functional as F import torch.utils.checkpoint from torch.nn import CrossEntropyLoss, LayerNorm - +from transformers import AutoConfig, AutoModel from transformers.activations import ACT2FN from transformers.cache_utils import Cache, SlidingWindowCache, StaticCache from transformers.generation import GenerationMixin @@ -37,8 +37,6 @@ from transformers.modeling_outputs import ( BaseModelOutputWithPast, ModelOutput, ) -from lerobot.common.policies.dexvla.fusion_modules import * - from transformers.modeling_rope_utils import ROPE_INIT_FUNCTIONS from transformers.modeling_utils import PreTrainedModel from transformers.utils import ( @@ -49,14 +47,13 @@ from transformers.utils import ( logging, replace_return_docstrings, ) -from .configuration_qwen2_vla import Qwen2VLAConfig, Qwen2VLVisionConfig -from transformers import AutoConfig, AutoModel -import gc +from lerobot.common.policies.dexvla.fusion_modules import * + +from .configuration_qwen2_vla import Qwen2VLAConfig, Qwen2VLVisionConfig if is_flash_attn_2_available(): from flash_attn import flash_attn_varlen_func - from transformers.modeling_flash_attention_utils import _flash_attention_forward else: flash_attn_varlen_func = None @@ -161,10 +158,12 @@ class Qwen2VLRotaryEmbedding(nn.Module): inv_freq, self.attention_scaling = self.rope_init_fn( self.config, device, seq_len=seq_len, **self.rope_kwargs ) - self.register_buffer("inv_freq", inv_freq, persistent=False) + self.register_buffer("inv_freq", inv_freq, persistent=False) self.max_seq_len_cached = seq_len - if seq_len < self.original_max_seq_len and self.max_seq_len_cached > self.original_max_seq_len: # reset + if ( + seq_len < self.original_max_seq_len and self.max_seq_len_cached > self.original_max_seq_len + ): # reset self.register_buffer("inv_freq", self.original_inv_freq, persistent=False) self.max_seq_len_cached = self.original_max_seq_len @@ -335,7 +334,9 @@ class VisionAttention(nn.Module): self, hidden_states: torch.Tensor, cu_seqlens: torch.Tensor, rotary_pos_emb: torch.Tensor = None ) -> torch.Tensor: seq_length = hidden_states.shape[0] - q, k, v = self.qkv(hidden_states).reshape(seq_length, 3, self.num_heads, -1).permute(1, 0, 2, 3).unbind(0) + q, k, v = ( + self.qkv(hidden_states).reshape(seq_length, 3, self.num_heads, -1).permute(1, 0, 2, 3).unbind(0) + ) q = apply_rotary_pos_emb_vision(q.unsqueeze(0), rotary_pos_emb).squeeze(0) k = apply_rotary_pos_emb_vision(k.unsqueeze(0), rotary_pos_emb).squeeze(0) @@ -369,7 +370,9 @@ class VisionFlashAttention2(nn.Module): self, hidden_states: torch.Tensor, cu_seqlens: torch.Tensor, rotary_pos_emb: torch.Tensor = None ) -> torch.Tensor: seq_length = hidden_states.shape[0] - q, k, v = self.qkv(hidden_states).reshape(seq_length, 3, self.num_heads, -1).permute(1, 0, 2, 3).unbind(0) + q, k, v = ( + self.qkv(hidden_states).reshape(seq_length, 3, self.num_heads, -1).permute(1, 0, 2, 3).unbind(0) + ) q = apply_rotary_pos_emb_vision(q.unsqueeze(0), rotary_pos_emb).squeeze(0) k = apply_rotary_pos_emb_vision(k.unsqueeze(0), rotary_pos_emb).squeeze(0) @@ -392,7 +395,9 @@ class VisionSdpaAttention(nn.Module): self, hidden_states: torch.Tensor, cu_seqlens: torch.Tensor, rotary_pos_emb: torch.Tensor = None ) -> torch.Tensor: seq_length = hidden_states.shape[0] - q, k, v = self.qkv(hidden_states).reshape(seq_length, 3, self.num_heads, -1).permute(1, 0, 2, 3).unbind(0) + q, k, v = ( + self.qkv(hidden_states).reshape(seq_length, 3, self.num_heads, -1).permute(1, 0, 2, 3).unbind(0) + ) q = apply_rotary_pos_emb_vision(q.unsqueeze(0), rotary_pos_emb).squeeze(0) k = apply_rotary_pos_emb_vision(k.unsqueeze(0), rotary_pos_emb).squeeze(0) @@ -538,7 +543,9 @@ class Qwen2VLAttention(nn.Module): output_attentions: bool = False, use_cache: bool = False, cache_position: Optional[torch.LongTensor] = None, - position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.46 + position_embeddings: Optional[ + Tuple[torch.Tensor, torch.Tensor] + ] = None, # will become mandatory in v4.46 ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: bsz, q_len, _ = hidden_states.size() @@ -569,8 +576,14 @@ class Qwen2VLAttention(nn.Module): ) if past_key_value is not None: - cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} # Specific to RoPE models - key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) + cache_kwargs = { + "sin": sin, + "cos": cos, + "cache_position": cache_position, + } # Specific to RoPE models + key_states, value_states = past_key_value.update( + key_states, value_states, self.layer_idx, cache_kwargs + ) # repeat k/v heads if n_kv_heads < n_heads key_states = repeat_kv(key_states, self.num_key_value_groups) @@ -585,7 +598,9 @@ class Qwen2VLAttention(nn.Module): # Fix precision issues in Qwen2-VL float16 inference # Replace inf values with zeros in attention weights to prevent NaN propagation if query_states.dtype == torch.float16: - attn_weights = torch.where(torch.isinf(attn_weights), torch.zeros_like(attn_weights), attn_weights) + attn_weights = torch.where( + torch.isinf(attn_weights), torch.zeros_like(attn_weights), attn_weights + ) # upcast attention to fp32 attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype) @@ -634,7 +649,9 @@ class Qwen2VLFlashAttention2(Qwen2VLAttention): output_attentions: bool = False, use_cache: bool = False, cache_position: Optional[torch.LongTensor] = None, - position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.46 + position_embeddings: Optional[ + Tuple[torch.Tensor, torch.Tensor] + ] = None, # will become mandatory in v4.46 ): bsz, q_len, _ = hidden_states.size() @@ -696,10 +713,18 @@ class Qwen2VLFlashAttention2(Qwen2VLAttention): if attention_mask is not None: attention_mask = attention_mask[:, slicing_tokens:] - attention_mask = torch.cat([attention_mask, torch.ones_like(attention_mask[:, -1:])], dim=-1) + attention_mask = torch.cat( + [attention_mask, torch.ones_like(attention_mask[:, -1:])], dim=-1 + ) - cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} # Specific to RoPE models - key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) + cache_kwargs = { + "sin": sin, + "cos": cos, + "cache_position": cache_position, + } # Specific to RoPE models + key_states, value_states = past_key_value.update( + key_states, value_states, self.layer_idx, cache_kwargs + ) # repeat k/v heads if n_kv_heads < n_heads key_states = repeat_kv(key_states, self.num_key_value_groups) @@ -781,7 +806,9 @@ class Qwen2VLSdpaAttention(Qwen2VLAttention): output_attentions: bool = False, use_cache: bool = False, cache_position: Optional[torch.LongTensor] = None, - position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.46 + position_embeddings: Optional[ + Tuple[torch.Tensor, torch.Tensor] + ] = None, # will become mandatory in v4.46 ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: if output_attentions: logger.warning_once( @@ -826,8 +853,14 @@ class Qwen2VLSdpaAttention(Qwen2VLAttention): ) if past_key_value is not None: - cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} # Specific to RoPE models - key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) + cache_kwargs = { + "sin": sin, + "cos": cos, + "cache_position": cache_position, + } # Specific to RoPE models + key_states, value_states = past_key_value.update( + key_states, value_states, self.layer_idx, cache_kwargs + ) key_states = repeat_kv(key_states, self.num_key_value_groups) value_states = repeat_kv(value_states, self.num_key_value_groups) @@ -897,7 +930,9 @@ class Qwen2VLDecoderLayer(nn.Module): output_attentions: Optional[bool] = False, use_cache: Optional[bool] = False, cache_position: Optional[torch.LongTensor] = None, - position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.46 + position_embeddings: Optional[ + Tuple[torch.Tensor, torch.Tensor] + ] = None, # will become mandatory in v4.46 **kwargs, ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: """ @@ -1116,7 +1151,9 @@ class Qwen2VLModel(Qwen2VLPreTrainedModel): return_dict: Optional[bool] = None, cache_position: Optional[torch.LongTensor] = None, ) -> Union[Tuple, BaseModelOutputWithPast]: - output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_attentions = ( + output_attentions if output_attentions is not None else self.config.output_attentions + ) output_hidden_states = ( output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states ) @@ -1208,7 +1245,9 @@ class Qwen2VLModel(Qwen2VLPreTrainedModel): next_cache = next_decoder_cache if use_cache else None if not return_dict: - return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None) + return tuple( + v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None + ) return BaseModelOutputWithPast( last_hidden_state=hidden_states, past_key_values=next_cache, @@ -1436,6 +1475,7 @@ QWEN2_VL_INPUTS_DOCSTRING = r""" The rope index difference between sequence length and multimodal rope. """ + class Qwen2VLForConditionalGenerationForVLA(Qwen2VLPreTrainedModel, GenerationMixin): _tied_weights_keys = ["lm_head.weight"] @@ -1599,9 +1639,15 @@ class Qwen2VLForConditionalGenerationForVLA(Qwen2VLPreTrainedModel, GenerationMi st_idx = llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0 llm_pos_ids_list.append(torch.arange(text_len).view(1, -1).expand(3, -1) + st_idx) - t_index = torch.arange(llm_grid_t).view(-1, 1).expand(-1, llm_grid_h * llm_grid_w).flatten() - h_index = torch.arange(llm_grid_h).view(1, -1, 1).expand(llm_grid_t, -1, llm_grid_w).flatten() - w_index = torch.arange(llm_grid_w).view(1, 1, -1).expand(llm_grid_t, llm_grid_h, -1).flatten() + t_index = ( + torch.arange(llm_grid_t).view(-1, 1).expand(-1, llm_grid_h * llm_grid_w).flatten() + ) + h_index = ( + torch.arange(llm_grid_h).view(1, -1, 1).expand(llm_grid_t, -1, llm_grid_w).flatten() + ) + w_index = ( + torch.arange(llm_grid_w).view(1, 1, -1).expand(llm_grid_t, llm_grid_h, -1).flatten() + ) llm_pos_ids_list.append(torch.stack([t_index, h_index, w_index]) + text_len + st_idx) st = ed + llm_grid_t * llm_grid_h * llm_grid_w @@ -1721,18 +1767,20 @@ class Qwen2VLForConditionalGenerationForVLA(Qwen2VLPreTrainedModel, GenerationMi ```""" self.computed_type = torch.bfloat16 - input_ids=input_ids.to("cuda") - attention_mask=attention_mask.to("cuda") + input_ids = input_ids.to("cuda") + attention_mask = attention_mask.to("cuda") if not is_eval: labels = labels.to("cuda") - actions = actions.to(dtype=self.computed_type, device='cuda') - states = states.to(dtype=self.computed_type, device='cuda') + actions = actions.to(dtype=self.computed_type, device="cuda") + states = states.to(dtype=self.computed_type, device="cuda") position_ids, rope_deltas = self.get_rope_index( input_ids, image_grid_thw, video_grid_thw, attention_mask ) if pixel_values is not None: - pixel_values = pixel_values.to(dtype=self.computed_type, device='cuda') - output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + pixel_values = pixel_values.to(dtype=self.computed_type, device="cuda") + output_attentions = ( + output_attentions if output_attentions is not None else self.config.output_attentions + ) output_hidden_states = ( output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states ) @@ -1792,7 +1840,7 @@ class Qwen2VLForConditionalGenerationForVLA(Qwen2VLPreTrainedModel, GenerationMi ) hidden_states = outputs[0] - if tinyvla: # dex-vla supports tinyvla-style VLA + if tinyvla: # dex-vla supports tinyvla-style VLA return hidden_states if self.with_llm_head: @@ -1831,23 +1879,30 @@ class Qwen2VLForConditionalGenerationForVLA(Qwen2VLPreTrainedModel, GenerationMi attentions=outputs.attentions, rope_deltas=rope_deltas, ) - + if self.using_film: - action_hidden_states = self.film_forward(labels=labels, input_ids=input_ids, - hidden_states=hidden_states) - else: # tinyvla + action_hidden_states = self.film_forward( + labels=labels, input_ids=input_ids, hidden_states=hidden_states + ) + else: # tinyvla action_hidden_states = hidden_states - ret = self.policy_head(actions=actions, hidden_states=action_hidden_states, states=states, is_pad=is_pad) + ret = self.policy_head( + actions=actions, hidden_states=action_hidden_states, states=states, is_pad=is_pad + ) if self.with_llm_head: - loss = {'loss': ret['loss'] + self.llm_loss_weight * llm_loss, - 'llm_loss': llm_loss, - 'action_loss': ret['loss']} + loss = { + "loss": ret["loss"] + self.llm_loss_weight * llm_loss, + "llm_loss": llm_loss, + "action_loss": ret["loss"], + } else: - loss = {'loss': ret['loss'], - 'llm_loss': (torch.ones(1)*(-100)).to(ret['loss'].dtype).squeeze(0), - 'action_loss': ret['loss']} + loss = { + "loss": ret["loss"], + "llm_loss": (torch.ones(1) * (-100)).to(ret["loss"].dtype).squeeze(0), + "action_loss": ret["loss"], + } if not return_dict: output = (logits,) + outputs[1:] @@ -1904,30 +1959,32 @@ class Qwen2VLForConditionalGenerationForVLA(Qwen2VLPreTrainedModel, GenerationMi return action_hidden_states def prepare_inputs_for_generation( - self, - input_ids, - past_key_values=None, - attention_mask=None, - inputs_embeds=None, - cache_position=None, - position_ids=None, - use_cache=True, - pixel_values=None, - pixel_values_videos=None, - image_grid_thw=None, - video_grid_thw=None, - **kwargs, + self, + input_ids, + past_key_values=None, + attention_mask=None, + inputs_embeds=None, + cache_position=None, + position_ids=None, + use_cache=True, + pixel_values=None, + pixel_values_videos=None, + image_grid_thw=None, + video_grid_thw=None, + **kwargs, ): # If we have cache: let's slice `input_ids` through `cache_position`, to keep only the unprocessed tokens # Exception 1: when passing input_embeds, input_ids may be missing entries # Exception 2: some generation methods do special slicing of input_ids, so we don't need to do it here if past_key_values is not None: if inputs_embeds is not None: # Exception 1 - input_ids = input_ids[:, -cache_position.shape[0]:] - elif input_ids.shape[1] != cache_position.shape[0]: # Default case (the "else", a no op, is Exception 2) + input_ids = input_ids[:, -cache_position.shape[0] :] + elif ( + input_ids.shape[1] != cache_position.shape[0] + ): # Default case (the "else", a no op, is Exception 2) input_ids = input_ids[:, cache_position] - rope_deltas = kwargs.get("rope_deltas", None) + rope_deltas = kwargs.get("rope_deltas") if attention_mask is not None and position_ids is None: if cache_position is None or (cache_position is not None and cache_position[0] == 0): position_ids, rope_deltas = self.get_rope_index( @@ -1936,7 +1993,9 @@ class Qwen2VLForConditionalGenerationForVLA(Qwen2VLPreTrainedModel, GenerationMi else: batch_size, seq_length = input_ids.shape delta = ( - cache_position[0] + rope_deltas if cache_position is not None and rope_deltas is not None else 0 + cache_position[0] + rope_deltas + if cache_position is not None and rope_deltas is not None + else 0 ) position_ids = torch.arange(seq_length, device=input_ids.device) position_ids = position_ids.view(1, -1).expand(batch_size, -1) @@ -1990,6 +2049,6 @@ class Qwen2VLForConditionalGenerationForVLA(Qwen2VLPreTrainedModel, GenerationMi return model_inputs - from transformers import AutoModelForCausalLM + AutoModelForCausalLM.register(Qwen2VLAConfig, Qwen2VLForConditionalGenerationForVLA) diff --git a/lerobot/common/policies/dexvla/robot_data_processor.py b/lerobot/common/policies/dexvla/robot_data_processor.py index 85e43572..08ede3e0 100644 --- a/lerobot/common/policies/dexvla/robot_data_processor.py +++ b/lerobot/common/policies/dexvla/robot_data_processor.py @@ -1,16 +1,15 @@ -from PIL import Image import numpy as np -from torchvision.transforms.functional import to_pil_image, to_tensor -import torchvision.transforms as transforms import torch -from qwen_vl_utils import process_vision_info +from PIL import Image from qwen_vl_utils import fetch_image + + class Qwen2VLAProcess: def __init__( - self, - tokenizer=None, - max_seq_len=512, - multimodal_processor=None, + self, + tokenizer=None, + max_seq_len=512, + multimodal_processor=None, ): super().__init__() self.tokenizer = tokenizer @@ -20,10 +19,10 @@ class Qwen2VLAProcess: def qwen2_image_preprocess(self, each): ele = {} each = Image.fromarray(each.squeeze(0).permute(1, 2, 0).cpu().numpy().astype(np.uint8)) - ele['image'] = each + ele["image"] = each - ele['resized_height'] = each.height - ele['resized_width'] = each.width + ele["resized_height"] = each.height + ele["resized_width"] = each.width each = fetch_image(ele) return torch.from_numpy(np.array(each)) @@ -58,61 +57,63 @@ class Qwen2VLAProcess: if eval: return model_inputs - input_labels = torch.ones_like(model_inputs['input_ids']) * -100 + input_labels = torch.ones_like(model_inputs["input_ids"]) * -100 if use_reasoning: - answer =reasoning + "Next action:" + '<|im_end|>' + answer = reasoning + "Next action:" + "<|im_end|>" else: - answer = '' + '<|im_end|>' + answer = "" + "<|im_end|>" output_text = self.tokenizer(answer, padding=True, return_tensors="pt") - output_labels = output_text['input_ids'] - model_inputs['input_ids'] = torch.cat((model_inputs['input_ids'], output_text['input_ids']), dim=-1) - model_inputs['attention_mask'] = torch.cat((model_inputs['attention_mask'], output_text['attention_mask']), dim=-1) + output_labels = output_text["input_ids"] + model_inputs["input_ids"] = torch.cat((model_inputs["input_ids"], output_text["input_ids"]), dim=-1) + model_inputs["attention_mask"] = torch.cat( + (model_inputs["attention_mask"], output_text["attention_mask"]), dim=-1 + ) labels = torch.cat((input_labels, output_labels), dim=-1) - data_dict['labels'] = labels + data_dict["labels"] = labels for k, v in model_inputs.items(): data_dict[k] = v return data_dict def forward(self, batch, use_reasoning=True): """This is the main process function for processing vl data into Qwen2_vl format""" - all_images = batch['images'] - all_images = torch.einsum('v b c h w -> b v c h w', all_images) # camera_views, batch_size, channel, height, width + all_images = batch["images"] + all_images = torch.einsum( + "v b c h w -> b v c h w", all_images + ) # camera_views, batch_size, channel, height, width ret_l = [] for idx, images in enumerate(all_images): - raw_lang = batch['raw_langs'][idx] - reasoning = batch['reasonings'][idx] + raw_lang = batch["raw_langs"][idx] + reasoning = batch["reasonings"][idx] ret_dict = self.single_forward_process(images, raw_lang, reasoning, use_reasoning=use_reasoning) ret_l.append(ret_dict) return self.post_process(ret_l) def post_process(self, instances): - input_ids = [torch.flip(instance['input_ids'].squeeze(0), dims=[0]) for instance in instances] - labels = [torch.flip(instance['labels'].squeeze(0), dims=[0]) for instance in instances] + input_ids = [torch.flip(instance["input_ids"].squeeze(0), dims=[0]) for instance in instances] + labels = [torch.flip(instance["labels"].squeeze(0), dims=[0]) for instance in instances] - image_grid_thw = torch.stack([instances['image_grid_thw'] for instances in instances]) - pixel_values = torch.stack([instances['pixel_values'] for instances in instances]) + image_grid_thw = torch.stack([instances["image_grid_thw"] for instances in instances]) + pixel_values = torch.stack([instances["pixel_values"] for instances in instances]) pixel_values_videos = None video_grid_thw = None - labels = torch.nn.utils.rnn.pad_sequence(labels, - batch_first=True, - padding_value=-100) + labels = torch.nn.utils.rnn.pad_sequence(labels, batch_first=True, padding_value=-100) labels = torch.flip(labels, dims=[1]) - input_ids = torch.nn.utils.rnn.pad_sequence(input_ids, - batch_first=True, - padding_value=self.tokenizer.pad_token_id) + input_ids = torch.nn.utils.rnn.pad_sequence( + input_ids, batch_first=True, padding_value=self.tokenizer.pad_token_id + ) input_ids = torch.flip(input_ids, dims=[1]) b = input_ids.shape[0] image_grid_thw = image_grid_thw.reshape(b * image_grid_thw.shape[1], image_grid_thw.shape[2]) pixel_values = pixel_values.reshape(b * pixel_values.shape[1], pixel_values.shape[2]) - attention_mask = input_ids.ne(self.tokenizer.pad_token_id), + attention_mask = (input_ids.ne(self.tokenizer.pad_token_id),) batch = dict( input_ids=input_ids, @@ -126,7 +127,6 @@ class Qwen2VLAProcess: return batch def construct_chat_data(self, len_image, raw_lang): - messages = [ { "role": "user", @@ -135,11 +135,13 @@ class Qwen2VLAProcess: ] for i in range(len_image): - messages[0]['content'].append({ - "type": "image", - "image": None, - }) - messages[0]['content'].append({"type": "text", "text": f""}) - messages[0]['content'][-1]['text'] = raw_lang + messages[0]["content"].append( + { + "type": "image", + "image": None, + } + ) + messages[0]["content"].append({"type": "text", "text": ""}) + messages[0]["content"][-1]["text"] = raw_lang - return messages \ No newline at end of file + return messages diff --git a/lerobot/common/policies/factory.py b/lerobot/common/policies/factory.py index e7777367..299877a6 100644 --- a/lerobot/common/policies/factory.py +++ b/lerobot/common/policies/factory.py @@ -24,9 +24,9 @@ from lerobot.common.datasets.utils import dataset_to_policy_features from lerobot.common.envs.configs import EnvConfig from lerobot.common.envs.utils import env_to_policy_features from lerobot.common.policies.act.configuration_act import ACTConfig +from lerobot.common.policies.dexvla.configuration_dexvla import DexVLAConfig from lerobot.common.policies.diffusion.configuration_diffusion import DiffusionConfig from lerobot.common.policies.pi0.configuration_pi0 import PI0Config -from lerobot.common.policies.dexvla.configuration_dexvla import DexVLAConfig from lerobot.common.policies.pretrained import PreTrainedPolicy from lerobot.common.policies.tdmpc.configuration_tdmpc import TDMPCConfig from lerobot.common.policies.vqbet.configuration_vqbet import VQBeTConfig