Merge branch 'main' into user/rcadene/2024_04_21_load_from_video
This commit is contained in:
commit
ae44510d50
66
Makefile
66
Makefile
|
@ -22,74 +22,82 @@ test-end-to-end:
|
|||
${MAKE} test-act-ete-eval
|
||||
${MAKE} test-diffusion-ete-train
|
||||
${MAKE} test-diffusion-ete-eval
|
||||
# ${MAKE} test-tdmpc-ete-train
|
||||
# ${MAKE} test-tdmpc-ete-eval
|
||||
${MAKE} test-tdmpc-ete-train
|
||||
${MAKE} test-tdmpc-ete-eval
|
||||
${MAKE} test-default-ete-eval
|
||||
|
||||
test-act-ete-train:
|
||||
python lerobot/scripts/train.py \
|
||||
policy=act \
|
||||
env=aloha \
|
||||
wandb.enable=False \
|
||||
offline_steps=2 \
|
||||
online_steps=0 \
|
||||
eval_episodes=1 \
|
||||
training.offline_steps=2 \
|
||||
training.online_steps=0 \
|
||||
eval.n_episodes=1 \
|
||||
device=cpu \
|
||||
save_model=true \
|
||||
save_freq=2 \
|
||||
training.save_model=true \
|
||||
training.save_freq=2 \
|
||||
policy.n_action_steps=20 \
|
||||
policy.chunk_size=20 \
|
||||
policy.batch_size=2 \
|
||||
training.batch_size=2 \
|
||||
hydra.run.dir=tests/outputs/act/
|
||||
|
||||
test-act-ete-eval:
|
||||
python lerobot/scripts/eval.py \
|
||||
--config tests/outputs/act/.hydra/config.yaml \
|
||||
eval_episodes=1 \
|
||||
-p tests/outputs/act/checkpoints/000002 \
|
||||
eval.n_episodes=1 \
|
||||
env.episode_length=8 \
|
||||
device=cpu \
|
||||
policy.pretrained_model_path=tests/outputs/act/models/2.pt
|
||||
|
||||
test-diffusion-ete-train:
|
||||
python lerobot/scripts/train.py \
|
||||
policy=diffusion \
|
||||
env=pusht \
|
||||
wandb.enable=False \
|
||||
offline_steps=2 \
|
||||
online_steps=0 \
|
||||
eval_episodes=1 \
|
||||
training.offline_steps=2 \
|
||||
training.online_steps=0 \
|
||||
eval.n_episodes=1 \
|
||||
device=cpu \
|
||||
save_model=true \
|
||||
save_freq=2 \
|
||||
policy.batch_size=2 \
|
||||
training.save_model=true \
|
||||
training.save_freq=2 \
|
||||
training.batch_size=2 \
|
||||
hydra.run.dir=tests/outputs/diffusion/
|
||||
|
||||
test-diffusion-ete-eval:
|
||||
python lerobot/scripts/eval.py \
|
||||
--config tests/outputs/diffusion/.hydra/config.yaml \
|
||||
eval_episodes=1 \
|
||||
-p tests/outputs/diffusion/checkpoints/000002 \
|
||||
eval.n_episodes=1 \
|
||||
env.episode_length=8 \
|
||||
device=cpu \
|
||||
policy.pretrained_model_path=tests/outputs/diffusion/models/2.pt
|
||||
|
||||
test-tdmpc-ete-train:
|
||||
python lerobot/scripts/train.py \
|
||||
policy=tdmpc \
|
||||
env=xarm \
|
||||
env.task=XarmLift-v0 \
|
||||
dataset_repo_id=lerobot/xarm_lift_medium_replay \
|
||||
wandb.enable=False \
|
||||
offline_steps=1 \
|
||||
online_steps=2 \
|
||||
eval_episodes=1 \
|
||||
training.offline_steps=2 \
|
||||
training.online_steps=2 \
|
||||
eval.n_episodes=1 \
|
||||
env.episode_length=2 \
|
||||
device=cpu \
|
||||
save_model=true \
|
||||
save_freq=2 \
|
||||
policy.batch_size=2 \
|
||||
training.save_model=true \
|
||||
training.save_freq=2 \
|
||||
training.batch_size=2 \
|
||||
hydra.run.dir=tests/outputs/tdmpc/
|
||||
|
||||
test-tdmpc-ete-eval:
|
||||
python lerobot/scripts/eval.py \
|
||||
--config tests/outputs/tdmpc/.hydra/config.yaml \
|
||||
eval_episodes=1 \
|
||||
-p tests/outputs/tdmpc/checkpoints/000002 \
|
||||
eval.n_episodes=1 \
|
||||
env.episode_length=8 \
|
||||
device=cpu \
|
||||
|
||||
|
||||
test-default-ete-eval:
|
||||
python lerobot/scripts/eval.py \
|
||||
--config lerobot/configs/default.yaml \
|
||||
eval.n_episodes=1 \
|
||||
env.episode_length=8 \
|
||||
device=cpu \
|
||||
policy.pretrained_model_path=tests/outputs/tdmpc/models/2.pt
|
||||
|
|
27
README.md
27
README.md
|
@ -135,16 +135,16 @@ Check out [examples](./examples) to see how you can load a pretrained policy fro
|
|||
Or you can achieve the same result by executing our script from the command line:
|
||||
```bash
|
||||
python lerobot/scripts/eval.py \
|
||||
--hub-id lerobot/diffusion_policy_pusht_image \
|
||||
-p lerobot/diffusion_policy_pusht_image \
|
||||
eval_episodes=10 \
|
||||
hydra.run.dir=outputs/eval/example_hub
|
||||
```
|
||||
|
||||
After training your own policy, you can also re-evaluate the checkpoints with:
|
||||
|
||||
```bash
|
||||
python lerobot/scripts/eval.py \
|
||||
--config PATH/TO/FOLDER/config.yaml \
|
||||
policy.pretrained_model_path=PATH/TO/FOLDER/weights.pth \
|
||||
-p PATH/TO/TRAIN/OUTPUT/FOLDER \
|
||||
eval_episodes=10 \
|
||||
hydra.run.dir=outputs/eval/example_dir
|
||||
```
|
||||
|
@ -246,29 +246,22 @@ Once you have trained a policy you may upload it to the HuggingFace hub.
|
|||
|
||||
Firstly, make sure you have a model repository set up on the hub. The hub ID looks like HF_USER/REPO_NAME.
|
||||
|
||||
Secondly, assuming you have trained a policy, you need:
|
||||
Secondly, assuming you have trained a policy, you need the following (which should all be in any of the subdirectories of `checkpoints` in your training output folder, if you've used the LeRobot training script):
|
||||
|
||||
- `config.yaml` which you can get from the `.hydra` directory of your training output folder.
|
||||
- `model.pt` which should be one of the saved models in the `models` directory of your training output folder (they won't be named `model.pt` but you will need to choose one).
|
||||
- `config.json`: A serialized version of the policy configuration (following the policy's dataclass config).
|
||||
- `model.safetensors`: The `torch.nn.Module` parameters saved in [Hugging Face Safetensors](https://huggingface.co/docs/safetensors/index) format.
|
||||
- `config.yaml`: This is the consolidated Hydra training configuration containing the policy, environment, and dataset configs. The policy configuration should match `config.json` exactly. The environment config is useful for anyone who wants to evaluate your policy. The dataset config just serves as a paper trail for reproducibility.
|
||||
|
||||
To upload these to the hub, prepare a folder with the following structure (you can use symlinks rather than copying):
|
||||
|
||||
```
|
||||
to_upload
|
||||
├── config.yaml
|
||||
└── model.pt
|
||||
```
|
||||
|
||||
With the folder prepared, run the following with a desired revision ID.
|
||||
To upload these to the hub, run the following with a desired revision ID.
|
||||
|
||||
```bash
|
||||
huggingface-cli upload $HUB_ID to_upload --revision $REVISION_ID
|
||||
huggingface-cli upload $HUB_ID PATH/TO/OUTPUT/DIR --revision $REVISION_ID
|
||||
```
|
||||
|
||||
If you want this to be the default revision also run the following (don't worry, it won't upload the files again; it will just adjust the file pointers):
|
||||
|
||||
```bash
|
||||
huggingface-cli upload $HUB_ID to_upload
|
||||
huggingface-cli upload $HUB_ID PATH/TO/OUTPUT/DIR
|
||||
```
|
||||
|
||||
See `eval.py` for an example of how a user may use your policy.
|
||||
|
|
|
@ -7,32 +7,21 @@ from pathlib import Path
|
|||
|
||||
from huggingface_hub import snapshot_download
|
||||
|
||||
from lerobot.common.utils.utils import init_hydra_config
|
||||
from lerobot.scripts.eval import eval
|
||||
|
||||
# Get a pretrained policy from the hub.
|
||||
# TODO(alexander-soare): This no longer works until we upload a new model that uses the current configs.
|
||||
hub_id = "lerobot/diffusion_policy_pusht_image"
|
||||
folder = Path(snapshot_download(hub_id))
|
||||
pretrained_policy_name = "lerobot/diffusion_policy_pusht_image"
|
||||
pretrained_policy_path = Path(snapshot_download(pretrained_policy_name))
|
||||
# OR uncomment the following to evaluate a policy from the local outputs/train folder.
|
||||
# folder = Path("outputs/train/example_pusht_diffusion")
|
||||
|
||||
config_path = folder / "config.yaml"
|
||||
weights_path = folder / "model.pt"
|
||||
# pretrained_policy_path = Path("outputs/train/example_pusht_diffusion")
|
||||
|
||||
# Override some config parameters to do with evaluation.
|
||||
overrides = [
|
||||
f"policy.pretrained_model_path={weights_path}",
|
||||
"eval_episodes=10",
|
||||
"rollout_batch_size=10",
|
||||
"eval.n_episodes=10",
|
||||
"eval.batch_size=10",
|
||||
"device=cuda",
|
||||
]
|
||||
|
||||
# Create a Hydra config.
|
||||
cfg = init_hydra_config(config_path, overrides)
|
||||
|
||||
# Evaluate the policy and save the outputs including metrics and videos.
|
||||
eval(
|
||||
cfg,
|
||||
out_dir=f"outputs/eval/example_{cfg.env.name}_{cfg.policy.name}",
|
||||
)
|
||||
# TODO(rcadene, alexander-soare): dont call eval, but add the minimal code snippet to rollout
|
||||
eval(pretrained_policy_path=pretrained_policy_path)
|
||||
|
|
|
@ -34,19 +34,17 @@ dataset = make_dataset(hydra_cfg)
|
|||
# If you're doing something different, you will likely need to change at least some of the defaults.
|
||||
cfg = DiffusionConfig()
|
||||
# TODO(alexander-soare): Remove LR scheduler from the policy.
|
||||
policy = DiffusionPolicy(cfg, lr_scheduler_num_training_steps=training_steps, dataset_stats=dataset.stats)
|
||||
policy = DiffusionPolicy(cfg, dataset_stats=dataset.stats)
|
||||
policy.train()
|
||||
policy.to(device)
|
||||
|
||||
optimizer = torch.optim.Adam(
|
||||
policy.diffusion.parameters(), cfg.lr, cfg.adam_betas, cfg.adam_eps, cfg.adam_weight_decay
|
||||
)
|
||||
optimizer = torch.optim.Adam(policy.parameters(), lr=1e-4)
|
||||
|
||||
# Create dataloader for offline training.
|
||||
dataloader = torch.utils.data.DataLoader(
|
||||
dataset,
|
||||
num_workers=4,
|
||||
batch_size=cfg.batch_size,
|
||||
batch_size=64,
|
||||
shuffle=True,
|
||||
pin_memory=device != torch.device("cpu"),
|
||||
drop_last=True,
|
||||
|
@ -71,6 +69,7 @@ while not done:
|
|||
done = True
|
||||
break
|
||||
|
||||
# Save the policy and configuration for later use.
|
||||
policy.save(output_directory / "model.pt")
|
||||
# Save the policy.
|
||||
policy.save_pretrained(output_directory)
|
||||
# Save the Hydra configuration so we have the environment configuration for eval.
|
||||
OmegaConf.save(hydra_cfg, output_directory / "config.yaml")
|
||||
|
|
|
@ -14,12 +14,13 @@ def make_dataset(
|
|||
cfg,
|
||||
split="train",
|
||||
):
|
||||
if cfg.env.name not in cfg.dataset.repo_id:
|
||||
if cfg.env.name not in cfg.dataset_repo_id:
|
||||
logging.warning(
|
||||
f"There might be a mismatch between your training dataset ({cfg.dataset.repo_id=}) and your environment ({cfg.env.name=})."
|
||||
f"There might be a mismatch between your training dataset ({cfg.dataset_repo_id=}) and your "
|
||||
f"environment ({cfg.env.name=})."
|
||||
)
|
||||
|
||||
delta_timestamps = cfg.policy.get("delta_timestamps")
|
||||
delta_timestamps = cfg.training.get("delta_timestamps")
|
||||
if delta_timestamps is not None:
|
||||
for key in delta_timestamps:
|
||||
if isinstance(delta_timestamps[key], str):
|
||||
|
@ -28,7 +29,7 @@ def make_dataset(
|
|||
# TODO(rcadene): add data augmentations
|
||||
|
||||
dataset = LeRobotDataset(
|
||||
cfg.dataset.repo_id,
|
||||
cfg.dataset_repo_id,
|
||||
split=split,
|
||||
root=DATA_DIR,
|
||||
delta_timestamps=delta_timestamps,
|
||||
|
|
|
@ -5,9 +5,12 @@ import logging
|
|||
import os
|
||||
from pathlib import Path
|
||||
|
||||
from huggingface_hub.constants import SAFETENSORS_SINGLE_FILE
|
||||
from omegaconf import OmegaConf
|
||||
from termcolor import colored
|
||||
|
||||
from lerobot.common.policies.policy_protocol import Policy
|
||||
|
||||
|
||||
def log_output_dir(out_dir):
|
||||
logging.info(colored("Output dir:", "yellow", attrs=["bold"]) + f" {out_dir}")
|
||||
|
@ -30,11 +33,11 @@ class Logger:
|
|||
self._log_dir = Path(log_dir)
|
||||
self._log_dir.mkdir(parents=True, exist_ok=True)
|
||||
self._job_name = job_name
|
||||
self._model_dir = self._log_dir / "models"
|
||||
self._model_dir = self._log_dir / "checkpoints"
|
||||
self._buffer_dir = self._log_dir / "buffers"
|
||||
self._save_model = cfg.save_model
|
||||
self._save_model = cfg.training.save_model
|
||||
self._disable_wandb_artifact = cfg.wandb.disable_artifact
|
||||
self._save_buffer = cfg.save_buffer
|
||||
self._save_buffer = cfg.training.get("save_buffer", False)
|
||||
self._group = cfg_to_group(cfg)
|
||||
self._seed = cfg.seed
|
||||
self._cfg = cfg
|
||||
|
@ -70,18 +73,20 @@ class Logger:
|
|||
logging.info(f"Track this run --> {colored(wandb.run.get_url(), 'yellow', attrs=['bold'])}")
|
||||
self._wandb = wandb
|
||||
|
||||
def save_model(self, policy, identifier):
|
||||
def save_model(self, policy: Policy, identifier):
|
||||
if self._save_model:
|
||||
self._model_dir.mkdir(parents=True, exist_ok=True)
|
||||
fp = self._model_dir / f"{str(identifier)}.pt"
|
||||
policy.save(fp)
|
||||
save_dir = self._model_dir / str(identifier)
|
||||
policy.save_pretrained(save_dir)
|
||||
# Also save the full Hydra config for the env configuration.
|
||||
OmegaConf.save(self._cfg, save_dir / "config.yaml")
|
||||
if self._wandb and not self._disable_wandb_artifact:
|
||||
# note wandb artifact does not accept ":" in its name
|
||||
artifact = self._wandb.Artifact(
|
||||
self._group.replace(":", "_") + "-" + str(self._seed) + "-" + str(identifier),
|
||||
type="model",
|
||||
)
|
||||
artifact.add_file(fp)
|
||||
artifact.add_file(save_dir / SAFETENSORS_SINGLE_FILE)
|
||||
self._wandb.log_artifact(artifact)
|
||||
|
||||
def save_buffer(self, buffer, identifier):
|
||||
|
|
|
@ -2,7 +2,7 @@ from dataclasses import dataclass, field
|
|||
|
||||
|
||||
@dataclass
|
||||
class ActionChunkingTransformerConfig:
|
||||
class ACTConfig:
|
||||
"""Configuration class for the Action Chunking Transformers policy.
|
||||
|
||||
Defaults are configured for training on bimanual Aloha tasks like "insertion" or "transfer".
|
||||
|
@ -22,23 +22,24 @@ class ActionChunkingTransformerConfig:
|
|||
The key represents the input data name, and the value is a list indicating the dimensions
|
||||
of the corresponding data. For example, "observation.images.top" refers to an input from the
|
||||
"top" camera with dimensions [3, 96, 96], indicating it has three color channels and 96x96 resolution.
|
||||
Importantly, shapes doesnt include batch dimension or temporal dimension.
|
||||
Importantly, shapes doesn't include batch dimension or temporal dimension.
|
||||
output_shapes: A dictionary defining the shapes of the output data for the policy.
|
||||
The key represents the output data name, and the value is a list indicating the dimensions
|
||||
of the corresponding data. For example, "action" refers to an output shape of [14], indicating
|
||||
14-dimensional actions. Importantly, shapes doesnt include batch dimension or temporal dimension.
|
||||
normalize_input_modes: A dictionary with key represents the modality (e.g. "observation.state"),
|
||||
and the value specifies the normalization mode to apply. The two availables
|
||||
modes are "mean_std" which substracts the mean and divide by the standard
|
||||
deviation and "min_max" which rescale in a [-1, 1] range.
|
||||
unnormalize_output_modes: Similar dictionary as `normalize_input_modes`, but to unormalize in original scale.
|
||||
14-dimensional actions. Importantly, shapes doesn't include batch dimension or temporal dimension.
|
||||
input_normalization_modes: A dictionary with key representing the modality (e.g. "observation.state"),
|
||||
and the value specifies the normalization mode to apply. The two available modes are "mean_std"
|
||||
which subtracts the mean and divides by the standard deviation and "min_max" which rescale in a
|
||||
[-1, 1] range.
|
||||
output_normalization_modes: Similar dictionary as `normalize_input_modes`, but to unnormalize to the
|
||||
original scale. Note that this is also used for normalizing the training targets.
|
||||
vision_backbone: Name of the torchvision resnet backbone to use for encoding images.
|
||||
pretrained_backbone_weights: Pretrained weights from torchvision to initalize the backbone.
|
||||
`None` means no pretrained weights.
|
||||
replace_final_stride_with_dilation: Whether to replace the ResNet's final 2x2 stride with a dilated
|
||||
convolution.
|
||||
pre_norm: Whether to use "pre-norm" in the transformer blocks.
|
||||
d_model: The transformer blocks' main hidden dimension.
|
||||
dim_model: The transformer blocks' main hidden dimension.
|
||||
n_heads: The number of heads to use in the transformer blocks' multi-head attention.
|
||||
dim_feedforward: The dimension to expand the transformer's hidden dimension to in the feed-forward
|
||||
layers.
|
||||
|
@ -62,13 +63,13 @@ class ActionChunkingTransformerConfig:
|
|||
chunk_size: int = 100
|
||||
n_action_steps: int = 100
|
||||
|
||||
input_shapes: dict[str, list[str]] = field(
|
||||
input_shapes: dict[str, list[int]] = field(
|
||||
default_factory=lambda: {
|
||||
"observation.images.top": [3, 480, 640],
|
||||
"observation.state": [14],
|
||||
}
|
||||
)
|
||||
output_shapes: dict[str, list[str]] = field(
|
||||
output_shapes: dict[str, list[int]] = field(
|
||||
default_factory=lambda: {
|
||||
"action": [14],
|
||||
}
|
||||
|
@ -94,7 +95,7 @@ class ActionChunkingTransformerConfig:
|
|||
replace_final_stride_with_dilation: int = False
|
||||
# Transformer layers.
|
||||
pre_norm: bool = False
|
||||
d_model: int = 512
|
||||
dim_model: int = 512
|
||||
n_heads: int = 8
|
||||
dim_feedforward: int = 3200
|
||||
feedforward_activation: str = "relu"
|
||||
|
@ -112,15 +113,6 @@ class ActionChunkingTransformerConfig:
|
|||
dropout: float = 0.1
|
||||
kl_weight: float = 10.0
|
||||
|
||||
# ---
|
||||
# TODO(alexander-soare): Remove these from the policy config.
|
||||
batch_size: int = 8
|
||||
lr: float = 1e-5
|
||||
lr_backbone: float = 1e-5
|
||||
weight_decay: float = 1e-4
|
||||
grad_clip_norm: float = 10
|
||||
utd: int = 1
|
||||
|
||||
def __post_init__(self):
|
||||
"""Input validation (not exhaustive)."""
|
||||
if not self.vision_backbone.startswith("resnet"):
|
||||
|
|
|
@ -14,18 +14,124 @@ import numpy as np
|
|||
import torch
|
||||
import torch.nn.functional as F # noqa: N812
|
||||
import torchvision
|
||||
from huggingface_hub import PyTorchModelHubMixin
|
||||
from torch import Tensor, nn
|
||||
from torchvision.models._utils import IntermediateLayerGetter
|
||||
from torchvision.ops.misc import FrozenBatchNorm2d
|
||||
|
||||
from lerobot.common.policies.act.configuration_act import ActionChunkingTransformerConfig
|
||||
from lerobot.common.policies.act.configuration_act import ACTConfig
|
||||
from lerobot.common.policies.normalize import Normalize, Unnormalize
|
||||
|
||||
|
||||
class ActionChunkingTransformerPolicy(nn.Module):
|
||||
class ACTPolicy(nn.Module, PyTorchModelHubMixin):
|
||||
"""
|
||||
Action Chunking Transformer Policy as per Learning Fine-Grained Bimanual Manipulation with Low-Cost
|
||||
Hardware (paper: https://arxiv.org/abs/2304.13705, code: https://github.com/tonyzhaozh/act)
|
||||
"""
|
||||
|
||||
name = "act"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config: ACTConfig | None = None,
|
||||
dataset_stats: dict[str, dict[str, Tensor]] | None = None,
|
||||
):
|
||||
"""
|
||||
Args:
|
||||
config: Policy configuration class instance or None, in which case the default instantiation of
|
||||
the configuration class is used.
|
||||
dataset_stats: Dataset statistics to be used for normalization. If not passed here, it is expected
|
||||
that they will be passed with a call to `load_state_dict` before the policy is used.
|
||||
"""
|
||||
super().__init__()
|
||||
if config is None:
|
||||
config = ACTConfig()
|
||||
self.config = config
|
||||
self.normalize_inputs = Normalize(
|
||||
config.input_shapes, config.input_normalization_modes, dataset_stats
|
||||
)
|
||||
self.normalize_targets = Normalize(
|
||||
config.output_shapes, config.output_normalization_modes, dataset_stats
|
||||
)
|
||||
self.unnormalize_outputs = Unnormalize(
|
||||
config.output_shapes, config.output_normalization_modes, dataset_stats
|
||||
)
|
||||
self.model = ACT(config)
|
||||
|
||||
def reset(self):
|
||||
"""This should be called whenever the environment is reset."""
|
||||
if self.config.n_action_steps is not None:
|
||||
self._action_queue = deque([], maxlen=self.config.n_action_steps)
|
||||
|
||||
@torch.no_grad
|
||||
def select_action(self, batch: dict[str, Tensor]) -> Tensor:
|
||||
"""Select a single action given environment observations.
|
||||
|
||||
This method wraps `select_actions` in order to return one action at a time for execution in the
|
||||
environment. It works by managing the actions in a queue and only calling `select_actions` when the
|
||||
queue is empty.
|
||||
"""
|
||||
self.eval()
|
||||
|
||||
batch = self.normalize_inputs(batch)
|
||||
self._stack_images(batch)
|
||||
|
||||
if len(self._action_queue) == 0:
|
||||
# `self.model.forward` returns a (batch_size, n_action_steps, action_dim) tensor, but the queue
|
||||
# effectively has shape (n_action_steps, batch_size, *), hence the transpose.
|
||||
actions = self.model(batch)[0][: self.config.n_action_steps]
|
||||
|
||||
# TODO(rcadene): make _forward return output dictionary?
|
||||
actions = self.unnormalize_outputs({"action": actions})["action"]
|
||||
|
||||
self._action_queue.extend(actions.transpose(0, 1))
|
||||
return self._action_queue.popleft()
|
||||
|
||||
def forward(self, batch: dict[str, Tensor]) -> dict[str, Tensor]:
|
||||
"""Run the batch through the model and compute the loss for training or validation."""
|
||||
batch = self.normalize_inputs(batch)
|
||||
batch = self.normalize_targets(batch)
|
||||
self._stack_images(batch)
|
||||
actions_hat, (mu_hat, log_sigma_x2_hat) = self.model(batch)
|
||||
|
||||
l1_loss = (
|
||||
F.l1_loss(batch["action"], actions_hat, reduction="none") * ~batch["action_is_pad"].unsqueeze(-1)
|
||||
).mean()
|
||||
|
||||
loss_dict = {"l1_loss": l1_loss}
|
||||
if self.config.use_vae:
|
||||
# Calculate Dₖₗ(latent_pdf || standard_normal). Note: After computing the KL-divergence for
|
||||
# each dimension independently, we sum over the latent dimension to get the total
|
||||
# KL-divergence per batch element, then take the mean over the batch.
|
||||
# (See App. B of https://arxiv.org/abs/1312.6114 for more details).
|
||||
mean_kld = (
|
||||
(-0.5 * (1 + log_sigma_x2_hat - mu_hat.pow(2) - (log_sigma_x2_hat).exp())).sum(-1).mean()
|
||||
)
|
||||
loss_dict["kld_loss"] = mean_kld
|
||||
loss_dict["loss"] = l1_loss + mean_kld * self.config.kl_weight
|
||||
else:
|
||||
loss_dict["loss"] = l1_loss
|
||||
|
||||
return loss_dict
|
||||
|
||||
def _stack_images(self, batch: dict[str, Tensor]) -> dict[str, Tensor]:
|
||||
"""Stacks all the images in a batch and puts them in a new key: "observation.images".
|
||||
|
||||
This function expects `batch` to have (at least):
|
||||
{
|
||||
"observation.state": (B, state_dim) batch of robot states.
|
||||
"observation.images.{name}": (B, C, H, W) tensor of images.
|
||||
}
|
||||
"""
|
||||
# Stack images in the order dictated by input_shapes.
|
||||
batch["observation.images"] = torch.stack(
|
||||
[batch[k] for k in self.config.input_shapes if k.startswith("observation.images.")],
|
||||
dim=-4,
|
||||
)
|
||||
|
||||
|
||||
class ACT(nn.Module):
|
||||
"""Action Chunking Transformer: The underlying neural network for ACTPolicy.
|
||||
|
||||
Note: In this code we use the terms `vae_encoder`, 'encoder', `decoder`. The meanings are as follows.
|
||||
- The `vae_encoder` is, as per the literature around variational auto-encoders (VAE), the part of the
|
||||
|
@ -59,51 +165,36 @@ class ActionChunkingTransformerPolicy(nn.Module):
|
|||
└───────────────────────┘
|
||||
"""
|
||||
|
||||
name = "act"
|
||||
|
||||
def __init__(self, cfg: ActionChunkingTransformerConfig | None = None, dataset_stats=None):
|
||||
"""
|
||||
Args:
|
||||
cfg: Policy configuration class instance or None, in which case the default instantiation of the
|
||||
configuration class is used.
|
||||
"""
|
||||
def __init__(self, config: ACTConfig):
|
||||
super().__init__()
|
||||
if cfg is None:
|
||||
cfg = ActionChunkingTransformerConfig()
|
||||
self.cfg = cfg
|
||||
self.normalize_inputs = Normalize(cfg.input_shapes, cfg.input_normalization_modes, dataset_stats)
|
||||
self.normalize_targets = Normalize(cfg.output_shapes, cfg.output_normalization_modes, dataset_stats)
|
||||
self.unnormalize_outputs = Unnormalize(
|
||||
cfg.output_shapes, cfg.output_normalization_modes, dataset_stats
|
||||
)
|
||||
|
||||
self.config = config
|
||||
# BERT style VAE encoder with input [cls, *joint_space_configuration, *action_sequence].
|
||||
# The cls token forms parameters of the latent's distribution (like this [*means, *log_variances]).
|
||||
if self.cfg.use_vae:
|
||||
self.vae_encoder = _TransformerEncoder(cfg)
|
||||
self.vae_encoder_cls_embed = nn.Embedding(1, cfg.d_model)
|
||||
if self.config.use_vae:
|
||||
self.vae_encoder = ACTEncoder(config)
|
||||
self.vae_encoder_cls_embed = nn.Embedding(1, config.dim_model)
|
||||
# Projection layer for joint-space configuration to hidden dimension.
|
||||
self.vae_encoder_robot_state_input_proj = nn.Linear(
|
||||
cfg.input_shapes["observation.state"][0], cfg.d_model
|
||||
config.input_shapes["observation.state"][0], config.dim_model
|
||||
)
|
||||
# Projection layer for action (joint-space target) to hidden dimension.
|
||||
self.vae_encoder_action_input_proj = nn.Linear(
|
||||
cfg.input_shapes["observation.state"][0], cfg.d_model
|
||||
config.input_shapes["observation.state"][0], config.dim_model
|
||||
)
|
||||
self.latent_dim = cfg.latent_dim
|
||||
self.latent_dim = config.latent_dim
|
||||
# Projection layer from the VAE encoder's output to the latent distribution's parameter space.
|
||||
self.vae_encoder_latent_output_proj = nn.Linear(cfg.d_model, self.latent_dim * 2)
|
||||
self.vae_encoder_latent_output_proj = nn.Linear(config.dim_model, self.latent_dim * 2)
|
||||
# Fixed sinusoidal positional embedding the whole input to the VAE encoder. Unsqueeze for batch
|
||||
# dimension.
|
||||
self.register_buffer(
|
||||
"vae_encoder_pos_enc",
|
||||
_create_sinusoidal_position_embedding(1 + 1 + cfg.chunk_size, cfg.d_model).unsqueeze(0),
|
||||
create_sinusoidal_pos_embedding(1 + 1 + config.chunk_size, config.dim_model).unsqueeze(0),
|
||||
)
|
||||
|
||||
# Backbone for image feature extraction.
|
||||
backbone_model = getattr(torchvision.models, cfg.vision_backbone)(
|
||||
replace_stride_with_dilation=[False, False, cfg.replace_final_stride_with_dilation],
|
||||
weights=cfg.pretrained_backbone_weights,
|
||||
backbone_model = getattr(torchvision.models, config.vision_backbone)(
|
||||
replace_stride_with_dilation=[False, False, config.replace_final_stride_with_dilation],
|
||||
weights=config.pretrained_backbone_weights,
|
||||
norm_layer=FrozenBatchNorm2d,
|
||||
)
|
||||
# Note: The assumption here is that we are using a ResNet model (and hence layer4 is the final feature
|
||||
|
@ -112,26 +203,28 @@ class ActionChunkingTransformerPolicy(nn.Module):
|
|||
self.backbone = IntermediateLayerGetter(backbone_model, return_layers={"layer4": "feature_map"})
|
||||
|
||||
# Transformer (acts as VAE decoder when training with the variational objective).
|
||||
self.encoder = _TransformerEncoder(cfg)
|
||||
self.decoder = _TransformerDecoder(cfg)
|
||||
self.encoder = ACTEncoder(config)
|
||||
self.decoder = ACTDecoder(config)
|
||||
|
||||
# Transformer encoder input projections. The tokens will be structured like
|
||||
# [latent, robot_state, image_feature_map_pixels].
|
||||
self.encoder_robot_state_input_proj = nn.Linear(cfg.input_shapes["observation.state"][0], cfg.d_model)
|
||||
self.encoder_latent_input_proj = nn.Linear(self.latent_dim, cfg.d_model)
|
||||
self.encoder_robot_state_input_proj = nn.Linear(
|
||||
config.input_shapes["observation.state"][0], config.dim_model
|
||||
)
|
||||
self.encoder_latent_input_proj = nn.Linear(self.latent_dim, config.dim_model)
|
||||
self.encoder_img_feat_input_proj = nn.Conv2d(
|
||||
backbone_model.fc.in_features, cfg.d_model, kernel_size=1
|
||||
backbone_model.fc.in_features, config.dim_model, kernel_size=1
|
||||
)
|
||||
# Transformer encoder positional embeddings.
|
||||
self.encoder_robot_and_latent_pos_embed = nn.Embedding(2, cfg.d_model)
|
||||
self.encoder_cam_feat_pos_embed = _SinusoidalPositionEmbedding2D(cfg.d_model // 2)
|
||||
self.encoder_robot_and_latent_pos_embed = nn.Embedding(2, config.dim_model)
|
||||
self.encoder_cam_feat_pos_embed = ACTSinusoidalPositionEmbedding2d(config.dim_model // 2)
|
||||
|
||||
# Transformer decoder.
|
||||
# Learnable positional embedding for the transformer's decoder (in the style of DETR object queries).
|
||||
self.decoder_pos_embed = nn.Embedding(cfg.chunk_size, cfg.d_model)
|
||||
self.decoder_pos_embed = nn.Embedding(config.chunk_size, config.dim_model)
|
||||
|
||||
# Final action regression head on the output of the transformer's decoder.
|
||||
self.action_head = nn.Linear(cfg.d_model, cfg.output_shapes["action"][0])
|
||||
self.action_head = nn.Linear(config.dim_model, config.output_shapes["action"][0])
|
||||
|
||||
self._reset_parameters()
|
||||
|
||||
|
@ -141,76 +234,7 @@ class ActionChunkingTransformerPolicy(nn.Module):
|
|||
if p.dim() > 1:
|
||||
nn.init.xavier_uniform_(p)
|
||||
|
||||
def reset(self):
|
||||
"""This should be called whenever the environment is reset."""
|
||||
if self.cfg.n_action_steps is not None:
|
||||
self._action_queue = deque([], maxlen=self.cfg.n_action_steps)
|
||||
|
||||
@torch.no_grad
|
||||
def select_action(self, batch: dict[str, Tensor], **_) -> Tensor:
|
||||
"""Select a single action given environment observations.
|
||||
|
||||
This method wraps `select_actions` in order to return one action at a time for execution in the
|
||||
environment. It works by managing the actions in a queue and only calling `select_actions` when the
|
||||
queue is empty.
|
||||
"""
|
||||
self.eval()
|
||||
|
||||
batch = self.normalize_inputs(batch)
|
||||
|
||||
if len(self._action_queue) == 0:
|
||||
# `_forward` returns a (batch_size, n_action_steps, action_dim) tensor, but the queue effectively
|
||||
# has shape (n_action_steps, batch_size, *), hence the transpose.
|
||||
actions = self._forward(batch)[0][: self.cfg.n_action_steps]
|
||||
|
||||
# TODO(rcadene): make _forward return output dictionary?
|
||||
actions = self.unnormalize_outputs({"action": actions})["action"]
|
||||
|
||||
self._action_queue.extend(actions.transpose(0, 1))
|
||||
return self._action_queue.popleft()
|
||||
|
||||
def forward(self, batch, **_) -> dict[str, Tensor]:
|
||||
"""Run the batch through the model and compute the loss for training or validation."""
|
||||
batch = self.normalize_inputs(batch)
|
||||
batch = self.normalize_targets(batch)
|
||||
actions_hat, (mu_hat, log_sigma_x2_hat) = self._forward(batch)
|
||||
|
||||
l1_loss = (
|
||||
F.l1_loss(batch["action"], actions_hat, reduction="none") * ~batch["action_is_pad"].unsqueeze(-1)
|
||||
).mean()
|
||||
|
||||
loss_dict = {"l1_loss": l1_loss}
|
||||
if self.cfg.use_vae:
|
||||
# Calculate Dₖₗ(latent_pdf || standard_normal). Note: After computing the KL-divergence for
|
||||
# each dimension independently, we sum over the latent dimension to get the total
|
||||
# KL-divergence per batch element, then take the mean over the batch.
|
||||
# (See App. B of https://arxiv.org/abs/1312.6114 for more details).
|
||||
mean_kld = (
|
||||
(-0.5 * (1 + log_sigma_x2_hat - mu_hat.pow(2) - (log_sigma_x2_hat).exp())).sum(-1).mean()
|
||||
)
|
||||
loss_dict["kld_loss"] = mean_kld
|
||||
loss_dict["loss"] = l1_loss + mean_kld * self.cfg.kl_weight
|
||||
else:
|
||||
loss_dict["loss"] = l1_loss
|
||||
|
||||
return loss_dict
|
||||
|
||||
def _stack_images(self, batch: dict[str, Tensor]) -> dict[str, Tensor]:
|
||||
"""Stacks all the images in a batch and puts them in a new key: "observation.images".
|
||||
|
||||
This function expects `batch` to have (at least):
|
||||
{
|
||||
"observation.state": (B, state_dim) batch of robot states.
|
||||
"observation.images.{name}": (B, C, H, W) tensor of images.
|
||||
}
|
||||
"""
|
||||
# Stack images in the order dictated by input_shapes.
|
||||
batch["observation.images"] = torch.stack(
|
||||
[batch[k] for k in self.cfg.input_shapes if k.startswith("observation.images.")],
|
||||
dim=-4,
|
||||
)
|
||||
|
||||
def _forward(self, batch: dict[str, Tensor]) -> tuple[Tensor, tuple[Tensor, Tensor] | tuple[None, None]]:
|
||||
def forward(self, batch: dict[str, Tensor]) -> tuple[Tensor, tuple[Tensor, Tensor] | tuple[None, None]]:
|
||||
"""A forward pass through the Action Chunking Transformer (with optional VAE encoder).
|
||||
|
||||
`batch` should have the following structure:
|
||||
|
@ -226,17 +250,15 @@ class ActionChunkingTransformerPolicy(nn.Module):
|
|||
Tuple containing the latent PDF's parameters (mean, log(σ²)) both as (B, L) tensors where L is the
|
||||
latent dimension.
|
||||
"""
|
||||
if self.cfg.use_vae and self.training:
|
||||
if self.config.use_vae and self.training:
|
||||
assert (
|
||||
"action" in batch
|
||||
), "actions must be provided when using the variational objective in training mode."
|
||||
|
||||
self._stack_images(batch)
|
||||
|
||||
batch_size = batch["observation.state"].shape[0]
|
||||
|
||||
# Prepare the latent for input to the transformer encoder.
|
||||
if self.cfg.use_vae and "action" in batch:
|
||||
if self.config.use_vae and "action" in batch:
|
||||
# Prepare the input to the VAE encoder: [cls, *joint_space_configuration, *action_sequence].
|
||||
cls_embed = einops.repeat(
|
||||
self.vae_encoder_cls_embed.weight, "1 d -> b 1 d", b=batch_size
|
||||
|
@ -306,7 +328,7 @@ class ActionChunkingTransformerPolicy(nn.Module):
|
|||
# Forward pass through the transformer modules.
|
||||
encoder_out = self.encoder(encoder_in, pos_embed=pos_embed)
|
||||
decoder_in = torch.zeros(
|
||||
(self.cfg.chunk_size, batch_size, self.cfg.d_model),
|
||||
(self.config.chunk_size, batch_size, self.config.dim_model),
|
||||
dtype=pos_embed.dtype,
|
||||
device=pos_embed.device,
|
||||
)
|
||||
|
@ -324,21 +346,14 @@ class ActionChunkingTransformerPolicy(nn.Module):
|
|||
|
||||
return actions, (mu, log_sigma_x2)
|
||||
|
||||
def save(self, fp):
|
||||
torch.save(self.state_dict(), fp)
|
||||
|
||||
def load(self, fp):
|
||||
d = torch.load(fp)
|
||||
self.load_state_dict(d)
|
||||
|
||||
|
||||
class _TransformerEncoder(nn.Module):
|
||||
class ACTEncoder(nn.Module):
|
||||
"""Convenience module for running multiple encoder layers, maybe followed by normalization."""
|
||||
|
||||
def __init__(self, cfg: ActionChunkingTransformerConfig):
|
||||
def __init__(self, config: ACTConfig):
|
||||
super().__init__()
|
||||
self.layers = nn.ModuleList([_TransformerEncoderLayer(cfg) for _ in range(cfg.n_encoder_layers)])
|
||||
self.norm = nn.LayerNorm(cfg.d_model) if cfg.pre_norm else nn.Identity()
|
||||
self.layers = nn.ModuleList([ACTEncoderLayer(config) for _ in range(config.n_encoder_layers)])
|
||||
self.norm = nn.LayerNorm(config.dim_model) if config.pre_norm else nn.Identity()
|
||||
|
||||
def forward(self, x: Tensor, pos_embed: Tensor | None = None) -> Tensor:
|
||||
for layer in self.layers:
|
||||
|
@ -347,23 +362,23 @@ class _TransformerEncoder(nn.Module):
|
|||
return x
|
||||
|
||||
|
||||
class _TransformerEncoderLayer(nn.Module):
|
||||
def __init__(self, cfg: ActionChunkingTransformerConfig):
|
||||
class ACTEncoderLayer(nn.Module):
|
||||
def __init__(self, config: ACTConfig):
|
||||
super().__init__()
|
||||
self.self_attn = nn.MultiheadAttention(cfg.d_model, cfg.n_heads, dropout=cfg.dropout)
|
||||
self.self_attn = nn.MultiheadAttention(config.dim_model, config.n_heads, dropout=config.dropout)
|
||||
|
||||
# Feed forward layers.
|
||||
self.linear1 = nn.Linear(cfg.d_model, cfg.dim_feedforward)
|
||||
self.dropout = nn.Dropout(cfg.dropout)
|
||||
self.linear2 = nn.Linear(cfg.dim_feedforward, cfg.d_model)
|
||||
self.linear1 = nn.Linear(config.dim_model, config.dim_feedforward)
|
||||
self.dropout = nn.Dropout(config.dropout)
|
||||
self.linear2 = nn.Linear(config.dim_feedforward, config.dim_model)
|
||||
|
||||
self.norm1 = nn.LayerNorm(cfg.d_model)
|
||||
self.norm2 = nn.LayerNorm(cfg.d_model)
|
||||
self.dropout1 = nn.Dropout(cfg.dropout)
|
||||
self.dropout2 = nn.Dropout(cfg.dropout)
|
||||
self.norm1 = nn.LayerNorm(config.dim_model)
|
||||
self.norm2 = nn.LayerNorm(config.dim_model)
|
||||
self.dropout1 = nn.Dropout(config.dropout)
|
||||
self.dropout2 = nn.Dropout(config.dropout)
|
||||
|
||||
self.activation = _get_activation_fn(cfg.feedforward_activation)
|
||||
self.pre_norm = cfg.pre_norm
|
||||
self.activation = get_activation_fn(config.feedforward_activation)
|
||||
self.pre_norm = config.pre_norm
|
||||
|
||||
def forward(self, x, pos_embed: Tensor | None = None) -> Tensor:
|
||||
skip = x
|
||||
|
@ -385,12 +400,12 @@ class _TransformerEncoderLayer(nn.Module):
|
|||
return x
|
||||
|
||||
|
||||
class _TransformerDecoder(nn.Module):
|
||||
def __init__(self, cfg: ActionChunkingTransformerConfig):
|
||||
class ACTDecoder(nn.Module):
|
||||
def __init__(self, config: ACTConfig):
|
||||
"""Convenience module for running multiple decoder layers followed by normalization."""
|
||||
super().__init__()
|
||||
self.layers = nn.ModuleList([_TransformerDecoderLayer(cfg) for _ in range(cfg.n_decoder_layers)])
|
||||
self.norm = nn.LayerNorm(cfg.d_model)
|
||||
self.layers = nn.ModuleList([ACTDecoderLayer(config) for _ in range(config.n_decoder_layers)])
|
||||
self.norm = nn.LayerNorm(config.dim_model)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
|
@ -408,26 +423,26 @@ class _TransformerDecoder(nn.Module):
|
|||
return x
|
||||
|
||||
|
||||
class _TransformerDecoderLayer(nn.Module):
|
||||
def __init__(self, cfg: ActionChunkingTransformerConfig):
|
||||
class ACTDecoderLayer(nn.Module):
|
||||
def __init__(self, config: ACTConfig):
|
||||
super().__init__()
|
||||
self.self_attn = nn.MultiheadAttention(cfg.d_model, cfg.n_heads, dropout=cfg.dropout)
|
||||
self.multihead_attn = nn.MultiheadAttention(cfg.d_model, cfg.n_heads, dropout=cfg.dropout)
|
||||
self.self_attn = nn.MultiheadAttention(config.dim_model, config.n_heads, dropout=config.dropout)
|
||||
self.multihead_attn = nn.MultiheadAttention(config.dim_model, config.n_heads, dropout=config.dropout)
|
||||
|
||||
# Feed forward layers.
|
||||
self.linear1 = nn.Linear(cfg.d_model, cfg.dim_feedforward)
|
||||
self.dropout = nn.Dropout(cfg.dropout)
|
||||
self.linear2 = nn.Linear(cfg.dim_feedforward, cfg.d_model)
|
||||
self.linear1 = nn.Linear(config.dim_model, config.dim_feedforward)
|
||||
self.dropout = nn.Dropout(config.dropout)
|
||||
self.linear2 = nn.Linear(config.dim_feedforward, config.dim_model)
|
||||
|
||||
self.norm1 = nn.LayerNorm(cfg.d_model)
|
||||
self.norm2 = nn.LayerNorm(cfg.d_model)
|
||||
self.norm3 = nn.LayerNorm(cfg.d_model)
|
||||
self.dropout1 = nn.Dropout(cfg.dropout)
|
||||
self.dropout2 = nn.Dropout(cfg.dropout)
|
||||
self.dropout3 = nn.Dropout(cfg.dropout)
|
||||
self.norm1 = nn.LayerNorm(config.dim_model)
|
||||
self.norm2 = nn.LayerNorm(config.dim_model)
|
||||
self.norm3 = nn.LayerNorm(config.dim_model)
|
||||
self.dropout1 = nn.Dropout(config.dropout)
|
||||
self.dropout2 = nn.Dropout(config.dropout)
|
||||
self.dropout3 = nn.Dropout(config.dropout)
|
||||
|
||||
self.activation = _get_activation_fn(cfg.feedforward_activation)
|
||||
self.pre_norm = cfg.pre_norm
|
||||
self.activation = get_activation_fn(config.feedforward_activation)
|
||||
self.pre_norm = config.pre_norm
|
||||
|
||||
def maybe_add_pos_embed(self, tensor: Tensor, pos_embed: Tensor | None) -> Tensor:
|
||||
return tensor if pos_embed is None else tensor + pos_embed
|
||||
|
@ -480,7 +495,7 @@ class _TransformerDecoderLayer(nn.Module):
|
|||
return x
|
||||
|
||||
|
||||
def _create_sinusoidal_position_embedding(num_positions: int, dimension: int) -> Tensor:
|
||||
def create_sinusoidal_pos_embedding(num_positions: int, dimension: int) -> Tensor:
|
||||
"""1D sinusoidal positional embeddings as in Attention is All You Need.
|
||||
|
||||
Args:
|
||||
|
@ -498,7 +513,7 @@ def _create_sinusoidal_position_embedding(num_positions: int, dimension: int) ->
|
|||
return torch.from_numpy(sinusoid_table).float()
|
||||
|
||||
|
||||
class _SinusoidalPositionEmbedding2D(nn.Module):
|
||||
class ACTSinusoidalPositionEmbedding2d(nn.Module):
|
||||
"""2D sinusoidal positional embeddings similar to what's presented in Attention Is All You Need.
|
||||
|
||||
The variation is that the position indices are normalized in [0, 2π] (not quite: the lower bound is 1/H
|
||||
|
@ -552,7 +567,7 @@ class _SinusoidalPositionEmbedding2D(nn.Module):
|
|||
return pos_embed
|
||||
|
||||
|
||||
def _get_activation_fn(activation: str) -> Callable:
|
||||
def get_activation_fn(activation: str) -> Callable:
|
||||
"""Return an activation function given a string."""
|
||||
if activation == "relu":
|
||||
return F.relu
|
||||
|
|
|
@ -3,7 +3,7 @@ from dataclasses import dataclass, field
|
|||
|
||||
@dataclass
|
||||
class DiffusionConfig:
|
||||
"""Configuration class for Diffusion Policy.
|
||||
"""Configuration class for DiffusionPolicy.
|
||||
|
||||
Defaults are configured for training with PushT providing proprioceptive and single camera observations.
|
||||
|
||||
|
@ -25,11 +25,12 @@ class DiffusionConfig:
|
|||
The key represents the output data name, and the value is a list indicating the dimensions
|
||||
of the corresponding data. For example, "action" refers to an output shape of [14], indicating
|
||||
14-dimensional actions. Importantly, shapes doesnt include batch dimension or temporal dimension.
|
||||
normalize_input_modes: A dictionary with key represents the modality (e.g. "observation.state"),
|
||||
and the value specifies the normalization mode to apply. The two availables
|
||||
modes are "mean_std" which substracts the mean and divide by the standard
|
||||
deviation and "min_max" which rescale in a [-1, 1] range.
|
||||
unnormalize_output_modes: Similar dictionary as `normalize_input_modes`, but to unormalize in original scale.
|
||||
input_normalization_modes: A dictionary with key representing the modality (e.g. "observation.state"),
|
||||
and the value specifies the normalization mode to apply. The two available modes are "mean_std"
|
||||
which subtracts the mean and divides by the standard deviation and "min_max" which rescale in a
|
||||
[-1, 1] range.
|
||||
output_normalization_modes: Similar dictionary as `normalize_input_modes`, but to unnormalize to the
|
||||
original scale. Note that this is also used for normalizing the training targets.
|
||||
vision_backbone: Name of the torchvision resnet backbone to use for encoding images.
|
||||
crop_shape: (H, W) shape to crop images to as a preprocessing step for the vision backbone. Must fit
|
||||
within the image size. If None, no cropping is done.
|
||||
|
@ -70,13 +71,13 @@ class DiffusionConfig:
|
|||
horizon: int = 16
|
||||
n_action_steps: int = 8
|
||||
|
||||
input_shapes: dict[str, list[str]] = field(
|
||||
input_shapes: dict[str, list[int]] = field(
|
||||
default_factory=lambda: {
|
||||
"observation.image": [3, 96, 96],
|
||||
"observation.state": [2],
|
||||
}
|
||||
)
|
||||
output_shapes: dict[str, list[str]] = field(
|
||||
output_shapes: dict[str, list[int]] = field(
|
||||
default_factory=lambda: {
|
||||
"action": [2],
|
||||
}
|
||||
|
@ -119,15 +120,6 @@ class DiffusionConfig:
|
|||
|
||||
# ---
|
||||
# TODO(alexander-soare): Remove these from the policy config.
|
||||
batch_size: int = 64
|
||||
grad_clip_norm: int = 10
|
||||
lr: float = 1.0e-4
|
||||
lr_scheduler: str = "cosine"
|
||||
lr_warmup_steps: int = 500
|
||||
adam_betas: tuple[float, float] = (0.95, 0.999)
|
||||
adam_eps: float = 1.0e-8
|
||||
adam_weight_decay: float = 1.0e-6
|
||||
utd: int = 1
|
||||
use_ema: bool = True
|
||||
ema_update_after_step: int = 0
|
||||
ema_min_alpha: float = 0.0
|
||||
|
|
|
@ -9,7 +9,6 @@ TODO(alexander-soare):
|
|||
"""
|
||||
|
||||
import copy
|
||||
import logging
|
||||
import math
|
||||
from collections import deque
|
||||
from typing import Callable
|
||||
|
@ -19,6 +18,7 @@ import torch
|
|||
import torch.nn.functional as F # noqa: N812
|
||||
import torchvision
|
||||
from diffusers.schedulers.scheduling_ddpm import DDPMScheduler
|
||||
from huggingface_hub import PyTorchModelHubMixin
|
||||
from robomimic.models.base_nets import SpatialSoftmax
|
||||
from torch import Tensor, nn
|
||||
from torch.nn.modules.batchnorm import _BatchNorm
|
||||
|
@ -32,7 +32,7 @@ from lerobot.common.policies.utils import (
|
|||
)
|
||||
|
||||
|
||||
class DiffusionPolicy(nn.Module):
|
||||
class DiffusionPolicy(nn.Module, PyTorchModelHubMixin):
|
||||
"""
|
||||
Diffusion Policy as per "Diffusion Policy: Visuomotor Policy Learning via Action Diffusion"
|
||||
(paper: https://arxiv.org/abs/2303.04137, code: https://github.com/real-stanford/diffusion_policy).
|
||||
|
@ -41,49 +41,55 @@ class DiffusionPolicy(nn.Module):
|
|||
name = "diffusion"
|
||||
|
||||
def __init__(
|
||||
self, cfg: DiffusionConfig | None = None, lr_scheduler_num_training_steps: int = 0, dataset_stats=None
|
||||
self,
|
||||
config: DiffusionConfig | None = None,
|
||||
dataset_stats: dict[str, dict[str, Tensor]] | None = None,
|
||||
):
|
||||
"""
|
||||
Args:
|
||||
cfg: Policy configuration class instance or None, in which case the default instantiation of the
|
||||
configuration class is used.
|
||||
config: Policy configuration class instance or None, in which case the default instantiation of
|
||||
the configuration class is used.
|
||||
dataset_stats: Dataset statistics to be used for normalization. If not passed here, it is expected
|
||||
that they will be passed with a call to `load_state_dict` before the policy is used.
|
||||
"""
|
||||
super().__init__()
|
||||
# TODO(alexander-soare): LR scheduler will be removed.
|
||||
assert lr_scheduler_num_training_steps > 0
|
||||
if cfg is None:
|
||||
cfg = DiffusionConfig()
|
||||
self.cfg = cfg
|
||||
self.normalize_inputs = Normalize(cfg.input_shapes, cfg.input_normalization_modes, dataset_stats)
|
||||
self.normalize_targets = Normalize(cfg.output_shapes, cfg.output_normalization_modes, dataset_stats)
|
||||
if config is None:
|
||||
config = DiffusionConfig()
|
||||
self.config = config
|
||||
self.normalize_inputs = Normalize(
|
||||
config.input_shapes, config.input_normalization_modes, dataset_stats
|
||||
)
|
||||
self.normalize_targets = Normalize(
|
||||
config.output_shapes, config.output_normalization_modes, dataset_stats
|
||||
)
|
||||
self.unnormalize_outputs = Unnormalize(
|
||||
cfg.output_shapes, cfg.output_normalization_modes, dataset_stats
|
||||
config.output_shapes, config.output_normalization_modes, dataset_stats
|
||||
)
|
||||
|
||||
# queues are populated during rollout of the policy, they contain the n latest observations and actions
|
||||
self._queues = None
|
||||
|
||||
self.diffusion = _DiffusionUnetImagePolicy(cfg)
|
||||
self.diffusion = DiffusionModel(config)
|
||||
|
||||
# TODO(alexander-soare): This should probably be managed outside of the policy class.
|
||||
self.ema_diffusion = None
|
||||
self.ema = None
|
||||
if self.cfg.use_ema:
|
||||
if self.config.use_ema:
|
||||
self.ema_diffusion = copy.deepcopy(self.diffusion)
|
||||
self.ema = _EMA(cfg, model=self.ema_diffusion)
|
||||
self.ema = DiffusionEMA(config, model=self.ema_diffusion)
|
||||
|
||||
def reset(self):
|
||||
"""
|
||||
Clear observation and action queues. Should be called on `env.reset()`
|
||||
"""
|
||||
self._queues = {
|
||||
"observation.image": deque(maxlen=self.cfg.n_obs_steps),
|
||||
"observation.state": deque(maxlen=self.cfg.n_obs_steps),
|
||||
"action": deque(maxlen=self.cfg.n_action_steps),
|
||||
"observation.image": deque(maxlen=self.config.n_obs_steps),
|
||||
"observation.state": deque(maxlen=self.config.n_obs_steps),
|
||||
"action": deque(maxlen=self.config.n_action_steps),
|
||||
}
|
||||
|
||||
@torch.no_grad
|
||||
def select_action(self, batch: dict[str, Tensor], **_) -> Tensor:
|
||||
def select_action(self, batch: dict[str, Tensor]) -> Tensor:
|
||||
"""Select a single action given environment observations.
|
||||
|
||||
This method handles caching a history of observations and an action trajectory generated by the
|
||||
|
@ -131,53 +137,41 @@ class DiffusionPolicy(nn.Module):
|
|||
action = self._queues["action"].popleft()
|
||||
return action
|
||||
|
||||
def forward(self, batch: dict[str, Tensor], **_) -> dict[str, Tensor]:
|
||||
def forward(self, batch: dict[str, Tensor]) -> dict[str, Tensor]:
|
||||
"""Run the batch through the model and compute the loss for training or validation."""
|
||||
batch = self.normalize_inputs(batch)
|
||||
batch = self.normalize_targets(batch)
|
||||
loss = self.diffusion.compute_loss(batch)
|
||||
return {"loss": loss}
|
||||
|
||||
def save(self, fp):
|
||||
torch.save(self.state_dict(), fp)
|
||||
|
||||
def load(self, fp):
|
||||
d = torch.load(fp)
|
||||
missing_keys, unexpected_keys = self.load_state_dict(d, strict=False)
|
||||
if len(missing_keys) > 0:
|
||||
assert all(k.startswith("ema_diffusion.") for k in missing_keys)
|
||||
logging.warning(
|
||||
"DiffusionPolicy.load expected ema parameters in loaded state dict but none were found."
|
||||
)
|
||||
assert len(unexpected_keys) == 0
|
||||
|
||||
|
||||
class _DiffusionUnetImagePolicy(nn.Module):
|
||||
def __init__(self, cfg: DiffusionConfig):
|
||||
class DiffusionModel(nn.Module):
|
||||
def __init__(self, config: DiffusionConfig):
|
||||
super().__init__()
|
||||
self.cfg = cfg
|
||||
self.config = config
|
||||
|
||||
self.rgb_encoder = _RgbEncoder(cfg)
|
||||
self.unet = _ConditionalUnet1D(
|
||||
cfg,
|
||||
global_cond_dim=(cfg.output_shapes["action"][0] + self.rgb_encoder.feature_dim) * cfg.n_obs_steps,
|
||||
self.rgb_encoder = DiffusionRgbEncoder(config)
|
||||
self.unet = DiffusionConditionalUnet1d(
|
||||
config,
|
||||
global_cond_dim=(config.output_shapes["action"][0] + self.rgb_encoder.feature_dim)
|
||||
* config.n_obs_steps,
|
||||
)
|
||||
|
||||
self.noise_scheduler = DDPMScheduler(
|
||||
num_train_timesteps=cfg.num_train_timesteps,
|
||||
beta_start=cfg.beta_start,
|
||||
beta_end=cfg.beta_end,
|
||||
beta_schedule=cfg.beta_schedule,
|
||||
num_train_timesteps=config.num_train_timesteps,
|
||||
beta_start=config.beta_start,
|
||||
beta_end=config.beta_end,
|
||||
beta_schedule=config.beta_schedule,
|
||||
variance_type="fixed_small",
|
||||
clip_sample=cfg.clip_sample,
|
||||
clip_sample_range=cfg.clip_sample_range,
|
||||
prediction_type=cfg.prediction_type,
|
||||
clip_sample=config.clip_sample,
|
||||
clip_sample_range=config.clip_sample_range,
|
||||
prediction_type=config.prediction_type,
|
||||
)
|
||||
|
||||
if cfg.num_inference_steps is None:
|
||||
if config.num_inference_steps is None:
|
||||
self.num_inference_steps = self.noise_scheduler.config.num_train_timesteps
|
||||
else:
|
||||
self.num_inference_steps = cfg.num_inference_steps
|
||||
self.num_inference_steps = config.num_inference_steps
|
||||
|
||||
# ========= inference ============
|
||||
def conditional_sample(
|
||||
|
@ -188,7 +182,7 @@ class _DiffusionUnetImagePolicy(nn.Module):
|
|||
|
||||
# Sample prior.
|
||||
sample = torch.randn(
|
||||
size=(batch_size, self.cfg.horizon, self.cfg.output_shapes["action"][0]),
|
||||
size=(batch_size, self.config.horizon, self.config.output_shapes["action"][0]),
|
||||
dtype=dtype,
|
||||
device=device,
|
||||
generator=generator,
|
||||
|
@ -218,7 +212,7 @@ class _DiffusionUnetImagePolicy(nn.Module):
|
|||
"""
|
||||
assert set(batch).issuperset({"observation.state", "observation.image"})
|
||||
batch_size, n_obs_steps = batch["observation.state"].shape[:2]
|
||||
assert n_obs_steps == self.cfg.n_obs_steps
|
||||
assert n_obs_steps == self.config.n_obs_steps
|
||||
|
||||
# Extract image feature (first combine batch and sequence dims).
|
||||
img_features = self.rgb_encoder(einops.rearrange(batch["observation.image"], "b n ... -> (b n) ..."))
|
||||
|
@ -231,10 +225,10 @@ class _DiffusionUnetImagePolicy(nn.Module):
|
|||
sample = self.conditional_sample(batch_size, global_cond=global_cond)
|
||||
|
||||
# `horizon` steps worth of actions (from the first observation).
|
||||
actions = sample[..., : self.cfg.output_shapes["action"][0]]
|
||||
actions = sample[..., : self.config.output_shapes["action"][0]]
|
||||
# Extract `n_action_steps` steps worth of actions (from the current observation).
|
||||
start = n_obs_steps - 1
|
||||
end = start + self.cfg.n_action_steps
|
||||
end = start + self.config.n_action_steps
|
||||
actions = actions[:, start:end]
|
||||
|
||||
return actions
|
||||
|
@ -253,8 +247,8 @@ class _DiffusionUnetImagePolicy(nn.Module):
|
|||
assert set(batch).issuperset({"observation.state", "observation.image", "action", "action_is_pad"})
|
||||
batch_size, n_obs_steps = batch["observation.state"].shape[:2]
|
||||
horizon = batch["action"].shape[1]
|
||||
assert horizon == self.cfg.horizon
|
||||
assert n_obs_steps == self.cfg.n_obs_steps
|
||||
assert horizon == self.config.horizon
|
||||
assert n_obs_steps == self.config.n_obs_steps
|
||||
|
||||
# Extract image feature (first combine batch and sequence dims).
|
||||
img_features = self.rgb_encoder(einops.rearrange(batch["observation.image"], "b n ... -> (b n) ..."))
|
||||
|
@ -283,12 +277,12 @@ class _DiffusionUnetImagePolicy(nn.Module):
|
|||
|
||||
# Compute the loss.
|
||||
# The target is either the original trajectory, or the noise.
|
||||
if self.cfg.prediction_type == "epsilon":
|
||||
if self.config.prediction_type == "epsilon":
|
||||
target = eps
|
||||
elif self.cfg.prediction_type == "sample":
|
||||
elif self.config.prediction_type == "sample":
|
||||
target = batch["action"]
|
||||
else:
|
||||
raise ValueError(f"Unsupported prediction type {self.cfg.prediction_type}")
|
||||
raise ValueError(f"Unsupported prediction type {self.config.prediction_type}")
|
||||
|
||||
loss = F.mse_loss(pred, target, reduction="none")
|
||||
|
||||
|
@ -300,35 +294,35 @@ class _DiffusionUnetImagePolicy(nn.Module):
|
|||
return loss.mean()
|
||||
|
||||
|
||||
class _RgbEncoder(nn.Module):
|
||||
class DiffusionRgbEncoder(nn.Module):
|
||||
"""Encoder an RGB image into a 1D feature vector.
|
||||
|
||||
Includes the ability to normalize and crop the image first.
|
||||
"""
|
||||
|
||||
def __init__(self, cfg: DiffusionConfig):
|
||||
def __init__(self, config: DiffusionConfig):
|
||||
super().__init__()
|
||||
# Set up optional preprocessing.
|
||||
if cfg.crop_shape is not None:
|
||||
if config.crop_shape is not None:
|
||||
self.do_crop = True
|
||||
# Always use center crop for eval
|
||||
self.center_crop = torchvision.transforms.CenterCrop(cfg.crop_shape)
|
||||
if cfg.crop_is_random:
|
||||
self.maybe_random_crop = torchvision.transforms.RandomCrop(cfg.crop_shape)
|
||||
self.center_crop = torchvision.transforms.CenterCrop(config.crop_shape)
|
||||
if config.crop_is_random:
|
||||
self.maybe_random_crop = torchvision.transforms.RandomCrop(config.crop_shape)
|
||||
else:
|
||||
self.maybe_random_crop = self.center_crop
|
||||
else:
|
||||
self.do_crop = False
|
||||
|
||||
# Set up backbone.
|
||||
backbone_model = getattr(torchvision.models, cfg.vision_backbone)(
|
||||
weights=cfg.pretrained_backbone_weights
|
||||
backbone_model = getattr(torchvision.models, config.vision_backbone)(
|
||||
weights=config.pretrained_backbone_weights
|
||||
)
|
||||
# Note: This assumes that the layer4 feature map is children()[-3]
|
||||
# TODO(alexander-soare): Use a safer alternative.
|
||||
self.backbone = nn.Sequential(*(list(backbone_model.children())[:-2]))
|
||||
if cfg.use_group_norm:
|
||||
if cfg.pretrained_backbone_weights:
|
||||
if config.use_group_norm:
|
||||
if config.pretrained_backbone_weights:
|
||||
raise ValueError(
|
||||
"You can't replace BatchNorm in a pretrained model without ruining the weights!"
|
||||
)
|
||||
|
@ -342,11 +336,11 @@ class _RgbEncoder(nn.Module):
|
|||
# Use a dry run to get the feature map shape.
|
||||
with torch.inference_mode():
|
||||
feat_map_shape = tuple(
|
||||
self.backbone(torch.zeros(size=(1, *cfg.input_shapes["observation.image"]))).shape[1:]
|
||||
self.backbone(torch.zeros(size=(1, *config.input_shapes["observation.image"]))).shape[1:]
|
||||
)
|
||||
self.pool = SpatialSoftmax(feat_map_shape, num_kp=cfg.spatial_softmax_num_keypoints)
|
||||
self.feature_dim = cfg.spatial_softmax_num_keypoints * 2
|
||||
self.out = nn.Linear(cfg.spatial_softmax_num_keypoints * 2, self.feature_dim)
|
||||
self.pool = SpatialSoftmax(feat_map_shape, num_kp=config.spatial_softmax_num_keypoints)
|
||||
self.feature_dim = config.spatial_softmax_num_keypoints * 2
|
||||
self.out = nn.Linear(config.spatial_softmax_num_keypoints * 2, self.feature_dim)
|
||||
self.relu = nn.ReLU()
|
||||
|
||||
def forward(self, x: Tensor) -> Tensor:
|
||||
|
@ -403,7 +397,7 @@ def _replace_submodules(
|
|||
return root_module
|
||||
|
||||
|
||||
class _SinusoidalPosEmb(nn.Module):
|
||||
class DiffusionSinusoidalPosEmb(nn.Module):
|
||||
"""1D sinusoidal positional embeddings as in Attention is All You Need."""
|
||||
|
||||
def __init__(self, dim: int):
|
||||
|
@ -420,7 +414,7 @@ class _SinusoidalPosEmb(nn.Module):
|
|||
return emb
|
||||
|
||||
|
||||
class _Conv1dBlock(nn.Module):
|
||||
class DiffusionConv1dBlock(nn.Module):
|
||||
"""Conv1d --> GroupNorm --> Mish"""
|
||||
|
||||
def __init__(self, inp_channels, out_channels, kernel_size, n_groups=8):
|
||||
|
@ -436,40 +430,40 @@ class _Conv1dBlock(nn.Module):
|
|||
return self.block(x)
|
||||
|
||||
|
||||
class _ConditionalUnet1D(nn.Module):
|
||||
class DiffusionConditionalUnet1d(nn.Module):
|
||||
"""A 1D convolutional UNet with FiLM modulation for conditioning.
|
||||
|
||||
Note: this removes local conditioning as compared to the original diffusion policy code.
|
||||
"""
|
||||
|
||||
def __init__(self, cfg: DiffusionConfig, global_cond_dim: int):
|
||||
def __init__(self, config: DiffusionConfig, global_cond_dim: int):
|
||||
super().__init__()
|
||||
|
||||
self.cfg = cfg
|
||||
self.config = config
|
||||
|
||||
# Encoder for the diffusion timestep.
|
||||
self.diffusion_step_encoder = nn.Sequential(
|
||||
_SinusoidalPosEmb(cfg.diffusion_step_embed_dim),
|
||||
nn.Linear(cfg.diffusion_step_embed_dim, cfg.diffusion_step_embed_dim * 4),
|
||||
DiffusionSinusoidalPosEmb(config.diffusion_step_embed_dim),
|
||||
nn.Linear(config.diffusion_step_embed_dim, config.diffusion_step_embed_dim * 4),
|
||||
nn.Mish(),
|
||||
nn.Linear(cfg.diffusion_step_embed_dim * 4, cfg.diffusion_step_embed_dim),
|
||||
nn.Linear(config.diffusion_step_embed_dim * 4, config.diffusion_step_embed_dim),
|
||||
)
|
||||
|
||||
# The FiLM conditioning dimension.
|
||||
cond_dim = cfg.diffusion_step_embed_dim + global_cond_dim
|
||||
cond_dim = config.diffusion_step_embed_dim + global_cond_dim
|
||||
|
||||
# In channels / out channels for each downsampling block in the Unet's encoder. For the decoder, we
|
||||
# just reverse these.
|
||||
in_out = [(cfg.output_shapes["action"][0], cfg.down_dims[0])] + list(
|
||||
zip(cfg.down_dims[:-1], cfg.down_dims[1:], strict=True)
|
||||
in_out = [(config.output_shapes["action"][0], config.down_dims[0])] + list(
|
||||
zip(config.down_dims[:-1], config.down_dims[1:], strict=True)
|
||||
)
|
||||
|
||||
# Unet encoder.
|
||||
common_res_block_kwargs = {
|
||||
"cond_dim": cond_dim,
|
||||
"kernel_size": cfg.kernel_size,
|
||||
"n_groups": cfg.n_groups,
|
||||
"use_film_scale_modulation": cfg.use_film_scale_modulation,
|
||||
"kernel_size": config.kernel_size,
|
||||
"n_groups": config.n_groups,
|
||||
"use_film_scale_modulation": config.use_film_scale_modulation,
|
||||
}
|
||||
self.down_modules = nn.ModuleList([])
|
||||
for ind, (dim_in, dim_out) in enumerate(in_out):
|
||||
|
@ -477,8 +471,8 @@ class _ConditionalUnet1D(nn.Module):
|
|||
self.down_modules.append(
|
||||
nn.ModuleList(
|
||||
[
|
||||
_ConditionalResidualBlock1D(dim_in, dim_out, **common_res_block_kwargs),
|
||||
_ConditionalResidualBlock1D(dim_out, dim_out, **common_res_block_kwargs),
|
||||
DiffusionConditionalResidualBlock1d(dim_in, dim_out, **common_res_block_kwargs),
|
||||
DiffusionConditionalResidualBlock1d(dim_out, dim_out, **common_res_block_kwargs),
|
||||
# Downsample as long as it is not the last block.
|
||||
nn.Conv1d(dim_out, dim_out, 3, 2, 1) if not is_last else nn.Identity(),
|
||||
]
|
||||
|
@ -488,8 +482,12 @@ class _ConditionalUnet1D(nn.Module):
|
|||
# Processing in the middle of the auto-encoder.
|
||||
self.mid_modules = nn.ModuleList(
|
||||
[
|
||||
_ConditionalResidualBlock1D(cfg.down_dims[-1], cfg.down_dims[-1], **common_res_block_kwargs),
|
||||
_ConditionalResidualBlock1D(cfg.down_dims[-1], cfg.down_dims[-1], **common_res_block_kwargs),
|
||||
DiffusionConditionalResidualBlock1d(
|
||||
config.down_dims[-1], config.down_dims[-1], **common_res_block_kwargs
|
||||
),
|
||||
DiffusionConditionalResidualBlock1d(
|
||||
config.down_dims[-1], config.down_dims[-1], **common_res_block_kwargs
|
||||
),
|
||||
]
|
||||
)
|
||||
|
||||
|
@ -501,8 +499,8 @@ class _ConditionalUnet1D(nn.Module):
|
|||
nn.ModuleList(
|
||||
[
|
||||
# dim_in * 2, because it takes the encoder's skip connection as well
|
||||
_ConditionalResidualBlock1D(dim_in * 2, dim_out, **common_res_block_kwargs),
|
||||
_ConditionalResidualBlock1D(dim_out, dim_out, **common_res_block_kwargs),
|
||||
DiffusionConditionalResidualBlock1d(dim_in * 2, dim_out, **common_res_block_kwargs),
|
||||
DiffusionConditionalResidualBlock1d(dim_out, dim_out, **common_res_block_kwargs),
|
||||
# Upsample as long as it is not the last block.
|
||||
nn.ConvTranspose1d(dim_out, dim_out, 4, 2, 1) if not is_last else nn.Identity(),
|
||||
]
|
||||
|
@ -510,8 +508,8 @@ class _ConditionalUnet1D(nn.Module):
|
|||
)
|
||||
|
||||
self.final_conv = nn.Sequential(
|
||||
_Conv1dBlock(cfg.down_dims[0], cfg.down_dims[0], kernel_size=cfg.kernel_size),
|
||||
nn.Conv1d(cfg.down_dims[0], cfg.output_shapes["action"][0], 1),
|
||||
DiffusionConv1dBlock(config.down_dims[0], config.down_dims[0], kernel_size=config.kernel_size),
|
||||
nn.Conv1d(config.down_dims[0], config.output_shapes["action"][0], 1),
|
||||
)
|
||||
|
||||
def forward(self, x: Tensor, timestep: Tensor | int, global_cond=None) -> Tensor:
|
||||
|
@ -559,7 +557,7 @@ class _ConditionalUnet1D(nn.Module):
|
|||
return x
|
||||
|
||||
|
||||
class _ConditionalResidualBlock1D(nn.Module):
|
||||
class DiffusionConditionalResidualBlock1d(nn.Module):
|
||||
"""ResNet style 1D convolutional block with FiLM modulation for conditioning."""
|
||||
|
||||
def __init__(
|
||||
|
@ -578,13 +576,13 @@ class _ConditionalResidualBlock1D(nn.Module):
|
|||
self.use_film_scale_modulation = use_film_scale_modulation
|
||||
self.out_channels = out_channels
|
||||
|
||||
self.conv1 = _Conv1dBlock(in_channels, out_channels, kernel_size, n_groups=n_groups)
|
||||
self.conv1 = DiffusionConv1dBlock(in_channels, out_channels, kernel_size, n_groups=n_groups)
|
||||
|
||||
# FiLM modulation (https://arxiv.org/abs/1709.07871) outputs per-channel bias and (maybe) scale.
|
||||
cond_channels = out_channels * 2 if use_film_scale_modulation else out_channels
|
||||
self.cond_encoder = nn.Sequential(nn.Mish(), nn.Linear(cond_dim, cond_channels))
|
||||
|
||||
self.conv2 = _Conv1dBlock(out_channels, out_channels, kernel_size, n_groups=n_groups)
|
||||
self.conv2 = DiffusionConv1dBlock(out_channels, out_channels, kernel_size, n_groups=n_groups)
|
||||
|
||||
# A final convolution for dimension matching the residual (if needed).
|
||||
self.residual_conv = (
|
||||
|
@ -617,18 +615,18 @@ class _ConditionalResidualBlock1D(nn.Module):
|
|||
return out
|
||||
|
||||
|
||||
class _EMA:
|
||||
class DiffusionEMA:
|
||||
"""
|
||||
Exponential Moving Average of models weights
|
||||
"""
|
||||
|
||||
def __init__(self, cfg: DiffusionConfig, model: nn.Module):
|
||||
def __init__(self, config: DiffusionConfig, model: nn.Module):
|
||||
"""
|
||||
@crowsonkb's notes on EMA Warmup:
|
||||
If gamma=1 and power=1, implements a simple average. gamma=1, power=2/3 are good values for models you plan
|
||||
to train for a million or more steps (reaches decay factor 0.999 at 31.6K steps, 0.9999 at 1M steps),
|
||||
gamma=1, power=3/4 for models you plan to train for less (reaches decay factor 0.999 at 10K steps, 0.9999
|
||||
at 215.4k steps).
|
||||
If gamma=1 and power=1, implements a simple average. gamma=1, power=2/3 are good values for models
|
||||
you plan to train for a million or more steps (reaches decay factor 0.999 at 31.6K steps, 0.9999
|
||||
at 1M steps), gamma=1, power=3/4 for models you plan to train for less (reaches decay factor 0.999
|
||||
at 10K steps, 0.9999 at 215.4k steps).
|
||||
Args:
|
||||
inv_gamma (float): Inverse multiplicative factor of EMA warmup. Default: 1.
|
||||
power (float): Exponential factor of EMA warmup. Default: 2/3.
|
||||
|
@ -639,11 +637,11 @@ class _EMA:
|
|||
self.averaged_model.eval()
|
||||
self.averaged_model.requires_grad_(False)
|
||||
|
||||
self.update_after_step = cfg.ema_update_after_step
|
||||
self.inv_gamma = cfg.ema_inv_gamma
|
||||
self.power = cfg.ema_power
|
||||
self.min_alpha = cfg.ema_min_alpha
|
||||
self.max_alpha = cfg.ema_max_alpha
|
||||
self.update_after_step = config.ema_update_after_step
|
||||
self.inv_gamma = config.ema_inv_gamma
|
||||
self.power = config.ema_power
|
||||
self.min_alpha = config.ema_min_alpha
|
||||
self.max_alpha = config.ema_max_alpha
|
||||
|
||||
self.alpha = 0.0
|
||||
self.optimization_step = 0
|
||||
|
|
|
@ -2,6 +2,7 @@ import inspect
|
|||
|
||||
from omegaconf import DictConfig, OmegaConf
|
||||
|
||||
from lerobot.common.policies.policy_protocol import Policy
|
||||
from lerobot.common.utils.utils import get_safe_torch_device
|
||||
|
||||
|
||||
|
@ -20,42 +21,53 @@ def _policy_cfg_from_hydra_cfg(policy_cfg_class, hydra_cfg):
|
|||
return policy_cfg
|
||||
|
||||
|
||||
def make_policy(hydra_cfg: DictConfig, dataset_stats=None):
|
||||
if hydra_cfg.policy.name == "tdmpc":
|
||||
from lerobot.common.policies.tdmpc.policy import TDMPCPolicy
|
||||
def get_policy_and_config_classes(name: str) -> tuple[Policy, object]:
|
||||
"""Get the policy's class and config class given a name (matching the policy class' `name` attribute)."""
|
||||
if name == "tdmpc":
|
||||
from lerobot.common.policies.tdmpc.configuration_tdmpc import TDMPCConfig
|
||||
from lerobot.common.policies.tdmpc.modeling_tdmpc import TDMPCPolicy
|
||||
|
||||
policy = TDMPCPolicy(
|
||||
hydra_cfg.policy,
|
||||
n_obs_steps=hydra_cfg.n_obs_steps,
|
||||
n_action_steps=hydra_cfg.n_action_steps,
|
||||
device=hydra_cfg.device,
|
||||
)
|
||||
elif hydra_cfg.policy.name == "diffusion":
|
||||
return TDMPCPolicy, TDMPCConfig
|
||||
elif name == "diffusion":
|
||||
from lerobot.common.policies.diffusion.configuration_diffusion import DiffusionConfig
|
||||
from lerobot.common.policies.diffusion.modeling_diffusion import DiffusionPolicy
|
||||
|
||||
policy_cfg = _policy_cfg_from_hydra_cfg(DiffusionConfig, hydra_cfg)
|
||||
policy = DiffusionPolicy(policy_cfg, hydra_cfg.offline_steps, dataset_stats)
|
||||
policy.to(get_safe_torch_device(hydra_cfg.device))
|
||||
elif hydra_cfg.policy.name == "act":
|
||||
from lerobot.common.policies.act.configuration_act import ActionChunkingTransformerConfig
|
||||
from lerobot.common.policies.act.modeling_act import ActionChunkingTransformerPolicy
|
||||
return DiffusionPolicy, DiffusionConfig
|
||||
elif name == "act":
|
||||
from lerobot.common.policies.act.configuration_act import ACTConfig
|
||||
from lerobot.common.policies.act.modeling_act import ACTPolicy
|
||||
|
||||
policy_cfg = _policy_cfg_from_hydra_cfg(ActionChunkingTransformerConfig, hydra_cfg)
|
||||
policy = ActionChunkingTransformerPolicy(policy_cfg, dataset_stats)
|
||||
policy.to(get_safe_torch_device(hydra_cfg.device))
|
||||
return ACTPolicy, ACTConfig
|
||||
else:
|
||||
raise ValueError(hydra_cfg.policy.name)
|
||||
raise NotImplementedError(f"Policy with name {name} is not implemented.")
|
||||
|
||||
if hydra_cfg.policy.pretrained_model_path:
|
||||
# TODO(rcadene): hack for old pretrained models from fowm
|
||||
if hydra_cfg.policy.name == "tdmpc" and "fowm" in hydra_cfg.policy.pretrained_model_path:
|
||||
if "offline" in hydra_cfg.policy.pretrained_model_path:
|
||||
policy.step[0] = 25000
|
||||
elif "final" in hydra_cfg.policy.pretrained_model_path:
|
||||
policy.step[0] = 100000
|
||||
else:
|
||||
raise NotImplementedError()
|
||||
policy.load(hydra_cfg.policy.pretrained_model_path)
|
||||
|
||||
def make_policy(
|
||||
hydra_cfg: DictConfig, pretrained_policy_name_or_path: str | None = None, dataset_stats=None
|
||||
) -> Policy:
|
||||
"""Make an instance of a policy class.
|
||||
|
||||
Args:
|
||||
hydra_cfg: A parsed Hydra configuration (see scripts). If `pretrained_policy_name_or_path` is
|
||||
provided, only `hydra_cfg.policy.name` is used while everything else is ignored.
|
||||
pretrained_policy_name_or_path: Either the repo ID of a model hosted on the Hub or a path to a
|
||||
directory containing weights saved using `Policy.save_pretrained`. Note that providing this
|
||||
argument overrides everything in `hydra_cfg.policy` apart from `hydra_cfg.policy.name`.
|
||||
dataset_stats: Dataset statistics to use for (un)normalization of inputs/outputs in the policy. Must
|
||||
be provided when initializing a new policy, and must not be provided when loading a pretrained
|
||||
policy. Therefore, this argument is mutually exclusive with `pretrained_policy_name_or_path`.
|
||||
"""
|
||||
if not (pretrained_policy_name_or_path is None) ^ (dataset_stats is None):
|
||||
raise ValueError("Only one of `pretrained_policy_name_or_path` and `dataset_stats` may be provided.")
|
||||
|
||||
policy_cls, policy_cfg_class = get_policy_and_config_classes(hydra_cfg.policy.name)
|
||||
|
||||
if pretrained_policy_name_or_path is None:
|
||||
policy_cfg = _policy_cfg_from_hydra_cfg(policy_cfg_class, hydra_cfg)
|
||||
policy = policy_cls(policy_cfg, dataset_stats)
|
||||
else:
|
||||
policy = policy_cls.from_pretrained(pretrained_policy_name_or_path)
|
||||
|
||||
policy.to(get_safe_torch_device(hydra_cfg.device))
|
||||
|
||||
return policy
|
||||
|
|
|
@ -57,17 +57,28 @@ def create_stats_buffers(
|
|||
)
|
||||
|
||||
if stats is not None:
|
||||
# Note: The clone is needed to make sure that the logic in save_pretrained doesn't see duplicated
|
||||
# tensors anywhere (for example, when we use the same stats for normalization and
|
||||
# unnormalization). See the logic here
|
||||
# https://github.com/huggingface/safetensors/blob/079781fd0dc455ba0fe851e2b4507c33d0c0d407/bindings/python/py_src/safetensors/torch.py#L97.
|
||||
if mode == "mean_std":
|
||||
buffer["mean"].data = stats[key]["mean"]
|
||||
buffer["std"].data = stats[key]["std"]
|
||||
buffer["mean"].data = stats[key]["mean"].clone()
|
||||
buffer["std"].data = stats[key]["std"].clone()
|
||||
elif mode == "min_max":
|
||||
buffer["min"].data = stats[key]["min"]
|
||||
buffer["max"].data = stats[key]["max"]
|
||||
buffer["min"].data = stats[key]["min"].clone()
|
||||
buffer["max"].data = stats[key]["max"].clone()
|
||||
|
||||
stats_buffers[key] = buffer
|
||||
return stats_buffers
|
||||
|
||||
|
||||
def _no_stats_error_str(name: str) -> str:
|
||||
return (
|
||||
f"`{name}` is infinity. You should either initialize with `stats` as an argument, or use a "
|
||||
"pretrained model."
|
||||
)
|
||||
|
||||
|
||||
class Normalize(nn.Module):
|
||||
"""Normalizes data (e.g. "observation.image") for more stable and faster convergence during training."""
|
||||
|
||||
|
@ -99,7 +110,6 @@ class Normalize(nn.Module):
|
|||
self.shapes = shapes
|
||||
self.modes = modes
|
||||
self.stats = stats
|
||||
# `self.buffer_observation_state["mean"]` contains `torch.tensor(state_dim)`
|
||||
stats_buffers = create_stats_buffers(shapes, modes, stats)
|
||||
for key, buffer in stats_buffers.items():
|
||||
setattr(self, "buffer_" + key.replace(".", "_"), buffer)
|
||||
|
@ -113,26 +123,14 @@ class Normalize(nn.Module):
|
|||
if mode == "mean_std":
|
||||
mean = buffer["mean"]
|
||||
std = buffer["std"]
|
||||
assert not torch.isinf(mean).any(), (
|
||||
"`mean` is infinity. You forgot to initialize with `stats` as argument, or called "
|
||||
"`policy.load_state_dict`."
|
||||
)
|
||||
assert not torch.isinf(std).any(), (
|
||||
"`std` is infinity. You forgot to initialize with `stats` as argument, or called "
|
||||
"`policy.load_state_dict`."
|
||||
)
|
||||
assert not torch.isinf(mean).any(), _no_stats_error_str("mean")
|
||||
assert not torch.isinf(std).any(), _no_stats_error_str("std")
|
||||
batch[key] = (batch[key] - mean) / (std + 1e-8)
|
||||
elif mode == "min_max":
|
||||
min = buffer["min"]
|
||||
max = buffer["max"]
|
||||
assert not torch.isinf(min).any(), (
|
||||
"`min` is infinity. You forgot to initialize with `stats` as argument, or called "
|
||||
"`policy.load_state_dict`."
|
||||
)
|
||||
assert not torch.isinf(max).any(), (
|
||||
"`max` is infinity. You forgot to initialize with `stats` as argument, or called "
|
||||
"`policy.load_state_dict`."
|
||||
)
|
||||
assert not torch.isinf(min).any(), _no_stats_error_str("min")
|
||||
assert not torch.isinf(max).any(), _no_stats_error_str("max")
|
||||
# normalize to [0,1]
|
||||
batch[key] = (batch[key] - min) / (max - min)
|
||||
# normalize to [-1, 1]
|
||||
|
@ -190,26 +188,14 @@ class Unnormalize(nn.Module):
|
|||
if mode == "mean_std":
|
||||
mean = buffer["mean"]
|
||||
std = buffer["std"]
|
||||
assert not torch.isinf(mean).any(), (
|
||||
"`mean` is infinity. You forgot to initialize with `stats` as argument, or called "
|
||||
"`policy.load_state_dict`."
|
||||
)
|
||||
assert not torch.isinf(std).any(), (
|
||||
"`std` is infinity. You forgot to initialize with `stats` as argument, or called "
|
||||
"`policy.load_state_dict`."
|
||||
)
|
||||
assert not torch.isinf(mean).any(), _no_stats_error_str("mean")
|
||||
assert not torch.isinf(std).any(), _no_stats_error_str("std")
|
||||
batch[key] = batch[key] * std + mean
|
||||
elif mode == "min_max":
|
||||
min = buffer["min"]
|
||||
max = buffer["max"]
|
||||
assert not torch.isinf(min).any(), (
|
||||
"`min` is infinity. You forgot to initialize with `stats` as argument, or called "
|
||||
"`policy.load_state_dict`."
|
||||
)
|
||||
assert not torch.isinf(max).any(), (
|
||||
"`max` is infinity. You forgot to initialize with `stats` as argument, or called "
|
||||
"`policy.load_state_dict`."
|
||||
)
|
||||
assert not torch.isinf(min).any(), _no_stats_error_str("min")
|
||||
assert not torch.isinf(max).any(), _no_stats_error_str("max")
|
||||
batch[key] = (batch[key] + 1) / 2
|
||||
batch[key] = batch[key] * (max - min) + min
|
||||
else:
|
||||
|
|
|
@ -14,10 +14,21 @@ from torch import Tensor
|
|||
|
||||
@runtime_checkable
|
||||
class Policy(Protocol):
|
||||
"""The required interface for implementing a policy."""
|
||||
"""The required interface for implementing a policy.
|
||||
|
||||
We also expect all policies to subclass torch.nn.Module and PyTorchModelHubMixin.
|
||||
"""
|
||||
|
||||
name: str
|
||||
|
||||
def __init__(self, cfg, dataset_stats: dict[str, dict[str, Tensor]] | None = None):
|
||||
"""
|
||||
Args:
|
||||
cfg: Policy configuration class instance or None, in which case the default instantiation of the
|
||||
configuration class is used.
|
||||
dataset_stats: Dataset statistics to be used for normalization.
|
||||
"""
|
||||
|
||||
def reset(self):
|
||||
"""To be called whenever the environment is reset.
|
||||
|
||||
|
@ -36,3 +47,13 @@ class Policy(Protocol):
|
|||
When the model uses a history of observations, or outputs a sequence of actions, this method deals
|
||||
with caching.
|
||||
"""
|
||||
|
||||
|
||||
@runtime_checkable
|
||||
class PolicyWithUpdate(Policy, Protocol):
|
||||
def update(self):
|
||||
"""An update method that is to be called after a training optimization step.
|
||||
|
||||
Implements an additional updates the model parameters may need (for example, doing an EMA step for a
|
||||
target model, or incrementing an internal buffer).
|
||||
"""
|
||||
|
|
|
@ -0,0 +1,150 @@
|
|||
from dataclasses import dataclass, field
|
||||
|
||||
|
||||
@dataclass
|
||||
class TDMPCConfig:
|
||||
"""Configuration class for TDMPCPolicy.
|
||||
|
||||
Defaults are configured for training with xarm_lift_medium_replay providing proprioceptive and single
|
||||
camera observations.
|
||||
|
||||
The parameters you will most likely need to change are the ones which depend on the environment / sensors.
|
||||
Those are: `input_shapes`, `output_shapes`, and perhaps `max_random_shift`.
|
||||
|
||||
Args:
|
||||
n_action_repeats: The number of times to repeat the action returned by the planning. (hint: Google
|
||||
action repeats in Q-learning or ask your favorite chatbot)
|
||||
horizon: Horizon for model predictive control.
|
||||
input_normalization_modes: A dictionary with key representing the modality (e.g. "observation.state"),
|
||||
and the value specifies the normalization mode to apply. The two available modes are "mean_std"
|
||||
which subtracts the mean and divides by the standard deviation and "min_max" which rescale in a
|
||||
[-1, 1] range. Note that here this defaults to None meaning inputs are not normalized. This is to
|
||||
match the original implementation.
|
||||
output_normalization_modes: Similar dictionary as `normalize_input_modes`, but to unnormalize to the
|
||||
original scale. Note that this is also used for normalizing the training targets. NOTE: Clipping
|
||||
to [-1, +1] is used during MPPI/CEM. Therefore, it is recommended that you stick with "min_max"
|
||||
normalization mode here.
|
||||
image_encoder_hidden_dim: Number of channels for the convolutional layers used for image encoding.
|
||||
state_encoder_hidden_dim: Hidden dimension for MLP used for state vector encoding.
|
||||
latent_dim: Observation's latent embedding dimension.
|
||||
q_ensemble_size: Number of Q function estimators to use in an ensemble for uncertainty estimation.
|
||||
mlp_dim: Hidden dimension of MLPs used for modelling the dynamics encoder, reward function, policy
|
||||
(π), Q ensemble, and V.
|
||||
discount: Discount factor (γ) to use for the reinforcement learning formalism.
|
||||
use_mpc: Whether to use model predictive control. The alternative is to just sample the policy model
|
||||
(π) for each step.
|
||||
cem_iterations: Number of iterations for the MPPI/CEM loop in MPC.
|
||||
max_std: Maximum standard deviation for actions sampled from the gaussian PDF in CEM.
|
||||
min_std: Minimum standard deviation for noise applied to actions sampled from the policy model (π).
|
||||
Doubles up as the minimum standard deviation for actions sampled from the gaussian PDF in CEM.
|
||||
n_gaussian_samples: Number of samples to draw from the gaussian distribution every CEM iteration. Must
|
||||
be non-zero.
|
||||
n_pi_samples: Number of samples to draw from the policy / world model rollout every CEM iteration. Can
|
||||
be zero.
|
||||
uncertainty_regularizer_coeff: Coefficient for the uncertainty regularization used when estimating
|
||||
trajectory values (this is the λ coeffiecient in eqn 4 of FOWM).
|
||||
n_elites: The number of elite samples to use for updating the gaussian parameters every CEM iteration.
|
||||
elite_weighting_temperature: The temperature to use for softmax weighting (by trajectory value) of the
|
||||
elites, when updating the gaussian parameters for CEM.
|
||||
gaussian_mean_momentum: Momentum (α) used for EMA updates of the mean parameter μ of the gaussian
|
||||
paramters optimized in CEM. Updates are calculated as μ⁻ ← αμ⁻ + (1-α)μ.
|
||||
max_random_shift_ratio: Maximum random shift (as a proportion of the image size) to apply to the
|
||||
image(s) (in units of pixels) for training-time augmentation. If set to 0, no such augmentation
|
||||
is applied. Note that the input images are assumed to be square for this augmentation.
|
||||
reward_coeff: Loss weighting coefficient for the reward regression loss.
|
||||
expectile_weight: Weighting (τ) used in expectile regression for the state value function (V).
|
||||
v_pred < v_target is weighted by τ and v_pred >= v_target is weighted by (1-τ). τ is expected to
|
||||
be in [0, 1]. Setting τ closer to 1 results in a more "optimistic" V. This is sensible to do
|
||||
because v_target is obtained by evaluating the learned state-action value functions (Q) with
|
||||
in-sample actions that may not be always optimal.
|
||||
value_coeff: Loss weighting coefficient for both the state-action value (Q) TD loss, and the state
|
||||
value (V) expectile regression loss.
|
||||
consistency_coeff: Loss weighting coefficient for the consistency loss.
|
||||
advantage_scaling: A factor by which the advantages are scaled prior to exponentiation for advantage
|
||||
weighted regression of the policy (π) estimator parameters. Note that the exponentiated advantages
|
||||
are clamped at 100.0.
|
||||
pi_coeff: Loss weighting coefficient for the action regression loss.
|
||||
temporal_decay_coeff: Exponential decay coefficient for decaying the loss coefficient for future time-
|
||||
steps. Hint: each loss computation involves `horizon` steps worth of actions starting from the
|
||||
current time step.
|
||||
target_model_momentum: Momentum (α) used for EMA updates of the target models. Updates are calculated
|
||||
as ϕ ← αϕ + (1-α)θ where ϕ are the parameters of the target model and θ are the parameters of the
|
||||
model being trained.
|
||||
"""
|
||||
|
||||
# Input / output structure.
|
||||
n_action_repeats: int = 2
|
||||
horizon: int = 5
|
||||
|
||||
input_shapes: dict[str, list[int]] = field(
|
||||
default_factory=lambda: {
|
||||
"observation.image": [3, 84, 84],
|
||||
"observation.state": [4],
|
||||
}
|
||||
)
|
||||
output_shapes: dict[str, list[int]] = field(
|
||||
default_factory=lambda: {
|
||||
"action": [4],
|
||||
}
|
||||
)
|
||||
|
||||
# Normalization / Unnormalization
|
||||
input_normalization_modes: dict[str, str] | None = None
|
||||
output_normalization_modes: dict[str, str] = field(
|
||||
default_factory=lambda: {"action": "min_max"},
|
||||
)
|
||||
|
||||
# Architecture / modeling.
|
||||
# Neural networks.
|
||||
image_encoder_hidden_dim: int = 32
|
||||
state_encoder_hidden_dim: int = 256
|
||||
latent_dim: int = 50
|
||||
q_ensemble_size: int = 5
|
||||
mlp_dim: int = 512
|
||||
# Reinforcement learning.
|
||||
discount: float = 0.9
|
||||
|
||||
# Inference.
|
||||
use_mpc: bool = True
|
||||
cem_iterations: int = 6
|
||||
max_std: float = 2.0
|
||||
min_std: float = 0.05
|
||||
n_gaussian_samples: int = 512
|
||||
n_pi_samples: int = 51
|
||||
uncertainty_regularizer_coeff: float = 1.0
|
||||
n_elites: int = 50
|
||||
elite_weighting_temperature: float = 0.5
|
||||
gaussian_mean_momentum: float = 0.1
|
||||
|
||||
# Training and loss computation.
|
||||
max_random_shift_ratio: float = 0.0476
|
||||
# Loss coefficients.
|
||||
reward_coeff: float = 0.5
|
||||
expectile_weight: float = 0.9
|
||||
value_coeff: float = 0.1
|
||||
consistency_coeff: float = 20.0
|
||||
advantage_scaling: float = 3.0
|
||||
pi_coeff: float = 0.5
|
||||
temporal_decay_coeff: float = 0.5
|
||||
# Target model.
|
||||
target_model_momentum: float = 0.995
|
||||
|
||||
def __post_init__(self):
|
||||
"""Input validation (not exhaustive)."""
|
||||
if self.input_shapes["observation.image"][-2] != self.input_shapes["observation.image"][-1]:
|
||||
# TODO(alexander-soare): This limitation is solely because of code in the random shift
|
||||
# augmentation. It should be able to be removed.
|
||||
raise ValueError(
|
||||
"Only square images are handled now. Got image shape "
|
||||
f"{self.input_shapes['observation.image']}."
|
||||
)
|
||||
if self.n_gaussian_samples <= 0:
|
||||
raise ValueError(
|
||||
f"The number of guassian samples for CEM should be non-zero. Got `{self.n_gaussian_samples=}`"
|
||||
)
|
||||
if self.output_normalization_modes != {"action": "min_max"}:
|
||||
raise ValueError(
|
||||
"TD-MPC assumes the action space dimensions to all be in [-1, 1]. Therefore it is strongly "
|
||||
f"advised that you stick with the default. See {self.__class__.__name__} docstring for more "
|
||||
"information."
|
||||
)
|
|
@ -1,576 +0,0 @@
|
|||
import os
|
||||
import pickle
|
||||
import re
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F # noqa: N812
|
||||
from torch import distributions as pyd
|
||||
from torch.distributions.utils import _standard_normal
|
||||
|
||||
DEFAULT_ACT_FN = nn.Mish()
|
||||
|
||||
|
||||
def __REDUCE__(b): # noqa: N802, N807
|
||||
return "mean" if b else "none"
|
||||
|
||||
|
||||
def l1(pred, target, reduce=False):
|
||||
"""Computes the L1-loss between predictions and targets."""
|
||||
return F.l1_loss(pred, target, reduction=__REDUCE__(reduce))
|
||||
|
||||
|
||||
def mse(pred, target, reduce=False):
|
||||
"""Computes the MSE loss between predictions and targets."""
|
||||
return F.mse_loss(pred, target, reduction=__REDUCE__(reduce))
|
||||
|
||||
|
||||
def l2_expectile(diff, expectile=0.7, reduce=False):
|
||||
weight = torch.where(diff > 0, expectile, (1 - expectile))
|
||||
loss = weight * (diff**2)
|
||||
reduction = __REDUCE__(reduce)
|
||||
if reduction == "mean":
|
||||
return torch.mean(loss)
|
||||
elif reduction == "sum":
|
||||
return torch.sum(loss)
|
||||
return loss
|
||||
|
||||
|
||||
def _get_out_shape(in_shape, layers):
|
||||
"""Utility function. Returns the output shape of a network for a given input shape."""
|
||||
x = torch.randn(*in_shape).unsqueeze(0)
|
||||
return (nn.Sequential(*layers) if isinstance(layers, list) else layers)(x).squeeze(0).shape
|
||||
|
||||
|
||||
def gaussian_logprob(eps, log_std):
|
||||
"""Compute Gaussian log probability."""
|
||||
residual = (-0.5 * eps.pow(2) - log_std).sum(-1, keepdim=True)
|
||||
return residual - 0.5 * np.log(2 * np.pi) * eps.size(-1)
|
||||
|
||||
|
||||
def squash(mu, pi, log_pi):
|
||||
"""Apply squashing function."""
|
||||
mu = torch.tanh(mu)
|
||||
pi = torch.tanh(pi)
|
||||
log_pi -= torch.log(F.relu(1 - pi.pow(2)) + 1e-6).sum(-1, keepdim=True)
|
||||
return mu, pi, log_pi
|
||||
|
||||
|
||||
def orthogonal_init(m):
|
||||
"""Orthogonal layer initialization."""
|
||||
if isinstance(m, nn.Linear):
|
||||
nn.init.orthogonal_(m.weight.data)
|
||||
if m.bias is not None:
|
||||
nn.init.zeros_(m.bias)
|
||||
elif isinstance(m, nn.Conv2d):
|
||||
gain = nn.init.calculate_gain("relu")
|
||||
nn.init.orthogonal_(m.weight.data, gain)
|
||||
if m.bias is not None:
|
||||
nn.init.zeros_(m.bias)
|
||||
|
||||
|
||||
def ema(m, m_target, tau):
|
||||
"""Update slow-moving average of online network (target network) at rate tau."""
|
||||
with torch.no_grad():
|
||||
# TODO(rcadene, aliberts): issue with strict=False
|
||||
# for p, p_target in zip(m.parameters(), m_target.parameters(), strict=False):
|
||||
# p_target.data.lerp_(p.data, tau)
|
||||
m_params_iter = iter(m.parameters())
|
||||
m_target_params_iter = iter(m_target.parameters())
|
||||
|
||||
while True:
|
||||
try:
|
||||
p = next(m_params_iter)
|
||||
p_target = next(m_target_params_iter)
|
||||
p_target.data.lerp_(p.data, tau)
|
||||
except StopIteration:
|
||||
# If any iterator is exhausted, exit the loop
|
||||
break
|
||||
|
||||
|
||||
def set_requires_grad(net, value):
|
||||
"""Enable/disable gradients for a given (sub)network."""
|
||||
for param in net.parameters():
|
||||
param.requires_grad_(value)
|
||||
|
||||
|
||||
class TruncatedNormal(pyd.Normal):
|
||||
"""Utility class implementing the truncated normal distribution."""
|
||||
|
||||
default_sample_shape = torch.Size()
|
||||
|
||||
def __init__(self, loc, scale, low=-1.0, high=1.0, eps=1e-6):
|
||||
super().__init__(loc, scale, validate_args=False)
|
||||
self.low = low
|
||||
self.high = high
|
||||
self.eps = eps
|
||||
|
||||
def _clamp(self, x):
|
||||
clamped_x = torch.clamp(x, self.low + self.eps, self.high - self.eps)
|
||||
x = x - x.detach() + clamped_x.detach()
|
||||
return x
|
||||
|
||||
def sample(self, clip=None, sample_shape=default_sample_shape):
|
||||
shape = self._extended_shape(sample_shape)
|
||||
eps = _standard_normal(shape, dtype=self.loc.dtype, device=self.loc.device)
|
||||
eps *= self.scale
|
||||
if clip is not None:
|
||||
eps = torch.clamp(eps, -clip, clip)
|
||||
x = self.loc + eps
|
||||
return self._clamp(x)
|
||||
|
||||
|
||||
class NormalizeImg(nn.Module):
|
||||
"""Normalizes pixel observations to [0,1) range."""
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
def forward(self, x):
|
||||
return x.div(255.0)
|
||||
|
||||
|
||||
class Flatten(nn.Module):
|
||||
"""Flattens its input to a (batched) vector."""
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
def forward(self, x):
|
||||
return x.view(x.size(0), -1)
|
||||
|
||||
|
||||
def enc(cfg):
|
||||
obs_shape = {
|
||||
"rgb": (3, cfg.img_size, cfg.img_size),
|
||||
"state": (cfg.state_dim,),
|
||||
}
|
||||
|
||||
"""Returns a TOLD encoder."""
|
||||
pixels_enc_layers, state_enc_layers = None, None
|
||||
if cfg.modality in {"pixels", "all"}:
|
||||
C = int(3 * cfg.frame_stack) # noqa: N806
|
||||
pixels_enc_layers = [
|
||||
NormalizeImg(),
|
||||
nn.Conv2d(C, cfg.num_channels, 7, stride=2),
|
||||
nn.ReLU(),
|
||||
nn.Conv2d(cfg.num_channels, cfg.num_channels, 5, stride=2),
|
||||
nn.ReLU(),
|
||||
nn.Conv2d(cfg.num_channels, cfg.num_channels, 3, stride=2),
|
||||
nn.ReLU(),
|
||||
nn.Conv2d(cfg.num_channels, cfg.num_channels, 3, stride=2),
|
||||
nn.ReLU(),
|
||||
]
|
||||
out_shape = _get_out_shape((C, cfg.img_size, cfg.img_size), pixels_enc_layers)
|
||||
pixels_enc_layers.extend(
|
||||
[
|
||||
Flatten(),
|
||||
nn.Linear(np.prod(out_shape), cfg.latent_dim),
|
||||
nn.LayerNorm(cfg.latent_dim),
|
||||
nn.Sigmoid(),
|
||||
]
|
||||
)
|
||||
if cfg.modality == "pixels":
|
||||
return ConvExt(nn.Sequential(*pixels_enc_layers))
|
||||
if cfg.modality in {"state", "all"}:
|
||||
state_dim = obs_shape[0] if cfg.modality == "state" else obs_shape["state"][0]
|
||||
state_enc_layers = [
|
||||
nn.Linear(state_dim, cfg.enc_dim),
|
||||
nn.ELU(),
|
||||
nn.Linear(cfg.enc_dim, cfg.latent_dim),
|
||||
nn.LayerNorm(cfg.latent_dim),
|
||||
nn.Sigmoid(),
|
||||
]
|
||||
if cfg.modality == "state":
|
||||
return nn.Sequential(*state_enc_layers)
|
||||
else:
|
||||
raise NotImplementedError
|
||||
|
||||
encoders = {}
|
||||
for k in obs_shape:
|
||||
if k == "state":
|
||||
encoders[k] = nn.Sequential(*state_enc_layers)
|
||||
elif k.endswith("rgb"):
|
||||
encoders[k] = ConvExt(nn.Sequential(*pixels_enc_layers))
|
||||
else:
|
||||
raise NotImplementedError
|
||||
return Multiplexer(nn.ModuleDict(encoders))
|
||||
|
||||
|
||||
def mlp(in_dim, mlp_dim, out_dim, act_fn=DEFAULT_ACT_FN):
|
||||
"""Returns an MLP."""
|
||||
if isinstance(mlp_dim, int):
|
||||
mlp_dim = [mlp_dim, mlp_dim]
|
||||
return nn.Sequential(
|
||||
nn.Linear(in_dim, mlp_dim[0]),
|
||||
nn.LayerNorm(mlp_dim[0]),
|
||||
act_fn,
|
||||
nn.Linear(mlp_dim[0], mlp_dim[1]),
|
||||
nn.LayerNorm(mlp_dim[1]),
|
||||
act_fn,
|
||||
nn.Linear(mlp_dim[1], out_dim),
|
||||
)
|
||||
|
||||
|
||||
def dynamics(in_dim, mlp_dim, out_dim, act_fn=DEFAULT_ACT_FN):
|
||||
"""Returns a dynamics network."""
|
||||
return nn.Sequential(
|
||||
mlp(in_dim, mlp_dim, out_dim, act_fn),
|
||||
nn.LayerNorm(out_dim),
|
||||
nn.Sigmoid(),
|
||||
)
|
||||
|
||||
|
||||
def q(cfg):
|
||||
action_dim = cfg.action_dim
|
||||
"""Returns a Q-function that uses Layer Normalization."""
|
||||
return nn.Sequential(
|
||||
nn.Linear(cfg.latent_dim + action_dim, cfg.mlp_dim),
|
||||
nn.LayerNorm(cfg.mlp_dim),
|
||||
nn.Tanh(),
|
||||
nn.Linear(cfg.mlp_dim, cfg.mlp_dim),
|
||||
nn.ELU(),
|
||||
nn.Linear(cfg.mlp_dim, 1),
|
||||
)
|
||||
|
||||
|
||||
def v(cfg):
|
||||
"""Returns a state value function that uses Layer Normalization."""
|
||||
return nn.Sequential(
|
||||
nn.Linear(cfg.latent_dim, cfg.mlp_dim),
|
||||
nn.LayerNorm(cfg.mlp_dim),
|
||||
nn.Tanh(),
|
||||
nn.Linear(cfg.mlp_dim, cfg.mlp_dim),
|
||||
nn.ELU(),
|
||||
nn.Linear(cfg.mlp_dim, 1),
|
||||
)
|
||||
|
||||
|
||||
def aug(cfg):
|
||||
obs_shape = {
|
||||
"rgb": (3, cfg.img_size, cfg.img_size),
|
||||
"state": (4,),
|
||||
}
|
||||
|
||||
"""Multiplex augmentation"""
|
||||
if cfg.modality == "state":
|
||||
return nn.Identity()
|
||||
elif cfg.modality == "pixels":
|
||||
return RandomShiftsAug(cfg)
|
||||
else:
|
||||
augs = {}
|
||||
for k in obs_shape:
|
||||
if k == "state":
|
||||
augs[k] = nn.Identity()
|
||||
elif k.endswith("rgb"):
|
||||
augs[k] = RandomShiftsAug(cfg)
|
||||
else:
|
||||
raise NotImplementedError
|
||||
return Multiplexer(nn.ModuleDict(augs))
|
||||
|
||||
|
||||
class ConvExt(nn.Module):
|
||||
"""Auxiliary conv net accommodating high-dim input"""
|
||||
|
||||
def __init__(self, conv):
|
||||
super().__init__()
|
||||
self.conv = conv
|
||||
|
||||
def forward(self, x):
|
||||
if x.ndim > 4:
|
||||
batch_shape = x.shape[:-3]
|
||||
out = self.conv(x.view(-1, *x.shape[-3:]))
|
||||
out = out.view(*batch_shape, *out.shape[1:])
|
||||
else:
|
||||
out = self.conv(x)
|
||||
return out
|
||||
|
||||
|
||||
class Multiplexer(nn.Module):
|
||||
"""Model multiplexer"""
|
||||
|
||||
def __init__(self, choices):
|
||||
super().__init__()
|
||||
self.choices = choices
|
||||
|
||||
def forward(self, x, key=None):
|
||||
if isinstance(x, dict):
|
||||
if key is not None:
|
||||
return self.choices[key](x)
|
||||
return {k: self.choices[k](_x) for k, _x in x.items()}
|
||||
return self.choices(x)
|
||||
|
||||
|
||||
class RandomShiftsAug(nn.Module):
|
||||
"""
|
||||
Random shift image augmentation.
|
||||
Adapted from https://github.com/facebookresearch/drqv2
|
||||
"""
|
||||
|
||||
def __init__(self, cfg):
|
||||
super().__init__()
|
||||
assert cfg.modality in {"pixels", "all"}
|
||||
self.pad = int(cfg.img_size / 21)
|
||||
|
||||
def forward(self, x):
|
||||
n, c, h, w = x.size()
|
||||
assert h == w
|
||||
padding = tuple([self.pad] * 4)
|
||||
x = F.pad(x, padding, "replicate")
|
||||
eps = 1.0 / (h + 2 * self.pad)
|
||||
arange = torch.linspace(
|
||||
-1.0 + eps,
|
||||
1.0 - eps,
|
||||
h + 2 * self.pad,
|
||||
device=x.device,
|
||||
dtype=torch.float32,
|
||||
)[:h]
|
||||
arange = arange.unsqueeze(0).repeat(h, 1).unsqueeze(2)
|
||||
base_grid = torch.cat([arange, arange.transpose(1, 0)], dim=2)
|
||||
base_grid = base_grid.unsqueeze(0).repeat(n, 1, 1, 1)
|
||||
shift = torch.randint(
|
||||
0,
|
||||
2 * self.pad + 1,
|
||||
size=(n, 1, 1, 2),
|
||||
device=x.device,
|
||||
dtype=torch.float32,
|
||||
)
|
||||
shift *= 2.0 / (h + 2 * self.pad)
|
||||
grid = base_grid + shift
|
||||
return F.grid_sample(x, grid, padding_mode="zeros", align_corners=False)
|
||||
|
||||
|
||||
# TODO(aliberts): remove class
|
||||
# class Episode:
|
||||
# """Storage object for a single episode."""
|
||||
|
||||
# def __init__(self, cfg, init_obs):
|
||||
# action_dim = cfg.action_dim
|
||||
|
||||
# self.cfg = cfg
|
||||
# self.device = torch.device(cfg.buffer_device)
|
||||
# if cfg.modality in {"pixels", "state"}:
|
||||
# dtype = torch.float32 if cfg.modality == "state" else torch.uint8
|
||||
# self.obses = torch.empty(
|
||||
# (cfg.episode_length + 1, *init_obs.shape),
|
||||
# dtype=dtype,
|
||||
# device=self.device,
|
||||
# )
|
||||
# self.obses[0] = torch.tensor(init_obs, dtype=dtype, device=self.device)
|
||||
# elif cfg.modality == "all":
|
||||
# self.obses = {}
|
||||
# for k, v in init_obs.items():
|
||||
# assert k in {"rgb", "state"}
|
||||
# dtype = torch.float32 if k == "state" else torch.uint8
|
||||
# self.obses[k] = torch.empty(
|
||||
# (cfg.episode_length + 1, *v.shape), dtype=dtype, device=self.device
|
||||
# )
|
||||
# self.obses[k][0] = torch.tensor(v, dtype=dtype, device=self.device)
|
||||
# else:
|
||||
# raise ValueError
|
||||
# self.actions = torch.empty((cfg.episode_length, action_dim), dtype=torch.float32, device=self.device)
|
||||
# self.rewards = torch.empty((cfg.episode_length,), dtype=torch.float32, device=self.device)
|
||||
# self.dones = torch.empty((cfg.episode_length,), dtype=torch.bool, device=self.device)
|
||||
# self.masks = torch.empty((cfg.episode_length,), dtype=torch.float32, device=self.device)
|
||||
# self.cumulative_reward = 0
|
||||
# self.done = False
|
||||
# self.success = False
|
||||
# self._idx = 0
|
||||
|
||||
# def __len__(self):
|
||||
# return self._idx
|
||||
|
||||
# @classmethod
|
||||
# def from_trajectory(cls, cfg, obses, actions, rewards, dones=None, masks=None):
|
||||
# """Constructs an episode from a trajectory."""
|
||||
|
||||
# if cfg.modality in {"pixels", "state"}:
|
||||
# episode = cls(cfg, obses[0])
|
||||
# episode.obses[1:] = torch.tensor(obses[1:], dtype=episode.obses.dtype, device=episode.device)
|
||||
# elif cfg.modality == "all":
|
||||
# episode = cls(cfg, {k: v[0] for k, v in obses.items()})
|
||||
# for k in obses:
|
||||
# episode.obses[k][1:] = torch.tensor(
|
||||
# obses[k][1:], dtype=episode.obses[k].dtype, device=episode.device
|
||||
# )
|
||||
# else:
|
||||
# raise NotImplementedError
|
||||
# episode.actions = torch.tensor(actions, dtype=episode.actions.dtype, device=episode.device)
|
||||
# episode.rewards = torch.tensor(rewards, dtype=episode.rewards.dtype, device=episode.device)
|
||||
# episode.dones = (
|
||||
# torch.tensor(dones, dtype=episode.dones.dtype, device=episode.device)
|
||||
# if dones is not None
|
||||
# else torch.zeros_like(episode.dones)
|
||||
# )
|
||||
# episode.masks = (
|
||||
# torch.tensor(masks, dtype=episode.masks.dtype, device=episode.device)
|
||||
# if masks is not None
|
||||
# else torch.ones_like(episode.masks)
|
||||
# )
|
||||
# episode.cumulative_reward = torch.sum(episode.rewards)
|
||||
# episode.done = True
|
||||
# episode._idx = cfg.episode_length
|
||||
# return episode
|
||||
|
||||
# @property
|
||||
# def first(self):
|
||||
# return len(self) == 0
|
||||
|
||||
# def __add__(self, transition):
|
||||
# self.add(*transition)
|
||||
# return self
|
||||
|
||||
# def add(self, obs, action, reward, done, mask=1.0, success=False):
|
||||
# """Add a transition into the episode."""
|
||||
# if isinstance(obs, dict):
|
||||
# for k, v in obs.items():
|
||||
# self.obses[k][self._idx + 1] = torch.tensor(
|
||||
# v, dtype=self.obses[k].dtype, device=self.obses[k].device
|
||||
# )
|
||||
# else:
|
||||
# self.obses[self._idx + 1] = torch.tensor(obs, dtype=self.obses.dtype, device=self.obses.device)
|
||||
# self.actions[self._idx] = action
|
||||
# self.rewards[self._idx] = reward
|
||||
# self.dones[self._idx] = done
|
||||
# self.masks[self._idx] = mask
|
||||
# self.cumulative_reward += reward
|
||||
# self.done = done
|
||||
# self.success = self.success or success
|
||||
# self._idx += 1
|
||||
|
||||
|
||||
def get_dataset_dict(cfg, env, return_reward_normalizer=False):
|
||||
"""Construct a dataset for env"""
|
||||
required_keys = [
|
||||
"observations",
|
||||
"next_observations",
|
||||
"actions",
|
||||
"rewards",
|
||||
"dones",
|
||||
"masks",
|
||||
]
|
||||
|
||||
if cfg.task.startswith("xarm"):
|
||||
dataset_path = os.path.join(cfg.dataset_dir, "buffer.pkl")
|
||||
print(f"Using offline dataset '{dataset_path}'")
|
||||
with open(dataset_path, "rb") as f:
|
||||
dataset_dict = pickle.load(f)
|
||||
for k in required_keys:
|
||||
if k not in dataset_dict and k[:-1] in dataset_dict:
|
||||
dataset_dict[k] = dataset_dict.pop(k[:-1])
|
||||
elif cfg.task.startswith("legged"):
|
||||
dataset_path = os.path.join(cfg.dataset_dir, "buffer.pkl")
|
||||
print(f"Using offline dataset '{dataset_path}'")
|
||||
with open(dataset_path, "rb") as f:
|
||||
dataset_dict = pickle.load(f)
|
||||
dataset_dict["actions"] /= env.unwrapped.clip_actions
|
||||
print(f"clip_actions={env.unwrapped.clip_actions}")
|
||||
else:
|
||||
import d4rl
|
||||
|
||||
dataset_dict = d4rl.qlearning_dataset(env)
|
||||
dones = np.full_like(dataset_dict["rewards"], False, dtype=bool)
|
||||
|
||||
for i in range(len(dones) - 1):
|
||||
if (
|
||||
np.linalg.norm(dataset_dict["observations"][i + 1] - dataset_dict["next_observations"][i])
|
||||
> 1e-6
|
||||
or dataset_dict["terminals"][i] == 1.0
|
||||
):
|
||||
dones[i] = True
|
||||
|
||||
dones[-1] = True
|
||||
|
||||
dataset_dict["masks"] = 1.0 - dataset_dict["terminals"]
|
||||
del dataset_dict["terminals"]
|
||||
|
||||
for k, v in dataset_dict.items():
|
||||
dataset_dict[k] = v.astype(np.float32)
|
||||
|
||||
dataset_dict["dones"] = dones
|
||||
|
||||
if cfg.is_data_clip:
|
||||
lim = 1 - cfg.data_clip_eps
|
||||
dataset_dict["actions"] = np.clip(dataset_dict["actions"], -lim, lim)
|
||||
reward_normalizer = get_reward_normalizer(cfg, dataset_dict)
|
||||
dataset_dict["rewards"] = reward_normalizer(dataset_dict["rewards"])
|
||||
|
||||
for key in required_keys:
|
||||
assert key in dataset_dict, f"Missing `{key}` in dataset."
|
||||
|
||||
if return_reward_normalizer:
|
||||
return dataset_dict, reward_normalizer
|
||||
return dataset_dict
|
||||
|
||||
|
||||
def get_trajectory_boundaries_and_returns(dataset):
|
||||
"""
|
||||
Split dataset into trajectories and compute returns
|
||||
"""
|
||||
episode_starts = [0]
|
||||
episode_ends = []
|
||||
|
||||
episode_return = 0
|
||||
episode_returns = []
|
||||
|
||||
n_transitions = len(dataset["rewards"])
|
||||
|
||||
for i in range(n_transitions):
|
||||
episode_return += dataset["rewards"][i]
|
||||
|
||||
if dataset["dones"][i]:
|
||||
episode_returns.append(episode_return)
|
||||
episode_ends.append(i + 1)
|
||||
if i + 1 < n_transitions:
|
||||
episode_starts.append(i + 1)
|
||||
episode_return = 0.0
|
||||
|
||||
return episode_starts, episode_ends, episode_returns
|
||||
|
||||
|
||||
def normalize_returns(dataset, scaling=1000):
|
||||
"""
|
||||
Normalize returns in the dataset
|
||||
"""
|
||||
(_, _, episode_returns) = get_trajectory_boundaries_and_returns(dataset)
|
||||
dataset["rewards"] /= np.max(episode_returns) - np.min(episode_returns)
|
||||
dataset["rewards"] *= scaling
|
||||
return dataset
|
||||
|
||||
|
||||
def get_reward_normalizer(cfg, dataset):
|
||||
"""
|
||||
Get a reward normalizer for the dataset
|
||||
"""
|
||||
if cfg.task.startswith("xarm"):
|
||||
return lambda x: x
|
||||
elif "maze" in cfg.task:
|
||||
return lambda x: x - 1.0
|
||||
elif cfg.task.split("-")[0] in ["hopper", "halfcheetah", "walker2d"]:
|
||||
(_, _, episode_returns) = get_trajectory_boundaries_and_returns(dataset)
|
||||
return lambda x: x / (np.max(episode_returns) - np.min(episode_returns)) * 1000.0
|
||||
elif hasattr(cfg, "reward_scale"):
|
||||
return lambda x: x * cfg.reward_scale
|
||||
return lambda x: x
|
||||
|
||||
|
||||
def linear_schedule(schdl, step):
|
||||
"""
|
||||
Outputs values following a linear decay schedule.
|
||||
Adapted from https://github.com/facebookresearch/drqv2
|
||||
"""
|
||||
try:
|
||||
return float(schdl)
|
||||
except ValueError:
|
||||
match = re.match(r"linear\((.+),(.+),(.+),(.+)\)", schdl)
|
||||
if match:
|
||||
init, final, start, end = (float(g) for g in match.groups())
|
||||
mix = np.clip((step - start) / (end - start), 0.0, 1.0)
|
||||
return (1.0 - mix) * init + mix * final
|
||||
match = re.match(r"linear\((.+),(.+),(.+)\)", schdl)
|
||||
if match:
|
||||
init, final, duration = (float(g) for g in match.groups())
|
||||
mix = np.clip(step / duration, 0.0, 1.0)
|
||||
return (1.0 - mix) * init + mix * final
|
||||
raise NotImplementedError(schdl)
|
|
@ -0,0 +1,798 @@
|
|||
"""Implementation of Finetuning Offline World Models in the Real World.
|
||||
|
||||
The comments in this code may sometimes refer to these references:
|
||||
TD-MPC paper: Temporal Difference Learning for Model Predictive Control (https://arxiv.org/abs/2203.04955)
|
||||
FOWM paper: Finetuning Offline World Models in the Real World (https://arxiv.org/abs/2310.16029)
|
||||
|
||||
TODO(alexander-soare): Make rollout work for batch sizes larger than 1.
|
||||
TODO(alexander-soare): Use batch-first throughout.
|
||||
"""
|
||||
|
||||
# ruff: noqa: N806
|
||||
|
||||
import logging
|
||||
from collections import deque
|
||||
from copy import deepcopy
|
||||
from functools import partial
|
||||
from typing import Callable
|
||||
|
||||
import einops
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F # noqa: N812
|
||||
from huggingface_hub import PyTorchModelHubMixin
|
||||
from torch import Tensor
|
||||
|
||||
from lerobot.common.policies.normalize import Normalize, Unnormalize
|
||||
from lerobot.common.policies.tdmpc.configuration_tdmpc import TDMPCConfig
|
||||
from lerobot.common.policies.utils import get_device_from_parameters, populate_queues
|
||||
|
||||
|
||||
class TDMPCPolicy(nn.Module, PyTorchModelHubMixin):
|
||||
"""Implementation of TD-MPC learning + inference.
|
||||
|
||||
Please note several warnings for this policy.
|
||||
- Evaluation of pretrained weights created with the original FOWM code
|
||||
(https://github.com/fyhMer/fowm) works as expected. To be precise: we trained and evaluated a
|
||||
model with the FOWM code for the xarm_lift_medium_replay dataset. We ported the weights across
|
||||
to LeRobot, and were able to evaluate with the same success metric. BUT, we had to use inter-
|
||||
process communication to use the xarm environment from FOWM. This is because our xarm
|
||||
environment uses newer dependencies and does not match the environment in FOWM. See
|
||||
https://github.com/huggingface/lerobot/pull/103 for implementation details.
|
||||
- We have NOT checked that training on LeRobot reproduces SOTA results. This is a TODO.
|
||||
- Our current xarm datasets were generated using the environment from FOWM. Therefore they do not
|
||||
match our xarm environment.
|
||||
"""
|
||||
|
||||
name = "tdmpc"
|
||||
|
||||
def __init__(
|
||||
self, config: TDMPCConfig | None = None, dataset_stats: dict[str, dict[str, Tensor]] | None = None
|
||||
):
|
||||
"""
|
||||
Args:
|
||||
config: Policy configuration class instance or None, in which case the default instantiation of
|
||||
the configuration class is used.
|
||||
dataset_stats: Dataset statistics to be used for normalization. If not passed here, it is expected
|
||||
that they will be passed with a call to `load_state_dict` before the policy is used.
|
||||
"""
|
||||
super().__init__()
|
||||
logging.warning(
|
||||
"""
|
||||
Please note several warnings for this policy.
|
||||
|
||||
- Evaluation of pretrained weights created with the original FOWM code
|
||||
(https://github.com/fyhMer/fowm) works as expected. To be precise: we trained and evaluated a
|
||||
model with the FOWM code for the xarm_lift_medium_replay dataset. We ported the weights across
|
||||
to LeRobot, and were able to evaluate with the same success metric. BUT, we had to use inter-
|
||||
process communication to use the xarm environment from FOWM. This is because our xarm
|
||||
environment uses newer dependencies and does not match the environment in FOWM. See
|
||||
https://github.com/huggingface/lerobot/pull/103 for implementation details.
|
||||
- We have NOT checked that training on LeRobot reproduces SOTA results. This is a TODO.
|
||||
- Our current xarm datasets were generated using the environment from FOWM. Therefore they do not
|
||||
match our xarm environment.
|
||||
"""
|
||||
)
|
||||
|
||||
if config is None:
|
||||
config = TDMPCConfig()
|
||||
self.config = config
|
||||
self.model = TDMPCTOLD(config)
|
||||
self.model_target = deepcopy(self.model)
|
||||
self.model_target.eval()
|
||||
|
||||
if config.input_normalization_modes is not None:
|
||||
self.normalize_inputs = Normalize(
|
||||
config.input_shapes, config.input_normalization_modes, dataset_stats
|
||||
)
|
||||
else:
|
||||
self.normalize_inputs = nn.Identity()
|
||||
self.normalize_targets = Normalize(
|
||||
config.output_shapes, config.output_normalization_modes, dataset_stats
|
||||
)
|
||||
self.unnormalize_outputs = Unnormalize(
|
||||
config.output_shapes, config.output_normalization_modes, dataset_stats
|
||||
)
|
||||
|
||||
def save(self, fp):
|
||||
"""Save state dict of TOLD model to filepath."""
|
||||
torch.save(self.state_dict(), fp)
|
||||
|
||||
def load(self, fp):
|
||||
"""Load a saved state dict from filepath into current agent."""
|
||||
self.load_state_dict(torch.load(fp))
|
||||
|
||||
def reset(self):
|
||||
"""
|
||||
Clear observation and action queues. Clear previous means for warm starting of MPPI/CEM. Should be
|
||||
called on `env.reset()`
|
||||
"""
|
||||
self._queues = {
|
||||
"observation.image": deque(maxlen=1),
|
||||
"observation.state": deque(maxlen=1),
|
||||
"action": deque(maxlen=self.config.n_action_repeats),
|
||||
}
|
||||
# Previous mean obtained from the cross-entropy method (CEM) used during MPC. It is used to warm start
|
||||
# CEM for the next step.
|
||||
self._prev_mean: torch.Tensor | None = None
|
||||
|
||||
@torch.no_grad()
|
||||
def select_action(self, batch: dict[str, Tensor]):
|
||||
"""Select a single action given environment observations."""
|
||||
assert "observation.image" in batch
|
||||
assert "observation.state" in batch
|
||||
assert len(batch) == 2
|
||||
|
||||
batch = self.normalize_inputs(batch)
|
||||
|
||||
self._queues = populate_queues(self._queues, batch)
|
||||
|
||||
# When the action queue is depleted, populate it again by querying the policy.
|
||||
if len(self._queues["action"]) == 0:
|
||||
batch = {key: torch.stack(list(self._queues[key]), dim=1) for key in batch}
|
||||
|
||||
# Remove the time dimensions as it is not handled yet.
|
||||
for key in batch:
|
||||
assert batch[key].shape[1] == 1
|
||||
batch[key] = batch[key][:, 0]
|
||||
|
||||
# NOTE: Order of observations matters here.
|
||||
z = self.model.encode({k: batch[k] for k in ["observation.image", "observation.state"]})
|
||||
if self.config.use_mpc:
|
||||
batch_size = batch["observation.image"].shape[0]
|
||||
# Batch processing is not handled in MPC mode, so process the batch in a loop.
|
||||
action = [] # will be a batch of actions for one step
|
||||
for i in range(batch_size):
|
||||
# Note: self.plan does not handle batches, hence the squeeze.
|
||||
action.append(self.plan(z[i]))
|
||||
action = torch.stack(action)
|
||||
else:
|
||||
# Plan with the policy (π) alone.
|
||||
action = self.model.pi(z)
|
||||
|
||||
self.unnormalize_outputs({"action": action})["action"]
|
||||
|
||||
for _ in range(self.config.n_action_repeats):
|
||||
self._queues["action"].append(action)
|
||||
|
||||
action = self._queues["action"].popleft()
|
||||
return torch.clamp(action, -1, 1)
|
||||
|
||||
@torch.no_grad()
|
||||
def plan(self, z: Tensor) -> Tensor:
|
||||
"""Plan next action using TD-MPC inference.
|
||||
|
||||
Args:
|
||||
z: (latent_dim,) tensor for the initial state.
|
||||
Returns:
|
||||
(action_dim,) tensor for the next action.
|
||||
|
||||
TODO(alexander-soare) Extend this to be able to work with batches.
|
||||
"""
|
||||
device = get_device_from_parameters(self)
|
||||
|
||||
# Sample Nπ trajectories from the policy.
|
||||
pi_actions = torch.empty(
|
||||
self.config.horizon,
|
||||
self.config.n_pi_samples,
|
||||
self.config.output_shapes["action"][0],
|
||||
device=device,
|
||||
)
|
||||
if self.config.n_pi_samples > 0:
|
||||
_z = einops.repeat(z, "d -> n d", n=self.config.n_pi_samples)
|
||||
for t in range(self.config.horizon):
|
||||
# Note: Adding a small amount of noise here doesn't hurt during inference and may even be
|
||||
# helpful for CEM.
|
||||
pi_actions[t] = self.model.pi(_z, self.config.min_std)
|
||||
_z = self.model.latent_dynamics(_z, pi_actions[t])
|
||||
|
||||
# In the CEM loop we will need this for a call to estimate_value with the gaussian sampled
|
||||
# trajectories.
|
||||
z = einops.repeat(z, "d -> n d", n=self.config.n_gaussian_samples + self.config.n_pi_samples)
|
||||
|
||||
# Model Predictive Path Integral (MPPI) with the cross-entropy method (CEM) as the optimization
|
||||
# algorithm.
|
||||
# The initial mean and standard deviation for the cross-entropy method (CEM).
|
||||
mean = torch.zeros(self.config.horizon, self.config.output_shapes["action"][0], device=device)
|
||||
# Maybe warm start CEM with the mean from the previous step.
|
||||
if self._prev_mean is not None:
|
||||
mean[:-1] = self._prev_mean[1:]
|
||||
std = self.config.max_std * torch.ones_like(mean)
|
||||
|
||||
for _ in range(self.config.cem_iterations):
|
||||
# Randomly sample action trajectories for the gaussian distribution.
|
||||
std_normal_noise = torch.randn(
|
||||
self.config.horizon,
|
||||
self.config.n_gaussian_samples,
|
||||
self.config.output_shapes["action"][0],
|
||||
device=std.device,
|
||||
)
|
||||
gaussian_actions = torch.clamp(mean.unsqueeze(1) + std.unsqueeze(1) * std_normal_noise, -1, 1)
|
||||
|
||||
# Compute elite actions.
|
||||
actions = torch.cat([gaussian_actions, pi_actions], dim=1)
|
||||
value = self.estimate_value(z, actions).nan_to_num_(0)
|
||||
elite_idxs = torch.topk(value, self.config.n_elites, dim=0).indices
|
||||
elite_value, elite_actions = value[elite_idxs], actions[:, elite_idxs]
|
||||
|
||||
# Update guassian PDF parameters to be the (weighted) mean and standard deviation of the elites.
|
||||
max_value = elite_value.max(0)[0]
|
||||
# The weighting is a softmax over trajectory values. Note that this is not the same as the usage
|
||||
# of Ω in eqn 4 of the TD-MPC paper. Instead it is the normalized version of it: s = Ω/ΣΩ. This
|
||||
# makes the equations: μ = Σ(s⋅Γ), σ = Σ(s⋅(Γ-μ)²).
|
||||
score = torch.exp(self.config.elite_weighting_temperature * (elite_value - max_value))
|
||||
score /= score.sum()
|
||||
_mean = torch.sum(einops.rearrange(score, "n -> n 1") * elite_actions, dim=1)
|
||||
_std = torch.sqrt(
|
||||
torch.sum(
|
||||
einops.rearrange(score, "n -> n 1")
|
||||
* (elite_actions - einops.rearrange(_mean, "h d -> h 1 d")) ** 2,
|
||||
dim=1,
|
||||
)
|
||||
)
|
||||
# Update mean with an exponential moving average, and std with a direct replacement.
|
||||
mean = (
|
||||
self.config.gaussian_mean_momentum * mean + (1 - self.config.gaussian_mean_momentum) * _mean
|
||||
)
|
||||
std = _std.clamp_(self.config.min_std, self.config.max_std)
|
||||
|
||||
# Keep track of the mean for warm-starting subsequent steps.
|
||||
self._prev_mean = mean
|
||||
|
||||
# Randomly select one of the elite actions from the last iteration of MPPI/CEM using the softmax
|
||||
# scores from the last iteration.
|
||||
actions = elite_actions[:, torch.multinomial(score, 1).item()]
|
||||
|
||||
# Select only the first action
|
||||
action = actions[0]
|
||||
return action
|
||||
|
||||
@torch.no_grad()
|
||||
def estimate_value(self, z: Tensor, actions: Tensor):
|
||||
"""Estimates the value of a trajectory as per eqn 4 of the FOWM paper.
|
||||
|
||||
Args:
|
||||
z: (batch, latent_dim) tensor of initial latent states.
|
||||
actions: (horizon, batch, action_dim) tensor of action trajectories.
|
||||
Returns:
|
||||
(batch,) tensor of values.
|
||||
"""
|
||||
# Initialize return and running discount factor.
|
||||
G, running_discount = 0, 1
|
||||
# Iterate over the actions in the trajectory to simulate the trajectory using the latent dynamics
|
||||
# model. Keep track of return.
|
||||
for t in range(actions.shape[0]):
|
||||
# We will compute the reward in a moment. First compute the uncertainty regularizer from eqn 4
|
||||
# of the FOWM paper.
|
||||
if self.config.uncertainty_regularizer_coeff > 0:
|
||||
regularization = -(
|
||||
self.config.uncertainty_regularizer_coeff * self.model.Qs(z, actions[t]).std(0)
|
||||
)
|
||||
else:
|
||||
regularization = 0
|
||||
# Estimate the next state (latent) and reward.
|
||||
z, reward = self.model.latent_dynamics_and_reward(z, actions[t])
|
||||
# Update the return and running discount.
|
||||
G += running_discount * (reward + regularization)
|
||||
running_discount *= self.config.discount
|
||||
# Add the estimated value of the final state (using the minimum for a conservative estimate).
|
||||
# Do so by predicting the next action, then taking a minimum over the ensemble of state-action value
|
||||
# estimators.
|
||||
# Note: This small amount of added noise seems to help a bit at inference time as observed by success
|
||||
# metrics over 50 episodes of xarm_lift_medium_replay.
|
||||
next_action = self.model.pi(z, self.config.min_std) # (batch, action_dim)
|
||||
terminal_values = self.model.Qs(z, next_action) # (ensemble, batch)
|
||||
# Randomly choose 2 of the Qs for terminal value estimation (as in App C. of the FOWM paper).
|
||||
if self.config.q_ensemble_size > 2:
|
||||
G += (
|
||||
running_discount
|
||||
* torch.min(terminal_values[torch.randint(0, self.config.q_ensemble_size, size=(2,))], dim=0)[
|
||||
0
|
||||
]
|
||||
)
|
||||
else:
|
||||
G += running_discount * torch.min(terminal_values, dim=0)[0]
|
||||
# Finally, also regularize the terminal value.
|
||||
if self.config.uncertainty_regularizer_coeff > 0:
|
||||
G -= running_discount * self.config.uncertainty_regularizer_coeff * terminal_values.std(0)
|
||||
return G
|
||||
|
||||
def forward(self, batch: dict[str, Tensor]) -> dict[str, Tensor]:
|
||||
"""Run the batch through the model and compute the loss."""
|
||||
device = get_device_from_parameters(self)
|
||||
|
||||
batch = self.normalize_inputs(batch)
|
||||
batch = self.normalize_targets(batch)
|
||||
|
||||
info = {}
|
||||
|
||||
# TODO(alexander-soare): Refactor TDMPC and make it comply with the policy interface documentation.
|
||||
batch_size = batch["index"].shape[0]
|
||||
|
||||
# (b, t) -> (t, b)
|
||||
for key in batch:
|
||||
if batch[key].ndim > 1:
|
||||
batch[key] = batch[key].transpose(1, 0)
|
||||
|
||||
action = batch["action"] # (t, b)
|
||||
reward = batch["next.reward"] # (t,)
|
||||
observations = {k: v for k, v in batch.items() if k.startswith("observation.")}
|
||||
|
||||
# Apply random image augmentations.
|
||||
if self.config.max_random_shift_ratio > 0:
|
||||
observations["observation.image"] = flatten_forward_unflatten(
|
||||
partial(random_shifts_aug, max_random_shift_ratio=self.config.max_random_shift_ratio),
|
||||
observations["observation.image"],
|
||||
)
|
||||
|
||||
# Get the current observation for predicting trajectories, and all future observations for use in
|
||||
# the latent consistency loss and TD loss.
|
||||
current_observation, next_observations = {}, {}
|
||||
for k in observations:
|
||||
current_observation[k] = observations[k][0]
|
||||
next_observations[k] = observations[k][1:]
|
||||
horizon = next_observations["observation.image"].shape[0]
|
||||
|
||||
# Run latent rollout using the latent dynamics model and policy model.
|
||||
# Note this has shape `horizon+1` because there are `horizon` actions and a current `z`. Each action
|
||||
# gives us a next `z`.
|
||||
z_preds = torch.empty(horizon + 1, batch_size, self.config.latent_dim, device=device)
|
||||
z_preds[0] = self.model.encode(current_observation)
|
||||
reward_preds = torch.empty_like(reward, device=device)
|
||||
for t in range(horizon):
|
||||
z_preds[t + 1], reward_preds[t] = self.model.latent_dynamics_and_reward(z_preds[t], action[t])
|
||||
|
||||
# Compute Q and V value predictions based on the latent rollout.
|
||||
q_preds_ensemble = self.model.Qs(z_preds[:-1], action) # (ensemble, horizon, batch)
|
||||
v_preds = self.model.V(z_preds[:-1])
|
||||
info.update({"Q": q_preds_ensemble.mean().item(), "V": v_preds.mean().item()})
|
||||
|
||||
# Compute various targets with stopgrad.
|
||||
with torch.no_grad():
|
||||
# Latent state consistency targets.
|
||||
z_targets = self.model_target.encode(next_observations)
|
||||
# State-action value targets (or TD targets) as in eqn 3 of the FOWM. Unlike TD-MPC which uses the
|
||||
# learned state-action value function in conjunction with the learned policy: Q(z, π(z)), FOWM
|
||||
# uses a learned state value function: V(z). This means the TD targets only depend on in-sample
|
||||
# actions (not actions estimated by π).
|
||||
# Note: Here we do not use self.model_target, but self.model. This is to follow the original code
|
||||
# and the FOWM paper.
|
||||
q_targets = reward + self.config.discount * self.model.V(self.model.encode(next_observations))
|
||||
# From eqn 3 of FOWM. These appear as Q(z, a). Here we call them v_targets to emphasize that we
|
||||
# are using them to compute loss for V.
|
||||
v_targets = self.model_target.Qs(z_preds[:-1].detach(), action, return_min=True)
|
||||
|
||||
# Compute losses.
|
||||
# Exponentially decay the loss weight with respect to the timestep. Steps that are more distant in the
|
||||
# future have less impact on the loss. Note: unsqueeze will let us broadcast to (seq, batch).
|
||||
temporal_loss_coeffs = torch.pow(
|
||||
self.config.temporal_decay_coeff, torch.arange(horizon, device=device)
|
||||
).unsqueeze(-1)
|
||||
# Compute consistency loss as MSE loss between latents predicted from the rollout and latents
|
||||
# predicted from the (target model's) observation encoder.
|
||||
consistency_loss = (
|
||||
(
|
||||
temporal_loss_coeffs
|
||||
* F.mse_loss(z_preds[1:], z_targets, reduction="none").mean(dim=-1)
|
||||
# `z_preds` depends on the current observation and the actions.
|
||||
* ~batch["observation.state_is_pad"][0]
|
||||
* ~batch["action_is_pad"]
|
||||
# `z_targets` depends on the next observation.
|
||||
* ~batch["observation.state_is_pad"][1:]
|
||||
)
|
||||
.sum(0)
|
||||
.mean()
|
||||
)
|
||||
# Compute the reward loss as MSE loss between rewards predicted from the rollout and the dataset
|
||||
# rewards.
|
||||
reward_loss = (
|
||||
(
|
||||
temporal_loss_coeffs
|
||||
* F.mse_loss(reward_preds, reward, reduction="none")
|
||||
* ~batch["next.reward_is_pad"]
|
||||
# `reward_preds` depends on the current observation and the actions.
|
||||
* ~batch["observation.state_is_pad"][0]
|
||||
* ~batch["action_is_pad"]
|
||||
)
|
||||
.sum(0)
|
||||
.mean()
|
||||
)
|
||||
# Compute state-action value loss (TD loss) for all of the Q functions in the ensemble.
|
||||
q_value_loss = (
|
||||
(
|
||||
F.mse_loss(
|
||||
q_preds_ensemble,
|
||||
einops.repeat(q_targets, "t b -> e t b", e=q_preds_ensemble.shape[0]),
|
||||
reduction="none",
|
||||
).sum(0) # sum over ensemble
|
||||
# `q_preds_ensemble` depends on the first observation and the actions.
|
||||
* ~batch["observation.state_is_pad"][0]
|
||||
* ~batch["action_is_pad"]
|
||||
# q_targets depends on the reward and the next observations.
|
||||
* ~batch["next.reward_is_pad"]
|
||||
* ~batch["observation.state_is_pad"][1:]
|
||||
)
|
||||
.sum(0)
|
||||
.mean()
|
||||
)
|
||||
# Compute state value loss as in eqn 3 of FOWM.
|
||||
diff = v_targets - v_preds
|
||||
# Expectile loss penalizes:
|
||||
# - `v_preds < v_targets` with weighting `expectile_weight`
|
||||
# - `v_preds >= v_targets` with weighting `1 - expectile_weight`
|
||||
raw_v_value_loss = torch.where(
|
||||
diff > 0, self.config.expectile_weight, (1 - self.config.expectile_weight)
|
||||
) * (diff**2)
|
||||
v_value_loss = (
|
||||
(
|
||||
temporal_loss_coeffs
|
||||
* raw_v_value_loss
|
||||
# `v_targets` depends on the first observation and the actions, as does `v_preds`.
|
||||
* ~batch["observation.state_is_pad"][0]
|
||||
* ~batch["action_is_pad"]
|
||||
)
|
||||
.sum(0)
|
||||
.mean()
|
||||
)
|
||||
|
||||
# Calculate the advantage weighted regression loss for π as detailed in FOWM 3.1.
|
||||
# We won't need these gradients again so detach.
|
||||
z_preds = z_preds.detach()
|
||||
# Use stopgrad for the advantage calculation.
|
||||
with torch.no_grad():
|
||||
advantage = self.model_target.Qs(z_preds[:-1], action, return_min=True) - self.model.V(
|
||||
z_preds[:-1]
|
||||
)
|
||||
info["advantage"] = advantage[0]
|
||||
# (t, b)
|
||||
exp_advantage = torch.clamp(torch.exp(advantage * self.config.advantage_scaling), max=100.0)
|
||||
action_preds = self.model.pi(z_preds[:-1]) # (t, b, a)
|
||||
# Calculate the MSE between the actions and the action predictions.
|
||||
# Note: FOWM's original code calculates the log probability (wrt to a unit standard deviation
|
||||
# gaussian) and sums over the action dimension. Computing the log probability amounts to multiplying
|
||||
# the MSE by 0.5 and adding a constant offset (the log(2*pi) term) . Here we drop the constant offset
|
||||
# as it doesn't change the optimization step, and we drop the 0.5 as we instead make a configuration
|
||||
# parameter for it (see below where we compute the total loss).
|
||||
mse = F.mse_loss(action_preds, action, reduction="none").sum(-1) # (t, b)
|
||||
# NOTE: The original implementation does not take the sum over the temporal dimension like with the
|
||||
# other losses.
|
||||
# TODO(alexander-soare): Take the sum over the temporal dimension and check that training still works
|
||||
# as well as expected.
|
||||
pi_loss = (
|
||||
exp_advantage
|
||||
* mse
|
||||
* temporal_loss_coeffs
|
||||
# `action_preds` depends on the first observation and the actions.
|
||||
* ~batch["observation.state_is_pad"][0]
|
||||
* ~batch["action_is_pad"]
|
||||
).mean()
|
||||
|
||||
loss = (
|
||||
self.config.consistency_coeff * consistency_loss
|
||||
+ self.config.reward_coeff * reward_loss
|
||||
+ self.config.value_coeff * q_value_loss
|
||||
+ self.config.value_coeff * v_value_loss
|
||||
+ self.config.pi_coeff * pi_loss
|
||||
)
|
||||
|
||||
info.update(
|
||||
{
|
||||
"consistency_loss": consistency_loss.item(),
|
||||
"reward_loss": reward_loss.item(),
|
||||
"Q_value_loss": q_value_loss.item(),
|
||||
"V_value_loss": v_value_loss.item(),
|
||||
"pi_loss": pi_loss.item(),
|
||||
"loss": loss,
|
||||
"sum_loss": loss.item() * self.config.horizon,
|
||||
}
|
||||
)
|
||||
|
||||
# Undo (b, t) -> (t, b).
|
||||
for key in batch:
|
||||
if batch[key].ndim > 1:
|
||||
batch[key] = batch[key].transpose(1, 0)
|
||||
|
||||
return info
|
||||
|
||||
def update(self):
|
||||
"""Update the target model's parameters with an EMA step."""
|
||||
# Note a minor variation with respect to the original FOWM code. Here they do this based on an EMA
|
||||
# update frequency parameter which is set to 2 (every 2 steps an update is done). To simplify the code
|
||||
# we update every step and adjust the decay parameter `alpha` accordingly (0.99 -> 0.995)
|
||||
update_ema_parameters(self.model_target, self.model, self.config.target_model_momentum)
|
||||
|
||||
|
||||
class TDMPCTOLD(nn.Module):
|
||||
"""Task-Oriented Latent Dynamics (TOLD) model used in TD-MPC."""
|
||||
|
||||
def __init__(self, config: TDMPCConfig):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self._encoder = TDMPCObservationEncoder(config)
|
||||
self._dynamics = nn.Sequential(
|
||||
nn.Linear(config.latent_dim + config.output_shapes["action"][0], config.mlp_dim),
|
||||
nn.LayerNorm(config.mlp_dim),
|
||||
nn.Mish(),
|
||||
nn.Linear(config.mlp_dim, config.mlp_dim),
|
||||
nn.LayerNorm(config.mlp_dim),
|
||||
nn.Mish(),
|
||||
nn.Linear(config.mlp_dim, config.latent_dim),
|
||||
nn.LayerNorm(config.latent_dim),
|
||||
nn.Sigmoid(),
|
||||
)
|
||||
self._reward = nn.Sequential(
|
||||
nn.Linear(config.latent_dim + config.output_shapes["action"][0], config.mlp_dim),
|
||||
nn.LayerNorm(config.mlp_dim),
|
||||
nn.Mish(),
|
||||
nn.Linear(config.mlp_dim, config.mlp_dim),
|
||||
nn.LayerNorm(config.mlp_dim),
|
||||
nn.Mish(),
|
||||
nn.Linear(config.mlp_dim, 1),
|
||||
)
|
||||
self._pi = nn.Sequential(
|
||||
nn.Linear(config.latent_dim, config.mlp_dim),
|
||||
nn.LayerNorm(config.mlp_dim),
|
||||
nn.Mish(),
|
||||
nn.Linear(config.mlp_dim, config.mlp_dim),
|
||||
nn.LayerNorm(config.mlp_dim),
|
||||
nn.Mish(),
|
||||
nn.Linear(config.mlp_dim, config.output_shapes["action"][0]),
|
||||
)
|
||||
self._Qs = nn.ModuleList(
|
||||
[
|
||||
nn.Sequential(
|
||||
nn.Linear(config.latent_dim + config.output_shapes["action"][0], config.mlp_dim),
|
||||
nn.LayerNorm(config.mlp_dim),
|
||||
nn.Tanh(),
|
||||
nn.Linear(config.mlp_dim, config.mlp_dim),
|
||||
nn.ELU(),
|
||||
nn.Linear(config.mlp_dim, 1),
|
||||
)
|
||||
for _ in range(config.q_ensemble_size)
|
||||
]
|
||||
)
|
||||
self._V = nn.Sequential(
|
||||
nn.Linear(config.latent_dim, config.mlp_dim),
|
||||
nn.LayerNorm(config.mlp_dim),
|
||||
nn.Tanh(),
|
||||
nn.Linear(config.mlp_dim, config.mlp_dim),
|
||||
nn.ELU(),
|
||||
nn.Linear(config.mlp_dim, 1),
|
||||
)
|
||||
self._init_weights()
|
||||
|
||||
def _init_weights(self):
|
||||
"""Initialize model weights.
|
||||
|
||||
Orthogonal initialization for all linear and convolutional layers' weights (apart from final layers
|
||||
of reward network and Q networks which get zero initialization).
|
||||
Zero initialization for all linear and convolutional layers' biases.
|
||||
"""
|
||||
|
||||
def _apply_fn(m):
|
||||
if isinstance(m, nn.Linear):
|
||||
nn.init.orthogonal_(m.weight.data)
|
||||
if m.bias is not None:
|
||||
nn.init.zeros_(m.bias)
|
||||
elif isinstance(m, nn.Conv2d):
|
||||
gain = nn.init.calculate_gain("relu")
|
||||
nn.init.orthogonal_(m.weight.data, gain)
|
||||
if m.bias is not None:
|
||||
nn.init.zeros_(m.bias)
|
||||
|
||||
self.apply(_apply_fn)
|
||||
for m in [self._reward, *self._Qs]:
|
||||
assert isinstance(
|
||||
m[-1], nn.Linear
|
||||
), "Sanity check. The last linear layer needs 0 initialization on weights."
|
||||
nn.init.zeros_(m[-1].weight)
|
||||
nn.init.zeros_(m[-1].bias) # this has already been done, but keep this line here for good measure
|
||||
|
||||
def encode(self, obs: dict[str, Tensor]) -> Tensor:
|
||||
"""Encodes an observation into its latent representation."""
|
||||
return self._encoder(obs)
|
||||
|
||||
def latent_dynamics_and_reward(self, z: Tensor, a: Tensor) -> tuple[Tensor, Tensor]:
|
||||
"""Predict the next state's latent representation and the reward given a current latent and action.
|
||||
|
||||
Args:
|
||||
z: (*, latent_dim) tensor for the current state's latent representation.
|
||||
a: (*, action_dim) tensor for the action to be applied.
|
||||
Returns:
|
||||
A tuple containing:
|
||||
- (*, latent_dim) tensor for the next state's latent representation.
|
||||
- (*,) tensor for the estimated reward.
|
||||
"""
|
||||
x = torch.cat([z, a], dim=-1)
|
||||
return self._dynamics(x), self._reward(x).squeeze(-1)
|
||||
|
||||
def latent_dynamics(self, z: Tensor, a: Tensor) -> Tensor:
|
||||
"""Predict the next state's latent representation given a current latent and action.
|
||||
|
||||
Args:
|
||||
z: (*, latent_dim) tensor for the current state's latent representation.
|
||||
a: (*, action_dim) tensor for the action to be applied.
|
||||
Returns:
|
||||
(*, latent_dim) tensor for the next state's latent representation.
|
||||
"""
|
||||
x = torch.cat([z, a], dim=-1)
|
||||
return self._dynamics(x)
|
||||
|
||||
def pi(self, z: Tensor, std: float = 0.0) -> Tensor:
|
||||
"""Samples an action from the learned policy.
|
||||
|
||||
The policy can also have added (truncated) Gaussian noise injected for encouraging exploration when
|
||||
generating rollouts for online training.
|
||||
|
||||
Args:
|
||||
z: (*, latent_dim) tensor for the current state's latent representation.
|
||||
std: The standard deviation of the injected noise.
|
||||
Returns:
|
||||
(*, action_dim) tensor for the sampled action.
|
||||
"""
|
||||
action = torch.tanh(self._pi(z))
|
||||
if std > 0:
|
||||
std = torch.ones_like(action) * std
|
||||
action += torch.randn_like(action) * std
|
||||
return action
|
||||
|
||||
def V(self, z: Tensor) -> Tensor: # noqa: N802
|
||||
"""Predict state value (V).
|
||||
|
||||
Args:
|
||||
z: (*, latent_dim) tensor for the current state's latent representation.
|
||||
Returns:
|
||||
(*,) tensor of estimated state values.
|
||||
"""
|
||||
return self._V(z).squeeze(-1)
|
||||
|
||||
def Qs(self, z: Tensor, a: Tensor, return_min: bool = False) -> Tensor: # noqa: N802
|
||||
"""Predict state-action value for all of the learned Q functions.
|
||||
|
||||
Args:
|
||||
z: (*, latent_dim) tensor for the current state's latent representation.
|
||||
a: (*, action_dim) tensor for the action to be applied.
|
||||
return_min: Set to true for implementing the detail in App. C of the FOWM paper: randomly select
|
||||
2 of the Qs and return the minimum
|
||||
Returns:
|
||||
(q_ensemble, *) tensor for the value predictions of each learned Q function in the ensemble OR
|
||||
(*,) tensor if return_min=True.
|
||||
"""
|
||||
x = torch.cat([z, a], dim=-1)
|
||||
if not return_min:
|
||||
return torch.stack([q(x).squeeze(-1) for q in self._Qs], dim=0)
|
||||
else:
|
||||
if len(self._Qs) > 2: # noqa: SIM108
|
||||
Qs = [self._Qs[i] for i in np.random.choice(len(self._Qs), size=2)]
|
||||
else:
|
||||
Qs = self._Qs
|
||||
return torch.stack([q(x).squeeze(-1) for q in Qs], dim=0).min(dim=0)[0]
|
||||
|
||||
|
||||
class TDMPCObservationEncoder(nn.Module):
|
||||
"""Encode image and/or state vector observations."""
|
||||
|
||||
def __init__(self, config: TDMPCConfig):
|
||||
"""
|
||||
Creates encoders for pixel and/or state modalities.
|
||||
TODO(alexander-soare): The original work allows for multiple images by concatenating them along the
|
||||
channel dimension. Re-implement this capability.
|
||||
"""
|
||||
super().__init__()
|
||||
self.config = config
|
||||
|
||||
if "observation.image" in config.input_shapes:
|
||||
self.image_enc_layers = nn.Sequential(
|
||||
nn.Conv2d(
|
||||
config.input_shapes["observation.image"][0], config.image_encoder_hidden_dim, 7, stride=2
|
||||
),
|
||||
nn.ReLU(),
|
||||
nn.Conv2d(config.image_encoder_hidden_dim, config.image_encoder_hidden_dim, 5, stride=2),
|
||||
nn.ReLU(),
|
||||
nn.Conv2d(config.image_encoder_hidden_dim, config.image_encoder_hidden_dim, 3, stride=2),
|
||||
nn.ReLU(),
|
||||
nn.Conv2d(config.image_encoder_hidden_dim, config.image_encoder_hidden_dim, 3, stride=2),
|
||||
nn.ReLU(),
|
||||
)
|
||||
dummy_batch = torch.zeros(1, *config.input_shapes["observation.image"])
|
||||
with torch.inference_mode():
|
||||
out_shape = self.image_enc_layers(dummy_batch).shape[1:]
|
||||
self.image_enc_layers.extend(
|
||||
nn.Sequential(
|
||||
nn.Flatten(),
|
||||
nn.Linear(np.prod(out_shape), config.latent_dim),
|
||||
nn.LayerNorm(config.latent_dim),
|
||||
nn.Sigmoid(),
|
||||
)
|
||||
)
|
||||
if "observation.state" in config.input_shapes:
|
||||
self.state_enc_layers = nn.Sequential(
|
||||
nn.Linear(config.input_shapes["observation.state"][0], config.state_encoder_hidden_dim),
|
||||
nn.ELU(),
|
||||
nn.Linear(config.state_encoder_hidden_dim, config.latent_dim),
|
||||
nn.LayerNorm(config.latent_dim),
|
||||
nn.Sigmoid(),
|
||||
)
|
||||
|
||||
def forward(self, obs_dict: dict[str, Tensor]) -> Tensor:
|
||||
"""Encode the image and/or state vector.
|
||||
|
||||
Each modality is encoded into a feature vector of size (latent_dim,) and then a uniform mean is taken
|
||||
over all features.
|
||||
"""
|
||||
feat = []
|
||||
if "observation.image" in self.config.input_shapes:
|
||||
feat.append(flatten_forward_unflatten(self.image_enc_layers, obs_dict["observation.image"]))
|
||||
if "observation.state" in self.config.input_shapes:
|
||||
feat.append(self.state_enc_layers(obs_dict["observation.state"]))
|
||||
return torch.stack(feat, dim=0).mean(0)
|
||||
|
||||
|
||||
def random_shifts_aug(x: Tensor, max_random_shift_ratio: float) -> Tensor:
|
||||
"""Randomly shifts images horizontally and vertically.
|
||||
|
||||
Adapted from https://github.com/facebookresearch/drqv2
|
||||
"""
|
||||
b, _, h, w = x.size()
|
||||
assert h == w, "non-square images not handled yet"
|
||||
pad = int(round(max_random_shift_ratio * h))
|
||||
x = F.pad(x, tuple([pad] * 4), "replicate")
|
||||
eps = 1.0 / (h + 2 * pad)
|
||||
arange = torch.linspace(
|
||||
-1.0 + eps,
|
||||
1.0 - eps,
|
||||
h + 2 * pad,
|
||||
device=x.device,
|
||||
dtype=torch.float32,
|
||||
)[:h]
|
||||
arange = einops.repeat(arange, "w -> h w 1", h=h)
|
||||
base_grid = torch.cat([arange, arange.transpose(1, 0)], dim=2)
|
||||
base_grid = einops.repeat(base_grid, "h w c -> b h w c", b=b)
|
||||
# A random shift in units of pixels and within the boundaries of the padding.
|
||||
shift = torch.randint(
|
||||
0,
|
||||
2 * pad + 1,
|
||||
size=(b, 1, 1, 2),
|
||||
device=x.device,
|
||||
dtype=torch.float32,
|
||||
)
|
||||
shift *= 2.0 / (h + 2 * pad)
|
||||
grid = base_grid + shift
|
||||
return F.grid_sample(x, grid, padding_mode="zeros", align_corners=False)
|
||||
|
||||
|
||||
def update_ema_parameters(ema_net: nn.Module, net: nn.Module, alpha: float):
|
||||
"""Update EMA parameters in place with ema_param <- alpha * ema_param + (1 - alpha) * param."""
|
||||
for ema_module, module in zip(ema_net.modules(), net.modules(), strict=True):
|
||||
for (n_p_ema, p_ema), (n_p, p) in zip(
|
||||
ema_module.named_parameters(recurse=False), module.named_parameters(recurse=False), strict=True
|
||||
):
|
||||
assert n_p_ema == n_p, "Parameter names don't match for EMA model update"
|
||||
if isinstance(p, dict):
|
||||
raise RuntimeError("Dict parameter not supported")
|
||||
if isinstance(module, nn.modules.batchnorm._BatchNorm) or not p.requires_grad:
|
||||
# Copy BatchNorm parameters, and non-trainable parameters directly.
|
||||
p_ema.copy_(p.to(dtype=p_ema.dtype).data)
|
||||
with torch.no_grad():
|
||||
p_ema.mul_(alpha)
|
||||
p_ema.add_(p.to(dtype=p_ema.dtype).data, alpha=1 - alpha)
|
||||
|
||||
|
||||
def flatten_forward_unflatten(fn: Callable[[Tensor], Tensor], image_tensor: Tensor) -> Tensor:
|
||||
"""Helper to temporarily flatten extra dims at the start of the image tensor.
|
||||
|
||||
Args:
|
||||
fn: Callable that the image tensor will be passed to. It should accept (B, C, H, W) and return
|
||||
(B, *), where * is any number of dimensions.
|
||||
image_tensor: An image tensor of shape (**, C, H, W), where ** is any number of dimensions, generally
|
||||
different from *.
|
||||
Returns:
|
||||
A return value from the callable reshaped to (**, *).
|
||||
"""
|
||||
if image_tensor.ndim == 4:
|
||||
return fn(image_tensor)
|
||||
start_dims = image_tensor.shape[:-3]
|
||||
inp = torch.flatten(image_tensor, end_dim=-4)
|
||||
flat_out = fn(inp)
|
||||
return torch.reshape(flat_out, (*start_dims, *flat_out.shape[1:]))
|
|
@ -1,495 +0,0 @@
|
|||
# ruff: noqa: N806
|
||||
|
||||
import time
|
||||
from collections import deque
|
||||
from copy import deepcopy
|
||||
|
||||
import einops
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
import lerobot.common.policies.tdmpc.helper as h
|
||||
from lerobot.common.policies.utils import populate_queues
|
||||
from lerobot.common.utils.utils import get_safe_torch_device
|
||||
|
||||
FIRST_FRAME = 0
|
||||
|
||||
|
||||
class TOLD(nn.Module):
|
||||
"""Task-Oriented Latent Dynamics (TOLD) model used in TD-MPC."""
|
||||
|
||||
def __init__(self, cfg):
|
||||
super().__init__()
|
||||
action_dim = cfg.action_dim
|
||||
|
||||
self.cfg = cfg
|
||||
self._encoder = h.enc(cfg)
|
||||
self._dynamics = h.dynamics(cfg.latent_dim + action_dim, cfg.mlp_dim, cfg.latent_dim)
|
||||
self._reward = h.mlp(cfg.latent_dim + action_dim, cfg.mlp_dim, 1)
|
||||
self._pi = h.mlp(cfg.latent_dim, cfg.mlp_dim, action_dim)
|
||||
self._Qs = nn.ModuleList([h.q(cfg) for _ in range(cfg.num_q)])
|
||||
self._V = h.v(cfg)
|
||||
self.apply(h.orthogonal_init)
|
||||
for m in [self._reward, *self._Qs]:
|
||||
m[-1].weight.data.fill_(0)
|
||||
m[-1].bias.data.fill_(0)
|
||||
|
||||
def track_q_grad(self, enable=True):
|
||||
"""Utility function. Enables/disables gradient tracking of Q-networks."""
|
||||
for m in self._Qs:
|
||||
h.set_requires_grad(m, enable)
|
||||
|
||||
def track_v_grad(self, enable=True):
|
||||
"""Utility function. Enables/disables gradient tracking of Q-networks."""
|
||||
if hasattr(self, "_V"):
|
||||
h.set_requires_grad(self._V, enable)
|
||||
|
||||
def encode(self, obs):
|
||||
"""Encodes an observation into its latent representation."""
|
||||
out = self._encoder(obs)
|
||||
if isinstance(obs, dict):
|
||||
# fusion
|
||||
out = torch.stack([v for k, v in out.items()]).mean(dim=0)
|
||||
return out
|
||||
|
||||
def next(self, z, a):
|
||||
"""Predicts next latent state (d) and single-step reward (R)."""
|
||||
x = torch.cat([z, a], dim=-1)
|
||||
return self._dynamics(x), self._reward(x)
|
||||
|
||||
def next_dynamics(self, z, a):
|
||||
"""Predicts next latent state (d)."""
|
||||
x = torch.cat([z, a], dim=-1)
|
||||
return self._dynamics(x)
|
||||
|
||||
def pi(self, z, std=0):
|
||||
"""Samples an action from the learned policy (pi)."""
|
||||
mu = torch.tanh(self._pi(z))
|
||||
if std > 0:
|
||||
std = torch.ones_like(mu) * std
|
||||
return h.TruncatedNormal(mu, std).sample(clip=0.3)
|
||||
return mu
|
||||
|
||||
def V(self, z): # noqa: N802
|
||||
"""Predict state value (V)."""
|
||||
return self._V(z)
|
||||
|
||||
def Q(self, z, a, return_type): # noqa: N802
|
||||
"""Predict state-action value (Q)."""
|
||||
assert return_type in {"min", "avg", "all"}
|
||||
x = torch.cat([z, a], dim=-1)
|
||||
|
||||
if return_type == "all":
|
||||
return torch.stack([q(x) for q in self._Qs], dim=0)
|
||||
|
||||
idxs = np.random.choice(self.cfg.num_q, 2, replace=False)
|
||||
Q1, Q2 = self._Qs[idxs[0]](x), self._Qs[idxs[1]](x)
|
||||
return torch.min(Q1, Q2) if return_type == "min" else (Q1 + Q2) / 2
|
||||
|
||||
|
||||
class TDMPCPolicy(nn.Module):
|
||||
"""Implementation of TD-MPC learning + inference."""
|
||||
|
||||
name = "tdmpc"
|
||||
|
||||
def __init__(self, cfg, n_obs_steps, n_action_steps, device):
|
||||
super().__init__()
|
||||
self.action_dim = cfg.action_dim
|
||||
|
||||
self.cfg = cfg
|
||||
self.n_obs_steps = n_obs_steps
|
||||
self.n_action_steps = n_action_steps
|
||||
self.device = get_safe_torch_device(device)
|
||||
self.std = h.linear_schedule(cfg.std_schedule, 0)
|
||||
self.model = TOLD(cfg)
|
||||
self.model.to(self.device)
|
||||
self.model_target = deepcopy(self.model)
|
||||
self.optim = torch.optim.Adam(self.model.parameters(), lr=self.cfg.lr)
|
||||
self.pi_optim = torch.optim.Adam(self.model._pi.parameters(), lr=self.cfg.lr)
|
||||
# self.bc_optim = torch.optim.Adam(self.model.parameters(), lr=self.cfg.lr)
|
||||
self.model.eval()
|
||||
self.model_target.eval()
|
||||
|
||||
self.register_buffer("step", torch.zeros(1))
|
||||
|
||||
def state_dict(self):
|
||||
"""Retrieve state dict of TOLD model, including slow-moving target network."""
|
||||
return {
|
||||
"model": self.model.state_dict(),
|
||||
"model_target": self.model_target.state_dict(),
|
||||
}
|
||||
|
||||
def save(self, fp):
|
||||
"""Save state dict of TOLD model to filepath."""
|
||||
torch.save(self.state_dict(), fp)
|
||||
|
||||
def load(self, fp):
|
||||
"""Load a saved state dict from filepath into current agent."""
|
||||
d = torch.load(fp)
|
||||
self.model.load_state_dict(d["model"])
|
||||
self.model_target.load_state_dict(d["model_target"])
|
||||
|
||||
def reset(self):
|
||||
"""
|
||||
Clear observation and action queues. Should be called on `env.reset()`
|
||||
"""
|
||||
self._queues = {
|
||||
"observation.image": deque(maxlen=self.n_obs_steps),
|
||||
"observation.state": deque(maxlen=self.n_obs_steps),
|
||||
"action": deque(maxlen=self.n_action_steps),
|
||||
}
|
||||
|
||||
@torch.no_grad()
|
||||
def select_action(self, batch, step):
|
||||
assert "observation.image" in batch
|
||||
assert "observation.state" in batch
|
||||
assert len(batch) == 2
|
||||
|
||||
self._queues = populate_queues(self._queues, batch)
|
||||
|
||||
t0 = step == 0
|
||||
|
||||
self.eval()
|
||||
|
||||
if len(self._queues["action"]) == 0:
|
||||
batch = {key: torch.stack(list(self._queues[key]), dim=1) for key in batch}
|
||||
|
||||
if self.n_obs_steps == 1:
|
||||
# hack to remove the time dimension
|
||||
for key in batch:
|
||||
assert batch[key].shape[1] == 1
|
||||
batch[key] = batch[key][:, 0]
|
||||
|
||||
actions = []
|
||||
batch_size = batch["observation.image"].shape[0]
|
||||
for i in range(batch_size):
|
||||
obs = {
|
||||
"rgb": batch["observation.image"][[i]],
|
||||
"state": batch["observation.state"][[i]],
|
||||
}
|
||||
# Note: unsqueeze needed because `act` still uses non-batch logic.
|
||||
action = self.act(obs, t0=t0, step=self.step)
|
||||
actions.append(action)
|
||||
action = torch.stack(actions)
|
||||
|
||||
# tdmpc returns an action for 1 timestep only, so we copy it over `n_action_steps` time
|
||||
if i in range(self.n_action_steps):
|
||||
self._queues["action"].append(action)
|
||||
|
||||
action = self._queues["action"].popleft()
|
||||
return action
|
||||
|
||||
@torch.no_grad()
|
||||
def act(self, obs, t0=False, step=None):
|
||||
"""Take an action. Uses either MPC or the learned policy, depending on the self.cfg.mpc flag."""
|
||||
obs = {k: o.detach() for k, o in obs.items()} if isinstance(obs, dict) else obs.detach()
|
||||
z = self.model.encode(obs)
|
||||
if self.cfg.mpc:
|
||||
a = self.plan(z, t0=t0, step=step)
|
||||
else:
|
||||
a = self.model.pi(z, self.cfg.min_std * self.model.training).squeeze(0)
|
||||
return a
|
||||
|
||||
@torch.no_grad()
|
||||
def estimate_value(self, z, actions, horizon):
|
||||
"""Estimate value of a trajectory starting at latent state z and executing given actions."""
|
||||
G, discount = 0, 1
|
||||
for t in range(horizon):
|
||||
if self.cfg.uncertainty_cost > 0:
|
||||
G -= (
|
||||
discount
|
||||
* self.cfg.uncertainty_cost
|
||||
* self.model.Q(z, actions[t], return_type="all").std(dim=0)
|
||||
)
|
||||
z, reward = self.model.next(z, actions[t])
|
||||
G += discount * reward
|
||||
discount *= self.cfg.discount
|
||||
pi = self.model.pi(z, self.cfg.min_std)
|
||||
G += discount * self.model.Q(z, pi, return_type="min")
|
||||
if self.cfg.uncertainty_cost > 0:
|
||||
G -= discount * self.cfg.uncertainty_cost * self.model.Q(z, pi, return_type="all").std(dim=0)
|
||||
return G
|
||||
|
||||
@torch.no_grad()
|
||||
def plan(self, z, step=None, t0=True):
|
||||
"""
|
||||
Plan next action using TD-MPC inference.
|
||||
z: latent state.
|
||||
step: current time step. determines e.g. planning horizon.
|
||||
t0: whether current step is the first step of an episode.
|
||||
"""
|
||||
# during eval: eval_mode: uniform sampling and action noise is disabled during evaluation.
|
||||
|
||||
assert step is not None
|
||||
# Seed steps
|
||||
if step < self.cfg.seed_steps and self.model.training:
|
||||
return torch.empty(self.action_dim, dtype=torch.float32, device=self.device).uniform_(-1, 1)
|
||||
|
||||
# Sample policy trajectories
|
||||
horizon = int(min(self.cfg.horizon, h.linear_schedule(self.cfg.horizon_schedule, step)))
|
||||
num_pi_trajs = int(self.cfg.mixture_coef * self.cfg.num_samples)
|
||||
if num_pi_trajs > 0:
|
||||
pi_actions = torch.empty(horizon, num_pi_trajs, self.action_dim, device=self.device)
|
||||
_z = z.repeat(num_pi_trajs, 1)
|
||||
for t in range(horizon):
|
||||
pi_actions[t] = self.model.pi(_z, self.cfg.min_std)
|
||||
_z = self.model.next_dynamics(_z, pi_actions[t])
|
||||
|
||||
# Initialize state and parameters
|
||||
z = z.repeat(self.cfg.num_samples + num_pi_trajs, 1)
|
||||
mean = torch.zeros(horizon, self.action_dim, device=self.device)
|
||||
std = self.cfg.max_std * torch.ones(horizon, self.action_dim, device=self.device)
|
||||
if not t0 and hasattr(self, "_prev_mean"):
|
||||
mean[:-1] = self._prev_mean[1:]
|
||||
|
||||
# Iterate CEM
|
||||
for _ in range(self.cfg.iterations):
|
||||
actions = torch.clamp(
|
||||
mean.unsqueeze(1)
|
||||
+ std.unsqueeze(1)
|
||||
* torch.randn(horizon, self.cfg.num_samples, self.action_dim, device=std.device),
|
||||
-1,
|
||||
1,
|
||||
)
|
||||
if num_pi_trajs > 0:
|
||||
actions = torch.cat([actions, pi_actions], dim=1)
|
||||
|
||||
# Compute elite actions
|
||||
value = self.estimate_value(z, actions, horizon).nan_to_num_(0)
|
||||
elite_idxs = torch.topk(value.squeeze(1), self.cfg.num_elites, dim=0).indices
|
||||
elite_value, elite_actions = value[elite_idxs], actions[:, elite_idxs]
|
||||
|
||||
# Update parameters
|
||||
max_value = elite_value.max(0)[0]
|
||||
score = torch.exp(self.cfg.temperature * (elite_value - max_value))
|
||||
score /= score.sum(0)
|
||||
_mean = torch.sum(score.unsqueeze(0) * elite_actions, dim=1) / (score.sum(0) + 1e-9)
|
||||
_std = torch.sqrt(
|
||||
torch.sum(
|
||||
score.unsqueeze(0) * (elite_actions - _mean.unsqueeze(1)) ** 2,
|
||||
dim=1,
|
||||
)
|
||||
/ (score.sum(0) + 1e-9)
|
||||
)
|
||||
_std = _std.clamp_(self.std, self.cfg.max_std)
|
||||
mean, std = self.cfg.momentum * mean + (1 - self.cfg.momentum) * _mean, _std
|
||||
|
||||
# Outputs
|
||||
# TODO(rcadene): remove numpy with
|
||||
# # Convert score tensor to probabilities using softmax
|
||||
# probabilities = torch.softmax(score, dim=0)
|
||||
# # Generate a random sample index based on the probabilities
|
||||
# sample_index = torch.multinomial(probabilities, 1).item()
|
||||
score = score.squeeze(1).cpu().numpy()
|
||||
actions = elite_actions[:, np.random.choice(np.arange(score.shape[0]), p=score)]
|
||||
self._prev_mean = mean
|
||||
mean, std = actions[0], _std[0]
|
||||
a = mean
|
||||
if self.model.training:
|
||||
a += std * torch.randn(self.action_dim, device=std.device)
|
||||
return torch.clamp(a, -1, 1)
|
||||
|
||||
def update_pi(self, zs, acts=None):
|
||||
"""Update policy using a sequence of latent states."""
|
||||
self.pi_optim.zero_grad(set_to_none=True)
|
||||
self.model.track_q_grad(False)
|
||||
self.model.track_v_grad(False)
|
||||
|
||||
info = {}
|
||||
# Advantage Weighted Regression
|
||||
assert acts is not None
|
||||
vs = self.model.V(zs)
|
||||
qs = self.model_target.Q(zs, acts, return_type="min")
|
||||
adv = qs - vs
|
||||
exp_a = torch.exp(adv * self.cfg.A_scaling)
|
||||
exp_a = torch.clamp(exp_a, max=100.0)
|
||||
log_probs = h.gaussian_logprob(self.model.pi(zs) - acts, 0)
|
||||
rho = torch.pow(self.cfg.rho, torch.arange(len(qs), device=self.device))
|
||||
pi_loss = -((exp_a * log_probs).mean(dim=(1, 2)) * rho).mean()
|
||||
info["adv"] = adv[0]
|
||||
|
||||
pi_loss.backward()
|
||||
torch.nn.utils.clip_grad_norm_(
|
||||
self.model._pi.parameters(),
|
||||
self.cfg.grad_clip_norm,
|
||||
error_if_nonfinite=False,
|
||||
)
|
||||
self.pi_optim.step()
|
||||
self.model.track_q_grad(True)
|
||||
self.model.track_v_grad(True)
|
||||
|
||||
info["pi_loss"] = pi_loss.item()
|
||||
return pi_loss.item(), info
|
||||
|
||||
@torch.no_grad()
|
||||
def _td_target(self, next_z, reward, mask):
|
||||
"""Compute the TD-target from a reward and the observation at the following time step."""
|
||||
next_v = self.model.V(next_z)
|
||||
td_target = reward + self.cfg.discount * mask * next_v.squeeze(2)
|
||||
return td_target
|
||||
|
||||
def forward(self, batch, step):
|
||||
# TODO(alexander-soare): Refactor TDMPC and make it comply with the policy interface documentation.
|
||||
raise NotImplementedError()
|
||||
|
||||
def update(self, batch, step):
|
||||
"""Main update function. Corresponds to one iteration of the model learning."""
|
||||
start_time = time.time()
|
||||
|
||||
batch_size = batch["index"].shape[0]
|
||||
|
||||
# TODO(rcadene): convert tdmpc with (batch size, time/horizon, channels)
|
||||
# instead of currently (time/horizon, batch size, channels) which is not the pytorch convention
|
||||
# batch size b = 256, time/horizon t = 5
|
||||
# b t ... -> t b ...
|
||||
for key in batch:
|
||||
if batch[key].ndim > 1:
|
||||
batch[key] = batch[key].transpose(1, 0)
|
||||
|
||||
action = batch["action"]
|
||||
reward = batch["next.reward"]
|
||||
# idxs = batch["index"] # TODO(rcadene): use idxs to update sampling weights
|
||||
done = torch.zeros_like(reward, dtype=torch.bool, device=reward.device)
|
||||
mask = torch.ones_like(reward, dtype=torch.bool, device=reward.device)
|
||||
weights = torch.ones(batch_size, dtype=torch.bool, device=reward.device)
|
||||
|
||||
obses = {
|
||||
"rgb": batch["observation.image"],
|
||||
"state": batch["observation.state"],
|
||||
}
|
||||
|
||||
shapes = {}
|
||||
for k in obses:
|
||||
shapes[k] = obses[k].shape
|
||||
obses[k] = einops.rearrange(obses[k], "t b ... -> (t b) ... ")
|
||||
|
||||
# Apply augmentations
|
||||
aug_tf = h.aug(self.cfg)
|
||||
obses = aug_tf(obses)
|
||||
|
||||
for k in obses:
|
||||
t, b = shapes[k][:2]
|
||||
obses[k] = einops.rearrange(obses[k], "(t b) ... -> t b ... ", b=b, t=t)
|
||||
|
||||
obs, next_obses = {}, {}
|
||||
for k in obses:
|
||||
obs[k] = obses[k][0]
|
||||
next_obses[k] = obses[k][1:].clone()
|
||||
|
||||
horizon = next_obses["rgb"].shape[0]
|
||||
loss_mask = torch.ones_like(mask, device=self.device)
|
||||
for t in range(1, horizon):
|
||||
loss_mask[t] = loss_mask[t - 1] * (~done[t - 1])
|
||||
|
||||
self.optim.zero_grad(set_to_none=True)
|
||||
self.std = h.linear_schedule(self.cfg.std_schedule, step)
|
||||
self.model.train()
|
||||
|
||||
data_s = time.time() - start_time
|
||||
|
||||
# Compute targets
|
||||
with torch.no_grad():
|
||||
next_z = self.model.encode(next_obses)
|
||||
z_targets = self.model_target.encode(next_obses)
|
||||
td_targets = self._td_target(next_z, reward, mask)
|
||||
|
||||
# Latent rollout
|
||||
zs = torch.empty(horizon + 1, batch_size, self.cfg.latent_dim, device=self.device)
|
||||
reward_preds = torch.empty_like(reward, device=self.device)
|
||||
assert reward.shape[0] == horizon
|
||||
z = self.model.encode(obs)
|
||||
zs[0] = z
|
||||
value_info = {"Q": 0.0, "V": 0.0}
|
||||
for t in range(horizon):
|
||||
z, reward_pred = self.model.next(z, action[t])
|
||||
zs[t + 1] = z
|
||||
reward_preds[t] = reward_pred.squeeze(1)
|
||||
|
||||
with torch.no_grad():
|
||||
v_target = self.model_target.Q(zs[:-1].detach(), action, return_type="min")
|
||||
|
||||
# Predictions
|
||||
qs = self.model.Q(zs[:-1], action, return_type="all")
|
||||
qs = qs.squeeze(3)
|
||||
value_info["Q"] = qs.mean().item()
|
||||
v = self.model.V(zs[:-1])
|
||||
value_info["V"] = v.mean().item()
|
||||
|
||||
# Losses
|
||||
rho = torch.pow(self.cfg.rho, torch.arange(horizon, device=self.device)).view(-1, 1)
|
||||
consistency_loss = (rho * torch.mean(h.mse(zs[1:], z_targets), dim=2) * loss_mask).sum(dim=0)
|
||||
reward_loss = (rho * h.mse(reward_preds, reward) * loss_mask).sum(dim=0)
|
||||
q_value_loss, priority_loss = 0, 0
|
||||
for q in range(self.cfg.num_q):
|
||||
q_value_loss += (rho * h.mse(qs[q], td_targets) * loss_mask).sum(dim=0)
|
||||
priority_loss += (rho * h.l1(qs[q], td_targets) * loss_mask).sum(dim=0)
|
||||
|
||||
expectile = h.linear_schedule(self.cfg.expectile, step)
|
||||
v_value_loss = (rho * h.l2_expectile(v_target - v, expectile=expectile).squeeze(2) * loss_mask).sum(
|
||||
dim=0
|
||||
)
|
||||
|
||||
total_loss = (
|
||||
self.cfg.consistency_coef * consistency_loss
|
||||
+ self.cfg.reward_coef * reward_loss
|
||||
+ self.cfg.value_coef * q_value_loss
|
||||
+ self.cfg.value_coef * v_value_loss
|
||||
)
|
||||
|
||||
weighted_loss = (total_loss * weights).mean()
|
||||
weighted_loss.register_hook(lambda grad: grad * (1 / self.cfg.horizon))
|
||||
has_nan = torch.isnan(weighted_loss).item()
|
||||
if has_nan:
|
||||
print(f"weighted_loss has nan: {total_loss=} {weights=}")
|
||||
else:
|
||||
weighted_loss.backward()
|
||||
|
||||
grad_norm = torch.nn.utils.clip_grad_norm_(
|
||||
self.model.parameters(), self.cfg.grad_clip_norm, error_if_nonfinite=False
|
||||
)
|
||||
self.optim.step()
|
||||
|
||||
# TODO(rcadene): implement PrioritizedSampling by modifying sampler.weights with priorities computed by a criterion
|
||||
# if self.cfg.per:
|
||||
# # Update priorities
|
||||
# priorities = priority_loss.clamp(max=1e4).detach()
|
||||
# has_nan = torch.isnan(priorities).any().item()
|
||||
# if has_nan:
|
||||
# print(f"priorities has nan: {priorities=}")
|
||||
# else:
|
||||
# replay_buffer.update_priority(
|
||||
# idxs[:num_slices],
|
||||
# priorities[:num_slices],
|
||||
# )
|
||||
# if demo_batch_size > 0:
|
||||
# demo_buffer.update_priority(demo_idxs, priorities[num_slices:])
|
||||
|
||||
# Update policy + target network
|
||||
_, pi_update_info = self.update_pi(zs[:-1].detach(), acts=action)
|
||||
|
||||
if step % self.cfg.update_freq == 0:
|
||||
h.ema(self.model._encoder, self.model_target._encoder, self.cfg.tau)
|
||||
h.ema(self.model._Qs, self.model_target._Qs, self.cfg.tau)
|
||||
|
||||
self.model.eval()
|
||||
|
||||
info = {
|
||||
"consistency_loss": float(consistency_loss.mean().item()),
|
||||
"reward_loss": float(reward_loss.mean().item()),
|
||||
"Q_value_loss": float(q_value_loss.mean().item()),
|
||||
"V_value_loss": float(v_value_loss.mean().item()),
|
||||
"sum_loss": float(total_loss.mean().item()),
|
||||
"loss": float(weighted_loss.mean().item()),
|
||||
"grad_norm": float(grad_norm),
|
||||
"lr": self.cfg.lr,
|
||||
"data_s": data_s,
|
||||
"update_s": time.time() - start_time,
|
||||
}
|
||||
# info["demo_batch_size"] = demo_batch_size
|
||||
info["expectile"] = expectile
|
||||
info.update(value_info)
|
||||
info.update(pi_update_info)
|
||||
|
||||
self.step[0] = step
|
||||
return info
|
|
@ -9,31 +9,24 @@ hydra:
|
|||
job:
|
||||
name: default
|
||||
|
||||
seed: 1337
|
||||
# batch size for TorchRL SerialEnv. Each underlying env will get the seed = seed + env_index
|
||||
# NOTE: only diffusion policy supports rollout_batch_size > 1
|
||||
rollout_batch_size: 1
|
||||
device: cuda # cpu
|
||||
prefetch: 4
|
||||
eval_freq: ???
|
||||
save_freq: ???
|
||||
eval_episodes: ???
|
||||
save_video: false
|
||||
save_model: false
|
||||
save_buffer: false
|
||||
train_steps: ???
|
||||
fps: ???
|
||||
seed: ???
|
||||
dataset_repo_id: lerobot/pusht
|
||||
|
||||
offline_prioritized_sampler: true
|
||||
training:
|
||||
offline_steps: ???
|
||||
online_steps: ???
|
||||
online_steps_between_rollouts: ???
|
||||
online_sampling_ratio: 0.5
|
||||
eval_freq: ???
|
||||
save_freq: ???
|
||||
log_freq: 250
|
||||
save_model: false
|
||||
|
||||
dataset:
|
||||
repo_id: ???
|
||||
|
||||
n_action_steps: ???
|
||||
n_obs_steps: ???
|
||||
env: ???
|
||||
|
||||
policy: ???
|
||||
eval:
|
||||
n_episodes: 1
|
||||
# TODO(alexander-soare): Right now this does not work. Reinstate this.
|
||||
batch_size: 1
|
||||
|
||||
wandb:
|
||||
enable: true
|
||||
|
|
|
@ -1,18 +1,7 @@
|
|||
# @package _global_
|
||||
|
||||
eval_episodes: 50
|
||||
eval_freq: 7500
|
||||
save_freq: 75000
|
||||
log_freq: 250
|
||||
# TODO: same as xarm, need to adjust
|
||||
offline_steps: 25000
|
||||
online_steps: 25000
|
||||
|
||||
fps: 50
|
||||
|
||||
dataset:
|
||||
repo_id: lerobot/aloha_sim_insertion_human
|
||||
|
||||
env:
|
||||
name: aloha
|
||||
task: AlohaInsertion-v0
|
||||
|
|
|
@ -1,18 +1,7 @@
|
|||
# @package _global_
|
||||
|
||||
eval_episodes: 50
|
||||
eval_freq: 7500
|
||||
save_freq: 75000
|
||||
log_freq: 250
|
||||
# TODO: same as xarm, need to adjust
|
||||
offline_steps: 25000
|
||||
online_steps: 25000
|
||||
|
||||
fps: 10
|
||||
|
||||
dataset:
|
||||
repo_id: lerobot/pusht
|
||||
|
||||
env:
|
||||
name: pusht
|
||||
task: PushT-v0
|
||||
|
|
|
@ -1,17 +1,7 @@
|
|||
# @package _global_
|
||||
|
||||
eval_episodes: 20
|
||||
eval_freq: 1000
|
||||
save_freq: 10000
|
||||
log_freq: 50
|
||||
offline_steps: 25000
|
||||
online_steps: 25000
|
||||
|
||||
fps: 15
|
||||
|
||||
dataset:
|
||||
repo_id: lerobot/xarm_lift_medium
|
||||
|
||||
env:
|
||||
name: xarm
|
||||
task: XarmLift-v0
|
||||
|
|
|
@ -1,30 +1,41 @@
|
|||
# @package _global_
|
||||
|
||||
offline_steps: 80000
|
||||
online_steps: 0
|
||||
seed: 1000
|
||||
dataset_repo_id: lerobot/aloha_sim_insertion_human
|
||||
|
||||
eval_episodes: 1
|
||||
eval_freq: 10000
|
||||
save_freq: 100000
|
||||
log_freq: 250
|
||||
training:
|
||||
offline_steps: 80000
|
||||
online_steps: 0
|
||||
eval_freq: 10000
|
||||
save_freq: 100000
|
||||
log_freq: 250
|
||||
save_model: true
|
||||
|
||||
n_obs_steps: 1
|
||||
# when temporal_agg=False, n_action_steps=horizon
|
||||
batch_size: 8
|
||||
lr: 1e-5
|
||||
lr_backbone: 1e-5
|
||||
weight_decay: 1e-4
|
||||
grad_clip_norm: 10
|
||||
online_steps_between_rollouts: 1
|
||||
|
||||
override_dataset_stats:
|
||||
observation.images.top:
|
||||
# stats from imagenet, since we use a pretrained vision model
|
||||
mean: [[[0.485]], [[0.456]], [[0.406]]] # (c,1,1)
|
||||
std: [[[0.229]], [[0.224]], [[0.225]]] # (c,1,1)
|
||||
override_dataset_stats:
|
||||
observation.images.top:
|
||||
# stats from imagenet, since we use a pretrained vision model
|
||||
mean: [[[0.485]], [[0.456]], [[0.406]]] # (c,1,1)
|
||||
std: [[[0.229]], [[0.224]], [[0.225]]] # (c,1,1)
|
||||
|
||||
delta_timestamps:
|
||||
action: "[i / ${fps} for i in range(${policy.chunk_size})]"
|
||||
|
||||
eval:
|
||||
n_episodes:: 50
|
||||
|
||||
# See `configuration_act.py` for more details.
|
||||
policy:
|
||||
name: act
|
||||
|
||||
pretrained_model_path:
|
||||
|
||||
# Input / output structure.
|
||||
n_obs_steps: ${n_obs_steps}
|
||||
n_obs_steps: 1
|
||||
chunk_size: 100 # chunk_size
|
||||
n_action_steps: 100
|
||||
|
||||
|
@ -49,7 +60,7 @@ policy:
|
|||
replace_final_stride_with_dilation: false
|
||||
# Transformer layers.
|
||||
pre_norm: false
|
||||
d_model: 512
|
||||
dim_model: 512
|
||||
n_heads: 8
|
||||
dim_feedforward: 3200
|
||||
feedforward_activation: relu
|
||||
|
@ -66,15 +77,3 @@ policy:
|
|||
# Training and loss computation.
|
||||
dropout: 0.1
|
||||
kl_weight: 10.0
|
||||
|
||||
# ---
|
||||
# TODO(alexander-soare): Remove these from the policy config.
|
||||
batch_size: 8
|
||||
lr: 1e-5
|
||||
lr_backbone: 1e-5
|
||||
weight_decay: 1e-4
|
||||
grad_clip_norm: 10
|
||||
utd: 1
|
||||
|
||||
delta_timestamps:
|
||||
action: "[i / ${fps} for i in range(${policy.chunk_size})]"
|
||||
|
|
|
@ -1,22 +1,33 @@
|
|||
# @package _global_
|
||||
|
||||
seed: 100000
|
||||
horizon: 16
|
||||
n_obs_steps: 2
|
||||
n_action_steps: 8
|
||||
dataset_obs_steps: ${n_obs_steps}
|
||||
past_action_visible: False
|
||||
keypoint_visible_rate: 1.0
|
||||
dataset_repo_id: lerobot/pusht
|
||||
|
||||
eval_episodes: 50
|
||||
eval_freq: 5000
|
||||
save_freq: 5000
|
||||
log_freq: 250
|
||||
training:
|
||||
offline_steps: 200000
|
||||
online_steps: 0
|
||||
eval_freq: 5000
|
||||
save_freq: 5000
|
||||
log_freq: 250
|
||||
save_model: true
|
||||
|
||||
offline_steps: 200000
|
||||
online_steps: 0
|
||||
batch_size: 64
|
||||
grad_clip_norm: 10
|
||||
lr: 1.0e-4
|
||||
lr_scheduler: cosine
|
||||
lr_warmup_steps: 500
|
||||
adam_betas: [0.95, 0.999]
|
||||
adam_eps: 1.0e-8
|
||||
adam_weight_decay: 1.0e-6
|
||||
online_steps_between_rollouts: 1
|
||||
|
||||
offline_prioritized_sampler: true
|
||||
delta_timestamps:
|
||||
observation.image: "[i / ${fps} for i in range(1 - ${policy.n_obs_steps}, 1)]"
|
||||
observation.state: "[i / ${fps} for i in range(1 - ${policy.n_obs_steps}, 1)]"
|
||||
action: "[i / ${fps} for i in range(1 - ${policy.n_obs_steps}, 1 - ${policy.n_obs_steps} + ${policy.horizon})]"
|
||||
|
||||
eval:
|
||||
n_episodes: 50
|
||||
|
||||
override_dataset_stats:
|
||||
# TODO(rcadene, alexander-soare): should we remove image stats as well? do we use a pretrained vision model?
|
||||
|
@ -35,12 +46,10 @@ override_dataset_stats:
|
|||
policy:
|
||||
name: diffusion
|
||||
|
||||
pretrained_model_path:
|
||||
|
||||
# Input / output structure.
|
||||
n_obs_steps: ${n_obs_steps}
|
||||
horizon: ${horizon}
|
||||
n_action_steps: ${n_action_steps}
|
||||
n_obs_steps: 2
|
||||
horizon: 16
|
||||
n_action_steps: 8
|
||||
|
||||
input_shapes:
|
||||
# TODO(rcadene, alexander-soare): add variables for height and width from the dataset/env?
|
||||
|
@ -84,23 +93,9 @@ policy:
|
|||
|
||||
# ---
|
||||
# TODO(alexander-soare): Remove these from the policy config.
|
||||
batch_size: 64
|
||||
grad_clip_norm: 10
|
||||
lr: 1.0e-4
|
||||
lr_scheduler: cosine
|
||||
lr_warmup_steps: 500
|
||||
adam_betas: [0.95, 0.999]
|
||||
adam_eps: 1.0e-8
|
||||
adam_weight_decay: 1.0e-6
|
||||
utd: 1
|
||||
use_ema: true
|
||||
ema_update_after_step: 0
|
||||
ema_min_alpha: 0.0
|
||||
ema_max_alpha: 0.9999
|
||||
ema_inv_gamma: 1.0
|
||||
ema_power: 0.75
|
||||
|
||||
delta_timestamps:
|
||||
observation.image: "[i / ${fps} for i in range(1 - ${n_obs_steps}, 1)]"
|
||||
observation.state: "[i / ${fps} for i in range(1 - ${n_obs_steps}, 1)]"
|
||||
action: "[i / ${fps} for i in range(1 - ${n_obs_steps}, 1 - ${n_obs_steps} + ${policy.horizon})]"
|
||||
|
|
|
@ -1,85 +1,76 @@
|
|||
# @package _global_
|
||||
|
||||
n_action_steps: 2
|
||||
n_obs_steps: 1
|
||||
seed: 1
|
||||
|
||||
training:
|
||||
offline_steps: 25000
|
||||
online_steps: 25000
|
||||
eval_freq: 5000
|
||||
online_steps_between_rollouts: 1
|
||||
online_sampling_ratio: 0.5
|
||||
|
||||
batch_size: 256
|
||||
grad_clip_norm: 10.0
|
||||
lr: 3e-4
|
||||
|
||||
delta_timestamps:
|
||||
observation.image: "[i / ${fps} for i in range(${policy.horizon} + 1)]"
|
||||
observation.state: "[i / ${fps} for i in range(${policy.horizon} + 1)]"
|
||||
action: "[i / ${fps} for i in range(${policy.horizon})]"
|
||||
next.reward: "[i / ${fps} for i in range(${policy.horizon})]"
|
||||
|
||||
policy:
|
||||
name: tdmpc
|
||||
|
||||
reward_scale: 1.0
|
||||
pretrained_model_path:
|
||||
|
||||
episode_length: ${env.episode_length}
|
||||
discount: 0.9
|
||||
modality: 'all'
|
||||
|
||||
# pixels
|
||||
frame_stack: 1
|
||||
num_channels: 32
|
||||
img_size: ${env.image_size}
|
||||
state_dim: ${env.action_dim}
|
||||
action_dim: ${env.action_dim}
|
||||
|
||||
# planning
|
||||
mpc: true
|
||||
iterations: 6
|
||||
num_samples: 512
|
||||
num_elites: 50
|
||||
mixture_coef: 0.1
|
||||
min_std: 0.05
|
||||
max_std: 2.0
|
||||
temperature: 0.5
|
||||
momentum: 0.1
|
||||
uncertainty_cost: 1
|
||||
|
||||
# actor
|
||||
log_std_min: -10
|
||||
log_std_max: 2
|
||||
|
||||
# learning
|
||||
batch_size: 256
|
||||
max_buffer_size: 10000
|
||||
# Input / output structure.
|
||||
n_action_repeats: 2
|
||||
horizon: 5
|
||||
reward_coef: 0.5
|
||||
value_coef: 0.1
|
||||
consistency_coef: 20
|
||||
rho: 0.5
|
||||
kappa: 0.1
|
||||
lr: 3e-4
|
||||
std_schedule: ${policy.min_std}
|
||||
horizon_schedule: ${policy.horizon}
|
||||
per: true
|
||||
per_alpha: 0.6
|
||||
per_beta: 0.4
|
||||
grad_clip_norm: 10
|
||||
seed_steps: 0
|
||||
update_freq: 2
|
||||
tau: 0.01
|
||||
utd: 1
|
||||
|
||||
# offline rl
|
||||
# dataset_dir: ???
|
||||
data_first_percent: 1.0
|
||||
is_data_clip: true
|
||||
data_clip_eps: 1e-5
|
||||
expectile: 0.9
|
||||
A_scaling: 3.0
|
||||
input_shapes:
|
||||
# TODO(rcadene, alexander-soare): add variables for height and width from the dataset/env?
|
||||
observation.image: [3, 84, 84]
|
||||
observation.state: ["${env.state_dim}"]
|
||||
output_shapes:
|
||||
action: ["${env.action_dim}"]
|
||||
|
||||
# offline->online
|
||||
offline_steps: ${offline_steps}
|
||||
pretrained_model_path: ""
|
||||
# pretrained_model_path: "/home/rcadene/code/fowm/logs/xarm_lift/all/default/2/models/offline.pt"
|
||||
# pretrained_model_path: "/home/rcadene/code/fowm/logs/xarm_lift/all/default/2/models/final.pt"
|
||||
balanced_sampling: true
|
||||
demo_schedule: 0.5
|
||||
# Normalization / Unnormalization
|
||||
input_normalization_modes: null
|
||||
output_normalization_modes:
|
||||
action: min_max
|
||||
|
||||
# architecture
|
||||
enc_dim: 256
|
||||
num_q: 5
|
||||
mlp_dim: 512
|
||||
# Architecture / modeling.
|
||||
# Neural networks.
|
||||
image_encoder_hidden_dim: 32
|
||||
state_encoder_hidden_dim: 256
|
||||
latent_dim: 50
|
||||
q_ensemble_size: 5
|
||||
mlp_dim: 512
|
||||
# Reinforcement learning.
|
||||
discount: 0.9
|
||||
|
||||
delta_timestamps:
|
||||
observation.image: "[i / ${fps} for i in range(6)]"
|
||||
observation.state: "[i / ${fps} for i in range(6)]"
|
||||
action: "[i / ${fps} for i in range(5)]"
|
||||
next.reward: "[i / ${fps} for i in range(5)]"
|
||||
# Inference.
|
||||
use_mpc: false
|
||||
cem_iterations: 6
|
||||
max_std: 2.0
|
||||
min_std: 0.05
|
||||
n_gaussian_samples: 512
|
||||
n_pi_samples: 51
|
||||
uncertainty_regularizer_coeff: 1.0
|
||||
n_elites: 50
|
||||
elite_weighting_temperature: 0.5
|
||||
gaussian_mean_momentum: 0.1
|
||||
|
||||
# Training and loss computation.
|
||||
max_random_shift_ratio: 0.0476
|
||||
# Loss coefficients.
|
||||
reward_coeff: 0.5
|
||||
expectile_weight: 0.9
|
||||
value_coeff: 0.1
|
||||
consistency_coeff: 20.0
|
||||
advantage_scaling: 3.0
|
||||
pi_coeff: 0.5
|
||||
temporal_decay_coeff: 0.5
|
||||
# Target model.
|
||||
target_model_momentum: 0.995
|
||||
|
|
|
@ -1,30 +1,29 @@
|
|||
"""Evaluate a policy on an environment by running rollouts and computing metrics.
|
||||
|
||||
The script may be run in one of two ways:
|
||||
Usage examples:
|
||||
|
||||
1. By providing the path to a config file with the --config argument.
|
||||
2. By providing a HuggingFace Hub ID with the --hub-id argument. You may also provide a revision number with the
|
||||
--revision argument.
|
||||
You want to evaluate a model from the hub (eg: https://huggingface.co/lerobot/diffusion_policy_pusht_image)
|
||||
for 10 episodes.
|
||||
|
||||
In either case, it is possible to override config arguments by adding a list of config.key=value arguments.
|
||||
```
|
||||
python lerobot/scripts/eval.py -p lerobot/diffusion_policy_pusht_image eval.n_episodes=10
|
||||
```
|
||||
|
||||
Examples:
|
||||
|
||||
You have a specific config file to go with trained model weights, and want to run 10 episodes.
|
||||
OR, you want to evaluate a model checkpoint from the LeRobot training script for 10 episodes.
|
||||
|
||||
```
|
||||
python lerobot/scripts/eval.py \
|
||||
--config PATH/TO/FOLDER/config.yaml \
|
||||
policy.pretrained_model_path=PATH/TO/FOLDER/weights.pth \
|
||||
eval_episodes=10
|
||||
-p outputs/train/diffusion_policy_pusht_image/checkpoints/005000 \
|
||||
eval.n_episodes=10
|
||||
```
|
||||
|
||||
You have a HuggingFace Hub ID, you know which revision you want, and want to run 10 episodes (note that in this case,
|
||||
you don't need to specify which weights to use):
|
||||
Note that in both examples, the repo/folder should contain at least `config.json`, `config.yaml` and
|
||||
`model.safetensors`.
|
||||
|
||||
```
|
||||
python lerobot/scripts/eval.py --hub-id HUB/ID --revision v1.0 eval_episodes=10
|
||||
```
|
||||
Note the formatting for providing the number of episodes. Generally, you may provide any number of arguments
|
||||
with `qualified.parameter.name=value`. In this case, the parameter eval.n_episodes appears as `n_episodes`
|
||||
nested under `eval` in the `config.yaml` found at
|
||||
https://huggingface.co/lerobot/diffusion_policy_pusht_image/tree/main.
|
||||
"""
|
||||
|
||||
import argparse
|
||||
|
@ -42,9 +41,12 @@ import numpy as np
|
|||
import torch
|
||||
from datasets import Dataset, Features, Image, Sequence, Value
|
||||
from huggingface_hub import snapshot_download
|
||||
from huggingface_hub.utils._errors import RepositoryNotFoundError
|
||||
from huggingface_hub.utils._validators import HFValidationError
|
||||
from PIL import Image as PILImage
|
||||
from tqdm import trange
|
||||
|
||||
from lerobot.common.datasets.factory import make_dataset
|
||||
from lerobot.common.datasets.utils import hf_transform_to_torch
|
||||
from lerobot.common.envs.factory import make_env
|
||||
from lerobot.common.envs.utils import postprocess_action, preprocess_observation
|
||||
|
@ -65,10 +67,10 @@ def eval_policy(
|
|||
"""
|
||||
set `return_episode_data` to return a Hugging Face dataset object in an "episodes" key of the return dict.
|
||||
"""
|
||||
policy.eval()
|
||||
|
||||
fps = env.unwrapped.metadata["render_fps"]
|
||||
|
||||
if policy is not None:
|
||||
policy.eval()
|
||||
device = "cpu" if policy is None else next(policy.parameters()).device
|
||||
|
||||
start = time.time()
|
||||
|
@ -130,7 +132,7 @@ def eval_policy(
|
|||
|
||||
# get the next action for the environment
|
||||
with torch.inference_mode():
|
||||
action = policy.select_action(observation, step=step)
|
||||
action = policy.select_action(observation)
|
||||
|
||||
# convert to cpu numpy
|
||||
action = postprocess_action(action)
|
||||
|
@ -349,26 +351,42 @@ def eval_policy(
|
|||
return info
|
||||
|
||||
|
||||
def eval(cfg: dict, out_dir=None):
|
||||
def eval(
|
||||
pretrained_policy_path: str | None = None,
|
||||
hydra_cfg_path: str | None = None,
|
||||
config_overrides: list[str] | None = None,
|
||||
):
|
||||
assert (pretrained_policy_path is None) ^ (hydra_cfg_path is None)
|
||||
if hydra_cfg_path is None:
|
||||
hydra_cfg = init_hydra_config(pretrained_policy_path / "config.yaml", config_overrides)
|
||||
else:
|
||||
hydra_cfg = init_hydra_config(hydra_cfg_path, config_overrides)
|
||||
out_dir = (
|
||||
f"outputs/eval/{dt.now().strftime('%Y-%m-%d/%H-%M-%S')}_{hydra_cfg.env.name}_{hydra_cfg.policy.name}"
|
||||
)
|
||||
|
||||
if out_dir is None:
|
||||
raise NotImplementedError()
|
||||
|
||||
init_logging()
|
||||
|
||||
# Check device is available
|
||||
get_safe_torch_device(cfg.device, log=True)
|
||||
get_safe_torch_device(hydra_cfg.device, log=True)
|
||||
|
||||
torch.backends.cudnn.benchmark = True
|
||||
torch.backends.cuda.matmul.allow_tf32 = True
|
||||
set_global_seed(cfg.seed)
|
||||
set_global_seed(hydra_cfg.seed)
|
||||
|
||||
log_output_dir(out_dir)
|
||||
|
||||
logging.info("Making environment.")
|
||||
env = make_env(cfg, num_parallel_envs=cfg.eval_episodes)
|
||||
env = make_env(hydra_cfg, num_parallel_envs=hydra_cfg.eval.n_episodes)
|
||||
|
||||
logging.info("Making policy.")
|
||||
policy = make_policy(cfg)
|
||||
if hydra_cfg_path is None:
|
||||
policy = make_policy(hydra_cfg=hydra_cfg, pretrained_policy_name_or_path=pretrained_policy_path)
|
||||
else:
|
||||
# Note: We need the dataset stats to pass to the policy's normalization modules.
|
||||
policy = make_policy(hydra_cfg=hydra_cfg, dataset_stats=make_dataset(hydra_cfg).stats)
|
||||
policy.eval()
|
||||
|
||||
info = eval_policy(
|
||||
env,
|
||||
|
@ -376,7 +394,7 @@ def eval(cfg: dict, out_dir=None):
|
|||
max_episodes_rendered=10,
|
||||
video_dir=Path(out_dir) / "eval",
|
||||
return_episode_data=False,
|
||||
seed=cfg.seed,
|
||||
seed=hydra_cfg.seed,
|
||||
)
|
||||
print(info["aggregated"])
|
||||
|
||||
|
@ -390,13 +408,29 @@ def eval(cfg: dict, out_dir=None):
|
|||
|
||||
|
||||
if __name__ == "__main__":
|
||||
init_logging()
|
||||
|
||||
parser = argparse.ArgumentParser(
|
||||
description=__doc__, formatter_class=argparse.RawDescriptionHelpFormatter
|
||||
)
|
||||
group = parser.add_mutually_exclusive_group(required=True)
|
||||
group.add_argument("--config", help="Path to a specific yaml config you want to use.")
|
||||
group.add_argument("--hub-id", help="HuggingFace Hub ID for a pretrained model.")
|
||||
parser.add_argument("--revision", help="Optionally provide the HuggingFace Hub revision ID.")
|
||||
group.add_argument(
|
||||
"-p",
|
||||
"--pretrained-policy-name-or-path",
|
||||
help=(
|
||||
"Either the repo ID of a model hosted on the Hub or a path to a directory containing weights "
|
||||
"saved using `Policy.save_pretrained`. If not provided, the policy is initialized from scratch "
|
||||
"(useful for debugging). This argument is mutually exclusive with `--config`."
|
||||
),
|
||||
)
|
||||
group.add_argument(
|
||||
"--config",
|
||||
help=(
|
||||
"Path to a yaml config you want to use for initializing a policy from scratch (useful for "
|
||||
"debugging). This argument is mutually exclusive with `--pretrained-policy-name-or-path` (`-p`)."
|
||||
),
|
||||
)
|
||||
parser.add_argument("--revision", help="Optionally provide the Hugging Face Hub revision ID.")
|
||||
parser.add_argument(
|
||||
"overrides",
|
||||
nargs="*",
|
||||
|
@ -404,16 +438,28 @@ if __name__ == "__main__":
|
|||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
if args.config is not None:
|
||||
# Note: For the config_path, Hydra wants a path relative to this script file.
|
||||
cfg = init_hydra_config(args.config, args.overrides)
|
||||
elif args.hub_id is not None:
|
||||
folder = Path(snapshot_download(args.hub_id, revision=args.revision))
|
||||
cfg = init_hydra_config(
|
||||
folder / "config.yaml", [f"policy.pretrained_model_path={folder / 'model.pt'}", *args.overrides]
|
||||
)
|
||||
if args.pretrained_policy_name_or_path is None:
|
||||
eval(hydra_cfg_path=args.config, config_overrides=args.overrides)
|
||||
else:
|
||||
try:
|
||||
pretrained_policy_path = Path(
|
||||
snapshot_download(args.pretrained_policy_name_or_path, revision=args.revision)
|
||||
)
|
||||
except HFValidationError:
|
||||
logging.warning(
|
||||
"The provided pretrained_policy_name_or_path is not a valid Hugging Face Hub repo ID. "
|
||||
"Treating it as a local directory."
|
||||
)
|
||||
except RepositoryNotFoundError:
|
||||
logging.warning(
|
||||
"The provided pretrained_policy_name_or_path was not found on the Hugging Face Hub. Treating "
|
||||
"it as a local directory."
|
||||
)
|
||||
pretrained_policy_path = Path(args.pretrained_policy_name_or_path)
|
||||
if not pretrained_policy_path.is_dir() or not pretrained_policy_path.exists():
|
||||
raise ValueError(
|
||||
"The provided pretrained_policy_name_or_path is not a valid/existing Hugging Face Hub "
|
||||
"repo ID, nor is it an existing local directory."
|
||||
)
|
||||
|
||||
eval(
|
||||
cfg,
|
||||
out_dir=f"outputs/eval/{dt.now().strftime('%Y-%m-%d/%H-%M-%S')}_{cfg.env.name}_{cfg.policy.name}",
|
||||
)
|
||||
eval(pretrained_policy_path=pretrained_policy_path, config_overrides=args.overrides)
|
||||
|
|
|
@ -15,6 +15,7 @@ from lerobot.common.datasets.utils import cycle
|
|||
from lerobot.common.envs.factory import make_env
|
||||
from lerobot.common.logger import Logger, log_output_dir
|
||||
from lerobot.common.policies.factory import make_policy
|
||||
from lerobot.common.policies.policy_protocol import PolicyWithUpdate
|
||||
from lerobot.common.utils.utils import (
|
||||
format_big_number,
|
||||
get_safe_torch_device,
|
||||
|
@ -39,12 +40,17 @@ def update_policy(policy, batch, optimizer, grad_clip_norm, lr_scheduler=None):
|
|||
|
||||
optimizer.step()
|
||||
optimizer.zero_grad()
|
||||
|
||||
if lr_scheduler is not None:
|
||||
lr_scheduler.step()
|
||||
|
||||
if hasattr(policy, "ema") and policy.ema is not None:
|
||||
policy.ema.step(policy.diffusion)
|
||||
|
||||
if isinstance(policy, PolicyWithUpdate):
|
||||
# To possibly update an internal buffer (for instance an Exponential Moving Average like in TDMPC).
|
||||
policy.update()
|
||||
|
||||
info = {
|
||||
"loss": loss.item(),
|
||||
"grad_norm": float(grad_norm),
|
||||
|
@ -81,7 +87,7 @@ def log_train_info(logger, info, step, cfg, dataset, is_offline):
|
|||
|
||||
# A sample is an (observation,action) pair, where observation and action
|
||||
# can be on multiple timestamps. In a batch, we have `batch_size`` number of samples.
|
||||
num_samples = (step + 1) * cfg.policy.batch_size
|
||||
num_samples = (step + 1) * cfg.training.batch_size
|
||||
avg_samples_per_ep = dataset.num_samples / dataset.num_episodes
|
||||
num_episodes = num_samples / avg_samples_per_ep
|
||||
num_epochs = num_samples / dataset.num_samples
|
||||
|
@ -117,7 +123,7 @@ def log_eval_info(logger, info, step, cfg, dataset, is_offline):
|
|||
|
||||
# A sample is an (observation,action) pair, where observation and action
|
||||
# can be on multiple timestamps. In a batch, we have `batch_size`` number of samples.
|
||||
num_samples = (step + 1) * cfg.policy.batch_size
|
||||
num_samples = (step + 1) * cfg.training.batch_size
|
||||
avg_samples_per_ep = dataset.num_samples / dataset.num_episodes
|
||||
num_episodes = num_samples / avg_samples_per_ep
|
||||
num_epochs = num_samples / dataset.num_samples
|
||||
|
@ -246,11 +252,12 @@ def train(cfg: dict, out_dir=None, job_name=None):
|
|||
raise NotImplementedError()
|
||||
if job_name is None:
|
||||
raise NotImplementedError()
|
||||
if cfg.online_steps > 0:
|
||||
assert cfg.rollout_batch_size == 1, "rollout_batch_size > 1 not supported for online training steps"
|
||||
|
||||
init_logging()
|
||||
|
||||
if cfg.training.online_steps > 0 and cfg.eval.batch_size > 1:
|
||||
logging.warning("eval.batch_size > 1 not supported for online training steps")
|
||||
|
||||
# Check device is available
|
||||
get_safe_torch_device(cfg.device, log=True)
|
||||
|
||||
|
@ -262,10 +269,10 @@ def train(cfg: dict, out_dir=None, job_name=None):
|
|||
offline_dataset = make_dataset(cfg)
|
||||
|
||||
logging.info("make_env")
|
||||
env = make_env(cfg, num_parallel_envs=cfg.eval_episodes)
|
||||
env = make_env(cfg, num_parallel_envs=cfg.eval.n_episodes)
|
||||
|
||||
logging.info("make_policy")
|
||||
policy = make_policy(cfg, dataset_stats=offline_dataset.stats)
|
||||
policy = make_policy(hydra_cfg=cfg, dataset_stats=offline_dataset.stats)
|
||||
|
||||
# Create optimizer and scheduler
|
||||
# Temporary hack to move optimizer out of policy
|
||||
|
@ -282,34 +289,33 @@ def train(cfg: dict, out_dir=None, job_name=None):
|
|||
"params": [
|
||||
p for n, p in policy.named_parameters() if n.startswith("backbone") and p.requires_grad
|
||||
],
|
||||
"lr": cfg.policy.lr_backbone,
|
||||
"lr": cfg.training.lr_backbone,
|
||||
},
|
||||
]
|
||||
optimizer = torch.optim.AdamW(
|
||||
optimizer_params_dicts, lr=cfg.policy.lr, weight_decay=cfg.policy.weight_decay
|
||||
optimizer_params_dicts, lr=cfg.training.lr, weight_decay=cfg.training.weight_decay
|
||||
)
|
||||
lr_scheduler = None
|
||||
elif cfg.policy.name == "diffusion":
|
||||
optimizer = torch.optim.Adam(
|
||||
policy.diffusion.parameters(),
|
||||
cfg.policy.lr,
|
||||
cfg.policy.adam_betas,
|
||||
cfg.policy.adam_eps,
|
||||
cfg.policy.adam_weight_decay,
|
||||
cfg.training.lr,
|
||||
cfg.training.adam_betas,
|
||||
cfg.training.adam_eps,
|
||||
cfg.training.adam_weight_decay,
|
||||
)
|
||||
# TODO(rcadene): modify lr scheduler so that it doesn't depend on epochs but steps
|
||||
# configure lr scheduler
|
||||
assert cfg.training.online_steps == 0, "Diffusion Policy does not handle online training."
|
||||
lr_scheduler = get_scheduler(
|
||||
cfg.policy.lr_scheduler,
|
||||
cfg.training.lr_scheduler,
|
||||
optimizer=optimizer,
|
||||
num_warmup_steps=cfg.policy.lr_warmup_steps,
|
||||
num_training_steps=cfg.offline_steps,
|
||||
# pytorch assumes stepping LRScheduler every epoch
|
||||
# however huggingface diffusers steps it every batch
|
||||
last_epoch=-1,
|
||||
num_warmup_steps=cfg.training.lr_warmup_steps,
|
||||
num_training_steps=cfg.training.offline_steps,
|
||||
)
|
||||
elif policy.name == "tdmpc":
|
||||
raise NotImplementedError("TD-MPC not implemented yet.")
|
||||
optimizer = torch.optim.Adam(policy.parameters(), cfg.training.lr)
|
||||
lr_scheduler = None
|
||||
else:
|
||||
raise NotImplementedError()
|
||||
|
||||
num_learnable_params = sum(p.numel() for p in policy.parameters() if p.requires_grad)
|
||||
num_total_params = sum(p.numel() for p in policy.parameters())
|
||||
|
@ -319,8 +325,8 @@ def train(cfg: dict, out_dir=None, job_name=None):
|
|||
|
||||
log_output_dir(out_dir)
|
||||
logging.info(f"{cfg.env.task=}")
|
||||
logging.info(f"{cfg.offline_steps=} ({format_big_number(cfg.offline_steps)})")
|
||||
logging.info(f"{cfg.online_steps=}")
|
||||
logging.info(f"{cfg.training.offline_steps=} ({format_big_number(cfg.training.offline_steps)})")
|
||||
logging.info(f"{cfg.training.online_steps=}")
|
||||
logging.info(f"{offline_dataset.num_samples=} ({format_big_number(offline_dataset.num_samples)})")
|
||||
logging.info(f"{offline_dataset.num_episodes=}")
|
||||
logging.info(f"{num_learnable_params=} ({format_big_number(num_learnable_params)})")
|
||||
|
@ -328,7 +334,7 @@ def train(cfg: dict, out_dir=None, job_name=None):
|
|||
|
||||
# Note: this helper will be used in offline and online training loops.
|
||||
def _maybe_eval_and_maybe_save(step):
|
||||
if step % cfg.eval_freq == 0:
|
||||
if step % cfg.training.eval_freq == 0:
|
||||
logging.info(f"Eval policy at step {step}")
|
||||
eval_info = eval_policy(
|
||||
env,
|
||||
|
@ -342,37 +348,44 @@ def train(cfg: dict, out_dir=None, job_name=None):
|
|||
logger.log_video(eval_info["videos"][0], step, mode="eval")
|
||||
logging.info("Resume training")
|
||||
|
||||
if cfg.save_model and step % cfg.save_freq == 0:
|
||||
if cfg.training.save_model and step % cfg.training.save_freq == 0:
|
||||
logging.info(f"Checkpoint policy after step {step}")
|
||||
logger.save_model(policy, identifier=step)
|
||||
# Note: Save with step as the identifier, and format it to have at least 6 digits but more if
|
||||
# needed (choose 6 as a minimum for consistency without being overkill).
|
||||
logger.save_model(
|
||||
policy,
|
||||
identifier=str(step).zfill(
|
||||
max(6, len(str(cfg.training.offline_steps + cfg.training.online_steps)))
|
||||
),
|
||||
)
|
||||
logging.info("Resume training")
|
||||
|
||||
# create dataloader for offline training
|
||||
dataloader = torch.utils.data.DataLoader(
|
||||
offline_dataset,
|
||||
num_workers=8,
|
||||
batch_size=cfg.policy.batch_size,
|
||||
num_workers=4,
|
||||
batch_size=cfg.training.batch_size,
|
||||
shuffle=True,
|
||||
pin_memory=cfg.device != "cpu",
|
||||
drop_last=False,
|
||||
)
|
||||
dl_iter = cycle(dataloader)
|
||||
|
||||
policy.train()
|
||||
step = 0 # number of policy update (forward + backward + optim)
|
||||
is_offline = True
|
||||
for offline_step in range(cfg.offline_steps):
|
||||
for offline_step in range(cfg.training.offline_steps):
|
||||
if offline_step == 0:
|
||||
logging.info("Start offline training on a fixed dataset")
|
||||
policy.train()
|
||||
batch = next(dl_iter)
|
||||
|
||||
for key in batch:
|
||||
batch[key] = batch[key].to(cfg.device, non_blocking=True)
|
||||
|
||||
train_info = update_policy(policy, batch, optimizer, cfg.policy.grad_clip_norm, lr_scheduler)
|
||||
train_info = update_policy(policy, batch, optimizer, cfg.training.grad_clip_norm, lr_scheduler)
|
||||
|
||||
# TODO(rcadene): is it ok if step_t=0 = 0 and not 1 as previously done?
|
||||
if step % cfg.log_freq == 0:
|
||||
if step % cfg.training.log_freq == 0:
|
||||
log_train_info(logger, train_info, step, cfg, offline_dataset, is_offline)
|
||||
|
||||
# Note: _maybe_eval_and_maybe_save happens **after** the `step`th training update has completed, so we pass in
|
||||
|
@ -398,7 +411,7 @@ def train(cfg: dict, out_dir=None, job_name=None):
|
|||
dataloader = torch.utils.data.DataLoader(
|
||||
concat_dataset,
|
||||
num_workers=4,
|
||||
batch_size=cfg.policy.batch_size,
|
||||
batch_size=cfg.training.batch_size,
|
||||
sampler=sampler,
|
||||
pin_memory=cfg.device != "cpu",
|
||||
drop_last=False,
|
||||
|
@ -407,10 +420,11 @@ def train(cfg: dict, out_dir=None, job_name=None):
|
|||
|
||||
online_step = 0
|
||||
is_offline = False
|
||||
for env_step in range(cfg.online_steps):
|
||||
for env_step in range(cfg.training.online_steps):
|
||||
if env_step == 0:
|
||||
logging.info("Start online training by interacting with environment")
|
||||
|
||||
policy.eval()
|
||||
with torch.no_grad():
|
||||
eval_info = eval_policy(
|
||||
rollout_env,
|
||||
|
@ -419,25 +433,25 @@ def train(cfg: dict, out_dir=None, job_name=None):
|
|||
seed=cfg.seed,
|
||||
)
|
||||
|
||||
add_episodes_inplace(
|
||||
online_dataset,
|
||||
concat_dataset,
|
||||
sampler,
|
||||
hf_dataset=eval_info["episodes"]["hf_dataset"],
|
||||
episode_data_index=eval_info["episodes"]["episode_data_index"],
|
||||
pc_online_samples=cfg.get("demo_schedule", 0.5),
|
||||
)
|
||||
add_episodes_inplace(
|
||||
online_dataset,
|
||||
concat_dataset,
|
||||
sampler,
|
||||
hf_dataset=eval_info["episodes"]["hf_dataset"],
|
||||
episode_data_index=eval_info["episodes"]["episode_data_index"],
|
||||
pc_online_samples=cfg.training.online_sampling_ratio,
|
||||
)
|
||||
|
||||
for _ in range(cfg.policy.utd):
|
||||
policy.train()
|
||||
policy.train()
|
||||
for _ in range(cfg.training.online_steps_between_rollouts):
|
||||
batch = next(dl_iter)
|
||||
|
||||
for key in batch:
|
||||
batch[key] = batch[key].to(cfg.device, non_blocking=True)
|
||||
|
||||
train_info = update_policy(policy, batch, optimizer, cfg.policy.grad_clip_norm, lr_scheduler)
|
||||
train_info = update_policy(policy, batch, optimizer, cfg.training.grad_clip_norm, lr_scheduler)
|
||||
|
||||
if step % cfg.log_freq == 0:
|
||||
if step % cfg.training.log_freq == 0:
|
||||
log_train_info(logger, train_info, step, cfg, online_dataset, is_offline)
|
||||
|
||||
# Note: _maybe_eval_and_maybe_save happens **after** the `step`th training update has completed, so we pass
|
||||
|
|
|
@ -837,13 +837,13 @@ files = [
|
|||
|
||||
[[package]]
|
||||
name = "filelock"
|
||||
version = "3.13.4"
|
||||
version = "3.14.0"
|
||||
description = "A platform independent file lock."
|
||||
optional = false
|
||||
python-versions = ">=3.8"
|
||||
files = [
|
||||
{file = "filelock-3.13.4-py3-none-any.whl", hash = "sha256:404e5e9253aa60ad457cae1be07c0f0ca90a63931200a47d9b6a6af84fd7b45f"},
|
||||
{file = "filelock-3.13.4.tar.gz", hash = "sha256:d13f466618bfde72bd2c18255e269f72542c6e70e7bac83a0232d6b1cc5c8cf4"},
|
||||
{file = "filelock-3.14.0-py3-none-any.whl", hash = "sha256:43339835842f110ca7ae60f1e1c160714c5a6afd15a2873419ab185334975c0f"},
|
||||
{file = "filelock-3.14.0.tar.gz", hash = "sha256:6ea72da3be9b8c82afd3edcf99f2fffbb5076335a5ae4d03248bb5b6c3eae78a"},
|
||||
]
|
||||
|
||||
[package.extras]
|
||||
|
@ -1050,69 +1050,61 @@ preview = ["glfw-preview"]
|
|||
|
||||
[[package]]
|
||||
name = "grpcio"
|
||||
version = "1.62.2"
|
||||
version = "1.63.0"
|
||||
description = "HTTP/2-based RPC framework"
|
||||
optional = false
|
||||
python-versions = ">=3.7"
|
||||
python-versions = ">=3.8"
|
||||
files = [
|
||||
{file = "grpcio-1.62.2-cp310-cp310-linux_armv7l.whl", hash = "sha256:66344ea741124c38588a664237ac2fa16dfd226964cca23ddc96bd4accccbde5"},
|
||||
{file = "grpcio-1.62.2-cp310-cp310-macosx_12_0_universal2.whl", hash = "sha256:5dab7ac2c1e7cb6179c6bfad6b63174851102cbe0682294e6b1d6f0981ad7138"},
|
||||
{file = "grpcio-1.62.2-cp310-cp310-manylinux_2_17_aarch64.whl", hash = "sha256:3ad00f3f0718894749d5a8bb0fa125a7980a2f49523731a9b1fabf2b3522aa43"},
|
||||
{file = "grpcio-1.62.2-cp310-cp310-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:2e72ddfee62430ea80133d2cbe788e0d06b12f865765cb24a40009668bd8ea05"},
|
||||
{file = "grpcio-1.62.2-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:53d3a59a10af4c2558a8e563aed9f256259d2992ae0d3037817b2155f0341de1"},
|
||||
{file = "grpcio-1.62.2-cp310-cp310-musllinux_1_1_i686.whl", hash = "sha256:a1511a303f8074f67af4119275b4f954189e8313541da7b88b1b3a71425cdb10"},
|
||||
{file = "grpcio-1.62.2-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:b94d41b7412ef149743fbc3178e59d95228a7064c5ab4760ae82b562bdffb199"},
|
||||
{file = "grpcio-1.62.2-cp310-cp310-win32.whl", hash = "sha256:a75af2fc7cb1fe25785be7bed1ab18cef959a376cdae7c6870184307614caa3f"},
|
||||
{file = "grpcio-1.62.2-cp310-cp310-win_amd64.whl", hash = "sha256:80407bc007754f108dc2061e37480238b0dc1952c855e86a4fc283501ee6bb5d"},
|
||||
{file = "grpcio-1.62.2-cp311-cp311-linux_armv7l.whl", hash = "sha256:c1624aa686d4b36790ed1c2e2306cc3498778dffaf7b8dd47066cf819028c3ad"},
|
||||
{file = "grpcio-1.62.2-cp311-cp311-macosx_10_10_universal2.whl", hash = "sha256:1c1bb80299bdef33309dff03932264636450c8fdb142ea39f47e06a7153d3063"},
|
||||
{file = "grpcio-1.62.2-cp311-cp311-manylinux_2_17_aarch64.whl", hash = "sha256:db068bbc9b1fa16479a82e1ecf172a93874540cb84be69f0b9cb9b7ac3c82670"},
|
||||
{file = "grpcio-1.62.2-cp311-cp311-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:e2cc8a308780edbe2c4913d6a49dbdb5befacdf72d489a368566be44cadaef1a"},
|
||||
{file = "grpcio-1.62.2-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:d0695ae31a89f1a8fc8256050329a91a9995b549a88619263a594ca31b76d756"},
|
||||
{file = "grpcio-1.62.2-cp311-cp311-musllinux_1_1_i686.whl", hash = "sha256:88b4f9ee77191dcdd8810241e89340a12cbe050be3e0d5f2f091c15571cd3930"},
|
||||
{file = "grpcio-1.62.2-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:2a0204532aa2f1afd467024b02b4069246320405bc18abec7babab03e2644e75"},
|
||||
{file = "grpcio-1.62.2-cp311-cp311-win32.whl", hash = "sha256:6e784f60e575a0de554ef9251cbc2ceb8790914fe324f11e28450047f264ee6f"},
|
||||
{file = "grpcio-1.62.2-cp311-cp311-win_amd64.whl", hash = "sha256:112eaa7865dd9e6d7c0556c8b04ae3c3a2dc35d62ad3373ab7f6a562d8199200"},
|
||||
{file = "grpcio-1.62.2-cp312-cp312-linux_armv7l.whl", hash = "sha256:65034473fc09628a02fb85f26e73885cf1ed39ebd9cf270247b38689ff5942c5"},
|
||||
{file = "grpcio-1.62.2-cp312-cp312-macosx_10_10_universal2.whl", hash = "sha256:d2c1771d0ee3cf72d69bb5e82c6a82f27fbd504c8c782575eddb7839729fbaad"},
|
||||
{file = "grpcio-1.62.2-cp312-cp312-manylinux_2_17_aarch64.whl", hash = "sha256:3abe6838196da518863b5d549938ce3159d809218936851b395b09cad9b5d64a"},
|
||||
{file = "grpcio-1.62.2-cp312-cp312-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:c5ffeb269f10cedb4f33142b89a061acda9f672fd1357331dbfd043422c94e9e"},
|
||||
{file = "grpcio-1.62.2-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:404d3b4b6b142b99ba1cff0b2177d26b623101ea2ce51c25ef6e53d9d0d87bcc"},
|
||||
{file = "grpcio-1.62.2-cp312-cp312-musllinux_1_1_i686.whl", hash = "sha256:262cda97efdabb20853d3b5a4c546a535347c14b64c017f628ca0cc7fa780cc6"},
|
||||
{file = "grpcio-1.62.2-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:17708db5b11b966373e21519c4c73e5a750555f02fde82276ea2a267077c68ad"},
|
||||
{file = "grpcio-1.62.2-cp312-cp312-win32.whl", hash = "sha256:b7ec9e2f8ffc8436f6b642a10019fc513722858f295f7efc28de135d336ac189"},
|
||||
{file = "grpcio-1.62.2-cp312-cp312-win_amd64.whl", hash = "sha256:aa787b83a3cd5e482e5c79be030e2b4a122ecc6c5c6c4c42a023a2b581fdf17b"},
|
||||
{file = "grpcio-1.62.2-cp37-cp37m-linux_armv7l.whl", hash = "sha256:cfd23ad29bfa13fd4188433b0e250f84ec2c8ba66b14a9877e8bce05b524cf54"},
|
||||
{file = "grpcio-1.62.2-cp37-cp37m-macosx_10_10_universal2.whl", hash = "sha256:af15e9efa4d776dfcecd1d083f3ccfb04f876d613e90ef8432432efbeeac689d"},
|
||||
{file = "grpcio-1.62.2-cp37-cp37m-manylinux_2_17_aarch64.whl", hash = "sha256:f4aa94361bb5141a45ca9187464ae81a92a2a135ce2800b2203134f7a1a1d479"},
|
||||
{file = "grpcio-1.62.2-cp37-cp37m-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:82af3613a219512a28ee5c95578eb38d44dd03bca02fd918aa05603c41018051"},
|
||||
{file = "grpcio-1.62.2-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:55ddaf53474e8caeb29eb03e3202f9d827ad3110475a21245f3c7712022882a9"},
|
||||
{file = "grpcio-1.62.2-cp37-cp37m-musllinux_1_1_i686.whl", hash = "sha256:c79b518c56dddeec79e5500a53d8a4db90da995dfe1738c3ac57fe46348be049"},
|
||||
{file = "grpcio-1.62.2-cp37-cp37m-musllinux_1_1_x86_64.whl", hash = "sha256:a5eb4844e5e60bf2c446ef38c5b40d7752c6effdee882f716eb57ae87255d20a"},
|
||||
{file = "grpcio-1.62.2-cp37-cp37m-win_amd64.whl", hash = "sha256:aaae70364a2d1fb238afd6cc9fcb10442b66e397fd559d3f0968d28cc3ac929c"},
|
||||
{file = "grpcio-1.62.2-cp38-cp38-linux_armv7l.whl", hash = "sha256:1bcfe5070e4406f489e39325b76caeadab28c32bf9252d3ae960c79935a4cc36"},
|
||||
{file = "grpcio-1.62.2-cp38-cp38-macosx_10_10_universal2.whl", hash = "sha256:da6a7b6b938c15fa0f0568e482efaae9c3af31963eec2da4ff13a6d8ec2888e4"},
|
||||
{file = "grpcio-1.62.2-cp38-cp38-manylinux_2_17_aarch64.whl", hash = "sha256:41955b641c34db7d84db8d306937b72bc4968eef1c401bea73081a8d6c3d8033"},
|
||||
{file = "grpcio-1.62.2-cp38-cp38-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:c772f225483905f675cb36a025969eef9712f4698364ecd3a63093760deea1bc"},
|
||||
{file = "grpcio-1.62.2-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:07ce1f775d37ca18c7a141300e5b71539690efa1f51fe17f812ca85b5e73262f"},
|
||||
{file = "grpcio-1.62.2-cp38-cp38-musllinux_1_1_i686.whl", hash = "sha256:26f415f40f4a93579fd648f48dca1c13dfacdfd0290f4a30f9b9aeb745026811"},
|
||||
{file = "grpcio-1.62.2-cp38-cp38-musllinux_1_1_x86_64.whl", hash = "sha256:db707e3685ff16fc1eccad68527d072ac8bdd2e390f6daa97bc394ea7de4acea"},
|
||||
{file = "grpcio-1.62.2-cp38-cp38-win32.whl", hash = "sha256:589ea8e75de5fd6df387de53af6c9189c5231e212b9aa306b6b0d4f07520fbb9"},
|
||||
{file = "grpcio-1.62.2-cp38-cp38-win_amd64.whl", hash = "sha256:3c3ed41f4d7a3aabf0f01ecc70d6b5d00ce1800d4af652a549de3f7cf35c4abd"},
|
||||
{file = "grpcio-1.62.2-cp39-cp39-linux_armv7l.whl", hash = "sha256:162ccf61499c893831b8437120600290a99c0bc1ce7b51f2c8d21ec87ff6af8b"},
|
||||
{file = "grpcio-1.62.2-cp39-cp39-macosx_10_10_universal2.whl", hash = "sha256:f27246d7da7d7e3bd8612f63785a7b0c39a244cf14b8dd9dd2f2fab939f2d7f1"},
|
||||
{file = "grpcio-1.62.2-cp39-cp39-manylinux_2_17_aarch64.whl", hash = "sha256:2507006c8a478f19e99b6fe36a2464696b89d40d88f34e4b709abe57e1337467"},
|
||||
{file = "grpcio-1.62.2-cp39-cp39-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:a90ac47a8ce934e2c8d71e317d2f9e7e6aaceb2d199de940ce2c2eb611b8c0f4"},
|
||||
{file = "grpcio-1.62.2-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:99701979bcaaa7de8d5f60476487c5df8f27483624f1f7e300ff4669ee44d1f2"},
|
||||
{file = "grpcio-1.62.2-cp39-cp39-musllinux_1_1_i686.whl", hash = "sha256:af7dc3f7a44f10863b1b0ecab4078f0a00f561aae1edbd01fd03ad4dcf61c9e9"},
|
||||
{file = "grpcio-1.62.2-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:fa63245271920786f4cb44dcada4983a3516be8f470924528cf658731864c14b"},
|
||||
{file = "grpcio-1.62.2-cp39-cp39-win32.whl", hash = "sha256:c6ad9c39704256ed91a1cffc1379d63f7d0278d6a0bad06b0330f5d30291e3a3"},
|
||||
{file = "grpcio-1.62.2-cp39-cp39-win_amd64.whl", hash = "sha256:16da954692fd61aa4941fbeda405a756cd96b97b5d95ca58a92547bba2c1624f"},
|
||||
{file = "grpcio-1.62.2.tar.gz", hash = "sha256:c77618071d96b7a8be2c10701a98537823b9c65ba256c0b9067e0594cdbd954d"},
|
||||
{file = "grpcio-1.63.0-cp310-cp310-linux_armv7l.whl", hash = "sha256:2e93aca840c29d4ab5db93f94ed0a0ca899e241f2e8aec6334ab3575dc46125c"},
|
||||
{file = "grpcio-1.63.0-cp310-cp310-macosx_12_0_universal2.whl", hash = "sha256:91b73d3f1340fefa1e1716c8c1ec9930c676d6b10a3513ab6c26004cb02d8b3f"},
|
||||
{file = "grpcio-1.63.0-cp310-cp310-manylinux_2_17_aarch64.whl", hash = "sha256:b3afbd9d6827fa6f475a4f91db55e441113f6d3eb9b7ebb8fb806e5bb6d6bd0d"},
|
||||
{file = "grpcio-1.63.0-cp310-cp310-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:8f3f6883ce54a7a5f47db43289a0a4c776487912de1a0e2cc83fdaec9685cc9f"},
|
||||
{file = "grpcio-1.63.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:cf8dae9cc0412cb86c8de5a8f3be395c5119a370f3ce2e69c8b7d46bb9872c8d"},
|
||||
{file = "grpcio-1.63.0-cp310-cp310-musllinux_1_1_i686.whl", hash = "sha256:08e1559fd3b3b4468486b26b0af64a3904a8dbc78d8d936af9c1cf9636eb3e8b"},
|
||||
{file = "grpcio-1.63.0-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:5c039ef01516039fa39da8a8a43a95b64e288f79f42a17e6c2904a02a319b357"},
|
||||
{file = "grpcio-1.63.0-cp310-cp310-win32.whl", hash = "sha256:ad2ac8903b2eae071055a927ef74121ed52d69468e91d9bcbd028bd0e554be6d"},
|
||||
{file = "grpcio-1.63.0-cp310-cp310-win_amd64.whl", hash = "sha256:b2e44f59316716532a993ca2966636df6fbe7be4ab6f099de6815570ebe4383a"},
|
||||
{file = "grpcio-1.63.0-cp311-cp311-linux_armv7l.whl", hash = "sha256:f28f8b2db7b86c77916829d64ab21ff49a9d8289ea1564a2b2a3a8ed9ffcccd3"},
|
||||
{file = "grpcio-1.63.0-cp311-cp311-macosx_10_9_universal2.whl", hash = "sha256:65bf975639a1f93bee63ca60d2e4951f1b543f498d581869922910a476ead2f5"},
|
||||
{file = "grpcio-1.63.0-cp311-cp311-manylinux_2_17_aarch64.whl", hash = "sha256:b5194775fec7dc3dbd6a935102bb156cd2c35efe1685b0a46c67b927c74f0cfb"},
|
||||
{file = "grpcio-1.63.0-cp311-cp311-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:e4cbb2100ee46d024c45920d16e888ee5d3cf47c66e316210bc236d5bebc42b3"},
|
||||
{file = "grpcio-1.63.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:1ff737cf29b5b801619f10e59b581869e32f400159e8b12d7a97e7e3bdeee6a2"},
|
||||
{file = "grpcio-1.63.0-cp311-cp311-musllinux_1_1_i686.whl", hash = "sha256:cd1e68776262dd44dedd7381b1a0ad09d9930ffb405f737d64f505eb7f77d6c7"},
|
||||
{file = "grpcio-1.63.0-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:93f45f27f516548e23e4ec3fbab21b060416007dbe768a111fc4611464cc773f"},
|
||||
{file = "grpcio-1.63.0-cp311-cp311-win32.whl", hash = "sha256:878b1d88d0137df60e6b09b74cdb73db123f9579232c8456f53e9abc4f62eb3c"},
|
||||
{file = "grpcio-1.63.0-cp311-cp311-win_amd64.whl", hash = "sha256:756fed02dacd24e8f488f295a913f250b56b98fb793f41d5b2de6c44fb762434"},
|
||||
{file = "grpcio-1.63.0-cp312-cp312-linux_armv7l.whl", hash = "sha256:93a46794cc96c3a674cdfb59ef9ce84d46185fe9421baf2268ccb556f8f81f57"},
|
||||
{file = "grpcio-1.63.0-cp312-cp312-macosx_10_9_universal2.whl", hash = "sha256:a7b19dfc74d0be7032ca1eda0ed545e582ee46cd65c162f9e9fc6b26ef827dc6"},
|
||||
{file = "grpcio-1.63.0-cp312-cp312-manylinux_2_17_aarch64.whl", hash = "sha256:8064d986d3a64ba21e498b9a376cbc5d6ab2e8ab0e288d39f266f0fca169b90d"},
|
||||
{file = "grpcio-1.63.0-cp312-cp312-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:219bb1848cd2c90348c79ed0a6b0ea51866bc7e72fa6e205e459fedab5770172"},
|
||||
{file = "grpcio-1.63.0-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:a2d60cd1d58817bc5985fae6168d8b5655c4981d448d0f5b6194bbcc038090d2"},
|
||||
{file = "grpcio-1.63.0-cp312-cp312-musllinux_1_1_i686.whl", hash = "sha256:9e350cb096e5c67832e9b6e018cf8a0d2a53b2a958f6251615173165269a91b0"},
|
||||
{file = "grpcio-1.63.0-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:56cdf96ff82e3cc90dbe8bac260352993f23e8e256e063c327b6cf9c88daf7a9"},
|
||||
{file = "grpcio-1.63.0-cp312-cp312-win32.whl", hash = "sha256:3a6d1f9ea965e750db7b4ee6f9fdef5fdf135abe8a249e75d84b0a3e0c668a1b"},
|
||||
{file = "grpcio-1.63.0-cp312-cp312-win_amd64.whl", hash = "sha256:d2497769895bb03efe3187fb1888fc20e98a5f18b3d14b606167dacda5789434"},
|
||||
{file = "grpcio-1.63.0-cp38-cp38-linux_armv7l.whl", hash = "sha256:fdf348ae69c6ff484402cfdb14e18c1b0054ac2420079d575c53a60b9b2853ae"},
|
||||
{file = "grpcio-1.63.0-cp38-cp38-macosx_10_9_universal2.whl", hash = "sha256:a3abfe0b0f6798dedd2e9e92e881d9acd0fdb62ae27dcbbfa7654a57e24060c0"},
|
||||
{file = "grpcio-1.63.0-cp38-cp38-manylinux_2_17_aarch64.whl", hash = "sha256:6ef0ad92873672a2a3767cb827b64741c363ebaa27e7f21659e4e31f4d750280"},
|
||||
{file = "grpcio-1.63.0-cp38-cp38-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:b416252ac5588d9dfb8a30a191451adbf534e9ce5f56bb02cd193f12d8845b7f"},
|
||||
{file = "grpcio-1.63.0-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:e3b77eaefc74d7eb861d3ffbdf91b50a1bb1639514ebe764c47773b833fa2d91"},
|
||||
{file = "grpcio-1.63.0-cp38-cp38-musllinux_1_1_i686.whl", hash = "sha256:b005292369d9c1f80bf70c1db1c17c6c342da7576f1c689e8eee4fb0c256af85"},
|
||||
{file = "grpcio-1.63.0-cp38-cp38-musllinux_1_1_x86_64.whl", hash = "sha256:cdcda1156dcc41e042d1e899ba1f5c2e9f3cd7625b3d6ebfa619806a4c1aadda"},
|
||||
{file = "grpcio-1.63.0-cp38-cp38-win32.whl", hash = "sha256:01799e8649f9e94ba7db1aeb3452188048b0019dc37696b0f5ce212c87c560c3"},
|
||||
{file = "grpcio-1.63.0-cp38-cp38-win_amd64.whl", hash = "sha256:6a1a3642d76f887aa4009d92f71eb37809abceb3b7b5a1eec9c554a246f20e3a"},
|
||||
{file = "grpcio-1.63.0-cp39-cp39-linux_armv7l.whl", hash = "sha256:75f701ff645858a2b16bc8c9fc68af215a8bb2d5a9b647448129de6e85d52bce"},
|
||||
{file = "grpcio-1.63.0-cp39-cp39-macosx_10_9_universal2.whl", hash = "sha256:cacdef0348a08e475a721967f48206a2254a1b26ee7637638d9e081761a5ba86"},
|
||||
{file = "grpcio-1.63.0-cp39-cp39-manylinux_2_17_aarch64.whl", hash = "sha256:0697563d1d84d6985e40ec5ec596ff41b52abb3fd91ec240e8cb44a63b895094"},
|
||||
{file = "grpcio-1.63.0-cp39-cp39-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:6426e1fb92d006e47476d42b8f240c1d916a6d4423c5258ccc5b105e43438f61"},
|
||||
{file = "grpcio-1.63.0-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:e48cee31bc5f5a31fb2f3b573764bd563aaa5472342860edcc7039525b53e46a"},
|
||||
{file = "grpcio-1.63.0-cp39-cp39-musllinux_1_1_i686.whl", hash = "sha256:50344663068041b34a992c19c600236e7abb42d6ec32567916b87b4c8b8833b3"},
|
||||
{file = "grpcio-1.63.0-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:259e11932230d70ef24a21b9fb5bb947eb4703f57865a404054400ee92f42f5d"},
|
||||
{file = "grpcio-1.63.0-cp39-cp39-win32.whl", hash = "sha256:a44624aad77bf8ca198c55af811fd28f2b3eaf0a50ec5b57b06c034416ef2d0a"},
|
||||
{file = "grpcio-1.63.0-cp39-cp39-win_amd64.whl", hash = "sha256:166e5c460e5d7d4656ff9e63b13e1f6029b122104c1633d5f37eaea348d7356d"},
|
||||
{file = "grpcio-1.63.0.tar.gz", hash = "sha256:f3023e14805c61bc439fb40ca545ac3d5740ce66120a678a3c6c2c55b70343d1"},
|
||||
]
|
||||
|
||||
[package.extras]
|
||||
protobuf = ["grpcio-tools (>=1.62.2)"]
|
||||
protobuf = ["grpcio-tools (>=1.63.0)"]
|
||||
|
||||
[[package]]
|
||||
name = "gym-aloha"
|
||||
|
@ -2414,7 +2406,6 @@ optional = false
|
|||
python-versions = ">=3.9"
|
||||
files = [
|
||||
{file = "pandas-2.2.2-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:90c6fca2acf139569e74e8781709dccb6fe25940488755716d1d354d6bc58bce"},
|
||||
{file = "pandas-2.2.2-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:c7adfc142dac335d8c1e0dcbd37eb8617eac386596eb9e1a1b77791cf2498238"},
|
||||
{file = "pandas-2.2.2-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:4abfe0be0d7221be4f12552995e58723c7422c80a659da13ca382697de830c08"},
|
||||
{file = "pandas-2.2.2-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:8635c16bf3d99040fdf3ca3db669a7250ddf49c55dc4aa8fe0ae0fa8d6dcc1f0"},
|
||||
{file = "pandas-2.2.2-cp310-cp310-musllinux_1_1_aarch64.whl", hash = "sha256:40ae1dffb3967a52203105a077415a86044a2bea011b5f321c6aa64b379a3f51"},
|
||||
|
@ -2435,7 +2426,6 @@ files = [
|
|||
{file = "pandas-2.2.2-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:43498c0bdb43d55cb162cdc8c06fac328ccb5d2eabe3cadeb3529ae6f0517c32"},
|
||||
{file = "pandas-2.2.2-cp312-cp312-win_amd64.whl", hash = "sha256:d187d355ecec3629624fccb01d104da7d7f391db0311145817525281e2804d23"},
|
||||
{file = "pandas-2.2.2-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:0ca6377b8fca51815f382bd0b697a0814c8bda55115678cbc94c30aacbb6eff2"},
|
||||
{file = "pandas-2.2.2-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:9057e6aa78a584bc93a13f0a9bf7e753a5e9770a30b4d758b8d5f2a62a9433cd"},
|
||||
{file = "pandas-2.2.2-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:001910ad31abc7bf06f49dcc903755d2f7f3a9186c0c040b827e522e9cef0863"},
|
||||
{file = "pandas-2.2.2-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:66b479b0bd07204e37583c191535505410daa8df638fd8e75ae1b383851fe921"},
|
||||
{file = "pandas-2.2.2-cp39-cp39-musllinux_1_1_aarch64.whl", hash = "sha256:a77e9d1c386196879aa5eb712e77461aaee433e54c68cf253053a73b7e49c33a"},
|
||||
|
@ -3110,104 +3100,90 @@ files = [
|
|||
|
||||
[[package]]
|
||||
name = "regex"
|
||||
version = "2024.4.16"
|
||||
version = "2024.4.28"
|
||||
description = "Alternative regular expression module, to replace re."
|
||||
optional = false
|
||||
python-versions = ">=3.7"
|
||||
python-versions = ">=3.8"
|
||||
files = [
|
||||
{file = "regex-2024.4.16-cp310-cp310-macosx_10_9_universal2.whl", hash = "sha256:fb83cc090eac63c006871fd24db5e30a1f282faa46328572661c0a24a2323a08"},
|
||||
{file = "regex-2024.4.16-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:8c91e1763696c0eb66340c4df98623c2d4e77d0746b8f8f2bee2c6883fd1fe18"},
|
||||
{file = "regex-2024.4.16-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:10188fe732dec829c7acca7422cdd1bf57d853c7199d5a9e96bb4d40db239c73"},
|
||||
{file = "regex-2024.4.16-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:956b58d692f235cfbf5b4f3abd6d99bf102f161ccfe20d2fd0904f51c72c4c66"},
|
||||
{file = "regex-2024.4.16-cp310-cp310-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:a70b51f55fd954d1f194271695821dd62054d949efd6368d8be64edd37f55c86"},
|
||||
{file = "regex-2024.4.16-cp310-cp310-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:5c02fcd2bf45162280613d2e4a1ca3ac558ff921ae4e308ecb307650d3a6ee51"},
|
||||
{file = "regex-2024.4.16-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:c4ed75ea6892a56896d78f11006161eea52c45a14994794bcfa1654430984b22"},
|
||||
{file = "regex-2024.4.16-cp310-cp310-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:bd727ad276bb91928879f3aa6396c9a1d34e5e180dce40578421a691eeb77f47"},
|
||||
{file = "regex-2024.4.16-cp310-cp310-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_12_x86_64.manylinux2010_x86_64.whl", hash = "sha256:7cbc5d9e8a1781e7be17da67b92580d6ce4dcef5819c1b1b89f49d9678cc278c"},
|
||||
{file = "regex-2024.4.16-cp310-cp310-musllinux_1_1_aarch64.whl", hash = "sha256:78fddb22b9ef810b63ef341c9fcf6455232d97cfe03938cbc29e2672c436670e"},
|
||||
{file = "regex-2024.4.16-cp310-cp310-musllinux_1_1_i686.whl", hash = "sha256:445ca8d3c5a01309633a0c9db57150312a181146315693273e35d936472df912"},
|
||||
{file = "regex-2024.4.16-cp310-cp310-musllinux_1_1_ppc64le.whl", hash = "sha256:95399831a206211d6bc40224af1c635cb8790ddd5c7493e0bd03b85711076a53"},
|
||||
{file = "regex-2024.4.16-cp310-cp310-musllinux_1_1_s390x.whl", hash = "sha256:7731728b6568fc286d86745f27f07266de49603a6fdc4d19c87e8c247be452af"},
|
||||
{file = "regex-2024.4.16-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:4facc913e10bdba42ec0aee76d029aedda628161a7ce4116b16680a0413f658a"},
|
||||
{file = "regex-2024.4.16-cp310-cp310-win32.whl", hash = "sha256:911742856ce98d879acbea33fcc03c1d8dc1106234c5e7d068932c945db209c0"},
|
||||
{file = "regex-2024.4.16-cp310-cp310-win_amd64.whl", hash = "sha256:e0a2df336d1135a0b3a67f3bbf78a75f69562c1199ed9935372b82215cddd6e2"},
|
||||
{file = "regex-2024.4.16-cp311-cp311-macosx_10_9_universal2.whl", hash = "sha256:1210365faba7c2150451eb78ec5687871c796b0f1fa701bfd2a4a25420482d26"},
|
||||
{file = "regex-2024.4.16-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:9ab40412f8cd6f615bfedea40c8bf0407d41bf83b96f6fc9ff34976d6b7037fd"},
|
||||
{file = "regex-2024.4.16-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:fd80d1280d473500d8086d104962a82d77bfbf2b118053824b7be28cd5a79ea5"},
|
||||
{file = "regex-2024.4.16-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:7bb966fdd9217e53abf824f437a5a2d643a38d4fd5fd0ca711b9da683d452969"},
|
||||
{file = "regex-2024.4.16-cp311-cp311-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:20b7a68444f536365af42a75ccecb7ab41a896a04acf58432db9e206f4e525d6"},
|
||||
{file = "regex-2024.4.16-cp311-cp311-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:b74586dd0b039c62416034f811d7ee62810174bb70dffcca6439f5236249eb09"},
|
||||
{file = "regex-2024.4.16-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:0c8290b44d8b0af4e77048646c10c6e3aa583c1ca67f3b5ffb6e06cf0c6f0f89"},
|
||||
{file = "regex-2024.4.16-cp311-cp311-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:f2d80a6749724b37853ece57988b39c4e79d2b5fe2869a86e8aeae3bbeef9eb0"},
|
||||
{file = "regex-2024.4.16-cp311-cp311-musllinux_1_1_aarch64.whl", hash = "sha256:3a1018e97aeb24e4f939afcd88211ace472ba566efc5bdf53fd8fd7f41fa7170"},
|
||||
{file = "regex-2024.4.16-cp311-cp311-musllinux_1_1_i686.whl", hash = "sha256:8d015604ee6204e76569d2f44e5a210728fa917115bef0d102f4107e622b08d5"},
|
||||
{file = "regex-2024.4.16-cp311-cp311-musllinux_1_1_ppc64le.whl", hash = "sha256:3d5ac5234fb5053850d79dd8eb1015cb0d7d9ed951fa37aa9e6249a19aa4f336"},
|
||||
{file = "regex-2024.4.16-cp311-cp311-musllinux_1_1_s390x.whl", hash = "sha256:0a38d151e2cdd66d16dab550c22f9521ba79761423b87c01dae0a6e9add79c0d"},
|
||||
{file = "regex-2024.4.16-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:159dc4e59a159cb8e4e8f8961eb1fa5d58f93cb1acd1701d8aff38d45e1a84a6"},
|
||||
{file = "regex-2024.4.16-cp311-cp311-win32.whl", hash = "sha256:ba2336d6548dee3117520545cfe44dc28a250aa091f8281d28804aa8d707d93d"},
|
||||
{file = "regex-2024.4.16-cp311-cp311-win_amd64.whl", hash = "sha256:8f83b6fd3dc3ba94d2b22717f9c8b8512354fd95221ac661784df2769ea9bba9"},
|
||||
{file = "regex-2024.4.16-cp312-cp312-macosx_10_9_universal2.whl", hash = "sha256:80b696e8972b81edf0af2a259e1b2a4a661f818fae22e5fa4fa1a995fb4a40fd"},
|
||||
{file = "regex-2024.4.16-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:d61ae114d2a2311f61d90c2ef1358518e8f05eafda76eaf9c772a077e0b465ec"},
|
||||
{file = "regex-2024.4.16-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:8ba6745440b9a27336443b0c285d705ce73adb9ec90e2f2004c64d95ab5a7598"},
|
||||
{file = "regex-2024.4.16-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:6295004b2dd37b0835ea5c14a33e00e8cfa3c4add4d587b77287825f3418d310"},
|
||||
{file = "regex-2024.4.16-cp312-cp312-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:4aba818dcc7263852aabb172ec27b71d2abca02a593b95fa79351b2774eb1d2b"},
|
||||
{file = "regex-2024.4.16-cp312-cp312-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:d0800631e565c47520aaa04ae38b96abc5196fe8b4aa9bd864445bd2b5848a7a"},
|
||||
{file = "regex-2024.4.16-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:08dea89f859c3df48a440dbdcd7b7155bc675f2fa2ec8c521d02dc69e877db70"},
|
||||
{file = "regex-2024.4.16-cp312-cp312-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:eeaa0b5328b785abc344acc6241cffde50dc394a0644a968add75fcefe15b9d4"},
|
||||
{file = "regex-2024.4.16-cp312-cp312-musllinux_1_1_aarch64.whl", hash = "sha256:4e819a806420bc010489f4e741b3036071aba209f2e0989d4750b08b12a9343f"},
|
||||
{file = "regex-2024.4.16-cp312-cp312-musllinux_1_1_i686.whl", hash = "sha256:c2d0e7cbb6341e830adcbfa2479fdeebbfbb328f11edd6b5675674e7a1e37730"},
|
||||
{file = "regex-2024.4.16-cp312-cp312-musllinux_1_1_ppc64le.whl", hash = "sha256:91797b98f5e34b6a49f54be33f72e2fb658018ae532be2f79f7c63b4ae225145"},
|
||||
{file = "regex-2024.4.16-cp312-cp312-musllinux_1_1_s390x.whl", hash = "sha256:d2da13568eff02b30fd54fccd1e042a70fe920d816616fda4bf54ec705668d81"},
|
||||
{file = "regex-2024.4.16-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:370c68dc5570b394cbaadff50e64d705f64debed30573e5c313c360689b6aadc"},
|
||||
{file = "regex-2024.4.16-cp312-cp312-win32.whl", hash = "sha256:904c883cf10a975b02ab3478bce652f0f5346a2c28d0a8521d97bb23c323cc8b"},
|
||||
{file = "regex-2024.4.16-cp312-cp312-win_amd64.whl", hash = "sha256:785c071c982dce54d44ea0b79cd6dfafddeccdd98cfa5f7b86ef69b381b457d9"},
|
||||
{file = "regex-2024.4.16-cp37-cp37m-macosx_10_9_x86_64.whl", hash = "sha256:e2f142b45c6fed48166faeb4303b4b58c9fcd827da63f4cf0a123c3480ae11fb"},
|
||||
{file = "regex-2024.4.16-cp37-cp37m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:e87ab229332ceb127a165612d839ab87795972102cb9830e5f12b8c9a5c1b508"},
|
||||
{file = "regex-2024.4.16-cp37-cp37m-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:81500ed5af2090b4a9157a59dbc89873a25c33db1bb9a8cf123837dcc9765047"},
|
||||
{file = "regex-2024.4.16-cp37-cp37m-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:b340cccad138ecb363324aa26893963dcabb02bb25e440ebdf42e30963f1a4e0"},
|
||||
{file = "regex-2024.4.16-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:2c72608e70f053643437bd2be0608f7f1c46d4022e4104d76826f0839199347a"},
|
||||
{file = "regex-2024.4.16-cp37-cp37m-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:a01fe2305e6232ef3e8f40bfc0f0f3a04def9aab514910fa4203bafbc0bb4682"},
|
||||
{file = "regex-2024.4.16-cp37-cp37m-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_12_x86_64.manylinux2010_x86_64.whl", hash = "sha256:03576e3a423d19dda13e55598f0fd507b5d660d42c51b02df4e0d97824fdcae3"},
|
||||
{file = "regex-2024.4.16-cp37-cp37m-musllinux_1_1_aarch64.whl", hash = "sha256:549c3584993772e25f02d0656ac48abdda73169fe347263948cf2b1cead622f3"},
|
||||
{file = "regex-2024.4.16-cp37-cp37m-musllinux_1_1_i686.whl", hash = "sha256:34422d5a69a60b7e9a07a690094e824b66f5ddc662a5fc600d65b7c174a05f04"},
|
||||
{file = "regex-2024.4.16-cp37-cp37m-musllinux_1_1_ppc64le.whl", hash = "sha256:5f580c651a72b75c39e311343fe6875d6f58cf51c471a97f15a938d9fe4e0d37"},
|
||||
{file = "regex-2024.4.16-cp37-cp37m-musllinux_1_1_s390x.whl", hash = "sha256:3399dd8a7495bbb2bacd59b84840eef9057826c664472e86c91d675d007137f5"},
|
||||
{file = "regex-2024.4.16-cp37-cp37m-musllinux_1_1_x86_64.whl", hash = "sha256:8d1f86f3f4e2388aa3310b50694ac44daefbd1681def26b4519bd050a398dc5a"},
|
||||
{file = "regex-2024.4.16-cp37-cp37m-win32.whl", hash = "sha256:dd5acc0a7d38fdc7a3a6fd3ad14c880819008ecb3379626e56b163165162cc46"},
|
||||
{file = "regex-2024.4.16-cp37-cp37m-win_amd64.whl", hash = "sha256:ba8122e3bb94ecda29a8de4cf889f600171424ea586847aa92c334772d200331"},
|
||||
{file = "regex-2024.4.16-cp38-cp38-macosx_10_9_universal2.whl", hash = "sha256:743deffdf3b3481da32e8a96887e2aa945ec6685af1cfe2bcc292638c9ba2f48"},
|
||||
{file = "regex-2024.4.16-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:7571f19f4a3fd00af9341c7801d1ad1967fc9c3f5e62402683047e7166b9f2b4"},
|
||||
{file = "regex-2024.4.16-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:df79012ebf6f4efb8d307b1328226aef24ca446b3ff8d0e30202d7ebcb977a8c"},
|
||||
{file = "regex-2024.4.16-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:e757d475953269fbf4b441207bb7dbdd1c43180711b6208e129b637792ac0b93"},
|
||||
{file = "regex-2024.4.16-cp38-cp38-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:4313ab9bf6a81206c8ac28fdfcddc0435299dc88cad12cc6305fd0e78b81f9e4"},
|
||||
{file = "regex-2024.4.16-cp38-cp38-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:d83c2bc678453646f1a18f8db1e927a2d3f4935031b9ad8a76e56760461105dd"},
|
||||
{file = "regex-2024.4.16-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:9df1bfef97db938469ef0a7354b2d591a2d438bc497b2c489471bec0e6baf7c4"},
|
||||
{file = "regex-2024.4.16-cp38-cp38-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:62120ed0de69b3649cc68e2965376048793f466c5a6c4370fb27c16c1beac22d"},
|
||||
{file = "regex-2024.4.16-cp38-cp38-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_12_x86_64.manylinux2010_x86_64.whl", hash = "sha256:c2ef6f7990b6e8758fe48ad08f7e2f66c8f11dc66e24093304b87cae9037bb4a"},
|
||||
{file = "regex-2024.4.16-cp38-cp38-musllinux_1_1_aarch64.whl", hash = "sha256:8fc6976a3395fe4d1fbeb984adaa8ec652a1e12f36b56ec8c236e5117b585427"},
|
||||
{file = "regex-2024.4.16-cp38-cp38-musllinux_1_1_i686.whl", hash = "sha256:03e68f44340528111067cecf12721c3df4811c67268b897fbe695c95f860ac42"},
|
||||
{file = "regex-2024.4.16-cp38-cp38-musllinux_1_1_ppc64le.whl", hash = "sha256:ec7e0043b91115f427998febaa2beb82c82df708168b35ece3accb610b91fac1"},
|
||||
{file = "regex-2024.4.16-cp38-cp38-musllinux_1_1_s390x.whl", hash = "sha256:c21fc21a4c7480479d12fd8e679b699f744f76bb05f53a1d14182b31f55aac76"},
|
||||
{file = "regex-2024.4.16-cp38-cp38-musllinux_1_1_x86_64.whl", hash = "sha256:12f6a3f2f58bb7344751919a1876ee1b976fe08b9ffccb4bbea66f26af6017b9"},
|
||||
{file = "regex-2024.4.16-cp38-cp38-win32.whl", hash = "sha256:479595a4fbe9ed8f8f72c59717e8cf222da2e4c07b6ae5b65411e6302af9708e"},
|
||||
{file = "regex-2024.4.16-cp38-cp38-win_amd64.whl", hash = "sha256:0534b034fba6101611968fae8e856c1698da97ce2efb5c2b895fc8b9e23a5834"},
|
||||
{file = "regex-2024.4.16-cp39-cp39-macosx_10_9_universal2.whl", hash = "sha256:a7ccdd1c4a3472a7533b0a7aa9ee34c9a2bef859ba86deec07aff2ad7e0c3b94"},
|
||||
{file = "regex-2024.4.16-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:6f2f017c5be19984fbbf55f8af6caba25e62c71293213f044da3ada7091a4455"},
|
||||
{file = "regex-2024.4.16-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:803b8905b52de78b173d3c1e83df0efb929621e7b7c5766c0843704d5332682f"},
|
||||
{file = "regex-2024.4.16-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:684008ec44ad275832a5a152f6e764bbe1914bea10968017b6feaecdad5736e0"},
|
||||
{file = "regex-2024.4.16-cp39-cp39-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:65436dce9fdc0aeeb0a0effe0839cb3d6a05f45aa45a4d9f9c60989beca78b9c"},
|
||||
{file = "regex-2024.4.16-cp39-cp39-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:ea355eb43b11764cf799dda62c658c4d2fdb16af41f59bb1ccfec517b60bcb07"},
|
||||
{file = "regex-2024.4.16-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:98c1165f3809ce7774f05cb74e5408cd3aa93ee8573ae959a97a53db3ca3180d"},
|
||||
{file = "regex-2024.4.16-cp39-cp39-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:cccc79a9be9b64c881f18305a7c715ba199e471a3973faeb7ba84172abb3f317"},
|
||||
{file = "regex-2024.4.16-cp39-cp39-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_12_x86_64.manylinux2010_x86_64.whl", hash = "sha256:00169caa125f35d1bca6045d65a662af0202704489fada95346cfa092ec23f39"},
|
||||
{file = "regex-2024.4.16-cp39-cp39-musllinux_1_1_aarch64.whl", hash = "sha256:6cc38067209354e16c5609b66285af17a2863a47585bcf75285cab33d4c3b8df"},
|
||||
{file = "regex-2024.4.16-cp39-cp39-musllinux_1_1_i686.whl", hash = "sha256:23cff1b267038501b179ccbbd74a821ac4a7192a1852d1d558e562b507d46013"},
|
||||
{file = "regex-2024.4.16-cp39-cp39-musllinux_1_1_ppc64le.whl", hash = "sha256:b9d320b3bf82a39f248769fc7f188e00f93526cc0fe739cfa197868633d44701"},
|
||||
{file = "regex-2024.4.16-cp39-cp39-musllinux_1_1_s390x.whl", hash = "sha256:89ec7f2c08937421bbbb8b48c54096fa4f88347946d4747021ad85f1b3021b3c"},
|
||||
{file = "regex-2024.4.16-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:4918fd5f8b43aa7ec031e0fef1ee02deb80b6afd49c85f0790be1dc4ce34cb50"},
|
||||
{file = "regex-2024.4.16-cp39-cp39-win32.whl", hash = "sha256:684e52023aec43bdf0250e843e1fdd6febbe831bd9d52da72333fa201aaa2335"},
|
||||
{file = "regex-2024.4.16-cp39-cp39-win_amd64.whl", hash = "sha256:e697e1c0238133589e00c244a8b676bc2cfc3ab4961318d902040d099fec7483"},
|
||||
{file = "regex-2024.4.16.tar.gz", hash = "sha256:fa454d26f2e87ad661c4f0c5a5fe4cf6aab1e307d1b94f16ffdfcb089ba685c0"},
|
||||
{file = "regex-2024.4.28-cp310-cp310-macosx_10_9_universal2.whl", hash = "sha256:cd196d056b40af073d95a2879678585f0b74ad35190fac04ca67954c582c6b61"},
|
||||
{file = "regex-2024.4.28-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:8bb381f777351bd534462f63e1c6afb10a7caa9fa2a421ae22c26e796fe31b1f"},
|
||||
{file = "regex-2024.4.28-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:47af45b6153522733aa6e92543938e97a70ce0900649ba626cf5aad290b737b6"},
|
||||
{file = "regex-2024.4.28-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:99d6a550425cc51c656331af0e2b1651e90eaaa23fb4acde577cf15068e2e20f"},
|
||||
{file = "regex-2024.4.28-cp310-cp310-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:bf29304a8011feb58913c382902fde3395957a47645bf848eea695839aa101b7"},
|
||||
{file = "regex-2024.4.28-cp310-cp310-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:92da587eee39a52c91aebea8b850e4e4f095fe5928d415cb7ed656b3460ae79a"},
|
||||
{file = "regex-2024.4.28-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:6277d426e2f31bdbacb377d17a7475e32b2d7d1f02faaecc48d8e370c6a3ff31"},
|
||||
{file = "regex-2024.4.28-cp310-cp310-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:28e1f28d07220c0f3da0e8fcd5a115bbb53f8b55cecf9bec0c946eb9a059a94c"},
|
||||
{file = "regex-2024.4.28-cp310-cp310-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_12_x86_64.manylinux2010_x86_64.whl", hash = "sha256:aaa179975a64790c1f2701ac562b5eeb733946eeb036b5bcca05c8d928a62f10"},
|
||||
{file = "regex-2024.4.28-cp310-cp310-musllinux_1_1_aarch64.whl", hash = "sha256:6f435946b7bf7a1b438b4e6b149b947c837cb23c704e780c19ba3e6855dbbdd3"},
|
||||
{file = "regex-2024.4.28-cp310-cp310-musllinux_1_1_i686.whl", hash = "sha256:19d6c11bf35a6ad077eb23852827f91c804eeb71ecb85db4ee1386825b9dc4db"},
|
||||
{file = "regex-2024.4.28-cp310-cp310-musllinux_1_1_ppc64le.whl", hash = "sha256:fdae0120cddc839eb8e3c15faa8ad541cc6d906d3eb24d82fb041cfe2807bc1e"},
|
||||
{file = "regex-2024.4.28-cp310-cp310-musllinux_1_1_s390x.whl", hash = "sha256:e672cf9caaf669053121f1766d659a8813bd547edef6e009205378faf45c67b8"},
|
||||
{file = "regex-2024.4.28-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:f57515750d07e14743db55d59759893fdb21d2668f39e549a7d6cad5d70f9fea"},
|
||||
{file = "regex-2024.4.28-cp310-cp310-win32.whl", hash = "sha256:a1409c4eccb6981c7baabc8888d3550df518add6e06fe74fa1d9312c1838652d"},
|
||||
{file = "regex-2024.4.28-cp310-cp310-win_amd64.whl", hash = "sha256:1f687a28640f763f23f8a9801fe9e1b37338bb1ca5d564ddd41619458f1f22d1"},
|
||||
{file = "regex-2024.4.28-cp311-cp311-macosx_10_9_universal2.whl", hash = "sha256:84077821c85f222362b72fdc44f7a3a13587a013a45cf14534df1cbbdc9a6796"},
|
||||
{file = "regex-2024.4.28-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:b45d4503de8f4f3dc02f1d28a9b039e5504a02cc18906cfe744c11def942e9eb"},
|
||||
{file = "regex-2024.4.28-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:457c2cd5a646dd4ed536c92b535d73548fb8e216ebee602aa9f48e068fc393f3"},
|
||||
{file = "regex-2024.4.28-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:2b51739ddfd013c6f657b55a508de8b9ea78b56d22b236052c3a85a675102dc6"},
|
||||
{file = "regex-2024.4.28-cp311-cp311-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:459226445c7d7454981c4c0ce0ad1a72e1e751c3e417f305722bbcee6697e06a"},
|
||||
{file = "regex-2024.4.28-cp311-cp311-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:670fa596984b08a4a769491cbdf22350431970d0112e03d7e4eeaecaafcd0fec"},
|
||||
{file = "regex-2024.4.28-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:fe00f4fe11c8a521b173e6324d862ee7ee3412bf7107570c9b564fe1119b56fb"},
|
||||
{file = "regex-2024.4.28-cp311-cp311-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:36f392dc7763fe7924575475736bddf9ab9f7a66b920932d0ea50c2ded2f5636"},
|
||||
{file = "regex-2024.4.28-cp311-cp311-musllinux_1_1_aarch64.whl", hash = "sha256:23a412b7b1a7063f81a742463f38821097b6a37ce1e5b89dd8e871d14dbfd86b"},
|
||||
{file = "regex-2024.4.28-cp311-cp311-musllinux_1_1_i686.whl", hash = "sha256:f1d6e4b7b2ae3a6a9df53efbf199e4bfcff0959dbdb5fd9ced34d4407348e39a"},
|
||||
{file = "regex-2024.4.28-cp311-cp311-musllinux_1_1_ppc64le.whl", hash = "sha256:499334ad139557de97cbc4347ee921c0e2b5e9c0f009859e74f3f77918339257"},
|
||||
{file = "regex-2024.4.28-cp311-cp311-musllinux_1_1_s390x.whl", hash = "sha256:0940038bec2fe9e26b203d636c44d31dd8766abc1fe66262da6484bd82461ccf"},
|
||||
{file = "regex-2024.4.28-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:66372c2a01782c5fe8e04bff4a2a0121a9897e19223d9eab30c54c50b2ebeb7f"},
|
||||
{file = "regex-2024.4.28-cp311-cp311-win32.whl", hash = "sha256:c77d10ec3c1cf328b2f501ca32583625987ea0f23a0c2a49b37a39ee5c4c4630"},
|
||||
{file = "regex-2024.4.28-cp311-cp311-win_amd64.whl", hash = "sha256:fc0916c4295c64d6890a46e02d4482bb5ccf33bf1a824c0eaa9e83b148291f90"},
|
||||
{file = "regex-2024.4.28-cp312-cp312-macosx_10_9_universal2.whl", hash = "sha256:08a1749f04fee2811c7617fdd46d2e46d09106fa8f475c884b65c01326eb15c5"},
|
||||
{file = "regex-2024.4.28-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:b8eb28995771c087a73338f695a08c9abfdf723d185e57b97f6175c5051ff1ae"},
|
||||
{file = "regex-2024.4.28-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:dd7ef715ccb8040954d44cfeff17e6b8e9f79c8019daae2fd30a8806ef5435c0"},
|
||||
{file = "regex-2024.4.28-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:fb0315a2b26fde4005a7c401707c5352df274460f2f85b209cf6024271373013"},
|
||||
{file = "regex-2024.4.28-cp312-cp312-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:f2fc053228a6bd3a17a9b0a3f15c3ab3cf95727b00557e92e1cfe094b88cc662"},
|
||||
{file = "regex-2024.4.28-cp312-cp312-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:7fe9739a686dc44733d52d6e4f7b9c77b285e49edf8570754b322bca6b85b4cc"},
|
||||
{file = "regex-2024.4.28-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:a74fcf77d979364f9b69fcf8200849ca29a374973dc193a7317698aa37d8b01c"},
|
||||
{file = "regex-2024.4.28-cp312-cp312-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:965fd0cf4694d76f6564896b422724ec7b959ef927a7cb187fc6b3f4e4f59833"},
|
||||
{file = "regex-2024.4.28-cp312-cp312-musllinux_1_1_aarch64.whl", hash = "sha256:2fef0b38c34ae675fcbb1b5db760d40c3fc3612cfa186e9e50df5782cac02bcd"},
|
||||
{file = "regex-2024.4.28-cp312-cp312-musllinux_1_1_i686.whl", hash = "sha256:bc365ce25f6c7c5ed70e4bc674f9137f52b7dd6a125037f9132a7be52b8a252f"},
|
||||
{file = "regex-2024.4.28-cp312-cp312-musllinux_1_1_ppc64le.whl", hash = "sha256:ac69b394764bb857429b031d29d9604842bc4cbfd964d764b1af1868eeebc4f0"},
|
||||
{file = "regex-2024.4.28-cp312-cp312-musllinux_1_1_s390x.whl", hash = "sha256:144a1fc54765f5c5c36d6d4b073299832aa1ec6a746a6452c3ee7b46b3d3b11d"},
|
||||
{file = "regex-2024.4.28-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:2630ca4e152c221072fd4a56d4622b5ada876f668ecd24d5ab62544ae6793ed6"},
|
||||
{file = "regex-2024.4.28-cp312-cp312-win32.whl", hash = "sha256:7f3502f03b4da52bbe8ba962621daa846f38489cae5c4a7b5d738f15f6443d17"},
|
||||
{file = "regex-2024.4.28-cp312-cp312-win_amd64.whl", hash = "sha256:0dd3f69098511e71880fb00f5815db9ed0ef62c05775395968299cb400aeab82"},
|
||||
{file = "regex-2024.4.28-cp38-cp38-macosx_10_9_universal2.whl", hash = "sha256:374f690e1dd0dbdcddea4a5c9bdd97632cf656c69113f7cd6a361f2a67221cb6"},
|
||||
{file = "regex-2024.4.28-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:25f87ae6b96374db20f180eab083aafe419b194e96e4f282c40191e71980c666"},
|
||||
{file = "regex-2024.4.28-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:5dbc1bcc7413eebe5f18196e22804a3be1bfdfc7e2afd415e12c068624d48247"},
|
||||
{file = "regex-2024.4.28-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:f85151ec5a232335f1be022b09fbbe459042ea1951d8a48fef251223fc67eee1"},
|
||||
{file = "regex-2024.4.28-cp38-cp38-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:57ba112e5530530fd175ed550373eb263db4ca98b5f00694d73b18b9a02e7185"},
|
||||
{file = "regex-2024.4.28-cp38-cp38-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:224803b74aab56aa7be313f92a8d9911dcade37e5f167db62a738d0c85fdac4b"},
|
||||
{file = "regex-2024.4.28-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:0a54a047b607fd2d2d52a05e6ad294602f1e0dec2291152b745870afc47c1397"},
|
||||
{file = "regex-2024.4.28-cp38-cp38-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:0a2a512d623f1f2d01d881513af9fc6a7c46e5cfffb7dc50c38ce959f9246c94"},
|
||||
{file = "regex-2024.4.28-cp38-cp38-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_12_x86_64.manylinux2010_x86_64.whl", hash = "sha256:c06bf3f38f0707592898428636cbb75d0a846651b053a1cf748763e3063a6925"},
|
||||
{file = "regex-2024.4.28-cp38-cp38-musllinux_1_1_aarch64.whl", hash = "sha256:1031a5e7b048ee371ab3653aad3030ecfad6ee9ecdc85f0242c57751a05b0ac4"},
|
||||
{file = "regex-2024.4.28-cp38-cp38-musllinux_1_1_i686.whl", hash = "sha256:d7a353ebfa7154c871a35caca7bfd8f9e18666829a1dc187115b80e35a29393e"},
|
||||
{file = "regex-2024.4.28-cp38-cp38-musllinux_1_1_ppc64le.whl", hash = "sha256:7e76b9cfbf5ced1aca15a0e5b6f229344d9b3123439ffce552b11faab0114a02"},
|
||||
{file = "regex-2024.4.28-cp38-cp38-musllinux_1_1_s390x.whl", hash = "sha256:5ce479ecc068bc2a74cb98dd8dba99e070d1b2f4a8371a7dfe631f85db70fe6e"},
|
||||
{file = "regex-2024.4.28-cp38-cp38-musllinux_1_1_x86_64.whl", hash = "sha256:7d77b6f63f806578c604dca209280e4c54f0fa9a8128bb8d2cc5fb6f99da4150"},
|
||||
{file = "regex-2024.4.28-cp38-cp38-win32.whl", hash = "sha256:d84308f097d7a513359757c69707ad339da799e53b7393819ec2ea36bc4beb58"},
|
||||
{file = "regex-2024.4.28-cp38-cp38-win_amd64.whl", hash = "sha256:2cc1b87bba1dd1a898e664a31012725e48af826bf3971e786c53e32e02adae6c"},
|
||||
{file = "regex-2024.4.28-cp39-cp39-macosx_10_9_universal2.whl", hash = "sha256:7413167c507a768eafb5424413c5b2f515c606be5bb4ef8c5dee43925aa5718b"},
|
||||
{file = "regex-2024.4.28-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:108e2dcf0b53a7c4ab8986842a8edcb8ab2e59919a74ff51c296772e8e74d0ae"},
|
||||
{file = "regex-2024.4.28-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:f1c5742c31ba7d72f2dedf7968998730664b45e38827637e0f04a2ac7de2f5f1"},
|
||||
{file = "regex-2024.4.28-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:ecc6148228c9ae25ce403eade13a0961de1cb016bdb35c6eafd8e7b87ad028b1"},
|
||||
{file = "regex-2024.4.28-cp39-cp39-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:b7d893c8cf0e2429b823ef1a1d360a25950ed11f0e2a9df2b5198821832e1947"},
|
||||
{file = "regex-2024.4.28-cp39-cp39-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:4290035b169578ffbbfa50d904d26bec16a94526071ebec3dadbebf67a26b25e"},
|
||||
{file = "regex-2024.4.28-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:44a22ae1cfd82e4ffa2066eb3390777dc79468f866f0625261a93e44cdf6482b"},
|
||||
{file = "regex-2024.4.28-cp39-cp39-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:fd24fd140b69f0b0bcc9165c397e9b2e89ecbeda83303abf2a072609f60239e2"},
|
||||
{file = "regex-2024.4.28-cp39-cp39-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_12_x86_64.manylinux2010_x86_64.whl", hash = "sha256:39fb166d2196413bead229cd64a2ffd6ec78ebab83fff7d2701103cf9f4dfd26"},
|
||||
{file = "regex-2024.4.28-cp39-cp39-musllinux_1_1_aarch64.whl", hash = "sha256:9301cc6db4d83d2c0719f7fcda37229691745168bf6ae849bea2e85fc769175d"},
|
||||
{file = "regex-2024.4.28-cp39-cp39-musllinux_1_1_i686.whl", hash = "sha256:7c3d389e8d76a49923683123730c33e9553063d9041658f23897f0b396b2386f"},
|
||||
{file = "regex-2024.4.28-cp39-cp39-musllinux_1_1_ppc64le.whl", hash = "sha256:99ef6289b62042500d581170d06e17f5353b111a15aa6b25b05b91c6886df8fc"},
|
||||
{file = "regex-2024.4.28-cp39-cp39-musllinux_1_1_s390x.whl", hash = "sha256:b91d529b47798c016d4b4c1d06cc826ac40d196da54f0de3c519f5a297c5076a"},
|
||||
{file = "regex-2024.4.28-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:43548ad74ea50456e1c68d3c67fff3de64c6edb85bcd511d1136f9b5376fc9d1"},
|
||||
{file = "regex-2024.4.28-cp39-cp39-win32.whl", hash = "sha256:05d9b6578a22db7dedb4df81451f360395828b04f4513980b6bd7a1412c679cc"},
|
||||
{file = "regex-2024.4.28-cp39-cp39-win_amd64.whl", hash = "sha256:3986217ec830c2109875be740531feb8ddafe0dfa49767cdcd072ed7e8927962"},
|
||||
{file = "regex-2024.4.28.tar.gz", hash = "sha256:83ab366777ea45d58f72593adf35d36ca911ea8bd838483c1823b883a121b0e4"},
|
||||
]
|
||||
|
||||
[[package]]
|
||||
|
@ -4014,13 +3990,13 @@ zstd = ["zstandard (>=0.18.0)"]
|
|||
|
||||
[[package]]
|
||||
name = "virtualenv"
|
||||
version = "20.26.0"
|
||||
version = "20.26.1"
|
||||
description = "Virtual Python Environment builder"
|
||||
optional = true
|
||||
python-versions = ">=3.7"
|
||||
files = [
|
||||
{file = "virtualenv-20.26.0-py3-none-any.whl", hash = "sha256:0846377ea76e818daaa3e00a4365c018bc3ac9760cbb3544de542885aad61fb3"},
|
||||
{file = "virtualenv-20.26.0.tar.gz", hash = "sha256:ec25a9671a5102c8d2657f62792a27b48f016664c6873f6beed3800008577210"},
|
||||
{file = "virtualenv-20.26.1-py3-none-any.whl", hash = "sha256:7aa9982a728ae5892558bff6a2839c00b9ed145523ece2274fad6f414690ae75"},
|
||||
{file = "virtualenv-20.26.1.tar.gz", hash = "sha256:604bfdceaeece392802e6ae48e69cec49168b9c5f4a44e483963f9242eb0e78b"},
|
||||
]
|
||||
|
||||
[package.dependencies]
|
||||
|
|
|
@ -4,9 +4,9 @@ import gymnasium as gym
|
|||
import pytest
|
||||
|
||||
import lerobot
|
||||
from lerobot.common.policies.act.modeling_act import ActionChunkingTransformerPolicy
|
||||
from lerobot.common.policies.act.modeling_act import ACTPolicy
|
||||
from lerobot.common.policies.diffusion.modeling_diffusion import DiffusionPolicy
|
||||
from lerobot.common.policies.tdmpc.policy import TDMPCPolicy
|
||||
from lerobot.common.policies.tdmpc.modeling_tdmpc import TDMPCPolicy
|
||||
from tests.utils import require_env
|
||||
|
||||
|
||||
|
@ -30,7 +30,7 @@ def test_available_policies():
|
|||
consistent with those listed in `lerobot/__init__.py`.
|
||||
"""
|
||||
policy_classes = [
|
||||
ActionChunkingTransformerPolicy,
|
||||
ACTPolicy,
|
||||
DiffusionPolicy,
|
||||
TDMPCPolicy,
|
||||
]
|
||||
|
|
|
@ -35,7 +35,7 @@ def test_factory(env_name, repo_id, policy_name):
|
|||
DEFAULT_CONFIG_PATH,
|
||||
overrides=[
|
||||
f"env={env_name}",
|
||||
f"dataset.repo_id={repo_id}",
|
||||
f"dataset_repo_id={repo_id}",
|
||||
f"policy={policy_name}",
|
||||
f"device={DEVICE}",
|
||||
],
|
||||
|
|
|
@ -39,14 +39,14 @@ def test_examples_3_and_2():
|
|||
("training_steps = 5000", "training_steps = 1"),
|
||||
("num_workers=4", "num_workers=0"),
|
||||
('device = torch.device("cuda")', 'device = torch.device("cpu")'),
|
||||
("batch_size=cfg.batch_size", "batch_size=1"),
|
||||
("batch_size=64", "batch_size=1"),
|
||||
],
|
||||
)
|
||||
|
||||
# Pass empty globals to allow dictionary comprehension https://stackoverflow.com/a/32897127/4391249.
|
||||
exec(file_contents, {})
|
||||
|
||||
for file_name in ["model.pt", "config.yaml"]:
|
||||
for file_name in ["model.safetensors", "config.json", "config.yaml"]:
|
||||
assert Path(f"outputs/train/example_pusht_diffusion/{file_name}").exists()
|
||||
|
||||
path = "examples/2_evaluate_pretrained_policy.py"
|
||||
|
@ -58,15 +58,15 @@ def test_examples_3_and_2():
|
|||
file_contents = _find_and_replace(
|
||||
file_contents,
|
||||
[
|
||||
('"eval_episodes=10"', '"eval_episodes=1"'),
|
||||
('"rollout_batch_size=10"', '"rollout_batch_size=1"'),
|
||||
('"device=cuda"', '"device=cpu"'),
|
||||
('pretrained_policy_name = "lerobot/diffusion_policy_pusht_image"', ""),
|
||||
("pretrained_policy_path = Path(snapshot_download(pretrained_policy_name))", ""),
|
||||
(
|
||||
'# folder = Path("outputs/train/example_pusht_diffusion")',
|
||||
'folder = Path("outputs/train/example_pusht_diffusion")',
|
||||
'# pretrained_policy_path = Path("outputs/train/example_pusht_diffusion")',
|
||||
'pretrained_policy_path = Path("outputs/train/example_pusht_diffusion")',
|
||||
),
|
||||
('hub_id = "lerobot/diffusion_policy_pusht_image"', ""),
|
||||
("folder = Path(snapshot_download(hub_id)", ""),
|
||||
('"eval.n_episodes=10"', '"eval.n_episodes=1"'),
|
||||
('"eval.batch_size=10"', '"eval.batch_size=1"'),
|
||||
('"device=cuda"', '"device=cpu"'),
|
||||
],
|
||||
)
|
||||
|
||||
|
|
|
@ -1,41 +1,50 @@
|
|||
import inspect
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
from huggingface_hub import PyTorchModelHubMixin
|
||||
|
||||
from lerobot import available_policies
|
||||
from lerobot.common.datasets.factory import make_dataset
|
||||
from lerobot.common.datasets.utils import cycle
|
||||
from lerobot.common.envs.factory import make_env
|
||||
from lerobot.common.envs.utils import postprocess_action, preprocess_observation
|
||||
from lerobot.common.policies.act.modeling_act import ActionChunkingTransformerPolicy
|
||||
from lerobot.common.policies.diffusion.modeling_diffusion import DiffusionPolicy
|
||||
from lerobot.common.policies.factory import make_policy
|
||||
from lerobot.common.policies.factory import get_policy_and_config_classes, make_policy
|
||||
from lerobot.common.policies.normalize import Normalize, Unnormalize
|
||||
from lerobot.common.policies.policy_protocol import Policy
|
||||
from lerobot.common.utils.utils import init_hydra_config
|
||||
from tests.utils import DEFAULT_CONFIG_PATH, DEVICE, require_env
|
||||
|
||||
|
||||
@pytest.mark.parametrize("policy_name", available_policies)
|
||||
def test_get_policy_and_config_classes(policy_name: str):
|
||||
"""Check that the correct policy and config classes are returned."""
|
||||
policy_cls, config_cls = get_policy_and_config_classes(policy_name)
|
||||
assert policy_cls.name == policy_name
|
||||
assert issubclass(config_cls, inspect.signature(policy_cls.__init__).parameters["config"].annotation)
|
||||
|
||||
|
||||
# TODO(aliberts): refactor using lerobot/__init__.py variables
|
||||
@pytest.mark.parametrize(
|
||||
"env_name,policy_name,extra_overrides",
|
||||
[
|
||||
# ("xarm", "tdmpc", ["policy.mpc=true"]),
|
||||
# ("pusht", "tdmpc", ["policy.mpc=false"]),
|
||||
("xarm", "tdmpc", ["policy.use_mpc=true", "dataset_repo_id=lerobot/xarm_lift_medium"]),
|
||||
("pusht", "diffusion", []),
|
||||
("aloha", "act", ["env.task=AlohaInsertion-v0", "dataset.repo_id=lerobot/aloha_sim_insertion_human"]),
|
||||
("aloha", "act", ["env.task=AlohaInsertion-v0", "dataset_repo_id=lerobot/aloha_sim_insertion_human"]),
|
||||
(
|
||||
"aloha",
|
||||
"act",
|
||||
["env.task=AlohaInsertion-v0", "dataset.repo_id=lerobot/aloha_sim_insertion_scripted"],
|
||||
["env.task=AlohaInsertion-v0", "dataset_repo_id=lerobot/aloha_sim_insertion_scripted"],
|
||||
),
|
||||
(
|
||||
"aloha",
|
||||
"act",
|
||||
["env.task=AlohaTransferCube-v0", "dataset.repo_id=lerobot/aloha_sim_transfer_cube_human"],
|
||||
["env.task=AlohaTransferCube-v0", "dataset_repo_id=lerobot/aloha_sim_transfer_cube_human"],
|
||||
),
|
||||
(
|
||||
"aloha",
|
||||
"act",
|
||||
["env.task=AlohaTransferCube-v0", "dataset.repo_id=lerobot/aloha_sim_transfer_cube_scripted"],
|
||||
["env.task=AlohaTransferCube-v0", "dataset_repo_id=lerobot/aloha_sim_transfer_cube_scripted"],
|
||||
),
|
||||
],
|
||||
)
|
||||
|
@ -44,7 +53,8 @@ def test_policy(env_name, policy_name, extra_overrides):
|
|||
"""
|
||||
Tests:
|
||||
- Making the policy object.
|
||||
- Checking that the policy follows the correct protocol.
|
||||
- Checking that the policy follows the correct protocol and subclasses nn.Module
|
||||
and PyTorchModelHubMixin.
|
||||
- Updating the policy.
|
||||
- Using the policy to select actions at inference time.
|
||||
- Test the action can be applied to the policy
|
||||
|
@ -61,11 +71,13 @@ def test_policy(env_name, policy_name, extra_overrides):
|
|||
|
||||
# Check that we can make the policy object.
|
||||
dataset = make_dataset(cfg)
|
||||
policy = make_policy(cfg, dataset_stats=dataset.stats)
|
||||
policy = make_policy(hydra_cfg=cfg, dataset_stats=dataset.stats)
|
||||
# Check that the policy follows the required protocol.
|
||||
assert isinstance(
|
||||
policy, Policy
|
||||
), f"The policy does not follow the required protocol. Please see {Policy.__module__}.{Policy.__name__}."
|
||||
assert isinstance(policy, torch.nn.Module)
|
||||
assert isinstance(policy, PyTorchModelHubMixin)
|
||||
|
||||
# Check that we run select_actions and get the appropriate output.
|
||||
env = make_env(cfg, num_parallel_envs=2)
|
||||
|
@ -86,7 +98,7 @@ def test_policy(env_name, policy_name, extra_overrides):
|
|||
batch[key] = batch[key].to(DEVICE, non_blocking=True)
|
||||
|
||||
# Test updating the policy
|
||||
policy.forward(batch, step=0)
|
||||
policy.forward(batch)
|
||||
|
||||
# reset the policy and environment
|
||||
policy.reset()
|
||||
|
@ -100,7 +112,7 @@ def test_policy(env_name, policy_name, extra_overrides):
|
|||
|
||||
# get the next action for the environment
|
||||
with torch.inference_mode():
|
||||
action = policy.select_action(observation, step=0)
|
||||
action = policy.select_action(observation)
|
||||
|
||||
# convert action to cpu numpy array
|
||||
action = postprocess_action(action)
|
||||
|
@ -108,29 +120,25 @@ def test_policy(env_name, policy_name, extra_overrides):
|
|||
# Test step through policy
|
||||
env.step(action)
|
||||
|
||||
# Test load state_dict
|
||||
if policy_name != "tdmpc":
|
||||
# TODO(rcadene, alexander-soare): make it work for tdmpc
|
||||
new_policy = make_policy(cfg)
|
||||
new_policy.load_state_dict(policy.state_dict())
|
||||
|
||||
@pytest.mark.parametrize("policy_name", available_policies)
|
||||
def test_policy_defaults(policy_name: str):
|
||||
"""Check that the policy can be instantiated with defaults."""
|
||||
policy_cls, _ = get_policy_and_config_classes(policy_name)
|
||||
policy_cls()
|
||||
|
||||
|
||||
@pytest.mark.parametrize("policy_cls", [DiffusionPolicy, ActionChunkingTransformerPolicy])
|
||||
def test_policy_defaults(policy_cls):
|
||||
kwargs = {}
|
||||
# TODO(alexander-soare): Remove this kwargs hack when we move the scheduler out of DP.
|
||||
if policy_cls is DiffusionPolicy:
|
||||
kwargs = {"lr_scheduler_num_training_steps": 1}
|
||||
policy_cls(**kwargs)
|
||||
@pytest.mark.parametrize("policy_name", available_policies)
|
||||
def test_save_and_load_pretrained(policy_name: str):
|
||||
policy_cls, _ = get_policy_and_config_classes(policy_name)
|
||||
policy: Policy = policy_cls()
|
||||
save_dir = "/tmp/test_save_and_load_pretrained_{policy_cls.__name__}"
|
||||
policy.save_pretrained(save_dir)
|
||||
policy_ = policy_cls.from_pretrained(save_dir)
|
||||
assert all(torch.equal(p, p_) for p, p_ in zip(policy.parameters(), policy_.parameters(), strict=True))
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"insert_temporal_dim",
|
||||
[
|
||||
False,
|
||||
True,
|
||||
],
|
||||
)
|
||||
@pytest.mark.parametrize("insert_temporal_dim", [False, True])
|
||||
def test_normalize(insert_temporal_dim):
|
||||
"""
|
||||
Test that normalize/unnormalize can run without exceptions when properly set up, and that they raise
|
||||
|
|
|
@ -20,7 +20,7 @@ def test_visualize_dataset(tmpdir, repo_id):
|
|||
overrides=[
|
||||
"policy=act",
|
||||
"env=aloha",
|
||||
f"dataset.repo_id={repo_id}",
|
||||
f"dataset_repo_id={repo_id}",
|
||||
],
|
||||
)
|
||||
video_paths = visualize_dataset(cfg, out_dir=tmpdir)
|
||||
|
|
Loading…
Reference in New Issue