first draft

This commit is contained in:
Alexander Soare 2024-05-08 12:42:40 +01:00
parent 26d9a070d8
commit 8357dae26a
3 changed files with 41 additions and 15 deletions

View File

@ -130,14 +130,21 @@ class DiffusionConfig:
raise ValueError(
f"`vision_backbone` must be one of the ResNet variants. Got {self.vision_backbone}."
)
# There should only be one image key.
image_keys = {k for k in self.input_shapes if k.startswith("observation.image")}
if len(image_keys) != 1:
raise ValueError(
f"{self.__class__.__name__} only handles one image for now. Got image keys {image_keys}."
)
image_key = next(iter(image_keys))
if (
self.crop_shape[0] > self.input_shapes["observation.image"][1]
or self.crop_shape[1] > self.input_shapes["observation.image"][2]
self.crop_shape[0] > self.input_shapes[image_key][1]
or self.crop_shape[1] > self.input_shapes[image_key][2]
):
raise ValueError(
f'`crop_shape` should fit within `input_shapes["observation.image"]`. Got {self.crop_shape} '
f'for `crop_shape` and {self.input_shapes["observation.image"]} for '
'`input_shapes["observation.image"]`.'
f"`crop_shape` should fit within `input_shapes[{image_key}]`. Got {self.crop_shape} "
f"for `crop_shape` and {self.input_shapes[image_key]} for "
"`input_shapes[{image_key}]`."
)
supported_prediction_types = ["epsilon", "sample"]
if self.prediction_type not in supported_prediction_types:

View File

@ -67,15 +67,31 @@ class DiffusionPolicy(nn.Module, PyTorchModelHubMixin):
self.diffusion = DiffusionModel(config)
def reset(self):
"""
Clear observation and action queues. Should be called on `env.reset()`
"""
"""Clear observation and action queues. Should be called on `env.reset()`"""
self._queues = {
"observation.image": deque(maxlen=self.config.n_obs_steps),
"observation.state": deque(maxlen=self.config.n_obs_steps),
"action": deque(maxlen=self.config.n_action_steps),
}
def _preprocess_batch_keys(self, batch: dict[str, Tensor], train_mode: bool = False):
"""Check that the keys can be handled by this policy and standardize the image key.
This should be run after input normalization.
"""
assert "observation.state" in batch
# There should only be one image key.
image_keys = {k for k in batch if k.startswith("observation.image") and not k.endswith("_is_pad")}
assert (
len(image_keys) == 1
), f"{self.__class__.__name__} only handles one image for now. Got image keys {image_keys}."
if train_mode:
assert "action" in batch
assert "action_is_pad" in batch
image_key = next(iter(image_keys))
batch["observation.image"] = batch[image_key]
del batch[image_key]
@torch.no_grad
def select_action(self, batch: dict[str, Tensor]) -> Tensor:
"""Select a single action given environment observations.
@ -98,10 +114,8 @@ class DiffusionPolicy(nn.Module, PyTorchModelHubMixin):
"horizon" may not the best name to describe what the variable actually means, because this period is
actually measured from the first observation which (if `n_obs_steps` > 1) happened in the past.
"""
assert "observation.image" in batch
assert "observation.state" in batch
batch = self.normalize_inputs(batch)
self._preprocess_batch_keys(batch)
self._queues = populate_queues(self._queues, batch)
@ -121,6 +135,7 @@ class DiffusionPolicy(nn.Module, PyTorchModelHubMixin):
def forward(self, batch: dict[str, Tensor]) -> dict[str, Tensor]:
"""Run the batch through the model and compute the loss for training or validation."""
batch = self.normalize_inputs(batch)
self._preprocess_batch_keys(batch)
batch = self.normalize_targets(batch)
loss = self.diffusion.compute_loss(batch)
return {"loss": loss}
@ -185,13 +200,12 @@ class DiffusionModel(nn.Module):
def generate_actions(self, batch: dict[str, Tensor]) -> Tensor:
"""
This function expects `batch` to have (at least):
This function expects `batch` to have:
{
"observation.state": (B, n_obs_steps, state_dim)
"observation.image": (B, n_obs_steps, C, H, W)
}
"""
assert set(batch).issuperset({"observation.state", "observation.image"})
batch_size, n_obs_steps = batch["observation.state"].shape[:2]
assert n_obs_steps == self.config.n_obs_steps
@ -315,9 +329,13 @@ class DiffusionRgbEncoder(nn.Module):
# Set up pooling and final layers.
# Use a dry run to get the feature map shape.
image_keys = {k for k in config.input_shapes if k.startswith("observation.image")}
assert len(image_keys) == 1
with torch.inference_mode():
feat_map_shape = tuple(
self.backbone(torch.zeros(size=(1, *config.input_shapes["observation.image"]))).shape[1:]
self.backbone(
torch.zeros(size=(1, config.input_shapes[next(iter(image_keys))][0], *config.crop_shape))
).shape[1:]
)
self.pool = SpatialSoftmax(feat_map_shape, num_kp=config.spatial_softmax_num_keypoints)
self.feature_dim = config.spatial_softmax_num_keypoints * 2

View File

@ -8,6 +8,7 @@ import hydra
import torch
from datasets import concatenate_datasets
from datasets.utils import disable_progress_bars, enable_progress_bars
from omegaconf import DictConfig
from lerobot.common.datasets.factory import make_dataset
from lerobot.common.datasets.utils import cycle
@ -290,7 +291,7 @@ def add_episodes_inplace(
sampler.num_samples = len(concat_dataset)
def train(cfg: dict, out_dir=None, job_name=None):
def train(cfg: DictConfig, out_dir: str | None = None, job_name: str | None = None):
if out_dir is None:
raise NotImplementedError()
if job_name is None: