lerobot/lerobot/common/policies/dexvla/configuration_dexvla.py

175 lines
6.5 KiB
Python

# Copyright 2024 The Qwen team, Alibaba Group and the HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Qwen2VL model configuration"""
from dataclasses import dataclass, field
from typing import Tuple
from transformers import AutoConfig
from transformers.utils import logging
from lerobot.common.optim.optimizers import AdamWConfig
from lerobot.common.optim.schedulers import (
ConstantWithWarmupSchedulerConfig,
CosineDecayWithWarmupSchedulerConfig,
)
from lerobot.configs.policies import PreTrainedConfig
from lerobot.configs.types import NormalizationMode
from .policy_heads import register_policy_heads
from .qwe2_vla import register_qwen2_vla
logger = logging.get_logger(__name__)
register_policy_heads()
register_qwen2_vla()
@PreTrainedConfig.register_subclass("dexvla")
@dataclass
class DexVLAConfig(PreTrainedConfig):
# For loading policy head
policy_head_type: str = "scale_dp_policy"
policy_head_size: str = "scaledp_l"
action_dim: int = 14
state_dim: int = 14
chunk_size: int = 50
n_action_steps: int = 50
n_obs_steps: int = 1
hidden_size: int = 1536
qwen2_vl_path: str = (
None # '/media/rl/HDD/data/weights/Qwen2-VL-2B-Instruct', official weights of qwen2vl
)
pretrained_path: str = None # for loading pretrained weights of whole dexvla, usually for training stage3
pretrained_scaledp_path: str = None # for loading pretrained weights of ScaleDP(Stage1)
training_stage: int = 2 # specific training stage, [2, 3]
using_film: bool = True
llm_loss_weight: float = 1.0
with_llm_head: bool = True
using_reasoning: bool = True
resize_size: tuple = (240, 320)
# Training presets
optimizer_lr: float = 2e-5
optimizer_betas: Tuple[float, float] = (0.9, 0.95)
optimizer_eps: float = 1e-8
optimizer_weight_decay: float = 1e-10
scheduler_warmup_steps: int = 2_000
scheduler_decay_steps: int = 30_000
scheduler_decay_lr: float = 2.5e-6
normalization_mapping: dict[str, NormalizationMode] = field(
default_factory=lambda: {
# "VISUAL": NormalizationMode.MEAN_STD,
"STATE": NormalizationMode.MEAN_STD,
"ACTION": NormalizationMode.MIN_MAX,
}
)
def __post_init__(self):
if self.n_action_steps > self.chunk_size:
raise ValueError(
f"The chunk size is the upper bound for the number of action steps per model invocation. Got "
f"{self.n_action_steps} for `n_action_steps` and {self.chunk_size} for `chunk_size`."
)
if self.n_obs_steps != 1:
raise ValueError(
f"Multiple observation steps not handled yet. Got `nobs_steps={self.n_obs_steps}`"
)
if self.using_reasoning:
assert self.using_film, "using_reasoning requires `using_film=True`"
assert self.with_llm_head, "using_reasoning requires `with_llm_head=True`"
print("You have set using_reasoning=True, please make sure your data has key 'reasoning'.")
else:
print(
"Warning:DexVLA recommends to use reasoning data which can better handle long-horizon and dexterous tasks. You can set 'using_reaasoning=True'."
)
if self.qwen2_vl_path is None:
raise ValueError(
"DexVLA is built on official qwen2_vl-2B. You have to download the official weights of qwen2_vl-2B first and set 'qwen2_vl_path'."
)
if self.policy_head_type == "scale_dp_policy":
self.policy_head_config = AutoConfig.for_model(
model_type=self.policy_head_type,
model_size=self.policy_head_size,
cond_dim=self.hidden_size,
action_dim=self.action_dim,
prediction_horizon=self.chunk_size,
state_dim=self.state_dim,
)
elif self.policy_head_type == "unet_diffusion":
self.policy_head_config = AutoConfig.for_model(
model_type=self.policy_head_type,
global_cond_dim=self.hidden_size,
action_dim=self.action_dim,
state_dim=self.state_dim,
)
else:
raise ValueError(f"Policy head type {self.policy_head_type} not supported")
if self.training_stage not in [2, 3]:
raise ValueError(f"Training stage must be 2 or 3. Got {self.training_stage}.")
self.qwen2_vla_config = AutoConfig.from_pretrained(self.qwen2_vl_path)
def validate_features(self) -> None:
# TODO: implement value error
if not self.image_features and not self.env_state_feature:
raise ValueError("You must provide at least one image or the environment state among the inputs.")
# for i in range(self.empty_cameras):
# key = f"observation.images.empty_camera_{i}"
# empty_camera = PolicyFeature(
# type=FeatureType.VISUAL,
# shape=(3, 480, 640),
# )
# self.input_features[key] = empty_camera
def get_optimizer_preset(self) -> AdamWConfig:
return AdamWConfig(
lr=self.optimizer_lr,
betas=self.optimizer_betas,
eps=self.optimizer_eps,
weight_decay=self.optimizer_weight_decay,
)
def get_scheduler_preset(self):
if self.training_stage == 3:
return CosineDecayWithWarmupSchedulerConfig(
peak_lr=self.optimizer_lr,
decay_lr=self.scheduler_decay_lr,
num_warmup_steps=self.scheduler_warmup_steps,
num_decay_steps=self.scheduler_decay_steps,
)
else:
return ConstantWithWarmupSchedulerConfig(
num_warmup_steps=self.scheduler_warmup_steps,
)
@property
def observation_delta_indices(self) -> None:
return None
@property
def action_delta_indices(self) -> list:
return list(range(self.chunk_size))
@property
def reward_delta_indices(self) -> None:
return None