fix online training
This commit is contained in:
parent
06628ba059
commit
2a59825a00
|
@ -47,6 +47,7 @@ from PIL import Image as PILImage
|
|||
from tqdm import trange
|
||||
|
||||
from lerobot.common.datasets.factory import make_dataset
|
||||
from lerobot.common.datasets.utils import hf_transform_to_torch
|
||||
from lerobot.common.envs.factory import make_env
|
||||
from lerobot.common.envs.utils import postprocess_action, preprocess_observation
|
||||
from lerobot.common.logger import log_output_dir
|
||||
|
@ -270,7 +271,8 @@ def eval_policy(
|
|||
|
||||
data_dict["index"] = torch.arange(0, total_frames, 1)
|
||||
|
||||
hf_dataset = Dataset.from_dict(data_dict).with_format("torch")
|
||||
hf_dataset = Dataset.from_dict(data_dict)
|
||||
hf_dataset.set_transform(hf_transform_to_torch)
|
||||
|
||||
if max_episodes_rendered > 0:
|
||||
batch_stacked_frames = np.stack(ep_frames, 1) # (b, t, *)
|
||||
|
|
Loading…
Reference in New Issue