Sanitize cfg.policy, Fix skip_frame pusht.yaml

This commit is contained in:
Cadene 2024-02-25 11:09:02 +00:00
parent fc4b98544b
commit e765e26b0b
3 changed files with 78 additions and 72 deletions

View File

@ -2,20 +2,20 @@ from lerobot.common.policies.tdmpc import TDMPC
def make_policy(cfg): def make_policy(cfg):
if cfg.policy == "tdmpc": if cfg.policy.name == "tdmpc":
policy = TDMPC(cfg) policy = TDMPC(cfg.policy)
else: else:
raise ValueError(cfg.policy) raise ValueError(cfg.policy.name)
if cfg.pretrained_model_path: if cfg.policy.pretrained_model_path:
# TODO(rcadene): hack for old pretrained models from fowm # TODO(rcadene): hack for old pretrained models from fowm
if cfg.policy == "tdmpc" and "fowm" in cfg.pretrained_model_path: if cfg.policy.name == "tdmpc" and "fowm" in cfg.policy.pretrained_model_path:
if "offline" in cfg.pretrained_model_path: if "offline" in cfg.pretrained_model_path:
policy.step[0] = 25000 policy.step[0] = 25000
elif "final" in cfg.pretrained_model_path: elif "final" in cfg.pretrained_model_path:
policy.step[0] = 100000 policy.step[0] = 100000
else: else:
raise NotImplementedError() raise NotImplementedError()
policy.load(cfg.pretrained_model_path) policy.load(cfg.policy.pretrained_model_path)
return policy return policy

View File

@ -22,84 +22,90 @@ pixels_only: False
image_size: 84 image_size: 84
fps: 15 fps: 15
reward_scale: 1.0
# xarm_lift # xarm_lift
action_repeat: 2
episode_length: 25 episode_length: 25
modality: 'all'
action_repeat: 2 # TODO(rcadene): verify we use this
discount: 0.9
train_steps: 50000 train_steps: 50000
# pixels
frame_stack: 1
num_channels: 32
img_size: ${image_size}
state_dim: 4
action_dim: 4
# TDMPC policy:
policy: tdmpc name: tdmpc
# planning reward_scale: 1.0
mpc: true
iterations: 6
num_samples: 512
num_elites: 50
mixture_coef: 0.1
min_std: 0.05
max_std: 2.0
temperature: 0.5
momentum: 0.1
uncertainty_cost: 1
# actor # xarm_lift
log_std_min: -10 train_steps: ${train_steps}
log_std_max: 2 episode_length: ${episode_length}
discount: 0.9
modality: 'all'
# learning # pixels
batch_size: 256 frame_stack: 1
max_buffer_size: 10000 num_channels: 32
horizon: 5 img_size: ${image_size}
reward_coef: 0.5 state_dim: 4
value_coef: 0.1 action_dim: 4
consistency_coef: 20
rho: 0.5
kappa: 0.1
lr: 3e-4
std_schedule: ${min_std}
horizon_schedule: ${horizon}
per: true
per_alpha: 0.6
per_beta: 0.4
grad_clip_norm: 10
seed_steps: 0
update_freq: 2
tau: 0.01
utd: 1
# offline rl # planning
# dataset_dir: ??? mpc: true
data_first_percent: 1.0 iterations: 6
is_data_clip: true num_samples: 512
data_clip_eps: 1e-5 num_elites: 50
expectile: 0.9 mixture_coef: 0.1
A_scaling: 3.0 min_std: 0.05
max_std: 2.0
temperature: 0.5
momentum: 0.1
uncertainty_cost: 1
# offline->online # actor
offline_steps: 25000 # ${train_steps}/2 log_std_min: -10
pretrained_model_path: "" log_std_max: 2
# pretrained_model_path: "/home/rcadene/code/fowm/logs/xarm_lift/all/default/2/models/offline.pt"
# pretrained_model_path: "/home/rcadene/code/fowm/logs/xarm_lift/all/default/2/models/final.pt"
balanced_sampling: true
demo_schedule: 0.5
# architecture # learning
enc_dim: 256 batch_size: 256
num_q: 5 max_buffer_size: 10000
mlp_dim: 512 horizon: 5
latent_dim: 50 reward_coef: 0.5
value_coef: 0.1
consistency_coef: 20
rho: 0.5
kappa: 0.1
lr: 3e-4
std_schedule: ${policy.min_std}
horizon_schedule: ${policy.horizon}
per: true
per_alpha: 0.6
per_beta: 0.4
grad_clip_norm: 10
seed_steps: 0
update_freq: 2
tau: 0.01
utd: 1
# offline rl
# dataset_dir: ???
data_first_percent: 1.0
is_data_clip: true
data_clip_eps: 1e-5
expectile: 0.9
A_scaling: 3.0
# offline->online
offline_steps: 25000 # ${train_steps}/2
pretrained_model_path: ""
# pretrained_model_path: "/home/rcadene/code/fowm/logs/xarm_lift/all/default/2/models/offline.pt"
# pretrained_model_path: "/home/rcadene/code/fowm/logs/xarm_lift/all/default/2/models/final.pt"
balanced_sampling: true
demo_schedule: 0.5
# architecture
enc_dim: 256
num_q: 5
mlp_dim: 512
latent_dim: 50
# wandb # wandb
use_wandb: true use_wandb: true

View File

@ -10,7 +10,7 @@ hydra:
env: pusht env: pusht
task: pusht task: pusht
image_size: 96 image_size: 96
frame_skip: 1 action_repeat: 1
state_dim: 2 state_dim: 2
action_dim: 2 action_dim: 2
fps: 10 fps: 10