- Added additional logging information in wandb around the timings of the policy loop and optimization loop.

- Optimized critic design that improves the performance of the learner loop by a factor of 2
- Cleaned the code and fixed style issues

- Completed the config with actor_learner_config field that contains host-ip and port elemnts that are necessary for the actor-learner servers.

Co-authored-by: Adil Zouitine <adilzouitinegm@gmail.com>
This commit is contained in:
Michel Aractingi 2025-01-29 15:50:46 +00:00 committed by AdilZouitine
parent a0a81c0c12
commit 18207d995e
6 changed files with 461 additions and 313 deletions

View File

@ -45,6 +45,14 @@ class SACConfig:
"action": {"min": [-1, -1], "max": [1, 1]}, "action": {"min": [-1, -1], "max": [1, 1]},
} }
) )
# TODO: Move it outside of the config
actor_learner_config: dict[str, str | int] = field(
default_factory=lambda: {
"actor_ip": "127.0.0.1",
"port": 50051,
"learner_ip": "127.0.0.1",
}
)
camera_number: int = 1 camera_number: int = 1
# Add type annotations for these fields: # Add type annotations for these fields:
image_encoder_hidden_dim: int = 32 image_encoder_hidden_dim: int = 32

View File

@ -17,8 +17,7 @@
# TODO: (1) better device management # TODO: (1) better device management
from collections import deque from typing import Callable, Optional, Tuple
from typing import Callable, Optional, Sequence, Tuple, Union
import einops import einops
import numpy as np import numpy as np
@ -74,43 +73,42 @@ class SACPolicy(
config.output_shapes, config.output_normalization_modes, dataset_stats config.output_shapes, config.output_normalization_modes, dataset_stats
) )
# NOTE: For images the encoder should be shared between the actor and critic
if config.shared_encoder: if config.shared_encoder:
encoder_critic = SACObservationEncoder(config) encoder_critic = SACObservationEncoder(config)
encoder_actor: SACObservationEncoder = encoder_critic encoder_actor: SACObservationEncoder = encoder_critic
else: else:
encoder_critic = SACObservationEncoder(config) encoder_critic = SACObservationEncoder(config)
encoder_actor = SACObservationEncoder(config) encoder_actor = SACObservationEncoder(config)
# Define networks
critic_nets = [] self.critic_ensemble = CriticEnsemble(
for _ in range(config.num_critics):
critic_net = Critic(
encoder=encoder_critic, encoder=encoder_critic,
network=MLP( network_list=nn.ModuleList(
[
MLP(
input_dim=encoder_critic.output_dim + config.output_shapes["action"][0], input_dim=encoder_critic.output_dim + config.output_shapes["action"][0],
**config.critic_network_kwargs, **config.critic_network_kwargs,
)
for _ in range(config.num_critics)
]
), ),
device=device, device=device,
) )
critic_nets.append(critic_net)
target_critic_nets = [] self.critic_target = CriticEnsemble(
for _ in range(config.num_critics):
target_critic_net = Critic(
encoder=encoder_critic, encoder=encoder_critic,
network=MLP( network_list=nn.ModuleList(
[
MLP(
input_dim=encoder_critic.output_dim + config.output_shapes["action"][0], input_dim=encoder_critic.output_dim + config.output_shapes["action"][0],
**config.critic_network_kwargs, **config.critic_network_kwargs,
)
for _ in range(config.num_critics)
]
), ),
device=device, device=device,
) )
target_critic_nets.append(target_critic_net)
self.critic_ensemble = create_critic_ensemble(
critics=critic_nets, num_critics=config.num_critics, device=device
)
self.critic_target = create_critic_ensemble(
critics=target_critic_nets, num_critics=config.num_critics, device=device
)
self.critic_target.load_state_dict(self.critic_ensemble.state_dict()) self.critic_target.load_state_dict(self.critic_ensemble.state_dict())
self.actor = Policy( self.actor = Policy(
@ -123,7 +121,8 @@ class SACPolicy(
) )
if config.target_entropy is None: if config.target_entropy is None:
config.target_entropy = -np.prod(config.output_shapes["action"][0]) / 2 # (-dim(A)/2) config.target_entropy = -np.prod(config.output_shapes["action"][0]) / 2 # (-dim(A)/2)
# TODO: Handle the case where the temparameter is a fixed
# TODO (azouitine): Handle the case where the temparameter is a fixed
self.log_alpha = torch.zeros(1, requires_grad=True, device=device) self.log_alpha = torch.zeros(1, requires_grad=True, device=device)
self.temperature = self.log_alpha.exp().item() self.temperature = self.log_alpha.exp().item()
@ -152,14 +151,15 @@ class SACPolicy(
Tensor of Q-values from all critics Tensor of Q-values from all critics
""" """
critics = self.critic_target if use_target else self.critic_ensemble critics = self.critic_target if use_target else self.critic_ensemble
q_values = torch.stack([critic(observations, actions) for critic in critics]) q_values = critics(observations, actions)
return q_values return q_values
def forward(self, batch: dict[str, Tensor]) -> dict[str, Tensor | float]: ... def forward(self, batch: dict[str, Tensor]) -> dict[str, Tensor | float]: ...
def update_target_networks(self): def update_target_networks(self):
"""Update target networks with exponential moving average""" """Update target networks with exponential moving average"""
for target_critic, critic in zip(self.critic_target, self.critic_ensemble, strict=False): for target_param, param in zip(
for target_param, param in zip(target_critic.parameters(), critic.parameters(), strict=False): self.critic_target.parameters(), self.critic_ensemble.parameters(), strict=False
):
target_param.data.copy_( target_param.data.copy_(
param.data * self.config.critic_target_update_weight param.data * self.config.critic_target_update_weight
+ target_param.data * (1.0 - self.config.critic_target_update_weight) + target_param.data * (1.0 - self.config.critic_target_update_weight)
@ -264,34 +264,83 @@ class MLP(nn.Module):
return self.net(x) return self.net(x)
class Critic(nn.Module): class CriticEnsemble(nn.Module):
"""
Critic Ensemble
Q1 Q2 Qn
MLP 1 MLP 2 MLP
... num_critics
Embedding
SACObservationEncoder
Observation
"""
def __init__( def __init__(
self, self,
encoder: Optional[nn.Module], encoder: Optional[nn.Module],
network: nn.Module, network_list: nn.Module,
init_final: Optional[float] = None, init_final: Optional[float] = None,
device: str = "cpu", device: str = "cpu",
): ):
super().__init__() super().__init__()
self.device = torch.device(device) self.device = torch.device(device)
self.encoder = encoder self.encoder = encoder
self.network = network self.network_list = network_list
self.init_final = init_final self.init_final = init_final
# for network in network_list:
# network.to(self.device)
# Find the last Linear layer's output dimension # Find the last Linear layer's output dimension
for layer in reversed(network.net): for layer in reversed(network_list[0].net):
if isinstance(layer, nn.Linear): if isinstance(layer, nn.Linear):
out_features = layer.out_features out_features = layer.out_features
break break
# Output layer # Output layer
self.output_layers = []
if init_final is not None: if init_final is not None:
self.output_layer = nn.Linear(out_features, 1) for _ in network_list:
nn.init.uniform_(self.output_layer.weight, -init_final, init_final) output_layer = nn.Linear(out_features, 1, device=device)
nn.init.uniform_(self.output_layer.bias, -init_final, init_final) nn.init.uniform_(output_layer.weight, -init_final, init_final)
nn.init.uniform_(output_layer.bias, -init_final, init_final)
self.output_layers.append(output_layer)
else: else:
self.output_layer = nn.Linear(out_features, 1) self.output_layers = []
orthogonal_init()(self.output_layer.weight) for _ in network_list:
output_layer = nn.Linear(out_features, 1, device=device)
orthogonal_init()(output_layer.weight)
self.output_layers.append(output_layer)
self.output_layers = nn.ModuleList(self.output_layers)
self.to(self.device) self.to(self.device)
@ -307,9 +356,12 @@ class Critic(nn.Module):
obs_enc = observations if self.encoder is None else self.encoder(observations) obs_enc = observations if self.encoder is None else self.encoder(observations)
inputs = torch.cat([obs_enc, actions], dim=-1) inputs = torch.cat([obs_enc, actions], dim=-1)
x = self.network(inputs) list_q_values = []
value = self.output_layer(x) for network, output_layer in zip(self.network_list, self.output_layers, strict=False):
return value.squeeze(-1) x = network(inputs)
value = output_layer(x)
list_q_values.append(value.squeeze(-1))
return torch.stack(list_q_values)
class Policy(nn.Module): class Policy(nn.Module):
@ -416,9 +468,7 @@ class Policy(nn.Module):
class SACObservationEncoder(nn.Module): class SACObservationEncoder(nn.Module):
"""Encode image and/or state vector observations. """Encode image and/or state vector observations."""
TODO(ke-wang): The original work allows for (1) stacking multiple history frames and (2) using pretrained resnet encoders.
"""
def __init__(self, config: SACConfig): def __init__(self, config: SACConfig):
""" """
@ -513,8 +563,7 @@ class SACObservationEncoder(nn.Module):
feat.append(self.env_state_enc_layers(obs_dict["observation.environment_state"])) feat.append(self.env_state_enc_layers(obs_dict["observation.environment_state"]))
if "observation.state" in self.config.input_shapes: if "observation.state" in self.config.input_shapes:
feat.append(self.state_enc_layers(obs_dict["observation.state"])) feat.append(self.state_enc_layers(obs_dict["observation.state"]))
# TODO(ke-wang): currently average over all features, concatenate all features maybe a better way
# return torch.stack(feat, dim=0).mean(0)
features = torch.cat(tensors=feat, dim=-1) features = torch.cat(tensors=feat, dim=-1)
features = self.aggregation_layer(features) features = self.aggregation_layer(features)
@ -530,12 +579,8 @@ def orthogonal_init():
return lambda x: torch.nn.init.orthogonal_(x, gain=1.0) return lambda x: torch.nn.init.orthogonal_(x, gain=1.0)
def create_critic_ensemble(critics: list[nn.Module], num_critics: int, device: str = "cpu") -> nn.ModuleList: # TODO (azouitine): I think in our case this function is not usefull we should remove it
"""Creates an ensemble of critic networks""" # after some investigation
assert len(critics) == num_critics, f"Expected {num_critics} critics, got {len(critics)}"
return nn.ModuleList(critics).to(device)
# borrowed from tdmpc # borrowed from tdmpc
def flatten_forward_unflatten(fn: Callable[[Tensor], Tensor], image_tensor: Tensor) -> Tensor: def flatten_forward_unflatten(fn: Callable[[Tensor], Tensor], image_tensor: Tensor) -> Tensor:
"""Helper to temporarily flatten extra dims at the start of the image tensor. """Helper to temporarily flatten extra dims at the start of the image tensor.

View File

@ -10,7 +10,6 @@
seed: 1 seed: 1
dataset_repo_id: null dataset_repo_id: null
training: training:
# Offline training dataloader # Offline training dataloader
num_workers: 4 num_workers: 4
@ -75,15 +74,18 @@ policy:
# discount: 0.99 # discount: 0.99
discount: 0.80 discount: 0.80
temperature_init: 1.0 temperature_init: 1.0
num_critics: 2 num_critics: 2 #10
num_subsample_critics: null num_subsample_critics: null
critic_lr: 3e-4 critic_lr: 3e-4
actor_lr: 3e-4 actor_lr: 3e-4
temperature_lr: 3e-4 temperature_lr: 3e-4
# critic_target_update_weight: 0.005 # critic_target_update_weight: 0.005
critic_target_update_weight: 0.01 critic_target_update_weight: 0.01
utd_ratio: 2 utd_ratio: 2 # 10
actor_learner_config:
actor_ip: "127.0.0.1"
port: 50051
# # Loss coefficients. # # Loss coefficients.
# reward_coeff: 0.5 # reward_coeff: 0.5

View File

@ -13,117 +13,123 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import io
import logging import logging
import functools
from pprint import pformat
import random
from typing import Optional, Sequence, TypedDict, Callable
import pickle import pickle
import queue
import time
from concurrent import futures
from statistics import mean, quantiles
import hydra
import torch
import torch.nn.functional as F
from torch import nn
from tqdm import tqdm
from deepdiff import DeepDiff
from omegaconf import DictConfig, OmegaConf
from lerobot.common.datasets.lerobot_dataset import LeRobotDataset
# TODO: Remove the import of maniskill
from lerobot.common.datasets.factory import make_dataset
from lerobot.common.envs.factory import make_env, make_maniskill_env
from lerobot.common.envs.utils import preprocess_observation, preprocess_maniskill_observation
from lerobot.common.logger import Logger, log_output_dir
from lerobot.common.policies.factory import make_policy
from lerobot.common.policies.sac.modeling_sac import SACPolicy
from lerobot.common.policies.utils import get_device_from_parameters
from lerobot.common.utils.utils import (
format_big_number,
get_safe_torch_device,
init_hydra_config,
init_logging,
set_global_seed,
)
# from lerobot.scripts.eval import eval_policy # from lerobot.scripts.eval import eval_policy
from threading import Thread from threading import Thread
import queue
import grpc import grpc
from lerobot.scripts.server import hilserl_pb2, hilserl_pb2_grpc import hydra
import io import torch
import time from omegaconf import DictConfig
import logging from torch import nn
from concurrent import futures
from threading import Thread
from lerobot.scripts.server.buffer import move_state_dict_to_device, move_transition_to_device, Transition
import faulthandler # TODO: Remove the import of maniskill
import signal from lerobot.common.envs.factory import make_maniskill_env
from lerobot.common.envs.utils import preprocess_maniskill_observation
from lerobot.common.policies.factory import make_policy
from lerobot.common.policies.sac.modeling_sac import SACPolicy
from lerobot.common.utils.utils import (
get_safe_torch_device,
set_global_seed,
)
from lerobot.scripts.server import hilserl_pb2, hilserl_pb2_grpc
from lerobot.scripts.server.buffer import Transition, move_state_dict_to_device, move_transition_to_device
logging.basicConfig(level=logging.INFO) logging.basicConfig(level=logging.INFO)
parameters_queue = queue.Queue(maxsize=1) parameters_queue = queue.Queue(maxsize=1)
message_queue = queue.Queue(maxsize=1_000_000) message_queue = queue.Queue(maxsize=1_000_000)
class ActorInformation: class ActorInformation:
"""
This helper class is used to differentiate between two types of messages that are placed in the same queue during streaming:
- **Transition Data:** Contains experience tuples (observation, action, reward, next observation) collected during interaction.
- **Interaction Messages:** Encapsulates statistics related to the interaction process.
Attributes:
transition (Optional): Transition data to be sent to the learner.
interaction_message (Optional): Iteraction message providing additional statistics for logging.
"""
def __init__(self, transition=None, interaction_message=None): def __init__(self, transition=None, interaction_message=None):
self.transition = transition self.transition = transition
self.interaction_message = interaction_message self.interaction_message = interaction_message
# 1) Implement ActorService so the Learner can send parameters to this Actor.
class ActorServiceServicer(hilserl_pb2_grpc.ActorServiceServicer): class ActorServiceServicer(hilserl_pb2_grpc.ActorServiceServicer):
def StreamTransition(self, request, context): """
gRPC service for actor-learner communication in reinforcement learning.
This service is responsible for:
1. Streaming batches of transition data and statistical metrics from the actor to the learner.
2. Receiving updated network parameters from the learner.
"""
def StreamTransition(self, request, context): # noqa: N802
"""
Streams data from the actor to the learner.
This function continuously retrieves messages from the queue and processes them based on their type:
- **Transition Data:**
- A batch of transitions (observation, action, reward, next observation) is collected.
- Transitions are moved to the CPU and serialized using PyTorch.
- The serialized data is wrapped in a `hilserl_pb2.Transition` message and sent to the learner.
- **Interaction Messages:**
- Contains useful statistics about episodic rewards and policy timings.
- The message is serialized using `pickle` and sent to the learner.
Yields:
hilserl_pb2.ActorInformation: The response message containing either transition data or an interaction message.
"""
while True: while True:
# logging.info(f"[ACTOR] before message.empty()")
# logging.info(f"[ACTOR] size transition queue {message_queue.qsize()}")
# time.sleep(0.01)
# if message_queue.empty():
# continue
# logging.info(f"[ACTOR] after message.empty()")
start = time.time()
message = message_queue.get(block=True) message = message_queue.get(block=True)
# logging.info(f"[ACTOR] Message queue get time {time.time() - start}")
if message.transition is not None: if message.transition is not None:
# transition_to_send_to_learner = move_transition_to_device(message.transition, device="cpu") transition_to_send_to_learner = [
transition_to_send_to_learner = [move_transition_to_device(T, device="cpu") for T in message.transition] move_transition_to_device(T, device="cpu") for T in message.transition
# logging.info(f"[ACTOR] Message queue get time {time.time() - start}") ]
# Serialize it
buf = io.BytesIO() buf = io.BytesIO()
torch.save(transition_to_send_to_learner, buf) torch.save(transition_to_send_to_learner, buf)
transition_bytes = buf.getvalue() transition_bytes = buf.getvalue()
transition_message = hilserl_pb2.Transition( transition_message = hilserl_pb2.Transition(transition_bytes=transition_bytes)
transition_bytes=transition_bytes
)
response = hilserl_pb2.ActorInformation( response = hilserl_pb2.ActorInformation(transition=transition_message)
transition=transition_message
)
logging.info(f"[ACTOR] time to yield transition response {time.time() - start}")
logging.info(f"[ACTOR] size transition queue {message_queue.qsize()}")
elif message.interaction_message is not None: elif message.interaction_message is not None:
# Serialize it and send it to the Learner's server
content = hilserl_pb2.InteractionMessage( content = hilserl_pb2.InteractionMessage(
interaction_message_bytes=pickle.dumps(message.interaction_message) interaction_message_bytes=pickle.dumps(message.interaction_message)
) )
response = hilserl_pb2.ActorInformation( response = hilserl_pb2.ActorInformation(interaction_message=content)
interaction_message=content
)
# logging.info(f"[ACTOR] yield response before")
yield response yield response
# logging.info(f"[ACTOR] response yielded after")
def SendParameters(self, request, context): def SendParameters(self, request, context): # noqa: N802
""" """
Learner calls this with updated Parameters -> Actor Receives updated parameters from the learner and updates the actor.
The learner calls this method to send new model parameters. The received parameters are deserialized
and placed in a queue to be consumed by the actor.
Args:
request (hilserl_pb2.ParameterUpdate): The request containing serialized network parameters.
context (grpc.ServicerContext): The gRPC context.
Returns:
hilserl_pb2.Empty: An empty response to acknowledge receipt.
""" """
# logging.info("[ACTOR] Received parameters from Learner.")
buffer = io.BytesIO(request.parameter_bytes) buffer = io.BytesIO(request.parameter_bytes)
params = torch.load(buffer) params = torch.load(buffer)
parameters_queue.put(params) parameters_queue.put(params)
@ -132,38 +138,38 @@ class ActorServiceServicer(hilserl_pb2_grpc.ActorServiceServicer):
def serve_actor_service(port=50052): def serve_actor_service(port=50052):
""" """
Runs a gRPC server so that the Learner can push parameters to the Actor. Runs a gRPC server to start streaming the data from the actor to the learner.
Throught this server the learner can push parameters to the Actor as well.
""" """
server = grpc.server(futures.ThreadPoolExecutor(max_workers=20), server = grpc.server(
options=[('grpc.max_send_message_length', -1), futures.ThreadPoolExecutor(max_workers=20),
('grpc.max_receive_message_length', -1)]) options=[("grpc.max_send_message_length", -1), ("grpc.max_receive_message_length", -1)],
hilserl_pb2_grpc.add_ActorServiceServicer_to_server(
ActorServiceServicer(), server
) )
server.add_insecure_port(f'[::]:{port}') hilserl_pb2_grpc.add_ActorServiceServicer_to_server(ActorServiceServicer(), server)
server.add_insecure_port(f"[::]:{port}")
server.start() server.start()
logging.info(f"[ACTOR] gRPC server listening on port {port}") logging.info(f"[ACTOR] gRPC server listening on port {port}")
server.wait_for_termination() server.wait_for_termination()
def act_with_policy(cfg: DictConfig,
out_dir: str | None = None,
job_name: str | None = None):
if out_dir is None: def act_with_policy(cfg: DictConfig, out_dir: str | None = None, job_name: str | None = None):
raise NotImplementedError() """
if job_name is None: Executes policy interaction within the environment.
raise NotImplementedError()
This function rolls out the policy in the environment, collecting interaction data and pushing it to a queue for streaming to the learner.
Once an episode is completed, updated network parameters received from the learner are retrieved from a queue and loaded into the network.
Args:
cfg (DictConfig): Configuration settings for the interaction process.
out_dir (Optional[str]): Directory to store output logs or results. Defaults to None.
job_name (Optional[str]): Name of the job for logging or tracking purposes. Defaults to None.
"""
logging.info("make_env online") logging.info("make_env online")
# online_env = make_env(cfg, n_envs=1) # online_env = make_env(cfg, n_envs=1)
# TODO: Remove the import of maniskill and unifiy with make env # TODO: Remove the import of maniskill and unifiy with make env
online_env = make_maniskill_env(cfg, n_envs=1) online_env = make_maniskill_env(cfg, n_envs=1)
if cfg.training.eval_freq > 0:
logging.info("make_env eval")
# eval_env = make_env(cfg, n_envs=1)
# TODO: Remove the import of maniskill and unifiy with make env
eval_env = make_maniskill_env(cfg, n_envs=1)
set_global_seed(cfg.seed) set_global_seed(cfg.seed)
device = get_safe_torch_device(cfg.device, log=True) device = get_safe_torch_device(cfg.device, log=True)
@ -173,7 +179,6 @@ def act_with_policy(cfg: DictConfig,
logging.info("make_policy") logging.info("make_policy")
### Instantiate the policy in both the actor and learner processes ### Instantiate the policy in both the actor and learner processes
### To avoid sending a SACPolicy object through the port, we create a policy intance ### To avoid sending a SACPolicy object through the port, we create a policy intance
### on both sides, the learner sends the updated parameters every n steps to update the actor's parameters ### on both sides, the learner sends the updated parameters every n steps to update the actor's parameters
@ -181,7 +186,7 @@ def act_with_policy(cfg: DictConfig,
policy: SACPolicy = make_policy( policy: SACPolicy = make_policy(
hydra_cfg=cfg, hydra_cfg=cfg,
# dataset_stats=offline_dataset.meta.stats if not cfg.resume else None, # dataset_stats=offline_dataset.meta.stats if not cfg.resume else None,
# Hack: But if we do online traning, we do not need dataset_stats # Hack: But if we do online training, we do not need dataset_stats
dataset_stats=None, dataset_stats=None,
# TODO: Handle resume training # TODO: Handle resume training
pretrained_policy_name_or_path=None, pretrained_policy_name_or_path=None,
@ -195,17 +200,22 @@ def act_with_policy(cfg: DictConfig,
# obs = preprocess_observation(obs) # obs = preprocess_observation(obs)
obs = preprocess_maniskill_observation(obs) obs = preprocess_maniskill_observation(obs)
obs = {key: obs[key].to(device, non_blocking=True) for key in obs} obs = {key: obs[key].to(device, non_blocking=True) for key in obs}
### ACTOR ==================
# NOTE: For the moment we will solely handle the case of a single environment # NOTE: For the moment we will solely handle the case of a single environment
sum_reward_episode = 0 sum_reward_episode = 0
list_transition_to_send_to_learner = [] list_transition_to_send_to_learner = []
list_policy_fps = []
for interaction_step in range(cfg.training.online_steps): for interaction_step in range(cfg.training.online_steps):
# NOTE: At some point we should use a wrapper to handle the observation
# start = time.time()
if interaction_step >= cfg.training.online_step_before_learning: if interaction_step >= cfg.training.online_step_before_learning:
start = time.perf_counter()
action = policy.select_action(batch=obs) action = policy.select_action(batch=obs)
list_policy_fps.append(1.0 / (time.perf_counter() - start + 1e-9))
if list_policy_fps[-1] < cfg.fps:
logging.warning(
f"[ACTOR] policy frame rate {list_policy_fps[-1]} during interaction step {interaction_step} is below the required control frame rate {cfg.fps}"
)
next_obs, reward, done, truncated, info = online_env.step(action.cpu().numpy()) next_obs, reward, done, truncated, info = online_env.step(action.cpu().numpy())
else: else:
action = online_env.action_space.sample() action = online_env.action_space.sample()
@ -213,44 +223,57 @@ def act_with_policy(cfg: DictConfig,
# HACK # HACK
action = torch.tensor(action, dtype=torch.float32).to(device, non_blocking=True) action = torch.tensor(action, dtype=torch.float32).to(device, non_blocking=True)
# logging.info(f"[ACTOR] Time for env step {time.time() - start}")
# HACK: For maniskill # HACK: For maniskill
# next_obs = preprocess_observation(next_obs) # next_obs = preprocess_observation(next_obs)
next_obs = preprocess_maniskill_observation(next_obs) next_obs = preprocess_maniskill_observation(next_obs)
next_obs = {key: next_obs[key].to(device, non_blocking=True) for key in obs} next_obs = {key: next_obs[key].to(device, non_blocking=True) for key in obs}
sum_reward_episode += float(reward[0]) sum_reward_episode += float(reward[0])
# Because we are using a single environment
# we can safely assume that the episode is done # Because we are using a single environment we can index at zero
if done[0].item() or truncated[0].item(): if done[0].item() or truncated[0].item():
# TODO: Handle logging for episode information # TODO: Handle logging for episode information
logging.info(f"[ACTOR] Global step {interaction_step}: Episode reward: {sum_reward_episode}") logging.info(f"[ACTOR] Global step {interaction_step}: Episode reward: {sum_reward_episode}")
if not parameters_queue.empty(): if not parameters_queue.empty():
logging.info("[ACTOR] Load new parameters from Learner.") logging.debug("[ACTOR] Load new parameters from Learner.")
# Load new parameters from Learner
state_dict = parameters_queue.get() state_dict = parameters_queue.get()
state_dict = move_state_dict_to_device(state_dict, device=device) state_dict = move_state_dict_to_device(state_dict, device=device)
policy.actor.load_state_dict(state_dict) policy.actor.load_state_dict(state_dict)
if len(list_transition_to_send_to_learner) > 0: if len(list_transition_to_send_to_learner) > 0:
logging.info(f"[ACTOR] Sending {len(list_transition_to_send_to_learner)} transitions to Learner.") logging.debug(
f"[ACTOR] Sending {len(list_transition_to_send_to_learner)} transitions to Learner."
)
message_queue.put(ActorInformation(transition=list_transition_to_send_to_learner)) message_queue.put(ActorInformation(transition=list_transition_to_send_to_learner))
list_transition_to_send_to_learner = [] list_transition_to_send_to_learner = []
stats = {}
if len(list_policy_fps) > 0:
policy_fps = mean(list_policy_fps)
quantiles_90 = quantiles(list_policy_fps, n=10)[-1]
logging.debug(f"[ACTOR] Average policy frame rate: {policy_fps}")
logging.debug(f"[ACTOR] Policy frame rate 90th percentile: {quantiles_90}")
stats = {"Policy frequency [Hz]": policy_fps, "Policy frequency 90th-p [Hz]": quantiles_90}
list_policy_fps = []
# Send episodic reward to the learner # Send episodic reward to the learner
message_queue.put(ActorInformation(interaction_message={"episodic_reward": sum_reward_episode,"interaction_step": interaction_step})) message_queue.put(
ActorInformation(
interaction_message={
"Episodic reward": sum_reward_episode,
"Interaction step": interaction_step,
**stats,
}
)
)
sum_reward_episode = 0.0 sum_reward_episode = 0.0
# ============================ # TODO (michel-aractingi): Label the reward
# Prepare transition to send
# ============================
# Label the reward
# if config.label_reward_on_actor: # if config.label_reward_on_actor:
# reward = reward_classifier(obs) # reward = reward_classifier(obs)
list_transition_to_send_to_learner.append(Transition( list_transition_to_send_to_learner.append(
# transition_to_send_to_learner = Transition( Transition(
state=obs, state=obs,
action=action, action=action,
reward=reward, reward=reward,
@ -259,24 +282,29 @@ def act_with_policy(cfg: DictConfig,
complementary_info=None, complementary_info=None,
) )
) )
# message_queue.put(ActorInformation(transition=transition_to_send_to_learner))
# assign obs to the next obs and continue the rollout # assign obs to the next obs and continue the rollout
obs = next_obs obs = next_obs
@hydra.main(version_base="1.2", config_name="default", config_path="../../configs") @hydra.main(version_base="1.2", config_name="default", config_path="../../configs")
def actor_cli(cfg: dict): def actor_cli(cfg: dict):
server_thread = Thread(target=serve_actor_service, args=(50051,), daemon=True) port = cfg.actor_learner_config.port
server_thread = Thread(target=serve_actor_service, args=(port,), daemon=True)
server_thread.start() server_thread.start()
policy_thread = Thread(target=act_with_policy, policy_thread = Thread(
target=act_with_policy,
daemon=True, daemon=True,
args=(cfg,hydra.core.hydra_config.HydraConfig.get().run.dir, hydra.core.hydra_config.HydraConfig.get().job.name)) args=(
cfg,
hydra.core.hydra_config.HydraConfig.get().run.dir,
hydra.core.hydra_config.HydraConfig.get().job.name,
),
)
policy_thread.start() policy_thread.start()
policy_thread.join() policy_thread.join()
server_thread.join() server_thread.join()
if __name__ == "__main__":
with open("traceback.log", "w") as f:
faulthandler.register(signal.SIGUSR1, file=f)
if __name__ == "__main__":
actor_cli() actor_cli()

View File

@ -1,3 +1,19 @@
// !/usr/bin/env python
// Copyright 2024 The HuggingFace Inc. team.
// All rights reserved.
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
// http://www.apache.org/licenses/LICENSE-2.0
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
syntax = "proto3"; syntax = "proto3";
package hil_serl; package hil_serl;

View File

@ -1,97 +1,97 @@
import grpc #!/usr/bin/env python
from concurrent import futures
import functools # Copyright 2024 The HuggingFace Inc. team.
import logging # All rights reserved.
import queue #
import pickle # Licensed under the Apache License, Version 2.0 (the "License");
import torch # you may not use this file except in compliance with the License.
import torch.nn.functional as F # You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import io import io
import logging
import pickle
import queue
import time import time
from pprint import pformat from pprint import pformat
import random from threading import Lock, Thread
from typing import Optional, Sequence, TypedDict, Callable
import grpc
# Import generated stubs
import hilserl_pb2 # type: ignore
import hilserl_pb2_grpc # type: ignore
import hydra import hydra
import torch import torch
import torch.nn.functional as F
from torch import nn
from tqdm import tqdm
from deepdiff import DeepDiff
from omegaconf import DictConfig, OmegaConf from omegaconf import DictConfig, OmegaConf
from threading import Thread, Lock from torch import nn
from lerobot.common.datasets.lerobot_dataset import LeRobotDataset
# TODO: Remove the import of maniskill # TODO: Remove the import of maniskill
from lerobot.common.datasets.factory import make_dataset from lerobot.common.datasets.factory import make_dataset
from lerobot.common.logger import Logger, log_output_dir from lerobot.common.logger import Logger, log_output_dir
from lerobot.common.policies.factory import make_policy from lerobot.common.policies.factory import make_policy
from lerobot.common.policies.sac.modeling_sac import SACPolicy from lerobot.common.policies.sac.modeling_sac import SACPolicy
from lerobot.common.policies.utils import get_device_from_parameters
from lerobot.common.utils.utils import ( from lerobot.common.utils.utils import (
format_big_number, format_big_number,
get_safe_torch_device, get_safe_torch_device,
init_hydra_config,
init_logging, init_logging,
set_global_seed, set_global_seed,
) )
from lerobot.scripts.server.buffer import ReplayBuffer, move_transition_to_device, concatenate_batch_transitions, move_state_dict_to_device, Transition from lerobot.scripts.server.buffer import (
ReplayBuffer,
# Import generated stubs concatenate_batch_transitions,
import hilserl_pb2 move_state_dict_to_device,
import hilserl_pb2_grpc move_transition_to_device,
)
logging.basicConfig(level=logging.INFO) logging.basicConfig(level=logging.INFO)
# TODO: Implement it in cleaner way maybe # TODO: Implement it in cleaner way maybe
transition_queue = queue.Queue() transition_queue = queue.Queue()
interaction_message_queue = queue.Queue() interaction_message_queue = queue.Queue()
# 1) Implement the LearnerService so the Actor can send transitions here. def stream_transitions_from_actor(host="127.0.0.1", port=50051):
class LearnerServiceServicer(hilserl_pb2_grpc.LearnerServiceServicer):
# def SendTransition(self, request, context):
# """
# Actor calls this method to push a Transition -> Learner.
# """
# buffer = io.BytesIO(request.transition_bytes)
# transition = torch.load(buffer)
# transition_queue.put(transition)
# return hilserl_pb2.Empty()
def SendInteractionMessage(self, request, context):
""" """
Actor calls this method to push a Transition -> Learner. Runs a gRPC client that listens for transition and interaction messages from an Actor service.
"""
content = pickle.loads(request.interaction_message_bytes)
interaction_message_queue.put(content)
return hilserl_pb2.Empty()
This function establishes a gRPC connection with the given `host` and `port`, then continuously
streams transition data from the `ActorServiceStub`. The received transition data is deserialized
and stored in a queue (`transition_queue`). Similarly, interaction messages are also deserialized
and stored in a separate queue (`interaction_message_queue`).
Args:
host (str, optional): The IP address or hostname of the gRPC server. Defaults to `"127.0.0.1"`.
port (int, optional): The port number on which the gRPC server is running. Defaults to `50051`.
def stream_transitions_from_actor(port=50051):
"""
Runs a gRPC server listening for transitions from the Actor.
""" """
# NOTE: This is waiting for the handshake to be done
# In the future we will do it in a canonical way with a proper handshake
time.sleep(10) time.sleep(10)
channel = grpc.insecure_channel(f'127.0.0.1:{port}', channel = grpc.insecure_channel(
options=[('grpc.max_send_message_length', -1), f"{host}:{port}",
('grpc.max_receive_message_length', -1)]) options=[("grpc.max_send_message_length", -1), ("grpc.max_receive_message_length", -1)],
)
stub = hilserl_pb2_grpc.ActorServiceStub(channel) stub = hilserl_pb2_grpc.ActorServiceStub(channel)
for response in stub.StreamTransition(hilserl_pb2.Empty()): for response in stub.StreamTransition(hilserl_pb2.Empty()):
if response.HasField('transition'): if response.HasField("transition"):
buffer = io.BytesIO(response.transition.transition_bytes) buffer = io.BytesIO(response.transition.transition_bytes)
transition = torch.load(buffer) transition = torch.load(buffer)
transition_queue.put(transition) transition_queue.put(transition)
if response.HasField('interaction_message'): if response.HasField("interaction_message"):
content = pickle.loads(response.interaction_message.interaction_message_bytes) content = pickle.loads(response.interaction_message.interaction_message_bytes)
interaction_message_queue.put(content) interaction_message_queue.put(content)
# NOTE: Cool down the CPU, if you comment this line you will make a huge bottleneck # NOTE: Cool down the CPU, if you comment this line you will make a huge bottleneck
# TODO: LOOK TO REMOVE IT
time.sleep(0.001) time.sleep(0.001)
def learner_push_parameters( def learner_push_parameters(
policy: nn.Module, policy_lock: Lock, actor_host="127.0.0.1", actor_port=50052, seconds_between_pushes=5 policy: nn.Module, policy_lock: Lock, actor_host="127.0.0.1", actor_port=50052, seconds_between_pushes=5
): ):
@ -100,10 +100,10 @@ def learner_push_parameters(
and periodically push new parameters. and periodically push new parameters.
""" """
time.sleep(10) time.sleep(10)
# The Actor's server is presumably listening on a different port, e.g. 50052 channel = grpc.insecure_channel(
channel = grpc.insecure_channel(f"{actor_host}:{actor_port}", f"{actor_host}:{actor_port}",
options=[('grpc.max_send_message_length', -1), options=[("grpc.max_send_message_length", -1), ("grpc.max_receive_message_length", -1)],
('grpc.max_receive_message_length', -1)]) )
actor_stub = hilserl_pb2_grpc.ActorServiceStub(channel) actor_stub = hilserl_pb2_grpc.ActorServiceStub(channel)
while True: while True:
@ -116,20 +116,19 @@ def learner_push_parameters(
params_bytes = buf.getvalue() params_bytes = buf.getvalue()
# Push them to the Actors "SendParameters" method # Push them to the Actors "SendParameters" method
logging.info(f"[LEARNER] Pushing parameters to the Actor") logging.info("[LEARNER] Publishing parameters to the Actor")
response = actor_stub.SendParameters(hilserl_pb2.Parameters(parameter_bytes=params_bytes)) response = actor_stub.SendParameters(hilserl_pb2.Parameters(parameter_bytes=params_bytes)) # noqa: F841
time.sleep(seconds_between_pushes) time.sleep(seconds_between_pushes)
# Checked def add_actor_information_and_train(
def add_actor_information(
cfg, cfg,
device, device: str,
replay_buffer: ReplayBuffer, replay_buffer: ReplayBuffer,
offline_replay_buffer: ReplayBuffer, offline_replay_buffer: ReplayBuffer,
batch_size: int, batch_size: int,
optimizers, optimizers: dict[str, torch.optim.Optimizer],
policy, policy: nn.Module,
policy_lock: Lock, policy_lock: Lock,
buffer_lock: Lock, buffer_lock: Lock,
offline_buffer_lock: Lock, offline_buffer_lock: Lock,
@ -137,34 +136,52 @@ def add_actor_information(
logger: Logger, logger: Logger,
): ):
""" """
In a real application, you might run your training loop here, Handles data transfer from the actor to the learner, manages training updates,
reading from the transition queue and doing gradient updates. and logs training progress in an online reinforcement learning setup.
This function continuously:
- Transfers transitions from the actor to the replay buffer.
- Logs received interaction messages.
- Ensures training begins only when the replay buffer has a sufficient number of transitions.
- Samples batches from the replay buffer and performs multiple critic updates.
- Periodically updates the actor, critic, and temperature optimizers.
- Logs training statistics, including loss values and optimization frequency.
**NOTE:**
- This function performs multiple responsibilities (data transfer, training, and logging).
It should ideally be split into smaller functions in the future.
- Due to Python's **Global Interpreter Lock (GIL)**, running separate threads for different tasks
significantly reduces performance. Instead, this function executes all operations in a single thread.
Args:
cfg: Configuration object containing hyperparameters.
device (str): The computing device (`"cpu"` or `"cuda"`).
replay_buffer (ReplayBuffer): The primary replay buffer storing online transitions.
offline_replay_buffer (ReplayBuffer): An additional buffer for offline transitions.
batch_size (int): The number of transitions to sample per training step.
optimizers (Dict[str, torch.optim.Optimizer]): A dictionary of optimizers (`"actor"`, `"critic"`, `"temperature"`).
policy (nn.Module): The reinforcement learning policy with critic, actor, and temperature parameters.
policy_lock (Lock): A threading lock to ensure safe policy updates.
buffer_lock (Lock): A threading lock to safely access the online replay buffer.
offline_buffer_lock (Lock): A threading lock to safely access the offline replay buffer.
logger_lock (Lock): A threading lock to safely log training metrics.
logger (Logger): Logger instance for tracking training progress.
""" """
# NOTE: This function doesn't have a single responsibility, it should be split into multiple functions # NOTE: This function doesn't have a single responsibility, it should be split into multiple functions
# in the future. The reason why we did that is the GIL in Python. It's super slow the performance # in the future. The reason why we did that is the GIL in Python. It's super slow the performance
# are divided by 200. So we need to have a single thread that does all the work. # are divided by 200. So we need to have a single thread that does all the work.
start = time.time() time.time()
optimization_step = 0 optimization_step = 0
timeout_for_adding_transitions = 1
while True: while True:
time_for_adding_transitions = time.time()
while not transition_queue.empty(): while not transition_queue.empty():
transition_list = transition_queue.get() transition_list = transition_queue.get()
for transition in transition_list: for transition in transition_list:
transition = move_transition_to_device(transition, device=device) transition = move_transition_to_device(transition, device=device)
replay_buffer.add(**transition) replay_buffer.add(**transition)
# logging.info(f"[LEARNER] size of replay buffer: {len(replay_buffer)}")
# logging.info(f"[LEARNER] size of transition queues: {transition_queue.qsize()}")
# logging.info(f"[LEARNER] size of replay buffer: {len(replay_buffer)}")
# logging.info(f"[LEARNER] size of transition queues: {transition }")
if len(replay_buffer) > cfg.training.online_step_before_learning:
logging.info(f"[LEARNER] size of replay buffer: {len(replay_buffer)}")
while not interaction_message_queue.empty(): while not interaction_message_queue.empty():
interaction_message = interaction_message_queue.get() interaction_message = interaction_message_queue.get()
logger.log_dict(interaction_message,mode="train",custom_step_key="interaction_step") logger.log_dict(interaction_message, mode="train", custom_step_key="Interaction step")
# logging.info(f"[LEARNER] size of interaction message queue: {interaction_message_queue.qsize()}")
if len(replay_buffer) < cfg.training.online_step_before_learning: if len(replay_buffer) < cfg.training.online_step_before_learning:
continue continue
@ -223,7 +240,6 @@ def add_actor_information(
training_infos = {} training_infos = {}
training_infos["loss_critic"] = loss_critic.item() training_infos["loss_critic"] = loss_critic.item()
if optimization_step % cfg.training.policy_update_freq == 0: if optimization_step % cfg.training.policy_update_freq == 0:
for _ in range(cfg.training.policy_update_freq): for _ in range(cfg.training.policy_update_freq):
with policy_lock: with policy_lock:
@ -242,18 +258,52 @@ def add_actor_information(
training_infos["loss_temperature"] = loss_temperature.item() training_infos["loss_temperature"] = loss_temperature.item()
policy.update_target_networks()
if optimization_step % cfg.training.log_freq == 0: if optimization_step % cfg.training.log_freq == 0:
logger.log_dict(training_infos, step=optimization_step, mode="train") logger.log_dict(training_infos, step=optimization_step, mode="train")
policy.update_target_networks()
optimization_step += 1
time_for_one_optimization_step = time.time() - time_for_one_optimization_step time_for_one_optimization_step = time.time() - time_for_one_optimization_step
frequency_for_one_optimization_step = 1 / (time_for_one_optimization_step + 1e-9)
logging.info(f"[LEARNER] Time for one optimization step: {time_for_one_optimization_step}") logging.debug(f"[LEARNER] Optimization frequency loop [Hz]: {frequency_for_one_optimization_step}")
logger.log_dict({"Time optimization step":time_for_one_optimization_step}, step=optimization_step, mode="train")
logger.log_dict(
{"Optimization frequency loop [Hz]": frequency_for_one_optimization_step},
step=optimization_step,
mode="train",
)
optimization_step += 1
if optimization_step % cfg.training.log_freq == 0:
logging.info(f"[LEARNER] Number of optimization step: {optimization_step}")
def make_optimizers_and_scheduler(cfg, policy): def make_optimizers_and_scheduler(cfg, policy: nn.Module):
"""
Creates and returns optimizers for the actor, critic, and temperature components of a reinforcement learning policy.
This function sets up Adam optimizers for:
- The **actor network**, ensuring that only relevant parameters are optimized.
- The **critic ensemble**, which evaluates the value function.
- The **temperature parameter**, which controls the entropy in soft actor-critic (SAC)-like methods.
It also initializes a learning rate scheduler, though currently, it is set to `None`.
**NOTE:**
- If the encoder is shared, its parameters are excluded from the actors optimization process.
- The policys log temperature (`log_alpha`) is wrapped in a list to ensure proper optimization as a standalone tensor.
Args:
cfg: Configuration object containing hyperparameters.
policy (nn.Module): The policy model containing the actor, critic, and temperature components.
Returns:
Tuple[Dict[str, torch.optim.Optimizer], Optional[torch.optim.lr_scheduler._LRScheduler]]:
A tuple containing:
- `optimizers`: A dictionary mapping component names ("actor", "critic", "temperature") to their respective Adam optimizers.
- `lr_scheduler`: Currently set to `None` but can be extended to support learning rate scheduling.
"""
optimizer_actor = torch.optim.Adam( optimizer_actor = torch.optim.Adam(
# NOTE: Handle the case of shared encoder where the encoder weights are not optimized with the gradient of the actor # NOTE: Handle the case of shared encoder where the encoder weights are not optimized with the gradient of the actor
params=policy.actor.parameters_to_optimize, params=policy.actor.parameters_to_optimize,
@ -273,8 +323,6 @@ def make_optimizers_and_scheduler(cfg, policy):
return optimizers, lr_scheduler return optimizers, lr_scheduler
def train(cfg: DictConfig, out_dir: str | None = None, job_name: str | None = None): def train(cfg: DictConfig, out_dir: str | None = None, job_name: str | None = None):
if out_dir is None: if out_dir is None:
raise NotImplementedError() raise NotImplementedError()
@ -332,6 +380,7 @@ def train(cfg: DictConfig, out_dir: str | None = None, job_name: str | None = No
batch_size = cfg.training.batch_size batch_size = cfg.training.batch_size
offline_buffer_lock = None offline_buffer_lock = None
offline_replay_buffer = None offline_replay_buffer = None
if cfg.dataset_repo_id is not None: if cfg.dataset_repo_id is not None:
logging.info("make_dataset offline buffer") logging.info("make_dataset offline buffer")
offline_dataset = make_dataset(cfg) offline_dataset = make_dataset(cfg)
@ -342,15 +391,24 @@ def train(cfg: DictConfig, out_dir: str | None = None, job_name: str | None = No
offline_buffer_lock = Lock() offline_buffer_lock = Lock()
batch_size: int = batch_size // 2 # We will sample from both replay buffer batch_size: int = batch_size // 2 # We will sample from both replay buffer
server_thread = Thread(target=stream_transitions_from_actor, args=(50051,), daemon=True) actor_ip = cfg.actor_learner_config.actor_ip
port = cfg.actor_learner_config.port
server_thread = Thread(
target=stream_transitions_from_actor,
args=(
actor_ip,
port,
),
daemon=True,
)
server_thread.start() server_thread.start()
# Start a background thread to process transitions from the queue
transition_thread = Thread( transition_thread = Thread(
target=add_actor_information, target=add_actor_information_and_train,
daemon=True, daemon=True,
args=(cfg, args=(
cfg,
device, device,
replay_buffer, replay_buffer,
offline_replay_buffer, offline_replay_buffer,
@ -361,29 +419,20 @@ def train(cfg: DictConfig, out_dir: str | None = None, job_name: str | None = No
buffer_lock, buffer_lock,
offline_buffer_lock, offline_buffer_lock,
logger_lock, logger_lock,
logger), logger,
),
) )
transition_thread.start() transition_thread.start()
param_push_thread = Thread( param_push_thread = Thread(
target=learner_push_parameters, target=learner_push_parameters,
args=(policy, policy_lock, "127.0.0.1", 50051, 15), args=(policy, policy_lock, actor_ip, port, 15),
# args=("127.0.0.1", 50052),
daemon=True, daemon=True,
) )
param_push_thread.start() param_push_thread.start()
# interaction_thread = Thread(
# target=add_message_interaction_to_wandb,
# daemon=True,
# args=(cfg, logger, logger_lock),
# )
# interaction_thread.start()
transition_thread.join() transition_thread.join()
# param_push_thread.join()
server_thread.join() server_thread.join()
# interaction_thread.join()
@hydra.main(version_base="1.2", config_name="default", config_path="../../configs") @hydra.main(version_base="1.2", config_name="default", config_path="../../configs")