From b16c3348250a27b7268fd243e42db6bda9417598 Mon Sep 17 00:00:00 2001 From: Cadene Date: Sun, 25 Feb 2024 17:42:47 +0000 Subject: [PATCH] Refactor configs to have env in seperate yaml + Fix training --- lerobot/common/datasets/factory.py | 15 +++++----- lerobot/common/envs/factory.py | 6 ++-- lerobot/common/logger.py | 4 +-- lerobot/common/policies/diffusion.py | 44 ++++++++++++++++++++++++++++ lerobot/common/policies/factory.py | 9 ++++-- lerobot/configs/default.yaml | 31 ++++++++------------ lerobot/configs/{ => env}/pusht.yaml | 8 +++-- lerobot/configs/env/simxarm.yaml | 26 ++++++++++++++++ lerobot/scripts/train.py | 20 ++++++++----- tests/test_datasets.py | 17 +++++++++++ tests/test_envs.py | 8 ++--- tests/test_policies.py | 8 ++--- tests/utils.py | 4 +-- 13 files changed, 146 insertions(+), 54 deletions(-) create mode 100644 lerobot/common/policies/diffusion.py rename lerobot/configs/{ => env}/pusht.yaml (60%) create mode 100644 lerobot/configs/env/simxarm.yaml create mode 100644 tests/test_datasets.py diff --git a/lerobot/common/datasets/factory.py b/lerobot/common/datasets/factory.py index 66d81527..e8b61135 100644 --- a/lerobot/common/datasets/factory.py +++ b/lerobot/common/datasets/factory.py @@ -10,28 +10,29 @@ def make_offline_buffer(cfg, sampler=None): overwrite_sampler = sampler is not None if not overwrite_sampler: - num_traj_per_batch = cfg.batch_size # // cfg.horizon + # TODO(rcadene): move batch_size outside + num_traj_per_batch = cfg.policy.batch_size # // cfg.horizon # TODO(rcadene): Sampler outputs a batch_size <= cfg.batch_size. # We would need to add a transform to pad the tensordict to ensure batch_size == cfg.batch_size. sampler = PrioritizedSliceSampler( max_capacity=100_000, - alpha=cfg.per_alpha, - beta=cfg.per_beta, + alpha=cfg.policy.per_alpha, + beta=cfg.policy.per_beta, num_slices=num_traj_per_batch, strict_length=False, ) - if cfg.env == "simxarm": + if cfg.env.name == "simxarm": # TODO(rcadene): add PrioritizedSliceSampler inside Simxarm to not have to `sampler.extend(index)` here offline_buffer = SimxarmExperienceReplay( - f"xarm_{cfg.task}_medium", + f"xarm_{cfg.env.task}_medium", # download="force", download=True, streaming=False, root="data", sampler=sampler, ) - elif cfg.env == "pusht": + elif cfg.env.name == "pusht": offline_buffer = PushtExperienceReplay( "pusht", # download="force", @@ -41,7 +42,7 @@ def make_offline_buffer(cfg, sampler=None): sampler=sampler, ) else: - raise ValueError(cfg.env) + raise ValueError(cfg.env.name) if not overwrite_sampler: num_steps = len(offline_buffer) diff --git a/lerobot/common/envs/factory.py b/lerobot/common/envs/factory.py index fa094734..b93f3541 100644 --- a/lerobot/common/envs/factory.py +++ b/lerobot/common/envs/factory.py @@ -1,7 +1,5 @@ from torchrl.envs.transforms import StepCounter, TransformedEnv -from lerobot.common.envs.pusht import PushtEnv -from lerobot.common.envs.simxarm import SimxarmEnv from lerobot.common.envs.transforms import Prod @@ -14,9 +12,13 @@ def make_env(cfg): } if cfg.env.name == "simxarm": + from lerobot.common.envs.simxarm import SimxarmEnv + kwargs["task"] = cfg.env.task clsfunc = SimxarmEnv elif cfg.env.name == "pusht": + from lerobot.common.envs.pusht import PushtEnv + clsfunc = PushtEnv else: raise ValueError(cfg.env.name) diff --git a/lerobot/common/logger.py b/lerobot/common/logger.py index 031e062d..cb1bf0eb 100644 --- a/lerobot/common/logger.py +++ b/lerobot/common/logger.py @@ -50,7 +50,7 @@ def print_run(cfg, reward=None): ) kvs = [ - ("task", cfg.task), + ("task", cfg.env.task), ("train steps", f"{int(cfg.train_steps * cfg.env.action_repeat):,}"), # ('observations', 'x'.join([str(s) for s in cfg.obs_shape])), # ('actions', cfg.action_dim), @@ -72,7 +72,7 @@ def cfg_to_group(cfg, return_list=False): """Return a wandb-safe group name for logging. Optionally returns group name as list.""" # lst = [cfg.task, cfg.modality, re.sub("[^0-9a-zA-Z]+", "-", cfg.exp_name)] lst = [ - f"env:{cfg.env}", + f"env:{cfg.env.name}", f"seed:{cfg.seed}", ] return lst if return_list else "-".join(lst) diff --git a/lerobot/common/policies/diffusion.py b/lerobot/common/policies/diffusion.py new file mode 100644 index 00000000..b8272453 --- /dev/null +++ b/lerobot/common/policies/diffusion.py @@ -0,0 +1,44 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F +from diffusion_policy.policy.diffusion_unet_image_policy import DiffusionUnetImagePolicy + + +class DiffusionPolicy(nn.Module): + + def __init__( + self, + shape_meta: dict, + noise_scheduler: DDPMScheduler, + obs_encoder: MultiImageObsEncoder, + horizon, + n_action_steps, + n_obs_steps, + num_inference_steps=None, + obs_as_global_cond=True, + diffusion_step_embed_dim=256, + down_dims=(256, 512, 1024), + kernel_size=5, + n_groups=8, + cond_predict_scale=True, + # parameters passed to step + **kwargs, + ): + super().__init__() + self.diffusion = DiffusionUnetImagePolicy( + shape_meta=shape_meta, + noise_scheduler=noise_scheduler, + obs_encoder=obs_encoder, + horizon=horizon, + n_action_steps=n_action_steps, + n_obs_steps=n_obs_steps, + num_inference_steps=num_inference_steps, + obs_as_global_cond=obs_as_global_cond, + diffusion_step_embed_dim=diffusion_step_embed_dim, + down_dims=down_dims, + kernel_size=kernel_size, + n_groups=n_groups, + cond_predict_scale=cond_predict_scale, + # parameters passed to step + **kwargs, + ) diff --git a/lerobot/common/policies/factory.py b/lerobot/common/policies/factory.py index 79ef2720..d2407e1f 100644 --- a/lerobot/common/policies/factory.py +++ b/lerobot/common/policies/factory.py @@ -1,9 +1,12 @@ -from lerobot.common.policies.tdmpc import TDMPC - - def make_policy(cfg): if cfg.policy.name == "tdmpc": + from lerobot.common.policies.tdmpc import TDMPC + policy = TDMPC(cfg.policy) + elif cfg.policy.name == "diffusion": + from lerobot.common.policies.diffusion import DiffusionPolicy + + policy = DiffusionPolicy(cfg.policy) else: raise ValueError(cfg.policy.name) diff --git a/lerobot/configs/default.yaml b/lerobot/configs/default.yaml index 690d417f..1bd7dd89 100644 --- a/lerobot/configs/default.yaml +++ b/lerobot/configs/default.yaml @@ -1,31 +1,24 @@ +defaults: + - _self_ + - env: simxarm + hydra: run: dir: outputs/${now:%Y-%m-%d}/${now:%H-%M-%S}_${hydra.job.name} - job: - name: default seed: 1337 device: cuda buffer_device: cuda -eval_freq: 1000 -save_freq: 10000 -eval_episodes: 20 +eval_freq: ??? +save_freq: ??? +eval_episodes: ??? save_video: false save_model: false save_buffer: false -train_steps: 50000 -fps: 15 - -env: - name: simxarm - task: lift - from_pixels: True - pixels_only: False - image_size: 84 - action_repeat: 2 - episode_length: 25 - fps: ${fps} +train_steps: ??? +fps: ??? +env: ??? policy: name: tdmpc @@ -42,8 +35,8 @@ policy: frame_stack: 1 num_channels: 32 img_size: ${env.image_size} - state_dim: 4 - action_dim: 4 + state_dim: ??? + action_dim: ??? # planning mpc: true diff --git a/lerobot/configs/pusht.yaml b/lerobot/configs/env/pusht.yaml similarity index 60% rename from lerobot/configs/pusht.yaml rename to lerobot/configs/env/pusht.yaml index e856aca7..7b2ac7ba 100644 --- a/lerobot/configs/pusht.yaml +++ b/lerobot/configs/env/pusht.yaml @@ -1,6 +1,4 @@ -defaults: - - default - - _self_ +# @package _global_ hydra: job: @@ -9,11 +7,15 @@ hydra: eval_episodes: 50 eval_freq: 7500 save_freq: 75000 +train_steps: 50000 # TODO: same as simxarm, need to adjust + fps: 10 env: name: pusht task: pusht + from_pixels: True + pixels_only: False image_size: 96 action_repeat: 1 episode_length: 300 diff --git a/lerobot/configs/env/simxarm.yaml b/lerobot/configs/env/simxarm.yaml new file mode 100644 index 00000000..80324e78 --- /dev/null +++ b/lerobot/configs/env/simxarm.yaml @@ -0,0 +1,26 @@ +# @package _global_ + +hydra: + job: + name: simxarm + +eval_episodes: 20 +eval_freq: 1000 +save_freq: 10000 +train_steps: 50000 + +fps: 15 + +env: + name: simxarm + task: lift + from_pixels: True + pixels_only: False + image_size: 84 + action_repeat: 2 + episode_length: 25 + fps: ${fps} + +policy: + state_dim: 4 + action_dim: 4 \ No newline at end of file diff --git a/lerobot/scripts/train.py b/lerobot/scripts/train.py index 61cc1e63..5a0e2c16 100644 --- a/lerobot/scripts/train.py +++ b/lerobot/scripts/train.py @@ -62,13 +62,14 @@ def train(cfg: dict, out_dir=None, job_name=None): offline_buffer = make_offline_buffer(cfg) - if cfg.balanced_sampling: - num_traj_per_batch = cfg.batch_size + # TODO(rcadene): move balanced_sampling, per_alpha, per_beta outside policy + if cfg.policy.balanced_sampling: + num_traj_per_batch = cfg.policy.batch_size online_sampler = PrioritizedSliceSampler( max_capacity=100_000, - alpha=cfg.per_alpha, - beta=cfg.per_beta, + alpha=cfg.policy.per_alpha, + beta=cfg.policy.per_beta, num_slices=num_traj_per_batch, strict_length=True, ) @@ -92,7 +93,8 @@ def train(cfg: dict, out_dir=None, job_name=None): _step = step + num_updates rollout_metrics = {} - if step >= cfg.offline_steps: + # TODO(rcadene): move offline_steps outside policy + if step >= cfg.policy.offline_steps: is_offline = False # TODO: use SyncDataCollector for that? @@ -118,7 +120,7 @@ def train(cfg: dict, out_dir=None, job_name=None): "avg_max_reward": np.nanmean(ep_max_reward), "pc_success": np.nanmean(ep_success) * 100, } - num_updates = len(rollout) * cfg.utd + num_updates = len(rollout) * cfg.policy.utd _step = min(step + len(rollout), cfg.train_steps) # Update model @@ -128,8 +130,10 @@ def train(cfg: dict, out_dir=None, job_name=None): else: train_metrics = policy.update( online_buffer, - step + i // cfg.utd, - demo_buffer=offline_buffer if cfg.balanced_sampling else None, + step + i // cfg.policy.utd, + demo_buffer=( + offline_buffer if cfg.policy.balanced_sampling else None + ), ) # Log training metrics diff --git a/tests/test_datasets.py b/tests/test_datasets.py new file mode 100644 index 00000000..b61873fb --- /dev/null +++ b/tests/test_datasets.py @@ -0,0 +1,17 @@ +import pytest + +from lerobot.common.datasets.factory import make_offline_buffer + +from .utils import init_config + + +@pytest.mark.parametrize( + "env_name", + [ + "simxarm", + "pusht", + ], +) +def test_factory(env_name): + cfg = init_config(overrides=[f"env={env_name}"]) + offline_buffer = make_offline_buffer(cfg) diff --git a/tests/test_envs.py b/tests/test_envs.py index 433b719d..d9fa4854 100644 --- a/tests/test_envs.py +++ b/tests/test_envs.py @@ -78,13 +78,13 @@ def test_pusht(from_pixels, pixels_only): @pytest.mark.parametrize( - "config_name", + "env_name", [ - "default", + "simxarm", "pusht", ], ) -def test_factory(config_name): - cfg = init_config(config_name) +def test_factory(env_name): + cfg = init_config(overrides=[f"env={env_name}"]) env = make_env(cfg) check_env_specs(env) diff --git a/tests/test_policies.py b/tests/test_policies.py index 062e58f0..7408c729 100644 --- a/tests/test_policies.py +++ b/tests/test_policies.py @@ -6,12 +6,12 @@ from .utils import init_config @pytest.mark.parametrize( - "config_name", + "env_name", [ - "default", + "simxarm", "pusht", ], ) -def test_factory(config_name): - cfg = init_config(config_name) +def test_factory(env_name): + cfg = init_config(overrides=[f"env={env_name}"]) policy = make_policy(cfg) diff --git a/tests/utils.py b/tests/utils.py index be8583f5..40dc6de0 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -4,8 +4,8 @@ from hydra import compose, initialize CONFIG_PATH = "../lerobot/configs" -def init_config(config_name): +def init_config(config_name="default", overrides=None): hydra.core.global_hydra.GlobalHydra.instance().clear() initialize(config_path=CONFIG_PATH) - cfg = compose(config_name=config_name) + cfg = compose(config_name=config_name, overrides=overrides) return cfg