move hf_datasets at the end of add_episodes_inplace
This commit is contained in:
parent
a7dffc8359
commit
77088ffd66
|
@ -132,10 +132,10 @@ def calculate_online_sample_weight(n_off: int, n_on: int, pc_on: float):
|
||||||
|
|
||||||
|
|
||||||
def add_episodes_inplace(
|
def add_episodes_inplace(
|
||||||
hf_dataset: datasets.Dataset,
|
|
||||||
online_dataset: torch.utils.data.Dataset,
|
online_dataset: torch.utils.data.Dataset,
|
||||||
concat_dataset: torch.utils.data.ConcatDataset,
|
concat_dataset: torch.utils.data.ConcatDataset,
|
||||||
sampler: torch.utils.data.WeightedRandomSampler,
|
sampler: torch.utils.data.WeightedRandomSampler,
|
||||||
|
hf_dataset: datasets.Dataset,
|
||||||
pc_online_samples: float,
|
pc_online_samples: float,
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
|
@ -145,12 +145,12 @@ def add_episodes_inplace(
|
||||||
percentage of online samples.
|
percentage of online samples.
|
||||||
|
|
||||||
Parameters:
|
Parameters:
|
||||||
- hf_dataset (datasets.Dataset): A Hugging Face dataset containing the new episodes to be added.
|
|
||||||
- online_dataset (torch.utils.data.Dataset): The existing online dataset to be updated.
|
- online_dataset (torch.utils.data.Dataset): The existing online dataset to be updated.
|
||||||
- concat_dataset (torch.utils.data.ConcatDataset): The concatenated dataset that combines
|
- concat_dataset (torch.utils.data.ConcatDataset): The concatenated dataset that combines
|
||||||
offline and online datasets, used for sampling purposes.
|
offline and online datasets, used for sampling purposes.
|
||||||
- sampler (torch.utils.data.WeightedRandomSampler): A sampler that will be updated to
|
- sampler (torch.utils.data.WeightedRandomSampler): A sampler that will be updated to
|
||||||
reflect changes in the dataset sizes and specified sampling weights.
|
reflect changes in the dataset sizes and specified sampling weights.
|
||||||
|
- hf_dataset (datasets.Dataset): A Hugging Face dataset containing the new episodes to be added.
|
||||||
- pc_online_samples (float): The target percentage of samples that should come from
|
- pc_online_samples (float): The target percentage of samples that should come from
|
||||||
the online dataset during sampling operations.
|
the online dataset during sampling operations.
|
||||||
|
|
||||||
|
@ -335,7 +335,7 @@ def train(cfg: dict, out_dir=None, job_name=None):
|
||||||
|
|
||||||
online_pc_sampling = cfg.get("demo_schedule", 0.5)
|
online_pc_sampling = cfg.get("demo_schedule", 0.5)
|
||||||
add_episodes_inplace(
|
add_episodes_inplace(
|
||||||
eval_info["episodes"], online_dataset, concat_dataset, sampler, online_pc_sampling
|
online_dataset, concat_dataset, sampler, eval_info["episodes"], online_pc_sampling
|
||||||
)
|
)
|
||||||
|
|
||||||
for _ in range(cfg.policy.utd):
|
for _ in range(cfg.policy.utd):
|
||||||
|
|
Loading…
Reference in New Issue