precommit
This commit is contained in:
parent
6e48e044d7
commit
fe284bbbd7
|
@ -76,4 +76,4 @@
|
|||
"scheduler_warmup_steps": 1000,
|
||||
"scheduler_decay_steps": 30000,
|
||||
"scheduler_decay_lr": 2.5e-06
|
||||
}
|
||||
}
|
||||
|
|
|
@ -8,7 +8,6 @@ from lerobot.configs.policies import PreTrainedConfig
|
|||
from lerobot.configs.types import FeatureType, NormalizationMode, PolicyFeature
|
||||
|
||||
|
||||
|
||||
@PreTrainedConfig.register_subclass("pi0fast")
|
||||
@dataclass
|
||||
class PI0FASTConfig(PreTrainedConfig):
|
||||
|
@ -33,7 +32,7 @@ class PI0FASTConfig(PreTrainedConfig):
|
|||
resize_imgs_with_padding: tuple[int, int] = (224, 224)
|
||||
interpolate_like_pi: bool = False
|
||||
|
||||
# Add empty images. Used by pi0_aloha_sim which adds the emtpy
|
||||
# Add empty images. Used by pi0_aloha_sim which adds the empty
|
||||
# left and right wrist cameras in addition to the top camera.
|
||||
empty_cameras: int = 0
|
||||
|
||||
|
|
|
@ -57,11 +57,10 @@ from transformers import AutoProcessor, AutoTokenizer, PaliGemmaForConditionalGe
|
|||
from transformers.cache_utils import HybridCache, StaticCache
|
||||
from transformers.models.auto import CONFIG_MAPPING
|
||||
|
||||
from lerobot.common.constants import ACTION, OBS_IMAGE, OBS_IMAGE_2, OBS_IMAGE_3, OBS_ROBOT
|
||||
from lerobot.common.policies.normalize import Normalize, Unnormalize
|
||||
from lerobot.common.policies.pi0fast.configuration_pi0fast import PI0FASTConfig
|
||||
from lerobot.common.policies.pretrained import PreTrainedPolicy
|
||||
from lerobot.common.constants import ACTION, OBS_IMAGE, OBS_IMAGE_2, OBS_IMAGE_3, OBS_ROBOT
|
||||
|
||||
|
||||
IMAGES_ORDER = {
|
||||
OBS_IMAGE: 0,
|
||||
|
@ -75,6 +74,7 @@ PRECISION = {
|
|||
"bfloat16": torch.bfloat16,
|
||||
}
|
||||
|
||||
|
||||
def display(tensor: torch.Tensor):
|
||||
if tensor.dtype == torch.bool:
|
||||
tensor = tensor.float()
|
||||
|
@ -139,7 +139,6 @@ def aloha_gripper_from_angular_inv(value):
|
|||
return normalize(value, min_val=0.4, max_val=1.5)
|
||||
|
||||
|
||||
|
||||
class PI0FASTPolicy(PreTrainedPolicy):
|
||||
"""Wrapper class around PI0FAST tokenizer and model to train and run inference within LeRobot."""
|
||||
|
||||
|
@ -209,7 +208,7 @@ class PI0FASTPolicy(PreTrainedPolicy):
|
|||
for motor_idx in [6, 13]:
|
||||
actions[:, :, motor_idx] = aloha_gripper_from_angular_inv(actions[:, :, motor_idx])
|
||||
return actions
|
||||
|
||||
|
||||
@torch.no_grad
|
||||
def select_action(self, batch: dict[str, Tensor]) -> Tensor:
|
||||
"""Select a single action given environment observations.
|
||||
|
@ -425,7 +424,9 @@ class PI0FAST(nn.Module):
|
|||
self.fast_skip_tokens = self.config.fast_skip_tokens
|
||||
self.max_input_seq_len = self.config.max_input_seq_len
|
||||
self.action_horizon = self.config.chunk_size
|
||||
self.action_dim = self.config.action_feature.shape[0] #self.config.max_action_dim # self.config.action_feature.shape[0]
|
||||
self.action_dim = self.config.action_feature.shape[
|
||||
0
|
||||
] # self.config.max_action_dim # self.config.action_feature.shape[0]
|
||||
precision = config.precision
|
||||
torch_precision = PRECISION.get(precision, torch.float32)
|
||||
self.pad_token_id = (
|
||||
|
@ -496,7 +497,7 @@ class PI0FAST(nn.Module):
|
|||
if any(selector in name for selector in params_to_change_dtype):
|
||||
param.data = param.data.to(dtype=torch_precision)
|
||||
self.set_requires_grad()
|
||||
self.image_keys = self.config.image_features.keys()
|
||||
self.image_keys = self.config.image_features.keys()
|
||||
self.ignore_index = self.pi0_paligemma.config.ignore_index
|
||||
self.padding_side = self.config.padding_side
|
||||
|
||||
|
@ -508,7 +509,7 @@ class PI0FAST(nn.Module):
|
|||
# To avoid unused params issue with distributed training
|
||||
if self.config.freeze_lm_head:
|
||||
for name, params in self.pi0_paligemma.named_parameters():
|
||||
if any([k in name for k in ["embed_tokens"]]): # lm heads and embedding layer are tied
|
||||
if "embed_tokens" in name: # lm heads and embedding layer are tied
|
||||
params.requires_grad = False
|
||||
|
||||
def embed_tokens(self, tokens: torch.Tensor):
|
||||
|
@ -579,9 +580,7 @@ class PI0FAST(nn.Module):
|
|||
|
||||
return fast_out
|
||||
|
||||
def create_token_type_ids(
|
||||
self, padded_mask: torch.Tensor, prefix_len: int
|
||||
) -> torch.Tensor:
|
||||
def create_token_type_ids(self, padded_mask: torch.Tensor, prefix_len: int) -> torch.Tensor:
|
||||
token_type_ids = torch.zeros_like(padded_mask, dtype=torch.bool)
|
||||
# Compute cumulative sum mask
|
||||
cumsum_mask = (padded_mask != 0).cumsum(dim=1)
|
||||
|
@ -635,9 +634,9 @@ class PI0FAST(nn.Module):
|
|||
[self.paligemma_tokenizer.eos_token_id], dtype=torch.long, device=device
|
||||
).expand(bsize, -1)
|
||||
eos_mask = torch.tensor([1], dtype=torch.long, device=device).expand(bsize, -1)
|
||||
bos = self.paligemma_tokenizer('Action: ', add_special_tokens=False, return_tensors='pt')
|
||||
bos_token = bos['input_ids'].expand(act_ids.shape[0],-1).to(device)
|
||||
bos_mask = bos['attention_mask'].expand(act_ids.shape[0],-1).to(device)
|
||||
bos = self.paligemma_tokenizer("Action: ", add_special_tokens=False, return_tensors="pt")
|
||||
bos_token = bos["input_ids"].expand(act_ids.shape[0], -1).to(device)
|
||||
bos_mask = bos["attention_mask"].expand(act_ids.shape[0], -1).to(device)
|
||||
act_ids = torch.cat([bos_token, act_ids, eos_token], dim=1)
|
||||
act_mask = torch.cat([bos_mask, act_mask, eos_mask], dim=1)
|
||||
act_mask = act_mask.to(device)
|
||||
|
@ -656,13 +655,9 @@ class PI0FAST(nn.Module):
|
|||
padded_mask = padded_output["attention_mask"]
|
||||
|
||||
# define tensor of padding lengths
|
||||
att_mask = (padded_mask != 0).cumsum(
|
||||
dim=1
|
||||
) > prefix_lens
|
||||
att_mask = (padded_mask != 0).cumsum(dim=1) > prefix_lens
|
||||
|
||||
token_type_ids = self.create_token_type_ids(
|
||||
padded_mask=padded_mask, prefix_len=prefix_lens
|
||||
)
|
||||
token_type_ids = self.create_token_type_ids(padded_mask=padded_mask, prefix_len=prefix_lens)
|
||||
|
||||
padded_output["padded_mask"] = padded_output.pop("attention_mask")
|
||||
padded_output["attention_mask"] = att_mask
|
||||
|
@ -713,7 +708,9 @@ class PI0FAST(nn.Module):
|
|||
images, img_masks = self.prepare_images(batch)
|
||||
|
||||
padded_outs = self.create_input_tokens(
|
||||
state=batch[OBS_ROBOT], lang_text=batch["task"], actions=batch[ACTION],
|
||||
state=batch[OBS_ROBOT],
|
||||
lang_text=batch["task"],
|
||||
actions=batch[ACTION],
|
||||
)
|
||||
|
||||
embs, pad_masks, _, targets, loss_mask, token_type_ids = self.embed_inputs(
|
||||
|
@ -793,9 +790,9 @@ class PI0FAST(nn.Module):
|
|||
self.called_time_horizon = self.time_horizon
|
||||
self.called_action_dim = self.action_dim
|
||||
|
||||
assert (
|
||||
self.time_horizon is not None and self.action_dim is not None
|
||||
), "Tokenizer not initialized, call encode() once or pass in time_horizon and action_dim."
|
||||
assert self.time_horizon is not None and self.action_dim is not None, (
|
||||
"Tokenizer not initialized, call encode() once or pass in time_horizon and action_dim."
|
||||
)
|
||||
|
||||
decoded_actions = []
|
||||
for token in tokens:
|
||||
|
@ -816,13 +813,12 @@ class PI0FAST(nn.Module):
|
|||
)
|
||||
|
||||
decoded_dct_coeff = decoded_dct_coeff.reshape(-1, self.action_dim)
|
||||
assert (
|
||||
decoded_dct_coeff.shape
|
||||
== (
|
||||
self.time_horizon,
|
||||
self.action_dim,
|
||||
)
|
||||
), f"Decoded DCT coefficients have shape {decoded_dct_coeff.shape}, expected ({self.time_horizon}, {self.action_dim})"
|
||||
assert decoded_dct_coeff.shape == (
|
||||
self.time_horizon,
|
||||
self.action_dim,
|
||||
), (
|
||||
f"Decoded DCT coefficients have shape {decoded_dct_coeff.shape}, expected ({self.time_horizon}, {self.action_dim})"
|
||||
)
|
||||
except Exception as e:
|
||||
print(f"Error decoding tokens: {e}")
|
||||
print(f"Tokens: {token}")
|
||||
|
@ -847,7 +843,7 @@ class PI0FAST(nn.Module):
|
|||
cleaned_tokens = [
|
||||
tokens_sequence.replace("Action:", "").replace(":", "").strip().split("|")[0].strip()
|
||||
for tokens_sequence in decoded_tokens
|
||||
]
|
||||
]
|
||||
raw_action_tokens = [
|
||||
self.processor.tokenizer.encode(sample_tokens, return_tensors="pt", padding=False)
|
||||
for sample_tokens in cleaned_tokens
|
||||
|
@ -893,7 +889,7 @@ class PI0FAST(nn.Module):
|
|||
attention_mask=pad_masks,
|
||||
position_ids=prefix_position_ids,
|
||||
past_key_values=None,
|
||||
inputs_embeds=embs,
|
||||
inputs_embeds=embs,
|
||||
use_cache=self.config.use_cache,
|
||||
max_new_tokens=self.config.max_decoding_steps,
|
||||
do_sample=False,
|
||||
|
@ -996,4 +992,4 @@ def resize_with_pad(img, width, height, pad_value=0, interpolate_like_pi=True):
|
|||
|
||||
# pad on left and top of image
|
||||
padded_img = F.pad(resized_img, (pad_width, 0, pad_height, 0), value=pad_value)
|
||||
return padded_img
|
||||
return padded_img
|
||||
|
|
|
@ -24,4 +24,4 @@ python lerobot/scripts/train.py \
|
|||
--output_dir=$OUT_DIR \
|
||||
--eval_freq=$EVAL_FREQ \
|
||||
--steps=$OFFLINE_STEP \
|
||||
--save_freq=$SAVE_FREQ
|
||||
--save_freq=$SAVE_FREQ
|
||||
|
|
Loading…
Reference in New Issue