Merge remote-tracking branch 'origin/2025_02_20_add_dexvla' into 2025_02_20_add_dexvla

# Conflicts:
#	lerobot/common/policies/dexvla/modeling_dexvla.py
This commit is contained in:
lesjie-wen 2025-03-18 16:51:48 +08:00
commit a2f80f42fc
4 changed files with 33 additions and 24 deletions

View File

@ -110,13 +110,15 @@ class CosineDecayWithWarmupSchedulerConfig(LRSchedulerConfig):
return LambdaLR(optimizer, lr_lambda, -1)
@LRSchedulerConfig.register_subclass("constant_with_warmup")
@dataclass
class ConstantWithWarmupSchedulerConfig(LRSchedulerConfig):
"""Used by DexVLA to train Stage2"""
num_warmup_steps: int
def build(self, optimizer: Optimizer, num_training_steps: int) -> LambdaLR:
num_warmup_steps: int
def build(self, optimizer: Optimizer, num_training_steps: int) -> LambdaLR:
def lr_lambda(current_step):
def linear_warmup_schedule(current_step):
if current_step <= 0:
@ -134,6 +136,7 @@ class ConstantWithWarmupSchedulerConfig(LRSchedulerConfig):
return LambdaLR(optimizer, lr_lambda, -1)
def save_scheduler_state(scheduler: LRScheduler, save_dir: Path) -> None:
state_dict = scheduler.state_dict()
write_json(state_dict, save_dir / SCHEDULER_STATE)

View File

@ -26,7 +26,7 @@ Besides, your data should include another key ``action_is_pad`` which is a bool
Suppose the size of the action chunk is 5, and the length of the episode is 10. So the action chunk for the last 4 actions must be padded to make sure the length of action chunk is 5.
And the mask looks like:
~~~python
The 6th chunk: [false, false, false, false, true]
The 6th chunk: [false, false, false, false, true]
The 7th chunk: [false, false, false, true, true]
The 8th chunk: [false, false, true, true, true]
The 9th chunk: [false, true, true, true, true]
@ -34,9 +34,9 @@ The 9th chunk: [false, true, true, true, true]
## 🤗Download Pretrained Weights
### Download official Qwen2_VL weights
We construct the VLM backbone by integrating Qwen2-VL-2B, a powerful and efficient model, into our framework.
The Qwen2-VL 2B serves as the core of our architecture, providing robust capabilities
for vision-language tasks. We use off-the-shelf Qwen2-VL model proposed
We construct the VLM backbone by integrating Qwen2-VL-2B, a powerful and efficient model, into our framework.
The Qwen2-VL 2B serves as the core of our architecture, providing robust capabilities
for vision-language tasks. We use off-the-shelf Qwen2-VL model proposed
in [Qwen2-VL](https://arxiv.org/pdf/2409.12191) without any post training on VLM itself. You can download the official weights from this link:
| Model | Link |
@ -104,7 +104,5 @@ python lerobot/scripts/eval.py \
--env.task AlohaInsertion-v0 \
--eval.n_episodes 1 \
--eval.batch_size 1 \
--device cuda
--device cuda
~~~

View File

@ -21,8 +21,8 @@ from transformers.utils import logging
from lerobot.common.optim.optimizers import AdamWConfig
from lerobot.common.optim.schedulers import (
ConstantWithWarmupSchedulerConfig,
CosineDecayWithWarmupSchedulerConfig,
ConstantWithWarmupSchedulerConfig
)
from lerobot.configs.policies import PreTrainedConfig
from lerobot.configs.types import NormalizationMode
@ -47,12 +47,14 @@ class DexVLAConfig(PreTrainedConfig):
n_obs_steps: int = 1
hidden_size: int = 1536
qwen2_vl_path: str = None # '/media/rl/HDD/data/weights/Qwen2-VL-2B-Instruct', official weights of qwen2vl
qwen2_vl_path: str = (
None # '/media/rl/HDD/data/weights/Qwen2-VL-2B-Instruct', official weights of qwen2vl
)
pretrained_path: str = None # for loading pretrained weights of whole dexvla, usually for training stage3
pretrained_scaledp_path: str = None # for loading pretrained weights of ScaleDP(Stage1)
pretrained_path: str = None # for loading pretrained weights of whole dexvla, usually for training stage3
pretrained_scaledp_path: str = None # for loading pretrained weights of ScaleDP(Stage1)
training_stage: int = 2 # specific training stage, [2, 3]
training_stage: int = 2 # specific training stage, [2, 3]
using_film: bool = True
llm_loss_weight: float = 1.0
with_llm_head: bool = True
@ -119,7 +121,7 @@ class DexVLAConfig(PreTrainedConfig):
else:
raise ValueError(f"Policy head type {self.policy_head_type} not supported")
if self.training_stage not in [2,3]:
if self.training_stage not in [2, 3]:
raise ValueError(f"Training stage must be 2 or 3. Got {self.training_stage}.")
self.qwen2_vla_config = AutoConfig.from_pretrained(self.qwen2_vl_path)

View File

@ -55,25 +55,31 @@ class DexVLAPolicy(PreTrainedPolicy):
trust_remote_code=True,
_fast_init=False,
# attn_implementation="flash_attention_2",
).to(device='cuda', dtype=torch.bfloat16)
).to(device="cuda", dtype=torch.bfloat16)
if self.config.pretrained_scaledp_path is not None:
print(f'\n>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>Loading pretrained ScaleDP weights...<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<')
pretrain_scaledp_weights = torch.load(self.config.pretrained_scaledp_path, map_location='cpu')
print(
"\n>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>Loading pretrained ScaleDP weights...<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<"
)
pretrain_scaledp_weights = torch.load(self.config.pretrained_scaledp_path, map_location="cpu")
pretrain_scaledp_weights = pretrain_scaledp_weights['nets']['nets']
pretrain_scaledp_weights = pretrain_scaledp_weights["nets"]["nets"]
keys_to_del_dit = []
pretrain_scaledp_weights = {k[7:] if k.startswith('policy.') else k: v for k, v in pretrain_scaledp_weights.items()}
pretrain_scaledp_weights = {
k[7:] if k.startswith("policy.") else k: v for k, v in pretrain_scaledp_weights.items()
}
for k in pretrain_scaledp_weights:
if 'noise_pred' not in k: # del weights of vision backbones
if "noise_pred" not in k: # del weights of vision backbones
keys_to_del_dit.append(k)
if 'cond_obs_emb' in k:
if "cond_obs_emb" in k:
keys_to_del_dit.append(k)
for k in keys_to_del_dit:
del pretrain_scaledp_weights[k]
pretrain_scaledp_weights = {k[15:] if k.startswith('noise_pred_net.') else k: v for k, v in
pretrain_scaledp_weights.items()}
pretrain_scaledp_weights = {
k[15:] if k.startswith("noise_pred_net.") else k: v
for k, v in pretrain_scaledp_weights.items()
}
self.model.policy_head.load_state_dict(pretrain_scaledp_weights, strict=False)