Added server directory in `lerobot/scripts` that contains scripts and the protobuf message types to split training into two processes, acting and learning. The actor rollouts the policy and collects interaction data while the learner recieves the data, trains the policy and sends the updated parameters to the actor. The two scripts are ran simultaneously
Co-authored-by: Adil Zouitine <adilzouitinegm@gmail.com>
This commit is contained in:
parent
d75b44f89f
commit
322a78a378
|
@ -127,6 +127,8 @@ class Logger:
|
|||
job_type="train_eval",
|
||||
resume="must" if cfg.resume else None,
|
||||
)
|
||||
# Handle custom step key for rl asynchronous training.
|
||||
self._wandb_custom_step_key = None
|
||||
print(colored("Logs will be synced with wandb.", "blue", attrs=["bold"]))
|
||||
logging.info(f"Track this run --> {colored(wandb.run.get_url(), 'yellow', attrs=['bold'])}")
|
||||
self._wandb = wandb
|
||||
|
@ -226,18 +228,47 @@ class Logger:
|
|||
set_global_random_state({k: training_state[k] for k in get_global_random_state()})
|
||||
return training_state["step"]
|
||||
|
||||
def log_dict(self, d, step, mode="train"):
|
||||
def log_dict(self, d, step:int | None = None, mode="train", custom_step_key: str | None = None):
|
||||
"""Log a dictionary of metrics to WandB."""
|
||||
assert mode in {"train", "eval"}
|
||||
# TODO(alexander-soare): Add local text log.
|
||||
if step is None and custom_step_key is None:
|
||||
raise ValueError("Either step or custom_step_key must be provided.")
|
||||
|
||||
if self._wandb is not None:
|
||||
|
||||
# NOTE: This is not simple. Wandb step is it must always monotonically increase and it
|
||||
# increases with each wandb.log call, but in the case of asynchronous RL for example,
|
||||
# multiple time steps is possible for example, the interaction step with the environment,
|
||||
# the training step, the evaluation step, etc. So we need to define a custom step key
|
||||
# to log the correct step for each metric.
|
||||
if custom_step_key is not None and self._wandb_custom_step_key is None:
|
||||
# NOTE: Define the custom step key, once for the moment this implementation support only one
|
||||
# custom step.
|
||||
self._wandb_custom_step_key = f"{mode}/{custom_step_key}"
|
||||
self._wandb.define_metric(self._wandb_custom_step_key, hidden=True)
|
||||
|
||||
for k, v in d.items():
|
||||
if not isinstance(v, (int, float, str, wandb.Table)):
|
||||
logging.warning(
|
||||
f'WandB logging of key "{k}" was ignored as its type is not handled by this wrapper.'
|
||||
)
|
||||
continue
|
||||
|
||||
# We don't want to log the custom step
|
||||
if k == custom_step_key:
|
||||
continue
|
||||
|
||||
if self._wandb_custom_step_key is not None:
|
||||
# NOTE: Log the metric with the custom step key.
|
||||
value_custom_step_key = d[custom_step_key]
|
||||
self._wandb.log({f"{mode}/{k}": v, self._wandb_custom_step_key: value_custom_step_key})
|
||||
continue
|
||||
|
||||
self._wandb.log({f"{mode}/{k}": v}, step=step)
|
||||
|
||||
|
||||
|
||||
def log_video(self, video_path: str, step: int, mode: str = "train"):
|
||||
assert mode in {"train", "eval"}
|
||||
assert self._wandb is not None
|
||||
|
|
|
@ -76,7 +76,11 @@ def get_policy_and_config_classes(name: str) -> tuple[Policy, object]:
|
|||
|
||||
|
||||
def make_policy(
|
||||
hydra_cfg: DictConfig, pretrained_policy_name_or_path: str | None = None, dataset_stats=None
|
||||
hydra_cfg: DictConfig,
|
||||
pretrained_policy_name_or_path: str | None = None,
|
||||
dataset_stats=None,
|
||||
*args,
|
||||
**kwargs,
|
||||
) -> Policy:
|
||||
"""Make an instance of a policy class.
|
||||
|
||||
|
@ -100,7 +104,9 @@ def make_policy(
|
|||
policy_cfg = _policy_cfg_from_hydra_cfg(policy_cfg_class, hydra_cfg)
|
||||
if pretrained_policy_name_or_path is None:
|
||||
# Make a fresh policy.
|
||||
policy = policy_cls(policy_cfg, dataset_stats)
|
||||
# HACK: We pass *args and **kwargs to the policy constructor to allow for additional arguments
|
||||
# for example device for the sac policy.
|
||||
policy = policy_cls(*args, **kwargs, config=policy_cfg, dataset_stats=dataset_stats)
|
||||
else:
|
||||
# Load a pretrained policy and override the config if needed (for example, if there are inference-time
|
||||
# hyperparameters that we want to vary).
|
||||
|
|
|
@ -0,0 +1,282 @@
|
|||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import logging
|
||||
import functools
|
||||
from pprint import pformat
|
||||
import random
|
||||
from typing import Optional, Sequence, TypedDict, Callable
|
||||
import pickle
|
||||
|
||||
import hydra
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from torch import nn
|
||||
from tqdm import tqdm
|
||||
from deepdiff import DeepDiff
|
||||
from omegaconf import DictConfig, OmegaConf
|
||||
|
||||
from lerobot.common.datasets.lerobot_dataset import LeRobotDataset
|
||||
|
||||
# TODO: Remove the import of maniskill
|
||||
from lerobot.common.datasets.factory import make_dataset
|
||||
from lerobot.common.envs.factory import make_env, make_maniskill_env
|
||||
from lerobot.common.envs.utils import preprocess_observation, preprocess_maniskill_observation
|
||||
from lerobot.common.logger import Logger, log_output_dir
|
||||
from lerobot.common.policies.factory import make_policy
|
||||
from lerobot.common.policies.sac.modeling_sac import SACPolicy
|
||||
from lerobot.common.policies.utils import get_device_from_parameters
|
||||
from lerobot.common.utils.utils import (
|
||||
format_big_number,
|
||||
get_safe_torch_device,
|
||||
init_hydra_config,
|
||||
init_logging,
|
||||
set_global_seed,
|
||||
)
|
||||
# from lerobot.scripts.eval import eval_policy
|
||||
from threading import Thread
|
||||
import queue
|
||||
|
||||
import grpc
|
||||
from lerobot.scripts.server import hilserl_pb2, hilserl_pb2_grpc
|
||||
import io
|
||||
import time
|
||||
import logging
|
||||
from concurrent import futures
|
||||
from threading import Thread
|
||||
from lerobot.scripts.server.buffer import move_state_dict_to_device, move_transition_to_device, Transition
|
||||
|
||||
import faulthandler
|
||||
import signal
|
||||
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
|
||||
parameters_queue = queue.Queue(maxsize=1)
|
||||
message_queue = queue.Queue(maxsize=1_000_000)
|
||||
|
||||
class ActorInformation:
|
||||
def __init__(self, transition=None, interaction_message=None):
|
||||
self.transition = transition
|
||||
self.interaction_message = interaction_message
|
||||
|
||||
|
||||
# 1) Implement ActorService so the Learner can send parameters to this Actor.
|
||||
class ActorServiceServicer(hilserl_pb2_grpc.ActorServiceServicer):
|
||||
def StreamTransition(self, request, context):
|
||||
while True:
|
||||
# logging.info(f"[ACTOR] before message.empty()")
|
||||
# logging.info(f"[ACTOR] size transition queue {message_queue.qsize()}")
|
||||
# time.sleep(0.01)
|
||||
# if message_queue.empty():
|
||||
# continue
|
||||
# logging.info(f"[ACTOR] after message.empty()")
|
||||
start = time.time()
|
||||
message = message_queue.get(block=True)
|
||||
# logging.info(f"[ACTOR] Message queue get time {time.time() - start}")
|
||||
|
||||
if message.transition is not None:
|
||||
# transition_to_send_to_learner = move_transition_to_device(message.transition, device="cpu")
|
||||
transition_to_send_to_learner = [move_transition_to_device(T, device="cpu") for T in message.transition]
|
||||
# logging.info(f"[ACTOR] Message queue get time {time.time() - start}")
|
||||
|
||||
# Serialize it
|
||||
buf = io.BytesIO()
|
||||
torch.save(transition_to_send_to_learner, buf)
|
||||
transition_bytes = buf.getvalue()
|
||||
|
||||
transition_message = hilserl_pb2.Transition(
|
||||
transition_bytes=transition_bytes
|
||||
)
|
||||
|
||||
response = hilserl_pb2.ActorInformation(
|
||||
transition=transition_message
|
||||
)
|
||||
logging.info(f"[ACTOR] time to yield transition response {time.time() - start}")
|
||||
logging.info(f"[ACTOR] size transition queue {message_queue.qsize()}")
|
||||
|
||||
elif message.interaction_message is not None:
|
||||
# Serialize it and send it to the Learner's server
|
||||
content = hilserl_pb2.InteractionMessage(
|
||||
interaction_message_bytes=pickle.dumps(message.interaction_message)
|
||||
)
|
||||
response = hilserl_pb2.ActorInformation(
|
||||
interaction_message=content
|
||||
)
|
||||
|
||||
# logging.info(f"[ACTOR] yield response before")
|
||||
yield response
|
||||
# logging.info(f"[ACTOR] response yielded after")
|
||||
|
||||
def SendParameters(self, request, context):
|
||||
"""
|
||||
Learner calls this with updated Parameters -> Actor
|
||||
"""
|
||||
# logging.info("[ACTOR] Received parameters from Learner.")
|
||||
buffer = io.BytesIO(request.parameter_bytes)
|
||||
params = torch.load(buffer)
|
||||
parameters_queue.put(params)
|
||||
return hilserl_pb2.Empty()
|
||||
|
||||
|
||||
def serve_actor_service(port=50052):
|
||||
"""
|
||||
Runs a gRPC server so that the Learner can push parameters to the Actor.
|
||||
"""
|
||||
server = grpc.server(futures.ThreadPoolExecutor(max_workers=20),
|
||||
options=[('grpc.max_send_message_length', -1),
|
||||
('grpc.max_receive_message_length', -1)])
|
||||
hilserl_pb2_grpc.add_ActorServiceServicer_to_server(
|
||||
ActorServiceServicer(), server
|
||||
)
|
||||
server.add_insecure_port(f'[::]:{port}')
|
||||
server.start()
|
||||
logging.info(f"[ACTOR] gRPC server listening on port {port}")
|
||||
server.wait_for_termination()
|
||||
|
||||
def act_with_policy(cfg: DictConfig,
|
||||
out_dir: str | None = None,
|
||||
job_name: str | None = None):
|
||||
|
||||
if out_dir is None:
|
||||
raise NotImplementedError()
|
||||
if job_name is None:
|
||||
raise NotImplementedError()
|
||||
|
||||
logging.info("make_env online")
|
||||
|
||||
# online_env = make_env(cfg, n_envs=1)
|
||||
# TODO: Remove the import of maniskill and unifiy with make env
|
||||
online_env = make_maniskill_env(cfg, n_envs=1)
|
||||
if cfg.training.eval_freq > 0:
|
||||
logging.info("make_env eval")
|
||||
# eval_env = make_env(cfg, n_envs=1)
|
||||
# TODO: Remove the import of maniskill and unifiy with make env
|
||||
eval_env = make_maniskill_env(cfg, n_envs=1)
|
||||
|
||||
set_global_seed(cfg.seed)
|
||||
device = get_safe_torch_device(cfg.device, log=True)
|
||||
|
||||
torch.backends.cudnn.benchmark = True
|
||||
torch.backends.cuda.matmul.allow_tf32 = True
|
||||
|
||||
logging.info("make_policy")
|
||||
|
||||
|
||||
### Instantiate the policy in both the actor and learner processes
|
||||
### To avoid sending a SACPolicy object through the port, we create a policy intance
|
||||
### on both sides, the learner sends the updated parameters every n steps to update the actor's parameters
|
||||
# TODO: At some point we should just need make sac policy
|
||||
policy: SACPolicy = make_policy(
|
||||
hydra_cfg=cfg,
|
||||
# dataset_stats=offline_dataset.meta.stats if not cfg.resume else None,
|
||||
# Hack: But if we do online traning, we do not need dataset_stats
|
||||
dataset_stats=None,
|
||||
# TODO: Handle resume training
|
||||
pretrained_policy_name_or_path=None,
|
||||
device=device,
|
||||
)
|
||||
assert isinstance(policy, nn.Module)
|
||||
|
||||
# HACK for maniskill
|
||||
obs, info = online_env.reset()
|
||||
|
||||
# obs = preprocess_observation(obs)
|
||||
obs = preprocess_maniskill_observation(obs)
|
||||
obs = {key: obs[key].to(device, non_blocking=True) for key in obs}
|
||||
### ACTOR ==================
|
||||
# NOTE: For the moment we will solely handle the case of a single environment
|
||||
sum_reward_episode = 0
|
||||
list_transition_to_send_to_learner = []
|
||||
|
||||
for interaction_step in range(cfg.training.online_steps):
|
||||
# NOTE: At some point we should use a wrapper to handle the observation
|
||||
|
||||
# start = time.time()
|
||||
if interaction_step >= cfg.training.online_step_before_learning:
|
||||
action = policy.select_action(batch=obs)
|
||||
next_obs, reward, done, truncated, info = online_env.step(action.cpu().numpy())
|
||||
else:
|
||||
action = online_env.action_space.sample()
|
||||
next_obs, reward, done, truncated, info = online_env.step(action)
|
||||
# HACK
|
||||
action = torch.tensor(action, dtype=torch.float32).to(device, non_blocking=True)
|
||||
|
||||
# logging.info(f"[ACTOR] Time for env step {time.time() - start}")
|
||||
|
||||
# HACK: For maniskill
|
||||
# next_obs = preprocess_observation(next_obs)
|
||||
next_obs = preprocess_maniskill_observation(next_obs)
|
||||
next_obs = {key: next_obs[key].to(device, non_blocking=True) for key in obs}
|
||||
sum_reward_episode += float(reward[0])
|
||||
# Because we are using a single environment
|
||||
# we can safely assume that the episode is done
|
||||
if done[0].item() or truncated[0].item():
|
||||
# TODO: Handle logging for episode information
|
||||
logging.info(f"[ACTOR] Global step {interaction_step}: Episode reward: {sum_reward_episode}")
|
||||
|
||||
if not parameters_queue.empty():
|
||||
logging.info("[ACTOR] Load new parameters from Learner.")
|
||||
# Load new parameters from Learner
|
||||
state_dict = parameters_queue.get()
|
||||
state_dict = move_state_dict_to_device(state_dict, device=device)
|
||||
policy.actor.load_state_dict(state_dict)
|
||||
|
||||
if len(list_transition_to_send_to_learner) > 0:
|
||||
logging.info(f"[ACTOR] Sending {len(list_transition_to_send_to_learner)} transitions to Learner.")
|
||||
message_queue.put(ActorInformation(transition=list_transition_to_send_to_learner))
|
||||
list_transition_to_send_to_learner = []
|
||||
|
||||
# Send episodic reward to the learner
|
||||
message_queue.put(ActorInformation(interaction_message={"episodic_reward": sum_reward_episode,"interaction_step": interaction_step}))
|
||||
sum_reward_episode = 0.0
|
||||
|
||||
# ============================
|
||||
# Prepare transition to send
|
||||
# ============================
|
||||
# Label the reward
|
||||
# if config.label_reward_on_actor:
|
||||
# reward = reward_classifier(obs)
|
||||
|
||||
list_transition_to_send_to_learner.append(Transition(
|
||||
# transition_to_send_to_learner = Transition(
|
||||
state=obs,
|
||||
action=action,
|
||||
reward=reward,
|
||||
next_state=next_obs,
|
||||
done=done,
|
||||
complementary_info=None,
|
||||
)
|
||||
)
|
||||
# message_queue.put(ActorInformation(transition=transition_to_send_to_learner))
|
||||
|
||||
# assign obs to the next obs and continue the rollout
|
||||
obs = next_obs
|
||||
|
||||
@hydra.main(version_base="1.2", config_name="default", config_path="../../configs")
|
||||
def actor_cli(cfg: dict):
|
||||
server_thread = Thread(target=serve_actor_service, args=(50051,), daemon=True)
|
||||
server_thread.start()
|
||||
policy_thread = Thread(target=act_with_policy,
|
||||
daemon=True,
|
||||
args=(cfg,hydra.core.hydra_config.HydraConfig.get().run.dir, hydra.core.hydra_config.HydraConfig.get().job.name))
|
||||
policy_thread.start()
|
||||
policy_thread.join()
|
||||
server_thread.join()
|
||||
|
||||
if __name__ == "__main__":
|
||||
with open("traceback.log", "w") as f:
|
||||
faulthandler.register(signal.SIGUSR1, file=f)
|
||||
|
||||
actor_cli()
|
|
@ -0,0 +1,42 @@
|
|||
syntax = "proto3";
|
||||
|
||||
package hil_serl;
|
||||
|
||||
// LearnerService: the Actor calls this to push transitions.
|
||||
// The Learner implements this service.
|
||||
service LearnerService {
|
||||
// Actor -> Learner to store transitions
|
||||
rpc SendTransition(Transition) returns (Empty);
|
||||
rpc SendInteractionMessage(InteractionMessage) returns (Empty);
|
||||
}
|
||||
|
||||
// ActorService: the Learner calls this to push parameters.
|
||||
// The Actor implements this service.
|
||||
service ActorService {
|
||||
// Learner -> Actor to send new parameters
|
||||
rpc StreamTransition(Empty) returns (stream ActorInformation) {};
|
||||
rpc SendParameters(Parameters) returns (Empty);
|
||||
}
|
||||
|
||||
|
||||
message ActorInformation {
|
||||
oneof data {
|
||||
Transition transition = 1;
|
||||
InteractionMessage interaction_message = 2;
|
||||
}
|
||||
}
|
||||
|
||||
// Messages
|
||||
message Transition {
|
||||
bytes transition_bytes = 1;
|
||||
}
|
||||
|
||||
message Parameters {
|
||||
bytes parameter_bytes = 1;
|
||||
}
|
||||
|
||||
message InteractionMessage {
|
||||
bytes interaction_message_bytes = 1;
|
||||
}
|
||||
|
||||
message Empty {}
|
|
@ -0,0 +1,394 @@
|
|||
import grpc
|
||||
from concurrent import futures
|
||||
import functools
|
||||
import logging
|
||||
import queue
|
||||
import pickle
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
import io
|
||||
import time
|
||||
|
||||
from pprint import pformat
|
||||
import random
|
||||
from typing import Optional, Sequence, TypedDict, Callable
|
||||
|
||||
import hydra
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from torch import nn
|
||||
from tqdm import tqdm
|
||||
from deepdiff import DeepDiff
|
||||
from omegaconf import DictConfig, OmegaConf
|
||||
from threading import Thread, Lock
|
||||
|
||||
from lerobot.common.datasets.lerobot_dataset import LeRobotDataset
|
||||
|
||||
# TODO: Remove the import of maniskill
|
||||
from lerobot.common.datasets.factory import make_dataset
|
||||
from lerobot.common.logger import Logger, log_output_dir
|
||||
from lerobot.common.policies.factory import make_policy
|
||||
from lerobot.common.policies.sac.modeling_sac import SACPolicy
|
||||
from lerobot.common.policies.utils import get_device_from_parameters
|
||||
from lerobot.common.utils.utils import (
|
||||
format_big_number,
|
||||
get_safe_torch_device,
|
||||
init_hydra_config,
|
||||
init_logging,
|
||||
set_global_seed,
|
||||
)
|
||||
from lerobot.scripts.server.buffer import ReplayBuffer, move_transition_to_device, concatenate_batch_transitions, move_state_dict_to_device, Transition
|
||||
|
||||
# Import generated stubs
|
||||
import hilserl_pb2
|
||||
import hilserl_pb2_grpc
|
||||
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
|
||||
|
||||
|
||||
# TODO: Implement it in cleaner way maybe
|
||||
transition_queue = queue.Queue()
|
||||
interaction_message_queue = queue.Queue()
|
||||
|
||||
|
||||
# 1) Implement the LearnerService so the Actor can send transitions here.
|
||||
class LearnerServiceServicer(hilserl_pb2_grpc.LearnerServiceServicer):
|
||||
# def SendTransition(self, request, context):
|
||||
# """
|
||||
# Actor calls this method to push a Transition -> Learner.
|
||||
# """
|
||||
# buffer = io.BytesIO(request.transition_bytes)
|
||||
# transition = torch.load(buffer)
|
||||
# transition_queue.put(transition)
|
||||
# return hilserl_pb2.Empty()
|
||||
def SendInteractionMessage(self, request, context):
|
||||
"""
|
||||
Actor calls this method to push a Transition -> Learner.
|
||||
"""
|
||||
content = pickle.loads(request.interaction_message_bytes)
|
||||
interaction_message_queue.put(content)
|
||||
return hilserl_pb2.Empty()
|
||||
|
||||
|
||||
|
||||
def stream_transitions_from_actor(port=50051):
|
||||
"""
|
||||
Runs a gRPC server listening for transitions from the Actor.
|
||||
"""
|
||||
time.sleep(10)
|
||||
channel = grpc.insecure_channel(f'127.0.0.1:{port}',
|
||||
options=[('grpc.max_send_message_length', -1),
|
||||
('grpc.max_receive_message_length', -1)])
|
||||
stub = hilserl_pb2_grpc.ActorServiceStub(channel)
|
||||
for response in stub.StreamTransition(hilserl_pb2.Empty()):
|
||||
if response.HasField('transition'):
|
||||
buffer = io.BytesIO(response.transition.transition_bytes)
|
||||
transition = torch.load(buffer)
|
||||
transition_queue.put(transition)
|
||||
if response.HasField('interaction_message'):
|
||||
content = pickle.loads(response.interaction_message.interaction_message_bytes)
|
||||
interaction_message_queue.put(content)
|
||||
# NOTE: Cool down the CPU, if you comment this line you will make a huge bottleneck
|
||||
time.sleep(0.001)
|
||||
|
||||
def learner_push_parameters(
|
||||
policy: nn.Module, policy_lock: Lock, actor_host="127.0.0.1", actor_port=50052, seconds_between_pushes=5
|
||||
):
|
||||
"""
|
||||
As a client, connect to the Actor's gRPC server (ActorService)
|
||||
and periodically push new parameters.
|
||||
"""
|
||||
time.sleep(10)
|
||||
# The Actor's server is presumably listening on a different port, e.g. 50052
|
||||
channel = grpc.insecure_channel(f"{actor_host}:{actor_port}",
|
||||
options=[('grpc.max_send_message_length', -1),
|
||||
('grpc.max_receive_message_length', -1)])
|
||||
actor_stub = hilserl_pb2_grpc.ActorServiceStub(channel)
|
||||
|
||||
while True:
|
||||
with policy_lock:
|
||||
params_dict = policy.actor.state_dict()
|
||||
params_dict = move_state_dict_to_device(params_dict, device="cpu")
|
||||
# Serialize
|
||||
buf = io.BytesIO()
|
||||
torch.save(params_dict, buf)
|
||||
params_bytes = buf.getvalue()
|
||||
|
||||
# Push them to the Actor’s "SendParameters" method
|
||||
response = actor_stub.SendParameters(hilserl_pb2.Parameters(parameter_bytes=params_bytes))
|
||||
time.sleep(seconds_between_pushes)
|
||||
|
||||
|
||||
# Checked
|
||||
def add_actor_information(
|
||||
cfg,
|
||||
device,
|
||||
replay_buffer: ReplayBuffer,
|
||||
offline_replay_buffer: ReplayBuffer,
|
||||
batch_size: int,
|
||||
optimizers,
|
||||
policy,
|
||||
policy_lock: Lock,
|
||||
buffer_lock: Lock,
|
||||
offline_buffer_lock: Lock,
|
||||
logger_lock: Lock,
|
||||
logger: Logger,
|
||||
):
|
||||
"""
|
||||
In a real application, you might run your training loop here,
|
||||
reading from the transition queue and doing gradient updates.
|
||||
"""
|
||||
# NOTE: This function doesn't have a single responsibility, it should be split into multiple functions
|
||||
# in the future. The reason why we did that is the GIL in Python. It's super slow the performance
|
||||
# are divided by 200. So we need to have a single thread that does all the work.
|
||||
start = time.time()
|
||||
optimization_step = 0
|
||||
|
||||
while True:
|
||||
time_for_adding_transitions = time.time()
|
||||
while not transition_queue.empty():
|
||||
|
||||
transition_list = transition_queue.get()
|
||||
for transition in transition_list:
|
||||
transition = move_transition_to_device(transition, device=device)
|
||||
replay_buffer.add(**transition)
|
||||
logging.info(f"[LEARNER] size of replay buffer: {len(replay_buffer)}")
|
||||
logging.info(f"[LEARNER] size of transition queues: {transition_queue.qsize()}")
|
||||
|
||||
|
||||
while not interaction_message_queue.empty():
|
||||
interaction_message = interaction_message_queue.get()
|
||||
logger.log_dict(interaction_message,mode="train",custom_step_key="interaction_step")
|
||||
logging.info(f"[LEARNER] size of interaction message queue: {interaction_message_queue.qsize()}")
|
||||
|
||||
# if len(replay_buffer.memory) < cfg.training.online_step_before_learning:
|
||||
# continue
|
||||
|
||||
# for _ in range(cfg.policy.utd_ratio - 1):
|
||||
|
||||
# batch = replay_buffer.sample(batch_size)
|
||||
# if cfg.dataset_repo_id is not None:
|
||||
# batch_offline = offline_replay_buffer.sample(batch_size)
|
||||
# batch = concatenate_batch_transitions(batch, batch_offline)
|
||||
|
||||
# actions = batch["action"]
|
||||
# rewards = batch["reward"]
|
||||
# observations = batch["state"]
|
||||
# next_observations = batch["next_state"]
|
||||
# done = batch["done"]
|
||||
|
||||
# with policy_lock:
|
||||
# loss_critic = policy.compute_loss_critic(
|
||||
# observations=observations,
|
||||
# actions=actions,
|
||||
# rewards=rewards,
|
||||
# next_observations=next_observations,
|
||||
# done=done,
|
||||
# )
|
||||
# optimizers["critic"].zero_grad()
|
||||
# loss_critic.backward()
|
||||
# optimizers["critic"].step()
|
||||
|
||||
# batch = replay_buffer.sample(batch_size)
|
||||
|
||||
# if cfg.dataset_repo_id is not None:
|
||||
# batch_offline = offline_replay_buffer.sample(batch_size)
|
||||
# batch = concatenate_batch_transitions(
|
||||
# left_batch_transitions=batch, right_batch_transition=batch_offline
|
||||
# )
|
||||
|
||||
# actions = batch["action"]
|
||||
# rewards = batch["reward"]
|
||||
# observations = batch["state"]
|
||||
# next_observations = batch["next_state"]
|
||||
# done = batch["done"]
|
||||
|
||||
# with policy_lock:
|
||||
# loss_critic = policy.compute_loss_critic(
|
||||
# observations=observations,
|
||||
# actions=actions,
|
||||
# rewards=rewards,
|
||||
# next_observations=next_observations,
|
||||
# done=done,
|
||||
# )
|
||||
# optimizers["critic"].zero_grad()
|
||||
# loss_critic.backward()
|
||||
# optimizers["critic"].step()
|
||||
|
||||
# training_infos = {}
|
||||
# training_infos["loss_critic"] = loss_critic.item()
|
||||
|
||||
# if optimization_step % cfg.training.policy_update_freq == 0:
|
||||
# for _ in range(cfg.training.policy_update_freq):
|
||||
# with policy_lock:
|
||||
# loss_actor = policy.compute_loss_actor(observations=observations)
|
||||
|
||||
# optimizers["actor"].zero_grad()
|
||||
# loss_actor.backward()
|
||||
# optimizers["actor"].step()
|
||||
|
||||
# training_infos["loss_actor"] = loss_actor.item()
|
||||
|
||||
# loss_temperature = policy.compute_loss_temperature(observations=observations)
|
||||
# optimizers["temperature"].zero_grad()
|
||||
# loss_temperature.backward()
|
||||
# optimizers["temperature"].step()
|
||||
|
||||
# training_infos["loss_temperature"] = loss_temperature.item()
|
||||
|
||||
# if optimization_step % cfg.training.log_freq == 0:
|
||||
# logger.log_dict(training_infos, step=optimization_step, mode="train")
|
||||
|
||||
# policy.update_target_networks()
|
||||
# optimization_step += 1
|
||||
# time_for_one_optimization_step = time.time() - time_for_one_optimization_step
|
||||
|
||||
# logger.log_dict({"[LEARNER] Time optimization step":time_for_one_optimization_step}, step=optimization_step, mode="train")
|
||||
# time_for_one_optimization_step = time.time()
|
||||
|
||||
|
||||
def make_optimizers_and_scheduler(cfg, policy):
|
||||
optimizer_actor = torch.optim.Adam(
|
||||
# NOTE: Handle the case of shared encoder where the encoder weights are not optimized with the gradient of the actor
|
||||
params=policy.actor.parameters_to_optimize,
|
||||
lr=policy.config.actor_lr,
|
||||
)
|
||||
optimizer_critic = torch.optim.Adam(
|
||||
params=policy.critic_ensemble.parameters(), lr=policy.config.critic_lr
|
||||
)
|
||||
# We wrap policy log temperature in list because this is a torch tensor and not a nn.Module
|
||||
optimizer_temperature = torch.optim.Adam(params=[policy.log_alpha], lr=policy.config.critic_lr)
|
||||
lr_scheduler = None
|
||||
optimizers = {
|
||||
"actor": optimizer_actor,
|
||||
"critic": optimizer_critic,
|
||||
"temperature": optimizer_temperature,
|
||||
}
|
||||
return optimizers, lr_scheduler
|
||||
|
||||
|
||||
|
||||
|
||||
def train(cfg: DictConfig, out_dir: str | None = None, job_name: str | None = None):
|
||||
if out_dir is None:
|
||||
raise NotImplementedError()
|
||||
if job_name is None:
|
||||
raise NotImplementedError()
|
||||
|
||||
init_logging()
|
||||
logging.info(pformat(OmegaConf.to_container(cfg)))
|
||||
|
||||
logger = Logger(cfg, out_dir, wandb_job_name=job_name)
|
||||
logger_lock = Lock()
|
||||
|
||||
set_global_seed(cfg.seed)
|
||||
|
||||
device = get_safe_torch_device(cfg.device, log=True)
|
||||
|
||||
torch.backends.cudnn.benchmark = True
|
||||
torch.backends.cuda.matmul.allow_tf32 = True
|
||||
|
||||
logging.info("make_policy")
|
||||
|
||||
### Instantiate the policy in both the actor and learner processes
|
||||
### To avoid sending a SACPolicy object through the port, we create a policy intance
|
||||
### on both sides, the learner sends the updated parameters every n steps to update the actor's parameters
|
||||
# TODO: At some point we should just need make sac policy
|
||||
policy_lock = Lock()
|
||||
with logger_lock:
|
||||
policy: SACPolicy = make_policy(
|
||||
hydra_cfg=cfg,
|
||||
# dataset_stats=offline_dataset.meta.stats if not cfg.resume else None,
|
||||
# Hack: But if we do online traning, we do not need dataset_stats
|
||||
dataset_stats=None,
|
||||
pretrained_policy_name_or_path=str(logger.last_pretrained_model_dir) if cfg.resume else None,
|
||||
device=device,
|
||||
)
|
||||
assert isinstance(policy, nn.Module)
|
||||
|
||||
optimizers, lr_scheduler = make_optimizers_and_scheduler(cfg, policy)
|
||||
|
||||
# TODO: Handle resume
|
||||
num_learnable_params = sum(p.numel() for p in policy.parameters() if p.requires_grad)
|
||||
num_total_params = sum(p.numel() for p in policy.parameters())
|
||||
|
||||
log_output_dir(out_dir)
|
||||
logging.info(f"{cfg.env.task=}")
|
||||
logging.info(f"{cfg.training.online_steps=}")
|
||||
logging.info(f"{num_learnable_params=} ({format_big_number(num_learnable_params)})")
|
||||
logging.info(f"{num_total_params=} ({format_big_number(num_total_params)})")
|
||||
|
||||
buffer_lock = Lock()
|
||||
replay_buffer = ReplayBuffer(
|
||||
capacity=cfg.training.online_buffer_capacity, device=device, state_keys=cfg.policy.input_shapes.keys()
|
||||
)
|
||||
|
||||
batch_size = cfg.training.batch_size
|
||||
offline_buffer_lock = None
|
||||
offline_replay_buffer = None
|
||||
if cfg.dataset_repo_id is not None:
|
||||
logging.info("make_dataset offline buffer")
|
||||
offline_dataset = make_dataset(cfg)
|
||||
logging.info("Convertion to a offline replay buffer")
|
||||
offline_replay_buffer = ReplayBuffer.from_lerobot_dataset(
|
||||
offline_dataset, device=device, state_keys=cfg.policy.input_shapes.keys()
|
||||
)
|
||||
offline_buffer_lock = Lock()
|
||||
batch_size: int = batch_size // 2 # We will sample from both replay buffer
|
||||
|
||||
server_thread = Thread(target=stream_transitions_from_actor, args=(50051,), daemon=True)
|
||||
server_thread.start()
|
||||
|
||||
|
||||
# Start a background thread to process transitions from the queue
|
||||
transition_thread = Thread(
|
||||
target=add_actor_information,
|
||||
daemon=True,
|
||||
args=(cfg,
|
||||
device,
|
||||
replay_buffer,
|
||||
offline_replay_buffer,
|
||||
batch_size,
|
||||
optimizers,
|
||||
policy,
|
||||
policy_lock,
|
||||
buffer_lock,
|
||||
offline_buffer_lock,
|
||||
logger_lock,
|
||||
logger),
|
||||
)
|
||||
transition_thread.start()
|
||||
|
||||
# param_push_thread = Thread(
|
||||
# target=learner_push_parameters,
|
||||
# args=(policy, policy_lock, "127.0.0.1", 50052, 15),
|
||||
# # args=("127.0.0.1", 50052),
|
||||
# daemon=True,
|
||||
# )
|
||||
# param_push_thread.start()
|
||||
|
||||
# interaction_thread = Thread(
|
||||
# target=add_message_interaction_to_wandb,
|
||||
# daemon=True,
|
||||
# args=(cfg, logger, logger_lock),
|
||||
# )
|
||||
# interaction_thread.start()
|
||||
|
||||
transition_thread.join()
|
||||
# param_push_thread.join()
|
||||
server_thread.join()
|
||||
# interaction_thread.join()
|
||||
|
||||
|
||||
@hydra.main(version_base="1.2", config_name="default", config_path="../../configs")
|
||||
def train_cli(cfg: dict):
|
||||
train(
|
||||
cfg,
|
||||
out_dir=hydra.core.hydra_config.HydraConfig.get().run.dir,
|
||||
job_name=hydra.core.hydra_config.HydraConfig.get().job.name,
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
train_cli()
|
|
@ -177,6 +177,7 @@ class ReplayBuffer:
|
|||
)
|
||||
self.position: int = (self.position + 1) % self.capacity
|
||||
|
||||
# TODO: ADD image_augmentation and use_drq arguments in this function in order to instantiate the class with them
|
||||
@classmethod
|
||||
def from_lerobot_dataset(
|
||||
cls,
|
||||
|
|
Loading…
Reference in New Issue