2025-03-06 00:56:51 +08:00
|
|
|
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
|
|
|
#
|
|
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
|
|
# you may not use this file except in compliance with the License.
|
|
|
|
# You may obtain a copy of the License at
|
|
|
|
#
|
|
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
|
|
#
|
|
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
|
|
# See the License for the specific language governing permissions and
|
|
|
|
# limitations under the License.
|
|
|
|
|
2024-03-27 00:13:40 +08:00
|
|
|
"""
|
|
|
|
This scripts demonstrates how to evaluate a pretrained policy from the HuggingFace Hub or from your local
|
|
|
|
training outputs directory. In the latter case, you might want to run examples/3_train_policy.py first.
|
2025-01-31 20:57:37 +08:00
|
|
|
|
|
|
|
It requires the installation of the 'gym_pusht' simulation environment. Install it by running:
|
|
|
|
```bash
|
2025-03-27 16:55:06 +08:00
|
|
|
pip install --no-binary=av -e ".[pusht]"`
|
2025-01-31 20:57:37 +08:00
|
|
|
```
|
2024-03-27 00:13:40 +08:00
|
|
|
"""
|
|
|
|
|
|
|
|
from pathlib import Path
|
|
|
|
|
2024-05-06 09:03:14 +08:00
|
|
|
import gym_pusht # noqa: F401
|
|
|
|
import gymnasium as gym
|
|
|
|
import imageio
|
|
|
|
import numpy
|
|
|
|
import torch
|
2024-03-27 00:13:40 +08:00
|
|
|
|
2024-05-06 09:03:14 +08:00
|
|
|
from lerobot.common.policies.diffusion.modeling_diffusion import DiffusionPolicy
|
2024-03-27 00:13:40 +08:00
|
|
|
|
2024-05-06 09:03:14 +08:00
|
|
|
# Create a directory to store the video of the evaluation
|
|
|
|
output_directory = Path("outputs/eval/example_pusht_diffusion")
|
|
|
|
output_directory.mkdir(parents=True, exist_ok=True)
|
|
|
|
|
2025-01-31 20:57:37 +08:00
|
|
|
# Select your device
|
|
|
|
device = "cuda"
|
2024-08-15 20:59:47 +08:00
|
|
|
|
2025-01-31 20:57:37 +08:00
|
|
|
# Provide the [hugging face repo id](https://huggingface.co/lerobot/diffusion_pusht):
|
|
|
|
pretrained_policy_path = "lerobot/diffusion_pusht"
|
|
|
|
# OR a path to a local outputs/train folder.
|
|
|
|
# pretrained_policy_path = Path("outputs/train/example_pusht_diffusion")
|
2024-05-06 09:03:14 +08:00
|
|
|
|
2025-03-07 20:21:11 +08:00
|
|
|
policy = DiffusionPolicy.from_pretrained(pretrained_policy_path)
|
2024-05-06 09:03:14 +08:00
|
|
|
|
|
|
|
# Initialize evaluation environment to render two observation types:
|
|
|
|
# an image of the scene and state/position of the agent. The environment
|
|
|
|
# also automatically stops running after 300 interactions/steps.
|
|
|
|
env = gym.make(
|
|
|
|
"gym_pusht/PushT-v0",
|
|
|
|
obs_type="pixels_agent_pos",
|
|
|
|
max_episode_steps=300,
|
|
|
|
)
|
|
|
|
|
2025-01-31 20:57:37 +08:00
|
|
|
# We can verify that the shapes of the features expected by the policy match the ones from the observations
|
|
|
|
# produced by the environment
|
|
|
|
print(policy.config.input_features)
|
|
|
|
print(env.observation_space)
|
|
|
|
|
|
|
|
# Similarly, we can check that the actions produced by the policy will match the actions expected by the
|
|
|
|
# environment
|
|
|
|
print(policy.config.output_features)
|
|
|
|
print(env.action_space)
|
|
|
|
|
|
|
|
# Reset the policy and environments to prepare for rollout
|
2024-05-06 09:03:14 +08:00
|
|
|
policy.reset()
|
|
|
|
numpy_observation, info = env.reset(seed=42)
|
|
|
|
|
|
|
|
# Prepare to collect every rewards and all the frames of the episode,
|
|
|
|
# from initial state to final state.
|
|
|
|
rewards = []
|
|
|
|
frames = []
|
|
|
|
|
|
|
|
# Render frame of the initial state
|
|
|
|
frames.append(env.render())
|
|
|
|
|
|
|
|
step = 0
|
|
|
|
done = False
|
|
|
|
while not done:
|
|
|
|
# Prepare observation for the policy running in Pytorch
|
|
|
|
state = torch.from_numpy(numpy_observation["agent_pos"])
|
|
|
|
image = torch.from_numpy(numpy_observation["pixels"])
|
|
|
|
|
|
|
|
# Convert to float32 with image from channel first in [0,255]
|
|
|
|
# to channel last in [0,1]
|
|
|
|
state = state.to(torch.float32)
|
|
|
|
image = image.to(torch.float32) / 255
|
|
|
|
image = image.permute(2, 0, 1)
|
|
|
|
|
|
|
|
# Send data tensors from CPU to GPU
|
|
|
|
state = state.to(device, non_blocking=True)
|
|
|
|
image = image.to(device, non_blocking=True)
|
|
|
|
|
|
|
|
# Add extra (empty) batch dimension, required to forward the policy
|
|
|
|
state = state.unsqueeze(0)
|
|
|
|
image = image.unsqueeze(0)
|
|
|
|
|
|
|
|
# Create the policy input dictionary
|
|
|
|
observation = {
|
|
|
|
"observation.state": state,
|
|
|
|
"observation.image": image,
|
|
|
|
}
|
|
|
|
|
|
|
|
# Predict the next action with respect to the current observation
|
|
|
|
with torch.inference_mode():
|
|
|
|
action = policy.select_action(observation)
|
|
|
|
|
|
|
|
# Prepare the action for the environment
|
|
|
|
numpy_action = action.squeeze(0).to("cpu").numpy()
|
|
|
|
|
|
|
|
# Step through the environment and receive a new observation
|
|
|
|
numpy_observation, reward, terminated, truncated, info = env.step(numpy_action)
|
|
|
|
print(f"{step=} {reward=} {terminated=}")
|
|
|
|
|
|
|
|
# Keep track of all the rewards and frames
|
|
|
|
rewards.append(reward)
|
|
|
|
frames.append(env.render())
|
|
|
|
|
|
|
|
# The rollout is considered done when the success state is reach (i.e. terminated is True),
|
|
|
|
# or the maximum number of iterations is reached (i.e. truncated is True)
|
|
|
|
done = terminated | truncated | done
|
|
|
|
step += 1
|
|
|
|
|
|
|
|
if terminated:
|
|
|
|
print("Success!")
|
|
|
|
else:
|
|
|
|
print("Failure!")
|
|
|
|
|
|
|
|
# Get the speed of environment (i.e. its number of frames per second).
|
|
|
|
fps = env.metadata["render_fps"]
|
|
|
|
|
|
|
|
# Encode all frames into a mp4 video.
|
|
|
|
video_path = output_directory / "rollout.mp4"
|
|
|
|
imageio.mimsave(str(video_path), numpy.stack(frames), fps=fps)
|
2024-03-27 00:13:40 +08:00
|
|
|
|
2024-05-06 09:03:14 +08:00
|
|
|
print(f"Video of the evaluation is available in '{video_path}'.")
|