feat(agent): enhance agent functionality with LLM-driven decision-making and event handling

- Introduce LLM-driven decision-making across various agents, allowing for dynamic adjustments based on real-time analysis.
- Implement new event types for LLM thinking, decisions, actions, and observations to enrich the event streaming experience.
- Update agent task responses to include additional metrics for better tracking of task progress and outcomes.
- Refactor UI components to highlight LLM-related events and improve user interaction during audits.
- Enhance API endpoints to support new event structures and improve overall error handling.
This commit is contained in:
lintsinghua 2025-12-11 21:14:32 +08:00
parent 58c918f557
commit 8938a8a3c9
12 changed files with 2283 additions and 1691 deletions

View File

@ -75,18 +75,46 @@ class AgentTaskCreate(BaseModel):
class AgentTaskResponse(BaseModel):
"""Agent 任务响应"""
"""Agent 任务响应 - 包含所有前端需要的字段"""
id: str
project_id: str
name: Optional[str]
description: Optional[str]
task_type: str = "agent_audit"
status: str
current_phase: Optional[str]
current_step: Optional[str] = None
# 统计
total_findings: int = 0
verified_findings: int = 0
security_score: Optional[int] = None
# 进度统计
total_files: int = 0
indexed_files: int = 0
analyzed_files: int = 0
total_chunks: int = 0
# Agent 统计
total_iterations: int = 0
tool_calls_count: int = 0
tokens_used: int = 0
# 发现统计(兼容两种命名)
findings_count: int = 0
total_findings: int = 0 # 兼容字段
verified_count: int = 0
verified_findings: int = 0 # 兼容字段
false_positive_count: int = 0
# 严重程度统计
critical_count: int = 0
high_count: int = 0
medium_count: int = 0
low_count: int = 0
# 评分
quality_score: float = 0.0
security_score: Optional[float] = None
# 进度百分比
progress_percentage: float = 0.0
# 时间
created_at: datetime
@ -94,7 +122,11 @@ class AgentTaskResponse(BaseModel):
completed_at: Optional[datetime] = None
# 配置
config: Optional[dict] = None
audit_scope: Optional[dict] = None
target_vulnerabilities: Optional[List[str]] = None
verification_level: Optional[str] = None
exclude_patterns: Optional[List[str]] = None
target_files: Optional[List[str]] = None
# 错误信息
error_message: Optional[str] = None
@ -177,6 +209,12 @@ async def _execute_agent_task(task_id: str, project_root: str):
logger.error(f"Task {task_id} not found")
return
# 更新状态为运行中
task.status = AgentTaskStatus.RUNNING
task.started_at = datetime.now(timezone.utc)
await db.commit()
logger.info(f"Task {task_id} started")
# 创建 Runner
runner = AgentRunner(db, task, project_root)
_running_tasks[task_id] = runner
@ -184,22 +222,37 @@ async def _execute_agent_task(task_id: str, project_root: str):
# 执行
result = await runner.run()
logger.info(f"Task {task_id} completed: {result.get('success', False)}")
# 更新任务状态
await db.refresh(task)
if result.get('success', True): # 默认成功,除非明确失败
task.status = AgentTaskStatus.COMPLETED
task.completed_at = datetime.now(timezone.utc)
else:
task.status = AgentTaskStatus.FAILED
task.error_message = result.get('error', 'Unknown error')
task.completed_at = datetime.now(timezone.utc)
await db.commit()
logger.info(f"Task {task_id} completed with status: {task.status}")
except Exception as e:
logger.error(f"Task {task_id} failed: {e}", exc_info=True)
# 更新任务状态
task = await db.get(AgentTask, task_id)
if task:
task.status = AgentTaskStatus.FAILED
task.error_message = str(e)
task.completed_at = datetime.now(timezone.utc)
await db.commit()
try:
task = await db.get(AgentTask, task_id)
if task:
task.status = AgentTaskStatus.FAILED
task.error_message = str(e)[:1000] # 限制错误消息长度
task.completed_at = datetime.now(timezone.utc)
await db.commit()
except Exception as db_error:
logger.error(f"Failed to update task status: {db_error}")
finally:
# 清理
_running_tasks.pop(task_id, None)
logger.debug(f"Task {task_id} cleaned up")
# ============ API Endpoints ============
@ -315,7 +368,61 @@ async def get_agent_task(
if not project or project.owner_id != current_user.id:
raise HTTPException(status_code=403, detail="无权访问此任务")
return task
# 构建响应,确保所有字段都包含
try:
# 计算进度百分比
progress = 0.0
if hasattr(task, 'progress_percentage'):
progress = task.progress_percentage
elif task.status == AgentTaskStatus.COMPLETED:
progress = 100.0
elif task.status in [AgentTaskStatus.FAILED, AgentTaskStatus.CANCELLED]:
progress = 0.0
# 手动构建响应数据
response_data = {
"id": task.id,
"project_id": task.project_id,
"name": task.name,
"description": task.description,
"task_type": task.task_type or "agent_audit",
"status": task.status,
"current_phase": task.current_phase,
"current_step": task.current_step,
"total_files": task.total_files or 0,
"indexed_files": task.indexed_files or 0,
"analyzed_files": task.analyzed_files or 0,
"total_chunks": task.total_chunks or 0,
"total_iterations": task.total_iterations or 0,
"tool_calls_count": task.tool_calls_count or 0,
"tokens_used": task.tokens_used or 0,
"findings_count": task.findings_count or 0,
"total_findings": task.findings_count or 0, # 兼容字段
"verified_count": task.verified_count or 0,
"verified_findings": task.verified_count or 0, # 兼容字段
"false_positive_count": task.false_positive_count or 0,
"critical_count": task.critical_count or 0,
"high_count": task.high_count or 0,
"medium_count": task.medium_count or 0,
"low_count": task.low_count or 0,
"quality_score": float(task.quality_score or 0.0),
"security_score": float(task.security_score) if task.security_score is not None else None,
"progress_percentage": progress,
"created_at": task.created_at,
"started_at": task.started_at,
"completed_at": task.completed_at,
"error_message": task.error_message,
"audit_scope": task.audit_scope,
"target_vulnerabilities": task.target_vulnerabilities,
"verification_level": task.verification_level,
"exclude_patterns": task.exclude_patterns,
"target_files": task.target_files,
}
return AgentTaskResponse(**response_data)
except Exception as e:
logger.error(f"Error serializing task {task_id}: {e}", exc_info=True)
raise HTTPException(status_code=500, detail=f"序列化任务数据失败: {str(e)}")
@router.post("/{task_id}/cancel")
@ -396,10 +503,14 @@ async def stream_agent_events(
idle_time = 0
for event in events:
last_sequence = event.sequence
# event_type 已经是字符串,不需要 .value
event_type_str = str(event.event_type)
phase_str = str(event.phase) if event.phase else None
data = {
"id": event.id,
"type": event.event_type.value if hasattr(event.event_type, 'value') else str(event.event_type),
"phase": event.phase.value if event.phase and hasattr(event.phase, 'value') else event.phase,
"type": event_type_str,
"phase": phase_str,
"message": event.message,
"sequence": event.sequence,
"timestamp": event.created_at.isoformat() if event.created_at else None,
@ -411,9 +522,12 @@ async def stream_agent_events(
idle_time += poll_interval
# 检查任务是否结束
if task_status in [AgentTaskStatus.COMPLETED, AgentTaskStatus.FAILED, AgentTaskStatus.CANCELLED]:
yield f"data: {json.dumps({'type': 'task_end', 'status': task_status.value})}\n\n"
break
if task_status:
# task_status 可能是字符串或枚举,统一转换为字符串
status_str = str(task_status)
if status_str in ["completed", "failed", "cancelled"]:
yield f"data: {json.dumps({'type': 'task_end', 'status': status_str})}\n\n"
break
# 检查空闲超时
if idle_time >= max_idle:
@ -512,8 +626,8 @@ async def stream_agent_with_thinking(
for event in events:
last_sequence = event.sequence
# 获取事件类型字符串
event_type = event.event_type.value if hasattr(event.event_type, 'value') else str(event.event_type)
# 获取事件类型字符串event_type 已经是字符串)
event_type = str(event.event_type)
# 过滤事件
if event_type in skip_types:
@ -523,7 +637,7 @@ async def stream_agent_with_thinking(
data = {
"id": event.id,
"type": event_type,
"phase": event.phase.value if event.phase and hasattr(event.phase, 'value') else event.phase,
"phase": str(event.phase) if event.phase else None,
"message": event.message,
"sequence": event.sequence,
"timestamp": event.created_at.isoformat() if event.created_at else None,
@ -552,14 +666,16 @@ async def stream_agent_with_thinking(
idle_time += poll_interval
# 检查任务是否结束
if task_status in [AgentTaskStatus.COMPLETED, AgentTaskStatus.FAILED, AgentTaskStatus.CANCELLED]:
end_data = {
"type": "task_end",
"status": task_status.value,
"message": f"任务{'完成' if task_status == AgentTaskStatus.COMPLETED else '结束'}",
}
yield f"event: task_end\ndata: {json.dumps(end_data, ensure_ascii=False)}\n\n"
break
if task_status:
status_str = str(task_status)
if status_str in ["completed", "failed", "cancelled"]:
end_data = {
"type": "task_end",
"status": status_str,
"message": f"任务{'完成' if status_str == 'completed' else '结束'}",
}
yield f"event: task_end\ndata: {json.dumps(end_data, ensure_ascii=False)}\n\n"
break
# 发送心跳
last_heartbeat += poll_interval
@ -709,8 +825,9 @@ async def get_task_summary(
verified_count = 0
for f in findings:
sev = f.severity.value if hasattr(f.severity, 'value') else str(f.severity)
vtype = f.vulnerability_type.value if hasattr(f.vulnerability_type, 'value') else str(f.vulnerability_type)
# severity 和 vulnerability_type 已经是字符串
sev = str(f.severity)
vtype = str(f.vulnerability_type)
severity_distribution[sev] = severity_distribution.get(sev, 0) + 1
vulnerability_types[vtype] = vulnerability_types.get(vtype, 0) + 1
@ -730,11 +847,11 @@ async def get_task_summary(
.where(AgentEvent.event_type == AgentEventType.PHASE_COMPLETE)
.distinct()
)
phases = [p[0].value if p[0] and hasattr(p[0], 'value') else str(p[0]) for p in phases_result.fetchall() if p[0]]
phases = [str(p[0]) for p in phases_result.fetchall() if p[0]]
return TaskSummaryResponse(
task_id=task_id,
status=task.status.value if hasattr(task.status, 'value') else str(task.status),
status=str(task.status), # status 已经是字符串
security_score=task.security_score,
total_findings=len(findings),
verified_findings=verified_count,

File diff suppressed because it is too large Load Diff

View File

@ -1,6 +1,8 @@
"""
Agent 基类
定义 Agent 的基本接口和通用功能
核心原则LLM Agent 的大脑所有日志应该反映 LLM 的参与
"""
from abc import ABC, abstractmethod
@ -87,7 +89,11 @@ class AgentResult:
class BaseAgent(ABC):
"""
Agent 基类
所有 Agent 需要继承此类并实现核心方法
核心原则
1. LLM Agent 的大脑全程参与决策
2. 所有日志应该反映 LLM 的思考过程
3. 工具调用是 LLM 的决策结果
"""
def __init__(
@ -146,6 +152,8 @@ class BaseAgent(ABC):
def is_cancelled(self) -> bool:
return self._cancelled
# ============ 核心事件发射方法 ============
async def emit_event(
self,
event_type: str,
@ -161,28 +169,123 @@ class BaseAgent(ABC):
**kwargs
))
# ============ LLM 思考相关事件 ============
async def emit_thinking(self, message: str):
"""发射思考事件"""
await self.emit_event("thinking", f"[{self.name}] {message}")
"""发射 LLM 思考事件"""
await self.emit_event("thinking", f"🧠 [{self.name}] {message}")
async def emit_llm_start(self, iteration: int):
"""发射 LLM 开始思考事件"""
await self.emit_event(
"llm_start",
f"🤔 [{self.name}] LLM 开始第 {iteration} 轮思考...",
metadata={"iteration": iteration}
)
async def emit_llm_thought(self, thought: str, iteration: int):
"""发射 LLM 思考内容事件 - 这是核心!展示 LLM 在想什么"""
# 截断过长的思考内容
display_thought = thought[:500] + "..." if len(thought) > 500 else thought
await self.emit_event(
"llm_thought",
f"💭 [{self.name}] LLM 思考:\n{display_thought}",
metadata={
"thought": thought,
"iteration": iteration,
}
)
async def emit_llm_decision(self, decision: str, reason: str = ""):
"""发射 LLM 决策事件 - 展示 LLM 做了什么决定"""
await self.emit_event(
"llm_decision",
f"💡 [{self.name}] LLM 决策: {decision}" + (f" (理由: {reason})" if reason else ""),
metadata={
"decision": decision,
"reason": reason,
}
)
async def emit_llm_action(self, action: str, action_input: Dict):
"""发射 LLM 动作事件 - LLM 决定执行什么动作"""
import json
input_str = json.dumps(action_input, ensure_ascii=False)[:200]
await self.emit_event(
"llm_action",
f"⚡ [{self.name}] LLM 动作: {action}\n 参数: {input_str}",
metadata={
"action": action,
"action_input": action_input,
}
)
async def emit_llm_observation(self, observation: str):
"""发射 LLM 观察事件 - LLM 看到了什么"""
display_obs = observation[:300] + "..." if len(observation) > 300 else observation
await self.emit_event(
"llm_observation",
f"👁️ [{self.name}] LLM 观察到:\n{display_obs}",
metadata={"observation": observation[:2000]}
)
async def emit_llm_complete(self, result_summary: str, tokens_used: int):
"""发射 LLM 完成事件"""
await self.emit_event(
"llm_complete",
f"✅ [{self.name}] LLM 完成: {result_summary} (消耗 {tokens_used} tokens)",
metadata={
"tokens_used": tokens_used,
}
)
# ============ 工具调用相关事件 ============
async def emit_tool_call(self, tool_name: str, tool_input: Dict):
"""发射工具调用事件"""
"""发射工具调用事件 - LLM 决定调用工具"""
import json
input_str = json.dumps(tool_input, ensure_ascii=False)[:300]
await self.emit_event(
"tool_call",
f"[{self.name}] 调用工具: {tool_name}",
f"🔧 [{self.name}] LLM 调用工具: {tool_name}\n 输入: {input_str}",
tool_name=tool_name,
tool_input=tool_input,
)
async def emit_tool_result(self, tool_name: str, result: str, duration_ms: int):
"""发射工具结果事件"""
result_preview = result[:200] + "..." if len(result) > 200 else result
await self.emit_event(
"tool_result",
f"[{self.name}] {tool_name} 完成 ({duration_ms}ms)",
f"📤 [{self.name}] 工具 {tool_name} 返回 ({duration_ms}ms):\n {result_preview}",
tool_name=tool_name,
tool_duration_ms=duration_ms,
)
# ============ 发现相关事件 ============
async def emit_finding(self, title: str, severity: str, vuln_type: str, file_path: str = ""):
"""发射漏洞发现事件"""
severity_emoji = {
"critical": "🔴",
"high": "🟠",
"medium": "🟡",
"low": "🟢",
}.get(severity.lower(), "")
await self.emit_event(
"finding",
f"{severity_emoji} [{self.name}] 发现漏洞: [{severity.upper()}] {title}\n 类型: {vuln_type}\n 位置: {file_path}",
metadata={
"title": title,
"severity": severity,
"vulnerability_type": vuln_type,
"file_path": file_path,
}
)
# ============ 通用工具方法 ============
async def call_tool(self, tool_name: str, **kwargs) -> Any:
"""
调用工具
@ -208,7 +311,7 @@ class BaseAgent(ABC):
result = await tool.execute(**kwargs)
duration_ms = int((time.time() - start) * 1000)
await self.emit_tool_result(tool_name, str(result.data)[:200], duration_ms)
await self.emit_tool_result(tool_name, str(result.data)[:500], duration_ms)
return result
@ -229,8 +332,9 @@ class BaseAgent(ABC):
"""
self._iteration += 1
# 这里应该调用实际的 LLM 服务
# 使用 LangChain 或直接调用 API
# 发射 LLM 开始事件
await self.emit_llm_start(self._iteration)
try:
response = await self.llm_service.chat_completion(
messages=messages,
@ -281,4 +385,3 @@ class BaseAgent(ABC):
"tool_calls": self._tool_calls,
"tokens_used": self._total_tokens,
}

View File

@ -1,12 +1,19 @@
"""
Orchestrator Agent (编排层)
负责任务分解 Agent 调度和结果汇总
Orchestrator Agent (编排层) - LLM 驱动版
类型: Plan-and-Execute
LLM 是真正的大脑全程参与决策
- LLM 决定下一步做什么
- LLM 决定调度哪个子 Agent
- LLM 决定何时完成
- LLM 根据中间结果动态调整策略
类型: Autonomous Agent with Dynamic Planning
"""
import asyncio
import json
import logging
import re
from typing import List, Dict, Any, Optional
from dataclasses import dataclass
@ -15,69 +22,97 @@ from .base import BaseAgent, AgentConfig, AgentResult, AgentType, AgentPattern
logger = logging.getLogger(__name__)
@dataclass
class AuditPlan:
"""审计计划"""
phases: List[Dict[str, Any]]
high_risk_areas: List[str]
focus_vulnerabilities: List[str]
estimated_steps: int
priority_files: List[str]
metadata: Dict[str, Any]
ORCHESTRATOR_SYSTEM_PROMPT = """你是 DeepAudit 的编排 Agent负责**自主**协调整个安全审计流程。
## 你的角色
你是整个审计流程的**大脑**不是一个机械执行者你需要
1. 自主思考和决策
2. 根据观察结果动态调整策略
3. 决定何时调用哪个子 Agent
4. 判断何时审计完成
ORCHESTRATOR_SYSTEM_PROMPT = """你是 DeepAudit 的编排 Agent负责协调整个安全审计流程。
## 你可以调度的子 Agent
1. **recon**: 信息收集 Agent - 分析项目结构技术栈入口点
2. **analysis**: 分析 Agent - 深度代码审计漏洞检测
3. **verification**: 验证 Agent - 验证发现的漏洞生成 PoC
## 你的职责
1. 分析项目信息制定审计计划
2. 调度子 AgentReconAnalysisVerification执行任务
3. 汇总审计结果生成报告
## 你可以使用的操作
## 审计流程
1. **信息收集阶段**: 调度 Recon Agent 收集项目信息
- 项目结构分析
- 技术栈识别
- 入口点识别
- 依赖分析
2. **漏洞分析阶段**: 调度 Analysis Agent 进行代码分析
- 静态代码分析
- 语义搜索
- 模式匹配
- 数据流追踪
3. **漏洞验证阶段**: 调度 Verification Agent 验证发现
- 漏洞确认
- PoC 生成
- 沙箱测试
4. **报告生成阶段**: 汇总所有发现生成最终报告
## 输出格式
当生成审计计划时返回 JSON:
```json
{
"phases": [
{"name": "阶段名", "description": "描述", "agent": "agent_type"}
],
"high_risk_areas": ["高风险目录/文件"],
"focus_vulnerabilities": ["重点漏洞类型"],
"priority_files": ["优先审计的文件"],
"estimated_steps": 数字
}
### 1. 调度子 Agent
```
Action: dispatch_agent
Action Input: {"agent": "recon|analysis|verification", "task": "具体任务描述", "context": "任务上下文"}
```
请基于项目信息制定合理的审计计划"""
### 2. 汇总发现
```
Action: summarize
Action Input: {"findings": [...], "analysis": "你的分析"}
```
### 3. 完成审计
```
Action: finish
Action Input: {"conclusion": "审计结论", "findings": [...], "recommendations": [...]}
```
## 工作方式
每一步你需要
1. **Thought**: 分析当前状态思考下一步应该做什么
- 目前收集到了什么信息
- 还需要了解什么
- 应该深入分析哪些地方
- 有什么发现需要验证
2. **Action**: 选择一个操作
3. **Action Input**: 提供操作参数
## 输出格式
每一步必须严格按照以下格式
```
Thought: [你的思考过程]
Action: [dispatch_agent|summarize|finish]
Action Input: [JSON 参数]
```
## 审计策略建议
- 先用 recon Agent 了解项目全貌
- 根据 recon 结果 analysis Agent 重点审计高风险区域
- 发现可疑漏洞后 verification Agent 验证
- 随时根据新发现调整策略不要机械执行
- 当你认为审计足够全面时选择 finish
## 重要原则
1. **你是大脑不是执行器** - 每一步都要思考
2. **动态调整** - 根据发现调整策略
3. **主动决策** - 不要等待主动推进
4. **质量优先** - 宁可深入分析几个真实漏洞不要浅尝辄止
现在基于项目信息开始你的审计工作"""
@dataclass
class AgentStep:
"""执行步骤"""
thought: str
action: str
action_input: Dict[str, Any]
observation: Optional[str] = None
sub_agent_result: Optional[AgentResult] = None
class OrchestratorAgent(BaseAgent):
"""
编排 Agent
编排 Agent - LLM 驱动版
使用 Plan-and-Execute 模式
1. 首先生成审计计划
2. 按计划调度子 Agent
3. 收集结果并汇总
LLM 全程参与决策
1. LLM 思考当前状态
2. LLM 决定下一步操作
3. 执行操作获取结果
4. LLM 分析结果决定下一步
5. 重复直到 LLM 决定完成
"""
def __init__(
@ -90,13 +125,16 @@ class OrchestratorAgent(BaseAgent):
config = AgentConfig(
name="Orchestrator",
agent_type=AgentType.ORCHESTRATOR,
pattern=AgentPattern.PLAN_AND_EXECUTE,
max_iterations=10,
pattern=AgentPattern.REACT, # 改为 ReAct 模式!
max_iterations=20,
system_prompt=ORCHESTRATOR_SYSTEM_PROMPT,
)
super().__init__(config, llm_service, tools, event_emitter)
self.sub_agents = sub_agents or {}
self._conversation_history: List[Dict[str, str]] = []
self._steps: List[AgentStep] = []
self._all_findings: List[Dict] = []
def register_sub_agent(self, name: str, agent: BaseAgent):
"""注册子 Agent"""
@ -104,7 +142,7 @@ class OrchestratorAgent(BaseAgent):
async def run(self, input_data: Dict[str, Any]) -> AgentResult:
"""
执行编排任务
执行编排任务 - LLM 全程参与
Args:
input_data: {
@ -118,82 +156,136 @@ class OrchestratorAgent(BaseAgent):
project_info = input_data.get("project_info", {})
config = input_data.get("config", {})
# 构建初始消息
initial_message = self._build_initial_message(project_info, config)
# 初始化对话历史
self._conversation_history = [
{"role": "system", "content": self.config.system_prompt},
{"role": "user", "content": initial_message},
]
self._steps = []
self._all_findings = []
final_result = None
await self.emit_thinking("🧠 Orchestrator Agent 启动LLM 开始自主编排决策...")
try:
await self.emit_thinking("开始制定审计计划...")
# 1. 生成审计计划
plan = await self._create_audit_plan(project_info, config)
if not plan:
return AgentResult(
success=False,
error="无法生成审计计划",
)
await self.emit_event(
"planning",
f"审计计划已生成,共 {len(plan.phases)} 个阶段",
metadata={"plan": plan.__dict__}
)
# 2. 执行各阶段
all_findings = []
phase_results = {}
for phase in plan.phases:
for iteration in range(self.config.max_iterations):
if self.is_cancelled:
break
phase_name = phase.get("name", "unknown")
agent_type = phase.get("agent", "analysis")
self._iteration = iteration + 1
await self.emit_event(
"phase_start",
f"开始 {phase_name} 阶段",
phase=phase_name
# 🔥 发射 LLM 开始思考事件
await self.emit_llm_start(iteration + 1)
# 🔥 调用 LLM 进行思考和决策
response = await self.llm_service.chat_completion_raw(
messages=self._conversation_history,
temperature=0.1,
max_tokens=2048,
)
# 调度对应的子 Agent
result = await self._execute_phase(
phase_name=phase_name,
agent_type=agent_type,
project_info=project_info,
config=config,
plan=plan,
previous_results=phase_results,
)
llm_output = response.get("content", "")
tokens_this_round = response.get("usage", {}).get("total_tokens", 0)
self._total_tokens += tokens_this_round
phase_results[phase_name] = result
# 解析 LLM 的决策
step = self._parse_llm_response(llm_output)
if result.success and result.data:
if isinstance(result.data, dict):
findings = result.data.get("findings", [])
all_findings.extend(findings)
if not step:
# LLM 输出格式不正确,提示重试
await self.emit_llm_decision("格式错误", "需要重新输出")
self._conversation_history.append({
"role": "assistant",
"content": llm_output,
})
self._conversation_history.append({
"role": "user",
"content": "请按照规定格式输出Thought + Action + Action Input",
})
continue
await self.emit_event(
"phase_complete",
f"{phase_name} 阶段完成",
phase=phase_name
)
self._steps.append(step)
# 3. 汇总结果
await self.emit_thinking("汇总审计结果...")
# 🔥 发射 LLM 思考内容事件 - 展示编排决策的思考过程
if step.thought:
await self.emit_llm_thought(step.thought, iteration + 1)
summary = await self._generate_summary(
plan=plan,
phase_results=phase_results,
all_findings=all_findings,
)
# 添加 LLM 响应到历史
self._conversation_history.append({
"role": "assistant",
"content": llm_output,
})
# 执行 LLM 决定的操作
if step.action == "finish":
# 🔥 LLM 决定完成审计
await self.emit_llm_decision("完成审计", "LLM 判断审计已充分完成")
await self.emit_llm_complete(
f"编排完成,发现 {len(self._all_findings)} 个漏洞",
self._total_tokens
)
final_result = step.action_input
break
elif step.action == "dispatch_agent":
# 🔥 LLM 决定调度子 Agent
agent_name = step.action_input.get("agent", "unknown")
task_desc = step.action_input.get("task", "")
await self.emit_llm_decision(
f"调度 {agent_name} Agent",
f"任务: {task_desc[:100]}"
)
await self.emit_llm_action("dispatch_agent", step.action_input)
observation = await self._dispatch_agent(step.action_input)
step.observation = observation
# 🔥 发射观察事件
await self.emit_llm_observation(observation)
elif step.action == "summarize":
# LLM 要求汇总
await self.emit_llm_decision("汇总发现", "LLM 请求查看当前发现汇总")
observation = self._summarize_findings()
step.observation = observation
await self.emit_llm_observation(observation)
else:
observation = f"未知操作: {step.action},可用操作: dispatch_agent, summarize, finish"
await self.emit_llm_decision("未知操作", observation)
# 添加观察结果到历史
self._conversation_history.append({
"role": "user",
"content": f"Observation:\n{step.observation}",
})
# 生成最终结果
duration_ms = int((time.time() - start_time) * 1000)
await self.emit_event(
"info",
f"🎯 Orchestrator 完成: {len(self._all_findings)} 个发现, {self._iteration} 轮决策"
)
return AgentResult(
success=True,
data={
"plan": plan.__dict__,
"findings": all_findings,
"summary": summary,
"phase_results": {k: v.to_dict() for k, v in phase_results.items()},
"findings": self._all_findings,
"summary": final_result or self._generate_default_summary(),
"steps": [
{
"thought": s.thought,
"action": s.action,
"action_input": s.action_input,
"observation": s.observation[:500] if s.observation else None,
}
for s in self._steps
],
},
iterations=self._iteration,
tool_calls=self._tool_calls,
@ -208,174 +300,197 @@ class OrchestratorAgent(BaseAgent):
error=str(e),
)
async def _create_audit_plan(
def _build_initial_message(
self,
project_info: Dict[str, Any],
config: Dict[str, Any],
) -> Optional[AuditPlan]:
"""生成审计计划"""
# 构建 prompt
prompt = f"""基于以下项目信息,制定安全审计计划。
) -> str:
"""构建初始消息"""
msg = f"""请开始对以下项目进行安全审计。
## 项目信息
- 名称: {project_info.get('name', 'unknown')}
- 语言: {project_info.get('languages', [])}
- 文件数量: {project_info.get('file_count', 0)}
- 目录结构: {project_info.get('structure', {})}
- 目录结构: {json.dumps(project_info.get('structure', {}), ensure_ascii=False, indent=2)}
## 用户配置
- 目标漏洞: {config.get('target_vulnerabilities', [])}
- 目标漏洞: {config.get('target_vulnerabilities', ['all'])}
- 验证级别: {config.get('verification_level', 'sandbox')}
- 排除模式: {config.get('exclude_patterns', [])}
请生成审计计划返回 JSON 格式"""
## 可用子 Agent
{', '.join(self.sub_agents.keys()) if self.sub_agents else '(暂无子 Agent)'}
请开始你的审计工作首先思考应该如何开展然后决定第一步做什么"""
return msg
def _parse_llm_response(self, response: str) -> Optional[AgentStep]:
"""解析 LLM 响应"""
# 提取 Thought
thought_match = re.search(r'Thought:\s*(.*?)(?=Action:|$)', response, re.DOTALL)
thought = thought_match.group(1).strip() if thought_match else ""
# 提取 Action
action_match = re.search(r'Action:\s*(\w+)', response)
if not action_match:
return None
action = action_match.group(1).strip()
# 提取 Action Input
input_match = re.search(r'Action Input:\s*(.*?)(?=Thought:|Observation:|$)', response, re.DOTALL)
if not input_match:
return None
input_text = input_match.group(1).strip()
# 移除 markdown 代码块
input_text = re.sub(r'```json\s*', '', input_text)
input_text = re.sub(r'```\s*', '', input_text)
try:
# 调用 LLM
messages = [
{"role": "system", "content": self.config.system_prompt},
{"role": "user", "content": prompt},
]
action_input = json.loads(input_text)
except json.JSONDecodeError:
action_input = {"raw": input_text}
response = await self.llm_service.chat_completion_raw(
messages=messages,
temperature=0.1,
max_tokens=2000,
)
return AgentStep(
thought=thought,
action=action,
action_input=action_input,
)
content = response.get("content", "")
async def _dispatch_agent(self, params: Dict[str, Any]) -> str:
"""调度子 Agent"""
agent_name = params.get("agent", "")
task = params.get("task", "")
context = params.get("context", "")
# 解析 JSON
import json
import re
# 提取 JSON
json_match = re.search(r'\{[\s\S]*\}', content)
if json_match:
plan_data = json.loads(json_match.group())
return AuditPlan(
phases=plan_data.get("phases", self._default_phases()),
high_risk_areas=plan_data.get("high_risk_areas", []),
focus_vulnerabilities=plan_data.get("focus_vulnerabilities", []),
estimated_steps=plan_data.get("estimated_steps", 30),
priority_files=plan_data.get("priority_files", []),
metadata=plan_data,
)
else:
# 使用默认计划
return AuditPlan(
phases=self._default_phases(),
high_risk_areas=["src/", "api/", "controllers/", "routes/"],
focus_vulnerabilities=["sql_injection", "xss", "command_injection"],
estimated_steps=30,
priority_files=[],
metadata={},
)
except Exception as e:
logger.error(f"Failed to create audit plan: {e}")
return AuditPlan(
phases=self._default_phases(),
high_risk_areas=[],
focus_vulnerabilities=[],
estimated_steps=30,
priority_files=[],
metadata={},
)
def _default_phases(self) -> List[Dict[str, Any]]:
"""默认审计阶段"""
return [
{
"name": "recon",
"description": "信息收集 - 分析项目结构和技术栈",
"agent": "recon",
},
{
"name": "static_analysis",
"description": "静态分析 - 使用外部工具快速扫描",
"agent": "analysis",
},
{
"name": "deep_analysis",
"description": "深度分析 - AI 驱动的代码审计",
"agent": "analysis",
},
{
"name": "verification",
"description": "漏洞验证 - 确认发现的漏洞",
"agent": "verification",
},
]
async def _execute_phase(
self,
phase_name: str,
agent_type: str,
project_info: Dict[str, Any],
config: Dict[str, Any],
plan: AuditPlan,
previous_results: Dict[str, AgentResult],
) -> AgentResult:
"""执行审计阶段"""
agent = self.sub_agents.get(agent_type)
agent = self.sub_agents.get(agent_name)
if not agent:
logger.warning(f"Agent not found: {agent_type}")
return AgentResult(success=False, error=f"Agent {agent_type} not found")
available = list(self.sub_agents.keys())
return f"错误: Agent '{agent_name}' 不存在。可用的 Agent: {available}"
# 构建阶段输入
phase_input = {
"phase_name": phase_name,
"project_info": project_info,
"config": config,
"plan": plan.__dict__,
"previous_results": {k: v.to_dict() for k, v in previous_results.items()},
}
await self.emit_event(
"dispatch",
f"📤 调度 {agent_name} Agent: {task[:100]}...",
agent=agent_name,
task=task,
)
# 执行子 Agent
return await agent.run(phase_input)
self._tool_calls += 1
async def _generate_summary(
self,
plan: AuditPlan,
phase_results: Dict[str, AgentResult],
all_findings: List[Dict],
) -> Dict[str, Any]:
"""生成审计摘要"""
# 统计漏洞
try:
# 构建子 Agent 输入
sub_input = {
"task": task,
"task_context": context,
"project_info": {}, # 从上下文获取
"config": {},
}
# 执行子 Agent
result = await agent.run(sub_input)
# 收集发现
if result.success and result.data:
findings = result.data.get("findings", [])
self._all_findings.extend(findings)
await self.emit_event(
"dispatch_complete",
f"{agent_name} Agent 完成: {len(findings)} 个发现",
agent=agent_name,
findings_count=len(findings),
)
# 构建观察结果
observation = f"""## {agent_name} Agent 执行结果
**状态**: 成功
**发现数量**: {len(findings)}
**迭代次数**: {result.iterations}
**耗时**: {result.duration_ms}ms
### 发现摘要
"""
for i, f in enumerate(findings[:10]): # 最多显示 10 个
observation += f"""
{i+1}. [{f.get('severity', 'unknown')}] {f.get('title', 'Unknown')}
- 类型: {f.get('vulnerability_type', 'unknown')}
- 文件: {f.get('file_path', 'unknown')}
- 描述: {f.get('description', '')[:200]}...
"""
if len(findings) > 10:
observation += f"\n... 还有 {len(findings) - 10} 个发现"
if result.data.get("summary"):
observation += f"\n\n### Agent 总结\n{result.data['summary']}"
return observation
else:
return f"## {agent_name} Agent 执行失败\n\n错误: {result.error}"
except Exception as e:
logger.error(f"Sub-agent dispatch failed: {e}", exc_info=True)
return f"## 调度失败\n\n错误: {str(e)}"
def _summarize_findings(self) -> str:
"""汇总当前发现"""
if not self._all_findings:
return "目前还没有发现任何漏洞。"
# 统计
severity_counts = {"critical": 0, "high": 0, "medium": 0, "low": 0}
type_counts = {}
verified_count = 0
for finding in all_findings:
sev = finding.get("severity", "low")
for f in self._all_findings:
sev = f.get("severity", "low")
severity_counts[sev] = severity_counts.get(sev, 0) + 1
vtype = finding.get("vulnerability_type", "other")
vtype = f.get("vulnerability_type", "other")
type_counts[vtype] = type_counts.get(vtype, 0) + 1
if finding.get("is_verified"):
verified_count += 1
summary = f"""## 当前发现汇总
# 计算安全评分
base_score = 100
deductions = (
severity_counts["critical"] * 20 +
severity_counts["high"] * 10 +
severity_counts["medium"] * 5 +
severity_counts["low"] * 2
)
security_score = max(0, base_score - deductions)
**总计**: {len(self._all_findings)} 个漏洞
### 严重程度分布
- Critical: {severity_counts['critical']}
- High: {severity_counts['high']}
- Medium: {severity_counts['medium']}
- Low: {severity_counts['low']}
### 漏洞类型分布
"""
for vtype, count in type_counts.items():
summary += f"- {vtype}: {count}\n"
summary += "\n### 详细列表\n"
for i, f in enumerate(self._all_findings):
summary += f"{i+1}. [{f.get('severity')}] {f.get('title')} ({f.get('file_path')})\n"
return summary
def _generate_default_summary(self) -> Dict[str, Any]:
"""生成默认摘要"""
severity_counts = {"critical": 0, "high": 0, "medium": 0, "low": 0}
for f in self._all_findings:
sev = f.get("severity", "low")
severity_counts[sev] = severity_counts.get(sev, 0) + 1
return {
"total_findings": len(all_findings),
"verified_count": verified_count,
"total_findings": len(self._all_findings),
"severity_distribution": severity_counts,
"vulnerability_types": type_counts,
"security_score": security_score,
"phases_completed": len(phase_results),
"high_risk_areas": plan.high_risk_areas,
"conclusion": "审计完成(未通过 LLM 生成结论)",
}
def get_conversation_history(self) -> List[Dict[str, str]]:
"""获取对话历史"""
return self._conversation_history
def get_steps(self) -> List[AgentStep]:
"""获取执行步骤"""
return self._steps

View File

@ -1,72 +1,127 @@
"""
Recon Agent (信息收集层)
负责项目结构分析技术栈识别入口点识别
Recon Agent (信息收集层) - LLM 驱动版
类型: ReAct
LLM 是真正的大脑
- LLM 决定收集什么信息
- LLM 决定使用哪个工具
- LLM 决定何时信息足够
- LLM 动态调整收集策略
类型: ReAct (真正的!)
"""
import asyncio
import json
import logging
import os
import re
from typing import List, Dict, Any, Optional
from dataclasses import dataclass, field
from dataclasses import dataclass
from .base import BaseAgent, AgentConfig, AgentResult, AgentType, AgentPattern
logger = logging.getLogger(__name__)
RECON_SYSTEM_PROMPT = """你是 DeepAudit 的信息收集 Agent负责在安全审计前收集项目信息。
RECON_SYSTEM_PROMPT = """你是 DeepAudit 的信息收集 Agent负责在安全审计前**自主**收集项目信息。
## 你的职责
1. 分析项目结构和目录布局
2. 识别使用的技术栈和框架
3. 找出应用程序入口点
4. 分析依赖和第三方库
5. 识别高风险区域
## 你的角色
你是信息收集的**大脑**不是机械执行者你需要
1. 自主思考需要收集什么信息
2. 选择合适的工具获取信息
3. 根据发现动态调整策略
4. 判断何时信息收集足够
## 你可以使用的工具
- list_files: 列出目录内容
- read_file: 读取文件内容
- search_code: 搜索代码
- semgrep_scan: Semgrep 扫描
- npm_audit: npm 依赖审计
- safety_scan: Python 依赖审计
- gitleaks_scan: 密钥泄露扫描
## 信息收集要点
1. **目录结构**: 了解项目布局识别源码配置测试目录
2. **技术栈**: 检测语言框架数据库等
3. **入口点**: API 路由控制器处理函数
4. **配置文件**: 环境变量数据库配置API 密钥
5. **依赖**: package.json, requirements.txt, go.mod
6. **安全相关**: 认证授权加密相关代码
### 文件系统
- **list_files**: 列出目录内容
参数: directory (str), recursive (bool), pattern (str), max_files (int)
## 输出格式
完成后返回 JSON:
- **read_file**: 读取文件内容
参数: file_path (str), start_line (int), end_line (int), max_lines (int)
- **search_code**: 代码关键字搜索
参数: keyword (str), max_results (int)
### 安全扫描
- **semgrep_scan**: Semgrep 静态分析扫描
- **npm_audit**: npm 依赖漏洞审计
- **safety_scan**: Python 依赖漏洞审计
- **gitleaks_scan**: 密钥/敏感信息泄露扫描
- **osv_scan**: OSV 通用依赖漏洞扫描
## 工作方式
每一步你需要输出
```
Thought: [分析当前状态思考还需要什么信息]
Action: [工具名称]
Action Input: [JSON 格式的参数]
```
当你认为信息收集足够时输出
```
Thought: [总结收集到的信息]
Final Answer: [JSON 格式的收集结果]
```
## Final Answer 格式
```json
{
"project_structure": {...},
"project_structure": {
"directories": [],
"config_files": [],
"total_files": 数量
},
"tech_stack": {
"languages": [],
"frameworks": [],
"databases": []
},
"entry_points": [],
"high_risk_areas": [],
"dependencies": {...},
"entry_points": [
{"type": "描述", "file": "路径", "line": 行号}
],
"high_risk_areas": ["路径列表"],
"dependencies": {},
"initial_findings": []
}
```
请系统性地收集信息为后续分析做准备"""
## 信息收集策略建议
1. list_files 了解项目结构
2. 读取配置文件 (package.json, requirements.txt, go.mod ) 识别技术栈
3. 搜索入口点模式 (routes, controllers, handlers)
4. 运行安全扫描发现初步问题
5. 根据发现继续深入
## 重要原则
1. **你是大脑** - 每一步都要思考不要机械执行
2. **动态调整** - 根据发现调整策略
3. **效率优先** - 不要重复收集已有信息
4. **主动探索** - 发现有趣的东西要深入
现在开始收集项目信息"""
@dataclass
class ReconStep:
"""信息收集步骤"""
thought: str
action: Optional[str] = None
action_input: Optional[Dict] = None
observation: Optional[str] = None
is_final: bool = False
final_answer: Optional[Dict] = None
class ReconAgent(BaseAgent):
"""
信息收集 Agent
信息收集 Agent - LLM 驱动版
使用 ReAct 模式迭代收集项目信息
LLM 全程参与自主决定
1. 收集什么信息
2. 使用什么工具
3. 何时足够
"""
def __init__(
@ -81,84 +136,219 @@ class ReconAgent(BaseAgent):
pattern=AgentPattern.REACT,
max_iterations=15,
system_prompt=RECON_SYSTEM_PROMPT,
tools=[
"list_files", "read_file", "search_code",
"semgrep_scan", "npm_audit", "safety_scan",
"gitleaks_scan", "osv_scan",
],
)
super().__init__(config, llm_service, tools, event_emitter)
self._conversation_history: List[Dict[str, str]] = []
self._steps: List[ReconStep] = []
def _get_tools_description(self) -> str:
"""生成工具描述"""
tools_info = []
for name, tool in self.tools.items():
if name.startswith("_"):
continue
desc = f"- {name}: {getattr(tool, 'description', 'No description')}"
tools_info.append(desc)
return "\n".join(tools_info)
def _parse_llm_response(self, response: str) -> ReconStep:
"""解析 LLM 响应"""
step = ReconStep(thought="")
# 提取 Thought
thought_match = re.search(r'Thought:\s*(.*?)(?=Action:|Final Answer:|$)', response, re.DOTALL)
if thought_match:
step.thought = thought_match.group(1).strip()
# 检查是否是最终答案
final_match = re.search(r'Final Answer:\s*(.*?)$', response, re.DOTALL)
if final_match:
step.is_final = True
try:
answer_text = final_match.group(1).strip()
answer_text = re.sub(r'```json\s*', '', answer_text)
answer_text = re.sub(r'```\s*', '', answer_text)
step.final_answer = json.loads(answer_text)
except json.JSONDecodeError:
step.final_answer = {"raw_answer": final_match.group(1).strip()}
return step
# 提取 Action
action_match = re.search(r'Action:\s*(\w+)', response)
if action_match:
step.action = action_match.group(1).strip()
# 提取 Action Input
input_match = re.search(r'Action Input:\s*(.*?)(?=Thought:|Action:|Observation:|$)', response, re.DOTALL)
if input_match:
input_text = input_match.group(1).strip()
input_text = re.sub(r'```json\s*', '', input_text)
input_text = re.sub(r'```\s*', '', input_text)
try:
step.action_input = json.loads(input_text)
except json.JSONDecodeError:
step.action_input = {"raw_input": input_text}
return step
async def _execute_tool(self, tool_name: str, tool_input: Dict) -> str:
"""执行工具"""
tool = self.tools.get(tool_name)
if not tool:
return f"错误: 工具 '{tool_name}' 不存在。可用工具: {list(self.tools.keys())}"
try:
self._tool_calls += 1
await self.emit_tool_call(tool_name, tool_input)
import time
start = time.time()
result = await tool.execute(**tool_input)
duration_ms = int((time.time() - start) * 1000)
await self.emit_tool_result(tool_name, str(result.data)[:200], duration_ms)
if result.success:
output = str(result.data)
if len(output) > 4000:
output = output[:4000] + f"\n\n... [输出已截断,共 {len(str(result.data))} 字符]"
return output
else:
return f"工具执行失败: {result.error}"
except Exception as e:
logger.error(f"Tool execution error: {e}")
return f"工具执行错误: {str(e)}"
async def run(self, input_data: Dict[str, Any]) -> AgentResult:
"""执行信息收集"""
"""
执行信息收集 - LLM 全程参与
"""
import time
start_time = time.time()
project_info = input_data.get("project_info", {})
config = input_data.get("config", {})
task = input_data.get("task", "")
task_context = input_data.get("task_context", "")
# 构建初始消息
initial_message = f"""请开始收集项目信息。
## 项目基本信息
- 名称: {project_info.get('name', 'unknown')}
- 根目录: {project_info.get('root', '.')}
## 任务上下文
{task_context or task or '进行全面的信息收集,为安全审计做准备。'}
## 可用工具
{self._get_tools_description()}
请开始你的信息收集工作首先思考应该收集什么信息然后选择合适的工具"""
# 初始化对话历史
self._conversation_history = [
{"role": "system", "content": self.config.system_prompt},
{"role": "user", "content": initial_message},
]
self._steps = []
final_result = None
await self.emit_thinking("🔍 Recon Agent 启动LLM 开始自主收集信息...")
try:
await self.emit_thinking("开始信息收集...")
for iteration in range(self.config.max_iterations):
if self.is_cancelled:
break
# 收集结果
result_data = {
"project_structure": {},
"tech_stack": {
"languages": [],
"frameworks": [],
"databases": [],
},
"entry_points": [],
"high_risk_areas": [],
"dependencies": {},
"initial_findings": [],
}
self._iteration = iteration + 1
# 1. 分析项目结构
await self.emit_thinking("分析项目结构...")
structure = await self._analyze_structure()
result_data["project_structure"] = structure
# 🔥 发射 LLM 开始思考事件
await self.emit_llm_start(iteration + 1)
# 2. 识别技术栈
await self.emit_thinking("识别技术栈...")
tech_stack = await self._identify_tech_stack(structure)
result_data["tech_stack"] = tech_stack
# 🔥 调用 LLM 进行思考和决策
response = await self.llm_service.chat_completion_raw(
messages=self._conversation_history,
temperature=0.1,
max_tokens=2048,
)
# 3. 扫描依赖漏洞
await self.emit_thinking("扫描依赖漏洞...")
deps_result = await self._scan_dependencies(tech_stack)
result_data["dependencies"] = deps_result.get("dependencies", {})
if deps_result.get("findings"):
result_data["initial_findings"].extend(deps_result["findings"])
llm_output = response.get("content", "")
tokens_this_round = response.get("usage", {}).get("total_tokens", 0)
self._total_tokens += tokens_this_round
# 4. 快速密钥扫描
await self.emit_thinking("扫描密钥泄露...")
secrets_result = await self._scan_secrets()
if secrets_result.get("findings"):
result_data["initial_findings"].extend(secrets_result["findings"])
# 解析 LLM 响应
step = self._parse_llm_response(llm_output)
self._steps.append(step)
# 5. 识别入口点
await self.emit_thinking("识别入口点...")
entry_points = await self._identify_entry_points(tech_stack)
result_data["entry_points"] = entry_points
# 🔥 发射 LLM 思考内容事件 - 展示 LLM 在想什么
if step.thought:
await self.emit_llm_thought(step.thought, iteration + 1)
# 6. 识别高风险区域
result_data["high_risk_areas"] = self._identify_high_risk_areas(
structure, tech_stack, entry_points
)
# 添加 LLM 响应到历史
self._conversation_history.append({
"role": "assistant",
"content": llm_output,
})
# 检查是否完成
if step.is_final:
await self.emit_llm_decision("完成信息收集", "LLM 判断已收集足够信息")
await self.emit_llm_complete(
f"信息收集完成,共 {self._iteration} 轮思考",
self._total_tokens
)
final_result = step.final_answer
break
# 执行工具
if step.action:
# 🔥 发射 LLM 动作决策事件
await self.emit_llm_action(step.action, step.action_input or {})
observation = await self._execute_tool(
step.action,
step.action_input or {}
)
step.observation = observation
# 🔥 发射 LLM 观察事件
await self.emit_llm_observation(observation)
# 添加观察结果到历史
self._conversation_history.append({
"role": "user",
"content": f"Observation:\n{observation}",
})
else:
# LLM 没有选择工具,提示它继续
await self.emit_llm_decision("继续思考", "LLM 需要更多信息")
self._conversation_history.append({
"role": "user",
"content": "请继续,选择一个工具执行,或者如果信息收集完成,输出 Final Answer。",
})
# 处理结果
duration_ms = int((time.time() - start_time) * 1000)
# 如果没有最终结果,从历史中汇总
if not final_result:
final_result = self._summarize_from_steps()
await self.emit_event(
"info",
f"信息收集完成: 发现 {len(result_data['entry_points'])} 个入口点, "
f"{len(result_data['high_risk_areas'])} 个高风险区域, "
f"{len(result_data['initial_findings'])} 个初步发现"
f"🎯 Recon Agent 完成: {self._iteration} 轮迭代, {self._tool_calls} 次工具调用"
)
return AgentResult(
success=True,
data=result_data,
data=final_result,
iterations=self._iteration,
tool_calls=self._tool_calls,
tokens_used=self._total_tokens,
@ -166,270 +356,58 @@ class ReconAgent(BaseAgent):
)
except Exception as e:
logger.error(f"Recon agent failed: {e}", exc_info=True)
logger.error(f"Recon Agent failed: {e}", exc_info=True)
return AgentResult(success=False, error=str(e))
async def _analyze_structure(self) -> Dict[str, Any]:
"""分析项目结构"""
structure = {
"directories": [],
"files_by_type": {},
"config_files": [],
"total_files": 0,
def _summarize_from_steps(self) -> Dict[str, Any]:
"""从步骤中汇总结果"""
# 默认结果结构
result = {
"project_structure": {},
"tech_stack": {
"languages": [],
"frameworks": [],
"databases": [],
},
"entry_points": [],
"high_risk_areas": [],
"dependencies": {},
"initial_findings": [],
}
# 列出根目录
list_tool = self.tools.get("list_files")
if not list_tool:
return structure
# 从步骤的观察结果中提取信息
for step in self._steps:
if step.observation:
# 尝试从观察中识别技术栈等信息
obs_lower = step.observation.lower()
result = await list_tool.execute(directory=".", recursive=True, max_files=300)
if "package.json" in obs_lower:
result["tech_stack"]["languages"].append("JavaScript/TypeScript")
if "requirements.txt" in obs_lower or "setup.py" in obs_lower:
result["tech_stack"]["languages"].append("Python")
if "go.mod" in obs_lower:
result["tech_stack"]["languages"].append("Go")
if result.success:
structure["total_files"] = result.metadata.get("file_count", 0)
# 识别配置文件
config_patterns = [
"package.json", "requirements.txt", "go.mod", "Cargo.toml",
"pom.xml", "build.gradle", ".env", "config.py", "settings.py",
"docker-compose.yml", "Dockerfile",
]
# 从输出中解析文件列表
if isinstance(result.data, str):
for line in result.data.split('\n'):
line = line.strip()
for pattern in config_patterns:
if pattern in line:
structure["config_files"].append(line)
return structure
async def _identify_tech_stack(self, structure: Dict) -> Dict[str, Any]:
"""识别技术栈"""
tech_stack = {
"languages": [],
"frameworks": [],
"databases": [],
"package_managers": [],
}
config_files = structure.get("config_files", [])
# 基于配置文件推断
for cfg in config_files:
if "package.json" in cfg:
tech_stack["languages"].append("JavaScript/TypeScript")
tech_stack["package_managers"].append("npm")
elif "requirements.txt" in cfg or "setup.py" in cfg:
tech_stack["languages"].append("Python")
tech_stack["package_managers"].append("pip")
elif "go.mod" in cfg:
tech_stack["languages"].append("Go")
elif "Cargo.toml" in cfg:
tech_stack["languages"].append("Rust")
elif "pom.xml" in cfg or "build.gradle" in cfg:
tech_stack["languages"].append("Java")
# 读取 package.json 识别框架
read_tool = self.tools.get("read_file")
if read_tool and "package.json" in str(config_files):
result = await read_tool.execute(file_path="package.json", max_lines=100)
if result.success:
content = result.data
if "react" in content.lower():
tech_stack["frameworks"].append("React")
if "vue" in content.lower():
tech_stack["frameworks"].append("Vue")
if "express" in content.lower():
tech_stack["frameworks"].append("Express")
if "fastify" in content.lower():
tech_stack["frameworks"].append("Fastify")
if "next" in content.lower():
tech_stack["frameworks"].append("Next.js")
# 读取 requirements.txt 识别框架
if read_tool and "requirements.txt" in str(config_files):
result = await read_tool.execute(file_path="requirements.txt", max_lines=50)
if result.success:
content = result.data.lower()
if "django" in content:
tech_stack["frameworks"].append("Django")
if "flask" in content:
tech_stack["frameworks"].append("Flask")
if "fastapi" in content:
tech_stack["frameworks"].append("FastAPI")
if "sqlalchemy" in content:
tech_stack["databases"].append("SQLAlchemy")
if "pymongo" in content:
tech_stack["databases"].append("MongoDB")
# 识别框架
if "react" in obs_lower:
result["tech_stack"]["frameworks"].append("React")
if "django" in obs_lower:
result["tech_stack"]["frameworks"].append("Django")
if "fastapi" in obs_lower:
result["tech_stack"]["frameworks"].append("FastAPI")
if "express" in obs_lower:
result["tech_stack"]["frameworks"].append("Express")
# 去重
tech_stack["languages"] = list(set(tech_stack["languages"]))
tech_stack["frameworks"] = list(set(tech_stack["frameworks"]))
tech_stack["databases"] = list(set(tech_stack["databases"]))
return tech_stack
async def _scan_dependencies(self, tech_stack: Dict) -> Dict[str, Any]:
"""扫描依赖漏洞"""
result = {
"dependencies": {},
"findings": [],
}
# npm audit
if "npm" in tech_stack.get("package_managers", []):
npm_tool = self.tools.get("npm_audit")
if npm_tool:
npm_result = await npm_tool.execute()
if npm_result.success and npm_result.metadata.get("findings_count", 0) > 0:
result["dependencies"]["npm"] = npm_result.metadata
# 转换为发现格式
for sev, count in npm_result.metadata.get("severity_counts", {}).items():
if count > 0 and sev in ["critical", "high"]:
result["findings"].append({
"vulnerability_type": "dependency_vulnerability",
"severity": sev,
"title": f"npm 依赖漏洞 ({count}{sev})",
"source": "npm_audit",
})
# Safety (Python)
if "pip" in tech_stack.get("package_managers", []):
safety_tool = self.tools.get("safety_scan")
if safety_tool:
safety_result = await safety_tool.execute()
if safety_result.success and safety_result.metadata.get("findings_count", 0) > 0:
result["dependencies"]["pip"] = safety_result.metadata
result["findings"].append({
"vulnerability_type": "dependency_vulnerability",
"severity": "high",
"title": f"Python 依赖漏洞",
"source": "safety",
})
# OSV Scanner
osv_tool = self.tools.get("osv_scan")
if osv_tool:
osv_result = await osv_tool.execute()
if osv_result.success and osv_result.metadata.get("findings_count", 0) > 0:
result["dependencies"]["osv"] = osv_result.metadata
result["tech_stack"]["languages"] = list(set(result["tech_stack"]["languages"]))
result["tech_stack"]["frameworks"] = list(set(result["tech_stack"]["frameworks"]))
return result
async def _scan_secrets(self) -> Dict[str, Any]:
"""扫描密钥泄露"""
result = {"findings": []}
gitleaks_tool = self.tools.get("gitleaks_scan")
if gitleaks_tool:
gl_result = await gitleaks_tool.execute()
if gl_result.success and gl_result.metadata.get("findings_count", 0) > 0:
for finding in gl_result.metadata.get("findings", []):
result["findings"].append({
"vulnerability_type": "hardcoded_secret",
"severity": "high",
"title": f"密钥泄露: {finding.get('rule', 'unknown')}",
"file_path": finding.get("file"),
"line_start": finding.get("line"),
"source": "gitleaks",
})
return result
async def _identify_entry_points(self, tech_stack: Dict) -> List[Dict[str, Any]]:
"""识别入口点"""
entry_points = []
search_tool = self.tools.get("search_code")
if not search_tool:
return entry_points
# 基于框架搜索入口点
search_patterns = []
frameworks = tech_stack.get("frameworks", [])
if "Express" in frameworks:
search_patterns.extend([
("app.get(", "Express GET route"),
("app.post(", "Express POST route"),
("router.get(", "Express router GET"),
("router.post(", "Express router POST"),
])
if "FastAPI" in frameworks:
search_patterns.extend([
("@app.get(", "FastAPI GET endpoint"),
("@app.post(", "FastAPI POST endpoint"),
("@router.get(", "FastAPI router GET"),
("@router.post(", "FastAPI router POST"),
])
if "Django" in frameworks:
search_patterns.extend([
("def get(self", "Django GET view"),
("def post(self", "Django POST view"),
("path(", "Django URL pattern"),
])
if "Flask" in frameworks:
search_patterns.extend([
("@app.route(", "Flask route"),
("@blueprint.route(", "Flask blueprint route"),
])
# 通用模式
search_patterns.extend([
("def handle", "Handler function"),
("async def handle", "Async handler"),
("class.*Controller", "Controller class"),
("class.*Handler", "Handler class"),
])
for pattern, description in search_patterns[:10]: # 限制搜索数量
result = await search_tool.execute(keyword=pattern, max_results=10)
if result.success and result.metadata.get("matches", 0) > 0:
for match in result.metadata.get("results", [])[:5]:
entry_points.append({
"type": description,
"file": match.get("file"),
"line": match.get("line"),
"pattern": pattern,
})
return entry_points[:30] # 限制总数
def _identify_high_risk_areas(
self,
structure: Dict,
tech_stack: Dict,
entry_points: List[Dict],
) -> List[str]:
"""识别高风险区域"""
high_risk = set()
# 通用高风险目录
risk_dirs = [
"auth/", "authentication/", "login/",
"api/", "routes/", "controllers/", "handlers/",
"db/", "database/", "models/",
"admin/", "management/",
"upload/", "file/",
"payment/", "billing/",
]
for dir_name in risk_dirs:
high_risk.add(dir_name)
# 从入口点提取目录
for ep in entry_points:
file_path = ep.get("file", "")
if "/" in file_path:
dir_path = "/".join(file_path.split("/")[:-1]) + "/"
high_risk.add(dir_path)
return list(high_risk)[:20]
def get_conversation_history(self) -> List[Dict[str, str]]:
"""获取对话历史"""
return self._conversation_history
def get_steps(self) -> List[ReconStep]:
"""获取执行步骤"""
return self._steps

View File

@ -1,13 +1,20 @@
"""
Verification Agent (漏洞验证层)
负责漏洞确认PoC 生成沙箱测试
Verification Agent (漏洞验证层) - LLM 驱动版
类型: ReAct
LLM 是验证的大脑
- LLM 决定如何验证每个漏洞
- LLM 构造验证策略
- LLM 分析验证结果
- LLM 判断是否为真实漏洞
类型: ReAct (真正的!)
"""
import asyncio
import json
import logging
import re
from typing import List, Dict, Any, Optional
from dataclasses import dataclass
from datetime import datetime, timezone
from .base import BaseAgent, AgentConfig, AgentResult, AgentType, AgentPattern
@ -15,69 +22,121 @@ from .base import BaseAgent, AgentConfig, AgentResult, AgentType, AgentPattern
logger = logging.getLogger(__name__)
VERIFICATION_SYSTEM_PROMPT = """你是 DeepAudit 的漏洞验证 Agent负责确认发现的漏洞是否真实存在
VERIFICATION_SYSTEM_PROMPT = """你是 DeepAudit 的漏洞验证 Agent一个**自主**的安全验证专家
## 你的职责
1. 分析漏洞上下文判断是否为真正的安全问题
2. 构造 PoC概念验证代码
3. 在沙箱中执行测试
4. 评估漏洞的实际影响
## 你的角色
你是漏洞验证的**大脑**不是机械验证器你需要
1. 理解每个漏洞的上下文
2. 设计合适的验证策略
3. 使用工具获取更多信息
4. 判断漏洞是否真实存在
5. 评估实际影响
## 你可以使用的工具
### 代码分析
- read_file: 读取更多上下文
- function_context: 分析函数调用关系
- dataflow_analysis: 追踪数据流
- vulnerability_validation: LLM 漏洞验证
- **read_file**: 读取更多代码上下文
参数: file_path (str), start_line (int), end_line (int)
- **function_context**: 分析函数调用关系
参数: function_name (str)
- **dataflow_analysis**: 追踪数据流
参数: source (str), sink (str), file_path (str)
- **vulnerability_validation**: LLM 深度验证
参数: code (str), vulnerability_type (str), context (str)
### 沙箱执行
- sandbox_exec: 在沙箱中执行命令
- sandbox_http: 发送 HTTP 请求
- verify_vulnerability: 自动验证漏洞
### 沙箱验证
- **sandbox_exec**: 在沙箱中执行命令
参数: command (str), timeout (int)
- **sandbox_http**: 发送 HTTP 请求测试
参数: method (str), url (str), data (dict), headers (dict)
- **verify_vulnerability**: 自动化漏洞验证
参数: vulnerability_type (str), target (str), payload (str)
## 验证流程
1. **上下文分析**: 获取更多代码上下文
2. **可利用性分析**: 判断漏洞是否可被利用
3. **PoC 构造**: 设计验证方案
4. **沙箱测试**: 在隔离环境中测试
5. **结果评估**: 确定漏洞是否真实存在
## 工作方式
你将收到一批待验证的漏洞发现对于每个发现你需要
## 验证标准
- **确认 (confirmed)**: 漏洞真实存在且可利用
- **可能 (likely)**: 高度可能存在漏洞
- **不确定 (uncertain)**: 需要更多信息
- **误报 (false_positive)**: 确认是误报
```
Thought: [分析这个漏洞思考如何验证]
Action: [工具名称]
Action Input: [JSON 格式的参数]
```
## 输出格式
验证完所有发现后输出
```
Thought: [总结验证结果]
Final Answer: [JSON 格式的验证报告]
```
## Final Answer 格式
```json
{
"findings": [
{
"original_finding": {...},
...原始发现字段...,
"verdict": "confirmed/likely/uncertain/false_positive",
"confidence": 0.0-1.0,
"is_verified": true/false,
"verification_method": "描述验证方法",
"verification_details": "验证过程和结果详情",
"poc": {
"code": "PoC 代码",
"description": "描述",
"steps": ["步骤1", "步骤2"]
"description": "PoC 描述",
"steps": ["步骤1", "步骤2"],
"payload": "测试 payload"
},
"impact": "影响分析",
"impact": "实际影响分析",
"recommendation": "修复建议"
}
]
],
"summary": {
"total": 数量,
"confirmed": 数量,
"likely": 数量,
"false_positive": 数量
}
}
```
请谨慎验证减少误报同时不遗漏真正的漏洞"""
## 验证判定标准
- **confirmed**: 漏洞确认存在且可利用有明确证据
- **likely**: 高度可能存在漏洞但无法完全确认
- **uncertain**: 需要更多信息才能判断
- **false_positive**: 确认是误报有明确理由
## 验证策略建议
1. **上下文分析**: read_file 获取更多代码上下文
2. **数据流追踪**: dataflow_analysis 确认污点传播
3. **LLM 深度分析**: vulnerability_validation 进行专业分析
4. **沙箱测试**: 对高危漏洞用沙箱进行安全测试
## 重要原则
1. **质量优先** - 宁可漏报也不要误报太多
2. **深入理解** - 理解代码逻辑不要表面判断
3. **证据支撑** - 判定要有依据
4. **安全第一** - 沙箱测试要谨慎
现在开始验证漏洞发现"""
@dataclass
class VerificationStep:
"""验证步骤"""
thought: str
action: Optional[str] = None
action_input: Optional[Dict] = None
observation: Optional[str] = None
is_final: bool = False
final_answer: Optional[Dict] = None
class VerificationAgent(BaseAgent):
"""
漏洞验证 Agent
漏洞验证 Agent - LLM 驱动版
使用 ReAct 模式验证发现的漏洞
LLM 全程参与自主决定
1. 如何验证每个漏洞
2. 使用什么工具
3. 判断真假
"""
def __init__(
@ -90,25 +149,114 @@ class VerificationAgent(BaseAgent):
name="Verification",
agent_type=AgentType.VERIFICATION,
pattern=AgentPattern.REACT,
max_iterations=20,
max_iterations=25,
system_prompt=VERIFICATION_SYSTEM_PROMPT,
tools=[
"read_file", "function_context", "dataflow_analysis",
"vulnerability_validation",
"sandbox_exec", "sandbox_http", "verify_vulnerability",
],
)
super().__init__(config, llm_service, tools, event_emitter)
self._conversation_history: List[Dict[str, str]] = []
self._steps: List[VerificationStep] = []
def _get_tools_description(self) -> str:
"""生成工具描述"""
tools_info = []
for name, tool in self.tools.items():
if name.startswith("_"):
continue
desc = f"- {name}: {getattr(tool, 'description', 'No description')}"
tools_info.append(desc)
return "\n".join(tools_info)
def _parse_llm_response(self, response: str) -> VerificationStep:
"""解析 LLM 响应"""
step = VerificationStep(thought="")
# 提取 Thought
thought_match = re.search(r'Thought:\s*(.*?)(?=Action:|Final Answer:|$)', response, re.DOTALL)
if thought_match:
step.thought = thought_match.group(1).strip()
# 检查是否是最终答案
final_match = re.search(r'Final Answer:\s*(.*?)$', response, re.DOTALL)
if final_match:
step.is_final = True
try:
answer_text = final_match.group(1).strip()
answer_text = re.sub(r'```json\s*', '', answer_text)
answer_text = re.sub(r'```\s*', '', answer_text)
step.final_answer = json.loads(answer_text)
except json.JSONDecodeError:
step.final_answer = {"findings": [], "raw_answer": final_match.group(1).strip()}
return step
# 提取 Action
action_match = re.search(r'Action:\s*(\w+)', response)
if action_match:
step.action = action_match.group(1).strip()
# 提取 Action Input
input_match = re.search(r'Action Input:\s*(.*?)(?=Thought:|Action:|Observation:|$)', response, re.DOTALL)
if input_match:
input_text = input_match.group(1).strip()
input_text = re.sub(r'```json\s*', '', input_text)
input_text = re.sub(r'```\s*', '', input_text)
try:
step.action_input = json.loads(input_text)
except json.JSONDecodeError:
step.action_input = {"raw_input": input_text}
return step
async def _execute_tool(self, tool_name: str, tool_input: Dict) -> str:
"""执行工具"""
tool = self.tools.get(tool_name)
if not tool:
return f"错误: 工具 '{tool_name}' 不存在。可用工具: {list(self.tools.keys())}"
try:
self._tool_calls += 1
await self.emit_tool_call(tool_name, tool_input)
import time
start = time.time()
result = await tool.execute(**tool_input)
duration_ms = int((time.time() - start) * 1000)
await self.emit_tool_result(tool_name, str(result.data)[:200], duration_ms)
if result.success:
output = str(result.data)
# 包含 metadata
if result.metadata:
if "validation" in result.metadata:
output += f"\n\n验证结果:\n{json.dumps(result.metadata['validation'], ensure_ascii=False, indent=2)}"
if len(output) > 4000:
output = output[:4000] + f"\n\n... [输出已截断]"
return output
else:
return f"工具执行失败: {result.error}"
except Exception as e:
logger.error(f"Tool execution error: {e}")
return f"工具执行错误: {str(e)}"
async def run(self, input_data: Dict[str, Any]) -> AgentResult:
"""执行漏洞验证"""
"""
执行漏洞验证 - LLM 全程参与
"""
import time
start_time = time.time()
previous_results = input_data.get("previous_results", {})
config = input_data.get("config", {})
task = input_data.get("task", "")
task_context = input_data.get("task_context", "")
# 收集所有需要验证的发现
# 收集所有验证的发现
findings_to_verify = []
for phase_name, result in previous_results.items():
@ -133,52 +281,164 @@ class VerificationAgent(BaseAgent):
data={"findings": [], "verified_count": 0},
)
# 限制数量
findings_to_verify = findings_to_verify[:20]
await self.emit_event(
"info",
f"开始验证 {len(findings_to_verify)} 个发现"
)
try:
verified_findings = []
verification_level = config.get("verification_level", "sandbox")
# 构建初始消息
findings_summary = []
for i, f in enumerate(findings_to_verify):
findings_summary.append(f"""
### 发现 {i+1}: {f.get('title', 'Unknown')}
- 类型: {f.get('vulnerability_type', 'unknown')}
- 严重度: {f.get('severity', 'medium')}
- 文件: {f.get('file_path', 'unknown')}:{f.get('line_start', 0)}
- 代码:
```
{f.get('code_snippet', 'N/A')[:500]}
```
- 描述: {f.get('description', 'N/A')[:300]}
""")
for i, finding in enumerate(findings_to_verify[:20]): # 限制数量
initial_message = f"""请验证以下 {len(findings_to_verify)} 个安全发现。
## 待验证发现
{''.join(findings_summary)}
## 验证要求
- 验证级别: {config.get('verification_level', 'standard')}
## 可用工具
{self._get_tools_description()}
请开始验证对于每个发现思考如何验证它使用合适的工具获取更多信息然后判断是否为真实漏洞"""
# 初始化对话历史
self._conversation_history = [
{"role": "system", "content": self.config.system_prompt},
{"role": "user", "content": initial_message},
]
self._steps = []
final_result = None
await self.emit_thinking("🔐 Verification Agent 启动LLM 开始自主验证漏洞...")
try:
for iteration in range(self.config.max_iterations):
if self.is_cancelled:
break
await self.emit_thinking(
f"验证 [{i+1}/{min(len(findings_to_verify), 20)}]: {finding.get('title', 'unknown')}"
self._iteration = iteration + 1
# 🔥 发射 LLM 开始思考事件
await self.emit_llm_start(iteration + 1)
# 🔥 调用 LLM 进行思考和决策
response = await self.llm_service.chat_completion_raw(
messages=self._conversation_history,
temperature=0.1,
max_tokens=3000,
)
# 执行验证
verified = await self._verify_finding(finding, verification_level)
verified_findings.append(verified)
llm_output = response.get("content", "")
tokens_this_round = response.get("usage", {}).get("total_tokens", 0)
self._total_tokens += tokens_this_round
# 发射事件
if verified.get("is_verified"):
await self.emit_event(
"finding_verified",
f"✅ 已确认: {verified.get('title', '')}",
finding_id=verified.get("id"),
metadata={"severity": verified.get("severity")}
# 解析 LLM 响应
step = self._parse_llm_response(llm_output)
self._steps.append(step)
# 🔥 发射 LLM 思考内容事件 - 展示验证的思考过程
if step.thought:
await self.emit_llm_thought(step.thought, iteration + 1)
# 添加 LLM 响应到历史
self._conversation_history.append({
"role": "assistant",
"content": llm_output,
})
# 检查是否完成
if step.is_final:
await self.emit_llm_decision("完成漏洞验证", "LLM 判断验证已充分")
final_result = step.final_answer
await self.emit_llm_complete(
f"验证完成",
self._total_tokens
)
elif verified.get("verdict") == "false_positive":
await self.emit_event(
"finding_false_positive",
f"❌ 误报: {verified.get('title', '')}",
finding_id=verified.get("id"),
break
# 执行工具
if step.action:
# 🔥 发射 LLM 动作决策事件
await self.emit_llm_action(step.action, step.action_input or {})
observation = await self._execute_tool(
step.action,
step.action_input or {}
)
step.observation = observation
# 🔥 发射 LLM 观察事件
await self.emit_llm_observation(observation)
# 添加观察结果到历史
self._conversation_history.append({
"role": "user",
"content": f"Observation:\n{observation}",
})
else:
# LLM 没有选择工具,提示它继续
await self.emit_llm_decision("继续验证", "LLM 需要更多验证")
self._conversation_history.append({
"role": "user",
"content": "请继续验证。如果验证完成,输出 Final Answer 汇总所有验证结果。",
})
# 处理结果
duration_ms = int((time.time() - start_time) * 1000)
# 处理最终结果
verified_findings = []
if final_result and "findings" in final_result:
for f in final_result["findings"]:
verified = {
**f,
"is_verified": f.get("verdict") == "confirmed" or (
f.get("verdict") == "likely" and f.get("confidence", 0) >= 0.8
),
"verified_at": datetime.now(timezone.utc).isoformat() if f.get("verdict") in ["confirmed", "likely"] else None,
}
# 添加修复建议
if not verified.get("recommendation"):
verified["recommendation"] = self._get_recommendation(f.get("vulnerability_type", ""))
verified_findings.append(verified)
else:
# 如果没有最终结果,使用原始发现
for f in findings_to_verify:
verified_findings.append({
**f,
"verdict": "uncertain",
"confidence": 0.5,
"is_verified": False,
})
# 统计
confirmed_count = len([f for f in verified_findings if f.get("is_verified")])
confirmed_count = len([f for f in verified_findings if f.get("verdict") == "confirmed"])
likely_count = len([f for f in verified_findings if f.get("verdict") == "likely"])
false_positive_count = len([f for f in verified_findings if f.get("verdict") == "false_positive"])
duration_ms = int((time.time() - start_time) * 1000)
await self.emit_event(
"info",
f"验证完成: {confirmed_count} 确认, {likely_count} 可能, {false_positive_count} 误报"
f"🎯 Verification Agent 完成: {confirmed_count} 确认, {likely_count} 可能, {false_positive_count} 误报"
)
return AgentResult(
@ -196,167 +456,9 @@ class VerificationAgent(BaseAgent):
)
except Exception as e:
logger.error(f"Verification agent failed: {e}", exc_info=True)
logger.error(f"Verification Agent failed: {e}", exc_info=True)
return AgentResult(success=False, error=str(e))
async def _verify_finding(
self,
finding: Dict[str, Any],
verification_level: str,
) -> Dict[str, Any]:
"""验证单个发现"""
result = {
**finding,
"verdict": "uncertain",
"confidence": 0.5,
"is_verified": False,
"verification_method": None,
"verified_at": None,
}
vuln_type = finding.get("vulnerability_type", "")
file_path = finding.get("file_path", "")
line_start = finding.get("line_start", 0)
code_snippet = finding.get("code_snippet", "")
try:
# 1. 获取更多上下文
context = await self._get_context(file_path, line_start)
# 2. LLM 验证
validation_result = await self._llm_validation(
finding, context
)
result["verdict"] = validation_result.get("verdict", "uncertain")
result["confidence"] = validation_result.get("confidence", 0.5)
result["verification_method"] = "llm_analysis"
# 3. 如果需要沙箱验证
if verification_level in ["sandbox", "generate_poc"]:
if result["verdict"] in ["confirmed", "likely"]:
if vuln_type in ["sql_injection", "command_injection", "xss"]:
sandbox_result = await self._sandbox_verification(
finding, validation_result
)
if sandbox_result.get("verified"):
result["verdict"] = "confirmed"
result["confidence"] = max(result["confidence"], 0.9)
result["verification_method"] = "sandbox_test"
result["poc"] = sandbox_result.get("poc")
# 4. 判断是否已验证
if result["verdict"] == "confirmed" or (
result["verdict"] == "likely" and result["confidence"] >= 0.8
):
result["is_verified"] = True
result["verified_at"] = datetime.now(timezone.utc).isoformat()
# 5. 添加修复建议
if result["is_verified"]:
result["recommendation"] = self._get_recommendation(vuln_type)
except Exception as e:
logger.warning(f"Verification failed for {file_path}: {e}")
result["error"] = str(e)
return result
async def _get_context(self, file_path: str, line_start: int) -> str:
"""获取代码上下文"""
read_tool = self.tools.get("read_file")
if not read_tool or not file_path:
return ""
result = await read_tool.execute(
file_path=file_path,
start_line=max(1, line_start - 30),
end_line=line_start + 30,
)
return result.data if result.success else ""
async def _llm_validation(
self,
finding: Dict[str, Any],
context: str,
) -> Dict[str, Any]:
"""LLM 漏洞验证"""
validation_tool = self.tools.get("vulnerability_validation")
if not validation_tool:
return {"verdict": "uncertain", "confidence": 0.5}
code = finding.get("code_snippet", "") or context[:2000]
result = await validation_tool.execute(
code=code,
vulnerability_type=finding.get("vulnerability_type", "unknown"),
file_path=finding.get("file_path", ""),
line_number=finding.get("line_start"),
context=context[:1000] if context else None,
)
if result.success and result.metadata.get("validation"):
validation = result.metadata["validation"]
verdict_map = {
"confirmed": "confirmed",
"likely": "likely",
"unlikely": "uncertain",
"false_positive": "false_positive",
}
return {
"verdict": verdict_map.get(validation.get("verdict", ""), "uncertain"),
"confidence": validation.get("confidence", 0.5),
"explanation": validation.get("detailed_analysis", ""),
"exploitation_conditions": validation.get("exploitation_conditions", []),
"poc_idea": validation.get("poc_idea"),
}
return {"verdict": "uncertain", "confidence": 0.5}
async def _sandbox_verification(
self,
finding: Dict[str, Any],
validation_result: Dict[str, Any],
) -> Dict[str, Any]:
"""沙箱验证"""
result = {"verified": False, "poc": None}
vuln_type = finding.get("vulnerability_type", "")
poc_idea = validation_result.get("poc_idea", "")
# 根据漏洞类型选择验证方法
sandbox_tool = self.tools.get("sandbox_exec")
http_tool = self.tools.get("sandbox_http")
verify_tool = self.tools.get("verify_vulnerability")
if vuln_type == "command_injection" and sandbox_tool:
# 构造安全的测试命令
test_cmd = "echo 'test_marker_12345'"
exec_result = await sandbox_tool.execute(
command=f"python3 -c \"print('test')\"",
timeout=10,
)
if exec_result.success:
result["verified"] = True
result["poc"] = {
"description": "命令注入测试",
"method": "sandbox_exec",
}
elif vuln_type in ["sql_injection", "xss"] and verify_tool:
# 使用自动验证工具
# 注意:这需要实际的目标 URL
pass
return result
def _get_recommendation(self, vuln_type: str) -> str:
"""获取修复建议"""
recommendations = {
@ -369,7 +471,6 @@ class VerificationAgent(BaseAgent):
"hardcoded_secret": "使用环境变量或密钥管理服务存储敏感信息",
"weak_crypto": "使用强加密算法AES-256, SHA-256+),避免 MD5/SHA1",
}
return recommendations.get(vuln_type, "请根据具体情况修复此安全问题")
def _deduplicate(self, findings: List[Dict]) -> List[Dict]:
@ -390,3 +491,10 @@ class VerificationAgent(BaseAgent):
return unique
def get_conversation_history(self) -> List[Dict[str, str]]:
"""获取对话历史"""
return self._conversation_history
def get_steps(self) -> List[VerificationStep]:
"""获取执行步骤"""
return self._steps

View File

@ -91,6 +91,33 @@ class AgentEventEmitter:
metadata=metadata,
))
async def emit_llm_thought(self, thought: str, iteration: int = 0):
"""发射 LLM 思考内容事件 - 核心!展示 LLM 在想什么"""
display = thought[:500] + "..." if len(thought) > 500 else thought
await self.emit(AgentEventData(
event_type="llm_thought",
message=f"💭 LLM 思考:\n{display}",
metadata={"thought": thought, "iteration": iteration},
))
async def emit_llm_decision(self, decision: str, reason: str = ""):
"""发射 LLM 决策事件"""
await self.emit(AgentEventData(
event_type="llm_decision",
message=f"💡 LLM 决策: {decision}" + (f" ({reason})" if reason else ""),
metadata={"decision": decision, "reason": reason},
))
async def emit_llm_action(self, action: str, action_input: Dict):
"""发射 LLM 动作事件"""
import json
input_str = json.dumps(action_input, ensure_ascii=False)[:200]
await self.emit(AgentEventData(
event_type="llm_action",
message=f"⚡ LLM 动作: {action}\n 参数: {input_str}",
metadata={"action": action, "action_input": action_input},
))
async def emit_tool_call(
self,
tool_name: str,

View File

@ -1,12 +1,15 @@
"""
DeepAudit 审计工作流图
使用 LangGraph 构建状态机式的 Agent 协作流程
DeepAudit 审计工作流图 - LLM 驱动版
使用 LangGraph 构建 LLM 驱动的 Agent 协作流程
重要改变路由决策由 LLM 参与而不是硬编码条件
"""
from typing import TypedDict, Annotated, List, Dict, Any, Optional, Literal
from datetime import datetime
import operator
import logging
import json
from langgraph.graph import StateGraph, END
from langgraph.checkpoint.memory import MemorySaver
@ -56,12 +59,16 @@ class AuditState(TypedDict):
verified_findings: List[Finding]
false_positives: List[str]
# 控制流
# 控制流 - 🔥 关键LLM 可以设置这些来影响路由
current_phase: str
iteration: int
max_iterations: int
should_continue_analysis: bool
# 🔥 新增LLM 的路由决策
llm_next_action: Optional[str] # LLM 建议的下一步: "continue_analysis", "verify", "report", "end"
llm_routing_reason: Optional[str] # LLM 的决策理由
# 消息和事件
messages: Annotated[List[Dict], operator.add]
events: Annotated[List[Dict], operator.add]
@ -72,43 +79,233 @@ class AuditState(TypedDict):
error: Optional[str]
# ============ 路由函数 ============
# ============ LLM 路由决策器 ============
class LLMRouter:
"""
LLM 路由决策器
LLM 来决定下一步应该做什么
"""
def __init__(self, llm_service):
self.llm_service = llm_service
async def decide_after_recon(self, state: AuditState) -> Dict[str, Any]:
"""Recon 后让 LLM 决定下一步"""
entry_points = state.get("entry_points", [])
high_risk_areas = state.get("high_risk_areas", [])
tech_stack = state.get("tech_stack", {})
initial_findings = state.get("findings", [])
prompt = f"""作为安全审计的决策者,基于以下信息收集结果,决定下一步行动。
## 信息收集结果
- 入口点数量: {len(entry_points)}
- 高风险区域: {high_risk_areas[:10]}
- 技术栈: {tech_stack}
- 初步发现: {len(initial_findings)}
## 选项
1. "analysis" - 继续进行漏洞分析推荐有入口点或高风险区域时
2. "end" - 结束审计仅当没有任何可分析内容时
请返回 JSON 格式
{{"action": "analysis或end", "reason": "决策理由"}}"""
try:
response = await self.llm_service.chat_completion_raw(
messages=[
{"role": "system", "content": "你是安全审计流程的决策者,负责决定下一步行动。"},
{"role": "user", "content": prompt},
],
temperature=0.1,
max_tokens=200,
)
content = response.get("content", "")
# 提取 JSON
import re
json_match = re.search(r'\{.*\}', content, re.DOTALL)
if json_match:
result = json.loads(json_match.group())
return result
except Exception as e:
logger.warning(f"LLM routing decision failed: {e}")
# 默认决策
if entry_points or high_risk_areas:
return {"action": "analysis", "reason": "有可分析内容"}
return {"action": "end", "reason": "没有发现入口点或高风险区域"}
async def decide_after_analysis(self, state: AuditState) -> Dict[str, Any]:
"""Analysis 后让 LLM 决定下一步"""
findings = state.get("findings", [])
iteration = state.get("iteration", 0)
max_iterations = state.get("max_iterations", 3)
# 统计发现
severity_counts = {"critical": 0, "high": 0, "medium": 0, "low": 0}
for f in findings:
sev = f.get("severity", "medium")
severity_counts[sev] = severity_counts.get(sev, 0) + 1
prompt = f"""作为安全审计的决策者,基于以下分析结果,决定下一步行动。
## 分析结果
- 总发现数: {len(findings)}
- 严重程度分布: {severity_counts}
- 当前迭代: {iteration}/{max_iterations}
## 选项
1. "verification" - 验证发现的漏洞推荐有发现需要验证时
2. "analysis" - 继续深入分析推荐发现较少但还有迭代次数时
3. "report" - 生成报告推荐没有发现或已充分分析时
请返回 JSON 格式
{{"action": "verification/analysis/report", "reason": "决策理由"}}"""
try:
response = await self.llm_service.chat_completion_raw(
messages=[
{"role": "system", "content": "你是安全审计流程的决策者,负责决定下一步行动。"},
{"role": "user", "content": prompt},
],
temperature=0.1,
max_tokens=200,
)
content = response.get("content", "")
import re
json_match = re.search(r'\{.*\}', content, re.DOTALL)
if json_match:
result = json.loads(json_match.group())
return result
except Exception as e:
logger.warning(f"LLM routing decision failed: {e}")
# 默认决策
if not findings:
return {"action": "report", "reason": "没有发现漏洞"}
if len(findings) >= 3 or iteration >= max_iterations:
return {"action": "verification", "reason": "有足够的发现需要验证"}
return {"action": "analysis", "reason": "发现较少,继续分析"}
async def decide_after_verification(self, state: AuditState) -> Dict[str, Any]:
"""Verification 后让 LLM 决定下一步"""
verified_findings = state.get("verified_findings", [])
false_positives = state.get("false_positives", [])
iteration = state.get("iteration", 0)
max_iterations = state.get("max_iterations", 3)
prompt = f"""作为安全审计的决策者,基于以下验证结果,决定下一步行动。
## 验证结果
- 已确认漏洞: {len(verified_findings)}
- 误报数量: {len(false_positives)}
- 当前迭代: {iteration}/{max_iterations}
## 选项
1. "analysis" - 回到分析阶段重新分析推荐误报率太高时
2. "report" - 生成最终报告推荐验证完成时
请返回 JSON 格式
{{"action": "analysis/report", "reason": "决策理由"}}"""
try:
response = await self.llm_service.chat_completion_raw(
messages=[
{"role": "system", "content": "你是安全审计流程的决策者,负责决定下一步行动。"},
{"role": "user", "content": prompt},
],
temperature=0.1,
max_tokens=200,
)
content = response.get("content", "")
import re
json_match = re.search(r'\{.*\}', content, re.DOTALL)
if json_match:
result = json.loads(json_match.group())
return result
except Exception as e:
logger.warning(f"LLM routing decision failed: {e}")
# 默认决策
if len(false_positives) > len(verified_findings) and iteration < max_iterations:
return {"action": "analysis", "reason": "误报率较高,需要重新分析"}
return {"action": "report", "reason": "验证完成,生成报告"}
# ============ 路由函数 (结合 LLM 决策) ============
def route_after_recon(state: AuditState) -> Literal["analysis", "end"]:
"""Recon 后的路由决策"""
# 如果没有发现入口点或高风险区域,直接结束
"""
Recon 后的路由决策
优先使用 LLM 的决策否则使用默认逻辑
"""
# 检查 LLM 是否有决策
llm_action = state.get("llm_next_action")
if llm_action:
logger.info(f"Using LLM routing decision: {llm_action}, reason: {state.get('llm_routing_reason')}")
if llm_action == "end":
return "end"
return "analysis"
# 默认逻辑(作为 fallback
if not state.get("entry_points") and not state.get("high_risk_areas"):
return "end"
return "analysis"
def route_after_analysis(state: AuditState) -> Literal["verification", "analysis", "report"]:
"""Analysis 后的路由决策"""
"""
Analysis 后的路由决策
优先使用 LLM 的决策
"""
# 检查 LLM 是否有决策
llm_action = state.get("llm_next_action")
if llm_action:
logger.info(f"Using LLM routing decision: {llm_action}, reason: {state.get('llm_routing_reason')}")
if llm_action == "verification":
return "verification"
elif llm_action == "analysis":
return "analysis"
elif llm_action == "report":
return "report"
# 默认逻辑
findings = state.get("findings", [])
iteration = state.get("iteration", 0)
max_iterations = state.get("max_iterations", 3)
should_continue = state.get("should_continue_analysis", False)
# 如果没有发现,直接生成报告
if not findings:
return "report"
# 如果需要继续分析且未达到最大迭代
if should_continue and iteration < max_iterations:
return "analysis"
# 有发现需要验证
return "verification"
def route_after_verification(state: AuditState) -> Literal["analysis", "report"]:
"""Verification 后的路由决策"""
# 如果验证发现了误报,可能需要重新分析
"""
Verification 后的路由决策
优先使用 LLM 的决策
"""
# 检查 LLM 是否有决策
llm_action = state.get("llm_next_action")
if llm_action:
logger.info(f"Using LLM routing decision: {llm_action}, reason: {state.get('llm_routing_reason')}")
if llm_action == "analysis":
return "analysis"
return "report"
# 默认逻辑
false_positives = state.get("false_positives", [])
iteration = state.get("iteration", 0)
max_iterations = state.get("max_iterations", 3)
# 如果误报率太高且还有迭代次数,回到分析
if len(false_positives) > len(state.get("verified_findings", [])) and iteration < max_iterations:
return "analysis"
@ -123,6 +320,7 @@ def create_audit_graph(
verification_node,
report_node,
checkpointer: Optional[MemorySaver] = None,
llm_service=None, # 用于 LLM 路由决策
) -> StateGraph:
"""
创建审计工作流图
@ -132,7 +330,8 @@ def create_audit_graph(
analysis_node: 漏洞分析节点
verification_node: 漏洞验证节点
report_node: 报告生成节点
checkpointer: 检查点存储器用于状态持久化
checkpointer: 检查点存储器
llm_service: LLM 服务用于路由决策
Returns:
编译后的 StateGraph
@ -143,19 +342,19 @@ def create_audit_graph(
Recon 信息收集
Recon 信息收集 (LLM 驱动)
LLM 决定
Analysis 漏洞分析可循环
Analysis 漏洞分析 (LLM 驱动可循环)
LLM 决定
Verification 漏洞验证可回溯
Verification 漏洞验证 (LLM 驱动可回溯)
LLM 决定
Report 报告生成
@ -168,10 +367,53 @@ def create_audit_graph(
# 创建状态图
workflow = StateGraph(AuditState)
# 如果有 LLM 服务,创建路由决策器
llm_router = LLMRouter(llm_service) if llm_service else None
# 包装节点以添加 LLM 路由决策
async def recon_with_routing(state):
result = await recon_node(state)
# LLM 决定下一步
if llm_router:
decision = await llm_router.decide_after_recon({**state, **result})
result["llm_next_action"] = decision.get("action")
result["llm_routing_reason"] = decision.get("reason")
return result
async def analysis_with_routing(state):
result = await analysis_node(state)
# LLM 决定下一步
if llm_router:
decision = await llm_router.decide_after_analysis({**state, **result})
result["llm_next_action"] = decision.get("action")
result["llm_routing_reason"] = decision.get("reason")
return result
async def verification_with_routing(state):
result = await verification_node(state)
# LLM 决定下一步
if llm_router:
decision = await llm_router.decide_after_verification({**state, **result})
result["llm_next_action"] = decision.get("action")
result["llm_routing_reason"] = decision.get("reason")
return result
# 添加节点
workflow.add_node("recon", recon_node)
workflow.add_node("analysis", analysis_node)
workflow.add_node("verification", verification_node)
if llm_router:
workflow.add_node("recon", recon_with_routing)
workflow.add_node("analysis", analysis_with_routing)
workflow.add_node("verification", verification_with_routing)
else:
workflow.add_node("recon", recon_node)
workflow.add_node("analysis", analysis_node)
workflow.add_node("verification", verification_node)
workflow.add_node("report", report_node)
# 设置入口点
@ -192,7 +434,7 @@ def create_audit_graph(
route_after_analysis,
{
"verification": "verification",
"analysis": "analysis", # 循环
"analysis": "analysis",
"report": "report",
}
)
@ -201,7 +443,7 @@ def create_audit_graph(
"verification",
route_after_verification,
{
"analysis": "analysis", # 回溯
"analysis": "analysis",
"report": "report",
}
)
@ -225,50 +467,42 @@ def create_audit_graph_with_human(
report_node,
human_review_node,
checkpointer: Optional[MemorySaver] = None,
llm_service=None,
) -> StateGraph:
"""
创建带人机协作的审计工作流图
在验证阶段后增加人工审核节点
工作流结构:
START
Recon
Analysis
Verification
Human Review 人工审核可跳过
Report
END
"""
workflow = StateGraph(AuditState)
llm_router = LLMRouter(llm_service) if llm_service else None
# 包装节点
async def recon_with_routing(state):
result = await recon_node(state)
if llm_router:
decision = await llm_router.decide_after_recon({**state, **result})
result["llm_next_action"] = decision.get("action")
result["llm_routing_reason"] = decision.get("reason")
return result
async def analysis_with_routing(state):
result = await analysis_node(state)
if llm_router:
decision = await llm_router.decide_after_analysis({**state, **result})
result["llm_next_action"] = decision.get("action")
result["llm_routing_reason"] = decision.get("reason")
return result
# 添加节点
workflow.add_node("recon", recon_node)
workflow.add_node("analysis", analysis_node)
if llm_router:
workflow.add_node("recon", recon_with_routing)
workflow.add_node("analysis", analysis_with_routing)
else:
workflow.add_node("recon", recon_node)
workflow.add_node("analysis", analysis_node)
workflow.add_node("verification", verification_node)
workflow.add_node("human_review", human_review_node)
workflow.add_node("report", report_node)
@ -296,7 +530,6 @@ def create_audit_graph_with_human(
# Human Review 后的路由
def route_after_human(state: AuditState) -> Literal["analysis", "report"]:
# 人工可以决定重新分析或继续
if state.get("should_continue_analysis"):
return "analysis"
return "report"
@ -340,15 +573,6 @@ class AuditGraphRunner:
) -> Dict[str, Any]:
"""
执行审计工作流
Args:
project_root: 项目根目录
project_info: 项目信息
config: 审计配置
task_id: 任务 ID
Returns:
最终状态
"""
# 初始状态
initial_state: AuditState = {
@ -367,6 +591,8 @@ class AuditGraphRunner:
"iteration": 0,
"max_iterations": config.get("max_iterations", 3),
"should_continue_analysis": False,
"llm_next_action": None,
"llm_routing_reason": None,
"messages": [],
"events": [],
"summary": None,
@ -374,32 +600,32 @@ class AuditGraphRunner:
"error": None,
}
# 配置
run_config = {
"configurable": {
"thread_id": task_id,
}
}
# 执行图
try:
# 流式执行
async for event in self.graph.astream(initial_state, config=run_config):
# 发射事件
if self.event_emitter:
for node_name, node_state in event.items():
await self.event_emitter.emit_info(
f"节点 {node_name} 完成"
)
# 发射发现事件
# 发射 LLM 路由决策事件
if node_state.get("llm_routing_reason"):
await self.event_emitter.emit_info(
f"🧠 LLM 决策: {node_state.get('llm_next_action')} - {node_state.get('llm_routing_reason')}"
)
if node_name == "analysis" and node_state.get("findings"):
new_findings = node_state["findings"]
await self.event_emitter.emit_info(
f"发现 {len(new_findings)} 个潜在漏洞"
)
# 获取最终状态
final_state = self.graph.get_state(run_config)
return final_state.values
@ -412,44 +638,27 @@ class AuditGraphRunner:
initial_state: AuditState,
human_feedback_callback,
) -> Dict[str, Any]:
"""
带人机协作的执行
Args:
initial_state: 初始状态
human_feedback_callback: 人工反馈回调函数
Returns:
最终状态
"""
"""带人机协作的执行"""
run_config = {
"configurable": {
"thread_id": initial_state["task_id"],
}
}
# 执行到人工审核节点
async for event in self.graph.astream(initial_state, config=run_config):
pass
# 获取当前状态
current_state = self.graph.get_state(run_config)
# 如果在人工审核节点暂停
if current_state.next == ("human_review",):
# 调用人工反馈
human_decision = await human_feedback_callback(current_state.values)
# 更新状态并继续
updated_state = {
**current_state.values,
"should_continue_analysis": human_decision.get("continue_analysis", False),
}
# 继续执行
async for event in self.graph.astream(updated_state, config=run_config):
pass
# 返回最终状态
return self.graph.get_state(run_config).values

View File

@ -403,10 +403,19 @@ class AgentRunner:
Returns:
最终状态
"""
result = {}
async for _ in self.run_with_streaming():
pass # 消费所有事件
return result
final_state = {}
try:
async for event in self.run_with_streaming():
# 收集最终状态
if event.event_type == StreamEventType.TASK_COMPLETE:
final_state = event.data
elif event.event_type == StreamEventType.TASK_ERROR:
final_state = {"success": False, "error": event.data.get("error")}
except Exception as e:
logger.error(f"Agent run failed: {e}", exc_info=True)
final_state = {"success": False, "error": str(e)}
return final_state
async def run_with_streaming(self) -> AsyncGenerator[StreamEvent, None]:
"""

View File

@ -15,12 +15,20 @@ logger = logging.getLogger(__name__)
class StreamEventType(str, Enum):
"""流式事件类型"""
# LLM 相关
# 🔥 LLM 思考相关 - 这些是最重要的!展示 LLM 的大脑活动
LLM_START = "llm_start" # LLM 开始思考
LLM_THOUGHT = "llm_thought" # LLM 思考内容 ⭐ 核心
LLM_DECISION = "llm_decision" # LLM 决策 ⭐ 核心
LLM_ACTION = "llm_action" # LLM 动作
LLM_OBSERVATION = "llm_observation" # LLM 观察结果
LLM_COMPLETE = "llm_complete" # LLM 完成
# LLM Token 流 (实时输出)
THINKING_START = "thinking_start" # 开始思考
THINKING_TOKEN = "thinking_token" # 思考 Token
THINKING_TOKEN = "thinking_token" # 思考 Token (流式)
THINKING_END = "thinking_end" # 思考结束
# 工具调用相关
# 工具调用相关 - LLM 决定调用工具
TOOL_CALL_START = "tool_call_start" # 工具调用开始
TOOL_CALL_INPUT = "tool_call_input" # 工具输入参数
TOOL_CALL_OUTPUT = "tool_call_output" # 工具输出结果

View File

@ -4,13 +4,13 @@
* LLM
*/
import { useState, useEffect, useRef, useCallback } from "react";
import { useState, useEffect, useRef, useCallback, useMemo } from "react";
import { useParams, useNavigate } from "react-router-dom";
import {
Terminal, Bot, Cpu, Shield, AlertTriangle, CheckCircle2,
Terminal, Bot, Shield, AlertTriangle, CheckCircle2,
Loader2, Code, Zap, Activity, ChevronRight, XCircle,
FileCode, Search, Bug, Lock, Play, Square, RefreshCw,
ArrowLeft, Download, ExternalLink, Brain, Wrench,
FileCode, Search, Bug, Square, RefreshCw,
ArrowLeft, ExternalLink, Brain, Wrench,
ChevronDown, ChevronUp, Clock, Sparkles
} from "lucide-react";
import { Button } from "@/components/ui/button";
@ -26,42 +26,77 @@ import {
getAgentEvents,
getAgentFindings,
cancelAgentTask,
streamAgentEvents,
} from "@/shared/api/agentTasks";
// 事件类型图标映射
// 事件类型图标映射 - 🔥 重点展示 LLM 相关事件
const eventTypeIcons: Record<string, React.ReactNode> = {
// 🧠 LLM 核心事件 - 最重要!
llm_start: <Brain className="w-3 h-3 text-purple-400 animate-pulse" />,
llm_thought: <Sparkles className="w-3 h-3 text-purple-300" />,
llm_decision: <Zap className="w-3 h-3 text-yellow-400" />,
llm_action: <Zap className="w-3 h-3 text-orange-400" />,
llm_observation: <Search className="w-3 h-3 text-blue-400" />,
llm_complete: <CheckCircle2 className="w-3 h-3 text-green-400" />,
// 阶段相关
phase_start: <Zap className="w-3 h-3 text-cyan-400" />,
phase_complete: <CheckCircle2 className="w-3 h-3 text-green-400" />,
thinking: <Cpu className="w-3 h-3 text-purple-400" />,
tool_call: <Code className="w-3 h-3 text-yellow-400" />,
thinking: <Brain className="w-3 h-3 text-purple-400" />,
// 工具相关 - LLM 决定的工具调用
tool_call: <Wrench className="w-3 h-3 text-yellow-400" />,
tool_result: <CheckCircle2 className="w-3 h-3 text-green-400" />,
tool_error: <XCircle className="w-3 h-3 text-red-400" />,
// 发现相关
finding: <Bug className="w-3 h-3 text-orange-400" />,
finding_new: <Bug className="w-3 h-3 text-orange-400" />,
finding_verified: <Shield className="w-3 h-3 text-red-400" />,
// 状态相关
info: <Activity className="w-3 h-3 text-blue-400" />,
warning: <AlertTriangle className="w-3 h-3 text-yellow-400" />,
error: <XCircle className="w-3 h-3 text-red-500" />,
progress: <RefreshCw className="w-3 h-3 text-cyan-400 animate-spin" />,
// 任务相关
task_complete: <CheckCircle2 className="w-3 h-3 text-green-500" />,
task_error: <XCircle className="w-3 h-3 text-red-500" />,
task_cancel: <Square className="w-3 h-3 text-yellow-500" />,
};
// 事件类型颜色映射
// 事件类型颜色映射 - 🔥 LLM 事件突出显示
const eventTypeColors: Record<string, string> = {
// 🧠 LLM 核心事件 - 使用紫色系突出
llm_start: "text-purple-400 font-semibold",
llm_thought: "text-purple-300 bg-purple-950/30 rounded px-1", // 思考内容特别高亮
llm_decision: "text-yellow-300 font-semibold", // 决策特别突出
llm_action: "text-orange-300 font-medium",
llm_observation: "text-blue-300",
llm_complete: "text-green-400 font-semibold",
// 阶段相关
phase_start: "text-cyan-400 font-bold",
phase_complete: "text-green-400",
thinking: "text-purple-300",
// 工具相关
tool_call: "text-yellow-300",
tool_result: "text-green-300",
tool_error: "text-red-400",
// 发现相关
finding: "text-orange-300 font-medium",
finding_new: "text-orange-300",
finding_verified: "text-red-300",
finding_verified: "text-red-300 font-medium",
// 状态相关
info: "text-gray-300",
warning: "text-yellow-300",
error: "text-red-400",
progress: "text-cyan-300",
// 任务相关
task_complete: "text-green-400 font-bold",
task_error: "text-red-400 font-bold",
task_cancel: "text-yellow-400",
@ -99,30 +134,6 @@ export default function AgentAuditPage() {
const eventsEndRef = useRef<HTMLDivElement>(null);
const thinkingEndRef = useRef<HTMLDivElement>(null);
const abortControllerRef = useRef<AbortController | null>(null);
// 使用增强版流式 Hook
const {
thinking,
isThinking,
toolCalls,
currentPhase: streamPhase,
progress: streamProgress,
connect: connectStream,
disconnect: disconnectStream,
isConnected: isStreamConnected,
} = useAgentStream(taskId || null, {
includeThinking: true,
includeToolCalls: true,
onFinding: () => loadFindings(),
onComplete: () => {
loadTask();
loadFindings();
},
onError: (err) => {
console.error("Stream error:", err);
},
});
// 是否完成
const isComplete = task?.status === "completed" || task?.status === "failed" || task?.status === "cancelled";
@ -134,11 +145,16 @@ export default function AgentAuditPage() {
try {
const taskData = await getAgentTask(taskId);
setTask(taskData);
} catch (error) {
} catch (error: any) {
console.error("Failed to load task:", error);
toast.error("加载任务失败");
const errorMessage = error?.response?.data?.detail || error?.message || "未知错误";
toast.error(`加载任务失败: ${errorMessage}`);
// 如果是 404可能是任务不存在
if (error?.response?.status === 404) {
setTimeout(() => navigate("/tasks"), 2000);
}
}
}, [taskId]);
}, [taskId, navigate]);
// 加载事件
const loadEvents = useCallback(async () => {
@ -164,6 +180,32 @@ export default function AgentAuditPage() {
}
}, [taskId]);
// 🔥 稳定化回调函数,避免重复创建 connect/disconnect
const streamOptions = useMemo(() => ({
includeThinking: true,
includeToolCalls: true,
onFinding: () => loadFindings(),
onComplete: () => {
loadTask();
loadFindings();
},
onError: (err: string) => {
console.error("Stream error:", err);
},
}), [loadFindings, loadTask]);
// 使用增强版流式 Hook
const {
thinking,
isThinking,
toolCalls,
// currentPhase: streamPhase, // 暂未使用
// progress: streamProgress, // 暂未使用
connect: connectStream,
disconnect: disconnectStream,
isConnected: isStreamConnected,
} = useAgentStream(taskId || null, streamOptions);
// 初始化加载
useEffect(() => {
const init = async () => {
@ -175,10 +217,11 @@ export default function AgentAuditPage() {
init();
}, [loadTask, loadEvents, loadFindings]);
// 连接增强版流式 API
// 🔥 使用增强版流式 API优先
useEffect(() => {
if (!taskId || isComplete || isLoading) return;
// 连接流式 API
connectStream();
setIsStreaming(true);
@ -186,49 +229,44 @@ export default function AgentAuditPage() {
disconnectStream();
setIsStreaming(false);
};
}, [taskId, isComplete, isLoading, connectStream, disconnectStream]);
// eslint-disable-next-line react-hooks/exhaustive-deps
}, [taskId, isComplete, isLoading]); // connectStream/disconnectStream 是稳定的,不需要作为依赖
// 旧版事件流(作为后备)
// 🔥 后备:如果流式连接失败,使用轮询获取事件(仅作为后备)
useEffect(() => {
if (!taskId || isComplete || isLoading) return;
const startStreaming = async () => {
abortControllerRef.current = new AbortController();
// 如果流式连接已建立,不需要轮询
if (isStreamConnected) {
return;
}
// 每 5 秒轮询一次事件(作为后备机制)
const pollInterval = setInterval(async () => {
try {
const lastSequence = events.length > 0 ? Math.max(...events.map(e => e.sequence)) : 0;
const newEvents = await getAgentEvents(taskId, { after_sequence: lastSequence, limit: 50 });
for await (const event of streamAgentEvents(taskId, lastSequence, abortControllerRef.current.signal)) {
if (newEvents.length > 0) {
setEvents(prev => {
// 避免重复
if (prev.some(e => e.id === event.id)) return prev;
return [...prev, event];
// 合并新事件,避免重复
const existingIds = new Set(prev.map(e => e.id));
const uniqueNew = newEvents.filter(e => !existingIds.has(e.id));
return [...prev, ...uniqueNew];
});
// 如果是发现事件,刷新发现列表
if (event.event_type.startsWith("finding_")) {
loadFindings();
}
// 如果是结束事件,刷新任务状态
if (["task_complete", "task_error", "task_cancel"].includes(event.event_type)) {
loadTask();
// 如果有发现事件,刷新发现列表
if (newEvents.some(e => e.event_type.startsWith("finding_"))) {
loadFindings();
}
}
} catch (error) {
if ((error as Error).name !== "AbortError") {
console.error("Event stream error:", error);
}
console.error("Failed to poll events:", error);
}
};
}, 5000);
startStreaming();
return () => {
abortControllerRef.current?.abort();
};
}, [taskId, isComplete, isLoading, loadTask, loadFindings]);
return () => clearInterval(pollInterval);
}, [taskId, isComplete, isLoading, isStreamConnected, events.length, loadFindings]);
// 自动滚动
useEffect(() => {
@ -244,17 +282,17 @@ export default function AgentAuditPage() {
return () => clearInterval(interval);
}, []);
// 定期轮询任务状态(作为 SSE 的后备机制)
// 定期轮询任务状态(作为 SSE 的后备机制)- 🔥 增加间隔,避免资源耗尽
useEffect(() => {
if (!taskId || isComplete || isLoading) return;
// 每 3 秒轮询一次任务状态
// 🔥 每 10 秒轮询一次(而不是 3 秒),减少资源消耗
const pollInterval = setInterval(async () => {
try {
const taskData = await getAgentTask(taskId);
setTask(taskData);
// 如果任务已完成/失败/取消,刷新其他数据
// 如果任务已完成/失败/取消,刷新其他数据并停止轮询
if (taskData.status === "completed" || taskData.status === "failed" || taskData.status === "cancelled") {
await loadEvents();
await loadFindings();
@ -262,11 +300,13 @@ export default function AgentAuditPage() {
}
} catch (error) {
console.error("Failed to poll task status:", error);
// 🔥 如果连续失败,停止轮询避免资源耗尽
clearInterval(pollInterval);
}
}, 3000);
}, 10000); // 🔥 改为 10 秒
return () => clearInterval(pollInterval);
}, [taskId, isComplete, isLoading, loadEvents, loadFindings]);
}, [taskId, isComplete, isLoading]); // 🔥 移除函数依赖
// 取消任务
const handleCancel = async () => {
@ -371,57 +411,73 @@ export default function AgentAuditPage() {
{/* 左侧:执行日志 */}
<div className="flex-1 p-4 flex flex-col min-w-0">
{/* 思考过程展示区域 */}
{/* 🧠 LLM 思考过程展示区域 - 核心!展示 LLM 的大脑活动 */}
{(isThinking || thinking) && showThinking && (
<div className="mb-4 bg-purple-950/30 rounded-lg border border-purple-800/50 overflow-hidden">
<div className="mb-4 bg-purple-950/40 rounded-lg border-2 border-purple-700/60 overflow-hidden shadow-lg shadow-purple-900/20">
<div
className="flex items-center justify-between px-3 py-2 bg-purple-900/30 border-b border-purple-800/30 cursor-pointer"
className="flex items-center justify-between px-4 py-3 bg-purple-900/50 border-b border-purple-700/50 cursor-pointer"
onClick={() => setShowThinking(!showThinking)}
>
<div className="flex items-center gap-2 text-xs text-purple-400">
<Brain className={`w-4 h-4 ${isThinking ? "animate-pulse" : ""}`} />
<span className="uppercase tracking-wider">AI Thinking</span>
<div className="flex items-center gap-3 text-sm text-purple-300">
<div className="p-1.5 bg-purple-800/50 rounded-lg">
<Brain className={`w-5 h-5 ${isThinking ? "animate-pulse" : ""}`} />
</div>
<div>
<span className="uppercase tracking-wider font-semibold">🧠 LLM Thinking</span>
<span className="text-purple-400 ml-2 text-xs">Agent </span>
</div>
{isThinking && (
<span className="flex items-center gap-1 text-purple-300">
<span className="flex items-center gap-1 text-purple-200 bg-purple-800/50 px-2 py-0.5 rounded-full text-xs">
<Sparkles className="w-3 h-3 animate-spin" />
<span className="text-[10px]">Processing...</span>
<span>...</span>
</span>
)}
</div>
{showThinking ? <ChevronUp className="w-4 h-4 text-purple-400" /> : <ChevronDown className="w-4 h-4 text-purple-400" />}
{showThinking ? <ChevronUp className="w-5 h-5 text-purple-400" /> : <ChevronDown className="w-5 h-5 text-purple-400" />}
</div>
<div className="max-h-40 overflow-y-auto">
<div className="p-3 text-sm text-purple-200/80 font-mono whitespace-pre-wrap">
{thinking || "正在思考..."}
{isThinking && <span className="animate-pulse text-purple-400"></span>}
<div className="max-h-52 overflow-y-auto bg-[#1a1025]">
<div className="p-4 text-sm text-purple-100 font-mono whitespace-pre-wrap leading-relaxed">
{thinking || "🤔 正在思考下一步..."}
{isThinking && <span className="animate-pulse text-purple-400 text-lg"></span>}
</div>
<div ref={thinkingEndRef} />
</div>
</div>
)}
{/* 工具调用展示区域 */}
{/* 🔧 LLM 工具调用展示区域 - LLM 决定调用的工具 */}
{toolCalls.length > 0 && showToolDetails && (
<div className="mb-4 bg-yellow-950/20 rounded-lg border border-yellow-800/30 overflow-hidden">
<div className="mb-4 bg-yellow-950/30 rounded-lg border-2 border-yellow-700/50 overflow-hidden shadow-lg shadow-yellow-900/10">
<div
className="flex items-center justify-between px-3 py-2 bg-yellow-900/20 border-b border-yellow-800/20 cursor-pointer"
className="flex items-center justify-between px-4 py-3 bg-yellow-900/30 border-b border-yellow-700/40 cursor-pointer"
onClick={() => setShowToolDetails(!showToolDetails)}
>
<div className="flex items-center gap-2 text-xs text-yellow-500">
<Wrench className="w-4 h-4" />
<span className="uppercase tracking-wider">Tool Calls</span>
<Badge variant="outline" className="text-[10px] px-1.5 py-0 bg-yellow-900/30 border-yellow-700 text-yellow-400">
{toolCalls.length}
<div className="flex items-center gap-3 text-sm text-yellow-400">
<div className="p-1.5 bg-yellow-800/50 rounded-lg">
<Wrench className="w-5 h-5" />
</div>
<div>
<span className="uppercase tracking-wider font-semibold">🔧 LLM Tool Calls</span>
<span className="text-yellow-500 ml-2 text-xs">LLM </span>
</div>
<Badge variant="outline" className="text-xs px-2 py-0.5 bg-yellow-900/50 border-yellow-600 text-yellow-300">
{toolCalls.length}
</Badge>
</div>
{showToolDetails ? <ChevronUp className="w-4 h-4 text-yellow-500" /> : <ChevronDown className="w-4 h-4 text-yellow-500" />}
{showToolDetails ? <ChevronUp className="w-5 h-5 text-yellow-500" /> : <ChevronDown className="w-5 h-5 text-yellow-500" />}
</div>
<div className="max-h-48 overflow-y-auto">
<div className="p-2 space-y-2">
<div className="max-h-52 overflow-y-auto bg-[#1a1810]">
<div className="p-3 space-y-2">
{toolCalls.slice(-5).map((tc, idx) => (
<ToolCallCard key={`${tc.name}-${idx}`} toolCall={tc} />
<ToolCallCard
key={`${tc.name}-${idx}`}
toolCall={{
...tc,
output: tc.output as string | Record<string, unknown> | undefined
}}
/>
))}
</div>
</div>
@ -429,17 +485,20 @@ export default function AgentAuditPage() {
)}
<div className="flex items-center justify-between mb-3">
<div className="flex items-center gap-2 text-xs text-cyan-400">
<Terminal className="w-4 h-4" />
<span className="uppercase tracking-wider">Execution Log</span>
<div className="flex items-center gap-3 text-sm text-cyan-400">
<div className="flex items-center gap-2">
<Terminal className="w-4 h-4" />
<span className="uppercase tracking-wider font-semibold">LLM Execution Log</span>
</div>
<span className="text-xs text-gray-500">LLM & </span>
{(isStreaming || isStreamConnected) && (
<span className="flex items-center gap-1 text-green-400">
<span className="flex items-center gap-1.5 text-green-400 bg-green-900/30 px-2 py-0.5 rounded-full text-xs">
<span className="w-2 h-2 bg-green-400 rounded-full animate-pulse" />
LIVE
</span>
)}
</div>
<span className="text-xs text-gray-500">{events.length} events</span>
<span className="text-xs text-gray-500">{events.length} </span>
</div>
{/* 终端窗口 */}
@ -649,7 +708,7 @@ function StatusBadge({ status }: { status: string }) {
);
}
// 事件行组件
// 事件行组件 - 增强 LLM 事件展示
function EventLine({ event }: { event: AgentEvent }) {
const icon = eventTypeIcons[event.event_type] || <ChevronRight className="w-3 h-3 text-gray-500" />;
const colorClass = eventTypeColors[event.event_type] || "text-gray-400";
@ -658,13 +717,28 @@ function EventLine({ event }: { event: AgentEvent }) {
? new Date(event.timestamp).toLocaleTimeString("zh-CN", { hour12: false })
: "";
// LLM 思考事件特殊处理 - 展示多行内容
const isLLMThought = event.event_type === "llm_thought";
const isLLMDecision = event.event_type === "llm_decision";
const isLLMAction = event.event_type === "llm_action";
const isImportantLLMEvent = isLLMThought || isLLMDecision || isLLMAction;
// LLM 事件背景色
const bgClass = isLLMThought
? "bg-purple-950/40 border-l-2 border-purple-600"
: isLLMDecision
? "bg-yellow-950/30 border-l-2 border-yellow-600"
: isLLMAction
? "bg-orange-950/30 border-l-2 border-orange-600"
: "";
return (
<div className={`flex items-start gap-2 py-0.5 group hover:bg-white/5 px-1 rounded ${colorClass}`}>
<div className={`flex items-start gap-2 py-1 group hover:bg-white/5 px-2 rounded ${colorClass} ${bgClass}`}>
<span className="text-gray-600 text-xs w-20 flex-shrink-0 group-hover:text-gray-500">
{timestamp}
</span>
<span className="flex-shrink-0 mt-0.5">{icon}</span>
<span className="flex-1 text-sm break-all">
<span className={`flex-1 text-sm break-all ${isImportantLLMEvent ? "whitespace-pre-wrap" : ""}`}>
{event.message}
{event.tool_duration_ms && (
<span className="text-gray-600 ml-2">({event.tool_duration_ms}ms)</span>
@ -679,7 +753,7 @@ interface ToolCallProps {
toolCall: {
name: string;
input: Record<string, unknown>;
output?: unknown;
output?: string | Record<string, unknown>;
durationMs?: number;
status: 'running' | 'success' | 'error';
};

View File

@ -106,6 +106,9 @@ export class AgentStreamHandler {
private reconnectDelay = 1000;
private isConnected = false;
private thinkingBuffer: string[] = [];
private reader: ReadableStreamDefaultReader<Uint8Array> | null = null; // 🔥 保存 reader 引用
private abortController: AbortController | null = null; // 🔥 用于取消请求
private isDisconnecting = false; // 🔥 标记是否正在断开
constructor(taskId: string, options: StreamOptions = {}) {
this.taskId = taskId;
@ -121,6 +124,11 @@ export class AgentStreamHandler {
*
*/
connect(): void {
// 🔥 如果已经连接,不重复连接
if (this.isConnected || this.isDisconnecting) {
return;
}
const token = localStorage.getItem('access_token');
if (!token) {
this.options.onError?.('未登录');
@ -142,14 +150,23 @@ export class AgentStreamHandler {
* 使 fetch headers
*/
private async connectWithFetch(token: string, params: URLSearchParams): Promise<void> {
// 🔥 如果正在断开,不连接
if (this.isDisconnecting) {
return;
}
const url = `/api/v1/agent-tasks/${this.taskId}/stream?${params}`;
// 🔥 创建 AbortController 用于取消请求
this.abortController = new AbortController();
try {
const response = await fetch(url, {
headers: {
'Authorization': `Bearer ${token}`,
'Accept': 'text/event-stream',
},
signal: this.abortController.signal, // 🔥 支持取消
});
if (!response.ok) {
@ -159,8 +176,8 @@ export class AgentStreamHandler {
this.isConnected = true;
this.reconnectAttempts = 0;
const reader = response.body?.getReader();
if (!reader) {
this.reader = response.body?.getReader() || null;
if (!this.reader) {
throw new Error('无法获取响应流');
}
@ -168,7 +185,12 @@ export class AgentStreamHandler {
let buffer = '';
while (true) {
const { done, value } = await reader.read();
// 🔥 检查是否正在断开
if (this.isDisconnecting) {
break;
}
const { done, value } = await this.reader.read();
if (done) {
break;
@ -184,17 +206,42 @@ export class AgentStreamHandler {
this.handleEvent(event);
}
}
} catch (error) {
// 🔥 正常结束,清理 reader
if (this.reader) {
this.reader.releaseLock();
this.reader = null;
}
} catch (error: any) {
// 🔥 如果是取消错误,不处理
if (error.name === 'AbortError') {
return;
}
this.isConnected = false;
console.error('Stream connection error:', error);
// 尝试重连
if (this.reconnectAttempts < this.maxReconnectAttempts) {
// 🔥 只有在未断开时才尝试重连
if (!this.isDisconnecting && this.reconnectAttempts < this.maxReconnectAttempts) {
this.reconnectAttempts++;
setTimeout(() => this.connect(), this.reconnectDelay * this.reconnectAttempts);
setTimeout(() => {
if (!this.isDisconnecting) {
this.connect();
}
}, this.reconnectDelay * this.reconnectAttempts);
} else {
this.options.onError?.(`连接失败: ${error}`);
}
} finally {
// 🔥 清理 reader
if (this.reader) {
try {
this.reader.releaseLock();
} catch {
// 忽略释放错误
}
this.reader = null;
}
}
}
@ -357,11 +404,35 @@ export class AgentStreamHandler {
*
*/
disconnect(): void {
// 🔥 标记正在断开,防止重连
this.isDisconnecting = true;
this.isConnected = false;
// 🔥 取消 fetch 请求
if (this.abortController) {
this.abortController.abort();
this.abortController = null;
}
// 🔥 清理 reader
if (this.reader) {
try {
this.reader.cancel();
this.reader.releaseLock();
} catch {
// 忽略清理错误
}
this.reader = null;
}
// 清理 EventSource如果使用
if (this.eventSource) {
this.eventSource.close();
this.eventSource = null;
}
// 重置重连计数
this.reconnectAttempts = 0;
}
/**