198 lines
6.4 KiB
Python
198 lines
6.4 KiB
Python
import copy
|
|
import logging
|
|
import time
|
|
from collections import deque
|
|
|
|
import hydra
|
|
import torch
|
|
from torch import nn
|
|
|
|
from lerobot.common.policies.diffusion.diffusion_unet_image_policy import DiffusionUnetImagePolicy
|
|
from lerobot.common.policies.diffusion.model.lr_scheduler import get_scheduler
|
|
from lerobot.common.policies.diffusion.model.multi_image_obs_encoder import MultiImageObsEncoder, RgbEncoder
|
|
from lerobot.common.policies.utils import populate_queues
|
|
from lerobot.common.utils import get_safe_torch_device
|
|
|
|
|
|
class DiffusionPolicy(nn.Module):
|
|
name = "diffusion"
|
|
|
|
def __init__(
|
|
self,
|
|
cfg,
|
|
cfg_device,
|
|
cfg_noise_scheduler,
|
|
cfg_rgb_model,
|
|
cfg_obs_encoder,
|
|
cfg_optimizer,
|
|
cfg_ema,
|
|
shape_meta: dict,
|
|
horizon,
|
|
n_action_steps,
|
|
n_obs_steps,
|
|
num_inference_steps=None,
|
|
obs_as_global_cond=True,
|
|
diffusion_step_embed_dim=256,
|
|
down_dims=(256, 512, 1024),
|
|
kernel_size=5,
|
|
n_groups=8,
|
|
cond_predict_scale=True,
|
|
# parameters passed to step
|
|
**kwargs,
|
|
):
|
|
super().__init__()
|
|
self.cfg = cfg
|
|
self.n_obs_steps = n_obs_steps
|
|
self.n_action_steps = n_action_steps
|
|
# queues are populated during rollout of the policy, they contain the n latest observations and actions
|
|
self._queues = None
|
|
|
|
noise_scheduler = hydra.utils.instantiate(cfg_noise_scheduler)
|
|
rgb_model_input_shape = copy.deepcopy(shape_meta.obs.image.shape)
|
|
if cfg_obs_encoder.crop_shape is not None:
|
|
rgb_model_input_shape[1:] = cfg_obs_encoder.crop_shape
|
|
rgb_model = RgbEncoder(input_shape=rgb_model_input_shape, **cfg_rgb_model)
|
|
obs_encoder = MultiImageObsEncoder(
|
|
rgb_model=rgb_model,
|
|
**cfg_obs_encoder,
|
|
)
|
|
|
|
self.diffusion = DiffusionUnetImagePolicy(
|
|
shape_meta=shape_meta,
|
|
noise_scheduler=noise_scheduler,
|
|
obs_encoder=obs_encoder,
|
|
horizon=horizon,
|
|
n_action_steps=n_action_steps,
|
|
n_obs_steps=n_obs_steps,
|
|
num_inference_steps=num_inference_steps,
|
|
obs_as_global_cond=obs_as_global_cond,
|
|
diffusion_step_embed_dim=diffusion_step_embed_dim,
|
|
down_dims=down_dims,
|
|
kernel_size=kernel_size,
|
|
n_groups=n_groups,
|
|
cond_predict_scale=cond_predict_scale,
|
|
# parameters passed to step
|
|
**kwargs,
|
|
)
|
|
|
|
self.device = get_safe_torch_device(cfg_device)
|
|
self.diffusion.to(self.device)
|
|
|
|
self.ema_diffusion = None
|
|
self.ema = None
|
|
if self.cfg.use_ema:
|
|
self.ema_diffusion = copy.deepcopy(self.diffusion)
|
|
self.ema = hydra.utils.instantiate(
|
|
cfg_ema,
|
|
model=self.ema_diffusion,
|
|
)
|
|
|
|
self.optimizer = hydra.utils.instantiate(
|
|
cfg_optimizer,
|
|
params=self.diffusion.parameters(),
|
|
)
|
|
|
|
# TODO(rcadene): modify lr scheduler so that it doesnt depend on epochs but steps
|
|
self.global_step = 0
|
|
|
|
# configure lr scheduler
|
|
self.lr_scheduler = get_scheduler(
|
|
cfg.lr_scheduler,
|
|
optimizer=self.optimizer,
|
|
num_warmup_steps=cfg.lr_warmup_steps,
|
|
num_training_steps=cfg.offline_steps,
|
|
# pytorch assumes stepping LRScheduler every epoch
|
|
# however huggingface diffusers steps it every batch
|
|
last_epoch=self.global_step - 1,
|
|
)
|
|
|
|
def reset(self):
|
|
"""
|
|
Clear observation and action queues. Should be called on `env.reset()`
|
|
"""
|
|
self._queues = {
|
|
"observation.image": deque(maxlen=self.n_obs_steps),
|
|
"observation.state": deque(maxlen=self.n_obs_steps),
|
|
"action": deque(maxlen=self.n_action_steps),
|
|
}
|
|
|
|
@torch.no_grad()
|
|
def select_action(self, batch, step):
|
|
"""
|
|
Note: this uses the ema model weights if self.training == False, otherwise the non-ema model weights.
|
|
"""
|
|
# TODO(rcadene): remove unused step
|
|
del step
|
|
assert "observation.image" in batch
|
|
assert "observation.state" in batch
|
|
assert len(batch) == 2
|
|
|
|
self._queues = populate_queues(self._queues, batch)
|
|
|
|
if len(self._queues["action"]) == 0:
|
|
# stack n latest observations from the queue
|
|
batch = {key: torch.stack(list(self._queues[key]), dim=1) for key in batch}
|
|
|
|
obs_dict = {
|
|
"image": batch["observation.image"],
|
|
"agent_pos": batch["observation.state"],
|
|
}
|
|
if self.training:
|
|
out = self.diffusion.predict_action(obs_dict)
|
|
else:
|
|
out = self.ema_diffusion.predict_action(obs_dict)
|
|
self._queues["action"].extend(out["action"].transpose(0, 1))
|
|
|
|
action = self._queues["action"].popleft()
|
|
return action
|
|
|
|
def forward(self, batch, step):
|
|
start_time = time.time()
|
|
|
|
self.diffusion.train()
|
|
|
|
data_s = time.time() - start_time
|
|
loss = self.diffusion.compute_loss(batch)
|
|
loss.backward()
|
|
|
|
grad_norm = torch.nn.utils.clip_grad_norm_(
|
|
self.diffusion.parameters(),
|
|
self.cfg.grad_clip_norm,
|
|
error_if_nonfinite=False,
|
|
)
|
|
|
|
self.optimizer.step()
|
|
self.optimizer.zero_grad()
|
|
self.lr_scheduler.step()
|
|
|
|
if self.ema is not None:
|
|
self.ema.step(self.diffusion)
|
|
|
|
info = {
|
|
"loss": loss.item(),
|
|
"grad_norm": float(grad_norm),
|
|
"lr": self.lr_scheduler.get_last_lr()[0],
|
|
"data_s": data_s,
|
|
"update_s": time.time() - start_time,
|
|
}
|
|
|
|
# TODO(rcadene): remove hardcoding
|
|
# in diffusion_policy, len(dataloader) is 168 for a batch_size of 64
|
|
if step % 168 == 0:
|
|
self.global_step += 1
|
|
|
|
return info
|
|
|
|
def save(self, fp):
|
|
torch.save(self.state_dict(), fp)
|
|
|
|
def load(self, fp):
|
|
d = torch.load(fp)
|
|
missing_keys, unexpected_keys = self.load_state_dict(d, strict=False)
|
|
if len(missing_keys) > 0:
|
|
assert all(k.startswith("ema_diffusion.") for k in missing_keys)
|
|
logging.warning(
|
|
"DiffusionPolicy.load expected ema parameters in loaded state dict but none were found."
|
|
)
|
|
assert len(unexpected_keys) == 0
|