delete redundant kwargs, change parameter names

This commit is contained in:
jayLEE0301 2024-05-23 10:25:35 -04:00
parent 2f4f137586
commit d209c0f73c
3 changed files with 150 additions and 262 deletions

View File

@ -108,15 +108,18 @@ class VQBeTConfig:
vqvae_groups: int = 2
vqvae_n_embed: int = 16
vqvae_embedding_dim: int = 256
vqvae_enc_hidden_dim: int = 128
# VQ-BeT
block_size: int = 50
output_dim: int = 256
n_layer: int = 6
n_head: int = 6
n_embd: int = 120
gpt_block_size: int = 500
gpt_input_dim: int = 512
gpt_output_dim: int = 512
gpt_n_layer: int = 8
gpt_n_head: int = 8
gpt_n_embed: int = 512
dropout: float = 0.1
mlp_hidden_dim: int = 1024
offset_loss_weight: float = 10000.
secondary_code_multiplier: float = 0.5
def __post_init__(self):
"""Input validation (not exhaustive)."""

View File

@ -62,7 +62,7 @@ class VQBeTPolicy(nn.Module, PyTorchModelHubMixin):
# queues are populated during rollout of the policy, they contain the n latest observations and actions
self._obs_queues = None
self._queues = None
self.vqbet = VQBeTModel(config)
@ -74,12 +74,11 @@ class VQBeTPolicy(nn.Module, PyTorchModelHubMixin):
"""
Clear observation and action queues. Should be called on `env.reset()`
"""
self._obs_queues = {
self._queues = {
"observation.image": deque(maxlen=self.config.n_obs_steps),
"observation.state": deque(maxlen=self.config.n_obs_steps),
"action": deque(maxlen=self.config.n_action_pred_chunk),
}
if self.config.n_action_pred_chunk is not None:
self._action_queue = deque([], maxlen=self.config.n_action_pred_chunk)
@torch.no_grad
def select_action(self, batch: dict[str, Tensor]) -> Tensor:
@ -93,7 +92,7 @@ class VQBeTPolicy(nn.Module, PyTorchModelHubMixin):
self.eval()
batch = self.normalize_inputs(batch)
self._obs_queues = populate_queues(self._obs_queues, batch)
self._queues = populate_queues(self._queues, batch)
if not self.check_discretized():
self.vqbet._action_head._vqvae_model.discretized = True
@ -101,16 +100,16 @@ class VQBeTPolicy(nn.Module, PyTorchModelHubMixin):
assert "observation.image" in batch
assert "observation.state" in batch
if len(self._action_queue) == 0:
if len(self._queues["action"]) == 0:
batch = {key: torch.stack(list(self._obs_queues[key]), dim=1) for key in batch}
batch = {key: torch.stack(list(self._queues[key]), dim=1) for key in batch}
actions = self.vqbet(batch, rollout=True)[:, : self.config.n_action_pred_chunk]
actions = self.unnormalize_outputs({"action": actions})["action"]
self._action_queue.extend(actions.transpose(0, 1))
self._queues["action"].extend(actions.transpose(0, 1))
action = self._action_queue.popleft()
action = self._queues["action"].popleft()
return action
def forward(self, batch: dict[str, Tensor]) -> dict[str, Tensor]:
@ -123,9 +122,7 @@ class VQBeTPolicy(nn.Module, PyTorchModelHubMixin):
_, loss = self.vqbet(batch, rollout=False)
return {"loss": loss['actor_loss'], 'equal_single_code_rate': loss['equal_single_code_rate'], 'equal_single_code_rate2': loss['equal_single_code_rate2'], "offset_loss_weight": loss['offset_loss_weight'], \
"action_diff": loss['action_diff'], "action_diff_tot": loss['action_diff_tot'], "action_diff_mean_res1": loss['action_diff_mean_res1'], "action_diff_mean_res2": loss['action_diff_mean_res2'], \
"action_diff_max": loss['action_diff_max']}
return loss
class VQBeTModel(nn.Module):
@ -138,36 +135,19 @@ class VQBeTModel(nn.Module):
self.global_cond_dim = self.rgb_encoder.feature_dim
# action token and EOS token
self._action_token = nn.Parameter(torch.randn(1, 1, self.config.n_embd)) # Batch, Timestep, Data type, GPT input dim
self._eos_token = nn.Parameter(torch.randn(1, 1, self.config.n_embd))
self._action_token = nn.Parameter(torch.randn(1, 1, self.config.gpt_input_dim)) # Batch, Timestep, Data type, GPT input dim
self._eos_token = nn.Parameter(torch.randn(1, 1, self.config.gpt_input_dim))
self.state_projector = MLP(
config.output_shapes["action"][0], hidden_channels=[self.config.n_embd]
config.output_shapes["action"][0],
hidden_channels=[self.config.gpt_input_dim]
)
self.obs_projector = MLP(
self.global_cond_dim, hidden_channels=[self.config.n_embd]
self.global_cond_dim,
hidden_channels=[self.config.gpt_input_dim]
)
self._policy = GPT(
GPTConfig(
block_size=self.config.block_size,
input_dim=self.config.n_embd,
output_dim=self.config.output_dim,
n_layer=self.config.n_layer,
n_head=self.config.n_head,
n_embd=self.config.n_embd,
dropout=self.config.dropout,
)
)
self._action_head = VQBeTHead(
config.output_dim,
config.output_shapes["action"][0],
offset_loss_weight=config.offset_loss_weight,
hidden_size=config.mlp_hidden_dim,
vqvae_groups=config.vqvae_groups,
vqvae_n_embed=config.vqvae_n_embed,
vqvae_embedding_dim=config.vqvae_embedding_dim,
n_action_pred_chunk=config.n_action_pred_chunk
)
self._policy = GPT(config)
self._action_head = VQBeTHead(config)
def discretize(self, discretize_step, actions):
return self._action_head.discretize(discretize_step, actions)
@ -190,7 +170,7 @@ class VQBeTModel(nn.Module):
torch.unsqueeze(self.obs_projector(img_features), dim=2),
torch.unsqueeze(self.state_projector(batch["observation.state"]), dim=2),
self._action_token.repeat(batch_size, n_obs_steps, 1, 1)
], dim=-2).view(batch_size, -1, self.config.n_embd)
], dim=-2).view(batch_size, -1, self.config.gpt_input_dim)
if img_features.shape[1] != n_obs_steps:
raise NotImplementedError
# eos_token = self._eos_token.repeat(batch_size, 1, 1) # TODO remove EOS token
@ -231,57 +211,40 @@ class VQBeTModel(nn.Module):
action,
reduction="mean",
)
return pred_action, loss[0] if isinstance(loss, tuple) else loss
return pred_action, loss
class VQBeTHead(nn.Module):
def __init__(
self,
# network_kwargs
input_size,
output_size,
hidden_size=1024,
# loss_kwargs
offset_loss_weight=100.0,
secondary_code_multiplier=0.5,
vqvae_groups=2, # G(number of groups)
vqvae_n_embed=16, # C(number of code integers)
vqvae_embedding_dim=512, # D(embedding dims)
n_action_pred_chunk=1, # action chunk size
):
super().__init__()
self.input_size = input_size
self.output_size = output_size
self.hidden_size = hidden_size
self.offset_loss_weight = offset_loss_weight
self.secondary_code_multiplier = secondary_code_multiplier
def __init__(self, config: VQBeTConfig):
"""
TODO: add explanation for each value.
"""
self._G = vqvae_groups # G(number of groups)
self._C = vqvae_n_embed # C(number of code integers)
self._D = vqvae_embedding_dim # D(embedding dims)
self.n_action_pred_chunk = n_action_pred_chunk # action chunk size
super().__init__()
self.input_size = config.gpt_output_dim
self.output_size = config.output_shapes["action"][0]
self.hidden_size = config.mlp_hidden_dim
self.offset_loss_weight = config.offset_loss_weight
self.secondary_code_multiplier = config.secondary_code_multiplier
self.vqvae_groups = config.vqvae_groups
self.vqvae_n_embed = config.vqvae_n_embed # C(number of code integers)
self.vqvae_embedding_dim = config.vqvae_embedding_dim # D(embedding dims)
self.n_action_pred_chunk = config.n_action_pred_chunk # action chunk size
self._map_to_cbet_preds_bin = MLP(
in_channels=self.input_size,
hidden_channels=[self._G * self._C],
hidden_channels=[self.vqvae_groups * self.vqvae_n_embed],
)
self._map_to_cbet_preds_offset = MLP(
in_channels=self.input_size,
hidden_channels=[
self._G * self._C * n_action_pred_chunk * self.output_size,
self.vqvae_groups * self.vqvae_n_embed * config.n_action_pred_chunk * self.output_size,
],
)
# init vqvae
vqvae_config = {
"action_chunk": self.n_action_pred_chunk,
"action_dim": self.output_size,
"vqvae_n_latent_dims": self._D,
"vqvae_n_embed": self._C,
"vqvae_groups": self._G,
"device": get_device_from_parameters(self),
}
self._vqvae_model = init_vqvae(vqvae_config)
self._vqvae_model = VqVae(config)
# loss
self._criterion = FocalLoss(gamma=2.0)
@ -307,10 +270,10 @@ class VQBeTHead(nn.Module):
cbet_logits = self._map_to_cbet_preds_bin(x)
cbet_offsets = self._map_to_cbet_preds_offset(x)
cbet_logits = einops.rearrange(
cbet_logits, "(NT) (G C) -> (NT) G C", G=self._G
cbet_logits, "(NT) (G C) -> (NT) G C", G=self.vqvae_groups
)
cbet_offsets = einops.rearrange(
cbet_offsets, "(NT) (G C WA) -> (NT) G C WA", G=self._G, C=self._C
cbet_offsets, "(NT) (G C WA) -> (NT) G C WA", G=self.vqvae_groups, C=self.vqvae_n_embed
)
cbet_probs = torch.softmax(cbet_logits, dim=-1)
NT, G, choices = cbet_probs.shape
@ -322,7 +285,7 @@ class VQBeTHead(nn.Module):
indices = (
torch.arange(NT).unsqueeze(1).cuda(),
torch.arange(self._G).unsqueeze(0).cuda(),
torch.arange(self.vqvae_groups).unsqueeze(0).cuda(),
sampled_centers,
)
# Use advanced indexing to sample the values
@ -330,7 +293,7 @@ class VQBeTHead(nn.Module):
sampled_offsets = sampled_offsets.sum(dim=1)
centers = self._vqvae_model.draw_code_forward(sampled_centers).view(
NT, -1, self._D
NT, -1, self.vqvae_embedding_dim
)
return_decoder_input = einops.rearrange(
centers.clone().detach(), "NT 1 D -> NT D"
@ -341,7 +304,7 @@ class VQBeTHead(nn.Module):
.detach()
) # NT, A
sampled_offsets = einops.rearrange(
sampled_offsets, "NT (W A) -> NT W A", W=self._vqvae_model.input_dim_h
sampled_offsets, "NT (W A) -> NT W A", W=self.n_action_pred_chunk
)
predicted_action = decoded_action + sampled_offsets
predicted_action = einops.rearrange(
@ -349,7 +312,7 @@ class VQBeTHead(nn.Module):
"(N T) W A -> N T (W A)",
N=N,
T=T,
W=self._vqvae_model.input_dim_h,
W=self.n_action_pred_chunk,
)
return {
@ -357,10 +320,6 @@ class VQBeTHead(nn.Module):
"predicted_action": predicted_action,
"sampled_centers": sampled_centers,
"decoded_action": decoded_action,
"G": G,
"NT": NT,
"N": N,
"T": T,
}
def loss_fn(self, pred, target, **kwargs):
@ -369,11 +328,12 @@ class VQBeTHead(nn.Module):
predicted_action = pred["predicted_action"]
sampled_centers = pred["sampled_centers"]
decoded_action = pred["decoded_action"]
G, NT, N, T = pred["G"], pred["NT"], pred["N"], pred["T"]
NT = predicted_action.shape[0]
T = predicted_action.shape[1]
cbet_logits = pred["cbet_logits"]
predicted_action = einops.rearrange(
predicted_action, "N T (W A) -> (N T) W A", W=self._vqvae_model.input_dim_h
predicted_action, "N T (W A) -> (N T) W A", W=self.n_action_pred_chunk
)
action_seq = einops.rearrange(action_seq, "N T W A -> (N T) W A")
@ -400,74 +360,48 @@ class VQBeTHead(nn.Module):
)
cbet_loss = cbet_loss1 * 5 + cbet_loss2 * self.secondary_code_multiplier
equal_total_code_rate = (
torch.sum(
(torch.sum((action_bins == sampled_centers).int(), axis=1) == G).int()
)
/ NT
)
equal_single_code_rate = torch.sum(
equal_primary_code_rate = torch.sum(
(action_bins[:, 0] == sampled_centers[:, 0]).int()
) / (NT)
equal_single_code_rate2 = torch.sum(
equal_secondary_code_rate = torch.sum(
(action_bins[:, 1] == sampled_centers[:, 1]).int()
) / (NT)
action_diff = F.mse_loss(
einops.rearrange(action_seq, "(N T) W A -> N T W A", T=T)[:, 4, :, :],
einops.rearrange(predicted_action, "(N T) W A -> N T W A", T=T)[
:, 4, :, :
],
) # batch, time, windowsize (t ... t+N), action dim -> [:, -1, 0, :] is for rollout
action_diff_tot = F.mse_loss(
einops.rearrange(action_seq, "(N T) W A -> N T W A", T=T)[:, :, :, :],
einops.rearrange(predicted_action, "(N T) W A -> N T W A", T=T)[
:, :, :, :
],
) # batch, time, windowsize (t ... t+N), action dim -> [:, -1, 0, :] is for rollout
action_diff_mean_res1 = (
action_mse_error = F.mse_loss(
einops.rearrange(action_seq, "(N T) W A -> N T W A", T=T),
einops.rearrange(predicted_action, "(N T) W A -> N T W A", T=T),
)
vq_action_error = (
abs(
einops.rearrange(action_seq, "(N T) W A -> N T W A", T=T)[
:, 4, :, :
]
- einops.rearrange(decoded_action, "(N T) W A -> N T W A", T=T)[
:, 4, :, :
]
einops.rearrange(action_seq, "(N T) W A -> N T W A", T=T) - einops.rearrange(decoded_action, "(N T) W A -> N T W A", T=T)
)
).mean()
action_diff_mean_res2 = (
offset_action_error = (
abs(
einops.rearrange(action_seq, "(N T) W A -> N T W A", T=T)[
:, 4, :, :
]
- einops.rearrange(predicted_action, "(N T) W A -> N T W A", T=T)[
:, 4, :, :
]
einops.rearrange(action_seq, "(N T) W A -> N T W A", T=T) - einops.rearrange(predicted_action, "(N T) W A -> N T W A", T=T)
)
).mean()
action_diff_max = (
action_error_max = (
abs(
einops.rearrange(action_seq, "(N T) W A -> N T W A", T=T)[
:, 4, :, :
]
- einops.rearrange(predicted_action, "(N T) W A -> N T W A", T=T)[
:, 4, :, :
]
einops.rearrange(action_seq, "(N T) W A -> N T W A", T=T) - einops.rearrange(predicted_action, "(N T) W A -> N T W A", T=T)
)
).max()
loss = cbet_loss + self.offset_loss_weight * offset_loss
loss_dict = {
"loss": loss,
"classification_loss": cbet_loss.detach().cpu().item(),
"offset_loss": offset_loss.detach().cpu().item(),
"total_loss": loss.detach().cpu().item(),
"equal_total_code_rate": equal_total_code_rate,
"equal_single_code_rate": equal_single_code_rate,
"equal_single_code_rate2": equal_single_code_rate2,
"equal_primary_code_rate": equal_primary_code_rate.detach().cpu().item(),
"equal_secondary_code_rate": equal_secondary_code_rate.detach().cpu().item(),
"vq_action_error": vq_action_error.detach().cpu().item(),
"offset_action_error": offset_action_error.detach().cpu().item(),
"action_error_max": action_error_max.detach().cpu().item(),
"action_mse_error": action_mse_error.detach().cpu().item(),
}
return {"actor_loss": loss, "equal_single_code_rate": equal_single_code_rate, "equal_single_code_rate2": equal_single_code_rate2, "offset_loss_weight": self.offset_loss_weight, \
"action_diff": action_diff, "action_diff_tot": action_diff_tot, "action_diff_mean_res1": action_diff_mean_res1, "action_diff_mean_res2": action_diff_mean_res2, \
"action_diff_max": action_diff_max}, loss_dict
return loss_dict
class VQBeTOptimizer:
def __init__(self, policy, cfg):
@ -688,71 +622,43 @@ def _replace_submodules(
class VqVae(nn.Module):
def __init__(
self,
input_dim_h=10, # length of action chunk
input_dim_w=9, # action dim
n_latent_dims=512,
vqvae_n_embed=32,
vqvae_groups=4,
eval=True,
load_dir=None,
encoder_loss_multiplier=1.0,
act_scale=1.0,
self, config: VQBeTConfig,
):
super(VqVae, self).__init__()
self.n_latent_dims = n_latent_dims
self.input_dim_h = input_dim_h
self.input_dim_w = input_dim_w
self.rep_dim = self.n_latent_dims
self.vqvae_n_embed = vqvae_n_embed
self.vqvae_lr = 1e-3
self.vqvae_groups = vqvae_groups
self.encoder_loss_multiplier = encoder_loss_multiplier
self.act_scale = act_scale
self.n_action_pred_chunk = config.n_action_pred_chunk
self.action_dim = config.output_shapes["action"][0]
self.discretized = False
self.optimized_steps = 0
discrete_cfg = {"groups": self.vqvae_groups, "n_embed": self.vqvae_n_embed}
self.vq_layer = ResidualVQ(
dim=self.n_latent_dims,
num_quantizers=discrete_cfg["groups"],
codebook_size=self.vqvae_n_embed,
dim=config.vqvae_embedding_dim,
num_quantizers=config.vqvae_groups,
codebook_size=config.vqvae_n_embed,
)
self.embedding_dim = self.n_latent_dims
self.embedding_dim = config.vqvae_embedding_dim
if self.input_dim_h == 1:
if self.n_action_pred_chunk == 1:
self.encoder = MLP(
in_channels=input_dim_w,
hidden_channels=[128, 128, n_latent_dims],
in_channels=self.action_dim,
hidden_channels=[config.vqvae_enc_hidden_dim, config.vqvae_enc_hidden_dim, config.vqvae_embedding_dim],
)
self.decoder = MLP(
in_channels=n_latent_dims,
hidden_channels=[128, 128, input_dim_w],
in_channels=config.vqvae_embedding_dim,
hidden_channels=[config.vqvae_enc_hidden_dim, config.vqvae_enc_hidden_dim, self.action_dim],
)
else:
self.encoder = MLP(
in_channels=input_dim_w * self.input_dim_h,
hidden_channels=[128, 128, n_latent_dims],
in_channels=self.action_dim * self.n_action_pred_chunk,
hidden_channels=[config.vqvae_enc_hidden_dim, config.vqvae_enc_hidden_dim, config.vqvae_embedding_dim],
)
self.decoder = MLP(
in_channels=n_latent_dims,
hidden_channels=[128, 128, input_dim_w * self.input_dim_h],
in_channels=config.vqvae_embedding_dim,
hidden_channels=[config.vqvae_enc_hidden_dim, config.vqvae_enc_hidden_dim, self.action_dim * self.n_action_pred_chunk],
)
if load_dir is not None:
try:
state_dict = torch.load(load_dir)
except RuntimeError:
state_dict = torch.load(load_dir, map_location=torch.device("cpu"))
self.load_state_dict(state_dict)
if eval:
self.eval()
else:
self.train()
self.train()
def eval(self):
self.training = False
@ -783,23 +689,22 @@ class VqVae(nn.Module):
return z_embed
def get_action_from_latent(self, latent):
output = self.decoder(latent) * self.act_scale
if self.input_dim_h == 1:
return einops.rearrange(output, "N (T A) -> N T A", A=self.input_dim_w)
output = self.decoder(latent)
if self.n_action_pred_chunk == 1:
return einops.rearrange(output, "N (T A) -> N T A", A=self.action_dim)
else:
return einops.rearrange(output, "N (T A) -> N T A", A=self.input_dim_w)
return einops.rearrange(output, "N (T A) -> N T A", A=self.action_dim)
def preprocess(self, state):
if not torch.is_tensor(state):
state = torch.FloatTensor(state.copy())
if self.input_dim_h == 1:
if self.n_action_pred_chunk == 1:
state = state.squeeze(-2) # state.squeeze(-1)
else:
state = einops.rearrange(state, "N T A -> N (T A)")
return state
def get_code(self, state, required_recon=False):
state = state / self.act_scale
state = self.preprocess(state)
with torch.no_grad():
state_rep = self.encoder(state)
@ -810,9 +715,9 @@ class VqVae(nn.Module):
vq_code = vq_code.view(*state_rep_shape, -1)
vq_loss_state = torch.sum(vq_loss_state)
if required_recon:
recon_state = self.decoder(state_vq) * self.act_scale
recon_state_ae = self.decoder(state_rep) * self.act_scale
if self.input_dim_h == 1:
recon_state = self.decoder(state_vq)
recon_state_ae = self.decoder(state_rep)
if self.n_action_pred_chunk == 1:
return state_vq, vq_code, recon_state, recon_state_ae
else:
return (
@ -826,7 +731,6 @@ class VqVae(nn.Module):
return state_vq, vq_code
def vqvae_forward(self, state):
state = state / self.act_scale
state = self.preprocess(state)
state_rep = self.encoder(state)
state_rep_shape = state_rep.shape[:-1]
@ -839,7 +743,7 @@ class VqVae(nn.Module):
dec_out = self.decoder(state_vq)
encoder_loss = (state - dec_out).abs().mean()
rep_loss = encoder_loss * self.encoder_loss_multiplier + (vq_loss_state * 5)
rep_loss = encoder_loss * vq_loss_state * 5
metric = (
encoder_loss.clone().detach(),
@ -857,28 +761,16 @@ class VqVae(nn.Module):
def init_vqvae(config):
# model
vqvae_model = VqVae(
input_dim_h=config["action_chunk"],
input_dim_w=config["action_dim"],
n_latent_dims=config["vqvae_n_latent_dims"],
vqvae_n_embed=config["vqvae_n_embed"],
vqvae_groups=config["vqvae_groups"],
eval=False,
)
return vqvae_model
def pretrain_vqvae(vqvae_model, discretize_step, actions):
if vqvae_model.input_dim_h == 1:
if vqvae_model.n_action_pred_chunk == 1:
# not using action chunk
actions = actions.reshape(-1, 1, actions.shape[-1])
else:
# using action chunk
slices = []
slices.extend([actions[:, j:j+vqvae_model.input_dim_h, :] for j in range(actions.shape[1]+1-vqvae_model.input_dim_h)])
slices.extend([actions[:, j:j+vqvae_model.n_action_pred_chunk, :] for j in range(actions.shape[1]+1-vqvae_model.n_action_pred_chunk)])
actions = torch.cat(slices, dim=0)
@ -2166,40 +2058,40 @@ class MLP(torch.nn.Sequential):
class CausalSelfAttention(nn.Module):
def __init__(self, config):
super().__init__()
assert config.n_embd % config.n_head == 0
assert config.gpt_n_embed % config.gpt_n_head == 0
# key, query, value projections for all heads, but in a batch
self.c_attn = nn.Linear(config.n_embd, 3 * config.n_embd)
self.c_attn = nn.Linear(config.gpt_n_embed, 3 * config.gpt_n_embed)
# output projection
self.c_proj = nn.Linear(config.n_embd, config.n_embd)
self.c_proj = nn.Linear(config.gpt_n_embed, config.gpt_n_embed)
# regularization
self.attn_dropout = nn.Dropout(config.dropout)
self.resid_dropout = nn.Dropout(config.dropout)
# causal mask to ensure that attention is only applied to the left in the input sequence
self.register_buffer(
"bias",
torch.tril(torch.ones(config.block_size, config.block_size)).view(
1, 1, config.block_size, config.block_size
torch.tril(torch.ones(config.gpt_block_size, config.gpt_block_size)).view(
1, 1, config.gpt_block_size, config.gpt_block_size
),
)
self.n_head = config.n_head
self.n_embd = config.n_embd
self.gpt_n_head = config.gpt_n_head
self.gpt_n_embed = config.gpt_n_embed
def forward(self, x):
(
B,
T,
C,
) = x.size() # batch size, sequence length, embedding dimensionality (n_embd)
) = x.size() # batch size, sequence length, embedding dimensionality (gpt_n_embed)
# calculate query, key, values for all heads in batch and move head forward to be the batch dim
q, k, v = self.c_attn(x).split(self.n_embd, dim=2)
k = k.view(B, T, self.n_head, C // self.n_head).transpose(
q, k, v = self.c_attn(x).split(self.gpt_n_embed, dim=2)
k = k.view(B, T, self.gpt_n_head, C // self.gpt_n_head).transpose(
1, 2
) # (B, nh, T, hs)
q = q.view(B, T, self.n_head, C // self.n_head).transpose(
q = q.view(B, T, self.gpt_n_head, C // self.gpt_n_head).transpose(
1, 2
) # (B, nh, T, hs)
v = v.view(B, T, self.n_head, C // self.n_head).transpose(
v = v.view(B, T, self.gpt_n_head, C // self.gpt_n_head).transpose(
1, 2
) # (B, nh, T, hs)
@ -2222,13 +2114,13 @@ class CausalSelfAttention(nn.Module):
class Block(nn.Module):
def __init__(self, config):
super().__init__()
self.ln_1 = nn.LayerNorm(config.n_embd)
self.ln_1 = nn.LayerNorm(config.gpt_n_embed)
self.attn = CausalSelfAttention(config)
self.ln_2 = nn.LayerNorm(config.n_embd)
self.ln_2 = nn.LayerNorm(config.gpt_n_embed)
self.mlp = nn.Sequential(
nn.Linear(config.n_embd, 4 * config.n_embd),
nn.Linear(config.gpt_n_embed, 4 * config.gpt_n_embed),
nn.GELU(),
nn.Linear(4 * config.n_embd, config.n_embd),
nn.Linear(4 * config.gpt_n_embed, config.gpt_n_embed),
nn.Dropout(config.dropout)
)
@ -2238,15 +2130,6 @@ class Block(nn.Module):
return x
@dataclass
class GPTConfig:
block_size: int = 1024
input_dim: int = 256
output_dim: int = 256
n_layer: int = 12
n_head: int = 12
n_embd: int = 768
dropout: float = 0.1
class GPT(nn.Module):
@ -2287,29 +2170,28 @@ class GPT(nn.Module):
"""
def __init__(self, config):
def __init__(self, config: VQBeTConfig):
super().__init__()
assert config.input_dim is not None
assert config.output_dim is not None
assert config.block_size is not None
assert config.gpt_output_dim is not None
assert config.gpt_block_size is not None
self.config = config
self.transformer = nn.ModuleDict(
dict(
wte=nn.Linear(config.input_dim, config.n_embd),
wpe=nn.Embedding(config.block_size, config.n_embd),
wte=nn.Linear(config.gpt_input_dim, config.gpt_n_embed),
wpe=nn.Embedding(config.gpt_block_size, config.gpt_n_embed),
drop=nn.Dropout(config.dropout),
h=nn.ModuleList([Block(config) for _ in range(config.n_layer)]),
ln_f=nn.LayerNorm(config.n_embd),
h=nn.ModuleList([Block(config) for _ in range(config.gpt_n_layer)]),
ln_f=nn.LayerNorm(config.gpt_n_embed),
)
)
self.lm_head = nn.Linear(config.n_embd, config.output_dim, bias=False)
self.lm_head = nn.Linear(config.gpt_n_embed, config.gpt_output_dim, bias=False)
# init all weights, and apply a special scaled init to the residual projections, per GPT-2 paper
self.apply(self._init_weights)
for pn, p in self.named_parameters():
if pn.endswith("c_proj.weight"):
torch.nn.init.normal_(
p, mean=0.0, std=0.02 / math.sqrt(2 * config.n_layer)
p, mean=0.0, std=0.02 / math.sqrt(2 * config.gpt_n_layer)
)
# report number of parameters
@ -2320,8 +2202,8 @@ class GPT(nn.Module):
device = input.device
b, t, d = input.size()
assert (
t <= self.config.block_size
), f"Cannot forward sequence of length {t}, block size is only {self.config.block_size}"
t <= self.config.gpt_block_size
), f"Cannot forward sequence of length {t}, block size is only {self.config.gpt_block_size}"
pos = torch.arange(0, t, dtype=torch.long, device=device).unsqueeze(
0
) # shape (1, t)
@ -2329,10 +2211,10 @@ class GPT(nn.Module):
# forward the GPT model itself
tok_emb = self.transformer.wte(
input
) # token embeddings of shape (b, t, n_embd)
) # token embeddings of shape (b, t, gpt_n_embed)
pos_emb = self.transformer.wpe(
pos
) # position embeddings of shape (1, t, n_embd)
) # position embeddings of shape (1, t, gpt_n_embed)
x = self.transformer.drop(tok_emb + pos_emb)
for block in self.transformer.h:
x = block(x)
@ -2351,14 +2233,14 @@ class GPT(nn.Module):
torch.nn.init.zeros_(module.bias)
torch.nn.init.ones_(module.weight)
def crop_block_size(self, block_size):
assert block_size <= self.config.block_size
self.config.block_size = block_size
def crop_block_size(self, gpt_block_size):
assert gpt_block_size <= self.config.gpt_block_size
self.config.gpt_block_size = gpt_block_size
self.transformer.wpe.weight = nn.Parameter(
self.transformer.wpe.weight[:block_size]
self.transformer.wpe.weight[:gpt_block_size]
)
for block in self.transformer.h:
block.attn.bias = block.attn.bias[:, :, :block_size, :block_size]
block.attn.bias = block.attn.bias[:, :, :gpt_block_size, :gpt_block_size]
def configure_optimizers(self, weight_decay, learning_rate, betas, optimizer="Adamw", eps=None):
"""

View File

@ -22,7 +22,7 @@ override_dataset_stats:
max: [511.0, 511.0]
training:
offline_steps: 800000
offline_steps: 250000
online_steps: 0
eval_freq: 20000
save_freq: 20000
@ -90,12 +90,15 @@ policy:
vqvae_groups: 2
vqvae_n_embed: 16
vqvae_embedding_dim: 256
vqvae_enc_hidden_dim: 128
# VQ-BeT
block_size: 500
output_dim: 512
n_layer: 8 # 8
n_head: 8 # 4
n_embd: 512
gpt_block_size: 500
gpt_input_dim: 512
gpt_output_dim: 512
gpt_n_layer: 8
gpt_n_head: 8
gpt_n_embed: 512
dropout: 0.1
mlp_hidden_dim: 1024 # 512
offset_loss_weight: 10000.
mlp_hidden_dim: 1024
offset_loss_weight: 10000.
secondary_code_multiplier: 0.5