Port SpatialSoftmax and remove Robomimic dependency (#182)
Co-authored-by: Alexander Soare <alexander.soare159@gmail.com>
This commit is contained in:
parent
68c1b13406
commit
c9069df9f1
|
@ -17,7 +17,6 @@
|
||||||
"""Diffusion Policy as per "Diffusion Policy: Visuomotor Policy Learning via Action Diffusion"
|
"""Diffusion Policy as per "Diffusion Policy: Visuomotor Policy Learning via Action Diffusion"
|
||||||
|
|
||||||
TODO(alexander-soare):
|
TODO(alexander-soare):
|
||||||
- Remove reliance on Robomimic for SpatialSoftmax.
|
|
||||||
- Remove reliance on diffusers for DDPMScheduler and LR scheduler.
|
- Remove reliance on diffusers for DDPMScheduler and LR scheduler.
|
||||||
- Make compatible with multiple image keys.
|
- Make compatible with multiple image keys.
|
||||||
"""
|
"""
|
||||||
|
@ -27,13 +26,13 @@ from collections import deque
|
||||||
from typing import Callable
|
from typing import Callable
|
||||||
|
|
||||||
import einops
|
import einops
|
||||||
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
import torch.nn.functional as F # noqa: N812
|
import torch.nn.functional as F # noqa: N812
|
||||||
import torchvision
|
import torchvision
|
||||||
from diffusers.schedulers.scheduling_ddim import DDIMScheduler
|
from diffusers.schedulers.scheduling_ddim import DDIMScheduler
|
||||||
from diffusers.schedulers.scheduling_ddpm import DDPMScheduler
|
from diffusers.schedulers.scheduling_ddpm import DDPMScheduler
|
||||||
from huggingface_hub import PyTorchModelHubMixin
|
from huggingface_hub import PyTorchModelHubMixin
|
||||||
from robomimic.models.base_nets import SpatialSoftmax
|
|
||||||
from torch import Tensor, nn
|
from torch import Tensor, nn
|
||||||
|
|
||||||
from lerobot.common.policies.diffusion.configuration_diffusion import DiffusionConfig
|
from lerobot.common.policies.diffusion.configuration_diffusion import DiffusionConfig
|
||||||
|
@ -312,6 +311,77 @@ class DiffusionModel(nn.Module):
|
||||||
return loss.mean()
|
return loss.mean()
|
||||||
|
|
||||||
|
|
||||||
|
class SpatialSoftmax(nn.Module):
|
||||||
|
"""
|
||||||
|
Spatial Soft Argmax operation described in "Deep Spatial Autoencoders for Visuomotor Learning" by Finn et al.
|
||||||
|
(https://arxiv.org/pdf/1509.06113). A minimal port of the robomimic implementation.
|
||||||
|
|
||||||
|
At a high level, this takes 2D feature maps (from a convnet/ViT) and returns the "center of mass"
|
||||||
|
of activations of each channel, i.e., keypoints in the image space for the policy to focus on.
|
||||||
|
|
||||||
|
Example: take feature maps of size (512x10x12). We generate a grid of normalized coordinates (10x12x2):
|
||||||
|
-----------------------------------------------------
|
||||||
|
| (-1., -1.) | (-0.82, -1.) | ... | (1., -1.) |
|
||||||
|
| (-1., -0.78) | (-0.82, -0.78) | ... | (1., -0.78) |
|
||||||
|
| ... | ... | ... | ... |
|
||||||
|
| (-1., 1.) | (-0.82, 1.) | ... | (1., 1.) |
|
||||||
|
-----------------------------------------------------
|
||||||
|
This is achieved by applying channel-wise softmax over the activations (512x120) and computing the dot
|
||||||
|
product with the coordinates (120x2) to get expected points of maximal activation (512x2).
|
||||||
|
|
||||||
|
The example above results in 512 keypoints (corresponding to the 512 input channels). We can optionally
|
||||||
|
provide num_kp != None to control the number of keypoints. This is achieved by a first applying a learnable
|
||||||
|
linear mapping (in_channels, H, W) -> (num_kp, H, W).
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, input_shape, num_kp=None):
|
||||||
|
"""
|
||||||
|
Args:
|
||||||
|
input_shape (list): (C, H, W) input feature map shape.
|
||||||
|
num_kp (int): number of keypoints in output. If None, output will have the same number of channels as input.
|
||||||
|
"""
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
assert len(input_shape) == 3
|
||||||
|
self._in_c, self._in_h, self._in_w = input_shape
|
||||||
|
|
||||||
|
if num_kp is not None:
|
||||||
|
self.nets = torch.nn.Conv2d(self._in_c, num_kp, kernel_size=1)
|
||||||
|
self._out_c = num_kp
|
||||||
|
else:
|
||||||
|
self.nets = None
|
||||||
|
self._out_c = self._in_c
|
||||||
|
|
||||||
|
# we could use torch.linspace directly but that seems to behave slightly differently than numpy
|
||||||
|
# and causes a small degradation in pc_success of pre-trained models.
|
||||||
|
pos_x, pos_y = np.meshgrid(np.linspace(-1.0, 1.0, self._in_w), np.linspace(-1.0, 1.0, self._in_h))
|
||||||
|
pos_x = torch.from_numpy(pos_x.reshape(self._in_h * self._in_w, 1)).float()
|
||||||
|
pos_y = torch.from_numpy(pos_y.reshape(self._in_h * self._in_w, 1)).float()
|
||||||
|
# register as buffer so it's moved to the correct device.
|
||||||
|
self.register_buffer("pos_grid", torch.cat([pos_x, pos_y], dim=1))
|
||||||
|
|
||||||
|
def forward(self, features: Tensor) -> Tensor:
|
||||||
|
"""
|
||||||
|
Args:
|
||||||
|
features: (B, C, H, W) input feature maps.
|
||||||
|
Returns:
|
||||||
|
(B, K, 2) image-space coordinates of keypoints.
|
||||||
|
"""
|
||||||
|
if self.nets is not None:
|
||||||
|
features = self.nets(features)
|
||||||
|
|
||||||
|
# [B, K, H, W] -> [B * K, H * W] where K is number of keypoints
|
||||||
|
features = features.reshape(-1, self._in_h * self._in_w)
|
||||||
|
# 2d softmax normalization
|
||||||
|
attention = F.softmax(features, dim=-1)
|
||||||
|
# [B * K, H * W] x [H * W, 2] -> [B * K, 2] for spatial coordinate mean in x and y dimensions
|
||||||
|
expected_xy = attention @ self.pos_grid
|
||||||
|
# reshape to [B, K, 2]
|
||||||
|
feature_keypoints = expected_xy.view(-1, self._out_c, 2)
|
||||||
|
|
||||||
|
return feature_keypoints
|
||||||
|
|
||||||
|
|
||||||
class DiffusionRgbEncoder(nn.Module):
|
class DiffusionRgbEncoder(nn.Module):
|
||||||
"""Encoder an RGB image into a 1D feature vector.
|
"""Encoder an RGB image into a 1D feature vector.
|
||||||
|
|
||||||
|
|
|
@ -4,7 +4,7 @@
|
||||||
name = "absl-py"
|
name = "absl-py"
|
||||||
version = "2.1.0"
|
version = "2.1.0"
|
||||||
description = "Abseil Python Common Libraries, see https://github.com/abseil/abseil-py."
|
description = "Abseil Python Common Libraries, see https://github.com/abseil/abseil-py."
|
||||||
optional = false
|
optional = true
|
||||||
python-versions = ">=3.7"
|
python-versions = ">=3.7"
|
||||||
files = [
|
files = [
|
||||||
{file = "absl-py-2.1.0.tar.gz", hash = "sha256:7820790efbb316739cde8b4e19357243fc3608a152024288513dd968d7d959ff"},
|
{file = "absl-py-2.1.0.tar.gz", hash = "sha256:7820790efbb316739cde8b4e19357243fc3608a152024288513dd968d7d959ff"},
|
||||||
|
@ -767,16 +767,6 @@ files = [
|
||||||
[package.dependencies]
|
[package.dependencies]
|
||||||
six = ">=1.4.0"
|
six = ">=1.4.0"
|
||||||
|
|
||||||
[[package]]
|
|
||||||
name = "egl-probe"
|
|
||||||
version = "1.0.2"
|
|
||||||
description = ""
|
|
||||||
optional = false
|
|
||||||
python-versions = "*"
|
|
||||||
files = [
|
|
||||||
{file = "egl_probe-1.0.2.tar.gz", hash = "sha256:29bdca7b08da1e060cfb42cd46af8300a7ac4f3b1b2eeb16e545ea16d9a5ac93"},
|
|
||||||
]
|
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "einops"
|
name = "einops"
|
||||||
version = "0.8.0"
|
version = "0.8.0"
|
||||||
|
@ -1037,64 +1027,6 @@ files = [
|
||||||
[package.extras]
|
[package.extras]
|
||||||
preview = ["glfw-preview"]
|
preview = ["glfw-preview"]
|
||||||
|
|
||||||
[[package]]
|
|
||||||
name = "grpcio"
|
|
||||||
version = "1.63.0"
|
|
||||||
description = "HTTP/2-based RPC framework"
|
|
||||||
optional = false
|
|
||||||
python-versions = ">=3.8"
|
|
||||||
files = [
|
|
||||||
{file = "grpcio-1.63.0-cp310-cp310-linux_armv7l.whl", hash = "sha256:2e93aca840c29d4ab5db93f94ed0a0ca899e241f2e8aec6334ab3575dc46125c"},
|
|
||||||
{file = "grpcio-1.63.0-cp310-cp310-macosx_12_0_universal2.whl", hash = "sha256:91b73d3f1340fefa1e1716c8c1ec9930c676d6b10a3513ab6c26004cb02d8b3f"},
|
|
||||||
{file = "grpcio-1.63.0-cp310-cp310-manylinux_2_17_aarch64.whl", hash = "sha256:b3afbd9d6827fa6f475a4f91db55e441113f6d3eb9b7ebb8fb806e5bb6d6bd0d"},
|
|
||||||
{file = "grpcio-1.63.0-cp310-cp310-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:8f3f6883ce54a7a5f47db43289a0a4c776487912de1a0e2cc83fdaec9685cc9f"},
|
|
||||||
{file = "grpcio-1.63.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:cf8dae9cc0412cb86c8de5a8f3be395c5119a370f3ce2e69c8b7d46bb9872c8d"},
|
|
||||||
{file = "grpcio-1.63.0-cp310-cp310-musllinux_1_1_i686.whl", hash = "sha256:08e1559fd3b3b4468486b26b0af64a3904a8dbc78d8d936af9c1cf9636eb3e8b"},
|
|
||||||
{file = "grpcio-1.63.0-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:5c039ef01516039fa39da8a8a43a95b64e288f79f42a17e6c2904a02a319b357"},
|
|
||||||
{file = "grpcio-1.63.0-cp310-cp310-win32.whl", hash = "sha256:ad2ac8903b2eae071055a927ef74121ed52d69468e91d9bcbd028bd0e554be6d"},
|
|
||||||
{file = "grpcio-1.63.0-cp310-cp310-win_amd64.whl", hash = "sha256:b2e44f59316716532a993ca2966636df6fbe7be4ab6f099de6815570ebe4383a"},
|
|
||||||
{file = "grpcio-1.63.0-cp311-cp311-linux_armv7l.whl", hash = "sha256:f28f8b2db7b86c77916829d64ab21ff49a9d8289ea1564a2b2a3a8ed9ffcccd3"},
|
|
||||||
{file = "grpcio-1.63.0-cp311-cp311-macosx_10_9_universal2.whl", hash = "sha256:65bf975639a1f93bee63ca60d2e4951f1b543f498d581869922910a476ead2f5"},
|
|
||||||
{file = "grpcio-1.63.0-cp311-cp311-manylinux_2_17_aarch64.whl", hash = "sha256:b5194775fec7dc3dbd6a935102bb156cd2c35efe1685b0a46c67b927c74f0cfb"},
|
|
||||||
{file = "grpcio-1.63.0-cp311-cp311-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:e4cbb2100ee46d024c45920d16e888ee5d3cf47c66e316210bc236d5bebc42b3"},
|
|
||||||
{file = "grpcio-1.63.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:1ff737cf29b5b801619f10e59b581869e32f400159e8b12d7a97e7e3bdeee6a2"},
|
|
||||||
{file = "grpcio-1.63.0-cp311-cp311-musllinux_1_1_i686.whl", hash = "sha256:cd1e68776262dd44dedd7381b1a0ad09d9930ffb405f737d64f505eb7f77d6c7"},
|
|
||||||
{file = "grpcio-1.63.0-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:93f45f27f516548e23e4ec3fbab21b060416007dbe768a111fc4611464cc773f"},
|
|
||||||
{file = "grpcio-1.63.0-cp311-cp311-win32.whl", hash = "sha256:878b1d88d0137df60e6b09b74cdb73db123f9579232c8456f53e9abc4f62eb3c"},
|
|
||||||
{file = "grpcio-1.63.0-cp311-cp311-win_amd64.whl", hash = "sha256:756fed02dacd24e8f488f295a913f250b56b98fb793f41d5b2de6c44fb762434"},
|
|
||||||
{file = "grpcio-1.63.0-cp312-cp312-linux_armv7l.whl", hash = "sha256:93a46794cc96c3a674cdfb59ef9ce84d46185fe9421baf2268ccb556f8f81f57"},
|
|
||||||
{file = "grpcio-1.63.0-cp312-cp312-macosx_10_9_universal2.whl", hash = "sha256:a7b19dfc74d0be7032ca1eda0ed545e582ee46cd65c162f9e9fc6b26ef827dc6"},
|
|
||||||
{file = "grpcio-1.63.0-cp312-cp312-manylinux_2_17_aarch64.whl", hash = "sha256:8064d986d3a64ba21e498b9a376cbc5d6ab2e8ab0e288d39f266f0fca169b90d"},
|
|
||||||
{file = "grpcio-1.63.0-cp312-cp312-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:219bb1848cd2c90348c79ed0a6b0ea51866bc7e72fa6e205e459fedab5770172"},
|
|
||||||
{file = "grpcio-1.63.0-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:a2d60cd1d58817bc5985fae6168d8b5655c4981d448d0f5b6194bbcc038090d2"},
|
|
||||||
{file = "grpcio-1.63.0-cp312-cp312-musllinux_1_1_i686.whl", hash = "sha256:9e350cb096e5c67832e9b6e018cf8a0d2a53b2a958f6251615173165269a91b0"},
|
|
||||||
{file = "grpcio-1.63.0-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:56cdf96ff82e3cc90dbe8bac260352993f23e8e256e063c327b6cf9c88daf7a9"},
|
|
||||||
{file = "grpcio-1.63.0-cp312-cp312-win32.whl", hash = "sha256:3a6d1f9ea965e750db7b4ee6f9fdef5fdf135abe8a249e75d84b0a3e0c668a1b"},
|
|
||||||
{file = "grpcio-1.63.0-cp312-cp312-win_amd64.whl", hash = "sha256:d2497769895bb03efe3187fb1888fc20e98a5f18b3d14b606167dacda5789434"},
|
|
||||||
{file = "grpcio-1.63.0-cp38-cp38-linux_armv7l.whl", hash = "sha256:fdf348ae69c6ff484402cfdb14e18c1b0054ac2420079d575c53a60b9b2853ae"},
|
|
||||||
{file = "grpcio-1.63.0-cp38-cp38-macosx_10_9_universal2.whl", hash = "sha256:a3abfe0b0f6798dedd2e9e92e881d9acd0fdb62ae27dcbbfa7654a57e24060c0"},
|
|
||||||
{file = "grpcio-1.63.0-cp38-cp38-manylinux_2_17_aarch64.whl", hash = "sha256:6ef0ad92873672a2a3767cb827b64741c363ebaa27e7f21659e4e31f4d750280"},
|
|
||||||
{file = "grpcio-1.63.0-cp38-cp38-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:b416252ac5588d9dfb8a30a191451adbf534e9ce5f56bb02cd193f12d8845b7f"},
|
|
||||||
{file = "grpcio-1.63.0-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:e3b77eaefc74d7eb861d3ffbdf91b50a1bb1639514ebe764c47773b833fa2d91"},
|
|
||||||
{file = "grpcio-1.63.0-cp38-cp38-musllinux_1_1_i686.whl", hash = "sha256:b005292369d9c1f80bf70c1db1c17c6c342da7576f1c689e8eee4fb0c256af85"},
|
|
||||||
{file = "grpcio-1.63.0-cp38-cp38-musllinux_1_1_x86_64.whl", hash = "sha256:cdcda1156dcc41e042d1e899ba1f5c2e9f3cd7625b3d6ebfa619806a4c1aadda"},
|
|
||||||
{file = "grpcio-1.63.0-cp38-cp38-win32.whl", hash = "sha256:01799e8649f9e94ba7db1aeb3452188048b0019dc37696b0f5ce212c87c560c3"},
|
|
||||||
{file = "grpcio-1.63.0-cp38-cp38-win_amd64.whl", hash = "sha256:6a1a3642d76f887aa4009d92f71eb37809abceb3b7b5a1eec9c554a246f20e3a"},
|
|
||||||
{file = "grpcio-1.63.0-cp39-cp39-linux_armv7l.whl", hash = "sha256:75f701ff645858a2b16bc8c9fc68af215a8bb2d5a9b647448129de6e85d52bce"},
|
|
||||||
{file = "grpcio-1.63.0-cp39-cp39-macosx_10_9_universal2.whl", hash = "sha256:cacdef0348a08e475a721967f48206a2254a1b26ee7637638d9e081761a5ba86"},
|
|
||||||
{file = "grpcio-1.63.0-cp39-cp39-manylinux_2_17_aarch64.whl", hash = "sha256:0697563d1d84d6985e40ec5ec596ff41b52abb3fd91ec240e8cb44a63b895094"},
|
|
||||||
{file = "grpcio-1.63.0-cp39-cp39-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:6426e1fb92d006e47476d42b8f240c1d916a6d4423c5258ccc5b105e43438f61"},
|
|
||||||
{file = "grpcio-1.63.0-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:e48cee31bc5f5a31fb2f3b573764bd563aaa5472342860edcc7039525b53e46a"},
|
|
||||||
{file = "grpcio-1.63.0-cp39-cp39-musllinux_1_1_i686.whl", hash = "sha256:50344663068041b34a992c19c600236e7abb42d6ec32567916b87b4c8b8833b3"},
|
|
||||||
{file = "grpcio-1.63.0-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:259e11932230d70ef24a21b9fb5bb947eb4703f57865a404054400ee92f42f5d"},
|
|
||||||
{file = "grpcio-1.63.0-cp39-cp39-win32.whl", hash = "sha256:a44624aad77bf8ca198c55af811fd28f2b3eaf0a50ec5b57b06c034416ef2d0a"},
|
|
||||||
{file = "grpcio-1.63.0-cp39-cp39-win_amd64.whl", hash = "sha256:166e5c460e5d7d4656ff9e63b13e1f6029b122104c1633d5f37eaea348d7356d"},
|
|
||||||
{file = "grpcio-1.63.0.tar.gz", hash = "sha256:f3023e14805c61bc439fb40ca545ac3d5740ce66120a678a3c6c2c55b70343d1"},
|
|
||||||
]
|
|
||||||
|
|
||||||
[package.extras]
|
|
||||||
protobuf = ["grpcio-tools (>=1.63.0)"]
|
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "gym-aloha"
|
name = "gym-aloha"
|
||||||
version = "0.1.1"
|
version = "0.1.1"
|
||||||
|
@ -1668,7 +1600,6 @@ files = [
|
||||||
{file = "lxml-5.2.1-cp37-cp37m-musllinux_1_2_x86_64.whl", hash = "sha256:9e2addd2d1866fe112bc6f80117bcc6bc25191c5ed1bfbcf9f1386a884252ae8"},
|
{file = "lxml-5.2.1-cp37-cp37m-musllinux_1_2_x86_64.whl", hash = "sha256:9e2addd2d1866fe112bc6f80117bcc6bc25191c5ed1bfbcf9f1386a884252ae8"},
|
||||||
{file = "lxml-5.2.1-cp37-cp37m-win32.whl", hash = "sha256:f51969bac61441fd31f028d7b3b45962f3ecebf691a510495e5d2cd8c8092dbd"},
|
{file = "lxml-5.2.1-cp37-cp37m-win32.whl", hash = "sha256:f51969bac61441fd31f028d7b3b45962f3ecebf691a510495e5d2cd8c8092dbd"},
|
||||||
{file = "lxml-5.2.1-cp37-cp37m-win_amd64.whl", hash = "sha256:b0b58fbfa1bf7367dde8a557994e3b1637294be6cf2169810375caf8571a085c"},
|
{file = "lxml-5.2.1-cp37-cp37m-win_amd64.whl", hash = "sha256:b0b58fbfa1bf7367dde8a557994e3b1637294be6cf2169810375caf8571a085c"},
|
||||||
{file = "lxml-5.2.1-cp38-cp38-macosx_10_9_universal2.whl", hash = "sha256:3e183c6e3298a2ed5af9d7a356ea823bccaab4ec2349dc9ed83999fd289d14d5"},
|
|
||||||
{file = "lxml-5.2.1-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:804f74efe22b6a227306dd890eecc4f8c59ff25ca35f1f14e7482bbce96ef10b"},
|
{file = "lxml-5.2.1-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:804f74efe22b6a227306dd890eecc4f8c59ff25ca35f1f14e7482bbce96ef10b"},
|
||||||
{file = "lxml-5.2.1-cp38-cp38-manylinux_2_12_i686.manylinux2010_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:08802f0c56ed150cc6885ae0788a321b73505d2263ee56dad84d200cab11c07a"},
|
{file = "lxml-5.2.1-cp38-cp38-manylinux_2_12_i686.manylinux2010_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:08802f0c56ed150cc6885ae0788a321b73505d2263ee56dad84d200cab11c07a"},
|
||||||
{file = "lxml-5.2.1-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:0f8c09ed18ecb4ebf23e02b8e7a22a05d6411911e6fabef3a36e4f371f4f2585"},
|
{file = "lxml-5.2.1-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:0f8c09ed18ecb4ebf23e02b8e7a22a05d6411911e6fabef3a36e4f371f4f2585"},
|
||||||
|
@ -1740,21 +1671,6 @@ html5 = ["html5lib"]
|
||||||
htmlsoup = ["BeautifulSoup4"]
|
htmlsoup = ["BeautifulSoup4"]
|
||||||
source = ["Cython (>=3.0.10)"]
|
source = ["Cython (>=3.0.10)"]
|
||||||
|
|
||||||
[[package]]
|
|
||||||
name = "markdown"
|
|
||||||
version = "3.6"
|
|
||||||
description = "Python implementation of John Gruber's Markdown."
|
|
||||||
optional = false
|
|
||||||
python-versions = ">=3.8"
|
|
||||||
files = [
|
|
||||||
{file = "Markdown-3.6-py3-none-any.whl", hash = "sha256:48f276f4d8cfb8ce6527c8f79e2ee29708508bf4d40aa410fbc3b4ee832c850f"},
|
|
||||||
{file = "Markdown-3.6.tar.gz", hash = "sha256:ed4f41f6daecbeeb96e576ce414c41d2d876daa9a16cb35fa8ed8c2ddfad0224"},
|
|
||||||
]
|
|
||||||
|
|
||||||
[package.extras]
|
|
||||||
docs = ["mdx-gh-links (>=0.2)", "mkdocs (>=1.5)", "mkdocs-gen-files", "mkdocs-literate-nav", "mkdocs-nature (>=0.6)", "mkdocs-section-index", "mkdocstrings[python]"]
|
|
||||||
testing = ["coverage", "pyyaml"]
|
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "markupsafe"
|
name = "markupsafe"
|
||||||
version = "2.1.5"
|
version = "2.1.5"
|
||||||
|
@ -3056,6 +2972,7 @@ files = [
|
||||||
{file = "PyYAML-6.0.1-cp311-cp311-win_amd64.whl", hash = "sha256:bf07ee2fef7014951eeb99f56f39c9bb4af143d8aa3c21b1677805985307da34"},
|
{file = "PyYAML-6.0.1-cp311-cp311-win_amd64.whl", hash = "sha256:bf07ee2fef7014951eeb99f56f39c9bb4af143d8aa3c21b1677805985307da34"},
|
||||||
{file = "PyYAML-6.0.1-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:855fb52b0dc35af121542a76b9a84f8d1cd886ea97c84703eaa6d88e37a2ad28"},
|
{file = "PyYAML-6.0.1-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:855fb52b0dc35af121542a76b9a84f8d1cd886ea97c84703eaa6d88e37a2ad28"},
|
||||||
{file = "PyYAML-6.0.1-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:40df9b996c2b73138957fe23a16a4f0ba614f4c0efce1e9406a184b6d07fa3a9"},
|
{file = "PyYAML-6.0.1-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:40df9b996c2b73138957fe23a16a4f0ba614f4c0efce1e9406a184b6d07fa3a9"},
|
||||||
|
{file = "PyYAML-6.0.1-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:a08c6f0fe150303c1c6b71ebcd7213c2858041a7e01975da3a99aed1e7a378ef"},
|
||||||
{file = "PyYAML-6.0.1-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:6c22bec3fbe2524cde73d7ada88f6566758a8f7227bfbf93a408a9d86bcc12a0"},
|
{file = "PyYAML-6.0.1-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:6c22bec3fbe2524cde73d7ada88f6566758a8f7227bfbf93a408a9d86bcc12a0"},
|
||||||
{file = "PyYAML-6.0.1-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:8d4e9c88387b0f5c7d5f281e55304de64cf7f9c0021a3525bd3b1c542da3b0e4"},
|
{file = "PyYAML-6.0.1-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:8d4e9c88387b0f5c7d5f281e55304de64cf7f9c0021a3525bd3b1c542da3b0e4"},
|
||||||
{file = "PyYAML-6.0.1-cp312-cp312-win32.whl", hash = "sha256:d483d2cdf104e7c9fa60c544d92981f12ad66a457afae824d146093b8c294c54"},
|
{file = "PyYAML-6.0.1-cp312-cp312-win32.whl", hash = "sha256:d483d2cdf104e7c9fa60c544d92981f12ad66a457afae824d146093b8c294c54"},
|
||||||
|
@ -3224,30 +3141,6 @@ typing-extensions = ">=4.5"
|
||||||
[package.extras]
|
[package.extras]
|
||||||
tests = ["pytest (==7.1.2)"]
|
tests = ["pytest (==7.1.2)"]
|
||||||
|
|
||||||
[[package]]
|
|
||||||
name = "robomimic"
|
|
||||||
version = "0.2.0"
|
|
||||||
description = "robomimic: A Modular Framework for Robot Learning from Demonstration"
|
|
||||||
optional = false
|
|
||||||
python-versions = ">=3"
|
|
||||||
files = [
|
|
||||||
{file = "robomimic-0.2.0.tar.gz", hash = "sha256:ee3bb5cf9c3e1feead6b57b43c5db738fd0a8e0c015fdf6419808af8fffdc463"},
|
|
||||||
]
|
|
||||||
|
|
||||||
[package.dependencies]
|
|
||||||
egl_probe = ">=1.0.1"
|
|
||||||
h5py = "*"
|
|
||||||
imageio = "*"
|
|
||||||
imageio-ffmpeg = "*"
|
|
||||||
numpy = ">=1.13.3"
|
|
||||||
psutil = "*"
|
|
||||||
tensorboard = "*"
|
|
||||||
tensorboardX = "*"
|
|
||||||
termcolor = "*"
|
|
||||||
torch = "*"
|
|
||||||
torchvision = "*"
|
|
||||||
tqdm = "*"
|
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "safetensors"
|
name = "safetensors"
|
||||||
version = "0.4.3"
|
version = "0.4.3"
|
||||||
|
@ -3738,55 +3631,6 @@ files = [
|
||||||
{file = "tbb-2021.12.0-py3-none-win_amd64.whl", hash = "sha256:fc2772d850229f2f3df85f1109c4844c495a2db7433d38200959ee9265b34789"},
|
{file = "tbb-2021.12.0-py3-none-win_amd64.whl", hash = "sha256:fc2772d850229f2f3df85f1109c4844c495a2db7433d38200959ee9265b34789"},
|
||||||
]
|
]
|
||||||
|
|
||||||
[[package]]
|
|
||||||
name = "tensorboard"
|
|
||||||
version = "2.16.2"
|
|
||||||
description = "TensorBoard lets you watch Tensors Flow"
|
|
||||||
optional = false
|
|
||||||
python-versions = ">=3.9"
|
|
||||||
files = [
|
|
||||||
{file = "tensorboard-2.16.2-py3-none-any.whl", hash = "sha256:9f2b4e7dad86667615c0e5cd072f1ea8403fc032a299f0072d6f74855775cc45"},
|
|
||||||
]
|
|
||||||
|
|
||||||
[package.dependencies]
|
|
||||||
absl-py = ">=0.4"
|
|
||||||
grpcio = ">=1.48.2"
|
|
||||||
markdown = ">=2.6.8"
|
|
||||||
numpy = ">=1.12.0"
|
|
||||||
protobuf = ">=3.19.6,<4.24.0 || >4.24.0"
|
|
||||||
setuptools = ">=41.0.0"
|
|
||||||
six = ">1.9"
|
|
||||||
tensorboard-data-server = ">=0.7.0,<0.8.0"
|
|
||||||
werkzeug = ">=1.0.1"
|
|
||||||
|
|
||||||
[[package]]
|
|
||||||
name = "tensorboard-data-server"
|
|
||||||
version = "0.7.2"
|
|
||||||
description = "Fast data loading for TensorBoard"
|
|
||||||
optional = false
|
|
||||||
python-versions = ">=3.7"
|
|
||||||
files = [
|
|
||||||
{file = "tensorboard_data_server-0.7.2-py3-none-any.whl", hash = "sha256:7e0610d205889588983836ec05dc098e80f97b7e7bbff7e994ebb78f578d0ddb"},
|
|
||||||
{file = "tensorboard_data_server-0.7.2-py3-none-macosx_10_9_x86_64.whl", hash = "sha256:9fe5d24221b29625dbc7328b0436ca7fc1c23de4acf4d272f1180856e32f9f60"},
|
|
||||||
{file = "tensorboard_data_server-0.7.2-py3-none-manylinux_2_31_x86_64.whl", hash = "sha256:ef687163c24185ae9754ed5650eb5bc4d84ff257aabdc33f0cc6f74d8ba54530"},
|
|
||||||
]
|
|
||||||
|
|
||||||
[[package]]
|
|
||||||
name = "tensorboardx"
|
|
||||||
version = "2.6.2.2"
|
|
||||||
description = "TensorBoardX lets you watch Tensors Flow without Tensorflow"
|
|
||||||
optional = false
|
|
||||||
python-versions = "*"
|
|
||||||
files = [
|
|
||||||
{file = "tensorboardX-2.6.2.2-py2.py3-none-any.whl", hash = "sha256:160025acbf759ede23fd3526ae9d9bfbfd8b68eb16c38a010ebe326dc6395db8"},
|
|
||||||
{file = "tensorboardX-2.6.2.2.tar.gz", hash = "sha256:c6476d7cd0d529b0b72f4acadb1269f9ed8b22f441e87a84f2a3b940bb87b666"},
|
|
||||||
]
|
|
||||||
|
|
||||||
[package.dependencies]
|
|
||||||
numpy = "*"
|
|
||||||
packaging = "*"
|
|
||||||
protobuf = ">=3.20"
|
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "termcolor"
|
name = "termcolor"
|
||||||
version = "2.4.0"
|
version = "2.4.0"
|
||||||
|
@ -4064,23 +3908,6 @@ perf = ["orjson"]
|
||||||
reports = ["pydantic (>=2.0.0)"]
|
reports = ["pydantic (>=2.0.0)"]
|
||||||
sweeps = ["sweeps (>=0.2.0)"]
|
sweeps = ["sweeps (>=0.2.0)"]
|
||||||
|
|
||||||
[[package]]
|
|
||||||
name = "werkzeug"
|
|
||||||
version = "3.0.3"
|
|
||||||
description = "The comprehensive WSGI web application library."
|
|
||||||
optional = false
|
|
||||||
python-versions = ">=3.8"
|
|
||||||
files = [
|
|
||||||
{file = "werkzeug-3.0.3-py3-none-any.whl", hash = "sha256:fc9645dc43e03e4d630d23143a04a7f947a9a3b5727cd535fdfe155a17cc48c8"},
|
|
||||||
{file = "werkzeug-3.0.3.tar.gz", hash = "sha256:097e5bfda9f0aba8da6b8545146def481d06aa7d3266e7448e2cccf67dd8bd18"},
|
|
||||||
]
|
|
||||||
|
|
||||||
[package.dependencies]
|
|
||||||
MarkupSafe = ">=2.1.1"
|
|
||||||
|
|
||||||
[package.extras]
|
|
||||||
watchdog = ["watchdog (>=2.3)"]
|
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "xxhash"
|
name = "xxhash"
|
||||||
version = "3.4.1"
|
version = "3.4.1"
|
||||||
|
@ -4348,4 +4175,4 @@ xarm = ["gym-xarm"]
|
||||||
[metadata]
|
[metadata]
|
||||||
lock-version = "2.0"
|
lock-version = "2.0"
|
||||||
python-versions = ">=3.10,<3.13"
|
python-versions = ">=3.10,<3.13"
|
||||||
content-hash = "2f0d2cbf4a2dec546e25b29b9b108ff1f97b4c278b718360b3f7f6a2bf9dcef8"
|
content-hash = "e3e3c306a5519e4f716a1ac086ad9b734efedcac077a0ec71e5bc16349a1e559"
|
||||||
|
|
|
@ -44,7 +44,6 @@ diffusers = "^0.27.2"
|
||||||
torchvision = ">=0.18.0"
|
torchvision = ">=0.18.0"
|
||||||
h5py = ">=3.10.0"
|
h5py = ">=3.10.0"
|
||||||
huggingface-hub = ">=0.21.4"
|
huggingface-hub = ">=0.21.4"
|
||||||
robomimic = "0.2.0"
|
|
||||||
gymnasium = ">=0.29.1"
|
gymnasium = ">=0.29.1"
|
||||||
cmake = ">=3.29.0.1"
|
cmake = ">=3.29.0.1"
|
||||||
gym-pusht = { version = ">=0.1.3", optional = true}
|
gym-pusht = { version = ">=0.1.3", optional = true}
|
||||||
|
|
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Loading…
Reference in New Issue