Move normalize/unnormalize transforms to policy for act and diffusion
This commit is contained in:
parent
c1bcf857c5
commit
42ed7bb670
|
@ -263,15 +263,13 @@ Secondly, assuming you have trained a policy, you need:
|
|||
|
||||
- `config.yaml` which you can get from the `.hydra` directory of your training output folder.
|
||||
- `model.pt` which should be one of the saved models in the `models` directory of your training output folder (they won't be named `model.pt` but you will need to choose one).
|
||||
- `stats.pth` which should point to the same file in the dataset directory (found in `data/{dataset_name}`).
|
||||
|
||||
To upload these to the hub, prepare a folder with the following structure (you can use symlinks rather than copying):
|
||||
|
||||
```
|
||||
to_upload
|
||||
├── config.yaml
|
||||
├── model.pt
|
||||
└── stats.pth
|
||||
└── model.pt
|
||||
```
|
||||
|
||||
With the folder prepared, run the following with a desired revision ID.
|
||||
|
|
|
@ -19,7 +19,6 @@ folder = Path(snapshot_download(hub_id))
|
|||
|
||||
config_path = folder / "config.yaml"
|
||||
weights_path = folder / "model.pt"
|
||||
stats_path = folder / "stats.pth" # normalization stats
|
||||
|
||||
# Override some config parameters to do with evaluation.
|
||||
overrides = [
|
||||
|
@ -36,5 +35,4 @@ cfg = init_hydra_config(config_path, overrides)
|
|||
eval(
|
||||
cfg,
|
||||
out_dir=f"outputs/eval/example_{cfg.env.name}_{cfg.policy.name}",
|
||||
stats_path=stats_path,
|
||||
)
|
||||
|
|
|
@ -62,7 +62,6 @@ while not done:
|
|||
done = True
|
||||
break
|
||||
|
||||
# Save the policy, configuration, and normalization stats for later use.
|
||||
# Save the policy and configuration for later use.
|
||||
policy.save(output_directory / "model.pt")
|
||||
OmegaConf.save(hydra_cfg, output_directory / "config.yaml")
|
||||
torch.save(dataset.transform.transforms[-1].stats, output_directory / "stats.pth")
|
||||
|
|
|
@ -2,18 +2,12 @@ import os
|
|||
from pathlib import Path
|
||||
|
||||
import torch
|
||||
from torchvision.transforms import v2
|
||||
|
||||
from lerobot.common.transforms import NormalizeTransform
|
||||
|
||||
DATA_DIR = Path(os.environ["DATA_DIR"]) if "DATA_DIR" in os.environ else None
|
||||
|
||||
|
||||
def make_dataset(
|
||||
cfg,
|
||||
# set normalize=False to remove all transformations and keep images unnormalized in [0,255]
|
||||
normalize=True,
|
||||
stats_path=None,
|
||||
split="train",
|
||||
):
|
||||
if cfg.env.name == "xarm":
|
||||
|
@ -33,58 +27,23 @@ def make_dataset(
|
|||
else:
|
||||
raise ValueError(cfg.env.name)
|
||||
|
||||
transforms = None
|
||||
if normalize:
|
||||
# TODO(rcadene): make normalization strategy configurable between mean_std, min_max, manual_min_max,
|
||||
# min_max_from_spec
|
||||
# TODO(rcadene): remove this and put it in config. Ideally we want to reproduce SOTA results just with mean_std
|
||||
normalization_mode = "mean_std" if cfg.env.name == "aloha" else "min_max"
|
||||
|
||||
if cfg.policy.name == "diffusion" and cfg.env.name == "pusht":
|
||||
stats = {}
|
||||
# TODO(rcadene): we overwrite stats to have the same as pretrained model, but we should remove this
|
||||
stats["observation.state"] = {}
|
||||
stats["observation.state"]["min"] = torch.tensor([13.456424, 32.938293], dtype=torch.float32)
|
||||
stats["observation.state"]["max"] = torch.tensor([496.14618, 510.9579], dtype=torch.float32)
|
||||
stats["action"] = {}
|
||||
stats["action"]["min"] = torch.tensor([12.0, 25.0], dtype=torch.float32)
|
||||
stats["action"]["max"] = torch.tensor([511.0, 511.0], dtype=torch.float32)
|
||||
elif stats_path is None:
|
||||
# load a first dataset to access precomputed stats
|
||||
stats_dataset = clsfunc(
|
||||
dataset_id=cfg.dataset_id,
|
||||
split="train",
|
||||
root=DATA_DIR,
|
||||
)
|
||||
stats = stats_dataset.stats
|
||||
else:
|
||||
stats = torch.load(stats_path)
|
||||
|
||||
transforms = v2.Compose(
|
||||
[
|
||||
NormalizeTransform(
|
||||
stats,
|
||||
in_keys=[
|
||||
"observation.state",
|
||||
"action",
|
||||
],
|
||||
mode=normalization_mode,
|
||||
),
|
||||
]
|
||||
)
|
||||
|
||||
delta_timestamps = cfg.policy.get("delta_timestamps")
|
||||
if delta_timestamps is not None:
|
||||
for key in delta_timestamps:
|
||||
if isinstance(delta_timestamps[key], str):
|
||||
delta_timestamps[key] = eval(delta_timestamps[key])
|
||||
|
||||
# TODO(rcadene): add data augmentations
|
||||
|
||||
dataset = clsfunc(
|
||||
dataset_id=cfg.dataset_id,
|
||||
split=split,
|
||||
root=DATA_DIR,
|
||||
delta_timestamps=delta_timestamps,
|
||||
transform=transforms,
|
||||
)
|
||||
|
||||
if cfg.get("override_dataset_stats"):
|
||||
for key, val in cfg.override_dataset_stats.items():
|
||||
dataset.stats[key] = torch.tensor(val)
|
||||
|
||||
return dataset
|
||||
|
|
|
@ -1,10 +1,8 @@
|
|||
import einops
|
||||
import torch
|
||||
|
||||
from lerobot.common.transforms import apply_inverse_transform
|
||||
|
||||
|
||||
def preprocess_observation(observation, transform=None):
|
||||
def preprocess_observation(observation):
|
||||
# map to expected inputs for the policy
|
||||
obs = {}
|
||||
|
||||
|
@ -33,19 +31,11 @@ def preprocess_observation(observation, transform=None):
|
|||
# TODO(rcadene): enable pixels only baseline with `obs_type="pixels"` in environment by removing requirement for "agent_pos"
|
||||
obs["observation.state"] = torch.from_numpy(observation["agent_pos"]).float()
|
||||
|
||||
# apply same transforms as in training
|
||||
if transform is not None:
|
||||
for key in obs:
|
||||
obs[key] = torch.stack([transform({key: item})[key] for item in obs[key]])
|
||||
|
||||
return obs
|
||||
|
||||
|
||||
def postprocess_action(action, transform=None):
|
||||
action = action.to("cpu")
|
||||
# action is a batch (num_env,action_dim) instead of an item (action_dim),
|
||||
# we assume applying inverse transform on a batch works the same
|
||||
action = apply_inverse_transform({"action": action}, transform)["action"].numpy()
|
||||
def postprocess_action(action):
|
||||
action = action.to("cpu").numpy()
|
||||
assert (
|
||||
action.ndim == 2
|
||||
), "we assume dimensions are respectively the number of parallel envs, action dimensions"
|
||||
|
|
|
@ -1,4 +1,4 @@
|
|||
from dataclasses import dataclass, field
|
||||
from dataclasses import dataclass
|
||||
|
||||
|
||||
@dataclass
|
||||
|
@ -60,12 +60,14 @@ class ActionChunkingTransformerConfig:
|
|||
chunk_size: int = 100
|
||||
n_action_steps: int = 100
|
||||
|
||||
# Vision preprocessing.
|
||||
image_normalization_mean: tuple[float, float, float] = field(
|
||||
default_factory=lambda: [0.485, 0.456, 0.406]
|
||||
)
|
||||
image_normalization_std: tuple[float, float, float] = field(default_factory=lambda: [0.229, 0.224, 0.225])
|
||||
|
||||
# Normalization / Unnormalization
|
||||
normalize_input_modes: dict[str, str] = {
|
||||
"observation.image": "mean_std",
|
||||
"observation.state": "mean_std",
|
||||
}
|
||||
unnormalize_output_modes: dict[str, str] = {
|
||||
"action": "mean_std",
|
||||
}
|
||||
# Architecture.
|
||||
# Vision backbone.
|
||||
vision_backbone: str = "resnet18"
|
||||
|
|
|
@ -15,12 +15,15 @@ import numpy as np
|
|||
import torch
|
||||
import torch.nn.functional as F # noqa: N812
|
||||
import torchvision
|
||||
import torchvision.transforms as transforms
|
||||
from torch import Tensor, nn
|
||||
from torchvision.models._utils import IntermediateLayerGetter
|
||||
from torchvision.ops.misc import FrozenBatchNorm2d
|
||||
|
||||
from lerobot.common.policies.act.configuration_act import ActionChunkingTransformerConfig
|
||||
from lerobot.common.policies.utils import (
|
||||
normalize_inputs,
|
||||
unnormalize_outputs,
|
||||
)
|
||||
|
||||
|
||||
class ActionChunkingTransformerPolicy(nn.Module):
|
||||
|
@ -62,7 +65,7 @@ class ActionChunkingTransformerPolicy(nn.Module):
|
|||
|
||||
name = "act"
|
||||
|
||||
def __init__(self, cfg: ActionChunkingTransformerConfig | None = None):
|
||||
def __init__(self, cfg: ActionChunkingTransformerConfig | None = None, dataset_stats=None):
|
||||
"""
|
||||
Args:
|
||||
cfg: Policy configuration class instance or None, in which case the default instantiation of the
|
||||
|
@ -72,6 +75,9 @@ class ActionChunkingTransformerPolicy(nn.Module):
|
|||
if cfg is None:
|
||||
cfg = ActionChunkingTransformerConfig()
|
||||
self.cfg = cfg
|
||||
self.register_buffer("dataset_stats", dataset_stats)
|
||||
self.normalize_input_modes = cfg.normalize_input_modes
|
||||
self.unnormalize_output_modes = cfg.unnormalize_output_modes
|
||||
|
||||
# BERT style VAE encoder with input [cls, *joint_space_configuration, *action_sequence].
|
||||
# The cls token forms parameters of the latent's distribution (like this [*means, *log_variances]).
|
||||
|
@ -93,9 +99,6 @@ class ActionChunkingTransformerPolicy(nn.Module):
|
|||
)
|
||||
|
||||
# Backbone for image feature extraction.
|
||||
self.image_normalizer = transforms.Normalize(
|
||||
mean=cfg.image_normalization_mean, std=cfg.image_normalization_std
|
||||
)
|
||||
backbone_model = getattr(torchvision.models, cfg.vision_backbone)(
|
||||
replace_stride_with_dilation=[False, False, cfg.replace_final_stride_with_dilation],
|
||||
pretrained=cfg.use_pretrained_backbone,
|
||||
|
@ -169,10 +172,15 @@ class ActionChunkingTransformerPolicy(nn.Module):
|
|||
queue is empty.
|
||||
"""
|
||||
self.eval()
|
||||
|
||||
batch = normalize_inputs(batch, self.dataset_stats, self.normalize_input_modes)
|
||||
|
||||
if len(self._action_queue) == 0:
|
||||
# `_forward` returns a (batch_size, n_action_steps, action_dim) tensor, but the queue effectively
|
||||
# has shape (n_action_steps, batch_size, *), hence the transpose.
|
||||
self._action_queue.extend(self._forward(batch)[0][: self.cfg.n_action_steps].transpose(0, 1))
|
||||
actions = self._forward(batch)[0][: self.cfg.n_action_steps]
|
||||
actions = unnormalize_outputs(actions, self.dataset_stats, self.unnormalize_output_modes)
|
||||
self._action_queue.extend(actions.transpose(0, 1))
|
||||
return self._action_queue.popleft()
|
||||
|
||||
def forward(self, batch, **_) -> dict[str, Tensor]:
|
||||
|
@ -203,7 +211,10 @@ class ActionChunkingTransformerPolicy(nn.Module):
|
|||
"""Run the model in train mode, compute the loss, and do an optimization step."""
|
||||
start_time = time.time()
|
||||
self.train()
|
||||
|
||||
batch = normalize_inputs(batch, self.dataset_stats, self.normalize_input_modes)
|
||||
loss_dict = self.forward(batch)
|
||||
# TODO(rcadene): unnormalize_outputs(actions, self.dataset_stats, self.unnormalize_output_modes)
|
||||
loss = loss_dict["loss"]
|
||||
loss.backward()
|
||||
|
||||
|
@ -309,7 +320,7 @@ class ActionChunkingTransformerPolicy(nn.Module):
|
|||
# Camera observation features and positional embeddings.
|
||||
all_cam_features = []
|
||||
all_cam_pos_embeds = []
|
||||
images = self.image_normalizer(batch["observation.images"])
|
||||
images = batch["observation.images"]
|
||||
for cam_index in range(len(self.cfg.camera_names)):
|
||||
cam_features = self.backbone(images[:, cam_index])["feature_map"]
|
||||
cam_pos_embed = self.encoder_cam_feat_pos_embed(cam_features).to(dtype=cam_features.dtype)
|
||||
|
|
|
@ -69,9 +69,14 @@ class DiffusionConfig:
|
|||
horizon: int = 16
|
||||
n_action_steps: int = 8
|
||||
|
||||
# Vision preprocessing.
|
||||
image_normalization_mean: tuple[float, float, float] = (0.5, 0.5, 0.5)
|
||||
image_normalization_std: tuple[float, float, float] = (0.5, 0.5, 0.5)
|
||||
# Normalization / Unnormalization
|
||||
normalize_input_modes: dict[str, str] = {
|
||||
"observation.image": "mean_std",
|
||||
"observation.state": "min_max",
|
||||
}
|
||||
unnormalize_output_modes: dict[str, str] = {
|
||||
"action": "min_max",
|
||||
}
|
||||
|
||||
# Architecture / modeling.
|
||||
# Vision backbone.
|
||||
|
|
|
@ -13,7 +13,6 @@ import logging
|
|||
import math
|
||||
import time
|
||||
from collections import deque
|
||||
from itertools import chain
|
||||
from typing import Callable
|
||||
|
||||
import einops
|
||||
|
@ -30,7 +29,9 @@ from lerobot.common.policies.diffusion.configuration_diffusion import DiffusionC
|
|||
from lerobot.common.policies.utils import (
|
||||
get_device_from_parameters,
|
||||
get_dtype_from_parameters,
|
||||
normalize_inputs,
|
||||
populate_queues,
|
||||
unnormalize_outputs,
|
||||
)
|
||||
|
||||
|
||||
|
@ -42,7 +43,9 @@ class DiffusionPolicy(nn.Module):
|
|||
|
||||
name = "diffusion"
|
||||
|
||||
def __init__(self, cfg: DiffusionConfig | None, lr_scheduler_num_training_steps: int = 0):
|
||||
def __init__(
|
||||
self, cfg: DiffusionConfig | None, lr_scheduler_num_training_steps: int = 0, dataset_stats=None
|
||||
):
|
||||
"""
|
||||
Args:
|
||||
cfg: Policy configuration class instance or None, in which case the default instantiation of the
|
||||
|
@ -54,6 +57,9 @@ class DiffusionPolicy(nn.Module):
|
|||
if cfg is None:
|
||||
cfg = DiffusionConfig()
|
||||
self.cfg = cfg
|
||||
self.register_buffer("dataset_stats", dataset_stats)
|
||||
self.normalize_input_modes = cfg.normalize_input_modes
|
||||
self.unnormalize_output_modes = cfg.unnormalize_output_modes
|
||||
|
||||
# queues are populated during rollout of the policy, they contain the n latest observations and actions
|
||||
self._queues = None
|
||||
|
@ -126,6 +132,8 @@ class DiffusionPolicy(nn.Module):
|
|||
assert "observation.state" in batch
|
||||
assert len(batch) == 2
|
||||
|
||||
batch = normalize_inputs(batch, self.dataset_stats, self.normalize_input_modes)
|
||||
|
||||
self._queues = populate_queues(self._queues, batch)
|
||||
|
||||
if len(self._queues["action"]) == 0:
|
||||
|
@ -135,6 +143,8 @@ class DiffusionPolicy(nn.Module):
|
|||
actions = self.ema_diffusion.generate_actions(batch)
|
||||
else:
|
||||
actions = self.diffusion.generate_actions(batch)
|
||||
|
||||
actions = unnormalize_outputs(actions, self.dataset_stats, self.unnormalize_output_modes)
|
||||
self._queues["action"].extend(actions.transpose(0, 1))
|
||||
|
||||
action = self._queues["action"].popleft()
|
||||
|
@ -151,9 +161,13 @@ class DiffusionPolicy(nn.Module):
|
|||
|
||||
self.diffusion.train()
|
||||
|
||||
batch = normalize_inputs(batch, self.dataset_stats, self.normalize_input_modes)
|
||||
|
||||
loss = self.forward(batch)["loss"]
|
||||
loss.backward()
|
||||
|
||||
# TODO(rcadene): unnormalize_outputs(actions, self.dataset_stats, self.unnormalize_output_modes)
|
||||
|
||||
grad_norm = torch.nn.utils.clip_grad_norm_(
|
||||
self.diffusion.parameters(),
|
||||
self.cfg.grad_clip_norm,
|
||||
|
@ -346,12 +360,6 @@ class _RgbEncoder(nn.Module):
|
|||
def __init__(self, cfg: DiffusionConfig):
|
||||
super().__init__()
|
||||
# Set up optional preprocessing.
|
||||
if all(v == 1.0 for v in chain(cfg.image_normalization_mean, cfg.image_normalization_std)):
|
||||
self.normalizer = nn.Identity()
|
||||
else:
|
||||
self.normalizer = torchvision.transforms.Normalize(
|
||||
mean=cfg.image_normalization_mean, std=cfg.image_normalization_std
|
||||
)
|
||||
if cfg.crop_shape is not None:
|
||||
self.do_crop = True
|
||||
# Always use center crop for eval
|
||||
|
@ -397,8 +405,7 @@ class _RgbEncoder(nn.Module):
|
|||
Returns:
|
||||
(B, D) image feature.
|
||||
"""
|
||||
# Preprocess: normalize and maybe crop (if it was set up in the __init__).
|
||||
x = self.normalizer(x)
|
||||
# Preprocess: maybe crop (if it was set up in the __init__).
|
||||
if self.do_crop:
|
||||
if self.training: # noqa: SIM108
|
||||
x = self.maybe_random_crop(x)
|
||||
|
|
|
@ -20,7 +20,7 @@ def _policy_cfg_from_hydra_cfg(policy_cfg_class, hydra_cfg):
|
|||
return policy_cfg
|
||||
|
||||
|
||||
def make_policy(hydra_cfg: DictConfig):
|
||||
def make_policy(hydra_cfg: DictConfig, dataset_stats=None):
|
||||
if hydra_cfg.policy.name == "tdmpc":
|
||||
from lerobot.common.policies.tdmpc.policy import TDMPCPolicy
|
||||
|
||||
|
@ -35,14 +35,14 @@ def make_policy(hydra_cfg: DictConfig):
|
|||
from lerobot.common.policies.diffusion.modeling_diffusion import DiffusionPolicy
|
||||
|
||||
policy_cfg = _policy_cfg_from_hydra_cfg(DiffusionConfig, hydra_cfg)
|
||||
policy = DiffusionPolicy(policy_cfg, hydra_cfg.offline_steps)
|
||||
policy = DiffusionPolicy(policy_cfg, hydra_cfg.offline_steps, dataset_stats)
|
||||
policy.to(get_safe_torch_device(hydra_cfg.device))
|
||||
elif hydra_cfg.policy.name == "act":
|
||||
from lerobot.common.policies.act.configuration_act import ActionChunkingTransformerConfig
|
||||
from lerobot.common.policies.act.modeling_act import ActionChunkingTransformerPolicy
|
||||
|
||||
policy_cfg = _policy_cfg_from_hydra_cfg(ActionChunkingTransformerConfig, hydra_cfg)
|
||||
policy = ActionChunkingTransformerPolicy(policy_cfg)
|
||||
policy = ActionChunkingTransformerPolicy(policy_cfg, dataset_stats)
|
||||
policy.to(get_safe_torch_device(hydra_cfg.device))
|
||||
else:
|
||||
raise ValueError(hydra_cfg.policy.name)
|
||||
|
|
|
@ -28,3 +28,41 @@ def get_dtype_from_parameters(module: nn.Module) -> torch.dtype:
|
|||
Note: assumes that all parameters have the same dtype.
|
||||
"""
|
||||
return next(iter(module.parameters())).dtype
|
||||
|
||||
|
||||
def normalize_inputs(batch, stats, normalize_input_modes):
|
||||
if normalize_input_modes is None:
|
||||
return batch
|
||||
for key, mode in normalize_input_modes.items():
|
||||
if mode == "mean_std":
|
||||
mean = stats[key]["mean"].unsqueeze(0)
|
||||
std = stats[key]["std"].unsqueeze(0)
|
||||
batch[key] = (batch[key] - mean) / (std + 1e-8)
|
||||
elif mode == "min_max":
|
||||
min = stats[key]["min"].unsqueeze(0)
|
||||
max = stats[key]["max"].unsqueeze(0)
|
||||
# normalize to [0,1]
|
||||
batch[key] = (batch[key] - min) / (max - min)
|
||||
# normalize to [-1, 1]
|
||||
batch[key] = batch[key] * 2 - 1
|
||||
else:
|
||||
raise ValueError(mode)
|
||||
return batch
|
||||
|
||||
|
||||
def unnormalize_outputs(batch, stats, unnormalize_output_modes):
|
||||
if unnormalize_output_modes is None:
|
||||
return batch
|
||||
for key, mode in unnormalize_output_modes.items():
|
||||
if mode == "mean_std":
|
||||
mean = stats[key]["mean"].unsqueeze(0)
|
||||
std = stats[key]["std"].unsqueeze(0)
|
||||
batch[key] = batch[key] * std + mean
|
||||
elif mode == "min_max":
|
||||
min = stats[key]["min"].unsqueeze(0)
|
||||
max = stats[key]["max"].unsqueeze(0)
|
||||
batch[key] = (batch[key] + 1) / 2
|
||||
batch[key] = batch[key] * (max - min) + min
|
||||
else:
|
||||
raise ValueError(mode)
|
||||
return batch
|
||||
|
|
|
@ -1,65 +0,0 @@
|
|||
from torchvision.transforms.v2 import Compose, Transform
|
||||
|
||||
|
||||
def apply_inverse_transform(item, transform):
|
||||
transforms = transform.transforms if isinstance(transform, Compose) else [transform]
|
||||
for tf in transforms[::-1]:
|
||||
if tf.invertible:
|
||||
item = tf.inverse_transform(item)
|
||||
else:
|
||||
raise ValueError(f"Inverse transform called on a non invertible transform ({tf}).")
|
||||
return item
|
||||
|
||||
|
||||
class NormalizeTransform(Transform):
|
||||
invertible = True
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
stats: dict,
|
||||
in_keys: list[str] = None,
|
||||
out_keys: list[str] | None = None,
|
||||
in_keys_inv: list[str] | None = None,
|
||||
out_keys_inv: list[str] | None = None,
|
||||
mode="mean_std",
|
||||
):
|
||||
super().__init__()
|
||||
self.in_keys = in_keys
|
||||
self.out_keys = in_keys if out_keys is None else out_keys
|
||||
self.in_keys_inv = self.out_keys if in_keys_inv is None else in_keys_inv
|
||||
self.out_keys_inv = self.in_keys if out_keys_inv is None else out_keys_inv
|
||||
self.stats = stats
|
||||
assert mode in ["mean_std", "min_max"]
|
||||
self.mode = mode
|
||||
|
||||
def forward(self, item):
|
||||
for inkey, outkey in zip(self.in_keys, self.out_keys, strict=False):
|
||||
if inkey not in item:
|
||||
continue
|
||||
if self.mode == "mean_std":
|
||||
mean = self.stats[inkey]["mean"]
|
||||
std = self.stats[inkey]["std"]
|
||||
item[outkey] = (item[inkey] - mean) / (std + 1e-8)
|
||||
else:
|
||||
min = self.stats[inkey]["min"]
|
||||
max = self.stats[inkey]["max"]
|
||||
# normalize to [0,1]
|
||||
item[outkey] = (item[inkey] - min) / (max - min)
|
||||
# normalize to [-1, 1]
|
||||
item[outkey] = item[outkey] * 2 - 1
|
||||
return item
|
||||
|
||||
def inverse_transform(self, item):
|
||||
for inkey, outkey in zip(self.in_keys_inv, self.out_keys_inv, strict=False):
|
||||
if inkey not in item:
|
||||
continue
|
||||
if self.mode == "mean_std":
|
||||
mean = self.stats[inkey]["mean"]
|
||||
std = self.stats[inkey]["std"]
|
||||
item[outkey] = item[inkey] * std + mean
|
||||
else:
|
||||
min = self.stats[inkey]["min"]
|
||||
max = self.stats[inkey]["max"]
|
||||
item[outkey] = (item[inkey] + 1) / 2
|
||||
item[outkey] = item[outkey] * (max - min) + min
|
||||
return item
|
|
@ -11,6 +11,11 @@ log_freq: 250
|
|||
n_obs_steps: 1
|
||||
# when temporal_agg=False, n_action_steps=horizon
|
||||
|
||||
override_dataset_stats:
|
||||
observation.image:
|
||||
mean: [[[0.485]], [[0.456]], [[0.406]]] # (c,1,1)
|
||||
std: [[[0.229]], [[0.224]], [[0.225]]] # (c,1,1)
|
||||
|
||||
# See `configuration_act.py` for more details.
|
||||
policy:
|
||||
name: act
|
||||
|
@ -28,9 +33,12 @@ policy:
|
|||
chunk_size: 100 # chunk_size
|
||||
n_action_steps: 100
|
||||
|
||||
# Vision preprocessing.
|
||||
image_normalization_mean: [0.485, 0.456, 0.406]
|
||||
image_normalization_std: [0.229, 0.224, 0.225]
|
||||
# Normalization / Unnormalization
|
||||
normalize_input_modes:
|
||||
observation.image: mean_std
|
||||
observation.state: mean_std
|
||||
unnormalize_output_modes:
|
||||
action: mean_std
|
||||
|
||||
# Architecture.
|
||||
# Vision backbone.
|
||||
|
|
|
@ -18,6 +18,17 @@ online_steps: 0
|
|||
|
||||
offline_prioritized_sampler: true
|
||||
|
||||
override_dataset_stats:
|
||||
observation.image:
|
||||
mean: [[[0.5]], [[0.5]], [[0.5]]] # (c,1,1)
|
||||
std: [[[0.5]], [[0.5]], [[0.5]]] # (c,1,1)
|
||||
observation.state:
|
||||
min: [13.456424, 32.938293]
|
||||
max: [496.14618, 510.9579]
|
||||
action:
|
||||
min: [12.0, 25.0]
|
||||
max: [511.0, 511.0]
|
||||
|
||||
policy:
|
||||
name: diffusion
|
||||
|
||||
|
@ -36,9 +47,12 @@ policy:
|
|||
horizon: ${horizon}
|
||||
n_action_steps: ${n_action_steps}
|
||||
|
||||
# Vision preprocessing.
|
||||
image_normalization_mean: [0.5, 0.5, 0.5]
|
||||
image_normalization_std: [0.5, 0.5, 0.5]
|
||||
# Normalization / Unnormalization
|
||||
normalize_input_modes:
|
||||
observation.image: mean_std
|
||||
observation.state: min_max
|
||||
unnormalize_output_modes:
|
||||
action: min_max
|
||||
|
||||
# Architecture / modeling.
|
||||
# Vision backbone.
|
||||
|
|
|
@ -46,7 +46,6 @@ from huggingface_hub import snapshot_download
|
|||
from PIL import Image as PILImage
|
||||
from tqdm import trange
|
||||
|
||||
from lerobot.common.datasets.factory import make_dataset
|
||||
from lerobot.common.datasets.utils import hf_transform_to_torch
|
||||
from lerobot.common.envs.factory import make_env
|
||||
from lerobot.common.envs.utils import postprocess_action, preprocess_observation
|
||||
|
@ -64,8 +63,6 @@ def eval_policy(
|
|||
policy: torch.nn.Module,
|
||||
max_episodes_rendered: int = 0,
|
||||
video_dir: Path = None,
|
||||
# TODO(rcadene): make it possible to overwrite fps? we should use env.fps
|
||||
transform: callable = None,
|
||||
return_episode_data: bool = False,
|
||||
seed=None,
|
||||
):
|
||||
|
@ -132,10 +129,6 @@ def eval_policy(
|
|||
if return_episode_data:
|
||||
observations.append(deepcopy(observation))
|
||||
|
||||
# apply transform to normalize the observations
|
||||
for key in observation:
|
||||
observation[key] = torch.stack([transform({key: item})[key] for item in observation[key]])
|
||||
|
||||
# send observation to device/gpu
|
||||
observation = {key: observation[key].to(device, non_blocking=True) for key in observation}
|
||||
|
||||
|
@ -143,8 +136,8 @@ def eval_policy(
|
|||
with torch.inference_mode():
|
||||
action = policy.select_action(observation, step=step)
|
||||
|
||||
# apply inverse transform to unnormalize the action
|
||||
action = postprocess_action(action, transform)
|
||||
# convert to cpu numpy
|
||||
action = postprocess_action(action)
|
||||
|
||||
# apply the next action
|
||||
observation, reward, terminated, truncated, info = env.step(action)
|
||||
|
@ -360,7 +353,7 @@ def eval_policy(
|
|||
return info
|
||||
|
||||
|
||||
def eval(cfg: dict, out_dir=None, stats_path=None):
|
||||
def eval(cfg: dict, out_dir=None):
|
||||
if out_dir is None:
|
||||
raise NotImplementedError()
|
||||
|
||||
|
@ -375,10 +368,6 @@ def eval(cfg: dict, out_dir=None, stats_path=None):
|
|||
|
||||
log_output_dir(out_dir)
|
||||
|
||||
logging.info("Making transforms.")
|
||||
# TODO(alexander-soare): Completely decouple datasets from evaluation.
|
||||
transform = make_dataset(cfg, stats_path=stats_path).transform
|
||||
|
||||
logging.info("Making environment.")
|
||||
env = make_env(cfg, num_parallel_envs=cfg.eval_episodes)
|
||||
|
||||
|
@ -390,7 +379,6 @@ def eval(cfg: dict, out_dir=None, stats_path=None):
|
|||
policy,
|
||||
max_episodes_rendered=10,
|
||||
video_dir=Path(out_dir) / "eval",
|
||||
transform=transform,
|
||||
return_episode_data=False,
|
||||
seed=cfg.seed,
|
||||
)
|
||||
|
@ -423,17 +411,13 @@ if __name__ == "__main__":
|
|||
if args.config is not None:
|
||||
# Note: For the config_path, Hydra wants a path relative to this script file.
|
||||
cfg = init_hydra_config(args.config, args.overrides)
|
||||
# TODO(alexander-soare): Save and load stats in trained model directory.
|
||||
stats_path = None
|
||||
elif args.hub_id is not None:
|
||||
folder = Path(snapshot_download(args.hub_id, revision=args.revision))
|
||||
cfg = init_hydra_config(
|
||||
folder / "config.yaml", [f"policy.pretrained_model_path={folder / 'model.pt'}", *args.overrides]
|
||||
)
|
||||
stats_path = folder / "stats.pth"
|
||||
|
||||
eval(
|
||||
cfg,
|
||||
out_dir=f"outputs/eval/{dt.now().strftime('%Y-%m-%d/%H-%M-%S')}_{cfg.env.name}_{cfg.policy.name}",
|
||||
stats_path=stats_path,
|
||||
)
|
||||
|
|
|
@ -232,7 +232,7 @@ def train(cfg: dict, out_dir=None, job_name=None):
|
|||
env = make_env(cfg, num_parallel_envs=cfg.eval_episodes)
|
||||
|
||||
logging.info("make_policy")
|
||||
policy = make_policy(cfg)
|
||||
policy = make_policy(cfg, dataset_stats=offline_dataset.stats)
|
||||
|
||||
num_learnable_params = sum(p.numel() for p in policy.parameters() if p.requires_grad)
|
||||
num_total_params = sum(p.numel() for p in policy.parameters())
|
||||
|
|
|
@ -42,7 +42,7 @@ def test_factory(env_name):
|
|||
|
||||
env = make_env(cfg, num_parallel_envs=1)
|
||||
obs, _ = env.reset()
|
||||
obs = preprocess_observation(obs, transform=dataset.transform)
|
||||
obs = preprocess_observation(obs)
|
||||
for key in dataset.image_keys:
|
||||
img = obs[key]
|
||||
assert img.dtype == torch.float32
|
||||
|
|
|
@ -51,7 +51,7 @@ def test_examples_4_and_3():
|
|||
# Pass empty globals to allow dictionary comprehension https://stackoverflow.com/a/32897127/4391249.
|
||||
exec(file_contents, {})
|
||||
|
||||
for file_name in ["model.pt", "stats.pth", "config.yaml"]:
|
||||
for file_name in ["model.pt", "config.yaml"]:
|
||||
assert Path(f"outputs/train/example_pusht_diffusion/{file_name}").exists()
|
||||
|
||||
path = "examples/3_evaluate_pretrained_policy.py"
|
||||
|
|
|
@ -44,14 +44,16 @@ def test_policy(env_name, policy_name, extra_overrides):
|
|||
]
|
||||
+ extra_overrides,
|
||||
)
|
||||
|
||||
# Check that we can make the policy object.
|
||||
policy = make_policy(cfg)
|
||||
dataset = make_dataset(cfg)
|
||||
policy = make_policy(cfg, dataset_stats=dataset.stats)
|
||||
# Check that the policy follows the required protocol.
|
||||
assert isinstance(
|
||||
policy, Policy
|
||||
), f"The policy does not follow the required protocol. Please see {Policy.__module__}.{Policy.__name__}."
|
||||
|
||||
# Check that we run select_actions and get the appropriate output.
|
||||
dataset = make_dataset(cfg)
|
||||
env = make_env(cfg, num_parallel_envs=2)
|
||||
|
||||
dataloader = torch.utils.data.DataLoader(
|
||||
|
@ -77,7 +79,7 @@ def test_policy(env_name, policy_name, extra_overrides):
|
|||
observation, _ = env.reset(seed=cfg.seed)
|
||||
|
||||
# apply transform to normalize the observations
|
||||
observation = preprocess_observation(observation, dataset.transform)
|
||||
observation = preprocess_observation(observation)
|
||||
|
||||
# send observation to device/gpu
|
||||
observation = {key: observation[key].to(DEVICE, non_blocking=True) for key in observation}
|
||||
|
@ -86,8 +88,8 @@ def test_policy(env_name, policy_name, extra_overrides):
|
|||
with torch.inference_mode():
|
||||
action = policy.select_action(observation, step=0)
|
||||
|
||||
# apply inverse transform to unnormalize the action
|
||||
action = postprocess_action(action, dataset.transform)
|
||||
# convert action to cpu numpy array
|
||||
action = postprocess_action(action)
|
||||
|
||||
# Test step through policy
|
||||
env.step(action)
|
||||
|
|
Loading…
Reference in New Issue