commit
f127837d48
|
@ -71,7 +71,8 @@ class VQBeTConfig:
|
||||||
|
|
||||||
# Inputs / output structure.
|
# Inputs / output structure.
|
||||||
n_obs_steps: int = 5
|
n_obs_steps: int = 5
|
||||||
n_action_steps: int = 5
|
n_action_pred_token: int = 3
|
||||||
|
n_action_pred_chunk: int = 5
|
||||||
|
|
||||||
input_shapes: dict[str, list[int]] = field(
|
input_shapes: dict[str, list[int]] = field(
|
||||||
default_factory=lambda: {
|
default_factory=lambda: {
|
||||||
|
@ -115,6 +116,7 @@ class VQBeTConfig:
|
||||||
n_embd: int = 120
|
n_embd: int = 120
|
||||||
dropout: float = 0.1
|
dropout: float = 0.1
|
||||||
mlp_hidden_dim: int = 1024
|
mlp_hidden_dim: int = 1024
|
||||||
|
offset_loss_weight: float = 10000.
|
||||||
|
|
||||||
def __post_init__(self):
|
def __post_init__(self):
|
||||||
"""Input validation (not exhaustive)."""
|
"""Input validation (not exhaustive)."""
|
||||||
|
|
File diff suppressed because it is too large
Load Diff
|
@ -22,10 +22,10 @@ override_dataset_stats:
|
||||||
max: [511.0, 511.0]
|
max: [511.0, 511.0]
|
||||||
|
|
||||||
training:
|
training:
|
||||||
offline_steps: 200000
|
offline_steps: 800000
|
||||||
online_steps: 0
|
online_steps: 0
|
||||||
eval_freq: 10000
|
eval_freq: 20000
|
||||||
save_freq: 5000
|
save_freq: 20000
|
||||||
log_freq: 250
|
log_freq: 250
|
||||||
save_model: true
|
save_model: true
|
||||||
|
|
||||||
|
@ -41,7 +41,7 @@ training:
|
||||||
|
|
||||||
# VQ-BeT specific
|
# VQ-BeT specific
|
||||||
vqvae_lr: 1.0e-3
|
vqvae_lr: 1.0e-3
|
||||||
discretize_step: 3000
|
discretize_step: 20000
|
||||||
bet_weight_decay: 2e-4
|
bet_weight_decay: 2e-4
|
||||||
bet_learning_rate: 5.5e-5
|
bet_learning_rate: 5.5e-5
|
||||||
bet_betas: [0.9, 0.999]
|
bet_betas: [0.9, 0.999]
|
||||||
|
@ -49,10 +49,10 @@ training:
|
||||||
delta_timestamps:
|
delta_timestamps:
|
||||||
observation.image: "[i / ${fps} for i in range(1 - ${policy.n_obs_steps}, 1)]"
|
observation.image: "[i / ${fps} for i in range(1 - ${policy.n_obs_steps}, 1)]"
|
||||||
observation.state: "[i / ${fps} for i in range(1 - ${policy.n_obs_steps}, 1)]"
|
observation.state: "[i / ${fps} for i in range(1 - ${policy.n_obs_steps}, 1)]"
|
||||||
action: "[i / ${fps} for i in range(1 - ${policy.n_obs_steps}, ${policy.n_action_steps})]"
|
action: "[i / ${fps} for i in range(1 - ${policy.n_obs_steps}, ${policy.n_action_pred_token} + ${policy.n_action_pred_chunk} - 1)]"
|
||||||
|
|
||||||
eval:
|
eval:
|
||||||
n_episodes: 50
|
n_episodes: 500
|
||||||
batch_size: 50
|
batch_size: 50
|
||||||
|
|
||||||
policy:
|
policy:
|
||||||
|
@ -60,7 +60,8 @@ policy:
|
||||||
|
|
||||||
# Input / output structure.
|
# Input / output structure.
|
||||||
n_obs_steps: 5
|
n_obs_steps: 5
|
||||||
n_action_steps: 5
|
n_action_pred_token: 7
|
||||||
|
n_action_pred_chunk: 5
|
||||||
|
|
||||||
input_shapes:
|
input_shapes:
|
||||||
# TODO(rcadene, alexander-soare): add variables for height and width from the dataset/env?
|
# TODO(rcadene, alexander-soare): add variables for height and width from the dataset/env?
|
||||||
|
@ -90,10 +91,11 @@ policy:
|
||||||
vqvae_n_embed: 16
|
vqvae_n_embed: 16
|
||||||
vqvae_embedding_dim: 256
|
vqvae_embedding_dim: 256
|
||||||
# VQ-BeT
|
# VQ-BeT
|
||||||
block_size: 50
|
block_size: 500
|
||||||
output_dim: 256 # 512
|
output_dim: 512
|
||||||
n_layer: 6 # 8
|
n_layer: 8 # 8
|
||||||
n_head: 6 # 4
|
n_head: 8 # 4
|
||||||
n_embd: 120 # 512
|
n_embd: 512
|
||||||
dropout: 0.1
|
dropout: 0.1
|
||||||
mlp_hidden_dim: 1024 # 512
|
mlp_hidden_dim: 1024 # 512
|
||||||
|
offset_loss_weight: 10000.
|
Loading…
Reference in New Issue