fix online training
This commit is contained in:
parent
4a3eac4743
commit
e09d25267e
|
@ -14,7 +14,7 @@ def preprocess_observation(observation, transform=None):
|
|||
imgs = {"observation.image": observation["pixels"]}
|
||||
|
||||
for imgkey, img in imgs.items():
|
||||
img = torch.from_numpy(img).float()
|
||||
img = torch.from_numpy(img)
|
||||
# convert to (b c h w) torch format
|
||||
img = einops.rearrange(img, "b h w c -> b c h w")
|
||||
obs[imgkey] = img
|
||||
|
|
|
@ -43,6 +43,7 @@ import numpy as np
|
|||
import torch
|
||||
from datasets import Dataset
|
||||
from huggingface_hub import snapshot_download
|
||||
from PIL import Image as PILImage
|
||||
|
||||
from lerobot.common.datasets.factory import make_dataset
|
||||
from lerobot.common.envs.factory import make_env
|
||||
|
@ -218,7 +219,7 @@ def eval_policy(
|
|||
"episode_data_id_to": torch.tensor([idx_from + num_frames - 1] * num_frames),
|
||||
}
|
||||
for key in observations:
|
||||
ep_dict[key] = observations[key][ep_id, :num_frames]
|
||||
ep_dict[key] = observations[key][ep_id][:num_frames]
|
||||
ep_dicts.append(ep_dict)
|
||||
|
||||
idx_from += num_frames
|
||||
|
@ -227,7 +228,17 @@ def eval_policy(
|
|||
data_dict = {}
|
||||
keys = ep_dicts[0].keys()
|
||||
for key in keys:
|
||||
data_dict[key] = torch.cat([x[key] for x in ep_dicts])
|
||||
if "image" not in key:
|
||||
data_dict[key] = torch.cat([x[key] for x in ep_dicts])
|
||||
else:
|
||||
if key not in data_dict:
|
||||
data_dict[key] = []
|
||||
for ep_dict in ep_dicts:
|
||||
for x in ep_dict[key]:
|
||||
# c h w -> h w c
|
||||
img = PILImage.fromarray(x.permute(1, 2, 0).numpy())
|
||||
data_dict[key].append(img)
|
||||
|
||||
data_dict["index"] = torch.arange(0, total_frames, 1)
|
||||
|
||||
data_dict = Dataset.from_dict(data_dict).with_format("torch")
|
||||
|
|
Loading…
Reference in New Issue