diff --git a/lerobot/common/policies/factory.py b/lerobot/common/policies/factory.py index 82e54476..79ef2720 100644 --- a/lerobot/common/policies/factory.py +++ b/lerobot/common/policies/factory.py @@ -2,20 +2,20 @@ from lerobot.common.policies.tdmpc import TDMPC def make_policy(cfg): - if cfg.policy == "tdmpc": - policy = TDMPC(cfg) + if cfg.policy.name == "tdmpc": + policy = TDMPC(cfg.policy) else: - raise ValueError(cfg.policy) + raise ValueError(cfg.policy.name) - if cfg.pretrained_model_path: + if cfg.policy.pretrained_model_path: # TODO(rcadene): hack for old pretrained models from fowm - if cfg.policy == "tdmpc" and "fowm" in cfg.pretrained_model_path: + if cfg.policy.name == "tdmpc" and "fowm" in cfg.policy.pretrained_model_path: if "offline" in cfg.pretrained_model_path: policy.step[0] = 25000 elif "final" in cfg.pretrained_model_path: policy.step[0] = 100000 else: raise NotImplementedError() - policy.load(cfg.pretrained_model_path) + policy.load(cfg.policy.pretrained_model_path) return policy diff --git a/lerobot/configs/default.yaml b/lerobot/configs/default.yaml index f0339e5e..922efac8 100644 --- a/lerobot/configs/default.yaml +++ b/lerobot/configs/default.yaml @@ -22,84 +22,90 @@ pixels_only: False image_size: 84 fps: 15 -reward_scale: 1.0 # xarm_lift +action_repeat: 2 episode_length: 25 -modality: 'all' -action_repeat: 2 # TODO(rcadene): verify we use this -discount: 0.9 train_steps: 50000 -# pixels -frame_stack: 1 -num_channels: 32 -img_size: ${image_size} -state_dim: 4 -action_dim: 4 -# TDMPC -policy: tdmpc +policy: + name: tdmpc -# planning -mpc: true -iterations: 6 -num_samples: 512 -num_elites: 50 -mixture_coef: 0.1 -min_std: 0.05 -max_std: 2.0 -temperature: 0.5 -momentum: 0.1 -uncertainty_cost: 1 + reward_scale: 1.0 -# actor -log_std_min: -10 -log_std_max: 2 + # xarm_lift + train_steps: ${train_steps} + episode_length: ${episode_length} + discount: 0.9 + modality: 'all' -# learning -batch_size: 256 -max_buffer_size: 10000 -horizon: 5 -reward_coef: 0.5 -value_coef: 0.1 -consistency_coef: 20 -rho: 0.5 -kappa: 0.1 -lr: 3e-4 -std_schedule: ${min_std} -horizon_schedule: ${horizon} -per: true -per_alpha: 0.6 -per_beta: 0.4 -grad_clip_norm: 10 -seed_steps: 0 -update_freq: 2 -tau: 0.01 -utd: 1 + # pixels + frame_stack: 1 + num_channels: 32 + img_size: ${image_size} + state_dim: 4 + action_dim: 4 -# offline rl -# dataset_dir: ??? -data_first_percent: 1.0 -is_data_clip: true -data_clip_eps: 1e-5 -expectile: 0.9 -A_scaling: 3.0 + # planning + mpc: true + iterations: 6 + num_samples: 512 + num_elites: 50 + mixture_coef: 0.1 + min_std: 0.05 + max_std: 2.0 + temperature: 0.5 + momentum: 0.1 + uncertainty_cost: 1 -# offline->online -offline_steps: 25000 # ${train_steps}/2 -pretrained_model_path: "" -# pretrained_model_path: "/home/rcadene/code/fowm/logs/xarm_lift/all/default/2/models/offline.pt" -# pretrained_model_path: "/home/rcadene/code/fowm/logs/xarm_lift/all/default/2/models/final.pt" -balanced_sampling: true -demo_schedule: 0.5 + # actor + log_std_min: -10 + log_std_max: 2 -# architecture -enc_dim: 256 -num_q: 5 -mlp_dim: 512 -latent_dim: 50 + # learning + batch_size: 256 + max_buffer_size: 10000 + horizon: 5 + reward_coef: 0.5 + value_coef: 0.1 + consistency_coef: 20 + rho: 0.5 + kappa: 0.1 + lr: 3e-4 + std_schedule: ${policy.min_std} + horizon_schedule: ${policy.horizon} + per: true + per_alpha: 0.6 + per_beta: 0.4 + grad_clip_norm: 10 + seed_steps: 0 + update_freq: 2 + tau: 0.01 + utd: 1 + + # offline rl + # dataset_dir: ??? + data_first_percent: 1.0 + is_data_clip: true + data_clip_eps: 1e-5 + expectile: 0.9 + A_scaling: 3.0 + + # offline->online + offline_steps: 25000 # ${train_steps}/2 + pretrained_model_path: "" + # pretrained_model_path: "/home/rcadene/code/fowm/logs/xarm_lift/all/default/2/models/offline.pt" + # pretrained_model_path: "/home/rcadene/code/fowm/logs/xarm_lift/all/default/2/models/final.pt" + balanced_sampling: true + demo_schedule: 0.5 + + # architecture + enc_dim: 256 + num_q: 5 + mlp_dim: 512 + latent_dim: 50 # wandb use_wandb: true diff --git a/lerobot/configs/pusht.yaml b/lerobot/configs/pusht.yaml index d7166309..b2f7b766 100644 --- a/lerobot/configs/pusht.yaml +++ b/lerobot/configs/pusht.yaml @@ -10,7 +10,7 @@ hydra: env: pusht task: pusht image_size: 96 -frame_skip: 1 +action_repeat: 1 state_dim: 2 action_dim: 2 fps: 10