clean code

This commit is contained in:
jayLEE0301 2024-05-22 17:50:37 -04:00
parent 311f79a874
commit 1d689512af
3 changed files with 159 additions and 957 deletions

View File

@ -117,6 +117,7 @@ class VQBeTConfig:
n_embd: int = 120
dropout: float = 0.1
mlp_hidden_dim: int = 1024
offset_loss_weight: float = 10000.
def __post_init__(self):
"""Input validation (not exhaustive)."""

File diff suppressed because it is too large Load Diff

View File

@ -52,7 +52,7 @@ training:
action: "[i / ${fps} for i in range(1 - ${policy.n_obs_steps}, ${policy.n_action_pred_token} + ${policy.n_action_pred_chunk} - 1)]"
eval:
n_episodes: 50
n_episodes: 500
batch_size: 50
policy:
@ -61,8 +61,8 @@ policy:
# Input / output structure.
n_obs_steps: 5
# n_action_steps: 7 # n_action_pred_token + n_action_pred_window - 1
n_action_pred_token: 3 # 3 5
n_action_pred_chunk: 5 # 5 1 jay temp2
n_action_pred_token: 7
n_action_pred_chunk: 5
input_shapes:
# TODO(rcadene, alexander-soare): add variables for height and width from the dataset/env?
@ -93,9 +93,10 @@ policy:
vqvae_embedding_dim: 256
# VQ-BeT
block_size: 500
output_dim: 256 # 512
n_layer: 6 # 8
n_head: 6 # 4
n_embd: 120 # 512
output_dim: 512
n_layer: 8 # 8
n_head: 8 # 4
n_embd: 512
dropout: 0.1
mlp_hidden_dim: 1024 # 512
mlp_hidden_dim: 1024 # 512
offset_loss_weight: 10000.