Fix image writer
This commit is contained in:
parent
df3d2ec5df
commit
51e87f6f97
lerobot/common/datasets
|
@ -14,11 +14,12 @@
|
|||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import multiprocessing
|
||||
from concurrent.futures import ThreadPoolExecutor, wait
|
||||
import queue
|
||||
import threading
|
||||
from pathlib import Path
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import tqdm
|
||||
from PIL import Image
|
||||
|
||||
DEFAULT_IMAGE_PATH = "{image_key}/episode_{episode_index:06d}/frame_{frame_index:06d}.png"
|
||||
|
@ -39,8 +40,39 @@ def safe_stop_image_writer(func):
|
|||
return wrapper
|
||||
|
||||
|
||||
def write_image(image_array: np.ndarray, fpath: Path):
|
||||
try:
|
||||
image = Image.fromarray(image_array)
|
||||
image.save(fpath)
|
||||
except Exception as e:
|
||||
print(f"Error writing image {fpath}: {e}")
|
||||
|
||||
|
||||
def worker_thread_process(queue: queue.Queue):
|
||||
while True:
|
||||
item = queue.get()
|
||||
if item is None:
|
||||
queue.task_done()
|
||||
break
|
||||
image_array, fpath = item
|
||||
write_image(image_array, fpath)
|
||||
queue.task_done()
|
||||
|
||||
|
||||
def worker_process(queue: queue.Queue, num_threads: int):
|
||||
threads = []
|
||||
for _ in range(num_threads):
|
||||
t = threading.Thread(target=worker_thread_process, args=(queue,))
|
||||
t.daemon = True
|
||||
t.start()
|
||||
threads.append(t)
|
||||
for t in threads:
|
||||
t.join()
|
||||
|
||||
|
||||
class ImageWriter:
|
||||
"""This class abstract away the initialisation of processes or/and threads to
|
||||
"""
|
||||
This class abstract away the initialisation of processes or/and threads to
|
||||
save images on disk asynchrounously, which is critical to control a robot and record data
|
||||
at a high frame rate.
|
||||
|
||||
|
@ -53,113 +85,66 @@ class ImageWriter:
|
|||
the number of threads. If it is still not stable, try to use 1 subprocess, or more.
|
||||
"""
|
||||
|
||||
def __init__(self, write_dir: Path, num_processes: int = 0, num_threads: int = 1, timeout: int = 10):
|
||||
self.dir = write_dir
|
||||
self.dir.mkdir(parents=True, exist_ok=True)
|
||||
def __init__(self, write_dir: Path, num_processes: int = 0, num_threads: int = 1):
|
||||
self.write_dir = write_dir
|
||||
self.write_dir.mkdir(parents=True, exist_ok=True)
|
||||
self.image_path = DEFAULT_IMAGE_PATH
|
||||
|
||||
self.num_processes = num_processes
|
||||
self.num_threads = num_threads
|
||||
self.timeout = timeout
|
||||
self.queue = None
|
||||
self.threads = []
|
||||
self.processes = []
|
||||
|
||||
if self.num_processes == 0 and self.num_threads == 0:
|
||||
self.type = "synchronous"
|
||||
elif self.num_processes == 0 and self.num_threads > 0:
|
||||
self.type = "threads"
|
||||
self.threads = ThreadPoolExecutor(max_workers=self.num_threads)
|
||||
self.futures = []
|
||||
if self.num_processes == 0:
|
||||
# Use threading
|
||||
self.queue = queue.Queue()
|
||||
for _ in range(self.num_threads):
|
||||
t = threading.Thread(target=worker_thread_process, args=(self.queue,))
|
||||
t.daemon = True
|
||||
t.start()
|
||||
self.threads.append(t)
|
||||
else:
|
||||
self.type = "processes"
|
||||
self.main_event = multiprocessing.Event()
|
||||
self.image_queue = multiprocessing.Queue()
|
||||
self.processes: list[multiprocessing.Process] = []
|
||||
self.events: list[multiprocessing.Event] = []
|
||||
# Use multiprocessing
|
||||
self.queue = multiprocessing.JoinableQueue()
|
||||
for _ in range(self.num_processes):
|
||||
event = multiprocessing.Event()
|
||||
process = multiprocessing.Process(target=self._loop_to_save_images_in_threads, args=(event,))
|
||||
process.start()
|
||||
self.processes.append(process)
|
||||
self.events.append(event)
|
||||
|
||||
def _loop_to_save_images_in_threads(self, event: multiprocessing.Event) -> None:
|
||||
with ThreadPoolExecutor(max_workers=self.num_threads) as executor:
|
||||
futures = []
|
||||
while True:
|
||||
frame_data = self.image_queue.get()
|
||||
if frame_data is None:
|
||||
self._wait_threads(self.futures, 10)
|
||||
return
|
||||
|
||||
image, file_path = frame_data
|
||||
futures.append(executor.submit(self._save_image, image, file_path))
|
||||
|
||||
if self.main_event.is_set():
|
||||
self._wait_threads(self.futures, 10)
|
||||
event.set()
|
||||
|
||||
def async_save_image(self, image: torch.Tensor, file_path: Path) -> None:
|
||||
"""Save an image asynchronously using threads or processes."""
|
||||
if self.type == "synchronous":
|
||||
self._save_image(image, file_path)
|
||||
elif self.type == "threads":
|
||||
self.futures.append(self.threads.submit(self._save_image, image, file_path))
|
||||
else:
|
||||
self.image_queue.put((image, file_path))
|
||||
|
||||
def _save_image(self, image: torch.Tensor, file_path: Path) -> None:
|
||||
img = Image.fromarray(image.numpy())
|
||||
img.save(str(file_path), quality=100)
|
||||
p = multiprocessing.Process(target=worker_process, args=(self.queue, self.num_threads))
|
||||
p.daemon = True
|
||||
p.start()
|
||||
self.processes.append(p)
|
||||
|
||||
def get_image_file_path(self, episode_index: int, image_key: str, frame_index: int) -> Path:
|
||||
fpath = self.image_path.format(
|
||||
image_key=image_key, episode_index=episode_index, frame_index=frame_index
|
||||
)
|
||||
return self.dir / fpath
|
||||
return self.write_dir / fpath
|
||||
|
||||
def get_episode_dir(self, episode_index: int, image_key: str) -> Path:
|
||||
return self.get_image_file_path(
|
||||
episode_index=episode_index, image_key=image_key, frame_index=0
|
||||
).parent
|
||||
|
||||
def wait(self) -> None:
|
||||
"""Wait for the thread/processes to finish writing."""
|
||||
if self.type == "synchronous":
|
||||
return
|
||||
elif self.type == "threads":
|
||||
self._wait_threads(self.futures)
|
||||
def save_image(self, image_array: torch.Tensor | np.ndarray, fpath: Path):
|
||||
if isinstance(image_array, torch.Tensor):
|
||||
image_array = image_array.numpy()
|
||||
self.queue.put((image_array, fpath))
|
||||
|
||||
def wait_until_done(self):
|
||||
self.queue.join()
|
||||
|
||||
def stop(self):
|
||||
if self.num_processes == 0:
|
||||
# For threading
|
||||
for _ in self.threads:
|
||||
self.queue.put(None)
|
||||
for t in self.threads:
|
||||
t.join()
|
||||
else:
|
||||
self._wait_processes()
|
||||
|
||||
def _wait_threads(self, futures) -> None:
|
||||
with tqdm.tqdm(total=len(futures), desc="Writing images") as progress_bar:
|
||||
wait(futures, timeout=self.timeout)
|
||||
progress_bar.update(len(futures))
|
||||
|
||||
def _wait_processes(self) -> None:
|
||||
self.main_event.set()
|
||||
for event in self.events:
|
||||
event.wait()
|
||||
|
||||
self.main_event.clear()
|
||||
|
||||
def shutdown(self, timeout=20) -> None:
|
||||
"""Stop the image writer, waiting for all processes or threads to finish."""
|
||||
if self.type == "synchronous":
|
||||
return
|
||||
elif self.type == "threads":
|
||||
self.threads.shutdown(wait=True)
|
||||
else:
|
||||
self._stop_processes(timeout)
|
||||
|
||||
def _stop_processes(self, timeout) -> None:
|
||||
for _ in self.processes:
|
||||
self.image_queue.put(None)
|
||||
|
||||
for process in self.processes:
|
||||
process.join(timeout=timeout)
|
||||
|
||||
for process in self.processes:
|
||||
if process.is_alive():
|
||||
process.terminate()
|
||||
|
||||
self.image_queue.close()
|
||||
self.image_queue.join_thread()
|
||||
# For multiprocessing
|
||||
num_nones = self.num_processes * self.num_threads
|
||||
for _ in range(num_nones):
|
||||
self.queue.put(None)
|
||||
self.queue.close()
|
||||
self.queue.join_thread()
|
||||
for p in self.processes:
|
||||
p.join()
|
||||
|
|
|
@ -25,7 +25,6 @@ import datasets
|
|||
import torch
|
||||
import torch.utils
|
||||
from datasets import load_dataset
|
||||
from datasets.table import embed_table_storage
|
||||
from huggingface_hub import snapshot_download, upload_folder
|
||||
|
||||
from lerobot.common.datasets.compute_stats import aggregate_stats, compute_stats
|
||||
|
@ -51,6 +50,7 @@ from lerobot.common.datasets.utils import (
|
|||
load_stats,
|
||||
load_tasks,
|
||||
write_json,
|
||||
write_parquet,
|
||||
write_stats,
|
||||
)
|
||||
from lerobot.common.datasets.video_utils import (
|
||||
|
@ -354,7 +354,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
|||
@property
|
||||
def num_samples(self) -> int:
|
||||
"""Number of samples/frames in selected episodes."""
|
||||
return len(self.hf_dataset)
|
||||
return len(self.hf_dataset) if self.hf_dataset is not None else self.total_frames
|
||||
|
||||
@property
|
||||
def num_episodes(self) -> int:
|
||||
|
@ -584,9 +584,9 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
|||
if frame_index == 0:
|
||||
img_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
self.image_writer.async_save_image(
|
||||
image=frame[cam_key],
|
||||
file_path=img_path,
|
||||
self.image_writer.save_image(
|
||||
image_array=frame[cam_key],
|
||||
fpath=img_path,
|
||||
)
|
||||
|
||||
if cam_key in self.image_keys:
|
||||
|
@ -640,14 +640,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
|||
ep_dataset = datasets.Dataset.from_dict(self.episode_buffer, features=self._features, split="train")
|
||||
ep_data_path = self.root / self.get_data_file_path(ep_index=episode_index)
|
||||
ep_data_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# Embed image bytes into the table before saving to parquet
|
||||
format = ep_dataset.format
|
||||
ep_dataset = ep_dataset.with_format("arrow")
|
||||
ep_dataset = ep_dataset.map(embed_table_storage, batched=False)
|
||||
ep_dataset = ep_dataset.with_format(**format)
|
||||
|
||||
ep_dataset.to_parquet(ep_data_path)
|
||||
write_parquet(ep_dataset, ep_data_path)
|
||||
|
||||
def _save_episode_to_metadata(
|
||||
self, episode_index: int, episode_length: int, task: str, task_index: int
|
||||
|
@ -709,13 +702,13 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
|||
remove the image_write in order for the LeRobotDataset object to be pickleable and parallelized.
|
||||
"""
|
||||
if self.image_writer is not None:
|
||||
self.image_writer.shutdown()
|
||||
self.image_writer.stop()
|
||||
self.image_writer = None
|
||||
|
||||
def _wait_image_writer(self) -> None:
|
||||
"""Wait for asynchronous image writer to finish."""
|
||||
if self.image_writer is not None:
|
||||
self.image_writer.wait()
|
||||
self.image_writer.wait_until_done()
|
||||
|
||||
def encode_videos(self) -> None:
|
||||
# Use ffmpeg to convert frames stored as png into mp4 videos
|
||||
|
@ -754,7 +747,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
|||
self._write_video_info()
|
||||
|
||||
if not keep_image_files and self.image_writer is not None:
|
||||
shutil.rmtree(self.image_writer.dir)
|
||||
shutil.rmtree(self.image_writer.write_dir)
|
||||
|
||||
if run_compute_stats:
|
||||
self.stop_image_writer()
|
||||
|
|
|
@ -23,6 +23,7 @@ from typing import Any, Dict
|
|||
import datasets
|
||||
import jsonlines
|
||||
import torch
|
||||
from datasets.table import embed_table_storage
|
||||
from huggingface_hub import DatasetCard, HfApi
|
||||
from PIL import Image as PILImage
|
||||
from torchvision import transforms
|
||||
|
@ -80,6 +81,15 @@ def unflatten_dict(d: dict, sep: str = "/") -> dict:
|
|||
return outdict
|
||||
|
||||
|
||||
def write_parquet(dataset: datasets.Dataset, fpath: Path) -> None:
|
||||
# Embed image bytes into the table before saving to parquet
|
||||
format = dataset.format
|
||||
dataset = dataset.with_format("arrow")
|
||||
dataset = dataset.map(embed_table_storage, batched=False)
|
||||
dataset = dataset.with_format(**format)
|
||||
dataset.to_parquet(fpath)
|
||||
|
||||
|
||||
def load_json(fpath: Path) -> Any:
|
||||
with open(fpath) as f:
|
||||
return json.load(f)
|
||||
|
@ -114,6 +124,25 @@ def write_stats(stats: dict[str, torch.Tensor | dict], fpath: Path) -> None:
|
|||
write_json(serialized_stats, fpath)
|
||||
|
||||
|
||||
def load_info(local_dir: Path) -> dict:
|
||||
return load_json(local_dir / INFO_PATH)
|
||||
|
||||
|
||||
def load_stats(local_dir: Path) -> dict:
|
||||
stats = load_json(local_dir / STATS_PATH)
|
||||
stats = {key: torch.tensor(value) for key, value in flatten_dict(stats).items()}
|
||||
return unflatten_dict(stats)
|
||||
|
||||
|
||||
def load_tasks(local_dir: Path) -> dict:
|
||||
tasks = load_jsonlines(local_dir / TASKS_PATH)
|
||||
return {item["task_index"]: item["task"] for item in sorted(tasks, key=lambda x: x["task_index"])}
|
||||
|
||||
|
||||
def load_episode_dicts(local_dir: Path) -> dict:
|
||||
return load_jsonlines(local_dir / EPISODES_PATH)
|
||||
|
||||
|
||||
def hf_transform_to_torch(items_dict: dict[torch.Tensor | None]):
|
||||
"""Get a transform function that convert items from Hugging Face dataset (pyarrow)
|
||||
to torch tensors. Importantly, images are converted from PIL, which corresponds to
|
||||
|
@ -185,25 +214,6 @@ def get_hub_safe_version(repo_id: str, version: str, enforce_v2: bool = True) ->
|
|||
return version
|
||||
|
||||
|
||||
def load_info(local_dir: Path) -> dict:
|
||||
return load_json(local_dir / INFO_PATH)
|
||||
|
||||
|
||||
def load_stats(local_dir: Path) -> dict:
|
||||
stats = load_json(local_dir / STATS_PATH)
|
||||
stats = {key: torch.tensor(value) for key, value in flatten_dict(stats).items()}
|
||||
return unflatten_dict(stats)
|
||||
|
||||
|
||||
def load_tasks(local_dir: Path) -> dict:
|
||||
tasks = load_jsonlines(local_dir / TASKS_PATH)
|
||||
return {item["task_index"]: item["task"] for item in sorted(tasks, key=lambda x: x["task_index"])}
|
||||
|
||||
|
||||
def load_episode_dicts(local_dir: Path) -> dict:
|
||||
return load_jsonlines(local_dir / EPISODES_PATH)
|
||||
|
||||
|
||||
def _get_info_from_robot(robot: Robot, use_videos: bool) -> tuple[list | dict]:
|
||||
shapes = {key: len(names) for key, names in robot.names.items()}
|
||||
camera_shapes = {}
|
||||
|
|
Loading…
Reference in New Issue