From 2a014874949a58c14499052f07bb7a620c2e0faa Mon Sep 17 00:00:00 2001 From: Alexander Soare Date: Mon, 11 Mar 2024 13:34:04 +0000 Subject: [PATCH 01/16] early training loss as expected --- .../model/multi_image_obs_encoder.py | 25 ++- lerobot/common/policies/diffusion/policy.py | 4 +- lerobot/configs/policy/diffusion.yaml | 6 +- poetry.lock | 203 +++++++++++++++++- pyproject.toml | 2 + 5 files changed, 232 insertions(+), 8 deletions(-) diff --git a/lerobot/common/policies/diffusion/model/multi_image_obs_encoder.py b/lerobot/common/policies/diffusion/model/multi_image_obs_encoder.py index 94dc6f49..0b4bba7d 100644 --- a/lerobot/common/policies/diffusion/model/multi_image_obs_encoder.py +++ b/lerobot/common/policies/diffusion/model/multi_image_obs_encoder.py @@ -1,15 +1,37 @@ import copy from typing import Dict, Tuple, Union +import timm import torch import torch.nn as nn import torchvision +from robomimic.models.base_nets import SpatialSoftmax from lerobot.common.policies.diffusion.model.crop_randomizer import CropRandomizer from lerobot.common.policies.diffusion.model.module_attr_mixin import ModuleAttrMixin from lerobot.common.policies.diffusion.pytorch_utils import replace_submodules +class RgbEncoder(nn.Module): + """Following `VisualCore` from Robomimic 0.2.0.""" + + def __init__(self, input_shape, model_name="resnet18", pretrained=False, num_keypoints=32): + """ + resnet_name: a timm model name. + pretrained: whether to use timm pretrained weights. + num_keypoints: Number of keypoints for SpatialSoftmax (default value of 32 matches PushT Image). + """ + super().__init__() + self.backbone = timm.create_model(model_name, pretrained, num_classes=0, global_pool="") + # Figure out the feature map shape. + with torch.inference_mode(): + feat_map_shape = tuple(self.backbone(torch.zeros(size=(1, *input_shape))).shape[1:]) + self.pool = SpatialSoftmax(feat_map_shape, num_kp=num_keypoints) + + def forward(self, x): + return torch.flatten(self.pool(self.backbone(x)), start_dim=1) + + class MultiImageObsEncoder(ModuleAttrMixin): def __init__( self, @@ -101,7 +123,8 @@ class MultiImageObsEncoder(ModuleAttrMixin): if imagenet_norm: # TODO(rcadene): move normalizer to dataset and env this_normalizer = torchvision.transforms.Normalize( - mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225] + mean=[127.5, 127.5, 127.5], + std=[127.5, 127.5, 127.5], ) this_transform = nn.Sequential(this_resizer, this_randomizer, this_normalizer) diff --git a/lerobot/common/policies/diffusion/policy.py b/lerobot/common/policies/diffusion/policy.py index 3df76aa4..1a7e7772 100644 --- a/lerobot/common/policies/diffusion/policy.py +++ b/lerobot/common/policies/diffusion/policy.py @@ -7,7 +7,7 @@ import torch.nn as nn from lerobot.common.policies.diffusion.diffusion_unet_image_policy import DiffusionUnetImagePolicy from lerobot.common.policies.diffusion.model.lr_scheduler import get_scheduler -from lerobot.common.policies.diffusion.model.multi_image_obs_encoder import MultiImageObsEncoder +from lerobot.common.policies.diffusion.model.multi_image_obs_encoder import MultiImageObsEncoder, RgbEncoder class DiffusionPolicy(nn.Module): @@ -38,7 +38,7 @@ class DiffusionPolicy(nn.Module): self.cfg = cfg noise_scheduler = hydra.utils.instantiate(cfg_noise_scheduler) - rgb_model = hydra.utils.instantiate(cfg_rgb_model) + rgb_model = RgbEncoder(input_shape=shape_meta.obs.image.shape, **cfg_rgb_model) obs_encoder = MultiImageObsEncoder( rgb_model=rgb_model, **cfg_obs_encoder, diff --git a/lerobot/configs/policy/diffusion.yaml b/lerobot/configs/policy/diffusion.yaml index 0dae5056..28fd4e4e 100644 --- a/lerobot/configs/policy/diffusion.yaml +++ b/lerobot/configs/policy/diffusion.yaml @@ -84,9 +84,9 @@ obs_encoder: imagenet_norm: True rgb_model: - _target_: lerobot.common.policies.diffusion.pytorch_utils.get_resnet - name: resnet18 - weights: null + model_name: resnet18 + pretrained: false + num_keypoints: 32 ema: _target_: lerobot.common.policies.diffusion.model.ema_model.EMAModel diff --git a/poetry.lock b/poetry.lock index db4f8f3e..85ab6fa0 100644 --- a/poetry.lock +++ b/poetry.lock @@ -1,4 +1,4 @@ -# This file is automatically @generated by Poetry 1.8.1 and should not be changed by hand. +# This file is automatically @generated by Poetry 1.8.2 and should not be changed by hand. [[package]] name = "absl-py" @@ -573,6 +573,16 @@ files = [ [package.dependencies] six = ">=1.4.0" +[[package]] +name = "egl-probe" +version = "1.0.2" +description = "" +optional = false +python-versions = "*" +files = [ + {file = "egl_probe-1.0.2.tar.gz", hash = "sha256:29bdca7b08da1e060cfb42cd46af8300a7ac4f3b1b2eeb16e545ea16d9a5ac93"}, +] + [[package]] name = "einops" version = "0.7.0" @@ -769,6 +779,72 @@ files = [ [package.extras] preview = ["glfw-preview"] +[[package]] +name = "grpcio" +version = "1.62.1" +description = "HTTP/2-based RPC framework" +optional = false +python-versions = ">=3.7" +files = [ + {file = "grpcio-1.62.1-cp310-cp310-linux_armv7l.whl", hash = "sha256:179bee6f5ed7b5f618844f760b6acf7e910988de77a4f75b95bbfaa8106f3c1e"}, + {file = "grpcio-1.62.1-cp310-cp310-macosx_12_0_universal2.whl", hash = "sha256:48611e4fa010e823ba2de8fd3f77c1322dd60cb0d180dc6630a7e157b205f7ea"}, + {file = "grpcio-1.62.1-cp310-cp310-manylinux_2_17_aarch64.whl", hash = "sha256:b2a0e71b0a2158aa4bce48be9f8f9eb45cbd17c78c7443616d00abbe2a509f6d"}, + {file = "grpcio-1.62.1-cp310-cp310-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:fbe80577c7880911d3ad65e5ecc997416c98f354efeba2f8d0f9112a67ed65a5"}, + {file = "grpcio-1.62.1-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:58f6c693d446964e3292425e1d16e21a97a48ba9172f2d0df9d7b640acb99243"}, + {file = "grpcio-1.62.1-cp310-cp310-musllinux_1_1_i686.whl", hash = "sha256:77c339403db5a20ef4fed02e4d1a9a3d9866bf9c0afc77a42234677313ea22f3"}, + {file = "grpcio-1.62.1-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:b5a4ea906db7dec694098435d84bf2854fe158eb3cd51e1107e571246d4d1d70"}, + {file = "grpcio-1.62.1-cp310-cp310-win32.whl", hash = "sha256:4187201a53f8561c015bc745b81a1b2d278967b8de35f3399b84b0695e281d5f"}, + {file = "grpcio-1.62.1-cp310-cp310-win_amd64.whl", hash = "sha256:844d1f3fb11bd1ed362d3fdc495d0770cfab75761836193af166fee113421d66"}, + {file = "grpcio-1.62.1-cp311-cp311-linux_armv7l.whl", hash = "sha256:833379943d1728a005e44103f17ecd73d058d37d95783eb8f0b28ddc1f54d7b2"}, + {file = "grpcio-1.62.1-cp311-cp311-macosx_10_10_universal2.whl", hash = "sha256:c7fcc6a32e7b7b58f5a7d27530669337a5d587d4066060bcb9dee7a8c833dfb7"}, + {file = "grpcio-1.62.1-cp311-cp311-manylinux_2_17_aarch64.whl", hash = "sha256:fa7d28eb4d50b7cbe75bb8b45ed0da9a1dc5b219a0af59449676a29c2eed9698"}, + {file = "grpcio-1.62.1-cp311-cp311-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:48f7135c3de2f298b833be8b4ae20cafe37091634e91f61f5a7eb3d61ec6f660"}, + {file = "grpcio-1.62.1-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:71f11fd63365ade276c9d4a7b7df5c136f9030e3457107e1791b3737a9b9ed6a"}, + {file = "grpcio-1.62.1-cp311-cp311-musllinux_1_1_i686.whl", hash = "sha256:4b49fd8fe9f9ac23b78437da94c54aa7e9996fbb220bac024a67469ce5d0825f"}, + {file = "grpcio-1.62.1-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:482ae2ae78679ba9ed5752099b32e5fe580443b4f798e1b71df412abf43375db"}, + {file = "grpcio-1.62.1-cp311-cp311-win32.whl", hash = "sha256:1faa02530b6c7426404372515fe5ddf66e199c2ee613f88f025c6f3bd816450c"}, + {file = "grpcio-1.62.1-cp311-cp311-win_amd64.whl", hash = "sha256:5bd90b8c395f39bc82a5fb32a0173e220e3f401ff697840f4003e15b96d1befc"}, + {file = "grpcio-1.62.1-cp312-cp312-linux_armv7l.whl", hash = "sha256:b134d5d71b4e0837fff574c00e49176051a1c532d26c052a1e43231f252d813b"}, + {file = "grpcio-1.62.1-cp312-cp312-macosx_10_10_universal2.whl", hash = "sha256:d1f6c96573dc09d50dbcbd91dbf71d5cf97640c9427c32584010fbbd4c0e0037"}, + {file = "grpcio-1.62.1-cp312-cp312-manylinux_2_17_aarch64.whl", hash = "sha256:359f821d4578f80f41909b9ee9b76fb249a21035a061a327f91c953493782c31"}, + {file = "grpcio-1.62.1-cp312-cp312-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:a485f0c2010c696be269184bdb5ae72781344cb4e60db976c59d84dd6354fac9"}, + {file = "grpcio-1.62.1-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:b50b09b4dc01767163d67e1532f948264167cd27f49e9377e3556c3cba1268e1"}, + {file = "grpcio-1.62.1-cp312-cp312-musllinux_1_1_i686.whl", hash = "sha256:3227c667dccbe38f2c4d943238b887bac588d97c104815aecc62d2fd976e014b"}, + {file = "grpcio-1.62.1-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:3952b581eb121324853ce2b191dae08badb75cd493cb4e0243368aa9e61cfd41"}, + {file = "grpcio-1.62.1-cp312-cp312-win32.whl", hash = "sha256:83a17b303425104d6329c10eb34bba186ffa67161e63fa6cdae7776ff76df73f"}, + {file = "grpcio-1.62.1-cp312-cp312-win_amd64.whl", hash = "sha256:6696ffe440333a19d8d128e88d440f91fb92c75a80ce4b44d55800e656a3ef1d"}, + {file = "grpcio-1.62.1-cp37-cp37m-linux_armv7l.whl", hash = "sha256:e3393b0823f938253370ebef033c9fd23d27f3eae8eb9a8f6264900c7ea3fb5a"}, + {file = "grpcio-1.62.1-cp37-cp37m-macosx_10_10_universal2.whl", hash = "sha256:83e7ccb85a74beaeae2634f10eb858a0ed1a63081172649ff4261f929bacfd22"}, + {file = "grpcio-1.62.1-cp37-cp37m-manylinux_2_17_aarch64.whl", hash = "sha256:882020c87999d54667a284c7ddf065b359bd00251fcd70279ac486776dbf84ec"}, + {file = "grpcio-1.62.1-cp37-cp37m-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:a10383035e864f386fe096fed5c47d27a2bf7173c56a6e26cffaaa5a361addb1"}, + {file = "grpcio-1.62.1-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:960edebedc6b9ada1ef58e1c71156f28689978188cd8cff3b646b57288a927d9"}, + {file = "grpcio-1.62.1-cp37-cp37m-musllinux_1_1_i686.whl", hash = "sha256:23e2e04b83f347d0aadde0c9b616f4726c3d76db04b438fd3904b289a725267f"}, + {file = "grpcio-1.62.1-cp37-cp37m-musllinux_1_1_x86_64.whl", hash = "sha256:978121758711916d34fe57c1f75b79cdfc73952f1481bb9583399331682d36f7"}, + {file = "grpcio-1.62.1-cp37-cp37m-win_amd64.whl", hash = "sha256:9084086190cc6d628f282e5615f987288b95457292e969b9205e45b442276407"}, + {file = "grpcio-1.62.1-cp38-cp38-linux_armv7l.whl", hash = "sha256:22bccdd7b23c420a27fd28540fb5dcbc97dc6be105f7698cb0e7d7a420d0e362"}, + {file = "grpcio-1.62.1-cp38-cp38-macosx_10_10_universal2.whl", hash = "sha256:8999bf1b57172dbc7c3e4bb3c732658e918f5c333b2942243f10d0d653953ba9"}, + {file = "grpcio-1.62.1-cp38-cp38-manylinux_2_17_aarch64.whl", hash = "sha256:d9e52558b8b8c2f4ac05ac86344a7417ccdd2b460a59616de49eb6933b07a0bd"}, + {file = "grpcio-1.62.1-cp38-cp38-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:1714e7bc935780bc3de1b3fcbc7674209adf5208ff825799d579ffd6cd0bd505"}, + {file = "grpcio-1.62.1-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:c8842ccbd8c0e253c1f189088228f9b433f7a93b7196b9e5b6f87dba393f5d5d"}, + {file = "grpcio-1.62.1-cp38-cp38-musllinux_1_1_i686.whl", hash = "sha256:1f1e7b36bdff50103af95a80923bf1853f6823dd62f2d2a2524b66ed74103e49"}, + {file = "grpcio-1.62.1-cp38-cp38-musllinux_1_1_x86_64.whl", hash = "sha256:bba97b8e8883a8038606480d6b6772289f4c907f6ba780fa1f7b7da7dfd76f06"}, + {file = "grpcio-1.62.1-cp38-cp38-win32.whl", hash = "sha256:a7f615270fe534548112a74e790cd9d4f5509d744dd718cd442bf016626c22e4"}, + {file = "grpcio-1.62.1-cp38-cp38-win_amd64.whl", hash = "sha256:e6c8c8693df718c5ecbc7babb12c69a4e3677fd11de8886f05ab22d4e6b1c43b"}, + {file = "grpcio-1.62.1-cp39-cp39-linux_armv7l.whl", hash = "sha256:73db2dc1b201d20ab7083e7041946910bb991e7e9761a0394bbc3c2632326483"}, + {file = "grpcio-1.62.1-cp39-cp39-macosx_10_10_universal2.whl", hash = "sha256:407b26b7f7bbd4f4751dbc9767a1f0716f9fe72d3d7e96bb3ccfc4aace07c8de"}, + {file = "grpcio-1.62.1-cp39-cp39-manylinux_2_17_aarch64.whl", hash = "sha256:f8de7c8cef9261a2d0a62edf2ccea3d741a523c6b8a6477a340a1f2e417658de"}, + {file = "grpcio-1.62.1-cp39-cp39-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:9bd5c8a1af40ec305d001c60236308a67e25419003e9bb3ebfab5695a8d0b369"}, + {file = "grpcio-1.62.1-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:be0477cb31da67846a33b1a75c611f88bfbcd427fe17701b6317aefceee1b96f"}, + {file = "grpcio-1.62.1-cp39-cp39-musllinux_1_1_i686.whl", hash = "sha256:60dcd824df166ba266ee0cfaf35a31406cd16ef602b49f5d4dfb21f014b0dedd"}, + {file = "grpcio-1.62.1-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:973c49086cabab773525f6077f95e5a993bfc03ba8fc32e32f2c279497780585"}, + {file = "grpcio-1.62.1-cp39-cp39-win32.whl", hash = "sha256:12859468e8918d3bd243d213cd6fd6ab07208195dc140763c00dfe901ce1e1b4"}, + {file = "grpcio-1.62.1-cp39-cp39-win_amd64.whl", hash = "sha256:b7209117bbeebdfa5d898205cc55153a51285757902dd73c47de498ad4d11332"}, + {file = "grpcio-1.62.1.tar.gz", hash = "sha256:6c455e008fa86d9e9a9d85bb76da4277c0d7d9668a3bfa70dbe86e9f3c759947"}, +] + +[package.extras] +protobuf = ["grpcio-tools (>=1.62.1)"] + [[package]] name = "gym" version = "0.26.2" @@ -1076,6 +1152,21 @@ files = [ {file = "llvmlite-0.42.0.tar.gz", hash = "sha256:f92b09243c0cc3f457da8b983f67bd8e1295d0f5b3746c7a1861d7a99403854a"}, ] +[[package]] +name = "markdown" +version = "3.5.2" +description = "Python implementation of John Gruber's Markdown." +optional = false +python-versions = ">=3.8" +files = [ + {file = "Markdown-3.5.2-py3-none-any.whl", hash = "sha256:d43323865d89fc0cb9b20c75fc8ad313af307cc087e84b657d9eec768eddeadd"}, + {file = "Markdown-3.5.2.tar.gz", hash = "sha256:e1ac7b3dc550ee80e602e71c1d168002f062e49f1b11e26a36264dafd4df2ef8"}, +] + +[package.extras] +docs = ["mdx-gh-links (>=0.2)", "mkdocs (>=1.5)", "mkdocs-gen-files", "mkdocs-literate-nav", "mkdocs-nature (>=0.6)", "mkdocs-section-index", "mkdocstrings[python]"] +testing = ["coverage", "pyyaml"] + [[package]] name = "markupsafe" version = "2.1.5" @@ -2259,6 +2350,30 @@ urllib3 = ">=1.21.1,<3" socks = ["PySocks (>=1.5.6,!=1.5.7)"] use-chardet-on-py3 = ["chardet (>=3.0.2,<6)"] +[[package]] +name = "robomimic" +version = "0.2.0" +description = "robomimic: A Modular Framework for Robot Learning from Demonstration" +optional = false +python-versions = ">=3" +files = [ + {file = "robomimic-0.2.0.tar.gz", hash = "sha256:ee3bb5cf9c3e1feead6b57b43c5db738fd0a8e0c015fdf6419808af8fffdc463"}, +] + +[package.dependencies] +egl_probe = ">=1.0.1" +h5py = "*" +imageio = "*" +imageio-ffmpeg = "*" +numpy = ">=1.13.3" +psutil = "*" +tensorboard = "*" +tensorboardX = "*" +termcolor = "*" +torch = "*" +torchvision = "*" +tqdm = "*" + [[package]] name = "safetensors" version = "0.4.2" @@ -2746,6 +2861,55 @@ files = [ [package.dependencies] mpmath = ">=0.19" +[[package]] +name = "tensorboard" +version = "2.16.2" +description = "TensorBoard lets you watch Tensors Flow" +optional = false +python-versions = ">=3.9" +files = [ + {file = "tensorboard-2.16.2-py3-none-any.whl", hash = "sha256:9f2b4e7dad86667615c0e5cd072f1ea8403fc032a299f0072d6f74855775cc45"}, +] + +[package.dependencies] +absl-py = ">=0.4" +grpcio = ">=1.48.2" +markdown = ">=2.6.8" +numpy = ">=1.12.0" +protobuf = ">=3.19.6,<4.24.0 || >4.24.0" +setuptools = ">=41.0.0" +six = ">1.9" +tensorboard-data-server = ">=0.7.0,<0.8.0" +werkzeug = ">=1.0.1" + +[[package]] +name = "tensorboard-data-server" +version = "0.7.2" +description = "Fast data loading for TensorBoard" +optional = false +python-versions = ">=3.7" +files = [ + {file = "tensorboard_data_server-0.7.2-py3-none-any.whl", hash = "sha256:7e0610d205889588983836ec05dc098e80f97b7e7bbff7e994ebb78f578d0ddb"}, + {file = "tensorboard_data_server-0.7.2-py3-none-macosx_10_9_x86_64.whl", hash = "sha256:9fe5d24221b29625dbc7328b0436ca7fc1c23de4acf4d272f1180856e32f9f60"}, + {file = "tensorboard_data_server-0.7.2-py3-none-manylinux_2_31_x86_64.whl", hash = "sha256:ef687163c24185ae9754ed5650eb5bc4d84ff257aabdc33f0cc6f74d8ba54530"}, +] + +[[package]] +name = "tensorboardx" +version = "2.6.2.2" +description = "TensorBoardX lets you watch Tensors Flow without Tensorflow" +optional = false +python-versions = "*" +files = [ + {file = "tensorboardX-2.6.2.2-py2.py3-none-any.whl", hash = "sha256:160025acbf759ede23fd3526ae9d9bfbfd8b68eb16c38a010ebe326dc6395db8"}, + {file = "tensorboardX-2.6.2.2.tar.gz", hash = "sha256:c6476d7cd0d529b0b72f4acadb1269f9ed8b22f441e87a84f2a3b940bb87b666"}, +] + +[package.dependencies] +numpy = "*" +packaging = "*" +protobuf = ">=3.20" + [[package]] name = "tensordict" version = "0.4.0+551331d" @@ -2802,6 +2966,24 @@ numpy = "*" [package.extras] all = ["defusedxml", "fsspec", "imagecodecs (>=2023.8.12)", "lxml", "matplotlib", "zarr"] +[[package]] +name = "timm" +version = "0.9.16" +description = "PyTorch Image Models" +optional = false +python-versions = ">=3.8" +files = [ + {file = "timm-0.9.16-py3-none-any.whl", hash = "sha256:bf5704014476ab011589d3c14172ee4c901fd18f9110a928019cac5be2945914"}, + {file = "timm-0.9.16.tar.gz", hash = "sha256:891e54f375d55adf31a71ab0c117761f0e472f9f3971858ecdd1e7376b7071e6"}, +] + +[package.dependencies] +huggingface_hub = "*" +pyyaml = "*" +safetensors = "*" +torch = "*" +torchvision = "*" + [[package]] name = "tomli" version = "2.0.1" @@ -3086,6 +3268,23 @@ perf = ["orjson"] reports = ["pydantic (>=2.0.0)"] sweeps = ["sweeps (>=0.2.0)"] +[[package]] +name = "werkzeug" +version = "3.0.1" +description = "The comprehensive WSGI web application library." +optional = false +python-versions = ">=3.8" +files = [ + {file = "werkzeug-3.0.1-py3-none-any.whl", hash = "sha256:90a285dc0e42ad56b34e696398b8122ee4c681833fb35b8334a095d82c56da10"}, + {file = "werkzeug-3.0.1.tar.gz", hash = "sha256:507e811ecea72b18a404947aded4b3390e1db8f826b494d76550ef45bb3b1dcc"}, +] + +[package.dependencies] +MarkupSafe = ">=2.1.1" + +[package.extras] +watchdog = ["watchdog (>=2.3)"] + [[package]] name = "zarr" version = "2.17.0" @@ -3125,4 +3324,4 @@ testing = ["big-O", "jaraco.functools", "jaraco.itertools", "more-itertools", "p [metadata] lock-version = "2.0" python-versions = "^3.10" -content-hash = "c4d83579aed1c8c2e54cad7c8ec81b95a09ab8faff74fc9a4cb20bd00e4ddec6" +content-hash = "adc2cbe447c2ebe4a7273a4a849d725f6df56106e0f6bf178cf798de5d6337e2" diff --git a/pyproject.toml b/pyproject.toml index ebce8f32..ef31ece0 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -48,6 +48,8 @@ opencv-python = "^4.9.0.80" diffusers = "^0.26.3" torchvision = "^0.17.1" h5py = "^3.10.0" +robomimic = "0.2.0" +timm = "^0.9.16" [tool.poetry.group.dev.dependencies] From 87fcc536f91e857873c5e1cc9c05a23335dd27e6 Mon Sep 17 00:00:00 2001 From: Alexander Soare Date: Mon, 11 Mar 2024 18:45:21 +0000 Subject: [PATCH 02/16] wip - still need to verify full training run --- lerobot/common/envs/pusht/pusht_image_env.py | 2 +- .../diffusion/model/multi_image_obs_encoder.py | 2 ++ lerobot/configs/policy/diffusion.yaml | 12 ++++++------ 3 files changed, 9 insertions(+), 7 deletions(-) diff --git a/lerobot/common/envs/pusht/pusht_image_env.py b/lerobot/common/envs/pusht/pusht_image_env.py index 5f7bc03c..2d52c89e 100644 --- a/lerobot/common/envs/pusht/pusht_image_env.py +++ b/lerobot/common/envs/pusht/pusht_image_env.py @@ -25,7 +25,7 @@ class PushTImageEnv(PushTEnv): img = super()._render_frame(mode="rgb_array") agent_pos = np.array(self.agent.position) - img_obs = np.moveaxis(img.astype(np.float32) / 255, -1, 0) + img_obs = np.moveaxis(img.astype(np.float32), -1, 0) obs = {"image": img_obs, "agent_pos": agent_pos} # draw action diff --git a/lerobot/common/policies/diffusion/model/multi_image_obs_encoder.py b/lerobot/common/policies/diffusion/model/multi_image_obs_encoder.py index 0b4bba7d..91472dd5 100644 --- a/lerobot/common/policies/diffusion/model/multi_image_obs_encoder.py +++ b/lerobot/common/policies/diffusion/model/multi_image_obs_encoder.py @@ -123,6 +123,8 @@ class MultiImageObsEncoder(ModuleAttrMixin): if imagenet_norm: # TODO(rcadene): move normalizer to dataset and env this_normalizer = torchvision.transforms.Normalize( + # Note: This matches the normalization in the original impl. for PushT Image. This may not be + # the case for other tasks. mean=[127.5, 127.5, 127.5], std=[127.5, 127.5, 127.5], ) diff --git a/lerobot/configs/policy/diffusion.yaml b/lerobot/configs/policy/diffusion.yaml index 28fd4e4e..f07e4754 100644 --- a/lerobot/configs/policy/diffusion.yaml +++ b/lerobot/configs/policy/diffusion.yaml @@ -42,8 +42,8 @@ policy: num_inference_steps: 100 obs_as_global_cond: ${obs_as_global_cond} # crop_shape: null - diffusion_step_embed_dim: 256 # before 128 - down_dims: [256, 512, 1024] # before [512, 1024, 2048] + diffusion_step_embed_dim: 128 + down_dims: [512, 1024, 2048] kernel_size: 5 n_groups: 8 cond_predict_scale: True @@ -109,13 +109,13 @@ training: debug: False resume: True # optimization - # lr_scheduler: cosine - # lr_warmup_steps: 500 - num_epochs: 8000 + lr_scheduler: cosine + lr_warmup_steps: 500 + num_epochs: 500 # gradient_accumulate_every: 1 # EMA destroys performance when used with BatchNorm # replace BatchNorm with GroupNorm. - # use_ema: True + use_ema: True freeze_encoder: False # training loop control # in epochs From 98484ac68ed36247b3721c072b5c0637ce33570c Mon Sep 17 00:00:00 2001 From: Alexander Soare Date: Tue, 12 Mar 2024 21:59:01 +0000 Subject: [PATCH 03/16] ready for review --- .../diffusion/model/multi_image_obs_encoder.py | 12 ++++-------- lerobot/configs/policy/diffusion.yaml | 2 +- 2 files changed, 5 insertions(+), 9 deletions(-) diff --git a/lerobot/common/policies/diffusion/model/multi_image_obs_encoder.py b/lerobot/common/policies/diffusion/model/multi_image_obs_encoder.py index 91472dd5..6a1d3c0d 100644 --- a/lerobot/common/policies/diffusion/model/multi_image_obs_encoder.py +++ b/lerobot/common/policies/diffusion/model/multi_image_obs_encoder.py @@ -1,5 +1,5 @@ import copy -from typing import Dict, Tuple, Union +from typing import Dict, Optional, Tuple, Union import timm import torch @@ -46,7 +46,7 @@ class MultiImageObsEncoder(ModuleAttrMixin): share_rgb_model: bool = False, # renormalize rgb input with imagenet normalization # assuming input in [0,1] - imagenet_norm: bool = False, + norm_mean_std: Optional[tuple[float, float]] = None, ): """ Assumes rgb input: B,C,H,W @@ -120,13 +120,9 @@ class MultiImageObsEncoder(ModuleAttrMixin): this_normalizer = torchvision.transforms.CenterCrop(size=(h, w)) # configure normalizer this_normalizer = nn.Identity() - if imagenet_norm: - # TODO(rcadene): move normalizer to dataset and env + if norm_mean_std is not None: this_normalizer = torchvision.transforms.Normalize( - # Note: This matches the normalization in the original impl. for PushT Image. This may not be - # the case for other tasks. - mean=[127.5, 127.5, 127.5], - std=[127.5, 127.5, 127.5], + mean=norm_mean_std[0], std=norm_mean_std[1] ) this_transform = nn.Sequential(this_resizer, this_randomizer, this_normalizer) diff --git a/lerobot/configs/policy/diffusion.yaml b/lerobot/configs/policy/diffusion.yaml index f07e4754..7de44102 100644 --- a/lerobot/configs/policy/diffusion.yaml +++ b/lerobot/configs/policy/diffusion.yaml @@ -81,7 +81,7 @@ obs_encoder: # random_crop: True use_group_norm: True share_rgb_model: False - imagenet_norm: True + norm_mean_std: [0.5, 0.5] # for PushT the original impl normalizes to [-1, 1] (maybe not the case for robomimic envs) rgb_model: model_name: resnet18 From ba91976944bebfcaca477e57105740cbd7876aca Mon Sep 17 00:00:00 2001 From: Alexander Soare Date: Thu, 14 Mar 2024 15:22:55 +0000 Subject: [PATCH 04/16] wip: still needs batch logic for act and tdmp --- lerobot/common/envs/aloha/env.py | 43 +++++------ lerobot/common/envs/factory.py | 8 ++- lerobot/common/envs/pusht/env.py | 41 ++++------- lerobot/common/policies/abstract.py | 54 ++++++++++++++ lerobot/common/policies/act/policy.py | 6 +- lerobot/common/policies/diffusion/policy.py | 14 ++-- lerobot/common/policies/tdmpc/policy.py | 5 +- lerobot/configs/default.yaml | 2 + lerobot/scripts/eval.py | 76 ++++++++++++-------- lerobot/scripts/train.py | 11 ++- tests/test_policies.py | 80 +++++++++++++++++++++ 11 files changed, 240 insertions(+), 100 deletions(-) create mode 100644 lerobot/common/policies/abstract.py diff --git a/lerobot/common/envs/aloha/env.py b/lerobot/common/envs/aloha/env.py index 1211a37a..7ef24f2d 100644 --- a/lerobot/common/envs/aloha/env.py +++ b/lerobot/common/envs/aloha/env.py @@ -168,42 +168,31 @@ class AlohaEnv(AbstractEnv): def _step(self, tensordict: TensorDict): td = tensordict action = td["action"].numpy() - # step expects shape=(4,) so we pad if necessary + assert action.ndim == 1 # TODO(rcadene): add info["is_success"] and info["success"] ? - sum_reward = 0 - if action.ndim == 1: - action = einops.repeat(action, "c -> t c", t=self.frame_skip) - else: - if self.frame_skip > 1: - raise NotImplementedError() + _, reward, _, raw_obs = self._env.step(action) - num_action_steps = action.shape[0] - for i in range(num_action_steps): - _, reward, discount, raw_obs = self._env.step(action[i]) - del discount # not used + # TODO(rcadene): add an enum + success = done = reward == 4 + obs = self._format_raw_obs(raw_obs) - # TOOD(rcadene): add an enum - success = done = reward == 4 - sum_reward += reward - obs = self._format_raw_obs(raw_obs) + if self.num_prev_obs > 0: + stacked_obs = {} + if "image" in obs: + self._prev_obs_image_queue.append(obs["image"]["top"]) + stacked_obs["image"] = {"top": torch.stack(list(self._prev_obs_image_queue))} + if "state" in obs: + self._prev_obs_state_queue.append(obs["state"]) + stacked_obs["state"] = torch.stack(list(self._prev_obs_state_queue)) + obs = stacked_obs - if self.num_prev_obs > 0: - stacked_obs = {} - if "image" in obs: - self._prev_obs_image_queue.append(obs["image"]["top"]) - stacked_obs["image"] = {"top": torch.stack(list(self._prev_obs_image_queue))} - if "state" in obs: - self._prev_obs_state_queue.append(obs["state"]) - stacked_obs["state"] = torch.stack(list(self._prev_obs_state_queue)) - obs = stacked_obs - - self.call_rendering_hooks() + self.call_rendering_hooks() td = TensorDict( { "observation": TensorDict(obs, batch_size=[]), - "reward": torch.tensor([sum_reward], dtype=torch.float32), + "reward": torch.tensor([reward], dtype=torch.float32), # succes and done are true when coverage > self.success_threshold in env "done": torch.tensor([done], dtype=torch.bool), "success": torch.tensor([success], dtype=torch.bool), diff --git a/lerobot/common/envs/factory.py b/lerobot/common/envs/factory.py index 921cbad7..d6b294eb 100644 --- a/lerobot/common/envs/factory.py +++ b/lerobot/common/envs/factory.py @@ -1,15 +1,17 @@ from torchrl.envs.transforms import Compose, StepCounter, Transform, TransformedEnv -def make_env(cfg, transform=None): +def make_env(cfg, seed=None, transform=None): + """ + Provide seed to override the seed in the cfg (useful for batched environments). + """ kwargs = { "frame_skip": cfg.env.action_repeat, "from_pixels": cfg.env.from_pixels, "pixels_only": cfg.env.pixels_only, "image_size": cfg.env.image_size, - # TODO(rcadene): do we want a specific eval_env_seed? - "seed": cfg.seed, "num_prev_obs": cfg.n_obs_steps - 1, + "seed": seed if seed is not None else cfg.seed, } if cfg.env.name == "simxarm": diff --git a/lerobot/common/envs/pusht/env.py b/lerobot/common/envs/pusht/env.py index 4a7ccb2c..2fe05233 100644 --- a/lerobot/common/envs/pusht/env.py +++ b/lerobot/common/envs/pusht/env.py @@ -2,7 +2,6 @@ import importlib from collections import deque from typing import Optional -import einops import torch from tensordict import TensorDict from torchrl.data.tensor_specs import ( @@ -120,40 +119,30 @@ class PushtEnv(AbstractEnv): def _step(self, tensordict: TensorDict): td = tensordict action = td["action"].numpy() - # step expects shape=(4,) so we pad if necessary + assert action.ndim == 1 # TODO(rcadene): add info["is_success"] and info["success"] ? - sum_reward = 0 - if action.ndim == 1: - action = einops.repeat(action, "c -> t c", t=self.frame_skip) - else: - if self.frame_skip > 1: - raise NotImplementedError() + raw_obs, reward, done, info = self._env.step(action) - num_action_steps = action.shape[0] - for i in range(num_action_steps): - raw_obs, reward, done, info = self._env.step(action[i]) - sum_reward += reward + obs = self._format_raw_obs(raw_obs) - obs = self._format_raw_obs(raw_obs) + if self.num_prev_obs > 0: + stacked_obs = {} + if "image" in obs: + self._prev_obs_image_queue.append(obs["image"]) + stacked_obs["image"] = torch.stack(list(self._prev_obs_image_queue)) + if "state" in obs: + self._prev_obs_state_queue.append(obs["state"]) + stacked_obs["state"] = torch.stack(list(self._prev_obs_state_queue)) + obs = stacked_obs - if self.num_prev_obs > 0: - stacked_obs = {} - if "image" in obs: - self._prev_obs_image_queue.append(obs["image"]) - stacked_obs["image"] = torch.stack(list(self._prev_obs_image_queue)) - if "state" in obs: - self._prev_obs_state_queue.append(obs["state"]) - stacked_obs["state"] = torch.stack(list(self._prev_obs_state_queue)) - obs = stacked_obs - - self.call_rendering_hooks() + self.call_rendering_hooks() td = TensorDict( { "observation": TensorDict(obs, batch_size=[]), - "reward": torch.tensor([sum_reward], dtype=torch.float32), - # succes and done are true when coverage > self.success_threshold in env + "reward": torch.tensor([reward], dtype=torch.float32), + # success and done are true when coverage > self.success_threshold in env "done": torch.tensor([done], dtype=torch.bool), "success": torch.tensor([done], dtype=torch.bool), }, diff --git a/lerobot/common/policies/abstract.py b/lerobot/common/policies/abstract.py new file mode 100644 index 00000000..4956530a --- /dev/null +++ b/lerobot/common/policies/abstract.py @@ -0,0 +1,54 @@ +from abc import abstractmethod +from collections import deque + +import torch +from torch import Tensor, nn + + +class AbstractPolicy(nn.Module): + @abstractmethod + def update(self, replay_buffer, step): + """One step of the policy's learning algorithm.""" + pass + + def save(self, fp): + torch.save(self.state_dict(), fp) + + def load(self, fp): + d = torch.load(fp) + self.load_state_dict(d) + + @abstractmethod + def select_action(self, observation) -> Tensor: + """Select an action (or trajectory of actions) based on an observation during rollout. + + Should return a (batch_size, n_action_steps, *) tensor of actions. + """ + pass + + def forward(self, *args, **kwargs): + """Inference step that makes multi-step policies compatible with their single-step environments. + + WARNING: In general, this should not be overriden. + + Consider a "policy" that observes the environment then charts a course of N actions to take. To make this fit + into the formalism of a TorchRL environment, we view it as being effectively a policy that (1) makes an + observation and prepares a queue of actions, (2) consumes that queue when queried, regardless of the environment + observation, (3) repopulates the action queue when empty. This method handles the aforementioned logic so that + the subclass doesn't have to. + + This method effectively wraps the `select_action` method of the subclass. The following assumptions are made: + 1. The `select_action` method returns a Tensor of actions with shape (B, H, *) where B is the batch size, H is + the action trajectory horizon and * is the action dimensions. + 2. Prior to the `select_action` method being called, theres is an `n_action_steps` instance attribute defined. + """ + n_action_steps_attr = "n_action_steps" + if not hasattr(self, n_action_steps_attr): + raise RuntimeError(f"Underlying policy must have an `{n_action_steps_attr}` attribute") + if not hasattr(self, "_action_queue"): + self._action_queue = deque([], maxlen=getattr(self, n_action_steps_attr)) + if len(self._action_queue) == 0: + # Each element in the queue has shape (B, *). + self._action_queue.extend(self.select_action(*args, **kwargs).transpose(0, 1)) + + return self._action_queue.popleft() diff --git a/lerobot/common/policies/act/policy.py b/lerobot/common/policies/act/policy.py index d011cb76..e87f155e 100644 --- a/lerobot/common/policies/act/policy.py +++ b/lerobot/common/policies/act/policy.py @@ -2,10 +2,10 @@ import logging import time import torch -import torch.nn as nn import torch.nn.functional as F # noqa: N812 import torchvision.transforms as transforms +from lerobot.common.policies.abstract import AbstractPolicy from lerobot.common.policies.act.detr_vae import build @@ -40,7 +40,7 @@ def kl_divergence(mu, logvar): return total_kld, dimension_wise_kld, mean_kld -class ActionChunkingTransformerPolicy(nn.Module): +class ActionChunkingTransformerPolicy(AbstractPolicy): def __init__(self, cfg, device, n_action_steps=1): super().__init__() self.cfg = cfg @@ -147,7 +147,7 @@ class ActionChunkingTransformerPolicy(nn.Module): return loss @torch.no_grad() - def forward(self, observation, step_count): + def select_action(self, observation, step_count): # TODO(rcadene): remove unused step_count del step_count diff --git a/lerobot/common/policies/diffusion/policy.py b/lerobot/common/policies/diffusion/policy.py index 3df76aa4..db004a71 100644 --- a/lerobot/common/policies/diffusion/policy.py +++ b/lerobot/common/policies/diffusion/policy.py @@ -3,14 +3,14 @@ import time import hydra import torch -import torch.nn as nn +from lerobot.common.policies.abstract import AbstractPolicy from lerobot.common.policies.diffusion.diffusion_unet_image_policy import DiffusionUnetImagePolicy from lerobot.common.policies.diffusion.model.lr_scheduler import get_scheduler from lerobot.common.policies.diffusion.model.multi_image_obs_encoder import MultiImageObsEncoder -class DiffusionPolicy(nn.Module): +class DiffusionPolicy(AbstractPolicy): def __init__( self, cfg, @@ -44,6 +44,7 @@ class DiffusionPolicy(nn.Module): **cfg_obs_encoder, ) + self.n_action_steps = n_action_steps # needed for the parent class self.diffusion = DiffusionUnetImagePolicy( shape_meta=shape_meta, noise_scheduler=noise_scheduler, @@ -93,21 +94,16 @@ class DiffusionPolicy(nn.Module): ) @torch.no_grad() - def forward(self, observation, step_count): + def select_action(self, observation, step_count): # TODO(rcadene): remove unused step_count del step_count - # TODO(rcadene): remove unsqueeze hack to add bsize=1 - observation["image"] = observation["image"].unsqueeze(0) - observation["state"] = observation["state"].unsqueeze(0) - obs_dict = { "image": observation["image"], "agent_pos": observation["state"], } out = self.diffusion.predict_action(obs_dict) - - action = out["action"].squeeze(0) + action = out["action"] return action def update(self, replay_buffer, step): diff --git a/lerobot/common/policies/tdmpc/policy.py b/lerobot/common/policies/tdmpc/policy.py index ae9888a5..48955459 100644 --- a/lerobot/common/policies/tdmpc/policy.py +++ b/lerobot/common/policies/tdmpc/policy.py @@ -9,6 +9,7 @@ import torch import torch.nn as nn import lerobot.common.policies.tdmpc.helper as h +from lerobot.common.policies.abstract import AbstractPolicy FIRST_FRAME = 0 @@ -85,7 +86,7 @@ class TOLD(nn.Module): return torch.min(Q1, Q2) if return_type == "min" else (Q1 + Q2) / 2 -class TDMPC(nn.Module): +class TDMPC(AbstractPolicy): """Implementation of TD-MPC learning + inference.""" def __init__(self, cfg, device): @@ -124,7 +125,7 @@ class TDMPC(nn.Module): self.model_target.load_state_dict(d["model_target"]) @torch.no_grad() - def forward(self, observation, step_count): + def select_action(self, observation, step_count): t0 = step_count.item() == 0 # TODO(rcadene): remove unsqueeze hack... diff --git a/lerobot/configs/default.yaml b/lerobot/configs/default.yaml index 6841cb82..2a7aab6c 100644 --- a/lerobot/configs/default.yaml +++ b/lerobot/configs/default.yaml @@ -10,6 +10,8 @@ hydra: name: default seed: 1337 +# batch size for TorchRL SerialEnv. Each underlying env will get the seed = seed + env_index +rollout_batch_size: 10 device: cuda # cpu prefetch: 4 eval_freq: ??? diff --git a/lerobot/scripts/eval.py b/lerobot/scripts/eval.py index 7ba2812e..e9d57cba 100644 --- a/lerobot/scripts/eval.py +++ b/lerobot/scripts/eval.py @@ -9,7 +9,8 @@ import numpy as np import torch import tqdm from tensordict.nn import TensorDictModule -from torchrl.envs import EnvBase +from torchrl.envs import EnvBase, SerialEnv +from torchrl.envs.batched_envs import BatchedEnvBase from lerobot.common.datasets.factory import make_offline_buffer from lerobot.common.envs.factory import make_env @@ -23,7 +24,7 @@ def write_video(video_path, stacked_frames, fps): def eval_policy( - env: EnvBase, + env: BatchedEnvBase, policy: TensorDictModule = None, num_episodes: int = 10, max_steps: int = 30, @@ -36,45 +37,55 @@ def eval_policy( sum_rewards = [] max_rewards = [] successes = [] - threads = [] - for i in tqdm.tqdm(range(num_episodes)): + threads = [] # for video saving threads + episode_counter = 0 # for saving the correct number of videos + + # TODO(alexander-soare): if num_episodes is not evenly divisible by the batch size, this will do more work than + # needed as I'm currently taking a ceil. + for i in tqdm.tqdm(range(-(-num_episodes // env.batch_size[0]))): ep_frames = [] - if save_video or (return_first_video and i == 0): - def render_frame(env): + def maybe_render_frame(env: EnvBase, _): + if save_video or (return_first_video and i == 0): # noqa: B023 ep_frames.append(env.render()) # noqa: B023 - env.register_rendering_hook(render_frame) - with torch.inference_mode(): rollout = env.rollout( max_steps=max_steps, policy=policy, auto_cast_to_device=True, + callback=maybe_render_frame, ) # print(", ".join([f"{x:.3f}" for x in rollout["next", "reward"][:,0].tolist()])) - ep_sum_reward = rollout["next", "reward"].sum() - ep_max_reward = rollout["next", "reward"].max() - ep_success = rollout["next", "success"].any() - sum_rewards.append(ep_sum_reward.item()) - max_rewards.append(ep_max_reward.item()) - successes.append(ep_success.item()) + batch_sum_reward = rollout["next", "reward"].flatten(start_dim=1).sum(dim=-1) + batch_max_reward = rollout["next", "reward"].flatten(start_dim=1).max(dim=-1)[0] + batch_success = rollout["next", "success"].flatten(start_dim=1).any(dim=-1) + sum_rewards.extend(batch_sum_reward.tolist()) + max_rewards.extend(batch_max_reward.tolist()) + successes.extend(batch_success.tolist()) if save_video or (return_first_video and i == 0): - stacked_frames = np.stack(ep_frames) + batch_stacked_frames = np.stack(ep_frames) # (t, b, *) + batch_stacked_frames = batch_stacked_frames.transpose( + 1, 0, *range(2, batch_stacked_frames.ndim) + ) # (b, t, *) if save_video: - video_dir.mkdir(parents=True, exist_ok=True) - video_path = video_dir / f"eval_episode_{i}.mp4" - thread = threading.Thread( - target=write_video, - args=(str(video_path), stacked_frames, fps), - ) - thread.start() - threads.append(thread) + for stacked_frames in batch_stacked_frames: + if episode_counter >= num_episodes: + continue + video_dir.mkdir(parents=True, exist_ok=True) + video_path = video_dir / f"eval_episode_{episode_counter}.mp4" + thread = threading.Thread( + target=write_video, + args=(str(video_path), stacked_frames, fps), + ) + thread.start() + threads.append(thread) + episode_counter += 1 if return_first_video and i == 0: - first_video = stacked_frames.transpose(0, 3, 1, 2) + first_video = batch_stacked_frames[0].transpose(0, 3, 1, 2) env.reset_rendering_hooks() @@ -82,9 +93,9 @@ def eval_policy( thread.join() info = { - "avg_sum_reward": np.nanmean(sum_rewards), - "avg_max_reward": np.nanmean(max_rewards), - "pc_success": np.nanmean(successes) * 100, + "avg_sum_reward": np.nanmean(sum_rewards[:num_episodes]), + "avg_max_reward": np.nanmean(max_rewards[:num_episodes]), + "pc_success": np.nanmean(successes[:num_episodes]) * 100, "eval_s": time.time() - start, "eval_ep_s": (time.time() - start) / num_episodes, } @@ -119,7 +130,14 @@ def eval(cfg: dict, out_dir=None): offline_buffer = make_offline_buffer(cfg) logging.info("make_env") - env = make_env(cfg, transform=offline_buffer.transform) + env = SerialEnv( + cfg.rollout_batch_size, + create_env_fn=make_env, + create_env_kwargs=[ + {"cfg": cfg, "seed": s, "transform": offline_buffer.transform} + for s in range(cfg.seed, cfg.seed + cfg.rollout_batch_size) + ], + ) if cfg.policy.pretrained_model_path: policy = make_policy(cfg) @@ -138,7 +156,7 @@ def eval(cfg: dict, out_dir=None): save_video=True, video_dir=Path(out_dir) / "eval", fps=cfg.env.fps, - max_steps=cfg.env.episode_length // cfg.n_action_steps, + max_steps=cfg.env.episode_length, num_episodes=cfg.eval_episodes, ) print(metrics) diff --git a/lerobot/scripts/train.py b/lerobot/scripts/train.py index c063caf8..579f5a58 100644 --- a/lerobot/scripts/train.py +++ b/lerobot/scripts/train.py @@ -7,6 +7,7 @@ import torch from tensordict.nn import TensorDictModule from torchrl.data import LazyMemmapStorage, TensorDictReplayBuffer from torchrl.data.replay_buffers import PrioritizedSliceSampler +from torchrl.envs import SerialEnv from lerobot.common.datasets.factory import make_offline_buffer from lerobot.common.envs.factory import make_env @@ -148,6 +149,14 @@ def train(cfg: dict, out_dir=None, job_name=None): logging.info("make_env") env = make_env(cfg, transform=offline_buffer.transform) + env = SerialEnv( + cfg.rollout_batch_size, + create_env_fn=make_env, + create_env_kwargs=[ + {"cfg": cfg, "seed": s, "transform": offline_buffer.transform} + for s in range(cfg.seed, cfg.seed + cfg.rollout_batch_size) + ], + ) logging.info("make_policy") policy = make_policy(cfg) @@ -191,7 +200,7 @@ def train(cfg: dict, out_dir=None, job_name=None): env, td_policy, num_episodes=cfg.eval_episodes, - max_steps=cfg.env.episode_length // cfg.n_action_steps, + max_steps=cfg.env.episode_length, return_first_video=True, video_dir=Path(out_dir) / "eval", save_video=True, diff --git a/tests/test_policies.py b/tests/test_policies.py index f00429bc..7d9a4dce 100644 --- a/tests/test_policies.py +++ b/tests/test_policies.py @@ -1,7 +1,15 @@ + import pytest +from tensordict import TensorDict +from tensordict.nn import TensorDictModule +import torch +from torchrl.data import UnboundedContinuousTensorSpec +from torchrl.envs import EnvBase from lerobot.common.policies.factory import make_policy +from lerobot.common.policies.abstract import AbstractPolicy + from .utils import DEVICE, init_config @@ -23,3 +31,75 @@ def test_factory(env_name, policy_name): ] ) policy = make_policy(cfg) + + +def test_abstract_policy_forward(): + """ + Given an underlying policy that produces an action trajectory with n_action_steps actions, checks that: + - The policy is invoked the expected number of times during a rollout. + - The environment's termination condition is respected even when part way through an action trajectory. + - The observations are returned correctly. + """ + + n_action_steps = 8 # our test policy will output 8 action step horizons + terminate_at = 10 # some number that is more than n_action_steps but not a multiple + rollout_max_steps = terminate_at + 1 # some number greater than terminate_at + + # A minimal environment for testing. + class StubEnv(EnvBase): + + def __init__(self): + super().__init__() + self.action_spec = UnboundedContinuousTensorSpec(shape=(1,)) + self.reward_spec = UnboundedContinuousTensorSpec(shape=(1,)) + + def _step(self, tensordict: TensorDict) -> TensorDict: + self.invocation_count += 1 + return TensorDict( + { + "observation": torch.tensor([self.invocation_count]), + "reward": torch.tensor([self.invocation_count]), + "terminated": torch.tensor( + tensordict["action"].item() == terminate_at + ), + } + ) + + def _reset(self, tensordict: TensorDict) -> TensorDict: + self.invocation_count = 0 + return TensorDict( + { + "observation": torch.tensor([self.invocation_count]), + "reward": torch.tensor([self.invocation_count]), + } + ) + + def _set_seed(self, seed: int | None): + return + + + class StubPolicy(AbstractPolicy): + def __init__(self): + super().__init__() + self.n_action_steps = n_action_steps + self.n_policy_invocations = 0 + + def select_action(self): + self.n_policy_invocations += 1 + return torch.stack([torch.tensor([i]) for i in range(self.n_action_steps)]).unsqueeze(0) + + + env = StubEnv() + policy = StubPolicy() + policy = TensorDictModule( + policy, + in_keys=[], + out_keys=["action"], + ) + + # Keep track to make sure the policy is called the expected number of times + rollout = env.rollout(rollout_max_steps, policy) + + assert len(rollout) == terminate_at + 1 # +1 for the reset observation + assert policy.n_policy_invocations == (terminate_at // n_action_steps) + 1 + assert torch.equal(rollout['observation'].flatten(), torch.arange(terminate_at + 1)) From 88347965c2d829274934bface616dacda337ee40 Mon Sep 17 00:00:00 2001 From: Alexander Soare Date: Mon, 18 Mar 2024 19:18:21 +0000 Subject: [PATCH 05/16] revert dp changes, make act and tdmpc batch friendly --- lerobot/common/policies/abstract.py | 12 ++++--- lerobot/common/policies/act/policy.py | 9 +----- .../model/multi_image_obs_encoder.py | 31 +++---------------- lerobot/common/policies/diffusion/policy.py | 4 +-- lerobot/common/policies/tdmpc/policy.py | 7 +---- lerobot/configs/policy/diffusion.yaml | 20 ++++++------ lerobot/scripts/eval.py | 4 +-- tests/test_policies.py | 3 ++ 8 files changed, 32 insertions(+), 58 deletions(-) diff --git a/lerobot/common/policies/abstract.py b/lerobot/common/policies/abstract.py index 4956530a..9c652c0a 100644 --- a/lerobot/common/policies/abstract.py +++ b/lerobot/common/policies/abstract.py @@ -1,15 +1,20 @@ -from abc import abstractmethod +from abc import ABC, abstractmethod from collections import deque import torch from torch import Tensor, nn -class AbstractPolicy(nn.Module): +class AbstractPolicy(nn.Module, ABC): + """Base policy which all policies should be derived from. + + The forward method should generally not be overriden as it plays the role of handling multi-step policies. See its + documentation for more information. + """ + @abstractmethod def update(self, replay_buffer, step): """One step of the policy's learning algorithm.""" - pass def save(self, fp): torch.save(self.state_dict(), fp) @@ -24,7 +29,6 @@ class AbstractPolicy(nn.Module): Should return a (batch_size, n_action_steps, *) tensor of actions. """ - pass def forward(self, *args, **kwargs): """Inference step that makes multi-step policies compatible with their single-step environments. diff --git a/lerobot/common/policies/act/policy.py b/lerobot/common/policies/act/policy.py index e87f155e..e0499cdb 100644 --- a/lerobot/common/policies/act/policy.py +++ b/lerobot/common/policies/act/policy.py @@ -153,10 +153,6 @@ class ActionChunkingTransformerPolicy(AbstractPolicy): self.eval() - # TODO(rcadene): remove unsqueeze hack to add bsize=1 - observation["image", "top"] = observation["image", "top"].unsqueeze(0) - # observation["state"] = observation["state"].unsqueeze(0) - # TODO(rcadene): remove hack # add 1 camera dimension observation["image", "top"] = observation["image", "top"].unsqueeze(1) @@ -180,11 +176,8 @@ class ActionChunkingTransformerPolicy(AbstractPolicy): # exp_weights = torch.from_numpy(exp_weights).cuda().unsqueeze(dim=1) # raw_action = (actions_for_curr_step * exp_weights).sum(dim=0, keepdim=True) - # remove bsize=1 - action = action.squeeze(0) - # take first predicted action or n first actions - action = action[0] if self.n_action_steps == 1 else action[: self.n_action_steps] + action = action[: self.n_action_steps] return action def _forward(self, qpos, image, actions=None, is_pad=None): diff --git a/lerobot/common/policies/diffusion/model/multi_image_obs_encoder.py b/lerobot/common/policies/diffusion/model/multi_image_obs_encoder.py index 6a1d3c0d..94dc6f49 100644 --- a/lerobot/common/policies/diffusion/model/multi_image_obs_encoder.py +++ b/lerobot/common/policies/diffusion/model/multi_image_obs_encoder.py @@ -1,37 +1,15 @@ import copy -from typing import Dict, Optional, Tuple, Union +from typing import Dict, Tuple, Union -import timm import torch import torch.nn as nn import torchvision -from robomimic.models.base_nets import SpatialSoftmax from lerobot.common.policies.diffusion.model.crop_randomizer import CropRandomizer from lerobot.common.policies.diffusion.model.module_attr_mixin import ModuleAttrMixin from lerobot.common.policies.diffusion.pytorch_utils import replace_submodules -class RgbEncoder(nn.Module): - """Following `VisualCore` from Robomimic 0.2.0.""" - - def __init__(self, input_shape, model_name="resnet18", pretrained=False, num_keypoints=32): - """ - resnet_name: a timm model name. - pretrained: whether to use timm pretrained weights. - num_keypoints: Number of keypoints for SpatialSoftmax (default value of 32 matches PushT Image). - """ - super().__init__() - self.backbone = timm.create_model(model_name, pretrained, num_classes=0, global_pool="") - # Figure out the feature map shape. - with torch.inference_mode(): - feat_map_shape = tuple(self.backbone(torch.zeros(size=(1, *input_shape))).shape[1:]) - self.pool = SpatialSoftmax(feat_map_shape, num_kp=num_keypoints) - - def forward(self, x): - return torch.flatten(self.pool(self.backbone(x)), start_dim=1) - - class MultiImageObsEncoder(ModuleAttrMixin): def __init__( self, @@ -46,7 +24,7 @@ class MultiImageObsEncoder(ModuleAttrMixin): share_rgb_model: bool = False, # renormalize rgb input with imagenet normalization # assuming input in [0,1] - norm_mean_std: Optional[tuple[float, float]] = None, + imagenet_norm: bool = False, ): """ Assumes rgb input: B,C,H,W @@ -120,9 +98,10 @@ class MultiImageObsEncoder(ModuleAttrMixin): this_normalizer = torchvision.transforms.CenterCrop(size=(h, w)) # configure normalizer this_normalizer = nn.Identity() - if norm_mean_std is not None: + if imagenet_norm: + # TODO(rcadene): move normalizer to dataset and env this_normalizer = torchvision.transforms.Normalize( - mean=norm_mean_std[0], std=norm_mean_std[1] + mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225] ) this_transform = nn.Sequential(this_resizer, this_randomizer, this_normalizer) diff --git a/lerobot/common/policies/diffusion/policy.py b/lerobot/common/policies/diffusion/policy.py index e779596c..db004a71 100644 --- a/lerobot/common/policies/diffusion/policy.py +++ b/lerobot/common/policies/diffusion/policy.py @@ -7,7 +7,7 @@ import torch from lerobot.common.policies.abstract import AbstractPolicy from lerobot.common.policies.diffusion.diffusion_unet_image_policy import DiffusionUnetImagePolicy from lerobot.common.policies.diffusion.model.lr_scheduler import get_scheduler -from lerobot.common.policies.diffusion.model.multi_image_obs_encoder import MultiImageObsEncoder, RgbEncoder +from lerobot.common.policies.diffusion.model.multi_image_obs_encoder import MultiImageObsEncoder class DiffusionPolicy(AbstractPolicy): @@ -38,7 +38,7 @@ class DiffusionPolicy(AbstractPolicy): self.cfg = cfg noise_scheduler = hydra.utils.instantiate(cfg_noise_scheduler) - rgb_model = RgbEncoder(input_shape=shape_meta.obs.image.shape, **cfg_rgb_model) + rgb_model = hydra.utils.instantiate(cfg_rgb_model) obs_encoder = MultiImageObsEncoder( rgb_model=rgb_model, **cfg_obs_encoder, diff --git a/lerobot/common/policies/tdmpc/policy.py b/lerobot/common/policies/tdmpc/policy.py index 48955459..4c104bcd 100644 --- a/lerobot/common/policies/tdmpc/policy.py +++ b/lerobot/common/policies/tdmpc/policy.py @@ -128,11 +128,6 @@ class TDMPC(AbstractPolicy): def select_action(self, observation, step_count): t0 = step_count.item() == 0 - # TODO(rcadene): remove unsqueeze hack... - if observation["image"].ndim == 3: - observation["image"] = observation["image"].unsqueeze(0) - observation["state"] = observation["state"].unsqueeze(0) - obs = { # TODO(rcadene): remove contiguous hack... "rgb": observation["image"].contiguous(), @@ -149,7 +144,7 @@ class TDMPC(AbstractPolicy): if self.cfg.mpc: a = self.plan(z, t0=t0, step=step) else: - a = self.model.pi(z, self.cfg.min_std * self.model.training).squeeze(0) + a = self.model.pi(z, self.cfg.min_std * self.model.training) return a @torch.no_grad() diff --git a/lerobot/configs/policy/diffusion.yaml b/lerobot/configs/policy/diffusion.yaml index 7de44102..0dae5056 100644 --- a/lerobot/configs/policy/diffusion.yaml +++ b/lerobot/configs/policy/diffusion.yaml @@ -42,8 +42,8 @@ policy: num_inference_steps: 100 obs_as_global_cond: ${obs_as_global_cond} # crop_shape: null - diffusion_step_embed_dim: 128 - down_dims: [512, 1024, 2048] + diffusion_step_embed_dim: 256 # before 128 + down_dims: [256, 512, 1024] # before [512, 1024, 2048] kernel_size: 5 n_groups: 8 cond_predict_scale: True @@ -81,12 +81,12 @@ obs_encoder: # random_crop: True use_group_norm: True share_rgb_model: False - norm_mean_std: [0.5, 0.5] # for PushT the original impl normalizes to [-1, 1] (maybe not the case for robomimic envs) + imagenet_norm: True rgb_model: - model_name: resnet18 - pretrained: false - num_keypoints: 32 + _target_: lerobot.common.policies.diffusion.pytorch_utils.get_resnet + name: resnet18 + weights: null ema: _target_: lerobot.common.policies.diffusion.model.ema_model.EMAModel @@ -109,13 +109,13 @@ training: debug: False resume: True # optimization - lr_scheduler: cosine - lr_warmup_steps: 500 - num_epochs: 500 + # lr_scheduler: cosine + # lr_warmup_steps: 500 + num_epochs: 8000 # gradient_accumulate_every: 1 # EMA destroys performance when used with BatchNorm # replace BatchNorm with GroupNorm. - use_ema: True + # use_ema: True freeze_encoder: False # training loop control # in epochs diff --git a/lerobot/scripts/eval.py b/lerobot/scripts/eval.py index 839c12bb..7cfb796a 100644 --- a/lerobot/scripts/eval.py +++ b/lerobot/scripts/eval.py @@ -135,8 +135,8 @@ def eval(cfg: dict, out_dir=None): cfg.rollout_batch_size, create_env_fn=make_env, create_env_kwargs=[ - {"cfg": cfg, "seed": s, "transform": offline_buffer.transform} - for s in range(cfg.seed, cfg.seed + cfg.rollout_batch_size) + {"cfg": cfg, "seed": env_seed, "transform": offline_buffer.transform} + for env_seed in range(cfg.seed, cfg.seed + cfg.rollout_batch_size) ], ) diff --git a/tests/test_policies.py b/tests/test_policies.py index 7d9a4dce..92324485 100644 --- a/tests/test_policies.py +++ b/tests/test_policies.py @@ -84,6 +84,9 @@ def test_abstract_policy_forward(): self.n_action_steps = n_action_steps self.n_policy_invocations = 0 + def update(self): + pass + def select_action(self): self.n_policy_invocations += 1 return torch.stack([torch.tensor([i]) for i in range(self.n_action_steps)]).unsqueeze(0) From ea17f4ce501afe867c73954a86457f12a95fcf42 Mon Sep 17 00:00:00 2001 From: Alexander Soare Date: Tue, 19 Mar 2024 16:02:09 +0000 Subject: [PATCH 06/16] backup wip --- lerobot/common/datasets/abstract.py | 6 +- lerobot/common/datasets/aloha.py | 6 +- lerobot/common/envs/factory.py | 55 +++++++++++++----- lerobot/common/policies/abstract.py | 2 +- lerobot/configs/default.yaml | 8 ++- lerobot/configs/policy/diffusion.yaml | 4 +- lerobot/scripts/eval.py | 11 +--- lerobot/scripts/train.py | 9 --- .../data/aloha_sim_insertion_human/stats.pth | Bin 4434 -> 4306 bytes tests/data/pusht/stats.pth | Bin 4306 -> 4242 bytes tests/test_policies.py | 16 ++++- 11 files changed, 71 insertions(+), 46 deletions(-) diff --git a/lerobot/common/datasets/abstract.py b/lerobot/common/datasets/abstract.py index 34b33c2e..5db97497 100644 --- a/lerobot/common/datasets/abstract.py +++ b/lerobot/common/datasets/abstract.py @@ -49,9 +49,9 @@ class AbstractExperienceReplay(TensorDictReplayBuffer): @property def stats_patterns(self) -> dict: return { - ("observation", "state"): "b c -> 1 c", - ("observation", "image"): "b c h w -> 1 c 1 1", - ("action",): "b c -> 1 c", + ("observation", "state"): "b c -> c", + ("observation", "image"): "b c h w -> c", + ("action",): "b c -> c", } @property diff --git a/lerobot/common/datasets/aloha.py b/lerobot/common/datasets/aloha.py index 52a5676e..0637f8a3 100644 --- a/lerobot/common/datasets/aloha.py +++ b/lerobot/common/datasets/aloha.py @@ -113,11 +113,11 @@ class AlohaExperienceReplay(AbstractExperienceReplay): @property def stats_patterns(self) -> dict: d = { - ("observation", "state"): "b c -> 1 c", - ("action",): "b c -> 1 c", + ("observation", "state"): "b c -> c", + ("action",): "b c -> c", } for cam in CAMERAS[self.dataset_id]: - d[("observation", "image", cam)] = "b c h w -> 1 c 1 1" + d[("observation", "image", cam)] = "b c h w -> c" return d @property diff --git a/lerobot/common/envs/factory.py b/lerobot/common/envs/factory.py index d6b294eb..de86b3ad 100644 --- a/lerobot/common/envs/factory.py +++ b/lerobot/common/envs/factory.py @@ -1,17 +1,31 @@ from torchrl.envs.transforms import Compose, StepCounter, Transform, TransformedEnv -def make_env(cfg, seed=None, transform=None): +def make_env(cfg, transform=None): """ Provide seed to override the seed in the cfg (useful for batched environments). """ + # assert cfg.rollout_batch_size == 1, \ + # """ + # For the time being, rollout batch sizes of > 1 are not supported. This is because the SerialEnv rollout does not + # correctly handle terminated environments. If you really want to use a larger batch size, read on... + + # When calling `EnvBase.rollout` with `break_when_any_done == True` all environments stop rolling out as soon as the + # first is terminated or truncated. This almost certainly results in incorrect success metrics, as all but the first + # environment get an opportunity to reach the goal. A possible work around is to comment out `if any_done: break` + # inf `EnvBase._rollout_stop_early`. One potential downside is that the environments `step` function will continue + # to be called and the outputs will continue to be added to the rollout. + + # When calling `EnvBase.rollout` with `break_when_any_done == False` environments are reset when done. + # """ + kwargs = { "frame_skip": cfg.env.action_repeat, "from_pixels": cfg.env.from_pixels, "pixels_only": cfg.env.pixels_only, "image_size": cfg.env.image_size, "num_prev_obs": cfg.n_obs_steps - 1, - "seed": seed if seed is not None else cfg.seed, + "seed": cfg.seed, } if cfg.env.name == "simxarm": @@ -33,22 +47,33 @@ def make_env(cfg, seed=None, transform=None): else: raise ValueError(cfg.env.name) - env = clsfunc(**kwargs) + def _make_env(seed): + nonlocal kwargs + kwargs["seed"] = seed + env = clsfunc(**kwargs) - # limit rollout to max_steps - env = TransformedEnv(env, StepCounter(max_steps=cfg.env.episode_length)) + # limit rollout to max_steps + env = TransformedEnv(env, StepCounter(max_steps=cfg.env.episode_length)) - if transform is not None: - # useful to add normalization - if isinstance(transform, Compose): - for tf in transform: - env.append_transform(tf.clone()) - elif isinstance(transform, Transform): - env.append_transform(transform.clone()) - else: - raise NotImplementedError() + if transform is not None: + # useful to add normalization + if isinstance(transform, Compose): + for tf in transform: + env.append_transform(tf.clone()) + elif isinstance(transform, Transform): + env.append_transform(transform.clone()) + else: + raise NotImplementedError() - return env + return env + + # return SerialEnv( + # cfg.rollout_batch_size, + # create_env_fn=_make_env, + # create_env_kwargs={ + # "seed": env_seed for env_seed in range(cfg.seed, cfg.seed + cfg.rollout_batch_size) + # }, + # ) # def make_env(env_name, frame_skip, device, is_test=False): diff --git a/lerobot/common/policies/abstract.py b/lerobot/common/policies/abstract.py index 9c652c0a..ca2d8570 100644 --- a/lerobot/common/policies/abstract.py +++ b/lerobot/common/policies/abstract.py @@ -30,7 +30,7 @@ class AbstractPolicy(nn.Module, ABC): Should return a (batch_size, n_action_steps, *) tensor of actions. """ - def forward(self, *args, **kwargs): + def forward(self, *args, **kwargs) -> Tensor: """Inference step that makes multi-step policies compatible with their single-step environments. WARNING: In general, this should not be overriden. diff --git a/lerobot/configs/default.yaml b/lerobot/configs/default.yaml index 5cc8acd2..27b75c88 100644 --- a/lerobot/configs/default.yaml +++ b/lerobot/configs/default.yaml @@ -11,14 +11,16 @@ hydra: seed: 1337 # batch size for TorchRL SerialEnv. Each underlying env will get the seed = seed + env_index -rollout_batch_size: 10 +# NOTE: batch size of 1 is not yet supported! This is just a placeholder for future support. See +# `lerobot.common.envs.factory.make_env` for more information. +rollout_batch_size: 1 device: cuda # cpu prefetch: 4 eval_freq: ??? save_freq: ??? eval_episodes: ??? save_video: false -save_model: false +save_model: true save_buffer: false train_steps: ??? fps: ??? @@ -31,7 +33,7 @@ env: ??? policy: ??? wandb: - enable: true + enable: false # Set to true to disable saving an artifact despite save_model == True disable_artifact: false project: lerobot diff --git a/lerobot/configs/policy/diffusion.yaml b/lerobot/configs/policy/diffusion.yaml index 0dae5056..ce8acbd4 100644 --- a/lerobot/configs/policy/diffusion.yaml +++ b/lerobot/configs/policy/diffusion.yaml @@ -22,8 +22,8 @@ keypoint_visible_rate: 1.0 obs_as_global_cond: True eval_episodes: 1 -eval_freq: 10000 -save_freq: 100000 +eval_freq: 5000 +save_freq: 5000 log_freq: 250 offline_steps: 1344000 diff --git a/lerobot/scripts/eval.py b/lerobot/scripts/eval.py index 7cfb796a..c0199c0c 100644 --- a/lerobot/scripts/eval.py +++ b/lerobot/scripts/eval.py @@ -9,7 +9,7 @@ import numpy as np import torch import tqdm from tensordict.nn import TensorDictModule -from torchrl.envs import EnvBase, SerialEnv +from torchrl.envs import EnvBase from torchrl.envs.batched_envs import BatchedEnvBase from lerobot.common.datasets.factory import make_offline_buffer @@ -131,14 +131,7 @@ def eval(cfg: dict, out_dir=None): offline_buffer = make_offline_buffer(cfg) logging.info("make_env") - env = SerialEnv( - cfg.rollout_batch_size, - create_env_fn=make_env, - create_env_kwargs=[ - {"cfg": cfg, "seed": env_seed, "transform": offline_buffer.transform} - for env_seed in range(cfg.seed, cfg.seed + cfg.rollout_batch_size) - ], - ) + env = make_env(cfg, transform=offline_buffer.transform) if cfg.policy.pretrained_model_path: policy = make_policy(cfg) diff --git a/lerobot/scripts/train.py b/lerobot/scripts/train.py index 2c7bb575..5ecd616d 100644 --- a/lerobot/scripts/train.py +++ b/lerobot/scripts/train.py @@ -7,7 +7,6 @@ import torch from tensordict.nn import TensorDictModule from torchrl.data import LazyMemmapStorage, TensorDictReplayBuffer from torchrl.data.replay_buffers import PrioritizedSliceSampler -from torchrl.envs import SerialEnv from lerobot.common.datasets.factory import make_offline_buffer from lerobot.common.envs.factory import make_env @@ -149,14 +148,6 @@ def train(cfg: dict, out_dir=None, job_name=None): logging.info("make_env") env = make_env(cfg, transform=offline_buffer.transform) - env = SerialEnv( - cfg.rollout_batch_size, - create_env_fn=make_env, - create_env_kwargs=[ - {"cfg": cfg, "seed": s, "transform": offline_buffer.transform} - for s in range(cfg.seed, cfg.seed + cfg.rollout_batch_size) - ], - ) logging.info("make_policy") policy = make_policy(cfg) diff --git a/tests/data/aloha_sim_insertion_human/stats.pth b/tests/data/aloha_sim_insertion_human/stats.pth index 869d26cd0528611dc3ffb548c5b35d0cff50f6ef..f909ed075ce48cf7f677cc82a8859615b15924c1 100644 GIT binary patch delta 1186 zcmY*Ye@qis9KS-LKj_`HYi(7i45zlvA%*dy?cMbbNSsp3Dw82oXZW?aMF-e+h8Uu3 zYfxE`U?2@-xAa8-g;UG#GP`}1-Po0p?D zO^S)c44BVf;rC@!Hu!654^-D2*i@j!(C-OtT=+s#_`_{p;pdN=*d#{FPg7qo&6Jui z?QD2O)-Xni_GWes3EUgp*;1b}Tq_}sbnk1d7#rDJViN}HPO?=7Gjru5C2PHgxz~v-&ZLoAi z=A3HAaxWGJky~kh79&o{aoiRay*K$;p>gvm78BbyK2jQgGK5{gH%-`tv)4|tyB`{v ze7l<6w_fHNy~i;ivfGoJnE^P7?Sp{dE~hAWJa#qtKVOUVTTK0_29m*4nD~Zs;_d3p`P&Lb#CEm`eBwy zC-^ekB4Y_@cx7LytEsXp{Q6z5@T0GRHCuY=u`-40Q1>M!lBtpG!nL|uG@8V54Qr>V zWsVT7d3A;!Khi>_lpke+-A&Z6>pp#hvr&_lZu*a`Y5LZg>&_#16Z5!jmRi-Nr{PF% z=>|NA)BKQ}muJtl=RTj6?Qmq~WX3oEhZ20p3mUIceRmP$&e6?IEz!+|f zYfnH?e3e|Pr9Q%hj6^C9HDMRMf{r64VMcohom_fk2eA%yibQ&>7;Zu>1cAO61CsH5 z#9BmK{8 delta 1279 zcmY*Y3rt&87{0f(P-t&UEw&WOIyPY417VCuk$cX$YZY*HSqqGyVLEUNy0AE)F54t4 z;ZgQ5!Z?{5An#2>z=dh%vK4OYjES%WrHrA5x#2WRm=ea27&XTGm=11oZqE0A_k90< zzW=|0*uc7>1Ehkqt2`AtyUI>_a0R(j#3m7cMW%{)gMeEkPhgTrXGsQhW;ph#VOXR~q$df+nh5C9bNZ(xa~ZZ&&WyUzu&S5O_kCgzw2hagD5o zIecIs>(aM0t8J!{G|x1AseKkkJlEi*=GUO-e;yy1{;rUXPDwz;b=mN;--wPZMzYJN zZ$pPYM6Q{kdN32rFq*a9^shNvKx|tiSlir2%je@kL-sBDkJtaAiKubf@KzNFLuFvH zbDm!AIt=WKH2AS674RbU>wlTVeoiESdtvwJAXxCUp~CA;u>YbPm3=w`_ZVByV*CZz zTAc^IG41HwCn?A|*n(`%=5xLKx=a$C$NQT|QfFJG^?gkuBO7ja~@M!olJb$fLLbJ0IsmZD%`5@$W=K*DzY@AyMG7lR_mM{s}pY89BpD zJUqiT-FU|r6W#$!6Flsw{%`hS{}fAbPwu8AHtL6JUx(srFgEUB@7wRNk6XD)Zp04r zm3%*zQIGv{#^5<0Z#9p`S-ZUnR-U-x?V3??0!pk_#^ChO9R_n}Gi)?q_>sOv=*cVE z-y@!W@p-62TW%u=w~aro^>qzD<V!-|&YUjHC>Ta5>7$9_c%YlsgH0QI=_2`h zkkdCscf7F`RBOiRX2l_p@l6@X&t0Y`ZdQT0+ca?NDuwP_$h;`*S4>53O;1TnPRqKS#WBv*;`;ihQ*_GV6(Uq(K`|4$2chVRl?C^6u$2Z!*x1$;R? II#`VU3rHHKMF0Q* diff --git a/tests/data/pusht/stats.pth b/tests/data/pusht/stats.pth index 037e02f0fbe9a84fb24d2119174addd4a5a913dd..8846b8f65ae9b52dc74d369c239d64f42ff214fb 100644 GIT binary patch delta 828 zcmZuvO=uHQ5PrMeWH-s~<_VJ8Ce5ZGSgkSD4@Dbk3JTJ68_`<9 zSPVV*7o9_kUM<*Ku~BKUqJp4aRNCMnhZZV$6;!M@!JU`VdT?Oa@0htC4ObgB#=s zlDWDCoHc;h5?T?TMLaF+L7YeYrG`flH#l0{^ADdrni2?^5`fBqdM^o*;d~}P96mLY z&kkDsSu2@nB(Usggo2|$zZw(frB~x-IweYa=4v?tC8wpn zFiByJ>3+(;DnLa`&6&NHK_UJ`(^uP$4X+C-9kqIF(c;zP8@9Jt;ne6LK=2|;vX5Yqe#qFcY*p_16 z^`U+&K}f|;$|l?|K>44(yNhb&NppM04c0yGDz^jm2Tv3ApS_X%`*j}{T z=nqnUMS^uT%Jo8n0mNMt>%s2}jc@j>i*9cqBW}8(`bq vuPD$XaeD@gZYz~pj>^3v^7(Yn1(l_ zH!~13mK54+Kv+QbS15k>G@f=;Vn!zD&NXlOOOXvVgz=< z+GH7Cc@}U;0l7>PY?H%z<)lHu!~t|12=jtm2?sMKPvBLM0fiZMc?tH(XMhHR!w+Qs zjLB?#@+{ym2Aj_@*^y6<1r%hH1^GeJRX}M__<*=fGMtmQ@>wv>nfwkY3JMj7DA#07 zemNF!H~?xx qlm%fKkTN(B;$fKl9vF270tSLRK=A+s0p4uvAi7|3p@1|SL=*sgDe|%a diff --git a/tests/test_policies.py b/tests/test_policies.py index 92324485..ee5abdb7 100644 --- a/tests/test_policies.py +++ b/tests/test_policies.py @@ -1,4 +1,5 @@ +from omegaconf import open_dict import pytest from tensordict import TensorDict from tensordict.nn import TensorDictModule @@ -7,7 +8,8 @@ from torchrl.data import UnboundedContinuousTensorSpec from torchrl.envs import EnvBase from lerobot.common.policies.factory import make_policy - +from lerobot.common.envs.factory import make_env +from lerobot.common.datasets.factory import make_offline_buffer from lerobot.common.policies.abstract import AbstractPolicy from .utils import DEVICE, init_config @@ -30,7 +32,19 @@ def test_factory(env_name, policy_name): f"device={DEVICE}", ] ) + # Check that we can make the policy object. policy = make_policy(cfg) + # Check that we run select_action and get the appropriate output. + if env_name == "simxarm": + # TODO(rcadene): Not implemented + return + if policy_name == "tdmpc": + # TODO(alexander-soare): TDMPC does not use n_obs_steps but the environment requires this. + with open_dict(cfg): + cfg['n_obs_steps'] = 1 + offline_buffer = make_offline_buffer(cfg) + env = make_env(cfg, transform=offline_buffer.transform) + policy.select_action(env.observation_spec.rand()['observation'].to(DEVICE), torch.tensor(0, device=DEVICE)) def test_abstract_policy_forward(): From 896a11f60e3a0f9d7107642d692de1c20ee4fd48 Mon Sep 17 00:00:00 2001 From: Alexander Soare Date: Tue, 19 Mar 2024 18:50:04 +0000 Subject: [PATCH 07/16] backup wip --- lerobot/common/datasets/abstract.py | 2 +- lerobot/common/datasets/aloha.py | 2 +- lerobot/common/envs/aloha/env.py | 73 +++++++++--------- lerobot/common/envs/factory.py | 34 +++----- lerobot/common/envs/pusht/env.py | 62 ++++++++------- lerobot/common/policies/abstract.py | 32 +++++--- lerobot/common/policies/act/policy.py | 7 +- lerobot/common/policies/diffusion/policy.py | 5 +- lerobot/common/policies/factory.py | 3 + lerobot/common/policies/tdmpc/policy.py | 12 ++- lerobot/configs/default.yaml | 7 +- lerobot/configs/policy/diffusion.yaml | 4 +- lerobot/scripts/eval.py | 17 +++- .../data/aloha_sim_insertion_human/stats.pth | Bin 4306 -> 4370 bytes tests/data/pusht/stats.pth | Bin 4242 -> 4306 bytes tests/test_policies.py | 47 +++++++---- 16 files changed, 169 insertions(+), 138 deletions(-) diff --git a/lerobot/common/datasets/abstract.py b/lerobot/common/datasets/abstract.py index 5db97497..4ce447bf 100644 --- a/lerobot/common/datasets/abstract.py +++ b/lerobot/common/datasets/abstract.py @@ -50,7 +50,7 @@ class AbstractExperienceReplay(TensorDictReplayBuffer): def stats_patterns(self) -> dict: return { ("observation", "state"): "b c -> c", - ("observation", "image"): "b c h w -> c", + ("observation", "image"): "b c h w -> c 1 1", ("action",): "b c -> c", } diff --git a/lerobot/common/datasets/aloha.py b/lerobot/common/datasets/aloha.py index 0637f8a3..b1a5806f 100644 --- a/lerobot/common/datasets/aloha.py +++ b/lerobot/common/datasets/aloha.py @@ -117,7 +117,7 @@ class AlohaExperienceReplay(AbstractExperienceReplay): ("action",): "b c -> c", } for cam in CAMERAS[self.dataset_id]: - d[("observation", "image", cam)] = "b c h w -> c" + d[("observation", "image", cam)] = "b c h w -> c 1 1" return d @property diff --git a/lerobot/common/envs/aloha/env.py b/lerobot/common/envs/aloha/env.py index 7ef24f2d..001b2ba2 100644 --- a/lerobot/common/envs/aloha/env.py +++ b/lerobot/common/envs/aloha/env.py @@ -58,6 +58,7 @@ class AlohaEnv(AbstractEnv): num_prev_obs=num_prev_obs, num_prev_action=num_prev_action, ) + self._reset_warning_issued = False def _make_env(self): if not _has_gym: @@ -120,47 +121,47 @@ class AlohaEnv(AbstractEnv): return obs def _reset(self, tensordict: Optional[TensorDict] = None): - td = tensordict - if td is None or td.is_empty(): - # we need to handle seed iteration, since self._env.reset() rely an internal _seed. - self._current_seed += 1 - self.set_seed(self._current_seed) + if tensordict is not None and not self._reset_warning_issued: + logging.warning(f"{self.__class__.__name__}._reset ignores the provided tensordict.") + self._reset_warning_issued = True - # TODO(rcadene): do not use global variable for this - if "sim_transfer_cube" in self.task: - BOX_POSE[0] = sample_box_pose() # used in sim reset - elif "sim_insertion" in self.task: - BOX_POSE[0] = np.concatenate(sample_insertion_pose()) # used in sim reset + # we need to handle seed iteration, since self._env.reset() rely an internal _seed. + self._current_seed += 1 + self.set_seed(self._current_seed) - raw_obs = self._env.reset() - # TODO(rcadene): add assert - # assert self._current_seed == self._env._seed + # TODO(rcadene): do not use global variable for this + if "sim_transfer_cube" in self.task: + BOX_POSE[0] = sample_box_pose() # used in sim reset + elif "sim_insertion" in self.task: + BOX_POSE[0] = np.concatenate(sample_insertion_pose()) # used in sim reset - obs = self._format_raw_obs(raw_obs.observation) + raw_obs = self._env.reset() + # TODO(rcadene): add assert + # assert self._current_seed == self._env._seed - if self.num_prev_obs > 0: - stacked_obs = {} - if "image" in obs: - self._prev_obs_image_queue = deque( - [obs["image"]["top"]] * (self.num_prev_obs + 1), maxlen=(self.num_prev_obs + 1) - ) - stacked_obs["image"] = {"top": torch.stack(list(self._prev_obs_image_queue))} - if "state" in obs: - self._prev_obs_state_queue = deque( - [obs["state"]] * (self.num_prev_obs + 1), maxlen=(self.num_prev_obs + 1) - ) - stacked_obs["state"] = torch.stack(list(self._prev_obs_state_queue)) - obs = stacked_obs + obs = self._format_raw_obs(raw_obs.observation) - td = TensorDict( - { - "observation": TensorDict(obs, batch_size=[]), - "done": torch.tensor([False], dtype=torch.bool), - }, - batch_size=[], - ) - else: - raise NotImplementedError() + if self.num_prev_obs > 0: + stacked_obs = {} + if "image" in obs: + self._prev_obs_image_queue = deque( + [obs["image"]["top"]] * (self.num_prev_obs + 1), maxlen=(self.num_prev_obs + 1) + ) + stacked_obs["image"] = {"top": torch.stack(list(self._prev_obs_image_queue))} + if "state" in obs: + self._prev_obs_state_queue = deque( + [obs["state"]] * (self.num_prev_obs + 1), maxlen=(self.num_prev_obs + 1) + ) + stacked_obs["state"] = torch.stack(list(self._prev_obs_state_queue)) + obs = stacked_obs + + td = TensorDict( + { + "observation": TensorDict(obs, batch_size=[]), + "done": torch.tensor([False], dtype=torch.bool), + }, + batch_size=[], + ) self.call_rendering_hooks() return td diff --git a/lerobot/common/envs/factory.py b/lerobot/common/envs/factory.py index de86b3ad..689f5869 100644 --- a/lerobot/common/envs/factory.py +++ b/lerobot/common/envs/factory.py @@ -1,31 +1,20 @@ +from torchrl.envs import SerialEnv from torchrl.envs.transforms import Compose, StepCounter, Transform, TransformedEnv def make_env(cfg, transform=None): """ - Provide seed to override the seed in the cfg (useful for batched environments). + Note: The returned environment is wrapped in a torchrl.SerialEnv with cfg.rollout_batch_size underlying + environments. The env therefore returns batches.` """ - # assert cfg.rollout_batch_size == 1, \ - # """ - # For the time being, rollout batch sizes of > 1 are not supported. This is because the SerialEnv rollout does not - # correctly handle terminated environments. If you really want to use a larger batch size, read on... - - # When calling `EnvBase.rollout` with `break_when_any_done == True` all environments stop rolling out as soon as the - # first is terminated or truncated. This almost certainly results in incorrect success metrics, as all but the first - # environment get an opportunity to reach the goal. A possible work around is to comment out `if any_done: break` - # inf `EnvBase._rollout_stop_early`. One potential downside is that the environments `step` function will continue - # to be called and the outputs will continue to be added to the rollout. - - # When calling `EnvBase.rollout` with `break_when_any_done == False` environments are reset when done. - # """ kwargs = { "frame_skip": cfg.env.action_repeat, "from_pixels": cfg.env.from_pixels, "pixels_only": cfg.env.pixels_only, "image_size": cfg.env.image_size, - "num_prev_obs": cfg.n_obs_steps - 1, "seed": cfg.seed, + "num_prev_obs": cfg.n_obs_steps - 1, } if cfg.env.name == "simxarm": @@ -67,13 +56,14 @@ def make_env(cfg, transform=None): return env - # return SerialEnv( - # cfg.rollout_batch_size, - # create_env_fn=_make_env, - # create_env_kwargs={ - # "seed": env_seed for env_seed in range(cfg.seed, cfg.seed + cfg.rollout_batch_size) - # }, - # ) + return SerialEnv( + cfg.rollout_batch_size, + create_env_fn=_make_env, + create_env_kwargs={ + "seed": env_seed # noqa: B035 + for env_seed in range(cfg.seed, cfg.seed + cfg.rollout_batch_size) + }, + ) # def make_env(env_name, frame_skip, device, is_test=False): diff --git a/lerobot/common/envs/pusht/env.py b/lerobot/common/envs/pusht/env.py index 2fe05233..6c348cd6 100644 --- a/lerobot/common/envs/pusht/env.py +++ b/lerobot/common/envs/pusht/env.py @@ -1,4 +1,5 @@ import importlib +import logging from collections import deque from typing import Optional @@ -42,6 +43,7 @@ class PushtEnv(AbstractEnv): num_prev_obs=num_prev_obs, num_prev_action=num_prev_action, ) + self._reset_warning_issued = False def _make_env(self): if not _has_gym: @@ -79,39 +81,39 @@ class PushtEnv(AbstractEnv): return obs def _reset(self, tensordict: Optional[TensorDict] = None): - td = tensordict - if td is None or td.is_empty(): - # we need to handle seed iteration, since self._env.reset() rely an internal _seed. - self._current_seed += 1 - self.set_seed(self._current_seed) - raw_obs = self._env.reset() - assert self._current_seed == self._env._seed + if tensordict is not None and not self._reset_warning_issued: + logging.warning(f"{self.__class__.__name__}._reset ignores the provided tensordict.") + self._reset_warning_issued = True - obs = self._format_raw_obs(raw_obs) + # we need to handle seed iteration, since self._env.reset() rely an internal _seed. + self._current_seed += 1 + self.set_seed(self._current_seed) + raw_obs = self._env.reset() + assert self._current_seed == self._env._seed - if self.num_prev_obs > 0: - stacked_obs = {} - if "image" in obs: - self._prev_obs_image_queue = deque( - [obs["image"]] * (self.num_prev_obs + 1), maxlen=(self.num_prev_obs + 1) - ) - stacked_obs["image"] = torch.stack(list(self._prev_obs_image_queue)) - if "state" in obs: - self._prev_obs_state_queue = deque( - [obs["state"]] * (self.num_prev_obs + 1), maxlen=(self.num_prev_obs + 1) - ) - stacked_obs["state"] = torch.stack(list(self._prev_obs_state_queue)) - obs = stacked_obs + obs = self._format_raw_obs(raw_obs) - td = TensorDict( - { - "observation": TensorDict(obs, batch_size=[]), - "done": torch.tensor([False], dtype=torch.bool), - }, - batch_size=[], - ) - else: - raise NotImplementedError() + if self.num_prev_obs > 0: + stacked_obs = {} + if "image" in obs: + self._prev_obs_image_queue = deque( + [obs["image"]] * (self.num_prev_obs + 1), maxlen=(self.num_prev_obs + 1) + ) + stacked_obs["image"] = torch.stack(list(self._prev_obs_image_queue)) + if "state" in obs: + self._prev_obs_state_queue = deque( + [obs["state"]] * (self.num_prev_obs + 1), maxlen=(self.num_prev_obs + 1) + ) + stacked_obs["state"] = torch.stack(list(self._prev_obs_state_queue)) + obs = stacked_obs + + td = TensorDict( + { + "observation": TensorDict(obs, batch_size=[]), + "done": torch.tensor([False], dtype=torch.bool), + }, + batch_size=[], + ) self.call_rendering_hooks() return td diff --git a/lerobot/common/policies/abstract.py b/lerobot/common/policies/abstract.py index ca2d8570..9f16f5d7 100644 --- a/lerobot/common/policies/abstract.py +++ b/lerobot/common/policies/abstract.py @@ -12,6 +12,17 @@ class AbstractPolicy(nn.Module, ABC): documentation for more information. """ + def __init__(self, n_action_steps: int | None): + """ + n_action_steps: Sets the cache size for storing action trajectories. If None, it is assumed that a single + action is returned by `select_actions` and that doesn't have a horizon dimension. The `forward` method then + adds that dimension. + """ + super().__init__() + self.n_action_steps = n_action_steps + if n_action_steps is not None: + self._action_queue = deque([], maxlen=n_action_steps) + @abstractmethod def update(self, replay_buffer, step): """One step of the policy's learning algorithm.""" @@ -24,10 +35,11 @@ class AbstractPolicy(nn.Module, ABC): self.load_state_dict(d) @abstractmethod - def select_action(self, observation) -> Tensor: + def select_actions(self, observation) -> Tensor: """Select an action (or trajectory of actions) based on an observation during rollout. - Should return a (batch_size, n_action_steps, *) tensor of actions. + If n_action_steps was provided at initialization, this should return a (batch_size, n_action_steps, *) tensor of + actions. Otherwise if n_actions_steps is None, this should return a (batch_size, *) tensor of actions. """ def forward(self, *args, **kwargs) -> Tensor: @@ -41,18 +53,14 @@ class AbstractPolicy(nn.Module, ABC): observation, (3) repopulates the action queue when empty. This method handles the aforementioned logic so that the subclass doesn't have to. - This method effectively wraps the `select_action` method of the subclass. The following assumptions are made: - 1. The `select_action` method returns a Tensor of actions with shape (B, H, *) where B is the batch size, H is + This method effectively wraps the `select_actions` method of the subclass. The following assumptions are made: + 1. The `select_actions` method returns a Tensor of actions with shape (B, H, *) where B is the batch size, H is the action trajectory horizon and * is the action dimensions. - 2. Prior to the `select_action` method being called, theres is an `n_action_steps` instance attribute defined. + 2. Prior to the `select_actions` method being called, theres is an `n_action_steps` instance attribute defined. """ - n_action_steps_attr = "n_action_steps" - if not hasattr(self, n_action_steps_attr): - raise RuntimeError(f"Underlying policy must have an `{n_action_steps_attr}` attribute") - if not hasattr(self, "_action_queue"): - self._action_queue = deque([], maxlen=getattr(self, n_action_steps_attr)) + if self.n_action_steps is None: + return self.select_actions(*args, **kwargs) if len(self._action_queue) == 0: # Each element in the queue has shape (B, *). - self._action_queue.extend(self.select_action(*args, **kwargs).transpose(0, 1)) - + self._action_queue.extend(self.select_actions(*args, **kwargs).transpose(0, 1)) return self._action_queue.popleft() diff --git a/lerobot/common/policies/act/policy.py b/lerobot/common/policies/act/policy.py index e0499cdb..539cdcf5 100644 --- a/lerobot/common/policies/act/policy.py +++ b/lerobot/common/policies/act/policy.py @@ -42,7 +42,7 @@ def kl_divergence(mu, logvar): class ActionChunkingTransformerPolicy(AbstractPolicy): def __init__(self, cfg, device, n_action_steps=1): - super().__init__() + super().__init__(n_action_steps) self.cfg = cfg self.n_action_steps = n_action_steps self.device = device @@ -147,7 +147,10 @@ class ActionChunkingTransformerPolicy(AbstractPolicy): return loss @torch.no_grad() - def select_action(self, observation, step_count): + def select_actions(self, observation, step_count): + if observation["image"].shape[0] != 1: + raise NotImplementedError("Batch size > 1 not handled") + # TODO(rcadene): remove unused step_count del step_count diff --git a/lerobot/common/policies/diffusion/policy.py b/lerobot/common/policies/diffusion/policy.py index db004a71..2c47f172 100644 --- a/lerobot/common/policies/diffusion/policy.py +++ b/lerobot/common/policies/diffusion/policy.py @@ -34,7 +34,7 @@ class DiffusionPolicy(AbstractPolicy): # parameters passed to step **kwargs, ): - super().__init__() + super().__init__(n_action_steps) self.cfg = cfg noise_scheduler = hydra.utils.instantiate(cfg_noise_scheduler) @@ -44,7 +44,6 @@ class DiffusionPolicy(AbstractPolicy): **cfg_obs_encoder, ) - self.n_action_steps = n_action_steps # needed for the parent class self.diffusion = DiffusionUnetImagePolicy( shape_meta=shape_meta, noise_scheduler=noise_scheduler, @@ -94,7 +93,7 @@ class DiffusionPolicy(AbstractPolicy): ) @torch.no_grad() - def select_action(self, observation, step_count): + def select_actions(self, observation, step_count): # TODO(rcadene): remove unused step_count del step_count diff --git a/lerobot/common/policies/factory.py b/lerobot/common/policies/factory.py index c5e45300..085baab5 100644 --- a/lerobot/common/policies/factory.py +++ b/lerobot/common/policies/factory.py @@ -1,4 +1,7 @@ def make_policy(cfg): + if cfg.policy.name != "diffusion" and cfg.rollout_batch_size > 1: + raise NotImplementedError("Only diffusion policy supports rollout_batch_size > 1 for the time being.") + if cfg.policy.name == "tdmpc": from lerobot.common.policies.tdmpc.policy import TDMPC diff --git a/lerobot/common/policies/tdmpc/policy.py b/lerobot/common/policies/tdmpc/policy.py index 4c104bcd..320f6f2b 100644 --- a/lerobot/common/policies/tdmpc/policy.py +++ b/lerobot/common/policies/tdmpc/policy.py @@ -90,7 +90,7 @@ class TDMPC(AbstractPolicy): """Implementation of TD-MPC learning + inference.""" def __init__(self, cfg, device): - super().__init__() + super().__init__(None) self.action_dim = cfg.action_dim self.cfg = cfg @@ -125,7 +125,10 @@ class TDMPC(AbstractPolicy): self.model_target.load_state_dict(d["model_target"]) @torch.no_grad() - def select_action(self, observation, step_count): + def select_actions(self, observation, step_count): + if observation["image"].shape[0] != 1: + raise NotImplementedError("Batch size > 1 not handled") + t0 = step_count.item() == 0 obs = { @@ -133,7 +136,8 @@ class TDMPC(AbstractPolicy): "rgb": observation["image"].contiguous(), "state": observation["state"].contiguous(), } - action = self.act(obs, t0=t0, step=self.step.item()) + # Note: unsqueeze needed because `act` still uses non-batch logic. + action = self.act(obs, t0=t0, step=self.step.item()).unsqueeze(0) return action @torch.no_grad() @@ -144,7 +148,7 @@ class TDMPC(AbstractPolicy): if self.cfg.mpc: a = self.plan(z, t0=t0, step=step) else: - a = self.model.pi(z, self.cfg.min_std * self.model.training) + a = self.model.pi(z, self.cfg.min_std * self.model.training).squeeze(0) return a @torch.no_grad() diff --git a/lerobot/configs/default.yaml b/lerobot/configs/default.yaml index 27b75c88..52fd1d60 100644 --- a/lerobot/configs/default.yaml +++ b/lerobot/configs/default.yaml @@ -11,8 +11,7 @@ hydra: seed: 1337 # batch size for TorchRL SerialEnv. Each underlying env will get the seed = seed + env_index -# NOTE: batch size of 1 is not yet supported! This is just a placeholder for future support. See -# `lerobot.common.envs.factory.make_env` for more information. +# NOTE: only diffusion policy supports rollout_batch_size > 1 rollout_batch_size: 1 device: cuda # cpu prefetch: 4 @@ -20,7 +19,7 @@ eval_freq: ??? save_freq: ??? eval_episodes: ??? save_video: false -save_model: true +save_model: false save_buffer: false train_steps: ??? fps: ??? @@ -33,7 +32,7 @@ env: ??? policy: ??? wandb: - enable: false + enable: true # Set to true to disable saving an artifact despite save_model == True disable_artifact: false project: lerobot diff --git a/lerobot/configs/policy/diffusion.yaml b/lerobot/configs/policy/diffusion.yaml index ce8acbd4..0dae5056 100644 --- a/lerobot/configs/policy/diffusion.yaml +++ b/lerobot/configs/policy/diffusion.yaml @@ -22,8 +22,8 @@ keypoint_visible_rate: 1.0 obs_as_global_cond: True eval_episodes: 1 -eval_freq: 5000 -save_freq: 5000 +eval_freq: 10000 +save_freq: 100000 log_freq: 250 offline_steps: 1344000 diff --git a/lerobot/scripts/eval.py b/lerobot/scripts/eval.py index c0199c0c..2c564da0 100644 --- a/lerobot/scripts/eval.py +++ b/lerobot/scripts/eval.py @@ -51,16 +51,25 @@ def eval_policy( ep_frames.append(env.render()) # noqa: B023 with torch.inference_mode(): + # TODO(alexander-soare): Due the `break_when_any_done == False` this rolls out for max_steps even when all + # envs are done the first time. But we only use the first rollout. This is a waste of compute. rollout = env.rollout( max_steps=max_steps, policy=policy, auto_cast_to_device=True, callback=maybe_render_frame, + break_when_any_done=False, ) - # print(", ".join([f"{x:.3f}" for x in rollout["next", "reward"][:,0].tolist()])) - batch_sum_reward = rollout["next", "reward"].flatten(start_dim=1).sum(dim=-1) - batch_max_reward = rollout["next", "reward"].flatten(start_dim=1).max(dim=-1)[0] - batch_success = rollout["next", "success"].flatten(start_dim=1).any(dim=-1) + # Figure out where in each rollout sequence the first done condition was encountered (results after this won't + # be included). + # Note: this assumes that the shape of the done key is (batch_size, max_steps, 1). + # Note: this relies on a property of argmax: that it returns the first occurrence as a tiebreaker. + rollout_steps = rollout["next", "done"].shape[1] + done_indices = torch.argmax(rollout["next", "done"].to(int), axis=1) # (batch_size, rollout_steps) + mask = (torch.arange(rollout_steps) <= done_indices).unsqueeze(-1) # (batch_size, rollout_steps, 1) + batch_sum_reward = (rollout["next", "reward"] * mask).flatten(start_dim=1).sum(dim=-1) + batch_max_reward = (rollout["next", "reward"] * mask).flatten(start_dim=1).max(dim=-1)[0] + batch_success = (rollout["next", "success"] * mask).flatten(start_dim=1).any(dim=-1) sum_rewards.extend(batch_sum_reward.tolist()) max_rewards.extend(batch_max_reward.tolist()) successes.extend(batch_success.tolist()) diff --git a/tests/data/aloha_sim_insertion_human/stats.pth b/tests/data/aloha_sim_insertion_human/stats.pth index f909ed075ce48cf7f677cc82a8859615b15924c1..d41ac18cdfeb94b610369f116db8b267cf642af8 100644 GIT binary patch delta 754 zcmcblI7w;43no5qMsLRULT3;KWK4d?qze`Df(a=y!-N81LLqRW2&mBJcIGvVT#o|X zb@^Et82DKy3-HK@^D~$*6qh8H6zeBdmZatvrKA?QIT+aHvyI&< zz?+>zD$J5;vICoggIM8IYaPqOHj_o!Y!7A4w?3@1(Ds$&9BYLS5jMgp64p^%J8ZIZ z7;LBB{c3gd?+jb@hrBl1kG}vLk}lTSx%mK_1Cyli4@IX(Ulg1gvz47d#;;rRUV(>! zfq`c-KacEWGfpqb-7b4=uLbV2O$^^{3sO;EcHb1PV)HgmFGdH|wq<+n%<|qZal76= zM2UB=@2uChuOqnkZqf_4Q;eUtrzIe8zlTVb{e7!7Hn(`0><)KswheM)-2*c2qe9=y z$qw8O4ioy`TEB=nY9qOu(^f!uk@fmROKoqr&$Lc>^vx!cJXi9crBSY*e74)wFGjw`K*{Y zI3~OESpYc=eDW-yz?vM$FUG{dIe9mq91AFDKzzoS$zS*^8ShNiGEdsn5nSkNVz`+3`85qD`fmnWrA80wxWN`s`X^{IhfU+Pg15ySD dL6ZXo`~`o2Vgv{RyxG`6^oGeB1*F*^q5#o~@`eBa delta 716 zcmbQFbV+f;3nsSKLT7Kr*2&+Pbb(AS5Yvzu#0&&6bHL09AanC_<~595USWwPe}7e@J-4mE z-fPyoY}>85_C{{tw~Id*x+kSQbpQ4Jwe}((r`d45y<{u0b))U(Mb_H`yxBRX=c#X* z+`#7GuyIq5b@$B0HnKbK+PMGkvKHoTwGAq5wsuPPwh0g6v{tz=%ci66vdxBdpRH_@ zYHbe|ezM-dch3rBh_~M3$D7z4m?T3!DLTo&S8)1~s_X=kc(!oqY#s&%2A;|OJhGF^ zIK3n#?3dc!Ena48TE5s8q$2g3<#D)*&EGh^7#)5rUb^S6n)m)wy7l(9x_o=X4Sv|F zFXq|Hu|CwUF=_su%$EWCZxmPB%c`%nIT+4vSG{nH?LI@EJs{(F1qAaaH*h;RRH|OH z7PVY#!xQ(&=5#~1wM=A(ZR^5D>m7C*$kK|V_$$D7ZJ=>+@ac0LOr=K!BP z3n;86PvjS4I>9mdH=i5}C}coH4DpC?l z67_+uV`KsbHv@`2=dYi~!Jb?#bZ-^3ovpYXD_ISO%mF4umF86z~_^0g4bH S2=HcO2hjzSKMF{*K|}$D81YR2 diff --git a/tests/data/pusht/stats.pth b/tests/data/pusht/stats.pth index 8846b8f65ae9b52dc74d369c239d64f42ff214fb..039d5db3d53a93a92163692e0702cc1c5de65298 100644 GIT binary patch delta 575 zcmbQFcu8@?NhUsTMsLRULVFMeWK2HKqze^tg9$M)!-RZcLZ)z`5U9}R4CY2gE)%a@ z=c6nP3`bceU*MJz=V!2EC@x7XDb`P_EJ@8TN=YqpbFyNXJdss@@&Z=w$quXy0((8D zJ3n@7bq?@m=Lq-i`#<>tYk|Q0$GncFqPHADf=rVx-`^a-R=^~AV12o>nMIXzwM`{h z`TeDL!+97O74p@Vs30=W^8O|WM*z^ zU=A`h@`JnOO@28Ru*-p_s?6bHfCh0&Vo9RDMSwRW6EH{^I5fRw?35D&xT_rTCJ5HJwj0SY=G2=HcO2hjzS3k9UvAff=0DxYrv delta 539 zcmcblI7xBCNhY?|LVIt<*2&kIbb(Aa5L1X5#PkI*eZkBSAaipEb0Z_yw$)qKE@fe0 zSjsY4fJa80pFxhHxFoTpSU;(h;Pw{dzg3dnDq;Vfd*3pOB}bEDy8 z0qz2U2P`~}0Ur;6?GxMkp?&iNZWcxv150yDO9NALQ&R&2Lvu@WV%v7EEr=Z>+xGM-Qk>^ z$uGwOb~(^gl^8AtXz->amL%$11b8zt0fU5rg9Ah|Fo1jmv7T%4U4D@DVgd@XARlM| qWkFa5qzn!&aWhQT2Zm;WfPr8IDCmG7z?+R7L>o-rC?L%S5d{DdfR>v8 diff --git a/tests/test_policies.py b/tests/test_policies.py index ee5abdb7..953684ed 100644 --- a/tests/test_policies.py +++ b/tests/test_policies.py @@ -1,4 +1,3 @@ - from omegaconf import open_dict import pytest from tensordict import TensorDict @@ -16,35 +15,50 @@ from .utils import DEVICE, init_config @pytest.mark.parametrize( - "env_name,policy_name", + "env_name,policy_name,extra_overrides", [ - ("simxarm", "tdmpc"), - ("pusht", "tdmpc"), - ("simxarm", "diffusion"), - ("pusht", "diffusion"), + ("simxarm", "tdmpc", ["policy.mpc=true"]), + ("pusht", "tdmpc", ["policy.mpc=false"]), + ("simxarm", "diffusion", []), + ("pusht", "diffusion", []), + ("aloha", "act", ["env.task=sim_insertion_scripted"]), ], ) -def test_factory(env_name, policy_name): +def test_concrete_policy(env_name, policy_name, extra_overrides): + """ + Tests: + - Making the policy object. + - Updating the policy. + - Using the policy to select actions at inference time. + """ cfg = init_config( overrides=[ f"env={env_name}", f"policy={policy_name}", f"device={DEVICE}", ] + + extra_overrides ) # Check that we can make the policy object. policy = make_policy(cfg) - # Check that we run select_action and get the appropriate output. + # Check that we run select_actions and get the appropriate output. if env_name == "simxarm": # TODO(rcadene): Not implemented return if policy_name == "tdmpc": # TODO(alexander-soare): TDMPC does not use n_obs_steps but the environment requires this. with open_dict(cfg): - cfg['n_obs_steps'] = 1 + cfg["n_obs_steps"] = 1 offline_buffer = make_offline_buffer(cfg) env = make_env(cfg, transform=offline_buffer.transform) - policy.select_action(env.observation_spec.rand()['observation'].to(DEVICE), torch.tensor(0, device=DEVICE)) + + policy.update(offline_buffer, torch.tensor(0, device=DEVICE)) + + action = policy( + env.observation_spec.rand()["observation"].to(DEVICE), + torch.tensor(0, device=DEVICE), + ) + assert action.shape == env.action_spec.shape def test_abstract_policy_forward(): @@ -90,21 +104,20 @@ def test_abstract_policy_forward(): def _set_seed(self, seed: int | None): return - class StubPolicy(AbstractPolicy): def __init__(self): - super().__init__() - self.n_action_steps = n_action_steps + super().__init__(n_action_steps) self.n_policy_invocations = 0 def update(self): pass - def select_action(self): + def select_actions(self): self.n_policy_invocations += 1 - return torch.stack([torch.tensor([i]) for i in range(self.n_action_steps)]).unsqueeze(0) - + return torch.stack( + [torch.tensor([i]) for i in range(self.n_action_steps)] + ).unsqueeze(0) env = StubEnv() policy = StubPolicy() @@ -119,4 +132,4 @@ def test_abstract_policy_forward(): assert len(rollout) == terminate_at + 1 # +1 for the reset observation assert policy.n_policy_invocations == (terminate_at // n_action_steps) + 1 - assert torch.equal(rollout['observation'].flatten(), torch.arange(terminate_at + 1)) + assert torch.equal(rollout["observation"].flatten(), torch.arange(terminate_at + 1)) From 46ac87d2a68ec1b67943f0d22d73c6821f477386 Mon Sep 17 00:00:00 2001 From: Alexander Soare Date: Tue, 19 Mar 2024 18:59:08 +0000 Subject: [PATCH 08/16] ready for review --- tests/test_policies.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/tests/test_policies.py b/tests/test_policies.py index 953684ed..f2ebcfcc 100644 --- a/tests/test_policies.py +++ b/tests/test_policies.py @@ -52,7 +52,10 @@ def test_concrete_policy(env_name, policy_name, extra_overrides): offline_buffer = make_offline_buffer(cfg) env = make_env(cfg, transform=offline_buffer.transform) - policy.update(offline_buffer, torch.tensor(0, device=DEVICE)) + if policy_name != "aloha": + # TODO(alexander-soare): Fix this part of the test. PrioritizedSliceSampler raises NotImplementedError: + # seq_length as a list is not supported for now. + policy.update(offline_buffer, torch.tensor(0, device=DEVICE)) action = policy( env.observation_spec.rand()["observation"].to(DEVICE), From b54cdc9a0fe9584faa27780b1bb112539f5e435c Mon Sep 17 00:00:00 2001 From: Alexander Soare Date: Tue, 19 Mar 2024 19:08:25 +0000 Subject: [PATCH 09/16] break_when_any_done==True for batch_size==1 --- lerobot/scripts/eval.py | 4 ++-- tests/test_policies.py | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/lerobot/scripts/eval.py b/lerobot/scripts/eval.py index 2c564da0..86d4158e 100644 --- a/lerobot/scripts/eval.py +++ b/lerobot/scripts/eval.py @@ -51,14 +51,14 @@ def eval_policy( ep_frames.append(env.render()) # noqa: B023 with torch.inference_mode(): - # TODO(alexander-soare): Due the `break_when_any_done == False` this rolls out for max_steps even when all + # TODO(alexander-soare): When `break_when_any_done == False` this rolls out for max_steps even when all # envs are done the first time. But we only use the first rollout. This is a waste of compute. rollout = env.rollout( max_steps=max_steps, policy=policy, auto_cast_to_device=True, callback=maybe_render_frame, - break_when_any_done=False, + break_when_any_done=env.batch_size[0] == 1, ) # Figure out where in each rollout sequence the first done condition was encountered (results after this won't # be included). diff --git a/tests/test_policies.py b/tests/test_policies.py index f2ebcfcc..e6cfdfbc 100644 --- a/tests/test_policies.py +++ b/tests/test_policies.py @@ -52,7 +52,7 @@ def test_concrete_policy(env_name, policy_name, extra_overrides): offline_buffer = make_offline_buffer(cfg) env = make_env(cfg, transform=offline_buffer.transform) - if policy_name != "aloha": + if env_name != "aloha": # TODO(alexander-soare): Fix this part of the test. PrioritizedSliceSampler raises NotImplementedError: # seq_length as a list is not supported for now. policy.update(offline_buffer, torch.tensor(0, device=DEVICE)) From 18fa88475b29d51570a2d6349367a5d409979f40 Mon Sep 17 00:00:00 2001 From: Alexander Soare Date: Wed, 20 Mar 2024 08:09:38 +0000 Subject: [PATCH 10/16] Move reset_warning_issued flag to class attribute --- lerobot/common/envs/aloha/env.py | 7 ++++--- lerobot/common/envs/pusht/env.py | 7 ++++--- 2 files changed, 8 insertions(+), 6 deletions(-) diff --git a/lerobot/common/envs/aloha/env.py b/lerobot/common/envs/aloha/env.py index 001b2ba2..6f8fded1 100644 --- a/lerobot/common/envs/aloha/env.py +++ b/lerobot/common/envs/aloha/env.py @@ -35,6 +35,8 @@ _has_gym = importlib.util.find_spec("gym") is not None class AlohaEnv(AbstractEnv): + _reset_warning_issued = False + def __init__( self, task, @@ -58,7 +60,6 @@ class AlohaEnv(AbstractEnv): num_prev_obs=num_prev_obs, num_prev_action=num_prev_action, ) - self._reset_warning_issued = False def _make_env(self): if not _has_gym: @@ -121,9 +122,9 @@ class AlohaEnv(AbstractEnv): return obs def _reset(self, tensordict: Optional[TensorDict] = None): - if tensordict is not None and not self._reset_warning_issued: + if tensordict is not None and not AlohaEnv._reset_warning_issued: logging.warning(f"{self.__class__.__name__}._reset ignores the provided tensordict.") - self._reset_warning_issued = True + AlohaEnv._reset_warning_issued = True # we need to handle seed iteration, since self._env.reset() rely an internal _seed. self._current_seed += 1 diff --git a/lerobot/common/envs/pusht/env.py b/lerobot/common/envs/pusht/env.py index 6c348cd6..aadf626c 100644 --- a/lerobot/common/envs/pusht/env.py +++ b/lerobot/common/envs/pusht/env.py @@ -20,6 +20,8 @@ _has_gym = importlib.util.find_spec("gym") is not None class PushtEnv(AbstractEnv): + _reset_warning_issued = False + def __init__( self, task="pusht", @@ -43,7 +45,6 @@ class PushtEnv(AbstractEnv): num_prev_obs=num_prev_obs, num_prev_action=num_prev_action, ) - self._reset_warning_issued = False def _make_env(self): if not _has_gym: @@ -81,9 +82,9 @@ class PushtEnv(AbstractEnv): return obs def _reset(self, tensordict: Optional[TensorDict] = None): - if tensordict is not None and not self._reset_warning_issued: + if tensordict is not None and not PushtEnv._reset_warning_issued: logging.warning(f"{self.__class__.__name__}._reset ignores the provided tensordict.") - self._reset_warning_issued = True + PushtEnv._reset_warning_issued = True # we need to handle seed iteration, since self._env.reset() rely an internal _seed. self._current_seed += 1 From c5010fee9a40c64d61e69918ee64691f69a4b4c8 Mon Sep 17 00:00:00 2001 From: Alexander Soare Date: Wed, 20 Mar 2024 08:20:56 +0000 Subject: [PATCH 11/16] fix seeding --- lerobot/common/envs/factory.py | 8 +++----- 1 file changed, 3 insertions(+), 5 deletions(-) diff --git a/lerobot/common/envs/factory.py b/lerobot/common/envs/factory.py index 689f5869..e187d713 100644 --- a/lerobot/common/envs/factory.py +++ b/lerobot/common/envs/factory.py @@ -13,7 +13,6 @@ def make_env(cfg, transform=None): "from_pixels": cfg.env.from_pixels, "pixels_only": cfg.env.pixels_only, "image_size": cfg.env.image_size, - "seed": cfg.seed, "num_prev_obs": cfg.n_obs_steps - 1, } @@ -59,10 +58,9 @@ def make_env(cfg, transform=None): return SerialEnv( cfg.rollout_batch_size, create_env_fn=_make_env, - create_env_kwargs={ - "seed": env_seed # noqa: B035 - for env_seed in range(cfg.seed, cfg.seed + cfg.rollout_batch_size) - }, + create_env_kwargs=[ + {"seed": env_seed} for env_seed in range(cfg.seed, cfg.seed + cfg.rollout_batch_size) + ], ) From 4f1955edfdc7515f85ec9d70361932cd45e1c327 Mon Sep 17 00:00:00 2001 From: Alexander Soare Date: Wed, 20 Mar 2024 08:31:06 +0000 Subject: [PATCH 12/16] Clear action queue when environment is reset --- lerobot/common/policies/abstract.py | 8 ++++++-- lerobot/scripts/eval.py | 4 +++- 2 files changed, 9 insertions(+), 3 deletions(-) diff --git a/lerobot/common/policies/abstract.py b/lerobot/common/policies/abstract.py index 9f16f5d7..272ffcf4 100644 --- a/lerobot/common/policies/abstract.py +++ b/lerobot/common/policies/abstract.py @@ -20,8 +20,7 @@ class AbstractPolicy(nn.Module, ABC): """ super().__init__() self.n_action_steps = n_action_steps - if n_action_steps is not None: - self._action_queue = deque([], maxlen=n_action_steps) + self.clear_action_queue() @abstractmethod def update(self, replay_buffer, step): @@ -42,6 +41,11 @@ class AbstractPolicy(nn.Module, ABC): actions. Otherwise if n_actions_steps is None, this should return a (batch_size, *) tensor of actions. """ + def clear_action_queue(self): + """This should be called whenever the environment is reset.""" + if self.n_action_steps is not None: + self._action_queue = deque([], maxlen=self.n_action_steps) + def forward(self, *args, **kwargs) -> Tensor: """Inference step that makes multi-step policies compatible with their single-step environments. diff --git a/lerobot/scripts/eval.py b/lerobot/scripts/eval.py index 86d4158e..1e44c5df 100644 --- a/lerobot/scripts/eval.py +++ b/lerobot/scripts/eval.py @@ -15,6 +15,7 @@ from torchrl.envs.batched_envs import BatchedEnvBase from lerobot.common.datasets.factory import make_offline_buffer from lerobot.common.envs.factory import make_env from lerobot.common.logger import log_output_dir +from lerobot.common.policies.abstract import AbstractPolicy from lerobot.common.policies.factory import make_policy from lerobot.common.utils import init_logging, set_seed @@ -25,7 +26,7 @@ def write_video(video_path, stacked_frames, fps): def eval_policy( env: BatchedEnvBase, - policy: TensorDictModule = None, + policy: AbstractPolicy, num_episodes: int = 10, max_steps: int = 30, save_video: bool = False, @@ -53,6 +54,7 @@ def eval_policy( with torch.inference_mode(): # TODO(alexander-soare): When `break_when_any_done == False` this rolls out for max_steps even when all # envs are done the first time. But we only use the first rollout. This is a waste of compute. + policy.clear_action_queue() rollout = env.rollout( max_steps=max_steps, policy=policy, From 52e149fbfdf038406cfb1439868063c5d49b4d9c Mon Sep 17 00:00:00 2001 From: Alexander Soare Date: Wed, 20 Mar 2024 08:32:11 +0000 Subject: [PATCH 13/16] Only save video frames in first rollout --- lerobot/scripts/eval.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/lerobot/scripts/eval.py b/lerobot/scripts/eval.py index 1e44c5df..7127b24d 100644 --- a/lerobot/scripts/eval.py +++ b/lerobot/scripts/eval.py @@ -83,14 +83,16 @@ def eval_policy( ) # (b, t, *) if save_video: - for stacked_frames in batch_stacked_frames: + for stacked_frames, done_index in zip( + batch_stacked_frames, done_indices.flatten().tolist(), strict=False + ): if episode_counter >= num_episodes: continue video_dir.mkdir(parents=True, exist_ok=True) video_path = video_dir / f"eval_episode_{episode_counter}.mp4" thread = threading.Thread( target=write_video, - args=(str(video_path), stacked_frames, fps), + args=(str(video_path), stacked_frames[:done_index], fps), ) thread.start() threads.append(thread) From b1ec3da0358a65efa90f6f02f88b852917ba1854 Mon Sep 17 00:00:00 2001 From: Alexander Soare Date: Wed, 20 Mar 2024 09:23:23 +0000 Subject: [PATCH 14/16] remove internal rendering hooks --- lerobot/common/envs/abstract.py | 11 ----------- lerobot/common/envs/aloha/env.py | 3 --- lerobot/common/envs/pusht/env.py | 3 --- lerobot/common/envs/simxarm.py | 3 --- lerobot/scripts/eval.py | 2 -- 5 files changed, 22 deletions(-) diff --git a/lerobot/common/envs/abstract.py b/lerobot/common/envs/abstract.py index 0754fb76..8d1a09de 100644 --- a/lerobot/common/envs/abstract.py +++ b/lerobot/common/envs/abstract.py @@ -27,7 +27,6 @@ class AbstractEnv(EnvBase): self.image_size = image_size self.num_prev_obs = num_prev_obs self.num_prev_action = num_prev_action - self._rendering_hooks = [] if pixels_only: assert from_pixels @@ -45,16 +44,6 @@ class AbstractEnv(EnvBase): raise NotImplementedError() # self._prev_action_queue = deque(maxlen=self.num_prev_action) - def register_rendering_hook(self, func): - self._rendering_hooks.append(func) - - def call_rendering_hooks(self): - for func in self._rendering_hooks: - func(self) - - def reset_rendering_hooks(self): - self._rendering_hooks = [] - @abc.abstractmethod def render(self, mode="rgb_array", width=640, height=480): raise NotImplementedError() diff --git a/lerobot/common/envs/aloha/env.py b/lerobot/common/envs/aloha/env.py index 6f8fded1..e09564fb 100644 --- a/lerobot/common/envs/aloha/env.py +++ b/lerobot/common/envs/aloha/env.py @@ -164,7 +164,6 @@ class AlohaEnv(AbstractEnv): batch_size=[], ) - self.call_rendering_hooks() return td def _step(self, tensordict: TensorDict): @@ -189,8 +188,6 @@ class AlohaEnv(AbstractEnv): stacked_obs["state"] = torch.stack(list(self._prev_obs_state_queue)) obs = stacked_obs - self.call_rendering_hooks() - td = TensorDict( { "observation": TensorDict(obs, batch_size=[]), diff --git a/lerobot/common/envs/pusht/env.py b/lerobot/common/envs/pusht/env.py index aadf626c..f440d443 100644 --- a/lerobot/common/envs/pusht/env.py +++ b/lerobot/common/envs/pusht/env.py @@ -116,7 +116,6 @@ class PushtEnv(AbstractEnv): batch_size=[], ) - self.call_rendering_hooks() return td def _step(self, tensordict: TensorDict): @@ -139,8 +138,6 @@ class PushtEnv(AbstractEnv): stacked_obs["state"] = torch.stack(list(self._prev_obs_state_queue)) obs = stacked_obs - self.call_rendering_hooks() - td = TensorDict( { "observation": TensorDict(obs, batch_size=[]), diff --git a/lerobot/common/envs/simxarm.py b/lerobot/common/envs/simxarm.py index d0612625..eac3666d 100644 --- a/lerobot/common/envs/simxarm.py +++ b/lerobot/common/envs/simxarm.py @@ -118,7 +118,6 @@ class SimxarmEnv(AbstractEnv): else: raise NotImplementedError() - self.call_rendering_hooks() return td def _step(self, tensordict: TensorDict): @@ -152,8 +151,6 @@ class SimxarmEnv(AbstractEnv): stacked_obs["state"] = torch.stack(list(self._prev_obs_state_queue)) obs = stacked_obs - self.call_rendering_hooks() - td = TensorDict( { "observation": self._format_raw_obs(raw_obs), diff --git a/lerobot/scripts/eval.py b/lerobot/scripts/eval.py index 7127b24d..e98df19c 100644 --- a/lerobot/scripts/eval.py +++ b/lerobot/scripts/eval.py @@ -101,8 +101,6 @@ def eval_policy( if return_first_video and i == 0: first_video = batch_stacked_frames[0].transpose(0, 3, 1, 2) - env.reset_rendering_hooks() - for thread in threads: thread.join() From 5332766a8241eaf08af74940bef29a48f4600fa2 Mon Sep 17 00:00:00 2001 From: Alexander Soare Date: Wed, 20 Mar 2024 09:45:45 +0000 Subject: [PATCH 15/16] revision --- lerobot/common/policies/abstract.py | 3 +- lerobot/scripts/eval.py | 7 +- poetry.lock | 255 +++------------------------- pyproject.toml | 2 - 4 files changed, 34 insertions(+), 233 deletions(-) diff --git a/lerobot/common/policies/abstract.py b/lerobot/common/policies/abstract.py index 272ffcf4..1c300dbe 100644 --- a/lerobot/common/policies/abstract.py +++ b/lerobot/common/policies/abstract.py @@ -65,6 +65,7 @@ class AbstractPolicy(nn.Module, ABC): if self.n_action_steps is None: return self.select_actions(*args, **kwargs) if len(self._action_queue) == 0: - # Each element in the queue has shape (B, *). + # `select_actions` returns a (batch_size, n_action_steps, *) tensor, but the queue effectively has shape + # (n_action_steps, batch_size, *), hence the transpose. self._action_queue.extend(self.select_actions(*args, **kwargs).transpose(0, 1)) return self._action_queue.popleft() diff --git a/lerobot/scripts/eval.py b/lerobot/scripts/eval.py index e98df19c..41d58b91 100644 --- a/lerobot/scripts/eval.py +++ b/lerobot/scripts/eval.py @@ -3,6 +3,7 @@ import threading import time from pathlib import Path +import einops import hydra import imageio import numpy as np @@ -69,9 +70,9 @@ def eval_policy( rollout_steps = rollout["next", "done"].shape[1] done_indices = torch.argmax(rollout["next", "done"].to(int), axis=1) # (batch_size, rollout_steps) mask = (torch.arange(rollout_steps) <= done_indices).unsqueeze(-1) # (batch_size, rollout_steps, 1) - batch_sum_reward = (rollout["next", "reward"] * mask).flatten(start_dim=1).sum(dim=-1) - batch_max_reward = (rollout["next", "reward"] * mask).flatten(start_dim=1).max(dim=-1)[0] - batch_success = (rollout["next", "success"] * mask).flatten(start_dim=1).any(dim=-1) + batch_sum_reward = einops.reduce((rollout["next", "reward"] * mask), "b n 1 -> b", "sum") + batch_max_reward = einops.reduce((rollout["next", "reward"] * mask), "b n 1 -> b", "max") + batch_success = einops.reduce((rollout["next", "success"] * mask), "b n 1 -> b", "any") sum_rewards.extend(batch_sum_reward.tolist()) max_rewards.extend(batch_max_reward.tolist()) successes.extend(batch_success.tolist()) diff --git a/poetry.lock b/poetry.lock index 5e84ebef..ddb0a0e3 100644 --- a/poetry.lock +++ b/poetry.lock @@ -604,16 +604,6 @@ files = [ [package.dependencies] six = ">=1.4.0" -[[package]] -name = "egl-probe" -version = "1.0.2" -description = "" -optional = false -python-versions = "*" -files = [ - {file = "egl_probe-1.0.2.tar.gz", hash = "sha256:29bdca7b08da1e060cfb42cd46af8300a7ac4f3b1b2eeb16e545ea16d9a5ac93"}, -] - [[package]] name = "einops" version = "0.7.0" @@ -668,13 +658,13 @@ typing = ["typing-extensions (>=4.8)"] [[package]] name = "fsspec" -version = "2024.2.0" +version = "2024.3.1" description = "File-system specification" optional = false python-versions = ">=3.8" files = [ - {file = "fsspec-2024.2.0-py3-none-any.whl", hash = "sha256:817f969556fa5916bc682e02ca2045f96ff7f586d45110fcb76022063ad2c7d8"}, - {file = "fsspec-2024.2.0.tar.gz", hash = "sha256:b6ad1a679f760dda52b1168c859d01b7b80648ea6f7f7c7f5a8a91dc3f3ecb84"}, + {file = "fsspec-2024.3.1-py3-none-any.whl", hash = "sha256:918d18d41bf73f0e2b261824baeb1b124bcf771767e3a26425cd7dec3332f512"}, + {file = "fsspec-2024.3.1.tar.gz", hash = "sha256:f39780e282d7d117ffb42bb96992f8a90795e4d0fb0f661a70ca39fe9c43ded9"}, ] [package.extras] @@ -773,72 +763,6 @@ files = [ [package.extras] preview = ["glfw-preview"] -[[package]] -name = "grpcio" -version = "1.62.1" -description = "HTTP/2-based RPC framework" -optional = false -python-versions = ">=3.7" -files = [ - {file = "grpcio-1.62.1-cp310-cp310-linux_armv7l.whl", hash = "sha256:179bee6f5ed7b5f618844f760b6acf7e910988de77a4f75b95bbfaa8106f3c1e"}, - {file = "grpcio-1.62.1-cp310-cp310-macosx_12_0_universal2.whl", hash = "sha256:48611e4fa010e823ba2de8fd3f77c1322dd60cb0d180dc6630a7e157b205f7ea"}, - {file = "grpcio-1.62.1-cp310-cp310-manylinux_2_17_aarch64.whl", hash = "sha256:b2a0e71b0a2158aa4bce48be9f8f9eb45cbd17c78c7443616d00abbe2a509f6d"}, - {file = "grpcio-1.62.1-cp310-cp310-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:fbe80577c7880911d3ad65e5ecc997416c98f354efeba2f8d0f9112a67ed65a5"}, - {file = "grpcio-1.62.1-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:58f6c693d446964e3292425e1d16e21a97a48ba9172f2d0df9d7b640acb99243"}, - {file = "grpcio-1.62.1-cp310-cp310-musllinux_1_1_i686.whl", hash = "sha256:77c339403db5a20ef4fed02e4d1a9a3d9866bf9c0afc77a42234677313ea22f3"}, - {file = "grpcio-1.62.1-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:b5a4ea906db7dec694098435d84bf2854fe158eb3cd51e1107e571246d4d1d70"}, - {file = "grpcio-1.62.1-cp310-cp310-win32.whl", hash = "sha256:4187201a53f8561c015bc745b81a1b2d278967b8de35f3399b84b0695e281d5f"}, - {file = "grpcio-1.62.1-cp310-cp310-win_amd64.whl", hash = "sha256:844d1f3fb11bd1ed362d3fdc495d0770cfab75761836193af166fee113421d66"}, - {file = "grpcio-1.62.1-cp311-cp311-linux_armv7l.whl", hash = "sha256:833379943d1728a005e44103f17ecd73d058d37d95783eb8f0b28ddc1f54d7b2"}, - {file = "grpcio-1.62.1-cp311-cp311-macosx_10_10_universal2.whl", hash = "sha256:c7fcc6a32e7b7b58f5a7d27530669337a5d587d4066060bcb9dee7a8c833dfb7"}, - {file = "grpcio-1.62.1-cp311-cp311-manylinux_2_17_aarch64.whl", hash = "sha256:fa7d28eb4d50b7cbe75bb8b45ed0da9a1dc5b219a0af59449676a29c2eed9698"}, - {file = "grpcio-1.62.1-cp311-cp311-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:48f7135c3de2f298b833be8b4ae20cafe37091634e91f61f5a7eb3d61ec6f660"}, - {file = "grpcio-1.62.1-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:71f11fd63365ade276c9d4a7b7df5c136f9030e3457107e1791b3737a9b9ed6a"}, - {file = "grpcio-1.62.1-cp311-cp311-musllinux_1_1_i686.whl", hash = "sha256:4b49fd8fe9f9ac23b78437da94c54aa7e9996fbb220bac024a67469ce5d0825f"}, - {file = "grpcio-1.62.1-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:482ae2ae78679ba9ed5752099b32e5fe580443b4f798e1b71df412abf43375db"}, - {file = "grpcio-1.62.1-cp311-cp311-win32.whl", hash = "sha256:1faa02530b6c7426404372515fe5ddf66e199c2ee613f88f025c6f3bd816450c"}, - {file = "grpcio-1.62.1-cp311-cp311-win_amd64.whl", hash = "sha256:5bd90b8c395f39bc82a5fb32a0173e220e3f401ff697840f4003e15b96d1befc"}, - {file = "grpcio-1.62.1-cp312-cp312-linux_armv7l.whl", hash = "sha256:b134d5d71b4e0837fff574c00e49176051a1c532d26c052a1e43231f252d813b"}, - {file = "grpcio-1.62.1-cp312-cp312-macosx_10_10_universal2.whl", hash = "sha256:d1f6c96573dc09d50dbcbd91dbf71d5cf97640c9427c32584010fbbd4c0e0037"}, - {file = "grpcio-1.62.1-cp312-cp312-manylinux_2_17_aarch64.whl", hash = "sha256:359f821d4578f80f41909b9ee9b76fb249a21035a061a327f91c953493782c31"}, - {file = "grpcio-1.62.1-cp312-cp312-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:a485f0c2010c696be269184bdb5ae72781344cb4e60db976c59d84dd6354fac9"}, - {file = "grpcio-1.62.1-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:b50b09b4dc01767163d67e1532f948264167cd27f49e9377e3556c3cba1268e1"}, - {file = "grpcio-1.62.1-cp312-cp312-musllinux_1_1_i686.whl", hash = "sha256:3227c667dccbe38f2c4d943238b887bac588d97c104815aecc62d2fd976e014b"}, - {file = "grpcio-1.62.1-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:3952b581eb121324853ce2b191dae08badb75cd493cb4e0243368aa9e61cfd41"}, - {file = "grpcio-1.62.1-cp312-cp312-win32.whl", hash = "sha256:83a17b303425104d6329c10eb34bba186ffa67161e63fa6cdae7776ff76df73f"}, - {file = "grpcio-1.62.1-cp312-cp312-win_amd64.whl", hash = "sha256:6696ffe440333a19d8d128e88d440f91fb92c75a80ce4b44d55800e656a3ef1d"}, - {file = "grpcio-1.62.1-cp37-cp37m-linux_armv7l.whl", hash = "sha256:e3393b0823f938253370ebef033c9fd23d27f3eae8eb9a8f6264900c7ea3fb5a"}, - {file = "grpcio-1.62.1-cp37-cp37m-macosx_10_10_universal2.whl", hash = "sha256:83e7ccb85a74beaeae2634f10eb858a0ed1a63081172649ff4261f929bacfd22"}, - {file = "grpcio-1.62.1-cp37-cp37m-manylinux_2_17_aarch64.whl", hash = "sha256:882020c87999d54667a284c7ddf065b359bd00251fcd70279ac486776dbf84ec"}, - {file = "grpcio-1.62.1-cp37-cp37m-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:a10383035e864f386fe096fed5c47d27a2bf7173c56a6e26cffaaa5a361addb1"}, - {file = "grpcio-1.62.1-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:960edebedc6b9ada1ef58e1c71156f28689978188cd8cff3b646b57288a927d9"}, - {file = "grpcio-1.62.1-cp37-cp37m-musllinux_1_1_i686.whl", hash = "sha256:23e2e04b83f347d0aadde0c9b616f4726c3d76db04b438fd3904b289a725267f"}, - {file = "grpcio-1.62.1-cp37-cp37m-musllinux_1_1_x86_64.whl", hash = "sha256:978121758711916d34fe57c1f75b79cdfc73952f1481bb9583399331682d36f7"}, - {file = "grpcio-1.62.1-cp37-cp37m-win_amd64.whl", hash = "sha256:9084086190cc6d628f282e5615f987288b95457292e969b9205e45b442276407"}, - {file = "grpcio-1.62.1-cp38-cp38-linux_armv7l.whl", hash = "sha256:22bccdd7b23c420a27fd28540fb5dcbc97dc6be105f7698cb0e7d7a420d0e362"}, - {file = "grpcio-1.62.1-cp38-cp38-macosx_10_10_universal2.whl", hash = "sha256:8999bf1b57172dbc7c3e4bb3c732658e918f5c333b2942243f10d0d653953ba9"}, - {file = "grpcio-1.62.1-cp38-cp38-manylinux_2_17_aarch64.whl", hash = "sha256:d9e52558b8b8c2f4ac05ac86344a7417ccdd2b460a59616de49eb6933b07a0bd"}, - {file = "grpcio-1.62.1-cp38-cp38-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:1714e7bc935780bc3de1b3fcbc7674209adf5208ff825799d579ffd6cd0bd505"}, - {file = "grpcio-1.62.1-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:c8842ccbd8c0e253c1f189088228f9b433f7a93b7196b9e5b6f87dba393f5d5d"}, - {file = "grpcio-1.62.1-cp38-cp38-musllinux_1_1_i686.whl", hash = "sha256:1f1e7b36bdff50103af95a80923bf1853f6823dd62f2d2a2524b66ed74103e49"}, - {file = "grpcio-1.62.1-cp38-cp38-musllinux_1_1_x86_64.whl", hash = "sha256:bba97b8e8883a8038606480d6b6772289f4c907f6ba780fa1f7b7da7dfd76f06"}, - {file = "grpcio-1.62.1-cp38-cp38-win32.whl", hash = "sha256:a7f615270fe534548112a74e790cd9d4f5509d744dd718cd442bf016626c22e4"}, - {file = "grpcio-1.62.1-cp38-cp38-win_amd64.whl", hash = "sha256:e6c8c8693df718c5ecbc7babb12c69a4e3677fd11de8886f05ab22d4e6b1c43b"}, - {file = "grpcio-1.62.1-cp39-cp39-linux_armv7l.whl", hash = "sha256:73db2dc1b201d20ab7083e7041946910bb991e7e9761a0394bbc3c2632326483"}, - {file = "grpcio-1.62.1-cp39-cp39-macosx_10_10_universal2.whl", hash = "sha256:407b26b7f7bbd4f4751dbc9767a1f0716f9fe72d3d7e96bb3ccfc4aace07c8de"}, - {file = "grpcio-1.62.1-cp39-cp39-manylinux_2_17_aarch64.whl", hash = "sha256:f8de7c8cef9261a2d0a62edf2ccea3d741a523c6b8a6477a340a1f2e417658de"}, - {file = "grpcio-1.62.1-cp39-cp39-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:9bd5c8a1af40ec305d001c60236308a67e25419003e9bb3ebfab5695a8d0b369"}, - {file = "grpcio-1.62.1-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:be0477cb31da67846a33b1a75c611f88bfbcd427fe17701b6317aefceee1b96f"}, - {file = "grpcio-1.62.1-cp39-cp39-musllinux_1_1_i686.whl", hash = "sha256:60dcd824df166ba266ee0cfaf35a31406cd16ef602b49f5d4dfb21f014b0dedd"}, - {file = "grpcio-1.62.1-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:973c49086cabab773525f6077f95e5a993bfc03ba8fc32e32f2c279497780585"}, - {file = "grpcio-1.62.1-cp39-cp39-win32.whl", hash = "sha256:12859468e8918d3bd243d213cd6fd6ab07208195dc140763c00dfe901ce1e1b4"}, - {file = "grpcio-1.62.1-cp39-cp39-win_amd64.whl", hash = "sha256:b7209117bbeebdfa5d898205cc55153a51285757902dd73c47de498ad4d11332"}, - {file = "grpcio-1.62.1.tar.gz", hash = "sha256:6c455e008fa86d9e9a9d85bb76da4277c0d7d9668a3bfa70dbe86e9f3c759947"}, -] - -[package.extras] -protobuf = ["grpcio-tools (>=1.62.1)"] - [[package]] name = "gym" version = "0.26.2" @@ -1341,21 +1265,6 @@ html5 = ["html5lib"] htmlsoup = ["BeautifulSoup4"] source = ["Cython (>=3.0.7)"] -[[package]] -name = "markdown" -version = "3.6" -description = "Python implementation of John Gruber's Markdown." -optional = false -python-versions = ">=3.8" -files = [ - {file = "Markdown-3.6-py3-none-any.whl", hash = "sha256:48f276f4d8cfb8ce6527c8f79e2ee29708508bf4d40aa410fbc3b4ee832c850f"}, - {file = "Markdown-3.6.tar.gz", hash = "sha256:ed4f41f6daecbeeb96e576ce414c41d2d876daa9a16cb35fa8ed8c2ddfad0224"}, -] - -[package.extras] -docs = ["mdx-gh-links (>=0.2)", "mkdocs (>=1.5)", "mkdocs-gen-files", "mkdocs-literate-nav", "mkdocs-nature (>=0.6)", "mkdocs-section-index", "mkdocstrings[python]"] -testing = ["coverage", "pyyaml"] - [[package]] name = "markupsafe" version = "2.1.5" @@ -1559,32 +1468,32 @@ setuptools = "*" [[package]] name = "numba" -version = "0.59.0" +version = "0.59.1" description = "compiling Python code using LLVM" optional = false python-versions = ">=3.9" files = [ - {file = "numba-0.59.0-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:8d061d800473fb8fef76a455221f4ad649a53f5e0f96e3f6c8b8553ee6fa98fa"}, - {file = "numba-0.59.0-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:c086a434e7d3891ce5dfd3d1e7ee8102ac1e733962098578b507864120559ceb"}, - {file = "numba-0.59.0-cp310-cp310-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:9e20736bf62e61f8353fb71b0d3a1efba636c7a303d511600fc57648b55823ed"}, - {file = "numba-0.59.0-cp310-cp310-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:e86e6786aec31d2002122199486e10bbc0dc40f78d76364cded375912b13614c"}, - {file = "numba-0.59.0-cp310-cp310-win_amd64.whl", hash = "sha256:0307ee91b24500bb7e64d8a109848baf3a3905df48ce142b8ac60aaa406a0400"}, - {file = "numba-0.59.0-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:d540f69a8245fb714419c2209e9af6104e568eb97623adc8943642e61f5d6d8e"}, - {file = "numba-0.59.0-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:1192d6b2906bf3ff72b1d97458724d98860ab86a91abdd4cfd9328432b661e31"}, - {file = "numba-0.59.0-cp311-cp311-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:90efb436d3413809fcd15298c6d395cb7d98184350472588356ccf19db9e37c8"}, - {file = "numba-0.59.0-cp311-cp311-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:cd3dac45e25d927dcb65d44fb3a973994f5add2b15add13337844afe669dd1ba"}, - {file = "numba-0.59.0-cp311-cp311-win_amd64.whl", hash = "sha256:753dc601a159861808cc3207bad5c17724d3b69552fd22768fddbf302a817a4c"}, - {file = "numba-0.59.0-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:ce62bc0e6dd5264e7ff7f34f41786889fa81a6b860662f824aa7532537a7bee0"}, - {file = "numba-0.59.0-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:8cbef55b73741b5eea2dbaf1b0590b14977ca95a13a07d200b794f8f6833a01c"}, - {file = "numba-0.59.0-cp312-cp312-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:70d26ba589f764be45ea8c272caa467dbe882b9676f6749fe6f42678091f5f21"}, - {file = "numba-0.59.0-cp312-cp312-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:e125f7d69968118c28ec0eed9fbedd75440e64214b8d2eac033c22c04db48492"}, - {file = "numba-0.59.0-cp312-cp312-win_amd64.whl", hash = "sha256:4981659220b61a03c1e557654027d271f56f3087448967a55c79a0e5f926de62"}, - {file = "numba-0.59.0-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:fe4d7562d1eed754a7511ed7ba962067f198f86909741c5c6e18c4f1819b1f47"}, - {file = "numba-0.59.0-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:6feb1504bb432280f900deaf4b1dadcee68812209500ed3f81c375cbceab24dc"}, - {file = "numba-0.59.0-cp39-cp39-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:944faad25ee23ea9dda582bfb0189fb9f4fc232359a80ab2a028b94c14ce2b1d"}, - {file = "numba-0.59.0-cp39-cp39-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:5516a469514bfae52a9d7989db4940653a5cbfac106f44cb9c50133b7ad6224b"}, - {file = "numba-0.59.0-cp39-cp39-win_amd64.whl", hash = "sha256:32bd0a41525ec0b1b853da244808f4e5333867df3c43c30c33f89cf20b9c2b63"}, - {file = "numba-0.59.0.tar.gz", hash = "sha256:12b9b064a3e4ad00e2371fc5212ef0396c80f41caec9b5ec391c8b04b6eaf2a8"}, + {file = "numba-0.59.1-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:97385a7f12212c4f4bc28f648720a92514bee79d7063e40ef66c2d30600fd18e"}, + {file = "numba-0.59.1-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:0b77aecf52040de2a1eb1d7e314497b9e56fba17466c80b457b971a25bb1576d"}, + {file = "numba-0.59.1-cp310-cp310-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:3476a4f641bfd58f35ead42f4dcaf5f132569c4647c6f1360ccf18ee4cda3990"}, + {file = "numba-0.59.1-cp310-cp310-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:525ef3f820931bdae95ee5379c670d5c97289c6520726bc6937a4a7d4230ba24"}, + {file = "numba-0.59.1-cp310-cp310-win_amd64.whl", hash = "sha256:990e395e44d192a12105eca3083b61307db7da10e093972ca285c85bef0963d6"}, + {file = "numba-0.59.1-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:43727e7ad20b3ec23ee4fc642f5b61845c71f75dd2825b3c234390c6d8d64051"}, + {file = "numba-0.59.1-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:411df625372c77959570050e861981e9d196cc1da9aa62c3d6a836b5cc338966"}, + {file = "numba-0.59.1-cp311-cp311-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:2801003caa263d1e8497fb84829a7ecfb61738a95f62bc05693fcf1733e978e4"}, + {file = "numba-0.59.1-cp311-cp311-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:dd2842fac03be4e5324ebbbd4d2d0c8c0fc6e0df75c09477dd45b288a0777389"}, + {file = "numba-0.59.1-cp311-cp311-win_amd64.whl", hash = "sha256:0594b3dfb369fada1f8bb2e3045cd6c61a564c62e50cf1f86b4666bc721b3450"}, + {file = "numba-0.59.1-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:1cce206a3b92836cdf26ef39d3a3242fec25e07f020cc4feec4c4a865e340569"}, + {file = "numba-0.59.1-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:8c8b4477763cb1fbd86a3be7050500229417bf60867c93e131fd2626edb02238"}, + {file = "numba-0.59.1-cp312-cp312-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:7d80bce4ef7e65bf895c29e3889ca75a29ee01da80266a01d34815918e365835"}, + {file = "numba-0.59.1-cp312-cp312-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:f7ad1d217773e89a9845886401eaaab0a156a90aa2f179fdc125261fd1105096"}, + {file = "numba-0.59.1-cp312-cp312-win_amd64.whl", hash = "sha256:5bf68f4d69dd3a9f26a9b23548fa23e3bcb9042e2935257b471d2a8d3c424b7f"}, + {file = "numba-0.59.1-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:4e0318ae729de6e5dbe64c75ead1a95eb01fabfe0e2ebed81ebf0344d32db0ae"}, + {file = "numba-0.59.1-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:0f68589740a8c38bb7dc1b938b55d1145244c8353078eea23895d4f82c8b9ec1"}, + {file = "numba-0.59.1-cp39-cp39-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:649913a3758891c77c32e2d2a3bcbedf4a69f5fea276d11f9119677c45a422e8"}, + {file = "numba-0.59.1-cp39-cp39-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:9712808e4545270291d76b9a264839ac878c5eb7d8b6e02c970dc0ac29bc8187"}, + {file = "numba-0.59.1-cp39-cp39-win_amd64.whl", hash = "sha256:8d51ccd7008a83105ad6a0082b6a2b70f1142dc7cfd76deb8c5a862367eb8c86"}, + {file = "numba-0.59.1.tar.gz", hash = "sha256:76f69132b96028d2774ed20415e8c528a34e3299a40581bae178f0994a2f370b"}, ] [package.dependencies] @@ -2551,30 +2460,6 @@ urllib3 = ">=1.21.1,<3" socks = ["PySocks (>=1.5.6,!=1.5.7)"] use-chardet-on-py3 = ["chardet (>=3.0.2,<6)"] -[[package]] -name = "robomimic" -version = "0.2.0" -description = "robomimic: A Modular Framework for Robot Learning from Demonstration" -optional = false -python-versions = ">=3" -files = [ - {file = "robomimic-0.2.0.tar.gz", hash = "sha256:ee3bb5cf9c3e1feead6b57b43c5db738fd0a8e0c015fdf6419808af8fffdc463"}, -] - -[package.dependencies] -egl_probe = ">=1.0.1" -h5py = "*" -imageio = "*" -imageio-ffmpeg = "*" -numpy = ">=1.13.3" -psutil = "*" -tensorboard = "*" -tensorboardX = "*" -termcolor = "*" -torch = "*" -torchvision = "*" -tqdm = "*" - [[package]] name = "safetensors" version = "0.4.2" @@ -3063,58 +2948,9 @@ files = [ [package.dependencies] mpmath = ">=0.19" -[[package]] -name = "tensorboard" -version = "2.16.2" -description = "TensorBoard lets you watch Tensors Flow" -optional = false -python-versions = ">=3.9" -files = [ - {file = "tensorboard-2.16.2-py3-none-any.whl", hash = "sha256:9f2b4e7dad86667615c0e5cd072f1ea8403fc032a299f0072d6f74855775cc45"}, -] - -[package.dependencies] -absl-py = ">=0.4" -grpcio = ">=1.48.2" -markdown = ">=2.6.8" -numpy = ">=1.12.0" -protobuf = ">=3.19.6,<4.24.0 || >4.24.0" -setuptools = ">=41.0.0" -six = ">1.9" -tensorboard-data-server = ">=0.7.0,<0.8.0" -werkzeug = ">=1.0.1" - -[[package]] -name = "tensorboard-data-server" -version = "0.7.2" -description = "Fast data loading for TensorBoard" -optional = false -python-versions = ">=3.7" -files = [ - {file = "tensorboard_data_server-0.7.2-py3-none-any.whl", hash = "sha256:7e0610d205889588983836ec05dc098e80f97b7e7bbff7e994ebb78f578d0ddb"}, - {file = "tensorboard_data_server-0.7.2-py3-none-macosx_10_9_x86_64.whl", hash = "sha256:9fe5d24221b29625dbc7328b0436ca7fc1c23de4acf4d272f1180856e32f9f60"}, - {file = "tensorboard_data_server-0.7.2-py3-none-manylinux_2_31_x86_64.whl", hash = "sha256:ef687163c24185ae9754ed5650eb5bc4d84ff257aabdc33f0cc6f74d8ba54530"}, -] - -[[package]] -name = "tensorboardx" -version = "2.6.2.2" -description = "TensorBoardX lets you watch Tensors Flow without Tensorflow" -optional = false -python-versions = "*" -files = [ - {file = "tensorboardX-2.6.2.2-py2.py3-none-any.whl", hash = "sha256:160025acbf759ede23fd3526ae9d9bfbfd8b68eb16c38a010ebe326dc6395db8"}, - {file = "tensorboardX-2.6.2.2.tar.gz", hash = "sha256:c6476d7cd0d529b0b72f4acadb1269f9ed8b22f441e87a84f2a3b940bb87b666"}, -] - -[package.dependencies] -numpy = "*" -packaging = "*" -protobuf = ">=3.20" - [[package]] name = "tensordict" -version = "0.4.0+f1c833e" +version = "0.4.0+ca4256e" description = "" optional = false python-versions = "*" @@ -3135,7 +2971,7 @@ tests = ["pytest", "pytest-benchmark", "pytest-instafail", "pytest-rerunfailures type = "git" url = "https://github.com/pytorch/tensordict" reference = "HEAD" -resolved_reference = "f1c833ecf495aa61f3f76bf09f94dd708db496ec" +resolved_reference = "b4c91e8828c538ca0a50d8383fd99311a9afb078" [[package]] name = "termcolor" @@ -3168,24 +3004,6 @@ numpy = "*" [package.extras] all = ["defusedxml", "fsspec", "imagecodecs (>=2023.8.12)", "lxml", "matplotlib", "zarr"] -[[package]] -name = "timm" -version = "0.9.16" -description = "PyTorch Image Models" -optional = false -python-versions = ">=3.8" -files = [ - {file = "timm-0.9.16-py3-none-any.whl", hash = "sha256:bf5704014476ab011589d3c14172ee4c901fd18f9110a928019cac5be2945914"}, - {file = "timm-0.9.16.tar.gz", hash = "sha256:891e54f375d55adf31a71ab0c117761f0e472f9f3971858ecdd1e7376b7071e6"}, -] - -[package.dependencies] -huggingface_hub = "*" -pyyaml = "*" -safetensors = "*" -torch = "*" -torchvision = "*" - [[package]] name = "tomli" version = "2.0.1" @@ -3471,23 +3289,6 @@ perf = ["orjson"] reports = ["pydantic (>=2.0.0)"] sweeps = ["sweeps (>=0.2.0)"] -[[package]] -name = "werkzeug" -version = "3.0.1" -description = "The comprehensive WSGI web application library." -optional = false -python-versions = ">=3.8" -files = [ - {file = "werkzeug-3.0.1-py3-none-any.whl", hash = "sha256:90a285dc0e42ad56b34e696398b8122ee4c681833fb35b8334a095d82c56da10"}, - {file = "werkzeug-3.0.1.tar.gz", hash = "sha256:507e811ecea72b18a404947aded4b3390e1db8f826b494d76550ef45bb3b1dcc"}, -] - -[package.dependencies] -MarkupSafe = ">=2.1.1" - -[package.extras] -watchdog = ["watchdog (>=2.3)"] - [[package]] name = "zarr" version = "2.17.1" @@ -3527,4 +3328,4 @@ testing = ["big-O", "jaraco.functools", "jaraco.itertools", "more-itertools", "p [metadata] lock-version = "2.0" python-versions = "^3.10" -content-hash = "3bf6532037cfea563819989806d5cd171e33ceb077b0d6afddf54710cbbb3c74" +content-hash = "ee86b84a795e6a3e9c2d79f244a87b55589adbe46d549ac38adf48be27c04cf9" diff --git a/pyproject.toml b/pyproject.toml index 5cef3ef4..2e818a44 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -49,8 +49,6 @@ opencv-python = "^4.9.0.80" diffusers = "^0.26.3" torchvision = "^0.17.1" h5py = "^3.10.0" -robomimic = "0.2.0" -timm = "^0.9.16" dm-control = "1.0.14" huggingface-hub = {extras = ["hf-transfer"], version = "^0.21.4"} From 4b7ec81dde7c4c567bae2b0e70d7d1508f753863 Mon Sep 17 00:00:00 2001 From: Alexander Soare Date: Wed, 20 Mar 2024 14:49:41 +0000 Subject: [PATCH 16/16] remove abstracmethods, fix online training --- lerobot/common/envs/abstract.py | 19 ++++++------------- lerobot/common/policies/abstract.py | 7 +++---- lerobot/scripts/train.py | 6 ++++-- 3 files changed, 13 insertions(+), 19 deletions(-) diff --git a/lerobot/common/envs/abstract.py b/lerobot/common/envs/abstract.py index 8d1a09de..a449e23f 100644 --- a/lerobot/common/envs/abstract.py +++ b/lerobot/common/envs/abstract.py @@ -1,4 +1,3 @@ -import abc from collections import deque from typing import Optional @@ -44,26 +43,20 @@ class AbstractEnv(EnvBase): raise NotImplementedError() # self._prev_action_queue = deque(maxlen=self.num_prev_action) - @abc.abstractmethod def render(self, mode="rgb_array", width=640, height=480): - raise NotImplementedError() + raise NotImplementedError("Abstract method") - @abc.abstractmethod def _reset(self, tensordict: Optional[TensorDict] = None): - raise NotImplementedError() + raise NotImplementedError("Abstract method") - @abc.abstractmethod def _step(self, tensordict: TensorDict): - raise NotImplementedError() + raise NotImplementedError("Abstract method") - @abc.abstractmethod def _make_env(self): - raise NotImplementedError() + raise NotImplementedError("Abstract method") - @abc.abstractmethod def _make_spec(self): - raise NotImplementedError() + raise NotImplementedError("Abstract method") - @abc.abstractmethod def _set_seed(self, seed: Optional[int]): - raise NotImplementedError() + raise NotImplementedError("Abstract method") diff --git a/lerobot/common/policies/abstract.py b/lerobot/common/policies/abstract.py index 1c300dbe..e9c331a0 100644 --- a/lerobot/common/policies/abstract.py +++ b/lerobot/common/policies/abstract.py @@ -1,11 +1,10 @@ -from abc import ABC, abstractmethod from collections import deque import torch from torch import Tensor, nn -class AbstractPolicy(nn.Module, ABC): +class AbstractPolicy(nn.Module): """Base policy which all policies should be derived from. The forward method should generally not be overriden as it plays the role of handling multi-step policies. See its @@ -22,9 +21,9 @@ class AbstractPolicy(nn.Module, ABC): self.n_action_steps = n_action_steps self.clear_action_queue() - @abstractmethod def update(self, replay_buffer, step): """One step of the policy's learning algorithm.""" + raise NotImplementedError("Abstract method") def save(self, fp): torch.save(self.state_dict(), fp) @@ -33,13 +32,13 @@ class AbstractPolicy(nn.Module, ABC): d = torch.load(fp) self.load_state_dict(d) - @abstractmethod def select_actions(self, observation) -> Tensor: """Select an action (or trajectory of actions) based on an observation during rollout. If n_action_steps was provided at initialization, this should return a (batch_size, n_action_steps, *) tensor of actions. Otherwise if n_actions_steps is None, this should return a (batch_size, *) tensor of actions. """ + raise NotImplementedError("Abstract method") def clear_action_queue(self): """This should be called whenever the environment is reset.""" diff --git a/lerobot/scripts/train.py b/lerobot/scripts/train.py index 5ecd616d..242c77bc 100644 --- a/lerobot/scripts/train.py +++ b/lerobot/scripts/train.py @@ -112,6 +112,8 @@ def train(cfg: dict, out_dir=None, job_name=None): raise NotImplementedError() if job_name is None: raise NotImplementedError() + if cfg.online_steps > 0: + assert cfg.rollout_batch_size == 1, "rollout_batch_size > 1 not supported for online training steps" init_logging() @@ -218,11 +220,11 @@ def train(cfg: dict, out_dir=None, job_name=None): # TODO: add configurable number of rollout? (default=1) with torch.no_grad(): rollout = env.rollout( - max_steps=cfg.env.episode_length // cfg.n_action_steps, + max_steps=cfg.env.episode_length, policy=td_policy, auto_cast_to_device=True, ) - assert len(rollout) <= cfg.env.episode_length // cfg.n_action_steps + assert len(rollout) <= cfg.env.episode_length # set same episode index for all time steps contained in this rollout rollout["episode"] = torch.tensor([env_step] * len(rollout), dtype=torch.int) online_buffer.extend(rollout)