backup wip

This commit is contained in:
Alexander Soare 2024-05-20 13:48:20 +01:00
parent 39b6fcbe1e
commit f40cedeed7
2 changed files with 36 additions and 3 deletions

View File

@ -20,6 +20,8 @@ build-gpu:
test-end-to-end:
${MAKE} test-act-ete-train
${MAKE} test-act-ete-eval
${MAKE} test-act-ete-train-amp
${MAKE} test-act-ete-eval-amp
${MAKE} test-diffusion-ete-train
${MAKE} test-diffusion-ete-eval
${MAKE} test-tdmpc-ete-train
@ -29,6 +31,7 @@ test-end-to-end:
test-act-ete-train:
python lerobot/scripts/train.py \
policy=act \
policy.dim_model=64 \
env=aloha \
wandb.enable=False \
training.offline_steps=2 \
@ -51,9 +54,40 @@ test-act-ete-eval:
env.episode_length=8 \
device=cpu \
test-act-ete-train-amp:
python lerobot/scripts/train.py \
policy=act \
policy.dim_model=64 \
env=aloha \
wandb.enable=False \
training.offline_steps=2 \
training.online_steps=0 \
eval.n_episodes=1 \
eval.batch_size=1 \
device=cpu \
training.save_model=true \
training.save_freq=2 \
policy.n_action_steps=20 \
policy.chunk_size=20 \
training.batch_size=2 \
hydra.run.dir=tests/outputs/act/ \
use_amp=true
test-act-ete-eval-amp:
python lerobot/scripts/eval.py \
-p tests/outputs/act/checkpoints/000002 \
eval.n_episodes=1 \
eval.batch_size=1 \
env.episode_length=8 \
device=cpu \
use_amp=true
test-diffusion-ete-train:
python lerobot/scripts/train.py \
policy=diffusion \
policy.down_dims=\[64,128,256\] \
policy.diffusion_step_embed_dim=32 \
policy.num_inference_steps=10 \
env=pusht \
wandb.enable=False \
training.offline_steps=2 \
@ -100,7 +134,6 @@ test-tdmpc-ete-eval:
env.episode_length=8 \
device=cpu \
test-default-ete-eval:
python lerobot/scripts/eval.py \
--config lerobot/configs/default.yaml \

View File

@ -100,7 +100,7 @@ def update_policy(
use_amp: bool = False,
):
"""Returns a dictionary of items for logging."""
start_time = time.time()
start_time = time.perf_counter()
device = get_device_from_parameters(policy)
policy.train()
with torch.autocast(device_type=device.type) if use_amp else nullcontext():
@ -137,7 +137,7 @@ def update_policy(
"loss": loss.item(),
"grad_norm": float(grad_norm),
"lr": optimizer.param_groups[0]["lr"],
"update_s": time.time() - start_time,
"update_s": time.perf_counter() - start_time,
**{k: v for k, v in output_dict.items() if k != "loss"},
}