Merge pull request #3 from jayLEE0301/debugging-vqbet

Debugging vqbet
This commit is contained in:
Seungjae Lee 2024-05-22 18:25:16 -04:00 committed by GitHub
commit f127837d48
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 236 additions and 715 deletions

View File

@ -71,7 +71,8 @@ class VQBeTConfig:
# Inputs / output structure.
n_obs_steps: int = 5
n_action_steps: int = 5
n_action_pred_token: int = 3
n_action_pred_chunk: int = 5
input_shapes: dict[str, list[int]] = field(
default_factory=lambda: {
@ -115,6 +116,7 @@ class VQBeTConfig:
n_embd: int = 120
dropout: float = 0.1
mlp_hidden_dim: int = 1024
offset_loss_weight: float = 10000.
def __post_init__(self):
"""Input validation (not exhaustive)."""

File diff suppressed because it is too large Load Diff

View File

@ -22,10 +22,10 @@ override_dataset_stats:
max: [511.0, 511.0]
training:
offline_steps: 200000
offline_steps: 800000
online_steps: 0
eval_freq: 10000
save_freq: 5000
eval_freq: 20000
save_freq: 20000
log_freq: 250
save_model: true
@ -41,7 +41,7 @@ training:
# VQ-BeT specific
vqvae_lr: 1.0e-3
discretize_step: 3000
discretize_step: 20000
bet_weight_decay: 2e-4
bet_learning_rate: 5.5e-5
bet_betas: [0.9, 0.999]
@ -49,10 +49,10 @@ training:
delta_timestamps:
observation.image: "[i / ${fps} for i in range(1 - ${policy.n_obs_steps}, 1)]"
observation.state: "[i / ${fps} for i in range(1 - ${policy.n_obs_steps}, 1)]"
action: "[i / ${fps} for i in range(1 - ${policy.n_obs_steps}, ${policy.n_action_steps})]"
action: "[i / ${fps} for i in range(1 - ${policy.n_obs_steps}, ${policy.n_action_pred_token} + ${policy.n_action_pred_chunk} - 1)]"
eval:
n_episodes: 50
n_episodes: 500
batch_size: 50
policy:
@ -60,7 +60,8 @@ policy:
# Input / output structure.
n_obs_steps: 5
n_action_steps: 5
n_action_pred_token: 7
n_action_pred_chunk: 5
input_shapes:
# TODO(rcadene, alexander-soare): add variables for height and width from the dataset/env?
@ -90,10 +91,11 @@ policy:
vqvae_n_embed: 16
vqvae_embedding_dim: 256
# VQ-BeT
block_size: 50
output_dim: 256 # 512
n_layer: 6 # 8
n_head: 6 # 4
n_embd: 120 # 512
block_size: 500
output_dim: 512
n_layer: 8 # 8
n_head: 8 # 4
n_embd: 512
dropout: 0.1
mlp_hidden_dim: 1024 # 512
offset_loss_weight: 10000.