Merge remote-tracking branch 'origin/2025_02_20_add_dexvla' into 2025_02_20_add_dexvla
This commit is contained in:
commit
d9b20fa3c3
|
@ -12,8 +12,6 @@
|
||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
"""Qwen2VL model configuration"""
|
"""Qwen2VL model configuration"""
|
||||||
from .policy_heads import register_policy_heads
|
|
||||||
from .qwe2_vla import register_qwen2_vla
|
|
||||||
|
|
||||||
from dataclasses import dataclass, field
|
from dataclasses import dataclass, field
|
||||||
from typing import Tuple
|
from typing import Tuple
|
||||||
|
@ -29,6 +27,9 @@ from lerobot.common.optim.schedulers import (
|
||||||
from lerobot.configs.policies import PreTrainedConfig
|
from lerobot.configs.policies import PreTrainedConfig
|
||||||
from lerobot.configs.types import NormalizationMode
|
from lerobot.configs.types import NormalizationMode
|
||||||
|
|
||||||
|
from .policy_heads import register_policy_heads
|
||||||
|
from .qwe2_vla import register_qwen2_vla
|
||||||
|
|
||||||
logger = logging.get_logger(__name__)
|
logger = logging.get_logger(__name__)
|
||||||
register_policy_heads()
|
register_policy_heads()
|
||||||
register_qwen2_vla()
|
register_qwen2_vla()
|
||||||
|
|
|
@ -2,6 +2,7 @@ from collections import deque
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torchvision.transforms as transforms
|
import torchvision.transforms as transforms
|
||||||
|
from safetensors.torch import load_file
|
||||||
from torch import Tensor
|
from torch import Tensor
|
||||||
from transformers import AutoModelForCausalLM, AutoProcessor, AutoTokenizer
|
from transformers import AutoModelForCausalLM, AutoProcessor, AutoTokenizer
|
||||||
|
|
||||||
|
@ -11,10 +12,6 @@ from lerobot.common.policies.normalize import Normalize, Unnormalize
|
||||||
from lerobot.common.policies.pretrained import PreTrainedPolicy
|
from lerobot.common.policies.pretrained import PreTrainedPolicy
|
||||||
|
|
||||||
|
|
||||||
from collections import deque
|
|
||||||
from transformers import AutoProcessor, AutoTokenizer, AutoModelForCausalLM
|
|
||||||
import torchvision.transforms as transforms
|
|
||||||
from safetensors.torch import load_file
|
|
||||||
class DexVLAPolicy(PreTrainedPolicy):
|
class DexVLAPolicy(PreTrainedPolicy):
|
||||||
"""Wrapper class around Qwen2VLForConditionalGenerationForVLA model to train and run inference within LeRobot."""
|
"""Wrapper class around Qwen2VLForConditionalGenerationForVLA model to train and run inference within LeRobot."""
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue