Faster self.meta.episodes[...]
switch back to set_transform instead of set_format Add video_files_size_in_mb pre-commit run --all-files
This commit is contained in:
parent
601b5fdbfe
commit
d518b036d0
|
@ -56,6 +56,7 @@ from lerobot.common.datasets.utils import (
|
||||||
get_safe_version,
|
get_safe_version,
|
||||||
get_video_duration_in_s,
|
get_video_duration_in_s,
|
||||||
get_video_size_in_mb,
|
get_video_size_in_mb,
|
||||||
|
hf_transform_to_torch,
|
||||||
is_valid_version,
|
is_valid_version,
|
||||||
load_episodes,
|
load_episodes,
|
||||||
load_info,
|
load_info,
|
||||||
|
@ -136,14 +137,16 @@ class LeRobotDatasetMetadata:
|
||||||
return packaging.version.parse(self.info["codebase_version"])
|
return packaging.version.parse(self.info["codebase_version"])
|
||||||
|
|
||||||
def get_data_file_path(self, ep_index: int) -> Path:
|
def get_data_file_path(self, ep_index: int) -> Path:
|
||||||
chunk_idx = self.episodes["data/chunk_index"][ep_index]
|
ep = self.episodes[ep_index]
|
||||||
file_idx = self.episodes["data/file_index"][ep_index]
|
chunk_idx = ep["data/chunk_index"]
|
||||||
|
file_idx = ep["data/file_index"]
|
||||||
fpath = self.data_path.format(chunk_index=chunk_idx, file_index=file_idx)
|
fpath = self.data_path.format(chunk_index=chunk_idx, file_index=file_idx)
|
||||||
return Path(fpath)
|
return Path(fpath)
|
||||||
|
|
||||||
def get_video_file_path(self, ep_index: int, vid_key: str) -> Path:
|
def get_video_file_path(self, ep_index: int, vid_key: str) -> Path:
|
||||||
chunk_idx = self.episodes[f"videos/{vid_key}/chunk_index"][ep_index]
|
ep = self.episodes[ep_index]
|
||||||
file_idx = self.episodes[f"videos/{vid_key}/file_index"][ep_index]
|
chunk_idx = ep[f"videos/{vid_key}/chunk_index"]
|
||||||
|
file_idx = ep[f"videos/{vid_key}/file_index"]
|
||||||
fpath = self.video_path.format(video_key=vid_key, chunk_index=chunk_idx, file_index=file_idx)
|
fpath = self.video_path.format(video_key=vid_key, chunk_index=chunk_idx, file_index=file_idx)
|
||||||
return Path(fpath)
|
return Path(fpath)
|
||||||
|
|
||||||
|
@ -218,9 +221,14 @@ class LeRobotDatasetMetadata:
|
||||||
return self.info["chunks_size"]
|
return self.info["chunks_size"]
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def files_size_in_mb(self) -> int:
|
def data_files_size_in_mb(self) -> int:
|
||||||
"""Max size of file in mega bytes."""
|
"""Max size of data file in mega bytes."""
|
||||||
return self.info["files_size_in_mb"]
|
return self.info["data_files_size_in_mb"]
|
||||||
|
|
||||||
|
@property
|
||||||
|
def video_files_size_in_mb(self) -> int:
|
||||||
|
"""Max size of video file in mega bytes."""
|
||||||
|
return self.info["video_files_size_in_mb"]
|
||||||
|
|
||||||
def get_task_index(self, task: str) -> int | None:
|
def get_task_index(self, task: str) -> int | None:
|
||||||
"""
|
"""
|
||||||
|
@ -278,23 +286,14 @@ class LeRobotDatasetMetadata:
|
||||||
df["dataset_to_index"] = [len(df)]
|
df["dataset_to_index"] = [len(df)]
|
||||||
else:
|
else:
|
||||||
# Retrieve information from the latest parquet file
|
# Retrieve information from the latest parquet file
|
||||||
latest_ep = self.episodes.with_format(
|
latest_ep = self.episodes[-1]
|
||||||
columns=[
|
chunk_idx = latest_ep["meta/episodes/chunk_index"]
|
||||||
"meta/episodes/chunk_index",
|
file_idx = latest_ep["meta/episodes/file_index"]
|
||||||
"meta/episodes/file_index",
|
|
||||||
"dataset_from_index",
|
|
||||||
"dataset_to_index",
|
|
||||||
]
|
|
||||||
)[-1]
|
|
||||||
chunk_idx, file_idx = (
|
|
||||||
latest_ep["meta/episodes/chunk_index"],
|
|
||||||
latest_ep["meta/episodes/file_index"],
|
|
||||||
)
|
|
||||||
|
|
||||||
latest_path = self.root / DEFAULT_EPISODES_PATH.format(chunk_index=chunk_idx, file_index=file_idx)
|
latest_path = self.root / DEFAULT_EPISODES_PATH.format(chunk_index=chunk_idx, file_index=file_idx)
|
||||||
latest_size_in_mb = get_parquet_file_size_in_mb(latest_path)
|
latest_size_in_mb = get_parquet_file_size_in_mb(latest_path)
|
||||||
|
|
||||||
if latest_size_in_mb + ep_size_in_mb >= self.files_size_in_mb:
|
if latest_size_in_mb + ep_size_in_mb >= self.data_files_size_in_mb:
|
||||||
# Size limit is reached, prepare new parquet file
|
# Size limit is reached, prepare new parquet file
|
||||||
chunk_idx, file_idx = update_chunk_file_indices(chunk_idx, file_idx, self.chunks_size)
|
chunk_idx, file_idx = update_chunk_file_indices(chunk_idx, file_idx, self.chunks_size)
|
||||||
|
|
||||||
|
@ -304,7 +303,7 @@ class LeRobotDatasetMetadata:
|
||||||
df["dataset_from_index"] = [latest_ep["dataset_to_index"]]
|
df["dataset_from_index"] = [latest_ep["dataset_to_index"]]
|
||||||
df["dataset_to_index"] = [latest_ep["dataset_to_index"] + len(df)]
|
df["dataset_to_index"] = [latest_ep["dataset_to_index"] + len(df)]
|
||||||
|
|
||||||
if latest_size_in_mb + ep_size_in_mb < self.files_size_in_mb:
|
if latest_size_in_mb + ep_size_in_mb < self.data_files_size_in_mb:
|
||||||
# Size limit wasnt reached, concatenate latest dataframe with new one
|
# Size limit wasnt reached, concatenate latest dataframe with new one
|
||||||
latest_df = pd.read_parquet(latest_path)
|
latest_df = pd.read_parquet(latest_path)
|
||||||
df = pd.concat([latest_df, df], ignore_index=True)
|
df = pd.concat([latest_df, df], ignore_index=True)
|
||||||
|
@ -339,6 +338,7 @@ class LeRobotDatasetMetadata:
|
||||||
# Update info
|
# Update info
|
||||||
self.info["total_episodes"] += 1
|
self.info["total_episodes"] += 1
|
||||||
self.info["total_frames"] += episode_length
|
self.info["total_frames"] += episode_length
|
||||||
|
self.info["total_tasks"] = len(self.tasks)
|
||||||
self.info["splits"] = {"train": f"0:{self.info['total_episodes']}"}
|
self.info["splits"] = {"train": f"0:{self.info['total_episodes']}"}
|
||||||
if len(self.video_keys) > 0:
|
if len(self.video_keys) > 0:
|
||||||
self.update_video_info()
|
self.update_video_info()
|
||||||
|
@ -674,14 +674,14 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||||
def load_hf_dataset(self) -> datasets.Dataset:
|
def load_hf_dataset(self) -> datasets.Dataset:
|
||||||
"""hf_dataset contains all the observations, states, actions, rewards, etc."""
|
"""hf_dataset contains all the observations, states, actions, rewards, etc."""
|
||||||
hf_dataset = load_nested_dataset(self.root / "data")
|
hf_dataset = load_nested_dataset(self.root / "data")
|
||||||
hf_dataset.set_format("torch")
|
hf_dataset.set_transform(hf_transform_to_torch)
|
||||||
return hf_dataset
|
return hf_dataset
|
||||||
|
|
||||||
def create_hf_dataset(self) -> datasets.Dataset:
|
def create_hf_dataset(self) -> datasets.Dataset:
|
||||||
features = get_hf_features_from_features(self.features)
|
features = get_hf_features_from_features(self.features)
|
||||||
ft_dict = {col: [] for col in features}
|
ft_dict = {col: [] for col in features}
|
||||||
hf_dataset = datasets.Dataset.from_dict(ft_dict, features=features, split="train")
|
hf_dataset = datasets.Dataset.from_dict(ft_dict, features=features, split="train")
|
||||||
hf_dataset.set_format("torch")
|
hf_dataset.set_transform(hf_transform_to_torch)
|
||||||
return hf_dataset
|
return hf_dataset
|
||||||
|
|
||||||
@property
|
@property
|
||||||
|
@ -712,8 +712,9 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||||
return get_hf_features_from_features(self.features)
|
return get_hf_features_from_features(self.features)
|
||||||
|
|
||||||
def _get_query_indices(self, idx: int, ep_idx: int) -> tuple[dict[str, list[int | bool]]]:
|
def _get_query_indices(self, idx: int, ep_idx: int) -> tuple[dict[str, list[int | bool]]]:
|
||||||
ep_start = self.meta.episodes["dataset_from_index"][ep_idx]
|
ep = self.meta.episodes[ep_idx]
|
||||||
ep_end = self.meta.episodes["dataset_to_index"][ep_idx]
|
ep_start = ep["dataset_from_index"]
|
||||||
|
ep_end = ep["dataset_to_index"]
|
||||||
query_indices = {
|
query_indices = {
|
||||||
key: [max(ep_start, min(ep_end - 1, idx + delta)) for delta in delta_idx]
|
key: [max(ep_start, min(ep_end - 1, idx + delta)) for delta in delta_idx]
|
||||||
for key, delta_idx in self.delta_indices.items()
|
for key, delta_idx in self.delta_indices.items()
|
||||||
|
@ -754,12 +755,13 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||||
Segmentation Fault. This probably happens because a memory reference to the video loader is created in
|
Segmentation Fault. This probably happens because a memory reference to the video loader is created in
|
||||||
the main process and a subprocess fails to access it.
|
the main process and a subprocess fails to access it.
|
||||||
"""
|
"""
|
||||||
|
ep = self.meta.episodes[ep_idx]
|
||||||
item = {}
|
item = {}
|
||||||
for vid_key, query_ts in query_timestamps.items():
|
for vid_key, query_ts in query_timestamps.items():
|
||||||
# Episodes are stored sequentially on a single mp4 to reduce the number of files.
|
# Episodes are stored sequentially on a single mp4 to reduce the number of files.
|
||||||
# Thus we load the start timestamp of the episode on this mp4 and
|
# Thus we load the start timestamp of the episode on this mp4 and
|
||||||
# shift the query timestamp accordingly.
|
# shift the query timestamp accordingly.
|
||||||
from_timestamp = self.meta.episodes[f"videos/{vid_key}/from_timestamp"][ep_idx]
|
from_timestamp = ep[f"videos/{vid_key}/from_timestamp"]
|
||||||
shifted_query_ts = [from_timestamp + ts for ts in query_ts]
|
shifted_query_ts = [from_timestamp + ts for ts in query_ts]
|
||||||
|
|
||||||
video_path = self.root / self.meta.get_video_file_path(ep_idx, vid_key)
|
video_path = self.root / self.meta.get_video_file_path(ep_idx, vid_key)
|
||||||
|
@ -984,15 +986,16 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||||
latest_num_frames = 0
|
latest_num_frames = 0
|
||||||
else:
|
else:
|
||||||
# Retrieve information from the latest parquet file
|
# Retrieve information from the latest parquet file
|
||||||
latest_ep = self.meta.episodes.with_format(columns=["data/chunk_index", "data/file_index"])[-1]
|
latest_ep = self.meta.episodes[-1]
|
||||||
chunk_idx, file_idx = latest_ep["data/chunk_index"], latest_ep["data/file_index"]
|
chunk_idx = latest_ep["data/chunk_index"]
|
||||||
|
file_idx = latest_ep["data/file_index"]
|
||||||
|
|
||||||
latest_path = self.root / self.meta.data_path.format(chunk_index=chunk_idx, file_index=file_idx)
|
latest_path = self.root / self.meta.data_path.format(chunk_index=chunk_idx, file_index=file_idx)
|
||||||
latest_size_in_mb = get_parquet_file_size_in_mb(latest_path)
|
latest_size_in_mb = get_parquet_file_size_in_mb(latest_path)
|
||||||
latest_num_frames = get_parquet_num_frames(latest_path)
|
latest_num_frames = get_parquet_num_frames(latest_path)
|
||||||
|
|
||||||
# Determine if a new parquet file is needed
|
# Determine if a new parquet file is needed
|
||||||
if latest_size_in_mb + ep_size_in_mb >= self.meta.files_size_in_mb:
|
if latest_size_in_mb + ep_size_in_mb >= self.meta.data_files_size_in_mb:
|
||||||
# Size limit is reached, prepare new parquet file
|
# Size limit is reached, prepare new parquet file
|
||||||
chunk_idx, file_idx = update_chunk_file_indices(chunk_idx, file_idx, self.meta.chunks_size)
|
chunk_idx, file_idx = update_chunk_file_indices(chunk_idx, file_idx, self.meta.chunks_size)
|
||||||
latest_num_frames = 0
|
latest_num_frames = 0
|
||||||
|
@ -1039,13 +1042,9 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||||
shutil.move(str(ep_path), str(new_path))
|
shutil.move(str(ep_path), str(new_path))
|
||||||
else:
|
else:
|
||||||
# Retrieve information from the latest video file
|
# Retrieve information from the latest video file
|
||||||
latest_ep = self.meta.episodes.with_format(
|
latest_ep = self.meta.episodes[-1]
|
||||||
columns=[f"videos/{video_key}/chunk_index", f"videos/{video_key}/file_index"]
|
chunk_idx = latest_ep[f"videos/{video_key}/chunk_index"]
|
||||||
)[-1]
|
file_idx = latest_ep[f"videos/{video_key}/file_index"]
|
||||||
chunk_idx, file_idx = (
|
|
||||||
latest_ep[f"videos/{video_key}/chunk_index"],
|
|
||||||
latest_ep[f"videos/{video_key}/file_index"],
|
|
||||||
)
|
|
||||||
|
|
||||||
latest_path = self.root / self.meta.video_path.format(
|
latest_path = self.root / self.meta.video_path.format(
|
||||||
video_key=video_key, chunk_index=chunk_idx, file_index=file_idx
|
video_key=video_key, chunk_index=chunk_idx, file_index=file_idx
|
||||||
|
@ -1053,7 +1052,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||||
latest_size_in_mb = get_video_size_in_mb(latest_path)
|
latest_size_in_mb = get_video_size_in_mb(latest_path)
|
||||||
latest_duration_in_s = get_video_duration_in_s(latest_path)
|
latest_duration_in_s = get_video_duration_in_s(latest_path)
|
||||||
|
|
||||||
if latest_size_in_mb + ep_size_in_mb >= self.meta.files_size_in_mb:
|
if latest_size_in_mb + ep_size_in_mb >= self.meta.video_files_size_in_mb:
|
||||||
# Move temporary episode video to a new video file in the dataset
|
# Move temporary episode video to a new video file in the dataset
|
||||||
chunk_idx, file_idx = update_chunk_file_indices(chunk_idx, file_idx, self.meta.chunks_size)
|
chunk_idx, file_idx = update_chunk_file_indices(chunk_idx, file_idx, self.meta.chunks_size)
|
||||||
new_path = self.root / self.meta.video_path.format(
|
new_path = self.root / self.meta.video_path.format(
|
||||||
|
@ -1115,16 +1114,6 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||||
if self.image_writer is not None:
|
if self.image_writer is not None:
|
||||||
self.image_writer.wait_until_done()
|
self.image_writer.wait_until_done()
|
||||||
|
|
||||||
# TODO(rcadene): this method is currently not used
|
|
||||||
# def encode_videos(self) -> None:
|
|
||||||
# """
|
|
||||||
# Use ffmpeg to convert frames stored as png into mp4 videos.
|
|
||||||
# Note: `encode_video_frames` is a blocking call. Making it asynchronous shouldn't speedup encoding,
|
|
||||||
# since video encoding with ffmpeg is already using multithreading.
|
|
||||||
# """
|
|
||||||
# for ep_idx in range(self.meta.total_episodes):
|
|
||||||
# self.encode_episode_videos(ep_idx)
|
|
||||||
|
|
||||||
def _encode_temporary_episode_video(self, video_key: str, episode_index: int) -> dict:
|
def _encode_temporary_episode_video(self, video_key: str, episode_index: int) -> dict:
|
||||||
"""
|
"""
|
||||||
Use ffmpeg to convert frames stored as png into mp4 videos.
|
Use ffmpeg to convert frames stored as png into mp4 videos.
|
||||||
|
|
|
@ -50,7 +50,8 @@ from lerobot.common.utils.utils import is_valid_numpy_dtype_string
|
||||||
from lerobot.configs.types import FeatureType, PolicyFeature
|
from lerobot.configs.types import FeatureType, PolicyFeature
|
||||||
|
|
||||||
DEFAULT_CHUNK_SIZE = 1000 # Max number of files per chunk
|
DEFAULT_CHUNK_SIZE = 1000 # Max number of files per chunk
|
||||||
DEFAULT_FILE_SIZE_IN_MB = 100.0 # Max size per file
|
DEFAULT_DATA_FILE_SIZE_IN_MB = 100 # Max size per file
|
||||||
|
DEFAULT_VIDEO_FILE_SIZE_IN_MB = 500 # Max size per file
|
||||||
|
|
||||||
INFO_PATH = "meta/info.json"
|
INFO_PATH = "meta/info.json"
|
||||||
STATS_PATH = "meta/stats.json"
|
STATS_PATH = "meta/stats.json"
|
||||||
|
@ -142,6 +143,7 @@ def get_video_size_in_mb(mp4_path: Path):
|
||||||
|
|
||||||
|
|
||||||
def concat_video_files(paths_to_cat: list[Path], root: Path, video_key: str, chunk_idx: int, file_idx: int):
|
def concat_video_files(paths_to_cat: list[Path], root: Path, video_key: str, chunk_idx: int, file_idx: int):
|
||||||
|
# TODO(rcadene): move to video_utils.py
|
||||||
# TODO(rcadene): add docstring
|
# TODO(rcadene): add docstring
|
||||||
tmp_dir = Path(tempfile.mkdtemp(dir=root))
|
tmp_dir = Path(tempfile.mkdtemp(dir=root))
|
||||||
# Create a text file with the list of files to concatenate
|
# Create a text file with the list of files to concatenate
|
||||||
|
@ -175,6 +177,7 @@ def concat_video_files(paths_to_cat: list[Path], root: Path, video_key: str, chu
|
||||||
|
|
||||||
|
|
||||||
def get_video_duration_in_s(mp4_file: Path):
|
def get_video_duration_in_s(mp4_file: Path):
|
||||||
|
# TODO(rcadene): move to video_utils.py
|
||||||
command = [
|
command = [
|
||||||
"ffprobe",
|
"ffprobe",
|
||||||
"-v",
|
"-v",
|
||||||
|
@ -290,7 +293,7 @@ def load_stats(local_dir: Path) -> dict[str, dict[str, np.ndarray]]:
|
||||||
|
|
||||||
|
|
||||||
def write_hf_dataset(hf_dataset: Dataset, local_dir: Path):
|
def write_hf_dataset(hf_dataset: Dataset, local_dir: Path):
|
||||||
if get_hf_dataset_size_in_mb(hf_dataset) > DEFAULT_FILE_SIZE_IN_MB:
|
if get_hf_dataset_size_in_mb(hf_dataset) > DEFAULT_DATA_FILE_SIZE_IN_MB:
|
||||||
raise NotImplementedError("Contact a maintainer.")
|
raise NotImplementedError("Contact a maintainer.")
|
||||||
|
|
||||||
path = local_dir / DEFAULT_DATA_PATH.format(chunk_index=0, file_index=0)
|
path = local_dir / DEFAULT_DATA_PATH.format(chunk_index=0, file_index=0)
|
||||||
|
@ -310,7 +313,7 @@ def load_tasks(local_dir: Path):
|
||||||
|
|
||||||
|
|
||||||
def write_episodes(episodes: Dataset, local_dir: Path):
|
def write_episodes(episodes: Dataset, local_dir: Path):
|
||||||
if get_hf_dataset_size_in_mb(episodes) > DEFAULT_FILE_SIZE_IN_MB:
|
if get_hf_dataset_size_in_mb(episodes) > DEFAULT_DATA_FILE_SIZE_IN_MB:
|
||||||
raise NotImplementedError("Contact a maintainer.")
|
raise NotImplementedError("Contact a maintainer.")
|
||||||
|
|
||||||
fpath = local_dir / DEFAULT_EPISODES_PATH.format(chunk_index=0, file_index=0)
|
fpath = local_dir / DEFAULT_EPISODES_PATH.format(chunk_index=0, file_index=0)
|
||||||
|
@ -318,9 +321,13 @@ def write_episodes(episodes: Dataset, local_dir: Path):
|
||||||
episodes.to_parquet(fpath)
|
episodes.to_parquet(fpath)
|
||||||
|
|
||||||
|
|
||||||
def load_episodes(local_dir: Path):
|
def load_episodes(local_dir: Path) -> datasets.Dataset:
|
||||||
hf_dataset = load_nested_dataset(local_dir / EPISODES_DIR)
|
episodes = load_nested_dataset(local_dir / EPISODES_DIR)
|
||||||
return hf_dataset
|
# Select episode features/columns containing references to episode data and videos
|
||||||
|
# (e.g. tasks, dataset_from_index, dataset_to_index, data/chunk_index, data/file_index, etc.)
|
||||||
|
# This is to speedup access to these data, instead of having to load episode stats.
|
||||||
|
episodes = episodes.select_columns([key for key in episodes.features if not key.startswith("stats/")])
|
||||||
|
return episodes
|
||||||
|
|
||||||
|
|
||||||
def backward_compatible_episodes_stats(
|
def backward_compatible_episodes_stats(
|
||||||
|
@ -528,9 +535,9 @@ def create_empty_dataset_info(
|
||||||
"total_episodes": 0,
|
"total_episodes": 0,
|
||||||
"total_frames": 0,
|
"total_frames": 0,
|
||||||
"total_tasks": 0,
|
"total_tasks": 0,
|
||||||
"total_videos": 0,
|
|
||||||
"chunks_size": DEFAULT_CHUNK_SIZE,
|
"chunks_size": DEFAULT_CHUNK_SIZE,
|
||||||
"files_size_in_mb": DEFAULT_FILE_SIZE_IN_MB,
|
"data_files_size_in_mb": DEFAULT_DATA_FILE_SIZE_IN_MB,
|
||||||
|
"video_files_size_in_mb": DEFAULT_VIDEO_FILE_SIZE_IN_MB,
|
||||||
"fps": fps,
|
"fps": fps,
|
||||||
"splits": {},
|
"splits": {},
|
||||||
"data_path": DEFAULT_DATA_PATH,
|
"data_path": DEFAULT_DATA_PATH,
|
||||||
|
|
|
@ -34,8 +34,9 @@ from lerobot.common.datasets.compute_stats import aggregate_stats
|
||||||
from lerobot.common.datasets.lerobot_dataset import CODEBASE_VERSION, LeRobotDataset
|
from lerobot.common.datasets.lerobot_dataset import CODEBASE_VERSION, LeRobotDataset
|
||||||
from lerobot.common.datasets.utils import (
|
from lerobot.common.datasets.utils import (
|
||||||
DEFAULT_CHUNK_SIZE,
|
DEFAULT_CHUNK_SIZE,
|
||||||
|
DEFAULT_DATA_FILE_SIZE_IN_MB,
|
||||||
DEFAULT_DATA_PATH,
|
DEFAULT_DATA_PATH,
|
||||||
DEFAULT_FILE_SIZE_IN_MB,
|
DEFAULT_VIDEO_FILE_SIZE_IN_MB,
|
||||||
DEFAULT_VIDEO_PATH,
|
DEFAULT_VIDEO_PATH,
|
||||||
cast_stats_to_numpy,
|
cast_stats_to_numpy,
|
||||||
concat_video_files,
|
concat_video_files,
|
||||||
|
@ -174,7 +175,7 @@ def convert_data(root, new_root):
|
||||||
episodes_metadata.append(ep_metadata)
|
episodes_metadata.append(ep_metadata)
|
||||||
ep_idx += 1
|
ep_idx += 1
|
||||||
|
|
||||||
if size_in_mb < DEFAULT_FILE_SIZE_IN_MB:
|
if size_in_mb < DEFAULT_DATA_FILE_SIZE_IN_MB:
|
||||||
paths_to_cat.append(ep_path)
|
paths_to_cat.append(ep_path)
|
||||||
continue
|
continue
|
||||||
|
|
||||||
|
@ -263,7 +264,7 @@ def convert_videos_of_camera(root: Path, new_root: Path, video_key):
|
||||||
episodes_metadata.append(ep_metadata)
|
episodes_metadata.append(ep_metadata)
|
||||||
ep_idx += 1
|
ep_idx += 1
|
||||||
|
|
||||||
if size_in_mb < DEFAULT_FILE_SIZE_IN_MB:
|
if size_in_mb < DEFAULT_VIDEO_FILE_SIZE_IN_MB:
|
||||||
paths_to_cat.append(ep_path)
|
paths_to_cat.append(ep_path)
|
||||||
continue
|
continue
|
||||||
|
|
||||||
|
@ -337,8 +338,8 @@ def convert_info(root, new_root):
|
||||||
info["codebase_version"] = "v3.0"
|
info["codebase_version"] = "v3.0"
|
||||||
del info["total_chunks"]
|
del info["total_chunks"]
|
||||||
del info["total_videos"]
|
del info["total_videos"]
|
||||||
info["files_size_in_mb"] = DEFAULT_FILE_SIZE_IN_MB
|
info["data_files_size_in_mb"] = DEFAULT_DATA_FILE_SIZE_IN_MB
|
||||||
# TODO(rcadene): chunk- or chunk_ or file- or file_
|
info["video_files_size_in_mb"] = DEFAULT_VIDEO_FILE_SIZE_IN_MB
|
||||||
info["data_path"] = DEFAULT_DATA_PATH
|
info["data_path"] = DEFAULT_DATA_PATH
|
||||||
info["video_path"] = DEFAULT_VIDEO_PATH
|
info["video_path"] = DEFAULT_VIDEO_PATH
|
||||||
info["fps"] = float(info["fps"])
|
info["fps"] = float(info["fps"])
|
||||||
|
|
|
@ -155,6 +155,7 @@ def decode_video_frames_torchvision(
|
||||||
)
|
)
|
||||||
|
|
||||||
# get closest frames to the query timestamps
|
# get closest frames to the query timestamps
|
||||||
|
# TODO(rcadene): remove torch.stack
|
||||||
closest_frames = torch.stack([loaded_frames[idx] for idx in argmin_])
|
closest_frames = torch.stack([loaded_frames[idx] for idx in argmin_])
|
||||||
closest_ts = loaded_ts[argmin_]
|
closest_ts = loaded_ts[argmin_]
|
||||||
|
|
||||||
|
|
|
@ -28,12 +28,14 @@ from datasets import Dataset
|
||||||
from lerobot.common.datasets.lerobot_dataset import CODEBASE_VERSION, LeRobotDataset, LeRobotDatasetMetadata
|
from lerobot.common.datasets.lerobot_dataset import CODEBASE_VERSION, LeRobotDataset, LeRobotDatasetMetadata
|
||||||
from lerobot.common.datasets.utils import (
|
from lerobot.common.datasets.utils import (
|
||||||
DEFAULT_CHUNK_SIZE,
|
DEFAULT_CHUNK_SIZE,
|
||||||
|
DEFAULT_DATA_FILE_SIZE_IN_MB,
|
||||||
DEFAULT_DATA_PATH,
|
DEFAULT_DATA_PATH,
|
||||||
DEFAULT_FEATURES,
|
DEFAULT_FEATURES,
|
||||||
DEFAULT_FILE_SIZE_IN_MB,
|
DEFAULT_VIDEO_FILE_SIZE_IN_MB,
|
||||||
DEFAULT_VIDEO_PATH,
|
DEFAULT_VIDEO_PATH,
|
||||||
flatten_dict,
|
flatten_dict,
|
||||||
get_hf_features_from_features,
|
get_hf_features_from_features,
|
||||||
|
hf_transform_to_torch,
|
||||||
)
|
)
|
||||||
from tests.fixtures.constants import (
|
from tests.fixtures.constants import (
|
||||||
DEFAULT_FPS,
|
DEFAULT_FPS,
|
||||||
|
@ -121,7 +123,8 @@ def info_factory(features_factory):
|
||||||
total_tasks: int = 0,
|
total_tasks: int = 0,
|
||||||
total_videos: int = 0,
|
total_videos: int = 0,
|
||||||
chunks_size: int = DEFAULT_CHUNK_SIZE,
|
chunks_size: int = DEFAULT_CHUNK_SIZE,
|
||||||
files_size_in_mb: float = DEFAULT_FILE_SIZE_IN_MB,
|
data_files_size_in_mb: float = DEFAULT_DATA_FILE_SIZE_IN_MB,
|
||||||
|
video_files_size_in_mb: float = DEFAULT_VIDEO_FILE_SIZE_IN_MB,
|
||||||
data_path: str = DEFAULT_DATA_PATH,
|
data_path: str = DEFAULT_DATA_PATH,
|
||||||
video_path: str = DEFAULT_VIDEO_PATH,
|
video_path: str = DEFAULT_VIDEO_PATH,
|
||||||
motor_features: dict = DUMMY_MOTOR_FEATURES,
|
motor_features: dict = DUMMY_MOTOR_FEATURES,
|
||||||
|
@ -137,7 +140,8 @@ def info_factory(features_factory):
|
||||||
"total_tasks": total_tasks,
|
"total_tasks": total_tasks,
|
||||||
"total_videos": total_videos,
|
"total_videos": total_videos,
|
||||||
"chunks_size": chunks_size,
|
"chunks_size": chunks_size,
|
||||||
"files_size_in_mb": files_size_in_mb,
|
"data_files_size_in_mb": data_files_size_in_mb,
|
||||||
|
"video_files_size_in_mb": video_files_size_in_mb,
|
||||||
"fps": fps,
|
"fps": fps,
|
||||||
"splits": {},
|
"splits": {},
|
||||||
"data_path": data_path,
|
"data_path": data_path,
|
||||||
|
@ -352,7 +356,7 @@ def hf_dataset_factory(features_factory, tasks_factory, episodes_factory, img_ar
|
||||||
},
|
},
|
||||||
features=hf_features,
|
features=hf_features,
|
||||||
)
|
)
|
||||||
dataset.set_format("torch")
|
dataset.set_transform(hf_transform_to_torch)
|
||||||
return dataset
|
return dataset
|
||||||
|
|
||||||
return _create_hf_dataset
|
return _create_hf_dataset
|
||||||
|
|
Loading…
Reference in New Issue