diff --git a/lerobot/common/optim/schedulers.py b/lerobot/common/optim/schedulers.py index 486a0b99..e2ebb9e3 100644 --- a/lerobot/common/optim/schedulers.py +++ b/lerobot/common/optim/schedulers.py @@ -110,13 +110,15 @@ class CosineDecayWithWarmupSchedulerConfig(LRSchedulerConfig): return LambdaLR(optimizer, lr_lambda, -1) + @LRSchedulerConfig.register_subclass("constant_with_warmup") @dataclass class ConstantWithWarmupSchedulerConfig(LRSchedulerConfig): """Used by DexVLA to train Stage2""" - num_warmup_steps: int - def build(self, optimizer: Optimizer, num_training_steps: int) -> LambdaLR: + num_warmup_steps: int + + def build(self, optimizer: Optimizer, num_training_steps: int) -> LambdaLR: def lr_lambda(current_step): def linear_warmup_schedule(current_step): if current_step <= 0: @@ -134,6 +136,7 @@ class ConstantWithWarmupSchedulerConfig(LRSchedulerConfig): return LambdaLR(optimizer, lr_lambda, -1) + def save_scheduler_state(scheduler: LRScheduler, save_dir: Path) -> None: state_dict = scheduler.state_dict() write_json(state_dict, save_dir / SCHEDULER_STATE) diff --git a/lerobot/common/policies/dexvla/README.md b/lerobot/common/policies/dexvla/README.md index 51f6917f..4d832c01 100644 --- a/lerobot/common/policies/dexvla/README.md +++ b/lerobot/common/policies/dexvla/README.md @@ -26,7 +26,7 @@ Besides, your data should include another key ``action_is_pad`` which is a bool Suppose the size of the action chunk is 5, and the length of the episode is 10. So the action chunk for the last 4 actions must be padded to make sure the length of action chunk is 5. And the mask looks like: ~~~python -The 6th chunk: [false, false, false, false, true] +The 6th chunk: [false, false, false, false, true] The 7th chunk: [false, false, false, true, true] The 8th chunk: [false, false, true, true, true] The 9th chunk: [false, true, true, true, true] @@ -34,9 +34,9 @@ The 9th chunk: [false, true, true, true, true] ## šŸ¤—Download Pretrained Weights ### Download official Qwen2_VL weights -We construct the VLM backbone by integrating Qwen2-VL-2B, a powerful and efficient model, into our framework. -The Qwen2-VL 2B serves as the core of our architecture, providing robust capabilities -for vision-language tasks. We use off-the-shelf Qwen2-VL model proposed +We construct the VLM backbone by integrating Qwen2-VL-2B, a powerful and efficient model, into our framework. +The Qwen2-VL 2B serves as the core of our architecture, providing robust capabilities +for vision-language tasks. We use off-the-shelf Qwen2-VL model proposed in [Qwen2-VL](https://arxiv.org/pdf/2409.12191) without any post training on VLM itself. You can download the official weights from this link: | Model | Link | @@ -104,7 +104,5 @@ python lerobot/scripts/eval.py \ --env.task AlohaInsertion-v0 \ --eval.n_episodes 1 \ --eval.batch_size 1 \ ---device cuda +--device cuda ~~~ - - diff --git a/lerobot/common/policies/dexvla/configuration_dexvla.py b/lerobot/common/policies/dexvla/configuration_dexvla.py index 5c1f6743..f666b041 100644 --- a/lerobot/common/policies/dexvla/configuration_dexvla.py +++ b/lerobot/common/policies/dexvla/configuration_dexvla.py @@ -21,8 +21,8 @@ from transformers.utils import logging from lerobot.common.optim.optimizers import AdamWConfig from lerobot.common.optim.schedulers import ( + ConstantWithWarmupSchedulerConfig, CosineDecayWithWarmupSchedulerConfig, - ConstantWithWarmupSchedulerConfig ) from lerobot.configs.policies import PreTrainedConfig from lerobot.configs.types import NormalizationMode @@ -47,12 +47,14 @@ class DexVLAConfig(PreTrainedConfig): n_obs_steps: int = 1 hidden_size: int = 1536 - qwen2_vl_path: str = None # '/media/rl/HDD/data/weights/Qwen2-VL-2B-Instruct', official weights of qwen2vl + qwen2_vl_path: str = ( + None # '/media/rl/HDD/data/weights/Qwen2-VL-2B-Instruct', official weights of qwen2vl + ) - pretrained_path: str = None # for loading pretrained weights of whole dexvla, usually for training stage3 - pretrained_scaledp_path: str = None # for loading pretrained weights of ScaleDP(Stage1) + pretrained_path: str = None # for loading pretrained weights of whole dexvla, usually for training stage3 + pretrained_scaledp_path: str = None # for loading pretrained weights of ScaleDP(Stage1) - training_stage: int = 2 # specific training stage, [2, 3] + training_stage: int = 2 # specific training stage, [2, 3] using_film: bool = True llm_loss_weight: float = 1.0 with_llm_head: bool = True @@ -119,7 +121,7 @@ class DexVLAConfig(PreTrainedConfig): else: raise ValueError(f"Policy head type {self.policy_head_type} not supported") - if self.training_stage not in [2,3]: + if self.training_stage not in [2, 3]: raise ValueError(f"Training stage must be 2 or 3. Got {self.training_stage}.") self.qwen2_vla_config = AutoConfig.from_pretrained(self.qwen2_vl_path) diff --git a/lerobot/common/policies/dexvla/modeling_dexvla.py b/lerobot/common/policies/dexvla/modeling_dexvla.py index 48fdd2ad..7c321c1b 100644 --- a/lerobot/common/policies/dexvla/modeling_dexvla.py +++ b/lerobot/common/policies/dexvla/modeling_dexvla.py @@ -55,25 +55,31 @@ class DexVLAPolicy(PreTrainedPolicy): trust_remote_code=True, _fast_init=False, # attn_implementation="flash_attention_2", - ).to(device='cuda', dtype=torch.bfloat16) + ).to(device="cuda", dtype=torch.bfloat16) if self.config.pretrained_scaledp_path is not None: - print(f'\n>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>Loading pretrained ScaleDP weights...<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<') - pretrain_scaledp_weights = torch.load(self.config.pretrained_scaledp_path, map_location='cpu') + print( + "\n>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>Loading pretrained ScaleDP weights...<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<" + ) + pretrain_scaledp_weights = torch.load(self.config.pretrained_scaledp_path, map_location="cpu") - pretrain_scaledp_weights = pretrain_scaledp_weights['nets']['nets'] + pretrain_scaledp_weights = pretrain_scaledp_weights["nets"]["nets"] keys_to_del_dit = [] - pretrain_scaledp_weights = {k[7:] if k.startswith('policy.') else k: v for k, v in pretrain_scaledp_weights.items()} + pretrain_scaledp_weights = { + k[7:] if k.startswith("policy.") else k: v for k, v in pretrain_scaledp_weights.items() + } for k in pretrain_scaledp_weights: - if 'noise_pred' not in k: # del weights of vision backbones + if "noise_pred" not in k: # del weights of vision backbones keys_to_del_dit.append(k) - if 'cond_obs_emb' in k: + if "cond_obs_emb" in k: keys_to_del_dit.append(k) for k in keys_to_del_dit: del pretrain_scaledp_weights[k] - pretrain_scaledp_weights = {k[15:] if k.startswith('noise_pred_net.') else k: v for k, v in - pretrain_scaledp_weights.items()} + pretrain_scaledp_weights = { + k[15:] if k.startswith("noise_pred_net.") else k: v + for k, v in pretrain_scaledp_weights.items() + } self.model.policy_head.load_state_dict(pretrain_scaledp_weights, strict=False)