[HIL-SERL] Migrate threading to multiprocessing (#759)
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
This commit is contained in:
parent
584cad808e
commit
700f00c014
|
@ -116,11 +116,11 @@ def seeded_context(seed: int) -> Generator[None, None, None]:
|
|||
set_global_random_state(random_state_dict)
|
||||
|
||||
|
||||
def init_logging():
|
||||
def init_logging(log_file=None):
|
||||
def custom_format(record):
|
||||
dt = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
|
||||
fnameline = f"{record.pathname}:{record.lineno}"
|
||||
message = f"{record.levelname} {dt} {fnameline[-15:]:>15} {record.msg}"
|
||||
message = f"{record.levelname} [PID: {os.getpid()}] {dt} {fnameline[-15:]:>15} {record.msg}"
|
||||
return message
|
||||
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
|
@ -134,6 +134,12 @@ def init_logging():
|
|||
console_handler.setFormatter(formatter)
|
||||
logging.getLogger().addHandler(console_handler)
|
||||
|
||||
if log_file is not None:
|
||||
# File handler
|
||||
file_handler = logging.FileHandler(log_file)
|
||||
file_handler.setFormatter(formatter)
|
||||
logging.getLogger().addHandler(file_handler)
|
||||
|
||||
|
||||
def format_big_number(num, precision=0):
|
||||
suffixes = ["", "K", "M", "B", "T", "Q"]
|
||||
|
|
|
@ -22,3 +22,9 @@ env:
|
|||
wrapper:
|
||||
joint_masking_action_space: null
|
||||
delta_action: null
|
||||
|
||||
video_record:
|
||||
enabled: false
|
||||
record_dir: maniskill_videos
|
||||
trajectory_name: trajectory
|
||||
fps: ${fps}
|
||||
|
|
|
@ -28,4 +28,3 @@ env:
|
|||
reward_classifier:
|
||||
pretrained_path: outputs/classifier/13-02-random-sample-resnet10-frozen/checkpoints/best/pretrained_model
|
||||
config_path: lerobot/configs/policy/hilserl_classifier.yaml
|
||||
|
||||
|
|
|
@ -8,14 +8,12 @@
|
|||
# env.gym.obs_type=environment_state_agent_pos \
|
||||
|
||||
seed: 1
|
||||
# dataset_repo_id: null
|
||||
dataset_repo_id: "AdilZtn/Maniskill-Pushcube-demonstration-medium"
|
||||
|
||||
training:
|
||||
# Offline training dataloader
|
||||
num_workers: 4
|
||||
|
||||
# batch_size: 256
|
||||
batch_size: 512
|
||||
grad_clip_norm: 10.0
|
||||
lr: 3e-4
|
||||
|
@ -113,4 +111,7 @@ policy:
|
|||
actor_learner_config:
|
||||
learner_host: "127.0.0.1"
|
||||
learner_port: 50051
|
||||
policy_parameters_push_frequency: 15
|
||||
policy_parameters_push_frequency: 1
|
||||
concurrency:
|
||||
actor: 'processes'
|
||||
learner: 'processes'
|
||||
|
|
|
@ -13,22 +13,19 @@
|
|||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import io
|
||||
import logging
|
||||
import pickle
|
||||
import queue
|
||||
from statistics import mean, quantiles
|
||||
import signal
|
||||
from functools import lru_cache
|
||||
from lerobot.scripts.server.utils import setup_process_handlers
|
||||
|
||||
# from lerobot.scripts.eval import eval_policy
|
||||
from threading import Thread
|
||||
|
||||
import grpc
|
||||
import hydra
|
||||
import torch
|
||||
from omegaconf import DictConfig
|
||||
from torch import nn
|
||||
import time
|
||||
|
||||
# TODO: Remove the import of maniskill
|
||||
# from lerobot.common.envs.factory import make_maniskill_env
|
||||
|
@ -47,157 +44,184 @@ from lerobot.scripts.server.buffer import (
|
|||
Transition,
|
||||
move_state_dict_to_device,
|
||||
move_transition_to_device,
|
||||
bytes_buffer_size,
|
||||
python_object_to_bytes,
|
||||
transitions_to_bytes,
|
||||
bytes_to_state_dict,
|
||||
)
|
||||
from lerobot.scripts.server.network_utils import (
|
||||
receive_bytes_in_chunks,
|
||||
send_bytes_in_chunks,
|
||||
)
|
||||
from lerobot.scripts.server.gym_manipulator import get_classifier, make_robot_env
|
||||
from lerobot.scripts.server import learner_service
|
||||
|
||||
from threading import Event
|
||||
from torch.multiprocessing import Queue, Event
|
||||
from queue import Empty
|
||||
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
from lerobot.common.utils.utils import init_logging
|
||||
|
||||
parameters_queue = queue.Queue(maxsize=1)
|
||||
message_queue = queue.Queue(maxsize=1_000_000)
|
||||
from lerobot.scripts.server.utils import get_last_item_from_queue
|
||||
|
||||
ACTOR_SHUTDOWN_TIMEOUT = 30
|
||||
|
||||
|
||||
class ActorInformation:
|
||||
"""
|
||||
This helper class is used to differentiate between two types of messages that are placed in the same queue during streaming:
|
||||
|
||||
- **Transition Data:** Contains experience tuples (observation, action, reward, next observation) collected during interaction.
|
||||
- **Interaction Messages:** Encapsulates statistics related to the interaction process.
|
||||
|
||||
Attributes:
|
||||
transition (Optional): Transition data to be sent to the learner.
|
||||
interaction_message (Optional): Iteraction message providing additional statistics for logging.
|
||||
"""
|
||||
|
||||
def __init__(self, transition=None, interaction_message=None):
|
||||
self.transition = transition
|
||||
self.interaction_message = interaction_message
|
||||
|
||||
|
||||
def receive_policy(
|
||||
learner_client: hilserl_pb2_grpc.LearnerServiceStub,
|
||||
shutdown_event: Event,
|
||||
parameters_queue: queue.Queue,
|
||||
cfg: DictConfig,
|
||||
parameters_queue: Queue,
|
||||
shutdown_event: any, # Event,
|
||||
learner_client: hilserl_pb2_grpc.LearnerServiceStub | None = None,
|
||||
grpc_channel: grpc.Channel | None = None,
|
||||
):
|
||||
logging.info("[ACTOR] Start receiving parameters from the Learner")
|
||||
bytes_buffer = io.BytesIO()
|
||||
step = 0
|
||||
|
||||
if not use_threads(cfg):
|
||||
# Setup process handlers to handle shutdown signal
|
||||
# But use shutdown event from the main process
|
||||
setup_process_handlers(False)
|
||||
|
||||
if grpc_channel is None or learner_client is None:
|
||||
learner_client, grpc_channel = learner_service_client(
|
||||
host=cfg.actor_learner_config.learner_host,
|
||||
port=cfg.actor_learner_config.learner_port,
|
||||
)
|
||||
|
||||
try:
|
||||
for model_update in learner_client.StreamParameters(hilserl_pb2.Empty()):
|
||||
if shutdown_event.is_set():
|
||||
logging.info("[ACTOR] Shutting down policy streaming receiver")
|
||||
return hilserl_pb2.Empty()
|
||||
|
||||
if model_update.transfer_state == hilserl_pb2.TransferState.TRANSFER_BEGIN:
|
||||
bytes_buffer.seek(0)
|
||||
bytes_buffer.truncate(0)
|
||||
bytes_buffer.write(model_update.parameter_bytes)
|
||||
logging.info("Received model update at step 0")
|
||||
step = 0
|
||||
continue
|
||||
elif (
|
||||
model_update.transfer_state == hilserl_pb2.TransferState.TRANSFER_MIDDLE
|
||||
):
|
||||
bytes_buffer.write(model_update.parameter_bytes)
|
||||
step += 1
|
||||
logging.info(f"Received model update at step {step}")
|
||||
elif model_update.transfer_state == hilserl_pb2.TransferState.TRANSFER_END:
|
||||
bytes_buffer.write(model_update.parameter_bytes)
|
||||
logging.info(
|
||||
f"Received model update at step end size {bytes_buffer_size(bytes_buffer)}"
|
||||
)
|
||||
|
||||
state_dict = torch.load(bytes_buffer)
|
||||
|
||||
bytes_buffer.seek(0)
|
||||
bytes_buffer.truncate(0)
|
||||
step = 0
|
||||
|
||||
logging.info("Model updated")
|
||||
|
||||
parameters_queue.put(state_dict)
|
||||
|
||||
iterator = learner_client.StreamParameters(hilserl_pb2.Empty())
|
||||
receive_bytes_in_chunks(
|
||||
iterator,
|
||||
parameters_queue,
|
||||
shutdown_event,
|
||||
log_prefix="[ACTOR] parameters",
|
||||
)
|
||||
except grpc.RpcError as e:
|
||||
logging.error(f"[ACTOR] gRPC error: {e}")
|
||||
|
||||
if not use_threads(cfg):
|
||||
grpc_channel.close()
|
||||
logging.info("[ACTOR] Received policy loop stopped")
|
||||
|
||||
|
||||
def transitions_stream(
|
||||
shutdown_event: Event, transitions_queue: Queue
|
||||
) -> hilserl_pb2.Empty:
|
||||
while not shutdown_event.is_set():
|
||||
try:
|
||||
message = transitions_queue.get(block=True, timeout=5)
|
||||
except Empty:
|
||||
logging.debug("[ACTOR] Transition queue is empty")
|
||||
continue
|
||||
|
||||
yield from send_bytes_in_chunks(
|
||||
message, hilserl_pb2.Transition, log_prefix="[ACTOR] Send transitions"
|
||||
)
|
||||
|
||||
return hilserl_pb2.Empty()
|
||||
|
||||
|
||||
def transitions_stream(shutdown_event: Event, message_queue: queue.Queue):
|
||||
def interactions_stream(
|
||||
shutdown_event: any, # Event,
|
||||
interactions_queue: Queue,
|
||||
) -> hilserl_pb2.Empty:
|
||||
while not shutdown_event.is_set():
|
||||
try:
|
||||
message = message_queue.get(block=True, timeout=5)
|
||||
except queue.Empty:
|
||||
logging.debug("[ACTOR] Transition queue is empty")
|
||||
message = interactions_queue.get(block=True, timeout=5)
|
||||
except Empty:
|
||||
logging.debug("[ACTOR] Interaction queue is empty")
|
||||
continue
|
||||
|
||||
if message.transition is not None:
|
||||
transition_to_send_to_learner: list[Transition] = [
|
||||
move_transition_to_device(transition=T, device="cpu")
|
||||
for T in message.transition
|
||||
]
|
||||
# Check for NaNs in transitions before sending to learner
|
||||
for transition in transition_to_send_to_learner:
|
||||
for key, value in transition["state"].items():
|
||||
if torch.isnan(value).any():
|
||||
logging.warning(f"Found NaN values in transition {key}")
|
||||
buf = io.BytesIO()
|
||||
torch.save(transition_to_send_to_learner, buf)
|
||||
transition_bytes = buf.getvalue()
|
||||
|
||||
transition_message = hilserl_pb2.Transition(
|
||||
transition_bytes=transition_bytes
|
||||
)
|
||||
|
||||
response = hilserl_pb2.ActorInformation(transition=transition_message)
|
||||
|
||||
elif message.interaction_message is not None:
|
||||
content = hilserl_pb2.InteractionMessage(
|
||||
interaction_message_bytes=pickle.dumps(message.interaction_message)
|
||||
)
|
||||
response = hilserl_pb2.ActorInformation(interaction_message=content)
|
||||
|
||||
yield response
|
||||
yield from send_bytes_in_chunks(
|
||||
message,
|
||||
hilserl_pb2.InteractionMessage,
|
||||
log_prefix="[ACTOR] Send interactions",
|
||||
)
|
||||
|
||||
return hilserl_pb2.Empty()
|
||||
|
||||
|
||||
def send_transitions(
|
||||
learner_client: hilserl_pb2_grpc.LearnerServiceStub,
|
||||
shutdown_event: Event,
|
||||
message_queue: queue.Queue,
|
||||
):
|
||||
cfg: DictConfig,
|
||||
transitions_queue: Queue,
|
||||
shutdown_event: any, # Event,
|
||||
learner_client: hilserl_pb2_grpc.LearnerServiceStub | None = None,
|
||||
grpc_channel: grpc.Channel | None = None,
|
||||
) -> hilserl_pb2.Empty:
|
||||
"""
|
||||
Streams data from the actor to the learner.
|
||||
Sends transitions to the learner.
|
||||
|
||||
This function continuously retrieves messages from the queue and processes them based on their type:
|
||||
This function continuously retrieves messages from the queue and processes:
|
||||
|
||||
- **Transition Data:**
|
||||
- A batch of transitions (observation, action, reward, next observation) is collected.
|
||||
- Transitions are moved to the CPU and serialized using PyTorch.
|
||||
- The serialized data is wrapped in a `hilserl_pb2.Transition` message and sent to the learner.
|
||||
|
||||
- **Interaction Messages:**
|
||||
- Contains useful statistics about episodic rewards and policy timings.
|
||||
- The message is serialized using `pickle` and sent to the learner.
|
||||
|
||||
Yields:
|
||||
hilserl_pb2.ActorInformation: The response message containing either transition data or an interaction message.
|
||||
"""
|
||||
|
||||
if not use_threads(cfg):
|
||||
# Setup process handlers to handle shutdown signal
|
||||
# But use shutdown event from the main process
|
||||
setup_process_handlers(False)
|
||||
|
||||
if grpc_channel is None or learner_client is None:
|
||||
learner_client, grpc_channel = learner_service_client(
|
||||
host=cfg.actor_learner_config.learner_host,
|
||||
port=cfg.actor_learner_config.learner_port,
|
||||
)
|
||||
|
||||
try:
|
||||
learner_client.ReceiveTransitions(
|
||||
transitions_stream(shutdown_event, message_queue)
|
||||
learner_client.SendTransitions(
|
||||
transitions_stream(shutdown_event, transitions_queue)
|
||||
)
|
||||
except grpc.RpcError as e:
|
||||
logging.error(f"[ACTOR] gRPC error: {e}")
|
||||
|
||||
logging.info("[ACTOR] Finished streaming transitions")
|
||||
|
||||
if not use_threads(cfg):
|
||||
grpc_channel.close()
|
||||
logging.info("[ACTOR] Transitions process stopped")
|
||||
|
||||
|
||||
def send_interactions(
|
||||
cfg: DictConfig,
|
||||
interactions_queue: Queue,
|
||||
shutdown_event: any, # Event,
|
||||
learner_client: hilserl_pb2_grpc.LearnerServiceStub | None = None,
|
||||
grpc_channel: grpc.Channel | None = None,
|
||||
) -> hilserl_pb2.Empty:
|
||||
"""
|
||||
Sends interactions to the learner.
|
||||
|
||||
This function continuously retrieves messages from the queue and processes:
|
||||
|
||||
- **Interaction Messages:**
|
||||
- Contains useful statistics about episodic rewards and policy timings.
|
||||
- The message is serialized using `pickle` and sent to the learner.
|
||||
"""
|
||||
|
||||
if not use_threads(cfg):
|
||||
# Setup process handlers to handle shutdown signal
|
||||
# But use shutdown event from the main process
|
||||
setup_process_handlers(False)
|
||||
|
||||
if grpc_channel is None or learner_client is None:
|
||||
learner_client, grpc_channel = learner_service_client(
|
||||
host=cfg.actor_learner_config.learner_host,
|
||||
port=cfg.actor_learner_config.learner_port,
|
||||
)
|
||||
|
||||
try:
|
||||
learner_client.SendInteractions(
|
||||
interactions_stream(shutdown_event, interactions_queue)
|
||||
)
|
||||
except grpc.RpcError as e:
|
||||
logging.error(f"[ACTOR] gRPC error: {e}")
|
||||
|
||||
logging.info("[ACTOR] Finished streaming interactions")
|
||||
|
||||
if not use_threads(cfg):
|
||||
grpc_channel.close()
|
||||
logging.info("[ACTOR] Interactions process stopped")
|
||||
|
||||
|
||||
@lru_cache(maxsize=1)
|
||||
def learner_service_client(
|
||||
|
@ -217,7 +241,7 @@ def learner_service_client(
|
|||
{
|
||||
"name": [{}], # Applies to ALL methods in ALL services
|
||||
"retryPolicy": {
|
||||
"maxAttempts": 7, # Max retries (total attempts = 5)
|
||||
"maxAttempts": 5, # Max retries (total attempts = 5)
|
||||
"initialBackoff": "0.1s", # First retry after 0.1s
|
||||
"maxBackoff": "2s", # Max wait time between retries
|
||||
"backoffMultiplier": 2, # Exponential backoff factor
|
||||
|
@ -242,20 +266,27 @@ def learner_service_client(
|
|||
],
|
||||
)
|
||||
stub = hilserl_pb2_grpc.LearnerServiceStub(channel)
|
||||
logging.info("[LEARNER] Learner service client created")
|
||||
logging.info("[ACTOR] Learner service client created")
|
||||
return stub, channel
|
||||
|
||||
|
||||
def update_policy_parameters(policy: SACPolicy, parameters_queue: queue.Queue, device):
|
||||
def update_policy_parameters(policy: SACPolicy, parameters_queue: Queue, device):
|
||||
if not parameters_queue.empty():
|
||||
logging.info("[ACTOR] Load new parameters from Learner.")
|
||||
state_dict = parameters_queue.get()
|
||||
bytes_state_dict = get_last_item_from_queue(parameters_queue)
|
||||
state_dict = bytes_to_state_dict(bytes_state_dict)
|
||||
state_dict = move_state_dict_to_device(state_dict, device=device)
|
||||
policy.load_state_dict(state_dict)
|
||||
|
||||
|
||||
def act_with_policy(
|
||||
cfg: DictConfig, robot: Robot, reward_classifier: nn.Module, shutdown_event: Event
|
||||
cfg: DictConfig,
|
||||
robot: Robot,
|
||||
reward_classifier: nn.Module,
|
||||
shutdown_event: any, # Event,
|
||||
parameters_queue: Queue,
|
||||
transitions_queue: Queue,
|
||||
interactions_queue: Queue,
|
||||
):
|
||||
"""
|
||||
Executes policy interaction within the environment.
|
||||
|
@ -317,7 +348,7 @@ def act_with_policy(
|
|||
|
||||
for interaction_step in range(cfg.training.online_steps):
|
||||
if shutdown_event.is_set():
|
||||
logging.info("[ACTOR] Shutdown signal received. Exiting...")
|
||||
logging.info("[ACTOR] Shutting down act_with_policy")
|
||||
return
|
||||
|
||||
if interaction_step >= cfg.training.online_step_before_learning:
|
||||
|
@ -394,10 +425,9 @@ def act_with_policy(
|
|||
)
|
||||
|
||||
if len(list_transition_to_send_to_learner) > 0:
|
||||
send_transitions_in_chunks(
|
||||
push_transitions_to_transport_queue(
|
||||
transitions=list_transition_to_send_to_learner,
|
||||
message_queue=message_queue,
|
||||
chunk_size=4,
|
||||
transitions_queue=transitions_queue,
|
||||
)
|
||||
list_transition_to_send_to_learner = []
|
||||
|
||||
|
@ -405,9 +435,9 @@ def act_with_policy(
|
|||
list_policy_time.clear()
|
||||
|
||||
# Send episodic reward to the learner
|
||||
message_queue.put(
|
||||
ActorInformation(
|
||||
interaction_message={
|
||||
interactions_queue.put(
|
||||
python_object_to_bytes(
|
||||
{
|
||||
"Episodic reward": sum_reward_episode,
|
||||
"Interaction step": interaction_step,
|
||||
"Episode intervention": int(episode_intervention),
|
||||
|
@ -420,7 +450,7 @@ def act_with_policy(
|
|||
obs, info = online_env.reset()
|
||||
|
||||
|
||||
def send_transitions_in_chunks(transitions: list, message_queue, chunk_size: int = 100):
|
||||
def push_transitions_to_transport_queue(transitions: list, transitions_queue):
|
||||
"""Send transitions to learner in smaller chunks to avoid network issues.
|
||||
|
||||
Args:
|
||||
|
@ -428,10 +458,16 @@ def send_transitions_in_chunks(transitions: list, message_queue, chunk_size: int
|
|||
message_queue: Queue to send messages to learner
|
||||
chunk_size: Size of each chunk to send
|
||||
"""
|
||||
for i in range(0, len(transitions), chunk_size):
|
||||
chunk = transitions[i : i + chunk_size]
|
||||
logging.debug(f"[ACTOR] Sending chunk of {len(chunk)} transitions to Learner.")
|
||||
message_queue.put(ActorInformation(transition=chunk))
|
||||
transition_to_send_to_learner = []
|
||||
for transition in transitions:
|
||||
tr = move_transition_to_device(transition=transition, device="cpu")
|
||||
for key, value in tr["state"].items():
|
||||
if torch.isnan(value).any():
|
||||
logging.warning(f"Found NaN values in transition {key}")
|
||||
|
||||
transition_to_send_to_learner.append(tr)
|
||||
|
||||
transitions_queue.put(transitions_to_bytes(transition_to_send_to_learner))
|
||||
|
||||
|
||||
def get_frequency_stats(list_policy_time: list[float]) -> dict[str, float]:
|
||||
|
@ -458,39 +494,96 @@ def log_policy_frequency_issue(
|
|||
)
|
||||
|
||||
|
||||
def establish_learner_connection(
|
||||
stub,
|
||||
shutdown_event: any, # Event,
|
||||
attempts=30,
|
||||
):
|
||||
for _ in range(attempts):
|
||||
if shutdown_event.is_set():
|
||||
logging.info("[ACTOR] Shutting down establish_learner_connection")
|
||||
return False
|
||||
|
||||
# Force a connection attempt and check state
|
||||
try:
|
||||
logging.info("[ACTOR] Send ready message to Learner")
|
||||
if stub.Ready(hilserl_pb2.Empty()) == hilserl_pb2.Empty():
|
||||
return True
|
||||
except grpc.RpcError as e:
|
||||
logging.error(f"[ACTOR] Waiting for Learner to be ready... {e}")
|
||||
time.sleep(2)
|
||||
return False
|
||||
|
||||
|
||||
def use_threads(cfg: DictConfig) -> bool:
|
||||
return cfg.actor_learner_config.concurrency.actor == "threads"
|
||||
|
||||
|
||||
@hydra.main(version_base="1.2", config_name="default", config_path="../../configs")
|
||||
def actor_cli(cfg: dict):
|
||||
if not use_threads(cfg):
|
||||
import torch.multiprocessing as mp
|
||||
|
||||
mp.set_start_method("spawn")
|
||||
|
||||
init_logging(log_file="actor.log")
|
||||
robot = make_robot(cfg=cfg.robot)
|
||||
|
||||
shutdown_event = Event()
|
||||
|
||||
# Define signal handler
|
||||
def signal_handler(signum, frame):
|
||||
logging.info("Shutdown signal received. Cleaning up...")
|
||||
shutdown_event.set()
|
||||
|
||||
signal.signal(signal.SIGINT, signal_handler) # Ctrl+C
|
||||
signal.signal(signal.SIGTERM, signal_handler) # Termination request (kill)
|
||||
signal.signal(signal.SIGHUP, signal_handler) # Terminal closed/Hangup
|
||||
signal.signal(signal.SIGQUIT, signal_handler) # Ctrl+\
|
||||
shutdown_event = setup_process_handlers(use_threads(cfg))
|
||||
|
||||
learner_client, grpc_channel = learner_service_client(
|
||||
host=cfg.actor_learner_config.learner_host,
|
||||
port=cfg.actor_learner_config.learner_port,
|
||||
)
|
||||
|
||||
receive_policy_thread = Thread(
|
||||
logging.info("[ACTOR] Establishing connection with Learner")
|
||||
if not establish_learner_connection(learner_client, shutdown_event):
|
||||
logging.error("[ACTOR] Failed to establish connection with Learner")
|
||||
return
|
||||
|
||||
if not use_threads(cfg):
|
||||
# If we use multithreading, we can reuse the channel
|
||||
grpc_channel.close()
|
||||
grpc_channel = None
|
||||
|
||||
logging.info("[ACTOR] Connection with Learner established")
|
||||
|
||||
parameters_queue = Queue()
|
||||
transitions_queue = Queue()
|
||||
interactions_queue = Queue()
|
||||
|
||||
concurrency_entity = None
|
||||
if use_threads(cfg):
|
||||
from threading import Thread
|
||||
|
||||
concurrency_entity = Thread
|
||||
else:
|
||||
from multiprocessing import Process
|
||||
|
||||
concurrency_entity = Process
|
||||
|
||||
receive_policy_process = concurrency_entity(
|
||||
target=receive_policy,
|
||||
args=(learner_client, shutdown_event, parameters_queue),
|
||||
args=(cfg, parameters_queue, shutdown_event, grpc_channel),
|
||||
daemon=True,
|
||||
)
|
||||
|
||||
transitions_thread = Thread(
|
||||
transitions_process = concurrency_entity(
|
||||
target=send_transitions,
|
||||
args=(learner_client, shutdown_event, message_queue),
|
||||
args=(cfg, transitions_queue, shutdown_event, grpc_channel),
|
||||
daemon=True,
|
||||
)
|
||||
|
||||
interactions_process = concurrency_entity(
|
||||
target=send_interactions,
|
||||
args=(cfg, interactions_queue, shutdown_event, grpc_channel),
|
||||
daemon=True,
|
||||
)
|
||||
|
||||
transitions_process.start()
|
||||
interactions_process.start()
|
||||
receive_policy_process.start()
|
||||
|
||||
# HACK: FOR MANISKILL we do not have a reward classifier
|
||||
# TODO: Remove this once we merge into main
|
||||
reward_classifier = None
|
||||
|
@ -503,26 +596,35 @@ def actor_cli(cfg: dict):
|
|||
config_path=cfg.env.reward_classifier.config_path,
|
||||
)
|
||||
|
||||
policy_thread = Thread(
|
||||
target=act_with_policy,
|
||||
daemon=True,
|
||||
args=(cfg, robot, reward_classifier, shutdown_event),
|
||||
act_with_policy(
|
||||
cfg,
|
||||
robot,
|
||||
reward_classifier,
|
||||
shutdown_event,
|
||||
parameters_queue,
|
||||
transitions_queue,
|
||||
interactions_queue,
|
||||
)
|
||||
logging.info("[ACTOR] Policy process joined")
|
||||
|
||||
transitions_thread.start()
|
||||
policy_thread.start()
|
||||
receive_policy_thread.start()
|
||||
logging.info("[ACTOR] Closing queues")
|
||||
transitions_queue.close()
|
||||
interactions_queue.close()
|
||||
parameters_queue.close()
|
||||
|
||||
shutdown_event.wait()
|
||||
logging.info("[ACTOR] Shutdown event received")
|
||||
grpc_channel.close()
|
||||
transitions_process.join()
|
||||
logging.info("[ACTOR] Transitions process joined")
|
||||
interactions_process.join()
|
||||
logging.info("[ACTOR] Interactions process joined")
|
||||
receive_policy_process.join()
|
||||
logging.info("[ACTOR] Receive policy process joined")
|
||||
|
||||
policy_thread.join()
|
||||
logging.info("[ACTOR] Policy thread joined")
|
||||
transitions_thread.join()
|
||||
logging.info("[ACTOR] Transitions thread joined")
|
||||
receive_policy_thread.join()
|
||||
logging.info("[ACTOR] Receive policy thread joined")
|
||||
logging.info("[ACTOR] join queues")
|
||||
transitions_queue.cancel_join_thread()
|
||||
interactions_queue.cancel_join_thread()
|
||||
parameters_queue.cancel_join_thread()
|
||||
|
||||
logging.info("[ACTOR] queues closed")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
|
|
@ -23,6 +23,7 @@ from tqdm import tqdm
|
|||
|
||||
from lerobot.common.datasets.lerobot_dataset import LeRobotDataset
|
||||
import os
|
||||
import pickle
|
||||
|
||||
|
||||
class Transition(TypedDict):
|
||||
|
@ -91,7 +92,7 @@ def move_transition_to_device(
|
|||
return transition
|
||||
|
||||
|
||||
def move_state_dict_to_device(state_dict, device):
|
||||
def move_state_dict_to_device(state_dict, device="cpu"):
|
||||
"""
|
||||
Recursively move all tensors in a (potentially) nested
|
||||
dict/list/tuple structure to the CPU.
|
||||
|
@ -111,20 +112,41 @@ def move_state_dict_to_device(state_dict, device):
|
|||
return state_dict
|
||||
|
||||
|
||||
def state_to_bytes(state_dict: dict[str, torch.Tensor]) -> io.BytesIO:
|
||||
def state_to_bytes(state_dict: dict[str, torch.Tensor]) -> bytes:
|
||||
"""Convert model state dict to flat array for transmission"""
|
||||
buffer = io.BytesIO()
|
||||
|
||||
torch.save(state_dict, buffer)
|
||||
|
||||
return buffer
|
||||
return buffer.getvalue()
|
||||
|
||||
|
||||
def bytes_buffer_size(buffer: io.BytesIO) -> int:
|
||||
buffer.seek(0, io.SEEK_END)
|
||||
result = buffer.tell()
|
||||
def bytes_to_state_dict(buffer: bytes) -> dict[str, torch.Tensor]:
|
||||
buffer = io.BytesIO(buffer)
|
||||
buffer.seek(0)
|
||||
return result
|
||||
return torch.load(buffer)
|
||||
|
||||
|
||||
def python_object_to_bytes(python_object: Any) -> bytes:
|
||||
return pickle.dumps(python_object)
|
||||
|
||||
|
||||
def bytes_to_python_object(buffer: bytes) -> Any:
|
||||
buffer = io.BytesIO(buffer)
|
||||
buffer.seek(0)
|
||||
return pickle.load(buffer)
|
||||
|
||||
|
||||
def bytes_to_transitions(buffer: bytes) -> list[Transition]:
|
||||
buffer = io.BytesIO(buffer)
|
||||
buffer.seek(0)
|
||||
return torch.load(buffer)
|
||||
|
||||
|
||||
def transitions_to_bytes(transitions: list[Transition]) -> bytes:
|
||||
buffer = io.BytesIO()
|
||||
torch.save(transitions, buffer)
|
||||
return buffer.getvalue()
|
||||
|
||||
|
||||
def random_crop_vectorized(images: torch.Tensor, output_size: tuple) -> torch.Tensor:
|
||||
|
|
|
@ -24,14 +24,9 @@ service LearnerService {
|
|||
// Actor -> Learner to store transitions
|
||||
rpc SendInteractionMessage(InteractionMessage) returns (Empty);
|
||||
rpc StreamParameters(Empty) returns (stream Parameters);
|
||||
rpc ReceiveTransitions(stream ActorInformation) returns (Empty);
|
||||
}
|
||||
|
||||
message ActorInformation {
|
||||
oneof data {
|
||||
Transition transition = 1;
|
||||
InteractionMessage interaction_message = 2;
|
||||
}
|
||||
rpc SendTransitions(stream Transition) returns (Empty);
|
||||
rpc SendInteractions(stream InteractionMessage) returns (Empty);
|
||||
rpc Ready(Empty) returns (Empty);
|
||||
}
|
||||
|
||||
enum TransferState {
|
||||
|
@ -43,16 +38,18 @@ enum TransferState {
|
|||
|
||||
// Messages
|
||||
message Transition {
|
||||
bytes transition_bytes = 1;
|
||||
TransferState transfer_state = 1;
|
||||
bytes data = 2;
|
||||
}
|
||||
|
||||
message Parameters {
|
||||
TransferState transfer_state = 1;
|
||||
bytes parameter_bytes = 2;
|
||||
bytes data = 2;
|
||||
}
|
||||
|
||||
message InteractionMessage {
|
||||
bytes interaction_message_bytes = 1;
|
||||
TransferState transfer_state = 1;
|
||||
bytes data = 2;
|
||||
}
|
||||
|
||||
message Empty {}
|
||||
|
|
|
@ -24,25 +24,23 @@ _sym_db = _symbol_database.Default()
|
|||
|
||||
|
||||
|
||||
DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\rhilserl.proto\x12\x08hil_serl\"\x83\x01\n\x10\x41\x63torInformation\x12*\n\ntransition\x18\x01 \x01(\x0b\x32\x14.hil_serl.TransitionH\x00\x12;\n\x13interaction_message\x18\x02 \x01(\x0b\x32\x1c.hil_serl.InteractionMessageH\x00\x42\x06\n\x04\x64\x61ta\"&\n\nTransition\x12\x18\n\x10transition_bytes\x18\x01 \x01(\x0c\"V\n\nParameters\x12/\n\x0etransfer_state\x18\x01 \x01(\x0e\x32\x17.hil_serl.TransferState\x12\x17\n\x0fparameter_bytes\x18\x02 \x01(\x0c\"7\n\x12InteractionMessage\x12!\n\x19interaction_message_bytes\x18\x01 \x01(\x0c\"\x07\n\x05\x45mpty*`\n\rTransferState\x12\x14\n\x10TRANSFER_UNKNOWN\x10\x00\x12\x12\n\x0eTRANSFER_BEGIN\x10\x01\x12\x13\n\x0fTRANSFER_MIDDLE\x10\x02\x12\x10\n\x0cTRANSFER_END\x10\x03\x32\xdb\x01\n\x0eLearnerService\x12G\n\x16SendInteractionMessage\x12\x1c.hil_serl.InteractionMessage\x1a\x0f.hil_serl.Empty\x12;\n\x10StreamParameters\x12\x0f.hil_serl.Empty\x1a\x14.hil_serl.Parameters0\x01\x12\x43\n\x12ReceiveTransitions\x12\x1a.hil_serl.ActorInformation\x1a\x0f.hil_serl.Empty(\x01\x62\x06proto3')
|
||||
DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\rhilserl.proto\x12\x08hil_serl\"K\n\nTransition\x12/\n\x0etransfer_state\x18\x01 \x01(\x0e\x32\x17.hil_serl.TransferState\x12\x0c\n\x04\x64\x61ta\x18\x02 \x01(\x0c\"K\n\nParameters\x12/\n\x0etransfer_state\x18\x01 \x01(\x0e\x32\x17.hil_serl.TransferState\x12\x0c\n\x04\x64\x61ta\x18\x02 \x01(\x0c\"S\n\x12InteractionMessage\x12/\n\x0etransfer_state\x18\x01 \x01(\x0e\x32\x17.hil_serl.TransferState\x12\x0c\n\x04\x64\x61ta\x18\x02 \x01(\x0c\"\x07\n\x05\x45mpty*`\n\rTransferState\x12\x14\n\x10TRANSFER_UNKNOWN\x10\x00\x12\x12\n\x0eTRANSFER_BEGIN\x10\x01\x12\x13\n\x0fTRANSFER_MIDDLE\x10\x02\x12\x10\n\x0cTRANSFER_END\x10\x03\x32\xc2\x02\n\x0eLearnerService\x12G\n\x16SendInteractionMessage\x12\x1c.hil_serl.InteractionMessage\x1a\x0f.hil_serl.Empty\x12;\n\x10StreamParameters\x12\x0f.hil_serl.Empty\x1a\x14.hil_serl.Parameters0\x01\x12:\n\x0fSendTransitions\x12\x14.hil_serl.Transition\x1a\x0f.hil_serl.Empty(\x01\x12\x43\n\x10SendInteractions\x12\x1c.hil_serl.InteractionMessage\x1a\x0f.hil_serl.Empty(\x01\x12)\n\x05Ready\x12\x0f.hil_serl.Empty\x1a\x0f.hil_serl.Emptyb\x06proto3')
|
||||
|
||||
_globals = globals()
|
||||
_builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, _globals)
|
||||
_builder.BuildTopDescriptorsAndMessages(DESCRIPTOR, 'hilserl_pb2', _globals)
|
||||
if not _descriptor._USE_C_DESCRIPTORS:
|
||||
DESCRIPTOR._loaded_options = None
|
||||
_globals['_TRANSFERSTATE']._serialized_start=355
|
||||
_globals['_TRANSFERSTATE']._serialized_end=451
|
||||
_globals['_ACTORINFORMATION']._serialized_start=28
|
||||
_globals['_ACTORINFORMATION']._serialized_end=159
|
||||
_globals['_TRANSITION']._serialized_start=161
|
||||
_globals['_TRANSITION']._serialized_end=199
|
||||
_globals['_PARAMETERS']._serialized_start=201
|
||||
_globals['_PARAMETERS']._serialized_end=287
|
||||
_globals['_INTERACTIONMESSAGE']._serialized_start=289
|
||||
_globals['_INTERACTIONMESSAGE']._serialized_end=344
|
||||
_globals['_EMPTY']._serialized_start=346
|
||||
_globals['_EMPTY']._serialized_end=353
|
||||
_globals['_LEARNERSERVICE']._serialized_start=454
|
||||
_globals['_LEARNERSERVICE']._serialized_end=673
|
||||
_globals['_TRANSFERSTATE']._serialized_start=275
|
||||
_globals['_TRANSFERSTATE']._serialized_end=371
|
||||
_globals['_TRANSITION']._serialized_start=27
|
||||
_globals['_TRANSITION']._serialized_end=102
|
||||
_globals['_PARAMETERS']._serialized_start=104
|
||||
_globals['_PARAMETERS']._serialized_end=179
|
||||
_globals['_INTERACTIONMESSAGE']._serialized_start=181
|
||||
_globals['_INTERACTIONMESSAGE']._serialized_end=264
|
||||
_globals['_EMPTY']._serialized_start=266
|
||||
_globals['_EMPTY']._serialized_end=273
|
||||
_globals['_LEARNERSERVICE']._serialized_start=374
|
||||
_globals['_LEARNERSERVICE']._serialized_end=696
|
||||
# @@protoc_insertion_point(module_scope)
|
||||
|
|
|
@ -46,9 +46,19 @@ class LearnerServiceStub(object):
|
|||
request_serializer=hilserl__pb2.Empty.SerializeToString,
|
||||
response_deserializer=hilserl__pb2.Parameters.FromString,
|
||||
_registered_method=True)
|
||||
self.ReceiveTransitions = channel.stream_unary(
|
||||
'/hil_serl.LearnerService/ReceiveTransitions',
|
||||
request_serializer=hilserl__pb2.ActorInformation.SerializeToString,
|
||||
self.SendTransitions = channel.stream_unary(
|
||||
'/hil_serl.LearnerService/SendTransitions',
|
||||
request_serializer=hilserl__pb2.Transition.SerializeToString,
|
||||
response_deserializer=hilserl__pb2.Empty.FromString,
|
||||
_registered_method=True)
|
||||
self.SendInteractions = channel.stream_unary(
|
||||
'/hil_serl.LearnerService/SendInteractions',
|
||||
request_serializer=hilserl__pb2.InteractionMessage.SerializeToString,
|
||||
response_deserializer=hilserl__pb2.Empty.FromString,
|
||||
_registered_method=True)
|
||||
self.Ready = channel.unary_unary(
|
||||
'/hil_serl.LearnerService/Ready',
|
||||
request_serializer=hilserl__pb2.Empty.SerializeToString,
|
||||
response_deserializer=hilserl__pb2.Empty.FromString,
|
||||
_registered_method=True)
|
||||
|
||||
|
@ -71,7 +81,19 @@ class LearnerServiceServicer(object):
|
|||
context.set_details('Method not implemented!')
|
||||
raise NotImplementedError('Method not implemented!')
|
||||
|
||||
def ReceiveTransitions(self, request_iterator, context):
|
||||
def SendTransitions(self, request_iterator, context):
|
||||
"""Missing associated documentation comment in .proto file."""
|
||||
context.set_code(grpc.StatusCode.UNIMPLEMENTED)
|
||||
context.set_details('Method not implemented!')
|
||||
raise NotImplementedError('Method not implemented!')
|
||||
|
||||
def SendInteractions(self, request_iterator, context):
|
||||
"""Missing associated documentation comment in .proto file."""
|
||||
context.set_code(grpc.StatusCode.UNIMPLEMENTED)
|
||||
context.set_details('Method not implemented!')
|
||||
raise NotImplementedError('Method not implemented!')
|
||||
|
||||
def Ready(self, request, context):
|
||||
"""Missing associated documentation comment in .proto file."""
|
||||
context.set_code(grpc.StatusCode.UNIMPLEMENTED)
|
||||
context.set_details('Method not implemented!')
|
||||
|
@ -90,9 +112,19 @@ def add_LearnerServiceServicer_to_server(servicer, server):
|
|||
request_deserializer=hilserl__pb2.Empty.FromString,
|
||||
response_serializer=hilserl__pb2.Parameters.SerializeToString,
|
||||
),
|
||||
'ReceiveTransitions': grpc.stream_unary_rpc_method_handler(
|
||||
servicer.ReceiveTransitions,
|
||||
request_deserializer=hilserl__pb2.ActorInformation.FromString,
|
||||
'SendTransitions': grpc.stream_unary_rpc_method_handler(
|
||||
servicer.SendTransitions,
|
||||
request_deserializer=hilserl__pb2.Transition.FromString,
|
||||
response_serializer=hilserl__pb2.Empty.SerializeToString,
|
||||
),
|
||||
'SendInteractions': grpc.stream_unary_rpc_method_handler(
|
||||
servicer.SendInteractions,
|
||||
request_deserializer=hilserl__pb2.InteractionMessage.FromString,
|
||||
response_serializer=hilserl__pb2.Empty.SerializeToString,
|
||||
),
|
||||
'Ready': grpc.unary_unary_rpc_method_handler(
|
||||
servicer.Ready,
|
||||
request_deserializer=hilserl__pb2.Empty.FromString,
|
||||
response_serializer=hilserl__pb2.Empty.SerializeToString,
|
||||
),
|
||||
}
|
||||
|
@ -163,7 +195,7 @@ class LearnerService(object):
|
|||
_registered_method=True)
|
||||
|
||||
@staticmethod
|
||||
def ReceiveTransitions(request_iterator,
|
||||
def SendTransitions(request_iterator,
|
||||
target,
|
||||
options=(),
|
||||
channel_credentials=None,
|
||||
|
@ -176,8 +208,62 @@ class LearnerService(object):
|
|||
return grpc.experimental.stream_unary(
|
||||
request_iterator,
|
||||
target,
|
||||
'/hil_serl.LearnerService/ReceiveTransitions',
|
||||
hilserl__pb2.ActorInformation.SerializeToString,
|
||||
'/hil_serl.LearnerService/SendTransitions',
|
||||
hilserl__pb2.Transition.SerializeToString,
|
||||
hilserl__pb2.Empty.FromString,
|
||||
options,
|
||||
channel_credentials,
|
||||
insecure,
|
||||
call_credentials,
|
||||
compression,
|
||||
wait_for_ready,
|
||||
timeout,
|
||||
metadata,
|
||||
_registered_method=True)
|
||||
|
||||
@staticmethod
|
||||
def SendInteractions(request_iterator,
|
||||
target,
|
||||
options=(),
|
||||
channel_credentials=None,
|
||||
call_credentials=None,
|
||||
insecure=False,
|
||||
compression=None,
|
||||
wait_for_ready=None,
|
||||
timeout=None,
|
||||
metadata=None):
|
||||
return grpc.experimental.stream_unary(
|
||||
request_iterator,
|
||||
target,
|
||||
'/hil_serl.LearnerService/SendInteractions',
|
||||
hilserl__pb2.InteractionMessage.SerializeToString,
|
||||
hilserl__pb2.Empty.FromString,
|
||||
options,
|
||||
channel_credentials,
|
||||
insecure,
|
||||
call_credentials,
|
||||
compression,
|
||||
wait_for_ready,
|
||||
timeout,
|
||||
metadata,
|
||||
_registered_method=True)
|
||||
|
||||
@staticmethod
|
||||
def Ready(request,
|
||||
target,
|
||||
options=(),
|
||||
channel_credentials=None,
|
||||
call_credentials=None,
|
||||
insecure=False,
|
||||
compression=None,
|
||||
wait_for_ready=None,
|
||||
timeout=None,
|
||||
metadata=None):
|
||||
return grpc.experimental.unary_unary(
|
||||
request,
|
||||
target,
|
||||
'/hil_serl.LearnerService/Ready',
|
||||
hilserl__pb2.Empty.SerializeToString,
|
||||
hilserl__pb2.Empty.FromString,
|
||||
options,
|
||||
channel_credentials,
|
||||
|
|
|
@ -15,15 +15,18 @@
|
|||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import logging
|
||||
import queue
|
||||
import shutil
|
||||
import time
|
||||
from pprint import pformat
|
||||
from threading import Lock, Thread
|
||||
import signal
|
||||
from threading import Event
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
|
||||
# from torch.multiprocessing import Event, Queue, Process
|
||||
# from threading import Event, Thread
|
||||
# from torch.multiprocessing import Queue, Event
|
||||
from torch.multiprocessing import Queue
|
||||
|
||||
from lerobot.scripts.server.utils import setup_process_handlers
|
||||
|
||||
import grpc
|
||||
|
||||
# Import generated stubs
|
||||
|
@ -52,19 +55,19 @@ from lerobot.common.utils.utils import (
|
|||
set_global_random_state,
|
||||
set_global_seed,
|
||||
)
|
||||
|
||||
from lerobot.scripts.server.buffer import (
|
||||
ReplayBuffer,
|
||||
concatenate_batch_transitions,
|
||||
move_transition_to_device,
|
||||
move_state_dict_to_device,
|
||||
bytes_to_transitions,
|
||||
state_to_bytes,
|
||||
bytes_to_python_object,
|
||||
)
|
||||
|
||||
from lerobot.scripts.server import learner_service
|
||||
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
|
||||
transition_queue = queue.Queue()
|
||||
interaction_message_queue = queue.Queue()
|
||||
|
||||
|
||||
def handle_resume_logic(cfg: DictConfig, out_dir: str) -> DictConfig:
|
||||
if not cfg.resume:
|
||||
|
@ -195,67 +198,96 @@ def get_observation_features(
|
|||
return observation_features, next_observation_features
|
||||
|
||||
|
||||
def use_threads(cfg: DictConfig) -> bool:
|
||||
return cfg.actor_learner_config.concurrency.learner == "threads"
|
||||
|
||||
|
||||
def start_learner_threads(
|
||||
cfg: DictConfig,
|
||||
device: str,
|
||||
replay_buffer: ReplayBuffer,
|
||||
offline_replay_buffer: ReplayBuffer,
|
||||
batch_size: int,
|
||||
optimizers: dict,
|
||||
policy: SACPolicy,
|
||||
policy_lock: Lock,
|
||||
logger: Logger,
|
||||
resume_optimization_step: int | None = None,
|
||||
resume_interaction_step: int | None = None,
|
||||
shutdown_event: Event | None = None,
|
||||
out_dir: str,
|
||||
shutdown_event: any, # Event,
|
||||
) -> None:
|
||||
host = cfg.actor_learner_config.learner_host
|
||||
port = cfg.actor_learner_config.learner_port
|
||||
# Create multiprocessing queues
|
||||
transition_queue = Queue()
|
||||
interaction_message_queue = Queue()
|
||||
parameters_queue = Queue()
|
||||
|
||||
transition_thread = Thread(
|
||||
target=add_actor_information_and_train,
|
||||
daemon=True,
|
||||
concurrency_entity = None
|
||||
|
||||
if use_threads(cfg):
|
||||
from threading import Thread
|
||||
|
||||
concurrency_entity = Thread
|
||||
else:
|
||||
from torch.multiprocessing import Process
|
||||
|
||||
concurrency_entity = Process
|
||||
|
||||
communication_process = concurrency_entity(
|
||||
target=start_learner_server,
|
||||
args=(
|
||||
cfg,
|
||||
device,
|
||||
replay_buffer,
|
||||
offline_replay_buffer,
|
||||
batch_size,
|
||||
optimizers,
|
||||
policy,
|
||||
policy_lock,
|
||||
logger,
|
||||
resume_optimization_step,
|
||||
resume_interaction_step,
|
||||
parameters_queue,
|
||||
transition_queue,
|
||||
interaction_message_queue,
|
||||
shutdown_event,
|
||||
cfg,
|
||||
),
|
||||
daemon=True,
|
||||
)
|
||||
communication_process.start()
|
||||
|
||||
transition_thread.start()
|
||||
add_actor_information_and_train(
|
||||
cfg,
|
||||
logger,
|
||||
out_dir,
|
||||
shutdown_event,
|
||||
transition_queue,
|
||||
interaction_message_queue,
|
||||
parameters_queue,
|
||||
)
|
||||
logging.info("[LEARNER] Training process stopped")
|
||||
|
||||
logging.info("[LEARNER] Closing queues")
|
||||
transition_queue.close()
|
||||
interaction_message_queue.close()
|
||||
parameters_queue.close()
|
||||
|
||||
communication_process.join()
|
||||
logging.info("[LEARNER] Communication process joined")
|
||||
|
||||
logging.info("[LEARNER] join queues")
|
||||
transition_queue.cancel_join_thread()
|
||||
interaction_message_queue.cancel_join_thread()
|
||||
parameters_queue.cancel_join_thread()
|
||||
|
||||
logging.info("[LEARNER] queues closed")
|
||||
|
||||
|
||||
def start_learner_server(
|
||||
parameters_queue: Queue,
|
||||
transition_queue: Queue,
|
||||
interaction_message_queue: Queue,
|
||||
shutdown_event: any, # Event,
|
||||
cfg: DictConfig,
|
||||
):
|
||||
if not use_threads(cfg):
|
||||
# We need init logging for MP separataly
|
||||
init_logging()
|
||||
|
||||
# Setup process handlers to handle shutdown signal
|
||||
# But use shutdown event from the main process
|
||||
# Return back for MP
|
||||
setup_process_handlers(False)
|
||||
|
||||
service = learner_service.LearnerService(
|
||||
shutdown_event,
|
||||
policy,
|
||||
policy_lock,
|
||||
parameters_queue,
|
||||
cfg.actor_learner_config.policy_parameters_push_frequency,
|
||||
transition_queue,
|
||||
interaction_message_queue,
|
||||
)
|
||||
server = start_learner_server(service, host, port)
|
||||
|
||||
shutdown_event.wait()
|
||||
server.stop(learner_service.STUTDOWN_TIMEOUT)
|
||||
logging.info("[LEARNER] gRPC server stopped")
|
||||
|
||||
transition_thread.join()
|
||||
logging.info("[LEARNER] Transition thread stopped")
|
||||
|
||||
|
||||
def start_learner_server(
|
||||
service: learner_service.LearnerService,
|
||||
host="0.0.0.0",
|
||||
port=50051,
|
||||
) -> grpc.server:
|
||||
server = grpc.server(
|
||||
ThreadPoolExecutor(max_workers=learner_service.MAX_WORKERS),
|
||||
options=[
|
||||
|
@ -263,15 +295,23 @@ def start_learner_server(
|
|||
("grpc.max_send_message_length", learner_service.MAX_MESSAGE_SIZE),
|
||||
],
|
||||
)
|
||||
|
||||
hilserl_pb2_grpc.add_LearnerServiceServicer_to_server(
|
||||
service,
|
||||
server,
|
||||
)
|
||||
|
||||
host = cfg.actor_learner_config.learner_host
|
||||
port = cfg.actor_learner_config.learner_port
|
||||
|
||||
server.add_insecure_port(f"{host}:{port}")
|
||||
server.start()
|
||||
logging.info("[LEARNER] gRPC server started")
|
||||
|
||||
return server
|
||||
shutdown_event.wait()
|
||||
logging.info("[LEARNER] Stopping gRPC server...")
|
||||
server.stop(learner_service.STUTDOWN_TIMEOUT)
|
||||
logging.info("[LEARNER] gRPC server stopped")
|
||||
|
||||
|
||||
def check_nan_in_transition(
|
||||
|
@ -287,19 +327,21 @@ def check_nan_in_transition(
|
|||
logging.error("actions contains NaN values")
|
||||
|
||||
|
||||
def push_actor_policy_to_queue(parameters_queue: Queue, policy: nn.Module):
|
||||
logging.debug("[LEARNER] Pushing actor policy to the queue")
|
||||
state_dict = move_state_dict_to_device(policy.actor.state_dict(), device="cpu")
|
||||
state_bytes = state_to_bytes(state_dict)
|
||||
parameters_queue.put(state_bytes)
|
||||
|
||||
|
||||
def add_actor_information_and_train(
|
||||
cfg,
|
||||
device: str,
|
||||
replay_buffer: ReplayBuffer,
|
||||
offline_replay_buffer: ReplayBuffer,
|
||||
batch_size: int,
|
||||
optimizers: dict[str, torch.optim.Optimizer],
|
||||
policy: nn.Module,
|
||||
policy_lock: Lock,
|
||||
logger: Logger,
|
||||
resume_optimization_step: int | None = None,
|
||||
resume_interaction_step: int | None = None,
|
||||
shutdown_event: Event | None = None,
|
||||
out_dir: str,
|
||||
shutdown_event: any, # Event,
|
||||
transition_queue: Queue,
|
||||
interaction_message_queue: Queue,
|
||||
parameters_queue: Queue,
|
||||
):
|
||||
"""
|
||||
Handles data transfer from the actor to the learner, manages training updates,
|
||||
|
@ -322,17 +364,73 @@ def add_actor_information_and_train(
|
|||
Args:
|
||||
cfg: Configuration object containing hyperparameters.
|
||||
device (str): The computing device (`"cpu"` or `"cuda"`).
|
||||
replay_buffer (ReplayBuffer): The primary replay buffer storing online transitions.
|
||||
offline_replay_buffer (ReplayBuffer): An additional buffer for offline transitions.
|
||||
batch_size (int): The number of transitions to sample per training step.
|
||||
optimizers (Dict[str, torch.optim.Optimizer]): A dictionary of optimizers (`"actor"`, `"critic"`, `"temperature"`).
|
||||
policy (nn.Module): The reinforcement learning policy with critic, actor, and temperature parameters.
|
||||
policy_lock (Lock): A threading lock to ensure safe policy updates.
|
||||
logger (Logger): Logger instance for tracking training progress.
|
||||
resume_optimization_step (int | None): In the case of resume training, start from the last optimization step reached.
|
||||
resume_interaction_step (int | None): In the case of resume training, shift the interaction step with the last saved step in order to not break logging.
|
||||
shutdown_event (Event | None): Event to signal shutdown.
|
||||
out_dir (str): The output directory for storing training checkpoints and logs.
|
||||
shutdown_event (Event): Event to signal shutdown.
|
||||
transition_queue (Queue): Queue for receiving transitions from the actor.
|
||||
interaction_message_queue (Queue): Queue for receiving interaction messages from the actor.
|
||||
parameters_queue (Queue): Queue for sending policy parameters to the actor.
|
||||
"""
|
||||
|
||||
device = get_safe_torch_device(cfg.device, log=True)
|
||||
storage_device = get_safe_torch_device(cfg_device=cfg.training.storage_device)
|
||||
|
||||
logging.info("Initializing policy")
|
||||
### Instantiate the policy in both the actor and learner processes
|
||||
### To avoid sending a SACPolicy object through the port, we create a policy intance
|
||||
### on both sides, the learner sends the updated parameters every n steps to update the actor's parameters
|
||||
# TODO: At some point we should just need make sac policy
|
||||
|
||||
policy: SACPolicy = make_policy(
|
||||
hydra_cfg=cfg,
|
||||
# dataset_stats=offline_dataset.meta.stats if not cfg.resume else None,
|
||||
# Hack: But if we do online traning, we do not need dataset_stats
|
||||
dataset_stats=None,
|
||||
pretrained_policy_name_or_path=str(logger.last_pretrained_model_dir)
|
||||
if cfg.resume
|
||||
else None,
|
||||
)
|
||||
# compile policy
|
||||
policy = torch.compile(policy)
|
||||
assert isinstance(policy, nn.Module)
|
||||
|
||||
push_actor_policy_to_queue(parameters_queue, policy)
|
||||
|
||||
last_time_policy_pushed = time.time()
|
||||
|
||||
optimizers, lr_scheduler = make_optimizers_and_scheduler(cfg, policy)
|
||||
resume_optimization_step, resume_interaction_step = load_training_state(
|
||||
cfg, logger, optimizers
|
||||
)
|
||||
|
||||
log_training_info(cfg, out_dir, policy)
|
||||
|
||||
replay_buffer = initialize_replay_buffer(cfg, logger, device, storage_device)
|
||||
batch_size = cfg.training.batch_size
|
||||
offline_replay_buffer = None
|
||||
|
||||
if cfg.dataset_repo_id is not None:
|
||||
logging.info("make_dataset offline buffer")
|
||||
offline_dataset = make_dataset(cfg)
|
||||
logging.info("Convertion to a offline replay buffer")
|
||||
active_action_dims = None
|
||||
if cfg.env.wrapper.joint_masking_action_space is not None:
|
||||
active_action_dims = [
|
||||
i
|
||||
for i, mask in enumerate(cfg.env.wrapper.joint_masking_action_space)
|
||||
if mask
|
||||
]
|
||||
offline_replay_buffer = ReplayBuffer.from_lerobot_dataset(
|
||||
offline_dataset,
|
||||
device=device,
|
||||
state_keys=cfg.policy.input_shapes.keys(),
|
||||
action_mask=active_action_dims,
|
||||
action_delta=cfg.env.wrapper.delta_action,
|
||||
storage_device=storage_device,
|
||||
optimize_memory=True,
|
||||
)
|
||||
batch_size: int = batch_size // 2 # We will sample from both replay buffer
|
||||
|
||||
# NOTE: This function doesn't have a single responsibility, it should be split into multiple functions
|
||||
# in the future. The reason why we did that is the GIL in Python. It's super slow the performance
|
||||
# are divided by 200. So we need to have a single thread that does all the work.
|
||||
|
@ -345,33 +443,39 @@ def add_actor_information_and_train(
|
|||
interaction_step_shift = (
|
||||
resume_interaction_step if resume_interaction_step is not None else 0
|
||||
)
|
||||
saved_data = False
|
||||
|
||||
while True:
|
||||
if shutdown_event is not None and shutdown_event.is_set():
|
||||
logging.info("[LEARNER] Shutdown signal received. Exiting...")
|
||||
break
|
||||
|
||||
while not transition_queue.empty():
|
||||
logging.debug("[LEARNER] Waiting for transitions")
|
||||
while not transition_queue.empty() and not shutdown_event.is_set():
|
||||
transition_list = transition_queue.get()
|
||||
transition_list = bytes_to_transitions(transition_list)
|
||||
|
||||
for transition in transition_list:
|
||||
transition = move_transition_to_device(transition, device=device)
|
||||
replay_buffer.add(**transition)
|
||||
|
||||
if transition.get("complementary_info", {}).get("is_intervention"):
|
||||
offline_replay_buffer.add(**transition)
|
||||
|
||||
while not interaction_message_queue.empty():
|
||||
logging.debug("[LEARNER] Received transitions")
|
||||
logging.debug("[LEARNER] Waiting for interactions")
|
||||
while not interaction_message_queue.empty() and not shutdown_event.is_set():
|
||||
interaction_message = interaction_message_queue.get()
|
||||
interaction_message = bytes_to_python_object(interaction_message)
|
||||
# If cfg.resume, shift the interaction step with the last checkpointed step in order to not break the logging
|
||||
interaction_message["Interaction step"] += interaction_step_shift
|
||||
logger.log_dict(
|
||||
interaction_message, mode="train", custom_step_key="Interaction step"
|
||||
)
|
||||
# logging.info(f"Interaction message: {interaction_message}")
|
||||
|
||||
logging.debug("[LEARNER] Received interactions")
|
||||
|
||||
if len(replay_buffer) < cfg.training.online_step_before_learning:
|
||||
continue
|
||||
|
||||
logging.debug("[LEARNER] Starting optimization loop")
|
||||
time_for_one_optimization_step = time.time()
|
||||
for _ in range(cfg.policy.utd_ratio - 1):
|
||||
batch = replay_buffer.sample(batch_size)
|
||||
|
@ -392,19 +496,18 @@ def add_actor_information_and_train(
|
|||
observation_features, next_observation_features = get_observation_features(
|
||||
policy, observations, next_observations
|
||||
)
|
||||
with policy_lock:
|
||||
loss_critic = policy.compute_loss_critic(
|
||||
observations=observations,
|
||||
actions=actions,
|
||||
rewards=rewards,
|
||||
next_observations=next_observations,
|
||||
done=done,
|
||||
observation_features=observation_features,
|
||||
next_observation_features=next_observation_features,
|
||||
)
|
||||
optimizers["critic"].zero_grad()
|
||||
loss_critic.backward()
|
||||
optimizers["critic"].step()
|
||||
loss_critic = policy.compute_loss_critic(
|
||||
observations=observations,
|
||||
actions=actions,
|
||||
rewards=rewards,
|
||||
next_observations=next_observations,
|
||||
done=done,
|
||||
observation_features=observation_features,
|
||||
next_observation_features=next_observation_features,
|
||||
)
|
||||
optimizers["critic"].zero_grad()
|
||||
loss_critic.backward()
|
||||
optimizers["critic"].step()
|
||||
|
||||
batch = replay_buffer.sample(batch_size)
|
||||
|
||||
|
@ -427,46 +530,51 @@ def add_actor_information_and_train(
|
|||
observation_features, next_observation_features = get_observation_features(
|
||||
policy, observations, next_observations
|
||||
)
|
||||
with policy_lock:
|
||||
loss_critic = policy.compute_loss_critic(
|
||||
observations=observations,
|
||||
actions=actions,
|
||||
rewards=rewards,
|
||||
next_observations=next_observations,
|
||||
done=done,
|
||||
observation_features=observation_features,
|
||||
next_observation_features=next_observation_features,
|
||||
)
|
||||
optimizers["critic"].zero_grad()
|
||||
loss_critic.backward()
|
||||
optimizers["critic"].step()
|
||||
loss_critic = policy.compute_loss_critic(
|
||||
observations=observations,
|
||||
actions=actions,
|
||||
rewards=rewards,
|
||||
next_observations=next_observations,
|
||||
done=done,
|
||||
observation_features=observation_features,
|
||||
next_observation_features=next_observation_features,
|
||||
)
|
||||
optimizers["critic"].zero_grad()
|
||||
loss_critic.backward()
|
||||
optimizers["critic"].step()
|
||||
|
||||
training_infos = {}
|
||||
training_infos["loss_critic"] = loss_critic.item()
|
||||
|
||||
if optimization_step % cfg.training.policy_update_freq == 0:
|
||||
for _ in range(cfg.training.policy_update_freq):
|
||||
with policy_lock:
|
||||
loss_actor = policy.compute_loss_actor(
|
||||
observations=observations,
|
||||
observation_features=observation_features,
|
||||
)
|
||||
loss_actor = policy.compute_loss_actor(
|
||||
observations=observations,
|
||||
observation_features=observation_features,
|
||||
)
|
||||
|
||||
optimizers["actor"].zero_grad()
|
||||
loss_actor.backward()
|
||||
optimizers["actor"].step()
|
||||
optimizers["actor"].zero_grad()
|
||||
loss_actor.backward()
|
||||
optimizers["actor"].step()
|
||||
|
||||
training_infos["loss_actor"] = loss_actor.item()
|
||||
training_infos["loss_actor"] = loss_actor.item()
|
||||
|
||||
loss_temperature = policy.compute_loss_temperature(
|
||||
observations=observations,
|
||||
observation_features=observation_features,
|
||||
)
|
||||
optimizers["temperature"].zero_grad()
|
||||
loss_temperature.backward()
|
||||
optimizers["temperature"].step()
|
||||
loss_temperature = policy.compute_loss_temperature(
|
||||
observations=observations,
|
||||
observation_features=observation_features,
|
||||
)
|
||||
optimizers["temperature"].zero_grad()
|
||||
loss_temperature.backward()
|
||||
optimizers["temperature"].step()
|
||||
|
||||
training_infos["loss_temperature"] = loss_temperature.item()
|
||||
training_infos["loss_temperature"] = loss_temperature.item()
|
||||
|
||||
if (
|
||||
time.time() - last_time_policy_pushed
|
||||
> cfg.actor_learner_config.policy_parameters_push_frequency
|
||||
):
|
||||
push_actor_policy_to_queue(parameters_queue, policy)
|
||||
last_time_policy_pushed = time.time()
|
||||
|
||||
policy.update_target_networks()
|
||||
if optimization_step % cfg.training.log_freq == 0:
|
||||
|
@ -595,104 +703,36 @@ def train(cfg: DictConfig, out_dir: str | None = None, job_name: str | None = No
|
|||
|
||||
set_global_seed(cfg.seed)
|
||||
|
||||
device = get_safe_torch_device(cfg.device, log=True)
|
||||
storage_device = get_safe_torch_device(cfg_device=cfg.training.storage_device)
|
||||
|
||||
torch.backends.cudnn.benchmark = True
|
||||
torch.backends.cuda.matmul.allow_tf32 = True
|
||||
|
||||
logging.info("make_policy")
|
||||
|
||||
### Instantiate the policy in both the actor and learner processes
|
||||
### To avoid sending a SACPolicy object through the port, we create a policy intance
|
||||
### on both sides, the learner sends the updated parameters every n steps to update the actor's parameters
|
||||
# TODO: At some point we should just need make sac policy
|
||||
|
||||
policy_lock = Lock()
|
||||
policy: SACPolicy = make_policy(
|
||||
hydra_cfg=cfg,
|
||||
# dataset_stats=offline_dataset.meta.stats if not cfg.resume else None,
|
||||
# Hack: But if we do online traning, we do not need dataset_stats
|
||||
dataset_stats=None,
|
||||
pretrained_policy_name_or_path=str(logger.last_pretrained_model_dir)
|
||||
if cfg.resume
|
||||
else None,
|
||||
)
|
||||
# compile policy
|
||||
policy = torch.compile(policy)
|
||||
assert isinstance(policy, nn.Module)
|
||||
|
||||
optimizers, lr_scheduler = make_optimizers_and_scheduler(cfg, policy)
|
||||
resume_optimization_step, resume_interaction_step = load_training_state(
|
||||
cfg, logger, optimizers
|
||||
)
|
||||
|
||||
log_training_info(cfg, out_dir, policy)
|
||||
|
||||
replay_buffer = initialize_replay_buffer(cfg, logger, device, storage_device)
|
||||
batch_size = cfg.training.batch_size
|
||||
offline_replay_buffer = None
|
||||
|
||||
if cfg.dataset_repo_id is not None:
|
||||
logging.info("make_dataset offline buffer")
|
||||
offline_dataset = make_dataset(cfg)
|
||||
logging.info("Convertion to a offline replay buffer")
|
||||
active_action_dims = None
|
||||
if cfg.env.wrapper.joint_masking_action_space is not None:
|
||||
active_action_dims = [
|
||||
i
|
||||
for i, mask in enumerate(cfg.env.wrapper.joint_masking_action_space)
|
||||
if mask
|
||||
]
|
||||
offline_replay_buffer = ReplayBuffer.from_lerobot_dataset(
|
||||
offline_dataset,
|
||||
device=device,
|
||||
state_keys=cfg.policy.input_shapes.keys(),
|
||||
action_mask=active_action_dims,
|
||||
action_delta=cfg.env.wrapper.delta_action,
|
||||
storage_device=storage_device,
|
||||
optimize_memory=True,
|
||||
)
|
||||
batch_size: int = batch_size // 2 # We will sample from both replay buffer
|
||||
|
||||
shutdown_event = Event()
|
||||
|
||||
def signal_handler(signum, frame):
|
||||
print(
|
||||
f"\nReceived signal {signal.Signals(signum).name}. Initiating learner shutdown..."
|
||||
)
|
||||
shutdown_event.set()
|
||||
|
||||
# Register signal handlers
|
||||
signal.signal(signal.SIGINT, signal_handler) # Ctrl+C
|
||||
signal.signal(signal.SIGTERM, signal_handler) # Termination request
|
||||
signal.signal(signal.SIGHUP, signal_handler) # Terminal closed
|
||||
signal.signal(signal.SIGQUIT, signal_handler) # Ctrl+\
|
||||
shutdown_event = setup_process_handlers(use_threads(cfg))
|
||||
|
||||
start_learner_threads(
|
||||
cfg,
|
||||
device,
|
||||
replay_buffer,
|
||||
offline_replay_buffer,
|
||||
batch_size,
|
||||
optimizers,
|
||||
policy,
|
||||
policy_lock,
|
||||
logger,
|
||||
resume_optimization_step,
|
||||
resume_interaction_step,
|
||||
out_dir,
|
||||
shutdown_event,
|
||||
)
|
||||
|
||||
|
||||
@hydra.main(version_base="1.2", config_name="default", config_path="../../configs")
|
||||
def train_cli(cfg: dict):
|
||||
if not use_threads(cfg):
|
||||
import torch.multiprocessing as mp
|
||||
|
||||
mp.set_start_method("spawn")
|
||||
|
||||
train(
|
||||
cfg,
|
||||
out_dir=hydra.core.hydra_config.HydraConfig.get().run.dir,
|
||||
job_name=hydra.core.hydra_config.HydraConfig.get().job.name,
|
||||
)
|
||||
|
||||
logging.info("[LEARNER] train_cli finished")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
train_cli()
|
||||
|
||||
logging.info("[LEARNER] main finished")
|
||||
|
|
|
@ -1,23 +1,13 @@
|
|||
import hilserl_pb2 # type: ignore
|
||||
import hilserl_pb2_grpc # type: ignore
|
||||
import torch
|
||||
from torch import nn
|
||||
from threading import Lock, Event
|
||||
import logging
|
||||
import queue
|
||||
import io
|
||||
import pickle
|
||||
|
||||
from lerobot.scripts.server.buffer import (
|
||||
move_state_dict_to_device,
|
||||
bytes_buffer_size,
|
||||
state_to_bytes,
|
||||
)
|
||||
from multiprocessing import Event, Queue
|
||||
|
||||
from lerobot.scripts.server.network_utils import receive_bytes_in_chunks
|
||||
from lerobot.scripts.server.network_utils import send_bytes_in_chunks
|
||||
|
||||
MAX_MESSAGE_SIZE = 4 * 1024 * 1024 # 4 MB
|
||||
CHUNK_SIZE = 2 * 1024 * 1024 # 2 MB
|
||||
MAX_WORKERS = 10
|
||||
MAX_WORKERS = 3 # Stream parameters, send transitions and interactions
|
||||
STUTDOWN_TIMEOUT = 10
|
||||
|
||||
|
||||
|
@ -25,89 +15,68 @@ class LearnerService(hilserl_pb2_grpc.LearnerServiceServicer):
|
|||
def __init__(
|
||||
self,
|
||||
shutdown_event: Event,
|
||||
policy: nn.Module,
|
||||
policy_lock: Lock,
|
||||
parameters_queue: Queue,
|
||||
seconds_between_pushes: float,
|
||||
transition_queue: queue.Queue,
|
||||
interaction_message_queue: queue.Queue,
|
||||
transition_queue: Queue,
|
||||
interaction_message_queue: Queue,
|
||||
):
|
||||
self.shutdown_event = shutdown_event
|
||||
self.policy = policy
|
||||
self.policy_lock = policy_lock
|
||||
self.parameters_queue = parameters_queue
|
||||
self.seconds_between_pushes = seconds_between_pushes
|
||||
self.transition_queue = transition_queue
|
||||
self.interaction_message_queue = interaction_message_queue
|
||||
|
||||
def _get_policy_state(self):
|
||||
with self.policy_lock:
|
||||
params_dict = self.policy.actor.state_dict()
|
||||
# if self.policy.config.vision_encoder_name is not None:
|
||||
# if self.policy.config.freeze_vision_encoder:
|
||||
# params_dict: dict[str, torch.Tensor] = {
|
||||
# k: v
|
||||
# for k, v in params_dict.items()
|
||||
# if not k.startswith("encoder.")
|
||||
# }
|
||||
# else:
|
||||
# raise NotImplementedError(
|
||||
# "Vision encoder is not frozen, we need to send the full model over the network which requires chunking the model."
|
||||
# )
|
||||
|
||||
return move_state_dict_to_device(params_dict, device="cpu")
|
||||
|
||||
def _send_bytes(self, buffer: bytes):
|
||||
size_in_bytes = bytes_buffer_size(buffer)
|
||||
|
||||
sent_bytes = 0
|
||||
|
||||
logging.info(f"Model state size {size_in_bytes/1024/1024} MB with")
|
||||
|
||||
while sent_bytes < size_in_bytes:
|
||||
transfer_state = hilserl_pb2.TransferState.TRANSFER_MIDDLE
|
||||
|
||||
if sent_bytes + CHUNK_SIZE >= size_in_bytes:
|
||||
transfer_state = hilserl_pb2.TransferState.TRANSFER_END
|
||||
elif sent_bytes == 0:
|
||||
transfer_state = hilserl_pb2.TransferState.TRANSFER_BEGIN
|
||||
|
||||
size_to_read = min(CHUNK_SIZE, size_in_bytes - sent_bytes)
|
||||
chunk = buffer.read(size_to_read)
|
||||
|
||||
yield hilserl_pb2.Parameters(
|
||||
transfer_state=transfer_state, parameter_bytes=chunk
|
||||
)
|
||||
sent_bytes += size_to_read
|
||||
logging.info(
|
||||
f"[Learner] Sent {sent_bytes}/{size_in_bytes} bytes with state {transfer_state}"
|
||||
)
|
||||
|
||||
logging.info(f"[LEARNER] Published {sent_bytes/1024/1024} MB to the Actor")
|
||||
|
||||
def StreamParameters(self, request, context):
|
||||
# TODO: authorize the request
|
||||
logging.info("[LEARNER] Received request to stream parameters from the Actor")
|
||||
|
||||
while not self.shutdown_event.is_set():
|
||||
logging.debug("[LEARNER] Push parameters to the Actor")
|
||||
state_dict = self._get_policy_state()
|
||||
logging.info("[LEARNER] Push parameters to the Actor")
|
||||
buffer = self.parameters_queue.get()
|
||||
|
||||
with state_to_bytes(state_dict) as buffer:
|
||||
yield from self._send_bytes(buffer)
|
||||
yield from send_bytes_in_chunks(
|
||||
buffer,
|
||||
hilserl_pb2.Parameters,
|
||||
log_prefix="[LEARNER] Sending parameters",
|
||||
silent=True,
|
||||
)
|
||||
|
||||
logging.info("[LEARNER] Parameters sent")
|
||||
|
||||
self.shutdown_event.wait(self.seconds_between_pushes)
|
||||
|
||||
def ReceiveTransitions(self, request_iterator, context):
|
||||
logging.info("[LEARNER] Stream parameters finished")
|
||||
return hilserl_pb2.Empty()
|
||||
|
||||
def SendTransitions(self, request_iterator, _context):
|
||||
# TODO: authorize the request
|
||||
logging.info("[LEARNER] Received request to receive transitions from the Actor")
|
||||
|
||||
for request in request_iterator:
|
||||
logging.debug("[LEARNER] Received request")
|
||||
if request.HasField("transition"):
|
||||
buffer = io.BytesIO(request.transition.transition_bytes)
|
||||
transition = torch.load(buffer)
|
||||
self.transition_queue.put(transition)
|
||||
if request.HasField("interaction_message"):
|
||||
content = pickle.loads(
|
||||
request.interaction_message.interaction_message_bytes
|
||||
)
|
||||
self.interaction_message_queue.put(content)
|
||||
receive_bytes_in_chunks(
|
||||
request_iterator,
|
||||
self.transition_queue,
|
||||
self.shutdown_event,
|
||||
log_prefix="[LEARNER] transitions",
|
||||
)
|
||||
|
||||
logging.debug("[LEARNER] Finished receiving transitions")
|
||||
return hilserl_pb2.Empty()
|
||||
|
||||
def SendInteractions(self, request_iterator, _context):
|
||||
# TODO: authorize the request
|
||||
logging.info(
|
||||
"[LEARNER] Received request to receive interactions from the Actor"
|
||||
)
|
||||
|
||||
receive_bytes_in_chunks(
|
||||
request_iterator,
|
||||
self.interaction_message_queue,
|
||||
self.shutdown_event,
|
||||
log_prefix="[LEARNER] interactions",
|
||||
)
|
||||
|
||||
logging.debug("[LEARNER] Finished receiving interactions")
|
||||
return hilserl_pb2.Empty()
|
||||
|
||||
def Ready(self, request, context):
|
||||
return hilserl_pb2.Empty()
|
||||
|
|
|
@ -5,9 +5,8 @@ import torch
|
|||
|
||||
from omegaconf import DictConfig
|
||||
from typing import Any
|
||||
|
||||
"""Make ManiSkill3 gym environment"""
|
||||
from mani_skill.vector.wrappers.gymnasium import ManiSkillVectorEnv
|
||||
from mani_skill.utils.wrappers.record import RecordEpisode
|
||||
|
||||
|
||||
def preprocess_maniskill_observation(
|
||||
|
@ -143,6 +142,15 @@ def make_maniskill(
|
|||
num_envs=n_envs,
|
||||
)
|
||||
|
||||
if cfg.env.video_record.enabled:
|
||||
env = RecordEpisode(
|
||||
env,
|
||||
output_dir=cfg.env.video_record.record_dir,
|
||||
save_trajectory=True,
|
||||
trajectory_name=cfg.env.video_record.trajectory_name,
|
||||
save_video=True,
|
||||
video_fps=30,
|
||||
)
|
||||
env = ManiSkillObservationWrapper(env, device=cfg.env.device)
|
||||
env = ManiSkillVectorEnv(env, ignore_terminations=True, auto_reset=False)
|
||||
env._max_episode_steps = env.max_episode_steps = (
|
||||
|
|
|
@ -0,0 +1,102 @@
|
|||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2024 The HuggingFace Inc. team.
|
||||
# All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from lerobot.scripts.server import hilserl_pb2
|
||||
import logging
|
||||
import io
|
||||
from multiprocessing import Queue, Event
|
||||
from typing import Any
|
||||
|
||||
CHUNK_SIZE = 2 * 1024 * 1024 # 2 MB
|
||||
|
||||
|
||||
def bytes_buffer_size(buffer: io.BytesIO) -> int:
|
||||
buffer.seek(0, io.SEEK_END)
|
||||
result = buffer.tell()
|
||||
buffer.seek(0)
|
||||
return result
|
||||
|
||||
|
||||
def send_bytes_in_chunks(
|
||||
buffer: bytes, message_class: Any, log_prefix: str = "", silent: bool = True
|
||||
):
|
||||
buffer = io.BytesIO(buffer)
|
||||
size_in_bytes = bytes_buffer_size(buffer)
|
||||
|
||||
sent_bytes = 0
|
||||
|
||||
logging_method = logging.info if not silent else logging.debug
|
||||
|
||||
logging_method(f"{log_prefix} Buffer size {size_in_bytes/1024/1024} MB with")
|
||||
|
||||
while sent_bytes < size_in_bytes:
|
||||
transfer_state = hilserl_pb2.TransferState.TRANSFER_MIDDLE
|
||||
|
||||
if sent_bytes + CHUNK_SIZE >= size_in_bytes:
|
||||
transfer_state = hilserl_pb2.TransferState.TRANSFER_END
|
||||
elif sent_bytes == 0:
|
||||
transfer_state = hilserl_pb2.TransferState.TRANSFER_BEGIN
|
||||
|
||||
size_to_read = min(CHUNK_SIZE, size_in_bytes - sent_bytes)
|
||||
chunk = buffer.read(size_to_read)
|
||||
|
||||
yield message_class(transfer_state=transfer_state, data=chunk)
|
||||
sent_bytes += size_to_read
|
||||
logging_method(
|
||||
f"{log_prefix} Sent {sent_bytes}/{size_in_bytes} bytes with state {transfer_state}"
|
||||
)
|
||||
|
||||
logging_method(f"{log_prefix} Published {sent_bytes/1024/1024} MB")
|
||||
|
||||
|
||||
def receive_bytes_in_chunks(
|
||||
iterator, queue: Queue, shutdown_event: Event, log_prefix: str = ""
|
||||
):
|
||||
bytes_buffer = io.BytesIO()
|
||||
step = 0
|
||||
|
||||
logging.info(f"{log_prefix} Starting receiver")
|
||||
for item in iterator:
|
||||
logging.debug(f"{log_prefix} Received item")
|
||||
if shutdown_event.is_set():
|
||||
logging.info(f"{log_prefix} Shutting down receiver")
|
||||
return
|
||||
|
||||
if item.transfer_state == hilserl_pb2.TransferState.TRANSFER_BEGIN:
|
||||
bytes_buffer.seek(0)
|
||||
bytes_buffer.truncate(0)
|
||||
bytes_buffer.write(item.data)
|
||||
logging.debug(f"{log_prefix} Received data at step 0")
|
||||
step = 0
|
||||
continue
|
||||
elif item.transfer_state == hilserl_pb2.TransferState.TRANSFER_MIDDLE:
|
||||
bytes_buffer.write(item.data)
|
||||
step += 1
|
||||
logging.debug(f"{log_prefix} Received data at step {step}")
|
||||
elif item.transfer_state == hilserl_pb2.TransferState.TRANSFER_END:
|
||||
bytes_buffer.write(item.data)
|
||||
logging.debug(
|
||||
f"{log_prefix} Received data at step end size {bytes_buffer_size(bytes_buffer)}"
|
||||
)
|
||||
|
||||
queue.put(bytes_buffer.getvalue())
|
||||
|
||||
bytes_buffer.seek(0)
|
||||
bytes_buffer.truncate(0)
|
||||
step = 0
|
||||
|
||||
logging.debug(f"{log_prefix} Queue updated")
|
|
@ -0,0 +1,72 @@
|
|||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2024 The HuggingFace Inc. team.
|
||||
# All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import logging
|
||||
import signal
|
||||
import sys
|
||||
from torch.multiprocessing import Queue
|
||||
from queue import Empty
|
||||
|
||||
shutdown_event_counter = 0
|
||||
|
||||
|
||||
def setup_process_handlers(use_threads: bool) -> any:
|
||||
if use_threads:
|
||||
from threading import Event
|
||||
else:
|
||||
from multiprocessing import Event
|
||||
|
||||
shutdown_event = Event()
|
||||
|
||||
# Define signal handler
|
||||
def signal_handler(signum, frame):
|
||||
logging.info("Shutdown signal received. Cleaning up...")
|
||||
shutdown_event.set()
|
||||
global shutdown_event_counter
|
||||
shutdown_event_counter += 1
|
||||
|
||||
if shutdown_event_counter > 1:
|
||||
logging.info("Force shutdown")
|
||||
sys.exit(1)
|
||||
|
||||
signal.signal(signal.SIGINT, signal_handler) # Ctrl+C
|
||||
signal.signal(signal.SIGTERM, signal_handler) # Termination request (kill)
|
||||
signal.signal(signal.SIGHUP, signal_handler) # Terminal closed/Hangup
|
||||
signal.signal(signal.SIGQUIT, signal_handler) # Ctrl+\
|
||||
|
||||
def signal_handler(signum, frame):
|
||||
logging.info("Shutdown signal received. Cleaning up...")
|
||||
shutdown_event.set()
|
||||
|
||||
return shutdown_event
|
||||
|
||||
|
||||
def get_last_item_from_queue(queue: Queue):
|
||||
item = queue.get()
|
||||
counter = 1
|
||||
|
||||
# Drain queue and keep only the most recent parameters
|
||||
try:
|
||||
while True:
|
||||
item = queue.get_nowait()
|
||||
counter += 1
|
||||
except Empty:
|
||||
pass
|
||||
|
||||
logging.debug(f"Drained {counter} items from queue")
|
||||
|
||||
return item
|
Loading…
Reference in New Issue