Add regression tests (#119)
- Add `tests/scripts/save_policy_to_safetensor.py` to generate test artifacts - Add `test_backward_compatibility to test generated outputs from the policies against artifacts
This commit is contained in:
parent
19812ca470
commit
c77633c38c
|
@ -80,7 +80,8 @@ class TDMPCPolicy(nn.Module, PyTorchModelHubMixin):
|
||||||
self.config = config
|
self.config = config
|
||||||
self.model = TDMPCTOLD(config)
|
self.model = TDMPCTOLD(config)
|
||||||
self.model_target = deepcopy(self.model)
|
self.model_target = deepcopy(self.model)
|
||||||
self.model_target.eval()
|
for param in self.model_target.parameters():
|
||||||
|
param.requires_grad = False
|
||||||
|
|
||||||
if config.input_normalization_modes is not None:
|
if config.input_normalization_modes is not None:
|
||||||
self.normalize_inputs = Normalize(
|
self.normalize_inputs = Normalize(
|
||||||
|
|
|
@ -1,6 +1,7 @@
|
||||||
# @package _global_
|
# @package _global_
|
||||||
|
|
||||||
seed: 1
|
seed: 1
|
||||||
|
dataset_repo_id: lerobot/xarm_lift_medium_replay
|
||||||
|
|
||||||
training:
|
training:
|
||||||
offline_steps: 25000
|
offline_steps: 25000
|
||||||
|
|
|
@ -25,6 +25,51 @@ from lerobot.common.utils.utils import (
|
||||||
from lerobot.scripts.eval import eval_policy
|
from lerobot.scripts.eval import eval_policy
|
||||||
|
|
||||||
|
|
||||||
|
def make_optimizer_and_scheduler(cfg, policy):
|
||||||
|
if cfg.policy.name == "act":
|
||||||
|
optimizer_params_dicts = [
|
||||||
|
{
|
||||||
|
"params": [
|
||||||
|
p
|
||||||
|
for n, p in policy.named_parameters()
|
||||||
|
if not n.startswith("backbone") and p.requires_grad
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"params": [
|
||||||
|
p for n, p in policy.named_parameters() if n.startswith("backbone") and p.requires_grad
|
||||||
|
],
|
||||||
|
"lr": cfg.training.lr_backbone,
|
||||||
|
},
|
||||||
|
]
|
||||||
|
optimizer = torch.optim.AdamW(
|
||||||
|
optimizer_params_dicts, lr=cfg.training.lr, weight_decay=cfg.training.weight_decay
|
||||||
|
)
|
||||||
|
lr_scheduler = None
|
||||||
|
elif cfg.policy.name == "diffusion":
|
||||||
|
optimizer = torch.optim.Adam(
|
||||||
|
policy.diffusion.parameters(),
|
||||||
|
cfg.training.lr,
|
||||||
|
cfg.training.adam_betas,
|
||||||
|
cfg.training.adam_eps,
|
||||||
|
cfg.training.adam_weight_decay,
|
||||||
|
)
|
||||||
|
assert cfg.training.online_steps == 0, "Diffusion Policy does not handle online training."
|
||||||
|
lr_scheduler = get_scheduler(
|
||||||
|
cfg.training.lr_scheduler,
|
||||||
|
optimizer=optimizer,
|
||||||
|
num_warmup_steps=cfg.training.lr_warmup_steps,
|
||||||
|
num_training_steps=cfg.training.offline_steps,
|
||||||
|
)
|
||||||
|
elif policy.name == "tdmpc":
|
||||||
|
optimizer = torch.optim.Adam(policy.parameters(), cfg.training.lr)
|
||||||
|
lr_scheduler = None
|
||||||
|
else:
|
||||||
|
raise NotImplementedError()
|
||||||
|
|
||||||
|
return optimizer, lr_scheduler
|
||||||
|
|
||||||
|
|
||||||
def update_policy(policy, batch, optimizer, grad_clip_norm, lr_scheduler=None):
|
def update_policy(policy, batch, optimizer, grad_clip_norm, lr_scheduler=None):
|
||||||
start_time = time.time()
|
start_time = time.time()
|
||||||
policy.train()
|
policy.train()
|
||||||
|
@ -276,46 +321,7 @@ def train(cfg: dict, out_dir=None, job_name=None):
|
||||||
|
|
||||||
# Create optimizer and scheduler
|
# Create optimizer and scheduler
|
||||||
# Temporary hack to move optimizer out of policy
|
# Temporary hack to move optimizer out of policy
|
||||||
if cfg.policy.name == "act":
|
optimizer, lr_scheduler = make_optimizer_and_scheduler(cfg, policy)
|
||||||
optimizer_params_dicts = [
|
|
||||||
{
|
|
||||||
"params": [
|
|
||||||
p
|
|
||||||
for n, p in policy.named_parameters()
|
|
||||||
if not n.startswith("backbone") and p.requires_grad
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"params": [
|
|
||||||
p for n, p in policy.named_parameters() if n.startswith("backbone") and p.requires_grad
|
|
||||||
],
|
|
||||||
"lr": cfg.training.lr_backbone,
|
|
||||||
},
|
|
||||||
]
|
|
||||||
optimizer = torch.optim.AdamW(
|
|
||||||
optimizer_params_dicts, lr=cfg.training.lr, weight_decay=cfg.training.weight_decay
|
|
||||||
)
|
|
||||||
lr_scheduler = None
|
|
||||||
elif cfg.policy.name == "diffusion":
|
|
||||||
optimizer = torch.optim.Adam(
|
|
||||||
policy.diffusion.parameters(),
|
|
||||||
cfg.training.lr,
|
|
||||||
cfg.training.adam_betas,
|
|
||||||
cfg.training.adam_eps,
|
|
||||||
cfg.training.adam_weight_decay,
|
|
||||||
)
|
|
||||||
assert cfg.training.online_steps == 0, "Diffusion Policy does not handle online training."
|
|
||||||
lr_scheduler = get_scheduler(
|
|
||||||
cfg.training.lr_scheduler,
|
|
||||||
optimizer=optimizer,
|
|
||||||
num_warmup_steps=cfg.training.lr_warmup_steps,
|
|
||||||
num_training_steps=cfg.training.offline_steps,
|
|
||||||
)
|
|
||||||
elif policy.name == "tdmpc":
|
|
||||||
optimizer = torch.optim.Adam(policy.parameters(), cfg.training.lr)
|
|
||||||
lr_scheduler = None
|
|
||||||
else:
|
|
||||||
raise NotImplementedError()
|
|
||||||
|
|
||||||
num_learnable_params = sum(p.numel() for p in policy.parameters() if p.requires_grad)
|
num_learnable_params = sum(p.numel() for p in policy.parameters() if p.requires_grad)
|
||||||
num_total_params = sum(p.numel() for p in policy.parameters())
|
num_total_params = sum(p.numel() for p in policy.parameters())
|
||||||
|
|
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
|
@ -0,0 +1,101 @@
|
||||||
|
import shutil
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from safetensors.torch import save_file
|
||||||
|
|
||||||
|
from lerobot.common.datasets.factory import make_dataset
|
||||||
|
from lerobot.common.policies.factory import make_policy
|
||||||
|
from lerobot.common.utils.utils import init_hydra_config, set_global_seed
|
||||||
|
from lerobot.scripts.train import make_optimizer_and_scheduler
|
||||||
|
from tests.utils import DEFAULT_CONFIG_PATH
|
||||||
|
|
||||||
|
|
||||||
|
def get_policy_stats(env_name, policy_name, extra_overrides=None):
|
||||||
|
cfg = init_hydra_config(
|
||||||
|
DEFAULT_CONFIG_PATH,
|
||||||
|
overrides=[
|
||||||
|
f"env={env_name}",
|
||||||
|
f"policy={policy_name}",
|
||||||
|
"device=cpu",
|
||||||
|
]
|
||||||
|
+ extra_overrides,
|
||||||
|
)
|
||||||
|
set_global_seed(1337)
|
||||||
|
dataset = make_dataset(cfg)
|
||||||
|
policy = make_policy(cfg, dataset_stats=dataset.stats)
|
||||||
|
policy.train()
|
||||||
|
optimizer, _ = make_optimizer_and_scheduler(cfg, policy)
|
||||||
|
|
||||||
|
dataloader = torch.utils.data.DataLoader(
|
||||||
|
dataset,
|
||||||
|
num_workers=0,
|
||||||
|
batch_size=cfg.training.batch_size,
|
||||||
|
shuffle=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
batch = next(iter(dataloader))
|
||||||
|
output_dict = policy.forward(batch)
|
||||||
|
output_dict = {k: v for k, v in output_dict.items() if isinstance(v, torch.Tensor)}
|
||||||
|
loss = output_dict["loss"]
|
||||||
|
|
||||||
|
loss.backward()
|
||||||
|
grad_stats = {}
|
||||||
|
for key, param in policy.named_parameters():
|
||||||
|
if param.requires_grad:
|
||||||
|
grad_stats[f"{key}_mean"] = param.grad.mean()
|
||||||
|
grad_stats[f"{key}_std"] = (
|
||||||
|
param.grad.std() if param.grad.numel() > 1 else torch.tensor(float(0.0))
|
||||||
|
)
|
||||||
|
|
||||||
|
optimizer.step()
|
||||||
|
param_stats = {}
|
||||||
|
for key, param in policy.named_parameters():
|
||||||
|
param_stats[f"{key}_mean"] = param.mean()
|
||||||
|
param_stats[f"{key}_std"] = param.std() if param.numel() > 1 else torch.tensor(float(0.0))
|
||||||
|
|
||||||
|
optimizer.zero_grad()
|
||||||
|
policy.reset()
|
||||||
|
|
||||||
|
# HACK: We reload a batch with no delta_timestamps as `select_action` won't expect a timestamps dimension
|
||||||
|
dataset.delta_timestamps = None
|
||||||
|
batch = next(iter(dataloader))
|
||||||
|
obs = {
|
||||||
|
k: batch[k]
|
||||||
|
for k in batch
|
||||||
|
if k in ["observation.image", "observation.images.top", "observation.state"]
|
||||||
|
}
|
||||||
|
|
||||||
|
actions_queue = (
|
||||||
|
cfg.policy.n_action_steps if "n_action_steps" in cfg.policy else cfg.policy.n_action_repeats
|
||||||
|
)
|
||||||
|
actions = {str(i): policy.select_action(obs).contiguous() for i in range(actions_queue)}
|
||||||
|
return output_dict, grad_stats, param_stats, actions
|
||||||
|
|
||||||
|
|
||||||
|
def save_policy_to_safetensors(output_dir, env_name, policy_name, extra_overrides):
|
||||||
|
env_policy_dir = Path(output_dir) / f"{env_name}_{policy_name}"
|
||||||
|
|
||||||
|
if env_policy_dir.exists():
|
||||||
|
shutil.rmtree(env_policy_dir)
|
||||||
|
|
||||||
|
env_policy_dir.mkdir(parents=True, exist_ok=True)
|
||||||
|
output_dict, grad_stats, param_stats, actions = get_policy_stats(env_name, policy_name, extra_overrides)
|
||||||
|
save_file(output_dict, env_policy_dir / "output_dict.safetensors")
|
||||||
|
save_file(grad_stats, env_policy_dir / "grad_stats.safetensors")
|
||||||
|
save_file(param_stats, env_policy_dir / "param_stats.safetensors")
|
||||||
|
save_file(actions, env_policy_dir / "actions.safetensors")
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
env_policies = [
|
||||||
|
# ("xarm", "tdmpc", ["policy.n_action_repeats=2"]),
|
||||||
|
(
|
||||||
|
"pusht",
|
||||||
|
"diffusion",
|
||||||
|
["policy.n_action_steps=8", "policy.num_inference_steps=10", "policy.down_dims=[128, 256, 512]"],
|
||||||
|
),
|
||||||
|
("aloha", "act", ["policy.n_action_steps=10"]),
|
||||||
|
]
|
||||||
|
for env, policy, extra_overrides in env_policies:
|
||||||
|
save_policy_to_safetensors("tests/data/save_policy_to_safetensors", env, policy, extra_overrides)
|
|
@ -265,7 +265,7 @@ def test_backward_compatibility(repo_id):
|
||||||
|
|
||||||
for key in new_frame:
|
for key in new_frame:
|
||||||
assert torch.isclose(
|
assert torch.isclose(
|
||||||
new_frame[key], old_frame[key], rtol=1e-05, atol=1e-08
|
new_frame[key], old_frame[key]
|
||||||
).all(), f"{key=} for index={i} does not contain the same value"
|
).all(), f"{key=} for index={i} does not contain the same value"
|
||||||
|
|
||||||
# test2 first frames of first episode
|
# test2 first frames of first episode
|
||||||
|
|
|
@ -1,8 +1,10 @@
|
||||||
import inspect
|
import inspect
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
import torch
|
import torch
|
||||||
from huggingface_hub import PyTorchModelHubMixin
|
from huggingface_hub import PyTorchModelHubMixin
|
||||||
|
from safetensors.torch import load_file
|
||||||
|
|
||||||
from lerobot import available_policies
|
from lerobot import available_policies
|
||||||
from lerobot.common.datasets.factory import make_dataset
|
from lerobot.common.datasets.factory import make_dataset
|
||||||
|
@ -13,7 +15,8 @@ from lerobot.common.policies.factory import get_policy_and_config_classes, make_
|
||||||
from lerobot.common.policies.normalize import Normalize, Unnormalize
|
from lerobot.common.policies.normalize import Normalize, Unnormalize
|
||||||
from lerobot.common.policies.policy_protocol import Policy
|
from lerobot.common.policies.policy_protocol import Policy
|
||||||
from lerobot.common.utils.utils import init_hydra_config
|
from lerobot.common.utils.utils import init_hydra_config
|
||||||
from tests.utils import DEFAULT_CONFIG_PATH, DEVICE, require_env
|
from tests.scripts.save_policy_to_safetensor import get_policy_stats
|
||||||
|
from tests.utils import DEFAULT_CONFIG_PATH, DEVICE, require_env, require_x86_64_kernel
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("policy_name", available_policies)
|
@pytest.mark.parametrize("policy_name", available_policies)
|
||||||
|
@ -228,3 +231,37 @@ def test_normalize(insert_temporal_dim):
|
||||||
new_unnormalize = Unnormalize(output_shapes, unnormalize_output_modes, stats=None)
|
new_unnormalize = Unnormalize(output_shapes, unnormalize_output_modes, stats=None)
|
||||||
new_unnormalize.load_state_dict(unnormalize.state_dict())
|
new_unnormalize.load_state_dict(unnormalize.state_dict())
|
||||||
unnormalize(output_batch)
|
unnormalize(output_batch)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"env_name, policy_name, extra_overrides",
|
||||||
|
[
|
||||||
|
# ("xarm", "tdmpc", ["policy.n_action_repeats=2"]),
|
||||||
|
(
|
||||||
|
"pusht",
|
||||||
|
"diffusion",
|
||||||
|
["policy.n_action_steps=8", "policy.num_inference_steps=10", "policy.down_dims=[128, 256, 512]"],
|
||||||
|
),
|
||||||
|
("aloha", "act", ["policy.n_action_steps=10"]),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
# As artifacts have been generated on an x86_64 kernel, this test won't
|
||||||
|
# pass if it's run on another platform due to floating point errors
|
||||||
|
@require_x86_64_kernel
|
||||||
|
def test_backward_compatibility(env_name, policy_name, extra_overrides):
|
||||||
|
env_policy_dir = Path("tests/data/save_policy_to_safetensors") / f"{env_name}_{policy_name}"
|
||||||
|
saved_output_dict = load_file(env_policy_dir / "output_dict.safetensors")
|
||||||
|
saved_grad_stats = load_file(env_policy_dir / "grad_stats.safetensors")
|
||||||
|
saved_param_stats = load_file(env_policy_dir / "param_stats.safetensors")
|
||||||
|
saved_actions = load_file(env_policy_dir / "actions.safetensors")
|
||||||
|
|
||||||
|
output_dict, grad_stats, param_stats, actions = get_policy_stats(env_name, policy_name, extra_overrides)
|
||||||
|
|
||||||
|
for key in saved_output_dict:
|
||||||
|
assert torch.isclose(output_dict[key], saved_output_dict[key], rtol=0.1, atol=1e-7).all()
|
||||||
|
for key in saved_grad_stats:
|
||||||
|
assert torch.isclose(grad_stats[key], saved_grad_stats[key], rtol=0.1, atol=1e-7).all()
|
||||||
|
for key in saved_param_stats:
|
||||||
|
assert torch.isclose(param_stats[key], saved_param_stats[key], rtol=50, atol=1e-7).all()
|
||||||
|
for key in saved_actions:
|
||||||
|
assert torch.isclose(actions[key], saved_actions[key], rtol=0.1, atol=1e-7).all()
|
||||||
|
|
|
@ -1,3 +1,5 @@
|
||||||
|
import platform
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
|
@ -9,6 +11,51 @@ DEFAULT_CONFIG_PATH = "lerobot/configs/default.yaml"
|
||||||
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
|
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
|
||||||
|
|
||||||
|
|
||||||
|
def require_x86_64_kernel(func):
|
||||||
|
"""
|
||||||
|
Decorator that skips the test if plateform device is not an x86_64 cpu.
|
||||||
|
"""
|
||||||
|
from functools import wraps
|
||||||
|
|
||||||
|
@wraps(func)
|
||||||
|
def wrapper(*args, **kwargs):
|
||||||
|
if platform.machine() != "x86_64":
|
||||||
|
pytest.skip("requires x86_64 plateform")
|
||||||
|
return func(*args, **kwargs)
|
||||||
|
|
||||||
|
return wrapper
|
||||||
|
|
||||||
|
|
||||||
|
def require_cpu(func):
|
||||||
|
"""
|
||||||
|
Decorator that skips the test if device is not cpu.
|
||||||
|
"""
|
||||||
|
from functools import wraps
|
||||||
|
|
||||||
|
@wraps(func)
|
||||||
|
def wrapper(*args, **kwargs):
|
||||||
|
if DEVICE != "cpu":
|
||||||
|
pytest.skip("requires cpu")
|
||||||
|
return func(*args, **kwargs)
|
||||||
|
|
||||||
|
return wrapper
|
||||||
|
|
||||||
|
|
||||||
|
def require_cuda(func):
|
||||||
|
"""
|
||||||
|
Decorator that skips the test if cuda is not available.
|
||||||
|
"""
|
||||||
|
from functools import wraps
|
||||||
|
|
||||||
|
@wraps(func)
|
||||||
|
def wrapper(*args, **kwargs):
|
||||||
|
if not torch.cuda.is_available():
|
||||||
|
pytest.skip("requires cuda")
|
||||||
|
return func(*args, **kwargs)
|
||||||
|
|
||||||
|
return wrapper
|
||||||
|
|
||||||
|
|
||||||
def require_env(func):
|
def require_env(func):
|
||||||
"""
|
"""
|
||||||
Decorator that skips the test if the required environment package is not installed.
|
Decorator that skips the test if the required environment package is not installed.
|
||||||
|
|
Loading…
Reference in New Issue