From ae2a643cf249a031833baf39c376833b64915a9a Mon Sep 17 00:00:00 2001 From: Cadene Date: Thu, 18 Apr 2024 09:21:29 +0000 Subject: [PATCH] use datasets.Dataset type --- lerobot/common/datasets/utils.py | 3 ++- lerobot/scripts/train.py | 7 ++++--- 2 files changed, 6 insertions(+), 4 deletions(-) diff --git a/lerobot/common/datasets/utils.py b/lerobot/common/datasets/utils.py index aad74375..6b1a5b19 100644 --- a/lerobot/common/datasets/utils.py +++ b/lerobot/common/datasets/utils.py @@ -1,6 +1,7 @@ from copy import deepcopy from math import ceil +import datasets import einops import torch import tqdm @@ -8,7 +9,7 @@ import tqdm def load_previous_and_future_frames( item: dict[str, torch.Tensor], - hf_dataset: dict[str, torch.Tensor], + hf_dataset: dict[str, datasets.Dataset], delta_timestamps: dict[str, list[float]], tol: float = 0.04, ) -> dict[torch.Tensor]: diff --git a/lerobot/scripts/train.py b/lerobot/scripts/train.py index 6cd02f18..bd9ddeac 100644 --- a/lerobot/scripts/train.py +++ b/lerobot/scripts/train.py @@ -2,9 +2,10 @@ import logging from copy import deepcopy from pathlib import Path +import datasets import hydra import torch -from datasets import Dataset, concatenate_datasets +from datasets import concatenate_datasets from datasets.utils import disable_progress_bars, enable_progress_bars from lerobot.common.datasets.factory import make_dataset @@ -131,7 +132,7 @@ def calculate_online_sample_weight(n_off: int, n_on: int, pc_on: float): def add_episodes_inplace( - hf_dataset: Dataset, + hf_dataset: datasets.Dataset, online_dataset: torch.utils.data.Dataset, concat_dataset: torch.utils.data.ConcatDataset, sampler: torch.utils.data.WeightedRandomSampler, @@ -144,7 +145,7 @@ def add_episodes_inplace( percentage of online samples. Parameters: - - hf_dataset (Dataset): A Hugging Face dataset containing the new episodes to be added. + - hf_dataset (datasets.Dataset): A Hugging Face dataset containing the new episodes to be added. - online_dataset (torch.utils.data.Dataset): The existing online dataset to be updated. - concat_dataset (torch.utils.data.ConcatDataset): The concatenated dataset that combines offline and online datasets, used for sampling purposes.