Add multi-image support to diffusion policy (#218)
This commit is contained in:
parent
e28fa2344c
commit
15dd682714
|
@ -28,7 +28,9 @@ class DiffusionConfig:
|
||||||
|
|
||||||
Notes on the inputs and outputs:
|
Notes on the inputs and outputs:
|
||||||
- "observation.state" is required as an input key.
|
- "observation.state" is required as an input key.
|
||||||
- A key starting with "observation.image is required as an input.
|
- At least one key starting with "observation.image is required as an input.
|
||||||
|
- If there are multiple keys beginning with "observation.image" they are treated as multiple camera
|
||||||
|
views. Right now we only support all images having the same shape.
|
||||||
- "action" is required as an output key.
|
- "action" is required as an output key.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
|
@ -153,22 +155,26 @@ class DiffusionConfig:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"`vision_backbone` must be one of the ResNet variants. Got {self.vision_backbone}."
|
f"`vision_backbone` must be one of the ResNet variants. Got {self.vision_backbone}."
|
||||||
)
|
)
|
||||||
# There should only be one image key.
|
|
||||||
image_keys = {k for k in self.input_shapes if k.startswith("observation.image")}
|
image_keys = {k for k in self.input_shapes if k.startswith("observation.image")}
|
||||||
if len(image_keys) != 1:
|
if self.crop_shape is not None:
|
||||||
raise ValueError(
|
for image_key in image_keys:
|
||||||
f"{self.__class__.__name__} only handles one image for now. Got image keys {image_keys}."
|
if (
|
||||||
)
|
self.crop_shape[0] > self.input_shapes[image_key][1]
|
||||||
image_key = next(iter(image_keys))
|
or self.crop_shape[1] > self.input_shapes[image_key][2]
|
||||||
if self.crop_shape is not None and (
|
):
|
||||||
self.crop_shape[0] > self.input_shapes[image_key][1]
|
raise ValueError(
|
||||||
or self.crop_shape[1] > self.input_shapes[image_key][2]
|
f"`crop_shape` should fit within `input_shapes[{image_key}]`. Got {self.crop_shape} "
|
||||||
):
|
f"for `crop_shape` and {self.input_shapes[image_key]} for "
|
||||||
raise ValueError(
|
"`input_shapes[{image_key}]`."
|
||||||
f"`crop_shape` should fit within `input_shapes[{image_key}]`. Got {self.crop_shape} "
|
)
|
||||||
f"for `crop_shape` and {self.input_shapes[image_key]} for "
|
# Check that all input images have the same shape.
|
||||||
"`input_shapes[{image_key}]`."
|
first_image_key = next(iter(image_keys))
|
||||||
)
|
for image_key in image_keys:
|
||||||
|
if self.input_shapes[image_key] != self.input_shapes[first_image_key]:
|
||||||
|
raise ValueError(
|
||||||
|
f"`input_shapes[{image_key}]` does not match `input_shapes[{first_image_key}]`, but we "
|
||||||
|
"expect all image shapes to match."
|
||||||
|
)
|
||||||
supported_prediction_types = ["epsilon", "sample"]
|
supported_prediction_types = ["epsilon", "sample"]
|
||||||
if self.prediction_type not in supported_prediction_types:
|
if self.prediction_type not in supported_prediction_types:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
|
|
|
@ -18,7 +18,6 @@
|
||||||
|
|
||||||
TODO(alexander-soare):
|
TODO(alexander-soare):
|
||||||
- Remove reliance on diffusers for DDPMScheduler and LR scheduler.
|
- Remove reliance on diffusers for DDPMScheduler and LR scheduler.
|
||||||
- Make compatible with multiple image keys.
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import math
|
import math
|
||||||
|
@ -83,20 +82,14 @@ class DiffusionPolicy(nn.Module, PyTorchModelHubMixin):
|
||||||
|
|
||||||
self.diffusion = DiffusionModel(config)
|
self.diffusion = DiffusionModel(config)
|
||||||
|
|
||||||
image_keys = [k for k in config.input_shapes if k.startswith("observation.image")]
|
self.expected_image_keys = [k for k in config.input_shapes if k.startswith("observation.image")]
|
||||||
# Note: This check is covered in the post-init of the config but have a sanity check just in case.
|
|
||||||
if len(image_keys) != 1:
|
|
||||||
raise NotImplementedError(
|
|
||||||
f"{self.__class__.__name__} only handles one image for now. Got image keys {image_keys}."
|
|
||||||
)
|
|
||||||
self.input_image_key = image_keys[0]
|
|
||||||
|
|
||||||
self.reset()
|
self.reset()
|
||||||
|
|
||||||
def reset(self):
|
def reset(self):
|
||||||
"""Clear observation and action queues. Should be called on `env.reset()`"""
|
"""Clear observation and action queues. Should be called on `env.reset()`"""
|
||||||
self._queues = {
|
self._queues = {
|
||||||
"observation.image": deque(maxlen=self.config.n_obs_steps),
|
"observation.images": deque(maxlen=self.config.n_obs_steps),
|
||||||
"observation.state": deque(maxlen=self.config.n_obs_steps),
|
"observation.state": deque(maxlen=self.config.n_obs_steps),
|
||||||
"action": deque(maxlen=self.config.n_action_steps),
|
"action": deque(maxlen=self.config.n_action_steps),
|
||||||
}
|
}
|
||||||
|
@ -124,8 +117,8 @@ class DiffusionPolicy(nn.Module, PyTorchModelHubMixin):
|
||||||
actually measured from the first observation which (if `n_obs_steps` > 1) happened in the past.
|
actually measured from the first observation which (if `n_obs_steps` > 1) happened in the past.
|
||||||
"""
|
"""
|
||||||
batch = self.normalize_inputs(batch)
|
batch = self.normalize_inputs(batch)
|
||||||
batch["observation.image"] = batch[self.input_image_key]
|
batch["observation.images"] = torch.stack([batch[k] for k in self.expected_image_keys], dim=-4)
|
||||||
|
# Note: It's important that this happens after stacking the images into a single key.
|
||||||
self._queues = populate_queues(self._queues, batch)
|
self._queues = populate_queues(self._queues, batch)
|
||||||
|
|
||||||
if len(self._queues["action"]) == 0:
|
if len(self._queues["action"]) == 0:
|
||||||
|
@ -144,7 +137,7 @@ class DiffusionPolicy(nn.Module, PyTorchModelHubMixin):
|
||||||
def forward(self, batch: dict[str, Tensor]) -> dict[str, Tensor]:
|
def forward(self, batch: dict[str, Tensor]) -> dict[str, Tensor]:
|
||||||
"""Run the batch through the model and compute the loss for training or validation."""
|
"""Run the batch through the model and compute the loss for training or validation."""
|
||||||
batch = self.normalize_inputs(batch)
|
batch = self.normalize_inputs(batch)
|
||||||
batch["observation.image"] = batch[self.input_image_key]
|
batch["observation.images"] = torch.stack([batch[k] for k in self.expected_image_keys], dim=-4)
|
||||||
batch = self.normalize_targets(batch)
|
batch = self.normalize_targets(batch)
|
||||||
loss = self.diffusion.compute_loss(batch)
|
loss = self.diffusion.compute_loss(batch)
|
||||||
return {"loss": loss}
|
return {"loss": loss}
|
||||||
|
@ -169,9 +162,10 @@ class DiffusionModel(nn.Module):
|
||||||
self.config = config
|
self.config = config
|
||||||
|
|
||||||
self.rgb_encoder = DiffusionRgbEncoder(config)
|
self.rgb_encoder = DiffusionRgbEncoder(config)
|
||||||
|
num_images = len([k for k in config.input_shapes if k.startswith("observation.image")])
|
||||||
self.unet = DiffusionConditionalUnet1d(
|
self.unet = DiffusionConditionalUnet1d(
|
||||||
config,
|
config,
|
||||||
global_cond_dim=(config.output_shapes["action"][0] + self.rgb_encoder.feature_dim)
|
global_cond_dim=(config.output_shapes["action"][0] + self.rgb_encoder.feature_dim * num_images)
|
||||||
* config.n_obs_steps,
|
* config.n_obs_steps,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -220,23 +214,34 @@ class DiffusionModel(nn.Module):
|
||||||
|
|
||||||
return sample
|
return sample
|
||||||
|
|
||||||
|
def _prepare_global_conditioning(self, batch: dict[str, Tensor]) -> Tensor:
|
||||||
|
"""Encode image features and concatenate them all together along with the state vector."""
|
||||||
|
batch_size, n_obs_steps = batch["observation.state"].shape[:2]
|
||||||
|
# Extract image feature (first combine batch, sequence, and camera index dims).
|
||||||
|
img_features = self.rgb_encoder(
|
||||||
|
einops.rearrange(batch["observation.images"], "b s n ... -> (b s n) ...")
|
||||||
|
)
|
||||||
|
# Separate batch dim and sequence dim back out. The camera index dim gets absorbed into the feature
|
||||||
|
# dim (effectively concatenating the camera features).
|
||||||
|
img_features = einops.rearrange(
|
||||||
|
img_features, "(b s n) ... -> b s (n ...)", b=batch_size, s=n_obs_steps
|
||||||
|
)
|
||||||
|
# Concatenate state and image features then flatten to (B, global_cond_dim).
|
||||||
|
return torch.cat([batch["observation.state"], img_features], dim=-1).flatten(start_dim=1)
|
||||||
|
|
||||||
def generate_actions(self, batch: dict[str, Tensor]) -> Tensor:
|
def generate_actions(self, batch: dict[str, Tensor]) -> Tensor:
|
||||||
"""
|
"""
|
||||||
This function expects `batch` to have:
|
This function expects `batch` to have:
|
||||||
{
|
{
|
||||||
"observation.state": (B, n_obs_steps, state_dim)
|
"observation.state": (B, n_obs_steps, state_dim)
|
||||||
"observation.image": (B, n_obs_steps, C, H, W)
|
"observation.images": (B, n_obs_steps, num_cameras, C, H, W)
|
||||||
}
|
}
|
||||||
"""
|
"""
|
||||||
batch_size, n_obs_steps = batch["observation.state"].shape[:2]
|
batch_size, n_obs_steps = batch["observation.state"].shape[:2]
|
||||||
assert n_obs_steps == self.config.n_obs_steps
|
assert n_obs_steps == self.config.n_obs_steps
|
||||||
|
|
||||||
# Extract image feature (first combine batch and sequence dims).
|
# Encode image features and concatenate them all together along with the state vector.
|
||||||
img_features = self.rgb_encoder(einops.rearrange(batch["observation.image"], "b n ... -> (b n) ..."))
|
global_cond = self._prepare_global_conditioning(batch) # (B, global_cond_dim)
|
||||||
# Separate batch and sequence dims.
|
|
||||||
img_features = einops.rearrange(img_features, "(b n) ... -> b n ...", b=batch_size)
|
|
||||||
# Concatenate state and image features then flatten to (B, global_cond_dim).
|
|
||||||
global_cond = torch.cat([batch["observation.state"], img_features], dim=-1).flatten(start_dim=1)
|
|
||||||
|
|
||||||
# run sampling
|
# run sampling
|
||||||
actions = self.conditional_sample(batch_size, global_cond=global_cond)
|
actions = self.conditional_sample(batch_size, global_cond=global_cond)
|
||||||
|
@ -253,28 +258,23 @@ class DiffusionModel(nn.Module):
|
||||||
This function expects `batch` to have (at least):
|
This function expects `batch` to have (at least):
|
||||||
{
|
{
|
||||||
"observation.state": (B, n_obs_steps, state_dim)
|
"observation.state": (B, n_obs_steps, state_dim)
|
||||||
"observation.image": (B, n_obs_steps, C, H, W)
|
"observation.images": (B, n_obs_steps, num_cameras, C, H, W)
|
||||||
"action": (B, horizon, action_dim)
|
"action": (B, horizon, action_dim)
|
||||||
"action_is_pad": (B, horizon)
|
"action_is_pad": (B, horizon)
|
||||||
}
|
}
|
||||||
"""
|
"""
|
||||||
# Input validation.
|
# Input validation.
|
||||||
assert set(batch).issuperset({"observation.state", "observation.image", "action", "action_is_pad"})
|
assert set(batch).issuperset({"observation.state", "observation.images", "action", "action_is_pad"})
|
||||||
batch_size, n_obs_steps = batch["observation.state"].shape[:2]
|
n_obs_steps = batch["observation.state"].shape[1]
|
||||||
horizon = batch["action"].shape[1]
|
horizon = batch["action"].shape[1]
|
||||||
assert horizon == self.config.horizon
|
assert horizon == self.config.horizon
|
||||||
assert n_obs_steps == self.config.n_obs_steps
|
assert n_obs_steps == self.config.n_obs_steps
|
||||||
|
|
||||||
# Extract image feature (first combine batch and sequence dims).
|
# Encode image features and concatenate them all together along with the state vector.
|
||||||
img_features = self.rgb_encoder(einops.rearrange(batch["observation.image"], "b n ... -> (b n) ..."))
|
global_cond = self._prepare_global_conditioning(batch) # (B, global_cond_dim)
|
||||||
# Separate batch and sequence dims.
|
|
||||||
img_features = einops.rearrange(img_features, "(b n) ... -> b n ...", b=batch_size)
|
|
||||||
# Concatenate state and image features then flatten to (B, global_cond_dim).
|
|
||||||
global_cond = torch.cat([batch["observation.state"], img_features], dim=-1).flatten(start_dim=1)
|
|
||||||
|
|
||||||
trajectory = batch["action"]
|
|
||||||
|
|
||||||
# Forward diffusion.
|
# Forward diffusion.
|
||||||
|
trajectory = batch["action"]
|
||||||
# Sample noise to add to the trajectory.
|
# Sample noise to add to the trajectory.
|
||||||
eps = torch.randn(trajectory.shape, device=trajectory.device)
|
eps = torch.randn(trajectory.shape, device=trajectory.device)
|
||||||
# Sample a random noising timestep for each item in the batch.
|
# Sample a random noising timestep for each item in the batch.
|
||||||
|
@ -305,7 +305,8 @@ class DiffusionModel(nn.Module):
|
||||||
if self.config.do_mask_loss_for_padding:
|
if self.config.do_mask_loss_for_padding:
|
||||||
if "action_is_pad" not in batch:
|
if "action_is_pad" not in batch:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"You need to provide 'action_is_pad' in the batch when {self.config.do_mask_loss_for_padding=}."
|
"You need to provide 'action_is_pad' in the batch when "
|
||||||
|
f"{self.config.do_mask_loss_for_padding=}."
|
||||||
)
|
)
|
||||||
in_episode_bound = ~batch["action_is_pad"]
|
in_episode_bound = ~batch["action_is_pad"]
|
||||||
loss = loss * in_episode_bound.unsqueeze(-1)
|
loss = loss * in_episode_bound.unsqueeze(-1)
|
||||||
|
@ -428,7 +429,7 @@ class DiffusionRgbEncoder(nn.Module):
|
||||||
# use the height and width from `config.crop_shape` if it is provided, otherwise it should use the
|
# use the height and width from `config.crop_shape` if it is provided, otherwise it should use the
|
||||||
# height and width from `config.input_shapes`.
|
# height and width from `config.input_shapes`.
|
||||||
image_keys = [k for k in config.input_shapes if k.startswith("observation.image")]
|
image_keys = [k for k in config.input_shapes if k.startswith("observation.image")]
|
||||||
assert len(image_keys) == 1
|
# Note: we have a check in the config class to make sure all images have the same shape.
|
||||||
image_key = image_keys[0]
|
image_key = image_keys[0]
|
||||||
dummy_input_h_w = (
|
dummy_input_h_w = (
|
||||||
config.crop_shape if config.crop_shape is not None else config.input_shapes[image_key][1:]
|
config.crop_shape if config.crop_shape is not None else config.input_shapes[image_key][1:]
|
||||||
|
|
Loading…
Reference in New Issue