use datasets.Dataset type

This commit is contained in:
Cadene 2024-04-18 09:21:29 +00:00
parent 9e0e0ab5cc
commit ae2a643cf2
2 changed files with 6 additions and 4 deletions

View File

@ -1,6 +1,7 @@
from copy import deepcopy from copy import deepcopy
from math import ceil from math import ceil
import datasets
import einops import einops
import torch import torch
import tqdm import tqdm
@ -8,7 +9,7 @@ import tqdm
def load_previous_and_future_frames( def load_previous_and_future_frames(
item: dict[str, torch.Tensor], item: dict[str, torch.Tensor],
hf_dataset: dict[str, torch.Tensor], hf_dataset: dict[str, datasets.Dataset],
delta_timestamps: dict[str, list[float]], delta_timestamps: dict[str, list[float]],
tol: float = 0.04, tol: float = 0.04,
) -> dict[torch.Tensor]: ) -> dict[torch.Tensor]:

View File

@ -2,9 +2,10 @@ import logging
from copy import deepcopy from copy import deepcopy
from pathlib import Path from pathlib import Path
import datasets
import hydra import hydra
import torch import torch
from datasets import Dataset, concatenate_datasets from datasets import concatenate_datasets
from datasets.utils import disable_progress_bars, enable_progress_bars from datasets.utils import disable_progress_bars, enable_progress_bars
from lerobot.common.datasets.factory import make_dataset from lerobot.common.datasets.factory import make_dataset
@ -131,7 +132,7 @@ def calculate_online_sample_weight(n_off: int, n_on: int, pc_on: float):
def add_episodes_inplace( def add_episodes_inplace(
hf_dataset: Dataset, hf_dataset: datasets.Dataset,
online_dataset: torch.utils.data.Dataset, online_dataset: torch.utils.data.Dataset,
concat_dataset: torch.utils.data.ConcatDataset, concat_dataset: torch.utils.data.ConcatDataset,
sampler: torch.utils.data.WeightedRandomSampler, sampler: torch.utils.data.WeightedRandomSampler,
@ -144,7 +145,7 @@ def add_episodes_inplace(
percentage of online samples. percentage of online samples.
Parameters: Parameters:
- hf_dataset (Dataset): A Hugging Face dataset containing the new episodes to be added. - hf_dataset (datasets.Dataset): A Hugging Face dataset containing the new episodes to be added.
- online_dataset (torch.utils.data.Dataset): The existing online dataset to be updated. - online_dataset (torch.utils.data.Dataset): The existing online dataset to be updated.
- concat_dataset (torch.utils.data.ConcatDataset): The concatenated dataset that combines - concat_dataset (torch.utils.data.ConcatDataset): The concatenated dataset that combines
offline and online datasets, used for sampling purposes. offline and online datasets, used for sampling purposes.