Add regression tests (#119)

- Add `tests/scripts/save_policy_to_safetensor.py` to generate test artifacts
- Add `test_backward_compatibility to test generated outputs from the policies against artifacts
This commit is contained in:
Simon Alibert 2024-05-04 16:20:30 +02:00 committed by GitHub
parent 19812ca470
commit c77633c38c
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
15 changed files with 236 additions and 43 deletions

View File

@ -80,7 +80,8 @@ class TDMPCPolicy(nn.Module, PyTorchModelHubMixin):
self.config = config
self.model = TDMPCTOLD(config)
self.model_target = deepcopy(self.model)
self.model_target.eval()
for param in self.model_target.parameters():
param.requires_grad = False
if config.input_normalization_modes is not None:
self.normalize_inputs = Normalize(

View File

@ -1,6 +1,7 @@
# @package _global_
seed: 1
dataset_repo_id: lerobot/xarm_lift_medium_replay
training:
offline_steps: 25000

View File

@ -25,6 +25,51 @@ from lerobot.common.utils.utils import (
from lerobot.scripts.eval import eval_policy
def make_optimizer_and_scheduler(cfg, policy):
if cfg.policy.name == "act":
optimizer_params_dicts = [
{
"params": [
p
for n, p in policy.named_parameters()
if not n.startswith("backbone") and p.requires_grad
]
},
{
"params": [
p for n, p in policy.named_parameters() if n.startswith("backbone") and p.requires_grad
],
"lr": cfg.training.lr_backbone,
},
]
optimizer = torch.optim.AdamW(
optimizer_params_dicts, lr=cfg.training.lr, weight_decay=cfg.training.weight_decay
)
lr_scheduler = None
elif cfg.policy.name == "diffusion":
optimizer = torch.optim.Adam(
policy.diffusion.parameters(),
cfg.training.lr,
cfg.training.adam_betas,
cfg.training.adam_eps,
cfg.training.adam_weight_decay,
)
assert cfg.training.online_steps == 0, "Diffusion Policy does not handle online training."
lr_scheduler = get_scheduler(
cfg.training.lr_scheduler,
optimizer=optimizer,
num_warmup_steps=cfg.training.lr_warmup_steps,
num_training_steps=cfg.training.offline_steps,
)
elif policy.name == "tdmpc":
optimizer = torch.optim.Adam(policy.parameters(), cfg.training.lr)
lr_scheduler = None
else:
raise NotImplementedError()
return optimizer, lr_scheduler
def update_policy(policy, batch, optimizer, grad_clip_norm, lr_scheduler=None):
start_time = time.time()
policy.train()
@ -276,46 +321,7 @@ def train(cfg: dict, out_dir=None, job_name=None):
# Create optimizer and scheduler
# Temporary hack to move optimizer out of policy
if cfg.policy.name == "act":
optimizer_params_dicts = [
{
"params": [
p
for n, p in policy.named_parameters()
if not n.startswith("backbone") and p.requires_grad
]
},
{
"params": [
p for n, p in policy.named_parameters() if n.startswith("backbone") and p.requires_grad
],
"lr": cfg.training.lr_backbone,
},
]
optimizer = torch.optim.AdamW(
optimizer_params_dicts, lr=cfg.training.lr, weight_decay=cfg.training.weight_decay
)
lr_scheduler = None
elif cfg.policy.name == "diffusion":
optimizer = torch.optim.Adam(
policy.diffusion.parameters(),
cfg.training.lr,
cfg.training.adam_betas,
cfg.training.adam_eps,
cfg.training.adam_weight_decay,
)
assert cfg.training.online_steps == 0, "Diffusion Policy does not handle online training."
lr_scheduler = get_scheduler(
cfg.training.lr_scheduler,
optimizer=optimizer,
num_warmup_steps=cfg.training.lr_warmup_steps,
num_training_steps=cfg.training.offline_steps,
)
elif policy.name == "tdmpc":
optimizer = torch.optim.Adam(policy.parameters(), cfg.training.lr)
lr_scheduler = None
else:
raise NotImplementedError()
optimizer, lr_scheduler = make_optimizer_and_scheduler(cfg, policy)
num_learnable_params = sum(p.numel() for p in policy.parameters() if p.requires_grad)
num_total_params = sum(p.numel() for p in policy.parameters())

View File

@ -0,0 +1,101 @@
import shutil
from pathlib import Path
import torch
from safetensors.torch import save_file
from lerobot.common.datasets.factory import make_dataset
from lerobot.common.policies.factory import make_policy
from lerobot.common.utils.utils import init_hydra_config, set_global_seed
from lerobot.scripts.train import make_optimizer_and_scheduler
from tests.utils import DEFAULT_CONFIG_PATH
def get_policy_stats(env_name, policy_name, extra_overrides=None):
cfg = init_hydra_config(
DEFAULT_CONFIG_PATH,
overrides=[
f"env={env_name}",
f"policy={policy_name}",
"device=cpu",
]
+ extra_overrides,
)
set_global_seed(1337)
dataset = make_dataset(cfg)
policy = make_policy(cfg, dataset_stats=dataset.stats)
policy.train()
optimizer, _ = make_optimizer_and_scheduler(cfg, policy)
dataloader = torch.utils.data.DataLoader(
dataset,
num_workers=0,
batch_size=cfg.training.batch_size,
shuffle=False,
)
batch = next(iter(dataloader))
output_dict = policy.forward(batch)
output_dict = {k: v for k, v in output_dict.items() if isinstance(v, torch.Tensor)}
loss = output_dict["loss"]
loss.backward()
grad_stats = {}
for key, param in policy.named_parameters():
if param.requires_grad:
grad_stats[f"{key}_mean"] = param.grad.mean()
grad_stats[f"{key}_std"] = (
param.grad.std() if param.grad.numel() > 1 else torch.tensor(float(0.0))
)
optimizer.step()
param_stats = {}
for key, param in policy.named_parameters():
param_stats[f"{key}_mean"] = param.mean()
param_stats[f"{key}_std"] = param.std() if param.numel() > 1 else torch.tensor(float(0.0))
optimizer.zero_grad()
policy.reset()
# HACK: We reload a batch with no delta_timestamps as `select_action` won't expect a timestamps dimension
dataset.delta_timestamps = None
batch = next(iter(dataloader))
obs = {
k: batch[k]
for k in batch
if k in ["observation.image", "observation.images.top", "observation.state"]
}
actions_queue = (
cfg.policy.n_action_steps if "n_action_steps" in cfg.policy else cfg.policy.n_action_repeats
)
actions = {str(i): policy.select_action(obs).contiguous() for i in range(actions_queue)}
return output_dict, grad_stats, param_stats, actions
def save_policy_to_safetensors(output_dir, env_name, policy_name, extra_overrides):
env_policy_dir = Path(output_dir) / f"{env_name}_{policy_name}"
if env_policy_dir.exists():
shutil.rmtree(env_policy_dir)
env_policy_dir.mkdir(parents=True, exist_ok=True)
output_dict, grad_stats, param_stats, actions = get_policy_stats(env_name, policy_name, extra_overrides)
save_file(output_dict, env_policy_dir / "output_dict.safetensors")
save_file(grad_stats, env_policy_dir / "grad_stats.safetensors")
save_file(param_stats, env_policy_dir / "param_stats.safetensors")
save_file(actions, env_policy_dir / "actions.safetensors")
if __name__ == "__main__":
env_policies = [
# ("xarm", "tdmpc", ["policy.n_action_repeats=2"]),
(
"pusht",
"diffusion",
["policy.n_action_steps=8", "policy.num_inference_steps=10", "policy.down_dims=[128, 256, 512]"],
),
("aloha", "act", ["policy.n_action_steps=10"]),
]
for env, policy, extra_overrides in env_policies:
save_policy_to_safetensors("tests/data/save_policy_to_safetensors", env, policy, extra_overrides)

View File

@ -265,7 +265,7 @@ def test_backward_compatibility(repo_id):
for key in new_frame:
assert torch.isclose(
new_frame[key], old_frame[key], rtol=1e-05, atol=1e-08
new_frame[key], old_frame[key]
).all(), f"{key=} for index={i} does not contain the same value"
# test2 first frames of first episode

View File

@ -1,8 +1,10 @@
import inspect
from pathlib import Path
import pytest
import torch
from huggingface_hub import PyTorchModelHubMixin
from safetensors.torch import load_file
from lerobot import available_policies
from lerobot.common.datasets.factory import make_dataset
@ -13,7 +15,8 @@ from lerobot.common.policies.factory import get_policy_and_config_classes, make_
from lerobot.common.policies.normalize import Normalize, Unnormalize
from lerobot.common.policies.policy_protocol import Policy
from lerobot.common.utils.utils import init_hydra_config
from tests.utils import DEFAULT_CONFIG_PATH, DEVICE, require_env
from tests.scripts.save_policy_to_safetensor import get_policy_stats
from tests.utils import DEFAULT_CONFIG_PATH, DEVICE, require_env, require_x86_64_kernel
@pytest.mark.parametrize("policy_name", available_policies)
@ -228,3 +231,37 @@ def test_normalize(insert_temporal_dim):
new_unnormalize = Unnormalize(output_shapes, unnormalize_output_modes, stats=None)
new_unnormalize.load_state_dict(unnormalize.state_dict())
unnormalize(output_batch)
@pytest.mark.parametrize(
"env_name, policy_name, extra_overrides",
[
# ("xarm", "tdmpc", ["policy.n_action_repeats=2"]),
(
"pusht",
"diffusion",
["policy.n_action_steps=8", "policy.num_inference_steps=10", "policy.down_dims=[128, 256, 512]"],
),
("aloha", "act", ["policy.n_action_steps=10"]),
],
)
# As artifacts have been generated on an x86_64 kernel, this test won't
# pass if it's run on another platform due to floating point errors
@require_x86_64_kernel
def test_backward_compatibility(env_name, policy_name, extra_overrides):
env_policy_dir = Path("tests/data/save_policy_to_safetensors") / f"{env_name}_{policy_name}"
saved_output_dict = load_file(env_policy_dir / "output_dict.safetensors")
saved_grad_stats = load_file(env_policy_dir / "grad_stats.safetensors")
saved_param_stats = load_file(env_policy_dir / "param_stats.safetensors")
saved_actions = load_file(env_policy_dir / "actions.safetensors")
output_dict, grad_stats, param_stats, actions = get_policy_stats(env_name, policy_name, extra_overrides)
for key in saved_output_dict:
assert torch.isclose(output_dict[key], saved_output_dict[key], rtol=0.1, atol=1e-7).all()
for key in saved_grad_stats:
assert torch.isclose(grad_stats[key], saved_grad_stats[key], rtol=0.1, atol=1e-7).all()
for key in saved_param_stats:
assert torch.isclose(param_stats[key], saved_param_stats[key], rtol=50, atol=1e-7).all()
for key in saved_actions:
assert torch.isclose(actions[key], saved_actions[key], rtol=0.1, atol=1e-7).all()

View File

@ -1,3 +1,5 @@
import platform
import pytest
import torch
@ -9,6 +11,51 @@ DEFAULT_CONFIG_PATH = "lerobot/configs/default.yaml"
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
def require_x86_64_kernel(func):
"""
Decorator that skips the test if plateform device is not an x86_64 cpu.
"""
from functools import wraps
@wraps(func)
def wrapper(*args, **kwargs):
if platform.machine() != "x86_64":
pytest.skip("requires x86_64 plateform")
return func(*args, **kwargs)
return wrapper
def require_cpu(func):
"""
Decorator that skips the test if device is not cpu.
"""
from functools import wraps
@wraps(func)
def wrapper(*args, **kwargs):
if DEVICE != "cpu":
pytest.skip("requires cpu")
return func(*args, **kwargs)
return wrapper
def require_cuda(func):
"""
Decorator that skips the test if cuda is not available.
"""
from functools import wraps
@wraps(func)
def wrapper(*args, **kwargs):
if not torch.cuda.is_available():
pytest.skip("requires cuda")
return func(*args, **kwargs)
return wrapper
def require_env(func):
"""
Decorator that skips the test if the required environment package is not installed.