feat(visualization): implement rerun
This commit is contained in:
parent
2c22f7d76d
commit
a7d7f1258d
|
@ -17,12 +17,19 @@
|
||||||
|
|
||||||
import argparse
|
import argparse
|
||||||
import datetime as dt
|
import datetime as dt
|
||||||
|
import os
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
import cv2
|
import cv2
|
||||||
|
import keyboard
|
||||||
|
import rerun as rr
|
||||||
|
|
||||||
|
|
||||||
def display_and_save_video_stream(output_dir: Path, fps: int, width: int, height: int):
|
def display_and_save_video_stream(output_dir: Path, fps: int, width: int, height: int):
|
||||||
|
rr.init("lerobot_capture_camera_feed")
|
||||||
|
memory_limit = os.getenv("LEROBOT_RERUN_MEMORY_LIMIT", "5%")
|
||||||
|
rr.spawn(memory_limit=memory_limit)
|
||||||
|
|
||||||
now = dt.datetime.now()
|
now = dt.datetime.now()
|
||||||
capture_dir = output_dir / f"{now:%Y-%m-%d}" / f"{now:%H-%M-%S}"
|
capture_dir = output_dir / f"{now:%Y-%m-%d}" / f"{now:%H-%M-%S}"
|
||||||
if not capture_dir.exists():
|
if not capture_dir.exists():
|
||||||
|
@ -45,18 +52,18 @@ def display_and_save_video_stream(output_dir: Path, fps: int, width: int, height
|
||||||
if not ret:
|
if not ret:
|
||||||
print("Error: Could not read frame.")
|
print("Error: Could not read frame.")
|
||||||
break
|
break
|
||||||
|
rr.log("video/stream", rr.Image(frame.numpy()), static=True)
|
||||||
cv2.imshow("Video Stream", frame)
|
|
||||||
cv2.imwrite(str(capture_dir / f"frame_{frame_index:06d}.png"), frame)
|
cv2.imwrite(str(capture_dir / f"frame_{frame_index:06d}.png"), frame)
|
||||||
frame_index += 1
|
frame_index += 1
|
||||||
|
|
||||||
# Break the loop on 'q' key press
|
# Break the loop on 'q' key press
|
||||||
if cv2.waitKey(1) & 0xFF == ord("q"):
|
if keyboard.is_pressed("q"):
|
||||||
break
|
break
|
||||||
|
|
||||||
# Release the capture and destroy all windows
|
# Release the capture and destroy all windows
|
||||||
cap.release()
|
cap.release()
|
||||||
cv2.destroyAllWindows()
|
# TODO(Steven): Find a way to close visualizer: https://github.com/rerun-io/rerun/pull/9400
|
||||||
|
# cv2.destroyAllWindows()
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
|
|
@ -41,7 +41,7 @@ class TeleoperateControlConfig(ControlConfig):
|
||||||
fps: int | None = None
|
fps: int | None = None
|
||||||
teleop_time_s: float | None = None
|
teleop_time_s: float | None = None
|
||||||
# Display all cameras on screen
|
# Display all cameras on screen
|
||||||
display_cameras: bool = True
|
display_data: bool = False
|
||||||
|
|
||||||
|
|
||||||
@ControlConfig.register_subclass("record")
|
@ControlConfig.register_subclass("record")
|
||||||
|
@ -82,7 +82,7 @@ class RecordControlConfig(ControlConfig):
|
||||||
# Not enough threads might cause low camera fps.
|
# Not enough threads might cause low camera fps.
|
||||||
num_image_writer_threads_per_camera: int = 4
|
num_image_writer_threads_per_camera: int = 4
|
||||||
# Display all cameras on screen
|
# Display all cameras on screen
|
||||||
display_cameras: bool = True
|
display_data: bool = False
|
||||||
# Use vocal synthesis to read events.
|
# Use vocal synthesis to read events.
|
||||||
play_sounds: bool = True
|
play_sounds: bool = True
|
||||||
# Resume recording on an existing dataset.
|
# Resume recording on an existing dataset.
|
||||||
|
|
|
@ -24,7 +24,7 @@ from contextlib import nullcontext
|
||||||
from copy import copy
|
from copy import copy
|
||||||
from functools import cache
|
from functools import cache
|
||||||
|
|
||||||
import cv2
|
import rerun as rr
|
||||||
import torch
|
import torch
|
||||||
from deepdiff import DeepDiff
|
from deepdiff import DeepDiff
|
||||||
from termcolor import colored
|
from termcolor import colored
|
||||||
|
@ -174,13 +174,13 @@ def warmup_record(
|
||||||
events,
|
events,
|
||||||
enable_teleoperation,
|
enable_teleoperation,
|
||||||
warmup_time_s,
|
warmup_time_s,
|
||||||
display_cameras,
|
display_data,
|
||||||
fps,
|
fps,
|
||||||
):
|
):
|
||||||
control_loop(
|
control_loop(
|
||||||
robot=robot,
|
robot=robot,
|
||||||
control_time_s=warmup_time_s,
|
control_time_s=warmup_time_s,
|
||||||
display_cameras=display_cameras,
|
display_data=display_data,
|
||||||
events=events,
|
events=events,
|
||||||
fps=fps,
|
fps=fps,
|
||||||
teleoperate=enable_teleoperation,
|
teleoperate=enable_teleoperation,
|
||||||
|
@ -192,7 +192,7 @@ def record_episode(
|
||||||
dataset,
|
dataset,
|
||||||
events,
|
events,
|
||||||
episode_time_s,
|
episode_time_s,
|
||||||
display_cameras,
|
display_data,
|
||||||
policy,
|
policy,
|
||||||
fps,
|
fps,
|
||||||
single_task,
|
single_task,
|
||||||
|
@ -200,7 +200,7 @@ def record_episode(
|
||||||
control_loop(
|
control_loop(
|
||||||
robot=robot,
|
robot=robot,
|
||||||
control_time_s=episode_time_s,
|
control_time_s=episode_time_s,
|
||||||
display_cameras=display_cameras,
|
display_data=display_data,
|
||||||
dataset=dataset,
|
dataset=dataset,
|
||||||
events=events,
|
events=events,
|
||||||
policy=policy,
|
policy=policy,
|
||||||
|
@ -215,7 +215,7 @@ def control_loop(
|
||||||
robot,
|
robot,
|
||||||
control_time_s=None,
|
control_time_s=None,
|
||||||
teleoperate=False,
|
teleoperate=False,
|
||||||
display_cameras=False,
|
display_data=False,
|
||||||
dataset: LeRobotDataset | None = None,
|
dataset: LeRobotDataset | None = None,
|
||||||
events=None,
|
events=None,
|
||||||
policy: PreTrainedPolicy = None,
|
policy: PreTrainedPolicy = None,
|
||||||
|
@ -264,11 +264,14 @@ def control_loop(
|
||||||
frame = {**observation, **action, "task": single_task}
|
frame = {**observation, **action, "task": single_task}
|
||||||
dataset.add_frame(frame)
|
dataset.add_frame(frame)
|
||||||
|
|
||||||
if display_cameras and not is_headless():
|
if display_data and not is_headless():
|
||||||
|
for k, v in action.items():
|
||||||
|
for i, vv in enumerate(v):
|
||||||
|
rr.log(f"sent_{k}_{i}", rr.Scalar(vv.numpy()))
|
||||||
|
|
||||||
image_keys = [key for key in observation if "image" in key]
|
image_keys = [key for key in observation if "image" in key]
|
||||||
for key in image_keys:
|
for key in image_keys:
|
||||||
cv2.imshow(key, cv2.cvtColor(observation[key].numpy(), cv2.COLOR_RGB2BGR))
|
rr.log(key, rr.Image(observation[key].numpy()), static=True)
|
||||||
cv2.waitKey(1)
|
|
||||||
|
|
||||||
if fps is not None:
|
if fps is not None:
|
||||||
dt_s = time.perf_counter() - start_loop_t
|
dt_s = time.perf_counter() - start_loop_t
|
||||||
|
@ -297,15 +300,15 @@ def reset_environment(robot, events, reset_time_s, fps):
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def stop_recording(robot, listener, display_cameras):
|
def stop_recording(robot, listener, display_data):
|
||||||
robot.disconnect()
|
robot.disconnect()
|
||||||
|
|
||||||
if not is_headless():
|
if not is_headless() and listener is not None:
|
||||||
if listener is not None:
|
|
||||||
listener.stop()
|
listener.stop()
|
||||||
|
|
||||||
if display_cameras:
|
# TODO(Steven): Find a way to close visualizer: https://github.com/rerun-io/rerun/pull/9400
|
||||||
cv2.destroyAllWindows()
|
# if display_data:
|
||||||
|
# cv2.destroyAllWindows()
|
||||||
|
|
||||||
|
|
||||||
def sanity_check_dataset_name(repo_id, policy_cfg):
|
def sanity_check_dataset_name(repo_id, policy_cfg):
|
||||||
|
|
|
@ -443,7 +443,7 @@ class So100RobotConfig(ManipulatorRobotConfig):
|
||||||
leader_arms: dict[str, MotorsBusConfig] = field(
|
leader_arms: dict[str, MotorsBusConfig] = field(
|
||||||
default_factory=lambda: {
|
default_factory=lambda: {
|
||||||
"main": FeetechMotorsBusConfig(
|
"main": FeetechMotorsBusConfig(
|
||||||
port="/dev/tty.usbmodem58760431091",
|
port="/dev/tty.usbmodem58760434171",
|
||||||
motors={
|
motors={
|
||||||
# name: (index, model)
|
# name: (index, model)
|
||||||
"shoulder_pan": [1, "sts3215"],
|
"shoulder_pan": [1, "sts3215"],
|
||||||
|
@ -460,7 +460,7 @@ class So100RobotConfig(ManipulatorRobotConfig):
|
||||||
follower_arms: dict[str, MotorsBusConfig] = field(
|
follower_arms: dict[str, MotorsBusConfig] = field(
|
||||||
default_factory=lambda: {
|
default_factory=lambda: {
|
||||||
"main": FeetechMotorsBusConfig(
|
"main": FeetechMotorsBusConfig(
|
||||||
port="/dev/tty.usbmodem585A0076891",
|
port="/dev/tty.usbmodem58760430031",
|
||||||
motors={
|
motors={
|
||||||
# name: (index, model)
|
# name: (index, model)
|
||||||
"shoulder_pan": [1, "sts3215"],
|
"shoulder_pan": [1, "sts3215"],
|
||||||
|
@ -476,17 +476,11 @@ class So100RobotConfig(ManipulatorRobotConfig):
|
||||||
|
|
||||||
cameras: dict[str, CameraConfig] = field(
|
cameras: dict[str, CameraConfig] = field(
|
||||||
default_factory=lambda: {
|
default_factory=lambda: {
|
||||||
"laptop": OpenCVCameraConfig(
|
"cam": OpenCVCameraConfig(
|
||||||
camera_index=0,
|
camera_index=0,
|
||||||
fps=30,
|
fps=30,
|
||||||
width=640,
|
width=1280,
|
||||||
height=480,
|
height=720,
|
||||||
),
|
|
||||||
"phone": OpenCVCameraConfig(
|
|
||||||
camera_index=1,
|
|
||||||
fps=30,
|
|
||||||
width=640,
|
|
||||||
height=480,
|
|
||||||
),
|
),
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
|
@ -135,15 +135,19 @@ python lerobot/scripts/control_robot.py \
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
|
import os
|
||||||
import time
|
import time
|
||||||
from dataclasses import asdict
|
from dataclasses import asdict
|
||||||
from pprint import pformat
|
from pprint import pformat
|
||||||
|
|
||||||
|
import rerun as rr
|
||||||
|
|
||||||
# from safetensors.torch import load_file, save_file
|
# from safetensors.torch import load_file, save_file
|
||||||
from lerobot.common.datasets.lerobot_dataset import LeRobotDataset
|
from lerobot.common.datasets.lerobot_dataset import LeRobotDataset
|
||||||
from lerobot.common.policies.factory import make_policy
|
from lerobot.common.policies.factory import make_policy
|
||||||
from lerobot.common.robot_devices.control_configs import (
|
from lerobot.common.robot_devices.control_configs import (
|
||||||
CalibrateControlConfig,
|
CalibrateControlConfig,
|
||||||
|
ControlConfig,
|
||||||
ControlPipelineConfig,
|
ControlPipelineConfig,
|
||||||
RecordControlConfig,
|
RecordControlConfig,
|
||||||
RemoteRobotConfig,
|
RemoteRobotConfig,
|
||||||
|
@ -153,6 +157,7 @@ from lerobot.common.robot_devices.control_configs import (
|
||||||
from lerobot.common.robot_devices.control_utils import (
|
from lerobot.common.robot_devices.control_utils import (
|
||||||
control_loop,
|
control_loop,
|
||||||
init_keyboard_listener,
|
init_keyboard_listener,
|
||||||
|
is_headless,
|
||||||
log_control_info,
|
log_control_info,
|
||||||
record_episode,
|
record_episode,
|
||||||
reset_environment,
|
reset_environment,
|
||||||
|
@ -232,7 +237,7 @@ def teleoperate(robot: Robot, cfg: TeleoperateControlConfig):
|
||||||
control_time_s=cfg.teleop_time_s,
|
control_time_s=cfg.teleop_time_s,
|
||||||
fps=cfg.fps,
|
fps=cfg.fps,
|
||||||
teleoperate=True,
|
teleoperate=True,
|
||||||
display_cameras=cfg.display_cameras,
|
display_data=cfg.display_data,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@ -280,7 +285,7 @@ def record(
|
||||||
# 3. place the cameras windows on screen
|
# 3. place the cameras windows on screen
|
||||||
enable_teleoperation = policy is None
|
enable_teleoperation = policy is None
|
||||||
log_say("Warmup record", cfg.play_sounds)
|
log_say("Warmup record", cfg.play_sounds)
|
||||||
warmup_record(robot, events, enable_teleoperation, cfg.warmup_time_s, cfg.display_cameras, cfg.fps)
|
warmup_record(robot, events, enable_teleoperation, cfg.warmup_time_s, cfg.display_data, cfg.fps)
|
||||||
|
|
||||||
if has_method(robot, "teleop_safety_stop"):
|
if has_method(robot, "teleop_safety_stop"):
|
||||||
robot.teleop_safety_stop()
|
robot.teleop_safety_stop()
|
||||||
|
@ -296,7 +301,7 @@ def record(
|
||||||
dataset=dataset,
|
dataset=dataset,
|
||||||
events=events,
|
events=events,
|
||||||
episode_time_s=cfg.episode_time_s,
|
episode_time_s=cfg.episode_time_s,
|
||||||
display_cameras=cfg.display_cameras,
|
display_data=cfg.display_data,
|
||||||
policy=policy,
|
policy=policy,
|
||||||
fps=cfg.fps,
|
fps=cfg.fps,
|
||||||
single_task=cfg.single_task,
|
single_task=cfg.single_task,
|
||||||
|
@ -326,7 +331,7 @@ def record(
|
||||||
break
|
break
|
||||||
|
|
||||||
log_say("Stop recording", cfg.play_sounds, blocking=True)
|
log_say("Stop recording", cfg.play_sounds, blocking=True)
|
||||||
stop_recording(robot, listener, cfg.display_cameras)
|
stop_recording(robot, listener, cfg.display_data)
|
||||||
|
|
||||||
if cfg.push_to_hub:
|
if cfg.push_to_hub:
|
||||||
dataset.push_to_hub(tags=cfg.tags, private=cfg.private)
|
dataset.push_to_hub(tags=cfg.tags, private=cfg.private)
|
||||||
|
@ -363,6 +368,40 @@ def replay(
|
||||||
log_control_info(robot, dt_s, fps=cfg.fps)
|
log_control_info(robot, dt_s, fps=cfg.fps)
|
||||||
|
|
||||||
|
|
||||||
|
def _init_rerun(control_config: ControlConfig, session_name: str = "lerobot_control_loop") -> None:
|
||||||
|
"""Initializes the Rerun SDK for visualizing the control loop.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
control_config: Configuration determining data display and robot type.
|
||||||
|
session_name: Rerun session name. Defaults to "lerobot_control_loop".
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError: If viewer IP is missing for non-remote configurations with display enabled.
|
||||||
|
"""
|
||||||
|
if control_config.display_data and not is_headless():
|
||||||
|
# Configure Rerun flush batch size
|
||||||
|
batch_size = os.getenv("RERUN_FLUSH_NUM_BYTES", "8000")
|
||||||
|
os.environ["RERUN_FLUSH_NUM_BYTES"] = batch_size
|
||||||
|
|
||||||
|
# Get memory limit and viewer connection parameters
|
||||||
|
memory_limit = os.getenv("LEROBOT_RERUN_MEMORY_LIMIT", "5%")
|
||||||
|
viewer_ip = os.getenv("LEROBOT_VIEWER_IP")
|
||||||
|
viewer_port = os.getenv("LEROBOT_VIEWER_PORT", "9876")
|
||||||
|
|
||||||
|
# Initialize Rerun based on configuration
|
||||||
|
rr.init(session_name)
|
||||||
|
if isinstance(control_config, RemoteRobotConfig):
|
||||||
|
if not viewer_ip:
|
||||||
|
raise ValueError(
|
||||||
|
"Viewer IP required for remote config. Set LEROBOT_VIEWER_IP "
|
||||||
|
"or disable control_config.display_data."
|
||||||
|
)
|
||||||
|
logging.info(f"Connecting to viewer at {viewer_ip}:{viewer_port}")
|
||||||
|
rr.connect_tcp(f"{viewer_ip}:{viewer_port}")
|
||||||
|
else:
|
||||||
|
rr.spawn(memory_limit=memory_limit)
|
||||||
|
|
||||||
|
|
||||||
@parser.wrap()
|
@parser.wrap()
|
||||||
def control_robot(cfg: ControlPipelineConfig):
|
def control_robot(cfg: ControlPipelineConfig):
|
||||||
init_logging()
|
init_logging()
|
||||||
|
@ -370,17 +409,22 @@ def control_robot(cfg: ControlPipelineConfig):
|
||||||
|
|
||||||
robot = make_robot_from_config(cfg.robot)
|
robot = make_robot_from_config(cfg.robot)
|
||||||
|
|
||||||
|
# TODO(Steven): Blueprint for fixed window size
|
||||||
|
|
||||||
if isinstance(cfg.control, CalibrateControlConfig):
|
if isinstance(cfg.control, CalibrateControlConfig):
|
||||||
calibrate(robot, cfg.control)
|
calibrate(robot, cfg.control)
|
||||||
elif isinstance(cfg.control, TeleoperateControlConfig):
|
elif isinstance(cfg.control, TeleoperateControlConfig):
|
||||||
|
_init_rerun(control_config=cfg.control, session_name="lerobot_control_loop_teleop")
|
||||||
teleoperate(robot, cfg.control)
|
teleoperate(robot, cfg.control)
|
||||||
elif isinstance(cfg.control, RecordControlConfig):
|
elif isinstance(cfg.control, RecordControlConfig):
|
||||||
|
_init_rerun(control_config=cfg.control, session_name="lerobot_control_loop_record")
|
||||||
record(robot, cfg.control)
|
record(robot, cfg.control)
|
||||||
elif isinstance(cfg.control, ReplayControlConfig):
|
elif isinstance(cfg.control, ReplayControlConfig):
|
||||||
replay(robot, cfg.control)
|
replay(robot, cfg.control)
|
||||||
elif isinstance(cfg.control, RemoteRobotConfig):
|
elif isinstance(cfg.control, RemoteRobotConfig):
|
||||||
from lerobot.common.robot_devices.robots.lekiwi_remote import run_lekiwi
|
from lerobot.common.robot_devices.robots.lekiwi_remote import run_lekiwi
|
||||||
|
|
||||||
|
_init_rerun(control_config=cfg.control, session_name="lerobot_control_loop_remote")
|
||||||
run_lekiwi(cfg.robot)
|
run_lekiwi(cfg.robot)
|
||||||
|
|
||||||
if robot.is_connected:
|
if robot.is_connected:
|
||||||
|
|
|
@ -60,9 +60,9 @@ dependencies = [
|
||||||
"jsonlines>=4.0.0",
|
"jsonlines>=4.0.0",
|
||||||
"numba>=0.59.0",
|
"numba>=0.59.0",
|
||||||
"omegaconf>=2.3.0",
|
"omegaconf>=2.3.0",
|
||||||
"opencv-python>=4.9.0",
|
"opencv-python-headless>=4.9.0",
|
||||||
|
"keyboard>=0.13.5",
|
||||||
"packaging>=24.2",
|
"packaging>=24.2",
|
||||||
"av>=12.0.5,<13.0.0",
|
|
||||||
"pymunk>=6.6.0",
|
"pymunk>=6.6.0",
|
||||||
"pynput>=1.7.7",
|
"pynput>=1.7.7",
|
||||||
"pyzmq>=26.2.1",
|
"pyzmq>=26.2.1",
|
||||||
|
|
|
@ -172,8 +172,7 @@ def test_record_and_replay_and_policy(tmp_path, request, robot_type, mock):
|
||||||
push_to_hub=False,
|
push_to_hub=False,
|
||||||
# TODO(rcadene, aliberts): test video=True
|
# TODO(rcadene, aliberts): test video=True
|
||||||
video=False,
|
video=False,
|
||||||
# TODO(rcadene): display cameras through cv2 sometimes crashes on mac
|
display_data=False,
|
||||||
display_cameras=False,
|
|
||||||
play_sounds=False,
|
play_sounds=False,
|
||||||
)
|
)
|
||||||
dataset = record(robot, rec_cfg)
|
dataset = record(robot, rec_cfg)
|
||||||
|
@ -226,7 +225,7 @@ def test_record_and_replay_and_policy(tmp_path, request, robot_type, mock):
|
||||||
num_episodes=2,
|
num_episodes=2,
|
||||||
push_to_hub=False,
|
push_to_hub=False,
|
||||||
video=False,
|
video=False,
|
||||||
display_cameras=False,
|
display_data=False,
|
||||||
play_sounds=False,
|
play_sounds=False,
|
||||||
num_image_writer_processes=num_image_writer_processes,
|
num_image_writer_processes=num_image_writer_processes,
|
||||||
)
|
)
|
||||||
|
@ -273,7 +272,7 @@ def test_resume_record(tmp_path, request, robot_type, mock):
|
||||||
episode_time_s=1,
|
episode_time_s=1,
|
||||||
push_to_hub=False,
|
push_to_hub=False,
|
||||||
video=False,
|
video=False,
|
||||||
display_cameras=False,
|
display_data=False,
|
||||||
play_sounds=False,
|
play_sounds=False,
|
||||||
num_episodes=1,
|
num_episodes=1,
|
||||||
)
|
)
|
||||||
|
@ -330,7 +329,7 @@ def test_record_with_event_rerecord_episode(tmp_path, request, robot_type, mock)
|
||||||
num_episodes=1,
|
num_episodes=1,
|
||||||
push_to_hub=False,
|
push_to_hub=False,
|
||||||
video=False,
|
video=False,
|
||||||
display_cameras=False,
|
display_data=False,
|
||||||
play_sounds=False,
|
play_sounds=False,
|
||||||
)
|
)
|
||||||
dataset = record(robot, rec_cfg)
|
dataset = record(robot, rec_cfg)
|
||||||
|
@ -380,7 +379,7 @@ def test_record_with_event_exit_early(tmp_path, request, robot_type, mock):
|
||||||
num_episodes=1,
|
num_episodes=1,
|
||||||
push_to_hub=False,
|
push_to_hub=False,
|
||||||
video=False,
|
video=False,
|
||||||
display_cameras=False,
|
display_data=False,
|
||||||
play_sounds=False,
|
play_sounds=False,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -433,7 +432,7 @@ def test_record_with_event_stop_recording(tmp_path, request, robot_type, mock, n
|
||||||
num_episodes=2,
|
num_episodes=2,
|
||||||
push_to_hub=False,
|
push_to_hub=False,
|
||||||
video=False,
|
video=False,
|
||||||
display_cameras=False,
|
display_data=False,
|
||||||
play_sounds=False,
|
play_sounds=False,
|
||||||
num_image_writer_processes=num_image_writer_processes,
|
num_image_writer_processes=num_image_writer_processes,
|
||||||
)
|
)
|
||||||
|
|
Loading…
Reference in New Issue