[HIL-SERL] Migrate threading to multiprocessing (#759)
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
This commit is contained in:
parent
584cad808e
commit
700f00c014
|
@ -116,11 +116,11 @@ def seeded_context(seed: int) -> Generator[None, None, None]:
|
||||||
set_global_random_state(random_state_dict)
|
set_global_random_state(random_state_dict)
|
||||||
|
|
||||||
|
|
||||||
def init_logging():
|
def init_logging(log_file=None):
|
||||||
def custom_format(record):
|
def custom_format(record):
|
||||||
dt = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
|
dt = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
|
||||||
fnameline = f"{record.pathname}:{record.lineno}"
|
fnameline = f"{record.pathname}:{record.lineno}"
|
||||||
message = f"{record.levelname} {dt} {fnameline[-15:]:>15} {record.msg}"
|
message = f"{record.levelname} [PID: {os.getpid()}] {dt} {fnameline[-15:]:>15} {record.msg}"
|
||||||
return message
|
return message
|
||||||
|
|
||||||
logging.basicConfig(level=logging.INFO)
|
logging.basicConfig(level=logging.INFO)
|
||||||
|
@ -134,6 +134,12 @@ def init_logging():
|
||||||
console_handler.setFormatter(formatter)
|
console_handler.setFormatter(formatter)
|
||||||
logging.getLogger().addHandler(console_handler)
|
logging.getLogger().addHandler(console_handler)
|
||||||
|
|
||||||
|
if log_file is not None:
|
||||||
|
# File handler
|
||||||
|
file_handler = logging.FileHandler(log_file)
|
||||||
|
file_handler.setFormatter(formatter)
|
||||||
|
logging.getLogger().addHandler(file_handler)
|
||||||
|
|
||||||
|
|
||||||
def format_big_number(num, precision=0):
|
def format_big_number(num, precision=0):
|
||||||
suffixes = ["", "K", "M", "B", "T", "Q"]
|
suffixes = ["", "K", "M", "B", "T", "Q"]
|
||||||
|
|
|
@ -22,3 +22,9 @@ env:
|
||||||
wrapper:
|
wrapper:
|
||||||
joint_masking_action_space: null
|
joint_masking_action_space: null
|
||||||
delta_action: null
|
delta_action: null
|
||||||
|
|
||||||
|
video_record:
|
||||||
|
enabled: false
|
||||||
|
record_dir: maniskill_videos
|
||||||
|
trajectory_name: trajectory
|
||||||
|
fps: ${fps}
|
||||||
|
|
|
@ -28,4 +28,3 @@ env:
|
||||||
reward_classifier:
|
reward_classifier:
|
||||||
pretrained_path: outputs/classifier/13-02-random-sample-resnet10-frozen/checkpoints/best/pretrained_model
|
pretrained_path: outputs/classifier/13-02-random-sample-resnet10-frozen/checkpoints/best/pretrained_model
|
||||||
config_path: lerobot/configs/policy/hilserl_classifier.yaml
|
config_path: lerobot/configs/policy/hilserl_classifier.yaml
|
||||||
|
|
||||||
|
|
|
@ -8,14 +8,12 @@
|
||||||
# env.gym.obs_type=environment_state_agent_pos \
|
# env.gym.obs_type=environment_state_agent_pos \
|
||||||
|
|
||||||
seed: 1
|
seed: 1
|
||||||
# dataset_repo_id: null
|
|
||||||
dataset_repo_id: "AdilZtn/Maniskill-Pushcube-demonstration-medium"
|
dataset_repo_id: "AdilZtn/Maniskill-Pushcube-demonstration-medium"
|
||||||
|
|
||||||
training:
|
training:
|
||||||
# Offline training dataloader
|
# Offline training dataloader
|
||||||
num_workers: 4
|
num_workers: 4
|
||||||
|
|
||||||
# batch_size: 256
|
|
||||||
batch_size: 512
|
batch_size: 512
|
||||||
grad_clip_norm: 10.0
|
grad_clip_norm: 10.0
|
||||||
lr: 3e-4
|
lr: 3e-4
|
||||||
|
@ -113,4 +111,7 @@ policy:
|
||||||
actor_learner_config:
|
actor_learner_config:
|
||||||
learner_host: "127.0.0.1"
|
learner_host: "127.0.0.1"
|
||||||
learner_port: 50051
|
learner_port: 50051
|
||||||
policy_parameters_push_frequency: 15
|
policy_parameters_push_frequency: 1
|
||||||
|
concurrency:
|
||||||
|
actor: 'processes'
|
||||||
|
learner: 'processes'
|
||||||
|
|
|
@ -13,22 +13,19 @@
|
||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
import io
|
|
||||||
import logging
|
import logging
|
||||||
import pickle
|
|
||||||
import queue
|
|
||||||
from statistics import mean, quantiles
|
from statistics import mean, quantiles
|
||||||
import signal
|
|
||||||
from functools import lru_cache
|
from functools import lru_cache
|
||||||
|
from lerobot.scripts.server.utils import setup_process_handlers
|
||||||
|
|
||||||
# from lerobot.scripts.eval import eval_policy
|
# from lerobot.scripts.eval import eval_policy
|
||||||
from threading import Thread
|
|
||||||
|
|
||||||
import grpc
|
import grpc
|
||||||
import hydra
|
import hydra
|
||||||
import torch
|
import torch
|
||||||
from omegaconf import DictConfig
|
from omegaconf import DictConfig
|
||||||
from torch import nn
|
from torch import nn
|
||||||
|
import time
|
||||||
|
|
||||||
# TODO: Remove the import of maniskill
|
# TODO: Remove the import of maniskill
|
||||||
# from lerobot.common.envs.factory import make_maniskill_env
|
# from lerobot.common.envs.factory import make_maniskill_env
|
||||||
|
@ -47,157 +44,184 @@ from lerobot.scripts.server.buffer import (
|
||||||
Transition,
|
Transition,
|
||||||
move_state_dict_to_device,
|
move_state_dict_to_device,
|
||||||
move_transition_to_device,
|
move_transition_to_device,
|
||||||
bytes_buffer_size,
|
python_object_to_bytes,
|
||||||
|
transitions_to_bytes,
|
||||||
|
bytes_to_state_dict,
|
||||||
|
)
|
||||||
|
from lerobot.scripts.server.network_utils import (
|
||||||
|
receive_bytes_in_chunks,
|
||||||
|
send_bytes_in_chunks,
|
||||||
)
|
)
|
||||||
from lerobot.scripts.server.gym_manipulator import get_classifier, make_robot_env
|
from lerobot.scripts.server.gym_manipulator import get_classifier, make_robot_env
|
||||||
from lerobot.scripts.server import learner_service
|
from lerobot.scripts.server import learner_service
|
||||||
|
|
||||||
from threading import Event
|
from torch.multiprocessing import Queue, Event
|
||||||
|
from queue import Empty
|
||||||
|
|
||||||
logging.basicConfig(level=logging.INFO)
|
from lerobot.common.utils.utils import init_logging
|
||||||
|
|
||||||
parameters_queue = queue.Queue(maxsize=1)
|
from lerobot.scripts.server.utils import get_last_item_from_queue
|
||||||
message_queue = queue.Queue(maxsize=1_000_000)
|
|
||||||
|
|
||||||
ACTOR_SHUTDOWN_TIMEOUT = 30
|
ACTOR_SHUTDOWN_TIMEOUT = 30
|
||||||
|
|
||||||
|
|
||||||
class ActorInformation:
|
|
||||||
"""
|
|
||||||
This helper class is used to differentiate between two types of messages that are placed in the same queue during streaming:
|
|
||||||
|
|
||||||
- **Transition Data:** Contains experience tuples (observation, action, reward, next observation) collected during interaction.
|
|
||||||
- **Interaction Messages:** Encapsulates statistics related to the interaction process.
|
|
||||||
|
|
||||||
Attributes:
|
|
||||||
transition (Optional): Transition data to be sent to the learner.
|
|
||||||
interaction_message (Optional): Iteraction message providing additional statistics for logging.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self, transition=None, interaction_message=None):
|
|
||||||
self.transition = transition
|
|
||||||
self.interaction_message = interaction_message
|
|
||||||
|
|
||||||
|
|
||||||
def receive_policy(
|
def receive_policy(
|
||||||
learner_client: hilserl_pb2_grpc.LearnerServiceStub,
|
cfg: DictConfig,
|
||||||
shutdown_event: Event,
|
parameters_queue: Queue,
|
||||||
parameters_queue: queue.Queue,
|
shutdown_event: any, # Event,
|
||||||
|
learner_client: hilserl_pb2_grpc.LearnerServiceStub | None = None,
|
||||||
|
grpc_channel: grpc.Channel | None = None,
|
||||||
):
|
):
|
||||||
logging.info("[ACTOR] Start receiving parameters from the Learner")
|
logging.info("[ACTOR] Start receiving parameters from the Learner")
|
||||||
bytes_buffer = io.BytesIO()
|
|
||||||
step = 0
|
if not use_threads(cfg):
|
||||||
|
# Setup process handlers to handle shutdown signal
|
||||||
|
# But use shutdown event from the main process
|
||||||
|
setup_process_handlers(False)
|
||||||
|
|
||||||
|
if grpc_channel is None or learner_client is None:
|
||||||
|
learner_client, grpc_channel = learner_service_client(
|
||||||
|
host=cfg.actor_learner_config.learner_host,
|
||||||
|
port=cfg.actor_learner_config.learner_port,
|
||||||
|
)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
for model_update in learner_client.StreamParameters(hilserl_pb2.Empty()):
|
iterator = learner_client.StreamParameters(hilserl_pb2.Empty())
|
||||||
if shutdown_event.is_set():
|
receive_bytes_in_chunks(
|
||||||
logging.info("[ACTOR] Shutting down policy streaming receiver")
|
iterator,
|
||||||
return hilserl_pb2.Empty()
|
parameters_queue,
|
||||||
|
shutdown_event,
|
||||||
if model_update.transfer_state == hilserl_pb2.TransferState.TRANSFER_BEGIN:
|
log_prefix="[ACTOR] parameters",
|
||||||
bytes_buffer.seek(0)
|
)
|
||||||
bytes_buffer.truncate(0)
|
|
||||||
bytes_buffer.write(model_update.parameter_bytes)
|
|
||||||
logging.info("Received model update at step 0")
|
|
||||||
step = 0
|
|
||||||
continue
|
|
||||||
elif (
|
|
||||||
model_update.transfer_state == hilserl_pb2.TransferState.TRANSFER_MIDDLE
|
|
||||||
):
|
|
||||||
bytes_buffer.write(model_update.parameter_bytes)
|
|
||||||
step += 1
|
|
||||||
logging.info(f"Received model update at step {step}")
|
|
||||||
elif model_update.transfer_state == hilserl_pb2.TransferState.TRANSFER_END:
|
|
||||||
bytes_buffer.write(model_update.parameter_bytes)
|
|
||||||
logging.info(
|
|
||||||
f"Received model update at step end size {bytes_buffer_size(bytes_buffer)}"
|
|
||||||
)
|
|
||||||
|
|
||||||
state_dict = torch.load(bytes_buffer)
|
|
||||||
|
|
||||||
bytes_buffer.seek(0)
|
|
||||||
bytes_buffer.truncate(0)
|
|
||||||
step = 0
|
|
||||||
|
|
||||||
logging.info("Model updated")
|
|
||||||
|
|
||||||
parameters_queue.put(state_dict)
|
|
||||||
|
|
||||||
except grpc.RpcError as e:
|
except grpc.RpcError as e:
|
||||||
logging.error(f"[ACTOR] gRPC error: {e}")
|
logging.error(f"[ACTOR] gRPC error: {e}")
|
||||||
|
|
||||||
|
if not use_threads(cfg):
|
||||||
|
grpc_channel.close()
|
||||||
|
logging.info("[ACTOR] Received policy loop stopped")
|
||||||
|
|
||||||
|
|
||||||
|
def transitions_stream(
|
||||||
|
shutdown_event: Event, transitions_queue: Queue
|
||||||
|
) -> hilserl_pb2.Empty:
|
||||||
|
while not shutdown_event.is_set():
|
||||||
|
try:
|
||||||
|
message = transitions_queue.get(block=True, timeout=5)
|
||||||
|
except Empty:
|
||||||
|
logging.debug("[ACTOR] Transition queue is empty")
|
||||||
|
continue
|
||||||
|
|
||||||
|
yield from send_bytes_in_chunks(
|
||||||
|
message, hilserl_pb2.Transition, log_prefix="[ACTOR] Send transitions"
|
||||||
|
)
|
||||||
|
|
||||||
return hilserl_pb2.Empty()
|
return hilserl_pb2.Empty()
|
||||||
|
|
||||||
|
|
||||||
def transitions_stream(shutdown_event: Event, message_queue: queue.Queue):
|
def interactions_stream(
|
||||||
|
shutdown_event: any, # Event,
|
||||||
|
interactions_queue: Queue,
|
||||||
|
) -> hilserl_pb2.Empty:
|
||||||
while not shutdown_event.is_set():
|
while not shutdown_event.is_set():
|
||||||
try:
|
try:
|
||||||
message = message_queue.get(block=True, timeout=5)
|
message = interactions_queue.get(block=True, timeout=5)
|
||||||
except queue.Empty:
|
except Empty:
|
||||||
logging.debug("[ACTOR] Transition queue is empty")
|
logging.debug("[ACTOR] Interaction queue is empty")
|
||||||
continue
|
continue
|
||||||
|
|
||||||
if message.transition is not None:
|
yield from send_bytes_in_chunks(
|
||||||
transition_to_send_to_learner: list[Transition] = [
|
message,
|
||||||
move_transition_to_device(transition=T, device="cpu")
|
hilserl_pb2.InteractionMessage,
|
||||||
for T in message.transition
|
log_prefix="[ACTOR] Send interactions",
|
||||||
]
|
)
|
||||||
# Check for NaNs in transitions before sending to learner
|
|
||||||
for transition in transition_to_send_to_learner:
|
|
||||||
for key, value in transition["state"].items():
|
|
||||||
if torch.isnan(value).any():
|
|
||||||
logging.warning(f"Found NaN values in transition {key}")
|
|
||||||
buf = io.BytesIO()
|
|
||||||
torch.save(transition_to_send_to_learner, buf)
|
|
||||||
transition_bytes = buf.getvalue()
|
|
||||||
|
|
||||||
transition_message = hilserl_pb2.Transition(
|
|
||||||
transition_bytes=transition_bytes
|
|
||||||
)
|
|
||||||
|
|
||||||
response = hilserl_pb2.ActorInformation(transition=transition_message)
|
|
||||||
|
|
||||||
elif message.interaction_message is not None:
|
|
||||||
content = hilserl_pb2.InteractionMessage(
|
|
||||||
interaction_message_bytes=pickle.dumps(message.interaction_message)
|
|
||||||
)
|
|
||||||
response = hilserl_pb2.ActorInformation(interaction_message=content)
|
|
||||||
|
|
||||||
yield response
|
|
||||||
|
|
||||||
return hilserl_pb2.Empty()
|
return hilserl_pb2.Empty()
|
||||||
|
|
||||||
|
|
||||||
def send_transitions(
|
def send_transitions(
|
||||||
learner_client: hilserl_pb2_grpc.LearnerServiceStub,
|
cfg: DictConfig,
|
||||||
shutdown_event: Event,
|
transitions_queue: Queue,
|
||||||
message_queue: queue.Queue,
|
shutdown_event: any, # Event,
|
||||||
):
|
learner_client: hilserl_pb2_grpc.LearnerServiceStub | None = None,
|
||||||
|
grpc_channel: grpc.Channel | None = None,
|
||||||
|
) -> hilserl_pb2.Empty:
|
||||||
"""
|
"""
|
||||||
Streams data from the actor to the learner.
|
Sends transitions to the learner.
|
||||||
|
|
||||||
This function continuously retrieves messages from the queue and processes them based on their type:
|
This function continuously retrieves messages from the queue and processes:
|
||||||
|
|
||||||
- **Transition Data:**
|
- **Transition Data:**
|
||||||
- A batch of transitions (observation, action, reward, next observation) is collected.
|
- A batch of transitions (observation, action, reward, next observation) is collected.
|
||||||
- Transitions are moved to the CPU and serialized using PyTorch.
|
- Transitions are moved to the CPU and serialized using PyTorch.
|
||||||
- The serialized data is wrapped in a `hilserl_pb2.Transition` message and sent to the learner.
|
- The serialized data is wrapped in a `hilserl_pb2.Transition` message and sent to the learner.
|
||||||
|
|
||||||
- **Interaction Messages:**
|
|
||||||
- Contains useful statistics about episodic rewards and policy timings.
|
|
||||||
- The message is serialized using `pickle` and sent to the learner.
|
|
||||||
|
|
||||||
Yields:
|
|
||||||
hilserl_pb2.ActorInformation: The response message containing either transition data or an interaction message.
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
if not use_threads(cfg):
|
||||||
|
# Setup process handlers to handle shutdown signal
|
||||||
|
# But use shutdown event from the main process
|
||||||
|
setup_process_handlers(False)
|
||||||
|
|
||||||
|
if grpc_channel is None or learner_client is None:
|
||||||
|
learner_client, grpc_channel = learner_service_client(
|
||||||
|
host=cfg.actor_learner_config.learner_host,
|
||||||
|
port=cfg.actor_learner_config.learner_port,
|
||||||
|
)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
learner_client.ReceiveTransitions(
|
learner_client.SendTransitions(
|
||||||
transitions_stream(shutdown_event, message_queue)
|
transitions_stream(shutdown_event, transitions_queue)
|
||||||
)
|
)
|
||||||
except grpc.RpcError as e:
|
except grpc.RpcError as e:
|
||||||
logging.error(f"[ACTOR] gRPC error: {e}")
|
logging.error(f"[ACTOR] gRPC error: {e}")
|
||||||
|
|
||||||
logging.info("[ACTOR] Finished streaming transitions")
|
logging.info("[ACTOR] Finished streaming transitions")
|
||||||
|
|
||||||
|
if not use_threads(cfg):
|
||||||
|
grpc_channel.close()
|
||||||
|
logging.info("[ACTOR] Transitions process stopped")
|
||||||
|
|
||||||
|
|
||||||
|
def send_interactions(
|
||||||
|
cfg: DictConfig,
|
||||||
|
interactions_queue: Queue,
|
||||||
|
shutdown_event: any, # Event,
|
||||||
|
learner_client: hilserl_pb2_grpc.LearnerServiceStub | None = None,
|
||||||
|
grpc_channel: grpc.Channel | None = None,
|
||||||
|
) -> hilserl_pb2.Empty:
|
||||||
|
"""
|
||||||
|
Sends interactions to the learner.
|
||||||
|
|
||||||
|
This function continuously retrieves messages from the queue and processes:
|
||||||
|
|
||||||
|
- **Interaction Messages:**
|
||||||
|
- Contains useful statistics about episodic rewards and policy timings.
|
||||||
|
- The message is serialized using `pickle` and sent to the learner.
|
||||||
|
"""
|
||||||
|
|
||||||
|
if not use_threads(cfg):
|
||||||
|
# Setup process handlers to handle shutdown signal
|
||||||
|
# But use shutdown event from the main process
|
||||||
|
setup_process_handlers(False)
|
||||||
|
|
||||||
|
if grpc_channel is None or learner_client is None:
|
||||||
|
learner_client, grpc_channel = learner_service_client(
|
||||||
|
host=cfg.actor_learner_config.learner_host,
|
||||||
|
port=cfg.actor_learner_config.learner_port,
|
||||||
|
)
|
||||||
|
|
||||||
|
try:
|
||||||
|
learner_client.SendInteractions(
|
||||||
|
interactions_stream(shutdown_event, interactions_queue)
|
||||||
|
)
|
||||||
|
except grpc.RpcError as e:
|
||||||
|
logging.error(f"[ACTOR] gRPC error: {e}")
|
||||||
|
|
||||||
|
logging.info("[ACTOR] Finished streaming interactions")
|
||||||
|
|
||||||
|
if not use_threads(cfg):
|
||||||
|
grpc_channel.close()
|
||||||
|
logging.info("[ACTOR] Interactions process stopped")
|
||||||
|
|
||||||
|
|
||||||
@lru_cache(maxsize=1)
|
@lru_cache(maxsize=1)
|
||||||
def learner_service_client(
|
def learner_service_client(
|
||||||
|
@ -217,7 +241,7 @@ def learner_service_client(
|
||||||
{
|
{
|
||||||
"name": [{}], # Applies to ALL methods in ALL services
|
"name": [{}], # Applies to ALL methods in ALL services
|
||||||
"retryPolicy": {
|
"retryPolicy": {
|
||||||
"maxAttempts": 7, # Max retries (total attempts = 5)
|
"maxAttempts": 5, # Max retries (total attempts = 5)
|
||||||
"initialBackoff": "0.1s", # First retry after 0.1s
|
"initialBackoff": "0.1s", # First retry after 0.1s
|
||||||
"maxBackoff": "2s", # Max wait time between retries
|
"maxBackoff": "2s", # Max wait time between retries
|
||||||
"backoffMultiplier": 2, # Exponential backoff factor
|
"backoffMultiplier": 2, # Exponential backoff factor
|
||||||
|
@ -242,20 +266,27 @@ def learner_service_client(
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
stub = hilserl_pb2_grpc.LearnerServiceStub(channel)
|
stub = hilserl_pb2_grpc.LearnerServiceStub(channel)
|
||||||
logging.info("[LEARNER] Learner service client created")
|
logging.info("[ACTOR] Learner service client created")
|
||||||
return stub, channel
|
return stub, channel
|
||||||
|
|
||||||
|
|
||||||
def update_policy_parameters(policy: SACPolicy, parameters_queue: queue.Queue, device):
|
def update_policy_parameters(policy: SACPolicy, parameters_queue: Queue, device):
|
||||||
if not parameters_queue.empty():
|
if not parameters_queue.empty():
|
||||||
logging.info("[ACTOR] Load new parameters from Learner.")
|
logging.info("[ACTOR] Load new parameters from Learner.")
|
||||||
state_dict = parameters_queue.get()
|
bytes_state_dict = get_last_item_from_queue(parameters_queue)
|
||||||
|
state_dict = bytes_to_state_dict(bytes_state_dict)
|
||||||
state_dict = move_state_dict_to_device(state_dict, device=device)
|
state_dict = move_state_dict_to_device(state_dict, device=device)
|
||||||
policy.load_state_dict(state_dict)
|
policy.load_state_dict(state_dict)
|
||||||
|
|
||||||
|
|
||||||
def act_with_policy(
|
def act_with_policy(
|
||||||
cfg: DictConfig, robot: Robot, reward_classifier: nn.Module, shutdown_event: Event
|
cfg: DictConfig,
|
||||||
|
robot: Robot,
|
||||||
|
reward_classifier: nn.Module,
|
||||||
|
shutdown_event: any, # Event,
|
||||||
|
parameters_queue: Queue,
|
||||||
|
transitions_queue: Queue,
|
||||||
|
interactions_queue: Queue,
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Executes policy interaction within the environment.
|
Executes policy interaction within the environment.
|
||||||
|
@ -317,7 +348,7 @@ def act_with_policy(
|
||||||
|
|
||||||
for interaction_step in range(cfg.training.online_steps):
|
for interaction_step in range(cfg.training.online_steps):
|
||||||
if shutdown_event.is_set():
|
if shutdown_event.is_set():
|
||||||
logging.info("[ACTOR] Shutdown signal received. Exiting...")
|
logging.info("[ACTOR] Shutting down act_with_policy")
|
||||||
return
|
return
|
||||||
|
|
||||||
if interaction_step >= cfg.training.online_step_before_learning:
|
if interaction_step >= cfg.training.online_step_before_learning:
|
||||||
|
@ -394,10 +425,9 @@ def act_with_policy(
|
||||||
)
|
)
|
||||||
|
|
||||||
if len(list_transition_to_send_to_learner) > 0:
|
if len(list_transition_to_send_to_learner) > 0:
|
||||||
send_transitions_in_chunks(
|
push_transitions_to_transport_queue(
|
||||||
transitions=list_transition_to_send_to_learner,
|
transitions=list_transition_to_send_to_learner,
|
||||||
message_queue=message_queue,
|
transitions_queue=transitions_queue,
|
||||||
chunk_size=4,
|
|
||||||
)
|
)
|
||||||
list_transition_to_send_to_learner = []
|
list_transition_to_send_to_learner = []
|
||||||
|
|
||||||
|
@ -405,9 +435,9 @@ def act_with_policy(
|
||||||
list_policy_time.clear()
|
list_policy_time.clear()
|
||||||
|
|
||||||
# Send episodic reward to the learner
|
# Send episodic reward to the learner
|
||||||
message_queue.put(
|
interactions_queue.put(
|
||||||
ActorInformation(
|
python_object_to_bytes(
|
||||||
interaction_message={
|
{
|
||||||
"Episodic reward": sum_reward_episode,
|
"Episodic reward": sum_reward_episode,
|
||||||
"Interaction step": interaction_step,
|
"Interaction step": interaction_step,
|
||||||
"Episode intervention": int(episode_intervention),
|
"Episode intervention": int(episode_intervention),
|
||||||
|
@ -420,7 +450,7 @@ def act_with_policy(
|
||||||
obs, info = online_env.reset()
|
obs, info = online_env.reset()
|
||||||
|
|
||||||
|
|
||||||
def send_transitions_in_chunks(transitions: list, message_queue, chunk_size: int = 100):
|
def push_transitions_to_transport_queue(transitions: list, transitions_queue):
|
||||||
"""Send transitions to learner in smaller chunks to avoid network issues.
|
"""Send transitions to learner in smaller chunks to avoid network issues.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
|
@ -428,10 +458,16 @@ def send_transitions_in_chunks(transitions: list, message_queue, chunk_size: int
|
||||||
message_queue: Queue to send messages to learner
|
message_queue: Queue to send messages to learner
|
||||||
chunk_size: Size of each chunk to send
|
chunk_size: Size of each chunk to send
|
||||||
"""
|
"""
|
||||||
for i in range(0, len(transitions), chunk_size):
|
transition_to_send_to_learner = []
|
||||||
chunk = transitions[i : i + chunk_size]
|
for transition in transitions:
|
||||||
logging.debug(f"[ACTOR] Sending chunk of {len(chunk)} transitions to Learner.")
|
tr = move_transition_to_device(transition=transition, device="cpu")
|
||||||
message_queue.put(ActorInformation(transition=chunk))
|
for key, value in tr["state"].items():
|
||||||
|
if torch.isnan(value).any():
|
||||||
|
logging.warning(f"Found NaN values in transition {key}")
|
||||||
|
|
||||||
|
transition_to_send_to_learner.append(tr)
|
||||||
|
|
||||||
|
transitions_queue.put(transitions_to_bytes(transition_to_send_to_learner))
|
||||||
|
|
||||||
|
|
||||||
def get_frequency_stats(list_policy_time: list[float]) -> dict[str, float]:
|
def get_frequency_stats(list_policy_time: list[float]) -> dict[str, float]:
|
||||||
|
@ -458,39 +494,96 @@ def log_policy_frequency_issue(
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def establish_learner_connection(
|
||||||
|
stub,
|
||||||
|
shutdown_event: any, # Event,
|
||||||
|
attempts=30,
|
||||||
|
):
|
||||||
|
for _ in range(attempts):
|
||||||
|
if shutdown_event.is_set():
|
||||||
|
logging.info("[ACTOR] Shutting down establish_learner_connection")
|
||||||
|
return False
|
||||||
|
|
||||||
|
# Force a connection attempt and check state
|
||||||
|
try:
|
||||||
|
logging.info("[ACTOR] Send ready message to Learner")
|
||||||
|
if stub.Ready(hilserl_pb2.Empty()) == hilserl_pb2.Empty():
|
||||||
|
return True
|
||||||
|
except grpc.RpcError as e:
|
||||||
|
logging.error(f"[ACTOR] Waiting for Learner to be ready... {e}")
|
||||||
|
time.sleep(2)
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
|
def use_threads(cfg: DictConfig) -> bool:
|
||||||
|
return cfg.actor_learner_config.concurrency.actor == "threads"
|
||||||
|
|
||||||
|
|
||||||
@hydra.main(version_base="1.2", config_name="default", config_path="../../configs")
|
@hydra.main(version_base="1.2", config_name="default", config_path="../../configs")
|
||||||
def actor_cli(cfg: dict):
|
def actor_cli(cfg: dict):
|
||||||
|
if not use_threads(cfg):
|
||||||
|
import torch.multiprocessing as mp
|
||||||
|
|
||||||
|
mp.set_start_method("spawn")
|
||||||
|
|
||||||
|
init_logging(log_file="actor.log")
|
||||||
robot = make_robot(cfg=cfg.robot)
|
robot = make_robot(cfg=cfg.robot)
|
||||||
|
|
||||||
shutdown_event = Event()
|
shutdown_event = setup_process_handlers(use_threads(cfg))
|
||||||
|
|
||||||
# Define signal handler
|
|
||||||
def signal_handler(signum, frame):
|
|
||||||
logging.info("Shutdown signal received. Cleaning up...")
|
|
||||||
shutdown_event.set()
|
|
||||||
|
|
||||||
signal.signal(signal.SIGINT, signal_handler) # Ctrl+C
|
|
||||||
signal.signal(signal.SIGTERM, signal_handler) # Termination request (kill)
|
|
||||||
signal.signal(signal.SIGHUP, signal_handler) # Terminal closed/Hangup
|
|
||||||
signal.signal(signal.SIGQUIT, signal_handler) # Ctrl+\
|
|
||||||
|
|
||||||
learner_client, grpc_channel = learner_service_client(
|
learner_client, grpc_channel = learner_service_client(
|
||||||
host=cfg.actor_learner_config.learner_host,
|
host=cfg.actor_learner_config.learner_host,
|
||||||
port=cfg.actor_learner_config.learner_port,
|
port=cfg.actor_learner_config.learner_port,
|
||||||
)
|
)
|
||||||
|
|
||||||
receive_policy_thread = Thread(
|
logging.info("[ACTOR] Establishing connection with Learner")
|
||||||
|
if not establish_learner_connection(learner_client, shutdown_event):
|
||||||
|
logging.error("[ACTOR] Failed to establish connection with Learner")
|
||||||
|
return
|
||||||
|
|
||||||
|
if not use_threads(cfg):
|
||||||
|
# If we use multithreading, we can reuse the channel
|
||||||
|
grpc_channel.close()
|
||||||
|
grpc_channel = None
|
||||||
|
|
||||||
|
logging.info("[ACTOR] Connection with Learner established")
|
||||||
|
|
||||||
|
parameters_queue = Queue()
|
||||||
|
transitions_queue = Queue()
|
||||||
|
interactions_queue = Queue()
|
||||||
|
|
||||||
|
concurrency_entity = None
|
||||||
|
if use_threads(cfg):
|
||||||
|
from threading import Thread
|
||||||
|
|
||||||
|
concurrency_entity = Thread
|
||||||
|
else:
|
||||||
|
from multiprocessing import Process
|
||||||
|
|
||||||
|
concurrency_entity = Process
|
||||||
|
|
||||||
|
receive_policy_process = concurrency_entity(
|
||||||
target=receive_policy,
|
target=receive_policy,
|
||||||
args=(learner_client, shutdown_event, parameters_queue),
|
args=(cfg, parameters_queue, shutdown_event, grpc_channel),
|
||||||
daemon=True,
|
daemon=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
transitions_thread = Thread(
|
transitions_process = concurrency_entity(
|
||||||
target=send_transitions,
|
target=send_transitions,
|
||||||
args=(learner_client, shutdown_event, message_queue),
|
args=(cfg, transitions_queue, shutdown_event, grpc_channel),
|
||||||
daemon=True,
|
daemon=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
interactions_process = concurrency_entity(
|
||||||
|
target=send_interactions,
|
||||||
|
args=(cfg, interactions_queue, shutdown_event, grpc_channel),
|
||||||
|
daemon=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
transitions_process.start()
|
||||||
|
interactions_process.start()
|
||||||
|
receive_policy_process.start()
|
||||||
|
|
||||||
# HACK: FOR MANISKILL we do not have a reward classifier
|
# HACK: FOR MANISKILL we do not have a reward classifier
|
||||||
# TODO: Remove this once we merge into main
|
# TODO: Remove this once we merge into main
|
||||||
reward_classifier = None
|
reward_classifier = None
|
||||||
|
@ -503,26 +596,35 @@ def actor_cli(cfg: dict):
|
||||||
config_path=cfg.env.reward_classifier.config_path,
|
config_path=cfg.env.reward_classifier.config_path,
|
||||||
)
|
)
|
||||||
|
|
||||||
policy_thread = Thread(
|
act_with_policy(
|
||||||
target=act_with_policy,
|
cfg,
|
||||||
daemon=True,
|
robot,
|
||||||
args=(cfg, robot, reward_classifier, shutdown_event),
|
reward_classifier,
|
||||||
|
shutdown_event,
|
||||||
|
parameters_queue,
|
||||||
|
transitions_queue,
|
||||||
|
interactions_queue,
|
||||||
)
|
)
|
||||||
|
logging.info("[ACTOR] Policy process joined")
|
||||||
|
|
||||||
transitions_thread.start()
|
logging.info("[ACTOR] Closing queues")
|
||||||
policy_thread.start()
|
transitions_queue.close()
|
||||||
receive_policy_thread.start()
|
interactions_queue.close()
|
||||||
|
parameters_queue.close()
|
||||||
|
|
||||||
shutdown_event.wait()
|
transitions_process.join()
|
||||||
logging.info("[ACTOR] Shutdown event received")
|
logging.info("[ACTOR] Transitions process joined")
|
||||||
grpc_channel.close()
|
interactions_process.join()
|
||||||
|
logging.info("[ACTOR] Interactions process joined")
|
||||||
|
receive_policy_process.join()
|
||||||
|
logging.info("[ACTOR] Receive policy process joined")
|
||||||
|
|
||||||
policy_thread.join()
|
logging.info("[ACTOR] join queues")
|
||||||
logging.info("[ACTOR] Policy thread joined")
|
transitions_queue.cancel_join_thread()
|
||||||
transitions_thread.join()
|
interactions_queue.cancel_join_thread()
|
||||||
logging.info("[ACTOR] Transitions thread joined")
|
parameters_queue.cancel_join_thread()
|
||||||
receive_policy_thread.join()
|
|
||||||
logging.info("[ACTOR] Receive policy thread joined")
|
logging.info("[ACTOR] queues closed")
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
|
|
@ -23,6 +23,7 @@ from tqdm import tqdm
|
||||||
|
|
||||||
from lerobot.common.datasets.lerobot_dataset import LeRobotDataset
|
from lerobot.common.datasets.lerobot_dataset import LeRobotDataset
|
||||||
import os
|
import os
|
||||||
|
import pickle
|
||||||
|
|
||||||
|
|
||||||
class Transition(TypedDict):
|
class Transition(TypedDict):
|
||||||
|
@ -91,7 +92,7 @@ def move_transition_to_device(
|
||||||
return transition
|
return transition
|
||||||
|
|
||||||
|
|
||||||
def move_state_dict_to_device(state_dict, device):
|
def move_state_dict_to_device(state_dict, device="cpu"):
|
||||||
"""
|
"""
|
||||||
Recursively move all tensors in a (potentially) nested
|
Recursively move all tensors in a (potentially) nested
|
||||||
dict/list/tuple structure to the CPU.
|
dict/list/tuple structure to the CPU.
|
||||||
|
@ -111,20 +112,41 @@ def move_state_dict_to_device(state_dict, device):
|
||||||
return state_dict
|
return state_dict
|
||||||
|
|
||||||
|
|
||||||
def state_to_bytes(state_dict: dict[str, torch.Tensor]) -> io.BytesIO:
|
def state_to_bytes(state_dict: dict[str, torch.Tensor]) -> bytes:
|
||||||
"""Convert model state dict to flat array for transmission"""
|
"""Convert model state dict to flat array for transmission"""
|
||||||
buffer = io.BytesIO()
|
buffer = io.BytesIO()
|
||||||
|
|
||||||
torch.save(state_dict, buffer)
|
torch.save(state_dict, buffer)
|
||||||
|
|
||||||
return buffer
|
return buffer.getvalue()
|
||||||
|
|
||||||
|
|
||||||
def bytes_buffer_size(buffer: io.BytesIO) -> int:
|
def bytes_to_state_dict(buffer: bytes) -> dict[str, torch.Tensor]:
|
||||||
buffer.seek(0, io.SEEK_END)
|
buffer = io.BytesIO(buffer)
|
||||||
result = buffer.tell()
|
|
||||||
buffer.seek(0)
|
buffer.seek(0)
|
||||||
return result
|
return torch.load(buffer)
|
||||||
|
|
||||||
|
|
||||||
|
def python_object_to_bytes(python_object: Any) -> bytes:
|
||||||
|
return pickle.dumps(python_object)
|
||||||
|
|
||||||
|
|
||||||
|
def bytes_to_python_object(buffer: bytes) -> Any:
|
||||||
|
buffer = io.BytesIO(buffer)
|
||||||
|
buffer.seek(0)
|
||||||
|
return pickle.load(buffer)
|
||||||
|
|
||||||
|
|
||||||
|
def bytes_to_transitions(buffer: bytes) -> list[Transition]:
|
||||||
|
buffer = io.BytesIO(buffer)
|
||||||
|
buffer.seek(0)
|
||||||
|
return torch.load(buffer)
|
||||||
|
|
||||||
|
|
||||||
|
def transitions_to_bytes(transitions: list[Transition]) -> bytes:
|
||||||
|
buffer = io.BytesIO()
|
||||||
|
torch.save(transitions, buffer)
|
||||||
|
return buffer.getvalue()
|
||||||
|
|
||||||
|
|
||||||
def random_crop_vectorized(images: torch.Tensor, output_size: tuple) -> torch.Tensor:
|
def random_crop_vectorized(images: torch.Tensor, output_size: tuple) -> torch.Tensor:
|
||||||
|
|
|
@ -24,14 +24,9 @@ service LearnerService {
|
||||||
// Actor -> Learner to store transitions
|
// Actor -> Learner to store transitions
|
||||||
rpc SendInteractionMessage(InteractionMessage) returns (Empty);
|
rpc SendInteractionMessage(InteractionMessage) returns (Empty);
|
||||||
rpc StreamParameters(Empty) returns (stream Parameters);
|
rpc StreamParameters(Empty) returns (stream Parameters);
|
||||||
rpc ReceiveTransitions(stream ActorInformation) returns (Empty);
|
rpc SendTransitions(stream Transition) returns (Empty);
|
||||||
}
|
rpc SendInteractions(stream InteractionMessage) returns (Empty);
|
||||||
|
rpc Ready(Empty) returns (Empty);
|
||||||
message ActorInformation {
|
|
||||||
oneof data {
|
|
||||||
Transition transition = 1;
|
|
||||||
InteractionMessage interaction_message = 2;
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
enum TransferState {
|
enum TransferState {
|
||||||
|
@ -43,16 +38,18 @@ enum TransferState {
|
||||||
|
|
||||||
// Messages
|
// Messages
|
||||||
message Transition {
|
message Transition {
|
||||||
bytes transition_bytes = 1;
|
TransferState transfer_state = 1;
|
||||||
|
bytes data = 2;
|
||||||
}
|
}
|
||||||
|
|
||||||
message Parameters {
|
message Parameters {
|
||||||
TransferState transfer_state = 1;
|
TransferState transfer_state = 1;
|
||||||
bytes parameter_bytes = 2;
|
bytes data = 2;
|
||||||
}
|
}
|
||||||
|
|
||||||
message InteractionMessage {
|
message InteractionMessage {
|
||||||
bytes interaction_message_bytes = 1;
|
TransferState transfer_state = 1;
|
||||||
|
bytes data = 2;
|
||||||
}
|
}
|
||||||
|
|
||||||
message Empty {}
|
message Empty {}
|
||||||
|
|
|
@ -24,25 +24,23 @@ _sym_db = _symbol_database.Default()
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\rhilserl.proto\x12\x08hil_serl\"\x83\x01\n\x10\x41\x63torInformation\x12*\n\ntransition\x18\x01 \x01(\x0b\x32\x14.hil_serl.TransitionH\x00\x12;\n\x13interaction_message\x18\x02 \x01(\x0b\x32\x1c.hil_serl.InteractionMessageH\x00\x42\x06\n\x04\x64\x61ta\"&\n\nTransition\x12\x18\n\x10transition_bytes\x18\x01 \x01(\x0c\"V\n\nParameters\x12/\n\x0etransfer_state\x18\x01 \x01(\x0e\x32\x17.hil_serl.TransferState\x12\x17\n\x0fparameter_bytes\x18\x02 \x01(\x0c\"7\n\x12InteractionMessage\x12!\n\x19interaction_message_bytes\x18\x01 \x01(\x0c\"\x07\n\x05\x45mpty*`\n\rTransferState\x12\x14\n\x10TRANSFER_UNKNOWN\x10\x00\x12\x12\n\x0eTRANSFER_BEGIN\x10\x01\x12\x13\n\x0fTRANSFER_MIDDLE\x10\x02\x12\x10\n\x0cTRANSFER_END\x10\x03\x32\xdb\x01\n\x0eLearnerService\x12G\n\x16SendInteractionMessage\x12\x1c.hil_serl.InteractionMessage\x1a\x0f.hil_serl.Empty\x12;\n\x10StreamParameters\x12\x0f.hil_serl.Empty\x1a\x14.hil_serl.Parameters0\x01\x12\x43\n\x12ReceiveTransitions\x12\x1a.hil_serl.ActorInformation\x1a\x0f.hil_serl.Empty(\x01\x62\x06proto3')
|
DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\rhilserl.proto\x12\x08hil_serl\"K\n\nTransition\x12/\n\x0etransfer_state\x18\x01 \x01(\x0e\x32\x17.hil_serl.TransferState\x12\x0c\n\x04\x64\x61ta\x18\x02 \x01(\x0c\"K\n\nParameters\x12/\n\x0etransfer_state\x18\x01 \x01(\x0e\x32\x17.hil_serl.TransferState\x12\x0c\n\x04\x64\x61ta\x18\x02 \x01(\x0c\"S\n\x12InteractionMessage\x12/\n\x0etransfer_state\x18\x01 \x01(\x0e\x32\x17.hil_serl.TransferState\x12\x0c\n\x04\x64\x61ta\x18\x02 \x01(\x0c\"\x07\n\x05\x45mpty*`\n\rTransferState\x12\x14\n\x10TRANSFER_UNKNOWN\x10\x00\x12\x12\n\x0eTRANSFER_BEGIN\x10\x01\x12\x13\n\x0fTRANSFER_MIDDLE\x10\x02\x12\x10\n\x0cTRANSFER_END\x10\x03\x32\xc2\x02\n\x0eLearnerService\x12G\n\x16SendInteractionMessage\x12\x1c.hil_serl.InteractionMessage\x1a\x0f.hil_serl.Empty\x12;\n\x10StreamParameters\x12\x0f.hil_serl.Empty\x1a\x14.hil_serl.Parameters0\x01\x12:\n\x0fSendTransitions\x12\x14.hil_serl.Transition\x1a\x0f.hil_serl.Empty(\x01\x12\x43\n\x10SendInteractions\x12\x1c.hil_serl.InteractionMessage\x1a\x0f.hil_serl.Empty(\x01\x12)\n\x05Ready\x12\x0f.hil_serl.Empty\x1a\x0f.hil_serl.Emptyb\x06proto3')
|
||||||
|
|
||||||
_globals = globals()
|
_globals = globals()
|
||||||
_builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, _globals)
|
_builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, _globals)
|
||||||
_builder.BuildTopDescriptorsAndMessages(DESCRIPTOR, 'hilserl_pb2', _globals)
|
_builder.BuildTopDescriptorsAndMessages(DESCRIPTOR, 'hilserl_pb2', _globals)
|
||||||
if not _descriptor._USE_C_DESCRIPTORS:
|
if not _descriptor._USE_C_DESCRIPTORS:
|
||||||
DESCRIPTOR._loaded_options = None
|
DESCRIPTOR._loaded_options = None
|
||||||
_globals['_TRANSFERSTATE']._serialized_start=355
|
_globals['_TRANSFERSTATE']._serialized_start=275
|
||||||
_globals['_TRANSFERSTATE']._serialized_end=451
|
_globals['_TRANSFERSTATE']._serialized_end=371
|
||||||
_globals['_ACTORINFORMATION']._serialized_start=28
|
_globals['_TRANSITION']._serialized_start=27
|
||||||
_globals['_ACTORINFORMATION']._serialized_end=159
|
_globals['_TRANSITION']._serialized_end=102
|
||||||
_globals['_TRANSITION']._serialized_start=161
|
_globals['_PARAMETERS']._serialized_start=104
|
||||||
_globals['_TRANSITION']._serialized_end=199
|
_globals['_PARAMETERS']._serialized_end=179
|
||||||
_globals['_PARAMETERS']._serialized_start=201
|
_globals['_INTERACTIONMESSAGE']._serialized_start=181
|
||||||
_globals['_PARAMETERS']._serialized_end=287
|
_globals['_INTERACTIONMESSAGE']._serialized_end=264
|
||||||
_globals['_INTERACTIONMESSAGE']._serialized_start=289
|
_globals['_EMPTY']._serialized_start=266
|
||||||
_globals['_INTERACTIONMESSAGE']._serialized_end=344
|
_globals['_EMPTY']._serialized_end=273
|
||||||
_globals['_EMPTY']._serialized_start=346
|
_globals['_LEARNERSERVICE']._serialized_start=374
|
||||||
_globals['_EMPTY']._serialized_end=353
|
_globals['_LEARNERSERVICE']._serialized_end=696
|
||||||
_globals['_LEARNERSERVICE']._serialized_start=454
|
|
||||||
_globals['_LEARNERSERVICE']._serialized_end=673
|
|
||||||
# @@protoc_insertion_point(module_scope)
|
# @@protoc_insertion_point(module_scope)
|
||||||
|
|
|
@ -46,9 +46,19 @@ class LearnerServiceStub(object):
|
||||||
request_serializer=hilserl__pb2.Empty.SerializeToString,
|
request_serializer=hilserl__pb2.Empty.SerializeToString,
|
||||||
response_deserializer=hilserl__pb2.Parameters.FromString,
|
response_deserializer=hilserl__pb2.Parameters.FromString,
|
||||||
_registered_method=True)
|
_registered_method=True)
|
||||||
self.ReceiveTransitions = channel.stream_unary(
|
self.SendTransitions = channel.stream_unary(
|
||||||
'/hil_serl.LearnerService/ReceiveTransitions',
|
'/hil_serl.LearnerService/SendTransitions',
|
||||||
request_serializer=hilserl__pb2.ActorInformation.SerializeToString,
|
request_serializer=hilserl__pb2.Transition.SerializeToString,
|
||||||
|
response_deserializer=hilserl__pb2.Empty.FromString,
|
||||||
|
_registered_method=True)
|
||||||
|
self.SendInteractions = channel.stream_unary(
|
||||||
|
'/hil_serl.LearnerService/SendInteractions',
|
||||||
|
request_serializer=hilserl__pb2.InteractionMessage.SerializeToString,
|
||||||
|
response_deserializer=hilserl__pb2.Empty.FromString,
|
||||||
|
_registered_method=True)
|
||||||
|
self.Ready = channel.unary_unary(
|
||||||
|
'/hil_serl.LearnerService/Ready',
|
||||||
|
request_serializer=hilserl__pb2.Empty.SerializeToString,
|
||||||
response_deserializer=hilserl__pb2.Empty.FromString,
|
response_deserializer=hilserl__pb2.Empty.FromString,
|
||||||
_registered_method=True)
|
_registered_method=True)
|
||||||
|
|
||||||
|
@ -71,7 +81,19 @@ class LearnerServiceServicer(object):
|
||||||
context.set_details('Method not implemented!')
|
context.set_details('Method not implemented!')
|
||||||
raise NotImplementedError('Method not implemented!')
|
raise NotImplementedError('Method not implemented!')
|
||||||
|
|
||||||
def ReceiveTransitions(self, request_iterator, context):
|
def SendTransitions(self, request_iterator, context):
|
||||||
|
"""Missing associated documentation comment in .proto file."""
|
||||||
|
context.set_code(grpc.StatusCode.UNIMPLEMENTED)
|
||||||
|
context.set_details('Method not implemented!')
|
||||||
|
raise NotImplementedError('Method not implemented!')
|
||||||
|
|
||||||
|
def SendInteractions(self, request_iterator, context):
|
||||||
|
"""Missing associated documentation comment in .proto file."""
|
||||||
|
context.set_code(grpc.StatusCode.UNIMPLEMENTED)
|
||||||
|
context.set_details('Method not implemented!')
|
||||||
|
raise NotImplementedError('Method not implemented!')
|
||||||
|
|
||||||
|
def Ready(self, request, context):
|
||||||
"""Missing associated documentation comment in .proto file."""
|
"""Missing associated documentation comment in .proto file."""
|
||||||
context.set_code(grpc.StatusCode.UNIMPLEMENTED)
|
context.set_code(grpc.StatusCode.UNIMPLEMENTED)
|
||||||
context.set_details('Method not implemented!')
|
context.set_details('Method not implemented!')
|
||||||
|
@ -90,9 +112,19 @@ def add_LearnerServiceServicer_to_server(servicer, server):
|
||||||
request_deserializer=hilserl__pb2.Empty.FromString,
|
request_deserializer=hilserl__pb2.Empty.FromString,
|
||||||
response_serializer=hilserl__pb2.Parameters.SerializeToString,
|
response_serializer=hilserl__pb2.Parameters.SerializeToString,
|
||||||
),
|
),
|
||||||
'ReceiveTransitions': grpc.stream_unary_rpc_method_handler(
|
'SendTransitions': grpc.stream_unary_rpc_method_handler(
|
||||||
servicer.ReceiveTransitions,
|
servicer.SendTransitions,
|
||||||
request_deserializer=hilserl__pb2.ActorInformation.FromString,
|
request_deserializer=hilserl__pb2.Transition.FromString,
|
||||||
|
response_serializer=hilserl__pb2.Empty.SerializeToString,
|
||||||
|
),
|
||||||
|
'SendInteractions': grpc.stream_unary_rpc_method_handler(
|
||||||
|
servicer.SendInteractions,
|
||||||
|
request_deserializer=hilserl__pb2.InteractionMessage.FromString,
|
||||||
|
response_serializer=hilserl__pb2.Empty.SerializeToString,
|
||||||
|
),
|
||||||
|
'Ready': grpc.unary_unary_rpc_method_handler(
|
||||||
|
servicer.Ready,
|
||||||
|
request_deserializer=hilserl__pb2.Empty.FromString,
|
||||||
response_serializer=hilserl__pb2.Empty.SerializeToString,
|
response_serializer=hilserl__pb2.Empty.SerializeToString,
|
||||||
),
|
),
|
||||||
}
|
}
|
||||||
|
@ -163,7 +195,7 @@ class LearnerService(object):
|
||||||
_registered_method=True)
|
_registered_method=True)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def ReceiveTransitions(request_iterator,
|
def SendTransitions(request_iterator,
|
||||||
target,
|
target,
|
||||||
options=(),
|
options=(),
|
||||||
channel_credentials=None,
|
channel_credentials=None,
|
||||||
|
@ -176,8 +208,62 @@ class LearnerService(object):
|
||||||
return grpc.experimental.stream_unary(
|
return grpc.experimental.stream_unary(
|
||||||
request_iterator,
|
request_iterator,
|
||||||
target,
|
target,
|
||||||
'/hil_serl.LearnerService/ReceiveTransitions',
|
'/hil_serl.LearnerService/SendTransitions',
|
||||||
hilserl__pb2.ActorInformation.SerializeToString,
|
hilserl__pb2.Transition.SerializeToString,
|
||||||
|
hilserl__pb2.Empty.FromString,
|
||||||
|
options,
|
||||||
|
channel_credentials,
|
||||||
|
insecure,
|
||||||
|
call_credentials,
|
||||||
|
compression,
|
||||||
|
wait_for_ready,
|
||||||
|
timeout,
|
||||||
|
metadata,
|
||||||
|
_registered_method=True)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def SendInteractions(request_iterator,
|
||||||
|
target,
|
||||||
|
options=(),
|
||||||
|
channel_credentials=None,
|
||||||
|
call_credentials=None,
|
||||||
|
insecure=False,
|
||||||
|
compression=None,
|
||||||
|
wait_for_ready=None,
|
||||||
|
timeout=None,
|
||||||
|
metadata=None):
|
||||||
|
return grpc.experimental.stream_unary(
|
||||||
|
request_iterator,
|
||||||
|
target,
|
||||||
|
'/hil_serl.LearnerService/SendInteractions',
|
||||||
|
hilserl__pb2.InteractionMessage.SerializeToString,
|
||||||
|
hilserl__pb2.Empty.FromString,
|
||||||
|
options,
|
||||||
|
channel_credentials,
|
||||||
|
insecure,
|
||||||
|
call_credentials,
|
||||||
|
compression,
|
||||||
|
wait_for_ready,
|
||||||
|
timeout,
|
||||||
|
metadata,
|
||||||
|
_registered_method=True)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def Ready(request,
|
||||||
|
target,
|
||||||
|
options=(),
|
||||||
|
channel_credentials=None,
|
||||||
|
call_credentials=None,
|
||||||
|
insecure=False,
|
||||||
|
compression=None,
|
||||||
|
wait_for_ready=None,
|
||||||
|
timeout=None,
|
||||||
|
metadata=None):
|
||||||
|
return grpc.experimental.unary_unary(
|
||||||
|
request,
|
||||||
|
target,
|
||||||
|
'/hil_serl.LearnerService/Ready',
|
||||||
|
hilserl__pb2.Empty.SerializeToString,
|
||||||
hilserl__pb2.Empty.FromString,
|
hilserl__pb2.Empty.FromString,
|
||||||
options,
|
options,
|
||||||
channel_credentials,
|
channel_credentials,
|
||||||
|
|
|
@ -15,15 +15,18 @@
|
||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
import logging
|
import logging
|
||||||
import queue
|
|
||||||
import shutil
|
import shutil
|
||||||
import time
|
import time
|
||||||
from pprint import pformat
|
from pprint import pformat
|
||||||
from threading import Lock, Thread
|
|
||||||
import signal
|
|
||||||
from threading import Event
|
|
||||||
from concurrent.futures import ThreadPoolExecutor
|
from concurrent.futures import ThreadPoolExecutor
|
||||||
|
|
||||||
|
# from torch.multiprocessing import Event, Queue, Process
|
||||||
|
# from threading import Event, Thread
|
||||||
|
# from torch.multiprocessing import Queue, Event
|
||||||
|
from torch.multiprocessing import Queue
|
||||||
|
|
||||||
|
from lerobot.scripts.server.utils import setup_process_handlers
|
||||||
|
|
||||||
import grpc
|
import grpc
|
||||||
|
|
||||||
# Import generated stubs
|
# Import generated stubs
|
||||||
|
@ -52,19 +55,19 @@ from lerobot.common.utils.utils import (
|
||||||
set_global_random_state,
|
set_global_random_state,
|
||||||
set_global_seed,
|
set_global_seed,
|
||||||
)
|
)
|
||||||
|
|
||||||
from lerobot.scripts.server.buffer import (
|
from lerobot.scripts.server.buffer import (
|
||||||
ReplayBuffer,
|
ReplayBuffer,
|
||||||
concatenate_batch_transitions,
|
concatenate_batch_transitions,
|
||||||
move_transition_to_device,
|
move_transition_to_device,
|
||||||
|
move_state_dict_to_device,
|
||||||
|
bytes_to_transitions,
|
||||||
|
state_to_bytes,
|
||||||
|
bytes_to_python_object,
|
||||||
)
|
)
|
||||||
|
|
||||||
from lerobot.scripts.server import learner_service
|
from lerobot.scripts.server import learner_service
|
||||||
|
|
||||||
logging.basicConfig(level=logging.INFO)
|
|
||||||
|
|
||||||
transition_queue = queue.Queue()
|
|
||||||
interaction_message_queue = queue.Queue()
|
|
||||||
|
|
||||||
|
|
||||||
def handle_resume_logic(cfg: DictConfig, out_dir: str) -> DictConfig:
|
def handle_resume_logic(cfg: DictConfig, out_dir: str) -> DictConfig:
|
||||||
if not cfg.resume:
|
if not cfg.resume:
|
||||||
|
@ -195,67 +198,96 @@ def get_observation_features(
|
||||||
return observation_features, next_observation_features
|
return observation_features, next_observation_features
|
||||||
|
|
||||||
|
|
||||||
|
def use_threads(cfg: DictConfig) -> bool:
|
||||||
|
return cfg.actor_learner_config.concurrency.learner == "threads"
|
||||||
|
|
||||||
|
|
||||||
def start_learner_threads(
|
def start_learner_threads(
|
||||||
cfg: DictConfig,
|
cfg: DictConfig,
|
||||||
device: str,
|
|
||||||
replay_buffer: ReplayBuffer,
|
|
||||||
offline_replay_buffer: ReplayBuffer,
|
|
||||||
batch_size: int,
|
|
||||||
optimizers: dict,
|
|
||||||
policy: SACPolicy,
|
|
||||||
policy_lock: Lock,
|
|
||||||
logger: Logger,
|
logger: Logger,
|
||||||
resume_optimization_step: int | None = None,
|
out_dir: str,
|
||||||
resume_interaction_step: int | None = None,
|
shutdown_event: any, # Event,
|
||||||
shutdown_event: Event | None = None,
|
|
||||||
) -> None:
|
) -> None:
|
||||||
host = cfg.actor_learner_config.learner_host
|
# Create multiprocessing queues
|
||||||
port = cfg.actor_learner_config.learner_port
|
transition_queue = Queue()
|
||||||
|
interaction_message_queue = Queue()
|
||||||
|
parameters_queue = Queue()
|
||||||
|
|
||||||
transition_thread = Thread(
|
concurrency_entity = None
|
||||||
target=add_actor_information_and_train,
|
|
||||||
daemon=True,
|
if use_threads(cfg):
|
||||||
|
from threading import Thread
|
||||||
|
|
||||||
|
concurrency_entity = Thread
|
||||||
|
else:
|
||||||
|
from torch.multiprocessing import Process
|
||||||
|
|
||||||
|
concurrency_entity = Process
|
||||||
|
|
||||||
|
communication_process = concurrency_entity(
|
||||||
|
target=start_learner_server,
|
||||||
args=(
|
args=(
|
||||||
cfg,
|
parameters_queue,
|
||||||
device,
|
transition_queue,
|
||||||
replay_buffer,
|
interaction_message_queue,
|
||||||
offline_replay_buffer,
|
|
||||||
batch_size,
|
|
||||||
optimizers,
|
|
||||||
policy,
|
|
||||||
policy_lock,
|
|
||||||
logger,
|
|
||||||
resume_optimization_step,
|
|
||||||
resume_interaction_step,
|
|
||||||
shutdown_event,
|
shutdown_event,
|
||||||
|
cfg,
|
||||||
),
|
),
|
||||||
|
daemon=True,
|
||||||
)
|
)
|
||||||
|
communication_process.start()
|
||||||
|
|
||||||
transition_thread.start()
|
add_actor_information_and_train(
|
||||||
|
cfg,
|
||||||
|
logger,
|
||||||
|
out_dir,
|
||||||
|
shutdown_event,
|
||||||
|
transition_queue,
|
||||||
|
interaction_message_queue,
|
||||||
|
parameters_queue,
|
||||||
|
)
|
||||||
|
logging.info("[LEARNER] Training process stopped")
|
||||||
|
|
||||||
|
logging.info("[LEARNER] Closing queues")
|
||||||
|
transition_queue.close()
|
||||||
|
interaction_message_queue.close()
|
||||||
|
parameters_queue.close()
|
||||||
|
|
||||||
|
communication_process.join()
|
||||||
|
logging.info("[LEARNER] Communication process joined")
|
||||||
|
|
||||||
|
logging.info("[LEARNER] join queues")
|
||||||
|
transition_queue.cancel_join_thread()
|
||||||
|
interaction_message_queue.cancel_join_thread()
|
||||||
|
parameters_queue.cancel_join_thread()
|
||||||
|
|
||||||
|
logging.info("[LEARNER] queues closed")
|
||||||
|
|
||||||
|
|
||||||
|
def start_learner_server(
|
||||||
|
parameters_queue: Queue,
|
||||||
|
transition_queue: Queue,
|
||||||
|
interaction_message_queue: Queue,
|
||||||
|
shutdown_event: any, # Event,
|
||||||
|
cfg: DictConfig,
|
||||||
|
):
|
||||||
|
if not use_threads(cfg):
|
||||||
|
# We need init logging for MP separataly
|
||||||
|
init_logging()
|
||||||
|
|
||||||
|
# Setup process handlers to handle shutdown signal
|
||||||
|
# But use shutdown event from the main process
|
||||||
|
# Return back for MP
|
||||||
|
setup_process_handlers(False)
|
||||||
|
|
||||||
service = learner_service.LearnerService(
|
service = learner_service.LearnerService(
|
||||||
shutdown_event,
|
shutdown_event,
|
||||||
policy,
|
parameters_queue,
|
||||||
policy_lock,
|
|
||||||
cfg.actor_learner_config.policy_parameters_push_frequency,
|
cfg.actor_learner_config.policy_parameters_push_frequency,
|
||||||
transition_queue,
|
transition_queue,
|
||||||
interaction_message_queue,
|
interaction_message_queue,
|
||||||
)
|
)
|
||||||
server = start_learner_server(service, host, port)
|
|
||||||
|
|
||||||
shutdown_event.wait()
|
|
||||||
server.stop(learner_service.STUTDOWN_TIMEOUT)
|
|
||||||
logging.info("[LEARNER] gRPC server stopped")
|
|
||||||
|
|
||||||
transition_thread.join()
|
|
||||||
logging.info("[LEARNER] Transition thread stopped")
|
|
||||||
|
|
||||||
|
|
||||||
def start_learner_server(
|
|
||||||
service: learner_service.LearnerService,
|
|
||||||
host="0.0.0.0",
|
|
||||||
port=50051,
|
|
||||||
) -> grpc.server:
|
|
||||||
server = grpc.server(
|
server = grpc.server(
|
||||||
ThreadPoolExecutor(max_workers=learner_service.MAX_WORKERS),
|
ThreadPoolExecutor(max_workers=learner_service.MAX_WORKERS),
|
||||||
options=[
|
options=[
|
||||||
|
@ -263,15 +295,23 @@ def start_learner_server(
|
||||||
("grpc.max_send_message_length", learner_service.MAX_MESSAGE_SIZE),
|
("grpc.max_send_message_length", learner_service.MAX_MESSAGE_SIZE),
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
hilserl_pb2_grpc.add_LearnerServiceServicer_to_server(
|
hilserl_pb2_grpc.add_LearnerServiceServicer_to_server(
|
||||||
service,
|
service,
|
||||||
server,
|
server,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
host = cfg.actor_learner_config.learner_host
|
||||||
|
port = cfg.actor_learner_config.learner_port
|
||||||
|
|
||||||
server.add_insecure_port(f"{host}:{port}")
|
server.add_insecure_port(f"{host}:{port}")
|
||||||
server.start()
|
server.start()
|
||||||
logging.info("[LEARNER] gRPC server started")
|
logging.info("[LEARNER] gRPC server started")
|
||||||
|
|
||||||
return server
|
shutdown_event.wait()
|
||||||
|
logging.info("[LEARNER] Stopping gRPC server...")
|
||||||
|
server.stop(learner_service.STUTDOWN_TIMEOUT)
|
||||||
|
logging.info("[LEARNER] gRPC server stopped")
|
||||||
|
|
||||||
|
|
||||||
def check_nan_in_transition(
|
def check_nan_in_transition(
|
||||||
|
@ -287,19 +327,21 @@ def check_nan_in_transition(
|
||||||
logging.error("actions contains NaN values")
|
logging.error("actions contains NaN values")
|
||||||
|
|
||||||
|
|
||||||
|
def push_actor_policy_to_queue(parameters_queue: Queue, policy: nn.Module):
|
||||||
|
logging.debug("[LEARNER] Pushing actor policy to the queue")
|
||||||
|
state_dict = move_state_dict_to_device(policy.actor.state_dict(), device="cpu")
|
||||||
|
state_bytes = state_to_bytes(state_dict)
|
||||||
|
parameters_queue.put(state_bytes)
|
||||||
|
|
||||||
|
|
||||||
def add_actor_information_and_train(
|
def add_actor_information_and_train(
|
||||||
cfg,
|
cfg,
|
||||||
device: str,
|
|
||||||
replay_buffer: ReplayBuffer,
|
|
||||||
offline_replay_buffer: ReplayBuffer,
|
|
||||||
batch_size: int,
|
|
||||||
optimizers: dict[str, torch.optim.Optimizer],
|
|
||||||
policy: nn.Module,
|
|
||||||
policy_lock: Lock,
|
|
||||||
logger: Logger,
|
logger: Logger,
|
||||||
resume_optimization_step: int | None = None,
|
out_dir: str,
|
||||||
resume_interaction_step: int | None = None,
|
shutdown_event: any, # Event,
|
||||||
shutdown_event: Event | None = None,
|
transition_queue: Queue,
|
||||||
|
interaction_message_queue: Queue,
|
||||||
|
parameters_queue: Queue,
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Handles data transfer from the actor to the learner, manages training updates,
|
Handles data transfer from the actor to the learner, manages training updates,
|
||||||
|
@ -322,17 +364,73 @@ def add_actor_information_and_train(
|
||||||
Args:
|
Args:
|
||||||
cfg: Configuration object containing hyperparameters.
|
cfg: Configuration object containing hyperparameters.
|
||||||
device (str): The computing device (`"cpu"` or `"cuda"`).
|
device (str): The computing device (`"cpu"` or `"cuda"`).
|
||||||
replay_buffer (ReplayBuffer): The primary replay buffer storing online transitions.
|
|
||||||
offline_replay_buffer (ReplayBuffer): An additional buffer for offline transitions.
|
|
||||||
batch_size (int): The number of transitions to sample per training step.
|
|
||||||
optimizers (Dict[str, torch.optim.Optimizer]): A dictionary of optimizers (`"actor"`, `"critic"`, `"temperature"`).
|
|
||||||
policy (nn.Module): The reinforcement learning policy with critic, actor, and temperature parameters.
|
|
||||||
policy_lock (Lock): A threading lock to ensure safe policy updates.
|
|
||||||
logger (Logger): Logger instance for tracking training progress.
|
logger (Logger): Logger instance for tracking training progress.
|
||||||
resume_optimization_step (int | None): In the case of resume training, start from the last optimization step reached.
|
out_dir (str): The output directory for storing training checkpoints and logs.
|
||||||
resume_interaction_step (int | None): In the case of resume training, shift the interaction step with the last saved step in order to not break logging.
|
shutdown_event (Event): Event to signal shutdown.
|
||||||
shutdown_event (Event | None): Event to signal shutdown.
|
transition_queue (Queue): Queue for receiving transitions from the actor.
|
||||||
|
interaction_message_queue (Queue): Queue for receiving interaction messages from the actor.
|
||||||
|
parameters_queue (Queue): Queue for sending policy parameters to the actor.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
device = get_safe_torch_device(cfg.device, log=True)
|
||||||
|
storage_device = get_safe_torch_device(cfg_device=cfg.training.storage_device)
|
||||||
|
|
||||||
|
logging.info("Initializing policy")
|
||||||
|
### Instantiate the policy in both the actor and learner processes
|
||||||
|
### To avoid sending a SACPolicy object through the port, we create a policy intance
|
||||||
|
### on both sides, the learner sends the updated parameters every n steps to update the actor's parameters
|
||||||
|
# TODO: At some point we should just need make sac policy
|
||||||
|
|
||||||
|
policy: SACPolicy = make_policy(
|
||||||
|
hydra_cfg=cfg,
|
||||||
|
# dataset_stats=offline_dataset.meta.stats if not cfg.resume else None,
|
||||||
|
# Hack: But if we do online traning, we do not need dataset_stats
|
||||||
|
dataset_stats=None,
|
||||||
|
pretrained_policy_name_or_path=str(logger.last_pretrained_model_dir)
|
||||||
|
if cfg.resume
|
||||||
|
else None,
|
||||||
|
)
|
||||||
|
# compile policy
|
||||||
|
policy = torch.compile(policy)
|
||||||
|
assert isinstance(policy, nn.Module)
|
||||||
|
|
||||||
|
push_actor_policy_to_queue(parameters_queue, policy)
|
||||||
|
|
||||||
|
last_time_policy_pushed = time.time()
|
||||||
|
|
||||||
|
optimizers, lr_scheduler = make_optimizers_and_scheduler(cfg, policy)
|
||||||
|
resume_optimization_step, resume_interaction_step = load_training_state(
|
||||||
|
cfg, logger, optimizers
|
||||||
|
)
|
||||||
|
|
||||||
|
log_training_info(cfg, out_dir, policy)
|
||||||
|
|
||||||
|
replay_buffer = initialize_replay_buffer(cfg, logger, device, storage_device)
|
||||||
|
batch_size = cfg.training.batch_size
|
||||||
|
offline_replay_buffer = None
|
||||||
|
|
||||||
|
if cfg.dataset_repo_id is not None:
|
||||||
|
logging.info("make_dataset offline buffer")
|
||||||
|
offline_dataset = make_dataset(cfg)
|
||||||
|
logging.info("Convertion to a offline replay buffer")
|
||||||
|
active_action_dims = None
|
||||||
|
if cfg.env.wrapper.joint_masking_action_space is not None:
|
||||||
|
active_action_dims = [
|
||||||
|
i
|
||||||
|
for i, mask in enumerate(cfg.env.wrapper.joint_masking_action_space)
|
||||||
|
if mask
|
||||||
|
]
|
||||||
|
offline_replay_buffer = ReplayBuffer.from_lerobot_dataset(
|
||||||
|
offline_dataset,
|
||||||
|
device=device,
|
||||||
|
state_keys=cfg.policy.input_shapes.keys(),
|
||||||
|
action_mask=active_action_dims,
|
||||||
|
action_delta=cfg.env.wrapper.delta_action,
|
||||||
|
storage_device=storage_device,
|
||||||
|
optimize_memory=True,
|
||||||
|
)
|
||||||
|
batch_size: int = batch_size // 2 # We will sample from both replay buffer
|
||||||
|
|
||||||
# NOTE: This function doesn't have a single responsibility, it should be split into multiple functions
|
# NOTE: This function doesn't have a single responsibility, it should be split into multiple functions
|
||||||
# in the future. The reason why we did that is the GIL in Python. It's super slow the performance
|
# in the future. The reason why we did that is the GIL in Python. It's super slow the performance
|
||||||
# are divided by 200. So we need to have a single thread that does all the work.
|
# are divided by 200. So we need to have a single thread that does all the work.
|
||||||
|
@ -345,33 +443,39 @@ def add_actor_information_and_train(
|
||||||
interaction_step_shift = (
|
interaction_step_shift = (
|
||||||
resume_interaction_step if resume_interaction_step is not None else 0
|
resume_interaction_step if resume_interaction_step is not None else 0
|
||||||
)
|
)
|
||||||
saved_data = False
|
|
||||||
while True:
|
while True:
|
||||||
if shutdown_event is not None and shutdown_event.is_set():
|
if shutdown_event is not None and shutdown_event.is_set():
|
||||||
logging.info("[LEARNER] Shutdown signal received. Exiting...")
|
logging.info("[LEARNER] Shutdown signal received. Exiting...")
|
||||||
break
|
break
|
||||||
|
|
||||||
while not transition_queue.empty():
|
logging.debug("[LEARNER] Waiting for transitions")
|
||||||
|
while not transition_queue.empty() and not shutdown_event.is_set():
|
||||||
transition_list = transition_queue.get()
|
transition_list = transition_queue.get()
|
||||||
|
transition_list = bytes_to_transitions(transition_list)
|
||||||
|
|
||||||
for transition in transition_list:
|
for transition in transition_list:
|
||||||
transition = move_transition_to_device(transition, device=device)
|
transition = move_transition_to_device(transition, device=device)
|
||||||
replay_buffer.add(**transition)
|
replay_buffer.add(**transition)
|
||||||
|
|
||||||
if transition.get("complementary_info", {}).get("is_intervention"):
|
if transition.get("complementary_info", {}).get("is_intervention"):
|
||||||
offline_replay_buffer.add(**transition)
|
offline_replay_buffer.add(**transition)
|
||||||
|
logging.debug("[LEARNER] Received transitions")
|
||||||
while not interaction_message_queue.empty():
|
logging.debug("[LEARNER] Waiting for interactions")
|
||||||
|
while not interaction_message_queue.empty() and not shutdown_event.is_set():
|
||||||
interaction_message = interaction_message_queue.get()
|
interaction_message = interaction_message_queue.get()
|
||||||
|
interaction_message = bytes_to_python_object(interaction_message)
|
||||||
# If cfg.resume, shift the interaction step with the last checkpointed step in order to not break the logging
|
# If cfg.resume, shift the interaction step with the last checkpointed step in order to not break the logging
|
||||||
interaction_message["Interaction step"] += interaction_step_shift
|
interaction_message["Interaction step"] += interaction_step_shift
|
||||||
logger.log_dict(
|
logger.log_dict(
|
||||||
interaction_message, mode="train", custom_step_key="Interaction step"
|
interaction_message, mode="train", custom_step_key="Interaction step"
|
||||||
)
|
)
|
||||||
# logging.info(f"Interaction message: {interaction_message}")
|
|
||||||
|
logging.debug("[LEARNER] Received interactions")
|
||||||
|
|
||||||
if len(replay_buffer) < cfg.training.online_step_before_learning:
|
if len(replay_buffer) < cfg.training.online_step_before_learning:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
|
logging.debug("[LEARNER] Starting optimization loop")
|
||||||
time_for_one_optimization_step = time.time()
|
time_for_one_optimization_step = time.time()
|
||||||
for _ in range(cfg.policy.utd_ratio - 1):
|
for _ in range(cfg.policy.utd_ratio - 1):
|
||||||
batch = replay_buffer.sample(batch_size)
|
batch = replay_buffer.sample(batch_size)
|
||||||
|
@ -392,19 +496,18 @@ def add_actor_information_and_train(
|
||||||
observation_features, next_observation_features = get_observation_features(
|
observation_features, next_observation_features = get_observation_features(
|
||||||
policy, observations, next_observations
|
policy, observations, next_observations
|
||||||
)
|
)
|
||||||
with policy_lock:
|
loss_critic = policy.compute_loss_critic(
|
||||||
loss_critic = policy.compute_loss_critic(
|
observations=observations,
|
||||||
observations=observations,
|
actions=actions,
|
||||||
actions=actions,
|
rewards=rewards,
|
||||||
rewards=rewards,
|
next_observations=next_observations,
|
||||||
next_observations=next_observations,
|
done=done,
|
||||||
done=done,
|
observation_features=observation_features,
|
||||||
observation_features=observation_features,
|
next_observation_features=next_observation_features,
|
||||||
next_observation_features=next_observation_features,
|
)
|
||||||
)
|
optimizers["critic"].zero_grad()
|
||||||
optimizers["critic"].zero_grad()
|
loss_critic.backward()
|
||||||
loss_critic.backward()
|
optimizers["critic"].step()
|
||||||
optimizers["critic"].step()
|
|
||||||
|
|
||||||
batch = replay_buffer.sample(batch_size)
|
batch = replay_buffer.sample(batch_size)
|
||||||
|
|
||||||
|
@ -427,46 +530,51 @@ def add_actor_information_and_train(
|
||||||
observation_features, next_observation_features = get_observation_features(
|
observation_features, next_observation_features = get_observation_features(
|
||||||
policy, observations, next_observations
|
policy, observations, next_observations
|
||||||
)
|
)
|
||||||
with policy_lock:
|
loss_critic = policy.compute_loss_critic(
|
||||||
loss_critic = policy.compute_loss_critic(
|
observations=observations,
|
||||||
observations=observations,
|
actions=actions,
|
||||||
actions=actions,
|
rewards=rewards,
|
||||||
rewards=rewards,
|
next_observations=next_observations,
|
||||||
next_observations=next_observations,
|
done=done,
|
||||||
done=done,
|
observation_features=observation_features,
|
||||||
observation_features=observation_features,
|
next_observation_features=next_observation_features,
|
||||||
next_observation_features=next_observation_features,
|
)
|
||||||
)
|
optimizers["critic"].zero_grad()
|
||||||
optimizers["critic"].zero_grad()
|
loss_critic.backward()
|
||||||
loss_critic.backward()
|
optimizers["critic"].step()
|
||||||
optimizers["critic"].step()
|
|
||||||
|
|
||||||
training_infos = {}
|
training_infos = {}
|
||||||
training_infos["loss_critic"] = loss_critic.item()
|
training_infos["loss_critic"] = loss_critic.item()
|
||||||
|
|
||||||
if optimization_step % cfg.training.policy_update_freq == 0:
|
if optimization_step % cfg.training.policy_update_freq == 0:
|
||||||
for _ in range(cfg.training.policy_update_freq):
|
for _ in range(cfg.training.policy_update_freq):
|
||||||
with policy_lock:
|
loss_actor = policy.compute_loss_actor(
|
||||||
loss_actor = policy.compute_loss_actor(
|
observations=observations,
|
||||||
observations=observations,
|
observation_features=observation_features,
|
||||||
observation_features=observation_features,
|
)
|
||||||
)
|
|
||||||
|
|
||||||
optimizers["actor"].zero_grad()
|
optimizers["actor"].zero_grad()
|
||||||
loss_actor.backward()
|
loss_actor.backward()
|
||||||
optimizers["actor"].step()
|
optimizers["actor"].step()
|
||||||
|
|
||||||
training_infos["loss_actor"] = loss_actor.item()
|
training_infos["loss_actor"] = loss_actor.item()
|
||||||
|
|
||||||
loss_temperature = policy.compute_loss_temperature(
|
loss_temperature = policy.compute_loss_temperature(
|
||||||
observations=observations,
|
observations=observations,
|
||||||
observation_features=observation_features,
|
observation_features=observation_features,
|
||||||
)
|
)
|
||||||
optimizers["temperature"].zero_grad()
|
optimizers["temperature"].zero_grad()
|
||||||
loss_temperature.backward()
|
loss_temperature.backward()
|
||||||
optimizers["temperature"].step()
|
optimizers["temperature"].step()
|
||||||
|
|
||||||
training_infos["loss_temperature"] = loss_temperature.item()
|
training_infos["loss_temperature"] = loss_temperature.item()
|
||||||
|
|
||||||
|
if (
|
||||||
|
time.time() - last_time_policy_pushed
|
||||||
|
> cfg.actor_learner_config.policy_parameters_push_frequency
|
||||||
|
):
|
||||||
|
push_actor_policy_to_queue(parameters_queue, policy)
|
||||||
|
last_time_policy_pushed = time.time()
|
||||||
|
|
||||||
policy.update_target_networks()
|
policy.update_target_networks()
|
||||||
if optimization_step % cfg.training.log_freq == 0:
|
if optimization_step % cfg.training.log_freq == 0:
|
||||||
|
@ -595,104 +703,36 @@ def train(cfg: DictConfig, out_dir: str | None = None, job_name: str | None = No
|
||||||
|
|
||||||
set_global_seed(cfg.seed)
|
set_global_seed(cfg.seed)
|
||||||
|
|
||||||
device = get_safe_torch_device(cfg.device, log=True)
|
|
||||||
storage_device = get_safe_torch_device(cfg_device=cfg.training.storage_device)
|
|
||||||
|
|
||||||
torch.backends.cudnn.benchmark = True
|
torch.backends.cudnn.benchmark = True
|
||||||
torch.backends.cuda.matmul.allow_tf32 = True
|
torch.backends.cuda.matmul.allow_tf32 = True
|
||||||
|
|
||||||
logging.info("make_policy")
|
shutdown_event = setup_process_handlers(use_threads(cfg))
|
||||||
|
|
||||||
### Instantiate the policy in both the actor and learner processes
|
|
||||||
### To avoid sending a SACPolicy object through the port, we create a policy intance
|
|
||||||
### on both sides, the learner sends the updated parameters every n steps to update the actor's parameters
|
|
||||||
# TODO: At some point we should just need make sac policy
|
|
||||||
|
|
||||||
policy_lock = Lock()
|
|
||||||
policy: SACPolicy = make_policy(
|
|
||||||
hydra_cfg=cfg,
|
|
||||||
# dataset_stats=offline_dataset.meta.stats if not cfg.resume else None,
|
|
||||||
# Hack: But if we do online traning, we do not need dataset_stats
|
|
||||||
dataset_stats=None,
|
|
||||||
pretrained_policy_name_or_path=str(logger.last_pretrained_model_dir)
|
|
||||||
if cfg.resume
|
|
||||||
else None,
|
|
||||||
)
|
|
||||||
# compile policy
|
|
||||||
policy = torch.compile(policy)
|
|
||||||
assert isinstance(policy, nn.Module)
|
|
||||||
|
|
||||||
optimizers, lr_scheduler = make_optimizers_and_scheduler(cfg, policy)
|
|
||||||
resume_optimization_step, resume_interaction_step = load_training_state(
|
|
||||||
cfg, logger, optimizers
|
|
||||||
)
|
|
||||||
|
|
||||||
log_training_info(cfg, out_dir, policy)
|
|
||||||
|
|
||||||
replay_buffer = initialize_replay_buffer(cfg, logger, device, storage_device)
|
|
||||||
batch_size = cfg.training.batch_size
|
|
||||||
offline_replay_buffer = None
|
|
||||||
|
|
||||||
if cfg.dataset_repo_id is not None:
|
|
||||||
logging.info("make_dataset offline buffer")
|
|
||||||
offline_dataset = make_dataset(cfg)
|
|
||||||
logging.info("Convertion to a offline replay buffer")
|
|
||||||
active_action_dims = None
|
|
||||||
if cfg.env.wrapper.joint_masking_action_space is not None:
|
|
||||||
active_action_dims = [
|
|
||||||
i
|
|
||||||
for i, mask in enumerate(cfg.env.wrapper.joint_masking_action_space)
|
|
||||||
if mask
|
|
||||||
]
|
|
||||||
offline_replay_buffer = ReplayBuffer.from_lerobot_dataset(
|
|
||||||
offline_dataset,
|
|
||||||
device=device,
|
|
||||||
state_keys=cfg.policy.input_shapes.keys(),
|
|
||||||
action_mask=active_action_dims,
|
|
||||||
action_delta=cfg.env.wrapper.delta_action,
|
|
||||||
storage_device=storage_device,
|
|
||||||
optimize_memory=True,
|
|
||||||
)
|
|
||||||
batch_size: int = batch_size // 2 # We will sample from both replay buffer
|
|
||||||
|
|
||||||
shutdown_event = Event()
|
|
||||||
|
|
||||||
def signal_handler(signum, frame):
|
|
||||||
print(
|
|
||||||
f"\nReceived signal {signal.Signals(signum).name}. Initiating learner shutdown..."
|
|
||||||
)
|
|
||||||
shutdown_event.set()
|
|
||||||
|
|
||||||
# Register signal handlers
|
|
||||||
signal.signal(signal.SIGINT, signal_handler) # Ctrl+C
|
|
||||||
signal.signal(signal.SIGTERM, signal_handler) # Termination request
|
|
||||||
signal.signal(signal.SIGHUP, signal_handler) # Terminal closed
|
|
||||||
signal.signal(signal.SIGQUIT, signal_handler) # Ctrl+\
|
|
||||||
|
|
||||||
start_learner_threads(
|
start_learner_threads(
|
||||||
cfg,
|
cfg,
|
||||||
device,
|
|
||||||
replay_buffer,
|
|
||||||
offline_replay_buffer,
|
|
||||||
batch_size,
|
|
||||||
optimizers,
|
|
||||||
policy,
|
|
||||||
policy_lock,
|
|
||||||
logger,
|
logger,
|
||||||
resume_optimization_step,
|
out_dir,
|
||||||
resume_interaction_step,
|
|
||||||
shutdown_event,
|
shutdown_event,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@hydra.main(version_base="1.2", config_name="default", config_path="../../configs")
|
@hydra.main(version_base="1.2", config_name="default", config_path="../../configs")
|
||||||
def train_cli(cfg: dict):
|
def train_cli(cfg: dict):
|
||||||
|
if not use_threads(cfg):
|
||||||
|
import torch.multiprocessing as mp
|
||||||
|
|
||||||
|
mp.set_start_method("spawn")
|
||||||
|
|
||||||
train(
|
train(
|
||||||
cfg,
|
cfg,
|
||||||
out_dir=hydra.core.hydra_config.HydraConfig.get().run.dir,
|
out_dir=hydra.core.hydra_config.HydraConfig.get().run.dir,
|
||||||
job_name=hydra.core.hydra_config.HydraConfig.get().job.name,
|
job_name=hydra.core.hydra_config.HydraConfig.get().job.name,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
logging.info("[LEARNER] train_cli finished")
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
train_cli()
|
train_cli()
|
||||||
|
|
||||||
|
logging.info("[LEARNER] main finished")
|
||||||
|
|
|
@ -1,23 +1,13 @@
|
||||||
import hilserl_pb2 # type: ignore
|
import hilserl_pb2 # type: ignore
|
||||||
import hilserl_pb2_grpc # type: ignore
|
import hilserl_pb2_grpc # type: ignore
|
||||||
import torch
|
|
||||||
from torch import nn
|
|
||||||
from threading import Lock, Event
|
|
||||||
import logging
|
import logging
|
||||||
import queue
|
from multiprocessing import Event, Queue
|
||||||
import io
|
|
||||||
import pickle
|
|
||||||
|
|
||||||
from lerobot.scripts.server.buffer import (
|
|
||||||
move_state_dict_to_device,
|
|
||||||
bytes_buffer_size,
|
|
||||||
state_to_bytes,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
from lerobot.scripts.server.network_utils import receive_bytes_in_chunks
|
||||||
|
from lerobot.scripts.server.network_utils import send_bytes_in_chunks
|
||||||
|
|
||||||
MAX_MESSAGE_SIZE = 4 * 1024 * 1024 # 4 MB
|
MAX_MESSAGE_SIZE = 4 * 1024 * 1024 # 4 MB
|
||||||
CHUNK_SIZE = 2 * 1024 * 1024 # 2 MB
|
MAX_WORKERS = 3 # Stream parameters, send transitions and interactions
|
||||||
MAX_WORKERS = 10
|
|
||||||
STUTDOWN_TIMEOUT = 10
|
STUTDOWN_TIMEOUT = 10
|
||||||
|
|
||||||
|
|
||||||
|
@ -25,89 +15,68 @@ class LearnerService(hilserl_pb2_grpc.LearnerServiceServicer):
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
shutdown_event: Event,
|
shutdown_event: Event,
|
||||||
policy: nn.Module,
|
parameters_queue: Queue,
|
||||||
policy_lock: Lock,
|
|
||||||
seconds_between_pushes: float,
|
seconds_between_pushes: float,
|
||||||
transition_queue: queue.Queue,
|
transition_queue: Queue,
|
||||||
interaction_message_queue: queue.Queue,
|
interaction_message_queue: Queue,
|
||||||
):
|
):
|
||||||
self.shutdown_event = shutdown_event
|
self.shutdown_event = shutdown_event
|
||||||
self.policy = policy
|
self.parameters_queue = parameters_queue
|
||||||
self.policy_lock = policy_lock
|
|
||||||
self.seconds_between_pushes = seconds_between_pushes
|
self.seconds_between_pushes = seconds_between_pushes
|
||||||
self.transition_queue = transition_queue
|
self.transition_queue = transition_queue
|
||||||
self.interaction_message_queue = interaction_message_queue
|
self.interaction_message_queue = interaction_message_queue
|
||||||
|
|
||||||
def _get_policy_state(self):
|
|
||||||
with self.policy_lock:
|
|
||||||
params_dict = self.policy.actor.state_dict()
|
|
||||||
# if self.policy.config.vision_encoder_name is not None:
|
|
||||||
# if self.policy.config.freeze_vision_encoder:
|
|
||||||
# params_dict: dict[str, torch.Tensor] = {
|
|
||||||
# k: v
|
|
||||||
# for k, v in params_dict.items()
|
|
||||||
# if not k.startswith("encoder.")
|
|
||||||
# }
|
|
||||||
# else:
|
|
||||||
# raise NotImplementedError(
|
|
||||||
# "Vision encoder is not frozen, we need to send the full model over the network which requires chunking the model."
|
|
||||||
# )
|
|
||||||
|
|
||||||
return move_state_dict_to_device(params_dict, device="cpu")
|
|
||||||
|
|
||||||
def _send_bytes(self, buffer: bytes):
|
|
||||||
size_in_bytes = bytes_buffer_size(buffer)
|
|
||||||
|
|
||||||
sent_bytes = 0
|
|
||||||
|
|
||||||
logging.info(f"Model state size {size_in_bytes/1024/1024} MB with")
|
|
||||||
|
|
||||||
while sent_bytes < size_in_bytes:
|
|
||||||
transfer_state = hilserl_pb2.TransferState.TRANSFER_MIDDLE
|
|
||||||
|
|
||||||
if sent_bytes + CHUNK_SIZE >= size_in_bytes:
|
|
||||||
transfer_state = hilserl_pb2.TransferState.TRANSFER_END
|
|
||||||
elif sent_bytes == 0:
|
|
||||||
transfer_state = hilserl_pb2.TransferState.TRANSFER_BEGIN
|
|
||||||
|
|
||||||
size_to_read = min(CHUNK_SIZE, size_in_bytes - sent_bytes)
|
|
||||||
chunk = buffer.read(size_to_read)
|
|
||||||
|
|
||||||
yield hilserl_pb2.Parameters(
|
|
||||||
transfer_state=transfer_state, parameter_bytes=chunk
|
|
||||||
)
|
|
||||||
sent_bytes += size_to_read
|
|
||||||
logging.info(
|
|
||||||
f"[Learner] Sent {sent_bytes}/{size_in_bytes} bytes with state {transfer_state}"
|
|
||||||
)
|
|
||||||
|
|
||||||
logging.info(f"[LEARNER] Published {sent_bytes/1024/1024} MB to the Actor")
|
|
||||||
|
|
||||||
def StreamParameters(self, request, context):
|
def StreamParameters(self, request, context):
|
||||||
# TODO: authorize the request
|
# TODO: authorize the request
|
||||||
logging.info("[LEARNER] Received request to stream parameters from the Actor")
|
logging.info("[LEARNER] Received request to stream parameters from the Actor")
|
||||||
|
|
||||||
while not self.shutdown_event.is_set():
|
while not self.shutdown_event.is_set():
|
||||||
logging.debug("[LEARNER] Push parameters to the Actor")
|
logging.info("[LEARNER] Push parameters to the Actor")
|
||||||
state_dict = self._get_policy_state()
|
buffer = self.parameters_queue.get()
|
||||||
|
|
||||||
with state_to_bytes(state_dict) as buffer:
|
yield from send_bytes_in_chunks(
|
||||||
yield from self._send_bytes(buffer)
|
buffer,
|
||||||
|
hilserl_pb2.Parameters,
|
||||||
|
log_prefix="[LEARNER] Sending parameters",
|
||||||
|
silent=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
logging.info("[LEARNER] Parameters sent")
|
||||||
|
|
||||||
self.shutdown_event.wait(self.seconds_between_pushes)
|
self.shutdown_event.wait(self.seconds_between_pushes)
|
||||||
|
|
||||||
def ReceiveTransitions(self, request_iterator, context):
|
logging.info("[LEARNER] Stream parameters finished")
|
||||||
|
return hilserl_pb2.Empty()
|
||||||
|
|
||||||
|
def SendTransitions(self, request_iterator, _context):
|
||||||
# TODO: authorize the request
|
# TODO: authorize the request
|
||||||
logging.info("[LEARNER] Received request to receive transitions from the Actor")
|
logging.info("[LEARNER] Received request to receive transitions from the Actor")
|
||||||
|
|
||||||
for request in request_iterator:
|
receive_bytes_in_chunks(
|
||||||
logging.debug("[LEARNER] Received request")
|
request_iterator,
|
||||||
if request.HasField("transition"):
|
self.transition_queue,
|
||||||
buffer = io.BytesIO(request.transition.transition_bytes)
|
self.shutdown_event,
|
||||||
transition = torch.load(buffer)
|
log_prefix="[LEARNER] transitions",
|
||||||
self.transition_queue.put(transition)
|
)
|
||||||
if request.HasField("interaction_message"):
|
|
||||||
content = pickle.loads(
|
logging.debug("[LEARNER] Finished receiving transitions")
|
||||||
request.interaction_message.interaction_message_bytes
|
return hilserl_pb2.Empty()
|
||||||
)
|
|
||||||
self.interaction_message_queue.put(content)
|
def SendInteractions(self, request_iterator, _context):
|
||||||
|
# TODO: authorize the request
|
||||||
|
logging.info(
|
||||||
|
"[LEARNER] Received request to receive interactions from the Actor"
|
||||||
|
)
|
||||||
|
|
||||||
|
receive_bytes_in_chunks(
|
||||||
|
request_iterator,
|
||||||
|
self.interaction_message_queue,
|
||||||
|
self.shutdown_event,
|
||||||
|
log_prefix="[LEARNER] interactions",
|
||||||
|
)
|
||||||
|
|
||||||
|
logging.debug("[LEARNER] Finished receiving interactions")
|
||||||
|
return hilserl_pb2.Empty()
|
||||||
|
|
||||||
|
def Ready(self, request, context):
|
||||||
|
return hilserl_pb2.Empty()
|
||||||
|
|
|
@ -5,9 +5,8 @@ import torch
|
||||||
|
|
||||||
from omegaconf import DictConfig
|
from omegaconf import DictConfig
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
"""Make ManiSkill3 gym environment"""
|
|
||||||
from mani_skill.vector.wrappers.gymnasium import ManiSkillVectorEnv
|
from mani_skill.vector.wrappers.gymnasium import ManiSkillVectorEnv
|
||||||
|
from mani_skill.utils.wrappers.record import RecordEpisode
|
||||||
|
|
||||||
|
|
||||||
def preprocess_maniskill_observation(
|
def preprocess_maniskill_observation(
|
||||||
|
@ -143,6 +142,15 @@ def make_maniskill(
|
||||||
num_envs=n_envs,
|
num_envs=n_envs,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if cfg.env.video_record.enabled:
|
||||||
|
env = RecordEpisode(
|
||||||
|
env,
|
||||||
|
output_dir=cfg.env.video_record.record_dir,
|
||||||
|
save_trajectory=True,
|
||||||
|
trajectory_name=cfg.env.video_record.trajectory_name,
|
||||||
|
save_video=True,
|
||||||
|
video_fps=30,
|
||||||
|
)
|
||||||
env = ManiSkillObservationWrapper(env, device=cfg.env.device)
|
env = ManiSkillObservationWrapper(env, device=cfg.env.device)
|
||||||
env = ManiSkillVectorEnv(env, ignore_terminations=True, auto_reset=False)
|
env = ManiSkillVectorEnv(env, ignore_terminations=True, auto_reset=False)
|
||||||
env._max_episode_steps = env.max_episode_steps = (
|
env._max_episode_steps = env.max_episode_steps = (
|
||||||
|
|
|
@ -0,0 +1,102 @@
|
||||||
|
#!/usr/bin/env python
|
||||||
|
|
||||||
|
# Copyright 2024 The HuggingFace Inc. team.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
from lerobot.scripts.server import hilserl_pb2
|
||||||
|
import logging
|
||||||
|
import io
|
||||||
|
from multiprocessing import Queue, Event
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
CHUNK_SIZE = 2 * 1024 * 1024 # 2 MB
|
||||||
|
|
||||||
|
|
||||||
|
def bytes_buffer_size(buffer: io.BytesIO) -> int:
|
||||||
|
buffer.seek(0, io.SEEK_END)
|
||||||
|
result = buffer.tell()
|
||||||
|
buffer.seek(0)
|
||||||
|
return result
|
||||||
|
|
||||||
|
|
||||||
|
def send_bytes_in_chunks(
|
||||||
|
buffer: bytes, message_class: Any, log_prefix: str = "", silent: bool = True
|
||||||
|
):
|
||||||
|
buffer = io.BytesIO(buffer)
|
||||||
|
size_in_bytes = bytes_buffer_size(buffer)
|
||||||
|
|
||||||
|
sent_bytes = 0
|
||||||
|
|
||||||
|
logging_method = logging.info if not silent else logging.debug
|
||||||
|
|
||||||
|
logging_method(f"{log_prefix} Buffer size {size_in_bytes/1024/1024} MB with")
|
||||||
|
|
||||||
|
while sent_bytes < size_in_bytes:
|
||||||
|
transfer_state = hilserl_pb2.TransferState.TRANSFER_MIDDLE
|
||||||
|
|
||||||
|
if sent_bytes + CHUNK_SIZE >= size_in_bytes:
|
||||||
|
transfer_state = hilserl_pb2.TransferState.TRANSFER_END
|
||||||
|
elif sent_bytes == 0:
|
||||||
|
transfer_state = hilserl_pb2.TransferState.TRANSFER_BEGIN
|
||||||
|
|
||||||
|
size_to_read = min(CHUNK_SIZE, size_in_bytes - sent_bytes)
|
||||||
|
chunk = buffer.read(size_to_read)
|
||||||
|
|
||||||
|
yield message_class(transfer_state=transfer_state, data=chunk)
|
||||||
|
sent_bytes += size_to_read
|
||||||
|
logging_method(
|
||||||
|
f"{log_prefix} Sent {sent_bytes}/{size_in_bytes} bytes with state {transfer_state}"
|
||||||
|
)
|
||||||
|
|
||||||
|
logging_method(f"{log_prefix} Published {sent_bytes/1024/1024} MB")
|
||||||
|
|
||||||
|
|
||||||
|
def receive_bytes_in_chunks(
|
||||||
|
iterator, queue: Queue, shutdown_event: Event, log_prefix: str = ""
|
||||||
|
):
|
||||||
|
bytes_buffer = io.BytesIO()
|
||||||
|
step = 0
|
||||||
|
|
||||||
|
logging.info(f"{log_prefix} Starting receiver")
|
||||||
|
for item in iterator:
|
||||||
|
logging.debug(f"{log_prefix} Received item")
|
||||||
|
if shutdown_event.is_set():
|
||||||
|
logging.info(f"{log_prefix} Shutting down receiver")
|
||||||
|
return
|
||||||
|
|
||||||
|
if item.transfer_state == hilserl_pb2.TransferState.TRANSFER_BEGIN:
|
||||||
|
bytes_buffer.seek(0)
|
||||||
|
bytes_buffer.truncate(0)
|
||||||
|
bytes_buffer.write(item.data)
|
||||||
|
logging.debug(f"{log_prefix} Received data at step 0")
|
||||||
|
step = 0
|
||||||
|
continue
|
||||||
|
elif item.transfer_state == hilserl_pb2.TransferState.TRANSFER_MIDDLE:
|
||||||
|
bytes_buffer.write(item.data)
|
||||||
|
step += 1
|
||||||
|
logging.debug(f"{log_prefix} Received data at step {step}")
|
||||||
|
elif item.transfer_state == hilserl_pb2.TransferState.TRANSFER_END:
|
||||||
|
bytes_buffer.write(item.data)
|
||||||
|
logging.debug(
|
||||||
|
f"{log_prefix} Received data at step end size {bytes_buffer_size(bytes_buffer)}"
|
||||||
|
)
|
||||||
|
|
||||||
|
queue.put(bytes_buffer.getvalue())
|
||||||
|
|
||||||
|
bytes_buffer.seek(0)
|
||||||
|
bytes_buffer.truncate(0)
|
||||||
|
step = 0
|
||||||
|
|
||||||
|
logging.debug(f"{log_prefix} Queue updated")
|
|
@ -0,0 +1,72 @@
|
||||||
|
#!/usr/bin/env python
|
||||||
|
|
||||||
|
# Copyright 2024 The HuggingFace Inc. team.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
import logging
|
||||||
|
import signal
|
||||||
|
import sys
|
||||||
|
from torch.multiprocessing import Queue
|
||||||
|
from queue import Empty
|
||||||
|
|
||||||
|
shutdown_event_counter = 0
|
||||||
|
|
||||||
|
|
||||||
|
def setup_process_handlers(use_threads: bool) -> any:
|
||||||
|
if use_threads:
|
||||||
|
from threading import Event
|
||||||
|
else:
|
||||||
|
from multiprocessing import Event
|
||||||
|
|
||||||
|
shutdown_event = Event()
|
||||||
|
|
||||||
|
# Define signal handler
|
||||||
|
def signal_handler(signum, frame):
|
||||||
|
logging.info("Shutdown signal received. Cleaning up...")
|
||||||
|
shutdown_event.set()
|
||||||
|
global shutdown_event_counter
|
||||||
|
shutdown_event_counter += 1
|
||||||
|
|
||||||
|
if shutdown_event_counter > 1:
|
||||||
|
logging.info("Force shutdown")
|
||||||
|
sys.exit(1)
|
||||||
|
|
||||||
|
signal.signal(signal.SIGINT, signal_handler) # Ctrl+C
|
||||||
|
signal.signal(signal.SIGTERM, signal_handler) # Termination request (kill)
|
||||||
|
signal.signal(signal.SIGHUP, signal_handler) # Terminal closed/Hangup
|
||||||
|
signal.signal(signal.SIGQUIT, signal_handler) # Ctrl+\
|
||||||
|
|
||||||
|
def signal_handler(signum, frame):
|
||||||
|
logging.info("Shutdown signal received. Cleaning up...")
|
||||||
|
shutdown_event.set()
|
||||||
|
|
||||||
|
return shutdown_event
|
||||||
|
|
||||||
|
|
||||||
|
def get_last_item_from_queue(queue: Queue):
|
||||||
|
item = queue.get()
|
||||||
|
counter = 1
|
||||||
|
|
||||||
|
# Drain queue and keep only the most recent parameters
|
||||||
|
try:
|
||||||
|
while True:
|
||||||
|
item = queue.get_nowait()
|
||||||
|
counter += 1
|
||||||
|
except Empty:
|
||||||
|
pass
|
||||||
|
|
||||||
|
logging.debug(f"Drained {counter} items from queue")
|
||||||
|
|
||||||
|
return item
|
Loading…
Reference in New Issue