From c9069df9f1e09a98f193eacc7241adead2d10553 Mon Sep 17 00:00:00 2001 From: Akshay Kashyap Date: Thu, 16 May 2024 10:34:10 -0400 Subject: [PATCH] Port SpatialSoftmax and remove Robomimic dependency (#182) Co-authored-by: Alexander Soare --- .../policies/diffusion/modeling_diffusion.py | 74 +++++++- poetry.lock | 179 +----------------- pyproject.toml | 1 - .../pusht_diffusion/actions.safetensors | Bin 4600 -> 4600 bytes .../pusht_diffusion/grad_stats.safetensors | Bin 47424 -> 47424 bytes .../pusht_diffusion/output_dict.safetensors | Bin 68 -> 68 bytes .../pusht_diffusion/param_stats.safetensors | Bin 49120 -> 49120 bytes 7 files changed, 75 insertions(+), 179 deletions(-) diff --git a/lerobot/common/policies/diffusion/modeling_diffusion.py b/lerobot/common/policies/diffusion/modeling_diffusion.py index 1659b68e..2ae03f22 100644 --- a/lerobot/common/policies/diffusion/modeling_diffusion.py +++ b/lerobot/common/policies/diffusion/modeling_diffusion.py @@ -17,7 +17,6 @@ """Diffusion Policy as per "Diffusion Policy: Visuomotor Policy Learning via Action Diffusion" TODO(alexander-soare): - - Remove reliance on Robomimic for SpatialSoftmax. - Remove reliance on diffusers for DDPMScheduler and LR scheduler. - Make compatible with multiple image keys. """ @@ -27,13 +26,13 @@ from collections import deque from typing import Callable import einops +import numpy as np import torch import torch.nn.functional as F # noqa: N812 import torchvision from diffusers.schedulers.scheduling_ddim import DDIMScheduler from diffusers.schedulers.scheduling_ddpm import DDPMScheduler from huggingface_hub import PyTorchModelHubMixin -from robomimic.models.base_nets import SpatialSoftmax from torch import Tensor, nn from lerobot.common.policies.diffusion.configuration_diffusion import DiffusionConfig @@ -312,6 +311,77 @@ class DiffusionModel(nn.Module): return loss.mean() +class SpatialSoftmax(nn.Module): + """ + Spatial Soft Argmax operation described in "Deep Spatial Autoencoders for Visuomotor Learning" by Finn et al. + (https://arxiv.org/pdf/1509.06113). A minimal port of the robomimic implementation. + + At a high level, this takes 2D feature maps (from a convnet/ViT) and returns the "center of mass" + of activations of each channel, i.e., keypoints in the image space for the policy to focus on. + + Example: take feature maps of size (512x10x12). We generate a grid of normalized coordinates (10x12x2): + ----------------------------------------------------- + | (-1., -1.) | (-0.82, -1.) | ... | (1., -1.) | + | (-1., -0.78) | (-0.82, -0.78) | ... | (1., -0.78) | + | ... | ... | ... | ... | + | (-1., 1.) | (-0.82, 1.) | ... | (1., 1.) | + ----------------------------------------------------- + This is achieved by applying channel-wise softmax over the activations (512x120) and computing the dot + product with the coordinates (120x2) to get expected points of maximal activation (512x2). + + The example above results in 512 keypoints (corresponding to the 512 input channels). We can optionally + provide num_kp != None to control the number of keypoints. This is achieved by a first applying a learnable + linear mapping (in_channels, H, W) -> (num_kp, H, W). + """ + + def __init__(self, input_shape, num_kp=None): + """ + Args: + input_shape (list): (C, H, W) input feature map shape. + num_kp (int): number of keypoints in output. If None, output will have the same number of channels as input. + """ + super().__init__() + + assert len(input_shape) == 3 + self._in_c, self._in_h, self._in_w = input_shape + + if num_kp is not None: + self.nets = torch.nn.Conv2d(self._in_c, num_kp, kernel_size=1) + self._out_c = num_kp + else: + self.nets = None + self._out_c = self._in_c + + # we could use torch.linspace directly but that seems to behave slightly differently than numpy + # and causes a small degradation in pc_success of pre-trained models. + pos_x, pos_y = np.meshgrid(np.linspace(-1.0, 1.0, self._in_w), np.linspace(-1.0, 1.0, self._in_h)) + pos_x = torch.from_numpy(pos_x.reshape(self._in_h * self._in_w, 1)).float() + pos_y = torch.from_numpy(pos_y.reshape(self._in_h * self._in_w, 1)).float() + # register as buffer so it's moved to the correct device. + self.register_buffer("pos_grid", torch.cat([pos_x, pos_y], dim=1)) + + def forward(self, features: Tensor) -> Tensor: + """ + Args: + features: (B, C, H, W) input feature maps. + Returns: + (B, K, 2) image-space coordinates of keypoints. + """ + if self.nets is not None: + features = self.nets(features) + + # [B, K, H, W] -> [B * K, H * W] where K is number of keypoints + features = features.reshape(-1, self._in_h * self._in_w) + # 2d softmax normalization + attention = F.softmax(features, dim=-1) + # [B * K, H * W] x [H * W, 2] -> [B * K, 2] for spatial coordinate mean in x and y dimensions + expected_xy = attention @ self.pos_grid + # reshape to [B, K, 2] + feature_keypoints = expected_xy.view(-1, self._out_c, 2) + + return feature_keypoints + + class DiffusionRgbEncoder(nn.Module): """Encoder an RGB image into a 1D feature vector. diff --git a/poetry.lock b/poetry.lock index 388e03f4..e0b27f15 100644 --- a/poetry.lock +++ b/poetry.lock @@ -4,7 +4,7 @@ name = "absl-py" version = "2.1.0" description = "Abseil Python Common Libraries, see https://github.com/abseil/abseil-py." -optional = false +optional = true python-versions = ">=3.7" files = [ {file = "absl-py-2.1.0.tar.gz", hash = "sha256:7820790efbb316739cde8b4e19357243fc3608a152024288513dd968d7d959ff"}, @@ -767,16 +767,6 @@ files = [ [package.dependencies] six = ">=1.4.0" -[[package]] -name = "egl-probe" -version = "1.0.2" -description = "" -optional = false -python-versions = "*" -files = [ - {file = "egl_probe-1.0.2.tar.gz", hash = "sha256:29bdca7b08da1e060cfb42cd46af8300a7ac4f3b1b2eeb16e545ea16d9a5ac93"}, -] - [[package]] name = "einops" version = "0.8.0" @@ -1037,64 +1027,6 @@ files = [ [package.extras] preview = ["glfw-preview"] -[[package]] -name = "grpcio" -version = "1.63.0" -description = "HTTP/2-based RPC framework" -optional = false -python-versions = ">=3.8" -files = [ - {file = "grpcio-1.63.0-cp310-cp310-linux_armv7l.whl", hash = "sha256:2e93aca840c29d4ab5db93f94ed0a0ca899e241f2e8aec6334ab3575dc46125c"}, - {file = "grpcio-1.63.0-cp310-cp310-macosx_12_0_universal2.whl", hash = "sha256:91b73d3f1340fefa1e1716c8c1ec9930c676d6b10a3513ab6c26004cb02d8b3f"}, - {file = "grpcio-1.63.0-cp310-cp310-manylinux_2_17_aarch64.whl", hash = "sha256:b3afbd9d6827fa6f475a4f91db55e441113f6d3eb9b7ebb8fb806e5bb6d6bd0d"}, - {file = "grpcio-1.63.0-cp310-cp310-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:8f3f6883ce54a7a5f47db43289a0a4c776487912de1a0e2cc83fdaec9685cc9f"}, - {file = "grpcio-1.63.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:cf8dae9cc0412cb86c8de5a8f3be395c5119a370f3ce2e69c8b7d46bb9872c8d"}, - {file = "grpcio-1.63.0-cp310-cp310-musllinux_1_1_i686.whl", hash = "sha256:08e1559fd3b3b4468486b26b0af64a3904a8dbc78d8d936af9c1cf9636eb3e8b"}, - {file = "grpcio-1.63.0-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:5c039ef01516039fa39da8a8a43a95b64e288f79f42a17e6c2904a02a319b357"}, - {file = "grpcio-1.63.0-cp310-cp310-win32.whl", hash = "sha256:ad2ac8903b2eae071055a927ef74121ed52d69468e91d9bcbd028bd0e554be6d"}, - {file = "grpcio-1.63.0-cp310-cp310-win_amd64.whl", hash = "sha256:b2e44f59316716532a993ca2966636df6fbe7be4ab6f099de6815570ebe4383a"}, - {file = "grpcio-1.63.0-cp311-cp311-linux_armv7l.whl", hash = "sha256:f28f8b2db7b86c77916829d64ab21ff49a9d8289ea1564a2b2a3a8ed9ffcccd3"}, - {file = "grpcio-1.63.0-cp311-cp311-macosx_10_9_universal2.whl", hash = "sha256:65bf975639a1f93bee63ca60d2e4951f1b543f498d581869922910a476ead2f5"}, - {file = "grpcio-1.63.0-cp311-cp311-manylinux_2_17_aarch64.whl", hash = "sha256:b5194775fec7dc3dbd6a935102bb156cd2c35efe1685b0a46c67b927c74f0cfb"}, - {file = "grpcio-1.63.0-cp311-cp311-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:e4cbb2100ee46d024c45920d16e888ee5d3cf47c66e316210bc236d5bebc42b3"}, - {file = "grpcio-1.63.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:1ff737cf29b5b801619f10e59b581869e32f400159e8b12d7a97e7e3bdeee6a2"}, - {file = "grpcio-1.63.0-cp311-cp311-musllinux_1_1_i686.whl", hash = "sha256:cd1e68776262dd44dedd7381b1a0ad09d9930ffb405f737d64f505eb7f77d6c7"}, - {file = "grpcio-1.63.0-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:93f45f27f516548e23e4ec3fbab21b060416007dbe768a111fc4611464cc773f"}, - {file = "grpcio-1.63.0-cp311-cp311-win32.whl", hash = "sha256:878b1d88d0137df60e6b09b74cdb73db123f9579232c8456f53e9abc4f62eb3c"}, - {file = "grpcio-1.63.0-cp311-cp311-win_amd64.whl", hash = "sha256:756fed02dacd24e8f488f295a913f250b56b98fb793f41d5b2de6c44fb762434"}, - {file = "grpcio-1.63.0-cp312-cp312-linux_armv7l.whl", hash = "sha256:93a46794cc96c3a674cdfb59ef9ce84d46185fe9421baf2268ccb556f8f81f57"}, - {file = "grpcio-1.63.0-cp312-cp312-macosx_10_9_universal2.whl", hash = "sha256:a7b19dfc74d0be7032ca1eda0ed545e582ee46cd65c162f9e9fc6b26ef827dc6"}, - {file = "grpcio-1.63.0-cp312-cp312-manylinux_2_17_aarch64.whl", hash = "sha256:8064d986d3a64ba21e498b9a376cbc5d6ab2e8ab0e288d39f266f0fca169b90d"}, - {file = "grpcio-1.63.0-cp312-cp312-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:219bb1848cd2c90348c79ed0a6b0ea51866bc7e72fa6e205e459fedab5770172"}, - {file = "grpcio-1.63.0-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:a2d60cd1d58817bc5985fae6168d8b5655c4981d448d0f5b6194bbcc038090d2"}, - {file = "grpcio-1.63.0-cp312-cp312-musllinux_1_1_i686.whl", hash = "sha256:9e350cb096e5c67832e9b6e018cf8a0d2a53b2a958f6251615173165269a91b0"}, - {file = "grpcio-1.63.0-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:56cdf96ff82e3cc90dbe8bac260352993f23e8e256e063c327b6cf9c88daf7a9"}, - {file = "grpcio-1.63.0-cp312-cp312-win32.whl", hash = "sha256:3a6d1f9ea965e750db7b4ee6f9fdef5fdf135abe8a249e75d84b0a3e0c668a1b"}, - {file = "grpcio-1.63.0-cp312-cp312-win_amd64.whl", hash = "sha256:d2497769895bb03efe3187fb1888fc20e98a5f18b3d14b606167dacda5789434"}, - {file = "grpcio-1.63.0-cp38-cp38-linux_armv7l.whl", hash = "sha256:fdf348ae69c6ff484402cfdb14e18c1b0054ac2420079d575c53a60b9b2853ae"}, - {file = "grpcio-1.63.0-cp38-cp38-macosx_10_9_universal2.whl", hash = "sha256:a3abfe0b0f6798dedd2e9e92e881d9acd0fdb62ae27dcbbfa7654a57e24060c0"}, - {file = "grpcio-1.63.0-cp38-cp38-manylinux_2_17_aarch64.whl", hash = "sha256:6ef0ad92873672a2a3767cb827b64741c363ebaa27e7f21659e4e31f4d750280"}, - {file = "grpcio-1.63.0-cp38-cp38-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:b416252ac5588d9dfb8a30a191451adbf534e9ce5f56bb02cd193f12d8845b7f"}, - {file = "grpcio-1.63.0-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:e3b77eaefc74d7eb861d3ffbdf91b50a1bb1639514ebe764c47773b833fa2d91"}, - {file = "grpcio-1.63.0-cp38-cp38-musllinux_1_1_i686.whl", hash = "sha256:b005292369d9c1f80bf70c1db1c17c6c342da7576f1c689e8eee4fb0c256af85"}, - {file = "grpcio-1.63.0-cp38-cp38-musllinux_1_1_x86_64.whl", hash = "sha256:cdcda1156dcc41e042d1e899ba1f5c2e9f3cd7625b3d6ebfa619806a4c1aadda"}, - {file = "grpcio-1.63.0-cp38-cp38-win32.whl", hash = "sha256:01799e8649f9e94ba7db1aeb3452188048b0019dc37696b0f5ce212c87c560c3"}, - {file = "grpcio-1.63.0-cp38-cp38-win_amd64.whl", hash = "sha256:6a1a3642d76f887aa4009d92f71eb37809abceb3b7b5a1eec9c554a246f20e3a"}, - {file = "grpcio-1.63.0-cp39-cp39-linux_armv7l.whl", hash = "sha256:75f701ff645858a2b16bc8c9fc68af215a8bb2d5a9b647448129de6e85d52bce"}, - {file = "grpcio-1.63.0-cp39-cp39-macosx_10_9_universal2.whl", hash = "sha256:cacdef0348a08e475a721967f48206a2254a1b26ee7637638d9e081761a5ba86"}, - {file = "grpcio-1.63.0-cp39-cp39-manylinux_2_17_aarch64.whl", hash = "sha256:0697563d1d84d6985e40ec5ec596ff41b52abb3fd91ec240e8cb44a63b895094"}, - {file = "grpcio-1.63.0-cp39-cp39-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:6426e1fb92d006e47476d42b8f240c1d916a6d4423c5258ccc5b105e43438f61"}, - {file = "grpcio-1.63.0-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:e48cee31bc5f5a31fb2f3b573764bd563aaa5472342860edcc7039525b53e46a"}, - {file = "grpcio-1.63.0-cp39-cp39-musllinux_1_1_i686.whl", hash = "sha256:50344663068041b34a992c19c600236e7abb42d6ec32567916b87b4c8b8833b3"}, - {file = "grpcio-1.63.0-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:259e11932230d70ef24a21b9fb5bb947eb4703f57865a404054400ee92f42f5d"}, - {file = "grpcio-1.63.0-cp39-cp39-win32.whl", hash = "sha256:a44624aad77bf8ca198c55af811fd28f2b3eaf0a50ec5b57b06c034416ef2d0a"}, - {file = "grpcio-1.63.0-cp39-cp39-win_amd64.whl", hash = "sha256:166e5c460e5d7d4656ff9e63b13e1f6029b122104c1633d5f37eaea348d7356d"}, - {file = "grpcio-1.63.0.tar.gz", hash = "sha256:f3023e14805c61bc439fb40ca545ac3d5740ce66120a678a3c6c2c55b70343d1"}, -] - -[package.extras] -protobuf = ["grpcio-tools (>=1.63.0)"] - [[package]] name = "gym-aloha" version = "0.1.1" @@ -1668,7 +1600,6 @@ files = [ {file = "lxml-5.2.1-cp37-cp37m-musllinux_1_2_x86_64.whl", hash = "sha256:9e2addd2d1866fe112bc6f80117bcc6bc25191c5ed1bfbcf9f1386a884252ae8"}, {file = "lxml-5.2.1-cp37-cp37m-win32.whl", hash = "sha256:f51969bac61441fd31f028d7b3b45962f3ecebf691a510495e5d2cd8c8092dbd"}, {file = "lxml-5.2.1-cp37-cp37m-win_amd64.whl", hash = "sha256:b0b58fbfa1bf7367dde8a557994e3b1637294be6cf2169810375caf8571a085c"}, - {file = "lxml-5.2.1-cp38-cp38-macosx_10_9_universal2.whl", hash = "sha256:3e183c6e3298a2ed5af9d7a356ea823bccaab4ec2349dc9ed83999fd289d14d5"}, {file = "lxml-5.2.1-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:804f74efe22b6a227306dd890eecc4f8c59ff25ca35f1f14e7482bbce96ef10b"}, {file = "lxml-5.2.1-cp38-cp38-manylinux_2_12_i686.manylinux2010_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:08802f0c56ed150cc6885ae0788a321b73505d2263ee56dad84d200cab11c07a"}, {file = "lxml-5.2.1-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:0f8c09ed18ecb4ebf23e02b8e7a22a05d6411911e6fabef3a36e4f371f4f2585"}, @@ -1740,21 +1671,6 @@ html5 = ["html5lib"] htmlsoup = ["BeautifulSoup4"] source = ["Cython (>=3.0.10)"] -[[package]] -name = "markdown" -version = "3.6" -description = "Python implementation of John Gruber's Markdown." -optional = false -python-versions = ">=3.8" -files = [ - {file = "Markdown-3.6-py3-none-any.whl", hash = "sha256:48f276f4d8cfb8ce6527c8f79e2ee29708508bf4d40aa410fbc3b4ee832c850f"}, - {file = "Markdown-3.6.tar.gz", hash = "sha256:ed4f41f6daecbeeb96e576ce414c41d2d876daa9a16cb35fa8ed8c2ddfad0224"}, -] - -[package.extras] -docs = ["mdx-gh-links (>=0.2)", "mkdocs (>=1.5)", "mkdocs-gen-files", "mkdocs-literate-nav", "mkdocs-nature (>=0.6)", "mkdocs-section-index", "mkdocstrings[python]"] -testing = ["coverage", "pyyaml"] - [[package]] name = "markupsafe" version = "2.1.5" @@ -3056,6 +2972,7 @@ files = [ {file = "PyYAML-6.0.1-cp311-cp311-win_amd64.whl", hash = "sha256:bf07ee2fef7014951eeb99f56f39c9bb4af143d8aa3c21b1677805985307da34"}, {file = "PyYAML-6.0.1-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:855fb52b0dc35af121542a76b9a84f8d1cd886ea97c84703eaa6d88e37a2ad28"}, {file = "PyYAML-6.0.1-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:40df9b996c2b73138957fe23a16a4f0ba614f4c0efce1e9406a184b6d07fa3a9"}, + {file = "PyYAML-6.0.1-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:a08c6f0fe150303c1c6b71ebcd7213c2858041a7e01975da3a99aed1e7a378ef"}, {file = "PyYAML-6.0.1-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:6c22bec3fbe2524cde73d7ada88f6566758a8f7227bfbf93a408a9d86bcc12a0"}, {file = "PyYAML-6.0.1-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:8d4e9c88387b0f5c7d5f281e55304de64cf7f9c0021a3525bd3b1c542da3b0e4"}, {file = "PyYAML-6.0.1-cp312-cp312-win32.whl", hash = "sha256:d483d2cdf104e7c9fa60c544d92981f12ad66a457afae824d146093b8c294c54"}, @@ -3224,30 +3141,6 @@ typing-extensions = ">=4.5" [package.extras] tests = ["pytest (==7.1.2)"] -[[package]] -name = "robomimic" -version = "0.2.0" -description = "robomimic: A Modular Framework for Robot Learning from Demonstration" -optional = false -python-versions = ">=3" -files = [ - {file = "robomimic-0.2.0.tar.gz", hash = "sha256:ee3bb5cf9c3e1feead6b57b43c5db738fd0a8e0c015fdf6419808af8fffdc463"}, -] - -[package.dependencies] -egl_probe = ">=1.0.1" -h5py = "*" -imageio = "*" -imageio-ffmpeg = "*" -numpy = ">=1.13.3" -psutil = "*" -tensorboard = "*" -tensorboardX = "*" -termcolor = "*" -torch = "*" -torchvision = "*" -tqdm = "*" - [[package]] name = "safetensors" version = "0.4.3" @@ -3738,55 +3631,6 @@ files = [ {file = "tbb-2021.12.0-py3-none-win_amd64.whl", hash = "sha256:fc2772d850229f2f3df85f1109c4844c495a2db7433d38200959ee9265b34789"}, ] -[[package]] -name = "tensorboard" -version = "2.16.2" -description = "TensorBoard lets you watch Tensors Flow" -optional = false -python-versions = ">=3.9" -files = [ - {file = "tensorboard-2.16.2-py3-none-any.whl", hash = "sha256:9f2b4e7dad86667615c0e5cd072f1ea8403fc032a299f0072d6f74855775cc45"}, -] - -[package.dependencies] -absl-py = ">=0.4" -grpcio = ">=1.48.2" -markdown = ">=2.6.8" -numpy = ">=1.12.0" -protobuf = ">=3.19.6,<4.24.0 || >4.24.0" -setuptools = ">=41.0.0" -six = ">1.9" -tensorboard-data-server = ">=0.7.0,<0.8.0" -werkzeug = ">=1.0.1" - -[[package]] -name = "tensorboard-data-server" -version = "0.7.2" -description = "Fast data loading for TensorBoard" -optional = false -python-versions = ">=3.7" -files = [ - {file = "tensorboard_data_server-0.7.2-py3-none-any.whl", hash = "sha256:7e0610d205889588983836ec05dc098e80f97b7e7bbff7e994ebb78f578d0ddb"}, - {file = "tensorboard_data_server-0.7.2-py3-none-macosx_10_9_x86_64.whl", hash = "sha256:9fe5d24221b29625dbc7328b0436ca7fc1c23de4acf4d272f1180856e32f9f60"}, - {file = "tensorboard_data_server-0.7.2-py3-none-manylinux_2_31_x86_64.whl", hash = "sha256:ef687163c24185ae9754ed5650eb5bc4d84ff257aabdc33f0cc6f74d8ba54530"}, -] - -[[package]] -name = "tensorboardx" -version = "2.6.2.2" -description = "TensorBoardX lets you watch Tensors Flow without Tensorflow" -optional = false -python-versions = "*" -files = [ - {file = "tensorboardX-2.6.2.2-py2.py3-none-any.whl", hash = "sha256:160025acbf759ede23fd3526ae9d9bfbfd8b68eb16c38a010ebe326dc6395db8"}, - {file = "tensorboardX-2.6.2.2.tar.gz", hash = "sha256:c6476d7cd0d529b0b72f4acadb1269f9ed8b22f441e87a84f2a3b940bb87b666"}, -] - -[package.dependencies] -numpy = "*" -packaging = "*" -protobuf = ">=3.20" - [[package]] name = "termcolor" version = "2.4.0" @@ -4064,23 +3908,6 @@ perf = ["orjson"] reports = ["pydantic (>=2.0.0)"] sweeps = ["sweeps (>=0.2.0)"] -[[package]] -name = "werkzeug" -version = "3.0.3" -description = "The comprehensive WSGI web application library." -optional = false -python-versions = ">=3.8" -files = [ - {file = "werkzeug-3.0.3-py3-none-any.whl", hash = "sha256:fc9645dc43e03e4d630d23143a04a7f947a9a3b5727cd535fdfe155a17cc48c8"}, - {file = "werkzeug-3.0.3.tar.gz", hash = "sha256:097e5bfda9f0aba8da6b8545146def481d06aa7d3266e7448e2cccf67dd8bd18"}, -] - -[package.dependencies] -MarkupSafe = ">=2.1.1" - -[package.extras] -watchdog = ["watchdog (>=2.3)"] - [[package]] name = "xxhash" version = "3.4.1" @@ -4348,4 +4175,4 @@ xarm = ["gym-xarm"] [metadata] lock-version = "2.0" python-versions = ">=3.10,<3.13" -content-hash = "2f0d2cbf4a2dec546e25b29b9b108ff1f97b4c278b718360b3f7f6a2bf9dcef8" +content-hash = "e3e3c306a5519e4f716a1ac086ad9b734efedcac077a0ec71e5bc16349a1e559" diff --git a/pyproject.toml b/pyproject.toml index 24d9452d..5b80d06f 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -44,7 +44,6 @@ diffusers = "^0.27.2" torchvision = ">=0.18.0" h5py = ">=3.10.0" huggingface-hub = ">=0.21.4" -robomimic = "0.2.0" gymnasium = ">=0.29.1" cmake = ">=3.29.0.1" gym-pusht = { version = ">=0.1.3", optional = true} diff --git a/tests/data/save_policy_to_safetensors/pusht_diffusion/actions.safetensors b/tests/data/save_policy_to_safetensors/pusht_diffusion/actions.safetensors index 730f5b2bc2a801d15b4ade3593c90f95650f5472..8f03990351292611f702c163ee387b3d7248b5f0 100644 GIT binary patch literal 4600 zcmb7HdsK~C8;?kl8b&kjn)4BI(3Q}bv-eM`!5}mx8pn{7GKiQiLYGl#GK3n1XneVq z3`We;lu`u=p*I%~iCd7j^M+rPcv{kHYf;6Fc+iD<5$ zXoJt5?>36&iWZuiiNvBE?owgqGRs12wpJ|KAoY=cv(?RQ$3~wWOll&wG&NfrAQqW^ zD3amCrY2?<0@&*&%l z8K~x9-1n?Uh2w1uHh#|IggVv2^@}vFQ6BHw2> zlA`}Au-~f>ZuUKBEiA{Jqu*oq&S+eEw+54^eMeu1#Gx=vhhZfoI4_%zHpj~O_+Wnr zt3^TrEJ{eXbq!hDzaEKpzw!FkYM6+C4=WO*FF}>h(LxvKWujOzZ7f5IAZbzx|1YghSu|LsQZbsC%>1R}= zrQzYi>Ab%}KRibapq}QBU~|xp``xxy4v!hWh_h@Z8OGx4-_sdga3l$ld0D)`9sO`b`Zf)>Lvn7^Z|HyFG#DSNEd^O_=cfeOw{G zLhVZWqfrZndRAg{QUMw-$8dgH{n?1O`Z0L7Ri9pWmZN;T8JCZ}bev@Bt9iLjdKV?S zr8pbZfV}mQWUM#<)&))6F6%1e0;_vK>r0%mm1Cn@HaI3uz$p($Vw^sS zMlUv@|DF5*R|>bn__JIXJ-^TKlI68L-)r@W)Lz!W>sT?n59Om*l2O0Nu}`{-j9>JO z%Y}GOS1Zt2^@z_O^tREWV|p-Ic8_YA^>Bam)KNIuHomtPX`e@AnR=e)|2NF1OvD>m z4<0nTKvAaPcGd@5@K|vI$pe&V+Y*C~rAnNxaw8q7y6_#-1>IlTF-g+_H5XgCy@T^D z9Clod)l-v^@NgeL7kyINu>E-$*Wb~X^Yzu4+J}H98B#Nqu$ONmjo~U3&@OBV>w{v~ zO>`~j!HO~JeXY1R5zf}lSeCn<6i_8@1Sqg0wi~x5%qh`!I(+39~QyMqd8M;!T$1z(w5p_~^3$l=a2$6AMUdbde)N+}H7~p48TXn0m?tO6g0MGa<$vUGNv|fq8(imp zc4Y`XRKHKkU0TpD?J++>a=Y_~Gis1T;(ZZP^BUZLo_!R@F}OO8pt_g_?8&BO zClx5!6iltl<}q6dIpQ>o=PvSmiin_d>V3`Rh36ZH<0iI${>fWB-ML>EtFyRtyAS*4U4>g^7Al*g;9>d>Q${so@A?|_ z{85LzOThhs&p4jPW4+wcRE__f-i8Dlf9U9JqZfW#Nh=)VQM`E%J(|-_HEr(F+`ac< zu<9I-(>pVrj&nYQB_FGx6xCsLFC8}-)ne0Nxh%eTI1T&^#y5zaITv0C;FYlmf`V->rPu=#JDLTwW6SM_!PdzO(Q z@;zAD*-1{0Jji4Gt2fJDz~?0^-Xd_CN_}p!Jh!^qk%NX22u+(r>pt~-@9)X4>T}CW z*Ov^F4P;!BkUi#=NG@RST!s_=`#U?3G4vtnY1tY4{tEA6Z*flC^jbvvXDe|q zE|}yf)cyX~zci#7l7tRA{V>mkzU|=ijM$oN41hlczfHdNfkh z=eK}KwsxX=&pw8>Jd}*LO~kR>ebmtH6?s+8IGYuR5yx&4uN*1#W*&xkVHt_N5RTmP zX~<1YgPC6{DhyNbs$P8ukskJh-IqU5$GAu;3mQpZTBh^9NE?NBxPL~7J{RJd20fuc^JV19>(|_VMv*uDGM4!i)}HA{*$8(?85x}tfVgz`n$Ur zKjt@}KK^$eM{8XUQuisa$RmvRt;S1*hoQonsn5}fZ!V|yhgMq^ z^oI$xONiE(Hu@&xZ?0QpyOX#y^kHwHHGM$ta=n0QFUk6+R}ymAQi;;25cSzspg>CM zBjkN5rm9g_F|w$P4oQzhaz+?hb}I2#Vr^WL5ga@6R8>2sG_cE5ppcMhM_p9+xowcX3$0}q*7huY6 zR!`SWk{DBf&@5i8)T->q>!Ce0%yk7fL5_H+MkyJl9u`;v8?oLB^A5z~zIG32O zrS2*u_D@#(X7vejhuX&|wk_ncfwxrH%Xofo-aGu!nga`GwwIYr$*?-Q?Bh;!S?t5& zdYiG=A)C)aiFphf8wYv&ez>ttxJ}Np_s` zM{leOFS|M{N&5{%JLO>>+iG@!MA^)hCkbKK8o_H0+8UeD*!0da4MX$DwcQG42KBK7Y2_Oan4Xh;d*wk9WqHeK)?g zq37(ix!v`&W z3w(td)aSP-u$=ollxRc$2;6`tn`C_MU3Y^zH6+oG8bfJKzhNi{Wcz}}6S*>u$RASM z=vN4r5}vmp*K-IZJ_f~PqUsd4o0(q&(ZEKs`9KAtBXeqIY7hiB%)%Va9P>=nK_WaNLY{X200 literal 4600 zcmb7Hd0dUzAC5^=A;d7YwArt;sD><^^AI&ng%Beu*|JPZB4o59LTO4dEeaK3k|^BL zaLZOn5-M8IK5bg=t>62eYntD``h5DF_q^wMzTfA&oO9mW)K7-}?N`@UpR-@xarZvw zjp}pMP4slt1?sL_?D^TR2D1dZwgPoW``z~If7`Ulb>nVV?ya^!Uq{#Wpg>*cTbUeB zprfrjizn-TE14$>boBKMc(&d*vpJ$bS9{h^JbTu+vU#FFSKmOJXX}46no4kopUeE- zi_3Yp(U>Hm9kq#c&d_r9e($jr;MdHrH_$tFR?r+xe;$8|5`bOtobrueaI1jOmh*5fOf_W+TRae{+OQ z(V%HnS;(``r02_ekT6pdYqsgZY+(l!BHO6^UrkV5n?yGi%*UriA;j%*4`f_KRIv9m zIy$@IYrF&>Y9w@vqcT3_T;sHe$lRw(zUD`dW;IWwJ_qed)8KaKnEoF&`?CvEw65eLu1|c-_1{m8ks)-XrF6FV*|`Mhx9)5=Q<>UyXrJRG@h~m-M^; z3tE;R#{Cs3RK`mT+3mK}aitw~ukON)4l&%G3TbIv0<~XW2VIFRRl5HN(J~ z+U01C9HR#;FTxq6#3EaxFJF&?-(a(=3o%tor5~nz!i)X_C{5c=Hz<^l{<&eyPL($x z(^5-{UM+>z>F@EI@D3TcDFWE_oUQY5c%{v!W^TXH3(w`Ss(KNzyxGO_9}zi>s`?b+ z!RdD-e4v2g6Bo5%&lWT8*_h~#kih)lioW692N|6|Gu)&@&q%mhJC>)2 zG4A~lcza0ulJC#@n+^DDk2@-dU*|MR@aB;xr&&ZYbY^3~0X3FS9;fW155HX>kX(CS z+E?5=m$ue;g43Kb3e%=qSi*AO>LO~@J+<6WT*Rka>K`=`nvMSl=lCp9xYyr1VK^~lybMSh)` zgOQ~bIFT(+HRsO522l^1^)IqGcwAm#DLEXylFgHwyXj2lIw+i_NVCj?Rjg8MSCZ(c!x)Zz)$Wu@`M>IPsQ3^#P>`&YVPJj7$ZUd`<+PrY)DOpwlI8a z(HW{&9}UH^UaSt6SHD5#*d36IaUjRquOM4vA?a1qK+?x}rYHB)C~WiSf)AIQ#&wmn z#@&e7xsoCwo&$U_H;n5$t{&{4tLvDbtGFDA$l!-!Hd>x3!wKcjq#~;fmlk+q;PlYG zcq|O8af>#xdG$O6Tr(WmXD22xjqy8L?aAv9?%d{lDBhK$8&3=A%jf5C-z1Z*Z`4hJ zLtrPcY@)>52fcuUe-C1b4`gx4U+av@2U%Io%O}n z^V8;t52ghEiNU|jrBTZ_qk4usZSxvQf6Sgl^SrJy{NdiOb>?H|cuLQ>yX-6fWxfn; z`@Rzk-EC>)TuujvDb49Y#*APZkwM6`xNf#@U(8TM&$3tu_di0e;!Ma_TR_D%TWW{n z2t(aqGhLN_s4rr*==W?7%e~DD2kLzG1W6w_0Jp}B)R*{Bjzl zlZv3XZZbq2!=ZCjs%74~Vsd?f7!5BTpuf>=gon(h2A+fIrInA+u|pnBmmbmeZT?WU zI?e2uPPqY%mh-G%tlhZ0aoU2lJK!+Ym_74xxBL51{SDm=|8E|oJ@J40mcvI4}8X+n5aVxr1e6 zHZtE%pAI4F)gzDzm+1I?m6-6Vlq}Zi#m@6~)cR-@3gTt^>Y_8LjQP76(g{E18fGuK zCJk$D6~gAwJ`$i8s+C~$eGgPN2jS(`DTqriWqA+2@Dnx2 zu0(?GF0lIRrOk>hz2WHpQ1%3ItxIL8NQg$U%-EpsQDkgz03jGW% zu4WHf@51HE0rK6*B8Z=!hwjlV{OI0<2}!v~*nN!ISnP5Ojk6OGxTlLcrTxI{$Bn2& zLg5Bfe#yn!fWyS7dwpL#nlZnVF)^>%JpJi%lx9|8dsHLqL&3+P)Ymr(OPpHBhN=^E zO44@hw4lV&^*P=p8NyPg*wu=fAl`YR)pY?y9mK<>M^KC}(@LZ0kn5Z2{9YXu(>l@I(mHAuF)$O9Qh% zZAdDMr?U1k?j$dye#V8=DmE5_CZ?g?Ae{03-7h)=OsL1lU>aoDB8}aR|CTT&a>ioZ z_YXw8rZZ0ee3fK9?a_EF5ujxFQeJn3-e{0~l!B(oGrV8_f zO9))8%yQUY5Xf|y+jZi0dJauh6Ou*VPq174FQ(HWfUkjw*d;gPqK7Q?%sx%WE-Ate z(O}#vucAkmHbY{zkWRmR0=2Z4v~Asw{X4n)m8+lG;xu~KYYDj*a1f^cJusMk3yvv1 zWKF9Hi}U8GU2I%AsRt1~T`)N*fy@3Q5Zk^+a`kB3*rZFR+n-_ntKT8~I9Y!>?~?>sV}~Nyn7eN||Ci^# z?bFe(xC6%nEa~Q%S1~W;G@hhJA>u&`Mo&(sR(rYcR1SaOMF&mF{GR^)rUD_8^Kq`T8PDd7M_FqUGKMT64?d@3 znwl5~HZ`MZOEuI^%F@;MF5*%98`>|wA3SD-Al5$&d0K^Z<}i*wPzx<~{BIW#@$}{H zGHx6nRzb7I{X^otrTxhJ&)YH;oxz0+m*1O5MG@<%-^gp^ZRK$~J|PrxHqS`;mxEMc z^)9+Q^8~r6*#(uA8LV&l`bNht!j#TWp&puqVYB@4Bi4`l>h`LR3d&wKMo^3*Z4Rci#k}&W5R_rVX}0; zIUasN*Qa%|Ug=zwPcsjcL1XrxsH^+}i>4H6o@7sSN}OOmfvY3uFK%fy;%E3mrr;GB zdT}dejw4}jpsL>Y3WcCX8k8*qs2CIyGa-toArVo!!8m{o z3SwkmWD!~H08MvSvx$*NTu z;g*FL75b|7Z_aPTR`vE}5q!J~s}^?G0?|3>^8LAT@YZ12^|4hDWU7G{E1Dqcr%Z*P zFcB~MD zXIL*NdtoFe@-{(Vsy#cRNdjSg-bFui3xqrjqYXrbmT>pe1GK;CX87!Z8MUOh4fMi7 z6#;x*@IN7S?-+CYKge;XRl0Pz+705?8_RE9wBuHNS=#l=5zanuCLN6LqsG9M%BGQZ z5aDJccipcIJ#SOxK1DikZ>JTLp5DR;t9b0p@id$j-OGx!K2c$pU{+<` zei8eAr)K*i?Xl52C9MxNdzipFPz?-}k zY702mFif7F`5AddP}JKK|3w7h)v$H$JfpBo_ShJF5t7&3r~H$Ja3r@_sj3cuw6DA9 zJ1vzUHA)gu%Tfzq`s$??Pe)58cCuaRs`F1S;C-j5K|d(y6$&)ZI0EtNQM2id+bAQv zjkJ;8=K^vXj(RJZ#BY1$D|wf|=Rap?*QqdgeYlN2ZF&T*_!3}qeFdJ(iAEQaSK}iN z+W6LVBx6`&L6_(pHu;d94rig~zlSMHgXzo!m9fn{w8O_$HbvdBlSyNZTb6y~p zT^M=-kD>;Yb=$5$;{91lI+6=Ikts@Fy?1cC{wR69z6>-Hj9WIX`;hT}%2y)T!JW5u z+U_oeH&dqaLCG!)ZR_#&Q$DpbyA$2?^8;Ke^ z^PBay7@|u_?X1DRamd(wL(aERa|VQcg|858cbODi^P8Zj>dz?#d(YF$-ZU#7^CX<^ zZeMnzMTiquU1I)Z<=utTpe;iKkrW^UnIPu#cj7`%j2i$&m#lqk@P|^ zOO5i7$f`L$ug=Z5nwG`Q$+(AszPCrqiXj!~H>cKIhk_+z>vR zkNAPK+WZ?WDb}N2KAWZIs!Qqi`UAMM?50N^d7NuF=CIaZj81;JO&;l=glu={lTR$( zL;G)`xdz~{x-_~DuhZu)Qxtru9RA|eH{cHUPd84 zfRE-GR;%k9JaKVCPWDMS197c6lZCPj+}(E#?L+k?U2<26DGFXXNCwMn{?Z*a_mB*+ zzfIs7!h_5G`4lc3Kt@HY>H6YfL_Qm(=a*~YXG<(~%`z9XzP{2bs26xj)^xGa5gwQz zCZwV)d>9L}_fqBEHaK%u&5r3mg_VX?Z1ME?U$?knK8*DY-4IIdIaY;L>8Ywl-4vYd zLa5qG#aL23t%6D=4v&1ye!@$|HMtjB`I~e-A04{N-m%r^0=%3SlpAo}?dPN!^#b%p zd*}@APpFtHrR(?BaQ>Z+4PQ!&J&E*8_xBFv;TY3uRm=L5`17NsssP7>xOZnQWphB6 N2@G4MTx_#u{sckZ-xdG> delta 1719 zcmX|=FZ%Ts}s(r%@ctUTNxBoQ++_csHpQSn$ri!Pw3z@mVGB2_pdT1l!1 zU6hxAgg|)&L6k=zl4Js+0TB@jg-RbrWZ8<}bPE3k%456~IhJuLO}3ZrYK@q>Y-c2To)`eb)(AExJpp-iABym}58q@p zq6POSU>Ypd7V#Tkg~e`7*ScDWzIU4#3)2HONwpk16kIq4Qr@nBpa-`! z6L$OIqwb&ZCySzBWNIHnM~K;B2Hpn-$qA^mM`SK15l-)QC%eSi5V!t}W@r3ia0tJS zPp+rHbjOmYERp`r#V;nAVE1^K^gKcazrTeoN>Fw}1lW9+Q9CcRTuqJEAw0-u|PizIPi$CDAS&q<-S&>IS5i>T1lB2<{~L`4Pv1mWN9 zp^u$^PHY}ZWcxqxJ=Hx8(oTu9-}I>jxUn^ckvw#UxgYYF<(C{;K+5NEEd1|oYi=Kl znSz2OXKnpt10BBfuf^EH)t{pBKCYmv-w#$ z#1sswGmU`|IA_D~?#!UD%LxqMbA}xd&BIIQ=?Tdcca^bl9Nzi8pV(w-fT44qWUuyH z_@c?1oZom3nI;N}3nDpM{oi|Jp{qGH7Hy5&uC-wcMplG0Hle_8y8Lc_XNOvjtMN^e z4)S~MA)&EEN3t*E`?*&YrMhTT#h*+=sdXFiNQVIQWhkck#ZiQ2FN+C#$VHOgI(v1} zNozEDB2gQnz6;$}mVR-wKeGiG7SI)h=KtgBH9$4W0nZA}EsIlZUeVIz zSwwO4;?6{nX^ae`u=h}VwgjBes&znt|n4BL<4Ny zb&gU557X1H#wB#@rY-bttjhoN@Mi7T9~Z-pTTFZTfl5~3*0}gcKRN$m!=-^;3Nm5Q zM0?HN&E&DX?uD#D_M9N2A!|7}_D@Rpp>z|^SojpW7-HQTac*N4;0n}SkLWE01Z-F`~f*5aI3D<3( zGkMXO)K`6rG^;QEpDmDjE1ZhZk4O6j?!0|yRm10M(GUMbGV3s5>4@K-0-s-LU;1!h znBT#zHgI-lCy_NijLwcZllage$~p9i6flp$9F&CpS9$Q8y{xl4L*RcpLc!?51r&y5 zDJ5p1$yAi+gy!am8#R=oVp5VP!P59xdu`be`-?vnG^A?%Q=W;{${7V!-g-gzGCqTf zm{?$P#xkhLH{a^u;4M^j(h}Xhb}99*g{IDybSwWg8GTH?(2f;Y@j|i%)bY2S2vemA z<~9?C>( Pt?Bc-gS6{CoasLR6e#%s diff --git a/tests/data/save_policy_to_safetensors/pusht_diffusion/output_dict.safetensors b/tests/data/save_policy_to_safetensors/pusht_diffusion/output_dict.safetensors index 5f33535d5ad316f4cef55925a215428643ed2f52..a9f61b36ed2735850072768de10a153ef1aea7aa 100644 GIT binary patch delta 9 QcmZ>9nc%?kuA##o01%i1k^lez delta 9 QcmZ>9nc%>3p0ULq01lS|@Bjb+ diff --git a/tests/data/save_policy_to_safetensors/pusht_diffusion/param_stats.safetensors b/tests/data/save_policy_to_safetensors/pusht_diffusion/param_stats.safetensors index ade6a9e03118ac410f6f0ff9357b6779074437b6..a9f4608f6c1c544e8324803f154f57cb9f6acd96 100644 GIT binary patch delta 173 zcmV;e08;h_}JR*~kx<3gn zoKLrTs;WFyv)Q_T0RcFZS-c+sB9nQ%F9A%Gsk~niIFjc#(xDPP0|sZgbL}@g8j~5l z9|3BUIlWI30y9E5VzClEsMYT|bOt~?3VxHCy*?4GIj%Q@a}qt!tyH;GpgcShlj*%r b0gsb8z8?W&lUcq$0dSLlVQ3i2|LL%w>h_}JS&rtx<3gr zoKLrTs;WFIv)Q_T0RcCYS-c+sAd`8#F9A)Hsk~niIg;l$(xDPP0|sZgbL}@g9+Mfp z9|3NYIlWI3{4zo~VzClEsMYT|bOt~?|9z91y*?4RIj%Q@a}qt!tyH;GpgcSalj*%r b0g;nAz8?W%lUcq$0d$j@zFz?!lli{q5xGhY