diff --git a/lerobot/common/policies/pi0fast/config.json b/lerobot/common/policies/pi0fast/config.json index 5ac603aa..b2208e2c 100644 --- a/lerobot/common/policies/pi0fast/config.json +++ b/lerobot/common/policies/pi0fast/config.json @@ -76,4 +76,4 @@ "scheduler_warmup_steps": 1000, "scheduler_decay_steps": 30000, "scheduler_decay_lr": 2.5e-06 -} \ No newline at end of file +} diff --git a/lerobot/common/policies/pi0fast/configuration_pi0fast.py b/lerobot/common/policies/pi0fast/configuration_pi0fast.py index 7c5d3db5..f56234ce 100644 --- a/lerobot/common/policies/pi0fast/configuration_pi0fast.py +++ b/lerobot/common/policies/pi0fast/configuration_pi0fast.py @@ -8,7 +8,6 @@ from lerobot.configs.policies import PreTrainedConfig from lerobot.configs.types import FeatureType, NormalizationMode, PolicyFeature - @PreTrainedConfig.register_subclass("pi0fast") @dataclass class PI0FASTConfig(PreTrainedConfig): @@ -33,7 +32,7 @@ class PI0FASTConfig(PreTrainedConfig): resize_imgs_with_padding: tuple[int, int] = (224, 224) interpolate_like_pi: bool = False - # Add empty images. Used by pi0_aloha_sim which adds the emtpy + # Add empty images. Used by pi0_aloha_sim which adds the empty # left and right wrist cameras in addition to the top camera. empty_cameras: int = 0 diff --git a/lerobot/common/policies/pi0fast/modeling_pi0fast.py b/lerobot/common/policies/pi0fast/modeling_pi0fast.py index 85297b3d..61de917c 100644 --- a/lerobot/common/policies/pi0fast/modeling_pi0fast.py +++ b/lerobot/common/policies/pi0fast/modeling_pi0fast.py @@ -57,11 +57,10 @@ from transformers import AutoProcessor, AutoTokenizer, PaliGemmaForConditionalGe from transformers.cache_utils import HybridCache, StaticCache from transformers.models.auto import CONFIG_MAPPING +from lerobot.common.constants import ACTION, OBS_IMAGE, OBS_IMAGE_2, OBS_IMAGE_3, OBS_ROBOT from lerobot.common.policies.normalize import Normalize, Unnormalize from lerobot.common.policies.pi0fast.configuration_pi0fast import PI0FASTConfig from lerobot.common.policies.pretrained import PreTrainedPolicy -from lerobot.common.constants import ACTION, OBS_IMAGE, OBS_IMAGE_2, OBS_IMAGE_3, OBS_ROBOT - IMAGES_ORDER = { OBS_IMAGE: 0, @@ -75,6 +74,7 @@ PRECISION = { "bfloat16": torch.bfloat16, } + def display(tensor: torch.Tensor): if tensor.dtype == torch.bool: tensor = tensor.float() @@ -139,7 +139,6 @@ def aloha_gripper_from_angular_inv(value): return normalize(value, min_val=0.4, max_val=1.5) - class PI0FASTPolicy(PreTrainedPolicy): """Wrapper class around PI0FAST tokenizer and model to train and run inference within LeRobot.""" @@ -209,7 +208,7 @@ class PI0FASTPolicy(PreTrainedPolicy): for motor_idx in [6, 13]: actions[:, :, motor_idx] = aloha_gripper_from_angular_inv(actions[:, :, motor_idx]) return actions - + @torch.no_grad def select_action(self, batch: dict[str, Tensor]) -> Tensor: """Select a single action given environment observations. @@ -425,7 +424,9 @@ class PI0FAST(nn.Module): self.fast_skip_tokens = self.config.fast_skip_tokens self.max_input_seq_len = self.config.max_input_seq_len self.action_horizon = self.config.chunk_size - self.action_dim = self.config.action_feature.shape[0] #self.config.max_action_dim # self.config.action_feature.shape[0] + self.action_dim = self.config.action_feature.shape[ + 0 + ] # self.config.max_action_dim # self.config.action_feature.shape[0] precision = config.precision torch_precision = PRECISION.get(precision, torch.float32) self.pad_token_id = ( @@ -496,7 +497,7 @@ class PI0FAST(nn.Module): if any(selector in name for selector in params_to_change_dtype): param.data = param.data.to(dtype=torch_precision) self.set_requires_grad() - self.image_keys = self.config.image_features.keys() + self.image_keys = self.config.image_features.keys() self.ignore_index = self.pi0_paligemma.config.ignore_index self.padding_side = self.config.padding_side @@ -508,7 +509,7 @@ class PI0FAST(nn.Module): # To avoid unused params issue with distributed training if self.config.freeze_lm_head: for name, params in self.pi0_paligemma.named_parameters(): - if any([k in name for k in ["embed_tokens"]]): # lm heads and embedding layer are tied + if "embed_tokens" in name: # lm heads and embedding layer are tied params.requires_grad = False def embed_tokens(self, tokens: torch.Tensor): @@ -579,9 +580,7 @@ class PI0FAST(nn.Module): return fast_out - def create_token_type_ids( - self, padded_mask: torch.Tensor, prefix_len: int - ) -> torch.Tensor: + def create_token_type_ids(self, padded_mask: torch.Tensor, prefix_len: int) -> torch.Tensor: token_type_ids = torch.zeros_like(padded_mask, dtype=torch.bool) # Compute cumulative sum mask cumsum_mask = (padded_mask != 0).cumsum(dim=1) @@ -635,9 +634,9 @@ class PI0FAST(nn.Module): [self.paligemma_tokenizer.eos_token_id], dtype=torch.long, device=device ).expand(bsize, -1) eos_mask = torch.tensor([1], dtype=torch.long, device=device).expand(bsize, -1) - bos = self.paligemma_tokenizer('Action: ', add_special_tokens=False, return_tensors='pt') - bos_token = bos['input_ids'].expand(act_ids.shape[0],-1).to(device) - bos_mask = bos['attention_mask'].expand(act_ids.shape[0],-1).to(device) + bos = self.paligemma_tokenizer("Action: ", add_special_tokens=False, return_tensors="pt") + bos_token = bos["input_ids"].expand(act_ids.shape[0], -1).to(device) + bos_mask = bos["attention_mask"].expand(act_ids.shape[0], -1).to(device) act_ids = torch.cat([bos_token, act_ids, eos_token], dim=1) act_mask = torch.cat([bos_mask, act_mask, eos_mask], dim=1) act_mask = act_mask.to(device) @@ -656,13 +655,9 @@ class PI0FAST(nn.Module): padded_mask = padded_output["attention_mask"] # define tensor of padding lengths - att_mask = (padded_mask != 0).cumsum( - dim=1 - ) > prefix_lens + att_mask = (padded_mask != 0).cumsum(dim=1) > prefix_lens - token_type_ids = self.create_token_type_ids( - padded_mask=padded_mask, prefix_len=prefix_lens - ) + token_type_ids = self.create_token_type_ids(padded_mask=padded_mask, prefix_len=prefix_lens) padded_output["padded_mask"] = padded_output.pop("attention_mask") padded_output["attention_mask"] = att_mask @@ -713,7 +708,9 @@ class PI0FAST(nn.Module): images, img_masks = self.prepare_images(batch) padded_outs = self.create_input_tokens( - state=batch[OBS_ROBOT], lang_text=batch["task"], actions=batch[ACTION], + state=batch[OBS_ROBOT], + lang_text=batch["task"], + actions=batch[ACTION], ) embs, pad_masks, _, targets, loss_mask, token_type_ids = self.embed_inputs( @@ -793,9 +790,9 @@ class PI0FAST(nn.Module): self.called_time_horizon = self.time_horizon self.called_action_dim = self.action_dim - assert ( - self.time_horizon is not None and self.action_dim is not None - ), "Tokenizer not initialized, call encode() once or pass in time_horizon and action_dim." + assert self.time_horizon is not None and self.action_dim is not None, ( + "Tokenizer not initialized, call encode() once or pass in time_horizon and action_dim." + ) decoded_actions = [] for token in tokens: @@ -816,13 +813,12 @@ class PI0FAST(nn.Module): ) decoded_dct_coeff = decoded_dct_coeff.reshape(-1, self.action_dim) - assert ( - decoded_dct_coeff.shape - == ( - self.time_horizon, - self.action_dim, - ) - ), f"Decoded DCT coefficients have shape {decoded_dct_coeff.shape}, expected ({self.time_horizon}, {self.action_dim})" + assert decoded_dct_coeff.shape == ( + self.time_horizon, + self.action_dim, + ), ( + f"Decoded DCT coefficients have shape {decoded_dct_coeff.shape}, expected ({self.time_horizon}, {self.action_dim})" + ) except Exception as e: print(f"Error decoding tokens: {e}") print(f"Tokens: {token}") @@ -847,7 +843,7 @@ class PI0FAST(nn.Module): cleaned_tokens = [ tokens_sequence.replace("Action:", "").replace(":", "").strip().split("|")[0].strip() for tokens_sequence in decoded_tokens - ] + ] raw_action_tokens = [ self.processor.tokenizer.encode(sample_tokens, return_tensors="pt", padding=False) for sample_tokens in cleaned_tokens @@ -893,7 +889,7 @@ class PI0FAST(nn.Module): attention_mask=pad_masks, position_ids=prefix_position_ids, past_key_values=None, - inputs_embeds=embs, + inputs_embeds=embs, use_cache=self.config.use_cache, max_new_tokens=self.config.max_decoding_steps, do_sample=False, @@ -996,4 +992,4 @@ def resize_with_pad(img, width, height, pad_value=0, interpolate_like_pi=True): # pad on left and top of image padded_img = F.pad(resized_img, (pad_width, 0, pad_height, 0), value=pad_value) - return padded_img \ No newline at end of file + return padded_img diff --git a/lerobot/common/policies/pi0fast/test_pi0fast.sh b/lerobot/common/policies/pi0fast/test_pi0fast.sh index 571854fb..df06152a 100644 --- a/lerobot/common/policies/pi0fast/test_pi0fast.sh +++ b/lerobot/common/policies/pi0fast/test_pi0fast.sh @@ -24,4 +24,4 @@ python lerobot/scripts/train.py \ --output_dir=$OUT_DIR \ --eval_freq=$EVAL_FREQ \ --steps=$OFFLINE_STEP \ - --save_freq=$SAVE_FREQ \ No newline at end of file + --save_freq=$SAVE_FREQ