From 8c4643687c97ffafeecd78a7cde5504febe3427e Mon Sep 17 00:00:00 2001 From: Alexander Soare <alexander.soare159@gmail.com> Date: Thu, 15 Aug 2024 13:59:47 +0100 Subject: [PATCH] fix bug in example 2 (#361) --- examples/2_evaluate_pretrained_policy.py | 15 ++++++++------- 1 file changed, 8 insertions(+), 7 deletions(-) diff --git a/examples/2_evaluate_pretrained_policy.py b/examples/2_evaluate_pretrained_policy.py index 5c1932de..b2fe1dba 100644 --- a/examples/2_evaluate_pretrained_policy.py +++ b/examples/2_evaluate_pretrained_policy.py @@ -18,6 +18,14 @@ from lerobot.common.policies.diffusion.modeling_diffusion import DiffusionPolicy output_directory = Path("outputs/eval/example_pusht_diffusion") output_directory.mkdir(parents=True, exist_ok=True) +# Download the diffusion policy for pusht environment +pretrained_policy_path = Path(snapshot_download("lerobot/diffusion_pusht")) +# OR uncomment the following to evaluate a policy from the local outputs/train folder. +# pretrained_policy_path = Path("outputs/train/example_pusht_diffusion") + +policy = DiffusionPolicy.from_pretrained(pretrained_policy_path) +policy.eval() + # Check if GPU is available if torch.cuda.is_available(): device = torch.device("cuda") @@ -28,13 +36,6 @@ else: # Decrease the number of reverse-diffusion steps (trades off a bit of quality for 10x speed) policy.diffusion.num_inference_steps = 10 -# Download the diffusion policy for pusht environment -pretrained_policy_path = Path(snapshot_download("lerobot/diffusion_pusht")) -# OR uncomment the following to evaluate a policy from the local outputs/train folder. -# pretrained_policy_path = Path("outputs/train/example_pusht_diffusion") - -policy = DiffusionPolicy.from_pretrained(pretrained_policy_path) -policy.eval() policy.to(device) # Initialize evaluation environment to render two observation types: