From ad3379a73ab88616434889644afcdf58a4d3302f Mon Sep 17 00:00:00 2001 From: Cadene Date: Fri, 5 Apr 2024 10:59:32 +0000 Subject: [PATCH] fix memory leak due to itertools.cycle --- lerobot/common/datasets/utils.py | 9 +++++++++ lerobot/common/utils.py | 12 ++++++++++++ lerobot/scripts/train.py | 9 +++++++-- 3 files changed, 28 insertions(+), 2 deletions(-) diff --git a/lerobot/common/datasets/utils.py b/lerobot/common/datasets/utils.py index 6b207b4d..18b091cd 100644 --- a/lerobot/common/datasets/utils.py +++ b/lerobot/common/datasets/utils.py @@ -203,3 +203,12 @@ def compute_or_load_stats(dataset, batch_size=32, max_num_samples=None): torch.save(stats, stats_path) return stats + + +def cycle(iterable): + iterator = iter(iterable) + while True: + try: + yield next(iterator) + except StopIteration: + iterator = iter(iterable) diff --git a/lerobot/common/utils.py b/lerobot/common/utils.py index 7ed29334..e3e22832 100644 --- a/lerobot/common/utils.py +++ b/lerobot/common/utils.py @@ -95,3 +95,15 @@ def init_hydra_config(config_path: str, overrides: list[str] | None = None) -> D ) cfg = hydra.compose(Path(config_path).stem, overrides) return cfg + + +def print_cuda_memory_usage(): + import gc + + gc.collect() + # Also clear the cache if you want to fully release the memory + torch.cuda.empty_cache() + print("Current GPU Memory Allocated: {:.2f} MB".format(torch.cuda.memory_allocated(0) / 1024**2)) + print("Maximum GPU Memory Allocated: {:.2f} MB".format(torch.cuda.max_memory_allocated(0) / 1024**2)) + print("Current GPU Memory Reserved: {:.2f} MB".format(torch.cuda.memory_reserved(0) / 1024**2)) + print("Maximum GPU Memory Reserved: {:.2f} MB".format(torch.cuda.max_memory_reserved(0) / 1024**2)) diff --git a/lerobot/scripts/train.py b/lerobot/scripts/train.py index 602fa5ab..cca26902 100644 --- a/lerobot/scripts/train.py +++ b/lerobot/scripts/train.py @@ -1,5 +1,4 @@ import logging -from itertools import cycle from pathlib import Path import hydra @@ -7,10 +6,16 @@ import numpy as np import torch from lerobot.common.datasets.factory import make_dataset +from lerobot.common.datasets.utils import cycle from lerobot.common.envs.factory import make_env from lerobot.common.logger import Logger, log_output_dir from lerobot.common.policies.factory import make_policy -from lerobot.common.utils import format_big_number, get_safe_torch_device, init_logging, set_global_seed +from lerobot.common.utils import ( + format_big_number, + get_safe_torch_device, + init_logging, + set_global_seed, +) from lerobot.scripts.eval import eval_policy