diff --git a/lerobot/common/policies/dexvla/configuration_dexvla.py b/lerobot/common/policies/dexvla/configuration_dexvla.py index 6f3c0ef0..1519da14 100644 --- a/lerobot/common/policies/dexvla/configuration_dexvla.py +++ b/lerobot/common/policies/dexvla/configuration_dexvla.py @@ -23,6 +23,7 @@ from transformers import AutoConfig from lerobot.common.optim.optimizers import AdamWConfig from lerobot.common.optim.schedulers import ( CosineDecayWithWarmupSchedulerConfig, + ConstantWithWarmupSchedulerConfig ) from transformers.utils import logging from lerobot.configs.policies import PreTrainedConfig @@ -45,9 +46,12 @@ class DexVLAConfig(PreTrainedConfig): n_obs_steps: int = 1 hidden_size: int = 1536 - qwen2_vl_path: str = None # '/media/rl/HDD/data/weights/Qwen2-VL-2B-Instruct' + qwen2_vl_path: str = None # '/media/rl/HDD/data/weights/Qwen2-VL-2B-Instruct', official weights of qwen2vl - pretrained_path: str = None # pretrained dexvla + pretrained_path: str = None # for loading pretrained weights of whole dexvla, usually for training stage3 + pretrained_scaledp_path: str = None # for loading pretrained weights of ScaleDP(Stage1) + + training_stage: int = 2 # specific training stage, [2, 3] using_film: bool = True llm_loss_weight: float = 1.0 with_llm_head: bool = True @@ -59,7 +63,7 @@ class DexVLAConfig(PreTrainedConfig): optimizer_eps: float = 1e-8 optimizer_weight_decay: float = 1e-10 - scheduler_warmup_steps: int = 1_000 + scheduler_warmup_steps: int = 2_000 scheduler_decay_steps: int = 30_000 scheduler_decay_lr: float = 2.5e-6 @@ -110,6 +114,9 @@ class DexVLAConfig(PreTrainedConfig): else: raise ValueError(f'Policy head type {self.policy_head_type} not supported') + if self.training_stage not in [2,3]: + raise ValueError(f"Training stage must be 2 or 3. Got {self.training_stage}.") + self.qwen2_vla_config = AutoConfig.from_pretrained(self.qwen2_vl_path) def validate_features(self) -> None: @@ -134,12 +141,17 @@ class DexVLAConfig(PreTrainedConfig): ) def get_scheduler_preset(self): - return CosineDecayWithWarmupSchedulerConfig( - peak_lr=self.optimizer_lr, - decay_lr=self.scheduler_decay_lr, - num_warmup_steps=self.scheduler_warmup_steps, - num_decay_steps=self.scheduler_decay_steps, - ) + if self.training_stage == 3: + return CosineDecayWithWarmupSchedulerConfig( + peak_lr=self.optimizer_lr, + decay_lr=self.scheduler_decay_lr, + num_warmup_steps=self.scheduler_warmup_steps, + num_decay_steps=self.scheduler_decay_steps, + ) + else: + return ConstantWithWarmupSchedulerConfig( + num_warmup_steps=self.scheduler_warmup_steps, + ) @property def observation_delta_indices(self) -> None: