fix bug in example 2 (#361)
This commit is contained in:
parent
fab037f78d
commit
8c4643687c
|
@ -18,6 +18,14 @@ from lerobot.common.policies.diffusion.modeling_diffusion import DiffusionPolicy
|
||||||
output_directory = Path("outputs/eval/example_pusht_diffusion")
|
output_directory = Path("outputs/eval/example_pusht_diffusion")
|
||||||
output_directory.mkdir(parents=True, exist_ok=True)
|
output_directory.mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
|
# Download the diffusion policy for pusht environment
|
||||||
|
pretrained_policy_path = Path(snapshot_download("lerobot/diffusion_pusht"))
|
||||||
|
# OR uncomment the following to evaluate a policy from the local outputs/train folder.
|
||||||
|
# pretrained_policy_path = Path("outputs/train/example_pusht_diffusion")
|
||||||
|
|
||||||
|
policy = DiffusionPolicy.from_pretrained(pretrained_policy_path)
|
||||||
|
policy.eval()
|
||||||
|
|
||||||
# Check if GPU is available
|
# Check if GPU is available
|
||||||
if torch.cuda.is_available():
|
if torch.cuda.is_available():
|
||||||
device = torch.device("cuda")
|
device = torch.device("cuda")
|
||||||
|
@ -28,13 +36,6 @@ else:
|
||||||
# Decrease the number of reverse-diffusion steps (trades off a bit of quality for 10x speed)
|
# Decrease the number of reverse-diffusion steps (trades off a bit of quality for 10x speed)
|
||||||
policy.diffusion.num_inference_steps = 10
|
policy.diffusion.num_inference_steps = 10
|
||||||
|
|
||||||
# Download the diffusion policy for pusht environment
|
|
||||||
pretrained_policy_path = Path(snapshot_download("lerobot/diffusion_pusht"))
|
|
||||||
# OR uncomment the following to evaluate a policy from the local outputs/train folder.
|
|
||||||
# pretrained_policy_path = Path("outputs/train/example_pusht_diffusion")
|
|
||||||
|
|
||||||
policy = DiffusionPolicy.from_pretrained(pretrained_policy_path)
|
|
||||||
policy.eval()
|
|
||||||
policy.to(device)
|
policy.to(device)
|
||||||
|
|
||||||
# Initialize evaluation environment to render two observation types:
|
# Initialize evaluation environment to render two observation types:
|
||||||
|
|
Loading…
Reference in New Issue