diff --git a/lerobot/common/policies/diffusion/modeling_diffusion.py b/lerobot/common/policies/diffusion/modeling_diffusion.py index 1659b68e..2ae03f22 100644 --- a/lerobot/common/policies/diffusion/modeling_diffusion.py +++ b/lerobot/common/policies/diffusion/modeling_diffusion.py @@ -17,7 +17,6 @@ """Diffusion Policy as per "Diffusion Policy: Visuomotor Policy Learning via Action Diffusion" TODO(alexander-soare): - - Remove reliance on Robomimic for SpatialSoftmax. - Remove reliance on diffusers for DDPMScheduler and LR scheduler. - Make compatible with multiple image keys. """ @@ -27,13 +26,13 @@ from collections import deque from typing import Callable import einops +import numpy as np import torch import torch.nn.functional as F # noqa: N812 import torchvision from diffusers.schedulers.scheduling_ddim import DDIMScheduler from diffusers.schedulers.scheduling_ddpm import DDPMScheduler from huggingface_hub import PyTorchModelHubMixin -from robomimic.models.base_nets import SpatialSoftmax from torch import Tensor, nn from lerobot.common.policies.diffusion.configuration_diffusion import DiffusionConfig @@ -312,6 +311,77 @@ class DiffusionModel(nn.Module): return loss.mean() +class SpatialSoftmax(nn.Module): + """ + Spatial Soft Argmax operation described in "Deep Spatial Autoencoders for Visuomotor Learning" by Finn et al. + (https://arxiv.org/pdf/1509.06113). A minimal port of the robomimic implementation. + + At a high level, this takes 2D feature maps (from a convnet/ViT) and returns the "center of mass" + of activations of each channel, i.e., keypoints in the image space for the policy to focus on. + + Example: take feature maps of size (512x10x12). We generate a grid of normalized coordinates (10x12x2): + ----------------------------------------------------- + | (-1., -1.) | (-0.82, -1.) | ... | (1., -1.) | + | (-1., -0.78) | (-0.82, -0.78) | ... | (1., -0.78) | + | ... | ... | ... | ... | + | (-1., 1.) | (-0.82, 1.) | ... | (1., 1.) | + ----------------------------------------------------- + This is achieved by applying channel-wise softmax over the activations (512x120) and computing the dot + product with the coordinates (120x2) to get expected points of maximal activation (512x2). + + The example above results in 512 keypoints (corresponding to the 512 input channels). We can optionally + provide num_kp != None to control the number of keypoints. This is achieved by a first applying a learnable + linear mapping (in_channels, H, W) -> (num_kp, H, W). + """ + + def __init__(self, input_shape, num_kp=None): + """ + Args: + input_shape (list): (C, H, W) input feature map shape. + num_kp (int): number of keypoints in output. If None, output will have the same number of channels as input. + """ + super().__init__() + + assert len(input_shape) == 3 + self._in_c, self._in_h, self._in_w = input_shape + + if num_kp is not None: + self.nets = torch.nn.Conv2d(self._in_c, num_kp, kernel_size=1) + self._out_c = num_kp + else: + self.nets = None + self._out_c = self._in_c + + # we could use torch.linspace directly but that seems to behave slightly differently than numpy + # and causes a small degradation in pc_success of pre-trained models. + pos_x, pos_y = np.meshgrid(np.linspace(-1.0, 1.0, self._in_w), np.linspace(-1.0, 1.0, self._in_h)) + pos_x = torch.from_numpy(pos_x.reshape(self._in_h * self._in_w, 1)).float() + pos_y = torch.from_numpy(pos_y.reshape(self._in_h * self._in_w, 1)).float() + # register as buffer so it's moved to the correct device. + self.register_buffer("pos_grid", torch.cat([pos_x, pos_y], dim=1)) + + def forward(self, features: Tensor) -> Tensor: + """ + Args: + features: (B, C, H, W) input feature maps. + Returns: + (B, K, 2) image-space coordinates of keypoints. + """ + if self.nets is not None: + features = self.nets(features) + + # [B, K, H, W] -> [B * K, H * W] where K is number of keypoints + features = features.reshape(-1, self._in_h * self._in_w) + # 2d softmax normalization + attention = F.softmax(features, dim=-1) + # [B * K, H * W] x [H * W, 2] -> [B * K, 2] for spatial coordinate mean in x and y dimensions + expected_xy = attention @ self.pos_grid + # reshape to [B, K, 2] + feature_keypoints = expected_xy.view(-1, self._out_c, 2) + + return feature_keypoints + + class DiffusionRgbEncoder(nn.Module): """Encoder an RGB image into a 1D feature vector. diff --git a/poetry.lock b/poetry.lock index 388e03f4..e0b27f15 100644 --- a/poetry.lock +++ b/poetry.lock @@ -4,7 +4,7 @@ name = "absl-py" version = "2.1.0" description = "Abseil Python Common Libraries, see https://github.com/abseil/abseil-py." -optional = false +optional = true python-versions = ">=3.7" files = [ {file = "absl-py-2.1.0.tar.gz", hash = "sha256:7820790efbb316739cde8b4e19357243fc3608a152024288513dd968d7d959ff"}, @@ -767,16 +767,6 @@ files = [ [package.dependencies] six = ">=1.4.0" -[[package]] -name = "egl-probe" -version = "1.0.2" -description = "" -optional = false -python-versions = "*" -files = [ - {file = "egl_probe-1.0.2.tar.gz", hash = "sha256:29bdca7b08da1e060cfb42cd46af8300a7ac4f3b1b2eeb16e545ea16d9a5ac93"}, -] - [[package]] name = "einops" version = "0.8.0" @@ -1037,64 +1027,6 @@ files = [ [package.extras] preview = ["glfw-preview"] -[[package]] -name = "grpcio" -version = "1.63.0" -description = "HTTP/2-based RPC framework" -optional = false -python-versions = ">=3.8" -files = [ - {file = "grpcio-1.63.0-cp310-cp310-linux_armv7l.whl", hash = "sha256:2e93aca840c29d4ab5db93f94ed0a0ca899e241f2e8aec6334ab3575dc46125c"}, - {file = "grpcio-1.63.0-cp310-cp310-macosx_12_0_universal2.whl", hash = "sha256:91b73d3f1340fefa1e1716c8c1ec9930c676d6b10a3513ab6c26004cb02d8b3f"}, - {file = "grpcio-1.63.0-cp310-cp310-manylinux_2_17_aarch64.whl", hash = "sha256:b3afbd9d6827fa6f475a4f91db55e441113f6d3eb9b7ebb8fb806e5bb6d6bd0d"}, - {file = "grpcio-1.63.0-cp310-cp310-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:8f3f6883ce54a7a5f47db43289a0a4c776487912de1a0e2cc83fdaec9685cc9f"}, - {file = "grpcio-1.63.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:cf8dae9cc0412cb86c8de5a8f3be395c5119a370f3ce2e69c8b7d46bb9872c8d"}, - {file = "grpcio-1.63.0-cp310-cp310-musllinux_1_1_i686.whl", hash = "sha256:08e1559fd3b3b4468486b26b0af64a3904a8dbc78d8d936af9c1cf9636eb3e8b"}, - {file = "grpcio-1.63.0-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:5c039ef01516039fa39da8a8a43a95b64e288f79f42a17e6c2904a02a319b357"}, - {file = "grpcio-1.63.0-cp310-cp310-win32.whl", hash = "sha256:ad2ac8903b2eae071055a927ef74121ed52d69468e91d9bcbd028bd0e554be6d"}, - {file = "grpcio-1.63.0-cp310-cp310-win_amd64.whl", hash = "sha256:b2e44f59316716532a993ca2966636df6fbe7be4ab6f099de6815570ebe4383a"}, - {file = "grpcio-1.63.0-cp311-cp311-linux_armv7l.whl", hash = "sha256:f28f8b2db7b86c77916829d64ab21ff49a9d8289ea1564a2b2a3a8ed9ffcccd3"}, - {file = "grpcio-1.63.0-cp311-cp311-macosx_10_9_universal2.whl", hash = "sha256:65bf975639a1f93bee63ca60d2e4951f1b543f498d581869922910a476ead2f5"}, - {file = "grpcio-1.63.0-cp311-cp311-manylinux_2_17_aarch64.whl", hash = "sha256:b5194775fec7dc3dbd6a935102bb156cd2c35efe1685b0a46c67b927c74f0cfb"}, - {file = "grpcio-1.63.0-cp311-cp311-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:e4cbb2100ee46d024c45920d16e888ee5d3cf47c66e316210bc236d5bebc42b3"}, - {file = "grpcio-1.63.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:1ff737cf29b5b801619f10e59b581869e32f400159e8b12d7a97e7e3bdeee6a2"}, - {file = "grpcio-1.63.0-cp311-cp311-musllinux_1_1_i686.whl", hash = "sha256:cd1e68776262dd44dedd7381b1a0ad09d9930ffb405f737d64f505eb7f77d6c7"}, - {file = "grpcio-1.63.0-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:93f45f27f516548e23e4ec3fbab21b060416007dbe768a111fc4611464cc773f"}, - {file = "grpcio-1.63.0-cp311-cp311-win32.whl", hash = "sha256:878b1d88d0137df60e6b09b74cdb73db123f9579232c8456f53e9abc4f62eb3c"}, - {file = "grpcio-1.63.0-cp311-cp311-win_amd64.whl", hash = "sha256:756fed02dacd24e8f488f295a913f250b56b98fb793f41d5b2de6c44fb762434"}, - {file = "grpcio-1.63.0-cp312-cp312-linux_armv7l.whl", hash = "sha256:93a46794cc96c3a674cdfb59ef9ce84d46185fe9421baf2268ccb556f8f81f57"}, - {file = "grpcio-1.63.0-cp312-cp312-macosx_10_9_universal2.whl", hash = "sha256:a7b19dfc74d0be7032ca1eda0ed545e582ee46cd65c162f9e9fc6b26ef827dc6"}, - {file = "grpcio-1.63.0-cp312-cp312-manylinux_2_17_aarch64.whl", hash = "sha256:8064d986d3a64ba21e498b9a376cbc5d6ab2e8ab0e288d39f266f0fca169b90d"}, - {file = "grpcio-1.63.0-cp312-cp312-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:219bb1848cd2c90348c79ed0a6b0ea51866bc7e72fa6e205e459fedab5770172"}, - {file = "grpcio-1.63.0-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:a2d60cd1d58817bc5985fae6168d8b5655c4981d448d0f5b6194bbcc038090d2"}, - {file = "grpcio-1.63.0-cp312-cp312-musllinux_1_1_i686.whl", hash = "sha256:9e350cb096e5c67832e9b6e018cf8a0d2a53b2a958f6251615173165269a91b0"}, - {file = "grpcio-1.63.0-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:56cdf96ff82e3cc90dbe8bac260352993f23e8e256e063c327b6cf9c88daf7a9"}, - {file = "grpcio-1.63.0-cp312-cp312-win32.whl", hash = "sha256:3a6d1f9ea965e750db7b4ee6f9fdef5fdf135abe8a249e75d84b0a3e0c668a1b"}, - {file = "grpcio-1.63.0-cp312-cp312-win_amd64.whl", hash = "sha256:d2497769895bb03efe3187fb1888fc20e98a5f18b3d14b606167dacda5789434"}, - {file = "grpcio-1.63.0-cp38-cp38-linux_armv7l.whl", hash = "sha256:fdf348ae69c6ff484402cfdb14e18c1b0054ac2420079d575c53a60b9b2853ae"}, - {file = "grpcio-1.63.0-cp38-cp38-macosx_10_9_universal2.whl", hash = "sha256:a3abfe0b0f6798dedd2e9e92e881d9acd0fdb62ae27dcbbfa7654a57e24060c0"}, - {file = "grpcio-1.63.0-cp38-cp38-manylinux_2_17_aarch64.whl", hash = "sha256:6ef0ad92873672a2a3767cb827b64741c363ebaa27e7f21659e4e31f4d750280"}, - {file = "grpcio-1.63.0-cp38-cp38-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:b416252ac5588d9dfb8a30a191451adbf534e9ce5f56bb02cd193f12d8845b7f"}, - {file = "grpcio-1.63.0-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:e3b77eaefc74d7eb861d3ffbdf91b50a1bb1639514ebe764c47773b833fa2d91"}, - {file = "grpcio-1.63.0-cp38-cp38-musllinux_1_1_i686.whl", hash = "sha256:b005292369d9c1f80bf70c1db1c17c6c342da7576f1c689e8eee4fb0c256af85"}, - {file = "grpcio-1.63.0-cp38-cp38-musllinux_1_1_x86_64.whl", hash = "sha256:cdcda1156dcc41e042d1e899ba1f5c2e9f3cd7625b3d6ebfa619806a4c1aadda"}, - {file = "grpcio-1.63.0-cp38-cp38-win32.whl", hash = "sha256:01799e8649f9e94ba7db1aeb3452188048b0019dc37696b0f5ce212c87c560c3"}, - {file = "grpcio-1.63.0-cp38-cp38-win_amd64.whl", hash = "sha256:6a1a3642d76f887aa4009d92f71eb37809abceb3b7b5a1eec9c554a246f20e3a"}, - {file = "grpcio-1.63.0-cp39-cp39-linux_armv7l.whl", hash = "sha256:75f701ff645858a2b16bc8c9fc68af215a8bb2d5a9b647448129de6e85d52bce"}, - {file = "grpcio-1.63.0-cp39-cp39-macosx_10_9_universal2.whl", hash = "sha256:cacdef0348a08e475a721967f48206a2254a1b26ee7637638d9e081761a5ba86"}, - {file = "grpcio-1.63.0-cp39-cp39-manylinux_2_17_aarch64.whl", hash = "sha256:0697563d1d84d6985e40ec5ec596ff41b52abb3fd91ec240e8cb44a63b895094"}, - {file = "grpcio-1.63.0-cp39-cp39-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:6426e1fb92d006e47476d42b8f240c1d916a6d4423c5258ccc5b105e43438f61"}, - {file = "grpcio-1.63.0-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:e48cee31bc5f5a31fb2f3b573764bd563aaa5472342860edcc7039525b53e46a"}, - {file = "grpcio-1.63.0-cp39-cp39-musllinux_1_1_i686.whl", hash = "sha256:50344663068041b34a992c19c600236e7abb42d6ec32567916b87b4c8b8833b3"}, - {file = "grpcio-1.63.0-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:259e11932230d70ef24a21b9fb5bb947eb4703f57865a404054400ee92f42f5d"}, - {file = "grpcio-1.63.0-cp39-cp39-win32.whl", hash = "sha256:a44624aad77bf8ca198c55af811fd28f2b3eaf0a50ec5b57b06c034416ef2d0a"}, - {file = "grpcio-1.63.0-cp39-cp39-win_amd64.whl", hash = "sha256:166e5c460e5d7d4656ff9e63b13e1f6029b122104c1633d5f37eaea348d7356d"}, - {file = "grpcio-1.63.0.tar.gz", hash = "sha256:f3023e14805c61bc439fb40ca545ac3d5740ce66120a678a3c6c2c55b70343d1"}, -] - -[package.extras] -protobuf = ["grpcio-tools (>=1.63.0)"] - [[package]] name = "gym-aloha" version = "0.1.1" @@ -1668,7 +1600,6 @@ files = [ {file = "lxml-5.2.1-cp37-cp37m-musllinux_1_2_x86_64.whl", hash = "sha256:9e2addd2d1866fe112bc6f80117bcc6bc25191c5ed1bfbcf9f1386a884252ae8"}, {file = "lxml-5.2.1-cp37-cp37m-win32.whl", hash = "sha256:f51969bac61441fd31f028d7b3b45962f3ecebf691a510495e5d2cd8c8092dbd"}, {file = "lxml-5.2.1-cp37-cp37m-win_amd64.whl", hash = "sha256:b0b58fbfa1bf7367dde8a557994e3b1637294be6cf2169810375caf8571a085c"}, - {file = "lxml-5.2.1-cp38-cp38-macosx_10_9_universal2.whl", hash = "sha256:3e183c6e3298a2ed5af9d7a356ea823bccaab4ec2349dc9ed83999fd289d14d5"}, {file = "lxml-5.2.1-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:804f74efe22b6a227306dd890eecc4f8c59ff25ca35f1f14e7482bbce96ef10b"}, {file = "lxml-5.2.1-cp38-cp38-manylinux_2_12_i686.manylinux2010_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:08802f0c56ed150cc6885ae0788a321b73505d2263ee56dad84d200cab11c07a"}, {file = "lxml-5.2.1-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:0f8c09ed18ecb4ebf23e02b8e7a22a05d6411911e6fabef3a36e4f371f4f2585"}, @@ -1740,21 +1671,6 @@ html5 = ["html5lib"] htmlsoup = ["BeautifulSoup4"] source = ["Cython (>=3.0.10)"] -[[package]] -name = "markdown" -version = "3.6" -description = "Python implementation of John Gruber's Markdown." -optional = false -python-versions = ">=3.8" -files = [ - {file = "Markdown-3.6-py3-none-any.whl", hash = "sha256:48f276f4d8cfb8ce6527c8f79e2ee29708508bf4d40aa410fbc3b4ee832c850f"}, - {file = "Markdown-3.6.tar.gz", hash = "sha256:ed4f41f6daecbeeb96e576ce414c41d2d876daa9a16cb35fa8ed8c2ddfad0224"}, -] - -[package.extras] -docs = ["mdx-gh-links (>=0.2)", "mkdocs (>=1.5)", "mkdocs-gen-files", "mkdocs-literate-nav", "mkdocs-nature (>=0.6)", "mkdocs-section-index", "mkdocstrings[python]"] -testing = ["coverage", "pyyaml"] - [[package]] name = "markupsafe" version = "2.1.5" @@ -3056,6 +2972,7 @@ files = [ {file = "PyYAML-6.0.1-cp311-cp311-win_amd64.whl", hash = "sha256:bf07ee2fef7014951eeb99f56f39c9bb4af143d8aa3c21b1677805985307da34"}, {file = "PyYAML-6.0.1-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:855fb52b0dc35af121542a76b9a84f8d1cd886ea97c84703eaa6d88e37a2ad28"}, {file = "PyYAML-6.0.1-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:40df9b996c2b73138957fe23a16a4f0ba614f4c0efce1e9406a184b6d07fa3a9"}, + {file = "PyYAML-6.0.1-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:a08c6f0fe150303c1c6b71ebcd7213c2858041a7e01975da3a99aed1e7a378ef"}, {file = "PyYAML-6.0.1-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:6c22bec3fbe2524cde73d7ada88f6566758a8f7227bfbf93a408a9d86bcc12a0"}, {file = "PyYAML-6.0.1-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:8d4e9c88387b0f5c7d5f281e55304de64cf7f9c0021a3525bd3b1c542da3b0e4"}, {file = "PyYAML-6.0.1-cp312-cp312-win32.whl", hash = "sha256:d483d2cdf104e7c9fa60c544d92981f12ad66a457afae824d146093b8c294c54"}, @@ -3224,30 +3141,6 @@ typing-extensions = ">=4.5" [package.extras] tests = ["pytest (==7.1.2)"] -[[package]] -name = "robomimic" -version = "0.2.0" -description = "robomimic: A Modular Framework for Robot Learning from Demonstration" -optional = false -python-versions = ">=3" -files = [ - {file = "robomimic-0.2.0.tar.gz", hash = "sha256:ee3bb5cf9c3e1feead6b57b43c5db738fd0a8e0c015fdf6419808af8fffdc463"}, -] - -[package.dependencies] -egl_probe = ">=1.0.1" -h5py = "*" -imageio = "*" -imageio-ffmpeg = "*" -numpy = ">=1.13.3" -psutil = "*" -tensorboard = "*" -tensorboardX = "*" -termcolor = "*" -torch = "*" -torchvision = "*" -tqdm = "*" - [[package]] name = "safetensors" version = "0.4.3" @@ -3738,55 +3631,6 @@ files = [ {file = "tbb-2021.12.0-py3-none-win_amd64.whl", hash = "sha256:fc2772d850229f2f3df85f1109c4844c495a2db7433d38200959ee9265b34789"}, ] -[[package]] -name = "tensorboard" -version = "2.16.2" -description = "TensorBoard lets you watch Tensors Flow" -optional = false -python-versions = ">=3.9" -files = [ - {file = "tensorboard-2.16.2-py3-none-any.whl", hash = "sha256:9f2b4e7dad86667615c0e5cd072f1ea8403fc032a299f0072d6f74855775cc45"}, -] - -[package.dependencies] -absl-py = ">=0.4" -grpcio = ">=1.48.2" -markdown = ">=2.6.8" -numpy = ">=1.12.0" -protobuf = ">=3.19.6,<4.24.0 || >4.24.0" -setuptools = ">=41.0.0" -six = ">1.9" -tensorboard-data-server = ">=0.7.0,<0.8.0" -werkzeug = ">=1.0.1" - -[[package]] -name = "tensorboard-data-server" -version = "0.7.2" -description = "Fast data loading for TensorBoard" -optional = false -python-versions = ">=3.7" -files = [ - {file = "tensorboard_data_server-0.7.2-py3-none-any.whl", hash = "sha256:7e0610d205889588983836ec05dc098e80f97b7e7bbff7e994ebb78f578d0ddb"}, - {file = "tensorboard_data_server-0.7.2-py3-none-macosx_10_9_x86_64.whl", hash = "sha256:9fe5d24221b29625dbc7328b0436ca7fc1c23de4acf4d272f1180856e32f9f60"}, - {file = "tensorboard_data_server-0.7.2-py3-none-manylinux_2_31_x86_64.whl", hash = "sha256:ef687163c24185ae9754ed5650eb5bc4d84ff257aabdc33f0cc6f74d8ba54530"}, -] - -[[package]] -name = "tensorboardx" -version = "2.6.2.2" -description = "TensorBoardX lets you watch Tensors Flow without Tensorflow" -optional = false -python-versions = "*" -files = [ - {file = "tensorboardX-2.6.2.2-py2.py3-none-any.whl", hash = "sha256:160025acbf759ede23fd3526ae9d9bfbfd8b68eb16c38a010ebe326dc6395db8"}, - {file = "tensorboardX-2.6.2.2.tar.gz", hash = "sha256:c6476d7cd0d529b0b72f4acadb1269f9ed8b22f441e87a84f2a3b940bb87b666"}, -] - -[package.dependencies] -numpy = "*" -packaging = "*" -protobuf = ">=3.20" - [[package]] name = "termcolor" version = "2.4.0" @@ -4064,23 +3908,6 @@ perf = ["orjson"] reports = ["pydantic (>=2.0.0)"] sweeps = ["sweeps (>=0.2.0)"] -[[package]] -name = "werkzeug" -version = "3.0.3" -description = "The comprehensive WSGI web application library." -optional = false -python-versions = ">=3.8" -files = [ - {file = "werkzeug-3.0.3-py3-none-any.whl", hash = "sha256:fc9645dc43e03e4d630d23143a04a7f947a9a3b5727cd535fdfe155a17cc48c8"}, - {file = "werkzeug-3.0.3.tar.gz", hash = "sha256:097e5bfda9f0aba8da6b8545146def481d06aa7d3266e7448e2cccf67dd8bd18"}, -] - -[package.dependencies] -MarkupSafe = ">=2.1.1" - -[package.extras] -watchdog = ["watchdog (>=2.3)"] - [[package]] name = "xxhash" version = "3.4.1" @@ -4348,4 +4175,4 @@ xarm = ["gym-xarm"] [metadata] lock-version = "2.0" python-versions = ">=3.10,<3.13" -content-hash = "2f0d2cbf4a2dec546e25b29b9b108ff1f97b4c278b718360b3f7f6a2bf9dcef8" +content-hash = "e3e3c306a5519e4f716a1ac086ad9b734efedcac077a0ec71e5bc16349a1e559" diff --git a/pyproject.toml b/pyproject.toml index 24d9452d..5b80d06f 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -44,7 +44,6 @@ diffusers = "^0.27.2" torchvision = ">=0.18.0" h5py = ">=3.10.0" huggingface-hub = ">=0.21.4" -robomimic = "0.2.0" gymnasium = ">=0.29.1" cmake = ">=3.29.0.1" gym-pusht = { version = ">=0.1.3", optional = true} diff --git a/tests/data/save_policy_to_safetensors/pusht_diffusion/actions.safetensors b/tests/data/save_policy_to_safetensors/pusht_diffusion/actions.safetensors index 730f5b2b..8f039903 100644 Binary files a/tests/data/save_policy_to_safetensors/pusht_diffusion/actions.safetensors and b/tests/data/save_policy_to_safetensors/pusht_diffusion/actions.safetensors differ diff --git a/tests/data/save_policy_to_safetensors/pusht_diffusion/grad_stats.safetensors b/tests/data/save_policy_to_safetensors/pusht_diffusion/grad_stats.safetensors index f27cd678..2b659396 100644 Binary files a/tests/data/save_policy_to_safetensors/pusht_diffusion/grad_stats.safetensors and b/tests/data/save_policy_to_safetensors/pusht_diffusion/grad_stats.safetensors differ diff --git a/tests/data/save_policy_to_safetensors/pusht_diffusion/output_dict.safetensors b/tests/data/save_policy_to_safetensors/pusht_diffusion/output_dict.safetensors index 5f33535d..a9f61b36 100644 Binary files a/tests/data/save_policy_to_safetensors/pusht_diffusion/output_dict.safetensors and b/tests/data/save_policy_to_safetensors/pusht_diffusion/output_dict.safetensors differ diff --git a/tests/data/save_policy_to_safetensors/pusht_diffusion/param_stats.safetensors b/tests/data/save_policy_to_safetensors/pusht_diffusion/param_stats.safetensors index ade6a9e0..a9f4608f 100644 Binary files a/tests/data/save_policy_to_safetensors/pusht_diffusion/param_stats.safetensors and b/tests/data/save_policy_to_safetensors/pusht_diffusion/param_stats.safetensors differ