Sanitize cfg.policy, Fix skip_frame pusht.yaml

This commit is contained in:
Cadene 2024-02-25 11:09:02 +00:00
parent fc4b98544b
commit e765e26b0b
3 changed files with 78 additions and 72 deletions

View File

@ -2,20 +2,20 @@ from lerobot.common.policies.tdmpc import TDMPC
def make_policy(cfg): def make_policy(cfg):
if cfg.policy == "tdmpc": if cfg.policy.name == "tdmpc":
policy = TDMPC(cfg) policy = TDMPC(cfg.policy)
else: else:
raise ValueError(cfg.policy) raise ValueError(cfg.policy.name)
if cfg.pretrained_model_path: if cfg.policy.pretrained_model_path:
# TODO(rcadene): hack for old pretrained models from fowm # TODO(rcadene): hack for old pretrained models from fowm
if cfg.policy == "tdmpc" and "fowm" in cfg.pretrained_model_path: if cfg.policy.name == "tdmpc" and "fowm" in cfg.policy.pretrained_model_path:
if "offline" in cfg.pretrained_model_path: if "offline" in cfg.pretrained_model_path:
policy.step[0] = 25000 policy.step[0] = 25000
elif "final" in cfg.pretrained_model_path: elif "final" in cfg.pretrained_model_path:
policy.step[0] = 100000 policy.step[0] = 100000
else: else:
raise NotImplementedError() raise NotImplementedError()
policy.load(cfg.pretrained_model_path) policy.load(cfg.policy.pretrained_model_path)
return policy return policy

View File

@ -22,16 +22,25 @@ pixels_only: False
image_size: 84 image_size: 84
fps: 15 fps: 15
reward_scale: 1.0
# xarm_lift # xarm_lift
action_repeat: 2
episode_length: 25 episode_length: 25
modality: 'all'
action_repeat: 2 # TODO(rcadene): verify we use this
discount: 0.9
train_steps: 50000 train_steps: 50000
policy:
name: tdmpc
reward_scale: 1.0
# xarm_lift
train_steps: ${train_steps}
episode_length: ${episode_length}
discount: 0.9
modality: 'all'
# pixels # pixels
frame_stack: 1 frame_stack: 1
num_channels: 32 num_channels: 32
@ -39,9 +48,6 @@ img_size: ${image_size}
state_dim: 4 state_dim: 4
action_dim: 4 action_dim: 4
# TDMPC
policy: tdmpc
# planning # planning
mpc: true mpc: true
iterations: 6 iterations: 6
@ -68,8 +74,8 @@ consistency_coef: 20
rho: 0.5 rho: 0.5
kappa: 0.1 kappa: 0.1
lr: 3e-4 lr: 3e-4
std_schedule: ${min_std} std_schedule: ${policy.min_std}
horizon_schedule: ${horizon} horizon_schedule: ${policy.horizon}
per: true per: true
per_alpha: 0.6 per_alpha: 0.6
per_beta: 0.4 per_beta: 0.4

View File

@ -10,7 +10,7 @@ hydra:
env: pusht env: pusht
task: pusht task: pusht
image_size: 96 image_size: 96
frame_skip: 1 action_repeat: 1
state_dim: 2 state_dim: 2
action_dim: 2 action_dim: 2
fps: 10 fps: 10