Merge remote-tracking branch 'origin/2025_02_20_add_dexvla' into 2025_02_20_add_dexvla
This commit is contained in:
commit
d9b20fa3c3
|
@ -12,8 +12,6 @@
|
|||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Qwen2VL model configuration"""
|
||||
from .policy_heads import register_policy_heads
|
||||
from .qwe2_vla import register_qwen2_vla
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Tuple
|
||||
|
@ -29,6 +27,9 @@ from lerobot.common.optim.schedulers import (
|
|||
from lerobot.configs.policies import PreTrainedConfig
|
||||
from lerobot.configs.types import NormalizationMode
|
||||
|
||||
from .policy_heads import register_policy_heads
|
||||
from .qwe2_vla import register_qwen2_vla
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
register_policy_heads()
|
||||
register_qwen2_vla()
|
||||
|
|
|
@ -2,6 +2,7 @@ from collections import deque
|
|||
|
||||
import torch
|
||||
import torchvision.transforms as transforms
|
||||
from safetensors.torch import load_file
|
||||
from torch import Tensor
|
||||
from transformers import AutoModelForCausalLM, AutoProcessor, AutoTokenizer
|
||||
|
||||
|
@ -11,10 +12,6 @@ from lerobot.common.policies.normalize import Normalize, Unnormalize
|
|||
from lerobot.common.policies.pretrained import PreTrainedPolicy
|
||||
|
||||
|
||||
from collections import deque
|
||||
from transformers import AutoProcessor, AutoTokenizer, AutoModelForCausalLM
|
||||
import torchvision.transforms as transforms
|
||||
from safetensors.torch import load_file
|
||||
class DexVLAPolicy(PreTrainedPolicy):
|
||||
"""Wrapper class around Qwen2VLForConditionalGenerationForVLA model to train and run inference within LeRobot."""
|
||||
|
||||
|
|
Loading…
Reference in New Issue