From 5db0fd0665852005a0879be80804122561e44362 Mon Sep 17 00:00:00 2001 From: Petrit Avdylaj Date: Wed, 10 Apr 2024 10:41:50 +0200 Subject: [PATCH] fix: stream issue --- README.md | 2 +- ollama_proxy_server/main.py | 28 ++++++++++++++++++++++------ requirements.txt | 14 ++++++++------ 3 files changed, 31 insertions(+), 13 deletions(-) diff --git a/README.md b/README.md index de43124..3b661b9 100644 --- a/README.md +++ b/README.md @@ -54,7 +54,7 @@ ollama_proxy_add_user --users_list [path to the authorized `authorized_users.txt ### Starting the server Start the Ollama Proxy Server by running the following command in your terminal: ```bash -ollama_proxy_server --config [configuration file path] --users_list [users list file path] --port [port number to access the proxy] +python3 ollama_proxy_server/main.py --config [configuration file path] --users_list [users list file path] --port [port number to access the proxy] ``` The server will listen on port 808x, with x being the number of available ports starting from 0 (e.g., 8080, 8081, etc.). The first available port will be automatically selected if no other instance is running. diff --git a/ollama_proxy_server/main.py b/ollama_proxy_server/main.py index fa8a914..31c20d8 100644 --- a/ollama_proxy_server/main.py +++ b/ollama_proxy_server/main.py @@ -7,13 +7,12 @@ description: This is a proxy server that adds a security layer to one or multipl import configparser from http.server import BaseHTTPRequestHandler, HTTPServer +import json from socketserver import ThreadingMixIn from urllib.parse import urlparse, parse_qs from queue import Queue import requests -import threading import argparse -import base64 from ascii_colors import ASCIIColors from pathlib import Path import csv @@ -70,12 +69,23 @@ def main(): writer = csv.DictWriter(csvfile, fieldnames=fieldnames) row = {'time_stamp': str(datetime.datetime.now()), 'event':event, 'user_name': user, 'ip_address': ip_address, 'access': access, 'server': server, 'nb_queued_requests_on_server': nb_queued_requests_on_server, 'error': error} writer.writerow(row) - + def _send_response(self, response): self.send_response(response.status_code) - self.send_header('Content-type', response.headers['content-type']) + for key, value in response.headers.items(): + if key.lower() not in ['content-length', 'transfer-encoding', 'content-encoding']: + self.send_header(key, value) + self.send_header('Transfer-Encoding', 'chunked') self.end_headers() - self.wfile.write(response.content) + + try: + for chunk in response.iter_content(chunk_size=1024): + if chunk: + self.wfile.write(b"%X\r\n%s\r\n" % (len(chunk), chunk)) + self.wfile.flush() + self.wfile.write(b"0\r\n\r\n") + except BrokenPipeError: + pass def do_GET(self): self.log_request() @@ -146,7 +156,13 @@ def main(): self.add_access_log_entry(event="gen_request", user=self.user, ip_address=client_ip, access="Authorized", server=min_queued_server[0], nb_queued_requests_on_server=que.qsize()) que.put_nowait(1) try: - response = requests.request(self.command, min_queued_server[1]['url'] + path, params=get_params, data=post_params) + post_data_dict = {} + + if isinstance(post_data, bytes): + post_data_str = post_data.decode('utf-8') + post_data_dict = json.loads(post_data_str) + + response = requests.request(self.command, min_queued_server[1]['url'] + path, params=get_params, data=post_params, stream=post_data_dict.get("stream", False)) self._send_response(response) except Exception as ex: self.add_access_log_entry(event="gen_error",user=self.user, ip_address=client_ip, access="Authorized", server=min_queued_server[0], nb_queued_requests_on_server=que.qsize(),error=ex) diff --git a/requirements.txt b/requirements.txt index 38abe2f..63879a8 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,6 +1,8 @@ -configparser -queues -requests -urllib3 -requests -ascii_colors +ascii-colors==0.2.2 +certifi==2024.2.2 +charset-normalizer==3.3.2 +configparser==6.0.1 +idna==3.6 +queues==0.6.3 +requests==2.31.0 +urllib3==2.2.1 \ No newline at end of file