From e09d25267eaa51351c6b72808277ad9a854621c2 Mon Sep 17 00:00:00 2001 From: Cadene Date: Tue, 16 Apr 2024 16:07:39 +0000 Subject: [PATCH] fix online training --- lerobot/common/envs/utils.py | 2 +- lerobot/scripts/eval.py | 15 +++++++++++++-- 2 files changed, 14 insertions(+), 3 deletions(-) diff --git a/lerobot/common/envs/utils.py b/lerobot/common/envs/utils.py index 4d31ddb2..7f5216cd 100644 --- a/lerobot/common/envs/utils.py +++ b/lerobot/common/envs/utils.py @@ -14,7 +14,7 @@ def preprocess_observation(observation, transform=None): imgs = {"observation.image": observation["pixels"]} for imgkey, img in imgs.items(): - img = torch.from_numpy(img).float() + img = torch.from_numpy(img) # convert to (b c h w) torch format img = einops.rearrange(img, "b h w c -> b c h w") obs[imgkey] = img diff --git a/lerobot/scripts/eval.py b/lerobot/scripts/eval.py index 0f4a8399..d8c697c2 100644 --- a/lerobot/scripts/eval.py +++ b/lerobot/scripts/eval.py @@ -43,6 +43,7 @@ import numpy as np import torch from datasets import Dataset from huggingface_hub import snapshot_download +from PIL import Image as PILImage from lerobot.common.datasets.factory import make_dataset from lerobot.common.envs.factory import make_env @@ -218,7 +219,7 @@ def eval_policy( "episode_data_id_to": torch.tensor([idx_from + num_frames - 1] * num_frames), } for key in observations: - ep_dict[key] = observations[key][ep_id, :num_frames] + ep_dict[key] = observations[key][ep_id][:num_frames] ep_dicts.append(ep_dict) idx_from += num_frames @@ -227,7 +228,17 @@ def eval_policy( data_dict = {} keys = ep_dicts[0].keys() for key in keys: - data_dict[key] = torch.cat([x[key] for x in ep_dicts]) + if "image" not in key: + data_dict[key] = torch.cat([x[key] for x in ep_dicts]) + else: + if key not in data_dict: + data_dict[key] = [] + for ep_dict in ep_dicts: + for x in ep_dict[key]: + # c h w -> h w c + img = PILImage.fromarray(x.permute(1, 2, 0).numpy()) + data_dict[key].append(img) + data_dict["index"] = torch.arange(0, total_frames, 1) data_dict = Dataset.from_dict(data_dict).with_format("torch")