commit
f127837d48
|
@ -71,7 +71,8 @@ class VQBeTConfig:
|
|||
|
||||
# Inputs / output structure.
|
||||
n_obs_steps: int = 5
|
||||
n_action_steps: int = 5
|
||||
n_action_pred_token: int = 3
|
||||
n_action_pred_chunk: int = 5
|
||||
|
||||
input_shapes: dict[str, list[int]] = field(
|
||||
default_factory=lambda: {
|
||||
|
@ -115,6 +116,7 @@ class VQBeTConfig:
|
|||
n_embd: int = 120
|
||||
dropout: float = 0.1
|
||||
mlp_hidden_dim: int = 1024
|
||||
offset_loss_weight: float = 10000.
|
||||
|
||||
def __post_init__(self):
|
||||
"""Input validation (not exhaustive)."""
|
||||
|
|
File diff suppressed because it is too large
Load Diff
|
@ -22,10 +22,10 @@ override_dataset_stats:
|
|||
max: [511.0, 511.0]
|
||||
|
||||
training:
|
||||
offline_steps: 200000
|
||||
offline_steps: 800000
|
||||
online_steps: 0
|
||||
eval_freq: 10000
|
||||
save_freq: 5000
|
||||
eval_freq: 20000
|
||||
save_freq: 20000
|
||||
log_freq: 250
|
||||
save_model: true
|
||||
|
||||
|
@ -41,7 +41,7 @@ training:
|
|||
|
||||
# VQ-BeT specific
|
||||
vqvae_lr: 1.0e-3
|
||||
discretize_step: 3000
|
||||
discretize_step: 20000
|
||||
bet_weight_decay: 2e-4
|
||||
bet_learning_rate: 5.5e-5
|
||||
bet_betas: [0.9, 0.999]
|
||||
|
@ -49,10 +49,10 @@ training:
|
|||
delta_timestamps:
|
||||
observation.image: "[i / ${fps} for i in range(1 - ${policy.n_obs_steps}, 1)]"
|
||||
observation.state: "[i / ${fps} for i in range(1 - ${policy.n_obs_steps}, 1)]"
|
||||
action: "[i / ${fps} for i in range(1 - ${policy.n_obs_steps}, ${policy.n_action_steps})]"
|
||||
action: "[i / ${fps} for i in range(1 - ${policy.n_obs_steps}, ${policy.n_action_pred_token} + ${policy.n_action_pred_chunk} - 1)]"
|
||||
|
||||
eval:
|
||||
n_episodes: 50
|
||||
n_episodes: 500
|
||||
batch_size: 50
|
||||
|
||||
policy:
|
||||
|
@ -60,7 +60,8 @@ policy:
|
|||
|
||||
# Input / output structure.
|
||||
n_obs_steps: 5
|
||||
n_action_steps: 5
|
||||
n_action_pred_token: 7
|
||||
n_action_pred_chunk: 5
|
||||
|
||||
input_shapes:
|
||||
# TODO(rcadene, alexander-soare): add variables for height and width from the dataset/env?
|
||||
|
@ -90,10 +91,11 @@ policy:
|
|||
vqvae_n_embed: 16
|
||||
vqvae_embedding_dim: 256
|
||||
# VQ-BeT
|
||||
block_size: 50
|
||||
output_dim: 256 # 512
|
||||
n_layer: 6 # 8
|
||||
n_head: 6 # 4
|
||||
n_embd: 120 # 512
|
||||
block_size: 500
|
||||
output_dim: 512
|
||||
n_layer: 8 # 8
|
||||
n_head: 8 # 4
|
||||
n_embd: 512
|
||||
dropout: 0.1
|
||||
mlp_hidden_dim: 1024 # 512
|
||||
offset_loss_weight: 10000.
|
Loading…
Reference in New Issue