From a4891095e4c118cb27909453443c0e16632764f3 Mon Sep 17 00:00:00 2001 From: Alexander Soare Date: Wed, 1 May 2024 16:17:18 +0100 Subject: [PATCH] Use PytorchModelHubMixin to save models as safetensors (#125) Co-authored-by: Remi --- Makefile | 18 +- README.md | 27 +- examples/2_evaluate_pretrained_policy.py | 21 +- examples/3_train_policy.py | 7 +- lerobot/common/logger.py | 15 +- .../common/policies/act/configuration_act.py | 4 +- lerobot/common/policies/act/modeling_act.py | 164 +++++----- .../policies/diffusion/modeling_diffusion.py | 177 +++++------ lerobot/common/policies/factory.py | 66 ++-- lerobot/common/policies/normalize.py | 60 ++-- lerobot/common/policies/policy_protocol.py | 5 +- lerobot/configs/policy/act.yaml | 4 +- lerobot/configs/policy/diffusion.yaml | 2 - lerobot/scripts/eval.py | 123 +++++--- lerobot/scripts/train.py | 11 +- poetry.lock | 296 ++++++++---------- tests/test_examples.py | 14 +- tests/test_policies.py | 69 ++-- 18 files changed, 556 insertions(+), 527 deletions(-) diff --git a/Makefile b/Makefile index 7c82e29a..ea6c3091 100644 --- a/Makefile +++ b/Makefile @@ -24,6 +24,7 @@ test-end-to-end: ${MAKE} test-diffusion-ete-eval # ${MAKE} test-tdmpc-ete-train # ${MAKE} test-tdmpc-ete-eval + ${MAKE} test-default-ete-eval test-act-ete-train: python lerobot/scripts/train.py \ @@ -43,11 +44,10 @@ test-act-ete-train: test-act-ete-eval: python lerobot/scripts/eval.py \ - --config tests/outputs/act/.hydra/config.yaml \ + -p tests/outputs/act/checkpoints/000002 \ eval.n_episodes=1 \ env.episode_length=8 \ device=cpu \ - policy.pretrained_model_path=tests/outputs/act/models/2.pt test-diffusion-ete-train: python lerobot/scripts/train.py \ @@ -65,11 +65,10 @@ test-diffusion-ete-train: test-diffusion-ete-eval: python lerobot/scripts/eval.py \ - --config tests/outputs/diffusion/.hydra/config.yaml \ + -p tests/outputs/diffusion/checkpoints/000002 \ eval.n_episodes=1 \ env.episode_length=8 \ device=cpu \ - policy.pretrained_model_path=tests/outputs/diffusion/models/2.pt test-tdmpc-ete-train: python lerobot/scripts/train.py \ @@ -88,8 +87,15 @@ test-tdmpc-ete-train: test-tdmpc-ete-eval: python lerobot/scripts/eval.py \ - --config tests/outputs/tdmpc/.hydra/config.yaml \ + -p tests/outputs/tdmpc/checkpoints/000002 \ + eval.n_episodes=1 \ + env.episode_length=8 \ + device=cpu \ + + +test-default-ete-eval: + python lerobot/scripts/eval.py \ + --config lerobot/configs/default.yaml \ eval.n_episodes=1 \ env.episode_length=8 \ device=cpu \ - policy.pretrained_model_path=tests/outputs/tdmpc/models/2.pt diff --git a/README.md b/README.md index 35a2f422..96378c44 100644 --- a/README.md +++ b/README.md @@ -135,16 +135,16 @@ Check out [examples](./examples) to see how you can load a pretrained policy fro Or you can achieve the same result by executing our script from the command line: ```bash python lerobot/scripts/eval.py \ ---hub-id lerobot/diffusion_policy_pusht_image \ +-p lerobot/diffusion_policy_pusht_image \ eval_episodes=10 \ hydra.run.dir=outputs/eval/example_hub ``` After training your own policy, you can also re-evaluate the checkpoints with: + ```bash python lerobot/scripts/eval.py \ ---config PATH/TO/FOLDER/config.yaml \ -policy.pretrained_model_path=PATH/TO/FOLDER/weights.pth \ +-p PATH/TO/TRAIN/OUTPUT/FOLDER \ eval_episodes=10 \ hydra.run.dir=outputs/eval/example_dir ``` @@ -246,29 +246,22 @@ Once you have trained a policy you may upload it to the HuggingFace hub. Firstly, make sure you have a model repository set up on the hub. The hub ID looks like HF_USER/REPO_NAME. -Secondly, assuming you have trained a policy, you need: +Secondly, assuming you have trained a policy, you need the following (which should all be in any of the subdirectories of `checkpoints` in your training output folder, if you've used the LeRobot training script): -- `config.yaml` which you can get from the `.hydra` directory of your training output folder. -- `model.pt` which should be one of the saved models in the `models` directory of your training output folder (they won't be named `model.pt` but you will need to choose one). +- `config.json`: A serialized version of the policy configuration (following the policy's dataclass config). +- `model.safetensors`: The `torch.nn.Module` parameters saved in [Hugging Face Safetensors](https://huggingface.co/docs/safetensors/index) format. +- `config.yaml`: This is the consolidated Hydra training configuration containing the policy, environment, and dataset configs. The policy configuration should match `config.json` exactly. The environment config is useful for anyone who wants to evaluate your policy. The dataset config just serves as a paper trail for reproducibility. -To upload these to the hub, prepare a folder with the following structure (you can use symlinks rather than copying): - -``` -to_upload - ├── config.yaml - └── model.pt -``` - -With the folder prepared, run the following with a desired revision ID. +To upload these to the hub, run the following with a desired revision ID. ```bash -huggingface-cli upload $HUB_ID to_upload --revision $REVISION_ID +huggingface-cli upload $HUB_ID PATH/TO/OUTPUT/DIR --revision $REVISION_ID ``` If you want this to be the default revision also run the following (don't worry, it won't upload the files again; it will just adjust the file pointers): ```bash -huggingface-cli upload $HUB_ID to_upload +huggingface-cli upload $HUB_ID PATH/TO/OUTPUT/DIR ``` See `eval.py` for an example of how a user may use your policy. diff --git a/examples/2_evaluate_pretrained_policy.py b/examples/2_evaluate_pretrained_policy.py index 9aae7e1e..8a72201f 100644 --- a/examples/2_evaluate_pretrained_policy.py +++ b/examples/2_evaluate_pretrained_policy.py @@ -7,32 +7,21 @@ from pathlib import Path from huggingface_hub import snapshot_download -from lerobot.common.utils.utils import init_hydra_config from lerobot.scripts.eval import eval # Get a pretrained policy from the hub. -# TODO(alexander-soare): This no longer works until we upload a new model that uses the current configs. -hub_id = "lerobot/diffusion_policy_pusht_image" -folder = Path(snapshot_download(hub_id)) +pretrained_policy_name = "lerobot/diffusion_policy_pusht_image" +pretrained_policy_path = Path(snapshot_download(pretrained_policy_name)) # OR uncomment the following to evaluate a policy from the local outputs/train folder. -# folder = Path("outputs/train/example_pusht_diffusion") - -config_path = folder / "config.yaml" -weights_path = folder / "model.pt" +# pretrained_policy_path = Path("outputs/train/example_pusht_diffusion") # Override some config parameters to do with evaluation. overrides = [ - f"policy.pretrained_model_path={weights_path}", "eval.n_episodes=10", "eval.batch_size=10", "device=cuda", ] -# Create a Hydra config. -cfg = init_hydra_config(config_path, overrides) - # Evaluate the policy and save the outputs including metrics and videos. -eval( - cfg, - out_dir=f"outputs/eval/example_{cfg.env.name}_{cfg.policy.name}", -) +# TODO(rcadene, alexander-soare): dont call eval, but add the minimal code snippet to rollout +eval(pretrained_policy_path=pretrained_policy_path) diff --git a/examples/3_train_policy.py b/examples/3_train_policy.py index 03cd91d1..69e3d34c 100644 --- a/examples/3_train_policy.py +++ b/examples/3_train_policy.py @@ -34,7 +34,7 @@ dataset = make_dataset(hydra_cfg) # If you're doing something different, you will likely need to change at least some of the defaults. cfg = DiffusionConfig() # TODO(alexander-soare): Remove LR scheduler from the policy. -policy = DiffusionPolicy(cfg, lr_scheduler_num_training_steps=training_steps, dataset_stats=dataset.stats) +policy = DiffusionPolicy(cfg, dataset_stats=dataset.stats) policy.train() policy.to(device) @@ -69,6 +69,7 @@ while not done: done = True break -# Save the policy and configuration for later use. -policy.save(output_directory / "model.pt") +# Save the policy. +policy.save_pretrained(output_directory) +# Save the Hydra configuration so we have the environment configuration for eval. OmegaConf.save(hydra_cfg, output_directory / "config.yaml") diff --git a/lerobot/common/logger.py b/lerobot/common/logger.py index 8420685c..4c27fe7f 100644 --- a/lerobot/common/logger.py +++ b/lerobot/common/logger.py @@ -2,9 +2,12 @@ import logging import os from pathlib import Path +from huggingface_hub.constants import SAFETENSORS_SINGLE_FILE from omegaconf import OmegaConf from termcolor import colored +from lerobot.common.policies.policy_protocol import Policy + def log_output_dir(out_dir): logging.info(colored("Output dir:", "yellow", attrs=["bold"]) + f" {out_dir}") @@ -27,7 +30,7 @@ class Logger: self._log_dir = Path(log_dir) self._log_dir.mkdir(parents=True, exist_ok=True) self._job_name = job_name - self._model_dir = self._log_dir / "models" + self._model_dir = self._log_dir / "checkpoints" self._buffer_dir = self._log_dir / "buffers" self._save_model = cfg.training.save_model self._disable_wandb_artifact = cfg.wandb.disable_artifact @@ -67,18 +70,20 @@ class Logger: logging.info(f"Track this run --> {colored(wandb.run.get_url(), 'yellow', attrs=['bold'])}") self._wandb = wandb - def save_model(self, policy, identifier): + def save_model(self, policy: Policy, identifier): if self._save_model: self._model_dir.mkdir(parents=True, exist_ok=True) - fp = self._model_dir / f"{str(identifier)}.pt" - policy.save(fp) + save_dir = self._model_dir / str(identifier) + policy.save_pretrained(save_dir) + # Also save the full Hydra config for the env configuration. + OmegaConf.save(self._cfg, save_dir / "config.yaml") if self._wandb and not self._disable_wandb_artifact: # note wandb artifact does not accept ":" in its name artifact = self._wandb.Artifact( self._group.replace(":", "_") + "-" + str(self._seed) + "-" + str(identifier), type="model", ) - artifact.add_file(fp) + artifact.add_file(save_dir / SAFETENSORS_SINGLE_FILE) self._wandb.log_artifact(artifact) def save_buffer(self, buffer, identifier): diff --git a/lerobot/common/policies/act/configuration_act.py b/lerobot/common/policies/act/configuration_act.py index 4b0b0d19..b3700a26 100644 --- a/lerobot/common/policies/act/configuration_act.py +++ b/lerobot/common/policies/act/configuration_act.py @@ -38,7 +38,7 @@ class ACTConfig: replace_final_stride_with_dilation: Whether to replace the ResNet's final 2x2 stride with a dilated convolution. pre_norm: Whether to use "pre-norm" in the transformer blocks. - d_model: The transformer blocks' main hidden dimension. + dim_model: The transformer blocks' main hidden dimension. n_heads: The number of heads to use in the transformer blocks' multi-head attention. dim_feedforward: The dimension to expand the transformer's hidden dimension to in the feed-forward layers. @@ -94,7 +94,7 @@ class ACTConfig: replace_final_stride_with_dilation: int = False # Transformer layers. pre_norm: bool = False - d_model: int = 512 + dim_model: int = 512 n_heads: int = 8 dim_feedforward: int = 3200 feedforward_activation: str = "relu" diff --git a/lerobot/common/policies/act/modeling_act.py b/lerobot/common/policies/act/modeling_act.py index f4564284..448bd2cb 100644 --- a/lerobot/common/policies/act/modeling_act.py +++ b/lerobot/common/policies/act/modeling_act.py @@ -14,6 +14,7 @@ import numpy as np import torch import torch.nn.functional as F # noqa: N812 import torchvision +from huggingface_hub import PyTorchModelHubMixin from torch import Tensor, nn from torchvision.models._utils import IntermediateLayerGetter from torchvision.ops.misc import FrozenBatchNorm2d @@ -22,7 +23,7 @@ from lerobot.common.policies.act.configuration_act import ACTConfig from lerobot.common.policies.normalize import Normalize, Unnormalize -class ACTPolicy(nn.Module): +class ACTPolicy(nn.Module, PyTorchModelHubMixin): """ Action Chunking Transformer Policy as per Learning Fine-Grained Bimanual Manipulation with Low-Cost Hardware (paper: https://arxiv.org/abs/2304.13705, code: https://github.com/tonyzhaozh/act) @@ -30,27 +31,31 @@ class ACTPolicy(nn.Module): name = "act" - def __init__(self, cfg: ACTConfig | None = None, dataset_stats=None): + def __init__(self, config: ACTConfig | None = None, dataset_stats=None): """ Args: - cfg: Policy configuration class instance or None, in which case the default instantiation of the - configuration class is used. + config: Policy configuration class instance or None, in which case the default instantiation of + the configuration class is used. """ super().__init__() - if cfg is None: - cfg = ACTConfig() - self.cfg = cfg - self.normalize_inputs = Normalize(cfg.input_shapes, cfg.input_normalization_modes, dataset_stats) - self.normalize_targets = Normalize(cfg.output_shapes, cfg.output_normalization_modes, dataset_stats) - self.unnormalize_outputs = Unnormalize( - cfg.output_shapes, cfg.output_normalization_modes, dataset_stats + if config is None: + config = ACTConfig() + self.config = config + self.normalize_inputs = Normalize( + config.input_shapes, config.input_normalization_modes, dataset_stats ) - self.model = ACT(cfg) + self.normalize_targets = Normalize( + config.output_shapes, config.output_normalization_modes, dataset_stats + ) + self.unnormalize_outputs = Unnormalize( + config.output_shapes, config.output_normalization_modes, dataset_stats + ) + self.model = ACT(config) def reset(self): """This should be called whenever the environment is reset.""" - if self.cfg.n_action_steps is not None: - self._action_queue = deque([], maxlen=self.cfg.n_action_steps) + if self.config.n_action_steps is not None: + self._action_queue = deque([], maxlen=self.config.n_action_steps) @torch.no_grad def select_action(self, batch: dict[str, Tensor], **_) -> Tensor: @@ -68,7 +73,7 @@ class ACTPolicy(nn.Module): if len(self._action_queue) == 0: # `self.model.forward` returns a (batch_size, n_action_steps, action_dim) tensor, but the queue # effectively has shape (n_action_steps, batch_size, *), hence the transpose. - actions = self.model(batch)[0][: self.cfg.n_action_steps] + actions = self.model(batch)[0][: self.config.n_action_steps] # TODO(rcadene): make _forward return output dictionary? actions = self.unnormalize_outputs({"action": actions})["action"] @@ -88,7 +93,7 @@ class ACTPolicy(nn.Module): ).mean() loss_dict = {"l1_loss": l1_loss} - if self.cfg.use_vae: + if self.config.use_vae: # Calculate Dₖₗ(latent_pdf || standard_normal). Note: After computing the KL-divergence for # each dimension independently, we sum over the latent dimension to get the total # KL-divergence per batch element, then take the mean over the batch. @@ -97,7 +102,7 @@ class ACTPolicy(nn.Module): (-0.5 * (1 + log_sigma_x2_hat - mu_hat.pow(2) - (log_sigma_x2_hat).exp())).sum(-1).mean() ) loss_dict["kld_loss"] = mean_kld - loss_dict["loss"] = l1_loss + mean_kld * self.cfg.kl_weight + loss_dict["loss"] = l1_loss + mean_kld * self.config.kl_weight else: loss_dict["loss"] = l1_loss @@ -114,17 +119,10 @@ class ACTPolicy(nn.Module): """ # Stack images in the order dictated by input_shapes. batch["observation.images"] = torch.stack( - [batch[k] for k in self.cfg.input_shapes if k.startswith("observation.images.")], + [batch[k] for k in self.config.input_shapes if k.startswith("observation.images.")], dim=-4, ) - def save(self, fp): - torch.save(self.state_dict(), fp) - - def load(self, fp): - d = torch.load(fp) - self.load_state_dict(d) - class ACT(nn.Module): """Action Chunking Transformer: The underlying neural network for ACTPolicy. @@ -161,36 +159,36 @@ class ACT(nn.Module): └───────────────────────┘ """ - def __init__(self, cfg: ACTConfig): + def __init__(self, config: ACTConfig): super().__init__() - self.cfg = cfg + self.config = config # BERT style VAE encoder with input [cls, *joint_space_configuration, *action_sequence]. # The cls token forms parameters of the latent's distribution (like this [*means, *log_variances]). - if self.cfg.use_vae: - self.vae_encoder = ACTEncoder(cfg) - self.vae_encoder_cls_embed = nn.Embedding(1, cfg.d_model) + if self.config.use_vae: + self.vae_encoder = ACTEncoder(config) + self.vae_encoder_cls_embed = nn.Embedding(1, config.dim_model) # Projection layer for joint-space configuration to hidden dimension. self.vae_encoder_robot_state_input_proj = nn.Linear( - cfg.input_shapes["observation.state"][0], cfg.d_model + config.input_shapes["observation.state"][0], config.dim_model ) # Projection layer for action (joint-space target) to hidden dimension. self.vae_encoder_action_input_proj = nn.Linear( - cfg.input_shapes["observation.state"][0], cfg.d_model + config.input_shapes["observation.state"][0], config.dim_model ) - self.latent_dim = cfg.latent_dim + self.latent_dim = config.latent_dim # Projection layer from the VAE encoder's output to the latent distribution's parameter space. - self.vae_encoder_latent_output_proj = nn.Linear(cfg.d_model, self.latent_dim * 2) + self.vae_encoder_latent_output_proj = nn.Linear(config.dim_model, self.latent_dim * 2) # Fixed sinusoidal positional embedding the whole input to the VAE encoder. Unsqueeze for batch # dimension. self.register_buffer( "vae_encoder_pos_enc", - create_sinusoidal_position_embedding(1 + 1 + cfg.chunk_size, cfg.d_model).unsqueeze(0), + create_sinusoidal_pos_embedding(1 + 1 + config.chunk_size, config.dim_model).unsqueeze(0), ) # Backbone for image feature extraction. - backbone_model = getattr(torchvision.models, cfg.vision_backbone)( - replace_stride_with_dilation=[False, False, cfg.replace_final_stride_with_dilation], - weights=cfg.pretrained_backbone_weights, + backbone_model = getattr(torchvision.models, config.vision_backbone)( + replace_stride_with_dilation=[False, False, config.replace_final_stride_with_dilation], + weights=config.pretrained_backbone_weights, norm_layer=FrozenBatchNorm2d, ) # Note: The assumption here is that we are using a ResNet model (and hence layer4 is the final feature @@ -199,26 +197,28 @@ class ACT(nn.Module): self.backbone = IntermediateLayerGetter(backbone_model, return_layers={"layer4": "feature_map"}) # Transformer (acts as VAE decoder when training with the variational objective). - self.encoder = ACTEncoder(cfg) - self.decoder = ACTDecoder(cfg) + self.encoder = ACTEncoder(config) + self.decoder = ACTDecoder(config) # Transformer encoder input projections. The tokens will be structured like # [latent, robot_state, image_feature_map_pixels]. - self.encoder_robot_state_input_proj = nn.Linear(cfg.input_shapes["observation.state"][0], cfg.d_model) - self.encoder_latent_input_proj = nn.Linear(self.latent_dim, cfg.d_model) + self.encoder_robot_state_input_proj = nn.Linear( + config.input_shapes["observation.state"][0], config.dim_model + ) + self.encoder_latent_input_proj = nn.Linear(self.latent_dim, config.dim_model) self.encoder_img_feat_input_proj = nn.Conv2d( - backbone_model.fc.in_features, cfg.d_model, kernel_size=1 + backbone_model.fc.in_features, config.dim_model, kernel_size=1 ) # Transformer encoder positional embeddings. - self.encoder_robot_and_latent_pos_embed = nn.Embedding(2, cfg.d_model) - self.encoder_cam_feat_pos_embed = ACTSinusoidalPositionEmbedding2d(cfg.d_model // 2) + self.encoder_robot_and_latent_pos_embed = nn.Embedding(2, config.dim_model) + self.encoder_cam_feat_pos_embed = ACTSinusoidalPositionEmbedding2d(config.dim_model // 2) # Transformer decoder. # Learnable positional embedding for the transformer's decoder (in the style of DETR object queries). - self.decoder_pos_embed = nn.Embedding(cfg.chunk_size, cfg.d_model) + self.decoder_pos_embed = nn.Embedding(config.chunk_size, config.dim_model) # Final action regression head on the output of the transformer's decoder. - self.action_head = nn.Linear(cfg.d_model, cfg.output_shapes["action"][0]) + self.action_head = nn.Linear(config.dim_model, config.output_shapes["action"][0]) self._reset_parameters() @@ -244,7 +244,7 @@ class ACT(nn.Module): Tuple containing the latent PDF's parameters (mean, log(σ²)) both as (B, L) tensors where L is the latent dimension. """ - if self.cfg.use_vae and self.training: + if self.config.use_vae and self.training: assert ( "action" in batch ), "actions must be provided when using the variational objective in training mode." @@ -252,7 +252,7 @@ class ACT(nn.Module): batch_size = batch["observation.state"].shape[0] # Prepare the latent for input to the transformer encoder. - if self.cfg.use_vae and "action" in batch: + if self.config.use_vae and "action" in batch: # Prepare the input to the VAE encoder: [cls, *joint_space_configuration, *action_sequence]. cls_embed = einops.repeat( self.vae_encoder_cls_embed.weight, "1 d -> b 1 d", b=batch_size @@ -322,7 +322,7 @@ class ACT(nn.Module): # Forward pass through the transformer modules. encoder_out = self.encoder(encoder_in, pos_embed=pos_embed) decoder_in = torch.zeros( - (self.cfg.chunk_size, batch_size, self.cfg.d_model), + (self.config.chunk_size, batch_size, self.config.dim_model), dtype=pos_embed.dtype, device=pos_embed.device, ) @@ -344,10 +344,10 @@ class ACT(nn.Module): class ACTEncoder(nn.Module): """Convenience module for running multiple encoder layers, maybe followed by normalization.""" - def __init__(self, cfg: ACTConfig): + def __init__(self, config: ACTConfig): super().__init__() - self.layers = nn.ModuleList([ACTEncoderLayer(cfg) for _ in range(cfg.n_encoder_layers)]) - self.norm = nn.LayerNorm(cfg.d_model) if cfg.pre_norm else nn.Identity() + self.layers = nn.ModuleList([ACTEncoderLayer(config) for _ in range(config.n_encoder_layers)]) + self.norm = nn.LayerNorm(config.dim_model) if config.pre_norm else nn.Identity() def forward(self, x: Tensor, pos_embed: Tensor | None = None) -> Tensor: for layer in self.layers: @@ -357,22 +357,22 @@ class ACTEncoder(nn.Module): class ACTEncoderLayer(nn.Module): - def __init__(self, cfg: ACTConfig): + def __init__(self, config: ACTConfig): super().__init__() - self.self_attn = nn.MultiheadAttention(cfg.d_model, cfg.n_heads, dropout=cfg.dropout) + self.self_attn = nn.MultiheadAttention(config.dim_model, config.n_heads, dropout=config.dropout) # Feed forward layers. - self.linear1 = nn.Linear(cfg.d_model, cfg.dim_feedforward) - self.dropout = nn.Dropout(cfg.dropout) - self.linear2 = nn.Linear(cfg.dim_feedforward, cfg.d_model) + self.linear1 = nn.Linear(config.dim_model, config.dim_feedforward) + self.dropout = nn.Dropout(config.dropout) + self.linear2 = nn.Linear(config.dim_feedforward, config.dim_model) - self.norm1 = nn.LayerNorm(cfg.d_model) - self.norm2 = nn.LayerNorm(cfg.d_model) - self.dropout1 = nn.Dropout(cfg.dropout) - self.dropout2 = nn.Dropout(cfg.dropout) + self.norm1 = nn.LayerNorm(config.dim_model) + self.norm2 = nn.LayerNorm(config.dim_model) + self.dropout1 = nn.Dropout(config.dropout) + self.dropout2 = nn.Dropout(config.dropout) - self.activation = get_activation_fn(cfg.feedforward_activation) - self.pre_norm = cfg.pre_norm + self.activation = get_activation_fn(config.feedforward_activation) + self.pre_norm = config.pre_norm def forward(self, x, pos_embed: Tensor | None = None) -> Tensor: skip = x @@ -395,11 +395,11 @@ class ACTEncoderLayer(nn.Module): class ACTDecoder(nn.Module): - def __init__(self, cfg: ACTConfig): + def __init__(self, config: ACTConfig): """Convenience module for running multiple decoder layers followed by normalization.""" super().__init__() - self.layers = nn.ModuleList([ACTDecoderLayer(cfg) for _ in range(cfg.n_decoder_layers)]) - self.norm = nn.LayerNorm(cfg.d_model) + self.layers = nn.ModuleList([ACTDecoderLayer(config) for _ in range(config.n_decoder_layers)]) + self.norm = nn.LayerNorm(config.dim_model) def forward( self, @@ -418,25 +418,25 @@ class ACTDecoder(nn.Module): class ACTDecoderLayer(nn.Module): - def __init__(self, cfg: ACTConfig): + def __init__(self, config: ACTConfig): super().__init__() - self.self_attn = nn.MultiheadAttention(cfg.d_model, cfg.n_heads, dropout=cfg.dropout) - self.multihead_attn = nn.MultiheadAttention(cfg.d_model, cfg.n_heads, dropout=cfg.dropout) + self.self_attn = nn.MultiheadAttention(config.dim_model, config.n_heads, dropout=config.dropout) + self.multihead_attn = nn.MultiheadAttention(config.dim_model, config.n_heads, dropout=config.dropout) # Feed forward layers. - self.linear1 = nn.Linear(cfg.d_model, cfg.dim_feedforward) - self.dropout = nn.Dropout(cfg.dropout) - self.linear2 = nn.Linear(cfg.dim_feedforward, cfg.d_model) + self.linear1 = nn.Linear(config.dim_model, config.dim_feedforward) + self.dropout = nn.Dropout(config.dropout) + self.linear2 = nn.Linear(config.dim_feedforward, config.dim_model) - self.norm1 = nn.LayerNorm(cfg.d_model) - self.norm2 = nn.LayerNorm(cfg.d_model) - self.norm3 = nn.LayerNorm(cfg.d_model) - self.dropout1 = nn.Dropout(cfg.dropout) - self.dropout2 = nn.Dropout(cfg.dropout) - self.dropout3 = nn.Dropout(cfg.dropout) + self.norm1 = nn.LayerNorm(config.dim_model) + self.norm2 = nn.LayerNorm(config.dim_model) + self.norm3 = nn.LayerNorm(config.dim_model) + self.dropout1 = nn.Dropout(config.dropout) + self.dropout2 = nn.Dropout(config.dropout) + self.dropout3 = nn.Dropout(config.dropout) - self.activation = get_activation_fn(cfg.feedforward_activation) - self.pre_norm = cfg.pre_norm + self.activation = get_activation_fn(config.feedforward_activation) + self.pre_norm = config.pre_norm def maybe_add_pos_embed(self, tensor: Tensor, pos_embed: Tensor | None) -> Tensor: return tensor if pos_embed is None else tensor + pos_embed @@ -489,7 +489,7 @@ class ACTDecoderLayer(nn.Module): return x -def create_sinusoidal_position_embedding(num_positions: int, dimension: int) -> Tensor: +def create_sinusoidal_pos_embedding(num_positions: int, dimension: int) -> Tensor: """1D sinusoidal positional embeddings as in Attention is All You Need. Args: diff --git a/lerobot/common/policies/diffusion/modeling_diffusion.py b/lerobot/common/policies/diffusion/modeling_diffusion.py index 728fa97a..f57daf63 100644 --- a/lerobot/common/policies/diffusion/modeling_diffusion.py +++ b/lerobot/common/policies/diffusion/modeling_diffusion.py @@ -9,7 +9,6 @@ TODO(alexander-soare): """ import copy -import logging import math from collections import deque from typing import Callable @@ -19,6 +18,7 @@ import torch import torch.nn.functional as F # noqa: N812 import torchvision from diffusers.schedulers.scheduling_ddpm import DDPMScheduler +from huggingface_hub import PyTorchModelHubMixin from robomimic.models.base_nets import SpatialSoftmax from torch import Tensor, nn from torch.nn.modules.batchnorm import _BatchNorm @@ -32,7 +32,7 @@ from lerobot.common.policies.utils import ( ) -class DiffusionPolicy(nn.Module): +class DiffusionPolicy(nn.Module, PyTorchModelHubMixin): """ Diffusion Policy as per "Diffusion Policy: Visuomotor Policy Learning via Action Diffusion" (paper: https://arxiv.org/abs/2303.04137, code: https://github.com/real-stanford/diffusion_policy). @@ -41,45 +41,50 @@ class DiffusionPolicy(nn.Module): name = "diffusion" def __init__( - self, cfg: DiffusionConfig | None = None, lr_scheduler_num_training_steps: int = 0, dataset_stats=None + self, + config: DiffusionConfig | None = None, + dataset_stats=None, ): """ Args: - cfg: Policy configuration class instance or None, in which case the default instantiation of the - configuration class is used. + config: Policy configuration class instance or None, in which case the default instantiation of + the configuration class is used. """ super().__init__() # TODO(alexander-soare): LR scheduler will be removed. - assert lr_scheduler_num_training_steps > 0 - if cfg is None: - cfg = DiffusionConfig() - self.cfg = cfg - self.normalize_inputs = Normalize(cfg.input_shapes, cfg.input_normalization_modes, dataset_stats) - self.normalize_targets = Normalize(cfg.output_shapes, cfg.output_normalization_modes, dataset_stats) + if config is None: + config = DiffusionConfig() + self.config = config + self.normalize_inputs = Normalize( + config.input_shapes, config.input_normalization_modes, dataset_stats + ) + self.normalize_targets = Normalize( + config.output_shapes, config.output_normalization_modes, dataset_stats + ) self.unnormalize_outputs = Unnormalize( - cfg.output_shapes, cfg.output_normalization_modes, dataset_stats + config.output_shapes, config.output_normalization_modes, dataset_stats ) # queues are populated during rollout of the policy, they contain the n latest observations and actions self._queues = None - self.diffusion = DiffusionModel(cfg) + self.diffusion = DiffusionModel(config) # TODO(alexander-soare): This should probably be managed outside of the policy class. self.ema_diffusion = None self.ema = None - if self.cfg.use_ema: + if self.config.use_ema: self.ema_diffusion = copy.deepcopy(self.diffusion) - self.ema = DiffusionEMA(cfg, model=self.ema_diffusion) + self.ema = DiffusionEMA(config, model=self.ema_diffusion) def reset(self): """ Clear observation and action queues. Should be called on `env.reset()` """ self._queues = { - "observation.image": deque(maxlen=self.cfg.n_obs_steps), - "observation.state": deque(maxlen=self.cfg.n_obs_steps), - "action": deque(maxlen=self.cfg.n_action_steps), + "observation.image": deque(maxlen=self.config.n_obs_steps), + "observation.state": deque(maxlen=self.config.n_obs_steps), + "action": deque(maxlen=self.config.n_action_steps), } @torch.no_grad @@ -138,46 +143,34 @@ class DiffusionPolicy(nn.Module): loss = self.diffusion.compute_loss(batch) return {"loss": loss} - def save(self, fp): - torch.save(self.state_dict(), fp) - - def load(self, fp): - d = torch.load(fp) - missing_keys, unexpected_keys = self.load_state_dict(d, strict=False) - if len(missing_keys) > 0: - assert all(k.startswith("ema_diffusion.") for k in missing_keys) - logging.warning( - "DiffusionPolicy.load expected ema parameters in loaded state dict but none were found." - ) - assert len(unexpected_keys) == 0 - class DiffusionModel(nn.Module): - def __init__(self, cfg: DiffusionConfig): + def __init__(self, config: DiffusionConfig): super().__init__() - self.cfg = cfg + self.config = config - self.rgb_encoder = DiffusionRgbEncoder(cfg) + self.rgb_encoder = DiffusionRgbEncoder(config) self.unet = DiffusionConditionalUnet1d( - cfg, - global_cond_dim=(cfg.output_shapes["action"][0] + self.rgb_encoder.feature_dim) * cfg.n_obs_steps, + config, + global_cond_dim=(config.output_shapes["action"][0] + self.rgb_encoder.feature_dim) + * config.n_obs_steps, ) self.noise_scheduler = DDPMScheduler( - num_train_timesteps=cfg.num_train_timesteps, - beta_start=cfg.beta_start, - beta_end=cfg.beta_end, - beta_schedule=cfg.beta_schedule, + num_train_timesteps=config.num_train_timesteps, + beta_start=config.beta_start, + beta_end=config.beta_end, + beta_schedule=config.beta_schedule, variance_type="fixed_small", - clip_sample=cfg.clip_sample, - clip_sample_range=cfg.clip_sample_range, - prediction_type=cfg.prediction_type, + clip_sample=config.clip_sample, + clip_sample_range=config.clip_sample_range, + prediction_type=config.prediction_type, ) - if cfg.num_inference_steps is None: + if config.num_inference_steps is None: self.num_inference_steps = self.noise_scheduler.config.num_train_timesteps else: - self.num_inference_steps = cfg.num_inference_steps + self.num_inference_steps = config.num_inference_steps # ========= inference ============ def conditional_sample( @@ -188,7 +181,7 @@ class DiffusionModel(nn.Module): # Sample prior. sample = torch.randn( - size=(batch_size, self.cfg.horizon, self.cfg.output_shapes["action"][0]), + size=(batch_size, self.config.horizon, self.config.output_shapes["action"][0]), dtype=dtype, device=device, generator=generator, @@ -218,7 +211,7 @@ class DiffusionModel(nn.Module): """ assert set(batch).issuperset({"observation.state", "observation.image"}) batch_size, n_obs_steps = batch["observation.state"].shape[:2] - assert n_obs_steps == self.cfg.n_obs_steps + assert n_obs_steps == self.config.n_obs_steps # Extract image feature (first combine batch and sequence dims). img_features = self.rgb_encoder(einops.rearrange(batch["observation.image"], "b n ... -> (b n) ...")) @@ -231,10 +224,10 @@ class DiffusionModel(nn.Module): sample = self.conditional_sample(batch_size, global_cond=global_cond) # `horizon` steps worth of actions (from the first observation). - actions = sample[..., : self.cfg.output_shapes["action"][0]] + actions = sample[..., : self.config.output_shapes["action"][0]] # Extract `n_action_steps` steps worth of actions (from the current observation). start = n_obs_steps - 1 - end = start + self.cfg.n_action_steps + end = start + self.config.n_action_steps actions = actions[:, start:end] return actions @@ -253,8 +246,8 @@ class DiffusionModel(nn.Module): assert set(batch).issuperset({"observation.state", "observation.image", "action", "action_is_pad"}) batch_size, n_obs_steps = batch["observation.state"].shape[:2] horizon = batch["action"].shape[1] - assert horizon == self.cfg.horizon - assert n_obs_steps == self.cfg.n_obs_steps + assert horizon == self.config.horizon + assert n_obs_steps == self.config.n_obs_steps # Extract image feature (first combine batch and sequence dims). img_features = self.rgb_encoder(einops.rearrange(batch["observation.image"], "b n ... -> (b n) ...")) @@ -283,12 +276,12 @@ class DiffusionModel(nn.Module): # Compute the loss. # The target is either the original trajectory, or the noise. - if self.cfg.prediction_type == "epsilon": + if self.config.prediction_type == "epsilon": target = eps - elif self.cfg.prediction_type == "sample": + elif self.config.prediction_type == "sample": target = batch["action"] else: - raise ValueError(f"Unsupported prediction type {self.cfg.prediction_type}") + raise ValueError(f"Unsupported prediction type {self.config.prediction_type}") loss = F.mse_loss(pred, target, reduction="none") @@ -306,29 +299,29 @@ class DiffusionRgbEncoder(nn.Module): Includes the ability to normalize and crop the image first. """ - def __init__(self, cfg: DiffusionConfig): + def __init__(self, config: DiffusionConfig): super().__init__() # Set up optional preprocessing. - if cfg.crop_shape is not None: + if config.crop_shape is not None: self.do_crop = True # Always use center crop for eval - self.center_crop = torchvision.transforms.CenterCrop(cfg.crop_shape) - if cfg.crop_is_random: - self.maybe_random_crop = torchvision.transforms.RandomCrop(cfg.crop_shape) + self.center_crop = torchvision.transforms.CenterCrop(config.crop_shape) + if config.crop_is_random: + self.maybe_random_crop = torchvision.transforms.RandomCrop(config.crop_shape) else: self.maybe_random_crop = self.center_crop else: self.do_crop = False # Set up backbone. - backbone_model = getattr(torchvision.models, cfg.vision_backbone)( - weights=cfg.pretrained_backbone_weights + backbone_model = getattr(torchvision.models, config.vision_backbone)( + weights=config.pretrained_backbone_weights ) # Note: This assumes that the layer4 feature map is children()[-3] # TODO(alexander-soare): Use a safer alternative. self.backbone = nn.Sequential(*(list(backbone_model.children())[:-2])) - if cfg.use_group_norm: - if cfg.pretrained_backbone_weights: + if config.use_group_norm: + if config.pretrained_backbone_weights: raise ValueError( "You can't replace BatchNorm in a pretrained model without ruining the weights!" ) @@ -342,11 +335,11 @@ class DiffusionRgbEncoder(nn.Module): # Use a dry run to get the feature map shape. with torch.inference_mode(): feat_map_shape = tuple( - self.backbone(torch.zeros(size=(1, *cfg.input_shapes["observation.image"]))).shape[1:] + self.backbone(torch.zeros(size=(1, *config.input_shapes["observation.image"]))).shape[1:] ) - self.pool = SpatialSoftmax(feat_map_shape, num_kp=cfg.spatial_softmax_num_keypoints) - self.feature_dim = cfg.spatial_softmax_num_keypoints * 2 - self.out = nn.Linear(cfg.spatial_softmax_num_keypoints * 2, self.feature_dim) + self.pool = SpatialSoftmax(feat_map_shape, num_kp=config.spatial_softmax_num_keypoints) + self.feature_dim = config.spatial_softmax_num_keypoints * 2 + self.out = nn.Linear(config.spatial_softmax_num_keypoints * 2, self.feature_dim) self.relu = nn.ReLU() def forward(self, x: Tensor) -> Tensor: @@ -442,34 +435,34 @@ class DiffusionConditionalUnet1d(nn.Module): Note: this removes local conditioning as compared to the original diffusion policy code. """ - def __init__(self, cfg: DiffusionConfig, global_cond_dim: int): + def __init__(self, config: DiffusionConfig, global_cond_dim: int): super().__init__() - self.cfg = cfg + self.config = config # Encoder for the diffusion timestep. self.diffusion_step_encoder = nn.Sequential( - DiffusionSinusoidalPosEmb(cfg.diffusion_step_embed_dim), - nn.Linear(cfg.diffusion_step_embed_dim, cfg.diffusion_step_embed_dim * 4), + DiffusionSinusoidalPosEmb(config.diffusion_step_embed_dim), + nn.Linear(config.diffusion_step_embed_dim, config.diffusion_step_embed_dim * 4), nn.Mish(), - nn.Linear(cfg.diffusion_step_embed_dim * 4, cfg.diffusion_step_embed_dim), + nn.Linear(config.diffusion_step_embed_dim * 4, config.diffusion_step_embed_dim), ) # The FiLM conditioning dimension. - cond_dim = cfg.diffusion_step_embed_dim + global_cond_dim + cond_dim = config.diffusion_step_embed_dim + global_cond_dim # In channels / out channels for each downsampling block in the Unet's encoder. For the decoder, we # just reverse these. - in_out = [(cfg.output_shapes["action"][0], cfg.down_dims[0])] + list( - zip(cfg.down_dims[:-1], cfg.down_dims[1:], strict=True) + in_out = [(config.output_shapes["action"][0], config.down_dims[0])] + list( + zip(config.down_dims[:-1], config.down_dims[1:], strict=True) ) # Unet encoder. common_res_block_kwargs = { "cond_dim": cond_dim, - "kernel_size": cfg.kernel_size, - "n_groups": cfg.n_groups, - "use_film_scale_modulation": cfg.use_film_scale_modulation, + "kernel_size": config.kernel_size, + "n_groups": config.n_groups, + "use_film_scale_modulation": config.use_film_scale_modulation, } self.down_modules = nn.ModuleList([]) for ind, (dim_in, dim_out) in enumerate(in_out): @@ -489,10 +482,10 @@ class DiffusionConditionalUnet1d(nn.Module): self.mid_modules = nn.ModuleList( [ DiffusionConditionalResidualBlock1d( - cfg.down_dims[-1], cfg.down_dims[-1], **common_res_block_kwargs + config.down_dims[-1], config.down_dims[-1], **common_res_block_kwargs ), DiffusionConditionalResidualBlock1d( - cfg.down_dims[-1], cfg.down_dims[-1], **common_res_block_kwargs + config.down_dims[-1], config.down_dims[-1], **common_res_block_kwargs ), ] ) @@ -514,8 +507,8 @@ class DiffusionConditionalUnet1d(nn.Module): ) self.final_conv = nn.Sequential( - DiffusionConv1dBlock(cfg.down_dims[0], cfg.down_dims[0], kernel_size=cfg.kernel_size), - nn.Conv1d(cfg.down_dims[0], cfg.output_shapes["action"][0], 1), + DiffusionConv1dBlock(config.down_dims[0], config.down_dims[0], kernel_size=config.kernel_size), + nn.Conv1d(config.down_dims[0], config.output_shapes["action"][0], 1), ) def forward(self, x: Tensor, timestep: Tensor | int, global_cond=None) -> Tensor: @@ -626,13 +619,13 @@ class DiffusionEMA: Exponential Moving Average of models weights """ - def __init__(self, cfg: DiffusionConfig, model: nn.Module): + def __init__(self, config: DiffusionConfig, model: nn.Module): """ @crowsonkb's notes on EMA Warmup: - If gamma=1 and power=1, implements a simple average. gamma=1, power=2/3 are good values for models you plan - to train for a million or more steps (reaches decay factor 0.999 at 31.6K steps, 0.9999 at 1M steps), - gamma=1, power=3/4 for models you plan to train for less (reaches decay factor 0.999 at 10K steps, 0.9999 - at 215.4k steps). + If gamma=1 and power=1, implements a simple average. gamma=1, power=2/3 are good values for models + you plan to train for a million or more steps (reaches decay factor 0.999 at 31.6K steps, 0.9999 + at 1M steps), gamma=1, power=3/4 for models you plan to train for less (reaches decay factor 0.999 + at 10K steps, 0.9999 at 215.4k steps). Args: inv_gamma (float): Inverse multiplicative factor of EMA warmup. Default: 1. power (float): Exponential factor of EMA warmup. Default: 2/3. @@ -643,11 +636,11 @@ class DiffusionEMA: self.averaged_model.eval() self.averaged_model.requires_grad_(False) - self.update_after_step = cfg.ema_update_after_step - self.inv_gamma = cfg.ema_inv_gamma - self.power = cfg.ema_power - self.min_alpha = cfg.ema_min_alpha - self.max_alpha = cfg.ema_max_alpha + self.update_after_step = config.ema_update_after_step + self.inv_gamma = config.ema_inv_gamma + self.power = config.ema_power + self.min_alpha = config.ema_min_alpha + self.max_alpha = config.ema_max_alpha self.alpha = 0.0 self.optimization_step = 0 diff --git a/lerobot/common/policies/factory.py b/lerobot/common/policies/factory.py index 069c9fc1..727aa80b 100644 --- a/lerobot/common/policies/factory.py +++ b/lerobot/common/policies/factory.py @@ -2,6 +2,7 @@ import inspect from omegaconf import DictConfig, OmegaConf +from lerobot.common.policies.policy_protocol import Policy from lerobot.common.utils.utils import get_safe_torch_device @@ -20,42 +21,49 @@ def _policy_cfg_from_hydra_cfg(policy_cfg_class, hydra_cfg): return policy_cfg -def make_policy(hydra_cfg: DictConfig, dataset_stats=None): - if hydra_cfg.policy.name == "tdmpc": - from lerobot.common.policies.tdmpc.policy import TDMPCPolicy - - policy = TDMPCPolicy( - hydra_cfg.policy, - n_obs_steps=hydra_cfg.n_obs_steps, - n_action_steps=hydra_cfg.n_action_steps, - device=hydra_cfg.device, - ) - elif hydra_cfg.policy.name == "diffusion": +def get_policy_and_config_classes(name: str) -> tuple[Policy, object]: + """Get the policy's class and config class given a name (matching the policy class' `name` attribute).""" + if name == "tdmpc": + raise NotImplementedError("Coming soon!") + elif name == "diffusion": from lerobot.common.policies.diffusion.configuration_diffusion import DiffusionConfig from lerobot.common.policies.diffusion.modeling_diffusion import DiffusionPolicy - policy_cfg = _policy_cfg_from_hydra_cfg(DiffusionConfig, hydra_cfg) - policy = DiffusionPolicy(policy_cfg, hydra_cfg.training.offline_steps, dataset_stats) - policy.to(get_safe_torch_device(hydra_cfg.device)) - elif hydra_cfg.policy.name == "act": + return DiffusionPolicy, DiffusionConfig + elif name == "act": from lerobot.common.policies.act.configuration_act import ACTConfig from lerobot.common.policies.act.modeling_act import ACTPolicy - policy_cfg = _policy_cfg_from_hydra_cfg(ACTConfig, hydra_cfg) - policy = ACTPolicy(policy_cfg, dataset_stats) + return ACTPolicy, ACTConfig + else: + raise NotImplementedError(f"Policy with name {name} is not implemented.") + + +def make_policy( + hydra_cfg: DictConfig, pretrained_policy_name_or_path: str | None = None, dataset_stats=None +) -> Policy: + """Make an instance of a policy class. + + Args: + hydra_cfg: A parsed Hydra configuration (see scripts). If `pretrained_policy_name_or_path` is + provided, only `hydra_cfg.policy.name` is used while everything else is ignored. + pretrained_policy_name_or_path: Either the repo ID of a model hosted on the Hub or a path to a + directory containing weights saved using `Policy.save_pretrained`. Note that providing this + argument overrides everything in `hydra_cfg.policy` apart from `hydra_cfg.policy.name`. + dataset_stats: Dataset statistics to use for (un)normalization of inputs/outputs in the policy. Must + be provided when initializing a new policy, and must not be provided when loading a pretrained + policy. Therefore, this argument is mutually exclusive with `pretrained_policy_name_or_path`. + """ + if not (pretrained_policy_name_or_path is None) ^ (dataset_stats is None): + raise ValueError("Only one of `pretrained_policy_name_or_path` and `dataset_stats` may be provided.") + + policy_cls, policy_cfg_class = get_policy_and_config_classes(hydra_cfg.policy.name) + + if pretrained_policy_name_or_path is None: + policy_cfg = _policy_cfg_from_hydra_cfg(policy_cfg_class, hydra_cfg) + policy = policy_cls(policy_cfg, dataset_stats) policy.to(get_safe_torch_device(hydra_cfg.device)) else: - raise ValueError(hydra_cfg.policy.name) - - if hydra_cfg.policy.pretrained_model_path: - # TODO(rcadene): hack for old pretrained models from fowm - if hydra_cfg.policy.name == "tdmpc" and "fowm" in hydra_cfg.policy.pretrained_model_path: - if "offline" in hydra_cfg.policy.pretrained_model_path: - policy.step[0] = 25000 - elif "final" in hydra_cfg.policy.pretrained_model_path: - policy.step[0] = 100000 - else: - raise NotImplementedError() - policy.load(hydra_cfg.policy.pretrained_model_path) + policy = policy_cls.from_pretrained(pretrained_policy_name_or_path) return policy diff --git a/lerobot/common/policies/normalize.py b/lerobot/common/policies/normalize.py index df615a21..ab57c8ba 100644 --- a/lerobot/common/policies/normalize.py +++ b/lerobot/common/policies/normalize.py @@ -57,17 +57,28 @@ def create_stats_buffers( ) if stats is not None: + # Note: The clone is needed to make sure that the logic in save_pretrained doesn't see duplicated + # tensors anywhere (for example, when we use the same stats for normalization and + # unnormalization). See the logic here + # https://github.com/huggingface/safetensors/blob/079781fd0dc455ba0fe851e2b4507c33d0c0d407/bindings/python/py_src/safetensors/torch.py#L97. if mode == "mean_std": - buffer["mean"].data = stats[key]["mean"] - buffer["std"].data = stats[key]["std"] + buffer["mean"].data = stats[key]["mean"].clone() + buffer["std"].data = stats[key]["std"].clone() elif mode == "min_max": - buffer["min"].data = stats[key]["min"] - buffer["max"].data = stats[key]["max"] + buffer["min"].data = stats[key]["min"].clone() + buffer["max"].data = stats[key]["max"].clone() stats_buffers[key] = buffer return stats_buffers +def _no_stats_error_str(name: str) -> str: + return ( + f"`{name}` is infinity. You should either initialize with `stats` as an argument, or use a " + "pretrained model." + ) + + class Normalize(nn.Module): """Normalizes data (e.g. "observation.image") for more stable and faster convergence during training.""" @@ -99,7 +110,6 @@ class Normalize(nn.Module): self.shapes = shapes self.modes = modes self.stats = stats - # `self.buffer_observation_state["mean"]` contains `torch.tensor(state_dim)` stats_buffers = create_stats_buffers(shapes, modes, stats) for key, buffer in stats_buffers.items(): setattr(self, "buffer_" + key.replace(".", "_"), buffer) @@ -113,26 +123,14 @@ class Normalize(nn.Module): if mode == "mean_std": mean = buffer["mean"] std = buffer["std"] - assert not torch.isinf(mean).any(), ( - "`mean` is infinity. You forgot to initialize with `stats` as argument, or called " - "`policy.load_state_dict`." - ) - assert not torch.isinf(std).any(), ( - "`std` is infinity. You forgot to initialize with `stats` as argument, or called " - "`policy.load_state_dict`." - ) + assert not torch.isinf(mean).any(), _no_stats_error_str("mean") + assert not torch.isinf(std).any(), _no_stats_error_str("std") batch[key] = (batch[key] - mean) / (std + 1e-8) elif mode == "min_max": min = buffer["min"] max = buffer["max"] - assert not torch.isinf(min).any(), ( - "`min` is infinity. You forgot to initialize with `stats` as argument, or called " - "`policy.load_state_dict`." - ) - assert not torch.isinf(max).any(), ( - "`max` is infinity. You forgot to initialize with `stats` as argument, or called " - "`policy.load_state_dict`." - ) + assert not torch.isinf(min).any(), _no_stats_error_str("min") + assert not torch.isinf(max).any(), _no_stats_error_str("max") # normalize to [0,1] batch[key] = (batch[key] - min) / (max - min) # normalize to [-1, 1] @@ -190,26 +188,14 @@ class Unnormalize(nn.Module): if mode == "mean_std": mean = buffer["mean"] std = buffer["std"] - assert not torch.isinf(mean).any(), ( - "`mean` is infinity. You forgot to initialize with `stats` as argument, or called " - "`policy.load_state_dict`." - ) - assert not torch.isinf(std).any(), ( - "`std` is infinity. You forgot to initialize with `stats` as argument, or called " - "`policy.load_state_dict`." - ) + assert not torch.isinf(mean).any(), _no_stats_error_str("mean") + assert not torch.isinf(std).any(), _no_stats_error_str("std") batch[key] = batch[key] * std + mean elif mode == "min_max": min = buffer["min"] max = buffer["max"] - assert not torch.isinf(min).any(), ( - "`min` is infinity. You forgot to initialize with `stats` as argument, or called " - "`policy.load_state_dict`." - ) - assert not torch.isinf(max).any(), ( - "`max` is infinity. You forgot to initialize with `stats` as argument, or called " - "`policy.load_state_dict`." - ) + assert not torch.isinf(min).any(), _no_stats_error_str("min") + assert not torch.isinf(max).any(), _no_stats_error_str("max") batch[key] = (batch[key] + 1) / 2 batch[key] = batch[key] * (max - min) + min else: diff --git a/lerobot/common/policies/policy_protocol.py b/lerobot/common/policies/policy_protocol.py index 29317fa0..62bc9dfc 100644 --- a/lerobot/common/policies/policy_protocol.py +++ b/lerobot/common/policies/policy_protocol.py @@ -14,7 +14,10 @@ from torch import Tensor @runtime_checkable class Policy(Protocol): - """The required interface for implementing a policy.""" + """The required interface for implementing a policy. + + We also expect all policies to subclass torch.nn.Module and PyTorchModelHubMixin. + """ name: str diff --git a/lerobot/configs/policy/act.yaml b/lerobot/configs/policy/act.yaml index c03cd58f..c0f47c44 100644 --- a/lerobot/configs/policy/act.yaml +++ b/lerobot/configs/policy/act.yaml @@ -34,8 +34,6 @@ eval: policy: name: act - pretrained_model_path: - # Input / output structure. n_obs_steps: 1 chunk_size: 100 # chunk_size @@ -62,7 +60,7 @@ policy: replace_final_stride_with_dilation: false # Transformer layers. pre_norm: false - d_model: 512 + dim_model: 512 n_heads: 8 dim_feedforward: 3200 feedforward_activation: relu diff --git a/lerobot/configs/policy/diffusion.yaml b/lerobot/configs/policy/diffusion.yaml index 49dc8650..f1b05185 100644 --- a/lerobot/configs/policy/diffusion.yaml +++ b/lerobot/configs/policy/diffusion.yaml @@ -46,8 +46,6 @@ override_dataset_stats: policy: name: diffusion - pretrained_model_path: - # Input / output structure. n_obs_steps: 2 horizon: 16 diff --git a/lerobot/scripts/eval.py b/lerobot/scripts/eval.py index 9be2be3a..c74af290 100644 --- a/lerobot/scripts/eval.py +++ b/lerobot/scripts/eval.py @@ -1,30 +1,29 @@ """Evaluate a policy on an environment by running rollouts and computing metrics. -The script may be run in one of two ways: +Usage examples: -1. By providing the path to a config file with the --config argument. -2. By providing a HuggingFace Hub ID with the --hub-id argument. You may also provide a revision number with the - --revision argument. +You want to evaluate a model from the hub (eg: https://huggingface.co/lerobot/diffusion_policy_pusht_image) +for 10 episodes. -In either case, it is possible to override config arguments by adding a list of config.key=value arguments. +``` +python lerobot/scripts/eval.py -p lerobot/diffusion_policy_pusht_image eval.n_episodes=10 +``` -Examples: - -You have a specific config file to go with trained model weights, and want to run 10 episodes. +OR, you want to evaluate a model checkpoint from the LeRobot training script for 10 episodes. ``` python lerobot/scripts/eval.py \ ---config PATH/TO/FOLDER/config.yaml \ -policy.pretrained_model_path=PATH/TO/FOLDER/weights.pth \ -eval.n_episodes=10 + -p outputs/train/diffusion_policy_pusht_image/checkpoints/005000 \ + eval.n_episodes=10 ``` -You have a HuggingFace Hub ID, you know which revision you want, and want to run 10 episodes (note that in this case, -you don't need to specify which weights to use): +Note that in both examples, the repo/folder should contain at least `config.json`, `config.yaml` and +`model.safetensors`. -``` -python lerobot/scripts/eval.py --hub-id HUB/ID --revision v1.0 eval.n_episodes=10 -``` +Note the formatting for providing the number of episodes. Generally, you may provide any number of arguments +with `qualified.parameter.name=value`. In this case, the parameter eval.n_episodes appears as `n_episodes` +nested under `eval` in the `config.yaml` found at +https://huggingface.co/lerobot/diffusion_policy_pusht_image/tree/main. """ import argparse @@ -42,9 +41,12 @@ import numpy as np import torch from datasets import Dataset, Features, Image, Sequence, Value from huggingface_hub import snapshot_download +from huggingface_hub.utils._errors import RepositoryNotFoundError +from huggingface_hub.utils._validators import HFValidationError from PIL import Image as PILImage from tqdm import trange +from lerobot.common.datasets.factory import make_dataset from lerobot.common.datasets.utils import hf_transform_to_torch from lerobot.common.envs.factory import make_env from lerobot.common.envs.utils import postprocess_action, preprocess_observation @@ -349,26 +351,41 @@ def eval_policy( return info -def eval(cfg: dict, out_dir=None): +def eval( + pretrained_policy_path: str | None = None, + hydra_cfg_path: str | None = None, + config_overrides: list[str] | None = None, +): + assert (pretrained_policy_path is None) ^ (hydra_cfg_path is None) + if hydra_cfg_path is None: + hydra_cfg = init_hydra_config(pretrained_policy_path / "config.yaml", config_overrides) + else: + hydra_cfg = init_hydra_config(hydra_cfg_path, config_overrides) + out_dir = ( + f"outputs/eval/{dt.now().strftime('%Y-%m-%d/%H-%M-%S')}_{hydra_cfg.env.name}_{hydra_cfg.policy.name}" + ) + if out_dir is None: raise NotImplementedError() - init_logging() - # Check device is available - get_safe_torch_device(cfg.device, log=True) + get_safe_torch_device(hydra_cfg.device, log=True) torch.backends.cudnn.benchmark = True torch.backends.cuda.matmul.allow_tf32 = True - set_global_seed(cfg.seed) + set_global_seed(hydra_cfg.seed) log_output_dir(out_dir) logging.info("Making environment.") - env = make_env(cfg, num_parallel_envs=cfg.eval.n_episodes) + env = make_env(hydra_cfg, num_parallel_envs=hydra_cfg.eval.n_episodes) logging.info("Making policy.") - policy = make_policy(cfg) + if hydra_cfg_path is None: + policy = make_policy(hydra_cfg=hydra_cfg, pretrained_policy_name_or_path=pretrained_policy_path) + else: + # Note: We need the dataset stats to pass to the policy's normalization modules. + policy = make_policy(hydra_cfg=hydra_cfg, dataset_stats=make_dataset(hydra_cfg).stats) info = eval_policy( env, @@ -376,7 +393,7 @@ def eval(cfg: dict, out_dir=None): max_episodes_rendered=10, video_dir=Path(out_dir) / "eval", return_episode_data=False, - seed=cfg.seed, + seed=hydra_cfg.seed, ) print(info["aggregated"]) @@ -390,13 +407,29 @@ def eval(cfg: dict, out_dir=None): if __name__ == "__main__": + init_logging() + parser = argparse.ArgumentParser( description=__doc__, formatter_class=argparse.RawDescriptionHelpFormatter ) group = parser.add_mutually_exclusive_group(required=True) - group.add_argument("--config", help="Path to a specific yaml config you want to use.") - group.add_argument("--hub-id", help="HuggingFace Hub ID for a pretrained model.") - parser.add_argument("--revision", help="Optionally provide the HuggingFace Hub revision ID.") + group.add_argument( + "-p", + "--pretrained-policy-name-or-path", + help=( + "Either the repo ID of a model hosted on the Hub or a path to a directory containing weights " + "saved using `Policy.save_pretrained`. If not provided, the policy is initialized from scratch " + "(useful for debugging). This argument is mutually exclusive with `--config`." + ), + ) + group.add_argument( + "--config", + help=( + "Path to a yaml config you want to use for initializing a policy from scratch (useful for " + "debugging). This argument is mutually exclusive with `--pretrained-policy-name-or-path` (`-p`)." + ), + ) + parser.add_argument("--revision", help="Optionally provide the Hugging Face Hub revision ID.") parser.add_argument( "overrides", nargs="*", @@ -404,16 +437,28 @@ if __name__ == "__main__": ) args = parser.parse_args() - if args.config is not None: - # Note: For the config_path, Hydra wants a path relative to this script file. - cfg = init_hydra_config(args.config, args.overrides) - elif args.hub_id is not None: - folder = Path(snapshot_download(args.hub_id, revision=args.revision)) - cfg = init_hydra_config( - folder / "config.yaml", [f"policy.pretrained_model_path={folder / 'model.pt'}", *args.overrides] - ) + if args.pretrained_policy_name_or_path is None: + eval(hydra_cfg_path=args.config, config_overrides=args.overrides) + else: + try: + pretrained_policy_path = Path( + snapshot_download(args.pretrained_policy_name_or_path, revision=args.revision) + ) + except HFValidationError: + logging.warning( + "The provided pretrained_policy_name_or_path is not a valid Hugging Face Hub repo ID. " + "Treating it as a local directory." + ) + except RepositoryNotFoundError: + logging.warning( + "The provided pretrained_policy_name_or_path was not found on the Hugging Face Hub. Treating " + "it as a local directory." + ) + pretrained_policy_path = Path(args.pretrained_policy_name_or_path) + if not pretrained_policy_path.is_dir() or not pretrained_policy_path.exists(): + raise ValueError( + "The provided pretrained_policy_name_or_path is not a valid/existing Hugging Face Hub " + "repo ID, nor is it an existing local directory." + ) - eval( - cfg, - out_dir=f"outputs/eval/{dt.now().strftime('%Y-%m-%d/%H-%M-%S')}_{cfg.env.name}_{cfg.policy.name}", - ) + eval(pretrained_policy_path=pretrained_policy_path, config_overrides=args.overrides) diff --git a/lerobot/scripts/train.py b/lerobot/scripts/train.py index 52c29fb3..565c5f3a 100644 --- a/lerobot/scripts/train.py +++ b/lerobot/scripts/train.py @@ -265,7 +265,7 @@ def train(cfg: dict, out_dir=None, job_name=None): env = make_env(cfg, num_parallel_envs=cfg.eval.n_episodes) logging.info("make_policy") - policy = make_policy(cfg, dataset_stats=offline_dataset.stats) + policy = make_policy(hydra_cfg=cfg, dataset_stats=offline_dataset.stats) # Create optimizer and scheduler # Temporary hack to move optimizer out of policy @@ -340,7 +340,14 @@ def train(cfg: dict, out_dir=None, job_name=None): if cfg.training.save_model and step % cfg.training.save_freq == 0: logging.info(f"Checkpoint policy after step {step}") - logger.save_model(policy, identifier=step) + # Note: Save with step as the identifier, and format it to have at least 6 digits but more if + # needed (choose 6 as a minimum for consistency without being overkill). + logger.save_model( + policy, + identifier=str(step).zfill( + max(6, len(str(cfg.training.offline_steps + cfg.training.online_steps))) + ), + ) logging.info("Resume training") # create dataloader for offline training diff --git a/poetry.lock b/poetry.lock index 79c48641..89d35a55 100644 --- a/poetry.lock +++ b/poetry.lock @@ -826,13 +826,13 @@ files = [ [[package]] name = "filelock" -version = "3.13.4" +version = "3.14.0" description = "A platform independent file lock." optional = false python-versions = ">=3.8" files = [ - {file = "filelock-3.13.4-py3-none-any.whl", hash = "sha256:404e5e9253aa60ad457cae1be07c0f0ca90a63931200a47d9b6a6af84fd7b45f"}, - {file = "filelock-3.13.4.tar.gz", hash = "sha256:d13f466618bfde72bd2c18255e269f72542c6e70e7bac83a0232d6b1cc5c8cf4"}, + {file = "filelock-3.14.0-py3-none-any.whl", hash = "sha256:43339835842f110ca7ae60f1e1c160714c5a6afd15a2873419ab185334975c0f"}, + {file = "filelock-3.14.0.tar.gz", hash = "sha256:6ea72da3be9b8c82afd3edcf99f2fffbb5076335a5ae4d03248bb5b6c3eae78a"}, ] [package.extras] @@ -1039,69 +1039,61 @@ preview = ["glfw-preview"] [[package]] name = "grpcio" -version = "1.62.2" +version = "1.63.0" description = "HTTP/2-based RPC framework" optional = false -python-versions = ">=3.7" +python-versions = ">=3.8" files = [ - {file = "grpcio-1.62.2-cp310-cp310-linux_armv7l.whl", hash = "sha256:66344ea741124c38588a664237ac2fa16dfd226964cca23ddc96bd4accccbde5"}, - {file = "grpcio-1.62.2-cp310-cp310-macosx_12_0_universal2.whl", hash = "sha256:5dab7ac2c1e7cb6179c6bfad6b63174851102cbe0682294e6b1d6f0981ad7138"}, - {file = "grpcio-1.62.2-cp310-cp310-manylinux_2_17_aarch64.whl", hash = "sha256:3ad00f3f0718894749d5a8bb0fa125a7980a2f49523731a9b1fabf2b3522aa43"}, - {file = "grpcio-1.62.2-cp310-cp310-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:2e72ddfee62430ea80133d2cbe788e0d06b12f865765cb24a40009668bd8ea05"}, - {file = "grpcio-1.62.2-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:53d3a59a10af4c2558a8e563aed9f256259d2992ae0d3037817b2155f0341de1"}, - {file = "grpcio-1.62.2-cp310-cp310-musllinux_1_1_i686.whl", hash = "sha256:a1511a303f8074f67af4119275b4f954189e8313541da7b88b1b3a71425cdb10"}, - {file = "grpcio-1.62.2-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:b94d41b7412ef149743fbc3178e59d95228a7064c5ab4760ae82b562bdffb199"}, - {file = "grpcio-1.62.2-cp310-cp310-win32.whl", hash = "sha256:a75af2fc7cb1fe25785be7bed1ab18cef959a376cdae7c6870184307614caa3f"}, - {file = "grpcio-1.62.2-cp310-cp310-win_amd64.whl", hash = "sha256:80407bc007754f108dc2061e37480238b0dc1952c855e86a4fc283501ee6bb5d"}, - {file = "grpcio-1.62.2-cp311-cp311-linux_armv7l.whl", hash = "sha256:c1624aa686d4b36790ed1c2e2306cc3498778dffaf7b8dd47066cf819028c3ad"}, - {file = "grpcio-1.62.2-cp311-cp311-macosx_10_10_universal2.whl", hash = "sha256:1c1bb80299bdef33309dff03932264636450c8fdb142ea39f47e06a7153d3063"}, - {file = "grpcio-1.62.2-cp311-cp311-manylinux_2_17_aarch64.whl", hash = "sha256:db068bbc9b1fa16479a82e1ecf172a93874540cb84be69f0b9cb9b7ac3c82670"}, - {file = "grpcio-1.62.2-cp311-cp311-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:e2cc8a308780edbe2c4913d6a49dbdb5befacdf72d489a368566be44cadaef1a"}, - {file = "grpcio-1.62.2-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:d0695ae31a89f1a8fc8256050329a91a9995b549a88619263a594ca31b76d756"}, - {file = "grpcio-1.62.2-cp311-cp311-musllinux_1_1_i686.whl", hash = "sha256:88b4f9ee77191dcdd8810241e89340a12cbe050be3e0d5f2f091c15571cd3930"}, - {file = "grpcio-1.62.2-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:2a0204532aa2f1afd467024b02b4069246320405bc18abec7babab03e2644e75"}, - {file = "grpcio-1.62.2-cp311-cp311-win32.whl", hash = "sha256:6e784f60e575a0de554ef9251cbc2ceb8790914fe324f11e28450047f264ee6f"}, - {file = "grpcio-1.62.2-cp311-cp311-win_amd64.whl", hash = "sha256:112eaa7865dd9e6d7c0556c8b04ae3c3a2dc35d62ad3373ab7f6a562d8199200"}, - {file = "grpcio-1.62.2-cp312-cp312-linux_armv7l.whl", hash = "sha256:65034473fc09628a02fb85f26e73885cf1ed39ebd9cf270247b38689ff5942c5"}, - {file = "grpcio-1.62.2-cp312-cp312-macosx_10_10_universal2.whl", hash = "sha256:d2c1771d0ee3cf72d69bb5e82c6a82f27fbd504c8c782575eddb7839729fbaad"}, - {file = "grpcio-1.62.2-cp312-cp312-manylinux_2_17_aarch64.whl", hash = "sha256:3abe6838196da518863b5d549938ce3159d809218936851b395b09cad9b5d64a"}, - {file = "grpcio-1.62.2-cp312-cp312-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:c5ffeb269f10cedb4f33142b89a061acda9f672fd1357331dbfd043422c94e9e"}, - {file = "grpcio-1.62.2-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:404d3b4b6b142b99ba1cff0b2177d26b623101ea2ce51c25ef6e53d9d0d87bcc"}, - {file = "grpcio-1.62.2-cp312-cp312-musllinux_1_1_i686.whl", hash = "sha256:262cda97efdabb20853d3b5a4c546a535347c14b64c017f628ca0cc7fa780cc6"}, - {file = "grpcio-1.62.2-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:17708db5b11b966373e21519c4c73e5a750555f02fde82276ea2a267077c68ad"}, - {file = "grpcio-1.62.2-cp312-cp312-win32.whl", hash = "sha256:b7ec9e2f8ffc8436f6b642a10019fc513722858f295f7efc28de135d336ac189"}, - {file = "grpcio-1.62.2-cp312-cp312-win_amd64.whl", hash = "sha256:aa787b83a3cd5e482e5c79be030e2b4a122ecc6c5c6c4c42a023a2b581fdf17b"}, - {file = "grpcio-1.62.2-cp37-cp37m-linux_armv7l.whl", hash = "sha256:cfd23ad29bfa13fd4188433b0e250f84ec2c8ba66b14a9877e8bce05b524cf54"}, - {file = "grpcio-1.62.2-cp37-cp37m-macosx_10_10_universal2.whl", hash = "sha256:af15e9efa4d776dfcecd1d083f3ccfb04f876d613e90ef8432432efbeeac689d"}, - {file = "grpcio-1.62.2-cp37-cp37m-manylinux_2_17_aarch64.whl", hash = "sha256:f4aa94361bb5141a45ca9187464ae81a92a2a135ce2800b2203134f7a1a1d479"}, - {file = "grpcio-1.62.2-cp37-cp37m-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:82af3613a219512a28ee5c95578eb38d44dd03bca02fd918aa05603c41018051"}, - {file = "grpcio-1.62.2-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:55ddaf53474e8caeb29eb03e3202f9d827ad3110475a21245f3c7712022882a9"}, - {file = "grpcio-1.62.2-cp37-cp37m-musllinux_1_1_i686.whl", hash = "sha256:c79b518c56dddeec79e5500a53d8a4db90da995dfe1738c3ac57fe46348be049"}, - {file = "grpcio-1.62.2-cp37-cp37m-musllinux_1_1_x86_64.whl", hash = "sha256:a5eb4844e5e60bf2c446ef38c5b40d7752c6effdee882f716eb57ae87255d20a"}, - {file = "grpcio-1.62.2-cp37-cp37m-win_amd64.whl", hash = "sha256:aaae70364a2d1fb238afd6cc9fcb10442b66e397fd559d3f0968d28cc3ac929c"}, - {file = "grpcio-1.62.2-cp38-cp38-linux_armv7l.whl", hash = "sha256:1bcfe5070e4406f489e39325b76caeadab28c32bf9252d3ae960c79935a4cc36"}, - {file = "grpcio-1.62.2-cp38-cp38-macosx_10_10_universal2.whl", hash = "sha256:da6a7b6b938c15fa0f0568e482efaae9c3af31963eec2da4ff13a6d8ec2888e4"}, - {file = "grpcio-1.62.2-cp38-cp38-manylinux_2_17_aarch64.whl", hash = "sha256:41955b641c34db7d84db8d306937b72bc4968eef1c401bea73081a8d6c3d8033"}, - {file = "grpcio-1.62.2-cp38-cp38-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:c772f225483905f675cb36a025969eef9712f4698364ecd3a63093760deea1bc"}, - {file = "grpcio-1.62.2-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:07ce1f775d37ca18c7a141300e5b71539690efa1f51fe17f812ca85b5e73262f"}, - {file = "grpcio-1.62.2-cp38-cp38-musllinux_1_1_i686.whl", hash = "sha256:26f415f40f4a93579fd648f48dca1c13dfacdfd0290f4a30f9b9aeb745026811"}, - {file = "grpcio-1.62.2-cp38-cp38-musllinux_1_1_x86_64.whl", hash = "sha256:db707e3685ff16fc1eccad68527d072ac8bdd2e390f6daa97bc394ea7de4acea"}, - {file = "grpcio-1.62.2-cp38-cp38-win32.whl", hash = "sha256:589ea8e75de5fd6df387de53af6c9189c5231e212b9aa306b6b0d4f07520fbb9"}, - {file = "grpcio-1.62.2-cp38-cp38-win_amd64.whl", hash = "sha256:3c3ed41f4d7a3aabf0f01ecc70d6b5d00ce1800d4af652a549de3f7cf35c4abd"}, - {file = "grpcio-1.62.2-cp39-cp39-linux_armv7l.whl", hash = "sha256:162ccf61499c893831b8437120600290a99c0bc1ce7b51f2c8d21ec87ff6af8b"}, - {file = "grpcio-1.62.2-cp39-cp39-macosx_10_10_universal2.whl", hash = "sha256:f27246d7da7d7e3bd8612f63785a7b0c39a244cf14b8dd9dd2f2fab939f2d7f1"}, - {file = "grpcio-1.62.2-cp39-cp39-manylinux_2_17_aarch64.whl", hash = "sha256:2507006c8a478f19e99b6fe36a2464696b89d40d88f34e4b709abe57e1337467"}, - {file = "grpcio-1.62.2-cp39-cp39-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:a90ac47a8ce934e2c8d71e317d2f9e7e6aaceb2d199de940ce2c2eb611b8c0f4"}, - {file = "grpcio-1.62.2-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:99701979bcaaa7de8d5f60476487c5df8f27483624f1f7e300ff4669ee44d1f2"}, - {file = "grpcio-1.62.2-cp39-cp39-musllinux_1_1_i686.whl", hash = "sha256:af7dc3f7a44f10863b1b0ecab4078f0a00f561aae1edbd01fd03ad4dcf61c9e9"}, - {file = "grpcio-1.62.2-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:fa63245271920786f4cb44dcada4983a3516be8f470924528cf658731864c14b"}, - {file = "grpcio-1.62.2-cp39-cp39-win32.whl", hash = "sha256:c6ad9c39704256ed91a1cffc1379d63f7d0278d6a0bad06b0330f5d30291e3a3"}, - {file = "grpcio-1.62.2-cp39-cp39-win_amd64.whl", hash = "sha256:16da954692fd61aa4941fbeda405a756cd96b97b5d95ca58a92547bba2c1624f"}, - {file = "grpcio-1.62.2.tar.gz", hash = "sha256:c77618071d96b7a8be2c10701a98537823b9c65ba256c0b9067e0594cdbd954d"}, + {file = "grpcio-1.63.0-cp310-cp310-linux_armv7l.whl", hash = "sha256:2e93aca840c29d4ab5db93f94ed0a0ca899e241f2e8aec6334ab3575dc46125c"}, + {file = "grpcio-1.63.0-cp310-cp310-macosx_12_0_universal2.whl", hash = "sha256:91b73d3f1340fefa1e1716c8c1ec9930c676d6b10a3513ab6c26004cb02d8b3f"}, + {file = "grpcio-1.63.0-cp310-cp310-manylinux_2_17_aarch64.whl", hash = "sha256:b3afbd9d6827fa6f475a4f91db55e441113f6d3eb9b7ebb8fb806e5bb6d6bd0d"}, + {file = "grpcio-1.63.0-cp310-cp310-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:8f3f6883ce54a7a5f47db43289a0a4c776487912de1a0e2cc83fdaec9685cc9f"}, + {file = "grpcio-1.63.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:cf8dae9cc0412cb86c8de5a8f3be395c5119a370f3ce2e69c8b7d46bb9872c8d"}, + {file = "grpcio-1.63.0-cp310-cp310-musllinux_1_1_i686.whl", hash = "sha256:08e1559fd3b3b4468486b26b0af64a3904a8dbc78d8d936af9c1cf9636eb3e8b"}, + {file = "grpcio-1.63.0-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:5c039ef01516039fa39da8a8a43a95b64e288f79f42a17e6c2904a02a319b357"}, + {file = "grpcio-1.63.0-cp310-cp310-win32.whl", hash = "sha256:ad2ac8903b2eae071055a927ef74121ed52d69468e91d9bcbd028bd0e554be6d"}, + {file = "grpcio-1.63.0-cp310-cp310-win_amd64.whl", hash = "sha256:b2e44f59316716532a993ca2966636df6fbe7be4ab6f099de6815570ebe4383a"}, + {file = "grpcio-1.63.0-cp311-cp311-linux_armv7l.whl", hash = "sha256:f28f8b2db7b86c77916829d64ab21ff49a9d8289ea1564a2b2a3a8ed9ffcccd3"}, + {file = "grpcio-1.63.0-cp311-cp311-macosx_10_9_universal2.whl", hash = "sha256:65bf975639a1f93bee63ca60d2e4951f1b543f498d581869922910a476ead2f5"}, + {file = "grpcio-1.63.0-cp311-cp311-manylinux_2_17_aarch64.whl", hash = "sha256:b5194775fec7dc3dbd6a935102bb156cd2c35efe1685b0a46c67b927c74f0cfb"}, + {file = "grpcio-1.63.0-cp311-cp311-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:e4cbb2100ee46d024c45920d16e888ee5d3cf47c66e316210bc236d5bebc42b3"}, + {file = "grpcio-1.63.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:1ff737cf29b5b801619f10e59b581869e32f400159e8b12d7a97e7e3bdeee6a2"}, + {file = "grpcio-1.63.0-cp311-cp311-musllinux_1_1_i686.whl", hash = "sha256:cd1e68776262dd44dedd7381b1a0ad09d9930ffb405f737d64f505eb7f77d6c7"}, + {file = "grpcio-1.63.0-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:93f45f27f516548e23e4ec3fbab21b060416007dbe768a111fc4611464cc773f"}, + {file = "grpcio-1.63.0-cp311-cp311-win32.whl", hash = "sha256:878b1d88d0137df60e6b09b74cdb73db123f9579232c8456f53e9abc4f62eb3c"}, + {file = "grpcio-1.63.0-cp311-cp311-win_amd64.whl", hash = "sha256:756fed02dacd24e8f488f295a913f250b56b98fb793f41d5b2de6c44fb762434"}, + {file = "grpcio-1.63.0-cp312-cp312-linux_armv7l.whl", hash = "sha256:93a46794cc96c3a674cdfb59ef9ce84d46185fe9421baf2268ccb556f8f81f57"}, + {file = "grpcio-1.63.0-cp312-cp312-macosx_10_9_universal2.whl", hash = "sha256:a7b19dfc74d0be7032ca1eda0ed545e582ee46cd65c162f9e9fc6b26ef827dc6"}, + {file = "grpcio-1.63.0-cp312-cp312-manylinux_2_17_aarch64.whl", hash = "sha256:8064d986d3a64ba21e498b9a376cbc5d6ab2e8ab0e288d39f266f0fca169b90d"}, + {file = "grpcio-1.63.0-cp312-cp312-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:219bb1848cd2c90348c79ed0a6b0ea51866bc7e72fa6e205e459fedab5770172"}, + {file = "grpcio-1.63.0-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:a2d60cd1d58817bc5985fae6168d8b5655c4981d448d0f5b6194bbcc038090d2"}, + {file = "grpcio-1.63.0-cp312-cp312-musllinux_1_1_i686.whl", hash = "sha256:9e350cb096e5c67832e9b6e018cf8a0d2a53b2a958f6251615173165269a91b0"}, + {file = "grpcio-1.63.0-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:56cdf96ff82e3cc90dbe8bac260352993f23e8e256e063c327b6cf9c88daf7a9"}, + {file = "grpcio-1.63.0-cp312-cp312-win32.whl", hash = "sha256:3a6d1f9ea965e750db7b4ee6f9fdef5fdf135abe8a249e75d84b0a3e0c668a1b"}, + {file = "grpcio-1.63.0-cp312-cp312-win_amd64.whl", hash = "sha256:d2497769895bb03efe3187fb1888fc20e98a5f18b3d14b606167dacda5789434"}, + {file = "grpcio-1.63.0-cp38-cp38-linux_armv7l.whl", hash = "sha256:fdf348ae69c6ff484402cfdb14e18c1b0054ac2420079d575c53a60b9b2853ae"}, + {file = "grpcio-1.63.0-cp38-cp38-macosx_10_9_universal2.whl", hash = "sha256:a3abfe0b0f6798dedd2e9e92e881d9acd0fdb62ae27dcbbfa7654a57e24060c0"}, + {file = "grpcio-1.63.0-cp38-cp38-manylinux_2_17_aarch64.whl", hash = "sha256:6ef0ad92873672a2a3767cb827b64741c363ebaa27e7f21659e4e31f4d750280"}, + {file = "grpcio-1.63.0-cp38-cp38-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:b416252ac5588d9dfb8a30a191451adbf534e9ce5f56bb02cd193f12d8845b7f"}, + {file = "grpcio-1.63.0-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:e3b77eaefc74d7eb861d3ffbdf91b50a1bb1639514ebe764c47773b833fa2d91"}, + {file = "grpcio-1.63.0-cp38-cp38-musllinux_1_1_i686.whl", hash = "sha256:b005292369d9c1f80bf70c1db1c17c6c342da7576f1c689e8eee4fb0c256af85"}, + {file = "grpcio-1.63.0-cp38-cp38-musllinux_1_1_x86_64.whl", hash = "sha256:cdcda1156dcc41e042d1e899ba1f5c2e9f3cd7625b3d6ebfa619806a4c1aadda"}, + {file = "grpcio-1.63.0-cp38-cp38-win32.whl", hash = "sha256:01799e8649f9e94ba7db1aeb3452188048b0019dc37696b0f5ce212c87c560c3"}, + {file = "grpcio-1.63.0-cp38-cp38-win_amd64.whl", hash = "sha256:6a1a3642d76f887aa4009d92f71eb37809abceb3b7b5a1eec9c554a246f20e3a"}, + {file = "grpcio-1.63.0-cp39-cp39-linux_armv7l.whl", hash = "sha256:75f701ff645858a2b16bc8c9fc68af215a8bb2d5a9b647448129de6e85d52bce"}, + {file = "grpcio-1.63.0-cp39-cp39-macosx_10_9_universal2.whl", hash = "sha256:cacdef0348a08e475a721967f48206a2254a1b26ee7637638d9e081761a5ba86"}, + {file = "grpcio-1.63.0-cp39-cp39-manylinux_2_17_aarch64.whl", hash = "sha256:0697563d1d84d6985e40ec5ec596ff41b52abb3fd91ec240e8cb44a63b895094"}, + {file = "grpcio-1.63.0-cp39-cp39-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:6426e1fb92d006e47476d42b8f240c1d916a6d4423c5258ccc5b105e43438f61"}, + {file = "grpcio-1.63.0-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:e48cee31bc5f5a31fb2f3b573764bd563aaa5472342860edcc7039525b53e46a"}, + {file = "grpcio-1.63.0-cp39-cp39-musllinux_1_1_i686.whl", hash = "sha256:50344663068041b34a992c19c600236e7abb42d6ec32567916b87b4c8b8833b3"}, + {file = "grpcio-1.63.0-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:259e11932230d70ef24a21b9fb5bb947eb4703f57865a404054400ee92f42f5d"}, + {file = "grpcio-1.63.0-cp39-cp39-win32.whl", hash = "sha256:a44624aad77bf8ca198c55af811fd28f2b3eaf0a50ec5b57b06c034416ef2d0a"}, + {file = "grpcio-1.63.0-cp39-cp39-win_amd64.whl", hash = "sha256:166e5c460e5d7d4656ff9e63b13e1f6029b122104c1633d5f37eaea348d7356d"}, + {file = "grpcio-1.63.0.tar.gz", hash = "sha256:f3023e14805c61bc439fb40ca545ac3d5740ce66120a678a3c6c2c55b70343d1"}, ] [package.extras] -protobuf = ["grpcio-tools (>=1.62.2)"] +protobuf = ["grpcio-tools (>=1.63.0)"] [[package]] name = "gym-aloha" @@ -2379,7 +2371,6 @@ optional = false python-versions = ">=3.9" files = [ {file = "pandas-2.2.2-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:90c6fca2acf139569e74e8781709dccb6fe25940488755716d1d354d6bc58bce"}, - {file = "pandas-2.2.2-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:c7adfc142dac335d8c1e0dcbd37eb8617eac386596eb9e1a1b77791cf2498238"}, {file = "pandas-2.2.2-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:4abfe0be0d7221be4f12552995e58723c7422c80a659da13ca382697de830c08"}, {file = "pandas-2.2.2-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:8635c16bf3d99040fdf3ca3db669a7250ddf49c55dc4aa8fe0ae0fa8d6dcc1f0"}, {file = "pandas-2.2.2-cp310-cp310-musllinux_1_1_aarch64.whl", hash = "sha256:40ae1dffb3967a52203105a077415a86044a2bea011b5f321c6aa64b379a3f51"}, @@ -2400,7 +2391,6 @@ files = [ {file = "pandas-2.2.2-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:43498c0bdb43d55cb162cdc8c06fac328ccb5d2eabe3cadeb3529ae6f0517c32"}, {file = "pandas-2.2.2-cp312-cp312-win_amd64.whl", hash = "sha256:d187d355ecec3629624fccb01d104da7d7f391db0311145817525281e2804d23"}, {file = "pandas-2.2.2-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:0ca6377b8fca51815f382bd0b697a0814c8bda55115678cbc94c30aacbb6eff2"}, - {file = "pandas-2.2.2-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:9057e6aa78a584bc93a13f0a9bf7e753a5e9770a30b4d758b8d5f2a62a9433cd"}, {file = "pandas-2.2.2-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:001910ad31abc7bf06f49dcc903755d2f7f3a9186c0c040b827e522e9cef0863"}, {file = "pandas-2.2.2-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:66b479b0bd07204e37583c191535505410daa8df638fd8e75ae1b383851fe921"}, {file = "pandas-2.2.2-cp39-cp39-musllinux_1_1_aarch64.whl", hash = "sha256:a77e9d1c386196879aa5eb712e77461aaee433e54c68cf253053a73b7e49c33a"}, @@ -3023,104 +3013,90 @@ files = [ [[package]] name = "regex" -version = "2024.4.16" +version = "2024.4.28" description = "Alternative regular expression module, to replace re." optional = false -python-versions = ">=3.7" +python-versions = ">=3.8" files = [ - {file = "regex-2024.4.16-cp310-cp310-macosx_10_9_universal2.whl", hash = "sha256:fb83cc090eac63c006871fd24db5e30a1f282faa46328572661c0a24a2323a08"}, - {file = "regex-2024.4.16-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:8c91e1763696c0eb66340c4df98623c2d4e77d0746b8f8f2bee2c6883fd1fe18"}, - {file = "regex-2024.4.16-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:10188fe732dec829c7acca7422cdd1bf57d853c7199d5a9e96bb4d40db239c73"}, - {file = "regex-2024.4.16-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:956b58d692f235cfbf5b4f3abd6d99bf102f161ccfe20d2fd0904f51c72c4c66"}, - {file = "regex-2024.4.16-cp310-cp310-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:a70b51f55fd954d1f194271695821dd62054d949efd6368d8be64edd37f55c86"}, - {file = "regex-2024.4.16-cp310-cp310-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:5c02fcd2bf45162280613d2e4a1ca3ac558ff921ae4e308ecb307650d3a6ee51"}, - {file = "regex-2024.4.16-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:c4ed75ea6892a56896d78f11006161eea52c45a14994794bcfa1654430984b22"}, - {file = "regex-2024.4.16-cp310-cp310-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:bd727ad276bb91928879f3aa6396c9a1d34e5e180dce40578421a691eeb77f47"}, - {file = "regex-2024.4.16-cp310-cp310-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_12_x86_64.manylinux2010_x86_64.whl", hash = "sha256:7cbc5d9e8a1781e7be17da67b92580d6ce4dcef5819c1b1b89f49d9678cc278c"}, - {file = "regex-2024.4.16-cp310-cp310-musllinux_1_1_aarch64.whl", hash = "sha256:78fddb22b9ef810b63ef341c9fcf6455232d97cfe03938cbc29e2672c436670e"}, - {file = "regex-2024.4.16-cp310-cp310-musllinux_1_1_i686.whl", hash = "sha256:445ca8d3c5a01309633a0c9db57150312a181146315693273e35d936472df912"}, - {file = "regex-2024.4.16-cp310-cp310-musllinux_1_1_ppc64le.whl", hash = "sha256:95399831a206211d6bc40224af1c635cb8790ddd5c7493e0bd03b85711076a53"}, - {file = "regex-2024.4.16-cp310-cp310-musllinux_1_1_s390x.whl", hash = "sha256:7731728b6568fc286d86745f27f07266de49603a6fdc4d19c87e8c247be452af"}, - {file = "regex-2024.4.16-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:4facc913e10bdba42ec0aee76d029aedda628161a7ce4116b16680a0413f658a"}, - {file = "regex-2024.4.16-cp310-cp310-win32.whl", hash = "sha256:911742856ce98d879acbea33fcc03c1d8dc1106234c5e7d068932c945db209c0"}, - {file = "regex-2024.4.16-cp310-cp310-win_amd64.whl", hash = "sha256:e0a2df336d1135a0b3a67f3bbf78a75f69562c1199ed9935372b82215cddd6e2"}, - {file = "regex-2024.4.16-cp311-cp311-macosx_10_9_universal2.whl", hash = "sha256:1210365faba7c2150451eb78ec5687871c796b0f1fa701bfd2a4a25420482d26"}, - {file = "regex-2024.4.16-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:9ab40412f8cd6f615bfedea40c8bf0407d41bf83b96f6fc9ff34976d6b7037fd"}, - {file = "regex-2024.4.16-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:fd80d1280d473500d8086d104962a82d77bfbf2b118053824b7be28cd5a79ea5"}, - {file = "regex-2024.4.16-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:7bb966fdd9217e53abf824f437a5a2d643a38d4fd5fd0ca711b9da683d452969"}, - {file = "regex-2024.4.16-cp311-cp311-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:20b7a68444f536365af42a75ccecb7ab41a896a04acf58432db9e206f4e525d6"}, - {file = "regex-2024.4.16-cp311-cp311-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:b74586dd0b039c62416034f811d7ee62810174bb70dffcca6439f5236249eb09"}, - {file = "regex-2024.4.16-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:0c8290b44d8b0af4e77048646c10c6e3aa583c1ca67f3b5ffb6e06cf0c6f0f89"}, - {file = "regex-2024.4.16-cp311-cp311-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:f2d80a6749724b37853ece57988b39c4e79d2b5fe2869a86e8aeae3bbeef9eb0"}, - {file = "regex-2024.4.16-cp311-cp311-musllinux_1_1_aarch64.whl", hash = "sha256:3a1018e97aeb24e4f939afcd88211ace472ba566efc5bdf53fd8fd7f41fa7170"}, - {file = "regex-2024.4.16-cp311-cp311-musllinux_1_1_i686.whl", hash = "sha256:8d015604ee6204e76569d2f44e5a210728fa917115bef0d102f4107e622b08d5"}, - {file = "regex-2024.4.16-cp311-cp311-musllinux_1_1_ppc64le.whl", hash = "sha256:3d5ac5234fb5053850d79dd8eb1015cb0d7d9ed951fa37aa9e6249a19aa4f336"}, - {file = "regex-2024.4.16-cp311-cp311-musllinux_1_1_s390x.whl", hash = "sha256:0a38d151e2cdd66d16dab550c22f9521ba79761423b87c01dae0a6e9add79c0d"}, - {file = "regex-2024.4.16-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:159dc4e59a159cb8e4e8f8961eb1fa5d58f93cb1acd1701d8aff38d45e1a84a6"}, - {file = "regex-2024.4.16-cp311-cp311-win32.whl", hash = "sha256:ba2336d6548dee3117520545cfe44dc28a250aa091f8281d28804aa8d707d93d"}, - {file = "regex-2024.4.16-cp311-cp311-win_amd64.whl", hash = "sha256:8f83b6fd3dc3ba94d2b22717f9c8b8512354fd95221ac661784df2769ea9bba9"}, - {file = "regex-2024.4.16-cp312-cp312-macosx_10_9_universal2.whl", hash = "sha256:80b696e8972b81edf0af2a259e1b2a4a661f818fae22e5fa4fa1a995fb4a40fd"}, - {file = "regex-2024.4.16-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:d61ae114d2a2311f61d90c2ef1358518e8f05eafda76eaf9c772a077e0b465ec"}, - {file = "regex-2024.4.16-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:8ba6745440b9a27336443b0c285d705ce73adb9ec90e2f2004c64d95ab5a7598"}, - {file = "regex-2024.4.16-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:6295004b2dd37b0835ea5c14a33e00e8cfa3c4add4d587b77287825f3418d310"}, - {file = "regex-2024.4.16-cp312-cp312-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:4aba818dcc7263852aabb172ec27b71d2abca02a593b95fa79351b2774eb1d2b"}, - {file = "regex-2024.4.16-cp312-cp312-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:d0800631e565c47520aaa04ae38b96abc5196fe8b4aa9bd864445bd2b5848a7a"}, - {file = "regex-2024.4.16-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:08dea89f859c3df48a440dbdcd7b7155bc675f2fa2ec8c521d02dc69e877db70"}, - {file = "regex-2024.4.16-cp312-cp312-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:eeaa0b5328b785abc344acc6241cffde50dc394a0644a968add75fcefe15b9d4"}, - {file = "regex-2024.4.16-cp312-cp312-musllinux_1_1_aarch64.whl", hash = "sha256:4e819a806420bc010489f4e741b3036071aba209f2e0989d4750b08b12a9343f"}, - {file = "regex-2024.4.16-cp312-cp312-musllinux_1_1_i686.whl", hash = "sha256:c2d0e7cbb6341e830adcbfa2479fdeebbfbb328f11edd6b5675674e7a1e37730"}, - {file = "regex-2024.4.16-cp312-cp312-musllinux_1_1_ppc64le.whl", hash = "sha256:91797b98f5e34b6a49f54be33f72e2fb658018ae532be2f79f7c63b4ae225145"}, - {file = "regex-2024.4.16-cp312-cp312-musllinux_1_1_s390x.whl", hash = "sha256:d2da13568eff02b30fd54fccd1e042a70fe920d816616fda4bf54ec705668d81"}, - {file = "regex-2024.4.16-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:370c68dc5570b394cbaadff50e64d705f64debed30573e5c313c360689b6aadc"}, - {file = "regex-2024.4.16-cp312-cp312-win32.whl", hash = "sha256:904c883cf10a975b02ab3478bce652f0f5346a2c28d0a8521d97bb23c323cc8b"}, - {file = "regex-2024.4.16-cp312-cp312-win_amd64.whl", hash = "sha256:785c071c982dce54d44ea0b79cd6dfafddeccdd98cfa5f7b86ef69b381b457d9"}, - {file = "regex-2024.4.16-cp37-cp37m-macosx_10_9_x86_64.whl", hash = "sha256:e2f142b45c6fed48166faeb4303b4b58c9fcd827da63f4cf0a123c3480ae11fb"}, - {file = "regex-2024.4.16-cp37-cp37m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:e87ab229332ceb127a165612d839ab87795972102cb9830e5f12b8c9a5c1b508"}, - {file = "regex-2024.4.16-cp37-cp37m-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:81500ed5af2090b4a9157a59dbc89873a25c33db1bb9a8cf123837dcc9765047"}, - {file = "regex-2024.4.16-cp37-cp37m-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:b340cccad138ecb363324aa26893963dcabb02bb25e440ebdf42e30963f1a4e0"}, - {file = "regex-2024.4.16-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:2c72608e70f053643437bd2be0608f7f1c46d4022e4104d76826f0839199347a"}, - {file = "regex-2024.4.16-cp37-cp37m-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:a01fe2305e6232ef3e8f40bfc0f0f3a04def9aab514910fa4203bafbc0bb4682"}, - {file = "regex-2024.4.16-cp37-cp37m-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_12_x86_64.manylinux2010_x86_64.whl", hash = "sha256:03576e3a423d19dda13e55598f0fd507b5d660d42c51b02df4e0d97824fdcae3"}, - {file = "regex-2024.4.16-cp37-cp37m-musllinux_1_1_aarch64.whl", hash = "sha256:549c3584993772e25f02d0656ac48abdda73169fe347263948cf2b1cead622f3"}, - {file = "regex-2024.4.16-cp37-cp37m-musllinux_1_1_i686.whl", hash = "sha256:34422d5a69a60b7e9a07a690094e824b66f5ddc662a5fc600d65b7c174a05f04"}, - {file = "regex-2024.4.16-cp37-cp37m-musllinux_1_1_ppc64le.whl", hash = "sha256:5f580c651a72b75c39e311343fe6875d6f58cf51c471a97f15a938d9fe4e0d37"}, - {file = "regex-2024.4.16-cp37-cp37m-musllinux_1_1_s390x.whl", hash = "sha256:3399dd8a7495bbb2bacd59b84840eef9057826c664472e86c91d675d007137f5"}, - {file = "regex-2024.4.16-cp37-cp37m-musllinux_1_1_x86_64.whl", hash = "sha256:8d1f86f3f4e2388aa3310b50694ac44daefbd1681def26b4519bd050a398dc5a"}, - {file = "regex-2024.4.16-cp37-cp37m-win32.whl", hash = "sha256:dd5acc0a7d38fdc7a3a6fd3ad14c880819008ecb3379626e56b163165162cc46"}, - {file = "regex-2024.4.16-cp37-cp37m-win_amd64.whl", hash = "sha256:ba8122e3bb94ecda29a8de4cf889f600171424ea586847aa92c334772d200331"}, - {file = "regex-2024.4.16-cp38-cp38-macosx_10_9_universal2.whl", hash = "sha256:743deffdf3b3481da32e8a96887e2aa945ec6685af1cfe2bcc292638c9ba2f48"}, - {file = "regex-2024.4.16-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:7571f19f4a3fd00af9341c7801d1ad1967fc9c3f5e62402683047e7166b9f2b4"}, - {file = "regex-2024.4.16-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:df79012ebf6f4efb8d307b1328226aef24ca446b3ff8d0e30202d7ebcb977a8c"}, - {file = "regex-2024.4.16-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:e757d475953269fbf4b441207bb7dbdd1c43180711b6208e129b637792ac0b93"}, - {file = "regex-2024.4.16-cp38-cp38-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:4313ab9bf6a81206c8ac28fdfcddc0435299dc88cad12cc6305fd0e78b81f9e4"}, - {file = "regex-2024.4.16-cp38-cp38-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:d83c2bc678453646f1a18f8db1e927a2d3f4935031b9ad8a76e56760461105dd"}, - {file = "regex-2024.4.16-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:9df1bfef97db938469ef0a7354b2d591a2d438bc497b2c489471bec0e6baf7c4"}, - {file = "regex-2024.4.16-cp38-cp38-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:62120ed0de69b3649cc68e2965376048793f466c5a6c4370fb27c16c1beac22d"}, - {file = "regex-2024.4.16-cp38-cp38-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_12_x86_64.manylinux2010_x86_64.whl", hash = "sha256:c2ef6f7990b6e8758fe48ad08f7e2f66c8f11dc66e24093304b87cae9037bb4a"}, - {file = "regex-2024.4.16-cp38-cp38-musllinux_1_1_aarch64.whl", hash = "sha256:8fc6976a3395fe4d1fbeb984adaa8ec652a1e12f36b56ec8c236e5117b585427"}, - {file = "regex-2024.4.16-cp38-cp38-musllinux_1_1_i686.whl", hash = "sha256:03e68f44340528111067cecf12721c3df4811c67268b897fbe695c95f860ac42"}, - {file = "regex-2024.4.16-cp38-cp38-musllinux_1_1_ppc64le.whl", hash = "sha256:ec7e0043b91115f427998febaa2beb82c82df708168b35ece3accb610b91fac1"}, - {file = "regex-2024.4.16-cp38-cp38-musllinux_1_1_s390x.whl", hash = "sha256:c21fc21a4c7480479d12fd8e679b699f744f76bb05f53a1d14182b31f55aac76"}, - {file = "regex-2024.4.16-cp38-cp38-musllinux_1_1_x86_64.whl", hash = "sha256:12f6a3f2f58bb7344751919a1876ee1b976fe08b9ffccb4bbea66f26af6017b9"}, - {file = "regex-2024.4.16-cp38-cp38-win32.whl", hash = "sha256:479595a4fbe9ed8f8f72c59717e8cf222da2e4c07b6ae5b65411e6302af9708e"}, - {file = "regex-2024.4.16-cp38-cp38-win_amd64.whl", hash = "sha256:0534b034fba6101611968fae8e856c1698da97ce2efb5c2b895fc8b9e23a5834"}, - {file = "regex-2024.4.16-cp39-cp39-macosx_10_9_universal2.whl", hash = "sha256:a7ccdd1c4a3472a7533b0a7aa9ee34c9a2bef859ba86deec07aff2ad7e0c3b94"}, - {file = "regex-2024.4.16-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:6f2f017c5be19984fbbf55f8af6caba25e62c71293213f044da3ada7091a4455"}, - {file = "regex-2024.4.16-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:803b8905b52de78b173d3c1e83df0efb929621e7b7c5766c0843704d5332682f"}, - {file = "regex-2024.4.16-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:684008ec44ad275832a5a152f6e764bbe1914bea10968017b6feaecdad5736e0"}, - {file = "regex-2024.4.16-cp39-cp39-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:65436dce9fdc0aeeb0a0effe0839cb3d6a05f45aa45a4d9f9c60989beca78b9c"}, - {file = "regex-2024.4.16-cp39-cp39-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:ea355eb43b11764cf799dda62c658c4d2fdb16af41f59bb1ccfec517b60bcb07"}, - {file = "regex-2024.4.16-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:98c1165f3809ce7774f05cb74e5408cd3aa93ee8573ae959a97a53db3ca3180d"}, - {file = "regex-2024.4.16-cp39-cp39-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:cccc79a9be9b64c881f18305a7c715ba199e471a3973faeb7ba84172abb3f317"}, - {file = "regex-2024.4.16-cp39-cp39-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_12_x86_64.manylinux2010_x86_64.whl", hash = "sha256:00169caa125f35d1bca6045d65a662af0202704489fada95346cfa092ec23f39"}, - {file = "regex-2024.4.16-cp39-cp39-musllinux_1_1_aarch64.whl", hash = "sha256:6cc38067209354e16c5609b66285af17a2863a47585bcf75285cab33d4c3b8df"}, - {file = "regex-2024.4.16-cp39-cp39-musllinux_1_1_i686.whl", hash = "sha256:23cff1b267038501b179ccbbd74a821ac4a7192a1852d1d558e562b507d46013"}, - {file = "regex-2024.4.16-cp39-cp39-musllinux_1_1_ppc64le.whl", hash = "sha256:b9d320b3bf82a39f248769fc7f188e00f93526cc0fe739cfa197868633d44701"}, - {file = "regex-2024.4.16-cp39-cp39-musllinux_1_1_s390x.whl", hash = "sha256:89ec7f2c08937421bbbb8b48c54096fa4f88347946d4747021ad85f1b3021b3c"}, - {file = "regex-2024.4.16-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:4918fd5f8b43aa7ec031e0fef1ee02deb80b6afd49c85f0790be1dc4ce34cb50"}, - {file = "regex-2024.4.16-cp39-cp39-win32.whl", hash = "sha256:684e52023aec43bdf0250e843e1fdd6febbe831bd9d52da72333fa201aaa2335"}, - {file = "regex-2024.4.16-cp39-cp39-win_amd64.whl", hash = "sha256:e697e1c0238133589e00c244a8b676bc2cfc3ab4961318d902040d099fec7483"}, - {file = "regex-2024.4.16.tar.gz", hash = "sha256:fa454d26f2e87ad661c4f0c5a5fe4cf6aab1e307d1b94f16ffdfcb089ba685c0"}, + {file = "regex-2024.4.28-cp310-cp310-macosx_10_9_universal2.whl", hash = "sha256:cd196d056b40af073d95a2879678585f0b74ad35190fac04ca67954c582c6b61"}, + {file = "regex-2024.4.28-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:8bb381f777351bd534462f63e1c6afb10a7caa9fa2a421ae22c26e796fe31b1f"}, + {file = "regex-2024.4.28-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:47af45b6153522733aa6e92543938e97a70ce0900649ba626cf5aad290b737b6"}, + {file = "regex-2024.4.28-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:99d6a550425cc51c656331af0e2b1651e90eaaa23fb4acde577cf15068e2e20f"}, + {file = "regex-2024.4.28-cp310-cp310-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:bf29304a8011feb58913c382902fde3395957a47645bf848eea695839aa101b7"}, + {file = "regex-2024.4.28-cp310-cp310-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:92da587eee39a52c91aebea8b850e4e4f095fe5928d415cb7ed656b3460ae79a"}, + {file = "regex-2024.4.28-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:6277d426e2f31bdbacb377d17a7475e32b2d7d1f02faaecc48d8e370c6a3ff31"}, + {file = "regex-2024.4.28-cp310-cp310-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:28e1f28d07220c0f3da0e8fcd5a115bbb53f8b55cecf9bec0c946eb9a059a94c"}, + {file = "regex-2024.4.28-cp310-cp310-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_12_x86_64.manylinux2010_x86_64.whl", hash = "sha256:aaa179975a64790c1f2701ac562b5eeb733946eeb036b5bcca05c8d928a62f10"}, + {file = "regex-2024.4.28-cp310-cp310-musllinux_1_1_aarch64.whl", hash = "sha256:6f435946b7bf7a1b438b4e6b149b947c837cb23c704e780c19ba3e6855dbbdd3"}, + {file = "regex-2024.4.28-cp310-cp310-musllinux_1_1_i686.whl", hash = "sha256:19d6c11bf35a6ad077eb23852827f91c804eeb71ecb85db4ee1386825b9dc4db"}, + {file = "regex-2024.4.28-cp310-cp310-musllinux_1_1_ppc64le.whl", hash = "sha256:fdae0120cddc839eb8e3c15faa8ad541cc6d906d3eb24d82fb041cfe2807bc1e"}, + {file = "regex-2024.4.28-cp310-cp310-musllinux_1_1_s390x.whl", hash = "sha256:e672cf9caaf669053121f1766d659a8813bd547edef6e009205378faf45c67b8"}, + {file = "regex-2024.4.28-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:f57515750d07e14743db55d59759893fdb21d2668f39e549a7d6cad5d70f9fea"}, + {file = "regex-2024.4.28-cp310-cp310-win32.whl", hash = "sha256:a1409c4eccb6981c7baabc8888d3550df518add6e06fe74fa1d9312c1838652d"}, + {file = "regex-2024.4.28-cp310-cp310-win_amd64.whl", hash = "sha256:1f687a28640f763f23f8a9801fe9e1b37338bb1ca5d564ddd41619458f1f22d1"}, + {file = "regex-2024.4.28-cp311-cp311-macosx_10_9_universal2.whl", hash = "sha256:84077821c85f222362b72fdc44f7a3a13587a013a45cf14534df1cbbdc9a6796"}, + {file = "regex-2024.4.28-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:b45d4503de8f4f3dc02f1d28a9b039e5504a02cc18906cfe744c11def942e9eb"}, + {file = "regex-2024.4.28-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:457c2cd5a646dd4ed536c92b535d73548fb8e216ebee602aa9f48e068fc393f3"}, + {file = "regex-2024.4.28-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:2b51739ddfd013c6f657b55a508de8b9ea78b56d22b236052c3a85a675102dc6"}, + {file = "regex-2024.4.28-cp311-cp311-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:459226445c7d7454981c4c0ce0ad1a72e1e751c3e417f305722bbcee6697e06a"}, + {file = "regex-2024.4.28-cp311-cp311-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:670fa596984b08a4a769491cbdf22350431970d0112e03d7e4eeaecaafcd0fec"}, + {file = "regex-2024.4.28-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:fe00f4fe11c8a521b173e6324d862ee7ee3412bf7107570c9b564fe1119b56fb"}, + {file = "regex-2024.4.28-cp311-cp311-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:36f392dc7763fe7924575475736bddf9ab9f7a66b920932d0ea50c2ded2f5636"}, + {file = "regex-2024.4.28-cp311-cp311-musllinux_1_1_aarch64.whl", hash = "sha256:23a412b7b1a7063f81a742463f38821097b6a37ce1e5b89dd8e871d14dbfd86b"}, + {file = "regex-2024.4.28-cp311-cp311-musllinux_1_1_i686.whl", hash = "sha256:f1d6e4b7b2ae3a6a9df53efbf199e4bfcff0959dbdb5fd9ced34d4407348e39a"}, + {file = "regex-2024.4.28-cp311-cp311-musllinux_1_1_ppc64le.whl", hash = "sha256:499334ad139557de97cbc4347ee921c0e2b5e9c0f009859e74f3f77918339257"}, + {file = "regex-2024.4.28-cp311-cp311-musllinux_1_1_s390x.whl", hash = "sha256:0940038bec2fe9e26b203d636c44d31dd8766abc1fe66262da6484bd82461ccf"}, + {file = "regex-2024.4.28-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:66372c2a01782c5fe8e04bff4a2a0121a9897e19223d9eab30c54c50b2ebeb7f"}, + {file = "regex-2024.4.28-cp311-cp311-win32.whl", hash = "sha256:c77d10ec3c1cf328b2f501ca32583625987ea0f23a0c2a49b37a39ee5c4c4630"}, + {file = "regex-2024.4.28-cp311-cp311-win_amd64.whl", hash = "sha256:fc0916c4295c64d6890a46e02d4482bb5ccf33bf1a824c0eaa9e83b148291f90"}, + {file = "regex-2024.4.28-cp312-cp312-macosx_10_9_universal2.whl", hash = "sha256:08a1749f04fee2811c7617fdd46d2e46d09106fa8f475c884b65c01326eb15c5"}, + {file = "regex-2024.4.28-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:b8eb28995771c087a73338f695a08c9abfdf723d185e57b97f6175c5051ff1ae"}, + {file = "regex-2024.4.28-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:dd7ef715ccb8040954d44cfeff17e6b8e9f79c8019daae2fd30a8806ef5435c0"}, + {file = "regex-2024.4.28-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:fb0315a2b26fde4005a7c401707c5352df274460f2f85b209cf6024271373013"}, + {file = "regex-2024.4.28-cp312-cp312-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:f2fc053228a6bd3a17a9b0a3f15c3ab3cf95727b00557e92e1cfe094b88cc662"}, + {file = "regex-2024.4.28-cp312-cp312-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:7fe9739a686dc44733d52d6e4f7b9c77b285e49edf8570754b322bca6b85b4cc"}, + {file = "regex-2024.4.28-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:a74fcf77d979364f9b69fcf8200849ca29a374973dc193a7317698aa37d8b01c"}, + {file = "regex-2024.4.28-cp312-cp312-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:965fd0cf4694d76f6564896b422724ec7b959ef927a7cb187fc6b3f4e4f59833"}, + {file = "regex-2024.4.28-cp312-cp312-musllinux_1_1_aarch64.whl", hash = "sha256:2fef0b38c34ae675fcbb1b5db760d40c3fc3612cfa186e9e50df5782cac02bcd"}, + {file = "regex-2024.4.28-cp312-cp312-musllinux_1_1_i686.whl", hash = "sha256:bc365ce25f6c7c5ed70e4bc674f9137f52b7dd6a125037f9132a7be52b8a252f"}, + {file = "regex-2024.4.28-cp312-cp312-musllinux_1_1_ppc64le.whl", hash = "sha256:ac69b394764bb857429b031d29d9604842bc4cbfd964d764b1af1868eeebc4f0"}, + {file = "regex-2024.4.28-cp312-cp312-musllinux_1_1_s390x.whl", hash = "sha256:144a1fc54765f5c5c36d6d4b073299832aa1ec6a746a6452c3ee7b46b3d3b11d"}, + {file = "regex-2024.4.28-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:2630ca4e152c221072fd4a56d4622b5ada876f668ecd24d5ab62544ae6793ed6"}, + {file = "regex-2024.4.28-cp312-cp312-win32.whl", hash = "sha256:7f3502f03b4da52bbe8ba962621daa846f38489cae5c4a7b5d738f15f6443d17"}, + {file = "regex-2024.4.28-cp312-cp312-win_amd64.whl", hash = "sha256:0dd3f69098511e71880fb00f5815db9ed0ef62c05775395968299cb400aeab82"}, + {file = "regex-2024.4.28-cp38-cp38-macosx_10_9_universal2.whl", hash = "sha256:374f690e1dd0dbdcddea4a5c9bdd97632cf656c69113f7cd6a361f2a67221cb6"}, + {file = "regex-2024.4.28-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:25f87ae6b96374db20f180eab083aafe419b194e96e4f282c40191e71980c666"}, + {file = "regex-2024.4.28-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:5dbc1bcc7413eebe5f18196e22804a3be1bfdfc7e2afd415e12c068624d48247"}, + {file = "regex-2024.4.28-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:f85151ec5a232335f1be022b09fbbe459042ea1951d8a48fef251223fc67eee1"}, + {file = "regex-2024.4.28-cp38-cp38-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:57ba112e5530530fd175ed550373eb263db4ca98b5f00694d73b18b9a02e7185"}, + {file = "regex-2024.4.28-cp38-cp38-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:224803b74aab56aa7be313f92a8d9911dcade37e5f167db62a738d0c85fdac4b"}, + {file = "regex-2024.4.28-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:0a54a047b607fd2d2d52a05e6ad294602f1e0dec2291152b745870afc47c1397"}, + {file = "regex-2024.4.28-cp38-cp38-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:0a2a512d623f1f2d01d881513af9fc6a7c46e5cfffb7dc50c38ce959f9246c94"}, + {file = "regex-2024.4.28-cp38-cp38-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_12_x86_64.manylinux2010_x86_64.whl", hash = "sha256:c06bf3f38f0707592898428636cbb75d0a846651b053a1cf748763e3063a6925"}, + {file = "regex-2024.4.28-cp38-cp38-musllinux_1_1_aarch64.whl", hash = "sha256:1031a5e7b048ee371ab3653aad3030ecfad6ee9ecdc85f0242c57751a05b0ac4"}, + {file = "regex-2024.4.28-cp38-cp38-musllinux_1_1_i686.whl", hash = "sha256:d7a353ebfa7154c871a35caca7bfd8f9e18666829a1dc187115b80e35a29393e"}, + {file = "regex-2024.4.28-cp38-cp38-musllinux_1_1_ppc64le.whl", hash = "sha256:7e76b9cfbf5ced1aca15a0e5b6f229344d9b3123439ffce552b11faab0114a02"}, + {file = "regex-2024.4.28-cp38-cp38-musllinux_1_1_s390x.whl", hash = "sha256:5ce479ecc068bc2a74cb98dd8dba99e070d1b2f4a8371a7dfe631f85db70fe6e"}, + {file = "regex-2024.4.28-cp38-cp38-musllinux_1_1_x86_64.whl", hash = "sha256:7d77b6f63f806578c604dca209280e4c54f0fa9a8128bb8d2cc5fb6f99da4150"}, + {file = "regex-2024.4.28-cp38-cp38-win32.whl", hash = "sha256:d84308f097d7a513359757c69707ad339da799e53b7393819ec2ea36bc4beb58"}, + {file = "regex-2024.4.28-cp38-cp38-win_amd64.whl", hash = "sha256:2cc1b87bba1dd1a898e664a31012725e48af826bf3971e786c53e32e02adae6c"}, + {file = "regex-2024.4.28-cp39-cp39-macosx_10_9_universal2.whl", hash = "sha256:7413167c507a768eafb5424413c5b2f515c606be5bb4ef8c5dee43925aa5718b"}, + {file = "regex-2024.4.28-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:108e2dcf0b53a7c4ab8986842a8edcb8ab2e59919a74ff51c296772e8e74d0ae"}, + {file = "regex-2024.4.28-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:f1c5742c31ba7d72f2dedf7968998730664b45e38827637e0f04a2ac7de2f5f1"}, + {file = "regex-2024.4.28-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:ecc6148228c9ae25ce403eade13a0961de1cb016bdb35c6eafd8e7b87ad028b1"}, + {file = "regex-2024.4.28-cp39-cp39-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:b7d893c8cf0e2429b823ef1a1d360a25950ed11f0e2a9df2b5198821832e1947"}, + {file = "regex-2024.4.28-cp39-cp39-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:4290035b169578ffbbfa50d904d26bec16a94526071ebec3dadbebf67a26b25e"}, + {file = "regex-2024.4.28-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:44a22ae1cfd82e4ffa2066eb3390777dc79468f866f0625261a93e44cdf6482b"}, + {file = "regex-2024.4.28-cp39-cp39-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:fd24fd140b69f0b0bcc9165c397e9b2e89ecbeda83303abf2a072609f60239e2"}, + {file = "regex-2024.4.28-cp39-cp39-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_12_x86_64.manylinux2010_x86_64.whl", hash = "sha256:39fb166d2196413bead229cd64a2ffd6ec78ebab83fff7d2701103cf9f4dfd26"}, + {file = "regex-2024.4.28-cp39-cp39-musllinux_1_1_aarch64.whl", hash = "sha256:9301cc6db4d83d2c0719f7fcda37229691745168bf6ae849bea2e85fc769175d"}, + {file = "regex-2024.4.28-cp39-cp39-musllinux_1_1_i686.whl", hash = "sha256:7c3d389e8d76a49923683123730c33e9553063d9041658f23897f0b396b2386f"}, + {file = "regex-2024.4.28-cp39-cp39-musllinux_1_1_ppc64le.whl", hash = "sha256:99ef6289b62042500d581170d06e17f5353b111a15aa6b25b05b91c6886df8fc"}, + {file = "regex-2024.4.28-cp39-cp39-musllinux_1_1_s390x.whl", hash = "sha256:b91d529b47798c016d4b4c1d06cc826ac40d196da54f0de3c519f5a297c5076a"}, + {file = "regex-2024.4.28-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:43548ad74ea50456e1c68d3c67fff3de64c6edb85bcd511d1136f9b5376fc9d1"}, + {file = "regex-2024.4.28-cp39-cp39-win32.whl", hash = "sha256:05d9b6578a22db7dedb4df81451f360395828b04f4513980b6bd7a1412c679cc"}, + {file = "regex-2024.4.28-cp39-cp39-win_amd64.whl", hash = "sha256:3986217ec830c2109875be740531feb8ddafe0dfa49767cdcd072ed7e8927962"}, + {file = "regex-2024.4.28.tar.gz", hash = "sha256:83ab366777ea45d58f72593adf35d36ca911ea8bd838483c1823b883a121b0e4"}, ] [[package]] @@ -3959,13 +3935,13 @@ zstd = ["zstandard (>=0.18.0)"] [[package]] name = "virtualenv" -version = "20.26.0" +version = "20.26.1" description = "Virtual Python Environment builder" optional = true python-versions = ">=3.7" files = [ - {file = "virtualenv-20.26.0-py3-none-any.whl", hash = "sha256:0846377ea76e818daaa3e00a4365c018bc3ac9760cbb3544de542885aad61fb3"}, - {file = "virtualenv-20.26.0.tar.gz", hash = "sha256:ec25a9671a5102c8d2657f62792a27b48f016664c6873f6beed3800008577210"}, + {file = "virtualenv-20.26.1-py3-none-any.whl", hash = "sha256:7aa9982a728ae5892558bff6a2839c00b9ed145523ece2274fad6f414690ae75"}, + {file = "virtualenv-20.26.1.tar.gz", hash = "sha256:604bfdceaeece392802e6ae48e69cec49168b9c5f4a44e483963f9242eb0e78b"}, ] [package.dependencies] diff --git a/tests/test_examples.py b/tests/test_examples.py index 70dfdd35..1fca45f5 100644 --- a/tests/test_examples.py +++ b/tests/test_examples.py @@ -46,7 +46,7 @@ def test_examples_3_and_2(): # Pass empty globals to allow dictionary comprehension https://stackoverflow.com/a/32897127/4391249. exec(file_contents, {}) - for file_name in ["model.pt", "config.yaml"]: + for file_name in ["model.safetensors", "config.json", "config.yaml"]: assert Path(f"outputs/train/example_pusht_diffusion/{file_name}").exists() path = "examples/2_evaluate_pretrained_policy.py" @@ -58,15 +58,15 @@ def test_examples_3_and_2(): file_contents = _find_and_replace( file_contents, [ + ('pretrained_policy_name = "lerobot/diffusion_policy_pusht_image"', ""), + ("pretrained_policy_path = Path(snapshot_download(pretrained_policy_name))", ""), + ( + '# pretrained_policy_path = Path("outputs/train/example_pusht_diffusion")', + 'pretrained_policy_path = Path("outputs/train/example_pusht_diffusion")', + ), ('"eval.n_episodes=10"', '"eval.n_episodes=1"'), ('"eval.batch_size=10"', '"eval.batch_size=1"'), ('"device=cuda"', '"device=cpu"'), - ( - '# folder = Path("outputs/train/example_pusht_diffusion")', - 'folder = Path("outputs/train/example_pusht_diffusion")', - ), - ('hub_id = "lerobot/diffusion_policy_pusht_image"', ""), - ("folder = Path(snapshot_download(hub_id)", ""), ], ) diff --git a/tests/test_policies.py b/tests/test_policies.py index 9151c666..ed046659 100644 --- a/tests/test_policies.py +++ b/tests/test_policies.py @@ -1,19 +1,33 @@ +import inspect + import pytest import torch +from huggingface_hub import PyTorchModelHubMixin +from lerobot import available_policies from lerobot.common.datasets.factory import make_dataset from lerobot.common.datasets.utils import cycle from lerobot.common.envs.factory import make_env from lerobot.common.envs.utils import postprocess_action, preprocess_observation -from lerobot.common.policies.act.modeling_act import ACTPolicy -from lerobot.common.policies.diffusion.modeling_diffusion import DiffusionPolicy -from lerobot.common.policies.factory import make_policy +from lerobot.common.policies.factory import get_policy_and_config_classes, make_policy from lerobot.common.policies.normalize import Normalize, Unnormalize from lerobot.common.policies.policy_protocol import Policy from lerobot.common.utils.utils import init_hydra_config from tests.utils import DEFAULT_CONFIG_PATH, DEVICE, require_env +@pytest.mark.parametrize("policy_name", available_policies) +def test_get_policy_and_config_classes(policy_name: str): + """Check that the correct policy and config classes are returned.""" + if policy_name == "tdmpc": + with pytest.raises(NotImplementedError): + get_policy_and_config_classes(policy_name) + return + policy_cls, config_cls = get_policy_and_config_classes(policy_name) + assert policy_cls.name == policy_name + assert issubclass(config_cls, inspect.signature(policy_cls.__init__).parameters["config"].annotation) + + # TODO(aliberts): refactor using lerobot/__init__.py variables @pytest.mark.parametrize( "env_name,policy_name,extra_overrides", @@ -44,7 +58,8 @@ def test_policy(env_name, policy_name, extra_overrides): """ Tests: - Making the policy object. - - Checking that the policy follows the correct protocol. + - Checking that the policy follows the correct protocol and subclasses nn.Module + and PyTorchModelHubMixin. - Updating the policy. - Using the policy to select actions at inference time. - Test the action can be applied to the policy @@ -61,11 +76,13 @@ def test_policy(env_name, policy_name, extra_overrides): # Check that we can make the policy object. dataset = make_dataset(cfg) - policy = make_policy(cfg, dataset_stats=dataset.stats) + policy = make_policy(hydra_cfg=cfg, dataset_stats=dataset.stats) # Check that the policy follows the required protocol. assert isinstance( policy, Policy ), f"The policy does not follow the required protocol. Please see {Policy.__module__}.{Policy.__name__}." + assert isinstance(policy, torch.nn.Module) + assert isinstance(policy, PyTorchModelHubMixin) # Check that we run select_actions and get the appropriate output. env = make_env(cfg, num_parallel_envs=2) @@ -108,29 +125,33 @@ def test_policy(env_name, policy_name, extra_overrides): # Test step through policy env.step(action) - # Test load state_dict - if policy_name != "tdmpc": - # TODO(rcadene, alexander-soare): make it work for tdmpc - new_policy = make_policy(cfg) - new_policy.load_state_dict(policy.state_dict()) + +@pytest.mark.parametrize("policy_name", available_policies) +def test_policy_defaults(policy_name: str): + """Check that the policy can be instantiated with defaults.""" + if policy_name == "tdmpc": + with pytest.raises(NotImplementedError): + get_policy_and_config_classes(policy_name) + return + policy_cls, _ = get_policy_and_config_classes(policy_name) + policy_cls() -@pytest.mark.parametrize("policy_cls", [DiffusionPolicy, ACTPolicy]) -def test_policy_defaults(policy_cls): - kwargs = {} - # TODO(alexander-soare): Remove this kwargs hack when we move the scheduler out of DP. - if policy_cls is DiffusionPolicy: - kwargs = {"lr_scheduler_num_training_steps": 1} - policy_cls(**kwargs) +@pytest.mark.parametrize("policy_name", available_policies) +def test_save_and_load_pretrained(policy_name: str): + if policy_name == "tdmpc": + with pytest.raises(NotImplementedError): + get_policy_and_config_classes(policy_name) + return + policy_cls, _ = get_policy_and_config_classes(policy_name) + policy: Policy = policy_cls() + save_dir = "/tmp/test_save_and_load_pretrained_{policy_cls.__name__}" + policy.save_pretrained(save_dir) + policy_ = policy_cls.from_pretrained(save_dir) + assert all(torch.equal(p, p_) for p, p_ in zip(policy.parameters(), policy_.parameters(), strict=True)) -@pytest.mark.parametrize( - "insert_temporal_dim", - [ - False, - True, - ], -) +@pytest.mark.parametrize("insert_temporal_dim", [False, True]) def test_normalize(insert_temporal_dim): """ Test that normalize/unnormalize can run without exceptions when properly set up, and that they raise