diff --git a/lerobot/common/logger.py b/lerobot/common/logger.py index 569cad69..6dc724db 100644 --- a/lerobot/common/logger.py +++ b/lerobot/common/logger.py @@ -174,18 +174,32 @@ class Logger: self, save_dir: Path, train_step: int, - optimizer: Optimizer, + optimizer: Optimizer | dict, scheduler: LRScheduler | None, + interaction_step: int | None = None, ): """Checkpoint the global training_step, optimizer state, scheduler state, and random state. All of these are saved as "training_state.pth" under the checkpoint directory. """ + # In Sac, for example, we have a dictionary of torch.optim.Optimizer + if type(optimizer) is dict: + optimizer_state_dict = {} + for k in optimizer: + optimizer_state_dict[k] = optimizer[k].state_dict() + else: + optimizer_state_dict = optimizer.state_dict() + training_state = { "step": train_step, - "optimizer": optimizer.state_dict(), + "optimizer": optimizer_state_dict, **get_global_random_state(), } + # Interaction step is related to the distributed training code + # In that setup, we have two kinds of steps, the online step of the env and the optimization step + # We need to save both in order to resume the optimization properly and not break the logs dependant on the interaction step + if interaction_step is not None: + training_state["interaction_step"] = interaction_step if scheduler is not None: training_state["scheduler"] = scheduler.state_dict() torch.save(training_state, save_dir / self.training_state_file_name) @@ -197,6 +211,7 @@ class Logger: optimizer: Optimizer, scheduler: LRScheduler | None, identifier: str, + interaction_step: int | None = None, ): """Checkpoint the model weights and the training state.""" checkpoint_dir = self.checkpoints_dir / str(identifier) @@ -208,16 +223,24 @@ class Logger: self.save_model( checkpoint_dir / self.pretrained_model_dir_name, policy, wandb_artifact_name=wandb_artifact_name ) - self.save_training_state(checkpoint_dir, train_step, optimizer, scheduler) + self.save_training_state(checkpoint_dir, train_step, optimizer, scheduler, interaction_step) os.symlink(checkpoint_dir.absolute(), self.last_checkpoint_dir) - def load_last_training_state(self, optimizer: Optimizer, scheduler: LRScheduler | None) -> int: + def load_last_training_state(self, optimizer: Optimizer | dict, scheduler: LRScheduler | None) -> int: """ Given the last checkpoint in the logging directory, load the optimizer state, scheduler state, and random state, and return the global training step. """ training_state = torch.load(self.last_checkpoint_dir / self.training_state_file_name) - optimizer.load_state_dict(training_state["optimizer"]) + # For the case where the optimizer is a dictionary of optimizers (e.g., sac) + if type(training_state["optimizer"]) is dict: + assert set(training_state["optimizer"].keys()) == set(optimizer.keys()), ( + "Optimizer dictionaries do not have the same keys during resume!" + ) + for k, v in training_state["optimizer"].items(): + optimizer[k].load_state_dict(v) + else: + optimizer.load_state_dict(training_state["optimizer"]) if scheduler is not None: scheduler.load_state_dict(training_state["scheduler"]) elif "scheduler" in training_state: @@ -228,7 +251,7 @@ class Logger: set_global_random_state({k: training_state[k] for k in get_global_random_state()}) return training_state["step"] - def log_dict(self, d, step:int | None = None, mode="train", custom_step_key: str | None = None): + def log_dict(self, d, step: int | None = None, mode="train", custom_step_key: str | None = None): """Log a dictionary of metrics to WandB.""" assert mode in {"train", "eval"} # TODO(alexander-soare): Add local text log. @@ -236,10 +259,9 @@ class Logger: raise ValueError("Either step or custom_step_key must be provided.") if self._wandb is not None: - - # NOTE: This is not simple. Wandb step is it must always monotonically increase and it + # NOTE: This is not simple. Wandb step is it must always monotonically increase and it # increases with each wandb.log call, but in the case of asynchronous RL for example, - # multiple time steps is possible for example, the interaction step with the environment, + # multiple time steps is possible for example, the interaction step with the environment, # the training step, the evaluation step, etc. So we need to define a custom step key # to log the correct step for each metric. if custom_step_key is not None and self._wandb_custom_step_key is None: @@ -247,7 +269,7 @@ class Logger: # custom step. self._wandb_custom_step_key = f"{mode}/{custom_step_key}" self._wandb.define_metric(self._wandb_custom_step_key, hidden=True) - + for k, v in d.items(): if not isinstance(v, (int, float, str, wandb.Table)): logging.warning( @@ -267,8 +289,6 @@ class Logger: self._wandb.log({f"{mode}/{k}": v}, step=step) - - def log_video(self, video_path: str, step: int, mode: str = "train"): assert mode in {"train", "eval"} assert self._wandb is not None diff --git a/lerobot/common/policies/sac/modeling_sac.py b/lerobot/common/policies/sac/modeling_sac.py index 8567313d..64688b1b 100644 --- a/lerobot/common/policies/sac/modeling_sac.py +++ b/lerobot/common/policies/sac/modeling_sac.py @@ -29,6 +29,7 @@ from torch import Tensor from lerobot.common.policies.normalize import Normalize, Unnormalize from lerobot.common.policies.sac.configuration_sac import SACConfig +from lerobot.common.policies.utils import get_device_from_parameters class SACPolicy( @@ -44,7 +45,6 @@ class SACPolicy( self, config: SACConfig | None = None, dataset_stats: dict[str, dict[str, Tensor]] | None = None, - device: str = "cpu", ): super().__init__() @@ -92,7 +92,6 @@ class SACPolicy( for _ in range(config.num_critics) ] ), - device=device, ) self.critic_target = CriticEnsemble( @@ -106,7 +105,6 @@ class SACPolicy( for _ in range(config.num_critics) ] ), - device=device, ) self.critic_target.load_state_dict(self.critic_ensemble.state_dict()) @@ -115,7 +113,6 @@ class SACPolicy( encoder=encoder_actor, network=MLP(input_dim=encoder_actor.output_dim, **config.actor_network_kwargs), action_dim=config.output_shapes["action"][0], - device=device, encoder_is_shared=config.shared_encoder, **config.policy_kwargs, ) @@ -123,13 +120,22 @@ class SACPolicy( config.target_entropy = -np.prod(config.output_shapes["action"][0]) / 2 # (-dim(A)/2) # TODO (azouitine): Handle the case where the temparameter is a fixed - self.log_alpha = torch.zeros(1, requires_grad=True, device=device) + # TODO (michel-aractingi): Put the log_alpha in cuda by default because otherwise + # it triggers "can't optimize a non-leaf Tensor" + self.log_alpha = torch.zeros(1, requires_grad=True, device=torch.device("cuda:0")) self.temperature = self.log_alpha.exp().item() def reset(self): """Reset the policy""" pass + def to(self, *args, **kwargs): + """Override .to(device) method to involve moving the log_alpha fixed_std""" + if self.actor.fixed_std is not None: + self.actor.fixed_std = self.actor.fixed_std.to(*args, **kwargs) + self.log_alpha = self.log_alpha.to(*args, **kwargs) + super().to(*args, **kwargs) + @torch.no_grad() def select_action(self, batch: dict[str, Tensor]) -> Tensor: """Select action for inference/evaluation""" @@ -308,17 +314,12 @@ class CriticEnsemble(nn.Module): encoder: Optional[nn.Module], network_list: nn.Module, init_final: Optional[float] = None, - device: str = "cpu", ): super().__init__() - self.device = torch.device(device) self.encoder = encoder self.network_list = network_list self.init_final = init_final - # for network in network_list: - # network.to(self.device) - # Find the last Linear layer's output dimension for layer in reversed(network_list[0].net): if isinstance(layer, nn.Linear): @@ -329,29 +330,28 @@ class CriticEnsemble(nn.Module): self.output_layers = [] if init_final is not None: for _ in network_list: - output_layer = nn.Linear(out_features, 1, device=device) + output_layer = nn.Linear(out_features, 1) nn.init.uniform_(output_layer.weight, -init_final, init_final) nn.init.uniform_(output_layer.bias, -init_final, init_final) self.output_layers.append(output_layer) else: self.output_layers = [] for _ in network_list: - output_layer = nn.Linear(out_features, 1, device=device) + output_layer = nn.Linear(out_features, 1) orthogonal_init()(output_layer.weight) self.output_layers.append(output_layer) self.output_layers = nn.ModuleList(self.output_layers) - self.to(self.device) - def forward( self, observations: dict[str, torch.Tensor], actions: torch.Tensor, ) -> torch.Tensor: + device = get_device_from_parameters(self) # Move each tensor in observations to device - observations = {k: v.to(self.device) for k, v in observations.items()} - actions = actions.to(self.device) + observations = {k: v.to(device) for k, v in observations.items()} + actions = actions.to(device) obs_enc = observations if self.encoder is None else self.encoder(observations) @@ -375,17 +375,15 @@ class Policy(nn.Module): fixed_std: Optional[torch.Tensor] = None, init_final: Optional[float] = None, use_tanh_squash: bool = False, - device: str = "cpu", encoder_is_shared: bool = False, ): super().__init__() - self.device = torch.device(device) self.encoder = encoder self.network = network self.action_dim = action_dim self.log_std_min = log_std_min self.log_std_max = log_std_max - self.fixed_std = fixed_std.to(self.device) if fixed_std is not None else None + self.fixed_std = fixed_std self.use_tanh_squash = use_tanh_squash self.parameters_to_optimize = [] @@ -417,8 +415,6 @@ class Policy(nn.Module): orthogonal_init()(self.std_layer.weight) self.parameters_to_optimize += list(self.std_layer.parameters()) - self.to(self.device) - def forward( self, observations: torch.Tensor, @@ -460,7 +456,8 @@ class Policy(nn.Module): def get_features(self, observations: torch.Tensor) -> torch.Tensor: """Get encoded features from observations""" - observations = observations.to(self.device) + device = get_device_from_parameters(self) + observations = observations.to(device) if self.encoder is not None: with torch.inference_mode(): return self.encoder(observations) diff --git a/lerobot/configs/policy/sac_maniskill.yaml b/lerobot/configs/policy/sac_maniskill.yaml index 59f42247..fa3dca37 100644 --- a/lerobot/configs/policy/sac_maniskill.yaml +++ b/lerobot/configs/policy/sac_maniskill.yaml @@ -8,7 +8,7 @@ # env.gym.obs_type=environment_state_agent_pos \ seed: 1 -dataset_repo_id: null +dataset_repo_id: aractingi/hil-serl-maniskill-pushcube training: # Offline training dataloader @@ -21,7 +21,7 @@ training: eval_freq: 2500 log_freq: 500 - save_freq: 50000 + save_freq: 1000000 online_steps: 1000000 online_rollout_n_episodes: 10 diff --git a/lerobot/scripts/server/actor_server.py b/lerobot/scripts/server/actor_server.py index 0d2a1a5e..294f07a6 100644 --- a/lerobot/scripts/server/actor_server.py +++ b/lerobot/scripts/server/actor_server.py @@ -152,7 +152,7 @@ def serve_actor_service(port=50052): server.wait_for_termination() -def act_with_policy(cfg: DictConfig, out_dir: str | None = None, job_name: str | None = None): +def act_with_policy(cfg: DictConfig): """ Executes policy interaction within the environment. @@ -161,8 +161,6 @@ def act_with_policy(cfg: DictConfig, out_dir: str | None = None, job_name: str | Args: cfg (DictConfig): Configuration settings for the interaction process. - out_dir (Optional[str]): Directory to store output logs or results. Defaults to None. - job_name (Optional[str]): Name of the job for logging or tracking purposes. Defaults to None. """ logging.info("make_env online") @@ -189,9 +187,10 @@ def act_with_policy(cfg: DictConfig, out_dir: str | None = None, job_name: str | # Hack: But if we do online training, we do not need dataset_stats dataset_stats=None, # TODO: Handle resume training - pretrained_policy_name_or_path=None, - device=device, ) + # pretrained_policy_name_or_path=None, + # device=device, + # ) assert isinstance(policy, nn.Module) # HACK for maniskill @@ -295,11 +294,7 @@ def actor_cli(cfg: dict): policy_thread = Thread( target=act_with_policy, daemon=True, - args=( - cfg, - hydra.core.hydra_config.HydraConfig.get().run.dir, - hydra.core.hydra_config.HydraConfig.get().job.name, - ), + args=(cfg,), ) policy_thread.start() policy_thread.join() diff --git a/lerobot/scripts/server/learner_server.py b/lerobot/scripts/server/learner_server.py index a9375972..cf6d8c76 100644 --- a/lerobot/scripts/server/learner_server.py +++ b/lerobot/scripts/server/learner_server.py @@ -18,6 +18,7 @@ import io import logging import pickle import queue +import shutil import time from pprint import pformat from threading import Lock, Thread @@ -29,18 +30,25 @@ import hilserl_pb2 # type: ignore import hilserl_pb2_grpc # type: ignore import hydra import torch +from deepdiff import DeepDiff from omegaconf import DictConfig, OmegaConf +from termcolor import colored from torch import nn -# TODO: Remove the import of maniskill from lerobot.common.datasets.factory import make_dataset + +# TODO: Remove the import of maniskill +from lerobot.common.datasets.lerobot_dataset import LeRobotDataset from lerobot.common.logger import Logger, log_output_dir from lerobot.common.policies.factory import make_policy from lerobot.common.policies.sac.modeling_sac import SACPolicy from lerobot.common.utils.utils import ( format_big_number, + get_global_random_state, get_safe_torch_device, + init_hydra_config, init_logging, + set_global_random_state, set_global_seed, ) from lerobot.scripts.server.buffer import ( @@ -127,10 +135,9 @@ def add_actor_information_and_train( optimizers: dict[str, torch.optim.Optimizer], policy: nn.Module, policy_lock: Lock, - buffer_lock: Lock, - offline_buffer_lock: Lock, - logger_lock: Lock, logger: Logger, + resume_optimization_step: int | None = None, + resume_interaction_step: int | None = None, ): """ Handles data transfer from the actor to the learner, manages training updates, @@ -159,16 +166,17 @@ def add_actor_information_and_train( optimizers (Dict[str, torch.optim.Optimizer]): A dictionary of optimizers (`"actor"`, `"critic"`, `"temperature"`). policy (nn.Module): The reinforcement learning policy with critic, actor, and temperature parameters. policy_lock (Lock): A threading lock to ensure safe policy updates. - buffer_lock (Lock): A threading lock to safely access the online replay buffer. - offline_buffer_lock (Lock): A threading lock to safely access the offline replay buffer. - logger_lock (Lock): A threading lock to safely log training metrics. logger (Logger): Logger instance for tracking training progress. + resume_optimization_step (int | None): In the case of resume training, start from the last optimization step reached. + resume_interaction_step (int | None): In the case of resume training, shift the interaction step with the last saved step in order to not break logging. """ # NOTE: This function doesn't have a single responsibility, it should be split into multiple functions # in the future. The reason why we did that is the GIL in Python. It's super slow the performance # are divided by 200. So we need to have a single thread that does all the work. time.time() - optimization_step = 0 + interaction_message, transition = None, None + optimization_step = resume_optimization_step if resume_optimization_step is not None else 0 + interaction_step_shift = resume_interaction_step if resume_interaction_step is not None else 0 while True: while not transition_queue.empty(): transition_list = transition_queue.get() @@ -178,6 +186,8 @@ def add_actor_information_and_train( while not interaction_message_queue.empty(): interaction_message = interaction_message_queue.get() + # If cfg.resume, shift the interaction step with the last checkpointed step in order to not break the logging + interaction_message["Interaction step"] += interaction_step_shift logger.log_dict(interaction_message, mode="train", custom_step_key="Interaction step") if len(replay_buffer) < cfg.training.online_step_before_learning: @@ -186,9 +196,9 @@ def add_actor_information_and_train( for _ in range(cfg.policy.utd_ratio - 1): batch = replay_buffer.sample(batch_size) - if cfg.dataset_repo_id is not None: - batch_offline = offline_replay_buffer.sample(batch_size) - batch = concatenate_batch_transitions(batch, batch_offline) + # if cfg.offline_dataset_repo_id is not None: + # batch_offline = offline_replay_buffer.sample(batch_size) + # batch = concatenate_batch_transitions(batch, batch_offline) actions = batch["action"] rewards = batch["reward"] @@ -210,11 +220,11 @@ def add_actor_information_and_train( batch = replay_buffer.sample(batch_size) - if cfg.dataset_repo_id is not None: - batch_offline = offline_replay_buffer.sample(batch_size) - batch = concatenate_batch_transitions( - left_batch_transitions=batch, right_batch_transition=batch_offline - ) + # if cfg.offline_dataset_repo_id is not None: + # batch_offline = offline_replay_buffer.sample(batch_size) + # batch = concatenate_batch_transitions( + # left_batch_transitions=batch, right_batch_transition=batch_offline + # ) actions = batch["action"] rewards = batch["reward"] @@ -274,6 +284,39 @@ def add_actor_information_and_train( if optimization_step % cfg.training.log_freq == 0: logging.info(f"[LEARNER] Number of optimization step: {optimization_step}") + if cfg.training.save_checkpoint and ( + optimization_step % cfg.training.save_freq == 0 or optimization_step == cfg.training.online_steps + ): + logging.info(f"Checkpoint policy after step {optimization_step}") + # Note: Save with step as the identifier, and format it to have at least 6 digits but more if + # needed (choose 6 as a minimum for consistency without being overkill). + _num_digits = max(6, len(str(cfg.training.online_steps))) + step_identifier = f"{optimization_step:0{_num_digits}d}" + interaction_step = ( + interaction_message["Interaction step"] if interaction_message is not None else 0 + ) + logger.save_checkpoint( + optimization_step, + policy, + optimizers, + scheduler=None, + identifier=step_identifier, + interaction_step=interaction_step, + ) + + # TODO : temporarly save replay buffer here, remove later when on the robot + # We want to control this with the keyboard inputs + dataset_dir = logger.log_dir / "dataset" + if dataset_dir.exists() and dataset_dir.is_dir(): + shutil.rmtree( + dataset_dir, + ) + replay_buffer.to_lerobot_dataset( + cfg.dataset_repo_id, fps=cfg.fps, root=logger.log_dir / "dataset" + ) + + logging.info("Resume training") + def make_optimizers_and_scheduler(cfg, policy: nn.Module): """ @@ -330,7 +373,49 @@ def train(cfg: DictConfig, out_dir: str | None = None, job_name: str | None = No logging.info(pformat(OmegaConf.to_container(cfg))) logger = Logger(cfg, out_dir, wandb_job_name=job_name) - logger_lock = Lock() + + ## Handle resume by reloading the state of the policy and optimization + # If we are resuming a run, we need to check that a checkpoint exists in the log directory, and we need + # to check for any differences between the provided config and the checkpoint's config. + if cfg.resume: + if not Logger.get_last_checkpoint_dir(out_dir).exists(): + raise RuntimeError( + "You have set resume=True, but there is no model checkpoint in " + f"{Logger.get_last_checkpoint_dir(out_dir)}" + ) + checkpoint_cfg_path = str(Logger.get_last_pretrained_model_dir(out_dir) / "config.yaml") + logging.info( + colored( + "You have set resume=True, indicating that you wish to resume a run", + color="yellow", + attrs=["bold"], + ) + ) + # Get the configuration file from the last checkpoint. + checkpoint_cfg = init_hydra_config(checkpoint_cfg_path) + # Check for differences between the checkpoint configuration and provided configuration. + diff = DeepDiff(OmegaConf.to_container(checkpoint_cfg), OmegaConf.to_container(cfg)) + # Ignore the `resume` and parameters. + if "values_changed" in diff and "root['resume']" in diff["values_changed"]: + del diff["values_changed"]["root['resume']"] + + # Log a warning about differences between the checkpoint configuration and the provided + # configuration. + if len(diff) > 0: + logging.warning( + "At least one difference was detected between the checkpoint configuration and " + f"the provided configuration: \n{pformat(diff)}\nNote that the checkpoint configuration " + "takes precedence.", + ) + # Use the checkpoint config instead of the provided config (but keep `resume` parameter). + cfg = checkpoint_cfg + cfg.resume = True + elif Logger.get_last_checkpoint_dir(out_dir).exists(): + raise RuntimeError( + f"The configured output directory {Logger.get_last_checkpoint_dir(out_dir)} already exists. If " + "you meant to resume training, please use `resume=true` in your command or yaml configuration." + ) + # =========================== set_global_seed(cfg.seed) @@ -346,20 +431,38 @@ def train(cfg: DictConfig, out_dir: str | None = None, job_name: str | None = No ### on both sides, the learner sends the updated parameters every n steps to update the actor's parameters # TODO: At some point we should just need make sac policy policy_lock = Lock() - with logger_lock: - policy: SACPolicy = make_policy( - hydra_cfg=cfg, - # dataset_stats=offline_dataset.meta.stats if not cfg.resume else None, - # Hack: But if we do online traning, we do not need dataset_stats - dataset_stats=None, - pretrained_policy_name_or_path=str(logger.last_pretrained_model_dir) if cfg.resume else None, - device=device, - ) + policy: SACPolicy = make_policy( + hydra_cfg=cfg, + # dataset_stats=offline_dataset.meta.stats if not cfg.resume else None, + # Hack: But if we do online traning, we do not need dataset_stats + dataset_stats=None, + pretrained_policy_name_or_path=str(logger.last_pretrained_model_dir) if cfg.resume else None, + ) + # device=device, + # ) assert isinstance(policy, nn.Module) optimizers, lr_scheduler = make_optimizers_and_scheduler(cfg, policy) + # load last training state + # We can't use the logger function in `lerobot/common/logger.py` + # because it only loads the optimization step and not the interaction one + # to avoid altering that code, we will just load the optimization state manually + resume_interaction_step, resume_optimization_step = None, None + if cfg.resume: + training_state = torch.load(logger.last_checkpoint_dir / logger.training_state_file_name) + if type(training_state["optimizer"]) is dict: + assert set(training_state["optimizer"].keys()) == set(optimizers.keys()), ( + "Optimizer dictionaries do not have the same keys during resume!" + ) + for k, v in training_state["optimizer"].items(): + optimizers[k].load_state_dict(v) + else: + optimizers.load_state_dict(training_state["optimizer"]) + # Small hack to get the expected keys: use `get_global_random_state`. + set_global_random_state({k: training_state[k] for k in get_global_random_state()}) + resume_optimization_step = training_state["step"] + resume_interaction_step = training_state["interaction_step"] - # TODO: Handle resume num_learnable_params = sum(p.numel() for p in policy.parameters() if p.requires_grad) num_total_params = sum(p.numel() for p in policy.parameters()) @@ -369,24 +472,34 @@ def train(cfg: DictConfig, out_dir: str | None = None, job_name: str | None = No logging.info(f"{num_learnable_params=} ({format_big_number(num_learnable_params)})") logging.info(f"{num_total_params=} ({format_big_number(num_total_params)})") - buffer_lock = Lock() - replay_buffer = ReplayBuffer( - capacity=cfg.training.online_buffer_capacity, device=device, state_keys=cfg.policy.input_shapes.keys() - ) - + if not cfg.resume: + replay_buffer = ReplayBuffer( + capacity=cfg.training.online_buffer_capacity, + device=device, + state_keys=cfg.policy.input_shapes.keys(), + ) + else: + # Reload replay buffer + dataset = LeRobotDataset( + repo_id=cfg.dataset_repo_id, local_files_only=True, root=logger.log_dir / "dataset" + ) + replay_buffer = ReplayBuffer.from_lerobot_dataset( + lerobot_dataset=dataset, + capacity=cfg.training.online_buffer_capacity, + device=device, + state_keys=cfg.policy.input_shapes.keys(), + ) batch_size = cfg.training.batch_size - offline_buffer_lock = None offline_replay_buffer = None - if cfg.dataset_repo_id is not None: - logging.info("make_dataset offline buffer") - offline_dataset = make_dataset(cfg) - logging.info("Convertion to a offline replay buffer") - offline_replay_buffer = ReplayBuffer.from_lerobot_dataset( - offline_dataset, device=device, state_keys=cfg.policy.input_shapes.keys() - ) - offline_buffer_lock = Lock() - batch_size: int = batch_size // 2 # We will sample from both replay buffer + # if cfg.dataset_repo_id is not None: + # logging.info("make_dataset offline buffer") + # offline_dataset = make_dataset(cfg) + # logging.info("Convertion to a offline replay buffer") + # offline_replay_buffer = ReplayBuffer.from_lerobot_dataset( + # offline_dataset, device=device, state_keys=cfg.policy.input_shapes.keys() + # ) + # batch_size: int = batch_size // 2 # We will sample from both replay buffer actor_ip = cfg.actor_learner_config.actor_ip port = cfg.actor_learner_config.port @@ -413,10 +526,9 @@ def train(cfg: DictConfig, out_dir: str | None = None, job_name: str | None = No optimizers, policy, policy_lock, - buffer_lock, - offline_buffer_lock, - logger_lock, logger, + resume_optimization_step, + resume_interaction_step, ), ) transition_thread.start()