Fix image writer
This commit is contained in:
parent
df3d2ec5df
commit
51e87f6f97
|
@ -14,11 +14,12 @@
|
||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
import multiprocessing
|
import multiprocessing
|
||||||
from concurrent.futures import ThreadPoolExecutor, wait
|
import queue
|
||||||
|
import threading
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
import tqdm
|
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
|
|
||||||
DEFAULT_IMAGE_PATH = "{image_key}/episode_{episode_index:06d}/frame_{frame_index:06d}.png"
|
DEFAULT_IMAGE_PATH = "{image_key}/episode_{episode_index:06d}/frame_{frame_index:06d}.png"
|
||||||
|
@ -39,8 +40,39 @@ def safe_stop_image_writer(func):
|
||||||
return wrapper
|
return wrapper
|
||||||
|
|
||||||
|
|
||||||
|
def write_image(image_array: np.ndarray, fpath: Path):
|
||||||
|
try:
|
||||||
|
image = Image.fromarray(image_array)
|
||||||
|
image.save(fpath)
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Error writing image {fpath}: {e}")
|
||||||
|
|
||||||
|
|
||||||
|
def worker_thread_process(queue: queue.Queue):
|
||||||
|
while True:
|
||||||
|
item = queue.get()
|
||||||
|
if item is None:
|
||||||
|
queue.task_done()
|
||||||
|
break
|
||||||
|
image_array, fpath = item
|
||||||
|
write_image(image_array, fpath)
|
||||||
|
queue.task_done()
|
||||||
|
|
||||||
|
|
||||||
|
def worker_process(queue: queue.Queue, num_threads: int):
|
||||||
|
threads = []
|
||||||
|
for _ in range(num_threads):
|
||||||
|
t = threading.Thread(target=worker_thread_process, args=(queue,))
|
||||||
|
t.daemon = True
|
||||||
|
t.start()
|
||||||
|
threads.append(t)
|
||||||
|
for t in threads:
|
||||||
|
t.join()
|
||||||
|
|
||||||
|
|
||||||
class ImageWriter:
|
class ImageWriter:
|
||||||
"""This class abstract away the initialisation of processes or/and threads to
|
"""
|
||||||
|
This class abstract away the initialisation of processes or/and threads to
|
||||||
save images on disk asynchrounously, which is critical to control a robot and record data
|
save images on disk asynchrounously, which is critical to control a robot and record data
|
||||||
at a high frame rate.
|
at a high frame rate.
|
||||||
|
|
||||||
|
@ -53,113 +85,66 @@ class ImageWriter:
|
||||||
the number of threads. If it is still not stable, try to use 1 subprocess, or more.
|
the number of threads. If it is still not stable, try to use 1 subprocess, or more.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, write_dir: Path, num_processes: int = 0, num_threads: int = 1, timeout: int = 10):
|
def __init__(self, write_dir: Path, num_processes: int = 0, num_threads: int = 1):
|
||||||
self.dir = write_dir
|
self.write_dir = write_dir
|
||||||
self.dir.mkdir(parents=True, exist_ok=True)
|
self.write_dir.mkdir(parents=True, exist_ok=True)
|
||||||
self.image_path = DEFAULT_IMAGE_PATH
|
self.image_path = DEFAULT_IMAGE_PATH
|
||||||
|
|
||||||
self.num_processes = num_processes
|
self.num_processes = num_processes
|
||||||
self.num_threads = num_threads
|
self.num_threads = num_threads
|
||||||
self.timeout = timeout
|
self.queue = None
|
||||||
|
self.threads = []
|
||||||
|
self.processes = []
|
||||||
|
|
||||||
if self.num_processes == 0 and self.num_threads == 0:
|
if self.num_processes == 0:
|
||||||
self.type = "synchronous"
|
# Use threading
|
||||||
elif self.num_processes == 0 and self.num_threads > 0:
|
self.queue = queue.Queue()
|
||||||
self.type = "threads"
|
for _ in range(self.num_threads):
|
||||||
self.threads = ThreadPoolExecutor(max_workers=self.num_threads)
|
t = threading.Thread(target=worker_thread_process, args=(self.queue,))
|
||||||
self.futures = []
|
t.daemon = True
|
||||||
|
t.start()
|
||||||
|
self.threads.append(t)
|
||||||
else:
|
else:
|
||||||
self.type = "processes"
|
# Use multiprocessing
|
||||||
self.main_event = multiprocessing.Event()
|
self.queue = multiprocessing.JoinableQueue()
|
||||||
self.image_queue = multiprocessing.Queue()
|
|
||||||
self.processes: list[multiprocessing.Process] = []
|
|
||||||
self.events: list[multiprocessing.Event] = []
|
|
||||||
for _ in range(self.num_processes):
|
for _ in range(self.num_processes):
|
||||||
event = multiprocessing.Event()
|
p = multiprocessing.Process(target=worker_process, args=(self.queue, self.num_threads))
|
||||||
process = multiprocessing.Process(target=self._loop_to_save_images_in_threads, args=(event,))
|
p.daemon = True
|
||||||
process.start()
|
p.start()
|
||||||
self.processes.append(process)
|
self.processes.append(p)
|
||||||
self.events.append(event)
|
|
||||||
|
|
||||||
def _loop_to_save_images_in_threads(self, event: multiprocessing.Event) -> None:
|
|
||||||
with ThreadPoolExecutor(max_workers=self.num_threads) as executor:
|
|
||||||
futures = []
|
|
||||||
while True:
|
|
||||||
frame_data = self.image_queue.get()
|
|
||||||
if frame_data is None:
|
|
||||||
self._wait_threads(self.futures, 10)
|
|
||||||
return
|
|
||||||
|
|
||||||
image, file_path = frame_data
|
|
||||||
futures.append(executor.submit(self._save_image, image, file_path))
|
|
||||||
|
|
||||||
if self.main_event.is_set():
|
|
||||||
self._wait_threads(self.futures, 10)
|
|
||||||
event.set()
|
|
||||||
|
|
||||||
def async_save_image(self, image: torch.Tensor, file_path: Path) -> None:
|
|
||||||
"""Save an image asynchronously using threads or processes."""
|
|
||||||
if self.type == "synchronous":
|
|
||||||
self._save_image(image, file_path)
|
|
||||||
elif self.type == "threads":
|
|
||||||
self.futures.append(self.threads.submit(self._save_image, image, file_path))
|
|
||||||
else:
|
|
||||||
self.image_queue.put((image, file_path))
|
|
||||||
|
|
||||||
def _save_image(self, image: torch.Tensor, file_path: Path) -> None:
|
|
||||||
img = Image.fromarray(image.numpy())
|
|
||||||
img.save(str(file_path), quality=100)
|
|
||||||
|
|
||||||
def get_image_file_path(self, episode_index: int, image_key: str, frame_index: int) -> Path:
|
def get_image_file_path(self, episode_index: int, image_key: str, frame_index: int) -> Path:
|
||||||
fpath = self.image_path.format(
|
fpath = self.image_path.format(
|
||||||
image_key=image_key, episode_index=episode_index, frame_index=frame_index
|
image_key=image_key, episode_index=episode_index, frame_index=frame_index
|
||||||
)
|
)
|
||||||
return self.dir / fpath
|
return self.write_dir / fpath
|
||||||
|
|
||||||
def get_episode_dir(self, episode_index: int, image_key: str) -> Path:
|
def get_episode_dir(self, episode_index: int, image_key: str) -> Path:
|
||||||
return self.get_image_file_path(
|
return self.get_image_file_path(
|
||||||
episode_index=episode_index, image_key=image_key, frame_index=0
|
episode_index=episode_index, image_key=image_key, frame_index=0
|
||||||
).parent
|
).parent
|
||||||
|
|
||||||
def wait(self) -> None:
|
def save_image(self, image_array: torch.Tensor | np.ndarray, fpath: Path):
|
||||||
"""Wait for the thread/processes to finish writing."""
|
if isinstance(image_array, torch.Tensor):
|
||||||
if self.type == "synchronous":
|
image_array = image_array.numpy()
|
||||||
return
|
self.queue.put((image_array, fpath))
|
||||||
elif self.type == "threads":
|
|
||||||
self._wait_threads(self.futures)
|
def wait_until_done(self):
|
||||||
|
self.queue.join()
|
||||||
|
|
||||||
|
def stop(self):
|
||||||
|
if self.num_processes == 0:
|
||||||
|
# For threading
|
||||||
|
for _ in self.threads:
|
||||||
|
self.queue.put(None)
|
||||||
|
for t in self.threads:
|
||||||
|
t.join()
|
||||||
else:
|
else:
|
||||||
self._wait_processes()
|
# For multiprocessing
|
||||||
|
num_nones = self.num_processes * self.num_threads
|
||||||
def _wait_threads(self, futures) -> None:
|
for _ in range(num_nones):
|
||||||
with tqdm.tqdm(total=len(futures), desc="Writing images") as progress_bar:
|
self.queue.put(None)
|
||||||
wait(futures, timeout=self.timeout)
|
self.queue.close()
|
||||||
progress_bar.update(len(futures))
|
self.queue.join_thread()
|
||||||
|
for p in self.processes:
|
||||||
def _wait_processes(self) -> None:
|
p.join()
|
||||||
self.main_event.set()
|
|
||||||
for event in self.events:
|
|
||||||
event.wait()
|
|
||||||
|
|
||||||
self.main_event.clear()
|
|
||||||
|
|
||||||
def shutdown(self, timeout=20) -> None:
|
|
||||||
"""Stop the image writer, waiting for all processes or threads to finish."""
|
|
||||||
if self.type == "synchronous":
|
|
||||||
return
|
|
||||||
elif self.type == "threads":
|
|
||||||
self.threads.shutdown(wait=True)
|
|
||||||
else:
|
|
||||||
self._stop_processes(timeout)
|
|
||||||
|
|
||||||
def _stop_processes(self, timeout) -> None:
|
|
||||||
for _ in self.processes:
|
|
||||||
self.image_queue.put(None)
|
|
||||||
|
|
||||||
for process in self.processes:
|
|
||||||
process.join(timeout=timeout)
|
|
||||||
|
|
||||||
for process in self.processes:
|
|
||||||
if process.is_alive():
|
|
||||||
process.terminate()
|
|
||||||
|
|
||||||
self.image_queue.close()
|
|
||||||
self.image_queue.join_thread()
|
|
||||||
|
|
|
@ -25,7 +25,6 @@ import datasets
|
||||||
import torch
|
import torch
|
||||||
import torch.utils
|
import torch.utils
|
||||||
from datasets import load_dataset
|
from datasets import load_dataset
|
||||||
from datasets.table import embed_table_storage
|
|
||||||
from huggingface_hub import snapshot_download, upload_folder
|
from huggingface_hub import snapshot_download, upload_folder
|
||||||
|
|
||||||
from lerobot.common.datasets.compute_stats import aggregate_stats, compute_stats
|
from lerobot.common.datasets.compute_stats import aggregate_stats, compute_stats
|
||||||
|
@ -51,6 +50,7 @@ from lerobot.common.datasets.utils import (
|
||||||
load_stats,
|
load_stats,
|
||||||
load_tasks,
|
load_tasks,
|
||||||
write_json,
|
write_json,
|
||||||
|
write_parquet,
|
||||||
write_stats,
|
write_stats,
|
||||||
)
|
)
|
||||||
from lerobot.common.datasets.video_utils import (
|
from lerobot.common.datasets.video_utils import (
|
||||||
|
@ -354,7 +354,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||||
@property
|
@property
|
||||||
def num_samples(self) -> int:
|
def num_samples(self) -> int:
|
||||||
"""Number of samples/frames in selected episodes."""
|
"""Number of samples/frames in selected episodes."""
|
||||||
return len(self.hf_dataset)
|
return len(self.hf_dataset) if self.hf_dataset is not None else self.total_frames
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def num_episodes(self) -> int:
|
def num_episodes(self) -> int:
|
||||||
|
@ -584,9 +584,9 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||||
if frame_index == 0:
|
if frame_index == 0:
|
||||||
img_path.parent.mkdir(parents=True, exist_ok=True)
|
img_path.parent.mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
self.image_writer.async_save_image(
|
self.image_writer.save_image(
|
||||||
image=frame[cam_key],
|
image_array=frame[cam_key],
|
||||||
file_path=img_path,
|
fpath=img_path,
|
||||||
)
|
)
|
||||||
|
|
||||||
if cam_key in self.image_keys:
|
if cam_key in self.image_keys:
|
||||||
|
@ -640,14 +640,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||||
ep_dataset = datasets.Dataset.from_dict(self.episode_buffer, features=self._features, split="train")
|
ep_dataset = datasets.Dataset.from_dict(self.episode_buffer, features=self._features, split="train")
|
||||||
ep_data_path = self.root / self.get_data_file_path(ep_index=episode_index)
|
ep_data_path = self.root / self.get_data_file_path(ep_index=episode_index)
|
||||||
ep_data_path.parent.mkdir(parents=True, exist_ok=True)
|
ep_data_path.parent.mkdir(parents=True, exist_ok=True)
|
||||||
|
write_parquet(ep_dataset, ep_data_path)
|
||||||
# Embed image bytes into the table before saving to parquet
|
|
||||||
format = ep_dataset.format
|
|
||||||
ep_dataset = ep_dataset.with_format("arrow")
|
|
||||||
ep_dataset = ep_dataset.map(embed_table_storage, batched=False)
|
|
||||||
ep_dataset = ep_dataset.with_format(**format)
|
|
||||||
|
|
||||||
ep_dataset.to_parquet(ep_data_path)
|
|
||||||
|
|
||||||
def _save_episode_to_metadata(
|
def _save_episode_to_metadata(
|
||||||
self, episode_index: int, episode_length: int, task: str, task_index: int
|
self, episode_index: int, episode_length: int, task: str, task_index: int
|
||||||
|
@ -709,13 +702,13 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||||
remove the image_write in order for the LeRobotDataset object to be pickleable and parallelized.
|
remove the image_write in order for the LeRobotDataset object to be pickleable and parallelized.
|
||||||
"""
|
"""
|
||||||
if self.image_writer is not None:
|
if self.image_writer is not None:
|
||||||
self.image_writer.shutdown()
|
self.image_writer.stop()
|
||||||
self.image_writer = None
|
self.image_writer = None
|
||||||
|
|
||||||
def _wait_image_writer(self) -> None:
|
def _wait_image_writer(self) -> None:
|
||||||
"""Wait for asynchronous image writer to finish."""
|
"""Wait for asynchronous image writer to finish."""
|
||||||
if self.image_writer is not None:
|
if self.image_writer is not None:
|
||||||
self.image_writer.wait()
|
self.image_writer.wait_until_done()
|
||||||
|
|
||||||
def encode_videos(self) -> None:
|
def encode_videos(self) -> None:
|
||||||
# Use ffmpeg to convert frames stored as png into mp4 videos
|
# Use ffmpeg to convert frames stored as png into mp4 videos
|
||||||
|
@ -754,7 +747,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||||
self._write_video_info()
|
self._write_video_info()
|
||||||
|
|
||||||
if not keep_image_files and self.image_writer is not None:
|
if not keep_image_files and self.image_writer is not None:
|
||||||
shutil.rmtree(self.image_writer.dir)
|
shutil.rmtree(self.image_writer.write_dir)
|
||||||
|
|
||||||
if run_compute_stats:
|
if run_compute_stats:
|
||||||
self.stop_image_writer()
|
self.stop_image_writer()
|
||||||
|
|
|
@ -23,6 +23,7 @@ from typing import Any, Dict
|
||||||
import datasets
|
import datasets
|
||||||
import jsonlines
|
import jsonlines
|
||||||
import torch
|
import torch
|
||||||
|
from datasets.table import embed_table_storage
|
||||||
from huggingface_hub import DatasetCard, HfApi
|
from huggingface_hub import DatasetCard, HfApi
|
||||||
from PIL import Image as PILImage
|
from PIL import Image as PILImage
|
||||||
from torchvision import transforms
|
from torchvision import transforms
|
||||||
|
@ -80,6 +81,15 @@ def unflatten_dict(d: dict, sep: str = "/") -> dict:
|
||||||
return outdict
|
return outdict
|
||||||
|
|
||||||
|
|
||||||
|
def write_parquet(dataset: datasets.Dataset, fpath: Path) -> None:
|
||||||
|
# Embed image bytes into the table before saving to parquet
|
||||||
|
format = dataset.format
|
||||||
|
dataset = dataset.with_format("arrow")
|
||||||
|
dataset = dataset.map(embed_table_storage, batched=False)
|
||||||
|
dataset = dataset.with_format(**format)
|
||||||
|
dataset.to_parquet(fpath)
|
||||||
|
|
||||||
|
|
||||||
def load_json(fpath: Path) -> Any:
|
def load_json(fpath: Path) -> Any:
|
||||||
with open(fpath) as f:
|
with open(fpath) as f:
|
||||||
return json.load(f)
|
return json.load(f)
|
||||||
|
@ -114,6 +124,25 @@ def write_stats(stats: dict[str, torch.Tensor | dict], fpath: Path) -> None:
|
||||||
write_json(serialized_stats, fpath)
|
write_json(serialized_stats, fpath)
|
||||||
|
|
||||||
|
|
||||||
|
def load_info(local_dir: Path) -> dict:
|
||||||
|
return load_json(local_dir / INFO_PATH)
|
||||||
|
|
||||||
|
|
||||||
|
def load_stats(local_dir: Path) -> dict:
|
||||||
|
stats = load_json(local_dir / STATS_PATH)
|
||||||
|
stats = {key: torch.tensor(value) for key, value in flatten_dict(stats).items()}
|
||||||
|
return unflatten_dict(stats)
|
||||||
|
|
||||||
|
|
||||||
|
def load_tasks(local_dir: Path) -> dict:
|
||||||
|
tasks = load_jsonlines(local_dir / TASKS_PATH)
|
||||||
|
return {item["task_index"]: item["task"] for item in sorted(tasks, key=lambda x: x["task_index"])}
|
||||||
|
|
||||||
|
|
||||||
|
def load_episode_dicts(local_dir: Path) -> dict:
|
||||||
|
return load_jsonlines(local_dir / EPISODES_PATH)
|
||||||
|
|
||||||
|
|
||||||
def hf_transform_to_torch(items_dict: dict[torch.Tensor | None]):
|
def hf_transform_to_torch(items_dict: dict[torch.Tensor | None]):
|
||||||
"""Get a transform function that convert items from Hugging Face dataset (pyarrow)
|
"""Get a transform function that convert items from Hugging Face dataset (pyarrow)
|
||||||
to torch tensors. Importantly, images are converted from PIL, which corresponds to
|
to torch tensors. Importantly, images are converted from PIL, which corresponds to
|
||||||
|
@ -185,25 +214,6 @@ def get_hub_safe_version(repo_id: str, version: str, enforce_v2: bool = True) ->
|
||||||
return version
|
return version
|
||||||
|
|
||||||
|
|
||||||
def load_info(local_dir: Path) -> dict:
|
|
||||||
return load_json(local_dir / INFO_PATH)
|
|
||||||
|
|
||||||
|
|
||||||
def load_stats(local_dir: Path) -> dict:
|
|
||||||
stats = load_json(local_dir / STATS_PATH)
|
|
||||||
stats = {key: torch.tensor(value) for key, value in flatten_dict(stats).items()}
|
|
||||||
return unflatten_dict(stats)
|
|
||||||
|
|
||||||
|
|
||||||
def load_tasks(local_dir: Path) -> dict:
|
|
||||||
tasks = load_jsonlines(local_dir / TASKS_PATH)
|
|
||||||
return {item["task_index"]: item["task"] for item in sorted(tasks, key=lambda x: x["task_index"])}
|
|
||||||
|
|
||||||
|
|
||||||
def load_episode_dicts(local_dir: Path) -> dict:
|
|
||||||
return load_jsonlines(local_dir / EPISODES_PATH)
|
|
||||||
|
|
||||||
|
|
||||||
def _get_info_from_robot(robot: Robot, use_videos: bool) -> tuple[list | dict]:
|
def _get_info_from_robot(robot: Robot, use_videos: bool) -> tuple[list | dict]:
|
||||||
shapes = {key: len(names) for key, names in robot.names.items()}
|
shapes = {key: len(names) for key, names in robot.names.items()}
|
||||||
camera_shapes = {}
|
camera_shapes = {}
|
||||||
|
|
Loading…
Reference in New Issue