fix online training
This commit is contained in:
parent
06628ba059
commit
2a59825a00
|
@ -47,6 +47,7 @@ from PIL import Image as PILImage
|
||||||
from tqdm import trange
|
from tqdm import trange
|
||||||
|
|
||||||
from lerobot.common.datasets.factory import make_dataset
|
from lerobot.common.datasets.factory import make_dataset
|
||||||
|
from lerobot.common.datasets.utils import hf_transform_to_torch
|
||||||
from lerobot.common.envs.factory import make_env
|
from lerobot.common.envs.factory import make_env
|
||||||
from lerobot.common.envs.utils import postprocess_action, preprocess_observation
|
from lerobot.common.envs.utils import postprocess_action, preprocess_observation
|
||||||
from lerobot.common.logger import log_output_dir
|
from lerobot.common.logger import log_output_dir
|
||||||
|
@ -270,7 +271,8 @@ def eval_policy(
|
||||||
|
|
||||||
data_dict["index"] = torch.arange(0, total_frames, 1)
|
data_dict["index"] = torch.arange(0, total_frames, 1)
|
||||||
|
|
||||||
hf_dataset = Dataset.from_dict(data_dict).with_format("torch")
|
hf_dataset = Dataset.from_dict(data_dict)
|
||||||
|
hf_dataset.set_transform(hf_transform_to_torch)
|
||||||
|
|
||||||
if max_episodes_rendered > 0:
|
if max_episodes_rendered > 0:
|
||||||
batch_stacked_frames = np.stack(ep_frames, 1) # (b, t, *)
|
batch_stacked_frames = np.stack(ep_frames, 1) # (b, t, *)
|
||||||
|
|
Loading…
Reference in New Issue