Cleaned `learner_server.py`. Added several block function to improve readability.

Co-authored-by: Adil Zouitine <adilzouitinegm@gmail.com>
This commit is contained in:
Michel Aractingi 2025-01-31 08:33:33 +00:00
parent 367dfe51c6
commit 7c89bd1018
2 changed files with 166 additions and 134 deletions

View File

@ -21,7 +21,7 @@ training:
eval_freq: 2500 eval_freq: 2500
log_freq: 500 log_freq: 500
save_freq: 1000000 save_freq: 2000000
online_steps: 1000000 online_steps: 1000000
online_rollout_n_episodes: 10 online_rollout_n_episodes: 10

View File

@ -34,8 +34,7 @@ from deepdiff import DeepDiff
from omegaconf import DictConfig, OmegaConf from omegaconf import DictConfig, OmegaConf
from termcolor import colored from termcolor import colored
from torch import nn from torch import nn
from torch.optim.optimizer import Optimizer
from lerobot.common.datasets.factory import make_dataset
# TODO: Remove the import of maniskill # TODO: Remove the import of maniskill
from lerobot.common.datasets.lerobot_dataset import LeRobotDataset from lerobot.common.datasets.lerobot_dataset import LeRobotDataset
@ -53,18 +52,164 @@ from lerobot.common.utils.utils import (
) )
from lerobot.scripts.server.buffer import ( from lerobot.scripts.server.buffer import (
ReplayBuffer, ReplayBuffer,
concatenate_batch_transitions,
move_state_dict_to_device, move_state_dict_to_device,
move_transition_to_device, move_transition_to_device,
) )
logging.basicConfig(level=logging.INFO) logging.basicConfig(level=logging.INFO)
# TODO: Implement it in cleaner way maybe
transition_queue = queue.Queue() transition_queue = queue.Queue()
interaction_message_queue = queue.Queue() interaction_message_queue = queue.Queue()
def handle_resume_logic(cfg: DictConfig, out_dir: str) -> DictConfig:
if not cfg.resume:
if Logger.get_last_checkpoint_dir(out_dir).exists():
raise RuntimeError(
f"Output directory {Logger.get_last_checkpoint_dir(out_dir)} already exists. "
"Use `resume=true` to resume training."
)
return cfg
# if resume == True
checkpoint_dir = Logger.get_last_checkpoint_dir(out_dir)
if not checkpoint_dir.exists():
raise RuntimeError(f"No model checkpoint found in {checkpoint_dir} for resume=True")
checkpoint_cfg_path = str(Logger.get_last_pretrained_model_dir(out_dir) / "config.yaml")
logging.info(
colored(
"Resume=True detected, resuming previous run",
color="yellow",
attrs=["bold"],
)
)
checkpoint_cfg = init_hydra_config(checkpoint_cfg_path)
diff = DeepDiff(OmegaConf.to_container(checkpoint_cfg), OmegaConf.to_container(cfg))
if "values_changed" in diff and "root['resume']" in diff["values_changed"]:
del diff["values_changed"]["root['resume']"]
if len(diff) > 0:
logging.warning(
f"Differences between the checkpoint config and the provided config detected: \n{pformat(diff)}\n"
"Checkpoint configuration takes precedence."
)
checkpoint_cfg.resume = True
return checkpoint_cfg
def load_training_state(
cfg: DictConfig,
logger: Logger,
optimizers: Optimizer | dict,
):
if not cfg.resume:
return None, None
training_state = torch.load(logger.last_checkpoint_dir / logger.training_state_file_name)
if isinstance(training_state["optimizer"], dict):
assert set(training_state["optimizer"].keys()) == set(optimizers.keys())
for k, v in training_state["optimizer"].items():
optimizers[k].load_state_dict(v)
else:
optimizers.load_state_dict(training_state["optimizer"])
set_global_random_state({k: training_state[k] for k in get_global_random_state()})
return training_state["step"], training_state["interaction_step"]
def log_training_info(cfg: DictConfig, out_dir: str, policy: nn.Module) -> None:
num_learnable_params = sum(p.numel() for p in policy.parameters() if p.requires_grad)
num_total_params = sum(p.numel() for p in policy.parameters())
log_output_dir(out_dir)
logging.info(f"{cfg.env.task=}")
logging.info(f"{cfg.training.online_steps=}")
logging.info(f"{num_learnable_params=} ({format_big_number(num_learnable_params)})")
logging.info(f"{num_total_params=} ({format_big_number(num_total_params)})")
def initialize_replay_buffer(cfg: DictConfig, logger: Logger, device: str) -> ReplayBuffer:
if not cfg.resume:
return ReplayBuffer(
capacity=cfg.training.online_buffer_capacity,
device=device,
state_keys=cfg.policy.input_shapes.keys(),
)
dataset = LeRobotDataset(
repo_id=cfg.dataset_repo_id, local_files_only=True, root=logger.log_dir / "dataset"
)
return ReplayBuffer.from_lerobot_dataset(
lerobot_dataset=dataset,
capacity=cfg.training.online_buffer_capacity,
device=device,
state_keys=cfg.policy.input_shapes.keys(),
)
def start_learner_threads(
cfg: DictConfig,
device: str,
replay_buffer: ReplayBuffer,
offline_replay_buffer: ReplayBuffer,
batch_size: int,
optimizers: dict,
policy: SACPolicy,
policy_lock: Lock,
logger: Logger,
resume_optimization_step: int | None = None,
resume_interaction_step: int | None = None,
) -> None:
actor_ip = cfg.actor_learner_config.actor_ip
port = cfg.actor_learner_config.port
server_thread = Thread(
target=stream_transitions_from_actor,
args=(
actor_ip,
port,
),
daemon=True,
)
transition_thread = Thread(
target=add_actor_information_and_train,
daemon=True,
args=(
cfg,
device,
replay_buffer,
offline_replay_buffer,
batch_size,
optimizers,
policy,
policy_lock,
logger,
resume_optimization_step,
resume_interaction_step,
),
)
param_push_thread = Thread(
target=learner_push_parameters,
args=(policy, policy_lock, actor_ip, port, 15),
daemon=True,
)
server_thread.start()
transition_thread.start()
param_push_thread.start()
param_push_thread.join()
transition_thread.join()
server_thread.join()
def stream_transitions_from_actor(host="127.0.0.1", port=50051): def stream_transitions_from_actor(host="127.0.0.1", port=50051):
""" """
Runs a gRPC client that listens for transition and interaction messages from an Actor service. Runs a gRPC client that listens for transition and interaction messages from an Actor service.
@ -373,49 +518,7 @@ def train(cfg: DictConfig, out_dir: str | None = None, job_name: str | None = No
logging.info(pformat(OmegaConf.to_container(cfg))) logging.info(pformat(OmegaConf.to_container(cfg)))
logger = Logger(cfg, out_dir, wandb_job_name=job_name) logger = Logger(cfg, out_dir, wandb_job_name=job_name)
cfg = handle_resume_logic(cfg, out_dir)
## Handle resume by reloading the state of the policy and optimization
# If we are resuming a run, we need to check that a checkpoint exists in the log directory, and we need
# to check for any differences between the provided config and the checkpoint's config.
if cfg.resume:
if not Logger.get_last_checkpoint_dir(out_dir).exists():
raise RuntimeError(
"You have set resume=True, but there is no model checkpoint in "
f"{Logger.get_last_checkpoint_dir(out_dir)}"
)
checkpoint_cfg_path = str(Logger.get_last_pretrained_model_dir(out_dir) / "config.yaml")
logging.info(
colored(
"You have set resume=True, indicating that you wish to resume a run",
color="yellow",
attrs=["bold"],
)
)
# Get the configuration file from the last checkpoint.
checkpoint_cfg = init_hydra_config(checkpoint_cfg_path)
# Check for differences between the checkpoint configuration and provided configuration.
diff = DeepDiff(OmegaConf.to_container(checkpoint_cfg), OmegaConf.to_container(cfg))
# Ignore the `resume` and parameters.
if "values_changed" in diff and "root['resume']" in diff["values_changed"]:
del diff["values_changed"]["root['resume']"]
# Log a warning about differences between the checkpoint configuration and the provided
# configuration.
if len(diff) > 0:
logging.warning(
"At least one difference was detected between the checkpoint configuration and "
f"the provided configuration: \n{pformat(diff)}\nNote that the checkpoint configuration "
"takes precedence.",
)
# Use the checkpoint config instead of the provided config (but keep `resume` parameter).
cfg = checkpoint_cfg
cfg.resume = True
elif Logger.get_last_checkpoint_dir(out_dir).exists():
raise RuntimeError(
f"The configured output directory {Logger.get_last_checkpoint_dir(out_dir)} already exists. If "
"you meant to resume training, please use `resume=true` in your command or yaml configuration."
)
# ===========================
set_global_seed(cfg.seed) set_global_seed(cfg.seed)
@ -438,57 +541,14 @@ def train(cfg: DictConfig, out_dir: str | None = None, job_name: str | None = No
dataset_stats=None, dataset_stats=None,
pretrained_policy_name_or_path=str(logger.last_pretrained_model_dir) if cfg.resume else None, pretrained_policy_name_or_path=str(logger.last_pretrained_model_dir) if cfg.resume else None,
) )
# device=device,
# )
assert isinstance(policy, nn.Module) assert isinstance(policy, nn.Module)
optimizers, lr_scheduler = make_optimizers_and_scheduler(cfg, policy) optimizers, lr_scheduler = make_optimizers_and_scheduler(cfg, policy)
# load last training state resume_optimization_step, resume_interaction_step = load_training_state(cfg, logger, optimizers)
# We can't use the logger function in `lerobot/common/logger.py`
# because it only loads the optimization step and not the interaction one
# to avoid altering that code, we will just load the optimization state manually
resume_interaction_step, resume_optimization_step = None, None
if cfg.resume:
training_state = torch.load(logger.last_checkpoint_dir / logger.training_state_file_name)
if type(training_state["optimizer"]) is dict:
assert set(training_state["optimizer"].keys()) == set(optimizers.keys()), (
"Optimizer dictionaries do not have the same keys during resume!"
)
for k, v in training_state["optimizer"].items():
optimizers[k].load_state_dict(v)
else:
optimizers.load_state_dict(training_state["optimizer"])
# Small hack to get the expected keys: use `get_global_random_state`.
set_global_random_state({k: training_state[k] for k in get_global_random_state()})
resume_optimization_step = training_state["step"]
resume_interaction_step = training_state["interaction_step"]
num_learnable_params = sum(p.numel() for p in policy.parameters() if p.requires_grad) log_training_info(cfg, out_dir, policy)
num_total_params = sum(p.numel() for p in policy.parameters())
log_output_dir(out_dir) replay_buffer = initialize_replay_buffer(cfg, logger, device)
logging.info(f"{cfg.env.task=}")
logging.info(f"{cfg.training.online_steps=}")
logging.info(f"{num_learnable_params=} ({format_big_number(num_learnable_params)})")
logging.info(f"{num_total_params=} ({format_big_number(num_total_params)})")
if not cfg.resume:
replay_buffer = ReplayBuffer(
capacity=cfg.training.online_buffer_capacity,
device=device,
state_keys=cfg.policy.input_shapes.keys(),
)
else:
# Reload replay buffer
dataset = LeRobotDataset(
repo_id=cfg.dataset_repo_id, local_files_only=True, root=logger.log_dir / "dataset"
)
replay_buffer = ReplayBuffer.from_lerobot_dataset(
lerobot_dataset=dataset,
capacity=cfg.training.online_buffer_capacity,
device=device,
state_keys=cfg.policy.input_shapes.keys(),
)
batch_size = cfg.training.batch_size batch_size = cfg.training.batch_size
offline_replay_buffer = None offline_replay_buffer = None
@ -501,47 +561,19 @@ def train(cfg: DictConfig, out_dir: str | None = None, job_name: str | None = No
# ) # )
# batch_size: int = batch_size // 2 # We will sample from both replay buffer # batch_size: int = batch_size // 2 # We will sample from both replay buffer
actor_ip = cfg.actor_learner_config.actor_ip start_learner_threads(
port = cfg.actor_learner_config.port cfg,
device,
server_thread = Thread( replay_buffer,
target=stream_transitions_from_actor, offline_replay_buffer,
args=( batch_size,
actor_ip, optimizers,
port, policy,
), policy_lock,
daemon=True, logger,
resume_optimization_step,
resume_interaction_step,
) )
server_thread.start()
transition_thread = Thread(
target=add_actor_information_and_train,
daemon=True,
args=(
cfg,
device,
replay_buffer,
offline_replay_buffer,
batch_size,
optimizers,
policy,
policy_lock,
logger,
resume_optimization_step,
resume_interaction_step,
),
)
transition_thread.start()
param_push_thread = Thread(
target=learner_push_parameters,
args=(policy, policy_lock, actor_ip, port, 15),
daemon=True,
)
param_push_thread.start()
transition_thread.join()
server_thread.join()
@hydra.main(version_base="1.2", config_name="default", config_path="../../configs") @hydra.main(version_base="1.2", config_name="default", config_path="../../configs")