Add support for two training

This commit is contained in:
lesjie-wen 2025-03-18 15:15:33 +08:00
parent e3be39426b
commit 6bd6a9d63d
1 changed files with 21 additions and 9 deletions

View File

@ -23,6 +23,7 @@ from transformers import AutoConfig
from lerobot.common.optim.optimizers import AdamWConfig from lerobot.common.optim.optimizers import AdamWConfig
from lerobot.common.optim.schedulers import ( from lerobot.common.optim.schedulers import (
CosineDecayWithWarmupSchedulerConfig, CosineDecayWithWarmupSchedulerConfig,
ConstantWithWarmupSchedulerConfig
) )
from transformers.utils import logging from transformers.utils import logging
from lerobot.configs.policies import PreTrainedConfig from lerobot.configs.policies import PreTrainedConfig
@ -45,9 +46,12 @@ class DexVLAConfig(PreTrainedConfig):
n_obs_steps: int = 1 n_obs_steps: int = 1
hidden_size: int = 1536 hidden_size: int = 1536
qwen2_vl_path: str = None # '/media/rl/HDD/data/weights/Qwen2-VL-2B-Instruct' qwen2_vl_path: str = None # '/media/rl/HDD/data/weights/Qwen2-VL-2B-Instruct', official weights of qwen2vl
pretrained_path: str = None # pretrained dexvla pretrained_path: str = None # for loading pretrained weights of whole dexvla, usually for training stage3
pretrained_scaledp_path: str = None # for loading pretrained weights of ScaleDP(Stage1)
training_stage: int = 2 # specific training stage, [2, 3]
using_film: bool = True using_film: bool = True
llm_loss_weight: float = 1.0 llm_loss_weight: float = 1.0
with_llm_head: bool = True with_llm_head: bool = True
@ -59,7 +63,7 @@ class DexVLAConfig(PreTrainedConfig):
optimizer_eps: float = 1e-8 optimizer_eps: float = 1e-8
optimizer_weight_decay: float = 1e-10 optimizer_weight_decay: float = 1e-10
scheduler_warmup_steps: int = 1_000 scheduler_warmup_steps: int = 2_000
scheduler_decay_steps: int = 30_000 scheduler_decay_steps: int = 30_000
scheduler_decay_lr: float = 2.5e-6 scheduler_decay_lr: float = 2.5e-6
@ -110,6 +114,9 @@ class DexVLAConfig(PreTrainedConfig):
else: else:
raise ValueError(f'Policy head type {self.policy_head_type} not supported') raise ValueError(f'Policy head type {self.policy_head_type} not supported')
if self.training_stage not in [2,3]:
raise ValueError(f"Training stage must be 2 or 3. Got {self.training_stage}.")
self.qwen2_vla_config = AutoConfig.from_pretrained(self.qwen2_vl_path) self.qwen2_vla_config = AutoConfig.from_pretrained(self.qwen2_vl_path)
def validate_features(self) -> None: def validate_features(self) -> None:
@ -134,12 +141,17 @@ class DexVLAConfig(PreTrainedConfig):
) )
def get_scheduler_preset(self): def get_scheduler_preset(self):
return CosineDecayWithWarmupSchedulerConfig( if self.training_stage == 3:
peak_lr=self.optimizer_lr, return CosineDecayWithWarmupSchedulerConfig(
decay_lr=self.scheduler_decay_lr, peak_lr=self.optimizer_lr,
num_warmup_steps=self.scheduler_warmup_steps, decay_lr=self.scheduler_decay_lr,
num_decay_steps=self.scheduler_decay_steps, num_warmup_steps=self.scheduler_warmup_steps,
) num_decay_steps=self.scheduler_decay_steps,
)
else:
return ConstantWithWarmupSchedulerConfig(
num_warmup_steps=self.scheduler_warmup_steps,
)
@property @property
def observation_delta_indices(self) -> None: def observation_delta_indices(self) -> None: