Cleaned `learner_server.py`. Added several block function to improve readability.
Co-authored-by: Adil Zouitine <adilzouitinegm@gmail.com>
This commit is contained in:
parent
367dfe51c6
commit
7c89bd1018
|
@ -21,7 +21,7 @@ training:
|
||||||
|
|
||||||
eval_freq: 2500
|
eval_freq: 2500
|
||||||
log_freq: 500
|
log_freq: 500
|
||||||
save_freq: 1000000
|
save_freq: 2000000
|
||||||
|
|
||||||
online_steps: 1000000
|
online_steps: 1000000
|
||||||
online_rollout_n_episodes: 10
|
online_rollout_n_episodes: 10
|
||||||
|
|
|
@ -34,8 +34,7 @@ from deepdiff import DeepDiff
|
||||||
from omegaconf import DictConfig, OmegaConf
|
from omegaconf import DictConfig, OmegaConf
|
||||||
from termcolor import colored
|
from termcolor import colored
|
||||||
from torch import nn
|
from torch import nn
|
||||||
|
from torch.optim.optimizer import Optimizer
|
||||||
from lerobot.common.datasets.factory import make_dataset
|
|
||||||
|
|
||||||
# TODO: Remove the import of maniskill
|
# TODO: Remove the import of maniskill
|
||||||
from lerobot.common.datasets.lerobot_dataset import LeRobotDataset
|
from lerobot.common.datasets.lerobot_dataset import LeRobotDataset
|
||||||
|
@ -53,18 +52,164 @@ from lerobot.common.utils.utils import (
|
||||||
)
|
)
|
||||||
from lerobot.scripts.server.buffer import (
|
from lerobot.scripts.server.buffer import (
|
||||||
ReplayBuffer,
|
ReplayBuffer,
|
||||||
concatenate_batch_transitions,
|
|
||||||
move_state_dict_to_device,
|
move_state_dict_to_device,
|
||||||
move_transition_to_device,
|
move_transition_to_device,
|
||||||
)
|
)
|
||||||
|
|
||||||
logging.basicConfig(level=logging.INFO)
|
logging.basicConfig(level=logging.INFO)
|
||||||
|
|
||||||
# TODO: Implement it in cleaner way maybe
|
|
||||||
transition_queue = queue.Queue()
|
transition_queue = queue.Queue()
|
||||||
interaction_message_queue = queue.Queue()
|
interaction_message_queue = queue.Queue()
|
||||||
|
|
||||||
|
|
||||||
|
def handle_resume_logic(cfg: DictConfig, out_dir: str) -> DictConfig:
|
||||||
|
if not cfg.resume:
|
||||||
|
if Logger.get_last_checkpoint_dir(out_dir).exists():
|
||||||
|
raise RuntimeError(
|
||||||
|
f"Output directory {Logger.get_last_checkpoint_dir(out_dir)} already exists. "
|
||||||
|
"Use `resume=true` to resume training."
|
||||||
|
)
|
||||||
|
return cfg
|
||||||
|
|
||||||
|
# if resume == True
|
||||||
|
checkpoint_dir = Logger.get_last_checkpoint_dir(out_dir)
|
||||||
|
if not checkpoint_dir.exists():
|
||||||
|
raise RuntimeError(f"No model checkpoint found in {checkpoint_dir} for resume=True")
|
||||||
|
|
||||||
|
checkpoint_cfg_path = str(Logger.get_last_pretrained_model_dir(out_dir) / "config.yaml")
|
||||||
|
logging.info(
|
||||||
|
colored(
|
||||||
|
"Resume=True detected, resuming previous run",
|
||||||
|
color="yellow",
|
||||||
|
attrs=["bold"],
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
checkpoint_cfg = init_hydra_config(checkpoint_cfg_path)
|
||||||
|
diff = DeepDiff(OmegaConf.to_container(checkpoint_cfg), OmegaConf.to_container(cfg))
|
||||||
|
|
||||||
|
if "values_changed" in diff and "root['resume']" in diff["values_changed"]:
|
||||||
|
del diff["values_changed"]["root['resume']"]
|
||||||
|
|
||||||
|
if len(diff) > 0:
|
||||||
|
logging.warning(
|
||||||
|
f"Differences between the checkpoint config and the provided config detected: \n{pformat(diff)}\n"
|
||||||
|
"Checkpoint configuration takes precedence."
|
||||||
|
)
|
||||||
|
|
||||||
|
checkpoint_cfg.resume = True
|
||||||
|
return checkpoint_cfg
|
||||||
|
|
||||||
|
|
||||||
|
def load_training_state(
|
||||||
|
cfg: DictConfig,
|
||||||
|
logger: Logger,
|
||||||
|
optimizers: Optimizer | dict,
|
||||||
|
):
|
||||||
|
if not cfg.resume:
|
||||||
|
return None, None
|
||||||
|
|
||||||
|
training_state = torch.load(logger.last_checkpoint_dir / logger.training_state_file_name)
|
||||||
|
|
||||||
|
if isinstance(training_state["optimizer"], dict):
|
||||||
|
assert set(training_state["optimizer"].keys()) == set(optimizers.keys())
|
||||||
|
for k, v in training_state["optimizer"].items():
|
||||||
|
optimizers[k].load_state_dict(v)
|
||||||
|
else:
|
||||||
|
optimizers.load_state_dict(training_state["optimizer"])
|
||||||
|
|
||||||
|
set_global_random_state({k: training_state[k] for k in get_global_random_state()})
|
||||||
|
return training_state["step"], training_state["interaction_step"]
|
||||||
|
|
||||||
|
|
||||||
|
def log_training_info(cfg: DictConfig, out_dir: str, policy: nn.Module) -> None:
|
||||||
|
num_learnable_params = sum(p.numel() for p in policy.parameters() if p.requires_grad)
|
||||||
|
num_total_params = sum(p.numel() for p in policy.parameters())
|
||||||
|
|
||||||
|
log_output_dir(out_dir)
|
||||||
|
logging.info(f"{cfg.env.task=}")
|
||||||
|
logging.info(f"{cfg.training.online_steps=}")
|
||||||
|
logging.info(f"{num_learnable_params=} ({format_big_number(num_learnable_params)})")
|
||||||
|
logging.info(f"{num_total_params=} ({format_big_number(num_total_params)})")
|
||||||
|
|
||||||
|
|
||||||
|
def initialize_replay_buffer(cfg: DictConfig, logger: Logger, device: str) -> ReplayBuffer:
|
||||||
|
if not cfg.resume:
|
||||||
|
return ReplayBuffer(
|
||||||
|
capacity=cfg.training.online_buffer_capacity,
|
||||||
|
device=device,
|
||||||
|
state_keys=cfg.policy.input_shapes.keys(),
|
||||||
|
)
|
||||||
|
|
||||||
|
dataset = LeRobotDataset(
|
||||||
|
repo_id=cfg.dataset_repo_id, local_files_only=True, root=logger.log_dir / "dataset"
|
||||||
|
)
|
||||||
|
return ReplayBuffer.from_lerobot_dataset(
|
||||||
|
lerobot_dataset=dataset,
|
||||||
|
capacity=cfg.training.online_buffer_capacity,
|
||||||
|
device=device,
|
||||||
|
state_keys=cfg.policy.input_shapes.keys(),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def start_learner_threads(
|
||||||
|
cfg: DictConfig,
|
||||||
|
device: str,
|
||||||
|
replay_buffer: ReplayBuffer,
|
||||||
|
offline_replay_buffer: ReplayBuffer,
|
||||||
|
batch_size: int,
|
||||||
|
optimizers: dict,
|
||||||
|
policy: SACPolicy,
|
||||||
|
policy_lock: Lock,
|
||||||
|
logger: Logger,
|
||||||
|
resume_optimization_step: int | None = None,
|
||||||
|
resume_interaction_step: int | None = None,
|
||||||
|
) -> None:
|
||||||
|
actor_ip = cfg.actor_learner_config.actor_ip
|
||||||
|
port = cfg.actor_learner_config.port
|
||||||
|
|
||||||
|
server_thread = Thread(
|
||||||
|
target=stream_transitions_from_actor,
|
||||||
|
args=(
|
||||||
|
actor_ip,
|
||||||
|
port,
|
||||||
|
),
|
||||||
|
daemon=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
transition_thread = Thread(
|
||||||
|
target=add_actor_information_and_train,
|
||||||
|
daemon=True,
|
||||||
|
args=(
|
||||||
|
cfg,
|
||||||
|
device,
|
||||||
|
replay_buffer,
|
||||||
|
offline_replay_buffer,
|
||||||
|
batch_size,
|
||||||
|
optimizers,
|
||||||
|
policy,
|
||||||
|
policy_lock,
|
||||||
|
logger,
|
||||||
|
resume_optimization_step,
|
||||||
|
resume_interaction_step,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
param_push_thread = Thread(
|
||||||
|
target=learner_push_parameters,
|
||||||
|
args=(policy, policy_lock, actor_ip, port, 15),
|
||||||
|
daemon=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
server_thread.start()
|
||||||
|
transition_thread.start()
|
||||||
|
param_push_thread.start()
|
||||||
|
|
||||||
|
param_push_thread.join()
|
||||||
|
transition_thread.join()
|
||||||
|
server_thread.join()
|
||||||
|
|
||||||
|
|
||||||
def stream_transitions_from_actor(host="127.0.0.1", port=50051):
|
def stream_transitions_from_actor(host="127.0.0.1", port=50051):
|
||||||
"""
|
"""
|
||||||
Runs a gRPC client that listens for transition and interaction messages from an Actor service.
|
Runs a gRPC client that listens for transition and interaction messages from an Actor service.
|
||||||
|
@ -373,49 +518,7 @@ def train(cfg: DictConfig, out_dir: str | None = None, job_name: str | None = No
|
||||||
logging.info(pformat(OmegaConf.to_container(cfg)))
|
logging.info(pformat(OmegaConf.to_container(cfg)))
|
||||||
|
|
||||||
logger = Logger(cfg, out_dir, wandb_job_name=job_name)
|
logger = Logger(cfg, out_dir, wandb_job_name=job_name)
|
||||||
|
cfg = handle_resume_logic(cfg, out_dir)
|
||||||
## Handle resume by reloading the state of the policy and optimization
|
|
||||||
# If we are resuming a run, we need to check that a checkpoint exists in the log directory, and we need
|
|
||||||
# to check for any differences between the provided config and the checkpoint's config.
|
|
||||||
if cfg.resume:
|
|
||||||
if not Logger.get_last_checkpoint_dir(out_dir).exists():
|
|
||||||
raise RuntimeError(
|
|
||||||
"You have set resume=True, but there is no model checkpoint in "
|
|
||||||
f"{Logger.get_last_checkpoint_dir(out_dir)}"
|
|
||||||
)
|
|
||||||
checkpoint_cfg_path = str(Logger.get_last_pretrained_model_dir(out_dir) / "config.yaml")
|
|
||||||
logging.info(
|
|
||||||
colored(
|
|
||||||
"You have set resume=True, indicating that you wish to resume a run",
|
|
||||||
color="yellow",
|
|
||||||
attrs=["bold"],
|
|
||||||
)
|
|
||||||
)
|
|
||||||
# Get the configuration file from the last checkpoint.
|
|
||||||
checkpoint_cfg = init_hydra_config(checkpoint_cfg_path)
|
|
||||||
# Check for differences between the checkpoint configuration and provided configuration.
|
|
||||||
diff = DeepDiff(OmegaConf.to_container(checkpoint_cfg), OmegaConf.to_container(cfg))
|
|
||||||
# Ignore the `resume` and parameters.
|
|
||||||
if "values_changed" in diff and "root['resume']" in diff["values_changed"]:
|
|
||||||
del diff["values_changed"]["root['resume']"]
|
|
||||||
|
|
||||||
# Log a warning about differences between the checkpoint configuration and the provided
|
|
||||||
# configuration.
|
|
||||||
if len(diff) > 0:
|
|
||||||
logging.warning(
|
|
||||||
"At least one difference was detected between the checkpoint configuration and "
|
|
||||||
f"the provided configuration: \n{pformat(diff)}\nNote that the checkpoint configuration "
|
|
||||||
"takes precedence.",
|
|
||||||
)
|
|
||||||
# Use the checkpoint config instead of the provided config (but keep `resume` parameter).
|
|
||||||
cfg = checkpoint_cfg
|
|
||||||
cfg.resume = True
|
|
||||||
elif Logger.get_last_checkpoint_dir(out_dir).exists():
|
|
||||||
raise RuntimeError(
|
|
||||||
f"The configured output directory {Logger.get_last_checkpoint_dir(out_dir)} already exists. If "
|
|
||||||
"you meant to resume training, please use `resume=true` in your command or yaml configuration."
|
|
||||||
)
|
|
||||||
# ===========================
|
|
||||||
|
|
||||||
set_global_seed(cfg.seed)
|
set_global_seed(cfg.seed)
|
||||||
|
|
||||||
|
@ -438,57 +541,14 @@ def train(cfg: DictConfig, out_dir: str | None = None, job_name: str | None = No
|
||||||
dataset_stats=None,
|
dataset_stats=None,
|
||||||
pretrained_policy_name_or_path=str(logger.last_pretrained_model_dir) if cfg.resume else None,
|
pretrained_policy_name_or_path=str(logger.last_pretrained_model_dir) if cfg.resume else None,
|
||||||
)
|
)
|
||||||
# device=device,
|
|
||||||
# )
|
|
||||||
assert isinstance(policy, nn.Module)
|
assert isinstance(policy, nn.Module)
|
||||||
|
|
||||||
optimizers, lr_scheduler = make_optimizers_and_scheduler(cfg, policy)
|
optimizers, lr_scheduler = make_optimizers_and_scheduler(cfg, policy)
|
||||||
# load last training state
|
resume_optimization_step, resume_interaction_step = load_training_state(cfg, logger, optimizers)
|
||||||
# We can't use the logger function in `lerobot/common/logger.py`
|
|
||||||
# because it only loads the optimization step and not the interaction one
|
|
||||||
# to avoid altering that code, we will just load the optimization state manually
|
|
||||||
resume_interaction_step, resume_optimization_step = None, None
|
|
||||||
if cfg.resume:
|
|
||||||
training_state = torch.load(logger.last_checkpoint_dir / logger.training_state_file_name)
|
|
||||||
if type(training_state["optimizer"]) is dict:
|
|
||||||
assert set(training_state["optimizer"].keys()) == set(optimizers.keys()), (
|
|
||||||
"Optimizer dictionaries do not have the same keys during resume!"
|
|
||||||
)
|
|
||||||
for k, v in training_state["optimizer"].items():
|
|
||||||
optimizers[k].load_state_dict(v)
|
|
||||||
else:
|
|
||||||
optimizers.load_state_dict(training_state["optimizer"])
|
|
||||||
# Small hack to get the expected keys: use `get_global_random_state`.
|
|
||||||
set_global_random_state({k: training_state[k] for k in get_global_random_state()})
|
|
||||||
resume_optimization_step = training_state["step"]
|
|
||||||
resume_interaction_step = training_state["interaction_step"]
|
|
||||||
|
|
||||||
num_learnable_params = sum(p.numel() for p in policy.parameters() if p.requires_grad)
|
log_training_info(cfg, out_dir, policy)
|
||||||
num_total_params = sum(p.numel() for p in policy.parameters())
|
|
||||||
|
|
||||||
log_output_dir(out_dir)
|
replay_buffer = initialize_replay_buffer(cfg, logger, device)
|
||||||
logging.info(f"{cfg.env.task=}")
|
|
||||||
logging.info(f"{cfg.training.online_steps=}")
|
|
||||||
logging.info(f"{num_learnable_params=} ({format_big_number(num_learnable_params)})")
|
|
||||||
logging.info(f"{num_total_params=} ({format_big_number(num_total_params)})")
|
|
||||||
|
|
||||||
if not cfg.resume:
|
|
||||||
replay_buffer = ReplayBuffer(
|
|
||||||
capacity=cfg.training.online_buffer_capacity,
|
|
||||||
device=device,
|
|
||||||
state_keys=cfg.policy.input_shapes.keys(),
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
# Reload replay buffer
|
|
||||||
dataset = LeRobotDataset(
|
|
||||||
repo_id=cfg.dataset_repo_id, local_files_only=True, root=logger.log_dir / "dataset"
|
|
||||||
)
|
|
||||||
replay_buffer = ReplayBuffer.from_lerobot_dataset(
|
|
||||||
lerobot_dataset=dataset,
|
|
||||||
capacity=cfg.training.online_buffer_capacity,
|
|
||||||
device=device,
|
|
||||||
state_keys=cfg.policy.input_shapes.keys(),
|
|
||||||
)
|
|
||||||
batch_size = cfg.training.batch_size
|
batch_size = cfg.training.batch_size
|
||||||
offline_replay_buffer = None
|
offline_replay_buffer = None
|
||||||
|
|
||||||
|
@ -501,47 +561,19 @@ def train(cfg: DictConfig, out_dir: str | None = None, job_name: str | None = No
|
||||||
# )
|
# )
|
||||||
# batch_size: int = batch_size // 2 # We will sample from both replay buffer
|
# batch_size: int = batch_size // 2 # We will sample from both replay buffer
|
||||||
|
|
||||||
actor_ip = cfg.actor_learner_config.actor_ip
|
start_learner_threads(
|
||||||
port = cfg.actor_learner_config.port
|
cfg,
|
||||||
|
device,
|
||||||
server_thread = Thread(
|
replay_buffer,
|
||||||
target=stream_transitions_from_actor,
|
offline_replay_buffer,
|
||||||
args=(
|
batch_size,
|
||||||
actor_ip,
|
optimizers,
|
||||||
port,
|
policy,
|
||||||
),
|
policy_lock,
|
||||||
daemon=True,
|
logger,
|
||||||
|
resume_optimization_step,
|
||||||
|
resume_interaction_step,
|
||||||
)
|
)
|
||||||
server_thread.start()
|
|
||||||
|
|
||||||
transition_thread = Thread(
|
|
||||||
target=add_actor_information_and_train,
|
|
||||||
daemon=True,
|
|
||||||
args=(
|
|
||||||
cfg,
|
|
||||||
device,
|
|
||||||
replay_buffer,
|
|
||||||
offline_replay_buffer,
|
|
||||||
batch_size,
|
|
||||||
optimizers,
|
|
||||||
policy,
|
|
||||||
policy_lock,
|
|
||||||
logger,
|
|
||||||
resume_optimization_step,
|
|
||||||
resume_interaction_step,
|
|
||||||
),
|
|
||||||
)
|
|
||||||
transition_thread.start()
|
|
||||||
|
|
||||||
param_push_thread = Thread(
|
|
||||||
target=learner_push_parameters,
|
|
||||||
args=(policy, policy_lock, actor_ip, port, 15),
|
|
||||||
daemon=True,
|
|
||||||
)
|
|
||||||
param_push_thread.start()
|
|
||||||
|
|
||||||
transition_thread.join()
|
|
||||||
server_thread.join()
|
|
||||||
|
|
||||||
|
|
||||||
@hydra.main(version_base="1.2", config_name="default", config_path="../../configs")
|
@hydra.main(version_base="1.2", config_name="default", config_path="../../configs")
|
||||||
|
|
Loading…
Reference in New Issue