Fix online training (#94)

This commit is contained in:
Remi 2024-04-23 18:54:55 +02:00 committed by GitHub
parent 1030ea0070
commit c1bcf857c5
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 46 additions and 13 deletions

View File

@ -193,8 +193,9 @@ jobs:
env=xarm \ env=xarm \
wandb.enable=False \ wandb.enable=False \
offline_steps=1 \ offline_steps=1 \
online_steps=1 \ online_steps=2 \
eval_episodes=1 \ eval_episodes=1 \
env.episode_length=2 \
device=cpu \ device=cpu \
save_model=true \ save_model=true \
save_freq=2 \ save_freq=2 \

View File

@ -41,7 +41,7 @@ import gymnasium as gym
import imageio import imageio
import numpy as np import numpy as np
import torch import torch
from datasets import Dataset from datasets import Dataset, Features, Image, Sequence, Value
from huggingface_hub import snapshot_download from huggingface_hub import snapshot_download
from PIL import Image as PILImage from PIL import Image as PILImage
from tqdm import trange from tqdm import trange
@ -270,8 +270,34 @@ def eval_policy(
data_dict[key].append(img) data_dict[key].append(img)
data_dict["index"] = torch.arange(0, total_frames, 1) data_dict["index"] = torch.arange(0, total_frames, 1)
episode_data_index["from"] = torch.tensor(episode_data_index["from"])
episode_data_index["to"] = torch.tensor(episode_data_index["to"])
hf_dataset = Dataset.from_dict(data_dict) # TODO(rcadene): clean this
features = {}
for key in observations:
if "image" in key:
features[key] = Image()
else:
features[key] = Sequence(
length=data_dict[key].shape[1], feature=Value(dtype="float32", id=None)
)
features.update(
{
"action": Sequence(
length=data_dict["action"].shape[1], feature=Value(dtype="float32", id=None)
),
"episode_index": Value(dtype="int64", id=None),
"frame_index": Value(dtype="int64", id=None),
"timestamp": Value(dtype="float32", id=None),
"next.reward": Value(dtype="float32", id=None),
"next.done": Value(dtype="bool", id=None),
#'next.success': Value(dtype='bool', id=None),
"index": Value(dtype="int64", id=None),
}
)
features = Features(features)
hf_dataset = Dataset.from_dict(data_dict, features=features)
hf_dataset.set_transform(hf_transform_to_torch) hf_dataset.set_transform(hf_transform_to_torch)
if max_episodes_rendered > 0: if max_episodes_rendered > 0:

View File

@ -160,27 +160,32 @@ def add_episodes_inplace(
Raises: Raises:
- AssertionError: If the first episode_id or index in hf_dataset is not 0 - AssertionError: If the first episode_id or index in hf_dataset is not 0
""" """
first_episode_id = hf_dataset.select_columns("episode_index")[0]["episode_index"].item() first_episode_idx = hf_dataset.select_columns("episode_index")[0]["episode_index"].item()
last_episode_idx = hf_dataset.select_columns("episode_index")[-1]["episode_index"].item()
first_index = hf_dataset.select_columns("index")[0]["index"].item() first_index = hf_dataset.select_columns("index")[0]["index"].item()
assert first_episode_id == 0, f"We expect the first episode_id to be 0 and not {first_episode_id}" last_index = hf_dataset.select_columns("index")[-1]["index"].item()
assert first_index == 0, f"We expect the first first_index to be 0 and not {first_index}" # sanity check
assert first_episode_idx == 0, f"{first_episode_idx=} is not 0"
assert first_index == 0, f"{first_index=} is not 0"
assert first_index == episode_data_index["from"][first_episode_idx].item()
assert last_index == episode_data_index["to"][last_episode_idx].item() - 1
if len(online_dataset) == 0: if len(online_dataset) == 0:
# initialize online dataset # initialize online dataset
online_dataset.hf_dataset = hf_dataset online_dataset.hf_dataset = hf_dataset
online_dataset.episode_data_index = episode_data_index
else: else:
# find episode index and data frame indices according to previous episode in online_dataset # get the starting indices of the new episodes and frames to be added
start_episode = online_dataset.select_columns("episode_index")[-1]["episode_index"].item() + 1 start_episode_idx = last_episode_idx + 1
start_index = online_dataset.select_columns("index")[-1]["index"].item() + 1 start_index = last_index + 1
def shift_indices(example): def shift_indices(episode_index, index):
# note: we dont shift "frame_index" since it represents the index of the frame in the episode it belongs to # note: we dont shift "frame_index" since it represents the index of the frame in the episode it belongs to
example["episode_index"] += start_episode example = {"episode_index": episode_index + start_episode_idx, "index": index + start_index}
example["index"] += start_index
return example return example
disable_progress_bars() # map has a tqdm progress bar disable_progress_bars() # map has a tqdm progress bar
hf_dataset = hf_dataset.map(shift_indices) hf_dataset = hf_dataset.map(shift_indices, input_columns=["episode_index", "index"])
enable_progress_bars() enable_progress_bars()
episode_data_index["from"] += start_index episode_data_index["from"] += start_index
@ -306,6 +311,7 @@ def train(cfg: dict, out_dir=None, job_name=None):
# create an empty online dataset similar to offline dataset # create an empty online dataset similar to offline dataset
online_dataset = deepcopy(offline_dataset) online_dataset = deepcopy(offline_dataset)
online_dataset.hf_dataset = {} online_dataset.hf_dataset = {}
online_dataset.episode_data_index = {}
# create dataloader for online training # create dataloader for online training
concat_dataset = torch.utils.data.ConcatDataset([offline_dataset, online_dataset]) concat_dataset = torch.utils.data.ConcatDataset([offline_dataset, online_dataset])