Fix online training (#94)

This commit is contained in:
Remi 2024-04-23 18:54:55 +02:00 committed by GitHub
parent 1030ea0070
commit c1bcf857c5
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 46 additions and 13 deletions

View File

@ -193,8 +193,9 @@ jobs:
env=xarm \
wandb.enable=False \
offline_steps=1 \
online_steps=1 \
online_steps=2 \
eval_episodes=1 \
env.episode_length=2 \
device=cpu \
save_model=true \
save_freq=2 \

View File

@ -41,7 +41,7 @@ import gymnasium as gym
import imageio
import numpy as np
import torch
from datasets import Dataset
from datasets import Dataset, Features, Image, Sequence, Value
from huggingface_hub import snapshot_download
from PIL import Image as PILImage
from tqdm import trange
@ -270,8 +270,34 @@ def eval_policy(
data_dict[key].append(img)
data_dict["index"] = torch.arange(0, total_frames, 1)
episode_data_index["from"] = torch.tensor(episode_data_index["from"])
episode_data_index["to"] = torch.tensor(episode_data_index["to"])
hf_dataset = Dataset.from_dict(data_dict)
# TODO(rcadene): clean this
features = {}
for key in observations:
if "image" in key:
features[key] = Image()
else:
features[key] = Sequence(
length=data_dict[key].shape[1], feature=Value(dtype="float32", id=None)
)
features.update(
{
"action": Sequence(
length=data_dict["action"].shape[1], feature=Value(dtype="float32", id=None)
),
"episode_index": Value(dtype="int64", id=None),
"frame_index": Value(dtype="int64", id=None),
"timestamp": Value(dtype="float32", id=None),
"next.reward": Value(dtype="float32", id=None),
"next.done": Value(dtype="bool", id=None),
#'next.success': Value(dtype='bool', id=None),
"index": Value(dtype="int64", id=None),
}
)
features = Features(features)
hf_dataset = Dataset.from_dict(data_dict, features=features)
hf_dataset.set_transform(hf_transform_to_torch)
if max_episodes_rendered > 0:

View File

@ -160,27 +160,32 @@ def add_episodes_inplace(
Raises:
- AssertionError: If the first episode_id or index in hf_dataset is not 0
"""
first_episode_id = hf_dataset.select_columns("episode_index")[0]["episode_index"].item()
first_episode_idx = hf_dataset.select_columns("episode_index")[0]["episode_index"].item()
last_episode_idx = hf_dataset.select_columns("episode_index")[-1]["episode_index"].item()
first_index = hf_dataset.select_columns("index")[0]["index"].item()
assert first_episode_id == 0, f"We expect the first episode_id to be 0 and not {first_episode_id}"
assert first_index == 0, f"We expect the first first_index to be 0 and not {first_index}"
last_index = hf_dataset.select_columns("index")[-1]["index"].item()
# sanity check
assert first_episode_idx == 0, f"{first_episode_idx=} is not 0"
assert first_index == 0, f"{first_index=} is not 0"
assert first_index == episode_data_index["from"][first_episode_idx].item()
assert last_index == episode_data_index["to"][last_episode_idx].item() - 1
if len(online_dataset) == 0:
# initialize online dataset
online_dataset.hf_dataset = hf_dataset
online_dataset.episode_data_index = episode_data_index
else:
# find episode index and data frame indices according to previous episode in online_dataset
start_episode = online_dataset.select_columns("episode_index")[-1]["episode_index"].item() + 1
start_index = online_dataset.select_columns("index")[-1]["index"].item() + 1
# get the starting indices of the new episodes and frames to be added
start_episode_idx = last_episode_idx + 1
start_index = last_index + 1
def shift_indices(example):
def shift_indices(episode_index, index):
# note: we dont shift "frame_index" since it represents the index of the frame in the episode it belongs to
example["episode_index"] += start_episode
example["index"] += start_index
example = {"episode_index": episode_index + start_episode_idx, "index": index + start_index}
return example
disable_progress_bars() # map has a tqdm progress bar
hf_dataset = hf_dataset.map(shift_indices)
hf_dataset = hf_dataset.map(shift_indices, input_columns=["episode_index", "index"])
enable_progress_bars()
episode_data_index["from"] += start_index
@ -306,6 +311,7 @@ def train(cfg: dict, out_dir=None, job_name=None):
# create an empty online dataset similar to offline dataset
online_dataset = deepcopy(offline_dataset)
online_dataset.hf_dataset = {}
online_dataset.episode_data_index = {}
# create dataloader for online training
concat_dataset = torch.utils.data.ConcatDataset([offline_dataset, online_dataset])