move hf_datasets at the end of add_episodes_inplace
This commit is contained in:
parent
a7dffc8359
commit
77088ffd66
|
@ -132,10 +132,10 @@ def calculate_online_sample_weight(n_off: int, n_on: int, pc_on: float):
|
|||
|
||||
|
||||
def add_episodes_inplace(
|
||||
hf_dataset: datasets.Dataset,
|
||||
online_dataset: torch.utils.data.Dataset,
|
||||
concat_dataset: torch.utils.data.ConcatDataset,
|
||||
sampler: torch.utils.data.WeightedRandomSampler,
|
||||
hf_dataset: datasets.Dataset,
|
||||
pc_online_samples: float,
|
||||
):
|
||||
"""
|
||||
|
@ -145,12 +145,12 @@ def add_episodes_inplace(
|
|||
percentage of online samples.
|
||||
|
||||
Parameters:
|
||||
- hf_dataset (datasets.Dataset): A Hugging Face dataset containing the new episodes to be added.
|
||||
- online_dataset (torch.utils.data.Dataset): The existing online dataset to be updated.
|
||||
- concat_dataset (torch.utils.data.ConcatDataset): The concatenated dataset that combines
|
||||
offline and online datasets, used for sampling purposes.
|
||||
- sampler (torch.utils.data.WeightedRandomSampler): A sampler that will be updated to
|
||||
reflect changes in the dataset sizes and specified sampling weights.
|
||||
- hf_dataset (datasets.Dataset): A Hugging Face dataset containing the new episodes to be added.
|
||||
- pc_online_samples (float): The target percentage of samples that should come from
|
||||
the online dataset during sampling operations.
|
||||
|
||||
|
@ -335,7 +335,7 @@ def train(cfg: dict, out_dir=None, job_name=None):
|
|||
|
||||
online_pc_sampling = cfg.get("demo_schedule", 0.5)
|
||||
add_episodes_inplace(
|
||||
eval_info["episodes"], online_dataset, concat_dataset, sampler, online_pc_sampling
|
||||
online_dataset, concat_dataset, sampler, eval_info["episodes"], online_pc_sampling
|
||||
)
|
||||
|
||||
for _ in range(cfg.policy.utd):
|
||||
|
|
Loading…
Reference in New Issue