backup wip

This commit is contained in:
Alexander Soare 2024-04-11 17:51:35 +01:00
parent 91ff69d64c
commit 976a197f98
26 changed files with 661 additions and 2733 deletions

View File

@ -1,3 +1,8 @@
"""Helper code for loading PushT dataset from Diffusion Policy (https://diffusion-policy.cs.columbia.edu/)
Copied from the original Diffusion Policy repository.
"""
from __future__ import annotations
import math

View File

@ -8,8 +8,10 @@ import torch
import tqdm
from gym_pusht.envs.pusht import pymunk_to_shapely
from lerobot.common.datasets._diffusion_policy_replay_buffer import (
ReplayBuffer as DiffusionPolicyReplayBuffer,
)
from lerobot.common.datasets.utils import download_and_extract_zip, load_data_with_delta_timestamps
from lerobot.common.policies.diffusion.replay_buffer import ReplayBuffer as DiffusionPolicyReplayBuffer
# as define in env
SUCCESS_THRESHOLD = 0.95 # 95% coverage,

View File

@ -176,7 +176,8 @@ class ActionChunkingTransformerPolicy(nn.Module):
if self.n_action_steps is not None:
self._action_queue = deque([], maxlen=self.n_action_steps)
def select_action(self, batch: dict[str, Tensor], *_, **__) -> Tensor:
@torch.no_grad
def select_action(self, batch: dict[str, Tensor], **_) -> Tensor:
"""
This method wraps `select_actions` in order to return one action at a time for execution in the
environment. It works by managing the actions in a queue and only calling `select_actions` when the
@ -188,7 +189,7 @@ class ActionChunkingTransformerPolicy(nn.Module):
self._action_queue.extend(self.select_actions(batch).transpose(0, 1))
return self._action_queue.popleft()
@torch.no_grad()
@torch.no_grad
def select_actions(self, batch: dict[str, Tensor]) -> Tensor:
"""Use the action chunking transformer to generate a sequence of actions."""
self.eval()
@ -223,8 +224,6 @@ class ActionChunkingTransformerPolicy(nn.Module):
{
"observation.state": (B, 1, J) OR (B, J) tensor of robot states (joint configuration).
"observation.images.top": (B, 1, C, H, W) OR (B, C, H, W) tensor of images.
"action": (B, H, J) tensor of actions (positional target for robot joint configuration)
"action_is_pad": (B, H) mask for whether the actions are padding outside of the episode bounds.
}
"""
if add_obs_steps_dim:
@ -244,7 +243,8 @@ class ActionChunkingTransformerPolicy(nn.Module):
# Note: no squeeze is required for "observation.images.top" because then we'd have to unsqueeze to get
# the image index dimension.
def update(self, batch, *_, **__) -> dict:
def update(self, batch, **_) -> dict:
"""Run the model in train mode, compute the loss, and do an optimization step."""
start_time = time.time()
self._preprocess_batch(batch)
@ -278,6 +278,7 @@ class ActionChunkingTransformerPolicy(nn.Module):
return info
def forward(self, batch: dict[str, Tensor], return_loss: bool = False) -> dict | Tensor:
"""A forward pass through the DNN part of this policy with optional loss computation."""
images = self.image_normalizer(batch["observation.images.top"])
if return_loss: # training time

View File

@ -1,315 +0,0 @@
"""Code from the original diffusion policy project.
Notes on how to load a checkpoint from the original repository:
In the original repository, run the eval and use a breakpoint to extract the policy weights.
```
torch.save(policy.state_dict(), "weights.pt")
```
In this repository, add a breakpoint somewhere after creating an equivalent policy and load in the weights:
```
loaded = torch.load("weights.pt")
aligned = {}
their_prefix = "obs_encoder.obs_nets.image.backbone"
our_prefix = "obs_encoder.key_model_map.image.backbone"
aligned.update({our_prefix + k.removeprefix(their_prefix): v for k, v in loaded.items() if k.startswith(their_prefix)})
their_prefix = "obs_encoder.obs_nets.image.pool"
our_prefix = "obs_encoder.key_model_map.image.pool"
aligned.update({our_prefix + k.removeprefix(their_prefix): v for k, v in loaded.items() if k.startswith(their_prefix)})
their_prefix = "obs_encoder.obs_nets.image.nets.3"
our_prefix = "obs_encoder.key_model_map.image.out"
aligned.update({our_prefix + k.removeprefix(their_prefix): v for k, v in loaded.items() if k.startswith(their_prefix)})
aligned.update({k: v for k, v in loaded.items() if k.startswith('model.')})
# Note: here you are loading into the ema model.
missing_keys, unexpected_keys = policy.ema_diffusion.load_state_dict(aligned, strict=False)
assert all('_dummy_variable' in k for k in missing_keys)
assert len(unexpected_keys) == 0
```
Then in that same runtime you can also save the weights with the new aligned state_dict:
```
policy.save("weights.pt")
```
Now you can remove the breakpoint and extra code and load in the weights just like with any other lerobot checkpoint.
"""
from typing import Dict
import torch
import torch.nn.functional as F # noqa: N812
from diffusers.schedulers.scheduling_ddpm import DDPMScheduler
from einops import reduce
from lerobot.common.policies.diffusion.model.conditional_unet1d import ConditionalUnet1D
from lerobot.common.policies.diffusion.model.mask_generator import LowdimMaskGenerator
from lerobot.common.policies.diffusion.model.module_attr_mixin import ModuleAttrMixin
from lerobot.common.policies.diffusion.model.multi_image_obs_encoder import MultiImageObsEncoder
from lerobot.common.policies.diffusion.model.normalizer import LinearNormalizer
from lerobot.common.policies.diffusion.pytorch_utils import dict_apply
class BaseImagePolicy(ModuleAttrMixin):
# init accepts keyword argument shape_meta, see config/task/*_image.yaml
def predict_action(self, obs_dict: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
"""
obs_dict:
str: B,To,*
return: B,Ta,Da
"""
raise NotImplementedError()
# reset state for stateful policies
def reset(self):
pass
# ========== training ===========
# no standard training interface except setting normalizer
def set_normalizer(self, normalizer: LinearNormalizer):
raise NotImplementedError()
class DiffusionUnetImagePolicy(BaseImagePolicy):
def __init__(
self,
shape_meta: dict,
noise_scheduler: DDPMScheduler,
obs_encoder: MultiImageObsEncoder,
horizon,
n_action_steps,
n_obs_steps,
num_inference_steps=None,
obs_as_global_cond=True,
diffusion_step_embed_dim=256,
down_dims=(256, 512, 1024),
kernel_size=5,
n_groups=8,
cond_predict_scale=True,
# parameters passed to step
**kwargs,
):
super().__init__()
# parse shapes
action_shape = shape_meta["action"]["shape"]
assert len(action_shape) == 1
action_dim = action_shape[0]
# get feature dim
obs_feature_dim = obs_encoder.output_shape()[0]
# create diffusion model
input_dim = action_dim + obs_feature_dim
global_cond_dim = None
if obs_as_global_cond:
input_dim = action_dim
global_cond_dim = obs_feature_dim * n_obs_steps
model = ConditionalUnet1D(
input_dim=input_dim,
local_cond_dim=None,
global_cond_dim=global_cond_dim,
diffusion_step_embed_dim=diffusion_step_embed_dim,
down_dims=down_dims,
kernel_size=kernel_size,
n_groups=n_groups,
cond_predict_scale=cond_predict_scale,
)
self.obs_encoder = obs_encoder
self.model = model
self.noise_scheduler = noise_scheduler
self.mask_generator = LowdimMaskGenerator(
action_dim=action_dim,
obs_dim=0 if obs_as_global_cond else obs_feature_dim,
max_n_obs_steps=n_obs_steps,
fix_obs_steps=True,
action_visible=False,
)
self.horizon = horizon
self.obs_feature_dim = obs_feature_dim
self.action_dim = action_dim
self.n_action_steps = n_action_steps
self.n_obs_steps = n_obs_steps
self.obs_as_global_cond = obs_as_global_cond
self.kwargs = kwargs
if num_inference_steps is None:
num_inference_steps = noise_scheduler.config.num_train_timesteps
self.num_inference_steps = num_inference_steps
# ========= inference ============
def conditional_sample(
self,
condition_data,
condition_mask,
local_cond=None,
global_cond=None,
generator=None,
# keyword arguments to scheduler.step
**kwargs,
):
model = self.model
scheduler = self.noise_scheduler
trajectory = torch.randn(
size=condition_data.shape,
dtype=condition_data.dtype,
device=condition_data.device,
generator=generator,
)
# set step values
scheduler.set_timesteps(self.num_inference_steps)
for t in scheduler.timesteps:
# 1. apply conditioning
trajectory[condition_mask] = condition_data[condition_mask]
# 2. predict model output
model_output = model(trajectory, t, local_cond=local_cond, global_cond=global_cond)
# 3. compute previous image: x_t -> x_t-1
trajectory = scheduler.step(
model_output,
t,
trajectory,
generator=generator,
# **kwargs # TODO(rcadene): in diffusion_policy, expected to be {}
).prev_sample
# finally make sure conditioning is enforced
trajectory[condition_mask] = condition_data[condition_mask]
return trajectory
def predict_action(self, obs_dict: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
"""
obs_dict: must include "obs" key
result: must include "action" key
"""
assert "past_action" not in obs_dict # not implemented yet
nobs = obs_dict
value = next(iter(nobs.values()))
bsize, n_obs_steps = value.shape[:2]
horizon = self.horizon
action_dim = self.action_dim
obs_dim = self.obs_feature_dim
assert self.n_obs_steps == n_obs_steps
# build input
device = self.device
dtype = self.dtype
# handle different ways of passing observation
local_cond = None
global_cond = None
if self.obs_as_global_cond:
# condition through global feature
this_nobs = dict_apply(nobs, lambda x: x[:, :n_obs_steps, ...].reshape(-1, *x.shape[2:]))
nobs_features = self.obs_encoder(this_nobs)
# reshape back to B, Do
global_cond = nobs_features.reshape(bsize, -1)
# empty data for action
cond_data = torch.zeros(size=(bsize, horizon, action_dim), device=device, dtype=dtype)
cond_mask = torch.zeros_like(cond_data, dtype=torch.bool)
else:
# condition through impainting
this_nobs = dict_apply(nobs, lambda x: x[:, :n_obs_steps, ...].reshape(-1, *x.shape[2:]))
nobs_features = self.obs_encoder(this_nobs)
# reshape back to B, T, Do
nobs_features = nobs_features.reshape(bsize, n_obs_steps, -1)
cond_data = torch.zeros(size=(bsize, horizon, action_dim + obs_dim), device=device, dtype=dtype)
cond_mask = torch.zeros_like(cond_data, dtype=torch.bool)
cond_data[:, :n_obs_steps, action_dim:] = nobs_features
cond_mask[:, :n_obs_steps, action_dim:] = True
# run sampling
nsample = self.conditional_sample(
cond_data, cond_mask, local_cond=local_cond, global_cond=global_cond
)
action_pred = nsample[..., :action_dim]
# get action
start = n_obs_steps - 1
end = start + self.n_action_steps
action = action_pred[:, start:end]
result = {"action": action, "action_pred": action_pred}
return result
def compute_loss(self, batch):
nobs = {
"image": batch["observation.image"],
"agent_pos": batch["observation.state"],
}
nactions = batch["action"]
batch_size = nactions.shape[0]
horizon = nactions.shape[1]
# handle different ways of passing observation
local_cond = None
global_cond = None
trajectory = nactions
cond_data = trajectory
if self.obs_as_global_cond:
# reshape B, T, ... to B*T
this_nobs = dict_apply(nobs, lambda x: x[:, : self.n_obs_steps, ...].reshape(-1, *x.shape[2:]))
nobs_features = self.obs_encoder(this_nobs)
# reshape back to B, Do
global_cond = nobs_features.reshape(batch_size, -1)
else:
# reshape B, T, ... to B*T
this_nobs = dict_apply(nobs, lambda x: x.reshape(-1, *x.shape[2:]))
nobs_features = self.obs_encoder(this_nobs)
# reshape back to B, T, Do
nobs_features = nobs_features.reshape(batch_size, horizon, -1)
cond_data = torch.cat([nactions, nobs_features], dim=-1)
trajectory = cond_data.detach()
# generate impainting mask
condition_mask = self.mask_generator(trajectory.shape)
# Sample noise that we'll add to the images
noise = torch.randn(trajectory.shape, device=trajectory.device)
bsz = trajectory.shape[0]
# Sample a random timestep for each image
timesteps = torch.randint(
0, self.noise_scheduler.config.num_train_timesteps, (bsz,), device=trajectory.device
).long()
# Add noise to the clean images according to the noise magnitude at each timestep
# (this is the forward diffusion process)
noisy_trajectory = self.noise_scheduler.add_noise(trajectory, noise, timesteps)
# compute loss mask
loss_mask = ~condition_mask
# apply conditioning
noisy_trajectory[condition_mask] = cond_data[condition_mask]
# Predict the noise residual
pred = self.model(noisy_trajectory, timesteps, local_cond=local_cond, global_cond=global_cond)
pred_type = self.noise_scheduler.config.prediction_type
if pred_type == "epsilon":
target = noise
elif pred_type == "sample":
target = trajectory
else:
raise ValueError(f"Unsupported prediction type {pred_type}")
loss = F.mse_loss(pred, target, reduction="none")
loss = loss * loss_mask.type(loss.dtype)
if "action_is_pad" in batch:
in_episode_bound = ~batch["action_is_pad"]
loss = loss * in_episode_bound[:, :, None].type(loss.dtype)
loss = reduce(loss, "b t c -> b", "mean", b=batch_size)
loss = loss.mean()
return loss

View File

@ -1,286 +1,307 @@
import logging
from typing import Union
import math
import einops
import torch
import torch.nn as nn
from einops.layers.torch import Rearrange
from lerobot.common.policies.diffusion.model.conv1d_components import Conv1dBlock, Downsample1d, Upsample1d
from lerobot.common.policies.diffusion.model.positional_embedding import SinusoidalPosEmb
from torch import Tensor
logger = logging.getLogger(__name__)
class ConditionalResidualBlock1D(nn.Module):
class _SinusoidalPosEmb(nn.Module):
# TODO(now): consolidate?
def __init__(self, dim):
super().__init__()
self.dim = dim
def forward(self, x):
device = x.device
half_dim = self.dim // 2
emb = math.log(10000) / (half_dim - 1)
emb = torch.exp(torch.arange(half_dim, device=device) * -emb)
emb = x[:, None] * emb[None, :]
emb = torch.cat((emb.sin(), emb.cos()), dim=-1)
return emb
class _Conv1dBlock(nn.Module):
"""Conv1d --> GroupNorm --> Mish"""
def __init__(self, inp_channels, out_channels, kernel_size, n_groups=8):
super().__init__()
self.block = nn.Sequential(
nn.Conv1d(inp_channels, out_channels, kernel_size, padding=kernel_size // 2),
nn.GroupNorm(n_groups, out_channels),
nn.Mish(),
)
def forward(self, x):
return self.block(x)
class _ConditionalResidualBlock1D(nn.Module):
"""ResNet style 1D convolutional block with FiLM modulation for conditioning."""
def __init__(
self, in_channels, out_channels, cond_dim, kernel_size=3, n_groups=8, cond_predict_scale=False
self,
in_channels: int,
out_channels: int,
cond_dim: int,
kernel_size: int = 3,
n_groups: int = 8,
# Set to True to do scale modulation with FiLM as well as bias modulation (defaults to False meaning
# FiLM just modulates bias).
film_scale_modulation: bool = False,
):
super().__init__()
self.blocks = nn.ModuleList(
[
Conv1dBlock(in_channels, out_channels, kernel_size, n_groups=n_groups),
Conv1dBlock(out_channels, out_channels, kernel_size, n_groups=n_groups),
]
)
# FiLM modulation https://arxiv.org/abs/1709.07871
# predicts per-channel scale and bias
cond_channels = out_channels
if cond_predict_scale:
cond_channels = out_channels * 2
self.cond_predict_scale = cond_predict_scale
self.film_scale_modulation = film_scale_modulation
self.out_channels = out_channels
self.cond_encoder = nn.Sequential(
nn.Mish(),
nn.Linear(cond_dim, cond_channels),
Rearrange("batch t -> batch t 1"),
)
# make sure dimensions compatible
self.conv1 = _Conv1dBlock(in_channels, out_channels, kernel_size, n_groups=n_groups)
# FiLM modulation (https://arxiv.org/abs/1709.07871) outputs per-channel bias and (maybe) scale.
cond_channels = out_channels * 2 if film_scale_modulation else out_channels
self.cond_encoder = nn.Sequential(nn.Mish(), nn.Linear(cond_dim, cond_channels))
self.conv2 = _Conv1dBlock(out_channels, out_channels, kernel_size, n_groups=n_groups)
# A final convolution for dimension matching the residual (if needed).
self.residual_conv = (
nn.Conv1d(in_channels, out_channels, 1) if in_channels != out_channels else nn.Identity()
)
def forward(self, x, cond):
def forward(self, x: Tensor, cond: Tensor) -> Tensor:
"""
x : [ batch_size x in_channels x horizon ]
cond : [ batch_size x cond_dim]
Args:
x: (B, in_channels, T)
cond: (B, cond_dim)
Returns:
(B, out_channels, T)
"""
out = self.conv1(x)
returns:
out : [ batch_size x out_channels x horizon ]
"""
out = self.blocks[0](x)
embed = self.cond_encoder(cond)
if self.cond_predict_scale:
embed = embed.reshape(embed.shape[0], 2, self.out_channels, 1)
scale = embed[:, 0, ...]
bias = embed[:, 1, ...]
# Get condition embedding. Unsqueeze for broadcasting to `out`, resulting in (B, out_channels, 1).
cond_embed = self.cond_encoder(cond).unsqueeze(-1)
if self.film_scale_modulation:
# Treat the embedding as a list of scales and biases.
scale = cond_embed[:, : self.out_channels]
bias = cond_embed[:, self.out_channels :]
out = scale * out + bias
else:
out = out + embed
out = self.blocks[1](out)
# Treat the embedding as biases.
out = out + cond_embed
out = self.conv2(out)
out = out + self.residual_conv(x)
return out
class ConditionalUnet1D(nn.Module):
"""A 1D convolutional UNet with FiLM modulation for conditioning.
Two types of conditioning can be applied:
- Global: Conditioning information that is aggregated over the whole observation window. This is
incorporated via the FiLM technique in the residual convolution blocks of the Unet's encoder/decoder.
- Local: Conditioning information for each timestep in the observation window. This is incorporated
by encoding the information via 1D convolutions and adding the resulting embeddings to the inputs and
outputs of the Unet's encoder/decoder.
"""
def __init__(
self,
input_dim,
local_cond_dim=None,
global_cond_dim=None,
diffusion_step_embed_dim=256,
down_dims=None,
kernel_size=3,
n_groups=8,
cond_predict_scale=False,
input_dim: int,
local_cond_dim: int | None = None,
global_cond_dim: int | None = None,
diffusion_step_embed_dim: int = 256,
down_dims: int | None = None,
kernel_size: int = 3,
n_groups: int = 8,
film_scale_modulation: bool = False,
):
super().__init__()
if down_dims is None:
down_dims = [256, 512, 1024]
all_dims = [input_dim] + list(down_dims)
start_dim = down_dims[0]
dsed = diffusion_step_embed_dim
diffusion_step_encoder = nn.Sequential(
SinusoidalPosEmb(dsed),
nn.Linear(dsed, dsed * 4),
# Encoder for the diffusion timestep.
self.diffusion_step_encoder = nn.Sequential(
_SinusoidalPosEmb(diffusion_step_embed_dim),
nn.Linear(diffusion_step_embed_dim, diffusion_step_embed_dim * 4),
nn.Mish(),
nn.Linear(dsed * 4, dsed),
nn.Linear(diffusion_step_embed_dim * 4, diffusion_step_embed_dim),
)
cond_dim = dsed
# The FiLM conditioning dimension.
cond_dim = diffusion_step_embed_dim
if global_cond_dim is not None:
cond_dim += global_cond_dim
in_out = list(zip(all_dims[:-1], all_dims[1:], strict=False))
local_cond_encoder = None
self.local_cond_down_encoder = None
self.local_cond_up_encoder = None
if local_cond_dim is not None:
_, dim_out = in_out[0]
dim_in = local_cond_dim
local_cond_encoder = nn.ModuleList(
[
# down encoder
ConditionalResidualBlock1D(
dim_in,
dim_out,
# Encoder for the local conditioning. The output gets added to the Unet encoder input.
self.local_cond_down_encoder = _ConditionalResidualBlock1D(
local_cond_dim,
down_dims[0],
cond_dim=cond_dim,
kernel_size=kernel_size,
n_groups=n_groups,
cond_predict_scale=cond_predict_scale,
),
# up encoder
ConditionalResidualBlock1D(
dim_in,
dim_out,
film_scale_modulation=film_scale_modulation,
)
# Encoder for the local conditioning. The output gets added to the Unet encoder output.
self.local_cond_up_encoder = _ConditionalResidualBlock1D(
local_cond_dim,
down_dims[0],
cond_dim=cond_dim,
kernel_size=kernel_size,
n_groups=n_groups,
cond_predict_scale=cond_predict_scale,
),
]
film_scale_modulation=film_scale_modulation,
)
mid_dim = all_dims[-1]
self.mid_modules = nn.ModuleList(
[
ConditionalResidualBlock1D(
mid_dim,
mid_dim,
cond_dim=cond_dim,
kernel_size=kernel_size,
n_groups=n_groups,
cond_predict_scale=cond_predict_scale,
),
ConditionalResidualBlock1D(
mid_dim,
mid_dim,
cond_dim=cond_dim,
kernel_size=kernel_size,
n_groups=n_groups,
cond_predict_scale=cond_predict_scale,
),
]
)
# In channels / out channels for each downsampling block in the Unet's encoder. For the decoder, we
# just reverse these.
in_out = [(input_dim, down_dims[0])] + list(zip(down_dims[:-1], down_dims[1:], strict=True))
down_modules = nn.ModuleList([])
# Unet encoder.
self.down_modules = nn.ModuleList([])
for ind, (dim_in, dim_out) in enumerate(in_out):
is_last = ind >= (len(in_out) - 1)
down_modules.append(
self.down_modules.append(
nn.ModuleList(
[
ConditionalResidualBlock1D(
_ConditionalResidualBlock1D(
dim_in,
dim_out,
cond_dim=cond_dim,
kernel_size=kernel_size,
n_groups=n_groups,
cond_predict_scale=cond_predict_scale,
film_scale_modulation=film_scale_modulation,
),
ConditionalResidualBlock1D(
_ConditionalResidualBlock1D(
dim_out,
dim_out,
cond_dim=cond_dim,
kernel_size=kernel_size,
n_groups=n_groups,
cond_predict_scale=cond_predict_scale,
film_scale_modulation=film_scale_modulation,
),
Downsample1d(dim_out) if not is_last else nn.Identity(),
# Downsample as long as it is not the last block.
nn.Conv1d(dim_out, dim_out, 3, 2, 1) if not is_last else nn.Identity(),
]
)
)
up_modules = nn.ModuleList([])
for ind, (dim_in, dim_out) in enumerate(reversed(in_out[1:])):
# Processing in the middle of the auto-encoder.
self.mid_modules = nn.ModuleList(
[
_ConditionalResidualBlock1D(
down_dims[-1],
down_dims[-1],
cond_dim=cond_dim,
kernel_size=kernel_size,
n_groups=n_groups,
film_scale_modulation=film_scale_modulation,
),
_ConditionalResidualBlock1D(
down_dims[-1],
down_dims[-1],
cond_dim=cond_dim,
kernel_size=kernel_size,
n_groups=n_groups,
film_scale_modulation=film_scale_modulation,
),
]
)
# Unet decoder.
self.up_modules = nn.ModuleList([])
for ind, (dim_out, dim_in) in enumerate(reversed(in_out[1:])):
is_last = ind >= (len(in_out) - 1)
up_modules.append(
self.up_modules.append(
nn.ModuleList(
[
ConditionalResidualBlock1D(
dim_out * 2,
dim_in,
_ConditionalResidualBlock1D(
dim_in * 2, # x2 as it takes the encoder's skip connection as well
dim_out,
cond_dim=cond_dim,
kernel_size=kernel_size,
n_groups=n_groups,
cond_predict_scale=cond_predict_scale,
film_scale_modulation=film_scale_modulation,
),
ConditionalResidualBlock1D(
dim_in,
dim_in,
_ConditionalResidualBlock1D(
dim_out,
dim_out,
cond_dim=cond_dim,
kernel_size=kernel_size,
n_groups=n_groups,
cond_predict_scale=cond_predict_scale,
film_scale_modulation=film_scale_modulation,
),
Upsample1d(dim_in) if not is_last else nn.Identity(),
# Upsample as long as it is not the last block.
nn.ConvTranspose1d(dim_out, dim_out, 4, 2, 1) if not is_last else nn.Identity(),
]
)
)
final_conv = nn.Sequential(
Conv1dBlock(start_dim, start_dim, kernel_size=kernel_size),
nn.Conv1d(start_dim, input_dim, 1),
self.final_conv = nn.Sequential(
_Conv1dBlock(down_dims[0], down_dims[0], kernel_size=kernel_size),
nn.Conv1d(down_dims[0], input_dim, 1),
)
self.diffusion_step_encoder = diffusion_step_encoder
self.local_cond_encoder = local_cond_encoder
self.up_modules = up_modules
self.down_modules = down_modules
self.final_conv = final_conv
logger.info("number of parameters: %e", sum(p.numel() for p in self.parameters()))
def forward(
self,
sample: torch.Tensor,
timestep: Union[torch.Tensor, float, int],
local_cond=None,
global_cond=None,
**kwargs,
):
def forward(self, x: Tensor, timestep: Tensor | int, local_cond=None, global_cond=None) -> Tensor:
"""
x: (B,T,input_dim)
timestep: (B,) or int, diffusion step
local_cond: (B,T,local_cond_dim)
global_cond: (B,global_cond_dim)
output: (B,T,input_dim)
Args:
x: (B, T, input_dim) tensor for input to the Unet.
timestep: (B,) tensor of (timestep_we_are_denoising_from - 1).
local_cond: (B, T, local_cond_dim)
global_cond: (B, global_cond_dim)
output: (B, T, input_dim)
Returns:
(B, T, input_dim)
"""
sample = einops.rearrange(sample, "b h t -> b t h")
# 1. time
timesteps = timestep
if not torch.is_tensor(timesteps):
# TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can
timesteps = torch.tensor([timesteps], dtype=torch.long, device=sample.device)
elif torch.is_tensor(timesteps) and len(timesteps.shape) == 0:
timesteps = timesteps[None].to(sample.device)
# broadcast to batch dimension in a way that's compatible with ONNX/Core ML
timesteps = timesteps.expand(sample.shape[0])
global_feature = self.diffusion_step_encoder(timesteps)
if global_cond is not None:
global_feature = torch.cat([global_feature, global_cond], axis=-1)
# encode local features
h_local = []
# For 1D convolutions we'll need feature dimension first.
x = einops.rearrange(x, "b t d -> b d t")
if local_cond is not None:
local_cond = einops.rearrange(local_cond, "b h t -> b t h")
resnet, resnet2 = self.local_cond_encoder
x = resnet(local_cond, global_feature)
h_local.append(x)
x = resnet2(local_cond, global_feature)
h_local.append(x)
if self.local_cond_down_encoder is None or self.local_cond_up_encoder is None:
raise ValueError(
"`local_cond` was provided but the relevant encoders weren't built at initialization."
)
local_cond = einops.rearrange(local_cond, "b t d -> b d t")
x = sample
h = []
timesteps_embed = self.diffusion_step_encoder(timestep)
# If there is a global conditioning feature, concatenate it to the timestep embedding.
if global_cond is not None:
global_feature = torch.cat([timesteps_embed, global_cond], axis=-1)
else:
global_feature = timesteps_embed
encoder_skip_features: list[Tensor] = []
for idx, (resnet, resnet2, downsample) in enumerate(self.down_modules):
x = resnet(x, global_feature)
if idx == 0 and len(h_local) > 0:
x = x + h_local[0]
if idx == 0 and local_cond is not None:
x = x + self.local_cond_down_encoder(local_cond, global_feature)
x = resnet2(x, global_feature)
h.append(x)
encoder_skip_features.append(x)
x = downsample(x)
for mid_module in self.mid_modules:
x = mid_module(x, global_feature)
for idx, (resnet, resnet2, upsample) in enumerate(self.up_modules):
x = torch.cat((x, h.pop()), dim=1)
x = torch.cat((x, encoder_skip_features.pop()), dim=1)
x = resnet(x, global_feature)
# The correct condition should be:
# if idx == (len(self.up_modules)-1) and len(h_local) > 0:
# However this change will break compatibility with published checkpoints.
# Therefore it is left as a comment.
if idx == len(self.up_modules) and len(h_local) > 0:
x = x + h_local[1]
# Note: The condition in the original implementation is:
# if idx == len(self.up_modules) and local_cond is not None:
# But as they mention in their comments, this is incorrect. We use the correct condition here.
if idx == (len(self.up_modules) - 1) and local_cond is not None:
x = x + self.local_cond_up_encoder(local_cond, global_feature)
x = resnet2(x, global_feature)
x = upsample(x)
x = self.final_conv(x)
x = einops.rearrange(x, "b t h -> b h t")
x = einops.rearrange(x, "b d t -> b t d")
return x

View File

@ -1,47 +0,0 @@
import torch.nn as nn
# from einops.layers.torch import Rearrange
class Downsample1d(nn.Module):
def __init__(self, dim):
super().__init__()
self.conv = nn.Conv1d(dim, dim, 3, 2, 1)
def forward(self, x):
return self.conv(x)
class Upsample1d(nn.Module):
def __init__(self, dim):
super().__init__()
self.conv = nn.ConvTranspose1d(dim, dim, 4, 2, 1)
def forward(self, x):
return self.conv(x)
class Conv1dBlock(nn.Module):
"""
Conv1d --> GroupNorm --> Mish
"""
def __init__(self, inp_channels, out_channels, kernel_size, n_groups=8):
super().__init__()
self.block = nn.Sequential(
nn.Conv1d(inp_channels, out_channels, kernel_size, padding=kernel_size // 2),
# Rearrange('batch channels horizon -> batch channels 1 horizon'),
nn.GroupNorm(n_groups, out_channels),
# Rearrange('batch channels 1 horizon -> batch channels horizon'),
nn.Mish(),
)
def forward(self, x):
return self.block(x)
# def test():
# cb = Conv1dBlock(256, 128, kernel_size=3)
# x = torch.zeros((1,256,16))
# o = cb(x)

View File

@ -1,294 +0,0 @@
import torch
import torch.nn as nn
import torchvision.transforms.functional as ttf
import lerobot.common.policies.diffusion.model.tensor_utils as tu
class CropRandomizer(nn.Module):
"""
Randomly sample crops at input, and then average across crop features at output.
"""
def __init__(
self,
input_shape,
crop_height,
crop_width,
num_crops=1,
pos_enc=False,
):
"""
Args:
input_shape (tuple, list): shape of input (not including batch dimension)
crop_height (int): crop height
crop_width (int): crop width
num_crops (int): number of random crops to take
pos_enc (bool): if True, add 2 channels to the output to encode the spatial
location of the cropped pixels in the source image
"""
super().__init__()
assert len(input_shape) == 3 # (C, H, W)
assert crop_height < input_shape[1]
assert crop_width < input_shape[2]
self.input_shape = input_shape
self.crop_height = crop_height
self.crop_width = crop_width
self.num_crops = num_crops
self.pos_enc = pos_enc
def output_shape_in(self, input_shape=None):
"""
Function to compute output shape from inputs to this module. Corresponds to
the @forward_in operation, where raw inputs (usually observation modalities)
are passed in.
Args:
input_shape (iterable of int): shape of input. Does not include batch dimension.
Some modules may not need this argument, if their output does not depend
on the size of the input, or if they assume fixed size input.
Returns:
out_shape ([int]): list of integers corresponding to output shape
"""
# outputs are shape (C, CH, CW), or maybe C + 2 if using position encoding, because
# the number of crops are reshaped into the batch dimension, increasing the batch
# size from B to B * N
out_c = self.input_shape[0] + 2 if self.pos_enc else self.input_shape[0]
return [out_c, self.crop_height, self.crop_width]
def output_shape_out(self, input_shape=None):
"""
Function to compute output shape from inputs to this module. Corresponds to
the @forward_out operation, where processed inputs (usually encoded observation
modalities) are passed in.
Args:
input_shape (iterable of int): shape of input. Does not include batch dimension.
Some modules may not need this argument, if their output does not depend
on the size of the input, or if they assume fixed size input.
Returns:
out_shape ([int]): list of integers corresponding to output shape
"""
# since the forward_out operation splits [B * N, ...] -> [B, N, ...]
# and then pools to result in [B, ...], only the batch dimension changes,
# and so the other dimensions retain their shape.
return list(input_shape)
def forward_in(self, inputs):
"""
Samples N random crops for each input in the batch, and then reshapes
inputs to [B * N, ...].
"""
assert len(inputs.shape) >= 3 # must have at least (C, H, W) dimensions
if self.training:
# generate random crops
out, _ = sample_random_image_crops(
images=inputs,
crop_height=self.crop_height,
crop_width=self.crop_width,
num_crops=self.num_crops,
pos_enc=self.pos_enc,
)
# [B, N, ...] -> [B * N, ...]
return tu.join_dimensions(out, 0, 1)
else:
# take center crop during eval
out = ttf.center_crop(img=inputs, output_size=(self.crop_height, self.crop_width))
if self.num_crops > 1:
B, C, H, W = out.shape # noqa: N806
out = out.unsqueeze(1).expand(B, self.num_crops, C, H, W).reshape(-1, C, H, W)
# [B * N, ...]
return out
def forward_out(self, inputs):
"""
Splits the outputs from shape [B * N, ...] -> [B, N, ...] and then average across N
to result in shape [B, ...] to make sure the network output is consistent with
what would have happened if there were no randomization.
"""
if self.num_crops <= 1:
return inputs
else:
batch_size = inputs.shape[0] // self.num_crops
out = tu.reshape_dimensions(
inputs, begin_axis=0, end_axis=0, target_dims=(batch_size, self.num_crops)
)
return out.mean(dim=1)
def forward(self, inputs):
return self.forward_in(inputs)
def __repr__(self):
"""Pretty print network."""
header = "{}".format(str(self.__class__.__name__))
msg = header + "(input_shape={}, crop_size=[{}, {}], num_crops={})".format(
self.input_shape, self.crop_height, self.crop_width, self.num_crops
)
return msg
def crop_image_from_indices(images, crop_indices, crop_height, crop_width):
"""
Crops images at the locations specified by @crop_indices. Crops will be
taken across all channels.
Args:
images (torch.Tensor): batch of images of shape [..., C, H, W]
crop_indices (torch.Tensor): batch of indices of shape [..., N, 2] where
N is the number of crops to take per image and each entry corresponds
to the pixel height and width of where to take the crop. Note that
the indices can also be of shape [..., 2] if only 1 crop should
be taken per image. Leading dimensions must be consistent with
@images argument. Each index specifies the top left of the crop.
Values must be in range [0, H - CH - 1] x [0, W - CW - 1] where
H and W are the height and width of @images and CH and CW are
@crop_height and @crop_width.
crop_height (int): height of crop to take
crop_width (int): width of crop to take
Returns:
crops (torch.Tesnor): cropped images of shape [..., C, @crop_height, @crop_width]
"""
# make sure length of input shapes is consistent
assert crop_indices.shape[-1] == 2
ndim_im_shape = len(images.shape)
ndim_indices_shape = len(crop_indices.shape)
assert (ndim_im_shape == ndim_indices_shape + 1) or (ndim_im_shape == ndim_indices_shape + 2)
# maybe pad so that @crop_indices is shape [..., N, 2]
is_padded = False
if ndim_im_shape == ndim_indices_shape + 2:
crop_indices = crop_indices.unsqueeze(-2)
is_padded = True
# make sure leading dimensions between images and indices are consistent
assert images.shape[:-3] == crop_indices.shape[:-2]
device = images.device
image_c, image_h, image_w = images.shape[-3:]
num_crops = crop_indices.shape[-2]
# make sure @crop_indices are in valid range
assert (crop_indices[..., 0] >= 0).all().item()
assert (crop_indices[..., 0] < (image_h - crop_height)).all().item()
assert (crop_indices[..., 1] >= 0).all().item()
assert (crop_indices[..., 1] < (image_w - crop_width)).all().item()
# convert each crop index (ch, cw) into a list of pixel indices that correspond to the entire window.
# 2D index array with columns [0, 1, ..., CH - 1] and shape [CH, CW]
crop_ind_grid_h = torch.arange(crop_height).to(device)
crop_ind_grid_h = tu.unsqueeze_expand_at(crop_ind_grid_h, size=crop_width, dim=-1)
# 2D index array with rows [0, 1, ..., CW - 1] and shape [CH, CW]
crop_ind_grid_w = torch.arange(crop_width).to(device)
crop_ind_grid_w = tu.unsqueeze_expand_at(crop_ind_grid_w, size=crop_height, dim=0)
# combine into shape [CH, CW, 2]
crop_in_grid = torch.cat((crop_ind_grid_h.unsqueeze(-1), crop_ind_grid_w.unsqueeze(-1)), dim=-1)
# Add above grid with the offset index of each sampled crop to get 2d indices for each crop.
# After broadcasting, this will be shape [..., N, CH, CW, 2] and each crop has a [CH, CW, 2]
# shape array that tells us which pixels from the corresponding source image to grab.
grid_reshape = [1] * len(crop_indices.shape[:-1]) + [crop_height, crop_width, 2]
all_crop_inds = crop_indices.unsqueeze(-2).unsqueeze(-2) + crop_in_grid.reshape(grid_reshape)
# For using @torch.gather, convert to flat indices from 2D indices, and also
# repeat across the channel dimension. To get flat index of each pixel to grab for
# each sampled crop, we just use the mapping: ind = h_ind * @image_w + w_ind
all_crop_inds = all_crop_inds[..., 0] * image_w + all_crop_inds[..., 1] # shape [..., N, CH, CW]
all_crop_inds = tu.unsqueeze_expand_at(all_crop_inds, size=image_c, dim=-3) # shape [..., N, C, CH, CW]
all_crop_inds = tu.flatten(all_crop_inds, begin_axis=-2) # shape [..., N, C, CH * CW]
# Repeat and flatten the source images -> [..., N, C, H * W] and then use gather to index with crop pixel inds
images_to_crop = tu.unsqueeze_expand_at(images, size=num_crops, dim=-4)
images_to_crop = tu.flatten(images_to_crop, begin_axis=-2)
crops = torch.gather(images_to_crop, dim=-1, index=all_crop_inds)
# [..., N, C, CH * CW] -> [..., N, C, CH, CW]
reshape_axis = len(crops.shape) - 1
crops = tu.reshape_dimensions(
crops, begin_axis=reshape_axis, end_axis=reshape_axis, target_dims=(crop_height, crop_width)
)
if is_padded:
# undo padding -> [..., C, CH, CW]
crops = crops.squeeze(-4)
return crops
def sample_random_image_crops(images, crop_height, crop_width, num_crops, pos_enc=False):
"""
For each image, randomly sample @num_crops crops of size (@crop_height, @crop_width), from
@images.
Args:
images (torch.Tensor): batch of images of shape [..., C, H, W]
crop_height (int): height of crop to take
crop_width (int): width of crop to take
num_crops (n): number of crops to sample
pos_enc (bool): if True, also add 2 channels to the outputs that gives a spatial
encoding of the original source pixel locations. This means that the
output crops will contain information about where in the source image
it was sampled from.
Returns:
crops (torch.Tensor): crops of shape (..., @num_crops, C, @crop_height, @crop_width)
if @pos_enc is False, otherwise (..., @num_crops, C + 2, @crop_height, @crop_width)
crop_inds (torch.Tensor): sampled crop indices of shape (..., N, 2)
"""
device = images.device
# maybe add 2 channels of spatial encoding to the source image
source_im = images
if pos_enc:
# spatial encoding [y, x] in [0, 1]
h, w = source_im.shape[-2:]
pos_y, pos_x = torch.meshgrid(torch.arange(h), torch.arange(w))
pos_y = pos_y.float().to(device) / float(h)
pos_x = pos_x.float().to(device) / float(w)
position_enc = torch.stack((pos_y, pos_x)) # shape [C, H, W]
# unsqueeze and expand to match leading dimensions -> shape [..., C, H, W]
leading_shape = source_im.shape[:-3]
position_enc = position_enc[(None,) * len(leading_shape)]
position_enc = position_enc.expand(*leading_shape, -1, -1, -1)
# concat across channel dimension with input
source_im = torch.cat((source_im, position_enc), dim=-3)
# make sure sample boundaries ensure crops are fully within the images
image_c, image_h, image_w = source_im.shape[-3:]
max_sample_h = image_h - crop_height
max_sample_w = image_w - crop_width
# Sample crop locations for all tensor dimensions up to the last 3, which are [C, H, W].
# Each gets @num_crops samples - typically this will just be the batch dimension (B), so
# we will sample [B, N] indices, but this supports having more than one leading dimension,
# or possibly no leading dimension.
#
# Trick: sample in [0, 1) with rand, then re-scale to [0, M) and convert to long to get sampled ints
crop_inds_h = (max_sample_h * torch.rand(*source_im.shape[:-3], num_crops).to(device)).long()
crop_inds_w = (max_sample_w * torch.rand(*source_im.shape[:-3], num_crops).to(device)).long()
crop_inds = torch.cat((crop_inds_h.unsqueeze(-1), crop_inds_w.unsqueeze(-1)), dim=-1) # shape [..., N, 2]
crops = crop_image_from_indices(
images=source_im,
crop_indices=crop_inds,
crop_height=crop_height,
crop_width=crop_width,
)
return crops, crop_inds

View File

@ -1,41 +0,0 @@
import torch
import torch.nn as nn
class DictOfTensorMixin(nn.Module):
def __init__(self, params_dict=None):
super().__init__()
if params_dict is None:
params_dict = nn.ParameterDict()
self.params_dict = params_dict
@property
def device(self):
return next(iter(self.parameters())).device
def _load_from_state_dict(
self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs
):
def dfs_add(dest, keys, value: torch.Tensor):
if len(keys) == 1:
dest[keys[0]] = value
return
if keys[0] not in dest:
dest[keys[0]] = nn.ParameterDict()
dfs_add(dest[keys[0]], keys[1:], value)
def load_dict(state_dict, prefix):
out_dict = nn.ParameterDict()
for key, value in state_dict.items():
value: torch.Tensor
if key.startswith(prefix):
param_keys = key[len(prefix) :].split(".")[1:]
# if len(param_keys) == 0:
# import pdb; pdb.set_trace()
dfs_add(out_dict, param_keys, value.clone())
return out_dict
self.params_dict = load_dict(state_dict, prefix + "params_dict")
self.params_dict.requires_grad_(False)
return

View File

@ -0,0 +1,220 @@
import einops
import torch
import torch.nn.functional as F # noqa: N812
from diffusers.schedulers.scheduling_ddpm import DDPMScheduler
from torch import Tensor, nn
from lerobot.common.policies.diffusion.model.conditional_unet1d import ConditionalUnet1D
from lerobot.common.policies.diffusion.model.rgb_encoder import RgbEncoder
from lerobot.common.policies.utils import get_device_from_parameters, get_dtype_from_parameters
class DiffusionUnetImagePolicy(nn.Module):
"""
TODO(now): Add DDIM scheduler.
Changes: TODO(now)
- Use single image encoder for now instead of generic obs_encoder. We may generalize again when/if
needed. Code for a general observation encoder can be found at:
https://github.com/huggingface/lerobot/blob/920e0d118b493e4cc3058a9b1b764f38ae145d8e/lerobot/common/policies/diffusion/model/multi_image_obs_encoder.py
- Uses the observation as global conditioning for the Unet by default.
- Does not do any inpainting (which would be applicable if the observation were not used to condition
the Unet).
"""
def __init__(
self,
cfg,
shape_meta: dict,
noise_scheduler: DDPMScheduler,
horizon,
n_action_steps,
n_obs_steps,
num_inference_steps=None,
diffusion_step_embed_dim=256,
down_dims=(256, 512, 1024),
kernel_size=5,
n_groups=8,
film_scale_modulation=True,
):
super().__init__()
action_shape = shape_meta["action"]["shape"]
assert len(action_shape) == 1
action_dim = action_shape[0]
self.rgb_encoder = RgbEncoder(input_shape=shape_meta.obs.image.shape, **cfg.rgb_encoder)
self.unet = ConditionalUnet1D(
input_dim=action_dim,
global_cond_dim=(action_dim + self.rgb_encoder.feature_dim) * n_obs_steps,
diffusion_step_embed_dim=diffusion_step_embed_dim,
down_dims=down_dims,
kernel_size=kernel_size,
n_groups=n_groups,
film_scale_modulation=film_scale_modulation,
)
self.noise_scheduler = noise_scheduler
self.horizon = horizon
self.action_dim = action_dim
self.n_action_steps = n_action_steps
self.n_obs_steps = n_obs_steps
if num_inference_steps is None:
num_inference_steps = noise_scheduler.config.num_train_timesteps
self.num_inference_steps = num_inference_steps
# ========= inference ============
def conditional_sample(
self,
condition_data,
inpainting_mask,
local_cond=None,
global_cond=None,
generator=None,
):
model = self.unet
scheduler = self.noise_scheduler
trajectory = torch.randn(
size=condition_data.shape,
dtype=condition_data.dtype,
device=condition_data.device,
generator=generator,
)
# set step values
scheduler.set_timesteps(self.num_inference_steps)
for t in scheduler.timesteps:
# 1. apply conditioning
trajectory[inpainting_mask] = condition_data[inpainting_mask]
# 2. predict model output
model_output = model(
trajectory,
torch.full(trajectory.shape[:1], t, dtype=torch.long, device=trajectory.device),
local_cond=local_cond,
global_cond=global_cond,
)
# 3. compute previous image: x_t -> x_t-1
trajectory = scheduler.step(
model_output,
t,
trajectory,
generator=generator,
).prev_sample
# finally make sure conditioning is enforced
trajectory[inpainting_mask] = condition_data[inpainting_mask]
return trajectory
def predict_action(self, batch: dict[str, Tensor]) -> dict[str, Tensor]:
"""
This function expects `batch` to have (at least):
{
"observation.state": (B, n_obs_steps, state_dim)
"observation.image": (B, n_obs_steps, C, H, W)
}
"""
assert set(batch).issuperset({"observation.state", "observation.image"})
batch_size, n_obs_steps = batch["observation.state"].shape[:2]
assert n_obs_steps == self.n_obs_steps
assert self.n_obs_steps == n_obs_steps
# build input
device = get_device_from_parameters(self)
dtype = get_dtype_from_parameters(self)
# Extract image feature (first combine batch and sequence dims).
img_features = self.rgb_encoder(einops.rearrange(batch["observation.image"], "b n ... -> (b n) ..."))
# Separate batch and sequence dims.
img_features = einops.rearrange(img_features, "(b n) ... -> b n ...", b=batch_size)
# Concatenate state and image features then flatten to (B, global_cond_dim).
global_cond = torch.cat([batch["observation.state"], img_features], dim=-1).flatten(start_dim=1)
# reshape back to B, Do
# empty data for action
cond_data = torch.zeros(size=(batch_size, self.horizon, self.action_dim), device=device, dtype=dtype)
cond_mask = torch.zeros_like(cond_data, dtype=torch.bool)
# run sampling
nsample = self.conditional_sample(cond_data, cond_mask, global_cond=global_cond)
# `horizon` steps worth of actions (from the first observation).
action = nsample[..., : self.action_dim]
# Extract `n_action_steps` steps worth of action (from the current observation).
start = n_obs_steps - 1
end = start + self.n_action_steps
action = action[:, start:end]
return action
def compute_loss(self, batch: dict[str, Tensor]) -> Tensor:
"""
This function expects `batch` to have (at least):
{
"observation.state": (B, n_obs_steps, state_dim)
"observation.image": (B, n_obs_steps, C, H, W)
"action": (B, horizon, action_dim)
"action_is_pad": (B, horizon) # TODO(now) maybe this is (B, horizon, 1)
}
"""
assert set(batch).issuperset({"observation.state", "observation.image", "action", "action_is_pad"})
batch_size, n_obs_steps = batch["observation.state"].shape[:2]
horizon = batch["action"].shape[1]
assert horizon == self.horizon
assert n_obs_steps == self.n_obs_steps
assert self.n_obs_steps == n_obs_steps
# handle different ways of passing observation
local_cond = None
global_cond = None
trajectory = batch["action"]
cond_data = trajectory
# Extract image feature (first combine batch and sequence dims).
img_features = self.rgb_encoder(einops.rearrange(batch["observation.image"], "b n ... -> (b n) ..."))
# Separate batch and sequence dims.
img_features = einops.rearrange(img_features, "(b n) ... -> b n ...", b=batch_size)
# Concatenate state and image features then flatten to (B, global_cond_dim).
global_cond = torch.cat([batch["observation.state"], img_features], dim=-1).flatten(start_dim=1)
# Sample noise that we'll add to the images
noise = torch.randn(trajectory.shape, device=trajectory.device)
# Sample a random timestep for each image
timesteps = torch.randint(
0,
self.noise_scheduler.config.num_train_timesteps,
(trajectory.shape[0],),
device=trajectory.device,
).long()
# Add noise to the clean images according to the noise magnitude at each timestep
# (this is the forward diffusion process)
noisy_trajectory = self.noise_scheduler.add_noise(trajectory, noise, timesteps)
# Apply inpainting. TODO(now): implement?
inpainting_mask = torch.zeros_like(trajectory, dtype=bool)
noisy_trajectory[inpainting_mask] = cond_data[inpainting_mask]
# Predict the noise residual
pred = self.unet(noisy_trajectory, timesteps, local_cond=local_cond, global_cond=global_cond)
pred_type = self.noise_scheduler.config.prediction_type
if pred_type == "epsilon":
target = noise
elif pred_type == "sample":
target = trajectory
else:
raise ValueError(f"Unsupported prediction type {pred_type}")
loss = F.mse_loss(pred, target, reduction="none")
loss = loss * (~inpainting_mask)
if "action_is_pad" in batch:
in_episode_bound = ~batch["action_is_pad"]
loss = loss * in_episode_bound[:, :, None].type(loss.dtype)
return loss.mean()

View File

@ -51,13 +51,6 @@ class EMAModel:
def step(self, new_model):
self.decay = self.get_decay(self.optimization_step)
# old_all_dataptrs = set()
# for param in new_model.parameters():
# data_ptr = param.data_ptr()
# if data_ptr != 0:
# old_all_dataptrs.add(data_ptr)
# all_dataptrs = set()
for module, ema_module in zip(new_model.modules(), self.averaged_model.modules(), strict=False):
for param, ema_param in zip(
module.parameters(recurse=False), ema_module.parameters(recurse=False), strict=False
@ -66,10 +59,6 @@ class EMAModel:
if isinstance(param, dict):
raise RuntimeError("Dict parameter not supported")
# data_ptr = param.data_ptr()
# if data_ptr != 0:
# all_dataptrs.add(data_ptr)
if isinstance(module, _BatchNorm):
# skip batchnorms
ema_param.copy_(param.to(dtype=ema_param.dtype).data)
@ -79,6 +68,4 @@ class EMAModel:
ema_param.mul_(self.decay)
ema_param.add_(param.data.to(dtype=ema_param.dtype), alpha=1 - self.decay)
# verify that iterating over module and then parameters is identical to parameters recursively.
# assert old_all_dataptrs == all_dataptrs
self.optimization_step += 1

View File

@ -1,46 +0,0 @@
from diffusers.optimization import TYPE_TO_SCHEDULER_FUNCTION, Optimizer, Optional, SchedulerType, Union
def get_scheduler(
name: Union[str, SchedulerType],
optimizer: Optimizer,
num_warmup_steps: Optional[int] = None,
num_training_steps: Optional[int] = None,
**kwargs,
):
"""
Added kwargs vs diffuser's original implementation
Unified API to get any scheduler from its name.
Args:
name (`str` or `SchedulerType`):
The name of the scheduler to use.
optimizer (`torch.optim.Optimizer`):
The optimizer that will be used during training.
num_warmup_steps (`int`, *optional*):
The number of warmup steps to do. This is not required by all schedulers (hence the argument being
optional), the function will raise an error if it's unset and the scheduler type requires it.
num_training_steps (`int``, *optional*):
The number of training steps to do. This is not required by all schedulers (hence the argument being
optional), the function will raise an error if it's unset and the scheduler type requires it.
"""
name = SchedulerType(name)
schedule_func = TYPE_TO_SCHEDULER_FUNCTION[name]
if name == SchedulerType.CONSTANT:
return schedule_func(optimizer, **kwargs)
# All other schedulers require `num_warmup_steps`
if num_warmup_steps is None:
raise ValueError(f"{name} requires `num_warmup_steps`, please provide that argument.")
if name == SchedulerType.CONSTANT_WITH_WARMUP:
return schedule_func(optimizer, num_warmup_steps=num_warmup_steps, **kwargs)
# All other schedulers require `num_training_steps`
if num_training_steps is None:
raise ValueError(f"{name} requires `num_training_steps`, please provide that argument.")
return schedule_func(
optimizer, num_warmup_steps=num_warmup_steps, num_training_steps=num_training_steps, **kwargs
)

View File

@ -1,65 +0,0 @@
import torch
from lerobot.common.policies.diffusion.model.module_attr_mixin import ModuleAttrMixin
class LowdimMaskGenerator(ModuleAttrMixin):
def __init__(
self,
action_dim,
obs_dim,
# obs mask setup
max_n_obs_steps=2,
fix_obs_steps=True,
# action mask
action_visible=False,
):
super().__init__()
self.action_dim = action_dim
self.obs_dim = obs_dim
self.max_n_obs_steps = max_n_obs_steps
self.fix_obs_steps = fix_obs_steps
self.action_visible = action_visible
@torch.no_grad()
def forward(self, shape, seed=None):
device = self.device
B, T, D = shape # noqa: N806
assert (self.action_dim + self.obs_dim) == D
# create all tensors on this device
rng = torch.Generator(device=device)
if seed is not None:
rng = rng.manual_seed(seed)
# generate dim mask
dim_mask = torch.zeros(size=shape, dtype=torch.bool, device=device)
is_action_dim = dim_mask.clone()
is_action_dim[..., : self.action_dim] = True
is_obs_dim = ~is_action_dim
# generate obs mask
if self.fix_obs_steps:
obs_steps = torch.full((B,), fill_value=self.max_n_obs_steps, device=device)
else:
obs_steps = torch.randint(
low=1, high=self.max_n_obs_steps + 1, size=(B,), generator=rng, device=device
)
steps = torch.arange(0, T, device=device).reshape(1, T).expand(B, T)
obs_mask = (obs_steps > steps.T).T.reshape(B, T, 1).expand(B, T, D)
obs_mask = obs_mask & is_obs_dim
# generate action mask
if self.action_visible:
action_steps = torch.maximum(
obs_steps - 1, torch.tensor(0, dtype=obs_steps.dtype, device=obs_steps.device)
)
action_mask = (action_steps > steps.T).T.reshape(B, T, 1).expand(B, T, D)
action_mask = action_mask & is_action_dim
mask = obs_mask
if self.action_visible:
mask = mask | action_mask
return mask

View File

@ -1,15 +0,0 @@
import torch.nn as nn
class ModuleAttrMixin(nn.Module):
def __init__(self):
super().__init__()
self._dummy_variable = nn.Parameter()
@property
def device(self):
return next(iter(self.parameters())).device
@property
def dtype(self):
return next(iter(self.parameters())).dtype

View File

@ -1,214 +0,0 @@
import copy
from typing import Dict, Optional, Tuple, Union
import torch
import torch.nn as nn
import torchvision
from robomimic.models.base_nets import ResNet18Conv, SpatialSoftmax
from lerobot.common.policies.diffusion.model.crop_randomizer import CropRandomizer
from lerobot.common.policies.diffusion.model.module_attr_mixin import ModuleAttrMixin
from lerobot.common.policies.diffusion.pytorch_utils import replace_submodules
class RgbEncoder(nn.Module):
"""Following `VisualCore` from Robomimic 0.2.0."""
def __init__(self, input_shape, relu=True, pretrained=False, num_keypoints=32):
"""
input_shape: channel-first input shape (C, H, W)
resnet_name: a timm model name.
pretrained: whether to use timm pretrained weights.
relu: whether to use relu as a final step.
num_keypoints: Number of keypoints for SpatialSoftmax (default value of 32 matches PushT Image).
"""
super().__init__()
self.backbone = ResNet18Conv(input_channel=input_shape[0], pretrained=pretrained)
# Figure out the feature map shape.
with torch.inference_mode():
feat_map_shape = tuple(self.backbone(torch.zeros(size=(1, *input_shape))).shape[1:])
self.pool = SpatialSoftmax(feat_map_shape, num_kp=num_keypoints)
self.out = nn.Linear(num_keypoints * 2, num_keypoints * 2)
self.relu = nn.ReLU() if relu else nn.Identity()
def forward(self, x):
return self.relu(self.out(torch.flatten(self.pool(self.backbone(x)), start_dim=1)))
class MultiImageObsEncoder(ModuleAttrMixin):
def __init__(
self,
shape_meta: dict,
rgb_model: Union[nn.Module, Dict[str, nn.Module]],
resize_shape: Union[Tuple[int, int], Dict[str, tuple], None] = None,
crop_shape: Union[Tuple[int, int], Dict[str, tuple], None] = None,
random_crop: bool = True,
# replace BatchNorm with GroupNorm
use_group_norm: bool = False,
# use single rgb model for all rgb inputs
share_rgb_model: bool = False,
# renormalize rgb input with imagenet normalization
# assuming input in [0,1]
norm_mean_std: Optional[tuple[float, float]] = None,
):
"""
Assumes rgb input: B,C,H,W
Assumes low_dim input: B,D
"""
super().__init__()
rgb_keys = []
low_dim_keys = []
key_model_map = nn.ModuleDict()
key_transform_map = nn.ModuleDict()
key_shape_map = {}
# handle sharing vision backbone
if share_rgb_model:
assert isinstance(rgb_model, nn.Module)
key_model_map["rgb"] = rgb_model
obs_shape_meta = shape_meta["obs"]
for key, attr in obs_shape_meta.items():
shape = tuple(attr["shape"])
type = attr.get("type", "low_dim")
key_shape_map[key] = shape
if type == "rgb":
rgb_keys.append(key)
# configure model for this key
this_model = None
if not share_rgb_model:
if isinstance(rgb_model, dict):
# have provided model for each key
this_model = rgb_model[key]
else:
assert isinstance(rgb_model, nn.Module)
# have a copy of the rgb model
this_model = copy.deepcopy(rgb_model)
if this_model is not None:
if use_group_norm:
this_model = replace_submodules(
root_module=this_model,
predicate=lambda x: isinstance(x, nn.BatchNorm2d),
func=lambda x: nn.GroupNorm(
num_groups=x.num_features // 16, num_channels=x.num_features
),
)
key_model_map[key] = this_model
# configure resize
input_shape = shape
this_resizer = nn.Identity()
if resize_shape is not None:
if isinstance(resize_shape, dict):
h, w = resize_shape[key]
else:
h, w = resize_shape
this_resizer = torchvision.transforms.Resize(size=(h, w))
input_shape = (shape[0], h, w)
# configure randomizer
this_randomizer = nn.Identity()
if crop_shape is not None:
if isinstance(crop_shape, dict):
h, w = crop_shape[key]
else:
h, w = crop_shape
if random_crop:
this_randomizer = CropRandomizer(
input_shape=input_shape, crop_height=h, crop_width=w, num_crops=1, pos_enc=False
)
else:
this_normalizer = torchvision.transforms.CenterCrop(size=(h, w))
# configure normalizer
this_normalizer = nn.Identity()
if norm_mean_std is not None:
this_normalizer = torchvision.transforms.Normalize(
mean=norm_mean_std[0], std=norm_mean_std[1]
)
this_transform = nn.Sequential(this_resizer, this_randomizer, this_normalizer)
key_transform_map[key] = this_transform
elif type == "low_dim":
low_dim_keys.append(key)
else:
raise RuntimeError(f"Unsupported obs type: {type}")
rgb_keys = sorted(rgb_keys)
low_dim_keys = sorted(low_dim_keys)
self.shape_meta = shape_meta
self.key_model_map = key_model_map
self.key_transform_map = key_transform_map
self.share_rgb_model = share_rgb_model
self.rgb_keys = rgb_keys
self.low_dim_keys = low_dim_keys
self.key_shape_map = key_shape_map
def forward(self, obs_dict):
batch_size = None
features = []
# process lowdim input
for key in self.low_dim_keys:
data = obs_dict[key]
if batch_size is None:
batch_size = data.shape[0]
else:
assert batch_size == data.shape[0]
assert data.shape[1:] == self.key_shape_map[key]
features.append(data)
# process rgb input
if self.share_rgb_model:
# pass all rgb obs to rgb model
imgs = []
for key in self.rgb_keys:
img = obs_dict[key]
if batch_size is None:
batch_size = img.shape[0]
else:
assert batch_size == img.shape[0]
assert img.shape[1:] == self.key_shape_map[key]
img = self.key_transform_map[key](img)
imgs.append(img)
# (N*B,C,H,W)
imgs = torch.cat(imgs, dim=0)
# (N*B,D)
feature = self.key_model_map["rgb"](imgs)
# (N,B,D)
feature = feature.reshape(-1, batch_size, *feature.shape[1:])
# (B,N,D)
feature = torch.moveaxis(feature, 0, 1)
# (B,N*D)
feature = feature.reshape(batch_size, -1)
features.append(feature)
else:
# run each rgb obs to independent models
for key in self.rgb_keys:
img = obs_dict[key]
if batch_size is None:
batch_size = img.shape[0]
else:
assert batch_size == img.shape[0]
assert img.shape[1:] == self.key_shape_map[key]
img = self.key_transform_map[key](img)
feature = self.key_model_map[key](img)
features.append(feature)
# concatenate all features
result = torch.cat(features, dim=-1)
return result
@torch.no_grad()
def output_shape(self):
example_obs_dict = {}
obs_shape_meta = self.shape_meta["obs"]
batch_size = 1
for key, attr in obs_shape_meta.items():
shape = tuple(attr["shape"])
this_obs = torch.zeros((batch_size,) + shape, dtype=self.dtype, device=self.device)
example_obs_dict[key] = this_obs
example_output = self.forward(example_obs_dict)
output_shape = example_output.shape[1:]
return output_shape

View File

@ -1,358 +0,0 @@
from typing import Dict, Union
import numpy as np
import torch
import torch.nn as nn
import zarr
from lerobot.common.policies.diffusion.model.dict_of_tensor_mixin import DictOfTensorMixin
from lerobot.common.policies.diffusion.pytorch_utils import dict_apply
class LinearNormalizer(DictOfTensorMixin):
avaliable_modes = ["limits", "gaussian"]
@torch.no_grad()
def fit(
self,
data: Union[Dict, torch.Tensor, np.ndarray, zarr.Array],
last_n_dims=1,
dtype=torch.float32,
mode="limits",
output_max=1.0,
output_min=-1.0,
range_eps=1e-4,
fit_offset=True,
):
if isinstance(data, dict):
for key, value in data.items():
self.params_dict[key] = _fit(
value,
last_n_dims=last_n_dims,
dtype=dtype,
mode=mode,
output_max=output_max,
output_min=output_min,
range_eps=range_eps,
fit_offset=fit_offset,
)
else:
self.params_dict["_default"] = _fit(
data,
last_n_dims=last_n_dims,
dtype=dtype,
mode=mode,
output_max=output_max,
output_min=output_min,
range_eps=range_eps,
fit_offset=fit_offset,
)
def __call__(self, x: Union[Dict, torch.Tensor, np.ndarray]) -> torch.Tensor:
return self.normalize(x)
def __getitem__(self, key: str):
return SingleFieldLinearNormalizer(self.params_dict[key])
def __setitem__(self, key: str, value: "SingleFieldLinearNormalizer"):
self.params_dict[key] = value.params_dict
def _normalize_impl(self, x, forward=True):
if isinstance(x, dict):
result = {}
for key, value in x.items():
params = self.params_dict[key]
result[key] = _normalize(value, params, forward=forward)
return result
else:
if "_default" not in self.params_dict:
raise RuntimeError("Not initialized")
params = self.params_dict["_default"]
return _normalize(x, params, forward=forward)
def normalize(self, x: Union[Dict, torch.Tensor, np.ndarray]) -> torch.Tensor:
return self._normalize_impl(x, forward=True)
def unnormalize(self, x: Union[Dict, torch.Tensor, np.ndarray]) -> torch.Tensor:
return self._normalize_impl(x, forward=False)
def get_input_stats(self) -> Dict:
if len(self.params_dict) == 0:
raise RuntimeError("Not initialized")
if len(self.params_dict) == 1 and "_default" in self.params_dict:
return self.params_dict["_default"]["input_stats"]
result = {}
for key, value in self.params_dict.items():
if key != "_default":
result[key] = value["input_stats"]
return result
def get_output_stats(self, key="_default"):
input_stats = self.get_input_stats()
if "min" in input_stats:
# no dict
return dict_apply(input_stats, self.normalize)
result = {}
for key, group in input_stats.items():
this_dict = {}
for name, value in group.items():
this_dict[name] = self.normalize({key: value})[key]
result[key] = this_dict
return result
class SingleFieldLinearNormalizer(DictOfTensorMixin):
avaliable_modes = ["limits", "gaussian"]
@torch.no_grad()
def fit(
self,
data: Union[torch.Tensor, np.ndarray, zarr.Array],
last_n_dims=1,
dtype=torch.float32,
mode="limits",
output_max=1.0,
output_min=-1.0,
range_eps=1e-4,
fit_offset=True,
):
self.params_dict = _fit(
data,
last_n_dims=last_n_dims,
dtype=dtype,
mode=mode,
output_max=output_max,
output_min=output_min,
range_eps=range_eps,
fit_offset=fit_offset,
)
@classmethod
def create_fit(cls, data: Union[torch.Tensor, np.ndarray, zarr.Array], **kwargs):
obj = cls()
obj.fit(data, **kwargs)
return obj
@classmethod
def create_manual(
cls,
scale: Union[torch.Tensor, np.ndarray],
offset: Union[torch.Tensor, np.ndarray],
input_stats_dict: Dict[str, Union[torch.Tensor, np.ndarray]],
):
def to_tensor(x):
if not isinstance(x, torch.Tensor):
x = torch.from_numpy(x)
x = x.flatten()
return x
# check
for x in [offset] + list(input_stats_dict.values()):
assert x.shape == scale.shape
assert x.dtype == scale.dtype
params_dict = nn.ParameterDict(
{
"scale": to_tensor(scale),
"offset": to_tensor(offset),
"input_stats": nn.ParameterDict(dict_apply(input_stats_dict, to_tensor)),
}
)
return cls(params_dict)
@classmethod
def create_identity(cls, dtype=torch.float32):
scale = torch.tensor([1], dtype=dtype)
offset = torch.tensor([0], dtype=dtype)
input_stats_dict = {
"min": torch.tensor([-1], dtype=dtype),
"max": torch.tensor([1], dtype=dtype),
"mean": torch.tensor([0], dtype=dtype),
"std": torch.tensor([1], dtype=dtype),
}
return cls.create_manual(scale, offset, input_stats_dict)
def normalize(self, x: Union[torch.Tensor, np.ndarray]) -> torch.Tensor:
return _normalize(x, self.params_dict, forward=True)
def unnormalize(self, x: Union[torch.Tensor, np.ndarray]) -> torch.Tensor:
return _normalize(x, self.params_dict, forward=False)
def get_input_stats(self):
return self.params_dict["input_stats"]
def get_output_stats(self):
return dict_apply(self.params_dict["input_stats"], self.normalize)
def __call__(self, x: Union[torch.Tensor, np.ndarray]) -> torch.Tensor:
return self.normalize(x)
def _fit(
data: Union[torch.Tensor, np.ndarray, zarr.Array],
last_n_dims=1,
dtype=torch.float32,
mode="limits",
output_max=1.0,
output_min=-1.0,
range_eps=1e-4,
fit_offset=True,
):
assert mode in ["limits", "gaussian"]
assert last_n_dims >= 0
assert output_max > output_min
# convert data to torch and type
if isinstance(data, zarr.Array):
data = data[:]
if isinstance(data, np.ndarray):
data = torch.from_numpy(data)
if dtype is not None:
data = data.type(dtype)
# convert shape
dim = 1
if last_n_dims > 0:
dim = np.prod(data.shape[-last_n_dims:])
data = data.reshape(-1, dim)
# compute input stats min max mean std
input_min, _ = data.min(axis=0)
input_max, _ = data.max(axis=0)
input_mean = data.mean(axis=0)
input_std = data.std(axis=0)
# compute scale and offset
if mode == "limits":
if fit_offset:
# unit scale
input_range = input_max - input_min
ignore_dim = input_range < range_eps
input_range[ignore_dim] = output_max - output_min
scale = (output_max - output_min) / input_range
offset = output_min - scale * input_min
offset[ignore_dim] = (output_max + output_min) / 2 - input_min[ignore_dim]
# ignore dims scaled to mean of output max and min
else:
# use this when data is pre-zero-centered.
assert output_max > 0
assert output_min < 0
# unit abs
output_abs = min(abs(output_min), abs(output_max))
input_abs = torch.maximum(torch.abs(input_min), torch.abs(input_max))
ignore_dim = input_abs < range_eps
input_abs[ignore_dim] = output_abs
# don't scale constant channels
scale = output_abs / input_abs
offset = torch.zeros_like(input_mean)
elif mode == "gaussian":
ignore_dim = input_std < range_eps
scale = input_std.clone()
scale[ignore_dim] = 1
scale = 1 / scale
offset = -input_mean * scale if fit_offset else torch.zeros_like(input_mean)
# save
this_params = nn.ParameterDict(
{
"scale": scale,
"offset": offset,
"input_stats": nn.ParameterDict(
{"min": input_min, "max": input_max, "mean": input_mean, "std": input_std}
),
}
)
for p in this_params.parameters():
p.requires_grad_(False)
return this_params
def _normalize(x, params, forward=True):
assert "scale" in params
if isinstance(x, np.ndarray):
x = torch.from_numpy(x)
scale = params["scale"]
offset = params["offset"]
x = x.to(device=scale.device, dtype=scale.dtype)
src_shape = x.shape
x = x.reshape(-1, scale.shape[0])
x = x * scale + offset if forward else (x - offset) / scale
x = x.reshape(src_shape)
return x
def test():
data = torch.zeros((100, 10, 9, 2)).uniform_()
data[..., 0, 0] = 0
normalizer = SingleFieldLinearNormalizer()
normalizer.fit(data, mode="limits", last_n_dims=2)
datan = normalizer.normalize(data)
assert datan.shape == data.shape
assert np.allclose(datan.max(), 1.0)
assert np.allclose(datan.min(), -1.0)
dataun = normalizer.unnormalize(datan)
assert torch.allclose(data, dataun, atol=1e-7)
_ = normalizer.get_input_stats()
_ = normalizer.get_output_stats()
normalizer = SingleFieldLinearNormalizer()
normalizer.fit(data, mode="limits", last_n_dims=1, fit_offset=False)
datan = normalizer.normalize(data)
assert datan.shape == data.shape
assert np.allclose(datan.max(), 1.0, atol=1e-3)
assert np.allclose(datan.min(), 0.0, atol=1e-3)
dataun = normalizer.unnormalize(datan)
assert torch.allclose(data, dataun, atol=1e-7)
data = torch.zeros((100, 10, 9, 2)).uniform_()
normalizer = SingleFieldLinearNormalizer()
normalizer.fit(data, mode="gaussian", last_n_dims=0)
datan = normalizer.normalize(data)
assert datan.shape == data.shape
assert np.allclose(datan.mean(), 0.0, atol=1e-3)
assert np.allclose(datan.std(), 1.0, atol=1e-3)
dataun = normalizer.unnormalize(datan)
assert torch.allclose(data, dataun, atol=1e-7)
# dict
data = torch.zeros((100, 10, 9, 2)).uniform_()
data[..., 0, 0] = 0
normalizer = LinearNormalizer()
normalizer.fit(data, mode="limits", last_n_dims=2)
datan = normalizer.normalize(data)
assert datan.shape == data.shape
assert np.allclose(datan.max(), 1.0)
assert np.allclose(datan.min(), -1.0)
dataun = normalizer.unnormalize(datan)
assert torch.allclose(data, dataun, atol=1e-7)
_ = normalizer.get_input_stats()
_ = normalizer.get_output_stats()
data = {
"obs": torch.zeros((1000, 128, 9, 2)).uniform_() * 512,
"action": torch.zeros((1000, 128, 2)).uniform_() * 512,
}
normalizer = LinearNormalizer()
normalizer.fit(data)
datan = normalizer.normalize(data)
dataun = normalizer.unnormalize(datan)
for key in data:
assert torch.allclose(data[key], dataun[key], atol=1e-4)
_ = normalizer.get_input_stats()
_ = normalizer.get_output_stats()
state_dict = normalizer.state_dict()
n = LinearNormalizer()
n.load_state_dict(state_dict)
datan = n.normalize(data)
dataun = n.unnormalize(datan)
for key in data:
assert torch.allclose(data[key], dataun[key], atol=1e-4)

View File

@ -1,19 +0,0 @@
import math
import torch
import torch.nn as nn
class SinusoidalPosEmb(nn.Module):
def __init__(self, dim):
super().__init__()
self.dim = dim
def forward(self, x):
device = x.device
half_dim = self.dim // 2
emb = math.log(10000) / (half_dim - 1)
emb = torch.exp(torch.arange(half_dim, device=device) * -emb)
emb = x[:, None] * emb[None, :]
emb = torch.cat((emb.sin(), emb.cos()), dim=-1)
return emb

View File

@ -0,0 +1,147 @@
from typing import Callable
import torch
import torchvision
from robomimic.models.base_nets import SpatialSoftmax
from torch import Tensor, nn
from torchvision.transforms import CenterCrop, RandomCrop
class RgbEncoder(nn.Module):
"""Encoder an RGB image into a 1D feature vector.
Includes the ability to normalize and crop the image first.
"""
def __init__(
self,
input_shape: tuple[int, int, int],
norm_mean_std: tuple[float, float] = [1.0, 1.0],
crop_shape: tuple[int, int] | None = None,
random_crop: bool = False,
backbone_name: str = "resnet18",
pretrained_backbone: bool = False,
use_group_norm: bool = False,
relu: bool = True,
num_keypoints: int = 32,
):
"""
Args:
input_shape: channel-first input shape (C, H, W)
norm_mean_std: mean and standard deviation used for image normalization. Images are normalized as
(image - mean) / std.
crop_shape: (H, W) shape to crop to (must fit within the input shape). If not provided, no
cropping is done.
random_crop: Whether the crop should be random at training time (it's always a center crop in
eval mode).
backbone_name: The name of one of the available resnet models from torchvision (eg resnet18).
pretrained_backbone: whether to use timm pretrained weights.
use_group_norm: Whether to replace batch normalization with group normalization in the backbone.
The group sizes are set to be about 16 (to be precise, feature_dim // 16).
relu: whether to use relu as a final step.
num_keypoints: Number of keypoints for SpatialSoftmax (default value of 32 matches PushT Image).
"""
super().__init__()
if input_shape[0] != 3:
raise ValueError("Only RGB images are handled")
if not backbone_name.startswith("resnet"):
raise ValueError(
"Only resnet is supported for now (because of the assumption that 'layer4' is the output layer)"
)
# Set up optional preprocessing.
if norm_mean_std == [1.0, 1.0]:
self.normalizer = nn.Identity()
else:
self.normalizer = torchvision.transforms.Normalize(mean=norm_mean_std[0], std=norm_mean_std[1])
if crop_shape is not None:
self.do_crop = True
self.center_crop = CenterCrop(crop_shape) # always use center crop for eval
if random_crop:
self.maybe_random_crop = RandomCrop(crop_shape)
else:
self.maybe_random_crop = self.center_crop
else:
self.do_crop = False
# Set up backbone.
backbone_model = getattr(torchvision.models, backbone_name)(pretrained=pretrained_backbone)
# Note: This assumes that the layer4 feature map is children()[-3]
# TODO(alexander-soare): Use a safer alternative.
self.backbone = nn.Sequential(*(list(backbone_model.children())[:-2]))
if use_group_norm:
if pretrained_backbone:
raise ValueError(
"You can't replace BatchNorm in a pretrained model without ruining the weights!"
)
self.backbone = _replace_submodules(
root_module=self.backbone,
predicate=lambda x: isinstance(x, nn.BatchNorm2d),
func=lambda x: nn.GroupNorm(num_groups=x.num_features // 16, num_channels=x.num_features),
)
# Set up pooling and final layers.
# Use a dry run to get the feature map shape.
with torch.inference_mode():
feat_map_shape = tuple(self.backbone(torch.zeros(size=(1, *input_shape))).shape[1:])
self.pool = SpatialSoftmax(feat_map_shape, num_kp=num_keypoints)
self.feature_dim = num_keypoints * 2
self.out = nn.Linear(num_keypoints * 2, self.feature_dim)
self.maybe_relu = nn.ReLU() if relu else nn.Identity()
def forward(self, x: Tensor) -> Tensor:
"""
Args:
x: (B, C, H, W) image tensor with pixel values in [0, 1].
Returns:
(B, D) image feature.
"""
# Preprocess: normalize and maybe crop (if it was set up in the __init__).
x = self.normalizer(x)
if self.do_crop:
if self.training: # noqa: SIM108
x = self.maybe_random_crop(x)
else:
# Always use center crop for eval.
x = self.center_crop(x)
# Extract backbone feature.
x = torch.flatten(self.pool(self.backbone(x)), start_dim=1)
# Final linear layer.
x = self.out(x)
# Maybe a final non-linearity.
x = self.maybe_relu(x)
return x
def _replace_submodules(
root_module: nn.Module, predicate: Callable[[nn.Module], bool], func: Callable[[nn.Module], nn.Module]
) -> nn.Module:
"""
Args:
root_module: The module for which the submodules need to be replaced
predicate: Takes a module as an argument and must return True if the that module is to be replaced.
func: Takes a module as an argument and returns a new module to replace it with.
Returns:
The root module with its submodules replaced.
"""
if predicate(root_module):
return func(root_module)
replace_list = [k.split(".") for k, m in root_module.named_modules(remove_duplicate=True) if predicate(m)]
for *parents, k in replace_list:
parent_module = root_module
if len(parents) > 0:
parent_module = root_module.get_submodule(".".join(parents))
if isinstance(parent_module, nn.Sequential):
src_module = parent_module[int(k)]
else:
src_module = getattr(parent_module, k)
tgt_module = func(src_module)
if isinstance(parent_module, nn.Sequential):
parent_module[int(k)] = tgt_module
else:
setattr(parent_module, k, tgt_module)
# verify that all BN are replaced
assert not any(predicate(m) for _, m in root_module.named_modules(remove_duplicate=True))
return root_module

View File

@ -1,972 +0,0 @@
"""
A collection of utilities for working with nested tensor structures consisting
of numpy arrays and torch tensors.
"""
import collections
import numpy as np
import torch
def recursive_dict_list_tuple_apply(x, type_func_dict):
"""
Recursively apply functions to a nested dictionary or list or tuple, given a dictionary of
{data_type: function_to_apply}.
Args:
x (dict or list or tuple): a possibly nested dictionary or list or tuple
type_func_dict (dict): a mapping from data types to the functions to be
applied for each data type.
Returns:
y (dict or list or tuple): new nested dict-list-tuple
"""
assert list not in type_func_dict
assert tuple not in type_func_dict
assert dict not in type_func_dict
if isinstance(x, (dict, collections.OrderedDict)):
new_x = collections.OrderedDict() if isinstance(x, collections.OrderedDict) else {}
for k, v in x.items():
new_x[k] = recursive_dict_list_tuple_apply(v, type_func_dict)
return new_x
elif isinstance(x, (list, tuple)):
ret = [recursive_dict_list_tuple_apply(v, type_func_dict) for v in x]
if isinstance(x, tuple):
ret = tuple(ret)
return ret
else:
for t, f in type_func_dict.items():
if isinstance(x, t):
return f(x)
else:
raise NotImplementedError("Cannot handle data type %s" % str(type(x)))
def map_tensor(x, func):
"""
Apply function @func to torch.Tensor objects in a nested dictionary or
list or tuple.
Args:
x (dict or list or tuple): a possibly nested dictionary or list or tuple
func (function): function to apply to each tensor
Returns:
y (dict or list or tuple): new nested dict-list-tuple
"""
return recursive_dict_list_tuple_apply(
x,
{
torch.Tensor: func,
type(None): lambda x: x,
},
)
def map_ndarray(x, func):
"""
Apply function @func to np.ndarray objects in a nested dictionary or
list or tuple.
Args:
x (dict or list or tuple): a possibly nested dictionary or list or tuple
func (function): function to apply to each array
Returns:
y (dict or list or tuple): new nested dict-list-tuple
"""
return recursive_dict_list_tuple_apply(
x,
{
np.ndarray: func,
type(None): lambda x: x,
},
)
def map_tensor_ndarray(x, tensor_func, ndarray_func):
"""
Apply function @tensor_func to torch.Tensor objects and @ndarray_func to
np.ndarray objects in a nested dictionary or list or tuple.
Args:
x (dict or list or tuple): a possibly nested dictionary or list or tuple
tensor_func (function): function to apply to each tensor
ndarray_Func (function): function to apply to each array
Returns:
y (dict or list or tuple): new nested dict-list-tuple
"""
return recursive_dict_list_tuple_apply(
x,
{
torch.Tensor: tensor_func,
np.ndarray: ndarray_func,
type(None): lambda x: x,
},
)
def clone(x):
"""
Clones all torch tensors and numpy arrays in nested dictionary or list
or tuple and returns a new nested structure.
Args:
x (dict or list or tuple): a possibly nested dictionary or list or tuple
Returns:
y (dict or list or tuple): new nested dict-list-tuple
"""
return recursive_dict_list_tuple_apply(
x,
{
torch.Tensor: lambda x: x.clone(),
np.ndarray: lambda x: x.copy(),
type(None): lambda x: x,
},
)
def detach(x):
"""
Detaches all torch tensors in nested dictionary or list
or tuple and returns a new nested structure.
Args:
x (dict or list or tuple): a possibly nested dictionary or list or tuple
Returns:
y (dict or list or tuple): new nested dict-list-tuple
"""
return recursive_dict_list_tuple_apply(
x,
{
torch.Tensor: lambda x: x.detach(),
},
)
def to_batch(x):
"""
Introduces a leading batch dimension of 1 for all torch tensors and numpy
arrays in nested dictionary or list or tuple and returns a new nested structure.
Args:
x (dict or list or tuple): a possibly nested dictionary or list or tuple
Returns:
y (dict or list or tuple): new nested dict-list-tuple
"""
return recursive_dict_list_tuple_apply(
x,
{
torch.Tensor: lambda x: x[None, ...],
np.ndarray: lambda x: x[None, ...],
type(None): lambda x: x,
},
)
def to_sequence(x):
"""
Introduces a time dimension of 1 at dimension 1 for all torch tensors and numpy
arrays in nested dictionary or list or tuple and returns a new nested structure.
Args:
x (dict or list or tuple): a possibly nested dictionary or list or tuple
Returns:
y (dict or list or tuple): new nested dict-list-tuple
"""
return recursive_dict_list_tuple_apply(
x,
{
torch.Tensor: lambda x: x[:, None, ...],
np.ndarray: lambda x: x[:, None, ...],
type(None): lambda x: x,
},
)
def index_at_time(x, ind):
"""
Indexes all torch tensors and numpy arrays in dimension 1 with index @ind in
nested dictionary or list or tuple and returns a new nested structure.
Args:
x (dict or list or tuple): a possibly nested dictionary or list or tuple
ind (int): index
Returns:
y (dict or list or tuple): new nested dict-list-tuple
"""
return recursive_dict_list_tuple_apply(
x,
{
torch.Tensor: lambda x: x[:, ind, ...],
np.ndarray: lambda x: x[:, ind, ...],
type(None): lambda x: x,
},
)
def unsqueeze(x, dim):
"""
Adds dimension of size 1 at dimension @dim in all torch tensors and numpy arrays
in nested dictionary or list or tuple and returns a new nested structure.
Args:
x (dict or list or tuple): a possibly nested dictionary or list or tuple
dim (int): dimension
Returns:
y (dict or list or tuple): new nested dict-list-tuple
"""
return recursive_dict_list_tuple_apply(
x,
{
torch.Tensor: lambda x: x.unsqueeze(dim=dim),
np.ndarray: lambda x: np.expand_dims(x, axis=dim),
type(None): lambda x: x,
},
)
def contiguous(x):
"""
Makes all torch tensors and numpy arrays contiguous in nested dictionary or
list or tuple and returns a new nested structure.
Args:
x (dict or list or tuple): a possibly nested dictionary or list or tuple
Returns:
y (dict or list or tuple): new nested dict-list-tuple
"""
return recursive_dict_list_tuple_apply(
x,
{
torch.Tensor: lambda x: x.contiguous(),
np.ndarray: lambda x: np.ascontiguousarray(x),
type(None): lambda x: x,
},
)
def to_device(x, device):
"""
Sends all torch tensors in nested dictionary or list or tuple to device
@device, and returns a new nested structure.
Args:
x (dict or list or tuple): a possibly nested dictionary or list or tuple
device (torch.Device): device to send tensors to
Returns:
y (dict or list or tuple): new nested dict-list-tuple
"""
return recursive_dict_list_tuple_apply(
x,
{
torch.Tensor: lambda x, d=device: x.to(d),
type(None): lambda x: x,
},
)
def to_tensor(x):
"""
Converts all numpy arrays in nested dictionary or list or tuple to
torch tensors (and leaves existing torch Tensors as-is), and returns
a new nested structure.
Args:
x (dict or list or tuple): a possibly nested dictionary or list or tuple
Returns:
y (dict or list or tuple): new nested dict-list-tuple
"""
return recursive_dict_list_tuple_apply(
x,
{
torch.Tensor: lambda x: x,
np.ndarray: lambda x: torch.from_numpy(x),
type(None): lambda x: x,
},
)
def to_numpy(x):
"""
Converts all torch tensors in nested dictionary or list or tuple to
numpy (and leaves existing numpy arrays as-is), and returns
a new nested structure.
Args:
x (dict or list or tuple): a possibly nested dictionary or list or tuple
Returns:
y (dict or list or tuple): new nested dict-list-tuple
"""
def f(tensor):
if tensor.is_cuda:
return tensor.detach().cpu().numpy()
else:
return tensor.detach().numpy()
return recursive_dict_list_tuple_apply(
x,
{
torch.Tensor: f,
np.ndarray: lambda x: x,
type(None): lambda x: x,
},
)
def to_list(x):
"""
Converts all torch tensors and numpy arrays in nested dictionary or list
or tuple to a list, and returns a new nested structure. Useful for
json encoding.
Args:
x (dict or list or tuple): a possibly nested dictionary or list or tuple
Returns:
y (dict or list or tuple): new nested dict-list-tuple
"""
def f(tensor):
if tensor.is_cuda:
return tensor.detach().cpu().numpy().tolist()
else:
return tensor.detach().numpy().tolist()
return recursive_dict_list_tuple_apply(
x,
{
torch.Tensor: f,
np.ndarray: lambda x: x.tolist(),
type(None): lambda x: x,
},
)
def to_float(x):
"""
Converts all torch tensors and numpy arrays in nested dictionary or list
or tuple to float type entries, and returns a new nested structure.
Args:
x (dict or list or tuple): a possibly nested dictionary or list or tuple
Returns:
y (dict or list or tuple): new nested dict-list-tuple
"""
return recursive_dict_list_tuple_apply(
x,
{
torch.Tensor: lambda x: x.float(),
np.ndarray: lambda x: x.astype(np.float32),
type(None): lambda x: x,
},
)
def to_uint8(x):
"""
Converts all torch tensors and numpy arrays in nested dictionary or list
or tuple to uint8 type entries, and returns a new nested structure.
Args:
x (dict or list or tuple): a possibly nested dictionary or list or tuple
Returns:
y (dict or list or tuple): new nested dict-list-tuple
"""
return recursive_dict_list_tuple_apply(
x,
{
torch.Tensor: lambda x: x.byte(),
np.ndarray: lambda x: x.astype(np.uint8),
type(None): lambda x: x,
},
)
def to_torch(x, device):
"""
Converts all numpy arrays and torch tensors in nested dictionary or list or tuple to
torch tensors on device @device and returns a new nested structure.
Args:
x (dict or list or tuple): a possibly nested dictionary or list or tuple
device (torch.Device): device to send tensors to
Returns:
y (dict or list or tuple): new nested dict-list-tuple
"""
return to_device(to_float(to_tensor(x)), device)
def to_one_hot_single(tensor, num_class):
"""
Convert tensor to one-hot representation, assuming a certain number of total class labels.
Args:
tensor (torch.Tensor): tensor containing integer labels
num_class (int): number of classes
Returns:
x (torch.Tensor): tensor containing one-hot representation of labels
"""
x = torch.zeros(tensor.size() + (num_class,)).to(tensor.device)
x.scatter_(-1, tensor.unsqueeze(-1), 1)
return x
def to_one_hot(tensor, num_class):
"""
Convert all tensors in nested dictionary or list or tuple to one-hot representation,
assuming a certain number of total class labels.
Args:
tensor (dict or list or tuple): a possibly nested dictionary or list or tuple
num_class (int): number of classes
Returns:
y (dict or list or tuple): new nested dict-list-tuple
"""
return map_tensor(tensor, func=lambda x, nc=num_class: to_one_hot_single(x, nc))
def flatten_single(x, begin_axis=1):
"""
Flatten a tensor in all dimensions from @begin_axis onwards.
Args:
x (torch.Tensor): tensor to flatten
begin_axis (int): which axis to flatten from
Returns:
y (torch.Tensor): flattened tensor
"""
fixed_size = x.size()[:begin_axis]
_s = list(fixed_size) + [-1]
return x.reshape(*_s)
def flatten(x, begin_axis=1):
"""
Flatten all tensors in nested dictionary or list or tuple, from @begin_axis onwards.
Args:
x (dict or list or tuple): a possibly nested dictionary or list or tuple
begin_axis (int): which axis to flatten from
Returns:
y (dict or list or tuple): new nested dict-list-tuple
"""
return recursive_dict_list_tuple_apply(
x,
{
torch.Tensor: lambda x, b=begin_axis: flatten_single(x, begin_axis=b),
},
)
def reshape_dimensions_single(x, begin_axis, end_axis, target_dims):
"""
Reshape selected dimensions in a tensor to a target dimension.
Args:
x (torch.Tensor): tensor to reshape
begin_axis (int): begin dimension
end_axis (int): end dimension
target_dims (tuple or list): target shape for the range of dimensions
(@begin_axis, @end_axis)
Returns:
y (torch.Tensor): reshaped tensor
"""
assert begin_axis <= end_axis
assert begin_axis >= 0
assert end_axis < len(x.shape)
assert isinstance(target_dims, (tuple, list))
s = x.shape
final_s = []
for i in range(len(s)):
if i == begin_axis:
final_s.extend(target_dims)
elif i < begin_axis or i > end_axis:
final_s.append(s[i])
return x.reshape(*final_s)
def reshape_dimensions(x, begin_axis, end_axis, target_dims):
"""
Reshape selected dimensions for all tensors in nested dictionary or list or tuple
to a target dimension.
Args:
x (dict or list or tuple): a possibly nested dictionary or list or tuple
begin_axis (int): begin dimension
end_axis (int): end dimension
target_dims (tuple or list): target shape for the range of dimensions
(@begin_axis, @end_axis)
Returns:
y (dict or list or tuple): new nested dict-list-tuple
"""
return recursive_dict_list_tuple_apply(
x,
{
torch.Tensor: lambda x, b=begin_axis, e=end_axis, t=target_dims: reshape_dimensions_single(
x, begin_axis=b, end_axis=e, target_dims=t
),
np.ndarray: lambda x, b=begin_axis, e=end_axis, t=target_dims: reshape_dimensions_single(
x, begin_axis=b, end_axis=e, target_dims=t
),
type(None): lambda x: x,
},
)
def join_dimensions(x, begin_axis, end_axis):
"""
Joins all dimensions between dimensions (@begin_axis, @end_axis) into a flat dimension, for
all tensors in nested dictionary or list or tuple.
Args:
x (dict or list or tuple): a possibly nested dictionary or list or tuple
begin_axis (int): begin dimension
end_axis (int): end dimension
Returns:
y (dict or list or tuple): new nested dict-list-tuple
"""
return recursive_dict_list_tuple_apply(
x,
{
torch.Tensor: lambda x, b=begin_axis, e=end_axis: reshape_dimensions_single(
x, begin_axis=b, end_axis=e, target_dims=[-1]
),
np.ndarray: lambda x, b=begin_axis, e=end_axis: reshape_dimensions_single(
x, begin_axis=b, end_axis=e, target_dims=[-1]
),
type(None): lambda x: x,
},
)
def expand_at_single(x, size, dim):
"""
Expand a tensor at a single dimension @dim by @size
Args:
x (torch.Tensor): input tensor
size (int): size to expand
dim (int): dimension to expand
Returns:
y (torch.Tensor): expanded tensor
"""
assert dim < x.ndimension()
assert x.shape[dim] == 1
expand_dims = [-1] * x.ndimension()
expand_dims[dim] = size
return x.expand(*expand_dims)
def expand_at(x, size, dim):
"""
Expand all tensors in nested dictionary or list or tuple at a single
dimension @dim by @size.
Args:
x (dict or list or tuple): a possibly nested dictionary or list or tuple
size (int): size to expand
dim (int): dimension to expand
Returns:
y (dict or list or tuple): new nested dict-list-tuple
"""
return map_tensor(x, lambda t, s=size, d=dim: expand_at_single(t, s, d))
def unsqueeze_expand_at(x, size, dim):
"""
Unsqueeze and expand a tensor at a dimension @dim by @size.
Args:
x (dict or list or tuple): a possibly nested dictionary or list or tuple
size (int): size to expand
dim (int): dimension to unsqueeze and expand
Returns:
y (dict or list or tuple): new nested dict-list-tuple
"""
x = unsqueeze(x, dim)
return expand_at(x, size, dim)
def repeat_by_expand_at(x, repeats, dim):
"""
Repeat a dimension by combining expand and reshape operations.
Args:
x (dict or list or tuple): a possibly nested dictionary or list or tuple
repeats (int): number of times to repeat the target dimension
dim (int): dimension to repeat on
Returns:
y (dict or list or tuple): new nested dict-list-tuple
"""
x = unsqueeze_expand_at(x, repeats, dim + 1)
return join_dimensions(x, dim, dim + 1)
def named_reduce_single(x, reduction, dim):
"""
Reduce tensor at a dimension by named reduction functions.
Args:
x (torch.Tensor): tensor to be reduced
reduction (str): one of ["sum", "max", "mean", "flatten"]
dim (int): dimension to be reduced (or begin axis for flatten)
Returns:
y (torch.Tensor): reduced tensor
"""
assert x.ndimension() > dim
assert reduction in ["sum", "max", "mean", "flatten"]
if reduction == "flatten":
x = flatten(x, begin_axis=dim)
elif reduction == "max":
x = torch.max(x, dim=dim)[0] # [B, D]
elif reduction == "sum":
x = torch.sum(x, dim=dim)
else:
x = torch.mean(x, dim=dim)
return x
def named_reduce(x, reduction, dim):
"""
Reduces all tensors in nested dictionary or list or tuple at a dimension
using a named reduction function.
Args:
x (dict or list or tuple): a possibly nested dictionary or list or tuple
reduction (str): one of ["sum", "max", "mean", "flatten"]
dim (int): dimension to be reduced (or begin axis for flatten)
Returns:
y (dict or list or tuple): new nested dict-list-tuple
"""
return map_tensor(x, func=lambda t, r=reduction, d=dim: named_reduce_single(t, r, d))
def gather_along_dim_with_dim_single(x, target_dim, source_dim, indices):
"""
This function indexes out a target dimension of a tensor in a structured way,
by allowing a different value to be selected for each member of a flat index
tensor (@indices) corresponding to a source dimension. This can be interpreted
as moving along the source dimension, using the corresponding index value
in @indices to select values for all other dimensions outside of the
source and target dimensions. A common use case is to gather values
in target dimension 1 for each batch member (target dimension 0).
Args:
x (torch.Tensor): tensor to gather values for
target_dim (int): dimension to gather values along
source_dim (int): dimension to hold constant and use for gathering values
from the other dimensions
indices (torch.Tensor): flat index tensor with same shape as tensor @x along
@source_dim
Returns:
y (torch.Tensor): gathered tensor, with dimension @target_dim indexed out
"""
assert len(indices.shape) == 1
assert x.shape[source_dim] == indices.shape[0]
# unsqueeze in all dimensions except the source dimension
new_shape = [1] * x.ndimension()
new_shape[source_dim] = -1
indices = indices.reshape(*new_shape)
# repeat in all dimensions - but preserve shape of source dimension,
# and make sure target_dimension has singleton dimension
expand_shape = list(x.shape)
expand_shape[source_dim] = -1
expand_shape[target_dim] = 1
indices = indices.expand(*expand_shape)
out = x.gather(dim=target_dim, index=indices)
return out.squeeze(target_dim)
def gather_along_dim_with_dim(x, target_dim, source_dim, indices):
"""
Apply @gather_along_dim_with_dim_single to all tensors in a nested
dictionary or list or tuple.
Args:
x (dict or list or tuple): a possibly nested dictionary or list or tuple
target_dim (int): dimension to gather values along
source_dim (int): dimension to hold constant and use for gathering values
from the other dimensions
indices (torch.Tensor): flat index tensor with same shape as tensor @x along
@source_dim
Returns:
y (dict or list or tuple): new nested dict-list-tuple
"""
return map_tensor(
x, lambda y, t=target_dim, s=source_dim, i=indices: gather_along_dim_with_dim_single(y, t, s, i)
)
def gather_sequence_single(seq, indices):
"""
Given a tensor with leading dimensions [B, T, ...], gather an element from each sequence in
the batch given an index for each sequence.
Args:
seq (torch.Tensor): tensor with leading dimensions [B, T, ...]
indices (torch.Tensor): tensor indices of shape [B]
Return:
y (torch.Tensor): indexed tensor of shape [B, ....]
"""
return gather_along_dim_with_dim_single(seq, target_dim=1, source_dim=0, indices=indices)
def gather_sequence(seq, indices):
"""
Given a nested dictionary or list or tuple, gathers an element from each sequence of the batch
for tensors with leading dimensions [B, T, ...].
Args:
seq (dict or list or tuple): a possibly nested dictionary or list or tuple with tensors
of leading dimensions [B, T, ...]
indices (torch.Tensor): tensor indices of shape [B]
Returns:
y (dict or list or tuple): new nested dict-list-tuple with tensors of shape [B, ...]
"""
return gather_along_dim_with_dim(seq, target_dim=1, source_dim=0, indices=indices)
def pad_sequence_single(seq, padding, batched=False, pad_same=True, pad_values=None):
"""
Pad input tensor or array @seq in the time dimension (dimension 1).
Args:
seq (np.ndarray or torch.Tensor): sequence to be padded
padding (tuple): begin and end padding, e.g. [1, 1] pads both begin and end of the sequence by 1
batched (bool): if sequence has the batch dimension
pad_same (bool): if pad by duplicating
pad_values (scalar or (ndarray, Tensor)): values to be padded if not pad_same
Returns:
padded sequence (np.ndarray or torch.Tensor)
"""
assert isinstance(seq, (np.ndarray, torch.Tensor))
assert pad_same or pad_values is not None
if pad_values is not None:
assert isinstance(pad_values, float)
repeat_func = np.repeat if isinstance(seq, np.ndarray) else torch.repeat_interleave
concat_func = np.concatenate if isinstance(seq, np.ndarray) else torch.cat
ones_like_func = np.ones_like if isinstance(seq, np.ndarray) else torch.ones_like
seq_dim = 1 if batched else 0
begin_pad = []
end_pad = []
if padding[0] > 0:
pad = seq[[0]] if pad_same else ones_like_func(seq[[0]]) * pad_values
begin_pad.append(repeat_func(pad, padding[0], seq_dim))
if padding[1] > 0:
pad = seq[[-1]] if pad_same else ones_like_func(seq[[-1]]) * pad_values
end_pad.append(repeat_func(pad, padding[1], seq_dim))
return concat_func(begin_pad + [seq] + end_pad, seq_dim)
def pad_sequence(seq, padding, batched=False, pad_same=True, pad_values=None):
"""
Pad a nested dictionary or list or tuple of sequence tensors in the time dimension (dimension 1).
Args:
seq (dict or list or tuple): a possibly nested dictionary or list or tuple with tensors
of leading dimensions [B, T, ...]
padding (tuple): begin and end padding, e.g. [1, 1] pads both begin and end of the sequence by 1
batched (bool): if sequence has the batch dimension
pad_same (bool): if pad by duplicating
pad_values (scalar or (ndarray, Tensor)): values to be padded if not pad_same
Returns:
padded sequence (dict or list or tuple)
"""
return recursive_dict_list_tuple_apply(
seq,
{
torch.Tensor: lambda x, p=padding, b=batched, ps=pad_same, pv=pad_values: pad_sequence_single(
x, p, b, ps, pv
),
np.ndarray: lambda x, p=padding, b=batched, ps=pad_same, pv=pad_values: pad_sequence_single(
x, p, b, ps, pv
),
type(None): lambda x: x,
},
)
def assert_size_at_dim_single(x, size, dim, msg):
"""
Ensure that array or tensor @x has size @size in dim @dim.
Args:
x (np.ndarray or torch.Tensor): input array or tensor
size (int): size that tensors should have at @dim
dim (int): dimension to check
msg (str): text to display if assertion fails
"""
assert x.shape[dim] == size, msg
def assert_size_at_dim(x, size, dim, msg):
"""
Ensure that arrays and tensors in nested dictionary or list or tuple have
size @size in dim @dim.
Args:
x (dict or list or tuple): a possibly nested dictionary or list or tuple
size (int): size that tensors should have at @dim
dim (int): dimension to check
"""
map_tensor(x, lambda t, s=size, d=dim, m=msg: assert_size_at_dim_single(t, s, d, m))
def get_shape(x):
"""
Get all shapes of arrays and tensors in nested dictionary or list or tuple.
Args:
x (dict or list or tuple): a possibly nested dictionary or list or tuple
Returns:
y (dict or list or tuple): new nested dict-list-tuple that contains each array or
tensor's shape
"""
return recursive_dict_list_tuple_apply(
x,
{
torch.Tensor: lambda x: x.shape,
np.ndarray: lambda x: x.shape,
type(None): lambda x: x,
},
)
def list_of_flat_dict_to_dict_of_list(list_of_dict):
"""
Helper function to go from a list of flat dictionaries to a dictionary of lists.
By "flat" we mean that none of the values are dictionaries, but are numpy arrays,
floats, etc.
Args:
list_of_dict (list): list of flat dictionaries
Returns:
dict_of_list (dict): dictionary of lists
"""
assert isinstance(list_of_dict, list)
dic = collections.OrderedDict()
for i in range(len(list_of_dict)):
for k in list_of_dict[i]:
if k not in dic:
dic[k] = []
dic[k].append(list_of_dict[i][k])
return dic
def flatten_nested_dict_list(d, parent_key="", sep="_", item_key=""):
"""
Flatten a nested dict or list to a list.
For example, given a dict
{
a: 1
b: {
c: 2
}
c: 3
}
the function would return [(a, 1), (b_c, 2), (c, 3)]
Args:
d (dict, list): a nested dict or list to be flattened
parent_key (str): recursion helper
sep (str): separator for nesting keys
item_key (str): recursion helper
Returns:
list: a list of (key, value) tuples
"""
items = []
if isinstance(d, (tuple, list)):
new_key = parent_key + sep + item_key if len(parent_key) > 0 else item_key
for i, v in enumerate(d):
items.extend(flatten_nested_dict_list(v, new_key, sep=sep, item_key=str(i)))
return items
elif isinstance(d, dict):
new_key = parent_key + sep + item_key if len(parent_key) > 0 else item_key
for k, v in d.items():
assert isinstance(k, str)
items.extend(flatten_nested_dict_list(v, new_key, sep=sep, item_key=k))
return items
else:
new_key = parent_key + sep + item_key if len(parent_key) > 0 else item_key
return [(new_key, d)]
def time_distributed(inputs, op, activation=None, inputs_as_kwargs=False, inputs_as_args=False, **kwargs):
"""
Apply function @op to all tensors in nested dictionary or list or tuple @inputs in both the
batch (B) and time (T) dimension, where the tensors are expected to have shape [B, T, ...].
Will do this by reshaping tensors to [B * T, ...], passing through the op, and then reshaping
outputs to [B, T, ...].
Args:
inputs (list or tuple or dict): a possibly nested dictionary or list or tuple with tensors
of leading dimensions [B, T, ...]
op: a layer op that accepts inputs
activation: activation to apply at the output
inputs_as_kwargs (bool): whether to feed input as a kwargs dict to the op
inputs_as_args (bool) whether to feed input as a args list to the op
kwargs (dict): other kwargs to supply to the op
Returns:
outputs (dict or list or tuple): new nested dict-list-tuple with tensors of leading dimension [B, T].
"""
batch_size, seq_len = flatten_nested_dict_list(inputs)[0][1].shape[:2]
inputs = join_dimensions(inputs, 0, 1)
if inputs_as_kwargs:
outputs = op(**inputs, **kwargs)
elif inputs_as_args:
outputs = op(*inputs, **kwargs)
else:
outputs = op(inputs, **kwargs)
if activation is not None:
outputs = map_tensor(outputs, activation)
outputs = reshape_dimensions(outputs, begin_axis=0, end_axis=0, target_dims=(batch_size, seq_len))
return outputs

View File

@ -5,11 +5,10 @@ from collections import deque
import hydra
import torch
from torch import nn
from diffusers.optimization import get_scheduler
from torch import Tensor, nn
from lerobot.common.policies.diffusion.diffusion_unet_image_policy import DiffusionUnetImagePolicy
from lerobot.common.policies.diffusion.model.lr_scheduler import get_scheduler
from lerobot.common.policies.diffusion.model.multi_image_obs_encoder import MultiImageObsEncoder, RgbEncoder
from lerobot.common.policies.diffusion.model.diffusion_unet_image_policy import DiffusionUnetImagePolicy
from lerobot.common.policies.utils import populate_queues
from lerobot.common.utils import get_safe_torch_device
@ -22,8 +21,6 @@ class DiffusionPolicy(nn.Module):
cfg,
cfg_device,
cfg_noise_scheduler,
cfg_rgb_model,
cfg_obs_encoder,
cfg_optimizer,
cfg_ema,
shape_meta: dict,
@ -31,53 +28,43 @@ class DiffusionPolicy(nn.Module):
n_action_steps,
n_obs_steps,
num_inference_steps=None,
obs_as_global_cond=True,
diffusion_step_embed_dim=256,
down_dims=(256, 512, 1024),
kernel_size=5,
n_groups=8,
cond_predict_scale=True,
# parameters passed to step
**kwargs,
film_scale_modulation=True,
**_,
):
super().__init__()
self.cfg = cfg
self.n_obs_steps = n_obs_steps
self.n_action_steps = n_action_steps
# queues are populated during rollout of the policy, they contain the n latest observations and actions
self._queues = None
# TODO(now): In-house this.
noise_scheduler = hydra.utils.instantiate(cfg_noise_scheduler)
rgb_model_input_shape = copy.deepcopy(shape_meta.obs.image.shape)
if cfg_obs_encoder.crop_shape is not None:
rgb_model_input_shape[1:] = cfg_obs_encoder.crop_shape
rgb_model = RgbEncoder(input_shape=rgb_model_input_shape, **cfg_rgb_model)
obs_encoder = MultiImageObsEncoder(
rgb_model=rgb_model,
**cfg_obs_encoder,
)
self.diffusion = DiffusionUnetImagePolicy(
cfg,
shape_meta=shape_meta,
noise_scheduler=noise_scheduler,
obs_encoder=obs_encoder,
horizon=horizon,
n_action_steps=n_action_steps,
n_obs_steps=n_obs_steps,
num_inference_steps=num_inference_steps,
obs_as_global_cond=obs_as_global_cond,
diffusion_step_embed_dim=diffusion_step_embed_dim,
down_dims=down_dims,
kernel_size=kernel_size,
n_groups=n_groups,
cond_predict_scale=cond_predict_scale,
# parameters passed to step
**kwargs,
film_scale_modulation=film_scale_modulation,
)
self.device = get_safe_torch_device(cfg_device)
self.diffusion.to(self.device)
# TODO(alexander-soare): This should probably be managed outside of the policy class.
self.ema_diffusion = None
self.ema = None
if self.cfg.use_ema:
@ -116,42 +103,45 @@ class DiffusionPolicy(nn.Module):
"action": deque(maxlen=self.n_action_steps),
}
@torch.no_grad()
def select_action(self, batch, step):
def forward(self, batch: dict[str, Tensor], **_) -> Tensor:
"""A forward pass through the DNN part of this policy with optional loss computation."""
return self.select_action(batch)
@torch.no_grad
def select_action(self, batch, **_):
"""
Note: this uses the ema model weights if self.training == False, otherwise the non-ema model weights.
# TODO(now): Handle a batch
"""
# TODO(rcadene): remove unused step
del step
assert "observation.image" in batch
assert "observation.state" in batch
assert len(batch) == 2
assert len(batch) == 2 # TODO(now): Does this not have a batch dim?
self._queues = populate_queues(self._queues, batch)
if len(self._queues["action"]) == 0:
# stack n latest observations from the queue
batch = {key: torch.stack(list(self._queues[key]), dim=1) for key in batch}
obs_dict = {
"image": batch["observation.image"],
"agent_pos": batch["observation.state"],
}
if self.training:
out = self.diffusion.predict_action(obs_dict)
else:
out = self.ema_diffusion.predict_action(obs_dict)
self._queues["action"].extend(out["action"].transpose(0, 1))
actions = self._generate_actions(batch)
self._queues["action"].extend(actions.transpose(0, 1))
action = self._queues["action"].popleft()
return action
def forward(self, batch, step):
def _generate_actions(self, batch):
if not self.training and self.ema_diffusion is not None:
return self.ema_diffusion.predict_action(batch)
else:
return self.diffusion.predict_action(batch)
def update(self, batch, **_):
"""Run the model in train mode, compute the loss, and do an optimization step."""
start_time = time.time()
self.diffusion.train()
loss = self.diffusion.compute_loss(batch)
loss = self.compute_loss(batch)
loss.backward()
grad_norm = torch.nn.utils.clip_grad_norm_(
@ -174,13 +164,11 @@ class DiffusionPolicy(nn.Module):
"update_s": time.time() - start_time,
}
# TODO(rcadene): remove hardcoding
# in diffusion_policy, len(dataloader) is 168 for a batch_size of 64
if step % 168 == 0:
self.global_step += 1
return info
def compute_loss(self, batch: dict[str, Tensor]) -> Tensor:
return self.diffusion.compute_loss(batch)
def save(self, fp):
torch.save(self.state_dict(), fp)

View File

@ -1,76 +0,0 @@
from typing import Callable, Dict
import torch
import torch.nn as nn
import torchvision
def get_resnet(name, weights=None, **kwargs):
"""
name: resnet18, resnet34, resnet50
weights: "IMAGENET1K_V1", "r3m"
"""
# load r3m weights
if (weights == "r3m") or (weights == "R3M"):
return get_r3m(name=name, **kwargs)
func = getattr(torchvision.models, name)
resnet = func(weights=weights, **kwargs)
resnet.fc = torch.nn.Identity()
return resnet
def get_r3m(name, **kwargs):
"""
name: resnet18, resnet34, resnet50
"""
import r3m
r3m.device = "cpu"
model = r3m.load_r3m(name)
r3m_model = model.module
resnet_model = r3m_model.convnet
resnet_model = resnet_model.to("cpu")
return resnet_model
def dict_apply(
x: Dict[str, torch.Tensor], func: Callable[[torch.Tensor], torch.Tensor]
) -> Dict[str, torch.Tensor]:
result = {}
for key, value in x.items():
if isinstance(value, dict):
result[key] = dict_apply(value, func)
else:
result[key] = func(value)
return result
def replace_submodules(
root_module: nn.Module, predicate: Callable[[nn.Module], bool], func: Callable[[nn.Module], nn.Module]
) -> nn.Module:
"""
predicate: Return true if the module is to be replaced.
func: Return new module to use.
"""
if predicate(root_module):
return func(root_module)
bn_list = [k.split(".") for k, m in root_module.named_modules(remove_duplicate=True) if predicate(m)]
for *parent, k in bn_list:
parent_module = root_module
if len(parent) > 0:
parent_module = root_module.get_submodule(".".join(parent))
if isinstance(parent_module, nn.Sequential):
src_module = parent_module[int(k)]
else:
src_module = getattr(parent_module, k)
tgt_module = func(src_module)
if isinstance(parent_module, nn.Sequential):
parent_module[int(k)] = tgt_module
else:
setattr(parent_module, k, tgt_module)
# verify that all BN are replaced
bn_list = [k.split(".") for k, m in root_module.named_modules(remove_duplicate=True) if predicate(m)]
assert len(bn_list) == 0
return root_module

View File

@ -12,8 +12,6 @@ def make_policy(cfg):
cfg=cfg.policy,
cfg_device=cfg.device,
cfg_noise_scheduler=cfg.noise_scheduler,
cfg_rgb_model=cfg.rgb_model,
cfg_obs_encoder=cfg.obs_encoder,
cfg_optimizer=cfg.optimizer,
cfg_ema=cfg.ema,
# n_obs_steps=cfg.n_obs_steps,

View File

@ -1,3 +1,7 @@
import torch
from torch import nn
def populate_queues(queues, batch):
for key in batch:
if len(queues[key]) != queues[key].maxlen:
@ -8,3 +12,21 @@ def populate_queues(queues, batch):
# add latest observation to the queue
queues[key].append(batch[key])
return queues
def get_device_from_parameters(module: nn.Module) -> torch.device:
"""Get a module's device by checking one of its parameters.
Note: assumes that all parameters have the same device
TODO(now): Add test.
"""
return next(iter(module.parameters())).device
def get_dtype_from_parameters(module: nn.Module) -> torch.dtype:
"""Get a module's parameter dtype by checking one of its parameters.
Note: assumes that all parameters have the same dtype.
TODO(now): Add test.
"""
return next(iter(module.parameters())).dtype

View File

@ -11,6 +11,7 @@ from omegaconf import DictConfig
def get_safe_torch_device(cfg_device: str, log: bool = False) -> torch.device:
"""Given a string, return a torch.device with checks on whether the device is available."""
match cfg_device:
case "cuda":
assert torch.cuda.is_available()

View File

@ -19,7 +19,6 @@ n_action_steps: 8
dataset_obs_steps: ${n_obs_steps}
past_action_visible: False
keypoint_visible_rate: 1.0
obs_as_global_cond: True
eval_episodes: 50
eval_freq: 5000
@ -40,13 +39,12 @@ policy:
n_obs_steps: ${n_obs_steps}
n_action_steps: ${n_action_steps}
num_inference_steps: 100
obs_as_global_cond: ${obs_as_global_cond}
# crop_shape: null
diffusion_step_embed_dim: 128
down_dims: [512, 1024, 2048]
kernel_size: 5
n_groups: 8
cond_predict_scale: True
film_scale_modulation: True
pretrained_model_path:
@ -68,6 +66,16 @@ policy:
observation.state: [-0.1, 0]
action: [-0.1, 0, .1, .2, .3, .4, .5, .6, .7, .8, .9, 1.0, 1.1, 1.2, 1.3, 1.4]
rgb_encoder:
backbone_name: resnet18
pretrained_backbone: false
use_group_norm: True
num_keypoints: 32
relu: true
norm_mean_std: [0.5, 0.5] # for PushT the original impl normalizes to [-1, 1] (maybe not the case for robomimic envs)
crop_shape: [84, 84]
random_crop: True
noise_scheduler:
_target_: diffusers.schedulers.scheduling_ddpm.DDPMScheduler
num_train_timesteps: 100
@ -78,16 +86,6 @@ noise_scheduler:
clip_sample: True # required when predict_epsilon=False
prediction_type: epsilon # or sample
obs_encoder:
shape_meta: ${shape_meta}
# resize_shape: null
crop_shape: [84, 84]
# constant center crop
random_crop: True
use_group_norm: True
share_rgb_model: False
norm_mean_std: [0.5, 0.5] # for PushT the original impl normalizes to [-1, 1] (maybe not the case for robomimic envs)
rgb_model:
pretrained: false
num_keypoints: 32

View File

@ -121,7 +121,7 @@ def eval_policy(
# get the next action for the environment
with torch.inference_mode():
action = policy.select_action(observation, step)
action = policy.select_action(observation, step=step)
# apply inverse transform to unnormalize the action
action = postprocess_action(action, transform)

View File

@ -213,7 +213,7 @@ def train(cfg: dict, out_dir=None, job_name=None):
for key in batch:
batch[key] = batch[key].to(cfg.device, non_blocking=True)
train_info = policy(batch, step)
train_info = policy.update(batch, step=step)
# TODO(rcadene): is it ok if step_t=0 = 0 and not 1 as previously done?
if step % cfg.log_freq == 0: