Move normalize/unnormalize transforms to policy for act and diffusion
This commit is contained in:
parent
c1bcf857c5
commit
42ed7bb670
|
@ -263,15 +263,13 @@ Secondly, assuming you have trained a policy, you need:
|
||||||
|
|
||||||
- `config.yaml` which you can get from the `.hydra` directory of your training output folder.
|
- `config.yaml` which you can get from the `.hydra` directory of your training output folder.
|
||||||
- `model.pt` which should be one of the saved models in the `models` directory of your training output folder (they won't be named `model.pt` but you will need to choose one).
|
- `model.pt` which should be one of the saved models in the `models` directory of your training output folder (they won't be named `model.pt` but you will need to choose one).
|
||||||
- `stats.pth` which should point to the same file in the dataset directory (found in `data/{dataset_name}`).
|
|
||||||
|
|
||||||
To upload these to the hub, prepare a folder with the following structure (you can use symlinks rather than copying):
|
To upload these to the hub, prepare a folder with the following structure (you can use symlinks rather than copying):
|
||||||
|
|
||||||
```
|
```
|
||||||
to_upload
|
to_upload
|
||||||
├── config.yaml
|
├── config.yaml
|
||||||
├── model.pt
|
└── model.pt
|
||||||
└── stats.pth
|
|
||||||
```
|
```
|
||||||
|
|
||||||
With the folder prepared, run the following with a desired revision ID.
|
With the folder prepared, run the following with a desired revision ID.
|
||||||
|
|
|
@ -19,7 +19,6 @@ folder = Path(snapshot_download(hub_id))
|
||||||
|
|
||||||
config_path = folder / "config.yaml"
|
config_path = folder / "config.yaml"
|
||||||
weights_path = folder / "model.pt"
|
weights_path = folder / "model.pt"
|
||||||
stats_path = folder / "stats.pth" # normalization stats
|
|
||||||
|
|
||||||
# Override some config parameters to do with evaluation.
|
# Override some config parameters to do with evaluation.
|
||||||
overrides = [
|
overrides = [
|
||||||
|
@ -36,5 +35,4 @@ cfg = init_hydra_config(config_path, overrides)
|
||||||
eval(
|
eval(
|
||||||
cfg,
|
cfg,
|
||||||
out_dir=f"outputs/eval/example_{cfg.env.name}_{cfg.policy.name}",
|
out_dir=f"outputs/eval/example_{cfg.env.name}_{cfg.policy.name}",
|
||||||
stats_path=stats_path,
|
|
||||||
)
|
)
|
||||||
|
|
|
@ -62,7 +62,6 @@ while not done:
|
||||||
done = True
|
done = True
|
||||||
break
|
break
|
||||||
|
|
||||||
# Save the policy, configuration, and normalization stats for later use.
|
# Save the policy and configuration for later use.
|
||||||
policy.save(output_directory / "model.pt")
|
policy.save(output_directory / "model.pt")
|
||||||
OmegaConf.save(hydra_cfg, output_directory / "config.yaml")
|
OmegaConf.save(hydra_cfg, output_directory / "config.yaml")
|
||||||
torch.save(dataset.transform.transforms[-1].stats, output_directory / "stats.pth")
|
|
||||||
|
|
|
@ -2,18 +2,12 @@ import os
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from torchvision.transforms import v2
|
|
||||||
|
|
||||||
from lerobot.common.transforms import NormalizeTransform
|
|
||||||
|
|
||||||
DATA_DIR = Path(os.environ["DATA_DIR"]) if "DATA_DIR" in os.environ else None
|
DATA_DIR = Path(os.environ["DATA_DIR"]) if "DATA_DIR" in os.environ else None
|
||||||
|
|
||||||
|
|
||||||
def make_dataset(
|
def make_dataset(
|
||||||
cfg,
|
cfg,
|
||||||
# set normalize=False to remove all transformations and keep images unnormalized in [0,255]
|
|
||||||
normalize=True,
|
|
||||||
stats_path=None,
|
|
||||||
split="train",
|
split="train",
|
||||||
):
|
):
|
||||||
if cfg.env.name == "xarm":
|
if cfg.env.name == "xarm":
|
||||||
|
@ -33,58 +27,23 @@ def make_dataset(
|
||||||
else:
|
else:
|
||||||
raise ValueError(cfg.env.name)
|
raise ValueError(cfg.env.name)
|
||||||
|
|
||||||
transforms = None
|
|
||||||
if normalize:
|
|
||||||
# TODO(rcadene): make normalization strategy configurable between mean_std, min_max, manual_min_max,
|
|
||||||
# min_max_from_spec
|
|
||||||
# TODO(rcadene): remove this and put it in config. Ideally we want to reproduce SOTA results just with mean_std
|
|
||||||
normalization_mode = "mean_std" if cfg.env.name == "aloha" else "min_max"
|
|
||||||
|
|
||||||
if cfg.policy.name == "diffusion" and cfg.env.name == "pusht":
|
|
||||||
stats = {}
|
|
||||||
# TODO(rcadene): we overwrite stats to have the same as pretrained model, but we should remove this
|
|
||||||
stats["observation.state"] = {}
|
|
||||||
stats["observation.state"]["min"] = torch.tensor([13.456424, 32.938293], dtype=torch.float32)
|
|
||||||
stats["observation.state"]["max"] = torch.tensor([496.14618, 510.9579], dtype=torch.float32)
|
|
||||||
stats["action"] = {}
|
|
||||||
stats["action"]["min"] = torch.tensor([12.0, 25.0], dtype=torch.float32)
|
|
||||||
stats["action"]["max"] = torch.tensor([511.0, 511.0], dtype=torch.float32)
|
|
||||||
elif stats_path is None:
|
|
||||||
# load a first dataset to access precomputed stats
|
|
||||||
stats_dataset = clsfunc(
|
|
||||||
dataset_id=cfg.dataset_id,
|
|
||||||
split="train",
|
|
||||||
root=DATA_DIR,
|
|
||||||
)
|
|
||||||
stats = stats_dataset.stats
|
|
||||||
else:
|
|
||||||
stats = torch.load(stats_path)
|
|
||||||
|
|
||||||
transforms = v2.Compose(
|
|
||||||
[
|
|
||||||
NormalizeTransform(
|
|
||||||
stats,
|
|
||||||
in_keys=[
|
|
||||||
"observation.state",
|
|
||||||
"action",
|
|
||||||
],
|
|
||||||
mode=normalization_mode,
|
|
||||||
),
|
|
||||||
]
|
|
||||||
)
|
|
||||||
|
|
||||||
delta_timestamps = cfg.policy.get("delta_timestamps")
|
delta_timestamps = cfg.policy.get("delta_timestamps")
|
||||||
if delta_timestamps is not None:
|
if delta_timestamps is not None:
|
||||||
for key in delta_timestamps:
|
for key in delta_timestamps:
|
||||||
if isinstance(delta_timestamps[key], str):
|
if isinstance(delta_timestamps[key], str):
|
||||||
delta_timestamps[key] = eval(delta_timestamps[key])
|
delta_timestamps[key] = eval(delta_timestamps[key])
|
||||||
|
|
||||||
|
# TODO(rcadene): add data augmentations
|
||||||
|
|
||||||
dataset = clsfunc(
|
dataset = clsfunc(
|
||||||
dataset_id=cfg.dataset_id,
|
dataset_id=cfg.dataset_id,
|
||||||
split=split,
|
split=split,
|
||||||
root=DATA_DIR,
|
root=DATA_DIR,
|
||||||
delta_timestamps=delta_timestamps,
|
delta_timestamps=delta_timestamps,
|
||||||
transform=transforms,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if cfg.get("override_dataset_stats"):
|
||||||
|
for key, val in cfg.override_dataset_stats.items():
|
||||||
|
dataset.stats[key] = torch.tensor(val)
|
||||||
|
|
||||||
return dataset
|
return dataset
|
||||||
|
|
|
@ -1,10 +1,8 @@
|
||||||
import einops
|
import einops
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from lerobot.common.transforms import apply_inverse_transform
|
|
||||||
|
|
||||||
|
def preprocess_observation(observation):
|
||||||
def preprocess_observation(observation, transform=None):
|
|
||||||
# map to expected inputs for the policy
|
# map to expected inputs for the policy
|
||||||
obs = {}
|
obs = {}
|
||||||
|
|
||||||
|
@ -33,19 +31,11 @@ def preprocess_observation(observation, transform=None):
|
||||||
# TODO(rcadene): enable pixels only baseline with `obs_type="pixels"` in environment by removing requirement for "agent_pos"
|
# TODO(rcadene): enable pixels only baseline with `obs_type="pixels"` in environment by removing requirement for "agent_pos"
|
||||||
obs["observation.state"] = torch.from_numpy(observation["agent_pos"]).float()
|
obs["observation.state"] = torch.from_numpy(observation["agent_pos"]).float()
|
||||||
|
|
||||||
# apply same transforms as in training
|
|
||||||
if transform is not None:
|
|
||||||
for key in obs:
|
|
||||||
obs[key] = torch.stack([transform({key: item})[key] for item in obs[key]])
|
|
||||||
|
|
||||||
return obs
|
return obs
|
||||||
|
|
||||||
|
|
||||||
def postprocess_action(action, transform=None):
|
def postprocess_action(action):
|
||||||
action = action.to("cpu")
|
action = action.to("cpu").numpy()
|
||||||
# action is a batch (num_env,action_dim) instead of an item (action_dim),
|
|
||||||
# we assume applying inverse transform on a batch works the same
|
|
||||||
action = apply_inverse_transform({"action": action}, transform)["action"].numpy()
|
|
||||||
assert (
|
assert (
|
||||||
action.ndim == 2
|
action.ndim == 2
|
||||||
), "we assume dimensions are respectively the number of parallel envs, action dimensions"
|
), "we assume dimensions are respectively the number of parallel envs, action dimensions"
|
||||||
|
|
|
@ -1,4 +1,4 @@
|
||||||
from dataclasses import dataclass, field
|
from dataclasses import dataclass
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
|
@ -60,12 +60,14 @@ class ActionChunkingTransformerConfig:
|
||||||
chunk_size: int = 100
|
chunk_size: int = 100
|
||||||
n_action_steps: int = 100
|
n_action_steps: int = 100
|
||||||
|
|
||||||
# Vision preprocessing.
|
# Normalization / Unnormalization
|
||||||
image_normalization_mean: tuple[float, float, float] = field(
|
normalize_input_modes: dict[str, str] = {
|
||||||
default_factory=lambda: [0.485, 0.456, 0.406]
|
"observation.image": "mean_std",
|
||||||
)
|
"observation.state": "mean_std",
|
||||||
image_normalization_std: tuple[float, float, float] = field(default_factory=lambda: [0.229, 0.224, 0.225])
|
}
|
||||||
|
unnormalize_output_modes: dict[str, str] = {
|
||||||
|
"action": "mean_std",
|
||||||
|
}
|
||||||
# Architecture.
|
# Architecture.
|
||||||
# Vision backbone.
|
# Vision backbone.
|
||||||
vision_backbone: str = "resnet18"
|
vision_backbone: str = "resnet18"
|
||||||
|
|
|
@ -15,12 +15,15 @@ import numpy as np
|
||||||
import torch
|
import torch
|
||||||
import torch.nn.functional as F # noqa: N812
|
import torch.nn.functional as F # noqa: N812
|
||||||
import torchvision
|
import torchvision
|
||||||
import torchvision.transforms as transforms
|
|
||||||
from torch import Tensor, nn
|
from torch import Tensor, nn
|
||||||
from torchvision.models._utils import IntermediateLayerGetter
|
from torchvision.models._utils import IntermediateLayerGetter
|
||||||
from torchvision.ops.misc import FrozenBatchNorm2d
|
from torchvision.ops.misc import FrozenBatchNorm2d
|
||||||
|
|
||||||
from lerobot.common.policies.act.configuration_act import ActionChunkingTransformerConfig
|
from lerobot.common.policies.act.configuration_act import ActionChunkingTransformerConfig
|
||||||
|
from lerobot.common.policies.utils import (
|
||||||
|
normalize_inputs,
|
||||||
|
unnormalize_outputs,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class ActionChunkingTransformerPolicy(nn.Module):
|
class ActionChunkingTransformerPolicy(nn.Module):
|
||||||
|
@ -62,7 +65,7 @@ class ActionChunkingTransformerPolicy(nn.Module):
|
||||||
|
|
||||||
name = "act"
|
name = "act"
|
||||||
|
|
||||||
def __init__(self, cfg: ActionChunkingTransformerConfig | None = None):
|
def __init__(self, cfg: ActionChunkingTransformerConfig | None = None, dataset_stats=None):
|
||||||
"""
|
"""
|
||||||
Args:
|
Args:
|
||||||
cfg: Policy configuration class instance or None, in which case the default instantiation of the
|
cfg: Policy configuration class instance or None, in which case the default instantiation of the
|
||||||
|
@ -72,6 +75,9 @@ class ActionChunkingTransformerPolicy(nn.Module):
|
||||||
if cfg is None:
|
if cfg is None:
|
||||||
cfg = ActionChunkingTransformerConfig()
|
cfg = ActionChunkingTransformerConfig()
|
||||||
self.cfg = cfg
|
self.cfg = cfg
|
||||||
|
self.register_buffer("dataset_stats", dataset_stats)
|
||||||
|
self.normalize_input_modes = cfg.normalize_input_modes
|
||||||
|
self.unnormalize_output_modes = cfg.unnormalize_output_modes
|
||||||
|
|
||||||
# BERT style VAE encoder with input [cls, *joint_space_configuration, *action_sequence].
|
# BERT style VAE encoder with input [cls, *joint_space_configuration, *action_sequence].
|
||||||
# The cls token forms parameters of the latent's distribution (like this [*means, *log_variances]).
|
# The cls token forms parameters of the latent's distribution (like this [*means, *log_variances]).
|
||||||
|
@ -93,9 +99,6 @@ class ActionChunkingTransformerPolicy(nn.Module):
|
||||||
)
|
)
|
||||||
|
|
||||||
# Backbone for image feature extraction.
|
# Backbone for image feature extraction.
|
||||||
self.image_normalizer = transforms.Normalize(
|
|
||||||
mean=cfg.image_normalization_mean, std=cfg.image_normalization_std
|
|
||||||
)
|
|
||||||
backbone_model = getattr(torchvision.models, cfg.vision_backbone)(
|
backbone_model = getattr(torchvision.models, cfg.vision_backbone)(
|
||||||
replace_stride_with_dilation=[False, False, cfg.replace_final_stride_with_dilation],
|
replace_stride_with_dilation=[False, False, cfg.replace_final_stride_with_dilation],
|
||||||
pretrained=cfg.use_pretrained_backbone,
|
pretrained=cfg.use_pretrained_backbone,
|
||||||
|
@ -169,10 +172,15 @@ class ActionChunkingTransformerPolicy(nn.Module):
|
||||||
queue is empty.
|
queue is empty.
|
||||||
"""
|
"""
|
||||||
self.eval()
|
self.eval()
|
||||||
|
|
||||||
|
batch = normalize_inputs(batch, self.dataset_stats, self.normalize_input_modes)
|
||||||
|
|
||||||
if len(self._action_queue) == 0:
|
if len(self._action_queue) == 0:
|
||||||
# `_forward` returns a (batch_size, n_action_steps, action_dim) tensor, but the queue effectively
|
# `_forward` returns a (batch_size, n_action_steps, action_dim) tensor, but the queue effectively
|
||||||
# has shape (n_action_steps, batch_size, *), hence the transpose.
|
# has shape (n_action_steps, batch_size, *), hence the transpose.
|
||||||
self._action_queue.extend(self._forward(batch)[0][: self.cfg.n_action_steps].transpose(0, 1))
|
actions = self._forward(batch)[0][: self.cfg.n_action_steps]
|
||||||
|
actions = unnormalize_outputs(actions, self.dataset_stats, self.unnormalize_output_modes)
|
||||||
|
self._action_queue.extend(actions.transpose(0, 1))
|
||||||
return self._action_queue.popleft()
|
return self._action_queue.popleft()
|
||||||
|
|
||||||
def forward(self, batch, **_) -> dict[str, Tensor]:
|
def forward(self, batch, **_) -> dict[str, Tensor]:
|
||||||
|
@ -203,7 +211,10 @@ class ActionChunkingTransformerPolicy(nn.Module):
|
||||||
"""Run the model in train mode, compute the loss, and do an optimization step."""
|
"""Run the model in train mode, compute the loss, and do an optimization step."""
|
||||||
start_time = time.time()
|
start_time = time.time()
|
||||||
self.train()
|
self.train()
|
||||||
|
|
||||||
|
batch = normalize_inputs(batch, self.dataset_stats, self.normalize_input_modes)
|
||||||
loss_dict = self.forward(batch)
|
loss_dict = self.forward(batch)
|
||||||
|
# TODO(rcadene): unnormalize_outputs(actions, self.dataset_stats, self.unnormalize_output_modes)
|
||||||
loss = loss_dict["loss"]
|
loss = loss_dict["loss"]
|
||||||
loss.backward()
|
loss.backward()
|
||||||
|
|
||||||
|
@ -309,7 +320,7 @@ class ActionChunkingTransformerPolicy(nn.Module):
|
||||||
# Camera observation features and positional embeddings.
|
# Camera observation features and positional embeddings.
|
||||||
all_cam_features = []
|
all_cam_features = []
|
||||||
all_cam_pos_embeds = []
|
all_cam_pos_embeds = []
|
||||||
images = self.image_normalizer(batch["observation.images"])
|
images = batch["observation.images"]
|
||||||
for cam_index in range(len(self.cfg.camera_names)):
|
for cam_index in range(len(self.cfg.camera_names)):
|
||||||
cam_features = self.backbone(images[:, cam_index])["feature_map"]
|
cam_features = self.backbone(images[:, cam_index])["feature_map"]
|
||||||
cam_pos_embed = self.encoder_cam_feat_pos_embed(cam_features).to(dtype=cam_features.dtype)
|
cam_pos_embed = self.encoder_cam_feat_pos_embed(cam_features).to(dtype=cam_features.dtype)
|
||||||
|
|
|
@ -69,9 +69,14 @@ class DiffusionConfig:
|
||||||
horizon: int = 16
|
horizon: int = 16
|
||||||
n_action_steps: int = 8
|
n_action_steps: int = 8
|
||||||
|
|
||||||
# Vision preprocessing.
|
# Normalization / Unnormalization
|
||||||
image_normalization_mean: tuple[float, float, float] = (0.5, 0.5, 0.5)
|
normalize_input_modes: dict[str, str] = {
|
||||||
image_normalization_std: tuple[float, float, float] = (0.5, 0.5, 0.5)
|
"observation.image": "mean_std",
|
||||||
|
"observation.state": "min_max",
|
||||||
|
}
|
||||||
|
unnormalize_output_modes: dict[str, str] = {
|
||||||
|
"action": "min_max",
|
||||||
|
}
|
||||||
|
|
||||||
# Architecture / modeling.
|
# Architecture / modeling.
|
||||||
# Vision backbone.
|
# Vision backbone.
|
||||||
|
|
|
@ -13,7 +13,6 @@ import logging
|
||||||
import math
|
import math
|
||||||
import time
|
import time
|
||||||
from collections import deque
|
from collections import deque
|
||||||
from itertools import chain
|
|
||||||
from typing import Callable
|
from typing import Callable
|
||||||
|
|
||||||
import einops
|
import einops
|
||||||
|
@ -30,7 +29,9 @@ from lerobot.common.policies.diffusion.configuration_diffusion import DiffusionC
|
||||||
from lerobot.common.policies.utils import (
|
from lerobot.common.policies.utils import (
|
||||||
get_device_from_parameters,
|
get_device_from_parameters,
|
||||||
get_dtype_from_parameters,
|
get_dtype_from_parameters,
|
||||||
|
normalize_inputs,
|
||||||
populate_queues,
|
populate_queues,
|
||||||
|
unnormalize_outputs,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@ -42,7 +43,9 @@ class DiffusionPolicy(nn.Module):
|
||||||
|
|
||||||
name = "diffusion"
|
name = "diffusion"
|
||||||
|
|
||||||
def __init__(self, cfg: DiffusionConfig | None, lr_scheduler_num_training_steps: int = 0):
|
def __init__(
|
||||||
|
self, cfg: DiffusionConfig | None, lr_scheduler_num_training_steps: int = 0, dataset_stats=None
|
||||||
|
):
|
||||||
"""
|
"""
|
||||||
Args:
|
Args:
|
||||||
cfg: Policy configuration class instance or None, in which case the default instantiation of the
|
cfg: Policy configuration class instance or None, in which case the default instantiation of the
|
||||||
|
@ -54,6 +57,9 @@ class DiffusionPolicy(nn.Module):
|
||||||
if cfg is None:
|
if cfg is None:
|
||||||
cfg = DiffusionConfig()
|
cfg = DiffusionConfig()
|
||||||
self.cfg = cfg
|
self.cfg = cfg
|
||||||
|
self.register_buffer("dataset_stats", dataset_stats)
|
||||||
|
self.normalize_input_modes = cfg.normalize_input_modes
|
||||||
|
self.unnormalize_output_modes = cfg.unnormalize_output_modes
|
||||||
|
|
||||||
# queues are populated during rollout of the policy, they contain the n latest observations and actions
|
# queues are populated during rollout of the policy, they contain the n latest observations and actions
|
||||||
self._queues = None
|
self._queues = None
|
||||||
|
@ -126,6 +132,8 @@ class DiffusionPolicy(nn.Module):
|
||||||
assert "observation.state" in batch
|
assert "observation.state" in batch
|
||||||
assert len(batch) == 2
|
assert len(batch) == 2
|
||||||
|
|
||||||
|
batch = normalize_inputs(batch, self.dataset_stats, self.normalize_input_modes)
|
||||||
|
|
||||||
self._queues = populate_queues(self._queues, batch)
|
self._queues = populate_queues(self._queues, batch)
|
||||||
|
|
||||||
if len(self._queues["action"]) == 0:
|
if len(self._queues["action"]) == 0:
|
||||||
|
@ -135,6 +143,8 @@ class DiffusionPolicy(nn.Module):
|
||||||
actions = self.ema_diffusion.generate_actions(batch)
|
actions = self.ema_diffusion.generate_actions(batch)
|
||||||
else:
|
else:
|
||||||
actions = self.diffusion.generate_actions(batch)
|
actions = self.diffusion.generate_actions(batch)
|
||||||
|
|
||||||
|
actions = unnormalize_outputs(actions, self.dataset_stats, self.unnormalize_output_modes)
|
||||||
self._queues["action"].extend(actions.transpose(0, 1))
|
self._queues["action"].extend(actions.transpose(0, 1))
|
||||||
|
|
||||||
action = self._queues["action"].popleft()
|
action = self._queues["action"].popleft()
|
||||||
|
@ -151,9 +161,13 @@ class DiffusionPolicy(nn.Module):
|
||||||
|
|
||||||
self.diffusion.train()
|
self.diffusion.train()
|
||||||
|
|
||||||
|
batch = normalize_inputs(batch, self.dataset_stats, self.normalize_input_modes)
|
||||||
|
|
||||||
loss = self.forward(batch)["loss"]
|
loss = self.forward(batch)["loss"]
|
||||||
loss.backward()
|
loss.backward()
|
||||||
|
|
||||||
|
# TODO(rcadene): unnormalize_outputs(actions, self.dataset_stats, self.unnormalize_output_modes)
|
||||||
|
|
||||||
grad_norm = torch.nn.utils.clip_grad_norm_(
|
grad_norm = torch.nn.utils.clip_grad_norm_(
|
||||||
self.diffusion.parameters(),
|
self.diffusion.parameters(),
|
||||||
self.cfg.grad_clip_norm,
|
self.cfg.grad_clip_norm,
|
||||||
|
@ -346,12 +360,6 @@ class _RgbEncoder(nn.Module):
|
||||||
def __init__(self, cfg: DiffusionConfig):
|
def __init__(self, cfg: DiffusionConfig):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
# Set up optional preprocessing.
|
# Set up optional preprocessing.
|
||||||
if all(v == 1.0 for v in chain(cfg.image_normalization_mean, cfg.image_normalization_std)):
|
|
||||||
self.normalizer = nn.Identity()
|
|
||||||
else:
|
|
||||||
self.normalizer = torchvision.transforms.Normalize(
|
|
||||||
mean=cfg.image_normalization_mean, std=cfg.image_normalization_std
|
|
||||||
)
|
|
||||||
if cfg.crop_shape is not None:
|
if cfg.crop_shape is not None:
|
||||||
self.do_crop = True
|
self.do_crop = True
|
||||||
# Always use center crop for eval
|
# Always use center crop for eval
|
||||||
|
@ -397,8 +405,7 @@ class _RgbEncoder(nn.Module):
|
||||||
Returns:
|
Returns:
|
||||||
(B, D) image feature.
|
(B, D) image feature.
|
||||||
"""
|
"""
|
||||||
# Preprocess: normalize and maybe crop (if it was set up in the __init__).
|
# Preprocess: maybe crop (if it was set up in the __init__).
|
||||||
x = self.normalizer(x)
|
|
||||||
if self.do_crop:
|
if self.do_crop:
|
||||||
if self.training: # noqa: SIM108
|
if self.training: # noqa: SIM108
|
||||||
x = self.maybe_random_crop(x)
|
x = self.maybe_random_crop(x)
|
||||||
|
|
|
@ -20,7 +20,7 @@ def _policy_cfg_from_hydra_cfg(policy_cfg_class, hydra_cfg):
|
||||||
return policy_cfg
|
return policy_cfg
|
||||||
|
|
||||||
|
|
||||||
def make_policy(hydra_cfg: DictConfig):
|
def make_policy(hydra_cfg: DictConfig, dataset_stats=None):
|
||||||
if hydra_cfg.policy.name == "tdmpc":
|
if hydra_cfg.policy.name == "tdmpc":
|
||||||
from lerobot.common.policies.tdmpc.policy import TDMPCPolicy
|
from lerobot.common.policies.tdmpc.policy import TDMPCPolicy
|
||||||
|
|
||||||
|
@ -35,14 +35,14 @@ def make_policy(hydra_cfg: DictConfig):
|
||||||
from lerobot.common.policies.diffusion.modeling_diffusion import DiffusionPolicy
|
from lerobot.common.policies.diffusion.modeling_diffusion import DiffusionPolicy
|
||||||
|
|
||||||
policy_cfg = _policy_cfg_from_hydra_cfg(DiffusionConfig, hydra_cfg)
|
policy_cfg = _policy_cfg_from_hydra_cfg(DiffusionConfig, hydra_cfg)
|
||||||
policy = DiffusionPolicy(policy_cfg, hydra_cfg.offline_steps)
|
policy = DiffusionPolicy(policy_cfg, hydra_cfg.offline_steps, dataset_stats)
|
||||||
policy.to(get_safe_torch_device(hydra_cfg.device))
|
policy.to(get_safe_torch_device(hydra_cfg.device))
|
||||||
elif hydra_cfg.policy.name == "act":
|
elif hydra_cfg.policy.name == "act":
|
||||||
from lerobot.common.policies.act.configuration_act import ActionChunkingTransformerConfig
|
from lerobot.common.policies.act.configuration_act import ActionChunkingTransformerConfig
|
||||||
from lerobot.common.policies.act.modeling_act import ActionChunkingTransformerPolicy
|
from lerobot.common.policies.act.modeling_act import ActionChunkingTransformerPolicy
|
||||||
|
|
||||||
policy_cfg = _policy_cfg_from_hydra_cfg(ActionChunkingTransformerConfig, hydra_cfg)
|
policy_cfg = _policy_cfg_from_hydra_cfg(ActionChunkingTransformerConfig, hydra_cfg)
|
||||||
policy = ActionChunkingTransformerPolicy(policy_cfg)
|
policy = ActionChunkingTransformerPolicy(policy_cfg, dataset_stats)
|
||||||
policy.to(get_safe_torch_device(hydra_cfg.device))
|
policy.to(get_safe_torch_device(hydra_cfg.device))
|
||||||
else:
|
else:
|
||||||
raise ValueError(hydra_cfg.policy.name)
|
raise ValueError(hydra_cfg.policy.name)
|
||||||
|
|
|
@ -28,3 +28,41 @@ def get_dtype_from_parameters(module: nn.Module) -> torch.dtype:
|
||||||
Note: assumes that all parameters have the same dtype.
|
Note: assumes that all parameters have the same dtype.
|
||||||
"""
|
"""
|
||||||
return next(iter(module.parameters())).dtype
|
return next(iter(module.parameters())).dtype
|
||||||
|
|
||||||
|
|
||||||
|
def normalize_inputs(batch, stats, normalize_input_modes):
|
||||||
|
if normalize_input_modes is None:
|
||||||
|
return batch
|
||||||
|
for key, mode in normalize_input_modes.items():
|
||||||
|
if mode == "mean_std":
|
||||||
|
mean = stats[key]["mean"].unsqueeze(0)
|
||||||
|
std = stats[key]["std"].unsqueeze(0)
|
||||||
|
batch[key] = (batch[key] - mean) / (std + 1e-8)
|
||||||
|
elif mode == "min_max":
|
||||||
|
min = stats[key]["min"].unsqueeze(0)
|
||||||
|
max = stats[key]["max"].unsqueeze(0)
|
||||||
|
# normalize to [0,1]
|
||||||
|
batch[key] = (batch[key] - min) / (max - min)
|
||||||
|
# normalize to [-1, 1]
|
||||||
|
batch[key] = batch[key] * 2 - 1
|
||||||
|
else:
|
||||||
|
raise ValueError(mode)
|
||||||
|
return batch
|
||||||
|
|
||||||
|
|
||||||
|
def unnormalize_outputs(batch, stats, unnormalize_output_modes):
|
||||||
|
if unnormalize_output_modes is None:
|
||||||
|
return batch
|
||||||
|
for key, mode in unnormalize_output_modes.items():
|
||||||
|
if mode == "mean_std":
|
||||||
|
mean = stats[key]["mean"].unsqueeze(0)
|
||||||
|
std = stats[key]["std"].unsqueeze(0)
|
||||||
|
batch[key] = batch[key] * std + mean
|
||||||
|
elif mode == "min_max":
|
||||||
|
min = stats[key]["min"].unsqueeze(0)
|
||||||
|
max = stats[key]["max"].unsqueeze(0)
|
||||||
|
batch[key] = (batch[key] + 1) / 2
|
||||||
|
batch[key] = batch[key] * (max - min) + min
|
||||||
|
else:
|
||||||
|
raise ValueError(mode)
|
||||||
|
return batch
|
||||||
|
|
|
@ -1,65 +0,0 @@
|
||||||
from torchvision.transforms.v2 import Compose, Transform
|
|
||||||
|
|
||||||
|
|
||||||
def apply_inverse_transform(item, transform):
|
|
||||||
transforms = transform.transforms if isinstance(transform, Compose) else [transform]
|
|
||||||
for tf in transforms[::-1]:
|
|
||||||
if tf.invertible:
|
|
||||||
item = tf.inverse_transform(item)
|
|
||||||
else:
|
|
||||||
raise ValueError(f"Inverse transform called on a non invertible transform ({tf}).")
|
|
||||||
return item
|
|
||||||
|
|
||||||
|
|
||||||
class NormalizeTransform(Transform):
|
|
||||||
invertible = True
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
stats: dict,
|
|
||||||
in_keys: list[str] = None,
|
|
||||||
out_keys: list[str] | None = None,
|
|
||||||
in_keys_inv: list[str] | None = None,
|
|
||||||
out_keys_inv: list[str] | None = None,
|
|
||||||
mode="mean_std",
|
|
||||||
):
|
|
||||||
super().__init__()
|
|
||||||
self.in_keys = in_keys
|
|
||||||
self.out_keys = in_keys if out_keys is None else out_keys
|
|
||||||
self.in_keys_inv = self.out_keys if in_keys_inv is None else in_keys_inv
|
|
||||||
self.out_keys_inv = self.in_keys if out_keys_inv is None else out_keys_inv
|
|
||||||
self.stats = stats
|
|
||||||
assert mode in ["mean_std", "min_max"]
|
|
||||||
self.mode = mode
|
|
||||||
|
|
||||||
def forward(self, item):
|
|
||||||
for inkey, outkey in zip(self.in_keys, self.out_keys, strict=False):
|
|
||||||
if inkey not in item:
|
|
||||||
continue
|
|
||||||
if self.mode == "mean_std":
|
|
||||||
mean = self.stats[inkey]["mean"]
|
|
||||||
std = self.stats[inkey]["std"]
|
|
||||||
item[outkey] = (item[inkey] - mean) / (std + 1e-8)
|
|
||||||
else:
|
|
||||||
min = self.stats[inkey]["min"]
|
|
||||||
max = self.stats[inkey]["max"]
|
|
||||||
# normalize to [0,1]
|
|
||||||
item[outkey] = (item[inkey] - min) / (max - min)
|
|
||||||
# normalize to [-1, 1]
|
|
||||||
item[outkey] = item[outkey] * 2 - 1
|
|
||||||
return item
|
|
||||||
|
|
||||||
def inverse_transform(self, item):
|
|
||||||
for inkey, outkey in zip(self.in_keys_inv, self.out_keys_inv, strict=False):
|
|
||||||
if inkey not in item:
|
|
||||||
continue
|
|
||||||
if self.mode == "mean_std":
|
|
||||||
mean = self.stats[inkey]["mean"]
|
|
||||||
std = self.stats[inkey]["std"]
|
|
||||||
item[outkey] = item[inkey] * std + mean
|
|
||||||
else:
|
|
||||||
min = self.stats[inkey]["min"]
|
|
||||||
max = self.stats[inkey]["max"]
|
|
||||||
item[outkey] = (item[inkey] + 1) / 2
|
|
||||||
item[outkey] = item[outkey] * (max - min) + min
|
|
||||||
return item
|
|
|
@ -11,6 +11,11 @@ log_freq: 250
|
||||||
n_obs_steps: 1
|
n_obs_steps: 1
|
||||||
# when temporal_agg=False, n_action_steps=horizon
|
# when temporal_agg=False, n_action_steps=horizon
|
||||||
|
|
||||||
|
override_dataset_stats:
|
||||||
|
observation.image:
|
||||||
|
mean: [[[0.485]], [[0.456]], [[0.406]]] # (c,1,1)
|
||||||
|
std: [[[0.229]], [[0.224]], [[0.225]]] # (c,1,1)
|
||||||
|
|
||||||
# See `configuration_act.py` for more details.
|
# See `configuration_act.py` for more details.
|
||||||
policy:
|
policy:
|
||||||
name: act
|
name: act
|
||||||
|
@ -28,9 +33,12 @@ policy:
|
||||||
chunk_size: 100 # chunk_size
|
chunk_size: 100 # chunk_size
|
||||||
n_action_steps: 100
|
n_action_steps: 100
|
||||||
|
|
||||||
# Vision preprocessing.
|
# Normalization / Unnormalization
|
||||||
image_normalization_mean: [0.485, 0.456, 0.406]
|
normalize_input_modes:
|
||||||
image_normalization_std: [0.229, 0.224, 0.225]
|
observation.image: mean_std
|
||||||
|
observation.state: mean_std
|
||||||
|
unnormalize_output_modes:
|
||||||
|
action: mean_std
|
||||||
|
|
||||||
# Architecture.
|
# Architecture.
|
||||||
# Vision backbone.
|
# Vision backbone.
|
||||||
|
|
|
@ -18,6 +18,17 @@ online_steps: 0
|
||||||
|
|
||||||
offline_prioritized_sampler: true
|
offline_prioritized_sampler: true
|
||||||
|
|
||||||
|
override_dataset_stats:
|
||||||
|
observation.image:
|
||||||
|
mean: [[[0.5]], [[0.5]], [[0.5]]] # (c,1,1)
|
||||||
|
std: [[[0.5]], [[0.5]], [[0.5]]] # (c,1,1)
|
||||||
|
observation.state:
|
||||||
|
min: [13.456424, 32.938293]
|
||||||
|
max: [496.14618, 510.9579]
|
||||||
|
action:
|
||||||
|
min: [12.0, 25.0]
|
||||||
|
max: [511.0, 511.0]
|
||||||
|
|
||||||
policy:
|
policy:
|
||||||
name: diffusion
|
name: diffusion
|
||||||
|
|
||||||
|
@ -36,9 +47,12 @@ policy:
|
||||||
horizon: ${horizon}
|
horizon: ${horizon}
|
||||||
n_action_steps: ${n_action_steps}
|
n_action_steps: ${n_action_steps}
|
||||||
|
|
||||||
# Vision preprocessing.
|
# Normalization / Unnormalization
|
||||||
image_normalization_mean: [0.5, 0.5, 0.5]
|
normalize_input_modes:
|
||||||
image_normalization_std: [0.5, 0.5, 0.5]
|
observation.image: mean_std
|
||||||
|
observation.state: min_max
|
||||||
|
unnormalize_output_modes:
|
||||||
|
action: min_max
|
||||||
|
|
||||||
# Architecture / modeling.
|
# Architecture / modeling.
|
||||||
# Vision backbone.
|
# Vision backbone.
|
||||||
|
|
|
@ -46,7 +46,6 @@ from huggingface_hub import snapshot_download
|
||||||
from PIL import Image as PILImage
|
from PIL import Image as PILImage
|
||||||
from tqdm import trange
|
from tqdm import trange
|
||||||
|
|
||||||
from lerobot.common.datasets.factory import make_dataset
|
|
||||||
from lerobot.common.datasets.utils import hf_transform_to_torch
|
from lerobot.common.datasets.utils import hf_transform_to_torch
|
||||||
from lerobot.common.envs.factory import make_env
|
from lerobot.common.envs.factory import make_env
|
||||||
from lerobot.common.envs.utils import postprocess_action, preprocess_observation
|
from lerobot.common.envs.utils import postprocess_action, preprocess_observation
|
||||||
|
@ -64,8 +63,6 @@ def eval_policy(
|
||||||
policy: torch.nn.Module,
|
policy: torch.nn.Module,
|
||||||
max_episodes_rendered: int = 0,
|
max_episodes_rendered: int = 0,
|
||||||
video_dir: Path = None,
|
video_dir: Path = None,
|
||||||
# TODO(rcadene): make it possible to overwrite fps? we should use env.fps
|
|
||||||
transform: callable = None,
|
|
||||||
return_episode_data: bool = False,
|
return_episode_data: bool = False,
|
||||||
seed=None,
|
seed=None,
|
||||||
):
|
):
|
||||||
|
@ -132,10 +129,6 @@ def eval_policy(
|
||||||
if return_episode_data:
|
if return_episode_data:
|
||||||
observations.append(deepcopy(observation))
|
observations.append(deepcopy(observation))
|
||||||
|
|
||||||
# apply transform to normalize the observations
|
|
||||||
for key in observation:
|
|
||||||
observation[key] = torch.stack([transform({key: item})[key] for item in observation[key]])
|
|
||||||
|
|
||||||
# send observation to device/gpu
|
# send observation to device/gpu
|
||||||
observation = {key: observation[key].to(device, non_blocking=True) for key in observation}
|
observation = {key: observation[key].to(device, non_blocking=True) for key in observation}
|
||||||
|
|
||||||
|
@ -143,8 +136,8 @@ def eval_policy(
|
||||||
with torch.inference_mode():
|
with torch.inference_mode():
|
||||||
action = policy.select_action(observation, step=step)
|
action = policy.select_action(observation, step=step)
|
||||||
|
|
||||||
# apply inverse transform to unnormalize the action
|
# convert to cpu numpy
|
||||||
action = postprocess_action(action, transform)
|
action = postprocess_action(action)
|
||||||
|
|
||||||
# apply the next action
|
# apply the next action
|
||||||
observation, reward, terminated, truncated, info = env.step(action)
|
observation, reward, terminated, truncated, info = env.step(action)
|
||||||
|
@ -360,7 +353,7 @@ def eval_policy(
|
||||||
return info
|
return info
|
||||||
|
|
||||||
|
|
||||||
def eval(cfg: dict, out_dir=None, stats_path=None):
|
def eval(cfg: dict, out_dir=None):
|
||||||
if out_dir is None:
|
if out_dir is None:
|
||||||
raise NotImplementedError()
|
raise NotImplementedError()
|
||||||
|
|
||||||
|
@ -375,10 +368,6 @@ def eval(cfg: dict, out_dir=None, stats_path=None):
|
||||||
|
|
||||||
log_output_dir(out_dir)
|
log_output_dir(out_dir)
|
||||||
|
|
||||||
logging.info("Making transforms.")
|
|
||||||
# TODO(alexander-soare): Completely decouple datasets from evaluation.
|
|
||||||
transform = make_dataset(cfg, stats_path=stats_path).transform
|
|
||||||
|
|
||||||
logging.info("Making environment.")
|
logging.info("Making environment.")
|
||||||
env = make_env(cfg, num_parallel_envs=cfg.eval_episodes)
|
env = make_env(cfg, num_parallel_envs=cfg.eval_episodes)
|
||||||
|
|
||||||
|
@ -390,7 +379,6 @@ def eval(cfg: dict, out_dir=None, stats_path=None):
|
||||||
policy,
|
policy,
|
||||||
max_episodes_rendered=10,
|
max_episodes_rendered=10,
|
||||||
video_dir=Path(out_dir) / "eval",
|
video_dir=Path(out_dir) / "eval",
|
||||||
transform=transform,
|
|
||||||
return_episode_data=False,
|
return_episode_data=False,
|
||||||
seed=cfg.seed,
|
seed=cfg.seed,
|
||||||
)
|
)
|
||||||
|
@ -423,17 +411,13 @@ if __name__ == "__main__":
|
||||||
if args.config is not None:
|
if args.config is not None:
|
||||||
# Note: For the config_path, Hydra wants a path relative to this script file.
|
# Note: For the config_path, Hydra wants a path relative to this script file.
|
||||||
cfg = init_hydra_config(args.config, args.overrides)
|
cfg = init_hydra_config(args.config, args.overrides)
|
||||||
# TODO(alexander-soare): Save and load stats in trained model directory.
|
|
||||||
stats_path = None
|
|
||||||
elif args.hub_id is not None:
|
elif args.hub_id is not None:
|
||||||
folder = Path(snapshot_download(args.hub_id, revision=args.revision))
|
folder = Path(snapshot_download(args.hub_id, revision=args.revision))
|
||||||
cfg = init_hydra_config(
|
cfg = init_hydra_config(
|
||||||
folder / "config.yaml", [f"policy.pretrained_model_path={folder / 'model.pt'}", *args.overrides]
|
folder / "config.yaml", [f"policy.pretrained_model_path={folder / 'model.pt'}", *args.overrides]
|
||||||
)
|
)
|
||||||
stats_path = folder / "stats.pth"
|
|
||||||
|
|
||||||
eval(
|
eval(
|
||||||
cfg,
|
cfg,
|
||||||
out_dir=f"outputs/eval/{dt.now().strftime('%Y-%m-%d/%H-%M-%S')}_{cfg.env.name}_{cfg.policy.name}",
|
out_dir=f"outputs/eval/{dt.now().strftime('%Y-%m-%d/%H-%M-%S')}_{cfg.env.name}_{cfg.policy.name}",
|
||||||
stats_path=stats_path,
|
|
||||||
)
|
)
|
||||||
|
|
|
@ -232,7 +232,7 @@ def train(cfg: dict, out_dir=None, job_name=None):
|
||||||
env = make_env(cfg, num_parallel_envs=cfg.eval_episodes)
|
env = make_env(cfg, num_parallel_envs=cfg.eval_episodes)
|
||||||
|
|
||||||
logging.info("make_policy")
|
logging.info("make_policy")
|
||||||
policy = make_policy(cfg)
|
policy = make_policy(cfg, dataset_stats=offline_dataset.stats)
|
||||||
|
|
||||||
num_learnable_params = sum(p.numel() for p in policy.parameters() if p.requires_grad)
|
num_learnable_params = sum(p.numel() for p in policy.parameters() if p.requires_grad)
|
||||||
num_total_params = sum(p.numel() for p in policy.parameters())
|
num_total_params = sum(p.numel() for p in policy.parameters())
|
||||||
|
|
|
@ -42,7 +42,7 @@ def test_factory(env_name):
|
||||||
|
|
||||||
env = make_env(cfg, num_parallel_envs=1)
|
env = make_env(cfg, num_parallel_envs=1)
|
||||||
obs, _ = env.reset()
|
obs, _ = env.reset()
|
||||||
obs = preprocess_observation(obs, transform=dataset.transform)
|
obs = preprocess_observation(obs)
|
||||||
for key in dataset.image_keys:
|
for key in dataset.image_keys:
|
||||||
img = obs[key]
|
img = obs[key]
|
||||||
assert img.dtype == torch.float32
|
assert img.dtype == torch.float32
|
||||||
|
|
|
@ -51,7 +51,7 @@ def test_examples_4_and_3():
|
||||||
# Pass empty globals to allow dictionary comprehension https://stackoverflow.com/a/32897127/4391249.
|
# Pass empty globals to allow dictionary comprehension https://stackoverflow.com/a/32897127/4391249.
|
||||||
exec(file_contents, {})
|
exec(file_contents, {})
|
||||||
|
|
||||||
for file_name in ["model.pt", "stats.pth", "config.yaml"]:
|
for file_name in ["model.pt", "config.yaml"]:
|
||||||
assert Path(f"outputs/train/example_pusht_diffusion/{file_name}").exists()
|
assert Path(f"outputs/train/example_pusht_diffusion/{file_name}").exists()
|
||||||
|
|
||||||
path = "examples/3_evaluate_pretrained_policy.py"
|
path = "examples/3_evaluate_pretrained_policy.py"
|
||||||
|
|
|
@ -44,14 +44,16 @@ def test_policy(env_name, policy_name, extra_overrides):
|
||||||
]
|
]
|
||||||
+ extra_overrides,
|
+ extra_overrides,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Check that we can make the policy object.
|
# Check that we can make the policy object.
|
||||||
policy = make_policy(cfg)
|
dataset = make_dataset(cfg)
|
||||||
|
policy = make_policy(cfg, dataset_stats=dataset.stats)
|
||||||
# Check that the policy follows the required protocol.
|
# Check that the policy follows the required protocol.
|
||||||
assert isinstance(
|
assert isinstance(
|
||||||
policy, Policy
|
policy, Policy
|
||||||
), f"The policy does not follow the required protocol. Please see {Policy.__module__}.{Policy.__name__}."
|
), f"The policy does not follow the required protocol. Please see {Policy.__module__}.{Policy.__name__}."
|
||||||
|
|
||||||
# Check that we run select_actions and get the appropriate output.
|
# Check that we run select_actions and get the appropriate output.
|
||||||
dataset = make_dataset(cfg)
|
|
||||||
env = make_env(cfg, num_parallel_envs=2)
|
env = make_env(cfg, num_parallel_envs=2)
|
||||||
|
|
||||||
dataloader = torch.utils.data.DataLoader(
|
dataloader = torch.utils.data.DataLoader(
|
||||||
|
@ -77,7 +79,7 @@ def test_policy(env_name, policy_name, extra_overrides):
|
||||||
observation, _ = env.reset(seed=cfg.seed)
|
observation, _ = env.reset(seed=cfg.seed)
|
||||||
|
|
||||||
# apply transform to normalize the observations
|
# apply transform to normalize the observations
|
||||||
observation = preprocess_observation(observation, dataset.transform)
|
observation = preprocess_observation(observation)
|
||||||
|
|
||||||
# send observation to device/gpu
|
# send observation to device/gpu
|
||||||
observation = {key: observation[key].to(DEVICE, non_blocking=True) for key in observation}
|
observation = {key: observation[key].to(DEVICE, non_blocking=True) for key in observation}
|
||||||
|
@ -86,8 +88,8 @@ def test_policy(env_name, policy_name, extra_overrides):
|
||||||
with torch.inference_mode():
|
with torch.inference_mode():
|
||||||
action = policy.select_action(observation, step=0)
|
action = policy.select_action(observation, step=0)
|
||||||
|
|
||||||
# apply inverse transform to unnormalize the action
|
# convert action to cpu numpy array
|
||||||
action = postprocess_action(action, dataset.transform)
|
action = postprocess_action(action)
|
||||||
|
|
||||||
# Test step through policy
|
# Test step through policy
|
||||||
env.step(action)
|
env.step(action)
|
||||||
|
|
Loading…
Reference in New Issue