From cf5063e50e423d29eeb7560744fba0fb5d7fc3d4 Mon Sep 17 00:00:00 2001 From: Cadene Date: Wed, 28 Feb 2024 15:21:30 +0000 Subject: [PATCH] Add diffusion policy (train and eval works, TODO: reproduce results) --- lerobot/common/policies/diffusion.py | 110 +++++++++++++++++++++++++- lerobot/common/policies/factory.py | 23 ++---- lerobot/configs/policy/diffusion.yaml | 14 ++-- lerobot/configs/policy/tdmpc.yaml | 4 +- lerobot/scripts/train.py | 5 -- 5 files changed, 125 insertions(+), 31 deletions(-) diff --git a/lerobot/common/policies/diffusion.py b/lerobot/common/policies/diffusion.py index d05e7ac2..65d7085c 100644 --- a/lerobot/common/policies/diffusion.py +++ b/lerobot/common/policies/diffusion.py @@ -1,7 +1,12 @@ +import copy + +import hydra import torch import torch.nn as nn import torch.nn.functional as F from diffusers.schedulers.scheduling_ddpm import DDPMScheduler +from diffusion_policy.model.common.lr_scheduler import get_scheduler +from diffusion_policy.model.vision.model_getter import get_resnet from diffusion_policy.model.vision.multi_image_obs_encoder import MultiImageObsEncoder from diffusion_policy.policy.diffusion_unet_image_policy import DiffusionUnetImagePolicy @@ -10,9 +15,13 @@ class DiffusionPolicy(nn.Module): def __init__( self, + cfg, + cfg_noise_scheduler, + cfg_rgb_model, + cfg_obs_encoder, + cfg_optimizer, + cfg_ema, shape_meta: dict, - noise_scheduler: DDPMScheduler, - obs_encoder: MultiImageObsEncoder, horizon, n_action_steps, n_obs_steps, @@ -27,6 +36,15 @@ class DiffusionPolicy(nn.Module): **kwargs, ): super().__init__() + self.cfg = cfg + + noise_scheduler = DDPMScheduler(**cfg_noise_scheduler) + rgb_model = get_resnet(**cfg_rgb_model) + obs_encoder = MultiImageObsEncoder( + rgb_model=rgb_model, + **cfg_obs_encoder, + ) + self.diffusion = DiffusionUnetImagePolicy( shape_meta=shape_meta, noise_scheduler=noise_scheduler, @@ -44,3 +62,91 @@ class DiffusionPolicy(nn.Module): # parameters passed to step **kwargs, ) + + self.device = torch.device("cuda") + self.diffusion.cuda() + + self.ema = None + if self.cfg.use_ema: + self.ema = hydra.utils.instantiate( + cfg_ema, + model=copy.deepcopy(self.diffusion), + ) + + self.optimizer = hydra.utils.instantiate( + cfg_optimizer, + params=self.diffusion.parameters(), + ) + + # TODO(rcadene): modify lr scheduler so that it doesnt depend on epochs but steps + self.global_step = 0 + + # configure lr scheduler + self.lr_scheduler = get_scheduler( + cfg.lr_scheduler, + optimizer=self.optimizer, + num_warmup_steps=cfg.lr_warmup_steps, + num_training_steps=cfg.offline_steps, + # pytorch assumes stepping LRScheduler every epoch + # however huggingface diffusers steps it every batch + last_epoch=self.global_step - 1, + ) + + @torch.no_grad() + def forward(self, observation, step_count): + # TODO(rcadene): remove unused step_count + del step_count + + obs_dict = { + # c h w -> b t c h w (b=1, t=1) + "image": observation["image"][None, None, ...], + "agent_pos": observation["state"][None, None, ...], + } + out = self.diffusion.predict_action(obs_dict) + + # TODO(rcadene): add possibility to return >1 timestemps + FIRST_ACTION = 0 + action = out["action"].squeeze(0)[FIRST_ACTION] + return action + + def update(self, replay_buffer, step): + self.diffusion.train() + + num_slices = self.cfg.batch_size + batch_size = self.cfg.horizon * num_slices + + assert batch_size % self.cfg.horizon == 0 + assert batch_size % num_slices == 0 + + def process_batch(batch, horizon, num_slices): + # trajectory t = 256, horizon h = 5 + # (t h) ... -> h t ... + batch = batch.reshape(num_slices, horizon).transpose(1, 0).contiguous() + + out = { + "obs": { + "image": batch["observation", "image"].to(self.device), + "agent_pos": batch["observation", "state"].to(self.device), + }, + "action": batch["action"].to(self.device), + } + return out + + batch = replay_buffer.sample(batch_size) + batch = process_batch(batch, self.cfg.horizon, num_slices) + + loss = self.diffusion.compute_loss(batch) + loss.backward() + + self.optimizer.step() + self.optimizer.zero_grad() + self.lr_scheduler.step() + + if self.ema is not None: + self.ema.step(self.diffusion) + + metrics = { + "total_loss": loss.item(), + "lr": self.lr_scheduler.get_last_lr()[0], + } + return metrics diff --git a/lerobot/common/policies/factory.py b/lerobot/common/policies/factory.py index 7a87b5af..9d5afe35 100644 --- a/lerobot/common/policies/factory.py +++ b/lerobot/common/policies/factory.py @@ -4,26 +4,15 @@ def make_policy(cfg): policy = TDMPC(cfg.policy) elif cfg.policy.name == "diffusion": - from diffusers.schedulers.scheduling_ddpm import DDPMScheduler - from diffusion_policy.model.vision.model_getter import get_resnet - from diffusion_policy.model.vision.multi_image_obs_encoder import ( - MultiImageObsEncoder, - ) - from lerobot.common.policies.diffusion import DiffusionPolicy - noise_scheduler = DDPMScheduler(**cfg.noise_scheduler) - - rgb_model = get_resnet(**cfg.rgb_model) - - obs_encoder = MultiImageObsEncoder( - rgb_model=rgb_model, - **cfg.obs_encoder, - ) - policy = DiffusionPolicy( - noise_scheduler=noise_scheduler, - obs_encoder=obs_encoder, + cfg=cfg.policy, + cfg_noise_scheduler=cfg.noise_scheduler, + cfg_rgb_model=cfg.rgb_model, + cfg_obs_encoder=cfg.obs_encoder, + cfg_optimizer=cfg.optimizer, + cfg_ema=cfg.ema, n_action_steps=cfg.n_action_steps + cfg.n_latency_steps, **cfg.policy, ) diff --git a/lerobot/configs/policy/diffusion.yaml b/lerobot/configs/policy/diffusion.yaml index cfd37ab1..c3b18298 100644 --- a/lerobot/configs/policy/diffusion.yaml +++ b/lerobot/configs/policy/diffusion.yaml @@ -13,7 +13,7 @@ shape_meta: shape: [2] horizon: 16 -n_obs_steps: 2 +n_obs_steps: 1 # TODO(rcadene): before 2 n_action_steps: 8 n_latency_steps: 0 dataset_obs_steps: ${n_obs_steps} @@ -51,6 +51,10 @@ policy: balanced_sampling: true utd: 1 + offline_steps: ${offline_steps} + use_ema: true + lr_scheduler: cosine + lr_warmup_steps: 500 noise_scheduler: # _target_: diffusers.schedulers.scheduling_ddpm.DDPMScheduler @@ -99,13 +103,13 @@ training: debug: False resume: True # optimization - lr_scheduler: cosine - lr_warmup_steps: 500 + # lr_scheduler: cosine + # lr_warmup_steps: 500 num_epochs: 8000 - gradient_accumulate_every: 1 + # gradient_accumulate_every: 1 # EMA destroys performance when used with BatchNorm # replace BatchNorm with GroupNorm. - use_ema: True + # use_ema: True freeze_encoder: False # training loop control # in epochs diff --git a/lerobot/configs/policy/tdmpc.yaml b/lerobot/configs/policy/tdmpc.yaml index f4bb46ed..26dc4e51 100644 --- a/lerobot/configs/policy/tdmpc.yaml +++ b/lerobot/configs/policy/tdmpc.yaml @@ -62,7 +62,7 @@ policy: A_scaling: 3.0 # offline->online - offline_steps: 25000 # ${train_steps}/2 + offline_steps: ${offline_steps} pretrained_model_path: "" # pretrained_model_path: "/home/rcadene/code/fowm/logs/xarm_lift/all/default/2/models/offline.pt" # pretrained_model_path: "/home/rcadene/code/fowm/logs/xarm_lift/all/default/2/models/final.pt" @@ -73,4 +73,4 @@ policy: enc_dim: 256 num_q: 5 mlp_dim: 512 - latent_dim: 50 \ No newline at end of file + latent_dim: 50 diff --git a/lerobot/scripts/train.py b/lerobot/scripts/train.py index 9af2f79e..26f13f37 100644 --- a/lerobot/scripts/train.py +++ b/lerobot/scripts/train.py @@ -122,11 +122,6 @@ def train(cfg: dict, out_dir=None, job_name=None): start_time = time.time() step = 0 # number of policy update - print("First eval_policy_and_log with a random model or pretrained") - eval_policy_and_log( - env, td_policy, step, online_episode_idx, start_time, cfg, L, is_offline=True - ) - for offline_step in range(cfg.offline_steps): if offline_step == 0: print("Start offline training on a fixed dataset")