diff --git a/lerobot/scripts/train.py b/lerobot/scripts/train.py index bd9ddeac..b1d63067 100644 --- a/lerobot/scripts/train.py +++ b/lerobot/scripts/train.py @@ -132,10 +132,10 @@ def calculate_online_sample_weight(n_off: int, n_on: int, pc_on: float): def add_episodes_inplace( - hf_dataset: datasets.Dataset, online_dataset: torch.utils.data.Dataset, concat_dataset: torch.utils.data.ConcatDataset, sampler: torch.utils.data.WeightedRandomSampler, + hf_dataset: datasets.Dataset, pc_online_samples: float, ): """ @@ -145,12 +145,12 @@ def add_episodes_inplace( percentage of online samples. Parameters: - - hf_dataset (datasets.Dataset): A Hugging Face dataset containing the new episodes to be added. - online_dataset (torch.utils.data.Dataset): The existing online dataset to be updated. - concat_dataset (torch.utils.data.ConcatDataset): The concatenated dataset that combines offline and online datasets, used for sampling purposes. - sampler (torch.utils.data.WeightedRandomSampler): A sampler that will be updated to reflect changes in the dataset sizes and specified sampling weights. + - hf_dataset (datasets.Dataset): A Hugging Face dataset containing the new episodes to be added. - pc_online_samples (float): The target percentage of samples that should come from the online dataset during sampling operations. @@ -335,7 +335,7 @@ def train(cfg: dict, out_dir=None, job_name=None): online_pc_sampling = cfg.get("demo_schedule", 0.5) add_episodes_inplace( - eval_info["episodes"], online_dataset, concat_dataset, sampler, online_pc_sampling + online_dataset, concat_dataset, sampler, eval_info["episodes"], online_pc_sampling ) for _ in range(cfg.policy.utd):