Cleaned `learner_server.py`. Added several block function to improve readability.
Co-authored-by: Adil Zouitine <adilzouitinegm@gmail.com>
This commit is contained in:
parent
367dfe51c6
commit
7c89bd1018
|
@ -21,7 +21,7 @@ training:
|
|||
|
||||
eval_freq: 2500
|
||||
log_freq: 500
|
||||
save_freq: 1000000
|
||||
save_freq: 2000000
|
||||
|
||||
online_steps: 1000000
|
||||
online_rollout_n_episodes: 10
|
||||
|
|
|
@ -34,8 +34,7 @@ from deepdiff import DeepDiff
|
|||
from omegaconf import DictConfig, OmegaConf
|
||||
from termcolor import colored
|
||||
from torch import nn
|
||||
|
||||
from lerobot.common.datasets.factory import make_dataset
|
||||
from torch.optim.optimizer import Optimizer
|
||||
|
||||
# TODO: Remove the import of maniskill
|
||||
from lerobot.common.datasets.lerobot_dataset import LeRobotDataset
|
||||
|
@ -53,18 +52,164 @@ from lerobot.common.utils.utils import (
|
|||
)
|
||||
from lerobot.scripts.server.buffer import (
|
||||
ReplayBuffer,
|
||||
concatenate_batch_transitions,
|
||||
move_state_dict_to_device,
|
||||
move_transition_to_device,
|
||||
)
|
||||
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
|
||||
# TODO: Implement it in cleaner way maybe
|
||||
transition_queue = queue.Queue()
|
||||
interaction_message_queue = queue.Queue()
|
||||
|
||||
|
||||
def handle_resume_logic(cfg: DictConfig, out_dir: str) -> DictConfig:
|
||||
if not cfg.resume:
|
||||
if Logger.get_last_checkpoint_dir(out_dir).exists():
|
||||
raise RuntimeError(
|
||||
f"Output directory {Logger.get_last_checkpoint_dir(out_dir)} already exists. "
|
||||
"Use `resume=true` to resume training."
|
||||
)
|
||||
return cfg
|
||||
|
||||
# if resume == True
|
||||
checkpoint_dir = Logger.get_last_checkpoint_dir(out_dir)
|
||||
if not checkpoint_dir.exists():
|
||||
raise RuntimeError(f"No model checkpoint found in {checkpoint_dir} for resume=True")
|
||||
|
||||
checkpoint_cfg_path = str(Logger.get_last_pretrained_model_dir(out_dir) / "config.yaml")
|
||||
logging.info(
|
||||
colored(
|
||||
"Resume=True detected, resuming previous run",
|
||||
color="yellow",
|
||||
attrs=["bold"],
|
||||
)
|
||||
)
|
||||
|
||||
checkpoint_cfg = init_hydra_config(checkpoint_cfg_path)
|
||||
diff = DeepDiff(OmegaConf.to_container(checkpoint_cfg), OmegaConf.to_container(cfg))
|
||||
|
||||
if "values_changed" in diff and "root['resume']" in diff["values_changed"]:
|
||||
del diff["values_changed"]["root['resume']"]
|
||||
|
||||
if len(diff) > 0:
|
||||
logging.warning(
|
||||
f"Differences between the checkpoint config and the provided config detected: \n{pformat(diff)}\n"
|
||||
"Checkpoint configuration takes precedence."
|
||||
)
|
||||
|
||||
checkpoint_cfg.resume = True
|
||||
return checkpoint_cfg
|
||||
|
||||
|
||||
def load_training_state(
|
||||
cfg: DictConfig,
|
||||
logger: Logger,
|
||||
optimizers: Optimizer | dict,
|
||||
):
|
||||
if not cfg.resume:
|
||||
return None, None
|
||||
|
||||
training_state = torch.load(logger.last_checkpoint_dir / logger.training_state_file_name)
|
||||
|
||||
if isinstance(training_state["optimizer"], dict):
|
||||
assert set(training_state["optimizer"].keys()) == set(optimizers.keys())
|
||||
for k, v in training_state["optimizer"].items():
|
||||
optimizers[k].load_state_dict(v)
|
||||
else:
|
||||
optimizers.load_state_dict(training_state["optimizer"])
|
||||
|
||||
set_global_random_state({k: training_state[k] for k in get_global_random_state()})
|
||||
return training_state["step"], training_state["interaction_step"]
|
||||
|
||||
|
||||
def log_training_info(cfg: DictConfig, out_dir: str, policy: nn.Module) -> None:
|
||||
num_learnable_params = sum(p.numel() for p in policy.parameters() if p.requires_grad)
|
||||
num_total_params = sum(p.numel() for p in policy.parameters())
|
||||
|
||||
log_output_dir(out_dir)
|
||||
logging.info(f"{cfg.env.task=}")
|
||||
logging.info(f"{cfg.training.online_steps=}")
|
||||
logging.info(f"{num_learnable_params=} ({format_big_number(num_learnable_params)})")
|
||||
logging.info(f"{num_total_params=} ({format_big_number(num_total_params)})")
|
||||
|
||||
|
||||
def initialize_replay_buffer(cfg: DictConfig, logger: Logger, device: str) -> ReplayBuffer:
|
||||
if not cfg.resume:
|
||||
return ReplayBuffer(
|
||||
capacity=cfg.training.online_buffer_capacity,
|
||||
device=device,
|
||||
state_keys=cfg.policy.input_shapes.keys(),
|
||||
)
|
||||
|
||||
dataset = LeRobotDataset(
|
||||
repo_id=cfg.dataset_repo_id, local_files_only=True, root=logger.log_dir / "dataset"
|
||||
)
|
||||
return ReplayBuffer.from_lerobot_dataset(
|
||||
lerobot_dataset=dataset,
|
||||
capacity=cfg.training.online_buffer_capacity,
|
||||
device=device,
|
||||
state_keys=cfg.policy.input_shapes.keys(),
|
||||
)
|
||||
|
||||
|
||||
def start_learner_threads(
|
||||
cfg: DictConfig,
|
||||
device: str,
|
||||
replay_buffer: ReplayBuffer,
|
||||
offline_replay_buffer: ReplayBuffer,
|
||||
batch_size: int,
|
||||
optimizers: dict,
|
||||
policy: SACPolicy,
|
||||
policy_lock: Lock,
|
||||
logger: Logger,
|
||||
resume_optimization_step: int | None = None,
|
||||
resume_interaction_step: int | None = None,
|
||||
) -> None:
|
||||
actor_ip = cfg.actor_learner_config.actor_ip
|
||||
port = cfg.actor_learner_config.port
|
||||
|
||||
server_thread = Thread(
|
||||
target=stream_transitions_from_actor,
|
||||
args=(
|
||||
actor_ip,
|
||||
port,
|
||||
),
|
||||
daemon=True,
|
||||
)
|
||||
|
||||
transition_thread = Thread(
|
||||
target=add_actor_information_and_train,
|
||||
daemon=True,
|
||||
args=(
|
||||
cfg,
|
||||
device,
|
||||
replay_buffer,
|
||||
offline_replay_buffer,
|
||||
batch_size,
|
||||
optimizers,
|
||||
policy,
|
||||
policy_lock,
|
||||
logger,
|
||||
resume_optimization_step,
|
||||
resume_interaction_step,
|
||||
),
|
||||
)
|
||||
|
||||
param_push_thread = Thread(
|
||||
target=learner_push_parameters,
|
||||
args=(policy, policy_lock, actor_ip, port, 15),
|
||||
daemon=True,
|
||||
)
|
||||
|
||||
server_thread.start()
|
||||
transition_thread.start()
|
||||
param_push_thread.start()
|
||||
|
||||
param_push_thread.join()
|
||||
transition_thread.join()
|
||||
server_thread.join()
|
||||
|
||||
|
||||
def stream_transitions_from_actor(host="127.0.0.1", port=50051):
|
||||
"""
|
||||
Runs a gRPC client that listens for transition and interaction messages from an Actor service.
|
||||
|
@ -373,49 +518,7 @@ def train(cfg: DictConfig, out_dir: str | None = None, job_name: str | None = No
|
|||
logging.info(pformat(OmegaConf.to_container(cfg)))
|
||||
|
||||
logger = Logger(cfg, out_dir, wandb_job_name=job_name)
|
||||
|
||||
## Handle resume by reloading the state of the policy and optimization
|
||||
# If we are resuming a run, we need to check that a checkpoint exists in the log directory, and we need
|
||||
# to check for any differences between the provided config and the checkpoint's config.
|
||||
if cfg.resume:
|
||||
if not Logger.get_last_checkpoint_dir(out_dir).exists():
|
||||
raise RuntimeError(
|
||||
"You have set resume=True, but there is no model checkpoint in "
|
||||
f"{Logger.get_last_checkpoint_dir(out_dir)}"
|
||||
)
|
||||
checkpoint_cfg_path = str(Logger.get_last_pretrained_model_dir(out_dir) / "config.yaml")
|
||||
logging.info(
|
||||
colored(
|
||||
"You have set resume=True, indicating that you wish to resume a run",
|
||||
color="yellow",
|
||||
attrs=["bold"],
|
||||
)
|
||||
)
|
||||
# Get the configuration file from the last checkpoint.
|
||||
checkpoint_cfg = init_hydra_config(checkpoint_cfg_path)
|
||||
# Check for differences between the checkpoint configuration and provided configuration.
|
||||
diff = DeepDiff(OmegaConf.to_container(checkpoint_cfg), OmegaConf.to_container(cfg))
|
||||
# Ignore the `resume` and parameters.
|
||||
if "values_changed" in diff and "root['resume']" in diff["values_changed"]:
|
||||
del diff["values_changed"]["root['resume']"]
|
||||
|
||||
# Log a warning about differences between the checkpoint configuration and the provided
|
||||
# configuration.
|
||||
if len(diff) > 0:
|
||||
logging.warning(
|
||||
"At least one difference was detected between the checkpoint configuration and "
|
||||
f"the provided configuration: \n{pformat(diff)}\nNote that the checkpoint configuration "
|
||||
"takes precedence.",
|
||||
)
|
||||
# Use the checkpoint config instead of the provided config (but keep `resume` parameter).
|
||||
cfg = checkpoint_cfg
|
||||
cfg.resume = True
|
||||
elif Logger.get_last_checkpoint_dir(out_dir).exists():
|
||||
raise RuntimeError(
|
||||
f"The configured output directory {Logger.get_last_checkpoint_dir(out_dir)} already exists. If "
|
||||
"you meant to resume training, please use `resume=true` in your command or yaml configuration."
|
||||
)
|
||||
# ===========================
|
||||
cfg = handle_resume_logic(cfg, out_dir)
|
||||
|
||||
set_global_seed(cfg.seed)
|
||||
|
||||
|
@ -438,57 +541,14 @@ def train(cfg: DictConfig, out_dir: str | None = None, job_name: str | None = No
|
|||
dataset_stats=None,
|
||||
pretrained_policy_name_or_path=str(logger.last_pretrained_model_dir) if cfg.resume else None,
|
||||
)
|
||||
# device=device,
|
||||
# )
|
||||
assert isinstance(policy, nn.Module)
|
||||
|
||||
optimizers, lr_scheduler = make_optimizers_and_scheduler(cfg, policy)
|
||||
# load last training state
|
||||
# We can't use the logger function in `lerobot/common/logger.py`
|
||||
# because it only loads the optimization step and not the interaction one
|
||||
# to avoid altering that code, we will just load the optimization state manually
|
||||
resume_interaction_step, resume_optimization_step = None, None
|
||||
if cfg.resume:
|
||||
training_state = torch.load(logger.last_checkpoint_dir / logger.training_state_file_name)
|
||||
if type(training_state["optimizer"]) is dict:
|
||||
assert set(training_state["optimizer"].keys()) == set(optimizers.keys()), (
|
||||
"Optimizer dictionaries do not have the same keys during resume!"
|
||||
)
|
||||
for k, v in training_state["optimizer"].items():
|
||||
optimizers[k].load_state_dict(v)
|
||||
else:
|
||||
optimizers.load_state_dict(training_state["optimizer"])
|
||||
# Small hack to get the expected keys: use `get_global_random_state`.
|
||||
set_global_random_state({k: training_state[k] for k in get_global_random_state()})
|
||||
resume_optimization_step = training_state["step"]
|
||||
resume_interaction_step = training_state["interaction_step"]
|
||||
resume_optimization_step, resume_interaction_step = load_training_state(cfg, logger, optimizers)
|
||||
|
||||
num_learnable_params = sum(p.numel() for p in policy.parameters() if p.requires_grad)
|
||||
num_total_params = sum(p.numel() for p in policy.parameters())
|
||||
log_training_info(cfg, out_dir, policy)
|
||||
|
||||
log_output_dir(out_dir)
|
||||
logging.info(f"{cfg.env.task=}")
|
||||
logging.info(f"{cfg.training.online_steps=}")
|
||||
logging.info(f"{num_learnable_params=} ({format_big_number(num_learnable_params)})")
|
||||
logging.info(f"{num_total_params=} ({format_big_number(num_total_params)})")
|
||||
|
||||
if not cfg.resume:
|
||||
replay_buffer = ReplayBuffer(
|
||||
capacity=cfg.training.online_buffer_capacity,
|
||||
device=device,
|
||||
state_keys=cfg.policy.input_shapes.keys(),
|
||||
)
|
||||
else:
|
||||
# Reload replay buffer
|
||||
dataset = LeRobotDataset(
|
||||
repo_id=cfg.dataset_repo_id, local_files_only=True, root=logger.log_dir / "dataset"
|
||||
)
|
||||
replay_buffer = ReplayBuffer.from_lerobot_dataset(
|
||||
lerobot_dataset=dataset,
|
||||
capacity=cfg.training.online_buffer_capacity,
|
||||
device=device,
|
||||
state_keys=cfg.policy.input_shapes.keys(),
|
||||
)
|
||||
replay_buffer = initialize_replay_buffer(cfg, logger, device)
|
||||
batch_size = cfg.training.batch_size
|
||||
offline_replay_buffer = None
|
||||
|
||||
|
@ -501,47 +561,19 @@ def train(cfg: DictConfig, out_dir: str | None = None, job_name: str | None = No
|
|||
# )
|
||||
# batch_size: int = batch_size // 2 # We will sample from both replay buffer
|
||||
|
||||
actor_ip = cfg.actor_learner_config.actor_ip
|
||||
port = cfg.actor_learner_config.port
|
||||
|
||||
server_thread = Thread(
|
||||
target=stream_transitions_from_actor,
|
||||
args=(
|
||||
actor_ip,
|
||||
port,
|
||||
),
|
||||
daemon=True,
|
||||
start_learner_threads(
|
||||
cfg,
|
||||
device,
|
||||
replay_buffer,
|
||||
offline_replay_buffer,
|
||||
batch_size,
|
||||
optimizers,
|
||||
policy,
|
||||
policy_lock,
|
||||
logger,
|
||||
resume_optimization_step,
|
||||
resume_interaction_step,
|
||||
)
|
||||
server_thread.start()
|
||||
|
||||
transition_thread = Thread(
|
||||
target=add_actor_information_and_train,
|
||||
daemon=True,
|
||||
args=(
|
||||
cfg,
|
||||
device,
|
||||
replay_buffer,
|
||||
offline_replay_buffer,
|
||||
batch_size,
|
||||
optimizers,
|
||||
policy,
|
||||
policy_lock,
|
||||
logger,
|
||||
resume_optimization_step,
|
||||
resume_interaction_step,
|
||||
),
|
||||
)
|
||||
transition_thread.start()
|
||||
|
||||
param_push_thread = Thread(
|
||||
target=learner_push_parameters,
|
||||
args=(policy, policy_lock, actor_ip, port, 15),
|
||||
daemon=True,
|
||||
)
|
||||
param_push_thread.start()
|
||||
|
||||
transition_thread.join()
|
||||
server_thread.join()
|
||||
|
||||
|
||||
@hydra.main(version_base="1.2", config_name="default", config_path="../../configs")
|
||||
|
|
Loading…
Reference in New Issue