Merge branch 'v3.0.0' of github.com:lintsinghua/DeepAudit into feature/git_ssh
# Conflicts: # backend/app/api/v1/endpoints/agent_tasks.py
This commit is contained in:
commit
869513e0c5
|
|
@ -428,7 +428,7 @@ DeepSeek-Coder · Codestral<br/>
|
||||||
<div align="center">
|
<div align="center">
|
||||||
|
|
||||||
**欢迎大家来和我交流探讨!无论是技术问题、功能建议还是合作意向,都期待与你沟通~**
|
**欢迎大家来和我交流探讨!无论是技术问题、功能建议还是合作意向,都期待与你沟通~**
|
||||||
|
(项目开发、投资孵化等合作洽谈请通过邮箱联系)
|
||||||
| 联系方式 | |
|
| 联系方式 | |
|
||||||
|:---:|:---:|
|
|:---:|:---:|
|
||||||
| 📧 **邮箱** | **lintsinghua@qq.com** |
|
| 📧 **邮箱** | **lintsinghua@qq.com** |
|
||||||
|
|
|
||||||
|
|
@ -294,6 +294,7 @@ async def _execute_agent_task(task_id: str):
|
||||||
other_config = (user_config or {}).get('otherConfig', {})
|
other_config = (user_config or {}).get('otherConfig', {})
|
||||||
github_token = other_config.get('githubToken') or settings.GITHUB_TOKEN
|
github_token = other_config.get('githubToken') or settings.GITHUB_TOKEN
|
||||||
gitlab_token = other_config.get('gitlabToken') or settings.GITLAB_TOKEN
|
gitlab_token = other_config.get('gitlabToken') or settings.GITLAB_TOKEN
|
||||||
|
gitea_token = other_config.get('giteaToken') or settings.GITEA_TOKEN
|
||||||
|
|
||||||
# 解密SSH私钥
|
# 解密SSH私钥
|
||||||
ssh_private_key = None
|
ssh_private_key = None
|
||||||
|
|
@ -313,6 +314,7 @@ async def _execute_agent_task(task_id: str):
|
||||||
task.branch_name,
|
task.branch_name,
|
||||||
github_token=github_token,
|
github_token=github_token,
|
||||||
gitlab_token=gitlab_token,
|
gitlab_token=gitlab_token,
|
||||||
|
gitea_token=gitea_token, # 🔥 新增
|
||||||
ssh_private_key=ssh_private_key, # 🔥 新增SSH密钥
|
ssh_private_key=ssh_private_key, # 🔥 新增SSH密钥
|
||||||
event_emitter=event_emitter, # 🔥 新增
|
event_emitter=event_emitter, # 🔥 新增
|
||||||
)
|
)
|
||||||
|
|
@ -2226,6 +2228,7 @@ async def _get_project_root(
|
||||||
branch_name: Optional[str] = None,
|
branch_name: Optional[str] = None,
|
||||||
github_token: Optional[str] = None,
|
github_token: Optional[str] = None,
|
||||||
gitlab_token: Optional[str] = None,
|
gitlab_token: Optional[str] = None,
|
||||||
|
gitea_token: Optional[str] = None, # 🔥 新增
|
||||||
ssh_private_key: Optional[str] = None, # 🔥 新增:SSH私钥(用于SSH认证)
|
ssh_private_key: Optional[str] = None, # 🔥 新增:SSH私钥(用于SSH认证)
|
||||||
event_emitter: Optional[Any] = None, # 🔥 新增:用于发送实时日志
|
event_emitter: Optional[Any] = None, # 🔥 新增:用于发送实时日志
|
||||||
) -> str:
|
) -> str:
|
||||||
|
|
@ -2242,6 +2245,7 @@ async def _get_project_root(
|
||||||
branch_name: 分支名称(仓库项目使用,优先于 project.default_branch)
|
branch_name: 分支名称(仓库项目使用,优先于 project.default_branch)
|
||||||
github_token: GitHub 访问令牌(用于私有仓库)
|
github_token: GitHub 访问令牌(用于私有仓库)
|
||||||
gitlab_token: GitLab 访问令牌(用于私有仓库)
|
gitlab_token: GitLab 访问令牌(用于私有仓库)
|
||||||
|
gitea_token: Gitea 访问令牌(用于私有仓库)
|
||||||
ssh_private_key: SSH私钥(用于SSH认证)
|
ssh_private_key: SSH私钥(用于SSH认证)
|
||||||
event_emitter: 事件发送器(用于发送实时日志)
|
event_emitter: 事件发送器(用于发送实时日志)
|
||||||
|
|
||||||
|
|
@ -2503,6 +2507,16 @@ async def _get_project_root(
|
||||||
parsed.fragment
|
parsed.fragment
|
||||||
))
|
))
|
||||||
await emit(f"🔐 使用 GitLab Token 认证")
|
await emit(f"🔐 使用 GitLab Token 认证")
|
||||||
|
elif repo_type == "gitea" and gitea_token:
|
||||||
|
auth_url = urlunparse((
|
||||||
|
parsed.scheme,
|
||||||
|
f"{gitea_token}@{parsed.netloc}",
|
||||||
|
parsed.path,
|
||||||
|
parsed.params,
|
||||||
|
parsed.query,
|
||||||
|
parsed.fragment
|
||||||
|
))
|
||||||
|
await emit(f"🔐 使用 Gitea Token 认证")
|
||||||
elif is_ssh_url and ssh_private_key:
|
elif is_ssh_url and ssh_private_key:
|
||||||
await emit(f"🔐 使用 SSH Key 认证")
|
await emit(f"🔐 使用 SSH Key 认证")
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -16,6 +16,7 @@ from app.db.session import get_db, AsyncSessionLocal
|
||||||
from app.models.project import Project
|
from app.models.project import Project
|
||||||
from app.models.user import User
|
from app.models.user import User
|
||||||
from app.models.audit import AuditTask, AuditIssue
|
from app.models.audit import AuditTask, AuditIssue
|
||||||
|
from app.models.agent_task import AgentTask, AgentTaskStatus, AgentFinding
|
||||||
from app.models.user_config import UserConfig
|
from app.models.user_config import UserConfig
|
||||||
import zipfile
|
import zipfile
|
||||||
from app.services.scanner import scan_repo_task, get_github_files, get_gitlab_files, get_github_branches, get_gitlab_branches, get_gitea_branches, should_exclude, is_text_file
|
from app.services.scanner import scan_repo_task, get_github_files, get_gitlab_files, get_github_branches, get_gitlab_branches, get_gitea_branches, should_exclude, is_text_file
|
||||||
|
|
@ -162,26 +163,51 @@ async def get_stats(
|
||||||
projects = projects_result.scalars().all()
|
projects = projects_result.scalars().all()
|
||||||
project_ids = [p.id for p in projects]
|
project_ids = [p.id for p in projects]
|
||||||
|
|
||||||
# 只统计当前用户项目的任务
|
# 统计旧的 AuditTask
|
||||||
tasks_result = await db.execute(
|
tasks_result = await db.execute(
|
||||||
select(AuditTask).where(AuditTask.project_id.in_(project_ids)) if project_ids else select(AuditTask).where(False)
|
select(AuditTask).where(AuditTask.project_id.in_(project_ids)) if project_ids else select(AuditTask).where(False)
|
||||||
)
|
)
|
||||||
tasks = tasks_result.scalars().all()
|
tasks = tasks_result.scalars().all()
|
||||||
task_ids = [t.id for t in tasks]
|
task_ids = [t.id for t in tasks]
|
||||||
|
|
||||||
# 只统计当前用户任务的问题
|
# 统计旧的 AuditIssue
|
||||||
issues_result = await db.execute(
|
issues_result = await db.execute(
|
||||||
select(AuditIssue).where(AuditIssue.task_id.in_(task_ids)) if task_ids else select(AuditIssue).where(False)
|
select(AuditIssue).where(AuditIssue.task_id.in_(task_ids)) if task_ids else select(AuditIssue).where(False)
|
||||||
)
|
)
|
||||||
issues = issues_result.scalars().all()
|
issues = issues_result.scalars().all()
|
||||||
|
|
||||||
|
# 🔥 同时统计新的 AgentTask
|
||||||
|
agent_tasks_result = await db.execute(
|
||||||
|
select(AgentTask).where(AgentTask.project_id.in_(project_ids)) if project_ids else select(AgentTask).where(False)
|
||||||
|
)
|
||||||
|
agent_tasks = agent_tasks_result.scalars().all()
|
||||||
|
agent_task_ids = [t.id for t in agent_tasks]
|
||||||
|
|
||||||
|
# 🔥 统计 AgentFinding
|
||||||
|
agent_findings_result = await db.execute(
|
||||||
|
select(AgentFinding).where(AgentFinding.task_id.in_(agent_task_ids)) if agent_task_ids else select(AgentFinding).where(False)
|
||||||
|
)
|
||||||
|
agent_findings = agent_findings_result.scalars().all()
|
||||||
|
|
||||||
|
# 合并统计(旧任务 + 新 Agent 任务)
|
||||||
|
total_tasks = len(tasks) + len(agent_tasks)
|
||||||
|
completed_tasks = (
|
||||||
|
len([t for t in tasks if t.status == "completed"]) +
|
||||||
|
len([t for t in agent_tasks if t.status == AgentTaskStatus.COMPLETED])
|
||||||
|
)
|
||||||
|
total_issues = len(issues) + len(agent_findings)
|
||||||
|
resolved_issues = (
|
||||||
|
len([i for i in issues if i.status == "resolved"]) +
|
||||||
|
len([f for f in agent_findings if f.status == "resolved"])
|
||||||
|
)
|
||||||
|
|
||||||
return {
|
return {
|
||||||
"total_projects": len(projects),
|
"total_projects": len(projects),
|
||||||
"active_projects": len([p for p in projects if p.is_active]),
|
"active_projects": len([p for p in projects if p.is_active]),
|
||||||
"total_tasks": len(tasks),
|
"total_tasks": total_tasks,
|
||||||
"completed_tasks": len([t for t in tasks if t.status == "completed"]),
|
"completed_tasks": completed_tasks,
|
||||||
"total_issues": len(issues),
|
"total_issues": total_issues,
|
||||||
"resolved_issues": len([i for i in issues if i.status == "resolved"]),
|
"resolved_issues": resolved_issues,
|
||||||
}
|
}
|
||||||
|
|
||||||
@router.get("/{id}", response_model=ProjectResponse)
|
@router.get("/{id}", response_model=ProjectResponse)
|
||||||
|
|
|
||||||
|
|
@ -1,29 +1,19 @@
|
||||||
"""
|
"""
|
||||||
DeepAudit Agent 服务模块
|
DeepAudit Agent 服务模块
|
||||||
基于 LangGraph 的 AI Agent 代码安全审计
|
基于动态 Agent 树架构的 AI 代码安全审计
|
||||||
|
|
||||||
架构升级版本 - 支持:
|
架构:
|
||||||
- 动态Agent树结构
|
- OrchestratorAgent 作为编排层,动态调度子 Agent
|
||||||
- 专业知识模块系统
|
- ReconAgent 负责侦察和文件分析
|
||||||
- Agent间通信机制
|
- AnalysisAgent 负责漏洞分析
|
||||||
- 完整状态管理
|
- VerificationAgent 负责验证发现
|
||||||
- Think工具和漏洞报告工具
|
|
||||||
|
|
||||||
工作流:
|
工作流:
|
||||||
START → Recon → Analysis ⟲ → Verification → Report → END
|
START → Orchestrator → [Recon/Analysis/Verification] → Report → END
|
||||||
|
|
||||||
支持动态创建子Agent进行专业化分析
|
支持动态创建子Agent进行专业化分析
|
||||||
"""
|
"""
|
||||||
|
|
||||||
# 从 graph 模块导入主要组件
|
|
||||||
from .graph import (
|
|
||||||
AgentRunner,
|
|
||||||
run_agent_task,
|
|
||||||
LLMService,
|
|
||||||
AuditState,
|
|
||||||
create_audit_graph,
|
|
||||||
)
|
|
||||||
|
|
||||||
# 事件管理
|
# 事件管理
|
||||||
from .event_manager import EventManager, AgentEventEmitter
|
from .event_manager import EventManager, AgentEventEmitter
|
||||||
|
|
||||||
|
|
@ -33,14 +23,14 @@ from .agents import (
|
||||||
OrchestratorAgent, ReconAgent, AnalysisAgent, VerificationAgent,
|
OrchestratorAgent, ReconAgent, AnalysisAgent, VerificationAgent,
|
||||||
)
|
)
|
||||||
|
|
||||||
# 🔥 新增:核心模块(状态管理、注册表、消息)
|
# 核心模块(状态管理、注册表、消息)
|
||||||
from .core import (
|
from .core import (
|
||||||
AgentState, AgentStatus,
|
AgentState, AgentStatus,
|
||||||
AgentRegistry, agent_registry,
|
AgentRegistry, agent_registry,
|
||||||
AgentMessage, MessageType, MessagePriority, MessageBus,
|
AgentMessage, MessageType, MessagePriority, MessageBus,
|
||||||
)
|
)
|
||||||
|
|
||||||
# 🔥 新增:知识模块系统(基于RAG)
|
# 知识模块系统(基于RAG)
|
||||||
from .knowledge import (
|
from .knowledge import (
|
||||||
KnowledgeLoader, knowledge_loader,
|
KnowledgeLoader, knowledge_loader,
|
||||||
get_available_modules, get_module_content,
|
get_available_modules, get_module_content,
|
||||||
|
|
@ -48,7 +38,7 @@ from .knowledge import (
|
||||||
SecurityKnowledgeQueryTool, GetVulnerabilityKnowledgeTool,
|
SecurityKnowledgeQueryTool, GetVulnerabilityKnowledgeTool,
|
||||||
)
|
)
|
||||||
|
|
||||||
# 🔥 新增:协作工具
|
# 协作工具
|
||||||
from .tools import (
|
from .tools import (
|
||||||
ThinkTool, ReflectTool,
|
ThinkTool, ReflectTool,
|
||||||
CreateVulnerabilityReportTool,
|
CreateVulnerabilityReportTool,
|
||||||
|
|
@ -57,20 +47,11 @@ from .tools import (
|
||||||
WaitForMessageTool, AgentFinishTool,
|
WaitForMessageTool, AgentFinishTool,
|
||||||
)
|
)
|
||||||
|
|
||||||
# 🔥 新增:遥测模块
|
# 遥测模块
|
||||||
from .telemetry import Tracer, get_global_tracer, set_global_tracer
|
from .telemetry import Tracer, get_global_tracer, set_global_tracer
|
||||||
|
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
# 核心 Runner
|
|
||||||
"AgentRunner",
|
|
||||||
"run_agent_task",
|
|
||||||
"LLMService",
|
|
||||||
|
|
||||||
# LangGraph
|
|
||||||
"AuditState",
|
|
||||||
"create_audit_graph",
|
|
||||||
|
|
||||||
# 事件管理
|
# 事件管理
|
||||||
"EventManager",
|
"EventManager",
|
||||||
"AgentEventEmitter",
|
"AgentEventEmitter",
|
||||||
|
|
@ -84,7 +65,7 @@ __all__ = [
|
||||||
"AnalysisAgent",
|
"AnalysisAgent",
|
||||||
"VerificationAgent",
|
"VerificationAgent",
|
||||||
|
|
||||||
# 🔥 核心模块
|
# 核心模块
|
||||||
"AgentState",
|
"AgentState",
|
||||||
"AgentStatus",
|
"AgentStatus",
|
||||||
"AgentRegistry",
|
"AgentRegistry",
|
||||||
|
|
@ -94,7 +75,7 @@ __all__ = [
|
||||||
"MessagePriority",
|
"MessagePriority",
|
||||||
"MessageBus",
|
"MessageBus",
|
||||||
|
|
||||||
# 🔥 知识模块(基于RAG)
|
# 知识模块(基于RAG)
|
||||||
"KnowledgeLoader",
|
"KnowledgeLoader",
|
||||||
"knowledge_loader",
|
"knowledge_loader",
|
||||||
"get_available_modules",
|
"get_available_modules",
|
||||||
|
|
@ -104,7 +85,7 @@ __all__ = [
|
||||||
"SecurityKnowledgeQueryTool",
|
"SecurityKnowledgeQueryTool",
|
||||||
"GetVulnerabilityKnowledgeTool",
|
"GetVulnerabilityKnowledgeTool",
|
||||||
|
|
||||||
# 🔥 协作工具
|
# 协作工具
|
||||||
"ThinkTool",
|
"ThinkTool",
|
||||||
"ReflectTool",
|
"ReflectTool",
|
||||||
"CreateVulnerabilityReportTool",
|
"CreateVulnerabilityReportTool",
|
||||||
|
|
@ -115,9 +96,8 @@ __all__ = [
|
||||||
"WaitForMessageTool",
|
"WaitForMessageTool",
|
||||||
"AgentFinishTool",
|
"AgentFinishTool",
|
||||||
|
|
||||||
# 🔥 遥测模块
|
# 遥测模块
|
||||||
"Tracer",
|
"Tracer",
|
||||||
"get_global_tracer",
|
"get_global_tracer",
|
||||||
"set_global_tracer",
|
"set_global_tracer",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -1024,10 +1024,18 @@ class BaseAgent(ABC):
|
||||||
elif chunk["type"] == "error":
|
elif chunk["type"] == "error":
|
||||||
accumulated = chunk.get("accumulated", "")
|
accumulated = chunk.get("accumulated", "")
|
||||||
error_msg = chunk.get("error", "Unknown error")
|
error_msg = chunk.get("error", "Unknown error")
|
||||||
logger.error(f"[{self.name}] Stream error: {error_msg}")
|
error_type = chunk.get("error_type", "unknown")
|
||||||
if accumulated:
|
user_message = chunk.get("user_message", error_msg)
|
||||||
total_tokens = chunk.get("usage", {}).get("total_tokens", 0)
|
logger.error(f"[{self.name}] Stream error ({error_type}): {error_msg}")
|
||||||
else:
|
|
||||||
|
if chunk.get("usage"):
|
||||||
|
total_tokens = chunk["usage"].get("total_tokens", 0)
|
||||||
|
|
||||||
|
# 使用特殊前缀标记 API 错误,让调用方能够识别
|
||||||
|
# 格式:[API_ERROR:error_type] user_message
|
||||||
|
if error_type in ("rate_limit", "quota_exceeded", "authentication", "connection"):
|
||||||
|
accumulated = f"[API_ERROR:{error_type}] {user_message}"
|
||||||
|
elif not accumulated:
|
||||||
accumulated = f"[系统错误: {error_msg}] 请重新思考并输出你的决策。"
|
accumulated = f"[系统错误: {error_msg}] 请重新思考并输出你的决策。"
|
||||||
break
|
break
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -285,6 +285,55 @@ Action Input: {{"参数": "值"}}
|
||||||
# 重置空响应计数器
|
# 重置空响应计数器
|
||||||
self._empty_retry_count = 0
|
self._empty_retry_count = 0
|
||||||
|
|
||||||
|
# 🔥 检查是否是 API 错误(而非格式错误)
|
||||||
|
if llm_output.startswith("[API_ERROR:"):
|
||||||
|
# 提取错误类型和消息
|
||||||
|
match = re.match(r"\[API_ERROR:(\w+)\]\s*(.*)", llm_output)
|
||||||
|
if match:
|
||||||
|
error_type = match.group(1)
|
||||||
|
error_message = match.group(2)
|
||||||
|
|
||||||
|
if error_type == "rate_limit":
|
||||||
|
# 速率限制 - 等待后重试
|
||||||
|
api_retry_count = getattr(self, '_api_retry_count', 0) + 1
|
||||||
|
self._api_retry_count = api_retry_count
|
||||||
|
if api_retry_count >= 3:
|
||||||
|
logger.error(f"[{self.name}] Too many rate limit errors, stopping")
|
||||||
|
await self.emit_event("error", f"API 速率限制重试次数过多: {error_message}")
|
||||||
|
break
|
||||||
|
logger.warning(f"[{self.name}] Rate limit hit, waiting before retry ({api_retry_count}/3)")
|
||||||
|
await self.emit_event("warning", f"API 速率限制,等待后重试 ({api_retry_count}/3)")
|
||||||
|
await asyncio.sleep(30) # 等待 30 秒后重试
|
||||||
|
continue
|
||||||
|
|
||||||
|
elif error_type == "quota_exceeded":
|
||||||
|
# 配额用尽 - 终止任务
|
||||||
|
logger.error(f"[{self.name}] API quota exceeded: {error_message}")
|
||||||
|
await self.emit_event("error", f"API 配额已用尽: {error_message}")
|
||||||
|
break
|
||||||
|
|
||||||
|
elif error_type == "authentication":
|
||||||
|
# 认证错误 - 终止任务
|
||||||
|
logger.error(f"[{self.name}] API authentication error: {error_message}")
|
||||||
|
await self.emit_event("error", f"API 认证失败: {error_message}")
|
||||||
|
break
|
||||||
|
|
||||||
|
elif error_type == "connection":
|
||||||
|
# 连接错误 - 重试
|
||||||
|
api_retry_count = getattr(self, '_api_retry_count', 0) + 1
|
||||||
|
self._api_retry_count = api_retry_count
|
||||||
|
if api_retry_count >= 3:
|
||||||
|
logger.error(f"[{self.name}] Too many connection errors, stopping")
|
||||||
|
await self.emit_event("error", f"API 连接错误重试次数过多: {error_message}")
|
||||||
|
break
|
||||||
|
logger.warning(f"[{self.name}] Connection error, retrying ({api_retry_count}/3)")
|
||||||
|
await self.emit_event("warning", f"API 连接错误,重试中 ({api_retry_count}/3)")
|
||||||
|
await asyncio.sleep(5) # 等待 5 秒后重试
|
||||||
|
continue
|
||||||
|
|
||||||
|
# 重置 API 重试计数器(成功获取响应后)
|
||||||
|
self._api_retry_count = 0
|
||||||
|
|
||||||
# 解析 LLM 的决策
|
# 解析 LLM 的决策
|
||||||
step = self._parse_llm_response(llm_output)
|
step = self._parse_llm_response(llm_output)
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -1,28 +0,0 @@
|
||||||
"""
|
|
||||||
LangGraph 工作流模块
|
|
||||||
使用状态图构建混合 Agent 审计流程
|
|
||||||
"""
|
|
||||||
|
|
||||||
from .audit_graph import AuditState, create_audit_graph, create_audit_graph_with_human
|
|
||||||
from .nodes import ReconNode, AnalysisNode, VerificationNode, ReportNode, HumanReviewNode
|
|
||||||
from .runner import AgentRunner, run_agent_task, LLMService
|
|
||||||
|
|
||||||
__all__ = [
|
|
||||||
# 状态和图
|
|
||||||
"AuditState",
|
|
||||||
"create_audit_graph",
|
|
||||||
"create_audit_graph_with_human",
|
|
||||||
|
|
||||||
# 节点
|
|
||||||
"ReconNode",
|
|
||||||
"AnalysisNode",
|
|
||||||
"VerificationNode",
|
|
||||||
"ReportNode",
|
|
||||||
"HumanReviewNode",
|
|
||||||
|
|
||||||
# Runner
|
|
||||||
"AgentRunner",
|
|
||||||
"run_agent_task",
|
|
||||||
"LLMService",
|
|
||||||
]
|
|
||||||
|
|
||||||
|
|
@ -1,677 +0,0 @@
|
||||||
"""
|
|
||||||
DeepAudit 审计工作流图 - LLM 驱动版
|
|
||||||
使用 LangGraph 构建 LLM 驱动的 Agent 协作流程
|
|
||||||
|
|
||||||
重要改变:路由决策由 LLM 参与,而不是硬编码条件!
|
|
||||||
"""
|
|
||||||
|
|
||||||
from typing import TypedDict, Annotated, List, Dict, Any, Optional, Literal
|
|
||||||
from datetime import datetime
|
|
||||||
import operator
|
|
||||||
import logging
|
|
||||||
import json
|
|
||||||
|
|
||||||
from langgraph.graph import StateGraph, END
|
|
||||||
from langgraph.checkpoint.memory import MemorySaver
|
|
||||||
from langgraph.prebuilt import ToolNode
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
|
||||||
|
|
||||||
|
|
||||||
# ============ 状态定义 ============
|
|
||||||
|
|
||||||
class Finding(TypedDict):
|
|
||||||
"""漏洞发现"""
|
|
||||||
id: str
|
|
||||||
vulnerability_type: str
|
|
||||||
severity: str
|
|
||||||
title: str
|
|
||||||
description: str
|
|
||||||
file_path: Optional[str]
|
|
||||||
line_start: Optional[int]
|
|
||||||
code_snippet: Optional[str]
|
|
||||||
is_verified: bool
|
|
||||||
confidence: float
|
|
||||||
source: str
|
|
||||||
|
|
||||||
|
|
||||||
class AuditState(TypedDict):
|
|
||||||
"""
|
|
||||||
审计状态
|
|
||||||
在整个工作流中传递和更新
|
|
||||||
"""
|
|
||||||
# 输入
|
|
||||||
project_root: str
|
|
||||||
project_info: Dict[str, Any]
|
|
||||||
config: Dict[str, Any]
|
|
||||||
task_id: str
|
|
||||||
|
|
||||||
# Recon 阶段输出
|
|
||||||
tech_stack: Dict[str, Any]
|
|
||||||
entry_points: List[Dict[str, Any]]
|
|
||||||
high_risk_areas: List[str]
|
|
||||||
dependencies: Dict[str, Any]
|
|
||||||
|
|
||||||
# Analysis 阶段输出
|
|
||||||
findings: Annotated[List[Finding], operator.add] # 使用 add 合并多轮发现
|
|
||||||
|
|
||||||
# Verification 阶段输出
|
|
||||||
verified_findings: List[Finding]
|
|
||||||
false_positives: List[str]
|
|
||||||
# 🔥 NEW: 验证后的完整 findings(用于替换原始 findings)
|
|
||||||
_verified_findings_update: Optional[List[Finding]]
|
|
||||||
|
|
||||||
# 控制流 - 🔥 关键:LLM 可以设置这些来影响路由
|
|
||||||
current_phase: str
|
|
||||||
iteration: int
|
|
||||||
max_iterations: int
|
|
||||||
should_continue_analysis: bool
|
|
||||||
|
|
||||||
# 🔥 新增:LLM 的路由决策
|
|
||||||
llm_next_action: Optional[str] # LLM 建议的下一步: "continue_analysis", "verify", "report", "end"
|
|
||||||
llm_routing_reason: Optional[str] # LLM 的决策理由
|
|
||||||
|
|
||||||
# 🔥 新增:Agent 间协作的任务交接信息
|
|
||||||
recon_handoff: Optional[Dict[str, Any]] # Recon -> Analysis 的交接
|
|
||||||
analysis_handoff: Optional[Dict[str, Any]] # Analysis -> Verification 的交接
|
|
||||||
verification_handoff: Optional[Dict[str, Any]] # Verification -> Report 的交接
|
|
||||||
|
|
||||||
# 消息和事件
|
|
||||||
messages: Annotated[List[Dict], operator.add]
|
|
||||||
events: Annotated[List[Dict], operator.add]
|
|
||||||
|
|
||||||
# 最终输出
|
|
||||||
summary: Optional[Dict[str, Any]]
|
|
||||||
security_score: Optional[int]
|
|
||||||
error: Optional[str]
|
|
||||||
|
|
||||||
|
|
||||||
# ============ LLM 路由决策器 ============
|
|
||||||
|
|
||||||
class LLMRouter:
|
|
||||||
"""
|
|
||||||
LLM 路由决策器
|
|
||||||
让 LLM 来决定下一步应该做什么
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self, llm_service):
|
|
||||||
self.llm_service = llm_service
|
|
||||||
|
|
||||||
async def decide_after_recon(self, state: AuditState) -> Dict[str, Any]:
|
|
||||||
"""Recon 后让 LLM 决定下一步"""
|
|
||||||
entry_points = state.get("entry_points", [])
|
|
||||||
high_risk_areas = state.get("high_risk_areas", [])
|
|
||||||
tech_stack = state.get("tech_stack", {})
|
|
||||||
initial_findings = state.get("findings", [])
|
|
||||||
|
|
||||||
prompt = f"""作为安全审计的决策者,基于以下信息收集结果,决定下一步行动。
|
|
||||||
|
|
||||||
## 信息收集结果
|
|
||||||
- 入口点数量: {len(entry_points)}
|
|
||||||
- 高风险区域: {high_risk_areas[:10]}
|
|
||||||
- 技术栈: {tech_stack}
|
|
||||||
- 初步发现: {len(initial_findings)} 个
|
|
||||||
|
|
||||||
## 选项
|
|
||||||
1. "analysis" - 继续进行漏洞分析(推荐:有入口点或高风险区域时)
|
|
||||||
2. "end" - 结束审计(仅当没有任何可分析内容时)
|
|
||||||
|
|
||||||
请返回 JSON 格式:
|
|
||||||
{{"action": "analysis或end", "reason": "决策理由"}}"""
|
|
||||||
|
|
||||||
try:
|
|
||||||
response = await self.llm_service.chat_completion_raw(
|
|
||||||
messages=[
|
|
||||||
{"role": "system", "content": "你是安全审计流程的决策者,负责决定下一步行动。"},
|
|
||||||
{"role": "user", "content": prompt},
|
|
||||||
],
|
|
||||||
# 🔥 不传递 temperature 和 max_tokens,使用用户配置
|
|
||||||
)
|
|
||||||
|
|
||||||
content = response.get("content", "")
|
|
||||||
# 提取 JSON
|
|
||||||
import re
|
|
||||||
json_match = re.search(r'\{.*\}', content, re.DOTALL)
|
|
||||||
if json_match:
|
|
||||||
result = json.loads(json_match.group())
|
|
||||||
return result
|
|
||||||
except Exception as e:
|
|
||||||
logger.warning(f"LLM routing decision failed: {e}")
|
|
||||||
|
|
||||||
# 默认决策
|
|
||||||
if entry_points or high_risk_areas:
|
|
||||||
return {"action": "analysis", "reason": "有可分析内容"}
|
|
||||||
return {"action": "end", "reason": "没有发现入口点或高风险区域"}
|
|
||||||
|
|
||||||
async def decide_after_analysis(self, state: AuditState) -> Dict[str, Any]:
|
|
||||||
"""Analysis 后让 LLM 决定下一步"""
|
|
||||||
findings = state.get("findings", [])
|
|
||||||
iteration = state.get("iteration", 0)
|
|
||||||
max_iterations = state.get("max_iterations", 3)
|
|
||||||
|
|
||||||
# 统计发现
|
|
||||||
severity_counts = {"critical": 0, "high": 0, "medium": 0, "low": 0}
|
|
||||||
for f in findings:
|
|
||||||
# 跳过非字典类型的 finding
|
|
||||||
if not isinstance(f, dict):
|
|
||||||
continue
|
|
||||||
sev = f.get("severity", "medium")
|
|
||||||
severity_counts[sev] = severity_counts.get(sev, 0) + 1
|
|
||||||
|
|
||||||
prompt = f"""作为安全审计的决策者,基于以下分析结果,决定下一步行动。
|
|
||||||
|
|
||||||
## 分析结果
|
|
||||||
- 总发现数: {len(findings)}
|
|
||||||
- 严重程度分布: {severity_counts}
|
|
||||||
- 当前迭代: {iteration}/{max_iterations}
|
|
||||||
|
|
||||||
## 选项
|
|
||||||
1. "verification" - 验证发现的漏洞(推荐:有发现需要验证时)
|
|
||||||
2. "analysis" - 继续深入分析(推荐:发现较少但还有迭代次数时)
|
|
||||||
3. "report" - 生成报告(推荐:没有发现或已充分分析时)
|
|
||||||
|
|
||||||
请返回 JSON 格式:
|
|
||||||
{{"action": "verification/analysis/report", "reason": "决策理由"}}"""
|
|
||||||
|
|
||||||
try:
|
|
||||||
response = await self.llm_service.chat_completion_raw(
|
|
||||||
messages=[
|
|
||||||
{"role": "system", "content": "你是安全审计流程的决策者,负责决定下一步行动。"},
|
|
||||||
{"role": "user", "content": prompt},
|
|
||||||
],
|
|
||||||
# 🔥 不传递 temperature 和 max_tokens,使用用户配置
|
|
||||||
)
|
|
||||||
|
|
||||||
content = response.get("content", "")
|
|
||||||
import re
|
|
||||||
json_match = re.search(r'\{.*\}', content, re.DOTALL)
|
|
||||||
if json_match:
|
|
||||||
result = json.loads(json_match.group())
|
|
||||||
return result
|
|
||||||
except Exception as e:
|
|
||||||
logger.warning(f"LLM routing decision failed: {e}")
|
|
||||||
|
|
||||||
# 默认决策
|
|
||||||
if not findings:
|
|
||||||
return {"action": "report", "reason": "没有发现漏洞"}
|
|
||||||
if len(findings) >= 3 or iteration >= max_iterations:
|
|
||||||
return {"action": "verification", "reason": "有足够的发现需要验证"}
|
|
||||||
return {"action": "analysis", "reason": "发现较少,继续分析"}
|
|
||||||
|
|
||||||
async def decide_after_verification(self, state: AuditState) -> Dict[str, Any]:
|
|
||||||
"""Verification 后让 LLM 决定下一步"""
|
|
||||||
verified_findings = state.get("verified_findings", [])
|
|
||||||
false_positives = state.get("false_positives", [])
|
|
||||||
iteration = state.get("iteration", 0)
|
|
||||||
max_iterations = state.get("max_iterations", 3)
|
|
||||||
|
|
||||||
prompt = f"""作为安全审计的决策者,基于以下验证结果,决定下一步行动。
|
|
||||||
|
|
||||||
## 验证结果
|
|
||||||
- 已确认漏洞: {len(verified_findings)}
|
|
||||||
- 误报数量: {len(false_positives)}
|
|
||||||
- 当前迭代: {iteration}/{max_iterations}
|
|
||||||
|
|
||||||
## 选项
|
|
||||||
1. "analysis" - 回到分析阶段重新分析(推荐:误报率太高时)
|
|
||||||
2. "report" - 生成最终报告(推荐:验证完成时)
|
|
||||||
|
|
||||||
请返回 JSON 格式:
|
|
||||||
{{"action": "analysis/report", "reason": "决策理由"}}"""
|
|
||||||
|
|
||||||
try:
|
|
||||||
response = await self.llm_service.chat_completion_raw(
|
|
||||||
messages=[
|
|
||||||
{"role": "system", "content": "你是安全审计流程的决策者,负责决定下一步行动。"},
|
|
||||||
{"role": "user", "content": prompt},
|
|
||||||
],
|
|
||||||
# 🔥 不传递 temperature 和 max_tokens,使用用户配置
|
|
||||||
)
|
|
||||||
|
|
||||||
content = response.get("content", "")
|
|
||||||
import re
|
|
||||||
json_match = re.search(r'\{.*\}', content, re.DOTALL)
|
|
||||||
if json_match:
|
|
||||||
result = json.loads(json_match.group())
|
|
||||||
return result
|
|
||||||
except Exception as e:
|
|
||||||
logger.warning(f"LLM routing decision failed: {e}")
|
|
||||||
|
|
||||||
# 默认决策
|
|
||||||
if len(false_positives) > len(verified_findings) and iteration < max_iterations:
|
|
||||||
return {"action": "analysis", "reason": "误报率较高,需要重新分析"}
|
|
||||||
return {"action": "report", "reason": "验证完成,生成报告"}
|
|
||||||
|
|
||||||
|
|
||||||
# ============ 路由函数 (结合 LLM 决策) ============
|
|
||||||
|
|
||||||
def route_after_recon(state: AuditState) -> Literal["analysis", "end"]:
|
|
||||||
"""
|
|
||||||
Recon 后的路由决策
|
|
||||||
优先使用 LLM 的决策,否则使用默认逻辑
|
|
||||||
"""
|
|
||||||
# 🔥 检查是否有错误
|
|
||||||
if state.get("error") or state.get("current_phase") == "error":
|
|
||||||
logger.error(f"Recon phase has error, routing to end: {state.get('error')}")
|
|
||||||
return "end"
|
|
||||||
|
|
||||||
# 检查 LLM 是否有决策
|
|
||||||
llm_action = state.get("llm_next_action")
|
|
||||||
if llm_action:
|
|
||||||
logger.info(f"Using LLM routing decision: {llm_action}, reason: {state.get('llm_routing_reason')}")
|
|
||||||
if llm_action == "end":
|
|
||||||
return "end"
|
|
||||||
return "analysis"
|
|
||||||
|
|
||||||
# 默认逻辑(作为 fallback)
|
|
||||||
if not state.get("entry_points") and not state.get("high_risk_areas"):
|
|
||||||
return "end"
|
|
||||||
return "analysis"
|
|
||||||
|
|
||||||
|
|
||||||
def route_after_analysis(state: AuditState) -> Literal["verification", "analysis", "report"]:
|
|
||||||
"""
|
|
||||||
Analysis 后的路由决策
|
|
||||||
优先使用 LLM 的决策
|
|
||||||
"""
|
|
||||||
# 检查 LLM 是否有决策
|
|
||||||
llm_action = state.get("llm_next_action")
|
|
||||||
if llm_action:
|
|
||||||
logger.info(f"Using LLM routing decision: {llm_action}, reason: {state.get('llm_routing_reason')}")
|
|
||||||
if llm_action == "verification":
|
|
||||||
return "verification"
|
|
||||||
elif llm_action == "analysis":
|
|
||||||
return "analysis"
|
|
||||||
elif llm_action == "report":
|
|
||||||
return "report"
|
|
||||||
|
|
||||||
# 默认逻辑
|
|
||||||
findings = state.get("findings", [])
|
|
||||||
iteration = state.get("iteration", 0)
|
|
||||||
max_iterations = state.get("max_iterations", 3)
|
|
||||||
should_continue = state.get("should_continue_analysis", False)
|
|
||||||
|
|
||||||
if not findings:
|
|
||||||
return "report"
|
|
||||||
|
|
||||||
if should_continue and iteration < max_iterations:
|
|
||||||
return "analysis"
|
|
||||||
|
|
||||||
return "verification"
|
|
||||||
|
|
||||||
|
|
||||||
def route_after_verification(state: AuditState) -> Literal["analysis", "report"]:
|
|
||||||
"""
|
|
||||||
Verification 后的路由决策
|
|
||||||
优先使用 LLM 的决策
|
|
||||||
"""
|
|
||||||
# 检查 LLM 是否有决策
|
|
||||||
llm_action = state.get("llm_next_action")
|
|
||||||
if llm_action:
|
|
||||||
logger.info(f"Using LLM routing decision: {llm_action}, reason: {state.get('llm_routing_reason')}")
|
|
||||||
if llm_action == "analysis":
|
|
||||||
return "analysis"
|
|
||||||
return "report"
|
|
||||||
|
|
||||||
# 默认逻辑
|
|
||||||
false_positives = state.get("false_positives", [])
|
|
||||||
iteration = state.get("iteration", 0)
|
|
||||||
max_iterations = state.get("max_iterations", 3)
|
|
||||||
|
|
||||||
if len(false_positives) > len(state.get("verified_findings", [])) and iteration < max_iterations:
|
|
||||||
return "analysis"
|
|
||||||
|
|
||||||
return "report"
|
|
||||||
|
|
||||||
|
|
||||||
# ============ 创建审计图 ============
|
|
||||||
|
|
||||||
def create_audit_graph(
|
|
||||||
recon_node,
|
|
||||||
analysis_node,
|
|
||||||
verification_node,
|
|
||||||
report_node,
|
|
||||||
checkpointer: Optional[MemorySaver] = None,
|
|
||||||
llm_service=None, # 用于 LLM 路由决策
|
|
||||||
) -> StateGraph:
|
|
||||||
"""
|
|
||||||
创建审计工作流图
|
|
||||||
|
|
||||||
Args:
|
|
||||||
recon_node: 信息收集节点
|
|
||||||
analysis_node: 漏洞分析节点
|
|
||||||
verification_node: 漏洞验证节点
|
|
||||||
report_node: 报告生成节点
|
|
||||||
checkpointer: 检查点存储器
|
|
||||||
llm_service: LLM 服务(用于路由决策)
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
编译后的 StateGraph
|
|
||||||
|
|
||||||
工作流结构:
|
|
||||||
|
|
||||||
START
|
|
||||||
│
|
|
||||||
▼
|
|
||||||
┌──────┐
|
|
||||||
│Recon │ 信息收集 (LLM 驱动)
|
|
||||||
└──┬───┘
|
|
||||||
│ LLM 决定
|
|
||||||
▼
|
|
||||||
┌──────────┐
|
|
||||||
│ Analysis │◄─────┐ 漏洞分析 (LLM 驱动,可循环)
|
|
||||||
└────┬─────┘ │
|
|
||||||
│ LLM 决定 │
|
|
||||||
▼ │
|
|
||||||
┌────────────┐ │
|
|
||||||
│Verification│────┘ 漏洞验证 (LLM 驱动,可回溯)
|
|
||||||
└─────┬──────┘
|
|
||||||
│ LLM 决定
|
|
||||||
▼
|
|
||||||
┌──────────┐
|
|
||||||
│ Report │ 报告生成
|
|
||||||
└────┬─────┘
|
|
||||||
│
|
|
||||||
▼
|
|
||||||
END
|
|
||||||
"""
|
|
||||||
|
|
||||||
# 创建状态图
|
|
||||||
workflow = StateGraph(AuditState)
|
|
||||||
|
|
||||||
# 如果有 LLM 服务,创建路由决策器
|
|
||||||
llm_router = LLMRouter(llm_service) if llm_service else None
|
|
||||||
|
|
||||||
# 包装节点以添加 LLM 路由决策
|
|
||||||
async def recon_with_routing(state):
|
|
||||||
result = await recon_node(state)
|
|
||||||
|
|
||||||
# LLM 决定下一步
|
|
||||||
if llm_router:
|
|
||||||
decision = await llm_router.decide_after_recon({**state, **result})
|
|
||||||
result["llm_next_action"] = decision.get("action")
|
|
||||||
result["llm_routing_reason"] = decision.get("reason")
|
|
||||||
|
|
||||||
return result
|
|
||||||
|
|
||||||
async def analysis_with_routing(state):
|
|
||||||
result = await analysis_node(state)
|
|
||||||
|
|
||||||
# LLM 决定下一步
|
|
||||||
if llm_router:
|
|
||||||
decision = await llm_router.decide_after_analysis({**state, **result})
|
|
||||||
result["llm_next_action"] = decision.get("action")
|
|
||||||
result["llm_routing_reason"] = decision.get("reason")
|
|
||||||
|
|
||||||
return result
|
|
||||||
|
|
||||||
async def verification_with_routing(state):
|
|
||||||
result = await verification_node(state)
|
|
||||||
|
|
||||||
# LLM 决定下一步
|
|
||||||
if llm_router:
|
|
||||||
decision = await llm_router.decide_after_verification({**state, **result})
|
|
||||||
result["llm_next_action"] = decision.get("action")
|
|
||||||
result["llm_routing_reason"] = decision.get("reason")
|
|
||||||
|
|
||||||
return result
|
|
||||||
|
|
||||||
# 添加节点
|
|
||||||
if llm_router:
|
|
||||||
workflow.add_node("recon", recon_with_routing)
|
|
||||||
workflow.add_node("analysis", analysis_with_routing)
|
|
||||||
workflow.add_node("verification", verification_with_routing)
|
|
||||||
else:
|
|
||||||
workflow.add_node("recon", recon_node)
|
|
||||||
workflow.add_node("analysis", analysis_node)
|
|
||||||
workflow.add_node("verification", verification_node)
|
|
||||||
|
|
||||||
workflow.add_node("report", report_node)
|
|
||||||
|
|
||||||
# 设置入口点
|
|
||||||
workflow.set_entry_point("recon")
|
|
||||||
|
|
||||||
# 添加条件边
|
|
||||||
workflow.add_conditional_edges(
|
|
||||||
"recon",
|
|
||||||
route_after_recon,
|
|
||||||
{
|
|
||||||
"analysis": "analysis",
|
|
||||||
"end": END,
|
|
||||||
}
|
|
||||||
)
|
|
||||||
|
|
||||||
workflow.add_conditional_edges(
|
|
||||||
"analysis",
|
|
||||||
route_after_analysis,
|
|
||||||
{
|
|
||||||
"verification": "verification",
|
|
||||||
"analysis": "analysis",
|
|
||||||
"report": "report",
|
|
||||||
}
|
|
||||||
)
|
|
||||||
|
|
||||||
workflow.add_conditional_edges(
|
|
||||||
"verification",
|
|
||||||
route_after_verification,
|
|
||||||
{
|
|
||||||
"analysis": "analysis",
|
|
||||||
"report": "report",
|
|
||||||
}
|
|
||||||
)
|
|
||||||
|
|
||||||
# Report -> END
|
|
||||||
workflow.add_edge("report", END)
|
|
||||||
|
|
||||||
# 编译图
|
|
||||||
if checkpointer:
|
|
||||||
return workflow.compile(checkpointer=checkpointer)
|
|
||||||
else:
|
|
||||||
return workflow.compile()
|
|
||||||
|
|
||||||
|
|
||||||
# ============ 带人机协作的审计图 ============
|
|
||||||
|
|
||||||
def create_audit_graph_with_human(
|
|
||||||
recon_node,
|
|
||||||
analysis_node,
|
|
||||||
verification_node,
|
|
||||||
report_node,
|
|
||||||
human_review_node,
|
|
||||||
checkpointer: Optional[MemorySaver] = None,
|
|
||||||
llm_service=None,
|
|
||||||
) -> StateGraph:
|
|
||||||
"""
|
|
||||||
创建带人机协作的审计工作流图
|
|
||||||
|
|
||||||
在验证阶段后增加人工审核节点
|
|
||||||
"""
|
|
||||||
|
|
||||||
workflow = StateGraph(AuditState)
|
|
||||||
llm_router = LLMRouter(llm_service) if llm_service else None
|
|
||||||
|
|
||||||
# 包装节点
|
|
||||||
async def recon_with_routing(state):
|
|
||||||
result = await recon_node(state)
|
|
||||||
if llm_router:
|
|
||||||
decision = await llm_router.decide_after_recon({**state, **result})
|
|
||||||
result["llm_next_action"] = decision.get("action")
|
|
||||||
result["llm_routing_reason"] = decision.get("reason")
|
|
||||||
return result
|
|
||||||
|
|
||||||
async def analysis_with_routing(state):
|
|
||||||
result = await analysis_node(state)
|
|
||||||
if llm_router:
|
|
||||||
decision = await llm_router.decide_after_analysis({**state, **result})
|
|
||||||
result["llm_next_action"] = decision.get("action")
|
|
||||||
result["llm_routing_reason"] = decision.get("reason")
|
|
||||||
return result
|
|
||||||
|
|
||||||
# 添加节点
|
|
||||||
if llm_router:
|
|
||||||
workflow.add_node("recon", recon_with_routing)
|
|
||||||
workflow.add_node("analysis", analysis_with_routing)
|
|
||||||
else:
|
|
||||||
workflow.add_node("recon", recon_node)
|
|
||||||
workflow.add_node("analysis", analysis_node)
|
|
||||||
|
|
||||||
workflow.add_node("verification", verification_node)
|
|
||||||
workflow.add_node("human_review", human_review_node)
|
|
||||||
workflow.add_node("report", report_node)
|
|
||||||
|
|
||||||
workflow.set_entry_point("recon")
|
|
||||||
|
|
||||||
workflow.add_conditional_edges(
|
|
||||||
"recon",
|
|
||||||
route_after_recon,
|
|
||||||
{"analysis": "analysis", "end": END}
|
|
||||||
)
|
|
||||||
|
|
||||||
workflow.add_conditional_edges(
|
|
||||||
"analysis",
|
|
||||||
route_after_analysis,
|
|
||||||
{
|
|
||||||
"verification": "verification",
|
|
||||||
"analysis": "analysis",
|
|
||||||
"report": "report",
|
|
||||||
}
|
|
||||||
)
|
|
||||||
|
|
||||||
# Verification -> Human Review
|
|
||||||
workflow.add_edge("verification", "human_review")
|
|
||||||
|
|
||||||
# Human Review 后的路由
|
|
||||||
def route_after_human(state: AuditState) -> Literal["analysis", "report"]:
|
|
||||||
if state.get("should_continue_analysis"):
|
|
||||||
return "analysis"
|
|
||||||
return "report"
|
|
||||||
|
|
||||||
workflow.add_conditional_edges(
|
|
||||||
"human_review",
|
|
||||||
route_after_human,
|
|
||||||
{"analysis": "analysis", "report": "report"}
|
|
||||||
)
|
|
||||||
|
|
||||||
workflow.add_edge("report", END)
|
|
||||||
|
|
||||||
if checkpointer:
|
|
||||||
return workflow.compile(checkpointer=checkpointer, interrupt_before=["human_review"])
|
|
||||||
else:
|
|
||||||
return workflow.compile()
|
|
||||||
|
|
||||||
|
|
||||||
# ============ 执行器 ============
|
|
||||||
|
|
||||||
class AuditGraphRunner:
|
|
||||||
"""
|
|
||||||
审计图执行器
|
|
||||||
封装 LangGraph 工作流的执行
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
graph: StateGraph,
|
|
||||||
event_emitter=None,
|
|
||||||
):
|
|
||||||
self.graph = graph
|
|
||||||
self.event_emitter = event_emitter
|
|
||||||
|
|
||||||
async def run(
|
|
||||||
self,
|
|
||||||
project_root: str,
|
|
||||||
project_info: Dict[str, Any],
|
|
||||||
config: Dict[str, Any],
|
|
||||||
task_id: str,
|
|
||||||
) -> Dict[str, Any]:
|
|
||||||
"""
|
|
||||||
执行审计工作流
|
|
||||||
"""
|
|
||||||
# 初始状态
|
|
||||||
initial_state: AuditState = {
|
|
||||||
"project_root": project_root,
|
|
||||||
"project_info": project_info,
|
|
||||||
"config": config,
|
|
||||||
"task_id": task_id,
|
|
||||||
"tech_stack": {},
|
|
||||||
"entry_points": [],
|
|
||||||
"high_risk_areas": [],
|
|
||||||
"dependencies": {},
|
|
||||||
"findings": [],
|
|
||||||
"verified_findings": [],
|
|
||||||
"false_positives": [],
|
|
||||||
"_verified_findings_update": None, # 🔥 NEW: 验证后的 findings 更新
|
|
||||||
"current_phase": "start",
|
|
||||||
"iteration": 0,
|
|
||||||
"max_iterations": config.get("max_iterations", 3),
|
|
||||||
"should_continue_analysis": False,
|
|
||||||
"llm_next_action": None,
|
|
||||||
"llm_routing_reason": None,
|
|
||||||
"messages": [],
|
|
||||||
"events": [],
|
|
||||||
"summary": None,
|
|
||||||
"security_score": None,
|
|
||||||
"error": None,
|
|
||||||
}
|
|
||||||
|
|
||||||
run_config = {
|
|
||||||
"configurable": {
|
|
||||||
"thread_id": task_id,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
try:
|
|
||||||
async for event in self.graph.astream(initial_state, config=run_config):
|
|
||||||
if self.event_emitter:
|
|
||||||
for node_name, node_state in event.items():
|
|
||||||
await self.event_emitter.emit_info(
|
|
||||||
f"节点 {node_name} 完成"
|
|
||||||
)
|
|
||||||
|
|
||||||
# 发射 LLM 路由决策事件
|
|
||||||
if node_state.get("llm_routing_reason"):
|
|
||||||
await self.event_emitter.emit_info(
|
|
||||||
f"🧠 LLM 决策: {node_state.get('llm_next_action')} - {node_state.get('llm_routing_reason')}"
|
|
||||||
)
|
|
||||||
|
|
||||||
if node_name == "analysis" and node_state.get("findings"):
|
|
||||||
new_findings = node_state["findings"]
|
|
||||||
await self.event_emitter.emit_info(
|
|
||||||
f"发现 {len(new_findings)} 个潜在漏洞"
|
|
||||||
)
|
|
||||||
|
|
||||||
final_state = self.graph.get_state(run_config)
|
|
||||||
return final_state.values
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"Graph execution failed: {e}", exc_info=True)
|
|
||||||
raise
|
|
||||||
|
|
||||||
async def run_with_human_review(
|
|
||||||
self,
|
|
||||||
initial_state: AuditState,
|
|
||||||
human_feedback_callback,
|
|
||||||
) -> Dict[str, Any]:
|
|
||||||
"""带人机协作的执行"""
|
|
||||||
run_config = {
|
|
||||||
"configurable": {
|
|
||||||
"thread_id": initial_state["task_id"],
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
async for event in self.graph.astream(initial_state, config=run_config):
|
|
||||||
pass
|
|
||||||
|
|
||||||
current_state = self.graph.get_state(run_config)
|
|
||||||
|
|
||||||
if current_state.next == ("human_review",):
|
|
||||||
human_decision = await human_feedback_callback(current_state.values)
|
|
||||||
|
|
||||||
updated_state = {
|
|
||||||
**current_state.values,
|
|
||||||
"should_continue_analysis": human_decision.get("continue_analysis", False),
|
|
||||||
}
|
|
||||||
|
|
||||||
async for event in self.graph.astream(updated_state, config=run_config):
|
|
||||||
pass
|
|
||||||
|
|
||||||
return self.graph.get_state(run_config).values
|
|
||||||
|
|
@ -1,556 +0,0 @@
|
||||||
"""
|
|
||||||
LangGraph 节点实现
|
|
||||||
每个节点封装一个 Agent 的执行逻辑
|
|
||||||
|
|
||||||
协作增强:节点之间通过 TaskHandoff 传递结构化的上下文和洞察
|
|
||||||
"""
|
|
||||||
|
|
||||||
from typing import Dict, Any, List, Optional
|
|
||||||
import logging
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
|
||||||
|
|
||||||
# 延迟导入避免循环依赖
|
|
||||||
def get_audit_state_type():
|
|
||||||
from .audit_graph import AuditState
|
|
||||||
return AuditState
|
|
||||||
|
|
||||||
|
|
||||||
class BaseNode:
|
|
||||||
"""节点基类"""
|
|
||||||
|
|
||||||
def __init__(self, agent=None, event_emitter=None):
|
|
||||||
self.agent = agent
|
|
||||||
self.event_emitter = event_emitter
|
|
||||||
|
|
||||||
async def emit_event(self, event_type: str, message: str, **kwargs):
|
|
||||||
"""发射事件"""
|
|
||||||
if self.event_emitter:
|
|
||||||
try:
|
|
||||||
await self.event_emitter.emit_info(message)
|
|
||||||
except Exception as e:
|
|
||||||
logger.warning(f"Failed to emit event: {e}")
|
|
||||||
|
|
||||||
def _extract_handoff_from_state(self, state: Dict[str, Any], from_phase: str):
|
|
||||||
"""从状态中提取前序 Agent 的 handoff"""
|
|
||||||
handoff_data = state.get(f"{from_phase}_handoff")
|
|
||||||
if handoff_data:
|
|
||||||
from ..agents.base import TaskHandoff
|
|
||||||
return TaskHandoff.from_dict(handoff_data)
|
|
||||||
return None
|
|
||||||
|
|
||||||
|
|
||||||
class ReconNode(BaseNode):
|
|
||||||
"""
|
|
||||||
信息收集节点
|
|
||||||
|
|
||||||
输入: project_root, project_info, config
|
|
||||||
输出: tech_stack, entry_points, high_risk_areas, dependencies, recon_handoff
|
|
||||||
"""
|
|
||||||
|
|
||||||
async def __call__(self, state: Dict[str, Any]) -> Dict[str, Any]:
|
|
||||||
"""执行信息收集"""
|
|
||||||
await self.emit_event("phase_start", "🔍 开始信息收集阶段")
|
|
||||||
|
|
||||||
try:
|
|
||||||
# 调用 Recon Agent
|
|
||||||
result = await self.agent.run({
|
|
||||||
"project_info": state["project_info"],
|
|
||||||
"config": state["config"],
|
|
||||||
})
|
|
||||||
|
|
||||||
if result.success and result.data:
|
|
||||||
data = result.data
|
|
||||||
|
|
||||||
# 🔥 创建交接信息给 Analysis Agent
|
|
||||||
handoff = self.agent.create_handoff(
|
|
||||||
to_agent="Analysis",
|
|
||||||
summary=f"项目信息收集完成。发现 {len(data.get('entry_points', []))} 个入口点,{len(data.get('high_risk_areas', []))} 个高风险区域。",
|
|
||||||
key_findings=data.get("initial_findings", []),
|
|
||||||
suggested_actions=[
|
|
||||||
{
|
|
||||||
"type": "deep_analysis",
|
|
||||||
"description": f"深入分析高风险区域: {', '.join(data.get('high_risk_areas', [])[:5])}",
|
|
||||||
"priority": "high",
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"type": "entry_point_audit",
|
|
||||||
"description": "审计所有入口点的输入验证",
|
|
||||||
"priority": "high",
|
|
||||||
},
|
|
||||||
],
|
|
||||||
attention_points=[
|
|
||||||
f"技术栈: {data.get('tech_stack', {}).get('frameworks', [])}",
|
|
||||||
f"主要语言: {data.get('tech_stack', {}).get('languages', [])}",
|
|
||||||
],
|
|
||||||
priority_areas=data.get("high_risk_areas", [])[:10],
|
|
||||||
context_data={
|
|
||||||
"tech_stack": data.get("tech_stack", {}),
|
|
||||||
"entry_points": data.get("entry_points", []),
|
|
||||||
"dependencies": data.get("dependencies", {}),
|
|
||||||
},
|
|
||||||
)
|
|
||||||
|
|
||||||
await self.emit_event(
|
|
||||||
"phase_complete",
|
|
||||||
f"✅ 信息收集完成: 发现 {len(data.get('entry_points', []))} 个入口点"
|
|
||||||
)
|
|
||||||
|
|
||||||
return {
|
|
||||||
"tech_stack": data.get("tech_stack", {}),
|
|
||||||
"entry_points": data.get("entry_points", []),
|
|
||||||
"high_risk_areas": data.get("high_risk_areas", []),
|
|
||||||
"dependencies": data.get("dependencies", {}),
|
|
||||||
"current_phase": "recon_complete",
|
|
||||||
"findings": data.get("initial_findings", []),
|
|
||||||
# 🔥 保存交接信息
|
|
||||||
"recon_handoff": handoff.to_dict(),
|
|
||||||
"events": [{
|
|
||||||
"type": "recon_complete",
|
|
||||||
"data": {
|
|
||||||
"entry_points_count": len(data.get("entry_points", [])),
|
|
||||||
"high_risk_areas_count": len(data.get("high_risk_areas", [])),
|
|
||||||
"handoff_summary": handoff.summary,
|
|
||||||
}
|
|
||||||
}],
|
|
||||||
}
|
|
||||||
else:
|
|
||||||
return {
|
|
||||||
"error": result.error or "Recon failed",
|
|
||||||
"current_phase": "error",
|
|
||||||
}
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"Recon node failed: {e}", exc_info=True)
|
|
||||||
return {
|
|
||||||
"error": str(e),
|
|
||||||
"current_phase": "error",
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
class AnalysisNode(BaseNode):
|
|
||||||
"""
|
|
||||||
漏洞分析节点
|
|
||||||
|
|
||||||
输入: tech_stack, entry_points, high_risk_areas, recon_handoff
|
|
||||||
输出: findings (累加), should_continue_analysis, analysis_handoff
|
|
||||||
"""
|
|
||||||
|
|
||||||
async def __call__(self, state: Dict[str, Any]) -> Dict[str, Any]:
|
|
||||||
"""执行漏洞分析"""
|
|
||||||
iteration = state.get("iteration", 0) + 1
|
|
||||||
|
|
||||||
await self.emit_event(
|
|
||||||
"phase_start",
|
|
||||||
f"🔬 开始漏洞分析阶段 (迭代 {iteration})"
|
|
||||||
)
|
|
||||||
|
|
||||||
try:
|
|
||||||
# 🔥 提取 Recon 的交接信息
|
|
||||||
recon_handoff = self._extract_handoff_from_state(state, "recon")
|
|
||||||
if recon_handoff:
|
|
||||||
self.agent.receive_handoff(recon_handoff)
|
|
||||||
await self.emit_event(
|
|
||||||
"handoff_received",
|
|
||||||
f"📨 收到 Recon Agent 交接: {recon_handoff.summary[:50]}..."
|
|
||||||
)
|
|
||||||
|
|
||||||
# 构建分析输入
|
|
||||||
analysis_input = {
|
|
||||||
"phase_name": "analysis",
|
|
||||||
"project_info": state["project_info"],
|
|
||||||
"config": state["config"],
|
|
||||||
"plan": {
|
|
||||||
"high_risk_areas": state.get("high_risk_areas", []),
|
|
||||||
},
|
|
||||||
"previous_results": {
|
|
||||||
"recon": {
|
|
||||||
"data": {
|
|
||||||
"tech_stack": state.get("tech_stack", {}),
|
|
||||||
"entry_points": state.get("entry_points", []),
|
|
||||||
"high_risk_areas": state.get("high_risk_areas", []),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
},
|
|
||||||
# 🔥 传递交接信息
|
|
||||||
"handoff": recon_handoff,
|
|
||||||
}
|
|
||||||
|
|
||||||
# 调用 Analysis Agent
|
|
||||||
result = await self.agent.run(analysis_input)
|
|
||||||
|
|
||||||
if result.success and result.data:
|
|
||||||
new_findings = result.data.get("findings", [])
|
|
||||||
logger.info(f"[AnalysisNode] Agent returned {len(new_findings)} findings")
|
|
||||||
|
|
||||||
# 判断是否需要继续分析
|
|
||||||
should_continue = (
|
|
||||||
len(new_findings) >= 5 and
|
|
||||||
iteration < state.get("max_iterations", 3)
|
|
||||||
)
|
|
||||||
|
|
||||||
# 🔥 创建交接信息给 Verification Agent
|
|
||||||
# 统计严重程度
|
|
||||||
severity_counts = {"critical": 0, "high": 0, "medium": 0, "low": 0}
|
|
||||||
for f in new_findings:
|
|
||||||
if isinstance(f, dict):
|
|
||||||
sev = f.get("severity", "medium")
|
|
||||||
severity_counts[sev] = severity_counts.get(sev, 0) + 1
|
|
||||||
|
|
||||||
handoff = self.agent.create_handoff(
|
|
||||||
to_agent="Verification",
|
|
||||||
summary=f"漏洞分析完成。发现 {len(new_findings)} 个潜在漏洞 (Critical: {severity_counts['critical']}, High: {severity_counts['high']}, Medium: {severity_counts['medium']}, Low: {severity_counts['low']})",
|
|
||||||
key_findings=new_findings[:20], # 传递前20个发现
|
|
||||||
suggested_actions=[
|
|
||||||
{
|
|
||||||
"type": "verify_critical",
|
|
||||||
"description": "优先验证 Critical 和 High 级别的漏洞",
|
|
||||||
"priority": "critical",
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"type": "poc_generation",
|
|
||||||
"description": "为确认的漏洞生成 PoC",
|
|
||||||
"priority": "high",
|
|
||||||
},
|
|
||||||
],
|
|
||||||
attention_points=[
|
|
||||||
f"共 {severity_counts['critical']} 个 Critical 级别漏洞需要立即验证",
|
|
||||||
f"共 {severity_counts['high']} 个 High 级别漏洞需要优先验证",
|
|
||||||
"注意检查是否有误报,特别是静态分析工具的结果",
|
|
||||||
],
|
|
||||||
priority_areas=[
|
|
||||||
f.get("file_path", "") for f in new_findings
|
|
||||||
if f.get("severity") in ["critical", "high"]
|
|
||||||
][:10],
|
|
||||||
context_data={
|
|
||||||
"severity_distribution": severity_counts,
|
|
||||||
"total_findings": len(new_findings),
|
|
||||||
"iteration": iteration,
|
|
||||||
},
|
|
||||||
)
|
|
||||||
|
|
||||||
await self.emit_event(
|
|
||||||
"phase_complete",
|
|
||||||
f"✅ 分析迭代 {iteration} 完成: 发现 {len(new_findings)} 个潜在漏洞"
|
|
||||||
)
|
|
||||||
|
|
||||||
return {
|
|
||||||
"findings": new_findings,
|
|
||||||
"iteration": iteration,
|
|
||||||
"should_continue_analysis": should_continue,
|
|
||||||
"current_phase": "analysis_complete",
|
|
||||||
# 🔥 保存交接信息
|
|
||||||
"analysis_handoff": handoff.to_dict(),
|
|
||||||
"events": [{
|
|
||||||
"type": "analysis_iteration",
|
|
||||||
"data": {
|
|
||||||
"iteration": iteration,
|
|
||||||
"findings_count": len(new_findings),
|
|
||||||
"severity_distribution": severity_counts,
|
|
||||||
"handoff_summary": handoff.summary,
|
|
||||||
}
|
|
||||||
}],
|
|
||||||
}
|
|
||||||
else:
|
|
||||||
return {
|
|
||||||
"iteration": iteration,
|
|
||||||
"should_continue_analysis": False,
|
|
||||||
"current_phase": "analysis_complete",
|
|
||||||
}
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"Analysis node failed: {e}", exc_info=True)
|
|
||||||
return {
|
|
||||||
"error": str(e),
|
|
||||||
"should_continue_analysis": False,
|
|
||||||
"current_phase": "error",
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
class VerificationNode(BaseNode):
|
|
||||||
"""
|
|
||||||
漏洞验证节点
|
|
||||||
|
|
||||||
输入: findings, analysis_handoff
|
|
||||||
输出: verified_findings, false_positives, verification_handoff
|
|
||||||
"""
|
|
||||||
|
|
||||||
async def __call__(self, state: Dict[str, Any]) -> Dict[str, Any]:
|
|
||||||
"""执行漏洞验证"""
|
|
||||||
findings = state.get("findings", [])
|
|
||||||
logger.info(f"[VerificationNode] Received {len(findings)} findings to verify")
|
|
||||||
|
|
||||||
if not findings:
|
|
||||||
return {
|
|
||||||
"verified_findings": [],
|
|
||||||
"false_positives": [],
|
|
||||||
"current_phase": "verification_complete",
|
|
||||||
}
|
|
||||||
|
|
||||||
await self.emit_event(
|
|
||||||
"phase_start",
|
|
||||||
f"🔐 开始漏洞验证阶段 ({len(findings)} 个待验证)"
|
|
||||||
)
|
|
||||||
|
|
||||||
try:
|
|
||||||
# 🔥 提取 Analysis 的交接信息
|
|
||||||
analysis_handoff = self._extract_handoff_from_state(state, "analysis")
|
|
||||||
if analysis_handoff:
|
|
||||||
self.agent.receive_handoff(analysis_handoff)
|
|
||||||
await self.emit_event(
|
|
||||||
"handoff_received",
|
|
||||||
f"📨 收到 Analysis Agent 交接: {analysis_handoff.summary[:50]}..."
|
|
||||||
)
|
|
||||||
|
|
||||||
# 构建验证输入
|
|
||||||
verification_input = {
|
|
||||||
"previous_results": {
|
|
||||||
"analysis": {
|
|
||||||
"data": {
|
|
||||||
"findings": findings,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
},
|
|
||||||
"config": state["config"],
|
|
||||||
# 🔥 传递交接信息
|
|
||||||
"handoff": analysis_handoff,
|
|
||||||
}
|
|
||||||
|
|
||||||
# 调用 Verification Agent
|
|
||||||
result = await self.agent.run(verification_input)
|
|
||||||
|
|
||||||
if result.success and result.data:
|
|
||||||
all_verified_findings = result.data.get("findings", [])
|
|
||||||
verified = [f for f in all_verified_findings if f.get("is_verified")]
|
|
||||||
false_pos = [f.get("id", f.get("title", "unknown")) for f in all_verified_findings
|
|
||||||
if f.get("verdict") == "false_positive"]
|
|
||||||
|
|
||||||
# 🔥 CRITICAL FIX: 用验证结果更新原始 findings
|
|
||||||
# 创建 findings 的更新映射,基于 (file_path, line_start, vulnerability_type)
|
|
||||||
verified_map = {}
|
|
||||||
for vf in all_verified_findings:
|
|
||||||
key = (
|
|
||||||
vf.get("file_path", ""),
|
|
||||||
vf.get("line_start", 0),
|
|
||||||
vf.get("vulnerability_type", ""),
|
|
||||||
)
|
|
||||||
verified_map[key] = vf
|
|
||||||
|
|
||||||
# 合并验证结果到原始 findings
|
|
||||||
updated_findings = []
|
|
||||||
seen_keys = set()
|
|
||||||
|
|
||||||
# 首先处理原始 findings,用验证结果更新
|
|
||||||
for f in findings:
|
|
||||||
if not isinstance(f, dict):
|
|
||||||
continue
|
|
||||||
key = (
|
|
||||||
f.get("file_path", ""),
|
|
||||||
f.get("line_start", 0),
|
|
||||||
f.get("vulnerability_type", ""),
|
|
||||||
)
|
|
||||||
if key in verified_map:
|
|
||||||
# 使用验证后的版本
|
|
||||||
updated_findings.append(verified_map[key])
|
|
||||||
seen_keys.add(key)
|
|
||||||
else:
|
|
||||||
# 保留原始(未验证)
|
|
||||||
updated_findings.append(f)
|
|
||||||
seen_keys.add(key)
|
|
||||||
|
|
||||||
# 添加验证结果中的新发现(如果有)
|
|
||||||
for key, vf in verified_map.items():
|
|
||||||
if key not in seen_keys:
|
|
||||||
updated_findings.append(vf)
|
|
||||||
|
|
||||||
logger.info(f"[VerificationNode] Updated findings: {len(updated_findings)} total, {len(verified)} verified")
|
|
||||||
|
|
||||||
# 🔥 创建交接信息给 Report 节点
|
|
||||||
handoff = self.agent.create_handoff(
|
|
||||||
to_agent="Report",
|
|
||||||
summary=f"漏洞验证完成。{len(verified)} 个漏洞已确认,{len(false_pos)} 个误报已排除。",
|
|
||||||
key_findings=verified,
|
|
||||||
suggested_actions=[
|
|
||||||
{
|
|
||||||
"type": "generate_report",
|
|
||||||
"description": "生成详细的安全审计报告",
|
|
||||||
"priority": "high",
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"type": "remediation_plan",
|
|
||||||
"description": "为确认的漏洞制定修复计划",
|
|
||||||
"priority": "high",
|
|
||||||
},
|
|
||||||
],
|
|
||||||
attention_points=[
|
|
||||||
f"共 {len(verified)} 个漏洞已确认存在",
|
|
||||||
f"共 {len(false_pos)} 个误报已排除",
|
|
||||||
"建议按严重程度优先修复 Critical 和 High 级别漏洞",
|
|
||||||
],
|
|
||||||
context_data={
|
|
||||||
"verified_count": len(verified),
|
|
||||||
"false_positive_count": len(false_pos),
|
|
||||||
"total_analyzed": len(findings),
|
|
||||||
"verification_rate": len(verified) / len(findings) if findings else 0,
|
|
||||||
},
|
|
||||||
)
|
|
||||||
|
|
||||||
await self.emit_event(
|
|
||||||
"phase_complete",
|
|
||||||
f"✅ 验证完成: {len(verified)} 已确认, {len(false_pos)} 误报"
|
|
||||||
)
|
|
||||||
|
|
||||||
return {
|
|
||||||
# 🔥 CRITICAL: 返回更新后的 findings,这会替换状态中的 findings
|
|
||||||
# 注意:由于 LangGraph 使用 operator.add,我们需要在 runner 中处理合并
|
|
||||||
# 这里我们返回 _verified_findings_update 作为特殊字段
|
|
||||||
"_verified_findings_update": updated_findings,
|
|
||||||
"verified_findings": verified,
|
|
||||||
"false_positives": false_pos,
|
|
||||||
"current_phase": "verification_complete",
|
|
||||||
# 🔥 保存交接信息
|
|
||||||
"verification_handoff": handoff.to_dict(),
|
|
||||||
"events": [{
|
|
||||||
"type": "verification_complete",
|
|
||||||
"data": {
|
|
||||||
"verified_count": len(verified),
|
|
||||||
"false_positive_count": len(false_pos),
|
|
||||||
"total_findings": len(updated_findings),
|
|
||||||
"handoff_summary": handoff.summary,
|
|
||||||
}
|
|
||||||
}],
|
|
||||||
}
|
|
||||||
else:
|
|
||||||
return {
|
|
||||||
"verified_findings": [],
|
|
||||||
"false_positives": [],
|
|
||||||
"current_phase": "verification_complete",
|
|
||||||
}
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"Verification node failed: {e}", exc_info=True)
|
|
||||||
return {
|
|
||||||
"error": str(e),
|
|
||||||
"current_phase": "error",
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
class ReportNode(BaseNode):
|
|
||||||
"""
|
|
||||||
报告生成节点
|
|
||||||
|
|
||||||
输入: all state
|
|
||||||
输出: summary, security_score
|
|
||||||
"""
|
|
||||||
|
|
||||||
async def __call__(self, state: Dict[str, Any]) -> Dict[str, Any]:
|
|
||||||
"""生成审计报告"""
|
|
||||||
await self.emit_event("phase_start", "📊 生成审计报告")
|
|
||||||
|
|
||||||
try:
|
|
||||||
# 🔥 CRITICAL FIX: 优先使用验证后的 findings 更新
|
|
||||||
findings = state.get("_verified_findings_update") or state.get("findings", [])
|
|
||||||
verified = state.get("verified_findings", [])
|
|
||||||
false_positives = state.get("false_positives", [])
|
|
||||||
|
|
||||||
logger.info(f"[ReportNode] State contains {len(findings)} findings, {len(verified)} verified")
|
|
||||||
|
|
||||||
# 统计漏洞分布
|
|
||||||
severity_counts = {"critical": 0, "high": 0, "medium": 0, "low": 0}
|
|
||||||
type_counts = {}
|
|
||||||
|
|
||||||
for finding in findings:
|
|
||||||
# 跳过非字典类型的 finding(防止数据格式异常)
|
|
||||||
if not isinstance(finding, dict):
|
|
||||||
logger.warning(f"Skipping invalid finding (not a dict): {type(finding)}")
|
|
||||||
continue
|
|
||||||
|
|
||||||
sev = finding.get("severity", "medium")
|
|
||||||
severity_counts[sev] = severity_counts.get(sev, 0) + 1
|
|
||||||
|
|
||||||
vtype = finding.get("vulnerability_type", "other")
|
|
||||||
type_counts[vtype] = type_counts.get(vtype, 0) + 1
|
|
||||||
|
|
||||||
# 计算安全评分
|
|
||||||
base_score = 100
|
|
||||||
deductions = (
|
|
||||||
severity_counts["critical"] * 25 +
|
|
||||||
severity_counts["high"] * 15 +
|
|
||||||
severity_counts["medium"] * 8 +
|
|
||||||
severity_counts["low"] * 3
|
|
||||||
)
|
|
||||||
security_score = max(0, base_score - deductions)
|
|
||||||
|
|
||||||
# 生成摘要
|
|
||||||
summary = {
|
|
||||||
"total_findings": len(findings),
|
|
||||||
"verified_count": len(verified),
|
|
||||||
"false_positive_count": len(false_positives),
|
|
||||||
"severity_distribution": severity_counts,
|
|
||||||
"vulnerability_types": type_counts,
|
|
||||||
"tech_stack": state.get("tech_stack", {}),
|
|
||||||
"entry_points_analyzed": len(state.get("entry_points", [])),
|
|
||||||
"high_risk_areas": state.get("high_risk_areas", []),
|
|
||||||
"iterations": state.get("iteration", 1),
|
|
||||||
}
|
|
||||||
|
|
||||||
await self.emit_event(
|
|
||||||
"phase_complete",
|
|
||||||
f"报告生成完成: 安全评分 {security_score}/100"
|
|
||||||
)
|
|
||||||
|
|
||||||
return {
|
|
||||||
"summary": summary,
|
|
||||||
"security_score": security_score,
|
|
||||||
"current_phase": "complete",
|
|
||||||
"events": [{
|
|
||||||
"type": "audit_complete",
|
|
||||||
"data": {
|
|
||||||
"security_score": security_score,
|
|
||||||
"total_findings": len(findings),
|
|
||||||
"verified_count": len(verified),
|
|
||||||
}
|
|
||||||
}],
|
|
||||||
}
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"Report node failed: {e}", exc_info=True)
|
|
||||||
return {
|
|
||||||
"error": str(e),
|
|
||||||
"current_phase": "error",
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
class HumanReviewNode(BaseNode):
|
|
||||||
"""
|
|
||||||
人工审核节点
|
|
||||||
|
|
||||||
在此节点暂停,等待人工反馈
|
|
||||||
"""
|
|
||||||
|
|
||||||
async def __call__(self, state: Dict[str, Any]) -> Dict[str, Any]:
|
|
||||||
"""
|
|
||||||
人工审核节点
|
|
||||||
|
|
||||||
这个节点会被 interrupt_before 暂停
|
|
||||||
用户可以:
|
|
||||||
1. 确认发现
|
|
||||||
2. 标记误报
|
|
||||||
3. 请求重新分析
|
|
||||||
"""
|
|
||||||
await self.emit_event(
|
|
||||||
"human_review",
|
|
||||||
f"⏸️ 等待人工审核 ({len(state.get('verified_findings', []))} 个待确认)"
|
|
||||||
)
|
|
||||||
|
|
||||||
# 返回当前状态,不做修改
|
|
||||||
# 人工反馈会通过 update_state 传入
|
|
||||||
return {
|
|
||||||
"current_phase": "human_review",
|
|
||||||
"messages": [{
|
|
||||||
"role": "system",
|
|
||||||
"content": "等待人工审核",
|
|
||||||
"findings_for_review": state.get("verified_findings", []),
|
|
||||||
}],
|
|
||||||
}
|
|
||||||
|
|
||||||
File diff suppressed because it is too large
Load Diff
|
|
@ -6,6 +6,7 @@
|
||||||
import os
|
import os
|
||||||
import re
|
import re
|
||||||
import fnmatch
|
import fnmatch
|
||||||
|
import asyncio
|
||||||
from typing import Optional, List, Dict, Any
|
from typing import Optional, List, Dict, Any
|
||||||
from pydantic import BaseModel, Field
|
from pydantic import BaseModel, Field
|
||||||
|
|
||||||
|
|
@ -45,6 +46,36 @@ class FileReadTool(AgentTool):
|
||||||
self.exclude_patterns = exclude_patterns or []
|
self.exclude_patterns = exclude_patterns or []
|
||||||
self.target_files = set(target_files) if target_files else None
|
self.target_files = set(target_files) if target_files else None
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _read_file_lines_sync(file_path: str, start_idx: int, end_idx: int) -> tuple:
|
||||||
|
"""同步读取文件指定行范围(用于 asyncio.to_thread)"""
|
||||||
|
selected_lines = []
|
||||||
|
total_lines = 0
|
||||||
|
file_size = os.path.getsize(file_path)
|
||||||
|
|
||||||
|
with open(file_path, 'r', encoding='utf-8', errors='ignore') as f:
|
||||||
|
for i, line in enumerate(f):
|
||||||
|
total_lines = i + 1
|
||||||
|
if i >= start_idx and i < end_idx:
|
||||||
|
selected_lines.append(line)
|
||||||
|
elif i >= end_idx:
|
||||||
|
if i < end_idx + 1000:
|
||||||
|
continue
|
||||||
|
else:
|
||||||
|
remaining_bytes = file_size - f.tell()
|
||||||
|
avg_line_size = f.tell() / (i + 1)
|
||||||
|
estimated_remaining_lines = int(remaining_bytes / avg_line_size) if avg_line_size > 0 else 0
|
||||||
|
total_lines = i + 1 + estimated_remaining_lines
|
||||||
|
break
|
||||||
|
|
||||||
|
return selected_lines, total_lines
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _read_all_lines_sync(file_path: str) -> List[str]:
|
||||||
|
"""同步读取文件所有行(用于 asyncio.to_thread)"""
|
||||||
|
with open(file_path, 'r', encoding='utf-8', errors='ignore') as f:
|
||||||
|
return f.readlines()
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def name(self) -> str:
|
def name(self) -> str:
|
||||||
return "read_file"
|
return "read_file"
|
||||||
|
|
@ -136,37 +167,20 @@ class FileReadTool(AgentTool):
|
||||||
|
|
||||||
# 🔥 对于大文件,使用流式读取指定行范围
|
# 🔥 对于大文件,使用流式读取指定行范围
|
||||||
if is_large_file and (start_line is not None or end_line is not None):
|
if is_large_file and (start_line is not None or end_line is not None):
|
||||||
# 流式读取,避免一次性加载整个文件
|
|
||||||
selected_lines = []
|
|
||||||
total_lines = 0
|
|
||||||
|
|
||||||
# 计算实际的起始和结束行
|
# 计算实际的起始和结束行
|
||||||
start_idx = max(0, (start_line or 1) - 1)
|
start_idx = max(0, (start_line or 1) - 1)
|
||||||
end_idx = end_line if end_line else start_idx + max_lines
|
end_idx = end_line if end_line else start_idx + max_lines
|
||||||
|
|
||||||
with open(full_path, 'r', encoding='utf-8', errors='ignore') as f:
|
# 异步读取文件,避免阻塞事件循环
|
||||||
for i, line in enumerate(f):
|
selected_lines, total_lines = await asyncio.to_thread(
|
||||||
total_lines = i + 1
|
self._read_file_lines_sync, full_path, start_idx, end_idx
|
||||||
if i >= start_idx and i < end_idx:
|
)
|
||||||
selected_lines.append(line)
|
|
||||||
elif i >= end_idx:
|
|
||||||
# 继续计数以获取总行数,但限制读取量
|
|
||||||
if i < end_idx + 1000: # 最多再读1000行来估算总行数
|
|
||||||
continue
|
|
||||||
else:
|
|
||||||
# 估算剩余行数
|
|
||||||
remaining_bytes = file_size - f.tell()
|
|
||||||
avg_line_size = f.tell() / (i + 1)
|
|
||||||
estimated_remaining_lines = int(remaining_bytes / avg_line_size) if avg_line_size > 0 else 0
|
|
||||||
total_lines = i + 1 + estimated_remaining_lines
|
|
||||||
break
|
|
||||||
|
|
||||||
# 更新实际的结束索引
|
# 更新实际的结束索引
|
||||||
end_idx = min(end_idx, start_idx + len(selected_lines))
|
end_idx = min(end_idx, start_idx + len(selected_lines))
|
||||||
else:
|
else:
|
||||||
# 正常读取小文件
|
# 异步读取小文件,避免阻塞事件循环
|
||||||
with open(full_path, 'r', encoding='utf-8', errors='ignore') as f:
|
lines = await asyncio.to_thread(self._read_all_lines_sync, full_path)
|
||||||
lines = f.readlines()
|
|
||||||
|
|
||||||
total_lines = len(lines)
|
total_lines = len(lines)
|
||||||
|
|
||||||
|
|
@ -268,6 +282,12 @@ class FileSearchTool(AgentTool):
|
||||||
elif "/" not in pattern and "*" not in pattern:
|
elif "/" not in pattern and "*" not in pattern:
|
||||||
self.exclude_dirs.add(pattern)
|
self.exclude_dirs.add(pattern)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _read_file_lines_sync(file_path: str) -> List[str]:
|
||||||
|
"""同步读取文件所有行(用于 asyncio.to_thread)"""
|
||||||
|
with open(file_path, 'r', encoding='utf-8', errors='ignore') as f:
|
||||||
|
return f.readlines()
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def name(self) -> str:
|
def name(self) -> str:
|
||||||
return "search_code"
|
return "search_code"
|
||||||
|
|
@ -360,8 +380,10 @@ class FileSearchTool(AgentTool):
|
||||||
continue
|
continue
|
||||||
|
|
||||||
try:
|
try:
|
||||||
with open(file_path, 'r', encoding='utf-8', errors='ignore') as f:
|
# 异步读取文件,避免阻塞事件循环
|
||||||
lines = f.readlines()
|
lines = await asyncio.to_thread(
|
||||||
|
self._read_file_lines_sync, file_path
|
||||||
|
)
|
||||||
|
|
||||||
files_searched += 1
|
files_searched += 1
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -416,13 +416,93 @@ class LiteLLMAdapter(BaseLLMAdapter):
|
||||||
"finish_reason": "complete",
|
"finish_reason": "complete",
|
||||||
}
|
}
|
||||||
|
|
||||||
except Exception as e:
|
except litellm.exceptions.RateLimitError as e:
|
||||||
# 🔥 即使出错,也尝试返回估算的 usage
|
# 速率限制错误 - 需要特殊处理
|
||||||
logger.error(f"Stream error: {e}")
|
logger.error(f"Stream rate limit error: {e}")
|
||||||
|
error_msg = str(e)
|
||||||
|
# 区分"余额不足"和"频率超限"
|
||||||
|
if any(keyword in error_msg.lower() for keyword in ["余额不足", "资源包", "充值", "quota", "exceeded", "billing"]):
|
||||||
|
error_type = "quota_exceeded"
|
||||||
|
user_message = "API 配额已用尽,请检查账户余额或升级计划"
|
||||||
|
else:
|
||||||
|
error_type = "rate_limit"
|
||||||
|
# 尝试从错误消息中提取重试时间
|
||||||
|
import re
|
||||||
|
retry_match = re.search(r"retry\s*(?:in|after)\s*(\d+(?:\.\d+)?)\s*s", error_msg, re.IGNORECASE)
|
||||||
|
retry_seconds = float(retry_match.group(1)) if retry_match else 60
|
||||||
|
user_message = f"API 调用频率超限,建议等待 {int(retry_seconds)} 秒后重试"
|
||||||
|
|
||||||
output_tokens_estimate = estimate_tokens(accumulated_content) if accumulated_content else 0
|
output_tokens_estimate = estimate_tokens(accumulated_content) if accumulated_content else 0
|
||||||
yield {
|
yield {
|
||||||
"type": "error",
|
"type": "error",
|
||||||
|
"error_type": error_type,
|
||||||
|
"error": error_msg,
|
||||||
|
"user_message": user_message,
|
||||||
|
"accumulated": accumulated_content,
|
||||||
|
"usage": {
|
||||||
|
"prompt_tokens": input_tokens_estimate,
|
||||||
|
"completion_tokens": output_tokens_estimate,
|
||||||
|
"total_tokens": input_tokens_estimate + output_tokens_estimate,
|
||||||
|
} if accumulated_content else None,
|
||||||
|
}
|
||||||
|
|
||||||
|
except litellm.exceptions.AuthenticationError as e:
|
||||||
|
# 认证错误 - API Key 无效
|
||||||
|
logger.error(f"Stream authentication error: {e}")
|
||||||
|
yield {
|
||||||
|
"type": "error",
|
||||||
|
"error_type": "authentication",
|
||||||
"error": str(e),
|
"error": str(e),
|
||||||
|
"user_message": "API Key 无效或已过期,请检查配置",
|
||||||
|
"accumulated": accumulated_content,
|
||||||
|
"usage": None,
|
||||||
|
}
|
||||||
|
|
||||||
|
except litellm.exceptions.APIConnectionError as e:
|
||||||
|
# 连接错误 - 网络问题
|
||||||
|
logger.error(f"Stream connection error: {e}")
|
||||||
|
yield {
|
||||||
|
"type": "error",
|
||||||
|
"error_type": "connection",
|
||||||
|
"error": str(e),
|
||||||
|
"user_message": "无法连接到 API 服务,请检查网络连接",
|
||||||
|
"accumulated": accumulated_content,
|
||||||
|
"usage": None,
|
||||||
|
}
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
# 其他错误 - 检查是否是包装的速率限制错误
|
||||||
|
error_msg = str(e)
|
||||||
|
logger.error(f"Stream error: {e}")
|
||||||
|
|
||||||
|
# 检查是否是包装的速率限制错误(如 ServiceUnavailableError 包装 RateLimitError)
|
||||||
|
is_rate_limit = any(keyword in error_msg.lower() for keyword in [
|
||||||
|
"ratelimiterror", "rate limit", "429", "resource_exhausted",
|
||||||
|
"quota exceeded", "too many requests"
|
||||||
|
])
|
||||||
|
|
||||||
|
if is_rate_limit:
|
||||||
|
# 按速率限制错误处理
|
||||||
|
import re
|
||||||
|
# 检查是否是配额用尽
|
||||||
|
if any(keyword in error_msg.lower() for keyword in ["quota", "exceeded", "billing"]):
|
||||||
|
error_type = "quota_exceeded"
|
||||||
|
user_message = "API 配额已用尽,请检查账户余额或升级计划"
|
||||||
|
else:
|
||||||
|
error_type = "rate_limit"
|
||||||
|
retry_match = re.search(r"retry\s*(?:in|after)\s*(\d+(?:\.\d+)?)\s*s", error_msg, re.IGNORECASE)
|
||||||
|
retry_seconds = float(retry_match.group(1)) if retry_match else 60
|
||||||
|
user_message = f"API 调用频率超限,建议等待 {int(retry_seconds)} 秒后重试"
|
||||||
|
else:
|
||||||
|
error_type = "unknown"
|
||||||
|
user_message = "LLM 调用发生错误,请重试"
|
||||||
|
|
||||||
|
output_tokens_estimate = estimate_tokens(accumulated_content) if accumulated_content else 0
|
||||||
|
yield {
|
||||||
|
"type": "error",
|
||||||
|
"error_type": error_type,
|
||||||
|
"error": error_msg,
|
||||||
|
"user_message": user_message,
|
||||||
"accumulated": accumulated_content,
|
"accumulated": accumulated_content,
|
||||||
"usage": {
|
"usage": {
|
||||||
"prompt_tokens": input_tokens_estimate,
|
"prompt_tokens": input_tokens_estimate,
|
||||||
|
|
|
||||||
|
|
@ -739,6 +739,20 @@ class CodeIndexer:
|
||||||
self._needs_rebuild = False
|
self._needs_rebuild = False
|
||||||
self._rebuild_reason = ""
|
self._rebuild_reason = ""
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _read_file_sync(file_path: str) -> str:
|
||||||
|
"""
|
||||||
|
同步读取文件内容(用于 asyncio.to_thread 包装)
|
||||||
|
|
||||||
|
Args:
|
||||||
|
file_path: 文件路径
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
文件内容
|
||||||
|
"""
|
||||||
|
with open(file_path, 'r', encoding='utf-8', errors='ignore') as f:
|
||||||
|
return f.read()
|
||||||
|
|
||||||
async def initialize(self, force_rebuild: bool = False) -> Tuple[bool, str]:
|
async def initialize(self, force_rebuild: bool = False) -> Tuple[bool, str]:
|
||||||
"""
|
"""
|
||||||
初始化索引器,检测是否需要重建索引
|
初始化索引器,检测是否需要重建索引
|
||||||
|
|
@ -916,8 +930,10 @@ class CodeIndexer:
|
||||||
try:
|
try:
|
||||||
relative_path = os.path.relpath(file_path, directory)
|
relative_path = os.path.relpath(file_path, directory)
|
||||||
|
|
||||||
with open(file_path, 'r', encoding='utf-8', errors='ignore') as f:
|
# 异步读取文件,避免阻塞事件循环
|
||||||
content = f.read()
|
content = await asyncio.to_thread(
|
||||||
|
self._read_file_sync, file_path
|
||||||
|
)
|
||||||
|
|
||||||
if not content.strip():
|
if not content.strip():
|
||||||
progress.processed_files += 1
|
progress.processed_files += 1
|
||||||
|
|
@ -932,8 +948,8 @@ class CodeIndexer:
|
||||||
if len(content) > 500000:
|
if len(content) > 500000:
|
||||||
content = content[:500000]
|
content = content[:500000]
|
||||||
|
|
||||||
# 分块
|
# 异步分块,避免 Tree-sitter 解析阻塞事件循环
|
||||||
chunks = self.splitter.split_file(content, relative_path)
|
chunks = await self.splitter.split_file_async(content, relative_path)
|
||||||
|
|
||||||
# 为每个 chunk 添加 file_hash
|
# 为每个 chunk 添加 file_hash
|
||||||
for chunk in chunks:
|
for chunk in chunks:
|
||||||
|
|
@ -1018,8 +1034,10 @@ class CodeIndexer:
|
||||||
for relative_path in files_to_check:
|
for relative_path in files_to_check:
|
||||||
file_path = current_file_map[relative_path]
|
file_path = current_file_map[relative_path]
|
||||||
try:
|
try:
|
||||||
with open(file_path, 'r', encoding='utf-8', errors='ignore') as f:
|
# 异步读取文件,避免阻塞事件循环
|
||||||
content = f.read()
|
content = await asyncio.to_thread(
|
||||||
|
self._read_file_sync, file_path
|
||||||
|
)
|
||||||
current_hash = hashlib.md5(content.encode()).hexdigest()
|
current_hash = hashlib.md5(content.encode()).hexdigest()
|
||||||
if current_hash != indexed_file_hashes.get(relative_path):
|
if current_hash != indexed_file_hashes.get(relative_path):
|
||||||
files_to_update.add(relative_path)
|
files_to_update.add(relative_path)
|
||||||
|
|
@ -1055,8 +1073,10 @@ class CodeIndexer:
|
||||||
is_update = relative_path in files_to_update
|
is_update = relative_path in files_to_update
|
||||||
|
|
||||||
try:
|
try:
|
||||||
with open(file_path, 'r', encoding='utf-8', errors='ignore') as f:
|
# 异步读取文件,避免阻塞事件循环
|
||||||
content = f.read()
|
content = await asyncio.to_thread(
|
||||||
|
self._read_file_sync, file_path
|
||||||
|
)
|
||||||
|
|
||||||
if not content.strip():
|
if not content.strip():
|
||||||
progress.processed_files += 1
|
progress.processed_files += 1
|
||||||
|
|
@ -1075,8 +1095,8 @@ class CodeIndexer:
|
||||||
if len(content) > 500000:
|
if len(content) > 500000:
|
||||||
content = content[:500000]
|
content = content[:500000]
|
||||||
|
|
||||||
# 分块
|
# 异步分块,避免 Tree-sitter 解析阻塞事件循环
|
||||||
chunks = self.splitter.split_file(content, relative_path)
|
chunks = await self.splitter.split_file_async(content, relative_path)
|
||||||
|
|
||||||
# 为每个 chunk 添加 file_hash
|
# 为每个 chunk 添加 file_hash
|
||||||
for chunk in chunks:
|
for chunk in chunks:
|
||||||
|
|
|
||||||
|
|
@ -4,6 +4,7 @@
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import re
|
import re
|
||||||
|
import asyncio
|
||||||
import hashlib
|
import hashlib
|
||||||
import logging
|
import logging
|
||||||
from typing import List, Dict, Any, Optional, Tuple, Set
|
from typing import List, Dict, Any, Optional, Tuple, Set
|
||||||
|
|
@ -154,7 +155,7 @@ class TreeSitterParser:
|
||||||
".c": "c",
|
".c": "c",
|
||||||
".h": "c",
|
".h": "c",
|
||||||
".hpp": "cpp",
|
".hpp": "cpp",
|
||||||
".cs": "c_sharp",
|
".cs": "csharp",
|
||||||
".php": "php",
|
".php": "php",
|
||||||
".rb": "ruby",
|
".rb": "ruby",
|
||||||
".kt": "kotlin",
|
".kt": "kotlin",
|
||||||
|
|
@ -197,7 +198,7 @@ class TreeSitterParser:
|
||||||
# tree-sitter-languages 支持的语言列表
|
# tree-sitter-languages 支持的语言列表
|
||||||
SUPPORTED_LANGUAGES = {
|
SUPPORTED_LANGUAGES = {
|
||||||
"python", "javascript", "typescript", "tsx", "java", "go", "rust",
|
"python", "javascript", "typescript", "tsx", "java", "go", "rust",
|
||||||
"c", "cpp", "c_sharp", "php", "ruby", "kotlin", "swift", "bash",
|
"c", "cpp", "csharp", "php", "ruby", "kotlin", "swift", "bash",
|
||||||
"json", "yaml", "html", "css", "sql", "markdown",
|
"json", "yaml", "html", "css", "sql", "markdown",
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -230,7 +231,7 @@ class TreeSitterParser:
|
||||||
return False
|
return False
|
||||||
|
|
||||||
def parse(self, code: str, language: str) -> Optional[Any]:
|
def parse(self, code: str, language: str) -> Optional[Any]:
|
||||||
"""解析代码返回 AST"""
|
"""解析代码返回 AST(同步方法)"""
|
||||||
if not self._ensure_initialized(language):
|
if not self._ensure_initialized(language):
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
|
@ -245,6 +246,15 @@ class TreeSitterParser:
|
||||||
logger.warning(f"Failed to parse code: {e}")
|
logger.warning(f"Failed to parse code: {e}")
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
async def parse_async(self, code: str, language: str) -> Optional[Any]:
|
||||||
|
"""
|
||||||
|
异步解析代码返回 AST
|
||||||
|
|
||||||
|
将 CPU 密集型的 Tree-sitter 解析操作放到线程池中执行,
|
||||||
|
避免阻塞事件循环
|
||||||
|
"""
|
||||||
|
return await asyncio.to_thread(self.parse, code, language)
|
||||||
|
|
||||||
def extract_definitions(self, tree: Any, code: str, language: str) -> List[Dict[str, Any]]:
|
def extract_definitions(self, tree: Any, code: str, language: str) -> List[Dict[str, Any]]:
|
||||||
"""从 AST 提取定义"""
|
"""从 AST 提取定义"""
|
||||||
if tree is None:
|
if tree is None:
|
||||||
|
|
@ -452,6 +462,28 @@ class CodeSplitter:
|
||||||
|
|
||||||
return chunks
|
return chunks
|
||||||
|
|
||||||
|
async def split_file_async(
|
||||||
|
self,
|
||||||
|
content: str,
|
||||||
|
file_path: str,
|
||||||
|
language: Optional[str] = None
|
||||||
|
) -> List[CodeChunk]:
|
||||||
|
"""
|
||||||
|
异步分割单个文件
|
||||||
|
|
||||||
|
将 CPU 密集型的分块操作(包括 Tree-sitter 解析)放到线程池中执行,
|
||||||
|
避免阻塞事件循环。
|
||||||
|
|
||||||
|
Args:
|
||||||
|
content: 文件内容
|
||||||
|
file_path: 文件路径
|
||||||
|
language: 编程语言(可选)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
代码块列表
|
||||||
|
"""
|
||||||
|
return await asyncio.to_thread(self.split_file, content, file_path, language)
|
||||||
|
|
||||||
def _split_by_ast(
|
def _split_by_ast(
|
||||||
self,
|
self,
|
||||||
content: str,
|
content: str,
|
||||||
|
|
|
||||||
|
|
@ -1,355 +0,0 @@
|
||||||
"""
|
|
||||||
Agent 集成测试
|
|
||||||
测试完整的审计流程
|
|
||||||
"""
|
|
||||||
|
|
||||||
import pytest
|
|
||||||
import asyncio
|
|
||||||
import os
|
|
||||||
from unittest.mock import MagicMock, AsyncMock, patch
|
|
||||||
from datetime import datetime
|
|
||||||
|
|
||||||
from app.services.agent.graph.runner import AgentRunner, LLMService
|
|
||||||
from app.services.agent.graph.audit_graph import AuditState, create_audit_graph
|
|
||||||
from app.services.agent.graph.nodes import ReconNode, AnalysisNode, VerificationNode, ReportNode
|
|
||||||
from app.services.agent.event_manager import EventManager, AgentEventEmitter
|
|
||||||
|
|
||||||
|
|
||||||
class TestLLMService:
|
|
||||||
"""LLM 服务测试"""
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_llm_service_initialization(self):
|
|
||||||
"""测试 LLM 服务初始化"""
|
|
||||||
with patch("app.core.config.settings") as mock_settings:
|
|
||||||
mock_settings.LLM_MODEL = "gpt-4o-mini"
|
|
||||||
mock_settings.LLM_API_KEY = "test-key"
|
|
||||||
|
|
||||||
service = LLMService()
|
|
||||||
|
|
||||||
assert service.model == "gpt-4o-mini"
|
|
||||||
|
|
||||||
|
|
||||||
class TestEventManager:
|
|
||||||
"""事件管理器测试"""
|
|
||||||
|
|
||||||
def test_event_manager_initialization(self):
|
|
||||||
"""测试事件管理器初始化"""
|
|
||||||
manager = EventManager()
|
|
||||||
|
|
||||||
assert manager._event_queues == {}
|
|
||||||
assert manager._event_callbacks == {}
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_event_emitter(self):
|
|
||||||
"""测试事件发射器"""
|
|
||||||
manager = EventManager()
|
|
||||||
emitter = AgentEventEmitter("test-task-id", manager)
|
|
||||||
|
|
||||||
await emitter.emit_info("Test message")
|
|
||||||
|
|
||||||
assert emitter._sequence == 1
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_event_emitter_phase_tracking(self):
|
|
||||||
"""测试事件发射器阶段跟踪"""
|
|
||||||
manager = EventManager()
|
|
||||||
emitter = AgentEventEmitter("test-task-id", manager)
|
|
||||||
|
|
||||||
await emitter.emit_phase_start("recon", "开始信息收集")
|
|
||||||
|
|
||||||
assert emitter._current_phase == "recon"
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_event_emitter_task_complete(self):
|
|
||||||
"""测试任务完成事件"""
|
|
||||||
manager = EventManager()
|
|
||||||
emitter = AgentEventEmitter("test-task-id", manager)
|
|
||||||
|
|
||||||
await emitter.emit_task_complete(findings_count=5, duration_ms=1000)
|
|
||||||
|
|
||||||
assert emitter._sequence == 1
|
|
||||||
|
|
||||||
|
|
||||||
class TestAuditGraph:
|
|
||||||
"""审计图测试"""
|
|
||||||
|
|
||||||
def test_create_audit_graph(self, mock_event_emitter):
|
|
||||||
"""测试创建审计图"""
|
|
||||||
# 创建模拟节点
|
|
||||||
recon_node = MagicMock()
|
|
||||||
analysis_node = MagicMock()
|
|
||||||
verification_node = MagicMock()
|
|
||||||
report_node = MagicMock()
|
|
||||||
|
|
||||||
graph = create_audit_graph(
|
|
||||||
recon_node=recon_node,
|
|
||||||
analysis_node=analysis_node,
|
|
||||||
verification_node=verification_node,
|
|
||||||
report_node=report_node,
|
|
||||||
)
|
|
||||||
|
|
||||||
assert graph is not None
|
|
||||||
|
|
||||||
|
|
||||||
class TestReconNode:
|
|
||||||
"""Recon 节点测试"""
|
|
||||||
|
|
||||||
@pytest.fixture
|
|
||||||
def recon_node_with_mock_agent(self, mock_event_emitter):
|
|
||||||
"""创建带模拟 Agent 的 Recon 节点"""
|
|
||||||
mock_agent = MagicMock()
|
|
||||||
mock_agent.run = AsyncMock(return_value=MagicMock(
|
|
||||||
success=True,
|
|
||||||
data={
|
|
||||||
"tech_stack": {"languages": ["Python"]},
|
|
||||||
"entry_points": [{"path": "src/app.py", "type": "api"}],
|
|
||||||
"high_risk_areas": ["src/sql_vuln.py"],
|
|
||||||
"dependencies": {},
|
|
||||||
"initial_findings": [],
|
|
||||||
}
|
|
||||||
))
|
|
||||||
|
|
||||||
return ReconNode(mock_agent, mock_event_emitter)
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_recon_node_success(self, recon_node_with_mock_agent):
|
|
||||||
"""测试 Recon 节点成功执行"""
|
|
||||||
state = {
|
|
||||||
"project_info": {"name": "Test"},
|
|
||||||
"config": {},
|
|
||||||
}
|
|
||||||
|
|
||||||
result = await recon_node_with_mock_agent(state)
|
|
||||||
|
|
||||||
assert "tech_stack" in result
|
|
||||||
assert "entry_points" in result
|
|
||||||
assert result["current_phase"] == "recon_complete"
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_recon_node_failure(self, mock_event_emitter):
|
|
||||||
"""测试 Recon 节点失败处理"""
|
|
||||||
mock_agent = MagicMock()
|
|
||||||
mock_agent.run = AsyncMock(return_value=MagicMock(
|
|
||||||
success=False,
|
|
||||||
error="Test error",
|
|
||||||
data=None,
|
|
||||||
))
|
|
||||||
|
|
||||||
node = ReconNode(mock_agent, mock_event_emitter)
|
|
||||||
|
|
||||||
result = await node({
|
|
||||||
"project_info": {},
|
|
||||||
"config": {},
|
|
||||||
})
|
|
||||||
|
|
||||||
assert "error" in result
|
|
||||||
assert result["current_phase"] == "error"
|
|
||||||
|
|
||||||
|
|
||||||
class TestAnalysisNode:
|
|
||||||
"""Analysis 节点测试"""
|
|
||||||
|
|
||||||
@pytest.fixture
|
|
||||||
def analysis_node_with_mock_agent(self, mock_event_emitter):
|
|
||||||
"""创建带模拟 Agent 的 Analysis 节点"""
|
|
||||||
mock_agent = MagicMock()
|
|
||||||
mock_agent.run = AsyncMock(return_value=MagicMock(
|
|
||||||
success=True,
|
|
||||||
data={
|
|
||||||
"findings": [
|
|
||||||
{
|
|
||||||
"id": "finding-1",
|
|
||||||
"title": "SQL Injection",
|
|
||||||
"severity": "high",
|
|
||||||
"vulnerability_type": "sql_injection",
|
|
||||||
"file_path": "src/sql_vuln.py",
|
|
||||||
"line_start": 10,
|
|
||||||
"description": "SQL injection vulnerability",
|
|
||||||
}
|
|
||||||
],
|
|
||||||
"should_continue": False,
|
|
||||||
}
|
|
||||||
))
|
|
||||||
|
|
||||||
return AnalysisNode(mock_agent, mock_event_emitter)
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_analysis_node_success(self, analysis_node_with_mock_agent):
|
|
||||||
"""测试 Analysis 节点成功执行"""
|
|
||||||
state = {
|
|
||||||
"project_info": {"name": "Test"},
|
|
||||||
"tech_stack": {"languages": ["Python"]},
|
|
||||||
"entry_points": [],
|
|
||||||
"high_risk_areas": ["src/sql_vuln.py"],
|
|
||||||
"config": {},
|
|
||||||
"iteration": 0,
|
|
||||||
"findings": [],
|
|
||||||
}
|
|
||||||
|
|
||||||
result = await analysis_node_with_mock_agent(state)
|
|
||||||
|
|
||||||
assert "findings" in result
|
|
||||||
assert len(result["findings"]) > 0
|
|
||||||
assert result["iteration"] == 1
|
|
||||||
|
|
||||||
|
|
||||||
class TestIntegrationFlow:
|
|
||||||
"""完整流程集成测试"""
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_full_audit_flow_mock(self, temp_project_dir, mock_db_session, mock_task):
|
|
||||||
"""测试完整审计流程(使用模拟)"""
|
|
||||||
# 这个测试验证整个流程的连接性
|
|
||||||
|
|
||||||
# 创建事件管理器
|
|
||||||
event_manager = EventManager()
|
|
||||||
emitter = AgentEventEmitter(mock_task.id, event_manager)
|
|
||||||
|
|
||||||
# 模拟 LLM 服务
|
|
||||||
mock_llm = MagicMock()
|
|
||||||
mock_llm.chat_completion_raw = AsyncMock(return_value={
|
|
||||||
"content": "Analysis complete",
|
|
||||||
"usage": {"total_tokens": 100},
|
|
||||||
})
|
|
||||||
|
|
||||||
# 验证事件发射
|
|
||||||
await emitter.emit_phase_start("init", "初始化")
|
|
||||||
await emitter.emit_info("测试消息")
|
|
||||||
await emitter.emit_phase_complete("init", "初始化完成")
|
|
||||||
|
|
||||||
assert emitter._sequence == 3
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_audit_state_typing(self):
|
|
||||||
"""测试审计状态类型定义"""
|
|
||||||
state: AuditState = {
|
|
||||||
"project_root": "/tmp/test",
|
|
||||||
"project_info": {"name": "Test"},
|
|
||||||
"config": {},
|
|
||||||
"task_id": "test-id",
|
|
||||||
"tech_stack": {},
|
|
||||||
"entry_points": [],
|
|
||||||
"high_risk_areas": [],
|
|
||||||
"dependencies": {},
|
|
||||||
"findings": [],
|
|
||||||
"verified_findings": [],
|
|
||||||
"false_positives": [],
|
|
||||||
"current_phase": "start",
|
|
||||||
"iteration": 0,
|
|
||||||
"max_iterations": 50,
|
|
||||||
"should_continue_analysis": False,
|
|
||||||
"messages": [],
|
|
||||||
"events": [],
|
|
||||||
"summary": None,
|
|
||||||
"security_score": None,
|
|
||||||
"error": None,
|
|
||||||
}
|
|
||||||
|
|
||||||
assert state["current_phase"] == "start"
|
|
||||||
assert state["max_iterations"] == 50
|
|
||||||
|
|
||||||
|
|
||||||
class TestToolIntegration:
|
|
||||||
"""工具集成测试"""
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_tools_work_together(self, temp_project_dir):
|
|
||||||
"""测试工具协同工作"""
|
|
||||||
from app.services.agent.tools import (
|
|
||||||
FileReadTool, FileSearchTool, ListFilesTool, PatternMatchTool,
|
|
||||||
)
|
|
||||||
|
|
||||||
# 1. 列出文件
|
|
||||||
list_tool = ListFilesTool(temp_project_dir)
|
|
||||||
list_result = await list_tool.execute(directory="src", recursive=False)
|
|
||||||
assert list_result.success is True
|
|
||||||
|
|
||||||
# 2. 搜索关键代码
|
|
||||||
search_tool = FileSearchTool(temp_project_dir)
|
|
||||||
search_result = await search_tool.execute(keyword="execute")
|
|
||||||
assert search_result.success is True
|
|
||||||
|
|
||||||
# 3. 读取文件内容
|
|
||||||
read_tool = FileReadTool(temp_project_dir)
|
|
||||||
read_result = await read_tool.execute(file_path="src/sql_vuln.py")
|
|
||||||
assert read_result.success is True
|
|
||||||
|
|
||||||
# 4. 模式匹配
|
|
||||||
pattern_tool = PatternMatchTool(temp_project_dir)
|
|
||||||
pattern_result = await pattern_tool.execute(
|
|
||||||
code=read_result.data,
|
|
||||||
file_path="src/sql_vuln.py",
|
|
||||||
language="python"
|
|
||||||
)
|
|
||||||
assert pattern_result.success is True
|
|
||||||
|
|
||||||
|
|
||||||
class TestErrorHandling:
|
|
||||||
"""错误处理测试"""
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_tool_error_handling(self, temp_project_dir):
|
|
||||||
"""测试工具错误处理"""
|
|
||||||
from app.services.agent.tools import FileReadTool
|
|
||||||
|
|
||||||
tool = FileReadTool(temp_project_dir)
|
|
||||||
|
|
||||||
# 尝试读取不存在的文件
|
|
||||||
result = await tool.execute(file_path="nonexistent/file.py")
|
|
||||||
|
|
||||||
assert result.success is False
|
|
||||||
assert result.error is not None
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_agent_graceful_degradation(self, mock_event_emitter):
|
|
||||||
"""测试 Agent 优雅降级"""
|
|
||||||
# 创建一个会失败的 Agent
|
|
||||||
mock_agent = MagicMock()
|
|
||||||
mock_agent.run = AsyncMock(side_effect=Exception("Simulated error"))
|
|
||||||
|
|
||||||
node = ReconNode(mock_agent, mock_event_emitter)
|
|
||||||
|
|
||||||
result = await node({
|
|
||||||
"project_info": {},
|
|
||||||
"config": {},
|
|
||||||
})
|
|
||||||
|
|
||||||
# 应该返回错误状态而不是崩溃
|
|
||||||
assert "error" in result
|
|
||||||
assert result["current_phase"] == "error"
|
|
||||||
|
|
||||||
|
|
||||||
class TestPerformance:
|
|
||||||
"""性能测试"""
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_tool_response_time(self, temp_project_dir):
|
|
||||||
"""测试工具响应时间"""
|
|
||||||
from app.services.agent.tools import ListFilesTool
|
|
||||||
import time
|
|
||||||
|
|
||||||
tool = ListFilesTool(temp_project_dir)
|
|
||||||
|
|
||||||
start = time.time()
|
|
||||||
await tool.execute(directory=".", recursive=True)
|
|
||||||
duration = time.time() - start
|
|
||||||
|
|
||||||
# 工具应该在合理时间内响应
|
|
||||||
assert duration < 5.0 # 5 秒内
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_multiple_tool_calls(self, temp_project_dir):
|
|
||||||
"""测试多次工具调用"""
|
|
||||||
from app.services.agent.tools import FileSearchTool
|
|
||||||
|
|
||||||
tool = FileSearchTool(temp_project_dir)
|
|
||||||
|
|
||||||
# 执行多次调用
|
|
||||||
for _ in range(5):
|
|
||||||
result = await tool.execute(keyword="def")
|
|
||||||
assert result.success is True
|
|
||||||
|
|
||||||
# 验证调用计数
|
|
||||||
assert tool._call_count == 5
|
|
||||||
|
|
||||||
|
|
@ -9,7 +9,7 @@ import {
|
||||||
} from "@/components/ui/select";
|
} from "@/components/ui/select";
|
||||||
import { GitBranch, Zap, Info } from "lucide-react";
|
import { GitBranch, Zap, Info } from "lucide-react";
|
||||||
import type { Project, CreateAuditTaskForm } from "@/shared/types";
|
import type { Project, CreateAuditTaskForm } from "@/shared/types";
|
||||||
import { isRepositoryProject, isZipProject } from "@/shared/utils/projectUtils";
|
import { isRepositoryProject, isZipProject, getRepositoryPlatformLabel } from "@/shared/utils/projectUtils";
|
||||||
import ZipFileSection from "./ZipFileSection";
|
import ZipFileSection from "./ZipFileSection";
|
||||||
import type { ZipFileMeta } from "@/shared/utils/zipStorage";
|
import type { ZipFileMeta } from "@/shared/utils/zipStorage";
|
||||||
|
|
||||||
|
|
@ -138,7 +138,7 @@ function ProjectInfoCard({ project }: { project: Project }) {
|
||||||
{isRepo && (
|
{isRepo && (
|
||||||
<>
|
<>
|
||||||
<p>
|
<p>
|
||||||
仓库平台:{project.repository_type?.toUpperCase() || "OTHER"}
|
仓库平台:{getRepositoryPlatformLabel(project.repository_type)}
|
||||||
</p>
|
</p>
|
||||||
<p>默认分支:{project.default_branch}</p>
|
<p>默认分支:{project.default_branch}</p>
|
||||||
</>
|
</>
|
||||||
|
|
|
||||||
|
|
@ -34,13 +34,13 @@ import { api } from "@/shared/config/database";
|
||||||
import { runRepositoryAudit, scanStoredZipFile } from "@/features/projects/services";
|
import { runRepositoryAudit, scanStoredZipFile } from "@/features/projects/services";
|
||||||
import type { Project, AuditTask, CreateProjectForm } from "@/shared/types";
|
import type { Project, AuditTask, CreateProjectForm } from "@/shared/types";
|
||||||
import { hasZipFile } from "@/shared/utils/zipStorage";
|
import { hasZipFile } from "@/shared/utils/zipStorage";
|
||||||
import { isRepositoryProject, getSourceTypeLabel } from "@/shared/utils/projectUtils";
|
import { isRepositoryProject, getSourceTypeLabel, getRepositoryPlatformLabel } from "@/shared/utils/projectUtils";
|
||||||
import { toast } from "sonner";
|
import { toast } from "sonner";
|
||||||
import CreateTaskDialog from "@/components/audit/CreateTaskDialog";
|
import CreateTaskDialog from "@/components/audit/CreateTaskDialog";
|
||||||
import FileSelectionDialog from "@/components/audit/FileSelectionDialog";
|
import FileSelectionDialog from "@/components/audit/FileSelectionDialog";
|
||||||
import TerminalProgressDialog from "@/components/audit/TerminalProgressDialog";
|
import TerminalProgressDialog from "@/components/audit/TerminalProgressDialog";
|
||||||
import { Dialog, DialogContent, DialogHeader, DialogTitle, DialogFooter } from "@/components/ui/dialog";
|
import { Dialog, DialogContent, DialogHeader, DialogTitle, DialogFooter } from "@/components/ui/dialog";
|
||||||
import { SUPPORTED_LANGUAGES } from "@/shared/constants";
|
import { SUPPORTED_LANGUAGES, REPOSITORY_PLATFORMS } from "@/shared/constants";
|
||||||
|
|
||||||
export default function ProjectDetail() {
|
export default function ProjectDetail() {
|
||||||
const { id } = useParams<{ id: string }>();
|
const { id } = useParams<{ id: string }>();
|
||||||
|
|
@ -475,8 +475,7 @@ export default function ProjectDetail() {
|
||||||
<div className="flex items-center justify-between">
|
<div className="flex items-center justify-between">
|
||||||
<span className="text-sm text-muted-foreground uppercase">仓库平台</span>
|
<span className="text-sm text-muted-foreground uppercase">仓库平台</span>
|
||||||
<Badge className="cyber-badge-muted">
|
<Badge className="cyber-badge-muted">
|
||||||
{project.repository_type === 'github' ? 'GitHub' :
|
{getRepositoryPlatformLabel(project.repository_type)}
|
||||||
project.repository_type === 'gitlab' ? 'GitLab' : '其他'}
|
|
||||||
</Badge>
|
</Badge>
|
||||||
</div>
|
</div>
|
||||||
|
|
||||||
|
|
@ -529,8 +528,7 @@ export default function ProjectDetail() {
|
||||||
className="flex items-center justify-between p-3 bg-muted/50 rounded-lg hover:bg-muted transition-all group"
|
className="flex items-center justify-between p-3 bg-muted/50 rounded-lg hover:bg-muted transition-all group"
|
||||||
>
|
>
|
||||||
<div className="flex items-center space-x-3">
|
<div className="flex items-center space-x-3">
|
||||||
<div className={`w-8 h-8 rounded-lg flex items-center justify-center ${
|
<div className={`w-8 h-8 rounded-lg flex items-center justify-center ${task.status === 'completed' ? 'bg-emerald-500/20' :
|
||||||
task.status === 'completed' ? 'bg-emerald-500/20' :
|
|
||||||
task.status === 'running' ? 'bg-sky-500/20' :
|
task.status === 'running' ? 'bg-sky-500/20' :
|
||||||
task.status === 'failed' ? 'bg-rose-500/20' :
|
task.status === 'failed' ? 'bg-rose-500/20' :
|
||||||
'bg-muted'
|
'bg-muted'
|
||||||
|
|
@ -579,8 +577,7 @@ export default function ProjectDetail() {
|
||||||
<div key={task.id} className="cyber-card p-6">
|
<div key={task.id} className="cyber-card p-6">
|
||||||
<div className="flex items-center justify-between mb-4 pb-4 border-b border-border">
|
<div className="flex items-center justify-between mb-4 pb-4 border-b border-border">
|
||||||
<div className="flex items-center space-x-3">
|
<div className="flex items-center space-x-3">
|
||||||
<div className={`w-10 h-10 rounded-lg flex items-center justify-center ${
|
<div className={`w-10 h-10 rounded-lg flex items-center justify-center ${task.status === 'completed' ? 'bg-emerald-500/20' :
|
||||||
task.status === 'completed' ? 'bg-emerald-500/20' :
|
|
||||||
task.status === 'running' ? 'bg-sky-500/20' :
|
task.status === 'running' ? 'bg-sky-500/20' :
|
||||||
task.status === 'failed' ? 'bg-rose-500/20' :
|
task.status === 'failed' ? 'bg-rose-500/20' :
|
||||||
'bg-muted'
|
'bg-muted'
|
||||||
|
|
@ -676,8 +673,7 @@ export default function ProjectDetail() {
|
||||||
<div key={index} className="cyber-card p-4 hover:border-border transition-all">
|
<div key={index} className="cyber-card p-4 hover:border-border transition-all">
|
||||||
<div className="flex items-start justify-between">
|
<div className="flex items-start justify-between">
|
||||||
<div className="flex items-start space-x-3">
|
<div className="flex items-start space-x-3">
|
||||||
<div className={`w-8 h-8 rounded-lg flex items-center justify-center ${
|
<div className={`w-8 h-8 rounded-lg flex items-center justify-center ${issue.severity === 'critical' ? 'bg-rose-500/20 text-rose-600 dark:text-rose-400' :
|
||||||
issue.severity === 'critical' ? 'bg-rose-500/20 text-rose-600 dark:text-rose-400' :
|
|
||||||
issue.severity === 'high' ? 'bg-orange-500/20 text-orange-600 dark:text-orange-400' :
|
issue.severity === 'high' ? 'bg-orange-500/20 text-orange-600 dark:text-orange-400' :
|
||||||
issue.severity === 'medium' ? 'bg-amber-500/20 text-amber-600 dark:text-amber-400' :
|
issue.severity === 'medium' ? 'bg-amber-500/20 text-amber-600 dark:text-amber-400' :
|
||||||
'bg-sky-500/20 text-sky-600 dark:text-sky-400'
|
'bg-sky-500/20 text-sky-600 dark:text-sky-400'
|
||||||
|
|
@ -783,9 +779,11 @@ export default function ProjectDetail() {
|
||||||
<SelectValue />
|
<SelectValue />
|
||||||
</SelectTrigger>
|
</SelectTrigger>
|
||||||
<SelectContent className="cyber-dialog border-border">
|
<SelectContent className="cyber-dialog border-border">
|
||||||
<SelectItem value="github">GitHub</SelectItem>
|
{REPOSITORY_PLATFORMS.map((platform) => (
|
||||||
<SelectItem value="gitlab">GitLab</SelectItem>
|
<SelectItem key={platform.value} value={platform.value}>
|
||||||
<SelectItem value="other">其他</SelectItem>
|
{platform.label}
|
||||||
|
</SelectItem>
|
||||||
|
))}
|
||||||
</SelectContent>
|
</SelectContent>
|
||||||
</Select>
|
</Select>
|
||||||
</div>
|
</div>
|
||||||
|
|
|
||||||
|
|
@ -45,7 +45,7 @@ import { Link } from "react-router-dom";
|
||||||
import { toast } from "sonner";
|
import { toast } from "sonner";
|
||||||
import CreateTaskDialog from "@/components/audit/CreateTaskDialog";
|
import CreateTaskDialog from "@/components/audit/CreateTaskDialog";
|
||||||
import TerminalProgressDialog from "@/components/audit/TerminalProgressDialog";
|
import TerminalProgressDialog from "@/components/audit/TerminalProgressDialog";
|
||||||
import { SUPPORTED_LANGUAGES } from "@/shared/constants";
|
import { SUPPORTED_LANGUAGES, REPOSITORY_PLATFORMS } from "@/shared/constants";
|
||||||
|
|
||||||
export default function Projects() {
|
export default function Projects() {
|
||||||
const [projects, setProjects] = useState<Project[]>([]);
|
const [projects, setProjects] = useState<Project[]>([]);
|
||||||
|
|
@ -487,10 +487,11 @@ export default function Projects() {
|
||||||
<SelectValue />
|
<SelectValue />
|
||||||
</SelectTrigger>
|
</SelectTrigger>
|
||||||
<SelectContent className="cyber-dialog border-border">
|
<SelectContent className="cyber-dialog border-border">
|
||||||
<SelectItem value="github">GITHUB</SelectItem>
|
{REPOSITORY_PLATFORMS.map((platform) => (
|
||||||
<SelectItem value="gitlab">GITLAB</SelectItem>
|
<SelectItem key={platform.value} value={platform.value}>
|
||||||
<SelectItem value="gitea">GITEA</SelectItem>
|
{platform.label}
|
||||||
<SelectItem value="other">OTHER</SelectItem>
|
</SelectItem>
|
||||||
|
))}
|
||||||
</SelectContent>
|
</SelectContent>
|
||||||
</Select>
|
</Select>
|
||||||
</div>
|
</div>
|
||||||
|
|
@ -1046,10 +1047,11 @@ export default function Projects() {
|
||||||
<SelectValue />
|
<SelectValue />
|
||||||
</SelectTrigger>
|
</SelectTrigger>
|
||||||
<SelectContent className="cyber-dialog border-border">
|
<SelectContent className="cyber-dialog border-border">
|
||||||
<SelectItem value="github">GITHUB</SelectItem>
|
{REPOSITORY_PLATFORMS.map((platform) => (
|
||||||
<SelectItem value="gitlab">GITLAB</SelectItem>
|
<SelectItem key={platform.value} value={platform.value}>
|
||||||
<SelectItem value="gitea">GITEA</SelectItem>
|
{platform.label}
|
||||||
<SelectItem value="other">OTHER</SelectItem>
|
</SelectItem>
|
||||||
|
))}
|
||||||
</SelectContent>
|
</SelectContent>
|
||||||
</Select>
|
</Select>
|
||||||
</div>
|
</div>
|
||||||
|
|
|
||||||
|
|
@ -36,7 +36,7 @@ import type { AuditTask, AuditIssue } from "@/shared/types";
|
||||||
import { toast } from "sonner";
|
import { toast } from "sonner";
|
||||||
import ExportReportDialog from "@/components/reports/ExportReportDialog";
|
import ExportReportDialog from "@/components/reports/ExportReportDialog";
|
||||||
import { calculateTaskProgress } from "@/shared/utils/utils";
|
import { calculateTaskProgress } from "@/shared/utils/utils";
|
||||||
import { isRepositoryProject, getSourceTypeLabel } from "@/shared/utils/projectUtils";
|
import { isRepositoryProject, getSourceTypeLabel, getRepositoryPlatformLabel } from "@/shared/utils/projectUtils";
|
||||||
|
|
||||||
// AI explanation parser
|
// AI explanation parser
|
||||||
function parseAIExplanation(aiExplanation: string) {
|
function parseAIExplanation(aiExplanation: string) {
|
||||||
|
|
@ -86,8 +86,7 @@ function IssuesList({ issues }: { issues: AuditIssue[] }) {
|
||||||
<div key={issue.id || index} className="cyber-card p-4 hover:border-border transition-all group">
|
<div key={issue.id || index} className="cyber-card p-4 hover:border-border transition-all group">
|
||||||
<div className="flex items-start justify-between mb-3">
|
<div className="flex items-start justify-between mb-3">
|
||||||
<div className="flex items-start space-x-3">
|
<div className="flex items-start space-x-3">
|
||||||
<div className={`w-10 h-10 rounded-lg flex items-center justify-center ${
|
<div className={`w-10 h-10 rounded-lg flex items-center justify-center ${issue.severity === 'critical' ? 'bg-rose-500/20 text-rose-400' :
|
||||||
issue.severity === 'critical' ? 'bg-rose-500/20 text-rose-400' :
|
|
||||||
issue.severity === 'high' ? 'bg-orange-500/20 text-orange-400' :
|
issue.severity === 'high' ? 'bg-orange-500/20 text-orange-400' :
|
||||||
issue.severity === 'medium' ? 'bg-amber-500/20 text-amber-400' :
|
issue.severity === 'medium' ? 'bg-amber-500/20 text-amber-400' :
|
||||||
'bg-sky-500/20 text-sky-400'
|
'bg-sky-500/20 text-sky-400'
|
||||||
|
|
@ -702,7 +701,7 @@ export default function TaskDetail() {
|
||||||
{isRepositoryProject(task.project) && (
|
{isRepositoryProject(task.project) && (
|
||||||
<div>
|
<div>
|
||||||
<p className="text-xs font-bold text-muted-foreground uppercase mb-1">仓库平台</p>
|
<p className="text-xs font-bold text-muted-foreground uppercase mb-1">仓库平台</p>
|
||||||
<p className="text-base font-bold text-foreground">{task.project.repository_type?.toUpperCase() || 'OTHER'}</p>
|
<p className="text-base font-bold text-foreground">{getRepositoryPlatformLabel(task.project.repository_type)}</p>
|
||||||
</div>
|
</div>
|
||||||
)}
|
)}
|
||||||
{task.project.programming_languages && (
|
{task.project.programming_languages && (
|
||||||
|
|
|
||||||
|
|
@ -62,13 +62,6 @@ export const PROJECT_SOURCE_TYPES = {
|
||||||
ZIP: 'zip',
|
ZIP: 'zip',
|
||||||
} as const;
|
} as const;
|
||||||
|
|
||||||
// 仓库平台类型
|
|
||||||
export const REPOSITORY_TYPES = {
|
|
||||||
GITHUB: 'github',
|
|
||||||
GITLAB: 'gitlab',
|
|
||||||
OTHER: 'other',
|
|
||||||
} as const;
|
|
||||||
|
|
||||||
// 分析深度
|
// 分析深度
|
||||||
export const ANALYSIS_DEPTH = {
|
export const ANALYSIS_DEPTH = {
|
||||||
BASIC: 'basic',
|
BASIC: 'basic',
|
||||||
|
|
|
||||||
|
|
@ -22,17 +22,23 @@ export const PROJECT_SOURCE_TYPES: Array<{
|
||||||
}
|
}
|
||||||
];
|
];
|
||||||
|
|
||||||
|
// 仓库平台显示名称
|
||||||
|
export const REPOSITORY_PLATFORM_LABELS: Record<RepositoryPlatform, string> = {
|
||||||
|
github: 'GitHub',
|
||||||
|
gitlab: 'GitLab',
|
||||||
|
gitea: 'Gitea',
|
||||||
|
other: '其他',
|
||||||
|
};
|
||||||
|
|
||||||
// 仓库平台选项
|
// 仓库平台选项
|
||||||
export const REPOSITORY_PLATFORMS: Array<{
|
export const REPOSITORY_PLATFORMS: Array<{
|
||||||
value: RepositoryPlatform;
|
value: RepositoryPlatform;
|
||||||
label: string;
|
label: string;
|
||||||
icon?: string;
|
icon?: string;
|
||||||
}> = [
|
}> = Object.entries(REPOSITORY_PLATFORM_LABELS).map(([value, label]) => ({
|
||||||
{ value: 'github', label: 'GitHub' },
|
value: value as RepositoryPlatform,
|
||||||
{ value: 'gitlab', label: 'GitLab' },
|
label
|
||||||
{ value: 'gitea', label: 'Gitea' },
|
}));
|
||||||
{ value: 'other', label: '其他' }
|
|
||||||
];
|
|
||||||
|
|
||||||
// 项目来源类型的颜色配置
|
// 项目来源类型的颜色配置
|
||||||
export const SOURCE_TYPE_COLORS: Record<ProjectSourceType, {
|
export const SOURCE_TYPE_COLORS: Record<ProjectSourceType, {
|
||||||
|
|
|
||||||
|
|
@ -4,6 +4,7 @@
|
||||||
*/
|
*/
|
||||||
|
|
||||||
import type { Project, ProjectSourceType } from '@/shared/types';
|
import type { Project, ProjectSourceType } from '@/shared/types';
|
||||||
|
import { REPOSITORY_PLATFORM_LABELS } from '@/shared/constants/projectTypes';
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* 判断项目是否为仓库类型
|
* 判断项目是否为仓库类型
|
||||||
|
|
@ -45,13 +46,7 @@ export function getSourceTypeBadge(sourceType: ProjectSourceType): string {
|
||||||
* 获取仓库平台的显示名称
|
* 获取仓库平台的显示名称
|
||||||
*/
|
*/
|
||||||
export function getRepositoryPlatformLabel(platform?: string): string {
|
export function getRepositoryPlatformLabel(platform?: string): string {
|
||||||
const labels: Record<string, string> = {
|
return REPOSITORY_PLATFORM_LABELS[platform as keyof typeof REPOSITORY_PLATFORM_LABELS] || REPOSITORY_PLATFORM_LABELS.other;
|
||||||
github: 'GitHub',
|
|
||||||
gitlab: 'GitLab',
|
|
||||||
gitea: 'Gitea',
|
|
||||||
other: '其他'
|
|
||||||
};
|
|
||||||
return labels[platform || 'other'] || '其他';
|
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue