clean code

This commit is contained in:
jayLEE0301 2024-05-22 17:50:37 -04:00
parent 311f79a874
commit 1d689512af
3 changed files with 159 additions and 957 deletions

View File

@ -117,6 +117,7 @@ class VQBeTConfig:
n_embd: int = 120 n_embd: int = 120
dropout: float = 0.1 dropout: float = 0.1
mlp_hidden_dim: int = 1024 mlp_hidden_dim: int = 1024
offset_loss_weight: float = 10000.
def __post_init__(self): def __post_init__(self):
"""Input validation (not exhaustive).""" """Input validation (not exhaustive)."""

File diff suppressed because it is too large Load Diff

View File

@ -52,7 +52,7 @@ training:
action: "[i / ${fps} for i in range(1 - ${policy.n_obs_steps}, ${policy.n_action_pred_token} + ${policy.n_action_pred_chunk} - 1)]" action: "[i / ${fps} for i in range(1 - ${policy.n_obs_steps}, ${policy.n_action_pred_token} + ${policy.n_action_pred_chunk} - 1)]"
eval: eval:
n_episodes: 50 n_episodes: 500
batch_size: 50 batch_size: 50
policy: policy:
@ -61,8 +61,8 @@ policy:
# Input / output structure. # Input / output structure.
n_obs_steps: 5 n_obs_steps: 5
# n_action_steps: 7 # n_action_pred_token + n_action_pred_window - 1 # n_action_steps: 7 # n_action_pred_token + n_action_pred_window - 1
n_action_pred_token: 3 # 3 5 n_action_pred_token: 7
n_action_pred_chunk: 5 # 5 1 jay temp2 n_action_pred_chunk: 5
input_shapes: input_shapes:
# TODO(rcadene, alexander-soare): add variables for height and width from the dataset/env? # TODO(rcadene, alexander-soare): add variables for height and width from the dataset/env?
@ -93,9 +93,10 @@ policy:
vqvae_embedding_dim: 256 vqvae_embedding_dim: 256
# VQ-BeT # VQ-BeT
block_size: 500 block_size: 500
output_dim: 256 # 512 output_dim: 512
n_layer: 6 # 8 n_layer: 8 # 8
n_head: 6 # 4 n_head: 8 # 4
n_embd: 120 # 512 n_embd: 512
dropout: 0.1 dropout: 0.1
mlp_hidden_dim: 1024 # 512 mlp_hidden_dim: 1024 # 512
offset_loss_weight: 10000.