first draft

This commit is contained in:
Alexander Soare 2024-05-08 12:42:40 +01:00
parent 26d9a070d8
commit 8357dae26a
3 changed files with 41 additions and 15 deletions

View File

@ -130,14 +130,21 @@ class DiffusionConfig:
raise ValueError( raise ValueError(
f"`vision_backbone` must be one of the ResNet variants. Got {self.vision_backbone}." f"`vision_backbone` must be one of the ResNet variants. Got {self.vision_backbone}."
) )
# There should only be one image key.
image_keys = {k for k in self.input_shapes if k.startswith("observation.image")}
if len(image_keys) != 1:
raise ValueError(
f"{self.__class__.__name__} only handles one image for now. Got image keys {image_keys}."
)
image_key = next(iter(image_keys))
if ( if (
self.crop_shape[0] > self.input_shapes["observation.image"][1] self.crop_shape[0] > self.input_shapes[image_key][1]
or self.crop_shape[1] > self.input_shapes["observation.image"][2] or self.crop_shape[1] > self.input_shapes[image_key][2]
): ):
raise ValueError( raise ValueError(
f'`crop_shape` should fit within `input_shapes["observation.image"]`. Got {self.crop_shape} ' f"`crop_shape` should fit within `input_shapes[{image_key}]`. Got {self.crop_shape} "
f'for `crop_shape` and {self.input_shapes["observation.image"]} for ' f"for `crop_shape` and {self.input_shapes[image_key]} for "
'`input_shapes["observation.image"]`.' "`input_shapes[{image_key}]`."
) )
supported_prediction_types = ["epsilon", "sample"] supported_prediction_types = ["epsilon", "sample"]
if self.prediction_type not in supported_prediction_types: if self.prediction_type not in supported_prediction_types:

View File

@ -67,15 +67,31 @@ class DiffusionPolicy(nn.Module, PyTorchModelHubMixin):
self.diffusion = DiffusionModel(config) self.diffusion = DiffusionModel(config)
def reset(self): def reset(self):
""" """Clear observation and action queues. Should be called on `env.reset()`"""
Clear observation and action queues. Should be called on `env.reset()`
"""
self._queues = { self._queues = {
"observation.image": deque(maxlen=self.config.n_obs_steps), "observation.image": deque(maxlen=self.config.n_obs_steps),
"observation.state": deque(maxlen=self.config.n_obs_steps), "observation.state": deque(maxlen=self.config.n_obs_steps),
"action": deque(maxlen=self.config.n_action_steps), "action": deque(maxlen=self.config.n_action_steps),
} }
def _preprocess_batch_keys(self, batch: dict[str, Tensor], train_mode: bool = False):
"""Check that the keys can be handled by this policy and standardize the image key.
This should be run after input normalization.
"""
assert "observation.state" in batch
# There should only be one image key.
image_keys = {k for k in batch if k.startswith("observation.image") and not k.endswith("_is_pad")}
assert (
len(image_keys) == 1
), f"{self.__class__.__name__} only handles one image for now. Got image keys {image_keys}."
if train_mode:
assert "action" in batch
assert "action_is_pad" in batch
image_key = next(iter(image_keys))
batch["observation.image"] = batch[image_key]
del batch[image_key]
@torch.no_grad @torch.no_grad
def select_action(self, batch: dict[str, Tensor]) -> Tensor: def select_action(self, batch: dict[str, Tensor]) -> Tensor:
"""Select a single action given environment observations. """Select a single action given environment observations.
@ -98,10 +114,8 @@ class DiffusionPolicy(nn.Module, PyTorchModelHubMixin):
"horizon" may not the best name to describe what the variable actually means, because this period is "horizon" may not the best name to describe what the variable actually means, because this period is
actually measured from the first observation which (if `n_obs_steps` > 1) happened in the past. actually measured from the first observation which (if `n_obs_steps` > 1) happened in the past.
""" """
assert "observation.image" in batch
assert "observation.state" in batch
batch = self.normalize_inputs(batch) batch = self.normalize_inputs(batch)
self._preprocess_batch_keys(batch)
self._queues = populate_queues(self._queues, batch) self._queues = populate_queues(self._queues, batch)
@ -121,6 +135,7 @@ class DiffusionPolicy(nn.Module, PyTorchModelHubMixin):
def forward(self, batch: dict[str, Tensor]) -> dict[str, Tensor]: def forward(self, batch: dict[str, Tensor]) -> dict[str, Tensor]:
"""Run the batch through the model and compute the loss for training or validation.""" """Run the batch through the model and compute the loss for training or validation."""
batch = self.normalize_inputs(batch) batch = self.normalize_inputs(batch)
self._preprocess_batch_keys(batch)
batch = self.normalize_targets(batch) batch = self.normalize_targets(batch)
loss = self.diffusion.compute_loss(batch) loss = self.diffusion.compute_loss(batch)
return {"loss": loss} return {"loss": loss}
@ -185,13 +200,12 @@ class DiffusionModel(nn.Module):
def generate_actions(self, batch: dict[str, Tensor]) -> Tensor: def generate_actions(self, batch: dict[str, Tensor]) -> Tensor:
""" """
This function expects `batch` to have (at least): This function expects `batch` to have:
{ {
"observation.state": (B, n_obs_steps, state_dim) "observation.state": (B, n_obs_steps, state_dim)
"observation.image": (B, n_obs_steps, C, H, W) "observation.image": (B, n_obs_steps, C, H, W)
} }
""" """
assert set(batch).issuperset({"observation.state", "observation.image"})
batch_size, n_obs_steps = batch["observation.state"].shape[:2] batch_size, n_obs_steps = batch["observation.state"].shape[:2]
assert n_obs_steps == self.config.n_obs_steps assert n_obs_steps == self.config.n_obs_steps
@ -315,9 +329,13 @@ class DiffusionRgbEncoder(nn.Module):
# Set up pooling and final layers. # Set up pooling and final layers.
# Use a dry run to get the feature map shape. # Use a dry run to get the feature map shape.
image_keys = {k for k in config.input_shapes if k.startswith("observation.image")}
assert len(image_keys) == 1
with torch.inference_mode(): with torch.inference_mode():
feat_map_shape = tuple( feat_map_shape = tuple(
self.backbone(torch.zeros(size=(1, *config.input_shapes["observation.image"]))).shape[1:] self.backbone(
torch.zeros(size=(1, config.input_shapes[next(iter(image_keys))][0], *config.crop_shape))
).shape[1:]
) )
self.pool = SpatialSoftmax(feat_map_shape, num_kp=config.spatial_softmax_num_keypoints) self.pool = SpatialSoftmax(feat_map_shape, num_kp=config.spatial_softmax_num_keypoints)
self.feature_dim = config.spatial_softmax_num_keypoints * 2 self.feature_dim = config.spatial_softmax_num_keypoints * 2

View File

@ -8,6 +8,7 @@ import hydra
import torch import torch
from datasets import concatenate_datasets from datasets import concatenate_datasets
from datasets.utils import disable_progress_bars, enable_progress_bars from datasets.utils import disable_progress_bars, enable_progress_bars
from omegaconf import DictConfig
from lerobot.common.datasets.factory import make_dataset from lerobot.common.datasets.factory import make_dataset
from lerobot.common.datasets.utils import cycle from lerobot.common.datasets.utils import cycle
@ -290,7 +291,7 @@ def add_episodes_inplace(
sampler.num_samples = len(concat_dataset) sampler.num_samples = len(concat_dataset)
def train(cfg: dict, out_dir=None, job_name=None): def train(cfg: DictConfig, out_dir: str | None = None, job_name: str | None = None):
if out_dir is None: if out_dir is None:
raise NotImplementedError() raise NotImplementedError()
if job_name is None: if job_name is None: