ready for review

This commit is contained in:
Alexander Soare 2024-05-09 12:22:43 +01:00
parent 7bb5b15f4c
commit d767fdf958
1 changed files with 55 additions and 53 deletions

View File

@ -10,6 +10,7 @@ from datasets import concatenate_datasets
from datasets.utils import disable_progress_bars, enable_progress_bars
from lerobot.common.datasets.factory import make_dataset
from lerobot.common.datasets.lerobot_dataset import LeRobotDataset
from lerobot.common.datasets.utils import cycle
from lerobot.common.envs.factory import make_env
from lerobot.common.logger import Logger, log_output_dir
@ -214,77 +215,78 @@ def calculate_online_sample_weight(n_off: int, n_on: int, pc_on: float):
def add_episodes_inplace(
online_dataset: torch.utils.data.Dataset,
online_dataset: LeRobotDataset,
concat_dataset: torch.utils.data.ConcatDataset,
sampler: torch.utils.data.WeightedRandomSampler,
hf_dataset: datasets.Dataset,
episode_data_index: dict[str, torch.Tensor],
pc_online_samples: float,
new_hf_dataset: datasets.Dataset,
new_episode_data_index: dict[str, torch.Tensor],
online_sampling_ratio: float,
):
"""
Modifies the online_dataset, concat_dataset, and sampler in place by integrating
new episodes from hf_dataset into the online_dataset, updating the concatenated
new episodes from new_hf_dataset into the online_dataset, updating the concatenated
dataset's structure and adjusting the sampling strategy based on the specified
percentage of online samples.
Parameters:
- online_dataset (torch.utils.data.Dataset): The existing online dataset to be updated.
- concat_dataset (torch.utils.data.ConcatDataset): The concatenated dataset that combines
offline and online datasets, used for sampling purposes.
- sampler (torch.utils.data.WeightedRandomSampler): A sampler that will be updated to
reflect changes in the dataset sizes and specified sampling weights.
- hf_dataset (datasets.Dataset): A Hugging Face dataset containing the new episodes to be added.
- episode_data_index (dict): A dictionary containing two keys ("from" and "to") associated to dataset indices.
They indicate the start index and end index of each episode in the dataset.
- pc_online_samples (float): The target percentage of samples that should come from
the online dataset during sampling operations.
Raises:
- AssertionError: If the first episode_id or index in hf_dataset is not 0
Args:
online_dataset: The existing online dataset to be updated.
concat_dataset: The concatenated dataset that combines offline and online datasets (in that order),
used for sampling purposes.
sampler: A sampler that will be updated to reflect changes in the dataset sizes and specified sampling
weights.
new_hf_dataset: A Hugging Face dataset containing the new episodes to be added.
new_episode_data_index: A dictionary containing two keys ("from" and "to") associated to dataset
indices. They indicate the start index and end index of each episode in the dataset.
online_sampling_ratio: The target percentage of samples that should come from the online dataset
during sampling operations.
"""
first_episode_idx = hf_dataset.select_columns("episode_index")[0]["episode_index"].item()
last_episode_idx = hf_dataset.select_columns("episode_index")[-1]["episode_index"].item()
first_index = hf_dataset.select_columns("index")[0]["index"].item()
last_index = hf_dataset.select_columns("index")[-1]["index"].item()
# sanity check
assert first_episode_idx == 0, f"{first_episode_idx=} is not 0"
assert first_index == 0, f"{first_index=} is not 0"
assert first_index == episode_data_index["from"][first_episode_idx].item()
assert last_index == episode_data_index["to"][last_episode_idx].item() - 1
# Sanity check to make sure that new_hf_dataset starts from 0.
assert new_hf_dataset["episode_index"][0].item() == 0
assert new_hf_dataset["index"][0].item() == 0
# Sanity check to make sure that new_episode_data_index is aligned with new_hf_dataset.
assert new_episode_data_index["from"][0].item() == 0
assert new_episode_data_index["to"] - 1 == new_hf_dataset["index"][-1].item()
if len(online_dataset) == 0:
# initialize online dataset
online_dataset.hf_dataset = hf_dataset
online_dataset.episode_data_index = episode_data_index
# Initialize online dataset.
online_dataset.hf_dataset = new_hf_dataset
online_dataset.episode_data_index = new_episode_data_index
else:
# get the starting indices of the new episodes and frames to be added
start_episode_idx = last_episode_idx + 1
start_index = last_index + 1
# Get the indices required to continue where the data in concat_dataset finishes.
start_episode_idx = concat_dataset.datasets[-1].hf_dataset["episode_index"][-1].item() + 1
start_index = concat_dataset.datasets[-1].hf_dataset["index"][-1].item() + 1
def shift_indices(episode_index, index):
# note: we dont shift "frame_index" since it represents the index of the frame in the episode it belongs to
example = {"episode_index": episode_index + start_episode_idx, "index": index + start_index}
return example
disable_progress_bars() # map has a tqdm progress bar
hf_dataset = hf_dataset.map(shift_indices, input_columns=["episode_index", "index"])
# Shift the indices of new_hf_dataset.
disable_progress_bars() # Dataset.map has a tqdm progress bar
# note: we dont shift "frame_index" since it represents the index of the frame in the episode it
# belongs to
new_hf_dataset = new_hf_dataset.map(
lambda episode_index, data_index: {
"episode_index": episode_index + start_episode_idx,
"index": data_index + start_index,
},
input_columns=["episode_index", "index"],
)
enable_progress_bars()
episode_data_index["from"] += start_index
episode_data_index["to"] += start_index
# extend online dataset
online_dataset.hf_dataset = concatenate_datasets([online_dataset.hf_dataset, hf_dataset])
# Extend the online dataset with the new data.
online_dataset.hf_dataset = concatenate_datasets([online_dataset.hf_dataset, new_hf_dataset])
online_dataset.episode_data_index = {
k: torch.cat([online_dataset.episode_data_index[k], new_episode_data_index[k] + start_index])
for k in ["from", "to"]
}
# update the concatenated dataset length used during sampling
concat_dataset.cumulative_sizes = concat_dataset.cumsum(concat_dataset.datasets)
# update the sampling weights for each frame so that online frames get sampled a certain percentage of times
# update the sampling weights for each frame so that online frames get sampled a certain percentage of
# times
len_online = len(online_dataset)
len_offline = len(concat_dataset) - len_online
weight_offline = 1.0
weight_online = calculate_online_sample_weight(len_offline, len_online, pc_online_samples)
sampler.weights = torch.tensor([weight_offline] * len_offline + [weight_online] * len(online_dataset))
sampler.weights = torch.tensor(
[(1 - online_sampling_ratio) / len_offline] * len_offline
+ [online_sampling_ratio / len_online] * len_online
)
# update the total number of samples used during sampling
sampler.num_samples = len(concat_dataset)
@ -444,9 +446,9 @@ def train(cfg: dict, out_dir=None, job_name=None):
online_dataset,
concat_dataset,
sampler,
hf_dataset=eval_info["episodes"]["hf_dataset"],
episode_data_index=eval_info["episodes"]["episode_data_index"],
pc_online_samples=cfg.training.online_sampling_ratio,
new_hf_dataset=eval_info["episodes"]["hf_dataset"],
new_episode_data_index=eval_info["episodes"]["episode_data_index"],
online_sampling_ratio=cfg.training.online_sampling_ratio,
)
policy.train()