use datasets.Dataset type

This commit is contained in:
Cadene 2024-04-18 09:21:29 +00:00
parent 9e0e0ab5cc
commit ae2a643cf2
2 changed files with 6 additions and 4 deletions

View File

@ -1,6 +1,7 @@
from copy import deepcopy
from math import ceil
import datasets
import einops
import torch
import tqdm
@ -8,7 +9,7 @@ import tqdm
def load_previous_and_future_frames(
item: dict[str, torch.Tensor],
hf_dataset: dict[str, torch.Tensor],
hf_dataset: dict[str, datasets.Dataset],
delta_timestamps: dict[str, list[float]],
tol: float = 0.04,
) -> dict[torch.Tensor]:

View File

@ -2,9 +2,10 @@ import logging
from copy import deepcopy
from pathlib import Path
import datasets
import hydra
import torch
from datasets import Dataset, concatenate_datasets
from datasets import concatenate_datasets
from datasets.utils import disable_progress_bars, enable_progress_bars
from lerobot.common.datasets.factory import make_dataset
@ -131,7 +132,7 @@ def calculate_online_sample_weight(n_off: int, n_on: int, pc_on: float):
def add_episodes_inplace(
hf_dataset: Dataset,
hf_dataset: datasets.Dataset,
online_dataset: torch.utils.data.Dataset,
concat_dataset: torch.utils.data.ConcatDataset,
sampler: torch.utils.data.WeightedRandomSampler,
@ -144,7 +145,7 @@ def add_episodes_inplace(
percentage of online samples.
Parameters:
- hf_dataset (Dataset): A Hugging Face dataset containing the new episodes to be added.
- hf_dataset (datasets.Dataset): A Hugging Face dataset containing the new episodes to be added.
- online_dataset (torch.utils.data.Dataset): The existing online dataset to be updated.
- concat_dataset (torch.utils.data.ConcatDataset): The concatenated dataset that combines
offline and online datasets, used for sampling purposes.