Sanitize cfg.policy, Fix skip_frame pusht.yaml
This commit is contained in:
parent
fc4b98544b
commit
e765e26b0b
|
@ -2,20 +2,20 @@ from lerobot.common.policies.tdmpc import TDMPC
|
||||||
|
|
||||||
|
|
||||||
def make_policy(cfg):
|
def make_policy(cfg):
|
||||||
if cfg.policy == "tdmpc":
|
if cfg.policy.name == "tdmpc":
|
||||||
policy = TDMPC(cfg)
|
policy = TDMPC(cfg.policy)
|
||||||
else:
|
else:
|
||||||
raise ValueError(cfg.policy)
|
raise ValueError(cfg.policy.name)
|
||||||
|
|
||||||
if cfg.pretrained_model_path:
|
if cfg.policy.pretrained_model_path:
|
||||||
# TODO(rcadene): hack for old pretrained models from fowm
|
# TODO(rcadene): hack for old pretrained models from fowm
|
||||||
if cfg.policy == "tdmpc" and "fowm" in cfg.pretrained_model_path:
|
if cfg.policy.name == "tdmpc" and "fowm" in cfg.policy.pretrained_model_path:
|
||||||
if "offline" in cfg.pretrained_model_path:
|
if "offline" in cfg.pretrained_model_path:
|
||||||
policy.step[0] = 25000
|
policy.step[0] = 25000
|
||||||
elif "final" in cfg.pretrained_model_path:
|
elif "final" in cfg.pretrained_model_path:
|
||||||
policy.step[0] = 100000
|
policy.step[0] = 100000
|
||||||
else:
|
else:
|
||||||
raise NotImplementedError()
|
raise NotImplementedError()
|
||||||
policy.load(cfg.pretrained_model_path)
|
policy.load(cfg.policy.pretrained_model_path)
|
||||||
|
|
||||||
return policy
|
return policy
|
||||||
|
|
|
@ -22,16 +22,25 @@ pixels_only: False
|
||||||
image_size: 84
|
image_size: 84
|
||||||
fps: 15
|
fps: 15
|
||||||
|
|
||||||
reward_scale: 1.0
|
|
||||||
|
|
||||||
|
|
||||||
# xarm_lift
|
# xarm_lift
|
||||||
|
action_repeat: 2
|
||||||
episode_length: 25
|
episode_length: 25
|
||||||
modality: 'all'
|
|
||||||
action_repeat: 2 # TODO(rcadene): verify we use this
|
|
||||||
discount: 0.9
|
|
||||||
train_steps: 50000
|
train_steps: 50000
|
||||||
|
|
||||||
|
|
||||||
|
policy:
|
||||||
|
name: tdmpc
|
||||||
|
|
||||||
|
reward_scale: 1.0
|
||||||
|
|
||||||
|
# xarm_lift
|
||||||
|
train_steps: ${train_steps}
|
||||||
|
episode_length: ${episode_length}
|
||||||
|
discount: 0.9
|
||||||
|
modality: 'all'
|
||||||
|
|
||||||
# pixels
|
# pixels
|
||||||
frame_stack: 1
|
frame_stack: 1
|
||||||
num_channels: 32
|
num_channels: 32
|
||||||
|
@ -39,9 +48,6 @@ img_size: ${image_size}
|
||||||
state_dim: 4
|
state_dim: 4
|
||||||
action_dim: 4
|
action_dim: 4
|
||||||
|
|
||||||
# TDMPC
|
|
||||||
policy: tdmpc
|
|
||||||
|
|
||||||
# planning
|
# planning
|
||||||
mpc: true
|
mpc: true
|
||||||
iterations: 6
|
iterations: 6
|
||||||
|
@ -68,8 +74,8 @@ consistency_coef: 20
|
||||||
rho: 0.5
|
rho: 0.5
|
||||||
kappa: 0.1
|
kappa: 0.1
|
||||||
lr: 3e-4
|
lr: 3e-4
|
||||||
std_schedule: ${min_std}
|
std_schedule: ${policy.min_std}
|
||||||
horizon_schedule: ${horizon}
|
horizon_schedule: ${policy.horizon}
|
||||||
per: true
|
per: true
|
||||||
per_alpha: 0.6
|
per_alpha: 0.6
|
||||||
per_beta: 0.4
|
per_beta: 0.4
|
||||||
|
|
|
@ -10,7 +10,7 @@ hydra:
|
||||||
env: pusht
|
env: pusht
|
||||||
task: pusht
|
task: pusht
|
||||||
image_size: 96
|
image_size: 96
|
||||||
frame_skip: 1
|
action_repeat: 1
|
||||||
state_dim: 2
|
state_dim: 2
|
||||||
action_dim: 2
|
action_dim: 2
|
||||||
fps: 10
|
fps: 10
|
||||||
|
|
Loading…
Reference in New Issue