first draft
This commit is contained in:
parent
26d9a070d8
commit
8357dae26a
|
@ -130,14 +130,21 @@ class DiffusionConfig:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"`vision_backbone` must be one of the ResNet variants. Got {self.vision_backbone}."
|
f"`vision_backbone` must be one of the ResNet variants. Got {self.vision_backbone}."
|
||||||
)
|
)
|
||||||
|
# There should only be one image key.
|
||||||
|
image_keys = {k for k in self.input_shapes if k.startswith("observation.image")}
|
||||||
|
if len(image_keys) != 1:
|
||||||
|
raise ValueError(
|
||||||
|
f"{self.__class__.__name__} only handles one image for now. Got image keys {image_keys}."
|
||||||
|
)
|
||||||
|
image_key = next(iter(image_keys))
|
||||||
if (
|
if (
|
||||||
self.crop_shape[0] > self.input_shapes["observation.image"][1]
|
self.crop_shape[0] > self.input_shapes[image_key][1]
|
||||||
or self.crop_shape[1] > self.input_shapes["observation.image"][2]
|
or self.crop_shape[1] > self.input_shapes[image_key][2]
|
||||||
):
|
):
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
f'`crop_shape` should fit within `input_shapes["observation.image"]`. Got {self.crop_shape} '
|
f"`crop_shape` should fit within `input_shapes[{image_key}]`. Got {self.crop_shape} "
|
||||||
f'for `crop_shape` and {self.input_shapes["observation.image"]} for '
|
f"for `crop_shape` and {self.input_shapes[image_key]} for "
|
||||||
'`input_shapes["observation.image"]`.'
|
"`input_shapes[{image_key}]`."
|
||||||
)
|
)
|
||||||
supported_prediction_types = ["epsilon", "sample"]
|
supported_prediction_types = ["epsilon", "sample"]
|
||||||
if self.prediction_type not in supported_prediction_types:
|
if self.prediction_type not in supported_prediction_types:
|
||||||
|
|
|
@ -67,15 +67,31 @@ class DiffusionPolicy(nn.Module, PyTorchModelHubMixin):
|
||||||
self.diffusion = DiffusionModel(config)
|
self.diffusion = DiffusionModel(config)
|
||||||
|
|
||||||
def reset(self):
|
def reset(self):
|
||||||
"""
|
"""Clear observation and action queues. Should be called on `env.reset()`"""
|
||||||
Clear observation and action queues. Should be called on `env.reset()`
|
|
||||||
"""
|
|
||||||
self._queues = {
|
self._queues = {
|
||||||
"observation.image": deque(maxlen=self.config.n_obs_steps),
|
"observation.image": deque(maxlen=self.config.n_obs_steps),
|
||||||
"observation.state": deque(maxlen=self.config.n_obs_steps),
|
"observation.state": deque(maxlen=self.config.n_obs_steps),
|
||||||
"action": deque(maxlen=self.config.n_action_steps),
|
"action": deque(maxlen=self.config.n_action_steps),
|
||||||
}
|
}
|
||||||
|
|
||||||
|
def _preprocess_batch_keys(self, batch: dict[str, Tensor], train_mode: bool = False):
|
||||||
|
"""Check that the keys can be handled by this policy and standardize the image key.
|
||||||
|
|
||||||
|
This should be run after input normalization.
|
||||||
|
"""
|
||||||
|
assert "observation.state" in batch
|
||||||
|
# There should only be one image key.
|
||||||
|
image_keys = {k for k in batch if k.startswith("observation.image") and not k.endswith("_is_pad")}
|
||||||
|
assert (
|
||||||
|
len(image_keys) == 1
|
||||||
|
), f"{self.__class__.__name__} only handles one image for now. Got image keys {image_keys}."
|
||||||
|
if train_mode:
|
||||||
|
assert "action" in batch
|
||||||
|
assert "action_is_pad" in batch
|
||||||
|
image_key = next(iter(image_keys))
|
||||||
|
batch["observation.image"] = batch[image_key]
|
||||||
|
del batch[image_key]
|
||||||
|
|
||||||
@torch.no_grad
|
@torch.no_grad
|
||||||
def select_action(self, batch: dict[str, Tensor]) -> Tensor:
|
def select_action(self, batch: dict[str, Tensor]) -> Tensor:
|
||||||
"""Select a single action given environment observations.
|
"""Select a single action given environment observations.
|
||||||
|
@ -98,10 +114,8 @@ class DiffusionPolicy(nn.Module, PyTorchModelHubMixin):
|
||||||
"horizon" may not the best name to describe what the variable actually means, because this period is
|
"horizon" may not the best name to describe what the variable actually means, because this period is
|
||||||
actually measured from the first observation which (if `n_obs_steps` > 1) happened in the past.
|
actually measured from the first observation which (if `n_obs_steps` > 1) happened in the past.
|
||||||
"""
|
"""
|
||||||
assert "observation.image" in batch
|
|
||||||
assert "observation.state" in batch
|
|
||||||
|
|
||||||
batch = self.normalize_inputs(batch)
|
batch = self.normalize_inputs(batch)
|
||||||
|
self._preprocess_batch_keys(batch)
|
||||||
|
|
||||||
self._queues = populate_queues(self._queues, batch)
|
self._queues = populate_queues(self._queues, batch)
|
||||||
|
|
||||||
|
@ -121,6 +135,7 @@ class DiffusionPolicy(nn.Module, PyTorchModelHubMixin):
|
||||||
def forward(self, batch: dict[str, Tensor]) -> dict[str, Tensor]:
|
def forward(self, batch: dict[str, Tensor]) -> dict[str, Tensor]:
|
||||||
"""Run the batch through the model and compute the loss for training or validation."""
|
"""Run the batch through the model and compute the loss for training or validation."""
|
||||||
batch = self.normalize_inputs(batch)
|
batch = self.normalize_inputs(batch)
|
||||||
|
self._preprocess_batch_keys(batch)
|
||||||
batch = self.normalize_targets(batch)
|
batch = self.normalize_targets(batch)
|
||||||
loss = self.diffusion.compute_loss(batch)
|
loss = self.diffusion.compute_loss(batch)
|
||||||
return {"loss": loss}
|
return {"loss": loss}
|
||||||
|
@ -185,13 +200,12 @@ class DiffusionModel(nn.Module):
|
||||||
|
|
||||||
def generate_actions(self, batch: dict[str, Tensor]) -> Tensor:
|
def generate_actions(self, batch: dict[str, Tensor]) -> Tensor:
|
||||||
"""
|
"""
|
||||||
This function expects `batch` to have (at least):
|
This function expects `batch` to have:
|
||||||
{
|
{
|
||||||
"observation.state": (B, n_obs_steps, state_dim)
|
"observation.state": (B, n_obs_steps, state_dim)
|
||||||
"observation.image": (B, n_obs_steps, C, H, W)
|
"observation.image": (B, n_obs_steps, C, H, W)
|
||||||
}
|
}
|
||||||
"""
|
"""
|
||||||
assert set(batch).issuperset({"observation.state", "observation.image"})
|
|
||||||
batch_size, n_obs_steps = batch["observation.state"].shape[:2]
|
batch_size, n_obs_steps = batch["observation.state"].shape[:2]
|
||||||
assert n_obs_steps == self.config.n_obs_steps
|
assert n_obs_steps == self.config.n_obs_steps
|
||||||
|
|
||||||
|
@ -315,9 +329,13 @@ class DiffusionRgbEncoder(nn.Module):
|
||||||
|
|
||||||
# Set up pooling and final layers.
|
# Set up pooling and final layers.
|
||||||
# Use a dry run to get the feature map shape.
|
# Use a dry run to get the feature map shape.
|
||||||
|
image_keys = {k for k in config.input_shapes if k.startswith("observation.image")}
|
||||||
|
assert len(image_keys) == 1
|
||||||
with torch.inference_mode():
|
with torch.inference_mode():
|
||||||
feat_map_shape = tuple(
|
feat_map_shape = tuple(
|
||||||
self.backbone(torch.zeros(size=(1, *config.input_shapes["observation.image"]))).shape[1:]
|
self.backbone(
|
||||||
|
torch.zeros(size=(1, config.input_shapes[next(iter(image_keys))][0], *config.crop_shape))
|
||||||
|
).shape[1:]
|
||||||
)
|
)
|
||||||
self.pool = SpatialSoftmax(feat_map_shape, num_kp=config.spatial_softmax_num_keypoints)
|
self.pool = SpatialSoftmax(feat_map_shape, num_kp=config.spatial_softmax_num_keypoints)
|
||||||
self.feature_dim = config.spatial_softmax_num_keypoints * 2
|
self.feature_dim = config.spatial_softmax_num_keypoints * 2
|
||||||
|
|
|
@ -8,6 +8,7 @@ import hydra
|
||||||
import torch
|
import torch
|
||||||
from datasets import concatenate_datasets
|
from datasets import concatenate_datasets
|
||||||
from datasets.utils import disable_progress_bars, enable_progress_bars
|
from datasets.utils import disable_progress_bars, enable_progress_bars
|
||||||
|
from omegaconf import DictConfig
|
||||||
|
|
||||||
from lerobot.common.datasets.factory import make_dataset
|
from lerobot.common.datasets.factory import make_dataset
|
||||||
from lerobot.common.datasets.utils import cycle
|
from lerobot.common.datasets.utils import cycle
|
||||||
|
@ -290,7 +291,7 @@ def add_episodes_inplace(
|
||||||
sampler.num_samples = len(concat_dataset)
|
sampler.num_samples = len(concat_dataset)
|
||||||
|
|
||||||
|
|
||||||
def train(cfg: dict, out_dir=None, job_name=None):
|
def train(cfg: DictConfig, out_dir: str | None = None, job_name: str | None = None):
|
||||||
if out_dir is None:
|
if out_dir is None:
|
||||||
raise NotImplementedError()
|
raise NotImplementedError()
|
||||||
if job_name is None:
|
if job_name is None:
|
||||||
|
|
Loading…
Reference in New Issue