diff --git a/tests/test_regression.py b/tests/test_regression.py index 480e0a41..5d3b8f65 100644 --- a/tests/test_regression.py +++ b/tests/test_regression.py @@ -3,7 +3,7 @@ import torch from lerobot.common.datasets.factory import make_dataset from lerobot.common.policies.factory import make_policy -from lerobot.common.utils.utils import init_hydra_config +from lerobot.common.utils.utils import init_hydra_config, set_global_seed from lerobot.scripts.train import make_optimizer from tests.utils import DEFAULT_CONFIG_PATH, DEVICE, require_env @@ -14,68 +14,66 @@ from tests.utils import DEFAULT_CONFIG_PATH, DEVICE, require_env # ("xarm", "tdmpc", ["policy.mpc=true"]), # ("pusht", "tdmpc", ["policy.mpc=false"]), ("pusht", "diffusion", []), - ("aloha", "act", ["env.task=AlohaInsertion-v0", "dataset.repo_id=lerobot/aloha_sim_insertion_human"]), - ( - "aloha", - "act", - ["env.task=AlohaInsertion-v0", "dataset.repo_id=lerobot/aloha_sim_insertion_scripted"], - ), - ( - "aloha", - "act", - ["env.task=AlohaTransferCube-v0", "dataset.repo_id=lerobot/aloha_sim_transfer_cube_human"], - ), - ( - "aloha", - "act", - ["env.task=AlohaTransferCube-v0", "dataset.repo_id=lerobot/aloha_sim_transfer_cube_scripted"], - ), + ("aloha", "act"), ], ) @require_env -def test_backward(env_name, policy_name, extra_overrides): +def test_backward_compatibility(env_name, policy_name): cfg = init_hydra_config( DEFAULT_CONFIG_PATH, overrides=[ f"env={env_name}", f"policy={policy_name}", f"device={DEVICE}", - ] - + extra_overrides, + ], ) + set_global_seed(1337) dataset = make_dataset(cfg) policy = make_policy(cfg, dataset_stats=dataset.stats) policy.train() policy.to(DEVICE) - optimizer, lr_scheduler = make_optimizer(cfg, policy) + optimizer, _ = make_optimizer(cfg, policy) dataloader = torch.utils.data.DataLoader( dataset, - num_workers=4, - batch_size=cfg.policy.batch_size, - shuffle=True, - pin_memory=torch.device("cpu") != DEVICE, - drop_last=True, + num_workers=0, + batch_size=cfg.training.batch_size, + shuffle=False, ) - step = 0 - done = False - training_steps = 1 - while not done: - for batch in dataloader: - batch = {k: v.to(DEVICE, non_blocking=True) for k, v in batch.items()} - output_dict = policy.forward(batch) - loss = output_dict["loss"] - loss.backward() - optimizer.step() - optimizer.zero_grad() - step += 1 - if step >= training_steps: - done = True - break + batch = next(iter(dataloader)) + output_dict = policy.forward(batch) + loss = output_dict["loss"] + # TODO Check output dict values + + loss.backward() + grad_stats = {} + for key, param in policy.named_parameters(): + if param.requires_grad: + grad_stats[f"{key}_mean"] = param.grad.mean() + grad_stats[f"{key}_std"] = param.grad.std() + + optimizer.step() + param_stats = {} + for key, param in policy.named_parameters(): + param_stats[f"{key}_mean"] = param.mean() + param_stats[f"{key}_std"] = param.std() + + optimizer.zero_grad() + policy.reset() + + dataset.delta_timestamps = None + batch = next(iter(dataloader)) + + # TODO(aliberts): refacor `select_action` methods so that it expects `obs` instead of `batch` + if policy_name == "diffusion": + batch = {k: batch[k] for k in ["observation.image", "observation.state"]} + + actions = {i: policy.select_action(batch) for i in range(cfg.policy.n_action_steps)} + + print(len(actions)) if __name__ == "__main__": - test_backward( - "aloha", "act", ["env.task=AlohaInsertion-v0", "dataset.repo_id=lerobot/aloha_sim_insertion_scripted"] - ) + # test_backward_compatibility("aloha", "act") + test_backward_compatibility("pusht", "diffusion")