backup wip
This commit is contained in:
parent
5608e659e6
commit
03b08eb74e
|
@ -11,54 +11,54 @@ import torch
|
||||||
from omegaconf import OmegaConf
|
from omegaconf import OmegaConf
|
||||||
|
|
||||||
from lerobot.common.datasets.factory import make_dataset
|
from lerobot.common.datasets.factory import make_dataset
|
||||||
|
from lerobot.common.datasets.utils import cycle
|
||||||
|
from lerobot.common.policies.diffusion.configuration_diffusion import DiffusionConfig
|
||||||
from lerobot.common.policies.diffusion.modeling_diffusion import DiffusionPolicy
|
from lerobot.common.policies.diffusion.modeling_diffusion import DiffusionPolicy
|
||||||
from lerobot.common.utils import init_hydra_config
|
from lerobot.common.utils import init_hydra_config
|
||||||
|
|
||||||
output_directory = Path("outputs/train/example_pusht_diffusion")
|
output_directory = Path("outputs/train/example_pusht_diffusion")
|
||||||
os.makedirs(output_directory, exist_ok=True)
|
os.makedirs(output_directory, exist_ok=True)
|
||||||
|
|
||||||
overrides = [
|
# Number of offline training steps (we'll only do offline training for this example.
|
||||||
"env=pusht",
|
# Adjust as you prefer. 5000 steps are needed to get something worth evaluating.
|
||||||
"policy=diffusion",
|
training_steps = 5000
|
||||||
# Adjust as you prefer. 5000 steps are needed to get something worth evaluating.
|
device = torch.device("cuda")
|
||||||
"offline_steps=5000",
|
log_freq = 250
|
||||||
"log_freq=250",
|
|
||||||
"device=cuda",
|
|
||||||
]
|
|
||||||
|
|
||||||
cfg = init_hydra_config("lerobot/configs/default.yaml", overrides)
|
|
||||||
|
|
||||||
policy = DiffusionPolicy(
|
|
||||||
cfg=cfg.policy,
|
|
||||||
cfg_device=cfg.device,
|
|
||||||
cfg_noise_scheduler=cfg.noise_scheduler,
|
|
||||||
cfg_optimizer=cfg.optimizer,
|
|
||||||
cfg_ema=cfg.ema,
|
|
||||||
**cfg.policy,
|
|
||||||
)
|
|
||||||
policy.train()
|
|
||||||
|
|
||||||
|
# Set up the dataset.
|
||||||
|
cfg = init_hydra_config("lerobot/configs/default.yaml", overrides=["env=pusht"])
|
||||||
dataset = make_dataset(cfg)
|
dataset = make_dataset(cfg)
|
||||||
|
|
||||||
# create dataloader for offline training
|
# Set up the the policy.
|
||||||
|
# Policies are initialized with a configuration class, in this case `DiffusionConfig`.
|
||||||
|
# For this example, no arguments need to be passed because the defaults are set up for PushT.
|
||||||
|
# If you're doing something different, you will likely need to change at least some of the defaults.
|
||||||
|
cfg = DiffusionConfig()
|
||||||
|
# TODO(alexander-soare): Remove LR scheduler from the policy.
|
||||||
|
policy = DiffusionPolicy(cfg, lr_scheduler_num_training_steps=training_steps)
|
||||||
|
policy.train()
|
||||||
|
policy.to(device)
|
||||||
|
|
||||||
|
# Create dataloader for offline training.
|
||||||
dataloader = torch.utils.data.DataLoader(
|
dataloader = torch.utils.data.DataLoader(
|
||||||
dataset,
|
dataset,
|
||||||
num_workers=4,
|
num_workers=4,
|
||||||
batch_size=cfg.policy.batch_size,
|
batch_size=cfg.batch_size,
|
||||||
shuffle=True,
|
shuffle=True,
|
||||||
pin_memory=cfg.device != "cpu",
|
pin_memory=device != torch.device("cpu"),
|
||||||
drop_last=True,
|
drop_last=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
for step, batch in enumerate(dataloader):
|
# Run training loop.
|
||||||
info = policy(batch, step)
|
dataloader = cycle(dataloader)
|
||||||
|
for step in range(training_steps):
|
||||||
if step % cfg.log_freq == 0:
|
batch = {k: v.to(device, non_blocking=True) for k, v in next(dataloader).items()}
|
||||||
num_samples = (step + 1) * cfg.policy.batch_size
|
info = policy(batch)
|
||||||
|
if step % log_freq == 0:
|
||||||
|
num_samples = (step + 1) * cfg.batch_size
|
||||||
loss = info["loss"]
|
loss = info["loss"]
|
||||||
update_s = info["update_s"]
|
update_s = info["update_s"]
|
||||||
print(f"step:{step} samples:{num_samples} loss:{loss:.3f} update_time:{update_s:.3f}(seconds)")
|
print(f"step: {step} samples: {num_samples} loss: {loss:.3f} update_time: {update_s:.3f} (seconds)")
|
||||||
|
|
||||||
|
|
||||||
# Save the policy, configuration, and normalization stats for later use.
|
# Save the policy, configuration, and normalization stats for later use.
|
||||||
policy.save(output_directory / "model.pt")
|
policy.save(output_directory / "model.pt")
|
||||||
|
|
|
@ -208,6 +208,10 @@ def compute_stats(dataset, batch_size=32, max_num_samples=None):
|
||||||
|
|
||||||
|
|
||||||
def cycle(iterable):
|
def cycle(iterable):
|
||||||
|
"""The equivalent of itertools.cycle, but safe for Pytorch dataloaders.
|
||||||
|
|
||||||
|
See https://github.com/pytorch/pytorch/issues/23900 for information on why itertools.cycle is not safe.
|
||||||
|
"""
|
||||||
iterator = iter(iterable)
|
iterator = iter(iterable)
|
||||||
while True:
|
while True:
|
||||||
try:
|
try:
|
||||||
|
|
|
@ -26,8 +26,8 @@ class ActionChunkingTransformerConfig:
|
||||||
image_normalization_std: Value by which to divide the input image pixels (after the mean has been
|
image_normalization_std: Value by which to divide the input image pixels (after the mean has been
|
||||||
subtracted).
|
subtracted).
|
||||||
vision_backbone: Name of the torchvision resnet backbone to use for encoding images.
|
vision_backbone: Name of the torchvision resnet backbone to use for encoding images.
|
||||||
use_pretrained_backbone: Whether the backbone should be initialized with ImageNet, pretrained weights
|
use_pretrained_backbone: Whether the backbone should be initialized with pretrained weights from
|
||||||
from torchvision.
|
torchvision.
|
||||||
replace_final_stride_with_dilation: Whether to replace the ResNet's final 2x2 stride with a dilated
|
replace_final_stride_with_dilation: Whether to replace the ResNet's final 2x2 stride with a dilated
|
||||||
convolution.
|
convolution.
|
||||||
pre_norm: Whether to use "pre-norm" in the transformer blocks.
|
pre_norm: Whether to use "pre-norm" in the transformer blocks.
|
||||||
|
|
|
@ -13,9 +13,49 @@ class DiffusionConfig:
|
||||||
Args:
|
Args:
|
||||||
state_dim: Dimensionality of the observation state space (excluding images).
|
state_dim: Dimensionality of the observation state space (excluding images).
|
||||||
action_dim: Dimensionality of the action space.
|
action_dim: Dimensionality of the action space.
|
||||||
|
image_size: (H, W) size of the input images.
|
||||||
n_obs_steps: Number of environment steps worth of observations to pass to the policy (takes the
|
n_obs_steps: Number of environment steps worth of observations to pass to the policy (takes the
|
||||||
current step and additional steps going back).
|
current step and additional steps going back).
|
||||||
horizon: Diffusion model action prediction horizon as detailed in the main policy documentation.
|
horizon: Diffusion model action prediction size as detailed in `DiffusionPolicy.select_action`.
|
||||||
|
n_action_steps: The number of action steps to run in the environment for one invocation of the policy.
|
||||||
|
See `DiffusionPolicy.select_action` for more details.
|
||||||
|
image_normalization_mean: Value to subtract from the input image pixels (inputs are assumed to be in
|
||||||
|
[0, 1]) for normalization.
|
||||||
|
image_normalization_std: Value by which to divide the input image pixels (after the mean has been
|
||||||
|
subtracted).
|
||||||
|
vision_backbone: Name of the torchvision resnet backbone to use for encoding images.
|
||||||
|
crop_shape: (H, W) shape to crop images to as a preprocessing step for the vision backbone. Must fit
|
||||||
|
within the image size. If None, no cropping is done.
|
||||||
|
crop_is_random: Whether the crop should be random at training time (it's always a center crop in eval
|
||||||
|
mode).
|
||||||
|
use_pretrained_backbone: Whether the backbone should be initialized with pretrained weights from
|
||||||
|
torchvision.
|
||||||
|
use_group_norm: Whether to replace batch normalization with group normalization in the backbone.
|
||||||
|
The group sizes are set to be about 16 (to be precise, feature_dim // 16).
|
||||||
|
spatial_softmax_num_keypoints: Number of keypoints for SpatialSoftmax.
|
||||||
|
down_dims: Feature dimension for each stage of temporal downsampling in the diffusion modeling Unet.
|
||||||
|
You may provide a variable number of dimensions, therefore also controlling the degree of
|
||||||
|
downsampling.
|
||||||
|
kernel_size: The convolutional kernel size of the diffusion modeling Unet.
|
||||||
|
n_groups: Number of groups used in the group norm of the Unet's convolutional blocks.
|
||||||
|
diffusion_step_embed_dim: The Unet is conditioned on the diffusion timestep via a small non-linear
|
||||||
|
network. This is the output dimension of that network, i.e., the embedding dimension.
|
||||||
|
use_film_scale_modulation: FiLM (https://arxiv.org/abs/1709.07871) is used for the Unet conditioning.
|
||||||
|
Bias modulation is used be default, while this parameter indicates whether to also use scale
|
||||||
|
modulation.
|
||||||
|
num_train_timesteps: Number of diffusion steps for the forward diffusion schedule.
|
||||||
|
beta_schedule: Name of the diffusion beta schedule as per DDPMScheduler from Hugging Face diffusers.
|
||||||
|
beta_start: Beta value for the first forward-diffusion step.
|
||||||
|
beta_end: Beta value for the last forward-diffusion step.
|
||||||
|
prediction_type: The type of prediction that the diffusion modeling Unet makes. Choose from "epsilon"
|
||||||
|
or "sample". These have equivalent outcomes from a latent variable modeling perspective, but
|
||||||
|
"epsilon" has been shown to work better in many deep neural network settings.
|
||||||
|
clip_sample: Whether to clip the sample to [-`clip_sample_range`, +`clip_sample_range`] for each
|
||||||
|
denoising step at inference time. WARNING: you will need to make sure your action-space is
|
||||||
|
normalized to fit within this range.
|
||||||
|
clip_sample_range: The magnitude of the clipping range as described above.
|
||||||
|
num_inference_steps: Number of reverse diffusion steps to use at inference time (steps are evenly
|
||||||
|
spaced). If not provided, this defaults to be the same as `num_train_timesteps`.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
# Environment.
|
# Environment.
|
||||||
|
@ -36,7 +76,7 @@ class DiffusionConfig:
|
||||||
# Architecture / modeling.
|
# Architecture / modeling.
|
||||||
# Vision backbone.
|
# Vision backbone.
|
||||||
vision_backbone: str = "resnet18"
|
vision_backbone: str = "resnet18"
|
||||||
crop_shape: tuple[int, int] = (84, 84)
|
crop_shape: tuple[int, int] | None = (84, 84)
|
||||||
crop_is_random: bool = True
|
crop_is_random: bool = True
|
||||||
use_pretrained_backbone: bool = False
|
use_pretrained_backbone: bool = False
|
||||||
use_group_norm: bool = True
|
use_group_norm: bool = True
|
||||||
|
@ -46,18 +86,18 @@ class DiffusionConfig:
|
||||||
kernel_size: int = 5
|
kernel_size: int = 5
|
||||||
n_groups: int = 8
|
n_groups: int = 8
|
||||||
diffusion_step_embed_dim: int = 128
|
diffusion_step_embed_dim: int = 128
|
||||||
film_scale_modulation: bool = True
|
use_film_scale_modulation: bool = True
|
||||||
# Noise scheduler.
|
# Noise scheduler.
|
||||||
num_train_timesteps: int = 100
|
num_train_timesteps: int = 100
|
||||||
beta_schedule: str = "squaredcos_cap_v2"
|
beta_schedule: str = "squaredcos_cap_v2"
|
||||||
beta_start: float = 0.0001
|
beta_start: float = 0.0001
|
||||||
beta_end: float = 0.02
|
beta_end: float = 0.02
|
||||||
variance_type: str = "fixed_small"
|
|
||||||
prediction_type: str = "epsilon"
|
prediction_type: str = "epsilon"
|
||||||
clip_sample: True
|
clip_sample: bool = True
|
||||||
|
clip_sample_range: float = 1.0
|
||||||
|
|
||||||
# Inference
|
# Inference
|
||||||
num_inference_steps: int = 100
|
num_inference_steps: int | None = None
|
||||||
|
|
||||||
# ---
|
# ---
|
||||||
# TODO(alexander-soare): Remove these from the policy config.
|
# TODO(alexander-soare): Remove these from the policy config.
|
||||||
|
@ -72,12 +112,24 @@ class DiffusionConfig:
|
||||||
utd: int = 1
|
utd: int = 1
|
||||||
use_ema: bool = True
|
use_ema: bool = True
|
||||||
ema_update_after_step: int = 0
|
ema_update_after_step: int = 0
|
||||||
ema_min_rate: float = 0.0
|
ema_min_alpha: float = 0.0
|
||||||
ema_max_rate: float = 0.9999
|
ema_max_alpha: float = 0.9999
|
||||||
ema_inv_gamma: float = 1.0
|
ema_inv_gamma: float = 1.0
|
||||||
ema_power: float = 0.75
|
ema_power: float = 0.75
|
||||||
|
|
||||||
def __post_init__(self):
|
def __post_init__(self):
|
||||||
"""Input validation (not exhaustive)."""
|
"""Input validation (not exhaustive)."""
|
||||||
if not self.vision_backbone.startswith("resnet"):
|
if not self.vision_backbone.startswith("resnet"):
|
||||||
raise ValueError("`vision_backbone` must be one of the ResNet variants.")
|
raise ValueError(
|
||||||
|
f"`vision_backbone` must be one of the ResNet variants. Got {self.vision_backbone}."
|
||||||
|
)
|
||||||
|
if self.crop_shape[0] > self.image_size[0] or self.crop_shape[1] > self.image_size[1]:
|
||||||
|
raise ValueError(
|
||||||
|
f"`crop_shape` should fit within `image_size`. Got {self.crop_shape} for `crop_shape` and "
|
||||||
|
f"{self.image_size} for `image_size`."
|
||||||
|
)
|
||||||
|
supported_prediction_types = ["epsilon", "sample"]
|
||||||
|
if self.prediction_type not in supported_prediction_types:
|
||||||
|
raise ValueError(
|
||||||
|
f"`prediction_type` must be one of {supported_prediction_types}. Got {self.prediction_type}."
|
||||||
|
)
|
||||||
|
|
|
@ -1,8 +1,10 @@
|
||||||
"""
|
"""
|
||||||
TODO(alexander-soare):
|
TODO(alexander-soare):
|
||||||
- Remove reliance on Robomimic for SpatialSoftmax.
|
- Remove reliance on Robomimic for SpatialSoftmax.
|
||||||
- Remove reliance on diffusers for DDPMScheduler.
|
- Remove reliance on diffusers for DDPMScheduler and LR scheduler.
|
||||||
- Move EMA out of policy.
|
- Move EMA out of policy.
|
||||||
|
- Consolidate _DiffusionUnetImagePolicy into DiffusionPolicy.
|
||||||
|
- One more pass on comments and documentation.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import copy
|
import copy
|
||||||
|
@ -10,10 +12,10 @@ import logging
|
||||||
import math
|
import math
|
||||||
import time
|
import time
|
||||||
from collections import deque
|
from collections import deque
|
||||||
|
from itertools import chain
|
||||||
from typing import Callable
|
from typing import Callable
|
||||||
|
|
||||||
import einops
|
import einops
|
||||||
import hydra
|
|
||||||
import torch
|
import torch
|
||||||
import torch.nn.functional as F # noqa: N812
|
import torch.nn.functional as F # noqa: N812
|
||||||
import torchvision
|
import torchvision
|
||||||
|
@ -23,12 +25,12 @@ from robomimic.models.base_nets import SpatialSoftmax
|
||||||
from torch import Tensor, nn
|
from torch import Tensor, nn
|
||||||
from torch.nn.modules.batchnorm import _BatchNorm
|
from torch.nn.modules.batchnorm import _BatchNorm
|
||||||
|
|
||||||
|
from lerobot.common.policies.diffusion.configuration_diffusion import DiffusionConfig
|
||||||
from lerobot.common.policies.utils import (
|
from lerobot.common.policies.utils import (
|
||||||
get_device_from_parameters,
|
get_device_from_parameters,
|
||||||
get_dtype_from_parameters,
|
get_dtype_from_parameters,
|
||||||
populate_queues,
|
populate_queues,
|
||||||
)
|
)
|
||||||
from lerobot.common.utils import get_safe_torch_device
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
@ -41,69 +43,29 @@ class DiffusionPolicy(nn.Module):
|
||||||
|
|
||||||
name = "diffusion"
|
name = "diffusion"
|
||||||
|
|
||||||
def __init__(
|
def __init__(self, cfg: DiffusionConfig, lr_scheduler_num_training_steps: int):
|
||||||
self,
|
|
||||||
cfg,
|
|
||||||
cfg_device,
|
|
||||||
cfg_noise_scheduler,
|
|
||||||
cfg_optimizer,
|
|
||||||
cfg_ema,
|
|
||||||
shape_meta: dict,
|
|
||||||
horizon,
|
|
||||||
n_action_steps,
|
|
||||||
n_obs_steps,
|
|
||||||
num_inference_steps=None,
|
|
||||||
diffusion_step_embed_dim=256,
|
|
||||||
down_dims=(256, 512, 1024),
|
|
||||||
kernel_size=5,
|
|
||||||
n_groups=8,
|
|
||||||
film_scale_modulation=True,
|
|
||||||
**_,
|
|
||||||
):
|
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.cfg = cfg
|
self.cfg = cfg
|
||||||
self.n_obs_steps = n_obs_steps
|
|
||||||
self.n_action_steps = n_action_steps
|
|
||||||
|
|
||||||
# queues are populated during rollout of the policy, they contain the n latest observations and actions
|
# queues are populated during rollout of the policy, they contain the n latest observations and actions
|
||||||
self._queues = None
|
self._queues = None
|
||||||
|
|
||||||
noise_scheduler = hydra.utils.instantiate(cfg_noise_scheduler)
|
self.diffusion = _DiffusionUnetImagePolicy(cfg)
|
||||||
|
|
||||||
self.diffusion = _DiffusionUnetImagePolicy(
|
|
||||||
cfg,
|
|
||||||
shape_meta=shape_meta,
|
|
||||||
noise_scheduler=noise_scheduler,
|
|
||||||
horizon=horizon,
|
|
||||||
n_action_steps=n_action_steps,
|
|
||||||
n_obs_steps=n_obs_steps,
|
|
||||||
num_inference_steps=num_inference_steps,
|
|
||||||
diffusion_step_embed_dim=diffusion_step_embed_dim,
|
|
||||||
down_dims=down_dims,
|
|
||||||
kernel_size=kernel_size,
|
|
||||||
n_groups=n_groups,
|
|
||||||
film_scale_modulation=film_scale_modulation,
|
|
||||||
)
|
|
||||||
|
|
||||||
self.device = get_safe_torch_device(cfg_device)
|
|
||||||
self.diffusion.to(self.device)
|
|
||||||
|
|
||||||
# TODO(alexander-soare): This should probably be managed outside of the policy class.
|
# TODO(alexander-soare): This should probably be managed outside of the policy class.
|
||||||
self.ema_diffusion = None
|
self.ema_diffusion = None
|
||||||
self.ema = None
|
self.ema = None
|
||||||
if self.cfg.use_ema:
|
if self.cfg.use_ema:
|
||||||
self.ema_diffusion = copy.deepcopy(self.diffusion)
|
self.ema_diffusion = copy.deepcopy(self.diffusion)
|
||||||
self.ema = hydra.utils.instantiate(
|
self.ema = _EMA(cfg, model=self.ema_diffusion)
|
||||||
cfg_ema,
|
|
||||||
model=self.ema_diffusion,
|
|
||||||
)
|
|
||||||
|
|
||||||
self.optimizer = hydra.utils.instantiate(
|
# TODO(alexander-soare): Move optimizer out of policy.
|
||||||
cfg_optimizer,
|
self.optimizer = torch.optim.Adam(
|
||||||
params=self.diffusion.parameters(),
|
self.diffusion.parameters(), cfg.lr, cfg.adam_betas, cfg.adam_eps, cfg.adam_weight_decay
|
||||||
)
|
)
|
||||||
|
|
||||||
# TODO(rcadene): modify lr scheduler so that it doesnt depend on epochs but steps
|
# TODO(alexander-soare): Move LR scheduler out of policy.
|
||||||
|
# TODO(rcadene): modify lr scheduler so that it doesn't depend on epochs but steps
|
||||||
self.global_step = 0
|
self.global_step = 0
|
||||||
|
|
||||||
# configure lr scheduler
|
# configure lr scheduler
|
||||||
|
@ -111,7 +73,7 @@ class DiffusionPolicy(nn.Module):
|
||||||
cfg.lr_scheduler,
|
cfg.lr_scheduler,
|
||||||
optimizer=self.optimizer,
|
optimizer=self.optimizer,
|
||||||
num_warmup_steps=cfg.lr_warmup_steps,
|
num_warmup_steps=cfg.lr_warmup_steps,
|
||||||
num_training_steps=cfg.offline_steps,
|
num_training_steps=lr_scheduler_num_training_steps,
|
||||||
# pytorch assumes stepping LRScheduler every epoch
|
# pytorch assumes stepping LRScheduler every epoch
|
||||||
# however huggingface diffusers steps it every batch
|
# however huggingface diffusers steps it every batch
|
||||||
last_epoch=self.global_step - 1,
|
last_epoch=self.global_step - 1,
|
||||||
|
@ -122,9 +84,9 @@ class DiffusionPolicy(nn.Module):
|
||||||
Clear observation and action queues. Should be called on `env.reset()`
|
Clear observation and action queues. Should be called on `env.reset()`
|
||||||
"""
|
"""
|
||||||
self._queues = {
|
self._queues = {
|
||||||
"observation.image": deque(maxlen=self.n_obs_steps),
|
"observation.image": deque(maxlen=self.cfg.n_obs_steps),
|
||||||
"observation.state": deque(maxlen=self.n_obs_steps),
|
"observation.state": deque(maxlen=self.cfg.n_obs_steps),
|
||||||
"action": deque(maxlen=self.n_action_steps),
|
"action": deque(maxlen=self.cfg.n_action_steps),
|
||||||
}
|
}
|
||||||
|
|
||||||
@torch.no_grad
|
@torch.no_grad
|
||||||
|
@ -138,11 +100,13 @@ class DiffusionPolicy(nn.Module):
|
||||||
- The diffusion model generates `horizon` steps worth of actions.
|
- The diffusion model generates `horizon` steps worth of actions.
|
||||||
- `n_action_steps` worth of actions are actually kept for execution, starting from the current step.
|
- `n_action_steps` worth of actions are actually kept for execution, starting from the current step.
|
||||||
Schematically this looks like:
|
Schematically this looks like:
|
||||||
|
----------------------------------------------------------------------------------------------
|
||||||
(legend: o = n_obs_steps, h = horizon, a = n_action_steps)
|
(legend: o = n_obs_steps, h = horizon, a = n_action_steps)
|
||||||
|timestep | n-o+1 | n-o+2 | ..... | n | ..... | n+a-1 | n+a | ..... |n-o+1+h|
|
|timestep | n-o+1 | n-o+2 | ..... | n | ..... | n+a-1 | n+a | ..... |n-o+1+h|
|
||||||
|observation is used | YES | YES | ..... | NO | NO | NO | NO | NO | NO |
|
|observation is used | YES | YES | YES | NO | NO | NO | NO | NO | NO |
|
||||||
|action is generated | YES | YES | YES | YES | YES | YES | YES | YES | YES |
|
|action is generated | YES | YES | YES | YES | YES | YES | YES | YES | YES |
|
||||||
|action is used | NO | NO | NO | YES | YES | YES | NO | NO | NO |
|
|action is used | NO | NO | NO | YES | YES | YES | NO | NO | NO |
|
||||||
|
----------------------------------------------------------------------------------------------
|
||||||
Note that this means we require: `n_action_steps < horizon - n_obs_steps + 1`. Also, note that
|
Note that this means we require: `n_action_steps < horizon - n_obs_steps + 1`. Also, note that
|
||||||
"horizon" may not the best name to describe what the variable actually means, because this period is
|
"horizon" may not the best name to describe what the variable actually means, because this period is
|
||||||
actually measured from the first observation which (if `n_obs_steps` > 1) happened in the past.
|
actually measured from the first observation which (if `n_obs_steps` > 1) happened in the past.
|
||||||
|
@ -213,57 +177,41 @@ class DiffusionPolicy(nn.Module):
|
||||||
|
|
||||||
|
|
||||||
class _DiffusionUnetImagePolicy(nn.Module):
|
class _DiffusionUnetImagePolicy(nn.Module):
|
||||||
def __init__(
|
def __init__(self, cfg: DiffusionConfig):
|
||||||
self,
|
|
||||||
cfg,
|
|
||||||
shape_meta: dict,
|
|
||||||
noise_scheduler: DDPMScheduler,
|
|
||||||
horizon,
|
|
||||||
n_action_steps,
|
|
||||||
n_obs_steps,
|
|
||||||
num_inference_steps=None,
|
|
||||||
diffusion_step_embed_dim=256,
|
|
||||||
down_dims=(256, 512, 1024),
|
|
||||||
kernel_size=5,
|
|
||||||
n_groups=8,
|
|
||||||
film_scale_modulation=True,
|
|
||||||
):
|
|
||||||
super().__init__()
|
super().__init__()
|
||||||
action_shape = shape_meta["action"]["shape"]
|
self.cfg = cfg
|
||||||
assert len(action_shape) == 1
|
|
||||||
action_dim = action_shape[0]
|
|
||||||
|
|
||||||
self.rgb_encoder = _RgbEncoder(input_shape=shape_meta.obs.image.shape, **cfg.rgb_encoder)
|
|
||||||
|
|
||||||
|
self.rgb_encoder = _RgbEncoder(cfg)
|
||||||
self.unet = _ConditionalUnet1D(
|
self.unet = _ConditionalUnet1D(
|
||||||
input_dim=action_dim,
|
cfg, global_cond_dim=(cfg.action_dim + self.rgb_encoder.feature_dim) * cfg.n_obs_steps
|
||||||
global_cond_dim=(action_dim + self.rgb_encoder.feature_dim) * n_obs_steps,
|
|
||||||
diffusion_step_embed_dim=diffusion_step_embed_dim,
|
|
||||||
down_dims=down_dims,
|
|
||||||
kernel_size=kernel_size,
|
|
||||||
n_groups=n_groups,
|
|
||||||
film_scale_modulation=film_scale_modulation,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
self.noise_scheduler = noise_scheduler
|
self.noise_scheduler = DDPMScheduler(
|
||||||
self.horizon = horizon
|
num_train_timesteps=cfg.num_train_timesteps,
|
||||||
self.action_dim = action_dim
|
beta_start=cfg.beta_start,
|
||||||
self.n_action_steps = n_action_steps
|
beta_end=cfg.beta_end,
|
||||||
self.n_obs_steps = n_obs_steps
|
beta_schedule=cfg.beta_schedule,
|
||||||
|
variance_type="fixed_small",
|
||||||
|
clip_sample=cfg.clip_sample,
|
||||||
|
clip_sample_range=cfg.clip_sample_range,
|
||||||
|
prediction_type=cfg.prediction_type,
|
||||||
|
)
|
||||||
|
|
||||||
if num_inference_steps is None:
|
if cfg.num_inference_steps is None:
|
||||||
num_inference_steps = noise_scheduler.config.num_train_timesteps
|
self.num_inference_steps = self.noise_scheduler.config.num_train_timesteps
|
||||||
|
else:
|
||||||
self.num_inference_steps = num_inference_steps
|
self.num_inference_steps = cfg.num_inference_steps
|
||||||
|
|
||||||
# ========= inference ============
|
# ========= inference ============
|
||||||
def conditional_sample(self, batch_size, global_cond=None, generator=None):
|
def conditional_sample(
|
||||||
|
self, batch_size: int, global_cond: Tensor | None = None, generator: torch.Generator | None = None
|
||||||
|
) -> Tensor:
|
||||||
device = get_device_from_parameters(self)
|
device = get_device_from_parameters(self)
|
||||||
dtype = get_dtype_from_parameters(self)
|
dtype = get_dtype_from_parameters(self)
|
||||||
|
|
||||||
# Sample prior.
|
# Sample prior.
|
||||||
sample = torch.randn(
|
sample = torch.randn(
|
||||||
size=(batch_size, self.horizon, self.action_dim),
|
size=(batch_size, self.cfg.horizon, self.cfg.action_dim),
|
||||||
dtype=dtype,
|
dtype=dtype,
|
||||||
device=device,
|
device=device,
|
||||||
generator=generator,
|
generator=generator,
|
||||||
|
@ -283,7 +231,7 @@ class _DiffusionUnetImagePolicy(nn.Module):
|
||||||
|
|
||||||
return sample
|
return sample
|
||||||
|
|
||||||
def generate_actions(self, batch: dict[str, Tensor]) -> dict[str, Tensor]:
|
def generate_actions(self, batch: dict[str, Tensor]) -> Tensor:
|
||||||
"""
|
"""
|
||||||
This function expects `batch` to have (at least):
|
This function expects `batch` to have (at least):
|
||||||
{
|
{
|
||||||
|
@ -293,8 +241,7 @@ class _DiffusionUnetImagePolicy(nn.Module):
|
||||||
"""
|
"""
|
||||||
assert set(batch).issuperset({"observation.state", "observation.image"})
|
assert set(batch).issuperset({"observation.state", "observation.image"})
|
||||||
batch_size, n_obs_steps = batch["observation.state"].shape[:2]
|
batch_size, n_obs_steps = batch["observation.state"].shape[:2]
|
||||||
assert n_obs_steps == self.n_obs_steps
|
assert n_obs_steps == self.cfg.n_obs_steps
|
||||||
assert self.n_obs_steps == n_obs_steps
|
|
||||||
|
|
||||||
# Extract image feature (first combine batch and sequence dims).
|
# Extract image feature (first combine batch and sequence dims).
|
||||||
img_features = self.rgb_encoder(einops.rearrange(batch["observation.image"], "b n ... -> (b n) ..."))
|
img_features = self.rgb_encoder(einops.rearrange(batch["observation.image"], "b n ... -> (b n) ..."))
|
||||||
|
@ -307,13 +254,13 @@ class _DiffusionUnetImagePolicy(nn.Module):
|
||||||
sample = self.conditional_sample(batch_size, global_cond=global_cond)
|
sample = self.conditional_sample(batch_size, global_cond=global_cond)
|
||||||
|
|
||||||
# `horizon` steps worth of actions (from the first observation).
|
# `horizon` steps worth of actions (from the first observation).
|
||||||
action = sample[..., : self.action_dim]
|
actions = sample[..., : self.cfg.action_dim]
|
||||||
# Extract `n_action_steps` steps worth of actions (from the current observation).
|
# Extract `n_action_steps` steps worth of actions (from the current observation).
|
||||||
start = n_obs_steps - 1
|
start = n_obs_steps - 1
|
||||||
end = start + self.n_action_steps
|
end = start + self.cfg.n_action_steps
|
||||||
action = action[:, start:end]
|
actions = actions[:, start:end]
|
||||||
|
|
||||||
return action
|
return actions
|
||||||
|
|
||||||
def compute_loss(self, batch: dict[str, Tensor]) -> Tensor:
|
def compute_loss(self, batch: dict[str, Tensor]) -> Tensor:
|
||||||
"""
|
"""
|
||||||
|
@ -329,9 +276,8 @@ class _DiffusionUnetImagePolicy(nn.Module):
|
||||||
assert set(batch).issuperset({"observation.state", "observation.image", "action", "action_is_pad"})
|
assert set(batch).issuperset({"observation.state", "observation.image", "action", "action_is_pad"})
|
||||||
batch_size, n_obs_steps = batch["observation.state"].shape[:2]
|
batch_size, n_obs_steps = batch["observation.state"].shape[:2]
|
||||||
horizon = batch["action"].shape[1]
|
horizon = batch["action"].shape[1]
|
||||||
assert horizon == self.horizon
|
assert horizon == self.cfg.horizon
|
||||||
assert n_obs_steps == self.n_obs_steps
|
assert n_obs_steps == self.cfg.n_obs_steps
|
||||||
assert self.n_obs_steps == n_obs_steps
|
|
||||||
|
|
||||||
# Extract image feature (first combine batch and sequence dims).
|
# Extract image feature (first combine batch and sequence dims).
|
||||||
img_features = self.rgb_encoder(einops.rearrange(batch["observation.image"], "b n ... -> (b n) ..."))
|
img_features = self.rgb_encoder(einops.rearrange(batch["observation.image"], "b n ... -> (b n) ..."))
|
||||||
|
@ -359,14 +305,13 @@ class _DiffusionUnetImagePolicy(nn.Module):
|
||||||
pred = self.unet(noisy_trajectory, timesteps, global_cond=global_cond)
|
pred = self.unet(noisy_trajectory, timesteps, global_cond=global_cond)
|
||||||
|
|
||||||
# Compute the loss.
|
# Compute the loss.
|
||||||
# The targe is either the original trajectory, or the noise.
|
# The target is either the original trajectory, or the noise.
|
||||||
pred_type = self.noise_scheduler.config.prediction_type
|
if self.cfg.prediction_type == "epsilon":
|
||||||
if pred_type == "epsilon":
|
|
||||||
target = eps
|
target = eps
|
||||||
elif pred_type == "sample":
|
elif self.cfg.prediction_type == "sample":
|
||||||
target = batch["action"]
|
target = batch["action"]
|
||||||
else:
|
else:
|
||||||
raise ValueError(f"Unsupported prediction type {pred_type}")
|
raise ValueError(f"Unsupported prediction type {self.cfg.prediction_type}")
|
||||||
|
|
||||||
loss = F.mse_loss(pred, target, reduction="none")
|
loss = F.mse_loss(pred, target, reduction="none")
|
||||||
|
|
||||||
|
@ -384,64 +329,35 @@ class _RgbEncoder(nn.Module):
|
||||||
Includes the ability to normalize and crop the image first.
|
Includes the ability to normalize and crop the image first.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(
|
def __init__(self, cfg: DiffusionConfig):
|
||||||
self,
|
|
||||||
input_shape: tuple[int, int, int],
|
|
||||||
norm_mean_std: tuple[float, float] = [1.0, 1.0],
|
|
||||||
crop_shape: tuple[int, int] | None = None,
|
|
||||||
random_crop: bool = False,
|
|
||||||
backbone_name: str = "resnet18",
|
|
||||||
pretrained_backbone: bool = False,
|
|
||||||
use_group_norm: bool = False,
|
|
||||||
num_keypoints: int = 32,
|
|
||||||
):
|
|
||||||
"""
|
|
||||||
Args:
|
|
||||||
input_shape: channel-first input shape (C, H, W)
|
|
||||||
norm_mean_std: mean and standard deviation used for image normalization. Images are normalized as
|
|
||||||
(image - mean) / std.
|
|
||||||
crop_shape: (H, W) shape to crop to (must fit within the input shape). If not provided, no
|
|
||||||
cropping is done.
|
|
||||||
random_crop: Whether the crop should be random at training time (it's always a center crop in
|
|
||||||
eval mode).
|
|
||||||
backbone_name: The name of one of the available resnet models from torchvision (eg resnet18).
|
|
||||||
pretrained_backbone: whether to use timm pretrained weights.
|
|
||||||
use_group_norm: Whether to replace batch normalization with group normalization in the backbone.
|
|
||||||
The group sizes are set to be about 16 (to be precise, feature_dim // 16).
|
|
||||||
num_keypoints: Number of keypoints for SpatialSoftmax (default value of 32 matches PushT Image).
|
|
||||||
"""
|
|
||||||
super().__init__()
|
super().__init__()
|
||||||
if input_shape[0] != 3:
|
|
||||||
raise ValueError("Only RGB images are handled")
|
|
||||||
if not backbone_name.startswith("resnet"):
|
|
||||||
raise ValueError(
|
|
||||||
"Only resnet is supported for now (because of the assumption that 'layer4' is the output layer)"
|
|
||||||
)
|
|
||||||
|
|
||||||
# Set up optional preprocessing.
|
# Set up optional preprocessing.
|
||||||
if norm_mean_std == [1.0, 1.0]:
|
if all(v == 1.0 for v in chain(cfg.image_normalization_mean, cfg.image_normalization_std)):
|
||||||
self.normalizer = nn.Identity()
|
self.normalizer = nn.Identity()
|
||||||
else:
|
else:
|
||||||
self.normalizer = torchvision.transforms.Normalize(mean=norm_mean_std[0], std=norm_mean_std[1])
|
self.normalizer = torchvision.transforms.Normalize(
|
||||||
|
mean=cfg.image_normalization_mean, std=cfg.image_normalization_std
|
||||||
if crop_shape is not None:
|
)
|
||||||
|
if cfg.crop_shape is not None:
|
||||||
self.do_crop = True
|
self.do_crop = True
|
||||||
# Always use center crop for eval
|
# Always use center crop for eval
|
||||||
self.center_crop = torchvision.transforms.CenterCrop(crop_shape)
|
self.center_crop = torchvision.transforms.CenterCrop(cfg.crop_shape)
|
||||||
if random_crop:
|
if cfg.crop_is_random:
|
||||||
self.maybe_random_crop = torchvision.transforms.RandomCrop(crop_shape)
|
self.maybe_random_crop = torchvision.transforms.RandomCrop(cfg.crop_shape)
|
||||||
else:
|
else:
|
||||||
self.maybe_random_crop = self.center_crop
|
self.maybe_random_crop = self.center_crop
|
||||||
else:
|
else:
|
||||||
self.do_crop = False
|
self.do_crop = False
|
||||||
|
|
||||||
# Set up backbone.
|
# Set up backbone.
|
||||||
backbone_model = getattr(torchvision.models, backbone_name)(pretrained=pretrained_backbone)
|
backbone_model = getattr(torchvision.models, cfg.vision_backbone)(
|
||||||
|
pretrained=cfg.use_pretrained_backbone
|
||||||
|
)
|
||||||
# Note: This assumes that the layer4 feature map is children()[-3]
|
# Note: This assumes that the layer4 feature map is children()[-3]
|
||||||
# TODO(alexander-soare): Use a safer alternative.
|
# TODO(alexander-soare): Use a safer alternative.
|
||||||
self.backbone = nn.Sequential(*(list(backbone_model.children())[:-2]))
|
self.backbone = nn.Sequential(*(list(backbone_model.children())[:-2]))
|
||||||
if use_group_norm:
|
if cfg.use_group_norm:
|
||||||
if pretrained_backbone:
|
if cfg.use_pretrained_backbone:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
"You can't replace BatchNorm in a pretrained model without ruining the weights!"
|
"You can't replace BatchNorm in a pretrained model without ruining the weights!"
|
||||||
)
|
)
|
||||||
|
@ -454,10 +370,10 @@ class _RgbEncoder(nn.Module):
|
||||||
# Set up pooling and final layers.
|
# Set up pooling and final layers.
|
||||||
# Use a dry run to get the feature map shape.
|
# Use a dry run to get the feature map shape.
|
||||||
with torch.inference_mode():
|
with torch.inference_mode():
|
||||||
feat_map_shape = tuple(self.backbone(torch.zeros(size=(1, *input_shape))).shape[1:])
|
feat_map_shape = tuple(self.backbone(torch.zeros(size=(1, 3, *cfg.image_size))).shape[1:])
|
||||||
self.pool = SpatialSoftmax(feat_map_shape, num_kp=num_keypoints)
|
self.pool = SpatialSoftmax(feat_map_shape, num_kp=cfg.spatial_softmax_num_keypoints)
|
||||||
self.feature_dim = num_keypoints * 2
|
self.feature_dim = cfg.spatial_softmax_num_keypoints * 2
|
||||||
self.out = nn.Linear(num_keypoints * 2, self.feature_dim)
|
self.out = nn.Linear(cfg.spatial_softmax_num_keypoints * 2, self.feature_dim)
|
||||||
self.relu = nn.ReLU()
|
self.relu = nn.ReLU()
|
||||||
|
|
||||||
def forward(self, x: Tensor) -> Tensor:
|
def forward(self, x: Tensor) -> Tensor:
|
||||||
|
@ -516,16 +432,18 @@ def _replace_submodules(
|
||||||
|
|
||||||
|
|
||||||
class _SinusoidalPosEmb(nn.Module):
|
class _SinusoidalPosEmb(nn.Module):
|
||||||
def __init__(self, dim):
|
"""1D sinusoidal positional embeddings as in Attention is All You Need."""
|
||||||
|
|
||||||
|
def __init__(self, dim: int):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.dim = dim
|
self.dim = dim
|
||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x: Tensor) -> Tensor:
|
||||||
device = x.device
|
device = x.device
|
||||||
half_dim = self.dim // 2
|
half_dim = self.dim // 2
|
||||||
emb = math.log(10000) / (half_dim - 1)
|
emb = math.log(10000) / (half_dim - 1)
|
||||||
emb = torch.exp(torch.arange(half_dim, device=device) * -emb)
|
emb = torch.exp(torch.arange(half_dim, device=device) * -emb)
|
||||||
emb = x[:, None] * emb[None, :]
|
emb = x.unsqueeze(-1) * emb.unsqueeze(0)
|
||||||
emb = torch.cat((emb.sin(), emb.cos()), dim=-1)
|
emb = torch.cat((emb.sin(), emb.cos()), dim=-1)
|
||||||
return emb
|
return emb
|
||||||
|
|
||||||
|
@ -549,92 +467,46 @@ class _Conv1dBlock(nn.Module):
|
||||||
class _ConditionalUnet1D(nn.Module):
|
class _ConditionalUnet1D(nn.Module):
|
||||||
"""A 1D convolutional UNet with FiLM modulation for conditioning.
|
"""A 1D convolutional UNet with FiLM modulation for conditioning.
|
||||||
|
|
||||||
Two types of conditioning can be applied:
|
Note: this removes local conditioning as compared to the original diffusion policy code.
|
||||||
- Global: Conditioning information that is aggregated over the whole observation window. This is
|
|
||||||
incorporated via the FiLM technique in the residual convolution blocks of the Unet's encoder/decoder.
|
|
||||||
- Local: Conditioning information for each timestep in the observation window. This is incorporated
|
|
||||||
by encoding the information via 1D convolutions and adding the resulting embeddings to the inputs and
|
|
||||||
outputs of the Unet's encoder/decoder.
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(
|
def __init__(self, cfg: DiffusionConfig, global_cond_dim: int):
|
||||||
self,
|
|
||||||
input_dim: int,
|
|
||||||
local_cond_dim: int | None = None,
|
|
||||||
global_cond_dim: int | None = None,
|
|
||||||
diffusion_step_embed_dim: int = 256,
|
|
||||||
down_dims: int | None = None,
|
|
||||||
kernel_size: int = 3,
|
|
||||||
n_groups: int = 8,
|
|
||||||
film_scale_modulation: bool = False,
|
|
||||||
):
|
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
if down_dims is None:
|
self.cfg = cfg
|
||||||
down_dims = [256, 512, 1024]
|
|
||||||
|
|
||||||
# Encoder for the diffusion timestep.
|
# Encoder for the diffusion timestep.
|
||||||
self.diffusion_step_encoder = nn.Sequential(
|
self.diffusion_step_encoder = nn.Sequential(
|
||||||
_SinusoidalPosEmb(diffusion_step_embed_dim),
|
_SinusoidalPosEmb(cfg.diffusion_step_embed_dim),
|
||||||
nn.Linear(diffusion_step_embed_dim, diffusion_step_embed_dim * 4),
|
nn.Linear(cfg.diffusion_step_embed_dim, cfg.diffusion_step_embed_dim * 4),
|
||||||
nn.Mish(),
|
nn.Mish(),
|
||||||
nn.Linear(diffusion_step_embed_dim * 4, diffusion_step_embed_dim),
|
nn.Linear(cfg.diffusion_step_embed_dim * 4, cfg.diffusion_step_embed_dim),
|
||||||
)
|
)
|
||||||
|
|
||||||
# The FiLM conditioning dimension.
|
# The FiLM conditioning dimension.
|
||||||
cond_dim = diffusion_step_embed_dim
|
cond_dim = cfg.diffusion_step_embed_dim + global_cond_dim
|
||||||
if global_cond_dim is not None:
|
|
||||||
cond_dim += global_cond_dim
|
|
||||||
|
|
||||||
self.local_cond_down_encoder = None
|
|
||||||
self.local_cond_up_encoder = None
|
|
||||||
if local_cond_dim is not None:
|
|
||||||
# Encoder for the local conditioning. The output gets added to the Unet encoder input.
|
|
||||||
self.local_cond_down_encoder = _ConditionalResidualBlock1D(
|
|
||||||
local_cond_dim,
|
|
||||||
down_dims[0],
|
|
||||||
cond_dim=cond_dim,
|
|
||||||
kernel_size=kernel_size,
|
|
||||||
n_groups=n_groups,
|
|
||||||
film_scale_modulation=film_scale_modulation,
|
|
||||||
)
|
|
||||||
# Encoder for the local conditioning. The output gets added to the Unet encoder output.
|
|
||||||
self.local_cond_up_encoder = _ConditionalResidualBlock1D(
|
|
||||||
local_cond_dim,
|
|
||||||
down_dims[0],
|
|
||||||
cond_dim=cond_dim,
|
|
||||||
kernel_size=kernel_size,
|
|
||||||
n_groups=n_groups,
|
|
||||||
film_scale_modulation=film_scale_modulation,
|
|
||||||
)
|
|
||||||
|
|
||||||
# In channels / out channels for each downsampling block in the Unet's encoder. For the decoder, we
|
# In channels / out channels for each downsampling block in the Unet's encoder. For the decoder, we
|
||||||
# just reverse these.
|
# just reverse these.
|
||||||
in_out = [(input_dim, down_dims[0])] + list(zip(down_dims[:-1], down_dims[1:], strict=True))
|
in_out = [(cfg.action_dim, cfg.down_dims[0])] + list(
|
||||||
|
zip(cfg.down_dims[:-1], cfg.down_dims[1:], strict=True)
|
||||||
|
)
|
||||||
|
|
||||||
# Unet encoder.
|
# Unet encoder.
|
||||||
|
common_res_block_kwargs = {
|
||||||
|
"cond_dim": cond_dim,
|
||||||
|
"kernel_size": cfg.kernel_size,
|
||||||
|
"n_groups": cfg.n_groups,
|
||||||
|
"use_film_scale_modulation": cfg.use_film_scale_modulation,
|
||||||
|
}
|
||||||
self.down_modules = nn.ModuleList([])
|
self.down_modules = nn.ModuleList([])
|
||||||
for ind, (dim_in, dim_out) in enumerate(in_out):
|
for ind, (dim_in, dim_out) in enumerate(in_out):
|
||||||
is_last = ind >= (len(in_out) - 1)
|
is_last = ind >= (len(in_out) - 1)
|
||||||
self.down_modules.append(
|
self.down_modules.append(
|
||||||
nn.ModuleList(
|
nn.ModuleList(
|
||||||
[
|
[
|
||||||
_ConditionalResidualBlock1D(
|
_ConditionalResidualBlock1D(dim_in, dim_out, **common_res_block_kwargs),
|
||||||
dim_in,
|
_ConditionalResidualBlock1D(dim_out, dim_out, **common_res_block_kwargs),
|
||||||
dim_out,
|
|
||||||
cond_dim=cond_dim,
|
|
||||||
kernel_size=kernel_size,
|
|
||||||
n_groups=n_groups,
|
|
||||||
film_scale_modulation=film_scale_modulation,
|
|
||||||
),
|
|
||||||
_ConditionalResidualBlock1D(
|
|
||||||
dim_out,
|
|
||||||
dim_out,
|
|
||||||
cond_dim=cond_dim,
|
|
||||||
kernel_size=kernel_size,
|
|
||||||
n_groups=n_groups,
|
|
||||||
film_scale_modulation=film_scale_modulation,
|
|
||||||
),
|
|
||||||
# Downsample as long as it is not the last block.
|
# Downsample as long as it is not the last block.
|
||||||
nn.Conv1d(dim_out, dim_out, 3, 2, 1) if not is_last else nn.Identity(),
|
nn.Conv1d(dim_out, dim_out, 3, 2, 1) if not is_last else nn.Identity(),
|
||||||
]
|
]
|
||||||
|
@ -644,22 +516,8 @@ class _ConditionalUnet1D(nn.Module):
|
||||||
# Processing in the middle of the auto-encoder.
|
# Processing in the middle of the auto-encoder.
|
||||||
self.mid_modules = nn.ModuleList(
|
self.mid_modules = nn.ModuleList(
|
||||||
[
|
[
|
||||||
_ConditionalResidualBlock1D(
|
_ConditionalResidualBlock1D(cfg.down_dims[-1], cfg.down_dims[-1], **common_res_block_kwargs),
|
||||||
down_dims[-1],
|
_ConditionalResidualBlock1D(cfg.down_dims[-1], cfg.down_dims[-1], **common_res_block_kwargs),
|
||||||
down_dims[-1],
|
|
||||||
cond_dim=cond_dim,
|
|
||||||
kernel_size=kernel_size,
|
|
||||||
n_groups=n_groups,
|
|
||||||
film_scale_modulation=film_scale_modulation,
|
|
||||||
),
|
|
||||||
_ConditionalResidualBlock1D(
|
|
||||||
down_dims[-1],
|
|
||||||
down_dims[-1],
|
|
||||||
cond_dim=cond_dim,
|
|
||||||
kernel_size=kernel_size,
|
|
||||||
n_groups=n_groups,
|
|
||||||
film_scale_modulation=film_scale_modulation,
|
|
||||||
),
|
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -670,22 +528,9 @@ class _ConditionalUnet1D(nn.Module):
|
||||||
self.up_modules.append(
|
self.up_modules.append(
|
||||||
nn.ModuleList(
|
nn.ModuleList(
|
||||||
[
|
[
|
||||||
_ConditionalResidualBlock1D(
|
# dim_in * 2, because it takes the encoder's skip connection as well
|
||||||
dim_in * 2, # x2 as it takes the encoder's skip connection as well
|
_ConditionalResidualBlock1D(dim_in * 2, dim_out, **common_res_block_kwargs),
|
||||||
dim_out,
|
_ConditionalResidualBlock1D(dim_out, dim_out, **common_res_block_kwargs),
|
||||||
cond_dim=cond_dim,
|
|
||||||
kernel_size=kernel_size,
|
|
||||||
n_groups=n_groups,
|
|
||||||
film_scale_modulation=film_scale_modulation,
|
|
||||||
),
|
|
||||||
_ConditionalResidualBlock1D(
|
|
||||||
dim_out,
|
|
||||||
dim_out,
|
|
||||||
cond_dim=cond_dim,
|
|
||||||
kernel_size=kernel_size,
|
|
||||||
n_groups=n_groups,
|
|
||||||
film_scale_modulation=film_scale_modulation,
|
|
||||||
),
|
|
||||||
# Upsample as long as it is not the last block.
|
# Upsample as long as it is not the last block.
|
||||||
nn.ConvTranspose1d(dim_out, dim_out, 4, 2, 1) if not is_last else nn.Identity(),
|
nn.ConvTranspose1d(dim_out, dim_out, 4, 2, 1) if not is_last else nn.Identity(),
|
||||||
]
|
]
|
||||||
|
@ -693,29 +538,22 @@ class _ConditionalUnet1D(nn.Module):
|
||||||
)
|
)
|
||||||
|
|
||||||
self.final_conv = nn.Sequential(
|
self.final_conv = nn.Sequential(
|
||||||
_Conv1dBlock(down_dims[0], down_dims[0], kernel_size=kernel_size),
|
_Conv1dBlock(cfg.down_dims[0], cfg.down_dims[0], kernel_size=cfg.kernel_size),
|
||||||
nn.Conv1d(down_dims[0], input_dim, 1),
|
nn.Conv1d(cfg.down_dims[0], cfg.action_dim, 1),
|
||||||
)
|
)
|
||||||
|
|
||||||
def forward(self, x: Tensor, timestep: Tensor | int, local_cond=None, global_cond=None) -> Tensor:
|
def forward(self, x: Tensor, timestep: Tensor | int, global_cond=None) -> Tensor:
|
||||||
"""
|
"""
|
||||||
Args:
|
Args:
|
||||||
x: (B, T, input_dim) tensor for input to the Unet.
|
x: (B, T, input_dim) tensor for input to the Unet.
|
||||||
timestep: (B,) tensor of (timestep_we_are_denoising_from - 1).
|
timestep: (B,) tensor of (timestep_we_are_denoising_from - 1).
|
||||||
local_cond: (B, T, local_cond_dim)
|
|
||||||
global_cond: (B, global_cond_dim)
|
global_cond: (B, global_cond_dim)
|
||||||
output: (B, T, input_dim)
|
output: (B, T, input_dim)
|
||||||
Returns:
|
Returns:
|
||||||
(B, T, input_dim)
|
(B, T, input_dim) diffusion model prediction.
|
||||||
"""
|
"""
|
||||||
# For 1D convolutions we'll need feature dimension first.
|
# For 1D convolutions we'll need feature dimension first.
|
||||||
x = einops.rearrange(x, "b t d -> b d t")
|
x = einops.rearrange(x, "b t d -> b d t")
|
||||||
if local_cond is not None:
|
|
||||||
if self.local_cond_down_encoder is None or self.local_cond_up_encoder is None:
|
|
||||||
raise ValueError(
|
|
||||||
"`local_cond` was provided but the relevant encoders weren't built at initialization."
|
|
||||||
)
|
|
||||||
local_cond = einops.rearrange(local_cond, "b t d -> b d t")
|
|
||||||
|
|
||||||
timesteps_embed = self.diffusion_step_encoder(timestep)
|
timesteps_embed = self.diffusion_step_encoder(timestep)
|
||||||
|
|
||||||
|
@ -725,11 +563,10 @@ class _ConditionalUnet1D(nn.Module):
|
||||||
else:
|
else:
|
||||||
global_feature = timesteps_embed
|
global_feature = timesteps_embed
|
||||||
|
|
||||||
|
# Run encoder, keeping track of skip features to pass to the decoder.
|
||||||
encoder_skip_features: list[Tensor] = []
|
encoder_skip_features: list[Tensor] = []
|
||||||
for idx, (resnet, resnet2, downsample) in enumerate(self.down_modules):
|
for resnet, resnet2, downsample in self.down_modules:
|
||||||
x = resnet(x, global_feature)
|
x = resnet(x, global_feature)
|
||||||
if idx == 0 and local_cond is not None:
|
|
||||||
x = x + self.local_cond_down_encoder(local_cond, global_feature)
|
|
||||||
x = resnet2(x, global_feature)
|
x = resnet2(x, global_feature)
|
||||||
encoder_skip_features.append(x)
|
encoder_skip_features.append(x)
|
||||||
x = downsample(x)
|
x = downsample(x)
|
||||||
|
@ -737,14 +574,10 @@ class _ConditionalUnet1D(nn.Module):
|
||||||
for mid_module in self.mid_modules:
|
for mid_module in self.mid_modules:
|
||||||
x = mid_module(x, global_feature)
|
x = mid_module(x, global_feature)
|
||||||
|
|
||||||
for idx, (resnet, resnet2, upsample) in enumerate(self.up_modules):
|
# Run decoder, using the skip features from the encoder.
|
||||||
|
for resnet, resnet2, upsample in self.up_modules:
|
||||||
x = torch.cat((x, encoder_skip_features.pop()), dim=1)
|
x = torch.cat((x, encoder_skip_features.pop()), dim=1)
|
||||||
x = resnet(x, global_feature)
|
x = resnet(x, global_feature)
|
||||||
# Note: The condition in the original implementation is:
|
|
||||||
# if idx == len(self.up_modules) and local_cond is not None:
|
|
||||||
# But as they mention in their comments, this is incorrect. We use the correct condition here.
|
|
||||||
if idx == (len(self.up_modules) - 1) and local_cond is not None:
|
|
||||||
x = x + self.local_cond_up_encoder(local_cond, global_feature)
|
|
||||||
x = resnet2(x, global_feature)
|
x = resnet2(x, global_feature)
|
||||||
x = upsample(x)
|
x = upsample(x)
|
||||||
|
|
||||||
|
@ -766,17 +599,17 @@ class _ConditionalResidualBlock1D(nn.Module):
|
||||||
n_groups: int = 8,
|
n_groups: int = 8,
|
||||||
# Set to True to do scale modulation with FiLM as well as bias modulation (defaults to False meaning
|
# Set to True to do scale modulation with FiLM as well as bias modulation (defaults to False meaning
|
||||||
# FiLM just modulates bias).
|
# FiLM just modulates bias).
|
||||||
film_scale_modulation: bool = False,
|
use_film_scale_modulation: bool = False,
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
self.film_scale_modulation = film_scale_modulation
|
self.use_film_scale_modulation = use_film_scale_modulation
|
||||||
self.out_channels = out_channels
|
self.out_channels = out_channels
|
||||||
|
|
||||||
self.conv1 = _Conv1dBlock(in_channels, out_channels, kernel_size, n_groups=n_groups)
|
self.conv1 = _Conv1dBlock(in_channels, out_channels, kernel_size, n_groups=n_groups)
|
||||||
|
|
||||||
# FiLM modulation (https://arxiv.org/abs/1709.07871) outputs per-channel bias and (maybe) scale.
|
# FiLM modulation (https://arxiv.org/abs/1709.07871) outputs per-channel bias and (maybe) scale.
|
||||||
cond_channels = out_channels * 2 if film_scale_modulation else out_channels
|
cond_channels = out_channels * 2 if use_film_scale_modulation else out_channels
|
||||||
self.cond_encoder = nn.Sequential(nn.Mish(), nn.Linear(cond_dim, cond_channels))
|
self.cond_encoder = nn.Sequential(nn.Mish(), nn.Linear(cond_dim, cond_channels))
|
||||||
|
|
||||||
self.conv2 = _Conv1dBlock(out_channels, out_channels, kernel_size, n_groups=n_groups)
|
self.conv2 = _Conv1dBlock(out_channels, out_channels, kernel_size, n_groups=n_groups)
|
||||||
|
@ -798,7 +631,7 @@ class _ConditionalResidualBlock1D(nn.Module):
|
||||||
|
|
||||||
# Get condition embedding. Unsqueeze for broadcasting to `out`, resulting in (B, out_channels, 1).
|
# Get condition embedding. Unsqueeze for broadcasting to `out`, resulting in (B, out_channels, 1).
|
||||||
cond_embed = self.cond_encoder(cond).unsqueeze(-1)
|
cond_embed = self.cond_encoder(cond).unsqueeze(-1)
|
||||||
if self.film_scale_modulation:
|
if self.use_film_scale_modulation:
|
||||||
# Treat the embedding as a list of scales and biases.
|
# Treat the embedding as a list of scales and biases.
|
||||||
scale = cond_embed[:, : self.out_channels]
|
scale = cond_embed[:, : self.out_channels]
|
||||||
bias = cond_embed[:, self.out_channels :]
|
bias = cond_embed[:, self.out_channels :]
|
||||||
|
@ -817,9 +650,7 @@ class _EMA:
|
||||||
Exponential Moving Average of models weights
|
Exponential Moving Average of models weights
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(
|
def __init__(self, cfg: DiffusionConfig, model: nn.Module):
|
||||||
self, model, update_after_step=0, inv_gamma=1.0, power=2 / 3, min_value=0.0, max_value=0.9999
|
|
||||||
):
|
|
||||||
"""
|
"""
|
||||||
@crowsonkb's notes on EMA Warmup:
|
@crowsonkb's notes on EMA Warmup:
|
||||||
If gamma=1 and power=1, implements a simple average. gamma=1, power=2/3 are good values for models you plan
|
If gamma=1 and power=1, implements a simple average. gamma=1, power=2/3 are good values for models you plan
|
||||||
|
@ -829,18 +660,18 @@ class _EMA:
|
||||||
Args:
|
Args:
|
||||||
inv_gamma (float): Inverse multiplicative factor of EMA warmup. Default: 1.
|
inv_gamma (float): Inverse multiplicative factor of EMA warmup. Default: 1.
|
||||||
power (float): Exponential factor of EMA warmup. Default: 2/3.
|
power (float): Exponential factor of EMA warmup. Default: 2/3.
|
||||||
min_value (float): The minimum EMA decay rate. Default: 0.
|
min_alpha (float): The minimum EMA decay rate. Default: 0.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
self.averaged_model = model
|
self.averaged_model = model
|
||||||
self.averaged_model.eval()
|
self.averaged_model.eval()
|
||||||
self.averaged_model.requires_grad_(False)
|
self.averaged_model.requires_grad_(False)
|
||||||
|
|
||||||
self.update_after_step = update_after_step
|
self.update_after_step = cfg.ema_update_after_step
|
||||||
self.inv_gamma = inv_gamma
|
self.inv_gamma = cfg.ema_inv_gamma
|
||||||
self.power = power
|
self.power = cfg.ema_power
|
||||||
self.min_value = min_value
|
self.min_alpha = cfg.ema_min_alpha
|
||||||
self.max_value = max_value
|
self.max_alpha = cfg.ema_max_alpha
|
||||||
|
|
||||||
self.alpha = 0.0
|
self.alpha = 0.0
|
||||||
self.optimization_step = 0
|
self.optimization_step = 0
|
||||||
|
@ -855,7 +686,7 @@ class _EMA:
|
||||||
if step <= 0:
|
if step <= 0:
|
||||||
return 0.0
|
return 0.0
|
||||||
|
|
||||||
return max(self.min_value, min(value, self.max_value))
|
return max(self.min_alpha, min(value, self.max_alpha))
|
||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
def step(self, new_model):
|
def step(self, new_model):
|
||||||
|
|
|
@ -9,7 +9,7 @@ def _policy_cfg_from_hydra_cfg(policy_cfg_class, hydra_cfg):
|
||||||
expected_kwargs = set(inspect.signature(policy_cfg_class).parameters)
|
expected_kwargs = set(inspect.signature(policy_cfg_class).parameters)
|
||||||
assert set(hydra_cfg.policy).issuperset(
|
assert set(hydra_cfg.policy).issuperset(
|
||||||
expected_kwargs
|
expected_kwargs
|
||||||
), f"Hydra config is missing arguments: {set(hydra_cfg.policy).difference(expected_kwargs)}"
|
), f"Hydra config is missing arguments: {set(expected_kwargs).difference(hydra_cfg.policy)}"
|
||||||
policy_cfg = policy_cfg_class(
|
policy_cfg = policy_cfg_class(
|
||||||
**{
|
**{
|
||||||
k: v
|
k: v
|
||||||
|
@ -35,7 +35,7 @@ def make_policy(hydra_cfg: DictConfig):
|
||||||
from lerobot.common.policies.diffusion.modeling_diffusion import DiffusionPolicy
|
from lerobot.common.policies.diffusion.modeling_diffusion import DiffusionPolicy
|
||||||
|
|
||||||
policy_cfg = _policy_cfg_from_hydra_cfg(DiffusionConfig, hydra_cfg)
|
policy_cfg = _policy_cfg_from_hydra_cfg(DiffusionConfig, hydra_cfg)
|
||||||
policy = DiffusionPolicy(policy_cfg)
|
policy = DiffusionPolicy(policy_cfg, hydra_cfg.offline_steps)
|
||||||
policy.to(get_safe_torch_device(hydra_cfg.device))
|
policy.to(get_safe_torch_device(hydra_cfg.device))
|
||||||
elif hydra_cfg.policy.name == "act":
|
elif hydra_cfg.policy.name == "act":
|
||||||
from lerobot.common.policies.act.configuration_act import ActionChunkingTransformerConfig
|
from lerobot.common.policies.act.configuration_act import ActionChunkingTransformerConfig
|
||||||
|
|
|
@ -44,7 +44,7 @@ policy:
|
||||||
# Vision backbone.
|
# Vision backbone.
|
||||||
vision_backbone: resnet18
|
vision_backbone: resnet18
|
||||||
crop_shape: [84, 84]
|
crop_shape: [84, 84]
|
||||||
random_crop: True
|
crop_is_random: True
|
||||||
use_pretrained_backbone: false
|
use_pretrained_backbone: false
|
||||||
use_group_norm: True
|
use_group_norm: True
|
||||||
spatial_softmax_num_keypoints: 32
|
spatial_softmax_num_keypoints: 32
|
||||||
|
@ -53,15 +53,15 @@ policy:
|
||||||
kernel_size: 5
|
kernel_size: 5
|
||||||
n_groups: 8
|
n_groups: 8
|
||||||
diffusion_step_embed_dim: 128
|
diffusion_step_embed_dim: 128
|
||||||
film_scale_modulation: True
|
use_film_scale_modulation: True
|
||||||
# Noise scheduler.
|
# Noise scheduler.
|
||||||
num_train_timesteps: 100
|
num_train_timesteps: 100
|
||||||
beta_schedule: squaredcos_cap_v2
|
beta_schedule: squaredcos_cap_v2
|
||||||
beta_start: 0.0001
|
beta_start: 0.0001
|
||||||
beta_end: 0.02
|
beta_end: 0.02
|
||||||
variance_type: fixed_small
|
|
||||||
prediction_type: epsilon # epsilon / sample
|
prediction_type: epsilon # epsilon / sample
|
||||||
clip_sample: True
|
clip_sample: True
|
||||||
|
clip_sample_range: 1.0
|
||||||
|
|
||||||
# Inference
|
# Inference
|
||||||
num_inference_steps: 100
|
num_inference_steps: 100
|
||||||
|
@ -79,12 +79,12 @@ policy:
|
||||||
utd: 1
|
utd: 1
|
||||||
use_ema: true
|
use_ema: true
|
||||||
ema_update_after_step: 0
|
ema_update_after_step: 0
|
||||||
ema_min_rate: 0.0
|
ema_min_alpha: 0.0
|
||||||
ema_max_rate: 0.9999
|
ema_max_alpha: 0.9999
|
||||||
ema_inv_gamma: 1.0
|
ema_inv_gamma: 1.0
|
||||||
ema_power: 0.75
|
ema_power: 0.75
|
||||||
|
|
||||||
delta_timestamps:
|
delta_timestamps:
|
||||||
observation.images: "[i / ${fps} for i in range(1 - ${n_obs_steps}, 1)]"
|
observation.image: "[i / ${fps} for i in range(1 - ${n_obs_steps}, 1)]"
|
||||||
observation.state: "[i / ${fps} for i in range(1 - ${n_obs_steps}, 1)]"
|
observation.state: "[i / ${fps} for i in range(1 - ${n_obs_steps}, 1)]"
|
||||||
action: "[i / ${fps} for i in range(1 - ${n_obs_steps}, 1 - ${n_obs_steps} + ${policy.horizon})]"
|
action: "[i / ${fps} for i in range(1 - ${n_obs_steps}, 1 - ${n_obs_steps} + ${policy.horizon})]"
|
||||||
|
|
|
@ -1,8 +1,8 @@
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
|
|
||||||
def _find_and_replace(text: str, finds: list[str], replaces: list[str]) -> str:
|
def _find_and_replace(text: str, finds_and_replaces: list[tuple[str, str]]) -> str:
|
||||||
for f, r in zip(finds, replaces):
|
for f, r in finds_and_replaces:
|
||||||
assert f in text
|
assert f in text
|
||||||
text = text.replace(f, r)
|
text = text.replace(f, r)
|
||||||
return text
|
return text
|
||||||
|
@ -32,8 +32,10 @@ def test_examples_3_and_2():
|
||||||
# Do less steps and use CPU.
|
# Do less steps and use CPU.
|
||||||
file_contents = _find_and_replace(
|
file_contents = _find_and_replace(
|
||||||
file_contents,
|
file_contents,
|
||||||
['"offline_steps=5000"', '"device=cuda"'],
|
[
|
||||||
['"offline_steps=1"', '"device=cpu"'],
|
("offline_steps = 5000", "offline_steps = 1"),
|
||||||
|
('device = torch.device("cuda")', 'device = torch.device("cpu")'),
|
||||||
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
exec(file_contents)
|
exec(file_contents)
|
||||||
|
@ -50,20 +52,15 @@ def test_examples_3_and_2():
|
||||||
file_contents = _find_and_replace(
|
file_contents = _find_and_replace(
|
||||||
file_contents,
|
file_contents,
|
||||||
[
|
[
|
||||||
'"eval_episodes=10"',
|
('"eval_episodes=10"', '"eval_episodes=1"'),
|
||||||
'"rollout_batch_size=10"',
|
('"rollout_batch_size=10"', '"rollout_batch_size=1"'),
|
||||||
'"device=cuda"',
|
('"device=cuda"', '"device=cpu"'),
|
||||||
'# folder = Path("outputs/train/example_pusht_diffusion")',
|
(
|
||||||
'hub_id = "lerobot/diffusion_policy_pusht_image"',
|
'# folder = Path("outputs/train/example_pusht_diffusion")',
|
||||||
"folder = Path(snapshot_download(hub_id)",
|
'folder = Path("outputs/train/example_pusht_diffusion")',
|
||||||
],
|
),
|
||||||
[
|
('hub_id = "lerobot/diffusion_policy_pusht_image"', ""),
|
||||||
'"eval_episodes=1"',
|
("folder = Path(snapshot_download(hub_id)", ""),
|
||||||
'"rollout_batch_size=1"',
|
|
||||||
'"device=cpu"',
|
|
||||||
'folder = Path("outputs/train/example_pusht_diffusion")',
|
|
||||||
"",
|
|
||||||
"",
|
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue