cleaning
This commit is contained in:
parent
998bd92774
commit
043cc9180e
|
@ -22,8 +22,6 @@ OBS_ROBOT = "observation.state"
|
|||
OBS_IMAGE = "observation.image"
|
||||
OBS_IMAGES = "observation.images"
|
||||
ACTION = "action"
|
||||
OBS_IMAGE_2 = "observation.image2"
|
||||
OBS_IMAGE_3 = "observation.image3"
|
||||
|
||||
# files & directories
|
||||
CHECKPOINTS_DIR = "checkpoints"
|
||||
|
|
|
@ -81,7 +81,7 @@ class PI0FASTConfig(PreTrainedConfig):
|
|||
|
||||
# Allows padding/truncation of generated action tokens during detokenization to ensure decoding.
|
||||
# In the original version, tensors of 0s were generated if shapes didn't match for stable decoding.
|
||||
relaxed_decoding: bool = True
|
||||
relaxed_action_decoding: bool = True
|
||||
|
||||
def __post_init__(self):
|
||||
super().__post_init__()
|
||||
|
|
|
@ -56,24 +56,17 @@ from transformers import AutoProcessor, AutoTokenizer, PaliGemmaForConditionalGe
|
|||
from transformers.cache_utils import HybridCache, StaticCache
|
||||
from transformers.models.auto import CONFIG_MAPPING
|
||||
|
||||
from lerobot.common.constants import ACTION, OBS_IMAGE, OBS_IMAGE_2, OBS_IMAGE_3, OBS_ROBOT
|
||||
from lerobot.common.constants import ACTION, OBS_ROBOT
|
||||
from lerobot.common.policies.normalize import Normalize, Unnormalize
|
||||
from lerobot.common.policies.pi0fast.configuration_pi0fast import PI0FASTConfig
|
||||
from lerobot.common.policies.pretrained import PreTrainedPolicy
|
||||
|
||||
IMAGES_ORDER = {
|
||||
OBS_IMAGE: 0,
|
||||
OBS_IMAGE_2: 1,
|
||||
OBS_IMAGE_3: 2,
|
||||
}
|
||||
|
||||
PRECISION = {
|
||||
"float16": torch.float16,
|
||||
"float32": torch.float32,
|
||||
"bfloat16": torch.bfloat16,
|
||||
}
|
||||
|
||||
|
||||
def normalize(x, min_val, max_val):
|
||||
return (x - min_val) / (max_val - min_val)
|
||||
|
||||
|
@ -839,7 +832,7 @@ class PI0FAST(nn.Module):
|
|||
tok.tolist(),
|
||||
time_horizon=action_horizon,
|
||||
action_dim=action_dim,
|
||||
relaxed_decoding=self.config.relaxed_decoding,
|
||||
relaxed_decoding=self.config.relaxed_action_decoding,
|
||||
),
|
||||
device=tokens.device,
|
||||
).squeeze(0)
|
||||
|
|
Loading…
Reference in New Issue