diff --git a/lerobot/common/logger.py b/lerobot/common/logger.py index 4015492d..35c12062 100644 --- a/lerobot/common/logger.py +++ b/lerobot/common/logger.py @@ -127,6 +127,8 @@ class Logger: job_type="train_eval", resume="must" if cfg.resume else None, ) + # Handle custom step key for rl asynchronous training. + self._wandb_custom_step_key = None print(colored("Logs will be synced with wandb.", "blue", attrs=["bold"])) logging.info(f"Track this run --> {colored(wandb.run.get_url(), 'yellow', attrs=['bold'])}") self._wandb = wandb @@ -226,18 +228,47 @@ class Logger: set_global_random_state({k: training_state[k] for k in get_global_random_state()}) return training_state["step"] - def log_dict(self, d, step, mode="train"): + def log_dict(self, d, step:int | None = None, mode="train", custom_step_key: str | None = None): + """Log a dictionary of metrics to WandB.""" assert mode in {"train", "eval"} # TODO(alexander-soare): Add local text log. + if step is None and custom_step_key is None: + raise ValueError("Either step or custom_step_key must be provided.") + if self._wandb is not None: + + # NOTE: This is not simple. Wandb step is it must always monotonically increase and it + # increases with each wandb.log call, but in the case of asynchronous RL for example, + # multiple time steps is possible for example, the interaction step with the environment, + # the training step, the evaluation step, etc. So we need to define a custom step key + # to log the correct step for each metric. + if custom_step_key is not None and self._wandb_custom_step_key is None: + # NOTE: Define the custom step key, once for the moment this implementation support only one + # custom step. + self._wandb_custom_step_key = f"{mode}/{custom_step_key}" + self._wandb.define_metric(self._wandb_custom_step_key, hidden=True) + for k, v in d.items(): if not isinstance(v, (int, float, str, wandb.Table)): logging.warning( f'WandB logging of key "{k}" was ignored as its type is not handled by this wrapper.' ) continue + + # We don't want to log the custom step + if k == custom_step_key: + continue + + if self._wandb_custom_step_key is not None: + # NOTE: Log the metric with the custom step key. + value_custom_step_key = d[custom_step_key] + self._wandb.log({f"{mode}/{k}": v, self._wandb_custom_step_key: value_custom_step_key}) + continue + self._wandb.log({f"{mode}/{k}": v}, step=step) + + def log_video(self, video_path: str, step: int, mode: str = "train"): assert mode in {"train", "eval"} assert self._wandb is not None diff --git a/lerobot/scripts/server/actor_server.py b/lerobot/scripts/server/actor_server.py new file mode 100644 index 00000000..afa6a6e0 --- /dev/null +++ b/lerobot/scripts/server/actor_server.py @@ -0,0 +1,282 @@ +#!/usr/bin/env python + +# Copyright 2024 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import logging +import functools +from pprint import pformat +import random +from typing import Optional, Sequence, TypedDict, Callable +import pickle + +import hydra +import torch +import torch.nn.functional as F +from torch import nn +from tqdm import tqdm +from deepdiff import DeepDiff +from omegaconf import DictConfig, OmegaConf + +from lerobot.common.datasets.lerobot_dataset import LeRobotDataset + +# TODO: Remove the import of maniskill +from lerobot.common.datasets.factory import make_dataset +from lerobot.common.envs.factory import make_env, make_maniskill_env +from lerobot.common.envs.utils import preprocess_observation, preprocess_maniskill_observation +from lerobot.common.logger import Logger, log_output_dir +from lerobot.common.policies.factory import make_policy +from lerobot.common.policies.sac.modeling_sac import SACPolicy +from lerobot.common.policies.utils import get_device_from_parameters +from lerobot.common.utils.utils import ( + format_big_number, + get_safe_torch_device, + init_hydra_config, + init_logging, + set_global_seed, +) +# from lerobot.scripts.eval import eval_policy +from threading import Thread +import queue + +import grpc +from lerobot.scripts.server import hilserl_pb2, hilserl_pb2_grpc +import io +import time +import logging +from concurrent import futures +from threading import Thread +from lerobot.scripts.server.buffer import move_state_dict_to_device, move_transition_to_device, Transition + +import faulthandler +import signal + +logging.basicConfig(level=logging.INFO) + +parameters_queue = queue.Queue(maxsize=1) +message_queue = queue.Queue(maxsize=1_000_000) + +class ActorInformation: + def __init__(self, transition=None, interaction_message=None): + self.transition = transition + self.interaction_message = interaction_message + + +# 1) Implement ActorService so the Learner can send parameters to this Actor. +class ActorServiceServicer(hilserl_pb2_grpc.ActorServiceServicer): + def StreamTransition(self, request, context): + while True: + # logging.info(f"[ACTOR] before message.empty()") + # logging.info(f"[ACTOR] size transition queue {message_queue.qsize()}") + # time.sleep(0.01) + # if message_queue.empty(): + # continue + # logging.info(f"[ACTOR] after message.empty()") + start = time.time() + message = message_queue.get(block=True) + # logging.info(f"[ACTOR] Message queue get time {time.time() - start}") + + if message.transition is not None: + # transition_to_send_to_learner = move_transition_to_device(message.transition, device="cpu") + transition_to_send_to_learner = [move_transition_to_device(T, device="cpu") for T in message.transition] + # logging.info(f"[ACTOR] Message queue get time {time.time() - start}") + + # Serialize it + buf = io.BytesIO() + torch.save(transition_to_send_to_learner, buf) + transition_bytes = buf.getvalue() + + transition_message = hilserl_pb2.Transition( + transition_bytes=transition_bytes + ) + + response = hilserl_pb2.ActorInformation( + transition=transition_message + ) + logging.info(f"[ACTOR] time to yield transition response {time.time() - start}") + logging.info(f"[ACTOR] size transition queue {message_queue.qsize()}") + + elif message.interaction_message is not None: + # Serialize it and send it to the Learner's server + content = hilserl_pb2.InteractionMessage( + interaction_message_bytes=pickle.dumps(message.interaction_message) + ) + response = hilserl_pb2.ActorInformation( + interaction_message=content + ) + + # logging.info(f"[ACTOR] yield response before") + yield response + # logging.info(f"[ACTOR] response yielded after") + + def SendParameters(self, request, context): + """ + Learner calls this with updated Parameters -> Actor + """ + # logging.info("[ACTOR] Received parameters from Learner.") + buffer = io.BytesIO(request.parameter_bytes) + params = torch.load(buffer) + parameters_queue.put(params) + return hilserl_pb2.Empty() + + +def serve_actor_service(port=50052): + """ + Runs a gRPC server so that the Learner can push parameters to the Actor. + """ + server = grpc.server(futures.ThreadPoolExecutor(max_workers=20), + options=[('grpc.max_send_message_length', -1), + ('grpc.max_receive_message_length', -1)]) + hilserl_pb2_grpc.add_ActorServiceServicer_to_server( + ActorServiceServicer(), server + ) + server.add_insecure_port(f'[::]:{port}') + server.start() + logging.info(f"[ACTOR] gRPC server listening on port {port}") + server.wait_for_termination() + +def act_with_policy(cfg: DictConfig, + out_dir: str | None = None, + job_name: str | None = None): + + if out_dir is None: + raise NotImplementedError() + if job_name is None: + raise NotImplementedError() + + logging.info("make_env online") + + # online_env = make_env(cfg, n_envs=1) + # TODO: Remove the import of maniskill and unifiy with make env + online_env = make_maniskill_env(cfg, n_envs=1) + if cfg.training.eval_freq > 0: + logging.info("make_env eval") + # eval_env = make_env(cfg, n_envs=1) + # TODO: Remove the import of maniskill and unifiy with make env + eval_env = make_maniskill_env(cfg, n_envs=1) + + set_global_seed(cfg.seed) + device = get_safe_torch_device(cfg.device, log=True) + + torch.backends.cudnn.benchmark = True + torch.backends.cuda.matmul.allow_tf32 = True + + logging.info("make_policy") + + + ### Instantiate the policy in both the actor and learner processes + ### To avoid sending a SACPolicy object through the port, we create a policy intance + ### on both sides, the learner sends the updated parameters every n steps to update the actor's parameters + # TODO: At some point we should just need make sac policy + policy: SACPolicy = make_policy( + hydra_cfg=cfg, + # dataset_stats=offline_dataset.meta.stats if not cfg.resume else None, + # Hack: But if we do online traning, we do not need dataset_stats + dataset_stats=None, + # TODO: Handle resume training + pretrained_policy_name_or_path=None, + device=device, + ) + assert isinstance(policy, nn.Module) + + # HACK for maniskill + obs, info = online_env.reset() + + # obs = preprocess_observation(obs) + obs = preprocess_maniskill_observation(obs) + obs = {key: obs[key].to(device, non_blocking=True) for key in obs} + ### ACTOR ================== + # NOTE: For the moment we will solely handle the case of a single environment + sum_reward_episode = 0 + list_transition_to_send_to_learner = [] + + for interaction_step in range(cfg.training.online_steps): + # NOTE: At some point we should use a wrapper to handle the observation + + # start = time.time() + if interaction_step >= cfg.training.online_step_before_learning: + action = policy.select_action(batch=obs) + next_obs, reward, done, truncated, info = online_env.step(action.cpu().numpy()) + else: + action = online_env.action_space.sample() + next_obs, reward, done, truncated, info = online_env.step(action) + # HACK + action = torch.tensor(action, dtype=torch.float32).to(device, non_blocking=True) + + # logging.info(f"[ACTOR] Time for env step {time.time() - start}") + + # HACK: For maniskill + # next_obs = preprocess_observation(next_obs) + next_obs = preprocess_maniskill_observation(next_obs) + next_obs = {key: next_obs[key].to(device, non_blocking=True) for key in obs} + sum_reward_episode += float(reward[0]) + # Because we are using a single environment + # we can safely assume that the episode is done + if done[0].item() or truncated[0].item(): + # TODO: Handle logging for episode information + logging.info(f"[ACTOR] Global step {interaction_step}: Episode reward: {sum_reward_episode}") + + if not parameters_queue.empty(): + logging.info("[ACTOR] Load new parameters from Learner.") + # Load new parameters from Learner + state_dict = parameters_queue.get() + state_dict = move_state_dict_to_device(state_dict, device=device) + policy.actor.load_state_dict(state_dict) + + if len(list_transition_to_send_to_learner) > 0: + logging.info(f"[ACTOR] Sending {len(list_transition_to_send_to_learner)} transitions to Learner.") + message_queue.put(ActorInformation(transition=list_transition_to_send_to_learner)) + list_transition_to_send_to_learner = [] + + # Send episodic reward to the learner + message_queue.put(ActorInformation(interaction_message={"episodic_reward": sum_reward_episode,"interaction_step": interaction_step})) + sum_reward_episode = 0.0 + + # ============================ + # Prepare transition to send + # ============================ + # Label the reward + # if config.label_reward_on_actor: + # reward = reward_classifier(obs) + + list_transition_to_send_to_learner.append(Transition( + # transition_to_send_to_learner = Transition( + state=obs, + action=action, + reward=reward, + next_state=next_obs, + done=done, + complementary_info=None, + ) + ) + # message_queue.put(ActorInformation(transition=transition_to_send_to_learner)) + + # assign obs to the next obs and continue the rollout + obs = next_obs + +@hydra.main(version_base="1.2", config_name="default", config_path="../../configs") +def actor_cli(cfg: dict): + server_thread = Thread(target=serve_actor_service, args=(50051,), daemon=True) + server_thread.start() + policy_thread = Thread(target=act_with_policy, + daemon=True, + args=(cfg,hydra.core.hydra_config.HydraConfig.get().run.dir, hydra.core.hydra_config.HydraConfig.get().job.name)) + policy_thread.start() + policy_thread.join() + server_thread.join() + +if __name__ == "__main__": + with open("traceback.log", "w") as f: + faulthandler.register(signal.SIGUSR1, file=f) + + actor_cli() \ No newline at end of file diff --git a/lerobot/scripts/server/hilserl.proto b/lerobot/scripts/server/hilserl.proto new file mode 100644 index 00000000..41f85100 --- /dev/null +++ b/lerobot/scripts/server/hilserl.proto @@ -0,0 +1,42 @@ +syntax = "proto3"; + +package hil_serl; + +// LearnerService: the Actor calls this to push transitions. +// The Learner implements this service. +service LearnerService { + // Actor -> Learner to store transitions + rpc SendTransition(Transition) returns (Empty); + rpc SendInteractionMessage(InteractionMessage) returns (Empty); +} + +// ActorService: the Learner calls this to push parameters. +// The Actor implements this service. +service ActorService { + // Learner -> Actor to send new parameters + rpc StreamTransition(Empty) returns (stream ActorInformation) {}; + rpc SendParameters(Parameters) returns (Empty); +} + + +message ActorInformation { + oneof data { + Transition transition = 1; + InteractionMessage interaction_message = 2; + } +} + +// Messages +message Transition { + bytes transition_bytes = 1; +} + +message Parameters { + bytes parameter_bytes = 1; +} + +message InteractionMessage { + bytes interaction_message_bytes = 1; +} + +message Empty {} \ No newline at end of file diff --git a/lerobot/scripts/server/learner_server.py b/lerobot/scripts/server/learner_server.py new file mode 100644 index 00000000..22777a26 --- /dev/null +++ b/lerobot/scripts/server/learner_server.py @@ -0,0 +1,394 @@ +import grpc +from concurrent import futures +import functools +import logging +import queue +import pickle +import torch +import torch.nn.functional as F +import io +import time + +from pprint import pformat +import random +from typing import Optional, Sequence, TypedDict, Callable + +import hydra +import torch +import torch.nn.functional as F +from torch import nn +from tqdm import tqdm +from deepdiff import DeepDiff +from omegaconf import DictConfig, OmegaConf +from threading import Thread, Lock + +from lerobot.common.datasets.lerobot_dataset import LeRobotDataset + +# TODO: Remove the import of maniskill +from lerobot.common.datasets.factory import make_dataset +from lerobot.common.logger import Logger, log_output_dir +from lerobot.common.policies.factory import make_policy +from lerobot.common.policies.sac.modeling_sac import SACPolicy +from lerobot.common.policies.utils import get_device_from_parameters +from lerobot.common.utils.utils import ( + format_big_number, + get_safe_torch_device, + init_hydra_config, + init_logging, + set_global_seed, +) +from lerobot.scripts.server.buffer import ReplayBuffer, move_transition_to_device, concatenate_batch_transitions, move_state_dict_to_device, Transition + +# Import generated stubs +import hilserl_pb2 +import hilserl_pb2_grpc + +logging.basicConfig(level=logging.INFO) + + + +# TODO: Implement it in cleaner way maybe +transition_queue = queue.Queue() +interaction_message_queue = queue.Queue() + + +# 1) Implement the LearnerService so the Actor can send transitions here. +class LearnerServiceServicer(hilserl_pb2_grpc.LearnerServiceServicer): + # def SendTransition(self, request, context): + # """ + # Actor calls this method to push a Transition -> Learner. + # """ + # buffer = io.BytesIO(request.transition_bytes) + # transition = torch.load(buffer) + # transition_queue.put(transition) + # return hilserl_pb2.Empty() + def SendInteractionMessage(self, request, context): + """ + Actor calls this method to push a Transition -> Learner. + """ + content = pickle.loads(request.interaction_message_bytes) + interaction_message_queue.put(content) + return hilserl_pb2.Empty() + + + +def stream_transitions_from_actor(port=50051): + """ + Runs a gRPC server listening for transitions from the Actor. + """ + time.sleep(10) + channel = grpc.insecure_channel(f'127.0.0.1:{port}', + options=[('grpc.max_send_message_length', -1), + ('grpc.max_receive_message_length', -1)]) + stub = hilserl_pb2_grpc.ActorServiceStub(channel) + for response in stub.StreamTransition(hilserl_pb2.Empty()): + if response.HasField('transition'): + buffer = io.BytesIO(response.transition.transition_bytes) + transition = torch.load(buffer) + transition_queue.put(transition) + if response.HasField('interaction_message'): + content = pickle.loads(response.interaction_message.interaction_message_bytes) + interaction_message_queue.put(content) + # NOTE: Cool down the CPU, if you comment this line you will make a huge bottleneck + time.sleep(0.001) + +def learner_push_parameters( + policy: nn.Module, policy_lock: Lock, actor_host="127.0.0.1", actor_port=50052, seconds_between_pushes=5 +): + """ + As a client, connect to the Actor's gRPC server (ActorService) + and periodically push new parameters. + """ + time.sleep(10) + # The Actor's server is presumably listening on a different port, e.g. 50052 + channel = grpc.insecure_channel(f"{actor_host}:{actor_port}", + options=[('grpc.max_send_message_length', -1), + ('grpc.max_receive_message_length', -1)]) + actor_stub = hilserl_pb2_grpc.ActorServiceStub(channel) + + while True: + with policy_lock: + params_dict = policy.actor.state_dict() + params_dict = move_state_dict_to_device(params_dict, device="cpu") + # Serialize + buf = io.BytesIO() + torch.save(params_dict, buf) + params_bytes = buf.getvalue() + + # Push them to the Actor’s "SendParameters" method + response = actor_stub.SendParameters(hilserl_pb2.Parameters(parameter_bytes=params_bytes)) + time.sleep(seconds_between_pushes) + + +# Checked +def add_actor_information( + cfg, + device, + replay_buffer: ReplayBuffer, + offline_replay_buffer: ReplayBuffer, + batch_size: int, + optimizers, + policy, + policy_lock: Lock, + buffer_lock: Lock, + offline_buffer_lock: Lock, + logger_lock: Lock, + logger: Logger, +): + """ + In a real application, you might run your training loop here, + reading from the transition queue and doing gradient updates. + """ + # NOTE: This function doesn't have a single responsibility, it should be split into multiple functions + # in the future. The reason why we did that is the GIL in Python. It's super slow the performance + # are divided by 200. So we need to have a single thread that does all the work. + start = time.time() + optimization_step = 0 + + while True: + time_for_adding_transitions = time.time() + while not transition_queue.empty(): + + transition_list = transition_queue.get() + for transition in transition_list: + transition = move_transition_to_device(transition, device=device) + replay_buffer.add(**transition) + logging.info(f"[LEARNER] size of replay buffer: {len(replay_buffer)}") + logging.info(f"[LEARNER] size of transition queues: {transition_queue.qsize()}") + + + while not interaction_message_queue.empty(): + interaction_message = interaction_message_queue.get() + logger.log_dict(interaction_message,mode="train",custom_step_key="interaction_step") + logging.info(f"[LEARNER] size of interaction message queue: {interaction_message_queue.qsize()}") + + # if len(replay_buffer.memory) < cfg.training.online_step_before_learning: + # continue + + # for _ in range(cfg.policy.utd_ratio - 1): + + # batch = replay_buffer.sample(batch_size) + # if cfg.dataset_repo_id is not None: + # batch_offline = offline_replay_buffer.sample(batch_size) + # batch = concatenate_batch_transitions(batch, batch_offline) + + # actions = batch["action"] + # rewards = batch["reward"] + # observations = batch["state"] + # next_observations = batch["next_state"] + # done = batch["done"] + + # with policy_lock: + # loss_critic = policy.compute_loss_critic( + # observations=observations, + # actions=actions, + # rewards=rewards, + # next_observations=next_observations, + # done=done, + # ) + # optimizers["critic"].zero_grad() + # loss_critic.backward() + # optimizers["critic"].step() + + # batch = replay_buffer.sample(batch_size) + + # if cfg.dataset_repo_id is not None: + # batch_offline = offline_replay_buffer.sample(batch_size) + # batch = concatenate_batch_transitions( + # left_batch_transitions=batch, right_batch_transition=batch_offline + # ) + + # actions = batch["action"] + # rewards = batch["reward"] + # observations = batch["state"] + # next_observations = batch["next_state"] + # done = batch["done"] + + # with policy_lock: + # loss_critic = policy.compute_loss_critic( + # observations=observations, + # actions=actions, + # rewards=rewards, + # next_observations=next_observations, + # done=done, + # ) + # optimizers["critic"].zero_grad() + # loss_critic.backward() + # optimizers["critic"].step() + + # training_infos = {} + # training_infos["loss_critic"] = loss_critic.item() + + # if optimization_step % cfg.training.policy_update_freq == 0: + # for _ in range(cfg.training.policy_update_freq): + # with policy_lock: + # loss_actor = policy.compute_loss_actor(observations=observations) + + # optimizers["actor"].zero_grad() + # loss_actor.backward() + # optimizers["actor"].step() + + # training_infos["loss_actor"] = loss_actor.item() + + # loss_temperature = policy.compute_loss_temperature(observations=observations) + # optimizers["temperature"].zero_grad() + # loss_temperature.backward() + # optimizers["temperature"].step() + + # training_infos["loss_temperature"] = loss_temperature.item() + + # if optimization_step % cfg.training.log_freq == 0: + # logger.log_dict(training_infos, step=optimization_step, mode="train") + + # policy.update_target_networks() + # optimization_step += 1 + # time_for_one_optimization_step = time.time() - time_for_one_optimization_step + + # logger.log_dict({"[LEARNER] Time optimization step":time_for_one_optimization_step}, step=optimization_step, mode="train") + # time_for_one_optimization_step = time.time() + + +def make_optimizers_and_scheduler(cfg, policy): + optimizer_actor = torch.optim.Adam( + # NOTE: Handle the case of shared encoder where the encoder weights are not optimized with the gradient of the actor + params=policy.actor.parameters_to_optimize, + lr=policy.config.actor_lr, + ) + optimizer_critic = torch.optim.Adam( + params=policy.critic_ensemble.parameters(), lr=policy.config.critic_lr + ) + # We wrap policy log temperature in list because this is a torch tensor and not a nn.Module + optimizer_temperature = torch.optim.Adam(params=[policy.log_alpha], lr=policy.config.critic_lr) + lr_scheduler = None + optimizers = { + "actor": optimizer_actor, + "critic": optimizer_critic, + "temperature": optimizer_temperature, + } + return optimizers, lr_scheduler + + + + +def train(cfg: DictConfig, out_dir: str | None = None, job_name: str | None = None): + if out_dir is None: + raise NotImplementedError() + if job_name is None: + raise NotImplementedError() + + init_logging() + logging.info(pformat(OmegaConf.to_container(cfg))) + + logger = Logger(cfg, out_dir, wandb_job_name=job_name) + logger_lock = Lock() + + set_global_seed(cfg.seed) + + device = get_safe_torch_device(cfg.device, log=True) + + torch.backends.cudnn.benchmark = True + torch.backends.cuda.matmul.allow_tf32 = True + + logging.info("make_policy") + + ### Instantiate the policy in both the actor and learner processes + ### To avoid sending a SACPolicy object through the port, we create a policy intance + ### on both sides, the learner sends the updated parameters every n steps to update the actor's parameters + # TODO: At some point we should just need make sac policy + policy_lock = Lock() + with logger_lock: + policy: SACPolicy = make_policy( + hydra_cfg=cfg, + # dataset_stats=offline_dataset.meta.stats if not cfg.resume else None, + # Hack: But if we do online traning, we do not need dataset_stats + dataset_stats=None, + pretrained_policy_name_or_path=str(logger.last_pretrained_model_dir) if cfg.resume else None, + device=device, + ) + assert isinstance(policy, nn.Module) + + optimizers, lr_scheduler = make_optimizers_and_scheduler(cfg, policy) + + # TODO: Handle resume + num_learnable_params = sum(p.numel() for p in policy.parameters() if p.requires_grad) + num_total_params = sum(p.numel() for p in policy.parameters()) + + log_output_dir(out_dir) + logging.info(f"{cfg.env.task=}") + logging.info(f"{cfg.training.online_steps=}") + logging.info(f"{num_learnable_params=} ({format_big_number(num_learnable_params)})") + logging.info(f"{num_total_params=} ({format_big_number(num_total_params)})") + + buffer_lock = Lock() + replay_buffer = ReplayBuffer( + capacity=cfg.training.online_buffer_capacity, device=device, state_keys=cfg.policy.input_shapes.keys() + ) + + batch_size = cfg.training.batch_size + offline_buffer_lock = None + offline_replay_buffer = None + if cfg.dataset_repo_id is not None: + logging.info("make_dataset offline buffer") + offline_dataset = make_dataset(cfg) + logging.info("Convertion to a offline replay buffer") + offline_replay_buffer = ReplayBuffer.from_lerobot_dataset( + offline_dataset, device=device, state_keys=cfg.policy.input_shapes.keys() + ) + offline_buffer_lock = Lock() + batch_size: int = batch_size // 2 # We will sample from both replay buffer + + server_thread = Thread(target=stream_transitions_from_actor, args=(50051,), daemon=True) + server_thread.start() + + + # Start a background thread to process transitions from the queue + transition_thread = Thread( + target=add_actor_information, + daemon=True, + args=(cfg, + device, + replay_buffer, + offline_replay_buffer, + batch_size, + optimizers, + policy, + policy_lock, + buffer_lock, + offline_buffer_lock, + logger_lock, + logger), + ) + transition_thread.start() + + # param_push_thread = Thread( + # target=learner_push_parameters, + # args=(policy, policy_lock, "127.0.0.1", 50052, 15), + # # args=("127.0.0.1", 50052), + # daemon=True, + # ) + # param_push_thread.start() + + # interaction_thread = Thread( + # target=add_message_interaction_to_wandb, + # daemon=True, + # args=(cfg, logger, logger_lock), + # ) + # interaction_thread.start() + + transition_thread.join() + # param_push_thread.join() + server_thread.join() + # interaction_thread.join() + + +@hydra.main(version_base="1.2", config_name="default", config_path="../../configs") +def train_cli(cfg: dict): + train( + cfg, + out_dir=hydra.core.hydra_config.HydraConfig.get().run.dir, + job_name=hydra.core.hydra_config.HydraConfig.get().job.name, + ) + + +if __name__ == "__main__": + train_cli() diff --git a/lerobot/scripts/train_sac.py b/lerobot/scripts/train_sac.py index 866415d0..936d65ee 100644 --- a/lerobot/scripts/train_sac.py +++ b/lerobot/scripts/train_sac.py @@ -177,6 +177,7 @@ class ReplayBuffer: ) self.position: int = (self.position + 1) % self.capacity + # TODO: ADD image_augmentation and use_drq arguments in this function in order to instantiate the class with them @classmethod def from_lerobot_dataset( cls,