diff --git a/lerobot/scripts/eval.py b/lerobot/scripts/eval.py index 1528612d..7b3c6dd3 100644 --- a/lerobot/scripts/eval.py +++ b/lerobot/scripts/eval.py @@ -47,6 +47,7 @@ from PIL import Image as PILImage from tqdm import trange from lerobot.common.datasets.factory import make_dataset +from lerobot.common.datasets.utils import hf_transform_to_torch from lerobot.common.envs.factory import make_env from lerobot.common.envs.utils import postprocess_action, preprocess_observation from lerobot.common.logger import log_output_dir @@ -270,7 +271,8 @@ def eval_policy( data_dict["index"] = torch.arange(0, total_frames, 1) - hf_dataset = Dataset.from_dict(data_dict).with_format("torch") + hf_dataset = Dataset.from_dict(data_dict) + hf_dataset.set_transform(hf_transform_to_torch) if max_episodes_rendered > 0: batch_stacked_frames = np.stack(ep_frames, 1) # (b, t, *)