Move normalize/unnormalize transforms to policy for act and diffusion

This commit is contained in:
Cadene 2024-04-20 21:08:14 +00:00
parent c1bcf857c5
commit 42ed7bb670
19 changed files with 145 additions and 195 deletions

View File

@ -263,15 +263,13 @@ Secondly, assuming you have trained a policy, you need:
- `config.yaml` which you can get from the `.hydra` directory of your training output folder.
- `model.pt` which should be one of the saved models in the `models` directory of your training output folder (they won't be named `model.pt` but you will need to choose one).
- `stats.pth` which should point to the same file in the dataset directory (found in `data/{dataset_name}`).
To upload these to the hub, prepare a folder with the following structure (you can use symlinks rather than copying):
```
to_upload
├── config.yaml
├── model.pt
└── stats.pth
└── model.pt
```
With the folder prepared, run the following with a desired revision ID.

View File

@ -19,7 +19,6 @@ folder = Path(snapshot_download(hub_id))
config_path = folder / "config.yaml"
weights_path = folder / "model.pt"
stats_path = folder / "stats.pth" # normalization stats
# Override some config parameters to do with evaluation.
overrides = [
@ -36,5 +35,4 @@ cfg = init_hydra_config(config_path, overrides)
eval(
cfg,
out_dir=f"outputs/eval/example_{cfg.env.name}_{cfg.policy.name}",
stats_path=stats_path,
)

View File

@ -62,7 +62,6 @@ while not done:
done = True
break
# Save the policy, configuration, and normalization stats for later use.
# Save the policy and configuration for later use.
policy.save(output_directory / "model.pt")
OmegaConf.save(hydra_cfg, output_directory / "config.yaml")
torch.save(dataset.transform.transforms[-1].stats, output_directory / "stats.pth")

View File

@ -2,18 +2,12 @@ import os
from pathlib import Path
import torch
from torchvision.transforms import v2
from lerobot.common.transforms import NormalizeTransform
DATA_DIR = Path(os.environ["DATA_DIR"]) if "DATA_DIR" in os.environ else None
def make_dataset(
cfg,
# set normalize=False to remove all transformations and keep images unnormalized in [0,255]
normalize=True,
stats_path=None,
split="train",
):
if cfg.env.name == "xarm":
@ -33,58 +27,23 @@ def make_dataset(
else:
raise ValueError(cfg.env.name)
transforms = None
if normalize:
# TODO(rcadene): make normalization strategy configurable between mean_std, min_max, manual_min_max,
# min_max_from_spec
# TODO(rcadene): remove this and put it in config. Ideally we want to reproduce SOTA results just with mean_std
normalization_mode = "mean_std" if cfg.env.name == "aloha" else "min_max"
if cfg.policy.name == "diffusion" and cfg.env.name == "pusht":
stats = {}
# TODO(rcadene): we overwrite stats to have the same as pretrained model, but we should remove this
stats["observation.state"] = {}
stats["observation.state"]["min"] = torch.tensor([13.456424, 32.938293], dtype=torch.float32)
stats["observation.state"]["max"] = torch.tensor([496.14618, 510.9579], dtype=torch.float32)
stats["action"] = {}
stats["action"]["min"] = torch.tensor([12.0, 25.0], dtype=torch.float32)
stats["action"]["max"] = torch.tensor([511.0, 511.0], dtype=torch.float32)
elif stats_path is None:
# load a first dataset to access precomputed stats
stats_dataset = clsfunc(
dataset_id=cfg.dataset_id,
split="train",
root=DATA_DIR,
)
stats = stats_dataset.stats
else:
stats = torch.load(stats_path)
transforms = v2.Compose(
[
NormalizeTransform(
stats,
in_keys=[
"observation.state",
"action",
],
mode=normalization_mode,
),
]
)
delta_timestamps = cfg.policy.get("delta_timestamps")
if delta_timestamps is not None:
for key in delta_timestamps:
if isinstance(delta_timestamps[key], str):
delta_timestamps[key] = eval(delta_timestamps[key])
# TODO(rcadene): add data augmentations
dataset = clsfunc(
dataset_id=cfg.dataset_id,
split=split,
root=DATA_DIR,
delta_timestamps=delta_timestamps,
transform=transforms,
)
if cfg.get("override_dataset_stats"):
for key, val in cfg.override_dataset_stats.items():
dataset.stats[key] = torch.tensor(val)
return dataset

View File

@ -1,10 +1,8 @@
import einops
import torch
from lerobot.common.transforms import apply_inverse_transform
def preprocess_observation(observation, transform=None):
def preprocess_observation(observation):
# map to expected inputs for the policy
obs = {}
@ -33,19 +31,11 @@ def preprocess_observation(observation, transform=None):
# TODO(rcadene): enable pixels only baseline with `obs_type="pixels"` in environment by removing requirement for "agent_pos"
obs["observation.state"] = torch.from_numpy(observation["agent_pos"]).float()
# apply same transforms as in training
if transform is not None:
for key in obs:
obs[key] = torch.stack([transform({key: item})[key] for item in obs[key]])
return obs
def postprocess_action(action, transform=None):
action = action.to("cpu")
# action is a batch (num_env,action_dim) instead of an item (action_dim),
# we assume applying inverse transform on a batch works the same
action = apply_inverse_transform({"action": action}, transform)["action"].numpy()
def postprocess_action(action):
action = action.to("cpu").numpy()
assert (
action.ndim == 2
), "we assume dimensions are respectively the number of parallel envs, action dimensions"

View File

@ -1,4 +1,4 @@
from dataclasses import dataclass, field
from dataclasses import dataclass
@dataclass
@ -60,12 +60,14 @@ class ActionChunkingTransformerConfig:
chunk_size: int = 100
n_action_steps: int = 100
# Vision preprocessing.
image_normalization_mean: tuple[float, float, float] = field(
default_factory=lambda: [0.485, 0.456, 0.406]
)
image_normalization_std: tuple[float, float, float] = field(default_factory=lambda: [0.229, 0.224, 0.225])
# Normalization / Unnormalization
normalize_input_modes: dict[str, str] = {
"observation.image": "mean_std",
"observation.state": "mean_std",
}
unnormalize_output_modes: dict[str, str] = {
"action": "mean_std",
}
# Architecture.
# Vision backbone.
vision_backbone: str = "resnet18"

View File

@ -15,12 +15,15 @@ import numpy as np
import torch
import torch.nn.functional as F # noqa: N812
import torchvision
import torchvision.transforms as transforms
from torch import Tensor, nn
from torchvision.models._utils import IntermediateLayerGetter
from torchvision.ops.misc import FrozenBatchNorm2d
from lerobot.common.policies.act.configuration_act import ActionChunkingTransformerConfig
from lerobot.common.policies.utils import (
normalize_inputs,
unnormalize_outputs,
)
class ActionChunkingTransformerPolicy(nn.Module):
@ -62,7 +65,7 @@ class ActionChunkingTransformerPolicy(nn.Module):
name = "act"
def __init__(self, cfg: ActionChunkingTransformerConfig | None = None):
def __init__(self, cfg: ActionChunkingTransformerConfig | None = None, dataset_stats=None):
"""
Args:
cfg: Policy configuration class instance or None, in which case the default instantiation of the
@ -72,6 +75,9 @@ class ActionChunkingTransformerPolicy(nn.Module):
if cfg is None:
cfg = ActionChunkingTransformerConfig()
self.cfg = cfg
self.register_buffer("dataset_stats", dataset_stats)
self.normalize_input_modes = cfg.normalize_input_modes
self.unnormalize_output_modes = cfg.unnormalize_output_modes
# BERT style VAE encoder with input [cls, *joint_space_configuration, *action_sequence].
# The cls token forms parameters of the latent's distribution (like this [*means, *log_variances]).
@ -93,9 +99,6 @@ class ActionChunkingTransformerPolicy(nn.Module):
)
# Backbone for image feature extraction.
self.image_normalizer = transforms.Normalize(
mean=cfg.image_normalization_mean, std=cfg.image_normalization_std
)
backbone_model = getattr(torchvision.models, cfg.vision_backbone)(
replace_stride_with_dilation=[False, False, cfg.replace_final_stride_with_dilation],
pretrained=cfg.use_pretrained_backbone,
@ -169,10 +172,15 @@ class ActionChunkingTransformerPolicy(nn.Module):
queue is empty.
"""
self.eval()
batch = normalize_inputs(batch, self.dataset_stats, self.normalize_input_modes)
if len(self._action_queue) == 0:
# `_forward` returns a (batch_size, n_action_steps, action_dim) tensor, but the queue effectively
# has shape (n_action_steps, batch_size, *), hence the transpose.
self._action_queue.extend(self._forward(batch)[0][: self.cfg.n_action_steps].transpose(0, 1))
actions = self._forward(batch)[0][: self.cfg.n_action_steps]
actions = unnormalize_outputs(actions, self.dataset_stats, self.unnormalize_output_modes)
self._action_queue.extend(actions.transpose(0, 1))
return self._action_queue.popleft()
def forward(self, batch, **_) -> dict[str, Tensor]:
@ -203,7 +211,10 @@ class ActionChunkingTransformerPolicy(nn.Module):
"""Run the model in train mode, compute the loss, and do an optimization step."""
start_time = time.time()
self.train()
batch = normalize_inputs(batch, self.dataset_stats, self.normalize_input_modes)
loss_dict = self.forward(batch)
# TODO(rcadene): unnormalize_outputs(actions, self.dataset_stats, self.unnormalize_output_modes)
loss = loss_dict["loss"]
loss.backward()
@ -309,7 +320,7 @@ class ActionChunkingTransformerPolicy(nn.Module):
# Camera observation features and positional embeddings.
all_cam_features = []
all_cam_pos_embeds = []
images = self.image_normalizer(batch["observation.images"])
images = batch["observation.images"]
for cam_index in range(len(self.cfg.camera_names)):
cam_features = self.backbone(images[:, cam_index])["feature_map"]
cam_pos_embed = self.encoder_cam_feat_pos_embed(cam_features).to(dtype=cam_features.dtype)

View File

@ -69,9 +69,14 @@ class DiffusionConfig:
horizon: int = 16
n_action_steps: int = 8
# Vision preprocessing.
image_normalization_mean: tuple[float, float, float] = (0.5, 0.5, 0.5)
image_normalization_std: tuple[float, float, float] = (0.5, 0.5, 0.5)
# Normalization / Unnormalization
normalize_input_modes: dict[str, str] = {
"observation.image": "mean_std",
"observation.state": "min_max",
}
unnormalize_output_modes: dict[str, str] = {
"action": "min_max",
}
# Architecture / modeling.
# Vision backbone.

View File

@ -13,7 +13,6 @@ import logging
import math
import time
from collections import deque
from itertools import chain
from typing import Callable
import einops
@ -30,7 +29,9 @@ from lerobot.common.policies.diffusion.configuration_diffusion import DiffusionC
from lerobot.common.policies.utils import (
get_device_from_parameters,
get_dtype_from_parameters,
normalize_inputs,
populate_queues,
unnormalize_outputs,
)
@ -42,7 +43,9 @@ class DiffusionPolicy(nn.Module):
name = "diffusion"
def __init__(self, cfg: DiffusionConfig | None, lr_scheduler_num_training_steps: int = 0):
def __init__(
self, cfg: DiffusionConfig | None, lr_scheduler_num_training_steps: int = 0, dataset_stats=None
):
"""
Args:
cfg: Policy configuration class instance or None, in which case the default instantiation of the
@ -54,6 +57,9 @@ class DiffusionPolicy(nn.Module):
if cfg is None:
cfg = DiffusionConfig()
self.cfg = cfg
self.register_buffer("dataset_stats", dataset_stats)
self.normalize_input_modes = cfg.normalize_input_modes
self.unnormalize_output_modes = cfg.unnormalize_output_modes
# queues are populated during rollout of the policy, they contain the n latest observations and actions
self._queues = None
@ -126,6 +132,8 @@ class DiffusionPolicy(nn.Module):
assert "observation.state" in batch
assert len(batch) == 2
batch = normalize_inputs(batch, self.dataset_stats, self.normalize_input_modes)
self._queues = populate_queues(self._queues, batch)
if len(self._queues["action"]) == 0:
@ -135,6 +143,8 @@ class DiffusionPolicy(nn.Module):
actions = self.ema_diffusion.generate_actions(batch)
else:
actions = self.diffusion.generate_actions(batch)
actions = unnormalize_outputs(actions, self.dataset_stats, self.unnormalize_output_modes)
self._queues["action"].extend(actions.transpose(0, 1))
action = self._queues["action"].popleft()
@ -151,9 +161,13 @@ class DiffusionPolicy(nn.Module):
self.diffusion.train()
batch = normalize_inputs(batch, self.dataset_stats, self.normalize_input_modes)
loss = self.forward(batch)["loss"]
loss.backward()
# TODO(rcadene): unnormalize_outputs(actions, self.dataset_stats, self.unnormalize_output_modes)
grad_norm = torch.nn.utils.clip_grad_norm_(
self.diffusion.parameters(),
self.cfg.grad_clip_norm,
@ -346,12 +360,6 @@ class _RgbEncoder(nn.Module):
def __init__(self, cfg: DiffusionConfig):
super().__init__()
# Set up optional preprocessing.
if all(v == 1.0 for v in chain(cfg.image_normalization_mean, cfg.image_normalization_std)):
self.normalizer = nn.Identity()
else:
self.normalizer = torchvision.transforms.Normalize(
mean=cfg.image_normalization_mean, std=cfg.image_normalization_std
)
if cfg.crop_shape is not None:
self.do_crop = True
# Always use center crop for eval
@ -397,8 +405,7 @@ class _RgbEncoder(nn.Module):
Returns:
(B, D) image feature.
"""
# Preprocess: normalize and maybe crop (if it was set up in the __init__).
x = self.normalizer(x)
# Preprocess: maybe crop (if it was set up in the __init__).
if self.do_crop:
if self.training: # noqa: SIM108
x = self.maybe_random_crop(x)

View File

@ -20,7 +20,7 @@ def _policy_cfg_from_hydra_cfg(policy_cfg_class, hydra_cfg):
return policy_cfg
def make_policy(hydra_cfg: DictConfig):
def make_policy(hydra_cfg: DictConfig, dataset_stats=None):
if hydra_cfg.policy.name == "tdmpc":
from lerobot.common.policies.tdmpc.policy import TDMPCPolicy
@ -35,14 +35,14 @@ def make_policy(hydra_cfg: DictConfig):
from lerobot.common.policies.diffusion.modeling_diffusion import DiffusionPolicy
policy_cfg = _policy_cfg_from_hydra_cfg(DiffusionConfig, hydra_cfg)
policy = DiffusionPolicy(policy_cfg, hydra_cfg.offline_steps)
policy = DiffusionPolicy(policy_cfg, hydra_cfg.offline_steps, dataset_stats)
policy.to(get_safe_torch_device(hydra_cfg.device))
elif hydra_cfg.policy.name == "act":
from lerobot.common.policies.act.configuration_act import ActionChunkingTransformerConfig
from lerobot.common.policies.act.modeling_act import ActionChunkingTransformerPolicy
policy_cfg = _policy_cfg_from_hydra_cfg(ActionChunkingTransformerConfig, hydra_cfg)
policy = ActionChunkingTransformerPolicy(policy_cfg)
policy = ActionChunkingTransformerPolicy(policy_cfg, dataset_stats)
policy.to(get_safe_torch_device(hydra_cfg.device))
else:
raise ValueError(hydra_cfg.policy.name)

View File

@ -28,3 +28,41 @@ def get_dtype_from_parameters(module: nn.Module) -> torch.dtype:
Note: assumes that all parameters have the same dtype.
"""
return next(iter(module.parameters())).dtype
def normalize_inputs(batch, stats, normalize_input_modes):
if normalize_input_modes is None:
return batch
for key, mode in normalize_input_modes.items():
if mode == "mean_std":
mean = stats[key]["mean"].unsqueeze(0)
std = stats[key]["std"].unsqueeze(0)
batch[key] = (batch[key] - mean) / (std + 1e-8)
elif mode == "min_max":
min = stats[key]["min"].unsqueeze(0)
max = stats[key]["max"].unsqueeze(0)
# normalize to [0,1]
batch[key] = (batch[key] - min) / (max - min)
# normalize to [-1, 1]
batch[key] = batch[key] * 2 - 1
else:
raise ValueError(mode)
return batch
def unnormalize_outputs(batch, stats, unnormalize_output_modes):
if unnormalize_output_modes is None:
return batch
for key, mode in unnormalize_output_modes.items():
if mode == "mean_std":
mean = stats[key]["mean"].unsqueeze(0)
std = stats[key]["std"].unsqueeze(0)
batch[key] = batch[key] * std + mean
elif mode == "min_max":
min = stats[key]["min"].unsqueeze(0)
max = stats[key]["max"].unsqueeze(0)
batch[key] = (batch[key] + 1) / 2
batch[key] = batch[key] * (max - min) + min
else:
raise ValueError(mode)
return batch

View File

@ -1,65 +0,0 @@
from torchvision.transforms.v2 import Compose, Transform
def apply_inverse_transform(item, transform):
transforms = transform.transforms if isinstance(transform, Compose) else [transform]
for tf in transforms[::-1]:
if tf.invertible:
item = tf.inverse_transform(item)
else:
raise ValueError(f"Inverse transform called on a non invertible transform ({tf}).")
return item
class NormalizeTransform(Transform):
invertible = True
def __init__(
self,
stats: dict,
in_keys: list[str] = None,
out_keys: list[str] | None = None,
in_keys_inv: list[str] | None = None,
out_keys_inv: list[str] | None = None,
mode="mean_std",
):
super().__init__()
self.in_keys = in_keys
self.out_keys = in_keys if out_keys is None else out_keys
self.in_keys_inv = self.out_keys if in_keys_inv is None else in_keys_inv
self.out_keys_inv = self.in_keys if out_keys_inv is None else out_keys_inv
self.stats = stats
assert mode in ["mean_std", "min_max"]
self.mode = mode
def forward(self, item):
for inkey, outkey in zip(self.in_keys, self.out_keys, strict=False):
if inkey not in item:
continue
if self.mode == "mean_std":
mean = self.stats[inkey]["mean"]
std = self.stats[inkey]["std"]
item[outkey] = (item[inkey] - mean) / (std + 1e-8)
else:
min = self.stats[inkey]["min"]
max = self.stats[inkey]["max"]
# normalize to [0,1]
item[outkey] = (item[inkey] - min) / (max - min)
# normalize to [-1, 1]
item[outkey] = item[outkey] * 2 - 1
return item
def inverse_transform(self, item):
for inkey, outkey in zip(self.in_keys_inv, self.out_keys_inv, strict=False):
if inkey not in item:
continue
if self.mode == "mean_std":
mean = self.stats[inkey]["mean"]
std = self.stats[inkey]["std"]
item[outkey] = item[inkey] * std + mean
else:
min = self.stats[inkey]["min"]
max = self.stats[inkey]["max"]
item[outkey] = (item[inkey] + 1) / 2
item[outkey] = item[outkey] * (max - min) + min
return item

View File

@ -11,6 +11,11 @@ log_freq: 250
n_obs_steps: 1
# when temporal_agg=False, n_action_steps=horizon
override_dataset_stats:
observation.image:
mean: [[[0.485]], [[0.456]], [[0.406]]] # (c,1,1)
std: [[[0.229]], [[0.224]], [[0.225]]] # (c,1,1)
# See `configuration_act.py` for more details.
policy:
name: act
@ -28,9 +33,12 @@ policy:
chunk_size: 100 # chunk_size
n_action_steps: 100
# Vision preprocessing.
image_normalization_mean: [0.485, 0.456, 0.406]
image_normalization_std: [0.229, 0.224, 0.225]
# Normalization / Unnormalization
normalize_input_modes:
observation.image: mean_std
observation.state: mean_std
unnormalize_output_modes:
action: mean_std
# Architecture.
# Vision backbone.

View File

@ -18,6 +18,17 @@ online_steps: 0
offline_prioritized_sampler: true
override_dataset_stats:
observation.image:
mean: [[[0.5]], [[0.5]], [[0.5]]] # (c,1,1)
std: [[[0.5]], [[0.5]], [[0.5]]] # (c,1,1)
observation.state:
min: [13.456424, 32.938293]
max: [496.14618, 510.9579]
action:
min: [12.0, 25.0]
max: [511.0, 511.0]
policy:
name: diffusion
@ -36,9 +47,12 @@ policy:
horizon: ${horizon}
n_action_steps: ${n_action_steps}
# Vision preprocessing.
image_normalization_mean: [0.5, 0.5, 0.5]
image_normalization_std: [0.5, 0.5, 0.5]
# Normalization / Unnormalization
normalize_input_modes:
observation.image: mean_std
observation.state: min_max
unnormalize_output_modes:
action: min_max
# Architecture / modeling.
# Vision backbone.

View File

@ -46,7 +46,6 @@ from huggingface_hub import snapshot_download
from PIL import Image as PILImage
from tqdm import trange
from lerobot.common.datasets.factory import make_dataset
from lerobot.common.datasets.utils import hf_transform_to_torch
from lerobot.common.envs.factory import make_env
from lerobot.common.envs.utils import postprocess_action, preprocess_observation
@ -64,8 +63,6 @@ def eval_policy(
policy: torch.nn.Module,
max_episodes_rendered: int = 0,
video_dir: Path = None,
# TODO(rcadene): make it possible to overwrite fps? we should use env.fps
transform: callable = None,
return_episode_data: bool = False,
seed=None,
):
@ -132,10 +129,6 @@ def eval_policy(
if return_episode_data:
observations.append(deepcopy(observation))
# apply transform to normalize the observations
for key in observation:
observation[key] = torch.stack([transform({key: item})[key] for item in observation[key]])
# send observation to device/gpu
observation = {key: observation[key].to(device, non_blocking=True) for key in observation}
@ -143,8 +136,8 @@ def eval_policy(
with torch.inference_mode():
action = policy.select_action(observation, step=step)
# apply inverse transform to unnormalize the action
action = postprocess_action(action, transform)
# convert to cpu numpy
action = postprocess_action(action)
# apply the next action
observation, reward, terminated, truncated, info = env.step(action)
@ -360,7 +353,7 @@ def eval_policy(
return info
def eval(cfg: dict, out_dir=None, stats_path=None):
def eval(cfg: dict, out_dir=None):
if out_dir is None:
raise NotImplementedError()
@ -375,10 +368,6 @@ def eval(cfg: dict, out_dir=None, stats_path=None):
log_output_dir(out_dir)
logging.info("Making transforms.")
# TODO(alexander-soare): Completely decouple datasets from evaluation.
transform = make_dataset(cfg, stats_path=stats_path).transform
logging.info("Making environment.")
env = make_env(cfg, num_parallel_envs=cfg.eval_episodes)
@ -390,7 +379,6 @@ def eval(cfg: dict, out_dir=None, stats_path=None):
policy,
max_episodes_rendered=10,
video_dir=Path(out_dir) / "eval",
transform=transform,
return_episode_data=False,
seed=cfg.seed,
)
@ -423,17 +411,13 @@ if __name__ == "__main__":
if args.config is not None:
# Note: For the config_path, Hydra wants a path relative to this script file.
cfg = init_hydra_config(args.config, args.overrides)
# TODO(alexander-soare): Save and load stats in trained model directory.
stats_path = None
elif args.hub_id is not None:
folder = Path(snapshot_download(args.hub_id, revision=args.revision))
cfg = init_hydra_config(
folder / "config.yaml", [f"policy.pretrained_model_path={folder / 'model.pt'}", *args.overrides]
)
stats_path = folder / "stats.pth"
eval(
cfg,
out_dir=f"outputs/eval/{dt.now().strftime('%Y-%m-%d/%H-%M-%S')}_{cfg.env.name}_{cfg.policy.name}",
stats_path=stats_path,
)

View File

@ -232,7 +232,7 @@ def train(cfg: dict, out_dir=None, job_name=None):
env = make_env(cfg, num_parallel_envs=cfg.eval_episodes)
logging.info("make_policy")
policy = make_policy(cfg)
policy = make_policy(cfg, dataset_stats=offline_dataset.stats)
num_learnable_params = sum(p.numel() for p in policy.parameters() if p.requires_grad)
num_total_params = sum(p.numel() for p in policy.parameters())

View File

@ -42,7 +42,7 @@ def test_factory(env_name):
env = make_env(cfg, num_parallel_envs=1)
obs, _ = env.reset()
obs = preprocess_observation(obs, transform=dataset.transform)
obs = preprocess_observation(obs)
for key in dataset.image_keys:
img = obs[key]
assert img.dtype == torch.float32

View File

@ -51,7 +51,7 @@ def test_examples_4_and_3():
# Pass empty globals to allow dictionary comprehension https://stackoverflow.com/a/32897127/4391249.
exec(file_contents, {})
for file_name in ["model.pt", "stats.pth", "config.yaml"]:
for file_name in ["model.pt", "config.yaml"]:
assert Path(f"outputs/train/example_pusht_diffusion/{file_name}").exists()
path = "examples/3_evaluate_pretrained_policy.py"

View File

@ -44,14 +44,16 @@ def test_policy(env_name, policy_name, extra_overrides):
]
+ extra_overrides,
)
# Check that we can make the policy object.
policy = make_policy(cfg)
dataset = make_dataset(cfg)
policy = make_policy(cfg, dataset_stats=dataset.stats)
# Check that the policy follows the required protocol.
assert isinstance(
policy, Policy
), f"The policy does not follow the required protocol. Please see {Policy.__module__}.{Policy.__name__}."
# Check that we run select_actions and get the appropriate output.
dataset = make_dataset(cfg)
env = make_env(cfg, num_parallel_envs=2)
dataloader = torch.utils.data.DataLoader(
@ -77,7 +79,7 @@ def test_policy(env_name, policy_name, extra_overrides):
observation, _ = env.reset(seed=cfg.seed)
# apply transform to normalize the observations
observation = preprocess_observation(observation, dataset.transform)
observation = preprocess_observation(observation)
# send observation to device/gpu
observation = {key: observation[key].to(DEVICE, non_blocking=True) for key in observation}
@ -86,8 +88,8 @@ def test_policy(env_name, policy_name, extra_overrides):
with torch.inference_mode():
action = policy.select_action(observation, step=0)
# apply inverse transform to unnormalize the action
action = postprocess_action(action, dataset.transform)
# convert action to cpu numpy array
action = postprocess_action(action)
# Test step through policy
env.step(action)