use datasets.Dataset type
This commit is contained in:
parent
9e0e0ab5cc
commit
ae2a643cf2
|
@ -1,6 +1,7 @@
|
|||
from copy import deepcopy
|
||||
from math import ceil
|
||||
|
||||
import datasets
|
||||
import einops
|
||||
import torch
|
||||
import tqdm
|
||||
|
@ -8,7 +9,7 @@ import tqdm
|
|||
|
||||
def load_previous_and_future_frames(
|
||||
item: dict[str, torch.Tensor],
|
||||
hf_dataset: dict[str, torch.Tensor],
|
||||
hf_dataset: dict[str, datasets.Dataset],
|
||||
delta_timestamps: dict[str, list[float]],
|
||||
tol: float = 0.04,
|
||||
) -> dict[torch.Tensor]:
|
||||
|
|
|
@ -2,9 +2,10 @@ import logging
|
|||
from copy import deepcopy
|
||||
from pathlib import Path
|
||||
|
||||
import datasets
|
||||
import hydra
|
||||
import torch
|
||||
from datasets import Dataset, concatenate_datasets
|
||||
from datasets import concatenate_datasets
|
||||
from datasets.utils import disable_progress_bars, enable_progress_bars
|
||||
|
||||
from lerobot.common.datasets.factory import make_dataset
|
||||
|
@ -131,7 +132,7 @@ def calculate_online_sample_weight(n_off: int, n_on: int, pc_on: float):
|
|||
|
||||
|
||||
def add_episodes_inplace(
|
||||
hf_dataset: Dataset,
|
||||
hf_dataset: datasets.Dataset,
|
||||
online_dataset: torch.utils.data.Dataset,
|
||||
concat_dataset: torch.utils.data.ConcatDataset,
|
||||
sampler: torch.utils.data.WeightedRandomSampler,
|
||||
|
@ -144,7 +145,7 @@ def add_episodes_inplace(
|
|||
percentage of online samples.
|
||||
|
||||
Parameters:
|
||||
- hf_dataset (Dataset): A Hugging Face dataset containing the new episodes to be added.
|
||||
- hf_dataset (datasets.Dataset): A Hugging Face dataset containing the new episodes to be added.
|
||||
- online_dataset (torch.utils.data.Dataset): The existing online dataset to be updated.
|
||||
- concat_dataset (torch.utils.data.ConcatDataset): The concatenated dataset that combines
|
||||
offline and online datasets, used for sampling purposes.
|
||||
|
|
Loading…
Reference in New Issue