Fix convert v30 with image datasets

This commit is contained in:
Remi Cadene 2025-04-24 18:51:53 +02:00
parent 71715c3914
commit 253c649507
1 changed files with 25 additions and 5 deletions

View File

@ -24,8 +24,9 @@ from typing import Any
import jsonlines
import pandas as pd
import pyarrow as pa
import tqdm
from datasets import Dataset
from datasets import Dataset, Features, Image
from huggingface_hub import HfApi, snapshot_download
from requests import HTTPError
@ -138,7 +139,7 @@ def convert_tasks(root, new_root):
write_tasks(df_tasks, new_root)
def concat_data_files(paths_to_cat, new_root, chunk_idx, file_idx):
def concat_data_files(paths_to_cat, new_root, chunk_idx, file_idx, image_keys):
# TODO(rcadene): to save RAM use Dataset.from_parquet(file) and concatenate_datasets
dataframes = [pd.read_parquet(file) for file in paths_to_cat]
# Concatenate all DataFrames along rows
@ -146,13 +147,25 @@ def concat_data_files(paths_to_cat, new_root, chunk_idx, file_idx):
path = new_root / DEFAULT_DATA_PATH.format(chunk_index=chunk_idx, file_index=file_idx)
path.parent.mkdir(parents=True, exist_ok=True)
concatenated_df.to_parquet(path, index=False)
if len(image_keys) > 0:
schema = pa.Schema.from_pandas(concatenated_df)
features = Features.from_arrow_schema(schema)
for key in image_keys:
features[key] = Image()
schema = features.arrow_schema
else:
schema = None
concatenated_df.to_parquet(path, index=False, schema=schema)
def convert_data(root, new_root):
data_dir = root / "data"
ep_paths = sorted(data_dir.glob("*/*.parquet"))
image_keys = get_image_keys(root)
ep_idx = 0
chunk_idx = 0
file_idx = 0
@ -179,7 +192,7 @@ def convert_data(root, new_root):
paths_to_cat.append(ep_path)
continue
concat_data_files(paths_to_cat, new_root, chunk_idx, file_idx)
concat_data_files(paths_to_cat, new_root, chunk_idx, file_idx, image_keys)
# Reset for the next file
size_in_mb = ep_size_in_mb
@ -190,7 +203,7 @@ def convert_data(root, new_root):
# Write remaining data if any
if paths_to_cat:
concat_data_files(paths_to_cat, new_root, chunk_idx, file_idx)
concat_data_files(paths_to_cat, new_root, chunk_idx, file_idx, image_keys)
return episodes_metadata
@ -202,6 +215,13 @@ def get_video_keys(root):
return video_keys
def get_image_keys(root):
info = load_info(root)
features = info["features"]
image_keys = [key for key, ft in features.items() if ft["dtype"] == "image"]
return image_keys
def convert_videos(root: Path, new_root: Path):
video_keys = get_video_keys(root)
if len(video_keys) == 0: