From 27035d6cfb9eee5e58f1db32c4f47313265e97f2 Mon Sep 17 00:00:00 2001 From: Cadene Date: Tue, 23 Apr 2024 15:08:18 +0000 Subject: [PATCH] fix online training --- .github/workflows/test.yml | 3 ++- lerobot/scripts/eval.py | 30 ++++++++++++++++++++++++++++-- lerobot/scripts/train.py | 27 +++++++++++++++++---------- 3 files changed, 47 insertions(+), 13 deletions(-) diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index f0a7e78c..e56ceeaf 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -193,8 +193,9 @@ jobs: env=xarm \ wandb.enable=False \ offline_steps=1 \ - online_steps=1 \ + online_steps=2 \ eval_episodes=1 \ + env.episode_length=2 \ device=cpu \ save_model=true \ save_freq=2 \ diff --git a/lerobot/scripts/eval.py b/lerobot/scripts/eval.py index 7b3c6dd3..32b7e26b 100644 --- a/lerobot/scripts/eval.py +++ b/lerobot/scripts/eval.py @@ -41,7 +41,7 @@ import gymnasium as gym import imageio import numpy as np import torch -from datasets import Dataset +from datasets import Dataset, Features, Image, Sequence, Value from huggingface_hub import snapshot_download from PIL import Image as PILImage from tqdm import trange @@ -270,8 +270,34 @@ def eval_policy( data_dict[key].append(img) data_dict["index"] = torch.arange(0, total_frames, 1) + episode_data_index["from"] = torch.tensor(episode_data_index["from"]) + episode_data_index["to"] = torch.tensor(episode_data_index["to"]) - hf_dataset = Dataset.from_dict(data_dict) + # TODO(rcadene): clean this + features = {} + for key in observations: + if "image" in key: + features[key] = Image() + else: + features[key] = Sequence( + length=data_dict[key].shape[1], feature=Value(dtype="float32", id=None) + ) + features.update( + { + "action": Sequence( + length=data_dict["action"].shape[1], feature=Value(dtype="float32", id=None) + ), + "episode_index": Value(dtype="int64", id=None), + "frame_index": Value(dtype="int64", id=None), + "timestamp": Value(dtype="float32", id=None), + "next.reward": Value(dtype="float32", id=None), + "next.done": Value(dtype="bool", id=None), + #'next.success': Value(dtype='bool', id=None), + "index": Value(dtype="int64", id=None), + } + ) + features = Features(features) + hf_dataset = Dataset.from_dict(data_dict, features=features) hf_dataset.set_transform(hf_transform_to_torch) if max_episodes_rendered > 0: diff --git a/lerobot/scripts/train.py b/lerobot/scripts/train.py index 8a70a214..d5134600 100644 --- a/lerobot/scripts/train.py +++ b/lerobot/scripts/train.py @@ -160,27 +160,33 @@ def add_episodes_inplace( Raises: - AssertionError: If the first episode_id or index in hf_dataset is not 0 """ - first_episode_id = hf_dataset.select_columns("episode_index")[0]["episode_index"].item() + first_episode_idx = hf_dataset.select_columns("episode_index")[0]["episode_index"].item() + last_episode_idx = hf_dataset.select_columns("episode_index")[-1]["episode_index"].item() first_index = hf_dataset.select_columns("index")[0]["index"].item() - assert first_episode_id == 0, f"We expect the first episode_id to be 0 and not {first_episode_id}" - assert first_index == 0, f"We expect the first first_index to be 0 and not {first_index}" + last_index = hf_dataset.select_columns("index")[-1]["index"].item() + # sanity check + assert first_episode_idx == 0, f"{first_episode_idx=} is not 0" + assert first_index == 0, f"{first_index=} is not 0" + assert first_index == episode_data_index["from"][first_episode_idx].item() + assert last_index == episode_data_index["to"][last_episode_idx].item() - 1 + print(1, hf_dataset.features["observation.image"]) if len(online_dataset) == 0: # initialize online dataset online_dataset.hf_dataset = hf_dataset + online_dataset.episode_data_index = episode_data_index else: - # find episode index and data frame indices according to previous episode in online_dataset - start_episode = online_dataset.select_columns("episode_index")[-1]["episode_index"].item() + 1 - start_index = online_dataset.select_columns("index")[-1]["index"].item() + 1 + # get the starting indices of the new episodes and frames to be added + start_episode_idx = last_episode_idx + 1 + start_index = last_index + 1 - def shift_indices(example): + def shift_indices(episode_index, index): # note: we dont shift "frame_index" since it represents the index of the frame in the episode it belongs to - example["episode_index"] += start_episode - example["index"] += start_index + example = {"episode_index": episode_index + start_episode_idx, "index": index + start_index} return example disable_progress_bars() # map has a tqdm progress bar - hf_dataset = hf_dataset.map(shift_indices) + hf_dataset = hf_dataset.map(shift_indices, input_columns=["episode_index", "index"]) enable_progress_bars() episode_data_index["from"] += start_index @@ -306,6 +312,7 @@ def train(cfg: dict, out_dir=None, job_name=None): # create an empty online dataset similar to offline dataset online_dataset = deepcopy(offline_dataset) online_dataset.hf_dataset = {} + online_dataset.episode_data_index = {} # create dataloader for online training concat_dataset = torch.utils.data.ConcatDataset([offline_dataset, online_dataset])