Add extra info to dataset card, various fixes from Remi's review
This commit is contained in:
parent
4d15861872
commit
a91b7c6163
|
@ -206,3 +206,95 @@ def aggregate_stats(ls_datasets) -> dict[str, torch.Tensor]:
|
|||
)
|
||||
)
|
||||
return stats
|
||||
|
||||
|
||||
# TODO(aliberts): refactor stats in save_episodes
|
||||
# import numpy as np
|
||||
# from lerobot.common.datasets.utils import load_image_as_numpy
|
||||
# def aggregate_stats_v2(stats_list: list) -> dict:
|
||||
# """Aggregate stats from multiple compute_stats outputs into a single set of stats.
|
||||
|
||||
# The final stats will have the union of all data keys from each of the stats dicts.
|
||||
|
||||
# For instance:
|
||||
# - new_min = min(min_dataset_0, min_dataset_1, ...)
|
||||
# - new_max = max(max_dataset_0, max_dataset_1, ...)
|
||||
# - new_mean = (mean of all data, weighted by counts)
|
||||
# - new_std = (std of all data)
|
||||
# """
|
||||
# data_keys = set(key for stats in stats_list for key in stats.keys())
|
||||
# aggregated_stats = {key: {} for key in data_keys}
|
||||
|
||||
# for key in data_keys:
|
||||
# # Collect stats for the current key from all datasets where it exists
|
||||
# stats_with_key = [stats[key] for stats in stats_list if key in stats]
|
||||
|
||||
# # Aggregate 'min' and 'max' using np.minimum and np.maximum
|
||||
# aggregated_stats[key]['min'] = np.minimum.reduce([s['min'] for s in stats_with_key])
|
||||
# aggregated_stats[key]['max'] = np.maximum.reduce([s['max'] for s in stats_with_key])
|
||||
|
||||
# # Extract means, variances (std^2), and counts
|
||||
# means = np.array([s['mean'] for s in stats_with_key])
|
||||
# variances = np.array([s['std']**2 for s in stats_with_key])
|
||||
# counts = np.array([s['count'] for s in stats_with_key])
|
||||
|
||||
# # Ensure counts can broadcast with means/variances if they have additional dimensions
|
||||
# counts = counts.reshape(-1, *[1]*(means.ndim - 1))
|
||||
|
||||
# # Compute total counts
|
||||
# total_count = counts.sum(axis=0)
|
||||
|
||||
# # Compute the weighted mean
|
||||
# weighted_means = means * counts
|
||||
# total_mean = weighted_means.sum(axis=0) / total_count
|
||||
|
||||
# # Compute the variance using the parallel algorithm
|
||||
# delta_means = means - total_mean
|
||||
# weighted_variances = (variances + delta_means**2) * counts
|
||||
# total_variance = weighted_variances.sum(axis=0) / total_count
|
||||
|
||||
# # Store the aggregated stats
|
||||
# aggregated_stats[key]['mean'] = total_mean
|
||||
# aggregated_stats[key]['std'] = np.sqrt(total_variance)
|
||||
# aggregated_stats[key]['count'] = total_count
|
||||
|
||||
# return aggregated_stats
|
||||
|
||||
|
||||
# def compute_episode_stats(episode_buffer: dict, features: dict, episode_length: int, image_sampling: int = 10) -> dict:
|
||||
# stats = {}
|
||||
# for key, data in episode_buffer.items():
|
||||
# if features[key]["dtype"] in ["image", "video"]:
|
||||
# stats[key] = compute_image_stats(data, sampling=image_sampling)
|
||||
# else:
|
||||
# axes_to_reduce = 0 # Compute stats over the first axis
|
||||
# stats[key] = {
|
||||
# "min": np.min(data, axis=axes_to_reduce),
|
||||
# "max": np.max(data, axis=axes_to_reduce),
|
||||
# "mean": np.mean(data, axis=axes_to_reduce),
|
||||
# "std": np.std(data, axis=axes_to_reduce),
|
||||
# "count": episode_length,
|
||||
# }
|
||||
# return stats
|
||||
|
||||
|
||||
# def compute_image_stats(image_paths: list[str], sampling: int = 10) -> dict:
|
||||
# images = []
|
||||
# samples = range(0, len(image_paths), sampling)
|
||||
# for idx in samples:
|
||||
# path = image_paths[idx]
|
||||
# img = load_image_as_numpy(path, channel_first=True)
|
||||
# images.append(img)
|
||||
|
||||
# images = np.stack(images)
|
||||
# axes_to_reduce = (0, 2, 3) # keep channel dim
|
||||
# image_stats = {
|
||||
# "min": np.min(images, axis=axes_to_reduce, keepdims=True),
|
||||
# "max": np.max(images, axis=axes_to_reduce, keepdims=True),
|
||||
# "mean": np.mean(images, axis=axes_to_reduce, keepdims=True),
|
||||
# "std": np.std(images, axis=axes_to_reduce, keepdims=True)
|
||||
# }
|
||||
# for key in image_stats: # squeeze batch dim
|
||||
# image_stats[key] = np.squeeze(image_stats[key], axis=0)
|
||||
|
||||
# return image_stats
|
||||
|
|
|
@ -27,7 +27,7 @@ import PIL.Image
|
|||
import torch
|
||||
import torch.utils
|
||||
from datasets import load_dataset
|
||||
from huggingface_hub import snapshot_download, upload_folder
|
||||
from huggingface_hub import create_repo, snapshot_download, upload_folder
|
||||
|
||||
from lerobot.common.datasets.compute_stats import aggregate_stats, compute_stats
|
||||
from lerobot.common.datasets.image_writer import AsyncImageWriter, write_image
|
||||
|
@ -44,6 +44,7 @@ from lerobot.common.datasets.utils import (
|
|||
check_version_compatibility,
|
||||
create_branch,
|
||||
create_empty_dataset_info,
|
||||
create_lerobot_dataset_card,
|
||||
get_delta_indices,
|
||||
get_episode_data_index,
|
||||
get_features_from_robot,
|
||||
|
@ -54,9 +55,9 @@ from lerobot.common.datasets.utils import (
|
|||
load_info,
|
||||
load_stats,
|
||||
load_tasks,
|
||||
serialize_dict,
|
||||
write_json,
|
||||
write_parquet,
|
||||
write_stats,
|
||||
)
|
||||
from lerobot.common.datasets.video_utils import (
|
||||
VideoFrame,
|
||||
|
@ -75,11 +76,11 @@ class LeRobotDatasetMetadata:
|
|||
def __init__(
|
||||
self,
|
||||
repo_id: str,
|
||||
root: Path | None = None,
|
||||
root: str | Path | None = None,
|
||||
local_files_only: bool = False,
|
||||
):
|
||||
self.repo_id = repo_id
|
||||
self.root = root if root is not None else LEROBOT_HOME / repo_id
|
||||
self.root = Path(root) if root is not None else LEROBOT_HOME / repo_id
|
||||
self.local_files_only = local_files_only
|
||||
|
||||
# Load metadata
|
||||
|
@ -163,7 +164,7 @@ class LeRobotDatasetMetadata:
|
|||
return [key for key, ft in self.features.items() if ft["dtype"] in ["video", "image"]]
|
||||
|
||||
@property
|
||||
def names(self) -> dict[str, list[str]]:
|
||||
def names(self) -> dict[str, list | dict]:
|
||||
"""Names of the various dimensions of vector modalities."""
|
||||
return {key: ft["names"] for key, ft in self.features.items()}
|
||||
|
||||
|
@ -209,7 +210,7 @@ class LeRobotDatasetMetadata:
|
|||
task_index = self.task_to_task_index.get(task, None)
|
||||
return task_index if task_index is not None else self.total_tasks
|
||||
|
||||
def add_episode(self, episode_index: int, episode_length: int, task: str, task_index: int) -> None:
|
||||
def save_episode(self, episode_index: int, episode_length: int, task: str, task_index: int) -> None:
|
||||
self.info["total_episodes"] += 1
|
||||
self.info["total_frames"] += episode_length
|
||||
|
||||
|
@ -238,24 +239,37 @@ class LeRobotDatasetMetadata:
|
|||
self.episodes.append(episode_dict)
|
||||
append_jsonlines(episode_dict, self.root / EPISODES_PATH)
|
||||
|
||||
# TODO(aliberts): refactor stats in save_episodes
|
||||
# image_sampling = int(self.fps / 2) # sample 2 img/s for the stats
|
||||
# ep_stats = compute_episode_stats(episode_buffer, self.features, episode_length, image_sampling=image_sampling)
|
||||
# ep_stats = serialize_dict(ep_stats)
|
||||
# append_jsonlines(ep_stats, self.root / STATS_PATH)
|
||||
|
||||
def write_video_info(self) -> None:
|
||||
"""
|
||||
Warning: this function writes info from first episode videos, implicitly assuming that all videos have
|
||||
been encoded the same way. Also, this means it assumes the first episode exists.
|
||||
"""
|
||||
for key in self.video_keys:
|
||||
if key not in self.info["videos"]:
|
||||
if not self.features[key].get("info", None):
|
||||
video_path = self.root / self.get_video_file_path(ep_index=0, vid_key=key)
|
||||
self.info["videos"][key] = get_video_info(video_path)
|
||||
self.info["features"][key]["info"] = get_video_info(video_path)
|
||||
|
||||
write_json(self.info, self.root / INFO_PATH)
|
||||
|
||||
def __repr__(self):
|
||||
return (
|
||||
f"{self.__class__.__name__}\n"
|
||||
f"Repository ID: '{self.repo_id}',\n"
|
||||
f"{json.dumps(self.meta.info, indent=4)}\n"
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def create(
|
||||
cls,
|
||||
repo_id: str,
|
||||
fps: int,
|
||||
root: Path | None = None,
|
||||
root: str | Path | None = None,
|
||||
robot: Robot | None = None,
|
||||
robot_type: str | None = None,
|
||||
features: dict | None = None,
|
||||
|
@ -264,7 +278,7 @@ class LeRobotDatasetMetadata:
|
|||
"""Creates metadata for a LeRobotDataset."""
|
||||
obj = cls.__new__(cls)
|
||||
obj.repo_id = repo_id
|
||||
obj.root = root if root is not None else LEROBOT_HOME / repo_id
|
||||
obj.root = Path(root) if root is not None else LEROBOT_HOME / repo_id
|
||||
|
||||
if robot is not None:
|
||||
features = get_features_from_robot(robot, use_videos)
|
||||
|
@ -294,7 +308,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
|||
def __init__(
|
||||
self,
|
||||
repo_id: str,
|
||||
root: Path | None = None,
|
||||
root: str | Path | None = None,
|
||||
episodes: list[int] | None = None,
|
||||
image_transforms: Callable | None = None,
|
||||
delta_timestamps: dict[list[float]] | None = None,
|
||||
|
@ -402,7 +416,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
|||
"""
|
||||
super().__init__()
|
||||
self.repo_id = repo_id
|
||||
self.root = root if root is not None else LEROBOT_HOME / repo_id
|
||||
self.root = Path(root) if root is not None else LEROBOT_HOME / repo_id
|
||||
self.image_transforms = image_transforms
|
||||
self.delta_timestamps = delta_timestamps
|
||||
self.episodes = episodes
|
||||
|
@ -437,22 +451,32 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
|||
check_delta_timestamps(self.delta_timestamps, self.fps, self.tolerance_s)
|
||||
self.delta_indices = get_delta_indices(self.delta_timestamps, self.fps)
|
||||
|
||||
def push_to_hub(self, push_videos: bool = True) -> None:
|
||||
def push_to_hub(
|
||||
self,
|
||||
tags: list | None = None,
|
||||
text: str | None = None,
|
||||
license: str | None = "mit",
|
||||
push_videos: bool = True,
|
||||
) -> None:
|
||||
if not self.consolidated:
|
||||
raise RuntimeError(
|
||||
"You are trying to upload to the hub a LeRobotDataset that has not been consolidated yet."
|
||||
"Please call the dataset 'consolidate()' method first."
|
||||
)
|
||||
|
||||
ignore_patterns = ["images/"]
|
||||
if not push_videos:
|
||||
ignore_patterns.append("videos/")
|
||||
|
||||
create_repo(self.repo_id, repo_type="dataset", exist_ok=True)
|
||||
upload_folder(
|
||||
repo_id=self.repo_id,
|
||||
folder_path=self.root,
|
||||
repo_type="dataset",
|
||||
ignore_patterns=ignore_patterns,
|
||||
)
|
||||
card = create_lerobot_dataset_card(tags=tags, text=text, info=self.meta.info, license=license)
|
||||
card.push_to_hub(repo_id=self.repo_id, repo_type="dataset")
|
||||
create_branch(repo_id=self.repo_id, branch=CODEBASE_VERSION, repo_type="dataset")
|
||||
|
||||
def pull_from_repo(
|
||||
|
@ -501,8 +525,9 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
|||
files = [str(self.root / self.meta.get_data_file_path(ep_idx)) for ep_idx in self.episodes]
|
||||
hf_dataset = load_dataset("parquet", data_files=files, split="train")
|
||||
|
||||
# TODO(aliberts): hf_dataset.set_format("torch")
|
||||
hf_dataset.set_transform(hf_transform_to_torch)
|
||||
# return hf_dataset.with_format("torch") TODO
|
||||
|
||||
return hf_dataset
|
||||
|
||||
@property
|
||||
|
@ -653,30 +678,33 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
|||
def add_frame(self, frame: dict) -> None:
|
||||
"""
|
||||
This function only adds the frame to the episode_buffer. Apart from images — which are written in a
|
||||
temporary directory — nothing is written to disk. To save those frames, the 'add_episode()' method
|
||||
temporary directory — nothing is written to disk. To save those frames, the 'save_episode()' method
|
||||
then needs to be called.
|
||||
"""
|
||||
frame_index = self.episode_buffer["size"]
|
||||
for key, ft in self.features.items():
|
||||
if key == "frame_index":
|
||||
self.episode_buffer[key].append(frame_index)
|
||||
elif key == "timestamp":
|
||||
self.episode_buffer[key].append(frame_index / self.fps)
|
||||
elif key in frame and ft["dtype"] not in ["image", "video"]:
|
||||
self.episode_buffer[key].append(frame[key])
|
||||
elif key in frame and ft["dtype"] in ["image", "video"]:
|
||||
timestamp = frame["timestamp"] if "timestamp" in frame else frame_index / self.fps
|
||||
self.episode_buffer["frame_index"].append(frame_index)
|
||||
self.episode_buffer["timestamp"].append(timestamp)
|
||||
|
||||
for key in frame:
|
||||
if key not in self.features:
|
||||
raise ValueError(key)
|
||||
|
||||
if self.features[key]["dtype"] not in ["image", "video"]:
|
||||
item = frame[key].numpy() if isinstance(frame[key], torch.Tensor) else frame[key]
|
||||
self.episode_buffer[key].append(item)
|
||||
elif self.features[key]["dtype"] in ["image", "video"]:
|
||||
img_path = self._get_image_file_path(
|
||||
episode_index=self.episode_buffer["episode_index"], image_key=key, frame_index=frame_index
|
||||
)
|
||||
if frame_index == 0:
|
||||
img_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
self._save_image(frame[key], img_path)
|
||||
if ft["dtype"] == "image":
|
||||
self.episode_buffer[key].append(str(img_path))
|
||||
self.episode_buffer[key].append(str(img_path))
|
||||
|
||||
self.episode_buffer["size"] += 1
|
||||
|
||||
def add_episode(self, task: str, encode_videos: bool = False) -> None:
|
||||
def save_episode(self, task: str, encode_videos: bool = True, episode_data: dict | None = None) -> None:
|
||||
"""
|
||||
This will save to disk the current episode in self.episode_buffer. Note that since it affects files on
|
||||
disk, it sets self.consolidated to False to ensure proper consolidation later on before uploading to
|
||||
|
@ -686,49 +714,56 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
|||
you can do it later with dataset.consolidate(). This is to give more flexibility on when to spend
|
||||
time for video encoding.
|
||||
"""
|
||||
episode_length = self.episode_buffer.pop("size")
|
||||
episode_index = self.episode_buffer["episode_index"]
|
||||
if not episode_data:
|
||||
episode_buffer = self.episode_buffer
|
||||
|
||||
episode_length = episode_buffer.pop("size")
|
||||
episode_index = episode_buffer["episode_index"]
|
||||
if episode_index != self.meta.total_episodes:
|
||||
# TODO(aliberts): Add option to use existing episode_index
|
||||
raise NotImplementedError()
|
||||
|
||||
task_index = self.meta.get_task_index(task)
|
||||
|
||||
if not set(self.episode_buffer.keys()) == set(self.features):
|
||||
if not set(episode_buffer.keys()) == set(self.features):
|
||||
raise ValueError()
|
||||
|
||||
for key, ft in self.features.items():
|
||||
if key == "index":
|
||||
self.episode_buffer[key] = np.arange(
|
||||
episode_buffer[key] = np.arange(
|
||||
self.meta.total_frames, self.meta.total_frames + episode_length
|
||||
)
|
||||
elif key == "episode_index":
|
||||
self.episode_buffer[key] = np.full((episode_length,), episode_index)
|
||||
episode_buffer[key] = np.full((episode_length,), episode_index)
|
||||
elif key == "task_index":
|
||||
self.episode_buffer[key] = np.full((episode_length,), task_index)
|
||||
episode_buffer[key] = np.full((episode_length,), task_index)
|
||||
elif ft["dtype"] in ["image", "video"]:
|
||||
continue
|
||||
elif ft["shape"][0] == 1:
|
||||
self.episode_buffer[key] = torch.tensor(self.episode_buffer[key])
|
||||
elif ft["shape"][0] > 1:
|
||||
self.episode_buffer[key] = torch.stack(self.episode_buffer[key])
|
||||
elif len(ft["shape"]) == 1 and ft["shape"][0] == 1:
|
||||
episode_buffer[key] = np.array(episode_buffer[key], dtype=ft["dtype"])
|
||||
elif len(ft["shape"]) == 1 and ft["shape"][0] > 1:
|
||||
episode_buffer[key] = np.stack(episode_buffer[key])
|
||||
else:
|
||||
raise ValueError()
|
||||
|
||||
self.meta.add_episode(episode_index, episode_length, task, task_index)
|
||||
raise ValueError(key)
|
||||
|
||||
self._wait_image_writer()
|
||||
self._save_episode_table(episode_index)
|
||||
self._save_episode_table(episode_buffer, episode_index)
|
||||
|
||||
self.meta.save_episode(episode_index, episode_length, task, task_index)
|
||||
|
||||
if encode_videos and len(self.meta.video_keys) > 0:
|
||||
self.encode_videos()
|
||||
video_paths = self.encode_episode_videos(episode_index)
|
||||
for key in self.meta.video_keys:
|
||||
episode_buffer[key] = video_paths[key]
|
||||
|
||||
if not episode_data: # Reset the buffer
|
||||
self.episode_buffer = self._create_episode_buffer()
|
||||
|
||||
# Reset the buffer
|
||||
self.episode_buffer = self._create_episode_buffer()
|
||||
self.consolidated = False
|
||||
|
||||
def _save_episode_table(self, episode_index: int) -> None:
|
||||
ep_dataset = datasets.Dataset.from_dict(self.episode_buffer, features=self.hf_features, split="train")
|
||||
def _save_episode_table(self, episode_buffer: dict, episode_index: int) -> None:
|
||||
episode_dict = {key: episode_buffer[key] for key in self.hf_features}
|
||||
ep_dataset = datasets.Dataset.from_dict(episode_dict, features=self.hf_features, split="train")
|
||||
ep_data_path = self.root / self.meta.get_data_file_path(ep_index=episode_index)
|
||||
ep_data_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
write_parquet(ep_dataset, ep_data_path)
|
||||
|
@ -777,16 +812,28 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
|||
Note: `encode_video_frames` is a blocking call. Making it asynchronous shouldn't speedup encoding,
|
||||
since video encoding with ffmpeg is already using multithreading.
|
||||
"""
|
||||
for episode_index in range(self.meta.total_episodes):
|
||||
for key in self.meta.video_keys:
|
||||
video_path = self.root / self.meta.get_video_file_path(episode_index, key)
|
||||
if video_path.is_file():
|
||||
# Skip if video is already encoded. Could be the case when resuming data recording.
|
||||
continue
|
||||
img_dir = self._get_image_file_path(
|
||||
episode_index=episode_index, image_key=key, frame_index=0
|
||||
).parent
|
||||
encode_video_frames(img_dir, video_path, self.fps, overwrite=True)
|
||||
for ep_idx in range(self.meta.total_episodes):
|
||||
self.encode_episode_videos(ep_idx)
|
||||
|
||||
def encode_episode_videos(self, episode_index: int) -> dict:
|
||||
"""
|
||||
Use ffmpeg to convert frames stored as png into mp4 videos.
|
||||
Note: `encode_video_frames` is a blocking call. Making it asynchronous shouldn't speedup encoding,
|
||||
since video encoding with ffmpeg is already using multithreading.
|
||||
"""
|
||||
video_paths = {}
|
||||
for key in self.meta.video_keys:
|
||||
video_path = self.root / self.meta.get_video_file_path(episode_index, key)
|
||||
video_paths[key] = str(video_path)
|
||||
if video_path.is_file():
|
||||
# Skip if video is already encoded. Could be the case when resuming data recording.
|
||||
continue
|
||||
img_dir = self._get_image_file_path(
|
||||
episode_index=episode_index, image_key=key, frame_index=0
|
||||
).parent
|
||||
encode_video_frames(img_dir, video_path, self.fps, overwrite=True)
|
||||
|
||||
return video_paths
|
||||
|
||||
def consolidate(self, run_compute_stats: bool = True, keep_image_files: bool = False) -> None:
|
||||
self.hf_dataset = self.load_hf_dataset()
|
||||
|
@ -810,27 +857,22 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
|||
|
||||
if run_compute_stats:
|
||||
self.stop_image_writer()
|
||||
# TODO(aliberts): refactor stats in save_episodes
|
||||
self.meta.stats = compute_stats(self)
|
||||
write_stats(self.meta.stats, self.root / STATS_PATH)
|
||||
serialized_stats = serialize_dict(self.meta.stats)
|
||||
write_json(serialized_stats, self.root / STATS_PATH)
|
||||
self.consolidated = True
|
||||
else:
|
||||
logging.warning(
|
||||
"Skipping computation of the dataset statistics, dataset is not fully consolidated."
|
||||
)
|
||||
|
||||
# TODO(aliberts)
|
||||
# - [X] add video info in info.json
|
||||
# Sanity checks:
|
||||
# - [X] number of files
|
||||
# - [ ] shapes
|
||||
# - [ ] ep_lenghts
|
||||
|
||||
@classmethod
|
||||
def create(
|
||||
cls,
|
||||
repo_id: str,
|
||||
fps: int,
|
||||
root: Path | None = None,
|
||||
root: str | Path | None = None,
|
||||
robot: Robot | None = None,
|
||||
robot_type: str | None = None,
|
||||
features: dict | None = None,
|
||||
|
|
|
@ -22,6 +22,7 @@ from typing import Any
|
|||
|
||||
import datasets
|
||||
import jsonlines
|
||||
import numpy as np
|
||||
import pyarrow.compute as pc
|
||||
import torch
|
||||
from datasets.table import embed_table_storage
|
||||
|
@ -91,6 +92,11 @@ def unflatten_dict(d: dict, sep: str = "/") -> dict:
|
|||
return outdict
|
||||
|
||||
|
||||
def serialize_dict(stats: dict[str, torch.Tensor | np.ndarray | dict]) -> dict:
|
||||
serialized_dict = {key: value.tolist() for key, value in flatten_dict(stats).items()}
|
||||
return unflatten_dict(serialized_dict)
|
||||
|
||||
|
||||
def write_parquet(dataset: datasets.Dataset, fpath: Path) -> None:
|
||||
# Embed image bytes into the table before saving to parquet
|
||||
format = dataset.format
|
||||
|
@ -128,12 +134,6 @@ def append_jsonlines(data: dict, fpath: Path) -> None:
|
|||
writer.write(data)
|
||||
|
||||
|
||||
def write_stats(stats: dict[str, torch.Tensor | dict], fpath: Path) -> None:
|
||||
serialized_stats = {key: value.tolist() for key, value in flatten_dict(stats).items()}
|
||||
serialized_stats = unflatten_dict(serialized_stats)
|
||||
write_json(serialized_stats, fpath)
|
||||
|
||||
|
||||
def load_info(local_dir: Path) -> dict:
|
||||
return load_json(local_dir / INFO_PATH)
|
||||
|
||||
|
@ -153,6 +153,16 @@ def load_episodes(local_dir: Path) -> dict:
|
|||
return load_jsonlines(local_dir / EPISODES_PATH)
|
||||
|
||||
|
||||
def load_image_as_numpy(fpath: str | Path, dtype="float32", channel_first: bool = True) -> np.ndarray:
|
||||
img = PILImage.open(fpath).convert("RGB")
|
||||
img_array = np.array(img, dtype=dtype)
|
||||
if channel_first: # (H, W, C) -> (C, H, W)
|
||||
img_array = np.transpose(img_array, (2, 0, 1))
|
||||
if "float" in dtype:
|
||||
img_array /= 255.0
|
||||
return img_array
|
||||
|
||||
|
||||
def hf_transform_to_torch(items_dict: dict[torch.Tensor | None]):
|
||||
"""Get a transform function that convert items from Hugging Face dataset (pyarrow)
|
||||
to torch tensors. Importantly, images are converted from PIL, which corresponds to
|
||||
|
@ -331,7 +341,7 @@ def check_timestamps_sync(
|
|||
within_tolerance = torch.abs(diffs - 1 / fps) <= tolerance_s
|
||||
|
||||
# We mask differences between the timestamp at the end of an episode
|
||||
# and the one the start of the next episode since these are expected
|
||||
# and the one at the start of the next episode since these are expected
|
||||
# to be outside tolerance.
|
||||
mask = torch.ones(len(diffs), dtype=torch.bool)
|
||||
ignored_diffs = episode_data_index["to"][:-1] - 1
|
||||
|
@ -433,7 +443,12 @@ def create_branch(repo_id, *, branch: str, repo_type: str | None = None) -> None
|
|||
|
||||
|
||||
def create_lerobot_dataset_card(
|
||||
tags: list | None = None, text: str | None = None, info: dict | None = None
|
||||
tags: list | None = None,
|
||||
text: str | None = None,
|
||||
info: dict | None = None,
|
||||
license: str | None = None,
|
||||
citation: str | None = None,
|
||||
arxiv: str | None = None,
|
||||
) -> DatasetCard:
|
||||
card = DatasetCard(DATASET_CARD_TEMPLATE)
|
||||
card.data.configs = [
|
||||
|
@ -444,11 +459,19 @@ def create_lerobot_dataset_card(
|
|||
]
|
||||
card.data.task_categories = ["robotics"]
|
||||
card.data.tags = ["LeRobot"]
|
||||
if tags is not None:
|
||||
if license:
|
||||
card.data.license = license
|
||||
if tags:
|
||||
card.data.tags += tags
|
||||
if text is not None:
|
||||
if text:
|
||||
card.text += f"{text}\n"
|
||||
if info is not None:
|
||||
if info:
|
||||
card.text += "## Info\n"
|
||||
card.text += "[meta/info.json](meta/info.json)\n"
|
||||
card.text += f"```json\n{json.dumps(info, indent=4)}\n```"
|
||||
if citation:
|
||||
card.text += "## Citation\n"
|
||||
card.text += f"```\n{citation}\n```\n"
|
||||
if arxiv:
|
||||
card.data.arxiv = arxiv
|
||||
return card
|
||||
|
|
|
@ -213,8 +213,11 @@ def get_features_from_hf_dataset(dataset: Dataset, robot_config: dict | None = N
|
|||
assert isinstance(ft.feature, datasets.Value)
|
||||
dtype = ft.feature.dtype
|
||||
shape = (ft.length,)
|
||||
names = robot_config["names"][key] if robot_config else [f"motor_{i}" for i in range(ft.length)]
|
||||
assert len(names) == shape[0]
|
||||
motor_names = (
|
||||
robot_config["names"][key] if robot_config else [f"motor_{i}" for i in range(ft.length)]
|
||||
)
|
||||
assert len(motor_names) == shape[0]
|
||||
names = {"motors": motor_names}
|
||||
elif isinstance(ft, datasets.Image):
|
||||
dtype = "image"
|
||||
image = dataset[0][key] # Assuming first row
|
||||
|
@ -433,6 +436,9 @@ def convert_dataset(
|
|||
tasks_path: Path | None = None,
|
||||
tasks_col: Path | None = None,
|
||||
robot_config: dict | None = None,
|
||||
license: str | None = None,
|
||||
citation: str | None = None,
|
||||
arxiv: str | None = None,
|
||||
test_branch: str | None = None,
|
||||
):
|
||||
v1 = get_hub_safe_version(repo_id, V16, enforce_v2=False)
|
||||
|
@ -559,7 +565,9 @@ def convert_dataset(
|
|||
}
|
||||
write_json(metadata_v2_0, v20_dir / INFO_PATH)
|
||||
convert_stats_to_json(v1x_dir, v20_dir)
|
||||
card = create_lerobot_dataset_card(tags=repo_tags, info=metadata_v2_0)
|
||||
card = create_lerobot_dataset_card(
|
||||
tags=repo_tags, info=metadata_v2_0, license=license, citation=citation, arxiv=arxiv
|
||||
)
|
||||
|
||||
with contextlib.suppress(EntryNotFoundError):
|
||||
hub_api.delete_folder(repo_id=repo_id, path_in_repo="data", repo_type="dataset", revision=branch)
|
||||
|
@ -634,6 +642,12 @@ def main():
|
|||
default=None,
|
||||
help="Local directory to store the dataset during conversion. Defaults to /tmp/lerobot_dataset_v2",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--license",
|
||||
type=str,
|
||||
default="mit",
|
||||
help="Repo license. Must be one of https://huggingface.co/docs/hub/repositories-licenses. Defaults to mit.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--test-branch",
|
||||
type=str,
|
||||
|
@ -652,7 +666,4 @@ def main():
|
|||
|
||||
|
||||
if __name__ == "__main__":
|
||||
from time import sleep
|
||||
|
||||
sleep(1)
|
||||
main()
|
||||
|
|
|
@ -301,7 +301,7 @@ def record(
|
|||
dataset.clear_episode_buffer()
|
||||
continue
|
||||
|
||||
dataset.add_episode(task)
|
||||
dataset.save_episode(task)
|
||||
recorded_episodes += 1
|
||||
|
||||
if events["stop_recording"]:
|
||||
|
|
Loading…
Reference in New Issue