fix bug in example 2 (#361)

This commit is contained in:
Alexander Soare 2024-08-15 13:59:47 +01:00 committed by GitHub
parent fab037f78d
commit 8c4643687c
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
1 changed files with 8 additions and 7 deletions

View File

@ -18,6 +18,14 @@ from lerobot.common.policies.diffusion.modeling_diffusion import DiffusionPolicy
output_directory = Path("outputs/eval/example_pusht_diffusion")
output_directory.mkdir(parents=True, exist_ok=True)
# Download the diffusion policy for pusht environment
pretrained_policy_path = Path(snapshot_download("lerobot/diffusion_pusht"))
# OR uncomment the following to evaluate a policy from the local outputs/train folder.
# pretrained_policy_path = Path("outputs/train/example_pusht_diffusion")
policy = DiffusionPolicy.from_pretrained(pretrained_policy_path)
policy.eval()
# Check if GPU is available
if torch.cuda.is_available():
device = torch.device("cuda")
@ -28,13 +36,6 @@ else:
# Decrease the number of reverse-diffusion steps (trades off a bit of quality for 10x speed)
policy.diffusion.num_inference_steps = 10
# Download the diffusion policy for pusht environment
pretrained_policy_path = Path(snapshot_download("lerobot/diffusion_pusht"))
# OR uncomment the following to evaluate a policy from the local outputs/train folder.
# pretrained_policy_path = Path("outputs/train/example_pusht_diffusion")
policy = DiffusionPolicy.from_pretrained(pretrained_policy_path)
policy.eval()
policy.to(device)
# Initialize evaluation environment to render two observation types: