diff --git a/lerobot/common/policies/vqbet/configuration_vqbet.py b/lerobot/common/policies/vqbet/configuration_vqbet.py index 35aa6b6e..718accd5 100644 --- a/lerobot/common/policies/vqbet/configuration_vqbet.py +++ b/lerobot/common/policies/vqbet/configuration_vqbet.py @@ -117,6 +117,7 @@ class VQBeTConfig: n_embd: int = 120 dropout: float = 0.1 mlp_hidden_dim: int = 1024 + offset_loss_weight: float = 10000. def __post_init__(self): """Input validation (not exhaustive).""" diff --git a/lerobot/common/policies/vqbet/modeling_vqbet.py b/lerobot/common/policies/vqbet/modeling_vqbet.py index 728d759f..b98fdd52 100644 --- a/lerobot/common/policies/vqbet/modeling_vqbet.py +++ b/lerobot/common/policies/vqbet/modeling_vqbet.py @@ -12,6 +12,7 @@ from dataclasses import dataclass import einops from einops import rearrange, repeat, reduce, pack, unpack +import numpy as np import torch import torch.nn.functional as F # noqa: N812 import torchvision @@ -78,9 +79,7 @@ class VQBeTPolicy(nn.Module, PyTorchModelHubMixin): "observation.state": deque(maxlen=self.config.n_obs_steps), } if self.config.n_action_pred_chunk is not None: - self._action_queue = deque([], maxlen=self.config.n_action_pred_chunk) # original one - # self._action_queue = deque([], maxlen=self.config.n_action_pred_token) # jay temp2 - self._action_history_queue = deque([], maxlen=self.config.n_obs_steps - 1) + self._action_queue = deque([], maxlen=self.config.n_action_pred_chunk) @torch.no_grad def select_action(self, batch: dict[str, Tensor]) -> Tensor: @@ -91,17 +90,11 @@ class VQBeTPolicy(nn.Module, PyTorchModelHubMixin): queue is empty. """ - # seungjae TODO: implement averaging action over horizons - self.eval() batch = self.normalize_inputs(batch) self._obs_queues = populate_queues(self._obs_queues, batch) - if len(self._action_history_queue) == 0: - while len(self._action_history_queue) != self._action_history_queue.maxlen: - self._action_history_queue.append(batch["observation.state"]) - if not self.check_discretized(): self.vqbet._action_head._vqvae_model.discretized = True # raise NotImplementedError( @@ -110,39 +103,16 @@ class VQBeTPolicy(nn.Module, PyTorchModelHubMixin): assert "observation.image" in batch assert "observation.state" in batch - # jay TODO - # took from act, need to separate single act pred, and act seq pred. - # for act seq pred, we should provide averaged act over horizon. - if len(self._action_queue) == 0: - # original one - batch = {key: torch.stack(list(self._obs_queues[key]), dim=1) for key in batch} - batch["action"] = torch.stack(list(self._action_history_queue), dim=1) actions = self.vqbet(batch, rollout=True)[:, : self.config.n_action_pred_chunk] - # TODO(rcadene): make _forward return output dictionary? actions = self.unnormalize_outputs({"action": actions})["action"] self._action_queue.extend(actions.transpose(0, 1)) - - # jay temp2 - # batch = {key: torch.stack(list(self._obs_queues[key]), dim=1) for key in batch} - # batch["action"] = torch.stack(list(self._action_history_queue), dim=1) - # action_predicted = [] - # for i in range(self.config.n_action_pred_token): - - # actions = self.vqbet(batch, rollout=True, input_predicted = action_predicted)[:, i] - - # action_predicted.append(actions) - # actions = self.unnormalize_outputs({"action": torch.stack(action_predicted)})["action"] - # self._action_queue.extend(actions) - ############# - action = self._action_queue.popleft() - self._action_history_queue.append(self.normalize_targets({"action": action})["action"]) return action def forward(self, batch: dict[str, Tensor]) -> dict[str, Tensor]: @@ -150,63 +120,14 @@ class VQBeTPolicy(nn.Module, PyTorchModelHubMixin): batch = self.normalize_inputs(batch) batch = self.normalize_targets(batch) if not self.check_discretized(): - loss = self.vqbet.discretize(self.config.discretize_step, batch['action']) - return {"loss": loss} + loss, n_different_codes, n_different_combinations = self.vqbet.discretize(self.config.discretize_step, batch['action']) + return {"loss": loss, "n_different_codes": n_different_codes, "n_different_combinations": n_different_combinations} - # original one _, loss = self.vqbet(batch, rollout=False) - # jay temp2 - # action_predicted = [] - # pred_values = [] - # for i in range(self.config.n_action_pred_token): - # pred = self.vqbet(batch, rollout=False, input_predicted = action_predicted, recurrent = True) - # actions = pred["predicted_action"][:, i] - # action_predicted.append(actions) - # pred_values.append(pred) - - # # Stack 'a', 'b', and 'c' from each dictionary in the list - # # jay TODO - # predicted_action = torch.stack([d["predicted_action"][:, i] for (i, d) in enumerate(pred_values)]).transpose(0, 1) - # decoded_action = torch.stack([einops.rearrange(d["decoded_action"], "(B T) 1 N->B T N", T=5)[:, i] for (i, d) in enumerate(pred_values)]).transpose(0, 1).reshape(-1, 1, 2) - # sampled_centers = torch.stack([einops.rearrange(d["sampled_centers"], "(B T) N->B T N", T=5)[:, i] for (i, d) in enumerate(pred_values)]).transpose(0, 1).reshape(-1, 2) - # cbet_logits = torch.stack([einops.rearrange(d["cbet_logits"], "(B T) G N->B T G N", T=5)[:, i] for (i, d) in enumerate(pred_values)]).transpose(0, 1).reshape(-1, 2, 16) - - # gpt_output, G, NT, N, T = pred_values[-1]["input"], pred_values[-1]["G"], pred_values[-1]["NT"], pred_values[-1]["N"], pred_values[-1]["T"] - - - # pred_action = { - # "input": gpt_output, - # "cbet_logits1": None, - # "cbet_logits2": None, - # "cbet_logits": cbet_logits if "cbet_logits" in locals() else None, - # "predicted_action": predicted_action, - # "decoded_action": decoded_action, - # "sampled_centers": sampled_centers, - # "G": G, - # "NT": NT, - # "N": N, - # "T": T, - # } - - # action = batch["action"] - # n, total_w, act_dim = action.shape - # act_w = self.config.n_action_pred_chunk - # num_token = total_w + 1 - act_w - # output_shape = (n, num_token, act_w, act_dim) - # output = torch.empty(output_shape).to(action.device) - # for i in range(num_token): - # output[:, i, :, :] = action[:, i : i + act_w, :] - # action = output # batch size, 7, 5, 2 - - # loss = self.vqbet._action_head.loss_fn( - # pred_action, - # action[:, self.config.n_obs_steps-1:], - # reduction="mean", - # )[0] - ############# - - return {"loss": loss['actor_loss']} + return {"loss": loss['actor_loss'], 'equal_single_code_rate': loss['equal_single_code_rate'], 'equal_single_code_rate2': loss['equal_single_code_rate2'], "offset_loss_weight": loss['offset_loss_weight'], \ + "action_diff": loss['action_diff'], "action_diff_tot": loss['action_diff_tot'], "action_diff_mean_res1": loss['action_diff_mean_res1'], "action_diff_mean_res2": loss['action_diff_mean_res2'], \ + "action_diff_max": loss['action_diff_max']} class VQBeTModel(nn.Module): @@ -214,24 +135,24 @@ class VQBeTModel(nn.Module): super().__init__() self.config = config - self.rgb_encoder = DiffusionRgbEncoder(config) + self.rgb_encoder = VQBeTRgbEncoder(config) self.global_cond_dim = self.rgb_encoder.feature_dim # action token and EOS token - self._action_token = nn.Parameter(torch.randn(1, 1, self.global_cond_dim)) # Batch, Timestep, Data type, GPT input dim - self._eos_token = nn.Parameter(torch.randn(1, 1, self.global_cond_dim)) + self._action_token = nn.Parameter(torch.randn(1, 1, self.config.n_embd)) # Batch, Timestep, Data type, GPT input dim + self._eos_token = nn.Parameter(torch.randn(1, 1, self.config.n_embd)) self.state_projector = MLP( - config.output_shapes["action"][0], hidden_channels=[self.global_cond_dim] + config.output_shapes["action"][0], hidden_channels=[self.config.n_embd] ) - self.action_projector = MLP( - config.output_shapes["action"][0], hidden_channels=[self.global_cond_dim] + self.obs_projector = MLP( + self.global_cond_dim, hidden_channels=[self.config.n_embd] ) self._policy = GPT( GPTConfig( block_size=self.config.block_size, - input_dim=self.global_cond_dim, + input_dim=self.config.n_embd, output_dim=self.config.output_dim, n_layer=self.config.n_layer, n_head=self.config.n_head, @@ -242,6 +163,7 @@ class VQBeTModel(nn.Module): self._action_head = VQBeTHead( config.output_dim, config.output_shapes["action"][0], + offset_loss_weight=config.offset_loss_weight, hidden_size=config.mlp_hidden_dim, vqvae_groups=config.vqvae_groups, vqvae_n_embed=config.vqvae_n_embed, @@ -253,7 +175,7 @@ class VQBeTModel(nn.Module): return self._action_head.discretize(discretize_step, actions) # ========= inference ============ - def forward(self, batch: dict[str, Tensor], rollout: bool, input_predicted:list = None, recurrent = False) -> Tensor: # jay temp2 + def forward(self, batch: dict[str, Tensor], rollout: bool) -> Tensor: # Input validation. assert set(batch).issuperset({"observation.state", "observation.image"}) batch_size, n_obs_steps = batch["observation.state"].shape[:2] @@ -265,65 +187,36 @@ class VQBeTModel(nn.Module): img_features = einops.rearrange(img_features, "(b n) ... -> b n ...", b=batch_size) # Concatenate state and image features then flatten to (B, global_cond_dim). - - if 'goal' in batch.keys(): - goal = batch["goal"] - num_goal_token = goal.shape[1] - else: - goal = None - num_goal_token = 0 - - - if rollout: - gpt_input_action = torch.cat((batch["action"], torch.zeros(batch_size, 1, batch["action"].shape[-1]).to(get_device_from_parameters(self))), dim=1) - else: - gpt_input_action = batch["action"][..., :n_obs_steps, :] - for i, val in enumerate(batch["frame_index"]): - if val < n_obs_steps - 1: - # Update original_data based on the rules - gpt_input_action[i, :-(val+1)] = batch["observation.state"][i, :-(val+1)] - - - gpt_input_action[:, -1] = 0 global_cond = torch.cat([ - torch.unsqueeze(img_features, dim=2), + torch.unsqueeze(self.obs_projector(img_features), dim=2), torch.unsqueeze(self.state_projector(batch["observation.state"]), dim=2), - # torch.unsqueeze(self.action_projector(gpt_input_action), dim=2) # jay temp - ], dim=-2).view(batch_size, -1, self.global_cond_dim) # batch size, num tokens, GPT input dim - # global_cond = global_cond[:, :-1] # get rid of final action # jay temp - - - # 앞선 action을 주고 뒤를 예측하게 해도 되고, 그냥 싹 다 예측하게 해도 될듯 -> 일단 action 앞에거까지 주는걸로 - # 토큰이 그냥 single action이어도 되고, multi step (chunk) 여도 될듯 -> 일단 chunk로 - eos_token = self._eos_token.repeat(batch_size, 1, 1) - action_token = self._action_token.repeat(batch_size, self.config.n_action_pred_token, 1) - - # jay temp2 - # if (input_predicted is not None) and (len(input_predicted) > 0): - # len_input = len(input_predicted) - # action_token[:, :len_input] = self.action_projector(torch.stack(input_predicted, dim=1).to(get_device_from_parameters(self))) + self._action_token.repeat(batch_size, n_obs_steps, 1, 1) + ], dim=-2).view(batch_size, -1, self.config.n_embd) + if img_features.shape[1] != n_obs_steps: + raise NotImplementedError + # eos_token = self._eos_token.repeat(batch_size, 1, 1) + len_additional_action_token = self.config.n_action_pred_token-1 + action_token = self._action_token.repeat(batch_size, len_additional_action_token, 1) - prompt_length = global_cond.shape[1]+1 - global_cond = torch.cat([global_cond, eos_token, action_token], dim=1) + # prompt_length = global_cond.shape[1]+1 + global_cond = torch.cat([global_cond, action_token], dim=1) # get action features features = self._policy(global_cond) - features = features[:, prompt_length:] + historical_act_pred_index = np.arange(0, n_obs_steps) * 3 + 2 + features = torch.cat([ + features[:, historical_act_pred_index], + features[:, -len_additional_action_token:] + ], dim=1) # action head pred_action = self._action_head( features, - # **{"action_seq": action}, ) if rollout: - return pred_action["predicted_action"][:, -1, :].reshape(batch_size, self.config.n_action_pred_chunk, -1) # original one - # return pred_action["predicted_action"] # jay temp2 - - # should put this part between "if" and "else" - # elif recurrent: # jay temp2 => training - # return pred_action + return pred_action["predicted_action"][:, n_obs_steps-1, :].reshape(batch_size, self.config.n_action_pred_chunk, -1) else: action = batch["action"] n, total_w, act_dim = action.shape @@ -333,11 +226,11 @@ class VQBeTModel(nn.Module): output = torch.empty(output_shape).to(action.device) for i in range(num_token): output[:, i, :, :] = action[:, i : i + act_w, :] - action = output # batch size, 7, 5, 2 + action = output loss = self._action_head.loss_fn( pred_action, - action[:, n_obs_steps-1:], + action, reduction="mean", ) return pred_action, loss[0] if isinstance(loss, tuple) else loss @@ -364,7 +257,6 @@ class VQBeTHead(nn.Module): self.hidden_size = hidden_size self.offset_loss_weight = offset_loss_weight self.secondary_code_multiplier = secondary_code_multiplier - self.sequentially_select = False self._G = vqvae_groups # G(number of groups) self._C = vqvae_n_embed # C(number of code integers) @@ -372,26 +264,13 @@ class VQBeTHead(nn.Module): self.n_action_pred_chunk = n_action_pred_chunk # action chunk size - if self.sequentially_select: - print("use sequantial prediction for vq dictionary!") - self._map_to_cbet_preds_bin1 = MLP( - in_channels=self.input_size, - hidden_channels=[self.hidden_size, self.hidden_size, self._C], - ) - self._map_to_cbet_preds_bin2 = MLP( - in_channels=self.input_size + self._C, - hidden_channels=[self.hidden_size, self._C], - ) - else: - self._map_to_cbet_preds_bin = MLP( - in_channels=self.input_size, - hidden_channels=[self.hidden_size, self.hidden_size, self._G * self._C], - ) + self._map_to_cbet_preds_bin = MLP( + in_channels=self.input_size, + hidden_channels=[self._G * self._C], + ) self._map_to_cbet_preds_offset = MLP( in_channels=self.input_size, hidden_channels=[ - self.hidden_size, - self.hidden_size, self._G * self._C * n_action_pred_chunk * self.output_size, ], ) @@ -415,63 +294,33 @@ class VQBeTHead(nn.Module): self._vqvae_model.decoder.to(get_device_from_parameters(self)) self._vqvae_model.device = get_device_from_parameters(self) - loss = pretrain_vqvae(self._vqvae_model, discretize_step, actions) + loss, n_different_codes, n_different_combinations = pretrain_vqvae(self._vqvae_model, discretize_step, actions) if self._vqvae_model.discretized: print("Finished discretizing action data!") self._vqvae_model.eval() for param in self._vqvae_model.vq_layer.parameters(): param.requires_grad = False - return loss + return loss, n_different_codes, n_different_combinations def forward(self, x, **kwargs): N, T, _ = x.shape x = einops.rearrange(x, "N T WA -> (N T) WA") - if self.sequentially_select: - cbet_logits1 = self._map_to_cbet_preds_bin1(x) - cbet_offsets = self._map_to_cbet_preds_offset(x) - cbet_offsets = einops.rearrange( - cbet_offsets, "(NT) (G C WA) -> (NT) G C WA", G=self._G, C=self._C - ) - cbet_probs1 = torch.softmax(cbet_logits1, dim=-1) - NT, choices = cbet_probs1.shape - G = self._G - sampled_centers1 = einops.rearrange( - torch.multinomial(cbet_probs1.view(-1, choices), num_samples=1), - "(NT) 1 -> NT", - NT=NT, - ) - cbet_logits2 = self._map_to_cbet_preds_bin2( - torch.cat( - (x, F.one_hot(sampled_centers1, num_classes=self._C)), - axis=1, - ) - ) - cbet_probs2 = torch.softmax(cbet_logits2, dim=-1) - sampled_centers2 = einops.rearrange( - torch.multinomial(cbet_probs2.view(-1, choices), num_samples=1), - "(NT) 1 -> NT", - NT=NT, - ) - sampled_centers = torch.stack( - (sampled_centers1, sampled_centers2), axis=1 - ) # NT, G - else: - cbet_logits = self._map_to_cbet_preds_bin(x) - cbet_offsets = self._map_to_cbet_preds_offset(x) - cbet_logits = einops.rearrange( - cbet_logits, "(NT) (G C) -> (NT) G C", G=self._G - ) - cbet_offsets = einops.rearrange( - cbet_offsets, "(NT) (G C WA) -> (NT) G C WA", G=self._G, C=self._C - ) - cbet_probs = torch.softmax(cbet_logits, dim=-1) - NT, G, choices = cbet_probs.shape - sampled_centers = einops.rearrange( - torch.multinomial(cbet_probs.view(-1, choices), num_samples=1), - "(NT G) 1 -> NT G", - NT=NT, - ) + cbet_logits = self._map_to_cbet_preds_bin(x) + cbet_offsets = self._map_to_cbet_preds_offset(x) + cbet_logits = einops.rearrange( + cbet_logits, "(NT) (G C) -> (NT) G C", G=self._G + ) + cbet_offsets = einops.rearrange( + cbet_offsets, "(NT) (G C WA) -> (NT) G C WA", G=self._G, C=self._C + ) + cbet_probs = torch.softmax(cbet_logits, dim=-1) + NT, G, choices = cbet_probs.shape + sampled_centers = einops.rearrange( + torch.multinomial(cbet_probs.view(-1, choices), num_samples=1), + "(NT G) 1 -> NT G", + NT=NT, + ) indices = ( torch.arange(NT).unsqueeze(1).cuda(), @@ -506,13 +355,10 @@ class VQBeTHead(nn.Module): ) return { - "input": x, - "cbet_logits1": cbet_logits1 if "cbet_logits1" in locals() else None, - "cbet_logits2": cbet_logits2 if "cbet_logits2" in locals() else None, "cbet_logits": cbet_logits if "cbet_logits" in locals() else None, "predicted_action": predicted_action, - "decoded_action": decoded_action, "sampled_centers": sampled_centers, + "decoded_action": decoded_action, "G": G, "NT": NT, "N": N, @@ -522,28 +368,16 @@ class VQBeTHead(nn.Module): def loss_fn(self, pred, target, **kwargs): # Rename the inputs for clarity. action_seq = target - gpt_output = pred["input"] predicted_action = pred["predicted_action"] - decoded_action = pred["decoded_action"] sampled_centers = pred["sampled_centers"] + decoded_action = pred["decoded_action"] G, NT, N, T = pred["G"], pred["NT"], pred["N"], pred["T"] - if self.sequentially_select: - cbet_logits1 = pred["cbet_logits1"] - cbet_logits2 = pred["cbet_logits2"] - else: - cbet_logits = pred["cbet_logits"] + cbet_logits = pred["cbet_logits"] predicted_action = einops.rearrange( predicted_action, "N T (W A) -> (N T) W A", W=self._vqvae_model.input_dim_h ) - # n, total_w, act_dim = action_seq.shape - # act_w = self._vqvae_model.input_dim_h - # obs_w = total_w + 1 - act_w - # output_shape = (n, obs_w, act_w, act_dim) - # output = torch.empty(output_shape).to(action_seq.device) - # for i in range(obs_w): - # output[:, i, :, :] = action_seq[:, i : i + act_w, :] action_seq = einops.rearrange(action_seq, "N T W A -> (N T) W A") # Figure out the loss for the actions. # First, we need to find the closest cluster center for each action. @@ -557,81 +391,15 @@ class VQBeTHead(nn.Module): offset_loss = torch.nn.L1Loss()(action_seq, predicted_action) - action_diff = F.mse_loss( - einops.rearrange(action_seq, "(N T) W A -> N T W A", T=T)[:, -1, 0, :], - einops.rearrange(predicted_action, "(N T) W A -> N T W A", T=T)[ - :, -1, 0, : - ], - ) # batch, time, windowsize (t ... t+N), action dim -> [:, -1, 0, :] is for rollout - action_diff_tot = F.mse_loss( - einops.rearrange(action_seq, "(N T) W A -> N T W A", T=T)[:, -1, :, :], - einops.rearrange(predicted_action, "(N T) W A -> N T W A", T=T)[ - :, -1, :, : - ], - ) # batch, time, windowsize (t ... t+N), action dim -> [:, -1, 0, :] is for rollout - action_diff_mean_res1 = ( - abs( - einops.rearrange(action_seq, "(N T) W A -> N T W A", T=T)[ - :, -1, 0, : - ] - - einops.rearrange(decoded_action, "(N T) W A -> N T W A", T=T)[ - :, -1, 0, : - ] - ) - ).mean() - action_diff_mean_res2 = ( - abs( - einops.rearrange(action_seq, "(N T) W A -> N T W A", T=T)[ - :, -1, 0, : - ] - - einops.rearrange(predicted_action, "(N T) W A -> N T W A", T=T)[ - :, -1, 0, : - ] - ) - ).mean() - action_diff_max = ( - abs( - einops.rearrange(action_seq, "(N T) W A -> N T W A", T=T)[ - :, -1, 0, : - ] - - einops.rearrange(predicted_action, "(N T) W A -> N T W A", T=T)[ - :, -1, 0, : - ] - ) - ).max() - if self.sequentially_select: - cbet_loss1 = self._criterion( # F.cross_entropy - cbet_logits1[:, :], - action_bins[:, 0], - ) - cbet_logits2 = self._map_to_cbet_preds_bin2( - torch.cat( - (gpt_output, F.one_hot(action_bins[:, 0], num_classes=self._C)), - axis=1, - ) - ) - cbet_loss2 = self._criterion( # F.cross_entropy - cbet_logits2[:, :], - action_bins[:, 1], - ) - else: - cbet_loss1 = self._criterion( # F.cross_entropy - cbet_logits[:, 0, :], - action_bins[:, 0], - ) - cbet_loss2 = self._criterion( # F.cross_entropy - cbet_logits[:, 1, :], - action_bins[:, 1], - ) - # cbet_loss3 = self._criterion( # F.cross_entropy - # cbet_logits[:, 2, :], - # action_bins[:, 2], - # ) - # cbet_loss4 = self._criterion( # F.cross_entropy - # cbet_logits[:, 3, :], - # action_bins[:, 3], - # ) + cbet_loss1 = self._criterion( # F.cross_entropy + cbet_logits[:, 0, :], + action_bins[:, 0], + ) + cbet_loss2 = self._criterion( # F.cross_entropy + cbet_logits[:, 1, :], + action_bins[:, 1], + ) cbet_loss = cbet_loss1 * 5 + cbet_loss2 * self.secondary_code_multiplier equal_total_code_rate = ( @@ -647,6 +415,49 @@ class VQBeTHead(nn.Module): (action_bins[:, 1] == sampled_centers[:, 1]).int() ) / (NT) + action_diff = F.mse_loss( + einops.rearrange(action_seq, "(N T) W A -> N T W A", T=T)[:, 4, :, :], + einops.rearrange(predicted_action, "(N T) W A -> N T W A", T=T)[ + :, 4, :, : + ], + ) # batch, time, windowsize (t ... t+N), action dim -> [:, -1, 0, :] is for rollout + action_diff_tot = F.mse_loss( + einops.rearrange(action_seq, "(N T) W A -> N T W A", T=T)[:, :, :, :], + einops.rearrange(predicted_action, "(N T) W A -> N T W A", T=T)[ + :, :, :, : + ], + ) # batch, time, windowsize (t ... t+N), action dim -> [:, -1, 0, :] is for rollout + action_diff_mean_res1 = ( + abs( + einops.rearrange(action_seq, "(N T) W A -> N T W A", T=T)[ + :, 4, :, : + ] + - einops.rearrange(decoded_action, "(N T) W A -> N T W A", T=T)[ + :, 4, :, : + ] + ) + ).mean() + action_diff_mean_res2 = ( + abs( + einops.rearrange(action_seq, "(N T) W A -> N T W A", T=T)[ + :, 4, :, : + ] + - einops.rearrange(predicted_action, "(N T) W A -> N T W A", T=T)[ + :, 4, :, : + ] + ) + ).mean() + action_diff_max = ( + abs( + einops.rearrange(action_seq, "(N T) W A -> N T W A", T=T)[ + :, 4, :, : + ] + - einops.rearrange(predicted_action, "(N T) W A -> N T W A", T=T)[ + :, 4, :, : + ] + ) + ).max() + loss = cbet_loss + self.offset_loss_weight * offset_loss loss_dict = { "classification_loss": cbet_loss.detach().cpu().item(), @@ -655,13 +466,10 @@ class VQBeTHead(nn.Module): "equal_total_code_rate": equal_total_code_rate, "equal_single_code_rate": equal_single_code_rate, "equal_single_code_rate2": equal_single_code_rate2, - "action_diff": action_diff.detach().cpu().item(), - "action_diff_tot": action_diff_tot.detach().cpu().item(), - "action_diff_mean_res1": action_diff_mean_res1.detach().cpu().item(), - "action_diff_mean_res2": action_diff_mean_res2.detach().cpu().item(), - "action_diff_max": action_diff_max.detach().cpu().item(), } - return {"actor_loss": loss}, loss_dict + return {"actor_loss": loss, "equal_single_code_rate": equal_single_code_rate, "equal_single_code_rate2": equal_single_code_rate2, "offset_loss_weight": self.offset_loss_weight, \ + "action_diff": action_diff, "action_diff_tot": action_diff_tot, "action_diff_mean_res1": action_diff_mean_res1, "action_diff_mean_res2": action_diff_mean_res2, \ + "action_diff_max": action_diff_max}, loss_dict class VQBeTOptimizer: def __init__(self, policy, cfg): @@ -669,139 +477,6 @@ class VQBeTOptimizer: self.offline_steps = cfg.training.offline_steps self.optimizing_step = 0 - # Option 1 - - # vqvae_params = ( - # list(policy.vqbet._action_head._vqvae_model.encoder.parameters()) - # + list(policy.vqbet._action_head._vqvae_model.decoder.parameters()) - # + list(policy.vqbet._action_head._vqvae_model.vq_layer.parameters()) - # ) - # self.vqvae_optimizer = torch.optim.Adam( - # vqvae_params, lr=cfg.training.vqvae_lr, weight_decay=0.0001 - # ) - - # self.encoder_optimizer = torch.optim.Adam( - # policy.vqbet.parameters(), - # cfg.training.lr, - # cfg.training.adam_betas, - # cfg.training.adam_eps, - # cfg.training.adam_weight_decay, - # ) - - # self.bet_optimizer1 = policy.vqbet._policy.configure_optimizers( - # weight_decay=cfg.training.bet_weight_decay, - # learning_rate=cfg.training.bet_learning_rate, - # betas=cfg.training.bet_betas, - # ) - # if policy.vqbet._action_head.sequentially_select: - # self.bet_optimizer1.add_param_group( - # {"params": policy.vqbet._action_head._map_to_cbet_preds_bin1.parameters()} - # ) - # self.bet_optimizer1.add_param_group( - # {"params": policy.vqbet._action_head._map_to_cbet_preds_bin2.parameters()} - # ) - # else: - # self.bet_optimizer1.add_param_group( - # {"params": policy.vqbet._action_head._map_to_cbet_preds_bin.parameters()} - # ) - - # self.bet_optimizer2 = torch.optim.AdamW( - # policy.vqbet._action_head._map_to_cbet_preds_offset.parameters(), - # lr=cfg.training.bet_learning_rate, - # weight_decay=cfg.training.bet_weight_decay, - # betas=cfg.training.bet_betas, - # ) - - # Option 2 - - # vqvae_params = ( - # list(policy.vqbet._action_head._vqvae_model.encoder.parameters()) - # + list(policy.vqbet._action_head._vqvae_model.decoder.parameters()) - # + list(policy.vqbet._action_head._vqvae_model.vq_layer.parameters()) - # ) - # self.vqvae_optimizer = torch.optim.Adam( - # vqvae_params, lr=cfg.training.vqvae_lr, weight_decay=0.0001 - # ) - - # self.encoder_optimizer = torch.optim.Adam( - # policy.vqbet.parameters(), - # cfg.training.lr, - # cfg.training.adam_betas, - # cfg.training.adam_eps, - # cfg.training.adam_weight_decay, - # ) - - # self.bet_optimizer1 = policy.vqbet._policy.configure_optimizers( - # weight_decay=cfg.training.adam_weight_decay, - # learning_rate=cfg.training.lr, - # betas=cfg.training.adam_betas, - # optimizer = "Adam", - # eps=cfg.training.adam_eps, - # ) - # if policy.vqbet._action_head.sequentially_select: - # self.bet_optimizer1.add_param_group( - # {"params": policy.vqbet._action_head._map_to_cbet_preds_bin1.parameters()} - # ) - # self.bet_optimizer1.add_param_group( - # {"params": policy.vqbet._action_head._map_to_cbet_preds_bin2.parameters()} - # ) - # else: - # self.bet_optimizer1.add_param_group( - # {"params": policy.vqbet._action_head._map_to_cbet_preds_bin.parameters()} - # ) - - # self.bet_optimizer2 = torch.optim.Adam( - # policy.vqbet._action_head._map_to_cbet_preds_offset.parameters(), - # cfg.training.lr, - # cfg.training.adam_betas, - # cfg.training.adam_eps, - # cfg.training.adam_weight_decay, - # ) - - # Option 3 - - # vqvae_params = ( - # list(policy.vqbet._action_head._vqvae_model.encoder.parameters()) - # + list(policy.vqbet._action_head._vqvae_model.decoder.parameters()) - # + list(policy.vqbet._action_head._vqvae_model.vq_layer.parameters()) - # ) - # self.vqvae_optimizer = torch.optim.Adam( - # vqvae_params, lr=cfg.training.vqvae_lr, weight_decay=0.0001 - # ) - - # self.encoder_optimizer = torch.optim.AdamW( - # policy.vqbet.parameters(), - # lr=cfg.training.bet_learning_rate, - # weight_decay=cfg.training.bet_weight_decay, - # betas=cfg.training.bet_betas, - # ) - - # self.bet_optimizer1 = policy.vqbet._policy.configure_optimizers( - # weight_decay=cfg.training.bet_weight_decay, - # learning_rate=cfg.training.bet_learning_rate, - # betas=cfg.training.bet_betas, - # ) - # if policy.vqbet._action_head.sequentially_select: - # self.bet_optimizer1.add_param_group( - # {"params": policy.vqbet._action_head._map_to_cbet_preds_bin1.parameters()} - # ) - # self.bet_optimizer1.add_param_group( - # {"params": policy.vqbet._action_head._map_to_cbet_preds_bin2.parameters()} - # ) - # else: - # self.bet_optimizer1.add_param_group( - # {"params": policy.vqbet._action_head._map_to_cbet_preds_bin.parameters()} - # ) - - # self.bet_optimizer2 = torch.optim.AdamW( - # policy.vqbet._action_head._map_to_cbet_preds_offset.parameters(), - # lr=cfg.training.bet_learning_rate, - # weight_decay=cfg.training.bet_weight_decay, - # betas=cfg.training.bet_betas, - # ) - - - # Option 4 vqvae_params = ( list(policy.vqbet._action_head._vqvae_model.encoder.parameters()) @@ -813,7 +488,7 @@ class VQBeTOptimizer: ) self.encoder_optimizer = torch.optim.Adam( - policy.vqbet.parameters(), + policy.vqbet.rgb_encoder.parameters(), cfg.training.lr, cfg.training.adam_betas, cfg.training.adam_eps, @@ -836,28 +511,17 @@ class VQBeTOptimizer: {"params": policy.vqbet.state_projector.parameters()} ) self.bet_optimizer1.add_param_group( - {"params": policy.vqbet.action_projector.parameters()} + {"params": policy.vqbet.obs_projector.parameters()} ) - # if policy.vqbet._action_head.sequentially_select: - # self.bet_optimizer1.add_param_group( - # {"params": policy.vqbet._action_head._map_to_cbet_preds_bin1.parameters()} - # ) - # self.bet_optimizer1.add_param_group( - # {"params": policy.vqbet._action_head._map_to_cbet_preds_bin2.parameters()} - # ) - # else: - # self.bet_optimizer1.add_param_group( - # {"params": policy.vqbet._action_head._map_to_cbet_preds_bin.parameters()} - # ) - self.bet_optimizer0 = torch.optim.AdamW( + self.bet_optimizer2 = torch.optim.AdamW( policy.vqbet._action_head._map_to_cbet_preds_bin.parameters(), lr=cfg.training.bet_learning_rate, weight_decay=cfg.training.bet_weight_decay, betas=cfg.training.bet_betas, ) - self.bet_optimizer2 = torch.optim.AdamW( + self.bet_optimizer3 = torch.optim.AdamW( policy.vqbet._action_head._map_to_cbet_preds_offset.parameters(), lr=cfg.training.bet_learning_rate, weight_decay=cfg.training.bet_weight_decay, @@ -876,10 +540,10 @@ class VQBeTOptimizer: if self.optimizing_step < 0.6 * self.offline_steps: self.encoder_optimizer.step() self.bet_optimizer1.step() - self.bet_optimizer0.step() self.bet_optimizer2.step() + self.bet_optimizer3.step() else: - self.bet_optimizer2.step() + self.bet_optimizer3.step() def zero_grad(self): if self.optimizing_step < self.discretize_step: @@ -890,10 +554,10 @@ class VQBeTOptimizer: if self.optimizing_step < 0.6 * self.offline_steps: self.encoder_optimizer.zero_grad() self.bet_optimizer1.zero_grad() - self.bet_optimizer0.zero_grad() self.bet_optimizer2.zero_grad() + self.bet_optimizer3.zero_grad() else: - self.bet_optimizer2.zero_grad() + self.bet_optimizer3.zero_grad() class VQBeTScheduler: def __init__(self, optimizer, cfg): @@ -909,19 +573,6 @@ class VQBeTScheduler: num_training_steps=cfg.training.offline_steps, ) - # self.lr_scheduler2 = get_scheduler( - # cfg.training.lr_scheduler, - # optimizer=optimizer.bet_optimizer1, - # num_warmup_steps=cfg.training.lr_warmup_steps, - # num_training_steps=cfg.training.offline_steps, - # ) - - # self.lr_scheduler3 = get_scheduler( - # cfg.training.lr_scheduler, - # optimizer=optimizer.bet_optimizer2, - # num_warmup_steps=cfg.training.lr_warmup_steps, - # num_training_steps=cfg.training.offline_steps, - # ) def step(self): self.optimizing_step +=1 @@ -930,10 +581,12 @@ class VQBeTScheduler: # self.lr_scheduler2.step() # self.lr_scheduler3.step() -class DiffusionRgbEncoder(nn.Module): +class VQBeTRgbEncoder(nn.Module): """Encoder an RGB image into a 1D feature vector. Includes the ability to normalize and crop the image first. + + Same with DiffusionRgbEncoder from modeling_diffusion.py """ def __init__(self, config: VQBeTConfig): @@ -1034,50 +687,6 @@ def _replace_submodules( -# PyTorch dataset class for loading actions -# class ActionDataset(torch.utils.data.Dataset): -# def __init__(self, actions): -# self.actions = actions - -# def __len__(self): -# return len(self.actions) - -# def __getitem__(self, idx): -# return self.actions[idx] -class EncoderMLP(nn.Module): - def __init__( - self, - input_dim, - output_dim=16, - hidden_dim=128, - layer_num=1, - last_activation=None, - ): - super(EncoderMLP, self).__init__() - layers = [] - - layers.append(nn.Linear(input_dim, hidden_dim)) - layers.append(nn.ReLU()) - for _ in range(layer_num): - layers.append(nn.Linear(hidden_dim, hidden_dim)) - layers.append(nn.ReLU()) - - self.encoder = nn.Sequential(*layers) - self.fc = nn.Linear(hidden_dim, output_dim) - - if last_activation is not None: - self.last_layer = last_activation - else: - self.last_layer = None - self.apply(weights_init_encoder) - - def forward(self, x): - h = self.encoder(x) - state = self.fc(h) - if self.last_layer: - state = self.last_layer(state) - return state - class VqVae(nn.Module): def __init__( @@ -1116,18 +725,22 @@ class VqVae(nn.Module): self.embedding_dim = self.n_latent_dims if self.input_dim_h == 1: - self.encoder = EncoderMLP( - input_dim=input_dim_w, output_dim=n_latent_dims + self.encoder = MLP( + in_channels=input_dim_w, + hidden_channels=[128, 128, n_latent_dims], ) - self.decoder = EncoderMLP( - input_dim=n_latent_dims, output_dim=input_dim_w + self.decoder = MLP( + in_channels=n_latent_dims, + hidden_channels=[128, 128, input_dim_w], ) else: - self.encoder = EncoderMLP( - input_dim=input_dim_w * self.input_dim_h, output_dim=n_latent_dims + self.encoder = MLP( + in_channels=input_dim_w * self.input_dim_h, + hidden_channels=[128, 128, n_latent_dims], ) - self.decoder = EncoderMLP( - input_dim=n_latent_dims, output_dim=input_dim_w * self.input_dim_h + self.decoder = MLP( + in_channels=n_latent_dims, + hidden_channels=[128, 128, input_dim_w * self.input_dim_h], ) @@ -1276,10 +889,12 @@ def pretrain_vqvae(vqvae_model, discretize_step, actions): loss, metric = vqvae_model.vqvae_forward( actions ) # N T D + n_different_codes = len(torch.unique(metric[2])) + n_different_combinations = len(torch.unique(metric[2], dim=0)) vqvae_model.optimized_steps += 1 if vqvae_model.optimized_steps >= discretize_step: vqvae_model.discretized = True - return loss + return loss, n_different_codes, n_different_combinations def exists(val): @@ -1528,83 +1143,6 @@ class ResidualVQ(nn.Module): return ret -# grouped residual vq - - -class GroupedResidualVQ(nn.Module): - def __init__(self, *, dim, groups=1, accept_image_fmap=False, **kwargs): - super().__init__() - self.dim = dim - self.groups = groups - assert (dim % groups) == 0 - dim_per_group = dim // groups - - self.accept_image_fmap = accept_image_fmap - - self.rvqs = nn.ModuleList([]) - - for _ in range(groups): - self.rvqs.append( - ResidualVQ( - dim=dim_per_group, accept_image_fmap=accept_image_fmap, **kwargs - ) - ) - - @property - def codebooks(self): - return torch.stack(tuple(rvq.codebooks for rvq in self.rvqs)) - - def get_codes_from_indices(self, indices): - codes = tuple( - rvq.get_codes_from_indices(chunk_indices) - for rvq, chunk_indices in zip(self.rvqs, indices) - ) - return torch.stack(codes) - - def forward( - self, x, indices=None, return_all_codes=False, sample_codebook_temp=None - ): - shape = x.shape - split_dim = 1 if self.accept_image_fmap else -1 - assert shape[split_dim] == self.dim - - # split the feature dimension into groups - - x = x.chunk(self.groups, dim=split_dim) - - indices = default(indices, tuple()) - return_ce_loss = len(indices) > 0 - assert len(indices) == 0 or len(indices) == self.groups - - forward_kwargs = dict( - return_all_codes=return_all_codes, sample_codebook_temp=sample_codebook_temp - ) - - # invoke residual vq on each group - - out = tuple( - rvq(chunk, indices=chunk_indices, **forward_kwargs) - for rvq, chunk, chunk_indices in zip_longest(self.rvqs, x, indices) - ) - out = tuple(zip(*out)) - - # if returning cross entropy loss to rvq codebooks - - if return_ce_loss: - quantized, ce_losses = out - return torch.cat(quantized, dim=split_dim), sum(ce_losses) - - # otherwise, get all the zipped outputs and combine them - - quantized, all_indices, commit_losses, *maybe_all_codes = out - - quantized = torch.cat(quantized, dim=split_dim) - all_indices = torch.stack(all_indices) - commit_losses = torch.stack(commit_losses) - - ret = (quantized, all_indices, commit_losses, *maybe_all_codes) - return ret - class VectorQuantize(nn.Module): @@ -1621,7 +1159,6 @@ class VectorQuantize(nn.Module): kmeans_init=False, kmeans_iters=10, sync_kmeans=True, - use_cosine_sim=False, threshold_ema_dead_code=0, channel_last=True, accept_image_fmap=False, @@ -1685,7 +1222,6 @@ class VectorQuantize(nn.Module): self.sync_update_v = sync_update_v - codebook_class = EuclideanCodebook if not use_cosine_sim else CosineSimCodebook gumbel_sample_fn = partial( gumbel_sample, @@ -1717,9 +1253,6 @@ class VectorQuantize(nn.Module): ) if affine_param: - assert ( - not use_cosine_sim - ), "affine param is only compatible with euclidean codebook" codebook_kwargs = dict( **codebook_kwargs, affine_param=True, @@ -1728,7 +1261,7 @@ class VectorQuantize(nn.Module): affine_param_codebook_decay=affine_param_codebook_decay, ) - self._codebook = codebook_class(**codebook_kwargs) + self._codebook = EuclideanCodebook(**codebook_kwargs) self.in_place_codebook_optimizer = ( in_place_codebook_optimizer(self._codebook.parameters()) @@ -2218,7 +1751,6 @@ def kmeans( samples, num_clusters, num_iters=10, - use_cosine_sim=False, sample_fn=batched_sample_vectors, all_reduce_fn=noop, ): @@ -2232,10 +1764,7 @@ def kmeans( means = sample_fn(samples, num_clusters) for _ in range(num_iters): - if use_cosine_sim: - dists = samples @ rearrange(means, "h n d -> h d n") - else: - dists = -torch.cdist(samples, means, p=2) + dists = -torch.cdist(samples, means, p=2) buckets = torch.argmax(dists, dim=-1) bins = batched_bincount(buckets, minlength=num_clusters) @@ -2250,9 +1779,6 @@ def kmeans( new_means = new_means / rearrange(bins_min_clamped, "... -> ... 1") all_reduce_fn(new_means) - if use_cosine_sim: - new_means = l2norm(new_means) - means = torch.where(rearrange(zero_mask, "... -> ... 1"), means, new_means) return means, bins @@ -2591,193 +2117,6 @@ class EuclideanCodebook(nn.Module): return quantize, embed_ind, dist -class CosineSimCodebook(nn.Module): - def __init__( - self, - dim, - codebook_size, - num_codebooks=1, - kmeans_init=False, - kmeans_iters=10, - sync_kmeans=True, - decay=0.8, - eps=1e-5, - threshold_ema_dead_code=2, - reset_cluster_size=None, - use_ddp=False, - learnable_codebook=False, - gumbel_sample=gumbel_sample, - sample_codebook_temp=1.0, - ema_update=True, - ): - super().__init__() - self.transform_input = l2norm - - self.ema_update = ema_update - self.decay = decay - - if not kmeans_init: - embed = l2norm(uniform_init(num_codebooks, codebook_size, dim)) - else: - embed = torch.zeros(num_codebooks, codebook_size, dim) - - self.codebook_size = codebook_size - self.num_codebooks = num_codebooks - - self.kmeans_iters = kmeans_iters - self.eps = eps - self.threshold_ema_dead_code = threshold_ema_dead_code - self.reset_cluster_size = default(reset_cluster_size, threshold_ema_dead_code) - - assert callable(gumbel_sample) - self.gumbel_sample = gumbel_sample - self.sample_codebook_temp = sample_codebook_temp - - self.sample_fn = ( - sample_vectors_distributed - if use_ddp and sync_kmeans - else batched_sample_vectors - ) - self.kmeans_all_reduce_fn = ( - distributed.all_reduce if use_ddp and sync_kmeans else noop - ) - self.all_reduce_fn = distributed.all_reduce if use_ddp else noop - - self.register_buffer("initted", torch.Tensor([not kmeans_init])) - self.register_buffer("cluster_size", torch.zeros(num_codebooks, codebook_size)) - self.register_buffer("embed_avg", embed.clone()) - - self.learnable_codebook = learnable_codebook - if learnable_codebook: - self.embed = nn.Parameter(embed) - else: - self.register_buffer("embed", embed) - - @torch.jit.ignore - def init_embed_(self, data, mask=None): - if self.initted: - return - - if exists(mask): - c = data.shape[0] - data = rearrange(data[mask], "(c n) d -> c n d", c=c) - - embed, cluster_size = kmeans( - data, - self.codebook_size, - self.kmeans_iters, - use_cosine_sim=True, - sample_fn=self.sample_fn, - all_reduce_fn=self.kmeans_all_reduce_fn, - ) - - embed_sum = embed * rearrange(cluster_size, "... -> ... 1") - - self.embed.data.copy_(embed) - self.embed_avg.data.copy_(embed_sum) - self.cluster_size.data.copy_(cluster_size) - self.initted.data.copy_(torch.Tensor([True])) - - def replace(self, batch_samples, batch_mask): - batch_samples = l2norm(batch_samples) - - for ind, (samples, mask) in enumerate( - zip(batch_samples.unbind(dim=0), batch_mask.unbind(dim=0)) - ): - if not torch.any(mask): - continue - - sampled = self.sample_fn( - rearrange(samples, "... -> 1 ..."), mask.sum().item() - ) - sampled = rearrange(sampled, "1 ... -> ...") - - self.embed.data[ind][mask] = sampled - self.embed_avg.data[ind][mask] = sampled * self.reset_cluster_size - self.cluster_size.data[ind][mask] = self.reset_cluster_size - - def expire_codes_(self, batch_samples): - if self.threshold_ema_dead_code == 0: - return - - expired_codes = self.cluster_size < self.threshold_ema_dead_code - - if not torch.any(expired_codes): - return - - batch_samples = rearrange(batch_samples, "h ... d -> h (...) d") - self.replace(batch_samples, batch_mask=expired_codes) - - @autocast(enabled=False) - def forward(self, x, sample_codebook_temp=None, mask=None, freeze_codebook=False): - needs_codebook_dim = x.ndim < 4 - sample_codebook_temp = default(sample_codebook_temp, self.sample_codebook_temp) - - x = x.float() - - if needs_codebook_dim: - x = rearrange(x, "... -> 1 ...") - - dtype = x.dtype - - flatten, ps = pack_one(x, "h * d") - - if exists(mask): - mask = repeat( - mask, - "b n -> c (b h n)", - c=flatten.shape[0], - h=flatten.shape[-2] // (mask.shape[0] * mask.shape[1]), - ) - - self.init_embed_(flatten, mask=mask) - - embed = self.embed if self.learnable_codebook else self.embed.detach() - - dist = einsum("h n d, h c d -> h n c", flatten, embed) - - embed_ind, embed_onehot = self.gumbel_sample( - dist, dim=-1, temperature=sample_codebook_temp, training=self.training - ) - embed_ind = unpack_one(embed_ind, ps, "h *") - - if self.training: - unpacked_onehot = unpack_one(embed_onehot, ps, "h * c") - quantize = einsum("h b n c, h c d -> h b n d", unpacked_onehot, embed) - else: - quantize = batched_embedding(embed_ind, embed) - - if self.training and self.ema_update and not freeze_codebook: - if exists(mask): - embed_onehot[~mask] = 0.0 - - bins = embed_onehot.sum(dim=1) - self.all_reduce_fn(bins) - - ema_inplace(self.cluster_size.data, bins, self.decay) - - embed_sum = einsum("h n d, h n c -> h c d", flatten, embed_onehot) - self.all_reduce_fn(embed_sum.contiguous()) - ema_inplace(self.embed_avg.data, embed_sum, self.decay) - - cluster_size = laplace_smoothing( - self.cluster_size, self.codebook_size, self.eps - ) * self.cluster_size.sum(dim=-1, keepdim=True) - - embed_normalized = self.embed_avg / rearrange(cluster_size, "... -> ... 1") - embed_normalized = l2norm(embed_normalized) - - self.embed.data.copy_(l2norm(embed_normalized)) - self.expire_codes_(x) - - if needs_codebook_dim: - quantize, embed_ind = map( - lambda t: rearrange(t, "1 ... -> ..."), (quantize, embed_ind) - ) - - dist = unpack_one(dist, ps, "h * d") - return quantize, embed_ind, dist - class FocalLoss(nn.Module): """ @@ -2806,154 +2145,24 @@ class FocalLoss(nn.Module): return loss.sum() class MLP(torch.nn.Sequential): - """This block implements the multi-layer perceptron (MLP) module. - Adapted for backward compatibility from the torchvision library: - https://pytorch.org/vision/0.14/generated/torchvision.ops.MLP.html - - LICENSE: - - From PyTorch: - - Copyright (c) 2016- Facebook, Inc (Adam Paszke) - Copyright (c) 2014- Facebook, Inc (Soumith Chintala) - Copyright (c) 2011-2014 Idiap Research Institute (Ronan Collobert) - Copyright (c) 2012-2014 Deepmind Technologies (Koray Kavukcuoglu) - Copyright (c) 2011-2012 NEC Laboratories America (Koray Kavukcuoglu) - Copyright (c) 2011-2013 NYU (Clement Farabet) - Copyright (c) 2006-2010 NEC Laboratories America (Ronan Collobert, Leon Bottou, Iain Melvin, Jason Weston) - Copyright (c) 2006 Idiap Research Institute (Samy Bengio) - Copyright (c) 2001-2004 Idiap Research Institute (Ronan Collobert, Samy Bengio, Johnny Mariethoz) - - From Caffe2: - - Copyright (c) 2016-present, Facebook Inc. All rights reserved. - - All contributions by Facebook: - Copyright (c) 2016 Facebook Inc. - - All contributions by Google: - Copyright (c) 2015 Google Inc. - All rights reserved. - - All contributions by Yangqing Jia: - Copyright (c) 2015 Yangqing Jia - All rights reserved. - - All contributions by Kakao Brain: - Copyright 2019-2020 Kakao Brain - - All contributions by Cruise LLC: - Copyright (c) 2022 Cruise LLC. - All rights reserved. - - All contributions from Caffe: - Copyright(c) 2013, 2014, 2015, the respective contributors - All rights reserved. - - All other contributions: - Copyright(c) 2015, 2016 the respective contributors - All rights reserved. - - Caffe2 uses a copyright model similar to Caffe: each contributor holds - copyright over their contributions to Caffe2. The project versioning records - all such contribution and copyright details. If a contributor wants to further - mark their specific copyright on a particular contribution, they should - indicate their copyright solely in the commit message of the change when it is - committed. - - All rights reserved. - - Redistribution and use in source and binary forms, with or without - modification, are permitted provided that the following conditions are met: - - 1. Redistributions of source code must retain the above copyright - notice, this list of conditions and the following disclaimer. - - 2. Redistributions in binary form must reproduce the above copyright - notice, this list of conditions and the following disclaimer in the - documentation and/or other materials provided with the distribution. - - 3. Neither the names of Facebook, Deepmind Technologies, NYU, NEC Laboratories America - and IDIAP Research Institute nor the names of its contributors may be - used to endorse or promote products derived from this software without - specific prior written permission. - - THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" - AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE - IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE - ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE - LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR - CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF - SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS - INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN - CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) - ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE - POSSIBILITY OF SUCH DAMAGE. - - - Args: - in_channels (int): Number of channels of the input - hidden_channels (List[int]): List of the hidden channel dimensions - norm_layer (Callable[..., torch.nn.Module], optional): Norm layer that will be stacked on top of the linear layer. If ``None`` this layer won't be used. Default: ``None`` - activation_layer (Callable[..., torch.nn.Module], optional): Activation function which will be stacked on top of the normalization layer (if not None), otherwise on top of the linear layer. If ``None`` this layer won't be used. Default: ``torch.nn.ReLU`` - inplace (bool, optional): Parameter for the activation layer, which can optionally do the operation in-place. - Default is ``None``, which uses the respective default values of the ``activation_layer`` and Dropout layer. - bias (bool): Whether to use bias in the linear layer. Default ``True`` - dropout (float): The probability for the dropout layer. Default: 0.0 - """ def __init__( self, in_channels: int, hidden_channels: List[int], - activation_layer: Optional[Callable[..., torch.nn.Module]] = torch.nn.ReLU, - inplace: Optional[bool] = None, - bias: bool = True, - dropout: float = 0.0, ): - params = {} if inplace is None else {"inplace": inplace} layers = [] in_dim = in_channels for hidden_dim in hidden_channels[:-1]: - layers.append(torch.nn.Linear(in_dim, hidden_dim, bias=bias)) - layers.append(activation_layer(**params)) - layers.append(torch.nn.Dropout(dropout, **params)) + layers.append(torch.nn.Linear(in_dim, hidden_dim)) + layers.append(torch.nn.ReLU()) in_dim = hidden_dim - layers.append(torch.nn.Linear(in_dim, hidden_channels[-1], bias=bias)) - layers.append(torch.nn.Dropout(dropout, **params)) + layers.append(torch.nn.Linear(in_dim, hidden_channels[-1])) super().__init__(*layers) -def weights_init_encoder(m): - if isinstance(m, nn.Linear): - nn.init.orthogonal_(m.weight.data) - m.bias.data.fill_(0.0) - elif isinstance(m, nn.Conv2d) or isinstance(m, nn.ConvTranspose2d): - assert m.weight.size(2) == m.weight.size(3) - m.weight.data.fill_(0.0) - m.bias.data.fill_(0.0) - mid = m.weight.size(2) // 2 - gain = nn.init.calculate_gain("relu") - nn.init.orthogonal_(m.weight.data[:, :, mid, mid], gain) - - - -# @torch.jit.script # good to enable when not using torch.compile, disable when using (our default) -def new_gelu(x): - """ - Implementation of the GELU activation function currently in Google BERT repo (identical to OpenAI GPT). - Reference: Gaussian Error Linear Units (GELU) paper: https://arxiv.org/abs/1606.08415 - """ - return ( - 0.5 - * x - * ( - 1.0 - + torch.tanh(math.sqrt(2.0 / math.pi) * (x + 0.044715 * torch.pow(x, 3.0))) - ) - ) class CausalSelfAttention(nn.Module): @@ -3011,20 +2220,6 @@ class CausalSelfAttention(nn.Module): return y -class GPT_MLP(nn.Module): - def __init__(self, config): - super().__init__() - self.c_fc = nn.Linear(config.n_embd, 4 * config.n_embd) - self.c_proj = nn.Linear(4 * config.n_embd, config.n_embd) - self.dropout = nn.Dropout(config.dropout) - - def forward(self, x): - x = self.c_fc(x) - x = new_gelu(x) - x = self.c_proj(x) - x = self.dropout(x) - return x - class Block(nn.Module): def __init__(self, config): @@ -3032,7 +2227,12 @@ class Block(nn.Module): self.ln_1 = nn.LayerNorm(config.n_embd) self.attn = CausalSelfAttention(config) self.ln_2 = nn.LayerNorm(config.n_embd) - self.mlp = GPT_MLP(config) + self.mlp = nn.Sequential( + nn.Linear(config.n_embd, 4 * config.n_embd), + nn.GELU(), + nn.Linear(4 * config.n_embd, config.n_embd), + nn.Dropout(config.dropout) + ) def forward(self, x): x = x + self.attn(self.ln_1(x)) diff --git a/lerobot/configs/policy/vqbet.yaml b/lerobot/configs/policy/vqbet.yaml index 5fff7657..fe502d3e 100644 --- a/lerobot/configs/policy/vqbet.yaml +++ b/lerobot/configs/policy/vqbet.yaml @@ -52,7 +52,7 @@ training: action: "[i / ${fps} for i in range(1 - ${policy.n_obs_steps}, ${policy.n_action_pred_token} + ${policy.n_action_pred_chunk} - 1)]" eval: - n_episodes: 50 + n_episodes: 500 batch_size: 50 policy: @@ -61,8 +61,8 @@ policy: # Input / output structure. n_obs_steps: 5 # n_action_steps: 7 # n_action_pred_token + n_action_pred_window - 1 - n_action_pred_token: 3 # 3 5 - n_action_pred_chunk: 5 # 5 1 jay temp2 + n_action_pred_token: 7 + n_action_pred_chunk: 5 input_shapes: # TODO(rcadene, alexander-soare): add variables for height and width from the dataset/env? @@ -93,9 +93,10 @@ policy: vqvae_embedding_dim: 256 # VQ-BeT block_size: 500 - output_dim: 256 # 512 - n_layer: 6 # 8 - n_head: 6 # 4 - n_embd: 120 # 512 + output_dim: 512 + n_layer: 8 # 8 + n_head: 8 # 4 + n_embd: 512 dropout: 0.1 - mlp_hidden_dim: 1024 # 512 \ No newline at end of file + mlp_hidden_dim: 1024 # 512 + offset_loss_weight: 10000. \ No newline at end of file