Cleaned `learner_server.py`. Added several block function to improve readability.

Co-authored-by: Adil Zouitine <adilzouitinegm@gmail.com>
This commit is contained in:
Michel Aractingi 2025-01-31 08:33:33 +00:00
parent 367dfe51c6
commit 7c89bd1018
2 changed files with 166 additions and 134 deletions

View File

@ -21,7 +21,7 @@ training:
eval_freq: 2500
log_freq: 500
save_freq: 1000000
save_freq: 2000000
online_steps: 1000000
online_rollout_n_episodes: 10

View File

@ -34,8 +34,7 @@ from deepdiff import DeepDiff
from omegaconf import DictConfig, OmegaConf
from termcolor import colored
from torch import nn
from lerobot.common.datasets.factory import make_dataset
from torch.optim.optimizer import Optimizer
# TODO: Remove the import of maniskill
from lerobot.common.datasets.lerobot_dataset import LeRobotDataset
@ -53,18 +52,164 @@ from lerobot.common.utils.utils import (
)
from lerobot.scripts.server.buffer import (
ReplayBuffer,
concatenate_batch_transitions,
move_state_dict_to_device,
move_transition_to_device,
)
logging.basicConfig(level=logging.INFO)
# TODO: Implement it in cleaner way maybe
transition_queue = queue.Queue()
interaction_message_queue = queue.Queue()
def handle_resume_logic(cfg: DictConfig, out_dir: str) -> DictConfig:
if not cfg.resume:
if Logger.get_last_checkpoint_dir(out_dir).exists():
raise RuntimeError(
f"Output directory {Logger.get_last_checkpoint_dir(out_dir)} already exists. "
"Use `resume=true` to resume training."
)
return cfg
# if resume == True
checkpoint_dir = Logger.get_last_checkpoint_dir(out_dir)
if not checkpoint_dir.exists():
raise RuntimeError(f"No model checkpoint found in {checkpoint_dir} for resume=True")
checkpoint_cfg_path = str(Logger.get_last_pretrained_model_dir(out_dir) / "config.yaml")
logging.info(
colored(
"Resume=True detected, resuming previous run",
color="yellow",
attrs=["bold"],
)
)
checkpoint_cfg = init_hydra_config(checkpoint_cfg_path)
diff = DeepDiff(OmegaConf.to_container(checkpoint_cfg), OmegaConf.to_container(cfg))
if "values_changed" in diff and "root['resume']" in diff["values_changed"]:
del diff["values_changed"]["root['resume']"]
if len(diff) > 0:
logging.warning(
f"Differences between the checkpoint config and the provided config detected: \n{pformat(diff)}\n"
"Checkpoint configuration takes precedence."
)
checkpoint_cfg.resume = True
return checkpoint_cfg
def load_training_state(
cfg: DictConfig,
logger: Logger,
optimizers: Optimizer | dict,
):
if not cfg.resume:
return None, None
training_state = torch.load(logger.last_checkpoint_dir / logger.training_state_file_name)
if isinstance(training_state["optimizer"], dict):
assert set(training_state["optimizer"].keys()) == set(optimizers.keys())
for k, v in training_state["optimizer"].items():
optimizers[k].load_state_dict(v)
else:
optimizers.load_state_dict(training_state["optimizer"])
set_global_random_state({k: training_state[k] for k in get_global_random_state()})
return training_state["step"], training_state["interaction_step"]
def log_training_info(cfg: DictConfig, out_dir: str, policy: nn.Module) -> None:
num_learnable_params = sum(p.numel() for p in policy.parameters() if p.requires_grad)
num_total_params = sum(p.numel() for p in policy.parameters())
log_output_dir(out_dir)
logging.info(f"{cfg.env.task=}")
logging.info(f"{cfg.training.online_steps=}")
logging.info(f"{num_learnable_params=} ({format_big_number(num_learnable_params)})")
logging.info(f"{num_total_params=} ({format_big_number(num_total_params)})")
def initialize_replay_buffer(cfg: DictConfig, logger: Logger, device: str) -> ReplayBuffer:
if not cfg.resume:
return ReplayBuffer(
capacity=cfg.training.online_buffer_capacity,
device=device,
state_keys=cfg.policy.input_shapes.keys(),
)
dataset = LeRobotDataset(
repo_id=cfg.dataset_repo_id, local_files_only=True, root=logger.log_dir / "dataset"
)
return ReplayBuffer.from_lerobot_dataset(
lerobot_dataset=dataset,
capacity=cfg.training.online_buffer_capacity,
device=device,
state_keys=cfg.policy.input_shapes.keys(),
)
def start_learner_threads(
cfg: DictConfig,
device: str,
replay_buffer: ReplayBuffer,
offline_replay_buffer: ReplayBuffer,
batch_size: int,
optimizers: dict,
policy: SACPolicy,
policy_lock: Lock,
logger: Logger,
resume_optimization_step: int | None = None,
resume_interaction_step: int | None = None,
) -> None:
actor_ip = cfg.actor_learner_config.actor_ip
port = cfg.actor_learner_config.port
server_thread = Thread(
target=stream_transitions_from_actor,
args=(
actor_ip,
port,
),
daemon=True,
)
transition_thread = Thread(
target=add_actor_information_and_train,
daemon=True,
args=(
cfg,
device,
replay_buffer,
offline_replay_buffer,
batch_size,
optimizers,
policy,
policy_lock,
logger,
resume_optimization_step,
resume_interaction_step,
),
)
param_push_thread = Thread(
target=learner_push_parameters,
args=(policy, policy_lock, actor_ip, port, 15),
daemon=True,
)
server_thread.start()
transition_thread.start()
param_push_thread.start()
param_push_thread.join()
transition_thread.join()
server_thread.join()
def stream_transitions_from_actor(host="127.0.0.1", port=50051):
"""
Runs a gRPC client that listens for transition and interaction messages from an Actor service.
@ -373,49 +518,7 @@ def train(cfg: DictConfig, out_dir: str | None = None, job_name: str | None = No
logging.info(pformat(OmegaConf.to_container(cfg)))
logger = Logger(cfg, out_dir, wandb_job_name=job_name)
## Handle resume by reloading the state of the policy and optimization
# If we are resuming a run, we need to check that a checkpoint exists in the log directory, and we need
# to check for any differences between the provided config and the checkpoint's config.
if cfg.resume:
if not Logger.get_last_checkpoint_dir(out_dir).exists():
raise RuntimeError(
"You have set resume=True, but there is no model checkpoint in "
f"{Logger.get_last_checkpoint_dir(out_dir)}"
)
checkpoint_cfg_path = str(Logger.get_last_pretrained_model_dir(out_dir) / "config.yaml")
logging.info(
colored(
"You have set resume=True, indicating that you wish to resume a run",
color="yellow",
attrs=["bold"],
)
)
# Get the configuration file from the last checkpoint.
checkpoint_cfg = init_hydra_config(checkpoint_cfg_path)
# Check for differences between the checkpoint configuration and provided configuration.
diff = DeepDiff(OmegaConf.to_container(checkpoint_cfg), OmegaConf.to_container(cfg))
# Ignore the `resume` and parameters.
if "values_changed" in diff and "root['resume']" in diff["values_changed"]:
del diff["values_changed"]["root['resume']"]
# Log a warning about differences between the checkpoint configuration and the provided
# configuration.
if len(diff) > 0:
logging.warning(
"At least one difference was detected between the checkpoint configuration and "
f"the provided configuration: \n{pformat(diff)}\nNote that the checkpoint configuration "
"takes precedence.",
)
# Use the checkpoint config instead of the provided config (but keep `resume` parameter).
cfg = checkpoint_cfg
cfg.resume = True
elif Logger.get_last_checkpoint_dir(out_dir).exists():
raise RuntimeError(
f"The configured output directory {Logger.get_last_checkpoint_dir(out_dir)} already exists. If "
"you meant to resume training, please use `resume=true` in your command or yaml configuration."
)
# ===========================
cfg = handle_resume_logic(cfg, out_dir)
set_global_seed(cfg.seed)
@ -438,57 +541,14 @@ def train(cfg: DictConfig, out_dir: str | None = None, job_name: str | None = No
dataset_stats=None,
pretrained_policy_name_or_path=str(logger.last_pretrained_model_dir) if cfg.resume else None,
)
# device=device,
# )
assert isinstance(policy, nn.Module)
optimizers, lr_scheduler = make_optimizers_and_scheduler(cfg, policy)
# load last training state
# We can't use the logger function in `lerobot/common/logger.py`
# because it only loads the optimization step and not the interaction one
# to avoid altering that code, we will just load the optimization state manually
resume_interaction_step, resume_optimization_step = None, None
if cfg.resume:
training_state = torch.load(logger.last_checkpoint_dir / logger.training_state_file_name)
if type(training_state["optimizer"]) is dict:
assert set(training_state["optimizer"].keys()) == set(optimizers.keys()), (
"Optimizer dictionaries do not have the same keys during resume!"
)
for k, v in training_state["optimizer"].items():
optimizers[k].load_state_dict(v)
else:
optimizers.load_state_dict(training_state["optimizer"])
# Small hack to get the expected keys: use `get_global_random_state`.
set_global_random_state({k: training_state[k] for k in get_global_random_state()})
resume_optimization_step = training_state["step"]
resume_interaction_step = training_state["interaction_step"]
resume_optimization_step, resume_interaction_step = load_training_state(cfg, logger, optimizers)
num_learnable_params = sum(p.numel() for p in policy.parameters() if p.requires_grad)
num_total_params = sum(p.numel() for p in policy.parameters())
log_training_info(cfg, out_dir, policy)
log_output_dir(out_dir)
logging.info(f"{cfg.env.task=}")
logging.info(f"{cfg.training.online_steps=}")
logging.info(f"{num_learnable_params=} ({format_big_number(num_learnable_params)})")
logging.info(f"{num_total_params=} ({format_big_number(num_total_params)})")
if not cfg.resume:
replay_buffer = ReplayBuffer(
capacity=cfg.training.online_buffer_capacity,
device=device,
state_keys=cfg.policy.input_shapes.keys(),
)
else:
# Reload replay buffer
dataset = LeRobotDataset(
repo_id=cfg.dataset_repo_id, local_files_only=True, root=logger.log_dir / "dataset"
)
replay_buffer = ReplayBuffer.from_lerobot_dataset(
lerobot_dataset=dataset,
capacity=cfg.training.online_buffer_capacity,
device=device,
state_keys=cfg.policy.input_shapes.keys(),
)
replay_buffer = initialize_replay_buffer(cfg, logger, device)
batch_size = cfg.training.batch_size
offline_replay_buffer = None
@ -501,47 +561,19 @@ def train(cfg: DictConfig, out_dir: str | None = None, job_name: str | None = No
# )
# batch_size: int = batch_size // 2 # We will sample from both replay buffer
actor_ip = cfg.actor_learner_config.actor_ip
port = cfg.actor_learner_config.port
server_thread = Thread(
target=stream_transitions_from_actor,
args=(
actor_ip,
port,
),
daemon=True,
start_learner_threads(
cfg,
device,
replay_buffer,
offline_replay_buffer,
batch_size,
optimizers,
policy,
policy_lock,
logger,
resume_optimization_step,
resume_interaction_step,
)
server_thread.start()
transition_thread = Thread(
target=add_actor_information_and_train,
daemon=True,
args=(
cfg,
device,
replay_buffer,
offline_replay_buffer,
batch_size,
optimizers,
policy,
policy_lock,
logger,
resume_optimization_step,
resume_interaction_step,
),
)
transition_thread.start()
param_push_thread = Thread(
target=learner_push_parameters,
args=(policy, policy_lock, actor_ip, port, 15),
daemon=True,
)
param_push_thread.start()
transition_thread.join()
server_thread.join()
@hydra.main(version_base="1.2", config_name="default", config_path="../../configs")