feat: Implement shared HTTP client with retry logic and optimized connection settings for GitHub/Gitea API calls to improve reliability and performance.
Build and Push CodeReview / build (push) Has been cancelled Details

This commit is contained in:
vinland100 2026-02-06 14:12:51 +08:00
parent f0981ae187
commit 5491714cce
1 changed files with 448 additions and 235 deletions

View File

@ -4,6 +4,7 @@
import os import os
import asyncio import asyncio
import logging
import httpx import httpx
from typing import List, Dict, Any, Optional from typing import List, Dict, Any, Optional
from datetime import datetime, timezone from datetime import datetime, timezone
@ -15,7 +16,65 @@ from app.models.audit import AuditTask, AuditIssue
from app.models.project import Project from app.models.project import Project
from app.services.llm.service import LLMService from app.services.llm.service import LLMService
from app.core.config import settings from app.core.config import settings
from app.core.file_filter import is_text_file as core_is_text_file, should_exclude as core_should_exclude, TEXT_EXTENSIONS as CORE_TEXT_EXTENSIONS from app.core.file_filter import (
is_text_file as core_is_text_file,
should_exclude as core_should_exclude,
TEXT_EXTENSIONS as CORE_TEXT_EXTENSIONS,
)
logger = logging.getLogger(__name__)
# ---------------------------------------------------------------------------
# 共享 HTTP 客户端(模块级别单例)
# 通过复用连接池和 DNS 缓存,避免每次请求都重新建立 TCP 连接,
# 解决"每天第一次打开时因 DNS/连接冷启动导致超时"的问题。
# ---------------------------------------------------------------------------
_GIT_API_TIMEOUT = httpx.Timeout(
connect=15.0, # 连接超时 15 秒(原来 120 秒太长,用户体验差)
read=60.0, # 读取超时 60 秒(大仓库 tree API 可能较慢)
write=30.0,
pool=15.0, # 从连接池获取连接的超时
)
_http_client: Optional[httpx.AsyncClient] = None
def _get_http_client() -> httpx.AsyncClient:
"""获取或创建共享的 HTTP 客户端(惰性初始化)"""
global _http_client
if _http_client is None or _http_client.is_closed:
_http_client = httpx.AsyncClient(
timeout=_GIT_API_TIMEOUT,
limits=httpx.Limits(
max_connections=20,
max_keepalive_connections=10,
keepalive_expiry=300, # 保持连接 5 分钟
),
follow_redirects=True,
)
return _http_client
async def _request_with_retry(
client: httpx.AsyncClient,
url: str,
headers: Dict[str, str],
max_retries: int = 3,
) -> httpx.Response:
"""带自动重试的 HTTP GET 请求,针对连接超时进行重试"""
last_exc: Optional[Exception] = None
for attempt in range(max_retries):
try:
return await client.get(url, headers=headers)
except (httpx.ConnectTimeout, httpx.ConnectError) as e:
last_exc = e
if attempt < max_retries - 1:
wait = (attempt + 1) * 2
logger.info(f"[API] 连接失败 (第 {attempt + 1} 次), {wait}s 后重试: {url} - {e}")
await asyncio.sleep(wait)
else:
logger.error(f"[API] 连接最终失败 (共 {max_retries} 次): {url} - {e}")
raise last_exc # type: ignore[misc]
def get_analysis_config(user_config: Optional[Dict[str, Any]] = None) -> Dict[str, Any]: def get_analysis_config(user_config: Optional[Dict[str, Any]] = None) -> Dict[str, Any]:
@ -28,18 +87,19 @@ def get_analysis_config(user_config: Optional[Dict[str, Any]] = None) -> Dict[st
- llm_concurrency: LLM 并发数 - llm_concurrency: LLM 并发数
- llm_gap_ms: LLM 请求间隔毫秒 - llm_gap_ms: LLM 请求间隔毫秒
""" """
other_config = (user_config or {}).get('otherConfig', {}) other_config = (user_config or {}).get("otherConfig", {})
return { return {
'max_analyze_files': other_config.get('maxAnalyzeFiles') or settings.MAX_ANALYZE_FILES, "max_analyze_files": other_config.get("maxAnalyzeFiles") or settings.MAX_ANALYZE_FILES,
'llm_concurrency': other_config.get('llmConcurrency') or settings.LLM_CONCURRENCY, "llm_concurrency": other_config.get("llmConcurrency") or settings.LLM_CONCURRENCY,
'llm_gap_ms': other_config.get('llmGapMs') or settings.LLM_GAP_MS, "llm_gap_ms": other_config.get("llmGapMs") or settings.LLM_GAP_MS,
} }
# 支持的文本文件扩展名使用全局定义 # 支持的文本文件扩展名使用全局定义
TEXT_EXTENSIONS = list(CORE_TEXT_EXTENSIONS) TEXT_EXTENSIONS = list(CORE_TEXT_EXTENSIONS)
def is_text_file(path: str) -> bool: def is_text_file(path: str) -> bool:
"""检查是否为文本文件""" """检查是否为文本文件"""
return core_is_text_file(path) return core_is_text_file(path)
@ -53,30 +113,68 @@ def should_exclude(path: str, exclude_patterns: List[str] = None) -> bool:
def get_language_from_path(path: str) -> str: def get_language_from_path(path: str) -> str:
"""从文件路径获取语言类型""" """从文件路径获取语言类型"""
ext = path.split('.')[-1].lower() if '.' in path else '' ext = path.split(".")[-1].lower() if "." in path else ""
language_map = { language_map = {
'py': 'python', "py": "python",
'js': 'javascript', 'jsx': 'javascript', "js": "javascript",
'ts': 'typescript', 'tsx': 'typescript', "jsx": "javascript",
'java': 'java', 'go': 'go', 'rs': 'rust', "ts": "typescript",
'cpp': 'cpp', 'c': 'c', 'cc': 'cpp', 'h': 'c', 'hh': 'cpp', "tsx": "typescript",
'hpp': 'cpp', 'hxx': 'cpp', "java": "java",
'cs': 'csharp', 'php': 'php', 'rb': 'ruby', "go": "go",
'kt': 'kotlin', 'ktm': 'kotlin', 'kts': 'kotlin', "rs": "rust",
'swift': 'swift', 'dart': 'dart', "cpp": "cpp",
'scala': 'scala', 'sc': 'scala', "c": "c",
'groovy': 'groovy', 'gsh': 'groovy', 'gvy': 'groovy', 'gy': 'groovy', "cc": "cpp",
'sql': 'sql', 'sh': 'bash', 'bash': 'bash', 'zsh': 'bash', "h": "c",
'pl': 'perl', 'pm': 'perl', 't': 'perl', "hh": "cpp",
'lua': 'lua', 'hs': 'haskell', 'lhs': 'haskell', "hpp": "cpp",
'clj': 'clojure', 'cljs': 'clojure', 'cljc': 'clojure', 'edn': 'clojure', "hxx": "cpp",
'ex': 'elixir', 'exs': 'elixir', 'erl': 'erlang', 'hrl': 'erlang', "cs": "csharp",
'm': 'objective-c', 'mm': 'objective-c', "php": "php",
'r': 'r', 'rmd': 'r', "rb": "ruby",
'vb': 'visual-basic', 'fs': 'fsharp', 'fsi': 'fsharp', 'fsx': 'fsharp', "kt": "kotlin",
'tf': 'hcl', 'hcl': 'hcl', 'dockerfile': 'dockerfile' "ktm": "kotlin",
"kts": "kotlin",
"swift": "swift",
"dart": "dart",
"scala": "scala",
"sc": "scala",
"groovy": "groovy",
"gsh": "groovy",
"gvy": "groovy",
"gy": "groovy",
"sql": "sql",
"sh": "bash",
"bash": "bash",
"zsh": "bash",
"pl": "perl",
"pm": "perl",
"t": "perl",
"lua": "lua",
"hs": "haskell",
"lhs": "haskell",
"clj": "clojure",
"cljs": "clojure",
"cljc": "clojure",
"edn": "clojure",
"ex": "elixir",
"exs": "elixir",
"erl": "erlang",
"hrl": "erlang",
"m": "objective-c",
"mm": "objective-c",
"r": "r",
"rmd": "r",
"vb": "visual-basic",
"fs": "fsharp",
"fsi": "fsharp",
"fsx": "fsharp",
"tf": "hcl",
"hcl": "hcl",
"dockerfile": "dockerfile",
} }
return language_map.get(ext, 'text') return language_map.get(ext, "text")
class TaskControlManager: class TaskControlManager:
@ -107,24 +205,42 @@ async def github_api(url: str, token: str = None) -> Any:
"""调用GitHub API""" """调用GitHub API"""
headers = {"Accept": "application/vnd.github+json"} headers = {"Accept": "application/vnd.github+json"}
t = token or settings.GITHUB_TOKEN t = token or settings.GITHUB_TOKEN
client = _get_http_client()
async with httpx.AsyncClient(timeout=30) as client: # First try with token if available
# First try with token if available if t:
if t: headers["Authorization"] = f"Bearer {t}"
headers["Authorization"] = f"Bearer {t}" try:
try: response = await _request_with_retry(client, url, headers)
response = await client.get(url, headers=headers) if response.status_code == 200:
if response.status_code == 200: return response.json()
return response.json() if response.status_code != 401:
if response.status_code != 401: if response.status_code == 403:
if response.status_code == 403: raise Exception("GitHub API 403请配置 GITHUB_TOKEN 或确认仓库权限/频率限制")
raise Exception("GitHub API 403请配置 GITHUB_TOKEN 或确认仓库权限/频率限制") raise Exception(f"GitHub API {response.status_code}: {url}")
raise Exception(f"GitHub API {response.status_code}: {url}") # If 401, fall through to retry without token
# If 401, fall through to retry without token logger.info(
print(f"[API] GitHub API 401 (Unauthorized) with token, retrying without token for: {url}") f"[API] GitHub API 401 (Unauthorized) with token, retrying without token for: {url}"
except Exception as e: )
if "GitHub API 401" not in str(e) and "401" not in str(e): except Exception as e:
raise if "GitHub API 401" not in str(e) and "401" not in str(e):
raise
# Try without token
headers.pop("Authorization", None)
try:
response = await _request_with_retry(client, url, headers)
if response.status_code == 200:
return response.json()
if response.status_code == 403:
raise Exception("GitHub API 403请配置 GITHUB_TOKEN 或确认仓库权限/频率限制")
if response.status_code == 401:
raise Exception("GitHub API 401请配置 GITHUB_TOKEN 或确认仓库权限")
raise Exception(f"GitHub API {response.status_code}: {url}")
except Exception as e:
logger.error(f"[API] GitHub API 调用失败: {url}, 错误: {e}")
raise
# Try without token # Try without token
if "Authorization" in headers: if "Authorization" in headers:
@ -144,29 +260,46 @@ async def github_api(url: str, token: str = None) -> Any:
raise raise
async def gitea_api(url: str, token: str = None) -> Any: async def gitea_api(url: str, token: str = None) -> Any:
"""调用Gitea API""" """调用Gitea API"""
headers = {"Content-Type": "application/json"} headers = {"Content-Type": "application/json"}
t = token or settings.GITEA_TOKEN t = token or settings.GITEA_TOKEN
client = _get_http_client()
async with httpx.AsyncClient(timeout=120) as client: # First try with token if available
# First try with token if available if t:
if t: headers["Authorization"] = f"token {t}"
headers["Authorization"] = f"token {t}" try:
try: response = await _request_with_retry(client, url, headers)
response = await client.get(url, headers=headers) if response.status_code == 200:
if response.status_code == 200: return response.json()
return response.json() if response.status_code != 401:
if response.status_code != 401: if response.status_code == 403:
if response.status_code == 403: raise Exception("Gitea API 403请确认仓库权限/频率限制")
raise Exception("Gitea API 403请确认仓库权限/频率限制") raise Exception(f"Gitea API {response.status_code}: {url}")
raise Exception(f"Gitea API {response.status_code}: {url}") # If 401, fall through to retry without token
# If 401, fall through to retry without token logger.info(
print(f"[API] Gitea API 401 (Unauthorized) with token, retrying without token for: {url}") f"[API] Gitea API 401 (Unauthorized) with token, retrying without token for: {url}"
except Exception as e: )
if "Gitea API 401" not in str(e) and "401" not in str(e): except Exception as e:
raise if "Gitea API 401" not in str(e) and "401" not in str(e):
raise
# Try without token
headers.pop("Authorization", None)
try:
response = await _request_with_retry(client, url, headers)
if response.status_code == 200:
return response.json()
if response.status_code == 401:
raise Exception("Gitea API 401请配置 GITEA_TOKEN 或确认仓库权限")
if response.status_code == 403:
raise Exception("Gitea API 403请确认仓库权限/频率限制")
raise Exception(f"Gitea API {response.status_code}: {url}")
except Exception as e:
logger.error(f"[API] Gitea API 调用失败: {url}, 错误: {e}")
raise
# Try without token # Try without token
if "Authorization" in headers: if "Authorization" in headers:
@ -190,24 +323,42 @@ async def gitlab_api(url: str, token: str = None) -> Any:
"""调用GitLab API""" """调用GitLab API"""
headers = {"Content-Type": "application/json"} headers = {"Content-Type": "application/json"}
t = token or settings.GITLAB_TOKEN t = token or settings.GITLAB_TOKEN
client = _get_http_client()
async with httpx.AsyncClient(timeout=30) as client: # First try with token if available
# First try with token if available if t:
if t: headers["PRIVATE-TOKEN"] = t
headers["PRIVATE-TOKEN"] = t try:
try: response = await _request_with_retry(client, url, headers)
response = await client.get(url, headers=headers) if response.status_code == 200:
if response.status_code == 200: return response.json()
return response.json() if response.status_code != 401:
if response.status_code != 401: if response.status_code == 403:
if response.status_code == 403: raise Exception("GitLab API 403请确认仓库权限/频率限制")
raise Exception("GitLab API 403请确认仓库权限/频率限制") raise Exception(f"GitLab API {response.status_code}: {url}")
raise Exception(f"GitLab API {response.status_code}: {url}") # If 401, fall through to retry without token
# If 401, fall through to retry without token logger.info(
print(f"[API] GitLab API 401 (Unauthorized) with token, retrying without token for: {url}") f"[API] GitLab API 401 (Unauthorized) with token, retrying without token for: {url}"
except Exception as e: )
if "GitLab API 401" not in str(e) and "401" not in str(e): except Exception as e:
raise if "GitLab API 401" not in str(e) and "401" not in str(e):
raise
# Try without token
headers.pop("PRIVATE-TOKEN", None)
try:
response = await _request_with_retry(client, url, headers)
if response.status_code == 200:
return response.json()
if response.status_code == 401:
raise Exception("GitLab API 401请配置 GITLAB_TOKEN 或确认仓库权限")
if response.status_code == 403:
raise Exception("GitLab API 403请确认仓库权限/频率限制")
raise Exception(f"GitLab API {response.status_code}: {url}")
except Exception as e:
logger.error(f"[API] GitLab API 调用失败: {url}, 错误: {e}")
raise
# Try without token # Try without token
if "PRIVATE-TOKEN" in headers: if "PRIVATE-TOKEN" in headers:
@ -229,28 +380,31 @@ async def gitlab_api(url: str, token: str = None) -> Any:
async def fetch_file_content(url: str, headers: Dict[str, str] = None) -> Optional[str]: async def fetch_file_content(url: str, headers: Dict[str, str] = None) -> Optional[str]:
"""获取文件内容""" """获取文件内容"""
async with httpx.AsyncClient(timeout=30) as client: client = _get_http_client()
try: req_headers = headers or {}
response = await client.get(url, headers=headers or {}) try:
response = await _request_with_retry(client, url, req_headers)
if response.status_code == 200:
return response.text
# 如果带 Token 请求失败401/403尝试不带 Token 请求(针对公开仓库)
if response.status_code in (401, 403) and headers:
logger.info(
f"[API] 获取文件内容返回 {response.status_code},尝试不带 Token 重试: {url}"
)
response = await _request_with_retry(client, url, {})
if response.status_code == 200: if response.status_code == 200:
return response.text return response.text
# 如果带 Token 请求失败401/403尝试不带 Token 请求(针对公开仓库) except Exception as e:
if response.status_code in (401, 403) and headers: logger.error(f"获取文件内容失败: {url}, 错误: {e}")
print(f"[API] 获取文件内容返回 {response.status_code},尝试不带 Token 重试: {url}")
response = await client.get(url)
if response.status_code == 200:
return response.text
except Exception as e:
print(f"获取文件内容失败: {url}, 错误: {e}")
return None return None
async def get_github_branches(repo_url: str, token: str = None) -> List[str]: async def get_github_branches(repo_url: str, token: str = None) -> List[str]:
"""获取GitHub仓库分支列表""" """获取GitHub仓库分支列表"""
repo_info = parse_repository_url(repo_url, "github") repo_info = parse_repository_url(repo_url, "github")
owner, repo = repo_info['owner'], repo_info['repo'] owner, repo = repo_info["owner"], repo_info["repo"]
branches_url = f"https://api.github.com/repos/{owner}/{repo}/branches?per_page=100" branches_url = f"https://api.github.com/repos/{owner}/{repo}/branches?per_page=100"
branches_data = await github_api(branches_url, token) branches_data = await github_api(branches_url, token)
@ -262,14 +416,11 @@ async def get_github_branches(repo_url: str, token: str = None) -> List[str]:
return [b["name"] for b in branches_data if isinstance(b, dict) and "name" in b] return [b["name"] for b in branches_data if isinstance(b, dict) and "name" in b]
async def get_gitea_branches(repo_url: str, token: str = None) -> List[str]: async def get_gitea_branches(repo_url: str, token: str = None) -> List[str]:
"""获取Gitea仓库分支列表""" """获取Gitea仓库分支列表"""
repo_info = parse_repository_url(repo_url, "gitea") repo_info = parse_repository_url(repo_url, "gitea")
base_url = repo_info['base_url'] # This is {base}/api/v1 base_url = repo_info["base_url"] # This is {base}/api/v1
owner, repo = repo_info['owner'], repo_info['repo'] owner, repo = repo_info["owner"], repo_info["repo"]
branches_url = f"{base_url}/repos/{owner}/{repo}/branches" branches_url = f"{base_url}/repos/{owner}/{repo}/branches"
branches_data = await gitea_api(branches_url, token) branches_data = await gitea_api(branches_url, token)
@ -287,14 +438,14 @@ async def get_gitlab_branches(repo_url: str, token: str = None) -> List[str]:
extracted_token = token extracted_token = token
if parsed.username: if parsed.username:
if parsed.username == 'oauth2' and parsed.password: if parsed.username == "oauth2" and parsed.password:
extracted_token = parsed.password extracted_token = parsed.password
elif parsed.username and not parsed.password: elif parsed.username and not parsed.password:
extracted_token = parsed.username extracted_token = parsed.username
repo_info = parse_repository_url(repo_url, "gitlab") repo_info = parse_repository_url(repo_url, "gitlab")
base_url = repo_info['base_url'] base_url = repo_info["base_url"]
project_path = quote(repo_info['project_path'], safe='') project_path = quote(repo_info["project_path"], safe="")
branches_url = f"{base_url}/projects/{project_path}/repository/branches?per_page=100" branches_url = f"{base_url}/projects/{project_path}/repository/branches?per_page=100"
branches_data = await gitlab_api(branches_url, extracted_token) branches_data = await gitlab_api(branches_url, extracted_token)
@ -306,11 +457,13 @@ async def get_gitlab_branches(repo_url: str, token: str = None) -> List[str]:
return [b["name"] for b in branches_data if isinstance(b, dict) and "name" in b] return [b["name"] for b in branches_data if isinstance(b, dict) and "name" in b]
async def get_github_files(repo_url: str, branch: str, token: str = None, exclude_patterns: List[str] = None) -> List[Dict[str, str]]: async def get_github_files(
repo_url: str, branch: str, token: str = None, exclude_patterns: List[str] = None
) -> List[Dict[str, str]]:
"""获取GitHub仓库文件列表""" """获取GitHub仓库文件列表"""
# 解析仓库URL # 解析仓库URL
repo_info = parse_repository_url(repo_url, "github") repo_info = parse_repository_url(repo_url, "github")
owner, repo = repo_info['owner'], repo_info['repo'] owner, repo = repo_info["owner"], repo_info["repo"]
# 获取仓库文件树 # 获取仓库文件树
tree_url = f"https://api.github.com/repos/{owner}/{repo}/git/trees/{quote(branch)}?recursive=1" tree_url = f"https://api.github.com/repos/{owner}/{repo}/git/trees/{quote(branch)}?recursive=1"
@ -318,33 +471,41 @@ async def get_github_files(repo_url: str, branch: str, token: str = None, exclud
files = [] files = []
for item in tree_data.get("tree", []): for item in tree_data.get("tree", []):
if item.get("type") == "blob" and is_text_file(item["path"]) and not should_exclude(item["path"], exclude_patterns): if (
item.get("type") == "blob"
and is_text_file(item["path"])
and not should_exclude(item["path"], exclude_patterns)
):
size = item.get("size", 0) size = item.get("size", 0)
if size <= settings.MAX_FILE_SIZE_BYTES: if size <= settings.MAX_FILE_SIZE_BYTES:
files.append({ files.append(
"path": item["path"], {
"url": f"https://raw.githubusercontent.com/{owner}/{repo}/{quote(branch)}/{item['path']}" "path": item["path"],
}) "url": f"https://raw.githubusercontent.com/{owner}/{repo}/{quote(branch)}/{item['path']}",
}
)
return files return files
async def get_gitlab_files(repo_url: str, branch: str, token: str = None, exclude_patterns: List[str] = None) -> List[Dict[str, str]]: async def get_gitlab_files(
repo_url: str, branch: str, token: str = None, exclude_patterns: List[str] = None
) -> List[Dict[str, str]]:
"""获取GitLab仓库文件列表""" """获取GitLab仓库文件列表"""
parsed = urlparse(repo_url) parsed = urlparse(repo_url)
# 从URL中提取token如果存在 # 从URL中提取token如果存在
extracted_token = token extracted_token = token
if parsed.username: if parsed.username:
if parsed.username == 'oauth2' and parsed.password: if parsed.username == "oauth2" and parsed.password:
extracted_token = parsed.password extracted_token = parsed.password
elif parsed.username and not parsed.password: elif parsed.username and not parsed.password:
extracted_token = parsed.username extracted_token = parsed.username
# 解析项目路径 # 解析项目路径
repo_info = parse_repository_url(repo_url, "gitlab") repo_info = parse_repository_url(repo_url, "gitlab")
base_url = repo_info['base_url'] # {base}/api/v4 base_url = repo_info["base_url"] # {base}/api/v4
project_path = quote(repo_info['project_path'], safe='') project_path = quote(repo_info["project_path"], safe="")
# 获取仓库文件树 # 获取仓库文件树
tree_url = f"{base_url}/projects/{project_path}/repository/tree?ref={quote(branch)}&recursive=true&per_page=100" tree_url = f"{base_url}/projects/{project_path}/repository/tree?ref={quote(branch)}&recursive=true&per_page=100"
@ -352,40 +513,57 @@ async def get_gitlab_files(repo_url: str, branch: str, token: str = None, exclud
files = [] files = []
for item in tree_data: for item in tree_data:
if item.get("type") == "blob" and is_text_file(item["path"]) and not should_exclude(item["path"], exclude_patterns): if (
files.append({ item.get("type") == "blob"
"path": item["path"], and is_text_file(item["path"])
"url": f"{base_url}/projects/{project_path}/repository/files/{quote(item['path'], safe='')}/raw?ref={quote(branch)}", and not should_exclude(item["path"], exclude_patterns)
"token": extracted_token ):
}) files.append(
{
"path": item["path"],
"url": f"{base_url}/projects/{project_path}/repository/files/{quote(item['path'], safe='')}/raw?ref={quote(branch)}",
"token": extracted_token,
}
)
return files return files
async def get_gitea_files(
async def get_gitea_files(repo_url: str, branch: str, token: str = None, exclude_patterns: List[str] = None) -> List[Dict[str, str]]: repo_url: str, branch: str, token: str = None, exclude_patterns: List[str] = None
) -> List[Dict[str, str]]:
"""获取Gitea仓库文件列表""" """获取Gitea仓库文件列表"""
repo_info = parse_repository_url(repo_url, "gitea") repo_info = parse_repository_url(repo_url, "gitea")
base_url = repo_info['base_url'] base_url = repo_info["base_url"]
owner, repo = repo_info['owner'], repo_info['repo'] owner, repo = repo_info["owner"], repo_info["repo"]
# Gitea tree API: GET /repos/{owner}/{repo}/git/trees/{sha}?recursive=true # Gitea tree API: GET /repos/{owner}/{repo}/git/trees/{sha}?recursive=true
# 可以直接使用分支名作为sha # 可以直接使用分支名作为sha
tree_url = f"{base_url}/repos/{quote(owner)}/{quote(repo)}/git/trees/{quote(branch)}?recursive=true" tree_url = (
f"{base_url}/repos/{quote(owner)}/{quote(repo)}/git/trees/{quote(branch)}?recursive=true"
)
tree_data = await gitea_api(tree_url, token) tree_data = await gitea_api(tree_url, token)
files = [] files = []
for item in tree_data.get("tree", []): for item in tree_data.get("tree", []):
# Gitea API returns 'type': 'blob' for files # Gitea API returns 'type': 'blob' for files
if item.get("type") == "blob" and is_text_file(item["path"]) and not should_exclude(item["path"], exclude_patterns): if (
item.get("type") == "blob"
and is_text_file(item["path"])
and not should_exclude(item["path"], exclude_patterns)
):
# 使用API raw endpoint: GET /repos/{owner}/{repo}/raw/{filepath}?ref={branch} # 使用API raw endpoint: GET /repos/{owner}/{repo}/raw/{filepath}?ref={branch}
files.append({ files.append(
"path": item["path"], {
"url": f"{base_url}/repos/{owner}/{repo}/raw/{quote(item['path'])}?ref={quote(branch)}", "path": item["path"],
"token": token # 传递token以便fetch_file_content使用 "url": f"{base_url}/repos/{owner}/{repo}/raw/{quote(item['path'])}?ref={quote(branch)}",
}) "token": token, # 传递token以便fetch_file_content使用
}
)
return files return files
async def scan_repo_task(task_id: str, db_session_factory, user_config: dict = None): async def scan_repo_task(task_id: str, db_session_factory, user_config: dict = None):
""" """
后台仓库扫描任务 后台仓库扫描任务
@ -415,8 +593,8 @@ async def scan_repo_task(task_id: str, db_session_factory, user_config: dict = N
raise Exception("项目不存在") raise Exception("项目不存在")
# 检查项目类型 - 仅支持仓库类型项目 # 检查项目类型 - 仅支持仓库类型项目
source_type = getattr(project, 'source_type', 'repository') source_type = getattr(project, "source_type", "repository")
if source_type == 'zip': if source_type == "zip":
raise Exception("ZIP类型项目请使用ZIP上传扫描接口") raise Exception("ZIP类型项目请使用ZIP上传扫描接口")
if not project.repository_url: if not project.repository_url:
@ -428,6 +606,7 @@ async def scan_repo_task(task_id: str, db_session_factory, user_config: dict = N
# 解析任务的排除模式 # 解析任务的排除模式
import json as json_module import json as json_module
task_exclude_patterns = [] task_exclude_patterns = []
if task.exclude_patterns: if task.exclude_patterns:
try: try:
@ -435,7 +614,9 @@ async def scan_repo_task(task_id: str, db_session_factory, user_config: dict = N
except: except:
pass pass
print(f"🚀 开始扫描仓库: {repo_url}, 分支: {branch}, 类型: {repo_type}, 来源: {source_type}") print(
f"🚀 开始扫描仓库: {repo_url}, 分支: {branch}, 类型: {repo_type}, 来源: {source_type}"
)
if task_exclude_patterns: if task_exclude_patterns:
print(f"📋 排除模式: {task_exclude_patterns}") print(f"📋 排除模式: {task_exclude_patterns}")
@ -445,20 +626,20 @@ async def scan_repo_task(task_id: str, db_session_factory, user_config: dict = N
gitlab_token = settings.GITLAB_TOKEN gitlab_token = settings.GITLAB_TOKEN
gitea_token = settings.GITEA_TOKEN gitea_token = settings.GITEA_TOKEN
# 获取SSH私钥如果配置了 # 获取SSH私钥如果配置了
user_other_config = user_config.get('otherConfig', {}) if user_config else {} user_other_config = user_config.get("otherConfig", {}) if user_config else {}
ssh_private_key = None ssh_private_key = None
if 'sshPrivateKey' in user_other_config: if "sshPrivateKey" in user_other_config:
from app.core.encryption import decrypt_sensitive_data from app.core.encryption import decrypt_sensitive_data
ssh_private_key = decrypt_sensitive_data(user_other_config['sshPrivateKey'])
ssh_private_key = decrypt_sensitive_data(user_other_config["sshPrivateKey"])
files: List[Dict[str, str]] = [] files: List[Dict[str, str]] = []
extracted_gitlab_token = None extracted_gitlab_token = None
# 检查是否为SSH URL # 检查是否为SSH URL
from app.services.git_ssh_service import GitSSHOperations from app.services.git_ssh_service import GitSSHOperations
is_ssh_url = GitSSHOperations.is_ssh_url(repo_url) is_ssh_url = GitSSHOperations.is_ssh_url(repo_url)
if is_ssh_url: if is_ssh_url:
@ -472,7 +653,9 @@ async def scan_repo_task(task_id: str, db_session_factory, user_config: dict = N
repo_url, ssh_private_key, branch, task_exclude_patterns repo_url, ssh_private_key, branch, task_exclude_patterns
) )
# 转换为统一格式 # 转换为统一格式
files = [{'path': f['path'], 'content': f['content']} for f in files_with_content] files = [
{"path": f["path"], "content": f["content"]} for f in files_with_content
]
actual_branch = branch actual_branch = branch
print(f"✅ 通过SSH成功获取 {len(files)} 个文件") print(f"✅ 通过SSH成功获取 {len(files)} 个文件")
except Exception as e: except Exception as e:
@ -494,21 +677,29 @@ async def scan_repo_task(task_id: str, db_session_factory, user_config: dict = N
try: try:
print(f"🔄 尝试获取分支 {try_branch} 的文件列表...") print(f"🔄 尝试获取分支 {try_branch} 的文件列表...")
if repo_type == "github": if repo_type == "github":
files = await get_github_files(repo_url, try_branch, github_token, task_exclude_patterns) files = await get_github_files(
repo_url, try_branch, github_token, task_exclude_patterns
)
elif repo_type == "gitlab": elif repo_type == "gitlab":
files = await get_gitlab_files(repo_url, try_branch, gitlab_token, task_exclude_patterns) files = await get_gitlab_files(
repo_url, try_branch, gitlab_token, task_exclude_patterns
)
# GitLab文件可能带有token # GitLab文件可能带有token
if files and 'token' in files[0]: if files and "token" in files[0]:
extracted_gitlab_token = files[0].get('token') extracted_gitlab_token = files[0].get("token")
elif repo_type == "gitea": elif repo_type == "gitea":
files = await get_gitea_files(repo_url, try_branch, gitea_token, task_exclude_patterns) files = await get_gitea_files(
repo_url, try_branch, gitea_token, task_exclude_patterns
)
else: else:
raise Exception("不支持的仓库类型,仅支持 GitHub, GitLab 和 Gitea 仓库") raise Exception("不支持的仓库类型,仅支持 GitHub, GitLab 和 Gitea 仓库")
if files: if files:
actual_branch = try_branch actual_branch = try_branch
if try_branch != branch: if try_branch != branch:
print(f"⚠️ 分支 {branch} 不存在或无法访问,已降级到分支 {try_branch}") print(
f"⚠️ 分支 {branch} 不存在或无法访问,已降级到分支 {try_branch}"
)
break break
except Exception as e: except Exception as e:
last_error = str(e) last_error = str(e)
@ -530,23 +721,25 @@ async def scan_repo_task(task_id: str, db_session_factory, user_config: dict = N
# 获取分析配置(优先使用用户配置) # 获取分析配置(优先使用用户配置)
analysis_config = get_analysis_config(user_config) analysis_config = get_analysis_config(user_config)
max_analyze_files = analysis_config['max_analyze_files'] max_analyze_files = analysis_config["max_analyze_files"]
analysis_concurrency = analysis_config['llm_concurrency'] # 并发数 analysis_concurrency = analysis_config["llm_concurrency"] # 并发数
llm_gap_ms = analysis_config['llm_gap_ms'] llm_gap_ms = analysis_config["llm_gap_ms"]
# 限制文件数量 # 限制文件数量
# 如果指定了特定文件,则只分析这些文件 # 如果指定了特定文件,则只分析这些文件
target_files = (user_config or {}).get('scan_config', {}).get('file_paths', []) target_files = (user_config or {}).get("scan_config", {}).get("file_paths", [])
if target_files: if target_files:
print(f"🎯 指定分析 {len(target_files)} 个文件") print(f"🎯 指定分析 {len(target_files)} 个文件")
files = [f for f in files if f['path'] in target_files] files = [f for f in files if f["path"] in target_files]
elif max_analyze_files > 0: elif max_analyze_files > 0:
files = files[:max_analyze_files] files = files[:max_analyze_files]
task.total_files = len(files) task.total_files = len(files)
await db.commit() await db.commit()
print(f"📊 获取到 {len(files)} 个文件,开始分析 (最大文件数: {max_analyze_files}, 请求间隔: {llm_gap_ms}ms)") print(
f"📊 获取到 {len(files)} 个文件,开始分析 (最大文件数: {max_analyze_files}, 请求间隔: {llm_gap_ms}ms)"
)
# 4. 分析文件 # 4. 分析文件
total_issues = 0 total_issues = 0
@ -572,22 +765,24 @@ async def scan_repo_task(task_id: str, db_session_factory, user_config: dict = N
if task_control.is_cancelled(task_id): if task_control.is_cancelled(task_id):
return None return None
f_path = file_info['path'] f_path = file_info["path"]
MAX_RETRIES = 3 MAX_RETRIES = 3
for attempt in range(MAX_RETRIES): for attempt in range(MAX_RETRIES):
try: try:
# 4.1 获取文件内容 (仅在第一次尝试或内容获取失败时获取) # 4.1 获取文件内容 (仅在第一次尝试或内容获取失败时获取)
if attempt == 0: if attempt == 0:
if is_ssh_url: if is_ssh_url:
content = file_info.get('content', '') content = file_info.get("content", "")
else: else:
headers = {} headers = {}
if repo_type == "gitlab": if repo_type == "gitlab":
token_to_use = file_info.get('token') or gitlab_token token_to_use = file_info.get("token") or gitlab_token
if token_to_use: headers["PRIVATE-TOKEN"] = token_to_use if token_to_use:
headers["PRIVATE-TOKEN"] = token_to_use
elif repo_type == "gitea": elif repo_type == "gitea":
token_to_use = file_info.get('token') or gitea_token token_to_use = file_info.get("token") or gitea_token
if token_to_use: headers["Authorization"] = f"token {token_to_use}" if token_to_use:
headers["Authorization"] = f"token {token_to_use}"
elif repo_type == "github" and github_token: elif repo_type == "github" and github_token:
headers["Authorization"] = f"Bearer {github_token}" headers["Authorization"] = f"Bearer {github_token}"
@ -604,16 +799,17 @@ async def scan_repo_task(task_id: str, db_session_factory, user_config: dict = N
# 4.2 LLM 分析 # 4.2 LLM 分析
language = get_language_from_path(f_path) language = get_language_from_path(f_path)
scan_config = (user_config or {}).get('scan_config', {}) scan_config = (user_config or {}).get("scan_config", {})
rule_set_id = scan_config.get('rule_set_id') rule_set_id = scan_config.get("rule_set_id")
prompt_template_id = scan_config.get('prompt_template_id') prompt_template_id = scan_config.get("prompt_template_id")
if rule_set_id or prompt_template_id: if rule_set_id or prompt_template_id:
analysis_result = await llm_service.analyze_code_with_rules( analysis_result = await llm_service.analyze_code_with_rules(
content, language, content,
language,
rule_set_id=rule_set_id, rule_set_id=rule_set_id,
prompt_template_id=prompt_template_id, prompt_template_id=prompt_template_id,
db_session=None db_session=None,
) )
else: else:
analysis_result = await llm_service.analyze_code(content, language) analysis_result = await llm_service.analyze_code(content, language)
@ -623,7 +819,7 @@ async def scan_repo_task(task_id: str, db_session_factory, user_config: dict = N
"path": f_path, "path": f_path,
"content": content, "content": content,
"language": language, "language": language,
"analysis": analysis_result "analysis": analysis_result,
} }
except asyncio.CancelledError: except asyncio.CancelledError:
# 捕获取消异常,不再重试 # 捕获取消异常,不再重试
@ -633,10 +829,18 @@ async def scan_repo_task(task_id: str, db_session_factory, user_config: dict = N
wait_time = (attempt + 1) * 2 wait_time = (attempt + 1) * 2
# 特殊处理限流错误提示 # 特殊处理限流错误提示
error_str = str(e) error_str = str(e)
if "429" in error_str or "rate limit" in error_str.lower() or "额度不足" in error_str: if (
print(f"🚫 [限流提示] 仓库扫描任务触发 LLM 频率限制 (429),建议在设置中降低并发数或增加请求间隔。文件: {f_path}") "429" in error_str
or "rate limit" in error_str.lower()
or "额度不足" in error_str
):
print(
f"🚫 [限流提示] 仓库扫描任务触发 LLM 频率限制 (429),建议在设置中降低并发数或增加请求间隔。文件: {f_path}"
)
print(f"⚠️ 分析文件失败 ({f_path}), 正在进行第 {attempt+1} 次重试... 错误: {e}") print(
f"⚠️ 分析文件失败 ({f_path}), 正在进行第 {attempt + 1} 次重试... 错误: {e}"
)
await asyncio.sleep(wait_time) await asyncio.sleep(wait_time)
continue continue
else: else:
@ -660,7 +864,8 @@ async def scan_repo_task(task_id: str, db_session_factory, user_config: dict = N
except asyncio.CancelledError: except asyncio.CancelledError:
continue continue
if not res: continue if not res:
continue
if res["type"] == "skip": if res["type"] == "skip":
skipped_files += 1 skipped_files += 1
@ -674,7 +879,7 @@ async def scan_repo_task(task_id: str, db_session_factory, user_config: dict = N
f_path = res["path"] f_path = res["path"]
analysis = res["analysis"] analysis = res["analysis"]
file_lines = res["content"].split('\n') file_lines = res["content"].split("\n")
total_lines += len(file_lines) total_lines += len(file_lines)
# 保存问题 # 保存问题
@ -683,7 +888,9 @@ async def scan_repo_task(task_id: str, db_session_factory, user_config: dict = N
try: try:
# 防御性检查:确保 issue 是字典 # 防御性检查:确保 issue 是字典
if not isinstance(issue, dict): if not isinstance(issue, dict):
print(f"⚠️ 警告: 任务 {task_id} 中文件 {f_path} 的分析结果包含无效的问题格式: {issue}") print(
f"⚠️ 警告: 任务 {task_id} 中文件 {f_path} 的分析结果包含无效的问题格式: {issue}"
)
continue continue
# 辅助函数:清理字符串中 PostgreSQL 不支持的字符 # 辅助函数:清理字符串中 PostgreSQL 不支持的字符
@ -693,7 +900,7 @@ async def scan_repo_task(task_id: str, db_session_factory, user_config: dict = N
if not isinstance(text, str): if not isinstance(text, str):
text = str(text) text = str(text)
# 移除 NULL 字节 (PostgreSQL 不支持) # 移除 NULL 字节 (PostgreSQL 不支持)
text = text.replace('\x00', '') text = text.replace("\x00", "")
return text return text
line_num = issue.get("line", 1) line_num = issue.get("line", 1)
@ -703,7 +910,7 @@ async def scan_repo_task(task_id: str, db_session_factory, user_config: dict = N
idx = max(0, int(line_num) - 1) idx = max(0, int(line_num) - 1)
start = max(0, idx - 2) start = max(0, idx - 2)
end = min(len(file_lines), idx + 3) end = min(len(file_lines), idx + 3)
code_snippet = '\n'.join(file_lines[start:end]) code_snippet = "\n".join(file_lines[start:end])
except Exception: except Exception:
code_snippet = "" code_snippet = ""
@ -715,11 +922,13 @@ async def scan_repo_task(task_id: str, db_session_factory, user_config: dict = N
issue_type=issue.get("type", "maintainability"), issue_type=issue.get("type", "maintainability"),
severity=issue.get("severity", "low"), severity=issue.get("severity", "low"),
title=sanitize_for_db(issue.get("title", "Issue")), title=sanitize_for_db(issue.get("title", "Issue")),
message=sanitize_for_db(issue.get("description") or issue.get("title", "Issue")), message=sanitize_for_db(
issue.get("description") or issue.get("title", "Issue")
),
suggestion=sanitize_for_db(issue.get("suggestion")), suggestion=sanitize_for_db(issue.get("suggestion")),
code_snippet=code_snippet, code_snippet=code_snippet,
ai_explanation=sanitize_for_db(issue.get("ai_explanation")), ai_explanation=sanitize_for_db(issue.get("ai_explanation")),
status="open" status="open",
) )
db.add(audit_issue) db.add(audit_issue)
total_issues += 1 total_issues += 1
@ -739,10 +948,12 @@ async def scan_repo_task(task_id: str, db_session_factory, user_config: dict = N
task.scanned_files = processed_count task.scanned_files = processed_count
task.total_lines = total_lines task.total_lines = total_lines
task.issues_count = total_issues task.issues_count = total_issues
await db.commit() # 这里的 commit 是在一个协程里按序进行的,是安全的 await db.commit() # 这里的 commit 是在一个协程里按序进行的,是安全的
if processed_count % 10 == 0 or processed_count == len(files): if processed_count % 10 == 0 or processed_count == len(files):
print(f"📈 任务 {task_id}: 进度 {processed_count}/{len(files)} ({int(processed_count/len(files)*100) if len(files) > 0 else 0}%)") print(
f"📈 任务 {task_id}: 进度 {processed_count}/{len(files)} ({int(processed_count / len(files) * 100) if len(files) > 0 else 0}%)"
)
if consecutive_failures >= MAX_CONSECUTIVE_FAILURES: if consecutive_failures >= MAX_CONSECUTIVE_FAILURES:
print(f"❌ 任务 {task_id}: 连续失败 {consecutive_failures} 次,停止分析") print(f"❌ 任务 {task_id}: 连续失败 {consecutive_failures} 次,停止分析")
@ -761,7 +972,9 @@ async def scan_repo_task(task_id: str, db_session_factory, user_config: dict = N
await asyncio.gather(*task_objects, return_exceptions=True) await asyncio.gather(*task_objects, return_exceptions=True)
# 5. 完成任务 # 5. 完成任务
avg_quality_score = sum(quality_scores) / len(quality_scores) if quality_scores else 100.0 avg_quality_score = (
sum(quality_scores) / len(quality_scores) if quality_scores else 100.0
)
# 判断任务状态 # 判断任务状态
# 如果所有文件都被跳过(空文件等),标记为完成但给出提示 # 如果所有文件都被跳过(空文件等),标记为完成但给出提示