Merge remote-tracking branch 'origin/2025_02_20_add_dexvla' into 2025_02_20_add_dexvla

This commit is contained in:
lesjie-wen 2025-03-18 18:37:18 +08:00
commit d9b20fa3c3
2 changed files with 4 additions and 6 deletions

View File

@ -12,8 +12,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.
"""Qwen2VL model configuration"""
from .policy_heads import register_policy_heads
from .qwe2_vla import register_qwen2_vla
from dataclasses import dataclass, field
from typing import Tuple
@ -29,6 +27,9 @@ from lerobot.common.optim.schedulers import (
from lerobot.configs.policies import PreTrainedConfig
from lerobot.configs.types import NormalizationMode
from .policy_heads import register_policy_heads
from .qwe2_vla import register_qwen2_vla
logger = logging.get_logger(__name__)
register_policy_heads()
register_qwen2_vla()

View File

@ -2,6 +2,7 @@ from collections import deque
import torch
import torchvision.transforms as transforms
from safetensors.torch import load_file
from torch import Tensor
from transformers import AutoModelForCausalLM, AutoProcessor, AutoTokenizer
@ -11,10 +12,6 @@ from lerobot.common.policies.normalize import Normalize, Unnormalize
from lerobot.common.policies.pretrained import PreTrainedPolicy
from collections import deque
from transformers import AutoProcessor, AutoTokenizer, AutoModelForCausalLM
import torchvision.transforms as transforms
from safetensors.torch import load_file
class DexVLAPolicy(PreTrainedPolicy):
"""Wrapper class around Qwen2VLForConditionalGenerationForVLA model to train and run inference within LeRobot."""