2024-03-27 00:13:40 +08:00
|
|
|
"""
|
|
|
|
This scripts demonstrates how to evaluate a pretrained policy from the HuggingFace Hub or from your local
|
|
|
|
training outputs directory. In the latter case, you might want to run examples/3_train_policy.py first.
|
|
|
|
"""
|
|
|
|
|
|
|
|
from pathlib import Path
|
|
|
|
|
2024-05-06 09:03:14 +08:00
|
|
|
import gym_pusht # noqa: F401
|
|
|
|
import gymnasium as gym
|
|
|
|
import imageio
|
|
|
|
import numpy
|
|
|
|
import torch
|
2024-03-27 00:13:40 +08:00
|
|
|
from huggingface_hub import snapshot_download
|
|
|
|
|
2024-05-06 09:03:14 +08:00
|
|
|
from lerobot.common.policies.diffusion.modeling_diffusion import DiffusionPolicy
|
2024-03-27 00:13:40 +08:00
|
|
|
|
2024-05-06 09:03:14 +08:00
|
|
|
# Create a directory to store the video of the evaluation
|
|
|
|
output_directory = Path("outputs/eval/example_pusht_diffusion")
|
|
|
|
output_directory.mkdir(parents=True, exist_ok=True)
|
|
|
|
|
2024-08-15 20:59:47 +08:00
|
|
|
# Download the diffusion policy for pusht environment
|
|
|
|
pretrained_policy_path = Path(snapshot_download("lerobot/diffusion_pusht"))
|
|
|
|
# OR uncomment the following to evaluate a policy from the local outputs/train folder.
|
|
|
|
# pretrained_policy_path = Path("outputs/train/example_pusht_diffusion")
|
|
|
|
|
|
|
|
policy = DiffusionPolicy.from_pretrained(pretrained_policy_path)
|
|
|
|
policy.eval()
|
|
|
|
|
2024-08-13 23:03:05 +08:00
|
|
|
# Check if GPU is available
|
|
|
|
if torch.cuda.is_available():
|
|
|
|
device = torch.device("cuda")
|
|
|
|
print("GPU is available. Device set to:", device)
|
|
|
|
else:
|
|
|
|
device = torch.device("cpu")
|
|
|
|
print(f"GPU is not available. Device set to: {device}. Inference will be slower than on GPU.")
|
|
|
|
# Decrease the number of reverse-diffusion steps (trades off a bit of quality for 10x speed)
|
|
|
|
policy.diffusion.num_inference_steps = 10
|
2024-05-06 09:03:14 +08:00
|
|
|
|
|
|
|
policy.to(device)
|
|
|
|
|
|
|
|
# Initialize evaluation environment to render two observation types:
|
|
|
|
# an image of the scene and state/position of the agent. The environment
|
|
|
|
# also automatically stops running after 300 interactions/steps.
|
|
|
|
env = gym.make(
|
|
|
|
"gym_pusht/PushT-v0",
|
|
|
|
obs_type="pixels_agent_pos",
|
|
|
|
max_episode_steps=300,
|
|
|
|
)
|
|
|
|
|
|
|
|
# Reset the policy and environmens to prepare for rollout
|
|
|
|
policy.reset()
|
|
|
|
numpy_observation, info = env.reset(seed=42)
|
|
|
|
|
|
|
|
# Prepare to collect every rewards and all the frames of the episode,
|
|
|
|
# from initial state to final state.
|
|
|
|
rewards = []
|
|
|
|
frames = []
|
|
|
|
|
|
|
|
# Render frame of the initial state
|
|
|
|
frames.append(env.render())
|
|
|
|
|
|
|
|
step = 0
|
|
|
|
done = False
|
|
|
|
while not done:
|
|
|
|
# Prepare observation for the policy running in Pytorch
|
|
|
|
state = torch.from_numpy(numpy_observation["agent_pos"])
|
|
|
|
image = torch.from_numpy(numpy_observation["pixels"])
|
|
|
|
|
|
|
|
# Convert to float32 with image from channel first in [0,255]
|
|
|
|
# to channel last in [0,1]
|
|
|
|
state = state.to(torch.float32)
|
|
|
|
image = image.to(torch.float32) / 255
|
|
|
|
image = image.permute(2, 0, 1)
|
|
|
|
|
|
|
|
# Send data tensors from CPU to GPU
|
|
|
|
state = state.to(device, non_blocking=True)
|
|
|
|
image = image.to(device, non_blocking=True)
|
|
|
|
|
|
|
|
# Add extra (empty) batch dimension, required to forward the policy
|
|
|
|
state = state.unsqueeze(0)
|
|
|
|
image = image.unsqueeze(0)
|
|
|
|
|
|
|
|
# Create the policy input dictionary
|
|
|
|
observation = {
|
|
|
|
"observation.state": state,
|
|
|
|
"observation.image": image,
|
|
|
|
}
|
|
|
|
|
|
|
|
# Predict the next action with respect to the current observation
|
|
|
|
with torch.inference_mode():
|
|
|
|
action = policy.select_action(observation)
|
|
|
|
|
|
|
|
# Prepare the action for the environment
|
|
|
|
numpy_action = action.squeeze(0).to("cpu").numpy()
|
|
|
|
|
|
|
|
# Step through the environment and receive a new observation
|
|
|
|
numpy_observation, reward, terminated, truncated, info = env.step(numpy_action)
|
|
|
|
print(f"{step=} {reward=} {terminated=}")
|
|
|
|
|
|
|
|
# Keep track of all the rewards and frames
|
|
|
|
rewards.append(reward)
|
|
|
|
frames.append(env.render())
|
|
|
|
|
|
|
|
# The rollout is considered done when the success state is reach (i.e. terminated is True),
|
|
|
|
# or the maximum number of iterations is reached (i.e. truncated is True)
|
|
|
|
done = terminated | truncated | done
|
|
|
|
step += 1
|
|
|
|
|
|
|
|
if terminated:
|
|
|
|
print("Success!")
|
|
|
|
else:
|
|
|
|
print("Failure!")
|
|
|
|
|
|
|
|
# Get the speed of environment (i.e. its number of frames per second).
|
|
|
|
fps = env.metadata["render_fps"]
|
|
|
|
|
|
|
|
# Encode all frames into a mp4 video.
|
|
|
|
video_path = output_directory / "rollout.mp4"
|
|
|
|
imageio.mimsave(str(video_path), numpy.stack(frames), fps=fps)
|
2024-03-27 00:13:40 +08:00
|
|
|
|
2024-05-06 09:03:14 +08:00
|
|
|
print(f"Video of the evaluation is available in '{video_path}'.")
|