1184 lines
43 KiB
Python
1184 lines
43 KiB
Python
#!/usr/bin/env python
|
|
|
|
# Copyright 2024 The HuggingFace Inc. team.
|
|
# All rights reserved.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
import logging
|
|
import os
|
|
import shutil
|
|
import time
|
|
from concurrent.futures import ThreadPoolExecutor
|
|
from pathlib import Path
|
|
from pprint import pformat
|
|
|
|
import grpc
|
|
|
|
# Import generated stubs
|
|
import hilserl_pb2_grpc # type: ignore
|
|
import torch
|
|
from termcolor import colored
|
|
from torch import nn
|
|
from torch.multiprocessing import Queue
|
|
from torch.optim.optimizer import Optimizer
|
|
|
|
from lerobot.common.constants import (
|
|
CHECKPOINTS_DIR,
|
|
LAST_CHECKPOINT_LINK,
|
|
PRETRAINED_MODEL_DIR,
|
|
TRAINING_STATE_DIR,
|
|
)
|
|
from lerobot.common.datasets.factory import make_dataset
|
|
|
|
# TODO: Remove the import of maniskill
|
|
from lerobot.common.datasets.lerobot_dataset import LeRobotDataset
|
|
from lerobot.common.policies.factory import make_policy
|
|
from lerobot.common.policies.sac.modeling_sac import SACPolicy
|
|
from lerobot.common.utils.random_utils import set_seed
|
|
from lerobot.common.utils.train_utils import (
|
|
get_step_checkpoint_dir,
|
|
save_checkpoint,
|
|
update_last_checkpoint,
|
|
)
|
|
from lerobot.common.utils.train_utils import (
|
|
load_training_state as utils_load_training_state,
|
|
)
|
|
from lerobot.common.utils.utils import (
|
|
format_big_number,
|
|
get_safe_torch_device,
|
|
init_logging,
|
|
)
|
|
from lerobot.common.utils.wandb_utils import WandBLogger
|
|
from lerobot.configs import parser
|
|
from lerobot.configs.train import TrainPipelineConfig
|
|
from lerobot.scripts.server import learner_service
|
|
from lerobot.scripts.server.buffer import (
|
|
ReplayBuffer,
|
|
bytes_to_python_object,
|
|
bytes_to_transitions,
|
|
concatenate_batch_transitions,
|
|
move_state_dict_to_device,
|
|
move_transition_to_device,
|
|
state_to_bytes,
|
|
)
|
|
from lerobot.scripts.server.utils import setup_process_handlers
|
|
|
|
LOG_PREFIX = "[LEARNER]"
|
|
|
|
logging.basicConfig(level=logging.INFO)
|
|
|
|
#################################################
|
|
# MAIN ENTRY POINTS AND CORE ALGORITHM FUNCTIONS #
|
|
#################################################
|
|
|
|
|
|
@parser.wrap()
|
|
def train_cli(cfg: TrainPipelineConfig):
|
|
if not use_threads(cfg):
|
|
import torch.multiprocessing as mp
|
|
|
|
mp.set_start_method("spawn")
|
|
|
|
# Use the job_name from the config
|
|
train(
|
|
cfg,
|
|
job_name=cfg.job_name,
|
|
)
|
|
|
|
logging.info("[LEARNER] train_cli finished")
|
|
|
|
|
|
def train(cfg: TrainPipelineConfig, job_name: str | None = None):
|
|
"""
|
|
Main training function that initializes and runs the training process.
|
|
|
|
Args:
|
|
cfg (TrainPipelineConfig): The training configuration
|
|
job_name (str | None, optional): Job name for logging. Defaults to None.
|
|
"""
|
|
|
|
cfg.validate()
|
|
|
|
if job_name is None:
|
|
job_name = cfg.job_name
|
|
|
|
if job_name is None:
|
|
raise ValueError("Job name must be specified either in config or as a parameter")
|
|
|
|
# Create logs directory to ensure it exists
|
|
log_dir = os.path.join(cfg.output_dir, "logs")
|
|
os.makedirs(log_dir, exist_ok=True)
|
|
log_file = os.path.join(log_dir, f"learner_{job_name}.log")
|
|
|
|
# Initialize logging with explicit log file
|
|
init_logging(log_file=log_file)
|
|
logging.info(f"Learner logging initialized, writing to {log_file}")
|
|
logging.info(pformat(cfg.to_dict()))
|
|
|
|
# Setup WandB logging if enabled
|
|
if cfg.wandb.enable and cfg.wandb.project:
|
|
from lerobot.common.utils.wandb_utils import WandBLogger
|
|
|
|
wandb_logger = WandBLogger(cfg)
|
|
else:
|
|
wandb_logger = None
|
|
logging.info(colored("Logs will be saved locally.", "yellow", attrs=["bold"]))
|
|
|
|
# Handle resume logic
|
|
cfg = handle_resume_logic(cfg)
|
|
|
|
set_seed(seed=cfg.seed)
|
|
|
|
torch.backends.cudnn.benchmark = True
|
|
torch.backends.cuda.matmul.allow_tf32 = True
|
|
|
|
shutdown_event = setup_process_handlers(use_threads(cfg))
|
|
|
|
start_learner_threads(
|
|
cfg=cfg,
|
|
wandb_logger=wandb_logger,
|
|
shutdown_event=shutdown_event,
|
|
)
|
|
|
|
|
|
def start_learner_threads(
|
|
cfg: TrainPipelineConfig,
|
|
wandb_logger: WandBLogger | None,
|
|
shutdown_event: any, # Event,
|
|
) -> None:
|
|
"""
|
|
Start the learner threads for training.
|
|
|
|
Args:
|
|
cfg (TrainPipelineConfig): Training configuration
|
|
wandb_logger (WandBLogger | None): Logger for metrics
|
|
shutdown_event: Event to signal shutdown
|
|
"""
|
|
# Create multiprocessing queues
|
|
transition_queue = Queue()
|
|
interaction_message_queue = Queue()
|
|
parameters_queue = Queue()
|
|
|
|
concurrency_entity = None
|
|
|
|
if use_threads(cfg):
|
|
from threading import Thread
|
|
|
|
concurrency_entity = Thread
|
|
else:
|
|
from torch.multiprocessing import Process
|
|
|
|
concurrency_entity = Process
|
|
|
|
communication_process = concurrency_entity(
|
|
target=start_learner_server,
|
|
args=(
|
|
parameters_queue,
|
|
transition_queue,
|
|
interaction_message_queue,
|
|
shutdown_event,
|
|
cfg,
|
|
),
|
|
daemon=True,
|
|
)
|
|
communication_process.start()
|
|
|
|
add_actor_information_and_train(
|
|
cfg=cfg,
|
|
wandb_logger=wandb_logger,
|
|
shutdown_event=shutdown_event,
|
|
transition_queue=transition_queue,
|
|
interaction_message_queue=interaction_message_queue,
|
|
parameters_queue=parameters_queue,
|
|
)
|
|
logging.info("[LEARNER] Training process stopped")
|
|
|
|
logging.info("[LEARNER] Closing queues")
|
|
transition_queue.close()
|
|
interaction_message_queue.close()
|
|
parameters_queue.close()
|
|
|
|
communication_process.join()
|
|
logging.info("[LEARNER] Communication process joined")
|
|
|
|
logging.info("[LEARNER] join queues")
|
|
transition_queue.cancel_join_thread()
|
|
interaction_message_queue.cancel_join_thread()
|
|
parameters_queue.cancel_join_thread()
|
|
|
|
logging.info("[LEARNER] queues closed")
|
|
|
|
|
|
#################################################
|
|
# Core algorithm functions #
|
|
#################################################
|
|
|
|
|
|
def add_actor_information_and_train(
|
|
cfg: TrainPipelineConfig,
|
|
wandb_logger: WandBLogger | None,
|
|
shutdown_event: any, # Event,
|
|
transition_queue: Queue,
|
|
interaction_message_queue: Queue,
|
|
parameters_queue: Queue,
|
|
):
|
|
"""
|
|
Handles data transfer from the actor to the learner, manages training updates,
|
|
and logs training progress in an online reinforcement learning setup.
|
|
|
|
This function continuously:
|
|
- Transfers transitions from the actor to the replay buffer.
|
|
- Logs received interaction messages.
|
|
- Ensures training begins only when the replay buffer has a sufficient number of transitions.
|
|
- Samples batches from the replay buffer and performs multiple critic updates.
|
|
- Periodically updates the actor, critic, and temperature optimizers.
|
|
- Logs training statistics, including loss values and optimization frequency.
|
|
|
|
NOTE: This function doesn't have a single responsibility, it should be split into multiple functions
|
|
in the future. The reason why we did that is the GIL in Python. It's super slow the performance
|
|
are divided by 200. So we need to have a single thread that does all the work.
|
|
|
|
Args:
|
|
cfg (TrainPipelineConfig): Configuration object containing hyperparameters.
|
|
wandb_logger (WandBLogger | None): Logger for tracking training progress.
|
|
shutdown_event (Event): Event to signal shutdown.
|
|
transition_queue (Queue): Queue for receiving transitions from the actor.
|
|
interaction_message_queue (Queue): Queue for receiving interaction messages from the actor.
|
|
parameters_queue (Queue): Queue for sending policy parameters to the actor.
|
|
"""
|
|
# Extract all configuration variables at the beginning
|
|
device = get_safe_torch_device(try_device=cfg.policy.device, log=True)
|
|
storage_device = get_safe_torch_device(try_device=cfg.policy.storage_device)
|
|
clip_grad_norm_value = cfg.policy.grad_clip_norm
|
|
online_step_before_learning = cfg.policy.online_step_before_learning
|
|
utd_ratio = cfg.policy.utd_ratio
|
|
fps = cfg.env.fps
|
|
log_freq = cfg.log_freq
|
|
save_freq = cfg.save_freq
|
|
policy_update_freq = cfg.policy.policy_update_freq
|
|
policy_parameters_push_frequency = cfg.policy.actor_learner_config.policy_parameters_push_frequency
|
|
saving_checkpoint = cfg.save_checkpoint
|
|
online_steps = cfg.policy.online_steps
|
|
async_prefetch = cfg.policy.async_prefetch
|
|
|
|
# Initialize logging for multiprocessing
|
|
if not use_threads(cfg):
|
|
log_dir = os.path.join(cfg.output_dir, "logs")
|
|
os.makedirs(log_dir, exist_ok=True)
|
|
log_file = os.path.join(log_dir, f"learner_train_process_{os.getpid()}.log")
|
|
init_logging(log_file=log_file)
|
|
logging.info("Initialized logging for actor information and training process")
|
|
|
|
logging.info("Initializing policy")
|
|
|
|
policy: SACPolicy = make_policy(
|
|
cfg=cfg.policy,
|
|
# ds_meta=cfg.dataset,
|
|
env_cfg=cfg.env,
|
|
)
|
|
|
|
assert isinstance(policy, nn.Module)
|
|
policy.train()
|
|
|
|
push_actor_policy_to_queue(parameters_queue=parameters_queue, policy=policy)
|
|
|
|
last_time_policy_pushed = time.time()
|
|
|
|
optimizers, lr_scheduler = make_optimizers_and_scheduler(cfg=cfg, policy=policy)
|
|
resume_optimization_step, resume_interaction_step = load_training_state(cfg=cfg, optimizers=optimizers)
|
|
|
|
log_training_info(cfg=cfg, policy=policy)
|
|
|
|
replay_buffer = initialize_replay_buffer(cfg, device, storage_device)
|
|
batch_size = cfg.batch_size
|
|
offline_replay_buffer = None
|
|
|
|
if cfg.dataset is not None:
|
|
active_action_dims = None
|
|
# TODO: FIX THIS
|
|
if cfg.env.wrapper.joint_masking_action_space is not None:
|
|
active_action_dims = [
|
|
i for i, mask in enumerate(cfg.env.wrapper.joint_masking_action_space) if mask
|
|
]
|
|
offline_replay_buffer = initialize_offline_replay_buffer(
|
|
cfg=cfg,
|
|
device=device,
|
|
storage_device=storage_device,
|
|
active_action_dims=active_action_dims,
|
|
)
|
|
batch_size: int = batch_size // 2 # We will sample from both replay buffer
|
|
|
|
logging.info("Starting learner thread")
|
|
interaction_message, transition = None, None
|
|
optimization_step = resume_optimization_step if resume_optimization_step is not None else 0
|
|
interaction_step_shift = resume_interaction_step if resume_interaction_step is not None else 0
|
|
|
|
dataset_repo_id = None
|
|
if cfg.dataset is not None:
|
|
dataset_repo_id = cfg.dataset.repo_id
|
|
|
|
# Initialize iterators
|
|
online_iterator = None
|
|
offline_iterator = None
|
|
# NOTE: THIS IS THE MAIN LOOP OF THE LEARNER
|
|
while True:
|
|
# Exit the training loop if shutdown is requested
|
|
if shutdown_event is not None and shutdown_event.is_set():
|
|
logging.info("[LEARNER] Shutdown signal received. Exiting...")
|
|
break
|
|
|
|
# Process all available transitions
|
|
logging.debug("[LEARNER] Waiting for transitions")
|
|
process_transitions(
|
|
transition_queue=transition_queue,
|
|
replay_buffer=replay_buffer,
|
|
offline_replay_buffer=offline_replay_buffer,
|
|
device=device,
|
|
dataset_repo_id=dataset_repo_id,
|
|
shutdown_event=shutdown_event,
|
|
)
|
|
logging.debug("[LEARNER] Received transitions")
|
|
|
|
# Process all available interaction messages
|
|
logging.debug("[LEARNER] Waiting for interactions")
|
|
interaction_message = process_interaction_messages(
|
|
interaction_message_queue=interaction_message_queue,
|
|
interaction_step_shift=interaction_step_shift,
|
|
wandb_logger=wandb_logger,
|
|
shutdown_event=shutdown_event,
|
|
)
|
|
logging.debug("[LEARNER] Received interactions")
|
|
|
|
# Wait until the replay buffer has enough samples
|
|
if len(replay_buffer) < online_step_before_learning:
|
|
continue
|
|
|
|
if online_iterator is None:
|
|
logging.debug("[LEARNER] Initializing online replay buffer iterator")
|
|
online_iterator = replay_buffer.get_iterator(
|
|
batch_size=batch_size, async_prefetch=async_prefetch, queue_size=2
|
|
)
|
|
|
|
if offline_replay_buffer is not None and offline_iterator is None:
|
|
logging.debug("[LEARNER] Initializing offline replay buffer iterator")
|
|
offline_iterator = offline_replay_buffer.get_iterator(
|
|
batch_size=batch_size, async_prefetch=async_prefetch, queue_size=2
|
|
)
|
|
|
|
logging.debug("[LEARNER] Starting optimization loop")
|
|
time_for_one_optimization_step = time.time()
|
|
for _ in range(utd_ratio - 1):
|
|
# Sample from the iterators
|
|
batch = next(online_iterator)
|
|
|
|
if dataset_repo_id is not None:
|
|
batch_offline = next(offline_iterator)
|
|
batch = concatenate_batch_transitions(
|
|
left_batch_transitions=batch, right_batch_transition=batch_offline
|
|
)
|
|
|
|
actions = batch["action"]
|
|
rewards = batch["reward"]
|
|
observations = batch["state"]
|
|
next_observations = batch["next_state"]
|
|
done = batch["done"]
|
|
check_nan_in_transition(observations=observations, actions=actions, next_state=next_observations)
|
|
|
|
observation_features, next_observation_features = get_observation_features(
|
|
policy=policy, observations=observations, next_observations=next_observations
|
|
)
|
|
|
|
# Create a batch dictionary with all required elements for the forward method
|
|
forward_batch = {
|
|
"action": actions,
|
|
"reward": rewards,
|
|
"state": observations,
|
|
"next_state": next_observations,
|
|
"done": done,
|
|
"observation_feature": observation_features,
|
|
"next_observation_feature": next_observation_features,
|
|
}
|
|
|
|
# Use the forward method for critic loss (includes both main critic and grasp critic)
|
|
critic_output = policy.forward(forward_batch, model="critic")
|
|
|
|
# Main critic optimization
|
|
loss_critic = critic_output["loss_critic"]
|
|
optimizers["critic"].zero_grad()
|
|
loss_critic.backward()
|
|
critic_grad_norm = torch.nn.utils.clip_grad_norm_(
|
|
parameters=policy.critic_ensemble.parameters(), max_norm=clip_grad_norm_value
|
|
)
|
|
optimizers["critic"].step()
|
|
|
|
# Grasp critic optimization (if available)
|
|
if policy.config.num_discrete_actions is not None:
|
|
discrete_critic_output = policy.forward(forward_batch, model="grasp_critic")
|
|
loss_grasp_critic = discrete_critic_output["loss_grasp_critic"]
|
|
optimizers["grasp_critic"].zero_grad()
|
|
loss_grasp_critic.backward()
|
|
grasp_critic_grad_norm = torch.nn.utils.clip_grad_norm_(
|
|
parameters=policy.grasp_critic.parameters_to_optimize, max_norm=clip_grad_norm_value
|
|
)
|
|
optimizers["grasp_critic"].step()
|
|
|
|
# Update target networks
|
|
policy.update_target_networks()
|
|
|
|
# Sample for the last update in the UTD ratio
|
|
batch = next(online_iterator)
|
|
|
|
if dataset_repo_id is not None:
|
|
batch_offline = next(offline_iterator)
|
|
batch = concatenate_batch_transitions(
|
|
left_batch_transitions=batch, right_batch_transition=batch_offline
|
|
)
|
|
|
|
actions = batch["action"]
|
|
rewards = batch["reward"]
|
|
observations = batch["state"]
|
|
next_observations = batch["next_state"]
|
|
done = batch["done"]
|
|
|
|
check_nan_in_transition(observations=observations, actions=actions, next_state=next_observations)
|
|
|
|
observation_features, next_observation_features = get_observation_features(
|
|
policy=policy, observations=observations, next_observations=next_observations
|
|
)
|
|
|
|
# Create a batch dictionary with all required elements for the forward method
|
|
forward_batch = {
|
|
"action": actions,
|
|
"reward": rewards,
|
|
"state": observations,
|
|
"next_state": next_observations,
|
|
"done": done,
|
|
"observation_feature": observation_features,
|
|
"next_observation_feature": next_observation_features,
|
|
}
|
|
|
|
# Use the forward method for critic loss (includes both main critic and grasp critic)
|
|
critic_output = policy.forward(forward_batch, model="critic")
|
|
|
|
# Main critic optimization
|
|
loss_critic = critic_output["loss_critic"]
|
|
optimizers["critic"].zero_grad()
|
|
loss_critic.backward()
|
|
critic_grad_norm = torch.nn.utils.clip_grad_norm_(
|
|
parameters=policy.critic_ensemble.parameters(), max_norm=clip_grad_norm_value
|
|
).item()
|
|
optimizers["critic"].step()
|
|
|
|
# Initialize training info dictionary
|
|
training_infos = {
|
|
"loss_critic": loss_critic.item(),
|
|
"critic_grad_norm": critic_grad_norm,
|
|
}
|
|
|
|
# Grasp critic optimization (if available)
|
|
if policy.config.num_discrete_actions is not None:
|
|
discrete_critic_output = policy.forward(forward_batch, model="grasp_critic")
|
|
loss_grasp_critic = discrete_critic_output["loss_grasp_critic"]
|
|
optimizers["grasp_critic"].zero_grad()
|
|
loss_grasp_critic.backward()
|
|
grasp_critic_grad_norm = torch.nn.utils.clip_grad_norm_(
|
|
parameters=policy.grasp_critic.parameters_to_optimize, max_norm=clip_grad_norm_value
|
|
).item()
|
|
optimizers["grasp_critic"].step()
|
|
|
|
# Add grasp critic info to training info
|
|
training_infos["loss_grasp_critic"] = loss_grasp_critic.item()
|
|
training_infos["grasp_critic_grad_norm"] = grasp_critic_grad_norm
|
|
|
|
# Actor and temperature optimization (at specified frequency)
|
|
if optimization_step % policy_update_freq == 0:
|
|
for _ in range(policy_update_freq):
|
|
# Actor optimization
|
|
actor_output = policy.forward(forward_batch, model="actor")
|
|
loss_actor = actor_output["loss_actor"]
|
|
optimizers["actor"].zero_grad()
|
|
loss_actor.backward()
|
|
actor_grad_norm = torch.nn.utils.clip_grad_norm_(
|
|
parameters=policy.actor.parameters_to_optimize, max_norm=clip_grad_norm_value
|
|
).item()
|
|
optimizers["actor"].step()
|
|
|
|
# Add actor info to training info
|
|
training_infos["loss_actor"] = loss_actor.item()
|
|
training_infos["actor_grad_norm"] = actor_grad_norm
|
|
|
|
# Temperature optimization
|
|
temperature_output = policy.forward(forward_batch, model="temperature")
|
|
loss_temperature = temperature_output["loss_temperature"]
|
|
optimizers["temperature"].zero_grad()
|
|
loss_temperature.backward()
|
|
temp_grad_norm = torch.nn.utils.clip_grad_norm_(
|
|
parameters=[policy.log_alpha], max_norm=clip_grad_norm_value
|
|
).item()
|
|
optimizers["temperature"].step()
|
|
|
|
# Add temperature info to training info
|
|
training_infos["loss_temperature"] = loss_temperature.item()
|
|
training_infos["temperature_grad_norm"] = temp_grad_norm
|
|
training_infos["temperature"] = policy.temperature
|
|
|
|
# Update temperature
|
|
policy.update_temperature()
|
|
|
|
# Push policy to actors if needed
|
|
if time.time() - last_time_policy_pushed > policy_parameters_push_frequency:
|
|
push_actor_policy_to_queue(parameters_queue=parameters_queue, policy=policy)
|
|
last_time_policy_pushed = time.time()
|
|
|
|
# Update target networks
|
|
policy.update_target_networks()
|
|
|
|
# Log training metrics at specified intervals
|
|
if optimization_step % log_freq == 0:
|
|
training_infos["replay_buffer_size"] = len(replay_buffer)
|
|
if offline_replay_buffer is not None:
|
|
training_infos["offline_replay_buffer_size"] = len(offline_replay_buffer)
|
|
training_infos["Optimization step"] = optimization_step
|
|
|
|
# Log training metrics
|
|
if wandb_logger:
|
|
wandb_logger.log_dict(d=training_infos, mode="train", custom_step_key="Optimization step")
|
|
|
|
# Calculate and log optimization frequency
|
|
time_for_one_optimization_step = time.time() - time_for_one_optimization_step
|
|
frequency_for_one_optimization_step = 1 / (time_for_one_optimization_step + 1e-9)
|
|
|
|
logging.info(f"[LEARNER] Optimization frequency loop [Hz]: {frequency_for_one_optimization_step}")
|
|
|
|
# Log optimization frequency
|
|
if wandb_logger:
|
|
wandb_logger.log_dict(
|
|
{
|
|
"Optimization frequency loop [Hz]": frequency_for_one_optimization_step,
|
|
"Optimization step": optimization_step,
|
|
},
|
|
mode="train",
|
|
custom_step_key="Optimization step",
|
|
)
|
|
|
|
optimization_step += 1
|
|
if optimization_step % log_freq == 0:
|
|
logging.info(f"[LEARNER] Number of optimization step: {optimization_step}")
|
|
|
|
# Save checkpoint at specified intervals
|
|
if saving_checkpoint and (optimization_step % save_freq == 0 or optimization_step == online_steps):
|
|
save_training_checkpoint(
|
|
cfg=cfg,
|
|
optimization_step=optimization_step,
|
|
online_steps=online_steps,
|
|
interaction_message=interaction_message,
|
|
policy=policy,
|
|
optimizers=optimizers,
|
|
replay_buffer=replay_buffer,
|
|
offline_replay_buffer=offline_replay_buffer,
|
|
dataset_repo_id=dataset_repo_id,
|
|
fps=fps,
|
|
)
|
|
|
|
|
|
def start_learner_server(
|
|
parameters_queue: Queue,
|
|
transition_queue: Queue,
|
|
interaction_message_queue: Queue,
|
|
shutdown_event: any, # Event,
|
|
cfg: TrainPipelineConfig,
|
|
):
|
|
"""
|
|
Start the learner server for training.
|
|
|
|
Args:
|
|
parameters_queue: Queue for sending policy parameters to the actor
|
|
transition_queue: Queue for receiving transitions from the actor
|
|
interaction_message_queue: Queue for receiving interaction messages from the actor
|
|
shutdown_event: Event to signal shutdown
|
|
cfg: Training configuration
|
|
"""
|
|
if not use_threads(cfg):
|
|
# Create a process-specific log file
|
|
log_dir = os.path.join(cfg.output_dir, "logs")
|
|
os.makedirs(log_dir, exist_ok=True)
|
|
log_file = os.path.join(log_dir, f"learner_server_process_{os.getpid()}.log")
|
|
|
|
# Initialize logging with explicit log file
|
|
init_logging(log_file=log_file)
|
|
logging.info("Learner server process logging initialized")
|
|
|
|
# Setup process handlers to handle shutdown signal
|
|
# But use shutdown event from the main process
|
|
# Return back for MP
|
|
setup_process_handlers(False)
|
|
|
|
service = learner_service.LearnerService(
|
|
shutdown_event=shutdown_event,
|
|
parameters_queue=parameters_queue,
|
|
seconds_between_pushes=cfg.policy.actor_learner_config.policy_parameters_push_frequency,
|
|
transition_queue=transition_queue,
|
|
interaction_message_queue=interaction_message_queue,
|
|
)
|
|
|
|
server = grpc.server(
|
|
ThreadPoolExecutor(max_workers=learner_service.MAX_WORKERS),
|
|
options=[
|
|
("grpc.max_receive_message_length", learner_service.MAX_MESSAGE_SIZE),
|
|
("grpc.max_send_message_length", learner_service.MAX_MESSAGE_SIZE),
|
|
],
|
|
)
|
|
|
|
hilserl_pb2_grpc.add_LearnerServiceServicer_to_server(
|
|
service,
|
|
server,
|
|
)
|
|
|
|
host = cfg.policy.actor_learner_config.learner_host
|
|
port = cfg.policy.actor_learner_config.learner_port
|
|
|
|
server.add_insecure_port(f"{host}:{port}")
|
|
server.start()
|
|
logging.info("[LEARNER] gRPC server started")
|
|
|
|
shutdown_event.wait()
|
|
logging.info("[LEARNER] Stopping gRPC server...")
|
|
server.stop(learner_service.STUTDOWN_TIMEOUT)
|
|
logging.info("[LEARNER] gRPC server stopped")
|
|
|
|
|
|
def save_training_checkpoint(
|
|
cfg: TrainPipelineConfig,
|
|
optimization_step: int,
|
|
online_steps: int,
|
|
interaction_message: dict | None,
|
|
policy: nn.Module,
|
|
optimizers: dict[str, Optimizer],
|
|
replay_buffer: ReplayBuffer,
|
|
offline_replay_buffer: ReplayBuffer | None = None,
|
|
dataset_repo_id: str | None = None,
|
|
fps: int = 30,
|
|
) -> None:
|
|
"""
|
|
Save training checkpoint and associated data.
|
|
|
|
This function performs the following steps:
|
|
1. Creates a checkpoint directory with the current optimization step
|
|
2. Saves the policy model, configuration, and optimizer states
|
|
3. Saves the current interaction step for resuming training
|
|
4. Updates the "last" checkpoint symlink to point to this checkpoint
|
|
5. Saves the replay buffer as a dataset for later use
|
|
6. If an offline replay buffer exists, saves it as a separate dataset
|
|
|
|
Args:
|
|
cfg: Training configuration
|
|
optimization_step: Current optimization step
|
|
online_steps: Total number of online steps
|
|
interaction_message: Dictionary containing interaction information
|
|
policy: Policy model to save
|
|
optimizers: Dictionary of optimizers
|
|
replay_buffer: Replay buffer to save as dataset
|
|
offline_replay_buffer: Optional offline replay buffer to save
|
|
dataset_repo_id: Repository ID for dataset
|
|
fps: Frames per second for dataset
|
|
"""
|
|
logging.info(f"Checkpoint policy after step {optimization_step}")
|
|
_num_digits = max(6, len(str(online_steps)))
|
|
interaction_step = interaction_message["Interaction step"] if interaction_message is not None else 0
|
|
|
|
# Create checkpoint directory
|
|
checkpoint_dir = get_step_checkpoint_dir(cfg.output_dir, online_steps, optimization_step)
|
|
|
|
# Save checkpoint
|
|
save_checkpoint(
|
|
checkpoint_dir=checkpoint_dir,
|
|
step=optimization_step,
|
|
cfg=cfg,
|
|
policy=policy,
|
|
optimizer=optimizers,
|
|
scheduler=None,
|
|
)
|
|
|
|
# Save interaction step manually
|
|
training_state_dir = os.path.join(checkpoint_dir, TRAINING_STATE_DIR)
|
|
os.makedirs(training_state_dir, exist_ok=True)
|
|
training_state = {"step": optimization_step, "interaction_step": interaction_step}
|
|
torch.save(training_state, os.path.join(training_state_dir, "training_state.pt"))
|
|
|
|
# Update the "last" symlink
|
|
update_last_checkpoint(checkpoint_dir)
|
|
|
|
# TODO : temporarly save replay buffer here, remove later when on the robot
|
|
# We want to control this with the keyboard inputs
|
|
dataset_dir = os.path.join(cfg.output_dir, "dataset")
|
|
if os.path.exists(dataset_dir) and os.path.isdir(dataset_dir):
|
|
shutil.rmtree(dataset_dir)
|
|
|
|
# Save dataset
|
|
# NOTE: Handle the case where the dataset repo id is not specified in the config
|
|
# eg. RL training without demonstrations data
|
|
repo_id_buffer_save = cfg.env.task if dataset_repo_id is None else dataset_repo_id
|
|
replay_buffer.to_lerobot_dataset(repo_id=repo_id_buffer_save, fps=fps, root=dataset_dir)
|
|
|
|
if offline_replay_buffer is not None:
|
|
dataset_offline_dir = os.path.join(cfg.output_dir, "dataset_offline")
|
|
if os.path.exists(dataset_offline_dir) and os.path.isdir(dataset_offline_dir):
|
|
shutil.rmtree(dataset_offline_dir)
|
|
|
|
offline_replay_buffer.to_lerobot_dataset(
|
|
cfg.dataset.repo_id,
|
|
fps=fps,
|
|
root=dataset_offline_dir,
|
|
)
|
|
|
|
logging.info("Resume training")
|
|
|
|
|
|
def make_optimizers_and_scheduler(cfg: TrainPipelineConfig, policy: nn.Module):
|
|
"""
|
|
Creates and returns optimizers for the actor, critic, and temperature components of a reinforcement learning policy.
|
|
|
|
This function sets up Adam optimizers for:
|
|
- The **actor network**, ensuring that only relevant parameters are optimized.
|
|
- The **critic ensemble**, which evaluates the value function.
|
|
- The **temperature parameter**, which controls the entropy in soft actor-critic (SAC)-like methods.
|
|
|
|
It also initializes a learning rate scheduler, though currently, it is set to `None`.
|
|
|
|
**NOTE:**
|
|
- If the encoder is shared, its parameters are excluded from the actor's optimization process.
|
|
- The policy's log temperature (`log_alpha`) is wrapped in a list to ensure proper optimization as a standalone tensor.
|
|
|
|
Args:
|
|
cfg: Configuration object containing hyperparameters.
|
|
policy (nn.Module): The policy model containing the actor, critic, and temperature components.
|
|
|
|
Returns:
|
|
Tuple[Dict[str, torch.optim.Optimizer], Optional[torch.optim.lr_scheduler._LRScheduler]]:
|
|
A tuple containing:
|
|
- `optimizers`: A dictionary mapping component names ("actor", "critic", "temperature") to their respective Adam optimizers.
|
|
- `lr_scheduler`: Currently set to `None` but can be extended to support learning rate scheduling.
|
|
|
|
"""
|
|
optimizer_actor = torch.optim.Adam(
|
|
# NOTE: Handle the case of shared encoder where the encoder weights are not optimized with the gradient of the actor
|
|
params=policy.actor.parameters_to_optimize,
|
|
lr=cfg.policy.actor_lr,
|
|
)
|
|
optimizer_critic = torch.optim.Adam(params=policy.critic_ensemble.parameters(), lr=cfg.policy.critic_lr)
|
|
|
|
if cfg.policy.num_discrete_actions is not None:
|
|
optimizer_grasp_critic = torch.optim.Adam(
|
|
params=policy.grasp_critic.parameters_to_optimize, lr=cfg.policy.critic_lr
|
|
)
|
|
optimizer_temperature = torch.optim.Adam(params=[policy.log_alpha], lr=cfg.policy.critic_lr)
|
|
lr_scheduler = None
|
|
optimizers = {
|
|
"actor": optimizer_actor,
|
|
"critic": optimizer_critic,
|
|
"temperature": optimizer_temperature,
|
|
}
|
|
if cfg.policy.num_discrete_actions is not None:
|
|
optimizers["grasp_critic"] = optimizer_grasp_critic
|
|
return optimizers, lr_scheduler
|
|
|
|
|
|
#################################################
|
|
# Training setup functions #
|
|
#################################################
|
|
|
|
|
|
def handle_resume_logic(cfg: TrainPipelineConfig) -> TrainPipelineConfig:
|
|
"""
|
|
Handle the resume logic for training.
|
|
|
|
If resume is True:
|
|
- Verifies that a checkpoint exists
|
|
- Loads the checkpoint configuration
|
|
- Logs resumption details
|
|
- Returns the checkpoint configuration
|
|
|
|
If resume is False:
|
|
- Checks if an output directory exists (to prevent accidental overwriting)
|
|
- Returns the original configuration
|
|
|
|
Args:
|
|
cfg (TrainPipelineConfig): The training configuration
|
|
|
|
Returns:
|
|
TrainPipelineConfig: The updated configuration
|
|
|
|
Raises:
|
|
RuntimeError: If resume is True but no checkpoint found, or if resume is False but directory exists
|
|
"""
|
|
out_dir = cfg.output_dir
|
|
|
|
# Case 1: Not resuming, but need to check if directory exists to prevent overwrites
|
|
if not cfg.resume:
|
|
checkpoint_dir = os.path.join(out_dir, CHECKPOINTS_DIR, LAST_CHECKPOINT_LINK)
|
|
if os.path.exists(checkpoint_dir):
|
|
raise RuntimeError(
|
|
f"Output directory {checkpoint_dir} already exists. Use `resume=true` to resume training."
|
|
)
|
|
return cfg
|
|
|
|
# Case 2: Resuming training
|
|
checkpoint_dir = os.path.join(out_dir, CHECKPOINTS_DIR, LAST_CHECKPOINT_LINK)
|
|
if not os.path.exists(checkpoint_dir):
|
|
raise RuntimeError(f"No model checkpoint found in {checkpoint_dir} for resume=True")
|
|
|
|
# Log that we found a valid checkpoint and are resuming
|
|
logging.info(
|
|
colored(
|
|
"Valid checkpoint found: resume=True detected, resuming previous run",
|
|
color="yellow",
|
|
attrs=["bold"],
|
|
)
|
|
)
|
|
|
|
# Load config using Draccus
|
|
checkpoint_cfg_path = os.path.join(checkpoint_dir, PRETRAINED_MODEL_DIR, "train_config.json")
|
|
checkpoint_cfg = TrainPipelineConfig.from_pretrained(checkpoint_cfg_path)
|
|
|
|
# Ensure resume flag is set in returned config
|
|
checkpoint_cfg.resume = True
|
|
return checkpoint_cfg
|
|
|
|
|
|
def load_training_state(
|
|
cfg: TrainPipelineConfig,
|
|
optimizers: Optimizer | dict[str, Optimizer],
|
|
):
|
|
"""
|
|
Loads the training state (optimizers, step count, etc.) from a checkpoint.
|
|
|
|
Args:
|
|
cfg (TrainPipelineConfig): Training configuration
|
|
optimizers (Optimizer | dict): Optimizers to load state into
|
|
|
|
Returns:
|
|
tuple: (optimization_step, interaction_step) or (None, None) if not resuming
|
|
"""
|
|
if not cfg.resume:
|
|
return None, None
|
|
|
|
# Construct path to the last checkpoint directory
|
|
checkpoint_dir = os.path.join(cfg.output_dir, CHECKPOINTS_DIR, LAST_CHECKPOINT_LINK)
|
|
|
|
logging.info(f"Loading training state from {checkpoint_dir}")
|
|
|
|
try:
|
|
# Use the utility function from train_utils which loads the optimizer state
|
|
step, optimizers, _ = utils_load_training_state(Path(checkpoint_dir), optimizers, None)
|
|
|
|
# Load interaction step separately from training_state.pt
|
|
training_state_path = os.path.join(checkpoint_dir, TRAINING_STATE_DIR, "training_state.pt")
|
|
interaction_step = 0
|
|
if os.path.exists(training_state_path):
|
|
training_state = torch.load(training_state_path, weights_only=False)
|
|
interaction_step = training_state.get("interaction_step", 0)
|
|
|
|
logging.info(f"Resuming from step {step}, interaction step {interaction_step}")
|
|
return step, interaction_step
|
|
|
|
except Exception as e:
|
|
logging.error(f"Failed to load training state: {e}")
|
|
return None, None
|
|
|
|
|
|
def log_training_info(cfg: TrainPipelineConfig, policy: nn.Module) -> None:
|
|
"""
|
|
Log information about the training process.
|
|
|
|
Args:
|
|
cfg (TrainPipelineConfig): Training configuration
|
|
policy (nn.Module): Policy model
|
|
"""
|
|
num_learnable_params = sum(p.numel() for p in policy.parameters() if p.requires_grad)
|
|
num_total_params = sum(p.numel() for p in policy.parameters())
|
|
|
|
logging.info(colored("Output dir:", "yellow", attrs=["bold"]) + f" {cfg.output_dir}")
|
|
logging.info(f"{cfg.env.task=}")
|
|
logging.info(f"{cfg.policy.online_steps=}")
|
|
logging.info(f"{num_learnable_params=} ({format_big_number(num_learnable_params)})")
|
|
logging.info(f"{num_total_params=} ({format_big_number(num_total_params)})")
|
|
|
|
|
|
def initialize_replay_buffer(cfg: TrainPipelineConfig, device: str, storage_device: str) -> ReplayBuffer:
|
|
"""
|
|
Initialize a replay buffer, either empty or from a dataset if resuming.
|
|
|
|
Args:
|
|
cfg (TrainPipelineConfig): Training configuration
|
|
device (str): Device to store tensors on
|
|
storage_device (str): Device for storage optimization
|
|
|
|
Returns:
|
|
ReplayBuffer: Initialized replay buffer
|
|
"""
|
|
if not cfg.resume:
|
|
return ReplayBuffer(
|
|
capacity=cfg.policy.online_buffer_capacity,
|
|
device=device,
|
|
state_keys=cfg.policy.input_features.keys(),
|
|
storage_device=storage_device,
|
|
optimize_memory=True,
|
|
)
|
|
|
|
logging.info("Resume training load the online dataset")
|
|
dataset_path = os.path.join(cfg.output_dir, "dataset")
|
|
|
|
# NOTE: In RL is possible to not have a dataset.
|
|
repo_id = None
|
|
if cfg.dataset is not None:
|
|
repo_id = cfg.dataset.repo_id
|
|
dataset = LeRobotDataset(
|
|
repo_id=repo_id,
|
|
root=dataset_path,
|
|
)
|
|
return ReplayBuffer.from_lerobot_dataset(
|
|
lerobot_dataset=dataset,
|
|
capacity=cfg.policy.online_buffer_capacity,
|
|
device=device,
|
|
state_keys=cfg.policy.input_features.keys(),
|
|
optimize_memory=True,
|
|
)
|
|
|
|
|
|
def initialize_offline_replay_buffer(
|
|
cfg: TrainPipelineConfig,
|
|
device: str,
|
|
storage_device: str,
|
|
active_action_dims: list[int] | None = None,
|
|
) -> ReplayBuffer:
|
|
"""
|
|
Initialize an offline replay buffer from a dataset.
|
|
|
|
Args:
|
|
cfg (TrainPipelineConfig): Training configuration
|
|
device (str): Device to store tensors on
|
|
storage_device (str): Device for storage optimization
|
|
active_action_dims (list[int] | None): Active action dimensions for masking
|
|
|
|
Returns:
|
|
ReplayBuffer: Initialized offline replay buffer
|
|
"""
|
|
if not cfg.resume:
|
|
logging.info("make_dataset offline buffer")
|
|
offline_dataset = make_dataset(cfg)
|
|
else:
|
|
logging.info("load offline dataset")
|
|
dataset_offline_path = os.path.join(cfg.output_dir, "dataset_offline")
|
|
offline_dataset = LeRobotDataset(
|
|
repo_id=cfg.dataset.repo_id,
|
|
root=dataset_offline_path,
|
|
)
|
|
|
|
logging.info("Convert to a offline replay buffer")
|
|
offline_replay_buffer = ReplayBuffer.from_lerobot_dataset(
|
|
offline_dataset,
|
|
device=device,
|
|
state_keys=cfg.policy.input_features.keys(),
|
|
action_mask=active_action_dims,
|
|
action_delta=cfg.env.wrapper.delta_action,
|
|
storage_device=storage_device,
|
|
optimize_memory=True,
|
|
capacity=cfg.policy.offline_buffer_capacity,
|
|
)
|
|
return offline_replay_buffer
|
|
|
|
|
|
#################################################
|
|
# Utilities/Helpers functions #
|
|
#################################################
|
|
|
|
|
|
def get_observation_features(
|
|
policy: SACPolicy, observations: torch.Tensor, next_observations: torch.Tensor
|
|
) -> tuple[torch.Tensor | None, torch.Tensor | None]:
|
|
"""
|
|
Get observation features from the policy encoder. It act as cache for the observation features.
|
|
when the encoder is frozen, the observation features are not updated.
|
|
We can save compute by caching the observation features.
|
|
|
|
Args:
|
|
policy: The policy model
|
|
observations: The current observations
|
|
next_observations: The next observations
|
|
|
|
Returns:
|
|
tuple: observation_features, next_observation_features
|
|
"""
|
|
|
|
if policy.config.vision_encoder_name is None or not policy.config.freeze_vision_encoder:
|
|
return None, None
|
|
|
|
with torch.no_grad():
|
|
observation_features = (
|
|
policy.actor.encoder(observations) if policy.actor.encoder is not None else None
|
|
)
|
|
next_observation_features = (
|
|
policy.actor.encoder(next_observations) if policy.actor.encoder is not None else None
|
|
)
|
|
|
|
return observation_features, next_observation_features
|
|
|
|
|
|
def use_threads(cfg: TrainPipelineConfig) -> bool:
|
|
return cfg.policy.concurrency.learner == "threads"
|
|
|
|
|
|
def check_nan_in_transition(
|
|
observations: torch.Tensor,
|
|
actions: torch.Tensor,
|
|
next_state: torch.Tensor,
|
|
raise_error: bool = False,
|
|
) -> bool:
|
|
"""
|
|
Check for NaN values in transition data.
|
|
|
|
Args:
|
|
observations: Dictionary of observation tensors
|
|
actions: Action tensor
|
|
next_state: Dictionary of next state tensors
|
|
raise_error: If True, raises ValueError when NaN is detected
|
|
|
|
Returns:
|
|
bool: True if NaN values were detected, False otherwise
|
|
"""
|
|
nan_detected = False
|
|
|
|
# Check observations
|
|
for key, tensor in observations.items():
|
|
if torch.isnan(tensor).any():
|
|
logging.error(f"observations[{key}] contains NaN values")
|
|
nan_detected = True
|
|
if raise_error:
|
|
raise ValueError(f"NaN detected in observations[{key}]")
|
|
|
|
# Check next state
|
|
for key, tensor in next_state.items():
|
|
if torch.isnan(tensor).any():
|
|
logging.error(f"next_state[{key}] contains NaN values")
|
|
nan_detected = True
|
|
if raise_error:
|
|
raise ValueError(f"NaN detected in next_state[{key}]")
|
|
|
|
# Check actions
|
|
if torch.isnan(actions).any():
|
|
logging.error("actions contains NaN values")
|
|
nan_detected = True
|
|
if raise_error:
|
|
raise ValueError("NaN detected in actions")
|
|
|
|
return nan_detected
|
|
|
|
|
|
def push_actor_policy_to_queue(parameters_queue: Queue, policy: nn.Module):
|
|
logging.debug("[LEARNER] Pushing actor policy to the queue")
|
|
state_dict = move_state_dict_to_device(policy.actor.state_dict(), device="cpu")
|
|
state_bytes = state_to_bytes(state_dict)
|
|
parameters_queue.put(state_bytes)
|
|
|
|
|
|
def process_interaction_message(
|
|
message, interaction_step_shift: int, wandb_logger: WandBLogger | None = None
|
|
):
|
|
"""Process a single interaction message with consistent handling."""
|
|
message = bytes_to_python_object(message)
|
|
# Shift interaction step for consistency with checkpointed state
|
|
message["Interaction step"] += interaction_step_shift
|
|
|
|
# Log if logger available
|
|
if wandb_logger:
|
|
wandb_logger.log_dict(d=message, mode="train", custom_step_key="Interaction step")
|
|
|
|
return message
|
|
|
|
|
|
def process_transitions(
|
|
transition_queue: Queue,
|
|
replay_buffer: ReplayBuffer,
|
|
offline_replay_buffer: ReplayBuffer,
|
|
device: str,
|
|
dataset_repo_id: str | None,
|
|
shutdown_event: any,
|
|
):
|
|
"""Process all available transitions from the queue.
|
|
|
|
Args:
|
|
transition_queue: Queue for receiving transitions from the actor
|
|
replay_buffer: Replay buffer to add transitions to
|
|
offline_replay_buffer: Offline replay buffer to add transitions to
|
|
device: Device to move transitions to
|
|
dataset_repo_id: Repository ID for dataset
|
|
shutdown_event: Event to signal shutdown
|
|
"""
|
|
while not transition_queue.empty() and not shutdown_event.is_set():
|
|
transition_list = transition_queue.get()
|
|
transition_list = bytes_to_transitions(buffer=transition_list)
|
|
|
|
for transition in transition_list:
|
|
transition = move_transition_to_device(transition=transition, device=device)
|
|
|
|
# Skip transitions with NaN values
|
|
if check_nan_in_transition(
|
|
observations=transition["state"],
|
|
actions=transition["action"],
|
|
next_state=transition["next_state"],
|
|
):
|
|
logging.warning("[LEARNER] NaN detected in transition, skipping")
|
|
continue
|
|
|
|
replay_buffer.add(**transition)
|
|
|
|
# Add to offline buffer if it's an intervention
|
|
if dataset_repo_id is not None and transition.get("complementary_info", {}).get(
|
|
"is_intervention"
|
|
):
|
|
offline_replay_buffer.add(**transition)
|
|
|
|
|
|
def process_interaction_messages(
|
|
interaction_message_queue: Queue,
|
|
interaction_step_shift: int,
|
|
wandb_logger: WandBLogger | None,
|
|
shutdown_event: any,
|
|
) -> dict | None:
|
|
"""Process all available interaction messages from the queue.
|
|
|
|
Args:
|
|
interaction_message_queue: Queue for receiving interaction messages
|
|
interaction_step_shift: Amount to shift interaction step by
|
|
wandb_logger: Logger for tracking progress
|
|
shutdown_event: Event to signal shutdown
|
|
|
|
Returns:
|
|
dict | None: The last interaction message processed, or None if none were processed
|
|
"""
|
|
last_message = None
|
|
while not interaction_message_queue.empty() and not shutdown_event.is_set():
|
|
message = interaction_message_queue.get()
|
|
last_message = process_interaction_message(
|
|
message=message,
|
|
interaction_step_shift=interaction_step_shift,
|
|
wandb_logger=wandb_logger,
|
|
)
|
|
|
|
return last_message
|
|
|
|
|
|
if __name__ == "__main__":
|
|
train_cli()
|
|
logging.info("[LEARNER] main finished")
|