Sanitize cfg.policy, Fix skip_frame pusht.yaml

This commit is contained in:
Cadene 2024-02-25 11:09:02 +00:00
parent fc4b98544b
commit e765e26b0b
3 changed files with 78 additions and 72 deletions

View File

@ -2,20 +2,20 @@ from lerobot.common.policies.tdmpc import TDMPC
def make_policy(cfg):
if cfg.policy == "tdmpc":
policy = TDMPC(cfg)
if cfg.policy.name == "tdmpc":
policy = TDMPC(cfg.policy)
else:
raise ValueError(cfg.policy)
raise ValueError(cfg.policy.name)
if cfg.pretrained_model_path:
if cfg.policy.pretrained_model_path:
# TODO(rcadene): hack for old pretrained models from fowm
if cfg.policy == "tdmpc" and "fowm" in cfg.pretrained_model_path:
if cfg.policy.name == "tdmpc" and "fowm" in cfg.policy.pretrained_model_path:
if "offline" in cfg.pretrained_model_path:
policy.step[0] = 25000
elif "final" in cfg.pretrained_model_path:
policy.step[0] = 100000
else:
raise NotImplementedError()
policy.load(cfg.pretrained_model_path)
policy.load(cfg.policy.pretrained_model_path)
return policy

View File

@ -22,84 +22,90 @@ pixels_only: False
image_size: 84
fps: 15
reward_scale: 1.0
# xarm_lift
action_repeat: 2
episode_length: 25
modality: 'all'
action_repeat: 2 # TODO(rcadene): verify we use this
discount: 0.9
train_steps: 50000
# pixels
frame_stack: 1
num_channels: 32
img_size: ${image_size}
state_dim: 4
action_dim: 4
# TDMPC
policy: tdmpc
policy:
name: tdmpc
# planning
mpc: true
iterations: 6
num_samples: 512
num_elites: 50
mixture_coef: 0.1
min_std: 0.05
max_std: 2.0
temperature: 0.5
momentum: 0.1
uncertainty_cost: 1
reward_scale: 1.0
# actor
log_std_min: -10
log_std_max: 2
# xarm_lift
train_steps: ${train_steps}
episode_length: ${episode_length}
discount: 0.9
modality: 'all'
# learning
batch_size: 256
max_buffer_size: 10000
horizon: 5
reward_coef: 0.5
value_coef: 0.1
consistency_coef: 20
rho: 0.5
kappa: 0.1
lr: 3e-4
std_schedule: ${min_std}
horizon_schedule: ${horizon}
per: true
per_alpha: 0.6
per_beta: 0.4
grad_clip_norm: 10
seed_steps: 0
update_freq: 2
tau: 0.01
utd: 1
# pixels
frame_stack: 1
num_channels: 32
img_size: ${image_size}
state_dim: 4
action_dim: 4
# offline rl
# dataset_dir: ???
data_first_percent: 1.0
is_data_clip: true
data_clip_eps: 1e-5
expectile: 0.9
A_scaling: 3.0
# planning
mpc: true
iterations: 6
num_samples: 512
num_elites: 50
mixture_coef: 0.1
min_std: 0.05
max_std: 2.0
temperature: 0.5
momentum: 0.1
uncertainty_cost: 1
# offline->online
offline_steps: 25000 # ${train_steps}/2
pretrained_model_path: ""
# pretrained_model_path: "/home/rcadene/code/fowm/logs/xarm_lift/all/default/2/models/offline.pt"
# pretrained_model_path: "/home/rcadene/code/fowm/logs/xarm_lift/all/default/2/models/final.pt"
balanced_sampling: true
demo_schedule: 0.5
# actor
log_std_min: -10
log_std_max: 2
# architecture
enc_dim: 256
num_q: 5
mlp_dim: 512
latent_dim: 50
# learning
batch_size: 256
max_buffer_size: 10000
horizon: 5
reward_coef: 0.5
value_coef: 0.1
consistency_coef: 20
rho: 0.5
kappa: 0.1
lr: 3e-4
std_schedule: ${policy.min_std}
horizon_schedule: ${policy.horizon}
per: true
per_alpha: 0.6
per_beta: 0.4
grad_clip_norm: 10
seed_steps: 0
update_freq: 2
tau: 0.01
utd: 1
# offline rl
# dataset_dir: ???
data_first_percent: 1.0
is_data_clip: true
data_clip_eps: 1e-5
expectile: 0.9
A_scaling: 3.0
# offline->online
offline_steps: 25000 # ${train_steps}/2
pretrained_model_path: ""
# pretrained_model_path: "/home/rcadene/code/fowm/logs/xarm_lift/all/default/2/models/offline.pt"
# pretrained_model_path: "/home/rcadene/code/fowm/logs/xarm_lift/all/default/2/models/final.pt"
balanced_sampling: true
demo_schedule: 0.5
# architecture
enc_dim: 256
num_q: 5
mlp_dim: 512
latent_dim: 50
# wandb
use_wandb: true

View File

@ -10,7 +10,7 @@ hydra:
env: pusht
task: pusht
image_size: 96
frame_skip: 1
action_repeat: 1
state_dim: 2
action_dim: 2
fps: 10