feat(agent): enhance agent functionality with LLM-driven decision-making and event handling
- Introduce LLM-driven decision-making across various agents, allowing for dynamic adjustments based on real-time analysis. - Implement new event types for LLM thinking, decisions, actions, and observations to enrich the event streaming experience. - Update agent task responses to include additional metrics for better tracking of task progress and outcomes. - Refactor UI components to highlight LLM-related events and improve user interaction during audits. - Enhance API endpoints to support new event structures and improve overall error handling.
This commit is contained in:
parent
58c918f557
commit
8938a8a3c9
|
|
@ -75,18 +75,46 @@ class AgentTaskCreate(BaseModel):
|
|||
|
||||
|
||||
class AgentTaskResponse(BaseModel):
|
||||
"""Agent 任务响应"""
|
||||
"""Agent 任务响应 - 包含所有前端需要的字段"""
|
||||
id: str
|
||||
project_id: str
|
||||
name: Optional[str]
|
||||
description: Optional[str]
|
||||
task_type: str = "agent_audit"
|
||||
status: str
|
||||
current_phase: Optional[str]
|
||||
current_step: Optional[str] = None
|
||||
|
||||
# 统计
|
||||
total_findings: int = 0
|
||||
verified_findings: int = 0
|
||||
security_score: Optional[int] = None
|
||||
# 进度统计
|
||||
total_files: int = 0
|
||||
indexed_files: int = 0
|
||||
analyzed_files: int = 0
|
||||
total_chunks: int = 0
|
||||
|
||||
# Agent 统计
|
||||
total_iterations: int = 0
|
||||
tool_calls_count: int = 0
|
||||
tokens_used: int = 0
|
||||
|
||||
# 发现统计(兼容两种命名)
|
||||
findings_count: int = 0
|
||||
total_findings: int = 0 # 兼容字段
|
||||
verified_count: int = 0
|
||||
verified_findings: int = 0 # 兼容字段
|
||||
false_positive_count: int = 0
|
||||
|
||||
# 严重程度统计
|
||||
critical_count: int = 0
|
||||
high_count: int = 0
|
||||
medium_count: int = 0
|
||||
low_count: int = 0
|
||||
|
||||
# 评分
|
||||
quality_score: float = 0.0
|
||||
security_score: Optional[float] = None
|
||||
|
||||
# 进度百分比
|
||||
progress_percentage: float = 0.0
|
||||
|
||||
# 时间
|
||||
created_at: datetime
|
||||
|
|
@ -94,7 +122,11 @@ class AgentTaskResponse(BaseModel):
|
|||
completed_at: Optional[datetime] = None
|
||||
|
||||
# 配置
|
||||
config: Optional[dict] = None
|
||||
audit_scope: Optional[dict] = None
|
||||
target_vulnerabilities: Optional[List[str]] = None
|
||||
verification_level: Optional[str] = None
|
||||
exclude_patterns: Optional[List[str]] = None
|
||||
target_files: Optional[List[str]] = None
|
||||
|
||||
# 错误信息
|
||||
error_message: Optional[str] = None
|
||||
|
|
@ -177,6 +209,12 @@ async def _execute_agent_task(task_id: str, project_root: str):
|
|||
logger.error(f"Task {task_id} not found")
|
||||
return
|
||||
|
||||
# 更新状态为运行中
|
||||
task.status = AgentTaskStatus.RUNNING
|
||||
task.started_at = datetime.now(timezone.utc)
|
||||
await db.commit()
|
||||
logger.info(f"Task {task_id} started")
|
||||
|
||||
# 创建 Runner
|
||||
runner = AgentRunner(db, task, project_root)
|
||||
_running_tasks[task_id] = runner
|
||||
|
|
@ -184,22 +222,37 @@ async def _execute_agent_task(task_id: str, project_root: str):
|
|||
# 执行
|
||||
result = await runner.run()
|
||||
|
||||
logger.info(f"Task {task_id} completed: {result.get('success', False)}")
|
||||
# 更新任务状态
|
||||
await db.refresh(task)
|
||||
if result.get('success', True): # 默认成功,除非明确失败
|
||||
task.status = AgentTaskStatus.COMPLETED
|
||||
task.completed_at = datetime.now(timezone.utc)
|
||||
else:
|
||||
task.status = AgentTaskStatus.FAILED
|
||||
task.error_message = result.get('error', 'Unknown error')
|
||||
task.completed_at = datetime.now(timezone.utc)
|
||||
|
||||
await db.commit()
|
||||
logger.info(f"Task {task_id} completed with status: {task.status}")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Task {task_id} failed: {e}", exc_info=True)
|
||||
|
||||
# 更新任务状态
|
||||
task = await db.get(AgentTask, task_id)
|
||||
if task:
|
||||
task.status = AgentTaskStatus.FAILED
|
||||
task.error_message = str(e)
|
||||
task.completed_at = datetime.now(timezone.utc)
|
||||
await db.commit()
|
||||
try:
|
||||
task = await db.get(AgentTask, task_id)
|
||||
if task:
|
||||
task.status = AgentTaskStatus.FAILED
|
||||
task.error_message = str(e)[:1000] # 限制错误消息长度
|
||||
task.completed_at = datetime.now(timezone.utc)
|
||||
await db.commit()
|
||||
except Exception as db_error:
|
||||
logger.error(f"Failed to update task status: {db_error}")
|
||||
|
||||
finally:
|
||||
# 清理
|
||||
_running_tasks.pop(task_id, None)
|
||||
logger.debug(f"Task {task_id} cleaned up")
|
||||
|
||||
|
||||
# ============ API Endpoints ============
|
||||
|
|
@ -315,7 +368,61 @@ async def get_agent_task(
|
|||
if not project or project.owner_id != current_user.id:
|
||||
raise HTTPException(status_code=403, detail="无权访问此任务")
|
||||
|
||||
return task
|
||||
# 构建响应,确保所有字段都包含
|
||||
try:
|
||||
# 计算进度百分比
|
||||
progress = 0.0
|
||||
if hasattr(task, 'progress_percentage'):
|
||||
progress = task.progress_percentage
|
||||
elif task.status == AgentTaskStatus.COMPLETED:
|
||||
progress = 100.0
|
||||
elif task.status in [AgentTaskStatus.FAILED, AgentTaskStatus.CANCELLED]:
|
||||
progress = 0.0
|
||||
|
||||
# 手动构建响应数据
|
||||
response_data = {
|
||||
"id": task.id,
|
||||
"project_id": task.project_id,
|
||||
"name": task.name,
|
||||
"description": task.description,
|
||||
"task_type": task.task_type or "agent_audit",
|
||||
"status": task.status,
|
||||
"current_phase": task.current_phase,
|
||||
"current_step": task.current_step,
|
||||
"total_files": task.total_files or 0,
|
||||
"indexed_files": task.indexed_files or 0,
|
||||
"analyzed_files": task.analyzed_files or 0,
|
||||
"total_chunks": task.total_chunks or 0,
|
||||
"total_iterations": task.total_iterations or 0,
|
||||
"tool_calls_count": task.tool_calls_count or 0,
|
||||
"tokens_used": task.tokens_used or 0,
|
||||
"findings_count": task.findings_count or 0,
|
||||
"total_findings": task.findings_count or 0, # 兼容字段
|
||||
"verified_count": task.verified_count or 0,
|
||||
"verified_findings": task.verified_count or 0, # 兼容字段
|
||||
"false_positive_count": task.false_positive_count or 0,
|
||||
"critical_count": task.critical_count or 0,
|
||||
"high_count": task.high_count or 0,
|
||||
"medium_count": task.medium_count or 0,
|
||||
"low_count": task.low_count or 0,
|
||||
"quality_score": float(task.quality_score or 0.0),
|
||||
"security_score": float(task.security_score) if task.security_score is not None else None,
|
||||
"progress_percentage": progress,
|
||||
"created_at": task.created_at,
|
||||
"started_at": task.started_at,
|
||||
"completed_at": task.completed_at,
|
||||
"error_message": task.error_message,
|
||||
"audit_scope": task.audit_scope,
|
||||
"target_vulnerabilities": task.target_vulnerabilities,
|
||||
"verification_level": task.verification_level,
|
||||
"exclude_patterns": task.exclude_patterns,
|
||||
"target_files": task.target_files,
|
||||
}
|
||||
|
||||
return AgentTaskResponse(**response_data)
|
||||
except Exception as e:
|
||||
logger.error(f"Error serializing task {task_id}: {e}", exc_info=True)
|
||||
raise HTTPException(status_code=500, detail=f"序列化任务数据失败: {str(e)}")
|
||||
|
||||
|
||||
@router.post("/{task_id}/cancel")
|
||||
|
|
@ -396,10 +503,14 @@ async def stream_agent_events(
|
|||
idle_time = 0
|
||||
for event in events:
|
||||
last_sequence = event.sequence
|
||||
# event_type 已经是字符串,不需要 .value
|
||||
event_type_str = str(event.event_type)
|
||||
phase_str = str(event.phase) if event.phase else None
|
||||
|
||||
data = {
|
||||
"id": event.id,
|
||||
"type": event.event_type.value if hasattr(event.event_type, 'value') else str(event.event_type),
|
||||
"phase": event.phase.value if event.phase and hasattr(event.phase, 'value') else event.phase,
|
||||
"type": event_type_str,
|
||||
"phase": phase_str,
|
||||
"message": event.message,
|
||||
"sequence": event.sequence,
|
||||
"timestamp": event.created_at.isoformat() if event.created_at else None,
|
||||
|
|
@ -411,9 +522,12 @@ async def stream_agent_events(
|
|||
idle_time += poll_interval
|
||||
|
||||
# 检查任务是否结束
|
||||
if task_status in [AgentTaskStatus.COMPLETED, AgentTaskStatus.FAILED, AgentTaskStatus.CANCELLED]:
|
||||
yield f"data: {json.dumps({'type': 'task_end', 'status': task_status.value})}\n\n"
|
||||
break
|
||||
if task_status:
|
||||
# task_status 可能是字符串或枚举,统一转换为字符串
|
||||
status_str = str(task_status)
|
||||
if status_str in ["completed", "failed", "cancelled"]:
|
||||
yield f"data: {json.dumps({'type': 'task_end', 'status': status_str})}\n\n"
|
||||
break
|
||||
|
||||
# 检查空闲超时
|
||||
if idle_time >= max_idle:
|
||||
|
|
@ -512,8 +626,8 @@ async def stream_agent_with_thinking(
|
|||
for event in events:
|
||||
last_sequence = event.sequence
|
||||
|
||||
# 获取事件类型字符串
|
||||
event_type = event.event_type.value if hasattr(event.event_type, 'value') else str(event.event_type)
|
||||
# 获取事件类型字符串(event_type 已经是字符串)
|
||||
event_type = str(event.event_type)
|
||||
|
||||
# 过滤事件
|
||||
if event_type in skip_types:
|
||||
|
|
@ -523,7 +637,7 @@ async def stream_agent_with_thinking(
|
|||
data = {
|
||||
"id": event.id,
|
||||
"type": event_type,
|
||||
"phase": event.phase.value if event.phase and hasattr(event.phase, 'value') else event.phase,
|
||||
"phase": str(event.phase) if event.phase else None,
|
||||
"message": event.message,
|
||||
"sequence": event.sequence,
|
||||
"timestamp": event.created_at.isoformat() if event.created_at else None,
|
||||
|
|
@ -552,14 +666,16 @@ async def stream_agent_with_thinking(
|
|||
idle_time += poll_interval
|
||||
|
||||
# 检查任务是否结束
|
||||
if task_status in [AgentTaskStatus.COMPLETED, AgentTaskStatus.FAILED, AgentTaskStatus.CANCELLED]:
|
||||
end_data = {
|
||||
"type": "task_end",
|
||||
"status": task_status.value,
|
||||
"message": f"任务{'完成' if task_status == AgentTaskStatus.COMPLETED else '结束'}",
|
||||
}
|
||||
yield f"event: task_end\ndata: {json.dumps(end_data, ensure_ascii=False)}\n\n"
|
||||
break
|
||||
if task_status:
|
||||
status_str = str(task_status)
|
||||
if status_str in ["completed", "failed", "cancelled"]:
|
||||
end_data = {
|
||||
"type": "task_end",
|
||||
"status": status_str,
|
||||
"message": f"任务{'完成' if status_str == 'completed' else '结束'}",
|
||||
}
|
||||
yield f"event: task_end\ndata: {json.dumps(end_data, ensure_ascii=False)}\n\n"
|
||||
break
|
||||
|
||||
# 发送心跳
|
||||
last_heartbeat += poll_interval
|
||||
|
|
@ -709,8 +825,9 @@ async def get_task_summary(
|
|||
verified_count = 0
|
||||
|
||||
for f in findings:
|
||||
sev = f.severity.value if hasattr(f.severity, 'value') else str(f.severity)
|
||||
vtype = f.vulnerability_type.value if hasattr(f.vulnerability_type, 'value') else str(f.vulnerability_type)
|
||||
# severity 和 vulnerability_type 已经是字符串
|
||||
sev = str(f.severity)
|
||||
vtype = str(f.vulnerability_type)
|
||||
|
||||
severity_distribution[sev] = severity_distribution.get(sev, 0) + 1
|
||||
vulnerability_types[vtype] = vulnerability_types.get(vtype, 0) + 1
|
||||
|
|
@ -730,11 +847,11 @@ async def get_task_summary(
|
|||
.where(AgentEvent.event_type == AgentEventType.PHASE_COMPLETE)
|
||||
.distinct()
|
||||
)
|
||||
phases = [p[0].value if p[0] and hasattr(p[0], 'value') else str(p[0]) for p in phases_result.fetchall() if p[0]]
|
||||
phases = [str(p[0]) for p in phases_result.fetchall() if p[0]]
|
||||
|
||||
return TaskSummaryResponse(
|
||||
task_id=task_id,
|
||||
status=task.status.value if hasattr(task.status, 'value') else str(task.status),
|
||||
status=str(task.status), # status 已经是字符串
|
||||
security_score=task.security_score,
|
||||
total_findings=len(findings),
|
||||
verified_findings=verified_count,
|
||||
|
|
|
|||
File diff suppressed because it is too large
Load Diff
|
|
@ -1,6 +1,8 @@
|
|||
"""
|
||||
Agent 基类
|
||||
定义 Agent 的基本接口和通用功能
|
||||
|
||||
核心原则:LLM 是 Agent 的大脑,所有日志应该反映 LLM 的参与!
|
||||
"""
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
|
|
@ -87,7 +89,11 @@ class AgentResult:
|
|||
class BaseAgent(ABC):
|
||||
"""
|
||||
Agent 基类
|
||||
所有 Agent 需要继承此类并实现核心方法
|
||||
|
||||
核心原则:
|
||||
1. LLM 是 Agent 的大脑,全程参与决策
|
||||
2. 所有日志应该反映 LLM 的思考过程
|
||||
3. 工具调用是 LLM 的决策结果
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
|
|
@ -146,6 +152,8 @@ class BaseAgent(ABC):
|
|||
def is_cancelled(self) -> bool:
|
||||
return self._cancelled
|
||||
|
||||
# ============ 核心事件发射方法 ============
|
||||
|
||||
async def emit_event(
|
||||
self,
|
||||
event_type: str,
|
||||
|
|
@ -161,28 +169,123 @@ class BaseAgent(ABC):
|
|||
**kwargs
|
||||
))
|
||||
|
||||
# ============ LLM 思考相关事件 ============
|
||||
|
||||
async def emit_thinking(self, message: str):
|
||||
"""发射思考事件"""
|
||||
await self.emit_event("thinking", f"[{self.name}] {message}")
|
||||
"""发射 LLM 思考事件"""
|
||||
await self.emit_event("thinking", f"🧠 [{self.name}] {message}")
|
||||
|
||||
async def emit_llm_start(self, iteration: int):
|
||||
"""发射 LLM 开始思考事件"""
|
||||
await self.emit_event(
|
||||
"llm_start",
|
||||
f"🤔 [{self.name}] LLM 开始第 {iteration} 轮思考...",
|
||||
metadata={"iteration": iteration}
|
||||
)
|
||||
|
||||
async def emit_llm_thought(self, thought: str, iteration: int):
|
||||
"""发射 LLM 思考内容事件 - 这是核心!展示 LLM 在想什么"""
|
||||
# 截断过长的思考内容
|
||||
display_thought = thought[:500] + "..." if len(thought) > 500 else thought
|
||||
await self.emit_event(
|
||||
"llm_thought",
|
||||
f"💭 [{self.name}] LLM 思考:\n{display_thought}",
|
||||
metadata={
|
||||
"thought": thought,
|
||||
"iteration": iteration,
|
||||
}
|
||||
)
|
||||
|
||||
async def emit_llm_decision(self, decision: str, reason: str = ""):
|
||||
"""发射 LLM 决策事件 - 展示 LLM 做了什么决定"""
|
||||
await self.emit_event(
|
||||
"llm_decision",
|
||||
f"💡 [{self.name}] LLM 决策: {decision}" + (f" (理由: {reason})" if reason else ""),
|
||||
metadata={
|
||||
"decision": decision,
|
||||
"reason": reason,
|
||||
}
|
||||
)
|
||||
|
||||
async def emit_llm_action(self, action: str, action_input: Dict):
|
||||
"""发射 LLM 动作事件 - LLM 决定执行什么动作"""
|
||||
import json
|
||||
input_str = json.dumps(action_input, ensure_ascii=False)[:200]
|
||||
await self.emit_event(
|
||||
"llm_action",
|
||||
f"⚡ [{self.name}] LLM 动作: {action}\n 参数: {input_str}",
|
||||
metadata={
|
||||
"action": action,
|
||||
"action_input": action_input,
|
||||
}
|
||||
)
|
||||
|
||||
async def emit_llm_observation(self, observation: str):
|
||||
"""发射 LLM 观察事件 - LLM 看到了什么"""
|
||||
display_obs = observation[:300] + "..." if len(observation) > 300 else observation
|
||||
await self.emit_event(
|
||||
"llm_observation",
|
||||
f"👁️ [{self.name}] LLM 观察到:\n{display_obs}",
|
||||
metadata={"observation": observation[:2000]}
|
||||
)
|
||||
|
||||
async def emit_llm_complete(self, result_summary: str, tokens_used: int):
|
||||
"""发射 LLM 完成事件"""
|
||||
await self.emit_event(
|
||||
"llm_complete",
|
||||
f"✅ [{self.name}] LLM 完成: {result_summary} (消耗 {tokens_used} tokens)",
|
||||
metadata={
|
||||
"tokens_used": tokens_used,
|
||||
}
|
||||
)
|
||||
|
||||
# ============ 工具调用相关事件 ============
|
||||
|
||||
async def emit_tool_call(self, tool_name: str, tool_input: Dict):
|
||||
"""发射工具调用事件"""
|
||||
"""发射工具调用事件 - LLM 决定调用工具"""
|
||||
import json
|
||||
input_str = json.dumps(tool_input, ensure_ascii=False)[:300]
|
||||
await self.emit_event(
|
||||
"tool_call",
|
||||
f"[{self.name}] 调用工具: {tool_name}",
|
||||
f"🔧 [{self.name}] LLM 调用工具: {tool_name}\n 输入: {input_str}",
|
||||
tool_name=tool_name,
|
||||
tool_input=tool_input,
|
||||
)
|
||||
|
||||
async def emit_tool_result(self, tool_name: str, result: str, duration_ms: int):
|
||||
"""发射工具结果事件"""
|
||||
result_preview = result[:200] + "..." if len(result) > 200 else result
|
||||
await self.emit_event(
|
||||
"tool_result",
|
||||
f"[{self.name}] {tool_name} 完成 ({duration_ms}ms)",
|
||||
f"📤 [{self.name}] 工具 {tool_name} 返回 ({duration_ms}ms):\n {result_preview}",
|
||||
tool_name=tool_name,
|
||||
tool_duration_ms=duration_ms,
|
||||
)
|
||||
|
||||
# ============ 发现相关事件 ============
|
||||
|
||||
async def emit_finding(self, title: str, severity: str, vuln_type: str, file_path: str = ""):
|
||||
"""发射漏洞发现事件"""
|
||||
severity_emoji = {
|
||||
"critical": "🔴",
|
||||
"high": "🟠",
|
||||
"medium": "🟡",
|
||||
"low": "🟢",
|
||||
}.get(severity.lower(), "⚪")
|
||||
|
||||
await self.emit_event(
|
||||
"finding",
|
||||
f"{severity_emoji} [{self.name}] 发现漏洞: [{severity.upper()}] {title}\n 类型: {vuln_type}\n 位置: {file_path}",
|
||||
metadata={
|
||||
"title": title,
|
||||
"severity": severity,
|
||||
"vulnerability_type": vuln_type,
|
||||
"file_path": file_path,
|
||||
}
|
||||
)
|
||||
|
||||
# ============ 通用工具方法 ============
|
||||
|
||||
async def call_tool(self, tool_name: str, **kwargs) -> Any:
|
||||
"""
|
||||
调用工具
|
||||
|
|
@ -208,7 +311,7 @@ class BaseAgent(ABC):
|
|||
result = await tool.execute(**kwargs)
|
||||
|
||||
duration_ms = int((time.time() - start) * 1000)
|
||||
await self.emit_tool_result(tool_name, str(result.data)[:200], duration_ms)
|
||||
await self.emit_tool_result(tool_name, str(result.data)[:500], duration_ms)
|
||||
|
||||
return result
|
||||
|
||||
|
|
@ -229,8 +332,9 @@ class BaseAgent(ABC):
|
|||
"""
|
||||
self._iteration += 1
|
||||
|
||||
# 这里应该调用实际的 LLM 服务
|
||||
# 使用 LangChain 或直接调用 API
|
||||
# 发射 LLM 开始事件
|
||||
await self.emit_llm_start(self._iteration)
|
||||
|
||||
try:
|
||||
response = await self.llm_service.chat_completion(
|
||||
messages=messages,
|
||||
|
|
@ -281,4 +385,3 @@ class BaseAgent(ABC):
|
|||
"tool_calls": self._tool_calls,
|
||||
"tokens_used": self._total_tokens,
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -1,12 +1,19 @@
|
|||
"""
|
||||
Orchestrator Agent (编排层)
|
||||
负责任务分解、子 Agent 调度和结果汇总
|
||||
Orchestrator Agent (编排层) - LLM 驱动版
|
||||
|
||||
类型: Plan-and-Execute
|
||||
LLM 是真正的大脑,全程参与决策!
|
||||
- LLM 决定下一步做什么
|
||||
- LLM 决定调度哪个子 Agent
|
||||
- LLM 决定何时完成
|
||||
- LLM 根据中间结果动态调整策略
|
||||
|
||||
类型: Autonomous Agent with Dynamic Planning
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
import logging
|
||||
import re
|
||||
from typing import List, Dict, Any, Optional
|
||||
from dataclasses import dataclass
|
||||
|
||||
|
|
@ -15,69 +22,97 @@ from .base import BaseAgent, AgentConfig, AgentResult, AgentType, AgentPattern
|
|||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class AuditPlan:
|
||||
"""审计计划"""
|
||||
phases: List[Dict[str, Any]]
|
||||
high_risk_areas: List[str]
|
||||
focus_vulnerabilities: List[str]
|
||||
estimated_steps: int
|
||||
priority_files: List[str]
|
||||
metadata: Dict[str, Any]
|
||||
ORCHESTRATOR_SYSTEM_PROMPT = """你是 DeepAudit 的编排 Agent,负责**自主**协调整个安全审计流程。
|
||||
|
||||
## 你的角色
|
||||
你是整个审计流程的**大脑**,不是一个机械执行者。你需要:
|
||||
1. 自主思考和决策
|
||||
2. 根据观察结果动态调整策略
|
||||
3. 决定何时调用哪个子 Agent
|
||||
4. 判断何时审计完成
|
||||
|
||||
ORCHESTRATOR_SYSTEM_PROMPT = """你是 DeepAudit 的编排 Agent,负责协调整个安全审计流程。
|
||||
## 你可以调度的子 Agent
|
||||
1. **recon**: 信息收集 Agent - 分析项目结构、技术栈、入口点
|
||||
2. **analysis**: 分析 Agent - 深度代码审计、漏洞检测
|
||||
3. **verification**: 验证 Agent - 验证发现的漏洞、生成 PoC
|
||||
|
||||
## 你的职责
|
||||
1. 分析项目信息,制定审计计划
|
||||
2. 调度子 Agent(Recon、Analysis、Verification)执行任务
|
||||
3. 汇总审计结果,生成报告
|
||||
## 你可以使用的操作
|
||||
|
||||
## 审计流程
|
||||
1. **信息收集阶段**: 调度 Recon Agent 收集项目信息
|
||||
- 项目结构分析
|
||||
- 技术栈识别
|
||||
- 入口点识别
|
||||
- 依赖分析
|
||||
|
||||
2. **漏洞分析阶段**: 调度 Analysis Agent 进行代码分析
|
||||
- 静态代码分析
|
||||
- 语义搜索
|
||||
- 模式匹配
|
||||
- 数据流追踪
|
||||
|
||||
3. **漏洞验证阶段**: 调度 Verification Agent 验证发现
|
||||
- 漏洞确认
|
||||
- PoC 生成
|
||||
- 沙箱测试
|
||||
|
||||
4. **报告生成阶段**: 汇总所有发现,生成最终报告
|
||||
|
||||
## 输出格式
|
||||
当生成审计计划时,返回 JSON:
|
||||
```json
|
||||
{
|
||||
"phases": [
|
||||
{"name": "阶段名", "description": "描述", "agent": "agent_type"}
|
||||
],
|
||||
"high_risk_areas": ["高风险目录/文件"],
|
||||
"focus_vulnerabilities": ["重点漏洞类型"],
|
||||
"priority_files": ["优先审计的文件"],
|
||||
"estimated_steps": 数字
|
||||
}
|
||||
### 1. 调度子 Agent
|
||||
```
|
||||
Action: dispatch_agent
|
||||
Action Input: {"agent": "recon|analysis|verification", "task": "具体任务描述", "context": "任务上下文"}
|
||||
```
|
||||
|
||||
请基于项目信息制定合理的审计计划。"""
|
||||
### 2. 汇总发现
|
||||
```
|
||||
Action: summarize
|
||||
Action Input: {"findings": [...], "analysis": "你的分析"}
|
||||
```
|
||||
|
||||
### 3. 完成审计
|
||||
```
|
||||
Action: finish
|
||||
Action Input: {"conclusion": "审计结论", "findings": [...], "recommendations": [...]}
|
||||
```
|
||||
|
||||
## 工作方式
|
||||
每一步,你需要:
|
||||
|
||||
1. **Thought**: 分析当前状态,思考下一步应该做什么
|
||||
- 目前收集到了什么信息?
|
||||
- 还需要了解什么?
|
||||
- 应该深入分析哪些地方?
|
||||
- 有什么发现需要验证?
|
||||
|
||||
2. **Action**: 选择一个操作
|
||||
3. **Action Input**: 提供操作参数
|
||||
|
||||
## 输出格式
|
||||
每一步必须严格按照以下格式:
|
||||
|
||||
```
|
||||
Thought: [你的思考过程]
|
||||
Action: [dispatch_agent|summarize|finish]
|
||||
Action Input: [JSON 参数]
|
||||
```
|
||||
|
||||
## 审计策略建议
|
||||
- 先用 recon Agent 了解项目全貌
|
||||
- 根据 recon 结果,让 analysis Agent 重点审计高风险区域
|
||||
- 发现可疑漏洞后,用 verification Agent 验证
|
||||
- 随时根据新发现调整策略,不要机械执行
|
||||
- 当你认为审计足够全面时,选择 finish
|
||||
|
||||
## 重要原则
|
||||
1. **你是大脑,不是执行器** - 每一步都要思考
|
||||
2. **动态调整** - 根据发现调整策略
|
||||
3. **主动决策** - 不要等待,主动推进
|
||||
4. **质量优先** - 宁可深入分析几个真实漏洞,不要浅尝辄止
|
||||
|
||||
现在,基于项目信息开始你的审计工作!"""
|
||||
|
||||
|
||||
@dataclass
|
||||
class AgentStep:
|
||||
"""执行步骤"""
|
||||
thought: str
|
||||
action: str
|
||||
action_input: Dict[str, Any]
|
||||
observation: Optional[str] = None
|
||||
sub_agent_result: Optional[AgentResult] = None
|
||||
|
||||
|
||||
class OrchestratorAgent(BaseAgent):
|
||||
"""
|
||||
编排 Agent
|
||||
编排 Agent - LLM 驱动版
|
||||
|
||||
使用 Plan-and-Execute 模式:
|
||||
1. 首先生成审计计划
|
||||
2. 按计划调度子 Agent
|
||||
3. 收集结果并汇总
|
||||
LLM 全程参与决策:
|
||||
1. LLM 思考当前状态
|
||||
2. LLM 决定下一步操作
|
||||
3. 执行操作,获取结果
|
||||
4. LLM 分析结果,决定下一步
|
||||
5. 重复直到 LLM 决定完成
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
|
|
@ -90,13 +125,16 @@ class OrchestratorAgent(BaseAgent):
|
|||
config = AgentConfig(
|
||||
name="Orchestrator",
|
||||
agent_type=AgentType.ORCHESTRATOR,
|
||||
pattern=AgentPattern.PLAN_AND_EXECUTE,
|
||||
max_iterations=10,
|
||||
pattern=AgentPattern.REACT, # 改为 ReAct 模式!
|
||||
max_iterations=20,
|
||||
system_prompt=ORCHESTRATOR_SYSTEM_PROMPT,
|
||||
)
|
||||
super().__init__(config, llm_service, tools, event_emitter)
|
||||
|
||||
self.sub_agents = sub_agents or {}
|
||||
self._conversation_history: List[Dict[str, str]] = []
|
||||
self._steps: List[AgentStep] = []
|
||||
self._all_findings: List[Dict] = []
|
||||
|
||||
def register_sub_agent(self, name: str, agent: BaseAgent):
|
||||
"""注册子 Agent"""
|
||||
|
|
@ -104,7 +142,7 @@ class OrchestratorAgent(BaseAgent):
|
|||
|
||||
async def run(self, input_data: Dict[str, Any]) -> AgentResult:
|
||||
"""
|
||||
执行编排任务
|
||||
执行编排任务 - LLM 全程参与!
|
||||
|
||||
Args:
|
||||
input_data: {
|
||||
|
|
@ -118,82 +156,136 @@ class OrchestratorAgent(BaseAgent):
|
|||
project_info = input_data.get("project_info", {})
|
||||
config = input_data.get("config", {})
|
||||
|
||||
# 构建初始消息
|
||||
initial_message = self._build_initial_message(project_info, config)
|
||||
|
||||
# 初始化对话历史
|
||||
self._conversation_history = [
|
||||
{"role": "system", "content": self.config.system_prompt},
|
||||
{"role": "user", "content": initial_message},
|
||||
]
|
||||
|
||||
self._steps = []
|
||||
self._all_findings = []
|
||||
final_result = None
|
||||
|
||||
await self.emit_thinking("🧠 Orchestrator Agent 启动,LLM 开始自主编排决策...")
|
||||
|
||||
try:
|
||||
await self.emit_thinking("开始制定审计计划...")
|
||||
|
||||
# 1. 生成审计计划
|
||||
plan = await self._create_audit_plan(project_info, config)
|
||||
|
||||
if not plan:
|
||||
return AgentResult(
|
||||
success=False,
|
||||
error="无法生成审计计划",
|
||||
)
|
||||
|
||||
await self.emit_event(
|
||||
"planning",
|
||||
f"审计计划已生成,共 {len(plan.phases)} 个阶段",
|
||||
metadata={"plan": plan.__dict__}
|
||||
)
|
||||
|
||||
# 2. 执行各阶段
|
||||
all_findings = []
|
||||
phase_results = {}
|
||||
|
||||
for phase in plan.phases:
|
||||
for iteration in range(self.config.max_iterations):
|
||||
if self.is_cancelled:
|
||||
break
|
||||
|
||||
phase_name = phase.get("name", "unknown")
|
||||
agent_type = phase.get("agent", "analysis")
|
||||
self._iteration = iteration + 1
|
||||
|
||||
await self.emit_event(
|
||||
"phase_start",
|
||||
f"开始 {phase_name} 阶段",
|
||||
phase=phase_name
|
||||
# 🔥 发射 LLM 开始思考事件
|
||||
await self.emit_llm_start(iteration + 1)
|
||||
|
||||
# 🔥 调用 LLM 进行思考和决策
|
||||
response = await self.llm_service.chat_completion_raw(
|
||||
messages=self._conversation_history,
|
||||
temperature=0.1,
|
||||
max_tokens=2048,
|
||||
)
|
||||
|
||||
# 调度对应的子 Agent
|
||||
result = await self._execute_phase(
|
||||
phase_name=phase_name,
|
||||
agent_type=agent_type,
|
||||
project_info=project_info,
|
||||
config=config,
|
||||
plan=plan,
|
||||
previous_results=phase_results,
|
||||
)
|
||||
llm_output = response.get("content", "")
|
||||
tokens_this_round = response.get("usage", {}).get("total_tokens", 0)
|
||||
self._total_tokens += tokens_this_round
|
||||
|
||||
phase_results[phase_name] = result
|
||||
# 解析 LLM 的决策
|
||||
step = self._parse_llm_response(llm_output)
|
||||
|
||||
if result.success and result.data:
|
||||
if isinstance(result.data, dict):
|
||||
findings = result.data.get("findings", [])
|
||||
all_findings.extend(findings)
|
||||
if not step:
|
||||
# LLM 输出格式不正确,提示重试
|
||||
await self.emit_llm_decision("格式错误", "需要重新输出")
|
||||
self._conversation_history.append({
|
||||
"role": "assistant",
|
||||
"content": llm_output,
|
||||
})
|
||||
self._conversation_history.append({
|
||||
"role": "user",
|
||||
"content": "请按照规定格式输出:Thought + Action + Action Input",
|
||||
})
|
||||
continue
|
||||
|
||||
await self.emit_event(
|
||||
"phase_complete",
|
||||
f"{phase_name} 阶段完成",
|
||||
phase=phase_name
|
||||
)
|
||||
|
||||
# 3. 汇总结果
|
||||
await self.emit_thinking("汇总审计结果...")
|
||||
|
||||
summary = await self._generate_summary(
|
||||
plan=plan,
|
||||
phase_results=phase_results,
|
||||
all_findings=all_findings,
|
||||
)
|
||||
self._steps.append(step)
|
||||
|
||||
# 🔥 发射 LLM 思考内容事件 - 展示编排决策的思考过程
|
||||
if step.thought:
|
||||
await self.emit_llm_thought(step.thought, iteration + 1)
|
||||
|
||||
# 添加 LLM 响应到历史
|
||||
self._conversation_history.append({
|
||||
"role": "assistant",
|
||||
"content": llm_output,
|
||||
})
|
||||
|
||||
# 执行 LLM 决定的操作
|
||||
if step.action == "finish":
|
||||
# 🔥 LLM 决定完成审计
|
||||
await self.emit_llm_decision("完成审计", "LLM 判断审计已充分完成")
|
||||
await self.emit_llm_complete(
|
||||
f"编排完成,发现 {len(self._all_findings)} 个漏洞",
|
||||
self._total_tokens
|
||||
)
|
||||
final_result = step.action_input
|
||||
break
|
||||
|
||||
elif step.action == "dispatch_agent":
|
||||
# 🔥 LLM 决定调度子 Agent
|
||||
agent_name = step.action_input.get("agent", "unknown")
|
||||
task_desc = step.action_input.get("task", "")
|
||||
await self.emit_llm_decision(
|
||||
f"调度 {agent_name} Agent",
|
||||
f"任务: {task_desc[:100]}"
|
||||
)
|
||||
await self.emit_llm_action("dispatch_agent", step.action_input)
|
||||
|
||||
observation = await self._dispatch_agent(step.action_input)
|
||||
step.observation = observation
|
||||
|
||||
# 🔥 发射观察事件
|
||||
await self.emit_llm_observation(observation)
|
||||
|
||||
elif step.action == "summarize":
|
||||
# LLM 要求汇总
|
||||
await self.emit_llm_decision("汇总发现", "LLM 请求查看当前发现汇总")
|
||||
observation = self._summarize_findings()
|
||||
step.observation = observation
|
||||
await self.emit_llm_observation(observation)
|
||||
|
||||
else:
|
||||
observation = f"未知操作: {step.action},可用操作: dispatch_agent, summarize, finish"
|
||||
await self.emit_llm_decision("未知操作", observation)
|
||||
|
||||
# 添加观察结果到历史
|
||||
self._conversation_history.append({
|
||||
"role": "user",
|
||||
"content": f"Observation:\n{step.observation}",
|
||||
})
|
||||
|
||||
# 生成最终结果
|
||||
duration_ms = int((time.time() - start_time) * 1000)
|
||||
|
||||
await self.emit_event(
|
||||
"info",
|
||||
f"🎯 Orchestrator 完成: {len(self._all_findings)} 个发现, {self._iteration} 轮决策"
|
||||
)
|
||||
|
||||
return AgentResult(
|
||||
success=True,
|
||||
data={
|
||||
"plan": plan.__dict__,
|
||||
"findings": all_findings,
|
||||
"summary": summary,
|
||||
"phase_results": {k: v.to_dict() for k, v in phase_results.items()},
|
||||
"findings": self._all_findings,
|
||||
"summary": final_result or self._generate_default_summary(),
|
||||
"steps": [
|
||||
{
|
||||
"thought": s.thought,
|
||||
"action": s.action,
|
||||
"action_input": s.action_input,
|
||||
"observation": s.observation[:500] if s.observation else None,
|
||||
}
|
||||
for s in self._steps
|
||||
],
|
||||
},
|
||||
iterations=self._iteration,
|
||||
tool_calls=self._tool_calls,
|
||||
|
|
@ -208,174 +300,197 @@ class OrchestratorAgent(BaseAgent):
|
|||
error=str(e),
|
||||
)
|
||||
|
||||
async def _create_audit_plan(
|
||||
def _build_initial_message(
|
||||
self,
|
||||
project_info: Dict[str, Any],
|
||||
config: Dict[str, Any],
|
||||
) -> Optional[AuditPlan]:
|
||||
"""生成审计计划"""
|
||||
# 构建 prompt
|
||||
prompt = f"""基于以下项目信息,制定安全审计计划。
|
||||
) -> str:
|
||||
"""构建初始消息"""
|
||||
msg = f"""请开始对以下项目进行安全审计。
|
||||
|
||||
## 项目信息
|
||||
- 名称: {project_info.get('name', 'unknown')}
|
||||
- 语言: {project_info.get('languages', [])}
|
||||
- 文件数量: {project_info.get('file_count', 0)}
|
||||
- 目录结构: {project_info.get('structure', {})}
|
||||
- 目录结构: {json.dumps(project_info.get('structure', {}), ensure_ascii=False, indent=2)}
|
||||
|
||||
## 用户配置
|
||||
- 目标漏洞: {config.get('target_vulnerabilities', [])}
|
||||
- 目标漏洞: {config.get('target_vulnerabilities', ['all'])}
|
||||
- 验证级别: {config.get('verification_level', 'sandbox')}
|
||||
- 排除模式: {config.get('exclude_patterns', [])}
|
||||
|
||||
请生成审计计划,返回 JSON 格式。"""
|
||||
## 可用子 Agent
|
||||
{', '.join(self.sub_agents.keys()) if self.sub_agents else '(暂无子 Agent)'}
|
||||
|
||||
请开始你的审计工作。首先思考应该如何开展,然后决定第一步做什么。"""
|
||||
|
||||
return msg
|
||||
|
||||
def _parse_llm_response(self, response: str) -> Optional[AgentStep]:
|
||||
"""解析 LLM 响应"""
|
||||
# 提取 Thought
|
||||
thought_match = re.search(r'Thought:\s*(.*?)(?=Action:|$)', response, re.DOTALL)
|
||||
thought = thought_match.group(1).strip() if thought_match else ""
|
||||
|
||||
# 提取 Action
|
||||
action_match = re.search(r'Action:\s*(\w+)', response)
|
||||
if not action_match:
|
||||
return None
|
||||
action = action_match.group(1).strip()
|
||||
|
||||
# 提取 Action Input
|
||||
input_match = re.search(r'Action Input:\s*(.*?)(?=Thought:|Observation:|$)', response, re.DOTALL)
|
||||
if not input_match:
|
||||
return None
|
||||
|
||||
input_text = input_match.group(1).strip()
|
||||
# 移除 markdown 代码块
|
||||
input_text = re.sub(r'```json\s*', '', input_text)
|
||||
input_text = re.sub(r'```\s*', '', input_text)
|
||||
|
||||
try:
|
||||
# 调用 LLM
|
||||
messages = [
|
||||
{"role": "system", "content": self.config.system_prompt},
|
||||
{"role": "user", "content": prompt},
|
||||
]
|
||||
|
||||
response = await self.llm_service.chat_completion_raw(
|
||||
messages=messages,
|
||||
temperature=0.1,
|
||||
max_tokens=2000,
|
||||
)
|
||||
|
||||
content = response.get("content", "")
|
||||
|
||||
# 解析 JSON
|
||||
import json
|
||||
import re
|
||||
|
||||
# 提取 JSON
|
||||
json_match = re.search(r'\{[\s\S]*\}', content)
|
||||
if json_match:
|
||||
plan_data = json.loads(json_match.group())
|
||||
|
||||
return AuditPlan(
|
||||
phases=plan_data.get("phases", self._default_phases()),
|
||||
high_risk_areas=plan_data.get("high_risk_areas", []),
|
||||
focus_vulnerabilities=plan_data.get("focus_vulnerabilities", []),
|
||||
estimated_steps=plan_data.get("estimated_steps", 30),
|
||||
priority_files=plan_data.get("priority_files", []),
|
||||
metadata=plan_data,
|
||||
)
|
||||
else:
|
||||
# 使用默认计划
|
||||
return AuditPlan(
|
||||
phases=self._default_phases(),
|
||||
high_risk_areas=["src/", "api/", "controllers/", "routes/"],
|
||||
focus_vulnerabilities=["sql_injection", "xss", "command_injection"],
|
||||
estimated_steps=30,
|
||||
priority_files=[],
|
||||
metadata={},
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to create audit plan: {e}")
|
||||
return AuditPlan(
|
||||
phases=self._default_phases(),
|
||||
high_risk_areas=[],
|
||||
focus_vulnerabilities=[],
|
||||
estimated_steps=30,
|
||||
priority_files=[],
|
||||
metadata={},
|
||||
)
|
||||
action_input = json.loads(input_text)
|
||||
except json.JSONDecodeError:
|
||||
action_input = {"raw": input_text}
|
||||
|
||||
return AgentStep(
|
||||
thought=thought,
|
||||
action=action,
|
||||
action_input=action_input,
|
||||
)
|
||||
|
||||
def _default_phases(self) -> List[Dict[str, Any]]:
|
||||
"""默认审计阶段"""
|
||||
return [
|
||||
{
|
||||
"name": "recon",
|
||||
"description": "信息收集 - 分析项目结构和技术栈",
|
||||
"agent": "recon",
|
||||
},
|
||||
{
|
||||
"name": "static_analysis",
|
||||
"description": "静态分析 - 使用外部工具快速扫描",
|
||||
"agent": "analysis",
|
||||
},
|
||||
{
|
||||
"name": "deep_analysis",
|
||||
"description": "深度分析 - AI 驱动的代码审计",
|
||||
"agent": "analysis",
|
||||
},
|
||||
{
|
||||
"name": "verification",
|
||||
"description": "漏洞验证 - 确认发现的漏洞",
|
||||
"agent": "verification",
|
||||
},
|
||||
]
|
||||
|
||||
async def _execute_phase(
|
||||
self,
|
||||
phase_name: str,
|
||||
agent_type: str,
|
||||
project_info: Dict[str, Any],
|
||||
config: Dict[str, Any],
|
||||
plan: AuditPlan,
|
||||
previous_results: Dict[str, AgentResult],
|
||||
) -> AgentResult:
|
||||
"""执行审计阶段"""
|
||||
agent = self.sub_agents.get(agent_type)
|
||||
async def _dispatch_agent(self, params: Dict[str, Any]) -> str:
|
||||
"""调度子 Agent"""
|
||||
agent_name = params.get("agent", "")
|
||||
task = params.get("task", "")
|
||||
context = params.get("context", "")
|
||||
|
||||
agent = self.sub_agents.get(agent_name)
|
||||
|
||||
if not agent:
|
||||
logger.warning(f"Agent not found: {agent_type}")
|
||||
return AgentResult(success=False, error=f"Agent {agent_type} not found")
|
||||
available = list(self.sub_agents.keys())
|
||||
return f"错误: Agent '{agent_name}' 不存在。可用的 Agent: {available}"
|
||||
|
||||
# 构建阶段输入
|
||||
phase_input = {
|
||||
"phase_name": phase_name,
|
||||
"project_info": project_info,
|
||||
"config": config,
|
||||
"plan": plan.__dict__,
|
||||
"previous_results": {k: v.to_dict() for k, v in previous_results.items()},
|
||||
}
|
||||
await self.emit_event(
|
||||
"dispatch",
|
||||
f"📤 调度 {agent_name} Agent: {task[:100]}...",
|
||||
agent=agent_name,
|
||||
task=task,
|
||||
)
|
||||
|
||||
# 执行子 Agent
|
||||
return await agent.run(phase_input)
|
||||
self._tool_calls += 1
|
||||
|
||||
try:
|
||||
# 构建子 Agent 输入
|
||||
sub_input = {
|
||||
"task": task,
|
||||
"task_context": context,
|
||||
"project_info": {}, # 从上下文获取
|
||||
"config": {},
|
||||
}
|
||||
|
||||
# 执行子 Agent
|
||||
result = await agent.run(sub_input)
|
||||
|
||||
# 收集发现
|
||||
if result.success and result.data:
|
||||
findings = result.data.get("findings", [])
|
||||
self._all_findings.extend(findings)
|
||||
|
||||
await self.emit_event(
|
||||
"dispatch_complete",
|
||||
f"✅ {agent_name} Agent 完成: {len(findings)} 个发现",
|
||||
agent=agent_name,
|
||||
findings_count=len(findings),
|
||||
)
|
||||
|
||||
# 构建观察结果
|
||||
observation = f"""## {agent_name} Agent 执行结果
|
||||
|
||||
**状态**: 成功
|
||||
**发现数量**: {len(findings)}
|
||||
**迭代次数**: {result.iterations}
|
||||
**耗时**: {result.duration_ms}ms
|
||||
|
||||
### 发现摘要
|
||||
"""
|
||||
for i, f in enumerate(findings[:10]): # 最多显示 10 个
|
||||
observation += f"""
|
||||
{i+1}. [{f.get('severity', 'unknown')}] {f.get('title', 'Unknown')}
|
||||
- 类型: {f.get('vulnerability_type', 'unknown')}
|
||||
- 文件: {f.get('file_path', 'unknown')}
|
||||
- 描述: {f.get('description', '')[:200]}...
|
||||
"""
|
||||
|
||||
if len(findings) > 10:
|
||||
observation += f"\n... 还有 {len(findings) - 10} 个发现"
|
||||
|
||||
if result.data.get("summary"):
|
||||
observation += f"\n\n### Agent 总结\n{result.data['summary']}"
|
||||
|
||||
return observation
|
||||
else:
|
||||
return f"## {agent_name} Agent 执行失败\n\n错误: {result.error}"
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Sub-agent dispatch failed: {e}", exc_info=True)
|
||||
return f"## 调度失败\n\n错误: {str(e)}"
|
||||
|
||||
async def _generate_summary(
|
||||
self,
|
||||
plan: AuditPlan,
|
||||
phase_results: Dict[str, AgentResult],
|
||||
all_findings: List[Dict],
|
||||
) -> Dict[str, Any]:
|
||||
"""生成审计摘要"""
|
||||
# 统计漏洞
|
||||
def _summarize_findings(self) -> str:
|
||||
"""汇总当前发现"""
|
||||
if not self._all_findings:
|
||||
return "目前还没有发现任何漏洞。"
|
||||
|
||||
# 统计
|
||||
severity_counts = {"critical": 0, "high": 0, "medium": 0, "low": 0}
|
||||
type_counts = {}
|
||||
verified_count = 0
|
||||
|
||||
for finding in all_findings:
|
||||
sev = finding.get("severity", "low")
|
||||
for f in self._all_findings:
|
||||
sev = f.get("severity", "low")
|
||||
severity_counts[sev] = severity_counts.get(sev, 0) + 1
|
||||
|
||||
vtype = finding.get("vulnerability_type", "other")
|
||||
vtype = f.get("vulnerability_type", "other")
|
||||
type_counts[vtype] = type_counts.get(vtype, 0) + 1
|
||||
|
||||
if finding.get("is_verified"):
|
||||
verified_count += 1
|
||||
|
||||
# 计算安全评分
|
||||
base_score = 100
|
||||
deductions = (
|
||||
severity_counts["critical"] * 20 +
|
||||
severity_counts["high"] * 10 +
|
||||
severity_counts["medium"] * 5 +
|
||||
severity_counts["low"] * 2
|
||||
)
|
||||
security_score = max(0, base_score - deductions)
|
||||
summary = f"""## 当前发现汇总
|
||||
|
||||
**总计**: {len(self._all_findings)} 个漏洞
|
||||
|
||||
### 严重程度分布
|
||||
- Critical: {severity_counts['critical']}
|
||||
- High: {severity_counts['high']}
|
||||
- Medium: {severity_counts['medium']}
|
||||
- Low: {severity_counts['low']}
|
||||
|
||||
### 漏洞类型分布
|
||||
"""
|
||||
for vtype, count in type_counts.items():
|
||||
summary += f"- {vtype}: {count}\n"
|
||||
|
||||
summary += "\n### 详细列表\n"
|
||||
for i, f in enumerate(self._all_findings):
|
||||
summary += f"{i+1}. [{f.get('severity')}] {f.get('title')} ({f.get('file_path')})\n"
|
||||
|
||||
return summary
|
||||
|
||||
def _generate_default_summary(self) -> Dict[str, Any]:
|
||||
"""生成默认摘要"""
|
||||
severity_counts = {"critical": 0, "high": 0, "medium": 0, "low": 0}
|
||||
|
||||
for f in self._all_findings:
|
||||
sev = f.get("severity", "low")
|
||||
severity_counts[sev] = severity_counts.get(sev, 0) + 1
|
||||
|
||||
return {
|
||||
"total_findings": len(all_findings),
|
||||
"verified_count": verified_count,
|
||||
"total_findings": len(self._all_findings),
|
||||
"severity_distribution": severity_counts,
|
||||
"vulnerability_types": type_counts,
|
||||
"security_score": security_score,
|
||||
"phases_completed": len(phase_results),
|
||||
"high_risk_areas": plan.high_risk_areas,
|
||||
"conclusion": "审计完成(未通过 LLM 生成结论)",
|
||||
}
|
||||
|
||||
|
||||
def get_conversation_history(self) -> List[Dict[str, str]]:
|
||||
"""获取对话历史"""
|
||||
return self._conversation_history
|
||||
|
||||
def get_steps(self) -> List[AgentStep]:
|
||||
"""获取执行步骤"""
|
||||
return self._steps
|
||||
|
|
|
|||
|
|
@ -1,72 +1,127 @@
|
|||
"""
|
||||
Recon Agent (信息收集层)
|
||||
负责项目结构分析、技术栈识别、入口点识别
|
||||
Recon Agent (信息收集层) - LLM 驱动版
|
||||
|
||||
类型: ReAct
|
||||
LLM 是真正的大脑!
|
||||
- LLM 决定收集什么信息
|
||||
- LLM 决定使用哪个工具
|
||||
- LLM 决定何时信息足够
|
||||
- LLM 动态调整收集策略
|
||||
|
||||
类型: ReAct (真正的!)
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import re
|
||||
from typing import List, Dict, Any, Optional
|
||||
from dataclasses import dataclass, field
|
||||
from dataclasses import dataclass
|
||||
|
||||
from .base import BaseAgent, AgentConfig, AgentResult, AgentType, AgentPattern
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
RECON_SYSTEM_PROMPT = """你是 DeepAudit 的信息收集 Agent,负责在安全审计前收集项目信息。
|
||||
RECON_SYSTEM_PROMPT = """你是 DeepAudit 的信息收集 Agent,负责在安全审计前**自主**收集项目信息。
|
||||
|
||||
## 你的职责
|
||||
1. 分析项目结构和目录布局
|
||||
2. 识别使用的技术栈和框架
|
||||
3. 找出应用程序入口点
|
||||
4. 分析依赖和第三方库
|
||||
5. 识别高风险区域
|
||||
## 你的角色
|
||||
你是信息收集的**大脑**,不是机械执行者。你需要:
|
||||
1. 自主思考需要收集什么信息
|
||||
2. 选择合适的工具获取信息
|
||||
3. 根据发现动态调整策略
|
||||
4. 判断何时信息收集足够
|
||||
|
||||
## 你可以使用的工具
|
||||
- list_files: 列出目录内容
|
||||
- read_file: 读取文件内容
|
||||
- search_code: 搜索代码
|
||||
- semgrep_scan: Semgrep 扫描
|
||||
- npm_audit: npm 依赖审计
|
||||
- safety_scan: Python 依赖审计
|
||||
- gitleaks_scan: 密钥泄露扫描
|
||||
|
||||
## 信息收集要点
|
||||
1. **目录结构**: 了解项目布局,识别源码、配置、测试目录
|
||||
2. **技术栈**: 检测语言、框架、数据库等
|
||||
3. **入口点**: API 路由、控制器、处理函数
|
||||
4. **配置文件**: 环境变量、数据库配置、API 密钥
|
||||
5. **依赖**: package.json, requirements.txt, go.mod 等
|
||||
6. **安全相关**: 认证、授权、加密相关代码
|
||||
### 文件系统
|
||||
- **list_files**: 列出目录内容
|
||||
参数: directory (str), recursive (bool), pattern (str), max_files (int)
|
||||
|
||||
- **read_file**: 读取文件内容
|
||||
参数: file_path (str), start_line (int), end_line (int), max_lines (int)
|
||||
|
||||
- **search_code**: 代码关键字搜索
|
||||
参数: keyword (str), max_results (int)
|
||||
|
||||
## 输出格式
|
||||
完成后返回 JSON:
|
||||
### 安全扫描
|
||||
- **semgrep_scan**: Semgrep 静态分析扫描
|
||||
- **npm_audit**: npm 依赖漏洞审计
|
||||
- **safety_scan**: Python 依赖漏洞审计
|
||||
- **gitleaks_scan**: 密钥/敏感信息泄露扫描
|
||||
- **osv_scan**: OSV 通用依赖漏洞扫描
|
||||
|
||||
## 工作方式
|
||||
每一步,你需要输出:
|
||||
|
||||
```
|
||||
Thought: [分析当前状态,思考还需要什么信息]
|
||||
Action: [工具名称]
|
||||
Action Input: [JSON 格式的参数]
|
||||
```
|
||||
|
||||
当你认为信息收集足够时,输出:
|
||||
|
||||
```
|
||||
Thought: [总结收集到的信息]
|
||||
Final Answer: [JSON 格式的收集结果]
|
||||
```
|
||||
|
||||
## Final Answer 格式
|
||||
```json
|
||||
{
|
||||
"project_structure": {...},
|
||||
"project_structure": {
|
||||
"directories": [],
|
||||
"config_files": [],
|
||||
"total_files": 数量
|
||||
},
|
||||
"tech_stack": {
|
||||
"languages": [],
|
||||
"frameworks": [],
|
||||
"databases": []
|
||||
},
|
||||
"entry_points": [],
|
||||
"high_risk_areas": [],
|
||||
"dependencies": {...},
|
||||
"entry_points": [
|
||||
{"type": "描述", "file": "路径", "line": 行号}
|
||||
],
|
||||
"high_risk_areas": ["路径列表"],
|
||||
"dependencies": {},
|
||||
"initial_findings": []
|
||||
}
|
||||
```
|
||||
|
||||
请系统性地收集信息,为后续分析做准备。"""
|
||||
## 信息收集策略建议
|
||||
1. 先 list_files 了解项目结构
|
||||
2. 读取配置文件 (package.json, requirements.txt, go.mod 等) 识别技术栈
|
||||
3. 搜索入口点模式 (routes, controllers, handlers)
|
||||
4. 运行安全扫描发现初步问题
|
||||
5. 根据发现继续深入
|
||||
|
||||
## 重要原则
|
||||
1. **你是大脑** - 每一步都要思考,不要机械执行
|
||||
2. **动态调整** - 根据发现调整策略
|
||||
3. **效率优先** - 不要重复收集已有信息
|
||||
4. **主动探索** - 发现有趣的东西要深入
|
||||
|
||||
现在开始收集项目信息!"""
|
||||
|
||||
|
||||
@dataclass
|
||||
class ReconStep:
|
||||
"""信息收集步骤"""
|
||||
thought: str
|
||||
action: Optional[str] = None
|
||||
action_input: Optional[Dict] = None
|
||||
observation: Optional[str] = None
|
||||
is_final: bool = False
|
||||
final_answer: Optional[Dict] = None
|
||||
|
||||
|
||||
class ReconAgent(BaseAgent):
|
||||
"""
|
||||
信息收集 Agent
|
||||
信息收集 Agent - LLM 驱动版
|
||||
|
||||
使用 ReAct 模式迭代收集项目信息
|
||||
LLM 全程参与,自主决定:
|
||||
1. 收集什么信息
|
||||
2. 使用什么工具
|
||||
3. 何时足够
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
|
|
@ -81,84 +136,219 @@ class ReconAgent(BaseAgent):
|
|||
pattern=AgentPattern.REACT,
|
||||
max_iterations=15,
|
||||
system_prompt=RECON_SYSTEM_PROMPT,
|
||||
tools=[
|
||||
"list_files", "read_file", "search_code",
|
||||
"semgrep_scan", "npm_audit", "safety_scan",
|
||||
"gitleaks_scan", "osv_scan",
|
||||
],
|
||||
)
|
||||
super().__init__(config, llm_service, tools, event_emitter)
|
||||
|
||||
self._conversation_history: List[Dict[str, str]] = []
|
||||
self._steps: List[ReconStep] = []
|
||||
|
||||
def _get_tools_description(self) -> str:
|
||||
"""生成工具描述"""
|
||||
tools_info = []
|
||||
for name, tool in self.tools.items():
|
||||
if name.startswith("_"):
|
||||
continue
|
||||
desc = f"- {name}: {getattr(tool, 'description', 'No description')}"
|
||||
tools_info.append(desc)
|
||||
return "\n".join(tools_info)
|
||||
|
||||
def _parse_llm_response(self, response: str) -> ReconStep:
|
||||
"""解析 LLM 响应"""
|
||||
step = ReconStep(thought="")
|
||||
|
||||
# 提取 Thought
|
||||
thought_match = re.search(r'Thought:\s*(.*?)(?=Action:|Final Answer:|$)', response, re.DOTALL)
|
||||
if thought_match:
|
||||
step.thought = thought_match.group(1).strip()
|
||||
|
||||
# 检查是否是最终答案
|
||||
final_match = re.search(r'Final Answer:\s*(.*?)$', response, re.DOTALL)
|
||||
if final_match:
|
||||
step.is_final = True
|
||||
try:
|
||||
answer_text = final_match.group(1).strip()
|
||||
answer_text = re.sub(r'```json\s*', '', answer_text)
|
||||
answer_text = re.sub(r'```\s*', '', answer_text)
|
||||
step.final_answer = json.loads(answer_text)
|
||||
except json.JSONDecodeError:
|
||||
step.final_answer = {"raw_answer": final_match.group(1).strip()}
|
||||
return step
|
||||
|
||||
# 提取 Action
|
||||
action_match = re.search(r'Action:\s*(\w+)', response)
|
||||
if action_match:
|
||||
step.action = action_match.group(1).strip()
|
||||
|
||||
# 提取 Action Input
|
||||
input_match = re.search(r'Action Input:\s*(.*?)(?=Thought:|Action:|Observation:|$)', response, re.DOTALL)
|
||||
if input_match:
|
||||
input_text = input_match.group(1).strip()
|
||||
input_text = re.sub(r'```json\s*', '', input_text)
|
||||
input_text = re.sub(r'```\s*', '', input_text)
|
||||
try:
|
||||
step.action_input = json.loads(input_text)
|
||||
except json.JSONDecodeError:
|
||||
step.action_input = {"raw_input": input_text}
|
||||
|
||||
return step
|
||||
|
||||
async def _execute_tool(self, tool_name: str, tool_input: Dict) -> str:
|
||||
"""执行工具"""
|
||||
tool = self.tools.get(tool_name)
|
||||
|
||||
if not tool:
|
||||
return f"错误: 工具 '{tool_name}' 不存在。可用工具: {list(self.tools.keys())}"
|
||||
|
||||
try:
|
||||
self._tool_calls += 1
|
||||
await self.emit_tool_call(tool_name, tool_input)
|
||||
|
||||
import time
|
||||
start = time.time()
|
||||
|
||||
result = await tool.execute(**tool_input)
|
||||
|
||||
duration_ms = int((time.time() - start) * 1000)
|
||||
await self.emit_tool_result(tool_name, str(result.data)[:200], duration_ms)
|
||||
|
||||
if result.success:
|
||||
output = str(result.data)
|
||||
if len(output) > 4000:
|
||||
output = output[:4000] + f"\n\n... [输出已截断,共 {len(str(result.data))} 字符]"
|
||||
return output
|
||||
else:
|
||||
return f"工具执行失败: {result.error}"
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Tool execution error: {e}")
|
||||
return f"工具执行错误: {str(e)}"
|
||||
|
||||
async def run(self, input_data: Dict[str, Any]) -> AgentResult:
|
||||
"""执行信息收集"""
|
||||
"""
|
||||
执行信息收集 - LLM 全程参与!
|
||||
"""
|
||||
import time
|
||||
start_time = time.time()
|
||||
|
||||
project_info = input_data.get("project_info", {})
|
||||
config = input_data.get("config", {})
|
||||
task = input_data.get("task", "")
|
||||
task_context = input_data.get("task_context", "")
|
||||
|
||||
# 构建初始消息
|
||||
initial_message = f"""请开始收集项目信息。
|
||||
|
||||
## 项目基本信息
|
||||
- 名称: {project_info.get('name', 'unknown')}
|
||||
- 根目录: {project_info.get('root', '.')}
|
||||
|
||||
## 任务上下文
|
||||
{task_context or task or '进行全面的信息收集,为安全审计做准备。'}
|
||||
|
||||
## 可用工具
|
||||
{self._get_tools_description()}
|
||||
|
||||
请开始你的信息收集工作。首先思考应该收集什么信息,然后选择合适的工具。"""
|
||||
|
||||
# 初始化对话历史
|
||||
self._conversation_history = [
|
||||
{"role": "system", "content": self.config.system_prompt},
|
||||
{"role": "user", "content": initial_message},
|
||||
]
|
||||
|
||||
self._steps = []
|
||||
final_result = None
|
||||
|
||||
await self.emit_thinking("🔍 Recon Agent 启动,LLM 开始自主收集信息...")
|
||||
|
||||
try:
|
||||
await self.emit_thinking("开始信息收集...")
|
||||
|
||||
# 收集结果
|
||||
result_data = {
|
||||
"project_structure": {},
|
||||
"tech_stack": {
|
||||
"languages": [],
|
||||
"frameworks": [],
|
||||
"databases": [],
|
||||
},
|
||||
"entry_points": [],
|
||||
"high_risk_areas": [],
|
||||
"dependencies": {},
|
||||
"initial_findings": [],
|
||||
}
|
||||
|
||||
# 1. 分析项目结构
|
||||
await self.emit_thinking("分析项目结构...")
|
||||
structure = await self._analyze_structure()
|
||||
result_data["project_structure"] = structure
|
||||
|
||||
# 2. 识别技术栈
|
||||
await self.emit_thinking("识别技术栈...")
|
||||
tech_stack = await self._identify_tech_stack(structure)
|
||||
result_data["tech_stack"] = tech_stack
|
||||
|
||||
# 3. 扫描依赖漏洞
|
||||
await self.emit_thinking("扫描依赖漏洞...")
|
||||
deps_result = await self._scan_dependencies(tech_stack)
|
||||
result_data["dependencies"] = deps_result.get("dependencies", {})
|
||||
if deps_result.get("findings"):
|
||||
result_data["initial_findings"].extend(deps_result["findings"])
|
||||
|
||||
# 4. 快速密钥扫描
|
||||
await self.emit_thinking("扫描密钥泄露...")
|
||||
secrets_result = await self._scan_secrets()
|
||||
if secrets_result.get("findings"):
|
||||
result_data["initial_findings"].extend(secrets_result["findings"])
|
||||
|
||||
# 5. 识别入口点
|
||||
await self.emit_thinking("识别入口点...")
|
||||
entry_points = await self._identify_entry_points(tech_stack)
|
||||
result_data["entry_points"] = entry_points
|
||||
|
||||
# 6. 识别高风险区域
|
||||
result_data["high_risk_areas"] = self._identify_high_risk_areas(
|
||||
structure, tech_stack, entry_points
|
||||
)
|
||||
for iteration in range(self.config.max_iterations):
|
||||
if self.is_cancelled:
|
||||
break
|
||||
|
||||
self._iteration = iteration + 1
|
||||
|
||||
# 🔥 发射 LLM 开始思考事件
|
||||
await self.emit_llm_start(iteration + 1)
|
||||
|
||||
# 🔥 调用 LLM 进行思考和决策
|
||||
response = await self.llm_service.chat_completion_raw(
|
||||
messages=self._conversation_history,
|
||||
temperature=0.1,
|
||||
max_tokens=2048,
|
||||
)
|
||||
|
||||
llm_output = response.get("content", "")
|
||||
tokens_this_round = response.get("usage", {}).get("total_tokens", 0)
|
||||
self._total_tokens += tokens_this_round
|
||||
|
||||
# 解析 LLM 响应
|
||||
step = self._parse_llm_response(llm_output)
|
||||
self._steps.append(step)
|
||||
|
||||
# 🔥 发射 LLM 思考内容事件 - 展示 LLM 在想什么
|
||||
if step.thought:
|
||||
await self.emit_llm_thought(step.thought, iteration + 1)
|
||||
|
||||
# 添加 LLM 响应到历史
|
||||
self._conversation_history.append({
|
||||
"role": "assistant",
|
||||
"content": llm_output,
|
||||
})
|
||||
|
||||
# 检查是否完成
|
||||
if step.is_final:
|
||||
await self.emit_llm_decision("完成信息收集", "LLM 判断已收集足够信息")
|
||||
await self.emit_llm_complete(
|
||||
f"信息收集完成,共 {self._iteration} 轮思考",
|
||||
self._total_tokens
|
||||
)
|
||||
final_result = step.final_answer
|
||||
break
|
||||
|
||||
# 执行工具
|
||||
if step.action:
|
||||
# 🔥 发射 LLM 动作决策事件
|
||||
await self.emit_llm_action(step.action, step.action_input or {})
|
||||
|
||||
observation = await self._execute_tool(
|
||||
step.action,
|
||||
step.action_input or {}
|
||||
)
|
||||
|
||||
step.observation = observation
|
||||
|
||||
# 🔥 发射 LLM 观察事件
|
||||
await self.emit_llm_observation(observation)
|
||||
|
||||
# 添加观察结果到历史
|
||||
self._conversation_history.append({
|
||||
"role": "user",
|
||||
"content": f"Observation:\n{observation}",
|
||||
})
|
||||
else:
|
||||
# LLM 没有选择工具,提示它继续
|
||||
await self.emit_llm_decision("继续思考", "LLM 需要更多信息")
|
||||
self._conversation_history.append({
|
||||
"role": "user",
|
||||
"content": "请继续,选择一个工具执行,或者如果信息收集完成,输出 Final Answer。",
|
||||
})
|
||||
|
||||
# 处理结果
|
||||
duration_ms = int((time.time() - start_time) * 1000)
|
||||
|
||||
# 如果没有最终结果,从历史中汇总
|
||||
if not final_result:
|
||||
final_result = self._summarize_from_steps()
|
||||
|
||||
await self.emit_event(
|
||||
"info",
|
||||
f"信息收集完成: 发现 {len(result_data['entry_points'])} 个入口点, "
|
||||
f"{len(result_data['high_risk_areas'])} 个高风险区域, "
|
||||
f"{len(result_data['initial_findings'])} 个初步发现"
|
||||
f"🎯 Recon Agent 完成: {self._iteration} 轮迭代, {self._tool_calls} 次工具调用"
|
||||
)
|
||||
|
||||
return AgentResult(
|
||||
success=True,
|
||||
data=result_data,
|
||||
data=final_result,
|
||||
iterations=self._iteration,
|
||||
tool_calls=self._tool_calls,
|
||||
tokens_used=self._total_tokens,
|
||||
|
|
@ -166,270 +356,58 @@ class ReconAgent(BaseAgent):
|
|||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Recon agent failed: {e}", exc_info=True)
|
||||
logger.error(f"Recon Agent failed: {e}", exc_info=True)
|
||||
return AgentResult(success=False, error=str(e))
|
||||
|
||||
async def _analyze_structure(self) -> Dict[str, Any]:
|
||||
"""分析项目结构"""
|
||||
structure = {
|
||||
"directories": [],
|
||||
"files_by_type": {},
|
||||
"config_files": [],
|
||||
"total_files": 0,
|
||||
def _summarize_from_steps(self) -> Dict[str, Any]:
|
||||
"""从步骤中汇总结果"""
|
||||
# 默认结果结构
|
||||
result = {
|
||||
"project_structure": {},
|
||||
"tech_stack": {
|
||||
"languages": [],
|
||||
"frameworks": [],
|
||||
"databases": [],
|
||||
},
|
||||
"entry_points": [],
|
||||
"high_risk_areas": [],
|
||||
"dependencies": {},
|
||||
"initial_findings": [],
|
||||
}
|
||||
|
||||
# 列出根目录
|
||||
list_tool = self.tools.get("list_files")
|
||||
if not list_tool:
|
||||
return structure
|
||||
|
||||
result = await list_tool.execute(directory=".", recursive=True, max_files=300)
|
||||
|
||||
if result.success:
|
||||
structure["total_files"] = result.metadata.get("file_count", 0)
|
||||
|
||||
# 识别配置文件
|
||||
config_patterns = [
|
||||
"package.json", "requirements.txt", "go.mod", "Cargo.toml",
|
||||
"pom.xml", "build.gradle", ".env", "config.py", "settings.py",
|
||||
"docker-compose.yml", "Dockerfile",
|
||||
]
|
||||
|
||||
# 从输出中解析文件列表
|
||||
if isinstance(result.data, str):
|
||||
for line in result.data.split('\n'):
|
||||
line = line.strip()
|
||||
for pattern in config_patterns:
|
||||
if pattern in line:
|
||||
structure["config_files"].append(line)
|
||||
|
||||
return structure
|
||||
|
||||
async def _identify_tech_stack(self, structure: Dict) -> Dict[str, Any]:
|
||||
"""识别技术栈"""
|
||||
tech_stack = {
|
||||
"languages": [],
|
||||
"frameworks": [],
|
||||
"databases": [],
|
||||
"package_managers": [],
|
||||
}
|
||||
|
||||
config_files = structure.get("config_files", [])
|
||||
|
||||
# 基于配置文件推断
|
||||
for cfg in config_files:
|
||||
if "package.json" in cfg:
|
||||
tech_stack["languages"].append("JavaScript/TypeScript")
|
||||
tech_stack["package_managers"].append("npm")
|
||||
elif "requirements.txt" in cfg or "setup.py" in cfg:
|
||||
tech_stack["languages"].append("Python")
|
||||
tech_stack["package_managers"].append("pip")
|
||||
elif "go.mod" in cfg:
|
||||
tech_stack["languages"].append("Go")
|
||||
elif "Cargo.toml" in cfg:
|
||||
tech_stack["languages"].append("Rust")
|
||||
elif "pom.xml" in cfg or "build.gradle" in cfg:
|
||||
tech_stack["languages"].append("Java")
|
||||
|
||||
# 读取 package.json 识别框架
|
||||
read_tool = self.tools.get("read_file")
|
||||
if read_tool and "package.json" in str(config_files):
|
||||
result = await read_tool.execute(file_path="package.json", max_lines=100)
|
||||
if result.success:
|
||||
content = result.data
|
||||
if "react" in content.lower():
|
||||
tech_stack["frameworks"].append("React")
|
||||
if "vue" in content.lower():
|
||||
tech_stack["frameworks"].append("Vue")
|
||||
if "express" in content.lower():
|
||||
tech_stack["frameworks"].append("Express")
|
||||
if "fastify" in content.lower():
|
||||
tech_stack["frameworks"].append("Fastify")
|
||||
if "next" in content.lower():
|
||||
tech_stack["frameworks"].append("Next.js")
|
||||
|
||||
# 读取 requirements.txt 识别框架
|
||||
if read_tool and "requirements.txt" in str(config_files):
|
||||
result = await read_tool.execute(file_path="requirements.txt", max_lines=50)
|
||||
if result.success:
|
||||
content = result.data.lower()
|
||||
if "django" in content:
|
||||
tech_stack["frameworks"].append("Django")
|
||||
if "flask" in content:
|
||||
tech_stack["frameworks"].append("Flask")
|
||||
if "fastapi" in content:
|
||||
tech_stack["frameworks"].append("FastAPI")
|
||||
if "sqlalchemy" in content:
|
||||
tech_stack["databases"].append("SQLAlchemy")
|
||||
if "pymongo" in content:
|
||||
tech_stack["databases"].append("MongoDB")
|
||||
# 从步骤的观察结果中提取信息
|
||||
for step in self._steps:
|
||||
if step.observation:
|
||||
# 尝试从观察中识别技术栈等信息
|
||||
obs_lower = step.observation.lower()
|
||||
|
||||
if "package.json" in obs_lower:
|
||||
result["tech_stack"]["languages"].append("JavaScript/TypeScript")
|
||||
if "requirements.txt" in obs_lower or "setup.py" in obs_lower:
|
||||
result["tech_stack"]["languages"].append("Python")
|
||||
if "go.mod" in obs_lower:
|
||||
result["tech_stack"]["languages"].append("Go")
|
||||
|
||||
# 识别框架
|
||||
if "react" in obs_lower:
|
||||
result["tech_stack"]["frameworks"].append("React")
|
||||
if "django" in obs_lower:
|
||||
result["tech_stack"]["frameworks"].append("Django")
|
||||
if "fastapi" in obs_lower:
|
||||
result["tech_stack"]["frameworks"].append("FastAPI")
|
||||
if "express" in obs_lower:
|
||||
result["tech_stack"]["frameworks"].append("Express")
|
||||
|
||||
# 去重
|
||||
tech_stack["languages"] = list(set(tech_stack["languages"]))
|
||||
tech_stack["frameworks"] = list(set(tech_stack["frameworks"]))
|
||||
tech_stack["databases"] = list(set(tech_stack["databases"]))
|
||||
|
||||
return tech_stack
|
||||
|
||||
async def _scan_dependencies(self, tech_stack: Dict) -> Dict[str, Any]:
|
||||
"""扫描依赖漏洞"""
|
||||
result = {
|
||||
"dependencies": {},
|
||||
"findings": [],
|
||||
}
|
||||
|
||||
# npm audit
|
||||
if "npm" in tech_stack.get("package_managers", []):
|
||||
npm_tool = self.tools.get("npm_audit")
|
||||
if npm_tool:
|
||||
npm_result = await npm_tool.execute()
|
||||
if npm_result.success and npm_result.metadata.get("findings_count", 0) > 0:
|
||||
result["dependencies"]["npm"] = npm_result.metadata
|
||||
|
||||
# 转换为发现格式
|
||||
for sev, count in npm_result.metadata.get("severity_counts", {}).items():
|
||||
if count > 0 and sev in ["critical", "high"]:
|
||||
result["findings"].append({
|
||||
"vulnerability_type": "dependency_vulnerability",
|
||||
"severity": sev,
|
||||
"title": f"npm 依赖漏洞 ({count} 个 {sev})",
|
||||
"source": "npm_audit",
|
||||
})
|
||||
|
||||
# Safety (Python)
|
||||
if "pip" in tech_stack.get("package_managers", []):
|
||||
safety_tool = self.tools.get("safety_scan")
|
||||
if safety_tool:
|
||||
safety_result = await safety_tool.execute()
|
||||
if safety_result.success and safety_result.metadata.get("findings_count", 0) > 0:
|
||||
result["dependencies"]["pip"] = safety_result.metadata
|
||||
result["findings"].append({
|
||||
"vulnerability_type": "dependency_vulnerability",
|
||||
"severity": "high",
|
||||
"title": f"Python 依赖漏洞",
|
||||
"source": "safety",
|
||||
})
|
||||
|
||||
# OSV Scanner
|
||||
osv_tool = self.tools.get("osv_scan")
|
||||
if osv_tool:
|
||||
osv_result = await osv_tool.execute()
|
||||
if osv_result.success and osv_result.metadata.get("findings_count", 0) > 0:
|
||||
result["dependencies"]["osv"] = osv_result.metadata
|
||||
result["tech_stack"]["languages"] = list(set(result["tech_stack"]["languages"]))
|
||||
result["tech_stack"]["frameworks"] = list(set(result["tech_stack"]["frameworks"]))
|
||||
|
||||
return result
|
||||
|
||||
async def _scan_secrets(self) -> Dict[str, Any]:
|
||||
"""扫描密钥泄露"""
|
||||
result = {"findings": []}
|
||||
|
||||
gitleaks_tool = self.tools.get("gitleaks_scan")
|
||||
if gitleaks_tool:
|
||||
gl_result = await gitleaks_tool.execute()
|
||||
if gl_result.success and gl_result.metadata.get("findings_count", 0) > 0:
|
||||
for finding in gl_result.metadata.get("findings", []):
|
||||
result["findings"].append({
|
||||
"vulnerability_type": "hardcoded_secret",
|
||||
"severity": "high",
|
||||
"title": f"密钥泄露: {finding.get('rule', 'unknown')}",
|
||||
"file_path": finding.get("file"),
|
||||
"line_start": finding.get("line"),
|
||||
"source": "gitleaks",
|
||||
})
|
||||
|
||||
return result
|
||||
def get_conversation_history(self) -> List[Dict[str, str]]:
|
||||
"""获取对话历史"""
|
||||
return self._conversation_history
|
||||
|
||||
async def _identify_entry_points(self, tech_stack: Dict) -> List[Dict[str, Any]]:
|
||||
"""识别入口点"""
|
||||
entry_points = []
|
||||
search_tool = self.tools.get("search_code")
|
||||
|
||||
if not search_tool:
|
||||
return entry_points
|
||||
|
||||
# 基于框架搜索入口点
|
||||
search_patterns = []
|
||||
|
||||
frameworks = tech_stack.get("frameworks", [])
|
||||
|
||||
if "Express" in frameworks:
|
||||
search_patterns.extend([
|
||||
("app.get(", "Express GET route"),
|
||||
("app.post(", "Express POST route"),
|
||||
("router.get(", "Express router GET"),
|
||||
("router.post(", "Express router POST"),
|
||||
])
|
||||
|
||||
if "FastAPI" in frameworks:
|
||||
search_patterns.extend([
|
||||
("@app.get(", "FastAPI GET endpoint"),
|
||||
("@app.post(", "FastAPI POST endpoint"),
|
||||
("@router.get(", "FastAPI router GET"),
|
||||
("@router.post(", "FastAPI router POST"),
|
||||
])
|
||||
|
||||
if "Django" in frameworks:
|
||||
search_patterns.extend([
|
||||
("def get(self", "Django GET view"),
|
||||
("def post(self", "Django POST view"),
|
||||
("path(", "Django URL pattern"),
|
||||
])
|
||||
|
||||
if "Flask" in frameworks:
|
||||
search_patterns.extend([
|
||||
("@app.route(", "Flask route"),
|
||||
("@blueprint.route(", "Flask blueprint route"),
|
||||
])
|
||||
|
||||
# 通用模式
|
||||
search_patterns.extend([
|
||||
("def handle", "Handler function"),
|
||||
("async def handle", "Async handler"),
|
||||
("class.*Controller", "Controller class"),
|
||||
("class.*Handler", "Handler class"),
|
||||
])
|
||||
|
||||
for pattern, description in search_patterns[:10]: # 限制搜索数量
|
||||
result = await search_tool.execute(keyword=pattern, max_results=10)
|
||||
if result.success and result.metadata.get("matches", 0) > 0:
|
||||
for match in result.metadata.get("results", [])[:5]:
|
||||
entry_points.append({
|
||||
"type": description,
|
||||
"file": match.get("file"),
|
||||
"line": match.get("line"),
|
||||
"pattern": pattern,
|
||||
})
|
||||
|
||||
return entry_points[:30] # 限制总数
|
||||
|
||||
def _identify_high_risk_areas(
|
||||
self,
|
||||
structure: Dict,
|
||||
tech_stack: Dict,
|
||||
entry_points: List[Dict],
|
||||
) -> List[str]:
|
||||
"""识别高风险区域"""
|
||||
high_risk = set()
|
||||
|
||||
# 通用高风险目录
|
||||
risk_dirs = [
|
||||
"auth/", "authentication/", "login/",
|
||||
"api/", "routes/", "controllers/", "handlers/",
|
||||
"db/", "database/", "models/",
|
||||
"admin/", "management/",
|
||||
"upload/", "file/",
|
||||
"payment/", "billing/",
|
||||
]
|
||||
|
||||
for dir_name in risk_dirs:
|
||||
high_risk.add(dir_name)
|
||||
|
||||
# 从入口点提取目录
|
||||
for ep in entry_points:
|
||||
file_path = ep.get("file", "")
|
||||
if "/" in file_path:
|
||||
dir_path = "/".join(file_path.split("/")[:-1]) + "/"
|
||||
high_risk.add(dir_path)
|
||||
|
||||
return list(high_risk)[:20]
|
||||
|
||||
def get_steps(self) -> List[ReconStep]:
|
||||
"""获取执行步骤"""
|
||||
return self._steps
|
||||
|
|
|
|||
|
|
@ -1,13 +1,20 @@
|
|||
"""
|
||||
Verification Agent (漏洞验证层)
|
||||
负责漏洞确认、PoC 生成、沙箱测试
|
||||
Verification Agent (漏洞验证层) - LLM 驱动版
|
||||
|
||||
类型: ReAct
|
||||
LLM 是验证的大脑!
|
||||
- LLM 决定如何验证每个漏洞
|
||||
- LLM 构造验证策略
|
||||
- LLM 分析验证结果
|
||||
- LLM 判断是否为真实漏洞
|
||||
|
||||
类型: ReAct (真正的!)
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
import logging
|
||||
import re
|
||||
from typing import List, Dict, Any, Optional
|
||||
from dataclasses import dataclass
|
||||
from datetime import datetime, timezone
|
||||
|
||||
from .base import BaseAgent, AgentConfig, AgentResult, AgentType, AgentPattern
|
||||
|
|
@ -15,69 +22,121 @@ from .base import BaseAgent, AgentConfig, AgentResult, AgentType, AgentPattern
|
|||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
VERIFICATION_SYSTEM_PROMPT = """你是 DeepAudit 的漏洞验证 Agent,负责确认发现的漏洞是否真实存在。
|
||||
VERIFICATION_SYSTEM_PROMPT = """你是 DeepAudit 的漏洞验证 Agent,一个**自主**的安全验证专家。
|
||||
|
||||
## 你的职责
|
||||
1. 分析漏洞上下文,判断是否为真正的安全问题
|
||||
2. 构造 PoC(概念验证)代码
|
||||
3. 在沙箱中执行测试
|
||||
4. 评估漏洞的实际影响
|
||||
## 你的角色
|
||||
你是漏洞验证的**大脑**,不是机械验证器。你需要:
|
||||
1. 理解每个漏洞的上下文
|
||||
2. 设计合适的验证策略
|
||||
3. 使用工具获取更多信息
|
||||
4. 判断漏洞是否真实存在
|
||||
5. 评估实际影响
|
||||
|
||||
## 你可以使用的工具
|
||||
|
||||
### 代码分析
|
||||
- read_file: 读取更多上下文
|
||||
- function_context: 分析函数调用关系
|
||||
- dataflow_analysis: 追踪数据流
|
||||
- vulnerability_validation: LLM 漏洞验证
|
||||
- **read_file**: 读取更多代码上下文
|
||||
参数: file_path (str), start_line (int), end_line (int)
|
||||
- **function_context**: 分析函数调用关系
|
||||
参数: function_name (str)
|
||||
- **dataflow_analysis**: 追踪数据流
|
||||
参数: source (str), sink (str), file_path (str)
|
||||
- **vulnerability_validation**: LLM 深度验证 ⭐
|
||||
参数: code (str), vulnerability_type (str), context (str)
|
||||
|
||||
### 沙箱执行
|
||||
- sandbox_exec: 在沙箱中执行命令
|
||||
- sandbox_http: 发送 HTTP 请求
|
||||
- verify_vulnerability: 自动验证漏洞
|
||||
### 沙箱验证
|
||||
- **sandbox_exec**: 在沙箱中执行命令
|
||||
参数: command (str), timeout (int)
|
||||
- **sandbox_http**: 发送 HTTP 请求测试
|
||||
参数: method (str), url (str), data (dict), headers (dict)
|
||||
- **verify_vulnerability**: 自动化漏洞验证
|
||||
参数: vulnerability_type (str), target (str), payload (str)
|
||||
|
||||
## 验证流程
|
||||
1. **上下文分析**: 获取更多代码上下文
|
||||
2. **可利用性分析**: 判断漏洞是否可被利用
|
||||
3. **PoC 构造**: 设计验证方案
|
||||
4. **沙箱测试**: 在隔离环境中测试
|
||||
5. **结果评估**: 确定漏洞是否真实存在
|
||||
## 工作方式
|
||||
你将收到一批待验证的漏洞发现。对于每个发现,你需要:
|
||||
|
||||
## 验证标准
|
||||
- **确认 (confirmed)**: 漏洞真实存在且可利用
|
||||
- **可能 (likely)**: 高度可能存在漏洞
|
||||
- **不确定 (uncertain)**: 需要更多信息
|
||||
- **误报 (false_positive)**: 确认是误报
|
||||
```
|
||||
Thought: [分析这个漏洞,思考如何验证]
|
||||
Action: [工具名称]
|
||||
Action Input: [JSON 格式的参数]
|
||||
```
|
||||
|
||||
## 输出格式
|
||||
验证完所有发现后,输出:
|
||||
|
||||
```
|
||||
Thought: [总结验证结果]
|
||||
Final Answer: [JSON 格式的验证报告]
|
||||
```
|
||||
|
||||
## Final Answer 格式
|
||||
```json
|
||||
{
|
||||
"findings": [
|
||||
{
|
||||
"original_finding": {...},
|
||||
...原始发现字段...,
|
||||
"verdict": "confirmed/likely/uncertain/false_positive",
|
||||
"confidence": 0.0-1.0,
|
||||
"is_verified": true/false,
|
||||
"verification_method": "描述验证方法",
|
||||
"verification_details": "验证过程和结果详情",
|
||||
"poc": {
|
||||
"code": "PoC 代码",
|
||||
"description": "描述",
|
||||
"steps": ["步骤1", "步骤2"]
|
||||
"description": "PoC 描述",
|
||||
"steps": ["步骤1", "步骤2"],
|
||||
"payload": "测试 payload"
|
||||
},
|
||||
"impact": "影响分析",
|
||||
"impact": "实际影响分析",
|
||||
"recommendation": "修复建议"
|
||||
}
|
||||
]
|
||||
],
|
||||
"summary": {
|
||||
"total": 数量,
|
||||
"confirmed": 数量,
|
||||
"likely": 数量,
|
||||
"false_positive": 数量
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
请谨慎验证,减少误报,同时不遗漏真正的漏洞。"""
|
||||
## 验证判定标准
|
||||
- **confirmed**: 漏洞确认存在且可利用,有明确证据
|
||||
- **likely**: 高度可能存在漏洞,但无法完全确认
|
||||
- **uncertain**: 需要更多信息才能判断
|
||||
- **false_positive**: 确认是误报,有明确理由
|
||||
|
||||
## 验证策略建议
|
||||
1. **上下文分析**: 用 read_file 获取更多代码上下文
|
||||
2. **数据流追踪**: 用 dataflow_analysis 确认污点传播
|
||||
3. **LLM 深度分析**: 用 vulnerability_validation 进行专业分析
|
||||
4. **沙箱测试**: 对高危漏洞用沙箱进行安全测试
|
||||
|
||||
## 重要原则
|
||||
1. **质量优先** - 宁可漏报也不要误报太多
|
||||
2. **深入理解** - 理解代码逻辑,不要表面判断
|
||||
3. **证据支撑** - 判定要有依据
|
||||
4. **安全第一** - 沙箱测试要谨慎
|
||||
|
||||
现在开始验证漏洞发现!"""
|
||||
|
||||
|
||||
@dataclass
|
||||
class VerificationStep:
|
||||
"""验证步骤"""
|
||||
thought: str
|
||||
action: Optional[str] = None
|
||||
action_input: Optional[Dict] = None
|
||||
observation: Optional[str] = None
|
||||
is_final: bool = False
|
||||
final_answer: Optional[Dict] = None
|
||||
|
||||
|
||||
class VerificationAgent(BaseAgent):
|
||||
"""
|
||||
漏洞验证 Agent
|
||||
漏洞验证 Agent - LLM 驱动版
|
||||
|
||||
使用 ReAct 模式验证发现的漏洞
|
||||
LLM 全程参与,自主决定:
|
||||
1. 如何验证每个漏洞
|
||||
2. 使用什么工具
|
||||
3. 判断真假
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
|
|
@ -90,25 +149,114 @@ class VerificationAgent(BaseAgent):
|
|||
name="Verification",
|
||||
agent_type=AgentType.VERIFICATION,
|
||||
pattern=AgentPattern.REACT,
|
||||
max_iterations=20,
|
||||
max_iterations=25,
|
||||
system_prompt=VERIFICATION_SYSTEM_PROMPT,
|
||||
tools=[
|
||||
"read_file", "function_context", "dataflow_analysis",
|
||||
"vulnerability_validation",
|
||||
"sandbox_exec", "sandbox_http", "verify_vulnerability",
|
||||
],
|
||||
)
|
||||
super().__init__(config, llm_service, tools, event_emitter)
|
||||
|
||||
self._conversation_history: List[Dict[str, str]] = []
|
||||
self._steps: List[VerificationStep] = []
|
||||
|
||||
def _get_tools_description(self) -> str:
|
||||
"""生成工具描述"""
|
||||
tools_info = []
|
||||
for name, tool in self.tools.items():
|
||||
if name.startswith("_"):
|
||||
continue
|
||||
desc = f"- {name}: {getattr(tool, 'description', 'No description')}"
|
||||
tools_info.append(desc)
|
||||
return "\n".join(tools_info)
|
||||
|
||||
def _parse_llm_response(self, response: str) -> VerificationStep:
|
||||
"""解析 LLM 响应"""
|
||||
step = VerificationStep(thought="")
|
||||
|
||||
# 提取 Thought
|
||||
thought_match = re.search(r'Thought:\s*(.*?)(?=Action:|Final Answer:|$)', response, re.DOTALL)
|
||||
if thought_match:
|
||||
step.thought = thought_match.group(1).strip()
|
||||
|
||||
# 检查是否是最终答案
|
||||
final_match = re.search(r'Final Answer:\s*(.*?)$', response, re.DOTALL)
|
||||
if final_match:
|
||||
step.is_final = True
|
||||
try:
|
||||
answer_text = final_match.group(1).strip()
|
||||
answer_text = re.sub(r'```json\s*', '', answer_text)
|
||||
answer_text = re.sub(r'```\s*', '', answer_text)
|
||||
step.final_answer = json.loads(answer_text)
|
||||
except json.JSONDecodeError:
|
||||
step.final_answer = {"findings": [], "raw_answer": final_match.group(1).strip()}
|
||||
return step
|
||||
|
||||
# 提取 Action
|
||||
action_match = re.search(r'Action:\s*(\w+)', response)
|
||||
if action_match:
|
||||
step.action = action_match.group(1).strip()
|
||||
|
||||
# 提取 Action Input
|
||||
input_match = re.search(r'Action Input:\s*(.*?)(?=Thought:|Action:|Observation:|$)', response, re.DOTALL)
|
||||
if input_match:
|
||||
input_text = input_match.group(1).strip()
|
||||
input_text = re.sub(r'```json\s*', '', input_text)
|
||||
input_text = re.sub(r'```\s*', '', input_text)
|
||||
try:
|
||||
step.action_input = json.loads(input_text)
|
||||
except json.JSONDecodeError:
|
||||
step.action_input = {"raw_input": input_text}
|
||||
|
||||
return step
|
||||
|
||||
async def _execute_tool(self, tool_name: str, tool_input: Dict) -> str:
|
||||
"""执行工具"""
|
||||
tool = self.tools.get(tool_name)
|
||||
|
||||
if not tool:
|
||||
return f"错误: 工具 '{tool_name}' 不存在。可用工具: {list(self.tools.keys())}"
|
||||
|
||||
try:
|
||||
self._tool_calls += 1
|
||||
await self.emit_tool_call(tool_name, tool_input)
|
||||
|
||||
import time
|
||||
start = time.time()
|
||||
|
||||
result = await tool.execute(**tool_input)
|
||||
|
||||
duration_ms = int((time.time() - start) * 1000)
|
||||
await self.emit_tool_result(tool_name, str(result.data)[:200], duration_ms)
|
||||
|
||||
if result.success:
|
||||
output = str(result.data)
|
||||
|
||||
# 包含 metadata
|
||||
if result.metadata:
|
||||
if "validation" in result.metadata:
|
||||
output += f"\n\n验证结果:\n{json.dumps(result.metadata['validation'], ensure_ascii=False, indent=2)}"
|
||||
|
||||
if len(output) > 4000:
|
||||
output = output[:4000] + f"\n\n... [输出已截断]"
|
||||
return output
|
||||
else:
|
||||
return f"工具执行失败: {result.error}"
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Tool execution error: {e}")
|
||||
return f"工具执行错误: {str(e)}"
|
||||
|
||||
async def run(self, input_data: Dict[str, Any]) -> AgentResult:
|
||||
"""执行漏洞验证"""
|
||||
"""
|
||||
执行漏洞验证 - LLM 全程参与!
|
||||
"""
|
||||
import time
|
||||
start_time = time.time()
|
||||
|
||||
previous_results = input_data.get("previous_results", {})
|
||||
config = input_data.get("config", {})
|
||||
task = input_data.get("task", "")
|
||||
task_context = input_data.get("task_context", "")
|
||||
|
||||
# 收集所有需要验证的发现
|
||||
# 收集所有待验证的发现
|
||||
findings_to_verify = []
|
||||
|
||||
for phase_name, result in previous_results.items():
|
||||
|
|
@ -133,52 +281,164 @@ class VerificationAgent(BaseAgent):
|
|||
data={"findings": [], "verified_count": 0},
|
||||
)
|
||||
|
||||
# 限制数量
|
||||
findings_to_verify = findings_to_verify[:20]
|
||||
|
||||
await self.emit_event(
|
||||
"info",
|
||||
f"开始验证 {len(findings_to_verify)} 个发现"
|
||||
)
|
||||
|
||||
# 构建初始消息
|
||||
findings_summary = []
|
||||
for i, f in enumerate(findings_to_verify):
|
||||
findings_summary.append(f"""
|
||||
### 发现 {i+1}: {f.get('title', 'Unknown')}
|
||||
- 类型: {f.get('vulnerability_type', 'unknown')}
|
||||
- 严重度: {f.get('severity', 'medium')}
|
||||
- 文件: {f.get('file_path', 'unknown')}:{f.get('line_start', 0)}
|
||||
- 代码:
|
||||
```
|
||||
{f.get('code_snippet', 'N/A')[:500]}
|
||||
```
|
||||
- 描述: {f.get('description', 'N/A')[:300]}
|
||||
""")
|
||||
|
||||
initial_message = f"""请验证以下 {len(findings_to_verify)} 个安全发现。
|
||||
|
||||
## 待验证发现
|
||||
{''.join(findings_summary)}
|
||||
|
||||
## 验证要求
|
||||
- 验证级别: {config.get('verification_level', 'standard')}
|
||||
|
||||
## 可用工具
|
||||
{self._get_tools_description()}
|
||||
|
||||
请开始验证。对于每个发现,思考如何验证它,使用合适的工具获取更多信息,然后判断是否为真实漏洞。"""
|
||||
|
||||
# 初始化对话历史
|
||||
self._conversation_history = [
|
||||
{"role": "system", "content": self.config.system_prompt},
|
||||
{"role": "user", "content": initial_message},
|
||||
]
|
||||
|
||||
self._steps = []
|
||||
final_result = None
|
||||
|
||||
await self.emit_thinking("🔐 Verification Agent 启动,LLM 开始自主验证漏洞...")
|
||||
|
||||
try:
|
||||
verified_findings = []
|
||||
verification_level = config.get("verification_level", "sandbox")
|
||||
|
||||
for i, finding in enumerate(findings_to_verify[:20]): # 限制数量
|
||||
for iteration in range(self.config.max_iterations):
|
||||
if self.is_cancelled:
|
||||
break
|
||||
|
||||
await self.emit_thinking(
|
||||
f"验证 [{i+1}/{min(len(findings_to_verify), 20)}]: {finding.get('title', 'unknown')}"
|
||||
self._iteration = iteration + 1
|
||||
|
||||
# 🔥 发射 LLM 开始思考事件
|
||||
await self.emit_llm_start(iteration + 1)
|
||||
|
||||
# 🔥 调用 LLM 进行思考和决策
|
||||
response = await self.llm_service.chat_completion_raw(
|
||||
messages=self._conversation_history,
|
||||
temperature=0.1,
|
||||
max_tokens=3000,
|
||||
)
|
||||
|
||||
# 执行验证
|
||||
verified = await self._verify_finding(finding, verification_level)
|
||||
verified_findings.append(verified)
|
||||
llm_output = response.get("content", "")
|
||||
tokens_this_round = response.get("usage", {}).get("total_tokens", 0)
|
||||
self._total_tokens += tokens_this_round
|
||||
|
||||
# 发射事件
|
||||
if verified.get("is_verified"):
|
||||
await self.emit_event(
|
||||
"finding_verified",
|
||||
f"✅ 已确认: {verified.get('title', '')}",
|
||||
finding_id=verified.get("id"),
|
||||
metadata={"severity": verified.get("severity")}
|
||||
# 解析 LLM 响应
|
||||
step = self._parse_llm_response(llm_output)
|
||||
self._steps.append(step)
|
||||
|
||||
# 🔥 发射 LLM 思考内容事件 - 展示验证的思考过程
|
||||
if step.thought:
|
||||
await self.emit_llm_thought(step.thought, iteration + 1)
|
||||
|
||||
# 添加 LLM 响应到历史
|
||||
self._conversation_history.append({
|
||||
"role": "assistant",
|
||||
"content": llm_output,
|
||||
})
|
||||
|
||||
# 检查是否完成
|
||||
if step.is_final:
|
||||
await self.emit_llm_decision("完成漏洞验证", "LLM 判断验证已充分")
|
||||
final_result = step.final_answer
|
||||
await self.emit_llm_complete(
|
||||
f"验证完成",
|
||||
self._total_tokens
|
||||
)
|
||||
elif verified.get("verdict") == "false_positive":
|
||||
await self.emit_event(
|
||||
"finding_false_positive",
|
||||
f"❌ 误报: {verified.get('title', '')}",
|
||||
finding_id=verified.get("id"),
|
||||
break
|
||||
|
||||
# 执行工具
|
||||
if step.action:
|
||||
# 🔥 发射 LLM 动作决策事件
|
||||
await self.emit_llm_action(step.action, step.action_input or {})
|
||||
|
||||
observation = await self._execute_tool(
|
||||
step.action,
|
||||
step.action_input or {}
|
||||
)
|
||||
|
||||
step.observation = observation
|
||||
|
||||
# 🔥 发射 LLM 观察事件
|
||||
await self.emit_llm_observation(observation)
|
||||
|
||||
# 添加观察结果到历史
|
||||
self._conversation_history.append({
|
||||
"role": "user",
|
||||
"content": f"Observation:\n{observation}",
|
||||
})
|
||||
else:
|
||||
# LLM 没有选择工具,提示它继续
|
||||
await self.emit_llm_decision("继续验证", "LLM 需要更多验证")
|
||||
self._conversation_history.append({
|
||||
"role": "user",
|
||||
"content": "请继续验证。如果验证完成,输出 Final Answer 汇总所有验证结果。",
|
||||
})
|
||||
|
||||
# 处理结果
|
||||
duration_ms = int((time.time() - start_time) * 1000)
|
||||
|
||||
# 处理最终结果
|
||||
verified_findings = []
|
||||
if final_result and "findings" in final_result:
|
||||
for f in final_result["findings"]:
|
||||
verified = {
|
||||
**f,
|
||||
"is_verified": f.get("verdict") == "confirmed" or (
|
||||
f.get("verdict") == "likely" and f.get("confidence", 0) >= 0.8
|
||||
),
|
||||
"verified_at": datetime.now(timezone.utc).isoformat() if f.get("verdict") in ["confirmed", "likely"] else None,
|
||||
}
|
||||
|
||||
# 添加修复建议
|
||||
if not verified.get("recommendation"):
|
||||
verified["recommendation"] = self._get_recommendation(f.get("vulnerability_type", ""))
|
||||
|
||||
verified_findings.append(verified)
|
||||
else:
|
||||
# 如果没有最终结果,使用原始发现
|
||||
for f in findings_to_verify:
|
||||
verified_findings.append({
|
||||
**f,
|
||||
"verdict": "uncertain",
|
||||
"confidence": 0.5,
|
||||
"is_verified": False,
|
||||
})
|
||||
|
||||
# 统计
|
||||
confirmed_count = len([f for f in verified_findings if f.get("is_verified")])
|
||||
confirmed_count = len([f for f in verified_findings if f.get("verdict") == "confirmed"])
|
||||
likely_count = len([f for f in verified_findings if f.get("verdict") == "likely"])
|
||||
false_positive_count = len([f for f in verified_findings if f.get("verdict") == "false_positive"])
|
||||
|
||||
duration_ms = int((time.time() - start_time) * 1000)
|
||||
|
||||
await self.emit_event(
|
||||
"info",
|
||||
f"验证完成: {confirmed_count} 确认, {likely_count} 可能, {false_positive_count} 误报"
|
||||
f"🎯 Verification Agent 完成: {confirmed_count} 确认, {likely_count} 可能, {false_positive_count} 误报"
|
||||
)
|
||||
|
||||
return AgentResult(
|
||||
|
|
@ -196,167 +456,9 @@ class VerificationAgent(BaseAgent):
|
|||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Verification agent failed: {e}", exc_info=True)
|
||||
logger.error(f"Verification Agent failed: {e}", exc_info=True)
|
||||
return AgentResult(success=False, error=str(e))
|
||||
|
||||
async def _verify_finding(
|
||||
self,
|
||||
finding: Dict[str, Any],
|
||||
verification_level: str,
|
||||
) -> Dict[str, Any]:
|
||||
"""验证单个发现"""
|
||||
result = {
|
||||
**finding,
|
||||
"verdict": "uncertain",
|
||||
"confidence": 0.5,
|
||||
"is_verified": False,
|
||||
"verification_method": None,
|
||||
"verified_at": None,
|
||||
}
|
||||
|
||||
vuln_type = finding.get("vulnerability_type", "")
|
||||
file_path = finding.get("file_path", "")
|
||||
line_start = finding.get("line_start", 0)
|
||||
code_snippet = finding.get("code_snippet", "")
|
||||
|
||||
try:
|
||||
# 1. 获取更多上下文
|
||||
context = await self._get_context(file_path, line_start)
|
||||
|
||||
# 2. LLM 验证
|
||||
validation_result = await self._llm_validation(
|
||||
finding, context
|
||||
)
|
||||
|
||||
result["verdict"] = validation_result.get("verdict", "uncertain")
|
||||
result["confidence"] = validation_result.get("confidence", 0.5)
|
||||
result["verification_method"] = "llm_analysis"
|
||||
|
||||
# 3. 如果需要沙箱验证
|
||||
if verification_level in ["sandbox", "generate_poc"]:
|
||||
if result["verdict"] in ["confirmed", "likely"]:
|
||||
if vuln_type in ["sql_injection", "command_injection", "xss"]:
|
||||
sandbox_result = await self._sandbox_verification(
|
||||
finding, validation_result
|
||||
)
|
||||
|
||||
if sandbox_result.get("verified"):
|
||||
result["verdict"] = "confirmed"
|
||||
result["confidence"] = max(result["confidence"], 0.9)
|
||||
result["verification_method"] = "sandbox_test"
|
||||
result["poc"] = sandbox_result.get("poc")
|
||||
|
||||
# 4. 判断是否已验证
|
||||
if result["verdict"] == "confirmed" or (
|
||||
result["verdict"] == "likely" and result["confidence"] >= 0.8
|
||||
):
|
||||
result["is_verified"] = True
|
||||
result["verified_at"] = datetime.now(timezone.utc).isoformat()
|
||||
|
||||
# 5. 添加修复建议
|
||||
if result["is_verified"]:
|
||||
result["recommendation"] = self._get_recommendation(vuln_type)
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"Verification failed for {file_path}: {e}")
|
||||
result["error"] = str(e)
|
||||
|
||||
return result
|
||||
|
||||
async def _get_context(self, file_path: str, line_start: int) -> str:
|
||||
"""获取代码上下文"""
|
||||
read_tool = self.tools.get("read_file")
|
||||
if not read_tool or not file_path:
|
||||
return ""
|
||||
|
||||
result = await read_tool.execute(
|
||||
file_path=file_path,
|
||||
start_line=max(1, line_start - 30),
|
||||
end_line=line_start + 30,
|
||||
)
|
||||
|
||||
return result.data if result.success else ""
|
||||
|
||||
async def _llm_validation(
|
||||
self,
|
||||
finding: Dict[str, Any],
|
||||
context: str,
|
||||
) -> Dict[str, Any]:
|
||||
"""LLM 漏洞验证"""
|
||||
validation_tool = self.tools.get("vulnerability_validation")
|
||||
|
||||
if not validation_tool:
|
||||
return {"verdict": "uncertain", "confidence": 0.5}
|
||||
|
||||
code = finding.get("code_snippet", "") or context[:2000]
|
||||
|
||||
result = await validation_tool.execute(
|
||||
code=code,
|
||||
vulnerability_type=finding.get("vulnerability_type", "unknown"),
|
||||
file_path=finding.get("file_path", ""),
|
||||
line_number=finding.get("line_start"),
|
||||
context=context[:1000] if context else None,
|
||||
)
|
||||
|
||||
if result.success and result.metadata.get("validation"):
|
||||
validation = result.metadata["validation"]
|
||||
|
||||
verdict_map = {
|
||||
"confirmed": "confirmed",
|
||||
"likely": "likely",
|
||||
"unlikely": "uncertain",
|
||||
"false_positive": "false_positive",
|
||||
}
|
||||
|
||||
return {
|
||||
"verdict": verdict_map.get(validation.get("verdict", ""), "uncertain"),
|
||||
"confidence": validation.get("confidence", 0.5),
|
||||
"explanation": validation.get("detailed_analysis", ""),
|
||||
"exploitation_conditions": validation.get("exploitation_conditions", []),
|
||||
"poc_idea": validation.get("poc_idea"),
|
||||
}
|
||||
|
||||
return {"verdict": "uncertain", "confidence": 0.5}
|
||||
|
||||
async def _sandbox_verification(
|
||||
self,
|
||||
finding: Dict[str, Any],
|
||||
validation_result: Dict[str, Any],
|
||||
) -> Dict[str, Any]:
|
||||
"""沙箱验证"""
|
||||
result = {"verified": False, "poc": None}
|
||||
|
||||
vuln_type = finding.get("vulnerability_type", "")
|
||||
poc_idea = validation_result.get("poc_idea", "")
|
||||
|
||||
# 根据漏洞类型选择验证方法
|
||||
sandbox_tool = self.tools.get("sandbox_exec")
|
||||
http_tool = self.tools.get("sandbox_http")
|
||||
verify_tool = self.tools.get("verify_vulnerability")
|
||||
|
||||
if vuln_type == "command_injection" and sandbox_tool:
|
||||
# 构造安全的测试命令
|
||||
test_cmd = "echo 'test_marker_12345'"
|
||||
|
||||
exec_result = await sandbox_tool.execute(
|
||||
command=f"python3 -c \"print('test')\"",
|
||||
timeout=10,
|
||||
)
|
||||
|
||||
if exec_result.success:
|
||||
result["verified"] = True
|
||||
result["poc"] = {
|
||||
"description": "命令注入测试",
|
||||
"method": "sandbox_exec",
|
||||
}
|
||||
|
||||
elif vuln_type in ["sql_injection", "xss"] and verify_tool:
|
||||
# 使用自动验证工具
|
||||
# 注意:这需要实际的目标 URL
|
||||
pass
|
||||
|
||||
return result
|
||||
|
||||
def _get_recommendation(self, vuln_type: str) -> str:
|
||||
"""获取修复建议"""
|
||||
recommendations = {
|
||||
|
|
@ -369,7 +471,6 @@ class VerificationAgent(BaseAgent):
|
|||
"hardcoded_secret": "使用环境变量或密钥管理服务存储敏感信息",
|
||||
"weak_crypto": "使用强加密算法(AES-256, SHA-256+),避免 MD5/SHA1",
|
||||
}
|
||||
|
||||
return recommendations.get(vuln_type, "请根据具体情况修复此安全问题")
|
||||
|
||||
def _deduplicate(self, findings: List[Dict]) -> List[Dict]:
|
||||
|
|
@ -389,4 +490,11 @@ class VerificationAgent(BaseAgent):
|
|||
unique.append(f)
|
||||
|
||||
return unique
|
||||
|
||||
|
||||
def get_conversation_history(self) -> List[Dict[str, str]]:
|
||||
"""获取对话历史"""
|
||||
return self._conversation_history
|
||||
|
||||
def get_steps(self) -> List[VerificationStep]:
|
||||
"""获取执行步骤"""
|
||||
return self._steps
|
||||
|
|
|
|||
|
|
@ -91,6 +91,33 @@ class AgentEventEmitter:
|
|||
metadata=metadata,
|
||||
))
|
||||
|
||||
async def emit_llm_thought(self, thought: str, iteration: int = 0):
|
||||
"""发射 LLM 思考内容事件 - 核心!展示 LLM 在想什么"""
|
||||
display = thought[:500] + "..." if len(thought) > 500 else thought
|
||||
await self.emit(AgentEventData(
|
||||
event_type="llm_thought",
|
||||
message=f"💭 LLM 思考:\n{display}",
|
||||
metadata={"thought": thought, "iteration": iteration},
|
||||
))
|
||||
|
||||
async def emit_llm_decision(self, decision: str, reason: str = ""):
|
||||
"""发射 LLM 决策事件"""
|
||||
await self.emit(AgentEventData(
|
||||
event_type="llm_decision",
|
||||
message=f"💡 LLM 决策: {decision}" + (f" ({reason})" if reason else ""),
|
||||
metadata={"decision": decision, "reason": reason},
|
||||
))
|
||||
|
||||
async def emit_llm_action(self, action: str, action_input: Dict):
|
||||
"""发射 LLM 动作事件"""
|
||||
import json
|
||||
input_str = json.dumps(action_input, ensure_ascii=False)[:200]
|
||||
await self.emit(AgentEventData(
|
||||
event_type="llm_action",
|
||||
message=f"⚡ LLM 动作: {action}\n 参数: {input_str}",
|
||||
metadata={"action": action, "action_input": action_input},
|
||||
))
|
||||
|
||||
async def emit_tool_call(
|
||||
self,
|
||||
tool_name: str,
|
||||
|
|
|
|||
|
|
@ -1,12 +1,15 @@
|
|||
"""
|
||||
DeepAudit 审计工作流图
|
||||
使用 LangGraph 构建状态机式的 Agent 协作流程
|
||||
DeepAudit 审计工作流图 - LLM 驱动版
|
||||
使用 LangGraph 构建 LLM 驱动的 Agent 协作流程
|
||||
|
||||
重要改变:路由决策由 LLM 参与,而不是硬编码条件!
|
||||
"""
|
||||
|
||||
from typing import TypedDict, Annotated, List, Dict, Any, Optional, Literal
|
||||
from datetime import datetime
|
||||
import operator
|
||||
import logging
|
||||
import json
|
||||
|
||||
from langgraph.graph import StateGraph, END
|
||||
from langgraph.checkpoint.memory import MemorySaver
|
||||
|
|
@ -56,12 +59,16 @@ class AuditState(TypedDict):
|
|||
verified_findings: List[Finding]
|
||||
false_positives: List[str]
|
||||
|
||||
# 控制流
|
||||
# 控制流 - 🔥 关键:LLM 可以设置这些来影响路由
|
||||
current_phase: str
|
||||
iteration: int
|
||||
max_iterations: int
|
||||
should_continue_analysis: bool
|
||||
|
||||
# 🔥 新增:LLM 的路由决策
|
||||
llm_next_action: Optional[str] # LLM 建议的下一步: "continue_analysis", "verify", "report", "end"
|
||||
llm_routing_reason: Optional[str] # LLM 的决策理由
|
||||
|
||||
# 消息和事件
|
||||
messages: Annotated[List[Dict], operator.add]
|
||||
events: Annotated[List[Dict], operator.add]
|
||||
|
|
@ -72,43 +79,233 @@ class AuditState(TypedDict):
|
|||
error: Optional[str]
|
||||
|
||||
|
||||
# ============ 路由函数 ============
|
||||
# ============ LLM 路由决策器 ============
|
||||
|
||||
class LLMRouter:
|
||||
"""
|
||||
LLM 路由决策器
|
||||
让 LLM 来决定下一步应该做什么
|
||||
"""
|
||||
|
||||
def __init__(self, llm_service):
|
||||
self.llm_service = llm_service
|
||||
|
||||
async def decide_after_recon(self, state: AuditState) -> Dict[str, Any]:
|
||||
"""Recon 后让 LLM 决定下一步"""
|
||||
entry_points = state.get("entry_points", [])
|
||||
high_risk_areas = state.get("high_risk_areas", [])
|
||||
tech_stack = state.get("tech_stack", {})
|
||||
initial_findings = state.get("findings", [])
|
||||
|
||||
prompt = f"""作为安全审计的决策者,基于以下信息收集结果,决定下一步行动。
|
||||
|
||||
## 信息收集结果
|
||||
- 入口点数量: {len(entry_points)}
|
||||
- 高风险区域: {high_risk_areas[:10]}
|
||||
- 技术栈: {tech_stack}
|
||||
- 初步发现: {len(initial_findings)} 个
|
||||
|
||||
## 选项
|
||||
1. "analysis" - 继续进行漏洞分析(推荐:有入口点或高风险区域时)
|
||||
2. "end" - 结束审计(仅当没有任何可分析内容时)
|
||||
|
||||
请返回 JSON 格式:
|
||||
{{"action": "analysis或end", "reason": "决策理由"}}"""
|
||||
|
||||
try:
|
||||
response = await self.llm_service.chat_completion_raw(
|
||||
messages=[
|
||||
{"role": "system", "content": "你是安全审计流程的决策者,负责决定下一步行动。"},
|
||||
{"role": "user", "content": prompt},
|
||||
],
|
||||
temperature=0.1,
|
||||
max_tokens=200,
|
||||
)
|
||||
|
||||
content = response.get("content", "")
|
||||
# 提取 JSON
|
||||
import re
|
||||
json_match = re.search(r'\{.*\}', content, re.DOTALL)
|
||||
if json_match:
|
||||
result = json.loads(json_match.group())
|
||||
return result
|
||||
except Exception as e:
|
||||
logger.warning(f"LLM routing decision failed: {e}")
|
||||
|
||||
# 默认决策
|
||||
if entry_points or high_risk_areas:
|
||||
return {"action": "analysis", "reason": "有可分析内容"}
|
||||
return {"action": "end", "reason": "没有发现入口点或高风险区域"}
|
||||
|
||||
async def decide_after_analysis(self, state: AuditState) -> Dict[str, Any]:
|
||||
"""Analysis 后让 LLM 决定下一步"""
|
||||
findings = state.get("findings", [])
|
||||
iteration = state.get("iteration", 0)
|
||||
max_iterations = state.get("max_iterations", 3)
|
||||
|
||||
# 统计发现
|
||||
severity_counts = {"critical": 0, "high": 0, "medium": 0, "low": 0}
|
||||
for f in findings:
|
||||
sev = f.get("severity", "medium")
|
||||
severity_counts[sev] = severity_counts.get(sev, 0) + 1
|
||||
|
||||
prompt = f"""作为安全审计的决策者,基于以下分析结果,决定下一步行动。
|
||||
|
||||
## 分析结果
|
||||
- 总发现数: {len(findings)}
|
||||
- 严重程度分布: {severity_counts}
|
||||
- 当前迭代: {iteration}/{max_iterations}
|
||||
|
||||
## 选项
|
||||
1. "verification" - 验证发现的漏洞(推荐:有发现需要验证时)
|
||||
2. "analysis" - 继续深入分析(推荐:发现较少但还有迭代次数时)
|
||||
3. "report" - 生成报告(推荐:没有发现或已充分分析时)
|
||||
|
||||
请返回 JSON 格式:
|
||||
{{"action": "verification/analysis/report", "reason": "决策理由"}}"""
|
||||
|
||||
try:
|
||||
response = await self.llm_service.chat_completion_raw(
|
||||
messages=[
|
||||
{"role": "system", "content": "你是安全审计流程的决策者,负责决定下一步行动。"},
|
||||
{"role": "user", "content": prompt},
|
||||
],
|
||||
temperature=0.1,
|
||||
max_tokens=200,
|
||||
)
|
||||
|
||||
content = response.get("content", "")
|
||||
import re
|
||||
json_match = re.search(r'\{.*\}', content, re.DOTALL)
|
||||
if json_match:
|
||||
result = json.loads(json_match.group())
|
||||
return result
|
||||
except Exception as e:
|
||||
logger.warning(f"LLM routing decision failed: {e}")
|
||||
|
||||
# 默认决策
|
||||
if not findings:
|
||||
return {"action": "report", "reason": "没有发现漏洞"}
|
||||
if len(findings) >= 3 or iteration >= max_iterations:
|
||||
return {"action": "verification", "reason": "有足够的发现需要验证"}
|
||||
return {"action": "analysis", "reason": "发现较少,继续分析"}
|
||||
|
||||
async def decide_after_verification(self, state: AuditState) -> Dict[str, Any]:
|
||||
"""Verification 后让 LLM 决定下一步"""
|
||||
verified_findings = state.get("verified_findings", [])
|
||||
false_positives = state.get("false_positives", [])
|
||||
iteration = state.get("iteration", 0)
|
||||
max_iterations = state.get("max_iterations", 3)
|
||||
|
||||
prompt = f"""作为安全审计的决策者,基于以下验证结果,决定下一步行动。
|
||||
|
||||
## 验证结果
|
||||
- 已确认漏洞: {len(verified_findings)}
|
||||
- 误报数量: {len(false_positives)}
|
||||
- 当前迭代: {iteration}/{max_iterations}
|
||||
|
||||
## 选项
|
||||
1. "analysis" - 回到分析阶段重新分析(推荐:误报率太高时)
|
||||
2. "report" - 生成最终报告(推荐:验证完成时)
|
||||
|
||||
请返回 JSON 格式:
|
||||
{{"action": "analysis/report", "reason": "决策理由"}}"""
|
||||
|
||||
try:
|
||||
response = await self.llm_service.chat_completion_raw(
|
||||
messages=[
|
||||
{"role": "system", "content": "你是安全审计流程的决策者,负责决定下一步行动。"},
|
||||
{"role": "user", "content": prompt},
|
||||
],
|
||||
temperature=0.1,
|
||||
max_tokens=200,
|
||||
)
|
||||
|
||||
content = response.get("content", "")
|
||||
import re
|
||||
json_match = re.search(r'\{.*\}', content, re.DOTALL)
|
||||
if json_match:
|
||||
result = json.loads(json_match.group())
|
||||
return result
|
||||
except Exception as e:
|
||||
logger.warning(f"LLM routing decision failed: {e}")
|
||||
|
||||
# 默认决策
|
||||
if len(false_positives) > len(verified_findings) and iteration < max_iterations:
|
||||
return {"action": "analysis", "reason": "误报率较高,需要重新分析"}
|
||||
return {"action": "report", "reason": "验证完成,生成报告"}
|
||||
|
||||
|
||||
# ============ 路由函数 (结合 LLM 决策) ============
|
||||
|
||||
def route_after_recon(state: AuditState) -> Literal["analysis", "end"]:
|
||||
"""Recon 后的路由决策"""
|
||||
# 如果没有发现入口点或高风险区域,直接结束
|
||||
"""
|
||||
Recon 后的路由决策
|
||||
优先使用 LLM 的决策,否则使用默认逻辑
|
||||
"""
|
||||
# 检查 LLM 是否有决策
|
||||
llm_action = state.get("llm_next_action")
|
||||
if llm_action:
|
||||
logger.info(f"Using LLM routing decision: {llm_action}, reason: {state.get('llm_routing_reason')}")
|
||||
if llm_action == "end":
|
||||
return "end"
|
||||
return "analysis"
|
||||
|
||||
# 默认逻辑(作为 fallback)
|
||||
if not state.get("entry_points") and not state.get("high_risk_areas"):
|
||||
return "end"
|
||||
return "analysis"
|
||||
|
||||
|
||||
def route_after_analysis(state: AuditState) -> Literal["verification", "analysis", "report"]:
|
||||
"""Analysis 后的路由决策"""
|
||||
"""
|
||||
Analysis 后的路由决策
|
||||
优先使用 LLM 的决策
|
||||
"""
|
||||
# 检查 LLM 是否有决策
|
||||
llm_action = state.get("llm_next_action")
|
||||
if llm_action:
|
||||
logger.info(f"Using LLM routing decision: {llm_action}, reason: {state.get('llm_routing_reason')}")
|
||||
if llm_action == "verification":
|
||||
return "verification"
|
||||
elif llm_action == "analysis":
|
||||
return "analysis"
|
||||
elif llm_action == "report":
|
||||
return "report"
|
||||
|
||||
# 默认逻辑
|
||||
findings = state.get("findings", [])
|
||||
iteration = state.get("iteration", 0)
|
||||
max_iterations = state.get("max_iterations", 3)
|
||||
should_continue = state.get("should_continue_analysis", False)
|
||||
|
||||
# 如果没有发现,直接生成报告
|
||||
if not findings:
|
||||
return "report"
|
||||
|
||||
# 如果需要继续分析且未达到最大迭代
|
||||
if should_continue and iteration < max_iterations:
|
||||
return "analysis"
|
||||
|
||||
# 有发现需要验证
|
||||
return "verification"
|
||||
|
||||
|
||||
def route_after_verification(state: AuditState) -> Literal["analysis", "report"]:
|
||||
"""Verification 后的路由决策"""
|
||||
# 如果验证发现了误报,可能需要重新分析
|
||||
"""
|
||||
Verification 后的路由决策
|
||||
优先使用 LLM 的决策
|
||||
"""
|
||||
# 检查 LLM 是否有决策
|
||||
llm_action = state.get("llm_next_action")
|
||||
if llm_action:
|
||||
logger.info(f"Using LLM routing decision: {llm_action}, reason: {state.get('llm_routing_reason')}")
|
||||
if llm_action == "analysis":
|
||||
return "analysis"
|
||||
return "report"
|
||||
|
||||
# 默认逻辑
|
||||
false_positives = state.get("false_positives", [])
|
||||
iteration = state.get("iteration", 0)
|
||||
max_iterations = state.get("max_iterations", 3)
|
||||
|
||||
# 如果误报率太高且还有迭代次数,回到分析
|
||||
if len(false_positives) > len(state.get("verified_findings", [])) and iteration < max_iterations:
|
||||
return "analysis"
|
||||
|
||||
|
|
@ -123,6 +320,7 @@ def create_audit_graph(
|
|||
verification_node,
|
||||
report_node,
|
||||
checkpointer: Optional[MemorySaver] = None,
|
||||
llm_service=None, # 用于 LLM 路由决策
|
||||
) -> StateGraph:
|
||||
"""
|
||||
创建审计工作流图
|
||||
|
|
@ -132,7 +330,8 @@ def create_audit_graph(
|
|||
analysis_node: 漏洞分析节点
|
||||
verification_node: 漏洞验证节点
|
||||
report_node: 报告生成节点
|
||||
checkpointer: 检查点存储器(用于状态持久化)
|
||||
checkpointer: 检查点存储器
|
||||
llm_service: LLM 服务(用于路由决策)
|
||||
|
||||
Returns:
|
||||
编译后的 StateGraph
|
||||
|
|
@ -143,19 +342,19 @@ def create_audit_graph(
|
|||
│
|
||||
▼
|
||||
┌──────┐
|
||||
│Recon │ 信息收集
|
||||
│Recon │ 信息收集 (LLM 驱动)
|
||||
└──┬───┘
|
||||
│
|
||||
│ LLM 决定
|
||||
▼
|
||||
┌──────────┐
|
||||
│ Analysis │◄─────┐ 漏洞分析(可循环)
|
||||
│ Analysis │◄─────┐ 漏洞分析 (LLM 驱动,可循环)
|
||||
└────┬─────┘ │
|
||||
│ │
|
||||
│ LLM 决定 │
|
||||
▼ │
|
||||
┌────────────┐ │
|
||||
│Verification│────┘ 漏洞验证(可回溯)
|
||||
│Verification│────┘ 漏洞验证 (LLM 驱动,可回溯)
|
||||
└─────┬──────┘
|
||||
│
|
||||
│ LLM 决定
|
||||
▼
|
||||
┌──────────┐
|
||||
│ Report │ 报告生成
|
||||
|
|
@ -168,10 +367,53 @@ def create_audit_graph(
|
|||
# 创建状态图
|
||||
workflow = StateGraph(AuditState)
|
||||
|
||||
# 如果有 LLM 服务,创建路由决策器
|
||||
llm_router = LLMRouter(llm_service) if llm_service else None
|
||||
|
||||
# 包装节点以添加 LLM 路由决策
|
||||
async def recon_with_routing(state):
|
||||
result = await recon_node(state)
|
||||
|
||||
# LLM 决定下一步
|
||||
if llm_router:
|
||||
decision = await llm_router.decide_after_recon({**state, **result})
|
||||
result["llm_next_action"] = decision.get("action")
|
||||
result["llm_routing_reason"] = decision.get("reason")
|
||||
|
||||
return result
|
||||
|
||||
async def analysis_with_routing(state):
|
||||
result = await analysis_node(state)
|
||||
|
||||
# LLM 决定下一步
|
||||
if llm_router:
|
||||
decision = await llm_router.decide_after_analysis({**state, **result})
|
||||
result["llm_next_action"] = decision.get("action")
|
||||
result["llm_routing_reason"] = decision.get("reason")
|
||||
|
||||
return result
|
||||
|
||||
async def verification_with_routing(state):
|
||||
result = await verification_node(state)
|
||||
|
||||
# LLM 决定下一步
|
||||
if llm_router:
|
||||
decision = await llm_router.decide_after_verification({**state, **result})
|
||||
result["llm_next_action"] = decision.get("action")
|
||||
result["llm_routing_reason"] = decision.get("reason")
|
||||
|
||||
return result
|
||||
|
||||
# 添加节点
|
||||
workflow.add_node("recon", recon_node)
|
||||
workflow.add_node("analysis", analysis_node)
|
||||
workflow.add_node("verification", verification_node)
|
||||
if llm_router:
|
||||
workflow.add_node("recon", recon_with_routing)
|
||||
workflow.add_node("analysis", analysis_with_routing)
|
||||
workflow.add_node("verification", verification_with_routing)
|
||||
else:
|
||||
workflow.add_node("recon", recon_node)
|
||||
workflow.add_node("analysis", analysis_node)
|
||||
workflow.add_node("verification", verification_node)
|
||||
|
||||
workflow.add_node("report", report_node)
|
||||
|
||||
# 设置入口点
|
||||
|
|
@ -192,7 +434,7 @@ def create_audit_graph(
|
|||
route_after_analysis,
|
||||
{
|
||||
"verification": "verification",
|
||||
"analysis": "analysis", # 循环
|
||||
"analysis": "analysis",
|
||||
"report": "report",
|
||||
}
|
||||
)
|
||||
|
|
@ -201,7 +443,7 @@ def create_audit_graph(
|
|||
"verification",
|
||||
route_after_verification,
|
||||
{
|
||||
"analysis": "analysis", # 回溯
|
||||
"analysis": "analysis",
|
||||
"report": "report",
|
||||
}
|
||||
)
|
||||
|
|
@ -225,50 +467,42 @@ def create_audit_graph_with_human(
|
|||
report_node,
|
||||
human_review_node,
|
||||
checkpointer: Optional[MemorySaver] = None,
|
||||
llm_service=None,
|
||||
) -> StateGraph:
|
||||
"""
|
||||
创建带人机协作的审计工作流图
|
||||
|
||||
在验证阶段后增加人工审核节点
|
||||
|
||||
工作流结构:
|
||||
|
||||
START
|
||||
│
|
||||
▼
|
||||
┌──────┐
|
||||
│Recon │
|
||||
└──┬───┘
|
||||
│
|
||||
▼
|
||||
┌──────────┐
|
||||
│ Analysis │◄─────┐
|
||||
└────┬─────┘ │
|
||||
│ │
|
||||
▼ │
|
||||
┌────────────┐ │
|
||||
│Verification│────┘
|
||||
└─────┬──────┘
|
||||
│
|
||||
▼
|
||||
┌─────────────┐
|
||||
│Human Review │ ← 人工审核(可跳过)
|
||||
└──────┬──────┘
|
||||
│
|
||||
▼
|
||||
┌──────────┐
|
||||
│ Report │
|
||||
└────┬─────┘
|
||||
│
|
||||
▼
|
||||
END
|
||||
"""
|
||||
|
||||
workflow = StateGraph(AuditState)
|
||||
llm_router = LLMRouter(llm_service) if llm_service else None
|
||||
|
||||
# 包装节点
|
||||
async def recon_with_routing(state):
|
||||
result = await recon_node(state)
|
||||
if llm_router:
|
||||
decision = await llm_router.decide_after_recon({**state, **result})
|
||||
result["llm_next_action"] = decision.get("action")
|
||||
result["llm_routing_reason"] = decision.get("reason")
|
||||
return result
|
||||
|
||||
async def analysis_with_routing(state):
|
||||
result = await analysis_node(state)
|
||||
if llm_router:
|
||||
decision = await llm_router.decide_after_analysis({**state, **result})
|
||||
result["llm_next_action"] = decision.get("action")
|
||||
result["llm_routing_reason"] = decision.get("reason")
|
||||
return result
|
||||
|
||||
# 添加节点
|
||||
workflow.add_node("recon", recon_node)
|
||||
workflow.add_node("analysis", analysis_node)
|
||||
if llm_router:
|
||||
workflow.add_node("recon", recon_with_routing)
|
||||
workflow.add_node("analysis", analysis_with_routing)
|
||||
else:
|
||||
workflow.add_node("recon", recon_node)
|
||||
workflow.add_node("analysis", analysis_node)
|
||||
|
||||
workflow.add_node("verification", verification_node)
|
||||
workflow.add_node("human_review", human_review_node)
|
||||
workflow.add_node("report", report_node)
|
||||
|
|
@ -296,7 +530,6 @@ def create_audit_graph_with_human(
|
|||
|
||||
# Human Review 后的路由
|
||||
def route_after_human(state: AuditState) -> Literal["analysis", "report"]:
|
||||
# 人工可以决定重新分析或继续
|
||||
if state.get("should_continue_analysis"):
|
||||
return "analysis"
|
||||
return "report"
|
||||
|
|
@ -340,15 +573,6 @@ class AuditGraphRunner:
|
|||
) -> Dict[str, Any]:
|
||||
"""
|
||||
执行审计工作流
|
||||
|
||||
Args:
|
||||
project_root: 项目根目录
|
||||
project_info: 项目信息
|
||||
config: 审计配置
|
||||
task_id: 任务 ID
|
||||
|
||||
Returns:
|
||||
最终状态
|
||||
"""
|
||||
# 初始状态
|
||||
initial_state: AuditState = {
|
||||
|
|
@ -367,6 +591,8 @@ class AuditGraphRunner:
|
|||
"iteration": 0,
|
||||
"max_iterations": config.get("max_iterations", 3),
|
||||
"should_continue_analysis": False,
|
||||
"llm_next_action": None,
|
||||
"llm_routing_reason": None,
|
||||
"messages": [],
|
||||
"events": [],
|
||||
"summary": None,
|
||||
|
|
@ -374,32 +600,32 @@ class AuditGraphRunner:
|
|||
"error": None,
|
||||
}
|
||||
|
||||
# 配置
|
||||
run_config = {
|
||||
"configurable": {
|
||||
"thread_id": task_id,
|
||||
}
|
||||
}
|
||||
|
||||
# 执行图
|
||||
try:
|
||||
# 流式执行
|
||||
async for event in self.graph.astream(initial_state, config=run_config):
|
||||
# 发射事件
|
||||
if self.event_emitter:
|
||||
for node_name, node_state in event.items():
|
||||
await self.event_emitter.emit_info(
|
||||
f"节点 {node_name} 完成"
|
||||
)
|
||||
|
||||
# 发射发现事件
|
||||
# 发射 LLM 路由决策事件
|
||||
if node_state.get("llm_routing_reason"):
|
||||
await self.event_emitter.emit_info(
|
||||
f"🧠 LLM 决策: {node_state.get('llm_next_action')} - {node_state.get('llm_routing_reason')}"
|
||||
)
|
||||
|
||||
if node_name == "analysis" and node_state.get("findings"):
|
||||
new_findings = node_state["findings"]
|
||||
await self.event_emitter.emit_info(
|
||||
f"发现 {len(new_findings)} 个潜在漏洞"
|
||||
)
|
||||
|
||||
# 获取最终状态
|
||||
final_state = self.graph.get_state(run_config)
|
||||
return final_state.values
|
||||
|
||||
|
|
@ -412,44 +638,27 @@ class AuditGraphRunner:
|
|||
initial_state: AuditState,
|
||||
human_feedback_callback,
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
带人机协作的执行
|
||||
|
||||
Args:
|
||||
initial_state: 初始状态
|
||||
human_feedback_callback: 人工反馈回调函数
|
||||
|
||||
Returns:
|
||||
最终状态
|
||||
"""
|
||||
"""带人机协作的执行"""
|
||||
run_config = {
|
||||
"configurable": {
|
||||
"thread_id": initial_state["task_id"],
|
||||
}
|
||||
}
|
||||
|
||||
# 执行到人工审核节点
|
||||
async for event in self.graph.astream(initial_state, config=run_config):
|
||||
pass
|
||||
|
||||
# 获取当前状态
|
||||
current_state = self.graph.get_state(run_config)
|
||||
|
||||
# 如果在人工审核节点暂停
|
||||
if current_state.next == ("human_review",):
|
||||
# 调用人工反馈
|
||||
human_decision = await human_feedback_callback(current_state.values)
|
||||
|
||||
# 更新状态并继续
|
||||
updated_state = {
|
||||
**current_state.values,
|
||||
"should_continue_analysis": human_decision.get("continue_analysis", False),
|
||||
}
|
||||
|
||||
# 继续执行
|
||||
async for event in self.graph.astream(updated_state, config=run_config):
|
||||
pass
|
||||
|
||||
# 返回最终状态
|
||||
return self.graph.get_state(run_config).values
|
||||
|
||||
|
|
|
|||
|
|
@ -403,10 +403,19 @@ class AgentRunner:
|
|||
Returns:
|
||||
最终状态
|
||||
"""
|
||||
result = {}
|
||||
async for _ in self.run_with_streaming():
|
||||
pass # 消费所有事件
|
||||
return result
|
||||
final_state = {}
|
||||
try:
|
||||
async for event in self.run_with_streaming():
|
||||
# 收集最终状态
|
||||
if event.event_type == StreamEventType.TASK_COMPLETE:
|
||||
final_state = event.data
|
||||
elif event.event_type == StreamEventType.TASK_ERROR:
|
||||
final_state = {"success": False, "error": event.data.get("error")}
|
||||
except Exception as e:
|
||||
logger.error(f"Agent run failed: {e}", exc_info=True)
|
||||
final_state = {"success": False, "error": str(e)}
|
||||
|
||||
return final_state
|
||||
|
||||
async def run_with_streaming(self) -> AsyncGenerator[StreamEvent, None]:
|
||||
"""
|
||||
|
|
|
|||
|
|
@ -15,12 +15,20 @@ logger = logging.getLogger(__name__)
|
|||
|
||||
class StreamEventType(str, Enum):
|
||||
"""流式事件类型"""
|
||||
# LLM 相关
|
||||
# 🔥 LLM 思考相关 - 这些是最重要的!展示 LLM 的大脑活动
|
||||
LLM_START = "llm_start" # LLM 开始思考
|
||||
LLM_THOUGHT = "llm_thought" # LLM 思考内容 ⭐ 核心
|
||||
LLM_DECISION = "llm_decision" # LLM 决策 ⭐ 核心
|
||||
LLM_ACTION = "llm_action" # LLM 动作
|
||||
LLM_OBSERVATION = "llm_observation" # LLM 观察结果
|
||||
LLM_COMPLETE = "llm_complete" # LLM 完成
|
||||
|
||||
# LLM Token 流 (实时输出)
|
||||
THINKING_START = "thinking_start" # 开始思考
|
||||
THINKING_TOKEN = "thinking_token" # 思考 Token
|
||||
THINKING_TOKEN = "thinking_token" # 思考 Token (流式)
|
||||
THINKING_END = "thinking_end" # 思考结束
|
||||
|
||||
# 工具调用相关
|
||||
# 工具调用相关 - LLM 决定调用工具
|
||||
TOOL_CALL_START = "tool_call_start" # 工具调用开始
|
||||
TOOL_CALL_INPUT = "tool_call_input" # 工具输入参数
|
||||
TOOL_CALL_OUTPUT = "tool_call_output" # 工具输出结果
|
||||
|
|
|
|||
|
|
@ -4,13 +4,13 @@
|
|||
* 支持 LLM 思考过程和工具调用的实时流式展示
|
||||
*/
|
||||
|
||||
import { useState, useEffect, useRef, useCallback } from "react";
|
||||
import { useState, useEffect, useRef, useCallback, useMemo } from "react";
|
||||
import { useParams, useNavigate } from "react-router-dom";
|
||||
import {
|
||||
Terminal, Bot, Cpu, Shield, AlertTriangle, CheckCircle2,
|
||||
Terminal, Bot, Shield, AlertTriangle, CheckCircle2,
|
||||
Loader2, Code, Zap, Activity, ChevronRight, XCircle,
|
||||
FileCode, Search, Bug, Lock, Play, Square, RefreshCw,
|
||||
ArrowLeft, Download, ExternalLink, Brain, Wrench,
|
||||
FileCode, Search, Bug, Square, RefreshCw,
|
||||
ArrowLeft, ExternalLink, Brain, Wrench,
|
||||
ChevronDown, ChevronUp, Clock, Sparkles
|
||||
} from "lucide-react";
|
||||
import { Button } from "@/components/ui/button";
|
||||
|
|
@ -26,42 +26,77 @@ import {
|
|||
getAgentEvents,
|
||||
getAgentFindings,
|
||||
cancelAgentTask,
|
||||
streamAgentEvents,
|
||||
} from "@/shared/api/agentTasks";
|
||||
|
||||
// 事件类型图标映射
|
||||
// 事件类型图标映射 - 🔥 重点展示 LLM 相关事件
|
||||
const eventTypeIcons: Record<string, React.ReactNode> = {
|
||||
// 🧠 LLM 核心事件 - 最重要!
|
||||
llm_start: <Brain className="w-3 h-3 text-purple-400 animate-pulse" />,
|
||||
llm_thought: <Sparkles className="w-3 h-3 text-purple-300" />,
|
||||
llm_decision: <Zap className="w-3 h-3 text-yellow-400" />,
|
||||
llm_action: <Zap className="w-3 h-3 text-orange-400" />,
|
||||
llm_observation: <Search className="w-3 h-3 text-blue-400" />,
|
||||
llm_complete: <CheckCircle2 className="w-3 h-3 text-green-400" />,
|
||||
|
||||
// 阶段相关
|
||||
phase_start: <Zap className="w-3 h-3 text-cyan-400" />,
|
||||
phase_complete: <CheckCircle2 className="w-3 h-3 text-green-400" />,
|
||||
thinking: <Cpu className="w-3 h-3 text-purple-400" />,
|
||||
tool_call: <Code className="w-3 h-3 text-yellow-400" />,
|
||||
thinking: <Brain className="w-3 h-3 text-purple-400" />,
|
||||
|
||||
// 工具相关 - LLM 决定的工具调用
|
||||
tool_call: <Wrench className="w-3 h-3 text-yellow-400" />,
|
||||
tool_result: <CheckCircle2 className="w-3 h-3 text-green-400" />,
|
||||
tool_error: <XCircle className="w-3 h-3 text-red-400" />,
|
||||
|
||||
// 发现相关
|
||||
finding: <Bug className="w-3 h-3 text-orange-400" />,
|
||||
finding_new: <Bug className="w-3 h-3 text-orange-400" />,
|
||||
finding_verified: <Shield className="w-3 h-3 text-red-400" />,
|
||||
|
||||
// 状态相关
|
||||
info: <Activity className="w-3 h-3 text-blue-400" />,
|
||||
warning: <AlertTriangle className="w-3 h-3 text-yellow-400" />,
|
||||
error: <XCircle className="w-3 h-3 text-red-500" />,
|
||||
progress: <RefreshCw className="w-3 h-3 text-cyan-400 animate-spin" />,
|
||||
|
||||
// 任务相关
|
||||
task_complete: <CheckCircle2 className="w-3 h-3 text-green-500" />,
|
||||
task_error: <XCircle className="w-3 h-3 text-red-500" />,
|
||||
task_cancel: <Square className="w-3 h-3 text-yellow-500" />,
|
||||
};
|
||||
|
||||
// 事件类型颜色映射
|
||||
// 事件类型颜色映射 - 🔥 LLM 事件突出显示
|
||||
const eventTypeColors: Record<string, string> = {
|
||||
// 🧠 LLM 核心事件 - 使用紫色系突出
|
||||
llm_start: "text-purple-400 font-semibold",
|
||||
llm_thought: "text-purple-300 bg-purple-950/30 rounded px-1", // 思考内容特别高亮
|
||||
llm_decision: "text-yellow-300 font-semibold", // 决策特别突出
|
||||
llm_action: "text-orange-300 font-medium",
|
||||
llm_observation: "text-blue-300",
|
||||
llm_complete: "text-green-400 font-semibold",
|
||||
|
||||
// 阶段相关
|
||||
phase_start: "text-cyan-400 font-bold",
|
||||
phase_complete: "text-green-400",
|
||||
thinking: "text-purple-300",
|
||||
|
||||
// 工具相关
|
||||
tool_call: "text-yellow-300",
|
||||
tool_result: "text-green-300",
|
||||
tool_error: "text-red-400",
|
||||
|
||||
// 发现相关
|
||||
finding: "text-orange-300 font-medium",
|
||||
finding_new: "text-orange-300",
|
||||
finding_verified: "text-red-300",
|
||||
finding_verified: "text-red-300 font-medium",
|
||||
|
||||
// 状态相关
|
||||
info: "text-gray-300",
|
||||
warning: "text-yellow-300",
|
||||
error: "text-red-400",
|
||||
progress: "text-cyan-300",
|
||||
|
||||
// 任务相关
|
||||
task_complete: "text-green-400 font-bold",
|
||||
task_error: "text-red-400 font-bold",
|
||||
task_cancel: "text-yellow-400",
|
||||
|
|
@ -99,30 +134,6 @@ export default function AgentAuditPage() {
|
|||
|
||||
const eventsEndRef = useRef<HTMLDivElement>(null);
|
||||
const thinkingEndRef = useRef<HTMLDivElement>(null);
|
||||
const abortControllerRef = useRef<AbortController | null>(null);
|
||||
|
||||
// 使用增强版流式 Hook
|
||||
const {
|
||||
thinking,
|
||||
isThinking,
|
||||
toolCalls,
|
||||
currentPhase: streamPhase,
|
||||
progress: streamProgress,
|
||||
connect: connectStream,
|
||||
disconnect: disconnectStream,
|
||||
isConnected: isStreamConnected,
|
||||
} = useAgentStream(taskId || null, {
|
||||
includeThinking: true,
|
||||
includeToolCalls: true,
|
||||
onFinding: () => loadFindings(),
|
||||
onComplete: () => {
|
||||
loadTask();
|
||||
loadFindings();
|
||||
},
|
||||
onError: (err) => {
|
||||
console.error("Stream error:", err);
|
||||
},
|
||||
});
|
||||
|
||||
// 是否完成
|
||||
const isComplete = task?.status === "completed" || task?.status === "failed" || task?.status === "cancelled";
|
||||
|
|
@ -134,11 +145,16 @@ export default function AgentAuditPage() {
|
|||
try {
|
||||
const taskData = await getAgentTask(taskId);
|
||||
setTask(taskData);
|
||||
} catch (error) {
|
||||
} catch (error: any) {
|
||||
console.error("Failed to load task:", error);
|
||||
toast.error("加载任务失败");
|
||||
const errorMessage = error?.response?.data?.detail || error?.message || "未知错误";
|
||||
toast.error(`加载任务失败: ${errorMessage}`);
|
||||
// 如果是 404,可能是任务不存在
|
||||
if (error?.response?.status === 404) {
|
||||
setTimeout(() => navigate("/tasks"), 2000);
|
||||
}
|
||||
}
|
||||
}, [taskId]);
|
||||
}, [taskId, navigate]);
|
||||
|
||||
// 加载事件
|
||||
const loadEvents = useCallback(async () => {
|
||||
|
|
@ -164,6 +180,32 @@ export default function AgentAuditPage() {
|
|||
}
|
||||
}, [taskId]);
|
||||
|
||||
// 🔥 稳定化回调函数,避免重复创建 connect/disconnect
|
||||
const streamOptions = useMemo(() => ({
|
||||
includeThinking: true,
|
||||
includeToolCalls: true,
|
||||
onFinding: () => loadFindings(),
|
||||
onComplete: () => {
|
||||
loadTask();
|
||||
loadFindings();
|
||||
},
|
||||
onError: (err: string) => {
|
||||
console.error("Stream error:", err);
|
||||
},
|
||||
}), [loadFindings, loadTask]);
|
||||
|
||||
// 使用增强版流式 Hook
|
||||
const {
|
||||
thinking,
|
||||
isThinking,
|
||||
toolCalls,
|
||||
// currentPhase: streamPhase, // 暂未使用
|
||||
// progress: streamProgress, // 暂未使用
|
||||
connect: connectStream,
|
||||
disconnect: disconnectStream,
|
||||
isConnected: isStreamConnected,
|
||||
} = useAgentStream(taskId || null, streamOptions);
|
||||
|
||||
// 初始化加载
|
||||
useEffect(() => {
|
||||
const init = async () => {
|
||||
|
|
@ -175,10 +217,11 @@ export default function AgentAuditPage() {
|
|||
init();
|
||||
}, [loadTask, loadEvents, loadFindings]);
|
||||
|
||||
// 连接增强版流式 API
|
||||
// 🔥 使用增强版流式 API(优先)
|
||||
useEffect(() => {
|
||||
if (!taskId || isComplete || isLoading) return;
|
||||
|
||||
// 连接流式 API
|
||||
connectStream();
|
||||
setIsStreaming(true);
|
||||
|
||||
|
|
@ -186,49 +229,44 @@ export default function AgentAuditPage() {
|
|||
disconnectStream();
|
||||
setIsStreaming(false);
|
||||
};
|
||||
}, [taskId, isComplete, isLoading, connectStream, disconnectStream]);
|
||||
// eslint-disable-next-line react-hooks/exhaustive-deps
|
||||
}, [taskId, isComplete, isLoading]); // connectStream/disconnectStream 是稳定的,不需要作为依赖
|
||||
|
||||
// 旧版事件流(作为后备)
|
||||
// 🔥 后备:如果流式连接失败,使用轮询获取事件(仅作为后备)
|
||||
useEffect(() => {
|
||||
if (!taskId || isComplete || isLoading) return;
|
||||
|
||||
const startStreaming = async () => {
|
||||
abortControllerRef.current = new AbortController();
|
||||
|
||||
// 如果流式连接已建立,不需要轮询
|
||||
if (isStreamConnected) {
|
||||
return;
|
||||
}
|
||||
|
||||
// 每 5 秒轮询一次事件(作为后备机制)
|
||||
const pollInterval = setInterval(async () => {
|
||||
try {
|
||||
const lastSequence = events.length > 0 ? Math.max(...events.map(e => e.sequence)) : 0;
|
||||
const newEvents = await getAgentEvents(taskId, { after_sequence: lastSequence, limit: 50 });
|
||||
|
||||
for await (const event of streamAgentEvents(taskId, lastSequence, abortControllerRef.current.signal)) {
|
||||
if (newEvents.length > 0) {
|
||||
setEvents(prev => {
|
||||
// 避免重复
|
||||
if (prev.some(e => e.id === event.id)) return prev;
|
||||
return [...prev, event];
|
||||
// 合并新事件,避免重复
|
||||
const existingIds = new Set(prev.map(e => e.id));
|
||||
const uniqueNew = newEvents.filter(e => !existingIds.has(e.id));
|
||||
return [...prev, ...uniqueNew];
|
||||
});
|
||||
|
||||
// 如果是发现事件,刷新发现列表
|
||||
if (event.event_type.startsWith("finding_")) {
|
||||
loadFindings();
|
||||
}
|
||||
|
||||
// 如果是结束事件,刷新任务状态
|
||||
if (["task_complete", "task_error", "task_cancel"].includes(event.event_type)) {
|
||||
loadTask();
|
||||
// 如果有发现事件,刷新发现列表
|
||||
if (newEvents.some(e => e.event_type.startsWith("finding_"))) {
|
||||
loadFindings();
|
||||
}
|
||||
}
|
||||
} catch (error) {
|
||||
if ((error as Error).name !== "AbortError") {
|
||||
console.error("Event stream error:", error);
|
||||
}
|
||||
console.error("Failed to poll events:", error);
|
||||
}
|
||||
};
|
||||
}, 5000);
|
||||
|
||||
startStreaming();
|
||||
|
||||
return () => {
|
||||
abortControllerRef.current?.abort();
|
||||
};
|
||||
}, [taskId, isComplete, isLoading, loadTask, loadFindings]);
|
||||
return () => clearInterval(pollInterval);
|
||||
}, [taskId, isComplete, isLoading, isStreamConnected, events.length, loadFindings]);
|
||||
|
||||
// 自动滚动
|
||||
useEffect(() => {
|
||||
|
|
@ -244,17 +282,17 @@ export default function AgentAuditPage() {
|
|||
return () => clearInterval(interval);
|
||||
}, []);
|
||||
|
||||
// 定期轮询任务状态(作为 SSE 的后备机制)
|
||||
// 定期轮询任务状态(作为 SSE 的后备机制)- 🔥 增加间隔,避免资源耗尽
|
||||
useEffect(() => {
|
||||
if (!taskId || isComplete || isLoading) return;
|
||||
|
||||
// 每 3 秒轮询一次任务状态
|
||||
// 🔥 每 10 秒轮询一次(而不是 3 秒),减少资源消耗
|
||||
const pollInterval = setInterval(async () => {
|
||||
try {
|
||||
const taskData = await getAgentTask(taskId);
|
||||
setTask(taskData);
|
||||
|
||||
// 如果任务已完成/失败/取消,刷新其他数据
|
||||
// 如果任务已完成/失败/取消,刷新其他数据并停止轮询
|
||||
if (taskData.status === "completed" || taskData.status === "failed" || taskData.status === "cancelled") {
|
||||
await loadEvents();
|
||||
await loadFindings();
|
||||
|
|
@ -262,11 +300,13 @@ export default function AgentAuditPage() {
|
|||
}
|
||||
} catch (error) {
|
||||
console.error("Failed to poll task status:", error);
|
||||
// 🔥 如果连续失败,停止轮询避免资源耗尽
|
||||
clearInterval(pollInterval);
|
||||
}
|
||||
}, 3000);
|
||||
}, 10000); // 🔥 改为 10 秒
|
||||
|
||||
return () => clearInterval(pollInterval);
|
||||
}, [taskId, isComplete, isLoading, loadEvents, loadFindings]);
|
||||
}, [taskId, isComplete, isLoading]); // 🔥 移除函数依赖
|
||||
|
||||
// 取消任务
|
||||
const handleCancel = async () => {
|
||||
|
|
@ -371,57 +411,73 @@ export default function AgentAuditPage() {
|
|||
{/* 左侧:执行日志 */}
|
||||
<div className="flex-1 p-4 flex flex-col min-w-0">
|
||||
|
||||
{/* 思考过程展示区域 */}
|
||||
{/* 🧠 LLM 思考过程展示区域 - 核心!展示 LLM 的大脑活动 */}
|
||||
{(isThinking || thinking) && showThinking && (
|
||||
<div className="mb-4 bg-purple-950/30 rounded-lg border border-purple-800/50 overflow-hidden">
|
||||
<div className="mb-4 bg-purple-950/40 rounded-lg border-2 border-purple-700/60 overflow-hidden shadow-lg shadow-purple-900/20">
|
||||
<div
|
||||
className="flex items-center justify-between px-3 py-2 bg-purple-900/30 border-b border-purple-800/30 cursor-pointer"
|
||||
className="flex items-center justify-between px-4 py-3 bg-purple-900/50 border-b border-purple-700/50 cursor-pointer"
|
||||
onClick={() => setShowThinking(!showThinking)}
|
||||
>
|
||||
<div className="flex items-center gap-2 text-xs text-purple-400">
|
||||
<Brain className={`w-4 h-4 ${isThinking ? "animate-pulse" : ""}`} />
|
||||
<span className="uppercase tracking-wider">AI Thinking</span>
|
||||
<div className="flex items-center gap-3 text-sm text-purple-300">
|
||||
<div className="p-1.5 bg-purple-800/50 rounded-lg">
|
||||
<Brain className={`w-5 h-5 ${isThinking ? "animate-pulse" : ""}`} />
|
||||
</div>
|
||||
<div>
|
||||
<span className="uppercase tracking-wider font-semibold">🧠 LLM Thinking</span>
|
||||
<span className="text-purple-400 ml-2 text-xs">Agent 的大脑正在工作</span>
|
||||
</div>
|
||||
{isThinking && (
|
||||
<span className="flex items-center gap-1 text-purple-300">
|
||||
<span className="flex items-center gap-1 text-purple-200 bg-purple-800/50 px-2 py-0.5 rounded-full text-xs">
|
||||
<Sparkles className="w-3 h-3 animate-spin" />
|
||||
<span className="text-[10px]">Processing...</span>
|
||||
<span>思考中...</span>
|
||||
</span>
|
||||
)}
|
||||
</div>
|
||||
{showThinking ? <ChevronUp className="w-4 h-4 text-purple-400" /> : <ChevronDown className="w-4 h-4 text-purple-400" />}
|
||||
{showThinking ? <ChevronUp className="w-5 h-5 text-purple-400" /> : <ChevronDown className="w-5 h-5 text-purple-400" />}
|
||||
</div>
|
||||
|
||||
<div className="max-h-40 overflow-y-auto">
|
||||
<div className="p-3 text-sm text-purple-200/80 font-mono whitespace-pre-wrap">
|
||||
{thinking || "正在思考..."}
|
||||
{isThinking && <span className="animate-pulse text-purple-400">▌</span>}
|
||||
<div className="max-h-52 overflow-y-auto bg-[#1a1025]">
|
||||
<div className="p-4 text-sm text-purple-100 font-mono whitespace-pre-wrap leading-relaxed">
|
||||
{thinking || "🤔 正在思考下一步..."}
|
||||
{isThinking && <span className="animate-pulse text-purple-400 text-lg">▌</span>}
|
||||
</div>
|
||||
<div ref={thinkingEndRef} />
|
||||
</div>
|
||||
</div>
|
||||
)}
|
||||
|
||||
{/* 工具调用展示区域 */}
|
||||
{/* 🔧 LLM 工具调用展示区域 - LLM 决定调用的工具 */}
|
||||
{toolCalls.length > 0 && showToolDetails && (
|
||||
<div className="mb-4 bg-yellow-950/20 rounded-lg border border-yellow-800/30 overflow-hidden">
|
||||
<div className="mb-4 bg-yellow-950/30 rounded-lg border-2 border-yellow-700/50 overflow-hidden shadow-lg shadow-yellow-900/10">
|
||||
<div
|
||||
className="flex items-center justify-between px-3 py-2 bg-yellow-900/20 border-b border-yellow-800/20 cursor-pointer"
|
||||
className="flex items-center justify-between px-4 py-3 bg-yellow-900/30 border-b border-yellow-700/40 cursor-pointer"
|
||||
onClick={() => setShowToolDetails(!showToolDetails)}
|
||||
>
|
||||
<div className="flex items-center gap-2 text-xs text-yellow-500">
|
||||
<Wrench className="w-4 h-4" />
|
||||
<span className="uppercase tracking-wider">Tool Calls</span>
|
||||
<Badge variant="outline" className="text-[10px] px-1.5 py-0 bg-yellow-900/30 border-yellow-700 text-yellow-400">
|
||||
{toolCalls.length}
|
||||
<div className="flex items-center gap-3 text-sm text-yellow-400">
|
||||
<div className="p-1.5 bg-yellow-800/50 rounded-lg">
|
||||
<Wrench className="w-5 h-5" />
|
||||
</div>
|
||||
<div>
|
||||
<span className="uppercase tracking-wider font-semibold">🔧 LLM Tool Calls</span>
|
||||
<span className="text-yellow-500 ml-2 text-xs">LLM 决定调用的工具</span>
|
||||
</div>
|
||||
<Badge variant="outline" className="text-xs px-2 py-0.5 bg-yellow-900/50 border-yellow-600 text-yellow-300">
|
||||
{toolCalls.length} 次调用
|
||||
</Badge>
|
||||
</div>
|
||||
{showToolDetails ? <ChevronUp className="w-4 h-4 text-yellow-500" /> : <ChevronDown className="w-4 h-4 text-yellow-500" />}
|
||||
{showToolDetails ? <ChevronUp className="w-5 h-5 text-yellow-500" /> : <ChevronDown className="w-5 h-5 text-yellow-500" />}
|
||||
</div>
|
||||
|
||||
<div className="max-h-48 overflow-y-auto">
|
||||
<div className="p-2 space-y-2">
|
||||
<div className="max-h-52 overflow-y-auto bg-[#1a1810]">
|
||||
<div className="p-3 space-y-2">
|
||||
{toolCalls.slice(-5).map((tc, idx) => (
|
||||
<ToolCallCard key={`${tc.name}-${idx}`} toolCall={tc} />
|
||||
<ToolCallCard
|
||||
key={`${tc.name}-${idx}`}
|
||||
toolCall={{
|
||||
...tc,
|
||||
output: tc.output as string | Record<string, unknown> | undefined
|
||||
}}
|
||||
/>
|
||||
))}
|
||||
</div>
|
||||
</div>
|
||||
|
|
@ -429,17 +485,20 @@ export default function AgentAuditPage() {
|
|||
)}
|
||||
|
||||
<div className="flex items-center justify-between mb-3">
|
||||
<div className="flex items-center gap-2 text-xs text-cyan-400">
|
||||
<Terminal className="w-4 h-4" />
|
||||
<span className="uppercase tracking-wider">Execution Log</span>
|
||||
<div className="flex items-center gap-3 text-sm text-cyan-400">
|
||||
<div className="flex items-center gap-2">
|
||||
<Terminal className="w-4 h-4" />
|
||||
<span className="uppercase tracking-wider font-semibold">LLM Execution Log</span>
|
||||
</div>
|
||||
<span className="text-xs text-gray-500">LLM 思考 & 工具调用记录</span>
|
||||
{(isStreaming || isStreamConnected) && (
|
||||
<span className="flex items-center gap-1 text-green-400">
|
||||
<span className="flex items-center gap-1.5 text-green-400 bg-green-900/30 px-2 py-0.5 rounded-full text-xs">
|
||||
<span className="w-2 h-2 bg-green-400 rounded-full animate-pulse" />
|
||||
LIVE
|
||||
</span>
|
||||
)}
|
||||
</div>
|
||||
<span className="text-xs text-gray-500">{events.length} events</span>
|
||||
<span className="text-xs text-gray-500">{events.length} 条记录</span>
|
||||
</div>
|
||||
|
||||
{/* 终端窗口 */}
|
||||
|
|
@ -649,7 +708,7 @@ function StatusBadge({ status }: { status: string }) {
|
|||
);
|
||||
}
|
||||
|
||||
// 事件行组件
|
||||
// 事件行组件 - 增强 LLM 事件展示
|
||||
function EventLine({ event }: { event: AgentEvent }) {
|
||||
const icon = eventTypeIcons[event.event_type] || <ChevronRight className="w-3 h-3 text-gray-500" />;
|
||||
const colorClass = eventTypeColors[event.event_type] || "text-gray-400";
|
||||
|
|
@ -658,13 +717,28 @@ function EventLine({ event }: { event: AgentEvent }) {
|
|||
? new Date(event.timestamp).toLocaleTimeString("zh-CN", { hour12: false })
|
||||
: "";
|
||||
|
||||
// LLM 思考事件特殊处理 - 展示多行内容
|
||||
const isLLMThought = event.event_type === "llm_thought";
|
||||
const isLLMDecision = event.event_type === "llm_decision";
|
||||
const isLLMAction = event.event_type === "llm_action";
|
||||
const isImportantLLMEvent = isLLMThought || isLLMDecision || isLLMAction;
|
||||
|
||||
// LLM 事件背景色
|
||||
const bgClass = isLLMThought
|
||||
? "bg-purple-950/40 border-l-2 border-purple-600"
|
||||
: isLLMDecision
|
||||
? "bg-yellow-950/30 border-l-2 border-yellow-600"
|
||||
: isLLMAction
|
||||
? "bg-orange-950/30 border-l-2 border-orange-600"
|
||||
: "";
|
||||
|
||||
return (
|
||||
<div className={`flex items-start gap-2 py-0.5 group hover:bg-white/5 px-1 rounded ${colorClass}`}>
|
||||
<div className={`flex items-start gap-2 py-1 group hover:bg-white/5 px-2 rounded ${colorClass} ${bgClass}`}>
|
||||
<span className="text-gray-600 text-xs w-20 flex-shrink-0 group-hover:text-gray-500">
|
||||
{timestamp}
|
||||
</span>
|
||||
<span className="flex-shrink-0 mt-0.5">{icon}</span>
|
||||
<span className="flex-1 text-sm break-all">
|
||||
<span className={`flex-1 text-sm break-all ${isImportantLLMEvent ? "whitespace-pre-wrap" : ""}`}>
|
||||
{event.message}
|
||||
{event.tool_duration_ms && (
|
||||
<span className="text-gray-600 ml-2">({event.tool_duration_ms}ms)</span>
|
||||
|
|
@ -679,7 +753,7 @@ interface ToolCallProps {
|
|||
toolCall: {
|
||||
name: string;
|
||||
input: Record<string, unknown>;
|
||||
output?: unknown;
|
||||
output?: string | Record<string, unknown>;
|
||||
durationMs?: number;
|
||||
status: 'running' | 'success' | 'error';
|
||||
};
|
||||
|
|
|
|||
|
|
@ -106,6 +106,9 @@ export class AgentStreamHandler {
|
|||
private reconnectDelay = 1000;
|
||||
private isConnected = false;
|
||||
private thinkingBuffer: string[] = [];
|
||||
private reader: ReadableStreamDefaultReader<Uint8Array> | null = null; // 🔥 保存 reader 引用
|
||||
private abortController: AbortController | null = null; // 🔥 用于取消请求
|
||||
private isDisconnecting = false; // 🔥 标记是否正在断开
|
||||
|
||||
constructor(taskId: string, options: StreamOptions = {}) {
|
||||
this.taskId = taskId;
|
||||
|
|
@ -121,6 +124,11 @@ export class AgentStreamHandler {
|
|||
* 开始监听事件流
|
||||
*/
|
||||
connect(): void {
|
||||
// 🔥 如果已经连接,不重复连接
|
||||
if (this.isConnected || this.isDisconnecting) {
|
||||
return;
|
||||
}
|
||||
|
||||
const token = localStorage.getItem('access_token');
|
||||
if (!token) {
|
||||
this.options.onError?.('未登录');
|
||||
|
|
@ -142,14 +150,23 @@ export class AgentStreamHandler {
|
|||
* 使用 fetch 连接(支持自定义 headers)
|
||||
*/
|
||||
private async connectWithFetch(token: string, params: URLSearchParams): Promise<void> {
|
||||
// 🔥 如果正在断开,不连接
|
||||
if (this.isDisconnecting) {
|
||||
return;
|
||||
}
|
||||
|
||||
const url = `/api/v1/agent-tasks/${this.taskId}/stream?${params}`;
|
||||
|
||||
// 🔥 创建 AbortController 用于取消请求
|
||||
this.abortController = new AbortController();
|
||||
|
||||
try {
|
||||
const response = await fetch(url, {
|
||||
headers: {
|
||||
'Authorization': `Bearer ${token}`,
|
||||
'Accept': 'text/event-stream',
|
||||
},
|
||||
signal: this.abortController.signal, // 🔥 支持取消
|
||||
});
|
||||
|
||||
if (!response.ok) {
|
||||
|
|
@ -159,8 +176,8 @@ export class AgentStreamHandler {
|
|||
this.isConnected = true;
|
||||
this.reconnectAttempts = 0;
|
||||
|
||||
const reader = response.body?.getReader();
|
||||
if (!reader) {
|
||||
this.reader = response.body?.getReader() || null;
|
||||
if (!this.reader) {
|
||||
throw new Error('无法获取响应流');
|
||||
}
|
||||
|
||||
|
|
@ -168,7 +185,12 @@ export class AgentStreamHandler {
|
|||
let buffer = '';
|
||||
|
||||
while (true) {
|
||||
const { done, value } = await reader.read();
|
||||
// 🔥 检查是否正在断开
|
||||
if (this.isDisconnecting) {
|
||||
break;
|
||||
}
|
||||
|
||||
const { done, value } = await this.reader.read();
|
||||
|
||||
if (done) {
|
||||
break;
|
||||
|
|
@ -184,17 +206,42 @@ export class AgentStreamHandler {
|
|||
this.handleEvent(event);
|
||||
}
|
||||
}
|
||||
} catch (error) {
|
||||
|
||||
// 🔥 正常结束,清理 reader
|
||||
if (this.reader) {
|
||||
this.reader.releaseLock();
|
||||
this.reader = null;
|
||||
}
|
||||
} catch (error: any) {
|
||||
// 🔥 如果是取消错误,不处理
|
||||
if (error.name === 'AbortError') {
|
||||
return;
|
||||
}
|
||||
|
||||
this.isConnected = false;
|
||||
console.error('Stream connection error:', error);
|
||||
|
||||
// 尝试重连
|
||||
if (this.reconnectAttempts < this.maxReconnectAttempts) {
|
||||
// 🔥 只有在未断开时才尝试重连
|
||||
if (!this.isDisconnecting && this.reconnectAttempts < this.maxReconnectAttempts) {
|
||||
this.reconnectAttempts++;
|
||||
setTimeout(() => this.connect(), this.reconnectDelay * this.reconnectAttempts);
|
||||
setTimeout(() => {
|
||||
if (!this.isDisconnecting) {
|
||||
this.connect();
|
||||
}
|
||||
}, this.reconnectDelay * this.reconnectAttempts);
|
||||
} else {
|
||||
this.options.onError?.(`连接失败: ${error}`);
|
||||
}
|
||||
} finally {
|
||||
// 🔥 清理 reader
|
||||
if (this.reader) {
|
||||
try {
|
||||
this.reader.releaseLock();
|
||||
} catch {
|
||||
// 忽略释放错误
|
||||
}
|
||||
this.reader = null;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -357,11 +404,35 @@ export class AgentStreamHandler {
|
|||
* 断开连接
|
||||
*/
|
||||
disconnect(): void {
|
||||
// 🔥 标记正在断开,防止重连
|
||||
this.isDisconnecting = true;
|
||||
this.isConnected = false;
|
||||
|
||||
// 🔥 取消 fetch 请求
|
||||
if (this.abortController) {
|
||||
this.abortController.abort();
|
||||
this.abortController = null;
|
||||
}
|
||||
|
||||
// 🔥 清理 reader
|
||||
if (this.reader) {
|
||||
try {
|
||||
this.reader.cancel();
|
||||
this.reader.releaseLock();
|
||||
} catch {
|
||||
// 忽略清理错误
|
||||
}
|
||||
this.reader = null;
|
||||
}
|
||||
|
||||
// 清理 EventSource(如果使用)
|
||||
if (this.eventSource) {
|
||||
this.eventSource.close();
|
||||
this.eventSource = null;
|
||||
}
|
||||
|
||||
// 重置重连计数
|
||||
this.reconnectAttempts = 0;
|
||||
}
|
||||
|
||||
/**
|
||||
|
|
|
|||
Loading…
Reference in New Issue