Merge remote-tracking branch 'upstream/main' into policy_compatibility
This commit is contained in:
commit
4b4f922fa7
|
@ -7,6 +7,11 @@ ARG DEBIAN_FRONTEND=noninteractive
|
|||
# Install apt dependencies
|
||||
RUN apt-get update && apt-get install -y --no-install-recommends \
|
||||
build-essential cmake \
|
||||
git git-lfs openssh-client \
|
||||
nano vim \
|
||||
htop atop nvtop \
|
||||
sed gawk grep curl wget \
|
||||
tcpdump sysstat screen \
|
||||
libglib2.0-0 libgl1-mesa-glx libegl1-mesa \
|
||||
python${PYTHON_VERSION} python${PYTHON_VERSION}-venv \
|
||||
&& apt-get clean && rm -rf /var/lib/apt/lists/*
|
||||
|
@ -18,7 +23,8 @@ ENV PATH="/opt/venv/bin:$PATH"
|
|||
RUN echo "source /opt/venv/bin/activate" >> /root/.bashrc
|
||||
|
||||
# Install LeRobot
|
||||
COPY . /lerobot
|
||||
RUN git lfs install
|
||||
RUN git clone https://github.com/huggingface/lerobot.git
|
||||
WORKDIR /lerobot
|
||||
RUN pip install --upgrade --no-cache-dir pip
|
||||
RUN pip install --no-cache-dir ".[test, aloha, xarm, pusht]"
|
||||
|
|
|
@ -51,6 +51,7 @@ class DiffusionConfig:
|
|||
use_film_scale_modulation: FiLM (https://arxiv.org/abs/1709.07871) is used for the Unet conditioning.
|
||||
Bias modulation is used be default, while this parameter indicates whether to also use scale
|
||||
modulation.
|
||||
noise_scheduler_type: Name of the noise scheduler to use. Supported options: ["DDPM", "DDIM"].
|
||||
num_train_timesteps: Number of diffusion steps for the forward diffusion schedule.
|
||||
beta_schedule: Name of the diffusion beta schedule as per DDPMScheduler from Hugging Face diffusers.
|
||||
beta_start: Beta value for the first forward-diffusion step.
|
||||
|
@ -110,6 +111,7 @@ class DiffusionConfig:
|
|||
diffusion_step_embed_dim: int = 128
|
||||
use_film_scale_modulation: bool = True
|
||||
# Noise scheduler.
|
||||
noise_scheduler_type: str = "DDPM"
|
||||
num_train_timesteps: int = 100
|
||||
beta_schedule: str = "squaredcos_cap_v2"
|
||||
beta_start: float = 0.0001
|
||||
|
@ -151,3 +153,9 @@ class DiffusionConfig:
|
|||
raise ValueError(
|
||||
f"`prediction_type` must be one of {supported_prediction_types}. Got {self.prediction_type}."
|
||||
)
|
||||
supported_noise_schedulers = ["DDPM", "DDIM"]
|
||||
if self.noise_scheduler_type not in supported_noise_schedulers:
|
||||
raise ValueError(
|
||||
f"`noise_scheduler_type` must be one of {supported_noise_schedulers}. "
|
||||
f"Got {self.noise_scheduler_type}."
|
||||
)
|
||||
|
|
|
@ -13,6 +13,7 @@ import einops
|
|||
import torch
|
||||
import torch.nn.functional as F # noqa: N812
|
||||
import torchvision
|
||||
from diffusers.schedulers.scheduling_ddim import DDIMScheduler
|
||||
from diffusers.schedulers.scheduling_ddpm import DDPMScheduler
|
||||
from huggingface_hub import PyTorchModelHubMixin
|
||||
from robomimic.models.base_nets import SpatialSoftmax
|
||||
|
@ -144,6 +145,19 @@ class DiffusionPolicy(nn.Module, PyTorchModelHubMixin):
|
|||
return {"loss": loss}
|
||||
|
||||
|
||||
def _make_noise_scheduler(name: str, **kwargs: dict) -> DDPMScheduler | DDIMScheduler:
|
||||
"""
|
||||
Factory for noise scheduler instances of the requested type. All kwargs are passed
|
||||
to the scheduler.
|
||||
"""
|
||||
if name == "DDPM":
|
||||
return DDPMScheduler(**kwargs)
|
||||
elif name == "DDIM":
|
||||
return DDIMScheduler(**kwargs)
|
||||
else:
|
||||
raise ValueError(f"Unsupported noise scheduler type {name}")
|
||||
|
||||
|
||||
class DiffusionModel(nn.Module):
|
||||
def __init__(self, config: DiffusionConfig):
|
||||
super().__init__()
|
||||
|
@ -156,12 +170,12 @@ class DiffusionModel(nn.Module):
|
|||
* config.n_obs_steps,
|
||||
)
|
||||
|
||||
self.noise_scheduler = DDPMScheduler(
|
||||
self.noise_scheduler = _make_noise_scheduler(
|
||||
config.noise_scheduler_type,
|
||||
num_train_timesteps=config.num_train_timesteps,
|
||||
beta_start=config.beta_start,
|
||||
beta_end=config.beta_end,
|
||||
beta_schedule=config.beta_schedule,
|
||||
variance_type="fixed_small",
|
||||
clip_sample=config.clip_sample,
|
||||
clip_sample_range=config.clip_sample_range,
|
||||
prediction_type=config.prediction_type,
|
||||
|
@ -332,15 +346,16 @@ class DiffusionRgbEncoder(nn.Module):
|
|||
|
||||
# Set up pooling and final layers.
|
||||
# Use a dry run to get the feature map shape.
|
||||
# The dummy input should take the number of image channels from `config.input_shapes` and it should
|
||||
# use the height and width from `config.crop_shape`.
|
||||
image_keys = {k for k in config.input_shapes if k.startswith("observation.image")}
|
||||
assert len(image_keys) == 1
|
||||
image_key = next(iter(image_keys))
|
||||
dummy_input = torch.zeros(size=(1, config.input_shapes[image_key][0], *config.crop_shape))
|
||||
with torch.inference_mode():
|
||||
feat_map_shape = tuple(
|
||||
self.backbone(
|
||||
torch.zeros(size=(1, config.input_shapes[next(iter(image_keys))][0], *config.crop_shape))
|
||||
).shape[1:]
|
||||
)
|
||||
self.pool = SpatialSoftmax(feat_map_shape, num_kp=config.spatial_softmax_num_keypoints)
|
||||
dummy_feature_map = self.backbone(dummy_input)
|
||||
feature_map_shape = tuple(dummy_feature_map.shape[1:])
|
||||
self.pool = SpatialSoftmax(feature_map_shape, num_kp=config.spatial_softmax_num_keypoints)
|
||||
self.feature_dim = config.spatial_softmax_num_keypoints * 2
|
||||
self.out = nn.Linear(config.spatial_softmax_num_keypoints * 2, self.feature_dim)
|
||||
self.relu = nn.ReLU()
|
||||
|
|
|
@ -1,4 +1,5 @@
|
|||
import inspect
|
||||
import logging
|
||||
|
||||
from omegaconf import DictConfig, OmegaConf
|
||||
|
||||
|
@ -8,9 +9,10 @@ from lerobot.common.utils.utils import get_safe_torch_device
|
|||
|
||||
def _policy_cfg_from_hydra_cfg(policy_cfg_class, hydra_cfg):
|
||||
expected_kwargs = set(inspect.signature(policy_cfg_class).parameters)
|
||||
assert set(hydra_cfg.policy).issuperset(
|
||||
expected_kwargs
|
||||
), f"Hydra config is missing arguments: {set(expected_kwargs).difference(hydra_cfg.policy)}"
|
||||
if not set(hydra_cfg.policy).issuperset(expected_kwargs):
|
||||
logging.warning(
|
||||
f"Hydra config is missing arguments: {set(expected_kwargs).difference(hydra_cfg.policy)}"
|
||||
)
|
||||
policy_cfg = policy_cfg_class(
|
||||
**{
|
||||
k: v
|
||||
|
@ -62,11 +64,18 @@ def make_policy(
|
|||
|
||||
policy_cls, policy_cfg_class = get_policy_and_config_classes(hydra_cfg.policy.name)
|
||||
|
||||
policy_cfg = _policy_cfg_from_hydra_cfg(policy_cfg_class, hydra_cfg)
|
||||
if pretrained_policy_name_or_path is None:
|
||||
policy_cfg = _policy_cfg_from_hydra_cfg(policy_cfg_class, hydra_cfg)
|
||||
# Make a fresh policy.
|
||||
policy = policy_cls(policy_cfg, dataset_stats)
|
||||
else:
|
||||
policy = policy_cls.from_pretrained(pretrained_policy_name_or_path)
|
||||
# Load a pretrained policy and override the config if needed (for example, if there are inference-time
|
||||
# hyperparameters that we want to vary).
|
||||
# TODO(alexander-soare): This hack makes use of huggingface_hub's tooling to load the policy with, pretrained
|
||||
# weights which are then loaded into a fresh policy with the desired config. This PR in huggingface_hub should
|
||||
# make it possible to avoid the hack: https://github.com/huggingface/huggingface_hub/pull/2274.
|
||||
policy = policy_cls(policy_cfg)
|
||||
policy.load_state_dict(policy_cls.from_pretrained(pretrained_policy_name_or_path).state_dict())
|
||||
|
||||
policy.to(get_safe_torch_device(hydra_cfg.device))
|
||||
|
||||
|
|
|
@ -47,7 +47,7 @@ class TDMPCConfig:
|
|||
elite_weighting_temperature: The temperature to use for softmax weighting (by trajectory value) of the
|
||||
elites, when updating the gaussian parameters for CEM.
|
||||
gaussian_mean_momentum: Momentum (α) used for EMA updates of the mean parameter μ of the gaussian
|
||||
paramters optimized in CEM. Updates are calculated as μ⁻ ← αμ⁻ + (1-α)μ.
|
||||
parameters optimized in CEM. Updates are calculated as μ⁻ ← αμ⁻ + (1-α)μ.
|
||||
max_random_shift_ratio: Maximum random shift (as a proportion of the image size) to apply to the
|
||||
image(s) (in units of pixels) for training-time augmentation. If set to 0, no such augmentation
|
||||
is applied. Note that the input images are assumed to be square for this augmentation.
|
||||
|
|
|
@ -85,6 +85,7 @@ policy:
|
|||
diffusion_step_embed_dim: 128
|
||||
use_film_scale_modulation: True
|
||||
# Noise scheduler.
|
||||
noise_scheduler_type: DDPM
|
||||
num_train_timesteps: 100
|
||||
beta_schedule: squaredcos_cap_v2
|
||||
beta_start: 0.0001
|
||||
|
|
|
@ -131,17 +131,6 @@ files = [
|
|||
{file = "antlr4-python3-runtime-4.9.3.tar.gz", hash = "sha256:f224469b4168294902bb1efa80a8bf7855f24c99aef99cbefc1bcd3cce77881b"},
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "appdirs"
|
||||
version = "1.4.4"
|
||||
description = "A small Python module for determining appropriate platform-specific dirs, e.g. a \"user data dir\"."
|
||||
optional = false
|
||||
python-versions = "*"
|
||||
files = [
|
||||
{file = "appdirs-1.4.4-py2.py3-none-any.whl", hash = "sha256:a841dacd6b99318a741b166adb07e19ee71a274450e68237b4650ca1055ab128"},
|
||||
{file = "appdirs-1.4.4.tar.gz", hash = "sha256:7d5d0167b2b1ba821647616af46a749d1c653740dd0d2415100fe26e27afdf41"},
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "asciitree"
|
||||
version = "0.3.3"
|
||||
|
@ -1108,67 +1097,67 @@ protobuf = ["grpcio-tools (>=1.63.0)"]
|
|||
|
||||
[[package]]
|
||||
name = "gym-aloha"
|
||||
version = "0.1.0"
|
||||
version = "0.1.1"
|
||||
description = "A gym environment for ALOHA"
|
||||
optional = true
|
||||
python-versions = "<4.0,>=3.10"
|
||||
files = [
|
||||
{file = "gym_aloha-0.1.0-py3-none-any.whl", hash = "sha256:62e36eeb09284422cbb7baca0292c6f65e38ec8774bf9b0bf7159ad5990cf29a"},
|
||||
{file = "gym_aloha-0.1.0.tar.gz", hash = "sha256:bab332f469ba5ffe655fc3e9647aead05d2cb3b950dfb1f299b9539b3857ad7e"},
|
||||
{file = "gym_aloha-0.1.1-py3-none-any.whl", hash = "sha256:2698037246dbb106828f0bc229b61007b0a21d5967c72cc373f7bc1083203584"},
|
||||
{file = "gym_aloha-0.1.1.tar.gz", hash = "sha256:614ae1cf116323e7b5ae2f0e9bd282c4f052aee15e839e5587ddce45995359bc"},
|
||||
]
|
||||
|
||||
[package.dependencies]
|
||||
dm-control = "1.0.14"
|
||||
gymnasium = ">=0.29.1,<0.30.0"
|
||||
imageio = {version = ">=2.34.0,<3.0.0", extras = ["ffmpeg"]}
|
||||
dm-control = ">=1.0.14"
|
||||
gymnasium = ">=0.29.1"
|
||||
imageio = {version = ">=2.34.0", extras = ["ffmpeg"]}
|
||||
mujoco = ">=2.3.7,<3.0.0"
|
||||
|
||||
[package.extras]
|
||||
dev = ["debugpy (>=1.8.1,<2.0.0)", "pre-commit (>=3.7.0,<4.0.0)"]
|
||||
test = ["pytest (>=8.1.0,<9.0.0)", "pytest-cov (>=5.0.0,<6.0.0)"]
|
||||
dev = ["debugpy (>=1.8.1)", "pre-commit (>=3.7.0)"]
|
||||
test = ["pytest (>=8.1.0)", "pytest-cov (>=5.0.0)"]
|
||||
|
||||
[[package]]
|
||||
name = "gym-pusht"
|
||||
version = "0.1.1"
|
||||
version = "0.1.3"
|
||||
description = "A gymnasium environment for PushT."
|
||||
optional = true
|
||||
python-versions = "<4.0,>=3.10"
|
||||
files = [
|
||||
{file = "gym_pusht-0.1.1-py3-none-any.whl", hash = "sha256:dcf8644713db48286e907aabb11e005b0592632e323baa40d1a4f2dfbbc76c3d"},
|
||||
{file = "gym_pusht-0.1.1.tar.gz", hash = "sha256:0d1c9ffd4ad0e2411efcc724003a365a853f20b6d596980c113e7ec181ac021f"},
|
||||
{file = "gym_pusht-0.1.3-py3-none-any.whl", hash = "sha256:feeb02493a03d1aacc45d43d6397962c50ed779ab7e4019d73af11d2f0b3831b"},
|
||||
{file = "gym_pusht-0.1.3.tar.gz", hash = "sha256:c8e9a5256035ba49841ebbc7c32a06c4fa2daa52f5fad80da941b607c4553e28"},
|
||||
]
|
||||
|
||||
[package.dependencies]
|
||||
gymnasium = ">=0.29.1,<0.30.0"
|
||||
opencv-python = ">=4.9.0.80,<5.0.0.0"
|
||||
pygame = ">=2.5.2,<3.0.0"
|
||||
pymunk = ">=6.6.0,<7.0.0"
|
||||
gymnasium = ">=0.29.1"
|
||||
opencv-python = ">=4.9.0"
|
||||
pygame = ">=2.5.2"
|
||||
pymunk = ">=6.6.0"
|
||||
scikit-image = ">=0.22.0"
|
||||
shapely = ">=2.0.3,<3.0.0"
|
||||
shapely = ">=2.0.3"
|
||||
|
||||
[package.extras]
|
||||
dev = ["debugpy (>=1.8.1,<2.0.0)", "pre-commit (>=3.7.0,<4.0.0)"]
|
||||
test = ["pytest (>=8.1.0,<9.0.0)", "pytest-cov (>=5.0.0,<6.0.0)"]
|
||||
dev = ["debugpy (>=1.8.1)", "pre-commit (>=3.7.0)"]
|
||||
test = ["pytest (>=8.1.0)", "pytest-cov (>=5.0.0)"]
|
||||
|
||||
[[package]]
|
||||
name = "gym-xarm"
|
||||
version = "0.1.0"
|
||||
version = "0.1.1"
|
||||
description = "A gym environment for xArm"
|
||||
optional = true
|
||||
python-versions = "<4.0,>=3.10"
|
||||
files = [
|
||||
{file = "gym_xarm-0.1.0-py3-none-any.whl", hash = "sha256:d10ac19a59d302201a9b8bd913530211b1058467b787ad91a657907e40cdbc13"},
|
||||
{file = "gym_xarm-0.1.0.tar.gz", hash = "sha256:fc05f9d02af1f0205275311669dc191ce431be484e221a96401eb544764eb986"},
|
||||
{file = "gym_xarm-0.1.1-py3-none-any.whl", hash = "sha256:3bd7e3c1c5521ba80a56536f01a5e11321580704d72160355ce47a828a8808ad"},
|
||||
{file = "gym_xarm-0.1.1.tar.gz", hash = "sha256:e455524561b02d06b92a4f7d524f448d84a7484d9a2dbc78600e3c66240e0fb7"},
|
||||
]
|
||||
|
||||
[package.dependencies]
|
||||
gymnasium = ">=0.29.1,<0.30.0"
|
||||
gymnasium-robotics = ">=1.2.4,<2.0.0"
|
||||
gymnasium = ">=0.29.1"
|
||||
gymnasium-robotics = ">=1.2.4"
|
||||
mujoco = ">=2.3.7,<3.0.0"
|
||||
|
||||
[package.extras]
|
||||
dev = ["debugpy (>=1.8.1,<2.0.0)", "pre-commit (>=3.7.0,<4.0.0)"]
|
||||
test = ["pytest (>=8.1.0,<9.0.0)", "pytest-cov (>=5.0.0,<6.0.0)"]
|
||||
dev = ["debugpy (>=1.8.1)", "pre-commit (>=3.7.0)"]
|
||||
test = ["pytest (>=8.1.0)", "pytest-cov (>=5.0.0)"]
|
||||
|
||||
[[package]]
|
||||
name = "gymnasium"
|
||||
|
@ -1258,13 +1247,13 @@ numpy = ">=1.17.3"
|
|||
|
||||
[[package]]
|
||||
name = "huggingface-hub"
|
||||
version = "0.21.4"
|
||||
version = "0.23.0"
|
||||
description = "Client library to download and publish models, datasets and other repos on the huggingface.co hub"
|
||||
optional = false
|
||||
python-versions = ">=3.8.0"
|
||||
files = [
|
||||
{file = "huggingface_hub-0.21.4-py3-none-any.whl", hash = "sha256:df37c2c37fc6c82163cdd8a67ede261687d80d1e262526d6c0ce73b6b3630a7b"},
|
||||
{file = "huggingface_hub-0.21.4.tar.gz", hash = "sha256:e1f4968c93726565a80edf6dc309763c7b546d0cfe79aa221206034d50155531"},
|
||||
{file = "huggingface_hub-0.23.0-py3-none-any.whl", hash = "sha256:075c30d48ee7db2bba779190dc526d2c11d422aed6f9044c5e2fdc2c432fdb91"},
|
||||
{file = "huggingface_hub-0.23.0.tar.gz", hash = "sha256:7126dedd10a4c6fac796ced4d87a8cf004efc722a5125c2c09299017fa366fa9"},
|
||||
]
|
||||
|
||||
[package.dependencies]
|
||||
|
@ -1277,15 +1266,16 @@ tqdm = ">=4.42.1"
|
|||
typing-extensions = ">=3.7.4.3"
|
||||
|
||||
[package.extras]
|
||||
all = ["InquirerPy (==0.3.4)", "Jinja2", "Pillow", "aiohttp", "gradio", "jedi", "mypy (==1.5.1)", "numpy", "pydantic (>1.1,<2.0)", "pydantic (>1.1,<3.0)", "pytest", "pytest-asyncio", "pytest-cov", "pytest-env", "pytest-rerunfailures", "pytest-vcr", "pytest-xdist", "ruff (>=0.1.3)", "soundfile", "types-PyYAML", "types-requests", "types-simplejson", "types-toml", "types-tqdm", "types-urllib3", "typing-extensions (>=4.8.0)", "urllib3 (<2.0)"]
|
||||
all = ["InquirerPy (==0.3.4)", "Jinja2", "Pillow", "aiohttp", "fastapi", "gradio", "jedi", "minijinja (>=1.0)", "mypy (==1.5.1)", "numpy", "pytest", "pytest-asyncio", "pytest-cov", "pytest-env", "pytest-rerunfailures", "pytest-vcr", "pytest-xdist", "ruff (>=0.3.0)", "soundfile", "types-PyYAML", "types-requests", "types-simplejson", "types-toml", "types-tqdm", "types-urllib3", "typing-extensions (>=4.8.0)", "urllib3 (<2.0)"]
|
||||
cli = ["InquirerPy (==0.3.4)"]
|
||||
dev = ["InquirerPy (==0.3.4)", "Jinja2", "Pillow", "aiohttp", "gradio", "jedi", "mypy (==1.5.1)", "numpy", "pydantic (>1.1,<2.0)", "pydantic (>1.1,<3.0)", "pytest", "pytest-asyncio", "pytest-cov", "pytest-env", "pytest-rerunfailures", "pytest-vcr", "pytest-xdist", "ruff (>=0.1.3)", "soundfile", "types-PyYAML", "types-requests", "types-simplejson", "types-toml", "types-tqdm", "types-urllib3", "typing-extensions (>=4.8.0)", "urllib3 (<2.0)"]
|
||||
dev = ["InquirerPy (==0.3.4)", "Jinja2", "Pillow", "aiohttp", "fastapi", "gradio", "jedi", "minijinja (>=1.0)", "mypy (==1.5.1)", "numpy", "pytest", "pytest-asyncio", "pytest-cov", "pytest-env", "pytest-rerunfailures", "pytest-vcr", "pytest-xdist", "ruff (>=0.3.0)", "soundfile", "types-PyYAML", "types-requests", "types-simplejson", "types-toml", "types-tqdm", "types-urllib3", "typing-extensions (>=4.8.0)", "urllib3 (<2.0)"]
|
||||
fastai = ["fastai (>=2.4)", "fastcore (>=1.3.27)", "toml"]
|
||||
hf-transfer = ["hf-transfer (>=0.1.4)"]
|
||||
inference = ["aiohttp", "pydantic (>1.1,<2.0)", "pydantic (>1.1,<3.0)"]
|
||||
quality = ["mypy (==1.5.1)", "ruff (>=0.1.3)"]
|
||||
inference = ["aiohttp", "minijinja (>=1.0)"]
|
||||
quality = ["mypy (==1.5.1)", "ruff (>=0.3.0)"]
|
||||
tensorflow = ["graphviz", "pydot", "tensorflow"]
|
||||
testing = ["InquirerPy (==0.3.4)", "Jinja2", "Pillow", "aiohttp", "gradio", "jedi", "numpy", "pydantic (>1.1,<2.0)", "pydantic (>1.1,<3.0)", "pytest", "pytest-asyncio", "pytest-cov", "pytest-env", "pytest-rerunfailures", "pytest-vcr", "pytest-xdist", "soundfile", "urllib3 (<2.0)"]
|
||||
tensorflow-testing = ["keras (<3.0)", "tensorflow"]
|
||||
testing = ["InquirerPy (==0.3.4)", "Jinja2", "Pillow", "aiohttp", "fastapi", "gradio", "jedi", "minijinja (>=1.0)", "numpy", "pytest", "pytest-asyncio", "pytest-cov", "pytest-env", "pytest-rerunfailures", "pytest-vcr", "pytest-xdist", "soundfile", "urllib3 (<2.0)"]
|
||||
torch = ["safetensors", "torch"]
|
||||
typing = ["types-PyYAML", "types-requests", "types-simplejson", "types-toml", "types-tqdm", "types-urllib3", "typing-extensions (>=4.8.0)"]
|
||||
|
||||
|
@ -2587,7 +2577,7 @@ xmp = ["defusedxml"]
|
|||
name = "platformdirs"
|
||||
version = "4.2.1"
|
||||
description = "A small Python package for determining appropriate platform-specific dirs, e.g. a `user data dir`."
|
||||
optional = true
|
||||
optional = false
|
||||
python-versions = ">=3.8"
|
||||
files = [
|
||||
{file = "platformdirs-4.2.1-py3-none-any.whl", hash = "sha256:17d5a1161b3fd67b390023cb2d3b026bbd40abde6fdb052dfbd3a29c3ba22ee1"},
|
||||
|
@ -4034,36 +4024,40 @@ test = ["covdefaults (>=2.3)", "coverage (>=7.2.7)", "coverage-enable-subprocess
|
|||
|
||||
[[package]]
|
||||
name = "wandb"
|
||||
version = "0.16.6"
|
||||
version = "0.17.0"
|
||||
description = "A CLI and library for interacting with the Weights & Biases API."
|
||||
optional = false
|
||||
python-versions = ">=3.7"
|
||||
files = [
|
||||
{file = "wandb-0.16.6-py3-none-any.whl", hash = "sha256:5810019a3b981c796e98ea58557a7c380f18834e0c6bdaed15df115522e5616e"},
|
||||
{file = "wandb-0.16.6.tar.gz", hash = "sha256:86f491e3012d715e0d7d7421a4d6de41abef643b7403046261f962f3e512fe1c"},
|
||||
{file = "wandb-0.17.0-py3-none-any.whl", hash = "sha256:b1b056b4cad83b00436cb76049fd29ecedc6045999dcaa5eba40db6680960ac2"},
|
||||
{file = "wandb-0.17.0-py3-none-macosx_10_9_x86_64.whl", hash = "sha256:e1e6f04e093a6a027dcb100618ca23b122d032204b2ed4c62e4e991a48041a6b"},
|
||||
{file = "wandb-0.17.0-py3-none-macosx_11_0_arm64.whl", hash = "sha256:feeb60d4ff506d2a6bc67f953b310d70b004faa789479c03ccd1559c6f1a9633"},
|
||||
{file = "wandb-0.17.0-py3-none-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:b7bed8a3dd404a639e6bf5fea38c6efe2fb98d416ff1db4fb51be741278ed328"},
|
||||
{file = "wandb-0.17.0-py3-none-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:56a1dd6e0e635cba3f6ed30b52c71739bdc2a3e57df155619d2d80ee952b4201"},
|
||||
{file = "wandb-0.17.0-py3-none-win32.whl", hash = "sha256:1f692d3063a0d50474022cfe6668e1828260436d1cd40827d1e136b7f730c74c"},
|
||||
{file = "wandb-0.17.0-py3-none-win_amd64.whl", hash = "sha256:ab582ca0d54d52ef5b991de0717350b835400d9ac2d3adab210022b68338d694"},
|
||||
]
|
||||
|
||||
[package.dependencies]
|
||||
appdirs = ">=1.4.3"
|
||||
Click = ">=7.1,<8.0.0 || >8.0.0"
|
||||
click = ">=7.1,<8.0.0 || >8.0.0"
|
||||
docker-pycreds = ">=0.4.0"
|
||||
GitPython = ">=1.0.0,<3.1.29 || >3.1.29"
|
||||
gitpython = ">=1.0.0,<3.1.29 || >3.1.29"
|
||||
platformdirs = "*"
|
||||
protobuf = {version = ">=3.19.0,<4.21.0 || >4.21.0,<5", markers = "python_version > \"3.9\" or sys_platform != \"linux\""}
|
||||
psutil = ">=5.0.0"
|
||||
PyYAML = "*"
|
||||
pyyaml = "*"
|
||||
requests = ">=2.0.0,<3"
|
||||
sentry-sdk = ">=1.0.0"
|
||||
setproctitle = "*"
|
||||
setuptools = "*"
|
||||
|
||||
[package.extras]
|
||||
async = ["httpx (>=0.23.0)"]
|
||||
aws = ["boto3"]
|
||||
azure = ["azure-identity", "azure-storage-blob"]
|
||||
gcp = ["google-cloud-storage"]
|
||||
importers = ["filelock", "mlflow", "polars", "rich", "tenacity"]
|
||||
kubeflow = ["google-cloud-storage", "kubernetes", "minio", "sh"]
|
||||
launch = ["PyYAML (>=6.0.0)", "awscli", "azure-containerregistry", "azure-identity", "azure-storage-blob", "boto3", "botocore", "chardet", "google-auth", "google-cloud-aiplatform", "google-cloud-artifact-registry", "google-cloud-compute", "google-cloud-storage", "iso8601", "kubernetes", "kubernetes-asyncio", "nbconvert", "nbformat", "optuna", "pydantic", "tomli", "typing-extensions"]
|
||||
launch = ["awscli", "azure-containerregistry", "azure-identity", "azure-storage-blob", "boto3", "botocore", "chardet", "google-auth", "google-cloud-aiplatform", "google-cloud-artifact-registry", "google-cloud-compute", "google-cloud-storage", "iso8601", "kubernetes", "kubernetes-asyncio", "nbconvert", "nbformat", "optuna", "pydantic", "pyyaml (>=6.0.0)", "tomli", "typing-extensions"]
|
||||
media = ["bokeh", "moviepy", "numpy", "pillow", "plotly (>=5.18.0)", "rdkit-pypi", "soundfile"]
|
||||
models = ["cloudpickle"]
|
||||
perf = ["orjson"]
|
||||
|
@ -4309,13 +4303,13 @@ multidict = ">=4.0"
|
|||
|
||||
[[package]]
|
||||
name = "zarr"
|
||||
version = "2.17.2"
|
||||
version = "2.18.0"
|
||||
description = "An implementation of chunked, compressed, N-dimensional arrays for Python"
|
||||
optional = false
|
||||
python-versions = ">=3.9"
|
||||
files = [
|
||||
{file = "zarr-2.17.2-py3-none-any.whl", hash = "sha256:70d7cc07c24280c380ef80644151d136b7503b0d83c9f214e8000ddc0f57f69b"},
|
||||
{file = "zarr-2.17.2.tar.gz", hash = "sha256:2cbaa6cb4e342d45152d4a7a4b2013c337fcd3a8e7bc98253560180de60552ce"},
|
||||
{file = "zarr-2.18.0-py3-none-any.whl", hash = "sha256:7f8532b6a3f50f22e809e130e09353637ec8b5bb5e95a5a0bfaae91f63978b5d"},
|
||||
{file = "zarr-2.18.0.tar.gz", hash = "sha256:c3b7d2c85b8a42b0ad0ad268a36fb6886ca852098358c125c6b126a417e0a598"},
|
||||
]
|
||||
|
||||
[package.dependencies]
|
||||
|
@ -4354,4 +4348,4 @@ xarm = ["gym-xarm"]
|
|||
[metadata]
|
||||
lock-version = "2.0"
|
||||
python-versions = ">=3.10,<3.13"
|
||||
content-hash = "008a6af5ad9d9eafbd933c922c2c5d84fddae85aff8a9eefc0538b1319966f6e"
|
||||
content-hash = "21dd1d7404ac774bd1139e8cda44ea8e3ed97c30e524f2ed862de431d3d5fa87"
|
||||
|
|
|
@ -28,37 +28,37 @@ packages = [{include = "lerobot"}]
|
|||
|
||||
[tool.poetry.dependencies]
|
||||
python = ">=3.10,<3.13"
|
||||
termcolor = "^2.4.0"
|
||||
omegaconf = "^2.3.0"
|
||||
wandb = "^0.16.3"
|
||||
imageio = {extras = ["ffmpeg"], version = "^2.34.0"}
|
||||
gdown = "^5.1.0"
|
||||
hydra-core = "^1.3.2"
|
||||
einops = "^0.8.0"
|
||||
pymunk = "^6.6.0"
|
||||
zarr = "^2.17.0"
|
||||
numba = "^0.59.0"
|
||||
termcolor = ">=2.4.0"
|
||||
omegaconf = ">=2.3.0"
|
||||
wandb = ">=0.16.3"
|
||||
imageio = {extras = ["ffmpeg"], version = ">=2.34.0"}
|
||||
gdown = ">=5.1.0"
|
||||
hydra-core = ">=1.3.2"
|
||||
einops = ">=0.8.0"
|
||||
pymunk = ">=6.6.0"
|
||||
zarr = ">=2.17.0"
|
||||
numba = ">=0.59.0"
|
||||
torch = "^2.2.1"
|
||||
opencv-python = "^4.9.0.80"
|
||||
opencv-python = ">=4.9.0"
|
||||
diffusers = "^0.27.2"
|
||||
torchvision = "^0.18.0"
|
||||
h5py = "^3.10.0"
|
||||
huggingface-hub = "^0.21.4"
|
||||
torchvision = ">=0.18.0"
|
||||
h5py = ">=3.10.0"
|
||||
huggingface-hub = ">=0.21.4"
|
||||
robomimic = "0.2.0"
|
||||
gymnasium = "^0.29.1"
|
||||
cmake = "^3.29.0.1"
|
||||
gym-pusht = { version = "^0.1.1", optional = true}
|
||||
gym-xarm = { version = "^0.1.0", optional = true}
|
||||
gym-aloha = { version = "^0.1.0", optional = true}
|
||||
pre-commit = {version = "^3.7.0", optional = true}
|
||||
debugpy = {version = "^1.8.1", optional = true}
|
||||
pytest = {version = "^8.1.0", optional = true}
|
||||
pytest-cov = {version = "^5.0.0", optional = true}
|
||||
datasets = "^2.19.0"
|
||||
imagecodecs = { version = "^2024.1.1", optional = true }
|
||||
pyav = "^12.0.5"
|
||||
moviepy = "^1.0.3"
|
||||
rerun-sdk = "^0.15.1"
|
||||
gymnasium = ">=0.29.1"
|
||||
cmake = ">=3.29.0.1"
|
||||
gym-pusht = { version = ">=0.1.3", optional = true}
|
||||
gym-xarm = { version = ">=0.1.1", optional = true}
|
||||
gym-aloha = { version = ">=0.1.1", optional = true}
|
||||
pre-commit = {version = ">=3.7.0", optional = true}
|
||||
debugpy = {version = ">=1.8.1", optional = true}
|
||||
pytest = {version = ">=8.1.0", optional = true}
|
||||
pytest-cov = {version = ">=5.0.0", optional = true}
|
||||
datasets = ">=2.19.0"
|
||||
imagecodecs = { version = ">=2024.1.1", optional = true }
|
||||
pyav = ">=12.0.5"
|
||||
moviepy = ">=1.0.3"
|
||||
rerun-sdk = ">=0.15.1"
|
||||
|
||||
|
||||
[tool.poetry.extras]
|
||||
|
@ -104,5 +104,5 @@ ignore-init-module-imports = true
|
|||
|
||||
|
||||
[build-system]
|
||||
requires = ["poetry-core>=1.5.0"]
|
||||
requires = ["poetry-core"]
|
||||
build-backend = "poetry.core.masonry.api"
|
||||
|
|
Loading…
Reference in New Issue