- Added additional logging information in wandb around the timings of the policy loop and optimization loop.
- Optimized critic design that improves the performance of the learner loop by a factor of 2 - Cleaned the code and fixed style issues - Completed the config with actor_learner_config field that contains host-ip and port elemnts that are necessary for the actor-learner servers. Co-authored-by: Adil Zouitine <adilzouitinegm@gmail.com>
This commit is contained in:
parent
2ae657f568
commit
8cd44ae163
|
@ -45,6 +45,14 @@ class SACConfig:
|
|||
"action": {"min": [-1, -1], "max": [1, 1]},
|
||||
}
|
||||
)
|
||||
# TODO: Move it outside of the config
|
||||
actor_learner_config: dict[str, str | int] = field(
|
||||
default_factory=lambda: {
|
||||
"actor_ip": "127.0.0.1",
|
||||
"port": 50051,
|
||||
"learner_ip": "127.0.0.1",
|
||||
}
|
||||
)
|
||||
camera_number: int = 1
|
||||
# Add type annotations for these fields:
|
||||
image_encoder_hidden_dim: int = 32
|
||||
|
|
|
@ -17,8 +17,7 @@
|
|||
|
||||
# TODO: (1) better device management
|
||||
|
||||
from collections import deque
|
||||
from typing import Callable, Optional, Sequence, Tuple, Union
|
||||
from typing import Callable, Optional, Tuple
|
||||
|
||||
import einops
|
||||
import numpy as np
|
||||
|
@ -74,43 +73,42 @@ class SACPolicy(
|
|||
config.output_shapes, config.output_normalization_modes, dataset_stats
|
||||
)
|
||||
|
||||
# NOTE: For images the encoder should be shared between the actor and critic
|
||||
if config.shared_encoder:
|
||||
encoder_critic = SACObservationEncoder(config)
|
||||
encoder_actor: SACObservationEncoder = encoder_critic
|
||||
else:
|
||||
encoder_critic = SACObservationEncoder(config)
|
||||
encoder_actor = SACObservationEncoder(config)
|
||||
# Define networks
|
||||
critic_nets = []
|
||||
for _ in range(config.num_critics):
|
||||
critic_net = Critic(
|
||||
encoder=encoder_critic,
|
||||
network=MLP(
|
||||
input_dim=encoder_critic.output_dim + config.output_shapes["action"][0],
|
||||
**config.critic_network_kwargs,
|
||||
),
|
||||
device=device,
|
||||
)
|
||||
critic_nets.append(critic_net)
|
||||
|
||||
target_critic_nets = []
|
||||
for _ in range(config.num_critics):
|
||||
target_critic_net = Critic(
|
||||
encoder=encoder_critic,
|
||||
network=MLP(
|
||||
input_dim=encoder_critic.output_dim + config.output_shapes["action"][0],
|
||||
**config.critic_network_kwargs,
|
||||
),
|
||||
device=device,
|
||||
)
|
||||
target_critic_nets.append(target_critic_net)
|
||||
self.critic_ensemble = CriticEnsemble(
|
||||
encoder=encoder_critic,
|
||||
network_list=nn.ModuleList(
|
||||
[
|
||||
MLP(
|
||||
input_dim=encoder_critic.output_dim + config.output_shapes["action"][0],
|
||||
**config.critic_network_kwargs,
|
||||
)
|
||||
for _ in range(config.num_critics)
|
||||
]
|
||||
),
|
||||
device=device,
|
||||
)
|
||||
|
||||
self.critic_ensemble = create_critic_ensemble(
|
||||
critics=critic_nets, num_critics=config.num_critics, device=device
|
||||
)
|
||||
self.critic_target = create_critic_ensemble(
|
||||
critics=target_critic_nets, num_critics=config.num_critics, device=device
|
||||
self.critic_target = CriticEnsemble(
|
||||
encoder=encoder_critic,
|
||||
network_list=nn.ModuleList(
|
||||
[
|
||||
MLP(
|
||||
input_dim=encoder_critic.output_dim + config.output_shapes["action"][0],
|
||||
**config.critic_network_kwargs,
|
||||
)
|
||||
for _ in range(config.num_critics)
|
||||
]
|
||||
),
|
||||
device=device,
|
||||
)
|
||||
|
||||
self.critic_target.load_state_dict(self.critic_ensemble.state_dict())
|
||||
|
||||
self.actor = Policy(
|
||||
|
@ -123,7 +121,8 @@ class SACPolicy(
|
|||
)
|
||||
if config.target_entropy is None:
|
||||
config.target_entropy = -np.prod(config.output_shapes["action"][0]) / 2 # (-dim(A)/2)
|
||||
# TODO: Handle the case where the temparameter is a fixed
|
||||
|
||||
# TODO (azouitine): Handle the case where the temparameter is a fixed
|
||||
self.log_alpha = torch.zeros(1, requires_grad=True, device=device)
|
||||
self.temperature = self.log_alpha.exp().item()
|
||||
|
||||
|
@ -152,18 +151,19 @@ class SACPolicy(
|
|||
Tensor of Q-values from all critics
|
||||
"""
|
||||
critics = self.critic_target if use_target else self.critic_ensemble
|
||||
q_values = torch.stack([critic(observations, actions) for critic in critics])
|
||||
q_values = critics(observations, actions)
|
||||
return q_values
|
||||
|
||||
def forward(self, batch: dict[str, Tensor]) -> dict[str, Tensor | float]: ...
|
||||
def update_target_networks(self):
|
||||
"""Update target networks with exponential moving average"""
|
||||
for target_critic, critic in zip(self.critic_target, self.critic_ensemble, strict=False):
|
||||
for target_param, param in zip(target_critic.parameters(), critic.parameters(), strict=False):
|
||||
target_param.data.copy_(
|
||||
param.data * self.config.critic_target_update_weight
|
||||
+ target_param.data * (1.0 - self.config.critic_target_update_weight)
|
||||
)
|
||||
for target_param, param in zip(
|
||||
self.critic_target.parameters(), self.critic_ensemble.parameters(), strict=False
|
||||
):
|
||||
target_param.data.copy_(
|
||||
param.data * self.config.critic_target_update_weight
|
||||
+ target_param.data * (1.0 - self.config.critic_target_update_weight)
|
||||
)
|
||||
|
||||
def compute_loss_critic(self, observations, actions, rewards, next_observations, done) -> Tensor:
|
||||
temperature = self.log_alpha.exp().item()
|
||||
|
@ -264,34 +264,83 @@ class MLP(nn.Module):
|
|||
return self.net(x)
|
||||
|
||||
|
||||
class Critic(nn.Module):
|
||||
class CriticEnsemble(nn.Module):
|
||||
"""
|
||||
┌──────────────────┬─────────────────────────────────────────────────────────┐
|
||||
│ Critic Ensemble │ │
|
||||
├──────────────────┘ │
|
||||
│ │
|
||||
│ ┌────┐ ┌────┐ ┌────┐ │
|
||||
│ │ Q1 │ │ Q2 │ │ Qn │ │
|
||||
│ └────┘ └────┘ └────┘ │
|
||||
│ ┌──────────────┐ ┌──────────────┐ ┌──────────────┐ │
|
||||
│ │ │ │ │ │ │ │
|
||||
│ │ MLP 1 │ │ MLP 2 │ │ MLP │ │
|
||||
│ │ │ │ │ ... │ num_critics │ │
|
||||
│ │ │ │ │ │ │ │
|
||||
│ └──────────────┘ └──────────────┘ └──────────────┘ │
|
||||
│ ▲ ▲ ▲ │
|
||||
│ └───────────────────┴───────┬────────────────────────────┘ │
|
||||
│ │ │
|
||||
│ │ │
|
||||
│ ┌───────────────────┐ │
|
||||
│ │ Embedding │ │
|
||||
│ │ │ │
|
||||
│ └───────────────────┘ │
|
||||
│ ▲ │
|
||||
│ │ │
|
||||
│ ┌─────────────┴────────────┐ │
|
||||
│ │ │ │
|
||||
│ │ SACObservationEncoder │ │
|
||||
│ │ │ │
|
||||
│ └──────────────────────────┘ │
|
||||
│ ▲ │
|
||||
│ │ │
|
||||
│ │ │
|
||||
│ │ │
|
||||
└───────────────────────────┬────────────────────┬───────────────────────────┘
|
||||
│ Observation │
|
||||
└────────────────────┘
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
encoder: Optional[nn.Module],
|
||||
network: nn.Module,
|
||||
network_list: nn.Module,
|
||||
init_final: Optional[float] = None,
|
||||
device: str = "cpu",
|
||||
):
|
||||
super().__init__()
|
||||
self.device = torch.device(device)
|
||||
self.encoder = encoder
|
||||
self.network = network
|
||||
self.network_list = network_list
|
||||
self.init_final = init_final
|
||||
|
||||
# for network in network_list:
|
||||
# network.to(self.device)
|
||||
|
||||
# Find the last Linear layer's output dimension
|
||||
for layer in reversed(network.net):
|
||||
for layer in reversed(network_list[0].net):
|
||||
if isinstance(layer, nn.Linear):
|
||||
out_features = layer.out_features
|
||||
break
|
||||
|
||||
# Output layer
|
||||
self.output_layers = []
|
||||
if init_final is not None:
|
||||
self.output_layer = nn.Linear(out_features, 1)
|
||||
nn.init.uniform_(self.output_layer.weight, -init_final, init_final)
|
||||
nn.init.uniform_(self.output_layer.bias, -init_final, init_final)
|
||||
for _ in network_list:
|
||||
output_layer = nn.Linear(out_features, 1, device=device)
|
||||
nn.init.uniform_(output_layer.weight, -init_final, init_final)
|
||||
nn.init.uniform_(output_layer.bias, -init_final, init_final)
|
||||
self.output_layers.append(output_layer)
|
||||
else:
|
||||
self.output_layer = nn.Linear(out_features, 1)
|
||||
orthogonal_init()(self.output_layer.weight)
|
||||
self.output_layers = []
|
||||
for _ in network_list:
|
||||
output_layer = nn.Linear(out_features, 1, device=device)
|
||||
orthogonal_init()(output_layer.weight)
|
||||
self.output_layers.append(output_layer)
|
||||
|
||||
self.output_layers = nn.ModuleList(self.output_layers)
|
||||
|
||||
self.to(self.device)
|
||||
|
||||
|
@ -307,9 +356,12 @@ class Critic(nn.Module):
|
|||
obs_enc = observations if self.encoder is None else self.encoder(observations)
|
||||
|
||||
inputs = torch.cat([obs_enc, actions], dim=-1)
|
||||
x = self.network(inputs)
|
||||
value = self.output_layer(x)
|
||||
return value.squeeze(-1)
|
||||
list_q_values = []
|
||||
for network, output_layer in zip(self.network_list, self.output_layers, strict=False):
|
||||
x = network(inputs)
|
||||
value = output_layer(x)
|
||||
list_q_values.append(value.squeeze(-1))
|
||||
return torch.stack(list_q_values)
|
||||
|
||||
|
||||
class Policy(nn.Module):
|
||||
|
@ -416,9 +468,7 @@ class Policy(nn.Module):
|
|||
|
||||
|
||||
class SACObservationEncoder(nn.Module):
|
||||
"""Encode image and/or state vector observations.
|
||||
TODO(ke-wang): The original work allows for (1) stacking multiple history frames and (2) using pretrained resnet encoders.
|
||||
"""
|
||||
"""Encode image and/or state vector observations."""
|
||||
|
||||
def __init__(self, config: SACConfig):
|
||||
"""
|
||||
|
@ -513,8 +563,7 @@ class SACObservationEncoder(nn.Module):
|
|||
feat.append(self.env_state_enc_layers(obs_dict["observation.environment_state"]))
|
||||
if "observation.state" in self.config.input_shapes:
|
||||
feat.append(self.state_enc_layers(obs_dict["observation.state"]))
|
||||
# TODO(ke-wang): currently average over all features, concatenate all features maybe a better way
|
||||
# return torch.stack(feat, dim=0).mean(0)
|
||||
|
||||
features = torch.cat(tensors=feat, dim=-1)
|
||||
features = self.aggregation_layer(features)
|
||||
|
||||
|
@ -530,12 +579,8 @@ def orthogonal_init():
|
|||
return lambda x: torch.nn.init.orthogonal_(x, gain=1.0)
|
||||
|
||||
|
||||
def create_critic_ensemble(critics: list[nn.Module], num_critics: int, device: str = "cpu") -> nn.ModuleList:
|
||||
"""Creates an ensemble of critic networks"""
|
||||
assert len(critics) == num_critics, f"Expected {num_critics} critics, got {len(critics)}"
|
||||
return nn.ModuleList(critics).to(device)
|
||||
|
||||
|
||||
# TODO (azouitine): I think in our case this function is not usefull we should remove it
|
||||
# after some investigation
|
||||
# borrowed from tdmpc
|
||||
def flatten_forward_unflatten(fn: Callable[[Tensor], Tensor], image_tensor: Tensor) -> Tensor:
|
||||
"""Helper to temporarily flatten extra dims at the start of the image tensor.
|
||||
|
|
|
@ -8,8 +8,7 @@
|
|||
# env.gym.obs_type=environment_state_agent_pos \
|
||||
|
||||
seed: 1
|
||||
dataset_repo_id: null
|
||||
|
||||
dataset_repo_id: null
|
||||
|
||||
training:
|
||||
# Offline training dataloader
|
||||
|
@ -75,15 +74,18 @@ policy:
|
|||
# discount: 0.99
|
||||
discount: 0.80
|
||||
temperature_init: 1.0
|
||||
num_critics: 2
|
||||
num_critics: 2 #10
|
||||
num_subsample_critics: null
|
||||
critic_lr: 3e-4
|
||||
actor_lr: 3e-4
|
||||
temperature_lr: 3e-4
|
||||
# critic_target_update_weight: 0.005
|
||||
critic_target_update_weight: 0.01
|
||||
utd_ratio: 2
|
||||
utd_ratio: 2 # 10
|
||||
|
||||
actor_learner_config:
|
||||
actor_ip: "127.0.0.1"
|
||||
port: 50051
|
||||
|
||||
# # Loss coefficients.
|
||||
# reward_coeff: 0.5
|
||||
|
|
|
@ -13,117 +13,123 @@
|
|||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import io
|
||||
import logging
|
||||
import functools
|
||||
from pprint import pformat
|
||||
import random
|
||||
from typing import Optional, Sequence, TypedDict, Callable
|
||||
import pickle
|
||||
import queue
|
||||
import time
|
||||
from concurrent import futures
|
||||
from statistics import mean, quantiles
|
||||
|
||||
import hydra
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from torch import nn
|
||||
from tqdm import tqdm
|
||||
from deepdiff import DeepDiff
|
||||
from omegaconf import DictConfig, OmegaConf
|
||||
|
||||
from lerobot.common.datasets.lerobot_dataset import LeRobotDataset
|
||||
|
||||
# TODO: Remove the import of maniskill
|
||||
from lerobot.common.datasets.factory import make_dataset
|
||||
from lerobot.common.envs.factory import make_env, make_maniskill_env
|
||||
from lerobot.common.envs.utils import preprocess_observation, preprocess_maniskill_observation
|
||||
from lerobot.common.logger import Logger, log_output_dir
|
||||
from lerobot.common.policies.factory import make_policy
|
||||
from lerobot.common.policies.sac.modeling_sac import SACPolicy
|
||||
from lerobot.common.policies.utils import get_device_from_parameters
|
||||
from lerobot.common.utils.utils import (
|
||||
format_big_number,
|
||||
get_safe_torch_device,
|
||||
init_hydra_config,
|
||||
init_logging,
|
||||
set_global_seed,
|
||||
)
|
||||
# from lerobot.scripts.eval import eval_policy
|
||||
from threading import Thread
|
||||
import queue
|
||||
|
||||
import grpc
|
||||
from lerobot.scripts.server import hilserl_pb2, hilserl_pb2_grpc
|
||||
import io
|
||||
import time
|
||||
import logging
|
||||
from concurrent import futures
|
||||
from threading import Thread
|
||||
from lerobot.scripts.server.buffer import move_state_dict_to_device, move_transition_to_device, Transition
|
||||
import hydra
|
||||
import torch
|
||||
from omegaconf import DictConfig
|
||||
from torch import nn
|
||||
|
||||
import faulthandler
|
||||
import signal
|
||||
# TODO: Remove the import of maniskill
|
||||
from lerobot.common.envs.factory import make_maniskill_env
|
||||
from lerobot.common.envs.utils import preprocess_maniskill_observation
|
||||
from lerobot.common.policies.factory import make_policy
|
||||
from lerobot.common.policies.sac.modeling_sac import SACPolicy
|
||||
from lerobot.common.utils.utils import (
|
||||
get_safe_torch_device,
|
||||
set_global_seed,
|
||||
)
|
||||
from lerobot.scripts.server import hilserl_pb2, hilserl_pb2_grpc
|
||||
from lerobot.scripts.server.buffer import Transition, move_state_dict_to_device, move_transition_to_device
|
||||
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
|
||||
parameters_queue = queue.Queue(maxsize=1)
|
||||
message_queue = queue.Queue(maxsize=1_000_000)
|
||||
|
||||
|
||||
class ActorInformation:
|
||||
"""
|
||||
This helper class is used to differentiate between two types of messages that are placed in the same queue during streaming:
|
||||
|
||||
- **Transition Data:** Contains experience tuples (observation, action, reward, next observation) collected during interaction.
|
||||
- **Interaction Messages:** Encapsulates statistics related to the interaction process.
|
||||
|
||||
Attributes:
|
||||
transition (Optional): Transition data to be sent to the learner.
|
||||
interaction_message (Optional): Iteraction message providing additional statistics for logging.
|
||||
"""
|
||||
|
||||
def __init__(self, transition=None, interaction_message=None):
|
||||
self.transition = transition
|
||||
self.interaction_message = interaction_message
|
||||
|
||||
|
||||
# 1) Implement ActorService so the Learner can send parameters to this Actor.
|
||||
class ActorServiceServicer(hilserl_pb2_grpc.ActorServiceServicer):
|
||||
def StreamTransition(self, request, context):
|
||||
"""
|
||||
gRPC service for actor-learner communication in reinforcement learning.
|
||||
|
||||
This service is responsible for:
|
||||
1. Streaming batches of transition data and statistical metrics from the actor to the learner.
|
||||
2. Receiving updated network parameters from the learner.
|
||||
"""
|
||||
|
||||
def StreamTransition(self, request, context): # noqa: N802
|
||||
"""
|
||||
Streams data from the actor to the learner.
|
||||
|
||||
This function continuously retrieves messages from the queue and processes them based on their type:
|
||||
|
||||
- **Transition Data:**
|
||||
- A batch of transitions (observation, action, reward, next observation) is collected.
|
||||
- Transitions are moved to the CPU and serialized using PyTorch.
|
||||
- The serialized data is wrapped in a `hilserl_pb2.Transition` message and sent to the learner.
|
||||
|
||||
- **Interaction Messages:**
|
||||
- Contains useful statistics about episodic rewards and policy timings.
|
||||
- The message is serialized using `pickle` and sent to the learner.
|
||||
|
||||
Yields:
|
||||
hilserl_pb2.ActorInformation: The response message containing either transition data or an interaction message.
|
||||
"""
|
||||
while True:
|
||||
# logging.info(f"[ACTOR] before message.empty()")
|
||||
# logging.info(f"[ACTOR] size transition queue {message_queue.qsize()}")
|
||||
# time.sleep(0.01)
|
||||
# if message_queue.empty():
|
||||
# continue
|
||||
# logging.info(f"[ACTOR] after message.empty()")
|
||||
start = time.time()
|
||||
message = message_queue.get(block=True)
|
||||
# logging.info(f"[ACTOR] Message queue get time {time.time() - start}")
|
||||
|
||||
if message.transition is not None:
|
||||
# transition_to_send_to_learner = move_transition_to_device(message.transition, device="cpu")
|
||||
transition_to_send_to_learner = [move_transition_to_device(T, device="cpu") for T in message.transition]
|
||||
# logging.info(f"[ACTOR] Message queue get time {time.time() - start}")
|
||||
transition_to_send_to_learner = [
|
||||
move_transition_to_device(T, device="cpu") for T in message.transition
|
||||
]
|
||||
|
||||
# Serialize it
|
||||
buf = io.BytesIO()
|
||||
torch.save(transition_to_send_to_learner, buf)
|
||||
transition_bytes = buf.getvalue()
|
||||
|
||||
transition_message = hilserl_pb2.Transition(
|
||||
transition_bytes=transition_bytes
|
||||
)
|
||||
|
||||
response = hilserl_pb2.ActorInformation(
|
||||
transition=transition_message
|
||||
)
|
||||
logging.info(f"[ACTOR] time to yield transition response {time.time() - start}")
|
||||
logging.info(f"[ACTOR] size transition queue {message_queue.qsize()}")
|
||||
|
||||
transition_message = hilserl_pb2.Transition(transition_bytes=transition_bytes)
|
||||
|
||||
response = hilserl_pb2.ActorInformation(transition=transition_message)
|
||||
|
||||
elif message.interaction_message is not None:
|
||||
# Serialize it and send it to the Learner's server
|
||||
content = hilserl_pb2.InteractionMessage(
|
||||
interaction_message_bytes=pickle.dumps(message.interaction_message)
|
||||
)
|
||||
response = hilserl_pb2.ActorInformation(
|
||||
interaction_message=content
|
||||
)
|
||||
response = hilserl_pb2.ActorInformation(interaction_message=content)
|
||||
|
||||
# logging.info(f"[ACTOR] yield response before")
|
||||
yield response
|
||||
# logging.info(f"[ACTOR] response yielded after")
|
||||
|
||||
def SendParameters(self, request, context):
|
||||
def SendParameters(self, request, context): # noqa: N802
|
||||
"""
|
||||
Learner calls this with updated Parameters -> Actor
|
||||
Receives updated parameters from the learner and updates the actor.
|
||||
|
||||
The learner calls this method to send new model parameters. The received parameters are deserialized
|
||||
and placed in a queue to be consumed by the actor.
|
||||
|
||||
Args:
|
||||
request (hilserl_pb2.ParameterUpdate): The request containing serialized network parameters.
|
||||
context (grpc.ServicerContext): The gRPC context.
|
||||
|
||||
Returns:
|
||||
hilserl_pb2.Empty: An empty response to acknowledge receipt.
|
||||
"""
|
||||
# logging.info("[ACTOR] Received parameters from Learner.")
|
||||
buffer = io.BytesIO(request.parameter_bytes)
|
||||
params = torch.load(buffer)
|
||||
parameters_queue.put(params)
|
||||
|
@ -132,38 +138,38 @@ class ActorServiceServicer(hilserl_pb2_grpc.ActorServiceServicer):
|
|||
|
||||
def serve_actor_service(port=50052):
|
||||
"""
|
||||
Runs a gRPC server so that the Learner can push parameters to the Actor.
|
||||
Runs a gRPC server to start streaming the data from the actor to the learner.
|
||||
Throught this server the learner can push parameters to the Actor as well.
|
||||
"""
|
||||
server = grpc.server(futures.ThreadPoolExecutor(max_workers=20),
|
||||
options=[('grpc.max_send_message_length', -1),
|
||||
('grpc.max_receive_message_length', -1)])
|
||||
hilserl_pb2_grpc.add_ActorServiceServicer_to_server(
|
||||
ActorServiceServicer(), server
|
||||
server = grpc.server(
|
||||
futures.ThreadPoolExecutor(max_workers=20),
|
||||
options=[("grpc.max_send_message_length", -1), ("grpc.max_receive_message_length", -1)],
|
||||
)
|
||||
server.add_insecure_port(f'[::]:{port}')
|
||||
hilserl_pb2_grpc.add_ActorServiceServicer_to_server(ActorServiceServicer(), server)
|
||||
server.add_insecure_port(f"[::]:{port}")
|
||||
server.start()
|
||||
logging.info(f"[ACTOR] gRPC server listening on port {port}")
|
||||
server.wait_for_termination()
|
||||
|
||||
def act_with_policy(cfg: DictConfig,
|
||||
out_dir: str | None = None,
|
||||
job_name: str | None = None):
|
||||
|
||||
if out_dir is None:
|
||||
raise NotImplementedError()
|
||||
if job_name is None:
|
||||
raise NotImplementedError()
|
||||
def act_with_policy(cfg: DictConfig, out_dir: str | None = None, job_name: str | None = None):
|
||||
"""
|
||||
Executes policy interaction within the environment.
|
||||
|
||||
This function rolls out the policy in the environment, collecting interaction data and pushing it to a queue for streaming to the learner.
|
||||
Once an episode is completed, updated network parameters received from the learner are retrieved from a queue and loaded into the network.
|
||||
|
||||
Args:
|
||||
cfg (DictConfig): Configuration settings for the interaction process.
|
||||
out_dir (Optional[str]): Directory to store output logs or results. Defaults to None.
|
||||
job_name (Optional[str]): Name of the job for logging or tracking purposes. Defaults to None.
|
||||
"""
|
||||
|
||||
logging.info("make_env online")
|
||||
|
||||
# online_env = make_env(cfg, n_envs=1)
|
||||
# TODO: Remove the import of maniskill and unifiy with make env
|
||||
online_env = make_maniskill_env(cfg, n_envs=1)
|
||||
if cfg.training.eval_freq > 0:
|
||||
logging.info("make_env eval")
|
||||
# eval_env = make_env(cfg, n_envs=1)
|
||||
# TODO: Remove the import of maniskill and unifiy with make env
|
||||
eval_env = make_maniskill_env(cfg, n_envs=1)
|
||||
|
||||
set_global_seed(cfg.seed)
|
||||
device = get_safe_torch_device(cfg.device, log=True)
|
||||
|
@ -172,8 +178,7 @@ def act_with_policy(cfg: DictConfig,
|
|||
torch.backends.cuda.matmul.allow_tf32 = True
|
||||
|
||||
logging.info("make_policy")
|
||||
|
||||
|
||||
|
||||
### Instantiate the policy in both the actor and learner processes
|
||||
### To avoid sending a SACPolicy object through the port, we create a policy intance
|
||||
### on both sides, the learner sends the updated parameters every n steps to update the actor's parameters
|
||||
|
@ -181,7 +186,7 @@ def act_with_policy(cfg: DictConfig,
|
|||
policy: SACPolicy = make_policy(
|
||||
hydra_cfg=cfg,
|
||||
# dataset_stats=offline_dataset.meta.stats if not cfg.resume else None,
|
||||
# Hack: But if we do online traning, we do not need dataset_stats
|
||||
# Hack: But if we do online training, we do not need dataset_stats
|
||||
dataset_stats=None,
|
||||
# TODO: Handle resume training
|
||||
pretrained_policy_name_or_path=None,
|
||||
|
@ -195,17 +200,22 @@ def act_with_policy(cfg: DictConfig,
|
|||
# obs = preprocess_observation(obs)
|
||||
obs = preprocess_maniskill_observation(obs)
|
||||
obs = {key: obs[key].to(device, non_blocking=True) for key in obs}
|
||||
### ACTOR ==================
|
||||
|
||||
# NOTE: For the moment we will solely handle the case of a single environment
|
||||
sum_reward_episode = 0
|
||||
list_transition_to_send_to_learner = []
|
||||
list_policy_fps = []
|
||||
|
||||
for interaction_step in range(cfg.training.online_steps):
|
||||
# NOTE: At some point we should use a wrapper to handle the observation
|
||||
|
||||
# start = time.time()
|
||||
if interaction_step >= cfg.training.online_step_before_learning:
|
||||
start = time.perf_counter()
|
||||
action = policy.select_action(batch=obs)
|
||||
list_policy_fps.append(1.0 / (time.perf_counter() - start + 1e-9))
|
||||
if list_policy_fps[-1] < cfg.fps:
|
||||
logging.warning(
|
||||
f"[ACTOR] policy frame rate {list_policy_fps[-1]} during interaction step {interaction_step} is below the required control frame rate {cfg.fps}"
|
||||
)
|
||||
|
||||
next_obs, reward, done, truncated, info = online_env.step(action.cpu().numpy())
|
||||
else:
|
||||
action = online_env.action_space.sample()
|
||||
|
@ -213,70 +223,88 @@ def act_with_policy(cfg: DictConfig,
|
|||
# HACK
|
||||
action = torch.tensor(action, dtype=torch.float32).to(device, non_blocking=True)
|
||||
|
||||
# logging.info(f"[ACTOR] Time for env step {time.time() - start}")
|
||||
|
||||
# HACK: For maniskill
|
||||
# next_obs = preprocess_observation(next_obs)
|
||||
next_obs = preprocess_maniskill_observation(next_obs)
|
||||
next_obs = {key: next_obs[key].to(device, non_blocking=True) for key in obs}
|
||||
sum_reward_episode += float(reward[0])
|
||||
# Because we are using a single environment
|
||||
# we can safely assume that the episode is done
|
||||
|
||||
# Because we are using a single environment we can index at zero
|
||||
if done[0].item() or truncated[0].item():
|
||||
# TODO: Handle logging for episode information
|
||||
logging.info(f"[ACTOR] Global step {interaction_step}: Episode reward: {sum_reward_episode}")
|
||||
|
||||
if not parameters_queue.empty():
|
||||
logging.info("[ACTOR] Load new parameters from Learner.")
|
||||
# Load new parameters from Learner
|
||||
logging.debug("[ACTOR] Load new parameters from Learner.")
|
||||
state_dict = parameters_queue.get()
|
||||
state_dict = move_state_dict_to_device(state_dict, device=device)
|
||||
policy.actor.load_state_dict(state_dict)
|
||||
|
||||
|
||||
if len(list_transition_to_send_to_learner) > 0:
|
||||
logging.info(f"[ACTOR] Sending {len(list_transition_to_send_to_learner)} transitions to Learner.")
|
||||
logging.debug(
|
||||
f"[ACTOR] Sending {len(list_transition_to_send_to_learner)} transitions to Learner."
|
||||
)
|
||||
message_queue.put(ActorInformation(transition=list_transition_to_send_to_learner))
|
||||
list_transition_to_send_to_learner = []
|
||||
|
||||
stats = {}
|
||||
if len(list_policy_fps) > 0:
|
||||
policy_fps = mean(list_policy_fps)
|
||||
quantiles_90 = quantiles(list_policy_fps, n=10)[-1]
|
||||
logging.debug(f"[ACTOR] Average policy frame rate: {policy_fps}")
|
||||
logging.debug(f"[ACTOR] Policy frame rate 90th percentile: {quantiles_90}")
|
||||
stats = {"Policy frequency [Hz]": policy_fps, "Policy frequency 90th-p [Hz]": quantiles_90}
|
||||
list_policy_fps = []
|
||||
|
||||
# Send episodic reward to the learner
|
||||
message_queue.put(ActorInformation(interaction_message={"episodic_reward": sum_reward_episode,"interaction_step": interaction_step}))
|
||||
message_queue.put(
|
||||
ActorInformation(
|
||||
interaction_message={
|
||||
"Episodic reward": sum_reward_episode,
|
||||
"Interaction step": interaction_step,
|
||||
**stats,
|
||||
}
|
||||
)
|
||||
)
|
||||
sum_reward_episode = 0.0
|
||||
|
||||
# ============================
|
||||
# Prepare transition to send
|
||||
# ============================
|
||||
# Label the reward
|
||||
# TODO (michel-aractingi): Label the reward
|
||||
# if config.label_reward_on_actor:
|
||||
# reward = reward_classifier(obs)
|
||||
|
||||
list_transition_to_send_to_learner.append(Transition(
|
||||
# transition_to_send_to_learner = Transition(
|
||||
state=obs,
|
||||
action=action,
|
||||
reward=reward,
|
||||
next_state=next_obs,
|
||||
done=done,
|
||||
complementary_info=None,
|
||||
)
|
||||
list_transition_to_send_to_learner.append(
|
||||
Transition(
|
||||
state=obs,
|
||||
action=action,
|
||||
reward=reward,
|
||||
next_state=next_obs,
|
||||
done=done,
|
||||
complementary_info=None,
|
||||
)
|
||||
)
|
||||
# message_queue.put(ActorInformation(transition=transition_to_send_to_learner))
|
||||
|
||||
# assign obs to the next obs and continue the rollout
|
||||
obs = next_obs
|
||||
|
||||
|
||||
@hydra.main(version_base="1.2", config_name="default", config_path="../../configs")
|
||||
def actor_cli(cfg: dict):
|
||||
server_thread = Thread(target=serve_actor_service, args=(50051,), daemon=True)
|
||||
server_thread.start()
|
||||
policy_thread = Thread(target=act_with_policy,
|
||||
daemon=True,
|
||||
args=(cfg,hydra.core.hydra_config.HydraConfig.get().run.dir, hydra.core.hydra_config.HydraConfig.get().job.name))
|
||||
policy_thread.start()
|
||||
policy_thread.join()
|
||||
server_thread.join()
|
||||
port = cfg.actor_learner_config.port
|
||||
server_thread = Thread(target=serve_actor_service, args=(port,), daemon=True)
|
||||
server_thread.start()
|
||||
policy_thread = Thread(
|
||||
target=act_with_policy,
|
||||
daemon=True,
|
||||
args=(
|
||||
cfg,
|
||||
hydra.core.hydra_config.HydraConfig.get().run.dir,
|
||||
hydra.core.hydra_config.HydraConfig.get().job.name,
|
||||
),
|
||||
)
|
||||
policy_thread.start()
|
||||
policy_thread.join()
|
||||
server_thread.join()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
with open("traceback.log", "w") as f:
|
||||
faulthandler.register(signal.SIGUSR1, file=f)
|
||||
|
||||
actor_cli()
|
||||
actor_cli()
|
||||
|
|
|
@ -1,3 +1,19 @@
|
|||
// !/usr/bin/env python
|
||||
|
||||
// Copyright 2024 The HuggingFace Inc. team.
|
||||
// All rights reserved.
|
||||
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
syntax = "proto3";
|
||||
|
||||
package hil_serl;
|
||||
|
|
|
@ -1,97 +1,97 @@
|
|||
import grpc
|
||||
from concurrent import futures
|
||||
import functools
|
||||
import logging
|
||||
import queue
|
||||
import pickle
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2024 The HuggingFace Inc. team.
|
||||
# All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import io
|
||||
import logging
|
||||
import pickle
|
||||
import queue
|
||||
import time
|
||||
|
||||
from pprint import pformat
|
||||
import random
|
||||
from typing import Optional, Sequence, TypedDict, Callable
|
||||
from threading import Lock, Thread
|
||||
|
||||
import grpc
|
||||
|
||||
# Import generated stubs
|
||||
import hilserl_pb2 # type: ignore
|
||||
import hilserl_pb2_grpc # type: ignore
|
||||
import hydra
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from torch import nn
|
||||
from tqdm import tqdm
|
||||
from deepdiff import DeepDiff
|
||||
from omegaconf import DictConfig, OmegaConf
|
||||
from threading import Thread, Lock
|
||||
|
||||
from lerobot.common.datasets.lerobot_dataset import LeRobotDataset
|
||||
from torch import nn
|
||||
|
||||
# TODO: Remove the import of maniskill
|
||||
from lerobot.common.datasets.factory import make_dataset
|
||||
from lerobot.common.logger import Logger, log_output_dir
|
||||
from lerobot.common.policies.factory import make_policy
|
||||
from lerobot.common.policies.sac.modeling_sac import SACPolicy
|
||||
from lerobot.common.policies.utils import get_device_from_parameters
|
||||
from lerobot.common.utils.utils import (
|
||||
format_big_number,
|
||||
get_safe_torch_device,
|
||||
init_hydra_config,
|
||||
init_logging,
|
||||
set_global_seed,
|
||||
)
|
||||
from lerobot.scripts.server.buffer import ReplayBuffer, move_transition_to_device, concatenate_batch_transitions, move_state_dict_to_device, Transition
|
||||
|
||||
# Import generated stubs
|
||||
import hilserl_pb2
|
||||
import hilserl_pb2_grpc
|
||||
from lerobot.scripts.server.buffer import (
|
||||
ReplayBuffer,
|
||||
concatenate_batch_transitions,
|
||||
move_state_dict_to_device,
|
||||
move_transition_to_device,
|
||||
)
|
||||
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
|
||||
|
||||
|
||||
# TODO: Implement it in cleaner way maybe
|
||||
transition_queue = queue.Queue()
|
||||
interaction_message_queue = queue.Queue()
|
||||
|
||||
|
||||
# 1) Implement the LearnerService so the Actor can send transitions here.
|
||||
class LearnerServiceServicer(hilserl_pb2_grpc.LearnerServiceServicer):
|
||||
# def SendTransition(self, request, context):
|
||||
# """
|
||||
# Actor calls this method to push a Transition -> Learner.
|
||||
# """
|
||||
# buffer = io.BytesIO(request.transition_bytes)
|
||||
# transition = torch.load(buffer)
|
||||
# transition_queue.put(transition)
|
||||
# return hilserl_pb2.Empty()
|
||||
def SendInteractionMessage(self, request, context):
|
||||
"""
|
||||
Actor calls this method to push a Transition -> Learner.
|
||||
"""
|
||||
content = pickle.loads(request.interaction_message_bytes)
|
||||
interaction_message_queue.put(content)
|
||||
return hilserl_pb2.Empty()
|
||||
|
||||
|
||||
|
||||
def stream_transitions_from_actor(port=50051):
|
||||
def stream_transitions_from_actor(host="127.0.0.1", port=50051):
|
||||
"""
|
||||
Runs a gRPC server listening for transitions from the Actor.
|
||||
Runs a gRPC client that listens for transition and interaction messages from an Actor service.
|
||||
|
||||
This function establishes a gRPC connection with the given `host` and `port`, then continuously
|
||||
streams transition data from the `ActorServiceStub`. The received transition data is deserialized
|
||||
and stored in a queue (`transition_queue`). Similarly, interaction messages are also deserialized
|
||||
and stored in a separate queue (`interaction_message_queue`).
|
||||
|
||||
Args:
|
||||
host (str, optional): The IP address or hostname of the gRPC server. Defaults to `"127.0.0.1"`.
|
||||
port (int, optional): The port number on which the gRPC server is running. Defaults to `50051`.
|
||||
|
||||
"""
|
||||
# NOTE: This is waiting for the handshake to be done
|
||||
# In the future we will do it in a canonical way with a proper handshake
|
||||
time.sleep(10)
|
||||
channel = grpc.insecure_channel(f'127.0.0.1:{port}',
|
||||
options=[('grpc.max_send_message_length', -1),
|
||||
('grpc.max_receive_message_length', -1)])
|
||||
channel = grpc.insecure_channel(
|
||||
f"{host}:{port}",
|
||||
options=[("grpc.max_send_message_length", -1), ("grpc.max_receive_message_length", -1)],
|
||||
)
|
||||
stub = hilserl_pb2_grpc.ActorServiceStub(channel)
|
||||
for response in stub.StreamTransition(hilserl_pb2.Empty()):
|
||||
if response.HasField('transition'):
|
||||
if response.HasField("transition"):
|
||||
buffer = io.BytesIO(response.transition.transition_bytes)
|
||||
transition = torch.load(buffer)
|
||||
transition_queue.put(transition)
|
||||
if response.HasField('interaction_message'):
|
||||
if response.HasField("interaction_message"):
|
||||
content = pickle.loads(response.interaction_message.interaction_message_bytes)
|
||||
interaction_message_queue.put(content)
|
||||
# NOTE: Cool down the CPU, if you comment this line you will make a huge bottleneck
|
||||
# TODO: LOOK TO REMOVE IT
|
||||
time.sleep(0.001)
|
||||
|
||||
|
||||
def learner_push_parameters(
|
||||
policy: nn.Module, policy_lock: Lock, actor_host="127.0.0.1", actor_port=50052, seconds_between_pushes=5
|
||||
):
|
||||
|
@ -100,10 +100,10 @@ def learner_push_parameters(
|
|||
and periodically push new parameters.
|
||||
"""
|
||||
time.sleep(10)
|
||||
# The Actor's server is presumably listening on a different port, e.g. 50052
|
||||
channel = grpc.insecure_channel(f"{actor_host}:{actor_port}",
|
||||
options=[('grpc.max_send_message_length', -1),
|
||||
('grpc.max_receive_message_length', -1)])
|
||||
channel = grpc.insecure_channel(
|
||||
f"{actor_host}:{actor_port}",
|
||||
options=[("grpc.max_send_message_length", -1), ("grpc.max_receive_message_length", -1)],
|
||||
)
|
||||
actor_stub = hilserl_pb2_grpc.ActorServiceStub(channel)
|
||||
|
||||
while True:
|
||||
|
@ -116,20 +116,19 @@ def learner_push_parameters(
|
|||
params_bytes = buf.getvalue()
|
||||
|
||||
# Push them to the Actor’s "SendParameters" method
|
||||
logging.info(f"[LEARNER] Pushing parameters to the Actor")
|
||||
response = actor_stub.SendParameters(hilserl_pb2.Parameters(parameter_bytes=params_bytes))
|
||||
logging.info("[LEARNER] Publishing parameters to the Actor")
|
||||
response = actor_stub.SendParameters(hilserl_pb2.Parameters(parameter_bytes=params_bytes)) # noqa: F841
|
||||
time.sleep(seconds_between_pushes)
|
||||
|
||||
|
||||
# Checked
|
||||
def add_actor_information(
|
||||
def add_actor_information_and_train(
|
||||
cfg,
|
||||
device,
|
||||
device: str,
|
||||
replay_buffer: ReplayBuffer,
|
||||
offline_replay_buffer: ReplayBuffer,
|
||||
batch_size: int,
|
||||
optimizers,
|
||||
policy,
|
||||
optimizers: dict[str, torch.optim.Optimizer],
|
||||
policy: nn.Module,
|
||||
policy_lock: Lock,
|
||||
buffer_lock: Lock,
|
||||
offline_buffer_lock: Lock,
|
||||
|
@ -137,34 +136,52 @@ def add_actor_information(
|
|||
logger: Logger,
|
||||
):
|
||||
"""
|
||||
In a real application, you might run your training loop here,
|
||||
reading from the transition queue and doing gradient updates.
|
||||
Handles data transfer from the actor to the learner, manages training updates,
|
||||
and logs training progress in an online reinforcement learning setup.
|
||||
|
||||
This function continuously:
|
||||
- Transfers transitions from the actor to the replay buffer.
|
||||
- Logs received interaction messages.
|
||||
- Ensures training begins only when the replay buffer has a sufficient number of transitions.
|
||||
- Samples batches from the replay buffer and performs multiple critic updates.
|
||||
- Periodically updates the actor, critic, and temperature optimizers.
|
||||
- Logs training statistics, including loss values and optimization frequency.
|
||||
|
||||
**NOTE:**
|
||||
- This function performs multiple responsibilities (data transfer, training, and logging).
|
||||
It should ideally be split into smaller functions in the future.
|
||||
- Due to Python's **Global Interpreter Lock (GIL)**, running separate threads for different tasks
|
||||
significantly reduces performance. Instead, this function executes all operations in a single thread.
|
||||
|
||||
Args:
|
||||
cfg: Configuration object containing hyperparameters.
|
||||
device (str): The computing device (`"cpu"` or `"cuda"`).
|
||||
replay_buffer (ReplayBuffer): The primary replay buffer storing online transitions.
|
||||
offline_replay_buffer (ReplayBuffer): An additional buffer for offline transitions.
|
||||
batch_size (int): The number of transitions to sample per training step.
|
||||
optimizers (Dict[str, torch.optim.Optimizer]): A dictionary of optimizers (`"actor"`, `"critic"`, `"temperature"`).
|
||||
policy (nn.Module): The reinforcement learning policy with critic, actor, and temperature parameters.
|
||||
policy_lock (Lock): A threading lock to ensure safe policy updates.
|
||||
buffer_lock (Lock): A threading lock to safely access the online replay buffer.
|
||||
offline_buffer_lock (Lock): A threading lock to safely access the offline replay buffer.
|
||||
logger_lock (Lock): A threading lock to safely log training metrics.
|
||||
logger (Logger): Logger instance for tracking training progress.
|
||||
"""
|
||||
# NOTE: This function doesn't have a single responsibility, it should be split into multiple functions
|
||||
# in the future. The reason why we did that is the GIL in Python. It's super slow the performance
|
||||
# are divided by 200. So we need to have a single thread that does all the work.
|
||||
start = time.time()
|
||||
time.time()
|
||||
optimization_step = 0
|
||||
timeout_for_adding_transitions = 1
|
||||
while True:
|
||||
time_for_adding_transitions = time.time()
|
||||
while not transition_queue.empty():
|
||||
|
||||
transition_list = transition_queue.get()
|
||||
for transition in transition_list:
|
||||
transition = move_transition_to_device(transition, device=device)
|
||||
replay_buffer.add(**transition)
|
||||
# logging.info(f"[LEARNER] size of replay buffer: {len(replay_buffer)}")
|
||||
# logging.info(f"[LEARNER] size of transition queues: {transition_queue.qsize()}")
|
||||
# logging.info(f"[LEARNER] size of replay buffer: {len(replay_buffer)}")
|
||||
# logging.info(f"[LEARNER] size of transition queues: {transition }")
|
||||
if len(replay_buffer) > cfg.training.online_step_before_learning:
|
||||
logging.info(f"[LEARNER] size of replay buffer: {len(replay_buffer)}")
|
||||
|
||||
while not interaction_message_queue.empty():
|
||||
interaction_message = interaction_message_queue.get()
|
||||
logger.log_dict(interaction_message,mode="train",custom_step_key="interaction_step")
|
||||
# logging.info(f"[LEARNER] size of interaction message queue: {interaction_message_queue.qsize()}")
|
||||
logger.log_dict(interaction_message, mode="train", custom_step_key="Interaction step")
|
||||
|
||||
if len(replay_buffer) < cfg.training.online_step_before_learning:
|
||||
continue
|
||||
|
@ -212,7 +229,7 @@ def add_actor_information(
|
|||
loss_critic = policy.compute_loss_critic(
|
||||
observations=observations,
|
||||
actions=actions,
|
||||
rewards=rewards,
|
||||
rewards=rewards,
|
||||
next_observations=next_observations,
|
||||
done=done,
|
||||
)
|
||||
|
@ -223,7 +240,6 @@ def add_actor_information(
|
|||
training_infos = {}
|
||||
training_infos["loss_critic"] = loss_critic.item()
|
||||
|
||||
|
||||
if optimization_step % cfg.training.policy_update_freq == 0:
|
||||
for _ in range(cfg.training.policy_update_freq):
|
||||
with policy_lock:
|
||||
|
@ -242,18 +258,52 @@ def add_actor_information(
|
|||
|
||||
training_infos["loss_temperature"] = loss_temperature.item()
|
||||
|
||||
policy.update_target_networks()
|
||||
if optimization_step % cfg.training.log_freq == 0:
|
||||
logger.log_dict(training_infos, step=optimization_step, mode="train")
|
||||
|
||||
policy.update_target_networks()
|
||||
optimization_step += 1
|
||||
time_for_one_optimization_step = time.time() - time_for_one_optimization_step
|
||||
frequency_for_one_optimization_step = 1 / (time_for_one_optimization_step + 1e-9)
|
||||
|
||||
logging.info(f"[LEARNER] Time for one optimization step: {time_for_one_optimization_step}")
|
||||
logger.log_dict({"Time optimization step":time_for_one_optimization_step}, step=optimization_step, mode="train")
|
||||
logging.debug(f"[LEARNER] Optimization frequency loop [Hz]: {frequency_for_one_optimization_step}")
|
||||
|
||||
logger.log_dict(
|
||||
{"Optimization frequency loop [Hz]": frequency_for_one_optimization_step},
|
||||
step=optimization_step,
|
||||
mode="train",
|
||||
)
|
||||
|
||||
optimization_step += 1
|
||||
if optimization_step % cfg.training.log_freq == 0:
|
||||
logging.info(f"[LEARNER] Number of optimization step: {optimization_step}")
|
||||
|
||||
|
||||
def make_optimizers_and_scheduler(cfg, policy):
|
||||
def make_optimizers_and_scheduler(cfg, policy: nn.Module):
|
||||
"""
|
||||
Creates and returns optimizers for the actor, critic, and temperature components of a reinforcement learning policy.
|
||||
|
||||
This function sets up Adam optimizers for:
|
||||
- The **actor network**, ensuring that only relevant parameters are optimized.
|
||||
- The **critic ensemble**, which evaluates the value function.
|
||||
- The **temperature parameter**, which controls the entropy in soft actor-critic (SAC)-like methods.
|
||||
|
||||
It also initializes a learning rate scheduler, though currently, it is set to `None`.
|
||||
|
||||
**NOTE:**
|
||||
- If the encoder is shared, its parameters are excluded from the actor’s optimization process.
|
||||
- The policy’s log temperature (`log_alpha`) is wrapped in a list to ensure proper optimization as a standalone tensor.
|
||||
|
||||
Args:
|
||||
cfg: Configuration object containing hyperparameters.
|
||||
policy (nn.Module): The policy model containing the actor, critic, and temperature components.
|
||||
|
||||
Returns:
|
||||
Tuple[Dict[str, torch.optim.Optimizer], Optional[torch.optim.lr_scheduler._LRScheduler]]:
|
||||
A tuple containing:
|
||||
- `optimizers`: A dictionary mapping component names ("actor", "critic", "temperature") to their respective Adam optimizers.
|
||||
- `lr_scheduler`: Currently set to `None` but can be extended to support learning rate scheduling.
|
||||
|
||||
"""
|
||||
optimizer_actor = torch.optim.Adam(
|
||||
# NOTE: Handle the case of shared encoder where the encoder weights are not optimized with the gradient of the actor
|
||||
params=policy.actor.parameters_to_optimize,
|
||||
|
@ -273,8 +323,6 @@ def make_optimizers_and_scheduler(cfg, policy):
|
|||
return optimizers, lr_scheduler
|
||||
|
||||
|
||||
|
||||
|
||||
def train(cfg: DictConfig, out_dir: str | None = None, job_name: str | None = None):
|
||||
if out_dir is None:
|
||||
raise NotImplementedError()
|
||||
|
@ -332,6 +380,7 @@ def train(cfg: DictConfig, out_dir: str | None = None, job_name: str | None = No
|
|||
batch_size = cfg.training.batch_size
|
||||
offline_buffer_lock = None
|
||||
offline_replay_buffer = None
|
||||
|
||||
if cfg.dataset_repo_id is not None:
|
||||
logging.info("make_dataset offline buffer")
|
||||
offline_dataset = make_dataset(cfg)
|
||||
|
@ -342,48 +391,48 @@ def train(cfg: DictConfig, out_dir: str | None = None, job_name: str | None = No
|
|||
offline_buffer_lock = Lock()
|
||||
batch_size: int = batch_size // 2 # We will sample from both replay buffer
|
||||
|
||||
server_thread = Thread(target=stream_transitions_from_actor, args=(50051,), daemon=True)
|
||||
actor_ip = cfg.actor_learner_config.actor_ip
|
||||
port = cfg.actor_learner_config.port
|
||||
|
||||
server_thread = Thread(
|
||||
target=stream_transitions_from_actor,
|
||||
args=(
|
||||
actor_ip,
|
||||
port,
|
||||
),
|
||||
daemon=True,
|
||||
)
|
||||
server_thread.start()
|
||||
|
||||
|
||||
# Start a background thread to process transitions from the queue
|
||||
transition_thread = Thread(
|
||||
target=add_actor_information,
|
||||
target=add_actor_information_and_train,
|
||||
daemon=True,
|
||||
args=(cfg,
|
||||
device,
|
||||
replay_buffer,
|
||||
offline_replay_buffer,
|
||||
batch_size,
|
||||
optimizers,
|
||||
policy,
|
||||
policy_lock,
|
||||
buffer_lock,
|
||||
offline_buffer_lock,
|
||||
logger_lock,
|
||||
logger),
|
||||
args=(
|
||||
cfg,
|
||||
device,
|
||||
replay_buffer,
|
||||
offline_replay_buffer,
|
||||
batch_size,
|
||||
optimizers,
|
||||
policy,
|
||||
policy_lock,
|
||||
buffer_lock,
|
||||
offline_buffer_lock,
|
||||
logger_lock,
|
||||
logger,
|
||||
),
|
||||
)
|
||||
transition_thread.start()
|
||||
|
||||
param_push_thread = Thread(
|
||||
target=learner_push_parameters,
|
||||
args=(policy, policy_lock, "127.0.0.1", 50051, 15),
|
||||
# args=("127.0.0.1", 50052),
|
||||
args=(policy, policy_lock, actor_ip, port, 15),
|
||||
daemon=True,
|
||||
)
|
||||
param_push_thread.start()
|
||||
|
||||
# interaction_thread = Thread(
|
||||
# target=add_message_interaction_to_wandb,
|
||||
# daemon=True,
|
||||
# args=(cfg, logger, logger_lock),
|
||||
# )
|
||||
# interaction_thread.start()
|
||||
|
||||
transition_thread.join()
|
||||
# param_push_thread.join()
|
||||
server_thread.join()
|
||||
# interaction_thread.join()
|
||||
|
||||
|
||||
@hydra.main(version_base="1.2", config_name="default", config_path="../../configs")
|
||||
|
|
Loading…
Reference in New Issue