fix online training
This commit is contained in:
parent
4a3eac4743
commit
e09d25267e
|
@ -14,7 +14,7 @@ def preprocess_observation(observation, transform=None):
|
||||||
imgs = {"observation.image": observation["pixels"]}
|
imgs = {"observation.image": observation["pixels"]}
|
||||||
|
|
||||||
for imgkey, img in imgs.items():
|
for imgkey, img in imgs.items():
|
||||||
img = torch.from_numpy(img).float()
|
img = torch.from_numpy(img)
|
||||||
# convert to (b c h w) torch format
|
# convert to (b c h w) torch format
|
||||||
img = einops.rearrange(img, "b h w c -> b c h w")
|
img = einops.rearrange(img, "b h w c -> b c h w")
|
||||||
obs[imgkey] = img
|
obs[imgkey] = img
|
||||||
|
|
|
@ -43,6 +43,7 @@ import numpy as np
|
||||||
import torch
|
import torch
|
||||||
from datasets import Dataset
|
from datasets import Dataset
|
||||||
from huggingface_hub import snapshot_download
|
from huggingface_hub import snapshot_download
|
||||||
|
from PIL import Image as PILImage
|
||||||
|
|
||||||
from lerobot.common.datasets.factory import make_dataset
|
from lerobot.common.datasets.factory import make_dataset
|
||||||
from lerobot.common.envs.factory import make_env
|
from lerobot.common.envs.factory import make_env
|
||||||
|
@ -218,7 +219,7 @@ def eval_policy(
|
||||||
"episode_data_id_to": torch.tensor([idx_from + num_frames - 1] * num_frames),
|
"episode_data_id_to": torch.tensor([idx_from + num_frames - 1] * num_frames),
|
||||||
}
|
}
|
||||||
for key in observations:
|
for key in observations:
|
||||||
ep_dict[key] = observations[key][ep_id, :num_frames]
|
ep_dict[key] = observations[key][ep_id][:num_frames]
|
||||||
ep_dicts.append(ep_dict)
|
ep_dicts.append(ep_dict)
|
||||||
|
|
||||||
idx_from += num_frames
|
idx_from += num_frames
|
||||||
|
@ -227,7 +228,17 @@ def eval_policy(
|
||||||
data_dict = {}
|
data_dict = {}
|
||||||
keys = ep_dicts[0].keys()
|
keys = ep_dicts[0].keys()
|
||||||
for key in keys:
|
for key in keys:
|
||||||
data_dict[key] = torch.cat([x[key] for x in ep_dicts])
|
if "image" not in key:
|
||||||
|
data_dict[key] = torch.cat([x[key] for x in ep_dicts])
|
||||||
|
else:
|
||||||
|
if key not in data_dict:
|
||||||
|
data_dict[key] = []
|
||||||
|
for ep_dict in ep_dicts:
|
||||||
|
for x in ep_dict[key]:
|
||||||
|
# c h w -> h w c
|
||||||
|
img = PILImage.fromarray(x.permute(1, 2, 0).numpy())
|
||||||
|
data_dict[key].append(img)
|
||||||
|
|
||||||
data_dict["index"] = torch.arange(0, total_frames, 1)
|
data_dict["index"] = torch.arange(0, total_frames, 1)
|
||||||
|
|
||||||
data_dict = Dataset.from_dict(data_dict).with_format("torch")
|
data_dict = Dataset.from_dict(data_dict).with_format("torch")
|
||||||
|
|
Loading…
Reference in New Issue