fix online training

This commit is contained in:
Cadene 2024-04-20 00:12:34 +00:00
parent 06628ba059
commit 2a59825a00
1 changed files with 3 additions and 1 deletions

View File

@ -47,6 +47,7 @@ from PIL import Image as PILImage
from tqdm import trange
from lerobot.common.datasets.factory import make_dataset
from lerobot.common.datasets.utils import hf_transform_to_torch
from lerobot.common.envs.factory import make_env
from lerobot.common.envs.utils import postprocess_action, preprocess_observation
from lerobot.common.logger import log_output_dir
@ -270,7 +271,8 @@ def eval_policy(
data_dict["index"] = torch.arange(0, total_frames, 1)
hf_dataset = Dataset.from_dict(data_dict).with_format("torch")
hf_dataset = Dataset.from_dict(data_dict)
hf_dataset.set_transform(hf_transform_to_torch)
if max_episodes_rendered > 0:
batch_stacked_frames = np.stack(ep_frames, 1) # (b, t, *)