delete redundant kwargs, change parameter names
This commit is contained in:
parent
2f4f137586
commit
d209c0f73c
|
@ -108,15 +108,18 @@ class VQBeTConfig:
|
|||
vqvae_groups: int = 2
|
||||
vqvae_n_embed: int = 16
|
||||
vqvae_embedding_dim: int = 256
|
||||
vqvae_enc_hidden_dim: int = 128
|
||||
# VQ-BeT
|
||||
block_size: int = 50
|
||||
output_dim: int = 256
|
||||
n_layer: int = 6
|
||||
n_head: int = 6
|
||||
n_embd: int = 120
|
||||
gpt_block_size: int = 500
|
||||
gpt_input_dim: int = 512
|
||||
gpt_output_dim: int = 512
|
||||
gpt_n_layer: int = 8
|
||||
gpt_n_head: int = 8
|
||||
gpt_n_embed: int = 512
|
||||
dropout: float = 0.1
|
||||
mlp_hidden_dim: int = 1024
|
||||
offset_loss_weight: float = 10000.
|
||||
secondary_code_multiplier: float = 0.5
|
||||
|
||||
def __post_init__(self):
|
||||
"""Input validation (not exhaustive)."""
|
||||
|
|
|
@ -62,7 +62,7 @@ class VQBeTPolicy(nn.Module, PyTorchModelHubMixin):
|
|||
|
||||
|
||||
# queues are populated during rollout of the policy, they contain the n latest observations and actions
|
||||
self._obs_queues = None
|
||||
self._queues = None
|
||||
|
||||
self.vqbet = VQBeTModel(config)
|
||||
|
||||
|
@ -74,12 +74,11 @@ class VQBeTPolicy(nn.Module, PyTorchModelHubMixin):
|
|||
"""
|
||||
Clear observation and action queues. Should be called on `env.reset()`
|
||||
"""
|
||||
self._obs_queues = {
|
||||
self._queues = {
|
||||
"observation.image": deque(maxlen=self.config.n_obs_steps),
|
||||
"observation.state": deque(maxlen=self.config.n_obs_steps),
|
||||
"action": deque(maxlen=self.config.n_action_pred_chunk),
|
||||
}
|
||||
if self.config.n_action_pred_chunk is not None:
|
||||
self._action_queue = deque([], maxlen=self.config.n_action_pred_chunk)
|
||||
|
||||
@torch.no_grad
|
||||
def select_action(self, batch: dict[str, Tensor]) -> Tensor:
|
||||
|
@ -93,7 +92,7 @@ class VQBeTPolicy(nn.Module, PyTorchModelHubMixin):
|
|||
self.eval()
|
||||
|
||||
batch = self.normalize_inputs(batch)
|
||||
self._obs_queues = populate_queues(self._obs_queues, batch)
|
||||
self._queues = populate_queues(self._queues, batch)
|
||||
|
||||
if not self.check_discretized():
|
||||
self.vqbet._action_head._vqvae_model.discretized = True
|
||||
|
@ -101,16 +100,16 @@ class VQBeTPolicy(nn.Module, PyTorchModelHubMixin):
|
|||
assert "observation.image" in batch
|
||||
assert "observation.state" in batch
|
||||
|
||||
if len(self._action_queue) == 0:
|
||||
if len(self._queues["action"]) == 0:
|
||||
|
||||
batch = {key: torch.stack(list(self._obs_queues[key]), dim=1) for key in batch}
|
||||
batch = {key: torch.stack(list(self._queues[key]), dim=1) for key in batch}
|
||||
actions = self.vqbet(batch, rollout=True)[:, : self.config.n_action_pred_chunk]
|
||||
|
||||
actions = self.unnormalize_outputs({"action": actions})["action"]
|
||||
|
||||
self._action_queue.extend(actions.transpose(0, 1))
|
||||
self._queues["action"].extend(actions.transpose(0, 1))
|
||||
|
||||
action = self._action_queue.popleft()
|
||||
action = self._queues["action"].popleft()
|
||||
return action
|
||||
|
||||
def forward(self, batch: dict[str, Tensor]) -> dict[str, Tensor]:
|
||||
|
@ -123,9 +122,7 @@ class VQBeTPolicy(nn.Module, PyTorchModelHubMixin):
|
|||
|
||||
_, loss = self.vqbet(batch, rollout=False)
|
||||
|
||||
return {"loss": loss['actor_loss'], 'equal_single_code_rate': loss['equal_single_code_rate'], 'equal_single_code_rate2': loss['equal_single_code_rate2'], "offset_loss_weight": loss['offset_loss_weight'], \
|
||||
"action_diff": loss['action_diff'], "action_diff_tot": loss['action_diff_tot'], "action_diff_mean_res1": loss['action_diff_mean_res1'], "action_diff_mean_res2": loss['action_diff_mean_res2'], \
|
||||
"action_diff_max": loss['action_diff_max']}
|
||||
return loss
|
||||
|
||||
|
||||
class VQBeTModel(nn.Module):
|
||||
|
@ -138,36 +135,19 @@ class VQBeTModel(nn.Module):
|
|||
self.global_cond_dim = self.rgb_encoder.feature_dim
|
||||
|
||||
# action token and EOS token
|
||||
self._action_token = nn.Parameter(torch.randn(1, 1, self.config.n_embd)) # Batch, Timestep, Data type, GPT input dim
|
||||
self._eos_token = nn.Parameter(torch.randn(1, 1, self.config.n_embd))
|
||||
self._action_token = nn.Parameter(torch.randn(1, 1, self.config.gpt_input_dim)) # Batch, Timestep, Data type, GPT input dim
|
||||
self._eos_token = nn.Parameter(torch.randn(1, 1, self.config.gpt_input_dim))
|
||||
|
||||
self.state_projector = MLP(
|
||||
config.output_shapes["action"][0], hidden_channels=[self.config.n_embd]
|
||||
config.output_shapes["action"][0],
|
||||
hidden_channels=[self.config.gpt_input_dim]
|
||||
)
|
||||
self.obs_projector = MLP(
|
||||
self.global_cond_dim, hidden_channels=[self.config.n_embd]
|
||||
self.global_cond_dim,
|
||||
hidden_channels=[self.config.gpt_input_dim]
|
||||
)
|
||||
self._policy = GPT(
|
||||
GPTConfig(
|
||||
block_size=self.config.block_size,
|
||||
input_dim=self.config.n_embd,
|
||||
output_dim=self.config.output_dim,
|
||||
n_layer=self.config.n_layer,
|
||||
n_head=self.config.n_head,
|
||||
n_embd=self.config.n_embd,
|
||||
dropout=self.config.dropout,
|
||||
)
|
||||
)
|
||||
self._action_head = VQBeTHead(
|
||||
config.output_dim,
|
||||
config.output_shapes["action"][0],
|
||||
offset_loss_weight=config.offset_loss_weight,
|
||||
hidden_size=config.mlp_hidden_dim,
|
||||
vqvae_groups=config.vqvae_groups,
|
||||
vqvae_n_embed=config.vqvae_n_embed,
|
||||
vqvae_embedding_dim=config.vqvae_embedding_dim,
|
||||
n_action_pred_chunk=config.n_action_pred_chunk
|
||||
)
|
||||
self._policy = GPT(config)
|
||||
self._action_head = VQBeTHead(config)
|
||||
|
||||
def discretize(self, discretize_step, actions):
|
||||
return self._action_head.discretize(discretize_step, actions)
|
||||
|
@ -190,7 +170,7 @@ class VQBeTModel(nn.Module):
|
|||
torch.unsqueeze(self.obs_projector(img_features), dim=2),
|
||||
torch.unsqueeze(self.state_projector(batch["observation.state"]), dim=2),
|
||||
self._action_token.repeat(batch_size, n_obs_steps, 1, 1)
|
||||
], dim=-2).view(batch_size, -1, self.config.n_embd)
|
||||
], dim=-2).view(batch_size, -1, self.config.gpt_input_dim)
|
||||
if img_features.shape[1] != n_obs_steps:
|
||||
raise NotImplementedError
|
||||
# eos_token = self._eos_token.repeat(batch_size, 1, 1) # TODO remove EOS token
|
||||
|
@ -231,57 +211,40 @@ class VQBeTModel(nn.Module):
|
|||
action,
|
||||
reduction="mean",
|
||||
)
|
||||
return pred_action, loss[0] if isinstance(loss, tuple) else loss
|
||||
return pred_action, loss
|
||||
|
||||
|
||||
class VQBeTHead(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
# network_kwargs
|
||||
input_size,
|
||||
output_size,
|
||||
hidden_size=1024,
|
||||
# loss_kwargs
|
||||
offset_loss_weight=100.0,
|
||||
secondary_code_multiplier=0.5,
|
||||
vqvae_groups=2, # G(number of groups)
|
||||
vqvae_n_embed=16, # C(number of code integers)
|
||||
vqvae_embedding_dim=512, # D(embedding dims)
|
||||
n_action_pred_chunk=1, # action chunk size
|
||||
):
|
||||
super().__init__()
|
||||
self.input_size = input_size
|
||||
self.output_size = output_size
|
||||
self.hidden_size = hidden_size
|
||||
self.offset_loss_weight = offset_loss_weight
|
||||
self.secondary_code_multiplier = secondary_code_multiplier
|
||||
def __init__(self, config: VQBeTConfig):
|
||||
"""
|
||||
TODO: add explanation for each value.
|
||||
"""
|
||||
|
||||
self._G = vqvae_groups # G(number of groups)
|
||||
self._C = vqvae_n_embed # C(number of code integers)
|
||||
self._D = vqvae_embedding_dim # D(embedding dims)
|
||||
self.n_action_pred_chunk = n_action_pred_chunk # action chunk size
|
||||
super().__init__()
|
||||
self.input_size = config.gpt_output_dim
|
||||
self.output_size = config.output_shapes["action"][0]
|
||||
self.hidden_size = config.mlp_hidden_dim
|
||||
self.offset_loss_weight = config.offset_loss_weight
|
||||
self.secondary_code_multiplier = config.secondary_code_multiplier
|
||||
|
||||
self.vqvae_groups = config.vqvae_groups
|
||||
self.vqvae_n_embed = config.vqvae_n_embed # C(number of code integers)
|
||||
self.vqvae_embedding_dim = config.vqvae_embedding_dim # D(embedding dims)
|
||||
self.n_action_pred_chunk = config.n_action_pred_chunk # action chunk size
|
||||
|
||||
|
||||
self._map_to_cbet_preds_bin = MLP(
|
||||
in_channels=self.input_size,
|
||||
hidden_channels=[self._G * self._C],
|
||||
hidden_channels=[self.vqvae_groups * self.vqvae_n_embed],
|
||||
)
|
||||
self._map_to_cbet_preds_offset = MLP(
|
||||
in_channels=self.input_size,
|
||||
hidden_channels=[
|
||||
self._G * self._C * n_action_pred_chunk * self.output_size,
|
||||
self.vqvae_groups * self.vqvae_n_embed * config.n_action_pred_chunk * self.output_size,
|
||||
],
|
||||
)
|
||||
# init vqvae
|
||||
vqvae_config = {
|
||||
"action_chunk": self.n_action_pred_chunk,
|
||||
"action_dim": self.output_size,
|
||||
"vqvae_n_latent_dims": self._D,
|
||||
"vqvae_n_embed": self._C,
|
||||
"vqvae_groups": self._G,
|
||||
"device": get_device_from_parameters(self),
|
||||
}
|
||||
self._vqvae_model = init_vqvae(vqvae_config)
|
||||
self._vqvae_model = VqVae(config)
|
||||
# loss
|
||||
self._criterion = FocalLoss(gamma=2.0)
|
||||
|
||||
|
@ -307,10 +270,10 @@ class VQBeTHead(nn.Module):
|
|||
cbet_logits = self._map_to_cbet_preds_bin(x)
|
||||
cbet_offsets = self._map_to_cbet_preds_offset(x)
|
||||
cbet_logits = einops.rearrange(
|
||||
cbet_logits, "(NT) (G C) -> (NT) G C", G=self._G
|
||||
cbet_logits, "(NT) (G C) -> (NT) G C", G=self.vqvae_groups
|
||||
)
|
||||
cbet_offsets = einops.rearrange(
|
||||
cbet_offsets, "(NT) (G C WA) -> (NT) G C WA", G=self._G, C=self._C
|
||||
cbet_offsets, "(NT) (G C WA) -> (NT) G C WA", G=self.vqvae_groups, C=self.vqvae_n_embed
|
||||
)
|
||||
cbet_probs = torch.softmax(cbet_logits, dim=-1)
|
||||
NT, G, choices = cbet_probs.shape
|
||||
|
@ -322,7 +285,7 @@ class VQBeTHead(nn.Module):
|
|||
|
||||
indices = (
|
||||
torch.arange(NT).unsqueeze(1).cuda(),
|
||||
torch.arange(self._G).unsqueeze(0).cuda(),
|
||||
torch.arange(self.vqvae_groups).unsqueeze(0).cuda(),
|
||||
sampled_centers,
|
||||
)
|
||||
# Use advanced indexing to sample the values
|
||||
|
@ -330,7 +293,7 @@ class VQBeTHead(nn.Module):
|
|||
|
||||
sampled_offsets = sampled_offsets.sum(dim=1)
|
||||
centers = self._vqvae_model.draw_code_forward(sampled_centers).view(
|
||||
NT, -1, self._D
|
||||
NT, -1, self.vqvae_embedding_dim
|
||||
)
|
||||
return_decoder_input = einops.rearrange(
|
||||
centers.clone().detach(), "NT 1 D -> NT D"
|
||||
|
@ -341,7 +304,7 @@ class VQBeTHead(nn.Module):
|
|||
.detach()
|
||||
) # NT, A
|
||||
sampled_offsets = einops.rearrange(
|
||||
sampled_offsets, "NT (W A) -> NT W A", W=self._vqvae_model.input_dim_h
|
||||
sampled_offsets, "NT (W A) -> NT W A", W=self.n_action_pred_chunk
|
||||
)
|
||||
predicted_action = decoded_action + sampled_offsets
|
||||
predicted_action = einops.rearrange(
|
||||
|
@ -349,7 +312,7 @@ class VQBeTHead(nn.Module):
|
|||
"(N T) W A -> N T (W A)",
|
||||
N=N,
|
||||
T=T,
|
||||
W=self._vqvae_model.input_dim_h,
|
||||
W=self.n_action_pred_chunk,
|
||||
)
|
||||
|
||||
return {
|
||||
|
@ -357,10 +320,6 @@ class VQBeTHead(nn.Module):
|
|||
"predicted_action": predicted_action,
|
||||
"sampled_centers": sampled_centers,
|
||||
"decoded_action": decoded_action,
|
||||
"G": G,
|
||||
"NT": NT,
|
||||
"N": N,
|
||||
"T": T,
|
||||
}
|
||||
|
||||
def loss_fn(self, pred, target, **kwargs):
|
||||
|
@ -369,11 +328,12 @@ class VQBeTHead(nn.Module):
|
|||
predicted_action = pred["predicted_action"]
|
||||
sampled_centers = pred["sampled_centers"]
|
||||
decoded_action = pred["decoded_action"]
|
||||
G, NT, N, T = pred["G"], pred["NT"], pred["N"], pred["T"]
|
||||
NT = predicted_action.shape[0]
|
||||
T = predicted_action.shape[1]
|
||||
cbet_logits = pred["cbet_logits"]
|
||||
|
||||
predicted_action = einops.rearrange(
|
||||
predicted_action, "N T (W A) -> (N T) W A", W=self._vqvae_model.input_dim_h
|
||||
predicted_action, "N T (W A) -> (N T) W A", W=self.n_action_pred_chunk
|
||||
)
|
||||
|
||||
action_seq = einops.rearrange(action_seq, "N T W A -> (N T) W A")
|
||||
|
@ -400,74 +360,48 @@ class VQBeTHead(nn.Module):
|
|||
)
|
||||
cbet_loss = cbet_loss1 * 5 + cbet_loss2 * self.secondary_code_multiplier
|
||||
|
||||
equal_total_code_rate = (
|
||||
torch.sum(
|
||||
(torch.sum((action_bins == sampled_centers).int(), axis=1) == G).int()
|
||||
)
|
||||
/ NT
|
||||
)
|
||||
equal_single_code_rate = torch.sum(
|
||||
equal_primary_code_rate = torch.sum(
|
||||
(action_bins[:, 0] == sampled_centers[:, 0]).int()
|
||||
) / (NT)
|
||||
equal_single_code_rate2 = torch.sum(
|
||||
equal_secondary_code_rate = torch.sum(
|
||||
(action_bins[:, 1] == sampled_centers[:, 1]).int()
|
||||
) / (NT)
|
||||
|
||||
action_diff = F.mse_loss(
|
||||
einops.rearrange(action_seq, "(N T) W A -> N T W A", T=T)[:, 4, :, :],
|
||||
einops.rearrange(predicted_action, "(N T) W A -> N T W A", T=T)[
|
||||
:, 4, :, :
|
||||
],
|
||||
) # batch, time, windowsize (t ... t+N), action dim -> [:, -1, 0, :] is for rollout
|
||||
action_diff_tot = F.mse_loss(
|
||||
einops.rearrange(action_seq, "(N T) W A -> N T W A", T=T)[:, :, :, :],
|
||||
einops.rearrange(predicted_action, "(N T) W A -> N T W A", T=T)[
|
||||
:, :, :, :
|
||||
],
|
||||
) # batch, time, windowsize (t ... t+N), action dim -> [:, -1, 0, :] is for rollout
|
||||
action_diff_mean_res1 = (
|
||||
action_mse_error = F.mse_loss(
|
||||
einops.rearrange(action_seq, "(N T) W A -> N T W A", T=T),
|
||||
einops.rearrange(predicted_action, "(N T) W A -> N T W A", T=T),
|
||||
)
|
||||
vq_action_error = (
|
||||
abs(
|
||||
einops.rearrange(action_seq, "(N T) W A -> N T W A", T=T)[
|
||||
:, 4, :, :
|
||||
]
|
||||
- einops.rearrange(decoded_action, "(N T) W A -> N T W A", T=T)[
|
||||
:, 4, :, :
|
||||
]
|
||||
einops.rearrange(action_seq, "(N T) W A -> N T W A", T=T) - einops.rearrange(decoded_action, "(N T) W A -> N T W A", T=T)
|
||||
)
|
||||
).mean()
|
||||
action_diff_mean_res2 = (
|
||||
offset_action_error = (
|
||||
abs(
|
||||
einops.rearrange(action_seq, "(N T) W A -> N T W A", T=T)[
|
||||
:, 4, :, :
|
||||
]
|
||||
- einops.rearrange(predicted_action, "(N T) W A -> N T W A", T=T)[
|
||||
:, 4, :, :
|
||||
]
|
||||
einops.rearrange(action_seq, "(N T) W A -> N T W A", T=T) - einops.rearrange(predicted_action, "(N T) W A -> N T W A", T=T)
|
||||
)
|
||||
).mean()
|
||||
action_diff_max = (
|
||||
action_error_max = (
|
||||
abs(
|
||||
einops.rearrange(action_seq, "(N T) W A -> N T W A", T=T)[
|
||||
:, 4, :, :
|
||||
]
|
||||
- einops.rearrange(predicted_action, "(N T) W A -> N T W A", T=T)[
|
||||
:, 4, :, :
|
||||
]
|
||||
einops.rearrange(action_seq, "(N T) W A -> N T W A", T=T) - einops.rearrange(predicted_action, "(N T) W A -> N T W A", T=T)
|
||||
)
|
||||
).max()
|
||||
|
||||
loss = cbet_loss + self.offset_loss_weight * offset_loss
|
||||
|
||||
loss_dict = {
|
||||
"loss": loss,
|
||||
"classification_loss": cbet_loss.detach().cpu().item(),
|
||||
"offset_loss": offset_loss.detach().cpu().item(),
|
||||
"total_loss": loss.detach().cpu().item(),
|
||||
"equal_total_code_rate": equal_total_code_rate,
|
||||
"equal_single_code_rate": equal_single_code_rate,
|
||||
"equal_single_code_rate2": equal_single_code_rate2,
|
||||
"equal_primary_code_rate": equal_primary_code_rate.detach().cpu().item(),
|
||||
"equal_secondary_code_rate": equal_secondary_code_rate.detach().cpu().item(),
|
||||
"vq_action_error": vq_action_error.detach().cpu().item(),
|
||||
"offset_action_error": offset_action_error.detach().cpu().item(),
|
||||
"action_error_max": action_error_max.detach().cpu().item(),
|
||||
"action_mse_error": action_mse_error.detach().cpu().item(),
|
||||
|
||||
}
|
||||
return {"actor_loss": loss, "equal_single_code_rate": equal_single_code_rate, "equal_single_code_rate2": equal_single_code_rate2, "offset_loss_weight": self.offset_loss_weight, \
|
||||
"action_diff": action_diff, "action_diff_tot": action_diff_tot, "action_diff_mean_res1": action_diff_mean_res1, "action_diff_mean_res2": action_diff_mean_res2, \
|
||||
"action_diff_max": action_diff_max}, loss_dict
|
||||
return loss_dict
|
||||
|
||||
class VQBeTOptimizer:
|
||||
def __init__(self, policy, cfg):
|
||||
|
@ -688,71 +622,43 @@ def _replace_submodules(
|
|||
|
||||
class VqVae(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
input_dim_h=10, # length of action chunk
|
||||
input_dim_w=9, # action dim
|
||||
n_latent_dims=512,
|
||||
vqvae_n_embed=32,
|
||||
vqvae_groups=4,
|
||||
eval=True,
|
||||
load_dir=None,
|
||||
encoder_loss_multiplier=1.0,
|
||||
act_scale=1.0,
|
||||
self, config: VQBeTConfig,
|
||||
):
|
||||
|
||||
super(VqVae, self).__init__()
|
||||
self.n_latent_dims = n_latent_dims
|
||||
self.input_dim_h = input_dim_h
|
||||
self.input_dim_w = input_dim_w
|
||||
self.rep_dim = self.n_latent_dims
|
||||
self.vqvae_n_embed = vqvae_n_embed
|
||||
self.vqvae_lr = 1e-3
|
||||
self.vqvae_groups = vqvae_groups
|
||||
self.encoder_loss_multiplier = encoder_loss_multiplier
|
||||
self.act_scale = act_scale
|
||||
self.n_action_pred_chunk = config.n_action_pred_chunk
|
||||
self.action_dim = config.output_shapes["action"][0]
|
||||
|
||||
self.discretized = False
|
||||
self.optimized_steps = 0
|
||||
|
||||
discrete_cfg = {"groups": self.vqvae_groups, "n_embed": self.vqvae_n_embed}
|
||||
|
||||
self.vq_layer = ResidualVQ(
|
||||
dim=self.n_latent_dims,
|
||||
num_quantizers=discrete_cfg["groups"],
|
||||
codebook_size=self.vqvae_n_embed,
|
||||
dim=config.vqvae_embedding_dim,
|
||||
num_quantizers=config.vqvae_groups,
|
||||
codebook_size=config.vqvae_n_embed,
|
||||
)
|
||||
self.embedding_dim = self.n_latent_dims
|
||||
self.embedding_dim = config.vqvae_embedding_dim
|
||||
|
||||
if self.input_dim_h == 1:
|
||||
if self.n_action_pred_chunk == 1:
|
||||
self.encoder = MLP(
|
||||
in_channels=input_dim_w,
|
||||
hidden_channels=[128, 128, n_latent_dims],
|
||||
in_channels=self.action_dim,
|
||||
hidden_channels=[config.vqvae_enc_hidden_dim, config.vqvae_enc_hidden_dim, config.vqvae_embedding_dim],
|
||||
)
|
||||
self.decoder = MLP(
|
||||
in_channels=n_latent_dims,
|
||||
hidden_channels=[128, 128, input_dim_w],
|
||||
in_channels=config.vqvae_embedding_dim,
|
||||
hidden_channels=[config.vqvae_enc_hidden_dim, config.vqvae_enc_hidden_dim, self.action_dim],
|
||||
)
|
||||
else:
|
||||
self.encoder = MLP(
|
||||
in_channels=input_dim_w * self.input_dim_h,
|
||||
hidden_channels=[128, 128, n_latent_dims],
|
||||
in_channels=self.action_dim * self.n_action_pred_chunk,
|
||||
hidden_channels=[config.vqvae_enc_hidden_dim, config.vqvae_enc_hidden_dim, config.vqvae_embedding_dim],
|
||||
)
|
||||
self.decoder = MLP(
|
||||
in_channels=n_latent_dims,
|
||||
hidden_channels=[128, 128, input_dim_w * self.input_dim_h],
|
||||
in_channels=config.vqvae_embedding_dim,
|
||||
hidden_channels=[config.vqvae_enc_hidden_dim, config.vqvae_enc_hidden_dim, self.action_dim * self.n_action_pred_chunk],
|
||||
)
|
||||
|
||||
|
||||
if load_dir is not None:
|
||||
try:
|
||||
state_dict = torch.load(load_dir)
|
||||
except RuntimeError:
|
||||
state_dict = torch.load(load_dir, map_location=torch.device("cpu"))
|
||||
self.load_state_dict(state_dict)
|
||||
|
||||
if eval:
|
||||
self.eval()
|
||||
else:
|
||||
self.train()
|
||||
self.train()
|
||||
|
||||
def eval(self):
|
||||
self.training = False
|
||||
|
@ -783,23 +689,22 @@ class VqVae(nn.Module):
|
|||
return z_embed
|
||||
|
||||
def get_action_from_latent(self, latent):
|
||||
output = self.decoder(latent) * self.act_scale
|
||||
if self.input_dim_h == 1:
|
||||
return einops.rearrange(output, "N (T A) -> N T A", A=self.input_dim_w)
|
||||
output = self.decoder(latent)
|
||||
if self.n_action_pred_chunk == 1:
|
||||
return einops.rearrange(output, "N (T A) -> N T A", A=self.action_dim)
|
||||
else:
|
||||
return einops.rearrange(output, "N (T A) -> N T A", A=self.input_dim_w)
|
||||
return einops.rearrange(output, "N (T A) -> N T A", A=self.action_dim)
|
||||
|
||||
def preprocess(self, state):
|
||||
if not torch.is_tensor(state):
|
||||
state = torch.FloatTensor(state.copy())
|
||||
if self.input_dim_h == 1:
|
||||
if self.n_action_pred_chunk == 1:
|
||||
state = state.squeeze(-2) # state.squeeze(-1)
|
||||
else:
|
||||
state = einops.rearrange(state, "N T A -> N (T A)")
|
||||
return state
|
||||
|
||||
def get_code(self, state, required_recon=False):
|
||||
state = state / self.act_scale
|
||||
state = self.preprocess(state)
|
||||
with torch.no_grad():
|
||||
state_rep = self.encoder(state)
|
||||
|
@ -810,9 +715,9 @@ class VqVae(nn.Module):
|
|||
vq_code = vq_code.view(*state_rep_shape, -1)
|
||||
vq_loss_state = torch.sum(vq_loss_state)
|
||||
if required_recon:
|
||||
recon_state = self.decoder(state_vq) * self.act_scale
|
||||
recon_state_ae = self.decoder(state_rep) * self.act_scale
|
||||
if self.input_dim_h == 1:
|
||||
recon_state = self.decoder(state_vq)
|
||||
recon_state_ae = self.decoder(state_rep)
|
||||
if self.n_action_pred_chunk == 1:
|
||||
return state_vq, vq_code, recon_state, recon_state_ae
|
||||
else:
|
||||
return (
|
||||
|
@ -826,7 +731,6 @@ class VqVae(nn.Module):
|
|||
return state_vq, vq_code
|
||||
|
||||
def vqvae_forward(self, state):
|
||||
state = state / self.act_scale
|
||||
state = self.preprocess(state)
|
||||
state_rep = self.encoder(state)
|
||||
state_rep_shape = state_rep.shape[:-1]
|
||||
|
@ -839,7 +743,7 @@ class VqVae(nn.Module):
|
|||
dec_out = self.decoder(state_vq)
|
||||
encoder_loss = (state - dec_out).abs().mean()
|
||||
|
||||
rep_loss = encoder_loss * self.encoder_loss_multiplier + (vq_loss_state * 5)
|
||||
rep_loss = encoder_loss * vq_loss_state * 5
|
||||
|
||||
metric = (
|
||||
encoder_loss.clone().detach(),
|
||||
|
@ -857,28 +761,16 @@ class VqVae(nn.Module):
|
|||
|
||||
|
||||
|
||||
def init_vqvae(config):
|
||||
# model
|
||||
vqvae_model = VqVae(
|
||||
input_dim_h=config["action_chunk"],
|
||||
input_dim_w=config["action_dim"],
|
||||
n_latent_dims=config["vqvae_n_latent_dims"],
|
||||
vqvae_n_embed=config["vqvae_n_embed"],
|
||||
vqvae_groups=config["vqvae_groups"],
|
||||
eval=False,
|
||||
)
|
||||
|
||||
return vqvae_model
|
||||
|
||||
|
||||
def pretrain_vqvae(vqvae_model, discretize_step, actions):
|
||||
if vqvae_model.input_dim_h == 1:
|
||||
if vqvae_model.n_action_pred_chunk == 1:
|
||||
# not using action chunk
|
||||
actions = actions.reshape(-1, 1, actions.shape[-1])
|
||||
else:
|
||||
# using action chunk
|
||||
slices = []
|
||||
slices.extend([actions[:, j:j+vqvae_model.input_dim_h, :] for j in range(actions.shape[1]+1-vqvae_model.input_dim_h)])
|
||||
slices.extend([actions[:, j:j+vqvae_model.n_action_pred_chunk, :] for j in range(actions.shape[1]+1-vqvae_model.n_action_pred_chunk)])
|
||||
actions = torch.cat(slices, dim=0)
|
||||
|
||||
|
||||
|
@ -2166,40 +2058,40 @@ class MLP(torch.nn.Sequential):
|
|||
class CausalSelfAttention(nn.Module):
|
||||
def __init__(self, config):
|
||||
super().__init__()
|
||||
assert config.n_embd % config.n_head == 0
|
||||
assert config.gpt_n_embed % config.gpt_n_head == 0
|
||||
# key, query, value projections for all heads, but in a batch
|
||||
self.c_attn = nn.Linear(config.n_embd, 3 * config.n_embd)
|
||||
self.c_attn = nn.Linear(config.gpt_n_embed, 3 * config.gpt_n_embed)
|
||||
# output projection
|
||||
self.c_proj = nn.Linear(config.n_embd, config.n_embd)
|
||||
self.c_proj = nn.Linear(config.gpt_n_embed, config.gpt_n_embed)
|
||||
# regularization
|
||||
self.attn_dropout = nn.Dropout(config.dropout)
|
||||
self.resid_dropout = nn.Dropout(config.dropout)
|
||||
# causal mask to ensure that attention is only applied to the left in the input sequence
|
||||
self.register_buffer(
|
||||
"bias",
|
||||
torch.tril(torch.ones(config.block_size, config.block_size)).view(
|
||||
1, 1, config.block_size, config.block_size
|
||||
torch.tril(torch.ones(config.gpt_block_size, config.gpt_block_size)).view(
|
||||
1, 1, config.gpt_block_size, config.gpt_block_size
|
||||
),
|
||||
)
|
||||
self.n_head = config.n_head
|
||||
self.n_embd = config.n_embd
|
||||
self.gpt_n_head = config.gpt_n_head
|
||||
self.gpt_n_embed = config.gpt_n_embed
|
||||
|
||||
def forward(self, x):
|
||||
(
|
||||
B,
|
||||
T,
|
||||
C,
|
||||
) = x.size() # batch size, sequence length, embedding dimensionality (n_embd)
|
||||
) = x.size() # batch size, sequence length, embedding dimensionality (gpt_n_embed)
|
||||
|
||||
# calculate query, key, values for all heads in batch and move head forward to be the batch dim
|
||||
q, k, v = self.c_attn(x).split(self.n_embd, dim=2)
|
||||
k = k.view(B, T, self.n_head, C // self.n_head).transpose(
|
||||
q, k, v = self.c_attn(x).split(self.gpt_n_embed, dim=2)
|
||||
k = k.view(B, T, self.gpt_n_head, C // self.gpt_n_head).transpose(
|
||||
1, 2
|
||||
) # (B, nh, T, hs)
|
||||
q = q.view(B, T, self.n_head, C // self.n_head).transpose(
|
||||
q = q.view(B, T, self.gpt_n_head, C // self.gpt_n_head).transpose(
|
||||
1, 2
|
||||
) # (B, nh, T, hs)
|
||||
v = v.view(B, T, self.n_head, C // self.n_head).transpose(
|
||||
v = v.view(B, T, self.gpt_n_head, C // self.gpt_n_head).transpose(
|
||||
1, 2
|
||||
) # (B, nh, T, hs)
|
||||
|
||||
|
@ -2222,13 +2114,13 @@ class CausalSelfAttention(nn.Module):
|
|||
class Block(nn.Module):
|
||||
def __init__(self, config):
|
||||
super().__init__()
|
||||
self.ln_1 = nn.LayerNorm(config.n_embd)
|
||||
self.ln_1 = nn.LayerNorm(config.gpt_n_embed)
|
||||
self.attn = CausalSelfAttention(config)
|
||||
self.ln_2 = nn.LayerNorm(config.n_embd)
|
||||
self.ln_2 = nn.LayerNorm(config.gpt_n_embed)
|
||||
self.mlp = nn.Sequential(
|
||||
nn.Linear(config.n_embd, 4 * config.n_embd),
|
||||
nn.Linear(config.gpt_n_embed, 4 * config.gpt_n_embed),
|
||||
nn.GELU(),
|
||||
nn.Linear(4 * config.n_embd, config.n_embd),
|
||||
nn.Linear(4 * config.gpt_n_embed, config.gpt_n_embed),
|
||||
nn.Dropout(config.dropout)
|
||||
)
|
||||
|
||||
|
@ -2238,15 +2130,6 @@ class Block(nn.Module):
|
|||
return x
|
||||
|
||||
|
||||
@dataclass
|
||||
class GPTConfig:
|
||||
block_size: int = 1024
|
||||
input_dim: int = 256
|
||||
output_dim: int = 256
|
||||
n_layer: int = 12
|
||||
n_head: int = 12
|
||||
n_embd: int = 768
|
||||
dropout: float = 0.1
|
||||
|
||||
|
||||
class GPT(nn.Module):
|
||||
|
@ -2287,29 +2170,28 @@ class GPT(nn.Module):
|
|||
"""
|
||||
|
||||
|
||||
def __init__(self, config):
|
||||
def __init__(self, config: VQBeTConfig):
|
||||
super().__init__()
|
||||
assert config.input_dim is not None
|
||||
assert config.output_dim is not None
|
||||
assert config.block_size is not None
|
||||
assert config.gpt_output_dim is not None
|
||||
assert config.gpt_block_size is not None
|
||||
self.config = config
|
||||
|
||||
self.transformer = nn.ModuleDict(
|
||||
dict(
|
||||
wte=nn.Linear(config.input_dim, config.n_embd),
|
||||
wpe=nn.Embedding(config.block_size, config.n_embd),
|
||||
wte=nn.Linear(config.gpt_input_dim, config.gpt_n_embed),
|
||||
wpe=nn.Embedding(config.gpt_block_size, config.gpt_n_embed),
|
||||
drop=nn.Dropout(config.dropout),
|
||||
h=nn.ModuleList([Block(config) for _ in range(config.n_layer)]),
|
||||
ln_f=nn.LayerNorm(config.n_embd),
|
||||
h=nn.ModuleList([Block(config) for _ in range(config.gpt_n_layer)]),
|
||||
ln_f=nn.LayerNorm(config.gpt_n_embed),
|
||||
)
|
||||
)
|
||||
self.lm_head = nn.Linear(config.n_embd, config.output_dim, bias=False)
|
||||
self.lm_head = nn.Linear(config.gpt_n_embed, config.gpt_output_dim, bias=False)
|
||||
# init all weights, and apply a special scaled init to the residual projections, per GPT-2 paper
|
||||
self.apply(self._init_weights)
|
||||
for pn, p in self.named_parameters():
|
||||
if pn.endswith("c_proj.weight"):
|
||||
torch.nn.init.normal_(
|
||||
p, mean=0.0, std=0.02 / math.sqrt(2 * config.n_layer)
|
||||
p, mean=0.0, std=0.02 / math.sqrt(2 * config.gpt_n_layer)
|
||||
)
|
||||
|
||||
# report number of parameters
|
||||
|
@ -2320,8 +2202,8 @@ class GPT(nn.Module):
|
|||
device = input.device
|
||||
b, t, d = input.size()
|
||||
assert (
|
||||
t <= self.config.block_size
|
||||
), f"Cannot forward sequence of length {t}, block size is only {self.config.block_size}"
|
||||
t <= self.config.gpt_block_size
|
||||
), f"Cannot forward sequence of length {t}, block size is only {self.config.gpt_block_size}"
|
||||
pos = torch.arange(0, t, dtype=torch.long, device=device).unsqueeze(
|
||||
0
|
||||
) # shape (1, t)
|
||||
|
@ -2329,10 +2211,10 @@ class GPT(nn.Module):
|
|||
# forward the GPT model itself
|
||||
tok_emb = self.transformer.wte(
|
||||
input
|
||||
) # token embeddings of shape (b, t, n_embd)
|
||||
) # token embeddings of shape (b, t, gpt_n_embed)
|
||||
pos_emb = self.transformer.wpe(
|
||||
pos
|
||||
) # position embeddings of shape (1, t, n_embd)
|
||||
) # position embeddings of shape (1, t, gpt_n_embed)
|
||||
x = self.transformer.drop(tok_emb + pos_emb)
|
||||
for block in self.transformer.h:
|
||||
x = block(x)
|
||||
|
@ -2351,14 +2233,14 @@ class GPT(nn.Module):
|
|||
torch.nn.init.zeros_(module.bias)
|
||||
torch.nn.init.ones_(module.weight)
|
||||
|
||||
def crop_block_size(self, block_size):
|
||||
assert block_size <= self.config.block_size
|
||||
self.config.block_size = block_size
|
||||
def crop_block_size(self, gpt_block_size):
|
||||
assert gpt_block_size <= self.config.gpt_block_size
|
||||
self.config.gpt_block_size = gpt_block_size
|
||||
self.transformer.wpe.weight = nn.Parameter(
|
||||
self.transformer.wpe.weight[:block_size]
|
||||
self.transformer.wpe.weight[:gpt_block_size]
|
||||
)
|
||||
for block in self.transformer.h:
|
||||
block.attn.bias = block.attn.bias[:, :, :block_size, :block_size]
|
||||
block.attn.bias = block.attn.bias[:, :, :gpt_block_size, :gpt_block_size]
|
||||
|
||||
def configure_optimizers(self, weight_decay, learning_rate, betas, optimizer="Adamw", eps=None):
|
||||
"""
|
||||
|
|
|
@ -22,7 +22,7 @@ override_dataset_stats:
|
|||
max: [511.0, 511.0]
|
||||
|
||||
training:
|
||||
offline_steps: 800000
|
||||
offline_steps: 250000
|
||||
online_steps: 0
|
||||
eval_freq: 20000
|
||||
save_freq: 20000
|
||||
|
@ -90,12 +90,15 @@ policy:
|
|||
vqvae_groups: 2
|
||||
vqvae_n_embed: 16
|
||||
vqvae_embedding_dim: 256
|
||||
vqvae_enc_hidden_dim: 128
|
||||
# VQ-BeT
|
||||
block_size: 500
|
||||
output_dim: 512
|
||||
n_layer: 8 # 8
|
||||
n_head: 8 # 4
|
||||
n_embd: 512
|
||||
gpt_block_size: 500
|
||||
gpt_input_dim: 512
|
||||
gpt_output_dim: 512
|
||||
gpt_n_layer: 8
|
||||
gpt_n_head: 8
|
||||
gpt_n_embed: 512
|
||||
dropout: 0.1
|
||||
mlp_hidden_dim: 1024 # 512
|
||||
offset_loss_weight: 10000.
|
||||
mlp_hidden_dim: 1024
|
||||
offset_loss_weight: 10000.
|
||||
secondary_code_multiplier: 0.5
|
Loading…
Reference in New Issue