Fix online training (#94)
This commit is contained in:
parent
1030ea0070
commit
c1bcf857c5
|
@ -193,8 +193,9 @@ jobs:
|
||||||
env=xarm \
|
env=xarm \
|
||||||
wandb.enable=False \
|
wandb.enable=False \
|
||||||
offline_steps=1 \
|
offline_steps=1 \
|
||||||
online_steps=1 \
|
online_steps=2 \
|
||||||
eval_episodes=1 \
|
eval_episodes=1 \
|
||||||
|
env.episode_length=2 \
|
||||||
device=cpu \
|
device=cpu \
|
||||||
save_model=true \
|
save_model=true \
|
||||||
save_freq=2 \
|
save_freq=2 \
|
||||||
|
|
|
@ -41,7 +41,7 @@ import gymnasium as gym
|
||||||
import imageio
|
import imageio
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
from datasets import Dataset
|
from datasets import Dataset, Features, Image, Sequence, Value
|
||||||
from huggingface_hub import snapshot_download
|
from huggingface_hub import snapshot_download
|
||||||
from PIL import Image as PILImage
|
from PIL import Image as PILImage
|
||||||
from tqdm import trange
|
from tqdm import trange
|
||||||
|
@ -270,8 +270,34 @@ def eval_policy(
|
||||||
data_dict[key].append(img)
|
data_dict[key].append(img)
|
||||||
|
|
||||||
data_dict["index"] = torch.arange(0, total_frames, 1)
|
data_dict["index"] = torch.arange(0, total_frames, 1)
|
||||||
|
episode_data_index["from"] = torch.tensor(episode_data_index["from"])
|
||||||
|
episode_data_index["to"] = torch.tensor(episode_data_index["to"])
|
||||||
|
|
||||||
hf_dataset = Dataset.from_dict(data_dict)
|
# TODO(rcadene): clean this
|
||||||
|
features = {}
|
||||||
|
for key in observations:
|
||||||
|
if "image" in key:
|
||||||
|
features[key] = Image()
|
||||||
|
else:
|
||||||
|
features[key] = Sequence(
|
||||||
|
length=data_dict[key].shape[1], feature=Value(dtype="float32", id=None)
|
||||||
|
)
|
||||||
|
features.update(
|
||||||
|
{
|
||||||
|
"action": Sequence(
|
||||||
|
length=data_dict["action"].shape[1], feature=Value(dtype="float32", id=None)
|
||||||
|
),
|
||||||
|
"episode_index": Value(dtype="int64", id=None),
|
||||||
|
"frame_index": Value(dtype="int64", id=None),
|
||||||
|
"timestamp": Value(dtype="float32", id=None),
|
||||||
|
"next.reward": Value(dtype="float32", id=None),
|
||||||
|
"next.done": Value(dtype="bool", id=None),
|
||||||
|
#'next.success': Value(dtype='bool', id=None),
|
||||||
|
"index": Value(dtype="int64", id=None),
|
||||||
|
}
|
||||||
|
)
|
||||||
|
features = Features(features)
|
||||||
|
hf_dataset = Dataset.from_dict(data_dict, features=features)
|
||||||
hf_dataset.set_transform(hf_transform_to_torch)
|
hf_dataset.set_transform(hf_transform_to_torch)
|
||||||
|
|
||||||
if max_episodes_rendered > 0:
|
if max_episodes_rendered > 0:
|
||||||
|
|
|
@ -160,27 +160,32 @@ def add_episodes_inplace(
|
||||||
Raises:
|
Raises:
|
||||||
- AssertionError: If the first episode_id or index in hf_dataset is not 0
|
- AssertionError: If the first episode_id or index in hf_dataset is not 0
|
||||||
"""
|
"""
|
||||||
first_episode_id = hf_dataset.select_columns("episode_index")[0]["episode_index"].item()
|
first_episode_idx = hf_dataset.select_columns("episode_index")[0]["episode_index"].item()
|
||||||
|
last_episode_idx = hf_dataset.select_columns("episode_index")[-1]["episode_index"].item()
|
||||||
first_index = hf_dataset.select_columns("index")[0]["index"].item()
|
first_index = hf_dataset.select_columns("index")[0]["index"].item()
|
||||||
assert first_episode_id == 0, f"We expect the first episode_id to be 0 and not {first_episode_id}"
|
last_index = hf_dataset.select_columns("index")[-1]["index"].item()
|
||||||
assert first_index == 0, f"We expect the first first_index to be 0 and not {first_index}"
|
# sanity check
|
||||||
|
assert first_episode_idx == 0, f"{first_episode_idx=} is not 0"
|
||||||
|
assert first_index == 0, f"{first_index=} is not 0"
|
||||||
|
assert first_index == episode_data_index["from"][first_episode_idx].item()
|
||||||
|
assert last_index == episode_data_index["to"][last_episode_idx].item() - 1
|
||||||
|
|
||||||
if len(online_dataset) == 0:
|
if len(online_dataset) == 0:
|
||||||
# initialize online dataset
|
# initialize online dataset
|
||||||
online_dataset.hf_dataset = hf_dataset
|
online_dataset.hf_dataset = hf_dataset
|
||||||
|
online_dataset.episode_data_index = episode_data_index
|
||||||
else:
|
else:
|
||||||
# find episode index and data frame indices according to previous episode in online_dataset
|
# get the starting indices of the new episodes and frames to be added
|
||||||
start_episode = online_dataset.select_columns("episode_index")[-1]["episode_index"].item() + 1
|
start_episode_idx = last_episode_idx + 1
|
||||||
start_index = online_dataset.select_columns("index")[-1]["index"].item() + 1
|
start_index = last_index + 1
|
||||||
|
|
||||||
def shift_indices(example):
|
def shift_indices(episode_index, index):
|
||||||
# note: we dont shift "frame_index" since it represents the index of the frame in the episode it belongs to
|
# note: we dont shift "frame_index" since it represents the index of the frame in the episode it belongs to
|
||||||
example["episode_index"] += start_episode
|
example = {"episode_index": episode_index + start_episode_idx, "index": index + start_index}
|
||||||
example["index"] += start_index
|
|
||||||
return example
|
return example
|
||||||
|
|
||||||
disable_progress_bars() # map has a tqdm progress bar
|
disable_progress_bars() # map has a tqdm progress bar
|
||||||
hf_dataset = hf_dataset.map(shift_indices)
|
hf_dataset = hf_dataset.map(shift_indices, input_columns=["episode_index", "index"])
|
||||||
enable_progress_bars()
|
enable_progress_bars()
|
||||||
|
|
||||||
episode_data_index["from"] += start_index
|
episode_data_index["from"] += start_index
|
||||||
|
@ -306,6 +311,7 @@ def train(cfg: dict, out_dir=None, job_name=None):
|
||||||
# create an empty online dataset similar to offline dataset
|
# create an empty online dataset similar to offline dataset
|
||||||
online_dataset = deepcopy(offline_dataset)
|
online_dataset = deepcopy(offline_dataset)
|
||||||
online_dataset.hf_dataset = {}
|
online_dataset.hf_dataset = {}
|
||||||
|
online_dataset.episode_data_index = {}
|
||||||
|
|
||||||
# create dataloader for online training
|
# create dataloader for online training
|
||||||
concat_dataset = torch.utils.data.ConcatDataset([offline_dataset, online_dataset])
|
concat_dataset = torch.utils.data.ConcatDataset([offline_dataset, online_dataset])
|
||||||
|
|
Loading…
Reference in New Issue