diff --git a/lerobot/common/policies/sac/configuration_sac.py b/lerobot/common/policies/sac/configuration_sac.py index 62f35ed5..904679e8 100644 --- a/lerobot/common/policies/sac/configuration_sac.py +++ b/lerobot/common/policies/sac/configuration_sac.py @@ -45,6 +45,14 @@ class SACConfig: "action": {"min": [-1, -1], "max": [1, 1]}, } ) + # TODO: Move it outside of the config + actor_learner_config: dict[str, str | int] = field( + default_factory=lambda: { + "actor_ip": "127.0.0.1", + "port": 50051, + "learner_ip": "127.0.0.1", + } + ) camera_number: int = 1 # Add type annotations for these fields: image_encoder_hidden_dim: int = 32 diff --git a/lerobot/common/policies/sac/modeling_sac.py b/lerobot/common/policies/sac/modeling_sac.py index 8fb46199..8567313d 100644 --- a/lerobot/common/policies/sac/modeling_sac.py +++ b/lerobot/common/policies/sac/modeling_sac.py @@ -17,8 +17,7 @@ # TODO: (1) better device management -from collections import deque -from typing import Callable, Optional, Sequence, Tuple, Union +from typing import Callable, Optional, Tuple import einops import numpy as np @@ -74,43 +73,42 @@ class SACPolicy( config.output_shapes, config.output_normalization_modes, dataset_stats ) + # NOTE: For images the encoder should be shared between the actor and critic if config.shared_encoder: encoder_critic = SACObservationEncoder(config) encoder_actor: SACObservationEncoder = encoder_critic else: encoder_critic = SACObservationEncoder(config) encoder_actor = SACObservationEncoder(config) - # Define networks - critic_nets = [] - for _ in range(config.num_critics): - critic_net = Critic( - encoder=encoder_critic, - network=MLP( - input_dim=encoder_critic.output_dim + config.output_shapes["action"][0], - **config.critic_network_kwargs, - ), - device=device, - ) - critic_nets.append(critic_net) - target_critic_nets = [] - for _ in range(config.num_critics): - target_critic_net = Critic( - encoder=encoder_critic, - network=MLP( - input_dim=encoder_critic.output_dim + config.output_shapes["action"][0], - **config.critic_network_kwargs, - ), - device=device, - ) - target_critic_nets.append(target_critic_net) + self.critic_ensemble = CriticEnsemble( + encoder=encoder_critic, + network_list=nn.ModuleList( + [ + MLP( + input_dim=encoder_critic.output_dim + config.output_shapes["action"][0], + **config.critic_network_kwargs, + ) + for _ in range(config.num_critics) + ] + ), + device=device, + ) - self.critic_ensemble = create_critic_ensemble( - critics=critic_nets, num_critics=config.num_critics, device=device - ) - self.critic_target = create_critic_ensemble( - critics=target_critic_nets, num_critics=config.num_critics, device=device + self.critic_target = CriticEnsemble( + encoder=encoder_critic, + network_list=nn.ModuleList( + [ + MLP( + input_dim=encoder_critic.output_dim + config.output_shapes["action"][0], + **config.critic_network_kwargs, + ) + for _ in range(config.num_critics) + ] + ), + device=device, ) + self.critic_target.load_state_dict(self.critic_ensemble.state_dict()) self.actor = Policy( @@ -123,7 +121,8 @@ class SACPolicy( ) if config.target_entropy is None: config.target_entropy = -np.prod(config.output_shapes["action"][0]) / 2 # (-dim(A)/2) - # TODO: Handle the case where the temparameter is a fixed + + # TODO (azouitine): Handle the case where the temparameter is a fixed self.log_alpha = torch.zeros(1, requires_grad=True, device=device) self.temperature = self.log_alpha.exp().item() @@ -152,18 +151,19 @@ class SACPolicy( Tensor of Q-values from all critics """ critics = self.critic_target if use_target else self.critic_ensemble - q_values = torch.stack([critic(observations, actions) for critic in critics]) + q_values = critics(observations, actions) return q_values def forward(self, batch: dict[str, Tensor]) -> dict[str, Tensor | float]: ... def update_target_networks(self): """Update target networks with exponential moving average""" - for target_critic, critic in zip(self.critic_target, self.critic_ensemble, strict=False): - for target_param, param in zip(target_critic.parameters(), critic.parameters(), strict=False): - target_param.data.copy_( - param.data * self.config.critic_target_update_weight - + target_param.data * (1.0 - self.config.critic_target_update_weight) - ) + for target_param, param in zip( + self.critic_target.parameters(), self.critic_ensemble.parameters(), strict=False + ): + target_param.data.copy_( + param.data * self.config.critic_target_update_weight + + target_param.data * (1.0 - self.config.critic_target_update_weight) + ) def compute_loss_critic(self, observations, actions, rewards, next_observations, done) -> Tensor: temperature = self.log_alpha.exp().item() @@ -264,34 +264,83 @@ class MLP(nn.Module): return self.net(x) -class Critic(nn.Module): +class CriticEnsemble(nn.Module): + """ + ┌──────────────────┬─────────────────────────────────────────────────────────┐ + │ Critic Ensemble │ │ + ├──────────────────┘ │ + │ │ + │ ┌────┐ ┌────┐ ┌────┐ │ + │ │ Q1 │ │ Q2 │ │ Qn │ │ + │ └────┘ └────┘ └────┘ │ + │ ┌──────────────┐ ┌──────────────┐ ┌──────────────┐ │ + │ │ │ │ │ │ │ │ + │ │ MLP 1 │ │ MLP 2 │ │ MLP │ │ + │ │ │ │ │ ... │ num_critics │ │ + │ │ │ │ │ │ │ │ + │ └──────────────┘ └──────────────┘ └──────────────┘ │ + │ ▲ ▲ ▲ │ + │ └───────────────────┴───────┬────────────────────────────┘ │ + │ │ │ + │ │ │ + │ ┌───────────────────┐ │ + │ │ Embedding │ │ + │ │ │ │ + │ └───────────────────┘ │ + │ ▲ │ + │ │ │ + │ ┌─────────────┴────────────┐ │ + │ │ │ │ + │ │ SACObservationEncoder │ │ + │ │ │ │ + │ └──────────────────────────┘ │ + │ ▲ │ + │ │ │ + │ │ │ + │ │ │ + └───────────────────────────┬────────────────────┬───────────────────────────┘ + │ Observation │ + └────────────────────┘ + """ + def __init__( self, encoder: Optional[nn.Module], - network: nn.Module, + network_list: nn.Module, init_final: Optional[float] = None, device: str = "cpu", ): super().__init__() self.device = torch.device(device) self.encoder = encoder - self.network = network + self.network_list = network_list self.init_final = init_final + # for network in network_list: + # network.to(self.device) + # Find the last Linear layer's output dimension - for layer in reversed(network.net): + for layer in reversed(network_list[0].net): if isinstance(layer, nn.Linear): out_features = layer.out_features break # Output layer + self.output_layers = [] if init_final is not None: - self.output_layer = nn.Linear(out_features, 1) - nn.init.uniform_(self.output_layer.weight, -init_final, init_final) - nn.init.uniform_(self.output_layer.bias, -init_final, init_final) + for _ in network_list: + output_layer = nn.Linear(out_features, 1, device=device) + nn.init.uniform_(output_layer.weight, -init_final, init_final) + nn.init.uniform_(output_layer.bias, -init_final, init_final) + self.output_layers.append(output_layer) else: - self.output_layer = nn.Linear(out_features, 1) - orthogonal_init()(self.output_layer.weight) + self.output_layers = [] + for _ in network_list: + output_layer = nn.Linear(out_features, 1, device=device) + orthogonal_init()(output_layer.weight) + self.output_layers.append(output_layer) + + self.output_layers = nn.ModuleList(self.output_layers) self.to(self.device) @@ -307,9 +356,12 @@ class Critic(nn.Module): obs_enc = observations if self.encoder is None else self.encoder(observations) inputs = torch.cat([obs_enc, actions], dim=-1) - x = self.network(inputs) - value = self.output_layer(x) - return value.squeeze(-1) + list_q_values = [] + for network, output_layer in zip(self.network_list, self.output_layers, strict=False): + x = network(inputs) + value = output_layer(x) + list_q_values.append(value.squeeze(-1)) + return torch.stack(list_q_values) class Policy(nn.Module): @@ -416,9 +468,7 @@ class Policy(nn.Module): class SACObservationEncoder(nn.Module): - """Encode image and/or state vector observations. - TODO(ke-wang): The original work allows for (1) stacking multiple history frames and (2) using pretrained resnet encoders. - """ + """Encode image and/or state vector observations.""" def __init__(self, config: SACConfig): """ @@ -513,8 +563,7 @@ class SACObservationEncoder(nn.Module): feat.append(self.env_state_enc_layers(obs_dict["observation.environment_state"])) if "observation.state" in self.config.input_shapes: feat.append(self.state_enc_layers(obs_dict["observation.state"])) - # TODO(ke-wang): currently average over all features, concatenate all features maybe a better way - # return torch.stack(feat, dim=0).mean(0) + features = torch.cat(tensors=feat, dim=-1) features = self.aggregation_layer(features) @@ -530,12 +579,8 @@ def orthogonal_init(): return lambda x: torch.nn.init.orthogonal_(x, gain=1.0) -def create_critic_ensemble(critics: list[nn.Module], num_critics: int, device: str = "cpu") -> nn.ModuleList: - """Creates an ensemble of critic networks""" - assert len(critics) == num_critics, f"Expected {num_critics} critics, got {len(critics)}" - return nn.ModuleList(critics).to(device) - - +# TODO (azouitine): I think in our case this function is not usefull we should remove it +# after some investigation # borrowed from tdmpc def flatten_forward_unflatten(fn: Callable[[Tensor], Tensor], image_tensor: Tensor) -> Tensor: """Helper to temporarily flatten extra dims at the start of the image tensor. diff --git a/lerobot/configs/policy/sac_manyskill.yaml b/lerobot/configs/policy/sac_manyskill.yaml index fc824da5..59f42247 100644 --- a/lerobot/configs/policy/sac_manyskill.yaml +++ b/lerobot/configs/policy/sac_manyskill.yaml @@ -8,8 +8,7 @@ # env.gym.obs_type=environment_state_agent_pos \ seed: 1 -dataset_repo_id: null - +dataset_repo_id: null training: # Offline training dataloader @@ -75,15 +74,18 @@ policy: # discount: 0.99 discount: 0.80 temperature_init: 1.0 - num_critics: 2 + num_critics: 2 #10 num_subsample_critics: null critic_lr: 3e-4 actor_lr: 3e-4 temperature_lr: 3e-4 # critic_target_update_weight: 0.005 critic_target_update_weight: 0.01 - utd_ratio: 2 + utd_ratio: 2 # 10 +actor_learner_config: + actor_ip: "127.0.0.1" + port: 50051 # # Loss coefficients. # reward_coeff: 0.5 diff --git a/lerobot/scripts/server/actor_server.py b/lerobot/scripts/server/actor_server.py index afa6a6e0..0d2a1a5e 100644 --- a/lerobot/scripts/server/actor_server.py +++ b/lerobot/scripts/server/actor_server.py @@ -13,117 +13,123 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import io import logging -import functools -from pprint import pformat -import random -from typing import Optional, Sequence, TypedDict, Callable import pickle +import queue +import time +from concurrent import futures +from statistics import mean, quantiles -import hydra -import torch -import torch.nn.functional as F -from torch import nn -from tqdm import tqdm -from deepdiff import DeepDiff -from omegaconf import DictConfig, OmegaConf - -from lerobot.common.datasets.lerobot_dataset import LeRobotDataset - -# TODO: Remove the import of maniskill -from lerobot.common.datasets.factory import make_dataset -from lerobot.common.envs.factory import make_env, make_maniskill_env -from lerobot.common.envs.utils import preprocess_observation, preprocess_maniskill_observation -from lerobot.common.logger import Logger, log_output_dir -from lerobot.common.policies.factory import make_policy -from lerobot.common.policies.sac.modeling_sac import SACPolicy -from lerobot.common.policies.utils import get_device_from_parameters -from lerobot.common.utils.utils import ( - format_big_number, - get_safe_torch_device, - init_hydra_config, - init_logging, - set_global_seed, -) # from lerobot.scripts.eval import eval_policy from threading import Thread -import queue import grpc -from lerobot.scripts.server import hilserl_pb2, hilserl_pb2_grpc -import io -import time -import logging -from concurrent import futures -from threading import Thread -from lerobot.scripts.server.buffer import move_state_dict_to_device, move_transition_to_device, Transition +import hydra +import torch +from omegaconf import DictConfig +from torch import nn -import faulthandler -import signal +# TODO: Remove the import of maniskill +from lerobot.common.envs.factory import make_maniskill_env +from lerobot.common.envs.utils import preprocess_maniskill_observation +from lerobot.common.policies.factory import make_policy +from lerobot.common.policies.sac.modeling_sac import SACPolicy +from lerobot.common.utils.utils import ( + get_safe_torch_device, + set_global_seed, +) +from lerobot.scripts.server import hilserl_pb2, hilserl_pb2_grpc +from lerobot.scripts.server.buffer import Transition, move_state_dict_to_device, move_transition_to_device logging.basicConfig(level=logging.INFO) parameters_queue = queue.Queue(maxsize=1) message_queue = queue.Queue(maxsize=1_000_000) + class ActorInformation: + """ + This helper class is used to differentiate between two types of messages that are placed in the same queue during streaming: + + - **Transition Data:** Contains experience tuples (observation, action, reward, next observation) collected during interaction. + - **Interaction Messages:** Encapsulates statistics related to the interaction process. + + Attributes: + transition (Optional): Transition data to be sent to the learner. + interaction_message (Optional): Iteraction message providing additional statistics for logging. + """ + def __init__(self, transition=None, interaction_message=None): self.transition = transition self.interaction_message = interaction_message -# 1) Implement ActorService so the Learner can send parameters to this Actor. class ActorServiceServicer(hilserl_pb2_grpc.ActorServiceServicer): - def StreamTransition(self, request, context): + """ + gRPC service for actor-learner communication in reinforcement learning. + + This service is responsible for: + 1. Streaming batches of transition data and statistical metrics from the actor to the learner. + 2. Receiving updated network parameters from the learner. + """ + + def StreamTransition(self, request, context): # noqa: N802 + """ + Streams data from the actor to the learner. + + This function continuously retrieves messages from the queue and processes them based on their type: + + - **Transition Data:** + - A batch of transitions (observation, action, reward, next observation) is collected. + - Transitions are moved to the CPU and serialized using PyTorch. + - The serialized data is wrapped in a `hilserl_pb2.Transition` message and sent to the learner. + + - **Interaction Messages:** + - Contains useful statistics about episodic rewards and policy timings. + - The message is serialized using `pickle` and sent to the learner. + + Yields: + hilserl_pb2.ActorInformation: The response message containing either transition data or an interaction message. + """ while True: - # logging.info(f"[ACTOR] before message.empty()") - # logging.info(f"[ACTOR] size transition queue {message_queue.qsize()}") - # time.sleep(0.01) - # if message_queue.empty(): - # continue - # logging.info(f"[ACTOR] after message.empty()") - start = time.time() message = message_queue.get(block=True) - # logging.info(f"[ACTOR] Message queue get time {time.time() - start}") if message.transition is not None: - # transition_to_send_to_learner = move_transition_to_device(message.transition, device="cpu") - transition_to_send_to_learner = [move_transition_to_device(T, device="cpu") for T in message.transition] - # logging.info(f"[ACTOR] Message queue get time {time.time() - start}") + transition_to_send_to_learner = [ + move_transition_to_device(T, device="cpu") for T in message.transition + ] - # Serialize it buf = io.BytesIO() torch.save(transition_to_send_to_learner, buf) transition_bytes = buf.getvalue() - - transition_message = hilserl_pb2.Transition( - transition_bytes=transition_bytes - ) - response = hilserl_pb2.ActorInformation( - transition=transition_message - ) - logging.info(f"[ACTOR] time to yield transition response {time.time() - start}") - logging.info(f"[ACTOR] size transition queue {message_queue.qsize()}") - + transition_message = hilserl_pb2.Transition(transition_bytes=transition_bytes) + + response = hilserl_pb2.ActorInformation(transition=transition_message) + elif message.interaction_message is not None: - # Serialize it and send it to the Learner's server content = hilserl_pb2.InteractionMessage( interaction_message_bytes=pickle.dumps(message.interaction_message) - ) - response = hilserl_pb2.ActorInformation( - interaction_message=content ) + response = hilserl_pb2.ActorInformation(interaction_message=content) - # logging.info(f"[ACTOR] yield response before") yield response - # logging.info(f"[ACTOR] response yielded after") - def SendParameters(self, request, context): + def SendParameters(self, request, context): # noqa: N802 """ - Learner calls this with updated Parameters -> Actor + Receives updated parameters from the learner and updates the actor. + + The learner calls this method to send new model parameters. The received parameters are deserialized + and placed in a queue to be consumed by the actor. + + Args: + request (hilserl_pb2.ParameterUpdate): The request containing serialized network parameters. + context (grpc.ServicerContext): The gRPC context. + + Returns: + hilserl_pb2.Empty: An empty response to acknowledge receipt. """ - # logging.info("[ACTOR] Received parameters from Learner.") buffer = io.BytesIO(request.parameter_bytes) params = torch.load(buffer) parameters_queue.put(params) @@ -132,38 +138,38 @@ class ActorServiceServicer(hilserl_pb2_grpc.ActorServiceServicer): def serve_actor_service(port=50052): """ - Runs a gRPC server so that the Learner can push parameters to the Actor. + Runs a gRPC server to start streaming the data from the actor to the learner. + Throught this server the learner can push parameters to the Actor as well. """ - server = grpc.server(futures.ThreadPoolExecutor(max_workers=20), - options=[('grpc.max_send_message_length', -1), - ('grpc.max_receive_message_length', -1)]) - hilserl_pb2_grpc.add_ActorServiceServicer_to_server( - ActorServiceServicer(), server + server = grpc.server( + futures.ThreadPoolExecutor(max_workers=20), + options=[("grpc.max_send_message_length", -1), ("grpc.max_receive_message_length", -1)], ) - server.add_insecure_port(f'[::]:{port}') + hilserl_pb2_grpc.add_ActorServiceServicer_to_server(ActorServiceServicer(), server) + server.add_insecure_port(f"[::]:{port}") server.start() logging.info(f"[ACTOR] gRPC server listening on port {port}") server.wait_for_termination() -def act_with_policy(cfg: DictConfig, - out_dir: str | None = None, - job_name: str | None = None): - if out_dir is None: - raise NotImplementedError() - if job_name is None: - raise NotImplementedError() +def act_with_policy(cfg: DictConfig, out_dir: str | None = None, job_name: str | None = None): + """ + Executes policy interaction within the environment. + + This function rolls out the policy in the environment, collecting interaction data and pushing it to a queue for streaming to the learner. + Once an episode is completed, updated network parameters received from the learner are retrieved from a queue and loaded into the network. + + Args: + cfg (DictConfig): Configuration settings for the interaction process. + out_dir (Optional[str]): Directory to store output logs or results. Defaults to None. + job_name (Optional[str]): Name of the job for logging or tracking purposes. Defaults to None. + """ logging.info("make_env online") # online_env = make_env(cfg, n_envs=1) # TODO: Remove the import of maniskill and unifiy with make env online_env = make_maniskill_env(cfg, n_envs=1) - if cfg.training.eval_freq > 0: - logging.info("make_env eval") - # eval_env = make_env(cfg, n_envs=1) - # TODO: Remove the import of maniskill and unifiy with make env - eval_env = make_maniskill_env(cfg, n_envs=1) set_global_seed(cfg.seed) device = get_safe_torch_device(cfg.device, log=True) @@ -172,8 +178,7 @@ def act_with_policy(cfg: DictConfig, torch.backends.cuda.matmul.allow_tf32 = True logging.info("make_policy") - - + ### Instantiate the policy in both the actor and learner processes ### To avoid sending a SACPolicy object through the port, we create a policy intance ### on both sides, the learner sends the updated parameters every n steps to update the actor's parameters @@ -181,7 +186,7 @@ def act_with_policy(cfg: DictConfig, policy: SACPolicy = make_policy( hydra_cfg=cfg, # dataset_stats=offline_dataset.meta.stats if not cfg.resume else None, - # Hack: But if we do online traning, we do not need dataset_stats + # Hack: But if we do online training, we do not need dataset_stats dataset_stats=None, # TODO: Handle resume training pretrained_policy_name_or_path=None, @@ -195,17 +200,22 @@ def act_with_policy(cfg: DictConfig, # obs = preprocess_observation(obs) obs = preprocess_maniskill_observation(obs) obs = {key: obs[key].to(device, non_blocking=True) for key in obs} - ### ACTOR ================== + # NOTE: For the moment we will solely handle the case of a single environment sum_reward_episode = 0 list_transition_to_send_to_learner = [] + list_policy_fps = [] for interaction_step in range(cfg.training.online_steps): - # NOTE: At some point we should use a wrapper to handle the observation - - # start = time.time() if interaction_step >= cfg.training.online_step_before_learning: + start = time.perf_counter() action = policy.select_action(batch=obs) + list_policy_fps.append(1.0 / (time.perf_counter() - start + 1e-9)) + if list_policy_fps[-1] < cfg.fps: + logging.warning( + f"[ACTOR] policy frame rate {list_policy_fps[-1]} during interaction step {interaction_step} is below the required control frame rate {cfg.fps}" + ) + next_obs, reward, done, truncated, info = online_env.step(action.cpu().numpy()) else: action = online_env.action_space.sample() @@ -213,70 +223,88 @@ def act_with_policy(cfg: DictConfig, # HACK action = torch.tensor(action, dtype=torch.float32).to(device, non_blocking=True) - # logging.info(f"[ACTOR] Time for env step {time.time() - start}") - # HACK: For maniskill # next_obs = preprocess_observation(next_obs) next_obs = preprocess_maniskill_observation(next_obs) next_obs = {key: next_obs[key].to(device, non_blocking=True) for key in obs} sum_reward_episode += float(reward[0]) - # Because we are using a single environment - # we can safely assume that the episode is done + + # Because we are using a single environment we can index at zero if done[0].item() or truncated[0].item(): # TODO: Handle logging for episode information logging.info(f"[ACTOR] Global step {interaction_step}: Episode reward: {sum_reward_episode}") if not parameters_queue.empty(): - logging.info("[ACTOR] Load new parameters from Learner.") - # Load new parameters from Learner + logging.debug("[ACTOR] Load new parameters from Learner.") state_dict = parameters_queue.get() state_dict = move_state_dict_to_device(state_dict, device=device) policy.actor.load_state_dict(state_dict) - + if len(list_transition_to_send_to_learner) > 0: - logging.info(f"[ACTOR] Sending {len(list_transition_to_send_to_learner)} transitions to Learner.") + logging.debug( + f"[ACTOR] Sending {len(list_transition_to_send_to_learner)} transitions to Learner." + ) message_queue.put(ActorInformation(transition=list_transition_to_send_to_learner)) list_transition_to_send_to_learner = [] + stats = {} + if len(list_policy_fps) > 0: + policy_fps = mean(list_policy_fps) + quantiles_90 = quantiles(list_policy_fps, n=10)[-1] + logging.debug(f"[ACTOR] Average policy frame rate: {policy_fps}") + logging.debug(f"[ACTOR] Policy frame rate 90th percentile: {quantiles_90}") + stats = {"Policy frequency [Hz]": policy_fps, "Policy frequency 90th-p [Hz]": quantiles_90} + list_policy_fps = [] + # Send episodic reward to the learner - message_queue.put(ActorInformation(interaction_message={"episodic_reward": sum_reward_episode,"interaction_step": interaction_step})) + message_queue.put( + ActorInformation( + interaction_message={ + "Episodic reward": sum_reward_episode, + "Interaction step": interaction_step, + **stats, + } + ) + ) sum_reward_episode = 0.0 - # ============================ - # Prepare transition to send - # ============================ - # Label the reward + # TODO (michel-aractingi): Label the reward # if config.label_reward_on_actor: # reward = reward_classifier(obs) - list_transition_to_send_to_learner.append(Transition( - # transition_to_send_to_learner = Transition( - state=obs, - action=action, - reward=reward, - next_state=next_obs, - done=done, - complementary_info=None, - ) + list_transition_to_send_to_learner.append( + Transition( + state=obs, + action=action, + reward=reward, + next_state=next_obs, + done=done, + complementary_info=None, + ) ) - # message_queue.put(ActorInformation(transition=transition_to_send_to_learner)) # assign obs to the next obs and continue the rollout obs = next_obs + @hydra.main(version_base="1.2", config_name="default", config_path="../../configs") def actor_cli(cfg: dict): - server_thread = Thread(target=serve_actor_service, args=(50051,), daemon=True) - server_thread.start() - policy_thread = Thread(target=act_with_policy, - daemon=True, - args=(cfg,hydra.core.hydra_config.HydraConfig.get().run.dir, hydra.core.hydra_config.HydraConfig.get().job.name)) - policy_thread.start() - policy_thread.join() - server_thread.join() + port = cfg.actor_learner_config.port + server_thread = Thread(target=serve_actor_service, args=(port,), daemon=True) + server_thread.start() + policy_thread = Thread( + target=act_with_policy, + daemon=True, + args=( + cfg, + hydra.core.hydra_config.HydraConfig.get().run.dir, + hydra.core.hydra_config.HydraConfig.get().job.name, + ), + ) + policy_thread.start() + policy_thread.join() + server_thread.join() + if __name__ == "__main__": - with open("traceback.log", "w") as f: - faulthandler.register(signal.SIGUSR1, file=f) - - actor_cli() \ No newline at end of file + actor_cli() diff --git a/lerobot/scripts/server/hilserl.proto b/lerobot/scripts/server/hilserl.proto index 41f85100..9fd8663f 100644 --- a/lerobot/scripts/server/hilserl.proto +++ b/lerobot/scripts/server/hilserl.proto @@ -1,3 +1,19 @@ +// !/usr/bin/env python + +// Copyright 2024 The HuggingFace Inc. team. +// All rights reserved. + +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at + +// http://www.apache.org/licenses/LICENSE-2.0 + +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. syntax = "proto3"; package hil_serl; diff --git a/lerobot/scripts/server/learner_server.py b/lerobot/scripts/server/learner_server.py index bd15fc01..4c5c358c 100644 --- a/lerobot/scripts/server/learner_server.py +++ b/lerobot/scripts/server/learner_server.py @@ -1,97 +1,97 @@ -import grpc -from concurrent import futures -import functools -import logging -import queue -import pickle -import torch -import torch.nn.functional as F +#!/usr/bin/env python + +# Copyright 2024 The HuggingFace Inc. team. +# All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. import io +import logging +import pickle +import queue import time - from pprint import pformat -import random -from typing import Optional, Sequence, TypedDict, Callable +from threading import Lock, Thread +import grpc + +# Import generated stubs +import hilserl_pb2 # type: ignore +import hilserl_pb2_grpc # type: ignore import hydra import torch -import torch.nn.functional as F -from torch import nn -from tqdm import tqdm -from deepdiff import DeepDiff from omegaconf import DictConfig, OmegaConf -from threading import Thread, Lock - -from lerobot.common.datasets.lerobot_dataset import LeRobotDataset +from torch import nn # TODO: Remove the import of maniskill from lerobot.common.datasets.factory import make_dataset from lerobot.common.logger import Logger, log_output_dir from lerobot.common.policies.factory import make_policy from lerobot.common.policies.sac.modeling_sac import SACPolicy -from lerobot.common.policies.utils import get_device_from_parameters from lerobot.common.utils.utils import ( format_big_number, get_safe_torch_device, - init_hydra_config, init_logging, set_global_seed, ) -from lerobot.scripts.server.buffer import ReplayBuffer, move_transition_to_device, concatenate_batch_transitions, move_state_dict_to_device, Transition - -# Import generated stubs -import hilserl_pb2 -import hilserl_pb2_grpc +from lerobot.scripts.server.buffer import ( + ReplayBuffer, + concatenate_batch_transitions, + move_state_dict_to_device, + move_transition_to_device, +) logging.basicConfig(level=logging.INFO) - - # TODO: Implement it in cleaner way maybe transition_queue = queue.Queue() interaction_message_queue = queue.Queue() -# 1) Implement the LearnerService so the Actor can send transitions here. -class LearnerServiceServicer(hilserl_pb2_grpc.LearnerServiceServicer): - # def SendTransition(self, request, context): - # """ - # Actor calls this method to push a Transition -> Learner. - # """ - # buffer = io.BytesIO(request.transition_bytes) - # transition = torch.load(buffer) - # transition_queue.put(transition) - # return hilserl_pb2.Empty() - def SendInteractionMessage(self, request, context): - """ - Actor calls this method to push a Transition -> Learner. - """ - content = pickle.loads(request.interaction_message_bytes) - interaction_message_queue.put(content) - return hilserl_pb2.Empty() - - - -def stream_transitions_from_actor(port=50051): +def stream_transitions_from_actor(host="127.0.0.1", port=50051): """ - Runs a gRPC server listening for transitions from the Actor. + Runs a gRPC client that listens for transition and interaction messages from an Actor service. + + This function establishes a gRPC connection with the given `host` and `port`, then continuously + streams transition data from the `ActorServiceStub`. The received transition data is deserialized + and stored in a queue (`transition_queue`). Similarly, interaction messages are also deserialized + and stored in a separate queue (`interaction_message_queue`). + + Args: + host (str, optional): The IP address or hostname of the gRPC server. Defaults to `"127.0.0.1"`. + port (int, optional): The port number on which the gRPC server is running. Defaults to `50051`. + """ + # NOTE: This is waiting for the handshake to be done + # In the future we will do it in a canonical way with a proper handshake time.sleep(10) - channel = grpc.insecure_channel(f'127.0.0.1:{port}', - options=[('grpc.max_send_message_length', -1), - ('grpc.max_receive_message_length', -1)]) + channel = grpc.insecure_channel( + f"{host}:{port}", + options=[("grpc.max_send_message_length", -1), ("grpc.max_receive_message_length", -1)], + ) stub = hilserl_pb2_grpc.ActorServiceStub(channel) for response in stub.StreamTransition(hilserl_pb2.Empty()): - if response.HasField('transition'): + if response.HasField("transition"): buffer = io.BytesIO(response.transition.transition_bytes) transition = torch.load(buffer) transition_queue.put(transition) - if response.HasField('interaction_message'): + if response.HasField("interaction_message"): content = pickle.loads(response.interaction_message.interaction_message_bytes) interaction_message_queue.put(content) # NOTE: Cool down the CPU, if you comment this line you will make a huge bottleneck + # TODO: LOOK TO REMOVE IT time.sleep(0.001) + def learner_push_parameters( policy: nn.Module, policy_lock: Lock, actor_host="127.0.0.1", actor_port=50052, seconds_between_pushes=5 ): @@ -100,10 +100,10 @@ def learner_push_parameters( and periodically push new parameters. """ time.sleep(10) - # The Actor's server is presumably listening on a different port, e.g. 50052 - channel = grpc.insecure_channel(f"{actor_host}:{actor_port}", - options=[('grpc.max_send_message_length', -1), - ('grpc.max_receive_message_length', -1)]) + channel = grpc.insecure_channel( + f"{actor_host}:{actor_port}", + options=[("grpc.max_send_message_length", -1), ("grpc.max_receive_message_length", -1)], + ) actor_stub = hilserl_pb2_grpc.ActorServiceStub(channel) while True: @@ -116,20 +116,19 @@ def learner_push_parameters( params_bytes = buf.getvalue() # Push them to the Actor’s "SendParameters" method - logging.info(f"[LEARNER] Pushing parameters to the Actor") - response = actor_stub.SendParameters(hilserl_pb2.Parameters(parameter_bytes=params_bytes)) + logging.info("[LEARNER] Publishing parameters to the Actor") + response = actor_stub.SendParameters(hilserl_pb2.Parameters(parameter_bytes=params_bytes)) # noqa: F841 time.sleep(seconds_between_pushes) -# Checked -def add_actor_information( +def add_actor_information_and_train( cfg, - device, + device: str, replay_buffer: ReplayBuffer, offline_replay_buffer: ReplayBuffer, batch_size: int, - optimizers, - policy, + optimizers: dict[str, torch.optim.Optimizer], + policy: nn.Module, policy_lock: Lock, buffer_lock: Lock, offline_buffer_lock: Lock, @@ -137,34 +136,52 @@ def add_actor_information( logger: Logger, ): """ - In a real application, you might run your training loop here, - reading from the transition queue and doing gradient updates. + Handles data transfer from the actor to the learner, manages training updates, + and logs training progress in an online reinforcement learning setup. + + This function continuously: + - Transfers transitions from the actor to the replay buffer. + - Logs received interaction messages. + - Ensures training begins only when the replay buffer has a sufficient number of transitions. + - Samples batches from the replay buffer and performs multiple critic updates. + - Periodically updates the actor, critic, and temperature optimizers. + - Logs training statistics, including loss values and optimization frequency. + + **NOTE:** + - This function performs multiple responsibilities (data transfer, training, and logging). + It should ideally be split into smaller functions in the future. + - Due to Python's **Global Interpreter Lock (GIL)**, running separate threads for different tasks + significantly reduces performance. Instead, this function executes all operations in a single thread. + + Args: + cfg: Configuration object containing hyperparameters. + device (str): The computing device (`"cpu"` or `"cuda"`). + replay_buffer (ReplayBuffer): The primary replay buffer storing online transitions. + offline_replay_buffer (ReplayBuffer): An additional buffer for offline transitions. + batch_size (int): The number of transitions to sample per training step. + optimizers (Dict[str, torch.optim.Optimizer]): A dictionary of optimizers (`"actor"`, `"critic"`, `"temperature"`). + policy (nn.Module): The reinforcement learning policy with critic, actor, and temperature parameters. + policy_lock (Lock): A threading lock to ensure safe policy updates. + buffer_lock (Lock): A threading lock to safely access the online replay buffer. + offline_buffer_lock (Lock): A threading lock to safely access the offline replay buffer. + logger_lock (Lock): A threading lock to safely log training metrics. + logger (Logger): Logger instance for tracking training progress. """ # NOTE: This function doesn't have a single responsibility, it should be split into multiple functions # in the future. The reason why we did that is the GIL in Python. It's super slow the performance # are divided by 200. So we need to have a single thread that does all the work. - start = time.time() + time.time() optimization_step = 0 - timeout_for_adding_transitions = 1 while True: - time_for_adding_transitions = time.time() while not transition_queue.empty(): - transition_list = transition_queue.get() for transition in transition_list: transition = move_transition_to_device(transition, device=device) replay_buffer.add(**transition) - # logging.info(f"[LEARNER] size of replay buffer: {len(replay_buffer)}") - # logging.info(f"[LEARNER] size of transition queues: {transition_queue.qsize()}") - # logging.info(f"[LEARNER] size of replay buffer: {len(replay_buffer)}") - # logging.info(f"[LEARNER] size of transition queues: {transition }") - if len(replay_buffer) > cfg.training.online_step_before_learning: - logging.info(f"[LEARNER] size of replay buffer: {len(replay_buffer)}") while not interaction_message_queue.empty(): interaction_message = interaction_message_queue.get() - logger.log_dict(interaction_message,mode="train",custom_step_key="interaction_step") - # logging.info(f"[LEARNER] size of interaction message queue: {interaction_message_queue.qsize()}") + logger.log_dict(interaction_message, mode="train", custom_step_key="Interaction step") if len(replay_buffer) < cfg.training.online_step_before_learning: continue @@ -212,7 +229,7 @@ def add_actor_information( loss_critic = policy.compute_loss_critic( observations=observations, actions=actions, - rewards=rewards, + rewards=rewards, next_observations=next_observations, done=done, ) @@ -223,7 +240,6 @@ def add_actor_information( training_infos = {} training_infos["loss_critic"] = loss_critic.item() - if optimization_step % cfg.training.policy_update_freq == 0: for _ in range(cfg.training.policy_update_freq): with policy_lock: @@ -242,18 +258,52 @@ def add_actor_information( training_infos["loss_temperature"] = loss_temperature.item() + policy.update_target_networks() if optimization_step % cfg.training.log_freq == 0: logger.log_dict(training_infos, step=optimization_step, mode="train") - policy.update_target_networks() - optimization_step += 1 time_for_one_optimization_step = time.time() - time_for_one_optimization_step + frequency_for_one_optimization_step = 1 / (time_for_one_optimization_step + 1e-9) - logging.info(f"[LEARNER] Time for one optimization step: {time_for_one_optimization_step}") - logger.log_dict({"Time optimization step":time_for_one_optimization_step}, step=optimization_step, mode="train") + logging.debug(f"[LEARNER] Optimization frequency loop [Hz]: {frequency_for_one_optimization_step}") + + logger.log_dict( + {"Optimization frequency loop [Hz]": frequency_for_one_optimization_step}, + step=optimization_step, + mode="train", + ) + + optimization_step += 1 + if optimization_step % cfg.training.log_freq == 0: + logging.info(f"[LEARNER] Number of optimization step: {optimization_step}") -def make_optimizers_and_scheduler(cfg, policy): +def make_optimizers_and_scheduler(cfg, policy: nn.Module): + """ + Creates and returns optimizers for the actor, critic, and temperature components of a reinforcement learning policy. + + This function sets up Adam optimizers for: + - The **actor network**, ensuring that only relevant parameters are optimized. + - The **critic ensemble**, which evaluates the value function. + - The **temperature parameter**, which controls the entropy in soft actor-critic (SAC)-like methods. + + It also initializes a learning rate scheduler, though currently, it is set to `None`. + + **NOTE:** + - If the encoder is shared, its parameters are excluded from the actor’s optimization process. + - The policy’s log temperature (`log_alpha`) is wrapped in a list to ensure proper optimization as a standalone tensor. + + Args: + cfg: Configuration object containing hyperparameters. + policy (nn.Module): The policy model containing the actor, critic, and temperature components. + + Returns: + Tuple[Dict[str, torch.optim.Optimizer], Optional[torch.optim.lr_scheduler._LRScheduler]]: + A tuple containing: + - `optimizers`: A dictionary mapping component names ("actor", "critic", "temperature") to their respective Adam optimizers. + - `lr_scheduler`: Currently set to `None` but can be extended to support learning rate scheduling. + + """ optimizer_actor = torch.optim.Adam( # NOTE: Handle the case of shared encoder where the encoder weights are not optimized with the gradient of the actor params=policy.actor.parameters_to_optimize, @@ -273,8 +323,6 @@ def make_optimizers_and_scheduler(cfg, policy): return optimizers, lr_scheduler - - def train(cfg: DictConfig, out_dir: str | None = None, job_name: str | None = None): if out_dir is None: raise NotImplementedError() @@ -332,6 +380,7 @@ def train(cfg: DictConfig, out_dir: str | None = None, job_name: str | None = No batch_size = cfg.training.batch_size offline_buffer_lock = None offline_replay_buffer = None + if cfg.dataset_repo_id is not None: logging.info("make_dataset offline buffer") offline_dataset = make_dataset(cfg) @@ -342,48 +391,48 @@ def train(cfg: DictConfig, out_dir: str | None = None, job_name: str | None = No offline_buffer_lock = Lock() batch_size: int = batch_size // 2 # We will sample from both replay buffer - server_thread = Thread(target=stream_transitions_from_actor, args=(50051,), daemon=True) + actor_ip = cfg.actor_learner_config.actor_ip + port = cfg.actor_learner_config.port + + server_thread = Thread( + target=stream_transitions_from_actor, + args=( + actor_ip, + port, + ), + daemon=True, + ) server_thread.start() - - # Start a background thread to process transitions from the queue transition_thread = Thread( - target=add_actor_information, + target=add_actor_information_and_train, daemon=True, - args=(cfg, - device, - replay_buffer, - offline_replay_buffer, - batch_size, - optimizers, - policy, - policy_lock, - buffer_lock, - offline_buffer_lock, - logger_lock, - logger), + args=( + cfg, + device, + replay_buffer, + offline_replay_buffer, + batch_size, + optimizers, + policy, + policy_lock, + buffer_lock, + offline_buffer_lock, + logger_lock, + logger, + ), ) transition_thread.start() param_push_thread = Thread( target=learner_push_parameters, - args=(policy, policy_lock, "127.0.0.1", 50051, 15), - # args=("127.0.0.1", 50052), + args=(policy, policy_lock, actor_ip, port, 15), daemon=True, ) param_push_thread.start() - # interaction_thread = Thread( - # target=add_message_interaction_to_wandb, - # daemon=True, - # args=(cfg, logger, logger_lock), - # ) - # interaction_thread.start() - transition_thread.join() - # param_push_thread.join() server_thread.join() - # interaction_thread.join() @hydra.main(version_base="1.2", config_name="default", config_path="../../configs")