Merge branch 'main' into user/rcadene/2024_04_21_load_from_video

This commit is contained in:
Remi 2024-05-02 20:42:25 +02:00 committed by GitHub
commit ae44510d50
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
32 changed files with 1886 additions and 2008 deletions

View File

@ -22,74 +22,82 @@ test-end-to-end:
${MAKE} test-act-ete-eval ${MAKE} test-act-ete-eval
${MAKE} test-diffusion-ete-train ${MAKE} test-diffusion-ete-train
${MAKE} test-diffusion-ete-eval ${MAKE} test-diffusion-ete-eval
# ${MAKE} test-tdmpc-ete-train ${MAKE} test-tdmpc-ete-train
# ${MAKE} test-tdmpc-ete-eval ${MAKE} test-tdmpc-ete-eval
${MAKE} test-default-ete-eval
test-act-ete-train: test-act-ete-train:
python lerobot/scripts/train.py \ python lerobot/scripts/train.py \
policy=act \ policy=act \
env=aloha \ env=aloha \
wandb.enable=False \ wandb.enable=False \
offline_steps=2 \ training.offline_steps=2 \
online_steps=0 \ training.online_steps=0 \
eval_episodes=1 \ eval.n_episodes=1 \
device=cpu \ device=cpu \
save_model=true \ training.save_model=true \
save_freq=2 \ training.save_freq=2 \
policy.n_action_steps=20 \ policy.n_action_steps=20 \
policy.chunk_size=20 \ policy.chunk_size=20 \
policy.batch_size=2 \ training.batch_size=2 \
hydra.run.dir=tests/outputs/act/ hydra.run.dir=tests/outputs/act/
test-act-ete-eval: test-act-ete-eval:
python lerobot/scripts/eval.py \ python lerobot/scripts/eval.py \
--config tests/outputs/act/.hydra/config.yaml \ -p tests/outputs/act/checkpoints/000002 \
eval_episodes=1 \ eval.n_episodes=1 \
env.episode_length=8 \ env.episode_length=8 \
device=cpu \ device=cpu \
policy.pretrained_model_path=tests/outputs/act/models/2.pt
test-diffusion-ete-train: test-diffusion-ete-train:
python lerobot/scripts/train.py \ python lerobot/scripts/train.py \
policy=diffusion \ policy=diffusion \
env=pusht \ env=pusht \
wandb.enable=False \ wandb.enable=False \
offline_steps=2 \ training.offline_steps=2 \
online_steps=0 \ training.online_steps=0 \
eval_episodes=1 \ eval.n_episodes=1 \
device=cpu \ device=cpu \
save_model=true \ training.save_model=true \
save_freq=2 \ training.save_freq=2 \
policy.batch_size=2 \ training.batch_size=2 \
hydra.run.dir=tests/outputs/diffusion/ hydra.run.dir=tests/outputs/diffusion/
test-diffusion-ete-eval: test-diffusion-ete-eval:
python lerobot/scripts/eval.py \ python lerobot/scripts/eval.py \
--config tests/outputs/diffusion/.hydra/config.yaml \ -p tests/outputs/diffusion/checkpoints/000002 \
eval_episodes=1 \ eval.n_episodes=1 \
env.episode_length=8 \ env.episode_length=8 \
device=cpu \ device=cpu \
policy.pretrained_model_path=tests/outputs/diffusion/models/2.pt
test-tdmpc-ete-train: test-tdmpc-ete-train:
python lerobot/scripts/train.py \ python lerobot/scripts/train.py \
policy=tdmpc \ policy=tdmpc \
env=xarm \ env=xarm \
env.task=XarmLift-v0 \
dataset_repo_id=lerobot/xarm_lift_medium_replay \
wandb.enable=False \ wandb.enable=False \
offline_steps=1 \ training.offline_steps=2 \
online_steps=2 \ training.online_steps=2 \
eval_episodes=1 \ eval.n_episodes=1 \
env.episode_length=2 \ env.episode_length=2 \
device=cpu \ device=cpu \
save_model=true \ training.save_model=true \
save_freq=2 \ training.save_freq=2 \
policy.batch_size=2 \ training.batch_size=2 \
hydra.run.dir=tests/outputs/tdmpc/ hydra.run.dir=tests/outputs/tdmpc/
test-tdmpc-ete-eval: test-tdmpc-ete-eval:
python lerobot/scripts/eval.py \ python lerobot/scripts/eval.py \
--config tests/outputs/tdmpc/.hydra/config.yaml \ -p tests/outputs/tdmpc/checkpoints/000002 \
eval_episodes=1 \ eval.n_episodes=1 \
env.episode_length=8 \
device=cpu \
test-default-ete-eval:
python lerobot/scripts/eval.py \
--config lerobot/configs/default.yaml \
eval.n_episodes=1 \
env.episode_length=8 \ env.episode_length=8 \
device=cpu \ device=cpu \
policy.pretrained_model_path=tests/outputs/tdmpc/models/2.pt

View File

@ -135,16 +135,16 @@ Check out [examples](./examples) to see how you can load a pretrained policy fro
Or you can achieve the same result by executing our script from the command line: Or you can achieve the same result by executing our script from the command line:
```bash ```bash
python lerobot/scripts/eval.py \ python lerobot/scripts/eval.py \
--hub-id lerobot/diffusion_policy_pusht_image \ -p lerobot/diffusion_policy_pusht_image \
eval_episodes=10 \ eval_episodes=10 \
hydra.run.dir=outputs/eval/example_hub hydra.run.dir=outputs/eval/example_hub
``` ```
After training your own policy, you can also re-evaluate the checkpoints with: After training your own policy, you can also re-evaluate the checkpoints with:
```bash ```bash
python lerobot/scripts/eval.py \ python lerobot/scripts/eval.py \
--config PATH/TO/FOLDER/config.yaml \ -p PATH/TO/TRAIN/OUTPUT/FOLDER \
policy.pretrained_model_path=PATH/TO/FOLDER/weights.pth \
eval_episodes=10 \ eval_episodes=10 \
hydra.run.dir=outputs/eval/example_dir hydra.run.dir=outputs/eval/example_dir
``` ```
@ -246,29 +246,22 @@ Once you have trained a policy you may upload it to the HuggingFace hub.
Firstly, make sure you have a model repository set up on the hub. The hub ID looks like HF_USER/REPO_NAME. Firstly, make sure you have a model repository set up on the hub. The hub ID looks like HF_USER/REPO_NAME.
Secondly, assuming you have trained a policy, you need: Secondly, assuming you have trained a policy, you need the following (which should all be in any of the subdirectories of `checkpoints` in your training output folder, if you've used the LeRobot training script):
- `config.yaml` which you can get from the `.hydra` directory of your training output folder. - `config.json`: A serialized version of the policy configuration (following the policy's dataclass config).
- `model.pt` which should be one of the saved models in the `models` directory of your training output folder (they won't be named `model.pt` but you will need to choose one). - `model.safetensors`: The `torch.nn.Module` parameters saved in [Hugging Face Safetensors](https://huggingface.co/docs/safetensors/index) format.
- `config.yaml`: This is the consolidated Hydra training configuration containing the policy, environment, and dataset configs. The policy configuration should match `config.json` exactly. The environment config is useful for anyone who wants to evaluate your policy. The dataset config just serves as a paper trail for reproducibility.
To upload these to the hub, prepare a folder with the following structure (you can use symlinks rather than copying): To upload these to the hub, run the following with a desired revision ID.
```
to_upload
├── config.yaml
└── model.pt
```
With the folder prepared, run the following with a desired revision ID.
```bash ```bash
huggingface-cli upload $HUB_ID to_upload --revision $REVISION_ID huggingface-cli upload $HUB_ID PATH/TO/OUTPUT/DIR --revision $REVISION_ID
``` ```
If you want this to be the default revision also run the following (don't worry, it won't upload the files again; it will just adjust the file pointers): If you want this to be the default revision also run the following (don't worry, it won't upload the files again; it will just adjust the file pointers):
```bash ```bash
huggingface-cli upload $HUB_ID to_upload huggingface-cli upload $HUB_ID PATH/TO/OUTPUT/DIR
``` ```
See `eval.py` for an example of how a user may use your policy. See `eval.py` for an example of how a user may use your policy.

View File

@ -7,32 +7,21 @@ from pathlib import Path
from huggingface_hub import snapshot_download from huggingface_hub import snapshot_download
from lerobot.common.utils.utils import init_hydra_config
from lerobot.scripts.eval import eval from lerobot.scripts.eval import eval
# Get a pretrained policy from the hub. # Get a pretrained policy from the hub.
# TODO(alexander-soare): This no longer works until we upload a new model that uses the current configs. pretrained_policy_name = "lerobot/diffusion_policy_pusht_image"
hub_id = "lerobot/diffusion_policy_pusht_image" pretrained_policy_path = Path(snapshot_download(pretrained_policy_name))
folder = Path(snapshot_download(hub_id))
# OR uncomment the following to evaluate a policy from the local outputs/train folder. # OR uncomment the following to evaluate a policy from the local outputs/train folder.
# folder = Path("outputs/train/example_pusht_diffusion") # pretrained_policy_path = Path("outputs/train/example_pusht_diffusion")
config_path = folder / "config.yaml"
weights_path = folder / "model.pt"
# Override some config parameters to do with evaluation. # Override some config parameters to do with evaluation.
overrides = [ overrides = [
f"policy.pretrained_model_path={weights_path}", "eval.n_episodes=10",
"eval_episodes=10", "eval.batch_size=10",
"rollout_batch_size=10",
"device=cuda", "device=cuda",
] ]
# Create a Hydra config.
cfg = init_hydra_config(config_path, overrides)
# Evaluate the policy and save the outputs including metrics and videos. # Evaluate the policy and save the outputs including metrics and videos.
eval( # TODO(rcadene, alexander-soare): dont call eval, but add the minimal code snippet to rollout
cfg, eval(pretrained_policy_path=pretrained_policy_path)
out_dir=f"outputs/eval/example_{cfg.env.name}_{cfg.policy.name}",
)

View File

@ -34,19 +34,17 @@ dataset = make_dataset(hydra_cfg)
# If you're doing something different, you will likely need to change at least some of the defaults. # If you're doing something different, you will likely need to change at least some of the defaults.
cfg = DiffusionConfig() cfg = DiffusionConfig()
# TODO(alexander-soare): Remove LR scheduler from the policy. # TODO(alexander-soare): Remove LR scheduler from the policy.
policy = DiffusionPolicy(cfg, lr_scheduler_num_training_steps=training_steps, dataset_stats=dataset.stats) policy = DiffusionPolicy(cfg, dataset_stats=dataset.stats)
policy.train() policy.train()
policy.to(device) policy.to(device)
optimizer = torch.optim.Adam( optimizer = torch.optim.Adam(policy.parameters(), lr=1e-4)
policy.diffusion.parameters(), cfg.lr, cfg.adam_betas, cfg.adam_eps, cfg.adam_weight_decay
)
# Create dataloader for offline training. # Create dataloader for offline training.
dataloader = torch.utils.data.DataLoader( dataloader = torch.utils.data.DataLoader(
dataset, dataset,
num_workers=4, num_workers=4,
batch_size=cfg.batch_size, batch_size=64,
shuffle=True, shuffle=True,
pin_memory=device != torch.device("cpu"), pin_memory=device != torch.device("cpu"),
drop_last=True, drop_last=True,
@ -71,6 +69,7 @@ while not done:
done = True done = True
break break
# Save the policy and configuration for later use. # Save the policy.
policy.save(output_directory / "model.pt") policy.save_pretrained(output_directory)
# Save the Hydra configuration so we have the environment configuration for eval.
OmegaConf.save(hydra_cfg, output_directory / "config.yaml") OmegaConf.save(hydra_cfg, output_directory / "config.yaml")

View File

@ -14,12 +14,13 @@ def make_dataset(
cfg, cfg,
split="train", split="train",
): ):
if cfg.env.name not in cfg.dataset.repo_id: if cfg.env.name not in cfg.dataset_repo_id:
logging.warning( logging.warning(
f"There might be a mismatch between your training dataset ({cfg.dataset.repo_id=}) and your environment ({cfg.env.name=})." f"There might be a mismatch between your training dataset ({cfg.dataset_repo_id=}) and your "
f"environment ({cfg.env.name=})."
) )
delta_timestamps = cfg.policy.get("delta_timestamps") delta_timestamps = cfg.training.get("delta_timestamps")
if delta_timestamps is not None: if delta_timestamps is not None:
for key in delta_timestamps: for key in delta_timestamps:
if isinstance(delta_timestamps[key], str): if isinstance(delta_timestamps[key], str):
@ -28,7 +29,7 @@ def make_dataset(
# TODO(rcadene): add data augmentations # TODO(rcadene): add data augmentations
dataset = LeRobotDataset( dataset = LeRobotDataset(
cfg.dataset.repo_id, cfg.dataset_repo_id,
split=split, split=split,
root=DATA_DIR, root=DATA_DIR,
delta_timestamps=delta_timestamps, delta_timestamps=delta_timestamps,

View File

@ -5,9 +5,12 @@ import logging
import os import os
from pathlib import Path from pathlib import Path
from huggingface_hub.constants import SAFETENSORS_SINGLE_FILE
from omegaconf import OmegaConf from omegaconf import OmegaConf
from termcolor import colored from termcolor import colored
from lerobot.common.policies.policy_protocol import Policy
def log_output_dir(out_dir): def log_output_dir(out_dir):
logging.info(colored("Output dir:", "yellow", attrs=["bold"]) + f" {out_dir}") logging.info(colored("Output dir:", "yellow", attrs=["bold"]) + f" {out_dir}")
@ -30,11 +33,11 @@ class Logger:
self._log_dir = Path(log_dir) self._log_dir = Path(log_dir)
self._log_dir.mkdir(parents=True, exist_ok=True) self._log_dir.mkdir(parents=True, exist_ok=True)
self._job_name = job_name self._job_name = job_name
self._model_dir = self._log_dir / "models" self._model_dir = self._log_dir / "checkpoints"
self._buffer_dir = self._log_dir / "buffers" self._buffer_dir = self._log_dir / "buffers"
self._save_model = cfg.save_model self._save_model = cfg.training.save_model
self._disable_wandb_artifact = cfg.wandb.disable_artifact self._disable_wandb_artifact = cfg.wandb.disable_artifact
self._save_buffer = cfg.save_buffer self._save_buffer = cfg.training.get("save_buffer", False)
self._group = cfg_to_group(cfg) self._group = cfg_to_group(cfg)
self._seed = cfg.seed self._seed = cfg.seed
self._cfg = cfg self._cfg = cfg
@ -70,18 +73,20 @@ class Logger:
logging.info(f"Track this run --> {colored(wandb.run.get_url(), 'yellow', attrs=['bold'])}") logging.info(f"Track this run --> {colored(wandb.run.get_url(), 'yellow', attrs=['bold'])}")
self._wandb = wandb self._wandb = wandb
def save_model(self, policy, identifier): def save_model(self, policy: Policy, identifier):
if self._save_model: if self._save_model:
self._model_dir.mkdir(parents=True, exist_ok=True) self._model_dir.mkdir(parents=True, exist_ok=True)
fp = self._model_dir / f"{str(identifier)}.pt" save_dir = self._model_dir / str(identifier)
policy.save(fp) policy.save_pretrained(save_dir)
# Also save the full Hydra config for the env configuration.
OmegaConf.save(self._cfg, save_dir / "config.yaml")
if self._wandb and not self._disable_wandb_artifact: if self._wandb and not self._disable_wandb_artifact:
# note wandb artifact does not accept ":" in its name # note wandb artifact does not accept ":" in its name
artifact = self._wandb.Artifact( artifact = self._wandb.Artifact(
self._group.replace(":", "_") + "-" + str(self._seed) + "-" + str(identifier), self._group.replace(":", "_") + "-" + str(self._seed) + "-" + str(identifier),
type="model", type="model",
) )
artifact.add_file(fp) artifact.add_file(save_dir / SAFETENSORS_SINGLE_FILE)
self._wandb.log_artifact(artifact) self._wandb.log_artifact(artifact)
def save_buffer(self, buffer, identifier): def save_buffer(self, buffer, identifier):

View File

@ -2,7 +2,7 @@ from dataclasses import dataclass, field
@dataclass @dataclass
class ActionChunkingTransformerConfig: class ACTConfig:
"""Configuration class for the Action Chunking Transformers policy. """Configuration class for the Action Chunking Transformers policy.
Defaults are configured for training on bimanual Aloha tasks like "insertion" or "transfer". Defaults are configured for training on bimanual Aloha tasks like "insertion" or "transfer".
@ -22,23 +22,24 @@ class ActionChunkingTransformerConfig:
The key represents the input data name, and the value is a list indicating the dimensions The key represents the input data name, and the value is a list indicating the dimensions
of the corresponding data. For example, "observation.images.top" refers to an input from the of the corresponding data. For example, "observation.images.top" refers to an input from the
"top" camera with dimensions [3, 96, 96], indicating it has three color channels and 96x96 resolution. "top" camera with dimensions [3, 96, 96], indicating it has three color channels and 96x96 resolution.
Importantly, shapes doesnt include batch dimension or temporal dimension. Importantly, shapes doesn't include batch dimension or temporal dimension.
output_shapes: A dictionary defining the shapes of the output data for the policy. output_shapes: A dictionary defining the shapes of the output data for the policy.
The key represents the output data name, and the value is a list indicating the dimensions The key represents the output data name, and the value is a list indicating the dimensions
of the corresponding data. For example, "action" refers to an output shape of [14], indicating of the corresponding data. For example, "action" refers to an output shape of [14], indicating
14-dimensional actions. Importantly, shapes doesnt include batch dimension or temporal dimension. 14-dimensional actions. Importantly, shapes doesn't include batch dimension or temporal dimension.
normalize_input_modes: A dictionary with key represents the modality (e.g. "observation.state"), input_normalization_modes: A dictionary with key representing the modality (e.g. "observation.state"),
and the value specifies the normalization mode to apply. The two availables and the value specifies the normalization mode to apply. The two available modes are "mean_std"
modes are "mean_std" which substracts the mean and divide by the standard which subtracts the mean and divides by the standard deviation and "min_max" which rescale in a
deviation and "min_max" which rescale in a [-1, 1] range. [-1, 1] range.
unnormalize_output_modes: Similar dictionary as `normalize_input_modes`, but to unormalize in original scale. output_normalization_modes: Similar dictionary as `normalize_input_modes`, but to unnormalize to the
original scale. Note that this is also used for normalizing the training targets.
vision_backbone: Name of the torchvision resnet backbone to use for encoding images. vision_backbone: Name of the torchvision resnet backbone to use for encoding images.
pretrained_backbone_weights: Pretrained weights from torchvision to initalize the backbone. pretrained_backbone_weights: Pretrained weights from torchvision to initalize the backbone.
`None` means no pretrained weights. `None` means no pretrained weights.
replace_final_stride_with_dilation: Whether to replace the ResNet's final 2x2 stride with a dilated replace_final_stride_with_dilation: Whether to replace the ResNet's final 2x2 stride with a dilated
convolution. convolution.
pre_norm: Whether to use "pre-norm" in the transformer blocks. pre_norm: Whether to use "pre-norm" in the transformer blocks.
d_model: The transformer blocks' main hidden dimension. dim_model: The transformer blocks' main hidden dimension.
n_heads: The number of heads to use in the transformer blocks' multi-head attention. n_heads: The number of heads to use in the transformer blocks' multi-head attention.
dim_feedforward: The dimension to expand the transformer's hidden dimension to in the feed-forward dim_feedforward: The dimension to expand the transformer's hidden dimension to in the feed-forward
layers. layers.
@ -62,13 +63,13 @@ class ActionChunkingTransformerConfig:
chunk_size: int = 100 chunk_size: int = 100
n_action_steps: int = 100 n_action_steps: int = 100
input_shapes: dict[str, list[str]] = field( input_shapes: dict[str, list[int]] = field(
default_factory=lambda: { default_factory=lambda: {
"observation.images.top": [3, 480, 640], "observation.images.top": [3, 480, 640],
"observation.state": [14], "observation.state": [14],
} }
) )
output_shapes: dict[str, list[str]] = field( output_shapes: dict[str, list[int]] = field(
default_factory=lambda: { default_factory=lambda: {
"action": [14], "action": [14],
} }
@ -94,7 +95,7 @@ class ActionChunkingTransformerConfig:
replace_final_stride_with_dilation: int = False replace_final_stride_with_dilation: int = False
# Transformer layers. # Transformer layers.
pre_norm: bool = False pre_norm: bool = False
d_model: int = 512 dim_model: int = 512
n_heads: int = 8 n_heads: int = 8
dim_feedforward: int = 3200 dim_feedforward: int = 3200
feedforward_activation: str = "relu" feedforward_activation: str = "relu"
@ -112,15 +113,6 @@ class ActionChunkingTransformerConfig:
dropout: float = 0.1 dropout: float = 0.1
kl_weight: float = 10.0 kl_weight: float = 10.0
# ---
# TODO(alexander-soare): Remove these from the policy config.
batch_size: int = 8
lr: float = 1e-5
lr_backbone: float = 1e-5
weight_decay: float = 1e-4
grad_clip_norm: float = 10
utd: int = 1
def __post_init__(self): def __post_init__(self):
"""Input validation (not exhaustive).""" """Input validation (not exhaustive)."""
if not self.vision_backbone.startswith("resnet"): if not self.vision_backbone.startswith("resnet"):

View File

@ -14,18 +14,124 @@ import numpy as np
import torch import torch
import torch.nn.functional as F # noqa: N812 import torch.nn.functional as F # noqa: N812
import torchvision import torchvision
from huggingface_hub import PyTorchModelHubMixin
from torch import Tensor, nn from torch import Tensor, nn
from torchvision.models._utils import IntermediateLayerGetter from torchvision.models._utils import IntermediateLayerGetter
from torchvision.ops.misc import FrozenBatchNorm2d from torchvision.ops.misc import FrozenBatchNorm2d
from lerobot.common.policies.act.configuration_act import ActionChunkingTransformerConfig from lerobot.common.policies.act.configuration_act import ACTConfig
from lerobot.common.policies.normalize import Normalize, Unnormalize from lerobot.common.policies.normalize import Normalize, Unnormalize
class ActionChunkingTransformerPolicy(nn.Module): class ACTPolicy(nn.Module, PyTorchModelHubMixin):
""" """
Action Chunking Transformer Policy as per Learning Fine-Grained Bimanual Manipulation with Low-Cost Action Chunking Transformer Policy as per Learning Fine-Grained Bimanual Manipulation with Low-Cost
Hardware (paper: https://arxiv.org/abs/2304.13705, code: https://github.com/tonyzhaozh/act) Hardware (paper: https://arxiv.org/abs/2304.13705, code: https://github.com/tonyzhaozh/act)
"""
name = "act"
def __init__(
self,
config: ACTConfig | None = None,
dataset_stats: dict[str, dict[str, Tensor]] | None = None,
):
"""
Args:
config: Policy configuration class instance or None, in which case the default instantiation of
the configuration class is used.
dataset_stats: Dataset statistics to be used for normalization. If not passed here, it is expected
that they will be passed with a call to `load_state_dict` before the policy is used.
"""
super().__init__()
if config is None:
config = ACTConfig()
self.config = config
self.normalize_inputs = Normalize(
config.input_shapes, config.input_normalization_modes, dataset_stats
)
self.normalize_targets = Normalize(
config.output_shapes, config.output_normalization_modes, dataset_stats
)
self.unnormalize_outputs = Unnormalize(
config.output_shapes, config.output_normalization_modes, dataset_stats
)
self.model = ACT(config)
def reset(self):
"""This should be called whenever the environment is reset."""
if self.config.n_action_steps is not None:
self._action_queue = deque([], maxlen=self.config.n_action_steps)
@torch.no_grad
def select_action(self, batch: dict[str, Tensor]) -> Tensor:
"""Select a single action given environment observations.
This method wraps `select_actions` in order to return one action at a time for execution in the
environment. It works by managing the actions in a queue and only calling `select_actions` when the
queue is empty.
"""
self.eval()
batch = self.normalize_inputs(batch)
self._stack_images(batch)
if len(self._action_queue) == 0:
# `self.model.forward` returns a (batch_size, n_action_steps, action_dim) tensor, but the queue
# effectively has shape (n_action_steps, batch_size, *), hence the transpose.
actions = self.model(batch)[0][: self.config.n_action_steps]
# TODO(rcadene): make _forward return output dictionary?
actions = self.unnormalize_outputs({"action": actions})["action"]
self._action_queue.extend(actions.transpose(0, 1))
return self._action_queue.popleft()
def forward(self, batch: dict[str, Tensor]) -> dict[str, Tensor]:
"""Run the batch through the model and compute the loss for training or validation."""
batch = self.normalize_inputs(batch)
batch = self.normalize_targets(batch)
self._stack_images(batch)
actions_hat, (mu_hat, log_sigma_x2_hat) = self.model(batch)
l1_loss = (
F.l1_loss(batch["action"], actions_hat, reduction="none") * ~batch["action_is_pad"].unsqueeze(-1)
).mean()
loss_dict = {"l1_loss": l1_loss}
if self.config.use_vae:
# Calculate Dₖₗ(latent_pdf || standard_normal). Note: After computing the KL-divergence for
# each dimension independently, we sum over the latent dimension to get the total
# KL-divergence per batch element, then take the mean over the batch.
# (See App. B of https://arxiv.org/abs/1312.6114 for more details).
mean_kld = (
(-0.5 * (1 + log_sigma_x2_hat - mu_hat.pow(2) - (log_sigma_x2_hat).exp())).sum(-1).mean()
)
loss_dict["kld_loss"] = mean_kld
loss_dict["loss"] = l1_loss + mean_kld * self.config.kl_weight
else:
loss_dict["loss"] = l1_loss
return loss_dict
def _stack_images(self, batch: dict[str, Tensor]) -> dict[str, Tensor]:
"""Stacks all the images in a batch and puts them in a new key: "observation.images".
This function expects `batch` to have (at least):
{
"observation.state": (B, state_dim) batch of robot states.
"observation.images.{name}": (B, C, H, W) tensor of images.
}
"""
# Stack images in the order dictated by input_shapes.
batch["observation.images"] = torch.stack(
[batch[k] for k in self.config.input_shapes if k.startswith("observation.images.")],
dim=-4,
)
class ACT(nn.Module):
"""Action Chunking Transformer: The underlying neural network for ACTPolicy.
Note: In this code we use the terms `vae_encoder`, 'encoder', `decoder`. The meanings are as follows. Note: In this code we use the terms `vae_encoder`, 'encoder', `decoder`. The meanings are as follows.
- The `vae_encoder` is, as per the literature around variational auto-encoders (VAE), the part of the - The `vae_encoder` is, as per the literature around variational auto-encoders (VAE), the part of the
@ -59,51 +165,36 @@ class ActionChunkingTransformerPolicy(nn.Module):
""" """
name = "act" def __init__(self, config: ACTConfig):
def __init__(self, cfg: ActionChunkingTransformerConfig | None = None, dataset_stats=None):
"""
Args:
cfg: Policy configuration class instance or None, in which case the default instantiation of the
configuration class is used.
"""
super().__init__() super().__init__()
if cfg is None: self.config = config
cfg = ActionChunkingTransformerConfig()
self.cfg = cfg
self.normalize_inputs = Normalize(cfg.input_shapes, cfg.input_normalization_modes, dataset_stats)
self.normalize_targets = Normalize(cfg.output_shapes, cfg.output_normalization_modes, dataset_stats)
self.unnormalize_outputs = Unnormalize(
cfg.output_shapes, cfg.output_normalization_modes, dataset_stats
)
# BERT style VAE encoder with input [cls, *joint_space_configuration, *action_sequence]. # BERT style VAE encoder with input [cls, *joint_space_configuration, *action_sequence].
# The cls token forms parameters of the latent's distribution (like this [*means, *log_variances]). # The cls token forms parameters of the latent's distribution (like this [*means, *log_variances]).
if self.cfg.use_vae: if self.config.use_vae:
self.vae_encoder = _TransformerEncoder(cfg) self.vae_encoder = ACTEncoder(config)
self.vae_encoder_cls_embed = nn.Embedding(1, cfg.d_model) self.vae_encoder_cls_embed = nn.Embedding(1, config.dim_model)
# Projection layer for joint-space configuration to hidden dimension. # Projection layer for joint-space configuration to hidden dimension.
self.vae_encoder_robot_state_input_proj = nn.Linear( self.vae_encoder_robot_state_input_proj = nn.Linear(
cfg.input_shapes["observation.state"][0], cfg.d_model config.input_shapes["observation.state"][0], config.dim_model
) )
# Projection layer for action (joint-space target) to hidden dimension. # Projection layer for action (joint-space target) to hidden dimension.
self.vae_encoder_action_input_proj = nn.Linear( self.vae_encoder_action_input_proj = nn.Linear(
cfg.input_shapes["observation.state"][0], cfg.d_model config.input_shapes["observation.state"][0], config.dim_model
) )
self.latent_dim = cfg.latent_dim self.latent_dim = config.latent_dim
# Projection layer from the VAE encoder's output to the latent distribution's parameter space. # Projection layer from the VAE encoder's output to the latent distribution's parameter space.
self.vae_encoder_latent_output_proj = nn.Linear(cfg.d_model, self.latent_dim * 2) self.vae_encoder_latent_output_proj = nn.Linear(config.dim_model, self.latent_dim * 2)
# Fixed sinusoidal positional embedding the whole input to the VAE encoder. Unsqueeze for batch # Fixed sinusoidal positional embedding the whole input to the VAE encoder. Unsqueeze for batch
# dimension. # dimension.
self.register_buffer( self.register_buffer(
"vae_encoder_pos_enc", "vae_encoder_pos_enc",
_create_sinusoidal_position_embedding(1 + 1 + cfg.chunk_size, cfg.d_model).unsqueeze(0), create_sinusoidal_pos_embedding(1 + 1 + config.chunk_size, config.dim_model).unsqueeze(0),
) )
# Backbone for image feature extraction. # Backbone for image feature extraction.
backbone_model = getattr(torchvision.models, cfg.vision_backbone)( backbone_model = getattr(torchvision.models, config.vision_backbone)(
replace_stride_with_dilation=[False, False, cfg.replace_final_stride_with_dilation], replace_stride_with_dilation=[False, False, config.replace_final_stride_with_dilation],
weights=cfg.pretrained_backbone_weights, weights=config.pretrained_backbone_weights,
norm_layer=FrozenBatchNorm2d, norm_layer=FrozenBatchNorm2d,
) )
# Note: The assumption here is that we are using a ResNet model (and hence layer4 is the final feature # Note: The assumption here is that we are using a ResNet model (and hence layer4 is the final feature
@ -112,26 +203,28 @@ class ActionChunkingTransformerPolicy(nn.Module):
self.backbone = IntermediateLayerGetter(backbone_model, return_layers={"layer4": "feature_map"}) self.backbone = IntermediateLayerGetter(backbone_model, return_layers={"layer4": "feature_map"})
# Transformer (acts as VAE decoder when training with the variational objective). # Transformer (acts as VAE decoder when training with the variational objective).
self.encoder = _TransformerEncoder(cfg) self.encoder = ACTEncoder(config)
self.decoder = _TransformerDecoder(cfg) self.decoder = ACTDecoder(config)
# Transformer encoder input projections. The tokens will be structured like # Transformer encoder input projections. The tokens will be structured like
# [latent, robot_state, image_feature_map_pixels]. # [latent, robot_state, image_feature_map_pixels].
self.encoder_robot_state_input_proj = nn.Linear(cfg.input_shapes["observation.state"][0], cfg.d_model) self.encoder_robot_state_input_proj = nn.Linear(
self.encoder_latent_input_proj = nn.Linear(self.latent_dim, cfg.d_model) config.input_shapes["observation.state"][0], config.dim_model
)
self.encoder_latent_input_proj = nn.Linear(self.latent_dim, config.dim_model)
self.encoder_img_feat_input_proj = nn.Conv2d( self.encoder_img_feat_input_proj = nn.Conv2d(
backbone_model.fc.in_features, cfg.d_model, kernel_size=1 backbone_model.fc.in_features, config.dim_model, kernel_size=1
) )
# Transformer encoder positional embeddings. # Transformer encoder positional embeddings.
self.encoder_robot_and_latent_pos_embed = nn.Embedding(2, cfg.d_model) self.encoder_robot_and_latent_pos_embed = nn.Embedding(2, config.dim_model)
self.encoder_cam_feat_pos_embed = _SinusoidalPositionEmbedding2D(cfg.d_model // 2) self.encoder_cam_feat_pos_embed = ACTSinusoidalPositionEmbedding2d(config.dim_model // 2)
# Transformer decoder. # Transformer decoder.
# Learnable positional embedding for the transformer's decoder (in the style of DETR object queries). # Learnable positional embedding for the transformer's decoder (in the style of DETR object queries).
self.decoder_pos_embed = nn.Embedding(cfg.chunk_size, cfg.d_model) self.decoder_pos_embed = nn.Embedding(config.chunk_size, config.dim_model)
# Final action regression head on the output of the transformer's decoder. # Final action regression head on the output of the transformer's decoder.
self.action_head = nn.Linear(cfg.d_model, cfg.output_shapes["action"][0]) self.action_head = nn.Linear(config.dim_model, config.output_shapes["action"][0])
self._reset_parameters() self._reset_parameters()
@ -141,76 +234,7 @@ class ActionChunkingTransformerPolicy(nn.Module):
if p.dim() > 1: if p.dim() > 1:
nn.init.xavier_uniform_(p) nn.init.xavier_uniform_(p)
def reset(self): def forward(self, batch: dict[str, Tensor]) -> tuple[Tensor, tuple[Tensor, Tensor] | tuple[None, None]]:
"""This should be called whenever the environment is reset."""
if self.cfg.n_action_steps is not None:
self._action_queue = deque([], maxlen=self.cfg.n_action_steps)
@torch.no_grad
def select_action(self, batch: dict[str, Tensor], **_) -> Tensor:
"""Select a single action given environment observations.
This method wraps `select_actions` in order to return one action at a time for execution in the
environment. It works by managing the actions in a queue and only calling `select_actions` when the
queue is empty.
"""
self.eval()
batch = self.normalize_inputs(batch)
if len(self._action_queue) == 0:
# `_forward` returns a (batch_size, n_action_steps, action_dim) tensor, but the queue effectively
# has shape (n_action_steps, batch_size, *), hence the transpose.
actions = self._forward(batch)[0][: self.cfg.n_action_steps]
# TODO(rcadene): make _forward return output dictionary?
actions = self.unnormalize_outputs({"action": actions})["action"]
self._action_queue.extend(actions.transpose(0, 1))
return self._action_queue.popleft()
def forward(self, batch, **_) -> dict[str, Tensor]:
"""Run the batch through the model and compute the loss for training or validation."""
batch = self.normalize_inputs(batch)
batch = self.normalize_targets(batch)
actions_hat, (mu_hat, log_sigma_x2_hat) = self._forward(batch)
l1_loss = (
F.l1_loss(batch["action"], actions_hat, reduction="none") * ~batch["action_is_pad"].unsqueeze(-1)
).mean()
loss_dict = {"l1_loss": l1_loss}
if self.cfg.use_vae:
# Calculate Dₖₗ(latent_pdf || standard_normal). Note: After computing the KL-divergence for
# each dimension independently, we sum over the latent dimension to get the total
# KL-divergence per batch element, then take the mean over the batch.
# (See App. B of https://arxiv.org/abs/1312.6114 for more details).
mean_kld = (
(-0.5 * (1 + log_sigma_x2_hat - mu_hat.pow(2) - (log_sigma_x2_hat).exp())).sum(-1).mean()
)
loss_dict["kld_loss"] = mean_kld
loss_dict["loss"] = l1_loss + mean_kld * self.cfg.kl_weight
else:
loss_dict["loss"] = l1_loss
return loss_dict
def _stack_images(self, batch: dict[str, Tensor]) -> dict[str, Tensor]:
"""Stacks all the images in a batch and puts them in a new key: "observation.images".
This function expects `batch` to have (at least):
{
"observation.state": (B, state_dim) batch of robot states.
"observation.images.{name}": (B, C, H, W) tensor of images.
}
"""
# Stack images in the order dictated by input_shapes.
batch["observation.images"] = torch.stack(
[batch[k] for k in self.cfg.input_shapes if k.startswith("observation.images.")],
dim=-4,
)
def _forward(self, batch: dict[str, Tensor]) -> tuple[Tensor, tuple[Tensor, Tensor] | tuple[None, None]]:
"""A forward pass through the Action Chunking Transformer (with optional VAE encoder). """A forward pass through the Action Chunking Transformer (with optional VAE encoder).
`batch` should have the following structure: `batch` should have the following structure:
@ -226,17 +250,15 @@ class ActionChunkingTransformerPolicy(nn.Module):
Tuple containing the latent PDF's parameters (mean, log(σ²)) both as (B, L) tensors where L is the Tuple containing the latent PDF's parameters (mean, log(σ²)) both as (B, L) tensors where L is the
latent dimension. latent dimension.
""" """
if self.cfg.use_vae and self.training: if self.config.use_vae and self.training:
assert ( assert (
"action" in batch "action" in batch
), "actions must be provided when using the variational objective in training mode." ), "actions must be provided when using the variational objective in training mode."
self._stack_images(batch)
batch_size = batch["observation.state"].shape[0] batch_size = batch["observation.state"].shape[0]
# Prepare the latent for input to the transformer encoder. # Prepare the latent for input to the transformer encoder.
if self.cfg.use_vae and "action" in batch: if self.config.use_vae and "action" in batch:
# Prepare the input to the VAE encoder: [cls, *joint_space_configuration, *action_sequence]. # Prepare the input to the VAE encoder: [cls, *joint_space_configuration, *action_sequence].
cls_embed = einops.repeat( cls_embed = einops.repeat(
self.vae_encoder_cls_embed.weight, "1 d -> b 1 d", b=batch_size self.vae_encoder_cls_embed.weight, "1 d -> b 1 d", b=batch_size
@ -306,7 +328,7 @@ class ActionChunkingTransformerPolicy(nn.Module):
# Forward pass through the transformer modules. # Forward pass through the transformer modules.
encoder_out = self.encoder(encoder_in, pos_embed=pos_embed) encoder_out = self.encoder(encoder_in, pos_embed=pos_embed)
decoder_in = torch.zeros( decoder_in = torch.zeros(
(self.cfg.chunk_size, batch_size, self.cfg.d_model), (self.config.chunk_size, batch_size, self.config.dim_model),
dtype=pos_embed.dtype, dtype=pos_embed.dtype,
device=pos_embed.device, device=pos_embed.device,
) )
@ -324,21 +346,14 @@ class ActionChunkingTransformerPolicy(nn.Module):
return actions, (mu, log_sigma_x2) return actions, (mu, log_sigma_x2)
def save(self, fp):
torch.save(self.state_dict(), fp)
def load(self, fp): class ACTEncoder(nn.Module):
d = torch.load(fp)
self.load_state_dict(d)
class _TransformerEncoder(nn.Module):
"""Convenience module for running multiple encoder layers, maybe followed by normalization.""" """Convenience module for running multiple encoder layers, maybe followed by normalization."""
def __init__(self, cfg: ActionChunkingTransformerConfig): def __init__(self, config: ACTConfig):
super().__init__() super().__init__()
self.layers = nn.ModuleList([_TransformerEncoderLayer(cfg) for _ in range(cfg.n_encoder_layers)]) self.layers = nn.ModuleList([ACTEncoderLayer(config) for _ in range(config.n_encoder_layers)])
self.norm = nn.LayerNorm(cfg.d_model) if cfg.pre_norm else nn.Identity() self.norm = nn.LayerNorm(config.dim_model) if config.pre_norm else nn.Identity()
def forward(self, x: Tensor, pos_embed: Tensor | None = None) -> Tensor: def forward(self, x: Tensor, pos_embed: Tensor | None = None) -> Tensor:
for layer in self.layers: for layer in self.layers:
@ -347,23 +362,23 @@ class _TransformerEncoder(nn.Module):
return x return x
class _TransformerEncoderLayer(nn.Module): class ACTEncoderLayer(nn.Module):
def __init__(self, cfg: ActionChunkingTransformerConfig): def __init__(self, config: ACTConfig):
super().__init__() super().__init__()
self.self_attn = nn.MultiheadAttention(cfg.d_model, cfg.n_heads, dropout=cfg.dropout) self.self_attn = nn.MultiheadAttention(config.dim_model, config.n_heads, dropout=config.dropout)
# Feed forward layers. # Feed forward layers.
self.linear1 = nn.Linear(cfg.d_model, cfg.dim_feedforward) self.linear1 = nn.Linear(config.dim_model, config.dim_feedforward)
self.dropout = nn.Dropout(cfg.dropout) self.dropout = nn.Dropout(config.dropout)
self.linear2 = nn.Linear(cfg.dim_feedforward, cfg.d_model) self.linear2 = nn.Linear(config.dim_feedforward, config.dim_model)
self.norm1 = nn.LayerNorm(cfg.d_model) self.norm1 = nn.LayerNorm(config.dim_model)
self.norm2 = nn.LayerNorm(cfg.d_model) self.norm2 = nn.LayerNorm(config.dim_model)
self.dropout1 = nn.Dropout(cfg.dropout) self.dropout1 = nn.Dropout(config.dropout)
self.dropout2 = nn.Dropout(cfg.dropout) self.dropout2 = nn.Dropout(config.dropout)
self.activation = _get_activation_fn(cfg.feedforward_activation) self.activation = get_activation_fn(config.feedforward_activation)
self.pre_norm = cfg.pre_norm self.pre_norm = config.pre_norm
def forward(self, x, pos_embed: Tensor | None = None) -> Tensor: def forward(self, x, pos_embed: Tensor | None = None) -> Tensor:
skip = x skip = x
@ -385,12 +400,12 @@ class _TransformerEncoderLayer(nn.Module):
return x return x
class _TransformerDecoder(nn.Module): class ACTDecoder(nn.Module):
def __init__(self, cfg: ActionChunkingTransformerConfig): def __init__(self, config: ACTConfig):
"""Convenience module for running multiple decoder layers followed by normalization.""" """Convenience module for running multiple decoder layers followed by normalization."""
super().__init__() super().__init__()
self.layers = nn.ModuleList([_TransformerDecoderLayer(cfg) for _ in range(cfg.n_decoder_layers)]) self.layers = nn.ModuleList([ACTDecoderLayer(config) for _ in range(config.n_decoder_layers)])
self.norm = nn.LayerNorm(cfg.d_model) self.norm = nn.LayerNorm(config.dim_model)
def forward( def forward(
self, self,
@ -408,26 +423,26 @@ class _TransformerDecoder(nn.Module):
return x return x
class _TransformerDecoderLayer(nn.Module): class ACTDecoderLayer(nn.Module):
def __init__(self, cfg: ActionChunkingTransformerConfig): def __init__(self, config: ACTConfig):
super().__init__() super().__init__()
self.self_attn = nn.MultiheadAttention(cfg.d_model, cfg.n_heads, dropout=cfg.dropout) self.self_attn = nn.MultiheadAttention(config.dim_model, config.n_heads, dropout=config.dropout)
self.multihead_attn = nn.MultiheadAttention(cfg.d_model, cfg.n_heads, dropout=cfg.dropout) self.multihead_attn = nn.MultiheadAttention(config.dim_model, config.n_heads, dropout=config.dropout)
# Feed forward layers. # Feed forward layers.
self.linear1 = nn.Linear(cfg.d_model, cfg.dim_feedforward) self.linear1 = nn.Linear(config.dim_model, config.dim_feedforward)
self.dropout = nn.Dropout(cfg.dropout) self.dropout = nn.Dropout(config.dropout)
self.linear2 = nn.Linear(cfg.dim_feedforward, cfg.d_model) self.linear2 = nn.Linear(config.dim_feedforward, config.dim_model)
self.norm1 = nn.LayerNorm(cfg.d_model) self.norm1 = nn.LayerNorm(config.dim_model)
self.norm2 = nn.LayerNorm(cfg.d_model) self.norm2 = nn.LayerNorm(config.dim_model)
self.norm3 = nn.LayerNorm(cfg.d_model) self.norm3 = nn.LayerNorm(config.dim_model)
self.dropout1 = nn.Dropout(cfg.dropout) self.dropout1 = nn.Dropout(config.dropout)
self.dropout2 = nn.Dropout(cfg.dropout) self.dropout2 = nn.Dropout(config.dropout)
self.dropout3 = nn.Dropout(cfg.dropout) self.dropout3 = nn.Dropout(config.dropout)
self.activation = _get_activation_fn(cfg.feedforward_activation) self.activation = get_activation_fn(config.feedforward_activation)
self.pre_norm = cfg.pre_norm self.pre_norm = config.pre_norm
def maybe_add_pos_embed(self, tensor: Tensor, pos_embed: Tensor | None) -> Tensor: def maybe_add_pos_embed(self, tensor: Tensor, pos_embed: Tensor | None) -> Tensor:
return tensor if pos_embed is None else tensor + pos_embed return tensor if pos_embed is None else tensor + pos_embed
@ -480,7 +495,7 @@ class _TransformerDecoderLayer(nn.Module):
return x return x
def _create_sinusoidal_position_embedding(num_positions: int, dimension: int) -> Tensor: def create_sinusoidal_pos_embedding(num_positions: int, dimension: int) -> Tensor:
"""1D sinusoidal positional embeddings as in Attention is All You Need. """1D sinusoidal positional embeddings as in Attention is All You Need.
Args: Args:
@ -498,7 +513,7 @@ def _create_sinusoidal_position_embedding(num_positions: int, dimension: int) ->
return torch.from_numpy(sinusoid_table).float() return torch.from_numpy(sinusoid_table).float()
class _SinusoidalPositionEmbedding2D(nn.Module): class ACTSinusoidalPositionEmbedding2d(nn.Module):
"""2D sinusoidal positional embeddings similar to what's presented in Attention Is All You Need. """2D sinusoidal positional embeddings similar to what's presented in Attention Is All You Need.
The variation is that the position indices are normalized in [0, 2π] (not quite: the lower bound is 1/H The variation is that the position indices are normalized in [0, 2π] (not quite: the lower bound is 1/H
@ -552,7 +567,7 @@ class _SinusoidalPositionEmbedding2D(nn.Module):
return pos_embed return pos_embed
def _get_activation_fn(activation: str) -> Callable: def get_activation_fn(activation: str) -> Callable:
"""Return an activation function given a string.""" """Return an activation function given a string."""
if activation == "relu": if activation == "relu":
return F.relu return F.relu

View File

@ -3,7 +3,7 @@ from dataclasses import dataclass, field
@dataclass @dataclass
class DiffusionConfig: class DiffusionConfig:
"""Configuration class for Diffusion Policy. """Configuration class for DiffusionPolicy.
Defaults are configured for training with PushT providing proprioceptive and single camera observations. Defaults are configured for training with PushT providing proprioceptive and single camera observations.
@ -25,11 +25,12 @@ class DiffusionConfig:
The key represents the output data name, and the value is a list indicating the dimensions The key represents the output data name, and the value is a list indicating the dimensions
of the corresponding data. For example, "action" refers to an output shape of [14], indicating of the corresponding data. For example, "action" refers to an output shape of [14], indicating
14-dimensional actions. Importantly, shapes doesnt include batch dimension or temporal dimension. 14-dimensional actions. Importantly, shapes doesnt include batch dimension or temporal dimension.
normalize_input_modes: A dictionary with key represents the modality (e.g. "observation.state"), input_normalization_modes: A dictionary with key representing the modality (e.g. "observation.state"),
and the value specifies the normalization mode to apply. The two availables and the value specifies the normalization mode to apply. The two available modes are "mean_std"
modes are "mean_std" which substracts the mean and divide by the standard which subtracts the mean and divides by the standard deviation and "min_max" which rescale in a
deviation and "min_max" which rescale in a [-1, 1] range. [-1, 1] range.
unnormalize_output_modes: Similar dictionary as `normalize_input_modes`, but to unormalize in original scale. output_normalization_modes: Similar dictionary as `normalize_input_modes`, but to unnormalize to the
original scale. Note that this is also used for normalizing the training targets.
vision_backbone: Name of the torchvision resnet backbone to use for encoding images. vision_backbone: Name of the torchvision resnet backbone to use for encoding images.
crop_shape: (H, W) shape to crop images to as a preprocessing step for the vision backbone. Must fit crop_shape: (H, W) shape to crop images to as a preprocessing step for the vision backbone. Must fit
within the image size. If None, no cropping is done. within the image size. If None, no cropping is done.
@ -70,13 +71,13 @@ class DiffusionConfig:
horizon: int = 16 horizon: int = 16
n_action_steps: int = 8 n_action_steps: int = 8
input_shapes: dict[str, list[str]] = field( input_shapes: dict[str, list[int]] = field(
default_factory=lambda: { default_factory=lambda: {
"observation.image": [3, 96, 96], "observation.image": [3, 96, 96],
"observation.state": [2], "observation.state": [2],
} }
) )
output_shapes: dict[str, list[str]] = field( output_shapes: dict[str, list[int]] = field(
default_factory=lambda: { default_factory=lambda: {
"action": [2], "action": [2],
} }
@ -119,15 +120,6 @@ class DiffusionConfig:
# --- # ---
# TODO(alexander-soare): Remove these from the policy config. # TODO(alexander-soare): Remove these from the policy config.
batch_size: int = 64
grad_clip_norm: int = 10
lr: float = 1.0e-4
lr_scheduler: str = "cosine"
lr_warmup_steps: int = 500
adam_betas: tuple[float, float] = (0.95, 0.999)
adam_eps: float = 1.0e-8
adam_weight_decay: float = 1.0e-6
utd: int = 1
use_ema: bool = True use_ema: bool = True
ema_update_after_step: int = 0 ema_update_after_step: int = 0
ema_min_alpha: float = 0.0 ema_min_alpha: float = 0.0

View File

@ -9,7 +9,6 @@ TODO(alexander-soare):
""" """
import copy import copy
import logging
import math import math
from collections import deque from collections import deque
from typing import Callable from typing import Callable
@ -19,6 +18,7 @@ import torch
import torch.nn.functional as F # noqa: N812 import torch.nn.functional as F # noqa: N812
import torchvision import torchvision
from diffusers.schedulers.scheduling_ddpm import DDPMScheduler from diffusers.schedulers.scheduling_ddpm import DDPMScheduler
from huggingface_hub import PyTorchModelHubMixin
from robomimic.models.base_nets import SpatialSoftmax from robomimic.models.base_nets import SpatialSoftmax
from torch import Tensor, nn from torch import Tensor, nn
from torch.nn.modules.batchnorm import _BatchNorm from torch.nn.modules.batchnorm import _BatchNorm
@ -32,7 +32,7 @@ from lerobot.common.policies.utils import (
) )
class DiffusionPolicy(nn.Module): class DiffusionPolicy(nn.Module, PyTorchModelHubMixin):
""" """
Diffusion Policy as per "Diffusion Policy: Visuomotor Policy Learning via Action Diffusion" Diffusion Policy as per "Diffusion Policy: Visuomotor Policy Learning via Action Diffusion"
(paper: https://arxiv.org/abs/2303.04137, code: https://github.com/real-stanford/diffusion_policy). (paper: https://arxiv.org/abs/2303.04137, code: https://github.com/real-stanford/diffusion_policy).
@ -41,49 +41,55 @@ class DiffusionPolicy(nn.Module):
name = "diffusion" name = "diffusion"
def __init__( def __init__(
self, cfg: DiffusionConfig | None = None, lr_scheduler_num_training_steps: int = 0, dataset_stats=None self,
config: DiffusionConfig | None = None,
dataset_stats: dict[str, dict[str, Tensor]] | None = None,
): ):
""" """
Args: Args:
cfg: Policy configuration class instance or None, in which case the default instantiation of the config: Policy configuration class instance or None, in which case the default instantiation of
configuration class is used. the configuration class is used.
dataset_stats: Dataset statistics to be used for normalization. If not passed here, it is expected
that they will be passed with a call to `load_state_dict` before the policy is used.
""" """
super().__init__() super().__init__()
# TODO(alexander-soare): LR scheduler will be removed. if config is None:
assert lr_scheduler_num_training_steps > 0 config = DiffusionConfig()
if cfg is None: self.config = config
cfg = DiffusionConfig() self.normalize_inputs = Normalize(
self.cfg = cfg config.input_shapes, config.input_normalization_modes, dataset_stats
self.normalize_inputs = Normalize(cfg.input_shapes, cfg.input_normalization_modes, dataset_stats) )
self.normalize_targets = Normalize(cfg.output_shapes, cfg.output_normalization_modes, dataset_stats) self.normalize_targets = Normalize(
config.output_shapes, config.output_normalization_modes, dataset_stats
)
self.unnormalize_outputs = Unnormalize( self.unnormalize_outputs = Unnormalize(
cfg.output_shapes, cfg.output_normalization_modes, dataset_stats config.output_shapes, config.output_normalization_modes, dataset_stats
) )
# queues are populated during rollout of the policy, they contain the n latest observations and actions # queues are populated during rollout of the policy, they contain the n latest observations and actions
self._queues = None self._queues = None
self.diffusion = _DiffusionUnetImagePolicy(cfg) self.diffusion = DiffusionModel(config)
# TODO(alexander-soare): This should probably be managed outside of the policy class. # TODO(alexander-soare): This should probably be managed outside of the policy class.
self.ema_diffusion = None self.ema_diffusion = None
self.ema = None self.ema = None
if self.cfg.use_ema: if self.config.use_ema:
self.ema_diffusion = copy.deepcopy(self.diffusion) self.ema_diffusion = copy.deepcopy(self.diffusion)
self.ema = _EMA(cfg, model=self.ema_diffusion) self.ema = DiffusionEMA(config, model=self.ema_diffusion)
def reset(self): def reset(self):
""" """
Clear observation and action queues. Should be called on `env.reset()` Clear observation and action queues. Should be called on `env.reset()`
""" """
self._queues = { self._queues = {
"observation.image": deque(maxlen=self.cfg.n_obs_steps), "observation.image": deque(maxlen=self.config.n_obs_steps),
"observation.state": deque(maxlen=self.cfg.n_obs_steps), "observation.state": deque(maxlen=self.config.n_obs_steps),
"action": deque(maxlen=self.cfg.n_action_steps), "action": deque(maxlen=self.config.n_action_steps),
} }
@torch.no_grad @torch.no_grad
def select_action(self, batch: dict[str, Tensor], **_) -> Tensor: def select_action(self, batch: dict[str, Tensor]) -> Tensor:
"""Select a single action given environment observations. """Select a single action given environment observations.
This method handles caching a history of observations and an action trajectory generated by the This method handles caching a history of observations and an action trajectory generated by the
@ -131,53 +137,41 @@ class DiffusionPolicy(nn.Module):
action = self._queues["action"].popleft() action = self._queues["action"].popleft()
return action return action
def forward(self, batch: dict[str, Tensor], **_) -> dict[str, Tensor]: def forward(self, batch: dict[str, Tensor]) -> dict[str, Tensor]:
"""Run the batch through the model and compute the loss for training or validation.""" """Run the batch through the model and compute the loss for training or validation."""
batch = self.normalize_inputs(batch) batch = self.normalize_inputs(batch)
batch = self.normalize_targets(batch) batch = self.normalize_targets(batch)
loss = self.diffusion.compute_loss(batch) loss = self.diffusion.compute_loss(batch)
return {"loss": loss} return {"loss": loss}
def save(self, fp):
torch.save(self.state_dict(), fp)
def load(self, fp): class DiffusionModel(nn.Module):
d = torch.load(fp) def __init__(self, config: DiffusionConfig):
missing_keys, unexpected_keys = self.load_state_dict(d, strict=False)
if len(missing_keys) > 0:
assert all(k.startswith("ema_diffusion.") for k in missing_keys)
logging.warning(
"DiffusionPolicy.load expected ema parameters in loaded state dict but none were found."
)
assert len(unexpected_keys) == 0
class _DiffusionUnetImagePolicy(nn.Module):
def __init__(self, cfg: DiffusionConfig):
super().__init__() super().__init__()
self.cfg = cfg self.config = config
self.rgb_encoder = _RgbEncoder(cfg) self.rgb_encoder = DiffusionRgbEncoder(config)
self.unet = _ConditionalUnet1D( self.unet = DiffusionConditionalUnet1d(
cfg, config,
global_cond_dim=(cfg.output_shapes["action"][0] + self.rgb_encoder.feature_dim) * cfg.n_obs_steps, global_cond_dim=(config.output_shapes["action"][0] + self.rgb_encoder.feature_dim)
* config.n_obs_steps,
) )
self.noise_scheduler = DDPMScheduler( self.noise_scheduler = DDPMScheduler(
num_train_timesteps=cfg.num_train_timesteps, num_train_timesteps=config.num_train_timesteps,
beta_start=cfg.beta_start, beta_start=config.beta_start,
beta_end=cfg.beta_end, beta_end=config.beta_end,
beta_schedule=cfg.beta_schedule, beta_schedule=config.beta_schedule,
variance_type="fixed_small", variance_type="fixed_small",
clip_sample=cfg.clip_sample, clip_sample=config.clip_sample,
clip_sample_range=cfg.clip_sample_range, clip_sample_range=config.clip_sample_range,
prediction_type=cfg.prediction_type, prediction_type=config.prediction_type,
) )
if cfg.num_inference_steps is None: if config.num_inference_steps is None:
self.num_inference_steps = self.noise_scheduler.config.num_train_timesteps self.num_inference_steps = self.noise_scheduler.config.num_train_timesteps
else: else:
self.num_inference_steps = cfg.num_inference_steps self.num_inference_steps = config.num_inference_steps
# ========= inference ============ # ========= inference ============
def conditional_sample( def conditional_sample(
@ -188,7 +182,7 @@ class _DiffusionUnetImagePolicy(nn.Module):
# Sample prior. # Sample prior.
sample = torch.randn( sample = torch.randn(
size=(batch_size, self.cfg.horizon, self.cfg.output_shapes["action"][0]), size=(batch_size, self.config.horizon, self.config.output_shapes["action"][0]),
dtype=dtype, dtype=dtype,
device=device, device=device,
generator=generator, generator=generator,
@ -218,7 +212,7 @@ class _DiffusionUnetImagePolicy(nn.Module):
""" """
assert set(batch).issuperset({"observation.state", "observation.image"}) assert set(batch).issuperset({"observation.state", "observation.image"})
batch_size, n_obs_steps = batch["observation.state"].shape[:2] batch_size, n_obs_steps = batch["observation.state"].shape[:2]
assert n_obs_steps == self.cfg.n_obs_steps assert n_obs_steps == self.config.n_obs_steps
# Extract image feature (first combine batch and sequence dims). # Extract image feature (first combine batch and sequence dims).
img_features = self.rgb_encoder(einops.rearrange(batch["observation.image"], "b n ... -> (b n) ...")) img_features = self.rgb_encoder(einops.rearrange(batch["observation.image"], "b n ... -> (b n) ..."))
@ -231,10 +225,10 @@ class _DiffusionUnetImagePolicy(nn.Module):
sample = self.conditional_sample(batch_size, global_cond=global_cond) sample = self.conditional_sample(batch_size, global_cond=global_cond)
# `horizon` steps worth of actions (from the first observation). # `horizon` steps worth of actions (from the first observation).
actions = sample[..., : self.cfg.output_shapes["action"][0]] actions = sample[..., : self.config.output_shapes["action"][0]]
# Extract `n_action_steps` steps worth of actions (from the current observation). # Extract `n_action_steps` steps worth of actions (from the current observation).
start = n_obs_steps - 1 start = n_obs_steps - 1
end = start + self.cfg.n_action_steps end = start + self.config.n_action_steps
actions = actions[:, start:end] actions = actions[:, start:end]
return actions return actions
@ -253,8 +247,8 @@ class _DiffusionUnetImagePolicy(nn.Module):
assert set(batch).issuperset({"observation.state", "observation.image", "action", "action_is_pad"}) assert set(batch).issuperset({"observation.state", "observation.image", "action", "action_is_pad"})
batch_size, n_obs_steps = batch["observation.state"].shape[:2] batch_size, n_obs_steps = batch["observation.state"].shape[:2]
horizon = batch["action"].shape[1] horizon = batch["action"].shape[1]
assert horizon == self.cfg.horizon assert horizon == self.config.horizon
assert n_obs_steps == self.cfg.n_obs_steps assert n_obs_steps == self.config.n_obs_steps
# Extract image feature (first combine batch and sequence dims). # Extract image feature (first combine batch and sequence dims).
img_features = self.rgb_encoder(einops.rearrange(batch["observation.image"], "b n ... -> (b n) ...")) img_features = self.rgb_encoder(einops.rearrange(batch["observation.image"], "b n ... -> (b n) ..."))
@ -283,12 +277,12 @@ class _DiffusionUnetImagePolicy(nn.Module):
# Compute the loss. # Compute the loss.
# The target is either the original trajectory, or the noise. # The target is either the original trajectory, or the noise.
if self.cfg.prediction_type == "epsilon": if self.config.prediction_type == "epsilon":
target = eps target = eps
elif self.cfg.prediction_type == "sample": elif self.config.prediction_type == "sample":
target = batch["action"] target = batch["action"]
else: else:
raise ValueError(f"Unsupported prediction type {self.cfg.prediction_type}") raise ValueError(f"Unsupported prediction type {self.config.prediction_type}")
loss = F.mse_loss(pred, target, reduction="none") loss = F.mse_loss(pred, target, reduction="none")
@ -300,35 +294,35 @@ class _DiffusionUnetImagePolicy(nn.Module):
return loss.mean() return loss.mean()
class _RgbEncoder(nn.Module): class DiffusionRgbEncoder(nn.Module):
"""Encoder an RGB image into a 1D feature vector. """Encoder an RGB image into a 1D feature vector.
Includes the ability to normalize and crop the image first. Includes the ability to normalize and crop the image first.
""" """
def __init__(self, cfg: DiffusionConfig): def __init__(self, config: DiffusionConfig):
super().__init__() super().__init__()
# Set up optional preprocessing. # Set up optional preprocessing.
if cfg.crop_shape is not None: if config.crop_shape is not None:
self.do_crop = True self.do_crop = True
# Always use center crop for eval # Always use center crop for eval
self.center_crop = torchvision.transforms.CenterCrop(cfg.crop_shape) self.center_crop = torchvision.transforms.CenterCrop(config.crop_shape)
if cfg.crop_is_random: if config.crop_is_random:
self.maybe_random_crop = torchvision.transforms.RandomCrop(cfg.crop_shape) self.maybe_random_crop = torchvision.transforms.RandomCrop(config.crop_shape)
else: else:
self.maybe_random_crop = self.center_crop self.maybe_random_crop = self.center_crop
else: else:
self.do_crop = False self.do_crop = False
# Set up backbone. # Set up backbone.
backbone_model = getattr(torchvision.models, cfg.vision_backbone)( backbone_model = getattr(torchvision.models, config.vision_backbone)(
weights=cfg.pretrained_backbone_weights weights=config.pretrained_backbone_weights
) )
# Note: This assumes that the layer4 feature map is children()[-3] # Note: This assumes that the layer4 feature map is children()[-3]
# TODO(alexander-soare): Use a safer alternative. # TODO(alexander-soare): Use a safer alternative.
self.backbone = nn.Sequential(*(list(backbone_model.children())[:-2])) self.backbone = nn.Sequential(*(list(backbone_model.children())[:-2]))
if cfg.use_group_norm: if config.use_group_norm:
if cfg.pretrained_backbone_weights: if config.pretrained_backbone_weights:
raise ValueError( raise ValueError(
"You can't replace BatchNorm in a pretrained model without ruining the weights!" "You can't replace BatchNorm in a pretrained model without ruining the weights!"
) )
@ -342,11 +336,11 @@ class _RgbEncoder(nn.Module):
# Use a dry run to get the feature map shape. # Use a dry run to get the feature map shape.
with torch.inference_mode(): with torch.inference_mode():
feat_map_shape = tuple( feat_map_shape = tuple(
self.backbone(torch.zeros(size=(1, *cfg.input_shapes["observation.image"]))).shape[1:] self.backbone(torch.zeros(size=(1, *config.input_shapes["observation.image"]))).shape[1:]
) )
self.pool = SpatialSoftmax(feat_map_shape, num_kp=cfg.spatial_softmax_num_keypoints) self.pool = SpatialSoftmax(feat_map_shape, num_kp=config.spatial_softmax_num_keypoints)
self.feature_dim = cfg.spatial_softmax_num_keypoints * 2 self.feature_dim = config.spatial_softmax_num_keypoints * 2
self.out = nn.Linear(cfg.spatial_softmax_num_keypoints * 2, self.feature_dim) self.out = nn.Linear(config.spatial_softmax_num_keypoints * 2, self.feature_dim)
self.relu = nn.ReLU() self.relu = nn.ReLU()
def forward(self, x: Tensor) -> Tensor: def forward(self, x: Tensor) -> Tensor:
@ -403,7 +397,7 @@ def _replace_submodules(
return root_module return root_module
class _SinusoidalPosEmb(nn.Module): class DiffusionSinusoidalPosEmb(nn.Module):
"""1D sinusoidal positional embeddings as in Attention is All You Need.""" """1D sinusoidal positional embeddings as in Attention is All You Need."""
def __init__(self, dim: int): def __init__(self, dim: int):
@ -420,7 +414,7 @@ class _SinusoidalPosEmb(nn.Module):
return emb return emb
class _Conv1dBlock(nn.Module): class DiffusionConv1dBlock(nn.Module):
"""Conv1d --> GroupNorm --> Mish""" """Conv1d --> GroupNorm --> Mish"""
def __init__(self, inp_channels, out_channels, kernel_size, n_groups=8): def __init__(self, inp_channels, out_channels, kernel_size, n_groups=8):
@ -436,40 +430,40 @@ class _Conv1dBlock(nn.Module):
return self.block(x) return self.block(x)
class _ConditionalUnet1D(nn.Module): class DiffusionConditionalUnet1d(nn.Module):
"""A 1D convolutional UNet with FiLM modulation for conditioning. """A 1D convolutional UNet with FiLM modulation for conditioning.
Note: this removes local conditioning as compared to the original diffusion policy code. Note: this removes local conditioning as compared to the original diffusion policy code.
""" """
def __init__(self, cfg: DiffusionConfig, global_cond_dim: int): def __init__(self, config: DiffusionConfig, global_cond_dim: int):
super().__init__() super().__init__()
self.cfg = cfg self.config = config
# Encoder for the diffusion timestep. # Encoder for the diffusion timestep.
self.diffusion_step_encoder = nn.Sequential( self.diffusion_step_encoder = nn.Sequential(
_SinusoidalPosEmb(cfg.diffusion_step_embed_dim), DiffusionSinusoidalPosEmb(config.diffusion_step_embed_dim),
nn.Linear(cfg.diffusion_step_embed_dim, cfg.diffusion_step_embed_dim * 4), nn.Linear(config.diffusion_step_embed_dim, config.diffusion_step_embed_dim * 4),
nn.Mish(), nn.Mish(),
nn.Linear(cfg.diffusion_step_embed_dim * 4, cfg.diffusion_step_embed_dim), nn.Linear(config.diffusion_step_embed_dim * 4, config.diffusion_step_embed_dim),
) )
# The FiLM conditioning dimension. # The FiLM conditioning dimension.
cond_dim = cfg.diffusion_step_embed_dim + global_cond_dim cond_dim = config.diffusion_step_embed_dim + global_cond_dim
# In channels / out channels for each downsampling block in the Unet's encoder. For the decoder, we # In channels / out channels for each downsampling block in the Unet's encoder. For the decoder, we
# just reverse these. # just reverse these.
in_out = [(cfg.output_shapes["action"][0], cfg.down_dims[0])] + list( in_out = [(config.output_shapes["action"][0], config.down_dims[0])] + list(
zip(cfg.down_dims[:-1], cfg.down_dims[1:], strict=True) zip(config.down_dims[:-1], config.down_dims[1:], strict=True)
) )
# Unet encoder. # Unet encoder.
common_res_block_kwargs = { common_res_block_kwargs = {
"cond_dim": cond_dim, "cond_dim": cond_dim,
"kernel_size": cfg.kernel_size, "kernel_size": config.kernel_size,
"n_groups": cfg.n_groups, "n_groups": config.n_groups,
"use_film_scale_modulation": cfg.use_film_scale_modulation, "use_film_scale_modulation": config.use_film_scale_modulation,
} }
self.down_modules = nn.ModuleList([]) self.down_modules = nn.ModuleList([])
for ind, (dim_in, dim_out) in enumerate(in_out): for ind, (dim_in, dim_out) in enumerate(in_out):
@ -477,8 +471,8 @@ class _ConditionalUnet1D(nn.Module):
self.down_modules.append( self.down_modules.append(
nn.ModuleList( nn.ModuleList(
[ [
_ConditionalResidualBlock1D(dim_in, dim_out, **common_res_block_kwargs), DiffusionConditionalResidualBlock1d(dim_in, dim_out, **common_res_block_kwargs),
_ConditionalResidualBlock1D(dim_out, dim_out, **common_res_block_kwargs), DiffusionConditionalResidualBlock1d(dim_out, dim_out, **common_res_block_kwargs),
# Downsample as long as it is not the last block. # Downsample as long as it is not the last block.
nn.Conv1d(dim_out, dim_out, 3, 2, 1) if not is_last else nn.Identity(), nn.Conv1d(dim_out, dim_out, 3, 2, 1) if not is_last else nn.Identity(),
] ]
@ -488,8 +482,12 @@ class _ConditionalUnet1D(nn.Module):
# Processing in the middle of the auto-encoder. # Processing in the middle of the auto-encoder.
self.mid_modules = nn.ModuleList( self.mid_modules = nn.ModuleList(
[ [
_ConditionalResidualBlock1D(cfg.down_dims[-1], cfg.down_dims[-1], **common_res_block_kwargs), DiffusionConditionalResidualBlock1d(
_ConditionalResidualBlock1D(cfg.down_dims[-1], cfg.down_dims[-1], **common_res_block_kwargs), config.down_dims[-1], config.down_dims[-1], **common_res_block_kwargs
),
DiffusionConditionalResidualBlock1d(
config.down_dims[-1], config.down_dims[-1], **common_res_block_kwargs
),
] ]
) )
@ -501,8 +499,8 @@ class _ConditionalUnet1D(nn.Module):
nn.ModuleList( nn.ModuleList(
[ [
# dim_in * 2, because it takes the encoder's skip connection as well # dim_in * 2, because it takes the encoder's skip connection as well
_ConditionalResidualBlock1D(dim_in * 2, dim_out, **common_res_block_kwargs), DiffusionConditionalResidualBlock1d(dim_in * 2, dim_out, **common_res_block_kwargs),
_ConditionalResidualBlock1D(dim_out, dim_out, **common_res_block_kwargs), DiffusionConditionalResidualBlock1d(dim_out, dim_out, **common_res_block_kwargs),
# Upsample as long as it is not the last block. # Upsample as long as it is not the last block.
nn.ConvTranspose1d(dim_out, dim_out, 4, 2, 1) if not is_last else nn.Identity(), nn.ConvTranspose1d(dim_out, dim_out, 4, 2, 1) if not is_last else nn.Identity(),
] ]
@ -510,8 +508,8 @@ class _ConditionalUnet1D(nn.Module):
) )
self.final_conv = nn.Sequential( self.final_conv = nn.Sequential(
_Conv1dBlock(cfg.down_dims[0], cfg.down_dims[0], kernel_size=cfg.kernel_size), DiffusionConv1dBlock(config.down_dims[0], config.down_dims[0], kernel_size=config.kernel_size),
nn.Conv1d(cfg.down_dims[0], cfg.output_shapes["action"][0], 1), nn.Conv1d(config.down_dims[0], config.output_shapes["action"][0], 1),
) )
def forward(self, x: Tensor, timestep: Tensor | int, global_cond=None) -> Tensor: def forward(self, x: Tensor, timestep: Tensor | int, global_cond=None) -> Tensor:
@ -559,7 +557,7 @@ class _ConditionalUnet1D(nn.Module):
return x return x
class _ConditionalResidualBlock1D(nn.Module): class DiffusionConditionalResidualBlock1d(nn.Module):
"""ResNet style 1D convolutional block with FiLM modulation for conditioning.""" """ResNet style 1D convolutional block with FiLM modulation for conditioning."""
def __init__( def __init__(
@ -578,13 +576,13 @@ class _ConditionalResidualBlock1D(nn.Module):
self.use_film_scale_modulation = use_film_scale_modulation self.use_film_scale_modulation = use_film_scale_modulation
self.out_channels = out_channels self.out_channels = out_channels
self.conv1 = _Conv1dBlock(in_channels, out_channels, kernel_size, n_groups=n_groups) self.conv1 = DiffusionConv1dBlock(in_channels, out_channels, kernel_size, n_groups=n_groups)
# FiLM modulation (https://arxiv.org/abs/1709.07871) outputs per-channel bias and (maybe) scale. # FiLM modulation (https://arxiv.org/abs/1709.07871) outputs per-channel bias and (maybe) scale.
cond_channels = out_channels * 2 if use_film_scale_modulation else out_channels cond_channels = out_channels * 2 if use_film_scale_modulation else out_channels
self.cond_encoder = nn.Sequential(nn.Mish(), nn.Linear(cond_dim, cond_channels)) self.cond_encoder = nn.Sequential(nn.Mish(), nn.Linear(cond_dim, cond_channels))
self.conv2 = _Conv1dBlock(out_channels, out_channels, kernel_size, n_groups=n_groups) self.conv2 = DiffusionConv1dBlock(out_channels, out_channels, kernel_size, n_groups=n_groups)
# A final convolution for dimension matching the residual (if needed). # A final convolution for dimension matching the residual (if needed).
self.residual_conv = ( self.residual_conv = (
@ -617,18 +615,18 @@ class _ConditionalResidualBlock1D(nn.Module):
return out return out
class _EMA: class DiffusionEMA:
""" """
Exponential Moving Average of models weights Exponential Moving Average of models weights
""" """
def __init__(self, cfg: DiffusionConfig, model: nn.Module): def __init__(self, config: DiffusionConfig, model: nn.Module):
""" """
@crowsonkb's notes on EMA Warmup: @crowsonkb's notes on EMA Warmup:
If gamma=1 and power=1, implements a simple average. gamma=1, power=2/3 are good values for models you plan If gamma=1 and power=1, implements a simple average. gamma=1, power=2/3 are good values for models
to train for a million or more steps (reaches decay factor 0.999 at 31.6K steps, 0.9999 at 1M steps), you plan to train for a million or more steps (reaches decay factor 0.999 at 31.6K steps, 0.9999
gamma=1, power=3/4 for models you plan to train for less (reaches decay factor 0.999 at 10K steps, 0.9999 at 1M steps), gamma=1, power=3/4 for models you plan to train for less (reaches decay factor 0.999
at 215.4k steps). at 10K steps, 0.9999 at 215.4k steps).
Args: Args:
inv_gamma (float): Inverse multiplicative factor of EMA warmup. Default: 1. inv_gamma (float): Inverse multiplicative factor of EMA warmup. Default: 1.
power (float): Exponential factor of EMA warmup. Default: 2/3. power (float): Exponential factor of EMA warmup. Default: 2/3.
@ -639,11 +637,11 @@ class _EMA:
self.averaged_model.eval() self.averaged_model.eval()
self.averaged_model.requires_grad_(False) self.averaged_model.requires_grad_(False)
self.update_after_step = cfg.ema_update_after_step self.update_after_step = config.ema_update_after_step
self.inv_gamma = cfg.ema_inv_gamma self.inv_gamma = config.ema_inv_gamma
self.power = cfg.ema_power self.power = config.ema_power
self.min_alpha = cfg.ema_min_alpha self.min_alpha = config.ema_min_alpha
self.max_alpha = cfg.ema_max_alpha self.max_alpha = config.ema_max_alpha
self.alpha = 0.0 self.alpha = 0.0
self.optimization_step = 0 self.optimization_step = 0

View File

@ -2,6 +2,7 @@ import inspect
from omegaconf import DictConfig, OmegaConf from omegaconf import DictConfig, OmegaConf
from lerobot.common.policies.policy_protocol import Policy
from lerobot.common.utils.utils import get_safe_torch_device from lerobot.common.utils.utils import get_safe_torch_device
@ -20,42 +21,53 @@ def _policy_cfg_from_hydra_cfg(policy_cfg_class, hydra_cfg):
return policy_cfg return policy_cfg
def make_policy(hydra_cfg: DictConfig, dataset_stats=None): def get_policy_and_config_classes(name: str) -> tuple[Policy, object]:
if hydra_cfg.policy.name == "tdmpc": """Get the policy's class and config class given a name (matching the policy class' `name` attribute)."""
from lerobot.common.policies.tdmpc.policy import TDMPCPolicy if name == "tdmpc":
from lerobot.common.policies.tdmpc.configuration_tdmpc import TDMPCConfig
from lerobot.common.policies.tdmpc.modeling_tdmpc import TDMPCPolicy
policy = TDMPCPolicy( return TDMPCPolicy, TDMPCConfig
hydra_cfg.policy, elif name == "diffusion":
n_obs_steps=hydra_cfg.n_obs_steps,
n_action_steps=hydra_cfg.n_action_steps,
device=hydra_cfg.device,
)
elif hydra_cfg.policy.name == "diffusion":
from lerobot.common.policies.diffusion.configuration_diffusion import DiffusionConfig from lerobot.common.policies.diffusion.configuration_diffusion import DiffusionConfig
from lerobot.common.policies.diffusion.modeling_diffusion import DiffusionPolicy from lerobot.common.policies.diffusion.modeling_diffusion import DiffusionPolicy
policy_cfg = _policy_cfg_from_hydra_cfg(DiffusionConfig, hydra_cfg) return DiffusionPolicy, DiffusionConfig
policy = DiffusionPolicy(policy_cfg, hydra_cfg.offline_steps, dataset_stats) elif name == "act":
policy.to(get_safe_torch_device(hydra_cfg.device)) from lerobot.common.policies.act.configuration_act import ACTConfig
elif hydra_cfg.policy.name == "act": from lerobot.common.policies.act.modeling_act import ACTPolicy
from lerobot.common.policies.act.configuration_act import ActionChunkingTransformerConfig
from lerobot.common.policies.act.modeling_act import ActionChunkingTransformerPolicy
policy_cfg = _policy_cfg_from_hydra_cfg(ActionChunkingTransformerConfig, hydra_cfg) return ACTPolicy, ACTConfig
policy = ActionChunkingTransformerPolicy(policy_cfg, dataset_stats)
policy.to(get_safe_torch_device(hydra_cfg.device))
else: else:
raise ValueError(hydra_cfg.policy.name) raise NotImplementedError(f"Policy with name {name} is not implemented.")
if hydra_cfg.policy.pretrained_model_path:
# TODO(rcadene): hack for old pretrained models from fowm def make_policy(
if hydra_cfg.policy.name == "tdmpc" and "fowm" in hydra_cfg.policy.pretrained_model_path: hydra_cfg: DictConfig, pretrained_policy_name_or_path: str | None = None, dataset_stats=None
if "offline" in hydra_cfg.policy.pretrained_model_path: ) -> Policy:
policy.step[0] = 25000 """Make an instance of a policy class.
elif "final" in hydra_cfg.policy.pretrained_model_path:
policy.step[0] = 100000 Args:
hydra_cfg: A parsed Hydra configuration (see scripts). If `pretrained_policy_name_or_path` is
provided, only `hydra_cfg.policy.name` is used while everything else is ignored.
pretrained_policy_name_or_path: Either the repo ID of a model hosted on the Hub or a path to a
directory containing weights saved using `Policy.save_pretrained`. Note that providing this
argument overrides everything in `hydra_cfg.policy` apart from `hydra_cfg.policy.name`.
dataset_stats: Dataset statistics to use for (un)normalization of inputs/outputs in the policy. Must
be provided when initializing a new policy, and must not be provided when loading a pretrained
policy. Therefore, this argument is mutually exclusive with `pretrained_policy_name_or_path`.
"""
if not (pretrained_policy_name_or_path is None) ^ (dataset_stats is None):
raise ValueError("Only one of `pretrained_policy_name_or_path` and `dataset_stats` may be provided.")
policy_cls, policy_cfg_class = get_policy_and_config_classes(hydra_cfg.policy.name)
if pretrained_policy_name_or_path is None:
policy_cfg = _policy_cfg_from_hydra_cfg(policy_cfg_class, hydra_cfg)
policy = policy_cls(policy_cfg, dataset_stats)
else: else:
raise NotImplementedError() policy = policy_cls.from_pretrained(pretrained_policy_name_or_path)
policy.load(hydra_cfg.policy.pretrained_model_path)
policy.to(get_safe_torch_device(hydra_cfg.device))
return policy return policy

View File

@ -57,17 +57,28 @@ def create_stats_buffers(
) )
if stats is not None: if stats is not None:
# Note: The clone is needed to make sure that the logic in save_pretrained doesn't see duplicated
# tensors anywhere (for example, when we use the same stats for normalization and
# unnormalization). See the logic here
# https://github.com/huggingface/safetensors/blob/079781fd0dc455ba0fe851e2b4507c33d0c0d407/bindings/python/py_src/safetensors/torch.py#L97.
if mode == "mean_std": if mode == "mean_std":
buffer["mean"].data = stats[key]["mean"] buffer["mean"].data = stats[key]["mean"].clone()
buffer["std"].data = stats[key]["std"] buffer["std"].data = stats[key]["std"].clone()
elif mode == "min_max": elif mode == "min_max":
buffer["min"].data = stats[key]["min"] buffer["min"].data = stats[key]["min"].clone()
buffer["max"].data = stats[key]["max"] buffer["max"].data = stats[key]["max"].clone()
stats_buffers[key] = buffer stats_buffers[key] = buffer
return stats_buffers return stats_buffers
def _no_stats_error_str(name: str) -> str:
return (
f"`{name}` is infinity. You should either initialize with `stats` as an argument, or use a "
"pretrained model."
)
class Normalize(nn.Module): class Normalize(nn.Module):
"""Normalizes data (e.g. "observation.image") for more stable and faster convergence during training.""" """Normalizes data (e.g. "observation.image") for more stable and faster convergence during training."""
@ -99,7 +110,6 @@ class Normalize(nn.Module):
self.shapes = shapes self.shapes = shapes
self.modes = modes self.modes = modes
self.stats = stats self.stats = stats
# `self.buffer_observation_state["mean"]` contains `torch.tensor(state_dim)`
stats_buffers = create_stats_buffers(shapes, modes, stats) stats_buffers = create_stats_buffers(shapes, modes, stats)
for key, buffer in stats_buffers.items(): for key, buffer in stats_buffers.items():
setattr(self, "buffer_" + key.replace(".", "_"), buffer) setattr(self, "buffer_" + key.replace(".", "_"), buffer)
@ -113,26 +123,14 @@ class Normalize(nn.Module):
if mode == "mean_std": if mode == "mean_std":
mean = buffer["mean"] mean = buffer["mean"]
std = buffer["std"] std = buffer["std"]
assert not torch.isinf(mean).any(), ( assert not torch.isinf(mean).any(), _no_stats_error_str("mean")
"`mean` is infinity. You forgot to initialize with `stats` as argument, or called " assert not torch.isinf(std).any(), _no_stats_error_str("std")
"`policy.load_state_dict`."
)
assert not torch.isinf(std).any(), (
"`std` is infinity. You forgot to initialize with `stats` as argument, or called "
"`policy.load_state_dict`."
)
batch[key] = (batch[key] - mean) / (std + 1e-8) batch[key] = (batch[key] - mean) / (std + 1e-8)
elif mode == "min_max": elif mode == "min_max":
min = buffer["min"] min = buffer["min"]
max = buffer["max"] max = buffer["max"]
assert not torch.isinf(min).any(), ( assert not torch.isinf(min).any(), _no_stats_error_str("min")
"`min` is infinity. You forgot to initialize with `stats` as argument, or called " assert not torch.isinf(max).any(), _no_stats_error_str("max")
"`policy.load_state_dict`."
)
assert not torch.isinf(max).any(), (
"`max` is infinity. You forgot to initialize with `stats` as argument, or called "
"`policy.load_state_dict`."
)
# normalize to [0,1] # normalize to [0,1]
batch[key] = (batch[key] - min) / (max - min) batch[key] = (batch[key] - min) / (max - min)
# normalize to [-1, 1] # normalize to [-1, 1]
@ -190,26 +188,14 @@ class Unnormalize(nn.Module):
if mode == "mean_std": if mode == "mean_std":
mean = buffer["mean"] mean = buffer["mean"]
std = buffer["std"] std = buffer["std"]
assert not torch.isinf(mean).any(), ( assert not torch.isinf(mean).any(), _no_stats_error_str("mean")
"`mean` is infinity. You forgot to initialize with `stats` as argument, or called " assert not torch.isinf(std).any(), _no_stats_error_str("std")
"`policy.load_state_dict`."
)
assert not torch.isinf(std).any(), (
"`std` is infinity. You forgot to initialize with `stats` as argument, or called "
"`policy.load_state_dict`."
)
batch[key] = batch[key] * std + mean batch[key] = batch[key] * std + mean
elif mode == "min_max": elif mode == "min_max":
min = buffer["min"] min = buffer["min"]
max = buffer["max"] max = buffer["max"]
assert not torch.isinf(min).any(), ( assert not torch.isinf(min).any(), _no_stats_error_str("min")
"`min` is infinity. You forgot to initialize with `stats` as argument, or called " assert not torch.isinf(max).any(), _no_stats_error_str("max")
"`policy.load_state_dict`."
)
assert not torch.isinf(max).any(), (
"`max` is infinity. You forgot to initialize with `stats` as argument, or called "
"`policy.load_state_dict`."
)
batch[key] = (batch[key] + 1) / 2 batch[key] = (batch[key] + 1) / 2
batch[key] = batch[key] * (max - min) + min batch[key] = batch[key] * (max - min) + min
else: else:

View File

@ -14,10 +14,21 @@ from torch import Tensor
@runtime_checkable @runtime_checkable
class Policy(Protocol): class Policy(Protocol):
"""The required interface for implementing a policy.""" """The required interface for implementing a policy.
We also expect all policies to subclass torch.nn.Module and PyTorchModelHubMixin.
"""
name: str name: str
def __init__(self, cfg, dataset_stats: dict[str, dict[str, Tensor]] | None = None):
"""
Args:
cfg: Policy configuration class instance or None, in which case the default instantiation of the
configuration class is used.
dataset_stats: Dataset statistics to be used for normalization.
"""
def reset(self): def reset(self):
"""To be called whenever the environment is reset. """To be called whenever the environment is reset.
@ -36,3 +47,13 @@ class Policy(Protocol):
When the model uses a history of observations, or outputs a sequence of actions, this method deals When the model uses a history of observations, or outputs a sequence of actions, this method deals
with caching. with caching.
""" """
@runtime_checkable
class PolicyWithUpdate(Policy, Protocol):
def update(self):
"""An update method that is to be called after a training optimization step.
Implements an additional updates the model parameters may need (for example, doing an EMA step for a
target model, or incrementing an internal buffer).
"""

View File

@ -0,0 +1,150 @@
from dataclasses import dataclass, field
@dataclass
class TDMPCConfig:
"""Configuration class for TDMPCPolicy.
Defaults are configured for training with xarm_lift_medium_replay providing proprioceptive and single
camera observations.
The parameters you will most likely need to change are the ones which depend on the environment / sensors.
Those are: `input_shapes`, `output_shapes`, and perhaps `max_random_shift`.
Args:
n_action_repeats: The number of times to repeat the action returned by the planning. (hint: Google
action repeats in Q-learning or ask your favorite chatbot)
horizon: Horizon for model predictive control.
input_normalization_modes: A dictionary with key representing the modality (e.g. "observation.state"),
and the value specifies the normalization mode to apply. The two available modes are "mean_std"
which subtracts the mean and divides by the standard deviation and "min_max" which rescale in a
[-1, 1] range. Note that here this defaults to None meaning inputs are not normalized. This is to
match the original implementation.
output_normalization_modes: Similar dictionary as `normalize_input_modes`, but to unnormalize to the
original scale. Note that this is also used for normalizing the training targets. NOTE: Clipping
to [-1, +1] is used during MPPI/CEM. Therefore, it is recommended that you stick with "min_max"
normalization mode here.
image_encoder_hidden_dim: Number of channels for the convolutional layers used for image encoding.
state_encoder_hidden_dim: Hidden dimension for MLP used for state vector encoding.
latent_dim: Observation's latent embedding dimension.
q_ensemble_size: Number of Q function estimators to use in an ensemble for uncertainty estimation.
mlp_dim: Hidden dimension of MLPs used for modelling the dynamics encoder, reward function, policy
(π), Q ensemble, and V.
discount: Discount factor (γ) to use for the reinforcement learning formalism.
use_mpc: Whether to use model predictive control. The alternative is to just sample the policy model
(π) for each step.
cem_iterations: Number of iterations for the MPPI/CEM loop in MPC.
max_std: Maximum standard deviation for actions sampled from the gaussian PDF in CEM.
min_std: Minimum standard deviation for noise applied to actions sampled from the policy model (π).
Doubles up as the minimum standard deviation for actions sampled from the gaussian PDF in CEM.
n_gaussian_samples: Number of samples to draw from the gaussian distribution every CEM iteration. Must
be non-zero.
n_pi_samples: Number of samples to draw from the policy / world model rollout every CEM iteration. Can
be zero.
uncertainty_regularizer_coeff: Coefficient for the uncertainty regularization used when estimating
trajectory values (this is the λ coeffiecient in eqn 4 of FOWM).
n_elites: The number of elite samples to use for updating the gaussian parameters every CEM iteration.
elite_weighting_temperature: The temperature to use for softmax weighting (by trajectory value) of the
elites, when updating the gaussian parameters for CEM.
gaussian_mean_momentum: Momentum (α) used for EMA updates of the mean parameter μ of the gaussian
paramters optimized in CEM. Updates are calculated as μ αμ + (1-α)μ.
max_random_shift_ratio: Maximum random shift (as a proportion of the image size) to apply to the
image(s) (in units of pixels) for training-time augmentation. If set to 0, no such augmentation
is applied. Note that the input images are assumed to be square for this augmentation.
reward_coeff: Loss weighting coefficient for the reward regression loss.
expectile_weight: Weighting (τ) used in expectile regression for the state value function (V).
v_pred < v_target is weighted by τ and v_pred >= v_target is weighted by (1-τ). τ is expected to
be in [0, 1]. Setting τ closer to 1 results in a more "optimistic" V. This is sensible to do
because v_target is obtained by evaluating the learned state-action value functions (Q) with
in-sample actions that may not be always optimal.
value_coeff: Loss weighting coefficient for both the state-action value (Q) TD loss, and the state
value (V) expectile regression loss.
consistency_coeff: Loss weighting coefficient for the consistency loss.
advantage_scaling: A factor by which the advantages are scaled prior to exponentiation for advantage
weighted regression of the policy (π) estimator parameters. Note that the exponentiated advantages
are clamped at 100.0.
pi_coeff: Loss weighting coefficient for the action regression loss.
temporal_decay_coeff: Exponential decay coefficient for decaying the loss coefficient for future time-
steps. Hint: each loss computation involves `horizon` steps worth of actions starting from the
current time step.
target_model_momentum: Momentum (α) used for EMA updates of the target models. Updates are calculated
as ϕ αϕ + (1-α)θ where ϕ are the parameters of the target model and θ are the parameters of the
model being trained.
"""
# Input / output structure.
n_action_repeats: int = 2
horizon: int = 5
input_shapes: dict[str, list[int]] = field(
default_factory=lambda: {
"observation.image": [3, 84, 84],
"observation.state": [4],
}
)
output_shapes: dict[str, list[int]] = field(
default_factory=lambda: {
"action": [4],
}
)
# Normalization / Unnormalization
input_normalization_modes: dict[str, str] | None = None
output_normalization_modes: dict[str, str] = field(
default_factory=lambda: {"action": "min_max"},
)
# Architecture / modeling.
# Neural networks.
image_encoder_hidden_dim: int = 32
state_encoder_hidden_dim: int = 256
latent_dim: int = 50
q_ensemble_size: int = 5
mlp_dim: int = 512
# Reinforcement learning.
discount: float = 0.9
# Inference.
use_mpc: bool = True
cem_iterations: int = 6
max_std: float = 2.0
min_std: float = 0.05
n_gaussian_samples: int = 512
n_pi_samples: int = 51
uncertainty_regularizer_coeff: float = 1.0
n_elites: int = 50
elite_weighting_temperature: float = 0.5
gaussian_mean_momentum: float = 0.1
# Training and loss computation.
max_random_shift_ratio: float = 0.0476
# Loss coefficients.
reward_coeff: float = 0.5
expectile_weight: float = 0.9
value_coeff: float = 0.1
consistency_coeff: float = 20.0
advantage_scaling: float = 3.0
pi_coeff: float = 0.5
temporal_decay_coeff: float = 0.5
# Target model.
target_model_momentum: float = 0.995
def __post_init__(self):
"""Input validation (not exhaustive)."""
if self.input_shapes["observation.image"][-2] != self.input_shapes["observation.image"][-1]:
# TODO(alexander-soare): This limitation is solely because of code in the random shift
# augmentation. It should be able to be removed.
raise ValueError(
"Only square images are handled now. Got image shape "
f"{self.input_shapes['observation.image']}."
)
if self.n_gaussian_samples <= 0:
raise ValueError(
f"The number of guassian samples for CEM should be non-zero. Got `{self.n_gaussian_samples=}`"
)
if self.output_normalization_modes != {"action": "min_max"}:
raise ValueError(
"TD-MPC assumes the action space dimensions to all be in [-1, 1]. Therefore it is strongly "
f"advised that you stick with the default. See {self.__class__.__name__} docstring for more "
"information."
)

View File

@ -1,576 +0,0 @@
import os
import pickle
import re
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F # noqa: N812
from torch import distributions as pyd
from torch.distributions.utils import _standard_normal
DEFAULT_ACT_FN = nn.Mish()
def __REDUCE__(b): # noqa: N802, N807
return "mean" if b else "none"
def l1(pred, target, reduce=False):
"""Computes the L1-loss between predictions and targets."""
return F.l1_loss(pred, target, reduction=__REDUCE__(reduce))
def mse(pred, target, reduce=False):
"""Computes the MSE loss between predictions and targets."""
return F.mse_loss(pred, target, reduction=__REDUCE__(reduce))
def l2_expectile(diff, expectile=0.7, reduce=False):
weight = torch.where(diff > 0, expectile, (1 - expectile))
loss = weight * (diff**2)
reduction = __REDUCE__(reduce)
if reduction == "mean":
return torch.mean(loss)
elif reduction == "sum":
return torch.sum(loss)
return loss
def _get_out_shape(in_shape, layers):
"""Utility function. Returns the output shape of a network for a given input shape."""
x = torch.randn(*in_shape).unsqueeze(0)
return (nn.Sequential(*layers) if isinstance(layers, list) else layers)(x).squeeze(0).shape
def gaussian_logprob(eps, log_std):
"""Compute Gaussian log probability."""
residual = (-0.5 * eps.pow(2) - log_std).sum(-1, keepdim=True)
return residual - 0.5 * np.log(2 * np.pi) * eps.size(-1)
def squash(mu, pi, log_pi):
"""Apply squashing function."""
mu = torch.tanh(mu)
pi = torch.tanh(pi)
log_pi -= torch.log(F.relu(1 - pi.pow(2)) + 1e-6).sum(-1, keepdim=True)
return mu, pi, log_pi
def orthogonal_init(m):
"""Orthogonal layer initialization."""
if isinstance(m, nn.Linear):
nn.init.orthogonal_(m.weight.data)
if m.bias is not None:
nn.init.zeros_(m.bias)
elif isinstance(m, nn.Conv2d):
gain = nn.init.calculate_gain("relu")
nn.init.orthogonal_(m.weight.data, gain)
if m.bias is not None:
nn.init.zeros_(m.bias)
def ema(m, m_target, tau):
"""Update slow-moving average of online network (target network) at rate tau."""
with torch.no_grad():
# TODO(rcadene, aliberts): issue with strict=False
# for p, p_target in zip(m.parameters(), m_target.parameters(), strict=False):
# p_target.data.lerp_(p.data, tau)
m_params_iter = iter(m.parameters())
m_target_params_iter = iter(m_target.parameters())
while True:
try:
p = next(m_params_iter)
p_target = next(m_target_params_iter)
p_target.data.lerp_(p.data, tau)
except StopIteration:
# If any iterator is exhausted, exit the loop
break
def set_requires_grad(net, value):
"""Enable/disable gradients for a given (sub)network."""
for param in net.parameters():
param.requires_grad_(value)
class TruncatedNormal(pyd.Normal):
"""Utility class implementing the truncated normal distribution."""
default_sample_shape = torch.Size()
def __init__(self, loc, scale, low=-1.0, high=1.0, eps=1e-6):
super().__init__(loc, scale, validate_args=False)
self.low = low
self.high = high
self.eps = eps
def _clamp(self, x):
clamped_x = torch.clamp(x, self.low + self.eps, self.high - self.eps)
x = x - x.detach() + clamped_x.detach()
return x
def sample(self, clip=None, sample_shape=default_sample_shape):
shape = self._extended_shape(sample_shape)
eps = _standard_normal(shape, dtype=self.loc.dtype, device=self.loc.device)
eps *= self.scale
if clip is not None:
eps = torch.clamp(eps, -clip, clip)
x = self.loc + eps
return self._clamp(x)
class NormalizeImg(nn.Module):
"""Normalizes pixel observations to [0,1) range."""
def __init__(self):
super().__init__()
def forward(self, x):
return x.div(255.0)
class Flatten(nn.Module):
"""Flattens its input to a (batched) vector."""
def __init__(self):
super().__init__()
def forward(self, x):
return x.view(x.size(0), -1)
def enc(cfg):
obs_shape = {
"rgb": (3, cfg.img_size, cfg.img_size),
"state": (cfg.state_dim,),
}
"""Returns a TOLD encoder."""
pixels_enc_layers, state_enc_layers = None, None
if cfg.modality in {"pixels", "all"}:
C = int(3 * cfg.frame_stack) # noqa: N806
pixels_enc_layers = [
NormalizeImg(),
nn.Conv2d(C, cfg.num_channels, 7, stride=2),
nn.ReLU(),
nn.Conv2d(cfg.num_channels, cfg.num_channels, 5, stride=2),
nn.ReLU(),
nn.Conv2d(cfg.num_channels, cfg.num_channels, 3, stride=2),
nn.ReLU(),
nn.Conv2d(cfg.num_channels, cfg.num_channels, 3, stride=2),
nn.ReLU(),
]
out_shape = _get_out_shape((C, cfg.img_size, cfg.img_size), pixels_enc_layers)
pixels_enc_layers.extend(
[
Flatten(),
nn.Linear(np.prod(out_shape), cfg.latent_dim),
nn.LayerNorm(cfg.latent_dim),
nn.Sigmoid(),
]
)
if cfg.modality == "pixels":
return ConvExt(nn.Sequential(*pixels_enc_layers))
if cfg.modality in {"state", "all"}:
state_dim = obs_shape[0] if cfg.modality == "state" else obs_shape["state"][0]
state_enc_layers = [
nn.Linear(state_dim, cfg.enc_dim),
nn.ELU(),
nn.Linear(cfg.enc_dim, cfg.latent_dim),
nn.LayerNorm(cfg.latent_dim),
nn.Sigmoid(),
]
if cfg.modality == "state":
return nn.Sequential(*state_enc_layers)
else:
raise NotImplementedError
encoders = {}
for k in obs_shape:
if k == "state":
encoders[k] = nn.Sequential(*state_enc_layers)
elif k.endswith("rgb"):
encoders[k] = ConvExt(nn.Sequential(*pixels_enc_layers))
else:
raise NotImplementedError
return Multiplexer(nn.ModuleDict(encoders))
def mlp(in_dim, mlp_dim, out_dim, act_fn=DEFAULT_ACT_FN):
"""Returns an MLP."""
if isinstance(mlp_dim, int):
mlp_dim = [mlp_dim, mlp_dim]
return nn.Sequential(
nn.Linear(in_dim, mlp_dim[0]),
nn.LayerNorm(mlp_dim[0]),
act_fn,
nn.Linear(mlp_dim[0], mlp_dim[1]),
nn.LayerNorm(mlp_dim[1]),
act_fn,
nn.Linear(mlp_dim[1], out_dim),
)
def dynamics(in_dim, mlp_dim, out_dim, act_fn=DEFAULT_ACT_FN):
"""Returns a dynamics network."""
return nn.Sequential(
mlp(in_dim, mlp_dim, out_dim, act_fn),
nn.LayerNorm(out_dim),
nn.Sigmoid(),
)
def q(cfg):
action_dim = cfg.action_dim
"""Returns a Q-function that uses Layer Normalization."""
return nn.Sequential(
nn.Linear(cfg.latent_dim + action_dim, cfg.mlp_dim),
nn.LayerNorm(cfg.mlp_dim),
nn.Tanh(),
nn.Linear(cfg.mlp_dim, cfg.mlp_dim),
nn.ELU(),
nn.Linear(cfg.mlp_dim, 1),
)
def v(cfg):
"""Returns a state value function that uses Layer Normalization."""
return nn.Sequential(
nn.Linear(cfg.latent_dim, cfg.mlp_dim),
nn.LayerNorm(cfg.mlp_dim),
nn.Tanh(),
nn.Linear(cfg.mlp_dim, cfg.mlp_dim),
nn.ELU(),
nn.Linear(cfg.mlp_dim, 1),
)
def aug(cfg):
obs_shape = {
"rgb": (3, cfg.img_size, cfg.img_size),
"state": (4,),
}
"""Multiplex augmentation"""
if cfg.modality == "state":
return nn.Identity()
elif cfg.modality == "pixels":
return RandomShiftsAug(cfg)
else:
augs = {}
for k in obs_shape:
if k == "state":
augs[k] = nn.Identity()
elif k.endswith("rgb"):
augs[k] = RandomShiftsAug(cfg)
else:
raise NotImplementedError
return Multiplexer(nn.ModuleDict(augs))
class ConvExt(nn.Module):
"""Auxiliary conv net accommodating high-dim input"""
def __init__(self, conv):
super().__init__()
self.conv = conv
def forward(self, x):
if x.ndim > 4:
batch_shape = x.shape[:-3]
out = self.conv(x.view(-1, *x.shape[-3:]))
out = out.view(*batch_shape, *out.shape[1:])
else:
out = self.conv(x)
return out
class Multiplexer(nn.Module):
"""Model multiplexer"""
def __init__(self, choices):
super().__init__()
self.choices = choices
def forward(self, x, key=None):
if isinstance(x, dict):
if key is not None:
return self.choices[key](x)
return {k: self.choices[k](_x) for k, _x in x.items()}
return self.choices(x)
class RandomShiftsAug(nn.Module):
"""
Random shift image augmentation.
Adapted from https://github.com/facebookresearch/drqv2
"""
def __init__(self, cfg):
super().__init__()
assert cfg.modality in {"pixels", "all"}
self.pad = int(cfg.img_size / 21)
def forward(self, x):
n, c, h, w = x.size()
assert h == w
padding = tuple([self.pad] * 4)
x = F.pad(x, padding, "replicate")
eps = 1.0 / (h + 2 * self.pad)
arange = torch.linspace(
-1.0 + eps,
1.0 - eps,
h + 2 * self.pad,
device=x.device,
dtype=torch.float32,
)[:h]
arange = arange.unsqueeze(0).repeat(h, 1).unsqueeze(2)
base_grid = torch.cat([arange, arange.transpose(1, 0)], dim=2)
base_grid = base_grid.unsqueeze(0).repeat(n, 1, 1, 1)
shift = torch.randint(
0,
2 * self.pad + 1,
size=(n, 1, 1, 2),
device=x.device,
dtype=torch.float32,
)
shift *= 2.0 / (h + 2 * self.pad)
grid = base_grid + shift
return F.grid_sample(x, grid, padding_mode="zeros", align_corners=False)
# TODO(aliberts): remove class
# class Episode:
# """Storage object for a single episode."""
# def __init__(self, cfg, init_obs):
# action_dim = cfg.action_dim
# self.cfg = cfg
# self.device = torch.device(cfg.buffer_device)
# if cfg.modality in {"pixels", "state"}:
# dtype = torch.float32 if cfg.modality == "state" else torch.uint8
# self.obses = torch.empty(
# (cfg.episode_length + 1, *init_obs.shape),
# dtype=dtype,
# device=self.device,
# )
# self.obses[0] = torch.tensor(init_obs, dtype=dtype, device=self.device)
# elif cfg.modality == "all":
# self.obses = {}
# for k, v in init_obs.items():
# assert k in {"rgb", "state"}
# dtype = torch.float32 if k == "state" else torch.uint8
# self.obses[k] = torch.empty(
# (cfg.episode_length + 1, *v.shape), dtype=dtype, device=self.device
# )
# self.obses[k][0] = torch.tensor(v, dtype=dtype, device=self.device)
# else:
# raise ValueError
# self.actions = torch.empty((cfg.episode_length, action_dim), dtype=torch.float32, device=self.device)
# self.rewards = torch.empty((cfg.episode_length,), dtype=torch.float32, device=self.device)
# self.dones = torch.empty((cfg.episode_length,), dtype=torch.bool, device=self.device)
# self.masks = torch.empty((cfg.episode_length,), dtype=torch.float32, device=self.device)
# self.cumulative_reward = 0
# self.done = False
# self.success = False
# self._idx = 0
# def __len__(self):
# return self._idx
# @classmethod
# def from_trajectory(cls, cfg, obses, actions, rewards, dones=None, masks=None):
# """Constructs an episode from a trajectory."""
# if cfg.modality in {"pixels", "state"}:
# episode = cls(cfg, obses[0])
# episode.obses[1:] = torch.tensor(obses[1:], dtype=episode.obses.dtype, device=episode.device)
# elif cfg.modality == "all":
# episode = cls(cfg, {k: v[0] for k, v in obses.items()})
# for k in obses:
# episode.obses[k][1:] = torch.tensor(
# obses[k][1:], dtype=episode.obses[k].dtype, device=episode.device
# )
# else:
# raise NotImplementedError
# episode.actions = torch.tensor(actions, dtype=episode.actions.dtype, device=episode.device)
# episode.rewards = torch.tensor(rewards, dtype=episode.rewards.dtype, device=episode.device)
# episode.dones = (
# torch.tensor(dones, dtype=episode.dones.dtype, device=episode.device)
# if dones is not None
# else torch.zeros_like(episode.dones)
# )
# episode.masks = (
# torch.tensor(masks, dtype=episode.masks.dtype, device=episode.device)
# if masks is not None
# else torch.ones_like(episode.masks)
# )
# episode.cumulative_reward = torch.sum(episode.rewards)
# episode.done = True
# episode._idx = cfg.episode_length
# return episode
# @property
# def first(self):
# return len(self) == 0
# def __add__(self, transition):
# self.add(*transition)
# return self
# def add(self, obs, action, reward, done, mask=1.0, success=False):
# """Add a transition into the episode."""
# if isinstance(obs, dict):
# for k, v in obs.items():
# self.obses[k][self._idx + 1] = torch.tensor(
# v, dtype=self.obses[k].dtype, device=self.obses[k].device
# )
# else:
# self.obses[self._idx + 1] = torch.tensor(obs, dtype=self.obses.dtype, device=self.obses.device)
# self.actions[self._idx] = action
# self.rewards[self._idx] = reward
# self.dones[self._idx] = done
# self.masks[self._idx] = mask
# self.cumulative_reward += reward
# self.done = done
# self.success = self.success or success
# self._idx += 1
def get_dataset_dict(cfg, env, return_reward_normalizer=False):
"""Construct a dataset for env"""
required_keys = [
"observations",
"next_observations",
"actions",
"rewards",
"dones",
"masks",
]
if cfg.task.startswith("xarm"):
dataset_path = os.path.join(cfg.dataset_dir, "buffer.pkl")
print(f"Using offline dataset '{dataset_path}'")
with open(dataset_path, "rb") as f:
dataset_dict = pickle.load(f)
for k in required_keys:
if k not in dataset_dict and k[:-1] in dataset_dict:
dataset_dict[k] = dataset_dict.pop(k[:-1])
elif cfg.task.startswith("legged"):
dataset_path = os.path.join(cfg.dataset_dir, "buffer.pkl")
print(f"Using offline dataset '{dataset_path}'")
with open(dataset_path, "rb") as f:
dataset_dict = pickle.load(f)
dataset_dict["actions"] /= env.unwrapped.clip_actions
print(f"clip_actions={env.unwrapped.clip_actions}")
else:
import d4rl
dataset_dict = d4rl.qlearning_dataset(env)
dones = np.full_like(dataset_dict["rewards"], False, dtype=bool)
for i in range(len(dones) - 1):
if (
np.linalg.norm(dataset_dict["observations"][i + 1] - dataset_dict["next_observations"][i])
> 1e-6
or dataset_dict["terminals"][i] == 1.0
):
dones[i] = True
dones[-1] = True
dataset_dict["masks"] = 1.0 - dataset_dict["terminals"]
del dataset_dict["terminals"]
for k, v in dataset_dict.items():
dataset_dict[k] = v.astype(np.float32)
dataset_dict["dones"] = dones
if cfg.is_data_clip:
lim = 1 - cfg.data_clip_eps
dataset_dict["actions"] = np.clip(dataset_dict["actions"], -lim, lim)
reward_normalizer = get_reward_normalizer(cfg, dataset_dict)
dataset_dict["rewards"] = reward_normalizer(dataset_dict["rewards"])
for key in required_keys:
assert key in dataset_dict, f"Missing `{key}` in dataset."
if return_reward_normalizer:
return dataset_dict, reward_normalizer
return dataset_dict
def get_trajectory_boundaries_and_returns(dataset):
"""
Split dataset into trajectories and compute returns
"""
episode_starts = [0]
episode_ends = []
episode_return = 0
episode_returns = []
n_transitions = len(dataset["rewards"])
for i in range(n_transitions):
episode_return += dataset["rewards"][i]
if dataset["dones"][i]:
episode_returns.append(episode_return)
episode_ends.append(i + 1)
if i + 1 < n_transitions:
episode_starts.append(i + 1)
episode_return = 0.0
return episode_starts, episode_ends, episode_returns
def normalize_returns(dataset, scaling=1000):
"""
Normalize returns in the dataset
"""
(_, _, episode_returns) = get_trajectory_boundaries_and_returns(dataset)
dataset["rewards"] /= np.max(episode_returns) - np.min(episode_returns)
dataset["rewards"] *= scaling
return dataset
def get_reward_normalizer(cfg, dataset):
"""
Get a reward normalizer for the dataset
"""
if cfg.task.startswith("xarm"):
return lambda x: x
elif "maze" in cfg.task:
return lambda x: x - 1.0
elif cfg.task.split("-")[0] in ["hopper", "halfcheetah", "walker2d"]:
(_, _, episode_returns) = get_trajectory_boundaries_and_returns(dataset)
return lambda x: x / (np.max(episode_returns) - np.min(episode_returns)) * 1000.0
elif hasattr(cfg, "reward_scale"):
return lambda x: x * cfg.reward_scale
return lambda x: x
def linear_schedule(schdl, step):
"""
Outputs values following a linear decay schedule.
Adapted from https://github.com/facebookresearch/drqv2
"""
try:
return float(schdl)
except ValueError:
match = re.match(r"linear\((.+),(.+),(.+),(.+)\)", schdl)
if match:
init, final, start, end = (float(g) for g in match.groups())
mix = np.clip((step - start) / (end - start), 0.0, 1.0)
return (1.0 - mix) * init + mix * final
match = re.match(r"linear\((.+),(.+),(.+)\)", schdl)
if match:
init, final, duration = (float(g) for g in match.groups())
mix = np.clip(step / duration, 0.0, 1.0)
return (1.0 - mix) * init + mix * final
raise NotImplementedError(schdl)

View File

@ -0,0 +1,798 @@
"""Implementation of Finetuning Offline World Models in the Real World.
The comments in this code may sometimes refer to these references:
TD-MPC paper: Temporal Difference Learning for Model Predictive Control (https://arxiv.org/abs/2203.04955)
FOWM paper: Finetuning Offline World Models in the Real World (https://arxiv.org/abs/2310.16029)
TODO(alexander-soare): Make rollout work for batch sizes larger than 1.
TODO(alexander-soare): Use batch-first throughout.
"""
# ruff: noqa: N806
import logging
from collections import deque
from copy import deepcopy
from functools import partial
from typing import Callable
import einops
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F # noqa: N812
from huggingface_hub import PyTorchModelHubMixin
from torch import Tensor
from lerobot.common.policies.normalize import Normalize, Unnormalize
from lerobot.common.policies.tdmpc.configuration_tdmpc import TDMPCConfig
from lerobot.common.policies.utils import get_device_from_parameters, populate_queues
class TDMPCPolicy(nn.Module, PyTorchModelHubMixin):
"""Implementation of TD-MPC learning + inference.
Please note several warnings for this policy.
- Evaluation of pretrained weights created with the original FOWM code
(https://github.com/fyhMer/fowm) works as expected. To be precise: we trained and evaluated a
model with the FOWM code for the xarm_lift_medium_replay dataset. We ported the weights across
to LeRobot, and were able to evaluate with the same success metric. BUT, we had to use inter-
process communication to use the xarm environment from FOWM. This is because our xarm
environment uses newer dependencies and does not match the environment in FOWM. See
https://github.com/huggingface/lerobot/pull/103 for implementation details.
- We have NOT checked that training on LeRobot reproduces SOTA results. This is a TODO.
- Our current xarm datasets were generated using the environment from FOWM. Therefore they do not
match our xarm environment.
"""
name = "tdmpc"
def __init__(
self, config: TDMPCConfig | None = None, dataset_stats: dict[str, dict[str, Tensor]] | None = None
):
"""
Args:
config: Policy configuration class instance or None, in which case the default instantiation of
the configuration class is used.
dataset_stats: Dataset statistics to be used for normalization. If not passed here, it is expected
that they will be passed with a call to `load_state_dict` before the policy is used.
"""
super().__init__()
logging.warning(
"""
Please note several warnings for this policy.
- Evaluation of pretrained weights created with the original FOWM code
(https://github.com/fyhMer/fowm) works as expected. To be precise: we trained and evaluated a
model with the FOWM code for the xarm_lift_medium_replay dataset. We ported the weights across
to LeRobot, and were able to evaluate with the same success metric. BUT, we had to use inter-
process communication to use the xarm environment from FOWM. This is because our xarm
environment uses newer dependencies and does not match the environment in FOWM. See
https://github.com/huggingface/lerobot/pull/103 for implementation details.
- We have NOT checked that training on LeRobot reproduces SOTA results. This is a TODO.
- Our current xarm datasets were generated using the environment from FOWM. Therefore they do not
match our xarm environment.
"""
)
if config is None:
config = TDMPCConfig()
self.config = config
self.model = TDMPCTOLD(config)
self.model_target = deepcopy(self.model)
self.model_target.eval()
if config.input_normalization_modes is not None:
self.normalize_inputs = Normalize(
config.input_shapes, config.input_normalization_modes, dataset_stats
)
else:
self.normalize_inputs = nn.Identity()
self.normalize_targets = Normalize(
config.output_shapes, config.output_normalization_modes, dataset_stats
)
self.unnormalize_outputs = Unnormalize(
config.output_shapes, config.output_normalization_modes, dataset_stats
)
def save(self, fp):
"""Save state dict of TOLD model to filepath."""
torch.save(self.state_dict(), fp)
def load(self, fp):
"""Load a saved state dict from filepath into current agent."""
self.load_state_dict(torch.load(fp))
def reset(self):
"""
Clear observation and action queues. Clear previous means for warm starting of MPPI/CEM. Should be
called on `env.reset()`
"""
self._queues = {
"observation.image": deque(maxlen=1),
"observation.state": deque(maxlen=1),
"action": deque(maxlen=self.config.n_action_repeats),
}
# Previous mean obtained from the cross-entropy method (CEM) used during MPC. It is used to warm start
# CEM for the next step.
self._prev_mean: torch.Tensor | None = None
@torch.no_grad()
def select_action(self, batch: dict[str, Tensor]):
"""Select a single action given environment observations."""
assert "observation.image" in batch
assert "observation.state" in batch
assert len(batch) == 2
batch = self.normalize_inputs(batch)
self._queues = populate_queues(self._queues, batch)
# When the action queue is depleted, populate it again by querying the policy.
if len(self._queues["action"]) == 0:
batch = {key: torch.stack(list(self._queues[key]), dim=1) for key in batch}
# Remove the time dimensions as it is not handled yet.
for key in batch:
assert batch[key].shape[1] == 1
batch[key] = batch[key][:, 0]
# NOTE: Order of observations matters here.
z = self.model.encode({k: batch[k] for k in ["observation.image", "observation.state"]})
if self.config.use_mpc:
batch_size = batch["observation.image"].shape[0]
# Batch processing is not handled in MPC mode, so process the batch in a loop.
action = [] # will be a batch of actions for one step
for i in range(batch_size):
# Note: self.plan does not handle batches, hence the squeeze.
action.append(self.plan(z[i]))
action = torch.stack(action)
else:
# Plan with the policy (π) alone.
action = self.model.pi(z)
self.unnormalize_outputs({"action": action})["action"]
for _ in range(self.config.n_action_repeats):
self._queues["action"].append(action)
action = self._queues["action"].popleft()
return torch.clamp(action, -1, 1)
@torch.no_grad()
def plan(self, z: Tensor) -> Tensor:
"""Plan next action using TD-MPC inference.
Args:
z: (latent_dim,) tensor for the initial state.
Returns:
(action_dim,) tensor for the next action.
TODO(alexander-soare) Extend this to be able to work with batches.
"""
device = get_device_from_parameters(self)
# Sample Nπ trajectories from the policy.
pi_actions = torch.empty(
self.config.horizon,
self.config.n_pi_samples,
self.config.output_shapes["action"][0],
device=device,
)
if self.config.n_pi_samples > 0:
_z = einops.repeat(z, "d -> n d", n=self.config.n_pi_samples)
for t in range(self.config.horizon):
# Note: Adding a small amount of noise here doesn't hurt during inference and may even be
# helpful for CEM.
pi_actions[t] = self.model.pi(_z, self.config.min_std)
_z = self.model.latent_dynamics(_z, pi_actions[t])
# In the CEM loop we will need this for a call to estimate_value with the gaussian sampled
# trajectories.
z = einops.repeat(z, "d -> n d", n=self.config.n_gaussian_samples + self.config.n_pi_samples)
# Model Predictive Path Integral (MPPI) with the cross-entropy method (CEM) as the optimization
# algorithm.
# The initial mean and standard deviation for the cross-entropy method (CEM).
mean = torch.zeros(self.config.horizon, self.config.output_shapes["action"][0], device=device)
# Maybe warm start CEM with the mean from the previous step.
if self._prev_mean is not None:
mean[:-1] = self._prev_mean[1:]
std = self.config.max_std * torch.ones_like(mean)
for _ in range(self.config.cem_iterations):
# Randomly sample action trajectories for the gaussian distribution.
std_normal_noise = torch.randn(
self.config.horizon,
self.config.n_gaussian_samples,
self.config.output_shapes["action"][0],
device=std.device,
)
gaussian_actions = torch.clamp(mean.unsqueeze(1) + std.unsqueeze(1) * std_normal_noise, -1, 1)
# Compute elite actions.
actions = torch.cat([gaussian_actions, pi_actions], dim=1)
value = self.estimate_value(z, actions).nan_to_num_(0)
elite_idxs = torch.topk(value, self.config.n_elites, dim=0).indices
elite_value, elite_actions = value[elite_idxs], actions[:, elite_idxs]
# Update guassian PDF parameters to be the (weighted) mean and standard deviation of the elites.
max_value = elite_value.max(0)[0]
# The weighting is a softmax over trajectory values. Note that this is not the same as the usage
# of Ω in eqn 4 of the TD-MPC paper. Instead it is the normalized version of it: s = Ω/ΣΩ. This
# makes the equations: μ = Σ(s⋅Γ), σ = Σ(s⋅(Γ-μ)²).
score = torch.exp(self.config.elite_weighting_temperature * (elite_value - max_value))
score /= score.sum()
_mean = torch.sum(einops.rearrange(score, "n -> n 1") * elite_actions, dim=1)
_std = torch.sqrt(
torch.sum(
einops.rearrange(score, "n -> n 1")
* (elite_actions - einops.rearrange(_mean, "h d -> h 1 d")) ** 2,
dim=1,
)
)
# Update mean with an exponential moving average, and std with a direct replacement.
mean = (
self.config.gaussian_mean_momentum * mean + (1 - self.config.gaussian_mean_momentum) * _mean
)
std = _std.clamp_(self.config.min_std, self.config.max_std)
# Keep track of the mean for warm-starting subsequent steps.
self._prev_mean = mean
# Randomly select one of the elite actions from the last iteration of MPPI/CEM using the softmax
# scores from the last iteration.
actions = elite_actions[:, torch.multinomial(score, 1).item()]
# Select only the first action
action = actions[0]
return action
@torch.no_grad()
def estimate_value(self, z: Tensor, actions: Tensor):
"""Estimates the value of a trajectory as per eqn 4 of the FOWM paper.
Args:
z: (batch, latent_dim) tensor of initial latent states.
actions: (horizon, batch, action_dim) tensor of action trajectories.
Returns:
(batch,) tensor of values.
"""
# Initialize return and running discount factor.
G, running_discount = 0, 1
# Iterate over the actions in the trajectory to simulate the trajectory using the latent dynamics
# model. Keep track of return.
for t in range(actions.shape[0]):
# We will compute the reward in a moment. First compute the uncertainty regularizer from eqn 4
# of the FOWM paper.
if self.config.uncertainty_regularizer_coeff > 0:
regularization = -(
self.config.uncertainty_regularizer_coeff * self.model.Qs(z, actions[t]).std(0)
)
else:
regularization = 0
# Estimate the next state (latent) and reward.
z, reward = self.model.latent_dynamics_and_reward(z, actions[t])
# Update the return and running discount.
G += running_discount * (reward + regularization)
running_discount *= self.config.discount
# Add the estimated value of the final state (using the minimum for a conservative estimate).
# Do so by predicting the next action, then taking a minimum over the ensemble of state-action value
# estimators.
# Note: This small amount of added noise seems to help a bit at inference time as observed by success
# metrics over 50 episodes of xarm_lift_medium_replay.
next_action = self.model.pi(z, self.config.min_std) # (batch, action_dim)
terminal_values = self.model.Qs(z, next_action) # (ensemble, batch)
# Randomly choose 2 of the Qs for terminal value estimation (as in App C. of the FOWM paper).
if self.config.q_ensemble_size > 2:
G += (
running_discount
* torch.min(terminal_values[torch.randint(0, self.config.q_ensemble_size, size=(2,))], dim=0)[
0
]
)
else:
G += running_discount * torch.min(terminal_values, dim=0)[0]
# Finally, also regularize the terminal value.
if self.config.uncertainty_regularizer_coeff > 0:
G -= running_discount * self.config.uncertainty_regularizer_coeff * terminal_values.std(0)
return G
def forward(self, batch: dict[str, Tensor]) -> dict[str, Tensor]:
"""Run the batch through the model and compute the loss."""
device = get_device_from_parameters(self)
batch = self.normalize_inputs(batch)
batch = self.normalize_targets(batch)
info = {}
# TODO(alexander-soare): Refactor TDMPC and make it comply with the policy interface documentation.
batch_size = batch["index"].shape[0]
# (b, t) -> (t, b)
for key in batch:
if batch[key].ndim > 1:
batch[key] = batch[key].transpose(1, 0)
action = batch["action"] # (t, b)
reward = batch["next.reward"] # (t,)
observations = {k: v for k, v in batch.items() if k.startswith("observation.")}
# Apply random image augmentations.
if self.config.max_random_shift_ratio > 0:
observations["observation.image"] = flatten_forward_unflatten(
partial(random_shifts_aug, max_random_shift_ratio=self.config.max_random_shift_ratio),
observations["observation.image"],
)
# Get the current observation for predicting trajectories, and all future observations for use in
# the latent consistency loss and TD loss.
current_observation, next_observations = {}, {}
for k in observations:
current_observation[k] = observations[k][0]
next_observations[k] = observations[k][1:]
horizon = next_observations["observation.image"].shape[0]
# Run latent rollout using the latent dynamics model and policy model.
# Note this has shape `horizon+1` because there are `horizon` actions and a current `z`. Each action
# gives us a next `z`.
z_preds = torch.empty(horizon + 1, batch_size, self.config.latent_dim, device=device)
z_preds[0] = self.model.encode(current_observation)
reward_preds = torch.empty_like(reward, device=device)
for t in range(horizon):
z_preds[t + 1], reward_preds[t] = self.model.latent_dynamics_and_reward(z_preds[t], action[t])
# Compute Q and V value predictions based on the latent rollout.
q_preds_ensemble = self.model.Qs(z_preds[:-1], action) # (ensemble, horizon, batch)
v_preds = self.model.V(z_preds[:-1])
info.update({"Q": q_preds_ensemble.mean().item(), "V": v_preds.mean().item()})
# Compute various targets with stopgrad.
with torch.no_grad():
# Latent state consistency targets.
z_targets = self.model_target.encode(next_observations)
# State-action value targets (or TD targets) as in eqn 3 of the FOWM. Unlike TD-MPC which uses the
# learned state-action value function in conjunction with the learned policy: Q(z, π(z)), FOWM
# uses a learned state value function: V(z). This means the TD targets only depend on in-sample
# actions (not actions estimated by π).
# Note: Here we do not use self.model_target, but self.model. This is to follow the original code
# and the FOWM paper.
q_targets = reward + self.config.discount * self.model.V(self.model.encode(next_observations))
# From eqn 3 of FOWM. These appear as Q(z, a). Here we call them v_targets to emphasize that we
# are using them to compute loss for V.
v_targets = self.model_target.Qs(z_preds[:-1].detach(), action, return_min=True)
# Compute losses.
# Exponentially decay the loss weight with respect to the timestep. Steps that are more distant in the
# future have less impact on the loss. Note: unsqueeze will let us broadcast to (seq, batch).
temporal_loss_coeffs = torch.pow(
self.config.temporal_decay_coeff, torch.arange(horizon, device=device)
).unsqueeze(-1)
# Compute consistency loss as MSE loss between latents predicted from the rollout and latents
# predicted from the (target model's) observation encoder.
consistency_loss = (
(
temporal_loss_coeffs
* F.mse_loss(z_preds[1:], z_targets, reduction="none").mean(dim=-1)
# `z_preds` depends on the current observation and the actions.
* ~batch["observation.state_is_pad"][0]
* ~batch["action_is_pad"]
# `z_targets` depends on the next observation.
* ~batch["observation.state_is_pad"][1:]
)
.sum(0)
.mean()
)
# Compute the reward loss as MSE loss between rewards predicted from the rollout and the dataset
# rewards.
reward_loss = (
(
temporal_loss_coeffs
* F.mse_loss(reward_preds, reward, reduction="none")
* ~batch["next.reward_is_pad"]
# `reward_preds` depends on the current observation and the actions.
* ~batch["observation.state_is_pad"][0]
* ~batch["action_is_pad"]
)
.sum(0)
.mean()
)
# Compute state-action value loss (TD loss) for all of the Q functions in the ensemble.
q_value_loss = (
(
F.mse_loss(
q_preds_ensemble,
einops.repeat(q_targets, "t b -> e t b", e=q_preds_ensemble.shape[0]),
reduction="none",
).sum(0) # sum over ensemble
# `q_preds_ensemble` depends on the first observation and the actions.
* ~batch["observation.state_is_pad"][0]
* ~batch["action_is_pad"]
# q_targets depends on the reward and the next observations.
* ~batch["next.reward_is_pad"]
* ~batch["observation.state_is_pad"][1:]
)
.sum(0)
.mean()
)
# Compute state value loss as in eqn 3 of FOWM.
diff = v_targets - v_preds
# Expectile loss penalizes:
# - `v_preds < v_targets` with weighting `expectile_weight`
# - `v_preds >= v_targets` with weighting `1 - expectile_weight`
raw_v_value_loss = torch.where(
diff > 0, self.config.expectile_weight, (1 - self.config.expectile_weight)
) * (diff**2)
v_value_loss = (
(
temporal_loss_coeffs
* raw_v_value_loss
# `v_targets` depends on the first observation and the actions, as does `v_preds`.
* ~batch["observation.state_is_pad"][0]
* ~batch["action_is_pad"]
)
.sum(0)
.mean()
)
# Calculate the advantage weighted regression loss for π as detailed in FOWM 3.1.
# We won't need these gradients again so detach.
z_preds = z_preds.detach()
# Use stopgrad for the advantage calculation.
with torch.no_grad():
advantage = self.model_target.Qs(z_preds[:-1], action, return_min=True) - self.model.V(
z_preds[:-1]
)
info["advantage"] = advantage[0]
# (t, b)
exp_advantage = torch.clamp(torch.exp(advantage * self.config.advantage_scaling), max=100.0)
action_preds = self.model.pi(z_preds[:-1]) # (t, b, a)
# Calculate the MSE between the actions and the action predictions.
# Note: FOWM's original code calculates the log probability (wrt to a unit standard deviation
# gaussian) and sums over the action dimension. Computing the log probability amounts to multiplying
# the MSE by 0.5 and adding a constant offset (the log(2*pi) term) . Here we drop the constant offset
# as it doesn't change the optimization step, and we drop the 0.5 as we instead make a configuration
# parameter for it (see below where we compute the total loss).
mse = F.mse_loss(action_preds, action, reduction="none").sum(-1) # (t, b)
# NOTE: The original implementation does not take the sum over the temporal dimension like with the
# other losses.
# TODO(alexander-soare): Take the sum over the temporal dimension and check that training still works
# as well as expected.
pi_loss = (
exp_advantage
* mse
* temporal_loss_coeffs
# `action_preds` depends on the first observation and the actions.
* ~batch["observation.state_is_pad"][0]
* ~batch["action_is_pad"]
).mean()
loss = (
self.config.consistency_coeff * consistency_loss
+ self.config.reward_coeff * reward_loss
+ self.config.value_coeff * q_value_loss
+ self.config.value_coeff * v_value_loss
+ self.config.pi_coeff * pi_loss
)
info.update(
{
"consistency_loss": consistency_loss.item(),
"reward_loss": reward_loss.item(),
"Q_value_loss": q_value_loss.item(),
"V_value_loss": v_value_loss.item(),
"pi_loss": pi_loss.item(),
"loss": loss,
"sum_loss": loss.item() * self.config.horizon,
}
)
# Undo (b, t) -> (t, b).
for key in batch:
if batch[key].ndim > 1:
batch[key] = batch[key].transpose(1, 0)
return info
def update(self):
"""Update the target model's parameters with an EMA step."""
# Note a minor variation with respect to the original FOWM code. Here they do this based on an EMA
# update frequency parameter which is set to 2 (every 2 steps an update is done). To simplify the code
# we update every step and adjust the decay parameter `alpha` accordingly (0.99 -> 0.995)
update_ema_parameters(self.model_target, self.model, self.config.target_model_momentum)
class TDMPCTOLD(nn.Module):
"""Task-Oriented Latent Dynamics (TOLD) model used in TD-MPC."""
def __init__(self, config: TDMPCConfig):
super().__init__()
self.config = config
self._encoder = TDMPCObservationEncoder(config)
self._dynamics = nn.Sequential(
nn.Linear(config.latent_dim + config.output_shapes["action"][0], config.mlp_dim),
nn.LayerNorm(config.mlp_dim),
nn.Mish(),
nn.Linear(config.mlp_dim, config.mlp_dim),
nn.LayerNorm(config.mlp_dim),
nn.Mish(),
nn.Linear(config.mlp_dim, config.latent_dim),
nn.LayerNorm(config.latent_dim),
nn.Sigmoid(),
)
self._reward = nn.Sequential(
nn.Linear(config.latent_dim + config.output_shapes["action"][0], config.mlp_dim),
nn.LayerNorm(config.mlp_dim),
nn.Mish(),
nn.Linear(config.mlp_dim, config.mlp_dim),
nn.LayerNorm(config.mlp_dim),
nn.Mish(),
nn.Linear(config.mlp_dim, 1),
)
self._pi = nn.Sequential(
nn.Linear(config.latent_dim, config.mlp_dim),
nn.LayerNorm(config.mlp_dim),
nn.Mish(),
nn.Linear(config.mlp_dim, config.mlp_dim),
nn.LayerNorm(config.mlp_dim),
nn.Mish(),
nn.Linear(config.mlp_dim, config.output_shapes["action"][0]),
)
self._Qs = nn.ModuleList(
[
nn.Sequential(
nn.Linear(config.latent_dim + config.output_shapes["action"][0], config.mlp_dim),
nn.LayerNorm(config.mlp_dim),
nn.Tanh(),
nn.Linear(config.mlp_dim, config.mlp_dim),
nn.ELU(),
nn.Linear(config.mlp_dim, 1),
)
for _ in range(config.q_ensemble_size)
]
)
self._V = nn.Sequential(
nn.Linear(config.latent_dim, config.mlp_dim),
nn.LayerNorm(config.mlp_dim),
nn.Tanh(),
nn.Linear(config.mlp_dim, config.mlp_dim),
nn.ELU(),
nn.Linear(config.mlp_dim, 1),
)
self._init_weights()
def _init_weights(self):
"""Initialize model weights.
Orthogonal initialization for all linear and convolutional layers' weights (apart from final layers
of reward network and Q networks which get zero initialization).
Zero initialization for all linear and convolutional layers' biases.
"""
def _apply_fn(m):
if isinstance(m, nn.Linear):
nn.init.orthogonal_(m.weight.data)
if m.bias is not None:
nn.init.zeros_(m.bias)
elif isinstance(m, nn.Conv2d):
gain = nn.init.calculate_gain("relu")
nn.init.orthogonal_(m.weight.data, gain)
if m.bias is not None:
nn.init.zeros_(m.bias)
self.apply(_apply_fn)
for m in [self._reward, *self._Qs]:
assert isinstance(
m[-1], nn.Linear
), "Sanity check. The last linear layer needs 0 initialization on weights."
nn.init.zeros_(m[-1].weight)
nn.init.zeros_(m[-1].bias) # this has already been done, but keep this line here for good measure
def encode(self, obs: dict[str, Tensor]) -> Tensor:
"""Encodes an observation into its latent representation."""
return self._encoder(obs)
def latent_dynamics_and_reward(self, z: Tensor, a: Tensor) -> tuple[Tensor, Tensor]:
"""Predict the next state's latent representation and the reward given a current latent and action.
Args:
z: (*, latent_dim) tensor for the current state's latent representation.
a: (*, action_dim) tensor for the action to be applied.
Returns:
A tuple containing:
- (*, latent_dim) tensor for the next state's latent representation.
- (*,) tensor for the estimated reward.
"""
x = torch.cat([z, a], dim=-1)
return self._dynamics(x), self._reward(x).squeeze(-1)
def latent_dynamics(self, z: Tensor, a: Tensor) -> Tensor:
"""Predict the next state's latent representation given a current latent and action.
Args:
z: (*, latent_dim) tensor for the current state's latent representation.
a: (*, action_dim) tensor for the action to be applied.
Returns:
(*, latent_dim) tensor for the next state's latent representation.
"""
x = torch.cat([z, a], dim=-1)
return self._dynamics(x)
def pi(self, z: Tensor, std: float = 0.0) -> Tensor:
"""Samples an action from the learned policy.
The policy can also have added (truncated) Gaussian noise injected for encouraging exploration when
generating rollouts for online training.
Args:
z: (*, latent_dim) tensor for the current state's latent representation.
std: The standard deviation of the injected noise.
Returns:
(*, action_dim) tensor for the sampled action.
"""
action = torch.tanh(self._pi(z))
if std > 0:
std = torch.ones_like(action) * std
action += torch.randn_like(action) * std
return action
def V(self, z: Tensor) -> Tensor: # noqa: N802
"""Predict state value (V).
Args:
z: (*, latent_dim) tensor for the current state's latent representation.
Returns:
(*,) tensor of estimated state values.
"""
return self._V(z).squeeze(-1)
def Qs(self, z: Tensor, a: Tensor, return_min: bool = False) -> Tensor: # noqa: N802
"""Predict state-action value for all of the learned Q functions.
Args:
z: (*, latent_dim) tensor for the current state's latent representation.
a: (*, action_dim) tensor for the action to be applied.
return_min: Set to true for implementing the detail in App. C of the FOWM paper: randomly select
2 of the Qs and return the minimum
Returns:
(q_ensemble, *) tensor for the value predictions of each learned Q function in the ensemble OR
(*,) tensor if return_min=True.
"""
x = torch.cat([z, a], dim=-1)
if not return_min:
return torch.stack([q(x).squeeze(-1) for q in self._Qs], dim=0)
else:
if len(self._Qs) > 2: # noqa: SIM108
Qs = [self._Qs[i] for i in np.random.choice(len(self._Qs), size=2)]
else:
Qs = self._Qs
return torch.stack([q(x).squeeze(-1) for q in Qs], dim=0).min(dim=0)[0]
class TDMPCObservationEncoder(nn.Module):
"""Encode image and/or state vector observations."""
def __init__(self, config: TDMPCConfig):
"""
Creates encoders for pixel and/or state modalities.
TODO(alexander-soare): The original work allows for multiple images by concatenating them along the
channel dimension. Re-implement this capability.
"""
super().__init__()
self.config = config
if "observation.image" in config.input_shapes:
self.image_enc_layers = nn.Sequential(
nn.Conv2d(
config.input_shapes["observation.image"][0], config.image_encoder_hidden_dim, 7, stride=2
),
nn.ReLU(),
nn.Conv2d(config.image_encoder_hidden_dim, config.image_encoder_hidden_dim, 5, stride=2),
nn.ReLU(),
nn.Conv2d(config.image_encoder_hidden_dim, config.image_encoder_hidden_dim, 3, stride=2),
nn.ReLU(),
nn.Conv2d(config.image_encoder_hidden_dim, config.image_encoder_hidden_dim, 3, stride=2),
nn.ReLU(),
)
dummy_batch = torch.zeros(1, *config.input_shapes["observation.image"])
with torch.inference_mode():
out_shape = self.image_enc_layers(dummy_batch).shape[1:]
self.image_enc_layers.extend(
nn.Sequential(
nn.Flatten(),
nn.Linear(np.prod(out_shape), config.latent_dim),
nn.LayerNorm(config.latent_dim),
nn.Sigmoid(),
)
)
if "observation.state" in config.input_shapes:
self.state_enc_layers = nn.Sequential(
nn.Linear(config.input_shapes["observation.state"][0], config.state_encoder_hidden_dim),
nn.ELU(),
nn.Linear(config.state_encoder_hidden_dim, config.latent_dim),
nn.LayerNorm(config.latent_dim),
nn.Sigmoid(),
)
def forward(self, obs_dict: dict[str, Tensor]) -> Tensor:
"""Encode the image and/or state vector.
Each modality is encoded into a feature vector of size (latent_dim,) and then a uniform mean is taken
over all features.
"""
feat = []
if "observation.image" in self.config.input_shapes:
feat.append(flatten_forward_unflatten(self.image_enc_layers, obs_dict["observation.image"]))
if "observation.state" in self.config.input_shapes:
feat.append(self.state_enc_layers(obs_dict["observation.state"]))
return torch.stack(feat, dim=0).mean(0)
def random_shifts_aug(x: Tensor, max_random_shift_ratio: float) -> Tensor:
"""Randomly shifts images horizontally and vertically.
Adapted from https://github.com/facebookresearch/drqv2
"""
b, _, h, w = x.size()
assert h == w, "non-square images not handled yet"
pad = int(round(max_random_shift_ratio * h))
x = F.pad(x, tuple([pad] * 4), "replicate")
eps = 1.0 / (h + 2 * pad)
arange = torch.linspace(
-1.0 + eps,
1.0 - eps,
h + 2 * pad,
device=x.device,
dtype=torch.float32,
)[:h]
arange = einops.repeat(arange, "w -> h w 1", h=h)
base_grid = torch.cat([arange, arange.transpose(1, 0)], dim=2)
base_grid = einops.repeat(base_grid, "h w c -> b h w c", b=b)
# A random shift in units of pixels and within the boundaries of the padding.
shift = torch.randint(
0,
2 * pad + 1,
size=(b, 1, 1, 2),
device=x.device,
dtype=torch.float32,
)
shift *= 2.0 / (h + 2 * pad)
grid = base_grid + shift
return F.grid_sample(x, grid, padding_mode="zeros", align_corners=False)
def update_ema_parameters(ema_net: nn.Module, net: nn.Module, alpha: float):
"""Update EMA parameters in place with ema_param <- alpha * ema_param + (1 - alpha) * param."""
for ema_module, module in zip(ema_net.modules(), net.modules(), strict=True):
for (n_p_ema, p_ema), (n_p, p) in zip(
ema_module.named_parameters(recurse=False), module.named_parameters(recurse=False), strict=True
):
assert n_p_ema == n_p, "Parameter names don't match for EMA model update"
if isinstance(p, dict):
raise RuntimeError("Dict parameter not supported")
if isinstance(module, nn.modules.batchnorm._BatchNorm) or not p.requires_grad:
# Copy BatchNorm parameters, and non-trainable parameters directly.
p_ema.copy_(p.to(dtype=p_ema.dtype).data)
with torch.no_grad():
p_ema.mul_(alpha)
p_ema.add_(p.to(dtype=p_ema.dtype).data, alpha=1 - alpha)
def flatten_forward_unflatten(fn: Callable[[Tensor], Tensor], image_tensor: Tensor) -> Tensor:
"""Helper to temporarily flatten extra dims at the start of the image tensor.
Args:
fn: Callable that the image tensor will be passed to. It should accept (B, C, H, W) and return
(B, *), where * is any number of dimensions.
image_tensor: An image tensor of shape (**, C, H, W), where ** is any number of dimensions, generally
different from *.
Returns:
A return value from the callable reshaped to (**, *).
"""
if image_tensor.ndim == 4:
return fn(image_tensor)
start_dims = image_tensor.shape[:-3]
inp = torch.flatten(image_tensor, end_dim=-4)
flat_out = fn(inp)
return torch.reshape(flat_out, (*start_dims, *flat_out.shape[1:]))

View File

@ -1,495 +0,0 @@
# ruff: noqa: N806
import time
from collections import deque
from copy import deepcopy
import einops
import numpy as np
import torch
import torch.nn as nn
import lerobot.common.policies.tdmpc.helper as h
from lerobot.common.policies.utils import populate_queues
from lerobot.common.utils.utils import get_safe_torch_device
FIRST_FRAME = 0
class TOLD(nn.Module):
"""Task-Oriented Latent Dynamics (TOLD) model used in TD-MPC."""
def __init__(self, cfg):
super().__init__()
action_dim = cfg.action_dim
self.cfg = cfg
self._encoder = h.enc(cfg)
self._dynamics = h.dynamics(cfg.latent_dim + action_dim, cfg.mlp_dim, cfg.latent_dim)
self._reward = h.mlp(cfg.latent_dim + action_dim, cfg.mlp_dim, 1)
self._pi = h.mlp(cfg.latent_dim, cfg.mlp_dim, action_dim)
self._Qs = nn.ModuleList([h.q(cfg) for _ in range(cfg.num_q)])
self._V = h.v(cfg)
self.apply(h.orthogonal_init)
for m in [self._reward, *self._Qs]:
m[-1].weight.data.fill_(0)
m[-1].bias.data.fill_(0)
def track_q_grad(self, enable=True):
"""Utility function. Enables/disables gradient tracking of Q-networks."""
for m in self._Qs:
h.set_requires_grad(m, enable)
def track_v_grad(self, enable=True):
"""Utility function. Enables/disables gradient tracking of Q-networks."""
if hasattr(self, "_V"):
h.set_requires_grad(self._V, enable)
def encode(self, obs):
"""Encodes an observation into its latent representation."""
out = self._encoder(obs)
if isinstance(obs, dict):
# fusion
out = torch.stack([v for k, v in out.items()]).mean(dim=0)
return out
def next(self, z, a):
"""Predicts next latent state (d) and single-step reward (R)."""
x = torch.cat([z, a], dim=-1)
return self._dynamics(x), self._reward(x)
def next_dynamics(self, z, a):
"""Predicts next latent state (d)."""
x = torch.cat([z, a], dim=-1)
return self._dynamics(x)
def pi(self, z, std=0):
"""Samples an action from the learned policy (pi)."""
mu = torch.tanh(self._pi(z))
if std > 0:
std = torch.ones_like(mu) * std
return h.TruncatedNormal(mu, std).sample(clip=0.3)
return mu
def V(self, z): # noqa: N802
"""Predict state value (V)."""
return self._V(z)
def Q(self, z, a, return_type): # noqa: N802
"""Predict state-action value (Q)."""
assert return_type in {"min", "avg", "all"}
x = torch.cat([z, a], dim=-1)
if return_type == "all":
return torch.stack([q(x) for q in self._Qs], dim=0)
idxs = np.random.choice(self.cfg.num_q, 2, replace=False)
Q1, Q2 = self._Qs[idxs[0]](x), self._Qs[idxs[1]](x)
return torch.min(Q1, Q2) if return_type == "min" else (Q1 + Q2) / 2
class TDMPCPolicy(nn.Module):
"""Implementation of TD-MPC learning + inference."""
name = "tdmpc"
def __init__(self, cfg, n_obs_steps, n_action_steps, device):
super().__init__()
self.action_dim = cfg.action_dim
self.cfg = cfg
self.n_obs_steps = n_obs_steps
self.n_action_steps = n_action_steps
self.device = get_safe_torch_device(device)
self.std = h.linear_schedule(cfg.std_schedule, 0)
self.model = TOLD(cfg)
self.model.to(self.device)
self.model_target = deepcopy(self.model)
self.optim = torch.optim.Adam(self.model.parameters(), lr=self.cfg.lr)
self.pi_optim = torch.optim.Adam(self.model._pi.parameters(), lr=self.cfg.lr)
# self.bc_optim = torch.optim.Adam(self.model.parameters(), lr=self.cfg.lr)
self.model.eval()
self.model_target.eval()
self.register_buffer("step", torch.zeros(1))
def state_dict(self):
"""Retrieve state dict of TOLD model, including slow-moving target network."""
return {
"model": self.model.state_dict(),
"model_target": self.model_target.state_dict(),
}
def save(self, fp):
"""Save state dict of TOLD model to filepath."""
torch.save(self.state_dict(), fp)
def load(self, fp):
"""Load a saved state dict from filepath into current agent."""
d = torch.load(fp)
self.model.load_state_dict(d["model"])
self.model_target.load_state_dict(d["model_target"])
def reset(self):
"""
Clear observation and action queues. Should be called on `env.reset()`
"""
self._queues = {
"observation.image": deque(maxlen=self.n_obs_steps),
"observation.state": deque(maxlen=self.n_obs_steps),
"action": deque(maxlen=self.n_action_steps),
}
@torch.no_grad()
def select_action(self, batch, step):
assert "observation.image" in batch
assert "observation.state" in batch
assert len(batch) == 2
self._queues = populate_queues(self._queues, batch)
t0 = step == 0
self.eval()
if len(self._queues["action"]) == 0:
batch = {key: torch.stack(list(self._queues[key]), dim=1) for key in batch}
if self.n_obs_steps == 1:
# hack to remove the time dimension
for key in batch:
assert batch[key].shape[1] == 1
batch[key] = batch[key][:, 0]
actions = []
batch_size = batch["observation.image"].shape[0]
for i in range(batch_size):
obs = {
"rgb": batch["observation.image"][[i]],
"state": batch["observation.state"][[i]],
}
# Note: unsqueeze needed because `act` still uses non-batch logic.
action = self.act(obs, t0=t0, step=self.step)
actions.append(action)
action = torch.stack(actions)
# tdmpc returns an action for 1 timestep only, so we copy it over `n_action_steps` time
if i in range(self.n_action_steps):
self._queues["action"].append(action)
action = self._queues["action"].popleft()
return action
@torch.no_grad()
def act(self, obs, t0=False, step=None):
"""Take an action. Uses either MPC or the learned policy, depending on the self.cfg.mpc flag."""
obs = {k: o.detach() for k, o in obs.items()} if isinstance(obs, dict) else obs.detach()
z = self.model.encode(obs)
if self.cfg.mpc:
a = self.plan(z, t0=t0, step=step)
else:
a = self.model.pi(z, self.cfg.min_std * self.model.training).squeeze(0)
return a
@torch.no_grad()
def estimate_value(self, z, actions, horizon):
"""Estimate value of a trajectory starting at latent state z and executing given actions."""
G, discount = 0, 1
for t in range(horizon):
if self.cfg.uncertainty_cost > 0:
G -= (
discount
* self.cfg.uncertainty_cost
* self.model.Q(z, actions[t], return_type="all").std(dim=0)
)
z, reward = self.model.next(z, actions[t])
G += discount * reward
discount *= self.cfg.discount
pi = self.model.pi(z, self.cfg.min_std)
G += discount * self.model.Q(z, pi, return_type="min")
if self.cfg.uncertainty_cost > 0:
G -= discount * self.cfg.uncertainty_cost * self.model.Q(z, pi, return_type="all").std(dim=0)
return G
@torch.no_grad()
def plan(self, z, step=None, t0=True):
"""
Plan next action using TD-MPC inference.
z: latent state.
step: current time step. determines e.g. planning horizon.
t0: whether current step is the first step of an episode.
"""
# during eval: eval_mode: uniform sampling and action noise is disabled during evaluation.
assert step is not None
# Seed steps
if step < self.cfg.seed_steps and self.model.training:
return torch.empty(self.action_dim, dtype=torch.float32, device=self.device).uniform_(-1, 1)
# Sample policy trajectories
horizon = int(min(self.cfg.horizon, h.linear_schedule(self.cfg.horizon_schedule, step)))
num_pi_trajs = int(self.cfg.mixture_coef * self.cfg.num_samples)
if num_pi_trajs > 0:
pi_actions = torch.empty(horizon, num_pi_trajs, self.action_dim, device=self.device)
_z = z.repeat(num_pi_trajs, 1)
for t in range(horizon):
pi_actions[t] = self.model.pi(_z, self.cfg.min_std)
_z = self.model.next_dynamics(_z, pi_actions[t])
# Initialize state and parameters
z = z.repeat(self.cfg.num_samples + num_pi_trajs, 1)
mean = torch.zeros(horizon, self.action_dim, device=self.device)
std = self.cfg.max_std * torch.ones(horizon, self.action_dim, device=self.device)
if not t0 and hasattr(self, "_prev_mean"):
mean[:-1] = self._prev_mean[1:]
# Iterate CEM
for _ in range(self.cfg.iterations):
actions = torch.clamp(
mean.unsqueeze(1)
+ std.unsqueeze(1)
* torch.randn(horizon, self.cfg.num_samples, self.action_dim, device=std.device),
-1,
1,
)
if num_pi_trajs > 0:
actions = torch.cat([actions, pi_actions], dim=1)
# Compute elite actions
value = self.estimate_value(z, actions, horizon).nan_to_num_(0)
elite_idxs = torch.topk(value.squeeze(1), self.cfg.num_elites, dim=0).indices
elite_value, elite_actions = value[elite_idxs], actions[:, elite_idxs]
# Update parameters
max_value = elite_value.max(0)[0]
score = torch.exp(self.cfg.temperature * (elite_value - max_value))
score /= score.sum(0)
_mean = torch.sum(score.unsqueeze(0) * elite_actions, dim=1) / (score.sum(0) + 1e-9)
_std = torch.sqrt(
torch.sum(
score.unsqueeze(0) * (elite_actions - _mean.unsqueeze(1)) ** 2,
dim=1,
)
/ (score.sum(0) + 1e-9)
)
_std = _std.clamp_(self.std, self.cfg.max_std)
mean, std = self.cfg.momentum * mean + (1 - self.cfg.momentum) * _mean, _std
# Outputs
# TODO(rcadene): remove numpy with
# # Convert score tensor to probabilities using softmax
# probabilities = torch.softmax(score, dim=0)
# # Generate a random sample index based on the probabilities
# sample_index = torch.multinomial(probabilities, 1).item()
score = score.squeeze(1).cpu().numpy()
actions = elite_actions[:, np.random.choice(np.arange(score.shape[0]), p=score)]
self._prev_mean = mean
mean, std = actions[0], _std[0]
a = mean
if self.model.training:
a += std * torch.randn(self.action_dim, device=std.device)
return torch.clamp(a, -1, 1)
def update_pi(self, zs, acts=None):
"""Update policy using a sequence of latent states."""
self.pi_optim.zero_grad(set_to_none=True)
self.model.track_q_grad(False)
self.model.track_v_grad(False)
info = {}
# Advantage Weighted Regression
assert acts is not None
vs = self.model.V(zs)
qs = self.model_target.Q(zs, acts, return_type="min")
adv = qs - vs
exp_a = torch.exp(adv * self.cfg.A_scaling)
exp_a = torch.clamp(exp_a, max=100.0)
log_probs = h.gaussian_logprob(self.model.pi(zs) - acts, 0)
rho = torch.pow(self.cfg.rho, torch.arange(len(qs), device=self.device))
pi_loss = -((exp_a * log_probs).mean(dim=(1, 2)) * rho).mean()
info["adv"] = adv[0]
pi_loss.backward()
torch.nn.utils.clip_grad_norm_(
self.model._pi.parameters(),
self.cfg.grad_clip_norm,
error_if_nonfinite=False,
)
self.pi_optim.step()
self.model.track_q_grad(True)
self.model.track_v_grad(True)
info["pi_loss"] = pi_loss.item()
return pi_loss.item(), info
@torch.no_grad()
def _td_target(self, next_z, reward, mask):
"""Compute the TD-target from a reward and the observation at the following time step."""
next_v = self.model.V(next_z)
td_target = reward + self.cfg.discount * mask * next_v.squeeze(2)
return td_target
def forward(self, batch, step):
# TODO(alexander-soare): Refactor TDMPC and make it comply with the policy interface documentation.
raise NotImplementedError()
def update(self, batch, step):
"""Main update function. Corresponds to one iteration of the model learning."""
start_time = time.time()
batch_size = batch["index"].shape[0]
# TODO(rcadene): convert tdmpc with (batch size, time/horizon, channels)
# instead of currently (time/horizon, batch size, channels) which is not the pytorch convention
# batch size b = 256, time/horizon t = 5
# b t ... -> t b ...
for key in batch:
if batch[key].ndim > 1:
batch[key] = batch[key].transpose(1, 0)
action = batch["action"]
reward = batch["next.reward"]
# idxs = batch["index"] # TODO(rcadene): use idxs to update sampling weights
done = torch.zeros_like(reward, dtype=torch.bool, device=reward.device)
mask = torch.ones_like(reward, dtype=torch.bool, device=reward.device)
weights = torch.ones(batch_size, dtype=torch.bool, device=reward.device)
obses = {
"rgb": batch["observation.image"],
"state": batch["observation.state"],
}
shapes = {}
for k in obses:
shapes[k] = obses[k].shape
obses[k] = einops.rearrange(obses[k], "t b ... -> (t b) ... ")
# Apply augmentations
aug_tf = h.aug(self.cfg)
obses = aug_tf(obses)
for k in obses:
t, b = shapes[k][:2]
obses[k] = einops.rearrange(obses[k], "(t b) ... -> t b ... ", b=b, t=t)
obs, next_obses = {}, {}
for k in obses:
obs[k] = obses[k][0]
next_obses[k] = obses[k][1:].clone()
horizon = next_obses["rgb"].shape[0]
loss_mask = torch.ones_like(mask, device=self.device)
for t in range(1, horizon):
loss_mask[t] = loss_mask[t - 1] * (~done[t - 1])
self.optim.zero_grad(set_to_none=True)
self.std = h.linear_schedule(self.cfg.std_schedule, step)
self.model.train()
data_s = time.time() - start_time
# Compute targets
with torch.no_grad():
next_z = self.model.encode(next_obses)
z_targets = self.model_target.encode(next_obses)
td_targets = self._td_target(next_z, reward, mask)
# Latent rollout
zs = torch.empty(horizon + 1, batch_size, self.cfg.latent_dim, device=self.device)
reward_preds = torch.empty_like(reward, device=self.device)
assert reward.shape[0] == horizon
z = self.model.encode(obs)
zs[0] = z
value_info = {"Q": 0.0, "V": 0.0}
for t in range(horizon):
z, reward_pred = self.model.next(z, action[t])
zs[t + 1] = z
reward_preds[t] = reward_pred.squeeze(1)
with torch.no_grad():
v_target = self.model_target.Q(zs[:-1].detach(), action, return_type="min")
# Predictions
qs = self.model.Q(zs[:-1], action, return_type="all")
qs = qs.squeeze(3)
value_info["Q"] = qs.mean().item()
v = self.model.V(zs[:-1])
value_info["V"] = v.mean().item()
# Losses
rho = torch.pow(self.cfg.rho, torch.arange(horizon, device=self.device)).view(-1, 1)
consistency_loss = (rho * torch.mean(h.mse(zs[1:], z_targets), dim=2) * loss_mask).sum(dim=0)
reward_loss = (rho * h.mse(reward_preds, reward) * loss_mask).sum(dim=0)
q_value_loss, priority_loss = 0, 0
for q in range(self.cfg.num_q):
q_value_loss += (rho * h.mse(qs[q], td_targets) * loss_mask).sum(dim=0)
priority_loss += (rho * h.l1(qs[q], td_targets) * loss_mask).sum(dim=0)
expectile = h.linear_schedule(self.cfg.expectile, step)
v_value_loss = (rho * h.l2_expectile(v_target - v, expectile=expectile).squeeze(2) * loss_mask).sum(
dim=0
)
total_loss = (
self.cfg.consistency_coef * consistency_loss
+ self.cfg.reward_coef * reward_loss
+ self.cfg.value_coef * q_value_loss
+ self.cfg.value_coef * v_value_loss
)
weighted_loss = (total_loss * weights).mean()
weighted_loss.register_hook(lambda grad: grad * (1 / self.cfg.horizon))
has_nan = torch.isnan(weighted_loss).item()
if has_nan:
print(f"weighted_loss has nan: {total_loss=} {weights=}")
else:
weighted_loss.backward()
grad_norm = torch.nn.utils.clip_grad_norm_(
self.model.parameters(), self.cfg.grad_clip_norm, error_if_nonfinite=False
)
self.optim.step()
# TODO(rcadene): implement PrioritizedSampling by modifying sampler.weights with priorities computed by a criterion
# if self.cfg.per:
# # Update priorities
# priorities = priority_loss.clamp(max=1e4).detach()
# has_nan = torch.isnan(priorities).any().item()
# if has_nan:
# print(f"priorities has nan: {priorities=}")
# else:
# replay_buffer.update_priority(
# idxs[:num_slices],
# priorities[:num_slices],
# )
# if demo_batch_size > 0:
# demo_buffer.update_priority(demo_idxs, priorities[num_slices:])
# Update policy + target network
_, pi_update_info = self.update_pi(zs[:-1].detach(), acts=action)
if step % self.cfg.update_freq == 0:
h.ema(self.model._encoder, self.model_target._encoder, self.cfg.tau)
h.ema(self.model._Qs, self.model_target._Qs, self.cfg.tau)
self.model.eval()
info = {
"consistency_loss": float(consistency_loss.mean().item()),
"reward_loss": float(reward_loss.mean().item()),
"Q_value_loss": float(q_value_loss.mean().item()),
"V_value_loss": float(v_value_loss.mean().item()),
"sum_loss": float(total_loss.mean().item()),
"loss": float(weighted_loss.mean().item()),
"grad_norm": float(grad_norm),
"lr": self.cfg.lr,
"data_s": data_s,
"update_s": time.time() - start_time,
}
# info["demo_batch_size"] = demo_batch_size
info["expectile"] = expectile
info.update(value_info)
info.update(pi_update_info)
self.step[0] = step
return info

View File

@ -9,31 +9,24 @@ hydra:
job: job:
name: default name: default
seed: 1337
# batch size for TorchRL SerialEnv. Each underlying env will get the seed = seed + env_index
# NOTE: only diffusion policy supports rollout_batch_size > 1
rollout_batch_size: 1
device: cuda # cpu device: cuda # cpu
prefetch: 4 seed: ???
eval_freq: ??? dataset_repo_id: lerobot/pusht
save_freq: ???
eval_episodes: ???
save_video: false
save_model: false
save_buffer: false
train_steps: ???
fps: ???
offline_prioritized_sampler: true training:
offline_steps: ???
online_steps: ???
online_steps_between_rollouts: ???
online_sampling_ratio: 0.5
eval_freq: ???
save_freq: ???
log_freq: 250
save_model: false
dataset: eval:
repo_id: ??? n_episodes: 1
# TODO(alexander-soare): Right now this does not work. Reinstate this.
n_action_steps: ??? batch_size: 1
n_obs_steps: ???
env: ???
policy: ???
wandb: wandb:
enable: true enable: true

View File

@ -1,18 +1,7 @@
# @package _global_ # @package _global_
eval_episodes: 50
eval_freq: 7500
save_freq: 75000
log_freq: 250
# TODO: same as xarm, need to adjust
offline_steps: 25000
online_steps: 25000
fps: 50 fps: 50
dataset:
repo_id: lerobot/aloha_sim_insertion_human
env: env:
name: aloha name: aloha
task: AlohaInsertion-v0 task: AlohaInsertion-v0

View File

@ -1,18 +1,7 @@
# @package _global_ # @package _global_
eval_episodes: 50
eval_freq: 7500
save_freq: 75000
log_freq: 250
# TODO: same as xarm, need to adjust
offline_steps: 25000
online_steps: 25000
fps: 10 fps: 10
dataset:
repo_id: lerobot/pusht
env: env:
name: pusht name: pusht
task: PushT-v0 task: PushT-v0

View File

@ -1,17 +1,7 @@
# @package _global_ # @package _global_
eval_episodes: 20
eval_freq: 1000
save_freq: 10000
log_freq: 50
offline_steps: 25000
online_steps: 25000
fps: 15 fps: 15
dataset:
repo_id: lerobot/xarm_lift_medium
env: env:
name: xarm name: xarm
task: XarmLift-v0 task: XarmLift-v0

View File

@ -1,30 +1,41 @@
# @package _global_ # @package _global_
offline_steps: 80000 seed: 1000
online_steps: 0 dataset_repo_id: lerobot/aloha_sim_insertion_human
eval_episodes: 1 training:
eval_freq: 10000 offline_steps: 80000
save_freq: 100000 online_steps: 0
log_freq: 250 eval_freq: 10000
save_freq: 100000
log_freq: 250
save_model: true
n_obs_steps: 1 batch_size: 8
# when temporal_agg=False, n_action_steps=horizon lr: 1e-5
lr_backbone: 1e-5
weight_decay: 1e-4
grad_clip_norm: 10
online_steps_between_rollouts: 1
override_dataset_stats: override_dataset_stats:
observation.images.top: observation.images.top:
# stats from imagenet, since we use a pretrained vision model # stats from imagenet, since we use a pretrained vision model
mean: [[[0.485]], [[0.456]], [[0.406]]] # (c,1,1) mean: [[[0.485]], [[0.456]], [[0.406]]] # (c,1,1)
std: [[[0.229]], [[0.224]], [[0.225]]] # (c,1,1) std: [[[0.229]], [[0.224]], [[0.225]]] # (c,1,1)
delta_timestamps:
action: "[i / ${fps} for i in range(${policy.chunk_size})]"
eval:
n_episodes:: 50
# See `configuration_act.py` for more details. # See `configuration_act.py` for more details.
policy: policy:
name: act name: act
pretrained_model_path:
# Input / output structure. # Input / output structure.
n_obs_steps: ${n_obs_steps} n_obs_steps: 1
chunk_size: 100 # chunk_size chunk_size: 100 # chunk_size
n_action_steps: 100 n_action_steps: 100
@ -49,7 +60,7 @@ policy:
replace_final_stride_with_dilation: false replace_final_stride_with_dilation: false
# Transformer layers. # Transformer layers.
pre_norm: false pre_norm: false
d_model: 512 dim_model: 512
n_heads: 8 n_heads: 8
dim_feedforward: 3200 dim_feedforward: 3200
feedforward_activation: relu feedforward_activation: relu
@ -66,15 +77,3 @@ policy:
# Training and loss computation. # Training and loss computation.
dropout: 0.1 dropout: 0.1
kl_weight: 10.0 kl_weight: 10.0
# ---
# TODO(alexander-soare): Remove these from the policy config.
batch_size: 8
lr: 1e-5
lr_backbone: 1e-5
weight_decay: 1e-4
grad_clip_norm: 10
utd: 1
delta_timestamps:
action: "[i / ${fps} for i in range(${policy.chunk_size})]"

View File

@ -1,22 +1,33 @@
# @package _global_ # @package _global_
seed: 100000 seed: 100000
horizon: 16 dataset_repo_id: lerobot/pusht
n_obs_steps: 2
n_action_steps: 8
dataset_obs_steps: ${n_obs_steps}
past_action_visible: False
keypoint_visible_rate: 1.0
eval_episodes: 50 training:
eval_freq: 5000 offline_steps: 200000
save_freq: 5000 online_steps: 0
log_freq: 250 eval_freq: 5000
save_freq: 5000
log_freq: 250
save_model: true
offline_steps: 200000 batch_size: 64
online_steps: 0 grad_clip_norm: 10
lr: 1.0e-4
lr_scheduler: cosine
lr_warmup_steps: 500
adam_betas: [0.95, 0.999]
adam_eps: 1.0e-8
adam_weight_decay: 1.0e-6
online_steps_between_rollouts: 1
offline_prioritized_sampler: true delta_timestamps:
observation.image: "[i / ${fps} for i in range(1 - ${policy.n_obs_steps}, 1)]"
observation.state: "[i / ${fps} for i in range(1 - ${policy.n_obs_steps}, 1)]"
action: "[i / ${fps} for i in range(1 - ${policy.n_obs_steps}, 1 - ${policy.n_obs_steps} + ${policy.horizon})]"
eval:
n_episodes: 50
override_dataset_stats: override_dataset_stats:
# TODO(rcadene, alexander-soare): should we remove image stats as well? do we use a pretrained vision model? # TODO(rcadene, alexander-soare): should we remove image stats as well? do we use a pretrained vision model?
@ -35,12 +46,10 @@ override_dataset_stats:
policy: policy:
name: diffusion name: diffusion
pretrained_model_path:
# Input / output structure. # Input / output structure.
n_obs_steps: ${n_obs_steps} n_obs_steps: 2
horizon: ${horizon} horizon: 16
n_action_steps: ${n_action_steps} n_action_steps: 8
input_shapes: input_shapes:
# TODO(rcadene, alexander-soare): add variables for height and width from the dataset/env? # TODO(rcadene, alexander-soare): add variables for height and width from the dataset/env?
@ -84,23 +93,9 @@ policy:
# --- # ---
# TODO(alexander-soare): Remove these from the policy config. # TODO(alexander-soare): Remove these from the policy config.
batch_size: 64
grad_clip_norm: 10
lr: 1.0e-4
lr_scheduler: cosine
lr_warmup_steps: 500
adam_betas: [0.95, 0.999]
adam_eps: 1.0e-8
adam_weight_decay: 1.0e-6
utd: 1
use_ema: true use_ema: true
ema_update_after_step: 0 ema_update_after_step: 0
ema_min_alpha: 0.0 ema_min_alpha: 0.0
ema_max_alpha: 0.9999 ema_max_alpha: 0.9999
ema_inv_gamma: 1.0 ema_inv_gamma: 1.0
ema_power: 0.75 ema_power: 0.75
delta_timestamps:
observation.image: "[i / ${fps} for i in range(1 - ${n_obs_steps}, 1)]"
observation.state: "[i / ${fps} for i in range(1 - ${n_obs_steps}, 1)]"
action: "[i / ${fps} for i in range(1 - ${n_obs_steps}, 1 - ${n_obs_steps} + ${policy.horizon})]"

View File

@ -1,85 +1,76 @@
# @package _global_ # @package _global_
n_action_steps: 2 seed: 1
n_obs_steps: 1
training:
offline_steps: 25000
online_steps: 25000
eval_freq: 5000
online_steps_between_rollouts: 1
online_sampling_ratio: 0.5
batch_size: 256
grad_clip_norm: 10.0
lr: 3e-4
delta_timestamps:
observation.image: "[i / ${fps} for i in range(${policy.horizon} + 1)]"
observation.state: "[i / ${fps} for i in range(${policy.horizon} + 1)]"
action: "[i / ${fps} for i in range(${policy.horizon})]"
next.reward: "[i / ${fps} for i in range(${policy.horizon})]"
policy: policy:
name: tdmpc name: tdmpc
reward_scale: 1.0 pretrained_model_path:
episode_length: ${env.episode_length} # Input / output structure.
discount: 0.9 n_action_repeats: 2
modality: 'all'
# pixels
frame_stack: 1
num_channels: 32
img_size: ${env.image_size}
state_dim: ${env.action_dim}
action_dim: ${env.action_dim}
# planning
mpc: true
iterations: 6
num_samples: 512
num_elites: 50
mixture_coef: 0.1
min_std: 0.05
max_std: 2.0
temperature: 0.5
momentum: 0.1
uncertainty_cost: 1
# actor
log_std_min: -10
log_std_max: 2
# learning
batch_size: 256
max_buffer_size: 10000
horizon: 5 horizon: 5
reward_coef: 0.5
value_coef: 0.1
consistency_coef: 20
rho: 0.5
kappa: 0.1
lr: 3e-4
std_schedule: ${policy.min_std}
horizon_schedule: ${policy.horizon}
per: true
per_alpha: 0.6
per_beta: 0.4
grad_clip_norm: 10
seed_steps: 0
update_freq: 2
tau: 0.01
utd: 1
# offline rl input_shapes:
# dataset_dir: ??? # TODO(rcadene, alexander-soare): add variables for height and width from the dataset/env?
data_first_percent: 1.0 observation.image: [3, 84, 84]
is_data_clip: true observation.state: ["${env.state_dim}"]
data_clip_eps: 1e-5 output_shapes:
expectile: 0.9 action: ["${env.action_dim}"]
A_scaling: 3.0
# offline->online # Normalization / Unnormalization
offline_steps: ${offline_steps} input_normalization_modes: null
pretrained_model_path: "" output_normalization_modes:
# pretrained_model_path: "/home/rcadene/code/fowm/logs/xarm_lift/all/default/2/models/offline.pt" action: min_max
# pretrained_model_path: "/home/rcadene/code/fowm/logs/xarm_lift/all/default/2/models/final.pt"
balanced_sampling: true
demo_schedule: 0.5
# architecture # Architecture / modeling.
enc_dim: 256 # Neural networks.
num_q: 5 image_encoder_hidden_dim: 32
mlp_dim: 512 state_encoder_hidden_dim: 256
latent_dim: 50 latent_dim: 50
q_ensemble_size: 5
mlp_dim: 512
# Reinforcement learning.
discount: 0.9
delta_timestamps: # Inference.
observation.image: "[i / ${fps} for i in range(6)]" use_mpc: false
observation.state: "[i / ${fps} for i in range(6)]" cem_iterations: 6
action: "[i / ${fps} for i in range(5)]" max_std: 2.0
next.reward: "[i / ${fps} for i in range(5)]" min_std: 0.05
n_gaussian_samples: 512
n_pi_samples: 51
uncertainty_regularizer_coeff: 1.0
n_elites: 50
elite_weighting_temperature: 0.5
gaussian_mean_momentum: 0.1
# Training and loss computation.
max_random_shift_ratio: 0.0476
# Loss coefficients.
reward_coeff: 0.5
expectile_weight: 0.9
value_coeff: 0.1
consistency_coeff: 20.0
advantage_scaling: 3.0
pi_coeff: 0.5
temporal_decay_coeff: 0.5
# Target model.
target_model_momentum: 0.995

View File

@ -1,30 +1,29 @@
"""Evaluate a policy on an environment by running rollouts and computing metrics. """Evaluate a policy on an environment by running rollouts and computing metrics.
The script may be run in one of two ways: Usage examples:
1. By providing the path to a config file with the --config argument. You want to evaluate a model from the hub (eg: https://huggingface.co/lerobot/diffusion_policy_pusht_image)
2. By providing a HuggingFace Hub ID with the --hub-id argument. You may also provide a revision number with the for 10 episodes.
--revision argument.
In either case, it is possible to override config arguments by adding a list of config.key=value arguments. ```
python lerobot/scripts/eval.py -p lerobot/diffusion_policy_pusht_image eval.n_episodes=10
```
Examples: OR, you want to evaluate a model checkpoint from the LeRobot training script for 10 episodes.
You have a specific config file to go with trained model weights, and want to run 10 episodes.
``` ```
python lerobot/scripts/eval.py \ python lerobot/scripts/eval.py \
--config PATH/TO/FOLDER/config.yaml \ -p outputs/train/diffusion_policy_pusht_image/checkpoints/005000 \
policy.pretrained_model_path=PATH/TO/FOLDER/weights.pth \ eval.n_episodes=10
eval_episodes=10
``` ```
You have a HuggingFace Hub ID, you know which revision you want, and want to run 10 episodes (note that in this case, Note that in both examples, the repo/folder should contain at least `config.json`, `config.yaml` and
you don't need to specify which weights to use): `model.safetensors`.
``` Note the formatting for providing the number of episodes. Generally, you may provide any number of arguments
python lerobot/scripts/eval.py --hub-id HUB/ID --revision v1.0 eval_episodes=10 with `qualified.parameter.name=value`. In this case, the parameter eval.n_episodes appears as `n_episodes`
``` nested under `eval` in the `config.yaml` found at
https://huggingface.co/lerobot/diffusion_policy_pusht_image/tree/main.
""" """
import argparse import argparse
@ -42,9 +41,12 @@ import numpy as np
import torch import torch
from datasets import Dataset, Features, Image, Sequence, Value from datasets import Dataset, Features, Image, Sequence, Value
from huggingface_hub import snapshot_download from huggingface_hub import snapshot_download
from huggingface_hub.utils._errors import RepositoryNotFoundError
from huggingface_hub.utils._validators import HFValidationError
from PIL import Image as PILImage from PIL import Image as PILImage
from tqdm import trange from tqdm import trange
from lerobot.common.datasets.factory import make_dataset
from lerobot.common.datasets.utils import hf_transform_to_torch from lerobot.common.datasets.utils import hf_transform_to_torch
from lerobot.common.envs.factory import make_env from lerobot.common.envs.factory import make_env
from lerobot.common.envs.utils import postprocess_action, preprocess_observation from lerobot.common.envs.utils import postprocess_action, preprocess_observation
@ -65,10 +67,10 @@ def eval_policy(
""" """
set `return_episode_data` to return a Hugging Face dataset object in an "episodes" key of the return dict. set `return_episode_data` to return a Hugging Face dataset object in an "episodes" key of the return dict.
""" """
policy.eval()
fps = env.unwrapped.metadata["render_fps"] fps = env.unwrapped.metadata["render_fps"]
if policy is not None:
policy.eval()
device = "cpu" if policy is None else next(policy.parameters()).device device = "cpu" if policy is None else next(policy.parameters()).device
start = time.time() start = time.time()
@ -130,7 +132,7 @@ def eval_policy(
# get the next action for the environment # get the next action for the environment
with torch.inference_mode(): with torch.inference_mode():
action = policy.select_action(observation, step=step) action = policy.select_action(observation)
# convert to cpu numpy # convert to cpu numpy
action = postprocess_action(action) action = postprocess_action(action)
@ -349,26 +351,42 @@ def eval_policy(
return info return info
def eval(cfg: dict, out_dir=None): def eval(
pretrained_policy_path: str | None = None,
hydra_cfg_path: str | None = None,
config_overrides: list[str] | None = None,
):
assert (pretrained_policy_path is None) ^ (hydra_cfg_path is None)
if hydra_cfg_path is None:
hydra_cfg = init_hydra_config(pretrained_policy_path / "config.yaml", config_overrides)
else:
hydra_cfg = init_hydra_config(hydra_cfg_path, config_overrides)
out_dir = (
f"outputs/eval/{dt.now().strftime('%Y-%m-%d/%H-%M-%S')}_{hydra_cfg.env.name}_{hydra_cfg.policy.name}"
)
if out_dir is None: if out_dir is None:
raise NotImplementedError() raise NotImplementedError()
init_logging()
# Check device is available # Check device is available
get_safe_torch_device(cfg.device, log=True) get_safe_torch_device(hydra_cfg.device, log=True)
torch.backends.cudnn.benchmark = True torch.backends.cudnn.benchmark = True
torch.backends.cuda.matmul.allow_tf32 = True torch.backends.cuda.matmul.allow_tf32 = True
set_global_seed(cfg.seed) set_global_seed(hydra_cfg.seed)
log_output_dir(out_dir) log_output_dir(out_dir)
logging.info("Making environment.") logging.info("Making environment.")
env = make_env(cfg, num_parallel_envs=cfg.eval_episodes) env = make_env(hydra_cfg, num_parallel_envs=hydra_cfg.eval.n_episodes)
logging.info("Making policy.") logging.info("Making policy.")
policy = make_policy(cfg) if hydra_cfg_path is None:
policy = make_policy(hydra_cfg=hydra_cfg, pretrained_policy_name_or_path=pretrained_policy_path)
else:
# Note: We need the dataset stats to pass to the policy's normalization modules.
policy = make_policy(hydra_cfg=hydra_cfg, dataset_stats=make_dataset(hydra_cfg).stats)
policy.eval()
info = eval_policy( info = eval_policy(
env, env,
@ -376,7 +394,7 @@ def eval(cfg: dict, out_dir=None):
max_episodes_rendered=10, max_episodes_rendered=10,
video_dir=Path(out_dir) / "eval", video_dir=Path(out_dir) / "eval",
return_episode_data=False, return_episode_data=False,
seed=cfg.seed, seed=hydra_cfg.seed,
) )
print(info["aggregated"]) print(info["aggregated"])
@ -390,13 +408,29 @@ def eval(cfg: dict, out_dir=None):
if __name__ == "__main__": if __name__ == "__main__":
init_logging()
parser = argparse.ArgumentParser( parser = argparse.ArgumentParser(
description=__doc__, formatter_class=argparse.RawDescriptionHelpFormatter description=__doc__, formatter_class=argparse.RawDescriptionHelpFormatter
) )
group = parser.add_mutually_exclusive_group(required=True) group = parser.add_mutually_exclusive_group(required=True)
group.add_argument("--config", help="Path to a specific yaml config you want to use.") group.add_argument(
group.add_argument("--hub-id", help="HuggingFace Hub ID for a pretrained model.") "-p",
parser.add_argument("--revision", help="Optionally provide the HuggingFace Hub revision ID.") "--pretrained-policy-name-or-path",
help=(
"Either the repo ID of a model hosted on the Hub or a path to a directory containing weights "
"saved using `Policy.save_pretrained`. If not provided, the policy is initialized from scratch "
"(useful for debugging). This argument is mutually exclusive with `--config`."
),
)
group.add_argument(
"--config",
help=(
"Path to a yaml config you want to use for initializing a policy from scratch (useful for "
"debugging). This argument is mutually exclusive with `--pretrained-policy-name-or-path` (`-p`)."
),
)
parser.add_argument("--revision", help="Optionally provide the Hugging Face Hub revision ID.")
parser.add_argument( parser.add_argument(
"overrides", "overrides",
nargs="*", nargs="*",
@ -404,16 +438,28 @@ if __name__ == "__main__":
) )
args = parser.parse_args() args = parser.parse_args()
if args.config is not None: if args.pretrained_policy_name_or_path is None:
# Note: For the config_path, Hydra wants a path relative to this script file. eval(hydra_cfg_path=args.config, config_overrides=args.overrides)
cfg = init_hydra_config(args.config, args.overrides) else:
elif args.hub_id is not None: try:
folder = Path(snapshot_download(args.hub_id, revision=args.revision)) pretrained_policy_path = Path(
cfg = init_hydra_config( snapshot_download(args.pretrained_policy_name_or_path, revision=args.revision)
folder / "config.yaml", [f"policy.pretrained_model_path={folder / 'model.pt'}", *args.overrides] )
except HFValidationError:
logging.warning(
"The provided pretrained_policy_name_or_path is not a valid Hugging Face Hub repo ID. "
"Treating it as a local directory."
)
except RepositoryNotFoundError:
logging.warning(
"The provided pretrained_policy_name_or_path was not found on the Hugging Face Hub. Treating "
"it as a local directory."
)
pretrained_policy_path = Path(args.pretrained_policy_name_or_path)
if not pretrained_policy_path.is_dir() or not pretrained_policy_path.exists():
raise ValueError(
"The provided pretrained_policy_name_or_path is not a valid/existing Hugging Face Hub "
"repo ID, nor is it an existing local directory."
) )
eval( eval(pretrained_policy_path=pretrained_policy_path, config_overrides=args.overrides)
cfg,
out_dir=f"outputs/eval/{dt.now().strftime('%Y-%m-%d/%H-%M-%S')}_{cfg.env.name}_{cfg.policy.name}",
)

View File

@ -15,6 +15,7 @@ from lerobot.common.datasets.utils import cycle
from lerobot.common.envs.factory import make_env from lerobot.common.envs.factory import make_env
from lerobot.common.logger import Logger, log_output_dir from lerobot.common.logger import Logger, log_output_dir
from lerobot.common.policies.factory import make_policy from lerobot.common.policies.factory import make_policy
from lerobot.common.policies.policy_protocol import PolicyWithUpdate
from lerobot.common.utils.utils import ( from lerobot.common.utils.utils import (
format_big_number, format_big_number,
get_safe_torch_device, get_safe_torch_device,
@ -39,12 +40,17 @@ def update_policy(policy, batch, optimizer, grad_clip_norm, lr_scheduler=None):
optimizer.step() optimizer.step()
optimizer.zero_grad() optimizer.zero_grad()
if lr_scheduler is not None: if lr_scheduler is not None:
lr_scheduler.step() lr_scheduler.step()
if hasattr(policy, "ema") and policy.ema is not None: if hasattr(policy, "ema") and policy.ema is not None:
policy.ema.step(policy.diffusion) policy.ema.step(policy.diffusion)
if isinstance(policy, PolicyWithUpdate):
# To possibly update an internal buffer (for instance an Exponential Moving Average like in TDMPC).
policy.update()
info = { info = {
"loss": loss.item(), "loss": loss.item(),
"grad_norm": float(grad_norm), "grad_norm": float(grad_norm),
@ -81,7 +87,7 @@ def log_train_info(logger, info, step, cfg, dataset, is_offline):
# A sample is an (observation,action) pair, where observation and action # A sample is an (observation,action) pair, where observation and action
# can be on multiple timestamps. In a batch, we have `batch_size`` number of samples. # can be on multiple timestamps. In a batch, we have `batch_size`` number of samples.
num_samples = (step + 1) * cfg.policy.batch_size num_samples = (step + 1) * cfg.training.batch_size
avg_samples_per_ep = dataset.num_samples / dataset.num_episodes avg_samples_per_ep = dataset.num_samples / dataset.num_episodes
num_episodes = num_samples / avg_samples_per_ep num_episodes = num_samples / avg_samples_per_ep
num_epochs = num_samples / dataset.num_samples num_epochs = num_samples / dataset.num_samples
@ -117,7 +123,7 @@ def log_eval_info(logger, info, step, cfg, dataset, is_offline):
# A sample is an (observation,action) pair, where observation and action # A sample is an (observation,action) pair, where observation and action
# can be on multiple timestamps. In a batch, we have `batch_size`` number of samples. # can be on multiple timestamps. In a batch, we have `batch_size`` number of samples.
num_samples = (step + 1) * cfg.policy.batch_size num_samples = (step + 1) * cfg.training.batch_size
avg_samples_per_ep = dataset.num_samples / dataset.num_episodes avg_samples_per_ep = dataset.num_samples / dataset.num_episodes
num_episodes = num_samples / avg_samples_per_ep num_episodes = num_samples / avg_samples_per_ep
num_epochs = num_samples / dataset.num_samples num_epochs = num_samples / dataset.num_samples
@ -246,11 +252,12 @@ def train(cfg: dict, out_dir=None, job_name=None):
raise NotImplementedError() raise NotImplementedError()
if job_name is None: if job_name is None:
raise NotImplementedError() raise NotImplementedError()
if cfg.online_steps > 0:
assert cfg.rollout_batch_size == 1, "rollout_batch_size > 1 not supported for online training steps"
init_logging() init_logging()
if cfg.training.online_steps > 0 and cfg.eval.batch_size > 1:
logging.warning("eval.batch_size > 1 not supported for online training steps")
# Check device is available # Check device is available
get_safe_torch_device(cfg.device, log=True) get_safe_torch_device(cfg.device, log=True)
@ -262,10 +269,10 @@ def train(cfg: dict, out_dir=None, job_name=None):
offline_dataset = make_dataset(cfg) offline_dataset = make_dataset(cfg)
logging.info("make_env") logging.info("make_env")
env = make_env(cfg, num_parallel_envs=cfg.eval_episodes) env = make_env(cfg, num_parallel_envs=cfg.eval.n_episodes)
logging.info("make_policy") logging.info("make_policy")
policy = make_policy(cfg, dataset_stats=offline_dataset.stats) policy = make_policy(hydra_cfg=cfg, dataset_stats=offline_dataset.stats)
# Create optimizer and scheduler # Create optimizer and scheduler
# Temporary hack to move optimizer out of policy # Temporary hack to move optimizer out of policy
@ -282,34 +289,33 @@ def train(cfg: dict, out_dir=None, job_name=None):
"params": [ "params": [
p for n, p in policy.named_parameters() if n.startswith("backbone") and p.requires_grad p for n, p in policy.named_parameters() if n.startswith("backbone") and p.requires_grad
], ],
"lr": cfg.policy.lr_backbone, "lr": cfg.training.lr_backbone,
}, },
] ]
optimizer = torch.optim.AdamW( optimizer = torch.optim.AdamW(
optimizer_params_dicts, lr=cfg.policy.lr, weight_decay=cfg.policy.weight_decay optimizer_params_dicts, lr=cfg.training.lr, weight_decay=cfg.training.weight_decay
) )
lr_scheduler = None lr_scheduler = None
elif cfg.policy.name == "diffusion": elif cfg.policy.name == "diffusion":
optimizer = torch.optim.Adam( optimizer = torch.optim.Adam(
policy.diffusion.parameters(), policy.diffusion.parameters(),
cfg.policy.lr, cfg.training.lr,
cfg.policy.adam_betas, cfg.training.adam_betas,
cfg.policy.adam_eps, cfg.training.adam_eps,
cfg.policy.adam_weight_decay, cfg.training.adam_weight_decay,
) )
# TODO(rcadene): modify lr scheduler so that it doesn't depend on epochs but steps assert cfg.training.online_steps == 0, "Diffusion Policy does not handle online training."
# configure lr scheduler
lr_scheduler = get_scheduler( lr_scheduler = get_scheduler(
cfg.policy.lr_scheduler, cfg.training.lr_scheduler,
optimizer=optimizer, optimizer=optimizer,
num_warmup_steps=cfg.policy.lr_warmup_steps, num_warmup_steps=cfg.training.lr_warmup_steps,
num_training_steps=cfg.offline_steps, num_training_steps=cfg.training.offline_steps,
# pytorch assumes stepping LRScheduler every epoch
# however huggingface diffusers steps it every batch
last_epoch=-1,
) )
elif policy.name == "tdmpc": elif policy.name == "tdmpc":
raise NotImplementedError("TD-MPC not implemented yet.") optimizer = torch.optim.Adam(policy.parameters(), cfg.training.lr)
lr_scheduler = None
else:
raise NotImplementedError()
num_learnable_params = sum(p.numel() for p in policy.parameters() if p.requires_grad) num_learnable_params = sum(p.numel() for p in policy.parameters() if p.requires_grad)
num_total_params = sum(p.numel() for p in policy.parameters()) num_total_params = sum(p.numel() for p in policy.parameters())
@ -319,8 +325,8 @@ def train(cfg: dict, out_dir=None, job_name=None):
log_output_dir(out_dir) log_output_dir(out_dir)
logging.info(f"{cfg.env.task=}") logging.info(f"{cfg.env.task=}")
logging.info(f"{cfg.offline_steps=} ({format_big_number(cfg.offline_steps)})") logging.info(f"{cfg.training.offline_steps=} ({format_big_number(cfg.training.offline_steps)})")
logging.info(f"{cfg.online_steps=}") logging.info(f"{cfg.training.online_steps=}")
logging.info(f"{offline_dataset.num_samples=} ({format_big_number(offline_dataset.num_samples)})") logging.info(f"{offline_dataset.num_samples=} ({format_big_number(offline_dataset.num_samples)})")
logging.info(f"{offline_dataset.num_episodes=}") logging.info(f"{offline_dataset.num_episodes=}")
logging.info(f"{num_learnable_params=} ({format_big_number(num_learnable_params)})") logging.info(f"{num_learnable_params=} ({format_big_number(num_learnable_params)})")
@ -328,7 +334,7 @@ def train(cfg: dict, out_dir=None, job_name=None):
# Note: this helper will be used in offline and online training loops. # Note: this helper will be used in offline and online training loops.
def _maybe_eval_and_maybe_save(step): def _maybe_eval_and_maybe_save(step):
if step % cfg.eval_freq == 0: if step % cfg.training.eval_freq == 0:
logging.info(f"Eval policy at step {step}") logging.info(f"Eval policy at step {step}")
eval_info = eval_policy( eval_info = eval_policy(
env, env,
@ -342,37 +348,44 @@ def train(cfg: dict, out_dir=None, job_name=None):
logger.log_video(eval_info["videos"][0], step, mode="eval") logger.log_video(eval_info["videos"][0], step, mode="eval")
logging.info("Resume training") logging.info("Resume training")
if cfg.save_model and step % cfg.save_freq == 0: if cfg.training.save_model and step % cfg.training.save_freq == 0:
logging.info(f"Checkpoint policy after step {step}") logging.info(f"Checkpoint policy after step {step}")
logger.save_model(policy, identifier=step) # Note: Save with step as the identifier, and format it to have at least 6 digits but more if
# needed (choose 6 as a minimum for consistency without being overkill).
logger.save_model(
policy,
identifier=str(step).zfill(
max(6, len(str(cfg.training.offline_steps + cfg.training.online_steps)))
),
)
logging.info("Resume training") logging.info("Resume training")
# create dataloader for offline training # create dataloader for offline training
dataloader = torch.utils.data.DataLoader( dataloader = torch.utils.data.DataLoader(
offline_dataset, offline_dataset,
num_workers=8, num_workers=4,
batch_size=cfg.policy.batch_size, batch_size=cfg.training.batch_size,
shuffle=True, shuffle=True,
pin_memory=cfg.device != "cpu", pin_memory=cfg.device != "cpu",
drop_last=False, drop_last=False,
) )
dl_iter = cycle(dataloader) dl_iter = cycle(dataloader)
policy.train()
step = 0 # number of policy update (forward + backward + optim) step = 0 # number of policy update (forward + backward + optim)
is_offline = True is_offline = True
for offline_step in range(cfg.offline_steps): for offline_step in range(cfg.training.offline_steps):
if offline_step == 0: if offline_step == 0:
logging.info("Start offline training on a fixed dataset") logging.info("Start offline training on a fixed dataset")
policy.train()
batch = next(dl_iter) batch = next(dl_iter)
for key in batch: for key in batch:
batch[key] = batch[key].to(cfg.device, non_blocking=True) batch[key] = batch[key].to(cfg.device, non_blocking=True)
train_info = update_policy(policy, batch, optimizer, cfg.policy.grad_clip_norm, lr_scheduler) train_info = update_policy(policy, batch, optimizer, cfg.training.grad_clip_norm, lr_scheduler)
# TODO(rcadene): is it ok if step_t=0 = 0 and not 1 as previously done? # TODO(rcadene): is it ok if step_t=0 = 0 and not 1 as previously done?
if step % cfg.log_freq == 0: if step % cfg.training.log_freq == 0:
log_train_info(logger, train_info, step, cfg, offline_dataset, is_offline) log_train_info(logger, train_info, step, cfg, offline_dataset, is_offline)
# Note: _maybe_eval_and_maybe_save happens **after** the `step`th training update has completed, so we pass in # Note: _maybe_eval_and_maybe_save happens **after** the `step`th training update has completed, so we pass in
@ -398,7 +411,7 @@ def train(cfg: dict, out_dir=None, job_name=None):
dataloader = torch.utils.data.DataLoader( dataloader = torch.utils.data.DataLoader(
concat_dataset, concat_dataset,
num_workers=4, num_workers=4,
batch_size=cfg.policy.batch_size, batch_size=cfg.training.batch_size,
sampler=sampler, sampler=sampler,
pin_memory=cfg.device != "cpu", pin_memory=cfg.device != "cpu",
drop_last=False, drop_last=False,
@ -407,10 +420,11 @@ def train(cfg: dict, out_dir=None, job_name=None):
online_step = 0 online_step = 0
is_offline = False is_offline = False
for env_step in range(cfg.online_steps): for env_step in range(cfg.training.online_steps):
if env_step == 0: if env_step == 0:
logging.info("Start online training by interacting with environment") logging.info("Start online training by interacting with environment")
policy.eval()
with torch.no_grad(): with torch.no_grad():
eval_info = eval_policy( eval_info = eval_policy(
rollout_env, rollout_env,
@ -425,19 +439,19 @@ def train(cfg: dict, out_dir=None, job_name=None):
sampler, sampler,
hf_dataset=eval_info["episodes"]["hf_dataset"], hf_dataset=eval_info["episodes"]["hf_dataset"],
episode_data_index=eval_info["episodes"]["episode_data_index"], episode_data_index=eval_info["episodes"]["episode_data_index"],
pc_online_samples=cfg.get("demo_schedule", 0.5), pc_online_samples=cfg.training.online_sampling_ratio,
) )
for _ in range(cfg.policy.utd):
policy.train() policy.train()
for _ in range(cfg.training.online_steps_between_rollouts):
batch = next(dl_iter) batch = next(dl_iter)
for key in batch: for key in batch:
batch[key] = batch[key].to(cfg.device, non_blocking=True) batch[key] = batch[key].to(cfg.device, non_blocking=True)
train_info = update_policy(policy, batch, optimizer, cfg.policy.grad_clip_norm, lr_scheduler) train_info = update_policy(policy, batch, optimizer, cfg.training.grad_clip_norm, lr_scheduler)
if step % cfg.log_freq == 0: if step % cfg.training.log_freq == 0:
log_train_info(logger, train_info, step, cfg, online_dataset, is_offline) log_train_info(logger, train_info, step, cfg, online_dataset, is_offline)
# Note: _maybe_eval_and_maybe_save happens **after** the `step`th training update has completed, so we pass # Note: _maybe_eval_and_maybe_save happens **after** the `step`th training update has completed, so we pass

296
poetry.lock generated
View File

@ -837,13 +837,13 @@ files = [
[[package]] [[package]]
name = "filelock" name = "filelock"
version = "3.13.4" version = "3.14.0"
description = "A platform independent file lock." description = "A platform independent file lock."
optional = false optional = false
python-versions = ">=3.8" python-versions = ">=3.8"
files = [ files = [
{file = "filelock-3.13.4-py3-none-any.whl", hash = "sha256:404e5e9253aa60ad457cae1be07c0f0ca90a63931200a47d9b6a6af84fd7b45f"}, {file = "filelock-3.14.0-py3-none-any.whl", hash = "sha256:43339835842f110ca7ae60f1e1c160714c5a6afd15a2873419ab185334975c0f"},
{file = "filelock-3.13.4.tar.gz", hash = "sha256:d13f466618bfde72bd2c18255e269f72542c6e70e7bac83a0232d6b1cc5c8cf4"}, {file = "filelock-3.14.0.tar.gz", hash = "sha256:6ea72da3be9b8c82afd3edcf99f2fffbb5076335a5ae4d03248bb5b6c3eae78a"},
] ]
[package.extras] [package.extras]
@ -1050,69 +1050,61 @@ preview = ["glfw-preview"]
[[package]] [[package]]
name = "grpcio" name = "grpcio"
version = "1.62.2" version = "1.63.0"
description = "HTTP/2-based RPC framework" description = "HTTP/2-based RPC framework"
optional = false optional = false
python-versions = ">=3.7" python-versions = ">=3.8"
files = [ files = [
{file = "grpcio-1.62.2-cp310-cp310-linux_armv7l.whl", hash = "sha256:66344ea741124c38588a664237ac2fa16dfd226964cca23ddc96bd4accccbde5"}, {file = "grpcio-1.63.0-cp310-cp310-linux_armv7l.whl", hash = "sha256:2e93aca840c29d4ab5db93f94ed0a0ca899e241f2e8aec6334ab3575dc46125c"},
{file = "grpcio-1.62.2-cp310-cp310-macosx_12_0_universal2.whl", hash = "sha256:5dab7ac2c1e7cb6179c6bfad6b63174851102cbe0682294e6b1d6f0981ad7138"}, {file = "grpcio-1.63.0-cp310-cp310-macosx_12_0_universal2.whl", hash = "sha256:91b73d3f1340fefa1e1716c8c1ec9930c676d6b10a3513ab6c26004cb02d8b3f"},
{file = "grpcio-1.62.2-cp310-cp310-manylinux_2_17_aarch64.whl", hash = "sha256:3ad00f3f0718894749d5a8bb0fa125a7980a2f49523731a9b1fabf2b3522aa43"}, {file = "grpcio-1.63.0-cp310-cp310-manylinux_2_17_aarch64.whl", hash = "sha256:b3afbd9d6827fa6f475a4f91db55e441113f6d3eb9b7ebb8fb806e5bb6d6bd0d"},
{file = "grpcio-1.62.2-cp310-cp310-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:2e72ddfee62430ea80133d2cbe788e0d06b12f865765cb24a40009668bd8ea05"}, {file = "grpcio-1.63.0-cp310-cp310-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:8f3f6883ce54a7a5f47db43289a0a4c776487912de1a0e2cc83fdaec9685cc9f"},
{file = "grpcio-1.62.2-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:53d3a59a10af4c2558a8e563aed9f256259d2992ae0d3037817b2155f0341de1"}, {file = "grpcio-1.63.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:cf8dae9cc0412cb86c8de5a8f3be395c5119a370f3ce2e69c8b7d46bb9872c8d"},
{file = "grpcio-1.62.2-cp310-cp310-musllinux_1_1_i686.whl", hash = "sha256:a1511a303f8074f67af4119275b4f954189e8313541da7b88b1b3a71425cdb10"}, {file = "grpcio-1.63.0-cp310-cp310-musllinux_1_1_i686.whl", hash = "sha256:08e1559fd3b3b4468486b26b0af64a3904a8dbc78d8d936af9c1cf9636eb3e8b"},
{file = "grpcio-1.62.2-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:b94d41b7412ef149743fbc3178e59d95228a7064c5ab4760ae82b562bdffb199"}, {file = "grpcio-1.63.0-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:5c039ef01516039fa39da8a8a43a95b64e288f79f42a17e6c2904a02a319b357"},
{file = "grpcio-1.62.2-cp310-cp310-win32.whl", hash = "sha256:a75af2fc7cb1fe25785be7bed1ab18cef959a376cdae7c6870184307614caa3f"}, {file = "grpcio-1.63.0-cp310-cp310-win32.whl", hash = "sha256:ad2ac8903b2eae071055a927ef74121ed52d69468e91d9bcbd028bd0e554be6d"},
{file = "grpcio-1.62.2-cp310-cp310-win_amd64.whl", hash = "sha256:80407bc007754f108dc2061e37480238b0dc1952c855e86a4fc283501ee6bb5d"}, {file = "grpcio-1.63.0-cp310-cp310-win_amd64.whl", hash = "sha256:b2e44f59316716532a993ca2966636df6fbe7be4ab6f099de6815570ebe4383a"},
{file = "grpcio-1.62.2-cp311-cp311-linux_armv7l.whl", hash = "sha256:c1624aa686d4b36790ed1c2e2306cc3498778dffaf7b8dd47066cf819028c3ad"}, {file = "grpcio-1.63.0-cp311-cp311-linux_armv7l.whl", hash = "sha256:f28f8b2db7b86c77916829d64ab21ff49a9d8289ea1564a2b2a3a8ed9ffcccd3"},
{file = "grpcio-1.62.2-cp311-cp311-macosx_10_10_universal2.whl", hash = "sha256:1c1bb80299bdef33309dff03932264636450c8fdb142ea39f47e06a7153d3063"}, {file = "grpcio-1.63.0-cp311-cp311-macosx_10_9_universal2.whl", hash = "sha256:65bf975639a1f93bee63ca60d2e4951f1b543f498d581869922910a476ead2f5"},
{file = "grpcio-1.62.2-cp311-cp311-manylinux_2_17_aarch64.whl", hash = "sha256:db068bbc9b1fa16479a82e1ecf172a93874540cb84be69f0b9cb9b7ac3c82670"}, {file = "grpcio-1.63.0-cp311-cp311-manylinux_2_17_aarch64.whl", hash = "sha256:b5194775fec7dc3dbd6a935102bb156cd2c35efe1685b0a46c67b927c74f0cfb"},
{file = "grpcio-1.62.2-cp311-cp311-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:e2cc8a308780edbe2c4913d6a49dbdb5befacdf72d489a368566be44cadaef1a"}, {file = "grpcio-1.63.0-cp311-cp311-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:e4cbb2100ee46d024c45920d16e888ee5d3cf47c66e316210bc236d5bebc42b3"},
{file = "grpcio-1.62.2-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:d0695ae31a89f1a8fc8256050329a91a9995b549a88619263a594ca31b76d756"}, {file = "grpcio-1.63.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:1ff737cf29b5b801619f10e59b581869e32f400159e8b12d7a97e7e3bdeee6a2"},
{file = "grpcio-1.62.2-cp311-cp311-musllinux_1_1_i686.whl", hash = "sha256:88b4f9ee77191dcdd8810241e89340a12cbe050be3e0d5f2f091c15571cd3930"}, {file = "grpcio-1.63.0-cp311-cp311-musllinux_1_1_i686.whl", hash = "sha256:cd1e68776262dd44dedd7381b1a0ad09d9930ffb405f737d64f505eb7f77d6c7"},
{file = "grpcio-1.62.2-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:2a0204532aa2f1afd467024b02b4069246320405bc18abec7babab03e2644e75"}, {file = "grpcio-1.63.0-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:93f45f27f516548e23e4ec3fbab21b060416007dbe768a111fc4611464cc773f"},
{file = "grpcio-1.62.2-cp311-cp311-win32.whl", hash = "sha256:6e784f60e575a0de554ef9251cbc2ceb8790914fe324f11e28450047f264ee6f"}, {file = "grpcio-1.63.0-cp311-cp311-win32.whl", hash = "sha256:878b1d88d0137df60e6b09b74cdb73db123f9579232c8456f53e9abc4f62eb3c"},
{file = "grpcio-1.62.2-cp311-cp311-win_amd64.whl", hash = "sha256:112eaa7865dd9e6d7c0556c8b04ae3c3a2dc35d62ad3373ab7f6a562d8199200"}, {file = "grpcio-1.63.0-cp311-cp311-win_amd64.whl", hash = "sha256:756fed02dacd24e8f488f295a913f250b56b98fb793f41d5b2de6c44fb762434"},
{file = "grpcio-1.62.2-cp312-cp312-linux_armv7l.whl", hash = "sha256:65034473fc09628a02fb85f26e73885cf1ed39ebd9cf270247b38689ff5942c5"}, {file = "grpcio-1.63.0-cp312-cp312-linux_armv7l.whl", hash = "sha256:93a46794cc96c3a674cdfb59ef9ce84d46185fe9421baf2268ccb556f8f81f57"},
{file = "grpcio-1.62.2-cp312-cp312-macosx_10_10_universal2.whl", hash = "sha256:d2c1771d0ee3cf72d69bb5e82c6a82f27fbd504c8c782575eddb7839729fbaad"}, {file = "grpcio-1.63.0-cp312-cp312-macosx_10_9_universal2.whl", hash = "sha256:a7b19dfc74d0be7032ca1eda0ed545e582ee46cd65c162f9e9fc6b26ef827dc6"},
{file = "grpcio-1.62.2-cp312-cp312-manylinux_2_17_aarch64.whl", hash = "sha256:3abe6838196da518863b5d549938ce3159d809218936851b395b09cad9b5d64a"}, {file = "grpcio-1.63.0-cp312-cp312-manylinux_2_17_aarch64.whl", hash = "sha256:8064d986d3a64ba21e498b9a376cbc5d6ab2e8ab0e288d39f266f0fca169b90d"},
{file = "grpcio-1.62.2-cp312-cp312-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:c5ffeb269f10cedb4f33142b89a061acda9f672fd1357331dbfd043422c94e9e"}, {file = "grpcio-1.63.0-cp312-cp312-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:219bb1848cd2c90348c79ed0a6b0ea51866bc7e72fa6e205e459fedab5770172"},
{file = "grpcio-1.62.2-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:404d3b4b6b142b99ba1cff0b2177d26b623101ea2ce51c25ef6e53d9d0d87bcc"}, {file = "grpcio-1.63.0-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:a2d60cd1d58817bc5985fae6168d8b5655c4981d448d0f5b6194bbcc038090d2"},
{file = "grpcio-1.62.2-cp312-cp312-musllinux_1_1_i686.whl", hash = "sha256:262cda97efdabb20853d3b5a4c546a535347c14b64c017f628ca0cc7fa780cc6"}, {file = "grpcio-1.63.0-cp312-cp312-musllinux_1_1_i686.whl", hash = "sha256:9e350cb096e5c67832e9b6e018cf8a0d2a53b2a958f6251615173165269a91b0"},
{file = "grpcio-1.62.2-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:17708db5b11b966373e21519c4c73e5a750555f02fde82276ea2a267077c68ad"}, {file = "grpcio-1.63.0-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:56cdf96ff82e3cc90dbe8bac260352993f23e8e256e063c327b6cf9c88daf7a9"},
{file = "grpcio-1.62.2-cp312-cp312-win32.whl", hash = "sha256:b7ec9e2f8ffc8436f6b642a10019fc513722858f295f7efc28de135d336ac189"}, {file = "grpcio-1.63.0-cp312-cp312-win32.whl", hash = "sha256:3a6d1f9ea965e750db7b4ee6f9fdef5fdf135abe8a249e75d84b0a3e0c668a1b"},
{file = "grpcio-1.62.2-cp312-cp312-win_amd64.whl", hash = "sha256:aa787b83a3cd5e482e5c79be030e2b4a122ecc6c5c6c4c42a023a2b581fdf17b"}, {file = "grpcio-1.63.0-cp312-cp312-win_amd64.whl", hash = "sha256:d2497769895bb03efe3187fb1888fc20e98a5f18b3d14b606167dacda5789434"},
{file = "grpcio-1.62.2-cp37-cp37m-linux_armv7l.whl", hash = "sha256:cfd23ad29bfa13fd4188433b0e250f84ec2c8ba66b14a9877e8bce05b524cf54"}, {file = "grpcio-1.63.0-cp38-cp38-linux_armv7l.whl", hash = "sha256:fdf348ae69c6ff484402cfdb14e18c1b0054ac2420079d575c53a60b9b2853ae"},
{file = "grpcio-1.62.2-cp37-cp37m-macosx_10_10_universal2.whl", hash = "sha256:af15e9efa4d776dfcecd1d083f3ccfb04f876d613e90ef8432432efbeeac689d"}, {file = "grpcio-1.63.0-cp38-cp38-macosx_10_9_universal2.whl", hash = "sha256:a3abfe0b0f6798dedd2e9e92e881d9acd0fdb62ae27dcbbfa7654a57e24060c0"},
{file = "grpcio-1.62.2-cp37-cp37m-manylinux_2_17_aarch64.whl", hash = "sha256:f4aa94361bb5141a45ca9187464ae81a92a2a135ce2800b2203134f7a1a1d479"}, {file = "grpcio-1.63.0-cp38-cp38-manylinux_2_17_aarch64.whl", hash = "sha256:6ef0ad92873672a2a3767cb827b64741c363ebaa27e7f21659e4e31f4d750280"},
{file = "grpcio-1.62.2-cp37-cp37m-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:82af3613a219512a28ee5c95578eb38d44dd03bca02fd918aa05603c41018051"}, {file = "grpcio-1.63.0-cp38-cp38-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:b416252ac5588d9dfb8a30a191451adbf534e9ce5f56bb02cd193f12d8845b7f"},
{file = "grpcio-1.62.2-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:55ddaf53474e8caeb29eb03e3202f9d827ad3110475a21245f3c7712022882a9"}, {file = "grpcio-1.63.0-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:e3b77eaefc74d7eb861d3ffbdf91b50a1bb1639514ebe764c47773b833fa2d91"},
{file = "grpcio-1.62.2-cp37-cp37m-musllinux_1_1_i686.whl", hash = "sha256:c79b518c56dddeec79e5500a53d8a4db90da995dfe1738c3ac57fe46348be049"}, {file = "grpcio-1.63.0-cp38-cp38-musllinux_1_1_i686.whl", hash = "sha256:b005292369d9c1f80bf70c1db1c17c6c342da7576f1c689e8eee4fb0c256af85"},
{file = "grpcio-1.62.2-cp37-cp37m-musllinux_1_1_x86_64.whl", hash = "sha256:a5eb4844e5e60bf2c446ef38c5b40d7752c6effdee882f716eb57ae87255d20a"}, {file = "grpcio-1.63.0-cp38-cp38-musllinux_1_1_x86_64.whl", hash = "sha256:cdcda1156dcc41e042d1e899ba1f5c2e9f3cd7625b3d6ebfa619806a4c1aadda"},
{file = "grpcio-1.62.2-cp37-cp37m-win_amd64.whl", hash = "sha256:aaae70364a2d1fb238afd6cc9fcb10442b66e397fd559d3f0968d28cc3ac929c"}, {file = "grpcio-1.63.0-cp38-cp38-win32.whl", hash = "sha256:01799e8649f9e94ba7db1aeb3452188048b0019dc37696b0f5ce212c87c560c3"},
{file = "grpcio-1.62.2-cp38-cp38-linux_armv7l.whl", hash = "sha256:1bcfe5070e4406f489e39325b76caeadab28c32bf9252d3ae960c79935a4cc36"}, {file = "grpcio-1.63.0-cp38-cp38-win_amd64.whl", hash = "sha256:6a1a3642d76f887aa4009d92f71eb37809abceb3b7b5a1eec9c554a246f20e3a"},
{file = "grpcio-1.62.2-cp38-cp38-macosx_10_10_universal2.whl", hash = "sha256:da6a7b6b938c15fa0f0568e482efaae9c3af31963eec2da4ff13a6d8ec2888e4"}, {file = "grpcio-1.63.0-cp39-cp39-linux_armv7l.whl", hash = "sha256:75f701ff645858a2b16bc8c9fc68af215a8bb2d5a9b647448129de6e85d52bce"},
{file = "grpcio-1.62.2-cp38-cp38-manylinux_2_17_aarch64.whl", hash = "sha256:41955b641c34db7d84db8d306937b72bc4968eef1c401bea73081a8d6c3d8033"}, {file = "grpcio-1.63.0-cp39-cp39-macosx_10_9_universal2.whl", hash = "sha256:cacdef0348a08e475a721967f48206a2254a1b26ee7637638d9e081761a5ba86"},
{file = "grpcio-1.62.2-cp38-cp38-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:c772f225483905f675cb36a025969eef9712f4698364ecd3a63093760deea1bc"}, {file = "grpcio-1.63.0-cp39-cp39-manylinux_2_17_aarch64.whl", hash = "sha256:0697563d1d84d6985e40ec5ec596ff41b52abb3fd91ec240e8cb44a63b895094"},
{file = "grpcio-1.62.2-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:07ce1f775d37ca18c7a141300e5b71539690efa1f51fe17f812ca85b5e73262f"}, {file = "grpcio-1.63.0-cp39-cp39-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:6426e1fb92d006e47476d42b8f240c1d916a6d4423c5258ccc5b105e43438f61"},
{file = "grpcio-1.62.2-cp38-cp38-musllinux_1_1_i686.whl", hash = "sha256:26f415f40f4a93579fd648f48dca1c13dfacdfd0290f4a30f9b9aeb745026811"}, {file = "grpcio-1.63.0-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:e48cee31bc5f5a31fb2f3b573764bd563aaa5472342860edcc7039525b53e46a"},
{file = "grpcio-1.62.2-cp38-cp38-musllinux_1_1_x86_64.whl", hash = "sha256:db707e3685ff16fc1eccad68527d072ac8bdd2e390f6daa97bc394ea7de4acea"}, {file = "grpcio-1.63.0-cp39-cp39-musllinux_1_1_i686.whl", hash = "sha256:50344663068041b34a992c19c600236e7abb42d6ec32567916b87b4c8b8833b3"},
{file = "grpcio-1.62.2-cp38-cp38-win32.whl", hash = "sha256:589ea8e75de5fd6df387de53af6c9189c5231e212b9aa306b6b0d4f07520fbb9"}, {file = "grpcio-1.63.0-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:259e11932230d70ef24a21b9fb5bb947eb4703f57865a404054400ee92f42f5d"},
{file = "grpcio-1.62.2-cp38-cp38-win_amd64.whl", hash = "sha256:3c3ed41f4d7a3aabf0f01ecc70d6b5d00ce1800d4af652a549de3f7cf35c4abd"}, {file = "grpcio-1.63.0-cp39-cp39-win32.whl", hash = "sha256:a44624aad77bf8ca198c55af811fd28f2b3eaf0a50ec5b57b06c034416ef2d0a"},
{file = "grpcio-1.62.2-cp39-cp39-linux_armv7l.whl", hash = "sha256:162ccf61499c893831b8437120600290a99c0bc1ce7b51f2c8d21ec87ff6af8b"}, {file = "grpcio-1.63.0-cp39-cp39-win_amd64.whl", hash = "sha256:166e5c460e5d7d4656ff9e63b13e1f6029b122104c1633d5f37eaea348d7356d"},
{file = "grpcio-1.62.2-cp39-cp39-macosx_10_10_universal2.whl", hash = "sha256:f27246d7da7d7e3bd8612f63785a7b0c39a244cf14b8dd9dd2f2fab939f2d7f1"}, {file = "grpcio-1.63.0.tar.gz", hash = "sha256:f3023e14805c61bc439fb40ca545ac3d5740ce66120a678a3c6c2c55b70343d1"},
{file = "grpcio-1.62.2-cp39-cp39-manylinux_2_17_aarch64.whl", hash = "sha256:2507006c8a478f19e99b6fe36a2464696b89d40d88f34e4b709abe57e1337467"},
{file = "grpcio-1.62.2-cp39-cp39-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:a90ac47a8ce934e2c8d71e317d2f9e7e6aaceb2d199de940ce2c2eb611b8c0f4"},
{file = "grpcio-1.62.2-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:99701979bcaaa7de8d5f60476487c5df8f27483624f1f7e300ff4669ee44d1f2"},
{file = "grpcio-1.62.2-cp39-cp39-musllinux_1_1_i686.whl", hash = "sha256:af7dc3f7a44f10863b1b0ecab4078f0a00f561aae1edbd01fd03ad4dcf61c9e9"},
{file = "grpcio-1.62.2-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:fa63245271920786f4cb44dcada4983a3516be8f470924528cf658731864c14b"},
{file = "grpcio-1.62.2-cp39-cp39-win32.whl", hash = "sha256:c6ad9c39704256ed91a1cffc1379d63f7d0278d6a0bad06b0330f5d30291e3a3"},
{file = "grpcio-1.62.2-cp39-cp39-win_amd64.whl", hash = "sha256:16da954692fd61aa4941fbeda405a756cd96b97b5d95ca58a92547bba2c1624f"},
{file = "grpcio-1.62.2.tar.gz", hash = "sha256:c77618071d96b7a8be2c10701a98537823b9c65ba256c0b9067e0594cdbd954d"},
] ]
[package.extras] [package.extras]
protobuf = ["grpcio-tools (>=1.62.2)"] protobuf = ["grpcio-tools (>=1.63.0)"]
[[package]] [[package]]
name = "gym-aloha" name = "gym-aloha"
@ -2414,7 +2406,6 @@ optional = false
python-versions = ">=3.9" python-versions = ">=3.9"
files = [ files = [
{file = "pandas-2.2.2-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:90c6fca2acf139569e74e8781709dccb6fe25940488755716d1d354d6bc58bce"}, {file = "pandas-2.2.2-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:90c6fca2acf139569e74e8781709dccb6fe25940488755716d1d354d6bc58bce"},
{file = "pandas-2.2.2-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:c7adfc142dac335d8c1e0dcbd37eb8617eac386596eb9e1a1b77791cf2498238"},
{file = "pandas-2.2.2-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:4abfe0be0d7221be4f12552995e58723c7422c80a659da13ca382697de830c08"}, {file = "pandas-2.2.2-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:4abfe0be0d7221be4f12552995e58723c7422c80a659da13ca382697de830c08"},
{file = "pandas-2.2.2-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:8635c16bf3d99040fdf3ca3db669a7250ddf49c55dc4aa8fe0ae0fa8d6dcc1f0"}, {file = "pandas-2.2.2-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:8635c16bf3d99040fdf3ca3db669a7250ddf49c55dc4aa8fe0ae0fa8d6dcc1f0"},
{file = "pandas-2.2.2-cp310-cp310-musllinux_1_1_aarch64.whl", hash = "sha256:40ae1dffb3967a52203105a077415a86044a2bea011b5f321c6aa64b379a3f51"}, {file = "pandas-2.2.2-cp310-cp310-musllinux_1_1_aarch64.whl", hash = "sha256:40ae1dffb3967a52203105a077415a86044a2bea011b5f321c6aa64b379a3f51"},
@ -2435,7 +2426,6 @@ files = [
{file = "pandas-2.2.2-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:43498c0bdb43d55cb162cdc8c06fac328ccb5d2eabe3cadeb3529ae6f0517c32"}, {file = "pandas-2.2.2-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:43498c0bdb43d55cb162cdc8c06fac328ccb5d2eabe3cadeb3529ae6f0517c32"},
{file = "pandas-2.2.2-cp312-cp312-win_amd64.whl", hash = "sha256:d187d355ecec3629624fccb01d104da7d7f391db0311145817525281e2804d23"}, {file = "pandas-2.2.2-cp312-cp312-win_amd64.whl", hash = "sha256:d187d355ecec3629624fccb01d104da7d7f391db0311145817525281e2804d23"},
{file = "pandas-2.2.2-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:0ca6377b8fca51815f382bd0b697a0814c8bda55115678cbc94c30aacbb6eff2"}, {file = "pandas-2.2.2-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:0ca6377b8fca51815f382bd0b697a0814c8bda55115678cbc94c30aacbb6eff2"},
{file = "pandas-2.2.2-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:9057e6aa78a584bc93a13f0a9bf7e753a5e9770a30b4d758b8d5f2a62a9433cd"},
{file = "pandas-2.2.2-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:001910ad31abc7bf06f49dcc903755d2f7f3a9186c0c040b827e522e9cef0863"}, {file = "pandas-2.2.2-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:001910ad31abc7bf06f49dcc903755d2f7f3a9186c0c040b827e522e9cef0863"},
{file = "pandas-2.2.2-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:66b479b0bd07204e37583c191535505410daa8df638fd8e75ae1b383851fe921"}, {file = "pandas-2.2.2-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:66b479b0bd07204e37583c191535505410daa8df638fd8e75ae1b383851fe921"},
{file = "pandas-2.2.2-cp39-cp39-musllinux_1_1_aarch64.whl", hash = "sha256:a77e9d1c386196879aa5eb712e77461aaee433e54c68cf253053a73b7e49c33a"}, {file = "pandas-2.2.2-cp39-cp39-musllinux_1_1_aarch64.whl", hash = "sha256:a77e9d1c386196879aa5eb712e77461aaee433e54c68cf253053a73b7e49c33a"},
@ -3110,104 +3100,90 @@ files = [
[[package]] [[package]]
name = "regex" name = "regex"
version = "2024.4.16" version = "2024.4.28"
description = "Alternative regular expression module, to replace re." description = "Alternative regular expression module, to replace re."
optional = false optional = false
python-versions = ">=3.7" python-versions = ">=3.8"
files = [ files = [
{file = "regex-2024.4.16-cp310-cp310-macosx_10_9_universal2.whl", hash = "sha256:fb83cc090eac63c006871fd24db5e30a1f282faa46328572661c0a24a2323a08"}, {file = "regex-2024.4.28-cp310-cp310-macosx_10_9_universal2.whl", hash = "sha256:cd196d056b40af073d95a2879678585f0b74ad35190fac04ca67954c582c6b61"},
{file = "regex-2024.4.16-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:8c91e1763696c0eb66340c4df98623c2d4e77d0746b8f8f2bee2c6883fd1fe18"}, {file = "regex-2024.4.28-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:8bb381f777351bd534462f63e1c6afb10a7caa9fa2a421ae22c26e796fe31b1f"},
{file = "regex-2024.4.16-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:10188fe732dec829c7acca7422cdd1bf57d853c7199d5a9e96bb4d40db239c73"}, {file = "regex-2024.4.28-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:47af45b6153522733aa6e92543938e97a70ce0900649ba626cf5aad290b737b6"},
{file = "regex-2024.4.16-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:956b58d692f235cfbf5b4f3abd6d99bf102f161ccfe20d2fd0904f51c72c4c66"}, {file = "regex-2024.4.28-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:99d6a550425cc51c656331af0e2b1651e90eaaa23fb4acde577cf15068e2e20f"},
{file = "regex-2024.4.16-cp310-cp310-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:a70b51f55fd954d1f194271695821dd62054d949efd6368d8be64edd37f55c86"}, {file = "regex-2024.4.28-cp310-cp310-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:bf29304a8011feb58913c382902fde3395957a47645bf848eea695839aa101b7"},
{file = "regex-2024.4.16-cp310-cp310-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:5c02fcd2bf45162280613d2e4a1ca3ac558ff921ae4e308ecb307650d3a6ee51"}, {file = "regex-2024.4.28-cp310-cp310-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:92da587eee39a52c91aebea8b850e4e4f095fe5928d415cb7ed656b3460ae79a"},
{file = "regex-2024.4.16-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:c4ed75ea6892a56896d78f11006161eea52c45a14994794bcfa1654430984b22"}, {file = "regex-2024.4.28-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:6277d426e2f31bdbacb377d17a7475e32b2d7d1f02faaecc48d8e370c6a3ff31"},
{file = "regex-2024.4.16-cp310-cp310-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:bd727ad276bb91928879f3aa6396c9a1d34e5e180dce40578421a691eeb77f47"}, {file = "regex-2024.4.28-cp310-cp310-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:28e1f28d07220c0f3da0e8fcd5a115bbb53f8b55cecf9bec0c946eb9a059a94c"},
{file = "regex-2024.4.16-cp310-cp310-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_12_x86_64.manylinux2010_x86_64.whl", hash = "sha256:7cbc5d9e8a1781e7be17da67b92580d6ce4dcef5819c1b1b89f49d9678cc278c"}, {file = "regex-2024.4.28-cp310-cp310-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_12_x86_64.manylinux2010_x86_64.whl", hash = "sha256:aaa179975a64790c1f2701ac562b5eeb733946eeb036b5bcca05c8d928a62f10"},
{file = "regex-2024.4.16-cp310-cp310-musllinux_1_1_aarch64.whl", hash = "sha256:78fddb22b9ef810b63ef341c9fcf6455232d97cfe03938cbc29e2672c436670e"}, {file = "regex-2024.4.28-cp310-cp310-musllinux_1_1_aarch64.whl", hash = "sha256:6f435946b7bf7a1b438b4e6b149b947c837cb23c704e780c19ba3e6855dbbdd3"},
{file = "regex-2024.4.16-cp310-cp310-musllinux_1_1_i686.whl", hash = "sha256:445ca8d3c5a01309633a0c9db57150312a181146315693273e35d936472df912"}, {file = "regex-2024.4.28-cp310-cp310-musllinux_1_1_i686.whl", hash = "sha256:19d6c11bf35a6ad077eb23852827f91c804eeb71ecb85db4ee1386825b9dc4db"},
{file = "regex-2024.4.16-cp310-cp310-musllinux_1_1_ppc64le.whl", hash = "sha256:95399831a206211d6bc40224af1c635cb8790ddd5c7493e0bd03b85711076a53"}, {file = "regex-2024.4.28-cp310-cp310-musllinux_1_1_ppc64le.whl", hash = "sha256:fdae0120cddc839eb8e3c15faa8ad541cc6d906d3eb24d82fb041cfe2807bc1e"},
{file = "regex-2024.4.16-cp310-cp310-musllinux_1_1_s390x.whl", hash = "sha256:7731728b6568fc286d86745f27f07266de49603a6fdc4d19c87e8c247be452af"}, {file = "regex-2024.4.28-cp310-cp310-musllinux_1_1_s390x.whl", hash = "sha256:e672cf9caaf669053121f1766d659a8813bd547edef6e009205378faf45c67b8"},
{file = "regex-2024.4.16-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:4facc913e10bdba42ec0aee76d029aedda628161a7ce4116b16680a0413f658a"}, {file = "regex-2024.4.28-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:f57515750d07e14743db55d59759893fdb21d2668f39e549a7d6cad5d70f9fea"},
{file = "regex-2024.4.16-cp310-cp310-win32.whl", hash = "sha256:911742856ce98d879acbea33fcc03c1d8dc1106234c5e7d068932c945db209c0"}, {file = "regex-2024.4.28-cp310-cp310-win32.whl", hash = "sha256:a1409c4eccb6981c7baabc8888d3550df518add6e06fe74fa1d9312c1838652d"},
{file = "regex-2024.4.16-cp310-cp310-win_amd64.whl", hash = "sha256:e0a2df336d1135a0b3a67f3bbf78a75f69562c1199ed9935372b82215cddd6e2"}, {file = "regex-2024.4.28-cp310-cp310-win_amd64.whl", hash = "sha256:1f687a28640f763f23f8a9801fe9e1b37338bb1ca5d564ddd41619458f1f22d1"},
{file = "regex-2024.4.16-cp311-cp311-macosx_10_9_universal2.whl", hash = "sha256:1210365faba7c2150451eb78ec5687871c796b0f1fa701bfd2a4a25420482d26"}, {file = "regex-2024.4.28-cp311-cp311-macosx_10_9_universal2.whl", hash = "sha256:84077821c85f222362b72fdc44f7a3a13587a013a45cf14534df1cbbdc9a6796"},
{file = "regex-2024.4.16-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:9ab40412f8cd6f615bfedea40c8bf0407d41bf83b96f6fc9ff34976d6b7037fd"}, {file = "regex-2024.4.28-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:b45d4503de8f4f3dc02f1d28a9b039e5504a02cc18906cfe744c11def942e9eb"},
{file = "regex-2024.4.16-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:fd80d1280d473500d8086d104962a82d77bfbf2b118053824b7be28cd5a79ea5"}, {file = "regex-2024.4.28-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:457c2cd5a646dd4ed536c92b535d73548fb8e216ebee602aa9f48e068fc393f3"},
{file = "regex-2024.4.16-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:7bb966fdd9217e53abf824f437a5a2d643a38d4fd5fd0ca711b9da683d452969"}, {file = "regex-2024.4.28-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:2b51739ddfd013c6f657b55a508de8b9ea78b56d22b236052c3a85a675102dc6"},
{file = "regex-2024.4.16-cp311-cp311-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:20b7a68444f536365af42a75ccecb7ab41a896a04acf58432db9e206f4e525d6"}, {file = "regex-2024.4.28-cp311-cp311-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:459226445c7d7454981c4c0ce0ad1a72e1e751c3e417f305722bbcee6697e06a"},
{file = "regex-2024.4.16-cp311-cp311-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:b74586dd0b039c62416034f811d7ee62810174bb70dffcca6439f5236249eb09"}, {file = "regex-2024.4.28-cp311-cp311-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:670fa596984b08a4a769491cbdf22350431970d0112e03d7e4eeaecaafcd0fec"},
{file = "regex-2024.4.16-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:0c8290b44d8b0af4e77048646c10c6e3aa583c1ca67f3b5ffb6e06cf0c6f0f89"}, {file = "regex-2024.4.28-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:fe00f4fe11c8a521b173e6324d862ee7ee3412bf7107570c9b564fe1119b56fb"},
{file = "regex-2024.4.16-cp311-cp311-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:f2d80a6749724b37853ece57988b39c4e79d2b5fe2869a86e8aeae3bbeef9eb0"}, {file = "regex-2024.4.28-cp311-cp311-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:36f392dc7763fe7924575475736bddf9ab9f7a66b920932d0ea50c2ded2f5636"},
{file = "regex-2024.4.16-cp311-cp311-musllinux_1_1_aarch64.whl", hash = "sha256:3a1018e97aeb24e4f939afcd88211ace472ba566efc5bdf53fd8fd7f41fa7170"}, {file = "regex-2024.4.28-cp311-cp311-musllinux_1_1_aarch64.whl", hash = "sha256:23a412b7b1a7063f81a742463f38821097b6a37ce1e5b89dd8e871d14dbfd86b"},
{file = "regex-2024.4.16-cp311-cp311-musllinux_1_1_i686.whl", hash = "sha256:8d015604ee6204e76569d2f44e5a210728fa917115bef0d102f4107e622b08d5"}, {file = "regex-2024.4.28-cp311-cp311-musllinux_1_1_i686.whl", hash = "sha256:f1d6e4b7b2ae3a6a9df53efbf199e4bfcff0959dbdb5fd9ced34d4407348e39a"},
{file = "regex-2024.4.16-cp311-cp311-musllinux_1_1_ppc64le.whl", hash = "sha256:3d5ac5234fb5053850d79dd8eb1015cb0d7d9ed951fa37aa9e6249a19aa4f336"}, {file = "regex-2024.4.28-cp311-cp311-musllinux_1_1_ppc64le.whl", hash = "sha256:499334ad139557de97cbc4347ee921c0e2b5e9c0f009859e74f3f77918339257"},
{file = "regex-2024.4.16-cp311-cp311-musllinux_1_1_s390x.whl", hash = "sha256:0a38d151e2cdd66d16dab550c22f9521ba79761423b87c01dae0a6e9add79c0d"}, {file = "regex-2024.4.28-cp311-cp311-musllinux_1_1_s390x.whl", hash = "sha256:0940038bec2fe9e26b203d636c44d31dd8766abc1fe66262da6484bd82461ccf"},
{file = "regex-2024.4.16-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:159dc4e59a159cb8e4e8f8961eb1fa5d58f93cb1acd1701d8aff38d45e1a84a6"}, {file = "regex-2024.4.28-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:66372c2a01782c5fe8e04bff4a2a0121a9897e19223d9eab30c54c50b2ebeb7f"},
{file = "regex-2024.4.16-cp311-cp311-win32.whl", hash = "sha256:ba2336d6548dee3117520545cfe44dc28a250aa091f8281d28804aa8d707d93d"}, {file = "regex-2024.4.28-cp311-cp311-win32.whl", hash = "sha256:c77d10ec3c1cf328b2f501ca32583625987ea0f23a0c2a49b37a39ee5c4c4630"},
{file = "regex-2024.4.16-cp311-cp311-win_amd64.whl", hash = "sha256:8f83b6fd3dc3ba94d2b22717f9c8b8512354fd95221ac661784df2769ea9bba9"}, {file = "regex-2024.4.28-cp311-cp311-win_amd64.whl", hash = "sha256:fc0916c4295c64d6890a46e02d4482bb5ccf33bf1a824c0eaa9e83b148291f90"},
{file = "regex-2024.4.16-cp312-cp312-macosx_10_9_universal2.whl", hash = "sha256:80b696e8972b81edf0af2a259e1b2a4a661f818fae22e5fa4fa1a995fb4a40fd"}, {file = "regex-2024.4.28-cp312-cp312-macosx_10_9_universal2.whl", hash = "sha256:08a1749f04fee2811c7617fdd46d2e46d09106fa8f475c884b65c01326eb15c5"},
{file = "regex-2024.4.16-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:d61ae114d2a2311f61d90c2ef1358518e8f05eafda76eaf9c772a077e0b465ec"}, {file = "regex-2024.4.28-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:b8eb28995771c087a73338f695a08c9abfdf723d185e57b97f6175c5051ff1ae"},
{file = "regex-2024.4.16-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:8ba6745440b9a27336443b0c285d705ce73adb9ec90e2f2004c64d95ab5a7598"}, {file = "regex-2024.4.28-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:dd7ef715ccb8040954d44cfeff17e6b8e9f79c8019daae2fd30a8806ef5435c0"},
{file = "regex-2024.4.16-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:6295004b2dd37b0835ea5c14a33e00e8cfa3c4add4d587b77287825f3418d310"}, {file = "regex-2024.4.28-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:fb0315a2b26fde4005a7c401707c5352df274460f2f85b209cf6024271373013"},
{file = "regex-2024.4.16-cp312-cp312-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:4aba818dcc7263852aabb172ec27b71d2abca02a593b95fa79351b2774eb1d2b"}, {file = "regex-2024.4.28-cp312-cp312-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:f2fc053228a6bd3a17a9b0a3f15c3ab3cf95727b00557e92e1cfe094b88cc662"},
{file = "regex-2024.4.16-cp312-cp312-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:d0800631e565c47520aaa04ae38b96abc5196fe8b4aa9bd864445bd2b5848a7a"}, {file = "regex-2024.4.28-cp312-cp312-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:7fe9739a686dc44733d52d6e4f7b9c77b285e49edf8570754b322bca6b85b4cc"},
{file = "regex-2024.4.16-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:08dea89f859c3df48a440dbdcd7b7155bc675f2fa2ec8c521d02dc69e877db70"}, {file = "regex-2024.4.28-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:a74fcf77d979364f9b69fcf8200849ca29a374973dc193a7317698aa37d8b01c"},
{file = "regex-2024.4.16-cp312-cp312-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:eeaa0b5328b785abc344acc6241cffde50dc394a0644a968add75fcefe15b9d4"}, {file = "regex-2024.4.28-cp312-cp312-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:965fd0cf4694d76f6564896b422724ec7b959ef927a7cb187fc6b3f4e4f59833"},
{file = "regex-2024.4.16-cp312-cp312-musllinux_1_1_aarch64.whl", hash = "sha256:4e819a806420bc010489f4e741b3036071aba209f2e0989d4750b08b12a9343f"}, {file = "regex-2024.4.28-cp312-cp312-musllinux_1_1_aarch64.whl", hash = "sha256:2fef0b38c34ae675fcbb1b5db760d40c3fc3612cfa186e9e50df5782cac02bcd"},
{file = "regex-2024.4.16-cp312-cp312-musllinux_1_1_i686.whl", hash = "sha256:c2d0e7cbb6341e830adcbfa2479fdeebbfbb328f11edd6b5675674e7a1e37730"}, {file = "regex-2024.4.28-cp312-cp312-musllinux_1_1_i686.whl", hash = "sha256:bc365ce25f6c7c5ed70e4bc674f9137f52b7dd6a125037f9132a7be52b8a252f"},
{file = "regex-2024.4.16-cp312-cp312-musllinux_1_1_ppc64le.whl", hash = "sha256:91797b98f5e34b6a49f54be33f72e2fb658018ae532be2f79f7c63b4ae225145"}, {file = "regex-2024.4.28-cp312-cp312-musllinux_1_1_ppc64le.whl", hash = "sha256:ac69b394764bb857429b031d29d9604842bc4cbfd964d764b1af1868eeebc4f0"},
{file = "regex-2024.4.16-cp312-cp312-musllinux_1_1_s390x.whl", hash = "sha256:d2da13568eff02b30fd54fccd1e042a70fe920d816616fda4bf54ec705668d81"}, {file = "regex-2024.4.28-cp312-cp312-musllinux_1_1_s390x.whl", hash = "sha256:144a1fc54765f5c5c36d6d4b073299832aa1ec6a746a6452c3ee7b46b3d3b11d"},
{file = "regex-2024.4.16-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:370c68dc5570b394cbaadff50e64d705f64debed30573e5c313c360689b6aadc"}, {file = "regex-2024.4.28-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:2630ca4e152c221072fd4a56d4622b5ada876f668ecd24d5ab62544ae6793ed6"},
{file = "regex-2024.4.16-cp312-cp312-win32.whl", hash = "sha256:904c883cf10a975b02ab3478bce652f0f5346a2c28d0a8521d97bb23c323cc8b"}, {file = "regex-2024.4.28-cp312-cp312-win32.whl", hash = "sha256:7f3502f03b4da52bbe8ba962621daa846f38489cae5c4a7b5d738f15f6443d17"},
{file = "regex-2024.4.16-cp312-cp312-win_amd64.whl", hash = "sha256:785c071c982dce54d44ea0b79cd6dfafddeccdd98cfa5f7b86ef69b381b457d9"}, {file = "regex-2024.4.28-cp312-cp312-win_amd64.whl", hash = "sha256:0dd3f69098511e71880fb00f5815db9ed0ef62c05775395968299cb400aeab82"},
{file = "regex-2024.4.16-cp37-cp37m-macosx_10_9_x86_64.whl", hash = "sha256:e2f142b45c6fed48166faeb4303b4b58c9fcd827da63f4cf0a123c3480ae11fb"}, {file = "regex-2024.4.28-cp38-cp38-macosx_10_9_universal2.whl", hash = "sha256:374f690e1dd0dbdcddea4a5c9bdd97632cf656c69113f7cd6a361f2a67221cb6"},
{file = "regex-2024.4.16-cp37-cp37m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:e87ab229332ceb127a165612d839ab87795972102cb9830e5f12b8c9a5c1b508"}, {file = "regex-2024.4.28-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:25f87ae6b96374db20f180eab083aafe419b194e96e4f282c40191e71980c666"},
{file = "regex-2024.4.16-cp37-cp37m-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:81500ed5af2090b4a9157a59dbc89873a25c33db1bb9a8cf123837dcc9765047"}, {file = "regex-2024.4.28-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:5dbc1bcc7413eebe5f18196e22804a3be1bfdfc7e2afd415e12c068624d48247"},
{file = "regex-2024.4.16-cp37-cp37m-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:b340cccad138ecb363324aa26893963dcabb02bb25e440ebdf42e30963f1a4e0"}, {file = "regex-2024.4.28-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:f85151ec5a232335f1be022b09fbbe459042ea1951d8a48fef251223fc67eee1"},
{file = "regex-2024.4.16-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:2c72608e70f053643437bd2be0608f7f1c46d4022e4104d76826f0839199347a"}, {file = "regex-2024.4.28-cp38-cp38-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:57ba112e5530530fd175ed550373eb263db4ca98b5f00694d73b18b9a02e7185"},
{file = "regex-2024.4.16-cp37-cp37m-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:a01fe2305e6232ef3e8f40bfc0f0f3a04def9aab514910fa4203bafbc0bb4682"}, {file = "regex-2024.4.28-cp38-cp38-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:224803b74aab56aa7be313f92a8d9911dcade37e5f167db62a738d0c85fdac4b"},
{file = "regex-2024.4.16-cp37-cp37m-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_12_x86_64.manylinux2010_x86_64.whl", hash = "sha256:03576e3a423d19dda13e55598f0fd507b5d660d42c51b02df4e0d97824fdcae3"}, {file = "regex-2024.4.28-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:0a54a047b607fd2d2d52a05e6ad294602f1e0dec2291152b745870afc47c1397"},
{file = "regex-2024.4.16-cp37-cp37m-musllinux_1_1_aarch64.whl", hash = "sha256:549c3584993772e25f02d0656ac48abdda73169fe347263948cf2b1cead622f3"}, {file = "regex-2024.4.28-cp38-cp38-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:0a2a512d623f1f2d01d881513af9fc6a7c46e5cfffb7dc50c38ce959f9246c94"},
{file = "regex-2024.4.16-cp37-cp37m-musllinux_1_1_i686.whl", hash = "sha256:34422d5a69a60b7e9a07a690094e824b66f5ddc662a5fc600d65b7c174a05f04"}, {file = "regex-2024.4.28-cp38-cp38-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_12_x86_64.manylinux2010_x86_64.whl", hash = "sha256:c06bf3f38f0707592898428636cbb75d0a846651b053a1cf748763e3063a6925"},
{file = "regex-2024.4.16-cp37-cp37m-musllinux_1_1_ppc64le.whl", hash = "sha256:5f580c651a72b75c39e311343fe6875d6f58cf51c471a97f15a938d9fe4e0d37"}, {file = "regex-2024.4.28-cp38-cp38-musllinux_1_1_aarch64.whl", hash = "sha256:1031a5e7b048ee371ab3653aad3030ecfad6ee9ecdc85f0242c57751a05b0ac4"},
{file = "regex-2024.4.16-cp37-cp37m-musllinux_1_1_s390x.whl", hash = "sha256:3399dd8a7495bbb2bacd59b84840eef9057826c664472e86c91d675d007137f5"}, {file = "regex-2024.4.28-cp38-cp38-musllinux_1_1_i686.whl", hash = "sha256:d7a353ebfa7154c871a35caca7bfd8f9e18666829a1dc187115b80e35a29393e"},
{file = "regex-2024.4.16-cp37-cp37m-musllinux_1_1_x86_64.whl", hash = "sha256:8d1f86f3f4e2388aa3310b50694ac44daefbd1681def26b4519bd050a398dc5a"}, {file = "regex-2024.4.28-cp38-cp38-musllinux_1_1_ppc64le.whl", hash = "sha256:7e76b9cfbf5ced1aca15a0e5b6f229344d9b3123439ffce552b11faab0114a02"},
{file = "regex-2024.4.16-cp37-cp37m-win32.whl", hash = "sha256:dd5acc0a7d38fdc7a3a6fd3ad14c880819008ecb3379626e56b163165162cc46"}, {file = "regex-2024.4.28-cp38-cp38-musllinux_1_1_s390x.whl", hash = "sha256:5ce479ecc068bc2a74cb98dd8dba99e070d1b2f4a8371a7dfe631f85db70fe6e"},
{file = "regex-2024.4.16-cp37-cp37m-win_amd64.whl", hash = "sha256:ba8122e3bb94ecda29a8de4cf889f600171424ea586847aa92c334772d200331"}, {file = "regex-2024.4.28-cp38-cp38-musllinux_1_1_x86_64.whl", hash = "sha256:7d77b6f63f806578c604dca209280e4c54f0fa9a8128bb8d2cc5fb6f99da4150"},
{file = "regex-2024.4.16-cp38-cp38-macosx_10_9_universal2.whl", hash = "sha256:743deffdf3b3481da32e8a96887e2aa945ec6685af1cfe2bcc292638c9ba2f48"}, {file = "regex-2024.4.28-cp38-cp38-win32.whl", hash = "sha256:d84308f097d7a513359757c69707ad339da799e53b7393819ec2ea36bc4beb58"},
{file = "regex-2024.4.16-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:7571f19f4a3fd00af9341c7801d1ad1967fc9c3f5e62402683047e7166b9f2b4"}, {file = "regex-2024.4.28-cp38-cp38-win_amd64.whl", hash = "sha256:2cc1b87bba1dd1a898e664a31012725e48af826bf3971e786c53e32e02adae6c"},
{file = "regex-2024.4.16-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:df79012ebf6f4efb8d307b1328226aef24ca446b3ff8d0e30202d7ebcb977a8c"}, {file = "regex-2024.4.28-cp39-cp39-macosx_10_9_universal2.whl", hash = "sha256:7413167c507a768eafb5424413c5b2f515c606be5bb4ef8c5dee43925aa5718b"},
{file = "regex-2024.4.16-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:e757d475953269fbf4b441207bb7dbdd1c43180711b6208e129b637792ac0b93"}, {file = "regex-2024.4.28-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:108e2dcf0b53a7c4ab8986842a8edcb8ab2e59919a74ff51c296772e8e74d0ae"},
{file = "regex-2024.4.16-cp38-cp38-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:4313ab9bf6a81206c8ac28fdfcddc0435299dc88cad12cc6305fd0e78b81f9e4"}, {file = "regex-2024.4.28-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:f1c5742c31ba7d72f2dedf7968998730664b45e38827637e0f04a2ac7de2f5f1"},
{file = "regex-2024.4.16-cp38-cp38-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:d83c2bc678453646f1a18f8db1e927a2d3f4935031b9ad8a76e56760461105dd"}, {file = "regex-2024.4.28-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:ecc6148228c9ae25ce403eade13a0961de1cb016bdb35c6eafd8e7b87ad028b1"},
{file = "regex-2024.4.16-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:9df1bfef97db938469ef0a7354b2d591a2d438bc497b2c489471bec0e6baf7c4"}, {file = "regex-2024.4.28-cp39-cp39-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:b7d893c8cf0e2429b823ef1a1d360a25950ed11f0e2a9df2b5198821832e1947"},
{file = "regex-2024.4.16-cp38-cp38-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:62120ed0de69b3649cc68e2965376048793f466c5a6c4370fb27c16c1beac22d"}, {file = "regex-2024.4.28-cp39-cp39-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:4290035b169578ffbbfa50d904d26bec16a94526071ebec3dadbebf67a26b25e"},
{file = "regex-2024.4.16-cp38-cp38-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_12_x86_64.manylinux2010_x86_64.whl", hash = "sha256:c2ef6f7990b6e8758fe48ad08f7e2f66c8f11dc66e24093304b87cae9037bb4a"}, {file = "regex-2024.4.28-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:44a22ae1cfd82e4ffa2066eb3390777dc79468f866f0625261a93e44cdf6482b"},
{file = "regex-2024.4.16-cp38-cp38-musllinux_1_1_aarch64.whl", hash = "sha256:8fc6976a3395fe4d1fbeb984adaa8ec652a1e12f36b56ec8c236e5117b585427"}, {file = "regex-2024.4.28-cp39-cp39-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:fd24fd140b69f0b0bcc9165c397e9b2e89ecbeda83303abf2a072609f60239e2"},
{file = "regex-2024.4.16-cp38-cp38-musllinux_1_1_i686.whl", hash = "sha256:03e68f44340528111067cecf12721c3df4811c67268b897fbe695c95f860ac42"}, {file = "regex-2024.4.28-cp39-cp39-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_12_x86_64.manylinux2010_x86_64.whl", hash = "sha256:39fb166d2196413bead229cd64a2ffd6ec78ebab83fff7d2701103cf9f4dfd26"},
{file = "regex-2024.4.16-cp38-cp38-musllinux_1_1_ppc64le.whl", hash = "sha256:ec7e0043b91115f427998febaa2beb82c82df708168b35ece3accb610b91fac1"}, {file = "regex-2024.4.28-cp39-cp39-musllinux_1_1_aarch64.whl", hash = "sha256:9301cc6db4d83d2c0719f7fcda37229691745168bf6ae849bea2e85fc769175d"},
{file = "regex-2024.4.16-cp38-cp38-musllinux_1_1_s390x.whl", hash = "sha256:c21fc21a4c7480479d12fd8e679b699f744f76bb05f53a1d14182b31f55aac76"}, {file = "regex-2024.4.28-cp39-cp39-musllinux_1_1_i686.whl", hash = "sha256:7c3d389e8d76a49923683123730c33e9553063d9041658f23897f0b396b2386f"},
{file = "regex-2024.4.16-cp38-cp38-musllinux_1_1_x86_64.whl", hash = "sha256:12f6a3f2f58bb7344751919a1876ee1b976fe08b9ffccb4bbea66f26af6017b9"}, {file = "regex-2024.4.28-cp39-cp39-musllinux_1_1_ppc64le.whl", hash = "sha256:99ef6289b62042500d581170d06e17f5353b111a15aa6b25b05b91c6886df8fc"},
{file = "regex-2024.4.16-cp38-cp38-win32.whl", hash = "sha256:479595a4fbe9ed8f8f72c59717e8cf222da2e4c07b6ae5b65411e6302af9708e"}, {file = "regex-2024.4.28-cp39-cp39-musllinux_1_1_s390x.whl", hash = "sha256:b91d529b47798c016d4b4c1d06cc826ac40d196da54f0de3c519f5a297c5076a"},
{file = "regex-2024.4.16-cp38-cp38-win_amd64.whl", hash = "sha256:0534b034fba6101611968fae8e856c1698da97ce2efb5c2b895fc8b9e23a5834"}, {file = "regex-2024.4.28-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:43548ad74ea50456e1c68d3c67fff3de64c6edb85bcd511d1136f9b5376fc9d1"},
{file = "regex-2024.4.16-cp39-cp39-macosx_10_9_universal2.whl", hash = "sha256:a7ccdd1c4a3472a7533b0a7aa9ee34c9a2bef859ba86deec07aff2ad7e0c3b94"}, {file = "regex-2024.4.28-cp39-cp39-win32.whl", hash = "sha256:05d9b6578a22db7dedb4df81451f360395828b04f4513980b6bd7a1412c679cc"},
{file = "regex-2024.4.16-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:6f2f017c5be19984fbbf55f8af6caba25e62c71293213f044da3ada7091a4455"}, {file = "regex-2024.4.28-cp39-cp39-win_amd64.whl", hash = "sha256:3986217ec830c2109875be740531feb8ddafe0dfa49767cdcd072ed7e8927962"},
{file = "regex-2024.4.16-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:803b8905b52de78b173d3c1e83df0efb929621e7b7c5766c0843704d5332682f"}, {file = "regex-2024.4.28.tar.gz", hash = "sha256:83ab366777ea45d58f72593adf35d36ca911ea8bd838483c1823b883a121b0e4"},
{file = "regex-2024.4.16-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:684008ec44ad275832a5a152f6e764bbe1914bea10968017b6feaecdad5736e0"},
{file = "regex-2024.4.16-cp39-cp39-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:65436dce9fdc0aeeb0a0effe0839cb3d6a05f45aa45a4d9f9c60989beca78b9c"},
{file = "regex-2024.4.16-cp39-cp39-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:ea355eb43b11764cf799dda62c658c4d2fdb16af41f59bb1ccfec517b60bcb07"},
{file = "regex-2024.4.16-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:98c1165f3809ce7774f05cb74e5408cd3aa93ee8573ae959a97a53db3ca3180d"},
{file = "regex-2024.4.16-cp39-cp39-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:cccc79a9be9b64c881f18305a7c715ba199e471a3973faeb7ba84172abb3f317"},
{file = "regex-2024.4.16-cp39-cp39-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_12_x86_64.manylinux2010_x86_64.whl", hash = "sha256:00169caa125f35d1bca6045d65a662af0202704489fada95346cfa092ec23f39"},
{file = "regex-2024.4.16-cp39-cp39-musllinux_1_1_aarch64.whl", hash = "sha256:6cc38067209354e16c5609b66285af17a2863a47585bcf75285cab33d4c3b8df"},
{file = "regex-2024.4.16-cp39-cp39-musllinux_1_1_i686.whl", hash = "sha256:23cff1b267038501b179ccbbd74a821ac4a7192a1852d1d558e562b507d46013"},
{file = "regex-2024.4.16-cp39-cp39-musllinux_1_1_ppc64le.whl", hash = "sha256:b9d320b3bf82a39f248769fc7f188e00f93526cc0fe739cfa197868633d44701"},
{file = "regex-2024.4.16-cp39-cp39-musllinux_1_1_s390x.whl", hash = "sha256:89ec7f2c08937421bbbb8b48c54096fa4f88347946d4747021ad85f1b3021b3c"},
{file = "regex-2024.4.16-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:4918fd5f8b43aa7ec031e0fef1ee02deb80b6afd49c85f0790be1dc4ce34cb50"},
{file = "regex-2024.4.16-cp39-cp39-win32.whl", hash = "sha256:684e52023aec43bdf0250e843e1fdd6febbe831bd9d52da72333fa201aaa2335"},
{file = "regex-2024.4.16-cp39-cp39-win_amd64.whl", hash = "sha256:e697e1c0238133589e00c244a8b676bc2cfc3ab4961318d902040d099fec7483"},
{file = "regex-2024.4.16.tar.gz", hash = "sha256:fa454d26f2e87ad661c4f0c5a5fe4cf6aab1e307d1b94f16ffdfcb089ba685c0"},
] ]
[[package]] [[package]]
@ -4014,13 +3990,13 @@ zstd = ["zstandard (>=0.18.0)"]
[[package]] [[package]]
name = "virtualenv" name = "virtualenv"
version = "20.26.0" version = "20.26.1"
description = "Virtual Python Environment builder" description = "Virtual Python Environment builder"
optional = true optional = true
python-versions = ">=3.7" python-versions = ">=3.7"
files = [ files = [
{file = "virtualenv-20.26.0-py3-none-any.whl", hash = "sha256:0846377ea76e818daaa3e00a4365c018bc3ac9760cbb3544de542885aad61fb3"}, {file = "virtualenv-20.26.1-py3-none-any.whl", hash = "sha256:7aa9982a728ae5892558bff6a2839c00b9ed145523ece2274fad6f414690ae75"},
{file = "virtualenv-20.26.0.tar.gz", hash = "sha256:ec25a9671a5102c8d2657f62792a27b48f016664c6873f6beed3800008577210"}, {file = "virtualenv-20.26.1.tar.gz", hash = "sha256:604bfdceaeece392802e6ae48e69cec49168b9c5f4a44e483963f9242eb0e78b"},
] ]
[package.dependencies] [package.dependencies]

View File

@ -4,9 +4,9 @@ import gymnasium as gym
import pytest import pytest
import lerobot import lerobot
from lerobot.common.policies.act.modeling_act import ActionChunkingTransformerPolicy from lerobot.common.policies.act.modeling_act import ACTPolicy
from lerobot.common.policies.diffusion.modeling_diffusion import DiffusionPolicy from lerobot.common.policies.diffusion.modeling_diffusion import DiffusionPolicy
from lerobot.common.policies.tdmpc.policy import TDMPCPolicy from lerobot.common.policies.tdmpc.modeling_tdmpc import TDMPCPolicy
from tests.utils import require_env from tests.utils import require_env
@ -30,7 +30,7 @@ def test_available_policies():
consistent with those listed in `lerobot/__init__.py`. consistent with those listed in `lerobot/__init__.py`.
""" """
policy_classes = [ policy_classes = [
ActionChunkingTransformerPolicy, ACTPolicy,
DiffusionPolicy, DiffusionPolicy,
TDMPCPolicy, TDMPCPolicy,
] ]

View File

@ -35,7 +35,7 @@ def test_factory(env_name, repo_id, policy_name):
DEFAULT_CONFIG_PATH, DEFAULT_CONFIG_PATH,
overrides=[ overrides=[
f"env={env_name}", f"env={env_name}",
f"dataset.repo_id={repo_id}", f"dataset_repo_id={repo_id}",
f"policy={policy_name}", f"policy={policy_name}",
f"device={DEVICE}", f"device={DEVICE}",
], ],

View File

@ -39,14 +39,14 @@ def test_examples_3_and_2():
("training_steps = 5000", "training_steps = 1"), ("training_steps = 5000", "training_steps = 1"),
("num_workers=4", "num_workers=0"), ("num_workers=4", "num_workers=0"),
('device = torch.device("cuda")', 'device = torch.device("cpu")'), ('device = torch.device("cuda")', 'device = torch.device("cpu")'),
("batch_size=cfg.batch_size", "batch_size=1"), ("batch_size=64", "batch_size=1"),
], ],
) )
# Pass empty globals to allow dictionary comprehension https://stackoverflow.com/a/32897127/4391249. # Pass empty globals to allow dictionary comprehension https://stackoverflow.com/a/32897127/4391249.
exec(file_contents, {}) exec(file_contents, {})
for file_name in ["model.pt", "config.yaml"]: for file_name in ["model.safetensors", "config.json", "config.yaml"]:
assert Path(f"outputs/train/example_pusht_diffusion/{file_name}").exists() assert Path(f"outputs/train/example_pusht_diffusion/{file_name}").exists()
path = "examples/2_evaluate_pretrained_policy.py" path = "examples/2_evaluate_pretrained_policy.py"
@ -58,15 +58,15 @@ def test_examples_3_and_2():
file_contents = _find_and_replace( file_contents = _find_and_replace(
file_contents, file_contents,
[ [
('"eval_episodes=10"', '"eval_episodes=1"'), ('pretrained_policy_name = "lerobot/diffusion_policy_pusht_image"', ""),
('"rollout_batch_size=10"', '"rollout_batch_size=1"'), ("pretrained_policy_path = Path(snapshot_download(pretrained_policy_name))", ""),
('"device=cuda"', '"device=cpu"'),
( (
'# folder = Path("outputs/train/example_pusht_diffusion")', '# pretrained_policy_path = Path("outputs/train/example_pusht_diffusion")',
'folder = Path("outputs/train/example_pusht_diffusion")', 'pretrained_policy_path = Path("outputs/train/example_pusht_diffusion")',
), ),
('hub_id = "lerobot/diffusion_policy_pusht_image"', ""), ('"eval.n_episodes=10"', '"eval.n_episodes=1"'),
("folder = Path(snapshot_download(hub_id)", ""), ('"eval.batch_size=10"', '"eval.batch_size=1"'),
('"device=cuda"', '"device=cpu"'),
], ],
) )

View File

@ -1,41 +1,50 @@
import inspect
import pytest import pytest
import torch import torch
from huggingface_hub import PyTorchModelHubMixin
from lerobot import available_policies
from lerobot.common.datasets.factory import make_dataset from lerobot.common.datasets.factory import make_dataset
from lerobot.common.datasets.utils import cycle from lerobot.common.datasets.utils import cycle
from lerobot.common.envs.factory import make_env from lerobot.common.envs.factory import make_env
from lerobot.common.envs.utils import postprocess_action, preprocess_observation from lerobot.common.envs.utils import postprocess_action, preprocess_observation
from lerobot.common.policies.act.modeling_act import ActionChunkingTransformerPolicy from lerobot.common.policies.factory import get_policy_and_config_classes, make_policy
from lerobot.common.policies.diffusion.modeling_diffusion import DiffusionPolicy
from lerobot.common.policies.factory import make_policy
from lerobot.common.policies.normalize import Normalize, Unnormalize from lerobot.common.policies.normalize import Normalize, Unnormalize
from lerobot.common.policies.policy_protocol import Policy from lerobot.common.policies.policy_protocol import Policy
from lerobot.common.utils.utils import init_hydra_config from lerobot.common.utils.utils import init_hydra_config
from tests.utils import DEFAULT_CONFIG_PATH, DEVICE, require_env from tests.utils import DEFAULT_CONFIG_PATH, DEVICE, require_env
@pytest.mark.parametrize("policy_name", available_policies)
def test_get_policy_and_config_classes(policy_name: str):
"""Check that the correct policy and config classes are returned."""
policy_cls, config_cls = get_policy_and_config_classes(policy_name)
assert policy_cls.name == policy_name
assert issubclass(config_cls, inspect.signature(policy_cls.__init__).parameters["config"].annotation)
# TODO(aliberts): refactor using lerobot/__init__.py variables # TODO(aliberts): refactor using lerobot/__init__.py variables
@pytest.mark.parametrize( @pytest.mark.parametrize(
"env_name,policy_name,extra_overrides", "env_name,policy_name,extra_overrides",
[ [
# ("xarm", "tdmpc", ["policy.mpc=true"]), ("xarm", "tdmpc", ["policy.use_mpc=true", "dataset_repo_id=lerobot/xarm_lift_medium"]),
# ("pusht", "tdmpc", ["policy.mpc=false"]),
("pusht", "diffusion", []), ("pusht", "diffusion", []),
("aloha", "act", ["env.task=AlohaInsertion-v0", "dataset.repo_id=lerobot/aloha_sim_insertion_human"]), ("aloha", "act", ["env.task=AlohaInsertion-v0", "dataset_repo_id=lerobot/aloha_sim_insertion_human"]),
( (
"aloha", "aloha",
"act", "act",
["env.task=AlohaInsertion-v0", "dataset.repo_id=lerobot/aloha_sim_insertion_scripted"], ["env.task=AlohaInsertion-v0", "dataset_repo_id=lerobot/aloha_sim_insertion_scripted"],
), ),
( (
"aloha", "aloha",
"act", "act",
["env.task=AlohaTransferCube-v0", "dataset.repo_id=lerobot/aloha_sim_transfer_cube_human"], ["env.task=AlohaTransferCube-v0", "dataset_repo_id=lerobot/aloha_sim_transfer_cube_human"],
), ),
( (
"aloha", "aloha",
"act", "act",
["env.task=AlohaTransferCube-v0", "dataset.repo_id=lerobot/aloha_sim_transfer_cube_scripted"], ["env.task=AlohaTransferCube-v0", "dataset_repo_id=lerobot/aloha_sim_transfer_cube_scripted"],
), ),
], ],
) )
@ -44,7 +53,8 @@ def test_policy(env_name, policy_name, extra_overrides):
""" """
Tests: Tests:
- Making the policy object. - Making the policy object.
- Checking that the policy follows the correct protocol. - Checking that the policy follows the correct protocol and subclasses nn.Module
and PyTorchModelHubMixin.
- Updating the policy. - Updating the policy.
- Using the policy to select actions at inference time. - Using the policy to select actions at inference time.
- Test the action can be applied to the policy - Test the action can be applied to the policy
@ -61,11 +71,13 @@ def test_policy(env_name, policy_name, extra_overrides):
# Check that we can make the policy object. # Check that we can make the policy object.
dataset = make_dataset(cfg) dataset = make_dataset(cfg)
policy = make_policy(cfg, dataset_stats=dataset.stats) policy = make_policy(hydra_cfg=cfg, dataset_stats=dataset.stats)
# Check that the policy follows the required protocol. # Check that the policy follows the required protocol.
assert isinstance( assert isinstance(
policy, Policy policy, Policy
), f"The policy does not follow the required protocol. Please see {Policy.__module__}.{Policy.__name__}." ), f"The policy does not follow the required protocol. Please see {Policy.__module__}.{Policy.__name__}."
assert isinstance(policy, torch.nn.Module)
assert isinstance(policy, PyTorchModelHubMixin)
# Check that we run select_actions and get the appropriate output. # Check that we run select_actions and get the appropriate output.
env = make_env(cfg, num_parallel_envs=2) env = make_env(cfg, num_parallel_envs=2)
@ -86,7 +98,7 @@ def test_policy(env_name, policy_name, extra_overrides):
batch[key] = batch[key].to(DEVICE, non_blocking=True) batch[key] = batch[key].to(DEVICE, non_blocking=True)
# Test updating the policy # Test updating the policy
policy.forward(batch, step=0) policy.forward(batch)
# reset the policy and environment # reset the policy and environment
policy.reset() policy.reset()
@ -100,7 +112,7 @@ def test_policy(env_name, policy_name, extra_overrides):
# get the next action for the environment # get the next action for the environment
with torch.inference_mode(): with torch.inference_mode():
action = policy.select_action(observation, step=0) action = policy.select_action(observation)
# convert action to cpu numpy array # convert action to cpu numpy array
action = postprocess_action(action) action = postprocess_action(action)
@ -108,29 +120,25 @@ def test_policy(env_name, policy_name, extra_overrides):
# Test step through policy # Test step through policy
env.step(action) env.step(action)
# Test load state_dict
if policy_name != "tdmpc": @pytest.mark.parametrize("policy_name", available_policies)
# TODO(rcadene, alexander-soare): make it work for tdmpc def test_policy_defaults(policy_name: str):
new_policy = make_policy(cfg) """Check that the policy can be instantiated with defaults."""
new_policy.load_state_dict(policy.state_dict()) policy_cls, _ = get_policy_and_config_classes(policy_name)
policy_cls()
@pytest.mark.parametrize("policy_cls", [DiffusionPolicy, ActionChunkingTransformerPolicy]) @pytest.mark.parametrize("policy_name", available_policies)
def test_policy_defaults(policy_cls): def test_save_and_load_pretrained(policy_name: str):
kwargs = {} policy_cls, _ = get_policy_and_config_classes(policy_name)
# TODO(alexander-soare): Remove this kwargs hack when we move the scheduler out of DP. policy: Policy = policy_cls()
if policy_cls is DiffusionPolicy: save_dir = "/tmp/test_save_and_load_pretrained_{policy_cls.__name__}"
kwargs = {"lr_scheduler_num_training_steps": 1} policy.save_pretrained(save_dir)
policy_cls(**kwargs) policy_ = policy_cls.from_pretrained(save_dir)
assert all(torch.equal(p, p_) for p, p_ in zip(policy.parameters(), policy_.parameters(), strict=True))
@pytest.mark.parametrize( @pytest.mark.parametrize("insert_temporal_dim", [False, True])
"insert_temporal_dim",
[
False,
True,
],
)
def test_normalize(insert_temporal_dim): def test_normalize(insert_temporal_dim):
""" """
Test that normalize/unnormalize can run without exceptions when properly set up, and that they raise Test that normalize/unnormalize can run without exceptions when properly set up, and that they raise

View File

@ -20,7 +20,7 @@ def test_visualize_dataset(tmpdir, repo_id):
overrides=[ overrides=[
"policy=act", "policy=act",
"env=aloha", "env=aloha",
f"dataset.repo_id={repo_id}", f"dataset_repo_id={repo_id}",
], ],
) )
video_paths = visualize_dataset(cfg, out_dir=tmpdir) video_paths = visualize_dataset(cfg, out_dir=tmpdir)