From 3d625ae6d3f6b875a68f43a0b697cd7fbd3d0059 Mon Sep 17 00:00:00 2001 From: Alexander Soare Date: Tue, 28 May 2024 18:27:33 +0100 Subject: [PATCH] Handle `crop_shape=None` in Diffusion Policy (#219) --- .../common/policies/diffusion/configuration_diffusion.py | 2 +- lerobot/common/policies/diffusion/modeling_diffusion.py | 8 ++++++-- poetry.lock | 4 +--- 3 files changed, 8 insertions(+), 6 deletions(-) diff --git a/lerobot/common/policies/diffusion/configuration_diffusion.py b/lerobot/common/policies/diffusion/configuration_diffusion.py index 632f6cd6..81ff5de7 100644 --- a/lerobot/common/policies/diffusion/configuration_diffusion.py +++ b/lerobot/common/policies/diffusion/configuration_diffusion.py @@ -155,7 +155,7 @@ class DiffusionConfig: f"{self.__class__.__name__} only handles one image for now. Got image keys {image_keys}." ) image_key = next(iter(image_keys)) - if ( + if self.crop_shape is not None and ( self.crop_shape[0] > self.input_shapes[image_key][1] or self.crop_shape[1] > self.input_shapes[image_key][2] ): diff --git a/lerobot/common/policies/diffusion/modeling_diffusion.py b/lerobot/common/policies/diffusion/modeling_diffusion.py index 2bf45bb6..273f4f75 100644 --- a/lerobot/common/policies/diffusion/modeling_diffusion.py +++ b/lerobot/common/policies/diffusion/modeling_diffusion.py @@ -427,11 +427,15 @@ class DiffusionRgbEncoder(nn.Module): # Set up pooling and final layers. # Use a dry run to get the feature map shape. # The dummy input should take the number of image channels from `config.input_shapes` and it should - # use the height and width from `config.crop_shape`. + # use the height and width from `config.crop_shape` if it is provided, otherwise it should use the + # height and width from `config.input_shapes`. image_keys = [k for k in config.input_shapes if k.startswith("observation.image")] assert len(image_keys) == 1 image_key = image_keys[0] - dummy_input = torch.zeros(size=(1, config.input_shapes[image_key][0], *config.crop_shape)) + dummy_input_h_w = ( + config.crop_shape if config.crop_shape is not None else config.input_shapes[image_key][1:] + ) + dummy_input = torch.zeros(size=(1, config.input_shapes[image_key][0], *dummy_input_h_w)) with torch.inference_mode(): dummy_feature_map = self.backbone(dummy_input) feature_map_shape = tuple(dummy_feature_map.shape[1:]) diff --git a/poetry.lock b/poetry.lock index 2be04ee9..3a04e3d1 100644 --- a/poetry.lock +++ b/poetry.lock @@ -1,4 +1,4 @@ -# This file is automatically @generated by Poetry 1.8.3 and should not be changed by hand. +# This file is automatically @generated by Poetry 1.8.2 and should not be changed by hand. [[package]] name = "absl-py" @@ -2406,7 +2406,6 @@ optional = false python-versions = ">=3.9" files = [ {file = "pandas-2.2.2-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:90c6fca2acf139569e74e8781709dccb6fe25940488755716d1d354d6bc58bce"}, - {file = "pandas-2.2.2-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:c7adfc142dac335d8c1e0dcbd37eb8617eac386596eb9e1a1b77791cf2498238"}, {file = "pandas-2.2.2-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:4abfe0be0d7221be4f12552995e58723c7422c80a659da13ca382697de830c08"}, {file = "pandas-2.2.2-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:8635c16bf3d99040fdf3ca3db669a7250ddf49c55dc4aa8fe0ae0fa8d6dcc1f0"}, {file = "pandas-2.2.2-cp310-cp310-musllinux_1_1_aarch64.whl", hash = "sha256:40ae1dffb3967a52203105a077415a86044a2bea011b5f321c6aa64b379a3f51"}, @@ -2427,7 +2426,6 @@ files = [ {file = "pandas-2.2.2-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:43498c0bdb43d55cb162cdc8c06fac328ccb5d2eabe3cadeb3529ae6f0517c32"}, {file = "pandas-2.2.2-cp312-cp312-win_amd64.whl", hash = "sha256:d187d355ecec3629624fccb01d104da7d7f391db0311145817525281e2804d23"}, {file = "pandas-2.2.2-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:0ca6377b8fca51815f382bd0b697a0814c8bda55115678cbc94c30aacbb6eff2"}, - {file = "pandas-2.2.2-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:9057e6aa78a584bc93a13f0a9bf7e753a5e9770a30b4d758b8d5f2a62a9433cd"}, {file = "pandas-2.2.2-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:001910ad31abc7bf06f49dcc903755d2f7f3a9186c0c040b827e522e9cef0863"}, {file = "pandas-2.2.2-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:66b479b0bd07204e37583c191535505410daa8df638fd8e75ae1b383851fe921"}, {file = "pandas-2.2.2-cp39-cp39-musllinux_1_1_aarch64.whl", hash = "sha256:a77e9d1c386196879aa5eb712e77461aaee433e54c68cf253053a73b7e49c33a"},