This commit is contained in:
mshukor 2025-04-02 16:03:22 +02:00
parent 998bd92774
commit 043cc9180e
3 changed files with 3 additions and 12 deletions

View File

@ -22,8 +22,6 @@ OBS_ROBOT = "observation.state"
OBS_IMAGE = "observation.image" OBS_IMAGE = "observation.image"
OBS_IMAGES = "observation.images" OBS_IMAGES = "observation.images"
ACTION = "action" ACTION = "action"
OBS_IMAGE_2 = "observation.image2"
OBS_IMAGE_3 = "observation.image3"
# files & directories # files & directories
CHECKPOINTS_DIR = "checkpoints" CHECKPOINTS_DIR = "checkpoints"

View File

@ -81,7 +81,7 @@ class PI0FASTConfig(PreTrainedConfig):
# Allows padding/truncation of generated action tokens during detokenization to ensure decoding. # Allows padding/truncation of generated action tokens during detokenization to ensure decoding.
# In the original version, tensors of 0s were generated if shapes didn't match for stable decoding. # In the original version, tensors of 0s were generated if shapes didn't match for stable decoding.
relaxed_decoding: bool = True relaxed_action_decoding: bool = True
def __post_init__(self): def __post_init__(self):
super().__post_init__() super().__post_init__()

View File

@ -56,24 +56,17 @@ from transformers import AutoProcessor, AutoTokenizer, PaliGemmaForConditionalGe
from transformers.cache_utils import HybridCache, StaticCache from transformers.cache_utils import HybridCache, StaticCache
from transformers.models.auto import CONFIG_MAPPING from transformers.models.auto import CONFIG_MAPPING
from lerobot.common.constants import ACTION, OBS_IMAGE, OBS_IMAGE_2, OBS_IMAGE_3, OBS_ROBOT from lerobot.common.constants import ACTION, OBS_ROBOT
from lerobot.common.policies.normalize import Normalize, Unnormalize from lerobot.common.policies.normalize import Normalize, Unnormalize
from lerobot.common.policies.pi0fast.configuration_pi0fast import PI0FASTConfig from lerobot.common.policies.pi0fast.configuration_pi0fast import PI0FASTConfig
from lerobot.common.policies.pretrained import PreTrainedPolicy from lerobot.common.policies.pretrained import PreTrainedPolicy
IMAGES_ORDER = {
OBS_IMAGE: 0,
OBS_IMAGE_2: 1,
OBS_IMAGE_3: 2,
}
PRECISION = { PRECISION = {
"float16": torch.float16, "float16": torch.float16,
"float32": torch.float32, "float32": torch.float32,
"bfloat16": torch.bfloat16, "bfloat16": torch.bfloat16,
} }
def normalize(x, min_val, max_val): def normalize(x, min_val, max_val):
return (x - min_val) / (max_val - min_val) return (x - min_val) / (max_val - min_val)
@ -839,7 +832,7 @@ class PI0FAST(nn.Module):
tok.tolist(), tok.tolist(),
time_horizon=action_horizon, time_horizon=action_horizon,
action_dim=action_dim, action_dim=action_dim,
relaxed_decoding=self.config.relaxed_decoding, relaxed_decoding=self.config.relaxed_action_decoding,
), ),
device=tokens.device, device=tokens.device,
).squeeze(0) ).squeeze(0)