Merge remote-tracking branch 'origin/2025_02_20_add_dexvla' into 2025_02_20_add_dexvla
# Conflicts: # lerobot/common/policies/dexvla/modeling_dexvla.py
This commit is contained in:
commit
a2f80f42fc
|
@ -110,13 +110,15 @@ class CosineDecayWithWarmupSchedulerConfig(LRSchedulerConfig):
|
|||
|
||||
return LambdaLR(optimizer, lr_lambda, -1)
|
||||
|
||||
|
||||
@LRSchedulerConfig.register_subclass("constant_with_warmup")
|
||||
@dataclass
|
||||
class ConstantWithWarmupSchedulerConfig(LRSchedulerConfig):
|
||||
"""Used by DexVLA to train Stage2"""
|
||||
num_warmup_steps: int
|
||||
def build(self, optimizer: Optimizer, num_training_steps: int) -> LambdaLR:
|
||||
|
||||
num_warmup_steps: int
|
||||
|
||||
def build(self, optimizer: Optimizer, num_training_steps: int) -> LambdaLR:
|
||||
def lr_lambda(current_step):
|
||||
def linear_warmup_schedule(current_step):
|
||||
if current_step <= 0:
|
||||
|
@ -134,6 +136,7 @@ class ConstantWithWarmupSchedulerConfig(LRSchedulerConfig):
|
|||
|
||||
return LambdaLR(optimizer, lr_lambda, -1)
|
||||
|
||||
|
||||
def save_scheduler_state(scheduler: LRScheduler, save_dir: Path) -> None:
|
||||
state_dict = scheduler.state_dict()
|
||||
write_json(state_dict, save_dir / SCHEDULER_STATE)
|
||||
|
|
|
@ -26,7 +26,7 @@ Besides, your data should include another key ``action_is_pad`` which is a bool
|
|||
Suppose the size of the action chunk is 5, and the length of the episode is 10. So the action chunk for the last 4 actions must be padded to make sure the length of action chunk is 5.
|
||||
And the mask looks like:
|
||||
~~~python
|
||||
The 6th chunk: [false, false, false, false, true]
|
||||
The 6th chunk: [false, false, false, false, true]
|
||||
The 7th chunk: [false, false, false, true, true]
|
||||
The 8th chunk: [false, false, true, true, true]
|
||||
The 9th chunk: [false, true, true, true, true]
|
||||
|
@ -34,9 +34,9 @@ The 9th chunk: [false, true, true, true, true]
|
|||
|
||||
## 🤗Download Pretrained Weights
|
||||
### Download official Qwen2_VL weights
|
||||
We construct the VLM backbone by integrating Qwen2-VL-2B, a powerful and efficient model, into our framework.
|
||||
The Qwen2-VL 2B serves as the core of our architecture, providing robust capabilities
|
||||
for vision-language tasks. We use off-the-shelf Qwen2-VL model proposed
|
||||
We construct the VLM backbone by integrating Qwen2-VL-2B, a powerful and efficient model, into our framework.
|
||||
The Qwen2-VL 2B serves as the core of our architecture, providing robust capabilities
|
||||
for vision-language tasks. We use off-the-shelf Qwen2-VL model proposed
|
||||
in [Qwen2-VL](https://arxiv.org/pdf/2409.12191) without any post training on VLM itself. You can download the official weights from this link:
|
||||
|
||||
| Model | Link |
|
||||
|
@ -104,7 +104,5 @@ python lerobot/scripts/eval.py \
|
|||
--env.task AlohaInsertion-v0 \
|
||||
--eval.n_episodes 1 \
|
||||
--eval.batch_size 1 \
|
||||
--device cuda
|
||||
--device cuda
|
||||
~~~
|
||||
|
||||
|
||||
|
|
|
@ -21,8 +21,8 @@ from transformers.utils import logging
|
|||
|
||||
from lerobot.common.optim.optimizers import AdamWConfig
|
||||
from lerobot.common.optim.schedulers import (
|
||||
ConstantWithWarmupSchedulerConfig,
|
||||
CosineDecayWithWarmupSchedulerConfig,
|
||||
ConstantWithWarmupSchedulerConfig
|
||||
)
|
||||
from lerobot.configs.policies import PreTrainedConfig
|
||||
from lerobot.configs.types import NormalizationMode
|
||||
|
@ -47,12 +47,14 @@ class DexVLAConfig(PreTrainedConfig):
|
|||
n_obs_steps: int = 1
|
||||
|
||||
hidden_size: int = 1536
|
||||
qwen2_vl_path: str = None # '/media/rl/HDD/data/weights/Qwen2-VL-2B-Instruct', official weights of qwen2vl
|
||||
qwen2_vl_path: str = (
|
||||
None # '/media/rl/HDD/data/weights/Qwen2-VL-2B-Instruct', official weights of qwen2vl
|
||||
)
|
||||
|
||||
pretrained_path: str = None # for loading pretrained weights of whole dexvla, usually for training stage3
|
||||
pretrained_scaledp_path: str = None # for loading pretrained weights of ScaleDP(Stage1)
|
||||
pretrained_path: str = None # for loading pretrained weights of whole dexvla, usually for training stage3
|
||||
pretrained_scaledp_path: str = None # for loading pretrained weights of ScaleDP(Stage1)
|
||||
|
||||
training_stage: int = 2 # specific training stage, [2, 3]
|
||||
training_stage: int = 2 # specific training stage, [2, 3]
|
||||
using_film: bool = True
|
||||
llm_loss_weight: float = 1.0
|
||||
with_llm_head: bool = True
|
||||
|
@ -119,7 +121,7 @@ class DexVLAConfig(PreTrainedConfig):
|
|||
else:
|
||||
raise ValueError(f"Policy head type {self.policy_head_type} not supported")
|
||||
|
||||
if self.training_stage not in [2,3]:
|
||||
if self.training_stage not in [2, 3]:
|
||||
raise ValueError(f"Training stage must be 2 or 3. Got {self.training_stage}.")
|
||||
|
||||
self.qwen2_vla_config = AutoConfig.from_pretrained(self.qwen2_vl_path)
|
||||
|
|
|
@ -55,25 +55,31 @@ class DexVLAPolicy(PreTrainedPolicy):
|
|||
trust_remote_code=True,
|
||||
_fast_init=False,
|
||||
# attn_implementation="flash_attention_2",
|
||||
).to(device='cuda', dtype=torch.bfloat16)
|
||||
).to(device="cuda", dtype=torch.bfloat16)
|
||||
|
||||
if self.config.pretrained_scaledp_path is not None:
|
||||
print(f'\n>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>Loading pretrained ScaleDP weights...<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<')
|
||||
pretrain_scaledp_weights = torch.load(self.config.pretrained_scaledp_path, map_location='cpu')
|
||||
print(
|
||||
"\n>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>Loading pretrained ScaleDP weights...<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<"
|
||||
)
|
||||
pretrain_scaledp_weights = torch.load(self.config.pretrained_scaledp_path, map_location="cpu")
|
||||
|
||||
pretrain_scaledp_weights = pretrain_scaledp_weights['nets']['nets']
|
||||
pretrain_scaledp_weights = pretrain_scaledp_weights["nets"]["nets"]
|
||||
|
||||
keys_to_del_dit = []
|
||||
pretrain_scaledp_weights = {k[7:] if k.startswith('policy.') else k: v for k, v in pretrain_scaledp_weights.items()}
|
||||
pretrain_scaledp_weights = {
|
||||
k[7:] if k.startswith("policy.") else k: v for k, v in pretrain_scaledp_weights.items()
|
||||
}
|
||||
for k in pretrain_scaledp_weights:
|
||||
if 'noise_pred' not in k: # del weights of vision backbones
|
||||
if "noise_pred" not in k: # del weights of vision backbones
|
||||
keys_to_del_dit.append(k)
|
||||
if 'cond_obs_emb' in k:
|
||||
if "cond_obs_emb" in k:
|
||||
keys_to_del_dit.append(k)
|
||||
for k in keys_to_del_dit:
|
||||
del pretrain_scaledp_weights[k]
|
||||
pretrain_scaledp_weights = {k[15:] if k.startswith('noise_pred_net.') else k: v for k, v in
|
||||
pretrain_scaledp_weights.items()}
|
||||
pretrain_scaledp_weights = {
|
||||
k[15:] if k.startswith("noise_pred_net.") else k: v
|
||||
for k, v in pretrain_scaledp_weights.items()
|
||||
}
|
||||
|
||||
self.model.policy_head.load_state_dict(pretrain_scaledp_weights, strict=False)
|
||||
|
||||
|
|
Loading…
Reference in New Issue