From 60668423791a93c9f6f10741b3df9d0d43d60118 Mon Sep 17 00:00:00 2001 From: Saifeddine ALOUI Date: Wed, 17 Jan 2024 14:27:02 +0100 Subject: [PATCH] Update main.py --- ollama_proxy_server/main.py | 15 +++++++++------ 1 file changed, 9 insertions(+), 6 deletions(-) diff --git a/ollama_proxy_server/main.py b/ollama_proxy_server/main.py index b825443..253e415 100644 --- a/ollama_proxy_server/main.py +++ b/ollama_proxy_server/main.py @@ -47,19 +47,19 @@ def main(): authorized_users = get_authorized_users(args.users_list) class RequestHandler(BaseHTTPRequestHandler): - def add_access_log_entry(self, user, ip_address, access, server): + def add_access_log_entry(self, event, user, ip_address, access, server, nb_queued_requests_on_server, error=""): log_file_path = Path(args.log_path) if not log_file_path.exists(): with open(log_file_path, mode='w', newline='') as csvfile: - fieldnames = ['time_stamp', 'user_name', 'ip_address','access','server'] + fieldnames = ['time_stamp', 'event', 'user_name', 'ip_address', 'access', 'server', 'nb_queued_requests_on_server', 'error'] writer = csv.DictWriter(csvfile, fieldnames=fieldnames) writer.writeheader() with open(log_file_path, mode='a', newline='') as csvfile: - fieldnames = ['time_stamp', 'user_name', 'ip_address', 'access', 'server'] + fieldnames = ['time_stamp', 'event', 'user_name', 'ip_address', 'access', 'server', 'nb_queued_requests_on_server', 'error'] writer = csv.DictWriter(csvfile, fieldnames=fieldnames) - row = {'time_stamp': str(datetime.datetime.now()), 'user_name': user, 'ip_address': ip_address, 'access': access, 'server': server} + row = {'time_stamp': str(datetime.datetime.now()), 'event':event, 'user_name': user, 'ip_address': ip_address, 'access': access, 'server': server, 'nb_queued_requests_on_server': nb_queued_requests_on_server, 'error': error} writer.writerow(row) def _send_response(self, response): @@ -96,7 +96,7 @@ def main(): if not self._validate_user_and_key(): ASCIIColors.red(f'User is not authorized') client_ip, client_port = self.client_address - self.add_access_log_entry(user="unknown", ip_address=client_ip, access="Denied", server="None") + self.add_access_log_entry(user="unknown", ip_address=client_ip, access="Denied", server="None", -1) self.send_response(403) self.end_headers() return @@ -124,13 +124,16 @@ def main(): if path == '/api/generate': que = min_queued_server[1]['queue'] client_ip, client_port = self.client_address - self.add_access_log_entry(user=self.user, ip_address=client_ip, access="Authorized", server=min_queued_server[0]) + self.add_access_log_entry(event="gen_request", user=self.user, ip_address=client_ip, access="Authorized", server=min_queued_server[0], nb_queued_requests_on_server=que.qsize()) que.put_nowait(1) try: response = requests.request(self.command, min_queued_server[1]['url'] + path, params=get_params, data=post_params) self._send_response(response) + except Exception as ex: + self.add_access_log_entry(event="gen_error",user=self.user, ip_address=client_ip, access="Authorized", server=min_queued_server[0], nb_queued_requests_on_server=que.qsize(),error=ex) finally: que.get_nowait() + self.add_access_log_entry(event="gen_done",user=self.user, ip_address=client_ip, access="Authorized", server=min_queued_server[0], nb_queued_requests_on_server=que.qsize()) else: # For other endpoints, just mirror the request. response = requests.request(self.command, min_queued_server[1]['url'] + path, params=get_params, data=post_params)