lerobot/lerobot/configs/policy/tdmpc.yaml

80 lines
1.9 KiB
YAML
Raw Normal View History

2024-02-26 02:26:44 +08:00
# @package _global_
seed: 1
2024-05-09 20:42:12 +08:00
dataset_repo_id: lerobot/xarm_lift_medium
training:
offline_steps: 25000
# TODO(alexander-soare): uncomment when online training gets reinstated
online_steps: 0 # 25000 not implemented yet
eval_freq: 5000
online_steps_between_rollouts: 1
online_sampling_ratio: 0.5
2024-05-04 00:33:16 +08:00
online_env_seed: 10000
2024-02-26 02:26:44 +08:00
batch_size: 256
grad_clip_norm: 10.0
lr: 3e-4
2024-02-26 02:26:44 +08:00
delta_timestamps:
observation.image: "[i / ${fps} for i in range(${policy.horizon} + 1)]"
observation.state: "[i / ${fps} for i in range(${policy.horizon} + 1)]"
action: "[i / ${fps} for i in range(${policy.horizon})]"
next.reward: "[i / ${fps} for i in range(${policy.horizon})]"
2024-02-26 02:26:44 +08:00
policy:
name: tdmpc
2024-02-26 02:26:44 +08:00
pretrained_model_path:
2024-02-26 02:26:44 +08:00
# Input / output structure.
n_action_repeats: 2
2024-02-26 02:26:44 +08:00
horizon: 5
input_shapes:
# TODO(rcadene, alexander-soare): add variables for height and width from the dataset/env?
observation.image: [3, 84, 84]
observation.state: ["${env.state_dim}"]
output_shapes:
action: ["${env.action_dim}"]
2024-02-26 02:26:44 +08:00
# Normalization / Unnormalization
input_normalization_modes: null
output_normalization_modes:
action: min_max
2024-02-26 02:26:44 +08:00
# Architecture / modeling.
# Neural networks.
image_encoder_hidden_dim: 32
state_encoder_hidden_dim: 256
latent_dim: 50
q_ensemble_size: 5
mlp_dim: 512
# Reinforcement learning.
discount: 0.9
# Inference.
use_mpc: true
cem_iterations: 6
max_std: 2.0
min_std: 0.05
n_gaussian_samples: 512
n_pi_samples: 51
uncertainty_regularizer_coeff: 1.0
n_elites: 50
elite_weighting_temperature: 0.5
gaussian_mean_momentum: 0.1
# Training and loss computation.
max_random_shift_ratio: 0.0476
# Loss coefficients.
reward_coeff: 0.5
expectile_weight: 0.9
value_coeff: 0.1
consistency_coeff: 20.0
advantage_scaling: 3.0
pi_coeff: 0.5
temporal_decay_coeff: 0.5
# Target model.
target_model_momentum: 0.995