Merge remote-tracking branch 'origin/2025_02_20_add_dexvla' into 2025_02_20_add_dexvla

This commit is contained in:
lesjie-wen 2025-03-18 18:37:18 +08:00
commit d9b20fa3c3
2 changed files with 4 additions and 6 deletions

View File

@ -12,8 +12,6 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
"""Qwen2VL model configuration""" """Qwen2VL model configuration"""
from .policy_heads import register_policy_heads
from .qwe2_vla import register_qwen2_vla
from dataclasses import dataclass, field from dataclasses import dataclass, field
from typing import Tuple from typing import Tuple
@ -29,6 +27,9 @@ from lerobot.common.optim.schedulers import (
from lerobot.configs.policies import PreTrainedConfig from lerobot.configs.policies import PreTrainedConfig
from lerobot.configs.types import NormalizationMode from lerobot.configs.types import NormalizationMode
from .policy_heads import register_policy_heads
from .qwe2_vla import register_qwen2_vla
logger = logging.get_logger(__name__) logger = logging.get_logger(__name__)
register_policy_heads() register_policy_heads()
register_qwen2_vla() register_qwen2_vla()

View File

@ -2,6 +2,7 @@ from collections import deque
import torch import torch
import torchvision.transforms as transforms import torchvision.transforms as transforms
from safetensors.torch import load_file
from torch import Tensor from torch import Tensor
from transformers import AutoModelForCausalLM, AutoProcessor, AutoTokenizer from transformers import AutoModelForCausalLM, AutoProcessor, AutoTokenizer
@ -11,10 +12,6 @@ from lerobot.common.policies.normalize import Normalize, Unnormalize
from lerobot.common.policies.pretrained import PreTrainedPolicy from lerobot.common.policies.pretrained import PreTrainedPolicy
from collections import deque
from transformers import AutoProcessor, AutoTokenizer, AutoModelForCausalLM
import torchvision.transforms as transforms
from safetensors.torch import load_file
class DexVLAPolicy(PreTrainedPolicy): class DexVLAPolicy(PreTrainedPolicy):
"""Wrapper class around Qwen2VLForConditionalGenerationForVLA model to train and run inference within LeRobot.""" """Wrapper class around Qwen2VLForConditionalGenerationForVLA model to train and run inference within LeRobot."""