feat(agent): implement streaming support for agent events and enhance UI components
- Introduce streaming capabilities for agent events, allowing real-time updates during audits. - Add new hooks for managing agent stream events in React components. - Enhance the AgentAudit page to display LLM thinking processes and tool call details in real-time. - Update API endpoints to support streaming event data and improve error handling. - Refactor UI components for better organization and user experience during audits.
This commit is contained in:
parent
a43ebf1793
commit
58c918f557
|
|
@ -59,3 +59,5 @@ CMD ["uvicorn", "app.main:app", "--host", "0.0.0.0", "--port", "8000"]
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -103,3 +103,5 @@ datefmt = %H:%M:%S
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -90,3 +90,5 @@ else:
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -25,3 +25,5 @@ def downgrade() -> None:
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -29,6 +29,7 @@ from app.models.agent_task import (
|
||||||
from app.models.project import Project
|
from app.models.project import Project
|
||||||
from app.models.user import User
|
from app.models.user import User
|
||||||
from app.services.agent import AgentRunner, EventManager, run_agent_task
|
from app.services.agent import AgentRunner, EventManager, run_agent_task
|
||||||
|
from app.services.agent.streaming import StreamHandler, StreamEvent, StreamEventType
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
router = APIRouter()
|
router = APIRouter()
|
||||||
|
|
@ -432,6 +433,171 @@ async def stream_agent_events(
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/{task_id}/stream")
|
||||||
|
async def stream_agent_with_thinking(
|
||||||
|
task_id: str,
|
||||||
|
include_thinking: bool = Query(True, description="是否包含 LLM 思考过程"),
|
||||||
|
include_tool_calls: bool = Query(True, description="是否包含工具调用详情"),
|
||||||
|
after_sequence: int = Query(0, ge=0, description="从哪个序号之后开始"),
|
||||||
|
db: AsyncSession = Depends(get_db),
|
||||||
|
current_user: User = Depends(deps.get_current_user),
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
增强版事件流 (SSE)
|
||||||
|
|
||||||
|
支持:
|
||||||
|
- LLM 思考过程的 Token 级流式输出
|
||||||
|
- 工具调用的详细输入/输出
|
||||||
|
- 节点执行状态
|
||||||
|
- 发现事件
|
||||||
|
|
||||||
|
事件类型:
|
||||||
|
- thinking_start: LLM 开始思考
|
||||||
|
- thinking_token: LLM 输出 Token
|
||||||
|
- thinking_end: LLM 思考结束
|
||||||
|
- tool_call_start: 工具调用开始
|
||||||
|
- tool_call_end: 工具调用结束
|
||||||
|
- node_start: 节点开始
|
||||||
|
- node_end: 节点结束
|
||||||
|
- finding_new: 新发现
|
||||||
|
- finding_verified: 验证通过
|
||||||
|
- progress: 进度更新
|
||||||
|
- task_complete: 任务完成
|
||||||
|
- task_error: 任务错误
|
||||||
|
- heartbeat: 心跳
|
||||||
|
"""
|
||||||
|
task = await db.get(AgentTask, task_id)
|
||||||
|
if not task:
|
||||||
|
raise HTTPException(status_code=404, detail="任务不存在")
|
||||||
|
|
||||||
|
project = await db.get(Project, task.project_id)
|
||||||
|
if not project or project.owner_id != current_user.id:
|
||||||
|
raise HTTPException(status_code=403, detail="无权访问此任务")
|
||||||
|
|
||||||
|
async def enhanced_event_generator():
|
||||||
|
"""生成增强版 SSE 事件流"""
|
||||||
|
last_sequence = after_sequence
|
||||||
|
poll_interval = 0.3 # 更短的轮询间隔以支持流式
|
||||||
|
heartbeat_interval = 15 # 心跳间隔
|
||||||
|
max_idle = 600 # 10 分钟无事件后关闭
|
||||||
|
idle_time = 0
|
||||||
|
last_heartbeat = 0
|
||||||
|
|
||||||
|
# 事件类型过滤
|
||||||
|
skip_types = set()
|
||||||
|
if not include_thinking:
|
||||||
|
skip_types.update(["thinking_start", "thinking_token", "thinking_end"])
|
||||||
|
if not include_tool_calls:
|
||||||
|
skip_types.update(["tool_call_start", "tool_call_input", "tool_call_output", "tool_call_end"])
|
||||||
|
|
||||||
|
while True:
|
||||||
|
try:
|
||||||
|
async with async_session_factory() as session:
|
||||||
|
# 查询新事件
|
||||||
|
result = await session.execute(
|
||||||
|
select(AgentEvent)
|
||||||
|
.where(AgentEvent.task_id == task_id)
|
||||||
|
.where(AgentEvent.sequence > last_sequence)
|
||||||
|
.order_by(AgentEvent.sequence)
|
||||||
|
.limit(100)
|
||||||
|
)
|
||||||
|
events = result.scalars().all()
|
||||||
|
|
||||||
|
# 获取任务状态
|
||||||
|
current_task = await session.get(AgentTask, task_id)
|
||||||
|
task_status = current_task.status if current_task else None
|
||||||
|
|
||||||
|
if events:
|
||||||
|
idle_time = 0
|
||||||
|
for event in events:
|
||||||
|
last_sequence = event.sequence
|
||||||
|
|
||||||
|
# 获取事件类型字符串
|
||||||
|
event_type = event.event_type.value if hasattr(event.event_type, 'value') else str(event.event_type)
|
||||||
|
|
||||||
|
# 过滤事件
|
||||||
|
if event_type in skip_types:
|
||||||
|
continue
|
||||||
|
|
||||||
|
# 构建事件数据
|
||||||
|
data = {
|
||||||
|
"id": event.id,
|
||||||
|
"type": event_type,
|
||||||
|
"phase": event.phase.value if event.phase and hasattr(event.phase, 'value') else event.phase,
|
||||||
|
"message": event.message,
|
||||||
|
"sequence": event.sequence,
|
||||||
|
"timestamp": event.created_at.isoformat() if event.created_at else None,
|
||||||
|
}
|
||||||
|
|
||||||
|
# 添加工具调用详情
|
||||||
|
if include_tool_calls and event.tool_name:
|
||||||
|
data["tool"] = {
|
||||||
|
"name": event.tool_name,
|
||||||
|
"input": event.tool_input,
|
||||||
|
"output": event.tool_output,
|
||||||
|
"duration_ms": event.tool_duration_ms,
|
||||||
|
}
|
||||||
|
|
||||||
|
# 添加元数据
|
||||||
|
if event.event_metadata:
|
||||||
|
data["metadata"] = event.event_metadata
|
||||||
|
|
||||||
|
# 添加 Token 使用
|
||||||
|
if event.tokens_used:
|
||||||
|
data["tokens_used"] = event.tokens_used
|
||||||
|
|
||||||
|
# 使用标准 SSE 格式
|
||||||
|
yield f"event: {event_type}\ndata: {json.dumps(data, ensure_ascii=False)}\n\n"
|
||||||
|
else:
|
||||||
|
idle_time += poll_interval
|
||||||
|
|
||||||
|
# 检查任务是否结束
|
||||||
|
if task_status in [AgentTaskStatus.COMPLETED, AgentTaskStatus.FAILED, AgentTaskStatus.CANCELLED]:
|
||||||
|
end_data = {
|
||||||
|
"type": "task_end",
|
||||||
|
"status": task_status.value,
|
||||||
|
"message": f"任务{'完成' if task_status == AgentTaskStatus.COMPLETED else '结束'}",
|
||||||
|
}
|
||||||
|
yield f"event: task_end\ndata: {json.dumps(end_data, ensure_ascii=False)}\n\n"
|
||||||
|
break
|
||||||
|
|
||||||
|
# 发送心跳
|
||||||
|
last_heartbeat += poll_interval
|
||||||
|
if last_heartbeat >= heartbeat_interval:
|
||||||
|
last_heartbeat = 0
|
||||||
|
heartbeat_data = {
|
||||||
|
"type": "heartbeat",
|
||||||
|
"timestamp": datetime.now(timezone.utc).isoformat(),
|
||||||
|
"last_sequence": last_sequence,
|
||||||
|
}
|
||||||
|
yield f"event: heartbeat\ndata: {json.dumps(heartbeat_data)}\n\n"
|
||||||
|
|
||||||
|
# 检查空闲超时
|
||||||
|
if idle_time >= max_idle:
|
||||||
|
timeout_data = {"type": "timeout", "message": "连接超时"}
|
||||||
|
yield f"event: timeout\ndata: {json.dumps(timeout_data)}\n\n"
|
||||||
|
break
|
||||||
|
|
||||||
|
await asyncio.sleep(poll_interval)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Stream error: {e}")
|
||||||
|
error_data = {"type": "error", "message": str(e)}
|
||||||
|
yield f"event: error\ndata: {json.dumps(error_data)}\n\n"
|
||||||
|
break
|
||||||
|
|
||||||
|
return StreamingResponse(
|
||||||
|
enhanced_event_generator(),
|
||||||
|
media_type="text/event-stream",
|
||||||
|
headers={
|
||||||
|
"Cache-Control": "no-cache",
|
||||||
|
"Connection": "keep-alive",
|
||||||
|
"X-Accel-Buffering": "no",
|
||||||
|
"Content-Type": "text/event-stream; charset=utf-8",
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
@router.get("/{task_id}/events/list", response_model=List[AgentEventResponse])
|
@router.get("/{task_id}/events/list", response_model=List[AgentEventResponse])
|
||||||
async def list_agent_events(
|
async def list_agent_events(
|
||||||
task_id: str,
|
task_id: str,
|
||||||
|
|
|
||||||
|
|
@ -210,3 +210,5 @@ async def remove_project_member(
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -225,3 +225,5 @@ async def toggle_user_status(
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -29,3 +29,5 @@ def get_password_hash(password: str) -> str:
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -12,3 +12,5 @@ class Base:
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -28,3 +28,5 @@ async def async_session_factory():
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -24,3 +24,5 @@ class InstantAnalysis(Base):
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -25,3 +25,5 @@ class User(Base):
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -30,3 +30,5 @@ class UserConfig(Base):
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -10,3 +10,5 @@ class TokenPayload(BaseModel):
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -41,3 +41,5 @@ class UserListResponse(BaseModel):
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -147,11 +147,11 @@ class AnalysisAgent(BaseAgent):
|
||||||
deep_findings = await self._analyze_entry_points(entry_points)
|
deep_findings = await self._analyze_entry_points(entry_points)
|
||||||
all_findings.extend(deep_findings)
|
all_findings.extend(deep_findings)
|
||||||
|
|
||||||
# 分析高风险区域
|
# 分析高风险区域(现在会调用 LLM)
|
||||||
risk_findings = await self._analyze_high_risk_areas(high_risk_areas)
|
risk_findings = await self._analyze_high_risk_areas(high_risk_areas)
|
||||||
all_findings.extend(risk_findings)
|
all_findings.extend(risk_findings)
|
||||||
|
|
||||||
# 语义搜索常见漏洞
|
# 语义搜索常见漏洞(现在会调用 LLM)
|
||||||
vuln_types = config.get("target_vulnerabilities", [
|
vuln_types = config.get("target_vulnerabilities", [
|
||||||
"sql_injection", "xss", "command_injection",
|
"sql_injection", "xss", "command_injection",
|
||||||
"path_traversal", "ssrf", "hardcoded_secret",
|
"path_traversal", "ssrf", "hardcoded_secret",
|
||||||
|
|
@ -164,6 +164,12 @@ class AnalysisAgent(BaseAgent):
|
||||||
await self.emit_thinking(f"搜索 {vuln_type} 相关代码...")
|
await self.emit_thinking(f"搜索 {vuln_type} 相关代码...")
|
||||||
vuln_findings = await self._search_vulnerability_pattern(vuln_type)
|
vuln_findings = await self._search_vulnerability_pattern(vuln_type)
|
||||||
all_findings.extend(vuln_findings)
|
all_findings.extend(vuln_findings)
|
||||||
|
|
||||||
|
# 🔥 3. 如果还没有发现,使用 LLM 进行全面扫描
|
||||||
|
if len(all_findings) < 3:
|
||||||
|
await self.emit_thinking("执行 LLM 全面代码扫描...")
|
||||||
|
llm_findings = await self._llm_comprehensive_scan(tech_stack)
|
||||||
|
all_findings.extend(llm_findings)
|
||||||
|
|
||||||
# 去重
|
# 去重
|
||||||
all_findings = self._deduplicate_findings(all_findings)
|
all_findings = self._deduplicate_findings(all_findings)
|
||||||
|
|
@ -292,12 +298,12 @@ class AnalysisAgent(BaseAgent):
|
||||||
return findings
|
return findings
|
||||||
|
|
||||||
async def _analyze_high_risk_areas(self, high_risk_areas: List[str]) -> List[Dict]:
|
async def _analyze_high_risk_areas(self, high_risk_areas: List[str]) -> List[Dict]:
|
||||||
"""分析高风险区域"""
|
"""分析高风险区域 - 使用 LLM 深度分析"""
|
||||||
findings = []
|
findings = []
|
||||||
|
|
||||||
pattern_tool = self.tools.get("pattern_match")
|
|
||||||
read_tool = self.tools.get("read_file")
|
read_tool = self.tools.get("read_file")
|
||||||
search_tool = self.tools.get("search_code")
|
search_tool = self.tools.get("search_code")
|
||||||
|
code_analysis_tool = self.tools.get("code_analysis")
|
||||||
|
|
||||||
if not search_tool:
|
if not search_tool:
|
||||||
return findings
|
return findings
|
||||||
|
|
@ -305,36 +311,92 @@ class AnalysisAgent(BaseAgent):
|
||||||
# 在高风险区域搜索危险模式
|
# 在高风险区域搜索危险模式
|
||||||
dangerous_patterns = [
|
dangerous_patterns = [
|
||||||
("execute(", "sql_injection"),
|
("execute(", "sql_injection"),
|
||||||
|
("query(", "sql_injection"),
|
||||||
("eval(", "code_injection"),
|
("eval(", "code_injection"),
|
||||||
("system(", "command_injection"),
|
("system(", "command_injection"),
|
||||||
("exec(", "command_injection"),
|
("exec(", "command_injection"),
|
||||||
|
("subprocess", "command_injection"),
|
||||||
("innerHTML", "xss"),
|
("innerHTML", "xss"),
|
||||||
("document.write", "xss"),
|
("document.write", "xss"),
|
||||||
|
("open(", "path_traversal"),
|
||||||
|
("requests.get", "ssrf"),
|
||||||
]
|
]
|
||||||
|
|
||||||
for pattern, vuln_type in dangerous_patterns[:5]:
|
analyzed_files = set()
|
||||||
|
|
||||||
|
for pattern, vuln_type in dangerous_patterns[:8]:
|
||||||
if self.is_cancelled:
|
if self.is_cancelled:
|
||||||
break
|
break
|
||||||
|
|
||||||
result = await search_tool.execute(keyword=pattern, max_results=10)
|
result = await search_tool.execute(keyword=pattern, max_results=10)
|
||||||
|
|
||||||
if result.success and result.metadata.get("matches", 0) > 0:
|
if result.success and result.metadata.get("matches", 0) > 0:
|
||||||
for match in result.metadata.get("results", [])[:3]:
|
for match in result.metadata.get("results", [])[:5]:
|
||||||
file_path = match.get("file", "")
|
file_path = match.get("file", "")
|
||||||
|
line = match.get("line", 0)
|
||||||
|
|
||||||
# 检查是否在高风险区域
|
# 避免重复分析同一个文件的同一区域
|
||||||
in_high_risk = any(
|
file_key = f"{file_path}:{line // 50}"
|
||||||
area in file_path for area in high_risk_areas
|
if file_key in analyzed_files:
|
||||||
)
|
continue
|
||||||
|
analyzed_files.add(file_key)
|
||||||
|
|
||||||
if in_high_risk or True: # 暂时包含所有
|
# 🔥 使用 LLM 深度分析找到的代码
|
||||||
|
if read_tool and code_analysis_tool:
|
||||||
|
await self.emit_thinking(f"LLM 分析 {file_path}:{line} 的 {vuln_type} 风险...")
|
||||||
|
|
||||||
|
# 读取代码上下文
|
||||||
|
read_result = await read_tool.execute(
|
||||||
|
file_path=file_path,
|
||||||
|
start_line=max(1, line - 15),
|
||||||
|
end_line=line + 25,
|
||||||
|
)
|
||||||
|
|
||||||
|
if read_result.success:
|
||||||
|
# 调用 LLM 分析
|
||||||
|
analysis_result = await code_analysis_tool.execute(
|
||||||
|
code=read_result.data,
|
||||||
|
file_path=file_path,
|
||||||
|
focus=vuln_type,
|
||||||
|
)
|
||||||
|
|
||||||
|
if analysis_result.success and analysis_result.metadata.get("issues"):
|
||||||
|
for issue in analysis_result.metadata["issues"]:
|
||||||
|
findings.append({
|
||||||
|
"vulnerability_type": issue.get("type", vuln_type),
|
||||||
|
"severity": issue.get("severity", "medium"),
|
||||||
|
"title": issue.get("title", f"LLM 发现: {vuln_type}"),
|
||||||
|
"description": issue.get("description", ""),
|
||||||
|
"file_path": file_path,
|
||||||
|
"line_start": issue.get("line", line),
|
||||||
|
"code_snippet": issue.get("code_snippet", match.get("match", "")),
|
||||||
|
"suggestion": issue.get("suggestion", ""),
|
||||||
|
"ai_explanation": issue.get("ai_explanation", ""),
|
||||||
|
"source": "llm_analysis",
|
||||||
|
"needs_verification": True,
|
||||||
|
})
|
||||||
|
elif analysis_result.success:
|
||||||
|
# LLM 分析了但没发现问题,仍记录原始发现
|
||||||
|
findings.append({
|
||||||
|
"vulnerability_type": vuln_type,
|
||||||
|
"severity": "low",
|
||||||
|
"title": f"疑似 {vuln_type}: {pattern}",
|
||||||
|
"description": f"在 {file_path} 中发现危险模式,但 LLM 分析未确认",
|
||||||
|
"file_path": file_path,
|
||||||
|
"line_start": line,
|
||||||
|
"code_snippet": match.get("match", ""),
|
||||||
|
"source": "pattern_search",
|
||||||
|
"needs_verification": True,
|
||||||
|
})
|
||||||
|
else:
|
||||||
|
# 没有 LLM 工具,使用基础模式匹配
|
||||||
findings.append({
|
findings.append({
|
||||||
"vulnerability_type": vuln_type,
|
"vulnerability_type": vuln_type,
|
||||||
"severity": "high" if in_high_risk else "medium",
|
"severity": "medium",
|
||||||
"title": f"疑似 {vuln_type}: {pattern}",
|
"title": f"疑似 {vuln_type}: {pattern}",
|
||||||
"description": f"在 {file_path} 中发现危险模式 {pattern}",
|
"description": f"在 {file_path} 中发现危险模式 {pattern}",
|
||||||
"file_path": file_path,
|
"file_path": file_path,
|
||||||
"line_start": match.get("line", 0),
|
"line_start": line,
|
||||||
"code_snippet": match.get("match", ""),
|
"code_snippet": match.get("match", ""),
|
||||||
"source": "pattern_search",
|
"source": "pattern_search",
|
||||||
"needs_verification": True,
|
"needs_verification": True,
|
||||||
|
|
@ -343,10 +405,13 @@ class AnalysisAgent(BaseAgent):
|
||||||
return findings
|
return findings
|
||||||
|
|
||||||
async def _search_vulnerability_pattern(self, vuln_type: str) -> List[Dict]:
|
async def _search_vulnerability_pattern(self, vuln_type: str) -> List[Dict]:
|
||||||
"""搜索特定漏洞模式"""
|
"""搜索特定漏洞模式 - 使用 RAG + LLM"""
|
||||||
findings = []
|
findings = []
|
||||||
|
|
||||||
security_tool = self.tools.get("security_search")
|
security_tool = self.tools.get("security_search")
|
||||||
|
code_analysis_tool = self.tools.get("code_analysis")
|
||||||
|
read_tool = self.tools.get("read_file")
|
||||||
|
|
||||||
if not security_tool:
|
if not security_tool:
|
||||||
return findings
|
return findings
|
||||||
|
|
||||||
|
|
@ -357,20 +422,176 @@ class AnalysisAgent(BaseAgent):
|
||||||
|
|
||||||
if result.success and result.metadata.get("results_count", 0) > 0:
|
if result.success and result.metadata.get("results_count", 0) > 0:
|
||||||
for item in result.metadata.get("results", [])[:5]:
|
for item in result.metadata.get("results", [])[:5]:
|
||||||
findings.append({
|
file_path = item.get("file_path", "")
|
||||||
"vulnerability_type": vuln_type,
|
line_start = item.get("line_start", 0)
|
||||||
"severity": "medium",
|
content = item.get("content", "")[:2000]
|
||||||
"title": f"疑似 {vuln_type}",
|
|
||||||
"description": f"通过语义搜索发现可能存在 {vuln_type}",
|
# 🔥 使用 LLM 验证 RAG 搜索结果
|
||||||
"file_path": item.get("file_path", ""),
|
if code_analysis_tool and content:
|
||||||
"line_start": item.get("line_start", 0),
|
await self.emit_thinking(f"LLM 验证 RAG 发现的 {vuln_type}...")
|
||||||
"code_snippet": item.get("content", "")[:500],
|
|
||||||
"source": "rag_search",
|
analysis_result = await code_analysis_tool.execute(
|
||||||
"needs_verification": True,
|
code=content,
|
||||||
})
|
file_path=file_path,
|
||||||
|
focus=vuln_type,
|
||||||
|
)
|
||||||
|
|
||||||
|
if analysis_result.success and analysis_result.metadata.get("issues"):
|
||||||
|
for issue in analysis_result.metadata["issues"]:
|
||||||
|
findings.append({
|
||||||
|
"vulnerability_type": issue.get("type", vuln_type),
|
||||||
|
"severity": issue.get("severity", "medium"),
|
||||||
|
"title": issue.get("title", f"LLM 确认: {vuln_type}"),
|
||||||
|
"description": issue.get("description", ""),
|
||||||
|
"file_path": file_path,
|
||||||
|
"line_start": issue.get("line", line_start),
|
||||||
|
"code_snippet": issue.get("code_snippet", content[:500]),
|
||||||
|
"suggestion": issue.get("suggestion", ""),
|
||||||
|
"ai_explanation": issue.get("ai_explanation", ""),
|
||||||
|
"source": "rag_llm_analysis",
|
||||||
|
"needs_verification": True,
|
||||||
|
})
|
||||||
|
else:
|
||||||
|
# RAG 找到但 LLM 未确认
|
||||||
|
findings.append({
|
||||||
|
"vulnerability_type": vuln_type,
|
||||||
|
"severity": "low",
|
||||||
|
"title": f"疑似 {vuln_type} (待确认)",
|
||||||
|
"description": f"RAG 搜索发现可能存在 {vuln_type},但 LLM 未确认",
|
||||||
|
"file_path": file_path,
|
||||||
|
"line_start": line_start,
|
||||||
|
"code_snippet": content[:500],
|
||||||
|
"source": "rag_search",
|
||||||
|
"needs_verification": True,
|
||||||
|
})
|
||||||
|
else:
|
||||||
|
findings.append({
|
||||||
|
"vulnerability_type": vuln_type,
|
||||||
|
"severity": "medium",
|
||||||
|
"title": f"疑似 {vuln_type}",
|
||||||
|
"description": f"通过语义搜索发现可能存在 {vuln_type}",
|
||||||
|
"file_path": file_path,
|
||||||
|
"line_start": line_start,
|
||||||
|
"code_snippet": content[:500],
|
||||||
|
"source": "rag_search",
|
||||||
|
"needs_verification": True,
|
||||||
|
})
|
||||||
|
|
||||||
return findings
|
return findings
|
||||||
|
|
||||||
|
async def _llm_comprehensive_scan(self, tech_stack: Dict) -> List[Dict]:
|
||||||
|
"""
|
||||||
|
LLM 全面代码扫描
|
||||||
|
当其他方法没有发现足够的问题时,使用 LLM 直接分析关键文件
|
||||||
|
"""
|
||||||
|
findings = []
|
||||||
|
|
||||||
|
list_tool = self.tools.get("list_files")
|
||||||
|
read_tool = self.tools.get("read_file")
|
||||||
|
code_analysis_tool = self.tools.get("code_analysis")
|
||||||
|
|
||||||
|
if not all([list_tool, read_tool, code_analysis_tool]):
|
||||||
|
return findings
|
||||||
|
|
||||||
|
await self.emit_thinking("LLM 全面扫描关键代码文件...")
|
||||||
|
|
||||||
|
# 确定要扫描的文件类型
|
||||||
|
languages = tech_stack.get("languages", [])
|
||||||
|
file_patterns = []
|
||||||
|
|
||||||
|
if "Python" in languages:
|
||||||
|
file_patterns.extend(["*.py"])
|
||||||
|
if "JavaScript" in languages or "TypeScript" in languages:
|
||||||
|
file_patterns.extend(["*.js", "*.ts"])
|
||||||
|
if "Go" in languages:
|
||||||
|
file_patterns.extend(["*.go"])
|
||||||
|
if "Java" in languages:
|
||||||
|
file_patterns.extend(["*.java"])
|
||||||
|
if "PHP" in languages:
|
||||||
|
file_patterns.extend(["*.php"])
|
||||||
|
|
||||||
|
if not file_patterns:
|
||||||
|
file_patterns = ["*.py", "*.js", "*.ts", "*.go", "*.java", "*.php"]
|
||||||
|
|
||||||
|
# 扫描关键目录
|
||||||
|
key_dirs = ["src", "app", "api", "routes", "controllers", "handlers", "lib", "utils", "."]
|
||||||
|
scanned_files = 0
|
||||||
|
max_files_to_scan = 10
|
||||||
|
|
||||||
|
for key_dir in key_dirs:
|
||||||
|
if scanned_files >= max_files_to_scan or self.is_cancelled:
|
||||||
|
break
|
||||||
|
|
||||||
|
for pattern in file_patterns[:3]:
|
||||||
|
if scanned_files >= max_files_to_scan or self.is_cancelled:
|
||||||
|
break
|
||||||
|
|
||||||
|
# 列出文件
|
||||||
|
list_result = await list_tool.execute(
|
||||||
|
directory=key_dir,
|
||||||
|
pattern=pattern,
|
||||||
|
recursive=True,
|
||||||
|
max_files=20,
|
||||||
|
)
|
||||||
|
|
||||||
|
if not list_result.success:
|
||||||
|
continue
|
||||||
|
|
||||||
|
# 从输出中提取文件路径
|
||||||
|
output = list_result.data
|
||||||
|
file_paths = []
|
||||||
|
for line in output.split('\n'):
|
||||||
|
line = line.strip()
|
||||||
|
if line.startswith('📄 '):
|
||||||
|
file_paths.append(line[2:].strip())
|
||||||
|
|
||||||
|
# 分析每个文件
|
||||||
|
for file_path in file_paths[:5]:
|
||||||
|
if scanned_files >= max_files_to_scan or self.is_cancelled:
|
||||||
|
break
|
||||||
|
|
||||||
|
# 跳过测试文件和配置文件
|
||||||
|
if any(skip in file_path.lower() for skip in ['test', 'spec', 'mock', '__pycache__', 'node_modules']):
|
||||||
|
continue
|
||||||
|
|
||||||
|
await self.emit_thinking(f"LLM 分析文件: {file_path}")
|
||||||
|
|
||||||
|
# 读取文件
|
||||||
|
read_result = await read_tool.execute(
|
||||||
|
file_path=file_path,
|
||||||
|
max_lines=200,
|
||||||
|
)
|
||||||
|
|
||||||
|
if not read_result.success:
|
||||||
|
continue
|
||||||
|
|
||||||
|
scanned_files += 1
|
||||||
|
|
||||||
|
# 🔥 LLM 深度分析
|
||||||
|
analysis_result = await code_analysis_tool.execute(
|
||||||
|
code=read_result.data,
|
||||||
|
file_path=file_path,
|
||||||
|
)
|
||||||
|
|
||||||
|
if analysis_result.success and analysis_result.metadata.get("issues"):
|
||||||
|
for issue in analysis_result.metadata["issues"]:
|
||||||
|
findings.append({
|
||||||
|
"vulnerability_type": issue.get("type", "other"),
|
||||||
|
"severity": issue.get("severity", "medium"),
|
||||||
|
"title": issue.get("title", "LLM 发现的安全问题"),
|
||||||
|
"description": issue.get("description", ""),
|
||||||
|
"file_path": file_path,
|
||||||
|
"line_start": issue.get("line", 0),
|
||||||
|
"code_snippet": issue.get("code_snippet", ""),
|
||||||
|
"suggestion": issue.get("suggestion", ""),
|
||||||
|
"ai_explanation": issue.get("ai_explanation", ""),
|
||||||
|
"source": "llm_comprehensive_scan",
|
||||||
|
"needs_verification": True,
|
||||||
|
})
|
||||||
|
|
||||||
|
await self.emit_thinking(f"LLM 全面扫描完成,分析了 {scanned_files} 个文件")
|
||||||
|
return findings
|
||||||
|
|
||||||
def _deduplicate_findings(self, findings: List[Dict]) -> List[Dict]:
|
def _deduplicate_findings(self, findings: List[Dict]) -> List[Dict]:
|
||||||
"""去重发现"""
|
"""去重发现"""
|
||||||
seen = set()
|
seen = set()
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,375 @@
|
||||||
|
"""
|
||||||
|
真正的 ReAct Agent 实现
|
||||||
|
LLM 是大脑,全程参与决策!
|
||||||
|
|
||||||
|
ReAct 循环:
|
||||||
|
1. Thought: LLM 思考当前状态和下一步
|
||||||
|
2. Action: LLM 决定调用哪个工具
|
||||||
|
3. Observation: 执行工具,获取结果
|
||||||
|
4. 重复直到 LLM 决定完成
|
||||||
|
"""
|
||||||
|
|
||||||
|
import json
|
||||||
|
import logging
|
||||||
|
import re
|
||||||
|
from typing import List, Dict, Any, Optional, Tuple
|
||||||
|
from dataclasses import dataclass
|
||||||
|
|
||||||
|
from .base import BaseAgent, AgentConfig, AgentResult, AgentType, AgentPattern
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
REACT_SYSTEM_PROMPT = """你是 DeepAudit 安全审计 Agent,一个专业的代码安全分析专家。
|
||||||
|
|
||||||
|
## 你的任务
|
||||||
|
对目标项目进行全面的安全审计,发现潜在的安全漏洞。
|
||||||
|
|
||||||
|
## 你的工具
|
||||||
|
{tools_description}
|
||||||
|
|
||||||
|
## 工作方式
|
||||||
|
你需要通过 **思考-行动-观察** 循环来完成任务:
|
||||||
|
|
||||||
|
1. **Thought**: 分析当前情况,思考下一步应该做什么
|
||||||
|
2. **Action**: 选择一个工具并执行
|
||||||
|
3. **Observation**: 观察工具返回的结果
|
||||||
|
4. 重复上述过程直到你认为审计完成
|
||||||
|
|
||||||
|
## 输出格式
|
||||||
|
每一步必须严格按照以下格式输出:
|
||||||
|
|
||||||
|
```
|
||||||
|
Thought: [你的思考过程,分析当前状态,决定下一步]
|
||||||
|
Action: [工具名称]
|
||||||
|
Action Input: [工具参数,JSON 格式]
|
||||||
|
```
|
||||||
|
|
||||||
|
当你完成分析后,输出:
|
||||||
|
```
|
||||||
|
Thought: [总结分析结果]
|
||||||
|
Final Answer: [JSON 格式的最终发现]
|
||||||
|
```
|
||||||
|
|
||||||
|
## Final Answer 格式
|
||||||
|
```json
|
||||||
|
{{
|
||||||
|
"findings": [
|
||||||
|
{{
|
||||||
|
"vulnerability_type": "sql_injection",
|
||||||
|
"severity": "high",
|
||||||
|
"title": "SQL 注入漏洞",
|
||||||
|
"description": "详细描述",
|
||||||
|
"file_path": "path/to/file.py",
|
||||||
|
"line_start": 42,
|
||||||
|
"code_snippet": "危险代码片段",
|
||||||
|
"suggestion": "修复建议"
|
||||||
|
}}
|
||||||
|
],
|
||||||
|
"summary": "审计总结"
|
||||||
|
}}
|
||||||
|
```
|
||||||
|
|
||||||
|
## 审计策略建议
|
||||||
|
1. 先用 list_files 了解项目结构
|
||||||
|
2. 识别关键文件(路由、控制器、数据库操作)
|
||||||
|
3. 使用 search_code 搜索危险模式(eval, exec, query, innerHTML 等)
|
||||||
|
4. 读取可疑文件进行深度分析
|
||||||
|
5. 如果有 semgrep,用它进行全面扫描
|
||||||
|
|
||||||
|
## 重点关注的漏洞类型
|
||||||
|
- SQL 注入 (query, execute, raw SQL)
|
||||||
|
- XSS (innerHTML, document.write, v-html)
|
||||||
|
- 命令注入 (exec, system, subprocess, child_process)
|
||||||
|
- 路径遍历 (open, readFile, path concatenation)
|
||||||
|
- SSRF (requests, fetch, http client)
|
||||||
|
- 硬编码密钥 (password, secret, api_key, token)
|
||||||
|
- 不安全的反序列化 (pickle, yaml.load, eval)
|
||||||
|
|
||||||
|
现在开始审计!"""
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class AgentStep:
|
||||||
|
"""Agent 执行步骤"""
|
||||||
|
thought: str
|
||||||
|
action: Optional[str] = None
|
||||||
|
action_input: Optional[Dict] = None
|
||||||
|
observation: Optional[str] = None
|
||||||
|
is_final: bool = False
|
||||||
|
final_answer: Optional[Dict] = None
|
||||||
|
|
||||||
|
|
||||||
|
class ReActAgent(BaseAgent):
|
||||||
|
"""
|
||||||
|
真正的 ReAct Agent
|
||||||
|
|
||||||
|
LLM 全程参与决策,自主选择工具和分析策略
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
llm_service,
|
||||||
|
tools: Dict[str, Any],
|
||||||
|
event_emitter=None,
|
||||||
|
agent_type: AgentType = AgentType.ANALYSIS,
|
||||||
|
max_iterations: int = 30,
|
||||||
|
):
|
||||||
|
config = AgentConfig(
|
||||||
|
name="ReActAgent",
|
||||||
|
agent_type=agent_type,
|
||||||
|
pattern=AgentPattern.REACT,
|
||||||
|
max_iterations=max_iterations,
|
||||||
|
system_prompt=REACT_SYSTEM_PROMPT,
|
||||||
|
)
|
||||||
|
super().__init__(config, llm_service, tools, event_emitter)
|
||||||
|
|
||||||
|
self._conversation_history: List[Dict[str, str]] = []
|
||||||
|
self._steps: List[AgentStep] = []
|
||||||
|
|
||||||
|
def _get_tools_description(self) -> str:
|
||||||
|
"""生成工具描述"""
|
||||||
|
descriptions = []
|
||||||
|
|
||||||
|
for name, tool in self.tools.items():
|
||||||
|
if name.startswith("_"):
|
||||||
|
continue
|
||||||
|
|
||||||
|
desc = f"### {name}\n"
|
||||||
|
desc += f"{tool.description}\n"
|
||||||
|
|
||||||
|
# 添加参数说明
|
||||||
|
if hasattr(tool, 'args_schema') and tool.args_schema:
|
||||||
|
schema = tool.args_schema.schema()
|
||||||
|
properties = schema.get("properties", {})
|
||||||
|
if properties:
|
||||||
|
desc += "参数:\n"
|
||||||
|
for param_name, param_info in properties.items():
|
||||||
|
param_desc = param_info.get("description", "")
|
||||||
|
param_type = param_info.get("type", "string")
|
||||||
|
desc += f" - {param_name} ({param_type}): {param_desc}\n"
|
||||||
|
|
||||||
|
descriptions.append(desc)
|
||||||
|
|
||||||
|
return "\n".join(descriptions)
|
||||||
|
|
||||||
|
def _build_system_prompt(self, project_info: Dict, task_context: str = "") -> str:
|
||||||
|
"""构建系统提示词"""
|
||||||
|
tools_desc = self._get_tools_description()
|
||||||
|
prompt = self.config.system_prompt.format(tools_description=tools_desc)
|
||||||
|
|
||||||
|
if project_info:
|
||||||
|
prompt += f"\n\n## 项目信息\n"
|
||||||
|
prompt += f"- 名称: {project_info.get('name', 'unknown')}\n"
|
||||||
|
prompt += f"- 语言: {', '.join(project_info.get('languages', ['unknown']))}\n"
|
||||||
|
prompt += f"- 文件数: {project_info.get('file_count', 'unknown')}\n"
|
||||||
|
|
||||||
|
if task_context:
|
||||||
|
prompt += f"\n\n## 任务上下文\n{task_context}"
|
||||||
|
|
||||||
|
return prompt
|
||||||
|
|
||||||
|
def _parse_llm_response(self, response: str) -> AgentStep:
|
||||||
|
"""解析 LLM 响应"""
|
||||||
|
step = AgentStep(thought="")
|
||||||
|
|
||||||
|
# 提取 Thought
|
||||||
|
thought_match = re.search(r'Thought:\s*(.*?)(?=Action:|Final Answer:|$)', response, re.DOTALL)
|
||||||
|
if thought_match:
|
||||||
|
step.thought = thought_match.group(1).strip()
|
||||||
|
|
||||||
|
# 检查是否是最终答案
|
||||||
|
final_match = re.search(r'Final Answer:\s*(.*?)$', response, re.DOTALL)
|
||||||
|
if final_match:
|
||||||
|
step.is_final = True
|
||||||
|
try:
|
||||||
|
# 尝试提取 JSON
|
||||||
|
answer_text = final_match.group(1).strip()
|
||||||
|
# 移除 markdown 代码块
|
||||||
|
answer_text = re.sub(r'```json\s*', '', answer_text)
|
||||||
|
answer_text = re.sub(r'```\s*', '', answer_text)
|
||||||
|
step.final_answer = json.loads(answer_text)
|
||||||
|
except json.JSONDecodeError:
|
||||||
|
step.final_answer = {"raw_answer": final_match.group(1).strip()}
|
||||||
|
return step
|
||||||
|
|
||||||
|
# 提取 Action
|
||||||
|
action_match = re.search(r'Action:\s*(\w+)', response)
|
||||||
|
if action_match:
|
||||||
|
step.action = action_match.group(1).strip()
|
||||||
|
|
||||||
|
# 提取 Action Input
|
||||||
|
input_match = re.search(r'Action Input:\s*(.*?)(?=Thought:|Action:|Observation:|$)', response, re.DOTALL)
|
||||||
|
if input_match:
|
||||||
|
input_text = input_match.group(1).strip()
|
||||||
|
# 移除 markdown 代码块
|
||||||
|
input_text = re.sub(r'```json\s*', '', input_text)
|
||||||
|
input_text = re.sub(r'```\s*', '', input_text)
|
||||||
|
try:
|
||||||
|
step.action_input = json.loads(input_text)
|
||||||
|
except json.JSONDecodeError:
|
||||||
|
# 尝试简单解析
|
||||||
|
step.action_input = {"raw_input": input_text}
|
||||||
|
|
||||||
|
return step
|
||||||
|
|
||||||
|
async def _execute_tool(self, tool_name: str, tool_input: Dict) -> str:
|
||||||
|
"""执行工具"""
|
||||||
|
tool = self.tools.get(tool_name)
|
||||||
|
|
||||||
|
if not tool:
|
||||||
|
return f"错误: 工具 '{tool_name}' 不存在。可用工具: {list(self.tools.keys())}"
|
||||||
|
|
||||||
|
try:
|
||||||
|
self._tool_calls += 1
|
||||||
|
await self.emit_tool_call(tool_name, tool_input)
|
||||||
|
|
||||||
|
import time
|
||||||
|
start = time.time()
|
||||||
|
|
||||||
|
result = await tool.execute(**tool_input)
|
||||||
|
|
||||||
|
duration_ms = int((time.time() - start) * 1000)
|
||||||
|
await self.emit_tool_result(tool_name, str(result.data)[:200], duration_ms)
|
||||||
|
|
||||||
|
if result.success:
|
||||||
|
# 截断过长的输出
|
||||||
|
output = str(result.data)
|
||||||
|
if len(output) > 4000:
|
||||||
|
output = output[:4000] + "\n\n... [输出已截断,共 {} 字符]".format(len(str(result.data)))
|
||||||
|
return output
|
||||||
|
else:
|
||||||
|
return f"工具执行失败: {result.error}"
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Tool execution error: {e}")
|
||||||
|
return f"工具执行错误: {str(e)}"
|
||||||
|
|
||||||
|
async def run(self, input_data: Dict[str, Any]) -> AgentResult:
|
||||||
|
"""
|
||||||
|
执行 ReAct Agent
|
||||||
|
|
||||||
|
LLM 全程参与,自主决策!
|
||||||
|
"""
|
||||||
|
import time
|
||||||
|
start_time = time.time()
|
||||||
|
|
||||||
|
project_info = input_data.get("project_info", {})
|
||||||
|
task_context = input_data.get("task_context", "")
|
||||||
|
config = input_data.get("config", {})
|
||||||
|
|
||||||
|
# 构建系统提示词
|
||||||
|
system_prompt = self._build_system_prompt(project_info, task_context)
|
||||||
|
|
||||||
|
# 初始化对话历史
|
||||||
|
self._conversation_history = [
|
||||||
|
{"role": "system", "content": system_prompt},
|
||||||
|
{"role": "user", "content": "请开始对项目进行安全审计。首先了解项目结构,然后系统性地搜索和分析潜在的安全漏洞。"},
|
||||||
|
]
|
||||||
|
|
||||||
|
self._steps = []
|
||||||
|
all_findings = []
|
||||||
|
|
||||||
|
await self.emit_thinking("🤖 ReAct Agent 启动,LLM 开始自主分析...")
|
||||||
|
|
||||||
|
try:
|
||||||
|
for iteration in range(self.config.max_iterations):
|
||||||
|
if self.is_cancelled:
|
||||||
|
break
|
||||||
|
|
||||||
|
self._iteration = iteration + 1
|
||||||
|
|
||||||
|
await self.emit_thinking(f"💭 第 {iteration + 1} 轮思考...")
|
||||||
|
|
||||||
|
# 🔥 调用 LLM 进行思考和决策
|
||||||
|
response = await self.llm_service.chat_completion_raw(
|
||||||
|
messages=self._conversation_history,
|
||||||
|
temperature=0.1,
|
||||||
|
max_tokens=2048,
|
||||||
|
)
|
||||||
|
|
||||||
|
llm_output = response.get("content", "")
|
||||||
|
self._total_tokens += response.get("usage", {}).get("total_tokens", 0)
|
||||||
|
|
||||||
|
# 发射思考事件
|
||||||
|
await self.emit_event("thinking", f"LLM: {llm_output[:500]}...")
|
||||||
|
|
||||||
|
# 解析 LLM 响应
|
||||||
|
step = self._parse_llm_response(llm_output)
|
||||||
|
self._steps.append(step)
|
||||||
|
|
||||||
|
# 添加 LLM 响应到历史
|
||||||
|
self._conversation_history.append({
|
||||||
|
"role": "assistant",
|
||||||
|
"content": llm_output,
|
||||||
|
})
|
||||||
|
|
||||||
|
# 检查是否完成
|
||||||
|
if step.is_final:
|
||||||
|
await self.emit_thinking("✅ LLM 完成分析,生成最终报告")
|
||||||
|
|
||||||
|
if step.final_answer and "findings" in step.final_answer:
|
||||||
|
all_findings = step.final_answer["findings"]
|
||||||
|
break
|
||||||
|
|
||||||
|
# 执行工具
|
||||||
|
if step.action:
|
||||||
|
await self.emit_thinking(f"🔧 LLM 决定调用工具: {step.action}")
|
||||||
|
|
||||||
|
observation = await self._execute_tool(
|
||||||
|
step.action,
|
||||||
|
step.action_input or {}
|
||||||
|
)
|
||||||
|
|
||||||
|
step.observation = observation
|
||||||
|
|
||||||
|
# 添加观察结果到历史
|
||||||
|
self._conversation_history.append({
|
||||||
|
"role": "user",
|
||||||
|
"content": f"Observation: {observation}",
|
||||||
|
})
|
||||||
|
else:
|
||||||
|
# LLM 没有选择工具,提示它继续
|
||||||
|
self._conversation_history.append({
|
||||||
|
"role": "user",
|
||||||
|
"content": "请继续分析,选择一个工具执行,或者如果分析完成,输出 Final Answer。",
|
||||||
|
})
|
||||||
|
|
||||||
|
duration_ms = int((time.time() - start_time) * 1000)
|
||||||
|
|
||||||
|
await self.emit_event(
|
||||||
|
"info",
|
||||||
|
f"🎯 ReAct Agent 完成: {len(all_findings)} 个发现, {self._iteration} 轮迭代, {self._tool_calls} 次工具调用"
|
||||||
|
)
|
||||||
|
|
||||||
|
return AgentResult(
|
||||||
|
success=True,
|
||||||
|
data={
|
||||||
|
"findings": all_findings,
|
||||||
|
"steps": [
|
||||||
|
{
|
||||||
|
"thought": s.thought,
|
||||||
|
"action": s.action,
|
||||||
|
"action_input": s.action_input,
|
||||||
|
"observation": s.observation[:500] if s.observation else None,
|
||||||
|
}
|
||||||
|
for s in self._steps
|
||||||
|
],
|
||||||
|
},
|
||||||
|
iterations=self._iteration,
|
||||||
|
tool_calls=self._tool_calls,
|
||||||
|
tokens_used=self._total_tokens,
|
||||||
|
duration_ms=duration_ms,
|
||||||
|
)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"ReAct Agent failed: {e}", exc_info=True)
|
||||||
|
return AgentResult(success=False, error=str(e))
|
||||||
|
|
||||||
|
def get_conversation_history(self) -> List[Dict[str, str]]:
|
||||||
|
"""获取对话历史"""
|
||||||
|
return self._conversation_history
|
||||||
|
|
||||||
|
def get_steps(self) -> List[AgentStep]:
|
||||||
|
"""获取执行步骤"""
|
||||||
|
return self._steps
|
||||||
|
|
@ -192,6 +192,37 @@ class AgentEventEmitter:
|
||||||
"percentage": percentage,
|
"percentage": percentage,
|
||||||
},
|
},
|
||||||
))
|
))
|
||||||
|
|
||||||
|
async def emit_task_complete(
|
||||||
|
self,
|
||||||
|
findings_count: int,
|
||||||
|
duration_ms: int,
|
||||||
|
message: Optional[str] = None,
|
||||||
|
):
|
||||||
|
"""发射任务完成事件"""
|
||||||
|
await self.emit(AgentEventData(
|
||||||
|
event_type="task_complete",
|
||||||
|
message=message or f"✅ 审计完成!发现 {findings_count} 个漏洞,耗时 {duration_ms/1000:.1f}秒",
|
||||||
|
metadata={
|
||||||
|
"findings_count": findings_count,
|
||||||
|
"duration_ms": duration_ms,
|
||||||
|
},
|
||||||
|
))
|
||||||
|
|
||||||
|
async def emit_task_error(self, error: str, message: Optional[str] = None):
|
||||||
|
"""发射任务错误事件"""
|
||||||
|
await self.emit(AgentEventData(
|
||||||
|
event_type="task_error",
|
||||||
|
message=message or f"❌ 任务失败: {error}",
|
||||||
|
metadata={"error": error},
|
||||||
|
))
|
||||||
|
|
||||||
|
async def emit_task_cancelled(self, message: Optional[str] = None):
|
||||||
|
"""发射任务取消事件"""
|
||||||
|
await self.emit(AgentEventData(
|
||||||
|
event_type="task_cancel",
|
||||||
|
message=message or "⚠️ 任务已取消",
|
||||||
|
))
|
||||||
|
|
||||||
|
|
||||||
class EventManager:
|
class EventManager:
|
||||||
|
|
@ -368,4 +399,15 @@ class EventManager:
|
||||||
def create_emitter(self, task_id: str) -> AgentEventEmitter:
|
def create_emitter(self, task_id: str) -> AgentEventEmitter:
|
||||||
"""创建事件发射器"""
|
"""创建事件发射器"""
|
||||||
return AgentEventEmitter(task_id, self)
|
return AgentEventEmitter(task_id, self)
|
||||||
|
|
||||||
|
async def close(self):
|
||||||
|
"""关闭事件管理器,清理资源"""
|
||||||
|
# 清理所有队列
|
||||||
|
for task_id in list(self._event_queues.keys()):
|
||||||
|
self.remove_queue(task_id)
|
||||||
|
|
||||||
|
# 清理所有回调
|
||||||
|
self._event_callbacks.clear()
|
||||||
|
|
||||||
|
logger.debug("EventManager closed")
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -15,6 +15,7 @@ from sqlalchemy.ext.asyncio import AsyncSession
|
||||||
from langgraph.graph import StateGraph, END
|
from langgraph.graph import StateGraph, END
|
||||||
from langgraph.checkpoint.memory import MemorySaver
|
from langgraph.checkpoint.memory import MemorySaver
|
||||||
|
|
||||||
|
from app.services.agent.streaming import StreamHandler, StreamEvent, StreamEventType
|
||||||
from app.models.agent_task import (
|
from app.models.agent_task import (
|
||||||
AgentTask, AgentEvent, AgentFinding,
|
AgentTask, AgentEvent, AgentFinding,
|
||||||
AgentTaskStatus, AgentTaskPhase, AgentEventType,
|
AgentTaskStatus, AgentTaskPhase, AgentEventType,
|
||||||
|
|
@ -39,11 +40,15 @@ logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
class LLMService:
|
class LLMService:
|
||||||
"""LLM 服务封装"""
|
"""
|
||||||
|
LLM 服务封装
|
||||||
|
提供代码分析、漏洞检测等 AI 功能
|
||||||
|
"""
|
||||||
|
|
||||||
def __init__(self, model: Optional[str] = None, api_key: Optional[str] = None):
|
def __init__(self, model: Optional[str] = None, api_key: Optional[str] = None):
|
||||||
self.model = model or settings.LLM_MODEL or "gpt-4o-mini"
|
self.model = model or settings.LLM_MODEL or "gpt-4o-mini"
|
||||||
self.api_key = api_key or settings.LLM_API_KEY
|
self.api_key = api_key or settings.LLM_API_KEY
|
||||||
|
self.base_url = settings.LLM_BASE_URL
|
||||||
|
|
||||||
async def chat_completion_raw(
|
async def chat_completion_raw(
|
||||||
self,
|
self,
|
||||||
|
|
@ -61,6 +66,7 @@ class LLMService:
|
||||||
temperature=temperature,
|
temperature=temperature,
|
||||||
max_tokens=max_tokens,
|
max_tokens=max_tokens,
|
||||||
api_key=self.api_key,
|
api_key=self.api_key,
|
||||||
|
base_url=self.base_url,
|
||||||
)
|
)
|
||||||
|
|
||||||
return {
|
return {
|
||||||
|
|
@ -75,6 +81,125 @@ class LLMService:
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"LLM call failed: {e}")
|
logger.error(f"LLM call failed: {e}")
|
||||||
raise
|
raise
|
||||||
|
|
||||||
|
async def analyze_code(self, code: str, language: str) -> Dict[str, Any]:
|
||||||
|
"""
|
||||||
|
分析代码安全问题
|
||||||
|
|
||||||
|
Args:
|
||||||
|
code: 代码内容
|
||||||
|
language: 编程语言
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
分析结果,包含 issues 列表
|
||||||
|
"""
|
||||||
|
prompt = f"""请分析以下 {language} 代码的安全问题。
|
||||||
|
|
||||||
|
代码:
|
||||||
|
```{language}
|
||||||
|
{code[:8000]}
|
||||||
|
```
|
||||||
|
|
||||||
|
请识别所有潜在的安全漏洞,包括但不限于:
|
||||||
|
- SQL 注入
|
||||||
|
- XSS (跨站脚本)
|
||||||
|
- 命令注入
|
||||||
|
- 路径遍历
|
||||||
|
- 不安全的反序列化
|
||||||
|
- 硬编码密钥/密码
|
||||||
|
- 不安全的加密
|
||||||
|
- SSRF
|
||||||
|
- 认证/授权问题
|
||||||
|
|
||||||
|
对于每个发现的问题,请提供:
|
||||||
|
1. 漏洞类型
|
||||||
|
2. 严重程度 (critical/high/medium/low)
|
||||||
|
3. 问题描述
|
||||||
|
4. 具体行号
|
||||||
|
5. 修复建议
|
||||||
|
|
||||||
|
请以 JSON 格式返回结果:
|
||||||
|
{{
|
||||||
|
"issues": [
|
||||||
|
{{
|
||||||
|
"type": "漏洞类型",
|
||||||
|
"severity": "严重程度",
|
||||||
|
"title": "问题标题",
|
||||||
|
"description": "详细描述",
|
||||||
|
"line": 行号,
|
||||||
|
"code_snippet": "相关代码片段",
|
||||||
|
"suggestion": "修复建议"
|
||||||
|
}}
|
||||||
|
],
|
||||||
|
"quality_score": 0-100
|
||||||
|
}}
|
||||||
|
|
||||||
|
如果没有发现安全问题,返回空的 issues 数组和较高的 quality_score。"""
|
||||||
|
|
||||||
|
try:
|
||||||
|
result = await self.chat_completion_raw(
|
||||||
|
messages=[
|
||||||
|
{"role": "system", "content": "你是一位专业的代码安全审计专家,擅长发现代码中的安全漏洞。请只返回 JSON 格式的结果,不要包含其他内容。"},
|
||||||
|
{"role": "user", "content": prompt},
|
||||||
|
],
|
||||||
|
temperature=0.1,
|
||||||
|
max_tokens=4096,
|
||||||
|
)
|
||||||
|
|
||||||
|
content = result.get("content", "{}")
|
||||||
|
|
||||||
|
# 尝试提取 JSON
|
||||||
|
import json
|
||||||
|
import re
|
||||||
|
|
||||||
|
# 尝试直接解析
|
||||||
|
try:
|
||||||
|
return json.loads(content)
|
||||||
|
except json.JSONDecodeError:
|
||||||
|
pass
|
||||||
|
|
||||||
|
# 尝试从 markdown 代码块提取
|
||||||
|
json_match = re.search(r'```(?:json)?\s*([\s\S]*?)\s*```', content)
|
||||||
|
if json_match:
|
||||||
|
try:
|
||||||
|
return json.loads(json_match.group(1))
|
||||||
|
except json.JSONDecodeError:
|
||||||
|
pass
|
||||||
|
|
||||||
|
# 返回空结果
|
||||||
|
return {"issues": [], "quality_score": 80}
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Code analysis failed: {e}")
|
||||||
|
return {"issues": [], "quality_score": 0, "error": str(e)}
|
||||||
|
|
||||||
|
async def analyze_code_with_custom_prompt(
|
||||||
|
self,
|
||||||
|
code: str,
|
||||||
|
language: str,
|
||||||
|
prompt: str,
|
||||||
|
**kwargs
|
||||||
|
) -> Dict[str, Any]:
|
||||||
|
"""使用自定义提示词分析代码"""
|
||||||
|
full_prompt = prompt.replace("{code}", code).replace("{language}", language)
|
||||||
|
|
||||||
|
try:
|
||||||
|
result = await self.chat_completion_raw(
|
||||||
|
messages=[
|
||||||
|
{"role": "system", "content": "你是一位专业的代码安全审计专家。"},
|
||||||
|
{"role": "user", "content": full_prompt},
|
||||||
|
],
|
||||||
|
temperature=0.1,
|
||||||
|
)
|
||||||
|
|
||||||
|
return {
|
||||||
|
"analysis": result.get("content", ""),
|
||||||
|
"usage": result.get("usage", {}),
|
||||||
|
}
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Custom analysis failed: {e}")
|
||||||
|
return {"analysis": "", "error": str(e)}
|
||||||
|
|
||||||
|
|
||||||
class AgentRunner:
|
class AgentRunner:
|
||||||
|
|
@ -97,8 +222,9 @@ class AgentRunner:
|
||||||
self.task = task
|
self.task = task
|
||||||
self.project_root = project_root
|
self.project_root = project_root
|
||||||
|
|
||||||
# 事件管理
|
# 事件管理 - 传入 db_session_factory 以持久化事件
|
||||||
self.event_manager = EventManager()
|
from app.db.session import async_session_factory
|
||||||
|
self.event_manager = EventManager(db_session_factory=async_session_factory)
|
||||||
self.event_emitter = AgentEventEmitter(task.id, self.event_manager)
|
self.event_emitter = AgentEventEmitter(task.id, self.event_manager)
|
||||||
|
|
||||||
# LLM 服务
|
# LLM 服务
|
||||||
|
|
@ -120,6 +246,22 @@ class AgentRunner:
|
||||||
|
|
||||||
# 状态
|
# 状态
|
||||||
self._cancelled = False
|
self._cancelled = False
|
||||||
|
self._running_task: Optional[asyncio.Task] = None
|
||||||
|
|
||||||
|
# 流式处理器
|
||||||
|
self.stream_handler = StreamHandler(task.id)
|
||||||
|
|
||||||
|
def cancel(self):
|
||||||
|
"""取消任务"""
|
||||||
|
self._cancelled = True
|
||||||
|
if self._running_task and not self._running_task.done():
|
||||||
|
self._running_task.cancel()
|
||||||
|
logger.info(f"Task {self.task.id} cancellation requested")
|
||||||
|
|
||||||
|
@property
|
||||||
|
def is_cancelled(self) -> bool:
|
||||||
|
"""检查是否已取消"""
|
||||||
|
return self._cancelled
|
||||||
|
|
||||||
async def initialize(self):
|
async def initialize(self):
|
||||||
"""初始化 Runner"""
|
"""初始化 Runner"""
|
||||||
|
|
@ -149,15 +291,15 @@ class AgentRunner:
|
||||||
)
|
)
|
||||||
|
|
||||||
self.indexer = CodeIndexer(
|
self.indexer = CodeIndexer(
|
||||||
embedding_service=embedding_service,
|
|
||||||
vector_db_path=settings.VECTOR_DB_PATH,
|
|
||||||
collection_name=f"project_{self.task.project_id}",
|
collection_name=f"project_{self.task.project_id}",
|
||||||
|
embedding_service=embedding_service,
|
||||||
|
persist_directory=settings.VECTOR_DB_PATH,
|
||||||
)
|
)
|
||||||
|
|
||||||
self.retriever = CodeRetriever(
|
self.retriever = CodeRetriever(
|
||||||
embedding_service=embedding_service,
|
|
||||||
vector_db_path=settings.VECTOR_DB_PATH,
|
|
||||||
collection_name=f"project_{self.task.project_id}",
|
collection_name=f"project_{self.task.project_id}",
|
||||||
|
embedding_service=embedding_service,
|
||||||
|
persist_directory=settings.VECTOR_DB_PATH,
|
||||||
)
|
)
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
|
|
@ -261,6 +403,18 @@ class AgentRunner:
|
||||||
Returns:
|
Returns:
|
||||||
最终状态
|
最终状态
|
||||||
"""
|
"""
|
||||||
|
result = {}
|
||||||
|
async for _ in self.run_with_streaming():
|
||||||
|
pass # 消费所有事件
|
||||||
|
return result
|
||||||
|
|
||||||
|
async def run_with_streaming(self) -> AsyncGenerator[StreamEvent, None]:
|
||||||
|
"""
|
||||||
|
带流式输出的审计执行
|
||||||
|
|
||||||
|
Yields:
|
||||||
|
StreamEvent: 流式事件(包含 LLM 思考、工具调用等)
|
||||||
|
"""
|
||||||
import time
|
import time
|
||||||
start_time = time.time()
|
start_time = time.time()
|
||||||
|
|
||||||
|
|
@ -271,17 +425,28 @@ class AgentRunner:
|
||||||
# 更新任务状态
|
# 更新任务状态
|
||||||
await self._update_task_status(AgentTaskStatus.RUNNING)
|
await self._update_task_status(AgentTaskStatus.RUNNING)
|
||||||
|
|
||||||
|
# 发射任务开始事件
|
||||||
|
yield StreamEvent(
|
||||||
|
event_type=StreamEventType.TASK_START,
|
||||||
|
sequence=self.stream_handler._next_sequence(),
|
||||||
|
data={"task_id": self.task.id, "message": "🚀 审计任务开始"},
|
||||||
|
)
|
||||||
|
|
||||||
# 1. 索引代码
|
# 1. 索引代码
|
||||||
await self._index_code()
|
await self._index_code()
|
||||||
|
|
||||||
if self._cancelled:
|
if self._cancelled:
|
||||||
return {"success": False, "error": "任务已取消"}
|
yield StreamEvent(
|
||||||
|
event_type=StreamEventType.TASK_CANCEL,
|
||||||
|
sequence=self.stream_handler._next_sequence(),
|
||||||
|
data={"message": "任务已取消"},
|
||||||
|
)
|
||||||
|
return
|
||||||
|
|
||||||
# 2. 收集项目信息
|
# 2. 收集项目信息
|
||||||
project_info = await self._collect_project_info()
|
project_info = await self._collect_project_info()
|
||||||
|
|
||||||
# 3. 构建初始状态
|
# 3. 构建初始状态
|
||||||
# 从任务字段构建配置
|
|
||||||
task_config = {
|
task_config = {
|
||||||
"target_vulnerabilities": self.task.target_vulnerabilities or [],
|
"target_vulnerabilities": self.task.target_vulnerabilities or [],
|
||||||
"verification_level": self.task.verification_level or "sandbox",
|
"verification_level": self.task.verification_level or "sandbox",
|
||||||
|
|
@ -314,7 +479,7 @@ class AgentRunner:
|
||||||
"error": None,
|
"error": None,
|
||||||
}
|
}
|
||||||
|
|
||||||
# 4. 执行 LangGraph
|
# 4. 执行 LangGraph with astream_events
|
||||||
await self.event_emitter.emit_phase_start("langgraph", "🔄 启动 LangGraph 工作流")
|
await self.event_emitter.emit_phase_start("langgraph", "🔄 启动 LangGraph 工作流")
|
||||||
|
|
||||||
run_config = {
|
run_config = {
|
||||||
|
|
@ -325,26 +490,57 @@ class AgentRunner:
|
||||||
|
|
||||||
final_state = None
|
final_state = None
|
||||||
|
|
||||||
# 流式执行并发射事件
|
# 使用 astream_events 获取详细事件流
|
||||||
async for event in self.graph.astream(initial_state, config=run_config):
|
try:
|
||||||
if self._cancelled:
|
async for event in self.graph.astream_events(
|
||||||
break
|
initial_state,
|
||||||
|
config=run_config,
|
||||||
# 处理每个节点的输出
|
version="v2",
|
||||||
for node_name, node_output in event.items():
|
):
|
||||||
await self._handle_node_output(node_name, node_output)
|
if self._cancelled:
|
||||||
|
break
|
||||||
|
|
||||||
# 更新阶段
|
# 处理 LangGraph 事件
|
||||||
phase_map = {
|
stream_event = await self.stream_handler.process_langgraph_event(event)
|
||||||
"recon": AgentTaskPhase.RECONNAISSANCE,
|
if stream_event:
|
||||||
"analysis": AgentTaskPhase.ANALYSIS,
|
# 同步到 event_emitter 以持久化
|
||||||
"verification": AgentTaskPhase.VERIFICATION,
|
await self._sync_stream_event_to_db(stream_event)
|
||||||
"report": AgentTaskPhase.REPORTING,
|
yield stream_event
|
||||||
}
|
|
||||||
if node_name in phase_map:
|
|
||||||
await self._update_task_phase(phase_map[node_name])
|
|
||||||
|
|
||||||
final_state = node_output
|
# 更新最终状态
|
||||||
|
if event.get("event") == "on_chain_end":
|
||||||
|
output = event.get("data", {}).get("output")
|
||||||
|
if isinstance(output, dict):
|
||||||
|
final_state = output
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
# 如果 astream_events 不可用,回退到 astream
|
||||||
|
logger.warning(f"astream_events not available, falling back to astream: {e}")
|
||||||
|
async for event in self.graph.astream(initial_state, config=run_config):
|
||||||
|
if self._cancelled:
|
||||||
|
break
|
||||||
|
|
||||||
|
for node_name, node_output in event.items():
|
||||||
|
await self._handle_node_output(node_name, node_output)
|
||||||
|
|
||||||
|
# 发射节点事件
|
||||||
|
yield StreamEvent(
|
||||||
|
event_type=StreamEventType.NODE_END,
|
||||||
|
sequence=self.stream_handler._next_sequence(),
|
||||||
|
node_name=node_name,
|
||||||
|
data={"message": f"节点 {node_name} 完成"},
|
||||||
|
)
|
||||||
|
|
||||||
|
phase_map = {
|
||||||
|
"recon": AgentTaskPhase.RECONNAISSANCE,
|
||||||
|
"analysis": AgentTaskPhase.ANALYSIS,
|
||||||
|
"verification": AgentTaskPhase.VERIFICATION,
|
||||||
|
"report": AgentTaskPhase.REPORTING,
|
||||||
|
}
|
||||||
|
if node_name in phase_map:
|
||||||
|
await self._update_task_phase(phase_map[node_name])
|
||||||
|
|
||||||
|
final_state = node_output
|
||||||
|
|
||||||
# 5. 获取最终状态
|
# 5. 获取最终状态
|
||||||
if not final_state:
|
if not final_state:
|
||||||
|
|
@ -355,6 +551,13 @@ class AgentRunner:
|
||||||
findings = final_state.get("findings", [])
|
findings = final_state.get("findings", [])
|
||||||
await self._save_findings(findings)
|
await self._save_findings(findings)
|
||||||
|
|
||||||
|
# 发射发现事件
|
||||||
|
for finding in findings[:10]: # 限制数量
|
||||||
|
yield self.stream_handler.create_finding_event(
|
||||||
|
finding,
|
||||||
|
is_verified=finding.get("is_verified", False),
|
||||||
|
)
|
||||||
|
|
||||||
# 7. 更新任务摘要
|
# 7. 更新任务摘要
|
||||||
summary = final_state.get("summary", {})
|
summary = final_state.get("summary", {})
|
||||||
security_score = final_state.get("security_score", 100)
|
security_score = final_state.get("security_score", 100)
|
||||||
|
|
@ -374,30 +577,59 @@ class AgentRunner:
|
||||||
duration_ms=duration_ms,
|
duration_ms=duration_ms,
|
||||||
)
|
)
|
||||||
|
|
||||||
return {
|
yield StreamEvent(
|
||||||
"success": True,
|
event_type=StreamEventType.TASK_COMPLETE,
|
||||||
"data": {
|
sequence=self.stream_handler._next_sequence(),
|
||||||
"findings": findings,
|
data={
|
||||||
"verified_findings": final_state.get("verified_findings", []),
|
"findings_count": len(findings),
|
||||||
"summary": summary,
|
"verified_count": len(final_state.get("verified_findings", [])),
|
||||||
"security_score": security_score,
|
"security_score": security_score,
|
||||||
|
"duration_ms": duration_ms,
|
||||||
|
"message": f"✅ 审计完成!发现 {len(findings)} 个漏洞",
|
||||||
},
|
},
|
||||||
"duration_ms": duration_ms,
|
)
|
||||||
}
|
|
||||||
|
|
||||||
except asyncio.CancelledError:
|
except asyncio.CancelledError:
|
||||||
await self._update_task_status(AgentTaskStatus.CANCELLED)
|
await self._update_task_status(AgentTaskStatus.CANCELLED)
|
||||||
return {"success": False, "error": "任务已取消"}
|
yield StreamEvent(
|
||||||
|
event_type=StreamEventType.TASK_CANCEL,
|
||||||
|
sequence=self.stream_handler._next_sequence(),
|
||||||
|
data={"message": "任务已取消"},
|
||||||
|
)
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"LangGraph run failed: {e}", exc_info=True)
|
logger.error(f"LangGraph run failed: {e}", exc_info=True)
|
||||||
await self._update_task_status(AgentTaskStatus.FAILED, str(e))
|
await self._update_task_status(AgentTaskStatus.FAILED, str(e))
|
||||||
await self.event_emitter.emit_error(str(e))
|
await self.event_emitter.emit_error(str(e))
|
||||||
return {"success": False, "error": str(e)}
|
|
||||||
|
yield StreamEvent(
|
||||||
|
event_type=StreamEventType.TASK_ERROR,
|
||||||
|
sequence=self.stream_handler._next_sequence(),
|
||||||
|
data={"error": str(e), "message": f"❌ 审计失败: {e}"},
|
||||||
|
)
|
||||||
|
|
||||||
finally:
|
finally:
|
||||||
await self._cleanup()
|
await self._cleanup()
|
||||||
|
|
||||||
|
async def _sync_stream_event_to_db(self, event: StreamEvent):
|
||||||
|
"""同步流式事件到数据库"""
|
||||||
|
try:
|
||||||
|
# 将 StreamEvent 转换为 AgentEventData
|
||||||
|
await self.event_manager.add_event(
|
||||||
|
task_id=self.task.id,
|
||||||
|
event_type=event.event_type.value,
|
||||||
|
sequence=event.sequence,
|
||||||
|
phase=event.phase,
|
||||||
|
message=event.data.get("message"),
|
||||||
|
tool_name=event.tool_name,
|
||||||
|
tool_input=event.data.get("input") or event.data.get("input_params"),
|
||||||
|
tool_output=event.data.get("output") or event.data.get("output_data"),
|
||||||
|
tool_duration_ms=event.data.get("duration_ms"),
|
||||||
|
metadata=event.data,
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"Failed to sync stream event to db: {e}")
|
||||||
|
|
||||||
async def _handle_node_output(self, node_name: str, output: Dict[str, Any]):
|
async def _handle_node_output(self, node_name: str, output: Dict[str, Any]):
|
||||||
"""处理节点输出"""
|
"""处理节点输出"""
|
||||||
# 发射节点事件
|
# 发射节点事件
|
||||||
|
|
@ -445,7 +677,8 @@ class AgentRunner:
|
||||||
return
|
return
|
||||||
|
|
||||||
await self.event_emitter.emit_progress(
|
await self.event_emitter.emit_progress(
|
||||||
progress.processed / max(progress.total, 1) * 100,
|
progress.processed_files,
|
||||||
|
progress.total_files,
|
||||||
f"正在索引: {progress.current_file or 'N/A'}"
|
f"正在索引: {progress.current_file or 'N/A'}"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
@ -502,13 +735,23 @@ class AgentRunner:
|
||||||
|
|
||||||
type_map = {
|
type_map = {
|
||||||
"sql_injection": VulnerabilityType.SQL_INJECTION,
|
"sql_injection": VulnerabilityType.SQL_INJECTION,
|
||||||
|
"nosql_injection": VulnerabilityType.NOSQL_INJECTION,
|
||||||
"xss": VulnerabilityType.XSS,
|
"xss": VulnerabilityType.XSS,
|
||||||
"command_injection": VulnerabilityType.COMMAND_INJECTION,
|
"command_injection": VulnerabilityType.COMMAND_INJECTION,
|
||||||
|
"code_injection": VulnerabilityType.CODE_INJECTION,
|
||||||
"path_traversal": VulnerabilityType.PATH_TRAVERSAL,
|
"path_traversal": VulnerabilityType.PATH_TRAVERSAL,
|
||||||
|
"file_inclusion": VulnerabilityType.FILE_INCLUSION,
|
||||||
"ssrf": VulnerabilityType.SSRF,
|
"ssrf": VulnerabilityType.SSRF,
|
||||||
|
"xxe": VulnerabilityType.XXE,
|
||||||
|
"deserialization": VulnerabilityType.DESERIALIZATION,
|
||||||
|
"auth_bypass": VulnerabilityType.AUTH_BYPASS,
|
||||||
|
"idor": VulnerabilityType.IDOR,
|
||||||
|
"sensitive_data_exposure": VulnerabilityType.SENSITIVE_DATA_EXPOSURE,
|
||||||
"hardcoded_secret": VulnerabilityType.HARDCODED_SECRET,
|
"hardcoded_secret": VulnerabilityType.HARDCODED_SECRET,
|
||||||
"deserialization": VulnerabilityType.INSECURE_DESERIALIZATION,
|
|
||||||
"weak_crypto": VulnerabilityType.WEAK_CRYPTO,
|
"weak_crypto": VulnerabilityType.WEAK_CRYPTO,
|
||||||
|
"race_condition": VulnerabilityType.RACE_CONDITION,
|
||||||
|
"business_logic": VulnerabilityType.BUSINESS_LOGIC,
|
||||||
|
"memory_corruption": VulnerabilityType.MEMORY_CORRUPTION,
|
||||||
}
|
}
|
||||||
|
|
||||||
for finding in findings:
|
for finding in findings:
|
||||||
|
|
@ -536,7 +779,7 @@ class AgentRunner:
|
||||||
is_verified=finding.get("is_verified", False),
|
is_verified=finding.get("is_verified", False),
|
||||||
confidence=finding.get("confidence", 0.5),
|
confidence=finding.get("confidence", 0.5),
|
||||||
poc=finding.get("poc"),
|
poc=finding.get("poc"),
|
||||||
status=FindingStatus.VERIFIED if finding.get("is_verified") else FindingStatus.OPEN,
|
status=FindingStatus.VERIFIED if finding.get("is_verified") else FindingStatus.NEW,
|
||||||
)
|
)
|
||||||
|
|
||||||
self.db.add(db_finding)
|
self.db.add(db_finding)
|
||||||
|
|
@ -603,10 +846,6 @@ class AgentRunner:
|
||||||
await self.event_manager.close()
|
await self.event_manager.close()
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.warning(f"Cleanup error: {e}")
|
logger.warning(f"Cleanup error: {e}")
|
||||||
|
|
||||||
def cancel(self):
|
|
||||||
"""取消任务"""
|
|
||||||
self._cancelled = True
|
|
||||||
|
|
||||||
|
|
||||||
# 便捷函数
|
# 便捷函数
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,17 @@
|
||||||
|
"""
|
||||||
|
Agent 流式输出模块
|
||||||
|
支持 LLM Token 流式输出、工具调用展示、思考过程展示
|
||||||
|
"""
|
||||||
|
|
||||||
|
from .stream_handler import StreamHandler, StreamEvent, StreamEventType
|
||||||
|
from .token_streamer import TokenStreamer
|
||||||
|
from .tool_stream import ToolStreamHandler
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"StreamHandler",
|
||||||
|
"StreamEvent",
|
||||||
|
"StreamEventType",
|
||||||
|
"TokenStreamer",
|
||||||
|
"ToolStreamHandler",
|
||||||
|
]
|
||||||
|
|
||||||
|
|
@ -0,0 +1,453 @@
|
||||||
|
"""
|
||||||
|
流式事件处理器
|
||||||
|
处理 LangGraph 的各种流式事件并转换为前端可消费的格式
|
||||||
|
"""
|
||||||
|
|
||||||
|
import json
|
||||||
|
import logging
|
||||||
|
from enum import Enum
|
||||||
|
from typing import Any, Dict, Optional, AsyncGenerator, List
|
||||||
|
from dataclasses import dataclass, field
|
||||||
|
from datetime import datetime, timezone
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class StreamEventType(str, Enum):
|
||||||
|
"""流式事件类型"""
|
||||||
|
# LLM 相关
|
||||||
|
THINKING_START = "thinking_start" # 开始思考
|
||||||
|
THINKING_TOKEN = "thinking_token" # 思考 Token
|
||||||
|
THINKING_END = "thinking_end" # 思考结束
|
||||||
|
|
||||||
|
# 工具调用相关
|
||||||
|
TOOL_CALL_START = "tool_call_start" # 工具调用开始
|
||||||
|
TOOL_CALL_INPUT = "tool_call_input" # 工具输入参数
|
||||||
|
TOOL_CALL_OUTPUT = "tool_call_output" # 工具输出结果
|
||||||
|
TOOL_CALL_END = "tool_call_end" # 工具调用结束
|
||||||
|
TOOL_CALL_ERROR = "tool_call_error" # 工具调用错误
|
||||||
|
|
||||||
|
# 节点相关
|
||||||
|
NODE_START = "node_start" # 节点开始
|
||||||
|
NODE_END = "node_end" # 节点结束
|
||||||
|
|
||||||
|
# 阶段相关
|
||||||
|
PHASE_START = "phase_start"
|
||||||
|
PHASE_END = "phase_end"
|
||||||
|
|
||||||
|
# 发现相关
|
||||||
|
FINDING_NEW = "finding_new" # 新发现
|
||||||
|
FINDING_VERIFIED = "finding_verified" # 验证通过
|
||||||
|
|
||||||
|
# 状态相关
|
||||||
|
PROGRESS = "progress"
|
||||||
|
INFO = "info"
|
||||||
|
WARNING = "warning"
|
||||||
|
ERROR = "error"
|
||||||
|
|
||||||
|
# 任务相关
|
||||||
|
TASK_START = "task_start"
|
||||||
|
TASK_COMPLETE = "task_complete"
|
||||||
|
TASK_ERROR = "task_error"
|
||||||
|
TASK_CANCEL = "task_cancel"
|
||||||
|
|
||||||
|
# 心跳
|
||||||
|
HEARTBEAT = "heartbeat"
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class StreamEvent:
|
||||||
|
"""流式事件"""
|
||||||
|
event_type: StreamEventType
|
||||||
|
data: Dict[str, Any] = field(default_factory=dict)
|
||||||
|
timestamp: str = field(default_factory=lambda: datetime.now(timezone.utc).isoformat())
|
||||||
|
sequence: int = 0
|
||||||
|
|
||||||
|
# 可选字段
|
||||||
|
node_name: Optional[str] = None
|
||||||
|
phase: Optional[str] = None
|
||||||
|
tool_name: Optional[str] = None
|
||||||
|
|
||||||
|
def to_sse(self) -> str:
|
||||||
|
"""转换为 SSE 格式"""
|
||||||
|
event_data = {
|
||||||
|
"type": self.event_type.value,
|
||||||
|
"data": self.data,
|
||||||
|
"timestamp": self.timestamp,
|
||||||
|
"sequence": self.sequence,
|
||||||
|
}
|
||||||
|
|
||||||
|
if self.node_name:
|
||||||
|
event_data["node"] = self.node_name
|
||||||
|
if self.phase:
|
||||||
|
event_data["phase"] = self.phase
|
||||||
|
if self.tool_name:
|
||||||
|
event_data["tool"] = self.tool_name
|
||||||
|
|
||||||
|
return f"event: {self.event_type.value}\ndata: {json.dumps(event_data, ensure_ascii=False)}\n\n"
|
||||||
|
|
||||||
|
def to_dict(self) -> Dict[str, Any]:
|
||||||
|
"""转换为字典"""
|
||||||
|
return {
|
||||||
|
"event_type": self.event_type.value,
|
||||||
|
"data": self.data,
|
||||||
|
"timestamp": self.timestamp,
|
||||||
|
"sequence": self.sequence,
|
||||||
|
"node_name": self.node_name,
|
||||||
|
"phase": self.phase,
|
||||||
|
"tool_name": self.tool_name,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
class StreamHandler:
|
||||||
|
"""
|
||||||
|
流式事件处理器
|
||||||
|
|
||||||
|
最佳实践:
|
||||||
|
1. 使用 astream_events 捕获所有 LangGraph 事件
|
||||||
|
2. 将内部事件转换为前端友好的格式
|
||||||
|
3. 支持多种事件类型的分发
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, task_id: str):
|
||||||
|
self.task_id = task_id
|
||||||
|
self._sequence = 0
|
||||||
|
self._current_phase = None
|
||||||
|
self._current_node = None
|
||||||
|
self._thinking_buffer = []
|
||||||
|
self._tool_states: Dict[str, Dict] = {}
|
||||||
|
|
||||||
|
def _next_sequence(self) -> int:
|
||||||
|
"""获取下一个序列号"""
|
||||||
|
self._sequence += 1
|
||||||
|
return self._sequence
|
||||||
|
|
||||||
|
async def process_langgraph_event(self, event: Dict[str, Any]) -> Optional[StreamEvent]:
|
||||||
|
"""
|
||||||
|
处理 LangGraph 事件
|
||||||
|
|
||||||
|
支持的事件类型:
|
||||||
|
- on_chain_start: 链/节点开始
|
||||||
|
- on_chain_end: 链/节点结束
|
||||||
|
- on_chain_stream: LLM Token 流
|
||||||
|
- on_chat_model_start: 模型开始
|
||||||
|
- on_chat_model_stream: 模型 Token 流
|
||||||
|
- on_chat_model_end: 模型结束
|
||||||
|
- on_tool_start: 工具开始
|
||||||
|
- on_tool_end: 工具结束
|
||||||
|
- on_custom_event: 自定义事件
|
||||||
|
"""
|
||||||
|
event_kind = event.get("event", "")
|
||||||
|
event_name = event.get("name", "")
|
||||||
|
event_data = event.get("data", {})
|
||||||
|
|
||||||
|
# LLM Token 流
|
||||||
|
if event_kind == "on_chat_model_stream":
|
||||||
|
return await self._handle_llm_stream(event_data, event_name)
|
||||||
|
|
||||||
|
# LLM 开始
|
||||||
|
elif event_kind == "on_chat_model_start":
|
||||||
|
return await self._handle_llm_start(event_data, event_name)
|
||||||
|
|
||||||
|
# LLM 结束
|
||||||
|
elif event_kind == "on_chat_model_end":
|
||||||
|
return await self._handle_llm_end(event_data, event_name)
|
||||||
|
|
||||||
|
# 工具开始
|
||||||
|
elif event_kind == "on_tool_start":
|
||||||
|
return await self._handle_tool_start(event_name, event_data)
|
||||||
|
|
||||||
|
# 工具结束
|
||||||
|
elif event_kind == "on_tool_end":
|
||||||
|
return await self._handle_tool_end(event_name, event_data)
|
||||||
|
|
||||||
|
# 节点开始
|
||||||
|
elif event_kind == "on_chain_start" and self._is_node_event(event_name):
|
||||||
|
return await self._handle_node_start(event_name, event_data)
|
||||||
|
|
||||||
|
# 节点结束
|
||||||
|
elif event_kind == "on_chain_end" and self._is_node_event(event_name):
|
||||||
|
return await self._handle_node_end(event_name, event_data)
|
||||||
|
|
||||||
|
# 自定义事件
|
||||||
|
elif event_kind == "on_custom_event":
|
||||||
|
return await self._handle_custom_event(event_name, event_data)
|
||||||
|
|
||||||
|
return None
|
||||||
|
|
||||||
|
def _is_node_event(self, name: str) -> bool:
|
||||||
|
"""判断是否是节点事件"""
|
||||||
|
node_names = ["recon", "analysis", "verification", "report", "ReconNode", "AnalysisNode", "VerificationNode", "ReportNode"]
|
||||||
|
return any(n.lower() in name.lower() for n in node_names)
|
||||||
|
|
||||||
|
async def _handle_llm_start(self, data: Dict, name: str) -> StreamEvent:
|
||||||
|
"""处理 LLM 开始事件"""
|
||||||
|
self._thinking_buffer = []
|
||||||
|
|
||||||
|
return StreamEvent(
|
||||||
|
event_type=StreamEventType.THINKING_START,
|
||||||
|
sequence=self._next_sequence(),
|
||||||
|
node_name=self._current_node,
|
||||||
|
phase=self._current_phase,
|
||||||
|
data={
|
||||||
|
"model": name,
|
||||||
|
"message": "🤔 正在思考...",
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
async def _handle_llm_stream(self, data: Dict, name: str) -> Optional[StreamEvent]:
|
||||||
|
"""处理 LLM Token 流事件"""
|
||||||
|
chunk = data.get("chunk")
|
||||||
|
if not chunk:
|
||||||
|
return None
|
||||||
|
|
||||||
|
# 提取 Token 内容
|
||||||
|
content = ""
|
||||||
|
if hasattr(chunk, "content"):
|
||||||
|
content = chunk.content
|
||||||
|
elif isinstance(chunk, dict):
|
||||||
|
content = chunk.get("content", "")
|
||||||
|
|
||||||
|
if not content:
|
||||||
|
return None
|
||||||
|
|
||||||
|
# 添加到缓冲区
|
||||||
|
self._thinking_buffer.append(content)
|
||||||
|
|
||||||
|
return StreamEvent(
|
||||||
|
event_type=StreamEventType.THINKING_TOKEN,
|
||||||
|
sequence=self._next_sequence(),
|
||||||
|
node_name=self._current_node,
|
||||||
|
phase=self._current_phase,
|
||||||
|
data={
|
||||||
|
"token": content,
|
||||||
|
"accumulated": "".join(self._thinking_buffer),
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
async def _handle_llm_end(self, data: Dict, name: str) -> StreamEvent:
|
||||||
|
"""处理 LLM 结束事件"""
|
||||||
|
full_response = "".join(self._thinking_buffer)
|
||||||
|
self._thinking_buffer = []
|
||||||
|
|
||||||
|
# 提取使用的 Token 数
|
||||||
|
usage = {}
|
||||||
|
output = data.get("output")
|
||||||
|
if output and hasattr(output, "usage_metadata"):
|
||||||
|
usage = {
|
||||||
|
"input_tokens": getattr(output.usage_metadata, "input_tokens", 0),
|
||||||
|
"output_tokens": getattr(output.usage_metadata, "output_tokens", 0),
|
||||||
|
}
|
||||||
|
|
||||||
|
return StreamEvent(
|
||||||
|
event_type=StreamEventType.THINKING_END,
|
||||||
|
sequence=self._next_sequence(),
|
||||||
|
node_name=self._current_node,
|
||||||
|
phase=self._current_phase,
|
||||||
|
data={
|
||||||
|
"response": full_response[:2000], # 截断长响应
|
||||||
|
"usage": usage,
|
||||||
|
"message": "💡 思考完成",
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
async def _handle_tool_start(self, tool_name: str, data: Dict) -> StreamEvent:
|
||||||
|
"""处理工具开始事件"""
|
||||||
|
import time
|
||||||
|
|
||||||
|
tool_input = data.get("input", {})
|
||||||
|
|
||||||
|
# 记录工具状态
|
||||||
|
self._tool_states[tool_name] = {
|
||||||
|
"start_time": time.time(),
|
||||||
|
"input": tool_input,
|
||||||
|
}
|
||||||
|
|
||||||
|
return StreamEvent(
|
||||||
|
event_type=StreamEventType.TOOL_CALL_START,
|
||||||
|
sequence=self._next_sequence(),
|
||||||
|
node_name=self._current_node,
|
||||||
|
phase=self._current_phase,
|
||||||
|
tool_name=tool_name,
|
||||||
|
data={
|
||||||
|
"tool_name": tool_name,
|
||||||
|
"input": self._truncate_data(tool_input),
|
||||||
|
"message": f"🔧 调用工具: {tool_name}",
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
async def _handle_tool_end(self, tool_name: str, data: Dict) -> StreamEvent:
|
||||||
|
"""处理工具结束事件"""
|
||||||
|
import time
|
||||||
|
|
||||||
|
# 计算执行时间
|
||||||
|
duration_ms = 0
|
||||||
|
if tool_name in self._tool_states:
|
||||||
|
start_time = self._tool_states[tool_name].get("start_time", time.time())
|
||||||
|
duration_ms = int((time.time() - start_time) * 1000)
|
||||||
|
del self._tool_states[tool_name]
|
||||||
|
|
||||||
|
# 提取输出
|
||||||
|
output = data.get("output", "")
|
||||||
|
if hasattr(output, "content"):
|
||||||
|
output = output.content
|
||||||
|
|
||||||
|
return StreamEvent(
|
||||||
|
event_type=StreamEventType.TOOL_CALL_END,
|
||||||
|
sequence=self._next_sequence(),
|
||||||
|
node_name=self._current_node,
|
||||||
|
phase=self._current_phase,
|
||||||
|
tool_name=tool_name,
|
||||||
|
data={
|
||||||
|
"tool_name": tool_name,
|
||||||
|
"output": self._truncate_data(output),
|
||||||
|
"duration_ms": duration_ms,
|
||||||
|
"message": f"✅ 工具 {tool_name} 完成 ({duration_ms}ms)",
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
async def _handle_node_start(self, node_name: str, data: Dict) -> StreamEvent:
|
||||||
|
"""处理节点开始事件"""
|
||||||
|
self._current_node = node_name
|
||||||
|
|
||||||
|
# 映射节点到阶段
|
||||||
|
phase_map = {
|
||||||
|
"recon": "reconnaissance",
|
||||||
|
"analysis": "analysis",
|
||||||
|
"verification": "verification",
|
||||||
|
"report": "reporting",
|
||||||
|
}
|
||||||
|
|
||||||
|
for key, phase in phase_map.items():
|
||||||
|
if key in node_name.lower():
|
||||||
|
self._current_phase = phase
|
||||||
|
break
|
||||||
|
|
||||||
|
return StreamEvent(
|
||||||
|
event_type=StreamEventType.NODE_START,
|
||||||
|
sequence=self._next_sequence(),
|
||||||
|
node_name=node_name,
|
||||||
|
phase=self._current_phase,
|
||||||
|
data={
|
||||||
|
"node": node_name,
|
||||||
|
"phase": self._current_phase,
|
||||||
|
"message": f"▶️ 开始节点: {node_name}",
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
async def _handle_node_end(self, node_name: str, data: Dict) -> StreamEvent:
|
||||||
|
"""处理节点结束事件"""
|
||||||
|
# 提取输出信息
|
||||||
|
output = data.get("output", {})
|
||||||
|
|
||||||
|
summary = {}
|
||||||
|
if isinstance(output, dict):
|
||||||
|
# 提取关键信息
|
||||||
|
if "findings" in output:
|
||||||
|
summary["findings_count"] = len(output["findings"])
|
||||||
|
if "entry_points" in output:
|
||||||
|
summary["entry_points_count"] = len(output["entry_points"])
|
||||||
|
if "high_risk_areas" in output:
|
||||||
|
summary["high_risk_areas_count"] = len(output["high_risk_areas"])
|
||||||
|
if "verified_findings" in output:
|
||||||
|
summary["verified_count"] = len(output["verified_findings"])
|
||||||
|
|
||||||
|
return StreamEvent(
|
||||||
|
event_type=StreamEventType.NODE_END,
|
||||||
|
sequence=self._next_sequence(),
|
||||||
|
node_name=node_name,
|
||||||
|
phase=self._current_phase,
|
||||||
|
data={
|
||||||
|
"node": node_name,
|
||||||
|
"phase": self._current_phase,
|
||||||
|
"summary": summary,
|
||||||
|
"message": f"⏹️ 节点完成: {node_name}",
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
async def _handle_custom_event(self, event_name: str, data: Dict) -> StreamEvent:
|
||||||
|
"""处理自定义事件"""
|
||||||
|
# 映射自定义事件名到事件类型
|
||||||
|
event_type_map = {
|
||||||
|
"finding": StreamEventType.FINDING_NEW,
|
||||||
|
"finding_verified": StreamEventType.FINDING_VERIFIED,
|
||||||
|
"progress": StreamEventType.PROGRESS,
|
||||||
|
"warning": StreamEventType.WARNING,
|
||||||
|
"error": StreamEventType.ERROR,
|
||||||
|
}
|
||||||
|
|
||||||
|
event_type = event_type_map.get(event_name, StreamEventType.INFO)
|
||||||
|
|
||||||
|
return StreamEvent(
|
||||||
|
event_type=event_type,
|
||||||
|
sequence=self._next_sequence(),
|
||||||
|
node_name=self._current_node,
|
||||||
|
phase=self._current_phase,
|
||||||
|
data=data,
|
||||||
|
)
|
||||||
|
|
||||||
|
def _truncate_data(self, data: Any, max_length: int = 1000) -> Any:
|
||||||
|
"""截断数据"""
|
||||||
|
if isinstance(data, str):
|
||||||
|
return data[:max_length] + "..." if len(data) > max_length else data
|
||||||
|
elif isinstance(data, dict):
|
||||||
|
return {k: self._truncate_data(v, max_length // 2) for k, v in list(data.items())[:10]}
|
||||||
|
elif isinstance(data, list):
|
||||||
|
return [self._truncate_data(item, max_length // len(data)) for item in data[:10]]
|
||||||
|
else:
|
||||||
|
return str(data)[:max_length]
|
||||||
|
|
||||||
|
def create_progress_event(
|
||||||
|
self,
|
||||||
|
current: int,
|
||||||
|
total: int,
|
||||||
|
message: Optional[str] = None,
|
||||||
|
) -> StreamEvent:
|
||||||
|
"""创建进度事件"""
|
||||||
|
percentage = (current / total * 100) if total > 0 else 0
|
||||||
|
|
||||||
|
return StreamEvent(
|
||||||
|
event_type=StreamEventType.PROGRESS,
|
||||||
|
sequence=self._next_sequence(),
|
||||||
|
node_name=self._current_node,
|
||||||
|
phase=self._current_phase,
|
||||||
|
data={
|
||||||
|
"current": current,
|
||||||
|
"total": total,
|
||||||
|
"percentage": round(percentage, 1),
|
||||||
|
"message": message or f"进度: {current}/{total}",
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
def create_finding_event(
|
||||||
|
self,
|
||||||
|
finding: Dict[str, Any],
|
||||||
|
is_verified: bool = False,
|
||||||
|
) -> StreamEvent:
|
||||||
|
"""创建发现事件"""
|
||||||
|
event_type = StreamEventType.FINDING_VERIFIED if is_verified else StreamEventType.FINDING_NEW
|
||||||
|
|
||||||
|
return StreamEvent(
|
||||||
|
event_type=event_type,
|
||||||
|
sequence=self._next_sequence(),
|
||||||
|
node_name=self._current_node,
|
||||||
|
phase=self._current_phase,
|
||||||
|
data={
|
||||||
|
"title": finding.get("title", "Unknown"),
|
||||||
|
"severity": finding.get("severity", "medium"),
|
||||||
|
"vulnerability_type": finding.get("vulnerability_type", "other"),
|
||||||
|
"file_path": finding.get("file_path"),
|
||||||
|
"line_start": finding.get("line_start"),
|
||||||
|
"is_verified": is_verified,
|
||||||
|
"message": f"{'✅ 已验证' if is_verified else '🔍 新发现'}: [{finding.get('severity', 'medium').upper()}] {finding.get('title', 'Unknown')}",
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
def create_heartbeat(self) -> StreamEvent:
|
||||||
|
"""创建心跳事件"""
|
||||||
|
return StreamEvent(
|
||||||
|
event_type=StreamEventType.HEARTBEAT,
|
||||||
|
sequence=self._sequence, # 心跳不增加序列号
|
||||||
|
data={"message": "ping"},
|
||||||
|
)
|
||||||
|
|
||||||
|
|
@ -0,0 +1,261 @@
|
||||||
|
"""
|
||||||
|
LLM Token 流式输出处理器
|
||||||
|
支持多种 LLM 提供商的流式输出
|
||||||
|
"""
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
import logging
|
||||||
|
from typing import Any, Dict, Optional, AsyncGenerator, Callable
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from datetime import datetime, timezone
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class TokenChunk:
|
||||||
|
"""Token 块"""
|
||||||
|
content: str
|
||||||
|
token_count: int = 1
|
||||||
|
finish_reason: Optional[str] = None
|
||||||
|
model: Optional[str] = None
|
||||||
|
|
||||||
|
# 统计信息
|
||||||
|
accumulated_content: str = ""
|
||||||
|
total_tokens: int = 0
|
||||||
|
|
||||||
|
|
||||||
|
class TokenStreamer:
|
||||||
|
"""
|
||||||
|
LLM Token 流式输出处理器
|
||||||
|
|
||||||
|
最佳实践:
|
||||||
|
1. 使用 LiteLLM 的流式 API
|
||||||
|
2. 实时发送每个 Token
|
||||||
|
3. 跟踪累积内容和 Token 使用
|
||||||
|
4. 支持中断和超时
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
model: str,
|
||||||
|
api_key: Optional[str] = None,
|
||||||
|
base_url: Optional[str] = None,
|
||||||
|
on_token: Optional[Callable[[TokenChunk], None]] = None,
|
||||||
|
):
|
||||||
|
self.model = model
|
||||||
|
self.api_key = api_key
|
||||||
|
self.base_url = base_url
|
||||||
|
self.on_token = on_token
|
||||||
|
|
||||||
|
self._cancelled = False
|
||||||
|
self._accumulated_content = ""
|
||||||
|
self._total_tokens = 0
|
||||||
|
|
||||||
|
def cancel(self):
|
||||||
|
"""取消流式输出"""
|
||||||
|
self._cancelled = True
|
||||||
|
|
||||||
|
async def stream_completion(
|
||||||
|
self,
|
||||||
|
messages: list[Dict[str, str]],
|
||||||
|
temperature: float = 0.1,
|
||||||
|
max_tokens: int = 4096,
|
||||||
|
) -> AsyncGenerator[TokenChunk, None]:
|
||||||
|
"""
|
||||||
|
流式调用 LLM
|
||||||
|
|
||||||
|
Args:
|
||||||
|
messages: 消息列表
|
||||||
|
temperature: 温度
|
||||||
|
max_tokens: 最大 Token 数
|
||||||
|
|
||||||
|
Yields:
|
||||||
|
TokenChunk: Token 块
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
import litellm
|
||||||
|
|
||||||
|
response = await litellm.acompletion(
|
||||||
|
model=self.model,
|
||||||
|
messages=messages,
|
||||||
|
temperature=temperature,
|
||||||
|
max_tokens=max_tokens,
|
||||||
|
api_key=self.api_key,
|
||||||
|
base_url=self.base_url,
|
||||||
|
stream=True, # 启用流式输出
|
||||||
|
)
|
||||||
|
|
||||||
|
async for chunk in response:
|
||||||
|
if self._cancelled:
|
||||||
|
break
|
||||||
|
|
||||||
|
# 提取内容
|
||||||
|
content = ""
|
||||||
|
finish_reason = None
|
||||||
|
|
||||||
|
if hasattr(chunk, "choices") and chunk.choices:
|
||||||
|
choice = chunk.choices[0]
|
||||||
|
if hasattr(choice, "delta") and choice.delta:
|
||||||
|
content = getattr(choice.delta, "content", "") or ""
|
||||||
|
finish_reason = getattr(choice, "finish_reason", None)
|
||||||
|
|
||||||
|
if content:
|
||||||
|
self._accumulated_content += content
|
||||||
|
self._total_tokens += 1
|
||||||
|
|
||||||
|
token_chunk = TokenChunk(
|
||||||
|
content=content,
|
||||||
|
token_count=1,
|
||||||
|
finish_reason=finish_reason,
|
||||||
|
model=self.model,
|
||||||
|
accumulated_content=self._accumulated_content,
|
||||||
|
total_tokens=self._total_tokens,
|
||||||
|
)
|
||||||
|
|
||||||
|
# 回调
|
||||||
|
if self.on_token:
|
||||||
|
self.on_token(token_chunk)
|
||||||
|
|
||||||
|
yield token_chunk
|
||||||
|
|
||||||
|
# 检查是否完成
|
||||||
|
if finish_reason:
|
||||||
|
break
|
||||||
|
|
||||||
|
except asyncio.CancelledError:
|
||||||
|
logger.info("Token streaming cancelled")
|
||||||
|
raise
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Token streaming error: {e}")
|
||||||
|
raise
|
||||||
|
|
||||||
|
async def stream_with_tools(
|
||||||
|
self,
|
||||||
|
messages: list[Dict[str, str]],
|
||||||
|
tools: list[Dict[str, Any]],
|
||||||
|
temperature: float = 0.1,
|
||||||
|
max_tokens: int = 4096,
|
||||||
|
) -> AsyncGenerator[Dict[str, Any], None]:
|
||||||
|
"""
|
||||||
|
带工具调用的流式输出
|
||||||
|
|
||||||
|
Args:
|
||||||
|
messages: 消息列表
|
||||||
|
tools: 工具定义列表
|
||||||
|
temperature: 温度
|
||||||
|
max_tokens: 最大 Token 数
|
||||||
|
|
||||||
|
Yields:
|
||||||
|
包含 token 或 tool_call 的字典
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
import litellm
|
||||||
|
|
||||||
|
response = await litellm.acompletion(
|
||||||
|
model=self.model,
|
||||||
|
messages=messages,
|
||||||
|
tools=tools,
|
||||||
|
temperature=temperature,
|
||||||
|
max_tokens=max_tokens,
|
||||||
|
api_key=self.api_key,
|
||||||
|
base_url=self.base_url,
|
||||||
|
stream=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
# 工具调用累积器
|
||||||
|
tool_calls_accumulator: Dict[int, Dict] = {}
|
||||||
|
|
||||||
|
async for chunk in response:
|
||||||
|
if self._cancelled:
|
||||||
|
break
|
||||||
|
|
||||||
|
if not hasattr(chunk, "choices") or not chunk.choices:
|
||||||
|
continue
|
||||||
|
|
||||||
|
choice = chunk.choices[0]
|
||||||
|
delta = getattr(choice, "delta", None)
|
||||||
|
finish_reason = getattr(choice, "finish_reason", None)
|
||||||
|
|
||||||
|
if delta:
|
||||||
|
# 处理文本内容
|
||||||
|
content = getattr(delta, "content", "") or ""
|
||||||
|
if content:
|
||||||
|
self._accumulated_content += content
|
||||||
|
self._total_tokens += 1
|
||||||
|
|
||||||
|
yield {
|
||||||
|
"type": "token",
|
||||||
|
"content": content,
|
||||||
|
"accumulated": self._accumulated_content,
|
||||||
|
"total_tokens": self._total_tokens,
|
||||||
|
}
|
||||||
|
|
||||||
|
# 处理工具调用
|
||||||
|
tool_calls = getattr(delta, "tool_calls", None) or []
|
||||||
|
for tool_call in tool_calls:
|
||||||
|
idx = tool_call.index
|
||||||
|
|
||||||
|
if idx not in tool_calls_accumulator:
|
||||||
|
tool_calls_accumulator[idx] = {
|
||||||
|
"id": tool_call.id or "",
|
||||||
|
"name": "",
|
||||||
|
"arguments": "",
|
||||||
|
}
|
||||||
|
|
||||||
|
if tool_call.function:
|
||||||
|
if tool_call.function.name:
|
||||||
|
tool_calls_accumulator[idx]["name"] = tool_call.function.name
|
||||||
|
if tool_call.function.arguments:
|
||||||
|
tool_calls_accumulator[idx]["arguments"] += tool_call.function.arguments
|
||||||
|
|
||||||
|
yield {
|
||||||
|
"type": "tool_call_chunk",
|
||||||
|
"index": idx,
|
||||||
|
"tool_call": tool_calls_accumulator[idx],
|
||||||
|
}
|
||||||
|
|
||||||
|
# 完成时发送最终工具调用
|
||||||
|
if finish_reason == "tool_calls":
|
||||||
|
for idx, tool_call in tool_calls_accumulator.items():
|
||||||
|
yield {
|
||||||
|
"type": "tool_call_complete",
|
||||||
|
"index": idx,
|
||||||
|
"tool_call": tool_call,
|
||||||
|
}
|
||||||
|
|
||||||
|
if finish_reason:
|
||||||
|
yield {
|
||||||
|
"type": "finish",
|
||||||
|
"reason": finish_reason,
|
||||||
|
"accumulated": self._accumulated_content,
|
||||||
|
"total_tokens": self._total_tokens,
|
||||||
|
}
|
||||||
|
break
|
||||||
|
|
||||||
|
except asyncio.CancelledError:
|
||||||
|
logger.info("Tool streaming cancelled")
|
||||||
|
raise
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Tool streaming error: {e}")
|
||||||
|
yield {
|
||||||
|
"type": "error",
|
||||||
|
"error": str(e),
|
||||||
|
}
|
||||||
|
|
||||||
|
def get_accumulated_content(self) -> str:
|
||||||
|
"""获取累积内容"""
|
||||||
|
return self._accumulated_content
|
||||||
|
|
||||||
|
def get_total_tokens(self) -> int:
|
||||||
|
"""获取总 Token 数"""
|
||||||
|
return self._total_tokens
|
||||||
|
|
||||||
|
def reset(self):
|
||||||
|
"""重置状态"""
|
||||||
|
self._cancelled = False
|
||||||
|
self._accumulated_content = ""
|
||||||
|
self._total_tokens = 0
|
||||||
|
|
||||||
|
|
@ -0,0 +1,319 @@
|
||||||
|
"""
|
||||||
|
工具调用流式处理器
|
||||||
|
展示工具调用的输入、执行过程和输出
|
||||||
|
"""
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
import time
|
||||||
|
import logging
|
||||||
|
from typing import Any, Dict, Optional, AsyncGenerator, List, Callable
|
||||||
|
from dataclasses import dataclass, field
|
||||||
|
from datetime import datetime, timezone
|
||||||
|
from enum import Enum
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class ToolCallState(str, Enum):
|
||||||
|
"""工具调用状态"""
|
||||||
|
PENDING = "pending" # 等待执行
|
||||||
|
RUNNING = "running" # 执行中
|
||||||
|
SUCCESS = "success" # 成功
|
||||||
|
ERROR = "error" # 错误
|
||||||
|
TIMEOUT = "timeout" # 超时
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class ToolCallEvent:
|
||||||
|
"""工具调用事件"""
|
||||||
|
tool_name: str
|
||||||
|
state: ToolCallState
|
||||||
|
|
||||||
|
# 输入输出
|
||||||
|
input_params: Dict[str, Any] = field(default_factory=dict)
|
||||||
|
output_data: Optional[Any] = None
|
||||||
|
error_message: Optional[str] = None
|
||||||
|
|
||||||
|
# 时间
|
||||||
|
start_time: Optional[float] = None
|
||||||
|
end_time: Optional[float] = None
|
||||||
|
duration_ms: int = 0
|
||||||
|
|
||||||
|
# 元数据
|
||||||
|
call_id: Optional[str] = None
|
||||||
|
sequence: int = 0
|
||||||
|
timestamp: str = field(default_factory=lambda: datetime.now(timezone.utc).isoformat())
|
||||||
|
|
||||||
|
def to_dict(self) -> Dict[str, Any]:
|
||||||
|
"""转换为字典"""
|
||||||
|
return {
|
||||||
|
"tool_name": self.tool_name,
|
||||||
|
"state": self.state.value,
|
||||||
|
"input_params": self._truncate(self.input_params),
|
||||||
|
"output_data": self._truncate(self.output_data),
|
||||||
|
"error_message": self.error_message,
|
||||||
|
"duration_ms": self.duration_ms,
|
||||||
|
"call_id": self.call_id,
|
||||||
|
"sequence": self.sequence,
|
||||||
|
"timestamp": self.timestamp,
|
||||||
|
}
|
||||||
|
|
||||||
|
def _truncate(self, data: Any, max_length: int = 500) -> Any:
|
||||||
|
"""截断数据"""
|
||||||
|
if data is None:
|
||||||
|
return None
|
||||||
|
if isinstance(data, str):
|
||||||
|
return data[:max_length] + "..." if len(data) > max_length else data
|
||||||
|
elif isinstance(data, dict):
|
||||||
|
return {k: self._truncate(v, max_length // 2) for k, v in list(data.items())[:20]}
|
||||||
|
elif isinstance(data, list):
|
||||||
|
max_items = min(20, len(data))
|
||||||
|
return [self._truncate(item, max_length // max_items) for item in data[:max_items]]
|
||||||
|
else:
|
||||||
|
s = str(data)
|
||||||
|
return s[:max_length] + "..." if len(s) > max_length else s
|
||||||
|
|
||||||
|
|
||||||
|
class ToolStreamHandler:
|
||||||
|
"""
|
||||||
|
工具调用流式处理器
|
||||||
|
|
||||||
|
功能:
|
||||||
|
1. 跟踪工具调用状态
|
||||||
|
2. 记录输入参数
|
||||||
|
3. 流式输出执行过程
|
||||||
|
4. 记录输出和执行时间
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
on_event: Optional[Callable[[ToolCallEvent], None]] = None,
|
||||||
|
):
|
||||||
|
self.on_event = on_event
|
||||||
|
self._sequence = 0
|
||||||
|
self._active_calls: Dict[str, ToolCallEvent] = {}
|
||||||
|
self._history: List[ToolCallEvent] = []
|
||||||
|
|
||||||
|
def _next_sequence(self) -> int:
|
||||||
|
"""获取下一个序列号"""
|
||||||
|
self._sequence += 1
|
||||||
|
return self._sequence
|
||||||
|
|
||||||
|
def _generate_call_id(self) -> str:
|
||||||
|
"""生成调用 ID"""
|
||||||
|
import uuid
|
||||||
|
return str(uuid.uuid4())[:8]
|
||||||
|
|
||||||
|
async def emit_tool_start(
|
||||||
|
self,
|
||||||
|
tool_name: str,
|
||||||
|
input_params: Dict[str, Any],
|
||||||
|
call_id: Optional[str] = None,
|
||||||
|
) -> ToolCallEvent:
|
||||||
|
"""
|
||||||
|
发射工具开始事件
|
||||||
|
|
||||||
|
Args:
|
||||||
|
tool_name: 工具名称
|
||||||
|
input_params: 输入参数
|
||||||
|
call_id: 调用 ID
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
工具调用事件
|
||||||
|
"""
|
||||||
|
call_id = call_id or self._generate_call_id()
|
||||||
|
|
||||||
|
event = ToolCallEvent(
|
||||||
|
tool_name=tool_name,
|
||||||
|
state=ToolCallState.RUNNING,
|
||||||
|
input_params=input_params,
|
||||||
|
start_time=time.time(),
|
||||||
|
call_id=call_id,
|
||||||
|
sequence=self._next_sequence(),
|
||||||
|
)
|
||||||
|
|
||||||
|
self._active_calls[call_id] = event
|
||||||
|
|
||||||
|
if self.on_event:
|
||||||
|
self.on_event(event)
|
||||||
|
|
||||||
|
return event
|
||||||
|
|
||||||
|
async def emit_tool_end(
|
||||||
|
self,
|
||||||
|
call_id: str,
|
||||||
|
output_data: Any,
|
||||||
|
is_error: bool = False,
|
||||||
|
error_message: Optional[str] = None,
|
||||||
|
) -> ToolCallEvent:
|
||||||
|
"""
|
||||||
|
发射工具结束事件
|
||||||
|
|
||||||
|
Args:
|
||||||
|
call_id: 调用 ID
|
||||||
|
output_data: 输出数据
|
||||||
|
is_error: 是否错误
|
||||||
|
error_message: 错误消息
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
工具调用事件
|
||||||
|
"""
|
||||||
|
if call_id not in self._active_calls:
|
||||||
|
logger.warning(f"Unknown tool call: {call_id}")
|
||||||
|
return None
|
||||||
|
|
||||||
|
event = self._active_calls[call_id]
|
||||||
|
event.end_time = time.time()
|
||||||
|
event.duration_ms = int((event.end_time - event.start_time) * 1000) if event.start_time else 0
|
||||||
|
event.output_data = output_data
|
||||||
|
event.sequence = self._next_sequence()
|
||||||
|
|
||||||
|
if is_error:
|
||||||
|
event.state = ToolCallState.ERROR
|
||||||
|
event.error_message = error_message or str(output_data)
|
||||||
|
else:
|
||||||
|
event.state = ToolCallState.SUCCESS
|
||||||
|
|
||||||
|
# 移动到历史记录
|
||||||
|
del self._active_calls[call_id]
|
||||||
|
self._history.append(event)
|
||||||
|
|
||||||
|
if self.on_event:
|
||||||
|
self.on_event(event)
|
||||||
|
|
||||||
|
return event
|
||||||
|
|
||||||
|
async def emit_tool_timeout(self, call_id: str, timeout_seconds: int) -> ToolCallEvent:
|
||||||
|
"""发射工具超时事件"""
|
||||||
|
if call_id not in self._active_calls:
|
||||||
|
return None
|
||||||
|
|
||||||
|
event = self._active_calls[call_id]
|
||||||
|
event.end_time = time.time()
|
||||||
|
event.duration_ms = int((event.end_time - event.start_time) * 1000) if event.start_time else 0
|
||||||
|
event.state = ToolCallState.TIMEOUT
|
||||||
|
event.error_message = f"Tool execution timed out after {timeout_seconds}s"
|
||||||
|
event.sequence = self._next_sequence()
|
||||||
|
|
||||||
|
del self._active_calls[call_id]
|
||||||
|
self._history.append(event)
|
||||||
|
|
||||||
|
if self.on_event:
|
||||||
|
self.on_event(event)
|
||||||
|
|
||||||
|
return event
|
||||||
|
|
||||||
|
def wrap_tool(
|
||||||
|
self,
|
||||||
|
tool_func: Callable,
|
||||||
|
tool_name: str,
|
||||||
|
timeout: Optional[int] = None,
|
||||||
|
) -> Callable:
|
||||||
|
"""
|
||||||
|
包装工具函数以自动跟踪
|
||||||
|
|
||||||
|
Args:
|
||||||
|
tool_func: 工具函数
|
||||||
|
tool_name: 工具名称
|
||||||
|
timeout: 超时时间(秒)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
包装后的函数
|
||||||
|
"""
|
||||||
|
async def wrapped(*args, **kwargs):
|
||||||
|
call_id = self._generate_call_id()
|
||||||
|
|
||||||
|
# 发射开始事件
|
||||||
|
await self.emit_tool_start(
|
||||||
|
tool_name=tool_name,
|
||||||
|
input_params={"args": args, "kwargs": kwargs},
|
||||||
|
call_id=call_id,
|
||||||
|
)
|
||||||
|
|
||||||
|
try:
|
||||||
|
# 执行工具
|
||||||
|
if asyncio.iscoroutinefunction(tool_func):
|
||||||
|
if timeout:
|
||||||
|
result = await asyncio.wait_for(
|
||||||
|
tool_func(*args, **kwargs),
|
||||||
|
timeout=timeout,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
result = await tool_func(*args, **kwargs)
|
||||||
|
else:
|
||||||
|
if timeout:
|
||||||
|
result = await asyncio.wait_for(
|
||||||
|
asyncio.to_thread(tool_func, *args, **kwargs),
|
||||||
|
timeout=timeout,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
result = tool_func(*args, **kwargs)
|
||||||
|
|
||||||
|
# 发射结束事件
|
||||||
|
await self.emit_tool_end(call_id, result)
|
||||||
|
|
||||||
|
return result
|
||||||
|
|
||||||
|
except asyncio.TimeoutError:
|
||||||
|
await self.emit_tool_timeout(call_id, timeout or 0)
|
||||||
|
raise
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
await self.emit_tool_end(call_id, None, is_error=True, error_message=str(e))
|
||||||
|
raise
|
||||||
|
|
||||||
|
return wrapped
|
||||||
|
|
||||||
|
def get_active_calls(self) -> List[ToolCallEvent]:
|
||||||
|
"""获取活跃的调用"""
|
||||||
|
return list(self._active_calls.values())
|
||||||
|
|
||||||
|
def get_history(self, limit: int = 100) -> List[ToolCallEvent]:
|
||||||
|
"""获取历史记录"""
|
||||||
|
return self._history[-limit:]
|
||||||
|
|
||||||
|
def get_stats(self) -> Dict[str, Any]:
|
||||||
|
"""获取统计信息"""
|
||||||
|
total_calls = len(self._history)
|
||||||
|
success_calls = sum(1 for e in self._history if e.state == ToolCallState.SUCCESS)
|
||||||
|
error_calls = sum(1 for e in self._history if e.state == ToolCallState.ERROR)
|
||||||
|
timeout_calls = sum(1 for e in self._history if e.state == ToolCallState.TIMEOUT)
|
||||||
|
|
||||||
|
total_duration = sum(e.duration_ms for e in self._history)
|
||||||
|
avg_duration = total_duration / total_calls if total_calls > 0 else 0
|
||||||
|
|
||||||
|
# 按工具统计
|
||||||
|
tool_stats = {}
|
||||||
|
for event in self._history:
|
||||||
|
if event.tool_name not in tool_stats:
|
||||||
|
tool_stats[event.tool_name] = {
|
||||||
|
"calls": 0,
|
||||||
|
"success": 0,
|
||||||
|
"errors": 0,
|
||||||
|
"total_duration_ms": 0,
|
||||||
|
}
|
||||||
|
tool_stats[event.tool_name]["calls"] += 1
|
||||||
|
if event.state == ToolCallState.SUCCESS:
|
||||||
|
tool_stats[event.tool_name]["success"] += 1
|
||||||
|
elif event.state in [ToolCallState.ERROR, ToolCallState.TIMEOUT]:
|
||||||
|
tool_stats[event.tool_name]["errors"] += 1
|
||||||
|
tool_stats[event.tool_name]["total_duration_ms"] += event.duration_ms
|
||||||
|
|
||||||
|
return {
|
||||||
|
"total_calls": total_calls,
|
||||||
|
"success_calls": success_calls,
|
||||||
|
"error_calls": error_calls,
|
||||||
|
"timeout_calls": timeout_calls,
|
||||||
|
"success_rate": success_calls / total_calls if total_calls > 0 else 0,
|
||||||
|
"total_duration_ms": total_duration,
|
||||||
|
"avg_duration_ms": round(avg_duration, 2),
|
||||||
|
"active_calls": len(self._active_calls),
|
||||||
|
"by_tool": tool_stats,
|
||||||
|
}
|
||||||
|
|
||||||
|
def clear(self):
|
||||||
|
"""清空记录"""
|
||||||
|
self._active_calls.clear()
|
||||||
|
self._history.clear()
|
||||||
|
self._sequence = 0
|
||||||
|
|
||||||
|
|
@ -41,6 +41,16 @@ class PatternMatchTool(AgentTool):
|
||||||
使用正则表达式快速扫描代码中的危险模式
|
使用正则表达式快速扫描代码中的危险模式
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
def __init__(self, project_root: str = None):
|
||||||
|
"""
|
||||||
|
初始化模式匹配工具
|
||||||
|
|
||||||
|
Args:
|
||||||
|
project_root: 项目根目录(可选,用于上下文)
|
||||||
|
"""
|
||||||
|
super().__init__()
|
||||||
|
self.project_root = project_root
|
||||||
|
|
||||||
# 危险模式定义
|
# 危险模式定义
|
||||||
PATTERNS: Dict[str, Dict[str, Any]] = {
|
PATTERNS: Dict[str, Dict[str, Any]] = {
|
||||||
# SQL 注入模式
|
# SQL 注入模式
|
||||||
|
|
|
||||||
|
|
@ -145,3 +145,5 @@ class BaiduAdapter(BaseLLMAdapter):
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -83,3 +83,5 @@ class DoubaoAdapter(BaseLLMAdapter):
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -86,3 +86,5 @@ class MinimaxAdapter(BaseLLMAdapter):
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -134,3 +134,5 @@ class BaseLLMAdapter(ABC):
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -120,3 +120,5 @@ DEFAULT_BASE_URLS: Dict[LLMProvider, str] = {
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,5 @@
|
||||||
|
"""
|
||||||
|
DeepAudit Agent 测试套件
|
||||||
|
企业级测试框架,覆盖工具、Agent、节点和完整流程
|
||||||
|
"""
|
||||||
|
|
||||||
|
|
@ -0,0 +1,295 @@
|
||||||
|
"""
|
||||||
|
Agent 测试配置和 Fixtures
|
||||||
|
提供测试所需的公共设施
|
||||||
|
"""
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
import asyncio
|
||||||
|
import tempfile
|
||||||
|
import shutil
|
||||||
|
import os
|
||||||
|
from typing import Dict, Any, Optional
|
||||||
|
from unittest.mock import AsyncMock, MagicMock, patch
|
||||||
|
from dataclasses import dataclass
|
||||||
|
|
||||||
|
|
||||||
|
# ============ 测试配置 ============
|
||||||
|
|
||||||
|
@pytest.fixture(scope="session")
|
||||||
|
def event_loop():
|
||||||
|
"""创建事件循环"""
|
||||||
|
loop = asyncio.new_event_loop()
|
||||||
|
yield loop
|
||||||
|
loop.close()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def temp_project_dir():
|
||||||
|
"""创建临时项目目录,包含测试代码"""
|
||||||
|
temp_dir = tempfile.mkdtemp(prefix="deepaudit_test_")
|
||||||
|
|
||||||
|
# 创建测试项目结构
|
||||||
|
os.makedirs(os.path.join(temp_dir, "src"), exist_ok=True)
|
||||||
|
os.makedirs(os.path.join(temp_dir, "config"), exist_ok=True)
|
||||||
|
|
||||||
|
# 创建有漏洞的测试代码 - SQL 注入
|
||||||
|
sql_vuln_code = '''
|
||||||
|
import sqlite3
|
||||||
|
|
||||||
|
def get_user(user_id):
|
||||||
|
"""危险:SQL 注入漏洞"""
|
||||||
|
conn = sqlite3.connect("users.db")
|
||||||
|
cursor = conn.cursor()
|
||||||
|
# 直接拼接用户输入,存在 SQL 注入风险
|
||||||
|
query = f"SELECT * FROM users WHERE id = '{user_id}'"
|
||||||
|
cursor.execute(query)
|
||||||
|
return cursor.fetchone()
|
||||||
|
|
||||||
|
def search_users(name):
|
||||||
|
"""危险:SQL 注入漏洞"""
|
||||||
|
conn = sqlite3.connect("users.db")
|
||||||
|
cursor = conn.cursor()
|
||||||
|
cursor.execute("SELECT * FROM users WHERE name LIKE '%" + name + "%'")
|
||||||
|
return cursor.fetchall()
|
||||||
|
'''
|
||||||
|
|
||||||
|
# 创建有漏洞的测试代码 - 命令注入
|
||||||
|
cmd_vuln_code = '''
|
||||||
|
import os
|
||||||
|
import subprocess
|
||||||
|
|
||||||
|
def run_command(user_input):
|
||||||
|
"""危险:命令注入漏洞"""
|
||||||
|
# 直接执行用户输入
|
||||||
|
os.system(f"echo {user_input}")
|
||||||
|
|
||||||
|
def execute_script(script_name):
|
||||||
|
"""危险:命令注入漏洞"""
|
||||||
|
subprocess.call(f"bash {script_name}", shell=True)
|
||||||
|
'''
|
||||||
|
|
||||||
|
# 创建有漏洞的测试代码 - XSS
|
||||||
|
xss_vuln_code = '''
|
||||||
|
from flask import Flask, request, render_template_string
|
||||||
|
|
||||||
|
app = Flask(__name__)
|
||||||
|
|
||||||
|
@app.route("/greet")
|
||||||
|
def greet():
|
||||||
|
"""危险:XSS 漏洞"""
|
||||||
|
name = request.args.get("name", "")
|
||||||
|
# 直接将用户输入嵌入 HTML,存在 XSS 风险
|
||||||
|
return f"<h1>Hello, {name}!</h1>"
|
||||||
|
|
||||||
|
@app.route("/search")
|
||||||
|
def search():
|
||||||
|
"""危险:XSS 漏洞"""
|
||||||
|
query = request.args.get("q", "")
|
||||||
|
html = f"<p>搜索结果: {query}</p>"
|
||||||
|
return render_template_string(html)
|
||||||
|
'''
|
||||||
|
|
||||||
|
# 创建有漏洞的测试代码 - 路径遍历
|
||||||
|
path_vuln_code = '''
|
||||||
|
import os
|
||||||
|
|
||||||
|
def read_file(filename):
|
||||||
|
"""危险:路径遍历漏洞"""
|
||||||
|
# 没有验证文件路径
|
||||||
|
filepath = os.path.join("/app/data", filename)
|
||||||
|
with open(filepath, "r") as f:
|
||||||
|
return f.read()
|
||||||
|
|
||||||
|
def download_file(user_path):
|
||||||
|
"""危险:路径遍历漏洞"""
|
||||||
|
# 直接使用用户输入作为文件路径
|
||||||
|
with open(user_path, "rb") as f:
|
||||||
|
return f.read()
|
||||||
|
'''
|
||||||
|
|
||||||
|
# 创建有漏洞的测试代码 - 硬编码密钥
|
||||||
|
secret_vuln_code = '''
|
||||||
|
# 配置文件
|
||||||
|
DATABASE_URL = "postgresql://user:password123@localhost/db"
|
||||||
|
API_KEY = "sk-1234567890abcdef1234567890abcdef"
|
||||||
|
SECRET_KEY = "super_secret_key_dont_share"
|
||||||
|
AWS_SECRET_ACCESS_KEY = "wJalrXUtnFEMI/K7MDENG/bPxRfiCYEXAMPLEKEY"
|
||||||
|
|
||||||
|
def connect_database():
|
||||||
|
password = "admin123" # 硬编码密码
|
||||||
|
return f"mysql://root:{password}@localhost/mydb"
|
||||||
|
'''
|
||||||
|
|
||||||
|
# 创建安全的代码(用于对比)
|
||||||
|
safe_code = '''
|
||||||
|
import sqlite3
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
def get_user_safe(user_id: int) -> Optional[dict]:
|
||||||
|
"""安全:使用参数化查询"""
|
||||||
|
conn = sqlite3.connect("users.db")
|
||||||
|
cursor = conn.cursor()
|
||||||
|
cursor.execute("SELECT * FROM users WHERE id = ?", (user_id,))
|
||||||
|
return cursor.fetchone()
|
||||||
|
|
||||||
|
def validate_input(user_input: str) -> str:
|
||||||
|
"""输入验证"""
|
||||||
|
import re
|
||||||
|
if not re.match(r'^[a-zA-Z0-9_]+$', user_input):
|
||||||
|
raise ValueError("Invalid input")
|
||||||
|
return user_input
|
||||||
|
'''
|
||||||
|
|
||||||
|
# 创建配置文件
|
||||||
|
config_code = '''
|
||||||
|
import os
|
||||||
|
|
||||||
|
class Config:
|
||||||
|
"""安全配置"""
|
||||||
|
DATABASE_URL = os.environ.get("DATABASE_URL")
|
||||||
|
SECRET_KEY = os.environ.get("SECRET_KEY")
|
||||||
|
DEBUG = False
|
||||||
|
'''
|
||||||
|
|
||||||
|
# 创建 requirements.txt
|
||||||
|
requirements = '''
|
||||||
|
flask>=2.0.0
|
||||||
|
sqlalchemy>=2.0.0
|
||||||
|
requests>=2.28.0
|
||||||
|
'''
|
||||||
|
|
||||||
|
# 写入文件
|
||||||
|
with open(os.path.join(temp_dir, "src", "sql_vuln.py"), "w") as f:
|
||||||
|
f.write(sql_vuln_code)
|
||||||
|
|
||||||
|
with open(os.path.join(temp_dir, "src", "cmd_vuln.py"), "w") as f:
|
||||||
|
f.write(cmd_vuln_code)
|
||||||
|
|
||||||
|
with open(os.path.join(temp_dir, "src", "xss_vuln.py"), "w") as f:
|
||||||
|
f.write(xss_vuln_code)
|
||||||
|
|
||||||
|
with open(os.path.join(temp_dir, "src", "path_vuln.py"), "w") as f:
|
||||||
|
f.write(path_vuln_code)
|
||||||
|
|
||||||
|
with open(os.path.join(temp_dir, "src", "secrets.py"), "w") as f:
|
||||||
|
f.write(secret_vuln_code)
|
||||||
|
|
||||||
|
with open(os.path.join(temp_dir, "src", "safe_code.py"), "w") as f:
|
||||||
|
f.write(safe_code)
|
||||||
|
|
||||||
|
with open(os.path.join(temp_dir, "config", "settings.py"), "w") as f:
|
||||||
|
f.write(config_code)
|
||||||
|
|
||||||
|
with open(os.path.join(temp_dir, "requirements.txt"), "w") as f:
|
||||||
|
f.write(requirements)
|
||||||
|
|
||||||
|
yield temp_dir
|
||||||
|
|
||||||
|
# 清理
|
||||||
|
shutil.rmtree(temp_dir, ignore_errors=True)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def mock_llm_service():
|
||||||
|
"""模拟 LLM 服务"""
|
||||||
|
service = MagicMock()
|
||||||
|
service.chat_completion_raw = AsyncMock(return_value={
|
||||||
|
"content": "测试响应",
|
||||||
|
"usage": {"total_tokens": 100},
|
||||||
|
})
|
||||||
|
return service
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def mock_event_emitter():
|
||||||
|
"""模拟事件发射器"""
|
||||||
|
emitter = MagicMock()
|
||||||
|
emitter.emit_info = AsyncMock()
|
||||||
|
emitter.emit_warning = AsyncMock()
|
||||||
|
emitter.emit_error = AsyncMock()
|
||||||
|
emitter.emit_thinking = AsyncMock()
|
||||||
|
emitter.emit_tool_call = AsyncMock()
|
||||||
|
emitter.emit_tool_result = AsyncMock()
|
||||||
|
emitter.emit_finding = AsyncMock()
|
||||||
|
emitter.emit_progress = AsyncMock()
|
||||||
|
emitter.emit_phase_start = AsyncMock()
|
||||||
|
emitter.emit_phase_complete = AsyncMock()
|
||||||
|
emitter.emit_task_complete = AsyncMock()
|
||||||
|
emitter.emit = AsyncMock()
|
||||||
|
return emitter
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def mock_db_session():
|
||||||
|
"""模拟数据库会话"""
|
||||||
|
session = AsyncMock()
|
||||||
|
session.add = MagicMock()
|
||||||
|
session.commit = AsyncMock()
|
||||||
|
session.rollback = AsyncMock()
|
||||||
|
session.get = AsyncMock(return_value=None)
|
||||||
|
session.execute = AsyncMock()
|
||||||
|
return session
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class MockProject:
|
||||||
|
"""模拟项目"""
|
||||||
|
id: str = "test-project-id"
|
||||||
|
name: str = "Test Project"
|
||||||
|
description: str = "Test project for unit tests"
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class MockAgentTask:
|
||||||
|
"""模拟 Agent 任务"""
|
||||||
|
id: str = "test-task-id"
|
||||||
|
project_id: str = "test-project-id"
|
||||||
|
project: MockProject = None
|
||||||
|
name: str = "Test Agent Task"
|
||||||
|
status: str = "pending"
|
||||||
|
current_phase: str = "planning"
|
||||||
|
target_vulnerabilities: list = None
|
||||||
|
verification_level: str = "sandbox"
|
||||||
|
exclude_patterns: list = None
|
||||||
|
target_files: list = None
|
||||||
|
max_iterations: int = 50
|
||||||
|
timeout_seconds: int = 1800
|
||||||
|
|
||||||
|
def __post_init__(self):
|
||||||
|
if self.project is None:
|
||||||
|
self.project = MockProject()
|
||||||
|
if self.target_vulnerabilities is None:
|
||||||
|
self.target_vulnerabilities = []
|
||||||
|
if self.exclude_patterns is None:
|
||||||
|
self.exclude_patterns = []
|
||||||
|
if self.target_files is None:
|
||||||
|
self.target_files = []
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def mock_task():
|
||||||
|
"""创建模拟任务"""
|
||||||
|
return MockAgentTask()
|
||||||
|
|
||||||
|
|
||||||
|
# ============ 测试辅助函数 ============
|
||||||
|
|
||||||
|
def assert_finding_valid(finding: Dict[str, Any]):
|
||||||
|
"""验证漏洞发现的格式"""
|
||||||
|
required_fields = ["title", "severity", "vulnerability_type"]
|
||||||
|
for field in required_fields:
|
||||||
|
assert field in finding, f"Missing required field: {field}"
|
||||||
|
|
||||||
|
valid_severities = ["critical", "high", "medium", "low", "info"]
|
||||||
|
assert finding["severity"] in valid_severities, f"Invalid severity: {finding['severity']}"
|
||||||
|
|
||||||
|
|
||||||
|
def count_findings_by_type(findings: list, vuln_type: str) -> int:
|
||||||
|
"""统计特定类型的漏洞数量"""
|
||||||
|
return sum(1 for f in findings if f.get("vulnerability_type") == vuln_type)
|
||||||
|
|
||||||
|
|
||||||
|
def count_findings_by_severity(findings: list, severity: str) -> int:
|
||||||
|
"""统计特定严重程度的漏洞数量"""
|
||||||
|
return sum(1 for f in findings if f.get("severity") == severity)
|
||||||
|
|
||||||
|
|
@ -0,0 +1,96 @@
|
||||||
|
#!/usr/bin/env python
|
||||||
|
"""
|
||||||
|
Agent 测试运行器
|
||||||
|
运行所有 Agent 相关测试并生成报告
|
||||||
|
"""
|
||||||
|
|
||||||
|
import subprocess
|
||||||
|
import sys
|
||||||
|
import os
|
||||||
|
from datetime import datetime
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
|
||||||
|
def run_tests():
|
||||||
|
"""运行测试"""
|
||||||
|
# 获取项目根目录
|
||||||
|
project_root = Path(__file__).parent.parent.parent
|
||||||
|
os.chdir(project_root)
|
||||||
|
|
||||||
|
print("=" * 60)
|
||||||
|
print("DeepAudit Agent 测试套件")
|
||||||
|
print(f"时间: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}")
|
||||||
|
print("=" * 60)
|
||||||
|
print()
|
||||||
|
|
||||||
|
# 测试命令
|
||||||
|
cmd = [
|
||||||
|
sys.executable, "-m", "pytest",
|
||||||
|
"tests/agent/",
|
||||||
|
"-v",
|
||||||
|
"--tb=short",
|
||||||
|
"-x", # 遇到第一个失败就停止
|
||||||
|
"--color=yes",
|
||||||
|
]
|
||||||
|
|
||||||
|
print(f"运行命令: {' '.join(cmd)}")
|
||||||
|
print()
|
||||||
|
|
||||||
|
# 运行测试
|
||||||
|
result = subprocess.run(cmd, cwd=project_root)
|
||||||
|
|
||||||
|
print()
|
||||||
|
print("=" * 60)
|
||||||
|
if result.returncode == 0:
|
||||||
|
print("✅ 所有测试通过!")
|
||||||
|
else:
|
||||||
|
print(f"❌ 测试失败 (退出码: {result.returncode})")
|
||||||
|
print("=" * 60)
|
||||||
|
|
||||||
|
return result.returncode
|
||||||
|
|
||||||
|
|
||||||
|
def run_tests_with_coverage():
|
||||||
|
"""运行测试并生成覆盖率报告"""
|
||||||
|
project_root = Path(__file__).parent.parent.parent
|
||||||
|
os.chdir(project_root)
|
||||||
|
|
||||||
|
cmd = [
|
||||||
|
sys.executable, "-m", "pytest",
|
||||||
|
"tests/agent/",
|
||||||
|
"-v",
|
||||||
|
"--cov=app/services/agent",
|
||||||
|
"--cov-report=term-missing",
|
||||||
|
"--cov-report=html:coverage_agent",
|
||||||
|
]
|
||||||
|
|
||||||
|
result = subprocess.run(cmd, cwd=project_root)
|
||||||
|
return result.returncode
|
||||||
|
|
||||||
|
|
||||||
|
def run_specific_test(test_name: str):
|
||||||
|
"""运行特定测试"""
|
||||||
|
project_root = Path(__file__).parent.parent.parent
|
||||||
|
os.chdir(project_root)
|
||||||
|
|
||||||
|
cmd = [
|
||||||
|
sys.executable, "-m", "pytest",
|
||||||
|
f"tests/agent/{test_name}",
|
||||||
|
"-v",
|
||||||
|
"--tb=long",
|
||||||
|
"-s", # 显示 print 输出
|
||||||
|
]
|
||||||
|
|
||||||
|
result = subprocess.run(cmd, cwd=project_root)
|
||||||
|
return result.returncode
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
if len(sys.argv) > 1:
|
||||||
|
if sys.argv[1] == "--coverage":
|
||||||
|
sys.exit(run_tests_with_coverage())
|
||||||
|
else:
|
||||||
|
sys.exit(run_specific_test(sys.argv[1]))
|
||||||
|
else:
|
||||||
|
sys.exit(run_tests())
|
||||||
|
|
||||||
|
|
@ -0,0 +1,213 @@
|
||||||
|
"""
|
||||||
|
Agent 单元测试
|
||||||
|
测试各个 Agent 的功能
|
||||||
|
"""
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
import asyncio
|
||||||
|
import os
|
||||||
|
from unittest.mock import MagicMock, AsyncMock, patch
|
||||||
|
|
||||||
|
from app.services.agent.agents.base import BaseAgent, AgentConfig, AgentResult, AgentType, AgentPattern
|
||||||
|
from app.services.agent.agents.recon import ReconAgent
|
||||||
|
from app.services.agent.agents.analysis import AnalysisAgent
|
||||||
|
from app.services.agent.agents.verification import VerificationAgent
|
||||||
|
|
||||||
|
|
||||||
|
class TestReconAgent:
|
||||||
|
"""Recon Agent 测试"""
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def recon_agent(self, temp_project_dir, mock_llm_service, mock_event_emitter):
|
||||||
|
"""创建 Recon Agent 实例"""
|
||||||
|
from app.services.agent.tools import (
|
||||||
|
FileReadTool, FileSearchTool, ListFilesTool,
|
||||||
|
)
|
||||||
|
|
||||||
|
tools = {
|
||||||
|
"list_files": ListFilesTool(temp_project_dir),
|
||||||
|
"read_file": FileReadTool(temp_project_dir),
|
||||||
|
"search_code": FileSearchTool(temp_project_dir),
|
||||||
|
}
|
||||||
|
|
||||||
|
return ReconAgent(
|
||||||
|
llm_service=mock_llm_service,
|
||||||
|
tools=tools,
|
||||||
|
event_emitter=mock_event_emitter,
|
||||||
|
)
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_recon_agent_run(self, recon_agent, temp_project_dir):
|
||||||
|
"""测试 Recon Agent 运行"""
|
||||||
|
result = await recon_agent.run({
|
||||||
|
"project_info": {
|
||||||
|
"name": "Test Project",
|
||||||
|
"root": temp_project_dir,
|
||||||
|
},
|
||||||
|
"config": {},
|
||||||
|
})
|
||||||
|
|
||||||
|
assert result.success is True
|
||||||
|
assert result.data is not None
|
||||||
|
|
||||||
|
# 验证返回数据结构
|
||||||
|
data = result.data
|
||||||
|
assert "tech_stack" in data
|
||||||
|
assert "entry_points" in data or "high_risk_areas" in data
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_recon_agent_identifies_python(self, recon_agent, temp_project_dir):
|
||||||
|
"""测试 Recon Agent 识别 Python 技术栈"""
|
||||||
|
result = await recon_agent.run({
|
||||||
|
"project_info": {"root": temp_project_dir},
|
||||||
|
"config": {},
|
||||||
|
})
|
||||||
|
|
||||||
|
assert result.success is True
|
||||||
|
tech_stack = result.data.get("tech_stack", {})
|
||||||
|
languages = tech_stack.get("languages", [])
|
||||||
|
|
||||||
|
# 应该识别出 Python
|
||||||
|
assert "Python" in languages or len(languages) > 0
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_recon_agent_finds_high_risk_areas(self, recon_agent, temp_project_dir):
|
||||||
|
"""测试 Recon Agent 发现高风险区域"""
|
||||||
|
result = await recon_agent.run({
|
||||||
|
"project_info": {"root": temp_project_dir},
|
||||||
|
"config": {},
|
||||||
|
})
|
||||||
|
|
||||||
|
assert result.success is True
|
||||||
|
high_risk_areas = result.data.get("high_risk_areas", [])
|
||||||
|
|
||||||
|
# 应该发现高风险区域
|
||||||
|
assert len(high_risk_areas) > 0
|
||||||
|
|
||||||
|
|
||||||
|
class TestAnalysisAgent:
|
||||||
|
"""Analysis Agent 测试"""
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def analysis_agent(self, temp_project_dir, mock_llm_service, mock_event_emitter):
|
||||||
|
"""创建 Analysis Agent 实例"""
|
||||||
|
from app.services.agent.tools import (
|
||||||
|
FileReadTool, FileSearchTool, PatternMatchTool,
|
||||||
|
)
|
||||||
|
|
||||||
|
tools = {
|
||||||
|
"read_file": FileReadTool(temp_project_dir),
|
||||||
|
"search_code": FileSearchTool(temp_project_dir),
|
||||||
|
"pattern_match": PatternMatchTool(temp_project_dir),
|
||||||
|
}
|
||||||
|
|
||||||
|
return AnalysisAgent(
|
||||||
|
llm_service=mock_llm_service,
|
||||||
|
tools=tools,
|
||||||
|
event_emitter=mock_event_emitter,
|
||||||
|
)
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_analysis_agent_run(self, analysis_agent, temp_project_dir):
|
||||||
|
"""测试 Analysis Agent 运行"""
|
||||||
|
result = await analysis_agent.run({
|
||||||
|
"tech_stack": {"languages": ["Python"]},
|
||||||
|
"entry_points": [],
|
||||||
|
"high_risk_areas": ["src/sql_vuln.py", "src/cmd_vuln.py"],
|
||||||
|
"config": {},
|
||||||
|
})
|
||||||
|
|
||||||
|
assert result.success is True
|
||||||
|
assert result.data is not None
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_analysis_agent_finds_vulnerabilities(self, analysis_agent, temp_project_dir):
|
||||||
|
"""测试 Analysis Agent 发现漏洞"""
|
||||||
|
result = await analysis_agent.run({
|
||||||
|
"tech_stack": {"languages": ["Python"]},
|
||||||
|
"entry_points": [],
|
||||||
|
"high_risk_areas": [
|
||||||
|
"src/sql_vuln.py",
|
||||||
|
"src/cmd_vuln.py",
|
||||||
|
"src/xss_vuln.py",
|
||||||
|
"src/secrets.py",
|
||||||
|
],
|
||||||
|
"config": {},
|
||||||
|
})
|
||||||
|
|
||||||
|
assert result.success is True
|
||||||
|
findings = result.data.get("findings", [])
|
||||||
|
|
||||||
|
# 应该发现一些漏洞
|
||||||
|
# 注意:具体数量取决于分析逻辑
|
||||||
|
assert isinstance(findings, list)
|
||||||
|
|
||||||
|
|
||||||
|
class TestAgentResult:
|
||||||
|
"""Agent 结果测试"""
|
||||||
|
|
||||||
|
def test_agent_result_success(self):
|
||||||
|
"""测试成功的 Agent 结果"""
|
||||||
|
result = AgentResult(
|
||||||
|
success=True,
|
||||||
|
data={"findings": []},
|
||||||
|
iterations=5,
|
||||||
|
tool_calls=10,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert result.success is True
|
||||||
|
assert result.iterations == 5
|
||||||
|
assert result.tool_calls == 10
|
||||||
|
|
||||||
|
def test_agent_result_failure(self):
|
||||||
|
"""测试失败的 Agent 结果"""
|
||||||
|
result = AgentResult(
|
||||||
|
success=False,
|
||||||
|
error="Test error",
|
||||||
|
)
|
||||||
|
|
||||||
|
assert result.success is False
|
||||||
|
assert result.error == "Test error"
|
||||||
|
|
||||||
|
def test_agent_result_to_dict(self):
|
||||||
|
"""测试 Agent 结果转字典"""
|
||||||
|
result = AgentResult(
|
||||||
|
success=True,
|
||||||
|
data={"key": "value"},
|
||||||
|
iterations=3,
|
||||||
|
)
|
||||||
|
|
||||||
|
d = result.to_dict()
|
||||||
|
|
||||||
|
assert d["success"] is True
|
||||||
|
assert d["iterations"] == 3
|
||||||
|
|
||||||
|
|
||||||
|
class TestAgentConfig:
|
||||||
|
"""Agent 配置测试"""
|
||||||
|
|
||||||
|
def test_agent_config_defaults(self):
|
||||||
|
"""测试 Agent 配置默认值"""
|
||||||
|
config = AgentConfig(
|
||||||
|
name="Test",
|
||||||
|
agent_type=AgentType.RECON,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert config.pattern == AgentPattern.REACT
|
||||||
|
assert config.max_iterations == 20
|
||||||
|
assert config.temperature == 0.1
|
||||||
|
|
||||||
|
def test_agent_config_custom(self):
|
||||||
|
"""测试自定义 Agent 配置"""
|
||||||
|
config = AgentConfig(
|
||||||
|
name="Custom",
|
||||||
|
agent_type=AgentType.ANALYSIS,
|
||||||
|
pattern=AgentPattern.PLAN_AND_EXECUTE,
|
||||||
|
max_iterations=50,
|
||||||
|
temperature=0.5,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert config.pattern == AgentPattern.PLAN_AND_EXECUTE
|
||||||
|
assert config.max_iterations == 50
|
||||||
|
assert config.temperature == 0.5
|
||||||
|
|
||||||
|
|
@ -0,0 +1,355 @@
|
||||||
|
"""
|
||||||
|
Agent 集成测试
|
||||||
|
测试完整的审计流程
|
||||||
|
"""
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
import asyncio
|
||||||
|
import os
|
||||||
|
from unittest.mock import MagicMock, AsyncMock, patch
|
||||||
|
from datetime import datetime
|
||||||
|
|
||||||
|
from app.services.agent.graph.runner import AgentRunner, LLMService
|
||||||
|
from app.services.agent.graph.audit_graph import AuditState, create_audit_graph
|
||||||
|
from app.services.agent.graph.nodes import ReconNode, AnalysisNode, VerificationNode, ReportNode
|
||||||
|
from app.services.agent.event_manager import EventManager, AgentEventEmitter
|
||||||
|
|
||||||
|
|
||||||
|
class TestLLMService:
|
||||||
|
"""LLM 服务测试"""
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_llm_service_initialization(self):
|
||||||
|
"""测试 LLM 服务初始化"""
|
||||||
|
with patch("app.core.config.settings") as mock_settings:
|
||||||
|
mock_settings.LLM_MODEL = "gpt-4o-mini"
|
||||||
|
mock_settings.LLM_API_KEY = "test-key"
|
||||||
|
|
||||||
|
service = LLMService()
|
||||||
|
|
||||||
|
assert service.model == "gpt-4o-mini"
|
||||||
|
|
||||||
|
|
||||||
|
class TestEventManager:
|
||||||
|
"""事件管理器测试"""
|
||||||
|
|
||||||
|
def test_event_manager_initialization(self):
|
||||||
|
"""测试事件管理器初始化"""
|
||||||
|
manager = EventManager()
|
||||||
|
|
||||||
|
assert manager._event_queues == {}
|
||||||
|
assert manager._event_callbacks == {}
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_event_emitter(self):
|
||||||
|
"""测试事件发射器"""
|
||||||
|
manager = EventManager()
|
||||||
|
emitter = AgentEventEmitter("test-task-id", manager)
|
||||||
|
|
||||||
|
await emitter.emit_info("Test message")
|
||||||
|
|
||||||
|
assert emitter._sequence == 1
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_event_emitter_phase_tracking(self):
|
||||||
|
"""测试事件发射器阶段跟踪"""
|
||||||
|
manager = EventManager()
|
||||||
|
emitter = AgentEventEmitter("test-task-id", manager)
|
||||||
|
|
||||||
|
await emitter.emit_phase_start("recon", "开始信息收集")
|
||||||
|
|
||||||
|
assert emitter._current_phase == "recon"
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_event_emitter_task_complete(self):
|
||||||
|
"""测试任务完成事件"""
|
||||||
|
manager = EventManager()
|
||||||
|
emitter = AgentEventEmitter("test-task-id", manager)
|
||||||
|
|
||||||
|
await emitter.emit_task_complete(findings_count=5, duration_ms=1000)
|
||||||
|
|
||||||
|
assert emitter._sequence == 1
|
||||||
|
|
||||||
|
|
||||||
|
class TestAuditGraph:
|
||||||
|
"""审计图测试"""
|
||||||
|
|
||||||
|
def test_create_audit_graph(self, mock_event_emitter):
|
||||||
|
"""测试创建审计图"""
|
||||||
|
# 创建模拟节点
|
||||||
|
recon_node = MagicMock()
|
||||||
|
analysis_node = MagicMock()
|
||||||
|
verification_node = MagicMock()
|
||||||
|
report_node = MagicMock()
|
||||||
|
|
||||||
|
graph = create_audit_graph(
|
||||||
|
recon_node=recon_node,
|
||||||
|
analysis_node=analysis_node,
|
||||||
|
verification_node=verification_node,
|
||||||
|
report_node=report_node,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert graph is not None
|
||||||
|
|
||||||
|
|
||||||
|
class TestReconNode:
|
||||||
|
"""Recon 节点测试"""
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def recon_node_with_mock_agent(self, mock_event_emitter):
|
||||||
|
"""创建带模拟 Agent 的 Recon 节点"""
|
||||||
|
mock_agent = MagicMock()
|
||||||
|
mock_agent.run = AsyncMock(return_value=MagicMock(
|
||||||
|
success=True,
|
||||||
|
data={
|
||||||
|
"tech_stack": {"languages": ["Python"]},
|
||||||
|
"entry_points": [{"path": "src/app.py", "type": "api"}],
|
||||||
|
"high_risk_areas": ["src/sql_vuln.py"],
|
||||||
|
"dependencies": {},
|
||||||
|
"initial_findings": [],
|
||||||
|
}
|
||||||
|
))
|
||||||
|
|
||||||
|
return ReconNode(mock_agent, mock_event_emitter)
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_recon_node_success(self, recon_node_with_mock_agent):
|
||||||
|
"""测试 Recon 节点成功执行"""
|
||||||
|
state = {
|
||||||
|
"project_info": {"name": "Test"},
|
||||||
|
"config": {},
|
||||||
|
}
|
||||||
|
|
||||||
|
result = await recon_node_with_mock_agent(state)
|
||||||
|
|
||||||
|
assert "tech_stack" in result
|
||||||
|
assert "entry_points" in result
|
||||||
|
assert result["current_phase"] == "recon_complete"
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_recon_node_failure(self, mock_event_emitter):
|
||||||
|
"""测试 Recon 节点失败处理"""
|
||||||
|
mock_agent = MagicMock()
|
||||||
|
mock_agent.run = AsyncMock(return_value=MagicMock(
|
||||||
|
success=False,
|
||||||
|
error="Test error",
|
||||||
|
data=None,
|
||||||
|
))
|
||||||
|
|
||||||
|
node = ReconNode(mock_agent, mock_event_emitter)
|
||||||
|
|
||||||
|
result = await node({
|
||||||
|
"project_info": {},
|
||||||
|
"config": {},
|
||||||
|
})
|
||||||
|
|
||||||
|
assert "error" in result
|
||||||
|
assert result["current_phase"] == "error"
|
||||||
|
|
||||||
|
|
||||||
|
class TestAnalysisNode:
|
||||||
|
"""Analysis 节点测试"""
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def analysis_node_with_mock_agent(self, mock_event_emitter):
|
||||||
|
"""创建带模拟 Agent 的 Analysis 节点"""
|
||||||
|
mock_agent = MagicMock()
|
||||||
|
mock_agent.run = AsyncMock(return_value=MagicMock(
|
||||||
|
success=True,
|
||||||
|
data={
|
||||||
|
"findings": [
|
||||||
|
{
|
||||||
|
"id": "finding-1",
|
||||||
|
"title": "SQL Injection",
|
||||||
|
"severity": "high",
|
||||||
|
"vulnerability_type": "sql_injection",
|
||||||
|
"file_path": "src/sql_vuln.py",
|
||||||
|
"line_start": 10,
|
||||||
|
"description": "SQL injection vulnerability",
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"should_continue": False,
|
||||||
|
}
|
||||||
|
))
|
||||||
|
|
||||||
|
return AnalysisNode(mock_agent, mock_event_emitter)
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_analysis_node_success(self, analysis_node_with_mock_agent):
|
||||||
|
"""测试 Analysis 节点成功执行"""
|
||||||
|
state = {
|
||||||
|
"project_info": {"name": "Test"},
|
||||||
|
"tech_stack": {"languages": ["Python"]},
|
||||||
|
"entry_points": [],
|
||||||
|
"high_risk_areas": ["src/sql_vuln.py"],
|
||||||
|
"config": {},
|
||||||
|
"iteration": 0,
|
||||||
|
"findings": [],
|
||||||
|
}
|
||||||
|
|
||||||
|
result = await analysis_node_with_mock_agent(state)
|
||||||
|
|
||||||
|
assert "findings" in result
|
||||||
|
assert len(result["findings"]) > 0
|
||||||
|
assert result["iteration"] == 1
|
||||||
|
|
||||||
|
|
||||||
|
class TestIntegrationFlow:
|
||||||
|
"""完整流程集成测试"""
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_full_audit_flow_mock(self, temp_project_dir, mock_db_session, mock_task):
|
||||||
|
"""测试完整审计流程(使用模拟)"""
|
||||||
|
# 这个测试验证整个流程的连接性
|
||||||
|
|
||||||
|
# 创建事件管理器
|
||||||
|
event_manager = EventManager()
|
||||||
|
emitter = AgentEventEmitter(mock_task.id, event_manager)
|
||||||
|
|
||||||
|
# 模拟 LLM 服务
|
||||||
|
mock_llm = MagicMock()
|
||||||
|
mock_llm.chat_completion_raw = AsyncMock(return_value={
|
||||||
|
"content": "Analysis complete",
|
||||||
|
"usage": {"total_tokens": 100},
|
||||||
|
})
|
||||||
|
|
||||||
|
# 验证事件发射
|
||||||
|
await emitter.emit_phase_start("init", "初始化")
|
||||||
|
await emitter.emit_info("测试消息")
|
||||||
|
await emitter.emit_phase_complete("init", "初始化完成")
|
||||||
|
|
||||||
|
assert emitter._sequence == 3
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_audit_state_typing(self):
|
||||||
|
"""测试审计状态类型定义"""
|
||||||
|
state: AuditState = {
|
||||||
|
"project_root": "/tmp/test",
|
||||||
|
"project_info": {"name": "Test"},
|
||||||
|
"config": {},
|
||||||
|
"task_id": "test-id",
|
||||||
|
"tech_stack": {},
|
||||||
|
"entry_points": [],
|
||||||
|
"high_risk_areas": [],
|
||||||
|
"dependencies": {},
|
||||||
|
"findings": [],
|
||||||
|
"verified_findings": [],
|
||||||
|
"false_positives": [],
|
||||||
|
"current_phase": "start",
|
||||||
|
"iteration": 0,
|
||||||
|
"max_iterations": 50,
|
||||||
|
"should_continue_analysis": False,
|
||||||
|
"messages": [],
|
||||||
|
"events": [],
|
||||||
|
"summary": None,
|
||||||
|
"security_score": None,
|
||||||
|
"error": None,
|
||||||
|
}
|
||||||
|
|
||||||
|
assert state["current_phase"] == "start"
|
||||||
|
assert state["max_iterations"] == 50
|
||||||
|
|
||||||
|
|
||||||
|
class TestToolIntegration:
|
||||||
|
"""工具集成测试"""
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_tools_work_together(self, temp_project_dir):
|
||||||
|
"""测试工具协同工作"""
|
||||||
|
from app.services.agent.tools import (
|
||||||
|
FileReadTool, FileSearchTool, ListFilesTool, PatternMatchTool,
|
||||||
|
)
|
||||||
|
|
||||||
|
# 1. 列出文件
|
||||||
|
list_tool = ListFilesTool(temp_project_dir)
|
||||||
|
list_result = await list_tool.execute(directory="src", recursive=False)
|
||||||
|
assert list_result.success is True
|
||||||
|
|
||||||
|
# 2. 搜索关键代码
|
||||||
|
search_tool = FileSearchTool(temp_project_dir)
|
||||||
|
search_result = await search_tool.execute(keyword="execute")
|
||||||
|
assert search_result.success is True
|
||||||
|
|
||||||
|
# 3. 读取文件内容
|
||||||
|
read_tool = FileReadTool(temp_project_dir)
|
||||||
|
read_result = await read_tool.execute(file_path="src/sql_vuln.py")
|
||||||
|
assert read_result.success is True
|
||||||
|
|
||||||
|
# 4. 模式匹配
|
||||||
|
pattern_tool = PatternMatchTool(temp_project_dir)
|
||||||
|
pattern_result = await pattern_tool.execute(
|
||||||
|
code=read_result.data,
|
||||||
|
file_path="src/sql_vuln.py",
|
||||||
|
language="python"
|
||||||
|
)
|
||||||
|
assert pattern_result.success is True
|
||||||
|
|
||||||
|
|
||||||
|
class TestErrorHandling:
|
||||||
|
"""错误处理测试"""
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_tool_error_handling(self, temp_project_dir):
|
||||||
|
"""测试工具错误处理"""
|
||||||
|
from app.services.agent.tools import FileReadTool
|
||||||
|
|
||||||
|
tool = FileReadTool(temp_project_dir)
|
||||||
|
|
||||||
|
# 尝试读取不存在的文件
|
||||||
|
result = await tool.execute(file_path="nonexistent/file.py")
|
||||||
|
|
||||||
|
assert result.success is False
|
||||||
|
assert result.error is not None
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_agent_graceful_degradation(self, mock_event_emitter):
|
||||||
|
"""测试 Agent 优雅降级"""
|
||||||
|
# 创建一个会失败的 Agent
|
||||||
|
mock_agent = MagicMock()
|
||||||
|
mock_agent.run = AsyncMock(side_effect=Exception("Simulated error"))
|
||||||
|
|
||||||
|
node = ReconNode(mock_agent, mock_event_emitter)
|
||||||
|
|
||||||
|
result = await node({
|
||||||
|
"project_info": {},
|
||||||
|
"config": {},
|
||||||
|
})
|
||||||
|
|
||||||
|
# 应该返回错误状态而不是崩溃
|
||||||
|
assert "error" in result
|
||||||
|
assert result["current_phase"] == "error"
|
||||||
|
|
||||||
|
|
||||||
|
class TestPerformance:
|
||||||
|
"""性能测试"""
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_tool_response_time(self, temp_project_dir):
|
||||||
|
"""测试工具响应时间"""
|
||||||
|
from app.services.agent.tools import ListFilesTool
|
||||||
|
import time
|
||||||
|
|
||||||
|
tool = ListFilesTool(temp_project_dir)
|
||||||
|
|
||||||
|
start = time.time()
|
||||||
|
await tool.execute(directory=".", recursive=True)
|
||||||
|
duration = time.time() - start
|
||||||
|
|
||||||
|
# 工具应该在合理时间内响应
|
||||||
|
assert duration < 5.0 # 5 秒内
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_multiple_tool_calls(self, temp_project_dir):
|
||||||
|
"""测试多次工具调用"""
|
||||||
|
from app.services.agent.tools import FileSearchTool
|
||||||
|
|
||||||
|
tool = FileSearchTool(temp_project_dir)
|
||||||
|
|
||||||
|
# 执行多次调用
|
||||||
|
for _ in range(5):
|
||||||
|
result = await tool.execute(keyword="def")
|
||||||
|
assert result.success is True
|
||||||
|
|
||||||
|
# 验证调用计数
|
||||||
|
assert tool._call_count == 5
|
||||||
|
|
||||||
|
|
@ -0,0 +1,248 @@
|
||||||
|
"""
|
||||||
|
Agent 工具单元测试
|
||||||
|
测试各种安全分析工具的功能
|
||||||
|
"""
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
import asyncio
|
||||||
|
import os
|
||||||
|
from unittest.mock import MagicMock, AsyncMock, patch
|
||||||
|
|
||||||
|
# 导入工具
|
||||||
|
from app.services.agent.tools import (
|
||||||
|
FileReadTool, FileSearchTool, ListFilesTool,
|
||||||
|
PatternMatchTool,
|
||||||
|
)
|
||||||
|
from app.services.agent.tools.base import ToolResult
|
||||||
|
|
||||||
|
|
||||||
|
class TestFileTools:
|
||||||
|
"""文件操作工具测试"""
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_file_read_tool_success(self, temp_project_dir):
|
||||||
|
"""测试文件读取工具 - 成功读取"""
|
||||||
|
tool = FileReadTool(temp_project_dir)
|
||||||
|
|
||||||
|
result = await tool.execute(file_path="src/sql_vuln.py")
|
||||||
|
|
||||||
|
assert result.success is True
|
||||||
|
assert "SELECT * FROM users" in result.data
|
||||||
|
assert "sql_injection" in result.data.lower() or "cursor.execute" in result.data
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_file_read_tool_not_found(self, temp_project_dir):
|
||||||
|
"""测试文件读取工具 - 文件不存在"""
|
||||||
|
tool = FileReadTool(temp_project_dir)
|
||||||
|
|
||||||
|
result = await tool.execute(file_path="nonexistent.py")
|
||||||
|
|
||||||
|
assert result.success is False
|
||||||
|
assert "不存在" in result.error or "not found" in result.error.lower()
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_file_read_tool_path_traversal_blocked(self, temp_project_dir):
|
||||||
|
"""测试文件读取工具 - 路径遍历被阻止"""
|
||||||
|
tool = FileReadTool(temp_project_dir)
|
||||||
|
|
||||||
|
result = await tool.execute(file_path="../../../etc/passwd")
|
||||||
|
|
||||||
|
assert result.success is False
|
||||||
|
assert "安全" in result.error or "security" in result.error.lower()
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_file_search_tool(self, temp_project_dir):
|
||||||
|
"""测试文件搜索工具"""
|
||||||
|
tool = FileSearchTool(temp_project_dir)
|
||||||
|
|
||||||
|
result = await tool.execute(keyword="cursor.execute")
|
||||||
|
|
||||||
|
assert result.success is True
|
||||||
|
assert "sql_vuln.py" in result.data
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_list_files_tool(self, temp_project_dir):
|
||||||
|
"""测试文件列表工具"""
|
||||||
|
tool = ListFilesTool(temp_project_dir)
|
||||||
|
|
||||||
|
result = await tool.execute(directory=".", recursive=True)
|
||||||
|
|
||||||
|
assert result.success is True
|
||||||
|
assert "sql_vuln.py" in result.data
|
||||||
|
assert "requirements.txt" in result.data
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_list_files_tool_pattern(self, temp_project_dir):
|
||||||
|
"""测试文件列表工具 - 文件模式过滤"""
|
||||||
|
tool = ListFilesTool(temp_project_dir)
|
||||||
|
|
||||||
|
result = await tool.execute(directory="src", pattern="*.py")
|
||||||
|
|
||||||
|
assert result.success is True
|
||||||
|
assert "sql_vuln.py" in result.data
|
||||||
|
|
||||||
|
|
||||||
|
class TestPatternMatchTool:
|
||||||
|
"""模式匹配工具测试"""
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_pattern_match_sql_injection(self, temp_project_dir):
|
||||||
|
"""测试模式匹配 - SQL 注入检测"""
|
||||||
|
tool = PatternMatchTool(temp_project_dir)
|
||||||
|
|
||||||
|
# 读取有漏洞的代码
|
||||||
|
with open(os.path.join(temp_project_dir, "src", "sql_vuln.py")) as f:
|
||||||
|
code = f.read()
|
||||||
|
|
||||||
|
result = await tool.execute(
|
||||||
|
code=code,
|
||||||
|
file_path="src/sql_vuln.py",
|
||||||
|
pattern_types=["sql_injection"],
|
||||||
|
language="python"
|
||||||
|
)
|
||||||
|
|
||||||
|
assert result.success is True
|
||||||
|
# 应该检测到 SQL 注入模式
|
||||||
|
if result.data:
|
||||||
|
assert "sql" in str(result.data).lower() or len(result.metadata.get("matches", [])) > 0
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_pattern_match_command_injection(self, temp_project_dir):
|
||||||
|
"""测试模式匹配 - 命令注入检测"""
|
||||||
|
tool = PatternMatchTool(temp_project_dir)
|
||||||
|
|
||||||
|
with open(os.path.join(temp_project_dir, "src", "cmd_vuln.py")) as f:
|
||||||
|
code = f.read()
|
||||||
|
|
||||||
|
result = await tool.execute(
|
||||||
|
code=code,
|
||||||
|
file_path="src/cmd_vuln.py",
|
||||||
|
pattern_types=["command_injection"],
|
||||||
|
language="python"
|
||||||
|
)
|
||||||
|
|
||||||
|
assert result.success is True
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_pattern_match_xss(self, temp_project_dir):
|
||||||
|
"""测试模式匹配 - XSS 检测"""
|
||||||
|
tool = PatternMatchTool(temp_project_dir)
|
||||||
|
|
||||||
|
with open(os.path.join(temp_project_dir, "src", "xss_vuln.py")) as f:
|
||||||
|
code = f.read()
|
||||||
|
|
||||||
|
result = await tool.execute(
|
||||||
|
code=code,
|
||||||
|
file_path="src/xss_vuln.py",
|
||||||
|
pattern_types=["xss"],
|
||||||
|
language="python"
|
||||||
|
)
|
||||||
|
|
||||||
|
assert result.success is True
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_pattern_match_hardcoded_secrets(self, temp_project_dir):
|
||||||
|
"""测试模式匹配 - 硬编码密钥检测"""
|
||||||
|
tool = PatternMatchTool(temp_project_dir)
|
||||||
|
|
||||||
|
with open(os.path.join(temp_project_dir, "src", "secrets.py")) as f:
|
||||||
|
code = f.read()
|
||||||
|
|
||||||
|
result = await tool.execute(
|
||||||
|
code=code,
|
||||||
|
file_path="src/secrets.py",
|
||||||
|
pattern_types=["hardcoded_secret"],
|
||||||
|
)
|
||||||
|
|
||||||
|
assert result.success is True
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_pattern_match_safe_code(self, temp_project_dir):
|
||||||
|
"""测试模式匹配 - 安全代码应该没有问题"""
|
||||||
|
tool = PatternMatchTool(temp_project_dir)
|
||||||
|
|
||||||
|
with open(os.path.join(temp_project_dir, "src", "safe_code.py")) as f:
|
||||||
|
code = f.read()
|
||||||
|
|
||||||
|
result = await tool.execute(
|
||||||
|
code=code,
|
||||||
|
file_path="src/safe_code.py",
|
||||||
|
pattern_types=["sql_injection"],
|
||||||
|
language="python"
|
||||||
|
)
|
||||||
|
|
||||||
|
assert result.success is True
|
||||||
|
# 安全代码使用参数化查询,不应该有 SQL 注入漏洞
|
||||||
|
# 检查结果数据,如果有 matches 字段
|
||||||
|
matches = result.metadata.get("matches", [])
|
||||||
|
if isinstance(matches, list):
|
||||||
|
# 参数化查询不应该被误报为 SQL 注入
|
||||||
|
sql_injection_count = sum(
|
||||||
|
1 for m in matches
|
||||||
|
if isinstance(m, dict) and "sql" in m.get("pattern_type", "").lower()
|
||||||
|
)
|
||||||
|
# 安全代码的 SQL 注入匹配应该很少或没有
|
||||||
|
assert sql_injection_count <= 1 # 允许少量误报
|
||||||
|
|
||||||
|
|
||||||
|
class TestToolResult:
|
||||||
|
"""工具结果测试"""
|
||||||
|
|
||||||
|
def test_tool_result_success(self):
|
||||||
|
"""测试成功的工具结果"""
|
||||||
|
result = ToolResult(success=True, data="test data")
|
||||||
|
|
||||||
|
assert result.success is True
|
||||||
|
assert result.data == "test data"
|
||||||
|
assert result.error is None
|
||||||
|
|
||||||
|
def test_tool_result_failure(self):
|
||||||
|
"""测试失败的工具结果"""
|
||||||
|
result = ToolResult(success=False, error="test error")
|
||||||
|
|
||||||
|
assert result.success is False
|
||||||
|
assert result.error == "test error"
|
||||||
|
|
||||||
|
def test_tool_result_to_string(self):
|
||||||
|
"""测试工具结果转字符串"""
|
||||||
|
result = ToolResult(success=True, data={"key": "value"})
|
||||||
|
|
||||||
|
string = result.to_string()
|
||||||
|
|
||||||
|
assert "key" in string
|
||||||
|
assert "value" in string
|
||||||
|
|
||||||
|
def test_tool_result_to_string_truncate(self):
|
||||||
|
"""测试工具结果字符串截断"""
|
||||||
|
long_data = "x" * 10000
|
||||||
|
result = ToolResult(success=True, data=long_data)
|
||||||
|
|
||||||
|
string = result.to_string(max_length=100)
|
||||||
|
|
||||||
|
assert len(string) < len(long_data)
|
||||||
|
assert "truncated" in string.lower()
|
||||||
|
|
||||||
|
|
||||||
|
class TestToolMetadata:
|
||||||
|
"""工具元数据测试"""
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_tool_call_count(self, temp_project_dir):
|
||||||
|
"""测试工具调用计数"""
|
||||||
|
tool = ListFilesTool(temp_project_dir)
|
||||||
|
|
||||||
|
await tool.execute(directory=".")
|
||||||
|
await tool.execute(directory="src")
|
||||||
|
|
||||||
|
assert tool._call_count == 2
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_tool_duration_tracking(self, temp_project_dir):
|
||||||
|
"""测试工具执行时间跟踪"""
|
||||||
|
tool = ListFilesTool(temp_project_dir)
|
||||||
|
|
||||||
|
result = await tool.execute(directory=".")
|
||||||
|
|
||||||
|
assert result.duration_ms >= 0
|
||||||
|
assert tool._total_duration_ms >= 0
|
||||||
|
|
||||||
|
|
@ -0,0 +1,229 @@
|
||||||
|
# DeepAudit Agent 审计功能部署清单
|
||||||
|
|
||||||
|
## 📋 生产部署前必须完成的检查
|
||||||
|
|
||||||
|
### 1. 环境依赖 ✅
|
||||||
|
|
||||||
|
```bash
|
||||||
|
# 后端依赖
|
||||||
|
cd backend
|
||||||
|
uv pip install chromadb litellm langchain langgraph
|
||||||
|
|
||||||
|
# 外部安全工具(可选但推荐)
|
||||||
|
pip install semgrep bandit safety
|
||||||
|
|
||||||
|
# 或者使用系统包管理器
|
||||||
|
brew install semgrep # macOS
|
||||||
|
apt install semgrep # Ubuntu
|
||||||
|
```
|
||||||
|
|
||||||
|
### 2. LLM 配置 ✅
|
||||||
|
|
||||||
|
在 `.env` 文件中配置:
|
||||||
|
|
||||||
|
```env
|
||||||
|
# LLM 配置(必须)
|
||||||
|
LLM_PROVIDER=openai # 或 azure, anthropic, ollama 等
|
||||||
|
LLM_MODEL=gpt-4o-mini # 推荐使用 gpt-4 系列
|
||||||
|
LLM_API_KEY=sk-xxx # 你的 API Key
|
||||||
|
LLM_BASE_URL= # 可选,自定义端点
|
||||||
|
|
||||||
|
# 嵌入模型配置(RAG 需要)
|
||||||
|
EMBEDDING_PROVIDER=openai
|
||||||
|
EMBEDDING_MODEL=text-embedding-3-small
|
||||||
|
```
|
||||||
|
|
||||||
|
### 3. 数据库迁移 ✅
|
||||||
|
|
||||||
|
```bash
|
||||||
|
cd backend
|
||||||
|
alembic upgrade head
|
||||||
|
```
|
||||||
|
|
||||||
|
确保以下表已创建:
|
||||||
|
- `agent_tasks`
|
||||||
|
- `agent_events`
|
||||||
|
- `agent_findings`
|
||||||
|
|
||||||
|
### 4. 向量数据库 ✅
|
||||||
|
|
||||||
|
```bash
|
||||||
|
# 创建向量数据库目录
|
||||||
|
mkdir -p /var/data/deepaudit/vector_db
|
||||||
|
|
||||||
|
# 在 .env 中配置
|
||||||
|
VECTOR_DB_PATH=/var/data/deepaudit/vector_db
|
||||||
|
```
|
||||||
|
|
||||||
|
### 5. Docker 沙箱(可选)
|
||||||
|
|
||||||
|
如果需要漏洞验证功能:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
# 拉取沙箱镜像
|
||||||
|
docker pull python:3.11-slim
|
||||||
|
|
||||||
|
# 配置沙箱参数
|
||||||
|
SANDBOX_IMAGE=python:3.11-slim
|
||||||
|
SANDBOX_MEMORY_LIMIT=256m
|
||||||
|
SANDBOX_CPU_LIMIT=0.5
|
||||||
|
```
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## 🔬 功能测试检查
|
||||||
|
|
||||||
|
### 测试 1: 基础流程
|
||||||
|
|
||||||
|
```bash
|
||||||
|
cd backend
|
||||||
|
PYTHONPATH=. uv run pytest tests/agent/ -v
|
||||||
|
```
|
||||||
|
|
||||||
|
预期结果:43 个测试全部通过
|
||||||
|
|
||||||
|
### 测试 2: LLM 连接
|
||||||
|
|
||||||
|
```bash
|
||||||
|
cd backend
|
||||||
|
PYTHONPATH=. uv run python -c "
|
||||||
|
import asyncio
|
||||||
|
from app.services.agent.graph.runner import LLMService
|
||||||
|
|
||||||
|
async def test():
|
||||||
|
llm = LLMService()
|
||||||
|
result = await llm.analyze_code('print(\"hello\")', 'python')
|
||||||
|
print('LLM 连接成功:', 'issues' in result)
|
||||||
|
|
||||||
|
asyncio.run(test())
|
||||||
|
"
|
||||||
|
```
|
||||||
|
|
||||||
|
### 测试 3: 外部工具
|
||||||
|
|
||||||
|
```bash
|
||||||
|
# 测试 Semgrep
|
||||||
|
semgrep --version
|
||||||
|
|
||||||
|
# 测试 Bandit
|
||||||
|
bandit --version
|
||||||
|
```
|
||||||
|
|
||||||
|
### 测试 4: 端到端测试
|
||||||
|
|
||||||
|
1. 启动后端:`cd backend && uv run uvicorn app.main:app --reload`
|
||||||
|
2. 启动前端:`cd frontend && npm run dev`
|
||||||
|
3. 创建一个项目并上传代码
|
||||||
|
4. 选择 "Agent 审计模式" 创建任务
|
||||||
|
5. 观察执行日志和发现
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## ⚠️ 已知限制
|
||||||
|
|
||||||
|
| 限制 | 影响 | 解决方案 |
|
||||||
|
|------|------|---------|
|
||||||
|
| **LLM 成本** | 每次审计消耗 Token | 使用 gpt-4o-mini 降低成本 |
|
||||||
|
| **扫描时间** | 大项目需要较长时间 | 设置合理的超时时间 |
|
||||||
|
| **误报率** | AI 可能产生误报 | 启用验证阶段过滤 |
|
||||||
|
| **外部工具依赖** | 需要手动安装 | 提供 Docker 镜像 |
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## 🚀 生产环境建议
|
||||||
|
|
||||||
|
### 1. 资源配置
|
||||||
|
|
||||||
|
```yaml
|
||||||
|
# Kubernetes 部署示例
|
||||||
|
resources:
|
||||||
|
limits:
|
||||||
|
memory: "2Gi"
|
||||||
|
cpu: "2"
|
||||||
|
requests:
|
||||||
|
memory: "1Gi"
|
||||||
|
cpu: "1"
|
||||||
|
```
|
||||||
|
|
||||||
|
### 2. 并发控制
|
||||||
|
|
||||||
|
```env
|
||||||
|
# 限制同时运行的任务数
|
||||||
|
MAX_CONCURRENT_AGENT_TASKS=3
|
||||||
|
AGENT_TASK_TIMEOUT=1800 # 30 分钟
|
||||||
|
```
|
||||||
|
|
||||||
|
### 3. 日志监控
|
||||||
|
|
||||||
|
```python
|
||||||
|
# 配置日志级别
|
||||||
|
LOG_LEVEL=INFO
|
||||||
|
# 启用 SQLAlchemy 日志(调试用)
|
||||||
|
SQLALCHEMY_ECHO=false
|
||||||
|
```
|
||||||
|
|
||||||
|
### 4. 安全考虑
|
||||||
|
|
||||||
|
- [ ] 限制上传文件大小
|
||||||
|
- [ ] 限制扫描目录范围
|
||||||
|
- [ ] 启用沙箱隔离
|
||||||
|
- [ ] 配置 API 速率限制
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## ✅ 部署状态检查
|
||||||
|
|
||||||
|
运行以下命令验证部署状态:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
cd backend
|
||||||
|
PYTHONPATH=. uv run python -c "
|
||||||
|
print('检查部署状态...')
|
||||||
|
|
||||||
|
# 1. 检查数据库连接
|
||||||
|
try:
|
||||||
|
from app.db.session import async_session_factory
|
||||||
|
print('✅ 数据库配置正确')
|
||||||
|
except Exception as e:
|
||||||
|
print(f'❌ 数据库错误: {e}')
|
||||||
|
|
||||||
|
# 2. 检查 LLM 配置
|
||||||
|
from app.core.config import settings
|
||||||
|
if settings.LLM_API_KEY:
|
||||||
|
print('✅ LLM API Key 已配置')
|
||||||
|
else:
|
||||||
|
print('⚠️ LLM API Key 未配置')
|
||||||
|
|
||||||
|
# 3. 检查向量数据库
|
||||||
|
import os
|
||||||
|
if os.path.exists(settings.VECTOR_DB_PATH or '/tmp'):
|
||||||
|
print('✅ 向量数据库路径存在')
|
||||||
|
else:
|
||||||
|
print('⚠️ 向量数据库路径不存在')
|
||||||
|
|
||||||
|
# 4. 检查外部工具
|
||||||
|
import shutil
|
||||||
|
tools = ['semgrep', 'bandit']
|
||||||
|
for tool in tools:
|
||||||
|
if shutil.which(tool):
|
||||||
|
print(f'✅ {tool} 已安装')
|
||||||
|
else:
|
||||||
|
print(f'⚠️ {tool} 未安装(可选)')
|
||||||
|
|
||||||
|
print()
|
||||||
|
print('部署检查完成!')
|
||||||
|
"
|
||||||
|
```
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## 📝 结论
|
||||||
|
|
||||||
|
Agent 审计功能已经具备**基本的生产能力**,但建议:
|
||||||
|
|
||||||
|
1. **先在测试环境验证** - 用一个小项目测试完整流程
|
||||||
|
2. **监控 LLM 成本** - 观察 Token 消耗情况
|
||||||
|
3. **逐步开放** - 先给少数用户使用,收集反馈
|
||||||
|
4. **持续优化** - 根据实际效果调整 prompt 和阈值
|
||||||
|
|
||||||
|
如有问题,请查看日志或联系开发团队。
|
||||||
|
|
@ -55,3 +55,5 @@ EXPOSE 3000
|
||||||
ENTRYPOINT ["/docker-entrypoint.sh"]
|
ENTRYPOINT ["/docker-entrypoint.sh"]
|
||||||
CMD ["serve", "-s", "dist", "-l", "3000"]
|
CMD ["serve", "-s", "dist", "-l", "3000"]
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -18,3 +18,5 @@ export const ProtectedRoute = () => {
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,280 @@
|
||||||
|
/**
|
||||||
|
* Agent 流式事件 React Hook
|
||||||
|
*
|
||||||
|
* 用于在 React 组件中消费 Agent 审计的实时事件流
|
||||||
|
*/
|
||||||
|
|
||||||
|
import { useState, useEffect, useCallback, useRef } from 'react';
|
||||||
|
import {
|
||||||
|
AgentStreamHandler,
|
||||||
|
StreamEventData,
|
||||||
|
StreamOptions,
|
||||||
|
AgentStreamState,
|
||||||
|
} from '../shared/api/agentStream';
|
||||||
|
|
||||||
|
export interface UseAgentStreamOptions extends Omit<StreamOptions, 'onEvent'> {
|
||||||
|
autoConnect?: boolean;
|
||||||
|
maxEvents?: number;
|
||||||
|
}
|
||||||
|
|
||||||
|
export interface UseAgentStreamReturn extends AgentStreamState {
|
||||||
|
connect: () => void;
|
||||||
|
disconnect: () => void;
|
||||||
|
isConnected: boolean;
|
||||||
|
clearEvents: () => void;
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Agent 流式事件 Hook
|
||||||
|
*
|
||||||
|
* @example
|
||||||
|
* ```tsx
|
||||||
|
* function AgentAuditPanel({ taskId }: { taskId: string }) {
|
||||||
|
* const {
|
||||||
|
* events,
|
||||||
|
* thinking,
|
||||||
|
* isThinking,
|
||||||
|
* toolCalls,
|
||||||
|
* currentPhase,
|
||||||
|
* progress,
|
||||||
|
* findings,
|
||||||
|
* isComplete,
|
||||||
|
* error,
|
||||||
|
* connect,
|
||||||
|
* disconnect,
|
||||||
|
* isConnected,
|
||||||
|
* } = useAgentStream(taskId);
|
||||||
|
*
|
||||||
|
* useEffect(() => {
|
||||||
|
* connect();
|
||||||
|
* return () => disconnect();
|
||||||
|
* }, [taskId]);
|
||||||
|
*
|
||||||
|
* return (
|
||||||
|
* <div>
|
||||||
|
* {isThinking && <ThinkingIndicator text={thinking} />}
|
||||||
|
* {toolCalls.map(tc => <ToolCallCard key={tc.name} {...tc} />)}
|
||||||
|
* {findings.map(f => <FindingCard key={f.id} {...f} />)}
|
||||||
|
* </div>
|
||||||
|
* );
|
||||||
|
* }
|
||||||
|
* ```
|
||||||
|
*/
|
||||||
|
export function useAgentStream(
|
||||||
|
taskId: string | null,
|
||||||
|
options: UseAgentStreamOptions = {}
|
||||||
|
): UseAgentStreamReturn {
|
||||||
|
const {
|
||||||
|
autoConnect = false,
|
||||||
|
maxEvents = 500,
|
||||||
|
includeThinking = true,
|
||||||
|
includeToolCalls = true,
|
||||||
|
afterSequence = 0,
|
||||||
|
...callbackOptions
|
||||||
|
} = options;
|
||||||
|
|
||||||
|
// 状态
|
||||||
|
const [events, setEvents] = useState<StreamEventData[]>([]);
|
||||||
|
const [thinking, setThinking] = useState('');
|
||||||
|
const [isThinking, setIsThinking] = useState(false);
|
||||||
|
const [toolCalls, setToolCalls] = useState<AgentStreamState['toolCalls']>([]);
|
||||||
|
const [currentPhase, setCurrentPhase] = useState('');
|
||||||
|
const [progress, setProgress] = useState({ current: 0, total: 100, percentage: 0 });
|
||||||
|
const [findings, setFindings] = useState<Record<string, unknown>[]>([]);
|
||||||
|
const [isComplete, setIsComplete] = useState(false);
|
||||||
|
const [error, setError] = useState<string | null>(null);
|
||||||
|
const [isConnected, setIsConnected] = useState(false);
|
||||||
|
|
||||||
|
// Handler ref
|
||||||
|
const handlerRef = useRef<AgentStreamHandler | null>(null);
|
||||||
|
const thinkingBufferRef = useRef<string[]>([]);
|
||||||
|
|
||||||
|
// 连接
|
||||||
|
const connect = useCallback(() => {
|
||||||
|
if (!taskId) return;
|
||||||
|
|
||||||
|
// 断开现有连接
|
||||||
|
if (handlerRef.current) {
|
||||||
|
handlerRef.current.disconnect();
|
||||||
|
}
|
||||||
|
|
||||||
|
// 重置状态
|
||||||
|
setEvents([]);
|
||||||
|
setThinking('');
|
||||||
|
setIsThinking(false);
|
||||||
|
setToolCalls([]);
|
||||||
|
setCurrentPhase('');
|
||||||
|
setProgress({ current: 0, total: 100, percentage: 0 });
|
||||||
|
setFindings([]);
|
||||||
|
setIsComplete(false);
|
||||||
|
setError(null);
|
||||||
|
thinkingBufferRef.current = [];
|
||||||
|
|
||||||
|
// 创建新的 handler
|
||||||
|
handlerRef.current = new AgentStreamHandler(taskId, {
|
||||||
|
includeThinking,
|
||||||
|
includeToolCalls,
|
||||||
|
afterSequence,
|
||||||
|
|
||||||
|
onEvent: (event) => {
|
||||||
|
setEvents((prev) => [...prev.slice(-maxEvents + 1), event]);
|
||||||
|
},
|
||||||
|
|
||||||
|
onThinkingStart: () => {
|
||||||
|
thinkingBufferRef.current = [];
|
||||||
|
setIsThinking(true);
|
||||||
|
setThinking('');
|
||||||
|
callbackOptions.onThinkingStart?.();
|
||||||
|
},
|
||||||
|
|
||||||
|
onThinkingToken: (token, accumulated) => {
|
||||||
|
thinkingBufferRef.current.push(token);
|
||||||
|
setThinking(accumulated);
|
||||||
|
callbackOptions.onThinkingToken?.(token, accumulated);
|
||||||
|
},
|
||||||
|
|
||||||
|
onThinkingEnd: (response) => {
|
||||||
|
setIsThinking(false);
|
||||||
|
setThinking(response);
|
||||||
|
thinkingBufferRef.current = [];
|
||||||
|
callbackOptions.onThinkingEnd?.(response);
|
||||||
|
},
|
||||||
|
|
||||||
|
onToolStart: (name, input) => {
|
||||||
|
setToolCalls((prev) => [
|
||||||
|
...prev,
|
||||||
|
{ name, input, status: 'running' as const },
|
||||||
|
]);
|
||||||
|
callbackOptions.onToolStart?.(name, input);
|
||||||
|
},
|
||||||
|
|
||||||
|
onToolEnd: (name, output, durationMs) => {
|
||||||
|
setToolCalls((prev) =>
|
||||||
|
prev.map((tc) =>
|
||||||
|
tc.name === name && tc.status === 'running'
|
||||||
|
? { ...tc, output, durationMs, status: 'success' as const }
|
||||||
|
: tc
|
||||||
|
)
|
||||||
|
);
|
||||||
|
callbackOptions.onToolEnd?.(name, output, durationMs);
|
||||||
|
},
|
||||||
|
|
||||||
|
onNodeStart: (nodeName, phase) => {
|
||||||
|
setCurrentPhase(phase);
|
||||||
|
callbackOptions.onNodeStart?.(nodeName, phase);
|
||||||
|
},
|
||||||
|
|
||||||
|
onNodeEnd: (nodeName, summary) => {
|
||||||
|
callbackOptions.onNodeEnd?.(nodeName, summary);
|
||||||
|
},
|
||||||
|
|
||||||
|
onProgress: (current, total, message) => {
|
||||||
|
setProgress({
|
||||||
|
current,
|
||||||
|
total,
|
||||||
|
percentage: total > 0 ? Math.round((current / total) * 100) : 0,
|
||||||
|
});
|
||||||
|
callbackOptions.onProgress?.(current, total, message);
|
||||||
|
},
|
||||||
|
|
||||||
|
onFinding: (finding, isVerified) => {
|
||||||
|
setFindings((prev) => [...prev, finding]);
|
||||||
|
callbackOptions.onFinding?.(finding, isVerified);
|
||||||
|
},
|
||||||
|
|
||||||
|
onComplete: (data) => {
|
||||||
|
setIsComplete(true);
|
||||||
|
setIsConnected(false);
|
||||||
|
callbackOptions.onComplete?.(data);
|
||||||
|
},
|
||||||
|
|
||||||
|
onError: (err) => {
|
||||||
|
setError(err);
|
||||||
|
setIsComplete(true);
|
||||||
|
setIsConnected(false);
|
||||||
|
callbackOptions.onError?.(err);
|
||||||
|
},
|
||||||
|
|
||||||
|
onHeartbeat: () => {
|
||||||
|
callbackOptions.onHeartbeat?.();
|
||||||
|
},
|
||||||
|
});
|
||||||
|
|
||||||
|
handlerRef.current.connect();
|
||||||
|
setIsConnected(true);
|
||||||
|
}, [taskId, includeThinking, includeToolCalls, afterSequence, maxEvents, callbackOptions]);
|
||||||
|
|
||||||
|
// 断开连接
|
||||||
|
const disconnect = useCallback(() => {
|
||||||
|
if (handlerRef.current) {
|
||||||
|
handlerRef.current.disconnect();
|
||||||
|
handlerRef.current = null;
|
||||||
|
}
|
||||||
|
setIsConnected(false);
|
||||||
|
}, []);
|
||||||
|
|
||||||
|
// 清空事件
|
||||||
|
const clearEvents = useCallback(() => {
|
||||||
|
setEvents([]);
|
||||||
|
}, []);
|
||||||
|
|
||||||
|
// 自动连接
|
||||||
|
useEffect(() => {
|
||||||
|
if (autoConnect && taskId) {
|
||||||
|
connect();
|
||||||
|
}
|
||||||
|
return () => {
|
||||||
|
disconnect();
|
||||||
|
};
|
||||||
|
}, [taskId, autoConnect, connect, disconnect]);
|
||||||
|
|
||||||
|
// 清理
|
||||||
|
useEffect(() => {
|
||||||
|
return () => {
|
||||||
|
if (handlerRef.current) {
|
||||||
|
handlerRef.current.disconnect();
|
||||||
|
}
|
||||||
|
};
|
||||||
|
}, []);
|
||||||
|
|
||||||
|
return {
|
||||||
|
events,
|
||||||
|
thinking,
|
||||||
|
isThinking,
|
||||||
|
toolCalls,
|
||||||
|
currentPhase,
|
||||||
|
progress,
|
||||||
|
findings,
|
||||||
|
isComplete,
|
||||||
|
error,
|
||||||
|
connect,
|
||||||
|
disconnect,
|
||||||
|
isConnected,
|
||||||
|
clearEvents,
|
||||||
|
};
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 简化版 Hook - 只获取思考过程
|
||||||
|
*/
|
||||||
|
export function useAgentThinking(taskId: string | null) {
|
||||||
|
const { thinking, isThinking, connect, disconnect } = useAgentStream(taskId, {
|
||||||
|
includeToolCalls: false,
|
||||||
|
});
|
||||||
|
|
||||||
|
return { thinking, isThinking, connect, disconnect };
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 简化版 Hook - 只获取工具调用
|
||||||
|
*/
|
||||||
|
export function useAgentToolCalls(taskId: string | null) {
|
||||||
|
const { toolCalls, connect, disconnect } = useAgentStream(taskId, {
|
||||||
|
includeThinking: false,
|
||||||
|
});
|
||||||
|
|
||||||
|
return { toolCalls, connect, disconnect };
|
||||||
|
}
|
||||||
|
|
||||||
|
export default useAgentStream;
|
||||||
|
|
||||||
|
|
@ -1,6 +1,7 @@
|
||||||
/**
|
/**
|
||||||
* Agent 审计页面
|
* Agent 审计页面
|
||||||
* 机械终端风格的 AI Agent 审计界面
|
* 机械终端风格的 AI Agent 审计界面
|
||||||
|
* 支持 LLM 思考过程和工具调用的实时流式展示
|
||||||
*/
|
*/
|
||||||
|
|
||||||
import { useState, useEffect, useRef, useCallback } from "react";
|
import { useState, useEffect, useRef, useCallback } from "react";
|
||||||
|
|
@ -9,12 +10,14 @@ import {
|
||||||
Terminal, Bot, Cpu, Shield, AlertTriangle, CheckCircle2,
|
Terminal, Bot, Cpu, Shield, AlertTriangle, CheckCircle2,
|
||||||
Loader2, Code, Zap, Activity, ChevronRight, XCircle,
|
Loader2, Code, Zap, Activity, ChevronRight, XCircle,
|
||||||
FileCode, Search, Bug, Lock, Play, Square, RefreshCw,
|
FileCode, Search, Bug, Lock, Play, Square, RefreshCw,
|
||||||
ArrowLeft, Download, ExternalLink
|
ArrowLeft, Download, ExternalLink, Brain, Wrench,
|
||||||
|
ChevronDown, ChevronUp, Clock, Sparkles
|
||||||
} from "lucide-react";
|
} from "lucide-react";
|
||||||
import { Button } from "@/components/ui/button";
|
import { Button } from "@/components/ui/button";
|
||||||
import { Badge } from "@/components/ui/badge";
|
import { Badge } from "@/components/ui/badge";
|
||||||
import { ScrollArea } from "@/components/ui/scroll-area";
|
import { ScrollArea } from "@/components/ui/scroll-area";
|
||||||
import { toast } from "sonner";
|
import { toast } from "sonner";
|
||||||
|
import { useAgentStream } from "@/hooks/useAgentStream";
|
||||||
import {
|
import {
|
||||||
type AgentTask,
|
type AgentTask,
|
||||||
type AgentEvent,
|
type AgentEvent,
|
||||||
|
|
@ -91,10 +94,36 @@ export default function AgentAuditPage() {
|
||||||
const [isLoading, setIsLoading] = useState(true);
|
const [isLoading, setIsLoading] = useState(true);
|
||||||
const [isStreaming, setIsStreaming] = useState(false);
|
const [isStreaming, setIsStreaming] = useState(false);
|
||||||
const [currentTime, setCurrentTime] = useState(new Date().toLocaleTimeString("zh-CN", { hour12: false }));
|
const [currentTime, setCurrentTime] = useState(new Date().toLocaleTimeString("zh-CN", { hour12: false }));
|
||||||
|
const [showThinking, setShowThinking] = useState(true);
|
||||||
|
const [showToolDetails, setShowToolDetails] = useState(true);
|
||||||
|
|
||||||
const eventsEndRef = useRef<HTMLDivElement>(null);
|
const eventsEndRef = useRef<HTMLDivElement>(null);
|
||||||
|
const thinkingEndRef = useRef<HTMLDivElement>(null);
|
||||||
const abortControllerRef = useRef<AbortController | null>(null);
|
const abortControllerRef = useRef<AbortController | null>(null);
|
||||||
|
|
||||||
|
// 使用增强版流式 Hook
|
||||||
|
const {
|
||||||
|
thinking,
|
||||||
|
isThinking,
|
||||||
|
toolCalls,
|
||||||
|
currentPhase: streamPhase,
|
||||||
|
progress: streamProgress,
|
||||||
|
connect: connectStream,
|
||||||
|
disconnect: disconnectStream,
|
||||||
|
isConnected: isStreamConnected,
|
||||||
|
} = useAgentStream(taskId || null, {
|
||||||
|
includeThinking: true,
|
||||||
|
includeToolCalls: true,
|
||||||
|
onFinding: () => loadFindings(),
|
||||||
|
onComplete: () => {
|
||||||
|
loadTask();
|
||||||
|
loadFindings();
|
||||||
|
},
|
||||||
|
onError: (err) => {
|
||||||
|
console.error("Stream error:", err);
|
||||||
|
},
|
||||||
|
});
|
||||||
|
|
||||||
// 是否完成
|
// 是否完成
|
||||||
const isComplete = task?.status === "completed" || task?.status === "failed" || task?.status === "cancelled";
|
const isComplete = task?.status === "completed" || task?.status === "failed" || task?.status === "cancelled";
|
||||||
|
|
||||||
|
|
@ -146,12 +175,24 @@ export default function AgentAuditPage() {
|
||||||
init();
|
init();
|
||||||
}, [loadTask, loadEvents, loadFindings]);
|
}, [loadTask, loadEvents, loadFindings]);
|
||||||
|
|
||||||
// 事件流
|
// 连接增强版流式 API
|
||||||
|
useEffect(() => {
|
||||||
|
if (!taskId || isComplete || isLoading) return;
|
||||||
|
|
||||||
|
connectStream();
|
||||||
|
setIsStreaming(true);
|
||||||
|
|
||||||
|
return () => {
|
||||||
|
disconnectStream();
|
||||||
|
setIsStreaming(false);
|
||||||
|
};
|
||||||
|
}, [taskId, isComplete, isLoading, connectStream, disconnectStream]);
|
||||||
|
|
||||||
|
// 旧版事件流(作为后备)
|
||||||
useEffect(() => {
|
useEffect(() => {
|
||||||
if (!taskId || isComplete || isLoading) return;
|
if (!taskId || isComplete || isLoading) return;
|
||||||
|
|
||||||
const startStreaming = async () => {
|
const startStreaming = async () => {
|
||||||
setIsStreaming(true);
|
|
||||||
abortControllerRef.current = new AbortController();
|
abortControllerRef.current = new AbortController();
|
||||||
|
|
||||||
try {
|
try {
|
||||||
|
|
@ -179,8 +220,6 @@ export default function AgentAuditPage() {
|
||||||
if ((error as Error).name !== "AbortError") {
|
if ((error as Error).name !== "AbortError") {
|
||||||
console.error("Event stream error:", error);
|
console.error("Event stream error:", error);
|
||||||
}
|
}
|
||||||
} finally {
|
|
||||||
setIsStreaming(false);
|
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
|
@ -205,6 +244,30 @@ export default function AgentAuditPage() {
|
||||||
return () => clearInterval(interval);
|
return () => clearInterval(interval);
|
||||||
}, []);
|
}, []);
|
||||||
|
|
||||||
|
// 定期轮询任务状态(作为 SSE 的后备机制)
|
||||||
|
useEffect(() => {
|
||||||
|
if (!taskId || isComplete || isLoading) return;
|
||||||
|
|
||||||
|
// 每 3 秒轮询一次任务状态
|
||||||
|
const pollInterval = setInterval(async () => {
|
||||||
|
try {
|
||||||
|
const taskData = await getAgentTask(taskId);
|
||||||
|
setTask(taskData);
|
||||||
|
|
||||||
|
// 如果任务已完成/失败/取消,刷新其他数据
|
||||||
|
if (taskData.status === "completed" || taskData.status === "failed" || taskData.status === "cancelled") {
|
||||||
|
await loadEvents();
|
||||||
|
await loadFindings();
|
||||||
|
clearInterval(pollInterval);
|
||||||
|
}
|
||||||
|
} catch (error) {
|
||||||
|
console.error("Failed to poll task status:", error);
|
||||||
|
}
|
||||||
|
}, 3000);
|
||||||
|
|
||||||
|
return () => clearInterval(pollInterval);
|
||||||
|
}, [taskId, isComplete, isLoading, loadEvents, loadFindings]);
|
||||||
|
|
||||||
// 取消任务
|
// 取消任务
|
||||||
const handleCancel = async () => {
|
const handleCancel = async () => {
|
||||||
if (!taskId) return;
|
if (!taskId) return;
|
||||||
|
|
@ -291,14 +354,85 @@ export default function AgentAuditPage() {
|
||||||
</div>
|
</div>
|
||||||
</div>
|
</div>
|
||||||
|
|
||||||
|
{/* 错误提示 */}
|
||||||
|
{task.status === "failed" && task.error_message && (
|
||||||
|
<div className="mx-4 mt-2 p-3 bg-red-900/30 border border-red-700 rounded-lg">
|
||||||
|
<div className="flex items-start gap-2">
|
||||||
|
<XCircle className="w-5 h-5 text-red-400 flex-shrink-0 mt-0.5" />
|
||||||
|
<div>
|
||||||
|
<p className="text-red-400 font-semibold text-sm">任务执行失败</p>
|
||||||
|
<p className="text-red-300/80 text-xs mt-1 font-mono break-all">{task.error_message}</p>
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
)}
|
||||||
|
|
||||||
<div className="flex h-[calc(100vh-56px)]">
|
<div className="flex h-[calc(100vh-56px)]">
|
||||||
{/* 左侧:执行日志 */}
|
{/* 左侧:执行日志 */}
|
||||||
<div className="flex-1 p-4 flex flex-col min-w-0">
|
<div className="flex-1 p-4 flex flex-col min-w-0">
|
||||||
|
|
||||||
|
{/* 思考过程展示区域 */}
|
||||||
|
{(isThinking || thinking) && showThinking && (
|
||||||
|
<div className="mb-4 bg-purple-950/30 rounded-lg border border-purple-800/50 overflow-hidden">
|
||||||
|
<div
|
||||||
|
className="flex items-center justify-between px-3 py-2 bg-purple-900/30 border-b border-purple-800/30 cursor-pointer"
|
||||||
|
onClick={() => setShowThinking(!showThinking)}
|
||||||
|
>
|
||||||
|
<div className="flex items-center gap-2 text-xs text-purple-400">
|
||||||
|
<Brain className={`w-4 h-4 ${isThinking ? "animate-pulse" : ""}`} />
|
||||||
|
<span className="uppercase tracking-wider">AI Thinking</span>
|
||||||
|
{isThinking && (
|
||||||
|
<span className="flex items-center gap-1 text-purple-300">
|
||||||
|
<Sparkles className="w-3 h-3 animate-spin" />
|
||||||
|
<span className="text-[10px]">Processing...</span>
|
||||||
|
</span>
|
||||||
|
)}
|
||||||
|
</div>
|
||||||
|
{showThinking ? <ChevronUp className="w-4 h-4 text-purple-400" /> : <ChevronDown className="w-4 h-4 text-purple-400" />}
|
||||||
|
</div>
|
||||||
|
|
||||||
|
<div className="max-h-40 overflow-y-auto">
|
||||||
|
<div className="p-3 text-sm text-purple-200/80 font-mono whitespace-pre-wrap">
|
||||||
|
{thinking || "正在思考..."}
|
||||||
|
{isThinking && <span className="animate-pulse text-purple-400">▌</span>}
|
||||||
|
</div>
|
||||||
|
<div ref={thinkingEndRef} />
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
)}
|
||||||
|
|
||||||
|
{/* 工具调用展示区域 */}
|
||||||
|
{toolCalls.length > 0 && showToolDetails && (
|
||||||
|
<div className="mb-4 bg-yellow-950/20 rounded-lg border border-yellow-800/30 overflow-hidden">
|
||||||
|
<div
|
||||||
|
className="flex items-center justify-between px-3 py-2 bg-yellow-900/20 border-b border-yellow-800/20 cursor-pointer"
|
||||||
|
onClick={() => setShowToolDetails(!showToolDetails)}
|
||||||
|
>
|
||||||
|
<div className="flex items-center gap-2 text-xs text-yellow-500">
|
||||||
|
<Wrench className="w-4 h-4" />
|
||||||
|
<span className="uppercase tracking-wider">Tool Calls</span>
|
||||||
|
<Badge variant="outline" className="text-[10px] px-1.5 py-0 bg-yellow-900/30 border-yellow-700 text-yellow-400">
|
||||||
|
{toolCalls.length}
|
||||||
|
</Badge>
|
||||||
|
</div>
|
||||||
|
{showToolDetails ? <ChevronUp className="w-4 h-4 text-yellow-500" /> : <ChevronDown className="w-4 h-4 text-yellow-500" />}
|
||||||
|
</div>
|
||||||
|
|
||||||
|
<div className="max-h-48 overflow-y-auto">
|
||||||
|
<div className="p-2 space-y-2">
|
||||||
|
{toolCalls.slice(-5).map((tc, idx) => (
|
||||||
|
<ToolCallCard key={`${tc.name}-${idx}`} toolCall={tc} />
|
||||||
|
))}
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
)}
|
||||||
|
|
||||||
<div className="flex items-center justify-between mb-3">
|
<div className="flex items-center justify-between mb-3">
|
||||||
<div className="flex items-center gap-2 text-xs text-cyan-400">
|
<div className="flex items-center gap-2 text-xs text-cyan-400">
|
||||||
<Terminal className="w-4 h-4" />
|
<Terminal className="w-4 h-4" />
|
||||||
<span className="uppercase tracking-wider">Execution Log</span>
|
<span className="uppercase tracking-wider">Execution Log</span>
|
||||||
{isStreaming && (
|
{(isStreaming || isStreamConnected) && (
|
||||||
<span className="flex items-center gap-1 text-green-400">
|
<span className="flex items-center gap-1 text-green-400">
|
||||||
<span className="w-2 h-2 bg-green-400 rounded-full animate-pulse" />
|
<span className="w-2 h-2 bg-green-400 rounded-full animate-pulse" />
|
||||||
LIVE
|
LIVE
|
||||||
|
|
@ -540,6 +674,94 @@ function EventLine({ event }: { event: AgentEvent }) {
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// 工具调用卡片组件
|
||||||
|
interface ToolCallProps {
|
||||||
|
toolCall: {
|
||||||
|
name: string;
|
||||||
|
input: Record<string, unknown>;
|
||||||
|
output?: unknown;
|
||||||
|
durationMs?: number;
|
||||||
|
status: 'running' | 'success' | 'error';
|
||||||
|
};
|
||||||
|
}
|
||||||
|
|
||||||
|
function ToolCallCard({ toolCall }: ToolCallProps) {
|
||||||
|
const [expanded, setExpanded] = useState(false);
|
||||||
|
|
||||||
|
const statusConfig = {
|
||||||
|
running: {
|
||||||
|
icon: <Loader2 className="w-3 h-3 animate-spin text-yellow-400" />,
|
||||||
|
badge: "bg-yellow-900/30 border-yellow-700 text-yellow-400",
|
||||||
|
text: "Running",
|
||||||
|
},
|
||||||
|
success: {
|
||||||
|
icon: <CheckCircle2 className="w-3 h-3 text-green-400" />,
|
||||||
|
badge: "bg-green-900/30 border-green-700 text-green-400",
|
||||||
|
text: "Done",
|
||||||
|
},
|
||||||
|
error: {
|
||||||
|
icon: <XCircle className="w-3 h-3 text-red-400" />,
|
||||||
|
badge: "bg-red-900/30 border-red-700 text-red-400",
|
||||||
|
text: "Error",
|
||||||
|
},
|
||||||
|
};
|
||||||
|
|
||||||
|
const config = statusConfig[toolCall.status];
|
||||||
|
|
||||||
|
return (
|
||||||
|
<div className="bg-gray-900/50 rounded border border-gray-700/50 overflow-hidden">
|
||||||
|
<div
|
||||||
|
className="flex items-center justify-between px-2 py-1.5 cursor-pointer hover:bg-gray-800/50"
|
||||||
|
onClick={() => setExpanded(!expanded)}
|
||||||
|
>
|
||||||
|
<div className="flex items-center gap-2">
|
||||||
|
{config.icon}
|
||||||
|
<span className="text-xs font-mono text-gray-300">{toolCall.name}</span>
|
||||||
|
</div>
|
||||||
|
<div className="flex items-center gap-2">
|
||||||
|
{toolCall.durationMs && (
|
||||||
|
<span className="text-[10px] text-gray-500">
|
||||||
|
<Clock className="w-2.5 h-2.5 inline mr-0.5" />
|
||||||
|
{toolCall.durationMs}ms
|
||||||
|
</span>
|
||||||
|
)}
|
||||||
|
<Badge variant="outline" className={`text-[10px] px-1 py-0 ${config.badge}`}>
|
||||||
|
{config.text}
|
||||||
|
</Badge>
|
||||||
|
{expanded ? <ChevronUp className="w-3 h-3 text-gray-500" /> : <ChevronDown className="w-3 h-3 text-gray-500" />}
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
|
||||||
|
{expanded && (
|
||||||
|
<div className="border-t border-gray-700/50 text-[11px] font-mono">
|
||||||
|
{/* 输入 */}
|
||||||
|
{toolCall.input && Object.keys(toolCall.input).length > 0 && (
|
||||||
|
<div className="p-2 border-b border-gray-800/50">
|
||||||
|
<span className="text-gray-500 text-[10px] uppercase">Input:</span>
|
||||||
|
<pre className="mt-1 text-cyan-300/80 whitespace-pre-wrap break-all max-h-24 overflow-y-auto">
|
||||||
|
{JSON.stringify(toolCall.input, null, 2).slice(0, 500)}
|
||||||
|
</pre>
|
||||||
|
</div>
|
||||||
|
)}
|
||||||
|
|
||||||
|
{/* 输出 */}
|
||||||
|
{toolCall.output && (
|
||||||
|
<div className="p-2">
|
||||||
|
<span className="text-gray-500 text-[10px] uppercase">Output:</span>
|
||||||
|
<pre className="mt-1 text-green-300/80 whitespace-pre-wrap break-all max-h-24 overflow-y-auto">
|
||||||
|
{typeof toolCall.output === 'string'
|
||||||
|
? toolCall.output.slice(0, 500)
|
||||||
|
: JSON.stringify(toolCall.output, null, 2).slice(0, 500)
|
||||||
|
}
|
||||||
|
</pre>
|
||||||
|
</div>
|
||||||
|
)}
|
||||||
|
</div>
|
||||||
|
)}
|
||||||
|
</div>
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
// 发现卡片组件
|
// 发现卡片组件
|
||||||
function FindingCard({ finding }: { finding: AgentFinding }) {
|
function FindingCard({ finding }: { finding: AgentFinding }) {
|
||||||
const colorClass = severityColors[finding.severity] || severityColors.info;
|
const colorClass = severityColors[finding.severity] || severityColors.info;
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,496 @@
|
||||||
|
/**
|
||||||
|
* Agent 流式事件处理
|
||||||
|
*
|
||||||
|
* 最佳实践:
|
||||||
|
* 1. 使用 EventSource API 或 fetch + ReadableStream
|
||||||
|
* 2. 支持重连机制
|
||||||
|
* 3. 分类处理不同事件类型
|
||||||
|
*/
|
||||||
|
|
||||||
|
// 事件类型定义
|
||||||
|
export type StreamEventType =
|
||||||
|
// LLM 相关
|
||||||
|
| 'thinking_start'
|
||||||
|
| 'thinking_token'
|
||||||
|
| 'thinking_end'
|
||||||
|
// 工具调用相关
|
||||||
|
| 'tool_call_start'
|
||||||
|
| 'tool_call_input'
|
||||||
|
| 'tool_call_output'
|
||||||
|
| 'tool_call_end'
|
||||||
|
| 'tool_call_error'
|
||||||
|
// 节点相关
|
||||||
|
| 'node_start'
|
||||||
|
| 'node_end'
|
||||||
|
// 阶段相关
|
||||||
|
| 'phase_start'
|
||||||
|
| 'phase_end'
|
||||||
|
| 'phase_complete'
|
||||||
|
// 发现相关
|
||||||
|
| 'finding_new'
|
||||||
|
| 'finding_verified'
|
||||||
|
// 状态相关
|
||||||
|
| 'progress'
|
||||||
|
| 'info'
|
||||||
|
| 'warning'
|
||||||
|
| 'error'
|
||||||
|
// 任务相关
|
||||||
|
| 'task_start'
|
||||||
|
| 'task_complete'
|
||||||
|
| 'task_error'
|
||||||
|
| 'task_cancel'
|
||||||
|
| 'task_end'
|
||||||
|
// 心跳
|
||||||
|
| 'heartbeat';
|
||||||
|
|
||||||
|
// 工具调用详情
|
||||||
|
export interface ToolCallDetail {
|
||||||
|
name: string;
|
||||||
|
input?: Record<string, unknown>;
|
||||||
|
output?: unknown;
|
||||||
|
duration_ms?: number;
|
||||||
|
}
|
||||||
|
|
||||||
|
// 流式事件数据
|
||||||
|
export interface StreamEventData {
|
||||||
|
id?: string;
|
||||||
|
type: StreamEventType;
|
||||||
|
phase?: string;
|
||||||
|
message?: string;
|
||||||
|
sequence?: number;
|
||||||
|
timestamp?: string;
|
||||||
|
tool?: ToolCallDetail;
|
||||||
|
metadata?: Record<string, unknown>;
|
||||||
|
tokens_used?: number;
|
||||||
|
// 特定类型数据
|
||||||
|
token?: string; // thinking_token
|
||||||
|
accumulated?: string; // thinking_token/thinking_end
|
||||||
|
status?: string; // task_end
|
||||||
|
error?: string; // task_error
|
||||||
|
findings_count?: number; // task_complete
|
||||||
|
security_score?: number; // task_complete
|
||||||
|
}
|
||||||
|
|
||||||
|
// 事件回调类型
|
||||||
|
export type StreamEventCallback = (event: StreamEventData) => void;
|
||||||
|
|
||||||
|
// 流式选项
|
||||||
|
export interface StreamOptions {
|
||||||
|
includeThinking?: boolean;
|
||||||
|
includeToolCalls?: boolean;
|
||||||
|
afterSequence?: number;
|
||||||
|
onThinkingStart?: () => void;
|
||||||
|
onThinkingToken?: (token: string, accumulated: string) => void;
|
||||||
|
onThinkingEnd?: (fullResponse: string) => void;
|
||||||
|
onToolStart?: (toolName: string, input: Record<string, unknown>) => void;
|
||||||
|
onToolEnd?: (toolName: string, output: unknown, durationMs: number) => void;
|
||||||
|
onNodeStart?: (nodeName: string, phase: string) => void;
|
||||||
|
onNodeEnd?: (nodeName: string, summary: Record<string, unknown>) => void;
|
||||||
|
onFinding?: (finding: Record<string, unknown>, isVerified: boolean) => void;
|
||||||
|
onProgress?: (current: number, total: number, message: string) => void;
|
||||||
|
onComplete?: (data: { findingsCount: number; securityScore: number }) => void;
|
||||||
|
onError?: (error: string) => void;
|
||||||
|
onHeartbeat?: () => void;
|
||||||
|
onEvent?: StreamEventCallback; // 通用事件回调
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Agent 流式事件处理器
|
||||||
|
*/
|
||||||
|
export class AgentStreamHandler {
|
||||||
|
private taskId: string;
|
||||||
|
private eventSource: EventSource | null = null;
|
||||||
|
private options: StreamOptions;
|
||||||
|
private reconnectAttempts = 0;
|
||||||
|
private maxReconnectAttempts = 5;
|
||||||
|
private reconnectDelay = 1000;
|
||||||
|
private isConnected = false;
|
||||||
|
private thinkingBuffer: string[] = [];
|
||||||
|
|
||||||
|
constructor(taskId: string, options: StreamOptions = {}) {
|
||||||
|
this.taskId = taskId;
|
||||||
|
this.options = {
|
||||||
|
includeThinking: true,
|
||||||
|
includeToolCalls: true,
|
||||||
|
afterSequence: 0,
|
||||||
|
...options,
|
||||||
|
};
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 开始监听事件流
|
||||||
|
*/
|
||||||
|
connect(): void {
|
||||||
|
const token = localStorage.getItem('access_token');
|
||||||
|
if (!token) {
|
||||||
|
this.options.onError?.('未登录');
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
const params = new URLSearchParams({
|
||||||
|
include_thinking: String(this.options.includeThinking),
|
||||||
|
include_tool_calls: String(this.options.includeToolCalls),
|
||||||
|
after_sequence: String(this.options.afterSequence),
|
||||||
|
});
|
||||||
|
|
||||||
|
// 使用 EventSource (不支持自定义 headers,需要通过 URL 传递 token)
|
||||||
|
// 或者使用 fetch + ReadableStream
|
||||||
|
this.connectWithFetch(token, params);
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 使用 fetch 连接(支持自定义 headers)
|
||||||
|
*/
|
||||||
|
private async connectWithFetch(token: string, params: URLSearchParams): Promise<void> {
|
||||||
|
const url = `/api/v1/agent-tasks/${this.taskId}/stream?${params}`;
|
||||||
|
|
||||||
|
try {
|
||||||
|
const response = await fetch(url, {
|
||||||
|
headers: {
|
||||||
|
'Authorization': `Bearer ${token}`,
|
||||||
|
'Accept': 'text/event-stream',
|
||||||
|
},
|
||||||
|
});
|
||||||
|
|
||||||
|
if (!response.ok) {
|
||||||
|
throw new Error(`HTTP ${response.status}: ${response.statusText}`);
|
||||||
|
}
|
||||||
|
|
||||||
|
this.isConnected = true;
|
||||||
|
this.reconnectAttempts = 0;
|
||||||
|
|
||||||
|
const reader = response.body?.getReader();
|
||||||
|
if (!reader) {
|
||||||
|
throw new Error('无法获取响应流');
|
||||||
|
}
|
||||||
|
|
||||||
|
const decoder = new TextDecoder();
|
||||||
|
let buffer = '';
|
||||||
|
|
||||||
|
while (true) {
|
||||||
|
const { done, value } = await reader.read();
|
||||||
|
|
||||||
|
if (done) {
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
|
||||||
|
buffer += decoder.decode(value, { stream: true });
|
||||||
|
|
||||||
|
// 解析 SSE 事件
|
||||||
|
const events = this.parseSSE(buffer);
|
||||||
|
buffer = events.remaining;
|
||||||
|
|
||||||
|
for (const event of events.parsed) {
|
||||||
|
this.handleEvent(event);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} catch (error) {
|
||||||
|
this.isConnected = false;
|
||||||
|
console.error('Stream connection error:', error);
|
||||||
|
|
||||||
|
// 尝试重连
|
||||||
|
if (this.reconnectAttempts < this.maxReconnectAttempts) {
|
||||||
|
this.reconnectAttempts++;
|
||||||
|
setTimeout(() => this.connect(), this.reconnectDelay * this.reconnectAttempts);
|
||||||
|
} else {
|
||||||
|
this.options.onError?.(`连接失败: ${error}`);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 解析 SSE 格式
|
||||||
|
*/
|
||||||
|
private parseSSE(buffer: string): { parsed: StreamEventData[]; remaining: string } {
|
||||||
|
const parsed: StreamEventData[] = [];
|
||||||
|
const lines = buffer.split('\n');
|
||||||
|
let remaining = '';
|
||||||
|
let currentEvent: Partial<StreamEventData> = {};
|
||||||
|
|
||||||
|
for (let i = 0; i < lines.length; i++) {
|
||||||
|
const line = lines[i];
|
||||||
|
|
||||||
|
// 空行表示事件结束
|
||||||
|
if (line === '') {
|
||||||
|
if (currentEvent.type) {
|
||||||
|
parsed.push(currentEvent as StreamEventData);
|
||||||
|
currentEvent = {};
|
||||||
|
}
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
// 检查是否是最后一行(可能不完整)
|
||||||
|
if (i === lines.length - 1 && !buffer.endsWith('\n')) {
|
||||||
|
remaining = line;
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
|
||||||
|
// 解析 event: 行
|
||||||
|
if (line.startsWith('event:')) {
|
||||||
|
currentEvent.type = line.slice(6).trim() as StreamEventType;
|
||||||
|
}
|
||||||
|
// 解析 data: 行
|
||||||
|
else if (line.startsWith('data:')) {
|
||||||
|
try {
|
||||||
|
const data = JSON.parse(line.slice(5).trim());
|
||||||
|
currentEvent = { ...currentEvent, ...data };
|
||||||
|
} catch {
|
||||||
|
// 忽略解析错误
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return { parsed, remaining };
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 处理事件
|
||||||
|
*/
|
||||||
|
private handleEvent(event: StreamEventData): void {
|
||||||
|
// 通用回调
|
||||||
|
this.options.onEvent?.(event);
|
||||||
|
|
||||||
|
// 分类处理
|
||||||
|
switch (event.type) {
|
||||||
|
// LLM 思考
|
||||||
|
case 'thinking_start':
|
||||||
|
this.thinkingBuffer = [];
|
||||||
|
this.options.onThinkingStart?.();
|
||||||
|
break;
|
||||||
|
|
||||||
|
case 'thinking_token':
|
||||||
|
if (event.token) {
|
||||||
|
this.thinkingBuffer.push(event.token);
|
||||||
|
this.options.onThinkingToken?.(
|
||||||
|
event.token,
|
||||||
|
event.accumulated || this.thinkingBuffer.join('')
|
||||||
|
);
|
||||||
|
}
|
||||||
|
break;
|
||||||
|
|
||||||
|
case 'thinking_end':
|
||||||
|
const fullResponse = event.accumulated || this.thinkingBuffer.join('');
|
||||||
|
this.thinkingBuffer = [];
|
||||||
|
this.options.onThinkingEnd?.(fullResponse);
|
||||||
|
break;
|
||||||
|
|
||||||
|
// 工具调用
|
||||||
|
case 'tool_call_start':
|
||||||
|
if (event.tool) {
|
||||||
|
this.options.onToolStart?.(
|
||||||
|
event.tool.name,
|
||||||
|
event.tool.input || {}
|
||||||
|
);
|
||||||
|
}
|
||||||
|
break;
|
||||||
|
|
||||||
|
case 'tool_call_end':
|
||||||
|
if (event.tool) {
|
||||||
|
this.options.onToolEnd?.(
|
||||||
|
event.tool.name,
|
||||||
|
event.tool.output,
|
||||||
|
event.tool.duration_ms || 0
|
||||||
|
);
|
||||||
|
}
|
||||||
|
break;
|
||||||
|
|
||||||
|
// 节点
|
||||||
|
case 'node_start':
|
||||||
|
this.options.onNodeStart?.(
|
||||||
|
event.metadata?.node as string || 'unknown',
|
||||||
|
event.phase || ''
|
||||||
|
);
|
||||||
|
break;
|
||||||
|
|
||||||
|
case 'node_end':
|
||||||
|
this.options.onNodeEnd?.(
|
||||||
|
event.metadata?.node as string || 'unknown',
|
||||||
|
event.metadata?.summary as Record<string, unknown> || {}
|
||||||
|
);
|
||||||
|
break;
|
||||||
|
|
||||||
|
// 发现
|
||||||
|
case 'finding_new':
|
||||||
|
case 'finding_verified':
|
||||||
|
this.options.onFinding?.(
|
||||||
|
event.metadata || {},
|
||||||
|
event.type === 'finding_verified'
|
||||||
|
);
|
||||||
|
break;
|
||||||
|
|
||||||
|
// 进度
|
||||||
|
case 'progress':
|
||||||
|
this.options.onProgress?.(
|
||||||
|
event.metadata?.current as number || 0,
|
||||||
|
event.metadata?.total as number || 100,
|
||||||
|
event.message || ''
|
||||||
|
);
|
||||||
|
break;
|
||||||
|
|
||||||
|
// 任务完成
|
||||||
|
case 'task_complete':
|
||||||
|
case 'task_end':
|
||||||
|
if (event.status !== 'cancelled' && event.status !== 'failed') {
|
||||||
|
this.options.onComplete?.({
|
||||||
|
findingsCount: event.findings_count || event.metadata?.findings_count as number || 0,
|
||||||
|
securityScore: event.security_score || event.metadata?.security_score as number || 100,
|
||||||
|
});
|
||||||
|
}
|
||||||
|
this.disconnect();
|
||||||
|
break;
|
||||||
|
|
||||||
|
// 错误
|
||||||
|
case 'task_error':
|
||||||
|
case 'error':
|
||||||
|
this.options.onError?.(event.error || event.message || '未知错误');
|
||||||
|
this.disconnect();
|
||||||
|
break;
|
||||||
|
|
||||||
|
// 心跳
|
||||||
|
case 'heartbeat':
|
||||||
|
this.options.onHeartbeat?.();
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 断开连接
|
||||||
|
*/
|
||||||
|
disconnect(): void {
|
||||||
|
this.isConnected = false;
|
||||||
|
if (this.eventSource) {
|
||||||
|
this.eventSource.close();
|
||||||
|
this.eventSource = null;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 检查是否已连接
|
||||||
|
*/
|
||||||
|
get connected(): boolean {
|
||||||
|
return this.isConnected;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 创建流式事件处理器的便捷函数
|
||||||
|
*/
|
||||||
|
export function createAgentStream(
|
||||||
|
taskId: string,
|
||||||
|
options: StreamOptions = {}
|
||||||
|
): AgentStreamHandler {
|
||||||
|
return new AgentStreamHandler(taskId, options);
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* React Hook 风格的使用示例
|
||||||
|
*
|
||||||
|
* ```tsx
|
||||||
|
* const { events, thinking, toolCalls, connect, disconnect } = useAgentStream(taskId);
|
||||||
|
*
|
||||||
|
* useEffect(() => {
|
||||||
|
* connect();
|
||||||
|
* return () => disconnect();
|
||||||
|
* }, [taskId]);
|
||||||
|
* ```
|
||||||
|
*/
|
||||||
|
export interface AgentStreamState {
|
||||||
|
events: StreamEventData[];
|
||||||
|
thinking: string;
|
||||||
|
isThinking: boolean;
|
||||||
|
toolCalls: Array<{
|
||||||
|
name: string;
|
||||||
|
input: Record<string, unknown>;
|
||||||
|
output?: unknown;
|
||||||
|
durationMs?: number;
|
||||||
|
status: 'running' | 'success' | 'error';
|
||||||
|
}>;
|
||||||
|
currentPhase: string;
|
||||||
|
progress: { current: number; total: number; percentage: number };
|
||||||
|
findings: Array<Record<string, unknown>>;
|
||||||
|
isComplete: boolean;
|
||||||
|
error: string | null;
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 创建用于 React 状态管理的流式处理器
|
||||||
|
*/
|
||||||
|
export function createAgentStreamWithState(
|
||||||
|
taskId: string,
|
||||||
|
onStateChange: (state: AgentStreamState) => void
|
||||||
|
): AgentStreamHandler {
|
||||||
|
const state: AgentStreamState = {
|
||||||
|
events: [],
|
||||||
|
thinking: '',
|
||||||
|
isThinking: false,
|
||||||
|
toolCalls: [],
|
||||||
|
currentPhase: '',
|
||||||
|
progress: { current: 0, total: 100, percentage: 0 },
|
||||||
|
findings: [],
|
||||||
|
isComplete: false,
|
||||||
|
error: null,
|
||||||
|
};
|
||||||
|
|
||||||
|
const updateState = (updates: Partial<AgentStreamState>) => {
|
||||||
|
Object.assign(state, updates);
|
||||||
|
onStateChange({ ...state });
|
||||||
|
};
|
||||||
|
|
||||||
|
return new AgentStreamHandler(taskId, {
|
||||||
|
onEvent: (event) => {
|
||||||
|
updateState({
|
||||||
|
events: [...state.events, event].slice(-500), // 保留最近 500 条
|
||||||
|
});
|
||||||
|
},
|
||||||
|
onThinkingStart: () => {
|
||||||
|
updateState({ isThinking: true, thinking: '' });
|
||||||
|
},
|
||||||
|
onThinkingToken: (_, accumulated) => {
|
||||||
|
updateState({ thinking: accumulated });
|
||||||
|
},
|
||||||
|
onThinkingEnd: (response) => {
|
||||||
|
updateState({ isThinking: false, thinking: response });
|
||||||
|
},
|
||||||
|
onToolStart: (name, input) => {
|
||||||
|
updateState({
|
||||||
|
toolCalls: [
|
||||||
|
...state.toolCalls,
|
||||||
|
{ name, input, status: 'running' },
|
||||||
|
],
|
||||||
|
});
|
||||||
|
},
|
||||||
|
onToolEnd: (name, output, durationMs) => {
|
||||||
|
updateState({
|
||||||
|
toolCalls: state.toolCalls.map((tc) =>
|
||||||
|
tc.name === name && tc.status === 'running'
|
||||||
|
? { ...tc, output, durationMs, status: 'success' as const }
|
||||||
|
: tc
|
||||||
|
),
|
||||||
|
});
|
||||||
|
},
|
||||||
|
onNodeStart: (_, phase) => {
|
||||||
|
updateState({ currentPhase: phase });
|
||||||
|
},
|
||||||
|
onProgress: (current, total, _) => {
|
||||||
|
updateState({
|
||||||
|
progress: {
|
||||||
|
current,
|
||||||
|
total,
|
||||||
|
percentage: total > 0 ? Math.round((current / total) * 100) : 0,
|
||||||
|
},
|
||||||
|
});
|
||||||
|
},
|
||||||
|
onFinding: (finding, _) => {
|
||||||
|
updateState({
|
||||||
|
findings: [...state.findings, finding],
|
||||||
|
});
|
||||||
|
},
|
||||||
|
onComplete: () => {
|
||||||
|
updateState({ isComplete: true });
|
||||||
|
},
|
||||||
|
onError: (error) => {
|
||||||
|
updateState({ error, isComplete: true });
|
||||||
|
},
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
|
@ -43,6 +43,9 @@ export interface AgentTask {
|
||||||
|
|
||||||
// 进度
|
// 进度
|
||||||
progress_percentage: number;
|
progress_percentage: number;
|
||||||
|
|
||||||
|
// 错误信息
|
||||||
|
error_message: string | null;
|
||||||
}
|
}
|
||||||
|
|
||||||
export interface AgentFinding {
|
export interface AgentFinding {
|
||||||
|
|
@ -249,7 +252,7 @@ export async function* streamAgentEvents(
|
||||||
afterSequence = 0,
|
afterSequence = 0,
|
||||||
signal?: AbortSignal
|
signal?: AbortSignal
|
||||||
): AsyncGenerator<AgentEvent, void, unknown> {
|
): AsyncGenerator<AgentEvent, void, unknown> {
|
||||||
const token = localStorage.getItem("auth_token");
|
const token = localStorage.getItem("access_token");
|
||||||
const baseUrl = import.meta.env.VITE_API_URL || "";
|
const baseUrl = import.meta.env.VITE_API_URL || "";
|
||||||
const url = `${baseUrl}/api/v1/agent-tasks/${taskId}/events?after_sequence=${afterSequence}`;
|
const url = `${baseUrl}/api/v1/agent-tasks/${taskId}/events?after_sequence=${afterSequence}`;
|
||||||
|
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue