Merge branch 'v3.0.0' of github.com:lintsinghua/DeepAudit into feature/git_ssh
# Conflicts: # backend/app/api/v1/endpoints/agent_tasks.py
This commit is contained in:
commit
869513e0c5
|
|
@ -428,7 +428,7 @@ DeepSeek-Coder · Codestral<br/>
|
|||
<div align="center">
|
||||
|
||||
**欢迎大家来和我交流探讨!无论是技术问题、功能建议还是合作意向,都期待与你沟通~**
|
||||
|
||||
(项目开发、投资孵化等合作洽谈请通过邮箱联系)
|
||||
| 联系方式 | |
|
||||
|:---:|:---:|
|
||||
| 📧 **邮箱** | **lintsinghua@qq.com** |
|
||||
|
|
|
|||
|
|
@ -294,6 +294,7 @@ async def _execute_agent_task(task_id: str):
|
|||
other_config = (user_config or {}).get('otherConfig', {})
|
||||
github_token = other_config.get('githubToken') or settings.GITHUB_TOKEN
|
||||
gitlab_token = other_config.get('gitlabToken') or settings.GITLAB_TOKEN
|
||||
gitea_token = other_config.get('giteaToken') or settings.GITEA_TOKEN
|
||||
|
||||
# 解密SSH私钥
|
||||
ssh_private_key = None
|
||||
|
|
@ -313,6 +314,7 @@ async def _execute_agent_task(task_id: str):
|
|||
task.branch_name,
|
||||
github_token=github_token,
|
||||
gitlab_token=gitlab_token,
|
||||
gitea_token=gitea_token, # 🔥 新增
|
||||
ssh_private_key=ssh_private_key, # 🔥 新增SSH密钥
|
||||
event_emitter=event_emitter, # 🔥 新增
|
||||
)
|
||||
|
|
@ -2226,6 +2228,7 @@ async def _get_project_root(
|
|||
branch_name: Optional[str] = None,
|
||||
github_token: Optional[str] = None,
|
||||
gitlab_token: Optional[str] = None,
|
||||
gitea_token: Optional[str] = None, # 🔥 新增
|
||||
ssh_private_key: Optional[str] = None, # 🔥 新增:SSH私钥(用于SSH认证)
|
||||
event_emitter: Optional[Any] = None, # 🔥 新增:用于发送实时日志
|
||||
) -> str:
|
||||
|
|
@ -2242,6 +2245,7 @@ async def _get_project_root(
|
|||
branch_name: 分支名称(仓库项目使用,优先于 project.default_branch)
|
||||
github_token: GitHub 访问令牌(用于私有仓库)
|
||||
gitlab_token: GitLab 访问令牌(用于私有仓库)
|
||||
gitea_token: Gitea 访问令牌(用于私有仓库)
|
||||
ssh_private_key: SSH私钥(用于SSH认证)
|
||||
event_emitter: 事件发送器(用于发送实时日志)
|
||||
|
||||
|
|
@ -2503,6 +2507,16 @@ async def _get_project_root(
|
|||
parsed.fragment
|
||||
))
|
||||
await emit(f"🔐 使用 GitLab Token 认证")
|
||||
elif repo_type == "gitea" and gitea_token:
|
||||
auth_url = urlunparse((
|
||||
parsed.scheme,
|
||||
f"{gitea_token}@{parsed.netloc}",
|
||||
parsed.path,
|
||||
parsed.params,
|
||||
parsed.query,
|
||||
parsed.fragment
|
||||
))
|
||||
await emit(f"🔐 使用 Gitea Token 认证")
|
||||
elif is_ssh_url and ssh_private_key:
|
||||
await emit(f"🔐 使用 SSH Key 认证")
|
||||
|
||||
|
|
|
|||
|
|
@ -16,6 +16,7 @@ from app.db.session import get_db, AsyncSessionLocal
|
|||
from app.models.project import Project
|
||||
from app.models.user import User
|
||||
from app.models.audit import AuditTask, AuditIssue
|
||||
from app.models.agent_task import AgentTask, AgentTaskStatus, AgentFinding
|
||||
from app.models.user_config import UserConfig
|
||||
import zipfile
|
||||
from app.services.scanner import scan_repo_task, get_github_files, get_gitlab_files, get_github_branches, get_gitlab_branches, get_gitea_branches, should_exclude, is_text_file
|
||||
|
|
@ -162,26 +163,51 @@ async def get_stats(
|
|||
projects = projects_result.scalars().all()
|
||||
project_ids = [p.id for p in projects]
|
||||
|
||||
# 只统计当前用户项目的任务
|
||||
# 统计旧的 AuditTask
|
||||
tasks_result = await db.execute(
|
||||
select(AuditTask).where(AuditTask.project_id.in_(project_ids)) if project_ids else select(AuditTask).where(False)
|
||||
)
|
||||
tasks = tasks_result.scalars().all()
|
||||
task_ids = [t.id for t in tasks]
|
||||
|
||||
# 只统计当前用户任务的问题
|
||||
# 统计旧的 AuditIssue
|
||||
issues_result = await db.execute(
|
||||
select(AuditIssue).where(AuditIssue.task_id.in_(task_ids)) if task_ids else select(AuditIssue).where(False)
|
||||
)
|
||||
issues = issues_result.scalars().all()
|
||||
|
||||
# 🔥 同时统计新的 AgentTask
|
||||
agent_tasks_result = await db.execute(
|
||||
select(AgentTask).where(AgentTask.project_id.in_(project_ids)) if project_ids else select(AgentTask).where(False)
|
||||
)
|
||||
agent_tasks = agent_tasks_result.scalars().all()
|
||||
agent_task_ids = [t.id for t in agent_tasks]
|
||||
|
||||
# 🔥 统计 AgentFinding
|
||||
agent_findings_result = await db.execute(
|
||||
select(AgentFinding).where(AgentFinding.task_id.in_(agent_task_ids)) if agent_task_ids else select(AgentFinding).where(False)
|
||||
)
|
||||
agent_findings = agent_findings_result.scalars().all()
|
||||
|
||||
# 合并统计(旧任务 + 新 Agent 任务)
|
||||
total_tasks = len(tasks) + len(agent_tasks)
|
||||
completed_tasks = (
|
||||
len([t for t in tasks if t.status == "completed"]) +
|
||||
len([t for t in agent_tasks if t.status == AgentTaskStatus.COMPLETED])
|
||||
)
|
||||
total_issues = len(issues) + len(agent_findings)
|
||||
resolved_issues = (
|
||||
len([i for i in issues if i.status == "resolved"]) +
|
||||
len([f for f in agent_findings if f.status == "resolved"])
|
||||
)
|
||||
|
||||
return {
|
||||
"total_projects": len(projects),
|
||||
"active_projects": len([p for p in projects if p.is_active]),
|
||||
"total_tasks": len(tasks),
|
||||
"completed_tasks": len([t for t in tasks if t.status == "completed"]),
|
||||
"total_issues": len(issues),
|
||||
"resolved_issues": len([i for i in issues if i.status == "resolved"]),
|
||||
"total_tasks": total_tasks,
|
||||
"completed_tasks": completed_tasks,
|
||||
"total_issues": total_issues,
|
||||
"resolved_issues": resolved_issues,
|
||||
}
|
||||
|
||||
@router.get("/{id}", response_model=ProjectResponse)
|
||||
|
|
|
|||
|
|
@ -1,29 +1,19 @@
|
|||
"""
|
||||
DeepAudit Agent 服务模块
|
||||
基于 LangGraph 的 AI Agent 代码安全审计
|
||||
基于动态 Agent 树架构的 AI 代码安全审计
|
||||
|
||||
架构升级版本 - 支持:
|
||||
- 动态Agent树结构
|
||||
- 专业知识模块系统
|
||||
- Agent间通信机制
|
||||
- 完整状态管理
|
||||
- Think工具和漏洞报告工具
|
||||
架构:
|
||||
- OrchestratorAgent 作为编排层,动态调度子 Agent
|
||||
- ReconAgent 负责侦察和文件分析
|
||||
- AnalysisAgent 负责漏洞分析
|
||||
- VerificationAgent 负责验证发现
|
||||
|
||||
工作流:
|
||||
START → Recon → Analysis ⟲ → Verification → Report → END
|
||||
START → Orchestrator → [Recon/Analysis/Verification] → Report → END
|
||||
|
||||
支持动态创建子Agent进行专业化分析
|
||||
"""
|
||||
|
||||
# 从 graph 模块导入主要组件
|
||||
from .graph import (
|
||||
AgentRunner,
|
||||
run_agent_task,
|
||||
LLMService,
|
||||
AuditState,
|
||||
create_audit_graph,
|
||||
)
|
||||
|
||||
# 事件管理
|
||||
from .event_manager import EventManager, AgentEventEmitter
|
||||
|
||||
|
|
@ -33,14 +23,14 @@ from .agents import (
|
|||
OrchestratorAgent, ReconAgent, AnalysisAgent, VerificationAgent,
|
||||
)
|
||||
|
||||
# 🔥 新增:核心模块(状态管理、注册表、消息)
|
||||
# 核心模块(状态管理、注册表、消息)
|
||||
from .core import (
|
||||
AgentState, AgentStatus,
|
||||
AgentRegistry, agent_registry,
|
||||
AgentMessage, MessageType, MessagePriority, MessageBus,
|
||||
)
|
||||
|
||||
# 🔥 新增:知识模块系统(基于RAG)
|
||||
# 知识模块系统(基于RAG)
|
||||
from .knowledge import (
|
||||
KnowledgeLoader, knowledge_loader,
|
||||
get_available_modules, get_module_content,
|
||||
|
|
@ -48,7 +38,7 @@ from .knowledge import (
|
|||
SecurityKnowledgeQueryTool, GetVulnerabilityKnowledgeTool,
|
||||
)
|
||||
|
||||
# 🔥 新增:协作工具
|
||||
# 协作工具
|
||||
from .tools import (
|
||||
ThinkTool, ReflectTool,
|
||||
CreateVulnerabilityReportTool,
|
||||
|
|
@ -57,20 +47,11 @@ from .tools import (
|
|||
WaitForMessageTool, AgentFinishTool,
|
||||
)
|
||||
|
||||
# 🔥 新增:遥测模块
|
||||
# 遥测模块
|
||||
from .telemetry import Tracer, get_global_tracer, set_global_tracer
|
||||
|
||||
|
||||
__all__ = [
|
||||
# 核心 Runner
|
||||
"AgentRunner",
|
||||
"run_agent_task",
|
||||
"LLMService",
|
||||
|
||||
# LangGraph
|
||||
"AuditState",
|
||||
"create_audit_graph",
|
||||
|
||||
# 事件管理
|
||||
"EventManager",
|
||||
"AgentEventEmitter",
|
||||
|
|
@ -84,7 +65,7 @@ __all__ = [
|
|||
"AnalysisAgent",
|
||||
"VerificationAgent",
|
||||
|
||||
# 🔥 核心模块
|
||||
# 核心模块
|
||||
"AgentState",
|
||||
"AgentStatus",
|
||||
"AgentRegistry",
|
||||
|
|
@ -94,7 +75,7 @@ __all__ = [
|
|||
"MessagePriority",
|
||||
"MessageBus",
|
||||
|
||||
# 🔥 知识模块(基于RAG)
|
||||
# 知识模块(基于RAG)
|
||||
"KnowledgeLoader",
|
||||
"knowledge_loader",
|
||||
"get_available_modules",
|
||||
|
|
@ -104,7 +85,7 @@ __all__ = [
|
|||
"SecurityKnowledgeQueryTool",
|
||||
"GetVulnerabilityKnowledgeTool",
|
||||
|
||||
# 🔥 协作工具
|
||||
# 协作工具
|
||||
"ThinkTool",
|
||||
"ReflectTool",
|
||||
"CreateVulnerabilityReportTool",
|
||||
|
|
@ -115,9 +96,8 @@ __all__ = [
|
|||
"WaitForMessageTool",
|
||||
"AgentFinishTool",
|
||||
|
||||
# 🔥 遥测模块
|
||||
# 遥测模块
|
||||
"Tracer",
|
||||
"get_global_tracer",
|
||||
"set_global_tracer",
|
||||
]
|
||||
|
||||
|
|
|
|||
|
|
@ -1024,10 +1024,18 @@ class BaseAgent(ABC):
|
|||
elif chunk["type"] == "error":
|
||||
accumulated = chunk.get("accumulated", "")
|
||||
error_msg = chunk.get("error", "Unknown error")
|
||||
logger.error(f"[{self.name}] Stream error: {error_msg}")
|
||||
if accumulated:
|
||||
total_tokens = chunk.get("usage", {}).get("total_tokens", 0)
|
||||
else:
|
||||
error_type = chunk.get("error_type", "unknown")
|
||||
user_message = chunk.get("user_message", error_msg)
|
||||
logger.error(f"[{self.name}] Stream error ({error_type}): {error_msg}")
|
||||
|
||||
if chunk.get("usage"):
|
||||
total_tokens = chunk["usage"].get("total_tokens", 0)
|
||||
|
||||
# 使用特殊前缀标记 API 错误,让调用方能够识别
|
||||
# 格式:[API_ERROR:error_type] user_message
|
||||
if error_type in ("rate_limit", "quota_exceeded", "authentication", "connection"):
|
||||
accumulated = f"[API_ERROR:{error_type}] {user_message}"
|
||||
elif not accumulated:
|
||||
accumulated = f"[系统错误: {error_msg}] 请重新思考并输出你的决策。"
|
||||
break
|
||||
|
||||
|
|
|
|||
|
|
@ -285,6 +285,55 @@ Action Input: {{"参数": "值"}}
|
|||
# 重置空响应计数器
|
||||
self._empty_retry_count = 0
|
||||
|
||||
# 🔥 检查是否是 API 错误(而非格式错误)
|
||||
if llm_output.startswith("[API_ERROR:"):
|
||||
# 提取错误类型和消息
|
||||
match = re.match(r"\[API_ERROR:(\w+)\]\s*(.*)", llm_output)
|
||||
if match:
|
||||
error_type = match.group(1)
|
||||
error_message = match.group(2)
|
||||
|
||||
if error_type == "rate_limit":
|
||||
# 速率限制 - 等待后重试
|
||||
api_retry_count = getattr(self, '_api_retry_count', 0) + 1
|
||||
self._api_retry_count = api_retry_count
|
||||
if api_retry_count >= 3:
|
||||
logger.error(f"[{self.name}] Too many rate limit errors, stopping")
|
||||
await self.emit_event("error", f"API 速率限制重试次数过多: {error_message}")
|
||||
break
|
||||
logger.warning(f"[{self.name}] Rate limit hit, waiting before retry ({api_retry_count}/3)")
|
||||
await self.emit_event("warning", f"API 速率限制,等待后重试 ({api_retry_count}/3)")
|
||||
await asyncio.sleep(30) # 等待 30 秒后重试
|
||||
continue
|
||||
|
||||
elif error_type == "quota_exceeded":
|
||||
# 配额用尽 - 终止任务
|
||||
logger.error(f"[{self.name}] API quota exceeded: {error_message}")
|
||||
await self.emit_event("error", f"API 配额已用尽: {error_message}")
|
||||
break
|
||||
|
||||
elif error_type == "authentication":
|
||||
# 认证错误 - 终止任务
|
||||
logger.error(f"[{self.name}] API authentication error: {error_message}")
|
||||
await self.emit_event("error", f"API 认证失败: {error_message}")
|
||||
break
|
||||
|
||||
elif error_type == "connection":
|
||||
# 连接错误 - 重试
|
||||
api_retry_count = getattr(self, '_api_retry_count', 0) + 1
|
||||
self._api_retry_count = api_retry_count
|
||||
if api_retry_count >= 3:
|
||||
logger.error(f"[{self.name}] Too many connection errors, stopping")
|
||||
await self.emit_event("error", f"API 连接错误重试次数过多: {error_message}")
|
||||
break
|
||||
logger.warning(f"[{self.name}] Connection error, retrying ({api_retry_count}/3)")
|
||||
await self.emit_event("warning", f"API 连接错误,重试中 ({api_retry_count}/3)")
|
||||
await asyncio.sleep(5) # 等待 5 秒后重试
|
||||
continue
|
||||
|
||||
# 重置 API 重试计数器(成功获取响应后)
|
||||
self._api_retry_count = 0
|
||||
|
||||
# 解析 LLM 的决策
|
||||
step = self._parse_llm_response(llm_output)
|
||||
|
||||
|
|
|
|||
|
|
@ -1,28 +0,0 @@
|
|||
"""
|
||||
LangGraph 工作流模块
|
||||
使用状态图构建混合 Agent 审计流程
|
||||
"""
|
||||
|
||||
from .audit_graph import AuditState, create_audit_graph, create_audit_graph_with_human
|
||||
from .nodes import ReconNode, AnalysisNode, VerificationNode, ReportNode, HumanReviewNode
|
||||
from .runner import AgentRunner, run_agent_task, LLMService
|
||||
|
||||
__all__ = [
|
||||
# 状态和图
|
||||
"AuditState",
|
||||
"create_audit_graph",
|
||||
"create_audit_graph_with_human",
|
||||
|
||||
# 节点
|
||||
"ReconNode",
|
||||
"AnalysisNode",
|
||||
"VerificationNode",
|
||||
"ReportNode",
|
||||
"HumanReviewNode",
|
||||
|
||||
# Runner
|
||||
"AgentRunner",
|
||||
"run_agent_task",
|
||||
"LLMService",
|
||||
]
|
||||
|
||||
|
|
@ -1,677 +0,0 @@
|
|||
"""
|
||||
DeepAudit 审计工作流图 - LLM 驱动版
|
||||
使用 LangGraph 构建 LLM 驱动的 Agent 协作流程
|
||||
|
||||
重要改变:路由决策由 LLM 参与,而不是硬编码条件!
|
||||
"""
|
||||
|
||||
from typing import TypedDict, Annotated, List, Dict, Any, Optional, Literal
|
||||
from datetime import datetime
|
||||
import operator
|
||||
import logging
|
||||
import json
|
||||
|
||||
from langgraph.graph import StateGraph, END
|
||||
from langgraph.checkpoint.memory import MemorySaver
|
||||
from langgraph.prebuilt import ToolNode
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
# ============ 状态定义 ============
|
||||
|
||||
class Finding(TypedDict):
|
||||
"""漏洞发现"""
|
||||
id: str
|
||||
vulnerability_type: str
|
||||
severity: str
|
||||
title: str
|
||||
description: str
|
||||
file_path: Optional[str]
|
||||
line_start: Optional[int]
|
||||
code_snippet: Optional[str]
|
||||
is_verified: bool
|
||||
confidence: float
|
||||
source: str
|
||||
|
||||
|
||||
class AuditState(TypedDict):
|
||||
"""
|
||||
审计状态
|
||||
在整个工作流中传递和更新
|
||||
"""
|
||||
# 输入
|
||||
project_root: str
|
||||
project_info: Dict[str, Any]
|
||||
config: Dict[str, Any]
|
||||
task_id: str
|
||||
|
||||
# Recon 阶段输出
|
||||
tech_stack: Dict[str, Any]
|
||||
entry_points: List[Dict[str, Any]]
|
||||
high_risk_areas: List[str]
|
||||
dependencies: Dict[str, Any]
|
||||
|
||||
# Analysis 阶段输出
|
||||
findings: Annotated[List[Finding], operator.add] # 使用 add 合并多轮发现
|
||||
|
||||
# Verification 阶段输出
|
||||
verified_findings: List[Finding]
|
||||
false_positives: List[str]
|
||||
# 🔥 NEW: 验证后的完整 findings(用于替换原始 findings)
|
||||
_verified_findings_update: Optional[List[Finding]]
|
||||
|
||||
# 控制流 - 🔥 关键:LLM 可以设置这些来影响路由
|
||||
current_phase: str
|
||||
iteration: int
|
||||
max_iterations: int
|
||||
should_continue_analysis: bool
|
||||
|
||||
# 🔥 新增:LLM 的路由决策
|
||||
llm_next_action: Optional[str] # LLM 建议的下一步: "continue_analysis", "verify", "report", "end"
|
||||
llm_routing_reason: Optional[str] # LLM 的决策理由
|
||||
|
||||
# 🔥 新增:Agent 间协作的任务交接信息
|
||||
recon_handoff: Optional[Dict[str, Any]] # Recon -> Analysis 的交接
|
||||
analysis_handoff: Optional[Dict[str, Any]] # Analysis -> Verification 的交接
|
||||
verification_handoff: Optional[Dict[str, Any]] # Verification -> Report 的交接
|
||||
|
||||
# 消息和事件
|
||||
messages: Annotated[List[Dict], operator.add]
|
||||
events: Annotated[List[Dict], operator.add]
|
||||
|
||||
# 最终输出
|
||||
summary: Optional[Dict[str, Any]]
|
||||
security_score: Optional[int]
|
||||
error: Optional[str]
|
||||
|
||||
|
||||
# ============ LLM 路由决策器 ============
|
||||
|
||||
class LLMRouter:
|
||||
"""
|
||||
LLM 路由决策器
|
||||
让 LLM 来决定下一步应该做什么
|
||||
"""
|
||||
|
||||
def __init__(self, llm_service):
|
||||
self.llm_service = llm_service
|
||||
|
||||
async def decide_after_recon(self, state: AuditState) -> Dict[str, Any]:
|
||||
"""Recon 后让 LLM 决定下一步"""
|
||||
entry_points = state.get("entry_points", [])
|
||||
high_risk_areas = state.get("high_risk_areas", [])
|
||||
tech_stack = state.get("tech_stack", {})
|
||||
initial_findings = state.get("findings", [])
|
||||
|
||||
prompt = f"""作为安全审计的决策者,基于以下信息收集结果,决定下一步行动。
|
||||
|
||||
## 信息收集结果
|
||||
- 入口点数量: {len(entry_points)}
|
||||
- 高风险区域: {high_risk_areas[:10]}
|
||||
- 技术栈: {tech_stack}
|
||||
- 初步发现: {len(initial_findings)} 个
|
||||
|
||||
## 选项
|
||||
1. "analysis" - 继续进行漏洞分析(推荐:有入口点或高风险区域时)
|
||||
2. "end" - 结束审计(仅当没有任何可分析内容时)
|
||||
|
||||
请返回 JSON 格式:
|
||||
{{"action": "analysis或end", "reason": "决策理由"}}"""
|
||||
|
||||
try:
|
||||
response = await self.llm_service.chat_completion_raw(
|
||||
messages=[
|
||||
{"role": "system", "content": "你是安全审计流程的决策者,负责决定下一步行动。"},
|
||||
{"role": "user", "content": prompt},
|
||||
],
|
||||
# 🔥 不传递 temperature 和 max_tokens,使用用户配置
|
||||
)
|
||||
|
||||
content = response.get("content", "")
|
||||
# 提取 JSON
|
||||
import re
|
||||
json_match = re.search(r'\{.*\}', content, re.DOTALL)
|
||||
if json_match:
|
||||
result = json.loads(json_match.group())
|
||||
return result
|
||||
except Exception as e:
|
||||
logger.warning(f"LLM routing decision failed: {e}")
|
||||
|
||||
# 默认决策
|
||||
if entry_points or high_risk_areas:
|
||||
return {"action": "analysis", "reason": "有可分析内容"}
|
||||
return {"action": "end", "reason": "没有发现入口点或高风险区域"}
|
||||
|
||||
async def decide_after_analysis(self, state: AuditState) -> Dict[str, Any]:
|
||||
"""Analysis 后让 LLM 决定下一步"""
|
||||
findings = state.get("findings", [])
|
||||
iteration = state.get("iteration", 0)
|
||||
max_iterations = state.get("max_iterations", 3)
|
||||
|
||||
# 统计发现
|
||||
severity_counts = {"critical": 0, "high": 0, "medium": 0, "low": 0}
|
||||
for f in findings:
|
||||
# 跳过非字典类型的 finding
|
||||
if not isinstance(f, dict):
|
||||
continue
|
||||
sev = f.get("severity", "medium")
|
||||
severity_counts[sev] = severity_counts.get(sev, 0) + 1
|
||||
|
||||
prompt = f"""作为安全审计的决策者,基于以下分析结果,决定下一步行动。
|
||||
|
||||
## 分析结果
|
||||
- 总发现数: {len(findings)}
|
||||
- 严重程度分布: {severity_counts}
|
||||
- 当前迭代: {iteration}/{max_iterations}
|
||||
|
||||
## 选项
|
||||
1. "verification" - 验证发现的漏洞(推荐:有发现需要验证时)
|
||||
2. "analysis" - 继续深入分析(推荐:发现较少但还有迭代次数时)
|
||||
3. "report" - 生成报告(推荐:没有发现或已充分分析时)
|
||||
|
||||
请返回 JSON 格式:
|
||||
{{"action": "verification/analysis/report", "reason": "决策理由"}}"""
|
||||
|
||||
try:
|
||||
response = await self.llm_service.chat_completion_raw(
|
||||
messages=[
|
||||
{"role": "system", "content": "你是安全审计流程的决策者,负责决定下一步行动。"},
|
||||
{"role": "user", "content": prompt},
|
||||
],
|
||||
# 🔥 不传递 temperature 和 max_tokens,使用用户配置
|
||||
)
|
||||
|
||||
content = response.get("content", "")
|
||||
import re
|
||||
json_match = re.search(r'\{.*\}', content, re.DOTALL)
|
||||
if json_match:
|
||||
result = json.loads(json_match.group())
|
||||
return result
|
||||
except Exception as e:
|
||||
logger.warning(f"LLM routing decision failed: {e}")
|
||||
|
||||
# 默认决策
|
||||
if not findings:
|
||||
return {"action": "report", "reason": "没有发现漏洞"}
|
||||
if len(findings) >= 3 or iteration >= max_iterations:
|
||||
return {"action": "verification", "reason": "有足够的发现需要验证"}
|
||||
return {"action": "analysis", "reason": "发现较少,继续分析"}
|
||||
|
||||
async def decide_after_verification(self, state: AuditState) -> Dict[str, Any]:
|
||||
"""Verification 后让 LLM 决定下一步"""
|
||||
verified_findings = state.get("verified_findings", [])
|
||||
false_positives = state.get("false_positives", [])
|
||||
iteration = state.get("iteration", 0)
|
||||
max_iterations = state.get("max_iterations", 3)
|
||||
|
||||
prompt = f"""作为安全审计的决策者,基于以下验证结果,决定下一步行动。
|
||||
|
||||
## 验证结果
|
||||
- 已确认漏洞: {len(verified_findings)}
|
||||
- 误报数量: {len(false_positives)}
|
||||
- 当前迭代: {iteration}/{max_iterations}
|
||||
|
||||
## 选项
|
||||
1. "analysis" - 回到分析阶段重新分析(推荐:误报率太高时)
|
||||
2. "report" - 生成最终报告(推荐:验证完成时)
|
||||
|
||||
请返回 JSON 格式:
|
||||
{{"action": "analysis/report", "reason": "决策理由"}}"""
|
||||
|
||||
try:
|
||||
response = await self.llm_service.chat_completion_raw(
|
||||
messages=[
|
||||
{"role": "system", "content": "你是安全审计流程的决策者,负责决定下一步行动。"},
|
||||
{"role": "user", "content": prompt},
|
||||
],
|
||||
# 🔥 不传递 temperature 和 max_tokens,使用用户配置
|
||||
)
|
||||
|
||||
content = response.get("content", "")
|
||||
import re
|
||||
json_match = re.search(r'\{.*\}', content, re.DOTALL)
|
||||
if json_match:
|
||||
result = json.loads(json_match.group())
|
||||
return result
|
||||
except Exception as e:
|
||||
logger.warning(f"LLM routing decision failed: {e}")
|
||||
|
||||
# 默认决策
|
||||
if len(false_positives) > len(verified_findings) and iteration < max_iterations:
|
||||
return {"action": "analysis", "reason": "误报率较高,需要重新分析"}
|
||||
return {"action": "report", "reason": "验证完成,生成报告"}
|
||||
|
||||
|
||||
# ============ 路由函数 (结合 LLM 决策) ============
|
||||
|
||||
def route_after_recon(state: AuditState) -> Literal["analysis", "end"]:
|
||||
"""
|
||||
Recon 后的路由决策
|
||||
优先使用 LLM 的决策,否则使用默认逻辑
|
||||
"""
|
||||
# 🔥 检查是否有错误
|
||||
if state.get("error") or state.get("current_phase") == "error":
|
||||
logger.error(f"Recon phase has error, routing to end: {state.get('error')}")
|
||||
return "end"
|
||||
|
||||
# 检查 LLM 是否有决策
|
||||
llm_action = state.get("llm_next_action")
|
||||
if llm_action:
|
||||
logger.info(f"Using LLM routing decision: {llm_action}, reason: {state.get('llm_routing_reason')}")
|
||||
if llm_action == "end":
|
||||
return "end"
|
||||
return "analysis"
|
||||
|
||||
# 默认逻辑(作为 fallback)
|
||||
if not state.get("entry_points") and not state.get("high_risk_areas"):
|
||||
return "end"
|
||||
return "analysis"
|
||||
|
||||
|
||||
def route_after_analysis(state: AuditState) -> Literal["verification", "analysis", "report"]:
|
||||
"""
|
||||
Analysis 后的路由决策
|
||||
优先使用 LLM 的决策
|
||||
"""
|
||||
# 检查 LLM 是否有决策
|
||||
llm_action = state.get("llm_next_action")
|
||||
if llm_action:
|
||||
logger.info(f"Using LLM routing decision: {llm_action}, reason: {state.get('llm_routing_reason')}")
|
||||
if llm_action == "verification":
|
||||
return "verification"
|
||||
elif llm_action == "analysis":
|
||||
return "analysis"
|
||||
elif llm_action == "report":
|
||||
return "report"
|
||||
|
||||
# 默认逻辑
|
||||
findings = state.get("findings", [])
|
||||
iteration = state.get("iteration", 0)
|
||||
max_iterations = state.get("max_iterations", 3)
|
||||
should_continue = state.get("should_continue_analysis", False)
|
||||
|
||||
if not findings:
|
||||
return "report"
|
||||
|
||||
if should_continue and iteration < max_iterations:
|
||||
return "analysis"
|
||||
|
||||
return "verification"
|
||||
|
||||
|
||||
def route_after_verification(state: AuditState) -> Literal["analysis", "report"]:
|
||||
"""
|
||||
Verification 后的路由决策
|
||||
优先使用 LLM 的决策
|
||||
"""
|
||||
# 检查 LLM 是否有决策
|
||||
llm_action = state.get("llm_next_action")
|
||||
if llm_action:
|
||||
logger.info(f"Using LLM routing decision: {llm_action}, reason: {state.get('llm_routing_reason')}")
|
||||
if llm_action == "analysis":
|
||||
return "analysis"
|
||||
return "report"
|
||||
|
||||
# 默认逻辑
|
||||
false_positives = state.get("false_positives", [])
|
||||
iteration = state.get("iteration", 0)
|
||||
max_iterations = state.get("max_iterations", 3)
|
||||
|
||||
if len(false_positives) > len(state.get("verified_findings", [])) and iteration < max_iterations:
|
||||
return "analysis"
|
||||
|
||||
return "report"
|
||||
|
||||
|
||||
# ============ 创建审计图 ============
|
||||
|
||||
def create_audit_graph(
|
||||
recon_node,
|
||||
analysis_node,
|
||||
verification_node,
|
||||
report_node,
|
||||
checkpointer: Optional[MemorySaver] = None,
|
||||
llm_service=None, # 用于 LLM 路由决策
|
||||
) -> StateGraph:
|
||||
"""
|
||||
创建审计工作流图
|
||||
|
||||
Args:
|
||||
recon_node: 信息收集节点
|
||||
analysis_node: 漏洞分析节点
|
||||
verification_node: 漏洞验证节点
|
||||
report_node: 报告生成节点
|
||||
checkpointer: 检查点存储器
|
||||
llm_service: LLM 服务(用于路由决策)
|
||||
|
||||
Returns:
|
||||
编译后的 StateGraph
|
||||
|
||||
工作流结构:
|
||||
|
||||
START
|
||||
│
|
||||
▼
|
||||
┌──────┐
|
||||
│Recon │ 信息收集 (LLM 驱动)
|
||||
└──┬───┘
|
||||
│ LLM 决定
|
||||
▼
|
||||
┌──────────┐
|
||||
│ Analysis │◄─────┐ 漏洞分析 (LLM 驱动,可循环)
|
||||
└────┬─────┘ │
|
||||
│ LLM 决定 │
|
||||
▼ │
|
||||
┌────────────┐ │
|
||||
│Verification│────┘ 漏洞验证 (LLM 驱动,可回溯)
|
||||
└─────┬──────┘
|
||||
│ LLM 决定
|
||||
▼
|
||||
┌──────────┐
|
||||
│ Report │ 报告生成
|
||||
└────┬─────┘
|
||||
│
|
||||
▼
|
||||
END
|
||||
"""
|
||||
|
||||
# 创建状态图
|
||||
workflow = StateGraph(AuditState)
|
||||
|
||||
# 如果有 LLM 服务,创建路由决策器
|
||||
llm_router = LLMRouter(llm_service) if llm_service else None
|
||||
|
||||
# 包装节点以添加 LLM 路由决策
|
||||
async def recon_with_routing(state):
|
||||
result = await recon_node(state)
|
||||
|
||||
# LLM 决定下一步
|
||||
if llm_router:
|
||||
decision = await llm_router.decide_after_recon({**state, **result})
|
||||
result["llm_next_action"] = decision.get("action")
|
||||
result["llm_routing_reason"] = decision.get("reason")
|
||||
|
||||
return result
|
||||
|
||||
async def analysis_with_routing(state):
|
||||
result = await analysis_node(state)
|
||||
|
||||
# LLM 决定下一步
|
||||
if llm_router:
|
||||
decision = await llm_router.decide_after_analysis({**state, **result})
|
||||
result["llm_next_action"] = decision.get("action")
|
||||
result["llm_routing_reason"] = decision.get("reason")
|
||||
|
||||
return result
|
||||
|
||||
async def verification_with_routing(state):
|
||||
result = await verification_node(state)
|
||||
|
||||
# LLM 决定下一步
|
||||
if llm_router:
|
||||
decision = await llm_router.decide_after_verification({**state, **result})
|
||||
result["llm_next_action"] = decision.get("action")
|
||||
result["llm_routing_reason"] = decision.get("reason")
|
||||
|
||||
return result
|
||||
|
||||
# 添加节点
|
||||
if llm_router:
|
||||
workflow.add_node("recon", recon_with_routing)
|
||||
workflow.add_node("analysis", analysis_with_routing)
|
||||
workflow.add_node("verification", verification_with_routing)
|
||||
else:
|
||||
workflow.add_node("recon", recon_node)
|
||||
workflow.add_node("analysis", analysis_node)
|
||||
workflow.add_node("verification", verification_node)
|
||||
|
||||
workflow.add_node("report", report_node)
|
||||
|
||||
# 设置入口点
|
||||
workflow.set_entry_point("recon")
|
||||
|
||||
# 添加条件边
|
||||
workflow.add_conditional_edges(
|
||||
"recon",
|
||||
route_after_recon,
|
||||
{
|
||||
"analysis": "analysis",
|
||||
"end": END,
|
||||
}
|
||||
)
|
||||
|
||||
workflow.add_conditional_edges(
|
||||
"analysis",
|
||||
route_after_analysis,
|
||||
{
|
||||
"verification": "verification",
|
||||
"analysis": "analysis",
|
||||
"report": "report",
|
||||
}
|
||||
)
|
||||
|
||||
workflow.add_conditional_edges(
|
||||
"verification",
|
||||
route_after_verification,
|
||||
{
|
||||
"analysis": "analysis",
|
||||
"report": "report",
|
||||
}
|
||||
)
|
||||
|
||||
# Report -> END
|
||||
workflow.add_edge("report", END)
|
||||
|
||||
# 编译图
|
||||
if checkpointer:
|
||||
return workflow.compile(checkpointer=checkpointer)
|
||||
else:
|
||||
return workflow.compile()
|
||||
|
||||
|
||||
# ============ 带人机协作的审计图 ============
|
||||
|
||||
def create_audit_graph_with_human(
|
||||
recon_node,
|
||||
analysis_node,
|
||||
verification_node,
|
||||
report_node,
|
||||
human_review_node,
|
||||
checkpointer: Optional[MemorySaver] = None,
|
||||
llm_service=None,
|
||||
) -> StateGraph:
|
||||
"""
|
||||
创建带人机协作的审计工作流图
|
||||
|
||||
在验证阶段后增加人工审核节点
|
||||
"""
|
||||
|
||||
workflow = StateGraph(AuditState)
|
||||
llm_router = LLMRouter(llm_service) if llm_service else None
|
||||
|
||||
# 包装节点
|
||||
async def recon_with_routing(state):
|
||||
result = await recon_node(state)
|
||||
if llm_router:
|
||||
decision = await llm_router.decide_after_recon({**state, **result})
|
||||
result["llm_next_action"] = decision.get("action")
|
||||
result["llm_routing_reason"] = decision.get("reason")
|
||||
return result
|
||||
|
||||
async def analysis_with_routing(state):
|
||||
result = await analysis_node(state)
|
||||
if llm_router:
|
||||
decision = await llm_router.decide_after_analysis({**state, **result})
|
||||
result["llm_next_action"] = decision.get("action")
|
||||
result["llm_routing_reason"] = decision.get("reason")
|
||||
return result
|
||||
|
||||
# 添加节点
|
||||
if llm_router:
|
||||
workflow.add_node("recon", recon_with_routing)
|
||||
workflow.add_node("analysis", analysis_with_routing)
|
||||
else:
|
||||
workflow.add_node("recon", recon_node)
|
||||
workflow.add_node("analysis", analysis_node)
|
||||
|
||||
workflow.add_node("verification", verification_node)
|
||||
workflow.add_node("human_review", human_review_node)
|
||||
workflow.add_node("report", report_node)
|
||||
|
||||
workflow.set_entry_point("recon")
|
||||
|
||||
workflow.add_conditional_edges(
|
||||
"recon",
|
||||
route_after_recon,
|
||||
{"analysis": "analysis", "end": END}
|
||||
)
|
||||
|
||||
workflow.add_conditional_edges(
|
||||
"analysis",
|
||||
route_after_analysis,
|
||||
{
|
||||
"verification": "verification",
|
||||
"analysis": "analysis",
|
||||
"report": "report",
|
||||
}
|
||||
)
|
||||
|
||||
# Verification -> Human Review
|
||||
workflow.add_edge("verification", "human_review")
|
||||
|
||||
# Human Review 后的路由
|
||||
def route_after_human(state: AuditState) -> Literal["analysis", "report"]:
|
||||
if state.get("should_continue_analysis"):
|
||||
return "analysis"
|
||||
return "report"
|
||||
|
||||
workflow.add_conditional_edges(
|
||||
"human_review",
|
||||
route_after_human,
|
||||
{"analysis": "analysis", "report": "report"}
|
||||
)
|
||||
|
||||
workflow.add_edge("report", END)
|
||||
|
||||
if checkpointer:
|
||||
return workflow.compile(checkpointer=checkpointer, interrupt_before=["human_review"])
|
||||
else:
|
||||
return workflow.compile()
|
||||
|
||||
|
||||
# ============ 执行器 ============
|
||||
|
||||
class AuditGraphRunner:
|
||||
"""
|
||||
审计图执行器
|
||||
封装 LangGraph 工作流的执行
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
graph: StateGraph,
|
||||
event_emitter=None,
|
||||
):
|
||||
self.graph = graph
|
||||
self.event_emitter = event_emitter
|
||||
|
||||
async def run(
|
||||
self,
|
||||
project_root: str,
|
||||
project_info: Dict[str, Any],
|
||||
config: Dict[str, Any],
|
||||
task_id: str,
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
执行审计工作流
|
||||
"""
|
||||
# 初始状态
|
||||
initial_state: AuditState = {
|
||||
"project_root": project_root,
|
||||
"project_info": project_info,
|
||||
"config": config,
|
||||
"task_id": task_id,
|
||||
"tech_stack": {},
|
||||
"entry_points": [],
|
||||
"high_risk_areas": [],
|
||||
"dependencies": {},
|
||||
"findings": [],
|
||||
"verified_findings": [],
|
||||
"false_positives": [],
|
||||
"_verified_findings_update": None, # 🔥 NEW: 验证后的 findings 更新
|
||||
"current_phase": "start",
|
||||
"iteration": 0,
|
||||
"max_iterations": config.get("max_iterations", 3),
|
||||
"should_continue_analysis": False,
|
||||
"llm_next_action": None,
|
||||
"llm_routing_reason": None,
|
||||
"messages": [],
|
||||
"events": [],
|
||||
"summary": None,
|
||||
"security_score": None,
|
||||
"error": None,
|
||||
}
|
||||
|
||||
run_config = {
|
||||
"configurable": {
|
||||
"thread_id": task_id,
|
||||
}
|
||||
}
|
||||
|
||||
try:
|
||||
async for event in self.graph.astream(initial_state, config=run_config):
|
||||
if self.event_emitter:
|
||||
for node_name, node_state in event.items():
|
||||
await self.event_emitter.emit_info(
|
||||
f"节点 {node_name} 完成"
|
||||
)
|
||||
|
||||
# 发射 LLM 路由决策事件
|
||||
if node_state.get("llm_routing_reason"):
|
||||
await self.event_emitter.emit_info(
|
||||
f"🧠 LLM 决策: {node_state.get('llm_next_action')} - {node_state.get('llm_routing_reason')}"
|
||||
)
|
||||
|
||||
if node_name == "analysis" and node_state.get("findings"):
|
||||
new_findings = node_state["findings"]
|
||||
await self.event_emitter.emit_info(
|
||||
f"发现 {len(new_findings)} 个潜在漏洞"
|
||||
)
|
||||
|
||||
final_state = self.graph.get_state(run_config)
|
||||
return final_state.values
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Graph execution failed: {e}", exc_info=True)
|
||||
raise
|
||||
|
||||
async def run_with_human_review(
|
||||
self,
|
||||
initial_state: AuditState,
|
||||
human_feedback_callback,
|
||||
) -> Dict[str, Any]:
|
||||
"""带人机协作的执行"""
|
||||
run_config = {
|
||||
"configurable": {
|
||||
"thread_id": initial_state["task_id"],
|
||||
}
|
||||
}
|
||||
|
||||
async for event in self.graph.astream(initial_state, config=run_config):
|
||||
pass
|
||||
|
||||
current_state = self.graph.get_state(run_config)
|
||||
|
||||
if current_state.next == ("human_review",):
|
||||
human_decision = await human_feedback_callback(current_state.values)
|
||||
|
||||
updated_state = {
|
||||
**current_state.values,
|
||||
"should_continue_analysis": human_decision.get("continue_analysis", False),
|
||||
}
|
||||
|
||||
async for event in self.graph.astream(updated_state, config=run_config):
|
||||
pass
|
||||
|
||||
return self.graph.get_state(run_config).values
|
||||
|
|
@ -1,556 +0,0 @@
|
|||
"""
|
||||
LangGraph 节点实现
|
||||
每个节点封装一个 Agent 的执行逻辑
|
||||
|
||||
协作增强:节点之间通过 TaskHandoff 传递结构化的上下文和洞察
|
||||
"""
|
||||
|
||||
from typing import Dict, Any, List, Optional
|
||||
import logging
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# 延迟导入避免循环依赖
|
||||
def get_audit_state_type():
|
||||
from .audit_graph import AuditState
|
||||
return AuditState
|
||||
|
||||
|
||||
class BaseNode:
|
||||
"""节点基类"""
|
||||
|
||||
def __init__(self, agent=None, event_emitter=None):
|
||||
self.agent = agent
|
||||
self.event_emitter = event_emitter
|
||||
|
||||
async def emit_event(self, event_type: str, message: str, **kwargs):
|
||||
"""发射事件"""
|
||||
if self.event_emitter:
|
||||
try:
|
||||
await self.event_emitter.emit_info(message)
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to emit event: {e}")
|
||||
|
||||
def _extract_handoff_from_state(self, state: Dict[str, Any], from_phase: str):
|
||||
"""从状态中提取前序 Agent 的 handoff"""
|
||||
handoff_data = state.get(f"{from_phase}_handoff")
|
||||
if handoff_data:
|
||||
from ..agents.base import TaskHandoff
|
||||
return TaskHandoff.from_dict(handoff_data)
|
||||
return None
|
||||
|
||||
|
||||
class ReconNode(BaseNode):
|
||||
"""
|
||||
信息收集节点
|
||||
|
||||
输入: project_root, project_info, config
|
||||
输出: tech_stack, entry_points, high_risk_areas, dependencies, recon_handoff
|
||||
"""
|
||||
|
||||
async def __call__(self, state: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""执行信息收集"""
|
||||
await self.emit_event("phase_start", "🔍 开始信息收集阶段")
|
||||
|
||||
try:
|
||||
# 调用 Recon Agent
|
||||
result = await self.agent.run({
|
||||
"project_info": state["project_info"],
|
||||
"config": state["config"],
|
||||
})
|
||||
|
||||
if result.success and result.data:
|
||||
data = result.data
|
||||
|
||||
# 🔥 创建交接信息给 Analysis Agent
|
||||
handoff = self.agent.create_handoff(
|
||||
to_agent="Analysis",
|
||||
summary=f"项目信息收集完成。发现 {len(data.get('entry_points', []))} 个入口点,{len(data.get('high_risk_areas', []))} 个高风险区域。",
|
||||
key_findings=data.get("initial_findings", []),
|
||||
suggested_actions=[
|
||||
{
|
||||
"type": "deep_analysis",
|
||||
"description": f"深入分析高风险区域: {', '.join(data.get('high_risk_areas', [])[:5])}",
|
||||
"priority": "high",
|
||||
},
|
||||
{
|
||||
"type": "entry_point_audit",
|
||||
"description": "审计所有入口点的输入验证",
|
||||
"priority": "high",
|
||||
},
|
||||
],
|
||||
attention_points=[
|
||||
f"技术栈: {data.get('tech_stack', {}).get('frameworks', [])}",
|
||||
f"主要语言: {data.get('tech_stack', {}).get('languages', [])}",
|
||||
],
|
||||
priority_areas=data.get("high_risk_areas", [])[:10],
|
||||
context_data={
|
||||
"tech_stack": data.get("tech_stack", {}),
|
||||
"entry_points": data.get("entry_points", []),
|
||||
"dependencies": data.get("dependencies", {}),
|
||||
},
|
||||
)
|
||||
|
||||
await self.emit_event(
|
||||
"phase_complete",
|
||||
f"✅ 信息收集完成: 发现 {len(data.get('entry_points', []))} 个入口点"
|
||||
)
|
||||
|
||||
return {
|
||||
"tech_stack": data.get("tech_stack", {}),
|
||||
"entry_points": data.get("entry_points", []),
|
||||
"high_risk_areas": data.get("high_risk_areas", []),
|
||||
"dependencies": data.get("dependencies", {}),
|
||||
"current_phase": "recon_complete",
|
||||
"findings": data.get("initial_findings", []),
|
||||
# 🔥 保存交接信息
|
||||
"recon_handoff": handoff.to_dict(),
|
||||
"events": [{
|
||||
"type": "recon_complete",
|
||||
"data": {
|
||||
"entry_points_count": len(data.get("entry_points", [])),
|
||||
"high_risk_areas_count": len(data.get("high_risk_areas", [])),
|
||||
"handoff_summary": handoff.summary,
|
||||
}
|
||||
}],
|
||||
}
|
||||
else:
|
||||
return {
|
||||
"error": result.error or "Recon failed",
|
||||
"current_phase": "error",
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Recon node failed: {e}", exc_info=True)
|
||||
return {
|
||||
"error": str(e),
|
||||
"current_phase": "error",
|
||||
}
|
||||
|
||||
|
||||
class AnalysisNode(BaseNode):
|
||||
"""
|
||||
漏洞分析节点
|
||||
|
||||
输入: tech_stack, entry_points, high_risk_areas, recon_handoff
|
||||
输出: findings (累加), should_continue_analysis, analysis_handoff
|
||||
"""
|
||||
|
||||
async def __call__(self, state: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""执行漏洞分析"""
|
||||
iteration = state.get("iteration", 0) + 1
|
||||
|
||||
await self.emit_event(
|
||||
"phase_start",
|
||||
f"🔬 开始漏洞分析阶段 (迭代 {iteration})"
|
||||
)
|
||||
|
||||
try:
|
||||
# 🔥 提取 Recon 的交接信息
|
||||
recon_handoff = self._extract_handoff_from_state(state, "recon")
|
||||
if recon_handoff:
|
||||
self.agent.receive_handoff(recon_handoff)
|
||||
await self.emit_event(
|
||||
"handoff_received",
|
||||
f"📨 收到 Recon Agent 交接: {recon_handoff.summary[:50]}..."
|
||||
)
|
||||
|
||||
# 构建分析输入
|
||||
analysis_input = {
|
||||
"phase_name": "analysis",
|
||||
"project_info": state["project_info"],
|
||||
"config": state["config"],
|
||||
"plan": {
|
||||
"high_risk_areas": state.get("high_risk_areas", []),
|
||||
},
|
||||
"previous_results": {
|
||||
"recon": {
|
||||
"data": {
|
||||
"tech_stack": state.get("tech_stack", {}),
|
||||
"entry_points": state.get("entry_points", []),
|
||||
"high_risk_areas": state.get("high_risk_areas", []),
|
||||
}
|
||||
}
|
||||
},
|
||||
# 🔥 传递交接信息
|
||||
"handoff": recon_handoff,
|
||||
}
|
||||
|
||||
# 调用 Analysis Agent
|
||||
result = await self.agent.run(analysis_input)
|
||||
|
||||
if result.success and result.data:
|
||||
new_findings = result.data.get("findings", [])
|
||||
logger.info(f"[AnalysisNode] Agent returned {len(new_findings)} findings")
|
||||
|
||||
# 判断是否需要继续分析
|
||||
should_continue = (
|
||||
len(new_findings) >= 5 and
|
||||
iteration < state.get("max_iterations", 3)
|
||||
)
|
||||
|
||||
# 🔥 创建交接信息给 Verification Agent
|
||||
# 统计严重程度
|
||||
severity_counts = {"critical": 0, "high": 0, "medium": 0, "low": 0}
|
||||
for f in new_findings:
|
||||
if isinstance(f, dict):
|
||||
sev = f.get("severity", "medium")
|
||||
severity_counts[sev] = severity_counts.get(sev, 0) + 1
|
||||
|
||||
handoff = self.agent.create_handoff(
|
||||
to_agent="Verification",
|
||||
summary=f"漏洞分析完成。发现 {len(new_findings)} 个潜在漏洞 (Critical: {severity_counts['critical']}, High: {severity_counts['high']}, Medium: {severity_counts['medium']}, Low: {severity_counts['low']})",
|
||||
key_findings=new_findings[:20], # 传递前20个发现
|
||||
suggested_actions=[
|
||||
{
|
||||
"type": "verify_critical",
|
||||
"description": "优先验证 Critical 和 High 级别的漏洞",
|
||||
"priority": "critical",
|
||||
},
|
||||
{
|
||||
"type": "poc_generation",
|
||||
"description": "为确认的漏洞生成 PoC",
|
||||
"priority": "high",
|
||||
},
|
||||
],
|
||||
attention_points=[
|
||||
f"共 {severity_counts['critical']} 个 Critical 级别漏洞需要立即验证",
|
||||
f"共 {severity_counts['high']} 个 High 级别漏洞需要优先验证",
|
||||
"注意检查是否有误报,特别是静态分析工具的结果",
|
||||
],
|
||||
priority_areas=[
|
||||
f.get("file_path", "") for f in new_findings
|
||||
if f.get("severity") in ["critical", "high"]
|
||||
][:10],
|
||||
context_data={
|
||||
"severity_distribution": severity_counts,
|
||||
"total_findings": len(new_findings),
|
||||
"iteration": iteration,
|
||||
},
|
||||
)
|
||||
|
||||
await self.emit_event(
|
||||
"phase_complete",
|
||||
f"✅ 分析迭代 {iteration} 完成: 发现 {len(new_findings)} 个潜在漏洞"
|
||||
)
|
||||
|
||||
return {
|
||||
"findings": new_findings,
|
||||
"iteration": iteration,
|
||||
"should_continue_analysis": should_continue,
|
||||
"current_phase": "analysis_complete",
|
||||
# 🔥 保存交接信息
|
||||
"analysis_handoff": handoff.to_dict(),
|
||||
"events": [{
|
||||
"type": "analysis_iteration",
|
||||
"data": {
|
||||
"iteration": iteration,
|
||||
"findings_count": len(new_findings),
|
||||
"severity_distribution": severity_counts,
|
||||
"handoff_summary": handoff.summary,
|
||||
}
|
||||
}],
|
||||
}
|
||||
else:
|
||||
return {
|
||||
"iteration": iteration,
|
||||
"should_continue_analysis": False,
|
||||
"current_phase": "analysis_complete",
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Analysis node failed: {e}", exc_info=True)
|
||||
return {
|
||||
"error": str(e),
|
||||
"should_continue_analysis": False,
|
||||
"current_phase": "error",
|
||||
}
|
||||
|
||||
|
||||
class VerificationNode(BaseNode):
|
||||
"""
|
||||
漏洞验证节点
|
||||
|
||||
输入: findings, analysis_handoff
|
||||
输出: verified_findings, false_positives, verification_handoff
|
||||
"""
|
||||
|
||||
async def __call__(self, state: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""执行漏洞验证"""
|
||||
findings = state.get("findings", [])
|
||||
logger.info(f"[VerificationNode] Received {len(findings)} findings to verify")
|
||||
|
||||
if not findings:
|
||||
return {
|
||||
"verified_findings": [],
|
||||
"false_positives": [],
|
||||
"current_phase": "verification_complete",
|
||||
}
|
||||
|
||||
await self.emit_event(
|
||||
"phase_start",
|
||||
f"🔐 开始漏洞验证阶段 ({len(findings)} 个待验证)"
|
||||
)
|
||||
|
||||
try:
|
||||
# 🔥 提取 Analysis 的交接信息
|
||||
analysis_handoff = self._extract_handoff_from_state(state, "analysis")
|
||||
if analysis_handoff:
|
||||
self.agent.receive_handoff(analysis_handoff)
|
||||
await self.emit_event(
|
||||
"handoff_received",
|
||||
f"📨 收到 Analysis Agent 交接: {analysis_handoff.summary[:50]}..."
|
||||
)
|
||||
|
||||
# 构建验证输入
|
||||
verification_input = {
|
||||
"previous_results": {
|
||||
"analysis": {
|
||||
"data": {
|
||||
"findings": findings,
|
||||
}
|
||||
}
|
||||
},
|
||||
"config": state["config"],
|
||||
# 🔥 传递交接信息
|
||||
"handoff": analysis_handoff,
|
||||
}
|
||||
|
||||
# 调用 Verification Agent
|
||||
result = await self.agent.run(verification_input)
|
||||
|
||||
if result.success and result.data:
|
||||
all_verified_findings = result.data.get("findings", [])
|
||||
verified = [f for f in all_verified_findings if f.get("is_verified")]
|
||||
false_pos = [f.get("id", f.get("title", "unknown")) for f in all_verified_findings
|
||||
if f.get("verdict") == "false_positive"]
|
||||
|
||||
# 🔥 CRITICAL FIX: 用验证结果更新原始 findings
|
||||
# 创建 findings 的更新映射,基于 (file_path, line_start, vulnerability_type)
|
||||
verified_map = {}
|
||||
for vf in all_verified_findings:
|
||||
key = (
|
||||
vf.get("file_path", ""),
|
||||
vf.get("line_start", 0),
|
||||
vf.get("vulnerability_type", ""),
|
||||
)
|
||||
verified_map[key] = vf
|
||||
|
||||
# 合并验证结果到原始 findings
|
||||
updated_findings = []
|
||||
seen_keys = set()
|
||||
|
||||
# 首先处理原始 findings,用验证结果更新
|
||||
for f in findings:
|
||||
if not isinstance(f, dict):
|
||||
continue
|
||||
key = (
|
||||
f.get("file_path", ""),
|
||||
f.get("line_start", 0),
|
||||
f.get("vulnerability_type", ""),
|
||||
)
|
||||
if key in verified_map:
|
||||
# 使用验证后的版本
|
||||
updated_findings.append(verified_map[key])
|
||||
seen_keys.add(key)
|
||||
else:
|
||||
# 保留原始(未验证)
|
||||
updated_findings.append(f)
|
||||
seen_keys.add(key)
|
||||
|
||||
# 添加验证结果中的新发现(如果有)
|
||||
for key, vf in verified_map.items():
|
||||
if key not in seen_keys:
|
||||
updated_findings.append(vf)
|
||||
|
||||
logger.info(f"[VerificationNode] Updated findings: {len(updated_findings)} total, {len(verified)} verified")
|
||||
|
||||
# 🔥 创建交接信息给 Report 节点
|
||||
handoff = self.agent.create_handoff(
|
||||
to_agent="Report",
|
||||
summary=f"漏洞验证完成。{len(verified)} 个漏洞已确认,{len(false_pos)} 个误报已排除。",
|
||||
key_findings=verified,
|
||||
suggested_actions=[
|
||||
{
|
||||
"type": "generate_report",
|
||||
"description": "生成详细的安全审计报告",
|
||||
"priority": "high",
|
||||
},
|
||||
{
|
||||
"type": "remediation_plan",
|
||||
"description": "为确认的漏洞制定修复计划",
|
||||
"priority": "high",
|
||||
},
|
||||
],
|
||||
attention_points=[
|
||||
f"共 {len(verified)} 个漏洞已确认存在",
|
||||
f"共 {len(false_pos)} 个误报已排除",
|
||||
"建议按严重程度优先修复 Critical 和 High 级别漏洞",
|
||||
],
|
||||
context_data={
|
||||
"verified_count": len(verified),
|
||||
"false_positive_count": len(false_pos),
|
||||
"total_analyzed": len(findings),
|
||||
"verification_rate": len(verified) / len(findings) if findings else 0,
|
||||
},
|
||||
)
|
||||
|
||||
await self.emit_event(
|
||||
"phase_complete",
|
||||
f"✅ 验证完成: {len(verified)} 已确认, {len(false_pos)} 误报"
|
||||
)
|
||||
|
||||
return {
|
||||
# 🔥 CRITICAL: 返回更新后的 findings,这会替换状态中的 findings
|
||||
# 注意:由于 LangGraph 使用 operator.add,我们需要在 runner 中处理合并
|
||||
# 这里我们返回 _verified_findings_update 作为特殊字段
|
||||
"_verified_findings_update": updated_findings,
|
||||
"verified_findings": verified,
|
||||
"false_positives": false_pos,
|
||||
"current_phase": "verification_complete",
|
||||
# 🔥 保存交接信息
|
||||
"verification_handoff": handoff.to_dict(),
|
||||
"events": [{
|
||||
"type": "verification_complete",
|
||||
"data": {
|
||||
"verified_count": len(verified),
|
||||
"false_positive_count": len(false_pos),
|
||||
"total_findings": len(updated_findings),
|
||||
"handoff_summary": handoff.summary,
|
||||
}
|
||||
}],
|
||||
}
|
||||
else:
|
||||
return {
|
||||
"verified_findings": [],
|
||||
"false_positives": [],
|
||||
"current_phase": "verification_complete",
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Verification node failed: {e}", exc_info=True)
|
||||
return {
|
||||
"error": str(e),
|
||||
"current_phase": "error",
|
||||
}
|
||||
|
||||
|
||||
class ReportNode(BaseNode):
|
||||
"""
|
||||
报告生成节点
|
||||
|
||||
输入: all state
|
||||
输出: summary, security_score
|
||||
"""
|
||||
|
||||
async def __call__(self, state: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""生成审计报告"""
|
||||
await self.emit_event("phase_start", "📊 生成审计报告")
|
||||
|
||||
try:
|
||||
# 🔥 CRITICAL FIX: 优先使用验证后的 findings 更新
|
||||
findings = state.get("_verified_findings_update") or state.get("findings", [])
|
||||
verified = state.get("verified_findings", [])
|
||||
false_positives = state.get("false_positives", [])
|
||||
|
||||
logger.info(f"[ReportNode] State contains {len(findings)} findings, {len(verified)} verified")
|
||||
|
||||
# 统计漏洞分布
|
||||
severity_counts = {"critical": 0, "high": 0, "medium": 0, "low": 0}
|
||||
type_counts = {}
|
||||
|
||||
for finding in findings:
|
||||
# 跳过非字典类型的 finding(防止数据格式异常)
|
||||
if not isinstance(finding, dict):
|
||||
logger.warning(f"Skipping invalid finding (not a dict): {type(finding)}")
|
||||
continue
|
||||
|
||||
sev = finding.get("severity", "medium")
|
||||
severity_counts[sev] = severity_counts.get(sev, 0) + 1
|
||||
|
||||
vtype = finding.get("vulnerability_type", "other")
|
||||
type_counts[vtype] = type_counts.get(vtype, 0) + 1
|
||||
|
||||
# 计算安全评分
|
||||
base_score = 100
|
||||
deductions = (
|
||||
severity_counts["critical"] * 25 +
|
||||
severity_counts["high"] * 15 +
|
||||
severity_counts["medium"] * 8 +
|
||||
severity_counts["low"] * 3
|
||||
)
|
||||
security_score = max(0, base_score - deductions)
|
||||
|
||||
# 生成摘要
|
||||
summary = {
|
||||
"total_findings": len(findings),
|
||||
"verified_count": len(verified),
|
||||
"false_positive_count": len(false_positives),
|
||||
"severity_distribution": severity_counts,
|
||||
"vulnerability_types": type_counts,
|
||||
"tech_stack": state.get("tech_stack", {}),
|
||||
"entry_points_analyzed": len(state.get("entry_points", [])),
|
||||
"high_risk_areas": state.get("high_risk_areas", []),
|
||||
"iterations": state.get("iteration", 1),
|
||||
}
|
||||
|
||||
await self.emit_event(
|
||||
"phase_complete",
|
||||
f"报告生成完成: 安全评分 {security_score}/100"
|
||||
)
|
||||
|
||||
return {
|
||||
"summary": summary,
|
||||
"security_score": security_score,
|
||||
"current_phase": "complete",
|
||||
"events": [{
|
||||
"type": "audit_complete",
|
||||
"data": {
|
||||
"security_score": security_score,
|
||||
"total_findings": len(findings),
|
||||
"verified_count": len(verified),
|
||||
}
|
||||
}],
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Report node failed: {e}", exc_info=True)
|
||||
return {
|
||||
"error": str(e),
|
||||
"current_phase": "error",
|
||||
}
|
||||
|
||||
|
||||
class HumanReviewNode(BaseNode):
|
||||
"""
|
||||
人工审核节点
|
||||
|
||||
在此节点暂停,等待人工反馈
|
||||
"""
|
||||
|
||||
async def __call__(self, state: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""
|
||||
人工审核节点
|
||||
|
||||
这个节点会被 interrupt_before 暂停
|
||||
用户可以:
|
||||
1. 确认发现
|
||||
2. 标记误报
|
||||
3. 请求重新分析
|
||||
"""
|
||||
await self.emit_event(
|
||||
"human_review",
|
||||
f"⏸️ 等待人工审核 ({len(state.get('verified_findings', []))} 个待确认)"
|
||||
)
|
||||
|
||||
# 返回当前状态,不做修改
|
||||
# 人工反馈会通过 update_state 传入
|
||||
return {
|
||||
"current_phase": "human_review",
|
||||
"messages": [{
|
||||
"role": "system",
|
||||
"content": "等待人工审核",
|
||||
"findings_for_review": state.get("verified_findings", []),
|
||||
}],
|
||||
}
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
|
|
@ -6,6 +6,7 @@
|
|||
import os
|
||||
import re
|
||||
import fnmatch
|
||||
import asyncio
|
||||
from typing import Optional, List, Dict, Any
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
|
@ -45,6 +46,36 @@ class FileReadTool(AgentTool):
|
|||
self.exclude_patterns = exclude_patterns or []
|
||||
self.target_files = set(target_files) if target_files else None
|
||||
|
||||
@staticmethod
|
||||
def _read_file_lines_sync(file_path: str, start_idx: int, end_idx: int) -> tuple:
|
||||
"""同步读取文件指定行范围(用于 asyncio.to_thread)"""
|
||||
selected_lines = []
|
||||
total_lines = 0
|
||||
file_size = os.path.getsize(file_path)
|
||||
|
||||
with open(file_path, 'r', encoding='utf-8', errors='ignore') as f:
|
||||
for i, line in enumerate(f):
|
||||
total_lines = i + 1
|
||||
if i >= start_idx and i < end_idx:
|
||||
selected_lines.append(line)
|
||||
elif i >= end_idx:
|
||||
if i < end_idx + 1000:
|
||||
continue
|
||||
else:
|
||||
remaining_bytes = file_size - f.tell()
|
||||
avg_line_size = f.tell() / (i + 1)
|
||||
estimated_remaining_lines = int(remaining_bytes / avg_line_size) if avg_line_size > 0 else 0
|
||||
total_lines = i + 1 + estimated_remaining_lines
|
||||
break
|
||||
|
||||
return selected_lines, total_lines
|
||||
|
||||
@staticmethod
|
||||
def _read_all_lines_sync(file_path: str) -> List[str]:
|
||||
"""同步读取文件所有行(用于 asyncio.to_thread)"""
|
||||
with open(file_path, 'r', encoding='utf-8', errors='ignore') as f:
|
||||
return f.readlines()
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
return "read_file"
|
||||
|
|
@ -136,37 +167,20 @@ class FileReadTool(AgentTool):
|
|||
|
||||
# 🔥 对于大文件,使用流式读取指定行范围
|
||||
if is_large_file and (start_line is not None or end_line is not None):
|
||||
# 流式读取,避免一次性加载整个文件
|
||||
selected_lines = []
|
||||
total_lines = 0
|
||||
|
||||
# 计算实际的起始和结束行
|
||||
start_idx = max(0, (start_line or 1) - 1)
|
||||
end_idx = end_line if end_line else start_idx + max_lines
|
||||
|
||||
with open(full_path, 'r', encoding='utf-8', errors='ignore') as f:
|
||||
for i, line in enumerate(f):
|
||||
total_lines = i + 1
|
||||
if i >= start_idx and i < end_idx:
|
||||
selected_lines.append(line)
|
||||
elif i >= end_idx:
|
||||
# 继续计数以获取总行数,但限制读取量
|
||||
if i < end_idx + 1000: # 最多再读1000行来估算总行数
|
||||
continue
|
||||
else:
|
||||
# 估算剩余行数
|
||||
remaining_bytes = file_size - f.tell()
|
||||
avg_line_size = f.tell() / (i + 1)
|
||||
estimated_remaining_lines = int(remaining_bytes / avg_line_size) if avg_line_size > 0 else 0
|
||||
total_lines = i + 1 + estimated_remaining_lines
|
||||
break
|
||||
# 异步读取文件,避免阻塞事件循环
|
||||
selected_lines, total_lines = await asyncio.to_thread(
|
||||
self._read_file_lines_sync, full_path, start_idx, end_idx
|
||||
)
|
||||
|
||||
# 更新实际的结束索引
|
||||
end_idx = min(end_idx, start_idx + len(selected_lines))
|
||||
else:
|
||||
# 正常读取小文件
|
||||
with open(full_path, 'r', encoding='utf-8', errors='ignore') as f:
|
||||
lines = f.readlines()
|
||||
# 异步读取小文件,避免阻塞事件循环
|
||||
lines = await asyncio.to_thread(self._read_all_lines_sync, full_path)
|
||||
|
||||
total_lines = len(lines)
|
||||
|
||||
|
|
@ -268,6 +282,12 @@ class FileSearchTool(AgentTool):
|
|||
elif "/" not in pattern and "*" not in pattern:
|
||||
self.exclude_dirs.add(pattern)
|
||||
|
||||
@staticmethod
|
||||
def _read_file_lines_sync(file_path: str) -> List[str]:
|
||||
"""同步读取文件所有行(用于 asyncio.to_thread)"""
|
||||
with open(file_path, 'r', encoding='utf-8', errors='ignore') as f:
|
||||
return f.readlines()
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
return "search_code"
|
||||
|
|
@ -360,8 +380,10 @@ class FileSearchTool(AgentTool):
|
|||
continue
|
||||
|
||||
try:
|
||||
with open(file_path, 'r', encoding='utf-8', errors='ignore') as f:
|
||||
lines = f.readlines()
|
||||
# 异步读取文件,避免阻塞事件循环
|
||||
lines = await asyncio.to_thread(
|
||||
self._read_file_lines_sync, file_path
|
||||
)
|
||||
|
||||
files_searched += 1
|
||||
|
||||
|
|
|
|||
|
|
@ -416,13 +416,93 @@ class LiteLLMAdapter(BaseLLMAdapter):
|
|||
"finish_reason": "complete",
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
# 🔥 即使出错,也尝试返回估算的 usage
|
||||
logger.error(f"Stream error: {e}")
|
||||
except litellm.exceptions.RateLimitError as e:
|
||||
# 速率限制错误 - 需要特殊处理
|
||||
logger.error(f"Stream rate limit error: {e}")
|
||||
error_msg = str(e)
|
||||
# 区分"余额不足"和"频率超限"
|
||||
if any(keyword in error_msg.lower() for keyword in ["余额不足", "资源包", "充值", "quota", "exceeded", "billing"]):
|
||||
error_type = "quota_exceeded"
|
||||
user_message = "API 配额已用尽,请检查账户余额或升级计划"
|
||||
else:
|
||||
error_type = "rate_limit"
|
||||
# 尝试从错误消息中提取重试时间
|
||||
import re
|
||||
retry_match = re.search(r"retry\s*(?:in|after)\s*(\d+(?:\.\d+)?)\s*s", error_msg, re.IGNORECASE)
|
||||
retry_seconds = float(retry_match.group(1)) if retry_match else 60
|
||||
user_message = f"API 调用频率超限,建议等待 {int(retry_seconds)} 秒后重试"
|
||||
|
||||
output_tokens_estimate = estimate_tokens(accumulated_content) if accumulated_content else 0
|
||||
yield {
|
||||
"type": "error",
|
||||
"error_type": error_type,
|
||||
"error": error_msg,
|
||||
"user_message": user_message,
|
||||
"accumulated": accumulated_content,
|
||||
"usage": {
|
||||
"prompt_tokens": input_tokens_estimate,
|
||||
"completion_tokens": output_tokens_estimate,
|
||||
"total_tokens": input_tokens_estimate + output_tokens_estimate,
|
||||
} if accumulated_content else None,
|
||||
}
|
||||
|
||||
except litellm.exceptions.AuthenticationError as e:
|
||||
# 认证错误 - API Key 无效
|
||||
logger.error(f"Stream authentication error: {e}")
|
||||
yield {
|
||||
"type": "error",
|
||||
"error_type": "authentication",
|
||||
"error": str(e),
|
||||
"user_message": "API Key 无效或已过期,请检查配置",
|
||||
"accumulated": accumulated_content,
|
||||
"usage": None,
|
||||
}
|
||||
|
||||
except litellm.exceptions.APIConnectionError as e:
|
||||
# 连接错误 - 网络问题
|
||||
logger.error(f"Stream connection error: {e}")
|
||||
yield {
|
||||
"type": "error",
|
||||
"error_type": "connection",
|
||||
"error": str(e),
|
||||
"user_message": "无法连接到 API 服务,请检查网络连接",
|
||||
"accumulated": accumulated_content,
|
||||
"usage": None,
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
# 其他错误 - 检查是否是包装的速率限制错误
|
||||
error_msg = str(e)
|
||||
logger.error(f"Stream error: {e}")
|
||||
|
||||
# 检查是否是包装的速率限制错误(如 ServiceUnavailableError 包装 RateLimitError)
|
||||
is_rate_limit = any(keyword in error_msg.lower() for keyword in [
|
||||
"ratelimiterror", "rate limit", "429", "resource_exhausted",
|
||||
"quota exceeded", "too many requests"
|
||||
])
|
||||
|
||||
if is_rate_limit:
|
||||
# 按速率限制错误处理
|
||||
import re
|
||||
# 检查是否是配额用尽
|
||||
if any(keyword in error_msg.lower() for keyword in ["quota", "exceeded", "billing"]):
|
||||
error_type = "quota_exceeded"
|
||||
user_message = "API 配额已用尽,请检查账户余额或升级计划"
|
||||
else:
|
||||
error_type = "rate_limit"
|
||||
retry_match = re.search(r"retry\s*(?:in|after)\s*(\d+(?:\.\d+)?)\s*s", error_msg, re.IGNORECASE)
|
||||
retry_seconds = float(retry_match.group(1)) if retry_match else 60
|
||||
user_message = f"API 调用频率超限,建议等待 {int(retry_seconds)} 秒后重试"
|
||||
else:
|
||||
error_type = "unknown"
|
||||
user_message = "LLM 调用发生错误,请重试"
|
||||
|
||||
output_tokens_estimate = estimate_tokens(accumulated_content) if accumulated_content else 0
|
||||
yield {
|
||||
"type": "error",
|
||||
"error_type": error_type,
|
||||
"error": error_msg,
|
||||
"user_message": user_message,
|
||||
"accumulated": accumulated_content,
|
||||
"usage": {
|
||||
"prompt_tokens": input_tokens_estimate,
|
||||
|
|
|
|||
|
|
@ -739,6 +739,20 @@ class CodeIndexer:
|
|||
self._needs_rebuild = False
|
||||
self._rebuild_reason = ""
|
||||
|
||||
@staticmethod
|
||||
def _read_file_sync(file_path: str) -> str:
|
||||
"""
|
||||
同步读取文件内容(用于 asyncio.to_thread 包装)
|
||||
|
||||
Args:
|
||||
file_path: 文件路径
|
||||
|
||||
Returns:
|
||||
文件内容
|
||||
"""
|
||||
with open(file_path, 'r', encoding='utf-8', errors='ignore') as f:
|
||||
return f.read()
|
||||
|
||||
async def initialize(self, force_rebuild: bool = False) -> Tuple[bool, str]:
|
||||
"""
|
||||
初始化索引器,检测是否需要重建索引
|
||||
|
|
@ -916,8 +930,10 @@ class CodeIndexer:
|
|||
try:
|
||||
relative_path = os.path.relpath(file_path, directory)
|
||||
|
||||
with open(file_path, 'r', encoding='utf-8', errors='ignore') as f:
|
||||
content = f.read()
|
||||
# 异步读取文件,避免阻塞事件循环
|
||||
content = await asyncio.to_thread(
|
||||
self._read_file_sync, file_path
|
||||
)
|
||||
|
||||
if not content.strip():
|
||||
progress.processed_files += 1
|
||||
|
|
@ -932,8 +948,8 @@ class CodeIndexer:
|
|||
if len(content) > 500000:
|
||||
content = content[:500000]
|
||||
|
||||
# 分块
|
||||
chunks = self.splitter.split_file(content, relative_path)
|
||||
# 异步分块,避免 Tree-sitter 解析阻塞事件循环
|
||||
chunks = await self.splitter.split_file_async(content, relative_path)
|
||||
|
||||
# 为每个 chunk 添加 file_hash
|
||||
for chunk in chunks:
|
||||
|
|
@ -1018,8 +1034,10 @@ class CodeIndexer:
|
|||
for relative_path in files_to_check:
|
||||
file_path = current_file_map[relative_path]
|
||||
try:
|
||||
with open(file_path, 'r', encoding='utf-8', errors='ignore') as f:
|
||||
content = f.read()
|
||||
# 异步读取文件,避免阻塞事件循环
|
||||
content = await asyncio.to_thread(
|
||||
self._read_file_sync, file_path
|
||||
)
|
||||
current_hash = hashlib.md5(content.encode()).hexdigest()
|
||||
if current_hash != indexed_file_hashes.get(relative_path):
|
||||
files_to_update.add(relative_path)
|
||||
|
|
@ -1055,8 +1073,10 @@ class CodeIndexer:
|
|||
is_update = relative_path in files_to_update
|
||||
|
||||
try:
|
||||
with open(file_path, 'r', encoding='utf-8', errors='ignore') as f:
|
||||
content = f.read()
|
||||
# 异步读取文件,避免阻塞事件循环
|
||||
content = await asyncio.to_thread(
|
||||
self._read_file_sync, file_path
|
||||
)
|
||||
|
||||
if not content.strip():
|
||||
progress.processed_files += 1
|
||||
|
|
@ -1075,8 +1095,8 @@ class CodeIndexer:
|
|||
if len(content) > 500000:
|
||||
content = content[:500000]
|
||||
|
||||
# 分块
|
||||
chunks = self.splitter.split_file(content, relative_path)
|
||||
# 异步分块,避免 Tree-sitter 解析阻塞事件循环
|
||||
chunks = await self.splitter.split_file_async(content, relative_path)
|
||||
|
||||
# 为每个 chunk 添加 file_hash
|
||||
for chunk in chunks:
|
||||
|
|
|
|||
|
|
@ -4,6 +4,7 @@
|
|||
"""
|
||||
|
||||
import re
|
||||
import asyncio
|
||||
import hashlib
|
||||
import logging
|
||||
from typing import List, Dict, Any, Optional, Tuple, Set
|
||||
|
|
@ -154,7 +155,7 @@ class TreeSitterParser:
|
|||
".c": "c",
|
||||
".h": "c",
|
||||
".hpp": "cpp",
|
||||
".cs": "c_sharp",
|
||||
".cs": "csharp",
|
||||
".php": "php",
|
||||
".rb": "ruby",
|
||||
".kt": "kotlin",
|
||||
|
|
@ -197,7 +198,7 @@ class TreeSitterParser:
|
|||
# tree-sitter-languages 支持的语言列表
|
||||
SUPPORTED_LANGUAGES = {
|
||||
"python", "javascript", "typescript", "tsx", "java", "go", "rust",
|
||||
"c", "cpp", "c_sharp", "php", "ruby", "kotlin", "swift", "bash",
|
||||
"c", "cpp", "csharp", "php", "ruby", "kotlin", "swift", "bash",
|
||||
"json", "yaml", "html", "css", "sql", "markdown",
|
||||
}
|
||||
|
||||
|
|
@ -230,7 +231,7 @@ class TreeSitterParser:
|
|||
return False
|
||||
|
||||
def parse(self, code: str, language: str) -> Optional[Any]:
|
||||
"""解析代码返回 AST"""
|
||||
"""解析代码返回 AST(同步方法)"""
|
||||
if not self._ensure_initialized(language):
|
||||
return None
|
||||
|
||||
|
|
@ -245,6 +246,15 @@ class TreeSitterParser:
|
|||
logger.warning(f"Failed to parse code: {e}")
|
||||
return None
|
||||
|
||||
async def parse_async(self, code: str, language: str) -> Optional[Any]:
|
||||
"""
|
||||
异步解析代码返回 AST
|
||||
|
||||
将 CPU 密集型的 Tree-sitter 解析操作放到线程池中执行,
|
||||
避免阻塞事件循环
|
||||
"""
|
||||
return await asyncio.to_thread(self.parse, code, language)
|
||||
|
||||
def extract_definitions(self, tree: Any, code: str, language: str) -> List[Dict[str, Any]]:
|
||||
"""从 AST 提取定义"""
|
||||
if tree is None:
|
||||
|
|
@ -452,6 +462,28 @@ class CodeSplitter:
|
|||
|
||||
return chunks
|
||||
|
||||
async def split_file_async(
|
||||
self,
|
||||
content: str,
|
||||
file_path: str,
|
||||
language: Optional[str] = None
|
||||
) -> List[CodeChunk]:
|
||||
"""
|
||||
异步分割单个文件
|
||||
|
||||
将 CPU 密集型的分块操作(包括 Tree-sitter 解析)放到线程池中执行,
|
||||
避免阻塞事件循环。
|
||||
|
||||
Args:
|
||||
content: 文件内容
|
||||
file_path: 文件路径
|
||||
language: 编程语言(可选)
|
||||
|
||||
Returns:
|
||||
代码块列表
|
||||
"""
|
||||
return await asyncio.to_thread(self.split_file, content, file_path, language)
|
||||
|
||||
def _split_by_ast(
|
||||
self,
|
||||
content: str,
|
||||
|
|
|
|||
|
|
@ -1,355 +0,0 @@
|
|||
"""
|
||||
Agent 集成测试
|
||||
测试完整的审计流程
|
||||
"""
|
||||
|
||||
import pytest
|
||||
import asyncio
|
||||
import os
|
||||
from unittest.mock import MagicMock, AsyncMock, patch
|
||||
from datetime import datetime
|
||||
|
||||
from app.services.agent.graph.runner import AgentRunner, LLMService
|
||||
from app.services.agent.graph.audit_graph import AuditState, create_audit_graph
|
||||
from app.services.agent.graph.nodes import ReconNode, AnalysisNode, VerificationNode, ReportNode
|
||||
from app.services.agent.event_manager import EventManager, AgentEventEmitter
|
||||
|
||||
|
||||
class TestLLMService:
|
||||
"""LLM 服务测试"""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_llm_service_initialization(self):
|
||||
"""测试 LLM 服务初始化"""
|
||||
with patch("app.core.config.settings") as mock_settings:
|
||||
mock_settings.LLM_MODEL = "gpt-4o-mini"
|
||||
mock_settings.LLM_API_KEY = "test-key"
|
||||
|
||||
service = LLMService()
|
||||
|
||||
assert service.model == "gpt-4o-mini"
|
||||
|
||||
|
||||
class TestEventManager:
|
||||
"""事件管理器测试"""
|
||||
|
||||
def test_event_manager_initialization(self):
|
||||
"""测试事件管理器初始化"""
|
||||
manager = EventManager()
|
||||
|
||||
assert manager._event_queues == {}
|
||||
assert manager._event_callbacks == {}
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_event_emitter(self):
|
||||
"""测试事件发射器"""
|
||||
manager = EventManager()
|
||||
emitter = AgentEventEmitter("test-task-id", manager)
|
||||
|
||||
await emitter.emit_info("Test message")
|
||||
|
||||
assert emitter._sequence == 1
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_event_emitter_phase_tracking(self):
|
||||
"""测试事件发射器阶段跟踪"""
|
||||
manager = EventManager()
|
||||
emitter = AgentEventEmitter("test-task-id", manager)
|
||||
|
||||
await emitter.emit_phase_start("recon", "开始信息收集")
|
||||
|
||||
assert emitter._current_phase == "recon"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_event_emitter_task_complete(self):
|
||||
"""测试任务完成事件"""
|
||||
manager = EventManager()
|
||||
emitter = AgentEventEmitter("test-task-id", manager)
|
||||
|
||||
await emitter.emit_task_complete(findings_count=5, duration_ms=1000)
|
||||
|
||||
assert emitter._sequence == 1
|
||||
|
||||
|
||||
class TestAuditGraph:
|
||||
"""审计图测试"""
|
||||
|
||||
def test_create_audit_graph(self, mock_event_emitter):
|
||||
"""测试创建审计图"""
|
||||
# 创建模拟节点
|
||||
recon_node = MagicMock()
|
||||
analysis_node = MagicMock()
|
||||
verification_node = MagicMock()
|
||||
report_node = MagicMock()
|
||||
|
||||
graph = create_audit_graph(
|
||||
recon_node=recon_node,
|
||||
analysis_node=analysis_node,
|
||||
verification_node=verification_node,
|
||||
report_node=report_node,
|
||||
)
|
||||
|
||||
assert graph is not None
|
||||
|
||||
|
||||
class TestReconNode:
|
||||
"""Recon 节点测试"""
|
||||
|
||||
@pytest.fixture
|
||||
def recon_node_with_mock_agent(self, mock_event_emitter):
|
||||
"""创建带模拟 Agent 的 Recon 节点"""
|
||||
mock_agent = MagicMock()
|
||||
mock_agent.run = AsyncMock(return_value=MagicMock(
|
||||
success=True,
|
||||
data={
|
||||
"tech_stack": {"languages": ["Python"]},
|
||||
"entry_points": [{"path": "src/app.py", "type": "api"}],
|
||||
"high_risk_areas": ["src/sql_vuln.py"],
|
||||
"dependencies": {},
|
||||
"initial_findings": [],
|
||||
}
|
||||
))
|
||||
|
||||
return ReconNode(mock_agent, mock_event_emitter)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_recon_node_success(self, recon_node_with_mock_agent):
|
||||
"""测试 Recon 节点成功执行"""
|
||||
state = {
|
||||
"project_info": {"name": "Test"},
|
||||
"config": {},
|
||||
}
|
||||
|
||||
result = await recon_node_with_mock_agent(state)
|
||||
|
||||
assert "tech_stack" in result
|
||||
assert "entry_points" in result
|
||||
assert result["current_phase"] == "recon_complete"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_recon_node_failure(self, mock_event_emitter):
|
||||
"""测试 Recon 节点失败处理"""
|
||||
mock_agent = MagicMock()
|
||||
mock_agent.run = AsyncMock(return_value=MagicMock(
|
||||
success=False,
|
||||
error="Test error",
|
||||
data=None,
|
||||
))
|
||||
|
||||
node = ReconNode(mock_agent, mock_event_emitter)
|
||||
|
||||
result = await node({
|
||||
"project_info": {},
|
||||
"config": {},
|
||||
})
|
||||
|
||||
assert "error" in result
|
||||
assert result["current_phase"] == "error"
|
||||
|
||||
|
||||
class TestAnalysisNode:
|
||||
"""Analysis 节点测试"""
|
||||
|
||||
@pytest.fixture
|
||||
def analysis_node_with_mock_agent(self, mock_event_emitter):
|
||||
"""创建带模拟 Agent 的 Analysis 节点"""
|
||||
mock_agent = MagicMock()
|
||||
mock_agent.run = AsyncMock(return_value=MagicMock(
|
||||
success=True,
|
||||
data={
|
||||
"findings": [
|
||||
{
|
||||
"id": "finding-1",
|
||||
"title": "SQL Injection",
|
||||
"severity": "high",
|
||||
"vulnerability_type": "sql_injection",
|
||||
"file_path": "src/sql_vuln.py",
|
||||
"line_start": 10,
|
||||
"description": "SQL injection vulnerability",
|
||||
}
|
||||
],
|
||||
"should_continue": False,
|
||||
}
|
||||
))
|
||||
|
||||
return AnalysisNode(mock_agent, mock_event_emitter)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_analysis_node_success(self, analysis_node_with_mock_agent):
|
||||
"""测试 Analysis 节点成功执行"""
|
||||
state = {
|
||||
"project_info": {"name": "Test"},
|
||||
"tech_stack": {"languages": ["Python"]},
|
||||
"entry_points": [],
|
||||
"high_risk_areas": ["src/sql_vuln.py"],
|
||||
"config": {},
|
||||
"iteration": 0,
|
||||
"findings": [],
|
||||
}
|
||||
|
||||
result = await analysis_node_with_mock_agent(state)
|
||||
|
||||
assert "findings" in result
|
||||
assert len(result["findings"]) > 0
|
||||
assert result["iteration"] == 1
|
||||
|
||||
|
||||
class TestIntegrationFlow:
|
||||
"""完整流程集成测试"""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_full_audit_flow_mock(self, temp_project_dir, mock_db_session, mock_task):
|
||||
"""测试完整审计流程(使用模拟)"""
|
||||
# 这个测试验证整个流程的连接性
|
||||
|
||||
# 创建事件管理器
|
||||
event_manager = EventManager()
|
||||
emitter = AgentEventEmitter(mock_task.id, event_manager)
|
||||
|
||||
# 模拟 LLM 服务
|
||||
mock_llm = MagicMock()
|
||||
mock_llm.chat_completion_raw = AsyncMock(return_value={
|
||||
"content": "Analysis complete",
|
||||
"usage": {"total_tokens": 100},
|
||||
})
|
||||
|
||||
# 验证事件发射
|
||||
await emitter.emit_phase_start("init", "初始化")
|
||||
await emitter.emit_info("测试消息")
|
||||
await emitter.emit_phase_complete("init", "初始化完成")
|
||||
|
||||
assert emitter._sequence == 3
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_audit_state_typing(self):
|
||||
"""测试审计状态类型定义"""
|
||||
state: AuditState = {
|
||||
"project_root": "/tmp/test",
|
||||
"project_info": {"name": "Test"},
|
||||
"config": {},
|
||||
"task_id": "test-id",
|
||||
"tech_stack": {},
|
||||
"entry_points": [],
|
||||
"high_risk_areas": [],
|
||||
"dependencies": {},
|
||||
"findings": [],
|
||||
"verified_findings": [],
|
||||
"false_positives": [],
|
||||
"current_phase": "start",
|
||||
"iteration": 0,
|
||||
"max_iterations": 50,
|
||||
"should_continue_analysis": False,
|
||||
"messages": [],
|
||||
"events": [],
|
||||
"summary": None,
|
||||
"security_score": None,
|
||||
"error": None,
|
||||
}
|
||||
|
||||
assert state["current_phase"] == "start"
|
||||
assert state["max_iterations"] == 50
|
||||
|
||||
|
||||
class TestToolIntegration:
|
||||
"""工具集成测试"""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_tools_work_together(self, temp_project_dir):
|
||||
"""测试工具协同工作"""
|
||||
from app.services.agent.tools import (
|
||||
FileReadTool, FileSearchTool, ListFilesTool, PatternMatchTool,
|
||||
)
|
||||
|
||||
# 1. 列出文件
|
||||
list_tool = ListFilesTool(temp_project_dir)
|
||||
list_result = await list_tool.execute(directory="src", recursive=False)
|
||||
assert list_result.success is True
|
||||
|
||||
# 2. 搜索关键代码
|
||||
search_tool = FileSearchTool(temp_project_dir)
|
||||
search_result = await search_tool.execute(keyword="execute")
|
||||
assert search_result.success is True
|
||||
|
||||
# 3. 读取文件内容
|
||||
read_tool = FileReadTool(temp_project_dir)
|
||||
read_result = await read_tool.execute(file_path="src/sql_vuln.py")
|
||||
assert read_result.success is True
|
||||
|
||||
# 4. 模式匹配
|
||||
pattern_tool = PatternMatchTool(temp_project_dir)
|
||||
pattern_result = await pattern_tool.execute(
|
||||
code=read_result.data,
|
||||
file_path="src/sql_vuln.py",
|
||||
language="python"
|
||||
)
|
||||
assert pattern_result.success is True
|
||||
|
||||
|
||||
class TestErrorHandling:
|
||||
"""错误处理测试"""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_tool_error_handling(self, temp_project_dir):
|
||||
"""测试工具错误处理"""
|
||||
from app.services.agent.tools import FileReadTool
|
||||
|
||||
tool = FileReadTool(temp_project_dir)
|
||||
|
||||
# 尝试读取不存在的文件
|
||||
result = await tool.execute(file_path="nonexistent/file.py")
|
||||
|
||||
assert result.success is False
|
||||
assert result.error is not None
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_agent_graceful_degradation(self, mock_event_emitter):
|
||||
"""测试 Agent 优雅降级"""
|
||||
# 创建一个会失败的 Agent
|
||||
mock_agent = MagicMock()
|
||||
mock_agent.run = AsyncMock(side_effect=Exception("Simulated error"))
|
||||
|
||||
node = ReconNode(mock_agent, mock_event_emitter)
|
||||
|
||||
result = await node({
|
||||
"project_info": {},
|
||||
"config": {},
|
||||
})
|
||||
|
||||
# 应该返回错误状态而不是崩溃
|
||||
assert "error" in result
|
||||
assert result["current_phase"] == "error"
|
||||
|
||||
|
||||
class TestPerformance:
|
||||
"""性能测试"""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_tool_response_time(self, temp_project_dir):
|
||||
"""测试工具响应时间"""
|
||||
from app.services.agent.tools import ListFilesTool
|
||||
import time
|
||||
|
||||
tool = ListFilesTool(temp_project_dir)
|
||||
|
||||
start = time.time()
|
||||
await tool.execute(directory=".", recursive=True)
|
||||
duration = time.time() - start
|
||||
|
||||
# 工具应该在合理时间内响应
|
||||
assert duration < 5.0 # 5 秒内
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_multiple_tool_calls(self, temp_project_dir):
|
||||
"""测试多次工具调用"""
|
||||
from app.services.agent.tools import FileSearchTool
|
||||
|
||||
tool = FileSearchTool(temp_project_dir)
|
||||
|
||||
# 执行多次调用
|
||||
for _ in range(5):
|
||||
result = await tool.execute(keyword="def")
|
||||
assert result.success is True
|
||||
|
||||
# 验证调用计数
|
||||
assert tool._call_count == 5
|
||||
|
||||
|
|
@ -9,7 +9,7 @@ import {
|
|||
} from "@/components/ui/select";
|
||||
import { GitBranch, Zap, Info } from "lucide-react";
|
||||
import type { Project, CreateAuditTaskForm } from "@/shared/types";
|
||||
import { isRepositoryProject, isZipProject } from "@/shared/utils/projectUtils";
|
||||
import { isRepositoryProject, isZipProject, getRepositoryPlatformLabel } from "@/shared/utils/projectUtils";
|
||||
import ZipFileSection from "./ZipFileSection";
|
||||
import type { ZipFileMeta } from "@/shared/utils/zipStorage";
|
||||
|
||||
|
|
@ -138,7 +138,7 @@ function ProjectInfoCard({ project }: { project: Project }) {
|
|||
{isRepo && (
|
||||
<>
|
||||
<p>
|
||||
仓库平台:{project.repository_type?.toUpperCase() || "OTHER"}
|
||||
仓库平台:{getRepositoryPlatformLabel(project.repository_type)}
|
||||
</p>
|
||||
<p>默认分支:{project.default_branch}</p>
|
||||
</>
|
||||
|
|
|
|||
|
|
@ -34,13 +34,13 @@ import { api } from "@/shared/config/database";
|
|||
import { runRepositoryAudit, scanStoredZipFile } from "@/features/projects/services";
|
||||
import type { Project, AuditTask, CreateProjectForm } from "@/shared/types";
|
||||
import { hasZipFile } from "@/shared/utils/zipStorage";
|
||||
import { isRepositoryProject, getSourceTypeLabel } from "@/shared/utils/projectUtils";
|
||||
import { isRepositoryProject, getSourceTypeLabel, getRepositoryPlatformLabel } from "@/shared/utils/projectUtils";
|
||||
import { toast } from "sonner";
|
||||
import CreateTaskDialog from "@/components/audit/CreateTaskDialog";
|
||||
import FileSelectionDialog from "@/components/audit/FileSelectionDialog";
|
||||
import TerminalProgressDialog from "@/components/audit/TerminalProgressDialog";
|
||||
import { Dialog, DialogContent, DialogHeader, DialogTitle, DialogFooter } from "@/components/ui/dialog";
|
||||
import { SUPPORTED_LANGUAGES } from "@/shared/constants";
|
||||
import { SUPPORTED_LANGUAGES, REPOSITORY_PLATFORMS } from "@/shared/constants";
|
||||
|
||||
export default function ProjectDetail() {
|
||||
const { id } = useParams<{ id: string }>();
|
||||
|
|
@ -475,8 +475,7 @@ export default function ProjectDetail() {
|
|||
<div className="flex items-center justify-between">
|
||||
<span className="text-sm text-muted-foreground uppercase">仓库平台</span>
|
||||
<Badge className="cyber-badge-muted">
|
||||
{project.repository_type === 'github' ? 'GitHub' :
|
||||
project.repository_type === 'gitlab' ? 'GitLab' : '其他'}
|
||||
{getRepositoryPlatformLabel(project.repository_type)}
|
||||
</Badge>
|
||||
</div>
|
||||
|
||||
|
|
@ -529,12 +528,11 @@ export default function ProjectDetail() {
|
|||
className="flex items-center justify-between p-3 bg-muted/50 rounded-lg hover:bg-muted transition-all group"
|
||||
>
|
||||
<div className="flex items-center space-x-3">
|
||||
<div className={`w-8 h-8 rounded-lg flex items-center justify-center ${
|
||||
task.status === 'completed' ? 'bg-emerald-500/20' :
|
||||
<div className={`w-8 h-8 rounded-lg flex items-center justify-center ${task.status === 'completed' ? 'bg-emerald-500/20' :
|
||||
task.status === 'running' ? 'bg-sky-500/20' :
|
||||
task.status === 'failed' ? 'bg-rose-500/20' :
|
||||
'bg-muted'
|
||||
}`}>
|
||||
task.status === 'failed' ? 'bg-rose-500/20' :
|
||||
'bg-muted'
|
||||
}`}>
|
||||
{getStatusIcon(task.status)}
|
||||
</div>
|
||||
<div>
|
||||
|
|
@ -579,12 +577,11 @@ export default function ProjectDetail() {
|
|||
<div key={task.id} className="cyber-card p-6">
|
||||
<div className="flex items-center justify-between mb-4 pb-4 border-b border-border">
|
||||
<div className="flex items-center space-x-3">
|
||||
<div className={`w-10 h-10 rounded-lg flex items-center justify-center ${
|
||||
task.status === 'completed' ? 'bg-emerald-500/20' :
|
||||
<div className={`w-10 h-10 rounded-lg flex items-center justify-center ${task.status === 'completed' ? 'bg-emerald-500/20' :
|
||||
task.status === 'running' ? 'bg-sky-500/20' :
|
||||
task.status === 'failed' ? 'bg-rose-500/20' :
|
||||
'bg-muted'
|
||||
}`}>
|
||||
task.status === 'failed' ? 'bg-rose-500/20' :
|
||||
'bg-muted'
|
||||
}`}>
|
||||
{getStatusIcon(task.status)}
|
||||
</div>
|
||||
<div>
|
||||
|
|
@ -676,12 +673,11 @@ export default function ProjectDetail() {
|
|||
<div key={index} className="cyber-card p-4 hover:border-border transition-all">
|
||||
<div className="flex items-start justify-between">
|
||||
<div className="flex items-start space-x-3">
|
||||
<div className={`w-8 h-8 rounded-lg flex items-center justify-center ${
|
||||
issue.severity === 'critical' ? 'bg-rose-500/20 text-rose-600 dark:text-rose-400' :
|
||||
<div className={`w-8 h-8 rounded-lg flex items-center justify-center ${issue.severity === 'critical' ? 'bg-rose-500/20 text-rose-600 dark:text-rose-400' :
|
||||
issue.severity === 'high' ? 'bg-orange-500/20 text-orange-600 dark:text-orange-400' :
|
||||
issue.severity === 'medium' ? 'bg-amber-500/20 text-amber-600 dark:text-amber-400' :
|
||||
'bg-sky-500/20 text-sky-600 dark:text-sky-400'
|
||||
}`}>
|
||||
issue.severity === 'medium' ? 'bg-amber-500/20 text-amber-600 dark:text-amber-400' :
|
||||
'bg-sky-500/20 text-sky-600 dark:text-sky-400'
|
||||
}`}>
|
||||
<AlertTriangle className="w-4 h-4" />
|
||||
</div>
|
||||
<div>
|
||||
|
|
@ -695,13 +691,13 @@ export default function ProjectDetail() {
|
|||
<Badge className={`
|
||||
${issue.severity === 'critical' ? 'severity-critical' :
|
||||
issue.severity === 'high' ? 'severity-high' :
|
||||
issue.severity === 'medium' ? 'severity-medium' :
|
||||
'severity-low'}
|
||||
issue.severity === 'medium' ? 'severity-medium' :
|
||||
'severity-low'}
|
||||
font-bold uppercase px-2 py-1 rounded text-xs
|
||||
`}>
|
||||
{issue.severity === 'critical' ? '严重' :
|
||||
issue.severity === 'high' ? '高' :
|
||||
issue.severity === 'medium' ? '中等' : '低'}
|
||||
issue.severity === 'medium' ? '中等' : '低'}
|
||||
</Badge>
|
||||
</div>
|
||||
<p className="mt-3 text-sm text-muted-foreground font-mono border-t border-border pt-3">
|
||||
|
|
@ -783,9 +779,11 @@ export default function ProjectDetail() {
|
|||
<SelectValue />
|
||||
</SelectTrigger>
|
||||
<SelectContent className="cyber-dialog border-border">
|
||||
<SelectItem value="github">GitHub</SelectItem>
|
||||
<SelectItem value="gitlab">GitLab</SelectItem>
|
||||
<SelectItem value="other">其他</SelectItem>
|
||||
{REPOSITORY_PLATFORMS.map((platform) => (
|
||||
<SelectItem key={platform.value} value={platform.value}>
|
||||
{platform.label}
|
||||
</SelectItem>
|
||||
))}
|
||||
</SelectContent>
|
||||
</Select>
|
||||
</div>
|
||||
|
|
@ -831,14 +829,14 @@ export default function ProjectDetail() {
|
|||
className={`flex items-center space-x-2 p-3 border cursor-pointer transition-all rounded ${editForm.programming_languages?.includes(lang)
|
||||
? 'border-primary bg-primary/10 text-primary'
|
||||
: 'border-border hover:border-border text-muted-foreground'
|
||||
}`}
|
||||
}`}
|
||||
onClick={() => handleToggleLanguage(lang)}
|
||||
>
|
||||
<div
|
||||
className={`w-4 h-4 border-2 rounded-sm flex items-center justify-center ${editForm.programming_languages?.includes(lang)
|
||||
? 'bg-primary border-primary'
|
||||
: 'border-border'
|
||||
}`}
|
||||
}`}
|
||||
>
|
||||
{editForm.programming_languages?.includes(lang) && (
|
||||
<CheckCircle className="w-3 h-3 text-foreground" />
|
||||
|
|
|
|||
|
|
@ -45,7 +45,7 @@ import { Link } from "react-router-dom";
|
|||
import { toast } from "sonner";
|
||||
import CreateTaskDialog from "@/components/audit/CreateTaskDialog";
|
||||
import TerminalProgressDialog from "@/components/audit/TerminalProgressDialog";
|
||||
import { SUPPORTED_LANGUAGES } from "@/shared/constants";
|
||||
import { SUPPORTED_LANGUAGES, REPOSITORY_PLATFORMS } from "@/shared/constants";
|
||||
|
||||
export default function Projects() {
|
||||
const [projects, setProjects] = useState<Project[]>([]);
|
||||
|
|
@ -487,10 +487,11 @@ export default function Projects() {
|
|||
<SelectValue />
|
||||
</SelectTrigger>
|
||||
<SelectContent className="cyber-dialog border-border">
|
||||
<SelectItem value="github">GITHUB</SelectItem>
|
||||
<SelectItem value="gitlab">GITLAB</SelectItem>
|
||||
<SelectItem value="gitea">GITEA</SelectItem>
|
||||
<SelectItem value="other">OTHER</SelectItem>
|
||||
{REPOSITORY_PLATFORMS.map((platform) => (
|
||||
<SelectItem key={platform.value} value={platform.value}>
|
||||
{platform.label}
|
||||
</SelectItem>
|
||||
))}
|
||||
</SelectContent>
|
||||
</Select>
|
||||
</div>
|
||||
|
|
@ -1046,10 +1047,11 @@ export default function Projects() {
|
|||
<SelectValue />
|
||||
</SelectTrigger>
|
||||
<SelectContent className="cyber-dialog border-border">
|
||||
<SelectItem value="github">GITHUB</SelectItem>
|
||||
<SelectItem value="gitlab">GITLAB</SelectItem>
|
||||
<SelectItem value="gitea">GITEA</SelectItem>
|
||||
<SelectItem value="other">OTHER</SelectItem>
|
||||
{REPOSITORY_PLATFORMS.map((platform) => (
|
||||
<SelectItem key={platform.value} value={platform.value}>
|
||||
{platform.label}
|
||||
</SelectItem>
|
||||
))}
|
||||
</SelectContent>
|
||||
</Select>
|
||||
</div>
|
||||
|
|
|
|||
|
|
@ -36,7 +36,7 @@ import type { AuditTask, AuditIssue } from "@/shared/types";
|
|||
import { toast } from "sonner";
|
||||
import ExportReportDialog from "@/components/reports/ExportReportDialog";
|
||||
import { calculateTaskProgress } from "@/shared/utils/utils";
|
||||
import { isRepositoryProject, getSourceTypeLabel } from "@/shared/utils/projectUtils";
|
||||
import { isRepositoryProject, getSourceTypeLabel, getRepositoryPlatformLabel } from "@/shared/utils/projectUtils";
|
||||
|
||||
// AI explanation parser
|
||||
function parseAIExplanation(aiExplanation: string) {
|
||||
|
|
@ -86,12 +86,11 @@ function IssuesList({ issues }: { issues: AuditIssue[] }) {
|
|||
<div key={issue.id || index} className="cyber-card p-4 hover:border-border transition-all group">
|
||||
<div className="flex items-start justify-between mb-3">
|
||||
<div className="flex items-start space-x-3">
|
||||
<div className={`w-10 h-10 rounded-lg flex items-center justify-center ${
|
||||
issue.severity === 'critical' ? 'bg-rose-500/20 text-rose-400' :
|
||||
issue.severity === 'high' ? 'bg-orange-500/20 text-orange-400' :
|
||||
issue.severity === 'medium' ? 'bg-amber-500/20 text-amber-400' :
|
||||
'bg-sky-500/20 text-sky-400'
|
||||
}`}>
|
||||
<div className={`w-10 h-10 rounded-lg flex items-center justify-center ${issue.severity === 'critical' ? 'bg-rose-500/20 text-rose-400' :
|
||||
issue.severity === 'high' ? 'bg-orange-500/20 text-orange-400' :
|
||||
issue.severity === 'medium' ? 'bg-amber-500/20 text-amber-400' :
|
||||
'bg-sky-500/20 text-sky-400'
|
||||
}`}>
|
||||
{getTypeIcon(issue.issue_type)}
|
||||
</div>
|
||||
<div className="flex-1">
|
||||
|
|
@ -112,7 +111,7 @@ function IssuesList({ issues }: { issues: AuditIssue[] }) {
|
|||
<Badge className={`${getSeverityClasses(issue.severity)} font-bold uppercase px-2 py-1 rounded text-xs`}>
|
||||
{issue.severity === 'critical' ? '严重' :
|
||||
issue.severity === 'high' ? '高' :
|
||||
issue.severity === 'medium' ? '中等' : '低'}
|
||||
issue.severity === 'medium' ? '中等' : '低'}
|
||||
</Badge>
|
||||
</div>
|
||||
|
||||
|
|
@ -702,7 +701,7 @@ export default function TaskDetail() {
|
|||
{isRepositoryProject(task.project) && (
|
||||
<div>
|
||||
<p className="text-xs font-bold text-muted-foreground uppercase mb-1">仓库平台</p>
|
||||
<p className="text-base font-bold text-foreground">{task.project.repository_type?.toUpperCase() || 'OTHER'}</p>
|
||||
<p className="text-base font-bold text-foreground">{getRepositoryPlatformLabel(task.project.repository_type)}</p>
|
||||
</div>
|
||||
)}
|
||||
{task.project.programming_languages && (
|
||||
|
|
|
|||
|
|
@ -62,13 +62,6 @@ export const PROJECT_SOURCE_TYPES = {
|
|||
ZIP: 'zip',
|
||||
} as const;
|
||||
|
||||
// 仓库平台类型
|
||||
export const REPOSITORY_TYPES = {
|
||||
GITHUB: 'github',
|
||||
GITLAB: 'gitlab',
|
||||
OTHER: 'other',
|
||||
} as const;
|
||||
|
||||
// 分析深度
|
||||
export const ANALYSIS_DEPTH = {
|
||||
BASIC: 'basic',
|
||||
|
|
|
|||
|
|
@ -22,17 +22,23 @@ export const PROJECT_SOURCE_TYPES: Array<{
|
|||
}
|
||||
];
|
||||
|
||||
// 仓库平台显示名称
|
||||
export const REPOSITORY_PLATFORM_LABELS: Record<RepositoryPlatform, string> = {
|
||||
github: 'GitHub',
|
||||
gitlab: 'GitLab',
|
||||
gitea: 'Gitea',
|
||||
other: '其他',
|
||||
};
|
||||
|
||||
// 仓库平台选项
|
||||
export const REPOSITORY_PLATFORMS: Array<{
|
||||
value: RepositoryPlatform;
|
||||
label: string;
|
||||
icon?: string;
|
||||
}> = [
|
||||
{ value: 'github', label: 'GitHub' },
|
||||
{ value: 'gitlab', label: 'GitLab' },
|
||||
{ value: 'gitea', label: 'Gitea' },
|
||||
{ value: 'other', label: '其他' }
|
||||
];
|
||||
}> = Object.entries(REPOSITORY_PLATFORM_LABELS).map(([value, label]) => ({
|
||||
value: value as RepositoryPlatform,
|
||||
label
|
||||
}));
|
||||
|
||||
// 项目来源类型的颜色配置
|
||||
export const SOURCE_TYPE_COLORS: Record<ProjectSourceType, {
|
||||
|
|
|
|||
|
|
@ -4,6 +4,7 @@
|
|||
*/
|
||||
|
||||
import type { Project, ProjectSourceType } from '@/shared/types';
|
||||
import { REPOSITORY_PLATFORM_LABELS } from '@/shared/constants/projectTypes';
|
||||
|
||||
/**
|
||||
* 判断项目是否为仓库类型
|
||||
|
|
@ -45,13 +46,7 @@ export function getSourceTypeBadge(sourceType: ProjectSourceType): string {
|
|||
* 获取仓库平台的显示名称
|
||||
*/
|
||||
export function getRepositoryPlatformLabel(platform?: string): string {
|
||||
const labels: Record<string, string> = {
|
||||
github: 'GitHub',
|
||||
gitlab: 'GitLab',
|
||||
gitea: 'Gitea',
|
||||
other: '其他'
|
||||
};
|
||||
return labels[platform || 'other'] || '其他';
|
||||
return REPOSITORY_PLATFORM_LABELS[platform as keyof typeof REPOSITORY_PLATFORM_LABELS] || REPOSITORY_PLATFORM_LABELS.other;
|
||||
}
|
||||
|
||||
/**
|
||||
|
|
|
|||
Loading…
Reference in New Issue