From db78fee9de0d6fbc29ed17316af6ccd20c70df8a Mon Sep 17 00:00:00 2001
From: Eugene Mironov <helper2424@gmail.com>
Date: Wed, 5 Mar 2025 17:19:31 +0700
Subject: [PATCH] [HIL-SERL] Migrate threading to multiprocessing (#759)

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
---
 lerobot/common/utils/utils.py                 |  10 +-
 lerobot/configs/env/maniskill_example.yaml    |   6 +
 lerobot/configs/env/so100_real.yaml           |   1 -
 lerobot/configs/policy/sac_maniskill.yaml     |   7 +-
 lerobot/scripts/server/actor_server.py        | 422 ++++++++++-------
 lerobot/scripts/server/buffer.py              |  36 +-
 lerobot/scripts/server/hilserl.proto          |  19 +-
 lerobot/scripts/server/hilserl_pb2.py         |  28 +-
 lerobot/scripts/server/hilserl_pb2_grpc.py    | 106 ++++-
 lerobot/scripts/server/learner_server.py      | 440 ++++++++++--------
 lerobot/scripts/server/learner_service.py     | 131 ++----
 .../scripts/server/maniskill_manipulator.py   |  12 +-
 lerobot/scripts/server/network_utils.py       | 102 ++++
 lerobot/scripts/server/utils.py               |  72 +++
 14 files changed, 900 insertions(+), 492 deletions(-)
 create mode 100644 lerobot/scripts/server/network_utils.py
 create mode 100644 lerobot/scripts/server/utils.py

diff --git a/lerobot/common/utils/utils.py b/lerobot/common/utils/utils.py
index 56e151c9..272fa614 100644
--- a/lerobot/common/utils/utils.py
+++ b/lerobot/common/utils/utils.py
@@ -108,11 +108,11 @@ def is_amp_available(device: str):
         raise ValueError(f"Unknown device '{device}.")
 
 
-def init_logging():
+def init_logging(log_file=None):
     def custom_format(record):
         dt = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
         fnameline = f"{record.pathname}:{record.lineno}"
-        message = f"{record.levelname} {dt} {fnameline[-15:]:>15} {record.msg}"
+        message = f"{record.levelname} [PID: {os.getpid()}] {dt} {fnameline[-15:]:>15} {record.msg}"
         return message
 
     logging.basicConfig(level=logging.INFO)
@@ -126,6 +126,12 @@ def init_logging():
     console_handler.setFormatter(formatter)
     logging.getLogger().addHandler(console_handler)
 
+    if log_file is not None:
+        # File handler
+        file_handler = logging.FileHandler(log_file)
+        file_handler.setFormatter(formatter)
+        logging.getLogger().addHandler(file_handler)
+
 
 def format_big_number(num, precision=0):
     suffixes = ["", "K", "M", "B", "T", "Q"]
diff --git a/lerobot/configs/env/maniskill_example.yaml b/lerobot/configs/env/maniskill_example.yaml
index 9098bcbe..3df23b2e 100644
--- a/lerobot/configs/env/maniskill_example.yaml
+++ b/lerobot/configs/env/maniskill_example.yaml
@@ -22,3 +22,9 @@ env:
   wrapper:
     joint_masking_action_space: null
     delta_action: null
+
+  video_record:
+    enabled: false
+    record_dir: maniskill_videos
+    trajectory_name: trajectory
+    fps: ${fps}
diff --git a/lerobot/configs/env/so100_real.yaml b/lerobot/configs/env/so100_real.yaml
index 1bd5cd83..dc30224c 100644
--- a/lerobot/configs/env/so100_real.yaml
+++ b/lerobot/configs/env/so100_real.yaml
@@ -28,4 +28,3 @@ env:
   reward_classifier:
     pretrained_path:  outputs/classifier/13-02-random-sample-resnet10-frozen/checkpoints/best/pretrained_model
     config_path: lerobot/configs/policy/hilserl_classifier.yaml
-
diff --git a/lerobot/configs/policy/sac_maniskill.yaml b/lerobot/configs/policy/sac_maniskill.yaml
index c954b1ea..c9bbca44 100644
--- a/lerobot/configs/policy/sac_maniskill.yaml
+++ b/lerobot/configs/policy/sac_maniskill.yaml
@@ -8,14 +8,12 @@
 #   env.gym.obs_type=environment_state_agent_pos \
 
 seed: 1
-# dataset_repo_id: null
 dataset_repo_id: "AdilZtn/Maniskill-Pushcube-demonstration-medium"
 
 training:
   # Offline training dataloader
   num_workers: 4
 
-  # batch_size: 256
   batch_size: 512
   grad_clip_norm: 10.0
   lr: 3e-4
@@ -113,4 +111,7 @@ policy:
 actor_learner_config:
   learner_host: "127.0.0.1"
   learner_port: 50051
-  policy_parameters_push_frequency: 15
+  policy_parameters_push_frequency: 1
+  concurrency:
+    actor: 'processes'
+    learner: 'processes'
diff --git a/lerobot/scripts/server/actor_server.py b/lerobot/scripts/server/actor_server.py
index c70417cf..24d8356d 100644
--- a/lerobot/scripts/server/actor_server.py
+++ b/lerobot/scripts/server/actor_server.py
@@ -13,22 +13,19 @@
 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 # See the License for the specific language governing permissions and
 # limitations under the License.
-import io
 import logging
-import pickle
-import queue
 from statistics import mean, quantiles
-import signal
 from functools import lru_cache
+from lerobot.scripts.server.utils import setup_process_handlers
 
 # from lerobot.scripts.eval import eval_policy
-from threading import Thread
 
 import grpc
 import hydra
 import torch
 from omegaconf import DictConfig
 from torch import nn
+import time
 
 # TODO: Remove the import of maniskill
 # from lerobot.common.envs.factory import make_maniskill_env
@@ -47,157 +44,184 @@ from lerobot.scripts.server.buffer import (
     Transition,
     move_state_dict_to_device,
     move_transition_to_device,
-    bytes_buffer_size,
+    python_object_to_bytes,
+    transitions_to_bytes,
+    bytes_to_state_dict,
+)
+from lerobot.scripts.server.network_utils import (
+    receive_bytes_in_chunks,
+    send_bytes_in_chunks,
 )
 from lerobot.scripts.server.gym_manipulator import get_classifier, make_robot_env
 from lerobot.scripts.server import learner_service
 
-from threading import Event
+from torch.multiprocessing import Queue, Event
+from queue import Empty
 
-logging.basicConfig(level=logging.INFO)
+from lerobot.common.utils.utils import init_logging
 
-parameters_queue = queue.Queue(maxsize=1)
-message_queue = queue.Queue(maxsize=1_000_000)
+from lerobot.scripts.server.utils import get_last_item_from_queue
 
 ACTOR_SHUTDOWN_TIMEOUT = 30
 
 
-class ActorInformation:
-    """
-    This helper class is used to differentiate between two types of messages that are placed in the same queue during streaming:
-
-    - **Transition Data:** Contains experience tuples (observation, action, reward, next observation) collected during interaction.
-    - **Interaction Messages:** Encapsulates statistics related to the interaction process.
-
-    Attributes:
-        transition (Optional): Transition data to be sent to the learner.
-        interaction_message (Optional): Iteraction message providing additional statistics for logging.
-    """
-
-    def __init__(self, transition=None, interaction_message=None):
-        self.transition = transition
-        self.interaction_message = interaction_message
-
-
 def receive_policy(
-    learner_client: hilserl_pb2_grpc.LearnerServiceStub,
-    shutdown_event: Event,
-    parameters_queue: queue.Queue,
+    cfg: DictConfig,
+    parameters_queue: Queue,
+    shutdown_event: any,  # Event,
+    learner_client: hilserl_pb2_grpc.LearnerServiceStub | None = None,
+    grpc_channel: grpc.Channel | None = None,
 ):
     logging.info("[ACTOR] Start receiving parameters from the Learner")
-    bytes_buffer = io.BytesIO()
-    step = 0
+
+    if not use_threads(cfg):
+        # Setup process handlers to handle shutdown signal
+        # But use shutdown event from the main process
+        setup_process_handlers(False)
+
+    if grpc_channel is None or learner_client is None:
+        learner_client, grpc_channel = learner_service_client(
+            host=cfg.actor_learner_config.learner_host,
+            port=cfg.actor_learner_config.learner_port,
+        )
+
     try:
-        for model_update in learner_client.StreamParameters(hilserl_pb2.Empty()):
-            if shutdown_event.is_set():
-                logging.info("[ACTOR] Shutting down policy streaming receiver")
-                return hilserl_pb2.Empty()
-
-            if model_update.transfer_state == hilserl_pb2.TransferState.TRANSFER_BEGIN:
-                bytes_buffer.seek(0)
-                bytes_buffer.truncate(0)
-                bytes_buffer.write(model_update.parameter_bytes)
-                logging.info("Received model update at step 0")
-                step = 0
-                continue
-            elif (
-                model_update.transfer_state == hilserl_pb2.TransferState.TRANSFER_MIDDLE
-            ):
-                bytes_buffer.write(model_update.parameter_bytes)
-                step += 1
-                logging.info(f"Received model update at step {step}")
-            elif model_update.transfer_state == hilserl_pb2.TransferState.TRANSFER_END:
-                bytes_buffer.write(model_update.parameter_bytes)
-                logging.info(
-                    f"Received model update at step end size {bytes_buffer_size(bytes_buffer)}"
-                )
-
-                state_dict = torch.load(bytes_buffer)
-
-                bytes_buffer.seek(0)
-                bytes_buffer.truncate(0)
-                step = 0
-
-                logging.info("Model updated")
-
-                parameters_queue.put(state_dict)
-
+        iterator = learner_client.StreamParameters(hilserl_pb2.Empty())
+        receive_bytes_in_chunks(
+            iterator,
+            parameters_queue,
+            shutdown_event,
+            log_prefix="[ACTOR] parameters",
+        )
     except grpc.RpcError as e:
         logging.error(f"[ACTOR] gRPC error: {e}")
 
+    if not use_threads(cfg):
+        grpc_channel.close()
+    logging.info("[ACTOR] Received policy loop stopped")
+
+
+def transitions_stream(
+    shutdown_event: Event, transitions_queue: Queue
+) -> hilserl_pb2.Empty:
+    while not shutdown_event.is_set():
+        try:
+            message = transitions_queue.get(block=True, timeout=5)
+        except Empty:
+            logging.debug("[ACTOR] Transition queue is empty")
+            continue
+
+        yield from send_bytes_in_chunks(
+            message, hilserl_pb2.Transition, log_prefix="[ACTOR] Send transitions"
+        )
+
     return hilserl_pb2.Empty()
 
 
-def transitions_stream(shutdown_event: Event, message_queue: queue.Queue):
+def interactions_stream(
+    shutdown_event: any,  # Event,
+    interactions_queue: Queue,
+) -> hilserl_pb2.Empty:
     while not shutdown_event.is_set():
         try:
-            message = message_queue.get(block=True, timeout=5)
-        except queue.Empty:
-            logging.debug("[ACTOR] Transition queue is empty")
+            message = interactions_queue.get(block=True, timeout=5)
+        except Empty:
+            logging.debug("[ACTOR] Interaction queue is empty")
             continue
 
-        if message.transition is not None:
-            transition_to_send_to_learner: list[Transition] = [
-                move_transition_to_device(transition=T, device="cpu")
-                for T in message.transition
-            ]
-            # Check for NaNs in transitions before sending to learner
-            for transition in transition_to_send_to_learner:
-                for key, value in transition["state"].items():
-                    if torch.isnan(value).any():
-                        logging.warning(f"Found NaN values in transition {key}")
-            buf = io.BytesIO()
-            torch.save(transition_to_send_to_learner, buf)
-            transition_bytes = buf.getvalue()
-
-            transition_message = hilserl_pb2.Transition(
-                transition_bytes=transition_bytes
-            )
-
-            response = hilserl_pb2.ActorInformation(transition=transition_message)
-
-        elif message.interaction_message is not None:
-            content = hilserl_pb2.InteractionMessage(
-                interaction_message_bytes=pickle.dumps(message.interaction_message)
-            )
-            response = hilserl_pb2.ActorInformation(interaction_message=content)
-
-        yield response
+        yield from send_bytes_in_chunks(
+            message,
+            hilserl_pb2.InteractionMessage,
+            log_prefix="[ACTOR] Send interactions",
+        )
 
     return hilserl_pb2.Empty()
 
 
 def send_transitions(
-    learner_client: hilserl_pb2_grpc.LearnerServiceStub,
-    shutdown_event: Event,
-    message_queue: queue.Queue,
-):
+    cfg: DictConfig,
+    transitions_queue: Queue,
+    shutdown_event: any,  # Event,
+    learner_client: hilserl_pb2_grpc.LearnerServiceStub | None = None,
+    grpc_channel: grpc.Channel | None = None,
+) -> hilserl_pb2.Empty:
     """
-    Streams data from the actor to the learner.
+    Sends transitions to the learner.
 
-    This function continuously retrieves messages from the queue and processes them based on their type:
+    This function continuously retrieves messages from the queue and processes:
 
     - **Transition Data:**
         - A batch of transitions (observation, action, reward, next observation) is collected.
         - Transitions are moved to the CPU and serialized using PyTorch.
         - The serialized data is wrapped in a `hilserl_pb2.Transition` message and sent to the learner.
-
-    - **Interaction Messages:**
-        - Contains useful statistics about episodic rewards and policy timings.
-        - The message is serialized using `pickle` and sent to the learner.
-
-    Yields:
-        hilserl_pb2.ActorInformation: The response message containing either transition data or an interaction message.
     """
+
+    if not use_threads(cfg):
+        # Setup process handlers to handle shutdown signal
+        # But use shutdown event from the main process
+        setup_process_handlers(False)
+
+    if grpc_channel is None or learner_client is None:
+        learner_client, grpc_channel = learner_service_client(
+            host=cfg.actor_learner_config.learner_host,
+            port=cfg.actor_learner_config.learner_port,
+        )
+
     try:
-        learner_client.ReceiveTransitions(
-            transitions_stream(shutdown_event, message_queue)
+        learner_client.SendTransitions(
+            transitions_stream(shutdown_event, transitions_queue)
         )
     except grpc.RpcError as e:
         logging.error(f"[ACTOR] gRPC error: {e}")
 
     logging.info("[ACTOR] Finished streaming transitions")
 
+    if not use_threads(cfg):
+        grpc_channel.close()
+    logging.info("[ACTOR] Transitions process stopped")
+
+
+def send_interactions(
+    cfg: DictConfig,
+    interactions_queue: Queue,
+    shutdown_event: any,  # Event,
+    learner_client: hilserl_pb2_grpc.LearnerServiceStub | None = None,
+    grpc_channel: grpc.Channel | None = None,
+) -> hilserl_pb2.Empty:
+    """
+    Sends interactions to the learner.
+
+    This function continuously retrieves messages from the queue and processes:
+
+    - **Interaction Messages:**
+        - Contains useful statistics about episodic rewards and policy timings.
+        - The message is serialized using `pickle` and sent to the learner.
+    """
+
+    if not use_threads(cfg):
+        # Setup process handlers to handle shutdown signal
+        # But use shutdown event from the main process
+        setup_process_handlers(False)
+
+    if grpc_channel is None or learner_client is None:
+        learner_client, grpc_channel = learner_service_client(
+            host=cfg.actor_learner_config.learner_host,
+            port=cfg.actor_learner_config.learner_port,
+        )
+
+    try:
+        learner_client.SendInteractions(
+            interactions_stream(shutdown_event, interactions_queue)
+        )
+    except grpc.RpcError as e:
+        logging.error(f"[ACTOR] gRPC error: {e}")
+
+    logging.info("[ACTOR] Finished streaming interactions")
+
+    if not use_threads(cfg):
+        grpc_channel.close()
+    logging.info("[ACTOR] Interactions process stopped")
+
 
 @lru_cache(maxsize=1)
 def learner_service_client(
@@ -217,7 +241,7 @@ def learner_service_client(
             {
                 "name": [{}],  # Applies to ALL methods in ALL services
                 "retryPolicy": {
-                    "maxAttempts": 7,  # Max retries (total attempts = 5)
+                    "maxAttempts": 5,  # Max retries (total attempts = 5)
                     "initialBackoff": "0.1s",  # First retry after 0.1s
                     "maxBackoff": "2s",  # Max wait time between retries
                     "backoffMultiplier": 2,  # Exponential backoff factor
@@ -242,20 +266,27 @@ def learner_service_client(
         ],
     )
     stub = hilserl_pb2_grpc.LearnerServiceStub(channel)
-    logging.info("[LEARNER] Learner service client created")
+    logging.info("[ACTOR] Learner service client created")
     return stub, channel
 
 
-def update_policy_parameters(policy: SACPolicy, parameters_queue: queue.Queue, device):
+def update_policy_parameters(policy: SACPolicy, parameters_queue: Queue, device):
     if not parameters_queue.empty():
         logging.info("[ACTOR] Load new parameters from Learner.")
-        state_dict = parameters_queue.get()
+        bytes_state_dict = get_last_item_from_queue(parameters_queue)
+        state_dict = bytes_to_state_dict(bytes_state_dict)
         state_dict = move_state_dict_to_device(state_dict, device=device)
         policy.load_state_dict(state_dict)
 
 
 def act_with_policy(
-    cfg: DictConfig, robot: Robot, reward_classifier: nn.Module, shutdown_event: Event
+    cfg: DictConfig,
+    robot: Robot,
+    reward_classifier: nn.Module,
+    shutdown_event: any,  # Event,
+    parameters_queue: Queue,
+    transitions_queue: Queue,
+    interactions_queue: Queue,
 ):
     """
     Executes policy interaction within the environment.
@@ -317,7 +348,7 @@ def act_with_policy(
 
     for interaction_step in range(cfg.training.online_steps):
         if shutdown_event.is_set():
-            logging.info("[ACTOR] Shutdown signal received. Exiting...")
+            logging.info("[ACTOR] Shutting down act_with_policy")
             return
 
         if interaction_step >= cfg.training.online_step_before_learning:
@@ -394,10 +425,9 @@ def act_with_policy(
             )
 
             if len(list_transition_to_send_to_learner) > 0:
-                send_transitions_in_chunks(
+                push_transitions_to_transport_queue(
                     transitions=list_transition_to_send_to_learner,
-                    message_queue=message_queue,
-                    chunk_size=4,
+                    transitions_queue=transitions_queue,
                 )
                 list_transition_to_send_to_learner = []
 
@@ -405,9 +435,9 @@ def act_with_policy(
             list_policy_time.clear()
 
             # Send episodic reward to the learner
-            message_queue.put(
-                ActorInformation(
-                    interaction_message={
+            interactions_queue.put(
+                python_object_to_bytes(
+                    {
                         "Episodic reward": sum_reward_episode,
                         "Interaction step": interaction_step,
                         "Episode intervention": int(episode_intervention),
@@ -420,7 +450,7 @@ def act_with_policy(
             obs, info = online_env.reset()
 
 
-def send_transitions_in_chunks(transitions: list, message_queue, chunk_size: int = 100):
+def push_transitions_to_transport_queue(transitions: list, transitions_queue):
     """Send transitions to learner in smaller chunks to avoid network issues.
 
     Args:
@@ -428,10 +458,16 @@ def send_transitions_in_chunks(transitions: list, message_queue, chunk_size: int
         message_queue: Queue to send messages to learner
         chunk_size: Size of each chunk to send
     """
-    for i in range(0, len(transitions), chunk_size):
-        chunk = transitions[i : i + chunk_size]
-        logging.debug(f"[ACTOR] Sending chunk of {len(chunk)} transitions to Learner.")
-        message_queue.put(ActorInformation(transition=chunk))
+    transition_to_send_to_learner = []
+    for transition in transitions:
+        tr = move_transition_to_device(transition=transition, device="cpu")
+        for key, value in tr["state"].items():
+            if torch.isnan(value).any():
+                logging.warning(f"Found NaN values in transition {key}")
+
+        transition_to_send_to_learner.append(tr)
+
+    transitions_queue.put(transitions_to_bytes(transition_to_send_to_learner))
 
 
 def get_frequency_stats(list_policy_time: list[float]) -> dict[str, float]:
@@ -458,39 +494,96 @@ def log_policy_frequency_issue(
         )
 
 
+def establish_learner_connection(
+    stub,
+    shutdown_event: any,  # Event,
+    attempts=30,
+):
+    for _ in range(attempts):
+        if shutdown_event.is_set():
+            logging.info("[ACTOR] Shutting down establish_learner_connection")
+            return False
+
+        # Force a connection attempt and check state
+        try:
+            logging.info("[ACTOR] Send ready message to Learner")
+            if stub.Ready(hilserl_pb2.Empty()) == hilserl_pb2.Empty():
+                return True
+        except grpc.RpcError as e:
+            logging.error(f"[ACTOR] Waiting for Learner to be ready... {e}")
+            time.sleep(2)
+    return False
+
+
+def use_threads(cfg: DictConfig) -> bool:
+    return cfg.actor_learner_config.concurrency.actor == "threads"
+
+
 @hydra.main(version_base="1.2", config_name="default", config_path="../../configs")
 def actor_cli(cfg: dict):
+    if not use_threads(cfg):
+        import torch.multiprocessing as mp
+
+        mp.set_start_method("spawn")
+
+    init_logging(log_file="actor.log")
     robot = make_robot(cfg=cfg.robot)
 
-    shutdown_event = Event()
-
-    # Define signal handler
-    def signal_handler(signum, frame):
-        logging.info("Shutdown signal received. Cleaning up...")
-        shutdown_event.set()
-
-    signal.signal(signal.SIGINT, signal_handler)  # Ctrl+C
-    signal.signal(signal.SIGTERM, signal_handler)  # Termination request (kill)
-    signal.signal(signal.SIGHUP, signal_handler)  # Terminal closed/Hangup
-    signal.signal(signal.SIGQUIT, signal_handler)  # Ctrl+\
+    shutdown_event = setup_process_handlers(use_threads(cfg))
 
     learner_client, grpc_channel = learner_service_client(
         host=cfg.actor_learner_config.learner_host,
         port=cfg.actor_learner_config.learner_port,
     )
 
-    receive_policy_thread = Thread(
+    logging.info("[ACTOR] Establishing connection with Learner")
+    if not establish_learner_connection(learner_client, shutdown_event):
+        logging.error("[ACTOR] Failed to establish connection with Learner")
+        return
+
+    if not use_threads(cfg):
+        # If we use multithreading, we can reuse the channel
+        grpc_channel.close()
+        grpc_channel = None
+
+    logging.info("[ACTOR] Connection with Learner established")
+
+    parameters_queue = Queue()
+    transitions_queue = Queue()
+    interactions_queue = Queue()
+
+    concurrency_entity = None
+    if use_threads(cfg):
+        from threading import Thread
+
+        concurrency_entity = Thread
+    else:
+        from multiprocessing import Process
+
+        concurrency_entity = Process
+
+    receive_policy_process = concurrency_entity(
         target=receive_policy,
-        args=(learner_client, shutdown_event, parameters_queue),
+        args=(cfg, parameters_queue, shutdown_event, grpc_channel),
         daemon=True,
     )
 
-    transitions_thread = Thread(
+    transitions_process = concurrency_entity(
         target=send_transitions,
-        args=(learner_client, shutdown_event, message_queue),
+        args=(cfg, transitions_queue, shutdown_event, grpc_channel),
         daemon=True,
     )
 
+    interactions_process = concurrency_entity(
+        target=send_interactions,
+        args=(cfg, interactions_queue, shutdown_event, grpc_channel),
+        daemon=True,
+    )
+
+    transitions_process.start()
+    interactions_process.start()
+    receive_policy_process.start()
+
     # HACK: FOR MANISKILL we do not have a reward classifier
     # TODO: Remove this once we merge into main
     reward_classifier = None
@@ -503,26 +596,35 @@ def actor_cli(cfg: dict):
             config_path=cfg.env.reward_classifier.config_path,
         )
 
-    policy_thread = Thread(
-        target=act_with_policy,
-        daemon=True,
-        args=(cfg, robot, reward_classifier, shutdown_event),
+    act_with_policy(
+        cfg,
+        robot,
+        reward_classifier,
+        shutdown_event,
+        parameters_queue,
+        transitions_queue,
+        interactions_queue,
     )
+    logging.info("[ACTOR] Policy process joined")
 
-    transitions_thread.start()
-    policy_thread.start()
-    receive_policy_thread.start()
+    logging.info("[ACTOR] Closing queues")
+    transitions_queue.close()
+    interactions_queue.close()
+    parameters_queue.close()
 
-    shutdown_event.wait()
-    logging.info("[ACTOR] Shutdown event received")
-    grpc_channel.close()
+    transitions_process.join()
+    logging.info("[ACTOR] Transitions process joined")
+    interactions_process.join()
+    logging.info("[ACTOR] Interactions process joined")
+    receive_policy_process.join()
+    logging.info("[ACTOR] Receive policy process joined")
 
-    policy_thread.join()
-    logging.info("[ACTOR] Policy thread joined")
-    transitions_thread.join()
-    logging.info("[ACTOR] Transitions thread joined")
-    receive_policy_thread.join()
-    logging.info("[ACTOR] Receive policy thread joined")
+    logging.info("[ACTOR] join queues")
+    transitions_queue.cancel_join_thread()
+    interactions_queue.cancel_join_thread()
+    parameters_queue.cancel_join_thread()
+
+    logging.info("[ACTOR] queues closed")
 
 
 if __name__ == "__main__":
diff --git a/lerobot/scripts/server/buffer.py b/lerobot/scripts/server/buffer.py
index f93b40ca..80834eac 100644
--- a/lerobot/scripts/server/buffer.py
+++ b/lerobot/scripts/server/buffer.py
@@ -23,6 +23,7 @@ from tqdm import tqdm
 
 from lerobot.common.datasets.lerobot_dataset import LeRobotDataset
 import os
+import pickle
 
 
 class Transition(TypedDict):
@@ -91,7 +92,7 @@ def move_transition_to_device(
     return transition
 
 
-def move_state_dict_to_device(state_dict, device):
+def move_state_dict_to_device(state_dict, device="cpu"):
     """
     Recursively move all tensors in a (potentially) nested
     dict/list/tuple structure to the CPU.
@@ -111,20 +112,41 @@ def move_state_dict_to_device(state_dict, device):
         return state_dict
 
 
-def state_to_bytes(state_dict: dict[str, torch.Tensor]) -> io.BytesIO:
+def state_to_bytes(state_dict: dict[str, torch.Tensor]) -> bytes:
     """Convert model state dict to flat array for transmission"""
     buffer = io.BytesIO()
 
     torch.save(state_dict, buffer)
 
-    return buffer
+    return buffer.getvalue()
 
 
-def bytes_buffer_size(buffer: io.BytesIO) -> int:
-    buffer.seek(0, io.SEEK_END)
-    result = buffer.tell()
+def bytes_to_state_dict(buffer: bytes) -> dict[str, torch.Tensor]:
+    buffer = io.BytesIO(buffer)
     buffer.seek(0)
-    return result
+    return torch.load(buffer)
+
+
+def python_object_to_bytes(python_object: Any) -> bytes:
+    return pickle.dumps(python_object)
+
+
+def bytes_to_python_object(buffer: bytes) -> Any:
+    buffer = io.BytesIO(buffer)
+    buffer.seek(0)
+    return pickle.load(buffer)
+
+
+def bytes_to_transitions(buffer: bytes) -> list[Transition]:
+    buffer = io.BytesIO(buffer)
+    buffer.seek(0)
+    return torch.load(buffer)
+
+
+def transitions_to_bytes(transitions: list[Transition]) -> bytes:
+    buffer = io.BytesIO()
+    torch.save(transitions, buffer)
+    return buffer.getvalue()
 
 
 def random_crop_vectorized(images: torch.Tensor, output_size: tuple) -> torch.Tensor:
diff --git a/lerobot/scripts/server/hilserl.proto b/lerobot/scripts/server/hilserl.proto
index 6aa46e0e..dec2117b 100644
--- a/lerobot/scripts/server/hilserl.proto
+++ b/lerobot/scripts/server/hilserl.proto
@@ -24,14 +24,9 @@ service LearnerService {
   // Actor -> Learner to store transitions
   rpc SendInteractionMessage(InteractionMessage) returns (Empty);
   rpc StreamParameters(Empty) returns (stream Parameters);
-  rpc ReceiveTransitions(stream ActorInformation) returns (Empty);
-}
-
-message ActorInformation {
-    oneof data {
-        Transition transition = 1;
-        InteractionMessage interaction_message = 2;
-    }
+  rpc SendTransitions(stream Transition) returns (Empty);
+  rpc SendInteractions(stream InteractionMessage) returns (Empty);
+  rpc Ready(Empty) returns (Empty);
 }
 
 enum TransferState {
@@ -43,16 +38,18 @@ enum TransferState {
 
 // Messages
 message Transition {
-  bytes transition_bytes = 1;
+  TransferState transfer_state = 1;
+  bytes data = 2;
 }
 
 message Parameters {
   TransferState transfer_state = 1;
-  bytes parameter_bytes = 2;
+  bytes data = 2;
 }
 
 message InteractionMessage {
-  bytes interaction_message_bytes = 1;
+  TransferState transfer_state = 1;
+  bytes data = 2;
 }
 
 message Empty {}
diff --git a/lerobot/scripts/server/hilserl_pb2.py b/lerobot/scripts/server/hilserl_pb2.py
index d5eb8d4c..4a4cbea7 100644
--- a/lerobot/scripts/server/hilserl_pb2.py
+++ b/lerobot/scripts/server/hilserl_pb2.py
@@ -24,25 +24,23 @@ _sym_db = _symbol_database.Default()
 
 
 
-DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\rhilserl.proto\x12\x08hil_serl\"\x83\x01\n\x10\x41\x63torInformation\x12*\n\ntransition\x18\x01 \x01(\x0b\x32\x14.hil_serl.TransitionH\x00\x12;\n\x13interaction_message\x18\x02 \x01(\x0b\x32\x1c.hil_serl.InteractionMessageH\x00\x42\x06\n\x04\x64\x61ta\"&\n\nTransition\x12\x18\n\x10transition_bytes\x18\x01 \x01(\x0c\"V\n\nParameters\x12/\n\x0etransfer_state\x18\x01 \x01(\x0e\x32\x17.hil_serl.TransferState\x12\x17\n\x0fparameter_bytes\x18\x02 \x01(\x0c\"7\n\x12InteractionMessage\x12!\n\x19interaction_message_bytes\x18\x01 \x01(\x0c\"\x07\n\x05\x45mpty*`\n\rTransferState\x12\x14\n\x10TRANSFER_UNKNOWN\x10\x00\x12\x12\n\x0eTRANSFER_BEGIN\x10\x01\x12\x13\n\x0fTRANSFER_MIDDLE\x10\x02\x12\x10\n\x0cTRANSFER_END\x10\x03\x32\xdb\x01\n\x0eLearnerService\x12G\n\x16SendInteractionMessage\x12\x1c.hil_serl.InteractionMessage\x1a\x0f.hil_serl.Empty\x12;\n\x10StreamParameters\x12\x0f.hil_serl.Empty\x1a\x14.hil_serl.Parameters0\x01\x12\x43\n\x12ReceiveTransitions\x12\x1a.hil_serl.ActorInformation\x1a\x0f.hil_serl.Empty(\x01\x62\x06proto3')
+DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\rhilserl.proto\x12\x08hil_serl\"K\n\nTransition\x12/\n\x0etransfer_state\x18\x01 \x01(\x0e\x32\x17.hil_serl.TransferState\x12\x0c\n\x04\x64\x61ta\x18\x02 \x01(\x0c\"K\n\nParameters\x12/\n\x0etransfer_state\x18\x01 \x01(\x0e\x32\x17.hil_serl.TransferState\x12\x0c\n\x04\x64\x61ta\x18\x02 \x01(\x0c\"S\n\x12InteractionMessage\x12/\n\x0etransfer_state\x18\x01 \x01(\x0e\x32\x17.hil_serl.TransferState\x12\x0c\n\x04\x64\x61ta\x18\x02 \x01(\x0c\"\x07\n\x05\x45mpty*`\n\rTransferState\x12\x14\n\x10TRANSFER_UNKNOWN\x10\x00\x12\x12\n\x0eTRANSFER_BEGIN\x10\x01\x12\x13\n\x0fTRANSFER_MIDDLE\x10\x02\x12\x10\n\x0cTRANSFER_END\x10\x03\x32\xc2\x02\n\x0eLearnerService\x12G\n\x16SendInteractionMessage\x12\x1c.hil_serl.InteractionMessage\x1a\x0f.hil_serl.Empty\x12;\n\x10StreamParameters\x12\x0f.hil_serl.Empty\x1a\x14.hil_serl.Parameters0\x01\x12:\n\x0fSendTransitions\x12\x14.hil_serl.Transition\x1a\x0f.hil_serl.Empty(\x01\x12\x43\n\x10SendInteractions\x12\x1c.hil_serl.InteractionMessage\x1a\x0f.hil_serl.Empty(\x01\x12)\n\x05Ready\x12\x0f.hil_serl.Empty\x1a\x0f.hil_serl.Emptyb\x06proto3')
 
 _globals = globals()
 _builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, _globals)
 _builder.BuildTopDescriptorsAndMessages(DESCRIPTOR, 'hilserl_pb2', _globals)
 if not _descriptor._USE_C_DESCRIPTORS:
   DESCRIPTOR._loaded_options = None
-  _globals['_TRANSFERSTATE']._serialized_start=355
-  _globals['_TRANSFERSTATE']._serialized_end=451
-  _globals['_ACTORINFORMATION']._serialized_start=28
-  _globals['_ACTORINFORMATION']._serialized_end=159
-  _globals['_TRANSITION']._serialized_start=161
-  _globals['_TRANSITION']._serialized_end=199
-  _globals['_PARAMETERS']._serialized_start=201
-  _globals['_PARAMETERS']._serialized_end=287
-  _globals['_INTERACTIONMESSAGE']._serialized_start=289
-  _globals['_INTERACTIONMESSAGE']._serialized_end=344
-  _globals['_EMPTY']._serialized_start=346
-  _globals['_EMPTY']._serialized_end=353
-  _globals['_LEARNERSERVICE']._serialized_start=454
-  _globals['_LEARNERSERVICE']._serialized_end=673
+  _globals['_TRANSFERSTATE']._serialized_start=275
+  _globals['_TRANSFERSTATE']._serialized_end=371
+  _globals['_TRANSITION']._serialized_start=27
+  _globals['_TRANSITION']._serialized_end=102
+  _globals['_PARAMETERS']._serialized_start=104
+  _globals['_PARAMETERS']._serialized_end=179
+  _globals['_INTERACTIONMESSAGE']._serialized_start=181
+  _globals['_INTERACTIONMESSAGE']._serialized_end=264
+  _globals['_EMPTY']._serialized_start=266
+  _globals['_EMPTY']._serialized_end=273
+  _globals['_LEARNERSERVICE']._serialized_start=374
+  _globals['_LEARNERSERVICE']._serialized_end=696
 # @@protoc_insertion_point(module_scope)
diff --git a/lerobot/scripts/server/hilserl_pb2_grpc.py b/lerobot/scripts/server/hilserl_pb2_grpc.py
index 42d4674e..1fa96e81 100644
--- a/lerobot/scripts/server/hilserl_pb2_grpc.py
+++ b/lerobot/scripts/server/hilserl_pb2_grpc.py
@@ -46,9 +46,19 @@ class LearnerServiceStub(object):
                 request_serializer=hilserl__pb2.Empty.SerializeToString,
                 response_deserializer=hilserl__pb2.Parameters.FromString,
                 _registered_method=True)
-        self.ReceiveTransitions = channel.stream_unary(
-                '/hil_serl.LearnerService/ReceiveTransitions',
-                request_serializer=hilserl__pb2.ActorInformation.SerializeToString,
+        self.SendTransitions = channel.stream_unary(
+                '/hil_serl.LearnerService/SendTransitions',
+                request_serializer=hilserl__pb2.Transition.SerializeToString,
+                response_deserializer=hilserl__pb2.Empty.FromString,
+                _registered_method=True)
+        self.SendInteractions = channel.stream_unary(
+                '/hil_serl.LearnerService/SendInteractions',
+                request_serializer=hilserl__pb2.InteractionMessage.SerializeToString,
+                response_deserializer=hilserl__pb2.Empty.FromString,
+                _registered_method=True)
+        self.Ready = channel.unary_unary(
+                '/hil_serl.LearnerService/Ready',
+                request_serializer=hilserl__pb2.Empty.SerializeToString,
                 response_deserializer=hilserl__pb2.Empty.FromString,
                 _registered_method=True)
 
@@ -71,7 +81,19 @@ class LearnerServiceServicer(object):
         context.set_details('Method not implemented!')
         raise NotImplementedError('Method not implemented!')
 
-    def ReceiveTransitions(self, request_iterator, context):
+    def SendTransitions(self, request_iterator, context):
+        """Missing associated documentation comment in .proto file."""
+        context.set_code(grpc.StatusCode.UNIMPLEMENTED)
+        context.set_details('Method not implemented!')
+        raise NotImplementedError('Method not implemented!')
+
+    def SendInteractions(self, request_iterator, context):
+        """Missing associated documentation comment in .proto file."""
+        context.set_code(grpc.StatusCode.UNIMPLEMENTED)
+        context.set_details('Method not implemented!')
+        raise NotImplementedError('Method not implemented!')
+
+    def Ready(self, request, context):
         """Missing associated documentation comment in .proto file."""
         context.set_code(grpc.StatusCode.UNIMPLEMENTED)
         context.set_details('Method not implemented!')
@@ -90,9 +112,19 @@ def add_LearnerServiceServicer_to_server(servicer, server):
                     request_deserializer=hilserl__pb2.Empty.FromString,
                     response_serializer=hilserl__pb2.Parameters.SerializeToString,
             ),
-            'ReceiveTransitions': grpc.stream_unary_rpc_method_handler(
-                    servicer.ReceiveTransitions,
-                    request_deserializer=hilserl__pb2.ActorInformation.FromString,
+            'SendTransitions': grpc.stream_unary_rpc_method_handler(
+                    servicer.SendTransitions,
+                    request_deserializer=hilserl__pb2.Transition.FromString,
+                    response_serializer=hilserl__pb2.Empty.SerializeToString,
+            ),
+            'SendInteractions': grpc.stream_unary_rpc_method_handler(
+                    servicer.SendInteractions,
+                    request_deserializer=hilserl__pb2.InteractionMessage.FromString,
+                    response_serializer=hilserl__pb2.Empty.SerializeToString,
+            ),
+            'Ready': grpc.unary_unary_rpc_method_handler(
+                    servicer.Ready,
+                    request_deserializer=hilserl__pb2.Empty.FromString,
                     response_serializer=hilserl__pb2.Empty.SerializeToString,
             ),
     }
@@ -163,7 +195,7 @@ class LearnerService(object):
             _registered_method=True)
 
     @staticmethod
-    def ReceiveTransitions(request_iterator,
+    def SendTransitions(request_iterator,
             target,
             options=(),
             channel_credentials=None,
@@ -176,8 +208,62 @@ class LearnerService(object):
         return grpc.experimental.stream_unary(
             request_iterator,
             target,
-            '/hil_serl.LearnerService/ReceiveTransitions',
-            hilserl__pb2.ActorInformation.SerializeToString,
+            '/hil_serl.LearnerService/SendTransitions',
+            hilserl__pb2.Transition.SerializeToString,
+            hilserl__pb2.Empty.FromString,
+            options,
+            channel_credentials,
+            insecure,
+            call_credentials,
+            compression,
+            wait_for_ready,
+            timeout,
+            metadata,
+            _registered_method=True)
+
+    @staticmethod
+    def SendInteractions(request_iterator,
+            target,
+            options=(),
+            channel_credentials=None,
+            call_credentials=None,
+            insecure=False,
+            compression=None,
+            wait_for_ready=None,
+            timeout=None,
+            metadata=None):
+        return grpc.experimental.stream_unary(
+            request_iterator,
+            target,
+            '/hil_serl.LearnerService/SendInteractions',
+            hilserl__pb2.InteractionMessage.SerializeToString,
+            hilserl__pb2.Empty.FromString,
+            options,
+            channel_credentials,
+            insecure,
+            call_credentials,
+            compression,
+            wait_for_ready,
+            timeout,
+            metadata,
+            _registered_method=True)
+
+    @staticmethod
+    def Ready(request,
+            target,
+            options=(),
+            channel_credentials=None,
+            call_credentials=None,
+            insecure=False,
+            compression=None,
+            wait_for_ready=None,
+            timeout=None,
+            metadata=None):
+        return grpc.experimental.unary_unary(
+            request,
+            target,
+            '/hil_serl.LearnerService/Ready',
+            hilserl__pb2.Empty.SerializeToString,
             hilserl__pb2.Empty.FromString,
             options,
             channel_credentials,
diff --git a/lerobot/scripts/server/learner_server.py b/lerobot/scripts/server/learner_server.py
index 4bab9ac2..7bd4aee0 100644
--- a/lerobot/scripts/server/learner_server.py
+++ b/lerobot/scripts/server/learner_server.py
@@ -15,15 +15,18 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 import logging
-import queue
 import shutil
 import time
 from pprint import pformat
-from threading import Lock, Thread
-import signal
-from threading import Event
 from concurrent.futures import ThreadPoolExecutor
 
+# from torch.multiprocessing import Event, Queue, Process
+# from threading import Event, Thread
+# from torch.multiprocessing import Queue, Event
+from torch.multiprocessing import Queue
+
+from lerobot.scripts.server.utils import setup_process_handlers
+
 import grpc
 
 # Import generated stubs
@@ -52,19 +55,19 @@ from lerobot.common.utils.utils import (
     set_global_random_state,
     set_global_seed,
 )
+
 from lerobot.scripts.server.buffer import (
     ReplayBuffer,
     concatenate_batch_transitions,
     move_transition_to_device,
+    move_state_dict_to_device,
+    bytes_to_transitions,
+    state_to_bytes,
+    bytes_to_python_object,
 )
 
 from lerobot.scripts.server import learner_service
 
-logging.basicConfig(level=logging.INFO)
-
-transition_queue = queue.Queue()
-interaction_message_queue = queue.Queue()
-
 
 def handle_resume_logic(cfg: DictConfig, out_dir: str) -> DictConfig:
     if not cfg.resume:
@@ -195,67 +198,96 @@ def get_observation_features(
     return observation_features, next_observation_features
 
 
+def use_threads(cfg: DictConfig) -> bool:
+    return cfg.actor_learner_config.concurrency.learner == "threads"
+
+
 def start_learner_threads(
     cfg: DictConfig,
-    device: str,
-    replay_buffer: ReplayBuffer,
-    offline_replay_buffer: ReplayBuffer,
-    batch_size: int,
-    optimizers: dict,
-    policy: SACPolicy,
-    policy_lock: Lock,
     logger: Logger,
-    resume_optimization_step: int | None = None,
-    resume_interaction_step: int | None = None,
-    shutdown_event: Event | None = None,
+    out_dir: str,
+    shutdown_event: any,  # Event,
 ) -> None:
-    host = cfg.actor_learner_config.learner_host
-    port = cfg.actor_learner_config.learner_port
+    # Create multiprocessing queues
+    transition_queue = Queue()
+    interaction_message_queue = Queue()
+    parameters_queue = Queue()
 
-    transition_thread = Thread(
-        target=add_actor_information_and_train,
-        daemon=True,
+    concurrency_entity = None
+
+    if use_threads(cfg):
+        from threading import Thread
+
+        concurrency_entity = Thread
+    else:
+        from torch.multiprocessing import Process
+
+        concurrency_entity = Process
+
+    communication_process = concurrency_entity(
+        target=start_learner_server,
         args=(
-            cfg,
-            device,
-            replay_buffer,
-            offline_replay_buffer,
-            batch_size,
-            optimizers,
-            policy,
-            policy_lock,
-            logger,
-            resume_optimization_step,
-            resume_interaction_step,
+            parameters_queue,
+            transition_queue,
+            interaction_message_queue,
             shutdown_event,
+            cfg,
         ),
+        daemon=True,
     )
+    communication_process.start()
 
-    transition_thread.start()
+    add_actor_information_and_train(
+        cfg,
+        logger,
+        out_dir,
+        shutdown_event,
+        transition_queue,
+        interaction_message_queue,
+        parameters_queue,
+    )
+    logging.info("[LEARNER] Training process stopped")
+
+    logging.info("[LEARNER] Closing queues")
+    transition_queue.close()
+    interaction_message_queue.close()
+    parameters_queue.close()
+
+    communication_process.join()
+    logging.info("[LEARNER] Communication process joined")
+
+    logging.info("[LEARNER] join queues")
+    transition_queue.cancel_join_thread()
+    interaction_message_queue.cancel_join_thread()
+    parameters_queue.cancel_join_thread()
+
+    logging.info("[LEARNER] queues closed")
+
+
+def start_learner_server(
+    parameters_queue: Queue,
+    transition_queue: Queue,
+    interaction_message_queue: Queue,
+    shutdown_event: any,  # Event,
+    cfg: DictConfig,
+):
+    if not use_threads(cfg):
+        # We need init logging for MP separataly
+        init_logging()
+
+        # Setup process handlers to handle shutdown signal
+        # But use shutdown event from the main process
+        # Return back for MP
+        setup_process_handlers(False)
 
     service = learner_service.LearnerService(
         shutdown_event,
-        policy,
-        policy_lock,
+        parameters_queue,
         cfg.actor_learner_config.policy_parameters_push_frequency,
         transition_queue,
         interaction_message_queue,
     )
-    server = start_learner_server(service, host, port)
 
-    shutdown_event.wait()
-    server.stop(learner_service.STUTDOWN_TIMEOUT)
-    logging.info("[LEARNER] gRPC server stopped")
-
-    transition_thread.join()
-    logging.info("[LEARNER] Transition thread stopped")
-
-
-def start_learner_server(
-    service: learner_service.LearnerService,
-    host="0.0.0.0",
-    port=50051,
-) -> grpc.server:
     server = grpc.server(
         ThreadPoolExecutor(max_workers=learner_service.MAX_WORKERS),
         options=[
@@ -263,15 +295,23 @@ def start_learner_server(
             ("grpc.max_send_message_length", learner_service.MAX_MESSAGE_SIZE),
         ],
     )
+
     hilserl_pb2_grpc.add_LearnerServiceServicer_to_server(
         service,
         server,
     )
+
+    host = cfg.actor_learner_config.learner_host
+    port = cfg.actor_learner_config.learner_port
+
     server.add_insecure_port(f"{host}:{port}")
     server.start()
     logging.info("[LEARNER] gRPC server started")
 
-    return server
+    shutdown_event.wait()
+    logging.info("[LEARNER] Stopping gRPC server...")
+    server.stop(learner_service.STUTDOWN_TIMEOUT)
+    logging.info("[LEARNER] gRPC server stopped")
 
 
 def check_nan_in_transition(
@@ -287,19 +327,21 @@ def check_nan_in_transition(
         logging.error("actions contains NaN values")
 
 
+def push_actor_policy_to_queue(parameters_queue: Queue, policy: nn.Module):
+    logging.debug("[LEARNER] Pushing actor policy to the queue")
+    state_dict = move_state_dict_to_device(policy.actor.state_dict(), device="cpu")
+    state_bytes = state_to_bytes(state_dict)
+    parameters_queue.put(state_bytes)
+
+
 def add_actor_information_and_train(
     cfg,
-    device: str,
-    replay_buffer: ReplayBuffer,
-    offline_replay_buffer: ReplayBuffer,
-    batch_size: int,
-    optimizers: dict[str, torch.optim.Optimizer],
-    policy: nn.Module,
-    policy_lock: Lock,
     logger: Logger,
-    resume_optimization_step: int | None = None,
-    resume_interaction_step: int | None = None,
-    shutdown_event: Event | None = None,
+    out_dir: str,
+    shutdown_event: any,  # Event,
+    transition_queue: Queue,
+    interaction_message_queue: Queue,
+    parameters_queue: Queue,
 ):
     """
     Handles data transfer from the actor to the learner, manages training updates,
@@ -322,17 +364,73 @@ def add_actor_information_and_train(
     Args:
         cfg: Configuration object containing hyperparameters.
         device (str): The computing device (`"cpu"` or `"cuda"`).
-        replay_buffer (ReplayBuffer): The primary replay buffer storing online transitions.
-        offline_replay_buffer (ReplayBuffer): An additional buffer for offline transitions.
-        batch_size (int): The number of transitions to sample per training step.
-        optimizers (Dict[str, torch.optim.Optimizer]): A dictionary of optimizers (`"actor"`, `"critic"`, `"temperature"`).
-        policy (nn.Module): The reinforcement learning policy with critic, actor, and temperature parameters.
-        policy_lock (Lock): A threading lock to ensure safe policy updates.
         logger (Logger): Logger instance for tracking training progress.
-        resume_optimization_step (int | None): In the case of resume training, start from the last optimization step reached.
-        resume_interaction_step (int | None): In the case of resume training, shift the interaction step with the last saved step in order to not break logging.
-        shutdown_event (Event | None): Event to signal shutdown.
+        out_dir (str): The output directory for storing training checkpoints and logs.
+        shutdown_event (Event): Event to signal shutdown.
+        transition_queue (Queue): Queue for receiving transitions from the actor.
+        interaction_message_queue (Queue): Queue for receiving interaction messages from the actor.
+        parameters_queue (Queue): Queue for sending policy parameters to the actor.
     """
+
+    device = get_safe_torch_device(cfg.device, log=True)
+    storage_device = get_safe_torch_device(cfg_device=cfg.training.storage_device)
+
+    logging.info("Initializing policy")
+    ### Instantiate the policy in both the actor and learner processes
+    ### To avoid sending a SACPolicy object through the port, we create a policy intance
+    ### on both sides, the learner sends the updated parameters every n steps to update the actor's parameters
+    # TODO: At some point we should just need make sac policy
+
+    policy: SACPolicy = make_policy(
+        hydra_cfg=cfg,
+        # dataset_stats=offline_dataset.meta.stats if not cfg.resume else None,
+        # Hack: But if we do online traning, we do not need dataset_stats
+        dataset_stats=None,
+        pretrained_policy_name_or_path=str(logger.last_pretrained_model_dir)
+        if cfg.resume
+        else None,
+    )
+    # compile policy
+    policy = torch.compile(policy)
+    assert isinstance(policy, nn.Module)
+
+    push_actor_policy_to_queue(parameters_queue, policy)
+
+    last_time_policy_pushed = time.time()
+
+    optimizers, lr_scheduler = make_optimizers_and_scheduler(cfg, policy)
+    resume_optimization_step, resume_interaction_step = load_training_state(
+        cfg, logger, optimizers
+    )
+
+    log_training_info(cfg, out_dir, policy)
+
+    replay_buffer = initialize_replay_buffer(cfg, logger, device, storage_device)
+    batch_size = cfg.training.batch_size
+    offline_replay_buffer = None
+
+    if cfg.dataset_repo_id is not None:
+        logging.info("make_dataset offline buffer")
+        offline_dataset = make_dataset(cfg)
+        logging.info("Convertion to a offline replay buffer")
+        active_action_dims = None
+        if cfg.env.wrapper.joint_masking_action_space is not None:
+            active_action_dims = [
+                i
+                for i, mask in enumerate(cfg.env.wrapper.joint_masking_action_space)
+                if mask
+            ]
+        offline_replay_buffer = ReplayBuffer.from_lerobot_dataset(
+            offline_dataset,
+            device=device,
+            state_keys=cfg.policy.input_shapes.keys(),
+            action_mask=active_action_dims,
+            action_delta=cfg.env.wrapper.delta_action,
+            storage_device=storage_device,
+            optimize_memory=True,
+        )
+        batch_size: int = batch_size // 2  # We will sample from both replay buffer
+
     # NOTE: This function doesn't have a single responsibility, it should be split into multiple functions
     # in the future. The reason why we did that is the  GIL in Python. It's super slow the performance
     # are divided by 200. So we need to have a single thread that does all the work.
@@ -345,33 +443,39 @@ def add_actor_information_and_train(
     interaction_step_shift = (
         resume_interaction_step if resume_interaction_step is not None else 0
     )
-    saved_data = False
+
     while True:
         if shutdown_event is not None and shutdown_event.is_set():
             logging.info("[LEARNER] Shutdown signal received. Exiting...")
             break
 
-        while not transition_queue.empty():
+        logging.debug("[LEARNER] Waiting for transitions")
+        while not transition_queue.empty() and not shutdown_event.is_set():
             transition_list = transition_queue.get()
+            transition_list = bytes_to_transitions(transition_list)
+
             for transition in transition_list:
                 transition = move_transition_to_device(transition, device=device)
                 replay_buffer.add(**transition)
-
                 if transition.get("complementary_info", {}).get("is_intervention"):
                     offline_replay_buffer.add(**transition)
-
-        while not interaction_message_queue.empty():
+        logging.debug("[LEARNER] Received transitions")
+        logging.debug("[LEARNER] Waiting for interactions")
+        while not interaction_message_queue.empty() and not shutdown_event.is_set():
             interaction_message = interaction_message_queue.get()
+            interaction_message = bytes_to_python_object(interaction_message)
             # If cfg.resume, shift the interaction step with the last checkpointed step in order to not break the logging
             interaction_message["Interaction step"] += interaction_step_shift
             logger.log_dict(
                 interaction_message, mode="train", custom_step_key="Interaction step"
             )
-            # logging.info(f"Interaction message: {interaction_message}")
+
+        logging.debug("[LEARNER] Received interactions")
 
         if len(replay_buffer) < cfg.training.online_step_before_learning:
             continue
 
+        logging.debug("[LEARNER] Starting optimization loop")
         time_for_one_optimization_step = time.time()
         for _ in range(cfg.policy.utd_ratio - 1):
             batch = replay_buffer.sample(batch_size)
@@ -392,19 +496,18 @@ def add_actor_information_and_train(
             observation_features, next_observation_features = get_observation_features(
                 policy, observations, next_observations
             )
-            with policy_lock:
-                loss_critic = policy.compute_loss_critic(
-                    observations=observations,
-                    actions=actions,
-                    rewards=rewards,
-                    next_observations=next_observations,
-                    done=done,
-                    observation_features=observation_features,
-                    next_observation_features=next_observation_features,
-                )
-                optimizers["critic"].zero_grad()
-                loss_critic.backward()
-                optimizers["critic"].step()
+            loss_critic = policy.compute_loss_critic(
+                observations=observations,
+                actions=actions,
+                rewards=rewards,
+                next_observations=next_observations,
+                done=done,
+                observation_features=observation_features,
+                next_observation_features=next_observation_features,
+            )
+            optimizers["critic"].zero_grad()
+            loss_critic.backward()
+            optimizers["critic"].step()
 
         batch = replay_buffer.sample(batch_size)
 
@@ -427,46 +530,51 @@ def add_actor_information_and_train(
         observation_features, next_observation_features = get_observation_features(
             policy, observations, next_observations
         )
-        with policy_lock:
-            loss_critic = policy.compute_loss_critic(
-                observations=observations,
-                actions=actions,
-                rewards=rewards,
-                next_observations=next_observations,
-                done=done,
-                observation_features=observation_features,
-                next_observation_features=next_observation_features,
-            )
-            optimizers["critic"].zero_grad()
-            loss_critic.backward()
-            optimizers["critic"].step()
+        loss_critic = policy.compute_loss_critic(
+            observations=observations,
+            actions=actions,
+            rewards=rewards,
+            next_observations=next_observations,
+            done=done,
+            observation_features=observation_features,
+            next_observation_features=next_observation_features,
+        )
+        optimizers["critic"].zero_grad()
+        loss_critic.backward()
+        optimizers["critic"].step()
 
         training_infos = {}
         training_infos["loss_critic"] = loss_critic.item()
 
         if optimization_step % cfg.training.policy_update_freq == 0:
             for _ in range(cfg.training.policy_update_freq):
-                with policy_lock:
-                    loss_actor = policy.compute_loss_actor(
-                        observations=observations,
-                        observation_features=observation_features,
-                    )
+                loss_actor = policy.compute_loss_actor(
+                    observations=observations,
+                    observation_features=observation_features,
+                )
 
-                    optimizers["actor"].zero_grad()
-                    loss_actor.backward()
-                    optimizers["actor"].step()
+                optimizers["actor"].zero_grad()
+                loss_actor.backward()
+                optimizers["actor"].step()
 
-                    training_infos["loss_actor"] = loss_actor.item()
+                training_infos["loss_actor"] = loss_actor.item()
 
-                    loss_temperature = policy.compute_loss_temperature(
-                        observations=observations,
-                        observation_features=observation_features,
-                    )
-                    optimizers["temperature"].zero_grad()
-                    loss_temperature.backward()
-                    optimizers["temperature"].step()
+                loss_temperature = policy.compute_loss_temperature(
+                    observations=observations,
+                    observation_features=observation_features,
+                )
+                optimizers["temperature"].zero_grad()
+                loss_temperature.backward()
+                optimizers["temperature"].step()
 
-                    training_infos["loss_temperature"] = loss_temperature.item()
+                training_infos["loss_temperature"] = loss_temperature.item()
+
+        if (
+            time.time() - last_time_policy_pushed
+            > cfg.actor_learner_config.policy_parameters_push_frequency
+        ):
+            push_actor_policy_to_queue(parameters_queue, policy)
+            last_time_policy_pushed = time.time()
 
         policy.update_target_networks()
         if optimization_step % cfg.training.log_freq == 0:
@@ -595,104 +703,36 @@ def train(cfg: DictConfig, out_dir: str | None = None, job_name: str | None = No
 
     set_global_seed(cfg.seed)
 
-    device = get_safe_torch_device(cfg.device, log=True)
-    storage_device = get_safe_torch_device(cfg_device=cfg.training.storage_device)
-
     torch.backends.cudnn.benchmark = True
     torch.backends.cuda.matmul.allow_tf32 = True
 
-    logging.info("make_policy")
-
-    ### Instantiate the policy in both the actor and learner processes
-    ### To avoid sending a SACPolicy object through the port, we create a policy intance
-    ### on both sides, the learner sends the updated parameters every n steps to update the actor's parameters
-    # TODO: At some point we should just need make sac policy
-
-    policy_lock = Lock()
-    policy: SACPolicy = make_policy(
-        hydra_cfg=cfg,
-        # dataset_stats=offline_dataset.meta.stats if not cfg.resume else None,
-        # Hack: But if we do online traning, we do not need dataset_stats
-        dataset_stats=None,
-        pretrained_policy_name_or_path=str(logger.last_pretrained_model_dir)
-        if cfg.resume
-        else None,
-    )
-    # compile policy
-    policy = torch.compile(policy)
-    assert isinstance(policy, nn.Module)
-
-    optimizers, lr_scheduler = make_optimizers_and_scheduler(cfg, policy)
-    resume_optimization_step, resume_interaction_step = load_training_state(
-        cfg, logger, optimizers
-    )
-
-    log_training_info(cfg, out_dir, policy)
-
-    replay_buffer = initialize_replay_buffer(cfg, logger, device, storage_device)
-    batch_size = cfg.training.batch_size
-    offline_replay_buffer = None
-
-    if cfg.dataset_repo_id is not None:
-        logging.info("make_dataset offline buffer")
-        offline_dataset = make_dataset(cfg)
-        logging.info("Convertion to a offline replay buffer")
-        active_action_dims = None
-        if cfg.env.wrapper.joint_masking_action_space is not None:
-            active_action_dims = [
-                i
-                for i, mask in enumerate(cfg.env.wrapper.joint_masking_action_space)
-                if mask
-            ]
-        offline_replay_buffer = ReplayBuffer.from_lerobot_dataset(
-            offline_dataset,
-            device=device,
-            state_keys=cfg.policy.input_shapes.keys(),
-            action_mask=active_action_dims,
-            action_delta=cfg.env.wrapper.delta_action,
-            storage_device=storage_device,
-            optimize_memory=True,
-        )
-        batch_size: int = batch_size // 2  # We will sample from both replay buffer
-
-    shutdown_event = Event()
-
-    def signal_handler(signum, frame):
-        print(
-            f"\nReceived signal {signal.Signals(signum).name}. Initiating learner shutdown..."
-        )
-        shutdown_event.set()
-
-    # Register signal handlers
-    signal.signal(signal.SIGINT, signal_handler)  # Ctrl+C
-    signal.signal(signal.SIGTERM, signal_handler)  # Termination request
-    signal.signal(signal.SIGHUP, signal_handler)  # Terminal closed
-    signal.signal(signal.SIGQUIT, signal_handler)  # Ctrl+\
+    shutdown_event = setup_process_handlers(use_threads(cfg))
 
     start_learner_threads(
         cfg,
-        device,
-        replay_buffer,
-        offline_replay_buffer,
-        batch_size,
-        optimizers,
-        policy,
-        policy_lock,
         logger,
-        resume_optimization_step,
-        resume_interaction_step,
+        out_dir,
         shutdown_event,
     )
 
 
 @hydra.main(version_base="1.2", config_name="default", config_path="../../configs")
 def train_cli(cfg: dict):
+    if not use_threads(cfg):
+        import torch.multiprocessing as mp
+
+        mp.set_start_method("spawn")
+
     train(
         cfg,
         out_dir=hydra.core.hydra_config.HydraConfig.get().run.dir,
         job_name=hydra.core.hydra_config.HydraConfig.get().job.name,
     )
 
+    logging.info("[LEARNER] train_cli finished")
+
 
 if __name__ == "__main__":
     train_cli()
+
+    logging.info("[LEARNER] main finished")
diff --git a/lerobot/scripts/server/learner_service.py b/lerobot/scripts/server/learner_service.py
index d6e6b5b7..b1f91cdc 100644
--- a/lerobot/scripts/server/learner_service.py
+++ b/lerobot/scripts/server/learner_service.py
@@ -1,23 +1,13 @@
 import hilserl_pb2  # type: ignore
 import hilserl_pb2_grpc  # type: ignore
-import torch
-from torch import nn
-from threading import Lock, Event
 import logging
-import queue
-import io
-import pickle
-
-from lerobot.scripts.server.buffer import (
-    move_state_dict_to_device,
-    bytes_buffer_size,
-    state_to_bytes,
-)
+from multiprocessing import Event, Queue
 
+from lerobot.scripts.server.network_utils import receive_bytes_in_chunks
+from lerobot.scripts.server.network_utils import send_bytes_in_chunks
 
 MAX_MESSAGE_SIZE = 4 * 1024 * 1024  # 4 MB
-CHUNK_SIZE = 2 * 1024 * 1024  # 2 MB
-MAX_WORKERS = 10
+MAX_WORKERS = 3  # Stream parameters, send transitions and interactions
 STUTDOWN_TIMEOUT = 10
 
 
@@ -25,89 +15,68 @@ class LearnerService(hilserl_pb2_grpc.LearnerServiceServicer):
     def __init__(
         self,
         shutdown_event: Event,
-        policy: nn.Module,
-        policy_lock: Lock,
+        parameters_queue: Queue,
         seconds_between_pushes: float,
-        transition_queue: queue.Queue,
-        interaction_message_queue: queue.Queue,
+        transition_queue: Queue,
+        interaction_message_queue: Queue,
     ):
         self.shutdown_event = shutdown_event
-        self.policy = policy
-        self.policy_lock = policy_lock
+        self.parameters_queue = parameters_queue
         self.seconds_between_pushes = seconds_between_pushes
         self.transition_queue = transition_queue
         self.interaction_message_queue = interaction_message_queue
 
-    def _get_policy_state(self):
-        with self.policy_lock:
-            params_dict = self.policy.actor.state_dict()
-            # if self.policy.config.vision_encoder_name is not None:
-            #     if self.policy.config.freeze_vision_encoder:
-            #         params_dict: dict[str, torch.Tensor] = {
-            #             k: v
-            #             for k, v in params_dict.items()
-            #             if not k.startswith("encoder.")
-            #         }
-            #     else:
-            #         raise NotImplementedError(
-            #             "Vision encoder is not frozen, we need to send the full model over the network which requires chunking the model."
-            #         )
-
-        return move_state_dict_to_device(params_dict, device="cpu")
-
-    def _send_bytes(self, buffer: bytes):
-        size_in_bytes = bytes_buffer_size(buffer)
-
-        sent_bytes = 0
-
-        logging.info(f"Model state size {size_in_bytes/1024/1024} MB with")
-
-        while sent_bytes < size_in_bytes:
-            transfer_state = hilserl_pb2.TransferState.TRANSFER_MIDDLE
-
-            if sent_bytes + CHUNK_SIZE >= size_in_bytes:
-                transfer_state = hilserl_pb2.TransferState.TRANSFER_END
-            elif sent_bytes == 0:
-                transfer_state = hilserl_pb2.TransferState.TRANSFER_BEGIN
-
-            size_to_read = min(CHUNK_SIZE, size_in_bytes - sent_bytes)
-            chunk = buffer.read(size_to_read)
-
-            yield hilserl_pb2.Parameters(
-                transfer_state=transfer_state, parameter_bytes=chunk
-            )
-            sent_bytes += size_to_read
-            logging.info(
-                f"[Learner] Sent {sent_bytes}/{size_in_bytes} bytes with state {transfer_state}"
-            )
-
-        logging.info(f"[LEARNER] Published {sent_bytes/1024/1024} MB to the Actor")
-
     def StreamParameters(self, request, context):
         # TODO: authorize the request
         logging.info("[LEARNER] Received request to stream parameters from the Actor")
 
         while not self.shutdown_event.is_set():
-            logging.debug("[LEARNER] Push parameters to the Actor")
-            state_dict = self._get_policy_state()
+            logging.info("[LEARNER] Push parameters to the Actor")
+            buffer = self.parameters_queue.get()
 
-            with state_to_bytes(state_dict) as buffer:
-                yield from self._send_bytes(buffer)
+            yield from send_bytes_in_chunks(
+                buffer,
+                hilserl_pb2.Parameters,
+                log_prefix="[LEARNER] Sending parameters",
+                silent=True,
+            )
+
+            logging.info("[LEARNER] Parameters sent")
 
             self.shutdown_event.wait(self.seconds_between_pushes)
 
-    def ReceiveTransitions(self, request_iterator, context):
+        logging.info("[LEARNER] Stream parameters finished")
+        return hilserl_pb2.Empty()
+
+    def SendTransitions(self, request_iterator, _context):
         # TODO: authorize the request
         logging.info("[LEARNER] Received request to receive transitions from the Actor")
 
-        for request in request_iterator:
-            logging.debug("[LEARNER] Received request")
-            if request.HasField("transition"):
-                buffer = io.BytesIO(request.transition.transition_bytes)
-                transition = torch.load(buffer)
-                self.transition_queue.put(transition)
-            if request.HasField("interaction_message"):
-                content = pickle.loads(
-                    request.interaction_message.interaction_message_bytes
-                )
-                self.interaction_message_queue.put(content)
+        receive_bytes_in_chunks(
+            request_iterator,
+            self.transition_queue,
+            self.shutdown_event,
+            log_prefix="[LEARNER] transitions",
+        )
+
+        logging.debug("[LEARNER] Finished receiving transitions")
+        return hilserl_pb2.Empty()
+
+    def SendInteractions(self, request_iterator, _context):
+        # TODO: authorize the request
+        logging.info(
+            "[LEARNER] Received request to receive interactions from the Actor"
+        )
+
+        receive_bytes_in_chunks(
+            request_iterator,
+            self.interaction_message_queue,
+            self.shutdown_event,
+            log_prefix="[LEARNER] interactions",
+        )
+
+        logging.debug("[LEARNER] Finished receiving interactions")
+        return hilserl_pb2.Empty()
+
+    def Ready(self, request, context):
+        return hilserl_pb2.Empty()
diff --git a/lerobot/scripts/server/maniskill_manipulator.py b/lerobot/scripts/server/maniskill_manipulator.py
index b9c9d216..e4d55955 100644
--- a/lerobot/scripts/server/maniskill_manipulator.py
+++ b/lerobot/scripts/server/maniskill_manipulator.py
@@ -5,9 +5,8 @@ import torch
 
 from omegaconf import DictConfig
 from typing import Any
-
-"""Make ManiSkill3 gym environment"""
 from mani_skill.vector.wrappers.gymnasium import ManiSkillVectorEnv
+from mani_skill.utils.wrappers.record import RecordEpisode
 
 
 def preprocess_maniskill_observation(
@@ -143,6 +142,15 @@ def make_maniskill(
         num_envs=n_envs,
     )
 
+    if cfg.env.video_record.enabled:
+        env = RecordEpisode(
+            env,
+            output_dir=cfg.env.video_record.record_dir,
+            save_trajectory=True,
+            trajectory_name=cfg.env.video_record.trajectory_name,
+            save_video=True,
+            video_fps=30,
+        )
     env = ManiSkillObservationWrapper(env, device=cfg.env.device)
     env = ManiSkillVectorEnv(env, ignore_terminations=True, auto_reset=False)
     env._max_episode_steps = env.max_episode_steps = (
diff --git a/lerobot/scripts/server/network_utils.py b/lerobot/scripts/server/network_utils.py
new file mode 100644
index 00000000..f5e8973b
--- /dev/null
+++ b/lerobot/scripts/server/network_utils.py
@@ -0,0 +1,102 @@
+#!/usr/bin/env python
+
+# Copyright 2024 The HuggingFace Inc. team.
+# All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from lerobot.scripts.server import hilserl_pb2
+import logging
+import io
+from multiprocessing import Queue, Event
+from typing import Any
+
+CHUNK_SIZE = 2 * 1024 * 1024  # 2 MB
+
+
+def bytes_buffer_size(buffer: io.BytesIO) -> int:
+    buffer.seek(0, io.SEEK_END)
+    result = buffer.tell()
+    buffer.seek(0)
+    return result
+
+
+def send_bytes_in_chunks(
+    buffer: bytes, message_class: Any, log_prefix: str = "", silent: bool = True
+):
+    buffer = io.BytesIO(buffer)
+    size_in_bytes = bytes_buffer_size(buffer)
+
+    sent_bytes = 0
+
+    logging_method = logging.info if not silent else logging.debug
+
+    logging_method(f"{log_prefix} Buffer size {size_in_bytes/1024/1024} MB with")
+
+    while sent_bytes < size_in_bytes:
+        transfer_state = hilserl_pb2.TransferState.TRANSFER_MIDDLE
+
+        if sent_bytes + CHUNK_SIZE >= size_in_bytes:
+            transfer_state = hilserl_pb2.TransferState.TRANSFER_END
+        elif sent_bytes == 0:
+            transfer_state = hilserl_pb2.TransferState.TRANSFER_BEGIN
+
+        size_to_read = min(CHUNK_SIZE, size_in_bytes - sent_bytes)
+        chunk = buffer.read(size_to_read)
+
+        yield message_class(transfer_state=transfer_state, data=chunk)
+        sent_bytes += size_to_read
+        logging_method(
+            f"{log_prefix} Sent {sent_bytes}/{size_in_bytes} bytes with state {transfer_state}"
+        )
+
+    logging_method(f"{log_prefix} Published {sent_bytes/1024/1024} MB")
+
+
+def receive_bytes_in_chunks(
+    iterator, queue: Queue, shutdown_event: Event, log_prefix: str = ""
+):
+    bytes_buffer = io.BytesIO()
+    step = 0
+
+    logging.info(f"{log_prefix} Starting receiver")
+    for item in iterator:
+        logging.debug(f"{log_prefix} Received item")
+        if shutdown_event.is_set():
+            logging.info(f"{log_prefix} Shutting down receiver")
+            return
+
+        if item.transfer_state == hilserl_pb2.TransferState.TRANSFER_BEGIN:
+            bytes_buffer.seek(0)
+            bytes_buffer.truncate(0)
+            bytes_buffer.write(item.data)
+            logging.debug(f"{log_prefix} Received data at step 0")
+            step = 0
+            continue
+        elif item.transfer_state == hilserl_pb2.TransferState.TRANSFER_MIDDLE:
+            bytes_buffer.write(item.data)
+            step += 1
+            logging.debug(f"{log_prefix} Received data at step {step}")
+        elif item.transfer_state == hilserl_pb2.TransferState.TRANSFER_END:
+            bytes_buffer.write(item.data)
+            logging.debug(
+                f"{log_prefix} Received data at step end size {bytes_buffer_size(bytes_buffer)}"
+            )
+
+            queue.put(bytes_buffer.getvalue())
+
+            bytes_buffer.seek(0)
+            bytes_buffer.truncate(0)
+            step = 0
+
+            logging.debug(f"{log_prefix} Queue updated")
diff --git a/lerobot/scripts/server/utils.py b/lerobot/scripts/server/utils.py
new file mode 100644
index 00000000..699717e4
--- /dev/null
+++ b/lerobot/scripts/server/utils.py
@@ -0,0 +1,72 @@
+#!/usr/bin/env python
+
+# Copyright 2024 The HuggingFace Inc. team.
+# All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import logging
+import signal
+import sys
+from torch.multiprocessing import Queue
+from queue import Empty
+
+shutdown_event_counter = 0
+
+
+def setup_process_handlers(use_threads: bool) -> any:
+    if use_threads:
+        from threading import Event
+    else:
+        from multiprocessing import Event
+
+    shutdown_event = Event()
+
+    # Define signal handler
+    def signal_handler(signum, frame):
+        logging.info("Shutdown signal received. Cleaning up...")
+        shutdown_event.set()
+        global shutdown_event_counter
+        shutdown_event_counter += 1
+
+        if shutdown_event_counter > 1:
+            logging.info("Force shutdown")
+            sys.exit(1)
+
+    signal.signal(signal.SIGINT, signal_handler)  # Ctrl+C
+    signal.signal(signal.SIGTERM, signal_handler)  # Termination request (kill)
+    signal.signal(signal.SIGHUP, signal_handler)  # Terminal closed/Hangup
+    signal.signal(signal.SIGQUIT, signal_handler)  # Ctrl+\
+
+    def signal_handler(signum, frame):
+        logging.info("Shutdown signal received. Cleaning up...")
+        shutdown_event.set()
+
+    return shutdown_event
+
+
+def get_last_item_from_queue(queue: Queue):
+    item = queue.get()
+    counter = 1
+
+    # Drain queue and keep only the most recent parameters
+    try:
+        while True:
+            item = queue.get_nowait()
+            counter += 1
+    except Empty:
+        pass
+
+    logging.debug(f"Drained {counter} items from queue")
+
+    return item