fix online training

This commit is contained in:
Cadene 2024-04-16 16:07:39 +00:00
parent 4a3eac4743
commit e09d25267e
2 changed files with 14 additions and 3 deletions

View File

@ -14,7 +14,7 @@ def preprocess_observation(observation, transform=None):
imgs = {"observation.image": observation["pixels"]}
for imgkey, img in imgs.items():
img = torch.from_numpy(img).float()
img = torch.from_numpy(img)
# convert to (b c h w) torch format
img = einops.rearrange(img, "b h w c -> b c h w")
obs[imgkey] = img

View File

@ -43,6 +43,7 @@ import numpy as np
import torch
from datasets import Dataset
from huggingface_hub import snapshot_download
from PIL import Image as PILImage
from lerobot.common.datasets.factory import make_dataset
from lerobot.common.envs.factory import make_env
@ -218,7 +219,7 @@ def eval_policy(
"episode_data_id_to": torch.tensor([idx_from + num_frames - 1] * num_frames),
}
for key in observations:
ep_dict[key] = observations[key][ep_id, :num_frames]
ep_dict[key] = observations[key][ep_id][:num_frames]
ep_dicts.append(ep_dict)
idx_from += num_frames
@ -227,7 +228,17 @@ def eval_policy(
data_dict = {}
keys = ep_dicts[0].keys()
for key in keys:
data_dict[key] = torch.cat([x[key] for x in ep_dicts])
if "image" not in key:
data_dict[key] = torch.cat([x[key] for x in ep_dicts])
else:
if key not in data_dict:
data_dict[key] = []
for ep_dict in ep_dicts:
for x in ep_dict[key]:
# c h w -> h w c
img = PILImage.fromarray(x.permute(1, 2, 0).numpy())
data_dict[key].append(img)
data_dict["index"] = torch.arange(0, total_frames, 1)
data_dict = Dataset.from_dict(data_dict).with_format("torch")