backup wip
This commit is contained in:
parent
91ff69d64c
commit
976a197f98
|
@ -1,3 +1,8 @@
|
|||
"""Helper code for loading PushT dataset from Diffusion Policy (https://diffusion-policy.cs.columbia.edu/)
|
||||
|
||||
Copied from the original Diffusion Policy repository.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import math
|
|
@ -8,8 +8,10 @@ import torch
|
|||
import tqdm
|
||||
from gym_pusht.envs.pusht import pymunk_to_shapely
|
||||
|
||||
from lerobot.common.datasets._diffusion_policy_replay_buffer import (
|
||||
ReplayBuffer as DiffusionPolicyReplayBuffer,
|
||||
)
|
||||
from lerobot.common.datasets.utils import download_and_extract_zip, load_data_with_delta_timestamps
|
||||
from lerobot.common.policies.diffusion.replay_buffer import ReplayBuffer as DiffusionPolicyReplayBuffer
|
||||
|
||||
# as define in env
|
||||
SUCCESS_THRESHOLD = 0.95 # 95% coverage,
|
||||
|
|
|
@ -176,7 +176,8 @@ class ActionChunkingTransformerPolicy(nn.Module):
|
|||
if self.n_action_steps is not None:
|
||||
self._action_queue = deque([], maxlen=self.n_action_steps)
|
||||
|
||||
def select_action(self, batch: dict[str, Tensor], *_, **__) -> Tensor:
|
||||
@torch.no_grad
|
||||
def select_action(self, batch: dict[str, Tensor], **_) -> Tensor:
|
||||
"""
|
||||
This method wraps `select_actions` in order to return one action at a time for execution in the
|
||||
environment. It works by managing the actions in a queue and only calling `select_actions` when the
|
||||
|
@ -188,7 +189,7 @@ class ActionChunkingTransformerPolicy(nn.Module):
|
|||
self._action_queue.extend(self.select_actions(batch).transpose(0, 1))
|
||||
return self._action_queue.popleft()
|
||||
|
||||
@torch.no_grad()
|
||||
@torch.no_grad
|
||||
def select_actions(self, batch: dict[str, Tensor]) -> Tensor:
|
||||
"""Use the action chunking transformer to generate a sequence of actions."""
|
||||
self.eval()
|
||||
|
@ -223,8 +224,6 @@ class ActionChunkingTransformerPolicy(nn.Module):
|
|||
{
|
||||
"observation.state": (B, 1, J) OR (B, J) tensor of robot states (joint configuration).
|
||||
"observation.images.top": (B, 1, C, H, W) OR (B, C, H, W) tensor of images.
|
||||
"action": (B, H, J) tensor of actions (positional target for robot joint configuration)
|
||||
"action_is_pad": (B, H) mask for whether the actions are padding outside of the episode bounds.
|
||||
}
|
||||
"""
|
||||
if add_obs_steps_dim:
|
||||
|
@ -244,7 +243,8 @@ class ActionChunkingTransformerPolicy(nn.Module):
|
|||
# Note: no squeeze is required for "observation.images.top" because then we'd have to unsqueeze to get
|
||||
# the image index dimension.
|
||||
|
||||
def update(self, batch, *_, **__) -> dict:
|
||||
def update(self, batch, **_) -> dict:
|
||||
"""Run the model in train mode, compute the loss, and do an optimization step."""
|
||||
start_time = time.time()
|
||||
self._preprocess_batch(batch)
|
||||
|
||||
|
@ -278,6 +278,7 @@ class ActionChunkingTransformerPolicy(nn.Module):
|
|||
return info
|
||||
|
||||
def forward(self, batch: dict[str, Tensor], return_loss: bool = False) -> dict | Tensor:
|
||||
"""A forward pass through the DNN part of this policy with optional loss computation."""
|
||||
images = self.image_normalizer(batch["observation.images.top"])
|
||||
|
||||
if return_loss: # training time
|
||||
|
|
|
@ -1,315 +0,0 @@
|
|||
"""Code from the original diffusion policy project.
|
||||
|
||||
Notes on how to load a checkpoint from the original repository:
|
||||
|
||||
In the original repository, run the eval and use a breakpoint to extract the policy weights.
|
||||
|
||||
```
|
||||
torch.save(policy.state_dict(), "weights.pt")
|
||||
```
|
||||
|
||||
In this repository, add a breakpoint somewhere after creating an equivalent policy and load in the weights:
|
||||
|
||||
```
|
||||
loaded = torch.load("weights.pt")
|
||||
aligned = {}
|
||||
their_prefix = "obs_encoder.obs_nets.image.backbone"
|
||||
our_prefix = "obs_encoder.key_model_map.image.backbone"
|
||||
aligned.update({our_prefix + k.removeprefix(their_prefix): v for k, v in loaded.items() if k.startswith(their_prefix)})
|
||||
their_prefix = "obs_encoder.obs_nets.image.pool"
|
||||
our_prefix = "obs_encoder.key_model_map.image.pool"
|
||||
aligned.update({our_prefix + k.removeprefix(their_prefix): v for k, v in loaded.items() if k.startswith(their_prefix)})
|
||||
their_prefix = "obs_encoder.obs_nets.image.nets.3"
|
||||
our_prefix = "obs_encoder.key_model_map.image.out"
|
||||
aligned.update({our_prefix + k.removeprefix(their_prefix): v for k, v in loaded.items() if k.startswith(their_prefix)})
|
||||
aligned.update({k: v for k, v in loaded.items() if k.startswith('model.')})
|
||||
# Note: here you are loading into the ema model.
|
||||
missing_keys, unexpected_keys = policy.ema_diffusion.load_state_dict(aligned, strict=False)
|
||||
assert all('_dummy_variable' in k for k in missing_keys)
|
||||
assert len(unexpected_keys) == 0
|
||||
```
|
||||
|
||||
Then in that same runtime you can also save the weights with the new aligned state_dict:
|
||||
|
||||
```
|
||||
policy.save("weights.pt")
|
||||
```
|
||||
|
||||
Now you can remove the breakpoint and extra code and load in the weights just like with any other lerobot checkpoint.
|
||||
|
||||
"""
|
||||
|
||||
from typing import Dict
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F # noqa: N812
|
||||
from diffusers.schedulers.scheduling_ddpm import DDPMScheduler
|
||||
from einops import reduce
|
||||
|
||||
from lerobot.common.policies.diffusion.model.conditional_unet1d import ConditionalUnet1D
|
||||
from lerobot.common.policies.diffusion.model.mask_generator import LowdimMaskGenerator
|
||||
from lerobot.common.policies.diffusion.model.module_attr_mixin import ModuleAttrMixin
|
||||
from lerobot.common.policies.diffusion.model.multi_image_obs_encoder import MultiImageObsEncoder
|
||||
from lerobot.common.policies.diffusion.model.normalizer import LinearNormalizer
|
||||
from lerobot.common.policies.diffusion.pytorch_utils import dict_apply
|
||||
|
||||
|
||||
class BaseImagePolicy(ModuleAttrMixin):
|
||||
# init accepts keyword argument shape_meta, see config/task/*_image.yaml
|
||||
|
||||
def predict_action(self, obs_dict: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
|
||||
"""
|
||||
obs_dict:
|
||||
str: B,To,*
|
||||
return: B,Ta,Da
|
||||
"""
|
||||
raise NotImplementedError()
|
||||
|
||||
# reset state for stateful policies
|
||||
def reset(self):
|
||||
pass
|
||||
|
||||
# ========== training ===========
|
||||
# no standard training interface except setting normalizer
|
||||
def set_normalizer(self, normalizer: LinearNormalizer):
|
||||
raise NotImplementedError()
|
||||
|
||||
|
||||
class DiffusionUnetImagePolicy(BaseImagePolicy):
|
||||
def __init__(
|
||||
self,
|
||||
shape_meta: dict,
|
||||
noise_scheduler: DDPMScheduler,
|
||||
obs_encoder: MultiImageObsEncoder,
|
||||
horizon,
|
||||
n_action_steps,
|
||||
n_obs_steps,
|
||||
num_inference_steps=None,
|
||||
obs_as_global_cond=True,
|
||||
diffusion_step_embed_dim=256,
|
||||
down_dims=(256, 512, 1024),
|
||||
kernel_size=5,
|
||||
n_groups=8,
|
||||
cond_predict_scale=True,
|
||||
# parameters passed to step
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
# parse shapes
|
||||
action_shape = shape_meta["action"]["shape"]
|
||||
assert len(action_shape) == 1
|
||||
action_dim = action_shape[0]
|
||||
# get feature dim
|
||||
obs_feature_dim = obs_encoder.output_shape()[0]
|
||||
|
||||
# create diffusion model
|
||||
input_dim = action_dim + obs_feature_dim
|
||||
global_cond_dim = None
|
||||
if obs_as_global_cond:
|
||||
input_dim = action_dim
|
||||
global_cond_dim = obs_feature_dim * n_obs_steps
|
||||
|
||||
model = ConditionalUnet1D(
|
||||
input_dim=input_dim,
|
||||
local_cond_dim=None,
|
||||
global_cond_dim=global_cond_dim,
|
||||
diffusion_step_embed_dim=diffusion_step_embed_dim,
|
||||
down_dims=down_dims,
|
||||
kernel_size=kernel_size,
|
||||
n_groups=n_groups,
|
||||
cond_predict_scale=cond_predict_scale,
|
||||
)
|
||||
|
||||
self.obs_encoder = obs_encoder
|
||||
self.model = model
|
||||
self.noise_scheduler = noise_scheduler
|
||||
self.mask_generator = LowdimMaskGenerator(
|
||||
action_dim=action_dim,
|
||||
obs_dim=0 if obs_as_global_cond else obs_feature_dim,
|
||||
max_n_obs_steps=n_obs_steps,
|
||||
fix_obs_steps=True,
|
||||
action_visible=False,
|
||||
)
|
||||
self.horizon = horizon
|
||||
self.obs_feature_dim = obs_feature_dim
|
||||
self.action_dim = action_dim
|
||||
self.n_action_steps = n_action_steps
|
||||
self.n_obs_steps = n_obs_steps
|
||||
self.obs_as_global_cond = obs_as_global_cond
|
||||
self.kwargs = kwargs
|
||||
|
||||
if num_inference_steps is None:
|
||||
num_inference_steps = noise_scheduler.config.num_train_timesteps
|
||||
self.num_inference_steps = num_inference_steps
|
||||
|
||||
# ========= inference ============
|
||||
def conditional_sample(
|
||||
self,
|
||||
condition_data,
|
||||
condition_mask,
|
||||
local_cond=None,
|
||||
global_cond=None,
|
||||
generator=None,
|
||||
# keyword arguments to scheduler.step
|
||||
**kwargs,
|
||||
):
|
||||
model = self.model
|
||||
scheduler = self.noise_scheduler
|
||||
|
||||
trajectory = torch.randn(
|
||||
size=condition_data.shape,
|
||||
dtype=condition_data.dtype,
|
||||
device=condition_data.device,
|
||||
generator=generator,
|
||||
)
|
||||
|
||||
# set step values
|
||||
scheduler.set_timesteps(self.num_inference_steps)
|
||||
|
||||
for t in scheduler.timesteps:
|
||||
# 1. apply conditioning
|
||||
trajectory[condition_mask] = condition_data[condition_mask]
|
||||
|
||||
# 2. predict model output
|
||||
model_output = model(trajectory, t, local_cond=local_cond, global_cond=global_cond)
|
||||
|
||||
# 3. compute previous image: x_t -> x_t-1
|
||||
trajectory = scheduler.step(
|
||||
model_output,
|
||||
t,
|
||||
trajectory,
|
||||
generator=generator,
|
||||
# **kwargs # TODO(rcadene): in diffusion_policy, expected to be {}
|
||||
).prev_sample
|
||||
|
||||
# finally make sure conditioning is enforced
|
||||
trajectory[condition_mask] = condition_data[condition_mask]
|
||||
|
||||
return trajectory
|
||||
|
||||
def predict_action(self, obs_dict: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
|
||||
"""
|
||||
obs_dict: must include "obs" key
|
||||
result: must include "action" key
|
||||
"""
|
||||
assert "past_action" not in obs_dict # not implemented yet
|
||||
nobs = obs_dict
|
||||
value = next(iter(nobs.values()))
|
||||
bsize, n_obs_steps = value.shape[:2]
|
||||
horizon = self.horizon
|
||||
action_dim = self.action_dim
|
||||
obs_dim = self.obs_feature_dim
|
||||
assert self.n_obs_steps == n_obs_steps
|
||||
|
||||
# build input
|
||||
device = self.device
|
||||
dtype = self.dtype
|
||||
|
||||
# handle different ways of passing observation
|
||||
local_cond = None
|
||||
global_cond = None
|
||||
if self.obs_as_global_cond:
|
||||
# condition through global feature
|
||||
this_nobs = dict_apply(nobs, lambda x: x[:, :n_obs_steps, ...].reshape(-1, *x.shape[2:]))
|
||||
nobs_features = self.obs_encoder(this_nobs)
|
||||
# reshape back to B, Do
|
||||
global_cond = nobs_features.reshape(bsize, -1)
|
||||
# empty data for action
|
||||
cond_data = torch.zeros(size=(bsize, horizon, action_dim), device=device, dtype=dtype)
|
||||
cond_mask = torch.zeros_like(cond_data, dtype=torch.bool)
|
||||
else:
|
||||
# condition through impainting
|
||||
this_nobs = dict_apply(nobs, lambda x: x[:, :n_obs_steps, ...].reshape(-1, *x.shape[2:]))
|
||||
nobs_features = self.obs_encoder(this_nobs)
|
||||
# reshape back to B, T, Do
|
||||
nobs_features = nobs_features.reshape(bsize, n_obs_steps, -1)
|
||||
cond_data = torch.zeros(size=(bsize, horizon, action_dim + obs_dim), device=device, dtype=dtype)
|
||||
cond_mask = torch.zeros_like(cond_data, dtype=torch.bool)
|
||||
cond_data[:, :n_obs_steps, action_dim:] = nobs_features
|
||||
cond_mask[:, :n_obs_steps, action_dim:] = True
|
||||
|
||||
# run sampling
|
||||
nsample = self.conditional_sample(
|
||||
cond_data, cond_mask, local_cond=local_cond, global_cond=global_cond
|
||||
)
|
||||
|
||||
action_pred = nsample[..., :action_dim]
|
||||
# get action
|
||||
start = n_obs_steps - 1
|
||||
end = start + self.n_action_steps
|
||||
action = action_pred[:, start:end]
|
||||
|
||||
result = {"action": action, "action_pred": action_pred}
|
||||
return result
|
||||
|
||||
def compute_loss(self, batch):
|
||||
nobs = {
|
||||
"image": batch["observation.image"],
|
||||
"agent_pos": batch["observation.state"],
|
||||
}
|
||||
nactions = batch["action"]
|
||||
batch_size = nactions.shape[0]
|
||||
horizon = nactions.shape[1]
|
||||
|
||||
# handle different ways of passing observation
|
||||
local_cond = None
|
||||
global_cond = None
|
||||
trajectory = nactions
|
||||
cond_data = trajectory
|
||||
if self.obs_as_global_cond:
|
||||
# reshape B, T, ... to B*T
|
||||
this_nobs = dict_apply(nobs, lambda x: x[:, : self.n_obs_steps, ...].reshape(-1, *x.shape[2:]))
|
||||
nobs_features = self.obs_encoder(this_nobs)
|
||||
# reshape back to B, Do
|
||||
global_cond = nobs_features.reshape(batch_size, -1)
|
||||
else:
|
||||
# reshape B, T, ... to B*T
|
||||
this_nobs = dict_apply(nobs, lambda x: x.reshape(-1, *x.shape[2:]))
|
||||
nobs_features = self.obs_encoder(this_nobs)
|
||||
# reshape back to B, T, Do
|
||||
nobs_features = nobs_features.reshape(batch_size, horizon, -1)
|
||||
cond_data = torch.cat([nactions, nobs_features], dim=-1)
|
||||
trajectory = cond_data.detach()
|
||||
|
||||
# generate impainting mask
|
||||
condition_mask = self.mask_generator(trajectory.shape)
|
||||
|
||||
# Sample noise that we'll add to the images
|
||||
noise = torch.randn(trajectory.shape, device=trajectory.device)
|
||||
bsz = trajectory.shape[0]
|
||||
# Sample a random timestep for each image
|
||||
timesteps = torch.randint(
|
||||
0, self.noise_scheduler.config.num_train_timesteps, (bsz,), device=trajectory.device
|
||||
).long()
|
||||
# Add noise to the clean images according to the noise magnitude at each timestep
|
||||
# (this is the forward diffusion process)
|
||||
noisy_trajectory = self.noise_scheduler.add_noise(trajectory, noise, timesteps)
|
||||
|
||||
# compute loss mask
|
||||
loss_mask = ~condition_mask
|
||||
|
||||
# apply conditioning
|
||||
noisy_trajectory[condition_mask] = cond_data[condition_mask]
|
||||
|
||||
# Predict the noise residual
|
||||
pred = self.model(noisy_trajectory, timesteps, local_cond=local_cond, global_cond=global_cond)
|
||||
|
||||
pred_type = self.noise_scheduler.config.prediction_type
|
||||
if pred_type == "epsilon":
|
||||
target = noise
|
||||
elif pred_type == "sample":
|
||||
target = trajectory
|
||||
else:
|
||||
raise ValueError(f"Unsupported prediction type {pred_type}")
|
||||
|
||||
loss = F.mse_loss(pred, target, reduction="none")
|
||||
loss = loss * loss_mask.type(loss.dtype)
|
||||
|
||||
if "action_is_pad" in batch:
|
||||
in_episode_bound = ~batch["action_is_pad"]
|
||||
loss = loss * in_episode_bound[:, :, None].type(loss.dtype)
|
||||
|
||||
loss = reduce(loss, "b t c -> b", "mean", b=batch_size)
|
||||
loss = loss.mean()
|
||||
return loss
|
|
@ -1,286 +1,307 @@
|
|||
import logging
|
||||
from typing import Union
|
||||
import math
|
||||
|
||||
import einops
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from einops.layers.torch import Rearrange
|
||||
|
||||
from lerobot.common.policies.diffusion.model.conv1d_components import Conv1dBlock, Downsample1d, Upsample1d
|
||||
from lerobot.common.policies.diffusion.model.positional_embedding import SinusoidalPosEmb
|
||||
from torch import Tensor
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class ConditionalResidualBlock1D(nn.Module):
|
||||
class _SinusoidalPosEmb(nn.Module):
|
||||
# TODO(now): consolidate?
|
||||
def __init__(self, dim):
|
||||
super().__init__()
|
||||
self.dim = dim
|
||||
|
||||
def forward(self, x):
|
||||
device = x.device
|
||||
half_dim = self.dim // 2
|
||||
emb = math.log(10000) / (half_dim - 1)
|
||||
emb = torch.exp(torch.arange(half_dim, device=device) * -emb)
|
||||
emb = x[:, None] * emb[None, :]
|
||||
emb = torch.cat((emb.sin(), emb.cos()), dim=-1)
|
||||
return emb
|
||||
|
||||
|
||||
class _Conv1dBlock(nn.Module):
|
||||
"""Conv1d --> GroupNorm --> Mish"""
|
||||
|
||||
def __init__(self, inp_channels, out_channels, kernel_size, n_groups=8):
|
||||
super().__init__()
|
||||
|
||||
self.block = nn.Sequential(
|
||||
nn.Conv1d(inp_channels, out_channels, kernel_size, padding=kernel_size // 2),
|
||||
nn.GroupNorm(n_groups, out_channels),
|
||||
nn.Mish(),
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
return self.block(x)
|
||||
|
||||
|
||||
class _ConditionalResidualBlock1D(nn.Module):
|
||||
"""ResNet style 1D convolutional block with FiLM modulation for conditioning."""
|
||||
|
||||
def __init__(
|
||||
self, in_channels, out_channels, cond_dim, kernel_size=3, n_groups=8, cond_predict_scale=False
|
||||
self,
|
||||
in_channels: int,
|
||||
out_channels: int,
|
||||
cond_dim: int,
|
||||
kernel_size: int = 3,
|
||||
n_groups: int = 8,
|
||||
# Set to True to do scale modulation with FiLM as well as bias modulation (defaults to False meaning
|
||||
# FiLM just modulates bias).
|
||||
film_scale_modulation: bool = False,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.blocks = nn.ModuleList(
|
||||
[
|
||||
Conv1dBlock(in_channels, out_channels, kernel_size, n_groups=n_groups),
|
||||
Conv1dBlock(out_channels, out_channels, kernel_size, n_groups=n_groups),
|
||||
]
|
||||
)
|
||||
|
||||
# FiLM modulation https://arxiv.org/abs/1709.07871
|
||||
# predicts per-channel scale and bias
|
||||
cond_channels = out_channels
|
||||
if cond_predict_scale:
|
||||
cond_channels = out_channels * 2
|
||||
self.cond_predict_scale = cond_predict_scale
|
||||
self.film_scale_modulation = film_scale_modulation
|
||||
self.out_channels = out_channels
|
||||
self.cond_encoder = nn.Sequential(
|
||||
nn.Mish(),
|
||||
nn.Linear(cond_dim, cond_channels),
|
||||
Rearrange("batch t -> batch t 1"),
|
||||
)
|
||||
|
||||
# make sure dimensions compatible
|
||||
self.conv1 = _Conv1dBlock(in_channels, out_channels, kernel_size, n_groups=n_groups)
|
||||
|
||||
# FiLM modulation (https://arxiv.org/abs/1709.07871) outputs per-channel bias and (maybe) scale.
|
||||
cond_channels = out_channels * 2 if film_scale_modulation else out_channels
|
||||
self.cond_encoder = nn.Sequential(nn.Mish(), nn.Linear(cond_dim, cond_channels))
|
||||
|
||||
self.conv2 = _Conv1dBlock(out_channels, out_channels, kernel_size, n_groups=n_groups)
|
||||
|
||||
# A final convolution for dimension matching the residual (if needed).
|
||||
self.residual_conv = (
|
||||
nn.Conv1d(in_channels, out_channels, 1) if in_channels != out_channels else nn.Identity()
|
||||
)
|
||||
|
||||
def forward(self, x, cond):
|
||||
def forward(self, x: Tensor, cond: Tensor) -> Tensor:
|
||||
"""
|
||||
x : [ batch_size x in_channels x horizon ]
|
||||
cond : [ batch_size x cond_dim]
|
||||
Args:
|
||||
x: (B, in_channels, T)
|
||||
cond: (B, cond_dim)
|
||||
Returns:
|
||||
(B, out_channels, T)
|
||||
"""
|
||||
out = self.conv1(x)
|
||||
|
||||
returns:
|
||||
out : [ batch_size x out_channels x horizon ]
|
||||
"""
|
||||
out = self.blocks[0](x)
|
||||
embed = self.cond_encoder(cond)
|
||||
if self.cond_predict_scale:
|
||||
embed = embed.reshape(embed.shape[0], 2, self.out_channels, 1)
|
||||
scale = embed[:, 0, ...]
|
||||
bias = embed[:, 1, ...]
|
||||
# Get condition embedding. Unsqueeze for broadcasting to `out`, resulting in (B, out_channels, 1).
|
||||
cond_embed = self.cond_encoder(cond).unsqueeze(-1)
|
||||
if self.film_scale_modulation:
|
||||
# Treat the embedding as a list of scales and biases.
|
||||
scale = cond_embed[:, : self.out_channels]
|
||||
bias = cond_embed[:, self.out_channels :]
|
||||
out = scale * out + bias
|
||||
else:
|
||||
out = out + embed
|
||||
out = self.blocks[1](out)
|
||||
# Treat the embedding as biases.
|
||||
out = out + cond_embed
|
||||
|
||||
out = self.conv2(out)
|
||||
out = out + self.residual_conv(x)
|
||||
return out
|
||||
|
||||
|
||||
class ConditionalUnet1D(nn.Module):
|
||||
"""A 1D convolutional UNet with FiLM modulation for conditioning.
|
||||
|
||||
Two types of conditioning can be applied:
|
||||
- Global: Conditioning information that is aggregated over the whole observation window. This is
|
||||
incorporated via the FiLM technique in the residual convolution blocks of the Unet's encoder/decoder.
|
||||
- Local: Conditioning information for each timestep in the observation window. This is incorporated
|
||||
by encoding the information via 1D convolutions and adding the resulting embeddings to the inputs and
|
||||
outputs of the Unet's encoder/decoder.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
input_dim,
|
||||
local_cond_dim=None,
|
||||
global_cond_dim=None,
|
||||
diffusion_step_embed_dim=256,
|
||||
down_dims=None,
|
||||
kernel_size=3,
|
||||
n_groups=8,
|
||||
cond_predict_scale=False,
|
||||
input_dim: int,
|
||||
local_cond_dim: int | None = None,
|
||||
global_cond_dim: int | None = None,
|
||||
diffusion_step_embed_dim: int = 256,
|
||||
down_dims: int | None = None,
|
||||
kernel_size: int = 3,
|
||||
n_groups: int = 8,
|
||||
film_scale_modulation: bool = False,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
if down_dims is None:
|
||||
down_dims = [256, 512, 1024]
|
||||
|
||||
all_dims = [input_dim] + list(down_dims)
|
||||
start_dim = down_dims[0]
|
||||
|
||||
dsed = diffusion_step_embed_dim
|
||||
diffusion_step_encoder = nn.Sequential(
|
||||
SinusoidalPosEmb(dsed),
|
||||
nn.Linear(dsed, dsed * 4),
|
||||
# Encoder for the diffusion timestep.
|
||||
self.diffusion_step_encoder = nn.Sequential(
|
||||
_SinusoidalPosEmb(diffusion_step_embed_dim),
|
||||
nn.Linear(diffusion_step_embed_dim, diffusion_step_embed_dim * 4),
|
||||
nn.Mish(),
|
||||
nn.Linear(dsed * 4, dsed),
|
||||
nn.Linear(diffusion_step_embed_dim * 4, diffusion_step_embed_dim),
|
||||
)
|
||||
cond_dim = dsed
|
||||
|
||||
# The FiLM conditioning dimension.
|
||||
cond_dim = diffusion_step_embed_dim
|
||||
if global_cond_dim is not None:
|
||||
cond_dim += global_cond_dim
|
||||
|
||||
in_out = list(zip(all_dims[:-1], all_dims[1:], strict=False))
|
||||
|
||||
local_cond_encoder = None
|
||||
self.local_cond_down_encoder = None
|
||||
self.local_cond_up_encoder = None
|
||||
if local_cond_dim is not None:
|
||||
_, dim_out = in_out[0]
|
||||
dim_in = local_cond_dim
|
||||
local_cond_encoder = nn.ModuleList(
|
||||
[
|
||||
# down encoder
|
||||
ConditionalResidualBlock1D(
|
||||
dim_in,
|
||||
dim_out,
|
||||
cond_dim=cond_dim,
|
||||
kernel_size=kernel_size,
|
||||
n_groups=n_groups,
|
||||
cond_predict_scale=cond_predict_scale,
|
||||
),
|
||||
# up encoder
|
||||
ConditionalResidualBlock1D(
|
||||
dim_in,
|
||||
dim_out,
|
||||
cond_dim=cond_dim,
|
||||
kernel_size=kernel_size,
|
||||
n_groups=n_groups,
|
||||
cond_predict_scale=cond_predict_scale,
|
||||
),
|
||||
]
|
||||
# Encoder for the local conditioning. The output gets added to the Unet encoder input.
|
||||
self.local_cond_down_encoder = _ConditionalResidualBlock1D(
|
||||
local_cond_dim,
|
||||
down_dims[0],
|
||||
cond_dim=cond_dim,
|
||||
kernel_size=kernel_size,
|
||||
n_groups=n_groups,
|
||||
film_scale_modulation=film_scale_modulation,
|
||||
)
|
||||
# Encoder for the local conditioning. The output gets added to the Unet encoder output.
|
||||
self.local_cond_up_encoder = _ConditionalResidualBlock1D(
|
||||
local_cond_dim,
|
||||
down_dims[0],
|
||||
cond_dim=cond_dim,
|
||||
kernel_size=kernel_size,
|
||||
n_groups=n_groups,
|
||||
film_scale_modulation=film_scale_modulation,
|
||||
)
|
||||
|
||||
mid_dim = all_dims[-1]
|
||||
# In channels / out channels for each downsampling block in the Unet's encoder. For the decoder, we
|
||||
# just reverse these.
|
||||
in_out = [(input_dim, down_dims[0])] + list(zip(down_dims[:-1], down_dims[1:], strict=True))
|
||||
|
||||
# Unet encoder.
|
||||
self.down_modules = nn.ModuleList([])
|
||||
for ind, (dim_in, dim_out) in enumerate(in_out):
|
||||
is_last = ind >= (len(in_out) - 1)
|
||||
self.down_modules.append(
|
||||
nn.ModuleList(
|
||||
[
|
||||
_ConditionalResidualBlock1D(
|
||||
dim_in,
|
||||
dim_out,
|
||||
cond_dim=cond_dim,
|
||||
kernel_size=kernel_size,
|
||||
n_groups=n_groups,
|
||||
film_scale_modulation=film_scale_modulation,
|
||||
),
|
||||
_ConditionalResidualBlock1D(
|
||||
dim_out,
|
||||
dim_out,
|
||||
cond_dim=cond_dim,
|
||||
kernel_size=kernel_size,
|
||||
n_groups=n_groups,
|
||||
film_scale_modulation=film_scale_modulation,
|
||||
),
|
||||
# Downsample as long as it is not the last block.
|
||||
nn.Conv1d(dim_out, dim_out, 3, 2, 1) if not is_last else nn.Identity(),
|
||||
]
|
||||
)
|
||||
)
|
||||
|
||||
# Processing in the middle of the auto-encoder.
|
||||
self.mid_modules = nn.ModuleList(
|
||||
[
|
||||
ConditionalResidualBlock1D(
|
||||
mid_dim,
|
||||
mid_dim,
|
||||
_ConditionalResidualBlock1D(
|
||||
down_dims[-1],
|
||||
down_dims[-1],
|
||||
cond_dim=cond_dim,
|
||||
kernel_size=kernel_size,
|
||||
n_groups=n_groups,
|
||||
cond_predict_scale=cond_predict_scale,
|
||||
film_scale_modulation=film_scale_modulation,
|
||||
),
|
||||
ConditionalResidualBlock1D(
|
||||
mid_dim,
|
||||
mid_dim,
|
||||
_ConditionalResidualBlock1D(
|
||||
down_dims[-1],
|
||||
down_dims[-1],
|
||||
cond_dim=cond_dim,
|
||||
kernel_size=kernel_size,
|
||||
n_groups=n_groups,
|
||||
cond_predict_scale=cond_predict_scale,
|
||||
film_scale_modulation=film_scale_modulation,
|
||||
),
|
||||
]
|
||||
)
|
||||
|
||||
down_modules = nn.ModuleList([])
|
||||
for ind, (dim_in, dim_out) in enumerate(in_out):
|
||||
# Unet decoder.
|
||||
self.up_modules = nn.ModuleList([])
|
||||
for ind, (dim_out, dim_in) in enumerate(reversed(in_out[1:])):
|
||||
is_last = ind >= (len(in_out) - 1)
|
||||
down_modules.append(
|
||||
self.up_modules.append(
|
||||
nn.ModuleList(
|
||||
[
|
||||
ConditionalResidualBlock1D(
|
||||
dim_in,
|
||||
_ConditionalResidualBlock1D(
|
||||
dim_in * 2, # x2 as it takes the encoder's skip connection as well
|
||||
dim_out,
|
||||
cond_dim=cond_dim,
|
||||
kernel_size=kernel_size,
|
||||
n_groups=n_groups,
|
||||
cond_predict_scale=cond_predict_scale,
|
||||
film_scale_modulation=film_scale_modulation,
|
||||
),
|
||||
ConditionalResidualBlock1D(
|
||||
_ConditionalResidualBlock1D(
|
||||
dim_out,
|
||||
dim_out,
|
||||
cond_dim=cond_dim,
|
||||
kernel_size=kernel_size,
|
||||
n_groups=n_groups,
|
||||
cond_predict_scale=cond_predict_scale,
|
||||
film_scale_modulation=film_scale_modulation,
|
||||
),
|
||||
Downsample1d(dim_out) if not is_last else nn.Identity(),
|
||||
# Upsample as long as it is not the last block.
|
||||
nn.ConvTranspose1d(dim_out, dim_out, 4, 2, 1) if not is_last else nn.Identity(),
|
||||
]
|
||||
)
|
||||
)
|
||||
|
||||
up_modules = nn.ModuleList([])
|
||||
for ind, (dim_in, dim_out) in enumerate(reversed(in_out[1:])):
|
||||
is_last = ind >= (len(in_out) - 1)
|
||||
up_modules.append(
|
||||
nn.ModuleList(
|
||||
[
|
||||
ConditionalResidualBlock1D(
|
||||
dim_out * 2,
|
||||
dim_in,
|
||||
cond_dim=cond_dim,
|
||||
kernel_size=kernel_size,
|
||||
n_groups=n_groups,
|
||||
cond_predict_scale=cond_predict_scale,
|
||||
),
|
||||
ConditionalResidualBlock1D(
|
||||
dim_in,
|
||||
dim_in,
|
||||
cond_dim=cond_dim,
|
||||
kernel_size=kernel_size,
|
||||
n_groups=n_groups,
|
||||
cond_predict_scale=cond_predict_scale,
|
||||
),
|
||||
Upsample1d(dim_in) if not is_last else nn.Identity(),
|
||||
]
|
||||
)
|
||||
)
|
||||
|
||||
final_conv = nn.Sequential(
|
||||
Conv1dBlock(start_dim, start_dim, kernel_size=kernel_size),
|
||||
nn.Conv1d(start_dim, input_dim, 1),
|
||||
self.final_conv = nn.Sequential(
|
||||
_Conv1dBlock(down_dims[0], down_dims[0], kernel_size=kernel_size),
|
||||
nn.Conv1d(down_dims[0], input_dim, 1),
|
||||
)
|
||||
|
||||
self.diffusion_step_encoder = diffusion_step_encoder
|
||||
self.local_cond_encoder = local_cond_encoder
|
||||
self.up_modules = up_modules
|
||||
self.down_modules = down_modules
|
||||
self.final_conv = final_conv
|
||||
|
||||
logger.info("number of parameters: %e", sum(p.numel() for p in self.parameters()))
|
||||
|
||||
def forward(
|
||||
self,
|
||||
sample: torch.Tensor,
|
||||
timestep: Union[torch.Tensor, float, int],
|
||||
local_cond=None,
|
||||
global_cond=None,
|
||||
**kwargs,
|
||||
):
|
||||
def forward(self, x: Tensor, timestep: Tensor | int, local_cond=None, global_cond=None) -> Tensor:
|
||||
"""
|
||||
x: (B,T,input_dim)
|
||||
timestep: (B,) or int, diffusion step
|
||||
local_cond: (B,T,local_cond_dim)
|
||||
global_cond: (B,global_cond_dim)
|
||||
output: (B,T,input_dim)
|
||||
Args:
|
||||
x: (B, T, input_dim) tensor for input to the Unet.
|
||||
timestep: (B,) tensor of (timestep_we_are_denoising_from - 1).
|
||||
local_cond: (B, T, local_cond_dim)
|
||||
global_cond: (B, global_cond_dim)
|
||||
output: (B, T, input_dim)
|
||||
Returns:
|
||||
(B, T, input_dim)
|
||||
"""
|
||||
sample = einops.rearrange(sample, "b h t -> b t h")
|
||||
|
||||
# 1. time
|
||||
timesteps = timestep
|
||||
if not torch.is_tensor(timesteps):
|
||||
# TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can
|
||||
timesteps = torch.tensor([timesteps], dtype=torch.long, device=sample.device)
|
||||
elif torch.is_tensor(timesteps) and len(timesteps.shape) == 0:
|
||||
timesteps = timesteps[None].to(sample.device)
|
||||
# broadcast to batch dimension in a way that's compatible with ONNX/Core ML
|
||||
timesteps = timesteps.expand(sample.shape[0])
|
||||
|
||||
global_feature = self.diffusion_step_encoder(timesteps)
|
||||
|
||||
if global_cond is not None:
|
||||
global_feature = torch.cat([global_feature, global_cond], axis=-1)
|
||||
|
||||
# encode local features
|
||||
h_local = []
|
||||
# For 1D convolutions we'll need feature dimension first.
|
||||
x = einops.rearrange(x, "b t d -> b d t")
|
||||
if local_cond is not None:
|
||||
local_cond = einops.rearrange(local_cond, "b h t -> b t h")
|
||||
resnet, resnet2 = self.local_cond_encoder
|
||||
x = resnet(local_cond, global_feature)
|
||||
h_local.append(x)
|
||||
x = resnet2(local_cond, global_feature)
|
||||
h_local.append(x)
|
||||
if self.local_cond_down_encoder is None or self.local_cond_up_encoder is None:
|
||||
raise ValueError(
|
||||
"`local_cond` was provided but the relevant encoders weren't built at initialization."
|
||||
)
|
||||
local_cond = einops.rearrange(local_cond, "b t d -> b d t")
|
||||
|
||||
x = sample
|
||||
h = []
|
||||
timesteps_embed = self.diffusion_step_encoder(timestep)
|
||||
|
||||
# If there is a global conditioning feature, concatenate it to the timestep embedding.
|
||||
if global_cond is not None:
|
||||
global_feature = torch.cat([timesteps_embed, global_cond], axis=-1)
|
||||
else:
|
||||
global_feature = timesteps_embed
|
||||
|
||||
encoder_skip_features: list[Tensor] = []
|
||||
for idx, (resnet, resnet2, downsample) in enumerate(self.down_modules):
|
||||
x = resnet(x, global_feature)
|
||||
if idx == 0 and len(h_local) > 0:
|
||||
x = x + h_local[0]
|
||||
if idx == 0 and local_cond is not None:
|
||||
x = x + self.local_cond_down_encoder(local_cond, global_feature)
|
||||
x = resnet2(x, global_feature)
|
||||
h.append(x)
|
||||
encoder_skip_features.append(x)
|
||||
x = downsample(x)
|
||||
|
||||
for mid_module in self.mid_modules:
|
||||
x = mid_module(x, global_feature)
|
||||
|
||||
for idx, (resnet, resnet2, upsample) in enumerate(self.up_modules):
|
||||
x = torch.cat((x, h.pop()), dim=1)
|
||||
x = torch.cat((x, encoder_skip_features.pop()), dim=1)
|
||||
x = resnet(x, global_feature)
|
||||
# The correct condition should be:
|
||||
# if idx == (len(self.up_modules)-1) and len(h_local) > 0:
|
||||
# However this change will break compatibility with published checkpoints.
|
||||
# Therefore it is left as a comment.
|
||||
if idx == len(self.up_modules) and len(h_local) > 0:
|
||||
x = x + h_local[1]
|
||||
# Note: The condition in the original implementation is:
|
||||
# if idx == len(self.up_modules) and local_cond is not None:
|
||||
# But as they mention in their comments, this is incorrect. We use the correct condition here.
|
||||
if idx == (len(self.up_modules) - 1) and local_cond is not None:
|
||||
x = x + self.local_cond_up_encoder(local_cond, global_feature)
|
||||
x = resnet2(x, global_feature)
|
||||
x = upsample(x)
|
||||
|
||||
x = self.final_conv(x)
|
||||
|
||||
x = einops.rearrange(x, "b t h -> b h t")
|
||||
x = einops.rearrange(x, "b d t -> b t d")
|
||||
return x
|
||||
|
|
|
@ -1,47 +0,0 @@
|
|||
import torch.nn as nn
|
||||
|
||||
# from einops.layers.torch import Rearrange
|
||||
|
||||
|
||||
class Downsample1d(nn.Module):
|
||||
def __init__(self, dim):
|
||||
super().__init__()
|
||||
self.conv = nn.Conv1d(dim, dim, 3, 2, 1)
|
||||
|
||||
def forward(self, x):
|
||||
return self.conv(x)
|
||||
|
||||
|
||||
class Upsample1d(nn.Module):
|
||||
def __init__(self, dim):
|
||||
super().__init__()
|
||||
self.conv = nn.ConvTranspose1d(dim, dim, 4, 2, 1)
|
||||
|
||||
def forward(self, x):
|
||||
return self.conv(x)
|
||||
|
||||
|
||||
class Conv1dBlock(nn.Module):
|
||||
"""
|
||||
Conv1d --> GroupNorm --> Mish
|
||||
"""
|
||||
|
||||
def __init__(self, inp_channels, out_channels, kernel_size, n_groups=8):
|
||||
super().__init__()
|
||||
|
||||
self.block = nn.Sequential(
|
||||
nn.Conv1d(inp_channels, out_channels, kernel_size, padding=kernel_size // 2),
|
||||
# Rearrange('batch channels horizon -> batch channels 1 horizon'),
|
||||
nn.GroupNorm(n_groups, out_channels),
|
||||
# Rearrange('batch channels 1 horizon -> batch channels horizon'),
|
||||
nn.Mish(),
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
return self.block(x)
|
||||
|
||||
|
||||
# def test():
|
||||
# cb = Conv1dBlock(256, 128, kernel_size=3)
|
||||
# x = torch.zeros((1,256,16))
|
||||
# o = cb(x)
|
|
@ -1,294 +0,0 @@
|
|||
import torch
|
||||
import torch.nn as nn
|
||||
import torchvision.transforms.functional as ttf
|
||||
|
||||
import lerobot.common.policies.diffusion.model.tensor_utils as tu
|
||||
|
||||
|
||||
class CropRandomizer(nn.Module):
|
||||
"""
|
||||
Randomly sample crops at input, and then average across crop features at output.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
input_shape,
|
||||
crop_height,
|
||||
crop_width,
|
||||
num_crops=1,
|
||||
pos_enc=False,
|
||||
):
|
||||
"""
|
||||
Args:
|
||||
input_shape (tuple, list): shape of input (not including batch dimension)
|
||||
crop_height (int): crop height
|
||||
crop_width (int): crop width
|
||||
num_crops (int): number of random crops to take
|
||||
pos_enc (bool): if True, add 2 channels to the output to encode the spatial
|
||||
location of the cropped pixels in the source image
|
||||
"""
|
||||
super().__init__()
|
||||
|
||||
assert len(input_shape) == 3 # (C, H, W)
|
||||
assert crop_height < input_shape[1]
|
||||
assert crop_width < input_shape[2]
|
||||
|
||||
self.input_shape = input_shape
|
||||
self.crop_height = crop_height
|
||||
self.crop_width = crop_width
|
||||
self.num_crops = num_crops
|
||||
self.pos_enc = pos_enc
|
||||
|
||||
def output_shape_in(self, input_shape=None):
|
||||
"""
|
||||
Function to compute output shape from inputs to this module. Corresponds to
|
||||
the @forward_in operation, where raw inputs (usually observation modalities)
|
||||
are passed in.
|
||||
|
||||
Args:
|
||||
input_shape (iterable of int): shape of input. Does not include batch dimension.
|
||||
Some modules may not need this argument, if their output does not depend
|
||||
on the size of the input, or if they assume fixed size input.
|
||||
|
||||
Returns:
|
||||
out_shape ([int]): list of integers corresponding to output shape
|
||||
"""
|
||||
|
||||
# outputs are shape (C, CH, CW), or maybe C + 2 if using position encoding, because
|
||||
# the number of crops are reshaped into the batch dimension, increasing the batch
|
||||
# size from B to B * N
|
||||
out_c = self.input_shape[0] + 2 if self.pos_enc else self.input_shape[0]
|
||||
return [out_c, self.crop_height, self.crop_width]
|
||||
|
||||
def output_shape_out(self, input_shape=None):
|
||||
"""
|
||||
Function to compute output shape from inputs to this module. Corresponds to
|
||||
the @forward_out operation, where processed inputs (usually encoded observation
|
||||
modalities) are passed in.
|
||||
|
||||
Args:
|
||||
input_shape (iterable of int): shape of input. Does not include batch dimension.
|
||||
Some modules may not need this argument, if their output does not depend
|
||||
on the size of the input, or if they assume fixed size input.
|
||||
|
||||
Returns:
|
||||
out_shape ([int]): list of integers corresponding to output shape
|
||||
"""
|
||||
|
||||
# since the forward_out operation splits [B * N, ...] -> [B, N, ...]
|
||||
# and then pools to result in [B, ...], only the batch dimension changes,
|
||||
# and so the other dimensions retain their shape.
|
||||
return list(input_shape)
|
||||
|
||||
def forward_in(self, inputs):
|
||||
"""
|
||||
Samples N random crops for each input in the batch, and then reshapes
|
||||
inputs to [B * N, ...].
|
||||
"""
|
||||
assert len(inputs.shape) >= 3 # must have at least (C, H, W) dimensions
|
||||
if self.training:
|
||||
# generate random crops
|
||||
out, _ = sample_random_image_crops(
|
||||
images=inputs,
|
||||
crop_height=self.crop_height,
|
||||
crop_width=self.crop_width,
|
||||
num_crops=self.num_crops,
|
||||
pos_enc=self.pos_enc,
|
||||
)
|
||||
# [B, N, ...] -> [B * N, ...]
|
||||
return tu.join_dimensions(out, 0, 1)
|
||||
else:
|
||||
# take center crop during eval
|
||||
out = ttf.center_crop(img=inputs, output_size=(self.crop_height, self.crop_width))
|
||||
if self.num_crops > 1:
|
||||
B, C, H, W = out.shape # noqa: N806
|
||||
out = out.unsqueeze(1).expand(B, self.num_crops, C, H, W).reshape(-1, C, H, W)
|
||||
# [B * N, ...]
|
||||
return out
|
||||
|
||||
def forward_out(self, inputs):
|
||||
"""
|
||||
Splits the outputs from shape [B * N, ...] -> [B, N, ...] and then average across N
|
||||
to result in shape [B, ...] to make sure the network output is consistent with
|
||||
what would have happened if there were no randomization.
|
||||
"""
|
||||
if self.num_crops <= 1:
|
||||
return inputs
|
||||
else:
|
||||
batch_size = inputs.shape[0] // self.num_crops
|
||||
out = tu.reshape_dimensions(
|
||||
inputs, begin_axis=0, end_axis=0, target_dims=(batch_size, self.num_crops)
|
||||
)
|
||||
return out.mean(dim=1)
|
||||
|
||||
def forward(self, inputs):
|
||||
return self.forward_in(inputs)
|
||||
|
||||
def __repr__(self):
|
||||
"""Pretty print network."""
|
||||
header = "{}".format(str(self.__class__.__name__))
|
||||
msg = header + "(input_shape={}, crop_size=[{}, {}], num_crops={})".format(
|
||||
self.input_shape, self.crop_height, self.crop_width, self.num_crops
|
||||
)
|
||||
return msg
|
||||
|
||||
|
||||
def crop_image_from_indices(images, crop_indices, crop_height, crop_width):
|
||||
"""
|
||||
Crops images at the locations specified by @crop_indices. Crops will be
|
||||
taken across all channels.
|
||||
|
||||
Args:
|
||||
images (torch.Tensor): batch of images of shape [..., C, H, W]
|
||||
|
||||
crop_indices (torch.Tensor): batch of indices of shape [..., N, 2] where
|
||||
N is the number of crops to take per image and each entry corresponds
|
||||
to the pixel height and width of where to take the crop. Note that
|
||||
the indices can also be of shape [..., 2] if only 1 crop should
|
||||
be taken per image. Leading dimensions must be consistent with
|
||||
@images argument. Each index specifies the top left of the crop.
|
||||
Values must be in range [0, H - CH - 1] x [0, W - CW - 1] where
|
||||
H and W are the height and width of @images and CH and CW are
|
||||
@crop_height and @crop_width.
|
||||
|
||||
crop_height (int): height of crop to take
|
||||
|
||||
crop_width (int): width of crop to take
|
||||
|
||||
Returns:
|
||||
crops (torch.Tesnor): cropped images of shape [..., C, @crop_height, @crop_width]
|
||||
"""
|
||||
|
||||
# make sure length of input shapes is consistent
|
||||
assert crop_indices.shape[-1] == 2
|
||||
ndim_im_shape = len(images.shape)
|
||||
ndim_indices_shape = len(crop_indices.shape)
|
||||
assert (ndim_im_shape == ndim_indices_shape + 1) or (ndim_im_shape == ndim_indices_shape + 2)
|
||||
|
||||
# maybe pad so that @crop_indices is shape [..., N, 2]
|
||||
is_padded = False
|
||||
if ndim_im_shape == ndim_indices_shape + 2:
|
||||
crop_indices = crop_indices.unsqueeze(-2)
|
||||
is_padded = True
|
||||
|
||||
# make sure leading dimensions between images and indices are consistent
|
||||
assert images.shape[:-3] == crop_indices.shape[:-2]
|
||||
|
||||
device = images.device
|
||||
image_c, image_h, image_w = images.shape[-3:]
|
||||
num_crops = crop_indices.shape[-2]
|
||||
|
||||
# make sure @crop_indices are in valid range
|
||||
assert (crop_indices[..., 0] >= 0).all().item()
|
||||
assert (crop_indices[..., 0] < (image_h - crop_height)).all().item()
|
||||
assert (crop_indices[..., 1] >= 0).all().item()
|
||||
assert (crop_indices[..., 1] < (image_w - crop_width)).all().item()
|
||||
|
||||
# convert each crop index (ch, cw) into a list of pixel indices that correspond to the entire window.
|
||||
|
||||
# 2D index array with columns [0, 1, ..., CH - 1] and shape [CH, CW]
|
||||
crop_ind_grid_h = torch.arange(crop_height).to(device)
|
||||
crop_ind_grid_h = tu.unsqueeze_expand_at(crop_ind_grid_h, size=crop_width, dim=-1)
|
||||
# 2D index array with rows [0, 1, ..., CW - 1] and shape [CH, CW]
|
||||
crop_ind_grid_w = torch.arange(crop_width).to(device)
|
||||
crop_ind_grid_w = tu.unsqueeze_expand_at(crop_ind_grid_w, size=crop_height, dim=0)
|
||||
# combine into shape [CH, CW, 2]
|
||||
crop_in_grid = torch.cat((crop_ind_grid_h.unsqueeze(-1), crop_ind_grid_w.unsqueeze(-1)), dim=-1)
|
||||
|
||||
# Add above grid with the offset index of each sampled crop to get 2d indices for each crop.
|
||||
# After broadcasting, this will be shape [..., N, CH, CW, 2] and each crop has a [CH, CW, 2]
|
||||
# shape array that tells us which pixels from the corresponding source image to grab.
|
||||
grid_reshape = [1] * len(crop_indices.shape[:-1]) + [crop_height, crop_width, 2]
|
||||
all_crop_inds = crop_indices.unsqueeze(-2).unsqueeze(-2) + crop_in_grid.reshape(grid_reshape)
|
||||
|
||||
# For using @torch.gather, convert to flat indices from 2D indices, and also
|
||||
# repeat across the channel dimension. To get flat index of each pixel to grab for
|
||||
# each sampled crop, we just use the mapping: ind = h_ind * @image_w + w_ind
|
||||
all_crop_inds = all_crop_inds[..., 0] * image_w + all_crop_inds[..., 1] # shape [..., N, CH, CW]
|
||||
all_crop_inds = tu.unsqueeze_expand_at(all_crop_inds, size=image_c, dim=-3) # shape [..., N, C, CH, CW]
|
||||
all_crop_inds = tu.flatten(all_crop_inds, begin_axis=-2) # shape [..., N, C, CH * CW]
|
||||
|
||||
# Repeat and flatten the source images -> [..., N, C, H * W] and then use gather to index with crop pixel inds
|
||||
images_to_crop = tu.unsqueeze_expand_at(images, size=num_crops, dim=-4)
|
||||
images_to_crop = tu.flatten(images_to_crop, begin_axis=-2)
|
||||
crops = torch.gather(images_to_crop, dim=-1, index=all_crop_inds)
|
||||
# [..., N, C, CH * CW] -> [..., N, C, CH, CW]
|
||||
reshape_axis = len(crops.shape) - 1
|
||||
crops = tu.reshape_dimensions(
|
||||
crops, begin_axis=reshape_axis, end_axis=reshape_axis, target_dims=(crop_height, crop_width)
|
||||
)
|
||||
|
||||
if is_padded:
|
||||
# undo padding -> [..., C, CH, CW]
|
||||
crops = crops.squeeze(-4)
|
||||
return crops
|
||||
|
||||
|
||||
def sample_random_image_crops(images, crop_height, crop_width, num_crops, pos_enc=False):
|
||||
"""
|
||||
For each image, randomly sample @num_crops crops of size (@crop_height, @crop_width), from
|
||||
@images.
|
||||
|
||||
Args:
|
||||
images (torch.Tensor): batch of images of shape [..., C, H, W]
|
||||
|
||||
crop_height (int): height of crop to take
|
||||
|
||||
crop_width (int): width of crop to take
|
||||
|
||||
num_crops (n): number of crops to sample
|
||||
|
||||
pos_enc (bool): if True, also add 2 channels to the outputs that gives a spatial
|
||||
encoding of the original source pixel locations. This means that the
|
||||
output crops will contain information about where in the source image
|
||||
it was sampled from.
|
||||
|
||||
Returns:
|
||||
crops (torch.Tensor): crops of shape (..., @num_crops, C, @crop_height, @crop_width)
|
||||
if @pos_enc is False, otherwise (..., @num_crops, C + 2, @crop_height, @crop_width)
|
||||
|
||||
crop_inds (torch.Tensor): sampled crop indices of shape (..., N, 2)
|
||||
"""
|
||||
device = images.device
|
||||
|
||||
# maybe add 2 channels of spatial encoding to the source image
|
||||
source_im = images
|
||||
if pos_enc:
|
||||
# spatial encoding [y, x] in [0, 1]
|
||||
h, w = source_im.shape[-2:]
|
||||
pos_y, pos_x = torch.meshgrid(torch.arange(h), torch.arange(w))
|
||||
pos_y = pos_y.float().to(device) / float(h)
|
||||
pos_x = pos_x.float().to(device) / float(w)
|
||||
position_enc = torch.stack((pos_y, pos_x)) # shape [C, H, W]
|
||||
|
||||
# unsqueeze and expand to match leading dimensions -> shape [..., C, H, W]
|
||||
leading_shape = source_im.shape[:-3]
|
||||
position_enc = position_enc[(None,) * len(leading_shape)]
|
||||
position_enc = position_enc.expand(*leading_shape, -1, -1, -1)
|
||||
|
||||
# concat across channel dimension with input
|
||||
source_im = torch.cat((source_im, position_enc), dim=-3)
|
||||
|
||||
# make sure sample boundaries ensure crops are fully within the images
|
||||
image_c, image_h, image_w = source_im.shape[-3:]
|
||||
max_sample_h = image_h - crop_height
|
||||
max_sample_w = image_w - crop_width
|
||||
|
||||
# Sample crop locations for all tensor dimensions up to the last 3, which are [C, H, W].
|
||||
# Each gets @num_crops samples - typically this will just be the batch dimension (B), so
|
||||
# we will sample [B, N] indices, but this supports having more than one leading dimension,
|
||||
# or possibly no leading dimension.
|
||||
#
|
||||
# Trick: sample in [0, 1) with rand, then re-scale to [0, M) and convert to long to get sampled ints
|
||||
crop_inds_h = (max_sample_h * torch.rand(*source_im.shape[:-3], num_crops).to(device)).long()
|
||||
crop_inds_w = (max_sample_w * torch.rand(*source_im.shape[:-3], num_crops).to(device)).long()
|
||||
crop_inds = torch.cat((crop_inds_h.unsqueeze(-1), crop_inds_w.unsqueeze(-1)), dim=-1) # shape [..., N, 2]
|
||||
|
||||
crops = crop_image_from_indices(
|
||||
images=source_im,
|
||||
crop_indices=crop_inds,
|
||||
crop_height=crop_height,
|
||||
crop_width=crop_width,
|
||||
)
|
||||
|
||||
return crops, crop_inds
|
|
@ -1,41 +0,0 @@
|
|||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
|
||||
class DictOfTensorMixin(nn.Module):
|
||||
def __init__(self, params_dict=None):
|
||||
super().__init__()
|
||||
if params_dict is None:
|
||||
params_dict = nn.ParameterDict()
|
||||
self.params_dict = params_dict
|
||||
|
||||
@property
|
||||
def device(self):
|
||||
return next(iter(self.parameters())).device
|
||||
|
||||
def _load_from_state_dict(
|
||||
self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs
|
||||
):
|
||||
def dfs_add(dest, keys, value: torch.Tensor):
|
||||
if len(keys) == 1:
|
||||
dest[keys[0]] = value
|
||||
return
|
||||
|
||||
if keys[0] not in dest:
|
||||
dest[keys[0]] = nn.ParameterDict()
|
||||
dfs_add(dest[keys[0]], keys[1:], value)
|
||||
|
||||
def load_dict(state_dict, prefix):
|
||||
out_dict = nn.ParameterDict()
|
||||
for key, value in state_dict.items():
|
||||
value: torch.Tensor
|
||||
if key.startswith(prefix):
|
||||
param_keys = key[len(prefix) :].split(".")[1:]
|
||||
# if len(param_keys) == 0:
|
||||
# import pdb; pdb.set_trace()
|
||||
dfs_add(out_dict, param_keys, value.clone())
|
||||
return out_dict
|
||||
|
||||
self.params_dict = load_dict(state_dict, prefix + "params_dict")
|
||||
self.params_dict.requires_grad_(False)
|
||||
return
|
|
@ -0,0 +1,220 @@
|
|||
import einops
|
||||
import torch
|
||||
import torch.nn.functional as F # noqa: N812
|
||||
from diffusers.schedulers.scheduling_ddpm import DDPMScheduler
|
||||
from torch import Tensor, nn
|
||||
|
||||
from lerobot.common.policies.diffusion.model.conditional_unet1d import ConditionalUnet1D
|
||||
from lerobot.common.policies.diffusion.model.rgb_encoder import RgbEncoder
|
||||
from lerobot.common.policies.utils import get_device_from_parameters, get_dtype_from_parameters
|
||||
|
||||
|
||||
class DiffusionUnetImagePolicy(nn.Module):
|
||||
"""
|
||||
TODO(now): Add DDIM scheduler.
|
||||
|
||||
Changes: TODO(now)
|
||||
- Use single image encoder for now instead of generic obs_encoder. We may generalize again when/if
|
||||
needed. Code for a general observation encoder can be found at:
|
||||
https://github.com/huggingface/lerobot/blob/920e0d118b493e4cc3058a9b1b764f38ae145d8e/lerobot/common/policies/diffusion/model/multi_image_obs_encoder.py
|
||||
- Uses the observation as global conditioning for the Unet by default.
|
||||
- Does not do any inpainting (which would be applicable if the observation were not used to condition
|
||||
the Unet).
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
cfg,
|
||||
shape_meta: dict,
|
||||
noise_scheduler: DDPMScheduler,
|
||||
horizon,
|
||||
n_action_steps,
|
||||
n_obs_steps,
|
||||
num_inference_steps=None,
|
||||
diffusion_step_embed_dim=256,
|
||||
down_dims=(256, 512, 1024),
|
||||
kernel_size=5,
|
||||
n_groups=8,
|
||||
film_scale_modulation=True,
|
||||
):
|
||||
super().__init__()
|
||||
action_shape = shape_meta["action"]["shape"]
|
||||
assert len(action_shape) == 1
|
||||
action_dim = action_shape[0]
|
||||
|
||||
self.rgb_encoder = RgbEncoder(input_shape=shape_meta.obs.image.shape, **cfg.rgb_encoder)
|
||||
|
||||
self.unet = ConditionalUnet1D(
|
||||
input_dim=action_dim,
|
||||
global_cond_dim=(action_dim + self.rgb_encoder.feature_dim) * n_obs_steps,
|
||||
diffusion_step_embed_dim=diffusion_step_embed_dim,
|
||||
down_dims=down_dims,
|
||||
kernel_size=kernel_size,
|
||||
n_groups=n_groups,
|
||||
film_scale_modulation=film_scale_modulation,
|
||||
)
|
||||
|
||||
self.noise_scheduler = noise_scheduler
|
||||
self.horizon = horizon
|
||||
self.action_dim = action_dim
|
||||
self.n_action_steps = n_action_steps
|
||||
self.n_obs_steps = n_obs_steps
|
||||
|
||||
if num_inference_steps is None:
|
||||
num_inference_steps = noise_scheduler.config.num_train_timesteps
|
||||
|
||||
self.num_inference_steps = num_inference_steps
|
||||
|
||||
# ========= inference ============
|
||||
def conditional_sample(
|
||||
self,
|
||||
condition_data,
|
||||
inpainting_mask,
|
||||
local_cond=None,
|
||||
global_cond=None,
|
||||
generator=None,
|
||||
):
|
||||
model = self.unet
|
||||
scheduler = self.noise_scheduler
|
||||
|
||||
trajectory = torch.randn(
|
||||
size=condition_data.shape,
|
||||
dtype=condition_data.dtype,
|
||||
device=condition_data.device,
|
||||
generator=generator,
|
||||
)
|
||||
|
||||
# set step values
|
||||
scheduler.set_timesteps(self.num_inference_steps)
|
||||
|
||||
for t in scheduler.timesteps:
|
||||
# 1. apply conditioning
|
||||
trajectory[inpainting_mask] = condition_data[inpainting_mask]
|
||||
|
||||
# 2. predict model output
|
||||
model_output = model(
|
||||
trajectory,
|
||||
torch.full(trajectory.shape[:1], t, dtype=torch.long, device=trajectory.device),
|
||||
local_cond=local_cond,
|
||||
global_cond=global_cond,
|
||||
)
|
||||
|
||||
# 3. compute previous image: x_t -> x_t-1
|
||||
trajectory = scheduler.step(
|
||||
model_output,
|
||||
t,
|
||||
trajectory,
|
||||
generator=generator,
|
||||
).prev_sample
|
||||
|
||||
# finally make sure conditioning is enforced
|
||||
trajectory[inpainting_mask] = condition_data[inpainting_mask]
|
||||
|
||||
return trajectory
|
||||
|
||||
def predict_action(self, batch: dict[str, Tensor]) -> dict[str, Tensor]:
|
||||
"""
|
||||
This function expects `batch` to have (at least):
|
||||
{
|
||||
"observation.state": (B, n_obs_steps, state_dim)
|
||||
"observation.image": (B, n_obs_steps, C, H, W)
|
||||
}
|
||||
"""
|
||||
assert set(batch).issuperset({"observation.state", "observation.image"})
|
||||
batch_size, n_obs_steps = batch["observation.state"].shape[:2]
|
||||
assert n_obs_steps == self.n_obs_steps
|
||||
assert self.n_obs_steps == n_obs_steps
|
||||
|
||||
# build input
|
||||
device = get_device_from_parameters(self)
|
||||
dtype = get_dtype_from_parameters(self)
|
||||
|
||||
# Extract image feature (first combine batch and sequence dims).
|
||||
img_features = self.rgb_encoder(einops.rearrange(batch["observation.image"], "b n ... -> (b n) ..."))
|
||||
# Separate batch and sequence dims.
|
||||
img_features = einops.rearrange(img_features, "(b n) ... -> b n ...", b=batch_size)
|
||||
# Concatenate state and image features then flatten to (B, global_cond_dim).
|
||||
global_cond = torch.cat([batch["observation.state"], img_features], dim=-1).flatten(start_dim=1)
|
||||
# reshape back to B, Do
|
||||
# empty data for action
|
||||
cond_data = torch.zeros(size=(batch_size, self.horizon, self.action_dim), device=device, dtype=dtype)
|
||||
cond_mask = torch.zeros_like(cond_data, dtype=torch.bool)
|
||||
|
||||
# run sampling
|
||||
nsample = self.conditional_sample(cond_data, cond_mask, global_cond=global_cond)
|
||||
|
||||
# `horizon` steps worth of actions (from the first observation).
|
||||
action = nsample[..., : self.action_dim]
|
||||
# Extract `n_action_steps` steps worth of action (from the current observation).
|
||||
start = n_obs_steps - 1
|
||||
end = start + self.n_action_steps
|
||||
action = action[:, start:end]
|
||||
|
||||
return action
|
||||
|
||||
def compute_loss(self, batch: dict[str, Tensor]) -> Tensor:
|
||||
"""
|
||||
This function expects `batch` to have (at least):
|
||||
{
|
||||
"observation.state": (B, n_obs_steps, state_dim)
|
||||
"observation.image": (B, n_obs_steps, C, H, W)
|
||||
"action": (B, horizon, action_dim)
|
||||
"action_is_pad": (B, horizon) # TODO(now) maybe this is (B, horizon, 1)
|
||||
}
|
||||
"""
|
||||
assert set(batch).issuperset({"observation.state", "observation.image", "action", "action_is_pad"})
|
||||
batch_size, n_obs_steps = batch["observation.state"].shape[:2]
|
||||
horizon = batch["action"].shape[1]
|
||||
assert horizon == self.horizon
|
||||
assert n_obs_steps == self.n_obs_steps
|
||||
assert self.n_obs_steps == n_obs_steps
|
||||
|
||||
# handle different ways of passing observation
|
||||
local_cond = None
|
||||
global_cond = None
|
||||
trajectory = batch["action"]
|
||||
cond_data = trajectory
|
||||
|
||||
# Extract image feature (first combine batch and sequence dims).
|
||||
img_features = self.rgb_encoder(einops.rearrange(batch["observation.image"], "b n ... -> (b n) ..."))
|
||||
# Separate batch and sequence dims.
|
||||
img_features = einops.rearrange(img_features, "(b n) ... -> b n ...", b=batch_size)
|
||||
# Concatenate state and image features then flatten to (B, global_cond_dim).
|
||||
global_cond = torch.cat([batch["observation.state"], img_features], dim=-1).flatten(start_dim=1)
|
||||
|
||||
# Sample noise that we'll add to the images
|
||||
noise = torch.randn(trajectory.shape, device=trajectory.device)
|
||||
# Sample a random timestep for each image
|
||||
timesteps = torch.randint(
|
||||
0,
|
||||
self.noise_scheduler.config.num_train_timesteps,
|
||||
(trajectory.shape[0],),
|
||||
device=trajectory.device,
|
||||
).long()
|
||||
# Add noise to the clean images according to the noise magnitude at each timestep
|
||||
# (this is the forward diffusion process)
|
||||
noisy_trajectory = self.noise_scheduler.add_noise(trajectory, noise, timesteps)
|
||||
|
||||
# Apply inpainting. TODO(now): implement?
|
||||
inpainting_mask = torch.zeros_like(trajectory, dtype=bool)
|
||||
noisy_trajectory[inpainting_mask] = cond_data[inpainting_mask]
|
||||
|
||||
# Predict the noise residual
|
||||
pred = self.unet(noisy_trajectory, timesteps, local_cond=local_cond, global_cond=global_cond)
|
||||
|
||||
pred_type = self.noise_scheduler.config.prediction_type
|
||||
if pred_type == "epsilon":
|
||||
target = noise
|
||||
elif pred_type == "sample":
|
||||
target = trajectory
|
||||
else:
|
||||
raise ValueError(f"Unsupported prediction type {pred_type}")
|
||||
|
||||
loss = F.mse_loss(pred, target, reduction="none")
|
||||
loss = loss * (~inpainting_mask)
|
||||
|
||||
if "action_is_pad" in batch:
|
||||
in_episode_bound = ~batch["action_is_pad"]
|
||||
loss = loss * in_episode_bound[:, :, None].type(loss.dtype)
|
||||
|
||||
return loss.mean()
|
|
@ -51,13 +51,6 @@ class EMAModel:
|
|||
def step(self, new_model):
|
||||
self.decay = self.get_decay(self.optimization_step)
|
||||
|
||||
# old_all_dataptrs = set()
|
||||
# for param in new_model.parameters():
|
||||
# data_ptr = param.data_ptr()
|
||||
# if data_ptr != 0:
|
||||
# old_all_dataptrs.add(data_ptr)
|
||||
|
||||
# all_dataptrs = set()
|
||||
for module, ema_module in zip(new_model.modules(), self.averaged_model.modules(), strict=False):
|
||||
for param, ema_param in zip(
|
||||
module.parameters(recurse=False), ema_module.parameters(recurse=False), strict=False
|
||||
|
@ -66,10 +59,6 @@ class EMAModel:
|
|||
if isinstance(param, dict):
|
||||
raise RuntimeError("Dict parameter not supported")
|
||||
|
||||
# data_ptr = param.data_ptr()
|
||||
# if data_ptr != 0:
|
||||
# all_dataptrs.add(data_ptr)
|
||||
|
||||
if isinstance(module, _BatchNorm):
|
||||
# skip batchnorms
|
||||
ema_param.copy_(param.to(dtype=ema_param.dtype).data)
|
||||
|
@ -79,6 +68,4 @@ class EMAModel:
|
|||
ema_param.mul_(self.decay)
|
||||
ema_param.add_(param.data.to(dtype=ema_param.dtype), alpha=1 - self.decay)
|
||||
|
||||
# verify that iterating over module and then parameters is identical to parameters recursively.
|
||||
# assert old_all_dataptrs == all_dataptrs
|
||||
self.optimization_step += 1
|
||||
|
|
|
@ -1,46 +0,0 @@
|
|||
from diffusers.optimization import TYPE_TO_SCHEDULER_FUNCTION, Optimizer, Optional, SchedulerType, Union
|
||||
|
||||
|
||||
def get_scheduler(
|
||||
name: Union[str, SchedulerType],
|
||||
optimizer: Optimizer,
|
||||
num_warmup_steps: Optional[int] = None,
|
||||
num_training_steps: Optional[int] = None,
|
||||
**kwargs,
|
||||
):
|
||||
"""
|
||||
Added kwargs vs diffuser's original implementation
|
||||
|
||||
Unified API to get any scheduler from its name.
|
||||
|
||||
Args:
|
||||
name (`str` or `SchedulerType`):
|
||||
The name of the scheduler to use.
|
||||
optimizer (`torch.optim.Optimizer`):
|
||||
The optimizer that will be used during training.
|
||||
num_warmup_steps (`int`, *optional*):
|
||||
The number of warmup steps to do. This is not required by all schedulers (hence the argument being
|
||||
optional), the function will raise an error if it's unset and the scheduler type requires it.
|
||||
num_training_steps (`int``, *optional*):
|
||||
The number of training steps to do. This is not required by all schedulers (hence the argument being
|
||||
optional), the function will raise an error if it's unset and the scheduler type requires it.
|
||||
"""
|
||||
name = SchedulerType(name)
|
||||
schedule_func = TYPE_TO_SCHEDULER_FUNCTION[name]
|
||||
if name == SchedulerType.CONSTANT:
|
||||
return schedule_func(optimizer, **kwargs)
|
||||
|
||||
# All other schedulers require `num_warmup_steps`
|
||||
if num_warmup_steps is None:
|
||||
raise ValueError(f"{name} requires `num_warmup_steps`, please provide that argument.")
|
||||
|
||||
if name == SchedulerType.CONSTANT_WITH_WARMUP:
|
||||
return schedule_func(optimizer, num_warmup_steps=num_warmup_steps, **kwargs)
|
||||
|
||||
# All other schedulers require `num_training_steps`
|
||||
if num_training_steps is None:
|
||||
raise ValueError(f"{name} requires `num_training_steps`, please provide that argument.")
|
||||
|
||||
return schedule_func(
|
||||
optimizer, num_warmup_steps=num_warmup_steps, num_training_steps=num_training_steps, **kwargs
|
||||
)
|
|
@ -1,65 +0,0 @@
|
|||
import torch
|
||||
|
||||
from lerobot.common.policies.diffusion.model.module_attr_mixin import ModuleAttrMixin
|
||||
|
||||
|
||||
class LowdimMaskGenerator(ModuleAttrMixin):
|
||||
def __init__(
|
||||
self,
|
||||
action_dim,
|
||||
obs_dim,
|
||||
# obs mask setup
|
||||
max_n_obs_steps=2,
|
||||
fix_obs_steps=True,
|
||||
# action mask
|
||||
action_visible=False,
|
||||
):
|
||||
super().__init__()
|
||||
self.action_dim = action_dim
|
||||
self.obs_dim = obs_dim
|
||||
self.max_n_obs_steps = max_n_obs_steps
|
||||
self.fix_obs_steps = fix_obs_steps
|
||||
self.action_visible = action_visible
|
||||
|
||||
@torch.no_grad()
|
||||
def forward(self, shape, seed=None):
|
||||
device = self.device
|
||||
B, T, D = shape # noqa: N806
|
||||
assert (self.action_dim + self.obs_dim) == D
|
||||
|
||||
# create all tensors on this device
|
||||
rng = torch.Generator(device=device)
|
||||
if seed is not None:
|
||||
rng = rng.manual_seed(seed)
|
||||
|
||||
# generate dim mask
|
||||
dim_mask = torch.zeros(size=shape, dtype=torch.bool, device=device)
|
||||
is_action_dim = dim_mask.clone()
|
||||
is_action_dim[..., : self.action_dim] = True
|
||||
is_obs_dim = ~is_action_dim
|
||||
|
||||
# generate obs mask
|
||||
if self.fix_obs_steps:
|
||||
obs_steps = torch.full((B,), fill_value=self.max_n_obs_steps, device=device)
|
||||
else:
|
||||
obs_steps = torch.randint(
|
||||
low=1, high=self.max_n_obs_steps + 1, size=(B,), generator=rng, device=device
|
||||
)
|
||||
|
||||
steps = torch.arange(0, T, device=device).reshape(1, T).expand(B, T)
|
||||
obs_mask = (obs_steps > steps.T).T.reshape(B, T, 1).expand(B, T, D)
|
||||
obs_mask = obs_mask & is_obs_dim
|
||||
|
||||
# generate action mask
|
||||
if self.action_visible:
|
||||
action_steps = torch.maximum(
|
||||
obs_steps - 1, torch.tensor(0, dtype=obs_steps.dtype, device=obs_steps.device)
|
||||
)
|
||||
action_mask = (action_steps > steps.T).T.reshape(B, T, 1).expand(B, T, D)
|
||||
action_mask = action_mask & is_action_dim
|
||||
|
||||
mask = obs_mask
|
||||
if self.action_visible:
|
||||
mask = mask | action_mask
|
||||
|
||||
return mask
|
|
@ -1,15 +0,0 @@
|
|||
import torch.nn as nn
|
||||
|
||||
|
||||
class ModuleAttrMixin(nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self._dummy_variable = nn.Parameter()
|
||||
|
||||
@property
|
||||
def device(self):
|
||||
return next(iter(self.parameters())).device
|
||||
|
||||
@property
|
||||
def dtype(self):
|
||||
return next(iter(self.parameters())).dtype
|
|
@ -1,214 +0,0 @@
|
|||
import copy
|
||||
from typing import Dict, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torchvision
|
||||
from robomimic.models.base_nets import ResNet18Conv, SpatialSoftmax
|
||||
|
||||
from lerobot.common.policies.diffusion.model.crop_randomizer import CropRandomizer
|
||||
from lerobot.common.policies.diffusion.model.module_attr_mixin import ModuleAttrMixin
|
||||
from lerobot.common.policies.diffusion.pytorch_utils import replace_submodules
|
||||
|
||||
|
||||
class RgbEncoder(nn.Module):
|
||||
"""Following `VisualCore` from Robomimic 0.2.0."""
|
||||
|
||||
def __init__(self, input_shape, relu=True, pretrained=False, num_keypoints=32):
|
||||
"""
|
||||
input_shape: channel-first input shape (C, H, W)
|
||||
resnet_name: a timm model name.
|
||||
pretrained: whether to use timm pretrained weights.
|
||||
relu: whether to use relu as a final step.
|
||||
num_keypoints: Number of keypoints for SpatialSoftmax (default value of 32 matches PushT Image).
|
||||
"""
|
||||
super().__init__()
|
||||
self.backbone = ResNet18Conv(input_channel=input_shape[0], pretrained=pretrained)
|
||||
# Figure out the feature map shape.
|
||||
with torch.inference_mode():
|
||||
feat_map_shape = tuple(self.backbone(torch.zeros(size=(1, *input_shape))).shape[1:])
|
||||
self.pool = SpatialSoftmax(feat_map_shape, num_kp=num_keypoints)
|
||||
self.out = nn.Linear(num_keypoints * 2, num_keypoints * 2)
|
||||
self.relu = nn.ReLU() if relu else nn.Identity()
|
||||
|
||||
def forward(self, x):
|
||||
return self.relu(self.out(torch.flatten(self.pool(self.backbone(x)), start_dim=1)))
|
||||
|
||||
|
||||
class MultiImageObsEncoder(ModuleAttrMixin):
|
||||
def __init__(
|
||||
self,
|
||||
shape_meta: dict,
|
||||
rgb_model: Union[nn.Module, Dict[str, nn.Module]],
|
||||
resize_shape: Union[Tuple[int, int], Dict[str, tuple], None] = None,
|
||||
crop_shape: Union[Tuple[int, int], Dict[str, tuple], None] = None,
|
||||
random_crop: bool = True,
|
||||
# replace BatchNorm with GroupNorm
|
||||
use_group_norm: bool = False,
|
||||
# use single rgb model for all rgb inputs
|
||||
share_rgb_model: bool = False,
|
||||
# renormalize rgb input with imagenet normalization
|
||||
# assuming input in [0,1]
|
||||
norm_mean_std: Optional[tuple[float, float]] = None,
|
||||
):
|
||||
"""
|
||||
Assumes rgb input: B,C,H,W
|
||||
Assumes low_dim input: B,D
|
||||
"""
|
||||
super().__init__()
|
||||
|
||||
rgb_keys = []
|
||||
low_dim_keys = []
|
||||
key_model_map = nn.ModuleDict()
|
||||
key_transform_map = nn.ModuleDict()
|
||||
key_shape_map = {}
|
||||
|
||||
# handle sharing vision backbone
|
||||
if share_rgb_model:
|
||||
assert isinstance(rgb_model, nn.Module)
|
||||
key_model_map["rgb"] = rgb_model
|
||||
|
||||
obs_shape_meta = shape_meta["obs"]
|
||||
for key, attr in obs_shape_meta.items():
|
||||
shape = tuple(attr["shape"])
|
||||
type = attr.get("type", "low_dim")
|
||||
key_shape_map[key] = shape
|
||||
if type == "rgb":
|
||||
rgb_keys.append(key)
|
||||
# configure model for this key
|
||||
this_model = None
|
||||
if not share_rgb_model:
|
||||
if isinstance(rgb_model, dict):
|
||||
# have provided model for each key
|
||||
this_model = rgb_model[key]
|
||||
else:
|
||||
assert isinstance(rgb_model, nn.Module)
|
||||
# have a copy of the rgb model
|
||||
this_model = copy.deepcopy(rgb_model)
|
||||
|
||||
if this_model is not None:
|
||||
if use_group_norm:
|
||||
this_model = replace_submodules(
|
||||
root_module=this_model,
|
||||
predicate=lambda x: isinstance(x, nn.BatchNorm2d),
|
||||
func=lambda x: nn.GroupNorm(
|
||||
num_groups=x.num_features // 16, num_channels=x.num_features
|
||||
),
|
||||
)
|
||||
key_model_map[key] = this_model
|
||||
|
||||
# configure resize
|
||||
input_shape = shape
|
||||
this_resizer = nn.Identity()
|
||||
if resize_shape is not None:
|
||||
if isinstance(resize_shape, dict):
|
||||
h, w = resize_shape[key]
|
||||
else:
|
||||
h, w = resize_shape
|
||||
this_resizer = torchvision.transforms.Resize(size=(h, w))
|
||||
input_shape = (shape[0], h, w)
|
||||
|
||||
# configure randomizer
|
||||
this_randomizer = nn.Identity()
|
||||
if crop_shape is not None:
|
||||
if isinstance(crop_shape, dict):
|
||||
h, w = crop_shape[key]
|
||||
else:
|
||||
h, w = crop_shape
|
||||
if random_crop:
|
||||
this_randomizer = CropRandomizer(
|
||||
input_shape=input_shape, crop_height=h, crop_width=w, num_crops=1, pos_enc=False
|
||||
)
|
||||
else:
|
||||
this_normalizer = torchvision.transforms.CenterCrop(size=(h, w))
|
||||
# configure normalizer
|
||||
this_normalizer = nn.Identity()
|
||||
if norm_mean_std is not None:
|
||||
this_normalizer = torchvision.transforms.Normalize(
|
||||
mean=norm_mean_std[0], std=norm_mean_std[1]
|
||||
)
|
||||
|
||||
this_transform = nn.Sequential(this_resizer, this_randomizer, this_normalizer)
|
||||
key_transform_map[key] = this_transform
|
||||
elif type == "low_dim":
|
||||
low_dim_keys.append(key)
|
||||
else:
|
||||
raise RuntimeError(f"Unsupported obs type: {type}")
|
||||
rgb_keys = sorted(rgb_keys)
|
||||
low_dim_keys = sorted(low_dim_keys)
|
||||
|
||||
self.shape_meta = shape_meta
|
||||
self.key_model_map = key_model_map
|
||||
self.key_transform_map = key_transform_map
|
||||
self.share_rgb_model = share_rgb_model
|
||||
self.rgb_keys = rgb_keys
|
||||
self.low_dim_keys = low_dim_keys
|
||||
self.key_shape_map = key_shape_map
|
||||
|
||||
def forward(self, obs_dict):
|
||||
batch_size = None
|
||||
features = []
|
||||
|
||||
# process lowdim input
|
||||
for key in self.low_dim_keys:
|
||||
data = obs_dict[key]
|
||||
if batch_size is None:
|
||||
batch_size = data.shape[0]
|
||||
else:
|
||||
assert batch_size == data.shape[0]
|
||||
assert data.shape[1:] == self.key_shape_map[key]
|
||||
features.append(data)
|
||||
|
||||
# process rgb input
|
||||
if self.share_rgb_model:
|
||||
# pass all rgb obs to rgb model
|
||||
imgs = []
|
||||
for key in self.rgb_keys:
|
||||
img = obs_dict[key]
|
||||
if batch_size is None:
|
||||
batch_size = img.shape[0]
|
||||
else:
|
||||
assert batch_size == img.shape[0]
|
||||
assert img.shape[1:] == self.key_shape_map[key]
|
||||
img = self.key_transform_map[key](img)
|
||||
imgs.append(img)
|
||||
# (N*B,C,H,W)
|
||||
imgs = torch.cat(imgs, dim=0)
|
||||
# (N*B,D)
|
||||
feature = self.key_model_map["rgb"](imgs)
|
||||
# (N,B,D)
|
||||
feature = feature.reshape(-1, batch_size, *feature.shape[1:])
|
||||
# (B,N,D)
|
||||
feature = torch.moveaxis(feature, 0, 1)
|
||||
# (B,N*D)
|
||||
feature = feature.reshape(batch_size, -1)
|
||||
features.append(feature)
|
||||
else:
|
||||
# run each rgb obs to independent models
|
||||
for key in self.rgb_keys:
|
||||
img = obs_dict[key]
|
||||
if batch_size is None:
|
||||
batch_size = img.shape[0]
|
||||
else:
|
||||
assert batch_size == img.shape[0]
|
||||
assert img.shape[1:] == self.key_shape_map[key]
|
||||
img = self.key_transform_map[key](img)
|
||||
feature = self.key_model_map[key](img)
|
||||
features.append(feature)
|
||||
|
||||
# concatenate all features
|
||||
result = torch.cat(features, dim=-1)
|
||||
return result
|
||||
|
||||
@torch.no_grad()
|
||||
def output_shape(self):
|
||||
example_obs_dict = {}
|
||||
obs_shape_meta = self.shape_meta["obs"]
|
||||
batch_size = 1
|
||||
for key, attr in obs_shape_meta.items():
|
||||
shape = tuple(attr["shape"])
|
||||
this_obs = torch.zeros((batch_size,) + shape, dtype=self.dtype, device=self.device)
|
||||
example_obs_dict[key] = this_obs
|
||||
example_output = self.forward(example_obs_dict)
|
||||
output_shape = example_output.shape[1:]
|
||||
return output_shape
|
|
@ -1,358 +0,0 @@
|
|||
from typing import Dict, Union
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import zarr
|
||||
|
||||
from lerobot.common.policies.diffusion.model.dict_of_tensor_mixin import DictOfTensorMixin
|
||||
from lerobot.common.policies.diffusion.pytorch_utils import dict_apply
|
||||
|
||||
|
||||
class LinearNormalizer(DictOfTensorMixin):
|
||||
avaliable_modes = ["limits", "gaussian"]
|
||||
|
||||
@torch.no_grad()
|
||||
def fit(
|
||||
self,
|
||||
data: Union[Dict, torch.Tensor, np.ndarray, zarr.Array],
|
||||
last_n_dims=1,
|
||||
dtype=torch.float32,
|
||||
mode="limits",
|
||||
output_max=1.0,
|
||||
output_min=-1.0,
|
||||
range_eps=1e-4,
|
||||
fit_offset=True,
|
||||
):
|
||||
if isinstance(data, dict):
|
||||
for key, value in data.items():
|
||||
self.params_dict[key] = _fit(
|
||||
value,
|
||||
last_n_dims=last_n_dims,
|
||||
dtype=dtype,
|
||||
mode=mode,
|
||||
output_max=output_max,
|
||||
output_min=output_min,
|
||||
range_eps=range_eps,
|
||||
fit_offset=fit_offset,
|
||||
)
|
||||
else:
|
||||
self.params_dict["_default"] = _fit(
|
||||
data,
|
||||
last_n_dims=last_n_dims,
|
||||
dtype=dtype,
|
||||
mode=mode,
|
||||
output_max=output_max,
|
||||
output_min=output_min,
|
||||
range_eps=range_eps,
|
||||
fit_offset=fit_offset,
|
||||
)
|
||||
|
||||
def __call__(self, x: Union[Dict, torch.Tensor, np.ndarray]) -> torch.Tensor:
|
||||
return self.normalize(x)
|
||||
|
||||
def __getitem__(self, key: str):
|
||||
return SingleFieldLinearNormalizer(self.params_dict[key])
|
||||
|
||||
def __setitem__(self, key: str, value: "SingleFieldLinearNormalizer"):
|
||||
self.params_dict[key] = value.params_dict
|
||||
|
||||
def _normalize_impl(self, x, forward=True):
|
||||
if isinstance(x, dict):
|
||||
result = {}
|
||||
for key, value in x.items():
|
||||
params = self.params_dict[key]
|
||||
result[key] = _normalize(value, params, forward=forward)
|
||||
return result
|
||||
else:
|
||||
if "_default" not in self.params_dict:
|
||||
raise RuntimeError("Not initialized")
|
||||
params = self.params_dict["_default"]
|
||||
return _normalize(x, params, forward=forward)
|
||||
|
||||
def normalize(self, x: Union[Dict, torch.Tensor, np.ndarray]) -> torch.Tensor:
|
||||
return self._normalize_impl(x, forward=True)
|
||||
|
||||
def unnormalize(self, x: Union[Dict, torch.Tensor, np.ndarray]) -> torch.Tensor:
|
||||
return self._normalize_impl(x, forward=False)
|
||||
|
||||
def get_input_stats(self) -> Dict:
|
||||
if len(self.params_dict) == 0:
|
||||
raise RuntimeError("Not initialized")
|
||||
if len(self.params_dict) == 1 and "_default" in self.params_dict:
|
||||
return self.params_dict["_default"]["input_stats"]
|
||||
|
||||
result = {}
|
||||
for key, value in self.params_dict.items():
|
||||
if key != "_default":
|
||||
result[key] = value["input_stats"]
|
||||
return result
|
||||
|
||||
def get_output_stats(self, key="_default"):
|
||||
input_stats = self.get_input_stats()
|
||||
if "min" in input_stats:
|
||||
# no dict
|
||||
return dict_apply(input_stats, self.normalize)
|
||||
|
||||
result = {}
|
||||
for key, group in input_stats.items():
|
||||
this_dict = {}
|
||||
for name, value in group.items():
|
||||
this_dict[name] = self.normalize({key: value})[key]
|
||||
result[key] = this_dict
|
||||
return result
|
||||
|
||||
|
||||
class SingleFieldLinearNormalizer(DictOfTensorMixin):
|
||||
avaliable_modes = ["limits", "gaussian"]
|
||||
|
||||
@torch.no_grad()
|
||||
def fit(
|
||||
self,
|
||||
data: Union[torch.Tensor, np.ndarray, zarr.Array],
|
||||
last_n_dims=1,
|
||||
dtype=torch.float32,
|
||||
mode="limits",
|
||||
output_max=1.0,
|
||||
output_min=-1.0,
|
||||
range_eps=1e-4,
|
||||
fit_offset=True,
|
||||
):
|
||||
self.params_dict = _fit(
|
||||
data,
|
||||
last_n_dims=last_n_dims,
|
||||
dtype=dtype,
|
||||
mode=mode,
|
||||
output_max=output_max,
|
||||
output_min=output_min,
|
||||
range_eps=range_eps,
|
||||
fit_offset=fit_offset,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def create_fit(cls, data: Union[torch.Tensor, np.ndarray, zarr.Array], **kwargs):
|
||||
obj = cls()
|
||||
obj.fit(data, **kwargs)
|
||||
return obj
|
||||
|
||||
@classmethod
|
||||
def create_manual(
|
||||
cls,
|
||||
scale: Union[torch.Tensor, np.ndarray],
|
||||
offset: Union[torch.Tensor, np.ndarray],
|
||||
input_stats_dict: Dict[str, Union[torch.Tensor, np.ndarray]],
|
||||
):
|
||||
def to_tensor(x):
|
||||
if not isinstance(x, torch.Tensor):
|
||||
x = torch.from_numpy(x)
|
||||
x = x.flatten()
|
||||
return x
|
||||
|
||||
# check
|
||||
for x in [offset] + list(input_stats_dict.values()):
|
||||
assert x.shape == scale.shape
|
||||
assert x.dtype == scale.dtype
|
||||
|
||||
params_dict = nn.ParameterDict(
|
||||
{
|
||||
"scale": to_tensor(scale),
|
||||
"offset": to_tensor(offset),
|
||||
"input_stats": nn.ParameterDict(dict_apply(input_stats_dict, to_tensor)),
|
||||
}
|
||||
)
|
||||
return cls(params_dict)
|
||||
|
||||
@classmethod
|
||||
def create_identity(cls, dtype=torch.float32):
|
||||
scale = torch.tensor([1], dtype=dtype)
|
||||
offset = torch.tensor([0], dtype=dtype)
|
||||
input_stats_dict = {
|
||||
"min": torch.tensor([-1], dtype=dtype),
|
||||
"max": torch.tensor([1], dtype=dtype),
|
||||
"mean": torch.tensor([0], dtype=dtype),
|
||||
"std": torch.tensor([1], dtype=dtype),
|
||||
}
|
||||
return cls.create_manual(scale, offset, input_stats_dict)
|
||||
|
||||
def normalize(self, x: Union[torch.Tensor, np.ndarray]) -> torch.Tensor:
|
||||
return _normalize(x, self.params_dict, forward=True)
|
||||
|
||||
def unnormalize(self, x: Union[torch.Tensor, np.ndarray]) -> torch.Tensor:
|
||||
return _normalize(x, self.params_dict, forward=False)
|
||||
|
||||
def get_input_stats(self):
|
||||
return self.params_dict["input_stats"]
|
||||
|
||||
def get_output_stats(self):
|
||||
return dict_apply(self.params_dict["input_stats"], self.normalize)
|
||||
|
||||
def __call__(self, x: Union[torch.Tensor, np.ndarray]) -> torch.Tensor:
|
||||
return self.normalize(x)
|
||||
|
||||
|
||||
def _fit(
|
||||
data: Union[torch.Tensor, np.ndarray, zarr.Array],
|
||||
last_n_dims=1,
|
||||
dtype=torch.float32,
|
||||
mode="limits",
|
||||
output_max=1.0,
|
||||
output_min=-1.0,
|
||||
range_eps=1e-4,
|
||||
fit_offset=True,
|
||||
):
|
||||
assert mode in ["limits", "gaussian"]
|
||||
assert last_n_dims >= 0
|
||||
assert output_max > output_min
|
||||
|
||||
# convert data to torch and type
|
||||
if isinstance(data, zarr.Array):
|
||||
data = data[:]
|
||||
if isinstance(data, np.ndarray):
|
||||
data = torch.from_numpy(data)
|
||||
if dtype is not None:
|
||||
data = data.type(dtype)
|
||||
|
||||
# convert shape
|
||||
dim = 1
|
||||
if last_n_dims > 0:
|
||||
dim = np.prod(data.shape[-last_n_dims:])
|
||||
data = data.reshape(-1, dim)
|
||||
|
||||
# compute input stats min max mean std
|
||||
input_min, _ = data.min(axis=0)
|
||||
input_max, _ = data.max(axis=0)
|
||||
input_mean = data.mean(axis=0)
|
||||
input_std = data.std(axis=0)
|
||||
|
||||
# compute scale and offset
|
||||
if mode == "limits":
|
||||
if fit_offset:
|
||||
# unit scale
|
||||
input_range = input_max - input_min
|
||||
ignore_dim = input_range < range_eps
|
||||
input_range[ignore_dim] = output_max - output_min
|
||||
scale = (output_max - output_min) / input_range
|
||||
offset = output_min - scale * input_min
|
||||
offset[ignore_dim] = (output_max + output_min) / 2 - input_min[ignore_dim]
|
||||
# ignore dims scaled to mean of output max and min
|
||||
else:
|
||||
# use this when data is pre-zero-centered.
|
||||
assert output_max > 0
|
||||
assert output_min < 0
|
||||
# unit abs
|
||||
output_abs = min(abs(output_min), abs(output_max))
|
||||
input_abs = torch.maximum(torch.abs(input_min), torch.abs(input_max))
|
||||
ignore_dim = input_abs < range_eps
|
||||
input_abs[ignore_dim] = output_abs
|
||||
# don't scale constant channels
|
||||
scale = output_abs / input_abs
|
||||
offset = torch.zeros_like(input_mean)
|
||||
elif mode == "gaussian":
|
||||
ignore_dim = input_std < range_eps
|
||||
scale = input_std.clone()
|
||||
scale[ignore_dim] = 1
|
||||
scale = 1 / scale
|
||||
|
||||
offset = -input_mean * scale if fit_offset else torch.zeros_like(input_mean)
|
||||
|
||||
# save
|
||||
this_params = nn.ParameterDict(
|
||||
{
|
||||
"scale": scale,
|
||||
"offset": offset,
|
||||
"input_stats": nn.ParameterDict(
|
||||
{"min": input_min, "max": input_max, "mean": input_mean, "std": input_std}
|
||||
),
|
||||
}
|
||||
)
|
||||
for p in this_params.parameters():
|
||||
p.requires_grad_(False)
|
||||
return this_params
|
||||
|
||||
|
||||
def _normalize(x, params, forward=True):
|
||||
assert "scale" in params
|
||||
if isinstance(x, np.ndarray):
|
||||
x = torch.from_numpy(x)
|
||||
scale = params["scale"]
|
||||
offset = params["offset"]
|
||||
x = x.to(device=scale.device, dtype=scale.dtype)
|
||||
src_shape = x.shape
|
||||
x = x.reshape(-1, scale.shape[0])
|
||||
x = x * scale + offset if forward else (x - offset) / scale
|
||||
x = x.reshape(src_shape)
|
||||
return x
|
||||
|
||||
|
||||
def test():
|
||||
data = torch.zeros((100, 10, 9, 2)).uniform_()
|
||||
data[..., 0, 0] = 0
|
||||
|
||||
normalizer = SingleFieldLinearNormalizer()
|
||||
normalizer.fit(data, mode="limits", last_n_dims=2)
|
||||
datan = normalizer.normalize(data)
|
||||
assert datan.shape == data.shape
|
||||
assert np.allclose(datan.max(), 1.0)
|
||||
assert np.allclose(datan.min(), -1.0)
|
||||
dataun = normalizer.unnormalize(datan)
|
||||
assert torch.allclose(data, dataun, atol=1e-7)
|
||||
|
||||
_ = normalizer.get_input_stats()
|
||||
_ = normalizer.get_output_stats()
|
||||
|
||||
normalizer = SingleFieldLinearNormalizer()
|
||||
normalizer.fit(data, mode="limits", last_n_dims=1, fit_offset=False)
|
||||
datan = normalizer.normalize(data)
|
||||
assert datan.shape == data.shape
|
||||
assert np.allclose(datan.max(), 1.0, atol=1e-3)
|
||||
assert np.allclose(datan.min(), 0.0, atol=1e-3)
|
||||
dataun = normalizer.unnormalize(datan)
|
||||
assert torch.allclose(data, dataun, atol=1e-7)
|
||||
|
||||
data = torch.zeros((100, 10, 9, 2)).uniform_()
|
||||
normalizer = SingleFieldLinearNormalizer()
|
||||
normalizer.fit(data, mode="gaussian", last_n_dims=0)
|
||||
datan = normalizer.normalize(data)
|
||||
assert datan.shape == data.shape
|
||||
assert np.allclose(datan.mean(), 0.0, atol=1e-3)
|
||||
assert np.allclose(datan.std(), 1.0, atol=1e-3)
|
||||
dataun = normalizer.unnormalize(datan)
|
||||
assert torch.allclose(data, dataun, atol=1e-7)
|
||||
|
||||
# dict
|
||||
data = torch.zeros((100, 10, 9, 2)).uniform_()
|
||||
data[..., 0, 0] = 0
|
||||
|
||||
normalizer = LinearNormalizer()
|
||||
normalizer.fit(data, mode="limits", last_n_dims=2)
|
||||
datan = normalizer.normalize(data)
|
||||
assert datan.shape == data.shape
|
||||
assert np.allclose(datan.max(), 1.0)
|
||||
assert np.allclose(datan.min(), -1.0)
|
||||
dataun = normalizer.unnormalize(datan)
|
||||
assert torch.allclose(data, dataun, atol=1e-7)
|
||||
|
||||
_ = normalizer.get_input_stats()
|
||||
_ = normalizer.get_output_stats()
|
||||
|
||||
data = {
|
||||
"obs": torch.zeros((1000, 128, 9, 2)).uniform_() * 512,
|
||||
"action": torch.zeros((1000, 128, 2)).uniform_() * 512,
|
||||
}
|
||||
normalizer = LinearNormalizer()
|
||||
normalizer.fit(data)
|
||||
datan = normalizer.normalize(data)
|
||||
dataun = normalizer.unnormalize(datan)
|
||||
for key in data:
|
||||
assert torch.allclose(data[key], dataun[key], atol=1e-4)
|
||||
|
||||
_ = normalizer.get_input_stats()
|
||||
_ = normalizer.get_output_stats()
|
||||
|
||||
state_dict = normalizer.state_dict()
|
||||
n = LinearNormalizer()
|
||||
n.load_state_dict(state_dict)
|
||||
datan = n.normalize(data)
|
||||
dataun = n.unnormalize(datan)
|
||||
for key in data:
|
||||
assert torch.allclose(data[key], dataun[key], atol=1e-4)
|
|
@ -1,19 +0,0 @@
|
|||
import math
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
|
||||
class SinusoidalPosEmb(nn.Module):
|
||||
def __init__(self, dim):
|
||||
super().__init__()
|
||||
self.dim = dim
|
||||
|
||||
def forward(self, x):
|
||||
device = x.device
|
||||
half_dim = self.dim // 2
|
||||
emb = math.log(10000) / (half_dim - 1)
|
||||
emb = torch.exp(torch.arange(half_dim, device=device) * -emb)
|
||||
emb = x[:, None] * emb[None, :]
|
||||
emb = torch.cat((emb.sin(), emb.cos()), dim=-1)
|
||||
return emb
|
|
@ -0,0 +1,147 @@
|
|||
from typing import Callable
|
||||
|
||||
import torch
|
||||
import torchvision
|
||||
from robomimic.models.base_nets import SpatialSoftmax
|
||||
from torch import Tensor, nn
|
||||
from torchvision.transforms import CenterCrop, RandomCrop
|
||||
|
||||
|
||||
class RgbEncoder(nn.Module):
|
||||
"""Encoder an RGB image into a 1D feature vector.
|
||||
|
||||
Includes the ability to normalize and crop the image first.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
input_shape: tuple[int, int, int],
|
||||
norm_mean_std: tuple[float, float] = [1.0, 1.0],
|
||||
crop_shape: tuple[int, int] | None = None,
|
||||
random_crop: bool = False,
|
||||
backbone_name: str = "resnet18",
|
||||
pretrained_backbone: bool = False,
|
||||
use_group_norm: bool = False,
|
||||
relu: bool = True,
|
||||
num_keypoints: int = 32,
|
||||
):
|
||||
"""
|
||||
Args:
|
||||
input_shape: channel-first input shape (C, H, W)
|
||||
norm_mean_std: mean and standard deviation used for image normalization. Images are normalized as
|
||||
(image - mean) / std.
|
||||
crop_shape: (H, W) shape to crop to (must fit within the input shape). If not provided, no
|
||||
cropping is done.
|
||||
random_crop: Whether the crop should be random at training time (it's always a center crop in
|
||||
eval mode).
|
||||
backbone_name: The name of one of the available resnet models from torchvision (eg resnet18).
|
||||
pretrained_backbone: whether to use timm pretrained weights.
|
||||
use_group_norm: Whether to replace batch normalization with group normalization in the backbone.
|
||||
The group sizes are set to be about 16 (to be precise, feature_dim // 16).
|
||||
relu: whether to use relu as a final step.
|
||||
num_keypoints: Number of keypoints for SpatialSoftmax (default value of 32 matches PushT Image).
|
||||
"""
|
||||
super().__init__()
|
||||
if input_shape[0] != 3:
|
||||
raise ValueError("Only RGB images are handled")
|
||||
if not backbone_name.startswith("resnet"):
|
||||
raise ValueError(
|
||||
"Only resnet is supported for now (because of the assumption that 'layer4' is the output layer)"
|
||||
)
|
||||
|
||||
# Set up optional preprocessing.
|
||||
if norm_mean_std == [1.0, 1.0]:
|
||||
self.normalizer = nn.Identity()
|
||||
else:
|
||||
self.normalizer = torchvision.transforms.Normalize(mean=norm_mean_std[0], std=norm_mean_std[1])
|
||||
|
||||
if crop_shape is not None:
|
||||
self.do_crop = True
|
||||
self.center_crop = CenterCrop(crop_shape) # always use center crop for eval
|
||||
if random_crop:
|
||||
self.maybe_random_crop = RandomCrop(crop_shape)
|
||||
else:
|
||||
self.maybe_random_crop = self.center_crop
|
||||
else:
|
||||
self.do_crop = False
|
||||
|
||||
# Set up backbone.
|
||||
backbone_model = getattr(torchvision.models, backbone_name)(pretrained=pretrained_backbone)
|
||||
# Note: This assumes that the layer4 feature map is children()[-3]
|
||||
# TODO(alexander-soare): Use a safer alternative.
|
||||
self.backbone = nn.Sequential(*(list(backbone_model.children())[:-2]))
|
||||
if use_group_norm:
|
||||
if pretrained_backbone:
|
||||
raise ValueError(
|
||||
"You can't replace BatchNorm in a pretrained model without ruining the weights!"
|
||||
)
|
||||
self.backbone = _replace_submodules(
|
||||
root_module=self.backbone,
|
||||
predicate=lambda x: isinstance(x, nn.BatchNorm2d),
|
||||
func=lambda x: nn.GroupNorm(num_groups=x.num_features // 16, num_channels=x.num_features),
|
||||
)
|
||||
|
||||
# Set up pooling and final layers.
|
||||
# Use a dry run to get the feature map shape.
|
||||
with torch.inference_mode():
|
||||
feat_map_shape = tuple(self.backbone(torch.zeros(size=(1, *input_shape))).shape[1:])
|
||||
self.pool = SpatialSoftmax(feat_map_shape, num_kp=num_keypoints)
|
||||
self.feature_dim = num_keypoints * 2
|
||||
self.out = nn.Linear(num_keypoints * 2, self.feature_dim)
|
||||
self.maybe_relu = nn.ReLU() if relu else nn.Identity()
|
||||
|
||||
def forward(self, x: Tensor) -> Tensor:
|
||||
"""
|
||||
Args:
|
||||
x: (B, C, H, W) image tensor with pixel values in [0, 1].
|
||||
Returns:
|
||||
(B, D) image feature.
|
||||
"""
|
||||
# Preprocess: normalize and maybe crop (if it was set up in the __init__).
|
||||
x = self.normalizer(x)
|
||||
if self.do_crop:
|
||||
if self.training: # noqa: SIM108
|
||||
x = self.maybe_random_crop(x)
|
||||
else:
|
||||
# Always use center crop for eval.
|
||||
x = self.center_crop(x)
|
||||
# Extract backbone feature.
|
||||
x = torch.flatten(self.pool(self.backbone(x)), start_dim=1)
|
||||
# Final linear layer.
|
||||
x = self.out(x)
|
||||
# Maybe a final non-linearity.
|
||||
x = self.maybe_relu(x)
|
||||
return x
|
||||
|
||||
|
||||
def _replace_submodules(
|
||||
root_module: nn.Module, predicate: Callable[[nn.Module], bool], func: Callable[[nn.Module], nn.Module]
|
||||
) -> nn.Module:
|
||||
"""
|
||||
Args:
|
||||
root_module: The module for which the submodules need to be replaced
|
||||
predicate: Takes a module as an argument and must return True if the that module is to be replaced.
|
||||
func: Takes a module as an argument and returns a new module to replace it with.
|
||||
Returns:
|
||||
The root module with its submodules replaced.
|
||||
"""
|
||||
if predicate(root_module):
|
||||
return func(root_module)
|
||||
|
||||
replace_list = [k.split(".") for k, m in root_module.named_modules(remove_duplicate=True) if predicate(m)]
|
||||
for *parents, k in replace_list:
|
||||
parent_module = root_module
|
||||
if len(parents) > 0:
|
||||
parent_module = root_module.get_submodule(".".join(parents))
|
||||
if isinstance(parent_module, nn.Sequential):
|
||||
src_module = parent_module[int(k)]
|
||||
else:
|
||||
src_module = getattr(parent_module, k)
|
||||
tgt_module = func(src_module)
|
||||
if isinstance(parent_module, nn.Sequential):
|
||||
parent_module[int(k)] = tgt_module
|
||||
else:
|
||||
setattr(parent_module, k, tgt_module)
|
||||
# verify that all BN are replaced
|
||||
assert not any(predicate(m) for _, m in root_module.named_modules(remove_duplicate=True))
|
||||
return root_module
|
|
@ -1,972 +0,0 @@
|
|||
"""
|
||||
A collection of utilities for working with nested tensor structures consisting
|
||||
of numpy arrays and torch tensors.
|
||||
"""
|
||||
|
||||
import collections
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
|
||||
def recursive_dict_list_tuple_apply(x, type_func_dict):
|
||||
"""
|
||||
Recursively apply functions to a nested dictionary or list or tuple, given a dictionary of
|
||||
{data_type: function_to_apply}.
|
||||
|
||||
Args:
|
||||
x (dict or list or tuple): a possibly nested dictionary or list or tuple
|
||||
type_func_dict (dict): a mapping from data types to the functions to be
|
||||
applied for each data type.
|
||||
|
||||
Returns:
|
||||
y (dict or list or tuple): new nested dict-list-tuple
|
||||
"""
|
||||
assert list not in type_func_dict
|
||||
assert tuple not in type_func_dict
|
||||
assert dict not in type_func_dict
|
||||
|
||||
if isinstance(x, (dict, collections.OrderedDict)):
|
||||
new_x = collections.OrderedDict() if isinstance(x, collections.OrderedDict) else {}
|
||||
for k, v in x.items():
|
||||
new_x[k] = recursive_dict_list_tuple_apply(v, type_func_dict)
|
||||
return new_x
|
||||
elif isinstance(x, (list, tuple)):
|
||||
ret = [recursive_dict_list_tuple_apply(v, type_func_dict) for v in x]
|
||||
if isinstance(x, tuple):
|
||||
ret = tuple(ret)
|
||||
return ret
|
||||
else:
|
||||
for t, f in type_func_dict.items():
|
||||
if isinstance(x, t):
|
||||
return f(x)
|
||||
else:
|
||||
raise NotImplementedError("Cannot handle data type %s" % str(type(x)))
|
||||
|
||||
|
||||
def map_tensor(x, func):
|
||||
"""
|
||||
Apply function @func to torch.Tensor objects in a nested dictionary or
|
||||
list or tuple.
|
||||
|
||||
Args:
|
||||
x (dict or list or tuple): a possibly nested dictionary or list or tuple
|
||||
func (function): function to apply to each tensor
|
||||
|
||||
Returns:
|
||||
y (dict or list or tuple): new nested dict-list-tuple
|
||||
"""
|
||||
return recursive_dict_list_tuple_apply(
|
||||
x,
|
||||
{
|
||||
torch.Tensor: func,
|
||||
type(None): lambda x: x,
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
def map_ndarray(x, func):
|
||||
"""
|
||||
Apply function @func to np.ndarray objects in a nested dictionary or
|
||||
list or tuple.
|
||||
|
||||
Args:
|
||||
x (dict or list or tuple): a possibly nested dictionary or list or tuple
|
||||
func (function): function to apply to each array
|
||||
|
||||
Returns:
|
||||
y (dict or list or tuple): new nested dict-list-tuple
|
||||
"""
|
||||
return recursive_dict_list_tuple_apply(
|
||||
x,
|
||||
{
|
||||
np.ndarray: func,
|
||||
type(None): lambda x: x,
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
def map_tensor_ndarray(x, tensor_func, ndarray_func):
|
||||
"""
|
||||
Apply function @tensor_func to torch.Tensor objects and @ndarray_func to
|
||||
np.ndarray objects in a nested dictionary or list or tuple.
|
||||
|
||||
Args:
|
||||
x (dict or list or tuple): a possibly nested dictionary or list or tuple
|
||||
tensor_func (function): function to apply to each tensor
|
||||
ndarray_Func (function): function to apply to each array
|
||||
|
||||
Returns:
|
||||
y (dict or list or tuple): new nested dict-list-tuple
|
||||
"""
|
||||
return recursive_dict_list_tuple_apply(
|
||||
x,
|
||||
{
|
||||
torch.Tensor: tensor_func,
|
||||
np.ndarray: ndarray_func,
|
||||
type(None): lambda x: x,
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
def clone(x):
|
||||
"""
|
||||
Clones all torch tensors and numpy arrays in nested dictionary or list
|
||||
or tuple and returns a new nested structure.
|
||||
|
||||
Args:
|
||||
x (dict or list or tuple): a possibly nested dictionary or list or tuple
|
||||
|
||||
Returns:
|
||||
y (dict or list or tuple): new nested dict-list-tuple
|
||||
"""
|
||||
return recursive_dict_list_tuple_apply(
|
||||
x,
|
||||
{
|
||||
torch.Tensor: lambda x: x.clone(),
|
||||
np.ndarray: lambda x: x.copy(),
|
||||
type(None): lambda x: x,
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
def detach(x):
|
||||
"""
|
||||
Detaches all torch tensors in nested dictionary or list
|
||||
or tuple and returns a new nested structure.
|
||||
|
||||
Args:
|
||||
x (dict or list or tuple): a possibly nested dictionary or list or tuple
|
||||
|
||||
Returns:
|
||||
y (dict or list or tuple): new nested dict-list-tuple
|
||||
"""
|
||||
return recursive_dict_list_tuple_apply(
|
||||
x,
|
||||
{
|
||||
torch.Tensor: lambda x: x.detach(),
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
def to_batch(x):
|
||||
"""
|
||||
Introduces a leading batch dimension of 1 for all torch tensors and numpy
|
||||
arrays in nested dictionary or list or tuple and returns a new nested structure.
|
||||
|
||||
Args:
|
||||
x (dict or list or tuple): a possibly nested dictionary or list or tuple
|
||||
|
||||
Returns:
|
||||
y (dict or list or tuple): new nested dict-list-tuple
|
||||
"""
|
||||
return recursive_dict_list_tuple_apply(
|
||||
x,
|
||||
{
|
||||
torch.Tensor: lambda x: x[None, ...],
|
||||
np.ndarray: lambda x: x[None, ...],
|
||||
type(None): lambda x: x,
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
def to_sequence(x):
|
||||
"""
|
||||
Introduces a time dimension of 1 at dimension 1 for all torch tensors and numpy
|
||||
arrays in nested dictionary or list or tuple and returns a new nested structure.
|
||||
|
||||
Args:
|
||||
x (dict or list or tuple): a possibly nested dictionary or list or tuple
|
||||
|
||||
Returns:
|
||||
y (dict or list or tuple): new nested dict-list-tuple
|
||||
"""
|
||||
return recursive_dict_list_tuple_apply(
|
||||
x,
|
||||
{
|
||||
torch.Tensor: lambda x: x[:, None, ...],
|
||||
np.ndarray: lambda x: x[:, None, ...],
|
||||
type(None): lambda x: x,
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
def index_at_time(x, ind):
|
||||
"""
|
||||
Indexes all torch tensors and numpy arrays in dimension 1 with index @ind in
|
||||
nested dictionary or list or tuple and returns a new nested structure.
|
||||
|
||||
Args:
|
||||
x (dict or list or tuple): a possibly nested dictionary or list or tuple
|
||||
ind (int): index
|
||||
|
||||
Returns:
|
||||
y (dict or list or tuple): new nested dict-list-tuple
|
||||
"""
|
||||
return recursive_dict_list_tuple_apply(
|
||||
x,
|
||||
{
|
||||
torch.Tensor: lambda x: x[:, ind, ...],
|
||||
np.ndarray: lambda x: x[:, ind, ...],
|
||||
type(None): lambda x: x,
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
def unsqueeze(x, dim):
|
||||
"""
|
||||
Adds dimension of size 1 at dimension @dim in all torch tensors and numpy arrays
|
||||
in nested dictionary or list or tuple and returns a new nested structure.
|
||||
|
||||
Args:
|
||||
x (dict or list or tuple): a possibly nested dictionary or list or tuple
|
||||
dim (int): dimension
|
||||
|
||||
Returns:
|
||||
y (dict or list or tuple): new nested dict-list-tuple
|
||||
"""
|
||||
return recursive_dict_list_tuple_apply(
|
||||
x,
|
||||
{
|
||||
torch.Tensor: lambda x: x.unsqueeze(dim=dim),
|
||||
np.ndarray: lambda x: np.expand_dims(x, axis=dim),
|
||||
type(None): lambda x: x,
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
def contiguous(x):
|
||||
"""
|
||||
Makes all torch tensors and numpy arrays contiguous in nested dictionary or
|
||||
list or tuple and returns a new nested structure.
|
||||
|
||||
Args:
|
||||
x (dict or list or tuple): a possibly nested dictionary or list or tuple
|
||||
|
||||
Returns:
|
||||
y (dict or list or tuple): new nested dict-list-tuple
|
||||
"""
|
||||
return recursive_dict_list_tuple_apply(
|
||||
x,
|
||||
{
|
||||
torch.Tensor: lambda x: x.contiguous(),
|
||||
np.ndarray: lambda x: np.ascontiguousarray(x),
|
||||
type(None): lambda x: x,
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
def to_device(x, device):
|
||||
"""
|
||||
Sends all torch tensors in nested dictionary or list or tuple to device
|
||||
@device, and returns a new nested structure.
|
||||
|
||||
Args:
|
||||
x (dict or list or tuple): a possibly nested dictionary or list or tuple
|
||||
device (torch.Device): device to send tensors to
|
||||
|
||||
Returns:
|
||||
y (dict or list or tuple): new nested dict-list-tuple
|
||||
"""
|
||||
return recursive_dict_list_tuple_apply(
|
||||
x,
|
||||
{
|
||||
torch.Tensor: lambda x, d=device: x.to(d),
|
||||
type(None): lambda x: x,
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
def to_tensor(x):
|
||||
"""
|
||||
Converts all numpy arrays in nested dictionary or list or tuple to
|
||||
torch tensors (and leaves existing torch Tensors as-is), and returns
|
||||
a new nested structure.
|
||||
|
||||
Args:
|
||||
x (dict or list or tuple): a possibly nested dictionary or list or tuple
|
||||
|
||||
Returns:
|
||||
y (dict or list or tuple): new nested dict-list-tuple
|
||||
"""
|
||||
return recursive_dict_list_tuple_apply(
|
||||
x,
|
||||
{
|
||||
torch.Tensor: lambda x: x,
|
||||
np.ndarray: lambda x: torch.from_numpy(x),
|
||||
type(None): lambda x: x,
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
def to_numpy(x):
|
||||
"""
|
||||
Converts all torch tensors in nested dictionary or list or tuple to
|
||||
numpy (and leaves existing numpy arrays as-is), and returns
|
||||
a new nested structure.
|
||||
|
||||
Args:
|
||||
x (dict or list or tuple): a possibly nested dictionary or list or tuple
|
||||
|
||||
Returns:
|
||||
y (dict or list or tuple): new nested dict-list-tuple
|
||||
"""
|
||||
|
||||
def f(tensor):
|
||||
if tensor.is_cuda:
|
||||
return tensor.detach().cpu().numpy()
|
||||
else:
|
||||
return tensor.detach().numpy()
|
||||
|
||||
return recursive_dict_list_tuple_apply(
|
||||
x,
|
||||
{
|
||||
torch.Tensor: f,
|
||||
np.ndarray: lambda x: x,
|
||||
type(None): lambda x: x,
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
def to_list(x):
|
||||
"""
|
||||
Converts all torch tensors and numpy arrays in nested dictionary or list
|
||||
or tuple to a list, and returns a new nested structure. Useful for
|
||||
json encoding.
|
||||
|
||||
Args:
|
||||
x (dict or list or tuple): a possibly nested dictionary or list or tuple
|
||||
|
||||
Returns:
|
||||
y (dict or list or tuple): new nested dict-list-tuple
|
||||
"""
|
||||
|
||||
def f(tensor):
|
||||
if tensor.is_cuda:
|
||||
return tensor.detach().cpu().numpy().tolist()
|
||||
else:
|
||||
return tensor.detach().numpy().tolist()
|
||||
|
||||
return recursive_dict_list_tuple_apply(
|
||||
x,
|
||||
{
|
||||
torch.Tensor: f,
|
||||
np.ndarray: lambda x: x.tolist(),
|
||||
type(None): lambda x: x,
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
def to_float(x):
|
||||
"""
|
||||
Converts all torch tensors and numpy arrays in nested dictionary or list
|
||||
or tuple to float type entries, and returns a new nested structure.
|
||||
|
||||
Args:
|
||||
x (dict or list or tuple): a possibly nested dictionary or list or tuple
|
||||
|
||||
Returns:
|
||||
y (dict or list or tuple): new nested dict-list-tuple
|
||||
"""
|
||||
return recursive_dict_list_tuple_apply(
|
||||
x,
|
||||
{
|
||||
torch.Tensor: lambda x: x.float(),
|
||||
np.ndarray: lambda x: x.astype(np.float32),
|
||||
type(None): lambda x: x,
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
def to_uint8(x):
|
||||
"""
|
||||
Converts all torch tensors and numpy arrays in nested dictionary or list
|
||||
or tuple to uint8 type entries, and returns a new nested structure.
|
||||
|
||||
Args:
|
||||
x (dict or list or tuple): a possibly nested dictionary or list or tuple
|
||||
|
||||
Returns:
|
||||
y (dict or list or tuple): new nested dict-list-tuple
|
||||
"""
|
||||
return recursive_dict_list_tuple_apply(
|
||||
x,
|
||||
{
|
||||
torch.Tensor: lambda x: x.byte(),
|
||||
np.ndarray: lambda x: x.astype(np.uint8),
|
||||
type(None): lambda x: x,
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
def to_torch(x, device):
|
||||
"""
|
||||
Converts all numpy arrays and torch tensors in nested dictionary or list or tuple to
|
||||
torch tensors on device @device and returns a new nested structure.
|
||||
|
||||
Args:
|
||||
x (dict or list or tuple): a possibly nested dictionary or list or tuple
|
||||
device (torch.Device): device to send tensors to
|
||||
|
||||
Returns:
|
||||
y (dict or list or tuple): new nested dict-list-tuple
|
||||
"""
|
||||
return to_device(to_float(to_tensor(x)), device)
|
||||
|
||||
|
||||
def to_one_hot_single(tensor, num_class):
|
||||
"""
|
||||
Convert tensor to one-hot representation, assuming a certain number of total class labels.
|
||||
|
||||
Args:
|
||||
tensor (torch.Tensor): tensor containing integer labels
|
||||
num_class (int): number of classes
|
||||
|
||||
Returns:
|
||||
x (torch.Tensor): tensor containing one-hot representation of labels
|
||||
"""
|
||||
x = torch.zeros(tensor.size() + (num_class,)).to(tensor.device)
|
||||
x.scatter_(-1, tensor.unsqueeze(-1), 1)
|
||||
return x
|
||||
|
||||
|
||||
def to_one_hot(tensor, num_class):
|
||||
"""
|
||||
Convert all tensors in nested dictionary or list or tuple to one-hot representation,
|
||||
assuming a certain number of total class labels.
|
||||
|
||||
Args:
|
||||
tensor (dict or list or tuple): a possibly nested dictionary or list or tuple
|
||||
num_class (int): number of classes
|
||||
|
||||
Returns:
|
||||
y (dict or list or tuple): new nested dict-list-tuple
|
||||
"""
|
||||
return map_tensor(tensor, func=lambda x, nc=num_class: to_one_hot_single(x, nc))
|
||||
|
||||
|
||||
def flatten_single(x, begin_axis=1):
|
||||
"""
|
||||
Flatten a tensor in all dimensions from @begin_axis onwards.
|
||||
|
||||
Args:
|
||||
x (torch.Tensor): tensor to flatten
|
||||
begin_axis (int): which axis to flatten from
|
||||
|
||||
Returns:
|
||||
y (torch.Tensor): flattened tensor
|
||||
"""
|
||||
fixed_size = x.size()[:begin_axis]
|
||||
_s = list(fixed_size) + [-1]
|
||||
return x.reshape(*_s)
|
||||
|
||||
|
||||
def flatten(x, begin_axis=1):
|
||||
"""
|
||||
Flatten all tensors in nested dictionary or list or tuple, from @begin_axis onwards.
|
||||
|
||||
Args:
|
||||
x (dict or list or tuple): a possibly nested dictionary or list or tuple
|
||||
begin_axis (int): which axis to flatten from
|
||||
|
||||
Returns:
|
||||
y (dict or list or tuple): new nested dict-list-tuple
|
||||
"""
|
||||
return recursive_dict_list_tuple_apply(
|
||||
x,
|
||||
{
|
||||
torch.Tensor: lambda x, b=begin_axis: flatten_single(x, begin_axis=b),
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
def reshape_dimensions_single(x, begin_axis, end_axis, target_dims):
|
||||
"""
|
||||
Reshape selected dimensions in a tensor to a target dimension.
|
||||
|
||||
Args:
|
||||
x (torch.Tensor): tensor to reshape
|
||||
begin_axis (int): begin dimension
|
||||
end_axis (int): end dimension
|
||||
target_dims (tuple or list): target shape for the range of dimensions
|
||||
(@begin_axis, @end_axis)
|
||||
|
||||
Returns:
|
||||
y (torch.Tensor): reshaped tensor
|
||||
"""
|
||||
assert begin_axis <= end_axis
|
||||
assert begin_axis >= 0
|
||||
assert end_axis < len(x.shape)
|
||||
assert isinstance(target_dims, (tuple, list))
|
||||
s = x.shape
|
||||
final_s = []
|
||||
for i in range(len(s)):
|
||||
if i == begin_axis:
|
||||
final_s.extend(target_dims)
|
||||
elif i < begin_axis or i > end_axis:
|
||||
final_s.append(s[i])
|
||||
return x.reshape(*final_s)
|
||||
|
||||
|
||||
def reshape_dimensions(x, begin_axis, end_axis, target_dims):
|
||||
"""
|
||||
Reshape selected dimensions for all tensors in nested dictionary or list or tuple
|
||||
to a target dimension.
|
||||
|
||||
Args:
|
||||
x (dict or list or tuple): a possibly nested dictionary or list or tuple
|
||||
begin_axis (int): begin dimension
|
||||
end_axis (int): end dimension
|
||||
target_dims (tuple or list): target shape for the range of dimensions
|
||||
(@begin_axis, @end_axis)
|
||||
|
||||
Returns:
|
||||
y (dict or list or tuple): new nested dict-list-tuple
|
||||
"""
|
||||
return recursive_dict_list_tuple_apply(
|
||||
x,
|
||||
{
|
||||
torch.Tensor: lambda x, b=begin_axis, e=end_axis, t=target_dims: reshape_dimensions_single(
|
||||
x, begin_axis=b, end_axis=e, target_dims=t
|
||||
),
|
||||
np.ndarray: lambda x, b=begin_axis, e=end_axis, t=target_dims: reshape_dimensions_single(
|
||||
x, begin_axis=b, end_axis=e, target_dims=t
|
||||
),
|
||||
type(None): lambda x: x,
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
def join_dimensions(x, begin_axis, end_axis):
|
||||
"""
|
||||
Joins all dimensions between dimensions (@begin_axis, @end_axis) into a flat dimension, for
|
||||
all tensors in nested dictionary or list or tuple.
|
||||
|
||||
Args:
|
||||
x (dict or list or tuple): a possibly nested dictionary or list or tuple
|
||||
begin_axis (int): begin dimension
|
||||
end_axis (int): end dimension
|
||||
|
||||
Returns:
|
||||
y (dict or list or tuple): new nested dict-list-tuple
|
||||
"""
|
||||
return recursive_dict_list_tuple_apply(
|
||||
x,
|
||||
{
|
||||
torch.Tensor: lambda x, b=begin_axis, e=end_axis: reshape_dimensions_single(
|
||||
x, begin_axis=b, end_axis=e, target_dims=[-1]
|
||||
),
|
||||
np.ndarray: lambda x, b=begin_axis, e=end_axis: reshape_dimensions_single(
|
||||
x, begin_axis=b, end_axis=e, target_dims=[-1]
|
||||
),
|
||||
type(None): lambda x: x,
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
def expand_at_single(x, size, dim):
|
||||
"""
|
||||
Expand a tensor at a single dimension @dim by @size
|
||||
|
||||
Args:
|
||||
x (torch.Tensor): input tensor
|
||||
size (int): size to expand
|
||||
dim (int): dimension to expand
|
||||
|
||||
Returns:
|
||||
y (torch.Tensor): expanded tensor
|
||||
"""
|
||||
assert dim < x.ndimension()
|
||||
assert x.shape[dim] == 1
|
||||
expand_dims = [-1] * x.ndimension()
|
||||
expand_dims[dim] = size
|
||||
return x.expand(*expand_dims)
|
||||
|
||||
|
||||
def expand_at(x, size, dim):
|
||||
"""
|
||||
Expand all tensors in nested dictionary or list or tuple at a single
|
||||
dimension @dim by @size.
|
||||
|
||||
Args:
|
||||
x (dict or list or tuple): a possibly nested dictionary or list or tuple
|
||||
size (int): size to expand
|
||||
dim (int): dimension to expand
|
||||
|
||||
Returns:
|
||||
y (dict or list or tuple): new nested dict-list-tuple
|
||||
"""
|
||||
return map_tensor(x, lambda t, s=size, d=dim: expand_at_single(t, s, d))
|
||||
|
||||
|
||||
def unsqueeze_expand_at(x, size, dim):
|
||||
"""
|
||||
Unsqueeze and expand a tensor at a dimension @dim by @size.
|
||||
|
||||
Args:
|
||||
x (dict or list or tuple): a possibly nested dictionary or list or tuple
|
||||
size (int): size to expand
|
||||
dim (int): dimension to unsqueeze and expand
|
||||
|
||||
Returns:
|
||||
y (dict or list or tuple): new nested dict-list-tuple
|
||||
"""
|
||||
x = unsqueeze(x, dim)
|
||||
return expand_at(x, size, dim)
|
||||
|
||||
|
||||
def repeat_by_expand_at(x, repeats, dim):
|
||||
"""
|
||||
Repeat a dimension by combining expand and reshape operations.
|
||||
|
||||
Args:
|
||||
x (dict or list or tuple): a possibly nested dictionary or list or tuple
|
||||
repeats (int): number of times to repeat the target dimension
|
||||
dim (int): dimension to repeat on
|
||||
|
||||
Returns:
|
||||
y (dict or list or tuple): new nested dict-list-tuple
|
||||
"""
|
||||
x = unsqueeze_expand_at(x, repeats, dim + 1)
|
||||
return join_dimensions(x, dim, dim + 1)
|
||||
|
||||
|
||||
def named_reduce_single(x, reduction, dim):
|
||||
"""
|
||||
Reduce tensor at a dimension by named reduction functions.
|
||||
|
||||
Args:
|
||||
x (torch.Tensor): tensor to be reduced
|
||||
reduction (str): one of ["sum", "max", "mean", "flatten"]
|
||||
dim (int): dimension to be reduced (or begin axis for flatten)
|
||||
|
||||
Returns:
|
||||
y (torch.Tensor): reduced tensor
|
||||
"""
|
||||
assert x.ndimension() > dim
|
||||
assert reduction in ["sum", "max", "mean", "flatten"]
|
||||
if reduction == "flatten":
|
||||
x = flatten(x, begin_axis=dim)
|
||||
elif reduction == "max":
|
||||
x = torch.max(x, dim=dim)[0] # [B, D]
|
||||
elif reduction == "sum":
|
||||
x = torch.sum(x, dim=dim)
|
||||
else:
|
||||
x = torch.mean(x, dim=dim)
|
||||
return x
|
||||
|
||||
|
||||
def named_reduce(x, reduction, dim):
|
||||
"""
|
||||
Reduces all tensors in nested dictionary or list or tuple at a dimension
|
||||
using a named reduction function.
|
||||
|
||||
Args:
|
||||
x (dict or list or tuple): a possibly nested dictionary or list or tuple
|
||||
reduction (str): one of ["sum", "max", "mean", "flatten"]
|
||||
dim (int): dimension to be reduced (or begin axis for flatten)
|
||||
|
||||
Returns:
|
||||
y (dict or list or tuple): new nested dict-list-tuple
|
||||
"""
|
||||
return map_tensor(x, func=lambda t, r=reduction, d=dim: named_reduce_single(t, r, d))
|
||||
|
||||
|
||||
def gather_along_dim_with_dim_single(x, target_dim, source_dim, indices):
|
||||
"""
|
||||
This function indexes out a target dimension of a tensor in a structured way,
|
||||
by allowing a different value to be selected for each member of a flat index
|
||||
tensor (@indices) corresponding to a source dimension. This can be interpreted
|
||||
as moving along the source dimension, using the corresponding index value
|
||||
in @indices to select values for all other dimensions outside of the
|
||||
source and target dimensions. A common use case is to gather values
|
||||
in target dimension 1 for each batch member (target dimension 0).
|
||||
|
||||
Args:
|
||||
x (torch.Tensor): tensor to gather values for
|
||||
target_dim (int): dimension to gather values along
|
||||
source_dim (int): dimension to hold constant and use for gathering values
|
||||
from the other dimensions
|
||||
indices (torch.Tensor): flat index tensor with same shape as tensor @x along
|
||||
@source_dim
|
||||
|
||||
Returns:
|
||||
y (torch.Tensor): gathered tensor, with dimension @target_dim indexed out
|
||||
"""
|
||||
assert len(indices.shape) == 1
|
||||
assert x.shape[source_dim] == indices.shape[0]
|
||||
|
||||
# unsqueeze in all dimensions except the source dimension
|
||||
new_shape = [1] * x.ndimension()
|
||||
new_shape[source_dim] = -1
|
||||
indices = indices.reshape(*new_shape)
|
||||
|
||||
# repeat in all dimensions - but preserve shape of source dimension,
|
||||
# and make sure target_dimension has singleton dimension
|
||||
expand_shape = list(x.shape)
|
||||
expand_shape[source_dim] = -1
|
||||
expand_shape[target_dim] = 1
|
||||
indices = indices.expand(*expand_shape)
|
||||
|
||||
out = x.gather(dim=target_dim, index=indices)
|
||||
return out.squeeze(target_dim)
|
||||
|
||||
|
||||
def gather_along_dim_with_dim(x, target_dim, source_dim, indices):
|
||||
"""
|
||||
Apply @gather_along_dim_with_dim_single to all tensors in a nested
|
||||
dictionary or list or tuple.
|
||||
|
||||
Args:
|
||||
x (dict or list or tuple): a possibly nested dictionary or list or tuple
|
||||
target_dim (int): dimension to gather values along
|
||||
source_dim (int): dimension to hold constant and use for gathering values
|
||||
from the other dimensions
|
||||
indices (torch.Tensor): flat index tensor with same shape as tensor @x along
|
||||
@source_dim
|
||||
|
||||
Returns:
|
||||
y (dict or list or tuple): new nested dict-list-tuple
|
||||
"""
|
||||
return map_tensor(
|
||||
x, lambda y, t=target_dim, s=source_dim, i=indices: gather_along_dim_with_dim_single(y, t, s, i)
|
||||
)
|
||||
|
||||
|
||||
def gather_sequence_single(seq, indices):
|
||||
"""
|
||||
Given a tensor with leading dimensions [B, T, ...], gather an element from each sequence in
|
||||
the batch given an index for each sequence.
|
||||
|
||||
Args:
|
||||
seq (torch.Tensor): tensor with leading dimensions [B, T, ...]
|
||||
indices (torch.Tensor): tensor indices of shape [B]
|
||||
|
||||
Return:
|
||||
y (torch.Tensor): indexed tensor of shape [B, ....]
|
||||
"""
|
||||
return gather_along_dim_with_dim_single(seq, target_dim=1, source_dim=0, indices=indices)
|
||||
|
||||
|
||||
def gather_sequence(seq, indices):
|
||||
"""
|
||||
Given a nested dictionary or list or tuple, gathers an element from each sequence of the batch
|
||||
for tensors with leading dimensions [B, T, ...].
|
||||
|
||||
Args:
|
||||
seq (dict or list or tuple): a possibly nested dictionary or list or tuple with tensors
|
||||
of leading dimensions [B, T, ...]
|
||||
indices (torch.Tensor): tensor indices of shape [B]
|
||||
|
||||
Returns:
|
||||
y (dict or list or tuple): new nested dict-list-tuple with tensors of shape [B, ...]
|
||||
"""
|
||||
return gather_along_dim_with_dim(seq, target_dim=1, source_dim=0, indices=indices)
|
||||
|
||||
|
||||
def pad_sequence_single(seq, padding, batched=False, pad_same=True, pad_values=None):
|
||||
"""
|
||||
Pad input tensor or array @seq in the time dimension (dimension 1).
|
||||
|
||||
Args:
|
||||
seq (np.ndarray or torch.Tensor): sequence to be padded
|
||||
padding (tuple): begin and end padding, e.g. [1, 1] pads both begin and end of the sequence by 1
|
||||
batched (bool): if sequence has the batch dimension
|
||||
pad_same (bool): if pad by duplicating
|
||||
pad_values (scalar or (ndarray, Tensor)): values to be padded if not pad_same
|
||||
|
||||
Returns:
|
||||
padded sequence (np.ndarray or torch.Tensor)
|
||||
"""
|
||||
assert isinstance(seq, (np.ndarray, torch.Tensor))
|
||||
assert pad_same or pad_values is not None
|
||||
if pad_values is not None:
|
||||
assert isinstance(pad_values, float)
|
||||
repeat_func = np.repeat if isinstance(seq, np.ndarray) else torch.repeat_interleave
|
||||
concat_func = np.concatenate if isinstance(seq, np.ndarray) else torch.cat
|
||||
ones_like_func = np.ones_like if isinstance(seq, np.ndarray) else torch.ones_like
|
||||
seq_dim = 1 if batched else 0
|
||||
|
||||
begin_pad = []
|
||||
end_pad = []
|
||||
|
||||
if padding[0] > 0:
|
||||
pad = seq[[0]] if pad_same else ones_like_func(seq[[0]]) * pad_values
|
||||
begin_pad.append(repeat_func(pad, padding[0], seq_dim))
|
||||
if padding[1] > 0:
|
||||
pad = seq[[-1]] if pad_same else ones_like_func(seq[[-1]]) * pad_values
|
||||
end_pad.append(repeat_func(pad, padding[1], seq_dim))
|
||||
|
||||
return concat_func(begin_pad + [seq] + end_pad, seq_dim)
|
||||
|
||||
|
||||
def pad_sequence(seq, padding, batched=False, pad_same=True, pad_values=None):
|
||||
"""
|
||||
Pad a nested dictionary or list or tuple of sequence tensors in the time dimension (dimension 1).
|
||||
|
||||
Args:
|
||||
seq (dict or list or tuple): a possibly nested dictionary or list or tuple with tensors
|
||||
of leading dimensions [B, T, ...]
|
||||
padding (tuple): begin and end padding, e.g. [1, 1] pads both begin and end of the sequence by 1
|
||||
batched (bool): if sequence has the batch dimension
|
||||
pad_same (bool): if pad by duplicating
|
||||
pad_values (scalar or (ndarray, Tensor)): values to be padded if not pad_same
|
||||
|
||||
Returns:
|
||||
padded sequence (dict or list or tuple)
|
||||
"""
|
||||
return recursive_dict_list_tuple_apply(
|
||||
seq,
|
||||
{
|
||||
torch.Tensor: lambda x, p=padding, b=batched, ps=pad_same, pv=pad_values: pad_sequence_single(
|
||||
x, p, b, ps, pv
|
||||
),
|
||||
np.ndarray: lambda x, p=padding, b=batched, ps=pad_same, pv=pad_values: pad_sequence_single(
|
||||
x, p, b, ps, pv
|
||||
),
|
||||
type(None): lambda x: x,
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
def assert_size_at_dim_single(x, size, dim, msg):
|
||||
"""
|
||||
Ensure that array or tensor @x has size @size in dim @dim.
|
||||
|
||||
Args:
|
||||
x (np.ndarray or torch.Tensor): input array or tensor
|
||||
size (int): size that tensors should have at @dim
|
||||
dim (int): dimension to check
|
||||
msg (str): text to display if assertion fails
|
||||
"""
|
||||
assert x.shape[dim] == size, msg
|
||||
|
||||
|
||||
def assert_size_at_dim(x, size, dim, msg):
|
||||
"""
|
||||
Ensure that arrays and tensors in nested dictionary or list or tuple have
|
||||
size @size in dim @dim.
|
||||
|
||||
Args:
|
||||
x (dict or list or tuple): a possibly nested dictionary or list or tuple
|
||||
size (int): size that tensors should have at @dim
|
||||
dim (int): dimension to check
|
||||
"""
|
||||
map_tensor(x, lambda t, s=size, d=dim, m=msg: assert_size_at_dim_single(t, s, d, m))
|
||||
|
||||
|
||||
def get_shape(x):
|
||||
"""
|
||||
Get all shapes of arrays and tensors in nested dictionary or list or tuple.
|
||||
|
||||
Args:
|
||||
x (dict or list or tuple): a possibly nested dictionary or list or tuple
|
||||
|
||||
Returns:
|
||||
y (dict or list or tuple): new nested dict-list-tuple that contains each array or
|
||||
tensor's shape
|
||||
"""
|
||||
return recursive_dict_list_tuple_apply(
|
||||
x,
|
||||
{
|
||||
torch.Tensor: lambda x: x.shape,
|
||||
np.ndarray: lambda x: x.shape,
|
||||
type(None): lambda x: x,
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
def list_of_flat_dict_to_dict_of_list(list_of_dict):
|
||||
"""
|
||||
Helper function to go from a list of flat dictionaries to a dictionary of lists.
|
||||
By "flat" we mean that none of the values are dictionaries, but are numpy arrays,
|
||||
floats, etc.
|
||||
|
||||
Args:
|
||||
list_of_dict (list): list of flat dictionaries
|
||||
|
||||
Returns:
|
||||
dict_of_list (dict): dictionary of lists
|
||||
"""
|
||||
assert isinstance(list_of_dict, list)
|
||||
dic = collections.OrderedDict()
|
||||
for i in range(len(list_of_dict)):
|
||||
for k in list_of_dict[i]:
|
||||
if k not in dic:
|
||||
dic[k] = []
|
||||
dic[k].append(list_of_dict[i][k])
|
||||
return dic
|
||||
|
||||
|
||||
def flatten_nested_dict_list(d, parent_key="", sep="_", item_key=""):
|
||||
"""
|
||||
Flatten a nested dict or list to a list.
|
||||
|
||||
For example, given a dict
|
||||
{
|
||||
a: 1
|
||||
b: {
|
||||
c: 2
|
||||
}
|
||||
c: 3
|
||||
}
|
||||
|
||||
the function would return [(a, 1), (b_c, 2), (c, 3)]
|
||||
|
||||
Args:
|
||||
d (dict, list): a nested dict or list to be flattened
|
||||
parent_key (str): recursion helper
|
||||
sep (str): separator for nesting keys
|
||||
item_key (str): recursion helper
|
||||
Returns:
|
||||
list: a list of (key, value) tuples
|
||||
"""
|
||||
items = []
|
||||
if isinstance(d, (tuple, list)):
|
||||
new_key = parent_key + sep + item_key if len(parent_key) > 0 else item_key
|
||||
for i, v in enumerate(d):
|
||||
items.extend(flatten_nested_dict_list(v, new_key, sep=sep, item_key=str(i)))
|
||||
return items
|
||||
elif isinstance(d, dict):
|
||||
new_key = parent_key + sep + item_key if len(parent_key) > 0 else item_key
|
||||
for k, v in d.items():
|
||||
assert isinstance(k, str)
|
||||
items.extend(flatten_nested_dict_list(v, new_key, sep=sep, item_key=k))
|
||||
return items
|
||||
else:
|
||||
new_key = parent_key + sep + item_key if len(parent_key) > 0 else item_key
|
||||
return [(new_key, d)]
|
||||
|
||||
|
||||
def time_distributed(inputs, op, activation=None, inputs_as_kwargs=False, inputs_as_args=False, **kwargs):
|
||||
"""
|
||||
Apply function @op to all tensors in nested dictionary or list or tuple @inputs in both the
|
||||
batch (B) and time (T) dimension, where the tensors are expected to have shape [B, T, ...].
|
||||
Will do this by reshaping tensors to [B * T, ...], passing through the op, and then reshaping
|
||||
outputs to [B, T, ...].
|
||||
|
||||
Args:
|
||||
inputs (list or tuple or dict): a possibly nested dictionary or list or tuple with tensors
|
||||
of leading dimensions [B, T, ...]
|
||||
op: a layer op that accepts inputs
|
||||
activation: activation to apply at the output
|
||||
inputs_as_kwargs (bool): whether to feed input as a kwargs dict to the op
|
||||
inputs_as_args (bool) whether to feed input as a args list to the op
|
||||
kwargs (dict): other kwargs to supply to the op
|
||||
|
||||
Returns:
|
||||
outputs (dict or list or tuple): new nested dict-list-tuple with tensors of leading dimension [B, T].
|
||||
"""
|
||||
batch_size, seq_len = flatten_nested_dict_list(inputs)[0][1].shape[:2]
|
||||
inputs = join_dimensions(inputs, 0, 1)
|
||||
if inputs_as_kwargs:
|
||||
outputs = op(**inputs, **kwargs)
|
||||
elif inputs_as_args:
|
||||
outputs = op(*inputs, **kwargs)
|
||||
else:
|
||||
outputs = op(inputs, **kwargs)
|
||||
|
||||
if activation is not None:
|
||||
outputs = map_tensor(outputs, activation)
|
||||
outputs = reshape_dimensions(outputs, begin_axis=0, end_axis=0, target_dims=(batch_size, seq_len))
|
||||
return outputs
|
|
@ -5,11 +5,10 @@ from collections import deque
|
|||
|
||||
import hydra
|
||||
import torch
|
||||
from torch import nn
|
||||
from diffusers.optimization import get_scheduler
|
||||
from torch import Tensor, nn
|
||||
|
||||
from lerobot.common.policies.diffusion.diffusion_unet_image_policy import DiffusionUnetImagePolicy
|
||||
from lerobot.common.policies.diffusion.model.lr_scheduler import get_scheduler
|
||||
from lerobot.common.policies.diffusion.model.multi_image_obs_encoder import MultiImageObsEncoder, RgbEncoder
|
||||
from lerobot.common.policies.diffusion.model.diffusion_unet_image_policy import DiffusionUnetImagePolicy
|
||||
from lerobot.common.policies.utils import populate_queues
|
||||
from lerobot.common.utils import get_safe_torch_device
|
||||
|
||||
|
@ -22,8 +21,6 @@ class DiffusionPolicy(nn.Module):
|
|||
cfg,
|
||||
cfg_device,
|
||||
cfg_noise_scheduler,
|
||||
cfg_rgb_model,
|
||||
cfg_obs_encoder,
|
||||
cfg_optimizer,
|
||||
cfg_ema,
|
||||
shape_meta: dict,
|
||||
|
@ -31,53 +28,43 @@ class DiffusionPolicy(nn.Module):
|
|||
n_action_steps,
|
||||
n_obs_steps,
|
||||
num_inference_steps=None,
|
||||
obs_as_global_cond=True,
|
||||
diffusion_step_embed_dim=256,
|
||||
down_dims=(256, 512, 1024),
|
||||
kernel_size=5,
|
||||
n_groups=8,
|
||||
cond_predict_scale=True,
|
||||
# parameters passed to step
|
||||
**kwargs,
|
||||
film_scale_modulation=True,
|
||||
**_,
|
||||
):
|
||||
super().__init__()
|
||||
self.cfg = cfg
|
||||
self.n_obs_steps = n_obs_steps
|
||||
self.n_action_steps = n_action_steps
|
||||
|
||||
# queues are populated during rollout of the policy, they contain the n latest observations and actions
|
||||
self._queues = None
|
||||
|
||||
# TODO(now): In-house this.
|
||||
noise_scheduler = hydra.utils.instantiate(cfg_noise_scheduler)
|
||||
rgb_model_input_shape = copy.deepcopy(shape_meta.obs.image.shape)
|
||||
if cfg_obs_encoder.crop_shape is not None:
|
||||
rgb_model_input_shape[1:] = cfg_obs_encoder.crop_shape
|
||||
rgb_model = RgbEncoder(input_shape=rgb_model_input_shape, **cfg_rgb_model)
|
||||
obs_encoder = MultiImageObsEncoder(
|
||||
rgb_model=rgb_model,
|
||||
**cfg_obs_encoder,
|
||||
)
|
||||
|
||||
self.diffusion = DiffusionUnetImagePolicy(
|
||||
cfg,
|
||||
shape_meta=shape_meta,
|
||||
noise_scheduler=noise_scheduler,
|
||||
obs_encoder=obs_encoder,
|
||||
horizon=horizon,
|
||||
n_action_steps=n_action_steps,
|
||||
n_obs_steps=n_obs_steps,
|
||||
num_inference_steps=num_inference_steps,
|
||||
obs_as_global_cond=obs_as_global_cond,
|
||||
diffusion_step_embed_dim=diffusion_step_embed_dim,
|
||||
down_dims=down_dims,
|
||||
kernel_size=kernel_size,
|
||||
n_groups=n_groups,
|
||||
cond_predict_scale=cond_predict_scale,
|
||||
# parameters passed to step
|
||||
**kwargs,
|
||||
film_scale_modulation=film_scale_modulation,
|
||||
)
|
||||
|
||||
self.device = get_safe_torch_device(cfg_device)
|
||||
self.diffusion.to(self.device)
|
||||
|
||||
# TODO(alexander-soare): This should probably be managed outside of the policy class.
|
||||
self.ema_diffusion = None
|
||||
self.ema = None
|
||||
if self.cfg.use_ema:
|
||||
|
@ -116,42 +103,45 @@ class DiffusionPolicy(nn.Module):
|
|||
"action": deque(maxlen=self.n_action_steps),
|
||||
}
|
||||
|
||||
@torch.no_grad()
|
||||
def select_action(self, batch, step):
|
||||
def forward(self, batch: dict[str, Tensor], **_) -> Tensor:
|
||||
"""A forward pass through the DNN part of this policy with optional loss computation."""
|
||||
return self.select_action(batch)
|
||||
|
||||
@torch.no_grad
|
||||
def select_action(self, batch, **_):
|
||||
"""
|
||||
Note: this uses the ema model weights if self.training == False, otherwise the non-ema model weights.
|
||||
# TODO(now): Handle a batch
|
||||
"""
|
||||
# TODO(rcadene): remove unused step
|
||||
del step
|
||||
assert "observation.image" in batch
|
||||
assert "observation.state" in batch
|
||||
assert len(batch) == 2
|
||||
assert len(batch) == 2 # TODO(now): Does this not have a batch dim?
|
||||
|
||||
self._queues = populate_queues(self._queues, batch)
|
||||
|
||||
if len(self._queues["action"]) == 0:
|
||||
# stack n latest observations from the queue
|
||||
batch = {key: torch.stack(list(self._queues[key]), dim=1) for key in batch}
|
||||
|
||||
obs_dict = {
|
||||
"image": batch["observation.image"],
|
||||
"agent_pos": batch["observation.state"],
|
||||
}
|
||||
if self.training:
|
||||
out = self.diffusion.predict_action(obs_dict)
|
||||
else:
|
||||
out = self.ema_diffusion.predict_action(obs_dict)
|
||||
self._queues["action"].extend(out["action"].transpose(0, 1))
|
||||
actions = self._generate_actions(batch)
|
||||
self._queues["action"].extend(actions.transpose(0, 1))
|
||||
|
||||
action = self._queues["action"].popleft()
|
||||
return action
|
||||
|
||||
def forward(self, batch, step):
|
||||
def _generate_actions(self, batch):
|
||||
if not self.training and self.ema_diffusion is not None:
|
||||
return self.ema_diffusion.predict_action(batch)
|
||||
else:
|
||||
return self.diffusion.predict_action(batch)
|
||||
|
||||
def update(self, batch, **_):
|
||||
"""Run the model in train mode, compute the loss, and do an optimization step."""
|
||||
start_time = time.time()
|
||||
|
||||
self.diffusion.train()
|
||||
|
||||
loss = self.diffusion.compute_loss(batch)
|
||||
loss = self.compute_loss(batch)
|
||||
|
||||
loss.backward()
|
||||
|
||||
grad_norm = torch.nn.utils.clip_grad_norm_(
|
||||
|
@ -174,13 +164,11 @@ class DiffusionPolicy(nn.Module):
|
|||
"update_s": time.time() - start_time,
|
||||
}
|
||||
|
||||
# TODO(rcadene): remove hardcoding
|
||||
# in diffusion_policy, len(dataloader) is 168 for a batch_size of 64
|
||||
if step % 168 == 0:
|
||||
self.global_step += 1
|
||||
|
||||
return info
|
||||
|
||||
def compute_loss(self, batch: dict[str, Tensor]) -> Tensor:
|
||||
return self.diffusion.compute_loss(batch)
|
||||
|
||||
def save(self, fp):
|
||||
torch.save(self.state_dict(), fp)
|
||||
|
||||
|
|
|
@ -1,76 +0,0 @@
|
|||
from typing import Callable, Dict
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torchvision
|
||||
|
||||
|
||||
def get_resnet(name, weights=None, **kwargs):
|
||||
"""
|
||||
name: resnet18, resnet34, resnet50
|
||||
weights: "IMAGENET1K_V1", "r3m"
|
||||
"""
|
||||
# load r3m weights
|
||||
if (weights == "r3m") or (weights == "R3M"):
|
||||
return get_r3m(name=name, **kwargs)
|
||||
|
||||
func = getattr(torchvision.models, name)
|
||||
resnet = func(weights=weights, **kwargs)
|
||||
resnet.fc = torch.nn.Identity()
|
||||
return resnet
|
||||
|
||||
|
||||
def get_r3m(name, **kwargs):
|
||||
"""
|
||||
name: resnet18, resnet34, resnet50
|
||||
"""
|
||||
import r3m
|
||||
|
||||
r3m.device = "cpu"
|
||||
model = r3m.load_r3m(name)
|
||||
r3m_model = model.module
|
||||
resnet_model = r3m_model.convnet
|
||||
resnet_model = resnet_model.to("cpu")
|
||||
return resnet_model
|
||||
|
||||
|
||||
def dict_apply(
|
||||
x: Dict[str, torch.Tensor], func: Callable[[torch.Tensor], torch.Tensor]
|
||||
) -> Dict[str, torch.Tensor]:
|
||||
result = {}
|
||||
for key, value in x.items():
|
||||
if isinstance(value, dict):
|
||||
result[key] = dict_apply(value, func)
|
||||
else:
|
||||
result[key] = func(value)
|
||||
return result
|
||||
|
||||
|
||||
def replace_submodules(
|
||||
root_module: nn.Module, predicate: Callable[[nn.Module], bool], func: Callable[[nn.Module], nn.Module]
|
||||
) -> nn.Module:
|
||||
"""
|
||||
predicate: Return true if the module is to be replaced.
|
||||
func: Return new module to use.
|
||||
"""
|
||||
if predicate(root_module):
|
||||
return func(root_module)
|
||||
|
||||
bn_list = [k.split(".") for k, m in root_module.named_modules(remove_duplicate=True) if predicate(m)]
|
||||
for *parent, k in bn_list:
|
||||
parent_module = root_module
|
||||
if len(parent) > 0:
|
||||
parent_module = root_module.get_submodule(".".join(parent))
|
||||
if isinstance(parent_module, nn.Sequential):
|
||||
src_module = parent_module[int(k)]
|
||||
else:
|
||||
src_module = getattr(parent_module, k)
|
||||
tgt_module = func(src_module)
|
||||
if isinstance(parent_module, nn.Sequential):
|
||||
parent_module[int(k)] = tgt_module
|
||||
else:
|
||||
setattr(parent_module, k, tgt_module)
|
||||
# verify that all BN are replaced
|
||||
bn_list = [k.split(".") for k, m in root_module.named_modules(remove_duplicate=True) if predicate(m)]
|
||||
assert len(bn_list) == 0
|
||||
return root_module
|
|
@ -12,8 +12,6 @@ def make_policy(cfg):
|
|||
cfg=cfg.policy,
|
||||
cfg_device=cfg.device,
|
||||
cfg_noise_scheduler=cfg.noise_scheduler,
|
||||
cfg_rgb_model=cfg.rgb_model,
|
||||
cfg_obs_encoder=cfg.obs_encoder,
|
||||
cfg_optimizer=cfg.optimizer,
|
||||
cfg_ema=cfg.ema,
|
||||
# n_obs_steps=cfg.n_obs_steps,
|
||||
|
|
|
@ -1,3 +1,7 @@
|
|||
import torch
|
||||
from torch import nn
|
||||
|
||||
|
||||
def populate_queues(queues, batch):
|
||||
for key in batch:
|
||||
if len(queues[key]) != queues[key].maxlen:
|
||||
|
@ -8,3 +12,21 @@ def populate_queues(queues, batch):
|
|||
# add latest observation to the queue
|
||||
queues[key].append(batch[key])
|
||||
return queues
|
||||
|
||||
|
||||
def get_device_from_parameters(module: nn.Module) -> torch.device:
|
||||
"""Get a module's device by checking one of its parameters.
|
||||
|
||||
Note: assumes that all parameters have the same device
|
||||
TODO(now): Add test.
|
||||
"""
|
||||
return next(iter(module.parameters())).device
|
||||
|
||||
|
||||
def get_dtype_from_parameters(module: nn.Module) -> torch.dtype:
|
||||
"""Get a module's parameter dtype by checking one of its parameters.
|
||||
|
||||
Note: assumes that all parameters have the same dtype.
|
||||
TODO(now): Add test.
|
||||
"""
|
||||
return next(iter(module.parameters())).dtype
|
||||
|
|
|
@ -11,6 +11,7 @@ from omegaconf import DictConfig
|
|||
|
||||
|
||||
def get_safe_torch_device(cfg_device: str, log: bool = False) -> torch.device:
|
||||
"""Given a string, return a torch.device with checks on whether the device is available."""
|
||||
match cfg_device:
|
||||
case "cuda":
|
||||
assert torch.cuda.is_available()
|
||||
|
|
|
@ -19,7 +19,6 @@ n_action_steps: 8
|
|||
dataset_obs_steps: ${n_obs_steps}
|
||||
past_action_visible: False
|
||||
keypoint_visible_rate: 1.0
|
||||
obs_as_global_cond: True
|
||||
|
||||
eval_episodes: 50
|
||||
eval_freq: 5000
|
||||
|
@ -40,13 +39,12 @@ policy:
|
|||
n_obs_steps: ${n_obs_steps}
|
||||
n_action_steps: ${n_action_steps}
|
||||
num_inference_steps: 100
|
||||
obs_as_global_cond: ${obs_as_global_cond}
|
||||
# crop_shape: null
|
||||
diffusion_step_embed_dim: 128
|
||||
down_dims: [512, 1024, 2048]
|
||||
kernel_size: 5
|
||||
n_groups: 8
|
||||
cond_predict_scale: True
|
||||
film_scale_modulation: True
|
||||
|
||||
pretrained_model_path:
|
||||
|
||||
|
@ -68,6 +66,16 @@ policy:
|
|||
observation.state: [-0.1, 0]
|
||||
action: [-0.1, 0, .1, .2, .3, .4, .5, .6, .7, .8, .9, 1.0, 1.1, 1.2, 1.3, 1.4]
|
||||
|
||||
rgb_encoder:
|
||||
backbone_name: resnet18
|
||||
pretrained_backbone: false
|
||||
use_group_norm: True
|
||||
num_keypoints: 32
|
||||
relu: true
|
||||
norm_mean_std: [0.5, 0.5] # for PushT the original impl normalizes to [-1, 1] (maybe not the case for robomimic envs)
|
||||
crop_shape: [84, 84]
|
||||
random_crop: True
|
||||
|
||||
noise_scheduler:
|
||||
_target_: diffusers.schedulers.scheduling_ddpm.DDPMScheduler
|
||||
num_train_timesteps: 100
|
||||
|
@ -78,16 +86,6 @@ noise_scheduler:
|
|||
clip_sample: True # required when predict_epsilon=False
|
||||
prediction_type: epsilon # or sample
|
||||
|
||||
obs_encoder:
|
||||
shape_meta: ${shape_meta}
|
||||
# resize_shape: null
|
||||
crop_shape: [84, 84]
|
||||
# constant center crop
|
||||
random_crop: True
|
||||
use_group_norm: True
|
||||
share_rgb_model: False
|
||||
norm_mean_std: [0.5, 0.5] # for PushT the original impl normalizes to [-1, 1] (maybe not the case for robomimic envs)
|
||||
|
||||
rgb_model:
|
||||
pretrained: false
|
||||
num_keypoints: 32
|
||||
|
|
|
@ -121,7 +121,7 @@ def eval_policy(
|
|||
|
||||
# get the next action for the environment
|
||||
with torch.inference_mode():
|
||||
action = policy.select_action(observation, step)
|
||||
action = policy.select_action(observation, step=step)
|
||||
|
||||
# apply inverse transform to unnormalize the action
|
||||
action = postprocess_action(action, transform)
|
||||
|
|
|
@ -213,7 +213,7 @@ def train(cfg: dict, out_dir=None, job_name=None):
|
|||
for key in batch:
|
||||
batch[key] = batch[key].to(cfg.device, non_blocking=True)
|
||||
|
||||
train_info = policy(batch, step)
|
||||
train_info = policy.update(batch, step=step)
|
||||
|
||||
# TODO(rcadene): is it ok if step_t=0 = 0 and not 1 as previously done?
|
||||
if step % cfg.log_freq == 0:
|
||||
|
|
Loading…
Reference in New Issue