Add aloha + improve readme

This commit is contained in:
Cadene 2024-03-15 00:30:11 +00:00
parent 19730b3412
commit a311d38796
8 changed files with 115 additions and 37 deletions

View File

@ -146,6 +146,25 @@ Run tests
DATA_DIR="tests/data" pytest -sx tests DATA_DIR="tests/data" pytest -sx tests
``` ```
**Datasets**
To add a pytorch rl dataset to the hub, first login and use a token generated from [huggingface settings](https://huggingface.co/settings/tokens) with write access:
```
huggingface-cli login --token $HUGGINGFACE_TOKEN --add-to-git-credential
```
Then you can upload it to the hub with:
```
HF_HUB_ENABLE_HF_TRANSFER=1 huggingface-cli upload --repo-type dataset $HF_USER/$DATASET data/$DATASET
```
For instance, for [cadene/pusht](https://huggingface.co/datasets/cadene/pusht), we used:
```
HF_USER=cadene
DATASET=pusht
```
## Acknowledgment ## Acknowledgment
- Our Diffusion policy and Pusht environment are adapted from [Diffusion Policy](https://diffusion-policy.cs.columbia.edu/) - Our Diffusion policy and Pusht environment are adapted from [Diffusion Policy](https://diffusion-policy.cs.columbia.edu/)
- Our TDMPC policy and Simxarm environment are adapted from [FOWM](https://www.yunhaifeng.com/FOWM/) - Our TDMPC policy and Simxarm environment are adapted from [FOWM](https://www.yunhaifeng.com/FOWM/)

View File

@ -1,4 +1,3 @@
import abc
import logging import logging
from pathlib import Path from pathlib import Path
from typing import Callable from typing import Callable
@ -7,8 +6,8 @@ import einops
import torch import torch
import torchrl import torchrl
import tqdm import tqdm
from huggingface_hub import snapshot_download
from tensordict import TensorDict from tensordict import TensorDict
from torchrl.data.datasets.utils import _get_root_dir
from torchrl.data.replay_buffers.replay_buffers import TensorDictReplayBuffer from torchrl.data.replay_buffers.replay_buffers import TensorDictReplayBuffer
from torchrl.data.replay_buffers.samplers import SliceSampler from torchrl.data.replay_buffers.samplers import SliceSampler
from torchrl.data.replay_buffers.storages import TensorStorage, _collate_id from torchrl.data.replay_buffers.storages import TensorStorage, _collate_id
@ -33,11 +32,8 @@ class AbstractExperienceReplay(TensorDictReplayBuffer):
): ):
self.dataset_id = dataset_id self.dataset_id = dataset_id
self.shuffle = shuffle self.shuffle = shuffle
self.root = _get_root_dir(self.dataset_id) if root is None else root self.root = root
self.root = Path(self.root) storage = self._download_or_load_dataset()
self.data_dir = self.root / self.dataset_id
storage = self._download_or_load_storage()
super().__init__( super().__init__(
storage=storage, storage=storage,
@ -98,19 +94,12 @@ class AbstractExperienceReplay(TensorDictReplayBuffer):
torch.save(stats, stats_path) torch.save(stats, stats_path)
return stats return stats
@abc.abstractmethod def _download_or_load_dataset(self) -> torch.StorageBase:
def _download_and_preproc(self) -> torch.StorageBase: if self.root is None:
raise NotImplementedError() data_dir = snapshot_download(repo_id=f"cadene/{self.dataset_id}", repo_type="dataset")
def _download_or_load_storage(self):
if not self._is_downloaded():
storage = self._download_and_preproc()
else: else:
storage = TensorStorage(TensorDict.load_memmap(self.data_dir)) data_dir = Path(self.root) / self.dataset_id
return storage return TensorStorage(TensorDict.load_memmap(data_dir))
def _is_downloaded(self) -> bool:
return self.data_dir.is_dir()
def _compute_stats(self, num_batch=100, batch_size=32): def _compute_stats(self, num_batch=100, batch_size=32):
rb = TensorDictReplayBuffer( rb = TensorDictReplayBuffer(

View File

@ -124,8 +124,8 @@ class AlohaExperienceReplay(AbstractExperienceReplay):
def image_keys(self) -> list: def image_keys(self) -> list:
return [("observation", "image", cam) for cam in CAMERAS[self.dataset_id]] return [("observation", "image", cam) for cam in CAMERAS[self.dataset_id]]
def _download_and_preproc(self): def _download_and_preproc_obsolete(self, data_dir="data"):
raw_dir = self.data_dir.parent / f"{self.data_dir.name}_raw" raw_dir = Path(data_dir) / f"{self.dataset_id}_raw"
if not raw_dir.is_dir(): if not raw_dir.is_dir():
download(raw_dir, self.dataset_id) download(raw_dir, self.dataset_id)
@ -174,7 +174,9 @@ class AlohaExperienceReplay(AbstractExperienceReplay):
if ep_id == 0: if ep_id == 0:
# hack to initialize tensordict data structure to store episodes # hack to initialize tensordict data structure to store episodes
td_data = ep_td[0].expand(total_num_frames).memmap_like(self.data_dir) td_data = (
ep_td[0].expand(total_num_frames).memmap_like(Path(self.root) / f"{self.dataset_id}")
)
td_data[idxtd : idxtd + len(ep_td)] = ep_td td_data[idxtd : idxtd + len(ep_td)] = ep_td
idxtd = idxtd + len(ep_td) idxtd = idxtd + len(ep_td)

View File

@ -1,13 +1,13 @@
import logging import logging
import os import os
from pathlib import Path
import torch import torch
from torchrl.data.replay_buffers import PrioritizedSliceSampler, SliceSampler from torchrl.data.replay_buffers import PrioritizedSliceSampler, SliceSampler
from lerobot.common.envs.transforms import NormalizeTransform, Prod from lerobot.common.envs.transforms import NormalizeTransform, Prod
DATA_DIR = Path(os.environ.get("DATA_DIR", "data")) # used for unit tests
DATA_DIR = os.environ.get("DATA_DIR", None)
def make_offline_buffer( def make_offline_buffer(
@ -77,9 +77,9 @@ def make_offline_buffer(
offline_buffer = clsfunc( offline_buffer = clsfunc(
dataset_id=dataset_id, dataset_id=dataset_id,
root=DATA_DIR,
sampler=sampler, sampler=sampler,
batch_size=batch_size, batch_size=batch_size,
root=DATA_DIR,
pin_memory=pin_memory, pin_memory=pin_memory,
prefetch=prefetch if isinstance(prefetch, int) else None, prefetch=prefetch if isinstance(prefetch, int) else None,
) )

View File

@ -8,7 +8,6 @@ import pymunk
import torch import torch
import torchrl import torchrl
import tqdm import tqdm
from huggingface_hub import snapshot_download
from tensordict import TensorDict from tensordict import TensorDict
from torchrl.data.replay_buffers.samplers import SliceSampler from torchrl.data.replay_buffers.samplers import SliceSampler
from torchrl.data.replay_buffers.storages import TensorStorage from torchrl.data.replay_buffers.storages import TensorStorage
@ -112,12 +111,8 @@ class PushtExperienceReplay(AbstractExperienceReplay):
transform=transform, transform=transform,
) )
def _download_and_preproc(self):
snapshot_download(repo_id="cadene/pusht", local_dir=self.data_dir)
return TensorStorage(TensorDict.load_memmap(self.data_dir))
def _download_and_preproc_obsolete(self): def _download_and_preproc_obsolete(self):
raw_dir = self.data_dir.parent / f"{self.data_dir.name}_raw" raw_dir = Path(self.root) / f"{self.dataset_id}_raw"
zarr_path = (raw_dir / PUSHT_ZARR).resolve() zarr_path = (raw_dir / PUSHT_ZARR).resolve()
if not zarr_path.is_dir(): if not zarr_path.is_dir():
raw_dir.mkdir(parents=True, exist_ok=True) raw_dir.mkdir(parents=True, exist_ok=True)
@ -213,7 +208,7 @@ class PushtExperienceReplay(AbstractExperienceReplay):
if episode_id == 0: if episode_id == 0:
# hack to initialize tensordict data structure to store episodes # hack to initialize tensordict data structure to store episodes
td_data = ep_td[0].expand(total_frames).memmap_like(self.data_dir) td_data = ep_td[0].expand(total_frames).memmap_like(Path(self.root) / f"{self.dataset_id}")
td_data[idxtd : idxtd + len(ep_td)] = ep_td td_data[idxtd : idxtd + len(ep_td)] = ep_td

View File

@ -64,11 +64,11 @@ class SimxarmExperienceReplay(AbstractExperienceReplay):
transform=transform, transform=transform,
) )
def _download_and_preproc(self): def _download_and_preproc_obsolete(self):
# TODO(rcadene): finish download # TODO(rcadene): finish download
download() download()
dataset_path = self.data_dir / "buffer.pkl" dataset_path = Path(self.root) / "data" / "buffer.pkl"
print(f"Using offline dataset '{dataset_path}'") print(f"Using offline dataset '{dataset_path}'")
with open(dataset_path, "rb") as f: with open(dataset_path, "rb") as f:
dataset_dict = pickle.load(f) dataset_dict = pickle.load(f)
@ -110,7 +110,7 @@ class SimxarmExperienceReplay(AbstractExperienceReplay):
if episode_id == 0: if episode_id == 0:
# hack to initialize tensordict data structure to store episodes # hack to initialize tensordict data structure to store episodes
td_data = episode[0].expand(total_frames).memmap_like(self.data_dir) td_data = episode[0].expand(total_frames).memmap_like(Path(self.root) / f"{self.dataset_id}")
td_data[idx0:idx1] = episode td_data[idx0:idx1] = episode

75
poetry.lock generated
View File

@ -838,6 +838,78 @@ files = [
[package.dependencies] [package.dependencies]
numpy = ">=1.17.3" numpy = ">=1.17.3"
[[package]]
name = "hf-transfer"
version = "0.1.6"
description = ""
optional = false
python-versions = ">=3.7"
files = [
{file = "hf_transfer-0.1.6-cp310-cp310-macosx_10_12_x86_64.whl", hash = "sha256:6fd3d61f9229d27def007e53540412507b74ac2fdb1a29985ae0b6a5137749a2"},
{file = "hf_transfer-0.1.6-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:b043bb78df1225de043eb041de9d97783fcca14a0bdc1b1d560fc172fc21b648"},
{file = "hf_transfer-0.1.6-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:7db60dd18eae4fa6ea157235fb82196cde5313995b396d1b591aad3b790a7f8f"},
{file = "hf_transfer-0.1.6-cp310-cp310-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:30d31dbab9b5a558cce407b8728e39d87d7af1ef8745ddb90187e9ae0b9e1e90"},
{file = "hf_transfer-0.1.6-cp310-cp310-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:f6b368bddd757efc7af3126ba81f9ac8f9435e2cc00902cb3d64f2be28d8f719"},
{file = "hf_transfer-0.1.6-cp310-cp310-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:aa2086d8aefaaa3e144e167324574882004c0cec49bf2d0638ec4b74732d8da0"},
{file = "hf_transfer-0.1.6-cp310-cp310-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:45d8985a0940bfe1535cb4ca781f5c11e47c83798ef3373ee1f5d57bbe527a9c"},
{file = "hf_transfer-0.1.6-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:2f42b89735f1cde22f2a795d1f0915741023235666be7de45879e533c7d6010c"},
{file = "hf_transfer-0.1.6-cp310-none-win32.whl", hash = "sha256:2d2c4c4613f3ad45b6ce6291e347b2d3ba1b86816635681436567e461cb3c961"},
{file = "hf_transfer-0.1.6-cp310-none-win_amd64.whl", hash = "sha256:78b0eed8d8dce60168a46e584b9742b816af127d7e410a713e12c31249195342"},
{file = "hf_transfer-0.1.6-cp311-cp311-macosx_10_12_x86_64.whl", hash = "sha256:f1d8c172153f9a6cdaecf137612c42796076f61f6bea1072c90ac2e17c1ab6fa"},
{file = "hf_transfer-0.1.6-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:2c601996351f90c514a75a0eeb02bf700b1ad1db2d946cbfe4b60b79e29f0b2f"},
{file = "hf_transfer-0.1.6-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:8e585c808405557d3f5488f385706abb696997bbae262ea04520757e30836d9d"},
{file = "hf_transfer-0.1.6-cp311-cp311-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:ec51af1e8cf4268c268bd88932ade3d7ca895a3c661b42493503f02610ae906b"},
{file = "hf_transfer-0.1.6-cp311-cp311-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:d106fdf996332f6df3ed3fab6d6332df82e8c1fb4b20fd81a491ca4d2ab5616a"},
{file = "hf_transfer-0.1.6-cp311-cp311-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:e9c2ee9e9fde5a0319cc0e8ddfea10897482bc06d5709b10a238f1bc2ebcbc0b"},
{file = "hf_transfer-0.1.6-cp311-cp311-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:f394ea32bc7802b061e549d3133efc523b4ae4fd19bf4b74b183ca6066eef94e"},
{file = "hf_transfer-0.1.6-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:4282f09902114cd67fca98a1a1bad569a44521a8395fedf327e966714f68b977"},
{file = "hf_transfer-0.1.6-cp311-none-win32.whl", hash = "sha256:276dbf307d5ab6f1bcbf57b5918bfcf9c59d6848ccb28242349e1bb5985f983b"},
{file = "hf_transfer-0.1.6-cp311-none-win_amd64.whl", hash = "sha256:fa475175c51451186bea804471995fa8e7b2a48a61dcca55534911dc25955527"},
{file = "hf_transfer-0.1.6-cp312-cp312-macosx_10_12_x86_64.whl", hash = "sha256:23d157a67acfa00007799323a1c441b2bbacc7dee625b016b7946fe0e25e6c89"},
{file = "hf_transfer-0.1.6-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:6067342a2864b988f861cd2d31bd78eb1e84d153a3f6df38485b6696d9ad3013"},
{file = "hf_transfer-0.1.6-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:91cfcb3070e205b58fa8dc8bcb6a62ccc40913fcdb9cd1ff7c364c8e3aa85345"},
{file = "hf_transfer-0.1.6-cp312-cp312-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:eb76064ac5165d5eeaaf8d0903e8bf55477221ecc2a4a4d69f0baca065ab905b"},
{file = "hf_transfer-0.1.6-cp312-cp312-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:9dabd3a177d83028f164984cf4dd859f77ec1e20c97a6f307ff8fcada0785ef1"},
{file = "hf_transfer-0.1.6-cp312-cp312-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:d0bf4254e44f64a26e0a5b73b5d7e8d91bb36870718fb4f8e126ec943ff4c805"},
{file = "hf_transfer-0.1.6-cp312-cp312-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:5d32c1b106f38f336ceb21531f4db9b57d777b9a33017dafdb6a5316388ebe50"},
{file = "hf_transfer-0.1.6-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:ff05aba3c83921e5c7635ba9f07c693cc893350c447644824043aeac27b285f5"},
{file = "hf_transfer-0.1.6-cp312-none-win32.whl", hash = "sha256:051ef0c55607652cb5974f59638da035773254b9a07d7ee5b574fe062de4c9d1"},
{file = "hf_transfer-0.1.6-cp312-none-win_amd64.whl", hash = "sha256:716fb5c574fcbdd8092ce73f9b6c66f42e3544337490f77c60ec07df02bd081b"},
{file = "hf_transfer-0.1.6-cp37-cp37m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:6c0c981134a55965e279cb7be778c1ccaf93f902fc9ebe31da4f30caf824cc4d"},
{file = "hf_transfer-0.1.6-cp37-cp37m-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:1ef1f145f04c5b573915bcb1eb5db4039c74f6b46fce73fc473c4287e613b623"},
{file = "hf_transfer-0.1.6-cp37-cp37m-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:d0a7609b004db3347dbb7796df45403eceb171238210d054d93897d6d84c63a4"},
{file = "hf_transfer-0.1.6-cp37-cp37m-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:60f0864bf5996773dbd5f8ae4d1649041f773fe9d5769f4c0eeb5553100acef3"},
{file = "hf_transfer-0.1.6-cp37-cp37m-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:5d01e55d630ffe70a4f5d0ed576a04c6a48d7c65ca9a7d18f2fca385f20685a9"},
{file = "hf_transfer-0.1.6-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:d855946c5062b665190de15b2bdbd4c8eddfee35350bfb7564592e23d36fbbd3"},
{file = "hf_transfer-0.1.6-cp37-none-win32.whl", hash = "sha256:fd40b2409cfaf3e8aba20169ee09552f69140e029adeec261b988903ff0c8f6f"},
{file = "hf_transfer-0.1.6-cp37-none-win_amd64.whl", hash = "sha256:0e0eba49d46d3b5481919aea0794aec625fbc6ecdf13fe7e0e9f3fc5d5ad5971"},
{file = "hf_transfer-0.1.6-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:7e669fecb29fc454449739f9f53ed9253197e7c19e6a6eaa0f08334207af4287"},
{file = "hf_transfer-0.1.6-cp38-cp38-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:89f701802892e5eb84f89f402686861f87dc227d6082b05f4e9d9b4e8015a3c3"},
{file = "hf_transfer-0.1.6-cp38-cp38-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:b6f2b0c8b95b01409275d789a9b74d5f2e146346f985d384bf50ec727caf1ccc"},
{file = "hf_transfer-0.1.6-cp38-cp38-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:aa855a2fa262792a230f9efcdb5da6d431b747d1861d2a69fe7834b19aea077e"},
{file = "hf_transfer-0.1.6-cp38-cp38-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:4aa8ca349afb2f0713475426946261eb2035e4efb50ebd2c1d5ad04f395f4217"},
{file = "hf_transfer-0.1.6-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:01255f043996bc7d1bae62d8afc5033a90c7e36ce308b988eeb84afe0a69562f"},
{file = "hf_transfer-0.1.6-cp38-none-win32.whl", hash = "sha256:60b1db183e8a7540cd4f8b2160ff4de55f77cb0c3fc6a10be1e7c30eb1b2bdeb"},
{file = "hf_transfer-0.1.6-cp38-none-win_amd64.whl", hash = "sha256:fb8be3cba6aaa50ab2e9dffbd25c8eb2046785eeff642cf0cdd0dd9ae6be3539"},
{file = "hf_transfer-0.1.6-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:d09af35e3e3f09b664e6429e9a0dc200f29c5bdfd88bdd9666de51183b1fe202"},
{file = "hf_transfer-0.1.6-cp39-cp39-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:a4505bd707cc14d85c800f961fad8ca76f804a8ad22fbb7b1a217d8d0c15e6a5"},
{file = "hf_transfer-0.1.6-cp39-cp39-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:2c453fd8b0be9740faa23cecd1f28ee9ead7d900cefa64ff836960c503a744c9"},
{file = "hf_transfer-0.1.6-cp39-cp39-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:13cb8884e718a78c3b81a8cdec9c7ac196dd42961fce55c3ccff3dd783e5ad7a"},
{file = "hf_transfer-0.1.6-cp39-cp39-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:39cd39df171a2b5404de69c4e6cd14eee47f6fe91c1692f939bfb9e59a0110d8"},
{file = "hf_transfer-0.1.6-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:8ff0629ee9f98df57a783599602eb498f9ec3619dc69348b12e4d9d754abf0e9"},
{file = "hf_transfer-0.1.6-cp39-none-win32.whl", hash = "sha256:164a6ce445eb0cc7c645f5b6e1042c003d33292520c90052b6325f30c98e4c5f"},
{file = "hf_transfer-0.1.6-cp39-none-win_amd64.whl", hash = "sha256:11b8b4b73bf455f13218c5f827698a30ae10998ca31b8264b51052868c7a9f11"},
{file = "hf_transfer-0.1.6-pp310-pypy310_pp73-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:16957ba057376a99ea361074ce1094f61b58e769defa6be2422ae59c0b6a6530"},
{file = "hf_transfer-0.1.6-pp310-pypy310_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:7db952112e3b8ee1a5cbf500d2443e9ce4fb893281c5310a3e31469898628005"},
{file = "hf_transfer-0.1.6-pp37-pypy37_pp73-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:d39d826a7344f5e39f438d62632acd00467aa54a083b66496f61ef67a9885a56"},
{file = "hf_transfer-0.1.6-pp37-pypy37_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:a4e2653fbfa92e7651db73d99b697c8684e7345c479bd6857da80bed6138abb2"},
{file = "hf_transfer-0.1.6-pp38-pypy38_pp73-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:144277e6a86add10b90ec3b583253aec777130312256bfc8d5ade5377e253807"},
{file = "hf_transfer-0.1.6-pp38-pypy38_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:3bb53bcd16365313b2aa0dbdc28206f577d70770f31249cdabc387ac5841edcc"},
{file = "hf_transfer-0.1.6-pp39-pypy39_pp73-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:990d73a5a68d8261980f146c51f4c5f9995314011cb225222021ad7c39f3af2d"},
{file = "hf_transfer-0.1.6-pp39-pypy39_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:652406037029ab9b4097b4c5f29321bad5f64c2b46fbff142509d918aec87c29"},
{file = "hf_transfer-0.1.6.tar.gz", hash = "sha256:deb505a7d417d7055fd7b3549eadb91dfe782941261f3344025c486c16d1d2f9"},
]
[[package]] [[package]]
name = "huggingface-hub" name = "huggingface-hub"
version = "0.21.4" version = "0.21.4"
@ -852,6 +924,7 @@ files = [
[package.dependencies] [package.dependencies]
filelock = "*" filelock = "*"
fsspec = ">=2023.5.0" fsspec = ">=2023.5.0"
hf-transfer = {version = ">=0.1.4", optional = true, markers = "extra == \"hf_transfer\""}
packaging = ">=20.9" packaging = ">=20.9"
pyyaml = ">=5.1" pyyaml = ">=5.1"
requests = "*" requests = "*"
@ -3254,4 +3327,4 @@ testing = ["big-O", "jaraco.functools", "jaraco.itertools", "more-itertools", "p
[metadata] [metadata]
lock-version = "2.0" lock-version = "2.0"
python-versions = "^3.10" python-versions = "^3.10"
content-hash = "0794a87fd309dffa0ad2982b6902bed7f35ae9e2a82433420516798da04c7197" content-hash = "ee86b84a795e6a3e9c2d79f244a87b55589adbe46d549ac38adf48be27c04cf9"

View File

@ -50,7 +50,7 @@ diffusers = "^0.26.3"
torchvision = "^0.17.1" torchvision = "^0.17.1"
h5py = "^3.10.0" h5py = "^3.10.0"
dm-control = "1.0.14" dm-control = "1.0.14"
huggingface-hub = "^0.21.4" huggingface-hub = {extras = ["hf-transfer"], version = "^0.21.4"}
[tool.poetry.group.dev.dependencies] [tool.poetry.group.dev.dependencies]