precommit

This commit is contained in:
mshukor 2025-03-31 16:40:18 +02:00
parent 6e48e044d7
commit fe284bbbd7
4 changed files with 32 additions and 37 deletions

View File

@ -8,7 +8,6 @@ from lerobot.configs.policies import PreTrainedConfig
from lerobot.configs.types import FeatureType, NormalizationMode, PolicyFeature from lerobot.configs.types import FeatureType, NormalizationMode, PolicyFeature
@PreTrainedConfig.register_subclass("pi0fast") @PreTrainedConfig.register_subclass("pi0fast")
@dataclass @dataclass
class PI0FASTConfig(PreTrainedConfig): class PI0FASTConfig(PreTrainedConfig):
@ -33,7 +32,7 @@ class PI0FASTConfig(PreTrainedConfig):
resize_imgs_with_padding: tuple[int, int] = (224, 224) resize_imgs_with_padding: tuple[int, int] = (224, 224)
interpolate_like_pi: bool = False interpolate_like_pi: bool = False
# Add empty images. Used by pi0_aloha_sim which adds the emtpy # Add empty images. Used by pi0_aloha_sim which adds the empty
# left and right wrist cameras in addition to the top camera. # left and right wrist cameras in addition to the top camera.
empty_cameras: int = 0 empty_cameras: int = 0

View File

@ -57,11 +57,10 @@ from transformers import AutoProcessor, AutoTokenizer, PaliGemmaForConditionalGe
from transformers.cache_utils import HybridCache, StaticCache from transformers.cache_utils import HybridCache, StaticCache
from transformers.models.auto import CONFIG_MAPPING from transformers.models.auto import CONFIG_MAPPING
from lerobot.common.constants import ACTION, OBS_IMAGE, OBS_IMAGE_2, OBS_IMAGE_3, OBS_ROBOT
from lerobot.common.policies.normalize import Normalize, Unnormalize from lerobot.common.policies.normalize import Normalize, Unnormalize
from lerobot.common.policies.pi0fast.configuration_pi0fast import PI0FASTConfig from lerobot.common.policies.pi0fast.configuration_pi0fast import PI0FASTConfig
from lerobot.common.policies.pretrained import PreTrainedPolicy from lerobot.common.policies.pretrained import PreTrainedPolicy
from lerobot.common.constants import ACTION, OBS_IMAGE, OBS_IMAGE_2, OBS_IMAGE_3, OBS_ROBOT
IMAGES_ORDER = { IMAGES_ORDER = {
OBS_IMAGE: 0, OBS_IMAGE: 0,
@ -75,6 +74,7 @@ PRECISION = {
"bfloat16": torch.bfloat16, "bfloat16": torch.bfloat16,
} }
def display(tensor: torch.Tensor): def display(tensor: torch.Tensor):
if tensor.dtype == torch.bool: if tensor.dtype == torch.bool:
tensor = tensor.float() tensor = tensor.float()
@ -139,7 +139,6 @@ def aloha_gripper_from_angular_inv(value):
return normalize(value, min_val=0.4, max_val=1.5) return normalize(value, min_val=0.4, max_val=1.5)
class PI0FASTPolicy(PreTrainedPolicy): class PI0FASTPolicy(PreTrainedPolicy):
"""Wrapper class around PI0FAST tokenizer and model to train and run inference within LeRobot.""" """Wrapper class around PI0FAST tokenizer and model to train and run inference within LeRobot."""
@ -425,7 +424,9 @@ class PI0FAST(nn.Module):
self.fast_skip_tokens = self.config.fast_skip_tokens self.fast_skip_tokens = self.config.fast_skip_tokens
self.max_input_seq_len = self.config.max_input_seq_len self.max_input_seq_len = self.config.max_input_seq_len
self.action_horizon = self.config.chunk_size self.action_horizon = self.config.chunk_size
self.action_dim = self.config.action_feature.shape[0] #self.config.max_action_dim # self.config.action_feature.shape[0] self.action_dim = self.config.action_feature.shape[
0
] # self.config.max_action_dim # self.config.action_feature.shape[0]
precision = config.precision precision = config.precision
torch_precision = PRECISION.get(precision, torch.float32) torch_precision = PRECISION.get(precision, torch.float32)
self.pad_token_id = ( self.pad_token_id = (
@ -496,7 +497,7 @@ class PI0FAST(nn.Module):
if any(selector in name for selector in params_to_change_dtype): if any(selector in name for selector in params_to_change_dtype):
param.data = param.data.to(dtype=torch_precision) param.data = param.data.to(dtype=torch_precision)
self.set_requires_grad() self.set_requires_grad()
self.image_keys = self.config.image_features.keys() self.image_keys = self.config.image_features.keys()
self.ignore_index = self.pi0_paligemma.config.ignore_index self.ignore_index = self.pi0_paligemma.config.ignore_index
self.padding_side = self.config.padding_side self.padding_side = self.config.padding_side
@ -508,7 +509,7 @@ class PI0FAST(nn.Module):
# To avoid unused params issue with distributed training # To avoid unused params issue with distributed training
if self.config.freeze_lm_head: if self.config.freeze_lm_head:
for name, params in self.pi0_paligemma.named_parameters(): for name, params in self.pi0_paligemma.named_parameters():
if any([k in name for k in ["embed_tokens"]]): # lm heads and embedding layer are tied if "embed_tokens" in name: # lm heads and embedding layer are tied
params.requires_grad = False params.requires_grad = False
def embed_tokens(self, tokens: torch.Tensor): def embed_tokens(self, tokens: torch.Tensor):
@ -579,9 +580,7 @@ class PI0FAST(nn.Module):
return fast_out return fast_out
def create_token_type_ids( def create_token_type_ids(self, padded_mask: torch.Tensor, prefix_len: int) -> torch.Tensor:
self, padded_mask: torch.Tensor, prefix_len: int
) -> torch.Tensor:
token_type_ids = torch.zeros_like(padded_mask, dtype=torch.bool) token_type_ids = torch.zeros_like(padded_mask, dtype=torch.bool)
# Compute cumulative sum mask # Compute cumulative sum mask
cumsum_mask = (padded_mask != 0).cumsum(dim=1) cumsum_mask = (padded_mask != 0).cumsum(dim=1)
@ -635,9 +634,9 @@ class PI0FAST(nn.Module):
[self.paligemma_tokenizer.eos_token_id], dtype=torch.long, device=device [self.paligemma_tokenizer.eos_token_id], dtype=torch.long, device=device
).expand(bsize, -1) ).expand(bsize, -1)
eos_mask = torch.tensor([1], dtype=torch.long, device=device).expand(bsize, -1) eos_mask = torch.tensor([1], dtype=torch.long, device=device).expand(bsize, -1)
bos = self.paligemma_tokenizer('Action: ', add_special_tokens=False, return_tensors='pt') bos = self.paligemma_tokenizer("Action: ", add_special_tokens=False, return_tensors="pt")
bos_token = bos['input_ids'].expand(act_ids.shape[0],-1).to(device) bos_token = bos["input_ids"].expand(act_ids.shape[0], -1).to(device)
bos_mask = bos['attention_mask'].expand(act_ids.shape[0],-1).to(device) bos_mask = bos["attention_mask"].expand(act_ids.shape[0], -1).to(device)
act_ids = torch.cat([bos_token, act_ids, eos_token], dim=1) act_ids = torch.cat([bos_token, act_ids, eos_token], dim=1)
act_mask = torch.cat([bos_mask, act_mask, eos_mask], dim=1) act_mask = torch.cat([bos_mask, act_mask, eos_mask], dim=1)
act_mask = act_mask.to(device) act_mask = act_mask.to(device)
@ -656,13 +655,9 @@ class PI0FAST(nn.Module):
padded_mask = padded_output["attention_mask"] padded_mask = padded_output["attention_mask"]
# define tensor of padding lengths # define tensor of padding lengths
att_mask = (padded_mask != 0).cumsum( att_mask = (padded_mask != 0).cumsum(dim=1) > prefix_lens
dim=1
) > prefix_lens
token_type_ids = self.create_token_type_ids( token_type_ids = self.create_token_type_ids(padded_mask=padded_mask, prefix_len=prefix_lens)
padded_mask=padded_mask, prefix_len=prefix_lens
)
padded_output["padded_mask"] = padded_output.pop("attention_mask") padded_output["padded_mask"] = padded_output.pop("attention_mask")
padded_output["attention_mask"] = att_mask padded_output["attention_mask"] = att_mask
@ -713,7 +708,9 @@ class PI0FAST(nn.Module):
images, img_masks = self.prepare_images(batch) images, img_masks = self.prepare_images(batch)
padded_outs = self.create_input_tokens( padded_outs = self.create_input_tokens(
state=batch[OBS_ROBOT], lang_text=batch["task"], actions=batch[ACTION], state=batch[OBS_ROBOT],
lang_text=batch["task"],
actions=batch[ACTION],
) )
embs, pad_masks, _, targets, loss_mask, token_type_ids = self.embed_inputs( embs, pad_masks, _, targets, loss_mask, token_type_ids = self.embed_inputs(
@ -793,9 +790,9 @@ class PI0FAST(nn.Module):
self.called_time_horizon = self.time_horizon self.called_time_horizon = self.time_horizon
self.called_action_dim = self.action_dim self.called_action_dim = self.action_dim
assert ( assert self.time_horizon is not None and self.action_dim is not None, (
self.time_horizon is not None and self.action_dim is not None "Tokenizer not initialized, call encode() once or pass in time_horizon and action_dim."
), "Tokenizer not initialized, call encode() once or pass in time_horizon and action_dim." )
decoded_actions = [] decoded_actions = []
for token in tokens: for token in tokens:
@ -816,13 +813,12 @@ class PI0FAST(nn.Module):
) )
decoded_dct_coeff = decoded_dct_coeff.reshape(-1, self.action_dim) decoded_dct_coeff = decoded_dct_coeff.reshape(-1, self.action_dim)
assert ( assert decoded_dct_coeff.shape == (
decoded_dct_coeff.shape self.time_horizon,
== ( self.action_dim,
self.time_horizon, ), (
self.action_dim, f"Decoded DCT coefficients have shape {decoded_dct_coeff.shape}, expected ({self.time_horizon}, {self.action_dim})"
) )
), f"Decoded DCT coefficients have shape {decoded_dct_coeff.shape}, expected ({self.time_horizon}, {self.action_dim})"
except Exception as e: except Exception as e:
print(f"Error decoding tokens: {e}") print(f"Error decoding tokens: {e}")
print(f"Tokens: {token}") print(f"Tokens: {token}")