Add diffusion policy (train and eval works, TODO: reproduce results)
This commit is contained in:
parent
f1708c8a37
commit
cf5063e50e
|
@ -1,7 +1,12 @@
|
|||
import copy
|
||||
|
||||
import hydra
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from diffusers.schedulers.scheduling_ddpm import DDPMScheduler
|
||||
from diffusion_policy.model.common.lr_scheduler import get_scheduler
|
||||
from diffusion_policy.model.vision.model_getter import get_resnet
|
||||
from diffusion_policy.model.vision.multi_image_obs_encoder import MultiImageObsEncoder
|
||||
from diffusion_policy.policy.diffusion_unet_image_policy import DiffusionUnetImagePolicy
|
||||
|
||||
|
@ -10,9 +15,13 @@ class DiffusionPolicy(nn.Module):
|
|||
|
||||
def __init__(
|
||||
self,
|
||||
cfg,
|
||||
cfg_noise_scheduler,
|
||||
cfg_rgb_model,
|
||||
cfg_obs_encoder,
|
||||
cfg_optimizer,
|
||||
cfg_ema,
|
||||
shape_meta: dict,
|
||||
noise_scheduler: DDPMScheduler,
|
||||
obs_encoder: MultiImageObsEncoder,
|
||||
horizon,
|
||||
n_action_steps,
|
||||
n_obs_steps,
|
||||
|
@ -27,6 +36,15 @@ class DiffusionPolicy(nn.Module):
|
|||
**kwargs,
|
||||
):
|
||||
super().__init__()
|
||||
self.cfg = cfg
|
||||
|
||||
noise_scheduler = DDPMScheduler(**cfg_noise_scheduler)
|
||||
rgb_model = get_resnet(**cfg_rgb_model)
|
||||
obs_encoder = MultiImageObsEncoder(
|
||||
rgb_model=rgb_model,
|
||||
**cfg_obs_encoder,
|
||||
)
|
||||
|
||||
self.diffusion = DiffusionUnetImagePolicy(
|
||||
shape_meta=shape_meta,
|
||||
noise_scheduler=noise_scheduler,
|
||||
|
@ -44,3 +62,91 @@ class DiffusionPolicy(nn.Module):
|
|||
# parameters passed to step
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
self.device = torch.device("cuda")
|
||||
self.diffusion.cuda()
|
||||
|
||||
self.ema = None
|
||||
if self.cfg.use_ema:
|
||||
self.ema = hydra.utils.instantiate(
|
||||
cfg_ema,
|
||||
model=copy.deepcopy(self.diffusion),
|
||||
)
|
||||
|
||||
self.optimizer = hydra.utils.instantiate(
|
||||
cfg_optimizer,
|
||||
params=self.diffusion.parameters(),
|
||||
)
|
||||
|
||||
# TODO(rcadene): modify lr scheduler so that it doesnt depend on epochs but steps
|
||||
self.global_step = 0
|
||||
|
||||
# configure lr scheduler
|
||||
self.lr_scheduler = get_scheduler(
|
||||
cfg.lr_scheduler,
|
||||
optimizer=self.optimizer,
|
||||
num_warmup_steps=cfg.lr_warmup_steps,
|
||||
num_training_steps=cfg.offline_steps,
|
||||
# pytorch assumes stepping LRScheduler every epoch
|
||||
# however huggingface diffusers steps it every batch
|
||||
last_epoch=self.global_step - 1,
|
||||
)
|
||||
|
||||
@torch.no_grad()
|
||||
def forward(self, observation, step_count):
|
||||
# TODO(rcadene): remove unused step_count
|
||||
del step_count
|
||||
|
||||
obs_dict = {
|
||||
# c h w -> b t c h w (b=1, t=1)
|
||||
"image": observation["image"][None, None, ...],
|
||||
"agent_pos": observation["state"][None, None, ...],
|
||||
}
|
||||
out = self.diffusion.predict_action(obs_dict)
|
||||
|
||||
# TODO(rcadene): add possibility to return >1 timestemps
|
||||
FIRST_ACTION = 0
|
||||
action = out["action"].squeeze(0)[FIRST_ACTION]
|
||||
return action
|
||||
|
||||
def update(self, replay_buffer, step):
|
||||
self.diffusion.train()
|
||||
|
||||
num_slices = self.cfg.batch_size
|
||||
batch_size = self.cfg.horizon * num_slices
|
||||
|
||||
assert batch_size % self.cfg.horizon == 0
|
||||
assert batch_size % num_slices == 0
|
||||
|
||||
def process_batch(batch, horizon, num_slices):
|
||||
# trajectory t = 256, horizon h = 5
|
||||
# (t h) ... -> h t ...
|
||||
batch = batch.reshape(num_slices, horizon).transpose(1, 0).contiguous()
|
||||
|
||||
out = {
|
||||
"obs": {
|
||||
"image": batch["observation", "image"].to(self.device),
|
||||
"agent_pos": batch["observation", "state"].to(self.device),
|
||||
},
|
||||
"action": batch["action"].to(self.device),
|
||||
}
|
||||
return out
|
||||
|
||||
batch = replay_buffer.sample(batch_size)
|
||||
batch = process_batch(batch, self.cfg.horizon, num_slices)
|
||||
|
||||
loss = self.diffusion.compute_loss(batch)
|
||||
loss.backward()
|
||||
|
||||
self.optimizer.step()
|
||||
self.optimizer.zero_grad()
|
||||
self.lr_scheduler.step()
|
||||
|
||||
if self.ema is not None:
|
||||
self.ema.step(self.diffusion)
|
||||
|
||||
metrics = {
|
||||
"total_loss": loss.item(),
|
||||
"lr": self.lr_scheduler.get_last_lr()[0],
|
||||
}
|
||||
return metrics
|
||||
|
|
|
@ -4,26 +4,15 @@ def make_policy(cfg):
|
|||
|
||||
policy = TDMPC(cfg.policy)
|
||||
elif cfg.policy.name == "diffusion":
|
||||
from diffusers.schedulers.scheduling_ddpm import DDPMScheduler
|
||||
from diffusion_policy.model.vision.model_getter import get_resnet
|
||||
from diffusion_policy.model.vision.multi_image_obs_encoder import (
|
||||
MultiImageObsEncoder,
|
||||
)
|
||||
|
||||
from lerobot.common.policies.diffusion import DiffusionPolicy
|
||||
|
||||
noise_scheduler = DDPMScheduler(**cfg.noise_scheduler)
|
||||
|
||||
rgb_model = get_resnet(**cfg.rgb_model)
|
||||
|
||||
obs_encoder = MultiImageObsEncoder(
|
||||
rgb_model=rgb_model,
|
||||
**cfg.obs_encoder,
|
||||
)
|
||||
|
||||
policy = DiffusionPolicy(
|
||||
noise_scheduler=noise_scheduler,
|
||||
obs_encoder=obs_encoder,
|
||||
cfg=cfg.policy,
|
||||
cfg_noise_scheduler=cfg.noise_scheduler,
|
||||
cfg_rgb_model=cfg.rgb_model,
|
||||
cfg_obs_encoder=cfg.obs_encoder,
|
||||
cfg_optimizer=cfg.optimizer,
|
||||
cfg_ema=cfg.ema,
|
||||
n_action_steps=cfg.n_action_steps + cfg.n_latency_steps,
|
||||
**cfg.policy,
|
||||
)
|
||||
|
|
|
@ -13,7 +13,7 @@ shape_meta:
|
|||
shape: [2]
|
||||
|
||||
horizon: 16
|
||||
n_obs_steps: 2
|
||||
n_obs_steps: 1 # TODO(rcadene): before 2
|
||||
n_action_steps: 8
|
||||
n_latency_steps: 0
|
||||
dataset_obs_steps: ${n_obs_steps}
|
||||
|
@ -51,6 +51,10 @@ policy:
|
|||
balanced_sampling: true
|
||||
|
||||
utd: 1
|
||||
offline_steps: ${offline_steps}
|
||||
use_ema: true
|
||||
lr_scheduler: cosine
|
||||
lr_warmup_steps: 500
|
||||
|
||||
noise_scheduler:
|
||||
# _target_: diffusers.schedulers.scheduling_ddpm.DDPMScheduler
|
||||
|
@ -99,13 +103,13 @@ training:
|
|||
debug: False
|
||||
resume: True
|
||||
# optimization
|
||||
lr_scheduler: cosine
|
||||
lr_warmup_steps: 500
|
||||
# lr_scheduler: cosine
|
||||
# lr_warmup_steps: 500
|
||||
num_epochs: 8000
|
||||
gradient_accumulate_every: 1
|
||||
# gradient_accumulate_every: 1
|
||||
# EMA destroys performance when used with BatchNorm
|
||||
# replace BatchNorm with GroupNorm.
|
||||
use_ema: True
|
||||
# use_ema: True
|
||||
freeze_encoder: False
|
||||
# training loop control
|
||||
# in epochs
|
||||
|
|
|
@ -62,7 +62,7 @@ policy:
|
|||
A_scaling: 3.0
|
||||
|
||||
# offline->online
|
||||
offline_steps: 25000 # ${train_steps}/2
|
||||
offline_steps: ${offline_steps}
|
||||
pretrained_model_path: ""
|
||||
# pretrained_model_path: "/home/rcadene/code/fowm/logs/xarm_lift/all/default/2/models/offline.pt"
|
||||
# pretrained_model_path: "/home/rcadene/code/fowm/logs/xarm_lift/all/default/2/models/final.pt"
|
||||
|
@ -73,4 +73,4 @@ policy:
|
|||
enc_dim: 256
|
||||
num_q: 5
|
||||
mlp_dim: 512
|
||||
latent_dim: 50
|
||||
latent_dim: 50
|
||||
|
|
|
@ -122,11 +122,6 @@ def train(cfg: dict, out_dir=None, job_name=None):
|
|||
start_time = time.time()
|
||||
step = 0 # number of policy update
|
||||
|
||||
print("First eval_policy_and_log with a random model or pretrained")
|
||||
eval_policy_and_log(
|
||||
env, td_policy, step, online_episode_idx, start_time, cfg, L, is_offline=True
|
||||
)
|
||||
|
||||
for offline_step in range(cfg.offline_steps):
|
||||
if offline_step == 0:
|
||||
print("Start offline training on a fixed dataset")
|
||||
|
|
Loading…
Reference in New Issue