[HIL-SERL] Migrate threading to multiprocessing (#759)

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
This commit is contained in:
Eugene Mironov 2025-03-05 17:19:31 +07:00 committed by GitHub
parent 584cad808e
commit 700f00c014
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
14 changed files with 900 additions and 492 deletions

View File

@ -116,11 +116,11 @@ def seeded_context(seed: int) -> Generator[None, None, None]:
set_global_random_state(random_state_dict)
def init_logging():
def init_logging(log_file=None):
def custom_format(record):
dt = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
fnameline = f"{record.pathname}:{record.lineno}"
message = f"{record.levelname} {dt} {fnameline[-15:]:>15} {record.msg}"
message = f"{record.levelname} [PID: {os.getpid()}] {dt} {fnameline[-15:]:>15} {record.msg}"
return message
logging.basicConfig(level=logging.INFO)
@ -134,6 +134,12 @@ def init_logging():
console_handler.setFormatter(formatter)
logging.getLogger().addHandler(console_handler)
if log_file is not None:
# File handler
file_handler = logging.FileHandler(log_file)
file_handler.setFormatter(formatter)
logging.getLogger().addHandler(file_handler)
def format_big_number(num, precision=0):
suffixes = ["", "K", "M", "B", "T", "Q"]

View File

@ -22,3 +22,9 @@ env:
wrapper:
joint_masking_action_space: null
delta_action: null
video_record:
enabled: false
record_dir: maniskill_videos
trajectory_name: trajectory
fps: ${fps}

View File

@ -28,4 +28,3 @@ env:
reward_classifier:
pretrained_path: outputs/classifier/13-02-random-sample-resnet10-frozen/checkpoints/best/pretrained_model
config_path: lerobot/configs/policy/hilserl_classifier.yaml

View File

@ -8,14 +8,12 @@
# env.gym.obs_type=environment_state_agent_pos \
seed: 1
# dataset_repo_id: null
dataset_repo_id: "AdilZtn/Maniskill-Pushcube-demonstration-medium"
training:
# Offline training dataloader
num_workers: 4
# batch_size: 256
batch_size: 512
grad_clip_norm: 10.0
lr: 3e-4
@ -113,4 +111,7 @@ policy:
actor_learner_config:
learner_host: "127.0.0.1"
learner_port: 50051
policy_parameters_push_frequency: 15
policy_parameters_push_frequency: 1
concurrency:
actor: 'processes'
learner: 'processes'

View File

@ -13,22 +13,19 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import io
import logging
import pickle
import queue
from statistics import mean, quantiles
import signal
from functools import lru_cache
from lerobot.scripts.server.utils import setup_process_handlers
# from lerobot.scripts.eval import eval_policy
from threading import Thread
import grpc
import hydra
import torch
from omegaconf import DictConfig
from torch import nn
import time
# TODO: Remove the import of maniskill
# from lerobot.common.envs.factory import make_maniskill_env
@ -47,157 +44,184 @@ from lerobot.scripts.server.buffer import (
Transition,
move_state_dict_to_device,
move_transition_to_device,
bytes_buffer_size,
python_object_to_bytes,
transitions_to_bytes,
bytes_to_state_dict,
)
from lerobot.scripts.server.network_utils import (
receive_bytes_in_chunks,
send_bytes_in_chunks,
)
from lerobot.scripts.server.gym_manipulator import get_classifier, make_robot_env
from lerobot.scripts.server import learner_service
from threading import Event
from torch.multiprocessing import Queue, Event
from queue import Empty
logging.basicConfig(level=logging.INFO)
from lerobot.common.utils.utils import init_logging
parameters_queue = queue.Queue(maxsize=1)
message_queue = queue.Queue(maxsize=1_000_000)
from lerobot.scripts.server.utils import get_last_item_from_queue
ACTOR_SHUTDOWN_TIMEOUT = 30
class ActorInformation:
"""
This helper class is used to differentiate between two types of messages that are placed in the same queue during streaming:
- **Transition Data:** Contains experience tuples (observation, action, reward, next observation) collected during interaction.
- **Interaction Messages:** Encapsulates statistics related to the interaction process.
Attributes:
transition (Optional): Transition data to be sent to the learner.
interaction_message (Optional): Iteraction message providing additional statistics for logging.
"""
def __init__(self, transition=None, interaction_message=None):
self.transition = transition
self.interaction_message = interaction_message
def receive_policy(
learner_client: hilserl_pb2_grpc.LearnerServiceStub,
shutdown_event: Event,
parameters_queue: queue.Queue,
cfg: DictConfig,
parameters_queue: Queue,
shutdown_event: any, # Event,
learner_client: hilserl_pb2_grpc.LearnerServiceStub | None = None,
grpc_channel: grpc.Channel | None = None,
):
logging.info("[ACTOR] Start receiving parameters from the Learner")
bytes_buffer = io.BytesIO()
step = 0
if not use_threads(cfg):
# Setup process handlers to handle shutdown signal
# But use shutdown event from the main process
setup_process_handlers(False)
if grpc_channel is None or learner_client is None:
learner_client, grpc_channel = learner_service_client(
host=cfg.actor_learner_config.learner_host,
port=cfg.actor_learner_config.learner_port,
)
try:
for model_update in learner_client.StreamParameters(hilserl_pb2.Empty()):
if shutdown_event.is_set():
logging.info("[ACTOR] Shutting down policy streaming receiver")
return hilserl_pb2.Empty()
if model_update.transfer_state == hilserl_pb2.TransferState.TRANSFER_BEGIN:
bytes_buffer.seek(0)
bytes_buffer.truncate(0)
bytes_buffer.write(model_update.parameter_bytes)
logging.info("Received model update at step 0")
step = 0
continue
elif (
model_update.transfer_state == hilserl_pb2.TransferState.TRANSFER_MIDDLE
):
bytes_buffer.write(model_update.parameter_bytes)
step += 1
logging.info(f"Received model update at step {step}")
elif model_update.transfer_state == hilserl_pb2.TransferState.TRANSFER_END:
bytes_buffer.write(model_update.parameter_bytes)
logging.info(
f"Received model update at step end size {bytes_buffer_size(bytes_buffer)}"
)
state_dict = torch.load(bytes_buffer)
bytes_buffer.seek(0)
bytes_buffer.truncate(0)
step = 0
logging.info("Model updated")
parameters_queue.put(state_dict)
iterator = learner_client.StreamParameters(hilserl_pb2.Empty())
receive_bytes_in_chunks(
iterator,
parameters_queue,
shutdown_event,
log_prefix="[ACTOR] parameters",
)
except grpc.RpcError as e:
logging.error(f"[ACTOR] gRPC error: {e}")
if not use_threads(cfg):
grpc_channel.close()
logging.info("[ACTOR] Received policy loop stopped")
def transitions_stream(
shutdown_event: Event, transitions_queue: Queue
) -> hilserl_pb2.Empty:
while not shutdown_event.is_set():
try:
message = transitions_queue.get(block=True, timeout=5)
except Empty:
logging.debug("[ACTOR] Transition queue is empty")
continue
yield from send_bytes_in_chunks(
message, hilserl_pb2.Transition, log_prefix="[ACTOR] Send transitions"
)
return hilserl_pb2.Empty()
def transitions_stream(shutdown_event: Event, message_queue: queue.Queue):
def interactions_stream(
shutdown_event: any, # Event,
interactions_queue: Queue,
) -> hilserl_pb2.Empty:
while not shutdown_event.is_set():
try:
message = message_queue.get(block=True, timeout=5)
except queue.Empty:
logging.debug("[ACTOR] Transition queue is empty")
message = interactions_queue.get(block=True, timeout=5)
except Empty:
logging.debug("[ACTOR] Interaction queue is empty")
continue
if message.transition is not None:
transition_to_send_to_learner: list[Transition] = [
move_transition_to_device(transition=T, device="cpu")
for T in message.transition
]
# Check for NaNs in transitions before sending to learner
for transition in transition_to_send_to_learner:
for key, value in transition["state"].items():
if torch.isnan(value).any():
logging.warning(f"Found NaN values in transition {key}")
buf = io.BytesIO()
torch.save(transition_to_send_to_learner, buf)
transition_bytes = buf.getvalue()
transition_message = hilserl_pb2.Transition(
transition_bytes=transition_bytes
)
response = hilserl_pb2.ActorInformation(transition=transition_message)
elif message.interaction_message is not None:
content = hilserl_pb2.InteractionMessage(
interaction_message_bytes=pickle.dumps(message.interaction_message)
)
response = hilserl_pb2.ActorInformation(interaction_message=content)
yield response
yield from send_bytes_in_chunks(
message,
hilserl_pb2.InteractionMessage,
log_prefix="[ACTOR] Send interactions",
)
return hilserl_pb2.Empty()
def send_transitions(
learner_client: hilserl_pb2_grpc.LearnerServiceStub,
shutdown_event: Event,
message_queue: queue.Queue,
):
cfg: DictConfig,
transitions_queue: Queue,
shutdown_event: any, # Event,
learner_client: hilserl_pb2_grpc.LearnerServiceStub | None = None,
grpc_channel: grpc.Channel | None = None,
) -> hilserl_pb2.Empty:
"""
Streams data from the actor to the learner.
Sends transitions to the learner.
This function continuously retrieves messages from the queue and processes them based on their type:
This function continuously retrieves messages from the queue and processes:
- **Transition Data:**
- A batch of transitions (observation, action, reward, next observation) is collected.
- Transitions are moved to the CPU and serialized using PyTorch.
- The serialized data is wrapped in a `hilserl_pb2.Transition` message and sent to the learner.
- **Interaction Messages:**
- Contains useful statistics about episodic rewards and policy timings.
- The message is serialized using `pickle` and sent to the learner.
Yields:
hilserl_pb2.ActorInformation: The response message containing either transition data or an interaction message.
"""
if not use_threads(cfg):
# Setup process handlers to handle shutdown signal
# But use shutdown event from the main process
setup_process_handlers(False)
if grpc_channel is None or learner_client is None:
learner_client, grpc_channel = learner_service_client(
host=cfg.actor_learner_config.learner_host,
port=cfg.actor_learner_config.learner_port,
)
try:
learner_client.ReceiveTransitions(
transitions_stream(shutdown_event, message_queue)
learner_client.SendTransitions(
transitions_stream(shutdown_event, transitions_queue)
)
except grpc.RpcError as e:
logging.error(f"[ACTOR] gRPC error: {e}")
logging.info("[ACTOR] Finished streaming transitions")
if not use_threads(cfg):
grpc_channel.close()
logging.info("[ACTOR] Transitions process stopped")
def send_interactions(
cfg: DictConfig,
interactions_queue: Queue,
shutdown_event: any, # Event,
learner_client: hilserl_pb2_grpc.LearnerServiceStub | None = None,
grpc_channel: grpc.Channel | None = None,
) -> hilserl_pb2.Empty:
"""
Sends interactions to the learner.
This function continuously retrieves messages from the queue and processes:
- **Interaction Messages:**
- Contains useful statistics about episodic rewards and policy timings.
- The message is serialized using `pickle` and sent to the learner.
"""
if not use_threads(cfg):
# Setup process handlers to handle shutdown signal
# But use shutdown event from the main process
setup_process_handlers(False)
if grpc_channel is None or learner_client is None:
learner_client, grpc_channel = learner_service_client(
host=cfg.actor_learner_config.learner_host,
port=cfg.actor_learner_config.learner_port,
)
try:
learner_client.SendInteractions(
interactions_stream(shutdown_event, interactions_queue)
)
except grpc.RpcError as e:
logging.error(f"[ACTOR] gRPC error: {e}")
logging.info("[ACTOR] Finished streaming interactions")
if not use_threads(cfg):
grpc_channel.close()
logging.info("[ACTOR] Interactions process stopped")
@lru_cache(maxsize=1)
def learner_service_client(
@ -217,7 +241,7 @@ def learner_service_client(
{
"name": [{}], # Applies to ALL methods in ALL services
"retryPolicy": {
"maxAttempts": 7, # Max retries (total attempts = 5)
"maxAttempts": 5, # Max retries (total attempts = 5)
"initialBackoff": "0.1s", # First retry after 0.1s
"maxBackoff": "2s", # Max wait time between retries
"backoffMultiplier": 2, # Exponential backoff factor
@ -242,20 +266,27 @@ def learner_service_client(
],
)
stub = hilserl_pb2_grpc.LearnerServiceStub(channel)
logging.info("[LEARNER] Learner service client created")
logging.info("[ACTOR] Learner service client created")
return stub, channel
def update_policy_parameters(policy: SACPolicy, parameters_queue: queue.Queue, device):
def update_policy_parameters(policy: SACPolicy, parameters_queue: Queue, device):
if not parameters_queue.empty():
logging.info("[ACTOR] Load new parameters from Learner.")
state_dict = parameters_queue.get()
bytes_state_dict = get_last_item_from_queue(parameters_queue)
state_dict = bytes_to_state_dict(bytes_state_dict)
state_dict = move_state_dict_to_device(state_dict, device=device)
policy.load_state_dict(state_dict)
def act_with_policy(
cfg: DictConfig, robot: Robot, reward_classifier: nn.Module, shutdown_event: Event
cfg: DictConfig,
robot: Robot,
reward_classifier: nn.Module,
shutdown_event: any, # Event,
parameters_queue: Queue,
transitions_queue: Queue,
interactions_queue: Queue,
):
"""
Executes policy interaction within the environment.
@ -317,7 +348,7 @@ def act_with_policy(
for interaction_step in range(cfg.training.online_steps):
if shutdown_event.is_set():
logging.info("[ACTOR] Shutdown signal received. Exiting...")
logging.info("[ACTOR] Shutting down act_with_policy")
return
if interaction_step >= cfg.training.online_step_before_learning:
@ -394,10 +425,9 @@ def act_with_policy(
)
if len(list_transition_to_send_to_learner) > 0:
send_transitions_in_chunks(
push_transitions_to_transport_queue(
transitions=list_transition_to_send_to_learner,
message_queue=message_queue,
chunk_size=4,
transitions_queue=transitions_queue,
)
list_transition_to_send_to_learner = []
@ -405,9 +435,9 @@ def act_with_policy(
list_policy_time.clear()
# Send episodic reward to the learner
message_queue.put(
ActorInformation(
interaction_message={
interactions_queue.put(
python_object_to_bytes(
{
"Episodic reward": sum_reward_episode,
"Interaction step": interaction_step,
"Episode intervention": int(episode_intervention),
@ -420,7 +450,7 @@ def act_with_policy(
obs, info = online_env.reset()
def send_transitions_in_chunks(transitions: list, message_queue, chunk_size: int = 100):
def push_transitions_to_transport_queue(transitions: list, transitions_queue):
"""Send transitions to learner in smaller chunks to avoid network issues.
Args:
@ -428,10 +458,16 @@ def send_transitions_in_chunks(transitions: list, message_queue, chunk_size: int
message_queue: Queue to send messages to learner
chunk_size: Size of each chunk to send
"""
for i in range(0, len(transitions), chunk_size):
chunk = transitions[i : i + chunk_size]
logging.debug(f"[ACTOR] Sending chunk of {len(chunk)} transitions to Learner.")
message_queue.put(ActorInformation(transition=chunk))
transition_to_send_to_learner = []
for transition in transitions:
tr = move_transition_to_device(transition=transition, device="cpu")
for key, value in tr["state"].items():
if torch.isnan(value).any():
logging.warning(f"Found NaN values in transition {key}")
transition_to_send_to_learner.append(tr)
transitions_queue.put(transitions_to_bytes(transition_to_send_to_learner))
def get_frequency_stats(list_policy_time: list[float]) -> dict[str, float]:
@ -458,39 +494,96 @@ def log_policy_frequency_issue(
)
def establish_learner_connection(
stub,
shutdown_event: any, # Event,
attempts=30,
):
for _ in range(attempts):
if shutdown_event.is_set():
logging.info("[ACTOR] Shutting down establish_learner_connection")
return False
# Force a connection attempt and check state
try:
logging.info("[ACTOR] Send ready message to Learner")
if stub.Ready(hilserl_pb2.Empty()) == hilserl_pb2.Empty():
return True
except grpc.RpcError as e:
logging.error(f"[ACTOR] Waiting for Learner to be ready... {e}")
time.sleep(2)
return False
def use_threads(cfg: DictConfig) -> bool:
return cfg.actor_learner_config.concurrency.actor == "threads"
@hydra.main(version_base="1.2", config_name="default", config_path="../../configs")
def actor_cli(cfg: dict):
if not use_threads(cfg):
import torch.multiprocessing as mp
mp.set_start_method("spawn")
init_logging(log_file="actor.log")
robot = make_robot(cfg=cfg.robot)
shutdown_event = Event()
# Define signal handler
def signal_handler(signum, frame):
logging.info("Shutdown signal received. Cleaning up...")
shutdown_event.set()
signal.signal(signal.SIGINT, signal_handler) # Ctrl+C
signal.signal(signal.SIGTERM, signal_handler) # Termination request (kill)
signal.signal(signal.SIGHUP, signal_handler) # Terminal closed/Hangup
signal.signal(signal.SIGQUIT, signal_handler) # Ctrl+\
shutdown_event = setup_process_handlers(use_threads(cfg))
learner_client, grpc_channel = learner_service_client(
host=cfg.actor_learner_config.learner_host,
port=cfg.actor_learner_config.learner_port,
)
receive_policy_thread = Thread(
logging.info("[ACTOR] Establishing connection with Learner")
if not establish_learner_connection(learner_client, shutdown_event):
logging.error("[ACTOR] Failed to establish connection with Learner")
return
if not use_threads(cfg):
# If we use multithreading, we can reuse the channel
grpc_channel.close()
grpc_channel = None
logging.info("[ACTOR] Connection with Learner established")
parameters_queue = Queue()
transitions_queue = Queue()
interactions_queue = Queue()
concurrency_entity = None
if use_threads(cfg):
from threading import Thread
concurrency_entity = Thread
else:
from multiprocessing import Process
concurrency_entity = Process
receive_policy_process = concurrency_entity(
target=receive_policy,
args=(learner_client, shutdown_event, parameters_queue),
args=(cfg, parameters_queue, shutdown_event, grpc_channel),
daemon=True,
)
transitions_thread = Thread(
transitions_process = concurrency_entity(
target=send_transitions,
args=(learner_client, shutdown_event, message_queue),
args=(cfg, transitions_queue, shutdown_event, grpc_channel),
daemon=True,
)
interactions_process = concurrency_entity(
target=send_interactions,
args=(cfg, interactions_queue, shutdown_event, grpc_channel),
daemon=True,
)
transitions_process.start()
interactions_process.start()
receive_policy_process.start()
# HACK: FOR MANISKILL we do not have a reward classifier
# TODO: Remove this once we merge into main
reward_classifier = None
@ -503,26 +596,35 @@ def actor_cli(cfg: dict):
config_path=cfg.env.reward_classifier.config_path,
)
policy_thread = Thread(
target=act_with_policy,
daemon=True,
args=(cfg, robot, reward_classifier, shutdown_event),
act_with_policy(
cfg,
robot,
reward_classifier,
shutdown_event,
parameters_queue,
transitions_queue,
interactions_queue,
)
logging.info("[ACTOR] Policy process joined")
transitions_thread.start()
policy_thread.start()
receive_policy_thread.start()
logging.info("[ACTOR] Closing queues")
transitions_queue.close()
interactions_queue.close()
parameters_queue.close()
shutdown_event.wait()
logging.info("[ACTOR] Shutdown event received")
grpc_channel.close()
transitions_process.join()
logging.info("[ACTOR] Transitions process joined")
interactions_process.join()
logging.info("[ACTOR] Interactions process joined")
receive_policy_process.join()
logging.info("[ACTOR] Receive policy process joined")
policy_thread.join()
logging.info("[ACTOR] Policy thread joined")
transitions_thread.join()
logging.info("[ACTOR] Transitions thread joined")
receive_policy_thread.join()
logging.info("[ACTOR] Receive policy thread joined")
logging.info("[ACTOR] join queues")
transitions_queue.cancel_join_thread()
interactions_queue.cancel_join_thread()
parameters_queue.cancel_join_thread()
logging.info("[ACTOR] queues closed")
if __name__ == "__main__":

View File

@ -23,6 +23,7 @@ from tqdm import tqdm
from lerobot.common.datasets.lerobot_dataset import LeRobotDataset
import os
import pickle
class Transition(TypedDict):
@ -91,7 +92,7 @@ def move_transition_to_device(
return transition
def move_state_dict_to_device(state_dict, device):
def move_state_dict_to_device(state_dict, device="cpu"):
"""
Recursively move all tensors in a (potentially) nested
dict/list/tuple structure to the CPU.
@ -111,20 +112,41 @@ def move_state_dict_to_device(state_dict, device):
return state_dict
def state_to_bytes(state_dict: dict[str, torch.Tensor]) -> io.BytesIO:
def state_to_bytes(state_dict: dict[str, torch.Tensor]) -> bytes:
"""Convert model state dict to flat array for transmission"""
buffer = io.BytesIO()
torch.save(state_dict, buffer)
return buffer
return buffer.getvalue()
def bytes_buffer_size(buffer: io.BytesIO) -> int:
buffer.seek(0, io.SEEK_END)
result = buffer.tell()
def bytes_to_state_dict(buffer: bytes) -> dict[str, torch.Tensor]:
buffer = io.BytesIO(buffer)
buffer.seek(0)
return result
return torch.load(buffer)
def python_object_to_bytes(python_object: Any) -> bytes:
return pickle.dumps(python_object)
def bytes_to_python_object(buffer: bytes) -> Any:
buffer = io.BytesIO(buffer)
buffer.seek(0)
return pickle.load(buffer)
def bytes_to_transitions(buffer: bytes) -> list[Transition]:
buffer = io.BytesIO(buffer)
buffer.seek(0)
return torch.load(buffer)
def transitions_to_bytes(transitions: list[Transition]) -> bytes:
buffer = io.BytesIO()
torch.save(transitions, buffer)
return buffer.getvalue()
def random_crop_vectorized(images: torch.Tensor, output_size: tuple) -> torch.Tensor:

View File

@ -24,14 +24,9 @@ service LearnerService {
// Actor -> Learner to store transitions
rpc SendInteractionMessage(InteractionMessage) returns (Empty);
rpc StreamParameters(Empty) returns (stream Parameters);
rpc ReceiveTransitions(stream ActorInformation) returns (Empty);
}
message ActorInformation {
oneof data {
Transition transition = 1;
InteractionMessage interaction_message = 2;
}
rpc SendTransitions(stream Transition) returns (Empty);
rpc SendInteractions(stream InteractionMessage) returns (Empty);
rpc Ready(Empty) returns (Empty);
}
enum TransferState {
@ -43,16 +38,18 @@ enum TransferState {
// Messages
message Transition {
bytes transition_bytes = 1;
TransferState transfer_state = 1;
bytes data = 2;
}
message Parameters {
TransferState transfer_state = 1;
bytes parameter_bytes = 2;
bytes data = 2;
}
message InteractionMessage {
bytes interaction_message_bytes = 1;
TransferState transfer_state = 1;
bytes data = 2;
}
message Empty {}

View File

@ -24,25 +24,23 @@ _sym_db = _symbol_database.Default()
DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\rhilserl.proto\x12\x08hil_serl\"\x83\x01\n\x10\x41\x63torInformation\x12*\n\ntransition\x18\x01 \x01(\x0b\x32\x14.hil_serl.TransitionH\x00\x12;\n\x13interaction_message\x18\x02 \x01(\x0b\x32\x1c.hil_serl.InteractionMessageH\x00\x42\x06\n\x04\x64\x61ta\"&\n\nTransition\x12\x18\n\x10transition_bytes\x18\x01 \x01(\x0c\"V\n\nParameters\x12/\n\x0etransfer_state\x18\x01 \x01(\x0e\x32\x17.hil_serl.TransferState\x12\x17\n\x0fparameter_bytes\x18\x02 \x01(\x0c\"7\n\x12InteractionMessage\x12!\n\x19interaction_message_bytes\x18\x01 \x01(\x0c\"\x07\n\x05\x45mpty*`\n\rTransferState\x12\x14\n\x10TRANSFER_UNKNOWN\x10\x00\x12\x12\n\x0eTRANSFER_BEGIN\x10\x01\x12\x13\n\x0fTRANSFER_MIDDLE\x10\x02\x12\x10\n\x0cTRANSFER_END\x10\x03\x32\xdb\x01\n\x0eLearnerService\x12G\n\x16SendInteractionMessage\x12\x1c.hil_serl.InteractionMessage\x1a\x0f.hil_serl.Empty\x12;\n\x10StreamParameters\x12\x0f.hil_serl.Empty\x1a\x14.hil_serl.Parameters0\x01\x12\x43\n\x12ReceiveTransitions\x12\x1a.hil_serl.ActorInformation\x1a\x0f.hil_serl.Empty(\x01\x62\x06proto3')
DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\rhilserl.proto\x12\x08hil_serl\"K\n\nTransition\x12/\n\x0etransfer_state\x18\x01 \x01(\x0e\x32\x17.hil_serl.TransferState\x12\x0c\n\x04\x64\x61ta\x18\x02 \x01(\x0c\"K\n\nParameters\x12/\n\x0etransfer_state\x18\x01 \x01(\x0e\x32\x17.hil_serl.TransferState\x12\x0c\n\x04\x64\x61ta\x18\x02 \x01(\x0c\"S\n\x12InteractionMessage\x12/\n\x0etransfer_state\x18\x01 \x01(\x0e\x32\x17.hil_serl.TransferState\x12\x0c\n\x04\x64\x61ta\x18\x02 \x01(\x0c\"\x07\n\x05\x45mpty*`\n\rTransferState\x12\x14\n\x10TRANSFER_UNKNOWN\x10\x00\x12\x12\n\x0eTRANSFER_BEGIN\x10\x01\x12\x13\n\x0fTRANSFER_MIDDLE\x10\x02\x12\x10\n\x0cTRANSFER_END\x10\x03\x32\xc2\x02\n\x0eLearnerService\x12G\n\x16SendInteractionMessage\x12\x1c.hil_serl.InteractionMessage\x1a\x0f.hil_serl.Empty\x12;\n\x10StreamParameters\x12\x0f.hil_serl.Empty\x1a\x14.hil_serl.Parameters0\x01\x12:\n\x0fSendTransitions\x12\x14.hil_serl.Transition\x1a\x0f.hil_serl.Empty(\x01\x12\x43\n\x10SendInteractions\x12\x1c.hil_serl.InteractionMessage\x1a\x0f.hil_serl.Empty(\x01\x12)\n\x05Ready\x12\x0f.hil_serl.Empty\x1a\x0f.hil_serl.Emptyb\x06proto3')
_globals = globals()
_builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, _globals)
_builder.BuildTopDescriptorsAndMessages(DESCRIPTOR, 'hilserl_pb2', _globals)
if not _descriptor._USE_C_DESCRIPTORS:
DESCRIPTOR._loaded_options = None
_globals['_TRANSFERSTATE']._serialized_start=355
_globals['_TRANSFERSTATE']._serialized_end=451
_globals['_ACTORINFORMATION']._serialized_start=28
_globals['_ACTORINFORMATION']._serialized_end=159
_globals['_TRANSITION']._serialized_start=161
_globals['_TRANSITION']._serialized_end=199
_globals['_PARAMETERS']._serialized_start=201
_globals['_PARAMETERS']._serialized_end=287
_globals['_INTERACTIONMESSAGE']._serialized_start=289
_globals['_INTERACTIONMESSAGE']._serialized_end=344
_globals['_EMPTY']._serialized_start=346
_globals['_EMPTY']._serialized_end=353
_globals['_LEARNERSERVICE']._serialized_start=454
_globals['_LEARNERSERVICE']._serialized_end=673
_globals['_TRANSFERSTATE']._serialized_start=275
_globals['_TRANSFERSTATE']._serialized_end=371
_globals['_TRANSITION']._serialized_start=27
_globals['_TRANSITION']._serialized_end=102
_globals['_PARAMETERS']._serialized_start=104
_globals['_PARAMETERS']._serialized_end=179
_globals['_INTERACTIONMESSAGE']._serialized_start=181
_globals['_INTERACTIONMESSAGE']._serialized_end=264
_globals['_EMPTY']._serialized_start=266
_globals['_EMPTY']._serialized_end=273
_globals['_LEARNERSERVICE']._serialized_start=374
_globals['_LEARNERSERVICE']._serialized_end=696
# @@protoc_insertion_point(module_scope)

View File

@ -46,9 +46,19 @@ class LearnerServiceStub(object):
request_serializer=hilserl__pb2.Empty.SerializeToString,
response_deserializer=hilserl__pb2.Parameters.FromString,
_registered_method=True)
self.ReceiveTransitions = channel.stream_unary(
'/hil_serl.LearnerService/ReceiveTransitions',
request_serializer=hilserl__pb2.ActorInformation.SerializeToString,
self.SendTransitions = channel.stream_unary(
'/hil_serl.LearnerService/SendTransitions',
request_serializer=hilserl__pb2.Transition.SerializeToString,
response_deserializer=hilserl__pb2.Empty.FromString,
_registered_method=True)
self.SendInteractions = channel.stream_unary(
'/hil_serl.LearnerService/SendInteractions',
request_serializer=hilserl__pb2.InteractionMessage.SerializeToString,
response_deserializer=hilserl__pb2.Empty.FromString,
_registered_method=True)
self.Ready = channel.unary_unary(
'/hil_serl.LearnerService/Ready',
request_serializer=hilserl__pb2.Empty.SerializeToString,
response_deserializer=hilserl__pb2.Empty.FromString,
_registered_method=True)
@ -71,7 +81,19 @@ class LearnerServiceServicer(object):
context.set_details('Method not implemented!')
raise NotImplementedError('Method not implemented!')
def ReceiveTransitions(self, request_iterator, context):
def SendTransitions(self, request_iterator, context):
"""Missing associated documentation comment in .proto file."""
context.set_code(grpc.StatusCode.UNIMPLEMENTED)
context.set_details('Method not implemented!')
raise NotImplementedError('Method not implemented!')
def SendInteractions(self, request_iterator, context):
"""Missing associated documentation comment in .proto file."""
context.set_code(grpc.StatusCode.UNIMPLEMENTED)
context.set_details('Method not implemented!')
raise NotImplementedError('Method not implemented!')
def Ready(self, request, context):
"""Missing associated documentation comment in .proto file."""
context.set_code(grpc.StatusCode.UNIMPLEMENTED)
context.set_details('Method not implemented!')
@ -90,9 +112,19 @@ def add_LearnerServiceServicer_to_server(servicer, server):
request_deserializer=hilserl__pb2.Empty.FromString,
response_serializer=hilserl__pb2.Parameters.SerializeToString,
),
'ReceiveTransitions': grpc.stream_unary_rpc_method_handler(
servicer.ReceiveTransitions,
request_deserializer=hilserl__pb2.ActorInformation.FromString,
'SendTransitions': grpc.stream_unary_rpc_method_handler(
servicer.SendTransitions,
request_deserializer=hilserl__pb2.Transition.FromString,
response_serializer=hilserl__pb2.Empty.SerializeToString,
),
'SendInteractions': grpc.stream_unary_rpc_method_handler(
servicer.SendInteractions,
request_deserializer=hilserl__pb2.InteractionMessage.FromString,
response_serializer=hilserl__pb2.Empty.SerializeToString,
),
'Ready': grpc.unary_unary_rpc_method_handler(
servicer.Ready,
request_deserializer=hilserl__pb2.Empty.FromString,
response_serializer=hilserl__pb2.Empty.SerializeToString,
),
}
@ -163,7 +195,7 @@ class LearnerService(object):
_registered_method=True)
@staticmethod
def ReceiveTransitions(request_iterator,
def SendTransitions(request_iterator,
target,
options=(),
channel_credentials=None,
@ -176,8 +208,62 @@ class LearnerService(object):
return grpc.experimental.stream_unary(
request_iterator,
target,
'/hil_serl.LearnerService/ReceiveTransitions',
hilserl__pb2.ActorInformation.SerializeToString,
'/hil_serl.LearnerService/SendTransitions',
hilserl__pb2.Transition.SerializeToString,
hilserl__pb2.Empty.FromString,
options,
channel_credentials,
insecure,
call_credentials,
compression,
wait_for_ready,
timeout,
metadata,
_registered_method=True)
@staticmethod
def SendInteractions(request_iterator,
target,
options=(),
channel_credentials=None,
call_credentials=None,
insecure=False,
compression=None,
wait_for_ready=None,
timeout=None,
metadata=None):
return grpc.experimental.stream_unary(
request_iterator,
target,
'/hil_serl.LearnerService/SendInteractions',
hilserl__pb2.InteractionMessage.SerializeToString,
hilserl__pb2.Empty.FromString,
options,
channel_credentials,
insecure,
call_credentials,
compression,
wait_for_ready,
timeout,
metadata,
_registered_method=True)
@staticmethod
def Ready(request,
target,
options=(),
channel_credentials=None,
call_credentials=None,
insecure=False,
compression=None,
wait_for_ready=None,
timeout=None,
metadata=None):
return grpc.experimental.unary_unary(
request,
target,
'/hil_serl.LearnerService/Ready',
hilserl__pb2.Empty.SerializeToString,
hilserl__pb2.Empty.FromString,
options,
channel_credentials,

View File

@ -15,15 +15,18 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
import queue
import shutil
import time
from pprint import pformat
from threading import Lock, Thread
import signal
from threading import Event
from concurrent.futures import ThreadPoolExecutor
# from torch.multiprocessing import Event, Queue, Process
# from threading import Event, Thread
# from torch.multiprocessing import Queue, Event
from torch.multiprocessing import Queue
from lerobot.scripts.server.utils import setup_process_handlers
import grpc
# Import generated stubs
@ -52,19 +55,19 @@ from lerobot.common.utils.utils import (
set_global_random_state,
set_global_seed,
)
from lerobot.scripts.server.buffer import (
ReplayBuffer,
concatenate_batch_transitions,
move_transition_to_device,
move_state_dict_to_device,
bytes_to_transitions,
state_to_bytes,
bytes_to_python_object,
)
from lerobot.scripts.server import learner_service
logging.basicConfig(level=logging.INFO)
transition_queue = queue.Queue()
interaction_message_queue = queue.Queue()
def handle_resume_logic(cfg: DictConfig, out_dir: str) -> DictConfig:
if not cfg.resume:
@ -195,67 +198,96 @@ def get_observation_features(
return observation_features, next_observation_features
def use_threads(cfg: DictConfig) -> bool:
return cfg.actor_learner_config.concurrency.learner == "threads"
def start_learner_threads(
cfg: DictConfig,
device: str,
replay_buffer: ReplayBuffer,
offline_replay_buffer: ReplayBuffer,
batch_size: int,
optimizers: dict,
policy: SACPolicy,
policy_lock: Lock,
logger: Logger,
resume_optimization_step: int | None = None,
resume_interaction_step: int | None = None,
shutdown_event: Event | None = None,
out_dir: str,
shutdown_event: any, # Event,
) -> None:
host = cfg.actor_learner_config.learner_host
port = cfg.actor_learner_config.learner_port
# Create multiprocessing queues
transition_queue = Queue()
interaction_message_queue = Queue()
parameters_queue = Queue()
transition_thread = Thread(
target=add_actor_information_and_train,
daemon=True,
concurrency_entity = None
if use_threads(cfg):
from threading import Thread
concurrency_entity = Thread
else:
from torch.multiprocessing import Process
concurrency_entity = Process
communication_process = concurrency_entity(
target=start_learner_server,
args=(
cfg,
device,
replay_buffer,
offline_replay_buffer,
batch_size,
optimizers,
policy,
policy_lock,
logger,
resume_optimization_step,
resume_interaction_step,
parameters_queue,
transition_queue,
interaction_message_queue,
shutdown_event,
cfg,
),
daemon=True,
)
communication_process.start()
transition_thread.start()
add_actor_information_and_train(
cfg,
logger,
out_dir,
shutdown_event,
transition_queue,
interaction_message_queue,
parameters_queue,
)
logging.info("[LEARNER] Training process stopped")
logging.info("[LEARNER] Closing queues")
transition_queue.close()
interaction_message_queue.close()
parameters_queue.close()
communication_process.join()
logging.info("[LEARNER] Communication process joined")
logging.info("[LEARNER] join queues")
transition_queue.cancel_join_thread()
interaction_message_queue.cancel_join_thread()
parameters_queue.cancel_join_thread()
logging.info("[LEARNER] queues closed")
def start_learner_server(
parameters_queue: Queue,
transition_queue: Queue,
interaction_message_queue: Queue,
shutdown_event: any, # Event,
cfg: DictConfig,
):
if not use_threads(cfg):
# We need init logging for MP separataly
init_logging()
# Setup process handlers to handle shutdown signal
# But use shutdown event from the main process
# Return back for MP
setup_process_handlers(False)
service = learner_service.LearnerService(
shutdown_event,
policy,
policy_lock,
parameters_queue,
cfg.actor_learner_config.policy_parameters_push_frequency,
transition_queue,
interaction_message_queue,
)
server = start_learner_server(service, host, port)
shutdown_event.wait()
server.stop(learner_service.STUTDOWN_TIMEOUT)
logging.info("[LEARNER] gRPC server stopped")
transition_thread.join()
logging.info("[LEARNER] Transition thread stopped")
def start_learner_server(
service: learner_service.LearnerService,
host="0.0.0.0",
port=50051,
) -> grpc.server:
server = grpc.server(
ThreadPoolExecutor(max_workers=learner_service.MAX_WORKERS),
options=[
@ -263,15 +295,23 @@ def start_learner_server(
("grpc.max_send_message_length", learner_service.MAX_MESSAGE_SIZE),
],
)
hilserl_pb2_grpc.add_LearnerServiceServicer_to_server(
service,
server,
)
host = cfg.actor_learner_config.learner_host
port = cfg.actor_learner_config.learner_port
server.add_insecure_port(f"{host}:{port}")
server.start()
logging.info("[LEARNER] gRPC server started")
return server
shutdown_event.wait()
logging.info("[LEARNER] Stopping gRPC server...")
server.stop(learner_service.STUTDOWN_TIMEOUT)
logging.info("[LEARNER] gRPC server stopped")
def check_nan_in_transition(
@ -287,19 +327,21 @@ def check_nan_in_transition(
logging.error("actions contains NaN values")
def push_actor_policy_to_queue(parameters_queue: Queue, policy: nn.Module):
logging.debug("[LEARNER] Pushing actor policy to the queue")
state_dict = move_state_dict_to_device(policy.actor.state_dict(), device="cpu")
state_bytes = state_to_bytes(state_dict)
parameters_queue.put(state_bytes)
def add_actor_information_and_train(
cfg,
device: str,
replay_buffer: ReplayBuffer,
offline_replay_buffer: ReplayBuffer,
batch_size: int,
optimizers: dict[str, torch.optim.Optimizer],
policy: nn.Module,
policy_lock: Lock,
logger: Logger,
resume_optimization_step: int | None = None,
resume_interaction_step: int | None = None,
shutdown_event: Event | None = None,
out_dir: str,
shutdown_event: any, # Event,
transition_queue: Queue,
interaction_message_queue: Queue,
parameters_queue: Queue,
):
"""
Handles data transfer from the actor to the learner, manages training updates,
@ -322,17 +364,73 @@ def add_actor_information_and_train(
Args:
cfg: Configuration object containing hyperparameters.
device (str): The computing device (`"cpu"` or `"cuda"`).
replay_buffer (ReplayBuffer): The primary replay buffer storing online transitions.
offline_replay_buffer (ReplayBuffer): An additional buffer for offline transitions.
batch_size (int): The number of transitions to sample per training step.
optimizers (Dict[str, torch.optim.Optimizer]): A dictionary of optimizers (`"actor"`, `"critic"`, `"temperature"`).
policy (nn.Module): The reinforcement learning policy with critic, actor, and temperature parameters.
policy_lock (Lock): A threading lock to ensure safe policy updates.
logger (Logger): Logger instance for tracking training progress.
resume_optimization_step (int | None): In the case of resume training, start from the last optimization step reached.
resume_interaction_step (int | None): In the case of resume training, shift the interaction step with the last saved step in order to not break logging.
shutdown_event (Event | None): Event to signal shutdown.
out_dir (str): The output directory for storing training checkpoints and logs.
shutdown_event (Event): Event to signal shutdown.
transition_queue (Queue): Queue for receiving transitions from the actor.
interaction_message_queue (Queue): Queue for receiving interaction messages from the actor.
parameters_queue (Queue): Queue for sending policy parameters to the actor.
"""
device = get_safe_torch_device(cfg.device, log=True)
storage_device = get_safe_torch_device(cfg_device=cfg.training.storage_device)
logging.info("Initializing policy")
### Instantiate the policy in both the actor and learner processes
### To avoid sending a SACPolicy object through the port, we create a policy intance
### on both sides, the learner sends the updated parameters every n steps to update the actor's parameters
# TODO: At some point we should just need make sac policy
policy: SACPolicy = make_policy(
hydra_cfg=cfg,
# dataset_stats=offline_dataset.meta.stats if not cfg.resume else None,
# Hack: But if we do online traning, we do not need dataset_stats
dataset_stats=None,
pretrained_policy_name_or_path=str(logger.last_pretrained_model_dir)
if cfg.resume
else None,
)
# compile policy
policy = torch.compile(policy)
assert isinstance(policy, nn.Module)
push_actor_policy_to_queue(parameters_queue, policy)
last_time_policy_pushed = time.time()
optimizers, lr_scheduler = make_optimizers_and_scheduler(cfg, policy)
resume_optimization_step, resume_interaction_step = load_training_state(
cfg, logger, optimizers
)
log_training_info(cfg, out_dir, policy)
replay_buffer = initialize_replay_buffer(cfg, logger, device, storage_device)
batch_size = cfg.training.batch_size
offline_replay_buffer = None
if cfg.dataset_repo_id is not None:
logging.info("make_dataset offline buffer")
offline_dataset = make_dataset(cfg)
logging.info("Convertion to a offline replay buffer")
active_action_dims = None
if cfg.env.wrapper.joint_masking_action_space is not None:
active_action_dims = [
i
for i, mask in enumerate(cfg.env.wrapper.joint_masking_action_space)
if mask
]
offline_replay_buffer = ReplayBuffer.from_lerobot_dataset(
offline_dataset,
device=device,
state_keys=cfg.policy.input_shapes.keys(),
action_mask=active_action_dims,
action_delta=cfg.env.wrapper.delta_action,
storage_device=storage_device,
optimize_memory=True,
)
batch_size: int = batch_size // 2 # We will sample from both replay buffer
# NOTE: This function doesn't have a single responsibility, it should be split into multiple functions
# in the future. The reason why we did that is the GIL in Python. It's super slow the performance
# are divided by 200. So we need to have a single thread that does all the work.
@ -345,33 +443,39 @@ def add_actor_information_and_train(
interaction_step_shift = (
resume_interaction_step if resume_interaction_step is not None else 0
)
saved_data = False
while True:
if shutdown_event is not None and shutdown_event.is_set():
logging.info("[LEARNER] Shutdown signal received. Exiting...")
break
while not transition_queue.empty():
logging.debug("[LEARNER] Waiting for transitions")
while not transition_queue.empty() and not shutdown_event.is_set():
transition_list = transition_queue.get()
transition_list = bytes_to_transitions(transition_list)
for transition in transition_list:
transition = move_transition_to_device(transition, device=device)
replay_buffer.add(**transition)
if transition.get("complementary_info", {}).get("is_intervention"):
offline_replay_buffer.add(**transition)
while not interaction_message_queue.empty():
logging.debug("[LEARNER] Received transitions")
logging.debug("[LEARNER] Waiting for interactions")
while not interaction_message_queue.empty() and not shutdown_event.is_set():
interaction_message = interaction_message_queue.get()
interaction_message = bytes_to_python_object(interaction_message)
# If cfg.resume, shift the interaction step with the last checkpointed step in order to not break the logging
interaction_message["Interaction step"] += interaction_step_shift
logger.log_dict(
interaction_message, mode="train", custom_step_key="Interaction step"
)
# logging.info(f"Interaction message: {interaction_message}")
logging.debug("[LEARNER] Received interactions")
if len(replay_buffer) < cfg.training.online_step_before_learning:
continue
logging.debug("[LEARNER] Starting optimization loop")
time_for_one_optimization_step = time.time()
for _ in range(cfg.policy.utd_ratio - 1):
batch = replay_buffer.sample(batch_size)
@ -392,19 +496,18 @@ def add_actor_information_and_train(
observation_features, next_observation_features = get_observation_features(
policy, observations, next_observations
)
with policy_lock:
loss_critic = policy.compute_loss_critic(
observations=observations,
actions=actions,
rewards=rewards,
next_observations=next_observations,
done=done,
observation_features=observation_features,
next_observation_features=next_observation_features,
)
optimizers["critic"].zero_grad()
loss_critic.backward()
optimizers["critic"].step()
loss_critic = policy.compute_loss_critic(
observations=observations,
actions=actions,
rewards=rewards,
next_observations=next_observations,
done=done,
observation_features=observation_features,
next_observation_features=next_observation_features,
)
optimizers["critic"].zero_grad()
loss_critic.backward()
optimizers["critic"].step()
batch = replay_buffer.sample(batch_size)
@ -427,46 +530,51 @@ def add_actor_information_and_train(
observation_features, next_observation_features = get_observation_features(
policy, observations, next_observations
)
with policy_lock:
loss_critic = policy.compute_loss_critic(
observations=observations,
actions=actions,
rewards=rewards,
next_observations=next_observations,
done=done,
observation_features=observation_features,
next_observation_features=next_observation_features,
)
optimizers["critic"].zero_grad()
loss_critic.backward()
optimizers["critic"].step()
loss_critic = policy.compute_loss_critic(
observations=observations,
actions=actions,
rewards=rewards,
next_observations=next_observations,
done=done,
observation_features=observation_features,
next_observation_features=next_observation_features,
)
optimizers["critic"].zero_grad()
loss_critic.backward()
optimizers["critic"].step()
training_infos = {}
training_infos["loss_critic"] = loss_critic.item()
if optimization_step % cfg.training.policy_update_freq == 0:
for _ in range(cfg.training.policy_update_freq):
with policy_lock:
loss_actor = policy.compute_loss_actor(
observations=observations,
observation_features=observation_features,
)
loss_actor = policy.compute_loss_actor(
observations=observations,
observation_features=observation_features,
)
optimizers["actor"].zero_grad()
loss_actor.backward()
optimizers["actor"].step()
optimizers["actor"].zero_grad()
loss_actor.backward()
optimizers["actor"].step()
training_infos["loss_actor"] = loss_actor.item()
training_infos["loss_actor"] = loss_actor.item()
loss_temperature = policy.compute_loss_temperature(
observations=observations,
observation_features=observation_features,
)
optimizers["temperature"].zero_grad()
loss_temperature.backward()
optimizers["temperature"].step()
loss_temperature = policy.compute_loss_temperature(
observations=observations,
observation_features=observation_features,
)
optimizers["temperature"].zero_grad()
loss_temperature.backward()
optimizers["temperature"].step()
training_infos["loss_temperature"] = loss_temperature.item()
training_infos["loss_temperature"] = loss_temperature.item()
if (
time.time() - last_time_policy_pushed
> cfg.actor_learner_config.policy_parameters_push_frequency
):
push_actor_policy_to_queue(parameters_queue, policy)
last_time_policy_pushed = time.time()
policy.update_target_networks()
if optimization_step % cfg.training.log_freq == 0:
@ -595,104 +703,36 @@ def train(cfg: DictConfig, out_dir: str | None = None, job_name: str | None = No
set_global_seed(cfg.seed)
device = get_safe_torch_device(cfg.device, log=True)
storage_device = get_safe_torch_device(cfg_device=cfg.training.storage_device)
torch.backends.cudnn.benchmark = True
torch.backends.cuda.matmul.allow_tf32 = True
logging.info("make_policy")
### Instantiate the policy in both the actor and learner processes
### To avoid sending a SACPolicy object through the port, we create a policy intance
### on both sides, the learner sends the updated parameters every n steps to update the actor's parameters
# TODO: At some point we should just need make sac policy
policy_lock = Lock()
policy: SACPolicy = make_policy(
hydra_cfg=cfg,
# dataset_stats=offline_dataset.meta.stats if not cfg.resume else None,
# Hack: But if we do online traning, we do not need dataset_stats
dataset_stats=None,
pretrained_policy_name_or_path=str(logger.last_pretrained_model_dir)
if cfg.resume
else None,
)
# compile policy
policy = torch.compile(policy)
assert isinstance(policy, nn.Module)
optimizers, lr_scheduler = make_optimizers_and_scheduler(cfg, policy)
resume_optimization_step, resume_interaction_step = load_training_state(
cfg, logger, optimizers
)
log_training_info(cfg, out_dir, policy)
replay_buffer = initialize_replay_buffer(cfg, logger, device, storage_device)
batch_size = cfg.training.batch_size
offline_replay_buffer = None
if cfg.dataset_repo_id is not None:
logging.info("make_dataset offline buffer")
offline_dataset = make_dataset(cfg)
logging.info("Convertion to a offline replay buffer")
active_action_dims = None
if cfg.env.wrapper.joint_masking_action_space is not None:
active_action_dims = [
i
for i, mask in enumerate(cfg.env.wrapper.joint_masking_action_space)
if mask
]
offline_replay_buffer = ReplayBuffer.from_lerobot_dataset(
offline_dataset,
device=device,
state_keys=cfg.policy.input_shapes.keys(),
action_mask=active_action_dims,
action_delta=cfg.env.wrapper.delta_action,
storage_device=storage_device,
optimize_memory=True,
)
batch_size: int = batch_size // 2 # We will sample from both replay buffer
shutdown_event = Event()
def signal_handler(signum, frame):
print(
f"\nReceived signal {signal.Signals(signum).name}. Initiating learner shutdown..."
)
shutdown_event.set()
# Register signal handlers
signal.signal(signal.SIGINT, signal_handler) # Ctrl+C
signal.signal(signal.SIGTERM, signal_handler) # Termination request
signal.signal(signal.SIGHUP, signal_handler) # Terminal closed
signal.signal(signal.SIGQUIT, signal_handler) # Ctrl+\
shutdown_event = setup_process_handlers(use_threads(cfg))
start_learner_threads(
cfg,
device,
replay_buffer,
offline_replay_buffer,
batch_size,
optimizers,
policy,
policy_lock,
logger,
resume_optimization_step,
resume_interaction_step,
out_dir,
shutdown_event,
)
@hydra.main(version_base="1.2", config_name="default", config_path="../../configs")
def train_cli(cfg: dict):
if not use_threads(cfg):
import torch.multiprocessing as mp
mp.set_start_method("spawn")
train(
cfg,
out_dir=hydra.core.hydra_config.HydraConfig.get().run.dir,
job_name=hydra.core.hydra_config.HydraConfig.get().job.name,
)
logging.info("[LEARNER] train_cli finished")
if __name__ == "__main__":
train_cli()
logging.info("[LEARNER] main finished")

View File

@ -1,23 +1,13 @@
import hilserl_pb2 # type: ignore
import hilserl_pb2_grpc # type: ignore
import torch
from torch import nn
from threading import Lock, Event
import logging
import queue
import io
import pickle
from lerobot.scripts.server.buffer import (
move_state_dict_to_device,
bytes_buffer_size,
state_to_bytes,
)
from multiprocessing import Event, Queue
from lerobot.scripts.server.network_utils import receive_bytes_in_chunks
from lerobot.scripts.server.network_utils import send_bytes_in_chunks
MAX_MESSAGE_SIZE = 4 * 1024 * 1024 # 4 MB
CHUNK_SIZE = 2 * 1024 * 1024 # 2 MB
MAX_WORKERS = 10
MAX_WORKERS = 3 # Stream parameters, send transitions and interactions
STUTDOWN_TIMEOUT = 10
@ -25,89 +15,68 @@ class LearnerService(hilserl_pb2_grpc.LearnerServiceServicer):
def __init__(
self,
shutdown_event: Event,
policy: nn.Module,
policy_lock: Lock,
parameters_queue: Queue,
seconds_between_pushes: float,
transition_queue: queue.Queue,
interaction_message_queue: queue.Queue,
transition_queue: Queue,
interaction_message_queue: Queue,
):
self.shutdown_event = shutdown_event
self.policy = policy
self.policy_lock = policy_lock
self.parameters_queue = parameters_queue
self.seconds_between_pushes = seconds_between_pushes
self.transition_queue = transition_queue
self.interaction_message_queue = interaction_message_queue
def _get_policy_state(self):
with self.policy_lock:
params_dict = self.policy.actor.state_dict()
# if self.policy.config.vision_encoder_name is not None:
# if self.policy.config.freeze_vision_encoder:
# params_dict: dict[str, torch.Tensor] = {
# k: v
# for k, v in params_dict.items()
# if not k.startswith("encoder.")
# }
# else:
# raise NotImplementedError(
# "Vision encoder is not frozen, we need to send the full model over the network which requires chunking the model."
# )
return move_state_dict_to_device(params_dict, device="cpu")
def _send_bytes(self, buffer: bytes):
size_in_bytes = bytes_buffer_size(buffer)
sent_bytes = 0
logging.info(f"Model state size {size_in_bytes/1024/1024} MB with")
while sent_bytes < size_in_bytes:
transfer_state = hilserl_pb2.TransferState.TRANSFER_MIDDLE
if sent_bytes + CHUNK_SIZE >= size_in_bytes:
transfer_state = hilserl_pb2.TransferState.TRANSFER_END
elif sent_bytes == 0:
transfer_state = hilserl_pb2.TransferState.TRANSFER_BEGIN
size_to_read = min(CHUNK_SIZE, size_in_bytes - sent_bytes)
chunk = buffer.read(size_to_read)
yield hilserl_pb2.Parameters(
transfer_state=transfer_state, parameter_bytes=chunk
)
sent_bytes += size_to_read
logging.info(
f"[Learner] Sent {sent_bytes}/{size_in_bytes} bytes with state {transfer_state}"
)
logging.info(f"[LEARNER] Published {sent_bytes/1024/1024} MB to the Actor")
def StreamParameters(self, request, context):
# TODO: authorize the request
logging.info("[LEARNER] Received request to stream parameters from the Actor")
while not self.shutdown_event.is_set():
logging.debug("[LEARNER] Push parameters to the Actor")
state_dict = self._get_policy_state()
logging.info("[LEARNER] Push parameters to the Actor")
buffer = self.parameters_queue.get()
with state_to_bytes(state_dict) as buffer:
yield from self._send_bytes(buffer)
yield from send_bytes_in_chunks(
buffer,
hilserl_pb2.Parameters,
log_prefix="[LEARNER] Sending parameters",
silent=True,
)
logging.info("[LEARNER] Parameters sent")
self.shutdown_event.wait(self.seconds_between_pushes)
def ReceiveTransitions(self, request_iterator, context):
logging.info("[LEARNER] Stream parameters finished")
return hilserl_pb2.Empty()
def SendTransitions(self, request_iterator, _context):
# TODO: authorize the request
logging.info("[LEARNER] Received request to receive transitions from the Actor")
for request in request_iterator:
logging.debug("[LEARNER] Received request")
if request.HasField("transition"):
buffer = io.BytesIO(request.transition.transition_bytes)
transition = torch.load(buffer)
self.transition_queue.put(transition)
if request.HasField("interaction_message"):
content = pickle.loads(
request.interaction_message.interaction_message_bytes
)
self.interaction_message_queue.put(content)
receive_bytes_in_chunks(
request_iterator,
self.transition_queue,
self.shutdown_event,
log_prefix="[LEARNER] transitions",
)
logging.debug("[LEARNER] Finished receiving transitions")
return hilserl_pb2.Empty()
def SendInteractions(self, request_iterator, _context):
# TODO: authorize the request
logging.info(
"[LEARNER] Received request to receive interactions from the Actor"
)
receive_bytes_in_chunks(
request_iterator,
self.interaction_message_queue,
self.shutdown_event,
log_prefix="[LEARNER] interactions",
)
logging.debug("[LEARNER] Finished receiving interactions")
return hilserl_pb2.Empty()
def Ready(self, request, context):
return hilserl_pb2.Empty()

View File

@ -5,9 +5,8 @@ import torch
from omegaconf import DictConfig
from typing import Any
"""Make ManiSkill3 gym environment"""
from mani_skill.vector.wrappers.gymnasium import ManiSkillVectorEnv
from mani_skill.utils.wrappers.record import RecordEpisode
def preprocess_maniskill_observation(
@ -143,6 +142,15 @@ def make_maniskill(
num_envs=n_envs,
)
if cfg.env.video_record.enabled:
env = RecordEpisode(
env,
output_dir=cfg.env.video_record.record_dir,
save_trajectory=True,
trajectory_name=cfg.env.video_record.trajectory_name,
save_video=True,
video_fps=30,
)
env = ManiSkillObservationWrapper(env, device=cfg.env.device)
env = ManiSkillVectorEnv(env, ignore_terminations=True, auto_reset=False)
env._max_episode_steps = env.max_episode_steps = (

View File

@ -0,0 +1,102 @@
#!/usr/bin/env python
# Copyright 2024 The HuggingFace Inc. team.
# All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from lerobot.scripts.server import hilserl_pb2
import logging
import io
from multiprocessing import Queue, Event
from typing import Any
CHUNK_SIZE = 2 * 1024 * 1024 # 2 MB
def bytes_buffer_size(buffer: io.BytesIO) -> int:
buffer.seek(0, io.SEEK_END)
result = buffer.tell()
buffer.seek(0)
return result
def send_bytes_in_chunks(
buffer: bytes, message_class: Any, log_prefix: str = "", silent: bool = True
):
buffer = io.BytesIO(buffer)
size_in_bytes = bytes_buffer_size(buffer)
sent_bytes = 0
logging_method = logging.info if not silent else logging.debug
logging_method(f"{log_prefix} Buffer size {size_in_bytes/1024/1024} MB with")
while sent_bytes < size_in_bytes:
transfer_state = hilserl_pb2.TransferState.TRANSFER_MIDDLE
if sent_bytes + CHUNK_SIZE >= size_in_bytes:
transfer_state = hilserl_pb2.TransferState.TRANSFER_END
elif sent_bytes == 0:
transfer_state = hilserl_pb2.TransferState.TRANSFER_BEGIN
size_to_read = min(CHUNK_SIZE, size_in_bytes - sent_bytes)
chunk = buffer.read(size_to_read)
yield message_class(transfer_state=transfer_state, data=chunk)
sent_bytes += size_to_read
logging_method(
f"{log_prefix} Sent {sent_bytes}/{size_in_bytes} bytes with state {transfer_state}"
)
logging_method(f"{log_prefix} Published {sent_bytes/1024/1024} MB")
def receive_bytes_in_chunks(
iterator, queue: Queue, shutdown_event: Event, log_prefix: str = ""
):
bytes_buffer = io.BytesIO()
step = 0
logging.info(f"{log_prefix} Starting receiver")
for item in iterator:
logging.debug(f"{log_prefix} Received item")
if shutdown_event.is_set():
logging.info(f"{log_prefix} Shutting down receiver")
return
if item.transfer_state == hilserl_pb2.TransferState.TRANSFER_BEGIN:
bytes_buffer.seek(0)
bytes_buffer.truncate(0)
bytes_buffer.write(item.data)
logging.debug(f"{log_prefix} Received data at step 0")
step = 0
continue
elif item.transfer_state == hilserl_pb2.TransferState.TRANSFER_MIDDLE:
bytes_buffer.write(item.data)
step += 1
logging.debug(f"{log_prefix} Received data at step {step}")
elif item.transfer_state == hilserl_pb2.TransferState.TRANSFER_END:
bytes_buffer.write(item.data)
logging.debug(
f"{log_prefix} Received data at step end size {bytes_buffer_size(bytes_buffer)}"
)
queue.put(bytes_buffer.getvalue())
bytes_buffer.seek(0)
bytes_buffer.truncate(0)
step = 0
logging.debug(f"{log_prefix} Queue updated")

View File

@ -0,0 +1,72 @@
#!/usr/bin/env python
# Copyright 2024 The HuggingFace Inc. team.
# All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
import signal
import sys
from torch.multiprocessing import Queue
from queue import Empty
shutdown_event_counter = 0
def setup_process_handlers(use_threads: bool) -> any:
if use_threads:
from threading import Event
else:
from multiprocessing import Event
shutdown_event = Event()
# Define signal handler
def signal_handler(signum, frame):
logging.info("Shutdown signal received. Cleaning up...")
shutdown_event.set()
global shutdown_event_counter
shutdown_event_counter += 1
if shutdown_event_counter > 1:
logging.info("Force shutdown")
sys.exit(1)
signal.signal(signal.SIGINT, signal_handler) # Ctrl+C
signal.signal(signal.SIGTERM, signal_handler) # Termination request (kill)
signal.signal(signal.SIGHUP, signal_handler) # Terminal closed/Hangup
signal.signal(signal.SIGQUIT, signal_handler) # Ctrl+\
def signal_handler(signum, frame):
logging.info("Shutdown signal received. Cleaning up...")
shutdown_event.set()
return shutdown_event
def get_last_item_from_queue(queue: Queue):
item = queue.get()
counter = 1
# Drain queue and keep only the most recent parameters
try:
while True:
item = queue.get_nowait()
counter += 1
except Empty:
pass
logging.debug(f"Drained {counter} items from queue")
return item