first draft
This commit is contained in:
parent
26d9a070d8
commit
8357dae26a
|
@ -130,14 +130,21 @@ class DiffusionConfig:
|
|||
raise ValueError(
|
||||
f"`vision_backbone` must be one of the ResNet variants. Got {self.vision_backbone}."
|
||||
)
|
||||
# There should only be one image key.
|
||||
image_keys = {k for k in self.input_shapes if k.startswith("observation.image")}
|
||||
if len(image_keys) != 1:
|
||||
raise ValueError(
|
||||
f"{self.__class__.__name__} only handles one image for now. Got image keys {image_keys}."
|
||||
)
|
||||
image_key = next(iter(image_keys))
|
||||
if (
|
||||
self.crop_shape[0] > self.input_shapes["observation.image"][1]
|
||||
or self.crop_shape[1] > self.input_shapes["observation.image"][2]
|
||||
self.crop_shape[0] > self.input_shapes[image_key][1]
|
||||
or self.crop_shape[1] > self.input_shapes[image_key][2]
|
||||
):
|
||||
raise ValueError(
|
||||
f'`crop_shape` should fit within `input_shapes["observation.image"]`. Got {self.crop_shape} '
|
||||
f'for `crop_shape` and {self.input_shapes["observation.image"]} for '
|
||||
'`input_shapes["observation.image"]`.'
|
||||
f"`crop_shape` should fit within `input_shapes[{image_key}]`. Got {self.crop_shape} "
|
||||
f"for `crop_shape` and {self.input_shapes[image_key]} for "
|
||||
"`input_shapes[{image_key}]`."
|
||||
)
|
||||
supported_prediction_types = ["epsilon", "sample"]
|
||||
if self.prediction_type not in supported_prediction_types:
|
||||
|
|
|
@ -67,15 +67,31 @@ class DiffusionPolicy(nn.Module, PyTorchModelHubMixin):
|
|||
self.diffusion = DiffusionModel(config)
|
||||
|
||||
def reset(self):
|
||||
"""
|
||||
Clear observation and action queues. Should be called on `env.reset()`
|
||||
"""
|
||||
"""Clear observation and action queues. Should be called on `env.reset()`"""
|
||||
self._queues = {
|
||||
"observation.image": deque(maxlen=self.config.n_obs_steps),
|
||||
"observation.state": deque(maxlen=self.config.n_obs_steps),
|
||||
"action": deque(maxlen=self.config.n_action_steps),
|
||||
}
|
||||
|
||||
def _preprocess_batch_keys(self, batch: dict[str, Tensor], train_mode: bool = False):
|
||||
"""Check that the keys can be handled by this policy and standardize the image key.
|
||||
|
||||
This should be run after input normalization.
|
||||
"""
|
||||
assert "observation.state" in batch
|
||||
# There should only be one image key.
|
||||
image_keys = {k for k in batch if k.startswith("observation.image") and not k.endswith("_is_pad")}
|
||||
assert (
|
||||
len(image_keys) == 1
|
||||
), f"{self.__class__.__name__} only handles one image for now. Got image keys {image_keys}."
|
||||
if train_mode:
|
||||
assert "action" in batch
|
||||
assert "action_is_pad" in batch
|
||||
image_key = next(iter(image_keys))
|
||||
batch["observation.image"] = batch[image_key]
|
||||
del batch[image_key]
|
||||
|
||||
@torch.no_grad
|
||||
def select_action(self, batch: dict[str, Tensor]) -> Tensor:
|
||||
"""Select a single action given environment observations.
|
||||
|
@ -98,10 +114,8 @@ class DiffusionPolicy(nn.Module, PyTorchModelHubMixin):
|
|||
"horizon" may not the best name to describe what the variable actually means, because this period is
|
||||
actually measured from the first observation which (if `n_obs_steps` > 1) happened in the past.
|
||||
"""
|
||||
assert "observation.image" in batch
|
||||
assert "observation.state" in batch
|
||||
|
||||
batch = self.normalize_inputs(batch)
|
||||
self._preprocess_batch_keys(batch)
|
||||
|
||||
self._queues = populate_queues(self._queues, batch)
|
||||
|
||||
|
@ -121,6 +135,7 @@ class DiffusionPolicy(nn.Module, PyTorchModelHubMixin):
|
|||
def forward(self, batch: dict[str, Tensor]) -> dict[str, Tensor]:
|
||||
"""Run the batch through the model and compute the loss for training or validation."""
|
||||
batch = self.normalize_inputs(batch)
|
||||
self._preprocess_batch_keys(batch)
|
||||
batch = self.normalize_targets(batch)
|
||||
loss = self.diffusion.compute_loss(batch)
|
||||
return {"loss": loss}
|
||||
|
@ -185,13 +200,12 @@ class DiffusionModel(nn.Module):
|
|||
|
||||
def generate_actions(self, batch: dict[str, Tensor]) -> Tensor:
|
||||
"""
|
||||
This function expects `batch` to have (at least):
|
||||
This function expects `batch` to have:
|
||||
{
|
||||
"observation.state": (B, n_obs_steps, state_dim)
|
||||
"observation.image": (B, n_obs_steps, C, H, W)
|
||||
}
|
||||
"""
|
||||
assert set(batch).issuperset({"observation.state", "observation.image"})
|
||||
batch_size, n_obs_steps = batch["observation.state"].shape[:2]
|
||||
assert n_obs_steps == self.config.n_obs_steps
|
||||
|
||||
|
@ -315,9 +329,13 @@ class DiffusionRgbEncoder(nn.Module):
|
|||
|
||||
# Set up pooling and final layers.
|
||||
# Use a dry run to get the feature map shape.
|
||||
image_keys = {k for k in config.input_shapes if k.startswith("observation.image")}
|
||||
assert len(image_keys) == 1
|
||||
with torch.inference_mode():
|
||||
feat_map_shape = tuple(
|
||||
self.backbone(torch.zeros(size=(1, *config.input_shapes["observation.image"]))).shape[1:]
|
||||
self.backbone(
|
||||
torch.zeros(size=(1, config.input_shapes[next(iter(image_keys))][0], *config.crop_shape))
|
||||
).shape[1:]
|
||||
)
|
||||
self.pool = SpatialSoftmax(feat_map_shape, num_kp=config.spatial_softmax_num_keypoints)
|
||||
self.feature_dim = config.spatial_softmax_num_keypoints * 2
|
||||
|
|
|
@ -8,6 +8,7 @@ import hydra
|
|||
import torch
|
||||
from datasets import concatenate_datasets
|
||||
from datasets.utils import disable_progress_bars, enable_progress_bars
|
||||
from omegaconf import DictConfig
|
||||
|
||||
from lerobot.common.datasets.factory import make_dataset
|
||||
from lerobot.common.datasets.utils import cycle
|
||||
|
@ -290,7 +291,7 @@ def add_episodes_inplace(
|
|||
sampler.num_samples = len(concat_dataset)
|
||||
|
||||
|
||||
def train(cfg: dict, out_dir=None, job_name=None):
|
||||
def train(cfg: DictConfig, out_dir: str | None = None, job_name: str | None = None):
|
||||
if out_dir is None:
|
||||
raise NotImplementedError()
|
||||
if job_name is None:
|
||||
|
|
Loading…
Reference in New Issue