677 lines
24 KiB
Python
677 lines
24 KiB
Python
#!/usr/bin/env python
|
|
|
|
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
import logging
|
|
import os
|
|
import time
|
|
from functools import lru_cache
|
|
from queue import Empty
|
|
from statistics import mean, quantiles
|
|
|
|
# from lerobot.scripts.eval import eval_policy
|
|
import grpc
|
|
import torch
|
|
from torch import nn
|
|
from torch.multiprocessing import Event, Queue
|
|
|
|
# TODO: Remove the import of maniskill
|
|
from lerobot.common.policies.factory import make_policy
|
|
from lerobot.common.policies.sac.modeling_sac import SACPolicy
|
|
from lerobot.common.robot_devices.utils import busy_wait
|
|
from lerobot.common.utils.random_utils import set_seed
|
|
from lerobot.common.utils.utils import (
|
|
TimerManager,
|
|
get_safe_torch_device,
|
|
init_logging,
|
|
)
|
|
from lerobot.configs import parser
|
|
from lerobot.configs.train import TrainPipelineConfig
|
|
from lerobot.scripts.server import hilserl_pb2, hilserl_pb2_grpc, learner_service
|
|
from lerobot.scripts.server.buffer import (
|
|
Transition,
|
|
bytes_to_state_dict,
|
|
move_state_dict_to_device,
|
|
move_transition_to_device,
|
|
python_object_to_bytes,
|
|
transitions_to_bytes,
|
|
)
|
|
from lerobot.scripts.server.gym_manipulator import make_robot_env
|
|
from lerobot.scripts.server.network_utils import (
|
|
receive_bytes_in_chunks,
|
|
send_bytes_in_chunks,
|
|
)
|
|
from lerobot.scripts.server.utils import get_last_item_from_queue, setup_process_handlers
|
|
|
|
ACTOR_SHUTDOWN_TIMEOUT = 30
|
|
|
|
|
|
#################################################
|
|
# Main entry point #
|
|
#################################################
|
|
|
|
|
|
@parser.wrap()
|
|
def actor_cli(cfg: TrainPipelineConfig):
|
|
cfg.validate()
|
|
if not use_threads(cfg):
|
|
import torch.multiprocessing as mp
|
|
|
|
mp.set_start_method("spawn")
|
|
|
|
# Create logs directory to ensure it exists
|
|
log_dir = os.path.join(cfg.output_dir, "logs")
|
|
os.makedirs(log_dir, exist_ok=True)
|
|
log_file = os.path.join(log_dir, f"actor_{cfg.job_name}.log")
|
|
|
|
# Initialize logging with explicit log file
|
|
init_logging(log_file=log_file)
|
|
logging.info(f"Actor logging initialized, writing to {log_file}")
|
|
|
|
shutdown_event = setup_process_handlers(use_threads(cfg))
|
|
|
|
learner_client, grpc_channel = learner_service_client(
|
|
host=cfg.policy.actor_learner_config.learner_host,
|
|
port=cfg.policy.actor_learner_config.learner_port,
|
|
)
|
|
|
|
logging.info("[ACTOR] Establishing connection with Learner")
|
|
if not establish_learner_connection(learner_client, shutdown_event):
|
|
logging.error("[ACTOR] Failed to establish connection with Learner")
|
|
return
|
|
|
|
if not use_threads(cfg):
|
|
# If we use multithreading, we can reuse the channel
|
|
grpc_channel.close()
|
|
grpc_channel = None
|
|
|
|
logging.info("[ACTOR] Connection with Learner established")
|
|
|
|
parameters_queue = Queue()
|
|
transitions_queue = Queue()
|
|
interactions_queue = Queue()
|
|
|
|
concurrency_entity = None
|
|
if use_threads(cfg):
|
|
from threading import Thread
|
|
|
|
concurrency_entity = Thread
|
|
else:
|
|
from multiprocessing import Process
|
|
|
|
concurrency_entity = Process
|
|
|
|
receive_policy_process = concurrency_entity(
|
|
target=receive_policy,
|
|
args=(cfg, parameters_queue, shutdown_event, grpc_channel),
|
|
daemon=True,
|
|
)
|
|
|
|
transitions_process = concurrency_entity(
|
|
target=send_transitions,
|
|
args=(cfg, transitions_queue, shutdown_event, grpc_channel),
|
|
daemon=True,
|
|
)
|
|
|
|
interactions_process = concurrency_entity(
|
|
target=send_interactions,
|
|
args=(cfg, interactions_queue, shutdown_event, grpc_channel),
|
|
daemon=True,
|
|
)
|
|
|
|
transitions_process.start()
|
|
interactions_process.start()
|
|
receive_policy_process.start()
|
|
|
|
# HACK: FOR MANISKILL we do not have a reward classifier
|
|
# TODO: Remove this once we merge into main
|
|
reward_classifier = None
|
|
# if (
|
|
# cfg.env.reward_classifier["pretrained_path"] is not None
|
|
# and cfg.env.reward_classifier["config_path"] is not None
|
|
# ):
|
|
# reward_classifier = get_classifier(
|
|
# pretrained_path=cfg.env.reward_classifier["pretrained_path"],
|
|
# config_path=cfg.env.reward_classifier["config_path"],
|
|
# )
|
|
|
|
act_with_policy(
|
|
cfg=cfg,
|
|
reward_classifier=reward_classifier,
|
|
shutdown_event=shutdown_event,
|
|
parameters_queue=parameters_queue,
|
|
transitions_queue=transitions_queue,
|
|
interactions_queue=interactions_queue,
|
|
)
|
|
logging.info("[ACTOR] Policy process joined")
|
|
|
|
logging.info("[ACTOR] Closing queues")
|
|
transitions_queue.close()
|
|
interactions_queue.close()
|
|
parameters_queue.close()
|
|
|
|
transitions_process.join()
|
|
logging.info("[ACTOR] Transitions process joined")
|
|
interactions_process.join()
|
|
logging.info("[ACTOR] Interactions process joined")
|
|
receive_policy_process.join()
|
|
logging.info("[ACTOR] Receive policy process joined")
|
|
|
|
logging.info("[ACTOR] join queues")
|
|
transitions_queue.cancel_join_thread()
|
|
interactions_queue.cancel_join_thread()
|
|
parameters_queue.cancel_join_thread()
|
|
|
|
logging.info("[ACTOR] queues closed")
|
|
|
|
|
|
#################################################
|
|
# Core algorithm functions #
|
|
#################################################
|
|
|
|
|
|
def act_with_policy(
|
|
cfg: TrainPipelineConfig,
|
|
reward_classifier: nn.Module,
|
|
shutdown_event: any, # Event,
|
|
parameters_queue: Queue,
|
|
transitions_queue: Queue,
|
|
interactions_queue: Queue,
|
|
):
|
|
"""
|
|
Executes policy interaction within the environment.
|
|
|
|
This function rolls out the policy in the environment, collecting interaction data and pushing it to a queue for streaming to the learner.
|
|
Once an episode is completed, updated network parameters received from the learner are retrieved from a queue and loaded into the network.
|
|
|
|
Args:
|
|
cfg: Configuration settings for the interaction process.
|
|
reward_classifier: Reward classifier to use for the interaction process.
|
|
shutdown_event: Event to check if the process should shutdown.
|
|
parameters_queue: Queue to receive updated network parameters from the learner.
|
|
transitions_queue: Queue to send transitions to the learner.
|
|
interactions_queue: Queue to send interactions to the learner.
|
|
"""
|
|
# Initialize logging for multiprocessing
|
|
if not use_threads(cfg):
|
|
log_dir = os.path.join(cfg.output_dir, "logs")
|
|
os.makedirs(log_dir, exist_ok=True)
|
|
log_file = os.path.join(log_dir, f"actor_policy_{os.getpid()}.log")
|
|
init_logging(log_file=log_file)
|
|
logging.info("Actor policy process logging initialized")
|
|
|
|
logging.info("make_env online")
|
|
|
|
online_env = make_robot_env(cfg=cfg.env)
|
|
|
|
set_seed(cfg.seed)
|
|
device = get_safe_torch_device(cfg.policy.device, log=True)
|
|
|
|
torch.backends.cudnn.benchmark = True
|
|
torch.backends.cuda.matmul.allow_tf32 = True
|
|
|
|
logging.info("make_policy")
|
|
|
|
### Instantiate the policy in both the actor and learner processes
|
|
### To avoid sending a SACPolicy object through the port, we create a policy intance
|
|
### on both sides, the learner sends the updated parameters every n steps to update the actor's parameters
|
|
# TODO: At some point we should just need make sac policy
|
|
policy: SACPolicy = make_policy(
|
|
cfg=cfg.policy,
|
|
env_cfg=cfg.env,
|
|
)
|
|
policy = policy.eval()
|
|
assert isinstance(policy, nn.Module)
|
|
|
|
obs, info = online_env.reset()
|
|
|
|
# NOTE: For the moment we will solely handle the case of a single environment
|
|
sum_reward_episode = 0
|
|
list_transition_to_send_to_learner = []
|
|
list_policy_time = []
|
|
episode_intervention = False
|
|
# Add counters for intervention rate calculation
|
|
episode_intervention_steps = 0
|
|
episode_total_steps = 0
|
|
|
|
for interaction_step in range(cfg.policy.online_steps):
|
|
start_time = time.perf_counter()
|
|
if shutdown_event.is_set():
|
|
logging.info("[ACTOR] Shutting down act_with_policy")
|
|
return
|
|
|
|
if interaction_step >= cfg.policy.online_step_before_learning:
|
|
# Time policy inference and check if it meets FPS requirement
|
|
with TimerManager(
|
|
elapsed_time_list=list_policy_time,
|
|
label="Policy inference time",
|
|
log=False,
|
|
) as timer: # noqa: F841
|
|
action = policy.select_action(batch=obs)
|
|
policy_fps = 1.0 / (list_policy_time[-1] + 1e-9)
|
|
|
|
log_policy_frequency_issue(policy_fps=policy_fps, cfg=cfg, interaction_step=interaction_step)
|
|
|
|
next_obs, reward, done, truncated, info = online_env.step(action.squeeze(dim=0).cpu().numpy())
|
|
else:
|
|
# TODO (azouitine): Make a custom space for torch tensor
|
|
action = online_env.action_space.sample()
|
|
next_obs, reward, done, truncated, info = online_env.step(action)
|
|
|
|
# HACK: We have only one env but we want to batch it, it will be resolved with the torch box
|
|
action = (
|
|
torch.from_numpy(action[0]).to(device, non_blocking=device.type == "cuda").unsqueeze(dim=0)
|
|
)
|
|
|
|
sum_reward_episode += float(reward)
|
|
# Increment total steps counter for intervention rate
|
|
episode_total_steps += 1
|
|
|
|
# NOTE: We overide the action if the intervention is True, because the action applied is the intervention action
|
|
if "is_intervention" in info and info["is_intervention"]:
|
|
# TODO: Check the shape
|
|
# NOTE: The action space for demonstration before hand is with the full action space
|
|
# but sometimes for example we want to deactivate the gripper
|
|
action = info["action_intervention"]
|
|
episode_intervention = True
|
|
# Increment intervention steps counter
|
|
episode_intervention_steps += 1
|
|
|
|
# Check for NaN values in observations
|
|
for key, tensor in obs.items():
|
|
if torch.isnan(tensor).any():
|
|
logging.error(f"[ACTOR] NaN values found in obs[{key}] at step {interaction_step}")
|
|
|
|
list_transition_to_send_to_learner.append(
|
|
Transition(
|
|
state=obs,
|
|
action=action,
|
|
reward=reward,
|
|
next_state=next_obs,
|
|
done=done,
|
|
truncated=truncated, # TODO: (azouitine) Handle truncation properly
|
|
complementary_info=info, # TODO Handle information for the transition, is_demonstraction: bool
|
|
)
|
|
)
|
|
# assign obs to the next obs and continue the rollout
|
|
obs = next_obs
|
|
|
|
# HACK: We have only one env but we want to batch it, it will be resolved with the torch box
|
|
# Because we are using a single environment we can index at zero
|
|
if done or truncated:
|
|
# TODO: Handle logging for episode information
|
|
logging.info(f"[ACTOR] Global step {interaction_step}: Episode reward: {sum_reward_episode}")
|
|
|
|
update_policy_parameters(policy=policy.actor, parameters_queue=parameters_queue, device=device)
|
|
|
|
if len(list_transition_to_send_to_learner) > 0:
|
|
push_transitions_to_transport_queue(
|
|
transitions=list_transition_to_send_to_learner,
|
|
transitions_queue=transitions_queue,
|
|
)
|
|
list_transition_to_send_to_learner = []
|
|
|
|
stats = get_frequency_stats(list_policy_time)
|
|
list_policy_time.clear()
|
|
|
|
# Calculate intervention rate
|
|
intervention_rate = 0.0
|
|
if episode_total_steps > 0:
|
|
intervention_rate = episode_intervention_steps / episode_total_steps
|
|
|
|
# Send episodic reward to the learner
|
|
interactions_queue.put(
|
|
python_object_to_bytes(
|
|
{
|
|
"Episodic reward": sum_reward_episode,
|
|
"Interaction step": interaction_step,
|
|
"Episode intervention": int(episode_intervention),
|
|
"Intervention rate": intervention_rate,
|
|
**stats,
|
|
}
|
|
)
|
|
)
|
|
sum_reward_episode = 0.0
|
|
episode_intervention = False
|
|
# Reset intervention counters
|
|
episode_intervention_steps = 0
|
|
episode_total_steps = 0
|
|
obs, info = online_env.reset()
|
|
|
|
if cfg.env.fps is not None:
|
|
dt_time = time.perf_counter() - start_time
|
|
busy_wait(1 / cfg.env.fps - dt_time)
|
|
|
|
|
|
#################################################
|
|
# Communication Functions - Group all gRPC/messaging functions #
|
|
#################################################
|
|
|
|
|
|
def establish_learner_connection(
|
|
stub,
|
|
shutdown_event: any, # Event,
|
|
attempts=30,
|
|
):
|
|
for _ in range(attempts):
|
|
if shutdown_event.is_set():
|
|
logging.info("[ACTOR] Shutting down establish_learner_connection")
|
|
return False
|
|
|
|
# Force a connection attempt and check state
|
|
try:
|
|
logging.info("[ACTOR] Send ready message to Learner")
|
|
if stub.Ready(hilserl_pb2.Empty()) == hilserl_pb2.Empty():
|
|
return True
|
|
except grpc.RpcError as e:
|
|
logging.error(f"[ACTOR] Waiting for Learner to be ready... {e}")
|
|
time.sleep(2)
|
|
return False
|
|
|
|
|
|
@lru_cache(maxsize=1)
|
|
def learner_service_client(
|
|
host="127.0.0.1", port=50051
|
|
) -> tuple[hilserl_pb2_grpc.LearnerServiceStub, grpc.Channel]:
|
|
import json
|
|
|
|
"""
|
|
Returns a client for the learner service.
|
|
|
|
GRPC uses HTTP/2, which is a binary protocol and multiplexes requests over a single connection.
|
|
So we need to create only one client and reuse it.
|
|
"""
|
|
|
|
service_config = {
|
|
"methodConfig": [
|
|
{
|
|
"name": [{}], # Applies to ALL methods in ALL services
|
|
"retryPolicy": {
|
|
"maxAttempts": 5, # Max retries (total attempts = 5)
|
|
"initialBackoff": "0.1s", # First retry after 0.1s
|
|
"maxBackoff": "2s", # Max wait time between retries
|
|
"backoffMultiplier": 2, # Exponential backoff factor
|
|
"retryableStatusCodes": [
|
|
"UNAVAILABLE",
|
|
"DEADLINE_EXCEEDED",
|
|
], # Retries on network failures
|
|
},
|
|
}
|
|
]
|
|
}
|
|
|
|
service_config_json = json.dumps(service_config)
|
|
|
|
channel = grpc.insecure_channel(
|
|
f"{host}:{port}",
|
|
options=[
|
|
("grpc.max_receive_message_length", learner_service.MAX_MESSAGE_SIZE),
|
|
("grpc.max_send_message_length", learner_service.MAX_MESSAGE_SIZE),
|
|
("grpc.enable_retries", 1),
|
|
("grpc.service_config", service_config_json),
|
|
],
|
|
)
|
|
stub = hilserl_pb2_grpc.LearnerServiceStub(channel)
|
|
logging.info("[ACTOR] Learner service client created")
|
|
return stub, channel
|
|
|
|
|
|
def receive_policy(
|
|
cfg: TrainPipelineConfig,
|
|
parameters_queue: Queue,
|
|
shutdown_event: any, # Event,
|
|
learner_client: hilserl_pb2_grpc.LearnerServiceStub | None = None,
|
|
grpc_channel: grpc.Channel | None = None,
|
|
):
|
|
logging.info("[ACTOR] Start receiving parameters from the Learner")
|
|
|
|
if not use_threads(cfg):
|
|
# Create a process-specific log file
|
|
log_dir = os.path.join(cfg.output_dir, "logs")
|
|
os.makedirs(log_dir, exist_ok=True)
|
|
log_file = os.path.join(log_dir, f"actor_receive_policy_{os.getpid()}.log")
|
|
|
|
# Initialize logging with explicit log file
|
|
init_logging(log_file=log_file)
|
|
logging.info("Actor receive policy process logging initialized")
|
|
|
|
# Setup process handlers to handle shutdown signal
|
|
# But use shutdown event from the main process
|
|
setup_process_handlers(use_threads=False)
|
|
|
|
if grpc_channel is None or learner_client is None:
|
|
learner_client, grpc_channel = learner_service_client(
|
|
host=cfg.policy.actor_learner_config.learner_host,
|
|
port=cfg.policy.actor_learner_config.learner_port,
|
|
)
|
|
|
|
try:
|
|
iterator = learner_client.StreamParameters(hilserl_pb2.Empty())
|
|
receive_bytes_in_chunks(
|
|
iterator,
|
|
parameters_queue,
|
|
shutdown_event,
|
|
log_prefix="[ACTOR] parameters",
|
|
)
|
|
|
|
except grpc.RpcError as e:
|
|
logging.error(f"[ACTOR] gRPC error: {e}")
|
|
|
|
if not use_threads(cfg):
|
|
grpc_channel.close()
|
|
logging.info("[ACTOR] Received policy loop stopped")
|
|
|
|
|
|
def send_transitions(
|
|
cfg: TrainPipelineConfig,
|
|
transitions_queue: Queue,
|
|
shutdown_event: any, # Event,
|
|
learner_client: hilserl_pb2_grpc.LearnerServiceStub | None = None,
|
|
grpc_channel: grpc.Channel | None = None,
|
|
) -> hilserl_pb2.Empty:
|
|
"""
|
|
Sends transitions to the learner.
|
|
|
|
This function continuously retrieves messages from the queue and processes:
|
|
|
|
- **Transition Data:**
|
|
- A batch of transitions (observation, action, reward, next observation) is collected.
|
|
- Transitions are moved to the CPU and serialized using PyTorch.
|
|
- The serialized data is wrapped in a `hilserl_pb2.Transition` message and sent to the learner.
|
|
"""
|
|
|
|
if not use_threads(cfg):
|
|
# Create a process-specific log file
|
|
log_dir = os.path.join(cfg.output_dir, "logs")
|
|
os.makedirs(log_dir, exist_ok=True)
|
|
log_file = os.path.join(log_dir, f"actor_transitions_{os.getpid()}.log")
|
|
|
|
# Initialize logging with explicit log file
|
|
init_logging(log_file=log_file)
|
|
logging.info("Actor transitions process logging initialized")
|
|
|
|
# Setup process handlers to handle shutdown signal
|
|
# But use shutdown event from the main process
|
|
setup_process_handlers(False)
|
|
|
|
if grpc_channel is None or learner_client is None:
|
|
learner_client, grpc_channel = learner_service_client(
|
|
host=cfg.policy.actor_learner_config.learner_host,
|
|
port=cfg.policy.actor_learner_config.learner_port,
|
|
)
|
|
|
|
try:
|
|
learner_client.SendTransitions(transitions_stream(shutdown_event, transitions_queue))
|
|
except grpc.RpcError as e:
|
|
logging.error(f"[ACTOR] gRPC error: {e}")
|
|
|
|
logging.info("[ACTOR] Finished streaming transitions")
|
|
|
|
if not use_threads(cfg):
|
|
grpc_channel.close()
|
|
logging.info("[ACTOR] Transitions process stopped")
|
|
|
|
|
|
def send_interactions(
|
|
cfg: TrainPipelineConfig,
|
|
interactions_queue: Queue,
|
|
shutdown_event: any, # Event,
|
|
learner_client: hilserl_pb2_grpc.LearnerServiceStub | None = None,
|
|
grpc_channel: grpc.Channel | None = None,
|
|
) -> hilserl_pb2.Empty:
|
|
"""
|
|
Sends interactions to the learner.
|
|
|
|
This function continuously retrieves messages from the queue and processes:
|
|
|
|
- **Interaction Messages:**
|
|
- Contains useful statistics about episodic rewards and policy timings.
|
|
- The message is serialized using `pickle` and sent to the learner.
|
|
"""
|
|
|
|
if not use_threads(cfg):
|
|
# Create a process-specific log file
|
|
log_dir = os.path.join(cfg.output_dir, "logs")
|
|
os.makedirs(log_dir, exist_ok=True)
|
|
log_file = os.path.join(log_dir, f"actor_interactions_{os.getpid()}.log")
|
|
|
|
# Initialize logging with explicit log file
|
|
init_logging(log_file=log_file)
|
|
logging.info("Actor interactions process logging initialized")
|
|
|
|
# Setup process handlers to handle shutdown signal
|
|
# But use shutdown event from the main process
|
|
setup_process_handlers(False)
|
|
|
|
if grpc_channel is None or learner_client is None:
|
|
learner_client, grpc_channel = learner_service_client(
|
|
host=cfg.policy.actor_learner_config.learner_host,
|
|
port=cfg.policy.actor_learner_config.learner_port,
|
|
)
|
|
|
|
try:
|
|
learner_client.SendInteractions(interactions_stream(shutdown_event, interactions_queue))
|
|
except grpc.RpcError as e:
|
|
logging.error(f"[ACTOR] gRPC error: {e}")
|
|
|
|
logging.info("[ACTOR] Finished streaming interactions")
|
|
|
|
if not use_threads(cfg):
|
|
grpc_channel.close()
|
|
logging.info("[ACTOR] Interactions process stopped")
|
|
|
|
|
|
def transitions_stream(shutdown_event: Event, transitions_queue: Queue) -> hilserl_pb2.Empty:
|
|
while not shutdown_event.is_set():
|
|
try:
|
|
message = transitions_queue.get(block=True, timeout=5)
|
|
except Empty:
|
|
logging.debug("[ACTOR] Transition queue is empty")
|
|
continue
|
|
|
|
yield from send_bytes_in_chunks(
|
|
message, hilserl_pb2.Transition, log_prefix="[ACTOR] Send transitions"
|
|
)
|
|
|
|
return hilserl_pb2.Empty()
|
|
|
|
|
|
def interactions_stream(
|
|
shutdown_event: any, # Event,
|
|
interactions_queue: Queue,
|
|
) -> hilserl_pb2.Empty:
|
|
while not shutdown_event.is_set():
|
|
try:
|
|
message = interactions_queue.get(block=True, timeout=5)
|
|
except Empty:
|
|
logging.debug("[ACTOR] Interaction queue is empty")
|
|
continue
|
|
|
|
yield from send_bytes_in_chunks(
|
|
message,
|
|
hilserl_pb2.InteractionMessage,
|
|
log_prefix="[ACTOR] Send interactions",
|
|
)
|
|
|
|
return hilserl_pb2.Empty()
|
|
|
|
|
|
#################################################
|
|
# Policy functions #
|
|
#################################################
|
|
|
|
|
|
def update_policy_parameters(policy: SACPolicy, parameters_queue: Queue, device):
|
|
if not parameters_queue.empty():
|
|
logging.info("[ACTOR] Load new parameters from Learner.")
|
|
bytes_state_dict = get_last_item_from_queue(parameters_queue)
|
|
state_dict = bytes_to_state_dict(bytes_state_dict)
|
|
state_dict = move_state_dict_to_device(state_dict, device=device)
|
|
policy.load_state_dict(state_dict)
|
|
|
|
|
|
#################################################
|
|
# Utilities functions #
|
|
#################################################
|
|
|
|
|
|
def push_transitions_to_transport_queue(transitions: list, transitions_queue):
|
|
"""Send transitions to learner in smaller chunks to avoid network issues.
|
|
|
|
Args:
|
|
transitions: List of transitions to send
|
|
message_queue: Queue to send messages to learner
|
|
chunk_size: Size of each chunk to send
|
|
"""
|
|
transition_to_send_to_learner = []
|
|
for transition in transitions:
|
|
tr = move_transition_to_device(transition=transition, device="cpu")
|
|
for key, value in tr["state"].items():
|
|
if torch.isnan(value).any():
|
|
logging.warning(f"Found NaN values in transition {key}")
|
|
|
|
transition_to_send_to_learner.append(tr)
|
|
|
|
transitions_queue.put(transitions_to_bytes(transition_to_send_to_learner))
|
|
|
|
|
|
def get_frequency_stats(list_policy_time: list[float]) -> dict[str, float]:
|
|
stats = {}
|
|
list_policy_fps = [1.0 / t for t in list_policy_time]
|
|
if len(list_policy_fps) > 1:
|
|
policy_fps = mean(list_policy_fps)
|
|
quantiles_90 = quantiles(list_policy_fps, n=10)[-1]
|
|
logging.debug(f"[ACTOR] Average policy frame rate: {policy_fps}")
|
|
logging.debug(f"[ACTOR] Policy frame rate 90th percentile: {quantiles_90}")
|
|
stats = {
|
|
"Policy frequency [Hz]": policy_fps,
|
|
"Policy frequency 90th-p [Hz]": quantiles_90,
|
|
}
|
|
return stats
|
|
|
|
|
|
def log_policy_frequency_issue(policy_fps: float, cfg: TrainPipelineConfig, interaction_step: int):
|
|
if policy_fps < cfg.env.fps:
|
|
logging.warning(
|
|
f"[ACTOR] Policy FPS {policy_fps:.1f} below required {cfg.env.fps} at step {interaction_step}"
|
|
)
|
|
|
|
|
|
def use_threads(cfg: TrainPipelineConfig) -> bool:
|
|
return cfg.policy.concurrency.actor == "threads"
|
|
|
|
|
|
if __name__ == "__main__":
|
|
actor_cli()
|