Apply suggestions from code review
This commit is contained in:
parent
f56d769dfb
commit
23f6c875b5
|
@ -77,7 +77,7 @@ print(dataset.hf_dataset)
|
||||||
|
|
||||||
# LeRobot datasets also subclasses PyTorch datasets so you can do everything you know and love from working
|
# LeRobot datasets also subclasses PyTorch datasets so you can do everything you know and love from working
|
||||||
# with the latter, like iterating through the dataset.
|
# with the latter, like iterating through the dataset.
|
||||||
# The __get_item__ iterates over the frames of the dataset. Since our datasets are also structured by
|
# The __getitem__ iterates over the frames of the dataset. Since our datasets are also structured by
|
||||||
# episodes, you can access the frame indices of any episode using the episode_data_index. Here, we access
|
# episodes, you can access the frame indices of any episode using the episode_data_index. Here, we access
|
||||||
# frame indices associated to the first episode:
|
# frame indices associated to the first episode:
|
||||||
episode_index = 0
|
episode_index = 0
|
||||||
|
|
|
@ -1,7 +1,7 @@
|
||||||
"""
|
"""
|
||||||
This script demonstrates how to use torchvision's image transformation with LeRobotDataset for data
|
This script demonstrates how to use torchvision's image transformation with LeRobotDataset for data
|
||||||
augmentation purposes. The transformations are passed to the dataset as an argument upon creation, and
|
augmentation purposes. The transformations are passed to the dataset as an argument upon creation, and
|
||||||
transforms are applied to the observation images before they are returned in the dataset's __get_item__.
|
transforms are applied to the observation images before they are returned in the dataset's __getitem__.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
|
@ -8,7 +8,6 @@ especially in the context of imitation learning. The most reliable approach is t
|
||||||
on the target environment, whether that be in simulation or the real world.
|
on the target environment, whether that be in simulation or the real world.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
# TODO(aliberts, rcadene): Update this script with the new v2 api
|
|
||||||
import math
|
import math
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
|
|
|
@ -170,25 +170,28 @@ def aggregate_stats(ls_datasets) -> dict[str, torch.Tensor]:
|
||||||
"""
|
"""
|
||||||
data_keys = set()
|
data_keys = set()
|
||||||
for dataset in ls_datasets:
|
for dataset in ls_datasets:
|
||||||
data_keys.update(dataset.stats.keys())
|
data_keys.update(dataset.meta.stats.keys())
|
||||||
stats = {k: {} for k in data_keys}
|
stats = {k: {} for k in data_keys}
|
||||||
for data_key in data_keys:
|
for data_key in data_keys:
|
||||||
for stat_key in ["min", "max"]:
|
for stat_key in ["min", "max"]:
|
||||||
# compute `max(dataset_0["max"], dataset_1["max"], ...)`
|
# compute `max(dataset_0["max"], dataset_1["max"], ...)`
|
||||||
stats[data_key][stat_key] = einops.reduce(
|
stats[data_key][stat_key] = einops.reduce(
|
||||||
torch.stack([d.stats[data_key][stat_key] for d in ls_datasets if data_key in d.stats], dim=0),
|
torch.stack(
|
||||||
|
[ds.meta.stats[data_key][stat_key] for ds in ls_datasets if data_key in ds.meta.stats],
|
||||||
|
dim=0,
|
||||||
|
),
|
||||||
"n ... -> ...",
|
"n ... -> ...",
|
||||||
stat_key,
|
stat_key,
|
||||||
)
|
)
|
||||||
total_samples = sum(d.num_frames for d in ls_datasets if data_key in d.stats)
|
total_samples = sum(d.num_frames for d in ls_datasets if data_key in d.meta.stats)
|
||||||
# Compute the "sum" statistic by multiplying each mean by the number of samples in the respective
|
# Compute the "sum" statistic by multiplying each mean by the number of samples in the respective
|
||||||
# dataset, then divide by total_samples to get the overall "mean".
|
# dataset, then divide by total_samples to get the overall "mean".
|
||||||
# NOTE: the brackets around (d.num_frames / total_samples) are needed tor minimize the risk of
|
# NOTE: the brackets around (d.num_frames / total_samples) are needed tor minimize the risk of
|
||||||
# numerical overflow!
|
# numerical overflow!
|
||||||
stats[data_key]["mean"] = sum(
|
stats[data_key]["mean"] = sum(
|
||||||
d.stats[data_key]["mean"] * (d.num_frames / total_samples)
|
d.meta.stats[data_key]["mean"] * (d.num_frames / total_samples)
|
||||||
for d in ls_datasets
|
for d in ls_datasets
|
||||||
if data_key in d.stats
|
if data_key in d.meta.stats
|
||||||
)
|
)
|
||||||
# The derivation for standard deviation is a little more involved but is much in the same spirit as
|
# The derivation for standard deviation is a little more involved but is much in the same spirit as
|
||||||
# the computation of the mean.
|
# the computation of the mean.
|
||||||
|
@ -199,102 +202,13 @@ def aggregate_stats(ls_datasets) -> dict[str, torch.Tensor]:
|
||||||
# numerical overflow!
|
# numerical overflow!
|
||||||
stats[data_key]["std"] = torch.sqrt(
|
stats[data_key]["std"] = torch.sqrt(
|
||||||
sum(
|
sum(
|
||||||
(d.stats[data_key]["std"] ** 2 + (d.stats[data_key]["mean"] - stats[data_key]["mean"]) ** 2)
|
(
|
||||||
|
d.meta.stats[data_key]["std"] ** 2
|
||||||
|
+ (d.meta.stats[data_key]["mean"] - stats[data_key]["mean"]) ** 2
|
||||||
|
)
|
||||||
* (d.num_frames / total_samples)
|
* (d.num_frames / total_samples)
|
||||||
for d in ls_datasets
|
for d in ls_datasets
|
||||||
if data_key in d.stats
|
if data_key in d.meta.stats
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
return stats
|
return stats
|
||||||
|
|
||||||
|
|
||||||
# TODO(aliberts): refactor stats in save_episodes
|
|
||||||
# import numpy as np
|
|
||||||
# from lerobot.common.datasets.utils import load_image_as_numpy
|
|
||||||
# def aggregate_stats_v2(stats_list: list) -> dict:
|
|
||||||
# """Aggregate stats from multiple compute_stats outputs into a single set of stats.
|
|
||||||
|
|
||||||
# The final stats will have the union of all data keys from each of the stats dicts.
|
|
||||||
|
|
||||||
# For instance:
|
|
||||||
# - new_min = min(min_dataset_0, min_dataset_1, ...)
|
|
||||||
# - new_max = max(max_dataset_0, max_dataset_1, ...)
|
|
||||||
# - new_mean = (mean of all data, weighted by counts)
|
|
||||||
# - new_std = (std of all data)
|
|
||||||
# """
|
|
||||||
# data_keys = set(key for stats in stats_list for key in stats.keys())
|
|
||||||
# aggregated_stats = {key: {} for key in data_keys}
|
|
||||||
|
|
||||||
# for key in data_keys:
|
|
||||||
# # Collect stats for the current key from all datasets where it exists
|
|
||||||
# stats_with_key = [stats[key] for stats in stats_list if key in stats]
|
|
||||||
|
|
||||||
# # Aggregate 'min' and 'max' using np.minimum and np.maximum
|
|
||||||
# aggregated_stats[key]['min'] = np.minimum.reduce([s['min'] for s in stats_with_key])
|
|
||||||
# aggregated_stats[key]['max'] = np.maximum.reduce([s['max'] for s in stats_with_key])
|
|
||||||
|
|
||||||
# # Extract means, variances (std^2), and counts
|
|
||||||
# means = np.array([s['mean'] for s in stats_with_key])
|
|
||||||
# variances = np.array([s['std']**2 for s in stats_with_key])
|
|
||||||
# counts = np.array([s['count'] for s in stats_with_key])
|
|
||||||
|
|
||||||
# # Ensure counts can broadcast with means/variances if they have additional dimensions
|
|
||||||
# counts = counts.reshape(-1, *[1]*(means.ndim - 1))
|
|
||||||
|
|
||||||
# # Compute total counts
|
|
||||||
# total_count = counts.sum(axis=0)
|
|
||||||
|
|
||||||
# # Compute the weighted mean
|
|
||||||
# weighted_means = means * counts
|
|
||||||
# total_mean = weighted_means.sum(axis=0) / total_count
|
|
||||||
|
|
||||||
# # Compute the variance using the parallel algorithm
|
|
||||||
# delta_means = means - total_mean
|
|
||||||
# weighted_variances = (variances + delta_means**2) * counts
|
|
||||||
# total_variance = weighted_variances.sum(axis=0) / total_count
|
|
||||||
|
|
||||||
# # Store the aggregated stats
|
|
||||||
# aggregated_stats[key]['mean'] = total_mean
|
|
||||||
# aggregated_stats[key]['std'] = np.sqrt(total_variance)
|
|
||||||
# aggregated_stats[key]['count'] = total_count
|
|
||||||
|
|
||||||
# return aggregated_stats
|
|
||||||
|
|
||||||
|
|
||||||
# def compute_episode_stats(episode_buffer: dict, features: dict, episode_length: int, image_sampling: int = 10) -> dict:
|
|
||||||
# stats = {}
|
|
||||||
# for key, data in episode_buffer.items():
|
|
||||||
# if features[key]["dtype"] in ["image", "video"]:
|
|
||||||
# stats[key] = compute_image_stats(data, sampling=image_sampling)
|
|
||||||
# else:
|
|
||||||
# axes_to_reduce = 0 # Compute stats over the first axis
|
|
||||||
# stats[key] = {
|
|
||||||
# "min": np.min(data, axis=axes_to_reduce),
|
|
||||||
# "max": np.max(data, axis=axes_to_reduce),
|
|
||||||
# "mean": np.mean(data, axis=axes_to_reduce),
|
|
||||||
# "std": np.std(data, axis=axes_to_reduce),
|
|
||||||
# "count": episode_length,
|
|
||||||
# }
|
|
||||||
# return stats
|
|
||||||
|
|
||||||
|
|
||||||
# def compute_image_stats(image_paths: list[str], sampling: int = 10) -> dict:
|
|
||||||
# images = []
|
|
||||||
# samples = range(0, len(image_paths), sampling)
|
|
||||||
# for idx in samples:
|
|
||||||
# path = image_paths[idx]
|
|
||||||
# img = load_image_as_numpy(path, channel_first=True)
|
|
||||||
# images.append(img)
|
|
||||||
|
|
||||||
# images = np.stack(images)
|
|
||||||
# axes_to_reduce = (0, 2, 3) # keep channel dim
|
|
||||||
# image_stats = {
|
|
||||||
# "min": np.min(images, axis=axes_to_reduce, keepdims=True),
|
|
||||||
# "max": np.max(images, axis=axes_to_reduce, keepdims=True),
|
|
||||||
# "mean": np.mean(images, axis=axes_to_reduce, keepdims=True),
|
|
||||||
# "std": np.std(images, axis=axes_to_reduce, keepdims=True)
|
|
||||||
# }
|
|
||||||
# for key in image_stats: # squeeze batch dim
|
|
||||||
# image_stats[key] = np.squeeze(image_stats[key], axis=0)
|
|
||||||
|
|
||||||
# return image_stats
|
|
||||||
|
|
|
@ -63,7 +63,7 @@ def write_image(image: np.ndarray | PIL.Image.Image, fpath: Path):
|
||||||
print(f"Error writing image {fpath}: {e}")
|
print(f"Error writing image {fpath}: {e}")
|
||||||
|
|
||||||
|
|
||||||
def worker_thread_process(queue: queue.Queue):
|
def worker_thread_loop(queue: queue.Queue):
|
||||||
while True:
|
while True:
|
||||||
item = queue.get()
|
item = queue.get()
|
||||||
if item is None:
|
if item is None:
|
||||||
|
@ -77,7 +77,7 @@ def worker_thread_process(queue: queue.Queue):
|
||||||
def worker_process(queue: queue.Queue, num_threads: int):
|
def worker_process(queue: queue.Queue, num_threads: int):
|
||||||
threads = []
|
threads = []
|
||||||
for _ in range(num_threads):
|
for _ in range(num_threads):
|
||||||
t = threading.Thread(target=worker_thread_process, args=(queue,))
|
t = threading.Thread(target=worker_thread_loop, args=(queue,))
|
||||||
t.daemon = True
|
t.daemon = True
|
||||||
t.start()
|
t.start()
|
||||||
threads.append(t)
|
threads.append(t)
|
||||||
|
@ -115,7 +115,7 @@ class AsyncImageWriter:
|
||||||
# Use threading
|
# Use threading
|
||||||
self.queue = queue.Queue()
|
self.queue = queue.Queue()
|
||||||
for _ in range(self.num_threads):
|
for _ in range(self.num_threads):
|
||||||
t = threading.Thread(target=worker_thread_process, args=(self.queue,))
|
t = threading.Thread(target=worker_thread_loop, args=(self.queue,))
|
||||||
t.daemon = True
|
t.daemon = True
|
||||||
t.start()
|
t.start()
|
||||||
self.threads.append(t)
|
self.threads.append(t)
|
||||||
|
|
|
@ -427,12 +427,12 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||||
"""
|
"""
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.repo_id = repo_id
|
self.repo_id = repo_id
|
||||||
self.root = Path(root) if root is not None else LEROBOT_HOME / repo_id
|
self.root = Path(root) if root else LEROBOT_HOME / repo_id
|
||||||
self.image_transforms = image_transforms
|
self.image_transforms = image_transforms
|
||||||
self.delta_timestamps = delta_timestamps
|
self.delta_timestamps = delta_timestamps
|
||||||
self.episodes = episodes
|
self.episodes = episodes
|
||||||
self.tolerance_s = tolerance_s
|
self.tolerance_s = tolerance_s
|
||||||
self.video_backend = video_backend if video_backend is not None else "pyav"
|
self.video_backend = video_backend if video_backend else "pyav"
|
||||||
self.delta_indices = None
|
self.delta_indices = None
|
||||||
self.local_files_only = local_files_only
|
self.local_files_only = local_files_only
|
||||||
|
|
||||||
|
@ -473,10 +473,11 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||||
**card_kwargs,
|
**card_kwargs,
|
||||||
) -> None:
|
) -> None:
|
||||||
if not self.consolidated:
|
if not self.consolidated:
|
||||||
raise RuntimeError(
|
logging.warning(
|
||||||
"You are trying to upload to the hub a LeRobotDataset that has not been consolidated yet."
|
"You are trying to upload to the hub a LeRobotDataset that has not been consolidated yet. "
|
||||||
"Please call the dataset 'consolidate()' method first."
|
"Consolidating first."
|
||||||
)
|
)
|
||||||
|
self.consolidate()
|
||||||
|
|
||||||
ignore_patterns = ["images/"]
|
ignore_patterns = ["images/"]
|
||||||
if not push_videos:
|
if not push_videos:
|
||||||
|
@ -750,7 +751,10 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||||
episode_index = episode_buffer["episode_index"]
|
episode_index = episode_buffer["episode_index"]
|
||||||
if episode_index != self.meta.total_episodes:
|
if episode_index != self.meta.total_episodes:
|
||||||
# TODO(aliberts): Add option to use existing episode_index
|
# TODO(aliberts): Add option to use existing episode_index
|
||||||
raise NotImplementedError()
|
raise NotImplementedError(
|
||||||
|
"You might have manually provided the episode_buffer with an episode_index that doesn't "
|
||||||
|
"match the total number of episodes in the dataset. This is not supported for now."
|
||||||
|
)
|
||||||
|
|
||||||
if episode_length == 0:
|
if episode_length == 0:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
|
@ -818,7 +822,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||||
def start_image_writer(self, num_processes: int = 0, num_threads: int = 4) -> None:
|
def start_image_writer(self, num_processes: int = 0, num_threads: int = 4) -> None:
|
||||||
if isinstance(self.image_writer, AsyncImageWriter):
|
if isinstance(self.image_writer, AsyncImageWriter):
|
||||||
logging.warning(
|
logging.warning(
|
||||||
"You are starting a new AsyncImageWriter that is replacing an already exising one in the dataset."
|
"You are starting a new AsyncImageWriter that is replacing an already existing one in the dataset."
|
||||||
)
|
)
|
||||||
|
|
||||||
self.image_writer = AsyncImageWriter(
|
self.image_writer = AsyncImageWriter(
|
||||||
|
@ -965,56 +969,56 @@ class MultiLeRobotDataset(torch.utils.data.Dataset):
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
repo_ids: list[str],
|
repo_ids: list[str],
|
||||||
root: Path | None = None,
|
root: str | Path | None = None,
|
||||||
episodes: dict | None = None,
|
episodes: dict | None = None,
|
||||||
image_transforms: Callable | None = None,
|
image_transforms: Callable | None = None,
|
||||||
delta_timestamps: dict[list[float]] | None = None,
|
delta_timestamps: dict[list[float]] | None = None,
|
||||||
|
tolerances_s: dict | None = None,
|
||||||
|
download_videos: bool = True,
|
||||||
|
local_files_only: bool = False,
|
||||||
video_backend: str | None = None,
|
video_backend: str | None = None,
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.repo_ids = repo_ids
|
self.repo_ids = repo_ids
|
||||||
|
self.root = Path(root) if root else LEROBOT_HOME
|
||||||
|
self.tolerances_s = tolerances_s if tolerances_s else {repo_id: 1e-4 for repo_id in repo_ids}
|
||||||
# Construct the underlying datasets passing everything but `transform` and `delta_timestamps` which
|
# Construct the underlying datasets passing everything but `transform` and `delta_timestamps` which
|
||||||
# are handled by this class.
|
# are handled by this class.
|
||||||
self._datasets = [
|
self._datasets = [
|
||||||
LeRobotDataset(
|
LeRobotDataset(
|
||||||
repo_id,
|
repo_id,
|
||||||
root=root / repo_id if root is not None else None,
|
root=self.root / repo_id,
|
||||||
episodes=episodes[repo_id] if episodes is not None else None,
|
episodes=episodes[repo_id] if episodes else None,
|
||||||
delta_timestamps=delta_timestamps,
|
|
||||||
image_transforms=image_transforms,
|
image_transforms=image_transforms,
|
||||||
|
delta_timestamps=delta_timestamps,
|
||||||
|
tolerance_s=self.tolerances_s[repo_id],
|
||||||
|
download_videos=download_videos,
|
||||||
|
local_files_only=local_files_only,
|
||||||
video_backend=video_backend,
|
video_backend=video_backend,
|
||||||
)
|
)
|
||||||
for repo_id in repo_ids
|
for repo_id in repo_ids
|
||||||
]
|
]
|
||||||
# Check that some properties are consistent across datasets. Note: We may relax some of these
|
|
||||||
# consistency requirements in future iterations of this class.
|
|
||||||
for repo_id, dataset in zip(self.repo_ids, self._datasets, strict=True):
|
|
||||||
if dataset.meta.info != self._datasets[0].meta.info:
|
|
||||||
raise ValueError(
|
|
||||||
f"Detected a mismatch in dataset info between {self.repo_ids[0]} and {repo_id}. This is "
|
|
||||||
"not yet supported."
|
|
||||||
)
|
|
||||||
# Disable any data keys that are not common across all of the datasets. Note: we may relax this
|
# Disable any data keys that are not common across all of the datasets. Note: we may relax this
|
||||||
# restriction in future iterations of this class. For now, this is necessary at least for being able
|
# restriction in future iterations of this class. For now, this is necessary at least for being able
|
||||||
# to use PyTorch's default DataLoader collate function.
|
# to use PyTorch's default DataLoader collate function.
|
||||||
self.disabled_data_keys = set()
|
self.disabled_features = set()
|
||||||
intersection_data_keys = set(self._datasets[0].hf_dataset.features)
|
intersection_features = set(self._datasets[0].features)
|
||||||
for dataset in self._datasets:
|
for ds in self._datasets:
|
||||||
intersection_data_keys.intersection_update(dataset.hf_dataset.features)
|
intersection_features.intersection_update(ds.features)
|
||||||
if len(intersection_data_keys) == 0:
|
if len(intersection_features) == 0:
|
||||||
raise RuntimeError(
|
raise RuntimeError(
|
||||||
"Multiple datasets were provided but they had no keys common to all of them. The "
|
"Multiple datasets were provided but they had no keys common to all of them. "
|
||||||
"multi-dataset functionality currently only keeps common keys."
|
"The multi-dataset functionality currently only keeps common keys."
|
||||||
)
|
)
|
||||||
for repo_id, dataset in zip(self.repo_ids, self._datasets, strict=True):
|
for repo_id, ds in zip(self.repo_ids, self._datasets, strict=True):
|
||||||
extra_keys = set(dataset.hf_dataset.features).difference(intersection_data_keys)
|
extra_keys = set(ds.features).difference(intersection_features)
|
||||||
logging.warning(
|
logging.warning(
|
||||||
f"keys {extra_keys} of {repo_id} were disabled as they are not contained in all the "
|
f"keys {extra_keys} of {repo_id} were disabled as they are not contained in all the "
|
||||||
"other datasets."
|
"other datasets."
|
||||||
)
|
)
|
||||||
self.disabled_data_keys.update(extra_keys)
|
self.disabled_features.update(extra_keys)
|
||||||
|
|
||||||
self.root = root
|
|
||||||
self.image_transforms = image_transforms
|
self.image_transforms = image_transforms
|
||||||
self.delta_timestamps = delta_timestamps
|
self.delta_timestamps = delta_timestamps
|
||||||
self.stats = aggregate_stats(self._datasets)
|
self.stats = aggregate_stats(self._datasets)
|
||||||
|
@ -1054,9 +1058,7 @@ class MultiLeRobotDataset(torch.utils.data.Dataset):
|
||||||
def features(self) -> datasets.Features:
|
def features(self) -> datasets.Features:
|
||||||
features = {}
|
features = {}
|
||||||
for dataset in self._datasets:
|
for dataset in self._datasets:
|
||||||
features.update(
|
features.update({k: v for k, v in dataset.hf_features.items() if k not in self.disabled_features})
|
||||||
{k: v for k, v in dataset.hf_features.items() if k not in self.disabled_data_keys}
|
|
||||||
)
|
|
||||||
return features
|
return features
|
||||||
|
|
||||||
@property
|
@property
|
||||||
|
@ -1120,7 +1122,7 @@ class MultiLeRobotDataset(torch.utils.data.Dataset):
|
||||||
raise AssertionError("We expect the loop to break out as long as the index is within bounds.")
|
raise AssertionError("We expect the loop to break out as long as the index is within bounds.")
|
||||||
item = self._datasets[dataset_idx][idx - start_idx]
|
item = self._datasets[dataset_idx][idx - start_idx]
|
||||||
item["dataset_index"] = torch.tensor(dataset_idx)
|
item["dataset_index"] = torch.tensor(dataset_idx)
|
||||||
for data_key in self.disabled_data_keys:
|
for data_key in self.disabled_features:
|
||||||
if data_key in item:
|
if data_key in item:
|
||||||
del item[data_key]
|
del item[data_key]
|
||||||
|
|
||||||
|
|
|
@ -14,8 +14,8 @@
|
||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
import json
|
import json
|
||||||
|
import logging
|
||||||
import textwrap
|
import textwrap
|
||||||
import warnings
|
|
||||||
from itertools import accumulate
|
from itertools import accumulate
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from pprint import pformat
|
from pprint import pformat
|
||||||
|
@ -212,8 +212,8 @@ class BackwardCompatibilityError(Exception):
|
||||||
"Open the top cabinet, store the pot inside it then close the cabinet.", "Push the T-shaped block onto the T-shaped target.",
|
"Open the top cabinet, store the pot inside it then close the cabinet.", "Push the T-shaped block onto the T-shaped target.",
|
||||||
"Grab the spray paint on the shelf and place it in the bin on top of the robot dog.", "Fold the sweatshirt.", ...
|
"Grab the spray paint on the shelf and place it in the bin on top of the robot dog.", "Fold the sweatshirt.", ...
|
||||||
|
|
||||||
If you encounter a problem, contact LeRobot maintainers on Discord ('https://discord.com/invite/s3KuuzsPFb')
|
If you encounter a problem, contact LeRobot maintainers on [Discord](https://discord.com/invite/s3KuuzsPFb)
|
||||||
or open an issue on GitHub.
|
or open an [issue on GitHub](https://github.com/huggingface/lerobot/issues/new/choose).
|
||||||
""")
|
""")
|
||||||
super().__init__(message)
|
super().__init__(message)
|
||||||
|
|
||||||
|
@ -226,12 +226,11 @@ def check_version_compatibility(
|
||||||
if major_to_check < current_major and enforce_breaking_major:
|
if major_to_check < current_major and enforce_breaking_major:
|
||||||
raise BackwardCompatibilityError(repo_id, version_to_check)
|
raise BackwardCompatibilityError(repo_id, version_to_check)
|
||||||
elif float(version_to_check.strip("v")) < float(current_version.strip("v")):
|
elif float(version_to_check.strip("v")) < float(current_version.strip("v")):
|
||||||
warnings.warn(
|
logging.warning(
|
||||||
f"""The dataset you requested ({repo_id}) was created with a previous version ({version_to_check}) of the
|
f"""The dataset you requested ({repo_id}) was created with a previous version ({version_to_check}) of the
|
||||||
codebase. The current codebase version is {current_version}. You should be fine since
|
codebase. The current codebase version is {current_version}. You should be fine since
|
||||||
backward compatibility is maintained. If you encounter a problem, contact LeRobot maintainers on
|
backward compatibility is maintained. If you encounter a problem, contact LeRobot maintainers on
|
||||||
Discord ('https://discord.com/invite/s3KuuzsPFb') or open an issue on github.""",
|
Discord ('https://discord.com/invite/s3KuuzsPFb') or open an issue on github.""",
|
||||||
stacklevel=1,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@ -245,13 +244,12 @@ def get_hub_safe_version(repo_id: str, version: str) -> str:
|
||||||
if num_version >= 2.0 and all(v < 2.0 for v in hub_num_versions):
|
if num_version >= 2.0 and all(v < 2.0 for v in hub_num_versions):
|
||||||
raise BackwardCompatibilityError(repo_id, version)
|
raise BackwardCompatibilityError(repo_id, version)
|
||||||
|
|
||||||
warnings.warn(
|
logging.warning(
|
||||||
f"""You are trying to load a dataset from {repo_id} created with a previous version of the
|
f"""You are trying to load a dataset from {repo_id} created with a previous version of the
|
||||||
codebase. The following versions are available: {branches}.
|
codebase. The following versions are available: {branches}.
|
||||||
The requested version ('{version}') is not found. You should be fine since
|
The requested version ('{version}') is not found. You should be fine since
|
||||||
backward compatibility is maintained. If you encounter a problem, contact LeRobot maintainers on
|
backward compatibility is maintained. If you encounter a problem, contact LeRobot maintainers on
|
||||||
Discord ('https://discord.com/invite/s3KuuzsPFb') or open an issue on github.""",
|
Discord ('https://discord.com/invite/s3KuuzsPFb') or open an issue on github.""",
|
||||||
stacklevel=1,
|
|
||||||
)
|
)
|
||||||
if "main" not in branches:
|
if "main" not in branches:
|
||||||
raise ValueError(f"Version 'main' not found on {repo_id}")
|
raise ValueError(f"Version 'main' not found on {repo_id}")
|
||||||
|
|
|
@ -15,6 +15,8 @@
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
"""
|
"""
|
||||||
|
This script is for internal use to convert all datasets under the 'lerobot' hub user account to v2.
|
||||||
|
|
||||||
Note: Since the original Aloha datasets don't use shadow motors, you need to comment those out in
|
Note: Since the original Aloha datasets don't use shadow motors, you need to comment those out in
|
||||||
lerobot/configs/robot/aloha.yaml before running this script.
|
lerobot/configs/robot/aloha.yaml before running this script.
|
||||||
"""
|
"""
|
||||||
|
|
|
@ -103,11 +103,11 @@ import argparse
|
||||||
import contextlib
|
import contextlib
|
||||||
import filecmp
|
import filecmp
|
||||||
import json
|
import json
|
||||||
|
import logging
|
||||||
import math
|
import math
|
||||||
import shutil
|
import shutil
|
||||||
import subprocess
|
import subprocess
|
||||||
import tempfile
|
import tempfile
|
||||||
import warnings
|
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
import datasets
|
import datasets
|
||||||
|
@ -461,9 +461,8 @@ def convert_dataset(
|
||||||
video_keys = [key for key, ft in features.items() if ft["dtype"] == "video"]
|
video_keys = [key for key, ft in features.items() if ft["dtype"] == "video"]
|
||||||
|
|
||||||
if single_task and "language_instruction" in dataset.column_names:
|
if single_task and "language_instruction" in dataset.column_names:
|
||||||
warnings.warn(
|
logging.warning(
|
||||||
"'single_task' provided but 'language_instruction' tasks_col found. Using 'language_instruction'.",
|
"'single_task' provided but 'language_instruction' tasks_col found. Using 'language_instruction'.",
|
||||||
stacklevel=1,
|
|
||||||
)
|
)
|
||||||
single_task = None
|
single_task = None
|
||||||
tasks_col = "language_instruction"
|
tasks_col = "language_instruction"
|
||||||
|
@ -642,7 +641,7 @@ def main():
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--license",
|
"--license",
|
||||||
type=str,
|
type=str,
|
||||||
default="mit",
|
default="apache-2.0",
|
||||||
help="Repo license. Must be one of https://huggingface.co/docs/hub/repositories-licenses. Defaults to mit.",
|
help="Repo license. Must be one of https://huggingface.co/docs/hub/repositories-licenses. Defaults to mit.",
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
|
|
|
@ -17,7 +17,7 @@ from lerobot.common.datasets.utils import (
|
||||||
get_hf_features_from_features,
|
get_hf_features_from_features,
|
||||||
hf_transform_to_torch,
|
hf_transform_to_torch,
|
||||||
)
|
)
|
||||||
from tests.fixtures.defaults import (
|
from tests.fixtures.constants import (
|
||||||
DEFAULT_FPS,
|
DEFAULT_FPS,
|
||||||
DUMMY_CAMERA_FEATURES,
|
DUMMY_CAMERA_FEATURES,
|
||||||
DUMMY_MOTOR_FEATURES,
|
DUMMY_MOTOR_FEATURES,
|
||||||
|
|
|
@ -5,7 +5,7 @@ import pytest
|
||||||
from huggingface_hub.utils import filter_repo_objects
|
from huggingface_hub.utils import filter_repo_objects
|
||||||
|
|
||||||
from lerobot.common.datasets.utils import EPISODES_PATH, INFO_PATH, STATS_PATH, TASKS_PATH
|
from lerobot.common.datasets.utils import EPISODES_PATH, INFO_PATH, STATS_PATH, TASKS_PATH
|
||||||
from tests.fixtures.defaults import LEROBOT_TEST_DIR
|
from tests.fixtures.constants import LEROBOT_TEST_DIR
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(scope="session")
|
@pytest.fixture(scope="session")
|
||||||
|
|
|
@ -44,7 +44,7 @@ from lerobot.common.datasets.utils import (
|
||||||
unflatten_dict,
|
unflatten_dict,
|
||||||
)
|
)
|
||||||
from lerobot.common.utils.utils import init_hydra_config, seeded_context
|
from lerobot.common.utils.utils import init_hydra_config, seeded_context
|
||||||
from tests.fixtures.defaults import DUMMY_REPO_ID
|
from tests.fixtures.constants import DUMMY_REPO_ID
|
||||||
from tests.utils import DEFAULT_CONFIG_PATH, DEVICE, make_robot
|
from tests.utils import DEFAULT_CONFIG_PATH, DEVICE, make_robot
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -9,7 +9,7 @@ from lerobot.common.datasets.utils import (
|
||||||
get_delta_indices,
|
get_delta_indices,
|
||||||
hf_transform_to_torch,
|
hf_transform_to_torch,
|
||||||
)
|
)
|
||||||
from tests.fixtures.defaults import DUMMY_MOTOR_FEATURES
|
from tests.fixtures.constants import DUMMY_MOTOR_FEATURES
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(scope="module")
|
@pytest.fixture(scope="module")
|
||||||
|
|
|
@ -21,7 +21,7 @@ from pathlib import Path
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
from tests.fixtures.defaults import DUMMY_REPO_ID
|
from tests.fixtures.constants import DUMMY_REPO_ID
|
||||||
from tests.utils import require_package
|
from tests.utils import require_package
|
||||||
|
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue