Added support for checkpointing the policy. We can save and load the policy state dict, optimizers state, optimization step and interaction step

Added functions for converting the replay buffer from and to LeRobotDataset. When we want to save the replay buffer, we convert it first to LeRobotDataset format and save it locally and vice-versa.

Co-authored-by: Adil Zouitine <adilzouitinegm@gmail.com>
This commit is contained in:
Michel Aractingi 2025-01-30 17:39:41 +00:00
parent e856ffc91e
commit 367dfe51c6
7 changed files with 217 additions and 92 deletions

View File

@ -275,6 +275,7 @@ def get_hf_features_from_features(features: dict) -> datasets.Features:
hf_features[key] = datasets.Sequence( hf_features[key] = datasets.Sequence(
length=ft["shape"][0], feature=datasets.Value(dtype=ft["dtype"]) length=ft["shape"][0], feature=datasets.Value(dtype=ft["dtype"])
) )
# TODO: (alibers, azouitine) Add support for ft["shap"] == 0 as Value
return datasets.Features(hf_features) return datasets.Features(hf_features)

View File

@ -174,18 +174,32 @@ class Logger:
self, self,
save_dir: Path, save_dir: Path,
train_step: int, train_step: int,
optimizer: Optimizer, optimizer: Optimizer | dict,
scheduler: LRScheduler | None, scheduler: LRScheduler | None,
interaction_step: int | None = None,
): ):
"""Checkpoint the global training_step, optimizer state, scheduler state, and random state. """Checkpoint the global training_step, optimizer state, scheduler state, and random state.
All of these are saved as "training_state.pth" under the checkpoint directory. All of these are saved as "training_state.pth" under the checkpoint directory.
""" """
# In Sac, for example, we have a dictionary of torch.optim.Optimizer
if type(optimizer) is dict:
optimizer_state_dict = {}
for k in optimizer:
optimizer_state_dict[k] = optimizer[k].state_dict()
else:
optimizer_state_dict = optimizer.state_dict()
training_state = { training_state = {
"step": train_step, "step": train_step,
"optimizer": optimizer.state_dict(), "optimizer": optimizer_state_dict,
**get_global_random_state(), **get_global_random_state(),
} }
# Interaction step is related to the distributed training code
# In that setup, we have two kinds of steps, the online step of the env and the optimization step
# We need to save both in order to resume the optimization properly and not break the logs dependant on the interaction step
if interaction_step is not None:
training_state["interaction_step"] = interaction_step
if scheduler is not None: if scheduler is not None:
training_state["scheduler"] = scheduler.state_dict() training_state["scheduler"] = scheduler.state_dict()
torch.save(training_state, save_dir / self.training_state_file_name) torch.save(training_state, save_dir / self.training_state_file_name)
@ -197,6 +211,7 @@ class Logger:
optimizer: Optimizer, optimizer: Optimizer,
scheduler: LRScheduler | None, scheduler: LRScheduler | None,
identifier: str, identifier: str,
interaction_step: int | None = None,
): ):
"""Checkpoint the model weights and the training state.""" """Checkpoint the model weights and the training state."""
checkpoint_dir = self.checkpoints_dir / str(identifier) checkpoint_dir = self.checkpoints_dir / str(identifier)
@ -208,16 +223,24 @@ class Logger:
self.save_model( self.save_model(
checkpoint_dir / self.pretrained_model_dir_name, policy, wandb_artifact_name=wandb_artifact_name checkpoint_dir / self.pretrained_model_dir_name, policy, wandb_artifact_name=wandb_artifact_name
) )
self.save_training_state(checkpoint_dir, train_step, optimizer, scheduler) self.save_training_state(checkpoint_dir, train_step, optimizer, scheduler, interaction_step)
os.symlink(checkpoint_dir.absolute(), self.last_checkpoint_dir) os.symlink(checkpoint_dir.absolute(), self.last_checkpoint_dir)
def load_last_training_state(self, optimizer: Optimizer, scheduler: LRScheduler | None) -> int: def load_last_training_state(self, optimizer: Optimizer | dict, scheduler: LRScheduler | None) -> int:
""" """
Given the last checkpoint in the logging directory, load the optimizer state, scheduler state, and Given the last checkpoint in the logging directory, load the optimizer state, scheduler state, and
random state, and return the global training step. random state, and return the global training step.
""" """
training_state = torch.load(self.last_checkpoint_dir / self.training_state_file_name) training_state = torch.load(self.last_checkpoint_dir / self.training_state_file_name)
optimizer.load_state_dict(training_state["optimizer"]) # For the case where the optimizer is a dictionary of optimizers (e.g., sac)
if type(training_state["optimizer"]) is dict:
assert set(training_state["optimizer"].keys()) == set(optimizer.keys()), (
"Optimizer dictionaries do not have the same keys during resume!"
)
for k, v in training_state["optimizer"].items():
optimizer[k].load_state_dict(v)
else:
optimizer.load_state_dict(training_state["optimizer"])
if scheduler is not None: if scheduler is not None:
scheduler.load_state_dict(training_state["scheduler"]) scheduler.load_state_dict(training_state["scheduler"])
elif "scheduler" in training_state: elif "scheduler" in training_state:
@ -228,7 +251,7 @@ class Logger:
set_global_random_state({k: training_state[k] for k in get_global_random_state()}) set_global_random_state({k: training_state[k] for k in get_global_random_state()})
return training_state["step"] return training_state["step"]
def log_dict(self, d, step:int | None = None, mode="train", custom_step_key: str | None = None): def log_dict(self, d, step: int | None = None, mode="train", custom_step_key: str | None = None):
"""Log a dictionary of metrics to WandB.""" """Log a dictionary of metrics to WandB."""
assert mode in {"train", "eval"} assert mode in {"train", "eval"}
# TODO(alexander-soare): Add local text log. # TODO(alexander-soare): Add local text log.
@ -236,10 +259,9 @@ class Logger:
raise ValueError("Either step or custom_step_key must be provided.") raise ValueError("Either step or custom_step_key must be provided.")
if self._wandb is not None: if self._wandb is not None:
# NOTE: This is not simple. Wandb step is it must always monotonically increase and it
# NOTE: This is not simple. Wandb step is it must always monotonically increase and it
# increases with each wandb.log call, but in the case of asynchronous RL for example, # increases with each wandb.log call, but in the case of asynchronous RL for example,
# multiple time steps is possible for example, the interaction step with the environment, # multiple time steps is possible for example, the interaction step with the environment,
# the training step, the evaluation step, etc. So we need to define a custom step key # the training step, the evaluation step, etc. So we need to define a custom step key
# to log the correct step for each metric. # to log the correct step for each metric.
if custom_step_key is not None and self._wandb_custom_step_key is None: if custom_step_key is not None and self._wandb_custom_step_key is None:
@ -247,7 +269,7 @@ class Logger:
# custom step. # custom step.
self._wandb_custom_step_key = f"{mode}/{custom_step_key}" self._wandb_custom_step_key = f"{mode}/{custom_step_key}"
self._wandb.define_metric(self._wandb_custom_step_key, hidden=True) self._wandb.define_metric(self._wandb_custom_step_key, hidden=True)
for k, v in d.items(): for k, v in d.items():
if not isinstance(v, (int, float, str, wandb.Table)): if not isinstance(v, (int, float, str, wandb.Table)):
logging.warning( logging.warning(
@ -267,8 +289,6 @@ class Logger:
self._wandb.log({f"{mode}/{k}": v}, step=step) self._wandb.log({f"{mode}/{k}": v}, step=step)
def log_video(self, video_path: str, step: int, mode: str = "train"): def log_video(self, video_path: str, step: int, mode: str = "train"):
assert mode in {"train", "eval"} assert mode in {"train", "eval"}
assert self._wandb is not None assert self._wandb is not None

View File

@ -106,7 +106,7 @@ def make_policy(
# Make a fresh policy. # Make a fresh policy.
# HACK: We pass *args and **kwargs to the policy constructor to allow for additional arguments # HACK: We pass *args and **kwargs to the policy constructor to allow for additional arguments
# for example device for the sac policy. # for example device for the sac policy.
policy = policy_cls(*args, **kwargs, config=policy_cfg, dataset_stats=dataset_stats) policy = policy_cls(config=policy_cfg, dataset_stats=dataset_stats)
else: else:
# Load a pretrained policy and override the config if needed (for example, if there are inference-time # Load a pretrained policy and override the config if needed (for example, if there are inference-time
# hyperparameters that we want to vary). # hyperparameters that we want to vary).

View File

@ -29,6 +29,7 @@ from torch import Tensor
from lerobot.common.policies.normalize import Normalize, Unnormalize from lerobot.common.policies.normalize import Normalize, Unnormalize
from lerobot.common.policies.sac.configuration_sac import SACConfig from lerobot.common.policies.sac.configuration_sac import SACConfig
from lerobot.common.policies.utils import get_device_from_parameters
class SACPolicy( class SACPolicy(
@ -44,7 +45,6 @@ class SACPolicy(
self, self,
config: SACConfig | None = None, config: SACConfig | None = None,
dataset_stats: dict[str, dict[str, Tensor]] | None = None, dataset_stats: dict[str, dict[str, Tensor]] | None = None,
device: str = "cpu",
): ):
super().__init__() super().__init__()
@ -92,7 +92,6 @@ class SACPolicy(
for _ in range(config.num_critics) for _ in range(config.num_critics)
] ]
), ),
device=device,
) )
self.critic_target = CriticEnsemble( self.critic_target = CriticEnsemble(
@ -106,7 +105,6 @@ class SACPolicy(
for _ in range(config.num_critics) for _ in range(config.num_critics)
] ]
), ),
device=device,
) )
self.critic_target.load_state_dict(self.critic_ensemble.state_dict()) self.critic_target.load_state_dict(self.critic_ensemble.state_dict())
@ -115,7 +113,6 @@ class SACPolicy(
encoder=encoder_actor, encoder=encoder_actor,
network=MLP(input_dim=encoder_actor.output_dim, **config.actor_network_kwargs), network=MLP(input_dim=encoder_actor.output_dim, **config.actor_network_kwargs),
action_dim=config.output_shapes["action"][0], action_dim=config.output_shapes["action"][0],
device=device,
encoder_is_shared=config.shared_encoder, encoder_is_shared=config.shared_encoder,
**config.policy_kwargs, **config.policy_kwargs,
) )
@ -123,13 +120,22 @@ class SACPolicy(
config.target_entropy = -np.prod(config.output_shapes["action"][0]) / 2 # (-dim(A)/2) config.target_entropy = -np.prod(config.output_shapes["action"][0]) / 2 # (-dim(A)/2)
# TODO (azouitine): Handle the case where the temparameter is a fixed # TODO (azouitine): Handle the case where the temparameter is a fixed
self.log_alpha = torch.zeros(1, requires_grad=True, device=device) # TODO (michel-aractingi): Put the log_alpha in cuda by default because otherwise
# it triggers "can't optimize a non-leaf Tensor"
self.log_alpha = torch.zeros(1, requires_grad=True, device=torch.device("cuda:0"))
self.temperature = self.log_alpha.exp().item() self.temperature = self.log_alpha.exp().item()
def reset(self): def reset(self):
"""Reset the policy""" """Reset the policy"""
pass pass
def to(self, *args, **kwargs):
"""Override .to(device) method to involve moving the log_alpha fixed_std"""
if self.actor.fixed_std is not None:
self.actor.fixed_std = self.actor.fixed_std.to(*args, **kwargs)
self.log_alpha = self.log_alpha.to(*args, **kwargs)
super().to(*args, **kwargs)
@torch.no_grad() @torch.no_grad()
def select_action(self, batch: dict[str, Tensor]) -> Tensor: def select_action(self, batch: dict[str, Tensor]) -> Tensor:
"""Select action for inference/evaluation""" """Select action for inference/evaluation"""
@ -308,17 +314,12 @@ class CriticEnsemble(nn.Module):
encoder: Optional[nn.Module], encoder: Optional[nn.Module],
network_list: nn.Module, network_list: nn.Module,
init_final: Optional[float] = None, init_final: Optional[float] = None,
device: str = "cpu",
): ):
super().__init__() super().__init__()
self.device = torch.device(device)
self.encoder = encoder self.encoder = encoder
self.network_list = network_list self.network_list = network_list
self.init_final = init_final self.init_final = init_final
# for network in network_list:
# network.to(self.device)
# Find the last Linear layer's output dimension # Find the last Linear layer's output dimension
for layer in reversed(network_list[0].net): for layer in reversed(network_list[0].net):
if isinstance(layer, nn.Linear): if isinstance(layer, nn.Linear):
@ -329,29 +330,28 @@ class CriticEnsemble(nn.Module):
self.output_layers = [] self.output_layers = []
if init_final is not None: if init_final is not None:
for _ in network_list: for _ in network_list:
output_layer = nn.Linear(out_features, 1, device=device) output_layer = nn.Linear(out_features, 1)
nn.init.uniform_(output_layer.weight, -init_final, init_final) nn.init.uniform_(output_layer.weight, -init_final, init_final)
nn.init.uniform_(output_layer.bias, -init_final, init_final) nn.init.uniform_(output_layer.bias, -init_final, init_final)
self.output_layers.append(output_layer) self.output_layers.append(output_layer)
else: else:
self.output_layers = [] self.output_layers = []
for _ in network_list: for _ in network_list:
output_layer = nn.Linear(out_features, 1, device=device) output_layer = nn.Linear(out_features, 1)
orthogonal_init()(output_layer.weight) orthogonal_init()(output_layer.weight)
self.output_layers.append(output_layer) self.output_layers.append(output_layer)
self.output_layers = nn.ModuleList(self.output_layers) self.output_layers = nn.ModuleList(self.output_layers)
self.to(self.device)
def forward( def forward(
self, self,
observations: dict[str, torch.Tensor], observations: dict[str, torch.Tensor],
actions: torch.Tensor, actions: torch.Tensor,
) -> torch.Tensor: ) -> torch.Tensor:
device = get_device_from_parameters(self)
# Move each tensor in observations to device # Move each tensor in observations to device
observations = {k: v.to(self.device) for k, v in observations.items()} observations = {k: v.to(device) for k, v in observations.items()}
actions = actions.to(self.device) actions = actions.to(device)
obs_enc = observations if self.encoder is None else self.encoder(observations) obs_enc = observations if self.encoder is None else self.encoder(observations)
@ -375,17 +375,15 @@ class Policy(nn.Module):
fixed_std: Optional[torch.Tensor] = None, fixed_std: Optional[torch.Tensor] = None,
init_final: Optional[float] = None, init_final: Optional[float] = None,
use_tanh_squash: bool = False, use_tanh_squash: bool = False,
device: str = "cpu",
encoder_is_shared: bool = False, encoder_is_shared: bool = False,
): ):
super().__init__() super().__init__()
self.device = torch.device(device)
self.encoder = encoder self.encoder = encoder
self.network = network self.network = network
self.action_dim = action_dim self.action_dim = action_dim
self.log_std_min = log_std_min self.log_std_min = log_std_min
self.log_std_max = log_std_max self.log_std_max = log_std_max
self.fixed_std = fixed_std.to(self.device) if fixed_std is not None else None self.fixed_std = fixed_std
self.use_tanh_squash = use_tanh_squash self.use_tanh_squash = use_tanh_squash
self.parameters_to_optimize = [] self.parameters_to_optimize = []
@ -417,8 +415,6 @@ class Policy(nn.Module):
orthogonal_init()(self.std_layer.weight) orthogonal_init()(self.std_layer.weight)
self.parameters_to_optimize += list(self.std_layer.parameters()) self.parameters_to_optimize += list(self.std_layer.parameters())
self.to(self.device)
def forward( def forward(
self, self,
observations: torch.Tensor, observations: torch.Tensor,
@ -460,7 +456,8 @@ class Policy(nn.Module):
def get_features(self, observations: torch.Tensor) -> torch.Tensor: def get_features(self, observations: torch.Tensor) -> torch.Tensor:
"""Get encoded features from observations""" """Get encoded features from observations"""
observations = observations.to(self.device) device = get_device_from_parameters(self)
observations = observations.to(device)
if self.encoder is not None: if self.encoder is not None:
with torch.inference_mode(): with torch.inference_mode():
return self.encoder(observations) return self.encoder(observations)

View File

@ -8,7 +8,7 @@
# env.gym.obs_type=environment_state_agent_pos \ # env.gym.obs_type=environment_state_agent_pos \
seed: 1 seed: 1
dataset_repo_id: null dataset_repo_id: aractingi/hil-serl-maniskill-pushcube
training: training:
# Offline training dataloader # Offline training dataloader
@ -21,7 +21,7 @@ training:
eval_freq: 2500 eval_freq: 2500
log_freq: 500 log_freq: 500
save_freq: 50000 save_freq: 1000000
online_steps: 1000000 online_steps: 1000000
online_rollout_n_episodes: 10 online_rollout_n_episodes: 10

View File

@ -152,7 +152,7 @@ def serve_actor_service(port=50052):
server.wait_for_termination() server.wait_for_termination()
def act_with_policy(cfg: DictConfig, out_dir: str | None = None, job_name: str | None = None): def act_with_policy(cfg: DictConfig):
""" """
Executes policy interaction within the environment. Executes policy interaction within the environment.
@ -161,8 +161,6 @@ def act_with_policy(cfg: DictConfig, out_dir: str | None = None, job_name: str |
Args: Args:
cfg (DictConfig): Configuration settings for the interaction process. cfg (DictConfig): Configuration settings for the interaction process.
out_dir (Optional[str]): Directory to store output logs or results. Defaults to None.
job_name (Optional[str]): Name of the job for logging or tracking purposes. Defaults to None.
""" """
logging.info("make_env online") logging.info("make_env online")
@ -189,9 +187,10 @@ def act_with_policy(cfg: DictConfig, out_dir: str | None = None, job_name: str |
# Hack: But if we do online training, we do not need dataset_stats # Hack: But if we do online training, we do not need dataset_stats
dataset_stats=None, dataset_stats=None,
# TODO: Handle resume training # TODO: Handle resume training
pretrained_policy_name_or_path=None,
device=device,
) )
# pretrained_policy_name_or_path=None,
# device=device,
# )
assert isinstance(policy, nn.Module) assert isinstance(policy, nn.Module)
# HACK for maniskill # HACK for maniskill
@ -295,11 +294,7 @@ def actor_cli(cfg: dict):
policy_thread = Thread( policy_thread = Thread(
target=act_with_policy, target=act_with_policy,
daemon=True, daemon=True,
args=( args=(cfg,),
cfg,
hydra.core.hydra_config.HydraConfig.get().run.dir,
hydra.core.hydra_config.HydraConfig.get().job.name,
),
) )
policy_thread.start() policy_thread.start()
policy_thread.join() policy_thread.join()

View File

@ -18,6 +18,7 @@ import io
import logging import logging
import pickle import pickle
import queue import queue
import shutil
import time import time
from pprint import pformat from pprint import pformat
from threading import Lock, Thread from threading import Lock, Thread
@ -29,18 +30,25 @@ import hilserl_pb2 # type: ignore
import hilserl_pb2_grpc # type: ignore import hilserl_pb2_grpc # type: ignore
import hydra import hydra
import torch import torch
from deepdiff import DeepDiff
from omegaconf import DictConfig, OmegaConf from omegaconf import DictConfig, OmegaConf
from termcolor import colored
from torch import nn from torch import nn
# TODO: Remove the import of maniskill
from lerobot.common.datasets.factory import make_dataset from lerobot.common.datasets.factory import make_dataset
# TODO: Remove the import of maniskill
from lerobot.common.datasets.lerobot_dataset import LeRobotDataset
from lerobot.common.logger import Logger, log_output_dir from lerobot.common.logger import Logger, log_output_dir
from lerobot.common.policies.factory import make_policy from lerobot.common.policies.factory import make_policy
from lerobot.common.policies.sac.modeling_sac import SACPolicy from lerobot.common.policies.sac.modeling_sac import SACPolicy
from lerobot.common.utils.utils import ( from lerobot.common.utils.utils import (
format_big_number, format_big_number,
get_global_random_state,
get_safe_torch_device, get_safe_torch_device,
init_hydra_config,
init_logging, init_logging,
set_global_random_state,
set_global_seed, set_global_seed,
) )
from lerobot.scripts.server.buffer import ( from lerobot.scripts.server.buffer import (
@ -127,10 +135,9 @@ def add_actor_information_and_train(
optimizers: dict[str, torch.optim.Optimizer], optimizers: dict[str, torch.optim.Optimizer],
policy: nn.Module, policy: nn.Module,
policy_lock: Lock, policy_lock: Lock,
buffer_lock: Lock,
offline_buffer_lock: Lock,
logger_lock: Lock,
logger: Logger, logger: Logger,
resume_optimization_step: int | None = None,
resume_interaction_step: int | None = None,
): ):
""" """
Handles data transfer from the actor to the learner, manages training updates, Handles data transfer from the actor to the learner, manages training updates,
@ -159,16 +166,17 @@ def add_actor_information_and_train(
optimizers (Dict[str, torch.optim.Optimizer]): A dictionary of optimizers (`"actor"`, `"critic"`, `"temperature"`). optimizers (Dict[str, torch.optim.Optimizer]): A dictionary of optimizers (`"actor"`, `"critic"`, `"temperature"`).
policy (nn.Module): The reinforcement learning policy with critic, actor, and temperature parameters. policy (nn.Module): The reinforcement learning policy with critic, actor, and temperature parameters.
policy_lock (Lock): A threading lock to ensure safe policy updates. policy_lock (Lock): A threading lock to ensure safe policy updates.
buffer_lock (Lock): A threading lock to safely access the online replay buffer.
offline_buffer_lock (Lock): A threading lock to safely access the offline replay buffer.
logger_lock (Lock): A threading lock to safely log training metrics.
logger (Logger): Logger instance for tracking training progress. logger (Logger): Logger instance for tracking training progress.
resume_optimization_step (int | None): In the case of resume training, start from the last optimization step reached.
resume_interaction_step (int | None): In the case of resume training, shift the interaction step with the last saved step in order to not break logging.
""" """
# NOTE: This function doesn't have a single responsibility, it should be split into multiple functions # NOTE: This function doesn't have a single responsibility, it should be split into multiple functions
# in the future. The reason why we did that is the GIL in Python. It's super slow the performance # in the future. The reason why we did that is the GIL in Python. It's super slow the performance
# are divided by 200. So we need to have a single thread that does all the work. # are divided by 200. So we need to have a single thread that does all the work.
time.time() time.time()
optimization_step = 0 interaction_message, transition = None, None
optimization_step = resume_optimization_step if resume_optimization_step is not None else 0
interaction_step_shift = resume_interaction_step if resume_interaction_step is not None else 0
while True: while True:
while not transition_queue.empty(): while not transition_queue.empty():
transition_list = transition_queue.get() transition_list = transition_queue.get()
@ -178,6 +186,8 @@ def add_actor_information_and_train(
while not interaction_message_queue.empty(): while not interaction_message_queue.empty():
interaction_message = interaction_message_queue.get() interaction_message = interaction_message_queue.get()
# If cfg.resume, shift the interaction step with the last checkpointed step in order to not break the logging
interaction_message["Interaction step"] += interaction_step_shift
logger.log_dict(interaction_message, mode="train", custom_step_key="Interaction step") logger.log_dict(interaction_message, mode="train", custom_step_key="Interaction step")
if len(replay_buffer) < cfg.training.online_step_before_learning: if len(replay_buffer) < cfg.training.online_step_before_learning:
@ -186,9 +196,9 @@ def add_actor_information_and_train(
for _ in range(cfg.policy.utd_ratio - 1): for _ in range(cfg.policy.utd_ratio - 1):
batch = replay_buffer.sample(batch_size) batch = replay_buffer.sample(batch_size)
if cfg.dataset_repo_id is not None: # if cfg.offline_dataset_repo_id is not None:
batch_offline = offline_replay_buffer.sample(batch_size) # batch_offline = offline_replay_buffer.sample(batch_size)
batch = concatenate_batch_transitions(batch, batch_offline) # batch = concatenate_batch_transitions(batch, batch_offline)
actions = batch["action"] actions = batch["action"]
rewards = batch["reward"] rewards = batch["reward"]
@ -210,11 +220,11 @@ def add_actor_information_and_train(
batch = replay_buffer.sample(batch_size) batch = replay_buffer.sample(batch_size)
if cfg.dataset_repo_id is not None: # if cfg.offline_dataset_repo_id is not None:
batch_offline = offline_replay_buffer.sample(batch_size) # batch_offline = offline_replay_buffer.sample(batch_size)
batch = concatenate_batch_transitions( # batch = concatenate_batch_transitions(
left_batch_transitions=batch, right_batch_transition=batch_offline # left_batch_transitions=batch, right_batch_transition=batch_offline
) # )
actions = batch["action"] actions = batch["action"]
rewards = batch["reward"] rewards = batch["reward"]
@ -274,6 +284,39 @@ def add_actor_information_and_train(
if optimization_step % cfg.training.log_freq == 0: if optimization_step % cfg.training.log_freq == 0:
logging.info(f"[LEARNER] Number of optimization step: {optimization_step}") logging.info(f"[LEARNER] Number of optimization step: {optimization_step}")
if cfg.training.save_checkpoint and (
optimization_step % cfg.training.save_freq == 0 or optimization_step == cfg.training.online_steps
):
logging.info(f"Checkpoint policy after step {optimization_step}")
# Note: Save with step as the identifier, and format it to have at least 6 digits but more if
# needed (choose 6 as a minimum for consistency without being overkill).
_num_digits = max(6, len(str(cfg.training.online_steps)))
step_identifier = f"{optimization_step:0{_num_digits}d}"
interaction_step = (
interaction_message["Interaction step"] if interaction_message is not None else 0
)
logger.save_checkpoint(
optimization_step,
policy,
optimizers,
scheduler=None,
identifier=step_identifier,
interaction_step=interaction_step,
)
# TODO : temporarly save replay buffer here, remove later when on the robot
# We want to control this with the keyboard inputs
dataset_dir = logger.log_dir / "dataset"
if dataset_dir.exists() and dataset_dir.is_dir():
shutil.rmtree(
dataset_dir,
)
replay_buffer.to_lerobot_dataset(
cfg.dataset_repo_id, fps=cfg.fps, root=logger.log_dir / "dataset"
)
logging.info("Resume training")
def make_optimizers_and_scheduler(cfg, policy: nn.Module): def make_optimizers_and_scheduler(cfg, policy: nn.Module):
""" """
@ -330,7 +373,49 @@ def train(cfg: DictConfig, out_dir: str | None = None, job_name: str | None = No
logging.info(pformat(OmegaConf.to_container(cfg))) logging.info(pformat(OmegaConf.to_container(cfg)))
logger = Logger(cfg, out_dir, wandb_job_name=job_name) logger = Logger(cfg, out_dir, wandb_job_name=job_name)
logger_lock = Lock()
## Handle resume by reloading the state of the policy and optimization
# If we are resuming a run, we need to check that a checkpoint exists in the log directory, and we need
# to check for any differences between the provided config and the checkpoint's config.
if cfg.resume:
if not Logger.get_last_checkpoint_dir(out_dir).exists():
raise RuntimeError(
"You have set resume=True, but there is no model checkpoint in "
f"{Logger.get_last_checkpoint_dir(out_dir)}"
)
checkpoint_cfg_path = str(Logger.get_last_pretrained_model_dir(out_dir) / "config.yaml")
logging.info(
colored(
"You have set resume=True, indicating that you wish to resume a run",
color="yellow",
attrs=["bold"],
)
)
# Get the configuration file from the last checkpoint.
checkpoint_cfg = init_hydra_config(checkpoint_cfg_path)
# Check for differences between the checkpoint configuration and provided configuration.
diff = DeepDiff(OmegaConf.to_container(checkpoint_cfg), OmegaConf.to_container(cfg))
# Ignore the `resume` and parameters.
if "values_changed" in diff and "root['resume']" in diff["values_changed"]:
del diff["values_changed"]["root['resume']"]
# Log a warning about differences between the checkpoint configuration and the provided
# configuration.
if len(diff) > 0:
logging.warning(
"At least one difference was detected between the checkpoint configuration and "
f"the provided configuration: \n{pformat(diff)}\nNote that the checkpoint configuration "
"takes precedence.",
)
# Use the checkpoint config instead of the provided config (but keep `resume` parameter).
cfg = checkpoint_cfg
cfg.resume = True
elif Logger.get_last_checkpoint_dir(out_dir).exists():
raise RuntimeError(
f"The configured output directory {Logger.get_last_checkpoint_dir(out_dir)} already exists. If "
"you meant to resume training, please use `resume=true` in your command or yaml configuration."
)
# ===========================
set_global_seed(cfg.seed) set_global_seed(cfg.seed)
@ -346,20 +431,38 @@ def train(cfg: DictConfig, out_dir: str | None = None, job_name: str | None = No
### on both sides, the learner sends the updated parameters every n steps to update the actor's parameters ### on both sides, the learner sends the updated parameters every n steps to update the actor's parameters
# TODO: At some point we should just need make sac policy # TODO: At some point we should just need make sac policy
policy_lock = Lock() policy_lock = Lock()
with logger_lock: policy: SACPolicy = make_policy(
policy: SACPolicy = make_policy( hydra_cfg=cfg,
hydra_cfg=cfg, # dataset_stats=offline_dataset.meta.stats if not cfg.resume else None,
# dataset_stats=offline_dataset.meta.stats if not cfg.resume else None, # Hack: But if we do online traning, we do not need dataset_stats
# Hack: But if we do online traning, we do not need dataset_stats dataset_stats=None,
dataset_stats=None, pretrained_policy_name_or_path=str(logger.last_pretrained_model_dir) if cfg.resume else None,
pretrained_policy_name_or_path=str(logger.last_pretrained_model_dir) if cfg.resume else None, )
device=device, # device=device,
) # )
assert isinstance(policy, nn.Module) assert isinstance(policy, nn.Module)
optimizers, lr_scheduler = make_optimizers_and_scheduler(cfg, policy) optimizers, lr_scheduler = make_optimizers_and_scheduler(cfg, policy)
# load last training state
# We can't use the logger function in `lerobot/common/logger.py`
# because it only loads the optimization step and not the interaction one
# to avoid altering that code, we will just load the optimization state manually
resume_interaction_step, resume_optimization_step = None, None
if cfg.resume:
training_state = torch.load(logger.last_checkpoint_dir / logger.training_state_file_name)
if type(training_state["optimizer"]) is dict:
assert set(training_state["optimizer"].keys()) == set(optimizers.keys()), (
"Optimizer dictionaries do not have the same keys during resume!"
)
for k, v in training_state["optimizer"].items():
optimizers[k].load_state_dict(v)
else:
optimizers.load_state_dict(training_state["optimizer"])
# Small hack to get the expected keys: use `get_global_random_state`.
set_global_random_state({k: training_state[k] for k in get_global_random_state()})
resume_optimization_step = training_state["step"]
resume_interaction_step = training_state["interaction_step"]
# TODO: Handle resume
num_learnable_params = sum(p.numel() for p in policy.parameters() if p.requires_grad) num_learnable_params = sum(p.numel() for p in policy.parameters() if p.requires_grad)
num_total_params = sum(p.numel() for p in policy.parameters()) num_total_params = sum(p.numel() for p in policy.parameters())
@ -369,24 +472,34 @@ def train(cfg: DictConfig, out_dir: str | None = None, job_name: str | None = No
logging.info(f"{num_learnable_params=} ({format_big_number(num_learnable_params)})") logging.info(f"{num_learnable_params=} ({format_big_number(num_learnable_params)})")
logging.info(f"{num_total_params=} ({format_big_number(num_total_params)})") logging.info(f"{num_total_params=} ({format_big_number(num_total_params)})")
buffer_lock = Lock() if not cfg.resume:
replay_buffer = ReplayBuffer( replay_buffer = ReplayBuffer(
capacity=cfg.training.online_buffer_capacity, device=device, state_keys=cfg.policy.input_shapes.keys() capacity=cfg.training.online_buffer_capacity,
) device=device,
state_keys=cfg.policy.input_shapes.keys(),
)
else:
# Reload replay buffer
dataset = LeRobotDataset(
repo_id=cfg.dataset_repo_id, local_files_only=True, root=logger.log_dir / "dataset"
)
replay_buffer = ReplayBuffer.from_lerobot_dataset(
lerobot_dataset=dataset,
capacity=cfg.training.online_buffer_capacity,
device=device,
state_keys=cfg.policy.input_shapes.keys(),
)
batch_size = cfg.training.batch_size batch_size = cfg.training.batch_size
offline_buffer_lock = None
offline_replay_buffer = None offline_replay_buffer = None
if cfg.dataset_repo_id is not None: # if cfg.dataset_repo_id is not None:
logging.info("make_dataset offline buffer") # logging.info("make_dataset offline buffer")
offline_dataset = make_dataset(cfg) # offline_dataset = make_dataset(cfg)
logging.info("Convertion to a offline replay buffer") # logging.info("Convertion to a offline replay buffer")
offline_replay_buffer = ReplayBuffer.from_lerobot_dataset( # offline_replay_buffer = ReplayBuffer.from_lerobot_dataset(
offline_dataset, device=device, state_keys=cfg.policy.input_shapes.keys() # offline_dataset, device=device, state_keys=cfg.policy.input_shapes.keys()
) # )
offline_buffer_lock = Lock() # batch_size: int = batch_size // 2 # We will sample from both replay buffer
batch_size: int = batch_size // 2 # We will sample from both replay buffer
actor_ip = cfg.actor_learner_config.actor_ip actor_ip = cfg.actor_learner_config.actor_ip
port = cfg.actor_learner_config.port port = cfg.actor_learner_config.port
@ -413,10 +526,9 @@ def train(cfg: DictConfig, out_dir: str | None = None, job_name: str | None = No
optimizers, optimizers,
policy, policy,
policy_lock, policy_lock,
buffer_lock,
offline_buffer_lock,
logger_lock,
logger, logger,
resume_optimization_step,
resume_interaction_step,
), ),
) )
transition_thread.start() transition_thread.start()