From 5491714cce62045ee5ce3eecc002c27256a2bd46 Mon Sep 17 00:00:00 2001 From: vinland100 Date: Fri, 6 Feb 2026 14:12:51 +0800 Subject: [PATCH] feat: Implement shared HTTP client with retry logic and optimized connection settings for GitHub/Gitea API calls to improve reliability and performance. --- backend/app/services/scanner.py | 683 +++++++++++++++++++++----------- 1 file changed, 448 insertions(+), 235 deletions(-) diff --git a/backend/app/services/scanner.py b/backend/app/services/scanner.py index 3a5791b..9409add 100644 --- a/backend/app/services/scanner.py +++ b/backend/app/services/scanner.py @@ -4,6 +4,7 @@ import os import asyncio +import logging import httpx from typing import List, Dict, Any, Optional from datetime import datetime, timezone @@ -15,7 +16,65 @@ from app.models.audit import AuditTask, AuditIssue from app.models.project import Project from app.services.llm.service import LLMService from app.core.config import settings -from app.core.file_filter import is_text_file as core_is_text_file, should_exclude as core_should_exclude, TEXT_EXTENSIONS as CORE_TEXT_EXTENSIONS +from app.core.file_filter import ( + is_text_file as core_is_text_file, + should_exclude as core_should_exclude, + TEXT_EXTENSIONS as CORE_TEXT_EXTENSIONS, +) + +logger = logging.getLogger(__name__) + +# --------------------------------------------------------------------------- +# 共享 HTTP 客户端(模块级别单例) +# 通过复用连接池和 DNS 缓存,避免每次请求都重新建立 TCP 连接, +# 解决"每天第一次打开时因 DNS/连接冷启动导致超时"的问题。 +# --------------------------------------------------------------------------- +_GIT_API_TIMEOUT = httpx.Timeout( + connect=15.0, # 连接超时 15 秒(原来 120 秒太长,用户体验差) + read=60.0, # 读取超时 60 秒(大仓库 tree API 可能较慢) + write=30.0, + pool=15.0, # 从连接池获取连接的超时 +) + +_http_client: Optional[httpx.AsyncClient] = None + + +def _get_http_client() -> httpx.AsyncClient: + """获取或创建共享的 HTTP 客户端(惰性初始化)""" + global _http_client + if _http_client is None or _http_client.is_closed: + _http_client = httpx.AsyncClient( + timeout=_GIT_API_TIMEOUT, + limits=httpx.Limits( + max_connections=20, + max_keepalive_connections=10, + keepalive_expiry=300, # 保持连接 5 分钟 + ), + follow_redirects=True, + ) + return _http_client + + +async def _request_with_retry( + client: httpx.AsyncClient, + url: str, + headers: Dict[str, str], + max_retries: int = 3, +) -> httpx.Response: + """带自动重试的 HTTP GET 请求,针对连接超时进行重试""" + last_exc: Optional[Exception] = None + for attempt in range(max_retries): + try: + return await client.get(url, headers=headers) + except (httpx.ConnectTimeout, httpx.ConnectError) as e: + last_exc = e + if attempt < max_retries - 1: + wait = (attempt + 1) * 2 + logger.info(f"[API] 连接失败 (第 {attempt + 1} 次), {wait}s 后重试: {url} - {e}") + await asyncio.sleep(wait) + else: + logger.error(f"[API] 连接最终失败 (共 {max_retries} 次): {url} - {e}") + raise last_exc # type: ignore[misc] def get_analysis_config(user_config: Optional[Dict[str, Any]] = None) -> Dict[str, Any]: @@ -28,18 +87,19 @@ def get_analysis_config(user_config: Optional[Dict[str, Any]] = None) -> Dict[st - llm_concurrency: LLM 并发数 - llm_gap_ms: LLM 请求间隔(毫秒) """ - other_config = (user_config or {}).get('otherConfig', {}) + other_config = (user_config or {}).get("otherConfig", {}) return { - 'max_analyze_files': other_config.get('maxAnalyzeFiles') or settings.MAX_ANALYZE_FILES, - 'llm_concurrency': other_config.get('llmConcurrency') or settings.LLM_CONCURRENCY, - 'llm_gap_ms': other_config.get('llmGapMs') or settings.LLM_GAP_MS, + "max_analyze_files": other_config.get("maxAnalyzeFiles") or settings.MAX_ANALYZE_FILES, + "llm_concurrency": other_config.get("llmConcurrency") or settings.LLM_CONCURRENCY, + "llm_gap_ms": other_config.get("llmGapMs") or settings.LLM_GAP_MS, } # 支持的文本文件扩展名使用全局定义 TEXT_EXTENSIONS = list(CORE_TEXT_EXTENSIONS) + def is_text_file(path: str) -> bool: """检查是否为文本文件""" return core_is_text_file(path) @@ -53,47 +113,85 @@ def should_exclude(path: str, exclude_patterns: List[str] = None) -> bool: def get_language_from_path(path: str) -> str: """从文件路径获取语言类型""" - ext = path.split('.')[-1].lower() if '.' in path else '' + ext = path.split(".")[-1].lower() if "." in path else "" language_map = { - 'py': 'python', - 'js': 'javascript', 'jsx': 'javascript', - 'ts': 'typescript', 'tsx': 'typescript', - 'java': 'java', 'go': 'go', 'rs': 'rust', - 'cpp': 'cpp', 'c': 'c', 'cc': 'cpp', 'h': 'c', 'hh': 'cpp', - 'hpp': 'cpp', 'hxx': 'cpp', - 'cs': 'csharp', 'php': 'php', 'rb': 'ruby', - 'kt': 'kotlin', 'ktm': 'kotlin', 'kts': 'kotlin', - 'swift': 'swift', 'dart': 'dart', - 'scala': 'scala', 'sc': 'scala', - 'groovy': 'groovy', 'gsh': 'groovy', 'gvy': 'groovy', 'gy': 'groovy', - 'sql': 'sql', 'sh': 'bash', 'bash': 'bash', 'zsh': 'bash', - 'pl': 'perl', 'pm': 'perl', 't': 'perl', - 'lua': 'lua', 'hs': 'haskell', 'lhs': 'haskell', - 'clj': 'clojure', 'cljs': 'clojure', 'cljc': 'clojure', 'edn': 'clojure', - 'ex': 'elixir', 'exs': 'elixir', 'erl': 'erlang', 'hrl': 'erlang', - 'm': 'objective-c', 'mm': 'objective-c', - 'r': 'r', 'rmd': 'r', - 'vb': 'visual-basic', 'fs': 'fsharp', 'fsi': 'fsharp', 'fsx': 'fsharp', - 'tf': 'hcl', 'hcl': 'hcl', 'dockerfile': 'dockerfile' + "py": "python", + "js": "javascript", + "jsx": "javascript", + "ts": "typescript", + "tsx": "typescript", + "java": "java", + "go": "go", + "rs": "rust", + "cpp": "cpp", + "c": "c", + "cc": "cpp", + "h": "c", + "hh": "cpp", + "hpp": "cpp", + "hxx": "cpp", + "cs": "csharp", + "php": "php", + "rb": "ruby", + "kt": "kotlin", + "ktm": "kotlin", + "kts": "kotlin", + "swift": "swift", + "dart": "dart", + "scala": "scala", + "sc": "scala", + "groovy": "groovy", + "gsh": "groovy", + "gvy": "groovy", + "gy": "groovy", + "sql": "sql", + "sh": "bash", + "bash": "bash", + "zsh": "bash", + "pl": "perl", + "pm": "perl", + "t": "perl", + "lua": "lua", + "hs": "haskell", + "lhs": "haskell", + "clj": "clojure", + "cljs": "clojure", + "cljc": "clojure", + "edn": "clojure", + "ex": "elixir", + "exs": "elixir", + "erl": "erlang", + "hrl": "erlang", + "m": "objective-c", + "mm": "objective-c", + "r": "r", + "rmd": "r", + "vb": "visual-basic", + "fs": "fsharp", + "fsi": "fsharp", + "fsx": "fsharp", + "tf": "hcl", + "hcl": "hcl", + "dockerfile": "dockerfile", } - return language_map.get(ext, 'text') + return language_map.get(ext, "text") class TaskControlManager: """任务控制管理器 - 用于取消运行中的任务""" - + def __init__(self): self._cancelled_tasks: set = set() - + def cancel_task(self, task_id: str): """取消任务""" self._cancelled_tasks.add(task_id) print(f"🛑 任务 {task_id} 已标记为取消") - + def is_cancelled(self, task_id: str) -> bool: """检查任务是否被取消""" return task_id in self._cancelled_tasks - + def cleanup_task(self, task_id: str): """清理已完成任务的控制状态""" self._cancelled_tasks.discard(task_id) @@ -107,29 +205,47 @@ async def github_api(url: str, token: str = None) -> Any: """调用GitHub API""" headers = {"Accept": "application/vnd.github+json"} t = token or settings.GITHUB_TOKEN - - async with httpx.AsyncClient(timeout=30) as client: - # First try with token if available - if t: - headers["Authorization"] = f"Bearer {t}" - try: - response = await client.get(url, headers=headers) - if response.status_code == 200: - return response.json() - if response.status_code != 401: - if response.status_code == 403: - raise Exception("GitHub API 403:请配置 GITHUB_TOKEN 或确认仓库权限/频率限制") - raise Exception(f"GitHub API {response.status_code}: {url}") - # If 401, fall through to retry without token - print(f"[API] GitHub API 401 (Unauthorized) with token, retrying without token for: {url}") - except Exception as e: - if "GitHub API 401" not in str(e) and "401" not in str(e): - raise - + client = _get_http_client() + + # First try with token if available + if t: + headers["Authorization"] = f"Bearer {t}" + try: + response = await _request_with_retry(client, url, headers) + if response.status_code == 200: + return response.json() + if response.status_code != 401: + if response.status_code == 403: + raise Exception("GitHub API 403:请配置 GITHUB_TOKEN 或确认仓库权限/频率限制") + raise Exception(f"GitHub API {response.status_code}: {url}") + # If 401, fall through to retry without token + logger.info( + f"[API] GitHub API 401 (Unauthorized) with token, retrying without token for: {url}" + ) + except Exception as e: + if "GitHub API 401" not in str(e) and "401" not in str(e): + raise + + # Try without token + headers.pop("Authorization", None) + + try: + response = await _request_with_retry(client, url, headers) + if response.status_code == 200: + return response.json() + if response.status_code == 403: + raise Exception("GitHub API 403:请配置 GITHUB_TOKEN 或确认仓库权限/频率限制") + if response.status_code == 401: + raise Exception("GitHub API 401:请配置 GITHUB_TOKEN 或确认仓库权限") + raise Exception(f"GitHub API {response.status_code}: {url}") + except Exception as e: + logger.error(f"[API] GitHub API 调用失败: {url}, 错误: {e}") + raise + # Try without token if "Authorization" in headers: del headers["Authorization"] - + try: response = await client.get(url, headers=headers) if response.status_code == 200: @@ -144,34 +260,51 @@ async def github_api(url: str, token: str = None) -> Any: raise - async def gitea_api(url: str, token: str = None) -> Any: """调用Gitea API""" headers = {"Content-Type": "application/json"} t = token or settings.GITEA_TOKEN - - async with httpx.AsyncClient(timeout=120) as client: - # First try with token if available - if t: - headers["Authorization"] = f"token {t}" - try: - response = await client.get(url, headers=headers) - if response.status_code == 200: - return response.json() - if response.status_code != 401: - if response.status_code == 403: - raise Exception("Gitea API 403:请确认仓库权限/频率限制") - raise Exception(f"Gitea API {response.status_code}: {url}") - # If 401, fall through to retry without token - print(f"[API] Gitea API 401 (Unauthorized) with token, retrying without token for: {url}") - except Exception as e: - if "Gitea API 401" not in str(e) and "401" not in str(e): - raise - + client = _get_http_client() + + # First try with token if available + if t: + headers["Authorization"] = f"token {t}" + try: + response = await _request_with_retry(client, url, headers) + if response.status_code == 200: + return response.json() + if response.status_code != 401: + if response.status_code == 403: + raise Exception("Gitea API 403:请确认仓库权限/频率限制") + raise Exception(f"Gitea API {response.status_code}: {url}") + # If 401, fall through to retry without token + logger.info( + f"[API] Gitea API 401 (Unauthorized) with token, retrying without token for: {url}" + ) + except Exception as e: + if "Gitea API 401" not in str(e) and "401" not in str(e): + raise + + # Try without token + headers.pop("Authorization", None) + + try: + response = await _request_with_retry(client, url, headers) + if response.status_code == 200: + return response.json() + if response.status_code == 401: + raise Exception("Gitea API 401:请配置 GITEA_TOKEN 或确认仓库权限") + if response.status_code == 403: + raise Exception("Gitea API 403:请确认仓库权限/频率限制") + raise Exception(f"Gitea API {response.status_code}: {url}") + except Exception as e: + logger.error(f"[API] Gitea API 调用失败: {url}, 错误: {e}") + raise + # Try without token if "Authorization" in headers: del headers["Authorization"] - + try: response = await client.get(url, headers=headers) if response.status_code == 200: @@ -190,29 +323,47 @@ async def gitlab_api(url: str, token: str = None) -> Any: """调用GitLab API""" headers = {"Content-Type": "application/json"} t = token or settings.GITLAB_TOKEN - - async with httpx.AsyncClient(timeout=30) as client: - # First try with token if available - if t: - headers["PRIVATE-TOKEN"] = t - try: - response = await client.get(url, headers=headers) - if response.status_code == 200: - return response.json() - if response.status_code != 401: - if response.status_code == 403: - raise Exception("GitLab API 403:请确认仓库权限/频率限制") - raise Exception(f"GitLab API {response.status_code}: {url}") - # If 401, fall through to retry without token - print(f"[API] GitLab API 401 (Unauthorized) with token, retrying without token for: {url}") - except Exception as e: - if "GitLab API 401" not in str(e) and "401" not in str(e): - raise - + client = _get_http_client() + + # First try with token if available + if t: + headers["PRIVATE-TOKEN"] = t + try: + response = await _request_with_retry(client, url, headers) + if response.status_code == 200: + return response.json() + if response.status_code != 401: + if response.status_code == 403: + raise Exception("GitLab API 403:请确认仓库权限/频率限制") + raise Exception(f"GitLab API {response.status_code}: {url}") + # If 401, fall through to retry without token + logger.info( + f"[API] GitLab API 401 (Unauthorized) with token, retrying without token for: {url}" + ) + except Exception as e: + if "GitLab API 401" not in str(e) and "401" not in str(e): + raise + + # Try without token + headers.pop("PRIVATE-TOKEN", None) + + try: + response = await _request_with_retry(client, url, headers) + if response.status_code == 200: + return response.json() + if response.status_code == 401: + raise Exception("GitLab API 401:请配置 GITLAB_TOKEN 或确认仓库权限") + if response.status_code == 403: + raise Exception("GitLab API 403:请确认仓库权限/频率限制") + raise Exception(f"GitLab API {response.status_code}: {url}") + except Exception as e: + logger.error(f"[API] GitLab API 调用失败: {url}, 错误: {e}") + raise + # Try without token if "PRIVATE-TOKEN" in headers: del headers["PRIVATE-TOKEN"] - + try: response = await client.get(url, headers=headers) if response.status_code == 200: @@ -229,167 +380,194 @@ async def gitlab_api(url: str, token: str = None) -> Any: async def fetch_file_content(url: str, headers: Dict[str, str] = None) -> Optional[str]: """获取文件内容""" - async with httpx.AsyncClient(timeout=30) as client: - try: - response = await client.get(url, headers=headers or {}) + client = _get_http_client() + req_headers = headers or {} + try: + response = await _request_with_retry(client, url, req_headers) + if response.status_code == 200: + return response.text + + # 如果带 Token 请求失败(401/403),尝试不带 Token 请求(针对公开仓库) + if response.status_code in (401, 403) and headers: + logger.info( + f"[API] 获取文件内容返回 {response.status_code},尝试不带 Token 重试: {url}" + ) + response = await _request_with_retry(client, url, {}) if response.status_code == 200: return response.text - - # 如果带 Token 请求失败(401/403),尝试不带 Token 请求(针对公开仓库) - if response.status_code in (401, 403) and headers: - print(f"[API] 获取文件内容返回 {response.status_code},尝试不带 Token 重试: {url}") - response = await client.get(url) - if response.status_code == 200: - return response.text - - except Exception as e: - print(f"获取文件内容失败: {url}, 错误: {e}") + + except Exception as e: + logger.error(f"获取文件内容失败: {url}, 错误: {e}") return None async def get_github_branches(repo_url: str, token: str = None) -> List[str]: """获取GitHub仓库分支列表""" repo_info = parse_repository_url(repo_url, "github") - owner, repo = repo_info['owner'], repo_info['repo'] - + owner, repo = repo_info["owner"], repo_info["repo"] + branches_url = f"https://api.github.com/repos/{owner}/{repo}/branches?per_page=100" branches_data = await github_api(branches_url, token) - + if not isinstance(branches_data, list): print(f"[Branch] 警告: 获取 GitHub 分支列表返回非列表数据: {branches_data}") return [] - + return [b["name"] for b in branches_data if isinstance(b, dict) and "name" in b] - - - async def get_gitea_branches(repo_url: str, token: str = None) -> List[str]: """获取Gitea仓库分支列表""" repo_info = parse_repository_url(repo_url, "gitea") - base_url = repo_info['base_url'] # This is {base}/api/v1 - owner, repo = repo_info['owner'], repo_info['repo'] - + base_url = repo_info["base_url"] # This is {base}/api/v1 + owner, repo = repo_info["owner"], repo_info["repo"] + branches_url = f"{base_url}/repos/{owner}/{repo}/branches" branches_data = await gitea_api(branches_url, token) - + if not isinstance(branches_data, list): print(f"[Branch] 警告: 获取 Gitea 分支列表返回非列表数据: {branches_data}") return [] - + return [b["name"] for b in branches_data if isinstance(b, dict) and "name" in b] async def get_gitlab_branches(repo_url: str, token: str = None) -> List[str]: """获取GitLab仓库分支列表""" parsed = urlparse(repo_url) - + extracted_token = token if parsed.username: - if parsed.username == 'oauth2' and parsed.password: + if parsed.username == "oauth2" and parsed.password: extracted_token = parsed.password elif parsed.username and not parsed.password: extracted_token = parsed.username - + repo_info = parse_repository_url(repo_url, "gitlab") - base_url = repo_info['base_url'] - project_path = quote(repo_info['project_path'], safe='') - + base_url = repo_info["base_url"] + project_path = quote(repo_info["project_path"], safe="") + branches_url = f"{base_url}/projects/{project_path}/repository/branches?per_page=100" branches_data = await gitlab_api(branches_url, extracted_token) - + if not isinstance(branches_data, list): print(f"[Branch] 警告: 获取 GitLab 分支列表返回非列表数据: {branches_data}") return [] - + return [b["name"] for b in branches_data if isinstance(b, dict) and "name" in b] -async def get_github_files(repo_url: str, branch: str, token: str = None, exclude_patterns: List[str] = None) -> List[Dict[str, str]]: +async def get_github_files( + repo_url: str, branch: str, token: str = None, exclude_patterns: List[str] = None +) -> List[Dict[str, str]]: """获取GitHub仓库文件列表""" # 解析仓库URL repo_info = parse_repository_url(repo_url, "github") - owner, repo = repo_info['owner'], repo_info['repo'] - + owner, repo = repo_info["owner"], repo_info["repo"] + # 获取仓库文件树 tree_url = f"https://api.github.com/repos/{owner}/{repo}/git/trees/{quote(branch)}?recursive=1" tree_data = await github_api(tree_url, token) - + files = [] for item in tree_data.get("tree", []): - if item.get("type") == "blob" and is_text_file(item["path"]) and not should_exclude(item["path"], exclude_patterns): + if ( + item.get("type") == "blob" + and is_text_file(item["path"]) + and not should_exclude(item["path"], exclude_patterns) + ): size = item.get("size", 0) if size <= settings.MAX_FILE_SIZE_BYTES: - files.append({ - "path": item["path"], - "url": f"https://raw.githubusercontent.com/{owner}/{repo}/{quote(branch)}/{item['path']}" - }) - + files.append( + { + "path": item["path"], + "url": f"https://raw.githubusercontent.com/{owner}/{repo}/{quote(branch)}/{item['path']}", + } + ) + return files -async def get_gitlab_files(repo_url: str, branch: str, token: str = None, exclude_patterns: List[str] = None) -> List[Dict[str, str]]: +async def get_gitlab_files( + repo_url: str, branch: str, token: str = None, exclude_patterns: List[str] = None +) -> List[Dict[str, str]]: """获取GitLab仓库文件列表""" parsed = urlparse(repo_url) - + # 从URL中提取token(如果存在) extracted_token = token if parsed.username: - if parsed.username == 'oauth2' and parsed.password: + if parsed.username == "oauth2" and parsed.password: extracted_token = parsed.password elif parsed.username and not parsed.password: extracted_token = parsed.username - + # 解析项目路径 repo_info = parse_repository_url(repo_url, "gitlab") - base_url = repo_info['base_url'] # {base}/api/v4 - project_path = quote(repo_info['project_path'], safe='') - + base_url = repo_info["base_url"] # {base}/api/v4 + project_path = quote(repo_info["project_path"], safe="") + # 获取仓库文件树 tree_url = f"{base_url}/projects/{project_path}/repository/tree?ref={quote(branch)}&recursive=true&per_page=100" tree_data = await gitlab_api(tree_url, extracted_token) - + files = [] for item in tree_data: - if item.get("type") == "blob" and is_text_file(item["path"]) and not should_exclude(item["path"], exclude_patterns): - files.append({ - "path": item["path"], - "url": f"{base_url}/projects/{project_path}/repository/files/{quote(item['path'], safe='')}/raw?ref={quote(branch)}", - "token": extracted_token - }) - + if ( + item.get("type") == "blob" + and is_text_file(item["path"]) + and not should_exclude(item["path"], exclude_patterns) + ): + files.append( + { + "path": item["path"], + "url": f"{base_url}/projects/{project_path}/repository/files/{quote(item['path'], safe='')}/raw?ref={quote(branch)}", + "token": extracted_token, + } + ) + return files - -async def get_gitea_files(repo_url: str, branch: str, token: str = None, exclude_patterns: List[str] = None) -> List[Dict[str, str]]: +async def get_gitea_files( + repo_url: str, branch: str, token: str = None, exclude_patterns: List[str] = None +) -> List[Dict[str, str]]: """获取Gitea仓库文件列表""" repo_info = parse_repository_url(repo_url, "gitea") - base_url = repo_info['base_url'] - owner, repo = repo_info['owner'], repo_info['repo'] - + base_url = repo_info["base_url"] + owner, repo = repo_info["owner"], repo_info["repo"] + # Gitea tree API: GET /repos/{owner}/{repo}/git/trees/{sha}?recursive=true # 可以直接使用分支名作为sha - tree_url = f"{base_url}/repos/{quote(owner)}/{quote(repo)}/git/trees/{quote(branch)}?recursive=true" + tree_url = ( + f"{base_url}/repos/{quote(owner)}/{quote(repo)}/git/trees/{quote(branch)}?recursive=true" + ) tree_data = await gitea_api(tree_url, token) - + files = [] for item in tree_data.get("tree", []): - # Gitea API returns 'type': 'blob' for files - if item.get("type") == "blob" and is_text_file(item["path"]) and not should_exclude(item["path"], exclude_patterns): + # Gitea API returns 'type': 'blob' for files + if ( + item.get("type") == "blob" + and is_text_file(item["path"]) + and not should_exclude(item["path"], exclude_patterns) + ): # 使用API raw endpoint: GET /repos/{owner}/{repo}/raw/{filepath}?ref={branch} - files.append({ - "path": item["path"], - "url": f"{base_url}/repos/{owner}/{repo}/raw/{quote(item['path'])}?ref={quote(branch)}", - "token": token # 传递token以便fetch_file_content使用 - }) - + files.append( + { + "path": item["path"], + "url": f"{base_url}/repos/{owner}/{repo}/raw/{quote(item['path'])}?ref={quote(branch)}", + "token": token, # 传递token以便fetch_file_content使用 + } + ) + return files + + async def scan_repo_task(task_id: str, db_session_factory, user_config: dict = None): """ 后台仓库扫描任务 - + Args: task_id: 任务ID db_session_factory: 数据库会话工厂 @@ -405,7 +583,7 @@ async def scan_repo_task(task_id: str, db_session_factory, user_config: dict = N task.status = "running" task.started_at = datetime.now(timezone.utc) await db.commit() - + # 创建使用用户配置的LLM服务实例 llm_service = LLMService(user_config=user_config or {}) @@ -413,21 +591,22 @@ async def scan_repo_task(task_id: str, db_session_factory, user_config: dict = N project = await db.get(Project, task.project_id) if not project: raise Exception("项目不存在") - + # 检查项目类型 - 仅支持仓库类型项目 - source_type = getattr(project, 'source_type', 'repository') - if source_type == 'zip': + source_type = getattr(project, "source_type", "repository") + if source_type == "zip": raise Exception("ZIP类型项目请使用ZIP上传扫描接口") - + if not project.repository_url: raise Exception("仓库地址不存在") repo_url = project.repository_url branch = task.branch_name or project.default_branch or "main" repo_type = project.repository_type or "other" - + # 解析任务的排除模式 import json as json_module + task_exclude_patterns = [] if task.exclude_patterns: try: @@ -435,7 +614,9 @@ async def scan_repo_task(task_id: str, db_session_factory, user_config: dict = N except: pass - print(f"🚀 开始扫描仓库: {repo_url}, 分支: {branch}, 类型: {repo_type}, 来源: {source_type}") + print( + f"🚀 开始扫描仓库: {repo_url}, 分支: {branch}, 类型: {repo_type}, 来源: {source_type}" + ) if task_exclude_patterns: print(f"📋 排除模式: {task_exclude_patterns}") @@ -445,20 +626,20 @@ async def scan_repo_task(task_id: str, db_session_factory, user_config: dict = N gitlab_token = settings.GITLAB_TOKEN gitea_token = settings.GITEA_TOKEN - - # 获取SSH私钥(如果配置了) - user_other_config = user_config.get('otherConfig', {}) if user_config else {} + user_other_config = user_config.get("otherConfig", {}) if user_config else {} ssh_private_key = None - if 'sshPrivateKey' in user_other_config: + if "sshPrivateKey" in user_other_config: from app.core.encryption import decrypt_sensitive_data - ssh_private_key = decrypt_sensitive_data(user_other_config['sshPrivateKey']) + + ssh_private_key = decrypt_sensitive_data(user_other_config["sshPrivateKey"]) files: List[Dict[str, str]] = [] extracted_gitlab_token = None # 检查是否为SSH URL from app.services.git_ssh_service import GitSSHOperations + is_ssh_url = GitSSHOperations.is_ssh_url(repo_url) if is_ssh_url: @@ -472,7 +653,9 @@ async def scan_repo_task(task_id: str, db_session_factory, user_config: dict = N repo_url, ssh_private_key, branch, task_exclude_patterns ) # 转换为统一格式 - files = [{'path': f['path'], 'content': f['content']} for f in files_with_content] + files = [ + {"path": f["path"], "content": f["content"]} for f in files_with_content + ] actual_branch = branch print(f"✅ 通过SSH成功获取 {len(files)} 个文件") except Exception as e: @@ -494,21 +677,29 @@ async def scan_repo_task(task_id: str, db_session_factory, user_config: dict = N try: print(f"🔄 尝试获取分支 {try_branch} 的文件列表...") if repo_type == "github": - files = await get_github_files(repo_url, try_branch, github_token, task_exclude_patterns) + files = await get_github_files( + repo_url, try_branch, github_token, task_exclude_patterns + ) elif repo_type == "gitlab": - files = await get_gitlab_files(repo_url, try_branch, gitlab_token, task_exclude_patterns) + files = await get_gitlab_files( + repo_url, try_branch, gitlab_token, task_exclude_patterns + ) # GitLab文件可能带有token - if files and 'token' in files[0]: - extracted_gitlab_token = files[0].get('token') + if files and "token" in files[0]: + extracted_gitlab_token = files[0].get("token") elif repo_type == "gitea": - files = await get_gitea_files(repo_url, try_branch, gitea_token, task_exclude_patterns) + files = await get_gitea_files( + repo_url, try_branch, gitea_token, task_exclude_patterns + ) else: raise Exception("不支持的仓库类型,仅支持 GitHub, GitLab 和 Gitea 仓库") if files: actual_branch = try_branch if try_branch != branch: - print(f"⚠️ 分支 {branch} 不存在或无法访问,已降级到分支 {try_branch}") + print( + f"⚠️ 分支 {branch} 不存在或无法访问,已降级到分支 {try_branch}" + ) break except Exception as e: last_error = str(e) @@ -530,23 +721,25 @@ async def scan_repo_task(task_id: str, db_session_factory, user_config: dict = N # 获取分析配置(优先使用用户配置) analysis_config = get_analysis_config(user_config) - max_analyze_files = analysis_config['max_analyze_files'] - analysis_concurrency = analysis_config['llm_concurrency'] # 并发数 - llm_gap_ms = analysis_config['llm_gap_ms'] + max_analyze_files = analysis_config["max_analyze_files"] + analysis_concurrency = analysis_config["llm_concurrency"] # 并发数 + llm_gap_ms = analysis_config["llm_gap_ms"] # 限制文件数量 # 如果指定了特定文件,则只分析这些文件 - target_files = (user_config or {}).get('scan_config', {}).get('file_paths', []) + target_files = (user_config or {}).get("scan_config", {}).get("file_paths", []) if target_files: print(f"🎯 指定分析 {len(target_files)} 个文件") - files = [f for f in files if f['path'] in target_files] + files = [f for f in files if f["path"] in target_files] elif max_analyze_files > 0: files = files[:max_analyze_files] task.total_files = len(files) await db.commit() - print(f"📊 获取到 {len(files)} 个文件,开始分析 (最大文件数: {max_analyze_files}, 请求间隔: {llm_gap_ms}ms)") + print( + f"📊 获取到 {len(files)} 个文件,开始分析 (最大文件数: {max_analyze_files}, 请求间隔: {llm_gap_ms}ms)" + ) # 4. 分析文件 total_issues = 0 @@ -561,69 +754,72 @@ async def scan_repo_task(task_id: str, db_session_factory, user_config: dict = N # 4. 并行分析文件 print(f"🧬 启动并行分析: {len(files)} 个文件, 并发数: {analysis_concurrency}") - + semaphore = asyncio.Semaphore(analysis_concurrency) - + async def analyze_single_file(file_info): """内部函数:分析单个文件并返回结果""" nonlocal consecutive_failures, last_error - + async with semaphore: if task_control.is_cancelled(task_id): return None - - f_path = file_info['path'] + + f_path = file_info["path"] MAX_RETRIES = 3 for attempt in range(MAX_RETRIES): try: # 4.1 获取文件内容 (仅在第一次尝试或内容获取失败时获取) if attempt == 0: if is_ssh_url: - content = file_info.get('content', '') + content = file_info.get("content", "") else: headers = {} if repo_type == "gitlab": - token_to_use = file_info.get('token') or gitlab_token - if token_to_use: headers["PRIVATE-TOKEN"] = token_to_use + token_to_use = file_info.get("token") or gitlab_token + if token_to_use: + headers["PRIVATE-TOKEN"] = token_to_use elif repo_type == "gitea": - token_to_use = file_info.get('token') or gitea_token - if token_to_use: headers["Authorization"] = f"token {token_to_use}" + token_to_use = file_info.get("token") or gitea_token + if token_to_use: + headers["Authorization"] = f"token {token_to_use}" elif repo_type == "github" and github_token: headers["Authorization"] = f"Bearer {github_token}" - + content = await fetch_file_content(file_info["url"], headers) if not content or not content.strip(): return {"type": "skip", "reason": "empty", "path": f_path} - + if len(content) > settings.MAX_FILE_SIZE_BYTES: return {"type": "skip", "reason": "too_large", "path": f_path} - + if task_control.is_cancelled(task_id): return None # 4.2 LLM 分析 language = get_language_from_path(f_path) - scan_config = (user_config or {}).get('scan_config', {}) - rule_set_id = scan_config.get('rule_set_id') - prompt_template_id = scan_config.get('prompt_template_id') - + scan_config = (user_config or {}).get("scan_config", {}) + rule_set_id = scan_config.get("rule_set_id") + prompt_template_id = scan_config.get("prompt_template_id") + if rule_set_id or prompt_template_id: analysis_result = await llm_service.analyze_code_with_rules( - content, language, + content, + language, rule_set_id=rule_set_id, prompt_template_id=prompt_template_id, - db_session=None + db_session=None, ) else: analysis_result = await llm_service.analyze_code(content, language) - + return { "type": "success", "path": f_path, "content": content, "language": language, - "analysis": analysis_result + "analysis": analysis_result, } except asyncio.CancelledError: # 捕获取消异常,不再重试 @@ -633,10 +829,18 @@ async def scan_repo_task(task_id: str, db_session_factory, user_config: dict = N wait_time = (attempt + 1) * 2 # 特殊处理限流错误提示 error_str = str(e) - if "429" in error_str or "rate limit" in error_str.lower() or "额度不足" in error_str: - print(f"🚫 [限流提示] 仓库扫描任务触发 LLM 频率限制 (429),建议在设置中降低并发数或增加请求间隔。文件: {f_path}") - - print(f"⚠️ 分析文件失败 ({f_path}), 正在进行第 {attempt+1} 次重试... 错误: {e}") + if ( + "429" in error_str + or "rate limit" in error_str.lower() + or "额度不足" in error_str + ): + print( + f"🚫 [限流提示] 仓库扫描任务触发 LLM 频率限制 (429),建议在设置中降低并发数或增加请求间隔。文件: {f_path}" + ) + + print( + f"⚠️ 分析文件失败 ({f_path}), 正在进行第 {attempt + 1} 次重试... 错误: {e}" + ) await asyncio.sleep(wait_time) continue else: @@ -646,7 +850,7 @@ async def scan_repo_task(task_id: str, db_session_factory, user_config: dict = N # 创建所有分析任务对象以便跟踪 task_objects = [asyncio.create_task(analyze_single_file(f)) for f in files] - + try: # 使用 as_completed 处理结果,这样可以实时更新进度且安全使用当前 db session for future in asyncio.as_completed(task_objects): @@ -660,7 +864,8 @@ async def scan_repo_task(task_id: str, db_session_factory, user_config: dict = N except asyncio.CancelledError: continue - if not res: continue + if not res: + continue if res["type"] == "skip": skipped_files += 1 @@ -671,19 +876,21 @@ async def scan_repo_task(task_id: str, db_session_factory, user_config: dict = N elif res["type"] == "success": consecutive_failures = 0 scanned_files += 1 - + f_path = res["path"] analysis = res["analysis"] - file_lines = res["content"].split('\n') + file_lines = res["content"].split("\n") total_lines += len(file_lines) - + # 保存问题 issues = analysis.get("issues", []) for issue in issues: try: # 防御性检查:确保 issue 是字典 if not isinstance(issue, dict): - print(f"⚠️ 警告: 任务 {task_id} 中文件 {f_path} 的分析结果包含无效的问题格式: {issue}") + print( + f"⚠️ 警告: 任务 {task_id} 中文件 {f_path} 的分析结果包含无效的问题格式: {issue}" + ) continue # 辅助函数:清理字符串中 PostgreSQL 不支持的字符 @@ -693,7 +900,7 @@ async def scan_repo_task(task_id: str, db_session_factory, user_config: dict = N if not isinstance(text, str): text = str(text) # 移除 NULL 字节 (PostgreSQL 不支持) - text = text.replace('\x00', '') + text = text.replace("\x00", "") return text line_num = issue.get("line", 1) @@ -703,7 +910,7 @@ async def scan_repo_task(task_id: str, db_session_factory, user_config: dict = N idx = max(0, int(line_num) - 1) start = max(0, idx - 2) end = min(len(file_lines), idx + 3) - code_snippet = '\n'.join(file_lines[start:end]) + code_snippet = "\n".join(file_lines[start:end]) except Exception: code_snippet = "" @@ -715,18 +922,20 @@ async def scan_repo_task(task_id: str, db_session_factory, user_config: dict = N issue_type=issue.get("type", "maintainability"), severity=issue.get("severity", "low"), title=sanitize_for_db(issue.get("title", "Issue")), - message=sanitize_for_db(issue.get("description") or issue.get("title", "Issue")), + message=sanitize_for_db( + issue.get("description") or issue.get("title", "Issue") + ), suggestion=sanitize_for_db(issue.get("suggestion")), code_snippet=code_snippet, ai_explanation=sanitize_for_db(issue.get("ai_explanation")), - status="open" + status="open", ) db.add(audit_issue) total_issues += 1 except Exception as e: print(f"⚠️ 处理单个问题时出错 (文件 {f_path}): {e}") continue - + if "quality_score" in analysis: try: quality_score = float(analysis["quality_score"]) @@ -739,10 +948,12 @@ async def scan_repo_task(task_id: str, db_session_factory, user_config: dict = N task.scanned_files = processed_count task.total_lines = total_lines task.issues_count = total_issues - await db.commit() # 这里的 commit 是在一个协程里按序进行的,是安全的 - + await db.commit() # 这里的 commit 是在一个协程里按序进行的,是安全的 + if processed_count % 10 == 0 or processed_count == len(files): - print(f"📈 任务 {task_id}: 进度 {processed_count}/{len(files)} ({int(processed_count/len(files)*100) if len(files) > 0 else 0}%)") + print( + f"📈 任务 {task_id}: 进度 {processed_count}/{len(files)} ({int(processed_count / len(files) * 100) if len(files) > 0 else 0}%)" + ) if consecutive_failures >= MAX_CONSECUTIVE_FAILURES: print(f"❌ 任务 {task_id}: 连续失败 {consecutive_failures} 次,停止分析") @@ -754,15 +965,17 @@ async def scan_repo_task(task_id: str, db_session_factory, user_config: dict = N if not t.done(): t.cancel() pending_count += 1 - + if pending_count > 0: print(f"🧹 任务 {task_id}: 已清理 {pending_count} 个后台待处理或执行中的任务") # 等待一下让取消逻辑执行完毕,但不阻塞太久 await asyncio.gather(*task_objects, return_exceptions=True) # 5. 完成任务 - avg_quality_score = sum(quality_scores) / len(quality_scores) if quality_scores else 100.0 - + avg_quality_score = ( + sum(quality_scores) / len(quality_scores) if quality_scores else 100.0 + ) + # 判断任务状态 # 如果所有文件都被跳过(空文件等),标记为完成但给出提示 if len(files) > 0 and scanned_files == 0 and skipped_files == len(files): @@ -782,7 +995,7 @@ async def scan_repo_task(task_id: str, db_session_factory, user_config: dict = N task.total_lines = total_lines task.issues_count = 0 task.quality_score = 0 - + # 尝试从最后一个错误中获取更详细的系统提示 error_msg = f"{failed_files} 个文件分析失败,请检查 LLM API 配置。最近一个错误: {str(last_error) if 'last_error' in locals() else '未知错误'}" task.error_message = error_msg @@ -797,7 +1010,7 @@ async def scan_repo_task(task_id: str, db_session_factory, user_config: dict = N task.issues_count = total_issues task.quality_score = avg_quality_score await db.commit() - + result_msg = f"✅ 任务 {task_id} 完成: 成功分析 {scanned_files} 个文件" if failed_files > 0: result_msg += f", {failed_files} 个文件失败"