Use PytorchModelHubMixin to save models as safetensors (#125)

Co-authored-by: Remi <re.cadene@gmail.com>
This commit is contained in:
Alexander Soare 2024-05-01 16:17:18 +01:00 committed by GitHub
parent 01d5490d44
commit a4891095e4
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
18 changed files with 556 additions and 527 deletions

View File

@ -24,6 +24,7 @@ test-end-to-end:
${MAKE} test-diffusion-ete-eval
# ${MAKE} test-tdmpc-ete-train
# ${MAKE} test-tdmpc-ete-eval
${MAKE} test-default-ete-eval
test-act-ete-train:
python lerobot/scripts/train.py \
@ -43,11 +44,10 @@ test-act-ete-train:
test-act-ete-eval:
python lerobot/scripts/eval.py \
--config tests/outputs/act/.hydra/config.yaml \
-p tests/outputs/act/checkpoints/000002 \
eval.n_episodes=1 \
env.episode_length=8 \
device=cpu \
policy.pretrained_model_path=tests/outputs/act/models/2.pt
test-diffusion-ete-train:
python lerobot/scripts/train.py \
@ -65,11 +65,10 @@ test-diffusion-ete-train:
test-diffusion-ete-eval:
python lerobot/scripts/eval.py \
--config tests/outputs/diffusion/.hydra/config.yaml \
-p tests/outputs/diffusion/checkpoints/000002 \
eval.n_episodes=1 \
env.episode_length=8 \
device=cpu \
policy.pretrained_model_path=tests/outputs/diffusion/models/2.pt
test-tdmpc-ete-train:
python lerobot/scripts/train.py \
@ -88,8 +87,15 @@ test-tdmpc-ete-train:
test-tdmpc-ete-eval:
python lerobot/scripts/eval.py \
--config tests/outputs/tdmpc/.hydra/config.yaml \
-p tests/outputs/tdmpc/checkpoints/000002 \
eval.n_episodes=1 \
env.episode_length=8 \
device=cpu \
test-default-ete-eval:
python lerobot/scripts/eval.py \
--config lerobot/configs/default.yaml \
eval.n_episodes=1 \
env.episode_length=8 \
device=cpu \
policy.pretrained_model_path=tests/outputs/tdmpc/models/2.pt

View File

@ -135,16 +135,16 @@ Check out [examples](./examples) to see how you can load a pretrained policy fro
Or you can achieve the same result by executing our script from the command line:
```bash
python lerobot/scripts/eval.py \
--hub-id lerobot/diffusion_policy_pusht_image \
-p lerobot/diffusion_policy_pusht_image \
eval_episodes=10 \
hydra.run.dir=outputs/eval/example_hub
```
After training your own policy, you can also re-evaluate the checkpoints with:
```bash
python lerobot/scripts/eval.py \
--config PATH/TO/FOLDER/config.yaml \
policy.pretrained_model_path=PATH/TO/FOLDER/weights.pth \
-p PATH/TO/TRAIN/OUTPUT/FOLDER \
eval_episodes=10 \
hydra.run.dir=outputs/eval/example_dir
```
@ -246,29 +246,22 @@ Once you have trained a policy you may upload it to the HuggingFace hub.
Firstly, make sure you have a model repository set up on the hub. The hub ID looks like HF_USER/REPO_NAME.
Secondly, assuming you have trained a policy, you need:
Secondly, assuming you have trained a policy, you need the following (which should all be in any of the subdirectories of `checkpoints` in your training output folder, if you've used the LeRobot training script):
- `config.yaml` which you can get from the `.hydra` directory of your training output folder.
- `model.pt` which should be one of the saved models in the `models` directory of your training output folder (they won't be named `model.pt` but you will need to choose one).
- `config.json`: A serialized version of the policy configuration (following the policy's dataclass config).
- `model.safetensors`: The `torch.nn.Module` parameters saved in [Hugging Face Safetensors](https://huggingface.co/docs/safetensors/index) format.
- `config.yaml`: This is the consolidated Hydra training configuration containing the policy, environment, and dataset configs. The policy configuration should match `config.json` exactly. The environment config is useful for anyone who wants to evaluate your policy. The dataset config just serves as a paper trail for reproducibility.
To upload these to the hub, prepare a folder with the following structure (you can use symlinks rather than copying):
```
to_upload
├── config.yaml
└── model.pt
```
With the folder prepared, run the following with a desired revision ID.
To upload these to the hub, run the following with a desired revision ID.
```bash
huggingface-cli upload $HUB_ID to_upload --revision $REVISION_ID
huggingface-cli upload $HUB_ID PATH/TO/OUTPUT/DIR --revision $REVISION_ID
```
If you want this to be the default revision also run the following (don't worry, it won't upload the files again; it will just adjust the file pointers):
```bash
huggingface-cli upload $HUB_ID to_upload
huggingface-cli upload $HUB_ID PATH/TO/OUTPUT/DIR
```
See `eval.py` for an example of how a user may use your policy.

View File

@ -7,32 +7,21 @@ from pathlib import Path
from huggingface_hub import snapshot_download
from lerobot.common.utils.utils import init_hydra_config
from lerobot.scripts.eval import eval
# Get a pretrained policy from the hub.
# TODO(alexander-soare): This no longer works until we upload a new model that uses the current configs.
hub_id = "lerobot/diffusion_policy_pusht_image"
folder = Path(snapshot_download(hub_id))
pretrained_policy_name = "lerobot/diffusion_policy_pusht_image"
pretrained_policy_path = Path(snapshot_download(pretrained_policy_name))
# OR uncomment the following to evaluate a policy from the local outputs/train folder.
# folder = Path("outputs/train/example_pusht_diffusion")
config_path = folder / "config.yaml"
weights_path = folder / "model.pt"
# pretrained_policy_path = Path("outputs/train/example_pusht_diffusion")
# Override some config parameters to do with evaluation.
overrides = [
f"policy.pretrained_model_path={weights_path}",
"eval.n_episodes=10",
"eval.batch_size=10",
"device=cuda",
]
# Create a Hydra config.
cfg = init_hydra_config(config_path, overrides)
# Evaluate the policy and save the outputs including metrics and videos.
eval(
cfg,
out_dir=f"outputs/eval/example_{cfg.env.name}_{cfg.policy.name}",
)
# TODO(rcadene, alexander-soare): dont call eval, but add the minimal code snippet to rollout
eval(pretrained_policy_path=pretrained_policy_path)

View File

@ -34,7 +34,7 @@ dataset = make_dataset(hydra_cfg)
# If you're doing something different, you will likely need to change at least some of the defaults.
cfg = DiffusionConfig()
# TODO(alexander-soare): Remove LR scheduler from the policy.
policy = DiffusionPolicy(cfg, lr_scheduler_num_training_steps=training_steps, dataset_stats=dataset.stats)
policy = DiffusionPolicy(cfg, dataset_stats=dataset.stats)
policy.train()
policy.to(device)
@ -69,6 +69,7 @@ while not done:
done = True
break
# Save the policy and configuration for later use.
policy.save(output_directory / "model.pt")
# Save the policy.
policy.save_pretrained(output_directory)
# Save the Hydra configuration so we have the environment configuration for eval.
OmegaConf.save(hydra_cfg, output_directory / "config.yaml")

View File

@ -2,9 +2,12 @@ import logging
import os
from pathlib import Path
from huggingface_hub.constants import SAFETENSORS_SINGLE_FILE
from omegaconf import OmegaConf
from termcolor import colored
from lerobot.common.policies.policy_protocol import Policy
def log_output_dir(out_dir):
logging.info(colored("Output dir:", "yellow", attrs=["bold"]) + f" {out_dir}")
@ -27,7 +30,7 @@ class Logger:
self._log_dir = Path(log_dir)
self._log_dir.mkdir(parents=True, exist_ok=True)
self._job_name = job_name
self._model_dir = self._log_dir / "models"
self._model_dir = self._log_dir / "checkpoints"
self._buffer_dir = self._log_dir / "buffers"
self._save_model = cfg.training.save_model
self._disable_wandb_artifact = cfg.wandb.disable_artifact
@ -67,18 +70,20 @@ class Logger:
logging.info(f"Track this run --> {colored(wandb.run.get_url(), 'yellow', attrs=['bold'])}")
self._wandb = wandb
def save_model(self, policy, identifier):
def save_model(self, policy: Policy, identifier):
if self._save_model:
self._model_dir.mkdir(parents=True, exist_ok=True)
fp = self._model_dir / f"{str(identifier)}.pt"
policy.save(fp)
save_dir = self._model_dir / str(identifier)
policy.save_pretrained(save_dir)
# Also save the full Hydra config for the env configuration.
OmegaConf.save(self._cfg, save_dir / "config.yaml")
if self._wandb and not self._disable_wandb_artifact:
# note wandb artifact does not accept ":" in its name
artifact = self._wandb.Artifact(
self._group.replace(":", "_") + "-" + str(self._seed) + "-" + str(identifier),
type="model",
)
artifact.add_file(fp)
artifact.add_file(save_dir / SAFETENSORS_SINGLE_FILE)
self._wandb.log_artifact(artifact)
def save_buffer(self, buffer, identifier):

View File

@ -38,7 +38,7 @@ class ACTConfig:
replace_final_stride_with_dilation: Whether to replace the ResNet's final 2x2 stride with a dilated
convolution.
pre_norm: Whether to use "pre-norm" in the transformer blocks.
d_model: The transformer blocks' main hidden dimension.
dim_model: The transformer blocks' main hidden dimension.
n_heads: The number of heads to use in the transformer blocks' multi-head attention.
dim_feedforward: The dimension to expand the transformer's hidden dimension to in the feed-forward
layers.
@ -94,7 +94,7 @@ class ACTConfig:
replace_final_stride_with_dilation: int = False
# Transformer layers.
pre_norm: bool = False
d_model: int = 512
dim_model: int = 512
n_heads: int = 8
dim_feedforward: int = 3200
feedforward_activation: str = "relu"

View File

@ -14,6 +14,7 @@ import numpy as np
import torch
import torch.nn.functional as F # noqa: N812
import torchvision
from huggingface_hub import PyTorchModelHubMixin
from torch import Tensor, nn
from torchvision.models._utils import IntermediateLayerGetter
from torchvision.ops.misc import FrozenBatchNorm2d
@ -22,7 +23,7 @@ from lerobot.common.policies.act.configuration_act import ACTConfig
from lerobot.common.policies.normalize import Normalize, Unnormalize
class ACTPolicy(nn.Module):
class ACTPolicy(nn.Module, PyTorchModelHubMixin):
"""
Action Chunking Transformer Policy as per Learning Fine-Grained Bimanual Manipulation with Low-Cost
Hardware (paper: https://arxiv.org/abs/2304.13705, code: https://github.com/tonyzhaozh/act)
@ -30,27 +31,31 @@ class ACTPolicy(nn.Module):
name = "act"
def __init__(self, cfg: ACTConfig | None = None, dataset_stats=None):
def __init__(self, config: ACTConfig | None = None, dataset_stats=None):
"""
Args:
cfg: Policy configuration class instance or None, in which case the default instantiation of the
configuration class is used.
config: Policy configuration class instance or None, in which case the default instantiation of
the configuration class is used.
"""
super().__init__()
if cfg is None:
cfg = ACTConfig()
self.cfg = cfg
self.normalize_inputs = Normalize(cfg.input_shapes, cfg.input_normalization_modes, dataset_stats)
self.normalize_targets = Normalize(cfg.output_shapes, cfg.output_normalization_modes, dataset_stats)
self.unnormalize_outputs = Unnormalize(
cfg.output_shapes, cfg.output_normalization_modes, dataset_stats
if config is None:
config = ACTConfig()
self.config = config
self.normalize_inputs = Normalize(
config.input_shapes, config.input_normalization_modes, dataset_stats
)
self.model = ACT(cfg)
self.normalize_targets = Normalize(
config.output_shapes, config.output_normalization_modes, dataset_stats
)
self.unnormalize_outputs = Unnormalize(
config.output_shapes, config.output_normalization_modes, dataset_stats
)
self.model = ACT(config)
def reset(self):
"""This should be called whenever the environment is reset."""
if self.cfg.n_action_steps is not None:
self._action_queue = deque([], maxlen=self.cfg.n_action_steps)
if self.config.n_action_steps is not None:
self._action_queue = deque([], maxlen=self.config.n_action_steps)
@torch.no_grad
def select_action(self, batch: dict[str, Tensor], **_) -> Tensor:
@ -68,7 +73,7 @@ class ACTPolicy(nn.Module):
if len(self._action_queue) == 0:
# `self.model.forward` returns a (batch_size, n_action_steps, action_dim) tensor, but the queue
# effectively has shape (n_action_steps, batch_size, *), hence the transpose.
actions = self.model(batch)[0][: self.cfg.n_action_steps]
actions = self.model(batch)[0][: self.config.n_action_steps]
# TODO(rcadene): make _forward return output dictionary?
actions = self.unnormalize_outputs({"action": actions})["action"]
@ -88,7 +93,7 @@ class ACTPolicy(nn.Module):
).mean()
loss_dict = {"l1_loss": l1_loss}
if self.cfg.use_vae:
if self.config.use_vae:
# Calculate Dₖₗ(latent_pdf || standard_normal). Note: After computing the KL-divergence for
# each dimension independently, we sum over the latent dimension to get the total
# KL-divergence per batch element, then take the mean over the batch.
@ -97,7 +102,7 @@ class ACTPolicy(nn.Module):
(-0.5 * (1 + log_sigma_x2_hat - mu_hat.pow(2) - (log_sigma_x2_hat).exp())).sum(-1).mean()
)
loss_dict["kld_loss"] = mean_kld
loss_dict["loss"] = l1_loss + mean_kld * self.cfg.kl_weight
loss_dict["loss"] = l1_loss + mean_kld * self.config.kl_weight
else:
loss_dict["loss"] = l1_loss
@ -114,17 +119,10 @@ class ACTPolicy(nn.Module):
"""
# Stack images in the order dictated by input_shapes.
batch["observation.images"] = torch.stack(
[batch[k] for k in self.cfg.input_shapes if k.startswith("observation.images.")],
[batch[k] for k in self.config.input_shapes if k.startswith("observation.images.")],
dim=-4,
)
def save(self, fp):
torch.save(self.state_dict(), fp)
def load(self, fp):
d = torch.load(fp)
self.load_state_dict(d)
class ACT(nn.Module):
"""Action Chunking Transformer: The underlying neural network for ACTPolicy.
@ -161,36 +159,36 @@ class ACT(nn.Module):
"""
def __init__(self, cfg: ACTConfig):
def __init__(self, config: ACTConfig):
super().__init__()
self.cfg = cfg
self.config = config
# BERT style VAE encoder with input [cls, *joint_space_configuration, *action_sequence].
# The cls token forms parameters of the latent's distribution (like this [*means, *log_variances]).
if self.cfg.use_vae:
self.vae_encoder = ACTEncoder(cfg)
self.vae_encoder_cls_embed = nn.Embedding(1, cfg.d_model)
if self.config.use_vae:
self.vae_encoder = ACTEncoder(config)
self.vae_encoder_cls_embed = nn.Embedding(1, config.dim_model)
# Projection layer for joint-space configuration to hidden dimension.
self.vae_encoder_robot_state_input_proj = nn.Linear(
cfg.input_shapes["observation.state"][0], cfg.d_model
config.input_shapes["observation.state"][0], config.dim_model
)
# Projection layer for action (joint-space target) to hidden dimension.
self.vae_encoder_action_input_proj = nn.Linear(
cfg.input_shapes["observation.state"][0], cfg.d_model
config.input_shapes["observation.state"][0], config.dim_model
)
self.latent_dim = cfg.latent_dim
self.latent_dim = config.latent_dim
# Projection layer from the VAE encoder's output to the latent distribution's parameter space.
self.vae_encoder_latent_output_proj = nn.Linear(cfg.d_model, self.latent_dim * 2)
self.vae_encoder_latent_output_proj = nn.Linear(config.dim_model, self.latent_dim * 2)
# Fixed sinusoidal positional embedding the whole input to the VAE encoder. Unsqueeze for batch
# dimension.
self.register_buffer(
"vae_encoder_pos_enc",
create_sinusoidal_position_embedding(1 + 1 + cfg.chunk_size, cfg.d_model).unsqueeze(0),
create_sinusoidal_pos_embedding(1 + 1 + config.chunk_size, config.dim_model).unsqueeze(0),
)
# Backbone for image feature extraction.
backbone_model = getattr(torchvision.models, cfg.vision_backbone)(
replace_stride_with_dilation=[False, False, cfg.replace_final_stride_with_dilation],
weights=cfg.pretrained_backbone_weights,
backbone_model = getattr(torchvision.models, config.vision_backbone)(
replace_stride_with_dilation=[False, False, config.replace_final_stride_with_dilation],
weights=config.pretrained_backbone_weights,
norm_layer=FrozenBatchNorm2d,
)
# Note: The assumption here is that we are using a ResNet model (and hence layer4 is the final feature
@ -199,26 +197,28 @@ class ACT(nn.Module):
self.backbone = IntermediateLayerGetter(backbone_model, return_layers={"layer4": "feature_map"})
# Transformer (acts as VAE decoder when training with the variational objective).
self.encoder = ACTEncoder(cfg)
self.decoder = ACTDecoder(cfg)
self.encoder = ACTEncoder(config)
self.decoder = ACTDecoder(config)
# Transformer encoder input projections. The tokens will be structured like
# [latent, robot_state, image_feature_map_pixels].
self.encoder_robot_state_input_proj = nn.Linear(cfg.input_shapes["observation.state"][0], cfg.d_model)
self.encoder_latent_input_proj = nn.Linear(self.latent_dim, cfg.d_model)
self.encoder_robot_state_input_proj = nn.Linear(
config.input_shapes["observation.state"][0], config.dim_model
)
self.encoder_latent_input_proj = nn.Linear(self.latent_dim, config.dim_model)
self.encoder_img_feat_input_proj = nn.Conv2d(
backbone_model.fc.in_features, cfg.d_model, kernel_size=1
backbone_model.fc.in_features, config.dim_model, kernel_size=1
)
# Transformer encoder positional embeddings.
self.encoder_robot_and_latent_pos_embed = nn.Embedding(2, cfg.d_model)
self.encoder_cam_feat_pos_embed = ACTSinusoidalPositionEmbedding2d(cfg.d_model // 2)
self.encoder_robot_and_latent_pos_embed = nn.Embedding(2, config.dim_model)
self.encoder_cam_feat_pos_embed = ACTSinusoidalPositionEmbedding2d(config.dim_model // 2)
# Transformer decoder.
# Learnable positional embedding for the transformer's decoder (in the style of DETR object queries).
self.decoder_pos_embed = nn.Embedding(cfg.chunk_size, cfg.d_model)
self.decoder_pos_embed = nn.Embedding(config.chunk_size, config.dim_model)
# Final action regression head on the output of the transformer's decoder.
self.action_head = nn.Linear(cfg.d_model, cfg.output_shapes["action"][0])
self.action_head = nn.Linear(config.dim_model, config.output_shapes["action"][0])
self._reset_parameters()
@ -244,7 +244,7 @@ class ACT(nn.Module):
Tuple containing the latent PDF's parameters (mean, log(σ²)) both as (B, L) tensors where L is the
latent dimension.
"""
if self.cfg.use_vae and self.training:
if self.config.use_vae and self.training:
assert (
"action" in batch
), "actions must be provided when using the variational objective in training mode."
@ -252,7 +252,7 @@ class ACT(nn.Module):
batch_size = batch["observation.state"].shape[0]
# Prepare the latent for input to the transformer encoder.
if self.cfg.use_vae and "action" in batch:
if self.config.use_vae and "action" in batch:
# Prepare the input to the VAE encoder: [cls, *joint_space_configuration, *action_sequence].
cls_embed = einops.repeat(
self.vae_encoder_cls_embed.weight, "1 d -> b 1 d", b=batch_size
@ -322,7 +322,7 @@ class ACT(nn.Module):
# Forward pass through the transformer modules.
encoder_out = self.encoder(encoder_in, pos_embed=pos_embed)
decoder_in = torch.zeros(
(self.cfg.chunk_size, batch_size, self.cfg.d_model),
(self.config.chunk_size, batch_size, self.config.dim_model),
dtype=pos_embed.dtype,
device=pos_embed.device,
)
@ -344,10 +344,10 @@ class ACT(nn.Module):
class ACTEncoder(nn.Module):
"""Convenience module for running multiple encoder layers, maybe followed by normalization."""
def __init__(self, cfg: ACTConfig):
def __init__(self, config: ACTConfig):
super().__init__()
self.layers = nn.ModuleList([ACTEncoderLayer(cfg) for _ in range(cfg.n_encoder_layers)])
self.norm = nn.LayerNorm(cfg.d_model) if cfg.pre_norm else nn.Identity()
self.layers = nn.ModuleList([ACTEncoderLayer(config) for _ in range(config.n_encoder_layers)])
self.norm = nn.LayerNorm(config.dim_model) if config.pre_norm else nn.Identity()
def forward(self, x: Tensor, pos_embed: Tensor | None = None) -> Tensor:
for layer in self.layers:
@ -357,22 +357,22 @@ class ACTEncoder(nn.Module):
class ACTEncoderLayer(nn.Module):
def __init__(self, cfg: ACTConfig):
def __init__(self, config: ACTConfig):
super().__init__()
self.self_attn = nn.MultiheadAttention(cfg.d_model, cfg.n_heads, dropout=cfg.dropout)
self.self_attn = nn.MultiheadAttention(config.dim_model, config.n_heads, dropout=config.dropout)
# Feed forward layers.
self.linear1 = nn.Linear(cfg.d_model, cfg.dim_feedforward)
self.dropout = nn.Dropout(cfg.dropout)
self.linear2 = nn.Linear(cfg.dim_feedforward, cfg.d_model)
self.linear1 = nn.Linear(config.dim_model, config.dim_feedforward)
self.dropout = nn.Dropout(config.dropout)
self.linear2 = nn.Linear(config.dim_feedforward, config.dim_model)
self.norm1 = nn.LayerNorm(cfg.d_model)
self.norm2 = nn.LayerNorm(cfg.d_model)
self.dropout1 = nn.Dropout(cfg.dropout)
self.dropout2 = nn.Dropout(cfg.dropout)
self.norm1 = nn.LayerNorm(config.dim_model)
self.norm2 = nn.LayerNorm(config.dim_model)
self.dropout1 = nn.Dropout(config.dropout)
self.dropout2 = nn.Dropout(config.dropout)
self.activation = get_activation_fn(cfg.feedforward_activation)
self.pre_norm = cfg.pre_norm
self.activation = get_activation_fn(config.feedforward_activation)
self.pre_norm = config.pre_norm
def forward(self, x, pos_embed: Tensor | None = None) -> Tensor:
skip = x
@ -395,11 +395,11 @@ class ACTEncoderLayer(nn.Module):
class ACTDecoder(nn.Module):
def __init__(self, cfg: ACTConfig):
def __init__(self, config: ACTConfig):
"""Convenience module for running multiple decoder layers followed by normalization."""
super().__init__()
self.layers = nn.ModuleList([ACTDecoderLayer(cfg) for _ in range(cfg.n_decoder_layers)])
self.norm = nn.LayerNorm(cfg.d_model)
self.layers = nn.ModuleList([ACTDecoderLayer(config) for _ in range(config.n_decoder_layers)])
self.norm = nn.LayerNorm(config.dim_model)
def forward(
self,
@ -418,25 +418,25 @@ class ACTDecoder(nn.Module):
class ACTDecoderLayer(nn.Module):
def __init__(self, cfg: ACTConfig):
def __init__(self, config: ACTConfig):
super().__init__()
self.self_attn = nn.MultiheadAttention(cfg.d_model, cfg.n_heads, dropout=cfg.dropout)
self.multihead_attn = nn.MultiheadAttention(cfg.d_model, cfg.n_heads, dropout=cfg.dropout)
self.self_attn = nn.MultiheadAttention(config.dim_model, config.n_heads, dropout=config.dropout)
self.multihead_attn = nn.MultiheadAttention(config.dim_model, config.n_heads, dropout=config.dropout)
# Feed forward layers.
self.linear1 = nn.Linear(cfg.d_model, cfg.dim_feedforward)
self.dropout = nn.Dropout(cfg.dropout)
self.linear2 = nn.Linear(cfg.dim_feedforward, cfg.d_model)
self.linear1 = nn.Linear(config.dim_model, config.dim_feedforward)
self.dropout = nn.Dropout(config.dropout)
self.linear2 = nn.Linear(config.dim_feedforward, config.dim_model)
self.norm1 = nn.LayerNorm(cfg.d_model)
self.norm2 = nn.LayerNorm(cfg.d_model)
self.norm3 = nn.LayerNorm(cfg.d_model)
self.dropout1 = nn.Dropout(cfg.dropout)
self.dropout2 = nn.Dropout(cfg.dropout)
self.dropout3 = nn.Dropout(cfg.dropout)
self.norm1 = nn.LayerNorm(config.dim_model)
self.norm2 = nn.LayerNorm(config.dim_model)
self.norm3 = nn.LayerNorm(config.dim_model)
self.dropout1 = nn.Dropout(config.dropout)
self.dropout2 = nn.Dropout(config.dropout)
self.dropout3 = nn.Dropout(config.dropout)
self.activation = get_activation_fn(cfg.feedforward_activation)
self.pre_norm = cfg.pre_norm
self.activation = get_activation_fn(config.feedforward_activation)
self.pre_norm = config.pre_norm
def maybe_add_pos_embed(self, tensor: Tensor, pos_embed: Tensor | None) -> Tensor:
return tensor if pos_embed is None else tensor + pos_embed
@ -489,7 +489,7 @@ class ACTDecoderLayer(nn.Module):
return x
def create_sinusoidal_position_embedding(num_positions: int, dimension: int) -> Tensor:
def create_sinusoidal_pos_embedding(num_positions: int, dimension: int) -> Tensor:
"""1D sinusoidal positional embeddings as in Attention is All You Need.
Args:

View File

@ -9,7 +9,6 @@ TODO(alexander-soare):
"""
import copy
import logging
import math
from collections import deque
from typing import Callable
@ -19,6 +18,7 @@ import torch
import torch.nn.functional as F # noqa: N812
import torchvision
from diffusers.schedulers.scheduling_ddpm import DDPMScheduler
from huggingface_hub import PyTorchModelHubMixin
from robomimic.models.base_nets import SpatialSoftmax
from torch import Tensor, nn
from torch.nn.modules.batchnorm import _BatchNorm
@ -32,7 +32,7 @@ from lerobot.common.policies.utils import (
)
class DiffusionPolicy(nn.Module):
class DiffusionPolicy(nn.Module, PyTorchModelHubMixin):
"""
Diffusion Policy as per "Diffusion Policy: Visuomotor Policy Learning via Action Diffusion"
(paper: https://arxiv.org/abs/2303.04137, code: https://github.com/real-stanford/diffusion_policy).
@ -41,45 +41,50 @@ class DiffusionPolicy(nn.Module):
name = "diffusion"
def __init__(
self, cfg: DiffusionConfig | None = None, lr_scheduler_num_training_steps: int = 0, dataset_stats=None
self,
config: DiffusionConfig | None = None,
dataset_stats=None,
):
"""
Args:
cfg: Policy configuration class instance or None, in which case the default instantiation of the
configuration class is used.
config: Policy configuration class instance or None, in which case the default instantiation of
the configuration class is used.
"""
super().__init__()
# TODO(alexander-soare): LR scheduler will be removed.
assert lr_scheduler_num_training_steps > 0
if cfg is None:
cfg = DiffusionConfig()
self.cfg = cfg
self.normalize_inputs = Normalize(cfg.input_shapes, cfg.input_normalization_modes, dataset_stats)
self.normalize_targets = Normalize(cfg.output_shapes, cfg.output_normalization_modes, dataset_stats)
if config is None:
config = DiffusionConfig()
self.config = config
self.normalize_inputs = Normalize(
config.input_shapes, config.input_normalization_modes, dataset_stats
)
self.normalize_targets = Normalize(
config.output_shapes, config.output_normalization_modes, dataset_stats
)
self.unnormalize_outputs = Unnormalize(
cfg.output_shapes, cfg.output_normalization_modes, dataset_stats
config.output_shapes, config.output_normalization_modes, dataset_stats
)
# queues are populated during rollout of the policy, they contain the n latest observations and actions
self._queues = None
self.diffusion = DiffusionModel(cfg)
self.diffusion = DiffusionModel(config)
# TODO(alexander-soare): This should probably be managed outside of the policy class.
self.ema_diffusion = None
self.ema = None
if self.cfg.use_ema:
if self.config.use_ema:
self.ema_diffusion = copy.deepcopy(self.diffusion)
self.ema = DiffusionEMA(cfg, model=self.ema_diffusion)
self.ema = DiffusionEMA(config, model=self.ema_diffusion)
def reset(self):
"""
Clear observation and action queues. Should be called on `env.reset()`
"""
self._queues = {
"observation.image": deque(maxlen=self.cfg.n_obs_steps),
"observation.state": deque(maxlen=self.cfg.n_obs_steps),
"action": deque(maxlen=self.cfg.n_action_steps),
"observation.image": deque(maxlen=self.config.n_obs_steps),
"observation.state": deque(maxlen=self.config.n_obs_steps),
"action": deque(maxlen=self.config.n_action_steps),
}
@torch.no_grad
@ -138,46 +143,34 @@ class DiffusionPolicy(nn.Module):
loss = self.diffusion.compute_loss(batch)
return {"loss": loss}
def save(self, fp):
torch.save(self.state_dict(), fp)
def load(self, fp):
d = torch.load(fp)
missing_keys, unexpected_keys = self.load_state_dict(d, strict=False)
if len(missing_keys) > 0:
assert all(k.startswith("ema_diffusion.") for k in missing_keys)
logging.warning(
"DiffusionPolicy.load expected ema parameters in loaded state dict but none were found."
)
assert len(unexpected_keys) == 0
class DiffusionModel(nn.Module):
def __init__(self, cfg: DiffusionConfig):
def __init__(self, config: DiffusionConfig):
super().__init__()
self.cfg = cfg
self.config = config
self.rgb_encoder = DiffusionRgbEncoder(cfg)
self.rgb_encoder = DiffusionRgbEncoder(config)
self.unet = DiffusionConditionalUnet1d(
cfg,
global_cond_dim=(cfg.output_shapes["action"][0] + self.rgb_encoder.feature_dim) * cfg.n_obs_steps,
config,
global_cond_dim=(config.output_shapes["action"][0] + self.rgb_encoder.feature_dim)
* config.n_obs_steps,
)
self.noise_scheduler = DDPMScheduler(
num_train_timesteps=cfg.num_train_timesteps,
beta_start=cfg.beta_start,
beta_end=cfg.beta_end,
beta_schedule=cfg.beta_schedule,
num_train_timesteps=config.num_train_timesteps,
beta_start=config.beta_start,
beta_end=config.beta_end,
beta_schedule=config.beta_schedule,
variance_type="fixed_small",
clip_sample=cfg.clip_sample,
clip_sample_range=cfg.clip_sample_range,
prediction_type=cfg.prediction_type,
clip_sample=config.clip_sample,
clip_sample_range=config.clip_sample_range,
prediction_type=config.prediction_type,
)
if cfg.num_inference_steps is None:
if config.num_inference_steps is None:
self.num_inference_steps = self.noise_scheduler.config.num_train_timesteps
else:
self.num_inference_steps = cfg.num_inference_steps
self.num_inference_steps = config.num_inference_steps
# ========= inference ============
def conditional_sample(
@ -188,7 +181,7 @@ class DiffusionModel(nn.Module):
# Sample prior.
sample = torch.randn(
size=(batch_size, self.cfg.horizon, self.cfg.output_shapes["action"][0]),
size=(batch_size, self.config.horizon, self.config.output_shapes["action"][0]),
dtype=dtype,
device=device,
generator=generator,
@ -218,7 +211,7 @@ class DiffusionModel(nn.Module):
"""
assert set(batch).issuperset({"observation.state", "observation.image"})
batch_size, n_obs_steps = batch["observation.state"].shape[:2]
assert n_obs_steps == self.cfg.n_obs_steps
assert n_obs_steps == self.config.n_obs_steps
# Extract image feature (first combine batch and sequence dims).
img_features = self.rgb_encoder(einops.rearrange(batch["observation.image"], "b n ... -> (b n) ..."))
@ -231,10 +224,10 @@ class DiffusionModel(nn.Module):
sample = self.conditional_sample(batch_size, global_cond=global_cond)
# `horizon` steps worth of actions (from the first observation).
actions = sample[..., : self.cfg.output_shapes["action"][0]]
actions = sample[..., : self.config.output_shapes["action"][0]]
# Extract `n_action_steps` steps worth of actions (from the current observation).
start = n_obs_steps - 1
end = start + self.cfg.n_action_steps
end = start + self.config.n_action_steps
actions = actions[:, start:end]
return actions
@ -253,8 +246,8 @@ class DiffusionModel(nn.Module):
assert set(batch).issuperset({"observation.state", "observation.image", "action", "action_is_pad"})
batch_size, n_obs_steps = batch["observation.state"].shape[:2]
horizon = batch["action"].shape[1]
assert horizon == self.cfg.horizon
assert n_obs_steps == self.cfg.n_obs_steps
assert horizon == self.config.horizon
assert n_obs_steps == self.config.n_obs_steps
# Extract image feature (first combine batch and sequence dims).
img_features = self.rgb_encoder(einops.rearrange(batch["observation.image"], "b n ... -> (b n) ..."))
@ -283,12 +276,12 @@ class DiffusionModel(nn.Module):
# Compute the loss.
# The target is either the original trajectory, or the noise.
if self.cfg.prediction_type == "epsilon":
if self.config.prediction_type == "epsilon":
target = eps
elif self.cfg.prediction_type == "sample":
elif self.config.prediction_type == "sample":
target = batch["action"]
else:
raise ValueError(f"Unsupported prediction type {self.cfg.prediction_type}")
raise ValueError(f"Unsupported prediction type {self.config.prediction_type}")
loss = F.mse_loss(pred, target, reduction="none")
@ -306,29 +299,29 @@ class DiffusionRgbEncoder(nn.Module):
Includes the ability to normalize and crop the image first.
"""
def __init__(self, cfg: DiffusionConfig):
def __init__(self, config: DiffusionConfig):
super().__init__()
# Set up optional preprocessing.
if cfg.crop_shape is not None:
if config.crop_shape is not None:
self.do_crop = True
# Always use center crop for eval
self.center_crop = torchvision.transforms.CenterCrop(cfg.crop_shape)
if cfg.crop_is_random:
self.maybe_random_crop = torchvision.transforms.RandomCrop(cfg.crop_shape)
self.center_crop = torchvision.transforms.CenterCrop(config.crop_shape)
if config.crop_is_random:
self.maybe_random_crop = torchvision.transforms.RandomCrop(config.crop_shape)
else:
self.maybe_random_crop = self.center_crop
else:
self.do_crop = False
# Set up backbone.
backbone_model = getattr(torchvision.models, cfg.vision_backbone)(
weights=cfg.pretrained_backbone_weights
backbone_model = getattr(torchvision.models, config.vision_backbone)(
weights=config.pretrained_backbone_weights
)
# Note: This assumes that the layer4 feature map is children()[-3]
# TODO(alexander-soare): Use a safer alternative.
self.backbone = nn.Sequential(*(list(backbone_model.children())[:-2]))
if cfg.use_group_norm:
if cfg.pretrained_backbone_weights:
if config.use_group_norm:
if config.pretrained_backbone_weights:
raise ValueError(
"You can't replace BatchNorm in a pretrained model without ruining the weights!"
)
@ -342,11 +335,11 @@ class DiffusionRgbEncoder(nn.Module):
# Use a dry run to get the feature map shape.
with torch.inference_mode():
feat_map_shape = tuple(
self.backbone(torch.zeros(size=(1, *cfg.input_shapes["observation.image"]))).shape[1:]
self.backbone(torch.zeros(size=(1, *config.input_shapes["observation.image"]))).shape[1:]
)
self.pool = SpatialSoftmax(feat_map_shape, num_kp=cfg.spatial_softmax_num_keypoints)
self.feature_dim = cfg.spatial_softmax_num_keypoints * 2
self.out = nn.Linear(cfg.spatial_softmax_num_keypoints * 2, self.feature_dim)
self.pool = SpatialSoftmax(feat_map_shape, num_kp=config.spatial_softmax_num_keypoints)
self.feature_dim = config.spatial_softmax_num_keypoints * 2
self.out = nn.Linear(config.spatial_softmax_num_keypoints * 2, self.feature_dim)
self.relu = nn.ReLU()
def forward(self, x: Tensor) -> Tensor:
@ -442,34 +435,34 @@ class DiffusionConditionalUnet1d(nn.Module):
Note: this removes local conditioning as compared to the original diffusion policy code.
"""
def __init__(self, cfg: DiffusionConfig, global_cond_dim: int):
def __init__(self, config: DiffusionConfig, global_cond_dim: int):
super().__init__()
self.cfg = cfg
self.config = config
# Encoder for the diffusion timestep.
self.diffusion_step_encoder = nn.Sequential(
DiffusionSinusoidalPosEmb(cfg.diffusion_step_embed_dim),
nn.Linear(cfg.diffusion_step_embed_dim, cfg.diffusion_step_embed_dim * 4),
DiffusionSinusoidalPosEmb(config.diffusion_step_embed_dim),
nn.Linear(config.diffusion_step_embed_dim, config.diffusion_step_embed_dim * 4),
nn.Mish(),
nn.Linear(cfg.diffusion_step_embed_dim * 4, cfg.diffusion_step_embed_dim),
nn.Linear(config.diffusion_step_embed_dim * 4, config.diffusion_step_embed_dim),
)
# The FiLM conditioning dimension.
cond_dim = cfg.diffusion_step_embed_dim + global_cond_dim
cond_dim = config.diffusion_step_embed_dim + global_cond_dim
# In channels / out channels for each downsampling block in the Unet's encoder. For the decoder, we
# just reverse these.
in_out = [(cfg.output_shapes["action"][0], cfg.down_dims[0])] + list(
zip(cfg.down_dims[:-1], cfg.down_dims[1:], strict=True)
in_out = [(config.output_shapes["action"][0], config.down_dims[0])] + list(
zip(config.down_dims[:-1], config.down_dims[1:], strict=True)
)
# Unet encoder.
common_res_block_kwargs = {
"cond_dim": cond_dim,
"kernel_size": cfg.kernel_size,
"n_groups": cfg.n_groups,
"use_film_scale_modulation": cfg.use_film_scale_modulation,
"kernel_size": config.kernel_size,
"n_groups": config.n_groups,
"use_film_scale_modulation": config.use_film_scale_modulation,
}
self.down_modules = nn.ModuleList([])
for ind, (dim_in, dim_out) in enumerate(in_out):
@ -489,10 +482,10 @@ class DiffusionConditionalUnet1d(nn.Module):
self.mid_modules = nn.ModuleList(
[
DiffusionConditionalResidualBlock1d(
cfg.down_dims[-1], cfg.down_dims[-1], **common_res_block_kwargs
config.down_dims[-1], config.down_dims[-1], **common_res_block_kwargs
),
DiffusionConditionalResidualBlock1d(
cfg.down_dims[-1], cfg.down_dims[-1], **common_res_block_kwargs
config.down_dims[-1], config.down_dims[-1], **common_res_block_kwargs
),
]
)
@ -514,8 +507,8 @@ class DiffusionConditionalUnet1d(nn.Module):
)
self.final_conv = nn.Sequential(
DiffusionConv1dBlock(cfg.down_dims[0], cfg.down_dims[0], kernel_size=cfg.kernel_size),
nn.Conv1d(cfg.down_dims[0], cfg.output_shapes["action"][0], 1),
DiffusionConv1dBlock(config.down_dims[0], config.down_dims[0], kernel_size=config.kernel_size),
nn.Conv1d(config.down_dims[0], config.output_shapes["action"][0], 1),
)
def forward(self, x: Tensor, timestep: Tensor | int, global_cond=None) -> Tensor:
@ -626,13 +619,13 @@ class DiffusionEMA:
Exponential Moving Average of models weights
"""
def __init__(self, cfg: DiffusionConfig, model: nn.Module):
def __init__(self, config: DiffusionConfig, model: nn.Module):
"""
@crowsonkb's notes on EMA Warmup:
If gamma=1 and power=1, implements a simple average. gamma=1, power=2/3 are good values for models you plan
to train for a million or more steps (reaches decay factor 0.999 at 31.6K steps, 0.9999 at 1M steps),
gamma=1, power=3/4 for models you plan to train for less (reaches decay factor 0.999 at 10K steps, 0.9999
at 215.4k steps).
If gamma=1 and power=1, implements a simple average. gamma=1, power=2/3 are good values for models
you plan to train for a million or more steps (reaches decay factor 0.999 at 31.6K steps, 0.9999
at 1M steps), gamma=1, power=3/4 for models you plan to train for less (reaches decay factor 0.999
at 10K steps, 0.9999 at 215.4k steps).
Args:
inv_gamma (float): Inverse multiplicative factor of EMA warmup. Default: 1.
power (float): Exponential factor of EMA warmup. Default: 2/3.
@ -643,11 +636,11 @@ class DiffusionEMA:
self.averaged_model.eval()
self.averaged_model.requires_grad_(False)
self.update_after_step = cfg.ema_update_after_step
self.inv_gamma = cfg.ema_inv_gamma
self.power = cfg.ema_power
self.min_alpha = cfg.ema_min_alpha
self.max_alpha = cfg.ema_max_alpha
self.update_after_step = config.ema_update_after_step
self.inv_gamma = config.ema_inv_gamma
self.power = config.ema_power
self.min_alpha = config.ema_min_alpha
self.max_alpha = config.ema_max_alpha
self.alpha = 0.0
self.optimization_step = 0

View File

@ -2,6 +2,7 @@ import inspect
from omegaconf import DictConfig, OmegaConf
from lerobot.common.policies.policy_protocol import Policy
from lerobot.common.utils.utils import get_safe_torch_device
@ -20,42 +21,49 @@ def _policy_cfg_from_hydra_cfg(policy_cfg_class, hydra_cfg):
return policy_cfg
def make_policy(hydra_cfg: DictConfig, dataset_stats=None):
if hydra_cfg.policy.name == "tdmpc":
from lerobot.common.policies.tdmpc.policy import TDMPCPolicy
policy = TDMPCPolicy(
hydra_cfg.policy,
n_obs_steps=hydra_cfg.n_obs_steps,
n_action_steps=hydra_cfg.n_action_steps,
device=hydra_cfg.device,
)
elif hydra_cfg.policy.name == "diffusion":
def get_policy_and_config_classes(name: str) -> tuple[Policy, object]:
"""Get the policy's class and config class given a name (matching the policy class' `name` attribute)."""
if name == "tdmpc":
raise NotImplementedError("Coming soon!")
elif name == "diffusion":
from lerobot.common.policies.diffusion.configuration_diffusion import DiffusionConfig
from lerobot.common.policies.diffusion.modeling_diffusion import DiffusionPolicy
policy_cfg = _policy_cfg_from_hydra_cfg(DiffusionConfig, hydra_cfg)
policy = DiffusionPolicy(policy_cfg, hydra_cfg.training.offline_steps, dataset_stats)
policy.to(get_safe_torch_device(hydra_cfg.device))
elif hydra_cfg.policy.name == "act":
return DiffusionPolicy, DiffusionConfig
elif name == "act":
from lerobot.common.policies.act.configuration_act import ACTConfig
from lerobot.common.policies.act.modeling_act import ACTPolicy
policy_cfg = _policy_cfg_from_hydra_cfg(ACTConfig, hydra_cfg)
policy = ACTPolicy(policy_cfg, dataset_stats)
return ACTPolicy, ACTConfig
else:
raise NotImplementedError(f"Policy with name {name} is not implemented.")
def make_policy(
hydra_cfg: DictConfig, pretrained_policy_name_or_path: str | None = None, dataset_stats=None
) -> Policy:
"""Make an instance of a policy class.
Args:
hydra_cfg: A parsed Hydra configuration (see scripts). If `pretrained_policy_name_or_path` is
provided, only `hydra_cfg.policy.name` is used while everything else is ignored.
pretrained_policy_name_or_path: Either the repo ID of a model hosted on the Hub or a path to a
directory containing weights saved using `Policy.save_pretrained`. Note that providing this
argument overrides everything in `hydra_cfg.policy` apart from `hydra_cfg.policy.name`.
dataset_stats: Dataset statistics to use for (un)normalization of inputs/outputs in the policy. Must
be provided when initializing a new policy, and must not be provided when loading a pretrained
policy. Therefore, this argument is mutually exclusive with `pretrained_policy_name_or_path`.
"""
if not (pretrained_policy_name_or_path is None) ^ (dataset_stats is None):
raise ValueError("Only one of `pretrained_policy_name_or_path` and `dataset_stats` may be provided.")
policy_cls, policy_cfg_class = get_policy_and_config_classes(hydra_cfg.policy.name)
if pretrained_policy_name_or_path is None:
policy_cfg = _policy_cfg_from_hydra_cfg(policy_cfg_class, hydra_cfg)
policy = policy_cls(policy_cfg, dataset_stats)
policy.to(get_safe_torch_device(hydra_cfg.device))
else:
raise ValueError(hydra_cfg.policy.name)
if hydra_cfg.policy.pretrained_model_path:
# TODO(rcadene): hack for old pretrained models from fowm
if hydra_cfg.policy.name == "tdmpc" and "fowm" in hydra_cfg.policy.pretrained_model_path:
if "offline" in hydra_cfg.policy.pretrained_model_path:
policy.step[0] = 25000
elif "final" in hydra_cfg.policy.pretrained_model_path:
policy.step[0] = 100000
else:
raise NotImplementedError()
policy.load(hydra_cfg.policy.pretrained_model_path)
policy = policy_cls.from_pretrained(pretrained_policy_name_or_path)
return policy

View File

@ -57,17 +57,28 @@ def create_stats_buffers(
)
if stats is not None:
# Note: The clone is needed to make sure that the logic in save_pretrained doesn't see duplicated
# tensors anywhere (for example, when we use the same stats for normalization and
# unnormalization). See the logic here
# https://github.com/huggingface/safetensors/blob/079781fd0dc455ba0fe851e2b4507c33d0c0d407/bindings/python/py_src/safetensors/torch.py#L97.
if mode == "mean_std":
buffer["mean"].data = stats[key]["mean"]
buffer["std"].data = stats[key]["std"]
buffer["mean"].data = stats[key]["mean"].clone()
buffer["std"].data = stats[key]["std"].clone()
elif mode == "min_max":
buffer["min"].data = stats[key]["min"]
buffer["max"].data = stats[key]["max"]
buffer["min"].data = stats[key]["min"].clone()
buffer["max"].data = stats[key]["max"].clone()
stats_buffers[key] = buffer
return stats_buffers
def _no_stats_error_str(name: str) -> str:
return (
f"`{name}` is infinity. You should either initialize with `stats` as an argument, or use a "
"pretrained model."
)
class Normalize(nn.Module):
"""Normalizes data (e.g. "observation.image") for more stable and faster convergence during training."""
@ -99,7 +110,6 @@ class Normalize(nn.Module):
self.shapes = shapes
self.modes = modes
self.stats = stats
# `self.buffer_observation_state["mean"]` contains `torch.tensor(state_dim)`
stats_buffers = create_stats_buffers(shapes, modes, stats)
for key, buffer in stats_buffers.items():
setattr(self, "buffer_" + key.replace(".", "_"), buffer)
@ -113,26 +123,14 @@ class Normalize(nn.Module):
if mode == "mean_std":
mean = buffer["mean"]
std = buffer["std"]
assert not torch.isinf(mean).any(), (
"`mean` is infinity. You forgot to initialize with `stats` as argument, or called "
"`policy.load_state_dict`."
)
assert not torch.isinf(std).any(), (
"`std` is infinity. You forgot to initialize with `stats` as argument, or called "
"`policy.load_state_dict`."
)
assert not torch.isinf(mean).any(), _no_stats_error_str("mean")
assert not torch.isinf(std).any(), _no_stats_error_str("std")
batch[key] = (batch[key] - mean) / (std + 1e-8)
elif mode == "min_max":
min = buffer["min"]
max = buffer["max"]
assert not torch.isinf(min).any(), (
"`min` is infinity. You forgot to initialize with `stats` as argument, or called "
"`policy.load_state_dict`."
)
assert not torch.isinf(max).any(), (
"`max` is infinity. You forgot to initialize with `stats` as argument, or called "
"`policy.load_state_dict`."
)
assert not torch.isinf(min).any(), _no_stats_error_str("min")
assert not torch.isinf(max).any(), _no_stats_error_str("max")
# normalize to [0,1]
batch[key] = (batch[key] - min) / (max - min)
# normalize to [-1, 1]
@ -190,26 +188,14 @@ class Unnormalize(nn.Module):
if mode == "mean_std":
mean = buffer["mean"]
std = buffer["std"]
assert not torch.isinf(mean).any(), (
"`mean` is infinity. You forgot to initialize with `stats` as argument, or called "
"`policy.load_state_dict`."
)
assert not torch.isinf(std).any(), (
"`std` is infinity. You forgot to initialize with `stats` as argument, or called "
"`policy.load_state_dict`."
)
assert not torch.isinf(mean).any(), _no_stats_error_str("mean")
assert not torch.isinf(std).any(), _no_stats_error_str("std")
batch[key] = batch[key] * std + mean
elif mode == "min_max":
min = buffer["min"]
max = buffer["max"]
assert not torch.isinf(min).any(), (
"`min` is infinity. You forgot to initialize with `stats` as argument, or called "
"`policy.load_state_dict`."
)
assert not torch.isinf(max).any(), (
"`max` is infinity. You forgot to initialize with `stats` as argument, or called "
"`policy.load_state_dict`."
)
assert not torch.isinf(min).any(), _no_stats_error_str("min")
assert not torch.isinf(max).any(), _no_stats_error_str("max")
batch[key] = (batch[key] + 1) / 2
batch[key] = batch[key] * (max - min) + min
else:

View File

@ -14,7 +14,10 @@ from torch import Tensor
@runtime_checkable
class Policy(Protocol):
"""The required interface for implementing a policy."""
"""The required interface for implementing a policy.
We also expect all policies to subclass torch.nn.Module and PyTorchModelHubMixin.
"""
name: str

View File

@ -34,8 +34,6 @@ eval:
policy:
name: act
pretrained_model_path:
# Input / output structure.
n_obs_steps: 1
chunk_size: 100 # chunk_size
@ -62,7 +60,7 @@ policy:
replace_final_stride_with_dilation: false
# Transformer layers.
pre_norm: false
d_model: 512
dim_model: 512
n_heads: 8
dim_feedforward: 3200
feedforward_activation: relu

View File

@ -46,8 +46,6 @@ override_dataset_stats:
policy:
name: diffusion
pretrained_model_path:
# Input / output structure.
n_obs_steps: 2
horizon: 16

View File

@ -1,30 +1,29 @@
"""Evaluate a policy on an environment by running rollouts and computing metrics.
The script may be run in one of two ways:
Usage examples:
1. By providing the path to a config file with the --config argument.
2. By providing a HuggingFace Hub ID with the --hub-id argument. You may also provide a revision number with the
--revision argument.
You want to evaluate a model from the hub (eg: https://huggingface.co/lerobot/diffusion_policy_pusht_image)
for 10 episodes.
In either case, it is possible to override config arguments by adding a list of config.key=value arguments.
```
python lerobot/scripts/eval.py -p lerobot/diffusion_policy_pusht_image eval.n_episodes=10
```
Examples:
You have a specific config file to go with trained model weights, and want to run 10 episodes.
OR, you want to evaluate a model checkpoint from the LeRobot training script for 10 episodes.
```
python lerobot/scripts/eval.py \
--config PATH/TO/FOLDER/config.yaml \
policy.pretrained_model_path=PATH/TO/FOLDER/weights.pth \
eval.n_episodes=10
-p outputs/train/diffusion_policy_pusht_image/checkpoints/005000 \
eval.n_episodes=10
```
You have a HuggingFace Hub ID, you know which revision you want, and want to run 10 episodes (note that in this case,
you don't need to specify which weights to use):
Note that in both examples, the repo/folder should contain at least `config.json`, `config.yaml` and
`model.safetensors`.
```
python lerobot/scripts/eval.py --hub-id HUB/ID --revision v1.0 eval.n_episodes=10
```
Note the formatting for providing the number of episodes. Generally, you may provide any number of arguments
with `qualified.parameter.name=value`. In this case, the parameter eval.n_episodes appears as `n_episodes`
nested under `eval` in the `config.yaml` found at
https://huggingface.co/lerobot/diffusion_policy_pusht_image/tree/main.
"""
import argparse
@ -42,9 +41,12 @@ import numpy as np
import torch
from datasets import Dataset, Features, Image, Sequence, Value
from huggingface_hub import snapshot_download
from huggingface_hub.utils._errors import RepositoryNotFoundError
from huggingface_hub.utils._validators import HFValidationError
from PIL import Image as PILImage
from tqdm import trange
from lerobot.common.datasets.factory import make_dataset
from lerobot.common.datasets.utils import hf_transform_to_torch
from lerobot.common.envs.factory import make_env
from lerobot.common.envs.utils import postprocess_action, preprocess_observation
@ -349,26 +351,41 @@ def eval_policy(
return info
def eval(cfg: dict, out_dir=None):
def eval(
pretrained_policy_path: str | None = None,
hydra_cfg_path: str | None = None,
config_overrides: list[str] | None = None,
):
assert (pretrained_policy_path is None) ^ (hydra_cfg_path is None)
if hydra_cfg_path is None:
hydra_cfg = init_hydra_config(pretrained_policy_path / "config.yaml", config_overrides)
else:
hydra_cfg = init_hydra_config(hydra_cfg_path, config_overrides)
out_dir = (
f"outputs/eval/{dt.now().strftime('%Y-%m-%d/%H-%M-%S')}_{hydra_cfg.env.name}_{hydra_cfg.policy.name}"
)
if out_dir is None:
raise NotImplementedError()
init_logging()
# Check device is available
get_safe_torch_device(cfg.device, log=True)
get_safe_torch_device(hydra_cfg.device, log=True)
torch.backends.cudnn.benchmark = True
torch.backends.cuda.matmul.allow_tf32 = True
set_global_seed(cfg.seed)
set_global_seed(hydra_cfg.seed)
log_output_dir(out_dir)
logging.info("Making environment.")
env = make_env(cfg, num_parallel_envs=cfg.eval.n_episodes)
env = make_env(hydra_cfg, num_parallel_envs=hydra_cfg.eval.n_episodes)
logging.info("Making policy.")
policy = make_policy(cfg)
if hydra_cfg_path is None:
policy = make_policy(hydra_cfg=hydra_cfg, pretrained_policy_name_or_path=pretrained_policy_path)
else:
# Note: We need the dataset stats to pass to the policy's normalization modules.
policy = make_policy(hydra_cfg=hydra_cfg, dataset_stats=make_dataset(hydra_cfg).stats)
info = eval_policy(
env,
@ -376,7 +393,7 @@ def eval(cfg: dict, out_dir=None):
max_episodes_rendered=10,
video_dir=Path(out_dir) / "eval",
return_episode_data=False,
seed=cfg.seed,
seed=hydra_cfg.seed,
)
print(info["aggregated"])
@ -390,13 +407,29 @@ def eval(cfg: dict, out_dir=None):
if __name__ == "__main__":
init_logging()
parser = argparse.ArgumentParser(
description=__doc__, formatter_class=argparse.RawDescriptionHelpFormatter
)
group = parser.add_mutually_exclusive_group(required=True)
group.add_argument("--config", help="Path to a specific yaml config you want to use.")
group.add_argument("--hub-id", help="HuggingFace Hub ID for a pretrained model.")
parser.add_argument("--revision", help="Optionally provide the HuggingFace Hub revision ID.")
group.add_argument(
"-p",
"--pretrained-policy-name-or-path",
help=(
"Either the repo ID of a model hosted on the Hub or a path to a directory containing weights "
"saved using `Policy.save_pretrained`. If not provided, the policy is initialized from scratch "
"(useful for debugging). This argument is mutually exclusive with `--config`."
),
)
group.add_argument(
"--config",
help=(
"Path to a yaml config you want to use for initializing a policy from scratch (useful for "
"debugging). This argument is mutually exclusive with `--pretrained-policy-name-or-path` (`-p`)."
),
)
parser.add_argument("--revision", help="Optionally provide the Hugging Face Hub revision ID.")
parser.add_argument(
"overrides",
nargs="*",
@ -404,16 +437,28 @@ if __name__ == "__main__":
)
args = parser.parse_args()
if args.config is not None:
# Note: For the config_path, Hydra wants a path relative to this script file.
cfg = init_hydra_config(args.config, args.overrides)
elif args.hub_id is not None:
folder = Path(snapshot_download(args.hub_id, revision=args.revision))
cfg = init_hydra_config(
folder / "config.yaml", [f"policy.pretrained_model_path={folder / 'model.pt'}", *args.overrides]
)
if args.pretrained_policy_name_or_path is None:
eval(hydra_cfg_path=args.config, config_overrides=args.overrides)
else:
try:
pretrained_policy_path = Path(
snapshot_download(args.pretrained_policy_name_or_path, revision=args.revision)
)
except HFValidationError:
logging.warning(
"The provided pretrained_policy_name_or_path is not a valid Hugging Face Hub repo ID. "
"Treating it as a local directory."
)
except RepositoryNotFoundError:
logging.warning(
"The provided pretrained_policy_name_or_path was not found on the Hugging Face Hub. Treating "
"it as a local directory."
)
pretrained_policy_path = Path(args.pretrained_policy_name_or_path)
if not pretrained_policy_path.is_dir() or not pretrained_policy_path.exists():
raise ValueError(
"The provided pretrained_policy_name_or_path is not a valid/existing Hugging Face Hub "
"repo ID, nor is it an existing local directory."
)
eval(
cfg,
out_dir=f"outputs/eval/{dt.now().strftime('%Y-%m-%d/%H-%M-%S')}_{cfg.env.name}_{cfg.policy.name}",
)
eval(pretrained_policy_path=pretrained_policy_path, config_overrides=args.overrides)

View File

@ -265,7 +265,7 @@ def train(cfg: dict, out_dir=None, job_name=None):
env = make_env(cfg, num_parallel_envs=cfg.eval.n_episodes)
logging.info("make_policy")
policy = make_policy(cfg, dataset_stats=offline_dataset.stats)
policy = make_policy(hydra_cfg=cfg, dataset_stats=offline_dataset.stats)
# Create optimizer and scheduler
# Temporary hack to move optimizer out of policy
@ -340,7 +340,14 @@ def train(cfg: dict, out_dir=None, job_name=None):
if cfg.training.save_model and step % cfg.training.save_freq == 0:
logging.info(f"Checkpoint policy after step {step}")
logger.save_model(policy, identifier=step)
# Note: Save with step as the identifier, and format it to have at least 6 digits but more if
# needed (choose 6 as a minimum for consistency without being overkill).
logger.save_model(
policy,
identifier=str(step).zfill(
max(6, len(str(cfg.training.offline_steps + cfg.training.online_steps)))
),
)
logging.info("Resume training")
# create dataloader for offline training

296
poetry.lock generated
View File

@ -826,13 +826,13 @@ files = [
[[package]]
name = "filelock"
version = "3.13.4"
version = "3.14.0"
description = "A platform independent file lock."
optional = false
python-versions = ">=3.8"
files = [
{file = "filelock-3.13.4-py3-none-any.whl", hash = "sha256:404e5e9253aa60ad457cae1be07c0f0ca90a63931200a47d9b6a6af84fd7b45f"},
{file = "filelock-3.13.4.tar.gz", hash = "sha256:d13f466618bfde72bd2c18255e269f72542c6e70e7bac83a0232d6b1cc5c8cf4"},
{file = "filelock-3.14.0-py3-none-any.whl", hash = "sha256:43339835842f110ca7ae60f1e1c160714c5a6afd15a2873419ab185334975c0f"},
{file = "filelock-3.14.0.tar.gz", hash = "sha256:6ea72da3be9b8c82afd3edcf99f2fffbb5076335a5ae4d03248bb5b6c3eae78a"},
]
[package.extras]
@ -1039,69 +1039,61 @@ preview = ["glfw-preview"]
[[package]]
name = "grpcio"
version = "1.62.2"
version = "1.63.0"
description = "HTTP/2-based RPC framework"
optional = false
python-versions = ">=3.7"
python-versions = ">=3.8"
files = [
{file = "grpcio-1.62.2-cp310-cp310-linux_armv7l.whl", hash = "sha256:66344ea741124c38588a664237ac2fa16dfd226964cca23ddc96bd4accccbde5"},
{file = "grpcio-1.62.2-cp310-cp310-macosx_12_0_universal2.whl", hash = "sha256:5dab7ac2c1e7cb6179c6bfad6b63174851102cbe0682294e6b1d6f0981ad7138"},
{file = "grpcio-1.62.2-cp310-cp310-manylinux_2_17_aarch64.whl", hash = "sha256:3ad00f3f0718894749d5a8bb0fa125a7980a2f49523731a9b1fabf2b3522aa43"},
{file = "grpcio-1.62.2-cp310-cp310-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:2e72ddfee62430ea80133d2cbe788e0d06b12f865765cb24a40009668bd8ea05"},
{file = "grpcio-1.62.2-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:53d3a59a10af4c2558a8e563aed9f256259d2992ae0d3037817b2155f0341de1"},
{file = "grpcio-1.62.2-cp310-cp310-musllinux_1_1_i686.whl", hash = "sha256:a1511a303f8074f67af4119275b4f954189e8313541da7b88b1b3a71425cdb10"},
{file = "grpcio-1.62.2-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:b94d41b7412ef149743fbc3178e59d95228a7064c5ab4760ae82b562bdffb199"},
{file = "grpcio-1.62.2-cp310-cp310-win32.whl", hash = "sha256:a75af2fc7cb1fe25785be7bed1ab18cef959a376cdae7c6870184307614caa3f"},
{file = "grpcio-1.62.2-cp310-cp310-win_amd64.whl", hash = "sha256:80407bc007754f108dc2061e37480238b0dc1952c855e86a4fc283501ee6bb5d"},
{file = "grpcio-1.62.2-cp311-cp311-linux_armv7l.whl", hash = "sha256:c1624aa686d4b36790ed1c2e2306cc3498778dffaf7b8dd47066cf819028c3ad"},
{file = "grpcio-1.62.2-cp311-cp311-macosx_10_10_universal2.whl", hash = "sha256:1c1bb80299bdef33309dff03932264636450c8fdb142ea39f47e06a7153d3063"},
{file = "grpcio-1.62.2-cp311-cp311-manylinux_2_17_aarch64.whl", hash = "sha256:db068bbc9b1fa16479a82e1ecf172a93874540cb84be69f0b9cb9b7ac3c82670"},
{file = "grpcio-1.62.2-cp311-cp311-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:e2cc8a308780edbe2c4913d6a49dbdb5befacdf72d489a368566be44cadaef1a"},
{file = "grpcio-1.62.2-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:d0695ae31a89f1a8fc8256050329a91a9995b549a88619263a594ca31b76d756"},
{file = "grpcio-1.62.2-cp311-cp311-musllinux_1_1_i686.whl", hash = "sha256:88b4f9ee77191dcdd8810241e89340a12cbe050be3e0d5f2f091c15571cd3930"},
{file = "grpcio-1.62.2-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:2a0204532aa2f1afd467024b02b4069246320405bc18abec7babab03e2644e75"},
{file = "grpcio-1.62.2-cp311-cp311-win32.whl", hash = "sha256:6e784f60e575a0de554ef9251cbc2ceb8790914fe324f11e28450047f264ee6f"},
{file = "grpcio-1.62.2-cp311-cp311-win_amd64.whl", hash = "sha256:112eaa7865dd9e6d7c0556c8b04ae3c3a2dc35d62ad3373ab7f6a562d8199200"},
{file = "grpcio-1.62.2-cp312-cp312-linux_armv7l.whl", hash = "sha256:65034473fc09628a02fb85f26e73885cf1ed39ebd9cf270247b38689ff5942c5"},
{file = "grpcio-1.62.2-cp312-cp312-macosx_10_10_universal2.whl", hash = "sha256:d2c1771d0ee3cf72d69bb5e82c6a82f27fbd504c8c782575eddb7839729fbaad"},
{file = "grpcio-1.62.2-cp312-cp312-manylinux_2_17_aarch64.whl", hash = "sha256:3abe6838196da518863b5d549938ce3159d809218936851b395b09cad9b5d64a"},
{file = "grpcio-1.62.2-cp312-cp312-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:c5ffeb269f10cedb4f33142b89a061acda9f672fd1357331dbfd043422c94e9e"},
{file = "grpcio-1.62.2-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:404d3b4b6b142b99ba1cff0b2177d26b623101ea2ce51c25ef6e53d9d0d87bcc"},
{file = "grpcio-1.62.2-cp312-cp312-musllinux_1_1_i686.whl", hash = "sha256:262cda97efdabb20853d3b5a4c546a535347c14b64c017f628ca0cc7fa780cc6"},
{file = "grpcio-1.62.2-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:17708db5b11b966373e21519c4c73e5a750555f02fde82276ea2a267077c68ad"},
{file = "grpcio-1.62.2-cp312-cp312-win32.whl", hash = "sha256:b7ec9e2f8ffc8436f6b642a10019fc513722858f295f7efc28de135d336ac189"},
{file = "grpcio-1.62.2-cp312-cp312-win_amd64.whl", hash = "sha256:aa787b83a3cd5e482e5c79be030e2b4a122ecc6c5c6c4c42a023a2b581fdf17b"},
{file = "grpcio-1.62.2-cp37-cp37m-linux_armv7l.whl", hash = "sha256:cfd23ad29bfa13fd4188433b0e250f84ec2c8ba66b14a9877e8bce05b524cf54"},
{file = "grpcio-1.62.2-cp37-cp37m-macosx_10_10_universal2.whl", hash = "sha256:af15e9efa4d776dfcecd1d083f3ccfb04f876d613e90ef8432432efbeeac689d"},
{file = "grpcio-1.62.2-cp37-cp37m-manylinux_2_17_aarch64.whl", hash = "sha256:f4aa94361bb5141a45ca9187464ae81a92a2a135ce2800b2203134f7a1a1d479"},
{file = "grpcio-1.62.2-cp37-cp37m-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:82af3613a219512a28ee5c95578eb38d44dd03bca02fd918aa05603c41018051"},
{file = "grpcio-1.62.2-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:55ddaf53474e8caeb29eb03e3202f9d827ad3110475a21245f3c7712022882a9"},
{file = "grpcio-1.62.2-cp37-cp37m-musllinux_1_1_i686.whl", hash = "sha256:c79b518c56dddeec79e5500a53d8a4db90da995dfe1738c3ac57fe46348be049"},
{file = "grpcio-1.62.2-cp37-cp37m-musllinux_1_1_x86_64.whl", hash = "sha256:a5eb4844e5e60bf2c446ef38c5b40d7752c6effdee882f716eb57ae87255d20a"},
{file = "grpcio-1.62.2-cp37-cp37m-win_amd64.whl", hash = "sha256:aaae70364a2d1fb238afd6cc9fcb10442b66e397fd559d3f0968d28cc3ac929c"},
{file = "grpcio-1.62.2-cp38-cp38-linux_armv7l.whl", hash = "sha256:1bcfe5070e4406f489e39325b76caeadab28c32bf9252d3ae960c79935a4cc36"},
{file = "grpcio-1.62.2-cp38-cp38-macosx_10_10_universal2.whl", hash = "sha256:da6a7b6b938c15fa0f0568e482efaae9c3af31963eec2da4ff13a6d8ec2888e4"},
{file = "grpcio-1.62.2-cp38-cp38-manylinux_2_17_aarch64.whl", hash = "sha256:41955b641c34db7d84db8d306937b72bc4968eef1c401bea73081a8d6c3d8033"},
{file = "grpcio-1.62.2-cp38-cp38-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:c772f225483905f675cb36a025969eef9712f4698364ecd3a63093760deea1bc"},
{file = "grpcio-1.62.2-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:07ce1f775d37ca18c7a141300e5b71539690efa1f51fe17f812ca85b5e73262f"},
{file = "grpcio-1.62.2-cp38-cp38-musllinux_1_1_i686.whl", hash = "sha256:26f415f40f4a93579fd648f48dca1c13dfacdfd0290f4a30f9b9aeb745026811"},
{file = "grpcio-1.62.2-cp38-cp38-musllinux_1_1_x86_64.whl", hash = "sha256:db707e3685ff16fc1eccad68527d072ac8bdd2e390f6daa97bc394ea7de4acea"},
{file = "grpcio-1.62.2-cp38-cp38-win32.whl", hash = "sha256:589ea8e75de5fd6df387de53af6c9189c5231e212b9aa306b6b0d4f07520fbb9"},
{file = "grpcio-1.62.2-cp38-cp38-win_amd64.whl", hash = "sha256:3c3ed41f4d7a3aabf0f01ecc70d6b5d00ce1800d4af652a549de3f7cf35c4abd"},
{file = "grpcio-1.62.2-cp39-cp39-linux_armv7l.whl", hash = "sha256:162ccf61499c893831b8437120600290a99c0bc1ce7b51f2c8d21ec87ff6af8b"},
{file = "grpcio-1.62.2-cp39-cp39-macosx_10_10_universal2.whl", hash = "sha256:f27246d7da7d7e3bd8612f63785a7b0c39a244cf14b8dd9dd2f2fab939f2d7f1"},
{file = "grpcio-1.62.2-cp39-cp39-manylinux_2_17_aarch64.whl", hash = "sha256:2507006c8a478f19e99b6fe36a2464696b89d40d88f34e4b709abe57e1337467"},
{file = "grpcio-1.62.2-cp39-cp39-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:a90ac47a8ce934e2c8d71e317d2f9e7e6aaceb2d199de940ce2c2eb611b8c0f4"},
{file = "grpcio-1.62.2-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:99701979bcaaa7de8d5f60476487c5df8f27483624f1f7e300ff4669ee44d1f2"},
{file = "grpcio-1.62.2-cp39-cp39-musllinux_1_1_i686.whl", hash = "sha256:af7dc3f7a44f10863b1b0ecab4078f0a00f561aae1edbd01fd03ad4dcf61c9e9"},
{file = "grpcio-1.62.2-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:fa63245271920786f4cb44dcada4983a3516be8f470924528cf658731864c14b"},
{file = "grpcio-1.62.2-cp39-cp39-win32.whl", hash = "sha256:c6ad9c39704256ed91a1cffc1379d63f7d0278d6a0bad06b0330f5d30291e3a3"},
{file = "grpcio-1.62.2-cp39-cp39-win_amd64.whl", hash = "sha256:16da954692fd61aa4941fbeda405a756cd96b97b5d95ca58a92547bba2c1624f"},
{file = "grpcio-1.62.2.tar.gz", hash = "sha256:c77618071d96b7a8be2c10701a98537823b9c65ba256c0b9067e0594cdbd954d"},
{file = "grpcio-1.63.0-cp310-cp310-linux_armv7l.whl", hash = "sha256:2e93aca840c29d4ab5db93f94ed0a0ca899e241f2e8aec6334ab3575dc46125c"},
{file = "grpcio-1.63.0-cp310-cp310-macosx_12_0_universal2.whl", hash = "sha256:91b73d3f1340fefa1e1716c8c1ec9930c676d6b10a3513ab6c26004cb02d8b3f"},
{file = "grpcio-1.63.0-cp310-cp310-manylinux_2_17_aarch64.whl", hash = "sha256:b3afbd9d6827fa6f475a4f91db55e441113f6d3eb9b7ebb8fb806e5bb6d6bd0d"},
{file = "grpcio-1.63.0-cp310-cp310-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:8f3f6883ce54a7a5f47db43289a0a4c776487912de1a0e2cc83fdaec9685cc9f"},
{file = "grpcio-1.63.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:cf8dae9cc0412cb86c8de5a8f3be395c5119a370f3ce2e69c8b7d46bb9872c8d"},
{file = "grpcio-1.63.0-cp310-cp310-musllinux_1_1_i686.whl", hash = "sha256:08e1559fd3b3b4468486b26b0af64a3904a8dbc78d8d936af9c1cf9636eb3e8b"},
{file = "grpcio-1.63.0-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:5c039ef01516039fa39da8a8a43a95b64e288f79f42a17e6c2904a02a319b357"},
{file = "grpcio-1.63.0-cp310-cp310-win32.whl", hash = "sha256:ad2ac8903b2eae071055a927ef74121ed52d69468e91d9bcbd028bd0e554be6d"},
{file = "grpcio-1.63.0-cp310-cp310-win_amd64.whl", hash = "sha256:b2e44f59316716532a993ca2966636df6fbe7be4ab6f099de6815570ebe4383a"},
{file = "grpcio-1.63.0-cp311-cp311-linux_armv7l.whl", hash = "sha256:f28f8b2db7b86c77916829d64ab21ff49a9d8289ea1564a2b2a3a8ed9ffcccd3"},
{file = "grpcio-1.63.0-cp311-cp311-macosx_10_9_universal2.whl", hash = "sha256:65bf975639a1f93bee63ca60d2e4951f1b543f498d581869922910a476ead2f5"},
{file = "grpcio-1.63.0-cp311-cp311-manylinux_2_17_aarch64.whl", hash = "sha256:b5194775fec7dc3dbd6a935102bb156cd2c35efe1685b0a46c67b927c74f0cfb"},
{file = "grpcio-1.63.0-cp311-cp311-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:e4cbb2100ee46d024c45920d16e888ee5d3cf47c66e316210bc236d5bebc42b3"},
{file = "grpcio-1.63.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:1ff737cf29b5b801619f10e59b581869e32f400159e8b12d7a97e7e3bdeee6a2"},
{file = "grpcio-1.63.0-cp311-cp311-musllinux_1_1_i686.whl", hash = "sha256:cd1e68776262dd44dedd7381b1a0ad09d9930ffb405f737d64f505eb7f77d6c7"},
{file = "grpcio-1.63.0-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:93f45f27f516548e23e4ec3fbab21b060416007dbe768a111fc4611464cc773f"},
{file = "grpcio-1.63.0-cp311-cp311-win32.whl", hash = "sha256:878b1d88d0137df60e6b09b74cdb73db123f9579232c8456f53e9abc4f62eb3c"},
{file = "grpcio-1.63.0-cp311-cp311-win_amd64.whl", hash = "sha256:756fed02dacd24e8f488f295a913f250b56b98fb793f41d5b2de6c44fb762434"},
{file = "grpcio-1.63.0-cp312-cp312-linux_armv7l.whl", hash = "sha256:93a46794cc96c3a674cdfb59ef9ce84d46185fe9421baf2268ccb556f8f81f57"},
{file = "grpcio-1.63.0-cp312-cp312-macosx_10_9_universal2.whl", hash = "sha256:a7b19dfc74d0be7032ca1eda0ed545e582ee46cd65c162f9e9fc6b26ef827dc6"},
{file = "grpcio-1.63.0-cp312-cp312-manylinux_2_17_aarch64.whl", hash = "sha256:8064d986d3a64ba21e498b9a376cbc5d6ab2e8ab0e288d39f266f0fca169b90d"},
{file = "grpcio-1.63.0-cp312-cp312-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:219bb1848cd2c90348c79ed0a6b0ea51866bc7e72fa6e205e459fedab5770172"},
{file = "grpcio-1.63.0-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:a2d60cd1d58817bc5985fae6168d8b5655c4981d448d0f5b6194bbcc038090d2"},
{file = "grpcio-1.63.0-cp312-cp312-musllinux_1_1_i686.whl", hash = "sha256:9e350cb096e5c67832e9b6e018cf8a0d2a53b2a958f6251615173165269a91b0"},
{file = "grpcio-1.63.0-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:56cdf96ff82e3cc90dbe8bac260352993f23e8e256e063c327b6cf9c88daf7a9"},
{file = "grpcio-1.63.0-cp312-cp312-win32.whl", hash = "sha256:3a6d1f9ea965e750db7b4ee6f9fdef5fdf135abe8a249e75d84b0a3e0c668a1b"},
{file = "grpcio-1.63.0-cp312-cp312-win_amd64.whl", hash = "sha256:d2497769895bb03efe3187fb1888fc20e98a5f18b3d14b606167dacda5789434"},
{file = "grpcio-1.63.0-cp38-cp38-linux_armv7l.whl", hash = "sha256:fdf348ae69c6ff484402cfdb14e18c1b0054ac2420079d575c53a60b9b2853ae"},
{file = "grpcio-1.63.0-cp38-cp38-macosx_10_9_universal2.whl", hash = "sha256:a3abfe0b0f6798dedd2e9e92e881d9acd0fdb62ae27dcbbfa7654a57e24060c0"},
{file = "grpcio-1.63.0-cp38-cp38-manylinux_2_17_aarch64.whl", hash = "sha256:6ef0ad92873672a2a3767cb827b64741c363ebaa27e7f21659e4e31f4d750280"},
{file = "grpcio-1.63.0-cp38-cp38-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:b416252ac5588d9dfb8a30a191451adbf534e9ce5f56bb02cd193f12d8845b7f"},
{file = "grpcio-1.63.0-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:e3b77eaefc74d7eb861d3ffbdf91b50a1bb1639514ebe764c47773b833fa2d91"},
{file = "grpcio-1.63.0-cp38-cp38-musllinux_1_1_i686.whl", hash = "sha256:b005292369d9c1f80bf70c1db1c17c6c342da7576f1c689e8eee4fb0c256af85"},
{file = "grpcio-1.63.0-cp38-cp38-musllinux_1_1_x86_64.whl", hash = "sha256:cdcda1156dcc41e042d1e899ba1f5c2e9f3cd7625b3d6ebfa619806a4c1aadda"},
{file = "grpcio-1.63.0-cp38-cp38-win32.whl", hash = "sha256:01799e8649f9e94ba7db1aeb3452188048b0019dc37696b0f5ce212c87c560c3"},
{file = "grpcio-1.63.0-cp38-cp38-win_amd64.whl", hash = "sha256:6a1a3642d76f887aa4009d92f71eb37809abceb3b7b5a1eec9c554a246f20e3a"},
{file = "grpcio-1.63.0-cp39-cp39-linux_armv7l.whl", hash = "sha256:75f701ff645858a2b16bc8c9fc68af215a8bb2d5a9b647448129de6e85d52bce"},
{file = "grpcio-1.63.0-cp39-cp39-macosx_10_9_universal2.whl", hash = "sha256:cacdef0348a08e475a721967f48206a2254a1b26ee7637638d9e081761a5ba86"},
{file = "grpcio-1.63.0-cp39-cp39-manylinux_2_17_aarch64.whl", hash = "sha256:0697563d1d84d6985e40ec5ec596ff41b52abb3fd91ec240e8cb44a63b895094"},
{file = "grpcio-1.63.0-cp39-cp39-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:6426e1fb92d006e47476d42b8f240c1d916a6d4423c5258ccc5b105e43438f61"},
{file = "grpcio-1.63.0-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:e48cee31bc5f5a31fb2f3b573764bd563aaa5472342860edcc7039525b53e46a"},
{file = "grpcio-1.63.0-cp39-cp39-musllinux_1_1_i686.whl", hash = "sha256:50344663068041b34a992c19c600236e7abb42d6ec32567916b87b4c8b8833b3"},
{file = "grpcio-1.63.0-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:259e11932230d70ef24a21b9fb5bb947eb4703f57865a404054400ee92f42f5d"},
{file = "grpcio-1.63.0-cp39-cp39-win32.whl", hash = "sha256:a44624aad77bf8ca198c55af811fd28f2b3eaf0a50ec5b57b06c034416ef2d0a"},
{file = "grpcio-1.63.0-cp39-cp39-win_amd64.whl", hash = "sha256:166e5c460e5d7d4656ff9e63b13e1f6029b122104c1633d5f37eaea348d7356d"},
{file = "grpcio-1.63.0.tar.gz", hash = "sha256:f3023e14805c61bc439fb40ca545ac3d5740ce66120a678a3c6c2c55b70343d1"},
]
[package.extras]
protobuf = ["grpcio-tools (>=1.62.2)"]
protobuf = ["grpcio-tools (>=1.63.0)"]
[[package]]
name = "gym-aloha"
@ -2379,7 +2371,6 @@ optional = false
python-versions = ">=3.9"
files = [
{file = "pandas-2.2.2-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:90c6fca2acf139569e74e8781709dccb6fe25940488755716d1d354d6bc58bce"},
{file = "pandas-2.2.2-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:c7adfc142dac335d8c1e0dcbd37eb8617eac386596eb9e1a1b77791cf2498238"},
{file = "pandas-2.2.2-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:4abfe0be0d7221be4f12552995e58723c7422c80a659da13ca382697de830c08"},
{file = "pandas-2.2.2-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:8635c16bf3d99040fdf3ca3db669a7250ddf49c55dc4aa8fe0ae0fa8d6dcc1f0"},
{file = "pandas-2.2.2-cp310-cp310-musllinux_1_1_aarch64.whl", hash = "sha256:40ae1dffb3967a52203105a077415a86044a2bea011b5f321c6aa64b379a3f51"},
@ -2400,7 +2391,6 @@ files = [
{file = "pandas-2.2.2-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:43498c0bdb43d55cb162cdc8c06fac328ccb5d2eabe3cadeb3529ae6f0517c32"},
{file = "pandas-2.2.2-cp312-cp312-win_amd64.whl", hash = "sha256:d187d355ecec3629624fccb01d104da7d7f391db0311145817525281e2804d23"},
{file = "pandas-2.2.2-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:0ca6377b8fca51815f382bd0b697a0814c8bda55115678cbc94c30aacbb6eff2"},
{file = "pandas-2.2.2-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:9057e6aa78a584bc93a13f0a9bf7e753a5e9770a30b4d758b8d5f2a62a9433cd"},
{file = "pandas-2.2.2-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:001910ad31abc7bf06f49dcc903755d2f7f3a9186c0c040b827e522e9cef0863"},
{file = "pandas-2.2.2-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:66b479b0bd07204e37583c191535505410daa8df638fd8e75ae1b383851fe921"},
{file = "pandas-2.2.2-cp39-cp39-musllinux_1_1_aarch64.whl", hash = "sha256:a77e9d1c386196879aa5eb712e77461aaee433e54c68cf253053a73b7e49c33a"},
@ -3023,104 +3013,90 @@ files = [
[[package]]
name = "regex"
version = "2024.4.16"
version = "2024.4.28"
description = "Alternative regular expression module, to replace re."
optional = false
python-versions = ">=3.7"
python-versions = ">=3.8"
files = [
{file = "regex-2024.4.16-cp310-cp310-macosx_10_9_universal2.whl", hash = "sha256:fb83cc090eac63c006871fd24db5e30a1f282faa46328572661c0a24a2323a08"},
{file = "regex-2024.4.16-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:8c91e1763696c0eb66340c4df98623c2d4e77d0746b8f8f2bee2c6883fd1fe18"},
{file = "regex-2024.4.16-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:10188fe732dec829c7acca7422cdd1bf57d853c7199d5a9e96bb4d40db239c73"},
{file = "regex-2024.4.16-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:956b58d692f235cfbf5b4f3abd6d99bf102f161ccfe20d2fd0904f51c72c4c66"},
{file = "regex-2024.4.16-cp310-cp310-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:a70b51f55fd954d1f194271695821dd62054d949efd6368d8be64edd37f55c86"},
{file = "regex-2024.4.16-cp310-cp310-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:5c02fcd2bf45162280613d2e4a1ca3ac558ff921ae4e308ecb307650d3a6ee51"},
{file = "regex-2024.4.16-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:c4ed75ea6892a56896d78f11006161eea52c45a14994794bcfa1654430984b22"},
{file = "regex-2024.4.16-cp310-cp310-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:bd727ad276bb91928879f3aa6396c9a1d34e5e180dce40578421a691eeb77f47"},
{file = "regex-2024.4.16-cp310-cp310-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_12_x86_64.manylinux2010_x86_64.whl", hash = "sha256:7cbc5d9e8a1781e7be17da67b92580d6ce4dcef5819c1b1b89f49d9678cc278c"},
{file = "regex-2024.4.16-cp310-cp310-musllinux_1_1_aarch64.whl", hash = "sha256:78fddb22b9ef810b63ef341c9fcf6455232d97cfe03938cbc29e2672c436670e"},
{file = "regex-2024.4.16-cp310-cp310-musllinux_1_1_i686.whl", hash = "sha256:445ca8d3c5a01309633a0c9db57150312a181146315693273e35d936472df912"},
{file = "regex-2024.4.16-cp310-cp310-musllinux_1_1_ppc64le.whl", hash = "sha256:95399831a206211d6bc40224af1c635cb8790ddd5c7493e0bd03b85711076a53"},
{file = "regex-2024.4.16-cp310-cp310-musllinux_1_1_s390x.whl", hash = "sha256:7731728b6568fc286d86745f27f07266de49603a6fdc4d19c87e8c247be452af"},
{file = "regex-2024.4.16-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:4facc913e10bdba42ec0aee76d029aedda628161a7ce4116b16680a0413f658a"},
{file = "regex-2024.4.16-cp310-cp310-win32.whl", hash = "sha256:911742856ce98d879acbea33fcc03c1d8dc1106234c5e7d068932c945db209c0"},
{file = "regex-2024.4.16-cp310-cp310-win_amd64.whl", hash = "sha256:e0a2df336d1135a0b3a67f3bbf78a75f69562c1199ed9935372b82215cddd6e2"},
{file = "regex-2024.4.16-cp311-cp311-macosx_10_9_universal2.whl", hash = "sha256:1210365faba7c2150451eb78ec5687871c796b0f1fa701bfd2a4a25420482d26"},
{file = "regex-2024.4.16-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:9ab40412f8cd6f615bfedea40c8bf0407d41bf83b96f6fc9ff34976d6b7037fd"},
{file = "regex-2024.4.16-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:fd80d1280d473500d8086d104962a82d77bfbf2b118053824b7be28cd5a79ea5"},
{file = "regex-2024.4.16-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:7bb966fdd9217e53abf824f437a5a2d643a38d4fd5fd0ca711b9da683d452969"},
{file = "regex-2024.4.16-cp311-cp311-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:20b7a68444f536365af42a75ccecb7ab41a896a04acf58432db9e206f4e525d6"},
{file = "regex-2024.4.16-cp311-cp311-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:b74586dd0b039c62416034f811d7ee62810174bb70dffcca6439f5236249eb09"},
{file = "regex-2024.4.16-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:0c8290b44d8b0af4e77048646c10c6e3aa583c1ca67f3b5ffb6e06cf0c6f0f89"},
{file = "regex-2024.4.16-cp311-cp311-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:f2d80a6749724b37853ece57988b39c4e79d2b5fe2869a86e8aeae3bbeef9eb0"},
{file = "regex-2024.4.16-cp311-cp311-musllinux_1_1_aarch64.whl", hash = "sha256:3a1018e97aeb24e4f939afcd88211ace472ba566efc5bdf53fd8fd7f41fa7170"},
{file = "regex-2024.4.16-cp311-cp311-musllinux_1_1_i686.whl", hash = "sha256:8d015604ee6204e76569d2f44e5a210728fa917115bef0d102f4107e622b08d5"},
{file = "regex-2024.4.16-cp311-cp311-musllinux_1_1_ppc64le.whl", hash = "sha256:3d5ac5234fb5053850d79dd8eb1015cb0d7d9ed951fa37aa9e6249a19aa4f336"},
{file = "regex-2024.4.16-cp311-cp311-musllinux_1_1_s390x.whl", hash = "sha256:0a38d151e2cdd66d16dab550c22f9521ba79761423b87c01dae0a6e9add79c0d"},
{file = "regex-2024.4.16-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:159dc4e59a159cb8e4e8f8961eb1fa5d58f93cb1acd1701d8aff38d45e1a84a6"},
{file = "regex-2024.4.16-cp311-cp311-win32.whl", hash = "sha256:ba2336d6548dee3117520545cfe44dc28a250aa091f8281d28804aa8d707d93d"},
{file = "regex-2024.4.16-cp311-cp311-win_amd64.whl", hash = "sha256:8f83b6fd3dc3ba94d2b22717f9c8b8512354fd95221ac661784df2769ea9bba9"},
{file = "regex-2024.4.16-cp312-cp312-macosx_10_9_universal2.whl", hash = "sha256:80b696e8972b81edf0af2a259e1b2a4a661f818fae22e5fa4fa1a995fb4a40fd"},
{file = "regex-2024.4.16-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:d61ae114d2a2311f61d90c2ef1358518e8f05eafda76eaf9c772a077e0b465ec"},
{file = "regex-2024.4.16-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:8ba6745440b9a27336443b0c285d705ce73adb9ec90e2f2004c64d95ab5a7598"},
{file = "regex-2024.4.16-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:6295004b2dd37b0835ea5c14a33e00e8cfa3c4add4d587b77287825f3418d310"},
{file = "regex-2024.4.16-cp312-cp312-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:4aba818dcc7263852aabb172ec27b71d2abca02a593b95fa79351b2774eb1d2b"},
{file = "regex-2024.4.16-cp312-cp312-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:d0800631e565c47520aaa04ae38b96abc5196fe8b4aa9bd864445bd2b5848a7a"},
{file = "regex-2024.4.16-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:08dea89f859c3df48a440dbdcd7b7155bc675f2fa2ec8c521d02dc69e877db70"},
{file = "regex-2024.4.16-cp312-cp312-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:eeaa0b5328b785abc344acc6241cffde50dc394a0644a968add75fcefe15b9d4"},
{file = "regex-2024.4.16-cp312-cp312-musllinux_1_1_aarch64.whl", hash = "sha256:4e819a806420bc010489f4e741b3036071aba209f2e0989d4750b08b12a9343f"},
{file = "regex-2024.4.16-cp312-cp312-musllinux_1_1_i686.whl", hash = "sha256:c2d0e7cbb6341e830adcbfa2479fdeebbfbb328f11edd6b5675674e7a1e37730"},
{file = "regex-2024.4.16-cp312-cp312-musllinux_1_1_ppc64le.whl", hash = "sha256:91797b98f5e34b6a49f54be33f72e2fb658018ae532be2f79f7c63b4ae225145"},
{file = "regex-2024.4.16-cp312-cp312-musllinux_1_1_s390x.whl", hash = "sha256:d2da13568eff02b30fd54fccd1e042a70fe920d816616fda4bf54ec705668d81"},
{file = "regex-2024.4.16-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:370c68dc5570b394cbaadff50e64d705f64debed30573e5c313c360689b6aadc"},
{file = "regex-2024.4.16-cp312-cp312-win32.whl", hash = "sha256:904c883cf10a975b02ab3478bce652f0f5346a2c28d0a8521d97bb23c323cc8b"},
{file = "regex-2024.4.16-cp312-cp312-win_amd64.whl", hash = "sha256:785c071c982dce54d44ea0b79cd6dfafddeccdd98cfa5f7b86ef69b381b457d9"},
{file = "regex-2024.4.16-cp37-cp37m-macosx_10_9_x86_64.whl", hash = "sha256:e2f142b45c6fed48166faeb4303b4b58c9fcd827da63f4cf0a123c3480ae11fb"},
{file = "regex-2024.4.16-cp37-cp37m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:e87ab229332ceb127a165612d839ab87795972102cb9830e5f12b8c9a5c1b508"},
{file = "regex-2024.4.16-cp37-cp37m-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:81500ed5af2090b4a9157a59dbc89873a25c33db1bb9a8cf123837dcc9765047"},
{file = "regex-2024.4.16-cp37-cp37m-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:b340cccad138ecb363324aa26893963dcabb02bb25e440ebdf42e30963f1a4e0"},
{file = "regex-2024.4.16-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:2c72608e70f053643437bd2be0608f7f1c46d4022e4104d76826f0839199347a"},
{file = "regex-2024.4.16-cp37-cp37m-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:a01fe2305e6232ef3e8f40bfc0f0f3a04def9aab514910fa4203bafbc0bb4682"},
{file = "regex-2024.4.16-cp37-cp37m-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_12_x86_64.manylinux2010_x86_64.whl", hash = "sha256:03576e3a423d19dda13e55598f0fd507b5d660d42c51b02df4e0d97824fdcae3"},
{file = "regex-2024.4.16-cp37-cp37m-musllinux_1_1_aarch64.whl", hash = "sha256:549c3584993772e25f02d0656ac48abdda73169fe347263948cf2b1cead622f3"},
{file = "regex-2024.4.16-cp37-cp37m-musllinux_1_1_i686.whl", hash = "sha256:34422d5a69a60b7e9a07a690094e824b66f5ddc662a5fc600d65b7c174a05f04"},
{file = "regex-2024.4.16-cp37-cp37m-musllinux_1_1_ppc64le.whl", hash = "sha256:5f580c651a72b75c39e311343fe6875d6f58cf51c471a97f15a938d9fe4e0d37"},
{file = "regex-2024.4.16-cp37-cp37m-musllinux_1_1_s390x.whl", hash = "sha256:3399dd8a7495bbb2bacd59b84840eef9057826c664472e86c91d675d007137f5"},
{file = "regex-2024.4.16-cp37-cp37m-musllinux_1_1_x86_64.whl", hash = "sha256:8d1f86f3f4e2388aa3310b50694ac44daefbd1681def26b4519bd050a398dc5a"},
{file = "regex-2024.4.16-cp37-cp37m-win32.whl", hash = "sha256:dd5acc0a7d38fdc7a3a6fd3ad14c880819008ecb3379626e56b163165162cc46"},
{file = "regex-2024.4.16-cp37-cp37m-win_amd64.whl", hash = "sha256:ba8122e3bb94ecda29a8de4cf889f600171424ea586847aa92c334772d200331"},
{file = "regex-2024.4.16-cp38-cp38-macosx_10_9_universal2.whl", hash = "sha256:743deffdf3b3481da32e8a96887e2aa945ec6685af1cfe2bcc292638c9ba2f48"},
{file = "regex-2024.4.16-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:7571f19f4a3fd00af9341c7801d1ad1967fc9c3f5e62402683047e7166b9f2b4"},
{file = "regex-2024.4.16-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:df79012ebf6f4efb8d307b1328226aef24ca446b3ff8d0e30202d7ebcb977a8c"},
{file = "regex-2024.4.16-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:e757d475953269fbf4b441207bb7dbdd1c43180711b6208e129b637792ac0b93"},
{file = "regex-2024.4.16-cp38-cp38-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:4313ab9bf6a81206c8ac28fdfcddc0435299dc88cad12cc6305fd0e78b81f9e4"},
{file = "regex-2024.4.16-cp38-cp38-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:d83c2bc678453646f1a18f8db1e927a2d3f4935031b9ad8a76e56760461105dd"},
{file = "regex-2024.4.16-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:9df1bfef97db938469ef0a7354b2d591a2d438bc497b2c489471bec0e6baf7c4"},
{file = "regex-2024.4.16-cp38-cp38-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:62120ed0de69b3649cc68e2965376048793f466c5a6c4370fb27c16c1beac22d"},
{file = "regex-2024.4.16-cp38-cp38-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_12_x86_64.manylinux2010_x86_64.whl", hash = "sha256:c2ef6f7990b6e8758fe48ad08f7e2f66c8f11dc66e24093304b87cae9037bb4a"},
{file = "regex-2024.4.16-cp38-cp38-musllinux_1_1_aarch64.whl", hash = "sha256:8fc6976a3395fe4d1fbeb984adaa8ec652a1e12f36b56ec8c236e5117b585427"},
{file = "regex-2024.4.16-cp38-cp38-musllinux_1_1_i686.whl", hash = "sha256:03e68f44340528111067cecf12721c3df4811c67268b897fbe695c95f860ac42"},
{file = "regex-2024.4.16-cp38-cp38-musllinux_1_1_ppc64le.whl", hash = "sha256:ec7e0043b91115f427998febaa2beb82c82df708168b35ece3accb610b91fac1"},
{file = "regex-2024.4.16-cp38-cp38-musllinux_1_1_s390x.whl", hash = "sha256:c21fc21a4c7480479d12fd8e679b699f744f76bb05f53a1d14182b31f55aac76"},
{file = "regex-2024.4.16-cp38-cp38-musllinux_1_1_x86_64.whl", hash = "sha256:12f6a3f2f58bb7344751919a1876ee1b976fe08b9ffccb4bbea66f26af6017b9"},
{file = "regex-2024.4.16-cp38-cp38-win32.whl", hash = "sha256:479595a4fbe9ed8f8f72c59717e8cf222da2e4c07b6ae5b65411e6302af9708e"},
{file = "regex-2024.4.16-cp38-cp38-win_amd64.whl", hash = "sha256:0534b034fba6101611968fae8e856c1698da97ce2efb5c2b895fc8b9e23a5834"},
{file = "regex-2024.4.16-cp39-cp39-macosx_10_9_universal2.whl", hash = "sha256:a7ccdd1c4a3472a7533b0a7aa9ee34c9a2bef859ba86deec07aff2ad7e0c3b94"},
{file = "regex-2024.4.16-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:6f2f017c5be19984fbbf55f8af6caba25e62c71293213f044da3ada7091a4455"},
{file = "regex-2024.4.16-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:803b8905b52de78b173d3c1e83df0efb929621e7b7c5766c0843704d5332682f"},
{file = "regex-2024.4.16-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:684008ec44ad275832a5a152f6e764bbe1914bea10968017b6feaecdad5736e0"},
{file = "regex-2024.4.16-cp39-cp39-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:65436dce9fdc0aeeb0a0effe0839cb3d6a05f45aa45a4d9f9c60989beca78b9c"},
{file = "regex-2024.4.16-cp39-cp39-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:ea355eb43b11764cf799dda62c658c4d2fdb16af41f59bb1ccfec517b60bcb07"},
{file = "regex-2024.4.16-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:98c1165f3809ce7774f05cb74e5408cd3aa93ee8573ae959a97a53db3ca3180d"},
{file = "regex-2024.4.16-cp39-cp39-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:cccc79a9be9b64c881f18305a7c715ba199e471a3973faeb7ba84172abb3f317"},
{file = "regex-2024.4.16-cp39-cp39-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_12_x86_64.manylinux2010_x86_64.whl", hash = "sha256:00169caa125f35d1bca6045d65a662af0202704489fada95346cfa092ec23f39"},
{file = "regex-2024.4.16-cp39-cp39-musllinux_1_1_aarch64.whl", hash = "sha256:6cc38067209354e16c5609b66285af17a2863a47585bcf75285cab33d4c3b8df"},
{file = "regex-2024.4.16-cp39-cp39-musllinux_1_1_i686.whl", hash = "sha256:23cff1b267038501b179ccbbd74a821ac4a7192a1852d1d558e562b507d46013"},
{file = "regex-2024.4.16-cp39-cp39-musllinux_1_1_ppc64le.whl", hash = "sha256:b9d320b3bf82a39f248769fc7f188e00f93526cc0fe739cfa197868633d44701"},
{file = "regex-2024.4.16-cp39-cp39-musllinux_1_1_s390x.whl", hash = "sha256:89ec7f2c08937421bbbb8b48c54096fa4f88347946d4747021ad85f1b3021b3c"},
{file = "regex-2024.4.16-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:4918fd5f8b43aa7ec031e0fef1ee02deb80b6afd49c85f0790be1dc4ce34cb50"},
{file = "regex-2024.4.16-cp39-cp39-win32.whl", hash = "sha256:684e52023aec43bdf0250e843e1fdd6febbe831bd9d52da72333fa201aaa2335"},
{file = "regex-2024.4.16-cp39-cp39-win_amd64.whl", hash = "sha256:e697e1c0238133589e00c244a8b676bc2cfc3ab4961318d902040d099fec7483"},
{file = "regex-2024.4.16.tar.gz", hash = "sha256:fa454d26f2e87ad661c4f0c5a5fe4cf6aab1e307d1b94f16ffdfcb089ba685c0"},
{file = "regex-2024.4.28-cp310-cp310-macosx_10_9_universal2.whl", hash = "sha256:cd196d056b40af073d95a2879678585f0b74ad35190fac04ca67954c582c6b61"},
{file = "regex-2024.4.28-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:8bb381f777351bd534462f63e1c6afb10a7caa9fa2a421ae22c26e796fe31b1f"},
{file = "regex-2024.4.28-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:47af45b6153522733aa6e92543938e97a70ce0900649ba626cf5aad290b737b6"},
{file = "regex-2024.4.28-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:99d6a550425cc51c656331af0e2b1651e90eaaa23fb4acde577cf15068e2e20f"},
{file = "regex-2024.4.28-cp310-cp310-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:bf29304a8011feb58913c382902fde3395957a47645bf848eea695839aa101b7"},
{file = "regex-2024.4.28-cp310-cp310-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:92da587eee39a52c91aebea8b850e4e4f095fe5928d415cb7ed656b3460ae79a"},
{file = "regex-2024.4.28-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:6277d426e2f31bdbacb377d17a7475e32b2d7d1f02faaecc48d8e370c6a3ff31"},
{file = "regex-2024.4.28-cp310-cp310-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:28e1f28d07220c0f3da0e8fcd5a115bbb53f8b55cecf9bec0c946eb9a059a94c"},
{file = "regex-2024.4.28-cp310-cp310-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_12_x86_64.manylinux2010_x86_64.whl", hash = "sha256:aaa179975a64790c1f2701ac562b5eeb733946eeb036b5bcca05c8d928a62f10"},
{file = "regex-2024.4.28-cp310-cp310-musllinux_1_1_aarch64.whl", hash = "sha256:6f435946b7bf7a1b438b4e6b149b947c837cb23c704e780c19ba3e6855dbbdd3"},
{file = "regex-2024.4.28-cp310-cp310-musllinux_1_1_i686.whl", hash = "sha256:19d6c11bf35a6ad077eb23852827f91c804eeb71ecb85db4ee1386825b9dc4db"},
{file = "regex-2024.4.28-cp310-cp310-musllinux_1_1_ppc64le.whl", hash = "sha256:fdae0120cddc839eb8e3c15faa8ad541cc6d906d3eb24d82fb041cfe2807bc1e"},
{file = "regex-2024.4.28-cp310-cp310-musllinux_1_1_s390x.whl", hash = "sha256:e672cf9caaf669053121f1766d659a8813bd547edef6e009205378faf45c67b8"},
{file = "regex-2024.4.28-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:f57515750d07e14743db55d59759893fdb21d2668f39e549a7d6cad5d70f9fea"},
{file = "regex-2024.4.28-cp310-cp310-win32.whl", hash = "sha256:a1409c4eccb6981c7baabc8888d3550df518add6e06fe74fa1d9312c1838652d"},
{file = "regex-2024.4.28-cp310-cp310-win_amd64.whl", hash = "sha256:1f687a28640f763f23f8a9801fe9e1b37338bb1ca5d564ddd41619458f1f22d1"},
{file = "regex-2024.4.28-cp311-cp311-macosx_10_9_universal2.whl", hash = "sha256:84077821c85f222362b72fdc44f7a3a13587a013a45cf14534df1cbbdc9a6796"},
{file = "regex-2024.4.28-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:b45d4503de8f4f3dc02f1d28a9b039e5504a02cc18906cfe744c11def942e9eb"},
{file = "regex-2024.4.28-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:457c2cd5a646dd4ed536c92b535d73548fb8e216ebee602aa9f48e068fc393f3"},
{file = "regex-2024.4.28-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:2b51739ddfd013c6f657b55a508de8b9ea78b56d22b236052c3a85a675102dc6"},
{file = "regex-2024.4.28-cp311-cp311-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:459226445c7d7454981c4c0ce0ad1a72e1e751c3e417f305722bbcee6697e06a"},
{file = "regex-2024.4.28-cp311-cp311-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:670fa596984b08a4a769491cbdf22350431970d0112e03d7e4eeaecaafcd0fec"},
{file = "regex-2024.4.28-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:fe00f4fe11c8a521b173e6324d862ee7ee3412bf7107570c9b564fe1119b56fb"},
{file = "regex-2024.4.28-cp311-cp311-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:36f392dc7763fe7924575475736bddf9ab9f7a66b920932d0ea50c2ded2f5636"},
{file = "regex-2024.4.28-cp311-cp311-musllinux_1_1_aarch64.whl", hash = "sha256:23a412b7b1a7063f81a742463f38821097b6a37ce1e5b89dd8e871d14dbfd86b"},
{file = "regex-2024.4.28-cp311-cp311-musllinux_1_1_i686.whl", hash = "sha256:f1d6e4b7b2ae3a6a9df53efbf199e4bfcff0959dbdb5fd9ced34d4407348e39a"},
{file = "regex-2024.4.28-cp311-cp311-musllinux_1_1_ppc64le.whl", hash = "sha256:499334ad139557de97cbc4347ee921c0e2b5e9c0f009859e74f3f77918339257"},
{file = "regex-2024.4.28-cp311-cp311-musllinux_1_1_s390x.whl", hash = "sha256:0940038bec2fe9e26b203d636c44d31dd8766abc1fe66262da6484bd82461ccf"},
{file = "regex-2024.4.28-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:66372c2a01782c5fe8e04bff4a2a0121a9897e19223d9eab30c54c50b2ebeb7f"},
{file = "regex-2024.4.28-cp311-cp311-win32.whl", hash = "sha256:c77d10ec3c1cf328b2f501ca32583625987ea0f23a0c2a49b37a39ee5c4c4630"},
{file = "regex-2024.4.28-cp311-cp311-win_amd64.whl", hash = "sha256:fc0916c4295c64d6890a46e02d4482bb5ccf33bf1a824c0eaa9e83b148291f90"},
{file = "regex-2024.4.28-cp312-cp312-macosx_10_9_universal2.whl", hash = "sha256:08a1749f04fee2811c7617fdd46d2e46d09106fa8f475c884b65c01326eb15c5"},
{file = "regex-2024.4.28-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:b8eb28995771c087a73338f695a08c9abfdf723d185e57b97f6175c5051ff1ae"},
{file = "regex-2024.4.28-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:dd7ef715ccb8040954d44cfeff17e6b8e9f79c8019daae2fd30a8806ef5435c0"},
{file = "regex-2024.4.28-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:fb0315a2b26fde4005a7c401707c5352df274460f2f85b209cf6024271373013"},
{file = "regex-2024.4.28-cp312-cp312-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:f2fc053228a6bd3a17a9b0a3f15c3ab3cf95727b00557e92e1cfe094b88cc662"},
{file = "regex-2024.4.28-cp312-cp312-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:7fe9739a686dc44733d52d6e4f7b9c77b285e49edf8570754b322bca6b85b4cc"},
{file = "regex-2024.4.28-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:a74fcf77d979364f9b69fcf8200849ca29a374973dc193a7317698aa37d8b01c"},
{file = "regex-2024.4.28-cp312-cp312-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:965fd0cf4694d76f6564896b422724ec7b959ef927a7cb187fc6b3f4e4f59833"},
{file = "regex-2024.4.28-cp312-cp312-musllinux_1_1_aarch64.whl", hash = "sha256:2fef0b38c34ae675fcbb1b5db760d40c3fc3612cfa186e9e50df5782cac02bcd"},
{file = "regex-2024.4.28-cp312-cp312-musllinux_1_1_i686.whl", hash = "sha256:bc365ce25f6c7c5ed70e4bc674f9137f52b7dd6a125037f9132a7be52b8a252f"},
{file = "regex-2024.4.28-cp312-cp312-musllinux_1_1_ppc64le.whl", hash = "sha256:ac69b394764bb857429b031d29d9604842bc4cbfd964d764b1af1868eeebc4f0"},
{file = "regex-2024.4.28-cp312-cp312-musllinux_1_1_s390x.whl", hash = "sha256:144a1fc54765f5c5c36d6d4b073299832aa1ec6a746a6452c3ee7b46b3d3b11d"},
{file = "regex-2024.4.28-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:2630ca4e152c221072fd4a56d4622b5ada876f668ecd24d5ab62544ae6793ed6"},
{file = "regex-2024.4.28-cp312-cp312-win32.whl", hash = "sha256:7f3502f03b4da52bbe8ba962621daa846f38489cae5c4a7b5d738f15f6443d17"},
{file = "regex-2024.4.28-cp312-cp312-win_amd64.whl", hash = "sha256:0dd3f69098511e71880fb00f5815db9ed0ef62c05775395968299cb400aeab82"},
{file = "regex-2024.4.28-cp38-cp38-macosx_10_9_universal2.whl", hash = "sha256:374f690e1dd0dbdcddea4a5c9bdd97632cf656c69113f7cd6a361f2a67221cb6"},
{file = "regex-2024.4.28-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:25f87ae6b96374db20f180eab083aafe419b194e96e4f282c40191e71980c666"},
{file = "regex-2024.4.28-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:5dbc1bcc7413eebe5f18196e22804a3be1bfdfc7e2afd415e12c068624d48247"},
{file = "regex-2024.4.28-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:f85151ec5a232335f1be022b09fbbe459042ea1951d8a48fef251223fc67eee1"},
{file = "regex-2024.4.28-cp38-cp38-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:57ba112e5530530fd175ed550373eb263db4ca98b5f00694d73b18b9a02e7185"},
{file = "regex-2024.4.28-cp38-cp38-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:224803b74aab56aa7be313f92a8d9911dcade37e5f167db62a738d0c85fdac4b"},
{file = "regex-2024.4.28-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:0a54a047b607fd2d2d52a05e6ad294602f1e0dec2291152b745870afc47c1397"},
{file = "regex-2024.4.28-cp38-cp38-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:0a2a512d623f1f2d01d881513af9fc6a7c46e5cfffb7dc50c38ce959f9246c94"},
{file = "regex-2024.4.28-cp38-cp38-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_12_x86_64.manylinux2010_x86_64.whl", hash = "sha256:c06bf3f38f0707592898428636cbb75d0a846651b053a1cf748763e3063a6925"},
{file = "regex-2024.4.28-cp38-cp38-musllinux_1_1_aarch64.whl", hash = "sha256:1031a5e7b048ee371ab3653aad3030ecfad6ee9ecdc85f0242c57751a05b0ac4"},
{file = "regex-2024.4.28-cp38-cp38-musllinux_1_1_i686.whl", hash = "sha256:d7a353ebfa7154c871a35caca7bfd8f9e18666829a1dc187115b80e35a29393e"},
{file = "regex-2024.4.28-cp38-cp38-musllinux_1_1_ppc64le.whl", hash = "sha256:7e76b9cfbf5ced1aca15a0e5b6f229344d9b3123439ffce552b11faab0114a02"},
{file = "regex-2024.4.28-cp38-cp38-musllinux_1_1_s390x.whl", hash = "sha256:5ce479ecc068bc2a74cb98dd8dba99e070d1b2f4a8371a7dfe631f85db70fe6e"},
{file = "regex-2024.4.28-cp38-cp38-musllinux_1_1_x86_64.whl", hash = "sha256:7d77b6f63f806578c604dca209280e4c54f0fa9a8128bb8d2cc5fb6f99da4150"},
{file = "regex-2024.4.28-cp38-cp38-win32.whl", hash = "sha256:d84308f097d7a513359757c69707ad339da799e53b7393819ec2ea36bc4beb58"},
{file = "regex-2024.4.28-cp38-cp38-win_amd64.whl", hash = "sha256:2cc1b87bba1dd1a898e664a31012725e48af826bf3971e786c53e32e02adae6c"},
{file = "regex-2024.4.28-cp39-cp39-macosx_10_9_universal2.whl", hash = "sha256:7413167c507a768eafb5424413c5b2f515c606be5bb4ef8c5dee43925aa5718b"},
{file = "regex-2024.4.28-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:108e2dcf0b53a7c4ab8986842a8edcb8ab2e59919a74ff51c296772e8e74d0ae"},
{file = "regex-2024.4.28-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:f1c5742c31ba7d72f2dedf7968998730664b45e38827637e0f04a2ac7de2f5f1"},
{file = "regex-2024.4.28-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:ecc6148228c9ae25ce403eade13a0961de1cb016bdb35c6eafd8e7b87ad028b1"},
{file = "regex-2024.4.28-cp39-cp39-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:b7d893c8cf0e2429b823ef1a1d360a25950ed11f0e2a9df2b5198821832e1947"},
{file = "regex-2024.4.28-cp39-cp39-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:4290035b169578ffbbfa50d904d26bec16a94526071ebec3dadbebf67a26b25e"},
{file = "regex-2024.4.28-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:44a22ae1cfd82e4ffa2066eb3390777dc79468f866f0625261a93e44cdf6482b"},
{file = "regex-2024.4.28-cp39-cp39-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:fd24fd140b69f0b0bcc9165c397e9b2e89ecbeda83303abf2a072609f60239e2"},
{file = "regex-2024.4.28-cp39-cp39-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_12_x86_64.manylinux2010_x86_64.whl", hash = "sha256:39fb166d2196413bead229cd64a2ffd6ec78ebab83fff7d2701103cf9f4dfd26"},
{file = "regex-2024.4.28-cp39-cp39-musllinux_1_1_aarch64.whl", hash = "sha256:9301cc6db4d83d2c0719f7fcda37229691745168bf6ae849bea2e85fc769175d"},
{file = "regex-2024.4.28-cp39-cp39-musllinux_1_1_i686.whl", hash = "sha256:7c3d389e8d76a49923683123730c33e9553063d9041658f23897f0b396b2386f"},
{file = "regex-2024.4.28-cp39-cp39-musllinux_1_1_ppc64le.whl", hash = "sha256:99ef6289b62042500d581170d06e17f5353b111a15aa6b25b05b91c6886df8fc"},
{file = "regex-2024.4.28-cp39-cp39-musllinux_1_1_s390x.whl", hash = "sha256:b91d529b47798c016d4b4c1d06cc826ac40d196da54f0de3c519f5a297c5076a"},
{file = "regex-2024.4.28-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:43548ad74ea50456e1c68d3c67fff3de64c6edb85bcd511d1136f9b5376fc9d1"},
{file = "regex-2024.4.28-cp39-cp39-win32.whl", hash = "sha256:05d9b6578a22db7dedb4df81451f360395828b04f4513980b6bd7a1412c679cc"},
{file = "regex-2024.4.28-cp39-cp39-win_amd64.whl", hash = "sha256:3986217ec830c2109875be740531feb8ddafe0dfa49767cdcd072ed7e8927962"},
{file = "regex-2024.4.28.tar.gz", hash = "sha256:83ab366777ea45d58f72593adf35d36ca911ea8bd838483c1823b883a121b0e4"},
]
[[package]]
@ -3959,13 +3935,13 @@ zstd = ["zstandard (>=0.18.0)"]
[[package]]
name = "virtualenv"
version = "20.26.0"
version = "20.26.1"
description = "Virtual Python Environment builder"
optional = true
python-versions = ">=3.7"
files = [
{file = "virtualenv-20.26.0-py3-none-any.whl", hash = "sha256:0846377ea76e818daaa3e00a4365c018bc3ac9760cbb3544de542885aad61fb3"},
{file = "virtualenv-20.26.0.tar.gz", hash = "sha256:ec25a9671a5102c8d2657f62792a27b48f016664c6873f6beed3800008577210"},
{file = "virtualenv-20.26.1-py3-none-any.whl", hash = "sha256:7aa9982a728ae5892558bff6a2839c00b9ed145523ece2274fad6f414690ae75"},
{file = "virtualenv-20.26.1.tar.gz", hash = "sha256:604bfdceaeece392802e6ae48e69cec49168b9c5f4a44e483963f9242eb0e78b"},
]
[package.dependencies]

View File

@ -46,7 +46,7 @@ def test_examples_3_and_2():
# Pass empty globals to allow dictionary comprehension https://stackoverflow.com/a/32897127/4391249.
exec(file_contents, {})
for file_name in ["model.pt", "config.yaml"]:
for file_name in ["model.safetensors", "config.json", "config.yaml"]:
assert Path(f"outputs/train/example_pusht_diffusion/{file_name}").exists()
path = "examples/2_evaluate_pretrained_policy.py"
@ -58,15 +58,15 @@ def test_examples_3_and_2():
file_contents = _find_and_replace(
file_contents,
[
('pretrained_policy_name = "lerobot/diffusion_policy_pusht_image"', ""),
("pretrained_policy_path = Path(snapshot_download(pretrained_policy_name))", ""),
(
'# pretrained_policy_path = Path("outputs/train/example_pusht_diffusion")',
'pretrained_policy_path = Path("outputs/train/example_pusht_diffusion")',
),
('"eval.n_episodes=10"', '"eval.n_episodes=1"'),
('"eval.batch_size=10"', '"eval.batch_size=1"'),
('"device=cuda"', '"device=cpu"'),
(
'# folder = Path("outputs/train/example_pusht_diffusion")',
'folder = Path("outputs/train/example_pusht_diffusion")',
),
('hub_id = "lerobot/diffusion_policy_pusht_image"', ""),
("folder = Path(snapshot_download(hub_id)", ""),
],
)

View File

@ -1,19 +1,33 @@
import inspect
import pytest
import torch
from huggingface_hub import PyTorchModelHubMixin
from lerobot import available_policies
from lerobot.common.datasets.factory import make_dataset
from lerobot.common.datasets.utils import cycle
from lerobot.common.envs.factory import make_env
from lerobot.common.envs.utils import postprocess_action, preprocess_observation
from lerobot.common.policies.act.modeling_act import ACTPolicy
from lerobot.common.policies.diffusion.modeling_diffusion import DiffusionPolicy
from lerobot.common.policies.factory import make_policy
from lerobot.common.policies.factory import get_policy_and_config_classes, make_policy
from lerobot.common.policies.normalize import Normalize, Unnormalize
from lerobot.common.policies.policy_protocol import Policy
from lerobot.common.utils.utils import init_hydra_config
from tests.utils import DEFAULT_CONFIG_PATH, DEVICE, require_env
@pytest.mark.parametrize("policy_name", available_policies)
def test_get_policy_and_config_classes(policy_name: str):
"""Check that the correct policy and config classes are returned."""
if policy_name == "tdmpc":
with pytest.raises(NotImplementedError):
get_policy_and_config_classes(policy_name)
return
policy_cls, config_cls = get_policy_and_config_classes(policy_name)
assert policy_cls.name == policy_name
assert issubclass(config_cls, inspect.signature(policy_cls.__init__).parameters["config"].annotation)
# TODO(aliberts): refactor using lerobot/__init__.py variables
@pytest.mark.parametrize(
"env_name,policy_name,extra_overrides",
@ -44,7 +58,8 @@ def test_policy(env_name, policy_name, extra_overrides):
"""
Tests:
- Making the policy object.
- Checking that the policy follows the correct protocol.
- Checking that the policy follows the correct protocol and subclasses nn.Module
and PyTorchModelHubMixin.
- Updating the policy.
- Using the policy to select actions at inference time.
- Test the action can be applied to the policy
@ -61,11 +76,13 @@ def test_policy(env_name, policy_name, extra_overrides):
# Check that we can make the policy object.
dataset = make_dataset(cfg)
policy = make_policy(cfg, dataset_stats=dataset.stats)
policy = make_policy(hydra_cfg=cfg, dataset_stats=dataset.stats)
# Check that the policy follows the required protocol.
assert isinstance(
policy, Policy
), f"The policy does not follow the required protocol. Please see {Policy.__module__}.{Policy.__name__}."
assert isinstance(policy, torch.nn.Module)
assert isinstance(policy, PyTorchModelHubMixin)
# Check that we run select_actions and get the appropriate output.
env = make_env(cfg, num_parallel_envs=2)
@ -108,29 +125,33 @@ def test_policy(env_name, policy_name, extra_overrides):
# Test step through policy
env.step(action)
# Test load state_dict
if policy_name != "tdmpc":
# TODO(rcadene, alexander-soare): make it work for tdmpc
new_policy = make_policy(cfg)
new_policy.load_state_dict(policy.state_dict())
@pytest.mark.parametrize("policy_name", available_policies)
def test_policy_defaults(policy_name: str):
"""Check that the policy can be instantiated with defaults."""
if policy_name == "tdmpc":
with pytest.raises(NotImplementedError):
get_policy_and_config_classes(policy_name)
return
policy_cls, _ = get_policy_and_config_classes(policy_name)
policy_cls()
@pytest.mark.parametrize("policy_cls", [DiffusionPolicy, ACTPolicy])
def test_policy_defaults(policy_cls):
kwargs = {}
# TODO(alexander-soare): Remove this kwargs hack when we move the scheduler out of DP.
if policy_cls is DiffusionPolicy:
kwargs = {"lr_scheduler_num_training_steps": 1}
policy_cls(**kwargs)
@pytest.mark.parametrize("policy_name", available_policies)
def test_save_and_load_pretrained(policy_name: str):
if policy_name == "tdmpc":
with pytest.raises(NotImplementedError):
get_policy_and_config_classes(policy_name)
return
policy_cls, _ = get_policy_and_config_classes(policy_name)
policy: Policy = policy_cls()
save_dir = "/tmp/test_save_and_load_pretrained_{policy_cls.__name__}"
policy.save_pretrained(save_dir)
policy_ = policy_cls.from_pretrained(save_dir)
assert all(torch.equal(p, p_) for p, p_ in zip(policy.parameters(), policy_.parameters(), strict=True))
@pytest.mark.parametrize(
"insert_temporal_dim",
[
False,
True,
],
)
@pytest.mark.parametrize("insert_temporal_dim", [False, True])
def test_normalize(insert_temporal_dim):
"""
Test that normalize/unnormalize can run without exceptions when properly set up, and that they raise