backup wip
This commit is contained in:
parent
94cc22da9e
commit
5666ec3ec7
|
@ -66,53 +66,33 @@ class DiffusionUnetImagePolicy(nn.Module):
|
||||||
self.num_inference_steps = num_inference_steps
|
self.num_inference_steps = num_inference_steps
|
||||||
|
|
||||||
# ========= inference ============
|
# ========= inference ============
|
||||||
def conditional_sample(
|
def conditional_sample(self, batch_size, global_cond=None, generator=None):
|
||||||
self,
|
device = get_device_from_parameters(self)
|
||||||
condition_data,
|
dtype = get_dtype_from_parameters(self)
|
||||||
inpainting_mask,
|
|
||||||
local_cond=None,
|
|
||||||
global_cond=None,
|
|
||||||
generator=None,
|
|
||||||
):
|
|
||||||
model = self.unet
|
|
||||||
scheduler = self.noise_scheduler
|
|
||||||
|
|
||||||
trajectory = torch.randn(
|
# Sample prior.
|
||||||
size=condition_data.shape,
|
sample = torch.randn(
|
||||||
dtype=condition_data.dtype,
|
size=(batch_size, self.horizon, self.action_dim),
|
||||||
device=condition_data.device,
|
dtype=dtype,
|
||||||
|
device=device,
|
||||||
generator=generator,
|
generator=generator,
|
||||||
)
|
)
|
||||||
|
|
||||||
# set step values
|
self.noise_scheduler.set_timesteps(self.num_inference_steps)
|
||||||
scheduler.set_timesteps(self.num_inference_steps)
|
|
||||||
|
|
||||||
for t in scheduler.timesteps:
|
for t in self.noise_scheduler.timesteps:
|
||||||
# 1. apply conditioning
|
# Predict model output.
|
||||||
trajectory[inpainting_mask] = condition_data[inpainting_mask]
|
model_output = self.unet(
|
||||||
|
sample,
|
||||||
# 2. predict model output
|
torch.full(sample.shape[:1], t, dtype=torch.long, device=sample.device),
|
||||||
model_output = model(
|
|
||||||
trajectory,
|
|
||||||
torch.full(trajectory.shape[:1], t, dtype=torch.long, device=trajectory.device),
|
|
||||||
local_cond=local_cond,
|
|
||||||
global_cond=global_cond,
|
global_cond=global_cond,
|
||||||
)
|
)
|
||||||
|
# Compute previous image: x_t -> x_t-1 # TODO(now): Is this right?
|
||||||
|
sample = self.noise_scheduler.step(model_output, t, sample, generator=generator).prev_sample
|
||||||
|
|
||||||
# 3. compute previous image: x_t -> x_t-1
|
return sample
|
||||||
trajectory = scheduler.step(
|
|
||||||
model_output,
|
|
||||||
t,
|
|
||||||
trajectory,
|
|
||||||
generator=generator,
|
|
||||||
).prev_sample
|
|
||||||
|
|
||||||
# finally make sure conditioning is enforced
|
def generate_actions(self, batch: dict[str, Tensor]) -> dict[str, Tensor]:
|
||||||
trajectory[inpainting_mask] = condition_data[inpainting_mask]
|
|
||||||
|
|
||||||
return trajectory
|
|
||||||
|
|
||||||
def predict_action(self, batch: dict[str, Tensor]) -> dict[str, Tensor]:
|
|
||||||
"""
|
"""
|
||||||
This function expects `batch` to have (at least):
|
This function expects `batch` to have (at least):
|
||||||
{
|
{
|
||||||
|
@ -125,27 +105,19 @@ class DiffusionUnetImagePolicy(nn.Module):
|
||||||
assert n_obs_steps == self.n_obs_steps
|
assert n_obs_steps == self.n_obs_steps
|
||||||
assert self.n_obs_steps == n_obs_steps
|
assert self.n_obs_steps == n_obs_steps
|
||||||
|
|
||||||
# build input
|
|
||||||
device = get_device_from_parameters(self)
|
|
||||||
dtype = get_dtype_from_parameters(self)
|
|
||||||
|
|
||||||
# Extract image feature (first combine batch and sequence dims).
|
# Extract image feature (first combine batch and sequence dims).
|
||||||
img_features = self.rgb_encoder(einops.rearrange(batch["observation.image"], "b n ... -> (b n) ..."))
|
img_features = self.rgb_encoder(einops.rearrange(batch["observation.image"], "b n ... -> (b n) ..."))
|
||||||
# Separate batch and sequence dims.
|
# Separate batch and sequence dims.
|
||||||
img_features = einops.rearrange(img_features, "(b n) ... -> b n ...", b=batch_size)
|
img_features = einops.rearrange(img_features, "(b n) ... -> b n ...", b=batch_size)
|
||||||
# Concatenate state and image features then flatten to (B, global_cond_dim).
|
# Concatenate state and image features then flatten to (B, global_cond_dim).
|
||||||
global_cond = torch.cat([batch["observation.state"], img_features], dim=-1).flatten(start_dim=1)
|
global_cond = torch.cat([batch["observation.state"], img_features], dim=-1).flatten(start_dim=1)
|
||||||
# reshape back to B, Do
|
|
||||||
# empty data for action
|
|
||||||
cond_data = torch.zeros(size=(batch_size, self.horizon, self.action_dim), device=device, dtype=dtype)
|
|
||||||
cond_mask = torch.zeros_like(cond_data, dtype=torch.bool)
|
|
||||||
|
|
||||||
# run sampling
|
# run sampling
|
||||||
nsample = self.conditional_sample(cond_data, cond_mask, global_cond=global_cond)
|
sample = self.conditional_sample(batch_size, global_cond=global_cond)
|
||||||
|
|
||||||
# `horizon` steps worth of actions (from the first observation).
|
# `horizon` steps worth of actions (from the first observation).
|
||||||
action = nsample[..., : self.action_dim]
|
action = sample[..., : self.action_dim]
|
||||||
# Extract `n_action_steps` steps worth of action (from the current observation).
|
# Extract `n_action_steps` steps worth of actions (from the current observation).
|
||||||
start = n_obs_steps - 1
|
start = n_obs_steps - 1
|
||||||
end = start + self.n_action_steps
|
end = start + self.n_action_steps
|
||||||
action = action[:, start:end]
|
action = action[:, start:end]
|
||||||
|
@ -159,9 +131,10 @@ class DiffusionUnetImagePolicy(nn.Module):
|
||||||
"observation.state": (B, n_obs_steps, state_dim)
|
"observation.state": (B, n_obs_steps, state_dim)
|
||||||
"observation.image": (B, n_obs_steps, C, H, W)
|
"observation.image": (B, n_obs_steps, C, H, W)
|
||||||
"action": (B, horizon, action_dim)
|
"action": (B, horizon, action_dim)
|
||||||
"action_is_pad": (B, horizon) # TODO(now) maybe this is (B, horizon, 1)
|
"action_is_pad": (B, horizon)
|
||||||
}
|
}
|
||||||
"""
|
"""
|
||||||
|
# Input validation.
|
||||||
assert set(batch).issuperset({"observation.state", "observation.image", "action", "action_is_pad"})
|
assert set(batch).issuperset({"observation.state", "observation.image", "action", "action_is_pad"})
|
||||||
batch_size, n_obs_steps = batch["observation.state"].shape[:2]
|
batch_size, n_obs_steps = batch["observation.state"].shape[:2]
|
||||||
horizon = batch["action"].shape[1]
|
horizon = batch["action"].shape[1]
|
||||||
|
@ -169,12 +142,6 @@ class DiffusionUnetImagePolicy(nn.Module):
|
||||||
assert n_obs_steps == self.n_obs_steps
|
assert n_obs_steps == self.n_obs_steps
|
||||||
assert self.n_obs_steps == n_obs_steps
|
assert self.n_obs_steps == n_obs_steps
|
||||||
|
|
||||||
# handle different ways of passing observation
|
|
||||||
local_cond = None
|
|
||||||
global_cond = None
|
|
||||||
trajectory = batch["action"]
|
|
||||||
cond_data = trajectory
|
|
||||||
|
|
||||||
# Extract image feature (first combine batch and sequence dims).
|
# Extract image feature (first combine batch and sequence dims).
|
||||||
img_features = self.rgb_encoder(einops.rearrange(batch["observation.image"], "b n ... -> (b n) ..."))
|
img_features = self.rgb_encoder(einops.rearrange(batch["observation.image"], "b n ... -> (b n) ..."))
|
||||||
# Separate batch and sequence dims.
|
# Separate batch and sequence dims.
|
||||||
|
@ -182,39 +149,39 @@ class DiffusionUnetImagePolicy(nn.Module):
|
||||||
# Concatenate state and image features then flatten to (B, global_cond_dim).
|
# Concatenate state and image features then flatten to (B, global_cond_dim).
|
||||||
global_cond = torch.cat([batch["observation.state"], img_features], dim=-1).flatten(start_dim=1)
|
global_cond = torch.cat([batch["observation.state"], img_features], dim=-1).flatten(start_dim=1)
|
||||||
|
|
||||||
# Sample noise that we'll add to the images
|
trajectory = batch["action"]
|
||||||
noise = torch.randn(trajectory.shape, device=trajectory.device)
|
|
||||||
# Sample a random timestep for each image
|
# Forward diffusion.
|
||||||
|
# Sample noise to add to the trajectory.
|
||||||
|
eps = torch.randn(trajectory.shape, device=trajectory.device)
|
||||||
|
# Sample a random noising timestep for each item in the batch.
|
||||||
timesteps = torch.randint(
|
timesteps = torch.randint(
|
||||||
0,
|
low=0,
|
||||||
self.noise_scheduler.config.num_train_timesteps,
|
high=self.noise_scheduler.config.num_train_timesteps,
|
||||||
(trajectory.shape[0],),
|
size=(trajectory.shape[0],),
|
||||||
device=trajectory.device,
|
device=trajectory.device,
|
||||||
).long()
|
).long()
|
||||||
# Add noise to the clean images according to the noise magnitude at each timestep
|
# Add noise to the clean trajectories according to the noise magnitude at each timestep.
|
||||||
# (this is the forward diffusion process)
|
noisy_trajectory = self.noise_scheduler.add_noise(trajectory, eps, timesteps)
|
||||||
noisy_trajectory = self.noise_scheduler.add_noise(trajectory, noise, timesteps)
|
|
||||||
|
|
||||||
# Apply inpainting. TODO(now): implement?
|
# Run the denoising network (that might denoise the trajectory, or attempt to predict the noise).
|
||||||
inpainting_mask = torch.zeros_like(trajectory, dtype=bool)
|
pred = self.unet(noisy_trajectory, timesteps, global_cond=global_cond)
|
||||||
noisy_trajectory[inpainting_mask] = cond_data[inpainting_mask]
|
|
||||||
|
|
||||||
# Predict the noise residual
|
|
||||||
pred = self.unet(noisy_trajectory, timesteps, local_cond=local_cond, global_cond=global_cond)
|
|
||||||
|
|
||||||
|
# Compute the loss.
|
||||||
|
# The targe is either the original trajectory, or the noise.
|
||||||
pred_type = self.noise_scheduler.config.prediction_type
|
pred_type = self.noise_scheduler.config.prediction_type
|
||||||
if pred_type == "epsilon":
|
if pred_type == "epsilon":
|
||||||
target = noise
|
target = eps
|
||||||
elif pred_type == "sample":
|
elif pred_type == "sample":
|
||||||
target = trajectory
|
target = batch["action"]
|
||||||
else:
|
else:
|
||||||
raise ValueError(f"Unsupported prediction type {pred_type}")
|
raise ValueError(f"Unsupported prediction type {pred_type}")
|
||||||
|
|
||||||
loss = F.mse_loss(pred, target, reduction="none")
|
loss = F.mse_loss(pred, target, reduction="none")
|
||||||
loss = loss * (~inpainting_mask)
|
|
||||||
|
|
||||||
|
# Mask loss wherever the action is padded with copies (edges of the dataset trajectory).
|
||||||
if "action_is_pad" in batch:
|
if "action_is_pad" in batch:
|
||||||
in_episode_bound = ~batch["action_is_pad"]
|
in_episode_bound = ~batch["action_is_pad"]
|
||||||
loss = loss * in_episode_bound[:, :, None].type(loss.dtype)
|
loss = loss * in_episode_bound.unsqueeze(-1)
|
||||||
|
|
||||||
return loss.mean()
|
return loss.mean()
|
||||||
|
|
|
@ -32,7 +32,7 @@ class EMAModel:
|
||||||
self.min_value = min_value
|
self.min_value = min_value
|
||||||
self.max_value = max_value
|
self.max_value = max_value
|
||||||
|
|
||||||
self.decay = 0.0
|
self.alpha = 0.0
|
||||||
self.optimization_step = 0
|
self.optimization_step = 0
|
||||||
|
|
||||||
def get_decay(self, optimization_step):
|
def get_decay(self, optimization_step):
|
||||||
|
@ -49,23 +49,20 @@ class EMAModel:
|
||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
def step(self, new_model):
|
def step(self, new_model):
|
||||||
self.decay = self.get_decay(self.optimization_step)
|
self.alpha = self.get_decay(self.optimization_step)
|
||||||
|
|
||||||
for module, ema_module in zip(new_model.modules(), self.averaged_model.modules(), strict=False):
|
for module, ema_module in zip(new_model.modules(), self.averaged_model.modules(), strict=True):
|
||||||
|
# Iterate over immediate parameters only.
|
||||||
for param, ema_param in zip(
|
for param, ema_param in zip(
|
||||||
module.parameters(recurse=False), ema_module.parameters(recurse=False), strict=False
|
module.parameters(recurse=False), ema_module.parameters(recurse=False), strict=True
|
||||||
):
|
):
|
||||||
# iterative over immediate parameters only.
|
|
||||||
if isinstance(param, dict):
|
if isinstance(param, dict):
|
||||||
raise RuntimeError("Dict parameter not supported")
|
raise RuntimeError("Dict parameter not supported")
|
||||||
|
if isinstance(module, _BatchNorm) or not param.requires_grad:
|
||||||
if isinstance(module, _BatchNorm):
|
# Copy BatchNorm parameters, and non-trainable parameters directly.
|
||||||
# skip batchnorms
|
|
||||||
ema_param.copy_(param.to(dtype=ema_param.dtype).data)
|
|
||||||
elif not param.requires_grad:
|
|
||||||
ema_param.copy_(param.to(dtype=ema_param.dtype).data)
|
ema_param.copy_(param.to(dtype=ema_param.dtype).data)
|
||||||
else:
|
else:
|
||||||
ema_param.mul_(self.decay)
|
ema_param.mul_(self.alpha)
|
||||||
ema_param.add_(param.data.to(dtype=ema_param.dtype), alpha=1 - self.decay)
|
ema_param.add_(param.data.to(dtype=ema_param.dtype), alpha=1 - self.alpha)
|
||||||
|
|
||||||
self.optimization_step += 1
|
self.optimization_step += 1
|
||||||
|
|
|
@ -130,9 +130,9 @@ class DiffusionPolicy(nn.Module):
|
||||||
|
|
||||||
def _generate_actions(self, batch):
|
def _generate_actions(self, batch):
|
||||||
if not self.training and self.ema_diffusion is not None:
|
if not self.training and self.ema_diffusion is not None:
|
||||||
return self.ema_diffusion.predict_action(batch)
|
return self.ema_diffusion.generate_actions(batch)
|
||||||
else:
|
else:
|
||||||
return self.diffusion.predict_action(batch)
|
return self.diffusion.generate_actions(batch)
|
||||||
|
|
||||||
def update(self, batch, **_):
|
def update(self, batch, **_):
|
||||||
"""Run the model in train mode, compute the loss, and do an optimization step."""
|
"""Run the model in train mode, compute the loss, and do an optimization step."""
|
||||||
|
|
Loading…
Reference in New Issue