diff --git a/README.md b/README.md
index 8d64351..1d45b61 100644
--- a/README.md
+++ b/README.md
@@ -428,7 +428,7 @@ DeepSeek-Coder · Codestral
**欢迎大家来和我交流探讨!无论是技术问题、功能建议还是合作意向,都期待与你沟通~**
-
+(项目开发、投资孵化等合作洽谈请通过邮箱联系)
| 联系方式 | |
|:---:|:---:|
| 📧 **邮箱** | **lintsinghua@qq.com** |
diff --git a/backend/app/api/v1/endpoints/agent_tasks.py b/backend/app/api/v1/endpoints/agent_tasks.py
index a9052f6..b68b478 100644
--- a/backend/app/api/v1/endpoints/agent_tasks.py
+++ b/backend/app/api/v1/endpoints/agent_tasks.py
@@ -294,6 +294,7 @@ async def _execute_agent_task(task_id: str):
other_config = (user_config or {}).get('otherConfig', {})
github_token = other_config.get('githubToken') or settings.GITHUB_TOKEN
gitlab_token = other_config.get('gitlabToken') or settings.GITLAB_TOKEN
+ gitea_token = other_config.get('giteaToken') or settings.GITEA_TOKEN
# 解密SSH私钥
ssh_private_key = None
@@ -313,6 +314,7 @@ async def _execute_agent_task(task_id: str):
task.branch_name,
github_token=github_token,
gitlab_token=gitlab_token,
+ gitea_token=gitea_token, # 🔥 新增
ssh_private_key=ssh_private_key, # 🔥 新增SSH密钥
event_emitter=event_emitter, # 🔥 新增
)
@@ -2226,6 +2228,7 @@ async def _get_project_root(
branch_name: Optional[str] = None,
github_token: Optional[str] = None,
gitlab_token: Optional[str] = None,
+ gitea_token: Optional[str] = None, # 🔥 新增
ssh_private_key: Optional[str] = None, # 🔥 新增:SSH私钥(用于SSH认证)
event_emitter: Optional[Any] = None, # 🔥 新增:用于发送实时日志
) -> str:
@@ -2242,6 +2245,7 @@ async def _get_project_root(
branch_name: 分支名称(仓库项目使用,优先于 project.default_branch)
github_token: GitHub 访问令牌(用于私有仓库)
gitlab_token: GitLab 访问令牌(用于私有仓库)
+ gitea_token: Gitea 访问令牌(用于私有仓库)
ssh_private_key: SSH私钥(用于SSH认证)
event_emitter: 事件发送器(用于发送实时日志)
@@ -2503,9 +2507,19 @@ async def _get_project_root(
parsed.fragment
))
await emit(f"🔐 使用 GitLab Token 认证")
+ elif repo_type == "gitea" and gitea_token:
+ auth_url = urlunparse((
+ parsed.scheme,
+ f"{gitea_token}@{parsed.netloc}",
+ parsed.path,
+ parsed.params,
+ parsed.query,
+ parsed.fragment
+ ))
+ await emit(f"🔐 使用 Gitea Token 认证")
elif is_ssh_url and ssh_private_key:
await emit(f"🔐 使用 SSH Key 认证")
-
+
for branch in branches_to_try:
check_cancelled()
diff --git a/backend/app/api/v1/endpoints/projects.py b/backend/app/api/v1/endpoints/projects.py
index dac1138..e26143a 100644
--- a/backend/app/api/v1/endpoints/projects.py
+++ b/backend/app/api/v1/endpoints/projects.py
@@ -16,6 +16,7 @@ from app.db.session import get_db, AsyncSessionLocal
from app.models.project import Project
from app.models.user import User
from app.models.audit import AuditTask, AuditIssue
+from app.models.agent_task import AgentTask, AgentTaskStatus, AgentFinding
from app.models.user_config import UserConfig
import zipfile
from app.services.scanner import scan_repo_task, get_github_files, get_gitlab_files, get_github_branches, get_gitlab_branches, get_gitea_branches, should_exclude, is_text_file
@@ -161,27 +162,52 @@ async def get_stats(
)
projects = projects_result.scalars().all()
project_ids = [p.id for p in projects]
-
- # 只统计当前用户项目的任务
+
+ # 统计旧的 AuditTask
tasks_result = await db.execute(
select(AuditTask).where(AuditTask.project_id.in_(project_ids)) if project_ids else select(AuditTask).where(False)
)
tasks = tasks_result.scalars().all()
task_ids = [t.id for t in tasks]
-
- # 只统计当前用户任务的问题
+
+ # 统计旧的 AuditIssue
issues_result = await db.execute(
select(AuditIssue).where(AuditIssue.task_id.in_(task_ids)) if task_ids else select(AuditIssue).where(False)
)
issues = issues_result.scalars().all()
-
+
+ # 🔥 同时统计新的 AgentTask
+ agent_tasks_result = await db.execute(
+ select(AgentTask).where(AgentTask.project_id.in_(project_ids)) if project_ids else select(AgentTask).where(False)
+ )
+ agent_tasks = agent_tasks_result.scalars().all()
+ agent_task_ids = [t.id for t in agent_tasks]
+
+ # 🔥 统计 AgentFinding
+ agent_findings_result = await db.execute(
+ select(AgentFinding).where(AgentFinding.task_id.in_(agent_task_ids)) if agent_task_ids else select(AgentFinding).where(False)
+ )
+ agent_findings = agent_findings_result.scalars().all()
+
+ # 合并统计(旧任务 + 新 Agent 任务)
+ total_tasks = len(tasks) + len(agent_tasks)
+ completed_tasks = (
+ len([t for t in tasks if t.status == "completed"]) +
+ len([t for t in agent_tasks if t.status == AgentTaskStatus.COMPLETED])
+ )
+ total_issues = len(issues) + len(agent_findings)
+ resolved_issues = (
+ len([i for i in issues if i.status == "resolved"]) +
+ len([f for f in agent_findings if f.status == "resolved"])
+ )
+
return {
"total_projects": len(projects),
"active_projects": len([p for p in projects if p.is_active]),
- "total_tasks": len(tasks),
- "completed_tasks": len([t for t in tasks if t.status == "completed"]),
- "total_issues": len(issues),
- "resolved_issues": len([i for i in issues if i.status == "resolved"]),
+ "total_tasks": total_tasks,
+ "completed_tasks": completed_tasks,
+ "total_issues": total_issues,
+ "resolved_issues": resolved_issues,
}
@router.get("/{id}", response_model=ProjectResponse)
diff --git a/backend/app/services/agent/__init__.py b/backend/app/services/agent/__init__.py
index fee2a92..f169c80 100644
--- a/backend/app/services/agent/__init__.py
+++ b/backend/app/services/agent/__init__.py
@@ -1,29 +1,19 @@
"""
DeepAudit Agent 服务模块
-基于 LangGraph 的 AI Agent 代码安全审计
+基于动态 Agent 树架构的 AI 代码安全审计
-架构升级版本 - 支持:
-- 动态Agent树结构
-- 专业知识模块系统
-- Agent间通信机制
-- 完整状态管理
-- Think工具和漏洞报告工具
+架构:
+- OrchestratorAgent 作为编排层,动态调度子 Agent
+- ReconAgent 负责侦察和文件分析
+- AnalysisAgent 负责漏洞分析
+- VerificationAgent 负责验证发现
工作流:
- START → Recon → Analysis ⟲ → Verification → Report → END
-
+ START → Orchestrator → [Recon/Analysis/Verification] → Report → END
+
支持动态创建子Agent进行专业化分析
"""
-# 从 graph 模块导入主要组件
-from .graph import (
- AgentRunner,
- run_agent_task,
- LLMService,
- AuditState,
- create_audit_graph,
-)
-
# 事件管理
from .event_manager import EventManager, AgentEventEmitter
@@ -33,14 +23,14 @@ from .agents import (
OrchestratorAgent, ReconAgent, AnalysisAgent, VerificationAgent,
)
-# 🔥 新增:核心模块(状态管理、注册表、消息)
+# 核心模块(状态管理、注册表、消息)
from .core import (
AgentState, AgentStatus,
AgentRegistry, agent_registry,
AgentMessage, MessageType, MessagePriority, MessageBus,
)
-# 🔥 新增:知识模块系统(基于RAG)
+# 知识模块系统(基于RAG)
from .knowledge import (
KnowledgeLoader, knowledge_loader,
get_available_modules, get_module_content,
@@ -48,7 +38,7 @@ from .knowledge import (
SecurityKnowledgeQueryTool, GetVulnerabilityKnowledgeTool,
)
-# 🔥 新增:协作工具
+# 协作工具
from .tools import (
ThinkTool, ReflectTool,
CreateVulnerabilityReportTool,
@@ -57,24 +47,15 @@ from .tools import (
WaitForMessageTool, AgentFinishTool,
)
-# 🔥 新增:遥测模块
+# 遥测模块
from .telemetry import Tracer, get_global_tracer, set_global_tracer
__all__ = [
- # 核心 Runner
- "AgentRunner",
- "run_agent_task",
- "LLMService",
-
- # LangGraph
- "AuditState",
- "create_audit_graph",
-
# 事件管理
"EventManager",
"AgentEventEmitter",
-
+
# Agent 类
"BaseAgent",
"AgentConfig",
@@ -83,8 +64,8 @@ __all__ = [
"ReconAgent",
"AnalysisAgent",
"VerificationAgent",
-
- # 🔥 核心模块
+
+ # 核心模块
"AgentState",
"AgentStatus",
"AgentRegistry",
@@ -93,8 +74,8 @@ __all__ = [
"MessageType",
"MessagePriority",
"MessageBus",
-
- # 🔥 知识模块(基于RAG)
+
+ # 知识模块(基于RAG)
"KnowledgeLoader",
"knowledge_loader",
"get_available_modules",
@@ -103,8 +84,8 @@ __all__ = [
"security_knowledge_rag",
"SecurityKnowledgeQueryTool",
"GetVulnerabilityKnowledgeTool",
-
- # 🔥 协作工具
+
+ # 协作工具
"ThinkTool",
"ReflectTool",
"CreateVulnerabilityReportTool",
@@ -114,10 +95,9 @@ __all__ = [
"ViewAgentGraphTool",
"WaitForMessageTool",
"AgentFinishTool",
-
- # 🔥 遥测模块
+
+ # 遥测模块
"Tracer",
"get_global_tracer",
"set_global_tracer",
]
-
diff --git a/backend/app/services/agent/agents/base.py b/backend/app/services/agent/agents/base.py
index e0da612..bdc5188 100644
--- a/backend/app/services/agent/agents/base.py
+++ b/backend/app/services/agent/agents/base.py
@@ -1024,10 +1024,18 @@ class BaseAgent(ABC):
elif chunk["type"] == "error":
accumulated = chunk.get("accumulated", "")
error_msg = chunk.get("error", "Unknown error")
- logger.error(f"[{self.name}] Stream error: {error_msg}")
- if accumulated:
- total_tokens = chunk.get("usage", {}).get("total_tokens", 0)
- else:
+ error_type = chunk.get("error_type", "unknown")
+ user_message = chunk.get("user_message", error_msg)
+ logger.error(f"[{self.name}] Stream error ({error_type}): {error_msg}")
+
+ if chunk.get("usage"):
+ total_tokens = chunk["usage"].get("total_tokens", 0)
+
+ # 使用特殊前缀标记 API 错误,让调用方能够识别
+ # 格式:[API_ERROR:error_type] user_message
+ if error_type in ("rate_limit", "quota_exceeded", "authentication", "connection"):
+ accumulated = f"[API_ERROR:{error_type}] {user_message}"
+ elif not accumulated:
accumulated = f"[系统错误: {error_msg}] 请重新思考并输出你的决策。"
break
diff --git a/backend/app/services/agent/agents/orchestrator.py b/backend/app/services/agent/agents/orchestrator.py
index 73c5e41..1e3c8c1 100644
--- a/backend/app/services/agent/agents/orchestrator.py
+++ b/backend/app/services/agent/agents/orchestrator.py
@@ -284,7 +284,56 @@ Action Input: {{"参数": "值"}}
# 重置空响应计数器
self._empty_retry_count = 0
-
+
+ # 🔥 检查是否是 API 错误(而非格式错误)
+ if llm_output.startswith("[API_ERROR:"):
+ # 提取错误类型和消息
+ match = re.match(r"\[API_ERROR:(\w+)\]\s*(.*)", llm_output)
+ if match:
+ error_type = match.group(1)
+ error_message = match.group(2)
+
+ if error_type == "rate_limit":
+ # 速率限制 - 等待后重试
+ api_retry_count = getattr(self, '_api_retry_count', 0) + 1
+ self._api_retry_count = api_retry_count
+ if api_retry_count >= 3:
+ logger.error(f"[{self.name}] Too many rate limit errors, stopping")
+ await self.emit_event("error", f"API 速率限制重试次数过多: {error_message}")
+ break
+ logger.warning(f"[{self.name}] Rate limit hit, waiting before retry ({api_retry_count}/3)")
+ await self.emit_event("warning", f"API 速率限制,等待后重试 ({api_retry_count}/3)")
+ await asyncio.sleep(30) # 等待 30 秒后重试
+ continue
+
+ elif error_type == "quota_exceeded":
+ # 配额用尽 - 终止任务
+ logger.error(f"[{self.name}] API quota exceeded: {error_message}")
+ await self.emit_event("error", f"API 配额已用尽: {error_message}")
+ break
+
+ elif error_type == "authentication":
+ # 认证错误 - 终止任务
+ logger.error(f"[{self.name}] API authentication error: {error_message}")
+ await self.emit_event("error", f"API 认证失败: {error_message}")
+ break
+
+ elif error_type == "connection":
+ # 连接错误 - 重试
+ api_retry_count = getattr(self, '_api_retry_count', 0) + 1
+ self._api_retry_count = api_retry_count
+ if api_retry_count >= 3:
+ logger.error(f"[{self.name}] Too many connection errors, stopping")
+ await self.emit_event("error", f"API 连接错误重试次数过多: {error_message}")
+ break
+ logger.warning(f"[{self.name}] Connection error, retrying ({api_retry_count}/3)")
+ await self.emit_event("warning", f"API 连接错误,重试中 ({api_retry_count}/3)")
+ await asyncio.sleep(5) # 等待 5 秒后重试
+ continue
+
+ # 重置 API 重试计数器(成功获取响应后)
+ self._api_retry_count = 0
+
# 解析 LLM 的决策
step = self._parse_llm_response(llm_output)
diff --git a/backend/app/services/agent/graph/__init__.py b/backend/app/services/agent/graph/__init__.py
deleted file mode 100644
index 5bc17d1..0000000
--- a/backend/app/services/agent/graph/__init__.py
+++ /dev/null
@@ -1,28 +0,0 @@
-"""
-LangGraph 工作流模块
-使用状态图构建混合 Agent 审计流程
-"""
-
-from .audit_graph import AuditState, create_audit_graph, create_audit_graph_with_human
-from .nodes import ReconNode, AnalysisNode, VerificationNode, ReportNode, HumanReviewNode
-from .runner import AgentRunner, run_agent_task, LLMService
-
-__all__ = [
- # 状态和图
- "AuditState",
- "create_audit_graph",
- "create_audit_graph_with_human",
-
- # 节点
- "ReconNode",
- "AnalysisNode",
- "VerificationNode",
- "ReportNode",
- "HumanReviewNode",
-
- # Runner
- "AgentRunner",
- "run_agent_task",
- "LLMService",
-]
-
diff --git a/backend/app/services/agent/graph/audit_graph.py b/backend/app/services/agent/graph/audit_graph.py
deleted file mode 100644
index 4b0eada..0000000
--- a/backend/app/services/agent/graph/audit_graph.py
+++ /dev/null
@@ -1,677 +0,0 @@
-"""
-DeepAudit 审计工作流图 - LLM 驱动版
-使用 LangGraph 构建 LLM 驱动的 Agent 协作流程
-
-重要改变:路由决策由 LLM 参与,而不是硬编码条件!
-"""
-
-from typing import TypedDict, Annotated, List, Dict, Any, Optional, Literal
-from datetime import datetime
-import operator
-import logging
-import json
-
-from langgraph.graph import StateGraph, END
-from langgraph.checkpoint.memory import MemorySaver
-from langgraph.prebuilt import ToolNode
-
-logger = logging.getLogger(__name__)
-
-
-# ============ 状态定义 ============
-
-class Finding(TypedDict):
- """漏洞发现"""
- id: str
- vulnerability_type: str
- severity: str
- title: str
- description: str
- file_path: Optional[str]
- line_start: Optional[int]
- code_snippet: Optional[str]
- is_verified: bool
- confidence: float
- source: str
-
-
-class AuditState(TypedDict):
- """
- 审计状态
- 在整个工作流中传递和更新
- """
- # 输入
- project_root: str
- project_info: Dict[str, Any]
- config: Dict[str, Any]
- task_id: str
-
- # Recon 阶段输出
- tech_stack: Dict[str, Any]
- entry_points: List[Dict[str, Any]]
- high_risk_areas: List[str]
- dependencies: Dict[str, Any]
-
- # Analysis 阶段输出
- findings: Annotated[List[Finding], operator.add] # 使用 add 合并多轮发现
-
- # Verification 阶段输出
- verified_findings: List[Finding]
- false_positives: List[str]
- # 🔥 NEW: 验证后的完整 findings(用于替换原始 findings)
- _verified_findings_update: Optional[List[Finding]]
-
- # 控制流 - 🔥 关键:LLM 可以设置这些来影响路由
- current_phase: str
- iteration: int
- max_iterations: int
- should_continue_analysis: bool
-
- # 🔥 新增:LLM 的路由决策
- llm_next_action: Optional[str] # LLM 建议的下一步: "continue_analysis", "verify", "report", "end"
- llm_routing_reason: Optional[str] # LLM 的决策理由
-
- # 🔥 新增:Agent 间协作的任务交接信息
- recon_handoff: Optional[Dict[str, Any]] # Recon -> Analysis 的交接
- analysis_handoff: Optional[Dict[str, Any]] # Analysis -> Verification 的交接
- verification_handoff: Optional[Dict[str, Any]] # Verification -> Report 的交接
-
- # 消息和事件
- messages: Annotated[List[Dict], operator.add]
- events: Annotated[List[Dict], operator.add]
-
- # 最终输出
- summary: Optional[Dict[str, Any]]
- security_score: Optional[int]
- error: Optional[str]
-
-
-# ============ LLM 路由决策器 ============
-
-class LLMRouter:
- """
- LLM 路由决策器
- 让 LLM 来决定下一步应该做什么
- """
-
- def __init__(self, llm_service):
- self.llm_service = llm_service
-
- async def decide_after_recon(self, state: AuditState) -> Dict[str, Any]:
- """Recon 后让 LLM 决定下一步"""
- entry_points = state.get("entry_points", [])
- high_risk_areas = state.get("high_risk_areas", [])
- tech_stack = state.get("tech_stack", {})
- initial_findings = state.get("findings", [])
-
- prompt = f"""作为安全审计的决策者,基于以下信息收集结果,决定下一步行动。
-
-## 信息收集结果
-- 入口点数量: {len(entry_points)}
-- 高风险区域: {high_risk_areas[:10]}
-- 技术栈: {tech_stack}
-- 初步发现: {len(initial_findings)} 个
-
-## 选项
-1. "analysis" - 继续进行漏洞分析(推荐:有入口点或高风险区域时)
-2. "end" - 结束审计(仅当没有任何可分析内容时)
-
-请返回 JSON 格式:
-{{"action": "analysis或end", "reason": "决策理由"}}"""
-
- try:
- response = await self.llm_service.chat_completion_raw(
- messages=[
- {"role": "system", "content": "你是安全审计流程的决策者,负责决定下一步行动。"},
- {"role": "user", "content": prompt},
- ],
- # 🔥 不传递 temperature 和 max_tokens,使用用户配置
- )
-
- content = response.get("content", "")
- # 提取 JSON
- import re
- json_match = re.search(r'\{.*\}', content, re.DOTALL)
- if json_match:
- result = json.loads(json_match.group())
- return result
- except Exception as e:
- logger.warning(f"LLM routing decision failed: {e}")
-
- # 默认决策
- if entry_points or high_risk_areas:
- return {"action": "analysis", "reason": "有可分析内容"}
- return {"action": "end", "reason": "没有发现入口点或高风险区域"}
-
- async def decide_after_analysis(self, state: AuditState) -> Dict[str, Any]:
- """Analysis 后让 LLM 决定下一步"""
- findings = state.get("findings", [])
- iteration = state.get("iteration", 0)
- max_iterations = state.get("max_iterations", 3)
-
- # 统计发现
- severity_counts = {"critical": 0, "high": 0, "medium": 0, "low": 0}
- for f in findings:
- # 跳过非字典类型的 finding
- if not isinstance(f, dict):
- continue
- sev = f.get("severity", "medium")
- severity_counts[sev] = severity_counts.get(sev, 0) + 1
-
- prompt = f"""作为安全审计的决策者,基于以下分析结果,决定下一步行动。
-
-## 分析结果
-- 总发现数: {len(findings)}
-- 严重程度分布: {severity_counts}
-- 当前迭代: {iteration}/{max_iterations}
-
-## 选项
-1. "verification" - 验证发现的漏洞(推荐:有发现需要验证时)
-2. "analysis" - 继续深入分析(推荐:发现较少但还有迭代次数时)
-3. "report" - 生成报告(推荐:没有发现或已充分分析时)
-
-请返回 JSON 格式:
-{{"action": "verification/analysis/report", "reason": "决策理由"}}"""
-
- try:
- response = await self.llm_service.chat_completion_raw(
- messages=[
- {"role": "system", "content": "你是安全审计流程的决策者,负责决定下一步行动。"},
- {"role": "user", "content": prompt},
- ],
- # 🔥 不传递 temperature 和 max_tokens,使用用户配置
- )
-
- content = response.get("content", "")
- import re
- json_match = re.search(r'\{.*\}', content, re.DOTALL)
- if json_match:
- result = json.loads(json_match.group())
- return result
- except Exception as e:
- logger.warning(f"LLM routing decision failed: {e}")
-
- # 默认决策
- if not findings:
- return {"action": "report", "reason": "没有发现漏洞"}
- if len(findings) >= 3 or iteration >= max_iterations:
- return {"action": "verification", "reason": "有足够的发现需要验证"}
- return {"action": "analysis", "reason": "发现较少,继续分析"}
-
- async def decide_after_verification(self, state: AuditState) -> Dict[str, Any]:
- """Verification 后让 LLM 决定下一步"""
- verified_findings = state.get("verified_findings", [])
- false_positives = state.get("false_positives", [])
- iteration = state.get("iteration", 0)
- max_iterations = state.get("max_iterations", 3)
-
- prompt = f"""作为安全审计的决策者,基于以下验证结果,决定下一步行动。
-
-## 验证结果
-- 已确认漏洞: {len(verified_findings)}
-- 误报数量: {len(false_positives)}
-- 当前迭代: {iteration}/{max_iterations}
-
-## 选项
-1. "analysis" - 回到分析阶段重新分析(推荐:误报率太高时)
-2. "report" - 生成最终报告(推荐:验证完成时)
-
-请返回 JSON 格式:
-{{"action": "analysis/report", "reason": "决策理由"}}"""
-
- try:
- response = await self.llm_service.chat_completion_raw(
- messages=[
- {"role": "system", "content": "你是安全审计流程的决策者,负责决定下一步行动。"},
- {"role": "user", "content": prompt},
- ],
- # 🔥 不传递 temperature 和 max_tokens,使用用户配置
- )
-
- content = response.get("content", "")
- import re
- json_match = re.search(r'\{.*\}', content, re.DOTALL)
- if json_match:
- result = json.loads(json_match.group())
- return result
- except Exception as e:
- logger.warning(f"LLM routing decision failed: {e}")
-
- # 默认决策
- if len(false_positives) > len(verified_findings) and iteration < max_iterations:
- return {"action": "analysis", "reason": "误报率较高,需要重新分析"}
- return {"action": "report", "reason": "验证完成,生成报告"}
-
-
-# ============ 路由函数 (结合 LLM 决策) ============
-
-def route_after_recon(state: AuditState) -> Literal["analysis", "end"]:
- """
- Recon 后的路由决策
- 优先使用 LLM 的决策,否则使用默认逻辑
- """
- # 🔥 检查是否有错误
- if state.get("error") or state.get("current_phase") == "error":
- logger.error(f"Recon phase has error, routing to end: {state.get('error')}")
- return "end"
-
- # 检查 LLM 是否有决策
- llm_action = state.get("llm_next_action")
- if llm_action:
- logger.info(f"Using LLM routing decision: {llm_action}, reason: {state.get('llm_routing_reason')}")
- if llm_action == "end":
- return "end"
- return "analysis"
-
- # 默认逻辑(作为 fallback)
- if not state.get("entry_points") and not state.get("high_risk_areas"):
- return "end"
- return "analysis"
-
-
-def route_after_analysis(state: AuditState) -> Literal["verification", "analysis", "report"]:
- """
- Analysis 后的路由决策
- 优先使用 LLM 的决策
- """
- # 检查 LLM 是否有决策
- llm_action = state.get("llm_next_action")
- if llm_action:
- logger.info(f"Using LLM routing decision: {llm_action}, reason: {state.get('llm_routing_reason')}")
- if llm_action == "verification":
- return "verification"
- elif llm_action == "analysis":
- return "analysis"
- elif llm_action == "report":
- return "report"
-
- # 默认逻辑
- findings = state.get("findings", [])
- iteration = state.get("iteration", 0)
- max_iterations = state.get("max_iterations", 3)
- should_continue = state.get("should_continue_analysis", False)
-
- if not findings:
- return "report"
-
- if should_continue and iteration < max_iterations:
- return "analysis"
-
- return "verification"
-
-
-def route_after_verification(state: AuditState) -> Literal["analysis", "report"]:
- """
- Verification 后的路由决策
- 优先使用 LLM 的决策
- """
- # 检查 LLM 是否有决策
- llm_action = state.get("llm_next_action")
- if llm_action:
- logger.info(f"Using LLM routing decision: {llm_action}, reason: {state.get('llm_routing_reason')}")
- if llm_action == "analysis":
- return "analysis"
- return "report"
-
- # 默认逻辑
- false_positives = state.get("false_positives", [])
- iteration = state.get("iteration", 0)
- max_iterations = state.get("max_iterations", 3)
-
- if len(false_positives) > len(state.get("verified_findings", [])) and iteration < max_iterations:
- return "analysis"
-
- return "report"
-
-
-# ============ 创建审计图 ============
-
-def create_audit_graph(
- recon_node,
- analysis_node,
- verification_node,
- report_node,
- checkpointer: Optional[MemorySaver] = None,
- llm_service=None, # 用于 LLM 路由决策
-) -> StateGraph:
- """
- 创建审计工作流图
-
- Args:
- recon_node: 信息收集节点
- analysis_node: 漏洞分析节点
- verification_node: 漏洞验证节点
- report_node: 报告生成节点
- checkpointer: 检查点存储器
- llm_service: LLM 服务(用于路由决策)
-
- Returns:
- 编译后的 StateGraph
-
- 工作流结构:
-
- START
- │
- ▼
- ┌──────┐
- │Recon │ 信息收集 (LLM 驱动)
- └──┬───┘
- │ LLM 决定
- ▼
- ┌──────────┐
- │ Analysis │◄─────┐ 漏洞分析 (LLM 驱动,可循环)
- └────┬─────┘ │
- │ LLM 决定 │
- ▼ │
- ┌────────────┐ │
- │Verification│────┘ 漏洞验证 (LLM 驱动,可回溯)
- └─────┬──────┘
- │ LLM 决定
- ▼
- ┌──────────┐
- │ Report │ 报告生成
- └────┬─────┘
- │
- ▼
- END
- """
-
- # 创建状态图
- workflow = StateGraph(AuditState)
-
- # 如果有 LLM 服务,创建路由决策器
- llm_router = LLMRouter(llm_service) if llm_service else None
-
- # 包装节点以添加 LLM 路由决策
- async def recon_with_routing(state):
- result = await recon_node(state)
-
- # LLM 决定下一步
- if llm_router:
- decision = await llm_router.decide_after_recon({**state, **result})
- result["llm_next_action"] = decision.get("action")
- result["llm_routing_reason"] = decision.get("reason")
-
- return result
-
- async def analysis_with_routing(state):
- result = await analysis_node(state)
-
- # LLM 决定下一步
- if llm_router:
- decision = await llm_router.decide_after_analysis({**state, **result})
- result["llm_next_action"] = decision.get("action")
- result["llm_routing_reason"] = decision.get("reason")
-
- return result
-
- async def verification_with_routing(state):
- result = await verification_node(state)
-
- # LLM 决定下一步
- if llm_router:
- decision = await llm_router.decide_after_verification({**state, **result})
- result["llm_next_action"] = decision.get("action")
- result["llm_routing_reason"] = decision.get("reason")
-
- return result
-
- # 添加节点
- if llm_router:
- workflow.add_node("recon", recon_with_routing)
- workflow.add_node("analysis", analysis_with_routing)
- workflow.add_node("verification", verification_with_routing)
- else:
- workflow.add_node("recon", recon_node)
- workflow.add_node("analysis", analysis_node)
- workflow.add_node("verification", verification_node)
-
- workflow.add_node("report", report_node)
-
- # 设置入口点
- workflow.set_entry_point("recon")
-
- # 添加条件边
- workflow.add_conditional_edges(
- "recon",
- route_after_recon,
- {
- "analysis": "analysis",
- "end": END,
- }
- )
-
- workflow.add_conditional_edges(
- "analysis",
- route_after_analysis,
- {
- "verification": "verification",
- "analysis": "analysis",
- "report": "report",
- }
- )
-
- workflow.add_conditional_edges(
- "verification",
- route_after_verification,
- {
- "analysis": "analysis",
- "report": "report",
- }
- )
-
- # Report -> END
- workflow.add_edge("report", END)
-
- # 编译图
- if checkpointer:
- return workflow.compile(checkpointer=checkpointer)
- else:
- return workflow.compile()
-
-
-# ============ 带人机协作的审计图 ============
-
-def create_audit_graph_with_human(
- recon_node,
- analysis_node,
- verification_node,
- report_node,
- human_review_node,
- checkpointer: Optional[MemorySaver] = None,
- llm_service=None,
-) -> StateGraph:
- """
- 创建带人机协作的审计工作流图
-
- 在验证阶段后增加人工审核节点
- """
-
- workflow = StateGraph(AuditState)
- llm_router = LLMRouter(llm_service) if llm_service else None
-
- # 包装节点
- async def recon_with_routing(state):
- result = await recon_node(state)
- if llm_router:
- decision = await llm_router.decide_after_recon({**state, **result})
- result["llm_next_action"] = decision.get("action")
- result["llm_routing_reason"] = decision.get("reason")
- return result
-
- async def analysis_with_routing(state):
- result = await analysis_node(state)
- if llm_router:
- decision = await llm_router.decide_after_analysis({**state, **result})
- result["llm_next_action"] = decision.get("action")
- result["llm_routing_reason"] = decision.get("reason")
- return result
-
- # 添加节点
- if llm_router:
- workflow.add_node("recon", recon_with_routing)
- workflow.add_node("analysis", analysis_with_routing)
- else:
- workflow.add_node("recon", recon_node)
- workflow.add_node("analysis", analysis_node)
-
- workflow.add_node("verification", verification_node)
- workflow.add_node("human_review", human_review_node)
- workflow.add_node("report", report_node)
-
- workflow.set_entry_point("recon")
-
- workflow.add_conditional_edges(
- "recon",
- route_after_recon,
- {"analysis": "analysis", "end": END}
- )
-
- workflow.add_conditional_edges(
- "analysis",
- route_after_analysis,
- {
- "verification": "verification",
- "analysis": "analysis",
- "report": "report",
- }
- )
-
- # Verification -> Human Review
- workflow.add_edge("verification", "human_review")
-
- # Human Review 后的路由
- def route_after_human(state: AuditState) -> Literal["analysis", "report"]:
- if state.get("should_continue_analysis"):
- return "analysis"
- return "report"
-
- workflow.add_conditional_edges(
- "human_review",
- route_after_human,
- {"analysis": "analysis", "report": "report"}
- )
-
- workflow.add_edge("report", END)
-
- if checkpointer:
- return workflow.compile(checkpointer=checkpointer, interrupt_before=["human_review"])
- else:
- return workflow.compile()
-
-
-# ============ 执行器 ============
-
-class AuditGraphRunner:
- """
- 审计图执行器
- 封装 LangGraph 工作流的执行
- """
-
- def __init__(
- self,
- graph: StateGraph,
- event_emitter=None,
- ):
- self.graph = graph
- self.event_emitter = event_emitter
-
- async def run(
- self,
- project_root: str,
- project_info: Dict[str, Any],
- config: Dict[str, Any],
- task_id: str,
- ) -> Dict[str, Any]:
- """
- 执行审计工作流
- """
- # 初始状态
- initial_state: AuditState = {
- "project_root": project_root,
- "project_info": project_info,
- "config": config,
- "task_id": task_id,
- "tech_stack": {},
- "entry_points": [],
- "high_risk_areas": [],
- "dependencies": {},
- "findings": [],
- "verified_findings": [],
- "false_positives": [],
- "_verified_findings_update": None, # 🔥 NEW: 验证后的 findings 更新
- "current_phase": "start",
- "iteration": 0,
- "max_iterations": config.get("max_iterations", 3),
- "should_continue_analysis": False,
- "llm_next_action": None,
- "llm_routing_reason": None,
- "messages": [],
- "events": [],
- "summary": None,
- "security_score": None,
- "error": None,
- }
-
- run_config = {
- "configurable": {
- "thread_id": task_id,
- }
- }
-
- try:
- async for event in self.graph.astream(initial_state, config=run_config):
- if self.event_emitter:
- for node_name, node_state in event.items():
- await self.event_emitter.emit_info(
- f"节点 {node_name} 完成"
- )
-
- # 发射 LLM 路由决策事件
- if node_state.get("llm_routing_reason"):
- await self.event_emitter.emit_info(
- f"🧠 LLM 决策: {node_state.get('llm_next_action')} - {node_state.get('llm_routing_reason')}"
- )
-
- if node_name == "analysis" and node_state.get("findings"):
- new_findings = node_state["findings"]
- await self.event_emitter.emit_info(
- f"发现 {len(new_findings)} 个潜在漏洞"
- )
-
- final_state = self.graph.get_state(run_config)
- return final_state.values
-
- except Exception as e:
- logger.error(f"Graph execution failed: {e}", exc_info=True)
- raise
-
- async def run_with_human_review(
- self,
- initial_state: AuditState,
- human_feedback_callback,
- ) -> Dict[str, Any]:
- """带人机协作的执行"""
- run_config = {
- "configurable": {
- "thread_id": initial_state["task_id"],
- }
- }
-
- async for event in self.graph.astream(initial_state, config=run_config):
- pass
-
- current_state = self.graph.get_state(run_config)
-
- if current_state.next == ("human_review",):
- human_decision = await human_feedback_callback(current_state.values)
-
- updated_state = {
- **current_state.values,
- "should_continue_analysis": human_decision.get("continue_analysis", False),
- }
-
- async for event in self.graph.astream(updated_state, config=run_config):
- pass
-
- return self.graph.get_state(run_config).values
diff --git a/backend/app/services/agent/graph/nodes.py b/backend/app/services/agent/graph/nodes.py
deleted file mode 100644
index 0f8b034..0000000
--- a/backend/app/services/agent/graph/nodes.py
+++ /dev/null
@@ -1,556 +0,0 @@
-"""
-LangGraph 节点实现
-每个节点封装一个 Agent 的执行逻辑
-
-协作增强:节点之间通过 TaskHandoff 传递结构化的上下文和洞察
-"""
-
-from typing import Dict, Any, List, Optional
-import logging
-
-logger = logging.getLogger(__name__)
-
-# 延迟导入避免循环依赖
-def get_audit_state_type():
- from .audit_graph import AuditState
- return AuditState
-
-
-class BaseNode:
- """节点基类"""
-
- def __init__(self, agent=None, event_emitter=None):
- self.agent = agent
- self.event_emitter = event_emitter
-
- async def emit_event(self, event_type: str, message: str, **kwargs):
- """发射事件"""
- if self.event_emitter:
- try:
- await self.event_emitter.emit_info(message)
- except Exception as e:
- logger.warning(f"Failed to emit event: {e}")
-
- def _extract_handoff_from_state(self, state: Dict[str, Any], from_phase: str):
- """从状态中提取前序 Agent 的 handoff"""
- handoff_data = state.get(f"{from_phase}_handoff")
- if handoff_data:
- from ..agents.base import TaskHandoff
- return TaskHandoff.from_dict(handoff_data)
- return None
-
-
-class ReconNode(BaseNode):
- """
- 信息收集节点
-
- 输入: project_root, project_info, config
- 输出: tech_stack, entry_points, high_risk_areas, dependencies, recon_handoff
- """
-
- async def __call__(self, state: Dict[str, Any]) -> Dict[str, Any]:
- """执行信息收集"""
- await self.emit_event("phase_start", "🔍 开始信息收集阶段")
-
- try:
- # 调用 Recon Agent
- result = await self.agent.run({
- "project_info": state["project_info"],
- "config": state["config"],
- })
-
- if result.success and result.data:
- data = result.data
-
- # 🔥 创建交接信息给 Analysis Agent
- handoff = self.agent.create_handoff(
- to_agent="Analysis",
- summary=f"项目信息收集完成。发现 {len(data.get('entry_points', []))} 个入口点,{len(data.get('high_risk_areas', []))} 个高风险区域。",
- key_findings=data.get("initial_findings", []),
- suggested_actions=[
- {
- "type": "deep_analysis",
- "description": f"深入分析高风险区域: {', '.join(data.get('high_risk_areas', [])[:5])}",
- "priority": "high",
- },
- {
- "type": "entry_point_audit",
- "description": "审计所有入口点的输入验证",
- "priority": "high",
- },
- ],
- attention_points=[
- f"技术栈: {data.get('tech_stack', {}).get('frameworks', [])}",
- f"主要语言: {data.get('tech_stack', {}).get('languages', [])}",
- ],
- priority_areas=data.get("high_risk_areas", [])[:10],
- context_data={
- "tech_stack": data.get("tech_stack", {}),
- "entry_points": data.get("entry_points", []),
- "dependencies": data.get("dependencies", {}),
- },
- )
-
- await self.emit_event(
- "phase_complete",
- f"✅ 信息收集完成: 发现 {len(data.get('entry_points', []))} 个入口点"
- )
-
- return {
- "tech_stack": data.get("tech_stack", {}),
- "entry_points": data.get("entry_points", []),
- "high_risk_areas": data.get("high_risk_areas", []),
- "dependencies": data.get("dependencies", {}),
- "current_phase": "recon_complete",
- "findings": data.get("initial_findings", []),
- # 🔥 保存交接信息
- "recon_handoff": handoff.to_dict(),
- "events": [{
- "type": "recon_complete",
- "data": {
- "entry_points_count": len(data.get("entry_points", [])),
- "high_risk_areas_count": len(data.get("high_risk_areas", [])),
- "handoff_summary": handoff.summary,
- }
- }],
- }
- else:
- return {
- "error": result.error or "Recon failed",
- "current_phase": "error",
- }
-
- except Exception as e:
- logger.error(f"Recon node failed: {e}", exc_info=True)
- return {
- "error": str(e),
- "current_phase": "error",
- }
-
-
-class AnalysisNode(BaseNode):
- """
- 漏洞分析节点
-
- 输入: tech_stack, entry_points, high_risk_areas, recon_handoff
- 输出: findings (累加), should_continue_analysis, analysis_handoff
- """
-
- async def __call__(self, state: Dict[str, Any]) -> Dict[str, Any]:
- """执行漏洞分析"""
- iteration = state.get("iteration", 0) + 1
-
- await self.emit_event(
- "phase_start",
- f"🔬 开始漏洞分析阶段 (迭代 {iteration})"
- )
-
- try:
- # 🔥 提取 Recon 的交接信息
- recon_handoff = self._extract_handoff_from_state(state, "recon")
- if recon_handoff:
- self.agent.receive_handoff(recon_handoff)
- await self.emit_event(
- "handoff_received",
- f"📨 收到 Recon Agent 交接: {recon_handoff.summary[:50]}..."
- )
-
- # 构建分析输入
- analysis_input = {
- "phase_name": "analysis",
- "project_info": state["project_info"],
- "config": state["config"],
- "plan": {
- "high_risk_areas": state.get("high_risk_areas", []),
- },
- "previous_results": {
- "recon": {
- "data": {
- "tech_stack": state.get("tech_stack", {}),
- "entry_points": state.get("entry_points", []),
- "high_risk_areas": state.get("high_risk_areas", []),
- }
- }
- },
- # 🔥 传递交接信息
- "handoff": recon_handoff,
- }
-
- # 调用 Analysis Agent
- result = await self.agent.run(analysis_input)
-
- if result.success and result.data:
- new_findings = result.data.get("findings", [])
- logger.info(f"[AnalysisNode] Agent returned {len(new_findings)} findings")
-
- # 判断是否需要继续分析
- should_continue = (
- len(new_findings) >= 5 and
- iteration < state.get("max_iterations", 3)
- )
-
- # 🔥 创建交接信息给 Verification Agent
- # 统计严重程度
- severity_counts = {"critical": 0, "high": 0, "medium": 0, "low": 0}
- for f in new_findings:
- if isinstance(f, dict):
- sev = f.get("severity", "medium")
- severity_counts[sev] = severity_counts.get(sev, 0) + 1
-
- handoff = self.agent.create_handoff(
- to_agent="Verification",
- summary=f"漏洞分析完成。发现 {len(new_findings)} 个潜在漏洞 (Critical: {severity_counts['critical']}, High: {severity_counts['high']}, Medium: {severity_counts['medium']}, Low: {severity_counts['low']})",
- key_findings=new_findings[:20], # 传递前20个发现
- suggested_actions=[
- {
- "type": "verify_critical",
- "description": "优先验证 Critical 和 High 级别的漏洞",
- "priority": "critical",
- },
- {
- "type": "poc_generation",
- "description": "为确认的漏洞生成 PoC",
- "priority": "high",
- },
- ],
- attention_points=[
- f"共 {severity_counts['critical']} 个 Critical 级别漏洞需要立即验证",
- f"共 {severity_counts['high']} 个 High 级别漏洞需要优先验证",
- "注意检查是否有误报,特别是静态分析工具的结果",
- ],
- priority_areas=[
- f.get("file_path", "") for f in new_findings
- if f.get("severity") in ["critical", "high"]
- ][:10],
- context_data={
- "severity_distribution": severity_counts,
- "total_findings": len(new_findings),
- "iteration": iteration,
- },
- )
-
- await self.emit_event(
- "phase_complete",
- f"✅ 分析迭代 {iteration} 完成: 发现 {len(new_findings)} 个潜在漏洞"
- )
-
- return {
- "findings": new_findings,
- "iteration": iteration,
- "should_continue_analysis": should_continue,
- "current_phase": "analysis_complete",
- # 🔥 保存交接信息
- "analysis_handoff": handoff.to_dict(),
- "events": [{
- "type": "analysis_iteration",
- "data": {
- "iteration": iteration,
- "findings_count": len(new_findings),
- "severity_distribution": severity_counts,
- "handoff_summary": handoff.summary,
- }
- }],
- }
- else:
- return {
- "iteration": iteration,
- "should_continue_analysis": False,
- "current_phase": "analysis_complete",
- }
-
- except Exception as e:
- logger.error(f"Analysis node failed: {e}", exc_info=True)
- return {
- "error": str(e),
- "should_continue_analysis": False,
- "current_phase": "error",
- }
-
-
-class VerificationNode(BaseNode):
- """
- 漏洞验证节点
-
- 输入: findings, analysis_handoff
- 输出: verified_findings, false_positives, verification_handoff
- """
-
- async def __call__(self, state: Dict[str, Any]) -> Dict[str, Any]:
- """执行漏洞验证"""
- findings = state.get("findings", [])
- logger.info(f"[VerificationNode] Received {len(findings)} findings to verify")
-
- if not findings:
- return {
- "verified_findings": [],
- "false_positives": [],
- "current_phase": "verification_complete",
- }
-
- await self.emit_event(
- "phase_start",
- f"🔐 开始漏洞验证阶段 ({len(findings)} 个待验证)"
- )
-
- try:
- # 🔥 提取 Analysis 的交接信息
- analysis_handoff = self._extract_handoff_from_state(state, "analysis")
- if analysis_handoff:
- self.agent.receive_handoff(analysis_handoff)
- await self.emit_event(
- "handoff_received",
- f"📨 收到 Analysis Agent 交接: {analysis_handoff.summary[:50]}..."
- )
-
- # 构建验证输入
- verification_input = {
- "previous_results": {
- "analysis": {
- "data": {
- "findings": findings,
- }
- }
- },
- "config": state["config"],
- # 🔥 传递交接信息
- "handoff": analysis_handoff,
- }
-
- # 调用 Verification Agent
- result = await self.agent.run(verification_input)
-
- if result.success and result.data:
- all_verified_findings = result.data.get("findings", [])
- verified = [f for f in all_verified_findings if f.get("is_verified")]
- false_pos = [f.get("id", f.get("title", "unknown")) for f in all_verified_findings
- if f.get("verdict") == "false_positive"]
-
- # 🔥 CRITICAL FIX: 用验证结果更新原始 findings
- # 创建 findings 的更新映射,基于 (file_path, line_start, vulnerability_type)
- verified_map = {}
- for vf in all_verified_findings:
- key = (
- vf.get("file_path", ""),
- vf.get("line_start", 0),
- vf.get("vulnerability_type", ""),
- )
- verified_map[key] = vf
-
- # 合并验证结果到原始 findings
- updated_findings = []
- seen_keys = set()
-
- # 首先处理原始 findings,用验证结果更新
- for f in findings:
- if not isinstance(f, dict):
- continue
- key = (
- f.get("file_path", ""),
- f.get("line_start", 0),
- f.get("vulnerability_type", ""),
- )
- if key in verified_map:
- # 使用验证后的版本
- updated_findings.append(verified_map[key])
- seen_keys.add(key)
- else:
- # 保留原始(未验证)
- updated_findings.append(f)
- seen_keys.add(key)
-
- # 添加验证结果中的新发现(如果有)
- for key, vf in verified_map.items():
- if key not in seen_keys:
- updated_findings.append(vf)
-
- logger.info(f"[VerificationNode] Updated findings: {len(updated_findings)} total, {len(verified)} verified")
-
- # 🔥 创建交接信息给 Report 节点
- handoff = self.agent.create_handoff(
- to_agent="Report",
- summary=f"漏洞验证完成。{len(verified)} 个漏洞已确认,{len(false_pos)} 个误报已排除。",
- key_findings=verified,
- suggested_actions=[
- {
- "type": "generate_report",
- "description": "生成详细的安全审计报告",
- "priority": "high",
- },
- {
- "type": "remediation_plan",
- "description": "为确认的漏洞制定修复计划",
- "priority": "high",
- },
- ],
- attention_points=[
- f"共 {len(verified)} 个漏洞已确认存在",
- f"共 {len(false_pos)} 个误报已排除",
- "建议按严重程度优先修复 Critical 和 High 级别漏洞",
- ],
- context_data={
- "verified_count": len(verified),
- "false_positive_count": len(false_pos),
- "total_analyzed": len(findings),
- "verification_rate": len(verified) / len(findings) if findings else 0,
- },
- )
-
- await self.emit_event(
- "phase_complete",
- f"✅ 验证完成: {len(verified)} 已确认, {len(false_pos)} 误报"
- )
-
- return {
- # 🔥 CRITICAL: 返回更新后的 findings,这会替换状态中的 findings
- # 注意:由于 LangGraph 使用 operator.add,我们需要在 runner 中处理合并
- # 这里我们返回 _verified_findings_update 作为特殊字段
- "_verified_findings_update": updated_findings,
- "verified_findings": verified,
- "false_positives": false_pos,
- "current_phase": "verification_complete",
- # 🔥 保存交接信息
- "verification_handoff": handoff.to_dict(),
- "events": [{
- "type": "verification_complete",
- "data": {
- "verified_count": len(verified),
- "false_positive_count": len(false_pos),
- "total_findings": len(updated_findings),
- "handoff_summary": handoff.summary,
- }
- }],
- }
- else:
- return {
- "verified_findings": [],
- "false_positives": [],
- "current_phase": "verification_complete",
- }
-
- except Exception as e:
- logger.error(f"Verification node failed: {e}", exc_info=True)
- return {
- "error": str(e),
- "current_phase": "error",
- }
-
-
-class ReportNode(BaseNode):
- """
- 报告生成节点
-
- 输入: all state
- 输出: summary, security_score
- """
-
- async def __call__(self, state: Dict[str, Any]) -> Dict[str, Any]:
- """生成审计报告"""
- await self.emit_event("phase_start", "📊 生成审计报告")
-
- try:
- # 🔥 CRITICAL FIX: 优先使用验证后的 findings 更新
- findings = state.get("_verified_findings_update") or state.get("findings", [])
- verified = state.get("verified_findings", [])
- false_positives = state.get("false_positives", [])
-
- logger.info(f"[ReportNode] State contains {len(findings)} findings, {len(verified)} verified")
-
- # 统计漏洞分布
- severity_counts = {"critical": 0, "high": 0, "medium": 0, "low": 0}
- type_counts = {}
-
- for finding in findings:
- # 跳过非字典类型的 finding(防止数据格式异常)
- if not isinstance(finding, dict):
- logger.warning(f"Skipping invalid finding (not a dict): {type(finding)}")
- continue
-
- sev = finding.get("severity", "medium")
- severity_counts[sev] = severity_counts.get(sev, 0) + 1
-
- vtype = finding.get("vulnerability_type", "other")
- type_counts[vtype] = type_counts.get(vtype, 0) + 1
-
- # 计算安全评分
- base_score = 100
- deductions = (
- severity_counts["critical"] * 25 +
- severity_counts["high"] * 15 +
- severity_counts["medium"] * 8 +
- severity_counts["low"] * 3
- )
- security_score = max(0, base_score - deductions)
-
- # 生成摘要
- summary = {
- "total_findings": len(findings),
- "verified_count": len(verified),
- "false_positive_count": len(false_positives),
- "severity_distribution": severity_counts,
- "vulnerability_types": type_counts,
- "tech_stack": state.get("tech_stack", {}),
- "entry_points_analyzed": len(state.get("entry_points", [])),
- "high_risk_areas": state.get("high_risk_areas", []),
- "iterations": state.get("iteration", 1),
- }
-
- await self.emit_event(
- "phase_complete",
- f"报告生成完成: 安全评分 {security_score}/100"
- )
-
- return {
- "summary": summary,
- "security_score": security_score,
- "current_phase": "complete",
- "events": [{
- "type": "audit_complete",
- "data": {
- "security_score": security_score,
- "total_findings": len(findings),
- "verified_count": len(verified),
- }
- }],
- }
-
- except Exception as e:
- logger.error(f"Report node failed: {e}", exc_info=True)
- return {
- "error": str(e),
- "current_phase": "error",
- }
-
-
-class HumanReviewNode(BaseNode):
- """
- 人工审核节点
-
- 在此节点暂停,等待人工反馈
- """
-
- async def __call__(self, state: Dict[str, Any]) -> Dict[str, Any]:
- """
- 人工审核节点
-
- 这个节点会被 interrupt_before 暂停
- 用户可以:
- 1. 确认发现
- 2. 标记误报
- 3. 请求重新分析
- """
- await self.emit_event(
- "human_review",
- f"⏸️ 等待人工审核 ({len(state.get('verified_findings', []))} 个待确认)"
- )
-
- # 返回当前状态,不做修改
- # 人工反馈会通过 update_state 传入
- return {
- "current_phase": "human_review",
- "messages": [{
- "role": "system",
- "content": "等待人工审核",
- "findings_for_review": state.get("verified_findings", []),
- }],
- }
-
diff --git a/backend/app/services/agent/graph/runner.py b/backend/app/services/agent/graph/runner.py
deleted file mode 100644
index e0afd12..0000000
--- a/backend/app/services/agent/graph/runner.py
+++ /dev/null
@@ -1,1042 +0,0 @@
-"""
-DeepAudit LangGraph Runner
-基于 LangGraph 的 Agent 审计执行器
-"""
-
-import asyncio
-import logging
-import os
-import uuid
-from datetime import datetime, timezone
-from typing import Dict, List, Optional, Any, AsyncGenerator
-
-from sqlalchemy.ext.asyncio import AsyncSession
-
-from langgraph.graph import StateGraph, END
-from langgraph.checkpoint.memory import MemorySaver
-
-from app.services.agent.streaming import StreamHandler, StreamEvent, StreamEventType
-from app.models.agent_task import (
- AgentTask, AgentEvent, AgentFinding,
- AgentTaskStatus, AgentTaskPhase, AgentEventType,
- VulnerabilitySeverity, VulnerabilityType, FindingStatus,
-)
-from app.services.agent.event_manager import EventManager, AgentEventEmitter
-from app.services.agent.tools import (
- RAGQueryTool, SecurityCodeSearchTool, FunctionContextTool,
- PatternMatchTool, CodeAnalysisTool, DataFlowAnalysisTool, VulnerabilityValidationTool,
- FileReadTool, FileSearchTool, ListFilesTool,
- SandboxTool, SandboxHttpTool, VulnerabilityVerifyTool, SandboxManager,
- SemgrepTool, BanditTool, GitleaksTool, NpmAuditTool, SafetyTool,
- TruffleHogTool, OSVScannerTool,
-)
-from app.services.rag import CodeIndexer, CodeRetriever, EmbeddingService
-from app.core.config import settings
-
-from .audit_graph import AuditState, create_audit_graph
-from .nodes import ReconNode, AnalysisNode, VerificationNode, ReportNode
-
-logger = logging.getLogger(__name__)
-
-
-# 🔥 使用系统统一的 LLMService(支持用户配置)
-from app.services.llm.service import LLMService
-
-
-class AgentRunner:
- """
- DeepAudit LangGraph Agent Runner
-
- 基于 LangGraph 状态图的审计执行器
-
- 工作流:
- START → Recon → Analysis ⟲ → Verification → Report → END
- """
-
- def __init__(
- self,
- db: AsyncSession,
- task: AgentTask,
- project_root: str,
- user_config: Optional[Dict[str, Any]] = None,
- ):
- self.db = db
- self.task = task
- self.project_root = project_root
-
- # 🔥 保存用户配置,供 RAG 初始化使用
- self.user_config = user_config or {}
-
- # 事件管理 - 传入 db_session_factory 以持久化事件
- from app.db.session import async_session_factory
- self.event_manager = EventManager(db_session_factory=async_session_factory)
- self.event_emitter = AgentEventEmitter(task.id, self.event_manager)
-
- # 🔥 CRITICAL: 立即创建事件队列,确保在 Agent 开始执行前队列就存在
- # 这样即使前端 SSE 连接稍晚,token 事件也不会丢失
- self.event_manager.create_queue(task.id)
-
- # 🔥 LLM 服务 - 使用用户配置(从系统配置页面获取)
- self.llm_service = LLMService(user_config=self.user_config)
-
- # 工具集
- self.tools: Dict[str, Any] = {}
-
- # RAG 组件
- self.retriever: Optional[CodeRetriever] = None
- self.indexer: Optional[CodeIndexer] = None
-
- # 沙箱
- self.sandbox_manager: Optional[SandboxManager] = None
-
- # LangGraph
- self.graph: Optional[StateGraph] = None
- self.checkpointer = MemorySaver()
-
- # 状态
- self._cancelled = False
- self._running_task: Optional[asyncio.Task] = None
-
- # Agent 引用(用于取消传播)
- self._agents: List[Any] = []
-
- # 流式处理器
- self.stream_handler = StreamHandler(task.id)
-
- def cancel(self):
- """取消任务"""
- self._cancelled = True
-
- # 🔥 取消所有 Agent
- for agent in self._agents:
- if hasattr(agent, 'cancel'):
- agent.cancel()
- logger.debug(f"Cancelled agent: {agent.name if hasattr(agent, 'name') else 'unknown'}")
-
- # 取消运行中的任务
- if self._running_task and not self._running_task.done():
- self._running_task.cancel()
-
- logger.info(f"Task {self.task.id} cancellation requested")
-
- @property
- def is_cancelled(self) -> bool:
- """检查是否已取消"""
- return self._cancelled
-
- async def initialize(self):
- """初始化 Runner"""
- await self.event_emitter.emit_info("🚀 正在初始化 DeepAudit LangGraph Agent...")
-
- # 1. 初始化 RAG 系统
- await self._initialize_rag()
-
- # 2. 初始化工具
- await self._initialize_tools()
-
- # 3. 构建 LangGraph
- await self._build_graph()
-
- await self.event_emitter.emit_info("✅ LangGraph 系统初始化完成")
-
- async def _initialize_rag(self):
- """初始化 RAG 系统"""
- await self.event_emitter.emit_info("📚 初始化 RAG 代码检索系统...")
-
- try:
- # 🔥 从用户配置中获取配置
- # 优先级:用户嵌入配置 > 用户 LLM 配置 > 环境变量
- user_llm_config = self.user_config.get('llmConfig', {})
- user_other_config = self.user_config.get('otherConfig', {})
- user_embedding_config = user_other_config.get('embedding_config', {})
-
- # 🔥 Embedding Provider 优先级:用户嵌入配置 > 环境变量
- embedding_provider = (
- user_embedding_config.get('provider') or
- getattr(settings, 'EMBEDDING_PROVIDER', 'openai')
- )
-
- # 🔥 Embedding Model 优先级:用户嵌入配置 > 环境变量
- embedding_model = (
- user_embedding_config.get('model') or
- getattr(settings, 'EMBEDDING_MODEL', 'text-embedding-3-small')
- )
-
- # 🔥 API Key 优先级:用户嵌入配置 > 用户 LLM 配置 > 环境变量
- embedding_api_key = (
- user_embedding_config.get('api_key') or
- user_llm_config.get('llmApiKey') or
- getattr(settings, 'LLM_API_KEY', '') or
- ''
- )
-
- # 🔥 Base URL 优先级:用户嵌入配置 > 用户 LLM 配置 > 环境变量
- embedding_base_url = (
- user_embedding_config.get('base_url') or
- user_llm_config.get('llmBaseUrl') or
- getattr(settings, 'LLM_BASE_URL', None) or
- None
- )
-
- logger.info(f"RAG 配置: provider={embedding_provider}, model={embedding_model}")
- await self.event_emitter.emit_info(f"嵌入模型: {embedding_provider}/{embedding_model}")
-
- embedding_service = EmbeddingService(
- provider=embedding_provider,
- model=embedding_model,
- api_key=embedding_api_key,
- base_url=embedding_base_url,
- )
-
- self.indexer = CodeIndexer(
- collection_name=f"project_{self.task.project_id}",
- embedding_service=embedding_service,
- persist_directory=settings.VECTOR_DB_PATH,
- )
-
- self.retriever = CodeRetriever(
- collection_name=f"project_{self.task.project_id}",
- embedding_service=embedding_service,
- persist_directory=settings.VECTOR_DB_PATH,
- )
-
- except Exception as e:
- logger.warning(f"RAG initialization failed: {e}")
- await self.event_emitter.emit_warning(f"RAG 系统初始化失败: {e}")
-
- async def _initialize_tools(self):
- """初始化工具集"""
- await self.event_emitter.emit_info("初始化 Agent 工具集...")
-
- # 🔥 导入新工具
- from app.services.agent.tools import (
- ThinkTool, ReflectTool,
- CreateVulnerabilityReportTool,
- # 多语言代码测试工具
- PhpTestTool, PythonTestTool, JavaScriptTestTool, JavaTestTool,
- GoTestTool, RubyTestTool, ShellTestTool, UniversalCodeTestTool,
- # 漏洞验证专用工具
- CommandInjectionTestTool, SqlInjectionTestTool, XssTestTool,
- PathTraversalTestTool, SstiTestTool, DeserializationTestTool,
- UniversalVulnTestTool,
- # Kunlun-M 静态代码分析工具 (MIT License)
- KunlunMTool, KunlunRuleListTool, KunlunPluginTool,
- )
- # 🔥 导入知识查询工具
- from app.services.agent.knowledge import (
- SecurityKnowledgeQueryTool,
- GetVulnerabilityKnowledgeTool,
- )
-
- # 🔥 获取排除模式和目标文件
- exclude_patterns = self.task.exclude_patterns or []
- target_files = self.task.target_files or None
-
- # ============ 🔥 提前初始化 SandboxManager(供所有外部工具共享)============
- self.sandbox_manager = None
- try:
- from app.services.agent.tools.sandbox_tool import SandboxConfig
- sandbox_config = SandboxConfig(
- image=settings.SANDBOX_IMAGE,
- memory_limit=settings.SANDBOX_MEMORY_LIMIT,
- cpu_limit=settings.SANDBOX_CPU_LIMIT,
- timeout=settings.SANDBOX_TIMEOUT,
- network_mode=settings.SANDBOX_NETWORK_MODE,
- )
- self.sandbox_manager = SandboxManager(config=sandbox_config)
- # 🔥 必须调用 initialize() 来连接 Docker
- await self.sandbox_manager.initialize()
- logger.info(f"✅ SandboxManager initialized early (Docker available: {self.sandbox_manager.is_available})")
- except Exception as e:
- logger.warning(f"❌ Early Sandbox Manager initialization failed: {e}")
- import traceback
- logger.warning(f"Traceback: {traceback.format_exc()}")
- # 尝试创建默认管理器作为后备
- try:
- self.sandbox_manager = SandboxManager()
- await self.sandbox_manager.initialize()
- logger.info(f"⚠️ Created fallback SandboxManager (Docker available: {self.sandbox_manager.is_available})")
- except Exception as e2:
- logger.error(f"❌ Failed to create fallback SandboxManager: {e2}")
-
- # ============ 基础工具(所有 Agent 共享)============
- base_tools = {
- "read_file": FileReadTool(self.project_root, exclude_patterns, target_files),
- "list_files": ListFilesTool(self.project_root, exclude_patterns, target_files),
- # 🔥 新增:思考工具(所有Agent可用)
- "think": ThinkTool(),
- }
-
- # ============ Recon Agent 专属工具 ============
- # 职责:信息收集、项目结构分析、技术栈识别
- # 🔥 新增:外部工具也可用于Recon阶段的快速扫描
- self.recon_tools = {
- **base_tools,
- "search_code": FileSearchTool(self.project_root, exclude_patterns, target_files),
- # 🔥 新增:反思工具
- "reflect": ReflectTool(),
- # 🔥 外部安全工具(共享 SandboxManager 实例)
- "semgrep_scan": SemgrepTool(self.project_root, self.sandbox_manager),
- "bandit_scan": BanditTool(self.project_root, self.sandbox_manager),
- "gitleaks_scan": GitleaksTool(self.project_root, self.sandbox_manager),
- "safety_scan": SafetyTool(self.project_root, self.sandbox_manager),
- "npm_audit": NpmAuditTool(self.project_root, self.sandbox_manager),
- }
-
- # RAG 工具(Recon 用于语义搜索)
- if self.retriever:
- self.recon_tools["rag_query"] = RAGQueryTool(self.retriever)
- logger.info("✅ RAG 工具已注册到 Recon Agent")
- else:
- logger.warning("⚠️ RAG 未初始化,rag_query 工具不可用")
-
- # ============ Analysis Agent 专属工具 ============
- # 职责:漏洞分析、代码审计、模式匹配
- self.analysis_tools = {
- **base_tools,
- "search_code": FileSearchTool(self.project_root, exclude_patterns, target_files),
- # 模式匹配和代码分析
- "pattern_match": PatternMatchTool(self.project_root),
- # TODO: code_analysis 工具暂时禁用,因为 LLM 调用经常失败
- # "code_analysis": CodeAnalysisTool(self.llm_service),
- "dataflow_analysis": DataFlowAnalysisTool(self.llm_service),
- # 🔥 外部静态分析工具(共享 SandboxManager 实例)
- "semgrep_scan": SemgrepTool(self.project_root, self.sandbox_manager),
- "bandit_scan": BanditTool(self.project_root, self.sandbox_manager),
- "gitleaks_scan": GitleaksTool(self.project_root, self.sandbox_manager),
- "trufflehog_scan": TruffleHogTool(self.project_root, self.sandbox_manager),
- "npm_audit": NpmAuditTool(self.project_root, self.sandbox_manager),
- "safety_scan": SafetyTool(self.project_root, self.sandbox_manager),
- "osv_scan": OSVScannerTool(self.project_root, self.sandbox_manager),
- # 🔥 Kunlun-M 静态代码分析工具 (MIT License - https://github.com/LoRexxar/Kunlun-M)
- "kunlun_scan": KunlunMTool(self.project_root),
- "kunlun_list_rules": KunlunRuleListTool(self.project_root),
- "kunlun_plugin": KunlunPluginTool(self.project_root),
- # 🔥 新增:反思工具
- "reflect": ReflectTool(),
- # 🔥 新增:安全知识查询工具(基于RAG)
- "query_security_knowledge": SecurityKnowledgeQueryTool(),
- "get_vulnerability_knowledge": GetVulnerabilityKnowledgeTool(),
- }
-
- # RAG 工具(Analysis 用于安全相关代码搜索)
- if self.retriever:
- self.analysis_tools["rag_query"] = RAGQueryTool(self.retriever) # 通用语义搜索
- self.analysis_tools["security_search"] = SecurityCodeSearchTool(self.retriever) # 安全代码搜索
- self.analysis_tools["function_context"] = FunctionContextTool(self.retriever) # 函数上下文
- logger.info("✅ RAG 工具已注册到 Analysis Agent (rag_query, security_search, function_context)")
-
- # ============ Verification Agent 专属工具 ============
- # 职责:漏洞验证、PoC 执行、误报排除
- self.verification_tools = {
- **base_tools,
- # 验证工具 - 移除旧的 vulnerability_validation 和 dataflow_analysis,强制使用沙箱
- # 🔥 新增:漏洞报告工具(仅Verification可用)- v2.1: 传递 project_root
- "create_vulnerability_report": CreateVulnerabilityReportTool(self.project_root),
- # 🔥 新增:反思工具
- "reflect": ReflectTool(),
- }
-
- # 🔥 注册沙箱工具(使用提前初始化的 SandboxManager)
- if self.sandbox_manager:
- # 🔥 沙箱核心工具
- self.verification_tools["sandbox_exec"] = SandboxTool(self.sandbox_manager)
- self.verification_tools["sandbox_http"] = SandboxHttpTool(self.sandbox_manager)
- self.verification_tools["verify_vulnerability"] = VulnerabilityVerifyTool(self.sandbox_manager)
-
- # 🔥 多语言代码测试工具
- self.verification_tools["php_test"] = PhpTestTool(self.sandbox_manager, self.project_root)
- self.verification_tools["python_test"] = PythonTestTool(self.sandbox_manager, self.project_root)
- self.verification_tools["javascript_test"] = JavaScriptTestTool(self.sandbox_manager, self.project_root)
- self.verification_tools["java_test"] = JavaTestTool(self.sandbox_manager, self.project_root)
- self.verification_tools["go_test"] = GoTestTool(self.sandbox_manager, self.project_root)
- self.verification_tools["ruby_test"] = RubyTestTool(self.sandbox_manager, self.project_root)
- self.verification_tools["shell_test"] = ShellTestTool(self.sandbox_manager, self.project_root)
- self.verification_tools["universal_code_test"] = UniversalCodeTestTool(self.sandbox_manager, self.project_root)
-
- # 🔥 漏洞验证专用工具
- self.verification_tools["test_command_injection"] = CommandInjectionTestTool(self.sandbox_manager, self.project_root)
- self.verification_tools["test_sql_injection"] = SqlInjectionTestTool(self.sandbox_manager, self.project_root)
- self.verification_tools["test_xss"] = XssTestTool(self.sandbox_manager, self.project_root)
- self.verification_tools["test_path_traversal"] = PathTraversalTestTool(self.sandbox_manager, self.project_root)
- self.verification_tools["test_ssti"] = SstiTestTool(self.sandbox_manager, self.project_root)
- self.verification_tools["test_deserialization"] = DeserializationTestTool(self.sandbox_manager, self.project_root)
- self.verification_tools["universal_vuln_test"] = UniversalVulnTestTool(self.sandbox_manager, self.project_root)
-
- logger.info(f"✅ Sandbox tools initialized (Docker available: {self.sandbox_manager.is_available})")
- else:
- logger.error("❌ Sandbox tools NOT initialized due to critical manager failure")
-
- logger.info(f"✅ Verification tools: {list(self.verification_tools.keys())}")
-
- # 统计总工具数
- total_tools = len(set(
- list(self.recon_tools.keys()) +
- list(self.analysis_tools.keys()) +
- list(self.verification_tools.keys())
- ))
- await self.event_emitter.emit_info(f"已加载 {total_tools} 个工具")
-
- async def _build_graph(self):
- """构建 LangGraph 审计图"""
- await self.event_emitter.emit_info("📊 构建 LangGraph 审计工作流...")
-
- # 导入 Agent
- from app.services.agent.agents import ReconAgent, AnalysisAgent, VerificationAgent
-
- # 创建 Agent 实例(每个 Agent 使用专属工具集)
- recon_agent = ReconAgent(
- llm_service=self.llm_service,
- tools=self.recon_tools, # Recon 专属工具
- event_emitter=self.event_emitter,
- )
-
- analysis_agent = AnalysisAgent(
- llm_service=self.llm_service,
- tools=self.analysis_tools, # Analysis 专属工具
- event_emitter=self.event_emitter,
- )
-
- verification_agent = VerificationAgent(
- llm_service=self.llm_service,
- tools=self.verification_tools, # Verification 专属工具
- event_emitter=self.event_emitter,
- )
-
- # 🔥 保存 Agent 引用以便取消时传播信号
- self._agents = [recon_agent, analysis_agent, verification_agent]
-
- # 创建节点
- recon_node = ReconNode(recon_agent, self.event_emitter)
- analysis_node = AnalysisNode(analysis_agent, self.event_emitter)
- verification_node = VerificationNode(verification_agent, self.event_emitter)
- report_node = ReportNode(None, self.event_emitter)
-
- # 构建图
- self.graph = create_audit_graph(
- recon_node=recon_node,
- analysis_node=analysis_node,
- verification_node=verification_node,
- report_node=report_node,
- checkpointer=self.checkpointer,
- )
-
- await self.event_emitter.emit_info("✅ LangGraph 工作流构建完成")
-
- async def run(self) -> Dict[str, Any]:
- """
- 执行 LangGraph 审计
-
- Returns:
- 最终状态
- """
- final_state = {}
- try:
- async for event in self.run_with_streaming():
- # 收集最终状态
- if event.event_type == StreamEventType.TASK_COMPLETE:
- final_state = event.data
- elif event.event_type == StreamEventType.TASK_ERROR:
- final_state = {"success": False, "error": event.data.get("error")}
- except Exception as e:
- logger.error(f"Agent run failed: {e}", exc_info=True)
- final_state = {"success": False, "error": str(e)}
-
- return final_state
-
- async def run_with_streaming(self) -> AsyncGenerator[StreamEvent, None]:
- """
- 带流式输出的审计执行
-
- Yields:
- StreamEvent: 流式事件(包含 LLM 思考、工具调用等)
- """
- import time
- start_time = time.time()
-
- try:
- # 初始化
- await self.initialize()
-
- # 更新任务状态
- await self._update_task_status(AgentTaskStatus.RUNNING)
-
- # 发射任务开始事件
- yield StreamEvent(
- event_type=StreamEventType.TASK_START,
- sequence=self.stream_handler._next_sequence(),
- data={"task_id": self.task.id, "message": "🚀 审计任务开始"},
- )
-
- # 1. 索引代码
- await self._index_code()
-
- if self._cancelled:
- yield StreamEvent(
- event_type=StreamEventType.TASK_CANCEL,
- sequence=self.stream_handler._next_sequence(),
- data={"message": "任务已取消"},
- )
- return
-
- # 2. 收集项目信息
- project_info = await self._collect_project_info()
-
- # 3. 构建初始状态
- task_config = {
- "target_vulnerabilities": self.task.target_vulnerabilities or [],
- "verification_level": self.task.verification_level or "sandbox",
- "exclude_patterns": self.task.exclude_patterns or [],
- "target_files": self.task.target_files or [],
- "max_iterations": self.task.max_iterations or 50,
- "timeout_seconds": self.task.timeout_seconds or 1800,
- }
-
- initial_state: AuditState = {
- "project_root": self.project_root,
- "project_info": project_info,
- "config": task_config,
- "task_id": self.task.id,
- "tech_stack": {},
- "entry_points": [],
- "high_risk_areas": [],
- "dependencies": {},
- "findings": [],
- "verified_findings": [],
- "false_positives": [],
- "_verified_findings_update": None, # 🔥 NEW: 验证后的 findings 更新
- "current_phase": "start",
- "iteration": 0,
- "max_iterations": self.task.max_iterations or 50,
- "should_continue_analysis": False,
- # 🔥 Agent 协作交接信息
- "recon_handoff": None,
- "analysis_handoff": None,
- "verification_handoff": None,
- "messages": [],
- "events": [],
- "summary": None,
- "security_score": None,
- "error": None,
- }
-
- # 4. 执行 LangGraph with astream_events
- await self.event_emitter.emit_phase_start("langgraph", "🔄 启动 LangGraph 工作流")
-
- run_config = {
- "configurable": {
- "thread_id": self.task.id,
- }
- }
-
- final_state = None
-
- # 使用 astream_events 获取详细事件流
- try:
- async for event in self.graph.astream_events(
- initial_state,
- config=run_config,
- version="v2",
- ):
- if self._cancelled:
- break
-
- # 处理 LangGraph 事件
- stream_event = await self.stream_handler.process_langgraph_event(event)
- if stream_event:
- # 同步到 event_emitter 以持久化
- await self._sync_stream_event_to_db(stream_event)
- yield stream_event
-
- # 更新最终状态
- if event.get("event") == "on_chain_end":
- output = event.get("data", {}).get("output")
- if isinstance(output, dict):
- final_state = output
-
- except Exception as e:
- # 如果 astream_events 不可用,回退到 astream
- logger.warning(f"astream_events not available, falling back to astream: {e}")
- async for event in self.graph.astream(initial_state, config=run_config):
- if self._cancelled:
- break
-
- for node_name, node_output in event.items():
- await self._handle_node_output(node_name, node_output)
-
- # 发射节点事件
- yield StreamEvent(
- event_type=StreamEventType.NODE_END,
- sequence=self.stream_handler._next_sequence(),
- node_name=node_name,
- data={"message": f"节点 {node_name} 完成"},
- )
-
- phase_map = {
- "recon": AgentTaskPhase.RECONNAISSANCE,
- "analysis": AgentTaskPhase.ANALYSIS,
- "verification": AgentTaskPhase.VERIFICATION,
- "report": AgentTaskPhase.REPORTING,
- }
- if node_name in phase_map:
- await self._update_task_phase(phase_map[node_name])
-
- final_state = node_output
-
- # 5. 获取最终状态
- # 🔥 CRITICAL FIX: 始终从 graph 获取完整的累积状态
- # 因为每个节点只返回自己的输出,findings 等字段是通过 operator.add 累积的
- # 直接使用 node_output 会丢失之前节点累积的 findings
- graph_state = self.graph.get_state(run_config)
- if graph_state and graph_state.values:
- # 合并完整状态和最后节点的输出
- full_state = graph_state.values
- if final_state:
- # 保留最后节点的输出(如 summary, security_score)
- full_state = {**full_state, **final_state}
- final_state = full_state
- logger.info(f"[Runner] Got full state from graph with {len(final_state.get('findings', []))} findings")
- elif not final_state:
- final_state = {}
- logger.warning("[Runner] No final state available from graph")
-
- # 🔥 CRITICAL FIX: 如果有验证后的 findings 更新,使用它替换原始 findings
- # 这是因为 LangGraph 的 operator.add 累积器不适合更新已有 findings
- verified_findings_update = final_state.get("_verified_findings_update")
- if verified_findings_update:
- logger.info(f"[Runner] Using verified findings update: {len(verified_findings_update)} findings")
- final_state["findings"] = verified_findings_update
- else:
- # 🔥 FALLBACK: 如果没有 _verified_findings_update,尝试从 verified_findings 合并
- findings = final_state.get("findings", [])
- verified_findings = final_state.get("verified_findings", [])
-
- if verified_findings and findings:
- # 创建合并后的 findings 列表
- merged_findings = self._merge_findings_with_verification(findings, verified_findings)
- final_state["findings"] = merged_findings
- logger.info(f"[Runner] Merged findings: {len(merged_findings)} total")
- elif verified_findings and not findings:
- # 如果只有 verified_findings,直接使用
- final_state["findings"] = verified_findings
- logger.info(f"[Runner] Using verified_findings directly: {len(verified_findings)}")
-
- logger.info(f"[Runner] Final findings count: {len(final_state.get('findings', []))}")
-
- # 🔥 检查是否有错误
- error = final_state.get("error")
- if error:
- # 检查是否是 LLM 认证错误
- error_str = str(error)
- if "AuthenticationError" in error_str or "API key" in error_str or "invalid_api_key" in error_str:
- error_message = "LLM API 密钥配置错误。请检查环境变量 LLM_API_KEY 或配置中的 API 密钥是否正确。"
- logger.error(f"LLM authentication error: {error}")
- else:
- error_message = error_str
-
- duration_ms = int((time.time() - start_time) * 1000)
-
- # 标记任务为失败
- await self._update_task_status(AgentTaskStatus.FAILED, error_message)
- await self.event_emitter.emit_task_error(error_message)
-
- yield StreamEvent(
- event_type=StreamEventType.TASK_ERROR,
- sequence=self.stream_handler._next_sequence(),
- data={
- "error": error_message,
- "message": f"❌ 任务失败: {error_message}",
- },
- )
- return
-
- # 6. 保存发现
- findings = final_state.get("findings", [])
- await self._save_findings(findings)
-
- # 发射发现事件
- for finding in findings[:10]: # 限制数量
- yield self.stream_handler.create_finding_event(
- finding,
- is_verified=finding.get("is_verified", False),
- )
-
- # 7. 更新任务摘要
- summary = final_state.get("summary", {})
- security_score = final_state.get("security_score", 100)
-
- await self._update_task_summary(
- total_findings=len(findings),
- verified_count=len(final_state.get("verified_findings", [])),
- security_score=security_score,
- )
-
- # 8. 完成
- duration_ms = int((time.time() - start_time) * 1000)
-
- await self._update_task_status(AgentTaskStatus.COMPLETED)
- await self.event_emitter.emit_task_complete(
- findings_count=len(findings),
- duration_ms=duration_ms,
- )
-
- yield StreamEvent(
- event_type=StreamEventType.TASK_COMPLETE,
- sequence=self.stream_handler._next_sequence(),
- data={
- "findings_count": len(findings),
- "verified_count": len(final_state.get("verified_findings", [])),
- "security_score": security_score,
- "duration_ms": duration_ms,
- "message": f"✅ 审计完成!发现 {len(findings)} 个漏洞",
- },
- )
-
- except asyncio.CancelledError:
- await self._update_task_status(AgentTaskStatus.CANCELLED)
- yield StreamEvent(
- event_type=StreamEventType.TASK_CANCEL,
- sequence=self.stream_handler._next_sequence(),
- data={"message": "任务已取消"},
- )
-
- except Exception as e:
- logger.error(f"LangGraph run failed: {e}", exc_info=True)
- await self._update_task_status(AgentTaskStatus.FAILED, str(e))
- await self.event_emitter.emit_error(str(e))
-
- yield StreamEvent(
- event_type=StreamEventType.TASK_ERROR,
- sequence=self.stream_handler._next_sequence(),
- data={"error": str(e), "message": f"❌ 审计失败: {e}"},
- )
-
- finally:
- await self._cleanup()
-
- async def _sync_stream_event_to_db(self, event: StreamEvent):
- """同步流式事件到数据库"""
- try:
- # 将 StreamEvent 转换为 AgentEventData
- await self.event_manager.add_event(
- task_id=self.task.id,
- event_type=event.event_type.value,
- sequence=event.sequence,
- phase=event.phase,
- message=event.data.get("message"),
- tool_name=event.tool_name,
- tool_input=event.data.get("input") or event.data.get("input_params"),
- tool_output=event.data.get("output") or event.data.get("output_data"),
- tool_duration_ms=event.data.get("duration_ms"),
- metadata=event.data,
- )
- except Exception as e:
- logger.warning(f"Failed to sync stream event to db: {e}")
-
- async def _handle_node_output(self, node_name: str, output: Dict[str, Any]):
- """处理节点输出"""
- # 发射节点事件
- events = output.get("events", [])
- for evt in events:
- await self.event_emitter.emit_info(
- f"[{node_name}] {evt.get('type', 'event')}: {evt.get('data', {})}"
- )
-
- # 处理新发现
- if node_name == "analysis":
- new_findings = output.get("findings", [])
- if new_findings:
- for finding in new_findings[:5]: # 限制事件数量
- await self.event_emitter.emit_finding(
- title=finding.get("title", "Unknown"),
- severity=finding.get("severity", "medium"),
- file_path=finding.get("file_path"),
- )
-
- # 处理验证结果
- if node_name == "verification":
- verified = output.get("verified_findings", [])
- for v in verified[:5]:
- await self.event_emitter.emit_info(
- f"✅ 已验证: {v.get('title', 'Unknown')}"
- )
-
- # 处理错误
- if output.get("error"):
- await self.event_emitter.emit_error(output["error"])
-
- async def _index_code(self):
- """索引代码"""
- if not self.indexer:
- await self.event_emitter.emit_warning("RAG 未初始化,跳过代码索引")
- return
-
- await self._update_task_phase(AgentTaskPhase.INDEXING)
- await self.event_emitter.emit_phase_start("indexing", "📝 开始代码索引")
-
- try:
- async for progress in self.indexer.index_directory(self.project_root):
- if self._cancelled:
- return
-
- await self.event_emitter.emit_progress(
- progress.processed_files,
- progress.total_files,
- f"正在索引: {progress.current_file or 'N/A'}"
- )
-
- await self.event_emitter.emit_phase_complete("indexing", "✅ 代码索引完成")
-
- except Exception as e:
- logger.warning(f"Code indexing failed: {e}")
- await self.event_emitter.emit_warning(f"代码索引失败: {e}")
-
- async def _collect_project_info(self) -> Dict[str, Any]:
- """收集项目信息"""
- info = {
- "name": self.task.project.name if self.task.project else "unknown",
- "root": self.project_root,
- "languages": [],
- "file_count": 0,
- }
-
- try:
- exclude_dirs = {
- "node_modules", "__pycache__", ".git", "venv", ".venv",
- "build", "dist", "target", ".idea", ".vscode",
- }
-
- for root, dirs, files in os.walk(self.project_root):
- dirs[:] = [d for d in dirs if d not in exclude_dirs]
- info["file_count"] += len(files)
-
- lang_map = {
- ".py": "Python", ".js": "JavaScript", ".ts": "TypeScript",
- ".java": "Java", ".go": "Go", ".php": "PHP",
- ".rb": "Ruby", ".rs": "Rust", ".c": "C", ".cpp": "C++",
- }
-
- for f in files:
- ext = os.path.splitext(f)[1].lower()
- if ext in lang_map and lang_map[ext] not in info["languages"]:
- info["languages"].append(lang_map[ext])
-
- except Exception as e:
- logger.warning(f"Failed to collect project info: {e}")
-
- return info
-
- async def _save_findings(self, findings: List[Dict]):
- """保存发现到数据库"""
- logger.info(f"[Runner] Saving {len(findings)} findings to database for task {self.task.id}")
-
- if not findings:
- logger.info("[Runner] No findings to save")
- return
-
- severity_map = {
- "critical": VulnerabilitySeverity.CRITICAL,
- "high": VulnerabilitySeverity.HIGH,
- "medium": VulnerabilitySeverity.MEDIUM,
- "low": VulnerabilitySeverity.LOW,
- "info": VulnerabilitySeverity.INFO,
- }
-
- type_map = {
- "sql_injection": VulnerabilityType.SQL_INJECTION,
- "nosql_injection": VulnerabilityType.NOSQL_INJECTION,
- "xss": VulnerabilityType.XSS,
- "command_injection": VulnerabilityType.COMMAND_INJECTION,
- "code_injection": VulnerabilityType.CODE_INJECTION,
- "path_traversal": VulnerabilityType.PATH_TRAVERSAL,
- "file_inclusion": VulnerabilityType.FILE_INCLUSION,
- "ssrf": VulnerabilityType.SSRF,
- "xxe": VulnerabilityType.XXE,
- "deserialization": VulnerabilityType.DESERIALIZATION,
- "auth_bypass": VulnerabilityType.AUTH_BYPASS,
- "idor": VulnerabilityType.IDOR,
- "sensitive_data_exposure": VulnerabilityType.SENSITIVE_DATA_EXPOSURE,
- "hardcoded_secret": VulnerabilityType.HARDCODED_SECRET,
- "weak_crypto": VulnerabilityType.WEAK_CRYPTO,
- "race_condition": VulnerabilityType.RACE_CONDITION,
- "business_logic": VulnerabilityType.BUSINESS_LOGIC,
- "memory_corruption": VulnerabilityType.MEMORY_CORRUPTION,
- }
-
- for finding in findings:
- try:
- # 确保 finding 是字典
- if not isinstance(finding, dict):
- logger.warning(f"Skipping invalid finding (not a dict): {finding}")
- continue
-
- db_finding = AgentFinding(
- id=str(uuid.uuid4()),
- task_id=self.task.id,
- vulnerability_type=type_map.get(
- finding.get("vulnerability_type", "other"),
- VulnerabilityType.OTHER
- ),
- severity=severity_map.get(
- finding.get("severity", "medium"),
- VulnerabilitySeverity.MEDIUM
- ),
- title=finding.get("title", "Unknown"),
- description=finding.get("description", ""),
- file_path=finding.get("file_path"),
- line_start=finding.get("line_start"),
- line_end=finding.get("line_end"),
- code_snippet=finding.get("code_snippet"),
- source=finding.get("source"),
- sink=finding.get("sink"),
- suggestion=finding.get("suggestion") or finding.get("recommendation"),
- is_verified=finding.get("is_verified", False),
- confidence=finding.get("confidence", 0.5),
- poc=finding.get("poc"),
- status=FindingStatus.VERIFIED if finding.get("is_verified") else FindingStatus.NEW,
- )
-
- self.db.add(db_finding)
-
- except Exception as e:
- logger.warning(f"Failed to save finding: {e}")
-
- try:
- await self.db.commit()
- logger.info(f"[Runner] Successfully saved {len(findings)} findings to database")
- except Exception as e:
- logger.error(f"Failed to commit findings: {e}")
- await self.db.rollback()
-
- async def _update_task_status(
- self,
- status: AgentTaskStatus,
- error: Optional[str] = None
- ):
- """更新任务状态"""
- self.task.status = status
-
- if status == AgentTaskStatus.RUNNING:
- self.task.started_at = datetime.now(timezone.utc)
- elif status in [AgentTaskStatus.COMPLETED, AgentTaskStatus.FAILED, AgentTaskStatus.CANCELLED]:
- self.task.finished_at = datetime.now(timezone.utc)
-
- if error:
- self.task.error_message = error
-
- try:
- await self.db.commit()
- except Exception as e:
- logger.error(f"Failed to update task status: {e}")
-
- async def _update_task_phase(self, phase: AgentTaskPhase):
- """更新任务阶段"""
- self.task.current_phase = phase
- try:
- await self.db.commit()
- except Exception as e:
- logger.error(f"Failed to update task phase: {e}")
-
- async def _update_task_summary(
- self,
- total_findings: int,
- verified_count: int,
- security_score: int,
- ):
- """更新任务摘要"""
- self.task.total_findings = total_findings
- self.task.verified_findings = verified_count
- self.task.security_score = security_score
-
- try:
- await self.db.commit()
- except Exception as e:
- logger.error(f"Failed to update task summary: {e}")
-
- def _merge_findings_with_verification(
- self,
- findings: List[Dict],
- verified_findings: List[Dict],
- ) -> List[Dict]:
- """
- 合并原始 findings 和验证结果
-
- Args:
- findings: 原始 findings 列表
- verified_findings: 验证后的 findings 列表
-
- Returns:
- 合并后的 findings 列表
- """
- # 创建验证结果的查找映射
- verified_map = {}
- for vf in verified_findings:
- if not isinstance(vf, dict):
- continue
- key = (
- vf.get("file_path", ""),
- vf.get("line_start", 0),
- vf.get("vulnerability_type", ""),
- )
- verified_map[key] = vf
-
- merged = []
- seen_keys = set()
-
- # 首先处理原始 findings
- for f in findings:
- if not isinstance(f, dict):
- continue
-
- key = (
- f.get("file_path", ""),
- f.get("line_start", 0),
- f.get("vulnerability_type", ""),
- )
-
- if key in verified_map:
- # 使用验证后的版本(包含 is_verified, poc 等)
- merged.append(verified_map[key])
- else:
- # 保留原始 finding
- merged.append(f)
-
- seen_keys.add(key)
-
- # 添加验证结果中的新发现(如果有)
- for key, vf in verified_map.items():
- if key not in seen_keys:
- merged.append(vf)
-
- return merged
-
- async def _cleanup(self):
- """清理资源"""
- try:
- if self.sandbox_manager:
- await self.sandbox_manager.cleanup()
- await self.event_manager.close()
- except Exception as e:
- logger.warning(f"Cleanup error: {e}")
-
-
-# 便捷函数
-async def run_agent_task(
- db: AsyncSession,
- task: AgentTask,
- project_root: str,
-) -> Dict[str, Any]:
- """
- 运行 Agent 审计任务
-
- Args:
- db: 数据库会话
- task: Agent 任务
- project_root: 项目根目录
-
- Returns:
- 审计结果
- """
- runner = AgentRunner(db, task, project_root)
- return await runner.run()
-
diff --git a/backend/app/services/agent/tools/file_tool.py b/backend/app/services/agent/tools/file_tool.py
index f8b1d49..009cf44 100644
--- a/backend/app/services/agent/tools/file_tool.py
+++ b/backend/app/services/agent/tools/file_tool.py
@@ -6,6 +6,7 @@
import os
import re
import fnmatch
+import asyncio
from typing import Optional, List, Dict, Any
from pydantic import BaseModel, Field
@@ -44,7 +45,37 @@ class FileReadTool(AgentTool):
self.project_root = project_root
self.exclude_patterns = exclude_patterns or []
self.target_files = set(target_files) if target_files else None
-
+
+ @staticmethod
+ def _read_file_lines_sync(file_path: str, start_idx: int, end_idx: int) -> tuple:
+ """同步读取文件指定行范围(用于 asyncio.to_thread)"""
+ selected_lines = []
+ total_lines = 0
+ file_size = os.path.getsize(file_path)
+
+ with open(file_path, 'r', encoding='utf-8', errors='ignore') as f:
+ for i, line in enumerate(f):
+ total_lines = i + 1
+ if i >= start_idx and i < end_idx:
+ selected_lines.append(line)
+ elif i >= end_idx:
+ if i < end_idx + 1000:
+ continue
+ else:
+ remaining_bytes = file_size - f.tell()
+ avg_line_size = f.tell() / (i + 1)
+ estimated_remaining_lines = int(remaining_bytes / avg_line_size) if avg_line_size > 0 else 0
+ total_lines = i + 1 + estimated_remaining_lines
+ break
+
+ return selected_lines, total_lines
+
+ @staticmethod
+ def _read_all_lines_sync(file_path: str) -> List[str]:
+ """同步读取文件所有行(用于 asyncio.to_thread)"""
+ with open(file_path, 'r', encoding='utf-8', errors='ignore') as f:
+ return f.readlines()
+
@property
def name(self) -> str:
return "read_file"
@@ -136,51 +167,34 @@ class FileReadTool(AgentTool):
# 🔥 对于大文件,使用流式读取指定行范围
if is_large_file and (start_line is not None or end_line is not None):
- # 流式读取,避免一次性加载整个文件
- selected_lines = []
- total_lines = 0
-
# 计算实际的起始和结束行
start_idx = max(0, (start_line or 1) - 1)
end_idx = end_line if end_line else start_idx + max_lines
-
- with open(full_path, 'r', encoding='utf-8', errors='ignore') as f:
- for i, line in enumerate(f):
- total_lines = i + 1
- if i >= start_idx and i < end_idx:
- selected_lines.append(line)
- elif i >= end_idx:
- # 继续计数以获取总行数,但限制读取量
- if i < end_idx + 1000: # 最多再读1000行来估算总行数
- continue
- else:
- # 估算剩余行数
- remaining_bytes = file_size - f.tell()
- avg_line_size = f.tell() / (i + 1)
- estimated_remaining_lines = int(remaining_bytes / avg_line_size) if avg_line_size > 0 else 0
- total_lines = i + 1 + estimated_remaining_lines
- break
-
+
+ # 异步读取文件,避免阻塞事件循环
+ selected_lines, total_lines = await asyncio.to_thread(
+ self._read_file_lines_sync, full_path, start_idx, end_idx
+ )
+
# 更新实际的结束索引
end_idx = min(end_idx, start_idx + len(selected_lines))
else:
- # 正常读取小文件
- with open(full_path, 'r', encoding='utf-8', errors='ignore') as f:
- lines = f.readlines()
-
+ # 异步读取小文件,避免阻塞事件循环
+ lines = await asyncio.to_thread(self._read_all_lines_sync, full_path)
+
total_lines = len(lines)
-
+
# 处理行范围
if start_line is not None:
start_idx = max(0, start_line - 1)
else:
start_idx = 0
-
+
if end_line is not None:
end_idx = min(total_lines, end_line)
else:
end_idx = min(total_lines, start_idx + max_lines)
-
+
# 截取指定行
selected_lines = lines[start_idx:end_idx]
@@ -259,7 +273,7 @@ class FileSearchTool(AgentTool):
self.project_root = project_root
self.exclude_patterns = exclude_patterns or []
self.target_files = set(target_files) if target_files else None
-
+
# 从 exclude_patterns 中提取目录排除
self.exclude_dirs = set(self.DEFAULT_EXCLUDE_DIRS)
for pattern in self.exclude_patterns:
@@ -267,7 +281,13 @@ class FileSearchTool(AgentTool):
self.exclude_dirs.add(pattern[:-3])
elif "/" not in pattern and "*" not in pattern:
self.exclude_dirs.add(pattern)
-
+
+ @staticmethod
+ def _read_file_lines_sync(file_path: str) -> List[str]:
+ """同步读取文件所有行(用于 asyncio.to_thread)"""
+ with open(file_path, 'r', encoding='utf-8', errors='ignore') as f:
+ return f.readlines()
+
@property
def name(self) -> str:
return "search_code"
@@ -360,11 +380,13 @@ class FileSearchTool(AgentTool):
continue
try:
- with open(file_path, 'r', encoding='utf-8', errors='ignore') as f:
- lines = f.readlines()
-
+ # 异步读取文件,避免阻塞事件循环
+ lines = await asyncio.to_thread(
+ self._read_file_lines_sync, file_path
+ )
+
files_searched += 1
-
+
for i, line in enumerate(lines):
if pattern.search(line):
# 获取上下文
diff --git a/backend/app/services/llm/adapters/litellm_adapter.py b/backend/app/services/llm/adapters/litellm_adapter.py
index 842406c..3a78549 100644
--- a/backend/app/services/llm/adapters/litellm_adapter.py
+++ b/backend/app/services/llm/adapters/litellm_adapter.py
@@ -416,13 +416,93 @@ class LiteLLMAdapter(BaseLLMAdapter):
"finish_reason": "complete",
}
- except Exception as e:
- # 🔥 即使出错,也尝试返回估算的 usage
- logger.error(f"Stream error: {e}")
+ except litellm.exceptions.RateLimitError as e:
+ # 速率限制错误 - 需要特殊处理
+ logger.error(f"Stream rate limit error: {e}")
+ error_msg = str(e)
+ # 区分"余额不足"和"频率超限"
+ if any(keyword in error_msg.lower() for keyword in ["余额不足", "资源包", "充值", "quota", "exceeded", "billing"]):
+ error_type = "quota_exceeded"
+ user_message = "API 配额已用尽,请检查账户余额或升级计划"
+ else:
+ error_type = "rate_limit"
+ # 尝试从错误消息中提取重试时间
+ import re
+ retry_match = re.search(r"retry\s*(?:in|after)\s*(\d+(?:\.\d+)?)\s*s", error_msg, re.IGNORECASE)
+ retry_seconds = float(retry_match.group(1)) if retry_match else 60
+ user_message = f"API 调用频率超限,建议等待 {int(retry_seconds)} 秒后重试"
+
output_tokens_estimate = estimate_tokens(accumulated_content) if accumulated_content else 0
yield {
"type": "error",
+ "error_type": error_type,
+ "error": error_msg,
+ "user_message": user_message,
+ "accumulated": accumulated_content,
+ "usage": {
+ "prompt_tokens": input_tokens_estimate,
+ "completion_tokens": output_tokens_estimate,
+ "total_tokens": input_tokens_estimate + output_tokens_estimate,
+ } if accumulated_content else None,
+ }
+
+ except litellm.exceptions.AuthenticationError as e:
+ # 认证错误 - API Key 无效
+ logger.error(f"Stream authentication error: {e}")
+ yield {
+ "type": "error",
+ "error_type": "authentication",
"error": str(e),
+ "user_message": "API Key 无效或已过期,请检查配置",
+ "accumulated": accumulated_content,
+ "usage": None,
+ }
+
+ except litellm.exceptions.APIConnectionError as e:
+ # 连接错误 - 网络问题
+ logger.error(f"Stream connection error: {e}")
+ yield {
+ "type": "error",
+ "error_type": "connection",
+ "error": str(e),
+ "user_message": "无法连接到 API 服务,请检查网络连接",
+ "accumulated": accumulated_content,
+ "usage": None,
+ }
+
+ except Exception as e:
+ # 其他错误 - 检查是否是包装的速率限制错误
+ error_msg = str(e)
+ logger.error(f"Stream error: {e}")
+
+ # 检查是否是包装的速率限制错误(如 ServiceUnavailableError 包装 RateLimitError)
+ is_rate_limit = any(keyword in error_msg.lower() for keyword in [
+ "ratelimiterror", "rate limit", "429", "resource_exhausted",
+ "quota exceeded", "too many requests"
+ ])
+
+ if is_rate_limit:
+ # 按速率限制错误处理
+ import re
+ # 检查是否是配额用尽
+ if any(keyword in error_msg.lower() for keyword in ["quota", "exceeded", "billing"]):
+ error_type = "quota_exceeded"
+ user_message = "API 配额已用尽,请检查账户余额或升级计划"
+ else:
+ error_type = "rate_limit"
+ retry_match = re.search(r"retry\s*(?:in|after)\s*(\d+(?:\.\d+)?)\s*s", error_msg, re.IGNORECASE)
+ retry_seconds = float(retry_match.group(1)) if retry_match else 60
+ user_message = f"API 调用频率超限,建议等待 {int(retry_seconds)} 秒后重试"
+ else:
+ error_type = "unknown"
+ user_message = "LLM 调用发生错误,请重试"
+
+ output_tokens_estimate = estimate_tokens(accumulated_content) if accumulated_content else 0
+ yield {
+ "type": "error",
+ "error_type": error_type,
+ "error": error_msg,
+ "user_message": user_message,
"accumulated": accumulated_content,
"usage": {
"prompt_tokens": input_tokens_estimate,
diff --git a/backend/app/services/rag/indexer.py b/backend/app/services/rag/indexer.py
index d82ba68..bdb15ed 100644
--- a/backend/app/services/rag/indexer.py
+++ b/backend/app/services/rag/indexer.py
@@ -739,6 +739,20 @@ class CodeIndexer:
self._needs_rebuild = False
self._rebuild_reason = ""
+ @staticmethod
+ def _read_file_sync(file_path: str) -> str:
+ """
+ 同步读取文件内容(用于 asyncio.to_thread 包装)
+
+ Args:
+ file_path: 文件路径
+
+ Returns:
+ 文件内容
+ """
+ with open(file_path, 'r', encoding='utf-8', errors='ignore') as f:
+ return f.read()
+
async def initialize(self, force_rebuild: bool = False) -> Tuple[bool, str]:
"""
初始化索引器,检测是否需要重建索引
@@ -916,8 +930,10 @@ class CodeIndexer:
try:
relative_path = os.path.relpath(file_path, directory)
- with open(file_path, 'r', encoding='utf-8', errors='ignore') as f:
- content = f.read()
+ # 异步读取文件,避免阻塞事件循环
+ content = await asyncio.to_thread(
+ self._read_file_sync, file_path
+ )
if not content.strip():
progress.processed_files += 1
@@ -932,8 +948,8 @@ class CodeIndexer:
if len(content) > 500000:
content = content[:500000]
- # 分块
- chunks = self.splitter.split_file(content, relative_path)
+ # 异步分块,避免 Tree-sitter 解析阻塞事件循环
+ chunks = await self.splitter.split_file_async(content, relative_path)
# 为每个 chunk 添加 file_hash
for chunk in chunks:
@@ -1018,8 +1034,10 @@ class CodeIndexer:
for relative_path in files_to_check:
file_path = current_file_map[relative_path]
try:
- with open(file_path, 'r', encoding='utf-8', errors='ignore') as f:
- content = f.read()
+ # 异步读取文件,避免阻塞事件循环
+ content = await asyncio.to_thread(
+ self._read_file_sync, file_path
+ )
current_hash = hashlib.md5(content.encode()).hexdigest()
if current_hash != indexed_file_hashes.get(relative_path):
files_to_update.add(relative_path)
@@ -1055,8 +1073,10 @@ class CodeIndexer:
is_update = relative_path in files_to_update
try:
- with open(file_path, 'r', encoding='utf-8', errors='ignore') as f:
- content = f.read()
+ # 异步读取文件,避免阻塞事件循环
+ content = await asyncio.to_thread(
+ self._read_file_sync, file_path
+ )
if not content.strip():
progress.processed_files += 1
@@ -1075,8 +1095,8 @@ class CodeIndexer:
if len(content) > 500000:
content = content[:500000]
- # 分块
- chunks = self.splitter.split_file(content, relative_path)
+ # 异步分块,避免 Tree-sitter 解析阻塞事件循环
+ chunks = await self.splitter.split_file_async(content, relative_path)
# 为每个 chunk 添加 file_hash
for chunk in chunks:
diff --git a/backend/app/services/rag/splitter.py b/backend/app/services/rag/splitter.py
index 184db35..5c350b0 100644
--- a/backend/app/services/rag/splitter.py
+++ b/backend/app/services/rag/splitter.py
@@ -4,6 +4,7 @@
"""
import re
+import asyncio
import hashlib
import logging
from typing import List, Dict, Any, Optional, Tuple, Set
@@ -154,7 +155,7 @@ class TreeSitterParser:
".c": "c",
".h": "c",
".hpp": "cpp",
- ".cs": "c_sharp",
+ ".cs": "csharp",
".php": "php",
".rb": "ruby",
".kt": "kotlin",
@@ -197,7 +198,7 @@ class TreeSitterParser:
# tree-sitter-languages 支持的语言列表
SUPPORTED_LANGUAGES = {
"python", "javascript", "typescript", "tsx", "java", "go", "rust",
- "c", "cpp", "c_sharp", "php", "ruby", "kotlin", "swift", "bash",
+ "c", "cpp", "csharp", "php", "ruby", "kotlin", "swift", "bash",
"json", "yaml", "html", "css", "sql", "markdown",
}
@@ -230,21 +231,30 @@ class TreeSitterParser:
return False
def parse(self, code: str, language: str) -> Optional[Any]:
- """解析代码返回 AST"""
+ """解析代码返回 AST(同步方法)"""
if not self._ensure_initialized(language):
return None
-
+
parser = self._parsers.get(language)
if not parser:
return None
-
+
try:
tree = parser.parse(code.encode())
return tree
except Exception as e:
logger.warning(f"Failed to parse code: {e}")
return None
-
+
+ async def parse_async(self, code: str, language: str) -> Optional[Any]:
+ """
+ 异步解析代码返回 AST
+
+ 将 CPU 密集型的 Tree-sitter 解析操作放到线程池中执行,
+ 避免阻塞事件循环
+ """
+ return await asyncio.to_thread(self.parse, code, language)
+
def extract_definitions(self, tree: Any, code: str, language: str) -> List[Dict[str, Any]]:
"""从 AST 提取定义"""
if tree is None:
@@ -449,9 +459,31 @@ class CodeSplitter:
except Exception as e:
logger.warning(f"分块失败 {file_path}: {e}, 使用简单分块")
chunks = self._split_by_lines(content, file_path, language)
-
+
return chunks
-
+
+ async def split_file_async(
+ self,
+ content: str,
+ file_path: str,
+ language: Optional[str] = None
+ ) -> List[CodeChunk]:
+ """
+ 异步分割单个文件
+
+ 将 CPU 密集型的分块操作(包括 Tree-sitter 解析)放到线程池中执行,
+ 避免阻塞事件循环。
+
+ Args:
+ content: 文件内容
+ file_path: 文件路径
+ language: 编程语言(可选)
+
+ Returns:
+ 代码块列表
+ """
+ return await asyncio.to_thread(self.split_file, content, file_path, language)
+
def _split_by_ast(
self,
content: str,
diff --git a/backend/tests/agent/test_integration.py b/backend/tests/agent/test_integration.py
deleted file mode 100644
index d314404..0000000
--- a/backend/tests/agent/test_integration.py
+++ /dev/null
@@ -1,355 +0,0 @@
-"""
-Agent 集成测试
-测试完整的审计流程
-"""
-
-import pytest
-import asyncio
-import os
-from unittest.mock import MagicMock, AsyncMock, patch
-from datetime import datetime
-
-from app.services.agent.graph.runner import AgentRunner, LLMService
-from app.services.agent.graph.audit_graph import AuditState, create_audit_graph
-from app.services.agent.graph.nodes import ReconNode, AnalysisNode, VerificationNode, ReportNode
-from app.services.agent.event_manager import EventManager, AgentEventEmitter
-
-
-class TestLLMService:
- """LLM 服务测试"""
-
- @pytest.mark.asyncio
- async def test_llm_service_initialization(self):
- """测试 LLM 服务初始化"""
- with patch("app.core.config.settings") as mock_settings:
- mock_settings.LLM_MODEL = "gpt-4o-mini"
- mock_settings.LLM_API_KEY = "test-key"
-
- service = LLMService()
-
- assert service.model == "gpt-4o-mini"
-
-
-class TestEventManager:
- """事件管理器测试"""
-
- def test_event_manager_initialization(self):
- """测试事件管理器初始化"""
- manager = EventManager()
-
- assert manager._event_queues == {}
- assert manager._event_callbacks == {}
-
- @pytest.mark.asyncio
- async def test_event_emitter(self):
- """测试事件发射器"""
- manager = EventManager()
- emitter = AgentEventEmitter("test-task-id", manager)
-
- await emitter.emit_info("Test message")
-
- assert emitter._sequence == 1
-
- @pytest.mark.asyncio
- async def test_event_emitter_phase_tracking(self):
- """测试事件发射器阶段跟踪"""
- manager = EventManager()
- emitter = AgentEventEmitter("test-task-id", manager)
-
- await emitter.emit_phase_start("recon", "开始信息收集")
-
- assert emitter._current_phase == "recon"
-
- @pytest.mark.asyncio
- async def test_event_emitter_task_complete(self):
- """测试任务完成事件"""
- manager = EventManager()
- emitter = AgentEventEmitter("test-task-id", manager)
-
- await emitter.emit_task_complete(findings_count=5, duration_ms=1000)
-
- assert emitter._sequence == 1
-
-
-class TestAuditGraph:
- """审计图测试"""
-
- def test_create_audit_graph(self, mock_event_emitter):
- """测试创建审计图"""
- # 创建模拟节点
- recon_node = MagicMock()
- analysis_node = MagicMock()
- verification_node = MagicMock()
- report_node = MagicMock()
-
- graph = create_audit_graph(
- recon_node=recon_node,
- analysis_node=analysis_node,
- verification_node=verification_node,
- report_node=report_node,
- )
-
- assert graph is not None
-
-
-class TestReconNode:
- """Recon 节点测试"""
-
- @pytest.fixture
- def recon_node_with_mock_agent(self, mock_event_emitter):
- """创建带模拟 Agent 的 Recon 节点"""
- mock_agent = MagicMock()
- mock_agent.run = AsyncMock(return_value=MagicMock(
- success=True,
- data={
- "tech_stack": {"languages": ["Python"]},
- "entry_points": [{"path": "src/app.py", "type": "api"}],
- "high_risk_areas": ["src/sql_vuln.py"],
- "dependencies": {},
- "initial_findings": [],
- }
- ))
-
- return ReconNode(mock_agent, mock_event_emitter)
-
- @pytest.mark.asyncio
- async def test_recon_node_success(self, recon_node_with_mock_agent):
- """测试 Recon 节点成功执行"""
- state = {
- "project_info": {"name": "Test"},
- "config": {},
- }
-
- result = await recon_node_with_mock_agent(state)
-
- assert "tech_stack" in result
- assert "entry_points" in result
- assert result["current_phase"] == "recon_complete"
-
- @pytest.mark.asyncio
- async def test_recon_node_failure(self, mock_event_emitter):
- """测试 Recon 节点失败处理"""
- mock_agent = MagicMock()
- mock_agent.run = AsyncMock(return_value=MagicMock(
- success=False,
- error="Test error",
- data=None,
- ))
-
- node = ReconNode(mock_agent, mock_event_emitter)
-
- result = await node({
- "project_info": {},
- "config": {},
- })
-
- assert "error" in result
- assert result["current_phase"] == "error"
-
-
-class TestAnalysisNode:
- """Analysis 节点测试"""
-
- @pytest.fixture
- def analysis_node_with_mock_agent(self, mock_event_emitter):
- """创建带模拟 Agent 的 Analysis 节点"""
- mock_agent = MagicMock()
- mock_agent.run = AsyncMock(return_value=MagicMock(
- success=True,
- data={
- "findings": [
- {
- "id": "finding-1",
- "title": "SQL Injection",
- "severity": "high",
- "vulnerability_type": "sql_injection",
- "file_path": "src/sql_vuln.py",
- "line_start": 10,
- "description": "SQL injection vulnerability",
- }
- ],
- "should_continue": False,
- }
- ))
-
- return AnalysisNode(mock_agent, mock_event_emitter)
-
- @pytest.mark.asyncio
- async def test_analysis_node_success(self, analysis_node_with_mock_agent):
- """测试 Analysis 节点成功执行"""
- state = {
- "project_info": {"name": "Test"},
- "tech_stack": {"languages": ["Python"]},
- "entry_points": [],
- "high_risk_areas": ["src/sql_vuln.py"],
- "config": {},
- "iteration": 0,
- "findings": [],
- }
-
- result = await analysis_node_with_mock_agent(state)
-
- assert "findings" in result
- assert len(result["findings"]) > 0
- assert result["iteration"] == 1
-
-
-class TestIntegrationFlow:
- """完整流程集成测试"""
-
- @pytest.mark.asyncio
- async def test_full_audit_flow_mock(self, temp_project_dir, mock_db_session, mock_task):
- """测试完整审计流程(使用模拟)"""
- # 这个测试验证整个流程的连接性
-
- # 创建事件管理器
- event_manager = EventManager()
- emitter = AgentEventEmitter(mock_task.id, event_manager)
-
- # 模拟 LLM 服务
- mock_llm = MagicMock()
- mock_llm.chat_completion_raw = AsyncMock(return_value={
- "content": "Analysis complete",
- "usage": {"total_tokens": 100},
- })
-
- # 验证事件发射
- await emitter.emit_phase_start("init", "初始化")
- await emitter.emit_info("测试消息")
- await emitter.emit_phase_complete("init", "初始化完成")
-
- assert emitter._sequence == 3
-
- @pytest.mark.asyncio
- async def test_audit_state_typing(self):
- """测试审计状态类型定义"""
- state: AuditState = {
- "project_root": "/tmp/test",
- "project_info": {"name": "Test"},
- "config": {},
- "task_id": "test-id",
- "tech_stack": {},
- "entry_points": [],
- "high_risk_areas": [],
- "dependencies": {},
- "findings": [],
- "verified_findings": [],
- "false_positives": [],
- "current_phase": "start",
- "iteration": 0,
- "max_iterations": 50,
- "should_continue_analysis": False,
- "messages": [],
- "events": [],
- "summary": None,
- "security_score": None,
- "error": None,
- }
-
- assert state["current_phase"] == "start"
- assert state["max_iterations"] == 50
-
-
-class TestToolIntegration:
- """工具集成测试"""
-
- @pytest.mark.asyncio
- async def test_tools_work_together(self, temp_project_dir):
- """测试工具协同工作"""
- from app.services.agent.tools import (
- FileReadTool, FileSearchTool, ListFilesTool, PatternMatchTool,
- )
-
- # 1. 列出文件
- list_tool = ListFilesTool(temp_project_dir)
- list_result = await list_tool.execute(directory="src", recursive=False)
- assert list_result.success is True
-
- # 2. 搜索关键代码
- search_tool = FileSearchTool(temp_project_dir)
- search_result = await search_tool.execute(keyword="execute")
- assert search_result.success is True
-
- # 3. 读取文件内容
- read_tool = FileReadTool(temp_project_dir)
- read_result = await read_tool.execute(file_path="src/sql_vuln.py")
- assert read_result.success is True
-
- # 4. 模式匹配
- pattern_tool = PatternMatchTool(temp_project_dir)
- pattern_result = await pattern_tool.execute(
- code=read_result.data,
- file_path="src/sql_vuln.py",
- language="python"
- )
- assert pattern_result.success is True
-
-
-class TestErrorHandling:
- """错误处理测试"""
-
- @pytest.mark.asyncio
- async def test_tool_error_handling(self, temp_project_dir):
- """测试工具错误处理"""
- from app.services.agent.tools import FileReadTool
-
- tool = FileReadTool(temp_project_dir)
-
- # 尝试读取不存在的文件
- result = await tool.execute(file_path="nonexistent/file.py")
-
- assert result.success is False
- assert result.error is not None
-
- @pytest.mark.asyncio
- async def test_agent_graceful_degradation(self, mock_event_emitter):
- """测试 Agent 优雅降级"""
- # 创建一个会失败的 Agent
- mock_agent = MagicMock()
- mock_agent.run = AsyncMock(side_effect=Exception("Simulated error"))
-
- node = ReconNode(mock_agent, mock_event_emitter)
-
- result = await node({
- "project_info": {},
- "config": {},
- })
-
- # 应该返回错误状态而不是崩溃
- assert "error" in result
- assert result["current_phase"] == "error"
-
-
-class TestPerformance:
- """性能测试"""
-
- @pytest.mark.asyncio
- async def test_tool_response_time(self, temp_project_dir):
- """测试工具响应时间"""
- from app.services.agent.tools import ListFilesTool
- import time
-
- tool = ListFilesTool(temp_project_dir)
-
- start = time.time()
- await tool.execute(directory=".", recursive=True)
- duration = time.time() - start
-
- # 工具应该在合理时间内响应
- assert duration < 5.0 # 5 秒内
-
- @pytest.mark.asyncio
- async def test_multiple_tool_calls(self, temp_project_dir):
- """测试多次工具调用"""
- from app.services.agent.tools import FileSearchTool
-
- tool = FileSearchTool(temp_project_dir)
-
- # 执行多次调用
- for _ in range(5):
- result = await tool.execute(keyword="def")
- assert result.success is True
-
- # 验证调用计数
- assert tool._call_count == 5
-
diff --git a/frontend/src/components/audit/components/BasicConfig.tsx b/frontend/src/components/audit/components/BasicConfig.tsx
index c82968e..b613f85 100644
--- a/frontend/src/components/audit/components/BasicConfig.tsx
+++ b/frontend/src/components/audit/components/BasicConfig.tsx
@@ -9,7 +9,7 @@ import {
} from "@/components/ui/select";
import { GitBranch, Zap, Info } from "lucide-react";
import type { Project, CreateAuditTaskForm } from "@/shared/types";
-import { isRepositoryProject, isZipProject } from "@/shared/utils/projectUtils";
+import { isRepositoryProject, isZipProject, getRepositoryPlatformLabel } from "@/shared/utils/projectUtils";
import ZipFileSection from "./ZipFileSection";
import type { ZipFileMeta } from "@/shared/utils/zipStorage";
@@ -138,7 +138,7 @@ function ProjectInfoCard({ project }: { project: Project }) {
{isRepo && (
<>
- 仓库平台:{project.repository_type?.toUpperCase() || "OTHER"}
+ 仓库平台:{getRepositoryPlatformLabel(project.repository_type)}
默认分支:{project.default_branch}
>
diff --git a/frontend/src/pages/ProjectDetail.tsx b/frontend/src/pages/ProjectDetail.tsx
index 9d282c2..28ea711 100644
--- a/frontend/src/pages/ProjectDetail.tsx
+++ b/frontend/src/pages/ProjectDetail.tsx
@@ -34,13 +34,13 @@ import { api } from "@/shared/config/database";
import { runRepositoryAudit, scanStoredZipFile } from "@/features/projects/services";
import type { Project, AuditTask, CreateProjectForm } from "@/shared/types";
import { hasZipFile } from "@/shared/utils/zipStorage";
-import { isRepositoryProject, getSourceTypeLabel } from "@/shared/utils/projectUtils";
+import { isRepositoryProject, getSourceTypeLabel, getRepositoryPlatformLabel } from "@/shared/utils/projectUtils";
import { toast } from "sonner";
import CreateTaskDialog from "@/components/audit/CreateTaskDialog";
import FileSelectionDialog from "@/components/audit/FileSelectionDialog";
import TerminalProgressDialog from "@/components/audit/TerminalProgressDialog";
import { Dialog, DialogContent, DialogHeader, DialogTitle, DialogFooter } from "@/components/ui/dialog";
-import { SUPPORTED_LANGUAGES } from "@/shared/constants";
+import { SUPPORTED_LANGUAGES, REPOSITORY_PLATFORMS } from "@/shared/constants";
export default function ProjectDetail() {
const { id } = useParams<{ id: string }>();
@@ -475,8 +475,7 @@ export default function ProjectDetail() {
仓库平台
- {project.repository_type === 'github' ? 'GitHub' :
- project.repository_type === 'gitlab' ? 'GitLab' : '其他'}
+ {getRepositoryPlatformLabel(project.repository_type)}
@@ -529,12 +528,11 @@ export default function ProjectDetail() {
className="flex items-center justify-between p-3 bg-muted/50 rounded-lg hover:bg-muted transition-all group"
>
-
+ task.status === 'failed' ? 'bg-rose-500/20' :
+ 'bg-muted'
+ }`}>
{getStatusIcon(task.status)}
@@ -579,12 +577,11 @@ export default function ProjectDetail() {
-
+ task.status === 'failed' ? 'bg-rose-500/20' :
+ 'bg-muted'
+ }`}>
{getStatusIcon(task.status)}
@@ -676,12 +673,11 @@ export default function ProjectDetail() {
-
+ issue.severity === 'medium' ? 'bg-amber-500/20 text-amber-600 dark:text-amber-400' :
+ 'bg-sky-500/20 text-sky-600 dark:text-sky-400'
+ }`}>
@@ -695,13 +691,13 @@ export default function ProjectDetail() {
{issue.severity === 'critical' ? '严重' :
issue.severity === 'high' ? '高' :
- issue.severity === 'medium' ? '中等' : '低'}
+ issue.severity === 'medium' ? '中等' : '低'}
@@ -783,9 +779,11 @@ export default function ProjectDetail() {
- GitHub
- GitLab
- 其他
+ {REPOSITORY_PLATFORMS.map((platform) => (
+
+ {platform.label}
+
+ ))}
@@ -831,14 +829,14 @@ export default function ProjectDetail() {
className={`flex items-center space-x-2 p-3 border cursor-pointer transition-all rounded ${editForm.programming_languages?.includes(lang)
? 'border-primary bg-primary/10 text-primary'
: 'border-border hover:border-border text-muted-foreground'
- }`}
+ }`}
onClick={() => handleToggleLanguage(lang)}
>
{editForm.programming_languages?.includes(lang) && (
diff --git a/frontend/src/pages/Projects.tsx b/frontend/src/pages/Projects.tsx
index 979f188..7a94a48 100644
--- a/frontend/src/pages/Projects.tsx
+++ b/frontend/src/pages/Projects.tsx
@@ -45,7 +45,7 @@ import { Link } from "react-router-dom";
import { toast } from "sonner";
import CreateTaskDialog from "@/components/audit/CreateTaskDialog";
import TerminalProgressDialog from "@/components/audit/TerminalProgressDialog";
-import { SUPPORTED_LANGUAGES } from "@/shared/constants";
+import { SUPPORTED_LANGUAGES, REPOSITORY_PLATFORMS } from "@/shared/constants";
export default function Projects() {
const [projects, setProjects] = useState
([]);
@@ -487,10 +487,11 @@ export default function Projects() {
- GITHUB
- GITLAB
- GITEA
- OTHER
+ {REPOSITORY_PLATFORMS.map((platform) => (
+
+ {platform.label}
+
+ ))}
@@ -1046,10 +1047,11 @@ export default function Projects() {
- GITHUB
- GITLAB
- GITEA
- OTHER
+ {REPOSITORY_PLATFORMS.map((platform) => (
+
+ {platform.label}
+
+ ))}
diff --git a/frontend/src/pages/TaskDetail.tsx b/frontend/src/pages/TaskDetail.tsx
index 0d05359..21d3372 100644
--- a/frontend/src/pages/TaskDetail.tsx
+++ b/frontend/src/pages/TaskDetail.tsx
@@ -36,7 +36,7 @@ import type { AuditTask, AuditIssue } from "@/shared/types";
import { toast } from "sonner";
import ExportReportDialog from "@/components/reports/ExportReportDialog";
import { calculateTaskProgress } from "@/shared/utils/utils";
-import { isRepositoryProject, getSourceTypeLabel } from "@/shared/utils/projectUtils";
+import { isRepositoryProject, getSourceTypeLabel, getRepositoryPlatformLabel } from "@/shared/utils/projectUtils";
// AI explanation parser
function parseAIExplanation(aiExplanation: string) {
@@ -86,12 +86,11 @@ function IssuesList({ issues }: { issues: AuditIssue[] }) {
-
+
{getTypeIcon(issue.issue_type)}
@@ -112,7 +111,7 @@ function IssuesList({ issues }: { issues: AuditIssue[] }) {
{issue.severity === 'critical' ? '严重' :
issue.severity === 'high' ? '高' :
- issue.severity === 'medium' ? '中等' : '低'}
+ issue.severity === 'medium' ? '中等' : '低'}
@@ -702,7 +701,7 @@ export default function TaskDetail() {
{isRepositoryProject(task.project) && (
仓库平台
-
{task.project.repository_type?.toUpperCase() || 'OTHER'}
+
{getRepositoryPlatformLabel(task.project.repository_type)}
)}
{task.project.programming_languages && (
diff --git a/frontend/src/shared/constants/index.ts b/frontend/src/shared/constants/index.ts
index 2f48818..401ecb2 100644
--- a/frontend/src/shared/constants/index.ts
+++ b/frontend/src/shared/constants/index.ts
@@ -62,13 +62,6 @@ export const PROJECT_SOURCE_TYPES = {
ZIP: 'zip',
} as const;
-// 仓库平台类型
-export const REPOSITORY_TYPES = {
- GITHUB: 'github',
- GITLAB: 'gitlab',
- OTHER: 'other',
-} as const;
-
// 分析深度
export const ANALYSIS_DEPTH = {
BASIC: 'basic',
diff --git a/frontend/src/shared/constants/projectTypes.ts b/frontend/src/shared/constants/projectTypes.ts
index 55e6d47..4e78e9b 100644
--- a/frontend/src/shared/constants/projectTypes.ts
+++ b/frontend/src/shared/constants/projectTypes.ts
@@ -22,17 +22,23 @@ export const PROJECT_SOURCE_TYPES: Array<{
}
];
+// 仓库平台显示名称
+export const REPOSITORY_PLATFORM_LABELS: Record
= {
+ github: 'GitHub',
+ gitlab: 'GitLab',
+ gitea: 'Gitea',
+ other: '其他',
+};
+
// 仓库平台选项
export const REPOSITORY_PLATFORMS: Array<{
value: RepositoryPlatform;
label: string;
icon?: string;
-}> = [
- { value: 'github', label: 'GitHub' },
- { value: 'gitlab', label: 'GitLab' },
- { value: 'gitea', label: 'Gitea' },
- { value: 'other', label: '其他' }
- ];
+}> = Object.entries(REPOSITORY_PLATFORM_LABELS).map(([value, label]) => ({
+ value: value as RepositoryPlatform,
+ label
+}));
// 项目来源类型的颜色配置
export const SOURCE_TYPE_COLORS: Record = {
- github: 'GitHub',
- gitlab: 'GitLab',
- gitea: 'Gitea',
- other: '其他'
- };
- return labels[platform || 'other'] || '其他';
+ return REPOSITORY_PLATFORM_LABELS[platform as keyof typeof REPOSITORY_PLATFORM_LABELS] || REPOSITORY_PLATFORM_LABELS.other;
}
/**