From da68e7f1cc36700f9cd324b29278ce9bef7ff041 Mon Sep 17 00:00:00 2001 From: Saifeddine ALOUI Date: Tue, 16 Jan 2024 11:18:23 +0100 Subject: [PATCH] Update main.py --- ollama_proxy_server/main.py | 17 ++++++++++++----- 1 file changed, 12 insertions(+), 5 deletions(-) diff --git a/ollama_proxy_server/main.py b/ollama_proxy_server/main.py index 5f81706..5686c2c 100644 --- a/ollama_proxy_server/main.py +++ b/ollama_proxy_server/main.py @@ -44,19 +44,19 @@ def main(): authorized_users = get_authorized_users(args.users_list) class RequestHandler(BaseHTTPRequestHandler): - def add_access_log_entry(self, user, ip_address, access): + def add_access_log_entry(self, user, ip_address, access, server): log_file_path = Path(sys.argv[1]) if not log_file_path.exists(): with open(log_file_path, mode='w', newline='') as csvfile: - fieldnames = ['time_stamp', 'user_name', 'ip_address','access'] + fieldnames = ['time_stamp', 'user_name', 'ip_address','access','server'] writer = csv.DictWriter(csvfile, fieldnames=fieldnames) writer.writeheader() with open(log_file_path, mode='a', newline='') as csvfile: fieldnames = ['time_stamp', 'user_name', 'ip_address'] writer = csv.DictWriter(csvfile, fieldnames=fieldnames) - row = {'time_stamp': str(datetime.datetime.now()), 'user_name': user, 'ip_address': ip_address, 'access': access} + row = {'time_stamp': str(datetime.datetime.now()), 'user_name': user, 'ip_address': ip_address, 'access': access, 'server': server} writer.writerow(row) def _send_response(self, response): @@ -82,12 +82,18 @@ def main(): user, key = token.split(':') # Check if the user and key are in the list of authorized users - return authorized_users.get(user) == key + if authorized_users.get(user) == key: + self.user = user + return True + else: + self.user = "unknown" + return False def proxy(self): if not self._validate_user_and_key(): ASCIIColors.red(f'User is not authorized') - # self.add_access_log_entry("unknown") + client_ip, client_port = self.client_address + self.add_access_log_entry(user="unknown", ip_address=client_ip, "Denied", "None") self.send_response(403) self.end_headers() return @@ -114,6 +120,7 @@ def main(): # Apply the queuing mechanism only for a specific endpoint. if path == '/api/generate': que = min_queued_server[1]['queue'] + self.add_access_log_entry(user=self.user, ip_address=client_ip, "Authorized", min_queued_server[0]) que.put_nowait(1) try: response = requests.request(self.command, min_queued_server[1]['url'] + path, params=get_params, data=post_params)