move hf_datasets at the end of add_episodes_inplace

This commit is contained in:
Cadene 2024-04-18 09:25:13 +00:00
parent a7dffc8359
commit 77088ffd66
1 changed files with 3 additions and 3 deletions

View File

@ -132,10 +132,10 @@ def calculate_online_sample_weight(n_off: int, n_on: int, pc_on: float):
def add_episodes_inplace( def add_episodes_inplace(
hf_dataset: datasets.Dataset,
online_dataset: torch.utils.data.Dataset, online_dataset: torch.utils.data.Dataset,
concat_dataset: torch.utils.data.ConcatDataset, concat_dataset: torch.utils.data.ConcatDataset,
sampler: torch.utils.data.WeightedRandomSampler, sampler: torch.utils.data.WeightedRandomSampler,
hf_dataset: datasets.Dataset,
pc_online_samples: float, pc_online_samples: float,
): ):
""" """
@ -145,12 +145,12 @@ def add_episodes_inplace(
percentage of online samples. percentage of online samples.
Parameters: Parameters:
- hf_dataset (datasets.Dataset): A Hugging Face dataset containing the new episodes to be added.
- online_dataset (torch.utils.data.Dataset): The existing online dataset to be updated. - online_dataset (torch.utils.data.Dataset): The existing online dataset to be updated.
- concat_dataset (torch.utils.data.ConcatDataset): The concatenated dataset that combines - concat_dataset (torch.utils.data.ConcatDataset): The concatenated dataset that combines
offline and online datasets, used for sampling purposes. offline and online datasets, used for sampling purposes.
- sampler (torch.utils.data.WeightedRandomSampler): A sampler that will be updated to - sampler (torch.utils.data.WeightedRandomSampler): A sampler that will be updated to
reflect changes in the dataset sizes and specified sampling weights. reflect changes in the dataset sizes and specified sampling weights.
- hf_dataset (datasets.Dataset): A Hugging Face dataset containing the new episodes to be added.
- pc_online_samples (float): The target percentage of samples that should come from - pc_online_samples (float): The target percentage of samples that should come from
the online dataset during sampling operations. the online dataset during sampling operations.
@ -335,7 +335,7 @@ def train(cfg: dict, out_dir=None, job_name=None):
online_pc_sampling = cfg.get("demo_schedule", 0.5) online_pc_sampling = cfg.get("demo_schedule", 0.5)
add_episodes_inplace( add_episodes_inplace(
eval_info["episodes"], online_dataset, concat_dataset, sampler, online_pc_sampling online_dataset, concat_dataset, sampler, eval_info["episodes"], online_pc_sampling
) )
for _ in range(cfg.policy.utd): for _ in range(cfg.policy.utd):