cleaning
This commit is contained in:
parent
998bd92774
commit
043cc9180e
|
@ -22,8 +22,6 @@ OBS_ROBOT = "observation.state"
|
||||||
OBS_IMAGE = "observation.image"
|
OBS_IMAGE = "observation.image"
|
||||||
OBS_IMAGES = "observation.images"
|
OBS_IMAGES = "observation.images"
|
||||||
ACTION = "action"
|
ACTION = "action"
|
||||||
OBS_IMAGE_2 = "observation.image2"
|
|
||||||
OBS_IMAGE_3 = "observation.image3"
|
|
||||||
|
|
||||||
# files & directories
|
# files & directories
|
||||||
CHECKPOINTS_DIR = "checkpoints"
|
CHECKPOINTS_DIR = "checkpoints"
|
||||||
|
|
|
@ -81,7 +81,7 @@ class PI0FASTConfig(PreTrainedConfig):
|
||||||
|
|
||||||
# Allows padding/truncation of generated action tokens during detokenization to ensure decoding.
|
# Allows padding/truncation of generated action tokens during detokenization to ensure decoding.
|
||||||
# In the original version, tensors of 0s were generated if shapes didn't match for stable decoding.
|
# In the original version, tensors of 0s were generated if shapes didn't match for stable decoding.
|
||||||
relaxed_decoding: bool = True
|
relaxed_action_decoding: bool = True
|
||||||
|
|
||||||
def __post_init__(self):
|
def __post_init__(self):
|
||||||
super().__post_init__()
|
super().__post_init__()
|
||||||
|
|
|
@ -56,24 +56,17 @@ from transformers import AutoProcessor, AutoTokenizer, PaliGemmaForConditionalGe
|
||||||
from transformers.cache_utils import HybridCache, StaticCache
|
from transformers.cache_utils import HybridCache, StaticCache
|
||||||
from transformers.models.auto import CONFIG_MAPPING
|
from transformers.models.auto import CONFIG_MAPPING
|
||||||
|
|
||||||
from lerobot.common.constants import ACTION, OBS_IMAGE, OBS_IMAGE_2, OBS_IMAGE_3, OBS_ROBOT
|
from lerobot.common.constants import ACTION, OBS_ROBOT
|
||||||
from lerobot.common.policies.normalize import Normalize, Unnormalize
|
from lerobot.common.policies.normalize import Normalize, Unnormalize
|
||||||
from lerobot.common.policies.pi0fast.configuration_pi0fast import PI0FASTConfig
|
from lerobot.common.policies.pi0fast.configuration_pi0fast import PI0FASTConfig
|
||||||
from lerobot.common.policies.pretrained import PreTrainedPolicy
|
from lerobot.common.policies.pretrained import PreTrainedPolicy
|
||||||
|
|
||||||
IMAGES_ORDER = {
|
|
||||||
OBS_IMAGE: 0,
|
|
||||||
OBS_IMAGE_2: 1,
|
|
||||||
OBS_IMAGE_3: 2,
|
|
||||||
}
|
|
||||||
|
|
||||||
PRECISION = {
|
PRECISION = {
|
||||||
"float16": torch.float16,
|
"float16": torch.float16,
|
||||||
"float32": torch.float32,
|
"float32": torch.float32,
|
||||||
"bfloat16": torch.bfloat16,
|
"bfloat16": torch.bfloat16,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
def normalize(x, min_val, max_val):
|
def normalize(x, min_val, max_val):
|
||||||
return (x - min_val) / (max_val - min_val)
|
return (x - min_val) / (max_val - min_val)
|
||||||
|
|
||||||
|
@ -839,7 +832,7 @@ class PI0FAST(nn.Module):
|
||||||
tok.tolist(),
|
tok.tolist(),
|
||||||
time_horizon=action_horizon,
|
time_horizon=action_horizon,
|
||||||
action_dim=action_dim,
|
action_dim=action_dim,
|
||||||
relaxed_decoding=self.config.relaxed_decoding,
|
relaxed_decoding=self.config.relaxed_action_decoding,
|
||||||
),
|
),
|
||||||
device=tokens.device,
|
device=tokens.device,
|
||||||
).squeeze(0)
|
).squeeze(0)
|
||||||
|
|
Loading…
Reference in New Issue