93 lines
2.7 KiB
Python
93 lines
2.7 KiB
Python
#!/usr/bin/env python
|
|
|
|
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
import threading
|
|
import time
|
|
from contextlib import ContextDecorator
|
|
|
|
|
|
class TimeBenchmark(ContextDecorator):
|
|
"""
|
|
Measures execution time using a context manager or decorator.
|
|
|
|
This class supports both context manager and decorator usage, and is thread-safe for multithreaded
|
|
environments.
|
|
|
|
Args:
|
|
print: If True, prints the elapsed time upon exiting the context or completing the function. Defaults
|
|
to False.
|
|
|
|
Examples:
|
|
|
|
Using as a context manager:
|
|
|
|
>>> benchmark = TimeBenchmark()
|
|
>>> with benchmark:
|
|
... time.sleep(1)
|
|
>>> print(f"Block took {benchmark.result:.4f} seconds")
|
|
Block took approximately 1.0000 seconds
|
|
|
|
Using with multithreading:
|
|
|
|
```python
|
|
import threading
|
|
|
|
benchmark = TimeBenchmark()
|
|
|
|
def context_manager_example():
|
|
with benchmark:
|
|
time.sleep(0.01)
|
|
print(f"Block took {benchmark.result_ms:.2f} milliseconds")
|
|
|
|
threads = []
|
|
for _ in range(3):
|
|
t1 = threading.Thread(target=context_manager_example)
|
|
threads.append(t1)
|
|
|
|
for t in threads:
|
|
t.start()
|
|
|
|
for t in threads:
|
|
t.join()
|
|
```
|
|
Expected output:
|
|
Block took approximately 10.00 milliseconds
|
|
Block took approximately 10.00 milliseconds
|
|
Block took approximately 10.00 milliseconds
|
|
"""
|
|
|
|
def __init__(self, print=False):
|
|
self.local = threading.local()
|
|
self.print_time = print
|
|
|
|
def __enter__(self):
|
|
self.local.start_time = time.perf_counter()
|
|
return self
|
|
|
|
def __exit__(self, *exc):
|
|
self.local.end_time = time.perf_counter()
|
|
self.local.elapsed_time = self.local.end_time - self.local.start_time
|
|
if self.print_time:
|
|
print(f"Elapsed time: {self.local.elapsed_time:.4f} seconds")
|
|
return False
|
|
|
|
@property
|
|
def result(self):
|
|
return getattr(self.local, "elapsed_time", None)
|
|
|
|
@property
|
|
def result_ms(self):
|
|
return self.result * 1e3
|