diff --git a/Makefile b/Makefile index a0163f94..c561deb0 100644 --- a/Makefile +++ b/Makefile @@ -20,6 +20,8 @@ build-gpu: test-end-to-end: ${MAKE} test-act-ete-train ${MAKE} test-act-ete-eval + ${MAKE} test-act-ete-train-amp + ${MAKE} test-act-ete-eval-amp ${MAKE} test-diffusion-ete-train ${MAKE} test-diffusion-ete-eval ${MAKE} test-tdmpc-ete-train @@ -29,6 +31,7 @@ test-end-to-end: test-act-ete-train: python lerobot/scripts/train.py \ policy=act \ + policy.dim_model=64 \ env=aloha \ wandb.enable=False \ training.offline_steps=2 \ @@ -51,9 +54,40 @@ test-act-ete-eval: env.episode_length=8 \ device=cpu \ +test-act-ete-train-amp: + python lerobot/scripts/train.py \ + policy=act \ + policy.dim_model=64 \ + env=aloha \ + wandb.enable=False \ + training.offline_steps=2 \ + training.online_steps=0 \ + eval.n_episodes=1 \ + eval.batch_size=1 \ + device=cpu \ + training.save_model=true \ + training.save_freq=2 \ + policy.n_action_steps=20 \ + policy.chunk_size=20 \ + training.batch_size=2 \ + hydra.run.dir=tests/outputs/act/ \ + use_amp=true + +test-act-ete-eval-amp: + python lerobot/scripts/eval.py \ + -p tests/outputs/act/checkpoints/000002 \ + eval.n_episodes=1 \ + eval.batch_size=1 \ + env.episode_length=8 \ + device=cpu \ + use_amp=true + test-diffusion-ete-train: python lerobot/scripts/train.py \ policy=diffusion \ + policy.down_dims=\[64,128,256\] \ + policy.diffusion_step_embed_dim=32 \ + policy.num_inference_steps=10 \ env=pusht \ wandb.enable=False \ training.offline_steps=2 \ @@ -100,7 +134,6 @@ test-tdmpc-ete-eval: env.episode_length=8 \ device=cpu \ - test-default-ete-eval: python lerobot/scripts/eval.py \ --config lerobot/configs/default.yaml \ diff --git a/lerobot/scripts/train.py b/lerobot/scripts/train.py index eea3b650..83159465 100644 --- a/lerobot/scripts/train.py +++ b/lerobot/scripts/train.py @@ -100,7 +100,7 @@ def update_policy( use_amp: bool = False, ): """Returns a dictionary of items for logging.""" - start_time = time.time() + start_time = time.perf_counter() device = get_device_from_parameters(policy) policy.train() with torch.autocast(device_type=device.type) if use_amp else nullcontext(): @@ -137,7 +137,7 @@ def update_policy( "loss": loss.item(), "grad_norm": float(grad_norm), "lr": optimizer.param_groups[0]["lr"], - "update_s": time.time() - start_time, + "update_s": time.perf_counter() - start_time, **{k: v for k, v in output_dict.items() if k != "loss"}, }