Merge branch 'user/alexander-soare/train_pusht' into user/alexander-soare/multistep_policy_and_serial_env

This commit is contained in:
Alexander Soare 2024-03-14 16:06:21 +00:00
commit a222c88c99
5 changed files with 241 additions and 19 deletions

View File

@ -1,15 +1,37 @@
import copy import copy
from typing import Dict, Tuple, Union from typing import Dict, Optional, Tuple, Union
import timm
import torch import torch
import torch.nn as nn import torch.nn as nn
import torchvision import torchvision
from robomimic.models.base_nets import SpatialSoftmax
from lerobot.common.policies.diffusion.model.crop_randomizer import CropRandomizer from lerobot.common.policies.diffusion.model.crop_randomizer import CropRandomizer
from lerobot.common.policies.diffusion.model.module_attr_mixin import ModuleAttrMixin from lerobot.common.policies.diffusion.model.module_attr_mixin import ModuleAttrMixin
from lerobot.common.policies.diffusion.pytorch_utils import replace_submodules from lerobot.common.policies.diffusion.pytorch_utils import replace_submodules
class RgbEncoder(nn.Module):
"""Following `VisualCore` from Robomimic 0.2.0."""
def __init__(self, input_shape, model_name="resnet18", pretrained=False, num_keypoints=32):
"""
resnet_name: a timm model name.
pretrained: whether to use timm pretrained weights.
num_keypoints: Number of keypoints for SpatialSoftmax (default value of 32 matches PushT Image).
"""
super().__init__()
self.backbone = timm.create_model(model_name, pretrained, num_classes=0, global_pool="")
# Figure out the feature map shape.
with torch.inference_mode():
feat_map_shape = tuple(self.backbone(torch.zeros(size=(1, *input_shape))).shape[1:])
self.pool = SpatialSoftmax(feat_map_shape, num_kp=num_keypoints)
def forward(self, x):
return torch.flatten(self.pool(self.backbone(x)), start_dim=1)
class MultiImageObsEncoder(ModuleAttrMixin): class MultiImageObsEncoder(ModuleAttrMixin):
def __init__( def __init__(
self, self,
@ -24,7 +46,7 @@ class MultiImageObsEncoder(ModuleAttrMixin):
share_rgb_model: bool = False, share_rgb_model: bool = False,
# renormalize rgb input with imagenet normalization # renormalize rgb input with imagenet normalization
# assuming input in [0,1] # assuming input in [0,1]
imagenet_norm: bool = False, norm_mean_std: Optional[tuple[float, float]] = None,
): ):
""" """
Assumes rgb input: B,C,H,W Assumes rgb input: B,C,H,W
@ -98,10 +120,9 @@ class MultiImageObsEncoder(ModuleAttrMixin):
this_normalizer = torchvision.transforms.CenterCrop(size=(h, w)) this_normalizer = torchvision.transforms.CenterCrop(size=(h, w))
# configure normalizer # configure normalizer
this_normalizer = nn.Identity() this_normalizer = nn.Identity()
if imagenet_norm: if norm_mean_std is not None:
# TODO(rcadene): move normalizer to dataset and env
this_normalizer = torchvision.transforms.Normalize( this_normalizer = torchvision.transforms.Normalize(
mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225] mean=norm_mean_std[0], std=norm_mean_std[1]
) )
this_transform = nn.Sequential(this_resizer, this_randomizer, this_normalizer) this_transform = nn.Sequential(this_resizer, this_randomizer, this_normalizer)

View File

@ -7,7 +7,7 @@ import torch
from lerobot.common.policies.abstract import AbstractPolicy from lerobot.common.policies.abstract import AbstractPolicy
from lerobot.common.policies.diffusion.diffusion_unet_image_policy import DiffusionUnetImagePolicy from lerobot.common.policies.diffusion.diffusion_unet_image_policy import DiffusionUnetImagePolicy
from lerobot.common.policies.diffusion.model.lr_scheduler import get_scheduler from lerobot.common.policies.diffusion.model.lr_scheduler import get_scheduler
from lerobot.common.policies.diffusion.model.multi_image_obs_encoder import MultiImageObsEncoder from lerobot.common.policies.diffusion.model.multi_image_obs_encoder import MultiImageObsEncoder, RgbEncoder
class DiffusionPolicy(AbstractPolicy): class DiffusionPolicy(AbstractPolicy):
@ -38,7 +38,7 @@ class DiffusionPolicy(AbstractPolicy):
self.cfg = cfg self.cfg = cfg
noise_scheduler = hydra.utils.instantiate(cfg_noise_scheduler) noise_scheduler = hydra.utils.instantiate(cfg_noise_scheduler)
rgb_model = hydra.utils.instantiate(cfg_rgb_model) rgb_model = RgbEncoder(input_shape=shape_meta.obs.image.shape, **cfg_rgb_model)
obs_encoder = MultiImageObsEncoder( obs_encoder = MultiImageObsEncoder(
rgb_model=rgb_model, rgb_model=rgb_model,
**cfg_obs_encoder, **cfg_obs_encoder,

View File

@ -42,8 +42,8 @@ policy:
num_inference_steps: 100 num_inference_steps: 100
obs_as_global_cond: ${obs_as_global_cond} obs_as_global_cond: ${obs_as_global_cond}
# crop_shape: null # crop_shape: null
diffusion_step_embed_dim: 256 # before 128 diffusion_step_embed_dim: 128
down_dims: [256, 512, 1024] # before [512, 1024, 2048] down_dims: [512, 1024, 2048]
kernel_size: 5 kernel_size: 5
n_groups: 8 n_groups: 8
cond_predict_scale: True cond_predict_scale: True
@ -81,12 +81,12 @@ obs_encoder:
# random_crop: True # random_crop: True
use_group_norm: True use_group_norm: True
share_rgb_model: False share_rgb_model: False
imagenet_norm: True norm_mean_std: [0.5, 0.5] # for PushT the original impl normalizes to [-1, 1] (maybe not the case for robomimic envs)
rgb_model: rgb_model:
_target_: lerobot.common.policies.diffusion.pytorch_utils.get_resnet model_name: resnet18
name: resnet18 pretrained: false
weights: null num_keypoints: 32
ema: ema:
_target_: lerobot.common.policies.diffusion.model.ema_model.EMAModel _target_: lerobot.common.policies.diffusion.model.ema_model.EMAModel
@ -109,13 +109,13 @@ training:
debug: False debug: False
resume: True resume: True
# optimization # optimization
# lr_scheduler: cosine lr_scheduler: cosine
# lr_warmup_steps: 500 lr_warmup_steps: 500
num_epochs: 8000 num_epochs: 500
# gradient_accumulate_every: 1 # gradient_accumulate_every: 1
# EMA destroys performance when used with BatchNorm # EMA destroys performance when used with BatchNorm
# replace BatchNorm with GroupNorm. # replace BatchNorm with GroupNorm.
# use_ema: True use_ema: True
freeze_encoder: False freeze_encoder: False
# training loop control # training loop control
# in epochs # in epochs

203
poetry.lock generated
View File

@ -1,4 +1,4 @@
# This file is automatically @generated by Poetry 1.8.1 and should not be changed by hand. # This file is automatically @generated by Poetry 1.8.2 and should not be changed by hand.
[[package]] [[package]]
name = "absl-py" name = "absl-py"
@ -604,6 +604,16 @@ files = [
[package.dependencies] [package.dependencies]
six = ">=1.4.0" six = ">=1.4.0"
[[package]]
name = "egl-probe"
version = "1.0.2"
description = ""
optional = false
python-versions = "*"
files = [
{file = "egl_probe-1.0.2.tar.gz", hash = "sha256:29bdca7b08da1e060cfb42cd46af8300a7ac4f3b1b2eeb16e545ea16d9a5ac93"},
]
[[package]] [[package]]
name = "einops" name = "einops"
version = "0.7.0" version = "0.7.0"
@ -763,6 +773,72 @@ files = [
[package.extras] [package.extras]
preview = ["glfw-preview"] preview = ["glfw-preview"]
[[package]]
name = "grpcio"
version = "1.62.1"
description = "HTTP/2-based RPC framework"
optional = false
python-versions = ">=3.7"
files = [
{file = "grpcio-1.62.1-cp310-cp310-linux_armv7l.whl", hash = "sha256:179bee6f5ed7b5f618844f760b6acf7e910988de77a4f75b95bbfaa8106f3c1e"},
{file = "grpcio-1.62.1-cp310-cp310-macosx_12_0_universal2.whl", hash = "sha256:48611e4fa010e823ba2de8fd3f77c1322dd60cb0d180dc6630a7e157b205f7ea"},
{file = "grpcio-1.62.1-cp310-cp310-manylinux_2_17_aarch64.whl", hash = "sha256:b2a0e71b0a2158aa4bce48be9f8f9eb45cbd17c78c7443616d00abbe2a509f6d"},
{file = "grpcio-1.62.1-cp310-cp310-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:fbe80577c7880911d3ad65e5ecc997416c98f354efeba2f8d0f9112a67ed65a5"},
{file = "grpcio-1.62.1-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:58f6c693d446964e3292425e1d16e21a97a48ba9172f2d0df9d7b640acb99243"},
{file = "grpcio-1.62.1-cp310-cp310-musllinux_1_1_i686.whl", hash = "sha256:77c339403db5a20ef4fed02e4d1a9a3d9866bf9c0afc77a42234677313ea22f3"},
{file = "grpcio-1.62.1-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:b5a4ea906db7dec694098435d84bf2854fe158eb3cd51e1107e571246d4d1d70"},
{file = "grpcio-1.62.1-cp310-cp310-win32.whl", hash = "sha256:4187201a53f8561c015bc745b81a1b2d278967b8de35f3399b84b0695e281d5f"},
{file = "grpcio-1.62.1-cp310-cp310-win_amd64.whl", hash = "sha256:844d1f3fb11bd1ed362d3fdc495d0770cfab75761836193af166fee113421d66"},
{file = "grpcio-1.62.1-cp311-cp311-linux_armv7l.whl", hash = "sha256:833379943d1728a005e44103f17ecd73d058d37d95783eb8f0b28ddc1f54d7b2"},
{file = "grpcio-1.62.1-cp311-cp311-macosx_10_10_universal2.whl", hash = "sha256:c7fcc6a32e7b7b58f5a7d27530669337a5d587d4066060bcb9dee7a8c833dfb7"},
{file = "grpcio-1.62.1-cp311-cp311-manylinux_2_17_aarch64.whl", hash = "sha256:fa7d28eb4d50b7cbe75bb8b45ed0da9a1dc5b219a0af59449676a29c2eed9698"},
{file = "grpcio-1.62.1-cp311-cp311-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:48f7135c3de2f298b833be8b4ae20cafe37091634e91f61f5a7eb3d61ec6f660"},
{file = "grpcio-1.62.1-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:71f11fd63365ade276c9d4a7b7df5c136f9030e3457107e1791b3737a9b9ed6a"},
{file = "grpcio-1.62.1-cp311-cp311-musllinux_1_1_i686.whl", hash = "sha256:4b49fd8fe9f9ac23b78437da94c54aa7e9996fbb220bac024a67469ce5d0825f"},
{file = "grpcio-1.62.1-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:482ae2ae78679ba9ed5752099b32e5fe580443b4f798e1b71df412abf43375db"},
{file = "grpcio-1.62.1-cp311-cp311-win32.whl", hash = "sha256:1faa02530b6c7426404372515fe5ddf66e199c2ee613f88f025c6f3bd816450c"},
{file = "grpcio-1.62.1-cp311-cp311-win_amd64.whl", hash = "sha256:5bd90b8c395f39bc82a5fb32a0173e220e3f401ff697840f4003e15b96d1befc"},
{file = "grpcio-1.62.1-cp312-cp312-linux_armv7l.whl", hash = "sha256:b134d5d71b4e0837fff574c00e49176051a1c532d26c052a1e43231f252d813b"},
{file = "grpcio-1.62.1-cp312-cp312-macosx_10_10_universal2.whl", hash = "sha256:d1f6c96573dc09d50dbcbd91dbf71d5cf97640c9427c32584010fbbd4c0e0037"},
{file = "grpcio-1.62.1-cp312-cp312-manylinux_2_17_aarch64.whl", hash = "sha256:359f821d4578f80f41909b9ee9b76fb249a21035a061a327f91c953493782c31"},
{file = "grpcio-1.62.1-cp312-cp312-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:a485f0c2010c696be269184bdb5ae72781344cb4e60db976c59d84dd6354fac9"},
{file = "grpcio-1.62.1-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:b50b09b4dc01767163d67e1532f948264167cd27f49e9377e3556c3cba1268e1"},
{file = "grpcio-1.62.1-cp312-cp312-musllinux_1_1_i686.whl", hash = "sha256:3227c667dccbe38f2c4d943238b887bac588d97c104815aecc62d2fd976e014b"},
{file = "grpcio-1.62.1-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:3952b581eb121324853ce2b191dae08badb75cd493cb4e0243368aa9e61cfd41"},
{file = "grpcio-1.62.1-cp312-cp312-win32.whl", hash = "sha256:83a17b303425104d6329c10eb34bba186ffa67161e63fa6cdae7776ff76df73f"},
{file = "grpcio-1.62.1-cp312-cp312-win_amd64.whl", hash = "sha256:6696ffe440333a19d8d128e88d440f91fb92c75a80ce4b44d55800e656a3ef1d"},
{file = "grpcio-1.62.1-cp37-cp37m-linux_armv7l.whl", hash = "sha256:e3393b0823f938253370ebef033c9fd23d27f3eae8eb9a8f6264900c7ea3fb5a"},
{file = "grpcio-1.62.1-cp37-cp37m-macosx_10_10_universal2.whl", hash = "sha256:83e7ccb85a74beaeae2634f10eb858a0ed1a63081172649ff4261f929bacfd22"},
{file = "grpcio-1.62.1-cp37-cp37m-manylinux_2_17_aarch64.whl", hash = "sha256:882020c87999d54667a284c7ddf065b359bd00251fcd70279ac486776dbf84ec"},
{file = "grpcio-1.62.1-cp37-cp37m-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:a10383035e864f386fe096fed5c47d27a2bf7173c56a6e26cffaaa5a361addb1"},
{file = "grpcio-1.62.1-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:960edebedc6b9ada1ef58e1c71156f28689978188cd8cff3b646b57288a927d9"},
{file = "grpcio-1.62.1-cp37-cp37m-musllinux_1_1_i686.whl", hash = "sha256:23e2e04b83f347d0aadde0c9b616f4726c3d76db04b438fd3904b289a725267f"},
{file = "grpcio-1.62.1-cp37-cp37m-musllinux_1_1_x86_64.whl", hash = "sha256:978121758711916d34fe57c1f75b79cdfc73952f1481bb9583399331682d36f7"},
{file = "grpcio-1.62.1-cp37-cp37m-win_amd64.whl", hash = "sha256:9084086190cc6d628f282e5615f987288b95457292e969b9205e45b442276407"},
{file = "grpcio-1.62.1-cp38-cp38-linux_armv7l.whl", hash = "sha256:22bccdd7b23c420a27fd28540fb5dcbc97dc6be105f7698cb0e7d7a420d0e362"},
{file = "grpcio-1.62.1-cp38-cp38-macosx_10_10_universal2.whl", hash = "sha256:8999bf1b57172dbc7c3e4bb3c732658e918f5c333b2942243f10d0d653953ba9"},
{file = "grpcio-1.62.1-cp38-cp38-manylinux_2_17_aarch64.whl", hash = "sha256:d9e52558b8b8c2f4ac05ac86344a7417ccdd2b460a59616de49eb6933b07a0bd"},
{file = "grpcio-1.62.1-cp38-cp38-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:1714e7bc935780bc3de1b3fcbc7674209adf5208ff825799d579ffd6cd0bd505"},
{file = "grpcio-1.62.1-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:c8842ccbd8c0e253c1f189088228f9b433f7a93b7196b9e5b6f87dba393f5d5d"},
{file = "grpcio-1.62.1-cp38-cp38-musllinux_1_1_i686.whl", hash = "sha256:1f1e7b36bdff50103af95a80923bf1853f6823dd62f2d2a2524b66ed74103e49"},
{file = "grpcio-1.62.1-cp38-cp38-musllinux_1_1_x86_64.whl", hash = "sha256:bba97b8e8883a8038606480d6b6772289f4c907f6ba780fa1f7b7da7dfd76f06"},
{file = "grpcio-1.62.1-cp38-cp38-win32.whl", hash = "sha256:a7f615270fe534548112a74e790cd9d4f5509d744dd718cd442bf016626c22e4"},
{file = "grpcio-1.62.1-cp38-cp38-win_amd64.whl", hash = "sha256:e6c8c8693df718c5ecbc7babb12c69a4e3677fd11de8886f05ab22d4e6b1c43b"},
{file = "grpcio-1.62.1-cp39-cp39-linux_armv7l.whl", hash = "sha256:73db2dc1b201d20ab7083e7041946910bb991e7e9761a0394bbc3c2632326483"},
{file = "grpcio-1.62.1-cp39-cp39-macosx_10_10_universal2.whl", hash = "sha256:407b26b7f7bbd4f4751dbc9767a1f0716f9fe72d3d7e96bb3ccfc4aace07c8de"},
{file = "grpcio-1.62.1-cp39-cp39-manylinux_2_17_aarch64.whl", hash = "sha256:f8de7c8cef9261a2d0a62edf2ccea3d741a523c6b8a6477a340a1f2e417658de"},
{file = "grpcio-1.62.1-cp39-cp39-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:9bd5c8a1af40ec305d001c60236308a67e25419003e9bb3ebfab5695a8d0b369"},
{file = "grpcio-1.62.1-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:be0477cb31da67846a33b1a75c611f88bfbcd427fe17701b6317aefceee1b96f"},
{file = "grpcio-1.62.1-cp39-cp39-musllinux_1_1_i686.whl", hash = "sha256:60dcd824df166ba266ee0cfaf35a31406cd16ef602b49f5d4dfb21f014b0dedd"},
{file = "grpcio-1.62.1-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:973c49086cabab773525f6077f95e5a993bfc03ba8fc32e32f2c279497780585"},
{file = "grpcio-1.62.1-cp39-cp39-win32.whl", hash = "sha256:12859468e8918d3bd243d213cd6fd6ab07208195dc140763c00dfe901ce1e1b4"},
{file = "grpcio-1.62.1-cp39-cp39-win_amd64.whl", hash = "sha256:b7209117bbeebdfa5d898205cc55153a51285757902dd73c47de498ad4d11332"},
{file = "grpcio-1.62.1.tar.gz", hash = "sha256:6c455e008fa86d9e9a9d85bb76da4277c0d7d9668a3bfa70dbe86e9f3c759947"},
]
[package.extras]
protobuf = ["grpcio-tools (>=1.62.1)"]
[[package]] [[package]]
name = "gym" name = "gym"
version = "0.26.2" version = "0.26.2"
@ -1192,6 +1268,21 @@ html5 = ["html5lib"]
htmlsoup = ["BeautifulSoup4"] htmlsoup = ["BeautifulSoup4"]
source = ["Cython (>=3.0.7)"] source = ["Cython (>=3.0.7)"]
[[package]]
name = "markdown"
version = "3.5.2"
description = "Python implementation of John Gruber's Markdown."
optional = false
python-versions = ">=3.8"
files = [
{file = "Markdown-3.5.2-py3-none-any.whl", hash = "sha256:d43323865d89fc0cb9b20c75fc8ad313af307cc087e84b657d9eec768eddeadd"},
{file = "Markdown-3.5.2.tar.gz", hash = "sha256:e1ac7b3dc550ee80e602e71c1d168002f062e49f1b11e26a36264dafd4df2ef8"},
]
[package.extras]
docs = ["mdx-gh-links (>=0.2)", "mkdocs (>=1.5)", "mkdocs-gen-files", "mkdocs-literate-nav", "mkdocs-nature (>=0.6)", "mkdocs-section-index", "mkdocstrings[python]"]
testing = ["coverage", "pyyaml"]
[[package]] [[package]]
name = "markupsafe" name = "markupsafe"
version = "2.1.5" version = "2.1.5"
@ -2387,6 +2478,30 @@ urllib3 = ">=1.21.1,<3"
socks = ["PySocks (>=1.5.6,!=1.5.7)"] socks = ["PySocks (>=1.5.6,!=1.5.7)"]
use-chardet-on-py3 = ["chardet (>=3.0.2,<6)"] use-chardet-on-py3 = ["chardet (>=3.0.2,<6)"]
[[package]]
name = "robomimic"
version = "0.2.0"
description = "robomimic: A Modular Framework for Robot Learning from Demonstration"
optional = false
python-versions = ">=3"
files = [
{file = "robomimic-0.2.0.tar.gz", hash = "sha256:ee3bb5cf9c3e1feead6b57b43c5db738fd0a8e0c015fdf6419808af8fffdc463"},
]
[package.dependencies]
egl_probe = ">=1.0.1"
h5py = "*"
imageio = "*"
imageio-ffmpeg = "*"
numpy = ">=1.13.3"
psutil = "*"
tensorboard = "*"
tensorboardX = "*"
termcolor = "*"
torch = "*"
torchvision = "*"
tqdm = "*"
[[package]] [[package]]
name = "safetensors" name = "safetensors"
version = "0.4.2" version = "0.4.2"
@ -2874,6 +2989,55 @@ files = [
[package.dependencies] [package.dependencies]
mpmath = ">=0.19" mpmath = ">=0.19"
[[package]]
name = "tensorboard"
version = "2.16.2"
description = "TensorBoard lets you watch Tensors Flow"
optional = false
python-versions = ">=3.9"
files = [
{file = "tensorboard-2.16.2-py3-none-any.whl", hash = "sha256:9f2b4e7dad86667615c0e5cd072f1ea8403fc032a299f0072d6f74855775cc45"},
]
[package.dependencies]
absl-py = ">=0.4"
grpcio = ">=1.48.2"
markdown = ">=2.6.8"
numpy = ">=1.12.0"
protobuf = ">=3.19.6,<4.24.0 || >4.24.0"
setuptools = ">=41.0.0"
six = ">1.9"
tensorboard-data-server = ">=0.7.0,<0.8.0"
werkzeug = ">=1.0.1"
[[package]]
name = "tensorboard-data-server"
version = "0.7.2"
description = "Fast data loading for TensorBoard"
optional = false
python-versions = ">=3.7"
files = [
{file = "tensorboard_data_server-0.7.2-py3-none-any.whl", hash = "sha256:7e0610d205889588983836ec05dc098e80f97b7e7bbff7e994ebb78f578d0ddb"},
{file = "tensorboard_data_server-0.7.2-py3-none-macosx_10_9_x86_64.whl", hash = "sha256:9fe5d24221b29625dbc7328b0436ca7fc1c23de4acf4d272f1180856e32f9f60"},
{file = "tensorboard_data_server-0.7.2-py3-none-manylinux_2_31_x86_64.whl", hash = "sha256:ef687163c24185ae9754ed5650eb5bc4d84ff257aabdc33f0cc6f74d8ba54530"},
]
[[package]]
name = "tensorboardx"
version = "2.6.2.2"
description = "TensorBoardX lets you watch Tensors Flow without Tensorflow"
optional = false
python-versions = "*"
files = [
{file = "tensorboardX-2.6.2.2-py2.py3-none-any.whl", hash = "sha256:160025acbf759ede23fd3526ae9d9bfbfd8b68eb16c38a010ebe326dc6395db8"},
{file = "tensorboardX-2.6.2.2.tar.gz", hash = "sha256:c6476d7cd0d529b0b72f4acadb1269f9ed8b22f441e87a84f2a3b940bb87b666"},
]
[package.dependencies]
numpy = "*"
packaging = "*"
protobuf = ">=3.20"
[[package]] [[package]]
name = "tensordict" name = "tensordict"
version = "0.4.0+551331d" version = "0.4.0+551331d"
@ -2930,6 +3094,24 @@ numpy = "*"
[package.extras] [package.extras]
all = ["defusedxml", "fsspec", "imagecodecs (>=2023.8.12)", "lxml", "matplotlib", "zarr"] all = ["defusedxml", "fsspec", "imagecodecs (>=2023.8.12)", "lxml", "matplotlib", "zarr"]
[[package]]
name = "timm"
version = "0.9.16"
description = "PyTorch Image Models"
optional = false
python-versions = ">=3.8"
files = [
{file = "timm-0.9.16-py3-none-any.whl", hash = "sha256:bf5704014476ab011589d3c14172ee4c901fd18f9110a928019cac5be2945914"},
{file = "timm-0.9.16.tar.gz", hash = "sha256:891e54f375d55adf31a71ab0c117761f0e472f9f3971858ecdd1e7376b7071e6"},
]
[package.dependencies]
huggingface_hub = "*"
pyyaml = "*"
safetensors = "*"
torch = "*"
torchvision = "*"
[[package]] [[package]]
name = "tomli" name = "tomli"
version = "2.0.1" version = "2.0.1"
@ -3215,6 +3397,23 @@ perf = ["orjson"]
reports = ["pydantic (>=2.0.0)"] reports = ["pydantic (>=2.0.0)"]
sweeps = ["sweeps (>=0.2.0)"] sweeps = ["sweeps (>=0.2.0)"]
[[package]]
name = "werkzeug"
version = "3.0.1"
description = "The comprehensive WSGI web application library."
optional = false
python-versions = ">=3.8"
files = [
{file = "werkzeug-3.0.1-py3-none-any.whl", hash = "sha256:90a285dc0e42ad56b34e696398b8122ee4c681833fb35b8334a095d82c56da10"},
{file = "werkzeug-3.0.1.tar.gz", hash = "sha256:507e811ecea72b18a404947aded4b3390e1db8f826b494d76550ef45bb3b1dcc"},
]
[package.dependencies]
MarkupSafe = ">=2.1.1"
[package.extras]
watchdog = ["watchdog (>=2.3)"]
[[package]] [[package]]
name = "zarr" name = "zarr"
version = "2.17.1" version = "2.17.1"
@ -3254,4 +3453,4 @@ testing = ["big-O", "jaraco.functools", "jaraco.itertools", "more-itertools", "p
[metadata] [metadata]
lock-version = "2.0" lock-version = "2.0"
python-versions = "^3.10" python-versions = "^3.10"
content-hash = "3d82309a7b2388d774b56ceb6f6906ef0732d8cedda0d76cc84a30e239949be8" content-hash = "d7c551c44380ac26e784390f077d902287c40582127db91f9a85ab5d7707ad59"

View File

@ -49,6 +49,8 @@ opencv-python = "^4.9.0.80"
diffusers = "^0.26.3" diffusers = "^0.26.3"
torchvision = "^0.17.1" torchvision = "^0.17.1"
h5py = "^3.10.0" h5py = "^3.10.0"
robomimic = "0.2.0"
timm = "^0.9.16"
dm-control = "1.0.14" dm-control = "1.0.14"