139 lines
5.4 KiB
Python
139 lines
5.4 KiB
Python
"""
|
|
project: ollama_proxy_server
|
|
file: main.py
|
|
author: ParisNeo
|
|
description: This is a proxy server that adds a security layer to one or multiple ollama servers and routes the requests to the right server in order to minimize the charge of the server.
|
|
"""
|
|
|
|
import configparser
|
|
from http.server import BaseHTTPRequestHandler, HTTPServer
|
|
from socketserver import ThreadingMixIn
|
|
from urllib.parse import urlparse, parse_qs
|
|
from queue import Queue
|
|
import requests
|
|
import threading
|
|
import argparse
|
|
import base64
|
|
from ascii_colors import ASCIIColors
|
|
|
|
def get_config(filename):
|
|
config = configparser.ConfigParser()
|
|
config.read(filename)
|
|
return [(name, {'url': config[name]['url'], 'queue': Queue()}) for name in config.sections()]
|
|
|
|
# Read the authorized users and their keys from a file
|
|
def get_authorized_users(filename):
|
|
with open(filename, 'r') as f:
|
|
lines = f.readlines()
|
|
authorized_users = {}
|
|
for line in lines:
|
|
user, key = line.strip().split(':')
|
|
authorized_users[user] = key
|
|
return authorized_users
|
|
|
|
|
|
|
|
def main():
|
|
parser = argparse.ArgumentParser()
|
|
parser.add_argument('--config', default="config.ini", help='Path to the authorized users list')
|
|
parser.add_argument('--log_path', default="access_log.txt", help='Path to the access log file')
|
|
parser.add_argument('--users_list', default="authorized_users.txt", help='Path to the config file')
|
|
parser.add_argument('--port', type=int, default=8000, help='Port number for the server')
|
|
args = parser.parse_args()
|
|
servers = get_config(args.config)
|
|
authorized_users = get_authorized_users(args.users_list)
|
|
|
|
class RequestHandler(BaseHTTPRequestHandler):
|
|
def add_access_log_entry(self, user, ip_address, access):
|
|
log_file_path = Path(sys.argv[1])
|
|
|
|
if not log_file_path.exists():
|
|
with open(log_file_path, mode='w', newline='') as csvfile:
|
|
fieldnames = ['time_stamp', 'user_name', 'ip_address','access']
|
|
writer = csv.DictWriter(csvfile, fieldnames=fieldnames)
|
|
writer.writeheader()
|
|
|
|
with open(log_file_path, mode='a', newline='') as csvfile:
|
|
fieldnames = ['time_stamp', 'user_name', 'ip_address']
|
|
writer = csv.DictWriter(csvfile, fieldnames=fieldnames)
|
|
row = {'time_stamp': str(datetime.datetime.now()), 'user_name': user, 'ip_address': ip_address, 'access': access}
|
|
writer.writerow(row)
|
|
|
|
def _send_response(self, response):
|
|
self.send_response(response.status_code)
|
|
self.send_header('Content-type', response.headers['content-type'])
|
|
self.end_headers()
|
|
self.wfile.write(response.content)
|
|
|
|
def do_GET(self):
|
|
self.log_request()
|
|
self.proxy()
|
|
|
|
def do_POST(self):
|
|
self.log_request()
|
|
self.proxy()
|
|
|
|
def _validate_user_and_key(self):
|
|
# Extract the bearer token from the headers
|
|
auth_header = self.headers.get('Authorization')
|
|
if not auth_header or not auth_header.startswith('Bearer '):
|
|
return False
|
|
token = auth_header.split(' ')[1]
|
|
user, key = token.split(':')
|
|
|
|
# Check if the user and key are in the list of authorized users
|
|
return authorized_users.get(user) == key
|
|
|
|
def proxy(self):
|
|
if not self._validate_user_and_key():
|
|
ASCIIColors.red(f'User is not authorized')
|
|
# self.add_access_log_entry("unknown")
|
|
self.send_response(403)
|
|
self.end_headers()
|
|
return
|
|
url = urlparse(self.path)
|
|
path = url.path
|
|
get_params = parse_qs(url.query) or {}
|
|
|
|
|
|
if self.command == "POST":
|
|
content_length = int(self.headers['Content-Length'])
|
|
post_data = self.rfile.read(content_length)
|
|
post_params = post_data# parse_qs(post_data.decode('utf-8'))
|
|
else:
|
|
post_params = {}
|
|
|
|
|
|
# Find the server with the lowest number of queue entries.
|
|
min_queued_server = servers[0]
|
|
for server in servers:
|
|
cs = server[1]
|
|
if cs['queue'].qsize() < min_queued_server[1]['queue'].qsize():
|
|
min_queued_server = server
|
|
|
|
# Apply the queuing mechanism only for a specific endpoint.
|
|
if path == '/api/generate':
|
|
que = min_queued_server[1]['queue']
|
|
que.put_nowait(1)
|
|
try:
|
|
response = requests.request(self.command, min_queued_server[1]['url'] + path, params=get_params, data=post_params)
|
|
self._send_response(response)
|
|
finally:
|
|
que.get_nowait()
|
|
else:
|
|
# For other endpoints, just mirror the request.
|
|
response = requests.request(self.command, min_queued_server[1]['url'] + path, params=get_params, data=post_params)
|
|
self._send_response(response)
|
|
|
|
class ThreadedHTTPServer(ThreadingMixIn, HTTPServer):
|
|
pass
|
|
|
|
|
|
print('Starting server')
|
|
server = ThreadedHTTPServer(('', args.port), RequestHandler) # Set the entry port here.
|
|
print(f'Running server on port {args.port}')
|
|
server.serve_forever()
|
|
|
|
if __name__ == "__main__":
|
|
main()
|