diff --git a/Makefile b/Makefile index 79e39c0b..20d2c553 100644 --- a/Makefile +++ b/Makefile @@ -22,74 +22,82 @@ test-end-to-end: ${MAKE} test-act-ete-eval ${MAKE} test-diffusion-ete-train ${MAKE} test-diffusion-ete-eval - # ${MAKE} test-tdmpc-ete-train - # ${MAKE} test-tdmpc-ete-eval + ${MAKE} test-tdmpc-ete-train + ${MAKE} test-tdmpc-ete-eval + ${MAKE} test-default-ete-eval test-act-ete-train: python lerobot/scripts/train.py \ policy=act \ env=aloha \ wandb.enable=False \ - offline_steps=2 \ - online_steps=0 \ - eval_episodes=1 \ + training.offline_steps=2 \ + training.online_steps=0 \ + eval.n_episodes=1 \ device=cpu \ - save_model=true \ - save_freq=2 \ + training.save_model=true \ + training.save_freq=2 \ policy.n_action_steps=20 \ policy.chunk_size=20 \ - policy.batch_size=2 \ + training.batch_size=2 \ hydra.run.dir=tests/outputs/act/ test-act-ete-eval: python lerobot/scripts/eval.py \ - --config tests/outputs/act/.hydra/config.yaml \ - eval_episodes=1 \ + -p tests/outputs/act/checkpoints/000002 \ + eval.n_episodes=1 \ env.episode_length=8 \ device=cpu \ - policy.pretrained_model_path=tests/outputs/act/models/2.pt test-diffusion-ete-train: python lerobot/scripts/train.py \ policy=diffusion \ env=pusht \ wandb.enable=False \ - offline_steps=2 \ - online_steps=0 \ - eval_episodes=1 \ + training.offline_steps=2 \ + training.online_steps=0 \ + eval.n_episodes=1 \ device=cpu \ - save_model=true \ - save_freq=2 \ - policy.batch_size=2 \ + training.save_model=true \ + training.save_freq=2 \ + training.batch_size=2 \ hydra.run.dir=tests/outputs/diffusion/ test-diffusion-ete-eval: python lerobot/scripts/eval.py \ - --config tests/outputs/diffusion/.hydra/config.yaml \ - eval_episodes=1 \ + -p tests/outputs/diffusion/checkpoints/000002 \ + eval.n_episodes=1 \ env.episode_length=8 \ device=cpu \ - policy.pretrained_model_path=tests/outputs/diffusion/models/2.pt test-tdmpc-ete-train: python lerobot/scripts/train.py \ policy=tdmpc \ env=xarm \ + env.task=XarmLift-v0 \ + dataset_repo_id=lerobot/xarm_lift_medium_replay \ wandb.enable=False \ - offline_steps=1 \ - online_steps=2 \ - eval_episodes=1 \ + training.offline_steps=2 \ + training.online_steps=2 \ + eval.n_episodes=1 \ env.episode_length=2 \ device=cpu \ - save_model=true \ - save_freq=2 \ - policy.batch_size=2 \ + training.save_model=true \ + training.save_freq=2 \ + training.batch_size=2 \ hydra.run.dir=tests/outputs/tdmpc/ test-tdmpc-ete-eval: python lerobot/scripts/eval.py \ - --config tests/outputs/tdmpc/.hydra/config.yaml \ - eval_episodes=1 \ + -p tests/outputs/tdmpc/checkpoints/000002 \ + eval.n_episodes=1 \ + env.episode_length=8 \ + device=cpu \ + + +test-default-ete-eval: + python lerobot/scripts/eval.py \ + --config lerobot/configs/default.yaml \ + eval.n_episodes=1 \ env.episode_length=8 \ device=cpu \ - policy.pretrained_model_path=tests/outputs/tdmpc/models/2.pt diff --git a/README.md b/README.md index 35a2f422..96378c44 100644 --- a/README.md +++ b/README.md @@ -135,16 +135,16 @@ Check out [examples](./examples) to see how you can load a pretrained policy fro Or you can achieve the same result by executing our script from the command line: ```bash python lerobot/scripts/eval.py \ ---hub-id lerobot/diffusion_policy_pusht_image \ +-p lerobot/diffusion_policy_pusht_image \ eval_episodes=10 \ hydra.run.dir=outputs/eval/example_hub ``` After training your own policy, you can also re-evaluate the checkpoints with: + ```bash python lerobot/scripts/eval.py \ ---config PATH/TO/FOLDER/config.yaml \ -policy.pretrained_model_path=PATH/TO/FOLDER/weights.pth \ +-p PATH/TO/TRAIN/OUTPUT/FOLDER \ eval_episodes=10 \ hydra.run.dir=outputs/eval/example_dir ``` @@ -246,29 +246,22 @@ Once you have trained a policy you may upload it to the HuggingFace hub. Firstly, make sure you have a model repository set up on the hub. The hub ID looks like HF_USER/REPO_NAME. -Secondly, assuming you have trained a policy, you need: +Secondly, assuming you have trained a policy, you need the following (which should all be in any of the subdirectories of `checkpoints` in your training output folder, if you've used the LeRobot training script): -- `config.yaml` which you can get from the `.hydra` directory of your training output folder. -- `model.pt` which should be one of the saved models in the `models` directory of your training output folder (they won't be named `model.pt` but you will need to choose one). +- `config.json`: A serialized version of the policy configuration (following the policy's dataclass config). +- `model.safetensors`: The `torch.nn.Module` parameters saved in [Hugging Face Safetensors](https://huggingface.co/docs/safetensors/index) format. +- `config.yaml`: This is the consolidated Hydra training configuration containing the policy, environment, and dataset configs. The policy configuration should match `config.json` exactly. The environment config is useful for anyone who wants to evaluate your policy. The dataset config just serves as a paper trail for reproducibility. -To upload these to the hub, prepare a folder with the following structure (you can use symlinks rather than copying): - -``` -to_upload - ├── config.yaml - └── model.pt -``` - -With the folder prepared, run the following with a desired revision ID. +To upload these to the hub, run the following with a desired revision ID. ```bash -huggingface-cli upload $HUB_ID to_upload --revision $REVISION_ID +huggingface-cli upload $HUB_ID PATH/TO/OUTPUT/DIR --revision $REVISION_ID ``` If you want this to be the default revision also run the following (don't worry, it won't upload the files again; it will just adjust the file pointers): ```bash -huggingface-cli upload $HUB_ID to_upload +huggingface-cli upload $HUB_ID PATH/TO/OUTPUT/DIR ``` See `eval.py` for an example of how a user may use your policy. diff --git a/examples/2_evaluate_pretrained_policy.py b/examples/2_evaluate_pretrained_policy.py index 392ad1c6..8a72201f 100644 --- a/examples/2_evaluate_pretrained_policy.py +++ b/examples/2_evaluate_pretrained_policy.py @@ -7,32 +7,21 @@ from pathlib import Path from huggingface_hub import snapshot_download -from lerobot.common.utils.utils import init_hydra_config from lerobot.scripts.eval import eval # Get a pretrained policy from the hub. -# TODO(alexander-soare): This no longer works until we upload a new model that uses the current configs. -hub_id = "lerobot/diffusion_policy_pusht_image" -folder = Path(snapshot_download(hub_id)) +pretrained_policy_name = "lerobot/diffusion_policy_pusht_image" +pretrained_policy_path = Path(snapshot_download(pretrained_policy_name)) # OR uncomment the following to evaluate a policy from the local outputs/train folder. -# folder = Path("outputs/train/example_pusht_diffusion") - -config_path = folder / "config.yaml" -weights_path = folder / "model.pt" +# pretrained_policy_path = Path("outputs/train/example_pusht_diffusion") # Override some config parameters to do with evaluation. overrides = [ - f"policy.pretrained_model_path={weights_path}", - "eval_episodes=10", - "rollout_batch_size=10", + "eval.n_episodes=10", + "eval.batch_size=10", "device=cuda", ] -# Create a Hydra config. -cfg = init_hydra_config(config_path, overrides) - # Evaluate the policy and save the outputs including metrics and videos. -eval( - cfg, - out_dir=f"outputs/eval/example_{cfg.env.name}_{cfg.policy.name}", -) +# TODO(rcadene, alexander-soare): dont call eval, but add the minimal code snippet to rollout +eval(pretrained_policy_path=pretrained_policy_path) diff --git a/examples/3_train_policy.py b/examples/3_train_policy.py index fb2a4419..69e3d34c 100644 --- a/examples/3_train_policy.py +++ b/examples/3_train_policy.py @@ -34,19 +34,17 @@ dataset = make_dataset(hydra_cfg) # If you're doing something different, you will likely need to change at least some of the defaults. cfg = DiffusionConfig() # TODO(alexander-soare): Remove LR scheduler from the policy. -policy = DiffusionPolicy(cfg, lr_scheduler_num_training_steps=training_steps, dataset_stats=dataset.stats) +policy = DiffusionPolicy(cfg, dataset_stats=dataset.stats) policy.train() policy.to(device) -optimizer = torch.optim.Adam( - policy.diffusion.parameters(), cfg.lr, cfg.adam_betas, cfg.adam_eps, cfg.adam_weight_decay -) +optimizer = torch.optim.Adam(policy.parameters(), lr=1e-4) # Create dataloader for offline training. dataloader = torch.utils.data.DataLoader( dataset, num_workers=4, - batch_size=cfg.batch_size, + batch_size=64, shuffle=True, pin_memory=device != torch.device("cpu"), drop_last=True, @@ -71,6 +69,7 @@ while not done: done = True break -# Save the policy and configuration for later use. -policy.save(output_directory / "model.pt") +# Save the policy. +policy.save_pretrained(output_directory) +# Save the Hydra configuration so we have the environment configuration for eval. OmegaConf.save(hydra_cfg, output_directory / "config.yaml") diff --git a/lerobot/common/datasets/factory.py b/lerobot/common/datasets/factory.py index 0da17b8e..c9711ca3 100644 --- a/lerobot/common/datasets/factory.py +++ b/lerobot/common/datasets/factory.py @@ -14,12 +14,13 @@ def make_dataset( cfg, split="train", ): - if cfg.env.name not in cfg.dataset.repo_id: + if cfg.env.name not in cfg.dataset_repo_id: logging.warning( - f"There might be a mismatch between your training dataset ({cfg.dataset.repo_id=}) and your environment ({cfg.env.name=})." + f"There might be a mismatch between your training dataset ({cfg.dataset_repo_id=}) and your " + f"environment ({cfg.env.name=})." ) - delta_timestamps = cfg.policy.get("delta_timestamps") + delta_timestamps = cfg.training.get("delta_timestamps") if delta_timestamps is not None: for key in delta_timestamps: if isinstance(delta_timestamps[key], str): @@ -28,7 +29,7 @@ def make_dataset( # TODO(rcadene): add data augmentations dataset = LeRobotDataset( - cfg.dataset.repo_id, + cfg.dataset_repo_id, split=split, root=DATA_DIR, delta_timestamps=delta_timestamps, diff --git a/lerobot/common/logger.py b/lerobot/common/logger.py index fbab0580..bf1d51aa 100644 --- a/lerobot/common/logger.py +++ b/lerobot/common/logger.py @@ -5,9 +5,12 @@ import logging import os from pathlib import Path +from huggingface_hub.constants import SAFETENSORS_SINGLE_FILE from omegaconf import OmegaConf from termcolor import colored +from lerobot.common.policies.policy_protocol import Policy + def log_output_dir(out_dir): logging.info(colored("Output dir:", "yellow", attrs=["bold"]) + f" {out_dir}") @@ -30,11 +33,11 @@ class Logger: self._log_dir = Path(log_dir) self._log_dir.mkdir(parents=True, exist_ok=True) self._job_name = job_name - self._model_dir = self._log_dir / "models" + self._model_dir = self._log_dir / "checkpoints" self._buffer_dir = self._log_dir / "buffers" - self._save_model = cfg.save_model + self._save_model = cfg.training.save_model self._disable_wandb_artifact = cfg.wandb.disable_artifact - self._save_buffer = cfg.save_buffer + self._save_buffer = cfg.training.get("save_buffer", False) self._group = cfg_to_group(cfg) self._seed = cfg.seed self._cfg = cfg @@ -70,18 +73,20 @@ class Logger: logging.info(f"Track this run --> {colored(wandb.run.get_url(), 'yellow', attrs=['bold'])}") self._wandb = wandb - def save_model(self, policy, identifier): + def save_model(self, policy: Policy, identifier): if self._save_model: self._model_dir.mkdir(parents=True, exist_ok=True) - fp = self._model_dir / f"{str(identifier)}.pt" - policy.save(fp) + save_dir = self._model_dir / str(identifier) + policy.save_pretrained(save_dir) + # Also save the full Hydra config for the env configuration. + OmegaConf.save(self._cfg, save_dir / "config.yaml") if self._wandb and not self._disable_wandb_artifact: # note wandb artifact does not accept ":" in its name artifact = self._wandb.Artifact( self._group.replace(":", "_") + "-" + str(self._seed) + "-" + str(identifier), type="model", ) - artifact.add_file(fp) + artifact.add_file(save_dir / SAFETENSORS_SINGLE_FILE) self._wandb.log_artifact(artifact) def save_buffer(self, buffer, identifier): diff --git a/lerobot/common/policies/act/configuration_act.py b/lerobot/common/policies/act/configuration_act.py index 16be36df..a3980b14 100644 --- a/lerobot/common/policies/act/configuration_act.py +++ b/lerobot/common/policies/act/configuration_act.py @@ -2,7 +2,7 @@ from dataclasses import dataclass, field @dataclass -class ActionChunkingTransformerConfig: +class ACTConfig: """Configuration class for the Action Chunking Transformers policy. Defaults are configured for training on bimanual Aloha tasks like "insertion" or "transfer". @@ -22,23 +22,24 @@ class ActionChunkingTransformerConfig: The key represents the input data name, and the value is a list indicating the dimensions of the corresponding data. For example, "observation.images.top" refers to an input from the "top" camera with dimensions [3, 96, 96], indicating it has three color channels and 96x96 resolution. - Importantly, shapes doesnt include batch dimension or temporal dimension. + Importantly, shapes doesn't include batch dimension or temporal dimension. output_shapes: A dictionary defining the shapes of the output data for the policy. The key represents the output data name, and the value is a list indicating the dimensions of the corresponding data. For example, "action" refers to an output shape of [14], indicating - 14-dimensional actions. Importantly, shapes doesnt include batch dimension or temporal dimension. - normalize_input_modes: A dictionary with key represents the modality (e.g. "observation.state"), - and the value specifies the normalization mode to apply. The two availables - modes are "mean_std" which substracts the mean and divide by the standard - deviation and "min_max" which rescale in a [-1, 1] range. - unnormalize_output_modes: Similar dictionary as `normalize_input_modes`, but to unormalize in original scale. + 14-dimensional actions. Importantly, shapes doesn't include batch dimension or temporal dimension. + input_normalization_modes: A dictionary with key representing the modality (e.g. "observation.state"), + and the value specifies the normalization mode to apply. The two available modes are "mean_std" + which subtracts the mean and divides by the standard deviation and "min_max" which rescale in a + [-1, 1] range. + output_normalization_modes: Similar dictionary as `normalize_input_modes`, but to unnormalize to the + original scale. Note that this is also used for normalizing the training targets. vision_backbone: Name of the torchvision resnet backbone to use for encoding images. pretrained_backbone_weights: Pretrained weights from torchvision to initalize the backbone. `None` means no pretrained weights. replace_final_stride_with_dilation: Whether to replace the ResNet's final 2x2 stride with a dilated convolution. pre_norm: Whether to use "pre-norm" in the transformer blocks. - d_model: The transformer blocks' main hidden dimension. + dim_model: The transformer blocks' main hidden dimension. n_heads: The number of heads to use in the transformer blocks' multi-head attention. dim_feedforward: The dimension to expand the transformer's hidden dimension to in the feed-forward layers. @@ -62,13 +63,13 @@ class ActionChunkingTransformerConfig: chunk_size: int = 100 n_action_steps: int = 100 - input_shapes: dict[str, list[str]] = field( + input_shapes: dict[str, list[int]] = field( default_factory=lambda: { "observation.images.top": [3, 480, 640], "observation.state": [14], } ) - output_shapes: dict[str, list[str]] = field( + output_shapes: dict[str, list[int]] = field( default_factory=lambda: { "action": [14], } @@ -94,7 +95,7 @@ class ActionChunkingTransformerConfig: replace_final_stride_with_dilation: int = False # Transformer layers. pre_norm: bool = False - d_model: int = 512 + dim_model: int = 512 n_heads: int = 8 dim_feedforward: int = 3200 feedforward_activation: str = "relu" @@ -112,15 +113,6 @@ class ActionChunkingTransformerConfig: dropout: float = 0.1 kl_weight: float = 10.0 - # --- - # TODO(alexander-soare): Remove these from the policy config. - batch_size: int = 8 - lr: float = 1e-5 - lr_backbone: float = 1e-5 - weight_decay: float = 1e-4 - grad_clip_norm: float = 10 - utd: int = 1 - def __post_init__(self): """Input validation (not exhaustive).""" if not self.vision_backbone.startswith("resnet"): diff --git a/lerobot/common/policies/act/modeling_act.py b/lerobot/common/policies/act/modeling_act.py index 22384ca0..f9e52e02 100644 --- a/lerobot/common/policies/act/modeling_act.py +++ b/lerobot/common/policies/act/modeling_act.py @@ -14,18 +14,124 @@ import numpy as np import torch import torch.nn.functional as F # noqa: N812 import torchvision +from huggingface_hub import PyTorchModelHubMixin from torch import Tensor, nn from torchvision.models._utils import IntermediateLayerGetter from torchvision.ops.misc import FrozenBatchNorm2d -from lerobot.common.policies.act.configuration_act import ActionChunkingTransformerConfig +from lerobot.common.policies.act.configuration_act import ACTConfig from lerobot.common.policies.normalize import Normalize, Unnormalize -class ActionChunkingTransformerPolicy(nn.Module): +class ACTPolicy(nn.Module, PyTorchModelHubMixin): """ Action Chunking Transformer Policy as per Learning Fine-Grained Bimanual Manipulation with Low-Cost Hardware (paper: https://arxiv.org/abs/2304.13705, code: https://github.com/tonyzhaozh/act) + """ + + name = "act" + + def __init__( + self, + config: ACTConfig | None = None, + dataset_stats: dict[str, dict[str, Tensor]] | None = None, + ): + """ + Args: + config: Policy configuration class instance or None, in which case the default instantiation of + the configuration class is used. + dataset_stats: Dataset statistics to be used for normalization. If not passed here, it is expected + that they will be passed with a call to `load_state_dict` before the policy is used. + """ + super().__init__() + if config is None: + config = ACTConfig() + self.config = config + self.normalize_inputs = Normalize( + config.input_shapes, config.input_normalization_modes, dataset_stats + ) + self.normalize_targets = Normalize( + config.output_shapes, config.output_normalization_modes, dataset_stats + ) + self.unnormalize_outputs = Unnormalize( + config.output_shapes, config.output_normalization_modes, dataset_stats + ) + self.model = ACT(config) + + def reset(self): + """This should be called whenever the environment is reset.""" + if self.config.n_action_steps is not None: + self._action_queue = deque([], maxlen=self.config.n_action_steps) + + @torch.no_grad + def select_action(self, batch: dict[str, Tensor]) -> Tensor: + """Select a single action given environment observations. + + This method wraps `select_actions` in order to return one action at a time for execution in the + environment. It works by managing the actions in a queue and only calling `select_actions` when the + queue is empty. + """ + self.eval() + + batch = self.normalize_inputs(batch) + self._stack_images(batch) + + if len(self._action_queue) == 0: + # `self.model.forward` returns a (batch_size, n_action_steps, action_dim) tensor, but the queue + # effectively has shape (n_action_steps, batch_size, *), hence the transpose. + actions = self.model(batch)[0][: self.config.n_action_steps] + + # TODO(rcadene): make _forward return output dictionary? + actions = self.unnormalize_outputs({"action": actions})["action"] + + self._action_queue.extend(actions.transpose(0, 1)) + return self._action_queue.popleft() + + def forward(self, batch: dict[str, Tensor]) -> dict[str, Tensor]: + """Run the batch through the model and compute the loss for training or validation.""" + batch = self.normalize_inputs(batch) + batch = self.normalize_targets(batch) + self._stack_images(batch) + actions_hat, (mu_hat, log_sigma_x2_hat) = self.model(batch) + + l1_loss = ( + F.l1_loss(batch["action"], actions_hat, reduction="none") * ~batch["action_is_pad"].unsqueeze(-1) + ).mean() + + loss_dict = {"l1_loss": l1_loss} + if self.config.use_vae: + # Calculate Dₖₗ(latent_pdf || standard_normal). Note: After computing the KL-divergence for + # each dimension independently, we sum over the latent dimension to get the total + # KL-divergence per batch element, then take the mean over the batch. + # (See App. B of https://arxiv.org/abs/1312.6114 for more details). + mean_kld = ( + (-0.5 * (1 + log_sigma_x2_hat - mu_hat.pow(2) - (log_sigma_x2_hat).exp())).sum(-1).mean() + ) + loss_dict["kld_loss"] = mean_kld + loss_dict["loss"] = l1_loss + mean_kld * self.config.kl_weight + else: + loss_dict["loss"] = l1_loss + + return loss_dict + + def _stack_images(self, batch: dict[str, Tensor]) -> dict[str, Tensor]: + """Stacks all the images in a batch and puts them in a new key: "observation.images". + + This function expects `batch` to have (at least): + { + "observation.state": (B, state_dim) batch of robot states. + "observation.images.{name}": (B, C, H, W) tensor of images. + } + """ + # Stack images in the order dictated by input_shapes. + batch["observation.images"] = torch.stack( + [batch[k] for k in self.config.input_shapes if k.startswith("observation.images.")], + dim=-4, + ) + + +class ACT(nn.Module): + """Action Chunking Transformer: The underlying neural network for ACTPolicy. Note: In this code we use the terms `vae_encoder`, 'encoder', `decoder`. The meanings are as follows. - The `vae_encoder` is, as per the literature around variational auto-encoders (VAE), the part of the @@ -59,51 +165,36 @@ class ActionChunkingTransformerPolicy(nn.Module): └───────────────────────┘ """ - name = "act" - - def __init__(self, cfg: ActionChunkingTransformerConfig | None = None, dataset_stats=None): - """ - Args: - cfg: Policy configuration class instance or None, in which case the default instantiation of the - configuration class is used. - """ + def __init__(self, config: ACTConfig): super().__init__() - if cfg is None: - cfg = ActionChunkingTransformerConfig() - self.cfg = cfg - self.normalize_inputs = Normalize(cfg.input_shapes, cfg.input_normalization_modes, dataset_stats) - self.normalize_targets = Normalize(cfg.output_shapes, cfg.output_normalization_modes, dataset_stats) - self.unnormalize_outputs = Unnormalize( - cfg.output_shapes, cfg.output_normalization_modes, dataset_stats - ) - + self.config = config # BERT style VAE encoder with input [cls, *joint_space_configuration, *action_sequence]. # The cls token forms parameters of the latent's distribution (like this [*means, *log_variances]). - if self.cfg.use_vae: - self.vae_encoder = _TransformerEncoder(cfg) - self.vae_encoder_cls_embed = nn.Embedding(1, cfg.d_model) + if self.config.use_vae: + self.vae_encoder = ACTEncoder(config) + self.vae_encoder_cls_embed = nn.Embedding(1, config.dim_model) # Projection layer for joint-space configuration to hidden dimension. self.vae_encoder_robot_state_input_proj = nn.Linear( - cfg.input_shapes["observation.state"][0], cfg.d_model + config.input_shapes["observation.state"][0], config.dim_model ) # Projection layer for action (joint-space target) to hidden dimension. self.vae_encoder_action_input_proj = nn.Linear( - cfg.input_shapes["observation.state"][0], cfg.d_model + config.input_shapes["observation.state"][0], config.dim_model ) - self.latent_dim = cfg.latent_dim + self.latent_dim = config.latent_dim # Projection layer from the VAE encoder's output to the latent distribution's parameter space. - self.vae_encoder_latent_output_proj = nn.Linear(cfg.d_model, self.latent_dim * 2) + self.vae_encoder_latent_output_proj = nn.Linear(config.dim_model, self.latent_dim * 2) # Fixed sinusoidal positional embedding the whole input to the VAE encoder. Unsqueeze for batch # dimension. self.register_buffer( "vae_encoder_pos_enc", - _create_sinusoidal_position_embedding(1 + 1 + cfg.chunk_size, cfg.d_model).unsqueeze(0), + create_sinusoidal_pos_embedding(1 + 1 + config.chunk_size, config.dim_model).unsqueeze(0), ) # Backbone for image feature extraction. - backbone_model = getattr(torchvision.models, cfg.vision_backbone)( - replace_stride_with_dilation=[False, False, cfg.replace_final_stride_with_dilation], - weights=cfg.pretrained_backbone_weights, + backbone_model = getattr(torchvision.models, config.vision_backbone)( + replace_stride_with_dilation=[False, False, config.replace_final_stride_with_dilation], + weights=config.pretrained_backbone_weights, norm_layer=FrozenBatchNorm2d, ) # Note: The assumption here is that we are using a ResNet model (and hence layer4 is the final feature @@ -112,26 +203,28 @@ class ActionChunkingTransformerPolicy(nn.Module): self.backbone = IntermediateLayerGetter(backbone_model, return_layers={"layer4": "feature_map"}) # Transformer (acts as VAE decoder when training with the variational objective). - self.encoder = _TransformerEncoder(cfg) - self.decoder = _TransformerDecoder(cfg) + self.encoder = ACTEncoder(config) + self.decoder = ACTDecoder(config) # Transformer encoder input projections. The tokens will be structured like # [latent, robot_state, image_feature_map_pixels]. - self.encoder_robot_state_input_proj = nn.Linear(cfg.input_shapes["observation.state"][0], cfg.d_model) - self.encoder_latent_input_proj = nn.Linear(self.latent_dim, cfg.d_model) + self.encoder_robot_state_input_proj = nn.Linear( + config.input_shapes["observation.state"][0], config.dim_model + ) + self.encoder_latent_input_proj = nn.Linear(self.latent_dim, config.dim_model) self.encoder_img_feat_input_proj = nn.Conv2d( - backbone_model.fc.in_features, cfg.d_model, kernel_size=1 + backbone_model.fc.in_features, config.dim_model, kernel_size=1 ) # Transformer encoder positional embeddings. - self.encoder_robot_and_latent_pos_embed = nn.Embedding(2, cfg.d_model) - self.encoder_cam_feat_pos_embed = _SinusoidalPositionEmbedding2D(cfg.d_model // 2) + self.encoder_robot_and_latent_pos_embed = nn.Embedding(2, config.dim_model) + self.encoder_cam_feat_pos_embed = ACTSinusoidalPositionEmbedding2d(config.dim_model // 2) # Transformer decoder. # Learnable positional embedding for the transformer's decoder (in the style of DETR object queries). - self.decoder_pos_embed = nn.Embedding(cfg.chunk_size, cfg.d_model) + self.decoder_pos_embed = nn.Embedding(config.chunk_size, config.dim_model) # Final action regression head on the output of the transformer's decoder. - self.action_head = nn.Linear(cfg.d_model, cfg.output_shapes["action"][0]) + self.action_head = nn.Linear(config.dim_model, config.output_shapes["action"][0]) self._reset_parameters() @@ -141,76 +234,7 @@ class ActionChunkingTransformerPolicy(nn.Module): if p.dim() > 1: nn.init.xavier_uniform_(p) - def reset(self): - """This should be called whenever the environment is reset.""" - if self.cfg.n_action_steps is not None: - self._action_queue = deque([], maxlen=self.cfg.n_action_steps) - - @torch.no_grad - def select_action(self, batch: dict[str, Tensor], **_) -> Tensor: - """Select a single action given environment observations. - - This method wraps `select_actions` in order to return one action at a time for execution in the - environment. It works by managing the actions in a queue and only calling `select_actions` when the - queue is empty. - """ - self.eval() - - batch = self.normalize_inputs(batch) - - if len(self._action_queue) == 0: - # `_forward` returns a (batch_size, n_action_steps, action_dim) tensor, but the queue effectively - # has shape (n_action_steps, batch_size, *), hence the transpose. - actions = self._forward(batch)[0][: self.cfg.n_action_steps] - - # TODO(rcadene): make _forward return output dictionary? - actions = self.unnormalize_outputs({"action": actions})["action"] - - self._action_queue.extend(actions.transpose(0, 1)) - return self._action_queue.popleft() - - def forward(self, batch, **_) -> dict[str, Tensor]: - """Run the batch through the model and compute the loss for training or validation.""" - batch = self.normalize_inputs(batch) - batch = self.normalize_targets(batch) - actions_hat, (mu_hat, log_sigma_x2_hat) = self._forward(batch) - - l1_loss = ( - F.l1_loss(batch["action"], actions_hat, reduction="none") * ~batch["action_is_pad"].unsqueeze(-1) - ).mean() - - loss_dict = {"l1_loss": l1_loss} - if self.cfg.use_vae: - # Calculate Dₖₗ(latent_pdf || standard_normal). Note: After computing the KL-divergence for - # each dimension independently, we sum over the latent dimension to get the total - # KL-divergence per batch element, then take the mean over the batch. - # (See App. B of https://arxiv.org/abs/1312.6114 for more details). - mean_kld = ( - (-0.5 * (1 + log_sigma_x2_hat - mu_hat.pow(2) - (log_sigma_x2_hat).exp())).sum(-1).mean() - ) - loss_dict["kld_loss"] = mean_kld - loss_dict["loss"] = l1_loss + mean_kld * self.cfg.kl_weight - else: - loss_dict["loss"] = l1_loss - - return loss_dict - - def _stack_images(self, batch: dict[str, Tensor]) -> dict[str, Tensor]: - """Stacks all the images in a batch and puts them in a new key: "observation.images". - - This function expects `batch` to have (at least): - { - "observation.state": (B, state_dim) batch of robot states. - "observation.images.{name}": (B, C, H, W) tensor of images. - } - """ - # Stack images in the order dictated by input_shapes. - batch["observation.images"] = torch.stack( - [batch[k] for k in self.cfg.input_shapes if k.startswith("observation.images.")], - dim=-4, - ) - - def _forward(self, batch: dict[str, Tensor]) -> tuple[Tensor, tuple[Tensor, Tensor] | tuple[None, None]]: + def forward(self, batch: dict[str, Tensor]) -> tuple[Tensor, tuple[Tensor, Tensor] | tuple[None, None]]: """A forward pass through the Action Chunking Transformer (with optional VAE encoder). `batch` should have the following structure: @@ -226,17 +250,15 @@ class ActionChunkingTransformerPolicy(nn.Module): Tuple containing the latent PDF's parameters (mean, log(σ²)) both as (B, L) tensors where L is the latent dimension. """ - if self.cfg.use_vae and self.training: + if self.config.use_vae and self.training: assert ( "action" in batch ), "actions must be provided when using the variational objective in training mode." - self._stack_images(batch) - batch_size = batch["observation.state"].shape[0] # Prepare the latent for input to the transformer encoder. - if self.cfg.use_vae and "action" in batch: + if self.config.use_vae and "action" in batch: # Prepare the input to the VAE encoder: [cls, *joint_space_configuration, *action_sequence]. cls_embed = einops.repeat( self.vae_encoder_cls_embed.weight, "1 d -> b 1 d", b=batch_size @@ -306,7 +328,7 @@ class ActionChunkingTransformerPolicy(nn.Module): # Forward pass through the transformer modules. encoder_out = self.encoder(encoder_in, pos_embed=pos_embed) decoder_in = torch.zeros( - (self.cfg.chunk_size, batch_size, self.cfg.d_model), + (self.config.chunk_size, batch_size, self.config.dim_model), dtype=pos_embed.dtype, device=pos_embed.device, ) @@ -324,21 +346,14 @@ class ActionChunkingTransformerPolicy(nn.Module): return actions, (mu, log_sigma_x2) - def save(self, fp): - torch.save(self.state_dict(), fp) - def load(self, fp): - d = torch.load(fp) - self.load_state_dict(d) - - -class _TransformerEncoder(nn.Module): +class ACTEncoder(nn.Module): """Convenience module for running multiple encoder layers, maybe followed by normalization.""" - def __init__(self, cfg: ActionChunkingTransformerConfig): + def __init__(self, config: ACTConfig): super().__init__() - self.layers = nn.ModuleList([_TransformerEncoderLayer(cfg) for _ in range(cfg.n_encoder_layers)]) - self.norm = nn.LayerNorm(cfg.d_model) if cfg.pre_norm else nn.Identity() + self.layers = nn.ModuleList([ACTEncoderLayer(config) for _ in range(config.n_encoder_layers)]) + self.norm = nn.LayerNorm(config.dim_model) if config.pre_norm else nn.Identity() def forward(self, x: Tensor, pos_embed: Tensor | None = None) -> Tensor: for layer in self.layers: @@ -347,23 +362,23 @@ class _TransformerEncoder(nn.Module): return x -class _TransformerEncoderLayer(nn.Module): - def __init__(self, cfg: ActionChunkingTransformerConfig): +class ACTEncoderLayer(nn.Module): + def __init__(self, config: ACTConfig): super().__init__() - self.self_attn = nn.MultiheadAttention(cfg.d_model, cfg.n_heads, dropout=cfg.dropout) + self.self_attn = nn.MultiheadAttention(config.dim_model, config.n_heads, dropout=config.dropout) # Feed forward layers. - self.linear1 = nn.Linear(cfg.d_model, cfg.dim_feedforward) - self.dropout = nn.Dropout(cfg.dropout) - self.linear2 = nn.Linear(cfg.dim_feedforward, cfg.d_model) + self.linear1 = nn.Linear(config.dim_model, config.dim_feedforward) + self.dropout = nn.Dropout(config.dropout) + self.linear2 = nn.Linear(config.dim_feedforward, config.dim_model) - self.norm1 = nn.LayerNorm(cfg.d_model) - self.norm2 = nn.LayerNorm(cfg.d_model) - self.dropout1 = nn.Dropout(cfg.dropout) - self.dropout2 = nn.Dropout(cfg.dropout) + self.norm1 = nn.LayerNorm(config.dim_model) + self.norm2 = nn.LayerNorm(config.dim_model) + self.dropout1 = nn.Dropout(config.dropout) + self.dropout2 = nn.Dropout(config.dropout) - self.activation = _get_activation_fn(cfg.feedforward_activation) - self.pre_norm = cfg.pre_norm + self.activation = get_activation_fn(config.feedforward_activation) + self.pre_norm = config.pre_norm def forward(self, x, pos_embed: Tensor | None = None) -> Tensor: skip = x @@ -385,12 +400,12 @@ class _TransformerEncoderLayer(nn.Module): return x -class _TransformerDecoder(nn.Module): - def __init__(self, cfg: ActionChunkingTransformerConfig): +class ACTDecoder(nn.Module): + def __init__(self, config: ACTConfig): """Convenience module for running multiple decoder layers followed by normalization.""" super().__init__() - self.layers = nn.ModuleList([_TransformerDecoderLayer(cfg) for _ in range(cfg.n_decoder_layers)]) - self.norm = nn.LayerNorm(cfg.d_model) + self.layers = nn.ModuleList([ACTDecoderLayer(config) for _ in range(config.n_decoder_layers)]) + self.norm = nn.LayerNorm(config.dim_model) def forward( self, @@ -408,26 +423,26 @@ class _TransformerDecoder(nn.Module): return x -class _TransformerDecoderLayer(nn.Module): - def __init__(self, cfg: ActionChunkingTransformerConfig): +class ACTDecoderLayer(nn.Module): + def __init__(self, config: ACTConfig): super().__init__() - self.self_attn = nn.MultiheadAttention(cfg.d_model, cfg.n_heads, dropout=cfg.dropout) - self.multihead_attn = nn.MultiheadAttention(cfg.d_model, cfg.n_heads, dropout=cfg.dropout) + self.self_attn = nn.MultiheadAttention(config.dim_model, config.n_heads, dropout=config.dropout) + self.multihead_attn = nn.MultiheadAttention(config.dim_model, config.n_heads, dropout=config.dropout) # Feed forward layers. - self.linear1 = nn.Linear(cfg.d_model, cfg.dim_feedforward) - self.dropout = nn.Dropout(cfg.dropout) - self.linear2 = nn.Linear(cfg.dim_feedforward, cfg.d_model) + self.linear1 = nn.Linear(config.dim_model, config.dim_feedforward) + self.dropout = nn.Dropout(config.dropout) + self.linear2 = nn.Linear(config.dim_feedforward, config.dim_model) - self.norm1 = nn.LayerNorm(cfg.d_model) - self.norm2 = nn.LayerNorm(cfg.d_model) - self.norm3 = nn.LayerNorm(cfg.d_model) - self.dropout1 = nn.Dropout(cfg.dropout) - self.dropout2 = nn.Dropout(cfg.dropout) - self.dropout3 = nn.Dropout(cfg.dropout) + self.norm1 = nn.LayerNorm(config.dim_model) + self.norm2 = nn.LayerNorm(config.dim_model) + self.norm3 = nn.LayerNorm(config.dim_model) + self.dropout1 = nn.Dropout(config.dropout) + self.dropout2 = nn.Dropout(config.dropout) + self.dropout3 = nn.Dropout(config.dropout) - self.activation = _get_activation_fn(cfg.feedforward_activation) - self.pre_norm = cfg.pre_norm + self.activation = get_activation_fn(config.feedforward_activation) + self.pre_norm = config.pre_norm def maybe_add_pos_embed(self, tensor: Tensor, pos_embed: Tensor | None) -> Tensor: return tensor if pos_embed is None else tensor + pos_embed @@ -480,7 +495,7 @@ class _TransformerDecoderLayer(nn.Module): return x -def _create_sinusoidal_position_embedding(num_positions: int, dimension: int) -> Tensor: +def create_sinusoidal_pos_embedding(num_positions: int, dimension: int) -> Tensor: """1D sinusoidal positional embeddings as in Attention is All You Need. Args: @@ -498,7 +513,7 @@ def _create_sinusoidal_position_embedding(num_positions: int, dimension: int) -> return torch.from_numpy(sinusoid_table).float() -class _SinusoidalPositionEmbedding2D(nn.Module): +class ACTSinusoidalPositionEmbedding2d(nn.Module): """2D sinusoidal positional embeddings similar to what's presented in Attention Is All You Need. The variation is that the position indices are normalized in [0, 2π] (not quite: the lower bound is 1/H @@ -552,7 +567,7 @@ class _SinusoidalPositionEmbedding2D(nn.Module): return pos_embed -def _get_activation_fn(activation: str) -> Callable: +def get_activation_fn(activation: str) -> Callable: """Return an activation function given a string.""" if activation == "relu": return F.relu diff --git a/lerobot/common/policies/diffusion/configuration_diffusion.py b/lerobot/common/policies/diffusion/configuration_diffusion.py index 432afa21..b5188488 100644 --- a/lerobot/common/policies/diffusion/configuration_diffusion.py +++ b/lerobot/common/policies/diffusion/configuration_diffusion.py @@ -3,7 +3,7 @@ from dataclasses import dataclass, field @dataclass class DiffusionConfig: - """Configuration class for Diffusion Policy. + """Configuration class for DiffusionPolicy. Defaults are configured for training with PushT providing proprioceptive and single camera observations. @@ -25,11 +25,12 @@ class DiffusionConfig: The key represents the output data name, and the value is a list indicating the dimensions of the corresponding data. For example, "action" refers to an output shape of [14], indicating 14-dimensional actions. Importantly, shapes doesnt include batch dimension or temporal dimension. - normalize_input_modes: A dictionary with key represents the modality (e.g. "observation.state"), - and the value specifies the normalization mode to apply. The two availables - modes are "mean_std" which substracts the mean and divide by the standard - deviation and "min_max" which rescale in a [-1, 1] range. - unnormalize_output_modes: Similar dictionary as `normalize_input_modes`, but to unormalize in original scale. + input_normalization_modes: A dictionary with key representing the modality (e.g. "observation.state"), + and the value specifies the normalization mode to apply. The two available modes are "mean_std" + which subtracts the mean and divides by the standard deviation and "min_max" which rescale in a + [-1, 1] range. + output_normalization_modes: Similar dictionary as `normalize_input_modes`, but to unnormalize to the + original scale. Note that this is also used for normalizing the training targets. vision_backbone: Name of the torchvision resnet backbone to use for encoding images. crop_shape: (H, W) shape to crop images to as a preprocessing step for the vision backbone. Must fit within the image size. If None, no cropping is done. @@ -70,13 +71,13 @@ class DiffusionConfig: horizon: int = 16 n_action_steps: int = 8 - input_shapes: dict[str, list[str]] = field( + input_shapes: dict[str, list[int]] = field( default_factory=lambda: { "observation.image": [3, 96, 96], "observation.state": [2], } ) - output_shapes: dict[str, list[str]] = field( + output_shapes: dict[str, list[int]] = field( default_factory=lambda: { "action": [2], } @@ -119,15 +120,6 @@ class DiffusionConfig: # --- # TODO(alexander-soare): Remove these from the policy config. - batch_size: int = 64 - grad_clip_norm: int = 10 - lr: float = 1.0e-4 - lr_scheduler: str = "cosine" - lr_warmup_steps: int = 500 - adam_betas: tuple[float, float] = (0.95, 0.999) - adam_eps: float = 1.0e-8 - adam_weight_decay: float = 1.0e-6 - utd: int = 1 use_ema: bool = True ema_update_after_step: int = 0 ema_min_alpha: float = 0.0 diff --git a/lerobot/common/policies/diffusion/modeling_diffusion.py b/lerobot/common/policies/diffusion/modeling_diffusion.py index f9358198..5b6da771 100644 --- a/lerobot/common/policies/diffusion/modeling_diffusion.py +++ b/lerobot/common/policies/diffusion/modeling_diffusion.py @@ -9,7 +9,6 @@ TODO(alexander-soare): """ import copy -import logging import math from collections import deque from typing import Callable @@ -19,6 +18,7 @@ import torch import torch.nn.functional as F # noqa: N812 import torchvision from diffusers.schedulers.scheduling_ddpm import DDPMScheduler +from huggingface_hub import PyTorchModelHubMixin from robomimic.models.base_nets import SpatialSoftmax from torch import Tensor, nn from torch.nn.modules.batchnorm import _BatchNorm @@ -32,7 +32,7 @@ from lerobot.common.policies.utils import ( ) -class DiffusionPolicy(nn.Module): +class DiffusionPolicy(nn.Module, PyTorchModelHubMixin): """ Diffusion Policy as per "Diffusion Policy: Visuomotor Policy Learning via Action Diffusion" (paper: https://arxiv.org/abs/2303.04137, code: https://github.com/real-stanford/diffusion_policy). @@ -41,49 +41,55 @@ class DiffusionPolicy(nn.Module): name = "diffusion" def __init__( - self, cfg: DiffusionConfig | None = None, lr_scheduler_num_training_steps: int = 0, dataset_stats=None + self, + config: DiffusionConfig | None = None, + dataset_stats: dict[str, dict[str, Tensor]] | None = None, ): """ Args: - cfg: Policy configuration class instance or None, in which case the default instantiation of the - configuration class is used. + config: Policy configuration class instance or None, in which case the default instantiation of + the configuration class is used. + dataset_stats: Dataset statistics to be used for normalization. If not passed here, it is expected + that they will be passed with a call to `load_state_dict` before the policy is used. """ super().__init__() - # TODO(alexander-soare): LR scheduler will be removed. - assert lr_scheduler_num_training_steps > 0 - if cfg is None: - cfg = DiffusionConfig() - self.cfg = cfg - self.normalize_inputs = Normalize(cfg.input_shapes, cfg.input_normalization_modes, dataset_stats) - self.normalize_targets = Normalize(cfg.output_shapes, cfg.output_normalization_modes, dataset_stats) + if config is None: + config = DiffusionConfig() + self.config = config + self.normalize_inputs = Normalize( + config.input_shapes, config.input_normalization_modes, dataset_stats + ) + self.normalize_targets = Normalize( + config.output_shapes, config.output_normalization_modes, dataset_stats + ) self.unnormalize_outputs = Unnormalize( - cfg.output_shapes, cfg.output_normalization_modes, dataset_stats + config.output_shapes, config.output_normalization_modes, dataset_stats ) # queues are populated during rollout of the policy, they contain the n latest observations and actions self._queues = None - self.diffusion = _DiffusionUnetImagePolicy(cfg) + self.diffusion = DiffusionModel(config) # TODO(alexander-soare): This should probably be managed outside of the policy class. self.ema_diffusion = None self.ema = None - if self.cfg.use_ema: + if self.config.use_ema: self.ema_diffusion = copy.deepcopy(self.diffusion) - self.ema = _EMA(cfg, model=self.ema_diffusion) + self.ema = DiffusionEMA(config, model=self.ema_diffusion) def reset(self): """ Clear observation and action queues. Should be called on `env.reset()` """ self._queues = { - "observation.image": deque(maxlen=self.cfg.n_obs_steps), - "observation.state": deque(maxlen=self.cfg.n_obs_steps), - "action": deque(maxlen=self.cfg.n_action_steps), + "observation.image": deque(maxlen=self.config.n_obs_steps), + "observation.state": deque(maxlen=self.config.n_obs_steps), + "action": deque(maxlen=self.config.n_action_steps), } @torch.no_grad - def select_action(self, batch: dict[str, Tensor], **_) -> Tensor: + def select_action(self, batch: dict[str, Tensor]) -> Tensor: """Select a single action given environment observations. This method handles caching a history of observations and an action trajectory generated by the @@ -131,53 +137,41 @@ class DiffusionPolicy(nn.Module): action = self._queues["action"].popleft() return action - def forward(self, batch: dict[str, Tensor], **_) -> dict[str, Tensor]: + def forward(self, batch: dict[str, Tensor]) -> dict[str, Tensor]: """Run the batch through the model and compute the loss for training or validation.""" batch = self.normalize_inputs(batch) batch = self.normalize_targets(batch) loss = self.diffusion.compute_loss(batch) return {"loss": loss} - def save(self, fp): - torch.save(self.state_dict(), fp) - def load(self, fp): - d = torch.load(fp) - missing_keys, unexpected_keys = self.load_state_dict(d, strict=False) - if len(missing_keys) > 0: - assert all(k.startswith("ema_diffusion.") for k in missing_keys) - logging.warning( - "DiffusionPolicy.load expected ema parameters in loaded state dict but none were found." - ) - assert len(unexpected_keys) == 0 - - -class _DiffusionUnetImagePolicy(nn.Module): - def __init__(self, cfg: DiffusionConfig): +class DiffusionModel(nn.Module): + def __init__(self, config: DiffusionConfig): super().__init__() - self.cfg = cfg + self.config = config - self.rgb_encoder = _RgbEncoder(cfg) - self.unet = _ConditionalUnet1D( - cfg, - global_cond_dim=(cfg.output_shapes["action"][0] + self.rgb_encoder.feature_dim) * cfg.n_obs_steps, + self.rgb_encoder = DiffusionRgbEncoder(config) + self.unet = DiffusionConditionalUnet1d( + config, + global_cond_dim=(config.output_shapes["action"][0] + self.rgb_encoder.feature_dim) + * config.n_obs_steps, ) self.noise_scheduler = DDPMScheduler( - num_train_timesteps=cfg.num_train_timesteps, - beta_start=cfg.beta_start, - beta_end=cfg.beta_end, - beta_schedule=cfg.beta_schedule, + num_train_timesteps=config.num_train_timesteps, + beta_start=config.beta_start, + beta_end=config.beta_end, + beta_schedule=config.beta_schedule, variance_type="fixed_small", - clip_sample=cfg.clip_sample, - clip_sample_range=cfg.clip_sample_range, - prediction_type=cfg.prediction_type, + clip_sample=config.clip_sample, + clip_sample_range=config.clip_sample_range, + prediction_type=config.prediction_type, ) - if cfg.num_inference_steps is None: + if config.num_inference_steps is None: self.num_inference_steps = self.noise_scheduler.config.num_train_timesteps else: - self.num_inference_steps = cfg.num_inference_steps + self.num_inference_steps = config.num_inference_steps # ========= inference ============ def conditional_sample( @@ -188,7 +182,7 @@ class _DiffusionUnetImagePolicy(nn.Module): # Sample prior. sample = torch.randn( - size=(batch_size, self.cfg.horizon, self.cfg.output_shapes["action"][0]), + size=(batch_size, self.config.horizon, self.config.output_shapes["action"][0]), dtype=dtype, device=device, generator=generator, @@ -218,7 +212,7 @@ class _DiffusionUnetImagePolicy(nn.Module): """ assert set(batch).issuperset({"observation.state", "observation.image"}) batch_size, n_obs_steps = batch["observation.state"].shape[:2] - assert n_obs_steps == self.cfg.n_obs_steps + assert n_obs_steps == self.config.n_obs_steps # Extract image feature (first combine batch and sequence dims). img_features = self.rgb_encoder(einops.rearrange(batch["observation.image"], "b n ... -> (b n) ...")) @@ -231,10 +225,10 @@ class _DiffusionUnetImagePolicy(nn.Module): sample = self.conditional_sample(batch_size, global_cond=global_cond) # `horizon` steps worth of actions (from the first observation). - actions = sample[..., : self.cfg.output_shapes["action"][0]] + actions = sample[..., : self.config.output_shapes["action"][0]] # Extract `n_action_steps` steps worth of actions (from the current observation). start = n_obs_steps - 1 - end = start + self.cfg.n_action_steps + end = start + self.config.n_action_steps actions = actions[:, start:end] return actions @@ -253,8 +247,8 @@ class _DiffusionUnetImagePolicy(nn.Module): assert set(batch).issuperset({"observation.state", "observation.image", "action", "action_is_pad"}) batch_size, n_obs_steps = batch["observation.state"].shape[:2] horizon = batch["action"].shape[1] - assert horizon == self.cfg.horizon - assert n_obs_steps == self.cfg.n_obs_steps + assert horizon == self.config.horizon + assert n_obs_steps == self.config.n_obs_steps # Extract image feature (first combine batch and sequence dims). img_features = self.rgb_encoder(einops.rearrange(batch["observation.image"], "b n ... -> (b n) ...")) @@ -283,12 +277,12 @@ class _DiffusionUnetImagePolicy(nn.Module): # Compute the loss. # The target is either the original trajectory, or the noise. - if self.cfg.prediction_type == "epsilon": + if self.config.prediction_type == "epsilon": target = eps - elif self.cfg.prediction_type == "sample": + elif self.config.prediction_type == "sample": target = batch["action"] else: - raise ValueError(f"Unsupported prediction type {self.cfg.prediction_type}") + raise ValueError(f"Unsupported prediction type {self.config.prediction_type}") loss = F.mse_loss(pred, target, reduction="none") @@ -300,35 +294,35 @@ class _DiffusionUnetImagePolicy(nn.Module): return loss.mean() -class _RgbEncoder(nn.Module): +class DiffusionRgbEncoder(nn.Module): """Encoder an RGB image into a 1D feature vector. Includes the ability to normalize and crop the image first. """ - def __init__(self, cfg: DiffusionConfig): + def __init__(self, config: DiffusionConfig): super().__init__() # Set up optional preprocessing. - if cfg.crop_shape is not None: + if config.crop_shape is not None: self.do_crop = True # Always use center crop for eval - self.center_crop = torchvision.transforms.CenterCrop(cfg.crop_shape) - if cfg.crop_is_random: - self.maybe_random_crop = torchvision.transforms.RandomCrop(cfg.crop_shape) + self.center_crop = torchvision.transforms.CenterCrop(config.crop_shape) + if config.crop_is_random: + self.maybe_random_crop = torchvision.transforms.RandomCrop(config.crop_shape) else: self.maybe_random_crop = self.center_crop else: self.do_crop = False # Set up backbone. - backbone_model = getattr(torchvision.models, cfg.vision_backbone)( - weights=cfg.pretrained_backbone_weights + backbone_model = getattr(torchvision.models, config.vision_backbone)( + weights=config.pretrained_backbone_weights ) # Note: This assumes that the layer4 feature map is children()[-3] # TODO(alexander-soare): Use a safer alternative. self.backbone = nn.Sequential(*(list(backbone_model.children())[:-2])) - if cfg.use_group_norm: - if cfg.pretrained_backbone_weights: + if config.use_group_norm: + if config.pretrained_backbone_weights: raise ValueError( "You can't replace BatchNorm in a pretrained model without ruining the weights!" ) @@ -342,11 +336,11 @@ class _RgbEncoder(nn.Module): # Use a dry run to get the feature map shape. with torch.inference_mode(): feat_map_shape = tuple( - self.backbone(torch.zeros(size=(1, *cfg.input_shapes["observation.image"]))).shape[1:] + self.backbone(torch.zeros(size=(1, *config.input_shapes["observation.image"]))).shape[1:] ) - self.pool = SpatialSoftmax(feat_map_shape, num_kp=cfg.spatial_softmax_num_keypoints) - self.feature_dim = cfg.spatial_softmax_num_keypoints * 2 - self.out = nn.Linear(cfg.spatial_softmax_num_keypoints * 2, self.feature_dim) + self.pool = SpatialSoftmax(feat_map_shape, num_kp=config.spatial_softmax_num_keypoints) + self.feature_dim = config.spatial_softmax_num_keypoints * 2 + self.out = nn.Linear(config.spatial_softmax_num_keypoints * 2, self.feature_dim) self.relu = nn.ReLU() def forward(self, x: Tensor) -> Tensor: @@ -403,7 +397,7 @@ def _replace_submodules( return root_module -class _SinusoidalPosEmb(nn.Module): +class DiffusionSinusoidalPosEmb(nn.Module): """1D sinusoidal positional embeddings as in Attention is All You Need.""" def __init__(self, dim: int): @@ -420,7 +414,7 @@ class _SinusoidalPosEmb(nn.Module): return emb -class _Conv1dBlock(nn.Module): +class DiffusionConv1dBlock(nn.Module): """Conv1d --> GroupNorm --> Mish""" def __init__(self, inp_channels, out_channels, kernel_size, n_groups=8): @@ -436,40 +430,40 @@ class _Conv1dBlock(nn.Module): return self.block(x) -class _ConditionalUnet1D(nn.Module): +class DiffusionConditionalUnet1d(nn.Module): """A 1D convolutional UNet with FiLM modulation for conditioning. Note: this removes local conditioning as compared to the original diffusion policy code. """ - def __init__(self, cfg: DiffusionConfig, global_cond_dim: int): + def __init__(self, config: DiffusionConfig, global_cond_dim: int): super().__init__() - self.cfg = cfg + self.config = config # Encoder for the diffusion timestep. self.diffusion_step_encoder = nn.Sequential( - _SinusoidalPosEmb(cfg.diffusion_step_embed_dim), - nn.Linear(cfg.diffusion_step_embed_dim, cfg.diffusion_step_embed_dim * 4), + DiffusionSinusoidalPosEmb(config.diffusion_step_embed_dim), + nn.Linear(config.diffusion_step_embed_dim, config.diffusion_step_embed_dim * 4), nn.Mish(), - nn.Linear(cfg.diffusion_step_embed_dim * 4, cfg.diffusion_step_embed_dim), + nn.Linear(config.diffusion_step_embed_dim * 4, config.diffusion_step_embed_dim), ) # The FiLM conditioning dimension. - cond_dim = cfg.diffusion_step_embed_dim + global_cond_dim + cond_dim = config.diffusion_step_embed_dim + global_cond_dim # In channels / out channels for each downsampling block in the Unet's encoder. For the decoder, we # just reverse these. - in_out = [(cfg.output_shapes["action"][0], cfg.down_dims[0])] + list( - zip(cfg.down_dims[:-1], cfg.down_dims[1:], strict=True) + in_out = [(config.output_shapes["action"][0], config.down_dims[0])] + list( + zip(config.down_dims[:-1], config.down_dims[1:], strict=True) ) # Unet encoder. common_res_block_kwargs = { "cond_dim": cond_dim, - "kernel_size": cfg.kernel_size, - "n_groups": cfg.n_groups, - "use_film_scale_modulation": cfg.use_film_scale_modulation, + "kernel_size": config.kernel_size, + "n_groups": config.n_groups, + "use_film_scale_modulation": config.use_film_scale_modulation, } self.down_modules = nn.ModuleList([]) for ind, (dim_in, dim_out) in enumerate(in_out): @@ -477,8 +471,8 @@ class _ConditionalUnet1D(nn.Module): self.down_modules.append( nn.ModuleList( [ - _ConditionalResidualBlock1D(dim_in, dim_out, **common_res_block_kwargs), - _ConditionalResidualBlock1D(dim_out, dim_out, **common_res_block_kwargs), + DiffusionConditionalResidualBlock1d(dim_in, dim_out, **common_res_block_kwargs), + DiffusionConditionalResidualBlock1d(dim_out, dim_out, **common_res_block_kwargs), # Downsample as long as it is not the last block. nn.Conv1d(dim_out, dim_out, 3, 2, 1) if not is_last else nn.Identity(), ] @@ -488,8 +482,12 @@ class _ConditionalUnet1D(nn.Module): # Processing in the middle of the auto-encoder. self.mid_modules = nn.ModuleList( [ - _ConditionalResidualBlock1D(cfg.down_dims[-1], cfg.down_dims[-1], **common_res_block_kwargs), - _ConditionalResidualBlock1D(cfg.down_dims[-1], cfg.down_dims[-1], **common_res_block_kwargs), + DiffusionConditionalResidualBlock1d( + config.down_dims[-1], config.down_dims[-1], **common_res_block_kwargs + ), + DiffusionConditionalResidualBlock1d( + config.down_dims[-1], config.down_dims[-1], **common_res_block_kwargs + ), ] ) @@ -501,8 +499,8 @@ class _ConditionalUnet1D(nn.Module): nn.ModuleList( [ # dim_in * 2, because it takes the encoder's skip connection as well - _ConditionalResidualBlock1D(dim_in * 2, dim_out, **common_res_block_kwargs), - _ConditionalResidualBlock1D(dim_out, dim_out, **common_res_block_kwargs), + DiffusionConditionalResidualBlock1d(dim_in * 2, dim_out, **common_res_block_kwargs), + DiffusionConditionalResidualBlock1d(dim_out, dim_out, **common_res_block_kwargs), # Upsample as long as it is not the last block. nn.ConvTranspose1d(dim_out, dim_out, 4, 2, 1) if not is_last else nn.Identity(), ] @@ -510,8 +508,8 @@ class _ConditionalUnet1D(nn.Module): ) self.final_conv = nn.Sequential( - _Conv1dBlock(cfg.down_dims[0], cfg.down_dims[0], kernel_size=cfg.kernel_size), - nn.Conv1d(cfg.down_dims[0], cfg.output_shapes["action"][0], 1), + DiffusionConv1dBlock(config.down_dims[0], config.down_dims[0], kernel_size=config.kernel_size), + nn.Conv1d(config.down_dims[0], config.output_shapes["action"][0], 1), ) def forward(self, x: Tensor, timestep: Tensor | int, global_cond=None) -> Tensor: @@ -559,7 +557,7 @@ class _ConditionalUnet1D(nn.Module): return x -class _ConditionalResidualBlock1D(nn.Module): +class DiffusionConditionalResidualBlock1d(nn.Module): """ResNet style 1D convolutional block with FiLM modulation for conditioning.""" def __init__( @@ -578,13 +576,13 @@ class _ConditionalResidualBlock1D(nn.Module): self.use_film_scale_modulation = use_film_scale_modulation self.out_channels = out_channels - self.conv1 = _Conv1dBlock(in_channels, out_channels, kernel_size, n_groups=n_groups) + self.conv1 = DiffusionConv1dBlock(in_channels, out_channels, kernel_size, n_groups=n_groups) # FiLM modulation (https://arxiv.org/abs/1709.07871) outputs per-channel bias and (maybe) scale. cond_channels = out_channels * 2 if use_film_scale_modulation else out_channels self.cond_encoder = nn.Sequential(nn.Mish(), nn.Linear(cond_dim, cond_channels)) - self.conv2 = _Conv1dBlock(out_channels, out_channels, kernel_size, n_groups=n_groups) + self.conv2 = DiffusionConv1dBlock(out_channels, out_channels, kernel_size, n_groups=n_groups) # A final convolution for dimension matching the residual (if needed). self.residual_conv = ( @@ -617,18 +615,18 @@ class _ConditionalResidualBlock1D(nn.Module): return out -class _EMA: +class DiffusionEMA: """ Exponential Moving Average of models weights """ - def __init__(self, cfg: DiffusionConfig, model: nn.Module): + def __init__(self, config: DiffusionConfig, model: nn.Module): """ @crowsonkb's notes on EMA Warmup: - If gamma=1 and power=1, implements a simple average. gamma=1, power=2/3 are good values for models you plan - to train for a million or more steps (reaches decay factor 0.999 at 31.6K steps, 0.9999 at 1M steps), - gamma=1, power=3/4 for models you plan to train for less (reaches decay factor 0.999 at 10K steps, 0.9999 - at 215.4k steps). + If gamma=1 and power=1, implements a simple average. gamma=1, power=2/3 are good values for models + you plan to train for a million or more steps (reaches decay factor 0.999 at 31.6K steps, 0.9999 + at 1M steps), gamma=1, power=3/4 for models you plan to train for less (reaches decay factor 0.999 + at 10K steps, 0.9999 at 215.4k steps). Args: inv_gamma (float): Inverse multiplicative factor of EMA warmup. Default: 1. power (float): Exponential factor of EMA warmup. Default: 2/3. @@ -639,11 +637,11 @@ class _EMA: self.averaged_model.eval() self.averaged_model.requires_grad_(False) - self.update_after_step = cfg.ema_update_after_step - self.inv_gamma = cfg.ema_inv_gamma - self.power = cfg.ema_power - self.min_alpha = cfg.ema_min_alpha - self.max_alpha = cfg.ema_max_alpha + self.update_after_step = config.ema_update_after_step + self.inv_gamma = config.ema_inv_gamma + self.power = config.ema_power + self.min_alpha = config.ema_min_alpha + self.max_alpha = config.ema_max_alpha self.alpha = 0.0 self.optimization_step = 0 diff --git a/lerobot/common/policies/factory.py b/lerobot/common/policies/factory.py index a8235388..4819ca80 100644 --- a/lerobot/common/policies/factory.py +++ b/lerobot/common/policies/factory.py @@ -2,6 +2,7 @@ import inspect from omegaconf import DictConfig, OmegaConf +from lerobot.common.policies.policy_protocol import Policy from lerobot.common.utils.utils import get_safe_torch_device @@ -20,42 +21,53 @@ def _policy_cfg_from_hydra_cfg(policy_cfg_class, hydra_cfg): return policy_cfg -def make_policy(hydra_cfg: DictConfig, dataset_stats=None): - if hydra_cfg.policy.name == "tdmpc": - from lerobot.common.policies.tdmpc.policy import TDMPCPolicy +def get_policy_and_config_classes(name: str) -> tuple[Policy, object]: + """Get the policy's class and config class given a name (matching the policy class' `name` attribute).""" + if name == "tdmpc": + from lerobot.common.policies.tdmpc.configuration_tdmpc import TDMPCConfig + from lerobot.common.policies.tdmpc.modeling_tdmpc import TDMPCPolicy - policy = TDMPCPolicy( - hydra_cfg.policy, - n_obs_steps=hydra_cfg.n_obs_steps, - n_action_steps=hydra_cfg.n_action_steps, - device=hydra_cfg.device, - ) - elif hydra_cfg.policy.name == "diffusion": + return TDMPCPolicy, TDMPCConfig + elif name == "diffusion": from lerobot.common.policies.diffusion.configuration_diffusion import DiffusionConfig from lerobot.common.policies.diffusion.modeling_diffusion import DiffusionPolicy - policy_cfg = _policy_cfg_from_hydra_cfg(DiffusionConfig, hydra_cfg) - policy = DiffusionPolicy(policy_cfg, hydra_cfg.offline_steps, dataset_stats) - policy.to(get_safe_torch_device(hydra_cfg.device)) - elif hydra_cfg.policy.name == "act": - from lerobot.common.policies.act.configuration_act import ActionChunkingTransformerConfig - from lerobot.common.policies.act.modeling_act import ActionChunkingTransformerPolicy + return DiffusionPolicy, DiffusionConfig + elif name == "act": + from lerobot.common.policies.act.configuration_act import ACTConfig + from lerobot.common.policies.act.modeling_act import ACTPolicy - policy_cfg = _policy_cfg_from_hydra_cfg(ActionChunkingTransformerConfig, hydra_cfg) - policy = ActionChunkingTransformerPolicy(policy_cfg, dataset_stats) - policy.to(get_safe_torch_device(hydra_cfg.device)) + return ACTPolicy, ACTConfig else: - raise ValueError(hydra_cfg.policy.name) + raise NotImplementedError(f"Policy with name {name} is not implemented.") - if hydra_cfg.policy.pretrained_model_path: - # TODO(rcadene): hack for old pretrained models from fowm - if hydra_cfg.policy.name == "tdmpc" and "fowm" in hydra_cfg.policy.pretrained_model_path: - if "offline" in hydra_cfg.policy.pretrained_model_path: - policy.step[0] = 25000 - elif "final" in hydra_cfg.policy.pretrained_model_path: - policy.step[0] = 100000 - else: - raise NotImplementedError() - policy.load(hydra_cfg.policy.pretrained_model_path) + +def make_policy( + hydra_cfg: DictConfig, pretrained_policy_name_or_path: str | None = None, dataset_stats=None +) -> Policy: + """Make an instance of a policy class. + + Args: + hydra_cfg: A parsed Hydra configuration (see scripts). If `pretrained_policy_name_or_path` is + provided, only `hydra_cfg.policy.name` is used while everything else is ignored. + pretrained_policy_name_or_path: Either the repo ID of a model hosted on the Hub or a path to a + directory containing weights saved using `Policy.save_pretrained`. Note that providing this + argument overrides everything in `hydra_cfg.policy` apart from `hydra_cfg.policy.name`. + dataset_stats: Dataset statistics to use for (un)normalization of inputs/outputs in the policy. Must + be provided when initializing a new policy, and must not be provided when loading a pretrained + policy. Therefore, this argument is mutually exclusive with `pretrained_policy_name_or_path`. + """ + if not (pretrained_policy_name_or_path is None) ^ (dataset_stats is None): + raise ValueError("Only one of `pretrained_policy_name_or_path` and `dataset_stats` may be provided.") + + policy_cls, policy_cfg_class = get_policy_and_config_classes(hydra_cfg.policy.name) + + if pretrained_policy_name_or_path is None: + policy_cfg = _policy_cfg_from_hydra_cfg(policy_cfg_class, hydra_cfg) + policy = policy_cls(policy_cfg, dataset_stats) + else: + policy = policy_cls.from_pretrained(pretrained_policy_name_or_path) + + policy.to(get_safe_torch_device(hydra_cfg.device)) return policy diff --git a/lerobot/common/policies/normalize.py b/lerobot/common/policies/normalize.py index df615a21..ab57c8ba 100644 --- a/lerobot/common/policies/normalize.py +++ b/lerobot/common/policies/normalize.py @@ -57,17 +57,28 @@ def create_stats_buffers( ) if stats is not None: + # Note: The clone is needed to make sure that the logic in save_pretrained doesn't see duplicated + # tensors anywhere (for example, when we use the same stats for normalization and + # unnormalization). See the logic here + # https://github.com/huggingface/safetensors/blob/079781fd0dc455ba0fe851e2b4507c33d0c0d407/bindings/python/py_src/safetensors/torch.py#L97. if mode == "mean_std": - buffer["mean"].data = stats[key]["mean"] - buffer["std"].data = stats[key]["std"] + buffer["mean"].data = stats[key]["mean"].clone() + buffer["std"].data = stats[key]["std"].clone() elif mode == "min_max": - buffer["min"].data = stats[key]["min"] - buffer["max"].data = stats[key]["max"] + buffer["min"].data = stats[key]["min"].clone() + buffer["max"].data = stats[key]["max"].clone() stats_buffers[key] = buffer return stats_buffers +def _no_stats_error_str(name: str) -> str: + return ( + f"`{name}` is infinity. You should either initialize with `stats` as an argument, or use a " + "pretrained model." + ) + + class Normalize(nn.Module): """Normalizes data (e.g. "observation.image") for more stable and faster convergence during training.""" @@ -99,7 +110,6 @@ class Normalize(nn.Module): self.shapes = shapes self.modes = modes self.stats = stats - # `self.buffer_observation_state["mean"]` contains `torch.tensor(state_dim)` stats_buffers = create_stats_buffers(shapes, modes, stats) for key, buffer in stats_buffers.items(): setattr(self, "buffer_" + key.replace(".", "_"), buffer) @@ -113,26 +123,14 @@ class Normalize(nn.Module): if mode == "mean_std": mean = buffer["mean"] std = buffer["std"] - assert not torch.isinf(mean).any(), ( - "`mean` is infinity. You forgot to initialize with `stats` as argument, or called " - "`policy.load_state_dict`." - ) - assert not torch.isinf(std).any(), ( - "`std` is infinity. You forgot to initialize with `stats` as argument, or called " - "`policy.load_state_dict`." - ) + assert not torch.isinf(mean).any(), _no_stats_error_str("mean") + assert not torch.isinf(std).any(), _no_stats_error_str("std") batch[key] = (batch[key] - mean) / (std + 1e-8) elif mode == "min_max": min = buffer["min"] max = buffer["max"] - assert not torch.isinf(min).any(), ( - "`min` is infinity. You forgot to initialize with `stats` as argument, or called " - "`policy.load_state_dict`." - ) - assert not torch.isinf(max).any(), ( - "`max` is infinity. You forgot to initialize with `stats` as argument, or called " - "`policy.load_state_dict`." - ) + assert not torch.isinf(min).any(), _no_stats_error_str("min") + assert not torch.isinf(max).any(), _no_stats_error_str("max") # normalize to [0,1] batch[key] = (batch[key] - min) / (max - min) # normalize to [-1, 1] @@ -190,26 +188,14 @@ class Unnormalize(nn.Module): if mode == "mean_std": mean = buffer["mean"] std = buffer["std"] - assert not torch.isinf(mean).any(), ( - "`mean` is infinity. You forgot to initialize with `stats` as argument, or called " - "`policy.load_state_dict`." - ) - assert not torch.isinf(std).any(), ( - "`std` is infinity. You forgot to initialize with `stats` as argument, or called " - "`policy.load_state_dict`." - ) + assert not torch.isinf(mean).any(), _no_stats_error_str("mean") + assert not torch.isinf(std).any(), _no_stats_error_str("std") batch[key] = batch[key] * std + mean elif mode == "min_max": min = buffer["min"] max = buffer["max"] - assert not torch.isinf(min).any(), ( - "`min` is infinity. You forgot to initialize with `stats` as argument, or called " - "`policy.load_state_dict`." - ) - assert not torch.isinf(max).any(), ( - "`max` is infinity. You forgot to initialize with `stats` as argument, or called " - "`policy.load_state_dict`." - ) + assert not torch.isinf(min).any(), _no_stats_error_str("min") + assert not torch.isinf(max).any(), _no_stats_error_str("max") batch[key] = (batch[key] + 1) / 2 batch[key] = batch[key] * (max - min) + min else: diff --git a/lerobot/common/policies/policy_protocol.py b/lerobot/common/policies/policy_protocol.py index 29317fa0..5749c6a8 100644 --- a/lerobot/common/policies/policy_protocol.py +++ b/lerobot/common/policies/policy_protocol.py @@ -14,10 +14,21 @@ from torch import Tensor @runtime_checkable class Policy(Protocol): - """The required interface for implementing a policy.""" + """The required interface for implementing a policy. + + We also expect all policies to subclass torch.nn.Module and PyTorchModelHubMixin. + """ name: str + def __init__(self, cfg, dataset_stats: dict[str, dict[str, Tensor]] | None = None): + """ + Args: + cfg: Policy configuration class instance or None, in which case the default instantiation of the + configuration class is used. + dataset_stats: Dataset statistics to be used for normalization. + """ + def reset(self): """To be called whenever the environment is reset. @@ -36,3 +47,13 @@ class Policy(Protocol): When the model uses a history of observations, or outputs a sequence of actions, this method deals with caching. """ + + +@runtime_checkable +class PolicyWithUpdate(Policy, Protocol): + def update(self): + """An update method that is to be called after a training optimization step. + + Implements an additional updates the model parameters may need (for example, doing an EMA step for a + target model, or incrementing an internal buffer). + """ diff --git a/lerobot/common/policies/tdmpc/configuration_tdmpc.py b/lerobot/common/policies/tdmpc/configuration_tdmpc.py new file mode 100644 index 00000000..82e3a507 --- /dev/null +++ b/lerobot/common/policies/tdmpc/configuration_tdmpc.py @@ -0,0 +1,150 @@ +from dataclasses import dataclass, field + + +@dataclass +class TDMPCConfig: + """Configuration class for TDMPCPolicy. + + Defaults are configured for training with xarm_lift_medium_replay providing proprioceptive and single + camera observations. + + The parameters you will most likely need to change are the ones which depend on the environment / sensors. + Those are: `input_shapes`, `output_shapes`, and perhaps `max_random_shift`. + + Args: + n_action_repeats: The number of times to repeat the action returned by the planning. (hint: Google + action repeats in Q-learning or ask your favorite chatbot) + horizon: Horizon for model predictive control. + input_normalization_modes: A dictionary with key representing the modality (e.g. "observation.state"), + and the value specifies the normalization mode to apply. The two available modes are "mean_std" + which subtracts the mean and divides by the standard deviation and "min_max" which rescale in a + [-1, 1] range. Note that here this defaults to None meaning inputs are not normalized. This is to + match the original implementation. + output_normalization_modes: Similar dictionary as `normalize_input_modes`, but to unnormalize to the + original scale. Note that this is also used for normalizing the training targets. NOTE: Clipping + to [-1, +1] is used during MPPI/CEM. Therefore, it is recommended that you stick with "min_max" + normalization mode here. + image_encoder_hidden_dim: Number of channels for the convolutional layers used for image encoding. + state_encoder_hidden_dim: Hidden dimension for MLP used for state vector encoding. + latent_dim: Observation's latent embedding dimension. + q_ensemble_size: Number of Q function estimators to use in an ensemble for uncertainty estimation. + mlp_dim: Hidden dimension of MLPs used for modelling the dynamics encoder, reward function, policy + (π), Q ensemble, and V. + discount: Discount factor (γ) to use for the reinforcement learning formalism. + use_mpc: Whether to use model predictive control. The alternative is to just sample the policy model + (π) for each step. + cem_iterations: Number of iterations for the MPPI/CEM loop in MPC. + max_std: Maximum standard deviation for actions sampled from the gaussian PDF in CEM. + min_std: Minimum standard deviation for noise applied to actions sampled from the policy model (π). + Doubles up as the minimum standard deviation for actions sampled from the gaussian PDF in CEM. + n_gaussian_samples: Number of samples to draw from the gaussian distribution every CEM iteration. Must + be non-zero. + n_pi_samples: Number of samples to draw from the policy / world model rollout every CEM iteration. Can + be zero. + uncertainty_regularizer_coeff: Coefficient for the uncertainty regularization used when estimating + trajectory values (this is the λ coeffiecient in eqn 4 of FOWM). + n_elites: The number of elite samples to use for updating the gaussian parameters every CEM iteration. + elite_weighting_temperature: The temperature to use for softmax weighting (by trajectory value) of the + elites, when updating the gaussian parameters for CEM. + gaussian_mean_momentum: Momentum (α) used for EMA updates of the mean parameter μ of the gaussian + paramters optimized in CEM. Updates are calculated as μ⁻ ← αμ⁻ + (1-α)μ. + max_random_shift_ratio: Maximum random shift (as a proportion of the image size) to apply to the + image(s) (in units of pixels) for training-time augmentation. If set to 0, no such augmentation + is applied. Note that the input images are assumed to be square for this augmentation. + reward_coeff: Loss weighting coefficient for the reward regression loss. + expectile_weight: Weighting (τ) used in expectile regression for the state value function (V). + v_pred < v_target is weighted by τ and v_pred >= v_target is weighted by (1-τ). τ is expected to + be in [0, 1]. Setting τ closer to 1 results in a more "optimistic" V. This is sensible to do + because v_target is obtained by evaluating the learned state-action value functions (Q) with + in-sample actions that may not be always optimal. + value_coeff: Loss weighting coefficient for both the state-action value (Q) TD loss, and the state + value (V) expectile regression loss. + consistency_coeff: Loss weighting coefficient for the consistency loss. + advantage_scaling: A factor by which the advantages are scaled prior to exponentiation for advantage + weighted regression of the policy (π) estimator parameters. Note that the exponentiated advantages + are clamped at 100.0. + pi_coeff: Loss weighting coefficient for the action regression loss. + temporal_decay_coeff: Exponential decay coefficient for decaying the loss coefficient for future time- + steps. Hint: each loss computation involves `horizon` steps worth of actions starting from the + current time step. + target_model_momentum: Momentum (α) used for EMA updates of the target models. Updates are calculated + as ϕ ← αϕ + (1-α)θ where ϕ are the parameters of the target model and θ are the parameters of the + model being trained. + """ + + # Input / output structure. + n_action_repeats: int = 2 + horizon: int = 5 + + input_shapes: dict[str, list[int]] = field( + default_factory=lambda: { + "observation.image": [3, 84, 84], + "observation.state": [4], + } + ) + output_shapes: dict[str, list[int]] = field( + default_factory=lambda: { + "action": [4], + } + ) + + # Normalization / Unnormalization + input_normalization_modes: dict[str, str] | None = None + output_normalization_modes: dict[str, str] = field( + default_factory=lambda: {"action": "min_max"}, + ) + + # Architecture / modeling. + # Neural networks. + image_encoder_hidden_dim: int = 32 + state_encoder_hidden_dim: int = 256 + latent_dim: int = 50 + q_ensemble_size: int = 5 + mlp_dim: int = 512 + # Reinforcement learning. + discount: float = 0.9 + + # Inference. + use_mpc: bool = True + cem_iterations: int = 6 + max_std: float = 2.0 + min_std: float = 0.05 + n_gaussian_samples: int = 512 + n_pi_samples: int = 51 + uncertainty_regularizer_coeff: float = 1.0 + n_elites: int = 50 + elite_weighting_temperature: float = 0.5 + gaussian_mean_momentum: float = 0.1 + + # Training and loss computation. + max_random_shift_ratio: float = 0.0476 + # Loss coefficients. + reward_coeff: float = 0.5 + expectile_weight: float = 0.9 + value_coeff: float = 0.1 + consistency_coeff: float = 20.0 + advantage_scaling: float = 3.0 + pi_coeff: float = 0.5 + temporal_decay_coeff: float = 0.5 + # Target model. + target_model_momentum: float = 0.995 + + def __post_init__(self): + """Input validation (not exhaustive).""" + if self.input_shapes["observation.image"][-2] != self.input_shapes["observation.image"][-1]: + # TODO(alexander-soare): This limitation is solely because of code in the random shift + # augmentation. It should be able to be removed. + raise ValueError( + "Only square images are handled now. Got image shape " + f"{self.input_shapes['observation.image']}." + ) + if self.n_gaussian_samples <= 0: + raise ValueError( + f"The number of guassian samples for CEM should be non-zero. Got `{self.n_gaussian_samples=}`" + ) + if self.output_normalization_modes != {"action": "min_max"}: + raise ValueError( + "TD-MPC assumes the action space dimensions to all be in [-1, 1]. Therefore it is strongly " + f"advised that you stick with the default. See {self.__class__.__name__} docstring for more " + "information." + ) diff --git a/lerobot/common/policies/tdmpc/helper.py b/lerobot/common/policies/tdmpc/helper.py deleted file mode 100644 index 964f1718..00000000 --- a/lerobot/common/policies/tdmpc/helper.py +++ /dev/null @@ -1,576 +0,0 @@ -import os -import pickle -import re - -import numpy as np -import torch -import torch.nn as nn -import torch.nn.functional as F # noqa: N812 -from torch import distributions as pyd -from torch.distributions.utils import _standard_normal - -DEFAULT_ACT_FN = nn.Mish() - - -def __REDUCE__(b): # noqa: N802, N807 - return "mean" if b else "none" - - -def l1(pred, target, reduce=False): - """Computes the L1-loss between predictions and targets.""" - return F.l1_loss(pred, target, reduction=__REDUCE__(reduce)) - - -def mse(pred, target, reduce=False): - """Computes the MSE loss between predictions and targets.""" - return F.mse_loss(pred, target, reduction=__REDUCE__(reduce)) - - -def l2_expectile(diff, expectile=0.7, reduce=False): - weight = torch.where(diff > 0, expectile, (1 - expectile)) - loss = weight * (diff**2) - reduction = __REDUCE__(reduce) - if reduction == "mean": - return torch.mean(loss) - elif reduction == "sum": - return torch.sum(loss) - return loss - - -def _get_out_shape(in_shape, layers): - """Utility function. Returns the output shape of a network for a given input shape.""" - x = torch.randn(*in_shape).unsqueeze(0) - return (nn.Sequential(*layers) if isinstance(layers, list) else layers)(x).squeeze(0).shape - - -def gaussian_logprob(eps, log_std): - """Compute Gaussian log probability.""" - residual = (-0.5 * eps.pow(2) - log_std).sum(-1, keepdim=True) - return residual - 0.5 * np.log(2 * np.pi) * eps.size(-1) - - -def squash(mu, pi, log_pi): - """Apply squashing function.""" - mu = torch.tanh(mu) - pi = torch.tanh(pi) - log_pi -= torch.log(F.relu(1 - pi.pow(2)) + 1e-6).sum(-1, keepdim=True) - return mu, pi, log_pi - - -def orthogonal_init(m): - """Orthogonal layer initialization.""" - if isinstance(m, nn.Linear): - nn.init.orthogonal_(m.weight.data) - if m.bias is not None: - nn.init.zeros_(m.bias) - elif isinstance(m, nn.Conv2d): - gain = nn.init.calculate_gain("relu") - nn.init.orthogonal_(m.weight.data, gain) - if m.bias is not None: - nn.init.zeros_(m.bias) - - -def ema(m, m_target, tau): - """Update slow-moving average of online network (target network) at rate tau.""" - with torch.no_grad(): - # TODO(rcadene, aliberts): issue with strict=False - # for p, p_target in zip(m.parameters(), m_target.parameters(), strict=False): - # p_target.data.lerp_(p.data, tau) - m_params_iter = iter(m.parameters()) - m_target_params_iter = iter(m_target.parameters()) - - while True: - try: - p = next(m_params_iter) - p_target = next(m_target_params_iter) - p_target.data.lerp_(p.data, tau) - except StopIteration: - # If any iterator is exhausted, exit the loop - break - - -def set_requires_grad(net, value): - """Enable/disable gradients for a given (sub)network.""" - for param in net.parameters(): - param.requires_grad_(value) - - -class TruncatedNormal(pyd.Normal): - """Utility class implementing the truncated normal distribution.""" - - default_sample_shape = torch.Size() - - def __init__(self, loc, scale, low=-1.0, high=1.0, eps=1e-6): - super().__init__(loc, scale, validate_args=False) - self.low = low - self.high = high - self.eps = eps - - def _clamp(self, x): - clamped_x = torch.clamp(x, self.low + self.eps, self.high - self.eps) - x = x - x.detach() + clamped_x.detach() - return x - - def sample(self, clip=None, sample_shape=default_sample_shape): - shape = self._extended_shape(sample_shape) - eps = _standard_normal(shape, dtype=self.loc.dtype, device=self.loc.device) - eps *= self.scale - if clip is not None: - eps = torch.clamp(eps, -clip, clip) - x = self.loc + eps - return self._clamp(x) - - -class NormalizeImg(nn.Module): - """Normalizes pixel observations to [0,1) range.""" - - def __init__(self): - super().__init__() - - def forward(self, x): - return x.div(255.0) - - -class Flatten(nn.Module): - """Flattens its input to a (batched) vector.""" - - def __init__(self): - super().__init__() - - def forward(self, x): - return x.view(x.size(0), -1) - - -def enc(cfg): - obs_shape = { - "rgb": (3, cfg.img_size, cfg.img_size), - "state": (cfg.state_dim,), - } - - """Returns a TOLD encoder.""" - pixels_enc_layers, state_enc_layers = None, None - if cfg.modality in {"pixels", "all"}: - C = int(3 * cfg.frame_stack) # noqa: N806 - pixels_enc_layers = [ - NormalizeImg(), - nn.Conv2d(C, cfg.num_channels, 7, stride=2), - nn.ReLU(), - nn.Conv2d(cfg.num_channels, cfg.num_channels, 5, stride=2), - nn.ReLU(), - nn.Conv2d(cfg.num_channels, cfg.num_channels, 3, stride=2), - nn.ReLU(), - nn.Conv2d(cfg.num_channels, cfg.num_channels, 3, stride=2), - nn.ReLU(), - ] - out_shape = _get_out_shape((C, cfg.img_size, cfg.img_size), pixels_enc_layers) - pixels_enc_layers.extend( - [ - Flatten(), - nn.Linear(np.prod(out_shape), cfg.latent_dim), - nn.LayerNorm(cfg.latent_dim), - nn.Sigmoid(), - ] - ) - if cfg.modality == "pixels": - return ConvExt(nn.Sequential(*pixels_enc_layers)) - if cfg.modality in {"state", "all"}: - state_dim = obs_shape[0] if cfg.modality == "state" else obs_shape["state"][0] - state_enc_layers = [ - nn.Linear(state_dim, cfg.enc_dim), - nn.ELU(), - nn.Linear(cfg.enc_dim, cfg.latent_dim), - nn.LayerNorm(cfg.latent_dim), - nn.Sigmoid(), - ] - if cfg.modality == "state": - return nn.Sequential(*state_enc_layers) - else: - raise NotImplementedError - - encoders = {} - for k in obs_shape: - if k == "state": - encoders[k] = nn.Sequential(*state_enc_layers) - elif k.endswith("rgb"): - encoders[k] = ConvExt(nn.Sequential(*pixels_enc_layers)) - else: - raise NotImplementedError - return Multiplexer(nn.ModuleDict(encoders)) - - -def mlp(in_dim, mlp_dim, out_dim, act_fn=DEFAULT_ACT_FN): - """Returns an MLP.""" - if isinstance(mlp_dim, int): - mlp_dim = [mlp_dim, mlp_dim] - return nn.Sequential( - nn.Linear(in_dim, mlp_dim[0]), - nn.LayerNorm(mlp_dim[0]), - act_fn, - nn.Linear(mlp_dim[0], mlp_dim[1]), - nn.LayerNorm(mlp_dim[1]), - act_fn, - nn.Linear(mlp_dim[1], out_dim), - ) - - -def dynamics(in_dim, mlp_dim, out_dim, act_fn=DEFAULT_ACT_FN): - """Returns a dynamics network.""" - return nn.Sequential( - mlp(in_dim, mlp_dim, out_dim, act_fn), - nn.LayerNorm(out_dim), - nn.Sigmoid(), - ) - - -def q(cfg): - action_dim = cfg.action_dim - """Returns a Q-function that uses Layer Normalization.""" - return nn.Sequential( - nn.Linear(cfg.latent_dim + action_dim, cfg.mlp_dim), - nn.LayerNorm(cfg.mlp_dim), - nn.Tanh(), - nn.Linear(cfg.mlp_dim, cfg.mlp_dim), - nn.ELU(), - nn.Linear(cfg.mlp_dim, 1), - ) - - -def v(cfg): - """Returns a state value function that uses Layer Normalization.""" - return nn.Sequential( - nn.Linear(cfg.latent_dim, cfg.mlp_dim), - nn.LayerNorm(cfg.mlp_dim), - nn.Tanh(), - nn.Linear(cfg.mlp_dim, cfg.mlp_dim), - nn.ELU(), - nn.Linear(cfg.mlp_dim, 1), - ) - - -def aug(cfg): - obs_shape = { - "rgb": (3, cfg.img_size, cfg.img_size), - "state": (4,), - } - - """Multiplex augmentation""" - if cfg.modality == "state": - return nn.Identity() - elif cfg.modality == "pixels": - return RandomShiftsAug(cfg) - else: - augs = {} - for k in obs_shape: - if k == "state": - augs[k] = nn.Identity() - elif k.endswith("rgb"): - augs[k] = RandomShiftsAug(cfg) - else: - raise NotImplementedError - return Multiplexer(nn.ModuleDict(augs)) - - -class ConvExt(nn.Module): - """Auxiliary conv net accommodating high-dim input""" - - def __init__(self, conv): - super().__init__() - self.conv = conv - - def forward(self, x): - if x.ndim > 4: - batch_shape = x.shape[:-3] - out = self.conv(x.view(-1, *x.shape[-3:])) - out = out.view(*batch_shape, *out.shape[1:]) - else: - out = self.conv(x) - return out - - -class Multiplexer(nn.Module): - """Model multiplexer""" - - def __init__(self, choices): - super().__init__() - self.choices = choices - - def forward(self, x, key=None): - if isinstance(x, dict): - if key is not None: - return self.choices[key](x) - return {k: self.choices[k](_x) for k, _x in x.items()} - return self.choices(x) - - -class RandomShiftsAug(nn.Module): - """ - Random shift image augmentation. - Adapted from https://github.com/facebookresearch/drqv2 - """ - - def __init__(self, cfg): - super().__init__() - assert cfg.modality in {"pixels", "all"} - self.pad = int(cfg.img_size / 21) - - def forward(self, x): - n, c, h, w = x.size() - assert h == w - padding = tuple([self.pad] * 4) - x = F.pad(x, padding, "replicate") - eps = 1.0 / (h + 2 * self.pad) - arange = torch.linspace( - -1.0 + eps, - 1.0 - eps, - h + 2 * self.pad, - device=x.device, - dtype=torch.float32, - )[:h] - arange = arange.unsqueeze(0).repeat(h, 1).unsqueeze(2) - base_grid = torch.cat([arange, arange.transpose(1, 0)], dim=2) - base_grid = base_grid.unsqueeze(0).repeat(n, 1, 1, 1) - shift = torch.randint( - 0, - 2 * self.pad + 1, - size=(n, 1, 1, 2), - device=x.device, - dtype=torch.float32, - ) - shift *= 2.0 / (h + 2 * self.pad) - grid = base_grid + shift - return F.grid_sample(x, grid, padding_mode="zeros", align_corners=False) - - -# TODO(aliberts): remove class -# class Episode: -# """Storage object for a single episode.""" - -# def __init__(self, cfg, init_obs): -# action_dim = cfg.action_dim - -# self.cfg = cfg -# self.device = torch.device(cfg.buffer_device) -# if cfg.modality in {"pixels", "state"}: -# dtype = torch.float32 if cfg.modality == "state" else torch.uint8 -# self.obses = torch.empty( -# (cfg.episode_length + 1, *init_obs.shape), -# dtype=dtype, -# device=self.device, -# ) -# self.obses[0] = torch.tensor(init_obs, dtype=dtype, device=self.device) -# elif cfg.modality == "all": -# self.obses = {} -# for k, v in init_obs.items(): -# assert k in {"rgb", "state"} -# dtype = torch.float32 if k == "state" else torch.uint8 -# self.obses[k] = torch.empty( -# (cfg.episode_length + 1, *v.shape), dtype=dtype, device=self.device -# ) -# self.obses[k][0] = torch.tensor(v, dtype=dtype, device=self.device) -# else: -# raise ValueError -# self.actions = torch.empty((cfg.episode_length, action_dim), dtype=torch.float32, device=self.device) -# self.rewards = torch.empty((cfg.episode_length,), dtype=torch.float32, device=self.device) -# self.dones = torch.empty((cfg.episode_length,), dtype=torch.bool, device=self.device) -# self.masks = torch.empty((cfg.episode_length,), dtype=torch.float32, device=self.device) -# self.cumulative_reward = 0 -# self.done = False -# self.success = False -# self._idx = 0 - -# def __len__(self): -# return self._idx - -# @classmethod -# def from_trajectory(cls, cfg, obses, actions, rewards, dones=None, masks=None): -# """Constructs an episode from a trajectory.""" - -# if cfg.modality in {"pixels", "state"}: -# episode = cls(cfg, obses[0]) -# episode.obses[1:] = torch.tensor(obses[1:], dtype=episode.obses.dtype, device=episode.device) -# elif cfg.modality == "all": -# episode = cls(cfg, {k: v[0] for k, v in obses.items()}) -# for k in obses: -# episode.obses[k][1:] = torch.tensor( -# obses[k][1:], dtype=episode.obses[k].dtype, device=episode.device -# ) -# else: -# raise NotImplementedError -# episode.actions = torch.tensor(actions, dtype=episode.actions.dtype, device=episode.device) -# episode.rewards = torch.tensor(rewards, dtype=episode.rewards.dtype, device=episode.device) -# episode.dones = ( -# torch.tensor(dones, dtype=episode.dones.dtype, device=episode.device) -# if dones is not None -# else torch.zeros_like(episode.dones) -# ) -# episode.masks = ( -# torch.tensor(masks, dtype=episode.masks.dtype, device=episode.device) -# if masks is not None -# else torch.ones_like(episode.masks) -# ) -# episode.cumulative_reward = torch.sum(episode.rewards) -# episode.done = True -# episode._idx = cfg.episode_length -# return episode - -# @property -# def first(self): -# return len(self) == 0 - -# def __add__(self, transition): -# self.add(*transition) -# return self - -# def add(self, obs, action, reward, done, mask=1.0, success=False): -# """Add a transition into the episode.""" -# if isinstance(obs, dict): -# for k, v in obs.items(): -# self.obses[k][self._idx + 1] = torch.tensor( -# v, dtype=self.obses[k].dtype, device=self.obses[k].device -# ) -# else: -# self.obses[self._idx + 1] = torch.tensor(obs, dtype=self.obses.dtype, device=self.obses.device) -# self.actions[self._idx] = action -# self.rewards[self._idx] = reward -# self.dones[self._idx] = done -# self.masks[self._idx] = mask -# self.cumulative_reward += reward -# self.done = done -# self.success = self.success or success -# self._idx += 1 - - -def get_dataset_dict(cfg, env, return_reward_normalizer=False): - """Construct a dataset for env""" - required_keys = [ - "observations", - "next_observations", - "actions", - "rewards", - "dones", - "masks", - ] - - if cfg.task.startswith("xarm"): - dataset_path = os.path.join(cfg.dataset_dir, "buffer.pkl") - print(f"Using offline dataset '{dataset_path}'") - with open(dataset_path, "rb") as f: - dataset_dict = pickle.load(f) - for k in required_keys: - if k not in dataset_dict and k[:-1] in dataset_dict: - dataset_dict[k] = dataset_dict.pop(k[:-1]) - elif cfg.task.startswith("legged"): - dataset_path = os.path.join(cfg.dataset_dir, "buffer.pkl") - print(f"Using offline dataset '{dataset_path}'") - with open(dataset_path, "rb") as f: - dataset_dict = pickle.load(f) - dataset_dict["actions"] /= env.unwrapped.clip_actions - print(f"clip_actions={env.unwrapped.clip_actions}") - else: - import d4rl - - dataset_dict = d4rl.qlearning_dataset(env) - dones = np.full_like(dataset_dict["rewards"], False, dtype=bool) - - for i in range(len(dones) - 1): - if ( - np.linalg.norm(dataset_dict["observations"][i + 1] - dataset_dict["next_observations"][i]) - > 1e-6 - or dataset_dict["terminals"][i] == 1.0 - ): - dones[i] = True - - dones[-1] = True - - dataset_dict["masks"] = 1.0 - dataset_dict["terminals"] - del dataset_dict["terminals"] - - for k, v in dataset_dict.items(): - dataset_dict[k] = v.astype(np.float32) - - dataset_dict["dones"] = dones - - if cfg.is_data_clip: - lim = 1 - cfg.data_clip_eps - dataset_dict["actions"] = np.clip(dataset_dict["actions"], -lim, lim) - reward_normalizer = get_reward_normalizer(cfg, dataset_dict) - dataset_dict["rewards"] = reward_normalizer(dataset_dict["rewards"]) - - for key in required_keys: - assert key in dataset_dict, f"Missing `{key}` in dataset." - - if return_reward_normalizer: - return dataset_dict, reward_normalizer - return dataset_dict - - -def get_trajectory_boundaries_and_returns(dataset): - """ - Split dataset into trajectories and compute returns - """ - episode_starts = [0] - episode_ends = [] - - episode_return = 0 - episode_returns = [] - - n_transitions = len(dataset["rewards"]) - - for i in range(n_transitions): - episode_return += dataset["rewards"][i] - - if dataset["dones"][i]: - episode_returns.append(episode_return) - episode_ends.append(i + 1) - if i + 1 < n_transitions: - episode_starts.append(i + 1) - episode_return = 0.0 - - return episode_starts, episode_ends, episode_returns - - -def normalize_returns(dataset, scaling=1000): - """ - Normalize returns in the dataset - """ - (_, _, episode_returns) = get_trajectory_boundaries_and_returns(dataset) - dataset["rewards"] /= np.max(episode_returns) - np.min(episode_returns) - dataset["rewards"] *= scaling - return dataset - - -def get_reward_normalizer(cfg, dataset): - """ - Get a reward normalizer for the dataset - """ - if cfg.task.startswith("xarm"): - return lambda x: x - elif "maze" in cfg.task: - return lambda x: x - 1.0 - elif cfg.task.split("-")[0] in ["hopper", "halfcheetah", "walker2d"]: - (_, _, episode_returns) = get_trajectory_boundaries_and_returns(dataset) - return lambda x: x / (np.max(episode_returns) - np.min(episode_returns)) * 1000.0 - elif hasattr(cfg, "reward_scale"): - return lambda x: x * cfg.reward_scale - return lambda x: x - - -def linear_schedule(schdl, step): - """ - Outputs values following a linear decay schedule. - Adapted from https://github.com/facebookresearch/drqv2 - """ - try: - return float(schdl) - except ValueError: - match = re.match(r"linear\((.+),(.+),(.+),(.+)\)", schdl) - if match: - init, final, start, end = (float(g) for g in match.groups()) - mix = np.clip((step - start) / (end - start), 0.0, 1.0) - return (1.0 - mix) * init + mix * final - match = re.match(r"linear\((.+),(.+),(.+)\)", schdl) - if match: - init, final, duration = (float(g) for g in match.groups()) - mix = np.clip(step / duration, 0.0, 1.0) - return (1.0 - mix) * init + mix * final - raise NotImplementedError(schdl) diff --git a/lerobot/common/policies/tdmpc/modeling_tdmpc.py b/lerobot/common/policies/tdmpc/modeling_tdmpc.py new file mode 100644 index 00000000..4205b4fc --- /dev/null +++ b/lerobot/common/policies/tdmpc/modeling_tdmpc.py @@ -0,0 +1,798 @@ +"""Implementation of Finetuning Offline World Models in the Real World. + +The comments in this code may sometimes refer to these references: + TD-MPC paper: Temporal Difference Learning for Model Predictive Control (https://arxiv.org/abs/2203.04955) + FOWM paper: Finetuning Offline World Models in the Real World (https://arxiv.org/abs/2310.16029) + +TODO(alexander-soare): Make rollout work for batch sizes larger than 1. +TODO(alexander-soare): Use batch-first throughout. +""" + +# ruff: noqa: N806 + +import logging +from collections import deque +from copy import deepcopy +from functools import partial +from typing import Callable + +import einops +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F # noqa: N812 +from huggingface_hub import PyTorchModelHubMixin +from torch import Tensor + +from lerobot.common.policies.normalize import Normalize, Unnormalize +from lerobot.common.policies.tdmpc.configuration_tdmpc import TDMPCConfig +from lerobot.common.policies.utils import get_device_from_parameters, populate_queues + + +class TDMPCPolicy(nn.Module, PyTorchModelHubMixin): + """Implementation of TD-MPC learning + inference. + + Please note several warnings for this policy. + - Evaluation of pretrained weights created with the original FOWM code + (https://github.com/fyhMer/fowm) works as expected. To be precise: we trained and evaluated a + model with the FOWM code for the xarm_lift_medium_replay dataset. We ported the weights across + to LeRobot, and were able to evaluate with the same success metric. BUT, we had to use inter- + process communication to use the xarm environment from FOWM. This is because our xarm + environment uses newer dependencies and does not match the environment in FOWM. See + https://github.com/huggingface/lerobot/pull/103 for implementation details. + - We have NOT checked that training on LeRobot reproduces SOTA results. This is a TODO. + - Our current xarm datasets were generated using the environment from FOWM. Therefore they do not + match our xarm environment. + """ + + name = "tdmpc" + + def __init__( + self, config: TDMPCConfig | None = None, dataset_stats: dict[str, dict[str, Tensor]] | None = None + ): + """ + Args: + config: Policy configuration class instance or None, in which case the default instantiation of + the configuration class is used. + dataset_stats: Dataset statistics to be used for normalization. If not passed here, it is expected + that they will be passed with a call to `load_state_dict` before the policy is used. + """ + super().__init__() + logging.warning( + """ + Please note several warnings for this policy. + + - Evaluation of pretrained weights created with the original FOWM code + (https://github.com/fyhMer/fowm) works as expected. To be precise: we trained and evaluated a + model with the FOWM code for the xarm_lift_medium_replay dataset. We ported the weights across + to LeRobot, and were able to evaluate with the same success metric. BUT, we had to use inter- + process communication to use the xarm environment from FOWM. This is because our xarm + environment uses newer dependencies and does not match the environment in FOWM. See + https://github.com/huggingface/lerobot/pull/103 for implementation details. + - We have NOT checked that training on LeRobot reproduces SOTA results. This is a TODO. + - Our current xarm datasets were generated using the environment from FOWM. Therefore they do not + match our xarm environment. + """ + ) + + if config is None: + config = TDMPCConfig() + self.config = config + self.model = TDMPCTOLD(config) + self.model_target = deepcopy(self.model) + self.model_target.eval() + + if config.input_normalization_modes is not None: + self.normalize_inputs = Normalize( + config.input_shapes, config.input_normalization_modes, dataset_stats + ) + else: + self.normalize_inputs = nn.Identity() + self.normalize_targets = Normalize( + config.output_shapes, config.output_normalization_modes, dataset_stats + ) + self.unnormalize_outputs = Unnormalize( + config.output_shapes, config.output_normalization_modes, dataset_stats + ) + + def save(self, fp): + """Save state dict of TOLD model to filepath.""" + torch.save(self.state_dict(), fp) + + def load(self, fp): + """Load a saved state dict from filepath into current agent.""" + self.load_state_dict(torch.load(fp)) + + def reset(self): + """ + Clear observation and action queues. Clear previous means for warm starting of MPPI/CEM. Should be + called on `env.reset()` + """ + self._queues = { + "observation.image": deque(maxlen=1), + "observation.state": deque(maxlen=1), + "action": deque(maxlen=self.config.n_action_repeats), + } + # Previous mean obtained from the cross-entropy method (CEM) used during MPC. It is used to warm start + # CEM for the next step. + self._prev_mean: torch.Tensor | None = None + + @torch.no_grad() + def select_action(self, batch: dict[str, Tensor]): + """Select a single action given environment observations.""" + assert "observation.image" in batch + assert "observation.state" in batch + assert len(batch) == 2 + + batch = self.normalize_inputs(batch) + + self._queues = populate_queues(self._queues, batch) + + # When the action queue is depleted, populate it again by querying the policy. + if len(self._queues["action"]) == 0: + batch = {key: torch.stack(list(self._queues[key]), dim=1) for key in batch} + + # Remove the time dimensions as it is not handled yet. + for key in batch: + assert batch[key].shape[1] == 1 + batch[key] = batch[key][:, 0] + + # NOTE: Order of observations matters here. + z = self.model.encode({k: batch[k] for k in ["observation.image", "observation.state"]}) + if self.config.use_mpc: + batch_size = batch["observation.image"].shape[0] + # Batch processing is not handled in MPC mode, so process the batch in a loop. + action = [] # will be a batch of actions for one step + for i in range(batch_size): + # Note: self.plan does not handle batches, hence the squeeze. + action.append(self.plan(z[i])) + action = torch.stack(action) + else: + # Plan with the policy (π) alone. + action = self.model.pi(z) + + self.unnormalize_outputs({"action": action})["action"] + + for _ in range(self.config.n_action_repeats): + self._queues["action"].append(action) + + action = self._queues["action"].popleft() + return torch.clamp(action, -1, 1) + + @torch.no_grad() + def plan(self, z: Tensor) -> Tensor: + """Plan next action using TD-MPC inference. + + Args: + z: (latent_dim,) tensor for the initial state. + Returns: + (action_dim,) tensor for the next action. + + TODO(alexander-soare) Extend this to be able to work with batches. + """ + device = get_device_from_parameters(self) + + # Sample Nπ trajectories from the policy. + pi_actions = torch.empty( + self.config.horizon, + self.config.n_pi_samples, + self.config.output_shapes["action"][0], + device=device, + ) + if self.config.n_pi_samples > 0: + _z = einops.repeat(z, "d -> n d", n=self.config.n_pi_samples) + for t in range(self.config.horizon): + # Note: Adding a small amount of noise here doesn't hurt during inference and may even be + # helpful for CEM. + pi_actions[t] = self.model.pi(_z, self.config.min_std) + _z = self.model.latent_dynamics(_z, pi_actions[t]) + + # In the CEM loop we will need this for a call to estimate_value with the gaussian sampled + # trajectories. + z = einops.repeat(z, "d -> n d", n=self.config.n_gaussian_samples + self.config.n_pi_samples) + + # Model Predictive Path Integral (MPPI) with the cross-entropy method (CEM) as the optimization + # algorithm. + # The initial mean and standard deviation for the cross-entropy method (CEM). + mean = torch.zeros(self.config.horizon, self.config.output_shapes["action"][0], device=device) + # Maybe warm start CEM with the mean from the previous step. + if self._prev_mean is not None: + mean[:-1] = self._prev_mean[1:] + std = self.config.max_std * torch.ones_like(mean) + + for _ in range(self.config.cem_iterations): + # Randomly sample action trajectories for the gaussian distribution. + std_normal_noise = torch.randn( + self.config.horizon, + self.config.n_gaussian_samples, + self.config.output_shapes["action"][0], + device=std.device, + ) + gaussian_actions = torch.clamp(mean.unsqueeze(1) + std.unsqueeze(1) * std_normal_noise, -1, 1) + + # Compute elite actions. + actions = torch.cat([gaussian_actions, pi_actions], dim=1) + value = self.estimate_value(z, actions).nan_to_num_(0) + elite_idxs = torch.topk(value, self.config.n_elites, dim=0).indices + elite_value, elite_actions = value[elite_idxs], actions[:, elite_idxs] + + # Update guassian PDF parameters to be the (weighted) mean and standard deviation of the elites. + max_value = elite_value.max(0)[0] + # The weighting is a softmax over trajectory values. Note that this is not the same as the usage + # of Ω in eqn 4 of the TD-MPC paper. Instead it is the normalized version of it: s = Ω/ΣΩ. This + # makes the equations: μ = Σ(s⋅Γ), σ = Σ(s⋅(Γ-μ)²). + score = torch.exp(self.config.elite_weighting_temperature * (elite_value - max_value)) + score /= score.sum() + _mean = torch.sum(einops.rearrange(score, "n -> n 1") * elite_actions, dim=1) + _std = torch.sqrt( + torch.sum( + einops.rearrange(score, "n -> n 1") + * (elite_actions - einops.rearrange(_mean, "h d -> h 1 d")) ** 2, + dim=1, + ) + ) + # Update mean with an exponential moving average, and std with a direct replacement. + mean = ( + self.config.gaussian_mean_momentum * mean + (1 - self.config.gaussian_mean_momentum) * _mean + ) + std = _std.clamp_(self.config.min_std, self.config.max_std) + + # Keep track of the mean for warm-starting subsequent steps. + self._prev_mean = mean + + # Randomly select one of the elite actions from the last iteration of MPPI/CEM using the softmax + # scores from the last iteration. + actions = elite_actions[:, torch.multinomial(score, 1).item()] + + # Select only the first action + action = actions[0] + return action + + @torch.no_grad() + def estimate_value(self, z: Tensor, actions: Tensor): + """Estimates the value of a trajectory as per eqn 4 of the FOWM paper. + + Args: + z: (batch, latent_dim) tensor of initial latent states. + actions: (horizon, batch, action_dim) tensor of action trajectories. + Returns: + (batch,) tensor of values. + """ + # Initialize return and running discount factor. + G, running_discount = 0, 1 + # Iterate over the actions in the trajectory to simulate the trajectory using the latent dynamics + # model. Keep track of return. + for t in range(actions.shape[0]): + # We will compute the reward in a moment. First compute the uncertainty regularizer from eqn 4 + # of the FOWM paper. + if self.config.uncertainty_regularizer_coeff > 0: + regularization = -( + self.config.uncertainty_regularizer_coeff * self.model.Qs(z, actions[t]).std(0) + ) + else: + regularization = 0 + # Estimate the next state (latent) and reward. + z, reward = self.model.latent_dynamics_and_reward(z, actions[t]) + # Update the return and running discount. + G += running_discount * (reward + regularization) + running_discount *= self.config.discount + # Add the estimated value of the final state (using the minimum for a conservative estimate). + # Do so by predicting the next action, then taking a minimum over the ensemble of state-action value + # estimators. + # Note: This small amount of added noise seems to help a bit at inference time as observed by success + # metrics over 50 episodes of xarm_lift_medium_replay. + next_action = self.model.pi(z, self.config.min_std) # (batch, action_dim) + terminal_values = self.model.Qs(z, next_action) # (ensemble, batch) + # Randomly choose 2 of the Qs for terminal value estimation (as in App C. of the FOWM paper). + if self.config.q_ensemble_size > 2: + G += ( + running_discount + * torch.min(terminal_values[torch.randint(0, self.config.q_ensemble_size, size=(2,))], dim=0)[ + 0 + ] + ) + else: + G += running_discount * torch.min(terminal_values, dim=0)[0] + # Finally, also regularize the terminal value. + if self.config.uncertainty_regularizer_coeff > 0: + G -= running_discount * self.config.uncertainty_regularizer_coeff * terminal_values.std(0) + return G + + def forward(self, batch: dict[str, Tensor]) -> dict[str, Tensor]: + """Run the batch through the model and compute the loss.""" + device = get_device_from_parameters(self) + + batch = self.normalize_inputs(batch) + batch = self.normalize_targets(batch) + + info = {} + + # TODO(alexander-soare): Refactor TDMPC and make it comply with the policy interface documentation. + batch_size = batch["index"].shape[0] + + # (b, t) -> (t, b) + for key in batch: + if batch[key].ndim > 1: + batch[key] = batch[key].transpose(1, 0) + + action = batch["action"] # (t, b) + reward = batch["next.reward"] # (t,) + observations = {k: v for k, v in batch.items() if k.startswith("observation.")} + + # Apply random image augmentations. + if self.config.max_random_shift_ratio > 0: + observations["observation.image"] = flatten_forward_unflatten( + partial(random_shifts_aug, max_random_shift_ratio=self.config.max_random_shift_ratio), + observations["observation.image"], + ) + + # Get the current observation for predicting trajectories, and all future observations for use in + # the latent consistency loss and TD loss. + current_observation, next_observations = {}, {} + for k in observations: + current_observation[k] = observations[k][0] + next_observations[k] = observations[k][1:] + horizon = next_observations["observation.image"].shape[0] + + # Run latent rollout using the latent dynamics model and policy model. + # Note this has shape `horizon+1` because there are `horizon` actions and a current `z`. Each action + # gives us a next `z`. + z_preds = torch.empty(horizon + 1, batch_size, self.config.latent_dim, device=device) + z_preds[0] = self.model.encode(current_observation) + reward_preds = torch.empty_like(reward, device=device) + for t in range(horizon): + z_preds[t + 1], reward_preds[t] = self.model.latent_dynamics_and_reward(z_preds[t], action[t]) + + # Compute Q and V value predictions based on the latent rollout. + q_preds_ensemble = self.model.Qs(z_preds[:-1], action) # (ensemble, horizon, batch) + v_preds = self.model.V(z_preds[:-1]) + info.update({"Q": q_preds_ensemble.mean().item(), "V": v_preds.mean().item()}) + + # Compute various targets with stopgrad. + with torch.no_grad(): + # Latent state consistency targets. + z_targets = self.model_target.encode(next_observations) + # State-action value targets (or TD targets) as in eqn 3 of the FOWM. Unlike TD-MPC which uses the + # learned state-action value function in conjunction with the learned policy: Q(z, π(z)), FOWM + # uses a learned state value function: V(z). This means the TD targets only depend on in-sample + # actions (not actions estimated by π). + # Note: Here we do not use self.model_target, but self.model. This is to follow the original code + # and the FOWM paper. + q_targets = reward + self.config.discount * self.model.V(self.model.encode(next_observations)) + # From eqn 3 of FOWM. These appear as Q(z, a). Here we call them v_targets to emphasize that we + # are using them to compute loss for V. + v_targets = self.model_target.Qs(z_preds[:-1].detach(), action, return_min=True) + + # Compute losses. + # Exponentially decay the loss weight with respect to the timestep. Steps that are more distant in the + # future have less impact on the loss. Note: unsqueeze will let us broadcast to (seq, batch). + temporal_loss_coeffs = torch.pow( + self.config.temporal_decay_coeff, torch.arange(horizon, device=device) + ).unsqueeze(-1) + # Compute consistency loss as MSE loss between latents predicted from the rollout and latents + # predicted from the (target model's) observation encoder. + consistency_loss = ( + ( + temporal_loss_coeffs + * F.mse_loss(z_preds[1:], z_targets, reduction="none").mean(dim=-1) + # `z_preds` depends on the current observation and the actions. + * ~batch["observation.state_is_pad"][0] + * ~batch["action_is_pad"] + # `z_targets` depends on the next observation. + * ~batch["observation.state_is_pad"][1:] + ) + .sum(0) + .mean() + ) + # Compute the reward loss as MSE loss between rewards predicted from the rollout and the dataset + # rewards. + reward_loss = ( + ( + temporal_loss_coeffs + * F.mse_loss(reward_preds, reward, reduction="none") + * ~batch["next.reward_is_pad"] + # `reward_preds` depends on the current observation and the actions. + * ~batch["observation.state_is_pad"][0] + * ~batch["action_is_pad"] + ) + .sum(0) + .mean() + ) + # Compute state-action value loss (TD loss) for all of the Q functions in the ensemble. + q_value_loss = ( + ( + F.mse_loss( + q_preds_ensemble, + einops.repeat(q_targets, "t b -> e t b", e=q_preds_ensemble.shape[0]), + reduction="none", + ).sum(0) # sum over ensemble + # `q_preds_ensemble` depends on the first observation and the actions. + * ~batch["observation.state_is_pad"][0] + * ~batch["action_is_pad"] + # q_targets depends on the reward and the next observations. + * ~batch["next.reward_is_pad"] + * ~batch["observation.state_is_pad"][1:] + ) + .sum(0) + .mean() + ) + # Compute state value loss as in eqn 3 of FOWM. + diff = v_targets - v_preds + # Expectile loss penalizes: + # - `v_preds < v_targets` with weighting `expectile_weight` + # - `v_preds >= v_targets` with weighting `1 - expectile_weight` + raw_v_value_loss = torch.where( + diff > 0, self.config.expectile_weight, (1 - self.config.expectile_weight) + ) * (diff**2) + v_value_loss = ( + ( + temporal_loss_coeffs + * raw_v_value_loss + # `v_targets` depends on the first observation and the actions, as does `v_preds`. + * ~batch["observation.state_is_pad"][0] + * ~batch["action_is_pad"] + ) + .sum(0) + .mean() + ) + + # Calculate the advantage weighted regression loss for π as detailed in FOWM 3.1. + # We won't need these gradients again so detach. + z_preds = z_preds.detach() + # Use stopgrad for the advantage calculation. + with torch.no_grad(): + advantage = self.model_target.Qs(z_preds[:-1], action, return_min=True) - self.model.V( + z_preds[:-1] + ) + info["advantage"] = advantage[0] + # (t, b) + exp_advantage = torch.clamp(torch.exp(advantage * self.config.advantage_scaling), max=100.0) + action_preds = self.model.pi(z_preds[:-1]) # (t, b, a) + # Calculate the MSE between the actions and the action predictions. + # Note: FOWM's original code calculates the log probability (wrt to a unit standard deviation + # gaussian) and sums over the action dimension. Computing the log probability amounts to multiplying + # the MSE by 0.5 and adding a constant offset (the log(2*pi) term) . Here we drop the constant offset + # as it doesn't change the optimization step, and we drop the 0.5 as we instead make a configuration + # parameter for it (see below where we compute the total loss). + mse = F.mse_loss(action_preds, action, reduction="none").sum(-1) # (t, b) + # NOTE: The original implementation does not take the sum over the temporal dimension like with the + # other losses. + # TODO(alexander-soare): Take the sum over the temporal dimension and check that training still works + # as well as expected. + pi_loss = ( + exp_advantage + * mse + * temporal_loss_coeffs + # `action_preds` depends on the first observation and the actions. + * ~batch["observation.state_is_pad"][0] + * ~batch["action_is_pad"] + ).mean() + + loss = ( + self.config.consistency_coeff * consistency_loss + + self.config.reward_coeff * reward_loss + + self.config.value_coeff * q_value_loss + + self.config.value_coeff * v_value_loss + + self.config.pi_coeff * pi_loss + ) + + info.update( + { + "consistency_loss": consistency_loss.item(), + "reward_loss": reward_loss.item(), + "Q_value_loss": q_value_loss.item(), + "V_value_loss": v_value_loss.item(), + "pi_loss": pi_loss.item(), + "loss": loss, + "sum_loss": loss.item() * self.config.horizon, + } + ) + + # Undo (b, t) -> (t, b). + for key in batch: + if batch[key].ndim > 1: + batch[key] = batch[key].transpose(1, 0) + + return info + + def update(self): + """Update the target model's parameters with an EMA step.""" + # Note a minor variation with respect to the original FOWM code. Here they do this based on an EMA + # update frequency parameter which is set to 2 (every 2 steps an update is done). To simplify the code + # we update every step and adjust the decay parameter `alpha` accordingly (0.99 -> 0.995) + update_ema_parameters(self.model_target, self.model, self.config.target_model_momentum) + + +class TDMPCTOLD(nn.Module): + """Task-Oriented Latent Dynamics (TOLD) model used in TD-MPC.""" + + def __init__(self, config: TDMPCConfig): + super().__init__() + self.config = config + self._encoder = TDMPCObservationEncoder(config) + self._dynamics = nn.Sequential( + nn.Linear(config.latent_dim + config.output_shapes["action"][0], config.mlp_dim), + nn.LayerNorm(config.mlp_dim), + nn.Mish(), + nn.Linear(config.mlp_dim, config.mlp_dim), + nn.LayerNorm(config.mlp_dim), + nn.Mish(), + nn.Linear(config.mlp_dim, config.latent_dim), + nn.LayerNorm(config.latent_dim), + nn.Sigmoid(), + ) + self._reward = nn.Sequential( + nn.Linear(config.latent_dim + config.output_shapes["action"][0], config.mlp_dim), + nn.LayerNorm(config.mlp_dim), + nn.Mish(), + nn.Linear(config.mlp_dim, config.mlp_dim), + nn.LayerNorm(config.mlp_dim), + nn.Mish(), + nn.Linear(config.mlp_dim, 1), + ) + self._pi = nn.Sequential( + nn.Linear(config.latent_dim, config.mlp_dim), + nn.LayerNorm(config.mlp_dim), + nn.Mish(), + nn.Linear(config.mlp_dim, config.mlp_dim), + nn.LayerNorm(config.mlp_dim), + nn.Mish(), + nn.Linear(config.mlp_dim, config.output_shapes["action"][0]), + ) + self._Qs = nn.ModuleList( + [ + nn.Sequential( + nn.Linear(config.latent_dim + config.output_shapes["action"][0], config.mlp_dim), + nn.LayerNorm(config.mlp_dim), + nn.Tanh(), + nn.Linear(config.mlp_dim, config.mlp_dim), + nn.ELU(), + nn.Linear(config.mlp_dim, 1), + ) + for _ in range(config.q_ensemble_size) + ] + ) + self._V = nn.Sequential( + nn.Linear(config.latent_dim, config.mlp_dim), + nn.LayerNorm(config.mlp_dim), + nn.Tanh(), + nn.Linear(config.mlp_dim, config.mlp_dim), + nn.ELU(), + nn.Linear(config.mlp_dim, 1), + ) + self._init_weights() + + def _init_weights(self): + """Initialize model weights. + + Orthogonal initialization for all linear and convolutional layers' weights (apart from final layers + of reward network and Q networks which get zero initialization). + Zero initialization for all linear and convolutional layers' biases. + """ + + def _apply_fn(m): + if isinstance(m, nn.Linear): + nn.init.orthogonal_(m.weight.data) + if m.bias is not None: + nn.init.zeros_(m.bias) + elif isinstance(m, nn.Conv2d): + gain = nn.init.calculate_gain("relu") + nn.init.orthogonal_(m.weight.data, gain) + if m.bias is not None: + nn.init.zeros_(m.bias) + + self.apply(_apply_fn) + for m in [self._reward, *self._Qs]: + assert isinstance( + m[-1], nn.Linear + ), "Sanity check. The last linear layer needs 0 initialization on weights." + nn.init.zeros_(m[-1].weight) + nn.init.zeros_(m[-1].bias) # this has already been done, but keep this line here for good measure + + def encode(self, obs: dict[str, Tensor]) -> Tensor: + """Encodes an observation into its latent representation.""" + return self._encoder(obs) + + def latent_dynamics_and_reward(self, z: Tensor, a: Tensor) -> tuple[Tensor, Tensor]: + """Predict the next state's latent representation and the reward given a current latent and action. + + Args: + z: (*, latent_dim) tensor for the current state's latent representation. + a: (*, action_dim) tensor for the action to be applied. + Returns: + A tuple containing: + - (*, latent_dim) tensor for the next state's latent representation. + - (*,) tensor for the estimated reward. + """ + x = torch.cat([z, a], dim=-1) + return self._dynamics(x), self._reward(x).squeeze(-1) + + def latent_dynamics(self, z: Tensor, a: Tensor) -> Tensor: + """Predict the next state's latent representation given a current latent and action. + + Args: + z: (*, latent_dim) tensor for the current state's latent representation. + a: (*, action_dim) tensor for the action to be applied. + Returns: + (*, latent_dim) tensor for the next state's latent representation. + """ + x = torch.cat([z, a], dim=-1) + return self._dynamics(x) + + def pi(self, z: Tensor, std: float = 0.0) -> Tensor: + """Samples an action from the learned policy. + + The policy can also have added (truncated) Gaussian noise injected for encouraging exploration when + generating rollouts for online training. + + Args: + z: (*, latent_dim) tensor for the current state's latent representation. + std: The standard deviation of the injected noise. + Returns: + (*, action_dim) tensor for the sampled action. + """ + action = torch.tanh(self._pi(z)) + if std > 0: + std = torch.ones_like(action) * std + action += torch.randn_like(action) * std + return action + + def V(self, z: Tensor) -> Tensor: # noqa: N802 + """Predict state value (V). + + Args: + z: (*, latent_dim) tensor for the current state's latent representation. + Returns: + (*,) tensor of estimated state values. + """ + return self._V(z).squeeze(-1) + + def Qs(self, z: Tensor, a: Tensor, return_min: bool = False) -> Tensor: # noqa: N802 + """Predict state-action value for all of the learned Q functions. + + Args: + z: (*, latent_dim) tensor for the current state's latent representation. + a: (*, action_dim) tensor for the action to be applied. + return_min: Set to true for implementing the detail in App. C of the FOWM paper: randomly select + 2 of the Qs and return the minimum + Returns: + (q_ensemble, *) tensor for the value predictions of each learned Q function in the ensemble OR + (*,) tensor if return_min=True. + """ + x = torch.cat([z, a], dim=-1) + if not return_min: + return torch.stack([q(x).squeeze(-1) for q in self._Qs], dim=0) + else: + if len(self._Qs) > 2: # noqa: SIM108 + Qs = [self._Qs[i] for i in np.random.choice(len(self._Qs), size=2)] + else: + Qs = self._Qs + return torch.stack([q(x).squeeze(-1) for q in Qs], dim=0).min(dim=0)[0] + + +class TDMPCObservationEncoder(nn.Module): + """Encode image and/or state vector observations.""" + + def __init__(self, config: TDMPCConfig): + """ + Creates encoders for pixel and/or state modalities. + TODO(alexander-soare): The original work allows for multiple images by concatenating them along the + channel dimension. Re-implement this capability. + """ + super().__init__() + self.config = config + + if "observation.image" in config.input_shapes: + self.image_enc_layers = nn.Sequential( + nn.Conv2d( + config.input_shapes["observation.image"][0], config.image_encoder_hidden_dim, 7, stride=2 + ), + nn.ReLU(), + nn.Conv2d(config.image_encoder_hidden_dim, config.image_encoder_hidden_dim, 5, stride=2), + nn.ReLU(), + nn.Conv2d(config.image_encoder_hidden_dim, config.image_encoder_hidden_dim, 3, stride=2), + nn.ReLU(), + nn.Conv2d(config.image_encoder_hidden_dim, config.image_encoder_hidden_dim, 3, stride=2), + nn.ReLU(), + ) + dummy_batch = torch.zeros(1, *config.input_shapes["observation.image"]) + with torch.inference_mode(): + out_shape = self.image_enc_layers(dummy_batch).shape[1:] + self.image_enc_layers.extend( + nn.Sequential( + nn.Flatten(), + nn.Linear(np.prod(out_shape), config.latent_dim), + nn.LayerNorm(config.latent_dim), + nn.Sigmoid(), + ) + ) + if "observation.state" in config.input_shapes: + self.state_enc_layers = nn.Sequential( + nn.Linear(config.input_shapes["observation.state"][0], config.state_encoder_hidden_dim), + nn.ELU(), + nn.Linear(config.state_encoder_hidden_dim, config.latent_dim), + nn.LayerNorm(config.latent_dim), + nn.Sigmoid(), + ) + + def forward(self, obs_dict: dict[str, Tensor]) -> Tensor: + """Encode the image and/or state vector. + + Each modality is encoded into a feature vector of size (latent_dim,) and then a uniform mean is taken + over all features. + """ + feat = [] + if "observation.image" in self.config.input_shapes: + feat.append(flatten_forward_unflatten(self.image_enc_layers, obs_dict["observation.image"])) + if "observation.state" in self.config.input_shapes: + feat.append(self.state_enc_layers(obs_dict["observation.state"])) + return torch.stack(feat, dim=0).mean(0) + + +def random_shifts_aug(x: Tensor, max_random_shift_ratio: float) -> Tensor: + """Randomly shifts images horizontally and vertically. + + Adapted from https://github.com/facebookresearch/drqv2 + """ + b, _, h, w = x.size() + assert h == w, "non-square images not handled yet" + pad = int(round(max_random_shift_ratio * h)) + x = F.pad(x, tuple([pad] * 4), "replicate") + eps = 1.0 / (h + 2 * pad) + arange = torch.linspace( + -1.0 + eps, + 1.0 - eps, + h + 2 * pad, + device=x.device, + dtype=torch.float32, + )[:h] + arange = einops.repeat(arange, "w -> h w 1", h=h) + base_grid = torch.cat([arange, arange.transpose(1, 0)], dim=2) + base_grid = einops.repeat(base_grid, "h w c -> b h w c", b=b) + # A random shift in units of pixels and within the boundaries of the padding. + shift = torch.randint( + 0, + 2 * pad + 1, + size=(b, 1, 1, 2), + device=x.device, + dtype=torch.float32, + ) + shift *= 2.0 / (h + 2 * pad) + grid = base_grid + shift + return F.grid_sample(x, grid, padding_mode="zeros", align_corners=False) + + +def update_ema_parameters(ema_net: nn.Module, net: nn.Module, alpha: float): + """Update EMA parameters in place with ema_param <- alpha * ema_param + (1 - alpha) * param.""" + for ema_module, module in zip(ema_net.modules(), net.modules(), strict=True): + for (n_p_ema, p_ema), (n_p, p) in zip( + ema_module.named_parameters(recurse=False), module.named_parameters(recurse=False), strict=True + ): + assert n_p_ema == n_p, "Parameter names don't match for EMA model update" + if isinstance(p, dict): + raise RuntimeError("Dict parameter not supported") + if isinstance(module, nn.modules.batchnorm._BatchNorm) or not p.requires_grad: + # Copy BatchNorm parameters, and non-trainable parameters directly. + p_ema.copy_(p.to(dtype=p_ema.dtype).data) + with torch.no_grad(): + p_ema.mul_(alpha) + p_ema.add_(p.to(dtype=p_ema.dtype).data, alpha=1 - alpha) + + +def flatten_forward_unflatten(fn: Callable[[Tensor], Tensor], image_tensor: Tensor) -> Tensor: + """Helper to temporarily flatten extra dims at the start of the image tensor. + + Args: + fn: Callable that the image tensor will be passed to. It should accept (B, C, H, W) and return + (B, *), where * is any number of dimensions. + image_tensor: An image tensor of shape (**, C, H, W), where ** is any number of dimensions, generally + different from *. + Returns: + A return value from the callable reshaped to (**, *). + """ + if image_tensor.ndim == 4: + return fn(image_tensor) + start_dims = image_tensor.shape[:-3] + inp = torch.flatten(image_tensor, end_dim=-4) + flat_out = fn(inp) + return torch.reshape(flat_out, (*start_dims, *flat_out.shape[1:])) diff --git a/lerobot/common/policies/tdmpc/policy.py b/lerobot/common/policies/tdmpc/policy.py deleted file mode 100644 index adaa30c0..00000000 --- a/lerobot/common/policies/tdmpc/policy.py +++ /dev/null @@ -1,495 +0,0 @@ -# ruff: noqa: N806 - -import time -from collections import deque -from copy import deepcopy - -import einops -import numpy as np -import torch -import torch.nn as nn - -import lerobot.common.policies.tdmpc.helper as h -from lerobot.common.policies.utils import populate_queues -from lerobot.common.utils.utils import get_safe_torch_device - -FIRST_FRAME = 0 - - -class TOLD(nn.Module): - """Task-Oriented Latent Dynamics (TOLD) model used in TD-MPC.""" - - def __init__(self, cfg): - super().__init__() - action_dim = cfg.action_dim - - self.cfg = cfg - self._encoder = h.enc(cfg) - self._dynamics = h.dynamics(cfg.latent_dim + action_dim, cfg.mlp_dim, cfg.latent_dim) - self._reward = h.mlp(cfg.latent_dim + action_dim, cfg.mlp_dim, 1) - self._pi = h.mlp(cfg.latent_dim, cfg.mlp_dim, action_dim) - self._Qs = nn.ModuleList([h.q(cfg) for _ in range(cfg.num_q)]) - self._V = h.v(cfg) - self.apply(h.orthogonal_init) - for m in [self._reward, *self._Qs]: - m[-1].weight.data.fill_(0) - m[-1].bias.data.fill_(0) - - def track_q_grad(self, enable=True): - """Utility function. Enables/disables gradient tracking of Q-networks.""" - for m in self._Qs: - h.set_requires_grad(m, enable) - - def track_v_grad(self, enable=True): - """Utility function. Enables/disables gradient tracking of Q-networks.""" - if hasattr(self, "_V"): - h.set_requires_grad(self._V, enable) - - def encode(self, obs): - """Encodes an observation into its latent representation.""" - out = self._encoder(obs) - if isinstance(obs, dict): - # fusion - out = torch.stack([v for k, v in out.items()]).mean(dim=0) - return out - - def next(self, z, a): - """Predicts next latent state (d) and single-step reward (R).""" - x = torch.cat([z, a], dim=-1) - return self._dynamics(x), self._reward(x) - - def next_dynamics(self, z, a): - """Predicts next latent state (d).""" - x = torch.cat([z, a], dim=-1) - return self._dynamics(x) - - def pi(self, z, std=0): - """Samples an action from the learned policy (pi).""" - mu = torch.tanh(self._pi(z)) - if std > 0: - std = torch.ones_like(mu) * std - return h.TruncatedNormal(mu, std).sample(clip=0.3) - return mu - - def V(self, z): # noqa: N802 - """Predict state value (V).""" - return self._V(z) - - def Q(self, z, a, return_type): # noqa: N802 - """Predict state-action value (Q).""" - assert return_type in {"min", "avg", "all"} - x = torch.cat([z, a], dim=-1) - - if return_type == "all": - return torch.stack([q(x) for q in self._Qs], dim=0) - - idxs = np.random.choice(self.cfg.num_q, 2, replace=False) - Q1, Q2 = self._Qs[idxs[0]](x), self._Qs[idxs[1]](x) - return torch.min(Q1, Q2) if return_type == "min" else (Q1 + Q2) / 2 - - -class TDMPCPolicy(nn.Module): - """Implementation of TD-MPC learning + inference.""" - - name = "tdmpc" - - def __init__(self, cfg, n_obs_steps, n_action_steps, device): - super().__init__() - self.action_dim = cfg.action_dim - - self.cfg = cfg - self.n_obs_steps = n_obs_steps - self.n_action_steps = n_action_steps - self.device = get_safe_torch_device(device) - self.std = h.linear_schedule(cfg.std_schedule, 0) - self.model = TOLD(cfg) - self.model.to(self.device) - self.model_target = deepcopy(self.model) - self.optim = torch.optim.Adam(self.model.parameters(), lr=self.cfg.lr) - self.pi_optim = torch.optim.Adam(self.model._pi.parameters(), lr=self.cfg.lr) - # self.bc_optim = torch.optim.Adam(self.model.parameters(), lr=self.cfg.lr) - self.model.eval() - self.model_target.eval() - - self.register_buffer("step", torch.zeros(1)) - - def state_dict(self): - """Retrieve state dict of TOLD model, including slow-moving target network.""" - return { - "model": self.model.state_dict(), - "model_target": self.model_target.state_dict(), - } - - def save(self, fp): - """Save state dict of TOLD model to filepath.""" - torch.save(self.state_dict(), fp) - - def load(self, fp): - """Load a saved state dict from filepath into current agent.""" - d = torch.load(fp) - self.model.load_state_dict(d["model"]) - self.model_target.load_state_dict(d["model_target"]) - - def reset(self): - """ - Clear observation and action queues. Should be called on `env.reset()` - """ - self._queues = { - "observation.image": deque(maxlen=self.n_obs_steps), - "observation.state": deque(maxlen=self.n_obs_steps), - "action": deque(maxlen=self.n_action_steps), - } - - @torch.no_grad() - def select_action(self, batch, step): - assert "observation.image" in batch - assert "observation.state" in batch - assert len(batch) == 2 - - self._queues = populate_queues(self._queues, batch) - - t0 = step == 0 - - self.eval() - - if len(self._queues["action"]) == 0: - batch = {key: torch.stack(list(self._queues[key]), dim=1) for key in batch} - - if self.n_obs_steps == 1: - # hack to remove the time dimension - for key in batch: - assert batch[key].shape[1] == 1 - batch[key] = batch[key][:, 0] - - actions = [] - batch_size = batch["observation.image"].shape[0] - for i in range(batch_size): - obs = { - "rgb": batch["observation.image"][[i]], - "state": batch["observation.state"][[i]], - } - # Note: unsqueeze needed because `act` still uses non-batch logic. - action = self.act(obs, t0=t0, step=self.step) - actions.append(action) - action = torch.stack(actions) - - # tdmpc returns an action for 1 timestep only, so we copy it over `n_action_steps` time - if i in range(self.n_action_steps): - self._queues["action"].append(action) - - action = self._queues["action"].popleft() - return action - - @torch.no_grad() - def act(self, obs, t0=False, step=None): - """Take an action. Uses either MPC or the learned policy, depending on the self.cfg.mpc flag.""" - obs = {k: o.detach() for k, o in obs.items()} if isinstance(obs, dict) else obs.detach() - z = self.model.encode(obs) - if self.cfg.mpc: - a = self.plan(z, t0=t0, step=step) - else: - a = self.model.pi(z, self.cfg.min_std * self.model.training).squeeze(0) - return a - - @torch.no_grad() - def estimate_value(self, z, actions, horizon): - """Estimate value of a trajectory starting at latent state z and executing given actions.""" - G, discount = 0, 1 - for t in range(horizon): - if self.cfg.uncertainty_cost > 0: - G -= ( - discount - * self.cfg.uncertainty_cost - * self.model.Q(z, actions[t], return_type="all").std(dim=0) - ) - z, reward = self.model.next(z, actions[t]) - G += discount * reward - discount *= self.cfg.discount - pi = self.model.pi(z, self.cfg.min_std) - G += discount * self.model.Q(z, pi, return_type="min") - if self.cfg.uncertainty_cost > 0: - G -= discount * self.cfg.uncertainty_cost * self.model.Q(z, pi, return_type="all").std(dim=0) - return G - - @torch.no_grad() - def plan(self, z, step=None, t0=True): - """ - Plan next action using TD-MPC inference. - z: latent state. - step: current time step. determines e.g. planning horizon. - t0: whether current step is the first step of an episode. - """ - # during eval: eval_mode: uniform sampling and action noise is disabled during evaluation. - - assert step is not None - # Seed steps - if step < self.cfg.seed_steps and self.model.training: - return torch.empty(self.action_dim, dtype=torch.float32, device=self.device).uniform_(-1, 1) - - # Sample policy trajectories - horizon = int(min(self.cfg.horizon, h.linear_schedule(self.cfg.horizon_schedule, step))) - num_pi_trajs = int(self.cfg.mixture_coef * self.cfg.num_samples) - if num_pi_trajs > 0: - pi_actions = torch.empty(horizon, num_pi_trajs, self.action_dim, device=self.device) - _z = z.repeat(num_pi_trajs, 1) - for t in range(horizon): - pi_actions[t] = self.model.pi(_z, self.cfg.min_std) - _z = self.model.next_dynamics(_z, pi_actions[t]) - - # Initialize state and parameters - z = z.repeat(self.cfg.num_samples + num_pi_trajs, 1) - mean = torch.zeros(horizon, self.action_dim, device=self.device) - std = self.cfg.max_std * torch.ones(horizon, self.action_dim, device=self.device) - if not t0 and hasattr(self, "_prev_mean"): - mean[:-1] = self._prev_mean[1:] - - # Iterate CEM - for _ in range(self.cfg.iterations): - actions = torch.clamp( - mean.unsqueeze(1) - + std.unsqueeze(1) - * torch.randn(horizon, self.cfg.num_samples, self.action_dim, device=std.device), - -1, - 1, - ) - if num_pi_trajs > 0: - actions = torch.cat([actions, pi_actions], dim=1) - - # Compute elite actions - value = self.estimate_value(z, actions, horizon).nan_to_num_(0) - elite_idxs = torch.topk(value.squeeze(1), self.cfg.num_elites, dim=0).indices - elite_value, elite_actions = value[elite_idxs], actions[:, elite_idxs] - - # Update parameters - max_value = elite_value.max(0)[0] - score = torch.exp(self.cfg.temperature * (elite_value - max_value)) - score /= score.sum(0) - _mean = torch.sum(score.unsqueeze(0) * elite_actions, dim=1) / (score.sum(0) + 1e-9) - _std = torch.sqrt( - torch.sum( - score.unsqueeze(0) * (elite_actions - _mean.unsqueeze(1)) ** 2, - dim=1, - ) - / (score.sum(0) + 1e-9) - ) - _std = _std.clamp_(self.std, self.cfg.max_std) - mean, std = self.cfg.momentum * mean + (1 - self.cfg.momentum) * _mean, _std - - # Outputs - # TODO(rcadene): remove numpy with - # # Convert score tensor to probabilities using softmax - # probabilities = torch.softmax(score, dim=0) - # # Generate a random sample index based on the probabilities - # sample_index = torch.multinomial(probabilities, 1).item() - score = score.squeeze(1).cpu().numpy() - actions = elite_actions[:, np.random.choice(np.arange(score.shape[0]), p=score)] - self._prev_mean = mean - mean, std = actions[0], _std[0] - a = mean - if self.model.training: - a += std * torch.randn(self.action_dim, device=std.device) - return torch.clamp(a, -1, 1) - - def update_pi(self, zs, acts=None): - """Update policy using a sequence of latent states.""" - self.pi_optim.zero_grad(set_to_none=True) - self.model.track_q_grad(False) - self.model.track_v_grad(False) - - info = {} - # Advantage Weighted Regression - assert acts is not None - vs = self.model.V(zs) - qs = self.model_target.Q(zs, acts, return_type="min") - adv = qs - vs - exp_a = torch.exp(adv * self.cfg.A_scaling) - exp_a = torch.clamp(exp_a, max=100.0) - log_probs = h.gaussian_logprob(self.model.pi(zs) - acts, 0) - rho = torch.pow(self.cfg.rho, torch.arange(len(qs), device=self.device)) - pi_loss = -((exp_a * log_probs).mean(dim=(1, 2)) * rho).mean() - info["adv"] = adv[0] - - pi_loss.backward() - torch.nn.utils.clip_grad_norm_( - self.model._pi.parameters(), - self.cfg.grad_clip_norm, - error_if_nonfinite=False, - ) - self.pi_optim.step() - self.model.track_q_grad(True) - self.model.track_v_grad(True) - - info["pi_loss"] = pi_loss.item() - return pi_loss.item(), info - - @torch.no_grad() - def _td_target(self, next_z, reward, mask): - """Compute the TD-target from a reward and the observation at the following time step.""" - next_v = self.model.V(next_z) - td_target = reward + self.cfg.discount * mask * next_v.squeeze(2) - return td_target - - def forward(self, batch, step): - # TODO(alexander-soare): Refactor TDMPC and make it comply with the policy interface documentation. - raise NotImplementedError() - - def update(self, batch, step): - """Main update function. Corresponds to one iteration of the model learning.""" - start_time = time.time() - - batch_size = batch["index"].shape[0] - - # TODO(rcadene): convert tdmpc with (batch size, time/horizon, channels) - # instead of currently (time/horizon, batch size, channels) which is not the pytorch convention - # batch size b = 256, time/horizon t = 5 - # b t ... -> t b ... - for key in batch: - if batch[key].ndim > 1: - batch[key] = batch[key].transpose(1, 0) - - action = batch["action"] - reward = batch["next.reward"] - # idxs = batch["index"] # TODO(rcadene): use idxs to update sampling weights - done = torch.zeros_like(reward, dtype=torch.bool, device=reward.device) - mask = torch.ones_like(reward, dtype=torch.bool, device=reward.device) - weights = torch.ones(batch_size, dtype=torch.bool, device=reward.device) - - obses = { - "rgb": batch["observation.image"], - "state": batch["observation.state"], - } - - shapes = {} - for k in obses: - shapes[k] = obses[k].shape - obses[k] = einops.rearrange(obses[k], "t b ... -> (t b) ... ") - - # Apply augmentations - aug_tf = h.aug(self.cfg) - obses = aug_tf(obses) - - for k in obses: - t, b = shapes[k][:2] - obses[k] = einops.rearrange(obses[k], "(t b) ... -> t b ... ", b=b, t=t) - - obs, next_obses = {}, {} - for k in obses: - obs[k] = obses[k][0] - next_obses[k] = obses[k][1:].clone() - - horizon = next_obses["rgb"].shape[0] - loss_mask = torch.ones_like(mask, device=self.device) - for t in range(1, horizon): - loss_mask[t] = loss_mask[t - 1] * (~done[t - 1]) - - self.optim.zero_grad(set_to_none=True) - self.std = h.linear_schedule(self.cfg.std_schedule, step) - self.model.train() - - data_s = time.time() - start_time - - # Compute targets - with torch.no_grad(): - next_z = self.model.encode(next_obses) - z_targets = self.model_target.encode(next_obses) - td_targets = self._td_target(next_z, reward, mask) - - # Latent rollout - zs = torch.empty(horizon + 1, batch_size, self.cfg.latent_dim, device=self.device) - reward_preds = torch.empty_like(reward, device=self.device) - assert reward.shape[0] == horizon - z = self.model.encode(obs) - zs[0] = z - value_info = {"Q": 0.0, "V": 0.0} - for t in range(horizon): - z, reward_pred = self.model.next(z, action[t]) - zs[t + 1] = z - reward_preds[t] = reward_pred.squeeze(1) - - with torch.no_grad(): - v_target = self.model_target.Q(zs[:-1].detach(), action, return_type="min") - - # Predictions - qs = self.model.Q(zs[:-1], action, return_type="all") - qs = qs.squeeze(3) - value_info["Q"] = qs.mean().item() - v = self.model.V(zs[:-1]) - value_info["V"] = v.mean().item() - - # Losses - rho = torch.pow(self.cfg.rho, torch.arange(horizon, device=self.device)).view(-1, 1) - consistency_loss = (rho * torch.mean(h.mse(zs[1:], z_targets), dim=2) * loss_mask).sum(dim=0) - reward_loss = (rho * h.mse(reward_preds, reward) * loss_mask).sum(dim=0) - q_value_loss, priority_loss = 0, 0 - for q in range(self.cfg.num_q): - q_value_loss += (rho * h.mse(qs[q], td_targets) * loss_mask).sum(dim=0) - priority_loss += (rho * h.l1(qs[q], td_targets) * loss_mask).sum(dim=0) - - expectile = h.linear_schedule(self.cfg.expectile, step) - v_value_loss = (rho * h.l2_expectile(v_target - v, expectile=expectile).squeeze(2) * loss_mask).sum( - dim=0 - ) - - total_loss = ( - self.cfg.consistency_coef * consistency_loss - + self.cfg.reward_coef * reward_loss - + self.cfg.value_coef * q_value_loss - + self.cfg.value_coef * v_value_loss - ) - - weighted_loss = (total_loss * weights).mean() - weighted_loss.register_hook(lambda grad: grad * (1 / self.cfg.horizon)) - has_nan = torch.isnan(weighted_loss).item() - if has_nan: - print(f"weighted_loss has nan: {total_loss=} {weights=}") - else: - weighted_loss.backward() - - grad_norm = torch.nn.utils.clip_grad_norm_( - self.model.parameters(), self.cfg.grad_clip_norm, error_if_nonfinite=False - ) - self.optim.step() - - # TODO(rcadene): implement PrioritizedSampling by modifying sampler.weights with priorities computed by a criterion - # if self.cfg.per: - # # Update priorities - # priorities = priority_loss.clamp(max=1e4).detach() - # has_nan = torch.isnan(priorities).any().item() - # if has_nan: - # print(f"priorities has nan: {priorities=}") - # else: - # replay_buffer.update_priority( - # idxs[:num_slices], - # priorities[:num_slices], - # ) - # if demo_batch_size > 0: - # demo_buffer.update_priority(demo_idxs, priorities[num_slices:]) - - # Update policy + target network - _, pi_update_info = self.update_pi(zs[:-1].detach(), acts=action) - - if step % self.cfg.update_freq == 0: - h.ema(self.model._encoder, self.model_target._encoder, self.cfg.tau) - h.ema(self.model._Qs, self.model_target._Qs, self.cfg.tau) - - self.model.eval() - - info = { - "consistency_loss": float(consistency_loss.mean().item()), - "reward_loss": float(reward_loss.mean().item()), - "Q_value_loss": float(q_value_loss.mean().item()), - "V_value_loss": float(v_value_loss.mean().item()), - "sum_loss": float(total_loss.mean().item()), - "loss": float(weighted_loss.mean().item()), - "grad_norm": float(grad_norm), - "lr": self.cfg.lr, - "data_s": data_s, - "update_s": time.time() - start_time, - } - # info["demo_batch_size"] = demo_batch_size - info["expectile"] = expectile - info.update(value_info) - info.update(pi_update_info) - - self.step[0] = step - return info diff --git a/lerobot/configs/default.yaml b/lerobot/configs/default.yaml index 21370e4b..b3b85c0c 100644 --- a/lerobot/configs/default.yaml +++ b/lerobot/configs/default.yaml @@ -9,31 +9,24 @@ hydra: job: name: default -seed: 1337 -# batch size for TorchRL SerialEnv. Each underlying env will get the seed = seed + env_index -# NOTE: only diffusion policy supports rollout_batch_size > 1 -rollout_batch_size: 1 device: cuda # cpu -prefetch: 4 -eval_freq: ??? -save_freq: ??? -eval_episodes: ??? -save_video: false -save_model: false -save_buffer: false -train_steps: ??? -fps: ??? +seed: ??? +dataset_repo_id: lerobot/pusht -offline_prioritized_sampler: true +training: + offline_steps: ??? + online_steps: ??? + online_steps_between_rollouts: ??? + online_sampling_ratio: 0.5 + eval_freq: ??? + save_freq: ??? + log_freq: 250 + save_model: false -dataset: - repo_id: ??? - -n_action_steps: ??? -n_obs_steps: ??? -env: ??? - -policy: ??? +eval: + n_episodes: 1 + # TODO(alexander-soare): Right now this does not work. Reinstate this. + batch_size: 1 wandb: enable: true diff --git a/lerobot/configs/env/aloha.yaml b/lerobot/configs/env/aloha.yaml index 41d44db8..95e4503d 100644 --- a/lerobot/configs/env/aloha.yaml +++ b/lerobot/configs/env/aloha.yaml @@ -1,18 +1,7 @@ # @package _global_ -eval_episodes: 50 -eval_freq: 7500 -save_freq: 75000 -log_freq: 250 -# TODO: same as xarm, need to adjust -offline_steps: 25000 -online_steps: 25000 - fps: 50 -dataset: - repo_id: lerobot/aloha_sim_insertion_human - env: name: aloha task: AlohaInsertion-v0 diff --git a/lerobot/configs/env/pusht.yaml b/lerobot/configs/env/pusht.yaml index 29c2a258..43e9d187 100644 --- a/lerobot/configs/env/pusht.yaml +++ b/lerobot/configs/env/pusht.yaml @@ -1,18 +1,7 @@ # @package _global_ -eval_episodes: 50 -eval_freq: 7500 -save_freq: 75000 -log_freq: 250 -# TODO: same as xarm, need to adjust -offline_steps: 25000 -online_steps: 25000 - fps: 10 -dataset: - repo_id: lerobot/pusht - env: name: pusht task: PushT-v0 diff --git a/lerobot/configs/env/xarm.yaml b/lerobot/configs/env/xarm.yaml index 00b8e2d5..098b0396 100644 --- a/lerobot/configs/env/xarm.yaml +++ b/lerobot/configs/env/xarm.yaml @@ -1,17 +1,7 @@ # @package _global_ -eval_episodes: 20 -eval_freq: 1000 -save_freq: 10000 -log_freq: 50 -offline_steps: 25000 -online_steps: 25000 - fps: 15 -dataset: - repo_id: lerobot/xarm_lift_medium - env: name: xarm task: XarmLift-v0 diff --git a/lerobot/configs/policy/act.yaml b/lerobot/configs/policy/act.yaml index c67793e4..c0f47c44 100644 --- a/lerobot/configs/policy/act.yaml +++ b/lerobot/configs/policy/act.yaml @@ -1,30 +1,41 @@ # @package _global_ -offline_steps: 80000 -online_steps: 0 +seed: 1000 +dataset_repo_id: lerobot/aloha_sim_insertion_human -eval_episodes: 1 -eval_freq: 10000 -save_freq: 100000 -log_freq: 250 +training: + offline_steps: 80000 + online_steps: 0 + eval_freq: 10000 + save_freq: 100000 + log_freq: 250 + save_model: true -n_obs_steps: 1 -# when temporal_agg=False, n_action_steps=horizon + batch_size: 8 + lr: 1e-5 + lr_backbone: 1e-5 + weight_decay: 1e-4 + grad_clip_norm: 10 + online_steps_between_rollouts: 1 -override_dataset_stats: - observation.images.top: - # stats from imagenet, since we use a pretrained vision model - mean: [[[0.485]], [[0.456]], [[0.406]]] # (c,1,1) - std: [[[0.229]], [[0.224]], [[0.225]]] # (c,1,1) + override_dataset_stats: + observation.images.top: + # stats from imagenet, since we use a pretrained vision model + mean: [[[0.485]], [[0.456]], [[0.406]]] # (c,1,1) + std: [[[0.229]], [[0.224]], [[0.225]]] # (c,1,1) + + delta_timestamps: + action: "[i / ${fps} for i in range(${policy.chunk_size})]" + +eval: + n_episodes:: 50 # See `configuration_act.py` for more details. policy: name: act - pretrained_model_path: - # Input / output structure. - n_obs_steps: ${n_obs_steps} + n_obs_steps: 1 chunk_size: 100 # chunk_size n_action_steps: 100 @@ -49,7 +60,7 @@ policy: replace_final_stride_with_dilation: false # Transformer layers. pre_norm: false - d_model: 512 + dim_model: 512 n_heads: 8 dim_feedforward: 3200 feedforward_activation: relu @@ -66,15 +77,3 @@ policy: # Training and loss computation. dropout: 0.1 kl_weight: 10.0 - - # --- - # TODO(alexander-soare): Remove these from the policy config. - batch_size: 8 - lr: 1e-5 - lr_backbone: 1e-5 - weight_decay: 1e-4 - grad_clip_norm: 10 - utd: 1 - - delta_timestamps: - action: "[i / ${fps} for i in range(${policy.chunk_size})]" diff --git a/lerobot/configs/policy/diffusion.yaml b/lerobot/configs/policy/diffusion.yaml index f96e21c2..f1b05185 100644 --- a/lerobot/configs/policy/diffusion.yaml +++ b/lerobot/configs/policy/diffusion.yaml @@ -1,22 +1,33 @@ # @package _global_ seed: 100000 -horizon: 16 -n_obs_steps: 2 -n_action_steps: 8 -dataset_obs_steps: ${n_obs_steps} -past_action_visible: False -keypoint_visible_rate: 1.0 +dataset_repo_id: lerobot/pusht -eval_episodes: 50 -eval_freq: 5000 -save_freq: 5000 -log_freq: 250 +training: + offline_steps: 200000 + online_steps: 0 + eval_freq: 5000 + save_freq: 5000 + log_freq: 250 + save_model: true -offline_steps: 200000 -online_steps: 0 + batch_size: 64 + grad_clip_norm: 10 + lr: 1.0e-4 + lr_scheduler: cosine + lr_warmup_steps: 500 + adam_betas: [0.95, 0.999] + adam_eps: 1.0e-8 + adam_weight_decay: 1.0e-6 + online_steps_between_rollouts: 1 -offline_prioritized_sampler: true + delta_timestamps: + observation.image: "[i / ${fps} for i in range(1 - ${policy.n_obs_steps}, 1)]" + observation.state: "[i / ${fps} for i in range(1 - ${policy.n_obs_steps}, 1)]" + action: "[i / ${fps} for i in range(1 - ${policy.n_obs_steps}, 1 - ${policy.n_obs_steps} + ${policy.horizon})]" + +eval: + n_episodes: 50 override_dataset_stats: # TODO(rcadene, alexander-soare): should we remove image stats as well? do we use a pretrained vision model? @@ -35,12 +46,10 @@ override_dataset_stats: policy: name: diffusion - pretrained_model_path: - # Input / output structure. - n_obs_steps: ${n_obs_steps} - horizon: ${horizon} - n_action_steps: ${n_action_steps} + n_obs_steps: 2 + horizon: 16 + n_action_steps: 8 input_shapes: # TODO(rcadene, alexander-soare): add variables for height and width from the dataset/env? @@ -84,23 +93,9 @@ policy: # --- # TODO(alexander-soare): Remove these from the policy config. - batch_size: 64 - grad_clip_norm: 10 - lr: 1.0e-4 - lr_scheduler: cosine - lr_warmup_steps: 500 - adam_betas: [0.95, 0.999] - adam_eps: 1.0e-8 - adam_weight_decay: 1.0e-6 - utd: 1 use_ema: true ema_update_after_step: 0 ema_min_alpha: 0.0 ema_max_alpha: 0.9999 ema_inv_gamma: 1.0 ema_power: 0.75 - - delta_timestamps: - observation.image: "[i / ${fps} for i in range(1 - ${n_obs_steps}, 1)]" - observation.state: "[i / ${fps} for i in range(1 - ${n_obs_steps}, 1)]" - action: "[i / ${fps} for i in range(1 - ${n_obs_steps}, 1 - ${n_obs_steps} + ${policy.horizon})]" diff --git a/lerobot/configs/policy/tdmpc.yaml b/lerobot/configs/policy/tdmpc.yaml index c78a5d73..6387882c 100644 --- a/lerobot/configs/policy/tdmpc.yaml +++ b/lerobot/configs/policy/tdmpc.yaml @@ -1,85 +1,76 @@ # @package _global_ -n_action_steps: 2 -n_obs_steps: 1 +seed: 1 + +training: + offline_steps: 25000 + online_steps: 25000 + eval_freq: 5000 + online_steps_between_rollouts: 1 + online_sampling_ratio: 0.5 + + batch_size: 256 + grad_clip_norm: 10.0 + lr: 3e-4 + + delta_timestamps: + observation.image: "[i / ${fps} for i in range(${policy.horizon} + 1)]" + observation.state: "[i / ${fps} for i in range(${policy.horizon} + 1)]" + action: "[i / ${fps} for i in range(${policy.horizon})]" + next.reward: "[i / ${fps} for i in range(${policy.horizon})]" policy: name: tdmpc - reward_scale: 1.0 + pretrained_model_path: - episode_length: ${env.episode_length} - discount: 0.9 - modality: 'all' - - # pixels - frame_stack: 1 - num_channels: 32 - img_size: ${env.image_size} - state_dim: ${env.action_dim} - action_dim: ${env.action_dim} - - # planning - mpc: true - iterations: 6 - num_samples: 512 - num_elites: 50 - mixture_coef: 0.1 - min_std: 0.05 - max_std: 2.0 - temperature: 0.5 - momentum: 0.1 - uncertainty_cost: 1 - - # actor - log_std_min: -10 - log_std_max: 2 - - # learning - batch_size: 256 - max_buffer_size: 10000 + # Input / output structure. + n_action_repeats: 2 horizon: 5 - reward_coef: 0.5 - value_coef: 0.1 - consistency_coef: 20 - rho: 0.5 - kappa: 0.1 - lr: 3e-4 - std_schedule: ${policy.min_std} - horizon_schedule: ${policy.horizon} - per: true - per_alpha: 0.6 - per_beta: 0.4 - grad_clip_norm: 10 - seed_steps: 0 - update_freq: 2 - tau: 0.01 - utd: 1 - # offline rl - # dataset_dir: ??? - data_first_percent: 1.0 - is_data_clip: true - data_clip_eps: 1e-5 - expectile: 0.9 - A_scaling: 3.0 + input_shapes: + # TODO(rcadene, alexander-soare): add variables for height and width from the dataset/env? + observation.image: [3, 84, 84] + observation.state: ["${env.state_dim}"] + output_shapes: + action: ["${env.action_dim}"] - # offline->online - offline_steps: ${offline_steps} - pretrained_model_path: "" - # pretrained_model_path: "/home/rcadene/code/fowm/logs/xarm_lift/all/default/2/models/offline.pt" - # pretrained_model_path: "/home/rcadene/code/fowm/logs/xarm_lift/all/default/2/models/final.pt" - balanced_sampling: true - demo_schedule: 0.5 + # Normalization / Unnormalization + input_normalization_modes: null + output_normalization_modes: + action: min_max - # architecture - enc_dim: 256 - num_q: 5 - mlp_dim: 512 + # Architecture / modeling. + # Neural networks. + image_encoder_hidden_dim: 32 + state_encoder_hidden_dim: 256 latent_dim: 50 + q_ensemble_size: 5 + mlp_dim: 512 + # Reinforcement learning. + discount: 0.9 - delta_timestamps: - observation.image: "[i / ${fps} for i in range(6)]" - observation.state: "[i / ${fps} for i in range(6)]" - action: "[i / ${fps} for i in range(5)]" - next.reward: "[i / ${fps} for i in range(5)]" + # Inference. + use_mpc: false + cem_iterations: 6 + max_std: 2.0 + min_std: 0.05 + n_gaussian_samples: 512 + n_pi_samples: 51 + uncertainty_regularizer_coeff: 1.0 + n_elites: 50 + elite_weighting_temperature: 0.5 + gaussian_mean_momentum: 0.1 + + # Training and loss computation. + max_random_shift_ratio: 0.0476 + # Loss coefficients. + reward_coeff: 0.5 + expectile_weight: 0.9 + value_coeff: 0.1 + consistency_coeff: 20.0 + advantage_scaling: 3.0 + pi_coeff: 0.5 + temporal_decay_coeff: 0.5 + # Target model. + target_model_momentum: 0.995 diff --git a/lerobot/scripts/eval.py b/lerobot/scripts/eval.py index 0c10b7a5..6c9e28bf 100644 --- a/lerobot/scripts/eval.py +++ b/lerobot/scripts/eval.py @@ -1,30 +1,29 @@ """Evaluate a policy on an environment by running rollouts and computing metrics. -The script may be run in one of two ways: +Usage examples: -1. By providing the path to a config file with the --config argument. -2. By providing a HuggingFace Hub ID with the --hub-id argument. You may also provide a revision number with the - --revision argument. +You want to evaluate a model from the hub (eg: https://huggingface.co/lerobot/diffusion_policy_pusht_image) +for 10 episodes. -In either case, it is possible to override config arguments by adding a list of config.key=value arguments. +``` +python lerobot/scripts/eval.py -p lerobot/diffusion_policy_pusht_image eval.n_episodes=10 +``` -Examples: - -You have a specific config file to go with trained model weights, and want to run 10 episodes. +OR, you want to evaluate a model checkpoint from the LeRobot training script for 10 episodes. ``` python lerobot/scripts/eval.py \ ---config PATH/TO/FOLDER/config.yaml \ -policy.pretrained_model_path=PATH/TO/FOLDER/weights.pth \ -eval_episodes=10 + -p outputs/train/diffusion_policy_pusht_image/checkpoints/005000 \ + eval.n_episodes=10 ``` -You have a HuggingFace Hub ID, you know which revision you want, and want to run 10 episodes (note that in this case, -you don't need to specify which weights to use): +Note that in both examples, the repo/folder should contain at least `config.json`, `config.yaml` and +`model.safetensors`. -``` -python lerobot/scripts/eval.py --hub-id HUB/ID --revision v1.0 eval_episodes=10 -``` +Note the formatting for providing the number of episodes. Generally, you may provide any number of arguments +with `qualified.parameter.name=value`. In this case, the parameter eval.n_episodes appears as `n_episodes` +nested under `eval` in the `config.yaml` found at +https://huggingface.co/lerobot/diffusion_policy_pusht_image/tree/main. """ import argparse @@ -42,9 +41,12 @@ import numpy as np import torch from datasets import Dataset, Features, Image, Sequence, Value from huggingface_hub import snapshot_download +from huggingface_hub.utils._errors import RepositoryNotFoundError +from huggingface_hub.utils._validators import HFValidationError from PIL import Image as PILImage from tqdm import trange +from lerobot.common.datasets.factory import make_dataset from lerobot.common.datasets.utils import hf_transform_to_torch from lerobot.common.envs.factory import make_env from lerobot.common.envs.utils import postprocess_action, preprocess_observation @@ -65,10 +67,10 @@ def eval_policy( """ set `return_episode_data` to return a Hugging Face dataset object in an "episodes" key of the return dict. """ + policy.eval() + fps = env.unwrapped.metadata["render_fps"] - if policy is not None: - policy.eval() device = "cpu" if policy is None else next(policy.parameters()).device start = time.time() @@ -130,7 +132,7 @@ def eval_policy( # get the next action for the environment with torch.inference_mode(): - action = policy.select_action(observation, step=step) + action = policy.select_action(observation) # convert to cpu numpy action = postprocess_action(action) @@ -349,26 +351,42 @@ def eval_policy( return info -def eval(cfg: dict, out_dir=None): +def eval( + pretrained_policy_path: str | None = None, + hydra_cfg_path: str | None = None, + config_overrides: list[str] | None = None, +): + assert (pretrained_policy_path is None) ^ (hydra_cfg_path is None) + if hydra_cfg_path is None: + hydra_cfg = init_hydra_config(pretrained_policy_path / "config.yaml", config_overrides) + else: + hydra_cfg = init_hydra_config(hydra_cfg_path, config_overrides) + out_dir = ( + f"outputs/eval/{dt.now().strftime('%Y-%m-%d/%H-%M-%S')}_{hydra_cfg.env.name}_{hydra_cfg.policy.name}" + ) + if out_dir is None: raise NotImplementedError() - init_logging() - # Check device is available - get_safe_torch_device(cfg.device, log=True) + get_safe_torch_device(hydra_cfg.device, log=True) torch.backends.cudnn.benchmark = True torch.backends.cuda.matmul.allow_tf32 = True - set_global_seed(cfg.seed) + set_global_seed(hydra_cfg.seed) log_output_dir(out_dir) logging.info("Making environment.") - env = make_env(cfg, num_parallel_envs=cfg.eval_episodes) + env = make_env(hydra_cfg, num_parallel_envs=hydra_cfg.eval.n_episodes) logging.info("Making policy.") - policy = make_policy(cfg) + if hydra_cfg_path is None: + policy = make_policy(hydra_cfg=hydra_cfg, pretrained_policy_name_or_path=pretrained_policy_path) + else: + # Note: We need the dataset stats to pass to the policy's normalization modules. + policy = make_policy(hydra_cfg=hydra_cfg, dataset_stats=make_dataset(hydra_cfg).stats) + policy.eval() info = eval_policy( env, @@ -376,7 +394,7 @@ def eval(cfg: dict, out_dir=None): max_episodes_rendered=10, video_dir=Path(out_dir) / "eval", return_episode_data=False, - seed=cfg.seed, + seed=hydra_cfg.seed, ) print(info["aggregated"]) @@ -390,13 +408,29 @@ def eval(cfg: dict, out_dir=None): if __name__ == "__main__": + init_logging() + parser = argparse.ArgumentParser( description=__doc__, formatter_class=argparse.RawDescriptionHelpFormatter ) group = parser.add_mutually_exclusive_group(required=True) - group.add_argument("--config", help="Path to a specific yaml config you want to use.") - group.add_argument("--hub-id", help="HuggingFace Hub ID for a pretrained model.") - parser.add_argument("--revision", help="Optionally provide the HuggingFace Hub revision ID.") + group.add_argument( + "-p", + "--pretrained-policy-name-or-path", + help=( + "Either the repo ID of a model hosted on the Hub or a path to a directory containing weights " + "saved using `Policy.save_pretrained`. If not provided, the policy is initialized from scratch " + "(useful for debugging). This argument is mutually exclusive with `--config`." + ), + ) + group.add_argument( + "--config", + help=( + "Path to a yaml config you want to use for initializing a policy from scratch (useful for " + "debugging). This argument is mutually exclusive with `--pretrained-policy-name-or-path` (`-p`)." + ), + ) + parser.add_argument("--revision", help="Optionally provide the Hugging Face Hub revision ID.") parser.add_argument( "overrides", nargs="*", @@ -404,16 +438,28 @@ if __name__ == "__main__": ) args = parser.parse_args() - if args.config is not None: - # Note: For the config_path, Hydra wants a path relative to this script file. - cfg = init_hydra_config(args.config, args.overrides) - elif args.hub_id is not None: - folder = Path(snapshot_download(args.hub_id, revision=args.revision)) - cfg = init_hydra_config( - folder / "config.yaml", [f"policy.pretrained_model_path={folder / 'model.pt'}", *args.overrides] - ) + if args.pretrained_policy_name_or_path is None: + eval(hydra_cfg_path=args.config, config_overrides=args.overrides) + else: + try: + pretrained_policy_path = Path( + snapshot_download(args.pretrained_policy_name_or_path, revision=args.revision) + ) + except HFValidationError: + logging.warning( + "The provided pretrained_policy_name_or_path is not a valid Hugging Face Hub repo ID. " + "Treating it as a local directory." + ) + except RepositoryNotFoundError: + logging.warning( + "The provided pretrained_policy_name_or_path was not found on the Hugging Face Hub. Treating " + "it as a local directory." + ) + pretrained_policy_path = Path(args.pretrained_policy_name_or_path) + if not pretrained_policy_path.is_dir() or not pretrained_policy_path.exists(): + raise ValueError( + "The provided pretrained_policy_name_or_path is not a valid/existing Hugging Face Hub " + "repo ID, nor is it an existing local directory." + ) - eval( - cfg, - out_dir=f"outputs/eval/{dt.now().strftime('%Y-%m-%d/%H-%M-%S')}_{cfg.env.name}_{cfg.policy.name}", - ) + eval(pretrained_policy_path=pretrained_policy_path, config_overrides=args.overrides) diff --git a/lerobot/scripts/train.py b/lerobot/scripts/train.py index b2c73504..bd27b28a 100644 --- a/lerobot/scripts/train.py +++ b/lerobot/scripts/train.py @@ -15,6 +15,7 @@ from lerobot.common.datasets.utils import cycle from lerobot.common.envs.factory import make_env from lerobot.common.logger import Logger, log_output_dir from lerobot.common.policies.factory import make_policy +from lerobot.common.policies.policy_protocol import PolicyWithUpdate from lerobot.common.utils.utils import ( format_big_number, get_safe_torch_device, @@ -39,12 +40,17 @@ def update_policy(policy, batch, optimizer, grad_clip_norm, lr_scheduler=None): optimizer.step() optimizer.zero_grad() + if lr_scheduler is not None: lr_scheduler.step() if hasattr(policy, "ema") and policy.ema is not None: policy.ema.step(policy.diffusion) + if isinstance(policy, PolicyWithUpdate): + # To possibly update an internal buffer (for instance an Exponential Moving Average like in TDMPC). + policy.update() + info = { "loss": loss.item(), "grad_norm": float(grad_norm), @@ -81,7 +87,7 @@ def log_train_info(logger, info, step, cfg, dataset, is_offline): # A sample is an (observation,action) pair, where observation and action # can be on multiple timestamps. In a batch, we have `batch_size`` number of samples. - num_samples = (step + 1) * cfg.policy.batch_size + num_samples = (step + 1) * cfg.training.batch_size avg_samples_per_ep = dataset.num_samples / dataset.num_episodes num_episodes = num_samples / avg_samples_per_ep num_epochs = num_samples / dataset.num_samples @@ -117,7 +123,7 @@ def log_eval_info(logger, info, step, cfg, dataset, is_offline): # A sample is an (observation,action) pair, where observation and action # can be on multiple timestamps. In a batch, we have `batch_size`` number of samples. - num_samples = (step + 1) * cfg.policy.batch_size + num_samples = (step + 1) * cfg.training.batch_size avg_samples_per_ep = dataset.num_samples / dataset.num_episodes num_episodes = num_samples / avg_samples_per_ep num_epochs = num_samples / dataset.num_samples @@ -246,11 +252,12 @@ def train(cfg: dict, out_dir=None, job_name=None): raise NotImplementedError() if job_name is None: raise NotImplementedError() - if cfg.online_steps > 0: - assert cfg.rollout_batch_size == 1, "rollout_batch_size > 1 not supported for online training steps" init_logging() + if cfg.training.online_steps > 0 and cfg.eval.batch_size > 1: + logging.warning("eval.batch_size > 1 not supported for online training steps") + # Check device is available get_safe_torch_device(cfg.device, log=True) @@ -262,10 +269,10 @@ def train(cfg: dict, out_dir=None, job_name=None): offline_dataset = make_dataset(cfg) logging.info("make_env") - env = make_env(cfg, num_parallel_envs=cfg.eval_episodes) + env = make_env(cfg, num_parallel_envs=cfg.eval.n_episodes) logging.info("make_policy") - policy = make_policy(cfg, dataset_stats=offline_dataset.stats) + policy = make_policy(hydra_cfg=cfg, dataset_stats=offline_dataset.stats) # Create optimizer and scheduler # Temporary hack to move optimizer out of policy @@ -282,34 +289,33 @@ def train(cfg: dict, out_dir=None, job_name=None): "params": [ p for n, p in policy.named_parameters() if n.startswith("backbone") and p.requires_grad ], - "lr": cfg.policy.lr_backbone, + "lr": cfg.training.lr_backbone, }, ] optimizer = torch.optim.AdamW( - optimizer_params_dicts, lr=cfg.policy.lr, weight_decay=cfg.policy.weight_decay + optimizer_params_dicts, lr=cfg.training.lr, weight_decay=cfg.training.weight_decay ) lr_scheduler = None elif cfg.policy.name == "diffusion": optimizer = torch.optim.Adam( policy.diffusion.parameters(), - cfg.policy.lr, - cfg.policy.adam_betas, - cfg.policy.adam_eps, - cfg.policy.adam_weight_decay, + cfg.training.lr, + cfg.training.adam_betas, + cfg.training.adam_eps, + cfg.training.adam_weight_decay, ) - # TODO(rcadene): modify lr scheduler so that it doesn't depend on epochs but steps - # configure lr scheduler + assert cfg.training.online_steps == 0, "Diffusion Policy does not handle online training." lr_scheduler = get_scheduler( - cfg.policy.lr_scheduler, + cfg.training.lr_scheduler, optimizer=optimizer, - num_warmup_steps=cfg.policy.lr_warmup_steps, - num_training_steps=cfg.offline_steps, - # pytorch assumes stepping LRScheduler every epoch - # however huggingface diffusers steps it every batch - last_epoch=-1, + num_warmup_steps=cfg.training.lr_warmup_steps, + num_training_steps=cfg.training.offline_steps, ) elif policy.name == "tdmpc": - raise NotImplementedError("TD-MPC not implemented yet.") + optimizer = torch.optim.Adam(policy.parameters(), cfg.training.lr) + lr_scheduler = None + else: + raise NotImplementedError() num_learnable_params = sum(p.numel() for p in policy.parameters() if p.requires_grad) num_total_params = sum(p.numel() for p in policy.parameters()) @@ -319,8 +325,8 @@ def train(cfg: dict, out_dir=None, job_name=None): log_output_dir(out_dir) logging.info(f"{cfg.env.task=}") - logging.info(f"{cfg.offline_steps=} ({format_big_number(cfg.offline_steps)})") - logging.info(f"{cfg.online_steps=}") + logging.info(f"{cfg.training.offline_steps=} ({format_big_number(cfg.training.offline_steps)})") + logging.info(f"{cfg.training.online_steps=}") logging.info(f"{offline_dataset.num_samples=} ({format_big_number(offline_dataset.num_samples)})") logging.info(f"{offline_dataset.num_episodes=}") logging.info(f"{num_learnable_params=} ({format_big_number(num_learnable_params)})") @@ -328,7 +334,7 @@ def train(cfg: dict, out_dir=None, job_name=None): # Note: this helper will be used in offline and online training loops. def _maybe_eval_and_maybe_save(step): - if step % cfg.eval_freq == 0: + if step % cfg.training.eval_freq == 0: logging.info(f"Eval policy at step {step}") eval_info = eval_policy( env, @@ -342,37 +348,44 @@ def train(cfg: dict, out_dir=None, job_name=None): logger.log_video(eval_info["videos"][0], step, mode="eval") logging.info("Resume training") - if cfg.save_model and step % cfg.save_freq == 0: + if cfg.training.save_model and step % cfg.training.save_freq == 0: logging.info(f"Checkpoint policy after step {step}") - logger.save_model(policy, identifier=step) + # Note: Save with step as the identifier, and format it to have at least 6 digits but more if + # needed (choose 6 as a minimum for consistency without being overkill). + logger.save_model( + policy, + identifier=str(step).zfill( + max(6, len(str(cfg.training.offline_steps + cfg.training.online_steps))) + ), + ) logging.info("Resume training") # create dataloader for offline training dataloader = torch.utils.data.DataLoader( offline_dataset, - num_workers=8, - batch_size=cfg.policy.batch_size, + num_workers=4, + batch_size=cfg.training.batch_size, shuffle=True, pin_memory=cfg.device != "cpu", drop_last=False, ) dl_iter = cycle(dataloader) + policy.train() step = 0 # number of policy update (forward + backward + optim) is_offline = True - for offline_step in range(cfg.offline_steps): + for offline_step in range(cfg.training.offline_steps): if offline_step == 0: logging.info("Start offline training on a fixed dataset") - policy.train() batch = next(dl_iter) for key in batch: batch[key] = batch[key].to(cfg.device, non_blocking=True) - train_info = update_policy(policy, batch, optimizer, cfg.policy.grad_clip_norm, lr_scheduler) + train_info = update_policy(policy, batch, optimizer, cfg.training.grad_clip_norm, lr_scheduler) # TODO(rcadene): is it ok if step_t=0 = 0 and not 1 as previously done? - if step % cfg.log_freq == 0: + if step % cfg.training.log_freq == 0: log_train_info(logger, train_info, step, cfg, offline_dataset, is_offline) # Note: _maybe_eval_and_maybe_save happens **after** the `step`th training update has completed, so we pass in @@ -398,7 +411,7 @@ def train(cfg: dict, out_dir=None, job_name=None): dataloader = torch.utils.data.DataLoader( concat_dataset, num_workers=4, - batch_size=cfg.policy.batch_size, + batch_size=cfg.training.batch_size, sampler=sampler, pin_memory=cfg.device != "cpu", drop_last=False, @@ -407,10 +420,11 @@ def train(cfg: dict, out_dir=None, job_name=None): online_step = 0 is_offline = False - for env_step in range(cfg.online_steps): + for env_step in range(cfg.training.online_steps): if env_step == 0: logging.info("Start online training by interacting with environment") + policy.eval() with torch.no_grad(): eval_info = eval_policy( rollout_env, @@ -419,25 +433,25 @@ def train(cfg: dict, out_dir=None, job_name=None): seed=cfg.seed, ) - add_episodes_inplace( - online_dataset, - concat_dataset, - sampler, - hf_dataset=eval_info["episodes"]["hf_dataset"], - episode_data_index=eval_info["episodes"]["episode_data_index"], - pc_online_samples=cfg.get("demo_schedule", 0.5), - ) + add_episodes_inplace( + online_dataset, + concat_dataset, + sampler, + hf_dataset=eval_info["episodes"]["hf_dataset"], + episode_data_index=eval_info["episodes"]["episode_data_index"], + pc_online_samples=cfg.training.online_sampling_ratio, + ) - for _ in range(cfg.policy.utd): - policy.train() + policy.train() + for _ in range(cfg.training.online_steps_between_rollouts): batch = next(dl_iter) for key in batch: batch[key] = batch[key].to(cfg.device, non_blocking=True) - train_info = update_policy(policy, batch, optimizer, cfg.policy.grad_clip_norm, lr_scheduler) + train_info = update_policy(policy, batch, optimizer, cfg.training.grad_clip_norm, lr_scheduler) - if step % cfg.log_freq == 0: + if step % cfg.training.log_freq == 0: log_train_info(logger, train_info, step, cfg, online_dataset, is_offline) # Note: _maybe_eval_and_maybe_save happens **after** the `step`th training update has completed, so we pass diff --git a/poetry.lock b/poetry.lock index 5f4b198d..cb7cd6d1 100644 --- a/poetry.lock +++ b/poetry.lock @@ -837,13 +837,13 @@ files = [ [[package]] name = "filelock" -version = "3.13.4" +version = "3.14.0" description = "A platform independent file lock." optional = false python-versions = ">=3.8" files = [ - {file = "filelock-3.13.4-py3-none-any.whl", hash = "sha256:404e5e9253aa60ad457cae1be07c0f0ca90a63931200a47d9b6a6af84fd7b45f"}, - {file = "filelock-3.13.4.tar.gz", hash = "sha256:d13f466618bfde72bd2c18255e269f72542c6e70e7bac83a0232d6b1cc5c8cf4"}, + {file = "filelock-3.14.0-py3-none-any.whl", hash = "sha256:43339835842f110ca7ae60f1e1c160714c5a6afd15a2873419ab185334975c0f"}, + {file = "filelock-3.14.0.tar.gz", hash = "sha256:6ea72da3be9b8c82afd3edcf99f2fffbb5076335a5ae4d03248bb5b6c3eae78a"}, ] [package.extras] @@ -1050,69 +1050,61 @@ preview = ["glfw-preview"] [[package]] name = "grpcio" -version = "1.62.2" +version = "1.63.0" description = "HTTP/2-based RPC framework" optional = false -python-versions = ">=3.7" +python-versions = ">=3.8" files = [ - {file = "grpcio-1.62.2-cp310-cp310-linux_armv7l.whl", hash = "sha256:66344ea741124c38588a664237ac2fa16dfd226964cca23ddc96bd4accccbde5"}, - {file = "grpcio-1.62.2-cp310-cp310-macosx_12_0_universal2.whl", hash = "sha256:5dab7ac2c1e7cb6179c6bfad6b63174851102cbe0682294e6b1d6f0981ad7138"}, - {file = "grpcio-1.62.2-cp310-cp310-manylinux_2_17_aarch64.whl", hash = "sha256:3ad00f3f0718894749d5a8bb0fa125a7980a2f49523731a9b1fabf2b3522aa43"}, - {file = "grpcio-1.62.2-cp310-cp310-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:2e72ddfee62430ea80133d2cbe788e0d06b12f865765cb24a40009668bd8ea05"}, - {file = "grpcio-1.62.2-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:53d3a59a10af4c2558a8e563aed9f256259d2992ae0d3037817b2155f0341de1"}, - {file = "grpcio-1.62.2-cp310-cp310-musllinux_1_1_i686.whl", hash = "sha256:a1511a303f8074f67af4119275b4f954189e8313541da7b88b1b3a71425cdb10"}, - {file = "grpcio-1.62.2-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:b94d41b7412ef149743fbc3178e59d95228a7064c5ab4760ae82b562bdffb199"}, - {file = "grpcio-1.62.2-cp310-cp310-win32.whl", hash = "sha256:a75af2fc7cb1fe25785be7bed1ab18cef959a376cdae7c6870184307614caa3f"}, - {file = "grpcio-1.62.2-cp310-cp310-win_amd64.whl", hash = "sha256:80407bc007754f108dc2061e37480238b0dc1952c855e86a4fc283501ee6bb5d"}, - {file = "grpcio-1.62.2-cp311-cp311-linux_armv7l.whl", hash = "sha256:c1624aa686d4b36790ed1c2e2306cc3498778dffaf7b8dd47066cf819028c3ad"}, - {file = "grpcio-1.62.2-cp311-cp311-macosx_10_10_universal2.whl", hash = "sha256:1c1bb80299bdef33309dff03932264636450c8fdb142ea39f47e06a7153d3063"}, - {file = "grpcio-1.62.2-cp311-cp311-manylinux_2_17_aarch64.whl", hash = "sha256:db068bbc9b1fa16479a82e1ecf172a93874540cb84be69f0b9cb9b7ac3c82670"}, - {file = "grpcio-1.62.2-cp311-cp311-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:e2cc8a308780edbe2c4913d6a49dbdb5befacdf72d489a368566be44cadaef1a"}, - {file = "grpcio-1.62.2-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:d0695ae31a89f1a8fc8256050329a91a9995b549a88619263a594ca31b76d756"}, - {file = "grpcio-1.62.2-cp311-cp311-musllinux_1_1_i686.whl", hash = "sha256:88b4f9ee77191dcdd8810241e89340a12cbe050be3e0d5f2f091c15571cd3930"}, - {file = "grpcio-1.62.2-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:2a0204532aa2f1afd467024b02b4069246320405bc18abec7babab03e2644e75"}, - {file = "grpcio-1.62.2-cp311-cp311-win32.whl", hash = "sha256:6e784f60e575a0de554ef9251cbc2ceb8790914fe324f11e28450047f264ee6f"}, - {file = "grpcio-1.62.2-cp311-cp311-win_amd64.whl", hash = "sha256:112eaa7865dd9e6d7c0556c8b04ae3c3a2dc35d62ad3373ab7f6a562d8199200"}, - {file = "grpcio-1.62.2-cp312-cp312-linux_armv7l.whl", hash = "sha256:65034473fc09628a02fb85f26e73885cf1ed39ebd9cf270247b38689ff5942c5"}, - {file = "grpcio-1.62.2-cp312-cp312-macosx_10_10_universal2.whl", hash = "sha256:d2c1771d0ee3cf72d69bb5e82c6a82f27fbd504c8c782575eddb7839729fbaad"}, - {file = "grpcio-1.62.2-cp312-cp312-manylinux_2_17_aarch64.whl", hash = "sha256:3abe6838196da518863b5d549938ce3159d809218936851b395b09cad9b5d64a"}, - {file = "grpcio-1.62.2-cp312-cp312-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:c5ffeb269f10cedb4f33142b89a061acda9f672fd1357331dbfd043422c94e9e"}, - {file = "grpcio-1.62.2-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:404d3b4b6b142b99ba1cff0b2177d26b623101ea2ce51c25ef6e53d9d0d87bcc"}, - {file = "grpcio-1.62.2-cp312-cp312-musllinux_1_1_i686.whl", hash = "sha256:262cda97efdabb20853d3b5a4c546a535347c14b64c017f628ca0cc7fa780cc6"}, - {file = "grpcio-1.62.2-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:17708db5b11b966373e21519c4c73e5a750555f02fde82276ea2a267077c68ad"}, - {file = "grpcio-1.62.2-cp312-cp312-win32.whl", hash = "sha256:b7ec9e2f8ffc8436f6b642a10019fc513722858f295f7efc28de135d336ac189"}, - {file = "grpcio-1.62.2-cp312-cp312-win_amd64.whl", hash = "sha256:aa787b83a3cd5e482e5c79be030e2b4a122ecc6c5c6c4c42a023a2b581fdf17b"}, - {file = "grpcio-1.62.2-cp37-cp37m-linux_armv7l.whl", hash = "sha256:cfd23ad29bfa13fd4188433b0e250f84ec2c8ba66b14a9877e8bce05b524cf54"}, - {file = "grpcio-1.62.2-cp37-cp37m-macosx_10_10_universal2.whl", hash = "sha256:af15e9efa4d776dfcecd1d083f3ccfb04f876d613e90ef8432432efbeeac689d"}, - {file = "grpcio-1.62.2-cp37-cp37m-manylinux_2_17_aarch64.whl", hash = "sha256:f4aa94361bb5141a45ca9187464ae81a92a2a135ce2800b2203134f7a1a1d479"}, - {file = "grpcio-1.62.2-cp37-cp37m-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:82af3613a219512a28ee5c95578eb38d44dd03bca02fd918aa05603c41018051"}, - {file = "grpcio-1.62.2-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:55ddaf53474e8caeb29eb03e3202f9d827ad3110475a21245f3c7712022882a9"}, - {file = "grpcio-1.62.2-cp37-cp37m-musllinux_1_1_i686.whl", hash = "sha256:c79b518c56dddeec79e5500a53d8a4db90da995dfe1738c3ac57fe46348be049"}, - {file = "grpcio-1.62.2-cp37-cp37m-musllinux_1_1_x86_64.whl", hash = "sha256:a5eb4844e5e60bf2c446ef38c5b40d7752c6effdee882f716eb57ae87255d20a"}, - {file = "grpcio-1.62.2-cp37-cp37m-win_amd64.whl", hash = "sha256:aaae70364a2d1fb238afd6cc9fcb10442b66e397fd559d3f0968d28cc3ac929c"}, - {file = "grpcio-1.62.2-cp38-cp38-linux_armv7l.whl", hash = "sha256:1bcfe5070e4406f489e39325b76caeadab28c32bf9252d3ae960c79935a4cc36"}, - {file = "grpcio-1.62.2-cp38-cp38-macosx_10_10_universal2.whl", hash = "sha256:da6a7b6b938c15fa0f0568e482efaae9c3af31963eec2da4ff13a6d8ec2888e4"}, - {file = "grpcio-1.62.2-cp38-cp38-manylinux_2_17_aarch64.whl", hash = "sha256:41955b641c34db7d84db8d306937b72bc4968eef1c401bea73081a8d6c3d8033"}, - {file = "grpcio-1.62.2-cp38-cp38-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:c772f225483905f675cb36a025969eef9712f4698364ecd3a63093760deea1bc"}, - {file = "grpcio-1.62.2-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:07ce1f775d37ca18c7a141300e5b71539690efa1f51fe17f812ca85b5e73262f"}, - {file = "grpcio-1.62.2-cp38-cp38-musllinux_1_1_i686.whl", hash = "sha256:26f415f40f4a93579fd648f48dca1c13dfacdfd0290f4a30f9b9aeb745026811"}, - {file = "grpcio-1.62.2-cp38-cp38-musllinux_1_1_x86_64.whl", hash = "sha256:db707e3685ff16fc1eccad68527d072ac8bdd2e390f6daa97bc394ea7de4acea"}, - {file = "grpcio-1.62.2-cp38-cp38-win32.whl", hash = "sha256:589ea8e75de5fd6df387de53af6c9189c5231e212b9aa306b6b0d4f07520fbb9"}, - {file = "grpcio-1.62.2-cp38-cp38-win_amd64.whl", hash = "sha256:3c3ed41f4d7a3aabf0f01ecc70d6b5d00ce1800d4af652a549de3f7cf35c4abd"}, - {file = "grpcio-1.62.2-cp39-cp39-linux_armv7l.whl", hash = "sha256:162ccf61499c893831b8437120600290a99c0bc1ce7b51f2c8d21ec87ff6af8b"}, - {file = "grpcio-1.62.2-cp39-cp39-macosx_10_10_universal2.whl", hash = "sha256:f27246d7da7d7e3bd8612f63785a7b0c39a244cf14b8dd9dd2f2fab939f2d7f1"}, - {file = "grpcio-1.62.2-cp39-cp39-manylinux_2_17_aarch64.whl", hash = "sha256:2507006c8a478f19e99b6fe36a2464696b89d40d88f34e4b709abe57e1337467"}, - {file = "grpcio-1.62.2-cp39-cp39-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:a90ac47a8ce934e2c8d71e317d2f9e7e6aaceb2d199de940ce2c2eb611b8c0f4"}, - {file = "grpcio-1.62.2-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:99701979bcaaa7de8d5f60476487c5df8f27483624f1f7e300ff4669ee44d1f2"}, - {file = "grpcio-1.62.2-cp39-cp39-musllinux_1_1_i686.whl", hash = "sha256:af7dc3f7a44f10863b1b0ecab4078f0a00f561aae1edbd01fd03ad4dcf61c9e9"}, - {file = "grpcio-1.62.2-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:fa63245271920786f4cb44dcada4983a3516be8f470924528cf658731864c14b"}, - {file = "grpcio-1.62.2-cp39-cp39-win32.whl", hash = "sha256:c6ad9c39704256ed91a1cffc1379d63f7d0278d6a0bad06b0330f5d30291e3a3"}, - {file = "grpcio-1.62.2-cp39-cp39-win_amd64.whl", hash = "sha256:16da954692fd61aa4941fbeda405a756cd96b97b5d95ca58a92547bba2c1624f"}, - {file = "grpcio-1.62.2.tar.gz", hash = "sha256:c77618071d96b7a8be2c10701a98537823b9c65ba256c0b9067e0594cdbd954d"}, + {file = "grpcio-1.63.0-cp310-cp310-linux_armv7l.whl", hash = "sha256:2e93aca840c29d4ab5db93f94ed0a0ca899e241f2e8aec6334ab3575dc46125c"}, + {file = "grpcio-1.63.0-cp310-cp310-macosx_12_0_universal2.whl", hash = "sha256:91b73d3f1340fefa1e1716c8c1ec9930c676d6b10a3513ab6c26004cb02d8b3f"}, + {file = "grpcio-1.63.0-cp310-cp310-manylinux_2_17_aarch64.whl", hash = "sha256:b3afbd9d6827fa6f475a4f91db55e441113f6d3eb9b7ebb8fb806e5bb6d6bd0d"}, + {file = "grpcio-1.63.0-cp310-cp310-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:8f3f6883ce54a7a5f47db43289a0a4c776487912de1a0e2cc83fdaec9685cc9f"}, + {file = "grpcio-1.63.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:cf8dae9cc0412cb86c8de5a8f3be395c5119a370f3ce2e69c8b7d46bb9872c8d"}, + {file = "grpcio-1.63.0-cp310-cp310-musllinux_1_1_i686.whl", hash = "sha256:08e1559fd3b3b4468486b26b0af64a3904a8dbc78d8d936af9c1cf9636eb3e8b"}, + {file = "grpcio-1.63.0-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:5c039ef01516039fa39da8a8a43a95b64e288f79f42a17e6c2904a02a319b357"}, + {file = "grpcio-1.63.0-cp310-cp310-win32.whl", hash = "sha256:ad2ac8903b2eae071055a927ef74121ed52d69468e91d9bcbd028bd0e554be6d"}, + {file = "grpcio-1.63.0-cp310-cp310-win_amd64.whl", hash = "sha256:b2e44f59316716532a993ca2966636df6fbe7be4ab6f099de6815570ebe4383a"}, + {file = "grpcio-1.63.0-cp311-cp311-linux_armv7l.whl", hash = "sha256:f28f8b2db7b86c77916829d64ab21ff49a9d8289ea1564a2b2a3a8ed9ffcccd3"}, + {file = "grpcio-1.63.0-cp311-cp311-macosx_10_9_universal2.whl", hash = "sha256:65bf975639a1f93bee63ca60d2e4951f1b543f498d581869922910a476ead2f5"}, + {file = "grpcio-1.63.0-cp311-cp311-manylinux_2_17_aarch64.whl", hash = "sha256:b5194775fec7dc3dbd6a935102bb156cd2c35efe1685b0a46c67b927c74f0cfb"}, + {file = "grpcio-1.63.0-cp311-cp311-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:e4cbb2100ee46d024c45920d16e888ee5d3cf47c66e316210bc236d5bebc42b3"}, + {file = "grpcio-1.63.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:1ff737cf29b5b801619f10e59b581869e32f400159e8b12d7a97e7e3bdeee6a2"}, + {file = "grpcio-1.63.0-cp311-cp311-musllinux_1_1_i686.whl", hash = "sha256:cd1e68776262dd44dedd7381b1a0ad09d9930ffb405f737d64f505eb7f77d6c7"}, + {file = "grpcio-1.63.0-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:93f45f27f516548e23e4ec3fbab21b060416007dbe768a111fc4611464cc773f"}, + {file = "grpcio-1.63.0-cp311-cp311-win32.whl", hash = "sha256:878b1d88d0137df60e6b09b74cdb73db123f9579232c8456f53e9abc4f62eb3c"}, + {file = "grpcio-1.63.0-cp311-cp311-win_amd64.whl", hash = "sha256:756fed02dacd24e8f488f295a913f250b56b98fb793f41d5b2de6c44fb762434"}, + {file = "grpcio-1.63.0-cp312-cp312-linux_armv7l.whl", hash = "sha256:93a46794cc96c3a674cdfb59ef9ce84d46185fe9421baf2268ccb556f8f81f57"}, + {file = "grpcio-1.63.0-cp312-cp312-macosx_10_9_universal2.whl", hash = "sha256:a7b19dfc74d0be7032ca1eda0ed545e582ee46cd65c162f9e9fc6b26ef827dc6"}, + {file = "grpcio-1.63.0-cp312-cp312-manylinux_2_17_aarch64.whl", hash = "sha256:8064d986d3a64ba21e498b9a376cbc5d6ab2e8ab0e288d39f266f0fca169b90d"}, + {file = "grpcio-1.63.0-cp312-cp312-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:219bb1848cd2c90348c79ed0a6b0ea51866bc7e72fa6e205e459fedab5770172"}, + {file = "grpcio-1.63.0-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:a2d60cd1d58817bc5985fae6168d8b5655c4981d448d0f5b6194bbcc038090d2"}, + {file = "grpcio-1.63.0-cp312-cp312-musllinux_1_1_i686.whl", hash = "sha256:9e350cb096e5c67832e9b6e018cf8a0d2a53b2a958f6251615173165269a91b0"}, + {file = "grpcio-1.63.0-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:56cdf96ff82e3cc90dbe8bac260352993f23e8e256e063c327b6cf9c88daf7a9"}, + {file = "grpcio-1.63.0-cp312-cp312-win32.whl", hash = "sha256:3a6d1f9ea965e750db7b4ee6f9fdef5fdf135abe8a249e75d84b0a3e0c668a1b"}, + {file = "grpcio-1.63.0-cp312-cp312-win_amd64.whl", hash = "sha256:d2497769895bb03efe3187fb1888fc20e98a5f18b3d14b606167dacda5789434"}, + {file = "grpcio-1.63.0-cp38-cp38-linux_armv7l.whl", hash = "sha256:fdf348ae69c6ff484402cfdb14e18c1b0054ac2420079d575c53a60b9b2853ae"}, + {file = "grpcio-1.63.0-cp38-cp38-macosx_10_9_universal2.whl", hash = "sha256:a3abfe0b0f6798dedd2e9e92e881d9acd0fdb62ae27dcbbfa7654a57e24060c0"}, + {file = "grpcio-1.63.0-cp38-cp38-manylinux_2_17_aarch64.whl", hash = "sha256:6ef0ad92873672a2a3767cb827b64741c363ebaa27e7f21659e4e31f4d750280"}, + {file = "grpcio-1.63.0-cp38-cp38-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:b416252ac5588d9dfb8a30a191451adbf534e9ce5f56bb02cd193f12d8845b7f"}, + {file = "grpcio-1.63.0-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:e3b77eaefc74d7eb861d3ffbdf91b50a1bb1639514ebe764c47773b833fa2d91"}, + {file = "grpcio-1.63.0-cp38-cp38-musllinux_1_1_i686.whl", hash = "sha256:b005292369d9c1f80bf70c1db1c17c6c342da7576f1c689e8eee4fb0c256af85"}, + {file = "grpcio-1.63.0-cp38-cp38-musllinux_1_1_x86_64.whl", hash = "sha256:cdcda1156dcc41e042d1e899ba1f5c2e9f3cd7625b3d6ebfa619806a4c1aadda"}, + {file = "grpcio-1.63.0-cp38-cp38-win32.whl", hash = "sha256:01799e8649f9e94ba7db1aeb3452188048b0019dc37696b0f5ce212c87c560c3"}, + {file = "grpcio-1.63.0-cp38-cp38-win_amd64.whl", hash = "sha256:6a1a3642d76f887aa4009d92f71eb37809abceb3b7b5a1eec9c554a246f20e3a"}, + {file = "grpcio-1.63.0-cp39-cp39-linux_armv7l.whl", hash = "sha256:75f701ff645858a2b16bc8c9fc68af215a8bb2d5a9b647448129de6e85d52bce"}, + {file = "grpcio-1.63.0-cp39-cp39-macosx_10_9_universal2.whl", hash = "sha256:cacdef0348a08e475a721967f48206a2254a1b26ee7637638d9e081761a5ba86"}, + {file = "grpcio-1.63.0-cp39-cp39-manylinux_2_17_aarch64.whl", hash = "sha256:0697563d1d84d6985e40ec5ec596ff41b52abb3fd91ec240e8cb44a63b895094"}, + {file = "grpcio-1.63.0-cp39-cp39-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:6426e1fb92d006e47476d42b8f240c1d916a6d4423c5258ccc5b105e43438f61"}, + {file = "grpcio-1.63.0-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:e48cee31bc5f5a31fb2f3b573764bd563aaa5472342860edcc7039525b53e46a"}, + {file = "grpcio-1.63.0-cp39-cp39-musllinux_1_1_i686.whl", hash = "sha256:50344663068041b34a992c19c600236e7abb42d6ec32567916b87b4c8b8833b3"}, + {file = "grpcio-1.63.0-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:259e11932230d70ef24a21b9fb5bb947eb4703f57865a404054400ee92f42f5d"}, + {file = "grpcio-1.63.0-cp39-cp39-win32.whl", hash = "sha256:a44624aad77bf8ca198c55af811fd28f2b3eaf0a50ec5b57b06c034416ef2d0a"}, + {file = "grpcio-1.63.0-cp39-cp39-win_amd64.whl", hash = "sha256:166e5c460e5d7d4656ff9e63b13e1f6029b122104c1633d5f37eaea348d7356d"}, + {file = "grpcio-1.63.0.tar.gz", hash = "sha256:f3023e14805c61bc439fb40ca545ac3d5740ce66120a678a3c6c2c55b70343d1"}, ] [package.extras] -protobuf = ["grpcio-tools (>=1.62.2)"] +protobuf = ["grpcio-tools (>=1.63.0)"] [[package]] name = "gym-aloha" @@ -2414,7 +2406,6 @@ optional = false python-versions = ">=3.9" files = [ {file = "pandas-2.2.2-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:90c6fca2acf139569e74e8781709dccb6fe25940488755716d1d354d6bc58bce"}, - {file = "pandas-2.2.2-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:c7adfc142dac335d8c1e0dcbd37eb8617eac386596eb9e1a1b77791cf2498238"}, {file = "pandas-2.2.2-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:4abfe0be0d7221be4f12552995e58723c7422c80a659da13ca382697de830c08"}, {file = "pandas-2.2.2-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:8635c16bf3d99040fdf3ca3db669a7250ddf49c55dc4aa8fe0ae0fa8d6dcc1f0"}, {file = "pandas-2.2.2-cp310-cp310-musllinux_1_1_aarch64.whl", hash = "sha256:40ae1dffb3967a52203105a077415a86044a2bea011b5f321c6aa64b379a3f51"}, @@ -2435,7 +2426,6 @@ files = [ {file = "pandas-2.2.2-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:43498c0bdb43d55cb162cdc8c06fac328ccb5d2eabe3cadeb3529ae6f0517c32"}, {file = "pandas-2.2.2-cp312-cp312-win_amd64.whl", hash = "sha256:d187d355ecec3629624fccb01d104da7d7f391db0311145817525281e2804d23"}, {file = "pandas-2.2.2-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:0ca6377b8fca51815f382bd0b697a0814c8bda55115678cbc94c30aacbb6eff2"}, - {file = "pandas-2.2.2-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:9057e6aa78a584bc93a13f0a9bf7e753a5e9770a30b4d758b8d5f2a62a9433cd"}, {file = "pandas-2.2.2-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:001910ad31abc7bf06f49dcc903755d2f7f3a9186c0c040b827e522e9cef0863"}, {file = "pandas-2.2.2-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:66b479b0bd07204e37583c191535505410daa8df638fd8e75ae1b383851fe921"}, {file = "pandas-2.2.2-cp39-cp39-musllinux_1_1_aarch64.whl", hash = "sha256:a77e9d1c386196879aa5eb712e77461aaee433e54c68cf253053a73b7e49c33a"}, @@ -3110,104 +3100,90 @@ files = [ [[package]] name = "regex" -version = "2024.4.16" +version = "2024.4.28" description = "Alternative regular expression module, to replace re." optional = false -python-versions = ">=3.7" +python-versions = ">=3.8" files = [ - {file = "regex-2024.4.16-cp310-cp310-macosx_10_9_universal2.whl", hash = "sha256:fb83cc090eac63c006871fd24db5e30a1f282faa46328572661c0a24a2323a08"}, - {file = "regex-2024.4.16-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:8c91e1763696c0eb66340c4df98623c2d4e77d0746b8f8f2bee2c6883fd1fe18"}, - {file = "regex-2024.4.16-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:10188fe732dec829c7acca7422cdd1bf57d853c7199d5a9e96bb4d40db239c73"}, - {file = "regex-2024.4.16-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:956b58d692f235cfbf5b4f3abd6d99bf102f161ccfe20d2fd0904f51c72c4c66"}, - {file = "regex-2024.4.16-cp310-cp310-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:a70b51f55fd954d1f194271695821dd62054d949efd6368d8be64edd37f55c86"}, - {file = "regex-2024.4.16-cp310-cp310-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:5c02fcd2bf45162280613d2e4a1ca3ac558ff921ae4e308ecb307650d3a6ee51"}, - {file = "regex-2024.4.16-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:c4ed75ea6892a56896d78f11006161eea52c45a14994794bcfa1654430984b22"}, - {file = "regex-2024.4.16-cp310-cp310-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:bd727ad276bb91928879f3aa6396c9a1d34e5e180dce40578421a691eeb77f47"}, - {file = "regex-2024.4.16-cp310-cp310-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_12_x86_64.manylinux2010_x86_64.whl", hash = "sha256:7cbc5d9e8a1781e7be17da67b92580d6ce4dcef5819c1b1b89f49d9678cc278c"}, - {file = "regex-2024.4.16-cp310-cp310-musllinux_1_1_aarch64.whl", hash = "sha256:78fddb22b9ef810b63ef341c9fcf6455232d97cfe03938cbc29e2672c436670e"}, - {file = "regex-2024.4.16-cp310-cp310-musllinux_1_1_i686.whl", hash = "sha256:445ca8d3c5a01309633a0c9db57150312a181146315693273e35d936472df912"}, - {file = "regex-2024.4.16-cp310-cp310-musllinux_1_1_ppc64le.whl", hash = "sha256:95399831a206211d6bc40224af1c635cb8790ddd5c7493e0bd03b85711076a53"}, - {file = "regex-2024.4.16-cp310-cp310-musllinux_1_1_s390x.whl", hash = "sha256:7731728b6568fc286d86745f27f07266de49603a6fdc4d19c87e8c247be452af"}, - {file = "regex-2024.4.16-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:4facc913e10bdba42ec0aee76d029aedda628161a7ce4116b16680a0413f658a"}, - {file = "regex-2024.4.16-cp310-cp310-win32.whl", hash = "sha256:911742856ce98d879acbea33fcc03c1d8dc1106234c5e7d068932c945db209c0"}, - {file = "regex-2024.4.16-cp310-cp310-win_amd64.whl", hash = "sha256:e0a2df336d1135a0b3a67f3bbf78a75f69562c1199ed9935372b82215cddd6e2"}, - {file = "regex-2024.4.16-cp311-cp311-macosx_10_9_universal2.whl", hash = "sha256:1210365faba7c2150451eb78ec5687871c796b0f1fa701bfd2a4a25420482d26"}, - {file = "regex-2024.4.16-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:9ab40412f8cd6f615bfedea40c8bf0407d41bf83b96f6fc9ff34976d6b7037fd"}, - {file = "regex-2024.4.16-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:fd80d1280d473500d8086d104962a82d77bfbf2b118053824b7be28cd5a79ea5"}, - {file = "regex-2024.4.16-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:7bb966fdd9217e53abf824f437a5a2d643a38d4fd5fd0ca711b9da683d452969"}, - {file = "regex-2024.4.16-cp311-cp311-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:20b7a68444f536365af42a75ccecb7ab41a896a04acf58432db9e206f4e525d6"}, - {file = "regex-2024.4.16-cp311-cp311-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:b74586dd0b039c62416034f811d7ee62810174bb70dffcca6439f5236249eb09"}, - {file = "regex-2024.4.16-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:0c8290b44d8b0af4e77048646c10c6e3aa583c1ca67f3b5ffb6e06cf0c6f0f89"}, - {file = "regex-2024.4.16-cp311-cp311-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:f2d80a6749724b37853ece57988b39c4e79d2b5fe2869a86e8aeae3bbeef9eb0"}, - {file = "regex-2024.4.16-cp311-cp311-musllinux_1_1_aarch64.whl", hash = "sha256:3a1018e97aeb24e4f939afcd88211ace472ba566efc5bdf53fd8fd7f41fa7170"}, - {file = "regex-2024.4.16-cp311-cp311-musllinux_1_1_i686.whl", hash = "sha256:8d015604ee6204e76569d2f44e5a210728fa917115bef0d102f4107e622b08d5"}, - {file = "regex-2024.4.16-cp311-cp311-musllinux_1_1_ppc64le.whl", hash = "sha256:3d5ac5234fb5053850d79dd8eb1015cb0d7d9ed951fa37aa9e6249a19aa4f336"}, - {file = "regex-2024.4.16-cp311-cp311-musllinux_1_1_s390x.whl", hash = "sha256:0a38d151e2cdd66d16dab550c22f9521ba79761423b87c01dae0a6e9add79c0d"}, - {file = "regex-2024.4.16-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:159dc4e59a159cb8e4e8f8961eb1fa5d58f93cb1acd1701d8aff38d45e1a84a6"}, - {file = "regex-2024.4.16-cp311-cp311-win32.whl", hash = "sha256:ba2336d6548dee3117520545cfe44dc28a250aa091f8281d28804aa8d707d93d"}, - {file = "regex-2024.4.16-cp311-cp311-win_amd64.whl", hash = "sha256:8f83b6fd3dc3ba94d2b22717f9c8b8512354fd95221ac661784df2769ea9bba9"}, - {file = "regex-2024.4.16-cp312-cp312-macosx_10_9_universal2.whl", hash = "sha256:80b696e8972b81edf0af2a259e1b2a4a661f818fae22e5fa4fa1a995fb4a40fd"}, - {file = "regex-2024.4.16-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:d61ae114d2a2311f61d90c2ef1358518e8f05eafda76eaf9c772a077e0b465ec"}, - {file = "regex-2024.4.16-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:8ba6745440b9a27336443b0c285d705ce73adb9ec90e2f2004c64d95ab5a7598"}, - {file = "regex-2024.4.16-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:6295004b2dd37b0835ea5c14a33e00e8cfa3c4add4d587b77287825f3418d310"}, - {file = "regex-2024.4.16-cp312-cp312-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:4aba818dcc7263852aabb172ec27b71d2abca02a593b95fa79351b2774eb1d2b"}, - {file = "regex-2024.4.16-cp312-cp312-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:d0800631e565c47520aaa04ae38b96abc5196fe8b4aa9bd864445bd2b5848a7a"}, - {file = "regex-2024.4.16-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:08dea89f859c3df48a440dbdcd7b7155bc675f2fa2ec8c521d02dc69e877db70"}, - {file = "regex-2024.4.16-cp312-cp312-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:eeaa0b5328b785abc344acc6241cffde50dc394a0644a968add75fcefe15b9d4"}, - {file = "regex-2024.4.16-cp312-cp312-musllinux_1_1_aarch64.whl", hash = "sha256:4e819a806420bc010489f4e741b3036071aba209f2e0989d4750b08b12a9343f"}, - {file = "regex-2024.4.16-cp312-cp312-musllinux_1_1_i686.whl", hash = "sha256:c2d0e7cbb6341e830adcbfa2479fdeebbfbb328f11edd6b5675674e7a1e37730"}, - {file = "regex-2024.4.16-cp312-cp312-musllinux_1_1_ppc64le.whl", hash = "sha256:91797b98f5e34b6a49f54be33f72e2fb658018ae532be2f79f7c63b4ae225145"}, - {file = "regex-2024.4.16-cp312-cp312-musllinux_1_1_s390x.whl", hash = "sha256:d2da13568eff02b30fd54fccd1e042a70fe920d816616fda4bf54ec705668d81"}, - {file = "regex-2024.4.16-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:370c68dc5570b394cbaadff50e64d705f64debed30573e5c313c360689b6aadc"}, - {file = "regex-2024.4.16-cp312-cp312-win32.whl", hash = "sha256:904c883cf10a975b02ab3478bce652f0f5346a2c28d0a8521d97bb23c323cc8b"}, - {file = "regex-2024.4.16-cp312-cp312-win_amd64.whl", hash = "sha256:785c071c982dce54d44ea0b79cd6dfafddeccdd98cfa5f7b86ef69b381b457d9"}, - {file = "regex-2024.4.16-cp37-cp37m-macosx_10_9_x86_64.whl", hash = "sha256:e2f142b45c6fed48166faeb4303b4b58c9fcd827da63f4cf0a123c3480ae11fb"}, - {file = "regex-2024.4.16-cp37-cp37m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:e87ab229332ceb127a165612d839ab87795972102cb9830e5f12b8c9a5c1b508"}, - {file = "regex-2024.4.16-cp37-cp37m-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:81500ed5af2090b4a9157a59dbc89873a25c33db1bb9a8cf123837dcc9765047"}, - {file = "regex-2024.4.16-cp37-cp37m-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:b340cccad138ecb363324aa26893963dcabb02bb25e440ebdf42e30963f1a4e0"}, - {file = "regex-2024.4.16-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:2c72608e70f053643437bd2be0608f7f1c46d4022e4104d76826f0839199347a"}, - {file = "regex-2024.4.16-cp37-cp37m-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:a01fe2305e6232ef3e8f40bfc0f0f3a04def9aab514910fa4203bafbc0bb4682"}, - {file = "regex-2024.4.16-cp37-cp37m-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_12_x86_64.manylinux2010_x86_64.whl", hash = "sha256:03576e3a423d19dda13e55598f0fd507b5d660d42c51b02df4e0d97824fdcae3"}, - {file = "regex-2024.4.16-cp37-cp37m-musllinux_1_1_aarch64.whl", hash = "sha256:549c3584993772e25f02d0656ac48abdda73169fe347263948cf2b1cead622f3"}, - {file = "regex-2024.4.16-cp37-cp37m-musllinux_1_1_i686.whl", hash = "sha256:34422d5a69a60b7e9a07a690094e824b66f5ddc662a5fc600d65b7c174a05f04"}, - {file = "regex-2024.4.16-cp37-cp37m-musllinux_1_1_ppc64le.whl", hash = "sha256:5f580c651a72b75c39e311343fe6875d6f58cf51c471a97f15a938d9fe4e0d37"}, - {file = "regex-2024.4.16-cp37-cp37m-musllinux_1_1_s390x.whl", hash = "sha256:3399dd8a7495bbb2bacd59b84840eef9057826c664472e86c91d675d007137f5"}, - {file = "regex-2024.4.16-cp37-cp37m-musllinux_1_1_x86_64.whl", hash = "sha256:8d1f86f3f4e2388aa3310b50694ac44daefbd1681def26b4519bd050a398dc5a"}, - {file = "regex-2024.4.16-cp37-cp37m-win32.whl", hash = "sha256:dd5acc0a7d38fdc7a3a6fd3ad14c880819008ecb3379626e56b163165162cc46"}, - {file = "regex-2024.4.16-cp37-cp37m-win_amd64.whl", hash = "sha256:ba8122e3bb94ecda29a8de4cf889f600171424ea586847aa92c334772d200331"}, - {file = "regex-2024.4.16-cp38-cp38-macosx_10_9_universal2.whl", hash = "sha256:743deffdf3b3481da32e8a96887e2aa945ec6685af1cfe2bcc292638c9ba2f48"}, - {file = "regex-2024.4.16-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:7571f19f4a3fd00af9341c7801d1ad1967fc9c3f5e62402683047e7166b9f2b4"}, - {file = "regex-2024.4.16-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:df79012ebf6f4efb8d307b1328226aef24ca446b3ff8d0e30202d7ebcb977a8c"}, - {file = "regex-2024.4.16-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:e757d475953269fbf4b441207bb7dbdd1c43180711b6208e129b637792ac0b93"}, - {file = "regex-2024.4.16-cp38-cp38-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:4313ab9bf6a81206c8ac28fdfcddc0435299dc88cad12cc6305fd0e78b81f9e4"}, - {file = "regex-2024.4.16-cp38-cp38-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:d83c2bc678453646f1a18f8db1e927a2d3f4935031b9ad8a76e56760461105dd"}, - {file = "regex-2024.4.16-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:9df1bfef97db938469ef0a7354b2d591a2d438bc497b2c489471bec0e6baf7c4"}, - {file = "regex-2024.4.16-cp38-cp38-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:62120ed0de69b3649cc68e2965376048793f466c5a6c4370fb27c16c1beac22d"}, - {file = "regex-2024.4.16-cp38-cp38-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_12_x86_64.manylinux2010_x86_64.whl", hash = "sha256:c2ef6f7990b6e8758fe48ad08f7e2f66c8f11dc66e24093304b87cae9037bb4a"}, - {file = "regex-2024.4.16-cp38-cp38-musllinux_1_1_aarch64.whl", hash = "sha256:8fc6976a3395fe4d1fbeb984adaa8ec652a1e12f36b56ec8c236e5117b585427"}, - {file = "regex-2024.4.16-cp38-cp38-musllinux_1_1_i686.whl", hash = "sha256:03e68f44340528111067cecf12721c3df4811c67268b897fbe695c95f860ac42"}, - {file = "regex-2024.4.16-cp38-cp38-musllinux_1_1_ppc64le.whl", hash = "sha256:ec7e0043b91115f427998febaa2beb82c82df708168b35ece3accb610b91fac1"}, - {file = "regex-2024.4.16-cp38-cp38-musllinux_1_1_s390x.whl", hash = "sha256:c21fc21a4c7480479d12fd8e679b699f744f76bb05f53a1d14182b31f55aac76"}, - {file = "regex-2024.4.16-cp38-cp38-musllinux_1_1_x86_64.whl", hash = "sha256:12f6a3f2f58bb7344751919a1876ee1b976fe08b9ffccb4bbea66f26af6017b9"}, - {file = "regex-2024.4.16-cp38-cp38-win32.whl", hash = "sha256:479595a4fbe9ed8f8f72c59717e8cf222da2e4c07b6ae5b65411e6302af9708e"}, - {file = "regex-2024.4.16-cp38-cp38-win_amd64.whl", hash = "sha256:0534b034fba6101611968fae8e856c1698da97ce2efb5c2b895fc8b9e23a5834"}, - {file = "regex-2024.4.16-cp39-cp39-macosx_10_9_universal2.whl", hash = "sha256:a7ccdd1c4a3472a7533b0a7aa9ee34c9a2bef859ba86deec07aff2ad7e0c3b94"}, - {file = "regex-2024.4.16-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:6f2f017c5be19984fbbf55f8af6caba25e62c71293213f044da3ada7091a4455"}, - {file = "regex-2024.4.16-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:803b8905b52de78b173d3c1e83df0efb929621e7b7c5766c0843704d5332682f"}, - {file = "regex-2024.4.16-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:684008ec44ad275832a5a152f6e764bbe1914bea10968017b6feaecdad5736e0"}, - {file = "regex-2024.4.16-cp39-cp39-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:65436dce9fdc0aeeb0a0effe0839cb3d6a05f45aa45a4d9f9c60989beca78b9c"}, - {file = "regex-2024.4.16-cp39-cp39-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:ea355eb43b11764cf799dda62c658c4d2fdb16af41f59bb1ccfec517b60bcb07"}, - {file = "regex-2024.4.16-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:98c1165f3809ce7774f05cb74e5408cd3aa93ee8573ae959a97a53db3ca3180d"}, - {file = "regex-2024.4.16-cp39-cp39-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:cccc79a9be9b64c881f18305a7c715ba199e471a3973faeb7ba84172abb3f317"}, - {file = "regex-2024.4.16-cp39-cp39-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_12_x86_64.manylinux2010_x86_64.whl", hash = "sha256:00169caa125f35d1bca6045d65a662af0202704489fada95346cfa092ec23f39"}, - {file = "regex-2024.4.16-cp39-cp39-musllinux_1_1_aarch64.whl", hash = "sha256:6cc38067209354e16c5609b66285af17a2863a47585bcf75285cab33d4c3b8df"}, - {file = "regex-2024.4.16-cp39-cp39-musllinux_1_1_i686.whl", hash = "sha256:23cff1b267038501b179ccbbd74a821ac4a7192a1852d1d558e562b507d46013"}, - {file = "regex-2024.4.16-cp39-cp39-musllinux_1_1_ppc64le.whl", hash = "sha256:b9d320b3bf82a39f248769fc7f188e00f93526cc0fe739cfa197868633d44701"}, - {file = "regex-2024.4.16-cp39-cp39-musllinux_1_1_s390x.whl", hash = "sha256:89ec7f2c08937421bbbb8b48c54096fa4f88347946d4747021ad85f1b3021b3c"}, - {file = "regex-2024.4.16-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:4918fd5f8b43aa7ec031e0fef1ee02deb80b6afd49c85f0790be1dc4ce34cb50"}, - {file = "regex-2024.4.16-cp39-cp39-win32.whl", hash = "sha256:684e52023aec43bdf0250e843e1fdd6febbe831bd9d52da72333fa201aaa2335"}, - {file = "regex-2024.4.16-cp39-cp39-win_amd64.whl", hash = "sha256:e697e1c0238133589e00c244a8b676bc2cfc3ab4961318d902040d099fec7483"}, - {file = "regex-2024.4.16.tar.gz", hash = "sha256:fa454d26f2e87ad661c4f0c5a5fe4cf6aab1e307d1b94f16ffdfcb089ba685c0"}, + {file = "regex-2024.4.28-cp310-cp310-macosx_10_9_universal2.whl", hash = "sha256:cd196d056b40af073d95a2879678585f0b74ad35190fac04ca67954c582c6b61"}, + {file = "regex-2024.4.28-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:8bb381f777351bd534462f63e1c6afb10a7caa9fa2a421ae22c26e796fe31b1f"}, + {file = "regex-2024.4.28-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:47af45b6153522733aa6e92543938e97a70ce0900649ba626cf5aad290b737b6"}, + {file = "regex-2024.4.28-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:99d6a550425cc51c656331af0e2b1651e90eaaa23fb4acde577cf15068e2e20f"}, + {file = "regex-2024.4.28-cp310-cp310-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:bf29304a8011feb58913c382902fde3395957a47645bf848eea695839aa101b7"}, + {file = "regex-2024.4.28-cp310-cp310-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:92da587eee39a52c91aebea8b850e4e4f095fe5928d415cb7ed656b3460ae79a"}, + {file = "regex-2024.4.28-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:6277d426e2f31bdbacb377d17a7475e32b2d7d1f02faaecc48d8e370c6a3ff31"}, + {file = "regex-2024.4.28-cp310-cp310-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:28e1f28d07220c0f3da0e8fcd5a115bbb53f8b55cecf9bec0c946eb9a059a94c"}, + {file = "regex-2024.4.28-cp310-cp310-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_12_x86_64.manylinux2010_x86_64.whl", hash = "sha256:aaa179975a64790c1f2701ac562b5eeb733946eeb036b5bcca05c8d928a62f10"}, + {file = "regex-2024.4.28-cp310-cp310-musllinux_1_1_aarch64.whl", hash = "sha256:6f435946b7bf7a1b438b4e6b149b947c837cb23c704e780c19ba3e6855dbbdd3"}, + {file = "regex-2024.4.28-cp310-cp310-musllinux_1_1_i686.whl", hash = "sha256:19d6c11bf35a6ad077eb23852827f91c804eeb71ecb85db4ee1386825b9dc4db"}, + {file = "regex-2024.4.28-cp310-cp310-musllinux_1_1_ppc64le.whl", hash = "sha256:fdae0120cddc839eb8e3c15faa8ad541cc6d906d3eb24d82fb041cfe2807bc1e"}, + {file = "regex-2024.4.28-cp310-cp310-musllinux_1_1_s390x.whl", hash = "sha256:e672cf9caaf669053121f1766d659a8813bd547edef6e009205378faf45c67b8"}, + {file = "regex-2024.4.28-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:f57515750d07e14743db55d59759893fdb21d2668f39e549a7d6cad5d70f9fea"}, + {file = "regex-2024.4.28-cp310-cp310-win32.whl", hash = "sha256:a1409c4eccb6981c7baabc8888d3550df518add6e06fe74fa1d9312c1838652d"}, + {file = "regex-2024.4.28-cp310-cp310-win_amd64.whl", hash = "sha256:1f687a28640f763f23f8a9801fe9e1b37338bb1ca5d564ddd41619458f1f22d1"}, + {file = "regex-2024.4.28-cp311-cp311-macosx_10_9_universal2.whl", hash = "sha256:84077821c85f222362b72fdc44f7a3a13587a013a45cf14534df1cbbdc9a6796"}, + {file = "regex-2024.4.28-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:b45d4503de8f4f3dc02f1d28a9b039e5504a02cc18906cfe744c11def942e9eb"}, + {file = "regex-2024.4.28-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:457c2cd5a646dd4ed536c92b535d73548fb8e216ebee602aa9f48e068fc393f3"}, + {file = "regex-2024.4.28-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:2b51739ddfd013c6f657b55a508de8b9ea78b56d22b236052c3a85a675102dc6"}, + {file = "regex-2024.4.28-cp311-cp311-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:459226445c7d7454981c4c0ce0ad1a72e1e751c3e417f305722bbcee6697e06a"}, + {file = "regex-2024.4.28-cp311-cp311-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:670fa596984b08a4a769491cbdf22350431970d0112e03d7e4eeaecaafcd0fec"}, + {file = "regex-2024.4.28-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:fe00f4fe11c8a521b173e6324d862ee7ee3412bf7107570c9b564fe1119b56fb"}, + {file = "regex-2024.4.28-cp311-cp311-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:36f392dc7763fe7924575475736bddf9ab9f7a66b920932d0ea50c2ded2f5636"}, + {file = "regex-2024.4.28-cp311-cp311-musllinux_1_1_aarch64.whl", hash = "sha256:23a412b7b1a7063f81a742463f38821097b6a37ce1e5b89dd8e871d14dbfd86b"}, + {file = "regex-2024.4.28-cp311-cp311-musllinux_1_1_i686.whl", hash = "sha256:f1d6e4b7b2ae3a6a9df53efbf199e4bfcff0959dbdb5fd9ced34d4407348e39a"}, + {file = "regex-2024.4.28-cp311-cp311-musllinux_1_1_ppc64le.whl", hash = "sha256:499334ad139557de97cbc4347ee921c0e2b5e9c0f009859e74f3f77918339257"}, + {file = "regex-2024.4.28-cp311-cp311-musllinux_1_1_s390x.whl", hash = "sha256:0940038bec2fe9e26b203d636c44d31dd8766abc1fe66262da6484bd82461ccf"}, + {file = "regex-2024.4.28-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:66372c2a01782c5fe8e04bff4a2a0121a9897e19223d9eab30c54c50b2ebeb7f"}, + {file = "regex-2024.4.28-cp311-cp311-win32.whl", hash = "sha256:c77d10ec3c1cf328b2f501ca32583625987ea0f23a0c2a49b37a39ee5c4c4630"}, + {file = "regex-2024.4.28-cp311-cp311-win_amd64.whl", hash = "sha256:fc0916c4295c64d6890a46e02d4482bb5ccf33bf1a824c0eaa9e83b148291f90"}, + {file = "regex-2024.4.28-cp312-cp312-macosx_10_9_universal2.whl", hash = "sha256:08a1749f04fee2811c7617fdd46d2e46d09106fa8f475c884b65c01326eb15c5"}, + {file = "regex-2024.4.28-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:b8eb28995771c087a73338f695a08c9abfdf723d185e57b97f6175c5051ff1ae"}, + {file = "regex-2024.4.28-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:dd7ef715ccb8040954d44cfeff17e6b8e9f79c8019daae2fd30a8806ef5435c0"}, + {file = "regex-2024.4.28-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:fb0315a2b26fde4005a7c401707c5352df274460f2f85b209cf6024271373013"}, + {file = "regex-2024.4.28-cp312-cp312-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:f2fc053228a6bd3a17a9b0a3f15c3ab3cf95727b00557e92e1cfe094b88cc662"}, + {file = "regex-2024.4.28-cp312-cp312-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:7fe9739a686dc44733d52d6e4f7b9c77b285e49edf8570754b322bca6b85b4cc"}, + {file = "regex-2024.4.28-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:a74fcf77d979364f9b69fcf8200849ca29a374973dc193a7317698aa37d8b01c"}, + {file = "regex-2024.4.28-cp312-cp312-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:965fd0cf4694d76f6564896b422724ec7b959ef927a7cb187fc6b3f4e4f59833"}, + {file = "regex-2024.4.28-cp312-cp312-musllinux_1_1_aarch64.whl", hash = "sha256:2fef0b38c34ae675fcbb1b5db760d40c3fc3612cfa186e9e50df5782cac02bcd"}, + {file = "regex-2024.4.28-cp312-cp312-musllinux_1_1_i686.whl", hash = "sha256:bc365ce25f6c7c5ed70e4bc674f9137f52b7dd6a125037f9132a7be52b8a252f"}, + {file = "regex-2024.4.28-cp312-cp312-musllinux_1_1_ppc64le.whl", hash = "sha256:ac69b394764bb857429b031d29d9604842bc4cbfd964d764b1af1868eeebc4f0"}, + {file = "regex-2024.4.28-cp312-cp312-musllinux_1_1_s390x.whl", hash = "sha256:144a1fc54765f5c5c36d6d4b073299832aa1ec6a746a6452c3ee7b46b3d3b11d"}, + {file = "regex-2024.4.28-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:2630ca4e152c221072fd4a56d4622b5ada876f668ecd24d5ab62544ae6793ed6"}, + {file = "regex-2024.4.28-cp312-cp312-win32.whl", hash = "sha256:7f3502f03b4da52bbe8ba962621daa846f38489cae5c4a7b5d738f15f6443d17"}, + {file = "regex-2024.4.28-cp312-cp312-win_amd64.whl", hash = "sha256:0dd3f69098511e71880fb00f5815db9ed0ef62c05775395968299cb400aeab82"}, + {file = "regex-2024.4.28-cp38-cp38-macosx_10_9_universal2.whl", hash = "sha256:374f690e1dd0dbdcddea4a5c9bdd97632cf656c69113f7cd6a361f2a67221cb6"}, + {file = "regex-2024.4.28-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:25f87ae6b96374db20f180eab083aafe419b194e96e4f282c40191e71980c666"}, + {file = "regex-2024.4.28-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:5dbc1bcc7413eebe5f18196e22804a3be1bfdfc7e2afd415e12c068624d48247"}, + {file = "regex-2024.4.28-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:f85151ec5a232335f1be022b09fbbe459042ea1951d8a48fef251223fc67eee1"}, + {file = "regex-2024.4.28-cp38-cp38-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:57ba112e5530530fd175ed550373eb263db4ca98b5f00694d73b18b9a02e7185"}, + {file = "regex-2024.4.28-cp38-cp38-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:224803b74aab56aa7be313f92a8d9911dcade37e5f167db62a738d0c85fdac4b"}, + {file = "regex-2024.4.28-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:0a54a047b607fd2d2d52a05e6ad294602f1e0dec2291152b745870afc47c1397"}, + {file = "regex-2024.4.28-cp38-cp38-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:0a2a512d623f1f2d01d881513af9fc6a7c46e5cfffb7dc50c38ce959f9246c94"}, + {file = "regex-2024.4.28-cp38-cp38-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_12_x86_64.manylinux2010_x86_64.whl", hash = "sha256:c06bf3f38f0707592898428636cbb75d0a846651b053a1cf748763e3063a6925"}, + {file = "regex-2024.4.28-cp38-cp38-musllinux_1_1_aarch64.whl", hash = "sha256:1031a5e7b048ee371ab3653aad3030ecfad6ee9ecdc85f0242c57751a05b0ac4"}, + {file = "regex-2024.4.28-cp38-cp38-musllinux_1_1_i686.whl", hash = "sha256:d7a353ebfa7154c871a35caca7bfd8f9e18666829a1dc187115b80e35a29393e"}, + {file = "regex-2024.4.28-cp38-cp38-musllinux_1_1_ppc64le.whl", hash = "sha256:7e76b9cfbf5ced1aca15a0e5b6f229344d9b3123439ffce552b11faab0114a02"}, + {file = "regex-2024.4.28-cp38-cp38-musllinux_1_1_s390x.whl", hash = "sha256:5ce479ecc068bc2a74cb98dd8dba99e070d1b2f4a8371a7dfe631f85db70fe6e"}, + {file = "regex-2024.4.28-cp38-cp38-musllinux_1_1_x86_64.whl", hash = "sha256:7d77b6f63f806578c604dca209280e4c54f0fa9a8128bb8d2cc5fb6f99da4150"}, + {file = "regex-2024.4.28-cp38-cp38-win32.whl", hash = "sha256:d84308f097d7a513359757c69707ad339da799e53b7393819ec2ea36bc4beb58"}, + {file = "regex-2024.4.28-cp38-cp38-win_amd64.whl", hash = "sha256:2cc1b87bba1dd1a898e664a31012725e48af826bf3971e786c53e32e02adae6c"}, + {file = "regex-2024.4.28-cp39-cp39-macosx_10_9_universal2.whl", hash = "sha256:7413167c507a768eafb5424413c5b2f515c606be5bb4ef8c5dee43925aa5718b"}, + {file = "regex-2024.4.28-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:108e2dcf0b53a7c4ab8986842a8edcb8ab2e59919a74ff51c296772e8e74d0ae"}, + {file = "regex-2024.4.28-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:f1c5742c31ba7d72f2dedf7968998730664b45e38827637e0f04a2ac7de2f5f1"}, + {file = "regex-2024.4.28-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:ecc6148228c9ae25ce403eade13a0961de1cb016bdb35c6eafd8e7b87ad028b1"}, + {file = "regex-2024.4.28-cp39-cp39-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:b7d893c8cf0e2429b823ef1a1d360a25950ed11f0e2a9df2b5198821832e1947"}, + {file = "regex-2024.4.28-cp39-cp39-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:4290035b169578ffbbfa50d904d26bec16a94526071ebec3dadbebf67a26b25e"}, + {file = "regex-2024.4.28-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:44a22ae1cfd82e4ffa2066eb3390777dc79468f866f0625261a93e44cdf6482b"}, + {file = "regex-2024.4.28-cp39-cp39-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:fd24fd140b69f0b0bcc9165c397e9b2e89ecbeda83303abf2a072609f60239e2"}, + {file = "regex-2024.4.28-cp39-cp39-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_12_x86_64.manylinux2010_x86_64.whl", hash = "sha256:39fb166d2196413bead229cd64a2ffd6ec78ebab83fff7d2701103cf9f4dfd26"}, + {file = "regex-2024.4.28-cp39-cp39-musllinux_1_1_aarch64.whl", hash = "sha256:9301cc6db4d83d2c0719f7fcda37229691745168bf6ae849bea2e85fc769175d"}, + {file = "regex-2024.4.28-cp39-cp39-musllinux_1_1_i686.whl", hash = "sha256:7c3d389e8d76a49923683123730c33e9553063d9041658f23897f0b396b2386f"}, + {file = "regex-2024.4.28-cp39-cp39-musllinux_1_1_ppc64le.whl", hash = "sha256:99ef6289b62042500d581170d06e17f5353b111a15aa6b25b05b91c6886df8fc"}, + {file = "regex-2024.4.28-cp39-cp39-musllinux_1_1_s390x.whl", hash = "sha256:b91d529b47798c016d4b4c1d06cc826ac40d196da54f0de3c519f5a297c5076a"}, + {file = "regex-2024.4.28-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:43548ad74ea50456e1c68d3c67fff3de64c6edb85bcd511d1136f9b5376fc9d1"}, + {file = "regex-2024.4.28-cp39-cp39-win32.whl", hash = "sha256:05d9b6578a22db7dedb4df81451f360395828b04f4513980b6bd7a1412c679cc"}, + {file = "regex-2024.4.28-cp39-cp39-win_amd64.whl", hash = "sha256:3986217ec830c2109875be740531feb8ddafe0dfa49767cdcd072ed7e8927962"}, + {file = "regex-2024.4.28.tar.gz", hash = "sha256:83ab366777ea45d58f72593adf35d36ca911ea8bd838483c1823b883a121b0e4"}, ] [[package]] @@ -4014,13 +3990,13 @@ zstd = ["zstandard (>=0.18.0)"] [[package]] name = "virtualenv" -version = "20.26.0" +version = "20.26.1" description = "Virtual Python Environment builder" optional = true python-versions = ">=3.7" files = [ - {file = "virtualenv-20.26.0-py3-none-any.whl", hash = "sha256:0846377ea76e818daaa3e00a4365c018bc3ac9760cbb3544de542885aad61fb3"}, - {file = "virtualenv-20.26.0.tar.gz", hash = "sha256:ec25a9671a5102c8d2657f62792a27b48f016664c6873f6beed3800008577210"}, + {file = "virtualenv-20.26.1-py3-none-any.whl", hash = "sha256:7aa9982a728ae5892558bff6a2839c00b9ed145523ece2274fad6f414690ae75"}, + {file = "virtualenv-20.26.1.tar.gz", hash = "sha256:604bfdceaeece392802e6ae48e69cec49168b9c5f4a44e483963f9242eb0e78b"}, ] [package.dependencies] diff --git a/tests/test_available.py b/tests/test_available.py index 29f4f31e..b3d0cd78 100644 --- a/tests/test_available.py +++ b/tests/test_available.py @@ -4,9 +4,9 @@ import gymnasium as gym import pytest import lerobot -from lerobot.common.policies.act.modeling_act import ActionChunkingTransformerPolicy +from lerobot.common.policies.act.modeling_act import ACTPolicy from lerobot.common.policies.diffusion.modeling_diffusion import DiffusionPolicy -from lerobot.common.policies.tdmpc.policy import TDMPCPolicy +from lerobot.common.policies.tdmpc.modeling_tdmpc import TDMPCPolicy from tests.utils import require_env @@ -30,7 +30,7 @@ def test_available_policies(): consistent with those listed in `lerobot/__init__.py`. """ policy_classes = [ - ActionChunkingTransformerPolicy, + ACTPolicy, DiffusionPolicy, TDMPCPolicy, ] diff --git a/tests/test_datasets.py b/tests/test_datasets.py index 0bef1f44..d44ad78a 100644 --- a/tests/test_datasets.py +++ b/tests/test_datasets.py @@ -35,7 +35,7 @@ def test_factory(env_name, repo_id, policy_name): DEFAULT_CONFIG_PATH, overrides=[ f"env={env_name}", - f"dataset.repo_id={repo_id}", + f"dataset_repo_id={repo_id}", f"policy={policy_name}", f"device={DEVICE}", ], diff --git a/tests/test_examples.py b/tests/test_examples.py index 9f86fd03..1fca45f5 100644 --- a/tests/test_examples.py +++ b/tests/test_examples.py @@ -39,14 +39,14 @@ def test_examples_3_and_2(): ("training_steps = 5000", "training_steps = 1"), ("num_workers=4", "num_workers=0"), ('device = torch.device("cuda")', 'device = torch.device("cpu")'), - ("batch_size=cfg.batch_size", "batch_size=1"), + ("batch_size=64", "batch_size=1"), ], ) # Pass empty globals to allow dictionary comprehension https://stackoverflow.com/a/32897127/4391249. exec(file_contents, {}) - for file_name in ["model.pt", "config.yaml"]: + for file_name in ["model.safetensors", "config.json", "config.yaml"]: assert Path(f"outputs/train/example_pusht_diffusion/{file_name}").exists() path = "examples/2_evaluate_pretrained_policy.py" @@ -58,15 +58,15 @@ def test_examples_3_and_2(): file_contents = _find_and_replace( file_contents, [ - ('"eval_episodes=10"', '"eval_episodes=1"'), - ('"rollout_batch_size=10"', '"rollout_batch_size=1"'), - ('"device=cuda"', '"device=cpu"'), + ('pretrained_policy_name = "lerobot/diffusion_policy_pusht_image"', ""), + ("pretrained_policy_path = Path(snapshot_download(pretrained_policy_name))", ""), ( - '# folder = Path("outputs/train/example_pusht_diffusion")', - 'folder = Path("outputs/train/example_pusht_diffusion")', + '# pretrained_policy_path = Path("outputs/train/example_pusht_diffusion")', + 'pretrained_policy_path = Path("outputs/train/example_pusht_diffusion")', ), - ('hub_id = "lerobot/diffusion_policy_pusht_image"', ""), - ("folder = Path(snapshot_download(hub_id)", ""), + ('"eval.n_episodes=10"', '"eval.n_episodes=1"'), + ('"eval.batch_size=10"', '"eval.batch_size=1"'), + ('"device=cuda"', '"device=cpu"'), ], ) diff --git a/tests/test_policies.py b/tests/test_policies.py index e933ceaa..50f36a25 100644 --- a/tests/test_policies.py +++ b/tests/test_policies.py @@ -1,41 +1,50 @@ +import inspect + import pytest import torch +from huggingface_hub import PyTorchModelHubMixin +from lerobot import available_policies from lerobot.common.datasets.factory import make_dataset from lerobot.common.datasets.utils import cycle from lerobot.common.envs.factory import make_env from lerobot.common.envs.utils import postprocess_action, preprocess_observation -from lerobot.common.policies.act.modeling_act import ActionChunkingTransformerPolicy -from lerobot.common.policies.diffusion.modeling_diffusion import DiffusionPolicy -from lerobot.common.policies.factory import make_policy +from lerobot.common.policies.factory import get_policy_and_config_classes, make_policy from lerobot.common.policies.normalize import Normalize, Unnormalize from lerobot.common.policies.policy_protocol import Policy from lerobot.common.utils.utils import init_hydra_config from tests.utils import DEFAULT_CONFIG_PATH, DEVICE, require_env +@pytest.mark.parametrize("policy_name", available_policies) +def test_get_policy_and_config_classes(policy_name: str): + """Check that the correct policy and config classes are returned.""" + policy_cls, config_cls = get_policy_and_config_classes(policy_name) + assert policy_cls.name == policy_name + assert issubclass(config_cls, inspect.signature(policy_cls.__init__).parameters["config"].annotation) + + # TODO(aliberts): refactor using lerobot/__init__.py variables @pytest.mark.parametrize( "env_name,policy_name,extra_overrides", [ - # ("xarm", "tdmpc", ["policy.mpc=true"]), - # ("pusht", "tdmpc", ["policy.mpc=false"]), + ("xarm", "tdmpc", ["policy.use_mpc=true", "dataset_repo_id=lerobot/xarm_lift_medium"]), ("pusht", "diffusion", []), - ("aloha", "act", ["env.task=AlohaInsertion-v0", "dataset.repo_id=lerobot/aloha_sim_insertion_human"]), + ("aloha", "act", ["env.task=AlohaInsertion-v0", "dataset_repo_id=lerobot/aloha_sim_insertion_human"]), ( "aloha", "act", - ["env.task=AlohaInsertion-v0", "dataset.repo_id=lerobot/aloha_sim_insertion_scripted"], + ["env.task=AlohaInsertion-v0", "dataset_repo_id=lerobot/aloha_sim_insertion_scripted"], ), ( "aloha", "act", - ["env.task=AlohaTransferCube-v0", "dataset.repo_id=lerobot/aloha_sim_transfer_cube_human"], + ["env.task=AlohaTransferCube-v0", "dataset_repo_id=lerobot/aloha_sim_transfer_cube_human"], ), ( "aloha", "act", - ["env.task=AlohaTransferCube-v0", "dataset.repo_id=lerobot/aloha_sim_transfer_cube_scripted"], + ["env.task=AlohaTransferCube-v0", "dataset_repo_id=lerobot/aloha_sim_transfer_cube_scripted"], ), ], ) @@ -44,7 +53,8 @@ def test_policy(env_name, policy_name, extra_overrides): """ Tests: - Making the policy object. - - Checking that the policy follows the correct protocol. + - Checking that the policy follows the correct protocol and subclasses nn.Module + and PyTorchModelHubMixin. - Updating the policy. - Using the policy to select actions at inference time. - Test the action can be applied to the policy @@ -61,11 +71,13 @@ def test_policy(env_name, policy_name, extra_overrides): # Check that we can make the policy object. dataset = make_dataset(cfg) - policy = make_policy(cfg, dataset_stats=dataset.stats) + policy = make_policy(hydra_cfg=cfg, dataset_stats=dataset.stats) # Check that the policy follows the required protocol. assert isinstance( policy, Policy ), f"The policy does not follow the required protocol. Please see {Policy.__module__}.{Policy.__name__}." + assert isinstance(policy, torch.nn.Module) + assert isinstance(policy, PyTorchModelHubMixin) # Check that we run select_actions and get the appropriate output. env = make_env(cfg, num_parallel_envs=2) @@ -86,7 +98,7 @@ def test_policy(env_name, policy_name, extra_overrides): batch[key] = batch[key].to(DEVICE, non_blocking=True) # Test updating the policy - policy.forward(batch, step=0) + policy.forward(batch) # reset the policy and environment policy.reset() @@ -100,7 +112,7 @@ def test_policy(env_name, policy_name, extra_overrides): # get the next action for the environment with torch.inference_mode(): - action = policy.select_action(observation, step=0) + action = policy.select_action(observation) # convert action to cpu numpy array action = postprocess_action(action) @@ -108,29 +120,25 @@ def test_policy(env_name, policy_name, extra_overrides): # Test step through policy env.step(action) - # Test load state_dict - if policy_name != "tdmpc": - # TODO(rcadene, alexander-soare): make it work for tdmpc - new_policy = make_policy(cfg) - new_policy.load_state_dict(policy.state_dict()) + +@pytest.mark.parametrize("policy_name", available_policies) +def test_policy_defaults(policy_name: str): + """Check that the policy can be instantiated with defaults.""" + policy_cls, _ = get_policy_and_config_classes(policy_name) + policy_cls() -@pytest.mark.parametrize("policy_cls", [DiffusionPolicy, ActionChunkingTransformerPolicy]) -def test_policy_defaults(policy_cls): - kwargs = {} - # TODO(alexander-soare): Remove this kwargs hack when we move the scheduler out of DP. - if policy_cls is DiffusionPolicy: - kwargs = {"lr_scheduler_num_training_steps": 1} - policy_cls(**kwargs) +@pytest.mark.parametrize("policy_name", available_policies) +def test_save_and_load_pretrained(policy_name: str): + policy_cls, _ = get_policy_and_config_classes(policy_name) + policy: Policy = policy_cls() + save_dir = "/tmp/test_save_and_load_pretrained_{policy_cls.__name__}" + policy.save_pretrained(save_dir) + policy_ = policy_cls.from_pretrained(save_dir) + assert all(torch.equal(p, p_) for p, p_ in zip(policy.parameters(), policy_.parameters(), strict=True)) -@pytest.mark.parametrize( - "insert_temporal_dim", - [ - False, - True, - ], -) +@pytest.mark.parametrize("insert_temporal_dim", [False, True]) def test_normalize(insert_temporal_dim): """ Test that normalize/unnormalize can run without exceptions when properly set up, and that they raise diff --git a/tests/test_visualize_dataset.py b/tests/test_visualize_dataset.py index 3ed22970..895ff9d9 100644 --- a/tests/test_visualize_dataset.py +++ b/tests/test_visualize_dataset.py @@ -20,7 +20,7 @@ def test_visualize_dataset(tmpdir, repo_id): overrides=[ "policy=act", "env=aloha", - f"dataset.repo_id={repo_id}", + f"dataset_repo_id={repo_id}", ], ) video_paths = visualize_dataset(cfg, out_dir=tmpdir)