supporting accelerate
This commit is contained in:
parent
8699a28be0
commit
86e8c1d997
|
@ -13,7 +13,7 @@
|
|||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from typing import Any
|
||||
from typing import Any, Callable
|
||||
|
||||
from lerobot.common.utils.utils import format_big_number
|
||||
|
||||
|
@ -93,12 +93,14 @@ class MetricsTracker:
|
|||
num_episodes: int,
|
||||
metrics: dict[str, AverageMeter],
|
||||
initial_step: int = 0,
|
||||
accelerator: Callable = None,
|
||||
):
|
||||
self.__dict__.update({k: None for k in self.__keys__})
|
||||
self._batch_size = batch_size
|
||||
self._num_frames = num_frames
|
||||
self._avg_samples_per_ep = num_frames / num_episodes
|
||||
self.metrics = metrics
|
||||
self.accelerator = accelerator
|
||||
|
||||
self.steps = initial_step
|
||||
# A sample is an (observation,action) pair, where observation and action
|
||||
|
@ -128,7 +130,7 @@ class MetricsTracker:
|
|||
Updates metrics that depend on 'step' for one step.
|
||||
"""
|
||||
self.steps += 1
|
||||
self.samples += self._batch_size
|
||||
self.samples += self._batch_size * (self.accelerator.num_processes if self.accelerator else 1)
|
||||
self.episodes = self.samples / self._avg_samples_per_ep
|
||||
self.epochs = self.samples / self._num_frames
|
||||
|
||||
|
|
|
@ -16,7 +16,7 @@
|
|||
import random
|
||||
from contextlib import contextmanager
|
||||
from pathlib import Path
|
||||
from typing import Any, Generator
|
||||
from typing import Any, Generator, Callable
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
|
@ -163,14 +163,16 @@ def set_rng_state(random_state_dict: dict[str, Any]):
|
|||
torch.cuda.random.set_rng_state(random_state_dict["torch_cuda_random_state"])
|
||||
|
||||
|
||||
def set_seed(seed) -> None:
|
||||
def set_seed(seed, accelerator: Callable = None) -> None:
|
||||
"""Set seed for reproducibility."""
|
||||
random.seed(seed)
|
||||
np.random.seed(seed)
|
||||
torch.manual_seed(seed)
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.manual_seed_all(seed)
|
||||
|
||||
if accelerator:
|
||||
from accelerate.utils import set_seed
|
||||
set_seed(seed)
|
||||
|
||||
@contextmanager
|
||||
def seeded_context(seed: int) -> Generator[None, None, None]:
|
||||
|
|
|
@ -20,10 +20,10 @@ import platform
|
|||
from copy import copy
|
||||
from datetime import datetime, timezone
|
||||
from pathlib import Path
|
||||
|
||||
from typing import Callable
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
from typing import Any
|
||||
|
||||
def none_or_int(value):
|
||||
if value == "None":
|
||||
|
@ -50,12 +50,12 @@ def auto_select_torch_device() -> torch.device:
|
|||
return torch.device("cpu")
|
||||
|
||||
|
||||
def get_safe_torch_device(try_device: str, log: bool = False) -> torch.device:
|
||||
def get_safe_torch_device(try_device: str, log: bool = False, accelerator: Callable = None) -> torch.device:
|
||||
"""Given a string, return a torch.device with checks on whether the device is available."""
|
||||
match try_device:
|
||||
case "cuda":
|
||||
assert torch.cuda.is_available()
|
||||
device = torch.device("cuda")
|
||||
device = accelerator.device if accelerator else torch.device("cuda")
|
||||
case "mps":
|
||||
assert torch.backends.mps.is_available()
|
||||
device = torch.device("mps")
|
||||
|
@ -103,7 +103,7 @@ def is_amp_available(device: str):
|
|||
raise ValueError(f"Unknown device '{device}.")
|
||||
|
||||
|
||||
def init_logging():
|
||||
def init_logging(accelerator: Callable = None):
|
||||
def custom_format(record):
|
||||
dt = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
|
||||
fnameline = f"{record.pathname}:{record.lineno}"
|
||||
|
@ -120,7 +120,10 @@ def init_logging():
|
|||
console_handler = logging.StreamHandler()
|
||||
console_handler.setFormatter(formatter)
|
||||
logging.getLogger().addHandler(console_handler)
|
||||
|
||||
if accelerator is not None and not accelerator.is_main_process:
|
||||
# Disable duplicate logging on non-main processes
|
||||
logging.info(f"Setting logging level on non-main process {accelerator.process_index} to WARNING.")
|
||||
logging.getLogger().setLevel(logging.WARNING)
|
||||
|
||||
def format_big_number(num, precision=0):
|
||||
suffixes = ["", "K", "M", "B", "T", "Q"]
|
||||
|
@ -216,3 +219,18 @@ def is_valid_numpy_dtype_string(dtype_str: str) -> bool:
|
|||
except TypeError:
|
||||
# If a TypeError is raised, the string is not a valid dtype
|
||||
return False
|
||||
|
||||
def is_launched_with_accelerate() -> bool:
|
||||
return "ACCELERATE_MIXED_PRECISION" in os.environ
|
||||
|
||||
def get_accelerate_config(accelerator: Callable = None) -> dict[str, Any]:
|
||||
config = {}
|
||||
if not accelerator:
|
||||
return config
|
||||
config["num_processes"] = accelerator.num_processes
|
||||
config["device"] = str(accelerator.device)
|
||||
config["distributed_type"] = str(accelerator.distributed_type)
|
||||
config["mixed_precision"] = accelerator.mixed_precision
|
||||
config["gradient_accumulation_steps"] = accelerator.gradient_accumulation_steps
|
||||
|
||||
return config
|
|
@ -17,7 +17,7 @@ import logging
|
|||
import time
|
||||
from contextlib import nullcontext
|
||||
from pprint import pformat
|
||||
from typing import Any
|
||||
from typing import Any, Callable
|
||||
|
||||
import torch
|
||||
from termcolor import colored
|
||||
|
@ -46,6 +46,8 @@ from lerobot.common.utils.utils import (
|
|||
get_safe_torch_device,
|
||||
has_method,
|
||||
init_logging,
|
||||
get_accelerate_config,
|
||||
is_launched_with_accelerate
|
||||
)
|
||||
from lerobot.common.utils.wandb_utils import WandBLogger
|
||||
from lerobot.configs import parser
|
||||
|
@ -63,15 +65,26 @@ def update_policy(
|
|||
lr_scheduler=None,
|
||||
use_amp: bool = False,
|
||||
lock=None,
|
||||
accelerator: Callable = None,
|
||||
) -> tuple[MetricsTracker, dict]:
|
||||
start_time = time.perf_counter()
|
||||
device = get_device_from_parameters(policy)
|
||||
policy.train()
|
||||
with torch.autocast(device_type=device.type) if use_amp else nullcontext():
|
||||
with torch.autocast(device_type=device.type) if use_amp and accelerator is None else nullcontext():
|
||||
loss, output_dict = policy.forward(batch)
|
||||
# TODO(rcadene): policy.unnormalize_outputs(out_dict)
|
||||
grad_scaler.scale(loss).backward()
|
||||
|
||||
if accelerator:
|
||||
accelerator.backward(loss)
|
||||
accelerator.unscale_gradients(optimizer=optimizer)
|
||||
grad_norm = torch.nn.utils.clip_grad_norm_(
|
||||
policy.parameters(),
|
||||
grad_clip_norm,
|
||||
error_if_nonfinite=False,
|
||||
)
|
||||
optimizer.step()
|
||||
else:
|
||||
grad_scaler.scale(loss).backward()
|
||||
# Unscale the graident of the optimzer's assigned params in-place **prior to gradient clipping**.
|
||||
grad_scaler.unscale_(optimizer)
|
||||
|
||||
|
@ -94,6 +107,10 @@ def update_policy(
|
|||
if lr_scheduler is not None:
|
||||
lr_scheduler.step()
|
||||
|
||||
if accelerator:
|
||||
if has_method(accelerator.unwrap_model(policy, keep_fp32_wrapper=True), "update"): # FIXME(mshukor): avoid accelerator.unwrap_model ?
|
||||
accelerator.unwrap_model(policy, keep_fp32_wrapper=True).update()
|
||||
else:
|
||||
if has_method(policy, "update"):
|
||||
# To possibly update an internal buffer (for instance an Exponential Moving Average like in TDMPC).
|
||||
policy.update()
|
||||
|
@ -106,10 +123,14 @@ def update_policy(
|
|||
|
||||
|
||||
@parser.wrap()
|
||||
def train(cfg: TrainPipelineConfig):
|
||||
def train(cfg: TrainPipelineConfig, accelerator: Callable = None):
|
||||
cfg.validate()
|
||||
logging.info(pformat(cfg.to_dict()))
|
||||
|
||||
if accelerator and not accelerator.is_main_process:
|
||||
# Disable logging on non-main processes.
|
||||
cfg.wandb.enable = False
|
||||
|
||||
if cfg.wandb.enable and cfg.wandb.project:
|
||||
wandb_logger = WandBLogger(cfg)
|
||||
else:
|
||||
|
@ -117,10 +138,10 @@ def train(cfg: TrainPipelineConfig):
|
|||
logging.info(colored("Logs will be saved locally.", "yellow", attrs=["bold"]))
|
||||
|
||||
if cfg.seed is not None:
|
||||
set_seed(cfg.seed)
|
||||
set_seed(cfg.seed, accelerator=accelerator)
|
||||
|
||||
# Check device is available
|
||||
device = get_safe_torch_device(cfg.device, log=True)
|
||||
device = get_safe_torch_device(cfg.device, log=True, accelerator=accelerator)
|
||||
torch.backends.cudnn.benchmark = True
|
||||
torch.backends.cuda.matmul.allow_tf32 = True
|
||||
|
||||
|
@ -141,7 +162,7 @@ def train(cfg: TrainPipelineConfig):
|
|||
device=device,
|
||||
ds_meta=dataset.meta,
|
||||
)
|
||||
|
||||
policy.to(device)
|
||||
logging.info("Creating optimizer and scheduler")
|
||||
optimizer, lr_scheduler = make_optimizer_and_scheduler(cfg, policy)
|
||||
grad_scaler = GradScaler(device, enabled=cfg.use_amp)
|
||||
|
@ -184,6 +205,10 @@ def train(cfg: TrainPipelineConfig):
|
|||
pin_memory=device.type != "cpu",
|
||||
drop_last=False,
|
||||
)
|
||||
if accelerator:
|
||||
policy, optimizer, dataloader, lr_scheduler = accelerator.prepare(
|
||||
policy, optimizer, dataloader, lr_scheduler
|
||||
)
|
||||
dl_iter = cycle(dataloader)
|
||||
|
||||
policy.train()
|
||||
|
@ -197,7 +222,7 @@ def train(cfg: TrainPipelineConfig):
|
|||
}
|
||||
|
||||
train_tracker = MetricsTracker(
|
||||
cfg.batch_size, dataset.num_frames, dataset.num_episodes, train_metrics, initial_step=step
|
||||
cfg.batch_size, dataset.num_frames, dataset.num_episodes, train_metrics, initial_step=step, accelerator=accelerator
|
||||
)
|
||||
|
||||
logging.info("Start offline training on a fixed dataset")
|
||||
|
@ -219,6 +244,7 @@ def train(cfg: TrainPipelineConfig):
|
|||
grad_scaler=grad_scaler,
|
||||
lr_scheduler=lr_scheduler,
|
||||
use_amp=cfg.use_amp,
|
||||
accelerator=accelerator,
|
||||
)
|
||||
|
||||
# Note: eval and checkpoint happens *after* the `step`th training update has completed, so we
|
||||
|
@ -238,21 +264,26 @@ def train(cfg: TrainPipelineConfig):
|
|||
wandb_logger.log_dict(wandb_log_dict, step)
|
||||
train_tracker.reset_averages()
|
||||
|
||||
if cfg.save_checkpoint and is_saving_step:
|
||||
if cfg.save_checkpoint and is_saving_step and (not accelerator or accelerator.is_main_process):
|
||||
logging.info(f"Checkpoint policy after step {step}")
|
||||
checkpoint_dir = get_step_checkpoint_dir(cfg.output_dir, cfg.steps, step)
|
||||
save_checkpoint(checkpoint_dir, step, cfg, policy, optimizer, lr_scheduler)
|
||||
save_checkpoint(checkpoint_dir, step, cfg, policy if not accelerator else accelerator.unwrap_model(policy), optimizer, lr_scheduler)
|
||||
update_last_checkpoint(checkpoint_dir)
|
||||
if wandb_logger:
|
||||
wandb_logger.log_policy(checkpoint_dir)
|
||||
|
||||
if accelerator:
|
||||
accelerator.wait_for_everyone()
|
||||
if cfg.env and is_eval_step:
|
||||
step_id = get_step_identifier(step, cfg.steps)
|
||||
logging.info(f"Eval policy at step {step}")
|
||||
with torch.no_grad(), torch.autocast(device_type=device.type) if cfg.use_amp else nullcontext():
|
||||
with (
|
||||
torch.no_grad(),
|
||||
torch.autocast(device_type=device.type) if cfg.use_amp and not accelerator else nullcontext(),
|
||||
):
|
||||
eval_info = eval_policy(
|
||||
eval_env,
|
||||
policy,
|
||||
policy if not accelerator else accelerator.unwrap_model(policy),
|
||||
cfg.eval.n_episodes,
|
||||
videos_dir=cfg.output_dir / "eval" / f"videos_step_{step_id}",
|
||||
max_episodes_rendered=4,
|
||||
|
@ -265,7 +296,7 @@ def train(cfg: TrainPipelineConfig):
|
|||
"eval_s": AverageMeter("eval_s", ":.3f"),
|
||||
}
|
||||
eval_tracker = MetricsTracker(
|
||||
cfg.batch_size, dataset.num_frames, dataset.num_episodes, eval_metrics, initial_step=step
|
||||
cfg.batch_size, dataset.num_frames, dataset.num_episodes, eval_metrics, initial_step=step, accelerator=None
|
||||
)
|
||||
eval_tracker.eval_s = eval_info["aggregated"].pop("eval_s")
|
||||
eval_tracker.avg_sum_reward = eval_info["aggregated"].pop("avg_sum_reward")
|
||||
|
@ -283,4 +314,11 @@ def train(cfg: TrainPipelineConfig):
|
|||
|
||||
if __name__ == "__main__":
|
||||
init_logging()
|
||||
if is_launched_with_accelerate():
|
||||
import accelerate
|
||||
# We set step_scheduler_with_optimizer False to prevent accelerate from
|
||||
# adjusting the lr_scheduler steps based on the num_processes
|
||||
accelerator = accelerate.Accelerator(step_scheduler_with_optimizer=False)
|
||||
train(accelerator=accelerator)
|
||||
else:
|
||||
train()
|
||||
|
|
Loading…
Reference in New Issue