From f52f4f2cd2975686f8f8037d8396544712231475 Mon Sep 17 00:00:00 2001 From: Simon Alibert <75076266+aliberts@users.noreply.github.com> Date: Wed, 15 May 2024 12:13:09 +0200 Subject: [PATCH 1/4] Add copyrights (#157) --- lerobot/__init__.py | 15 +++++++++++++++ lerobot/__version__.py | 15 +++++++++++++++ .../_video_benchmark/run_video_benchmark.py | 15 +++++++++++++++ lerobot/common/datasets/factory.py | 15 +++++++++++++++ lerobot/common/datasets/lerobot_dataset.py | 15 +++++++++++++++ .../_diffusion_policy_replay_buffer.py | 15 +++++++++++++++ .../push_dataset_to_hub/_download_raw.py | 15 +++++++++++++++ .../_umi_imagecodecs_numcodecs.py | 15 +++++++++++++++ .../push_dataset_to_hub/aloha_hdf5_format.py | 15 +++++++++++++++ .../push_dataset_to_hub/compute_stats.py | 15 +++++++++++++++ .../push_dataset_to_hub/pusht_zarr_format.py | 15 +++++++++++++++ .../push_dataset_to_hub/umi_zarr_format.py | 15 +++++++++++++++ .../common/datasets/push_dataset_to_hub/utils.py | 15 +++++++++++++++ .../push_dataset_to_hub/xarm_pkl_format.py | 15 +++++++++++++++ lerobot/common/datasets/utils.py | 15 +++++++++++++++ lerobot/common/datasets/video_utils.py | 15 +++++++++++++++ lerobot/common/envs/factory.py | 15 +++++++++++++++ lerobot/common/envs/utils.py | 15 +++++++++++++++ lerobot/common/logger.py | 15 +++++++++++++++ lerobot/common/policies/act/configuration_act.py | 15 +++++++++++++++ lerobot/common/policies/act/modeling_act.py | 15 +++++++++++++++ .../diffusion/configuration_diffusion.py | 16 ++++++++++++++++ .../policies/diffusion/modeling_diffusion.py | 16 ++++++++++++++++ lerobot/common/policies/factory.py | 15 +++++++++++++++ lerobot/common/policies/normalize.py | 15 +++++++++++++++ lerobot/common/policies/policy_protocol.py | 15 +++++++++++++++ .../common/policies/tdmpc/configuration_tdmpc.py | 16 ++++++++++++++++ lerobot/common/policies/tdmpc/modeling_tdmpc.py | 16 ++++++++++++++++ lerobot/common/policies/utils.py | 15 +++++++++++++++ lerobot/common/utils/import_utils.py | 15 +++++++++++++++ lerobot/common/utils/io_utils.py | 15 +++++++++++++++ lerobot/common/utils/utils.py | 15 +++++++++++++++ lerobot/scripts/display_sys_info.py | 15 +++++++++++++++ lerobot/scripts/eval.py | 15 +++++++++++++++ lerobot/scripts/push_dataset_to_hub.py | 15 +++++++++++++++ lerobot/scripts/train.py | 15 +++++++++++++++ lerobot/scripts/visualize_dataset.py | 15 +++++++++++++++ tests/conftest.py | 15 +++++++++++++++ tests/scripts/save_dataset_to_safetensors.py | 15 +++++++++++++++ tests/scripts/save_policy_to_safetensor.py | 15 +++++++++++++++ tests/test_available.py | 15 +++++++++++++++ tests/test_datasets.py | 15 +++++++++++++++ tests/test_envs.py | 15 +++++++++++++++ tests/test_examples.py | 15 +++++++++++++++ tests/test_policies.py | 15 +++++++++++++++ tests/test_visualize_dataset.py | 15 +++++++++++++++ tests/utils.py | 15 +++++++++++++++ 47 files changed, 709 insertions(+) diff --git a/lerobot/__init__.py b/lerobot/__init__.py index 072f4bc7..e188bc52 100644 --- a/lerobot/__init__.py +++ b/lerobot/__init__.py @@ -1,3 +1,18 @@ +#!/usr/bin/env python + +# Copyright 2024 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. """ This file contains lists of available environments, dataset and policies to reflect the current state of LeRobot library. We do not want to import all the dependencies, but instead we keep it lightweight to ensure fast access to these variables. diff --git a/lerobot/__version__.py b/lerobot/__version__.py index 6232b699..d12aafaa 100644 --- a/lerobot/__version__.py +++ b/lerobot/__version__.py @@ -1,3 +1,18 @@ +#!/usr/bin/env python + +# Copyright 2024 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. """To enable `lerobot.__version__`""" from importlib.metadata import PackageNotFoundError, version diff --git a/lerobot/common/datasets/_video_benchmark/run_video_benchmark.py b/lerobot/common/datasets/_video_benchmark/run_video_benchmark.py index 85d48fcf..8be251dc 100644 --- a/lerobot/common/datasets/_video_benchmark/run_video_benchmark.py +++ b/lerobot/common/datasets/_video_benchmark/run_video_benchmark.py @@ -1,3 +1,18 @@ +#!/usr/bin/env python + +# Copyright 2024 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. import json import random import shutil diff --git a/lerobot/common/datasets/factory.py b/lerobot/common/datasets/factory.py index 22dd1789..78967db6 100644 --- a/lerobot/common/datasets/factory.py +++ b/lerobot/common/datasets/factory.py @@ -1,3 +1,18 @@ +#!/usr/bin/env python + +# Copyright 2024 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. import logging import torch diff --git a/lerobot/common/datasets/lerobot_dataset.py b/lerobot/common/datasets/lerobot_dataset.py index f7bc5bd2..21d09879 100644 --- a/lerobot/common/datasets/lerobot_dataset.py +++ b/lerobot/common/datasets/lerobot_dataset.py @@ -1,3 +1,18 @@ +#!/usr/bin/env python + +# Copyright 2024 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. import os from pathlib import Path diff --git a/lerobot/common/datasets/push_dataset_to_hub/_diffusion_policy_replay_buffer.py b/lerobot/common/datasets/push_dataset_to_hub/_diffusion_policy_replay_buffer.py index 2f532650..33b4c974 100644 --- a/lerobot/common/datasets/push_dataset_to_hub/_diffusion_policy_replay_buffer.py +++ b/lerobot/common/datasets/push_dataset_to_hub/_diffusion_policy_replay_buffer.py @@ -1,3 +1,18 @@ +#!/usr/bin/env python + +# Copyright 2024 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. """Helper code for loading PushT dataset from Diffusion Policy (https://diffusion-policy.cs.columbia.edu/) Copied from the original Diffusion Policy repository and used in our `download_and_upload_dataset.py` script. diff --git a/lerobot/common/datasets/push_dataset_to_hub/_download_raw.py b/lerobot/common/datasets/push_dataset_to_hub/_download_raw.py index d26f3d23..232fd055 100644 --- a/lerobot/common/datasets/push_dataset_to_hub/_download_raw.py +++ b/lerobot/common/datasets/push_dataset_to_hub/_download_raw.py @@ -1,3 +1,18 @@ +#!/usr/bin/env python + +# Copyright 2024 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. """ This file contains all obsolete download scripts. They are centralized here to not have to load useless dependencies when using datasets. diff --git a/lerobot/common/datasets/push_dataset_to_hub/_umi_imagecodecs_numcodecs.py b/lerobot/common/datasets/push_dataset_to_hub/_umi_imagecodecs_numcodecs.py index 1561fb88..a118b7e7 100644 --- a/lerobot/common/datasets/push_dataset_to_hub/_umi_imagecodecs_numcodecs.py +++ b/lerobot/common/datasets/push_dataset_to_hub/_umi_imagecodecs_numcodecs.py @@ -1,3 +1,18 @@ +#!/usr/bin/env python + +# Copyright 2024 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. # imagecodecs/numcodecs.py # Copyright (c) 2021-2022, Christoph Gohlke diff --git a/lerobot/common/datasets/push_dataset_to_hub/aloha_hdf5_format.py b/lerobot/common/datasets/push_dataset_to_hub/aloha_hdf5_format.py index f51a59cd..4efadc9e 100644 --- a/lerobot/common/datasets/push_dataset_to_hub/aloha_hdf5_format.py +++ b/lerobot/common/datasets/push_dataset_to_hub/aloha_hdf5_format.py @@ -1,3 +1,18 @@ +#!/usr/bin/env python + +# Copyright 2024 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. """ Contains utilities to process raw data format of HDF5 files like in: https://github.com/tonyzhaozh/act """ diff --git a/lerobot/common/datasets/push_dataset_to_hub/compute_stats.py b/lerobot/common/datasets/push_dataset_to_hub/compute_stats.py index a7a952fb..ec296658 100644 --- a/lerobot/common/datasets/push_dataset_to_hub/compute_stats.py +++ b/lerobot/common/datasets/push_dataset_to_hub/compute_stats.py @@ -1,3 +1,18 @@ +#!/usr/bin/env python + +# Copyright 2024 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. from copy import deepcopy from math import ceil diff --git a/lerobot/common/datasets/push_dataset_to_hub/pusht_zarr_format.py b/lerobot/common/datasets/push_dataset_to_hub/pusht_zarr_format.py index 0c3a8d19..8133a36a 100644 --- a/lerobot/common/datasets/push_dataset_to_hub/pusht_zarr_format.py +++ b/lerobot/common/datasets/push_dataset_to_hub/pusht_zarr_format.py @@ -1,3 +1,18 @@ +#!/usr/bin/env python + +# Copyright 2024 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. """Process zarr files formatted like in: https://github.com/real-stanford/diffusion_policy""" import shutil diff --git a/lerobot/common/datasets/push_dataset_to_hub/umi_zarr_format.py b/lerobot/common/datasets/push_dataset_to_hub/umi_zarr_format.py index 00828750..cab2bdc5 100644 --- a/lerobot/common/datasets/push_dataset_to_hub/umi_zarr_format.py +++ b/lerobot/common/datasets/push_dataset_to_hub/umi_zarr_format.py @@ -1,3 +1,18 @@ +#!/usr/bin/env python + +# Copyright 2024 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. """Process UMI (Universal Manipulation Interface) data stored in Zarr format like in: https://github.com/real-stanford/universal_manipulation_interface""" import logging diff --git a/lerobot/common/datasets/push_dataset_to_hub/utils.py b/lerobot/common/datasets/push_dataset_to_hub/utils.py index 1b12c0b7..4feb1dcf 100644 --- a/lerobot/common/datasets/push_dataset_to_hub/utils.py +++ b/lerobot/common/datasets/push_dataset_to_hub/utils.py @@ -1,3 +1,18 @@ +#!/usr/bin/env python + +# Copyright 2024 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. from concurrent.futures import ThreadPoolExecutor from pathlib import Path diff --git a/lerobot/common/datasets/push_dataset_to_hub/xarm_pkl_format.py b/lerobot/common/datasets/push_dataset_to_hub/xarm_pkl_format.py index 686edf4c..899ebdde 100644 --- a/lerobot/common/datasets/push_dataset_to_hub/xarm_pkl_format.py +++ b/lerobot/common/datasets/push_dataset_to_hub/xarm_pkl_format.py @@ -1,3 +1,18 @@ +#!/usr/bin/env python + +# Copyright 2024 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. """Process pickle files formatted like in: https://github.com/fyhMer/fowm""" import pickle diff --git a/lerobot/common/datasets/utils.py b/lerobot/common/datasets/utils.py index 96b8fbbc..5cdd5f7c 100644 --- a/lerobot/common/datasets/utils.py +++ b/lerobot/common/datasets/utils.py @@ -1,3 +1,18 @@ +#!/usr/bin/env python + +# Copyright 2024 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. import json from pathlib import Path diff --git a/lerobot/common/datasets/video_utils.py b/lerobot/common/datasets/video_utils.py index 0252be2e..edfca918 100644 --- a/lerobot/common/datasets/video_utils.py +++ b/lerobot/common/datasets/video_utils.py @@ -1,3 +1,18 @@ +#!/usr/bin/env python + +# Copyright 2024 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. import logging import subprocess import warnings diff --git a/lerobot/common/envs/factory.py b/lerobot/common/envs/factory.py index c5fd4671..83f94cfe 100644 --- a/lerobot/common/envs/factory.py +++ b/lerobot/common/envs/factory.py @@ -1,3 +1,18 @@ +#!/usr/bin/env python + +# Copyright 2024 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. import importlib import gymnasium as gym diff --git a/lerobot/common/envs/utils.py b/lerobot/common/envs/utils.py index 5370d385..8fce0369 100644 --- a/lerobot/common/envs/utils.py +++ b/lerobot/common/envs/utils.py @@ -1,3 +1,18 @@ +#!/usr/bin/env python + +# Copyright 2024 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. import einops import numpy as np import torch diff --git a/lerobot/common/logger.py b/lerobot/common/logger.py index ea8db050..109f6951 100644 --- a/lerobot/common/logger.py +++ b/lerobot/common/logger.py @@ -1,3 +1,18 @@ +#!/usr/bin/env python + +# Copyright 2024 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. # TODO(rcadene, alexander-soare): clean this file """Borrowed from https://github.com/fyhMer/fowm/blob/main/src/logger.py""" diff --git a/lerobot/common/policies/act/configuration_act.py b/lerobot/common/policies/act/configuration_act.py index a3980b14..95f443da 100644 --- a/lerobot/common/policies/act/configuration_act.py +++ b/lerobot/common/policies/act/configuration_act.py @@ -1,3 +1,18 @@ +#!/usr/bin/env python + +# Copyright 2024 Tony Z. Zhao and The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. from dataclasses import dataclass, field diff --git a/lerobot/common/policies/act/modeling_act.py b/lerobot/common/policies/act/modeling_act.py index a795d87b..e85a3736 100644 --- a/lerobot/common/policies/act/modeling_act.py +++ b/lerobot/common/policies/act/modeling_act.py @@ -1,3 +1,18 @@ +#!/usr/bin/env python + +# Copyright 2024 Tony Z. Zhao and The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. """Action Chunking Transformer Policy As per Learning Fine-Grained Bimanual Manipulation with Low-Cost Hardware (https://arxiv.org/abs/2304.13705). diff --git a/lerobot/common/policies/diffusion/configuration_diffusion.py b/lerobot/common/policies/diffusion/configuration_diffusion.py index 28a514ab..d0554942 100644 --- a/lerobot/common/policies/diffusion/configuration_diffusion.py +++ b/lerobot/common/policies/diffusion/configuration_diffusion.py @@ -1,3 +1,19 @@ +#!/usr/bin/env python + +# Copyright 2024 Columbia Artificial Intelligence, Robotics Lab, +# and The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. from dataclasses import dataclass, field diff --git a/lerobot/common/policies/diffusion/modeling_diffusion.py b/lerobot/common/policies/diffusion/modeling_diffusion.py index 3115160f..c67040b6 100644 --- a/lerobot/common/policies/diffusion/modeling_diffusion.py +++ b/lerobot/common/policies/diffusion/modeling_diffusion.py @@ -1,3 +1,19 @@ +#!/usr/bin/env python + +# Copyright 2024 Columbia Artificial Intelligence, Robotics Lab, +# and The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. """Diffusion Policy as per "Diffusion Policy: Visuomotor Policy Learning via Action Diffusion" TODO(alexander-soare): diff --git a/lerobot/common/policies/factory.py b/lerobot/common/policies/factory.py index a819d18f..4c124b61 100644 --- a/lerobot/common/policies/factory.py +++ b/lerobot/common/policies/factory.py @@ -1,3 +1,18 @@ +#!/usr/bin/env python + +# Copyright 2024 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. import inspect import logging diff --git a/lerobot/common/policies/normalize.py b/lerobot/common/policies/normalize.py index ab57c8ba..d638c541 100644 --- a/lerobot/common/policies/normalize.py +++ b/lerobot/common/policies/normalize.py @@ -1,3 +1,18 @@ +#!/usr/bin/env python + +# Copyright 2024 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. import torch from torch import Tensor, nn diff --git a/lerobot/common/policies/policy_protocol.py b/lerobot/common/policies/policy_protocol.py index b00cff5c..38738a90 100644 --- a/lerobot/common/policies/policy_protocol.py +++ b/lerobot/common/policies/policy_protocol.py @@ -1,3 +1,18 @@ +#!/usr/bin/env python + +# Copyright 2024 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. """A protocol that all policies should follow. This provides a mechanism for type-hinting and isinstance checks without requiring the policies classes diff --git a/lerobot/common/policies/tdmpc/configuration_tdmpc.py b/lerobot/common/policies/tdmpc/configuration_tdmpc.py index 00d00913..ddf52248 100644 --- a/lerobot/common/policies/tdmpc/configuration_tdmpc.py +++ b/lerobot/common/policies/tdmpc/configuration_tdmpc.py @@ -1,3 +1,19 @@ +#!/usr/bin/env python + +# Copyright 2024 Nicklas Hansen, Xiaolong Wang, Hao Su, +# and The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. from dataclasses import dataclass, field diff --git a/lerobot/common/policies/tdmpc/modeling_tdmpc.py b/lerobot/common/policies/tdmpc/modeling_tdmpc.py index 1fba43d0..70e78c98 100644 --- a/lerobot/common/policies/tdmpc/modeling_tdmpc.py +++ b/lerobot/common/policies/tdmpc/modeling_tdmpc.py @@ -1,3 +1,19 @@ +#!/usr/bin/env python + +# Copyright 2024 Nicklas Hansen, Xiaolong Wang, Hao Su, +# and The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. """Implementation of Finetuning Offline World Models in the Real World. The comments in this code may sometimes refer to these references: diff --git a/lerobot/common/policies/utils.py b/lerobot/common/policies/utils.py index b23c1336..8f7b6eec 100644 --- a/lerobot/common/policies/utils.py +++ b/lerobot/common/policies/utils.py @@ -1,3 +1,18 @@ +#!/usr/bin/env python + +# Copyright 2024 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. import torch from torch import nn diff --git a/lerobot/common/utils/import_utils.py b/lerobot/common/utils/import_utils.py index 642e0ff1..cd5f8245 100644 --- a/lerobot/common/utils/import_utils.py +++ b/lerobot/common/utils/import_utils.py @@ -1,3 +1,18 @@ +#!/usr/bin/env python + +# Copyright 2024 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. import importlib import logging diff --git a/lerobot/common/utils/io_utils.py b/lerobot/common/utils/io_utils.py index 5d727bd7..b85f17c7 100644 --- a/lerobot/common/utils/io_utils.py +++ b/lerobot/common/utils/io_utils.py @@ -1,3 +1,18 @@ +#!/usr/bin/env python + +# Copyright 2024 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. import warnings import imageio diff --git a/lerobot/common/utils/utils.py b/lerobot/common/utils/utils.py index 8fe621f4..d62507b5 100644 --- a/lerobot/common/utils/utils.py +++ b/lerobot/common/utils/utils.py @@ -1,3 +1,18 @@ +#!/usr/bin/env python + +# Copyright 2024 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. import logging import os.path as osp import random diff --git a/lerobot/scripts/display_sys_info.py b/lerobot/scripts/display_sys_info.py index e4ea4260..4d8b4850 100644 --- a/lerobot/scripts/display_sys_info.py +++ b/lerobot/scripts/display_sys_info.py @@ -1,3 +1,18 @@ +#!/usr/bin/env python + +# Copyright 2024 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. import platform import huggingface_hub diff --git a/lerobot/scripts/eval.py b/lerobot/scripts/eval.py index e4a9bfef..9c95633a 100644 --- a/lerobot/scripts/eval.py +++ b/lerobot/scripts/eval.py @@ -1,3 +1,18 @@ +#!/usr/bin/env python + +# Copyright 2024 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. """Evaluate a policy on an environment by running rollouts and computing metrics. Usage examples: diff --git a/lerobot/scripts/push_dataset_to_hub.py b/lerobot/scripts/push_dataset_to_hub.py index dfac410b..16d890a7 100644 --- a/lerobot/scripts/push_dataset_to_hub.py +++ b/lerobot/scripts/push_dataset_to_hub.py @@ -1,3 +1,18 @@ +#!/usr/bin/env python + +# Copyright 2024 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. """ Use this script to convert your dataset into LeRobot dataset format and upload it to the Hugging Face hub, or store it locally. LeRobot dataset format is lightweight, fast to load from, and does not require any diff --git a/lerobot/scripts/train.py b/lerobot/scripts/train.py index 7319e03f..ab07695b 100644 --- a/lerobot/scripts/train.py +++ b/lerobot/scripts/train.py @@ -1,3 +1,18 @@ +#!/usr/bin/env python + +# Copyright 2024 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. import logging import time from copy import deepcopy diff --git a/lerobot/scripts/visualize_dataset.py b/lerobot/scripts/visualize_dataset.py index d4fafe67..58da6a47 100644 --- a/lerobot/scripts/visualize_dataset.py +++ b/lerobot/scripts/visualize_dataset.py @@ -1,3 +1,18 @@ +#!/usr/bin/env python + +# Copyright 2024 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. """ Visualize data of **all** frames of any episode of a dataset of type LeRobotDataset. Note: The last frame of the episode doesnt always correspond to a final state. diff --git a/tests/conftest.py b/tests/conftest.py index 856ca455..62f831aa 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,3 +1,18 @@ +#!/usr/bin/env python + +# Copyright 2024 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. from .utils import DEVICE diff --git a/tests/scripts/save_dataset_to_safetensors.py b/tests/scripts/save_dataset_to_safetensors.py index 17cf2b38..554efe75 100644 --- a/tests/scripts/save_dataset_to_safetensors.py +++ b/tests/scripts/save_dataset_to_safetensors.py @@ -1,3 +1,18 @@ +#!/usr/bin/env python + +# Copyright 2024 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. """ This script provides a utility for saving a dataset as safetensors files for the purpose of testing backward compatibility when updating the data format. It uses the `PushtDataset` to create a DataLoader and saves selected frame from the diff --git a/tests/scripts/save_policy_to_safetensor.py b/tests/scripts/save_policy_to_safetensor.py index 29e9a34f..e79a94ff 100644 --- a/tests/scripts/save_policy_to_safetensor.py +++ b/tests/scripts/save_policy_to_safetensor.py @@ -1,3 +1,18 @@ +#!/usr/bin/env python + +# Copyright 2024 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. import shutil from pathlib import Path diff --git a/tests/test_available.py b/tests/test_available.py index ead9296a..db5bd520 100644 --- a/tests/test_available.py +++ b/tests/test_available.py @@ -1,3 +1,18 @@ +#!/usr/bin/env python + +# Copyright 2024 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. import importlib import gymnasium as gym diff --git a/tests/test_datasets.py b/tests/test_datasets.py index 1d93d48f..afea16a5 100644 --- a/tests/test_datasets.py +++ b/tests/test_datasets.py @@ -1,3 +1,18 @@ +#!/usr/bin/env python + +# Copyright 2024 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. import json import logging from copy import deepcopy diff --git a/tests/test_envs.py b/tests/test_envs.py index f172a645..aec9999d 100644 --- a/tests/test_envs.py +++ b/tests/test_envs.py @@ -1,3 +1,18 @@ +#!/usr/bin/env python + +# Copyright 2024 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. import importlib import gymnasium as gym diff --git a/tests/test_examples.py b/tests/test_examples.py index 543eb022..de95a991 100644 --- a/tests/test_examples.py +++ b/tests/test_examples.py @@ -1,3 +1,18 @@ +#!/usr/bin/env python + +# Copyright 2024 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. # TODO(aliberts): Mute logging for these tests import subprocess import sys diff --git a/tests/test_policies.py b/tests/test_policies.py index f0fa7c56..c8457854 100644 --- a/tests/test_policies.py +++ b/tests/test_policies.py @@ -1,3 +1,18 @@ +#!/usr/bin/env python + +# Copyright 2024 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. import inspect from pathlib import Path diff --git a/tests/test_visualize_dataset.py b/tests/test_visualize_dataset.py index 0124afd3..99954040 100644 --- a/tests/test_visualize_dataset.py +++ b/tests/test_visualize_dataset.py @@ -1,3 +1,18 @@ +#!/usr/bin/env python + +# Copyright 2024 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. import pytest from lerobot.scripts.visualize_dataset import visualize_dataset diff --git a/tests/utils.py b/tests/utils.py index 74e3ba8f..ba49ee70 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -1,3 +1,18 @@ +#!/usr/bin/env python + +# Copyright 2024 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. import platform from functools import wraps From 68c1b13406068b9d88afbfcb2366f927141514f3 Mon Sep 17 00:00:00 2001 From: Alexander Soare Date: Thu, 16 May 2024 13:51:53 +0100 Subject: [PATCH 2/4] Make policies compatible with other/multiple image keys (#149) --- .../common/policies/act/configuration_act.py | 7 --- lerobot/common/policies/act/modeling_act.py | 46 +++++++------------ .../diffusion/configuration_diffusion.py | 17 +++++-- .../policies/diffusion/modeling_diffusion.py | 34 +++++++++----- .../policies/tdmpc/configuration_tdmpc.py | 12 +++-- .../common/policies/tdmpc/modeling_tdmpc.py | 20 ++++---- lerobot/common/policies/utils.py | 4 ++ lerobot/scripts/train.py | 3 +- tests/test_policies.py | 33 +++++++++++++ 9 files changed, 107 insertions(+), 69 deletions(-) diff --git a/lerobot/common/policies/act/configuration_act.py b/lerobot/common/policies/act/configuration_act.py index 95f443da..be444b06 100644 --- a/lerobot/common/policies/act/configuration_act.py +++ b/lerobot/common/policies/act/configuration_act.py @@ -145,10 +145,3 @@ class ACTConfig: raise ValueError( f"Multiple observation steps not handled yet. Got `nobs_steps={self.n_obs_steps}`" ) - # Check that there is only one image. - # TODO(alexander-soare): generalize this to multiple images. - if ( - sum(k.startswith("observation.images.") for k in self.input_shapes) != 1 - or "observation.images.top" not in self.input_shapes - ): - raise ValueError('For now, only "observation.images.top" is accepted for an image input.') diff --git a/lerobot/common/policies/act/modeling_act.py b/lerobot/common/policies/act/modeling_act.py index e85a3736..4a8df1ce 100644 --- a/lerobot/common/policies/act/modeling_act.py +++ b/lerobot/common/policies/act/modeling_act.py @@ -62,6 +62,7 @@ class ACTPolicy(nn.Module, PyTorchModelHubMixin): if config is None: config = ACTConfig() self.config = config + self.normalize_inputs = Normalize( config.input_shapes, config.input_normalization_modes, dataset_stats ) @@ -71,8 +72,13 @@ class ACTPolicy(nn.Module, PyTorchModelHubMixin): self.unnormalize_outputs = Unnormalize( config.output_shapes, config.output_normalization_modes, dataset_stats ) + self.model = ACT(config) + self.expected_image_keys = [k for k in config.input_shapes if k.startswith("observation.image")] + + self.reset() + def reset(self): """This should be called whenever the environment is reset.""" if self.config.n_action_steps is not None: @@ -86,13 +92,10 @@ class ACTPolicy(nn.Module, PyTorchModelHubMixin): environment. It works by managing the actions in a queue and only calling `select_actions` when the queue is empty. """ - assert "observation.images.top" in batch - assert "observation.state" in batch - self.eval() batch = self.normalize_inputs(batch) - self._stack_images(batch) + batch["observation.images"] = torch.stack([batch[k] for k in self.expected_image_keys], dim=-4) if len(self._action_queue) == 0: # `self.model.forward` returns a (batch_size, n_action_steps, action_dim) tensor, but the queue @@ -108,8 +111,8 @@ class ACTPolicy(nn.Module, PyTorchModelHubMixin): def forward(self, batch: dict[str, Tensor]) -> dict[str, Tensor]: """Run the batch through the model and compute the loss for training or validation.""" batch = self.normalize_inputs(batch) + batch["observation.images"] = torch.stack([batch[k] for k in self.expected_image_keys], dim=-4) batch = self.normalize_targets(batch) - self._stack_images(batch) actions_hat, (mu_hat, log_sigma_x2_hat) = self.model(batch) l1_loss = ( @@ -132,21 +135,6 @@ class ACTPolicy(nn.Module, PyTorchModelHubMixin): return loss_dict - def _stack_images(self, batch: dict[str, Tensor]) -> dict[str, Tensor]: - """Stacks all the images in a batch and puts them in a new key: "observation.images". - - This function expects `batch` to have (at least): - { - "observation.state": (B, state_dim) batch of robot states. - "observation.images.{name}": (B, C, H, W) tensor of images. - } - """ - # Stack images in the order dictated by input_shapes. - batch["observation.images"] = torch.stack( - [batch[k] for k in self.config.input_shapes if k.startswith("observation.images.")], - dim=-4, - ) - class ACT(nn.Module): """Action Chunking Transformer: The underlying neural network for ACTPolicy. @@ -176,10 +164,10 @@ class ACT(nn.Module): │ encoder │ │ │ │Transf.│ │ │ │ │ │ │encoder│ │ └───▲─────┘ │ │ │ │ │ - │ │ │ └───▲───┘ │ - │ │ │ │ │ - inputs └─────┼─────┘ │ - │ │ + │ │ │ └▲──▲─▲─┘ │ + │ │ │ │ │ │ │ + inputs └─────┼──┘ │ image emb. │ + │ state emb. │ └───────────────────────┘ """ @@ -321,18 +309,18 @@ class ACT(nn.Module): all_cam_features.append(cam_features) all_cam_pos_embeds.append(cam_pos_embed) # Concatenate camera observation feature maps and positional embeddings along the width dimension. - encoder_in = torch.cat(all_cam_features, axis=3) - cam_pos_embed = torch.cat(all_cam_pos_embeds, axis=3) + encoder_in = torch.cat(all_cam_features, axis=-1) + cam_pos_embed = torch.cat(all_cam_pos_embeds, axis=-1) # Get positional embeddings for robot state and latent. - robot_state_embed = self.encoder_robot_state_input_proj(batch["observation.state"]) - latent_embed = self.encoder_latent_input_proj(latent_sample) + robot_state_embed = self.encoder_robot_state_input_proj(batch["observation.state"]) # (B, C) + latent_embed = self.encoder_latent_input_proj(latent_sample) # (B, C) # Stack encoder input and positional embeddings moving to (S, B, C). encoder_in = torch.cat( [ torch.stack([latent_embed, robot_state_embed], axis=0), - encoder_in.flatten(2).permute(2, 0, 1), + einops.rearrange(encoder_in, "b c h w -> (h w) b c"), ] ) pos_embed = torch.cat( diff --git a/lerobot/common/policies/diffusion/configuration_diffusion.py b/lerobot/common/policies/diffusion/configuration_diffusion.py index d0554942..632f6cd6 100644 --- a/lerobot/common/policies/diffusion/configuration_diffusion.py +++ b/lerobot/common/policies/diffusion/configuration_diffusion.py @@ -148,14 +148,21 @@ class DiffusionConfig: raise ValueError( f"`vision_backbone` must be one of the ResNet variants. Got {self.vision_backbone}." ) + # There should only be one image key. + image_keys = {k for k in self.input_shapes if k.startswith("observation.image")} + if len(image_keys) != 1: + raise ValueError( + f"{self.__class__.__name__} only handles one image for now. Got image keys {image_keys}." + ) + image_key = next(iter(image_keys)) if ( - self.crop_shape[0] > self.input_shapes["observation.image"][1] - or self.crop_shape[1] > self.input_shapes["observation.image"][2] + self.crop_shape[0] > self.input_shapes[image_key][1] + or self.crop_shape[1] > self.input_shapes[image_key][2] ): raise ValueError( - f'`crop_shape` should fit within `input_shapes["observation.image"]`. Got {self.crop_shape} ' - f'for `crop_shape` and {self.input_shapes["observation.image"]} for ' - '`input_shapes["observation.image"]`.' + f"`crop_shape` should fit within `input_shapes[{image_key}]`. Got {self.crop_shape} " + f"for `crop_shape` and {self.input_shapes[image_key]} for " + "`input_shapes[{image_key}]`." ) supported_prediction_types = ["epsilon", "sample"] if self.prediction_type not in supported_prediction_types: diff --git a/lerobot/common/policies/diffusion/modeling_diffusion.py b/lerobot/common/policies/diffusion/modeling_diffusion.py index c67040b6..1659b68e 100644 --- a/lerobot/common/policies/diffusion/modeling_diffusion.py +++ b/lerobot/common/policies/diffusion/modeling_diffusion.py @@ -19,6 +19,7 @@ TODO(alexander-soare): - Remove reliance on Robomimic for SpatialSoftmax. - Remove reliance on diffusers for DDPMScheduler and LR scheduler. + - Make compatible with multiple image keys. """ import math @@ -83,10 +84,18 @@ class DiffusionPolicy(nn.Module, PyTorchModelHubMixin): self.diffusion = DiffusionModel(config) + image_keys = [k for k in config.input_shapes if k.startswith("observation.image")] + # Note: This check is covered in the post-init of the config but have a sanity check just in case. + if len(image_keys) != 1: + raise NotImplementedError( + f"{self.__class__.__name__} only handles one image for now. Got image keys {image_keys}." + ) + self.input_image_key = image_keys[0] + + self.reset() + def reset(self): - """ - Clear observation and action queues. Should be called on `env.reset()` - """ + """Clear observation and action queues. Should be called on `env.reset()`""" self._queues = { "observation.image": deque(maxlen=self.config.n_obs_steps), "observation.state": deque(maxlen=self.config.n_obs_steps), @@ -115,16 +124,14 @@ class DiffusionPolicy(nn.Module, PyTorchModelHubMixin): "horizon" may not the best name to describe what the variable actually means, because this period is actually measured from the first observation which (if `n_obs_steps` > 1) happened in the past. """ - assert "observation.image" in batch - assert "observation.state" in batch - batch = self.normalize_inputs(batch) + batch["observation.image"] = batch[self.input_image_key] self._queues = populate_queues(self._queues, batch) if len(self._queues["action"]) == 0: # stack n latest observations from the queue - batch = {key: torch.stack(list(self._queues[key]), dim=1) for key in batch} + batch = {k: torch.stack(list(self._queues[k]), dim=1) for k in batch if k in self._queues} actions = self.diffusion.generate_actions(batch) # TODO(rcadene): make above methods return output dictionary? @@ -138,6 +145,7 @@ class DiffusionPolicy(nn.Module, PyTorchModelHubMixin): def forward(self, batch: dict[str, Tensor]) -> dict[str, Tensor]: """Run the batch through the model and compute the loss for training or validation.""" batch = self.normalize_inputs(batch) + batch["observation.image"] = batch[self.input_image_key] batch = self.normalize_targets(batch) loss = self.diffusion.compute_loss(batch) return {"loss": loss} @@ -215,13 +223,12 @@ class DiffusionModel(nn.Module): def generate_actions(self, batch: dict[str, Tensor]) -> Tensor: """ - This function expects `batch` to have (at least): + This function expects `batch` to have: { "observation.state": (B, n_obs_steps, state_dim) "observation.image": (B, n_obs_steps, C, H, W) } """ - assert set(batch).issuperset({"observation.state", "observation.image"}) batch_size, n_obs_steps = batch["observation.state"].shape[:2] assert n_obs_steps == self.config.n_obs_steps @@ -345,9 +352,12 @@ class DiffusionRgbEncoder(nn.Module): # Set up pooling and final layers. # Use a dry run to get the feature map shape. - # The dummy input should take the number of image channels from `config.input_shapes` and it should use the - # height and width from `config.crop_shape`. - dummy_input = torch.zeros(size=(1, config.input_shapes["observation.image"][0], *config.crop_shape)) + # The dummy input should take the number of image channels from `config.input_shapes` and it should + # use the height and width from `config.crop_shape`. + image_keys = [k for k in config.input_shapes if k.startswith("observation.image")] + assert len(image_keys) == 1 + image_key = image_keys[0] + dummy_input = torch.zeros(size=(1, config.input_shapes[image_key][0], *config.crop_shape)) with torch.inference_mode(): dummy_feature_map = self.backbone(dummy_input) feature_map_shape = tuple(dummy_feature_map.shape[1:]) diff --git a/lerobot/common/policies/tdmpc/configuration_tdmpc.py b/lerobot/common/policies/tdmpc/configuration_tdmpc.py index ddf52248..cf76fb08 100644 --- a/lerobot/common/policies/tdmpc/configuration_tdmpc.py +++ b/lerobot/common/policies/tdmpc/configuration_tdmpc.py @@ -147,12 +147,18 @@ class TDMPCConfig: def __post_init__(self): """Input validation (not exhaustive).""" - if self.input_shapes["observation.image"][-2] != self.input_shapes["observation.image"][-1]: + # There should only be one image key. + image_keys = {k for k in self.input_shapes if k.startswith("observation.image")} + if len(image_keys) != 1: + raise ValueError( + f"{self.__class__.__name__} only handles one image for now. Got image keys {image_keys}." + ) + image_key = next(iter(image_keys)) + if self.input_shapes[image_key][-2] != self.input_shapes[image_key][-1]: # TODO(alexander-soare): This limitation is solely because of code in the random shift # augmentation. It should be able to be removed. raise ValueError( - "Only square images are handled now. Got image shape " - f"{self.input_shapes['observation.image']}." + f"Only square images are handled now. Got image shape {self.input_shapes[image_key]}." ) if self.n_gaussian_samples <= 0: raise ValueError( diff --git a/lerobot/common/policies/tdmpc/modeling_tdmpc.py b/lerobot/common/policies/tdmpc/modeling_tdmpc.py index 70e78c98..7c873bf2 100644 --- a/lerobot/common/policies/tdmpc/modeling_tdmpc.py +++ b/lerobot/common/policies/tdmpc/modeling_tdmpc.py @@ -112,13 +112,12 @@ class TDMPCPolicy(nn.Module, PyTorchModelHubMixin): config.output_shapes, config.output_normalization_modes, dataset_stats ) - def save(self, fp): - """Save state dict of TOLD model to filepath.""" - torch.save(self.state_dict(), fp) + image_keys = [k for k in config.input_shapes if k.startswith("observation.image")] + # Note: This check is covered in the post-init of the config but have a sanity check just in case. + assert len(image_keys) == 1 + self.input_image_key = image_keys[0] - def load(self, fp): - """Load a saved state dict from filepath into current agent.""" - self.load_state_dict(torch.load(fp)) + self.reset() def reset(self): """ @@ -137,10 +136,8 @@ class TDMPCPolicy(nn.Module, PyTorchModelHubMixin): @torch.no_grad() def select_action(self, batch: dict[str, Tensor]): """Select a single action given environment observations.""" - assert "observation.image" in batch - assert "observation.state" in batch - batch = self.normalize_inputs(batch) + batch["observation.image"] = batch[self.input_image_key] self._queues = populate_queues(self._queues, batch) @@ -319,13 +316,11 @@ class TDMPCPolicy(nn.Module, PyTorchModelHubMixin): device = get_device_from_parameters(self) batch = self.normalize_inputs(batch) + batch["observation.image"] = batch[self.input_image_key] batch = self.normalize_targets(batch) info = {} - # TODO(alexander-soare): Refactor TDMPC and make it comply with the policy interface documentation. - batch_size = batch["index"].shape[0] - # (b, t) -> (t, b) for key in batch: if batch[key].ndim > 1: @@ -353,6 +348,7 @@ class TDMPCPolicy(nn.Module, PyTorchModelHubMixin): # Run latent rollout using the latent dynamics model and policy model. # Note this has shape `horizon+1` because there are `horizon` actions and a current `z`. Each action # gives us a next `z`. + batch_size = batch["index"].shape[0] z_preds = torch.empty(horizon + 1, batch_size, self.config.latent_dim, device=device) z_preds[0] = self.model.encode(current_observation) reward_preds = torch.empty_like(reward, device=device) diff --git a/lerobot/common/policies/utils.py b/lerobot/common/policies/utils.py index 8f7b6eec..5a62daa2 100644 --- a/lerobot/common/policies/utils.py +++ b/lerobot/common/policies/utils.py @@ -19,6 +19,10 @@ from torch import nn def populate_queues(queues, batch): for key in batch: + # Ignore keys not in the queues already (leaving the responsibility to the caller to make sure the + # queues have the keys they want). + if key not in queues: + continue if len(queues[key]) != queues[key].maxlen: # initialize by copying the first observation several times until the queue is full while len(queues[key]) != queues[key].maxlen: diff --git a/lerobot/scripts/train.py b/lerobot/scripts/train.py index ab07695b..7ca7a0b3 100644 --- a/lerobot/scripts/train.py +++ b/lerobot/scripts/train.py @@ -23,6 +23,7 @@ import hydra import torch from datasets import concatenate_datasets from datasets.utils import disable_progress_bars, enable_progress_bars +from omegaconf import DictConfig from lerobot.common.datasets.factory import make_dataset from lerobot.common.datasets.utils import cycle @@ -307,7 +308,7 @@ def add_episodes_inplace( sampler.num_samples = len(concat_dataset) -def train(cfg: dict, out_dir=None, job_name=None): +def train(cfg: DictConfig, out_dir: str | None = None, job_name: str | None = None): if out_dir is None: raise NotImplementedError() if job_name is None: diff --git a/tests/test_policies.py b/tests/test_policies.py index c8457854..75633fe6 100644 --- a/tests/test_policies.py +++ b/tests/test_policies.py @@ -64,6 +64,14 @@ def test_get_policy_and_config_classes(policy_name: str): "act", ["env.task=AlohaTransferCube-v0", "dataset_repo_id=lerobot/aloha_sim_transfer_cube_scripted"], ), + # Note: these parameters also need custom logic in the test function for overriding the Hydra config. + ( + "aloha", + "diffusion", + ["env.task=AlohaInsertion-v0", "dataset_repo_id=lerobot/aloha_sim_insertion_human"], + ), + # Note: these parameters also need custom logic in the test function for overriding the Hydra config. + ("pusht", "act", ["env.task=PushT-v0", "dataset_repo_id=lerobot/pusht"]), ], ) @require_env @@ -87,6 +95,31 @@ def test_policy(env_name, policy_name, extra_overrides): + extra_overrides, ) + # Additional config override logic. + if env_name == "aloha" and policy_name == "diffusion": + for keys in [ + ("training", "delta_timestamps"), + ("policy", "input_shapes"), + ("policy", "input_normalization_modes"), + ]: + dct = dict(cfg[keys[0]][keys[1]]) + dct["observation.images.top"] = dct["observation.image"] + del dct["observation.image"] + cfg[keys[0]][keys[1]] = dct + cfg.override_dataset_stats = None + + # Additional config override logic. + if env_name == "pusht" and policy_name == "act": + for keys in [ + ("policy", "input_shapes"), + ("policy", "input_normalization_modes"), + ]: + dct = dict(cfg[keys[0]][keys[1]]) + dct["observation.image"] = dct["observation.images.top"] + del dct["observation.images.top"] + cfg[keys[0]][keys[1]] = dct + cfg.override_dataset_stats = None + # Check that we can make the policy object. dataset = make_dataset(cfg) policy = make_policy(hydra_cfg=cfg, dataset_stats=dataset.stats) From c9069df9f1e09a98f193eacc7241adead2d10553 Mon Sep 17 00:00:00 2001 From: Akshay Kashyap Date: Thu, 16 May 2024 10:34:10 -0400 Subject: [PATCH 3/4] Port SpatialSoftmax and remove Robomimic dependency (#182) Co-authored-by: Alexander Soare --- .../policies/diffusion/modeling_diffusion.py | 74 +++++++- poetry.lock | 179 +----------------- pyproject.toml | 1 - .../pusht_diffusion/actions.safetensors | Bin 4600 -> 4600 bytes .../pusht_diffusion/grad_stats.safetensors | Bin 47424 -> 47424 bytes .../pusht_diffusion/output_dict.safetensors | Bin 68 -> 68 bytes .../pusht_diffusion/param_stats.safetensors | Bin 49120 -> 49120 bytes 7 files changed, 75 insertions(+), 179 deletions(-) diff --git a/lerobot/common/policies/diffusion/modeling_diffusion.py b/lerobot/common/policies/diffusion/modeling_diffusion.py index 1659b68e..2ae03f22 100644 --- a/lerobot/common/policies/diffusion/modeling_diffusion.py +++ b/lerobot/common/policies/diffusion/modeling_diffusion.py @@ -17,7 +17,6 @@ """Diffusion Policy as per "Diffusion Policy: Visuomotor Policy Learning via Action Diffusion" TODO(alexander-soare): - - Remove reliance on Robomimic for SpatialSoftmax. - Remove reliance on diffusers for DDPMScheduler and LR scheduler. - Make compatible with multiple image keys. """ @@ -27,13 +26,13 @@ from collections import deque from typing import Callable import einops +import numpy as np import torch import torch.nn.functional as F # noqa: N812 import torchvision from diffusers.schedulers.scheduling_ddim import DDIMScheduler from diffusers.schedulers.scheduling_ddpm import DDPMScheduler from huggingface_hub import PyTorchModelHubMixin -from robomimic.models.base_nets import SpatialSoftmax from torch import Tensor, nn from lerobot.common.policies.diffusion.configuration_diffusion import DiffusionConfig @@ -312,6 +311,77 @@ class DiffusionModel(nn.Module): return loss.mean() +class SpatialSoftmax(nn.Module): + """ + Spatial Soft Argmax operation described in "Deep Spatial Autoencoders for Visuomotor Learning" by Finn et al. + (https://arxiv.org/pdf/1509.06113). A minimal port of the robomimic implementation. + + At a high level, this takes 2D feature maps (from a convnet/ViT) and returns the "center of mass" + of activations of each channel, i.e., keypoints in the image space for the policy to focus on. + + Example: take feature maps of size (512x10x12). We generate a grid of normalized coordinates (10x12x2): + ----------------------------------------------------- + | (-1., -1.) | (-0.82, -1.) | ... | (1., -1.) | + | (-1., -0.78) | (-0.82, -0.78) | ... | (1., -0.78) | + | ... | ... | ... | ... | + | (-1., 1.) | (-0.82, 1.) | ... | (1., 1.) | + ----------------------------------------------------- + This is achieved by applying channel-wise softmax over the activations (512x120) and computing the dot + product with the coordinates (120x2) to get expected points of maximal activation (512x2). + + The example above results in 512 keypoints (corresponding to the 512 input channels). We can optionally + provide num_kp != None to control the number of keypoints. This is achieved by a first applying a learnable + linear mapping (in_channels, H, W) -> (num_kp, H, W). + """ + + def __init__(self, input_shape, num_kp=None): + """ + Args: + input_shape (list): (C, H, W) input feature map shape. + num_kp (int): number of keypoints in output. If None, output will have the same number of channels as input. + """ + super().__init__() + + assert len(input_shape) == 3 + self._in_c, self._in_h, self._in_w = input_shape + + if num_kp is not None: + self.nets = torch.nn.Conv2d(self._in_c, num_kp, kernel_size=1) + self._out_c = num_kp + else: + self.nets = None + self._out_c = self._in_c + + # we could use torch.linspace directly but that seems to behave slightly differently than numpy + # and causes a small degradation in pc_success of pre-trained models. + pos_x, pos_y = np.meshgrid(np.linspace(-1.0, 1.0, self._in_w), np.linspace(-1.0, 1.0, self._in_h)) + pos_x = torch.from_numpy(pos_x.reshape(self._in_h * self._in_w, 1)).float() + pos_y = torch.from_numpy(pos_y.reshape(self._in_h * self._in_w, 1)).float() + # register as buffer so it's moved to the correct device. + self.register_buffer("pos_grid", torch.cat([pos_x, pos_y], dim=1)) + + def forward(self, features: Tensor) -> Tensor: + """ + Args: + features: (B, C, H, W) input feature maps. + Returns: + (B, K, 2) image-space coordinates of keypoints. + """ + if self.nets is not None: + features = self.nets(features) + + # [B, K, H, W] -> [B * K, H * W] where K is number of keypoints + features = features.reshape(-1, self._in_h * self._in_w) + # 2d softmax normalization + attention = F.softmax(features, dim=-1) + # [B * K, H * W] x [H * W, 2] -> [B * K, 2] for spatial coordinate mean in x and y dimensions + expected_xy = attention @ self.pos_grid + # reshape to [B, K, 2] + feature_keypoints = expected_xy.view(-1, self._out_c, 2) + + return feature_keypoints + + class DiffusionRgbEncoder(nn.Module): """Encoder an RGB image into a 1D feature vector. diff --git a/poetry.lock b/poetry.lock index 388e03f4..e0b27f15 100644 --- a/poetry.lock +++ b/poetry.lock @@ -4,7 +4,7 @@ name = "absl-py" version = "2.1.0" description = "Abseil Python Common Libraries, see https://github.com/abseil/abseil-py." -optional = false +optional = true python-versions = ">=3.7" files = [ {file = "absl-py-2.1.0.tar.gz", hash = "sha256:7820790efbb316739cde8b4e19357243fc3608a152024288513dd968d7d959ff"}, @@ -767,16 +767,6 @@ files = [ [package.dependencies] six = ">=1.4.0" -[[package]] -name = "egl-probe" -version = "1.0.2" -description = "" -optional = false -python-versions = "*" -files = [ - {file = "egl_probe-1.0.2.tar.gz", hash = "sha256:29bdca7b08da1e060cfb42cd46af8300a7ac4f3b1b2eeb16e545ea16d9a5ac93"}, -] - [[package]] name = "einops" version = "0.8.0" @@ -1037,64 +1027,6 @@ files = [ [package.extras] preview = ["glfw-preview"] -[[package]] -name = "grpcio" -version = "1.63.0" -description = "HTTP/2-based RPC framework" -optional = false -python-versions = ">=3.8" -files = [ - {file = "grpcio-1.63.0-cp310-cp310-linux_armv7l.whl", hash = "sha256:2e93aca840c29d4ab5db93f94ed0a0ca899e241f2e8aec6334ab3575dc46125c"}, - {file = "grpcio-1.63.0-cp310-cp310-macosx_12_0_universal2.whl", hash = "sha256:91b73d3f1340fefa1e1716c8c1ec9930c676d6b10a3513ab6c26004cb02d8b3f"}, - {file = "grpcio-1.63.0-cp310-cp310-manylinux_2_17_aarch64.whl", hash = "sha256:b3afbd9d6827fa6f475a4f91db55e441113f6d3eb9b7ebb8fb806e5bb6d6bd0d"}, - {file = "grpcio-1.63.0-cp310-cp310-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:8f3f6883ce54a7a5f47db43289a0a4c776487912de1a0e2cc83fdaec9685cc9f"}, - {file = "grpcio-1.63.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:cf8dae9cc0412cb86c8de5a8f3be395c5119a370f3ce2e69c8b7d46bb9872c8d"}, - {file = "grpcio-1.63.0-cp310-cp310-musllinux_1_1_i686.whl", hash = "sha256:08e1559fd3b3b4468486b26b0af64a3904a8dbc78d8d936af9c1cf9636eb3e8b"}, - {file = "grpcio-1.63.0-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:5c039ef01516039fa39da8a8a43a95b64e288f79f42a17e6c2904a02a319b357"}, - {file = "grpcio-1.63.0-cp310-cp310-win32.whl", hash = "sha256:ad2ac8903b2eae071055a927ef74121ed52d69468e91d9bcbd028bd0e554be6d"}, - {file = "grpcio-1.63.0-cp310-cp310-win_amd64.whl", hash = "sha256:b2e44f59316716532a993ca2966636df6fbe7be4ab6f099de6815570ebe4383a"}, - {file = "grpcio-1.63.0-cp311-cp311-linux_armv7l.whl", hash = "sha256:f28f8b2db7b86c77916829d64ab21ff49a9d8289ea1564a2b2a3a8ed9ffcccd3"}, - {file = "grpcio-1.63.0-cp311-cp311-macosx_10_9_universal2.whl", hash = "sha256:65bf975639a1f93bee63ca60d2e4951f1b543f498d581869922910a476ead2f5"}, - {file = "grpcio-1.63.0-cp311-cp311-manylinux_2_17_aarch64.whl", hash = "sha256:b5194775fec7dc3dbd6a935102bb156cd2c35efe1685b0a46c67b927c74f0cfb"}, - {file = "grpcio-1.63.0-cp311-cp311-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:e4cbb2100ee46d024c45920d16e888ee5d3cf47c66e316210bc236d5bebc42b3"}, - {file = "grpcio-1.63.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:1ff737cf29b5b801619f10e59b581869e32f400159e8b12d7a97e7e3bdeee6a2"}, - {file = "grpcio-1.63.0-cp311-cp311-musllinux_1_1_i686.whl", hash = "sha256:cd1e68776262dd44dedd7381b1a0ad09d9930ffb405f737d64f505eb7f77d6c7"}, - {file = "grpcio-1.63.0-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:93f45f27f516548e23e4ec3fbab21b060416007dbe768a111fc4611464cc773f"}, - {file = "grpcio-1.63.0-cp311-cp311-win32.whl", hash = "sha256:878b1d88d0137df60e6b09b74cdb73db123f9579232c8456f53e9abc4f62eb3c"}, - {file = "grpcio-1.63.0-cp311-cp311-win_amd64.whl", hash = "sha256:756fed02dacd24e8f488f295a913f250b56b98fb793f41d5b2de6c44fb762434"}, - {file = "grpcio-1.63.0-cp312-cp312-linux_armv7l.whl", hash = "sha256:93a46794cc96c3a674cdfb59ef9ce84d46185fe9421baf2268ccb556f8f81f57"}, - {file = "grpcio-1.63.0-cp312-cp312-macosx_10_9_universal2.whl", hash = "sha256:a7b19dfc74d0be7032ca1eda0ed545e582ee46cd65c162f9e9fc6b26ef827dc6"}, - {file = "grpcio-1.63.0-cp312-cp312-manylinux_2_17_aarch64.whl", hash = "sha256:8064d986d3a64ba21e498b9a376cbc5d6ab2e8ab0e288d39f266f0fca169b90d"}, - {file = "grpcio-1.63.0-cp312-cp312-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:219bb1848cd2c90348c79ed0a6b0ea51866bc7e72fa6e205e459fedab5770172"}, - {file = "grpcio-1.63.0-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:a2d60cd1d58817bc5985fae6168d8b5655c4981d448d0f5b6194bbcc038090d2"}, - {file = "grpcio-1.63.0-cp312-cp312-musllinux_1_1_i686.whl", hash = "sha256:9e350cb096e5c67832e9b6e018cf8a0d2a53b2a958f6251615173165269a91b0"}, - {file = "grpcio-1.63.0-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:56cdf96ff82e3cc90dbe8bac260352993f23e8e256e063c327b6cf9c88daf7a9"}, - {file = "grpcio-1.63.0-cp312-cp312-win32.whl", hash = "sha256:3a6d1f9ea965e750db7b4ee6f9fdef5fdf135abe8a249e75d84b0a3e0c668a1b"}, - {file = "grpcio-1.63.0-cp312-cp312-win_amd64.whl", hash = "sha256:d2497769895bb03efe3187fb1888fc20e98a5f18b3d14b606167dacda5789434"}, - {file = "grpcio-1.63.0-cp38-cp38-linux_armv7l.whl", hash = "sha256:fdf348ae69c6ff484402cfdb14e18c1b0054ac2420079d575c53a60b9b2853ae"}, - {file = "grpcio-1.63.0-cp38-cp38-macosx_10_9_universal2.whl", hash = "sha256:a3abfe0b0f6798dedd2e9e92e881d9acd0fdb62ae27dcbbfa7654a57e24060c0"}, - {file = "grpcio-1.63.0-cp38-cp38-manylinux_2_17_aarch64.whl", hash = "sha256:6ef0ad92873672a2a3767cb827b64741c363ebaa27e7f21659e4e31f4d750280"}, - {file = "grpcio-1.63.0-cp38-cp38-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:b416252ac5588d9dfb8a30a191451adbf534e9ce5f56bb02cd193f12d8845b7f"}, - {file = "grpcio-1.63.0-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:e3b77eaefc74d7eb861d3ffbdf91b50a1bb1639514ebe764c47773b833fa2d91"}, - {file = "grpcio-1.63.0-cp38-cp38-musllinux_1_1_i686.whl", hash = "sha256:b005292369d9c1f80bf70c1db1c17c6c342da7576f1c689e8eee4fb0c256af85"}, - {file = "grpcio-1.63.0-cp38-cp38-musllinux_1_1_x86_64.whl", hash = "sha256:cdcda1156dcc41e042d1e899ba1f5c2e9f3cd7625b3d6ebfa619806a4c1aadda"}, - {file = "grpcio-1.63.0-cp38-cp38-win32.whl", hash = "sha256:01799e8649f9e94ba7db1aeb3452188048b0019dc37696b0f5ce212c87c560c3"}, - {file = "grpcio-1.63.0-cp38-cp38-win_amd64.whl", hash = "sha256:6a1a3642d76f887aa4009d92f71eb37809abceb3b7b5a1eec9c554a246f20e3a"}, - {file = "grpcio-1.63.0-cp39-cp39-linux_armv7l.whl", hash = "sha256:75f701ff645858a2b16bc8c9fc68af215a8bb2d5a9b647448129de6e85d52bce"}, - {file = "grpcio-1.63.0-cp39-cp39-macosx_10_9_universal2.whl", hash = "sha256:cacdef0348a08e475a721967f48206a2254a1b26ee7637638d9e081761a5ba86"}, - {file = "grpcio-1.63.0-cp39-cp39-manylinux_2_17_aarch64.whl", hash = "sha256:0697563d1d84d6985e40ec5ec596ff41b52abb3fd91ec240e8cb44a63b895094"}, - {file = "grpcio-1.63.0-cp39-cp39-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:6426e1fb92d006e47476d42b8f240c1d916a6d4423c5258ccc5b105e43438f61"}, - {file = "grpcio-1.63.0-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:e48cee31bc5f5a31fb2f3b573764bd563aaa5472342860edcc7039525b53e46a"}, - {file = "grpcio-1.63.0-cp39-cp39-musllinux_1_1_i686.whl", hash = "sha256:50344663068041b34a992c19c600236e7abb42d6ec32567916b87b4c8b8833b3"}, - {file = "grpcio-1.63.0-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:259e11932230d70ef24a21b9fb5bb947eb4703f57865a404054400ee92f42f5d"}, - {file = "grpcio-1.63.0-cp39-cp39-win32.whl", hash = "sha256:a44624aad77bf8ca198c55af811fd28f2b3eaf0a50ec5b57b06c034416ef2d0a"}, - {file = "grpcio-1.63.0-cp39-cp39-win_amd64.whl", hash = "sha256:166e5c460e5d7d4656ff9e63b13e1f6029b122104c1633d5f37eaea348d7356d"}, - {file = "grpcio-1.63.0.tar.gz", hash = "sha256:f3023e14805c61bc439fb40ca545ac3d5740ce66120a678a3c6c2c55b70343d1"}, -] - -[package.extras] -protobuf = ["grpcio-tools (>=1.63.0)"] - [[package]] name = "gym-aloha" version = "0.1.1" @@ -1668,7 +1600,6 @@ files = [ {file = "lxml-5.2.1-cp37-cp37m-musllinux_1_2_x86_64.whl", hash = "sha256:9e2addd2d1866fe112bc6f80117bcc6bc25191c5ed1bfbcf9f1386a884252ae8"}, {file = "lxml-5.2.1-cp37-cp37m-win32.whl", hash = "sha256:f51969bac61441fd31f028d7b3b45962f3ecebf691a510495e5d2cd8c8092dbd"}, {file = "lxml-5.2.1-cp37-cp37m-win_amd64.whl", hash = "sha256:b0b58fbfa1bf7367dde8a557994e3b1637294be6cf2169810375caf8571a085c"}, - {file = "lxml-5.2.1-cp38-cp38-macosx_10_9_universal2.whl", hash = "sha256:3e183c6e3298a2ed5af9d7a356ea823bccaab4ec2349dc9ed83999fd289d14d5"}, {file = "lxml-5.2.1-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:804f74efe22b6a227306dd890eecc4f8c59ff25ca35f1f14e7482bbce96ef10b"}, {file = "lxml-5.2.1-cp38-cp38-manylinux_2_12_i686.manylinux2010_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:08802f0c56ed150cc6885ae0788a321b73505d2263ee56dad84d200cab11c07a"}, {file = "lxml-5.2.1-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:0f8c09ed18ecb4ebf23e02b8e7a22a05d6411911e6fabef3a36e4f371f4f2585"}, @@ -1740,21 +1671,6 @@ html5 = ["html5lib"] htmlsoup = ["BeautifulSoup4"] source = ["Cython (>=3.0.10)"] -[[package]] -name = "markdown" -version = "3.6" -description = "Python implementation of John Gruber's Markdown." -optional = false -python-versions = ">=3.8" -files = [ - {file = "Markdown-3.6-py3-none-any.whl", hash = "sha256:48f276f4d8cfb8ce6527c8f79e2ee29708508bf4d40aa410fbc3b4ee832c850f"}, - {file = "Markdown-3.6.tar.gz", hash = "sha256:ed4f41f6daecbeeb96e576ce414c41d2d876daa9a16cb35fa8ed8c2ddfad0224"}, -] - -[package.extras] -docs = ["mdx-gh-links (>=0.2)", "mkdocs (>=1.5)", "mkdocs-gen-files", "mkdocs-literate-nav", "mkdocs-nature (>=0.6)", "mkdocs-section-index", "mkdocstrings[python]"] -testing = ["coverage", "pyyaml"] - [[package]] name = "markupsafe" version = "2.1.5" @@ -3056,6 +2972,7 @@ files = [ {file = "PyYAML-6.0.1-cp311-cp311-win_amd64.whl", hash = "sha256:bf07ee2fef7014951eeb99f56f39c9bb4af143d8aa3c21b1677805985307da34"}, {file = "PyYAML-6.0.1-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:855fb52b0dc35af121542a76b9a84f8d1cd886ea97c84703eaa6d88e37a2ad28"}, {file = "PyYAML-6.0.1-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:40df9b996c2b73138957fe23a16a4f0ba614f4c0efce1e9406a184b6d07fa3a9"}, + {file = "PyYAML-6.0.1-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:a08c6f0fe150303c1c6b71ebcd7213c2858041a7e01975da3a99aed1e7a378ef"}, {file = "PyYAML-6.0.1-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:6c22bec3fbe2524cde73d7ada88f6566758a8f7227bfbf93a408a9d86bcc12a0"}, {file = "PyYAML-6.0.1-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:8d4e9c88387b0f5c7d5f281e55304de64cf7f9c0021a3525bd3b1c542da3b0e4"}, {file = "PyYAML-6.0.1-cp312-cp312-win32.whl", hash = "sha256:d483d2cdf104e7c9fa60c544d92981f12ad66a457afae824d146093b8c294c54"}, @@ -3224,30 +3141,6 @@ typing-extensions = ">=4.5" [package.extras] tests = ["pytest (==7.1.2)"] -[[package]] -name = "robomimic" -version = "0.2.0" -description = "robomimic: A Modular Framework for Robot Learning from Demonstration" -optional = false -python-versions = ">=3" -files = [ - {file = "robomimic-0.2.0.tar.gz", hash = "sha256:ee3bb5cf9c3e1feead6b57b43c5db738fd0a8e0c015fdf6419808af8fffdc463"}, -] - -[package.dependencies] -egl_probe = ">=1.0.1" -h5py = "*" -imageio = "*" -imageio-ffmpeg = "*" -numpy = ">=1.13.3" -psutil = "*" -tensorboard = "*" -tensorboardX = "*" -termcolor = "*" -torch = "*" -torchvision = "*" -tqdm = "*" - [[package]] name = "safetensors" version = "0.4.3" @@ -3738,55 +3631,6 @@ files = [ {file = "tbb-2021.12.0-py3-none-win_amd64.whl", hash = "sha256:fc2772d850229f2f3df85f1109c4844c495a2db7433d38200959ee9265b34789"}, ] -[[package]] -name = "tensorboard" -version = "2.16.2" -description = "TensorBoard lets you watch Tensors Flow" -optional = false -python-versions = ">=3.9" -files = [ - {file = "tensorboard-2.16.2-py3-none-any.whl", hash = "sha256:9f2b4e7dad86667615c0e5cd072f1ea8403fc032a299f0072d6f74855775cc45"}, -] - -[package.dependencies] -absl-py = ">=0.4" -grpcio = ">=1.48.2" -markdown = ">=2.6.8" -numpy = ">=1.12.0" -protobuf = ">=3.19.6,<4.24.0 || >4.24.0" -setuptools = ">=41.0.0" -six = ">1.9" -tensorboard-data-server = ">=0.7.0,<0.8.0" -werkzeug = ">=1.0.1" - -[[package]] -name = "tensorboard-data-server" -version = "0.7.2" -description = "Fast data loading for TensorBoard" -optional = false -python-versions = ">=3.7" -files = [ - {file = "tensorboard_data_server-0.7.2-py3-none-any.whl", hash = "sha256:7e0610d205889588983836ec05dc098e80f97b7e7bbff7e994ebb78f578d0ddb"}, - {file = "tensorboard_data_server-0.7.2-py3-none-macosx_10_9_x86_64.whl", hash = "sha256:9fe5d24221b29625dbc7328b0436ca7fc1c23de4acf4d272f1180856e32f9f60"}, - {file = "tensorboard_data_server-0.7.2-py3-none-manylinux_2_31_x86_64.whl", hash = "sha256:ef687163c24185ae9754ed5650eb5bc4d84ff257aabdc33f0cc6f74d8ba54530"}, -] - -[[package]] -name = "tensorboardx" -version = "2.6.2.2" -description = "TensorBoardX lets you watch Tensors Flow without Tensorflow" -optional = false -python-versions = "*" -files = [ - {file = "tensorboardX-2.6.2.2-py2.py3-none-any.whl", hash = "sha256:160025acbf759ede23fd3526ae9d9bfbfd8b68eb16c38a010ebe326dc6395db8"}, - {file = "tensorboardX-2.6.2.2.tar.gz", hash = "sha256:c6476d7cd0d529b0b72f4acadb1269f9ed8b22f441e87a84f2a3b940bb87b666"}, -] - -[package.dependencies] -numpy = "*" -packaging = "*" -protobuf = ">=3.20" - [[package]] name = "termcolor" version = "2.4.0" @@ -4064,23 +3908,6 @@ perf = ["orjson"] reports = ["pydantic (>=2.0.0)"] sweeps = ["sweeps (>=0.2.0)"] -[[package]] -name = "werkzeug" -version = "3.0.3" -description = "The comprehensive WSGI web application library." -optional = false -python-versions = ">=3.8" -files = [ - {file = "werkzeug-3.0.3-py3-none-any.whl", hash = "sha256:fc9645dc43e03e4d630d23143a04a7f947a9a3b5727cd535fdfe155a17cc48c8"}, - {file = "werkzeug-3.0.3.tar.gz", hash = "sha256:097e5bfda9f0aba8da6b8545146def481d06aa7d3266e7448e2cccf67dd8bd18"}, -] - -[package.dependencies] -MarkupSafe = ">=2.1.1" - -[package.extras] -watchdog = ["watchdog (>=2.3)"] - [[package]] name = "xxhash" version = "3.4.1" @@ -4348,4 +4175,4 @@ xarm = ["gym-xarm"] [metadata] lock-version = "2.0" python-versions = ">=3.10,<3.13" -content-hash = "2f0d2cbf4a2dec546e25b29b9b108ff1f97b4c278b718360b3f7f6a2bf9dcef8" +content-hash = "e3e3c306a5519e4f716a1ac086ad9b734efedcac077a0ec71e5bc16349a1e559" diff --git a/pyproject.toml b/pyproject.toml index 24d9452d..5b80d06f 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -44,7 +44,6 @@ diffusers = "^0.27.2" torchvision = ">=0.18.0" h5py = ">=3.10.0" huggingface-hub = ">=0.21.4" -robomimic = "0.2.0" gymnasium = ">=0.29.1" cmake = ">=3.29.0.1" gym-pusht = { version = ">=0.1.3", optional = true} diff --git a/tests/data/save_policy_to_safetensors/pusht_diffusion/actions.safetensors b/tests/data/save_policy_to_safetensors/pusht_diffusion/actions.safetensors index 730f5b2bc2a801d15b4ade3593c90f95650f5472..8f03990351292611f702c163ee387b3d7248b5f0 100644 GIT binary patch literal 4600 zcmb7HdsK~C8;?kl8b&kjn)4BI(3Q}bv-eM`!5}mx8pn{7GKiQiLYGl#GK3n1XneVq z3`We;lu`u=p*I%~iCd7j^M+rPcv{kHYf;6Fc+iD<5$ zXoJt5?>36&iWZuiiNvBE?owgqGRs12wpJ|KAoY=cv(?RQ$3~wWOll&wG&NfrAQqW^ zD3amCrY2?<0@&*&%l z8K~x9-1n?Uh2w1uHh#|IggVv2^@}vFQ6BHw2> zlA`}Au-~f>ZuUKBEiA{Jqu*oq&S+eEw+54^eMeu1#Gx=vhhZfoI4_%zHpj~O_+Wnr zt3^TrEJ{eXbq!hDzaEKpzw!FkYM6+C4=WO*FF}>h(LxvKWujOzZ7f5IAZbzx|1YghSu|LsQZbsC%>1R}= zrQzYi>Ab%}KRibapq}QBU~|xp``xxy4v!hWh_h@Z8OGx4-_sdga3l$ld0D)`9sO`b`Zf)>Lvn7^Z|HyFG#DSNEd^O_=cfeOw{G zLhVZWqfrZndRAg{QUMw-$8dgH{n?1O`Z0L7Ri9pWmZN;T8JCZ}bev@Bt9iLjdKV?S zr8pbZfV}mQWUM#<)&))6F6%1e0;_vK>r0%mm1Cn@HaI3uz$p($Vw^sS zMlUv@|DF5*R|>bn__JIXJ-^TKlI68L-)r@W)Lz!W>sT?n59Om*l2O0Nu}`{-j9>JO z%Y}GOS1Zt2^@z_O^tREWV|p-Ic8_YA^>Bam)KNIuHomtPX`e@AnR=e)|2NF1OvD>m z4<0nTKvAaPcGd@5@K|vI$pe&V+Y*C~rAnNxaw8q7y6_#-1>IlTF-g+_H5XgCy@T^D z9Clod)l-v^@NgeL7kyINu>E-$*Wb~X^Yzu4+J}H98B#Nqu$ONmjo~U3&@OBV>w{v~ zO>`~j!HO~JeXY1R5zf}lSeCn<6i_8@1Sqg0wi~x5%qh`!I(+39~QyMqd8M;!T$1z(w5p_~^3$l=a2$6AMUdbde)N+}H7~p48TXn0m?tO6g0MGa<$vUGNv|fq8(imp zc4Y`XRKHKkU0TpD?J++>a=Y_~Gis1T;(ZZP^BUZLo_!R@F}OO8pt_g_?8&BO zClx5!6iltl<}q6dIpQ>o=PvSmiin_d>V3`Rh36ZH<0iI${>fWB-ML>EtFyRtyAS*4U4>g^7Al*g;9>d>Q${so@A?|_ z{85LzOThhs&p4jPW4+wcRE__f-i8Dlf9U9JqZfW#Nh=)VQM`E%J(|-_HEr(F+`ac< zu<9I-(>pVrj&nYQB_FGx6xCsLFC8}-)ne0Nxh%eTI1T&^#y5zaITv0C;FYlmf`V->rPu=#JDLTwW6SM_!PdzO(Q z@;zAD*-1{0Jji4Gt2fJDz~?0^-Xd_CN_}p!Jh!^qk%NX22u+(r>pt~-@9)X4>T}CW z*Ov^F4P;!BkUi#=NG@RST!s_=`#U?3G4vtnY1tY4{tEA6Z*flC^jbvvXDe|q zE|}yf)cyX~zci#7l7tRA{V>mkzU|=ijM$oN41hlczfHdNfkh z=eK}KwsxX=&pw8>Jd}*LO~kR>ebmtH6?s+8IGYuR5yx&4uN*1#W*&xkVHt_N5RTmP zX~<1YgPC6{DhyNbs$P8ukskJh-IqU5$GAu;3mQpZTBh^9NE?NBxPL~7J{RJd20fuc^JV19>(|_VMv*uDGM4!i)}HA{*$8(?85x}tfVgz`n$Ur zKjt@}KK^$eM{8XUQuisa$RmvRt;S1*hoQonsn5}fZ!V|yhgMq^ z^oI$xONiE(Hu@&xZ?0QpyOX#y^kHwHHGM$ta=n0QFUk6+R}ymAQi;;25cSzspg>CM zBjkN5rm9g_F|w$P4oQzhaz+?hb}I2#Vr^WL5ga@6R8>2sG_cE5ppcMhM_p9+xowcX3$0}q*7huY6 zR!`SWk{DBf&@5i8)T->q>!Ce0%yk7fL5_H+MkyJl9u`;v8?oLB^A5z~zIG32O zrS2*u_D@#(X7vejhuX&|wk_ncfwxrH%Xofo-aGu!nga`GwwIYr$*?-Q?Bh;!S?t5& zdYiG=A)C)aiFphf8wYv&ez>ttxJ}Np_s` zM{leOFS|M{N&5{%JLO>>+iG@!MA^)hCkbKK8o_H0+8UeD*!0da4MX$DwcQG42KBK7Y2_Oan4Xh;d*wk9WqHeK)?g zq37(ix!v`&W z3w(td)aSP-u$=ollxRc$2;6`tn`C_MU3Y^zH6+oG8bfJKzhNi{Wcz}}6S*>u$RASM z=vN4r5}vmp*K-IZJ_f~PqUsd4o0(q&(ZEKs`9KAtBXeqIY7hiB%)%Va9P>=nK_WaNLY{X200 literal 4600 zcmb7Hd0dUzAC5^=A;d7YwArt;sD><^^AI&ng%Beu*|JPZB4o59LTO4dEeaK3k|^BL zaLZOn5-M8IK5bg=t>62eYntD``h5DF_q^wMzTfA&oO9mW)K7-}?N`@UpR-@xarZvw zjp}pMP4slt1?sL_?D^TR2D1dZwgPoW``z~If7`Ulb>nVV?ya^!Uq{#Wpg>*cTbUeB zprfrjizn-TE14$>boBKMc(&d*vpJ$bS9{h^JbTu+vU#FFSKmOJXX}46no4kopUeE- zi_3Yp(U>Hm9kq#c&d_r9e($jr;MdHrH_$tFR?r+xe;$8|5`bOtobrueaI1jOmh*5fOf_W+TRae{+OQ z(V%HnS;(``r02_ekT6pdYqsgZY+(l!BHO6^UrkV5n?yGi%*UriA;j%*4`f_KRIv9m zIy$@IYrF&>Y9w@vqcT3_T;sHe$lRw(zUD`dW;IWwJ_qed)8KaKnEoF&`?CvEw65eLu1|c-_1{m8ks)-XrF6FV*|`Mhx9)5=Q<>UyXrJRG@h~m-M^; z3tE;R#{Cs3RK`mT+3mK}aitw~ukON)4l&%G3TbIv0<~XW2VIFRRl5HN(J~ z+U01C9HR#;FTxq6#3EaxFJF&?-(a(=3o%tor5~nz!i)X_C{5c=Hz<^l{<&eyPL($x z(^5-{UM+>z>F@EI@D3TcDFWE_oUQY5c%{v!W^TXH3(w`Ss(KNzyxGO_9}zi>s`?b+ z!RdD-e4v2g6Bo5%&lWT8*_h~#kih)lioW692N|6|Gu)&@&q%mhJC>)2 zG4A~lcza0ulJC#@n+^DDk2@-dU*|MR@aB;xr&&ZYbY^3~0X3FS9;fW155HX>kX(CS z+E?5=m$ue;g43Kb3e%=qSi*AO>LO~@J+<6WT*Rka>K`=`nvMSl=lCp9xYyr1VK^~lybMSh)` zgOQ~bIFT(+HRsO522l^1^)IqGcwAm#DLEXylFgHwyXj2lIw+i_NVCj?Rjg8MSCZ(c!x)Zz)$Wu@`M>IPsQ3^#P>`&YVPJj7$ZUd`<+PrY)DOpwlI8a z(HW{&9}UH^UaSt6SHD5#*d36IaUjRquOM4vA?a1qK+?x}rYHB)C~WiSf)AIQ#&wmn z#@&e7xsoCwo&$U_H;n5$t{&{4tLvDbtGFDA$l!-!Hd>x3!wKcjq#~;fmlk+q;PlYG zcq|O8af>#xdG$O6Tr(WmXD22xjqy8L?aAv9?%d{lDBhK$8&3=A%jf5C-z1Z*Z`4hJ zLtrPcY@)>52fcuUe-C1b4`gx4U+av@2U%Io%O}n z^V8;t52ghEiNU|jrBTZ_qk4usZSxvQf6Sgl^SrJy{NdiOb>?H|cuLQ>yX-6fWxfn; z`@Rzk-EC>)TuujvDb49Y#*APZkwM6`xNf#@U(8TM&$3tu_di0e;!Ma_TR_D%TWW{n z2t(aqGhLN_s4rr*==W?7%e~DD2kLzG1W6w_0Jp}B)R*{Bjzl zlZv3XZZbq2!=ZCjs%74~Vsd?f7!5BTpuf>=gon(h2A+fIrInA+u|pnBmmbmeZT?WU zI?e2uPPqY%mh-G%tlhZ0aoU2lJK!+Ym_74xxBL51{SDm=|8E|oJ@J40mcvI4}8X+n5aVxr1e6 zHZtE%pAI4F)gzDzm+1I?m6-6Vlq}Zi#m@6~)cR-@3gTt^>Y_8LjQP76(g{E18fGuK zCJk$D6~gAwJ`$i8s+C~$eGgPN2jS(`DTqriWqA+2@Dnx2 zu0(?GF0lIRrOk>hz2WHpQ1%3ItxIL8NQg$U%-EpsQDkgz03jGW% zu4WHf@51HE0rK6*B8Z=!hwjlV{OI0<2}!v~*nN!ISnP5Ojk6OGxTlLcrTxI{$Bn2& zLg5Bfe#yn!fWyS7dwpL#nlZnVF)^>%JpJi%lx9|8dsHLqL&3+P)Ymr(OPpHBhN=^E zO44@hw4lV&^*P=p8NyPg*wu=fAl`YR)pY?y9mK<>M^KC}(@LZ0kn5Z2{9YXu(>l@I(mHAuF)$O9Qh% zZAdDMr?U1k?j$dye#V8=DmE5_CZ?g?Ae{03-7h)=OsL1lU>aoDB8}aR|CTT&a>ioZ z_YXw8rZZ0ee3fK9?a_EF5ujxFQeJn3-e{0~l!B(oGrV8_f zO9))8%yQUY5Xf|y+jZi0dJauh6Ou*VPq174FQ(HWfUkjw*d;gPqK7Q?%sx%WE-Ate z(O}#vucAkmHbY{zkWRmR0=2Z4v~Asw{X4n)m8+lG;xu~KYYDj*a1f^cJusMk3yvv1 zWKF9Hi}U8GU2I%AsRt1~T`)N*fy@3Q5Zk^+a`kB3*rZFR+n-_ntKT8~I9Y!>?~?>sV}~Nyn7eN||Ci^# z?bFe(xC6%nEa~Q%S1~W;G@hhJA>u&`Mo&(sR(rYcR1SaOMF&mF{GR^)rUD_8^Kq`T8PDd7M_FqUGKMT64?d@3 znwl5~HZ`MZOEuI^%F@;MF5*%98`>|wA3SD-Al5$&d0K^Z<}i*wPzx<~{BIW#@$}{H zGHx6nRzb7I{X^otrTxhJ&)YH;oxz0+m*1O5MG@<%-^gp^ZRK$~J|PrxHqS`;mxEMc z^)9+Q^8~r6*#(uA8LV&l`bNht!j#TWp&puqVYB@4Bi4`l>h`LR3d&wKMo^3*Z4Rci#k}&W5R_rVX}0; zIUasN*Qa%|Ug=zwPcsjcL1XrxsH^+}i>4H6o@7sSN}OOmfvY3uFK%fy;%E3mrr;GB zdT}dejw4}jpsL>Y3WcCX8k8*qs2CIyGa-toArVo!!8m{o z3SwkmWD!~H08MvSvx$*NTu z;g*FL75b|7Z_aPTR`vE}5q!J~s}^?G0?|3>^8LAT@YZ12^|4hDWU7G{E1Dqcr%Z*P zFcB~MD zXIL*NdtoFe@-{(Vsy#cRNdjSg-bFui3xqrjqYXrbmT>pe1GK;CX87!Z8MUOh4fMi7 z6#;x*@IN7S?-+CYKge;XRl0Pz+705?8_RE9wBuHNS=#l=5zanuCLN6LqsG9M%BGQZ z5aDJccipcIJ#SOxK1DikZ>JTLp5DR;t9b0p@id$j-OGx!K2c$pU{+<` zei8eAr)K*i?Xl52C9MxNdzipFPz?-}k zY702mFif7F`5AddP}JKK|3w7h)v$H$JfpBo_ShJF5t7&3r~H$Ja3r@_sj3cuw6DA9 zJ1vzUHA)gu%Tfzq`s$??Pe)58cCuaRs`F1S;C-j5K|d(y6$&)ZI0EtNQM2id+bAQv zjkJ;8=K^vXj(RJZ#BY1$D|wf|=Rap?*QqdgeYlN2ZF&T*_!3}qeFdJ(iAEQaSK}iN z+W6LVBx6`&L6_(pHu;d94rig~zlSMHgXzo!m9fn{w8O_$HbvdBlSyNZTb6y~p zT^M=-kD>;Yb=$5$;{91lI+6=Ikts@Fy?1cC{wR69z6>-Hj9WIX`;hT}%2y)T!JW5u z+U_oeH&dqaLCG!)ZR_#&Q$DpbyA$2?^8;Ke^ z^PBay7@|u_?X1DRamd(wL(aERa|VQcg|858cbODi^P8Zj>dz?#d(YF$-ZU#7^CX<^ zZeMnzMTiquU1I)Z<=utTpe;iKkrW^UnIPu#cj7`%j2i$&m#lqk@P|^ zOO5i7$f`L$ug=Z5nwG`Q$+(AszPCrqiXj!~H>cKIhk_+z>vR zkNAPK+WZ?WDb}N2KAWZIs!Qqi`UAMM?50N^d7NuF=CIaZj81;JO&;l=glu={lTR$( zL;G)`xdz~{x-_~DuhZu)Qxtru9RA|eH{cHUPd84 zfRE-GR;%k9JaKVCPWDMS197c6lZCPj+}(E#?L+k?U2<26DGFXXNCwMn{?Z*a_mB*+ zzfIs7!h_5G`4lc3Kt@HY>H6YfL_Qm(=a*~YXG<(~%`z9XzP{2bs26xj)^xGa5gwQz zCZwV)d>9L}_fqBEHaK%u&5r3mg_VX?Z1ME?U$?knK8*DY-4IIdIaY;L>8Ywl-4vYd zLa5qG#aL23t%6D=4v&1ye!@$|HMtjB`I~e-A04{N-m%r^0=%3SlpAo}?dPN!^#b%p zd*}@APpFtHrR(?BaQ>Z+4PQ!&J&E*8_xBFv;TY3uRm=L5`17NsssP7>xOZnQWphB6 N2@G4MTx_#u{sckZ-xdG> delta 1719 zcmX|=FZ%Ts}s(r%@ctUTNxBoQ++_csHpQSn$ri!Pw3z@mVGB2_pdT1l!1 zU6hxAgg|)&L6k=zl4Js+0TB@jg-RbrWZ8<}bPE3k%456~IhJuLO}3ZrYK@q>Y-c2To)`eb)(AExJpp-iABym}58q@p zq6POSU>Ypd7V#Tkg~e`7*ScDWzIU4#3)2HONwpk16kIq4Qr@nBpa-`! z6L$OIqwb&ZCySzBWNIHnM~K;B2Hpn-$qA^mM`SK15l-)QC%eSi5V!t}W@r3ia0tJS zPp+rHbjOmYERp`r#V;nAVE1^K^gKcazrTeoN>Fw}1lW9+Q9CcRTuqJEAw0-u|PizIPi$CDAS&q<-S&>IS5i>T1lB2<{~L`4Pv1mWN9 zp^u$^PHY}ZWcxqxJ=Hx8(oTu9-}I>jxUn^ckvw#UxgYYF<(C{;K+5NEEd1|oYi=Kl znSz2OXKnpt10BBfuf^EH)t{pBKCYmv-w#$ z#1sswGmU`|IA_D~?#!UD%LxqMbA}xd&BIIQ=?Tdcca^bl9Nzi8pV(w-fT44qWUuyH z_@c?1oZom3nI;N}3nDpM{oi|Jp{qGH7Hy5&uC-wcMplG0Hle_8y8Lc_XNOvjtMN^e z4)S~MA)&EEN3t*E`?*&YrMhTT#h*+=sdXFiNQVIQWhkck#ZiQ2FN+C#$VHOgI(v1} zNozEDB2gQnz6;$}mVR-wKeGiG7SI)h=KtgBH9$4W0nZA}EsIlZUeVIz zSwwO4;?6{nX^ae`u=h}VwgjBes&znt|n4BL<4Ny zb&gU557X1H#wB#@rY-bttjhoN@Mi7T9~Z-pTTFZTfl5~3*0}gcKRN$m!=-^;3Nm5Q zM0?HN&E&DX?uD#D_M9N2A!|7}_D@Rpp>z|^SojpW7-HQTac*N4;0n}SkLWE01Z-F`~f*5aI3D<3( zGkMXO)K`6rG^;QEpDmDjE1ZhZk4O6j?!0|yRm10M(GUMbGV3s5>4@K-0-s-LU;1!h znBT#zHgI-lCy_NijLwcZllage$~p9i6flp$9F&CpS9$Q8y{xl4L*RcpLc!?51r&y5 zDJ5p1$yAi+gy!am8#R=oVp5VP!P59xdu`be`-?vnG^A?%Q=W;{${7V!-g-gzGCqTf zm{?$P#xkhLH{a^u;4M^j(h}Xhb}99*g{IDybSwWg8GTH?(2f;Y@j|i%)bY2S2vemA z<~9?C>( Pt?Bc-gS6{CoasLR6e#%s diff --git a/tests/data/save_policy_to_safetensors/pusht_diffusion/output_dict.safetensors b/tests/data/save_policy_to_safetensors/pusht_diffusion/output_dict.safetensors index 5f33535d5ad316f4cef55925a215428643ed2f52..a9f61b36ed2735850072768de10a153ef1aea7aa 100644 GIT binary patch delta 9 QcmZ>9nc%?kuA##o01%i1k^lez delta 9 QcmZ>9nc%>3p0ULq01lS|@Bjb+ diff --git a/tests/data/save_policy_to_safetensors/pusht_diffusion/param_stats.safetensors b/tests/data/save_policy_to_safetensors/pusht_diffusion/param_stats.safetensors index ade6a9e03118ac410f6f0ff9357b6779074437b6..a9f4608f6c1c544e8324803f154f57cb9f6acd96 100644 GIT binary patch delta 173 zcmV;e08;h_}JR*~kx<3gn zoKLrTs;WFyv)Q_T0RcFZS-c+sB9nQ%F9A%Gsk~niIFjc#(xDPP0|sZgbL}@g8j~5l z9|3BUIlWI30y9E5VzClEsMYT|bOt~?3VxHCy*?4GIj%Q@a}qt!tyH;GpgcShlj*%r b0gsb8z8?W&lUcq$0dSLlVQ3i2|LL%w>h_}JS&rtx<3gr zoKLrTs;WFIv)Q_T0RcCYS-c+sAd`8#F9A)Hsk~niIg;l$(xDPP0|sZgbL}@g9+Mfp z9|3NYIlWI3{4zo~VzClEsMYT|bOt~?|9z91y*?4RIj%Q@a}qt!tyH;GpgcSalj*%r b0g;nAz8?W%lUcq$0d$j@zFz?!lli{q5xGhY From 4d7d41cdee1e2406746ff38739fda2c58586e811 Mon Sep 17 00:00:00 2001 From: Alexander Soare Date: Thu, 16 May 2024 15:43:25 +0100 Subject: [PATCH 4/4] Fix act action queue (#185) --- lerobot/common/policies/act/modeling_act.py | 6 +++--- .../aloha_act/actions.safetensors | Bin 5104 -> 5104 bytes .../aloha_act/grad_stats.safetensors | Bin 31688 -> 31688 bytes 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/lerobot/common/policies/act/modeling_act.py b/lerobot/common/policies/act/modeling_act.py index 4a8df1ce..3aab03cf 100644 --- a/lerobot/common/policies/act/modeling_act.py +++ b/lerobot/common/policies/act/modeling_act.py @@ -98,13 +98,13 @@ class ACTPolicy(nn.Module, PyTorchModelHubMixin): batch["observation.images"] = torch.stack([batch[k] for k in self.expected_image_keys], dim=-4) if len(self._action_queue) == 0: - # `self.model.forward` returns a (batch_size, n_action_steps, action_dim) tensor, but the queue - # effectively has shape (n_action_steps, batch_size, *), hence the transpose. - actions = self.model(batch)[0][: self.config.n_action_steps] + actions = self.model(batch)[0][:, : self.config.n_action_steps] # TODO(rcadene): make _forward return output dictionary? actions = self.unnormalize_outputs({"action": actions})["action"] + # `self.model.forward` returns a (batch_size, n_action_steps, action_dim) tensor, but the queue + # effectively has shape (n_action_steps, batch_size, *), hence the transpose. self._action_queue.extend(actions.transpose(0, 1)) return self._action_queue.popleft() diff --git a/tests/data/save_policy_to_safetensors/aloha_act/actions.safetensors b/tests/data/save_policy_to_safetensors/aloha_act/actions.safetensors index 7e7ad8e1df015d0ff52d689b317b8d77b3f380fa..3c9447d7fa0b68143216f21c4d9cf5c075253fe4 100644 GIT binary patch literal 5104 zcmb8y`CpCc8wc=ECJFH+l}a6H;gptB=P1v0x19DxyO6P*bCPsUY15ExgprtL#?n}# zO<7tjW9_<7N>MWzW*AG85mBVbl-BPt|AFVX=Xt%}ulw~}_kDd{_Z2--<-cb=&zk4B zo+nNECSJyKO=;fc^7K`{cf+c*Oba~S9rEyVFiL#_bC{;ci4RDXtiTL;m9 zC3E;VeSuRc&AN7 z{j+4uYd#-ORt><@(V>`-UkWFTJZRT-1sQLXv-8$wC>z@Vm3PKN-14d5{K*_&ymF!H zQYqVXY!-s*Ak@4LMxoVND3E(nvTQXeCdrw8=UANftOle%jD(nNqv1%)Y&49qBfBwC zjCVi>otO&C_6foHukzrxSRcB$HI?SQNo3MLXQB^R3(cYrT>4=(*j&fQ#@BNx^0thP z+%**)T}H65zCdhsI0^?+{U~F7JS8+GG0m0snAuzdnvOl(?_ZAxD=i_$@AjeGUK#UA z(nXaED$Ft`9Oqstf|_Vw>bkmy94{v@@lrA3jVcg7?cxM~O@v*KgsAQ5OZ6|J*!?se zymnrN4gL{|=Pnii&(D`ei{t6xp9w5Y>WeN{Zo=N5KXMN<^q@FlHjX;$PCh~jyI-e^ z5y=DKb0`cWZyg2ASMJmrxQep>Q81J9({Z8wWyo$G;=Fh9pf%GTvo?Cs%Lh{S+{h4h zXQ;5}b3;+II3KK%y(m|rpgl2h%+JyhPv+FX;UiyxMUxTO|89wo{$k{nEM>nn&cc9` zeK4&l1jFM`!wGdynrIh8o1^2Hc(w`l7vF)LRyEl0SQ|#Xw89Q!Px`e-!j=`7qyAU@ zaQok29KE*;*135R$D&AkY#f_+Wi0-9`zBIRo=v6}VzE|AaUv#0})EEPwIns}LQg(a2A*LSeg{|TR_`azSB7@!N z>u2%gTo%hRt9f|EunHyztAk-D;6B56>BR{O|zWi3`W^1w~M&??Klr zlW1>mEaShMi^~lg!APbBfrXPHQ^Lc4Ivwcbex;M!;M#&7Sn3jvy_e2_eZL!hyC|N* zz2$6PJ`cC5-Gkao6X52(F(6hoLGxdL4$YP^hjG@Ja;^td48t(n_ZKMUxs!iIEUi$J zv!tCGc*?W^DtbqP^hF2PwAB=|c-9oVB8ug78(^PhKl~9Ch9Sv?U@_B+R@$wl(%2+6 zu-gL1uKpFQp_^N`q=T!QXofz8LYiw4#f;V&;D>R8kT4C9U8+{@L&|it zJA4h2QZzsuH5CFnXJfj*6Aj;xu@ZkXoTjP5B9;f^qo6|YKjcB?W@{*Ohnx)*i!gKe z4h)^01i$%Bh0e4Y_|Kp71QrSu@sXlXE*mMYc{nJ zE`68;ho)(O@YYN;;W<*>BPk1fXoD-0`{B;tfmlEBEC}zo)7ERzq|1{tzZ5k*`9~f2 zG^oNx|1NG{y*Ww)EGc$R6w_8W#_x@X!9*(veeUGLDCR{KO%(Aa&nL)vx@JmF!fR$TtD2; zrFDPeI`ssYG{v2+#YeGa;l|j$Wf)R7g`(807`9Y;Q~w+VwbduEdPhLvpEV%J>*4ek zj)arP1;`KapigNl*|H->xGZuQM0vre`mO*jHh7b3WGwlOOJpq<-SOL$Tkr(>IC0)s zSbTRD^5Z!=dvOJGE1iaQ)dP@mBN*qcJqDXg80{%tP3yQgmYVUZqPZO5;2d=jY|??H zn)Y~3&xNu)q|Cw42($H7Sg>OdcHPg1jLoj}r&b(U*2J-j{s|cIst&};j8I&h3X==@ z_}BMB3Og@hu2(FvKD-Y!W(J{1R07HKT_|%Pmd@UYW8WR^udvbo6(ar~56!I_a7ku` zfch-oF|FqrO2Y8zIOp(Kt)xYku9x7>n=J0l^Z=mR&P?A7+W zHuN|~%DNBeVr2FpcrFXVShYM@YUWN=_fkl2dkWK8{l4O+pbk`ae&k{rJ2|MWw=jJ}HNIVku);Xal(Tk2} zNLk!NU5r~j1h))Bv8?bo>`?QhNM+wTijtYwKo>i{y9v$%-P|z;b+CNzgj!dR|0x(ZZ_9^gogNg_C8t1#WcJ3x7;jvwhHhamR~s}DzWaG5dPLjPR+)r3 zm1v{$oqpKg7>cdaioy1vn9AE$QGutNWuD-nX>SE+r=F-@HDSO|j zhu*V>!E9>?)|@^G< zz^y14({^5jnL{3w)+M8M<$9-EG;nP|4J3Jtgpt#pbEi%CI6lCFcC3)GmP=Ex)_D-# z90|snWyP>W%ZGk_yoSb2PiFU{XJU+1Ef^_2a9_`R$+{QUK-1rs@(xB*WKRO~d*p(*=2SuLwr(!iO}br3wdseVgdIjWB=$OsQfzw z-`AgjmT+Ivez%ewOcR-(t2dhMx(P`edbr~oG~j}!Io_&tra!eM%%pHChWYk^@sm(o z@bnl!x|o6wt)_BW95Z=lhG`L3LE1LRbsm}oFI{ahOC~1uT~gK-sgKSxhu}Mt5S&q- z4-pDa%0mV1WAV%+nc-HKS}?c#3MLgyg4`cw;df7+$jesB^y+7!x6LQ;S{{NGMW^6} zwFgbEm6Kq5JUhJH9L>hpL&k+spmtFmR;}S(rg6%9&$gBscvW`r4*0^=Nh$exQU3g`HdmB~QmBIjAIwu#T*FCA!HJ0r4*Rsr} z|KXVhH=yM}AJ@D<9rP|aV#-4?DSnl*H~&t@5n3utF-bWG+%E=*^Q5+>STcQ?#1115 z{rV6x$9?3uNOefpb3&_I9u!+8WAROfSY@cfCVK|sl~KoG9(Ynh@M^mKDv9wMEil~Z z8g%n|xs@v>L2HB&o*#CgampMFbuz|niv~fwFbt=iD}Y8nF$r!a(&6uwbJc4b>={`B zCYw}Y->0##T{H{3KRMA}K@=O%HAS5z%KJ-K2o{zXfxU(&9rjHj*8w@J(DBAoWwp?) zqY0DM#zO8hV8v_}86K;!34Ey1oECgSSI0s`D%Kv_#px?9POe)dF4@;_Hp6LLm zS^I#SanB47Oq@-Q%c7XgkP*H;JqUNxL-0q;gO5+V$;xv*J$RbL{FHrKsk|3Rx4-5p zrat9zw9T<;qb>cU8^wO8fM zTtlNR+Prn6#9b>H&(Rq7st-fK^e~(!EQHc5AL?*hLznj@va~-U@UKlbfzLX)2RS35 zyI>}|EpVjAtrFJ8>7(?_0Gu=o#t5r?z_H3aW^&R>S1|qVaky*Z74UmC3LLJg!zR zGLEiCN}0E>C0dpCLv!)FDQd4+^Slz?DEhd**Y%w01wM*2 zt;sY-#w}0y6ZvhLCG}LHHD>J8jHMuYVh-GV-xYy7by<|xd(>_xpFVrh|~GIz@iFu%GAybiqM22MZY9;pe@t;Ulw z_R5&sCO!OP=`cL*3qgyyMNl}#i`JOPsrUONc3#;J?$I|uaIu}cul|~w5Glmzi@ZrA vBZ{rh(8bkZ!(cEz7|*`WhvcW8q<3{SHIyk`wgdjHS_PY}-f@d24RZen2?GM@ literal 5104 zcmb7{`CpCc8^=#k6d^5WG3{DT+D>&^p6k}4PDIXGT~etenfW$oNElI+$kM)ikNFQgzdg_E_5S5^-`D&3T(PRW|DFwIj%J=4%p{7h zlcZ*zW&tkFW_+`h*p;8&mI(Pyf@OR&$x6k_<%uydDN;p>@~b0XAP_Fw#5Z<vEyL zg&uC7Sm)1Wl^oy6MIiVDyL>ME3FJGudpLih1)ohTLB6xI!0{99`nmKcknij!{4X-M z&!&|i-^I!0ztG%2m;MCuU0j9#r6&ArS_$$6jxPVD=JC1oC-BovQKq)Z#63J7uXsy5GJxp@I&J{)!?{s0-R}@1}BCr@!DxulB-DB zxF`n{W%h#S{0J-@J_F4wW)phI>CTEo)>37QN5r-8V5S!AIjaStM>BA3D@R#}B&^8P z9@`UoL4x6^@k-p~o?pyz68XT9d|de7!JFb^|8ayyw=uRfn%E-7woAkOtmMSy7-d`ZW!L z!=#0HptAr}bLP@zT^UubRIrjM5g1oq4xzR} zuF%beqO7^pwk4ig(iM#Fu^87LssK^O2d)UUA%3eF+Gcu@`rT;e*lvP~kU==3DaP_I zb76m-9|=w;k++td83`@XTS4IVZVbHppbZvJ_~`nEkuONuRW&QTXvt$Y8p3gVd;$1Y z1dzq}R4O~cc>gaH`78iv2QmcaT~vq^eg zPC7|)<~^u^fi^YpL{k;cZEffF2istCo)g)BCuROiO>w!}2;6=WhQ1lOkku%n_>zq@ zusM|-ZZ*b9EIzuzsv7p1c@D=Obd6wa^5Y>Wsj>tr6JXnFsr;gDB&NHRP$I zV29raU_eV1H0|x?ROe}e^EN9qZFZw!el)Y5Ivu0Z`=LQijFwutu%y<9I^APvr%cYq zj(byT=yVb4#bbbHG#Lcr9I+|Ki*~J*FxPozc=7HK47Y~k&%fk@pxB>GQsOD_p7I=g zC!>o`1rTB0Eq*3eC&a(o5OM0*>iAsyDvqWt|Hx)a&|el zv-Ee5D!BJR zma5v6fv86nV^kjfO3RB$My z3@UT_xDSUua&wJ@_na)p=TBLp>uXQCQxU~1zA(lQYx>~Iui?m5<^iwUmyY|yQKM=iGb%DfpGgFn zd(}bkzbVl8hXWq8VI=jIFr)LPxEO{&YkU|k)-Hssf!X9|E~oVUiA)wRz=L(Q@N@G7 zxKgbLp>8v9&esC!d>PFe7dv7_aSzNBh2fvk#ZWu!M`vfqsJ14N)p>q_Uh+GTwr~Q> zPS$`u|5)SWPH#$E9?kNO&BP5IJ&>|76h|dqgge2rY1-yh^e#1#`RA))AJ#y|X%)zm zzTs#bA2re)X>_=RL5KmGWDWp-ZYbWKQ2+vG<@~uOQ{XRYOnh+~zI}KT2A_9vi_<%} zukY~j&nIrQW0{1tXBePDHUz7i=Ha}tIbgt_Lm8!U)VyOYYf`#=cVG?hFZFQKTZg!k zUGCU;JAig%Nm%3&L%fhT2n$z5pkMtd=u`T7+w*uTcTQ#BuLUe~y8*K1E-vjKRp|Z8 z9nDQdRChwkTDBVCiQ++cksgLm#U~*AUJwPHiKAb?RNj|#e-ufpV5!eXZu6z_aHGu> zLwr05XG$4=w*~TAhrs@4rFZ|$hKLa#DqX6eX}1$t(>7zY_^}iUs?;DbbPSm8wZrCc zpo@oMSpF9_*!_aXoF0Z_ta%>X)0j=FS!-#kwVXB8x}r_=P1xwI1=Z`+z(!z>+ACZs z*eiy02Dszm#e<-l5su$%J_}7UKdQN#Of6k+V8tG%16*k_3kl7LLk zNLlnWSG=YCKZq`eVOV$(EPCos@1G>n^Ib}=YYdhit%7rPL)@%A_qlazt+7ewME5Vn zu*)CJu~U=Bw)uqP=2OREnQb8b(72v1pGsvVcc!ED)f(`B)yWxo-sjGpvcbeYcRKic z44Zt!9KV~a!fsWCV8yAUV3!(1*Z*BZe;!O@l}BdcP;nKA_rBvYeA>9PRRRPZ5j}qt z!~S?^fg9KGSaawC^wli@FHaFw-&sSAjtZt2Np-o2x>}>4ftQ=H;3e8Z|?Lm!qqgeSj z(~u?i!oa%)xIZHw@-=;F{-SuAaW;XS`c@zP`>#RS5e@jJQ4f+f@p0pHAdjUI#u=L6 z_Td3AeAL`@cKMUsG?ttmCoswJQE0fL76cxnA#(dWZpUnE6h3gE{xk{eoN0(NNBSYw zU5raQ@*s0TAcY=Ep|W-b^UpTKr>Ad0$v`is`>~6AaLyXdF1XPmg@n~=8{+xuL3naK z42_4fA!Bn8Z3$mZNA9FD-7Z^v5LgR}mqT3T3LYHsa={Nr11P^i!u(Z?@Ib);e6?&5 zjvhG);xZ9A?~kWaixg&I%Hj5hRUk6$;o2sThL$!Ld@@c%_e!N~BN?FjkpW0QyZ{d! zDS+myB0Bvvj=W_lOuWk%mk(FNy3BrVS)vw%eYC_vPj7F$qO zwlCKYA8A*CHID~YDTXk@S)=yP-Zc9AD0Xwlbkx!ufr*I`SQ303_Pv}%yKc#-WxkwM z4v)rB76ifi7|t`g8kq ztXy3SXH;|{wL>5NxM+iCg54NNzxbXZS{B$Q2SHCQRRC8rdJ{L#12jpz(qQ|BE z0o7nsrUw7B>f(NlQ2H&?nI`)y=g89-_m_=8t!5amRvd@j`AW~br_n{1G-h#L4Ud%8 zz~EnfT$!?)TY5U+vLl^fxsP#()T5RM0mbKsm^An8=d=-@WxeTkllV_#LmVC4vx z@{Kn9)NO{7CwkL{vM5$xHy!VK4?>}`?{2u84O`lMD1KKW1wWFpmIf7Ed>CQU2~Egr z(+9qxJw}h2MYBW_W_`mPFI0@czYb!wzMc!MfBRFax`L8i5?E=wA?{pP3-{_LgV_mP zhzPL2^;g~KSVc6eb9cs`!hYcCiE;9lvoOEGm*#FuBHb9J2S<&;XFuNtg~?>lzdaV_ zY_P(gUjeQ5ie_fn4tVBXKQu3kzU3 zfPE0XA}Q+)H^Slz!=RC(?5&S-A#r()FJ7?3*j@6`*kia@d+2b`$gm0By=~S#f#C6zVRlN_% zGNp`H`~{{i?S~&`EX3`%3n3y#L{FTPNN19qDW=ax)iXE2`IrW5I;jhSi{@zG>`o@i zy>aSGd)zzeBRD23K<|AeQ1D{_<@w5JaE@}{47bJhtG8hK6%B|IjDbd5D;y_eq}-pR_6sSyUCLYp zQ*gI>KU669#v)Z^|J0jLSNF!!gh{#=lj<<)iIfRfQ(}A0S}~5&!@I diff --git a/tests/data/save_policy_to_safetensors/aloha_act/grad_stats.safetensors b/tests/data/save_policy_to_safetensors/aloha_act/grad_stats.safetensors index 5188d8f428cdddf509da7f91f29950d9c80ad309..7dfbc3b35cc394e87cd77018d4f427928b9aa637 100644 GIT binary patch delta 110 zcmV-!0FnR5_W{WF0k9BvQS|F*yWiJSJZkQmIefA!Jk{F(xgs{;JMQ7>IWSb)JK|an zIg=~7I~5V|xpws>JPm4_xc^&QI}Rt-xeTT(Je8;;IfRu%JHo>fxUd`sI~9{*cNYn| QUje!FRLwdIWJV(JKkCj zIg=~7I}{P{xpws>JPvA`xc^&QI}az;xeTT(Je8;;IfRu%JHo>fxUd`sI}(#&cNYn` QUje!FRLwd=vw?R<1in}@TmS$7