diff --git a/lerobot/common/policies/tdmpc/modeling_tdmpc.py b/lerobot/common/policies/tdmpc/modeling_tdmpc.py index eab0f94e..1fba43d0 100644 --- a/lerobot/common/policies/tdmpc/modeling_tdmpc.py +++ b/lerobot/common/policies/tdmpc/modeling_tdmpc.py @@ -80,7 +80,8 @@ class TDMPCPolicy(nn.Module, PyTorchModelHubMixin): self.config = config self.model = TDMPCTOLD(config) self.model_target = deepcopy(self.model) - self.model_target.eval() + for param in self.model_target.parameters(): + param.requires_grad = False if config.input_normalization_modes is not None: self.normalize_inputs = Normalize( diff --git a/lerobot/configs/policy/tdmpc.yaml b/lerobot/configs/policy/tdmpc.yaml index 71dfa9c9..eb89033b 100644 --- a/lerobot/configs/policy/tdmpc.yaml +++ b/lerobot/configs/policy/tdmpc.yaml @@ -1,6 +1,7 @@ # @package _global_ seed: 1 +dataset_repo_id: lerobot/xarm_lift_medium_replay training: offline_steps: 25000 diff --git a/lerobot/scripts/train.py b/lerobot/scripts/train.py index 268185a3..f58dbd06 100644 --- a/lerobot/scripts/train.py +++ b/lerobot/scripts/train.py @@ -25,6 +25,51 @@ from lerobot.common.utils.utils import ( from lerobot.scripts.eval import eval_policy +def make_optimizer_and_scheduler(cfg, policy): + if cfg.policy.name == "act": + optimizer_params_dicts = [ + { + "params": [ + p + for n, p in policy.named_parameters() + if not n.startswith("backbone") and p.requires_grad + ] + }, + { + "params": [ + p for n, p in policy.named_parameters() if n.startswith("backbone") and p.requires_grad + ], + "lr": cfg.training.lr_backbone, + }, + ] + optimizer = torch.optim.AdamW( + optimizer_params_dicts, lr=cfg.training.lr, weight_decay=cfg.training.weight_decay + ) + lr_scheduler = None + elif cfg.policy.name == "diffusion": + optimizer = torch.optim.Adam( + policy.diffusion.parameters(), + cfg.training.lr, + cfg.training.adam_betas, + cfg.training.adam_eps, + cfg.training.adam_weight_decay, + ) + assert cfg.training.online_steps == 0, "Diffusion Policy does not handle online training." + lr_scheduler = get_scheduler( + cfg.training.lr_scheduler, + optimizer=optimizer, + num_warmup_steps=cfg.training.lr_warmup_steps, + num_training_steps=cfg.training.offline_steps, + ) + elif policy.name == "tdmpc": + optimizer = torch.optim.Adam(policy.parameters(), cfg.training.lr) + lr_scheduler = None + else: + raise NotImplementedError() + + return optimizer, lr_scheduler + + def update_policy(policy, batch, optimizer, grad_clip_norm, lr_scheduler=None): start_time = time.time() policy.train() @@ -276,46 +321,7 @@ def train(cfg: dict, out_dir=None, job_name=None): # Create optimizer and scheduler # Temporary hack to move optimizer out of policy - if cfg.policy.name == "act": - optimizer_params_dicts = [ - { - "params": [ - p - for n, p in policy.named_parameters() - if not n.startswith("backbone") and p.requires_grad - ] - }, - { - "params": [ - p for n, p in policy.named_parameters() if n.startswith("backbone") and p.requires_grad - ], - "lr": cfg.training.lr_backbone, - }, - ] - optimizer = torch.optim.AdamW( - optimizer_params_dicts, lr=cfg.training.lr, weight_decay=cfg.training.weight_decay - ) - lr_scheduler = None - elif cfg.policy.name == "diffusion": - optimizer = torch.optim.Adam( - policy.diffusion.parameters(), - cfg.training.lr, - cfg.training.adam_betas, - cfg.training.adam_eps, - cfg.training.adam_weight_decay, - ) - assert cfg.training.online_steps == 0, "Diffusion Policy does not handle online training." - lr_scheduler = get_scheduler( - cfg.training.lr_scheduler, - optimizer=optimizer, - num_warmup_steps=cfg.training.lr_warmup_steps, - num_training_steps=cfg.training.offline_steps, - ) - elif policy.name == "tdmpc": - optimizer = torch.optim.Adam(policy.parameters(), cfg.training.lr) - lr_scheduler = None - else: - raise NotImplementedError() + optimizer, lr_scheduler = make_optimizer_and_scheduler(cfg, policy) num_learnable_params = sum(p.numel() for p in policy.parameters() if p.requires_grad) num_total_params = sum(p.numel() for p in policy.parameters()) diff --git a/tests/data/save_policy_to_safetensors/aloha_act/actions.safetensors b/tests/data/save_policy_to_safetensors/aloha_act/actions.safetensors new file mode 100644 index 00000000..70c9b6d8 Binary files /dev/null and b/tests/data/save_policy_to_safetensors/aloha_act/actions.safetensors differ diff --git a/tests/data/save_policy_to_safetensors/aloha_act/grad_stats.safetensors b/tests/data/save_policy_to_safetensors/aloha_act/grad_stats.safetensors new file mode 100644 index 00000000..2e845189 Binary files /dev/null and b/tests/data/save_policy_to_safetensors/aloha_act/grad_stats.safetensors differ diff --git a/tests/data/save_policy_to_safetensors/aloha_act/output_dict.safetensors b/tests/data/save_policy_to_safetensors/aloha_act/output_dict.safetensors new file mode 100644 index 00000000..e8d537c8 Binary files /dev/null and b/tests/data/save_policy_to_safetensors/aloha_act/output_dict.safetensors differ diff --git a/tests/data/save_policy_to_safetensors/aloha_act/param_stats.safetensors b/tests/data/save_policy_to_safetensors/aloha_act/param_stats.safetensors new file mode 100644 index 00000000..6e33879f Binary files /dev/null and b/tests/data/save_policy_to_safetensors/aloha_act/param_stats.safetensors differ diff --git a/tests/data/save_policy_to_safetensors/pusht_diffusion/actions.safetensors b/tests/data/save_policy_to_safetensors/pusht_diffusion/actions.safetensors new file mode 100644 index 00000000..d9b20317 Binary files /dev/null and b/tests/data/save_policy_to_safetensors/pusht_diffusion/actions.safetensors differ diff --git a/tests/data/save_policy_to_safetensors/pusht_diffusion/grad_stats.safetensors b/tests/data/save_policy_to_safetensors/pusht_diffusion/grad_stats.safetensors new file mode 100644 index 00000000..4eed4aaf Binary files /dev/null and b/tests/data/save_policy_to_safetensors/pusht_diffusion/grad_stats.safetensors differ diff --git a/tests/data/save_policy_to_safetensors/pusht_diffusion/output_dict.safetensors b/tests/data/save_policy_to_safetensors/pusht_diffusion/output_dict.safetensors new file mode 100644 index 00000000..77472bb5 Binary files /dev/null and b/tests/data/save_policy_to_safetensors/pusht_diffusion/output_dict.safetensors differ diff --git a/tests/data/save_policy_to_safetensors/pusht_diffusion/param_stats.safetensors b/tests/data/save_policy_to_safetensors/pusht_diffusion/param_stats.safetensors new file mode 100644 index 00000000..a26868db Binary files /dev/null and b/tests/data/save_policy_to_safetensors/pusht_diffusion/param_stats.safetensors differ diff --git a/tests/scripts/save_policy_to_safetensor.py b/tests/scripts/save_policy_to_safetensor.py new file mode 100644 index 00000000..70337c17 --- /dev/null +++ b/tests/scripts/save_policy_to_safetensor.py @@ -0,0 +1,101 @@ +import shutil +from pathlib import Path + +import torch +from safetensors.torch import save_file + +from lerobot.common.datasets.factory import make_dataset +from lerobot.common.policies.factory import make_policy +from lerobot.common.utils.utils import init_hydra_config, set_global_seed +from lerobot.scripts.train import make_optimizer_and_scheduler +from tests.utils import DEFAULT_CONFIG_PATH + + +def get_policy_stats(env_name, policy_name, extra_overrides=None): + cfg = init_hydra_config( + DEFAULT_CONFIG_PATH, + overrides=[ + f"env={env_name}", + f"policy={policy_name}", + "device=cpu", + ] + + extra_overrides, + ) + set_global_seed(1337) + dataset = make_dataset(cfg) + policy = make_policy(cfg, dataset_stats=dataset.stats) + policy.train() + optimizer, _ = make_optimizer_and_scheduler(cfg, policy) + + dataloader = torch.utils.data.DataLoader( + dataset, + num_workers=0, + batch_size=cfg.training.batch_size, + shuffle=False, + ) + + batch = next(iter(dataloader)) + output_dict = policy.forward(batch) + output_dict = {k: v for k, v in output_dict.items() if isinstance(v, torch.Tensor)} + loss = output_dict["loss"] + + loss.backward() + grad_stats = {} + for key, param in policy.named_parameters(): + if param.requires_grad: + grad_stats[f"{key}_mean"] = param.grad.mean() + grad_stats[f"{key}_std"] = ( + param.grad.std() if param.grad.numel() > 1 else torch.tensor(float(0.0)) + ) + + optimizer.step() + param_stats = {} + for key, param in policy.named_parameters(): + param_stats[f"{key}_mean"] = param.mean() + param_stats[f"{key}_std"] = param.std() if param.numel() > 1 else torch.tensor(float(0.0)) + + optimizer.zero_grad() + policy.reset() + + # HACK: We reload a batch with no delta_timestamps as `select_action` won't expect a timestamps dimension + dataset.delta_timestamps = None + batch = next(iter(dataloader)) + obs = { + k: batch[k] + for k in batch + if k in ["observation.image", "observation.images.top", "observation.state"] + } + + actions_queue = ( + cfg.policy.n_action_steps if "n_action_steps" in cfg.policy else cfg.policy.n_action_repeats + ) + actions = {str(i): policy.select_action(obs).contiguous() for i in range(actions_queue)} + return output_dict, grad_stats, param_stats, actions + + +def save_policy_to_safetensors(output_dir, env_name, policy_name, extra_overrides): + env_policy_dir = Path(output_dir) / f"{env_name}_{policy_name}" + + if env_policy_dir.exists(): + shutil.rmtree(env_policy_dir) + + env_policy_dir.mkdir(parents=True, exist_ok=True) + output_dict, grad_stats, param_stats, actions = get_policy_stats(env_name, policy_name, extra_overrides) + save_file(output_dict, env_policy_dir / "output_dict.safetensors") + save_file(grad_stats, env_policy_dir / "grad_stats.safetensors") + save_file(param_stats, env_policy_dir / "param_stats.safetensors") + save_file(actions, env_policy_dir / "actions.safetensors") + + +if __name__ == "__main__": + env_policies = [ + # ("xarm", "tdmpc", ["policy.n_action_repeats=2"]), + ( + "pusht", + "diffusion", + ["policy.n_action_steps=8", "policy.num_inference_steps=10", "policy.down_dims=[128, 256, 512]"], + ), + ("aloha", "act", ["policy.n_action_steps=10"]), + ] + for env, policy, extra_overrides in env_policies: + save_policy_to_safetensors("tests/data/save_policy_to_safetensors", env, policy, extra_overrides) diff --git a/tests/test_datasets.py b/tests/test_datasets.py index e50d4108..22b271be 100644 --- a/tests/test_datasets.py +++ b/tests/test_datasets.py @@ -265,7 +265,7 @@ def test_backward_compatibility(repo_id): for key in new_frame: assert torch.isclose( - new_frame[key], old_frame[key], rtol=1e-05, atol=1e-08 + new_frame[key], old_frame[key] ).all(), f"{key=} for index={i} does not contain the same value" # test2 first frames of first episode diff --git a/tests/test_policies.py b/tests/test_policies.py index 51cdb93e..7d2f19ba 100644 --- a/tests/test_policies.py +++ b/tests/test_policies.py @@ -1,8 +1,10 @@ import inspect +from pathlib import Path import pytest import torch from huggingface_hub import PyTorchModelHubMixin +from safetensors.torch import load_file from lerobot import available_policies from lerobot.common.datasets.factory import make_dataset @@ -13,7 +15,8 @@ from lerobot.common.policies.factory import get_policy_and_config_classes, make_ from lerobot.common.policies.normalize import Normalize, Unnormalize from lerobot.common.policies.policy_protocol import Policy from lerobot.common.utils.utils import init_hydra_config -from tests.utils import DEFAULT_CONFIG_PATH, DEVICE, require_env +from tests.scripts.save_policy_to_safetensor import get_policy_stats +from tests.utils import DEFAULT_CONFIG_PATH, DEVICE, require_env, require_x86_64_kernel @pytest.mark.parametrize("policy_name", available_policies) @@ -228,3 +231,37 @@ def test_normalize(insert_temporal_dim): new_unnormalize = Unnormalize(output_shapes, unnormalize_output_modes, stats=None) new_unnormalize.load_state_dict(unnormalize.state_dict()) unnormalize(output_batch) + + +@pytest.mark.parametrize( + "env_name, policy_name, extra_overrides", + [ + # ("xarm", "tdmpc", ["policy.n_action_repeats=2"]), + ( + "pusht", + "diffusion", + ["policy.n_action_steps=8", "policy.num_inference_steps=10", "policy.down_dims=[128, 256, 512]"], + ), + ("aloha", "act", ["policy.n_action_steps=10"]), + ], +) +# As artifacts have been generated on an x86_64 kernel, this test won't +# pass if it's run on another platform due to floating point errors +@require_x86_64_kernel +def test_backward_compatibility(env_name, policy_name, extra_overrides): + env_policy_dir = Path("tests/data/save_policy_to_safetensors") / f"{env_name}_{policy_name}" + saved_output_dict = load_file(env_policy_dir / "output_dict.safetensors") + saved_grad_stats = load_file(env_policy_dir / "grad_stats.safetensors") + saved_param_stats = load_file(env_policy_dir / "param_stats.safetensors") + saved_actions = load_file(env_policy_dir / "actions.safetensors") + + output_dict, grad_stats, param_stats, actions = get_policy_stats(env_name, policy_name, extra_overrides) + + for key in saved_output_dict: + assert torch.isclose(output_dict[key], saved_output_dict[key], rtol=0.1, atol=1e-7).all() + for key in saved_grad_stats: + assert torch.isclose(grad_stats[key], saved_grad_stats[key], rtol=0.1, atol=1e-7).all() + for key in saved_param_stats: + assert torch.isclose(param_stats[key], saved_param_stats[key], rtol=50, atol=1e-7).all() + for key in saved_actions: + assert torch.isclose(actions[key], saved_actions[key], rtol=0.1, atol=1e-7).all() diff --git a/tests/utils.py b/tests/utils.py index f3fe5790..6a706694 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -1,3 +1,5 @@ +import platform + import pytest import torch @@ -9,6 +11,51 @@ DEFAULT_CONFIG_PATH = "lerobot/configs/default.yaml" DEVICE = "cuda" if torch.cuda.is_available() else "cpu" +def require_x86_64_kernel(func): + """ + Decorator that skips the test if plateform device is not an x86_64 cpu. + """ + from functools import wraps + + @wraps(func) + def wrapper(*args, **kwargs): + if platform.machine() != "x86_64": + pytest.skip("requires x86_64 plateform") + return func(*args, **kwargs) + + return wrapper + + +def require_cpu(func): + """ + Decorator that skips the test if device is not cpu. + """ + from functools import wraps + + @wraps(func) + def wrapper(*args, **kwargs): + if DEVICE != "cpu": + pytest.skip("requires cpu") + return func(*args, **kwargs) + + return wrapper + + +def require_cuda(func): + """ + Decorator that skips the test if cuda is not available. + """ + from functools import wraps + + @wraps(func) + def wrapper(*args, **kwargs): + if not torch.cuda.is_available(): + pytest.skip("requires cuda") + return func(*args, **kwargs) + + return wrapper + + def require_env(func): """ Decorator that skips the test if the required environment package is not installed.