1045 lines
43 KiB
Python
1045 lines
43 KiB
Python
"""
|
||
仓库扫描服务 - 支持GitHub, GitLab 和 Gitea 仓库扫描
|
||
"""
|
||
|
||
import os
|
||
import asyncio
|
||
import logging
|
||
import httpx
|
||
from typing import List, Dict, Any, Optional
|
||
from datetime import datetime, timezone
|
||
from urllib.parse import urlparse, quote
|
||
from sqlalchemy.ext.asyncio import AsyncSession
|
||
|
||
from app.utils.repo_utils import parse_repository_url
|
||
from app.models.audit import AuditTask, AuditIssue
|
||
from app.models.project import Project
|
||
from app.services.llm.service import LLMService
|
||
from app.core.config import settings
|
||
from app.core.file_filter import (
|
||
is_text_file as core_is_text_file,
|
||
should_exclude as core_should_exclude,
|
||
TEXT_EXTENSIONS as CORE_TEXT_EXTENSIONS,
|
||
)
|
||
|
||
logger = logging.getLogger(__name__)
|
||
|
||
# ---------------------------------------------------------------------------
|
||
# 共享 HTTP 客户端(模块级别单例)
|
||
# 通过复用连接池和 DNS 缓存,避免每次请求都重新建立 TCP 连接,
|
||
# 解决"每天第一次打开时因 DNS/连接冷启动导致超时"的问题。
|
||
# ---------------------------------------------------------------------------
|
||
_GIT_API_TIMEOUT = httpx.Timeout(
|
||
connect=15.0, # 连接超时 15 秒(原来 120 秒太长,用户体验差)
|
||
read=60.0, # 读取超时 60 秒(大仓库 tree API 可能较慢)
|
||
write=30.0,
|
||
pool=15.0, # 从连接池获取连接的超时
|
||
)
|
||
|
||
_http_client: Optional[httpx.AsyncClient] = None
|
||
|
||
|
||
def _get_http_client() -> httpx.AsyncClient:
|
||
"""获取或创建共享的 HTTP 客户端(惰性初始化)"""
|
||
global _http_client
|
||
if _http_client is None or _http_client.is_closed:
|
||
_http_client = httpx.AsyncClient(
|
||
timeout=_GIT_API_TIMEOUT,
|
||
limits=httpx.Limits(
|
||
max_connections=20,
|
||
max_keepalive_connections=10,
|
||
keepalive_expiry=300, # 保持连接 5 分钟
|
||
),
|
||
follow_redirects=True,
|
||
)
|
||
return _http_client
|
||
|
||
|
||
async def _request_with_retry(
|
||
client: httpx.AsyncClient,
|
||
url: str,
|
||
headers: Dict[str, str],
|
||
max_retries: int = 3,
|
||
) -> httpx.Response:
|
||
"""带自动重试的 HTTP GET 请求,针对连接超时进行重试
|
||
|
||
首次尝试使用较短的超时时间(2秒)快速判断连接是否成功,
|
||
后续重试使用默认超时时间,避免长时间阻塞用户。
|
||
"""
|
||
last_exc: Optional[Exception] = None
|
||
|
||
# 首次尝试使用短超时(2秒),快速探测连接
|
||
first_attempt_timeout = httpx.Timeout(
|
||
connect=2.0, # 首次连接超时 2 秒
|
||
read=10.0, # 首次读取超时 10 秒
|
||
write=10.0,
|
||
pool=2.0,
|
||
)
|
||
|
||
for attempt in range(max_retries):
|
||
try:
|
||
if attempt == 0:
|
||
# 首次尝试:使用短超时快速探测
|
||
return await client.get(url, headers=headers, timeout=first_attempt_timeout)
|
||
else:
|
||
# 后续重试:使用默认超时(client 配置的 _GIT_API_TIMEOUT)
|
||
return await client.get(url, headers=headers)
|
||
except (httpx.ConnectTimeout, httpx.ConnectError, httpx.ReadTimeout) as e:
|
||
last_exc = e
|
||
if attempt < max_retries - 1:
|
||
wait = (attempt + 1) * 2
|
||
logger.info(f"[API] 连接失败 (第 {attempt + 1} 次), {wait}s 后重试: {url} - {e}")
|
||
await asyncio.sleep(wait)
|
||
else:
|
||
logger.error(f"[API] 连接最终失败 (共 {max_retries} 次): {url} - {e}")
|
||
raise last_exc # type: ignore[misc]
|
||
|
||
|
||
def get_analysis_config(user_config: Optional[Dict[str, Any]] = None) -> Dict[str, Any]:
|
||
"""
|
||
获取分析配置参数(优先使用用户配置,然后使用系统配置)
|
||
|
||
Returns:
|
||
包含以下字段的字典:
|
||
- max_analyze_files: 最大分析文件数
|
||
- llm_concurrency: LLM 并发数
|
||
- llm_gap_ms: LLM 请求间隔(毫秒)
|
||
"""
|
||
other_config = (user_config or {}).get("otherConfig", {})
|
||
|
||
return {
|
||
"max_analyze_files": other_config.get("maxAnalyzeFiles") or settings.MAX_ANALYZE_FILES,
|
||
"llm_concurrency": other_config.get("llmConcurrency") or settings.LLM_CONCURRENCY,
|
||
"llm_gap_ms": other_config.get("llmGapMs") or settings.LLM_GAP_MS,
|
||
}
|
||
|
||
|
||
# 支持的文本文件扩展名使用全局定义
|
||
TEXT_EXTENSIONS = list(CORE_TEXT_EXTENSIONS)
|
||
|
||
|
||
def is_text_file(path: str) -> bool:
|
||
"""检查是否为文本文件"""
|
||
return core_is_text_file(path)
|
||
|
||
|
||
def should_exclude(path: str, exclude_patterns: List[str] = None) -> bool:
|
||
"""检查是否应该排除该文件"""
|
||
filename = os.path.basename(path)
|
||
return core_should_exclude(path, filename, exclude_patterns)
|
||
|
||
|
||
def get_language_from_path(path: str) -> str:
|
||
"""从文件路径获取语言类型"""
|
||
ext = path.split(".")[-1].lower() if "." in path else ""
|
||
language_map = {
|
||
"py": "python",
|
||
"js": "javascript",
|
||
"jsx": "javascript",
|
||
"ts": "typescript",
|
||
"tsx": "typescript",
|
||
"java": "java",
|
||
"go": "go",
|
||
"rs": "rust",
|
||
"cpp": "cpp",
|
||
"c": "c",
|
||
"cc": "cpp",
|
||
"h": "c",
|
||
"hh": "cpp",
|
||
"hpp": "cpp",
|
||
"hxx": "cpp",
|
||
"cs": "csharp",
|
||
"php": "php",
|
||
"rb": "ruby",
|
||
"kt": "kotlin",
|
||
"ktm": "kotlin",
|
||
"kts": "kotlin",
|
||
"swift": "swift",
|
||
"dart": "dart",
|
||
"scala": "scala",
|
||
"sc": "scala",
|
||
"groovy": "groovy",
|
||
"gsh": "groovy",
|
||
"gvy": "groovy",
|
||
"gy": "groovy",
|
||
"sql": "sql",
|
||
"sh": "bash",
|
||
"bash": "bash",
|
||
"zsh": "bash",
|
||
"pl": "perl",
|
||
"pm": "perl",
|
||
"t": "perl",
|
||
"lua": "lua",
|
||
"hs": "haskell",
|
||
"lhs": "haskell",
|
||
"clj": "clojure",
|
||
"cljs": "clojure",
|
||
"cljc": "clojure",
|
||
"edn": "clojure",
|
||
"ex": "elixir",
|
||
"exs": "elixir",
|
||
"erl": "erlang",
|
||
"hrl": "erlang",
|
||
"m": "objective-c",
|
||
"mm": "objective-c",
|
||
"r": "r",
|
||
"rmd": "r",
|
||
"vb": "visual-basic",
|
||
"fs": "fsharp",
|
||
"fsi": "fsharp",
|
||
"fsx": "fsharp",
|
||
"tf": "hcl",
|
||
"hcl": "hcl",
|
||
"dockerfile": "dockerfile",
|
||
}
|
||
return language_map.get(ext, "text")
|
||
|
||
|
||
class TaskControlManager:
|
||
"""任务控制管理器 - 用于取消运行中的任务"""
|
||
|
||
def __init__(self):
|
||
self._cancelled_tasks: set = set()
|
||
|
||
def cancel_task(self, task_id: str):
|
||
"""取消任务"""
|
||
self._cancelled_tasks.add(task_id)
|
||
print(f"🛑 任务 {task_id} 已标记为取消")
|
||
|
||
def is_cancelled(self, task_id: str) -> bool:
|
||
"""检查任务是否被取消"""
|
||
return task_id in self._cancelled_tasks
|
||
|
||
def cleanup_task(self, task_id: str):
|
||
"""清理已完成任务的控制状态"""
|
||
self._cancelled_tasks.discard(task_id)
|
||
|
||
|
||
# 全局任务控制器
|
||
task_control = TaskControlManager()
|
||
|
||
|
||
async def github_api(url: str, token: str = None) -> Any:
|
||
"""调用GitHub API"""
|
||
headers = {"Accept": "application/vnd.github+json"}
|
||
t = token or settings.GITHUB_TOKEN
|
||
client = _get_http_client()
|
||
|
||
# First try with token if available
|
||
if t:
|
||
headers["Authorization"] = f"Bearer {t}"
|
||
try:
|
||
response = await _request_with_retry(client, url, headers)
|
||
if response.status_code == 200:
|
||
return response.json()
|
||
if response.status_code != 401:
|
||
if response.status_code == 403:
|
||
raise Exception("GitHub API 403:请配置 GITHUB_TOKEN 或确认仓库权限/频率限制")
|
||
raise Exception(f"GitHub API {response.status_code}: {url}")
|
||
# If 401, fall through to retry without token
|
||
logger.info(
|
||
f"[API] GitHub API 401 (Unauthorized) with token, retrying without token for: {url}"
|
||
)
|
||
except Exception as e:
|
||
if "GitHub API 401" not in str(e) and "401" not in str(e):
|
||
raise
|
||
|
||
# Try without token
|
||
headers.pop("Authorization", None)
|
||
|
||
try:
|
||
response = await _request_with_retry(client, url, headers)
|
||
if response.status_code == 200:
|
||
return response.json()
|
||
if response.status_code == 403:
|
||
raise Exception("GitHub API 403:请配置 GITHUB_TOKEN 或确认仓库权限/频率限制")
|
||
if response.status_code == 401:
|
||
raise Exception("GitHub API 401:请配置 GITHUB_TOKEN 或确认仓库权限")
|
||
raise Exception(f"GitHub API {response.status_code}: {url}")
|
||
except Exception as e:
|
||
logger.error(f"[API] GitHub API 调用失败: {url}, 错误: {e}")
|
||
raise
|
||
|
||
# Try without token
|
||
if "Authorization" in headers:
|
||
del headers["Authorization"]
|
||
|
||
try:
|
||
response = await client.get(url, headers=headers)
|
||
if response.status_code == 200:
|
||
return response.json()
|
||
if response.status_code == 403:
|
||
raise Exception("GitHub API 403:请配置 GITHUB_TOKEN 或确认仓库权限/频率限制")
|
||
if response.status_code == 401:
|
||
raise Exception("GitHub API 401:请配置 GITHUB_TOKEN 或确认仓库权限")
|
||
raise Exception(f"GitHub API {response.status_code}: {url}")
|
||
except Exception as e:
|
||
print(f"[API] GitHub API 调用失败: {url}, 错误: {e}")
|
||
raise
|
||
|
||
|
||
async def gitea_api(url: str, token: str = None) -> Any:
|
||
"""调用Gitea API"""
|
||
headers = {"Content-Type": "application/json"}
|
||
t = token or settings.GITEA_TOKEN
|
||
client = _get_http_client()
|
||
|
||
# First try with token if available
|
||
if t:
|
||
headers["Authorization"] = f"token {t}"
|
||
try:
|
||
response = await _request_with_retry(client, url, headers)
|
||
if response.status_code == 200:
|
||
return response.json()
|
||
if response.status_code != 401:
|
||
if response.status_code == 403:
|
||
raise Exception("Gitea API 403:请确认仓库权限/频率限制")
|
||
raise Exception(f"Gitea API {response.status_code}: {url}")
|
||
# If 401, fall through to retry without token
|
||
logger.info(
|
||
f"[API] Gitea API 401 (Unauthorized) with token, retrying without token for: {url}"
|
||
)
|
||
except Exception as e:
|
||
if "Gitea API 401" not in str(e) and "401" not in str(e):
|
||
raise
|
||
|
||
# Try without token
|
||
headers.pop("Authorization", None)
|
||
|
||
try:
|
||
response = await _request_with_retry(client, url, headers)
|
||
if response.status_code == 200:
|
||
return response.json()
|
||
if response.status_code == 401:
|
||
raise Exception("Gitea API 401:请配置 GITEA_TOKEN 或确认仓库权限")
|
||
if response.status_code == 403:
|
||
raise Exception("Gitea API 403:请确认仓库权限/频率限制")
|
||
raise Exception(f"Gitea API {response.status_code}: {url}")
|
||
except Exception as e:
|
||
logger.error(f"[API] Gitea API 调用失败: {url}, 错误: {e}")
|
||
raise
|
||
|
||
# Try without token
|
||
if "Authorization" in headers:
|
||
del headers["Authorization"]
|
||
|
||
try:
|
||
response = await client.get(url, headers=headers)
|
||
if response.status_code == 200:
|
||
return response.json()
|
||
if response.status_code == 401:
|
||
raise Exception("Gitea API 401:请配置 GITEA_TOKEN 或确认仓库权限")
|
||
if response.status_code == 403:
|
||
raise Exception("Gitea API 403:请确认仓库权限/频率限制")
|
||
raise Exception(f"Gitea API {response.status_code}: {url}")
|
||
except Exception as e:
|
||
print(f"[API] Gitea API 调用失败: {url}, 错误: {e}")
|
||
raise
|
||
|
||
|
||
async def gitlab_api(url: str, token: str = None) -> Any:
|
||
"""调用GitLab API"""
|
||
headers = {"Content-Type": "application/json"}
|
||
t = token or settings.GITLAB_TOKEN
|
||
client = _get_http_client()
|
||
|
||
# First try with token if available
|
||
if t:
|
||
headers["PRIVATE-TOKEN"] = t
|
||
try:
|
||
response = await _request_with_retry(client, url, headers)
|
||
if response.status_code == 200:
|
||
return response.json()
|
||
if response.status_code != 401:
|
||
if response.status_code == 403:
|
||
raise Exception("GitLab API 403:请确认仓库权限/频率限制")
|
||
raise Exception(f"GitLab API {response.status_code}: {url}")
|
||
# If 401, fall through to retry without token
|
||
logger.info(
|
||
f"[API] GitLab API 401 (Unauthorized) with token, retrying without token for: {url}"
|
||
)
|
||
except Exception as e:
|
||
if "GitLab API 401" not in str(e) and "401" not in str(e):
|
||
raise
|
||
|
||
# Try without token
|
||
headers.pop("PRIVATE-TOKEN", None)
|
||
|
||
try:
|
||
response = await _request_with_retry(client, url, headers)
|
||
if response.status_code == 200:
|
||
return response.json()
|
||
if response.status_code == 401:
|
||
raise Exception("GitLab API 401:请配置 GITLAB_TOKEN 或确认仓库权限")
|
||
if response.status_code == 403:
|
||
raise Exception("GitLab API 403:请确认仓库权限/频率限制")
|
||
raise Exception(f"GitLab API {response.status_code}: {url}")
|
||
except Exception as e:
|
||
logger.error(f"[API] GitLab API 调用失败: {url}, 错误: {e}")
|
||
raise
|
||
|
||
# Try without token
|
||
if "PRIVATE-TOKEN" in headers:
|
||
del headers["PRIVATE-TOKEN"]
|
||
|
||
try:
|
||
response = await client.get(url, headers=headers)
|
||
if response.status_code == 200:
|
||
return response.json()
|
||
if response.status_code == 401:
|
||
raise Exception("GitLab API 401:请配置 GITLAB_TOKEN 或确认仓库权限")
|
||
if response.status_code == 403:
|
||
raise Exception("GitLab API 403:请确认仓库权限/频率限制")
|
||
raise Exception(f"GitLab API {response.status_code}: {url}")
|
||
except Exception as e:
|
||
print(f"[API] GitLab API 调用失败: {url}, 错误: {e}")
|
||
raise
|
||
|
||
|
||
async def fetch_file_content(url: str, headers: Dict[str, str] = None) -> Optional[str]:
|
||
"""获取文件内容"""
|
||
client = _get_http_client()
|
||
req_headers = headers or {}
|
||
try:
|
||
response = await _request_with_retry(client, url, req_headers)
|
||
if response.status_code == 200:
|
||
return response.text
|
||
|
||
# 如果带 Token 请求失败(401/403),尝试不带 Token 请求(针对公开仓库)
|
||
if response.status_code in (401, 403) and headers:
|
||
logger.info(
|
||
f"[API] 获取文件内容返回 {response.status_code},尝试不带 Token 重试: {url}"
|
||
)
|
||
response = await _request_with_retry(client, url, {})
|
||
if response.status_code == 200:
|
||
return response.text
|
||
|
||
except Exception as e:
|
||
logger.error(f"获取文件内容失败: {url}, 错误: {e}")
|
||
return None
|
||
|
||
|
||
async def get_github_branches(repo_url: str, token: str = None) -> List[str]:
|
||
"""获取GitHub仓库分支列表"""
|
||
repo_info = parse_repository_url(repo_url, "github")
|
||
owner, repo = repo_info["owner"], repo_info["repo"]
|
||
|
||
branches_url = f"https://api.github.com/repos/{owner}/{repo}/branches?per_page=100"
|
||
branches_data = await github_api(branches_url, token)
|
||
|
||
if not isinstance(branches_data, list):
|
||
print(f"[Branch] 警告: 获取 GitHub 分支列表返回非列表数据: {branches_data}")
|
||
return []
|
||
|
||
return [b["name"] for b in branches_data if isinstance(b, dict) and "name" in b]
|
||
|
||
|
||
async def get_gitea_branches(repo_url: str, token: str = None) -> List[str]:
|
||
"""获取Gitea仓库分支列表"""
|
||
repo_info = parse_repository_url(repo_url, "gitea")
|
||
base_url = repo_info["base_url"] # This is {base}/api/v1
|
||
owner, repo = repo_info["owner"], repo_info["repo"]
|
||
|
||
branches_url = f"{base_url}/repos/{owner}/{repo}/branches"
|
||
branches_data = await gitea_api(branches_url, token)
|
||
|
||
if not isinstance(branches_data, list):
|
||
print(f"[Branch] 警告: 获取 Gitea 分支列表返回非列表数据: {branches_data}")
|
||
return []
|
||
|
||
return [b["name"] for b in branches_data if isinstance(b, dict) and "name" in b]
|
||
|
||
|
||
async def get_gitlab_branches(repo_url: str, token: str = None) -> List[str]:
|
||
"""获取GitLab仓库分支列表"""
|
||
parsed = urlparse(repo_url)
|
||
|
||
extracted_token = token
|
||
if parsed.username:
|
||
if parsed.username == "oauth2" and parsed.password:
|
||
extracted_token = parsed.password
|
||
elif parsed.username and not parsed.password:
|
||
extracted_token = parsed.username
|
||
|
||
repo_info = parse_repository_url(repo_url, "gitlab")
|
||
base_url = repo_info["base_url"]
|
||
project_path = quote(repo_info["project_path"], safe="")
|
||
|
||
branches_url = f"{base_url}/projects/{project_path}/repository/branches?per_page=100"
|
||
branches_data = await gitlab_api(branches_url, extracted_token)
|
||
|
||
if not isinstance(branches_data, list):
|
||
print(f"[Branch] 警告: 获取 GitLab 分支列表返回非列表数据: {branches_data}")
|
||
return []
|
||
|
||
return [b["name"] for b in branches_data if isinstance(b, dict) and "name" in b]
|
||
|
||
|
||
async def get_github_files(
|
||
repo_url: str, branch: str, token: str = None, exclude_patterns: List[str] = None
|
||
) -> List[Dict[str, str]]:
|
||
"""获取GitHub仓库文件列表"""
|
||
# 解析仓库URL
|
||
repo_info = parse_repository_url(repo_url, "github")
|
||
owner, repo = repo_info["owner"], repo_info["repo"]
|
||
|
||
# 获取仓库文件树
|
||
tree_url = f"https://api.github.com/repos/{owner}/{repo}/git/trees/{quote(branch)}?recursive=1"
|
||
tree_data = await github_api(tree_url, token)
|
||
|
||
files = []
|
||
for item in tree_data.get("tree", []):
|
||
if (
|
||
item.get("type") == "blob"
|
||
and is_text_file(item["path"])
|
||
and not should_exclude(item["path"], exclude_patterns)
|
||
):
|
||
size = item.get("size", 0)
|
||
if size <= settings.MAX_FILE_SIZE_BYTES:
|
||
files.append(
|
||
{
|
||
"path": item["path"],
|
||
"url": f"https://raw.githubusercontent.com/{owner}/{repo}/{quote(branch)}/{item['path']}",
|
||
}
|
||
)
|
||
|
||
return files
|
||
|
||
|
||
async def get_gitlab_files(
|
||
repo_url: str, branch: str, token: str = None, exclude_patterns: List[str] = None
|
||
) -> List[Dict[str, str]]:
|
||
"""获取GitLab仓库文件列表"""
|
||
parsed = urlparse(repo_url)
|
||
|
||
# 从URL中提取token(如果存在)
|
||
extracted_token = token
|
||
if parsed.username:
|
||
if parsed.username == "oauth2" and parsed.password:
|
||
extracted_token = parsed.password
|
||
elif parsed.username and not parsed.password:
|
||
extracted_token = parsed.username
|
||
|
||
# 解析项目路径
|
||
repo_info = parse_repository_url(repo_url, "gitlab")
|
||
base_url = repo_info["base_url"] # {base}/api/v4
|
||
project_path = quote(repo_info["project_path"], safe="")
|
||
|
||
# 获取仓库文件树
|
||
tree_url = f"{base_url}/projects/{project_path}/repository/tree?ref={quote(branch)}&recursive=true&per_page=100"
|
||
tree_data = await gitlab_api(tree_url, extracted_token)
|
||
|
||
files = []
|
||
for item in tree_data:
|
||
if (
|
||
item.get("type") == "blob"
|
||
and is_text_file(item["path"])
|
||
and not should_exclude(item["path"], exclude_patterns)
|
||
):
|
||
files.append(
|
||
{
|
||
"path": item["path"],
|
||
"url": f"{base_url}/projects/{project_path}/repository/files/{quote(item['path'], safe='')}/raw?ref={quote(branch)}",
|
||
"token": extracted_token,
|
||
}
|
||
)
|
||
|
||
return files
|
||
|
||
|
||
async def get_gitea_files(
|
||
repo_url: str, branch: str, token: str = None, exclude_patterns: List[str] = None
|
||
) -> List[Dict[str, str]]:
|
||
"""获取Gitea仓库文件列表"""
|
||
repo_info = parse_repository_url(repo_url, "gitea")
|
||
base_url = repo_info["base_url"]
|
||
owner, repo = repo_info["owner"], repo_info["repo"]
|
||
|
||
# Gitea tree API: GET /repos/{owner}/{repo}/git/trees/{sha}?recursive=true
|
||
# 可以直接使用分支名作为sha
|
||
tree_url = (
|
||
f"{base_url}/repos/{quote(owner)}/{quote(repo)}/git/trees/{quote(branch)}?recursive=true"
|
||
)
|
||
tree_data = await gitea_api(tree_url, token)
|
||
|
||
files = []
|
||
for item in tree_data.get("tree", []):
|
||
# Gitea API returns 'type': 'blob' for files
|
||
if (
|
||
item.get("type") == "blob"
|
||
and is_text_file(item["path"])
|
||
and not should_exclude(item["path"], exclude_patterns)
|
||
):
|
||
# 使用API raw endpoint: GET /repos/{owner}/{repo}/raw/{filepath}?ref={branch}
|
||
files.append(
|
||
{
|
||
"path": item["path"],
|
||
"url": f"{base_url}/repos/{owner}/{repo}/raw/{quote(item['path'])}?ref={quote(branch)}",
|
||
"token": token, # 传递token以便fetch_file_content使用
|
||
}
|
||
)
|
||
|
||
return files
|
||
|
||
|
||
async def scan_repo_task(task_id: str, db_session_factory, user_config: dict = None):
|
||
"""
|
||
后台仓库扫描任务
|
||
|
||
Args:
|
||
task_id: 任务ID
|
||
db_session_factory: 数据库会话工厂
|
||
user_config: 用户配置字典(包含llmConfig和otherConfig)
|
||
"""
|
||
async with db_session_factory() as db:
|
||
task = await db.get(AuditTask, task_id)
|
||
if not task:
|
||
return
|
||
|
||
try:
|
||
# 1. 更新状态为运行中
|
||
task.status = "running"
|
||
task.started_at = datetime.now(timezone.utc)
|
||
await db.commit()
|
||
|
||
# 创建使用用户配置的LLM服务实例
|
||
llm_service = LLMService(user_config=user_config or {})
|
||
|
||
# 2. 获取项目信息
|
||
project = await db.get(Project, task.project_id)
|
||
if not project:
|
||
raise Exception("项目不存在")
|
||
|
||
# 检查项目类型 - 仅支持仓库类型项目
|
||
source_type = getattr(project, "source_type", "repository")
|
||
if source_type == "zip":
|
||
raise Exception("ZIP类型项目请使用ZIP上传扫描接口")
|
||
|
||
if not project.repository_url:
|
||
raise Exception("仓库地址不存在")
|
||
|
||
repo_url = project.repository_url
|
||
branch = task.branch_name or project.default_branch or "main"
|
||
repo_type = project.repository_type or "other"
|
||
|
||
# 解析任务的排除模式
|
||
import json as json_module
|
||
|
||
task_exclude_patterns = []
|
||
if task.exclude_patterns:
|
||
try:
|
||
task_exclude_patterns = json_module.loads(task.exclude_patterns)
|
||
except:
|
||
pass
|
||
|
||
print(
|
||
f"🚀 开始扫描仓库: {repo_url}, 分支: {branch}, 类型: {repo_type}, 来源: {source_type}"
|
||
)
|
||
if task_exclude_patterns:
|
||
print(f"📋 排除模式: {task_exclude_patterns}")
|
||
|
||
# 3. 获取文件列表
|
||
# Git Token 始终来自系统默认(.env),逻辑锁定
|
||
github_token = settings.GITHUB_TOKEN
|
||
gitlab_token = settings.GITLAB_TOKEN
|
||
gitea_token = settings.GITEA_TOKEN
|
||
|
||
# 获取SSH私钥(如果配置了)
|
||
user_other_config = user_config.get("otherConfig", {}) if user_config else {}
|
||
ssh_private_key = None
|
||
if "sshPrivateKey" in user_other_config:
|
||
from app.core.encryption import decrypt_sensitive_data
|
||
|
||
ssh_private_key = decrypt_sensitive_data(user_other_config["sshPrivateKey"])
|
||
|
||
files: List[Dict[str, str]] = []
|
||
extracted_gitlab_token = None
|
||
|
||
# 检查是否为SSH URL
|
||
from app.services.git_ssh_service import GitSSHOperations
|
||
|
||
is_ssh_url = GitSSHOperations.is_ssh_url(repo_url)
|
||
|
||
if is_ssh_url:
|
||
# 使用SSH方式获取文件
|
||
if not ssh_private_key:
|
||
raise Exception("仓库使用SSH URL,但未配置SSH密钥。请先生成并配置SSH密钥。")
|
||
|
||
print(f"🔐 使用SSH方式访问仓库: {repo_url}")
|
||
try:
|
||
files_with_content = GitSSHOperations.get_repo_files_via_ssh(
|
||
repo_url, ssh_private_key, branch, task_exclude_patterns
|
||
)
|
||
# 转换为统一格式
|
||
files = [
|
||
{"path": f["path"], "content": f["content"]} for f in files_with_content
|
||
]
|
||
actual_branch = branch
|
||
print(f"✅ 通过SSH成功获取 {len(files)} 个文件")
|
||
except Exception as e:
|
||
raise Exception(f"SSH方式获取仓库文件失败: {str(e)}")
|
||
else:
|
||
# 使用API方式获取文件(原有逻辑)
|
||
# 构建分支尝试顺序(分支降级机制)
|
||
branches_to_try = [branch]
|
||
if project.default_branch and project.default_branch != branch:
|
||
branches_to_try.append(project.default_branch)
|
||
for common_branch in ["main", "master"]:
|
||
if common_branch not in branches_to_try:
|
||
branches_to_try.append(common_branch)
|
||
|
||
actual_branch = branch # 实际使用的分支
|
||
last_error = None
|
||
|
||
for try_branch in branches_to_try:
|
||
try:
|
||
print(f"🔄 尝试获取分支 {try_branch} 的文件列表...")
|
||
if repo_type == "github":
|
||
files = await get_github_files(
|
||
repo_url, try_branch, github_token, task_exclude_patterns
|
||
)
|
||
elif repo_type == "gitlab":
|
||
files = await get_gitlab_files(
|
||
repo_url, try_branch, gitlab_token, task_exclude_patterns
|
||
)
|
||
# GitLab文件可能带有token
|
||
if files and "token" in files[0]:
|
||
extracted_gitlab_token = files[0].get("token")
|
||
elif repo_type == "gitea":
|
||
files = await get_gitea_files(
|
||
repo_url, try_branch, gitea_token, task_exclude_patterns
|
||
)
|
||
else:
|
||
raise Exception("不支持的仓库类型,仅支持 GitHub, GitLab 和 Gitea 仓库")
|
||
|
||
if files:
|
||
actual_branch = try_branch
|
||
if try_branch != branch:
|
||
print(
|
||
f"⚠️ 分支 {branch} 不存在或无法访问,已降级到分支 {try_branch}"
|
||
)
|
||
break
|
||
except Exception as e:
|
||
last_error = str(e)
|
||
print(f"⚠️ 获取分支 {try_branch} 失败: {last_error[:100]}")
|
||
continue
|
||
|
||
if not files:
|
||
error_msg = f"无法获取仓库文件,所有分支尝试均失败"
|
||
if last_error:
|
||
if "404" in last_error or "Not Found" in last_error:
|
||
error_msg = f"仓库或分支不存在: {branch}"
|
||
elif "401" in last_error or "403" in last_error:
|
||
error_msg = "无访问权限,请检查 Token 配置"
|
||
else:
|
||
error_msg = f"获取文件失败: {last_error[:100]}"
|
||
raise Exception(error_msg)
|
||
|
||
print(f"✅ 成功获取分支 {actual_branch} 的文件列表")
|
||
|
||
# 获取分析配置(优先使用用户配置)
|
||
analysis_config = get_analysis_config(user_config)
|
||
max_analyze_files = analysis_config["max_analyze_files"]
|
||
analysis_concurrency = analysis_config["llm_concurrency"] # 并发数
|
||
llm_gap_ms = analysis_config["llm_gap_ms"]
|
||
|
||
# 限制文件数量
|
||
# 如果指定了特定文件,则只分析这些文件
|
||
target_files = (user_config or {}).get("scan_config", {}).get("file_paths", [])
|
||
if target_files:
|
||
print(f"🎯 指定分析 {len(target_files)} 个文件")
|
||
files = [f for f in files if f["path"] in target_files]
|
||
elif max_analyze_files > 0:
|
||
files = files[:max_analyze_files]
|
||
|
||
task.total_files = len(files)
|
||
await db.commit()
|
||
|
||
print(
|
||
f"📊 获取到 {len(files)} 个文件,开始分析 (最大文件数: {max_analyze_files}, 请求间隔: {llm_gap_ms}ms)"
|
||
)
|
||
|
||
# 4. 分析文件
|
||
total_issues = 0
|
||
total_lines = 0
|
||
quality_scores = []
|
||
scanned_files = 0
|
||
failed_files = 0
|
||
skipped_files = 0 # 跳过的文件(空文件、太大等)
|
||
consecutive_failures = 0
|
||
MAX_CONSECUTIVE_FAILURES = 5
|
||
last_error = None
|
||
|
||
# 4. 并行分析文件
|
||
print(f"🧬 启动并行分析: {len(files)} 个文件, 并发数: {analysis_concurrency}")
|
||
|
||
semaphore = asyncio.Semaphore(analysis_concurrency)
|
||
|
||
async def analyze_single_file(file_info):
|
||
"""内部函数:分析单个文件并返回结果"""
|
||
nonlocal consecutive_failures, last_error
|
||
|
||
async with semaphore:
|
||
if task_control.is_cancelled(task_id):
|
||
return None
|
||
|
||
f_path = file_info["path"]
|
||
MAX_RETRIES = 3
|
||
for attempt in range(MAX_RETRIES):
|
||
try:
|
||
# 4.1 获取文件内容 (仅在第一次尝试或内容获取失败时获取)
|
||
if attempt == 0:
|
||
if is_ssh_url:
|
||
content = file_info.get("content", "")
|
||
else:
|
||
headers = {}
|
||
if repo_type == "gitlab":
|
||
token_to_use = file_info.get("token") or gitlab_token
|
||
if token_to_use:
|
||
headers["PRIVATE-TOKEN"] = token_to_use
|
||
elif repo_type == "gitea":
|
||
token_to_use = file_info.get("token") or gitea_token
|
||
if token_to_use:
|
||
headers["Authorization"] = f"token {token_to_use}"
|
||
elif repo_type == "github" and github_token:
|
||
headers["Authorization"] = f"Bearer {github_token}"
|
||
|
||
content = await fetch_file_content(file_info["url"], headers)
|
||
|
||
if not content or not content.strip():
|
||
return {"type": "skip", "reason": "empty", "path": f_path}
|
||
|
||
if len(content) > settings.MAX_FILE_SIZE_BYTES:
|
||
return {"type": "skip", "reason": "too_large", "path": f_path}
|
||
|
||
if task_control.is_cancelled(task_id):
|
||
return None
|
||
|
||
# 4.2 LLM 分析
|
||
language = get_language_from_path(f_path)
|
||
scan_config = (user_config or {}).get("scan_config", {})
|
||
rule_set_id = scan_config.get("rule_set_id")
|
||
prompt_template_id = scan_config.get("prompt_template_id")
|
||
|
||
if rule_set_id or prompt_template_id:
|
||
analysis_result = await llm_service.analyze_code_with_rules(
|
||
content,
|
||
language,
|
||
rule_set_id=rule_set_id,
|
||
prompt_template_id=prompt_template_id,
|
||
db_session=None,
|
||
)
|
||
else:
|
||
analysis_result = await llm_service.analyze_code(content, language)
|
||
|
||
return {
|
||
"type": "success",
|
||
"path": f_path,
|
||
"content": content,
|
||
"language": language,
|
||
"analysis": analysis_result,
|
||
}
|
||
except asyncio.CancelledError:
|
||
# 捕获取消异常,不再重试
|
||
return None
|
||
except Exception as e:
|
||
if attempt < MAX_RETRIES - 1:
|
||
wait_time = (attempt + 1) * 2
|
||
# 特殊处理限流错误提示
|
||
error_str = str(e)
|
||
if (
|
||
"429" in error_str
|
||
or "rate limit" in error_str.lower()
|
||
or "额度不足" in error_str
|
||
):
|
||
print(
|
||
f"🚫 [限流提示] 仓库扫描任务触发 LLM 频率限制 (429),建议在设置中降低并发数或增加请求间隔。文件: {f_path}"
|
||
)
|
||
|
||
print(
|
||
f"⚠️ 分析文件失败 ({f_path}), 正在进行第 {attempt + 1} 次重试... 错误: {e}"
|
||
)
|
||
await asyncio.sleep(wait_time)
|
||
continue
|
||
else:
|
||
print(f"❌ 分析文件最终失败 ({f_path}): {e}")
|
||
last_error = str(e)
|
||
return {"type": "error", "path": f_path, "error": str(e)}
|
||
|
||
# 创建所有分析任务对象以便跟踪
|
||
task_objects = [asyncio.create_task(analyze_single_file(f)) for f in files]
|
||
|
||
try:
|
||
# 使用 as_completed 处理结果,这样可以实时更新进度且安全使用当前 db session
|
||
for future in asyncio.as_completed(task_objects):
|
||
if task_control.is_cancelled(task_id):
|
||
# 停止处理后续完成的任务
|
||
print(f"🛑 任务 {task_id} 检测到取消信号,停止主循环")
|
||
break
|
||
|
||
try:
|
||
res = await future
|
||
except asyncio.CancelledError:
|
||
continue
|
||
|
||
if not res:
|
||
continue
|
||
|
||
if res["type"] == "skip":
|
||
skipped_files += 1
|
||
task.total_files = max(0, task.total_files - 1)
|
||
elif res["type"] == "error":
|
||
failed_files += 1
|
||
consecutive_failures += 1
|
||
elif res["type"] == "success":
|
||
consecutive_failures = 0
|
||
scanned_files += 1
|
||
|
||
f_path = res["path"]
|
||
analysis = res["analysis"]
|
||
file_lines = res["content"].split("\n")
|
||
total_lines += len(file_lines)
|
||
|
||
# 保存问题
|
||
issues = analysis.get("issues", [])
|
||
for issue in issues:
|
||
try:
|
||
# 防御性检查:确保 issue 是字典
|
||
if not isinstance(issue, dict):
|
||
print(
|
||
f"⚠️ 警告: 任务 {task_id} 中文件 {f_path} 的分析结果包含无效的问题格式: {issue}"
|
||
)
|
||
continue
|
||
|
||
# 辅助函数:清理字符串中 PostgreSQL 不支持的字符
|
||
def sanitize_for_db(text):
|
||
if text is None:
|
||
return None
|
||
if not isinstance(text, str):
|
||
text = str(text)
|
||
# 移除 NULL 字节 (PostgreSQL 不支持)
|
||
text = text.replace("\x00", "")
|
||
return text
|
||
|
||
line_num = issue.get("line", 1)
|
||
code_snippet = sanitize_for_db(issue.get("code_snippet"))
|
||
if not code_snippet or len(code_snippet.strip()) < 5:
|
||
try:
|
||
idx = max(0, int(line_num) - 1)
|
||
start = max(0, idx - 2)
|
||
end = min(len(file_lines), idx + 3)
|
||
code_snippet = "\n".join(file_lines[start:end])
|
||
except Exception:
|
||
code_snippet = ""
|
||
|
||
audit_issue = AuditIssue(
|
||
task_id=task.id,
|
||
file_path=f_path,
|
||
line_number=line_num,
|
||
column_number=issue.get("column"),
|
||
issue_type=issue.get("type", "maintainability"),
|
||
severity=issue.get("severity", "low"),
|
||
title=sanitize_for_db(issue.get("title", "Issue")),
|
||
message=sanitize_for_db(
|
||
issue.get("description") or issue.get("title", "Issue")
|
||
),
|
||
suggestion=sanitize_for_db(issue.get("suggestion")),
|
||
code_snippet=code_snippet,
|
||
ai_explanation=sanitize_for_db(issue.get("ai_explanation")),
|
||
status="open",
|
||
)
|
||
db.add(audit_issue)
|
||
total_issues += 1
|
||
except Exception as e:
|
||
print(f"⚠️ 处理单个问题时出错 (文件 {f_path}): {e}")
|
||
continue
|
||
|
||
if "quality_score" in analysis:
|
||
try:
|
||
quality_score = float(analysis["quality_score"])
|
||
quality_scores.append(quality_score)
|
||
except (ValueError, TypeError):
|
||
pass
|
||
|
||
# 更新主任务进度
|
||
processed_count = scanned_files + failed_files
|
||
task.scanned_files = processed_count
|
||
task.total_lines = total_lines
|
||
task.issues_count = total_issues
|
||
await db.commit() # 这里的 commit 是在一个协程里按序进行的,是安全的
|
||
|
||
if processed_count % 10 == 0 or processed_count == len(files):
|
||
print(
|
||
f"📈 任务 {task_id}: 进度 {processed_count}/{len(files)} ({int(processed_count / len(files) * 100) if len(files) > 0 else 0}%)"
|
||
)
|
||
|
||
if consecutive_failures >= MAX_CONSECUTIVE_FAILURES:
|
||
print(f"❌ 任务 {task_id}: 连续失败 {consecutive_failures} 次,停止分析")
|
||
break
|
||
finally:
|
||
# 无论正常结束、中途 break 还是发生异常,都确保取消所有未完成的任务
|
||
pending_count = 0
|
||
for t in task_objects:
|
||
if not t.done():
|
||
t.cancel()
|
||
pending_count += 1
|
||
|
||
if pending_count > 0:
|
||
print(f"🧹 任务 {task_id}: 已清理 {pending_count} 个后台待处理或执行中的任务")
|
||
# 等待一下让取消逻辑执行完毕,但不阻塞太久
|
||
await asyncio.gather(*task_objects, return_exceptions=True)
|
||
|
||
# 5. 完成任务
|
||
avg_quality_score = (
|
||
sum(quality_scores) / len(quality_scores) if quality_scores else 100.0
|
||
)
|
||
|
||
# 判断任务状态
|
||
# 如果所有文件都被跳过(空文件等),标记为完成但给出提示
|
||
if len(files) > 0 and scanned_files == 0 and skipped_files == len(files):
|
||
task.status = "completed"
|
||
task.completed_at = datetime.now(timezone.utc)
|
||
task.scanned_files = 0
|
||
task.total_lines = 0
|
||
task.issues_count = 0
|
||
task.quality_score = 100.0
|
||
await db.commit()
|
||
print(f"⚠️ 任务 {task_id} 完成: 所有 {len(files)} 个文件均为空或被跳过,无需分析")
|
||
# 如果有文件需要分析但全部失败(LLM调用失败),标记为失败
|
||
elif len(files) > 0 and scanned_files == 0 and failed_files > 0:
|
||
task.status = "failed"
|
||
task.completed_at = datetime.now(timezone.utc)
|
||
task.scanned_files = 0
|
||
task.total_lines = total_lines
|
||
task.issues_count = 0
|
||
task.quality_score = 0
|
||
|
||
# 尝试从最后一个错误中获取更详细的系统提示
|
||
error_msg = f"{failed_files} 个文件分析失败,请检查 LLM API 配置。最近一个错误: {str(last_error) if 'last_error' in locals() else '未知错误'}"
|
||
task.error_message = error_msg
|
||
await db.commit()
|
||
print(f"❌ 任务 {task_id} 失败: {error_msg}")
|
||
else:
|
||
task.status = "completed"
|
||
task.completed_at = datetime.now(timezone.utc)
|
||
# 最终显示的已扫描文件数为成功分析的文件数
|
||
task.scanned_files = scanned_files
|
||
task.total_lines = total_lines
|
||
task.issues_count = total_issues
|
||
task.quality_score = avg_quality_score
|
||
await db.commit()
|
||
|
||
result_msg = f"✅ 任务 {task_id} 完成: 成功分析 {scanned_files} 个文件"
|
||
if failed_files > 0:
|
||
result_msg += f", {failed_files} 个文件失败"
|
||
result_msg += f", 发现 {total_issues} 个问题, 质量分 {avg_quality_score:.1f}"
|
||
print(result_msg)
|
||
task_control.cleanup_task(task_id)
|
||
|
||
except Exception as e:
|
||
print(f"❌ 扫描失败: {e}")
|
||
task.status = "failed"
|
||
task.completed_at = datetime.now(timezone.utc)
|
||
await db.commit()
|
||
task_control.cleanup_task(task_id)
|