feat: Implement shared HTTP client with retry logic and optimized connection settings for GitHub/Gitea API calls to improve reliability and performance.
Build and Push CodeReview / build (push) Has been cancelled Details

This commit is contained in:
vinland100 2026-02-06 14:12:51 +08:00
parent f0981ae187
commit 5491714cce
1 changed files with 448 additions and 235 deletions

View File

@ -4,6 +4,7 @@
import os
import asyncio
import logging
import httpx
from typing import List, Dict, Any, Optional
from datetime import datetime, timezone
@ -15,7 +16,65 @@ from app.models.audit import AuditTask, AuditIssue
from app.models.project import Project
from app.services.llm.service import LLMService
from app.core.config import settings
from app.core.file_filter import is_text_file as core_is_text_file, should_exclude as core_should_exclude, TEXT_EXTENSIONS as CORE_TEXT_EXTENSIONS
from app.core.file_filter import (
is_text_file as core_is_text_file,
should_exclude as core_should_exclude,
TEXT_EXTENSIONS as CORE_TEXT_EXTENSIONS,
)
logger = logging.getLogger(__name__)
# ---------------------------------------------------------------------------
# 共享 HTTP 客户端(模块级别单例)
# 通过复用连接池和 DNS 缓存,避免每次请求都重新建立 TCP 连接,
# 解决"每天第一次打开时因 DNS/连接冷启动导致超时"的问题。
# ---------------------------------------------------------------------------
_GIT_API_TIMEOUT = httpx.Timeout(
connect=15.0, # 连接超时 15 秒(原来 120 秒太长,用户体验差)
read=60.0, # 读取超时 60 秒(大仓库 tree API 可能较慢)
write=30.0,
pool=15.0, # 从连接池获取连接的超时
)
_http_client: Optional[httpx.AsyncClient] = None
def _get_http_client() -> httpx.AsyncClient:
"""获取或创建共享的 HTTP 客户端(惰性初始化)"""
global _http_client
if _http_client is None or _http_client.is_closed:
_http_client = httpx.AsyncClient(
timeout=_GIT_API_TIMEOUT,
limits=httpx.Limits(
max_connections=20,
max_keepalive_connections=10,
keepalive_expiry=300, # 保持连接 5 分钟
),
follow_redirects=True,
)
return _http_client
async def _request_with_retry(
client: httpx.AsyncClient,
url: str,
headers: Dict[str, str],
max_retries: int = 3,
) -> httpx.Response:
"""带自动重试的 HTTP GET 请求,针对连接超时进行重试"""
last_exc: Optional[Exception] = None
for attempt in range(max_retries):
try:
return await client.get(url, headers=headers)
except (httpx.ConnectTimeout, httpx.ConnectError) as e:
last_exc = e
if attempt < max_retries - 1:
wait = (attempt + 1) * 2
logger.info(f"[API] 连接失败 (第 {attempt + 1} 次), {wait}s 后重试: {url} - {e}")
await asyncio.sleep(wait)
else:
logger.error(f"[API] 连接最终失败 (共 {max_retries} 次): {url} - {e}")
raise last_exc # type: ignore[misc]
def get_analysis_config(user_config: Optional[Dict[str, Any]] = None) -> Dict[str, Any]:
@ -28,18 +87,19 @@ def get_analysis_config(user_config: Optional[Dict[str, Any]] = None) -> Dict[st
- llm_concurrency: LLM 并发数
- llm_gap_ms: LLM 请求间隔毫秒
"""
other_config = (user_config or {}).get('otherConfig', {})
other_config = (user_config or {}).get("otherConfig", {})
return {
'max_analyze_files': other_config.get('maxAnalyzeFiles') or settings.MAX_ANALYZE_FILES,
'llm_concurrency': other_config.get('llmConcurrency') or settings.LLM_CONCURRENCY,
'llm_gap_ms': other_config.get('llmGapMs') or settings.LLM_GAP_MS,
"max_analyze_files": other_config.get("maxAnalyzeFiles") or settings.MAX_ANALYZE_FILES,
"llm_concurrency": other_config.get("llmConcurrency") or settings.LLM_CONCURRENCY,
"llm_gap_ms": other_config.get("llmGapMs") or settings.LLM_GAP_MS,
}
# 支持的文本文件扩展名使用全局定义
TEXT_EXTENSIONS = list(CORE_TEXT_EXTENSIONS)
def is_text_file(path: str) -> bool:
"""检查是否为文本文件"""
return core_is_text_file(path)
@ -53,30 +113,68 @@ def should_exclude(path: str, exclude_patterns: List[str] = None) -> bool:
def get_language_from_path(path: str) -> str:
"""从文件路径获取语言类型"""
ext = path.split('.')[-1].lower() if '.' in path else ''
ext = path.split(".")[-1].lower() if "." in path else ""
language_map = {
'py': 'python',
'js': 'javascript', 'jsx': 'javascript',
'ts': 'typescript', 'tsx': 'typescript',
'java': 'java', 'go': 'go', 'rs': 'rust',
'cpp': 'cpp', 'c': 'c', 'cc': 'cpp', 'h': 'c', 'hh': 'cpp',
'hpp': 'cpp', 'hxx': 'cpp',
'cs': 'csharp', 'php': 'php', 'rb': 'ruby',
'kt': 'kotlin', 'ktm': 'kotlin', 'kts': 'kotlin',
'swift': 'swift', 'dart': 'dart',
'scala': 'scala', 'sc': 'scala',
'groovy': 'groovy', 'gsh': 'groovy', 'gvy': 'groovy', 'gy': 'groovy',
'sql': 'sql', 'sh': 'bash', 'bash': 'bash', 'zsh': 'bash',
'pl': 'perl', 'pm': 'perl', 't': 'perl',
'lua': 'lua', 'hs': 'haskell', 'lhs': 'haskell',
'clj': 'clojure', 'cljs': 'clojure', 'cljc': 'clojure', 'edn': 'clojure',
'ex': 'elixir', 'exs': 'elixir', 'erl': 'erlang', 'hrl': 'erlang',
'm': 'objective-c', 'mm': 'objective-c',
'r': 'r', 'rmd': 'r',
'vb': 'visual-basic', 'fs': 'fsharp', 'fsi': 'fsharp', 'fsx': 'fsharp',
'tf': 'hcl', 'hcl': 'hcl', 'dockerfile': 'dockerfile'
"py": "python",
"js": "javascript",
"jsx": "javascript",
"ts": "typescript",
"tsx": "typescript",
"java": "java",
"go": "go",
"rs": "rust",
"cpp": "cpp",
"c": "c",
"cc": "cpp",
"h": "c",
"hh": "cpp",
"hpp": "cpp",
"hxx": "cpp",
"cs": "csharp",
"php": "php",
"rb": "ruby",
"kt": "kotlin",
"ktm": "kotlin",
"kts": "kotlin",
"swift": "swift",
"dart": "dart",
"scala": "scala",
"sc": "scala",
"groovy": "groovy",
"gsh": "groovy",
"gvy": "groovy",
"gy": "groovy",
"sql": "sql",
"sh": "bash",
"bash": "bash",
"zsh": "bash",
"pl": "perl",
"pm": "perl",
"t": "perl",
"lua": "lua",
"hs": "haskell",
"lhs": "haskell",
"clj": "clojure",
"cljs": "clojure",
"cljc": "clojure",
"edn": "clojure",
"ex": "elixir",
"exs": "elixir",
"erl": "erlang",
"hrl": "erlang",
"m": "objective-c",
"mm": "objective-c",
"r": "r",
"rmd": "r",
"vb": "visual-basic",
"fs": "fsharp",
"fsi": "fsharp",
"fsx": "fsharp",
"tf": "hcl",
"hcl": "hcl",
"dockerfile": "dockerfile",
}
return language_map.get(ext, 'text')
return language_map.get(ext, "text")
class TaskControlManager:
@ -107,13 +205,13 @@ async def github_api(url: str, token: str = None) -> Any:
"""调用GitHub API"""
headers = {"Accept": "application/vnd.github+json"}
t = token or settings.GITHUB_TOKEN
client = _get_http_client()
async with httpx.AsyncClient(timeout=30) as client:
# First try with token if available
if t:
headers["Authorization"] = f"Bearer {t}"
try:
response = await client.get(url, headers=headers)
response = await _request_with_retry(client, url, headers)
if response.status_code == 200:
return response.json()
if response.status_code != 401:
@ -121,11 +219,29 @@ async def github_api(url: str, token: str = None) -> Any:
raise Exception("GitHub API 403请配置 GITHUB_TOKEN 或确认仓库权限/频率限制")
raise Exception(f"GitHub API {response.status_code}: {url}")
# If 401, fall through to retry without token
print(f"[API] GitHub API 401 (Unauthorized) with token, retrying without token for: {url}")
logger.info(
f"[API] GitHub API 401 (Unauthorized) with token, retrying without token for: {url}"
)
except Exception as e:
if "GitHub API 401" not in str(e) and "401" not in str(e):
raise
# Try without token
headers.pop("Authorization", None)
try:
response = await _request_with_retry(client, url, headers)
if response.status_code == 200:
return response.json()
if response.status_code == 403:
raise Exception("GitHub API 403请配置 GITHUB_TOKEN 或确认仓库权限/频率限制")
if response.status_code == 401:
raise Exception("GitHub API 401请配置 GITHUB_TOKEN 或确认仓库权限")
raise Exception(f"GitHub API {response.status_code}: {url}")
except Exception as e:
logger.error(f"[API] GitHub API 调用失败: {url}, 错误: {e}")
raise
# Try without token
if "Authorization" in headers:
del headers["Authorization"]
@ -144,18 +260,17 @@ async def github_api(url: str, token: str = None) -> Any:
raise
async def gitea_api(url: str, token: str = None) -> Any:
"""调用Gitea API"""
headers = {"Content-Type": "application/json"}
t = token or settings.GITEA_TOKEN
client = _get_http_client()
async with httpx.AsyncClient(timeout=120) as client:
# First try with token if available
if t:
headers["Authorization"] = f"token {t}"
try:
response = await client.get(url, headers=headers)
response = await _request_with_retry(client, url, headers)
if response.status_code == 200:
return response.json()
if response.status_code != 401:
@ -163,11 +278,29 @@ async def gitea_api(url: str, token: str = None) -> Any:
raise Exception("Gitea API 403请确认仓库权限/频率限制")
raise Exception(f"Gitea API {response.status_code}: {url}")
# If 401, fall through to retry without token
print(f"[API] Gitea API 401 (Unauthorized) with token, retrying without token for: {url}")
logger.info(
f"[API] Gitea API 401 (Unauthorized) with token, retrying without token for: {url}"
)
except Exception as e:
if "Gitea API 401" not in str(e) and "401" not in str(e):
raise
# Try without token
headers.pop("Authorization", None)
try:
response = await _request_with_retry(client, url, headers)
if response.status_code == 200:
return response.json()
if response.status_code == 401:
raise Exception("Gitea API 401请配置 GITEA_TOKEN 或确认仓库权限")
if response.status_code == 403:
raise Exception("Gitea API 403请确认仓库权限/频率限制")
raise Exception(f"Gitea API {response.status_code}: {url}")
except Exception as e:
logger.error(f"[API] Gitea API 调用失败: {url}, 错误: {e}")
raise
# Try without token
if "Authorization" in headers:
del headers["Authorization"]
@ -190,13 +323,13 @@ async def gitlab_api(url: str, token: str = None) -> Any:
"""调用GitLab API"""
headers = {"Content-Type": "application/json"}
t = token or settings.GITLAB_TOKEN
client = _get_http_client()
async with httpx.AsyncClient(timeout=30) as client:
# First try with token if available
if t:
headers["PRIVATE-TOKEN"] = t
try:
response = await client.get(url, headers=headers)
response = await _request_with_retry(client, url, headers)
if response.status_code == 200:
return response.json()
if response.status_code != 401:
@ -204,11 +337,29 @@ async def gitlab_api(url: str, token: str = None) -> Any:
raise Exception("GitLab API 403请确认仓库权限/频率限制")
raise Exception(f"GitLab API {response.status_code}: {url}")
# If 401, fall through to retry without token
print(f"[API] GitLab API 401 (Unauthorized) with token, retrying without token for: {url}")
logger.info(
f"[API] GitLab API 401 (Unauthorized) with token, retrying without token for: {url}"
)
except Exception as e:
if "GitLab API 401" not in str(e) and "401" not in str(e):
raise
# Try without token
headers.pop("PRIVATE-TOKEN", None)
try:
response = await _request_with_retry(client, url, headers)
if response.status_code == 200:
return response.json()
if response.status_code == 401:
raise Exception("GitLab API 401请配置 GITLAB_TOKEN 或确认仓库权限")
if response.status_code == 403:
raise Exception("GitLab API 403请确认仓库权限/频率限制")
raise Exception(f"GitLab API {response.status_code}: {url}")
except Exception as e:
logger.error(f"[API] GitLab API 调用失败: {url}, 错误: {e}")
raise
# Try without token
if "PRIVATE-TOKEN" in headers:
del headers["PRIVATE-TOKEN"]
@ -229,28 +380,31 @@ async def gitlab_api(url: str, token: str = None) -> Any:
async def fetch_file_content(url: str, headers: Dict[str, str] = None) -> Optional[str]:
"""获取文件内容"""
async with httpx.AsyncClient(timeout=30) as client:
client = _get_http_client()
req_headers = headers or {}
try:
response = await client.get(url, headers=headers or {})
response = await _request_with_retry(client, url, req_headers)
if response.status_code == 200:
return response.text
# 如果带 Token 请求失败401/403尝试不带 Token 请求(针对公开仓库)
if response.status_code in (401, 403) and headers:
print(f"[API] 获取文件内容返回 {response.status_code},尝试不带 Token 重试: {url}")
response = await client.get(url)
logger.info(
f"[API] 获取文件内容返回 {response.status_code},尝试不带 Token 重试: {url}"
)
response = await _request_with_retry(client, url, {})
if response.status_code == 200:
return response.text
except Exception as e:
print(f"获取文件内容失败: {url}, 错误: {e}")
logger.error(f"获取文件内容失败: {url}, 错误: {e}")
return None
async def get_github_branches(repo_url: str, token: str = None) -> List[str]:
"""获取GitHub仓库分支列表"""
repo_info = parse_repository_url(repo_url, "github")
owner, repo = repo_info['owner'], repo_info['repo']
owner, repo = repo_info["owner"], repo_info["repo"]
branches_url = f"https://api.github.com/repos/{owner}/{repo}/branches?per_page=100"
branches_data = await github_api(branches_url, token)
@ -262,14 +416,11 @@ async def get_github_branches(repo_url: str, token: str = None) -> List[str]:
return [b["name"] for b in branches_data if isinstance(b, dict) and "name" in b]
async def get_gitea_branches(repo_url: str, token: str = None) -> List[str]:
"""获取Gitea仓库分支列表"""
repo_info = parse_repository_url(repo_url, "gitea")
base_url = repo_info['base_url'] # This is {base}/api/v1
owner, repo = repo_info['owner'], repo_info['repo']
base_url = repo_info["base_url"] # This is {base}/api/v1
owner, repo = repo_info["owner"], repo_info["repo"]
branches_url = f"{base_url}/repos/{owner}/{repo}/branches"
branches_data = await gitea_api(branches_url, token)
@ -287,14 +438,14 @@ async def get_gitlab_branches(repo_url: str, token: str = None) -> List[str]:
extracted_token = token
if parsed.username:
if parsed.username == 'oauth2' and parsed.password:
if parsed.username == "oauth2" and parsed.password:
extracted_token = parsed.password
elif parsed.username and not parsed.password:
extracted_token = parsed.username
repo_info = parse_repository_url(repo_url, "gitlab")
base_url = repo_info['base_url']
project_path = quote(repo_info['project_path'], safe='')
base_url = repo_info["base_url"]
project_path = quote(repo_info["project_path"], safe="")
branches_url = f"{base_url}/projects/{project_path}/repository/branches?per_page=100"
branches_data = await gitlab_api(branches_url, extracted_token)
@ -306,11 +457,13 @@ async def get_gitlab_branches(repo_url: str, token: str = None) -> List[str]:
return [b["name"] for b in branches_data if isinstance(b, dict) and "name" in b]
async def get_github_files(repo_url: str, branch: str, token: str = None, exclude_patterns: List[str] = None) -> List[Dict[str, str]]:
async def get_github_files(
repo_url: str, branch: str, token: str = None, exclude_patterns: List[str] = None
) -> List[Dict[str, str]]:
"""获取GitHub仓库文件列表"""
# 解析仓库URL
repo_info = parse_repository_url(repo_url, "github")
owner, repo = repo_info['owner'], repo_info['repo']
owner, repo = repo_info["owner"], repo_info["repo"]
# 获取仓库文件树
tree_url = f"https://api.github.com/repos/{owner}/{repo}/git/trees/{quote(branch)}?recursive=1"
@ -318,33 +471,41 @@ async def get_github_files(repo_url: str, branch: str, token: str = None, exclud
files = []
for item in tree_data.get("tree", []):
if item.get("type") == "blob" and is_text_file(item["path"]) and not should_exclude(item["path"], exclude_patterns):
if (
item.get("type") == "blob"
and is_text_file(item["path"])
and not should_exclude(item["path"], exclude_patterns)
):
size = item.get("size", 0)
if size <= settings.MAX_FILE_SIZE_BYTES:
files.append({
files.append(
{
"path": item["path"],
"url": f"https://raw.githubusercontent.com/{owner}/{repo}/{quote(branch)}/{item['path']}"
})
"url": f"https://raw.githubusercontent.com/{owner}/{repo}/{quote(branch)}/{item['path']}",
}
)
return files
async def get_gitlab_files(repo_url: str, branch: str, token: str = None, exclude_patterns: List[str] = None) -> List[Dict[str, str]]:
async def get_gitlab_files(
repo_url: str, branch: str, token: str = None, exclude_patterns: List[str] = None
) -> List[Dict[str, str]]:
"""获取GitLab仓库文件列表"""
parsed = urlparse(repo_url)
# 从URL中提取token如果存在
extracted_token = token
if parsed.username:
if parsed.username == 'oauth2' and parsed.password:
if parsed.username == "oauth2" and parsed.password:
extracted_token = parsed.password
elif parsed.username and not parsed.password:
extracted_token = parsed.username
# 解析项目路径
repo_info = parse_repository_url(repo_url, "gitlab")
base_url = repo_info['base_url'] # {base}/api/v4
project_path = quote(repo_info['project_path'], safe='')
base_url = repo_info["base_url"] # {base}/api/v4
project_path = quote(repo_info["project_path"], safe="")
# 获取仓库文件树
tree_url = f"{base_url}/projects/{project_path}/repository/tree?ref={quote(branch)}&recursive=true&per_page=100"
@ -352,40 +513,57 @@ async def get_gitlab_files(repo_url: str, branch: str, token: str = None, exclud
files = []
for item in tree_data:
if item.get("type") == "blob" and is_text_file(item["path"]) and not should_exclude(item["path"], exclude_patterns):
files.append({
if (
item.get("type") == "blob"
and is_text_file(item["path"])
and not should_exclude(item["path"], exclude_patterns)
):
files.append(
{
"path": item["path"],
"url": f"{base_url}/projects/{project_path}/repository/files/{quote(item['path'], safe='')}/raw?ref={quote(branch)}",
"token": extracted_token
})
"token": extracted_token,
}
)
return files
async def get_gitea_files(repo_url: str, branch: str, token: str = None, exclude_patterns: List[str] = None) -> List[Dict[str, str]]:
async def get_gitea_files(
repo_url: str, branch: str, token: str = None, exclude_patterns: List[str] = None
) -> List[Dict[str, str]]:
"""获取Gitea仓库文件列表"""
repo_info = parse_repository_url(repo_url, "gitea")
base_url = repo_info['base_url']
owner, repo = repo_info['owner'], repo_info['repo']
base_url = repo_info["base_url"]
owner, repo = repo_info["owner"], repo_info["repo"]
# Gitea tree API: GET /repos/{owner}/{repo}/git/trees/{sha}?recursive=true
# 可以直接使用分支名作为sha
tree_url = f"{base_url}/repos/{quote(owner)}/{quote(repo)}/git/trees/{quote(branch)}?recursive=true"
tree_url = (
f"{base_url}/repos/{quote(owner)}/{quote(repo)}/git/trees/{quote(branch)}?recursive=true"
)
tree_data = await gitea_api(tree_url, token)
files = []
for item in tree_data.get("tree", []):
# Gitea API returns 'type': 'blob' for files
if item.get("type") == "blob" and is_text_file(item["path"]) and not should_exclude(item["path"], exclude_patterns):
if (
item.get("type") == "blob"
and is_text_file(item["path"])
and not should_exclude(item["path"], exclude_patterns)
):
# 使用API raw endpoint: GET /repos/{owner}/{repo}/raw/{filepath}?ref={branch}
files.append({
files.append(
{
"path": item["path"],
"url": f"{base_url}/repos/{owner}/{repo}/raw/{quote(item['path'])}?ref={quote(branch)}",
"token": token # 传递token以便fetch_file_content使用
})
"token": token, # 传递token以便fetch_file_content使用
}
)
return files
async def scan_repo_task(task_id: str, db_session_factory, user_config: dict = None):
"""
后台仓库扫描任务
@ -415,8 +593,8 @@ async def scan_repo_task(task_id: str, db_session_factory, user_config: dict = N
raise Exception("项目不存在")
# 检查项目类型 - 仅支持仓库类型项目
source_type = getattr(project, 'source_type', 'repository')
if source_type == 'zip':
source_type = getattr(project, "source_type", "repository")
if source_type == "zip":
raise Exception("ZIP类型项目请使用ZIP上传扫描接口")
if not project.repository_url:
@ -428,6 +606,7 @@ async def scan_repo_task(task_id: str, db_session_factory, user_config: dict = N
# 解析任务的排除模式
import json as json_module
task_exclude_patterns = []
if task.exclude_patterns:
try:
@ -435,7 +614,9 @@ async def scan_repo_task(task_id: str, db_session_factory, user_config: dict = N
except:
pass
print(f"🚀 开始扫描仓库: {repo_url}, 分支: {branch}, 类型: {repo_type}, 来源: {source_type}")
print(
f"🚀 开始扫描仓库: {repo_url}, 分支: {branch}, 类型: {repo_type}, 来源: {source_type}"
)
if task_exclude_patterns:
print(f"📋 排除模式: {task_exclude_patterns}")
@ -445,20 +626,20 @@ async def scan_repo_task(task_id: str, db_session_factory, user_config: dict = N
gitlab_token = settings.GITLAB_TOKEN
gitea_token = settings.GITEA_TOKEN
# 获取SSH私钥如果配置了
user_other_config = user_config.get('otherConfig', {}) if user_config else {}
user_other_config = user_config.get("otherConfig", {}) if user_config else {}
ssh_private_key = None
if 'sshPrivateKey' in user_other_config:
if "sshPrivateKey" in user_other_config:
from app.core.encryption import decrypt_sensitive_data
ssh_private_key = decrypt_sensitive_data(user_other_config['sshPrivateKey'])
ssh_private_key = decrypt_sensitive_data(user_other_config["sshPrivateKey"])
files: List[Dict[str, str]] = []
extracted_gitlab_token = None
# 检查是否为SSH URL
from app.services.git_ssh_service import GitSSHOperations
is_ssh_url = GitSSHOperations.is_ssh_url(repo_url)
if is_ssh_url:
@ -472,7 +653,9 @@ async def scan_repo_task(task_id: str, db_session_factory, user_config: dict = N
repo_url, ssh_private_key, branch, task_exclude_patterns
)
# 转换为统一格式
files = [{'path': f['path'], 'content': f['content']} for f in files_with_content]
files = [
{"path": f["path"], "content": f["content"]} for f in files_with_content
]
actual_branch = branch
print(f"✅ 通过SSH成功获取 {len(files)} 个文件")
except Exception as e:
@ -494,21 +677,29 @@ async def scan_repo_task(task_id: str, db_session_factory, user_config: dict = N
try:
print(f"🔄 尝试获取分支 {try_branch} 的文件列表...")
if repo_type == "github":
files = await get_github_files(repo_url, try_branch, github_token, task_exclude_patterns)
files = await get_github_files(
repo_url, try_branch, github_token, task_exclude_patterns
)
elif repo_type == "gitlab":
files = await get_gitlab_files(repo_url, try_branch, gitlab_token, task_exclude_patterns)
files = await get_gitlab_files(
repo_url, try_branch, gitlab_token, task_exclude_patterns
)
# GitLab文件可能带有token
if files and 'token' in files[0]:
extracted_gitlab_token = files[0].get('token')
if files and "token" in files[0]:
extracted_gitlab_token = files[0].get("token")
elif repo_type == "gitea":
files = await get_gitea_files(repo_url, try_branch, gitea_token, task_exclude_patterns)
files = await get_gitea_files(
repo_url, try_branch, gitea_token, task_exclude_patterns
)
else:
raise Exception("不支持的仓库类型,仅支持 GitHub, GitLab 和 Gitea 仓库")
if files:
actual_branch = try_branch
if try_branch != branch:
print(f"⚠️ 分支 {branch} 不存在或无法访问,已降级到分支 {try_branch}")
print(
f"⚠️ 分支 {branch} 不存在或无法访问,已降级到分支 {try_branch}"
)
break
except Exception as e:
last_error = str(e)
@ -530,23 +721,25 @@ async def scan_repo_task(task_id: str, db_session_factory, user_config: dict = N
# 获取分析配置(优先使用用户配置)
analysis_config = get_analysis_config(user_config)
max_analyze_files = analysis_config['max_analyze_files']
analysis_concurrency = analysis_config['llm_concurrency'] # 并发数
llm_gap_ms = analysis_config['llm_gap_ms']
max_analyze_files = analysis_config["max_analyze_files"]
analysis_concurrency = analysis_config["llm_concurrency"] # 并发数
llm_gap_ms = analysis_config["llm_gap_ms"]
# 限制文件数量
# 如果指定了特定文件,则只分析这些文件
target_files = (user_config or {}).get('scan_config', {}).get('file_paths', [])
target_files = (user_config or {}).get("scan_config", {}).get("file_paths", [])
if target_files:
print(f"🎯 指定分析 {len(target_files)} 个文件")
files = [f for f in files if f['path'] in target_files]
files = [f for f in files if f["path"] in target_files]
elif max_analyze_files > 0:
files = files[:max_analyze_files]
task.total_files = len(files)
await db.commit()
print(f"📊 获取到 {len(files)} 个文件,开始分析 (最大文件数: {max_analyze_files}, 请求间隔: {llm_gap_ms}ms)")
print(
f"📊 获取到 {len(files)} 个文件,开始分析 (最大文件数: {max_analyze_files}, 请求间隔: {llm_gap_ms}ms)"
)
# 4. 分析文件
total_issues = 0
@ -572,22 +765,24 @@ async def scan_repo_task(task_id: str, db_session_factory, user_config: dict = N
if task_control.is_cancelled(task_id):
return None
f_path = file_info['path']
f_path = file_info["path"]
MAX_RETRIES = 3
for attempt in range(MAX_RETRIES):
try:
# 4.1 获取文件内容 (仅在第一次尝试或内容获取失败时获取)
if attempt == 0:
if is_ssh_url:
content = file_info.get('content', '')
content = file_info.get("content", "")
else:
headers = {}
if repo_type == "gitlab":
token_to_use = file_info.get('token') or gitlab_token
if token_to_use: headers["PRIVATE-TOKEN"] = token_to_use
token_to_use = file_info.get("token") or gitlab_token
if token_to_use:
headers["PRIVATE-TOKEN"] = token_to_use
elif repo_type == "gitea":
token_to_use = file_info.get('token') or gitea_token
if token_to_use: headers["Authorization"] = f"token {token_to_use}"
token_to_use = file_info.get("token") or gitea_token
if token_to_use:
headers["Authorization"] = f"token {token_to_use}"
elif repo_type == "github" and github_token:
headers["Authorization"] = f"Bearer {github_token}"
@ -604,16 +799,17 @@ async def scan_repo_task(task_id: str, db_session_factory, user_config: dict = N
# 4.2 LLM 分析
language = get_language_from_path(f_path)
scan_config = (user_config or {}).get('scan_config', {})
rule_set_id = scan_config.get('rule_set_id')
prompt_template_id = scan_config.get('prompt_template_id')
scan_config = (user_config or {}).get("scan_config", {})
rule_set_id = scan_config.get("rule_set_id")
prompt_template_id = scan_config.get("prompt_template_id")
if rule_set_id or prompt_template_id:
analysis_result = await llm_service.analyze_code_with_rules(
content, language,
content,
language,
rule_set_id=rule_set_id,
prompt_template_id=prompt_template_id,
db_session=None
db_session=None,
)
else:
analysis_result = await llm_service.analyze_code(content, language)
@ -623,7 +819,7 @@ async def scan_repo_task(task_id: str, db_session_factory, user_config: dict = N
"path": f_path,
"content": content,
"language": language,
"analysis": analysis_result
"analysis": analysis_result,
}
except asyncio.CancelledError:
# 捕获取消异常,不再重试
@ -633,10 +829,18 @@ async def scan_repo_task(task_id: str, db_session_factory, user_config: dict = N
wait_time = (attempt + 1) * 2
# 特殊处理限流错误提示
error_str = str(e)
if "429" in error_str or "rate limit" in error_str.lower() or "额度不足" in error_str:
print(f"🚫 [限流提示] 仓库扫描任务触发 LLM 频率限制 (429),建议在设置中降低并发数或增加请求间隔。文件: {f_path}")
if (
"429" in error_str
or "rate limit" in error_str.lower()
or "额度不足" in error_str
):
print(
f"🚫 [限流提示] 仓库扫描任务触发 LLM 频率限制 (429),建议在设置中降低并发数或增加请求间隔。文件: {f_path}"
)
print(f"⚠️ 分析文件失败 ({f_path}), 正在进行第 {attempt+1} 次重试... 错误: {e}")
print(
f"⚠️ 分析文件失败 ({f_path}), 正在进行第 {attempt + 1} 次重试... 错误: {e}"
)
await asyncio.sleep(wait_time)
continue
else:
@ -660,7 +864,8 @@ async def scan_repo_task(task_id: str, db_session_factory, user_config: dict = N
except asyncio.CancelledError:
continue
if not res: continue
if not res:
continue
if res["type"] == "skip":
skipped_files += 1
@ -674,7 +879,7 @@ async def scan_repo_task(task_id: str, db_session_factory, user_config: dict = N
f_path = res["path"]
analysis = res["analysis"]
file_lines = res["content"].split('\n')
file_lines = res["content"].split("\n")
total_lines += len(file_lines)
# 保存问题
@ -683,7 +888,9 @@ async def scan_repo_task(task_id: str, db_session_factory, user_config: dict = N
try:
# 防御性检查:确保 issue 是字典
if not isinstance(issue, dict):
print(f"⚠️ 警告: 任务 {task_id} 中文件 {f_path} 的分析结果包含无效的问题格式: {issue}")
print(
f"⚠️ 警告: 任务 {task_id} 中文件 {f_path} 的分析结果包含无效的问题格式: {issue}"
)
continue
# 辅助函数:清理字符串中 PostgreSQL 不支持的字符
@ -693,7 +900,7 @@ async def scan_repo_task(task_id: str, db_session_factory, user_config: dict = N
if not isinstance(text, str):
text = str(text)
# 移除 NULL 字节 (PostgreSQL 不支持)
text = text.replace('\x00', '')
text = text.replace("\x00", "")
return text
line_num = issue.get("line", 1)
@ -703,7 +910,7 @@ async def scan_repo_task(task_id: str, db_session_factory, user_config: dict = N
idx = max(0, int(line_num) - 1)
start = max(0, idx - 2)
end = min(len(file_lines), idx + 3)
code_snippet = '\n'.join(file_lines[start:end])
code_snippet = "\n".join(file_lines[start:end])
except Exception:
code_snippet = ""
@ -715,11 +922,13 @@ async def scan_repo_task(task_id: str, db_session_factory, user_config: dict = N
issue_type=issue.get("type", "maintainability"),
severity=issue.get("severity", "low"),
title=sanitize_for_db(issue.get("title", "Issue")),
message=sanitize_for_db(issue.get("description") or issue.get("title", "Issue")),
message=sanitize_for_db(
issue.get("description") or issue.get("title", "Issue")
),
suggestion=sanitize_for_db(issue.get("suggestion")),
code_snippet=code_snippet,
ai_explanation=sanitize_for_db(issue.get("ai_explanation")),
status="open"
status="open",
)
db.add(audit_issue)
total_issues += 1
@ -742,7 +951,9 @@ async def scan_repo_task(task_id: str, db_session_factory, user_config: dict = N
await db.commit() # 这里的 commit 是在一个协程里按序进行的,是安全的
if processed_count % 10 == 0 or processed_count == len(files):
print(f"📈 任务 {task_id}: 进度 {processed_count}/{len(files)} ({int(processed_count/len(files)*100) if len(files) > 0 else 0}%)")
print(
f"📈 任务 {task_id}: 进度 {processed_count}/{len(files)} ({int(processed_count / len(files) * 100) if len(files) > 0 else 0}%)"
)
if consecutive_failures >= MAX_CONSECUTIVE_FAILURES:
print(f"❌ 任务 {task_id}: 连续失败 {consecutive_failures} 次,停止分析")
@ -761,7 +972,9 @@ async def scan_repo_task(task_id: str, db_session_factory, user_config: dict = N
await asyncio.gather(*task_objects, return_exceptions=True)
# 5. 完成任务
avg_quality_score = sum(quality_scores) / len(quality_scores) if quality_scores else 100.0
avg_quality_score = (
sum(quality_scores) / len(quality_scores) if quality_scores else 100.0
)
# 判断任务状态
# 如果所有文件都被跳过(空文件等),标记为完成但给出提示