Merge branch 'v3.0.0' of github.com:lintsinghua/DeepAudit into feature/git_ssh

# Conflicts:
#	backend/app/api/v1/endpoints/agent_tasks.py
This commit is contained in:
Image 2025-12-26 09:39:25 +08:00
commit 869513e0c5
22 changed files with 405 additions and 2839 deletions

View File

@ -428,7 +428,7 @@ DeepSeek-Coder · Codestral<br/>
<div align="center">
**欢迎大家来和我交流探讨!无论是技术问题、功能建议还是合作意向,都期待与你沟通~**
(项目开发、投资孵化等合作洽谈请通过邮箱联系)
| 联系方式 | |
|:---:|:---:|
| 📧 **邮箱** | **lintsinghua@qq.com** |

View File

@ -294,6 +294,7 @@ async def _execute_agent_task(task_id: str):
other_config = (user_config or {}).get('otherConfig', {})
github_token = other_config.get('githubToken') or settings.GITHUB_TOKEN
gitlab_token = other_config.get('gitlabToken') or settings.GITLAB_TOKEN
gitea_token = other_config.get('giteaToken') or settings.GITEA_TOKEN
# 解密SSH私钥
ssh_private_key = None
@ -313,6 +314,7 @@ async def _execute_agent_task(task_id: str):
task.branch_name,
github_token=github_token,
gitlab_token=gitlab_token,
gitea_token=gitea_token, # 🔥 新增
ssh_private_key=ssh_private_key, # 🔥 新增SSH密钥
event_emitter=event_emitter, # 🔥 新增
)
@ -2226,6 +2228,7 @@ async def _get_project_root(
branch_name: Optional[str] = None,
github_token: Optional[str] = None,
gitlab_token: Optional[str] = None,
gitea_token: Optional[str] = None, # 🔥 新增
ssh_private_key: Optional[str] = None, # 🔥 新增SSH私钥用于SSH认证
event_emitter: Optional[Any] = None, # 🔥 新增:用于发送实时日志
) -> str:
@ -2242,6 +2245,7 @@ async def _get_project_root(
branch_name: 分支名称仓库项目使用优先于 project.default_branch
github_token: GitHub 访问令牌用于私有仓库
gitlab_token: GitLab 访问令牌用于私有仓库
gitea_token: Gitea 访问令牌用于私有仓库
ssh_private_key: SSH私钥用于SSH认证
event_emitter: 事件发送器用于发送实时日志
@ -2503,6 +2507,16 @@ async def _get_project_root(
parsed.fragment
))
await emit(f"🔐 使用 GitLab Token 认证")
elif repo_type == "gitea" and gitea_token:
auth_url = urlunparse((
parsed.scheme,
f"{gitea_token}@{parsed.netloc}",
parsed.path,
parsed.params,
parsed.query,
parsed.fragment
))
await emit(f"🔐 使用 Gitea Token 认证")
elif is_ssh_url and ssh_private_key:
await emit(f"🔐 使用 SSH Key 认证")

View File

@ -16,6 +16,7 @@ from app.db.session import get_db, AsyncSessionLocal
from app.models.project import Project
from app.models.user import User
from app.models.audit import AuditTask, AuditIssue
from app.models.agent_task import AgentTask, AgentTaskStatus, AgentFinding
from app.models.user_config import UserConfig
import zipfile
from app.services.scanner import scan_repo_task, get_github_files, get_gitlab_files, get_github_branches, get_gitlab_branches, get_gitea_branches, should_exclude, is_text_file
@ -162,26 +163,51 @@ async def get_stats(
projects = projects_result.scalars().all()
project_ids = [p.id for p in projects]
# 只统计当前用户项目的任务
# 统计旧的 AuditTask
tasks_result = await db.execute(
select(AuditTask).where(AuditTask.project_id.in_(project_ids)) if project_ids else select(AuditTask).where(False)
)
tasks = tasks_result.scalars().all()
task_ids = [t.id for t in tasks]
# 只统计当前用户任务的问题
# 统计旧的 AuditIssue
issues_result = await db.execute(
select(AuditIssue).where(AuditIssue.task_id.in_(task_ids)) if task_ids else select(AuditIssue).where(False)
)
issues = issues_result.scalars().all()
# 🔥 同时统计新的 AgentTask
agent_tasks_result = await db.execute(
select(AgentTask).where(AgentTask.project_id.in_(project_ids)) if project_ids else select(AgentTask).where(False)
)
agent_tasks = agent_tasks_result.scalars().all()
agent_task_ids = [t.id for t in agent_tasks]
# 🔥 统计 AgentFinding
agent_findings_result = await db.execute(
select(AgentFinding).where(AgentFinding.task_id.in_(agent_task_ids)) if agent_task_ids else select(AgentFinding).where(False)
)
agent_findings = agent_findings_result.scalars().all()
# 合并统计(旧任务 + 新 Agent 任务)
total_tasks = len(tasks) + len(agent_tasks)
completed_tasks = (
len([t for t in tasks if t.status == "completed"]) +
len([t for t in agent_tasks if t.status == AgentTaskStatus.COMPLETED])
)
total_issues = len(issues) + len(agent_findings)
resolved_issues = (
len([i for i in issues if i.status == "resolved"]) +
len([f for f in agent_findings if f.status == "resolved"])
)
return {
"total_projects": len(projects),
"active_projects": len([p for p in projects if p.is_active]),
"total_tasks": len(tasks),
"completed_tasks": len([t for t in tasks if t.status == "completed"]),
"total_issues": len(issues),
"resolved_issues": len([i for i in issues if i.status == "resolved"]),
"total_tasks": total_tasks,
"completed_tasks": completed_tasks,
"total_issues": total_issues,
"resolved_issues": resolved_issues,
}
@router.get("/{id}", response_model=ProjectResponse)

View File

@ -1,29 +1,19 @@
"""
DeepAudit Agent 服务模块
基于 LangGraph AI Agent 代码安全审计
基于动态 Agent 树架构的 AI 代码安全审计
架构升级版本 - 支持
- 动态Agent树结构
- 专业知识模块系统
- Agent间通信机制
- 完整状态管理
- Think工具和漏洞报告工具
架构:
- OrchestratorAgent 作为编排层动态调度子 Agent
- ReconAgent 负责侦察和文件分析
- AnalysisAgent 负责漏洞分析
- VerificationAgent 负责验证发现
工作流:
START Recon Analysis Verification Report END
START Orchestrator [Recon/Analysis/Verification] Report END
支持动态创建子Agent进行专业化分析
"""
# 从 graph 模块导入主要组件
from .graph import (
AgentRunner,
run_agent_task,
LLMService,
AuditState,
create_audit_graph,
)
# 事件管理
from .event_manager import EventManager, AgentEventEmitter
@ -33,14 +23,14 @@ from .agents import (
OrchestratorAgent, ReconAgent, AnalysisAgent, VerificationAgent,
)
# 🔥 新增:核心模块(状态管理、注册表、消息)
# 核心模块(状态管理、注册表、消息)
from .core import (
AgentState, AgentStatus,
AgentRegistry, agent_registry,
AgentMessage, MessageType, MessagePriority, MessageBus,
)
# 🔥 新增:知识模块系统基于RAG
# 知识模块系统基于RAG
from .knowledge import (
KnowledgeLoader, knowledge_loader,
get_available_modules, get_module_content,
@ -48,7 +38,7 @@ from .knowledge import (
SecurityKnowledgeQueryTool, GetVulnerabilityKnowledgeTool,
)
# 🔥 新增:协作工具
# 协作工具
from .tools import (
ThinkTool, ReflectTool,
CreateVulnerabilityReportTool,
@ -57,20 +47,11 @@ from .tools import (
WaitForMessageTool, AgentFinishTool,
)
# 🔥 新增:遥测模块
# 遥测模块
from .telemetry import Tracer, get_global_tracer, set_global_tracer
__all__ = [
# 核心 Runner
"AgentRunner",
"run_agent_task",
"LLMService",
# LangGraph
"AuditState",
"create_audit_graph",
# 事件管理
"EventManager",
"AgentEventEmitter",
@ -84,7 +65,7 @@ __all__ = [
"AnalysisAgent",
"VerificationAgent",
# 🔥 核心模块
# 核心模块
"AgentState",
"AgentStatus",
"AgentRegistry",
@ -94,7 +75,7 @@ __all__ = [
"MessagePriority",
"MessageBus",
# 🔥 知识模块基于RAG
# 知识模块基于RAG
"KnowledgeLoader",
"knowledge_loader",
"get_available_modules",
@ -104,7 +85,7 @@ __all__ = [
"SecurityKnowledgeQueryTool",
"GetVulnerabilityKnowledgeTool",
# 🔥 协作工具
# 协作工具
"ThinkTool",
"ReflectTool",
"CreateVulnerabilityReportTool",
@ -115,9 +96,8 @@ __all__ = [
"WaitForMessageTool",
"AgentFinishTool",
# 🔥 遥测模块
# 遥测模块
"Tracer",
"get_global_tracer",
"set_global_tracer",
]

View File

@ -1024,10 +1024,18 @@ class BaseAgent(ABC):
elif chunk["type"] == "error":
accumulated = chunk.get("accumulated", "")
error_msg = chunk.get("error", "Unknown error")
logger.error(f"[{self.name}] Stream error: {error_msg}")
if accumulated:
total_tokens = chunk.get("usage", {}).get("total_tokens", 0)
else:
error_type = chunk.get("error_type", "unknown")
user_message = chunk.get("user_message", error_msg)
logger.error(f"[{self.name}] Stream error ({error_type}): {error_msg}")
if chunk.get("usage"):
total_tokens = chunk["usage"].get("total_tokens", 0)
# 使用特殊前缀标记 API 错误,让调用方能够识别
# 格式:[API_ERROR:error_type] user_message
if error_type in ("rate_limit", "quota_exceeded", "authentication", "connection"):
accumulated = f"[API_ERROR:{error_type}] {user_message}"
elif not accumulated:
accumulated = f"[系统错误: {error_msg}] 请重新思考并输出你的决策。"
break

View File

@ -285,6 +285,55 @@ Action Input: {{"参数": "值"}}
# 重置空响应计数器
self._empty_retry_count = 0
# 🔥 检查是否是 API 错误(而非格式错误)
if llm_output.startswith("[API_ERROR:"):
# 提取错误类型和消息
match = re.match(r"\[API_ERROR:(\w+)\]\s*(.*)", llm_output)
if match:
error_type = match.group(1)
error_message = match.group(2)
if error_type == "rate_limit":
# 速率限制 - 等待后重试
api_retry_count = getattr(self, '_api_retry_count', 0) + 1
self._api_retry_count = api_retry_count
if api_retry_count >= 3:
logger.error(f"[{self.name}] Too many rate limit errors, stopping")
await self.emit_event("error", f"API 速率限制重试次数过多: {error_message}")
break
logger.warning(f"[{self.name}] Rate limit hit, waiting before retry ({api_retry_count}/3)")
await self.emit_event("warning", f"API 速率限制,等待后重试 ({api_retry_count}/3)")
await asyncio.sleep(30) # 等待 30 秒后重试
continue
elif error_type == "quota_exceeded":
# 配额用尽 - 终止任务
logger.error(f"[{self.name}] API quota exceeded: {error_message}")
await self.emit_event("error", f"API 配额已用尽: {error_message}")
break
elif error_type == "authentication":
# 认证错误 - 终止任务
logger.error(f"[{self.name}] API authentication error: {error_message}")
await self.emit_event("error", f"API 认证失败: {error_message}")
break
elif error_type == "connection":
# 连接错误 - 重试
api_retry_count = getattr(self, '_api_retry_count', 0) + 1
self._api_retry_count = api_retry_count
if api_retry_count >= 3:
logger.error(f"[{self.name}] Too many connection errors, stopping")
await self.emit_event("error", f"API 连接错误重试次数过多: {error_message}")
break
logger.warning(f"[{self.name}] Connection error, retrying ({api_retry_count}/3)")
await self.emit_event("warning", f"API 连接错误,重试中 ({api_retry_count}/3)")
await asyncio.sleep(5) # 等待 5 秒后重试
continue
# 重置 API 重试计数器(成功获取响应后)
self._api_retry_count = 0
# 解析 LLM 的决策
step = self._parse_llm_response(llm_output)

View File

@ -1,28 +0,0 @@
"""
LangGraph 工作流模块
使用状态图构建混合 Agent 审计流程
"""
from .audit_graph import AuditState, create_audit_graph, create_audit_graph_with_human
from .nodes import ReconNode, AnalysisNode, VerificationNode, ReportNode, HumanReviewNode
from .runner import AgentRunner, run_agent_task, LLMService
__all__ = [
# 状态和图
"AuditState",
"create_audit_graph",
"create_audit_graph_with_human",
# 节点
"ReconNode",
"AnalysisNode",
"VerificationNode",
"ReportNode",
"HumanReviewNode",
# Runner
"AgentRunner",
"run_agent_task",
"LLMService",
]

View File

@ -1,677 +0,0 @@
"""
DeepAudit 审计工作流图 - LLM 驱动版
使用 LangGraph 构建 LLM 驱动的 Agent 协作流程
重要改变路由决策由 LLM 参与而不是硬编码条件
"""
from typing import TypedDict, Annotated, List, Dict, Any, Optional, Literal
from datetime import datetime
import operator
import logging
import json
from langgraph.graph import StateGraph, END
from langgraph.checkpoint.memory import MemorySaver
from langgraph.prebuilt import ToolNode
logger = logging.getLogger(__name__)
# ============ 状态定义 ============
class Finding(TypedDict):
"""漏洞发现"""
id: str
vulnerability_type: str
severity: str
title: str
description: str
file_path: Optional[str]
line_start: Optional[int]
code_snippet: Optional[str]
is_verified: bool
confidence: float
source: str
class AuditState(TypedDict):
"""
审计状态
在整个工作流中传递和更新
"""
# 输入
project_root: str
project_info: Dict[str, Any]
config: Dict[str, Any]
task_id: str
# Recon 阶段输出
tech_stack: Dict[str, Any]
entry_points: List[Dict[str, Any]]
high_risk_areas: List[str]
dependencies: Dict[str, Any]
# Analysis 阶段输出
findings: Annotated[List[Finding], operator.add] # 使用 add 合并多轮发现
# Verification 阶段输出
verified_findings: List[Finding]
false_positives: List[str]
# 🔥 NEW: 验证后的完整 findings用于替换原始 findings
_verified_findings_update: Optional[List[Finding]]
# 控制流 - 🔥 关键LLM 可以设置这些来影响路由
current_phase: str
iteration: int
max_iterations: int
should_continue_analysis: bool
# 🔥 新增LLM 的路由决策
llm_next_action: Optional[str] # LLM 建议的下一步: "continue_analysis", "verify", "report", "end"
llm_routing_reason: Optional[str] # LLM 的决策理由
# 🔥 新增Agent 间协作的任务交接信息
recon_handoff: Optional[Dict[str, Any]] # Recon -> Analysis 的交接
analysis_handoff: Optional[Dict[str, Any]] # Analysis -> Verification 的交接
verification_handoff: Optional[Dict[str, Any]] # Verification -> Report 的交接
# 消息和事件
messages: Annotated[List[Dict], operator.add]
events: Annotated[List[Dict], operator.add]
# 最终输出
summary: Optional[Dict[str, Any]]
security_score: Optional[int]
error: Optional[str]
# ============ LLM 路由决策器 ============
class LLMRouter:
"""
LLM 路由决策器
LLM 来决定下一步应该做什么
"""
def __init__(self, llm_service):
self.llm_service = llm_service
async def decide_after_recon(self, state: AuditState) -> Dict[str, Any]:
"""Recon 后让 LLM 决定下一步"""
entry_points = state.get("entry_points", [])
high_risk_areas = state.get("high_risk_areas", [])
tech_stack = state.get("tech_stack", {})
initial_findings = state.get("findings", [])
prompt = f"""作为安全审计的决策者,基于以下信息收集结果,决定下一步行动。
## 信息收集结果
- 入口点数量: {len(entry_points)}
- 高风险区域: {high_risk_areas[:10]}
- 技术栈: {tech_stack}
- 初步发现: {len(initial_findings)}
## 选项
1. "analysis" - 继续进行漏洞分析推荐有入口点或高风险区域时
2. "end" - 结束审计仅当没有任何可分析内容时
请返回 JSON 格式
{{"action": "analysis或end", "reason": "决策理由"}}"""
try:
response = await self.llm_service.chat_completion_raw(
messages=[
{"role": "system", "content": "你是安全审计流程的决策者,负责决定下一步行动。"},
{"role": "user", "content": prompt},
],
# 🔥 不传递 temperature 和 max_tokens使用用户配置
)
content = response.get("content", "")
# 提取 JSON
import re
json_match = re.search(r'\{.*\}', content, re.DOTALL)
if json_match:
result = json.loads(json_match.group())
return result
except Exception as e:
logger.warning(f"LLM routing decision failed: {e}")
# 默认决策
if entry_points or high_risk_areas:
return {"action": "analysis", "reason": "有可分析内容"}
return {"action": "end", "reason": "没有发现入口点或高风险区域"}
async def decide_after_analysis(self, state: AuditState) -> Dict[str, Any]:
"""Analysis 后让 LLM 决定下一步"""
findings = state.get("findings", [])
iteration = state.get("iteration", 0)
max_iterations = state.get("max_iterations", 3)
# 统计发现
severity_counts = {"critical": 0, "high": 0, "medium": 0, "low": 0}
for f in findings:
# 跳过非字典类型的 finding
if not isinstance(f, dict):
continue
sev = f.get("severity", "medium")
severity_counts[sev] = severity_counts.get(sev, 0) + 1
prompt = f"""作为安全审计的决策者,基于以下分析结果,决定下一步行动。
## 分析结果
- 总发现数: {len(findings)}
- 严重程度分布: {severity_counts}
- 当前迭代: {iteration}/{max_iterations}
## 选项
1. "verification" - 验证发现的漏洞推荐有发现需要验证时
2. "analysis" - 继续深入分析推荐发现较少但还有迭代次数时
3. "report" - 生成报告推荐没有发现或已充分分析时
请返回 JSON 格式
{{"action": "verification/analysis/report", "reason": "决策理由"}}"""
try:
response = await self.llm_service.chat_completion_raw(
messages=[
{"role": "system", "content": "你是安全审计流程的决策者,负责决定下一步行动。"},
{"role": "user", "content": prompt},
],
# 🔥 不传递 temperature 和 max_tokens使用用户配置
)
content = response.get("content", "")
import re
json_match = re.search(r'\{.*\}', content, re.DOTALL)
if json_match:
result = json.loads(json_match.group())
return result
except Exception as e:
logger.warning(f"LLM routing decision failed: {e}")
# 默认决策
if not findings:
return {"action": "report", "reason": "没有发现漏洞"}
if len(findings) >= 3 or iteration >= max_iterations:
return {"action": "verification", "reason": "有足够的发现需要验证"}
return {"action": "analysis", "reason": "发现较少,继续分析"}
async def decide_after_verification(self, state: AuditState) -> Dict[str, Any]:
"""Verification 后让 LLM 决定下一步"""
verified_findings = state.get("verified_findings", [])
false_positives = state.get("false_positives", [])
iteration = state.get("iteration", 0)
max_iterations = state.get("max_iterations", 3)
prompt = f"""作为安全审计的决策者,基于以下验证结果,决定下一步行动。
## 验证结果
- 已确认漏洞: {len(verified_findings)}
- 误报数量: {len(false_positives)}
- 当前迭代: {iteration}/{max_iterations}
## 选项
1. "analysis" - 回到分析阶段重新分析推荐误报率太高时
2. "report" - 生成最终报告推荐验证完成时
请返回 JSON 格式
{{"action": "analysis/report", "reason": "决策理由"}}"""
try:
response = await self.llm_service.chat_completion_raw(
messages=[
{"role": "system", "content": "你是安全审计流程的决策者,负责决定下一步行动。"},
{"role": "user", "content": prompt},
],
# 🔥 不传递 temperature 和 max_tokens使用用户配置
)
content = response.get("content", "")
import re
json_match = re.search(r'\{.*\}', content, re.DOTALL)
if json_match:
result = json.loads(json_match.group())
return result
except Exception as e:
logger.warning(f"LLM routing decision failed: {e}")
# 默认决策
if len(false_positives) > len(verified_findings) and iteration < max_iterations:
return {"action": "analysis", "reason": "误报率较高,需要重新分析"}
return {"action": "report", "reason": "验证完成,生成报告"}
# ============ 路由函数 (结合 LLM 决策) ============
def route_after_recon(state: AuditState) -> Literal["analysis", "end"]:
"""
Recon 后的路由决策
优先使用 LLM 的决策否则使用默认逻辑
"""
# 🔥 检查是否有错误
if state.get("error") or state.get("current_phase") == "error":
logger.error(f"Recon phase has error, routing to end: {state.get('error')}")
return "end"
# 检查 LLM 是否有决策
llm_action = state.get("llm_next_action")
if llm_action:
logger.info(f"Using LLM routing decision: {llm_action}, reason: {state.get('llm_routing_reason')}")
if llm_action == "end":
return "end"
return "analysis"
# 默认逻辑(作为 fallback
if not state.get("entry_points") and not state.get("high_risk_areas"):
return "end"
return "analysis"
def route_after_analysis(state: AuditState) -> Literal["verification", "analysis", "report"]:
"""
Analysis 后的路由决策
优先使用 LLM 的决策
"""
# 检查 LLM 是否有决策
llm_action = state.get("llm_next_action")
if llm_action:
logger.info(f"Using LLM routing decision: {llm_action}, reason: {state.get('llm_routing_reason')}")
if llm_action == "verification":
return "verification"
elif llm_action == "analysis":
return "analysis"
elif llm_action == "report":
return "report"
# 默认逻辑
findings = state.get("findings", [])
iteration = state.get("iteration", 0)
max_iterations = state.get("max_iterations", 3)
should_continue = state.get("should_continue_analysis", False)
if not findings:
return "report"
if should_continue and iteration < max_iterations:
return "analysis"
return "verification"
def route_after_verification(state: AuditState) -> Literal["analysis", "report"]:
"""
Verification 后的路由决策
优先使用 LLM 的决策
"""
# 检查 LLM 是否有决策
llm_action = state.get("llm_next_action")
if llm_action:
logger.info(f"Using LLM routing decision: {llm_action}, reason: {state.get('llm_routing_reason')}")
if llm_action == "analysis":
return "analysis"
return "report"
# 默认逻辑
false_positives = state.get("false_positives", [])
iteration = state.get("iteration", 0)
max_iterations = state.get("max_iterations", 3)
if len(false_positives) > len(state.get("verified_findings", [])) and iteration < max_iterations:
return "analysis"
return "report"
# ============ 创建审计图 ============
def create_audit_graph(
recon_node,
analysis_node,
verification_node,
report_node,
checkpointer: Optional[MemorySaver] = None,
llm_service=None, # 用于 LLM 路由决策
) -> StateGraph:
"""
创建审计工作流图
Args:
recon_node: 信息收集节点
analysis_node: 漏洞分析节点
verification_node: 漏洞验证节点
report_node: 报告生成节点
checkpointer: 检查点存储器
llm_service: LLM 服务用于路由决策
Returns:
编译后的 StateGraph
工作流结构:
START
Recon 信息收集 (LLM 驱动)
LLM 决定
Analysis 漏洞分析 (LLM 驱动可循环)
LLM 决定
Verification 漏洞验证 (LLM 驱动可回溯)
LLM 决定
Report 报告生成
END
"""
# 创建状态图
workflow = StateGraph(AuditState)
# 如果有 LLM 服务,创建路由决策器
llm_router = LLMRouter(llm_service) if llm_service else None
# 包装节点以添加 LLM 路由决策
async def recon_with_routing(state):
result = await recon_node(state)
# LLM 决定下一步
if llm_router:
decision = await llm_router.decide_after_recon({**state, **result})
result["llm_next_action"] = decision.get("action")
result["llm_routing_reason"] = decision.get("reason")
return result
async def analysis_with_routing(state):
result = await analysis_node(state)
# LLM 决定下一步
if llm_router:
decision = await llm_router.decide_after_analysis({**state, **result})
result["llm_next_action"] = decision.get("action")
result["llm_routing_reason"] = decision.get("reason")
return result
async def verification_with_routing(state):
result = await verification_node(state)
# LLM 决定下一步
if llm_router:
decision = await llm_router.decide_after_verification({**state, **result})
result["llm_next_action"] = decision.get("action")
result["llm_routing_reason"] = decision.get("reason")
return result
# 添加节点
if llm_router:
workflow.add_node("recon", recon_with_routing)
workflow.add_node("analysis", analysis_with_routing)
workflow.add_node("verification", verification_with_routing)
else:
workflow.add_node("recon", recon_node)
workflow.add_node("analysis", analysis_node)
workflow.add_node("verification", verification_node)
workflow.add_node("report", report_node)
# 设置入口点
workflow.set_entry_point("recon")
# 添加条件边
workflow.add_conditional_edges(
"recon",
route_after_recon,
{
"analysis": "analysis",
"end": END,
}
)
workflow.add_conditional_edges(
"analysis",
route_after_analysis,
{
"verification": "verification",
"analysis": "analysis",
"report": "report",
}
)
workflow.add_conditional_edges(
"verification",
route_after_verification,
{
"analysis": "analysis",
"report": "report",
}
)
# Report -> END
workflow.add_edge("report", END)
# 编译图
if checkpointer:
return workflow.compile(checkpointer=checkpointer)
else:
return workflow.compile()
# ============ 带人机协作的审计图 ============
def create_audit_graph_with_human(
recon_node,
analysis_node,
verification_node,
report_node,
human_review_node,
checkpointer: Optional[MemorySaver] = None,
llm_service=None,
) -> StateGraph:
"""
创建带人机协作的审计工作流图
在验证阶段后增加人工审核节点
"""
workflow = StateGraph(AuditState)
llm_router = LLMRouter(llm_service) if llm_service else None
# 包装节点
async def recon_with_routing(state):
result = await recon_node(state)
if llm_router:
decision = await llm_router.decide_after_recon({**state, **result})
result["llm_next_action"] = decision.get("action")
result["llm_routing_reason"] = decision.get("reason")
return result
async def analysis_with_routing(state):
result = await analysis_node(state)
if llm_router:
decision = await llm_router.decide_after_analysis({**state, **result})
result["llm_next_action"] = decision.get("action")
result["llm_routing_reason"] = decision.get("reason")
return result
# 添加节点
if llm_router:
workflow.add_node("recon", recon_with_routing)
workflow.add_node("analysis", analysis_with_routing)
else:
workflow.add_node("recon", recon_node)
workflow.add_node("analysis", analysis_node)
workflow.add_node("verification", verification_node)
workflow.add_node("human_review", human_review_node)
workflow.add_node("report", report_node)
workflow.set_entry_point("recon")
workflow.add_conditional_edges(
"recon",
route_after_recon,
{"analysis": "analysis", "end": END}
)
workflow.add_conditional_edges(
"analysis",
route_after_analysis,
{
"verification": "verification",
"analysis": "analysis",
"report": "report",
}
)
# Verification -> Human Review
workflow.add_edge("verification", "human_review")
# Human Review 后的路由
def route_after_human(state: AuditState) -> Literal["analysis", "report"]:
if state.get("should_continue_analysis"):
return "analysis"
return "report"
workflow.add_conditional_edges(
"human_review",
route_after_human,
{"analysis": "analysis", "report": "report"}
)
workflow.add_edge("report", END)
if checkpointer:
return workflow.compile(checkpointer=checkpointer, interrupt_before=["human_review"])
else:
return workflow.compile()
# ============ 执行器 ============
class AuditGraphRunner:
"""
审计图执行器
封装 LangGraph 工作流的执行
"""
def __init__(
self,
graph: StateGraph,
event_emitter=None,
):
self.graph = graph
self.event_emitter = event_emitter
async def run(
self,
project_root: str,
project_info: Dict[str, Any],
config: Dict[str, Any],
task_id: str,
) -> Dict[str, Any]:
"""
执行审计工作流
"""
# 初始状态
initial_state: AuditState = {
"project_root": project_root,
"project_info": project_info,
"config": config,
"task_id": task_id,
"tech_stack": {},
"entry_points": [],
"high_risk_areas": [],
"dependencies": {},
"findings": [],
"verified_findings": [],
"false_positives": [],
"_verified_findings_update": None, # 🔥 NEW: 验证后的 findings 更新
"current_phase": "start",
"iteration": 0,
"max_iterations": config.get("max_iterations", 3),
"should_continue_analysis": False,
"llm_next_action": None,
"llm_routing_reason": None,
"messages": [],
"events": [],
"summary": None,
"security_score": None,
"error": None,
}
run_config = {
"configurable": {
"thread_id": task_id,
}
}
try:
async for event in self.graph.astream(initial_state, config=run_config):
if self.event_emitter:
for node_name, node_state in event.items():
await self.event_emitter.emit_info(
f"节点 {node_name} 完成"
)
# 发射 LLM 路由决策事件
if node_state.get("llm_routing_reason"):
await self.event_emitter.emit_info(
f"🧠 LLM 决策: {node_state.get('llm_next_action')} - {node_state.get('llm_routing_reason')}"
)
if node_name == "analysis" and node_state.get("findings"):
new_findings = node_state["findings"]
await self.event_emitter.emit_info(
f"发现 {len(new_findings)} 个潜在漏洞"
)
final_state = self.graph.get_state(run_config)
return final_state.values
except Exception as e:
logger.error(f"Graph execution failed: {e}", exc_info=True)
raise
async def run_with_human_review(
self,
initial_state: AuditState,
human_feedback_callback,
) -> Dict[str, Any]:
"""带人机协作的执行"""
run_config = {
"configurable": {
"thread_id": initial_state["task_id"],
}
}
async for event in self.graph.astream(initial_state, config=run_config):
pass
current_state = self.graph.get_state(run_config)
if current_state.next == ("human_review",):
human_decision = await human_feedback_callback(current_state.values)
updated_state = {
**current_state.values,
"should_continue_analysis": human_decision.get("continue_analysis", False),
}
async for event in self.graph.astream(updated_state, config=run_config):
pass
return self.graph.get_state(run_config).values

View File

@ -1,556 +0,0 @@
"""
LangGraph 节点实现
每个节点封装一个 Agent 的执行逻辑
协作增强节点之间通过 TaskHandoff 传递结构化的上下文和洞察
"""
from typing import Dict, Any, List, Optional
import logging
logger = logging.getLogger(__name__)
# 延迟导入避免循环依赖
def get_audit_state_type():
from .audit_graph import AuditState
return AuditState
class BaseNode:
"""节点基类"""
def __init__(self, agent=None, event_emitter=None):
self.agent = agent
self.event_emitter = event_emitter
async def emit_event(self, event_type: str, message: str, **kwargs):
"""发射事件"""
if self.event_emitter:
try:
await self.event_emitter.emit_info(message)
except Exception as e:
logger.warning(f"Failed to emit event: {e}")
def _extract_handoff_from_state(self, state: Dict[str, Any], from_phase: str):
"""从状态中提取前序 Agent 的 handoff"""
handoff_data = state.get(f"{from_phase}_handoff")
if handoff_data:
from ..agents.base import TaskHandoff
return TaskHandoff.from_dict(handoff_data)
return None
class ReconNode(BaseNode):
"""
信息收集节点
输入: project_root, project_info, config
输出: tech_stack, entry_points, high_risk_areas, dependencies, recon_handoff
"""
async def __call__(self, state: Dict[str, Any]) -> Dict[str, Any]:
"""执行信息收集"""
await self.emit_event("phase_start", "🔍 开始信息收集阶段")
try:
# 调用 Recon Agent
result = await self.agent.run({
"project_info": state["project_info"],
"config": state["config"],
})
if result.success and result.data:
data = result.data
# 🔥 创建交接信息给 Analysis Agent
handoff = self.agent.create_handoff(
to_agent="Analysis",
summary=f"项目信息收集完成。发现 {len(data.get('entry_points', []))} 个入口点,{len(data.get('high_risk_areas', []))} 个高风险区域。",
key_findings=data.get("initial_findings", []),
suggested_actions=[
{
"type": "deep_analysis",
"description": f"深入分析高风险区域: {', '.join(data.get('high_risk_areas', [])[:5])}",
"priority": "high",
},
{
"type": "entry_point_audit",
"description": "审计所有入口点的输入验证",
"priority": "high",
},
],
attention_points=[
f"技术栈: {data.get('tech_stack', {}).get('frameworks', [])}",
f"主要语言: {data.get('tech_stack', {}).get('languages', [])}",
],
priority_areas=data.get("high_risk_areas", [])[:10],
context_data={
"tech_stack": data.get("tech_stack", {}),
"entry_points": data.get("entry_points", []),
"dependencies": data.get("dependencies", {}),
},
)
await self.emit_event(
"phase_complete",
f"✅ 信息收集完成: 发现 {len(data.get('entry_points', []))} 个入口点"
)
return {
"tech_stack": data.get("tech_stack", {}),
"entry_points": data.get("entry_points", []),
"high_risk_areas": data.get("high_risk_areas", []),
"dependencies": data.get("dependencies", {}),
"current_phase": "recon_complete",
"findings": data.get("initial_findings", []),
# 🔥 保存交接信息
"recon_handoff": handoff.to_dict(),
"events": [{
"type": "recon_complete",
"data": {
"entry_points_count": len(data.get("entry_points", [])),
"high_risk_areas_count": len(data.get("high_risk_areas", [])),
"handoff_summary": handoff.summary,
}
}],
}
else:
return {
"error": result.error or "Recon failed",
"current_phase": "error",
}
except Exception as e:
logger.error(f"Recon node failed: {e}", exc_info=True)
return {
"error": str(e),
"current_phase": "error",
}
class AnalysisNode(BaseNode):
"""
漏洞分析节点
输入: tech_stack, entry_points, high_risk_areas, recon_handoff
输出: findings (累加), should_continue_analysis, analysis_handoff
"""
async def __call__(self, state: Dict[str, Any]) -> Dict[str, Any]:
"""执行漏洞分析"""
iteration = state.get("iteration", 0) + 1
await self.emit_event(
"phase_start",
f"🔬 开始漏洞分析阶段 (迭代 {iteration})"
)
try:
# 🔥 提取 Recon 的交接信息
recon_handoff = self._extract_handoff_from_state(state, "recon")
if recon_handoff:
self.agent.receive_handoff(recon_handoff)
await self.emit_event(
"handoff_received",
f"📨 收到 Recon Agent 交接: {recon_handoff.summary[:50]}..."
)
# 构建分析输入
analysis_input = {
"phase_name": "analysis",
"project_info": state["project_info"],
"config": state["config"],
"plan": {
"high_risk_areas": state.get("high_risk_areas", []),
},
"previous_results": {
"recon": {
"data": {
"tech_stack": state.get("tech_stack", {}),
"entry_points": state.get("entry_points", []),
"high_risk_areas": state.get("high_risk_areas", []),
}
}
},
# 🔥 传递交接信息
"handoff": recon_handoff,
}
# 调用 Analysis Agent
result = await self.agent.run(analysis_input)
if result.success and result.data:
new_findings = result.data.get("findings", [])
logger.info(f"[AnalysisNode] Agent returned {len(new_findings)} findings")
# 判断是否需要继续分析
should_continue = (
len(new_findings) >= 5 and
iteration < state.get("max_iterations", 3)
)
# 🔥 创建交接信息给 Verification Agent
# 统计严重程度
severity_counts = {"critical": 0, "high": 0, "medium": 0, "low": 0}
for f in new_findings:
if isinstance(f, dict):
sev = f.get("severity", "medium")
severity_counts[sev] = severity_counts.get(sev, 0) + 1
handoff = self.agent.create_handoff(
to_agent="Verification",
summary=f"漏洞分析完成。发现 {len(new_findings)} 个潜在漏洞 (Critical: {severity_counts['critical']}, High: {severity_counts['high']}, Medium: {severity_counts['medium']}, Low: {severity_counts['low']})",
key_findings=new_findings[:20], # 传递前20个发现
suggested_actions=[
{
"type": "verify_critical",
"description": "优先验证 Critical 和 High 级别的漏洞",
"priority": "critical",
},
{
"type": "poc_generation",
"description": "为确认的漏洞生成 PoC",
"priority": "high",
},
],
attention_points=[
f"{severity_counts['critical']} 个 Critical 级别漏洞需要立即验证",
f"{severity_counts['high']} 个 High 级别漏洞需要优先验证",
"注意检查是否有误报,特别是静态分析工具的结果",
],
priority_areas=[
f.get("file_path", "") for f in new_findings
if f.get("severity") in ["critical", "high"]
][:10],
context_data={
"severity_distribution": severity_counts,
"total_findings": len(new_findings),
"iteration": iteration,
},
)
await self.emit_event(
"phase_complete",
f"✅ 分析迭代 {iteration} 完成: 发现 {len(new_findings)} 个潜在漏洞"
)
return {
"findings": new_findings,
"iteration": iteration,
"should_continue_analysis": should_continue,
"current_phase": "analysis_complete",
# 🔥 保存交接信息
"analysis_handoff": handoff.to_dict(),
"events": [{
"type": "analysis_iteration",
"data": {
"iteration": iteration,
"findings_count": len(new_findings),
"severity_distribution": severity_counts,
"handoff_summary": handoff.summary,
}
}],
}
else:
return {
"iteration": iteration,
"should_continue_analysis": False,
"current_phase": "analysis_complete",
}
except Exception as e:
logger.error(f"Analysis node failed: {e}", exc_info=True)
return {
"error": str(e),
"should_continue_analysis": False,
"current_phase": "error",
}
class VerificationNode(BaseNode):
"""
漏洞验证节点
输入: findings, analysis_handoff
输出: verified_findings, false_positives, verification_handoff
"""
async def __call__(self, state: Dict[str, Any]) -> Dict[str, Any]:
"""执行漏洞验证"""
findings = state.get("findings", [])
logger.info(f"[VerificationNode] Received {len(findings)} findings to verify")
if not findings:
return {
"verified_findings": [],
"false_positives": [],
"current_phase": "verification_complete",
}
await self.emit_event(
"phase_start",
f"🔐 开始漏洞验证阶段 ({len(findings)} 个待验证)"
)
try:
# 🔥 提取 Analysis 的交接信息
analysis_handoff = self._extract_handoff_from_state(state, "analysis")
if analysis_handoff:
self.agent.receive_handoff(analysis_handoff)
await self.emit_event(
"handoff_received",
f"📨 收到 Analysis Agent 交接: {analysis_handoff.summary[:50]}..."
)
# 构建验证输入
verification_input = {
"previous_results": {
"analysis": {
"data": {
"findings": findings,
}
}
},
"config": state["config"],
# 🔥 传递交接信息
"handoff": analysis_handoff,
}
# 调用 Verification Agent
result = await self.agent.run(verification_input)
if result.success and result.data:
all_verified_findings = result.data.get("findings", [])
verified = [f for f in all_verified_findings if f.get("is_verified")]
false_pos = [f.get("id", f.get("title", "unknown")) for f in all_verified_findings
if f.get("verdict") == "false_positive"]
# 🔥 CRITICAL FIX: 用验证结果更新原始 findings
# 创建 findings 的更新映射,基于 (file_path, line_start, vulnerability_type)
verified_map = {}
for vf in all_verified_findings:
key = (
vf.get("file_path", ""),
vf.get("line_start", 0),
vf.get("vulnerability_type", ""),
)
verified_map[key] = vf
# 合并验证结果到原始 findings
updated_findings = []
seen_keys = set()
# 首先处理原始 findings用验证结果更新
for f in findings:
if not isinstance(f, dict):
continue
key = (
f.get("file_path", ""),
f.get("line_start", 0),
f.get("vulnerability_type", ""),
)
if key in verified_map:
# 使用验证后的版本
updated_findings.append(verified_map[key])
seen_keys.add(key)
else:
# 保留原始(未验证)
updated_findings.append(f)
seen_keys.add(key)
# 添加验证结果中的新发现(如果有)
for key, vf in verified_map.items():
if key not in seen_keys:
updated_findings.append(vf)
logger.info(f"[VerificationNode] Updated findings: {len(updated_findings)} total, {len(verified)} verified")
# 🔥 创建交接信息给 Report 节点
handoff = self.agent.create_handoff(
to_agent="Report",
summary=f"漏洞验证完成。{len(verified)} 个漏洞已确认,{len(false_pos)} 个误报已排除。",
key_findings=verified,
suggested_actions=[
{
"type": "generate_report",
"description": "生成详细的安全审计报告",
"priority": "high",
},
{
"type": "remediation_plan",
"description": "为确认的漏洞制定修复计划",
"priority": "high",
},
],
attention_points=[
f"{len(verified)} 个漏洞已确认存在",
f"{len(false_pos)} 个误报已排除",
"建议按严重程度优先修复 Critical 和 High 级别漏洞",
],
context_data={
"verified_count": len(verified),
"false_positive_count": len(false_pos),
"total_analyzed": len(findings),
"verification_rate": len(verified) / len(findings) if findings else 0,
},
)
await self.emit_event(
"phase_complete",
f"✅ 验证完成: {len(verified)} 已确认, {len(false_pos)} 误报"
)
return {
# 🔥 CRITICAL: 返回更新后的 findings这会替换状态中的 findings
# 注意:由于 LangGraph 使用 operator.add我们需要在 runner 中处理合并
# 这里我们返回 _verified_findings_update 作为特殊字段
"_verified_findings_update": updated_findings,
"verified_findings": verified,
"false_positives": false_pos,
"current_phase": "verification_complete",
# 🔥 保存交接信息
"verification_handoff": handoff.to_dict(),
"events": [{
"type": "verification_complete",
"data": {
"verified_count": len(verified),
"false_positive_count": len(false_pos),
"total_findings": len(updated_findings),
"handoff_summary": handoff.summary,
}
}],
}
else:
return {
"verified_findings": [],
"false_positives": [],
"current_phase": "verification_complete",
}
except Exception as e:
logger.error(f"Verification node failed: {e}", exc_info=True)
return {
"error": str(e),
"current_phase": "error",
}
class ReportNode(BaseNode):
"""
报告生成节点
输入: all state
输出: summary, security_score
"""
async def __call__(self, state: Dict[str, Any]) -> Dict[str, Any]:
"""生成审计报告"""
await self.emit_event("phase_start", "📊 生成审计报告")
try:
# 🔥 CRITICAL FIX: 优先使用验证后的 findings 更新
findings = state.get("_verified_findings_update") or state.get("findings", [])
verified = state.get("verified_findings", [])
false_positives = state.get("false_positives", [])
logger.info(f"[ReportNode] State contains {len(findings)} findings, {len(verified)} verified")
# 统计漏洞分布
severity_counts = {"critical": 0, "high": 0, "medium": 0, "low": 0}
type_counts = {}
for finding in findings:
# 跳过非字典类型的 finding防止数据格式异常
if not isinstance(finding, dict):
logger.warning(f"Skipping invalid finding (not a dict): {type(finding)}")
continue
sev = finding.get("severity", "medium")
severity_counts[sev] = severity_counts.get(sev, 0) + 1
vtype = finding.get("vulnerability_type", "other")
type_counts[vtype] = type_counts.get(vtype, 0) + 1
# 计算安全评分
base_score = 100
deductions = (
severity_counts["critical"] * 25 +
severity_counts["high"] * 15 +
severity_counts["medium"] * 8 +
severity_counts["low"] * 3
)
security_score = max(0, base_score - deductions)
# 生成摘要
summary = {
"total_findings": len(findings),
"verified_count": len(verified),
"false_positive_count": len(false_positives),
"severity_distribution": severity_counts,
"vulnerability_types": type_counts,
"tech_stack": state.get("tech_stack", {}),
"entry_points_analyzed": len(state.get("entry_points", [])),
"high_risk_areas": state.get("high_risk_areas", []),
"iterations": state.get("iteration", 1),
}
await self.emit_event(
"phase_complete",
f"报告生成完成: 安全评分 {security_score}/100"
)
return {
"summary": summary,
"security_score": security_score,
"current_phase": "complete",
"events": [{
"type": "audit_complete",
"data": {
"security_score": security_score,
"total_findings": len(findings),
"verified_count": len(verified),
}
}],
}
except Exception as e:
logger.error(f"Report node failed: {e}", exc_info=True)
return {
"error": str(e),
"current_phase": "error",
}
class HumanReviewNode(BaseNode):
"""
人工审核节点
在此节点暂停等待人工反馈
"""
async def __call__(self, state: Dict[str, Any]) -> Dict[str, Any]:
"""
人工审核节点
这个节点会被 interrupt_before 暂停
用户可以
1. 确认发现
2. 标记误报
3. 请求重新分析
"""
await self.emit_event(
"human_review",
f"⏸️ 等待人工审核 ({len(state.get('verified_findings', []))} 个待确认)"
)
# 返回当前状态,不做修改
# 人工反馈会通过 update_state 传入
return {
"current_phase": "human_review",
"messages": [{
"role": "system",
"content": "等待人工审核",
"findings_for_review": state.get("verified_findings", []),
}],
}

File diff suppressed because it is too large Load Diff

View File

@ -6,6 +6,7 @@
import os
import re
import fnmatch
import asyncio
from typing import Optional, List, Dict, Any
from pydantic import BaseModel, Field
@ -45,6 +46,36 @@ class FileReadTool(AgentTool):
self.exclude_patterns = exclude_patterns or []
self.target_files = set(target_files) if target_files else None
@staticmethod
def _read_file_lines_sync(file_path: str, start_idx: int, end_idx: int) -> tuple:
"""同步读取文件指定行范围(用于 asyncio.to_thread"""
selected_lines = []
total_lines = 0
file_size = os.path.getsize(file_path)
with open(file_path, 'r', encoding='utf-8', errors='ignore') as f:
for i, line in enumerate(f):
total_lines = i + 1
if i >= start_idx and i < end_idx:
selected_lines.append(line)
elif i >= end_idx:
if i < end_idx + 1000:
continue
else:
remaining_bytes = file_size - f.tell()
avg_line_size = f.tell() / (i + 1)
estimated_remaining_lines = int(remaining_bytes / avg_line_size) if avg_line_size > 0 else 0
total_lines = i + 1 + estimated_remaining_lines
break
return selected_lines, total_lines
@staticmethod
def _read_all_lines_sync(file_path: str) -> List[str]:
"""同步读取文件所有行(用于 asyncio.to_thread"""
with open(file_path, 'r', encoding='utf-8', errors='ignore') as f:
return f.readlines()
@property
def name(self) -> str:
return "read_file"
@ -136,37 +167,20 @@ class FileReadTool(AgentTool):
# 🔥 对于大文件,使用流式读取指定行范围
if is_large_file and (start_line is not None or end_line is not None):
# 流式读取,避免一次性加载整个文件
selected_lines = []
total_lines = 0
# 计算实际的起始和结束行
start_idx = max(0, (start_line or 1) - 1)
end_idx = end_line if end_line else start_idx + max_lines
with open(full_path, 'r', encoding='utf-8', errors='ignore') as f:
for i, line in enumerate(f):
total_lines = i + 1
if i >= start_idx and i < end_idx:
selected_lines.append(line)
elif i >= end_idx:
# 继续计数以获取总行数,但限制读取量
if i < end_idx + 1000: # 最多再读1000行来估算总行数
continue
else:
# 估算剩余行数
remaining_bytes = file_size - f.tell()
avg_line_size = f.tell() / (i + 1)
estimated_remaining_lines = int(remaining_bytes / avg_line_size) if avg_line_size > 0 else 0
total_lines = i + 1 + estimated_remaining_lines
break
# 异步读取文件,避免阻塞事件循环
selected_lines, total_lines = await asyncio.to_thread(
self._read_file_lines_sync, full_path, start_idx, end_idx
)
# 更新实际的结束索引
end_idx = min(end_idx, start_idx + len(selected_lines))
else:
# 正常读取小文件
with open(full_path, 'r', encoding='utf-8', errors='ignore') as f:
lines = f.readlines()
# 异步读取小文件,避免阻塞事件循环
lines = await asyncio.to_thread(self._read_all_lines_sync, full_path)
total_lines = len(lines)
@ -268,6 +282,12 @@ class FileSearchTool(AgentTool):
elif "/" not in pattern and "*" not in pattern:
self.exclude_dirs.add(pattern)
@staticmethod
def _read_file_lines_sync(file_path: str) -> List[str]:
"""同步读取文件所有行(用于 asyncio.to_thread"""
with open(file_path, 'r', encoding='utf-8', errors='ignore') as f:
return f.readlines()
@property
def name(self) -> str:
return "search_code"
@ -360,8 +380,10 @@ class FileSearchTool(AgentTool):
continue
try:
with open(file_path, 'r', encoding='utf-8', errors='ignore') as f:
lines = f.readlines()
# 异步读取文件,避免阻塞事件循环
lines = await asyncio.to_thread(
self._read_file_lines_sync, file_path
)
files_searched += 1

View File

@ -416,13 +416,93 @@ class LiteLLMAdapter(BaseLLMAdapter):
"finish_reason": "complete",
}
except Exception as e:
# 🔥 即使出错,也尝试返回估算的 usage
logger.error(f"Stream error: {e}")
except litellm.exceptions.RateLimitError as e:
# 速率限制错误 - 需要特殊处理
logger.error(f"Stream rate limit error: {e}")
error_msg = str(e)
# 区分"余额不足"和"频率超限"
if any(keyword in error_msg.lower() for keyword in ["余额不足", "资源包", "充值", "quota", "exceeded", "billing"]):
error_type = "quota_exceeded"
user_message = "API 配额已用尽,请检查账户余额或升级计划"
else:
error_type = "rate_limit"
# 尝试从错误消息中提取重试时间
import re
retry_match = re.search(r"retry\s*(?:in|after)\s*(\d+(?:\.\d+)?)\s*s", error_msg, re.IGNORECASE)
retry_seconds = float(retry_match.group(1)) if retry_match else 60
user_message = f"API 调用频率超限,建议等待 {int(retry_seconds)} 秒后重试"
output_tokens_estimate = estimate_tokens(accumulated_content) if accumulated_content else 0
yield {
"type": "error",
"error_type": error_type,
"error": error_msg,
"user_message": user_message,
"accumulated": accumulated_content,
"usage": {
"prompt_tokens": input_tokens_estimate,
"completion_tokens": output_tokens_estimate,
"total_tokens": input_tokens_estimate + output_tokens_estimate,
} if accumulated_content else None,
}
except litellm.exceptions.AuthenticationError as e:
# 认证错误 - API Key 无效
logger.error(f"Stream authentication error: {e}")
yield {
"type": "error",
"error_type": "authentication",
"error": str(e),
"user_message": "API Key 无效或已过期,请检查配置",
"accumulated": accumulated_content,
"usage": None,
}
except litellm.exceptions.APIConnectionError as e:
# 连接错误 - 网络问题
logger.error(f"Stream connection error: {e}")
yield {
"type": "error",
"error_type": "connection",
"error": str(e),
"user_message": "无法连接到 API 服务,请检查网络连接",
"accumulated": accumulated_content,
"usage": None,
}
except Exception as e:
# 其他错误 - 检查是否是包装的速率限制错误
error_msg = str(e)
logger.error(f"Stream error: {e}")
# 检查是否是包装的速率限制错误(如 ServiceUnavailableError 包装 RateLimitError
is_rate_limit = any(keyword in error_msg.lower() for keyword in [
"ratelimiterror", "rate limit", "429", "resource_exhausted",
"quota exceeded", "too many requests"
])
if is_rate_limit:
# 按速率限制错误处理
import re
# 检查是否是配额用尽
if any(keyword in error_msg.lower() for keyword in ["quota", "exceeded", "billing"]):
error_type = "quota_exceeded"
user_message = "API 配额已用尽,请检查账户余额或升级计划"
else:
error_type = "rate_limit"
retry_match = re.search(r"retry\s*(?:in|after)\s*(\d+(?:\.\d+)?)\s*s", error_msg, re.IGNORECASE)
retry_seconds = float(retry_match.group(1)) if retry_match else 60
user_message = f"API 调用频率超限,建议等待 {int(retry_seconds)} 秒后重试"
else:
error_type = "unknown"
user_message = "LLM 调用发生错误,请重试"
output_tokens_estimate = estimate_tokens(accumulated_content) if accumulated_content else 0
yield {
"type": "error",
"error_type": error_type,
"error": error_msg,
"user_message": user_message,
"accumulated": accumulated_content,
"usage": {
"prompt_tokens": input_tokens_estimate,

View File

@ -739,6 +739,20 @@ class CodeIndexer:
self._needs_rebuild = False
self._rebuild_reason = ""
@staticmethod
def _read_file_sync(file_path: str) -> str:
"""
同步读取文件内容用于 asyncio.to_thread 包装
Args:
file_path: 文件路径
Returns:
文件内容
"""
with open(file_path, 'r', encoding='utf-8', errors='ignore') as f:
return f.read()
async def initialize(self, force_rebuild: bool = False) -> Tuple[bool, str]:
"""
初始化索引器检测是否需要重建索引
@ -916,8 +930,10 @@ class CodeIndexer:
try:
relative_path = os.path.relpath(file_path, directory)
with open(file_path, 'r', encoding='utf-8', errors='ignore') as f:
content = f.read()
# 异步读取文件,避免阻塞事件循环
content = await asyncio.to_thread(
self._read_file_sync, file_path
)
if not content.strip():
progress.processed_files += 1
@ -932,8 +948,8 @@ class CodeIndexer:
if len(content) > 500000:
content = content[:500000]
# 分块
chunks = self.splitter.split_file(content, relative_path)
# 异步分块,避免 Tree-sitter 解析阻塞事件循环
chunks = await self.splitter.split_file_async(content, relative_path)
# 为每个 chunk 添加 file_hash
for chunk in chunks:
@ -1018,8 +1034,10 @@ class CodeIndexer:
for relative_path in files_to_check:
file_path = current_file_map[relative_path]
try:
with open(file_path, 'r', encoding='utf-8', errors='ignore') as f:
content = f.read()
# 异步读取文件,避免阻塞事件循环
content = await asyncio.to_thread(
self._read_file_sync, file_path
)
current_hash = hashlib.md5(content.encode()).hexdigest()
if current_hash != indexed_file_hashes.get(relative_path):
files_to_update.add(relative_path)
@ -1055,8 +1073,10 @@ class CodeIndexer:
is_update = relative_path in files_to_update
try:
with open(file_path, 'r', encoding='utf-8', errors='ignore') as f:
content = f.read()
# 异步读取文件,避免阻塞事件循环
content = await asyncio.to_thread(
self._read_file_sync, file_path
)
if not content.strip():
progress.processed_files += 1
@ -1075,8 +1095,8 @@ class CodeIndexer:
if len(content) > 500000:
content = content[:500000]
# 分块
chunks = self.splitter.split_file(content, relative_path)
# 异步分块,避免 Tree-sitter 解析阻塞事件循环
chunks = await self.splitter.split_file_async(content, relative_path)
# 为每个 chunk 添加 file_hash
for chunk in chunks:

View File

@ -4,6 +4,7 @@
"""
import re
import asyncio
import hashlib
import logging
from typing import List, Dict, Any, Optional, Tuple, Set
@ -154,7 +155,7 @@ class TreeSitterParser:
".c": "c",
".h": "c",
".hpp": "cpp",
".cs": "c_sharp",
".cs": "csharp",
".php": "php",
".rb": "ruby",
".kt": "kotlin",
@ -197,7 +198,7 @@ class TreeSitterParser:
# tree-sitter-languages 支持的语言列表
SUPPORTED_LANGUAGES = {
"python", "javascript", "typescript", "tsx", "java", "go", "rust",
"c", "cpp", "c_sharp", "php", "ruby", "kotlin", "swift", "bash",
"c", "cpp", "csharp", "php", "ruby", "kotlin", "swift", "bash",
"json", "yaml", "html", "css", "sql", "markdown",
}
@ -230,7 +231,7 @@ class TreeSitterParser:
return False
def parse(self, code: str, language: str) -> Optional[Any]:
"""解析代码返回 AST"""
"""解析代码返回 AST(同步方法)"""
if not self._ensure_initialized(language):
return None
@ -245,6 +246,15 @@ class TreeSitterParser:
logger.warning(f"Failed to parse code: {e}")
return None
async def parse_async(self, code: str, language: str) -> Optional[Any]:
"""
异步解析代码返回 AST
CPU 密集型的 Tree-sitter 解析操作放到线程池中执行
避免阻塞事件循环
"""
return await asyncio.to_thread(self.parse, code, language)
def extract_definitions(self, tree: Any, code: str, language: str) -> List[Dict[str, Any]]:
"""从 AST 提取定义"""
if tree is None:
@ -452,6 +462,28 @@ class CodeSplitter:
return chunks
async def split_file_async(
self,
content: str,
file_path: str,
language: Optional[str] = None
) -> List[CodeChunk]:
"""
异步分割单个文件
CPU 密集型的分块操作包括 Tree-sitter 解析放到线程池中执行
避免阻塞事件循环
Args:
content: 文件内容
file_path: 文件路径
language: 编程语言可选
Returns:
代码块列表
"""
return await asyncio.to_thread(self.split_file, content, file_path, language)
def _split_by_ast(
self,
content: str,

View File

@ -1,355 +0,0 @@
"""
Agent 集成测试
测试完整的审计流程
"""
import pytest
import asyncio
import os
from unittest.mock import MagicMock, AsyncMock, patch
from datetime import datetime
from app.services.agent.graph.runner import AgentRunner, LLMService
from app.services.agent.graph.audit_graph import AuditState, create_audit_graph
from app.services.agent.graph.nodes import ReconNode, AnalysisNode, VerificationNode, ReportNode
from app.services.agent.event_manager import EventManager, AgentEventEmitter
class TestLLMService:
"""LLM 服务测试"""
@pytest.mark.asyncio
async def test_llm_service_initialization(self):
"""测试 LLM 服务初始化"""
with patch("app.core.config.settings") as mock_settings:
mock_settings.LLM_MODEL = "gpt-4o-mini"
mock_settings.LLM_API_KEY = "test-key"
service = LLMService()
assert service.model == "gpt-4o-mini"
class TestEventManager:
"""事件管理器测试"""
def test_event_manager_initialization(self):
"""测试事件管理器初始化"""
manager = EventManager()
assert manager._event_queues == {}
assert manager._event_callbacks == {}
@pytest.mark.asyncio
async def test_event_emitter(self):
"""测试事件发射器"""
manager = EventManager()
emitter = AgentEventEmitter("test-task-id", manager)
await emitter.emit_info("Test message")
assert emitter._sequence == 1
@pytest.mark.asyncio
async def test_event_emitter_phase_tracking(self):
"""测试事件发射器阶段跟踪"""
manager = EventManager()
emitter = AgentEventEmitter("test-task-id", manager)
await emitter.emit_phase_start("recon", "开始信息收集")
assert emitter._current_phase == "recon"
@pytest.mark.asyncio
async def test_event_emitter_task_complete(self):
"""测试任务完成事件"""
manager = EventManager()
emitter = AgentEventEmitter("test-task-id", manager)
await emitter.emit_task_complete(findings_count=5, duration_ms=1000)
assert emitter._sequence == 1
class TestAuditGraph:
"""审计图测试"""
def test_create_audit_graph(self, mock_event_emitter):
"""测试创建审计图"""
# 创建模拟节点
recon_node = MagicMock()
analysis_node = MagicMock()
verification_node = MagicMock()
report_node = MagicMock()
graph = create_audit_graph(
recon_node=recon_node,
analysis_node=analysis_node,
verification_node=verification_node,
report_node=report_node,
)
assert graph is not None
class TestReconNode:
"""Recon 节点测试"""
@pytest.fixture
def recon_node_with_mock_agent(self, mock_event_emitter):
"""创建带模拟 Agent 的 Recon 节点"""
mock_agent = MagicMock()
mock_agent.run = AsyncMock(return_value=MagicMock(
success=True,
data={
"tech_stack": {"languages": ["Python"]},
"entry_points": [{"path": "src/app.py", "type": "api"}],
"high_risk_areas": ["src/sql_vuln.py"],
"dependencies": {},
"initial_findings": [],
}
))
return ReconNode(mock_agent, mock_event_emitter)
@pytest.mark.asyncio
async def test_recon_node_success(self, recon_node_with_mock_agent):
"""测试 Recon 节点成功执行"""
state = {
"project_info": {"name": "Test"},
"config": {},
}
result = await recon_node_with_mock_agent(state)
assert "tech_stack" in result
assert "entry_points" in result
assert result["current_phase"] == "recon_complete"
@pytest.mark.asyncio
async def test_recon_node_failure(self, mock_event_emitter):
"""测试 Recon 节点失败处理"""
mock_agent = MagicMock()
mock_agent.run = AsyncMock(return_value=MagicMock(
success=False,
error="Test error",
data=None,
))
node = ReconNode(mock_agent, mock_event_emitter)
result = await node({
"project_info": {},
"config": {},
})
assert "error" in result
assert result["current_phase"] == "error"
class TestAnalysisNode:
"""Analysis 节点测试"""
@pytest.fixture
def analysis_node_with_mock_agent(self, mock_event_emitter):
"""创建带模拟 Agent 的 Analysis 节点"""
mock_agent = MagicMock()
mock_agent.run = AsyncMock(return_value=MagicMock(
success=True,
data={
"findings": [
{
"id": "finding-1",
"title": "SQL Injection",
"severity": "high",
"vulnerability_type": "sql_injection",
"file_path": "src/sql_vuln.py",
"line_start": 10,
"description": "SQL injection vulnerability",
}
],
"should_continue": False,
}
))
return AnalysisNode(mock_agent, mock_event_emitter)
@pytest.mark.asyncio
async def test_analysis_node_success(self, analysis_node_with_mock_agent):
"""测试 Analysis 节点成功执行"""
state = {
"project_info": {"name": "Test"},
"tech_stack": {"languages": ["Python"]},
"entry_points": [],
"high_risk_areas": ["src/sql_vuln.py"],
"config": {},
"iteration": 0,
"findings": [],
}
result = await analysis_node_with_mock_agent(state)
assert "findings" in result
assert len(result["findings"]) > 0
assert result["iteration"] == 1
class TestIntegrationFlow:
"""完整流程集成测试"""
@pytest.mark.asyncio
async def test_full_audit_flow_mock(self, temp_project_dir, mock_db_session, mock_task):
"""测试完整审计流程(使用模拟)"""
# 这个测试验证整个流程的连接性
# 创建事件管理器
event_manager = EventManager()
emitter = AgentEventEmitter(mock_task.id, event_manager)
# 模拟 LLM 服务
mock_llm = MagicMock()
mock_llm.chat_completion_raw = AsyncMock(return_value={
"content": "Analysis complete",
"usage": {"total_tokens": 100},
})
# 验证事件发射
await emitter.emit_phase_start("init", "初始化")
await emitter.emit_info("测试消息")
await emitter.emit_phase_complete("init", "初始化完成")
assert emitter._sequence == 3
@pytest.mark.asyncio
async def test_audit_state_typing(self):
"""测试审计状态类型定义"""
state: AuditState = {
"project_root": "/tmp/test",
"project_info": {"name": "Test"},
"config": {},
"task_id": "test-id",
"tech_stack": {},
"entry_points": [],
"high_risk_areas": [],
"dependencies": {},
"findings": [],
"verified_findings": [],
"false_positives": [],
"current_phase": "start",
"iteration": 0,
"max_iterations": 50,
"should_continue_analysis": False,
"messages": [],
"events": [],
"summary": None,
"security_score": None,
"error": None,
}
assert state["current_phase"] == "start"
assert state["max_iterations"] == 50
class TestToolIntegration:
"""工具集成测试"""
@pytest.mark.asyncio
async def test_tools_work_together(self, temp_project_dir):
"""测试工具协同工作"""
from app.services.agent.tools import (
FileReadTool, FileSearchTool, ListFilesTool, PatternMatchTool,
)
# 1. 列出文件
list_tool = ListFilesTool(temp_project_dir)
list_result = await list_tool.execute(directory="src", recursive=False)
assert list_result.success is True
# 2. 搜索关键代码
search_tool = FileSearchTool(temp_project_dir)
search_result = await search_tool.execute(keyword="execute")
assert search_result.success is True
# 3. 读取文件内容
read_tool = FileReadTool(temp_project_dir)
read_result = await read_tool.execute(file_path="src/sql_vuln.py")
assert read_result.success is True
# 4. 模式匹配
pattern_tool = PatternMatchTool(temp_project_dir)
pattern_result = await pattern_tool.execute(
code=read_result.data,
file_path="src/sql_vuln.py",
language="python"
)
assert pattern_result.success is True
class TestErrorHandling:
"""错误处理测试"""
@pytest.mark.asyncio
async def test_tool_error_handling(self, temp_project_dir):
"""测试工具错误处理"""
from app.services.agent.tools import FileReadTool
tool = FileReadTool(temp_project_dir)
# 尝试读取不存在的文件
result = await tool.execute(file_path="nonexistent/file.py")
assert result.success is False
assert result.error is not None
@pytest.mark.asyncio
async def test_agent_graceful_degradation(self, mock_event_emitter):
"""测试 Agent 优雅降级"""
# 创建一个会失败的 Agent
mock_agent = MagicMock()
mock_agent.run = AsyncMock(side_effect=Exception("Simulated error"))
node = ReconNode(mock_agent, mock_event_emitter)
result = await node({
"project_info": {},
"config": {},
})
# 应该返回错误状态而不是崩溃
assert "error" in result
assert result["current_phase"] == "error"
class TestPerformance:
"""性能测试"""
@pytest.mark.asyncio
async def test_tool_response_time(self, temp_project_dir):
"""测试工具响应时间"""
from app.services.agent.tools import ListFilesTool
import time
tool = ListFilesTool(temp_project_dir)
start = time.time()
await tool.execute(directory=".", recursive=True)
duration = time.time() - start
# 工具应该在合理时间内响应
assert duration < 5.0 # 5 秒内
@pytest.mark.asyncio
async def test_multiple_tool_calls(self, temp_project_dir):
"""测试多次工具调用"""
from app.services.agent.tools import FileSearchTool
tool = FileSearchTool(temp_project_dir)
# 执行多次调用
for _ in range(5):
result = await tool.execute(keyword="def")
assert result.success is True
# 验证调用计数
assert tool._call_count == 5

View File

@ -9,7 +9,7 @@ import {
} from "@/components/ui/select";
import { GitBranch, Zap, Info } from "lucide-react";
import type { Project, CreateAuditTaskForm } from "@/shared/types";
import { isRepositoryProject, isZipProject } from "@/shared/utils/projectUtils";
import { isRepositoryProject, isZipProject, getRepositoryPlatformLabel } from "@/shared/utils/projectUtils";
import ZipFileSection from "./ZipFileSection";
import type { ZipFileMeta } from "@/shared/utils/zipStorage";
@ -138,7 +138,7 @@ function ProjectInfoCard({ project }: { project: Project }) {
{isRepo && (
<>
<p>
{project.repository_type?.toUpperCase() || "OTHER"}
{getRepositoryPlatformLabel(project.repository_type)}
</p>
<p>{project.default_branch}</p>
</>

View File

@ -34,13 +34,13 @@ import { api } from "@/shared/config/database";
import { runRepositoryAudit, scanStoredZipFile } from "@/features/projects/services";
import type { Project, AuditTask, CreateProjectForm } from "@/shared/types";
import { hasZipFile } from "@/shared/utils/zipStorage";
import { isRepositoryProject, getSourceTypeLabel } from "@/shared/utils/projectUtils";
import { isRepositoryProject, getSourceTypeLabel, getRepositoryPlatformLabel } from "@/shared/utils/projectUtils";
import { toast } from "sonner";
import CreateTaskDialog from "@/components/audit/CreateTaskDialog";
import FileSelectionDialog from "@/components/audit/FileSelectionDialog";
import TerminalProgressDialog from "@/components/audit/TerminalProgressDialog";
import { Dialog, DialogContent, DialogHeader, DialogTitle, DialogFooter } from "@/components/ui/dialog";
import { SUPPORTED_LANGUAGES } from "@/shared/constants";
import { SUPPORTED_LANGUAGES, REPOSITORY_PLATFORMS } from "@/shared/constants";
export default function ProjectDetail() {
const { id } = useParams<{ id: string }>();
@ -475,8 +475,7 @@ export default function ProjectDetail() {
<div className="flex items-center justify-between">
<span className="text-sm text-muted-foreground uppercase"></span>
<Badge className="cyber-badge-muted">
{project.repository_type === 'github' ? 'GitHub' :
project.repository_type === 'gitlab' ? 'GitLab' : '其他'}
{getRepositoryPlatformLabel(project.repository_type)}
</Badge>
</div>
@ -529,12 +528,11 @@ export default function ProjectDetail() {
className="flex items-center justify-between p-3 bg-muted/50 rounded-lg hover:bg-muted transition-all group"
>
<div className="flex items-center space-x-3">
<div className={`w-8 h-8 rounded-lg flex items-center justify-center ${
task.status === 'completed' ? 'bg-emerald-500/20' :
<div className={`w-8 h-8 rounded-lg flex items-center justify-center ${task.status === 'completed' ? 'bg-emerald-500/20' :
task.status === 'running' ? 'bg-sky-500/20' :
task.status === 'failed' ? 'bg-rose-500/20' :
'bg-muted'
}`}>
task.status === 'failed' ? 'bg-rose-500/20' :
'bg-muted'
}`}>
{getStatusIcon(task.status)}
</div>
<div>
@ -579,12 +577,11 @@ export default function ProjectDetail() {
<div key={task.id} className="cyber-card p-6">
<div className="flex items-center justify-between mb-4 pb-4 border-b border-border">
<div className="flex items-center space-x-3">
<div className={`w-10 h-10 rounded-lg flex items-center justify-center ${
task.status === 'completed' ? 'bg-emerald-500/20' :
<div className={`w-10 h-10 rounded-lg flex items-center justify-center ${task.status === 'completed' ? 'bg-emerald-500/20' :
task.status === 'running' ? 'bg-sky-500/20' :
task.status === 'failed' ? 'bg-rose-500/20' :
'bg-muted'
}`}>
task.status === 'failed' ? 'bg-rose-500/20' :
'bg-muted'
}`}>
{getStatusIcon(task.status)}
</div>
<div>
@ -676,12 +673,11 @@ export default function ProjectDetail() {
<div key={index} className="cyber-card p-4 hover:border-border transition-all">
<div className="flex items-start justify-between">
<div className="flex items-start space-x-3">
<div className={`w-8 h-8 rounded-lg flex items-center justify-center ${
issue.severity === 'critical' ? 'bg-rose-500/20 text-rose-600 dark:text-rose-400' :
<div className={`w-8 h-8 rounded-lg flex items-center justify-center ${issue.severity === 'critical' ? 'bg-rose-500/20 text-rose-600 dark:text-rose-400' :
issue.severity === 'high' ? 'bg-orange-500/20 text-orange-600 dark:text-orange-400' :
issue.severity === 'medium' ? 'bg-amber-500/20 text-amber-600 dark:text-amber-400' :
'bg-sky-500/20 text-sky-600 dark:text-sky-400'
}`}>
issue.severity === 'medium' ? 'bg-amber-500/20 text-amber-600 dark:text-amber-400' :
'bg-sky-500/20 text-sky-600 dark:text-sky-400'
}`}>
<AlertTriangle className="w-4 h-4" />
</div>
<div>
@ -695,13 +691,13 @@ export default function ProjectDetail() {
<Badge className={`
${issue.severity === 'critical' ? 'severity-critical' :
issue.severity === 'high' ? 'severity-high' :
issue.severity === 'medium' ? 'severity-medium' :
'severity-low'}
issue.severity === 'medium' ? 'severity-medium' :
'severity-low'}
font-bold uppercase px-2 py-1 rounded text-xs
`}>
{issue.severity === 'critical' ? '严重' :
issue.severity === 'high' ? '高' :
issue.severity === 'medium' ? '中等' : '低'}
issue.severity === 'medium' ? '中等' : '低'}
</Badge>
</div>
<p className="mt-3 text-sm text-muted-foreground font-mono border-t border-border pt-3">
@ -783,9 +779,11 @@ export default function ProjectDetail() {
<SelectValue />
</SelectTrigger>
<SelectContent className="cyber-dialog border-border">
<SelectItem value="github">GitHub</SelectItem>
<SelectItem value="gitlab">GitLab</SelectItem>
<SelectItem value="other"></SelectItem>
{REPOSITORY_PLATFORMS.map((platform) => (
<SelectItem key={platform.value} value={platform.value}>
{platform.label}
</SelectItem>
))}
</SelectContent>
</Select>
</div>
@ -831,14 +829,14 @@ export default function ProjectDetail() {
className={`flex items-center space-x-2 p-3 border cursor-pointer transition-all rounded ${editForm.programming_languages?.includes(lang)
? 'border-primary bg-primary/10 text-primary'
: 'border-border hover:border-border text-muted-foreground'
}`}
}`}
onClick={() => handleToggleLanguage(lang)}
>
<div
className={`w-4 h-4 border-2 rounded-sm flex items-center justify-center ${editForm.programming_languages?.includes(lang)
? 'bg-primary border-primary'
: 'border-border'
}`}
}`}
>
{editForm.programming_languages?.includes(lang) && (
<CheckCircle className="w-3 h-3 text-foreground" />

View File

@ -45,7 +45,7 @@ import { Link } from "react-router-dom";
import { toast } from "sonner";
import CreateTaskDialog from "@/components/audit/CreateTaskDialog";
import TerminalProgressDialog from "@/components/audit/TerminalProgressDialog";
import { SUPPORTED_LANGUAGES } from "@/shared/constants";
import { SUPPORTED_LANGUAGES, REPOSITORY_PLATFORMS } from "@/shared/constants";
export default function Projects() {
const [projects, setProjects] = useState<Project[]>([]);
@ -487,10 +487,11 @@ export default function Projects() {
<SelectValue />
</SelectTrigger>
<SelectContent className="cyber-dialog border-border">
<SelectItem value="github">GITHUB</SelectItem>
<SelectItem value="gitlab">GITLAB</SelectItem>
<SelectItem value="gitea">GITEA</SelectItem>
<SelectItem value="other">OTHER</SelectItem>
{REPOSITORY_PLATFORMS.map((platform) => (
<SelectItem key={platform.value} value={platform.value}>
{platform.label}
</SelectItem>
))}
</SelectContent>
</Select>
</div>
@ -1046,10 +1047,11 @@ export default function Projects() {
<SelectValue />
</SelectTrigger>
<SelectContent className="cyber-dialog border-border">
<SelectItem value="github">GITHUB</SelectItem>
<SelectItem value="gitlab">GITLAB</SelectItem>
<SelectItem value="gitea">GITEA</SelectItem>
<SelectItem value="other">OTHER</SelectItem>
{REPOSITORY_PLATFORMS.map((platform) => (
<SelectItem key={platform.value} value={platform.value}>
{platform.label}
</SelectItem>
))}
</SelectContent>
</Select>
</div>

View File

@ -36,7 +36,7 @@ import type { AuditTask, AuditIssue } from "@/shared/types";
import { toast } from "sonner";
import ExportReportDialog from "@/components/reports/ExportReportDialog";
import { calculateTaskProgress } from "@/shared/utils/utils";
import { isRepositoryProject, getSourceTypeLabel } from "@/shared/utils/projectUtils";
import { isRepositoryProject, getSourceTypeLabel, getRepositoryPlatformLabel } from "@/shared/utils/projectUtils";
// AI explanation parser
function parseAIExplanation(aiExplanation: string) {
@ -86,12 +86,11 @@ function IssuesList({ issues }: { issues: AuditIssue[] }) {
<div key={issue.id || index} className="cyber-card p-4 hover:border-border transition-all group">
<div className="flex items-start justify-between mb-3">
<div className="flex items-start space-x-3">
<div className={`w-10 h-10 rounded-lg flex items-center justify-center ${
issue.severity === 'critical' ? 'bg-rose-500/20 text-rose-400' :
issue.severity === 'high' ? 'bg-orange-500/20 text-orange-400' :
issue.severity === 'medium' ? 'bg-amber-500/20 text-amber-400' :
'bg-sky-500/20 text-sky-400'
}`}>
<div className={`w-10 h-10 rounded-lg flex items-center justify-center ${issue.severity === 'critical' ? 'bg-rose-500/20 text-rose-400' :
issue.severity === 'high' ? 'bg-orange-500/20 text-orange-400' :
issue.severity === 'medium' ? 'bg-amber-500/20 text-amber-400' :
'bg-sky-500/20 text-sky-400'
}`}>
{getTypeIcon(issue.issue_type)}
</div>
<div className="flex-1">
@ -112,7 +111,7 @@ function IssuesList({ issues }: { issues: AuditIssue[] }) {
<Badge className={`${getSeverityClasses(issue.severity)} font-bold uppercase px-2 py-1 rounded text-xs`}>
{issue.severity === 'critical' ? '严重' :
issue.severity === 'high' ? '高' :
issue.severity === 'medium' ? '中等' : '低'}
issue.severity === 'medium' ? '中等' : '低'}
</Badge>
</div>
@ -702,7 +701,7 @@ export default function TaskDetail() {
{isRepositoryProject(task.project) && (
<div>
<p className="text-xs font-bold text-muted-foreground uppercase mb-1"></p>
<p className="text-base font-bold text-foreground">{task.project.repository_type?.toUpperCase() || 'OTHER'}</p>
<p className="text-base font-bold text-foreground">{getRepositoryPlatformLabel(task.project.repository_type)}</p>
</div>
)}
{task.project.programming_languages && (

View File

@ -62,13 +62,6 @@ export const PROJECT_SOURCE_TYPES = {
ZIP: 'zip',
} as const;
// 仓库平台类型
export const REPOSITORY_TYPES = {
GITHUB: 'github',
GITLAB: 'gitlab',
OTHER: 'other',
} as const;
// 分析深度
export const ANALYSIS_DEPTH = {
BASIC: 'basic',

View File

@ -22,17 +22,23 @@ export const PROJECT_SOURCE_TYPES: Array<{
}
];
// 仓库平台显示名称
export const REPOSITORY_PLATFORM_LABELS: Record<RepositoryPlatform, string> = {
github: 'GitHub',
gitlab: 'GitLab',
gitea: 'Gitea',
other: '其他',
};
// 仓库平台选项
export const REPOSITORY_PLATFORMS: Array<{
value: RepositoryPlatform;
label: string;
icon?: string;
}> = [
{ value: 'github', label: 'GitHub' },
{ value: 'gitlab', label: 'GitLab' },
{ value: 'gitea', label: 'Gitea' },
{ value: 'other', label: '其他' }
];
}> = Object.entries(REPOSITORY_PLATFORM_LABELS).map(([value, label]) => ({
value: value as RepositoryPlatform,
label
}));
// 项目来源类型的颜色配置
export const SOURCE_TYPE_COLORS: Record<ProjectSourceType, {

View File

@ -4,6 +4,7 @@
*/
import type { Project, ProjectSourceType } from '@/shared/types';
import { REPOSITORY_PLATFORM_LABELS } from '@/shared/constants/projectTypes';
/**
*
@ -45,13 +46,7 @@ export function getSourceTypeBadge(sourceType: ProjectSourceType): string {
*
*/
export function getRepositoryPlatformLabel(platform?: string): string {
const labels: Record<string, string> = {
github: 'GitHub',
gitlab: 'GitLab',
gitea: 'Gitea',
other: '其他'
};
return labels[platform || 'other'] || '其他';
return REPOSITORY_PLATFORM_LABELS[platform as keyof typeof REPOSITORY_PLATFORM_LABELS] || REPOSITORY_PLATFORM_LABELS.other;
}
/**