clean code
This commit is contained in:
parent
311f79a874
commit
1d689512af
|
@ -117,6 +117,7 @@ class VQBeTConfig:
|
|||
n_embd: int = 120
|
||||
dropout: float = 0.1
|
||||
mlp_hidden_dim: int = 1024
|
||||
offset_loss_weight: float = 10000.
|
||||
|
||||
def __post_init__(self):
|
||||
"""Input validation (not exhaustive)."""
|
||||
|
|
File diff suppressed because it is too large
Load Diff
|
@ -52,7 +52,7 @@ training:
|
|||
action: "[i / ${fps} for i in range(1 - ${policy.n_obs_steps}, ${policy.n_action_pred_token} + ${policy.n_action_pred_chunk} - 1)]"
|
||||
|
||||
eval:
|
||||
n_episodes: 50
|
||||
n_episodes: 500
|
||||
batch_size: 50
|
||||
|
||||
policy:
|
||||
|
@ -61,8 +61,8 @@ policy:
|
|||
# Input / output structure.
|
||||
n_obs_steps: 5
|
||||
# n_action_steps: 7 # n_action_pred_token + n_action_pred_window - 1
|
||||
n_action_pred_token: 3 # 3 5
|
||||
n_action_pred_chunk: 5 # 5 1 jay temp2
|
||||
n_action_pred_token: 7
|
||||
n_action_pred_chunk: 5
|
||||
|
||||
input_shapes:
|
||||
# TODO(rcadene, alexander-soare): add variables for height and width from the dataset/env?
|
||||
|
@ -93,9 +93,10 @@ policy:
|
|||
vqvae_embedding_dim: 256
|
||||
# VQ-BeT
|
||||
block_size: 500
|
||||
output_dim: 256 # 512
|
||||
n_layer: 6 # 8
|
||||
n_head: 6 # 4
|
||||
n_embd: 120 # 512
|
||||
output_dim: 512
|
||||
n_layer: 8 # 8
|
||||
n_head: 8 # 4
|
||||
n_embd: 512
|
||||
dropout: 0.1
|
||||
mlp_hidden_dim: 1024 # 512
|
||||
offset_loss_weight: 10000.
|
Loading…
Reference in New Issue