From 700f00c01457151161368f7e7f42b3a4711d1a38 Mon Sep 17 00:00:00 2001 From: Eugene Mironov Date: Wed, 5 Mar 2025 17:19:31 +0700 Subject: [PATCH] [HIL-SERL] Migrate threading to multiprocessing (#759) Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> --- lerobot/common/utils/utils.py | 10 +- lerobot/configs/env/maniskill_example.yaml | 6 + lerobot/configs/env/so100_real.yaml | 1 - lerobot/configs/policy/sac_maniskill.yaml | 7 +- lerobot/scripts/server/actor_server.py | 422 ++++++++++------- lerobot/scripts/server/buffer.py | 36 +- lerobot/scripts/server/hilserl.proto | 19 +- lerobot/scripts/server/hilserl_pb2.py | 28 +- lerobot/scripts/server/hilserl_pb2_grpc.py | 106 ++++- lerobot/scripts/server/learner_server.py | 440 ++++++++++-------- lerobot/scripts/server/learner_service.py | 131 ++---- .../scripts/server/maniskill_manipulator.py | 12 +- lerobot/scripts/server/network_utils.py | 102 ++++ lerobot/scripts/server/utils.py | 72 +++ 14 files changed, 900 insertions(+), 492 deletions(-) create mode 100644 lerobot/scripts/server/network_utils.py create mode 100644 lerobot/scripts/server/utils.py diff --git a/lerobot/common/utils/utils.py b/lerobot/common/utils/utils.py index 2bf19738..fecf88f9 100644 --- a/lerobot/common/utils/utils.py +++ b/lerobot/common/utils/utils.py @@ -116,11 +116,11 @@ def seeded_context(seed: int) -> Generator[None, None, None]: set_global_random_state(random_state_dict) -def init_logging(): +def init_logging(log_file=None): def custom_format(record): dt = datetime.now().strftime("%Y-%m-%d %H:%M:%S") fnameline = f"{record.pathname}:{record.lineno}" - message = f"{record.levelname} {dt} {fnameline[-15:]:>15} {record.msg}" + message = f"{record.levelname} [PID: {os.getpid()}] {dt} {fnameline[-15:]:>15} {record.msg}" return message logging.basicConfig(level=logging.INFO) @@ -134,6 +134,12 @@ def init_logging(): console_handler.setFormatter(formatter) logging.getLogger().addHandler(console_handler) + if log_file is not None: + # File handler + file_handler = logging.FileHandler(log_file) + file_handler.setFormatter(formatter) + logging.getLogger().addHandler(file_handler) + def format_big_number(num, precision=0): suffixes = ["", "K", "M", "B", "T", "Q"] diff --git a/lerobot/configs/env/maniskill_example.yaml b/lerobot/configs/env/maniskill_example.yaml index 9098bcbe..3df23b2e 100644 --- a/lerobot/configs/env/maniskill_example.yaml +++ b/lerobot/configs/env/maniskill_example.yaml @@ -22,3 +22,9 @@ env: wrapper: joint_masking_action_space: null delta_action: null + + video_record: + enabled: false + record_dir: maniskill_videos + trajectory_name: trajectory + fps: ${fps} diff --git a/lerobot/configs/env/so100_real.yaml b/lerobot/configs/env/so100_real.yaml index 1bd5cd83..dc30224c 100644 --- a/lerobot/configs/env/so100_real.yaml +++ b/lerobot/configs/env/so100_real.yaml @@ -28,4 +28,3 @@ env: reward_classifier: pretrained_path: outputs/classifier/13-02-random-sample-resnet10-frozen/checkpoints/best/pretrained_model config_path: lerobot/configs/policy/hilserl_classifier.yaml - diff --git a/lerobot/configs/policy/sac_maniskill.yaml b/lerobot/configs/policy/sac_maniskill.yaml index c954b1ea..c9bbca44 100644 --- a/lerobot/configs/policy/sac_maniskill.yaml +++ b/lerobot/configs/policy/sac_maniskill.yaml @@ -8,14 +8,12 @@ # env.gym.obs_type=environment_state_agent_pos \ seed: 1 -# dataset_repo_id: null dataset_repo_id: "AdilZtn/Maniskill-Pushcube-demonstration-medium" training: # Offline training dataloader num_workers: 4 - # batch_size: 256 batch_size: 512 grad_clip_norm: 10.0 lr: 3e-4 @@ -113,4 +111,7 @@ policy: actor_learner_config: learner_host: "127.0.0.1" learner_port: 50051 - policy_parameters_push_frequency: 15 + policy_parameters_push_frequency: 1 + concurrency: + actor: 'processes' + learner: 'processes' diff --git a/lerobot/scripts/server/actor_server.py b/lerobot/scripts/server/actor_server.py index c70417cf..24d8356d 100644 --- a/lerobot/scripts/server/actor_server.py +++ b/lerobot/scripts/server/actor_server.py @@ -13,22 +13,19 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -import io import logging -import pickle -import queue from statistics import mean, quantiles -import signal from functools import lru_cache +from lerobot.scripts.server.utils import setup_process_handlers # from lerobot.scripts.eval import eval_policy -from threading import Thread import grpc import hydra import torch from omegaconf import DictConfig from torch import nn +import time # TODO: Remove the import of maniskill # from lerobot.common.envs.factory import make_maniskill_env @@ -47,157 +44,184 @@ from lerobot.scripts.server.buffer import ( Transition, move_state_dict_to_device, move_transition_to_device, - bytes_buffer_size, + python_object_to_bytes, + transitions_to_bytes, + bytes_to_state_dict, +) +from lerobot.scripts.server.network_utils import ( + receive_bytes_in_chunks, + send_bytes_in_chunks, ) from lerobot.scripts.server.gym_manipulator import get_classifier, make_robot_env from lerobot.scripts.server import learner_service -from threading import Event +from torch.multiprocessing import Queue, Event +from queue import Empty -logging.basicConfig(level=logging.INFO) +from lerobot.common.utils.utils import init_logging -parameters_queue = queue.Queue(maxsize=1) -message_queue = queue.Queue(maxsize=1_000_000) +from lerobot.scripts.server.utils import get_last_item_from_queue ACTOR_SHUTDOWN_TIMEOUT = 30 -class ActorInformation: - """ - This helper class is used to differentiate between two types of messages that are placed in the same queue during streaming: - - - **Transition Data:** Contains experience tuples (observation, action, reward, next observation) collected during interaction. - - **Interaction Messages:** Encapsulates statistics related to the interaction process. - - Attributes: - transition (Optional): Transition data to be sent to the learner. - interaction_message (Optional): Iteraction message providing additional statistics for logging. - """ - - def __init__(self, transition=None, interaction_message=None): - self.transition = transition - self.interaction_message = interaction_message - - def receive_policy( - learner_client: hilserl_pb2_grpc.LearnerServiceStub, - shutdown_event: Event, - parameters_queue: queue.Queue, + cfg: DictConfig, + parameters_queue: Queue, + shutdown_event: any, # Event, + learner_client: hilserl_pb2_grpc.LearnerServiceStub | None = None, + grpc_channel: grpc.Channel | None = None, ): logging.info("[ACTOR] Start receiving parameters from the Learner") - bytes_buffer = io.BytesIO() - step = 0 + + if not use_threads(cfg): + # Setup process handlers to handle shutdown signal + # But use shutdown event from the main process + setup_process_handlers(False) + + if grpc_channel is None or learner_client is None: + learner_client, grpc_channel = learner_service_client( + host=cfg.actor_learner_config.learner_host, + port=cfg.actor_learner_config.learner_port, + ) + try: - for model_update in learner_client.StreamParameters(hilserl_pb2.Empty()): - if shutdown_event.is_set(): - logging.info("[ACTOR] Shutting down policy streaming receiver") - return hilserl_pb2.Empty() - - if model_update.transfer_state == hilserl_pb2.TransferState.TRANSFER_BEGIN: - bytes_buffer.seek(0) - bytes_buffer.truncate(0) - bytes_buffer.write(model_update.parameter_bytes) - logging.info("Received model update at step 0") - step = 0 - continue - elif ( - model_update.transfer_state == hilserl_pb2.TransferState.TRANSFER_MIDDLE - ): - bytes_buffer.write(model_update.parameter_bytes) - step += 1 - logging.info(f"Received model update at step {step}") - elif model_update.transfer_state == hilserl_pb2.TransferState.TRANSFER_END: - bytes_buffer.write(model_update.parameter_bytes) - logging.info( - f"Received model update at step end size {bytes_buffer_size(bytes_buffer)}" - ) - - state_dict = torch.load(bytes_buffer) - - bytes_buffer.seek(0) - bytes_buffer.truncate(0) - step = 0 - - logging.info("Model updated") - - parameters_queue.put(state_dict) - + iterator = learner_client.StreamParameters(hilserl_pb2.Empty()) + receive_bytes_in_chunks( + iterator, + parameters_queue, + shutdown_event, + log_prefix="[ACTOR] parameters", + ) except grpc.RpcError as e: logging.error(f"[ACTOR] gRPC error: {e}") + if not use_threads(cfg): + grpc_channel.close() + logging.info("[ACTOR] Received policy loop stopped") + + +def transitions_stream( + shutdown_event: Event, transitions_queue: Queue +) -> hilserl_pb2.Empty: + while not shutdown_event.is_set(): + try: + message = transitions_queue.get(block=True, timeout=5) + except Empty: + logging.debug("[ACTOR] Transition queue is empty") + continue + + yield from send_bytes_in_chunks( + message, hilserl_pb2.Transition, log_prefix="[ACTOR] Send transitions" + ) + return hilserl_pb2.Empty() -def transitions_stream(shutdown_event: Event, message_queue: queue.Queue): +def interactions_stream( + shutdown_event: any, # Event, + interactions_queue: Queue, +) -> hilserl_pb2.Empty: while not shutdown_event.is_set(): try: - message = message_queue.get(block=True, timeout=5) - except queue.Empty: - logging.debug("[ACTOR] Transition queue is empty") + message = interactions_queue.get(block=True, timeout=5) + except Empty: + logging.debug("[ACTOR] Interaction queue is empty") continue - if message.transition is not None: - transition_to_send_to_learner: list[Transition] = [ - move_transition_to_device(transition=T, device="cpu") - for T in message.transition - ] - # Check for NaNs in transitions before sending to learner - for transition in transition_to_send_to_learner: - for key, value in transition["state"].items(): - if torch.isnan(value).any(): - logging.warning(f"Found NaN values in transition {key}") - buf = io.BytesIO() - torch.save(transition_to_send_to_learner, buf) - transition_bytes = buf.getvalue() - - transition_message = hilserl_pb2.Transition( - transition_bytes=transition_bytes - ) - - response = hilserl_pb2.ActorInformation(transition=transition_message) - - elif message.interaction_message is not None: - content = hilserl_pb2.InteractionMessage( - interaction_message_bytes=pickle.dumps(message.interaction_message) - ) - response = hilserl_pb2.ActorInformation(interaction_message=content) - - yield response + yield from send_bytes_in_chunks( + message, + hilserl_pb2.InteractionMessage, + log_prefix="[ACTOR] Send interactions", + ) return hilserl_pb2.Empty() def send_transitions( - learner_client: hilserl_pb2_grpc.LearnerServiceStub, - shutdown_event: Event, - message_queue: queue.Queue, -): + cfg: DictConfig, + transitions_queue: Queue, + shutdown_event: any, # Event, + learner_client: hilserl_pb2_grpc.LearnerServiceStub | None = None, + grpc_channel: grpc.Channel | None = None, +) -> hilserl_pb2.Empty: """ - Streams data from the actor to the learner. + Sends transitions to the learner. - This function continuously retrieves messages from the queue and processes them based on their type: + This function continuously retrieves messages from the queue and processes: - **Transition Data:** - A batch of transitions (observation, action, reward, next observation) is collected. - Transitions are moved to the CPU and serialized using PyTorch. - The serialized data is wrapped in a `hilserl_pb2.Transition` message and sent to the learner. - - - **Interaction Messages:** - - Contains useful statistics about episodic rewards and policy timings. - - The message is serialized using `pickle` and sent to the learner. - - Yields: - hilserl_pb2.ActorInformation: The response message containing either transition data or an interaction message. """ + + if not use_threads(cfg): + # Setup process handlers to handle shutdown signal + # But use shutdown event from the main process + setup_process_handlers(False) + + if grpc_channel is None or learner_client is None: + learner_client, grpc_channel = learner_service_client( + host=cfg.actor_learner_config.learner_host, + port=cfg.actor_learner_config.learner_port, + ) + try: - learner_client.ReceiveTransitions( - transitions_stream(shutdown_event, message_queue) + learner_client.SendTransitions( + transitions_stream(shutdown_event, transitions_queue) ) except grpc.RpcError as e: logging.error(f"[ACTOR] gRPC error: {e}") logging.info("[ACTOR] Finished streaming transitions") + if not use_threads(cfg): + grpc_channel.close() + logging.info("[ACTOR] Transitions process stopped") + + +def send_interactions( + cfg: DictConfig, + interactions_queue: Queue, + shutdown_event: any, # Event, + learner_client: hilserl_pb2_grpc.LearnerServiceStub | None = None, + grpc_channel: grpc.Channel | None = None, +) -> hilserl_pb2.Empty: + """ + Sends interactions to the learner. + + This function continuously retrieves messages from the queue and processes: + + - **Interaction Messages:** + - Contains useful statistics about episodic rewards and policy timings. + - The message is serialized using `pickle` and sent to the learner. + """ + + if not use_threads(cfg): + # Setup process handlers to handle shutdown signal + # But use shutdown event from the main process + setup_process_handlers(False) + + if grpc_channel is None or learner_client is None: + learner_client, grpc_channel = learner_service_client( + host=cfg.actor_learner_config.learner_host, + port=cfg.actor_learner_config.learner_port, + ) + + try: + learner_client.SendInteractions( + interactions_stream(shutdown_event, interactions_queue) + ) + except grpc.RpcError as e: + logging.error(f"[ACTOR] gRPC error: {e}") + + logging.info("[ACTOR] Finished streaming interactions") + + if not use_threads(cfg): + grpc_channel.close() + logging.info("[ACTOR] Interactions process stopped") + @lru_cache(maxsize=1) def learner_service_client( @@ -217,7 +241,7 @@ def learner_service_client( { "name": [{}], # Applies to ALL methods in ALL services "retryPolicy": { - "maxAttempts": 7, # Max retries (total attempts = 5) + "maxAttempts": 5, # Max retries (total attempts = 5) "initialBackoff": "0.1s", # First retry after 0.1s "maxBackoff": "2s", # Max wait time between retries "backoffMultiplier": 2, # Exponential backoff factor @@ -242,20 +266,27 @@ def learner_service_client( ], ) stub = hilserl_pb2_grpc.LearnerServiceStub(channel) - logging.info("[LEARNER] Learner service client created") + logging.info("[ACTOR] Learner service client created") return stub, channel -def update_policy_parameters(policy: SACPolicy, parameters_queue: queue.Queue, device): +def update_policy_parameters(policy: SACPolicy, parameters_queue: Queue, device): if not parameters_queue.empty(): logging.info("[ACTOR] Load new parameters from Learner.") - state_dict = parameters_queue.get() + bytes_state_dict = get_last_item_from_queue(parameters_queue) + state_dict = bytes_to_state_dict(bytes_state_dict) state_dict = move_state_dict_to_device(state_dict, device=device) policy.load_state_dict(state_dict) def act_with_policy( - cfg: DictConfig, robot: Robot, reward_classifier: nn.Module, shutdown_event: Event + cfg: DictConfig, + robot: Robot, + reward_classifier: nn.Module, + shutdown_event: any, # Event, + parameters_queue: Queue, + transitions_queue: Queue, + interactions_queue: Queue, ): """ Executes policy interaction within the environment. @@ -317,7 +348,7 @@ def act_with_policy( for interaction_step in range(cfg.training.online_steps): if shutdown_event.is_set(): - logging.info("[ACTOR] Shutdown signal received. Exiting...") + logging.info("[ACTOR] Shutting down act_with_policy") return if interaction_step >= cfg.training.online_step_before_learning: @@ -394,10 +425,9 @@ def act_with_policy( ) if len(list_transition_to_send_to_learner) > 0: - send_transitions_in_chunks( + push_transitions_to_transport_queue( transitions=list_transition_to_send_to_learner, - message_queue=message_queue, - chunk_size=4, + transitions_queue=transitions_queue, ) list_transition_to_send_to_learner = [] @@ -405,9 +435,9 @@ def act_with_policy( list_policy_time.clear() # Send episodic reward to the learner - message_queue.put( - ActorInformation( - interaction_message={ + interactions_queue.put( + python_object_to_bytes( + { "Episodic reward": sum_reward_episode, "Interaction step": interaction_step, "Episode intervention": int(episode_intervention), @@ -420,7 +450,7 @@ def act_with_policy( obs, info = online_env.reset() -def send_transitions_in_chunks(transitions: list, message_queue, chunk_size: int = 100): +def push_transitions_to_transport_queue(transitions: list, transitions_queue): """Send transitions to learner in smaller chunks to avoid network issues. Args: @@ -428,10 +458,16 @@ def send_transitions_in_chunks(transitions: list, message_queue, chunk_size: int message_queue: Queue to send messages to learner chunk_size: Size of each chunk to send """ - for i in range(0, len(transitions), chunk_size): - chunk = transitions[i : i + chunk_size] - logging.debug(f"[ACTOR] Sending chunk of {len(chunk)} transitions to Learner.") - message_queue.put(ActorInformation(transition=chunk)) + transition_to_send_to_learner = [] + for transition in transitions: + tr = move_transition_to_device(transition=transition, device="cpu") + for key, value in tr["state"].items(): + if torch.isnan(value).any(): + logging.warning(f"Found NaN values in transition {key}") + + transition_to_send_to_learner.append(tr) + + transitions_queue.put(transitions_to_bytes(transition_to_send_to_learner)) def get_frequency_stats(list_policy_time: list[float]) -> dict[str, float]: @@ -458,39 +494,96 @@ def log_policy_frequency_issue( ) +def establish_learner_connection( + stub, + shutdown_event: any, # Event, + attempts=30, +): + for _ in range(attempts): + if shutdown_event.is_set(): + logging.info("[ACTOR] Shutting down establish_learner_connection") + return False + + # Force a connection attempt and check state + try: + logging.info("[ACTOR] Send ready message to Learner") + if stub.Ready(hilserl_pb2.Empty()) == hilserl_pb2.Empty(): + return True + except grpc.RpcError as e: + logging.error(f"[ACTOR] Waiting for Learner to be ready... {e}") + time.sleep(2) + return False + + +def use_threads(cfg: DictConfig) -> bool: + return cfg.actor_learner_config.concurrency.actor == "threads" + + @hydra.main(version_base="1.2", config_name="default", config_path="../../configs") def actor_cli(cfg: dict): + if not use_threads(cfg): + import torch.multiprocessing as mp + + mp.set_start_method("spawn") + + init_logging(log_file="actor.log") robot = make_robot(cfg=cfg.robot) - shutdown_event = Event() - - # Define signal handler - def signal_handler(signum, frame): - logging.info("Shutdown signal received. Cleaning up...") - shutdown_event.set() - - signal.signal(signal.SIGINT, signal_handler) # Ctrl+C - signal.signal(signal.SIGTERM, signal_handler) # Termination request (kill) - signal.signal(signal.SIGHUP, signal_handler) # Terminal closed/Hangup - signal.signal(signal.SIGQUIT, signal_handler) # Ctrl+\ + shutdown_event = setup_process_handlers(use_threads(cfg)) learner_client, grpc_channel = learner_service_client( host=cfg.actor_learner_config.learner_host, port=cfg.actor_learner_config.learner_port, ) - receive_policy_thread = Thread( + logging.info("[ACTOR] Establishing connection with Learner") + if not establish_learner_connection(learner_client, shutdown_event): + logging.error("[ACTOR] Failed to establish connection with Learner") + return + + if not use_threads(cfg): + # If we use multithreading, we can reuse the channel + grpc_channel.close() + grpc_channel = None + + logging.info("[ACTOR] Connection with Learner established") + + parameters_queue = Queue() + transitions_queue = Queue() + interactions_queue = Queue() + + concurrency_entity = None + if use_threads(cfg): + from threading import Thread + + concurrency_entity = Thread + else: + from multiprocessing import Process + + concurrency_entity = Process + + receive_policy_process = concurrency_entity( target=receive_policy, - args=(learner_client, shutdown_event, parameters_queue), + args=(cfg, parameters_queue, shutdown_event, grpc_channel), daemon=True, ) - transitions_thread = Thread( + transitions_process = concurrency_entity( target=send_transitions, - args=(learner_client, shutdown_event, message_queue), + args=(cfg, transitions_queue, shutdown_event, grpc_channel), daemon=True, ) + interactions_process = concurrency_entity( + target=send_interactions, + args=(cfg, interactions_queue, shutdown_event, grpc_channel), + daemon=True, + ) + + transitions_process.start() + interactions_process.start() + receive_policy_process.start() + # HACK: FOR MANISKILL we do not have a reward classifier # TODO: Remove this once we merge into main reward_classifier = None @@ -503,26 +596,35 @@ def actor_cli(cfg: dict): config_path=cfg.env.reward_classifier.config_path, ) - policy_thread = Thread( - target=act_with_policy, - daemon=True, - args=(cfg, robot, reward_classifier, shutdown_event), + act_with_policy( + cfg, + robot, + reward_classifier, + shutdown_event, + parameters_queue, + transitions_queue, + interactions_queue, ) + logging.info("[ACTOR] Policy process joined") - transitions_thread.start() - policy_thread.start() - receive_policy_thread.start() + logging.info("[ACTOR] Closing queues") + transitions_queue.close() + interactions_queue.close() + parameters_queue.close() - shutdown_event.wait() - logging.info("[ACTOR] Shutdown event received") - grpc_channel.close() + transitions_process.join() + logging.info("[ACTOR] Transitions process joined") + interactions_process.join() + logging.info("[ACTOR] Interactions process joined") + receive_policy_process.join() + logging.info("[ACTOR] Receive policy process joined") - policy_thread.join() - logging.info("[ACTOR] Policy thread joined") - transitions_thread.join() - logging.info("[ACTOR] Transitions thread joined") - receive_policy_thread.join() - logging.info("[ACTOR] Receive policy thread joined") + logging.info("[ACTOR] join queues") + transitions_queue.cancel_join_thread() + interactions_queue.cancel_join_thread() + parameters_queue.cancel_join_thread() + + logging.info("[ACTOR] queues closed") if __name__ == "__main__": diff --git a/lerobot/scripts/server/buffer.py b/lerobot/scripts/server/buffer.py index f93b40ca..80834eac 100644 --- a/lerobot/scripts/server/buffer.py +++ b/lerobot/scripts/server/buffer.py @@ -23,6 +23,7 @@ from tqdm import tqdm from lerobot.common.datasets.lerobot_dataset import LeRobotDataset import os +import pickle class Transition(TypedDict): @@ -91,7 +92,7 @@ def move_transition_to_device( return transition -def move_state_dict_to_device(state_dict, device): +def move_state_dict_to_device(state_dict, device="cpu"): """ Recursively move all tensors in a (potentially) nested dict/list/tuple structure to the CPU. @@ -111,20 +112,41 @@ def move_state_dict_to_device(state_dict, device): return state_dict -def state_to_bytes(state_dict: dict[str, torch.Tensor]) -> io.BytesIO: +def state_to_bytes(state_dict: dict[str, torch.Tensor]) -> bytes: """Convert model state dict to flat array for transmission""" buffer = io.BytesIO() torch.save(state_dict, buffer) - return buffer + return buffer.getvalue() -def bytes_buffer_size(buffer: io.BytesIO) -> int: - buffer.seek(0, io.SEEK_END) - result = buffer.tell() +def bytes_to_state_dict(buffer: bytes) -> dict[str, torch.Tensor]: + buffer = io.BytesIO(buffer) buffer.seek(0) - return result + return torch.load(buffer) + + +def python_object_to_bytes(python_object: Any) -> bytes: + return pickle.dumps(python_object) + + +def bytes_to_python_object(buffer: bytes) -> Any: + buffer = io.BytesIO(buffer) + buffer.seek(0) + return pickle.load(buffer) + + +def bytes_to_transitions(buffer: bytes) -> list[Transition]: + buffer = io.BytesIO(buffer) + buffer.seek(0) + return torch.load(buffer) + + +def transitions_to_bytes(transitions: list[Transition]) -> bytes: + buffer = io.BytesIO() + torch.save(transitions, buffer) + return buffer.getvalue() def random_crop_vectorized(images: torch.Tensor, output_size: tuple) -> torch.Tensor: diff --git a/lerobot/scripts/server/hilserl.proto b/lerobot/scripts/server/hilserl.proto index 6aa46e0e..dec2117b 100644 --- a/lerobot/scripts/server/hilserl.proto +++ b/lerobot/scripts/server/hilserl.proto @@ -24,14 +24,9 @@ service LearnerService { // Actor -> Learner to store transitions rpc SendInteractionMessage(InteractionMessage) returns (Empty); rpc StreamParameters(Empty) returns (stream Parameters); - rpc ReceiveTransitions(stream ActorInformation) returns (Empty); -} - -message ActorInformation { - oneof data { - Transition transition = 1; - InteractionMessage interaction_message = 2; - } + rpc SendTransitions(stream Transition) returns (Empty); + rpc SendInteractions(stream InteractionMessage) returns (Empty); + rpc Ready(Empty) returns (Empty); } enum TransferState { @@ -43,16 +38,18 @@ enum TransferState { // Messages message Transition { - bytes transition_bytes = 1; + TransferState transfer_state = 1; + bytes data = 2; } message Parameters { TransferState transfer_state = 1; - bytes parameter_bytes = 2; + bytes data = 2; } message InteractionMessage { - bytes interaction_message_bytes = 1; + TransferState transfer_state = 1; + bytes data = 2; } message Empty {} diff --git a/lerobot/scripts/server/hilserl_pb2.py b/lerobot/scripts/server/hilserl_pb2.py index d5eb8d4c..4a4cbea7 100644 --- a/lerobot/scripts/server/hilserl_pb2.py +++ b/lerobot/scripts/server/hilserl_pb2.py @@ -24,25 +24,23 @@ _sym_db = _symbol_database.Default() -DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\rhilserl.proto\x12\x08hil_serl\"\x83\x01\n\x10\x41\x63torInformation\x12*\n\ntransition\x18\x01 \x01(\x0b\x32\x14.hil_serl.TransitionH\x00\x12;\n\x13interaction_message\x18\x02 \x01(\x0b\x32\x1c.hil_serl.InteractionMessageH\x00\x42\x06\n\x04\x64\x61ta\"&\n\nTransition\x12\x18\n\x10transition_bytes\x18\x01 \x01(\x0c\"V\n\nParameters\x12/\n\x0etransfer_state\x18\x01 \x01(\x0e\x32\x17.hil_serl.TransferState\x12\x17\n\x0fparameter_bytes\x18\x02 \x01(\x0c\"7\n\x12InteractionMessage\x12!\n\x19interaction_message_bytes\x18\x01 \x01(\x0c\"\x07\n\x05\x45mpty*`\n\rTransferState\x12\x14\n\x10TRANSFER_UNKNOWN\x10\x00\x12\x12\n\x0eTRANSFER_BEGIN\x10\x01\x12\x13\n\x0fTRANSFER_MIDDLE\x10\x02\x12\x10\n\x0cTRANSFER_END\x10\x03\x32\xdb\x01\n\x0eLearnerService\x12G\n\x16SendInteractionMessage\x12\x1c.hil_serl.InteractionMessage\x1a\x0f.hil_serl.Empty\x12;\n\x10StreamParameters\x12\x0f.hil_serl.Empty\x1a\x14.hil_serl.Parameters0\x01\x12\x43\n\x12ReceiveTransitions\x12\x1a.hil_serl.ActorInformation\x1a\x0f.hil_serl.Empty(\x01\x62\x06proto3') +DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\rhilserl.proto\x12\x08hil_serl\"K\n\nTransition\x12/\n\x0etransfer_state\x18\x01 \x01(\x0e\x32\x17.hil_serl.TransferState\x12\x0c\n\x04\x64\x61ta\x18\x02 \x01(\x0c\"K\n\nParameters\x12/\n\x0etransfer_state\x18\x01 \x01(\x0e\x32\x17.hil_serl.TransferState\x12\x0c\n\x04\x64\x61ta\x18\x02 \x01(\x0c\"S\n\x12InteractionMessage\x12/\n\x0etransfer_state\x18\x01 \x01(\x0e\x32\x17.hil_serl.TransferState\x12\x0c\n\x04\x64\x61ta\x18\x02 \x01(\x0c\"\x07\n\x05\x45mpty*`\n\rTransferState\x12\x14\n\x10TRANSFER_UNKNOWN\x10\x00\x12\x12\n\x0eTRANSFER_BEGIN\x10\x01\x12\x13\n\x0fTRANSFER_MIDDLE\x10\x02\x12\x10\n\x0cTRANSFER_END\x10\x03\x32\xc2\x02\n\x0eLearnerService\x12G\n\x16SendInteractionMessage\x12\x1c.hil_serl.InteractionMessage\x1a\x0f.hil_serl.Empty\x12;\n\x10StreamParameters\x12\x0f.hil_serl.Empty\x1a\x14.hil_serl.Parameters0\x01\x12:\n\x0fSendTransitions\x12\x14.hil_serl.Transition\x1a\x0f.hil_serl.Empty(\x01\x12\x43\n\x10SendInteractions\x12\x1c.hil_serl.InteractionMessage\x1a\x0f.hil_serl.Empty(\x01\x12)\n\x05Ready\x12\x0f.hil_serl.Empty\x1a\x0f.hil_serl.Emptyb\x06proto3') _globals = globals() _builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, _globals) _builder.BuildTopDescriptorsAndMessages(DESCRIPTOR, 'hilserl_pb2', _globals) if not _descriptor._USE_C_DESCRIPTORS: DESCRIPTOR._loaded_options = None - _globals['_TRANSFERSTATE']._serialized_start=355 - _globals['_TRANSFERSTATE']._serialized_end=451 - _globals['_ACTORINFORMATION']._serialized_start=28 - _globals['_ACTORINFORMATION']._serialized_end=159 - _globals['_TRANSITION']._serialized_start=161 - _globals['_TRANSITION']._serialized_end=199 - _globals['_PARAMETERS']._serialized_start=201 - _globals['_PARAMETERS']._serialized_end=287 - _globals['_INTERACTIONMESSAGE']._serialized_start=289 - _globals['_INTERACTIONMESSAGE']._serialized_end=344 - _globals['_EMPTY']._serialized_start=346 - _globals['_EMPTY']._serialized_end=353 - _globals['_LEARNERSERVICE']._serialized_start=454 - _globals['_LEARNERSERVICE']._serialized_end=673 + _globals['_TRANSFERSTATE']._serialized_start=275 + _globals['_TRANSFERSTATE']._serialized_end=371 + _globals['_TRANSITION']._serialized_start=27 + _globals['_TRANSITION']._serialized_end=102 + _globals['_PARAMETERS']._serialized_start=104 + _globals['_PARAMETERS']._serialized_end=179 + _globals['_INTERACTIONMESSAGE']._serialized_start=181 + _globals['_INTERACTIONMESSAGE']._serialized_end=264 + _globals['_EMPTY']._serialized_start=266 + _globals['_EMPTY']._serialized_end=273 + _globals['_LEARNERSERVICE']._serialized_start=374 + _globals['_LEARNERSERVICE']._serialized_end=696 # @@protoc_insertion_point(module_scope) diff --git a/lerobot/scripts/server/hilserl_pb2_grpc.py b/lerobot/scripts/server/hilserl_pb2_grpc.py index 42d4674e..1fa96e81 100644 --- a/lerobot/scripts/server/hilserl_pb2_grpc.py +++ b/lerobot/scripts/server/hilserl_pb2_grpc.py @@ -46,9 +46,19 @@ class LearnerServiceStub(object): request_serializer=hilserl__pb2.Empty.SerializeToString, response_deserializer=hilserl__pb2.Parameters.FromString, _registered_method=True) - self.ReceiveTransitions = channel.stream_unary( - '/hil_serl.LearnerService/ReceiveTransitions', - request_serializer=hilserl__pb2.ActorInformation.SerializeToString, + self.SendTransitions = channel.stream_unary( + '/hil_serl.LearnerService/SendTransitions', + request_serializer=hilserl__pb2.Transition.SerializeToString, + response_deserializer=hilserl__pb2.Empty.FromString, + _registered_method=True) + self.SendInteractions = channel.stream_unary( + '/hil_serl.LearnerService/SendInteractions', + request_serializer=hilserl__pb2.InteractionMessage.SerializeToString, + response_deserializer=hilserl__pb2.Empty.FromString, + _registered_method=True) + self.Ready = channel.unary_unary( + '/hil_serl.LearnerService/Ready', + request_serializer=hilserl__pb2.Empty.SerializeToString, response_deserializer=hilserl__pb2.Empty.FromString, _registered_method=True) @@ -71,7 +81,19 @@ class LearnerServiceServicer(object): context.set_details('Method not implemented!') raise NotImplementedError('Method not implemented!') - def ReceiveTransitions(self, request_iterator, context): + def SendTransitions(self, request_iterator, context): + """Missing associated documentation comment in .proto file.""" + context.set_code(grpc.StatusCode.UNIMPLEMENTED) + context.set_details('Method not implemented!') + raise NotImplementedError('Method not implemented!') + + def SendInteractions(self, request_iterator, context): + """Missing associated documentation comment in .proto file.""" + context.set_code(grpc.StatusCode.UNIMPLEMENTED) + context.set_details('Method not implemented!') + raise NotImplementedError('Method not implemented!') + + def Ready(self, request, context): """Missing associated documentation comment in .proto file.""" context.set_code(grpc.StatusCode.UNIMPLEMENTED) context.set_details('Method not implemented!') @@ -90,9 +112,19 @@ def add_LearnerServiceServicer_to_server(servicer, server): request_deserializer=hilserl__pb2.Empty.FromString, response_serializer=hilserl__pb2.Parameters.SerializeToString, ), - 'ReceiveTransitions': grpc.stream_unary_rpc_method_handler( - servicer.ReceiveTransitions, - request_deserializer=hilserl__pb2.ActorInformation.FromString, + 'SendTransitions': grpc.stream_unary_rpc_method_handler( + servicer.SendTransitions, + request_deserializer=hilserl__pb2.Transition.FromString, + response_serializer=hilserl__pb2.Empty.SerializeToString, + ), + 'SendInteractions': grpc.stream_unary_rpc_method_handler( + servicer.SendInteractions, + request_deserializer=hilserl__pb2.InteractionMessage.FromString, + response_serializer=hilserl__pb2.Empty.SerializeToString, + ), + 'Ready': grpc.unary_unary_rpc_method_handler( + servicer.Ready, + request_deserializer=hilserl__pb2.Empty.FromString, response_serializer=hilserl__pb2.Empty.SerializeToString, ), } @@ -163,7 +195,7 @@ class LearnerService(object): _registered_method=True) @staticmethod - def ReceiveTransitions(request_iterator, + def SendTransitions(request_iterator, target, options=(), channel_credentials=None, @@ -176,8 +208,62 @@ class LearnerService(object): return grpc.experimental.stream_unary( request_iterator, target, - '/hil_serl.LearnerService/ReceiveTransitions', - hilserl__pb2.ActorInformation.SerializeToString, + '/hil_serl.LearnerService/SendTransitions', + hilserl__pb2.Transition.SerializeToString, + hilserl__pb2.Empty.FromString, + options, + channel_credentials, + insecure, + call_credentials, + compression, + wait_for_ready, + timeout, + metadata, + _registered_method=True) + + @staticmethod + def SendInteractions(request_iterator, + target, + options=(), + channel_credentials=None, + call_credentials=None, + insecure=False, + compression=None, + wait_for_ready=None, + timeout=None, + metadata=None): + return grpc.experimental.stream_unary( + request_iterator, + target, + '/hil_serl.LearnerService/SendInteractions', + hilserl__pb2.InteractionMessage.SerializeToString, + hilserl__pb2.Empty.FromString, + options, + channel_credentials, + insecure, + call_credentials, + compression, + wait_for_ready, + timeout, + metadata, + _registered_method=True) + + @staticmethod + def Ready(request, + target, + options=(), + channel_credentials=None, + call_credentials=None, + insecure=False, + compression=None, + wait_for_ready=None, + timeout=None, + metadata=None): + return grpc.experimental.unary_unary( + request, + target, + '/hil_serl.LearnerService/Ready', + hilserl__pb2.Empty.SerializeToString, hilserl__pb2.Empty.FromString, options, channel_credentials, diff --git a/lerobot/scripts/server/learner_server.py b/lerobot/scripts/server/learner_server.py index 4bab9ac2..7bd4aee0 100644 --- a/lerobot/scripts/server/learner_server.py +++ b/lerobot/scripts/server/learner_server.py @@ -15,15 +15,18 @@ # See the License for the specific language governing permissions and # limitations under the License. import logging -import queue import shutil import time from pprint import pformat -from threading import Lock, Thread -import signal -from threading import Event from concurrent.futures import ThreadPoolExecutor +# from torch.multiprocessing import Event, Queue, Process +# from threading import Event, Thread +# from torch.multiprocessing import Queue, Event +from torch.multiprocessing import Queue + +from lerobot.scripts.server.utils import setup_process_handlers + import grpc # Import generated stubs @@ -52,19 +55,19 @@ from lerobot.common.utils.utils import ( set_global_random_state, set_global_seed, ) + from lerobot.scripts.server.buffer import ( ReplayBuffer, concatenate_batch_transitions, move_transition_to_device, + move_state_dict_to_device, + bytes_to_transitions, + state_to_bytes, + bytes_to_python_object, ) from lerobot.scripts.server import learner_service -logging.basicConfig(level=logging.INFO) - -transition_queue = queue.Queue() -interaction_message_queue = queue.Queue() - def handle_resume_logic(cfg: DictConfig, out_dir: str) -> DictConfig: if not cfg.resume: @@ -195,67 +198,96 @@ def get_observation_features( return observation_features, next_observation_features +def use_threads(cfg: DictConfig) -> bool: + return cfg.actor_learner_config.concurrency.learner == "threads" + + def start_learner_threads( cfg: DictConfig, - device: str, - replay_buffer: ReplayBuffer, - offline_replay_buffer: ReplayBuffer, - batch_size: int, - optimizers: dict, - policy: SACPolicy, - policy_lock: Lock, logger: Logger, - resume_optimization_step: int | None = None, - resume_interaction_step: int | None = None, - shutdown_event: Event | None = None, + out_dir: str, + shutdown_event: any, # Event, ) -> None: - host = cfg.actor_learner_config.learner_host - port = cfg.actor_learner_config.learner_port + # Create multiprocessing queues + transition_queue = Queue() + interaction_message_queue = Queue() + parameters_queue = Queue() - transition_thread = Thread( - target=add_actor_information_and_train, - daemon=True, + concurrency_entity = None + + if use_threads(cfg): + from threading import Thread + + concurrency_entity = Thread + else: + from torch.multiprocessing import Process + + concurrency_entity = Process + + communication_process = concurrency_entity( + target=start_learner_server, args=( - cfg, - device, - replay_buffer, - offline_replay_buffer, - batch_size, - optimizers, - policy, - policy_lock, - logger, - resume_optimization_step, - resume_interaction_step, + parameters_queue, + transition_queue, + interaction_message_queue, shutdown_event, + cfg, ), + daemon=True, ) + communication_process.start() - transition_thread.start() + add_actor_information_and_train( + cfg, + logger, + out_dir, + shutdown_event, + transition_queue, + interaction_message_queue, + parameters_queue, + ) + logging.info("[LEARNER] Training process stopped") + + logging.info("[LEARNER] Closing queues") + transition_queue.close() + interaction_message_queue.close() + parameters_queue.close() + + communication_process.join() + logging.info("[LEARNER] Communication process joined") + + logging.info("[LEARNER] join queues") + transition_queue.cancel_join_thread() + interaction_message_queue.cancel_join_thread() + parameters_queue.cancel_join_thread() + + logging.info("[LEARNER] queues closed") + + +def start_learner_server( + parameters_queue: Queue, + transition_queue: Queue, + interaction_message_queue: Queue, + shutdown_event: any, # Event, + cfg: DictConfig, +): + if not use_threads(cfg): + # We need init logging for MP separataly + init_logging() + + # Setup process handlers to handle shutdown signal + # But use shutdown event from the main process + # Return back for MP + setup_process_handlers(False) service = learner_service.LearnerService( shutdown_event, - policy, - policy_lock, + parameters_queue, cfg.actor_learner_config.policy_parameters_push_frequency, transition_queue, interaction_message_queue, ) - server = start_learner_server(service, host, port) - shutdown_event.wait() - server.stop(learner_service.STUTDOWN_TIMEOUT) - logging.info("[LEARNER] gRPC server stopped") - - transition_thread.join() - logging.info("[LEARNER] Transition thread stopped") - - -def start_learner_server( - service: learner_service.LearnerService, - host="0.0.0.0", - port=50051, -) -> grpc.server: server = grpc.server( ThreadPoolExecutor(max_workers=learner_service.MAX_WORKERS), options=[ @@ -263,15 +295,23 @@ def start_learner_server( ("grpc.max_send_message_length", learner_service.MAX_MESSAGE_SIZE), ], ) + hilserl_pb2_grpc.add_LearnerServiceServicer_to_server( service, server, ) + + host = cfg.actor_learner_config.learner_host + port = cfg.actor_learner_config.learner_port + server.add_insecure_port(f"{host}:{port}") server.start() logging.info("[LEARNER] gRPC server started") - return server + shutdown_event.wait() + logging.info("[LEARNER] Stopping gRPC server...") + server.stop(learner_service.STUTDOWN_TIMEOUT) + logging.info("[LEARNER] gRPC server stopped") def check_nan_in_transition( @@ -287,19 +327,21 @@ def check_nan_in_transition( logging.error("actions contains NaN values") +def push_actor_policy_to_queue(parameters_queue: Queue, policy: nn.Module): + logging.debug("[LEARNER] Pushing actor policy to the queue") + state_dict = move_state_dict_to_device(policy.actor.state_dict(), device="cpu") + state_bytes = state_to_bytes(state_dict) + parameters_queue.put(state_bytes) + + def add_actor_information_and_train( cfg, - device: str, - replay_buffer: ReplayBuffer, - offline_replay_buffer: ReplayBuffer, - batch_size: int, - optimizers: dict[str, torch.optim.Optimizer], - policy: nn.Module, - policy_lock: Lock, logger: Logger, - resume_optimization_step: int | None = None, - resume_interaction_step: int | None = None, - shutdown_event: Event | None = None, + out_dir: str, + shutdown_event: any, # Event, + transition_queue: Queue, + interaction_message_queue: Queue, + parameters_queue: Queue, ): """ Handles data transfer from the actor to the learner, manages training updates, @@ -322,17 +364,73 @@ def add_actor_information_and_train( Args: cfg: Configuration object containing hyperparameters. device (str): The computing device (`"cpu"` or `"cuda"`). - replay_buffer (ReplayBuffer): The primary replay buffer storing online transitions. - offline_replay_buffer (ReplayBuffer): An additional buffer for offline transitions. - batch_size (int): The number of transitions to sample per training step. - optimizers (Dict[str, torch.optim.Optimizer]): A dictionary of optimizers (`"actor"`, `"critic"`, `"temperature"`). - policy (nn.Module): The reinforcement learning policy with critic, actor, and temperature parameters. - policy_lock (Lock): A threading lock to ensure safe policy updates. logger (Logger): Logger instance for tracking training progress. - resume_optimization_step (int | None): In the case of resume training, start from the last optimization step reached. - resume_interaction_step (int | None): In the case of resume training, shift the interaction step with the last saved step in order to not break logging. - shutdown_event (Event | None): Event to signal shutdown. + out_dir (str): The output directory for storing training checkpoints and logs. + shutdown_event (Event): Event to signal shutdown. + transition_queue (Queue): Queue for receiving transitions from the actor. + interaction_message_queue (Queue): Queue for receiving interaction messages from the actor. + parameters_queue (Queue): Queue for sending policy parameters to the actor. """ + + device = get_safe_torch_device(cfg.device, log=True) + storage_device = get_safe_torch_device(cfg_device=cfg.training.storage_device) + + logging.info("Initializing policy") + ### Instantiate the policy in both the actor and learner processes + ### To avoid sending a SACPolicy object through the port, we create a policy intance + ### on both sides, the learner sends the updated parameters every n steps to update the actor's parameters + # TODO: At some point we should just need make sac policy + + policy: SACPolicy = make_policy( + hydra_cfg=cfg, + # dataset_stats=offline_dataset.meta.stats if not cfg.resume else None, + # Hack: But if we do online traning, we do not need dataset_stats + dataset_stats=None, + pretrained_policy_name_or_path=str(logger.last_pretrained_model_dir) + if cfg.resume + else None, + ) + # compile policy + policy = torch.compile(policy) + assert isinstance(policy, nn.Module) + + push_actor_policy_to_queue(parameters_queue, policy) + + last_time_policy_pushed = time.time() + + optimizers, lr_scheduler = make_optimizers_and_scheduler(cfg, policy) + resume_optimization_step, resume_interaction_step = load_training_state( + cfg, logger, optimizers + ) + + log_training_info(cfg, out_dir, policy) + + replay_buffer = initialize_replay_buffer(cfg, logger, device, storage_device) + batch_size = cfg.training.batch_size + offline_replay_buffer = None + + if cfg.dataset_repo_id is not None: + logging.info("make_dataset offline buffer") + offline_dataset = make_dataset(cfg) + logging.info("Convertion to a offline replay buffer") + active_action_dims = None + if cfg.env.wrapper.joint_masking_action_space is not None: + active_action_dims = [ + i + for i, mask in enumerate(cfg.env.wrapper.joint_masking_action_space) + if mask + ] + offline_replay_buffer = ReplayBuffer.from_lerobot_dataset( + offline_dataset, + device=device, + state_keys=cfg.policy.input_shapes.keys(), + action_mask=active_action_dims, + action_delta=cfg.env.wrapper.delta_action, + storage_device=storage_device, + optimize_memory=True, + ) + batch_size: int = batch_size // 2 # We will sample from both replay buffer + # NOTE: This function doesn't have a single responsibility, it should be split into multiple functions # in the future. The reason why we did that is the GIL in Python. It's super slow the performance # are divided by 200. So we need to have a single thread that does all the work. @@ -345,33 +443,39 @@ def add_actor_information_and_train( interaction_step_shift = ( resume_interaction_step if resume_interaction_step is not None else 0 ) - saved_data = False + while True: if shutdown_event is not None and shutdown_event.is_set(): logging.info("[LEARNER] Shutdown signal received. Exiting...") break - while not transition_queue.empty(): + logging.debug("[LEARNER] Waiting for transitions") + while not transition_queue.empty() and not shutdown_event.is_set(): transition_list = transition_queue.get() + transition_list = bytes_to_transitions(transition_list) + for transition in transition_list: transition = move_transition_to_device(transition, device=device) replay_buffer.add(**transition) - if transition.get("complementary_info", {}).get("is_intervention"): offline_replay_buffer.add(**transition) - - while not interaction_message_queue.empty(): + logging.debug("[LEARNER] Received transitions") + logging.debug("[LEARNER] Waiting for interactions") + while not interaction_message_queue.empty() and not shutdown_event.is_set(): interaction_message = interaction_message_queue.get() + interaction_message = bytes_to_python_object(interaction_message) # If cfg.resume, shift the interaction step with the last checkpointed step in order to not break the logging interaction_message["Interaction step"] += interaction_step_shift logger.log_dict( interaction_message, mode="train", custom_step_key="Interaction step" ) - # logging.info(f"Interaction message: {interaction_message}") + + logging.debug("[LEARNER] Received interactions") if len(replay_buffer) < cfg.training.online_step_before_learning: continue + logging.debug("[LEARNER] Starting optimization loop") time_for_one_optimization_step = time.time() for _ in range(cfg.policy.utd_ratio - 1): batch = replay_buffer.sample(batch_size) @@ -392,19 +496,18 @@ def add_actor_information_and_train( observation_features, next_observation_features = get_observation_features( policy, observations, next_observations ) - with policy_lock: - loss_critic = policy.compute_loss_critic( - observations=observations, - actions=actions, - rewards=rewards, - next_observations=next_observations, - done=done, - observation_features=observation_features, - next_observation_features=next_observation_features, - ) - optimizers["critic"].zero_grad() - loss_critic.backward() - optimizers["critic"].step() + loss_critic = policy.compute_loss_critic( + observations=observations, + actions=actions, + rewards=rewards, + next_observations=next_observations, + done=done, + observation_features=observation_features, + next_observation_features=next_observation_features, + ) + optimizers["critic"].zero_grad() + loss_critic.backward() + optimizers["critic"].step() batch = replay_buffer.sample(batch_size) @@ -427,46 +530,51 @@ def add_actor_information_and_train( observation_features, next_observation_features = get_observation_features( policy, observations, next_observations ) - with policy_lock: - loss_critic = policy.compute_loss_critic( - observations=observations, - actions=actions, - rewards=rewards, - next_observations=next_observations, - done=done, - observation_features=observation_features, - next_observation_features=next_observation_features, - ) - optimizers["critic"].zero_grad() - loss_critic.backward() - optimizers["critic"].step() + loss_critic = policy.compute_loss_critic( + observations=observations, + actions=actions, + rewards=rewards, + next_observations=next_observations, + done=done, + observation_features=observation_features, + next_observation_features=next_observation_features, + ) + optimizers["critic"].zero_grad() + loss_critic.backward() + optimizers["critic"].step() training_infos = {} training_infos["loss_critic"] = loss_critic.item() if optimization_step % cfg.training.policy_update_freq == 0: for _ in range(cfg.training.policy_update_freq): - with policy_lock: - loss_actor = policy.compute_loss_actor( - observations=observations, - observation_features=observation_features, - ) + loss_actor = policy.compute_loss_actor( + observations=observations, + observation_features=observation_features, + ) - optimizers["actor"].zero_grad() - loss_actor.backward() - optimizers["actor"].step() + optimizers["actor"].zero_grad() + loss_actor.backward() + optimizers["actor"].step() - training_infos["loss_actor"] = loss_actor.item() + training_infos["loss_actor"] = loss_actor.item() - loss_temperature = policy.compute_loss_temperature( - observations=observations, - observation_features=observation_features, - ) - optimizers["temperature"].zero_grad() - loss_temperature.backward() - optimizers["temperature"].step() + loss_temperature = policy.compute_loss_temperature( + observations=observations, + observation_features=observation_features, + ) + optimizers["temperature"].zero_grad() + loss_temperature.backward() + optimizers["temperature"].step() - training_infos["loss_temperature"] = loss_temperature.item() + training_infos["loss_temperature"] = loss_temperature.item() + + if ( + time.time() - last_time_policy_pushed + > cfg.actor_learner_config.policy_parameters_push_frequency + ): + push_actor_policy_to_queue(parameters_queue, policy) + last_time_policy_pushed = time.time() policy.update_target_networks() if optimization_step % cfg.training.log_freq == 0: @@ -595,104 +703,36 @@ def train(cfg: DictConfig, out_dir: str | None = None, job_name: str | None = No set_global_seed(cfg.seed) - device = get_safe_torch_device(cfg.device, log=True) - storage_device = get_safe_torch_device(cfg_device=cfg.training.storage_device) - torch.backends.cudnn.benchmark = True torch.backends.cuda.matmul.allow_tf32 = True - logging.info("make_policy") - - ### Instantiate the policy in both the actor and learner processes - ### To avoid sending a SACPolicy object through the port, we create a policy intance - ### on both sides, the learner sends the updated parameters every n steps to update the actor's parameters - # TODO: At some point we should just need make sac policy - - policy_lock = Lock() - policy: SACPolicy = make_policy( - hydra_cfg=cfg, - # dataset_stats=offline_dataset.meta.stats if not cfg.resume else None, - # Hack: But if we do online traning, we do not need dataset_stats - dataset_stats=None, - pretrained_policy_name_or_path=str(logger.last_pretrained_model_dir) - if cfg.resume - else None, - ) - # compile policy - policy = torch.compile(policy) - assert isinstance(policy, nn.Module) - - optimizers, lr_scheduler = make_optimizers_and_scheduler(cfg, policy) - resume_optimization_step, resume_interaction_step = load_training_state( - cfg, logger, optimizers - ) - - log_training_info(cfg, out_dir, policy) - - replay_buffer = initialize_replay_buffer(cfg, logger, device, storage_device) - batch_size = cfg.training.batch_size - offline_replay_buffer = None - - if cfg.dataset_repo_id is not None: - logging.info("make_dataset offline buffer") - offline_dataset = make_dataset(cfg) - logging.info("Convertion to a offline replay buffer") - active_action_dims = None - if cfg.env.wrapper.joint_masking_action_space is not None: - active_action_dims = [ - i - for i, mask in enumerate(cfg.env.wrapper.joint_masking_action_space) - if mask - ] - offline_replay_buffer = ReplayBuffer.from_lerobot_dataset( - offline_dataset, - device=device, - state_keys=cfg.policy.input_shapes.keys(), - action_mask=active_action_dims, - action_delta=cfg.env.wrapper.delta_action, - storage_device=storage_device, - optimize_memory=True, - ) - batch_size: int = batch_size // 2 # We will sample from both replay buffer - - shutdown_event = Event() - - def signal_handler(signum, frame): - print( - f"\nReceived signal {signal.Signals(signum).name}. Initiating learner shutdown..." - ) - shutdown_event.set() - - # Register signal handlers - signal.signal(signal.SIGINT, signal_handler) # Ctrl+C - signal.signal(signal.SIGTERM, signal_handler) # Termination request - signal.signal(signal.SIGHUP, signal_handler) # Terminal closed - signal.signal(signal.SIGQUIT, signal_handler) # Ctrl+\ + shutdown_event = setup_process_handlers(use_threads(cfg)) start_learner_threads( cfg, - device, - replay_buffer, - offline_replay_buffer, - batch_size, - optimizers, - policy, - policy_lock, logger, - resume_optimization_step, - resume_interaction_step, + out_dir, shutdown_event, ) @hydra.main(version_base="1.2", config_name="default", config_path="../../configs") def train_cli(cfg: dict): + if not use_threads(cfg): + import torch.multiprocessing as mp + + mp.set_start_method("spawn") + train( cfg, out_dir=hydra.core.hydra_config.HydraConfig.get().run.dir, job_name=hydra.core.hydra_config.HydraConfig.get().job.name, ) + logging.info("[LEARNER] train_cli finished") + if __name__ == "__main__": train_cli() + + logging.info("[LEARNER] main finished") diff --git a/lerobot/scripts/server/learner_service.py b/lerobot/scripts/server/learner_service.py index d6e6b5b7..b1f91cdc 100644 --- a/lerobot/scripts/server/learner_service.py +++ b/lerobot/scripts/server/learner_service.py @@ -1,23 +1,13 @@ import hilserl_pb2 # type: ignore import hilserl_pb2_grpc # type: ignore -import torch -from torch import nn -from threading import Lock, Event import logging -import queue -import io -import pickle - -from lerobot.scripts.server.buffer import ( - move_state_dict_to_device, - bytes_buffer_size, - state_to_bytes, -) +from multiprocessing import Event, Queue +from lerobot.scripts.server.network_utils import receive_bytes_in_chunks +from lerobot.scripts.server.network_utils import send_bytes_in_chunks MAX_MESSAGE_SIZE = 4 * 1024 * 1024 # 4 MB -CHUNK_SIZE = 2 * 1024 * 1024 # 2 MB -MAX_WORKERS = 10 +MAX_WORKERS = 3 # Stream parameters, send transitions and interactions STUTDOWN_TIMEOUT = 10 @@ -25,89 +15,68 @@ class LearnerService(hilserl_pb2_grpc.LearnerServiceServicer): def __init__( self, shutdown_event: Event, - policy: nn.Module, - policy_lock: Lock, + parameters_queue: Queue, seconds_between_pushes: float, - transition_queue: queue.Queue, - interaction_message_queue: queue.Queue, + transition_queue: Queue, + interaction_message_queue: Queue, ): self.shutdown_event = shutdown_event - self.policy = policy - self.policy_lock = policy_lock + self.parameters_queue = parameters_queue self.seconds_between_pushes = seconds_between_pushes self.transition_queue = transition_queue self.interaction_message_queue = interaction_message_queue - def _get_policy_state(self): - with self.policy_lock: - params_dict = self.policy.actor.state_dict() - # if self.policy.config.vision_encoder_name is not None: - # if self.policy.config.freeze_vision_encoder: - # params_dict: dict[str, torch.Tensor] = { - # k: v - # for k, v in params_dict.items() - # if not k.startswith("encoder.") - # } - # else: - # raise NotImplementedError( - # "Vision encoder is not frozen, we need to send the full model over the network which requires chunking the model." - # ) - - return move_state_dict_to_device(params_dict, device="cpu") - - def _send_bytes(self, buffer: bytes): - size_in_bytes = bytes_buffer_size(buffer) - - sent_bytes = 0 - - logging.info(f"Model state size {size_in_bytes/1024/1024} MB with") - - while sent_bytes < size_in_bytes: - transfer_state = hilserl_pb2.TransferState.TRANSFER_MIDDLE - - if sent_bytes + CHUNK_SIZE >= size_in_bytes: - transfer_state = hilserl_pb2.TransferState.TRANSFER_END - elif sent_bytes == 0: - transfer_state = hilserl_pb2.TransferState.TRANSFER_BEGIN - - size_to_read = min(CHUNK_SIZE, size_in_bytes - sent_bytes) - chunk = buffer.read(size_to_read) - - yield hilserl_pb2.Parameters( - transfer_state=transfer_state, parameter_bytes=chunk - ) - sent_bytes += size_to_read - logging.info( - f"[Learner] Sent {sent_bytes}/{size_in_bytes} bytes with state {transfer_state}" - ) - - logging.info(f"[LEARNER] Published {sent_bytes/1024/1024} MB to the Actor") - def StreamParameters(self, request, context): # TODO: authorize the request logging.info("[LEARNER] Received request to stream parameters from the Actor") while not self.shutdown_event.is_set(): - logging.debug("[LEARNER] Push parameters to the Actor") - state_dict = self._get_policy_state() + logging.info("[LEARNER] Push parameters to the Actor") + buffer = self.parameters_queue.get() - with state_to_bytes(state_dict) as buffer: - yield from self._send_bytes(buffer) + yield from send_bytes_in_chunks( + buffer, + hilserl_pb2.Parameters, + log_prefix="[LEARNER] Sending parameters", + silent=True, + ) + + logging.info("[LEARNER] Parameters sent") self.shutdown_event.wait(self.seconds_between_pushes) - def ReceiveTransitions(self, request_iterator, context): + logging.info("[LEARNER] Stream parameters finished") + return hilserl_pb2.Empty() + + def SendTransitions(self, request_iterator, _context): # TODO: authorize the request logging.info("[LEARNER] Received request to receive transitions from the Actor") - for request in request_iterator: - logging.debug("[LEARNER] Received request") - if request.HasField("transition"): - buffer = io.BytesIO(request.transition.transition_bytes) - transition = torch.load(buffer) - self.transition_queue.put(transition) - if request.HasField("interaction_message"): - content = pickle.loads( - request.interaction_message.interaction_message_bytes - ) - self.interaction_message_queue.put(content) + receive_bytes_in_chunks( + request_iterator, + self.transition_queue, + self.shutdown_event, + log_prefix="[LEARNER] transitions", + ) + + logging.debug("[LEARNER] Finished receiving transitions") + return hilserl_pb2.Empty() + + def SendInteractions(self, request_iterator, _context): + # TODO: authorize the request + logging.info( + "[LEARNER] Received request to receive interactions from the Actor" + ) + + receive_bytes_in_chunks( + request_iterator, + self.interaction_message_queue, + self.shutdown_event, + log_prefix="[LEARNER] interactions", + ) + + logging.debug("[LEARNER] Finished receiving interactions") + return hilserl_pb2.Empty() + + def Ready(self, request, context): + return hilserl_pb2.Empty() diff --git a/lerobot/scripts/server/maniskill_manipulator.py b/lerobot/scripts/server/maniskill_manipulator.py index b9c9d216..e4d55955 100644 --- a/lerobot/scripts/server/maniskill_manipulator.py +++ b/lerobot/scripts/server/maniskill_manipulator.py @@ -5,9 +5,8 @@ import torch from omegaconf import DictConfig from typing import Any - -"""Make ManiSkill3 gym environment""" from mani_skill.vector.wrappers.gymnasium import ManiSkillVectorEnv +from mani_skill.utils.wrappers.record import RecordEpisode def preprocess_maniskill_observation( @@ -143,6 +142,15 @@ def make_maniskill( num_envs=n_envs, ) + if cfg.env.video_record.enabled: + env = RecordEpisode( + env, + output_dir=cfg.env.video_record.record_dir, + save_trajectory=True, + trajectory_name=cfg.env.video_record.trajectory_name, + save_video=True, + video_fps=30, + ) env = ManiSkillObservationWrapper(env, device=cfg.env.device) env = ManiSkillVectorEnv(env, ignore_terminations=True, auto_reset=False) env._max_episode_steps = env.max_episode_steps = ( diff --git a/lerobot/scripts/server/network_utils.py b/lerobot/scripts/server/network_utils.py new file mode 100644 index 00000000..f5e8973b --- /dev/null +++ b/lerobot/scripts/server/network_utils.py @@ -0,0 +1,102 @@ +#!/usr/bin/env python + +# Copyright 2024 The HuggingFace Inc. team. +# All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from lerobot.scripts.server import hilserl_pb2 +import logging +import io +from multiprocessing import Queue, Event +from typing import Any + +CHUNK_SIZE = 2 * 1024 * 1024 # 2 MB + + +def bytes_buffer_size(buffer: io.BytesIO) -> int: + buffer.seek(0, io.SEEK_END) + result = buffer.tell() + buffer.seek(0) + return result + + +def send_bytes_in_chunks( + buffer: bytes, message_class: Any, log_prefix: str = "", silent: bool = True +): + buffer = io.BytesIO(buffer) + size_in_bytes = bytes_buffer_size(buffer) + + sent_bytes = 0 + + logging_method = logging.info if not silent else logging.debug + + logging_method(f"{log_prefix} Buffer size {size_in_bytes/1024/1024} MB with") + + while sent_bytes < size_in_bytes: + transfer_state = hilserl_pb2.TransferState.TRANSFER_MIDDLE + + if sent_bytes + CHUNK_SIZE >= size_in_bytes: + transfer_state = hilserl_pb2.TransferState.TRANSFER_END + elif sent_bytes == 0: + transfer_state = hilserl_pb2.TransferState.TRANSFER_BEGIN + + size_to_read = min(CHUNK_SIZE, size_in_bytes - sent_bytes) + chunk = buffer.read(size_to_read) + + yield message_class(transfer_state=transfer_state, data=chunk) + sent_bytes += size_to_read + logging_method( + f"{log_prefix} Sent {sent_bytes}/{size_in_bytes} bytes with state {transfer_state}" + ) + + logging_method(f"{log_prefix} Published {sent_bytes/1024/1024} MB") + + +def receive_bytes_in_chunks( + iterator, queue: Queue, shutdown_event: Event, log_prefix: str = "" +): + bytes_buffer = io.BytesIO() + step = 0 + + logging.info(f"{log_prefix} Starting receiver") + for item in iterator: + logging.debug(f"{log_prefix} Received item") + if shutdown_event.is_set(): + logging.info(f"{log_prefix} Shutting down receiver") + return + + if item.transfer_state == hilserl_pb2.TransferState.TRANSFER_BEGIN: + bytes_buffer.seek(0) + bytes_buffer.truncate(0) + bytes_buffer.write(item.data) + logging.debug(f"{log_prefix} Received data at step 0") + step = 0 + continue + elif item.transfer_state == hilserl_pb2.TransferState.TRANSFER_MIDDLE: + bytes_buffer.write(item.data) + step += 1 + logging.debug(f"{log_prefix} Received data at step {step}") + elif item.transfer_state == hilserl_pb2.TransferState.TRANSFER_END: + bytes_buffer.write(item.data) + logging.debug( + f"{log_prefix} Received data at step end size {bytes_buffer_size(bytes_buffer)}" + ) + + queue.put(bytes_buffer.getvalue()) + + bytes_buffer.seek(0) + bytes_buffer.truncate(0) + step = 0 + + logging.debug(f"{log_prefix} Queue updated") diff --git a/lerobot/scripts/server/utils.py b/lerobot/scripts/server/utils.py new file mode 100644 index 00000000..699717e4 --- /dev/null +++ b/lerobot/scripts/server/utils.py @@ -0,0 +1,72 @@ +#!/usr/bin/env python + +# Copyright 2024 The HuggingFace Inc. team. +# All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging +import signal +import sys +from torch.multiprocessing import Queue +from queue import Empty + +shutdown_event_counter = 0 + + +def setup_process_handlers(use_threads: bool) -> any: + if use_threads: + from threading import Event + else: + from multiprocessing import Event + + shutdown_event = Event() + + # Define signal handler + def signal_handler(signum, frame): + logging.info("Shutdown signal received. Cleaning up...") + shutdown_event.set() + global shutdown_event_counter + shutdown_event_counter += 1 + + if shutdown_event_counter > 1: + logging.info("Force shutdown") + sys.exit(1) + + signal.signal(signal.SIGINT, signal_handler) # Ctrl+C + signal.signal(signal.SIGTERM, signal_handler) # Termination request (kill) + signal.signal(signal.SIGHUP, signal_handler) # Terminal closed/Hangup + signal.signal(signal.SIGQUIT, signal_handler) # Ctrl+\ + + def signal_handler(signum, frame): + logging.info("Shutdown signal received. Cleaning up...") + shutdown_event.set() + + return shutdown_event + + +def get_last_item_from_queue(queue: Queue): + item = queue.get() + counter = 1 + + # Drain queue and keep only the most recent parameters + try: + while True: + item = queue.get_nowait() + counter += 1 + except Empty: + pass + + logging.debug(f"Drained {counter} items from queue") + + return item