feat(agent): implement streaming support for agent events and enhance UI components

- Introduce streaming capabilities for agent events, allowing real-time updates during audits.
- Add new hooks for managing agent stream events in React components.
- Enhance the AgentAudit page to display LLM thinking processes and tool call details in real-time.
- Update API endpoints to support streaming event data and improve error handling.
- Refactor UI components for better organization and user experience during audits.
This commit is contained in:
lintsinghua 2025-12-11 20:33:46 +08:00
parent a43ebf1793
commit 58c918f557
42 changed files with 4664 additions and 77 deletions

View File

@ -59,3 +59,5 @@ CMD ["uvicorn", "app.main:app", "--host", "0.0.0.0", "--port", "8000"]

View File

@ -103,3 +103,5 @@ datefmt = %H:%M:%S

View File

@ -90,3 +90,5 @@ else:

View File

@ -25,3 +25,5 @@ def downgrade() -> None:

View File

@ -29,6 +29,7 @@ from app.models.agent_task import (
from app.models.project import Project
from app.models.user import User
from app.services.agent import AgentRunner, EventManager, run_agent_task
from app.services.agent.streaming import StreamHandler, StreamEvent, StreamEventType
logger = logging.getLogger(__name__)
router = APIRouter()
@ -432,6 +433,171 @@ async def stream_agent_events(
)
@router.get("/{task_id}/stream")
async def stream_agent_with_thinking(
task_id: str,
include_thinking: bool = Query(True, description="是否包含 LLM 思考过程"),
include_tool_calls: bool = Query(True, description="是否包含工具调用详情"),
after_sequence: int = Query(0, ge=0, description="从哪个序号之后开始"),
db: AsyncSession = Depends(get_db),
current_user: User = Depends(deps.get_current_user),
):
"""
增强版事件流 (SSE)
支持:
- LLM 思考过程的 Token 级流式输出
- 工具调用的详细输入/输出
- 节点执行状态
- 发现事件
事件类型:
- thinking_start: LLM 开始思考
- thinking_token: LLM 输出 Token
- thinking_end: LLM 思考结束
- tool_call_start: 工具调用开始
- tool_call_end: 工具调用结束
- node_start: 节点开始
- node_end: 节点结束
- finding_new: 新发现
- finding_verified: 验证通过
- progress: 进度更新
- task_complete: 任务完成
- task_error: 任务错误
- heartbeat: 心跳
"""
task = await db.get(AgentTask, task_id)
if not task:
raise HTTPException(status_code=404, detail="任务不存在")
project = await db.get(Project, task.project_id)
if not project or project.owner_id != current_user.id:
raise HTTPException(status_code=403, detail="无权访问此任务")
async def enhanced_event_generator():
"""生成增强版 SSE 事件流"""
last_sequence = after_sequence
poll_interval = 0.3 # 更短的轮询间隔以支持流式
heartbeat_interval = 15 # 心跳间隔
max_idle = 600 # 10 分钟无事件后关闭
idle_time = 0
last_heartbeat = 0
# 事件类型过滤
skip_types = set()
if not include_thinking:
skip_types.update(["thinking_start", "thinking_token", "thinking_end"])
if not include_tool_calls:
skip_types.update(["tool_call_start", "tool_call_input", "tool_call_output", "tool_call_end"])
while True:
try:
async with async_session_factory() as session:
# 查询新事件
result = await session.execute(
select(AgentEvent)
.where(AgentEvent.task_id == task_id)
.where(AgentEvent.sequence > last_sequence)
.order_by(AgentEvent.sequence)
.limit(100)
)
events = result.scalars().all()
# 获取任务状态
current_task = await session.get(AgentTask, task_id)
task_status = current_task.status if current_task else None
if events:
idle_time = 0
for event in events:
last_sequence = event.sequence
# 获取事件类型字符串
event_type = event.event_type.value if hasattr(event.event_type, 'value') else str(event.event_type)
# 过滤事件
if event_type in skip_types:
continue
# 构建事件数据
data = {
"id": event.id,
"type": event_type,
"phase": event.phase.value if event.phase and hasattr(event.phase, 'value') else event.phase,
"message": event.message,
"sequence": event.sequence,
"timestamp": event.created_at.isoformat() if event.created_at else None,
}
# 添加工具调用详情
if include_tool_calls and event.tool_name:
data["tool"] = {
"name": event.tool_name,
"input": event.tool_input,
"output": event.tool_output,
"duration_ms": event.tool_duration_ms,
}
# 添加元数据
if event.event_metadata:
data["metadata"] = event.event_metadata
# 添加 Token 使用
if event.tokens_used:
data["tokens_used"] = event.tokens_used
# 使用标准 SSE 格式
yield f"event: {event_type}\ndata: {json.dumps(data, ensure_ascii=False)}\n\n"
else:
idle_time += poll_interval
# 检查任务是否结束
if task_status in [AgentTaskStatus.COMPLETED, AgentTaskStatus.FAILED, AgentTaskStatus.CANCELLED]:
end_data = {
"type": "task_end",
"status": task_status.value,
"message": f"任务{'完成' if task_status == AgentTaskStatus.COMPLETED else '结束'}",
}
yield f"event: task_end\ndata: {json.dumps(end_data, ensure_ascii=False)}\n\n"
break
# 发送心跳
last_heartbeat += poll_interval
if last_heartbeat >= heartbeat_interval:
last_heartbeat = 0
heartbeat_data = {
"type": "heartbeat",
"timestamp": datetime.now(timezone.utc).isoformat(),
"last_sequence": last_sequence,
}
yield f"event: heartbeat\ndata: {json.dumps(heartbeat_data)}\n\n"
# 检查空闲超时
if idle_time >= max_idle:
timeout_data = {"type": "timeout", "message": "连接超时"}
yield f"event: timeout\ndata: {json.dumps(timeout_data)}\n\n"
break
await asyncio.sleep(poll_interval)
except Exception as e:
logger.error(f"Stream error: {e}")
error_data = {"type": "error", "message": str(e)}
yield f"event: error\ndata: {json.dumps(error_data)}\n\n"
break
return StreamingResponse(
enhanced_event_generator(),
media_type="text/event-stream",
headers={
"Cache-Control": "no-cache",
"Connection": "keep-alive",
"X-Accel-Buffering": "no",
"Content-Type": "text/event-stream; charset=utf-8",
}
)
@router.get("/{task_id}/events/list", response_model=List[AgentEventResponse])
async def list_agent_events(
task_id: str,

View File

@ -210,3 +210,5 @@ async def remove_project_member(

View File

@ -225,3 +225,5 @@ async def toggle_user_status(

View File

@ -29,3 +29,5 @@ def get_password_hash(password: str) -> str:

View File

@ -12,3 +12,5 @@ class Base:

View File

@ -28,3 +28,5 @@ async def async_session_factory():

View File

@ -24,3 +24,5 @@ class InstantAnalysis(Base):

View File

@ -25,3 +25,5 @@ class User(Base):

View File

@ -30,3 +30,5 @@ class UserConfig(Base):

View File

@ -10,3 +10,5 @@ class TokenPayload(BaseModel):

View File

@ -41,3 +41,5 @@ class UserListResponse(BaseModel):

View File

@ -147,11 +147,11 @@ class AnalysisAgent(BaseAgent):
deep_findings = await self._analyze_entry_points(entry_points)
all_findings.extend(deep_findings)
# 分析高风险区域
# 分析高风险区域(现在会调用 LLM
risk_findings = await self._analyze_high_risk_areas(high_risk_areas)
all_findings.extend(risk_findings)
# 语义搜索常见漏洞
# 语义搜索常见漏洞(现在会调用 LLM
vuln_types = config.get("target_vulnerabilities", [
"sql_injection", "xss", "command_injection",
"path_traversal", "ssrf", "hardcoded_secret",
@ -164,6 +164,12 @@ class AnalysisAgent(BaseAgent):
await self.emit_thinking(f"搜索 {vuln_type} 相关代码...")
vuln_findings = await self._search_vulnerability_pattern(vuln_type)
all_findings.extend(vuln_findings)
# 🔥 3. 如果还没有发现,使用 LLM 进行全面扫描
if len(all_findings) < 3:
await self.emit_thinking("执行 LLM 全面代码扫描...")
llm_findings = await self._llm_comprehensive_scan(tech_stack)
all_findings.extend(llm_findings)
# 去重
all_findings = self._deduplicate_findings(all_findings)
@ -292,12 +298,12 @@ class AnalysisAgent(BaseAgent):
return findings
async def _analyze_high_risk_areas(self, high_risk_areas: List[str]) -> List[Dict]:
"""分析高风险区域"""
"""分析高风险区域 - 使用 LLM 深度分析"""
findings = []
pattern_tool = self.tools.get("pattern_match")
read_tool = self.tools.get("read_file")
search_tool = self.tools.get("search_code")
code_analysis_tool = self.tools.get("code_analysis")
if not search_tool:
return findings
@ -305,36 +311,92 @@ class AnalysisAgent(BaseAgent):
# 在高风险区域搜索危险模式
dangerous_patterns = [
("execute(", "sql_injection"),
("query(", "sql_injection"),
("eval(", "code_injection"),
("system(", "command_injection"),
("exec(", "command_injection"),
("subprocess", "command_injection"),
("innerHTML", "xss"),
("document.write", "xss"),
("open(", "path_traversal"),
("requests.get", "ssrf"),
]
for pattern, vuln_type in dangerous_patterns[:5]:
analyzed_files = set()
for pattern, vuln_type in dangerous_patterns[:8]:
if self.is_cancelled:
break
result = await search_tool.execute(keyword=pattern, max_results=10)
if result.success and result.metadata.get("matches", 0) > 0:
for match in result.metadata.get("results", [])[:3]:
for match in result.metadata.get("results", [])[:5]:
file_path = match.get("file", "")
line = match.get("line", 0)
# 检查是否在高风险区域
in_high_risk = any(
area in file_path for area in high_risk_areas
)
# 避免重复分析同一个文件的同一区域
file_key = f"{file_path}:{line // 50}"
if file_key in analyzed_files:
continue
analyzed_files.add(file_key)
if in_high_risk or True: # 暂时包含所有
# 🔥 使用 LLM 深度分析找到的代码
if read_tool and code_analysis_tool:
await self.emit_thinking(f"LLM 分析 {file_path}:{line}{vuln_type} 风险...")
# 读取代码上下文
read_result = await read_tool.execute(
file_path=file_path,
start_line=max(1, line - 15),
end_line=line + 25,
)
if read_result.success:
# 调用 LLM 分析
analysis_result = await code_analysis_tool.execute(
code=read_result.data,
file_path=file_path,
focus=vuln_type,
)
if analysis_result.success and analysis_result.metadata.get("issues"):
for issue in analysis_result.metadata["issues"]:
findings.append({
"vulnerability_type": issue.get("type", vuln_type),
"severity": issue.get("severity", "medium"),
"title": issue.get("title", f"LLM 发现: {vuln_type}"),
"description": issue.get("description", ""),
"file_path": file_path,
"line_start": issue.get("line", line),
"code_snippet": issue.get("code_snippet", match.get("match", "")),
"suggestion": issue.get("suggestion", ""),
"ai_explanation": issue.get("ai_explanation", ""),
"source": "llm_analysis",
"needs_verification": True,
})
elif analysis_result.success:
# LLM 分析了但没发现问题,仍记录原始发现
findings.append({
"vulnerability_type": vuln_type,
"severity": "low",
"title": f"疑似 {vuln_type}: {pattern}",
"description": f"{file_path} 中发现危险模式,但 LLM 分析未确认",
"file_path": file_path,
"line_start": line,
"code_snippet": match.get("match", ""),
"source": "pattern_search",
"needs_verification": True,
})
else:
# 没有 LLM 工具,使用基础模式匹配
findings.append({
"vulnerability_type": vuln_type,
"severity": "high" if in_high_risk else "medium",
"severity": "medium",
"title": f"疑似 {vuln_type}: {pattern}",
"description": f"{file_path} 中发现危险模式 {pattern}",
"file_path": file_path,
"line_start": match.get("line", 0),
"line_start": line,
"code_snippet": match.get("match", ""),
"source": "pattern_search",
"needs_verification": True,
@ -343,10 +405,13 @@ class AnalysisAgent(BaseAgent):
return findings
async def _search_vulnerability_pattern(self, vuln_type: str) -> List[Dict]:
"""搜索特定漏洞模式"""
"""搜索特定漏洞模式 - 使用 RAG + LLM"""
findings = []
security_tool = self.tools.get("security_search")
code_analysis_tool = self.tools.get("code_analysis")
read_tool = self.tools.get("read_file")
if not security_tool:
return findings
@ -357,20 +422,176 @@ class AnalysisAgent(BaseAgent):
if result.success and result.metadata.get("results_count", 0) > 0:
for item in result.metadata.get("results", [])[:5]:
findings.append({
"vulnerability_type": vuln_type,
"severity": "medium",
"title": f"疑似 {vuln_type}",
"description": f"通过语义搜索发现可能存在 {vuln_type}",
"file_path": item.get("file_path", ""),
"line_start": item.get("line_start", 0),
"code_snippet": item.get("content", "")[:500],
"source": "rag_search",
"needs_verification": True,
})
file_path = item.get("file_path", "")
line_start = item.get("line_start", 0)
content = item.get("content", "")[:2000]
# 🔥 使用 LLM 验证 RAG 搜索结果
if code_analysis_tool and content:
await self.emit_thinking(f"LLM 验证 RAG 发现的 {vuln_type}...")
analysis_result = await code_analysis_tool.execute(
code=content,
file_path=file_path,
focus=vuln_type,
)
if analysis_result.success and analysis_result.metadata.get("issues"):
for issue in analysis_result.metadata["issues"]:
findings.append({
"vulnerability_type": issue.get("type", vuln_type),
"severity": issue.get("severity", "medium"),
"title": issue.get("title", f"LLM 确认: {vuln_type}"),
"description": issue.get("description", ""),
"file_path": file_path,
"line_start": issue.get("line", line_start),
"code_snippet": issue.get("code_snippet", content[:500]),
"suggestion": issue.get("suggestion", ""),
"ai_explanation": issue.get("ai_explanation", ""),
"source": "rag_llm_analysis",
"needs_verification": True,
})
else:
# RAG 找到但 LLM 未确认
findings.append({
"vulnerability_type": vuln_type,
"severity": "low",
"title": f"疑似 {vuln_type} (待确认)",
"description": f"RAG 搜索发现可能存在 {vuln_type},但 LLM 未确认",
"file_path": file_path,
"line_start": line_start,
"code_snippet": content[:500],
"source": "rag_search",
"needs_verification": True,
})
else:
findings.append({
"vulnerability_type": vuln_type,
"severity": "medium",
"title": f"疑似 {vuln_type}",
"description": f"通过语义搜索发现可能存在 {vuln_type}",
"file_path": file_path,
"line_start": line_start,
"code_snippet": content[:500],
"source": "rag_search",
"needs_verification": True,
})
return findings
async def _llm_comprehensive_scan(self, tech_stack: Dict) -> List[Dict]:
"""
LLM 全面代码扫描
当其他方法没有发现足够的问题时使用 LLM 直接分析关键文件
"""
findings = []
list_tool = self.tools.get("list_files")
read_tool = self.tools.get("read_file")
code_analysis_tool = self.tools.get("code_analysis")
if not all([list_tool, read_tool, code_analysis_tool]):
return findings
await self.emit_thinking("LLM 全面扫描关键代码文件...")
# 确定要扫描的文件类型
languages = tech_stack.get("languages", [])
file_patterns = []
if "Python" in languages:
file_patterns.extend(["*.py"])
if "JavaScript" in languages or "TypeScript" in languages:
file_patterns.extend(["*.js", "*.ts"])
if "Go" in languages:
file_patterns.extend(["*.go"])
if "Java" in languages:
file_patterns.extend(["*.java"])
if "PHP" in languages:
file_patterns.extend(["*.php"])
if not file_patterns:
file_patterns = ["*.py", "*.js", "*.ts", "*.go", "*.java", "*.php"]
# 扫描关键目录
key_dirs = ["src", "app", "api", "routes", "controllers", "handlers", "lib", "utils", "."]
scanned_files = 0
max_files_to_scan = 10
for key_dir in key_dirs:
if scanned_files >= max_files_to_scan or self.is_cancelled:
break
for pattern in file_patterns[:3]:
if scanned_files >= max_files_to_scan or self.is_cancelled:
break
# 列出文件
list_result = await list_tool.execute(
directory=key_dir,
pattern=pattern,
recursive=True,
max_files=20,
)
if not list_result.success:
continue
# 从输出中提取文件路径
output = list_result.data
file_paths = []
for line in output.split('\n'):
line = line.strip()
if line.startswith('📄 '):
file_paths.append(line[2:].strip())
# 分析每个文件
for file_path in file_paths[:5]:
if scanned_files >= max_files_to_scan or self.is_cancelled:
break
# 跳过测试文件和配置文件
if any(skip in file_path.lower() for skip in ['test', 'spec', 'mock', '__pycache__', 'node_modules']):
continue
await self.emit_thinking(f"LLM 分析文件: {file_path}")
# 读取文件
read_result = await read_tool.execute(
file_path=file_path,
max_lines=200,
)
if not read_result.success:
continue
scanned_files += 1
# 🔥 LLM 深度分析
analysis_result = await code_analysis_tool.execute(
code=read_result.data,
file_path=file_path,
)
if analysis_result.success and analysis_result.metadata.get("issues"):
for issue in analysis_result.metadata["issues"]:
findings.append({
"vulnerability_type": issue.get("type", "other"),
"severity": issue.get("severity", "medium"),
"title": issue.get("title", "LLM 发现的安全问题"),
"description": issue.get("description", ""),
"file_path": file_path,
"line_start": issue.get("line", 0),
"code_snippet": issue.get("code_snippet", ""),
"suggestion": issue.get("suggestion", ""),
"ai_explanation": issue.get("ai_explanation", ""),
"source": "llm_comprehensive_scan",
"needs_verification": True,
})
await self.emit_thinking(f"LLM 全面扫描完成,分析了 {scanned_files} 个文件")
return findings
def _deduplicate_findings(self, findings: List[Dict]) -> List[Dict]:
"""去重发现"""
seen = set()

View File

@ -0,0 +1,375 @@
"""
真正的 ReAct Agent 实现
LLM 是大脑全程参与决策
ReAct 循环:
1. Thought: LLM 思考当前状态和下一步
2. Action: LLM 决定调用哪个工具
3. Observation: 执行工具获取结果
4. 重复直到 LLM 决定完成
"""
import json
import logging
import re
from typing import List, Dict, Any, Optional, Tuple
from dataclasses import dataclass
from .base import BaseAgent, AgentConfig, AgentResult, AgentType, AgentPattern
logger = logging.getLogger(__name__)
REACT_SYSTEM_PROMPT = """你是 DeepAudit 安全审计 Agent一个专业的代码安全分析专家。
## 你的任务
对目标项目进行全面的安全审计发现潜在的安全漏洞
## 你的工具
{tools_description}
## 工作方式
你需要通过 **思考-行动-观察** 循环来完成任务
1. **Thought**: 分析当前情况思考下一步应该做什么
2. **Action**: 选择一个工具并执行
3. **Observation**: 观察工具返回的结果
4. 重复上述过程直到你认为审计完成
## 输出格式
每一步必须严格按照以下格式输出
```
Thought: [你的思考过程分析当前状态决定下一步]
Action: [工具名称]
Action Input: [工具参数JSON 格式]
```
当你完成分析后输出
```
Thought: [总结分析结果]
Final Answer: [JSON 格式的最终发现]
```
## Final Answer 格式
```json
{{
"findings": [
{{
"vulnerability_type": "sql_injection",
"severity": "high",
"title": "SQL 注入漏洞",
"description": "详细描述",
"file_path": "path/to/file.py",
"line_start": 42,
"code_snippet": "危险代码片段",
"suggestion": "修复建议"
}}
],
"summary": "审计总结"
}}
```
## 审计策略建议
1. 先用 list_files 了解项目结构
2. 识别关键文件路由控制器数据库操作
3. 使用 search_code 搜索危险模式eval, exec, query, innerHTML
4. 读取可疑文件进行深度分析
5. 如果有 semgrep用它进行全面扫描
## 重点关注的漏洞类型
- SQL 注入 (query, execute, raw SQL)
- XSS (innerHTML, document.write, v-html)
- 命令注入 (exec, system, subprocess, child_process)
- 路径遍历 (open, readFile, path concatenation)
- SSRF (requests, fetch, http client)
- 硬编码密钥 (password, secret, api_key, token)
- 不安全的反序列化 (pickle, yaml.load, eval)
现在开始审计"""
@dataclass
class AgentStep:
"""Agent 执行步骤"""
thought: str
action: Optional[str] = None
action_input: Optional[Dict] = None
observation: Optional[str] = None
is_final: bool = False
final_answer: Optional[Dict] = None
class ReActAgent(BaseAgent):
"""
真正的 ReAct Agent
LLM 全程参与决策自主选择工具和分析策略
"""
def __init__(
self,
llm_service,
tools: Dict[str, Any],
event_emitter=None,
agent_type: AgentType = AgentType.ANALYSIS,
max_iterations: int = 30,
):
config = AgentConfig(
name="ReActAgent",
agent_type=agent_type,
pattern=AgentPattern.REACT,
max_iterations=max_iterations,
system_prompt=REACT_SYSTEM_PROMPT,
)
super().__init__(config, llm_service, tools, event_emitter)
self._conversation_history: List[Dict[str, str]] = []
self._steps: List[AgentStep] = []
def _get_tools_description(self) -> str:
"""生成工具描述"""
descriptions = []
for name, tool in self.tools.items():
if name.startswith("_"):
continue
desc = f"### {name}\n"
desc += f"{tool.description}\n"
# 添加参数说明
if hasattr(tool, 'args_schema') and tool.args_schema:
schema = tool.args_schema.schema()
properties = schema.get("properties", {})
if properties:
desc += "参数:\n"
for param_name, param_info in properties.items():
param_desc = param_info.get("description", "")
param_type = param_info.get("type", "string")
desc += f" - {param_name} ({param_type}): {param_desc}\n"
descriptions.append(desc)
return "\n".join(descriptions)
def _build_system_prompt(self, project_info: Dict, task_context: str = "") -> str:
"""构建系统提示词"""
tools_desc = self._get_tools_description()
prompt = self.config.system_prompt.format(tools_description=tools_desc)
if project_info:
prompt += f"\n\n## 项目信息\n"
prompt += f"- 名称: {project_info.get('name', 'unknown')}\n"
prompt += f"- 语言: {', '.join(project_info.get('languages', ['unknown']))}\n"
prompt += f"- 文件数: {project_info.get('file_count', 'unknown')}\n"
if task_context:
prompt += f"\n\n## 任务上下文\n{task_context}"
return prompt
def _parse_llm_response(self, response: str) -> AgentStep:
"""解析 LLM 响应"""
step = AgentStep(thought="")
# 提取 Thought
thought_match = re.search(r'Thought:\s*(.*?)(?=Action:|Final Answer:|$)', response, re.DOTALL)
if thought_match:
step.thought = thought_match.group(1).strip()
# 检查是否是最终答案
final_match = re.search(r'Final Answer:\s*(.*?)$', response, re.DOTALL)
if final_match:
step.is_final = True
try:
# 尝试提取 JSON
answer_text = final_match.group(1).strip()
# 移除 markdown 代码块
answer_text = re.sub(r'```json\s*', '', answer_text)
answer_text = re.sub(r'```\s*', '', answer_text)
step.final_answer = json.loads(answer_text)
except json.JSONDecodeError:
step.final_answer = {"raw_answer": final_match.group(1).strip()}
return step
# 提取 Action
action_match = re.search(r'Action:\s*(\w+)', response)
if action_match:
step.action = action_match.group(1).strip()
# 提取 Action Input
input_match = re.search(r'Action Input:\s*(.*?)(?=Thought:|Action:|Observation:|$)', response, re.DOTALL)
if input_match:
input_text = input_match.group(1).strip()
# 移除 markdown 代码块
input_text = re.sub(r'```json\s*', '', input_text)
input_text = re.sub(r'```\s*', '', input_text)
try:
step.action_input = json.loads(input_text)
except json.JSONDecodeError:
# 尝试简单解析
step.action_input = {"raw_input": input_text}
return step
async def _execute_tool(self, tool_name: str, tool_input: Dict) -> str:
"""执行工具"""
tool = self.tools.get(tool_name)
if not tool:
return f"错误: 工具 '{tool_name}' 不存在。可用工具: {list(self.tools.keys())}"
try:
self._tool_calls += 1
await self.emit_tool_call(tool_name, tool_input)
import time
start = time.time()
result = await tool.execute(**tool_input)
duration_ms = int((time.time() - start) * 1000)
await self.emit_tool_result(tool_name, str(result.data)[:200], duration_ms)
if result.success:
# 截断过长的输出
output = str(result.data)
if len(output) > 4000:
output = output[:4000] + "\n\n... [输出已截断,共 {} 字符]".format(len(str(result.data)))
return output
else:
return f"工具执行失败: {result.error}"
except Exception as e:
logger.error(f"Tool execution error: {e}")
return f"工具执行错误: {str(e)}"
async def run(self, input_data: Dict[str, Any]) -> AgentResult:
"""
执行 ReAct Agent
LLM 全程参与自主决策
"""
import time
start_time = time.time()
project_info = input_data.get("project_info", {})
task_context = input_data.get("task_context", "")
config = input_data.get("config", {})
# 构建系统提示词
system_prompt = self._build_system_prompt(project_info, task_context)
# 初始化对话历史
self._conversation_history = [
{"role": "system", "content": system_prompt},
{"role": "user", "content": "请开始对项目进行安全审计。首先了解项目结构,然后系统性地搜索和分析潜在的安全漏洞。"},
]
self._steps = []
all_findings = []
await self.emit_thinking("🤖 ReAct Agent 启动LLM 开始自主分析...")
try:
for iteration in range(self.config.max_iterations):
if self.is_cancelled:
break
self._iteration = iteration + 1
await self.emit_thinking(f"💭 第 {iteration + 1} 轮思考...")
# 🔥 调用 LLM 进行思考和决策
response = await self.llm_service.chat_completion_raw(
messages=self._conversation_history,
temperature=0.1,
max_tokens=2048,
)
llm_output = response.get("content", "")
self._total_tokens += response.get("usage", {}).get("total_tokens", 0)
# 发射思考事件
await self.emit_event("thinking", f"LLM: {llm_output[:500]}...")
# 解析 LLM 响应
step = self._parse_llm_response(llm_output)
self._steps.append(step)
# 添加 LLM 响应到历史
self._conversation_history.append({
"role": "assistant",
"content": llm_output,
})
# 检查是否完成
if step.is_final:
await self.emit_thinking("✅ LLM 完成分析,生成最终报告")
if step.final_answer and "findings" in step.final_answer:
all_findings = step.final_answer["findings"]
break
# 执行工具
if step.action:
await self.emit_thinking(f"🔧 LLM 决定调用工具: {step.action}")
observation = await self._execute_tool(
step.action,
step.action_input or {}
)
step.observation = observation
# 添加观察结果到历史
self._conversation_history.append({
"role": "user",
"content": f"Observation: {observation}",
})
else:
# LLM 没有选择工具,提示它继续
self._conversation_history.append({
"role": "user",
"content": "请继续分析,选择一个工具执行,或者如果分析完成,输出 Final Answer。",
})
duration_ms = int((time.time() - start_time) * 1000)
await self.emit_event(
"info",
f"🎯 ReAct Agent 完成: {len(all_findings)} 个发现, {self._iteration} 轮迭代, {self._tool_calls} 次工具调用"
)
return AgentResult(
success=True,
data={
"findings": all_findings,
"steps": [
{
"thought": s.thought,
"action": s.action,
"action_input": s.action_input,
"observation": s.observation[:500] if s.observation else None,
}
for s in self._steps
],
},
iterations=self._iteration,
tool_calls=self._tool_calls,
tokens_used=self._total_tokens,
duration_ms=duration_ms,
)
except Exception as e:
logger.error(f"ReAct Agent failed: {e}", exc_info=True)
return AgentResult(success=False, error=str(e))
def get_conversation_history(self) -> List[Dict[str, str]]:
"""获取对话历史"""
return self._conversation_history
def get_steps(self) -> List[AgentStep]:
"""获取执行步骤"""
return self._steps

View File

@ -192,6 +192,37 @@ class AgentEventEmitter:
"percentage": percentage,
},
))
async def emit_task_complete(
self,
findings_count: int,
duration_ms: int,
message: Optional[str] = None,
):
"""发射任务完成事件"""
await self.emit(AgentEventData(
event_type="task_complete",
message=message or f"✅ 审计完成!发现 {findings_count} 个漏洞,耗时 {duration_ms/1000:.1f}",
metadata={
"findings_count": findings_count,
"duration_ms": duration_ms,
},
))
async def emit_task_error(self, error: str, message: Optional[str] = None):
"""发射任务错误事件"""
await self.emit(AgentEventData(
event_type="task_error",
message=message or f"❌ 任务失败: {error}",
metadata={"error": error},
))
async def emit_task_cancelled(self, message: Optional[str] = None):
"""发射任务取消事件"""
await self.emit(AgentEventData(
event_type="task_cancel",
message=message or "⚠️ 任务已取消",
))
class EventManager:
@ -368,4 +399,15 @@ class EventManager:
def create_emitter(self, task_id: str) -> AgentEventEmitter:
"""创建事件发射器"""
return AgentEventEmitter(task_id, self)
async def close(self):
"""关闭事件管理器,清理资源"""
# 清理所有队列
for task_id in list(self._event_queues.keys()):
self.remove_queue(task_id)
# 清理所有回调
self._event_callbacks.clear()
logger.debug("EventManager closed")

View File

@ -15,6 +15,7 @@ from sqlalchemy.ext.asyncio import AsyncSession
from langgraph.graph import StateGraph, END
from langgraph.checkpoint.memory import MemorySaver
from app.services.agent.streaming import StreamHandler, StreamEvent, StreamEventType
from app.models.agent_task import (
AgentTask, AgentEvent, AgentFinding,
AgentTaskStatus, AgentTaskPhase, AgentEventType,
@ -39,11 +40,15 @@ logger = logging.getLogger(__name__)
class LLMService:
"""LLM 服务封装"""
"""
LLM 服务封装
提供代码分析漏洞检测等 AI 功能
"""
def __init__(self, model: Optional[str] = None, api_key: Optional[str] = None):
self.model = model or settings.LLM_MODEL or "gpt-4o-mini"
self.api_key = api_key or settings.LLM_API_KEY
self.base_url = settings.LLM_BASE_URL
async def chat_completion_raw(
self,
@ -61,6 +66,7 @@ class LLMService:
temperature=temperature,
max_tokens=max_tokens,
api_key=self.api_key,
base_url=self.base_url,
)
return {
@ -75,6 +81,125 @@ class LLMService:
except Exception as e:
logger.error(f"LLM call failed: {e}")
raise
async def analyze_code(self, code: str, language: str) -> Dict[str, Any]:
"""
分析代码安全问题
Args:
code: 代码内容
language: 编程语言
Returns:
分析结果包含 issues 列表
"""
prompt = f"""请分析以下 {language} 代码的安全问题。
代码:
```{language}
{code[:8000]}
```
请识别所有潜在的安全漏洞包括但不限于:
- SQL 注入
- XSS (跨站脚本)
- 命令注入
- 路径遍历
- 不安全的反序列化
- 硬编码密钥/密码
- 不安全的加密
- SSRF
- 认证/授权问题
对于每个发现的问题请提供:
1. 漏洞类型
2. 严重程度 (critical/high/medium/low)
3. 问题描述
4. 具体行号
5. 修复建议
请以 JSON 格式返回结果:
{{
"issues": [
{{
"type": "漏洞类型",
"severity": "严重程度",
"title": "问题标题",
"description": "详细描述",
"line": 行号,
"code_snippet": "相关代码片段",
"suggestion": "修复建议"
}}
],
"quality_score": 0-100
}}
如果没有发现安全问题返回空的 issues 数组和较高的 quality_score"""
try:
result = await self.chat_completion_raw(
messages=[
{"role": "system", "content": "你是一位专业的代码安全审计专家,擅长发现代码中的安全漏洞。请只返回 JSON 格式的结果,不要包含其他内容。"},
{"role": "user", "content": prompt},
],
temperature=0.1,
max_tokens=4096,
)
content = result.get("content", "{}")
# 尝试提取 JSON
import json
import re
# 尝试直接解析
try:
return json.loads(content)
except json.JSONDecodeError:
pass
# 尝试从 markdown 代码块提取
json_match = re.search(r'```(?:json)?\s*([\s\S]*?)\s*```', content)
if json_match:
try:
return json.loads(json_match.group(1))
except json.JSONDecodeError:
pass
# 返回空结果
return {"issues": [], "quality_score": 80}
except Exception as e:
logger.error(f"Code analysis failed: {e}")
return {"issues": [], "quality_score": 0, "error": str(e)}
async def analyze_code_with_custom_prompt(
self,
code: str,
language: str,
prompt: str,
**kwargs
) -> Dict[str, Any]:
"""使用自定义提示词分析代码"""
full_prompt = prompt.replace("{code}", code).replace("{language}", language)
try:
result = await self.chat_completion_raw(
messages=[
{"role": "system", "content": "你是一位专业的代码安全审计专家。"},
{"role": "user", "content": full_prompt},
],
temperature=0.1,
)
return {
"analysis": result.get("content", ""),
"usage": result.get("usage", {}),
}
except Exception as e:
logger.error(f"Custom analysis failed: {e}")
return {"analysis": "", "error": str(e)}
class AgentRunner:
@ -97,8 +222,9 @@ class AgentRunner:
self.task = task
self.project_root = project_root
# 事件管理
self.event_manager = EventManager()
# 事件管理 - 传入 db_session_factory 以持久化事件
from app.db.session import async_session_factory
self.event_manager = EventManager(db_session_factory=async_session_factory)
self.event_emitter = AgentEventEmitter(task.id, self.event_manager)
# LLM 服务
@ -120,6 +246,22 @@ class AgentRunner:
# 状态
self._cancelled = False
self._running_task: Optional[asyncio.Task] = None
# 流式处理器
self.stream_handler = StreamHandler(task.id)
def cancel(self):
"""取消任务"""
self._cancelled = True
if self._running_task and not self._running_task.done():
self._running_task.cancel()
logger.info(f"Task {self.task.id} cancellation requested")
@property
def is_cancelled(self) -> bool:
"""检查是否已取消"""
return self._cancelled
async def initialize(self):
"""初始化 Runner"""
@ -149,15 +291,15 @@ class AgentRunner:
)
self.indexer = CodeIndexer(
embedding_service=embedding_service,
vector_db_path=settings.VECTOR_DB_PATH,
collection_name=f"project_{self.task.project_id}",
embedding_service=embedding_service,
persist_directory=settings.VECTOR_DB_PATH,
)
self.retriever = CodeRetriever(
embedding_service=embedding_service,
vector_db_path=settings.VECTOR_DB_PATH,
collection_name=f"project_{self.task.project_id}",
embedding_service=embedding_service,
persist_directory=settings.VECTOR_DB_PATH,
)
except Exception as e:
@ -261,6 +403,18 @@ class AgentRunner:
Returns:
最终状态
"""
result = {}
async for _ in self.run_with_streaming():
pass # 消费所有事件
return result
async def run_with_streaming(self) -> AsyncGenerator[StreamEvent, None]:
"""
带流式输出的审计执行
Yields:
StreamEvent: 流式事件包含 LLM 思考工具调用等
"""
import time
start_time = time.time()
@ -271,17 +425,28 @@ class AgentRunner:
# 更新任务状态
await self._update_task_status(AgentTaskStatus.RUNNING)
# 发射任务开始事件
yield StreamEvent(
event_type=StreamEventType.TASK_START,
sequence=self.stream_handler._next_sequence(),
data={"task_id": self.task.id, "message": "🚀 审计任务开始"},
)
# 1. 索引代码
await self._index_code()
if self._cancelled:
return {"success": False, "error": "任务已取消"}
yield StreamEvent(
event_type=StreamEventType.TASK_CANCEL,
sequence=self.stream_handler._next_sequence(),
data={"message": "任务已取消"},
)
return
# 2. 收集项目信息
project_info = await self._collect_project_info()
# 3. 构建初始状态
# 从任务字段构建配置
task_config = {
"target_vulnerabilities": self.task.target_vulnerabilities or [],
"verification_level": self.task.verification_level or "sandbox",
@ -314,7 +479,7 @@ class AgentRunner:
"error": None,
}
# 4. 执行 LangGraph
# 4. 执行 LangGraph with astream_events
await self.event_emitter.emit_phase_start("langgraph", "🔄 启动 LangGraph 工作流")
run_config = {
@ -325,26 +490,57 @@ class AgentRunner:
final_state = None
# 流式执行并发射事件
async for event in self.graph.astream(initial_state, config=run_config):
if self._cancelled:
break
# 处理每个节点的输出
for node_name, node_output in event.items():
await self._handle_node_output(node_name, node_output)
# 使用 astream_events 获取详细事件流
try:
async for event in self.graph.astream_events(
initial_state,
config=run_config,
version="v2",
):
if self._cancelled:
break
# 更新阶段
phase_map = {
"recon": AgentTaskPhase.RECONNAISSANCE,
"analysis": AgentTaskPhase.ANALYSIS,
"verification": AgentTaskPhase.VERIFICATION,
"report": AgentTaskPhase.REPORTING,
}
if node_name in phase_map:
await self._update_task_phase(phase_map[node_name])
# 处理 LangGraph 事件
stream_event = await self.stream_handler.process_langgraph_event(event)
if stream_event:
# 同步到 event_emitter 以持久化
await self._sync_stream_event_to_db(stream_event)
yield stream_event
final_state = node_output
# 更新最终状态
if event.get("event") == "on_chain_end":
output = event.get("data", {}).get("output")
if isinstance(output, dict):
final_state = output
except Exception as e:
# 如果 astream_events 不可用,回退到 astream
logger.warning(f"astream_events not available, falling back to astream: {e}")
async for event in self.graph.astream(initial_state, config=run_config):
if self._cancelled:
break
for node_name, node_output in event.items():
await self._handle_node_output(node_name, node_output)
# 发射节点事件
yield StreamEvent(
event_type=StreamEventType.NODE_END,
sequence=self.stream_handler._next_sequence(),
node_name=node_name,
data={"message": f"节点 {node_name} 完成"},
)
phase_map = {
"recon": AgentTaskPhase.RECONNAISSANCE,
"analysis": AgentTaskPhase.ANALYSIS,
"verification": AgentTaskPhase.VERIFICATION,
"report": AgentTaskPhase.REPORTING,
}
if node_name in phase_map:
await self._update_task_phase(phase_map[node_name])
final_state = node_output
# 5. 获取最终状态
if not final_state:
@ -355,6 +551,13 @@ class AgentRunner:
findings = final_state.get("findings", [])
await self._save_findings(findings)
# 发射发现事件
for finding in findings[:10]: # 限制数量
yield self.stream_handler.create_finding_event(
finding,
is_verified=finding.get("is_verified", False),
)
# 7. 更新任务摘要
summary = final_state.get("summary", {})
security_score = final_state.get("security_score", 100)
@ -374,30 +577,59 @@ class AgentRunner:
duration_ms=duration_ms,
)
return {
"success": True,
"data": {
"findings": findings,
"verified_findings": final_state.get("verified_findings", []),
"summary": summary,
yield StreamEvent(
event_type=StreamEventType.TASK_COMPLETE,
sequence=self.stream_handler._next_sequence(),
data={
"findings_count": len(findings),
"verified_count": len(final_state.get("verified_findings", [])),
"security_score": security_score,
"duration_ms": duration_ms,
"message": f"✅ 审计完成!发现 {len(findings)} 个漏洞",
},
"duration_ms": duration_ms,
}
)
except asyncio.CancelledError:
await self._update_task_status(AgentTaskStatus.CANCELLED)
return {"success": False, "error": "任务已取消"}
yield StreamEvent(
event_type=StreamEventType.TASK_CANCEL,
sequence=self.stream_handler._next_sequence(),
data={"message": "任务已取消"},
)
except Exception as e:
logger.error(f"LangGraph run failed: {e}", exc_info=True)
await self._update_task_status(AgentTaskStatus.FAILED, str(e))
await self.event_emitter.emit_error(str(e))
return {"success": False, "error": str(e)}
yield StreamEvent(
event_type=StreamEventType.TASK_ERROR,
sequence=self.stream_handler._next_sequence(),
data={"error": str(e), "message": f"❌ 审计失败: {e}"},
)
finally:
await self._cleanup()
async def _sync_stream_event_to_db(self, event: StreamEvent):
"""同步流式事件到数据库"""
try:
# 将 StreamEvent 转换为 AgentEventData
await self.event_manager.add_event(
task_id=self.task.id,
event_type=event.event_type.value,
sequence=event.sequence,
phase=event.phase,
message=event.data.get("message"),
tool_name=event.tool_name,
tool_input=event.data.get("input") or event.data.get("input_params"),
tool_output=event.data.get("output") or event.data.get("output_data"),
tool_duration_ms=event.data.get("duration_ms"),
metadata=event.data,
)
except Exception as e:
logger.warning(f"Failed to sync stream event to db: {e}")
async def _handle_node_output(self, node_name: str, output: Dict[str, Any]):
"""处理节点输出"""
# 发射节点事件
@ -445,7 +677,8 @@ class AgentRunner:
return
await self.event_emitter.emit_progress(
progress.processed / max(progress.total, 1) * 100,
progress.processed_files,
progress.total_files,
f"正在索引: {progress.current_file or 'N/A'}"
)
@ -502,13 +735,23 @@ class AgentRunner:
type_map = {
"sql_injection": VulnerabilityType.SQL_INJECTION,
"nosql_injection": VulnerabilityType.NOSQL_INJECTION,
"xss": VulnerabilityType.XSS,
"command_injection": VulnerabilityType.COMMAND_INJECTION,
"code_injection": VulnerabilityType.CODE_INJECTION,
"path_traversal": VulnerabilityType.PATH_TRAVERSAL,
"file_inclusion": VulnerabilityType.FILE_INCLUSION,
"ssrf": VulnerabilityType.SSRF,
"xxe": VulnerabilityType.XXE,
"deserialization": VulnerabilityType.DESERIALIZATION,
"auth_bypass": VulnerabilityType.AUTH_BYPASS,
"idor": VulnerabilityType.IDOR,
"sensitive_data_exposure": VulnerabilityType.SENSITIVE_DATA_EXPOSURE,
"hardcoded_secret": VulnerabilityType.HARDCODED_SECRET,
"deserialization": VulnerabilityType.INSECURE_DESERIALIZATION,
"weak_crypto": VulnerabilityType.WEAK_CRYPTO,
"race_condition": VulnerabilityType.RACE_CONDITION,
"business_logic": VulnerabilityType.BUSINESS_LOGIC,
"memory_corruption": VulnerabilityType.MEMORY_CORRUPTION,
}
for finding in findings:
@ -536,7 +779,7 @@ class AgentRunner:
is_verified=finding.get("is_verified", False),
confidence=finding.get("confidence", 0.5),
poc=finding.get("poc"),
status=FindingStatus.VERIFIED if finding.get("is_verified") else FindingStatus.OPEN,
status=FindingStatus.VERIFIED if finding.get("is_verified") else FindingStatus.NEW,
)
self.db.add(db_finding)
@ -603,10 +846,6 @@ class AgentRunner:
await self.event_manager.close()
except Exception as e:
logger.warning(f"Cleanup error: {e}")
def cancel(self):
"""取消任务"""
self._cancelled = True
# 便捷函数

View File

@ -0,0 +1,17 @@
"""
Agent 流式输出模块
支持 LLM Token 流式输出工具调用展示思考过程展示
"""
from .stream_handler import StreamHandler, StreamEvent, StreamEventType
from .token_streamer import TokenStreamer
from .tool_stream import ToolStreamHandler
__all__ = [
"StreamHandler",
"StreamEvent",
"StreamEventType",
"TokenStreamer",
"ToolStreamHandler",
]

View File

@ -0,0 +1,453 @@
"""
流式事件处理器
处理 LangGraph 的各种流式事件并转换为前端可消费的格式
"""
import json
import logging
from enum import Enum
from typing import Any, Dict, Optional, AsyncGenerator, List
from dataclasses import dataclass, field
from datetime import datetime, timezone
logger = logging.getLogger(__name__)
class StreamEventType(str, Enum):
"""流式事件类型"""
# LLM 相关
THINKING_START = "thinking_start" # 开始思考
THINKING_TOKEN = "thinking_token" # 思考 Token
THINKING_END = "thinking_end" # 思考结束
# 工具调用相关
TOOL_CALL_START = "tool_call_start" # 工具调用开始
TOOL_CALL_INPUT = "tool_call_input" # 工具输入参数
TOOL_CALL_OUTPUT = "tool_call_output" # 工具输出结果
TOOL_CALL_END = "tool_call_end" # 工具调用结束
TOOL_CALL_ERROR = "tool_call_error" # 工具调用错误
# 节点相关
NODE_START = "node_start" # 节点开始
NODE_END = "node_end" # 节点结束
# 阶段相关
PHASE_START = "phase_start"
PHASE_END = "phase_end"
# 发现相关
FINDING_NEW = "finding_new" # 新发现
FINDING_VERIFIED = "finding_verified" # 验证通过
# 状态相关
PROGRESS = "progress"
INFO = "info"
WARNING = "warning"
ERROR = "error"
# 任务相关
TASK_START = "task_start"
TASK_COMPLETE = "task_complete"
TASK_ERROR = "task_error"
TASK_CANCEL = "task_cancel"
# 心跳
HEARTBEAT = "heartbeat"
@dataclass
class StreamEvent:
"""流式事件"""
event_type: StreamEventType
data: Dict[str, Any] = field(default_factory=dict)
timestamp: str = field(default_factory=lambda: datetime.now(timezone.utc).isoformat())
sequence: int = 0
# 可选字段
node_name: Optional[str] = None
phase: Optional[str] = None
tool_name: Optional[str] = None
def to_sse(self) -> str:
"""转换为 SSE 格式"""
event_data = {
"type": self.event_type.value,
"data": self.data,
"timestamp": self.timestamp,
"sequence": self.sequence,
}
if self.node_name:
event_data["node"] = self.node_name
if self.phase:
event_data["phase"] = self.phase
if self.tool_name:
event_data["tool"] = self.tool_name
return f"event: {self.event_type.value}\ndata: {json.dumps(event_data, ensure_ascii=False)}\n\n"
def to_dict(self) -> Dict[str, Any]:
"""转换为字典"""
return {
"event_type": self.event_type.value,
"data": self.data,
"timestamp": self.timestamp,
"sequence": self.sequence,
"node_name": self.node_name,
"phase": self.phase,
"tool_name": self.tool_name,
}
class StreamHandler:
"""
流式事件处理器
最佳实践:
1. 使用 astream_events 捕获所有 LangGraph 事件
2. 将内部事件转换为前端友好的格式
3. 支持多种事件类型的分发
"""
def __init__(self, task_id: str):
self.task_id = task_id
self._sequence = 0
self._current_phase = None
self._current_node = None
self._thinking_buffer = []
self._tool_states: Dict[str, Dict] = {}
def _next_sequence(self) -> int:
"""获取下一个序列号"""
self._sequence += 1
return self._sequence
async def process_langgraph_event(self, event: Dict[str, Any]) -> Optional[StreamEvent]:
"""
处理 LangGraph 事件
支持的事件类型:
- on_chain_start: /节点开始
- on_chain_end: /节点结束
- on_chain_stream: LLM Token
- on_chat_model_start: 模型开始
- on_chat_model_stream: 模型 Token
- on_chat_model_end: 模型结束
- on_tool_start: 工具开始
- on_tool_end: 工具结束
- on_custom_event: 自定义事件
"""
event_kind = event.get("event", "")
event_name = event.get("name", "")
event_data = event.get("data", {})
# LLM Token 流
if event_kind == "on_chat_model_stream":
return await self._handle_llm_stream(event_data, event_name)
# LLM 开始
elif event_kind == "on_chat_model_start":
return await self._handle_llm_start(event_data, event_name)
# LLM 结束
elif event_kind == "on_chat_model_end":
return await self._handle_llm_end(event_data, event_name)
# 工具开始
elif event_kind == "on_tool_start":
return await self._handle_tool_start(event_name, event_data)
# 工具结束
elif event_kind == "on_tool_end":
return await self._handle_tool_end(event_name, event_data)
# 节点开始
elif event_kind == "on_chain_start" and self._is_node_event(event_name):
return await self._handle_node_start(event_name, event_data)
# 节点结束
elif event_kind == "on_chain_end" and self._is_node_event(event_name):
return await self._handle_node_end(event_name, event_data)
# 自定义事件
elif event_kind == "on_custom_event":
return await self._handle_custom_event(event_name, event_data)
return None
def _is_node_event(self, name: str) -> bool:
"""判断是否是节点事件"""
node_names = ["recon", "analysis", "verification", "report", "ReconNode", "AnalysisNode", "VerificationNode", "ReportNode"]
return any(n.lower() in name.lower() for n in node_names)
async def _handle_llm_start(self, data: Dict, name: str) -> StreamEvent:
"""处理 LLM 开始事件"""
self._thinking_buffer = []
return StreamEvent(
event_type=StreamEventType.THINKING_START,
sequence=self._next_sequence(),
node_name=self._current_node,
phase=self._current_phase,
data={
"model": name,
"message": "🤔 正在思考...",
},
)
async def _handle_llm_stream(self, data: Dict, name: str) -> Optional[StreamEvent]:
"""处理 LLM Token 流事件"""
chunk = data.get("chunk")
if not chunk:
return None
# 提取 Token 内容
content = ""
if hasattr(chunk, "content"):
content = chunk.content
elif isinstance(chunk, dict):
content = chunk.get("content", "")
if not content:
return None
# 添加到缓冲区
self._thinking_buffer.append(content)
return StreamEvent(
event_type=StreamEventType.THINKING_TOKEN,
sequence=self._next_sequence(),
node_name=self._current_node,
phase=self._current_phase,
data={
"token": content,
"accumulated": "".join(self._thinking_buffer),
},
)
async def _handle_llm_end(self, data: Dict, name: str) -> StreamEvent:
"""处理 LLM 结束事件"""
full_response = "".join(self._thinking_buffer)
self._thinking_buffer = []
# 提取使用的 Token 数
usage = {}
output = data.get("output")
if output and hasattr(output, "usage_metadata"):
usage = {
"input_tokens": getattr(output.usage_metadata, "input_tokens", 0),
"output_tokens": getattr(output.usage_metadata, "output_tokens", 0),
}
return StreamEvent(
event_type=StreamEventType.THINKING_END,
sequence=self._next_sequence(),
node_name=self._current_node,
phase=self._current_phase,
data={
"response": full_response[:2000], # 截断长响应
"usage": usage,
"message": "💡 思考完成",
},
)
async def _handle_tool_start(self, tool_name: str, data: Dict) -> StreamEvent:
"""处理工具开始事件"""
import time
tool_input = data.get("input", {})
# 记录工具状态
self._tool_states[tool_name] = {
"start_time": time.time(),
"input": tool_input,
}
return StreamEvent(
event_type=StreamEventType.TOOL_CALL_START,
sequence=self._next_sequence(),
node_name=self._current_node,
phase=self._current_phase,
tool_name=tool_name,
data={
"tool_name": tool_name,
"input": self._truncate_data(tool_input),
"message": f"🔧 调用工具: {tool_name}",
},
)
async def _handle_tool_end(self, tool_name: str, data: Dict) -> StreamEvent:
"""处理工具结束事件"""
import time
# 计算执行时间
duration_ms = 0
if tool_name in self._tool_states:
start_time = self._tool_states[tool_name].get("start_time", time.time())
duration_ms = int((time.time() - start_time) * 1000)
del self._tool_states[tool_name]
# 提取输出
output = data.get("output", "")
if hasattr(output, "content"):
output = output.content
return StreamEvent(
event_type=StreamEventType.TOOL_CALL_END,
sequence=self._next_sequence(),
node_name=self._current_node,
phase=self._current_phase,
tool_name=tool_name,
data={
"tool_name": tool_name,
"output": self._truncate_data(output),
"duration_ms": duration_ms,
"message": f"✅ 工具 {tool_name} 完成 ({duration_ms}ms)",
},
)
async def _handle_node_start(self, node_name: str, data: Dict) -> StreamEvent:
"""处理节点开始事件"""
self._current_node = node_name
# 映射节点到阶段
phase_map = {
"recon": "reconnaissance",
"analysis": "analysis",
"verification": "verification",
"report": "reporting",
}
for key, phase in phase_map.items():
if key in node_name.lower():
self._current_phase = phase
break
return StreamEvent(
event_type=StreamEventType.NODE_START,
sequence=self._next_sequence(),
node_name=node_name,
phase=self._current_phase,
data={
"node": node_name,
"phase": self._current_phase,
"message": f"▶️ 开始节点: {node_name}",
},
)
async def _handle_node_end(self, node_name: str, data: Dict) -> StreamEvent:
"""处理节点结束事件"""
# 提取输出信息
output = data.get("output", {})
summary = {}
if isinstance(output, dict):
# 提取关键信息
if "findings" in output:
summary["findings_count"] = len(output["findings"])
if "entry_points" in output:
summary["entry_points_count"] = len(output["entry_points"])
if "high_risk_areas" in output:
summary["high_risk_areas_count"] = len(output["high_risk_areas"])
if "verified_findings" in output:
summary["verified_count"] = len(output["verified_findings"])
return StreamEvent(
event_type=StreamEventType.NODE_END,
sequence=self._next_sequence(),
node_name=node_name,
phase=self._current_phase,
data={
"node": node_name,
"phase": self._current_phase,
"summary": summary,
"message": f"⏹️ 节点完成: {node_name}",
},
)
async def _handle_custom_event(self, event_name: str, data: Dict) -> StreamEvent:
"""处理自定义事件"""
# 映射自定义事件名到事件类型
event_type_map = {
"finding": StreamEventType.FINDING_NEW,
"finding_verified": StreamEventType.FINDING_VERIFIED,
"progress": StreamEventType.PROGRESS,
"warning": StreamEventType.WARNING,
"error": StreamEventType.ERROR,
}
event_type = event_type_map.get(event_name, StreamEventType.INFO)
return StreamEvent(
event_type=event_type,
sequence=self._next_sequence(),
node_name=self._current_node,
phase=self._current_phase,
data=data,
)
def _truncate_data(self, data: Any, max_length: int = 1000) -> Any:
"""截断数据"""
if isinstance(data, str):
return data[:max_length] + "..." if len(data) > max_length else data
elif isinstance(data, dict):
return {k: self._truncate_data(v, max_length // 2) for k, v in list(data.items())[:10]}
elif isinstance(data, list):
return [self._truncate_data(item, max_length // len(data)) for item in data[:10]]
else:
return str(data)[:max_length]
def create_progress_event(
self,
current: int,
total: int,
message: Optional[str] = None,
) -> StreamEvent:
"""创建进度事件"""
percentage = (current / total * 100) if total > 0 else 0
return StreamEvent(
event_type=StreamEventType.PROGRESS,
sequence=self._next_sequence(),
node_name=self._current_node,
phase=self._current_phase,
data={
"current": current,
"total": total,
"percentage": round(percentage, 1),
"message": message or f"进度: {current}/{total}",
},
)
def create_finding_event(
self,
finding: Dict[str, Any],
is_verified: bool = False,
) -> StreamEvent:
"""创建发现事件"""
event_type = StreamEventType.FINDING_VERIFIED if is_verified else StreamEventType.FINDING_NEW
return StreamEvent(
event_type=event_type,
sequence=self._next_sequence(),
node_name=self._current_node,
phase=self._current_phase,
data={
"title": finding.get("title", "Unknown"),
"severity": finding.get("severity", "medium"),
"vulnerability_type": finding.get("vulnerability_type", "other"),
"file_path": finding.get("file_path"),
"line_start": finding.get("line_start"),
"is_verified": is_verified,
"message": f"{'✅ 已验证' if is_verified else '🔍 新发现'}: [{finding.get('severity', 'medium').upper()}] {finding.get('title', 'Unknown')}",
},
)
def create_heartbeat(self) -> StreamEvent:
"""创建心跳事件"""
return StreamEvent(
event_type=StreamEventType.HEARTBEAT,
sequence=self._sequence, # 心跳不增加序列号
data={"message": "ping"},
)

View File

@ -0,0 +1,261 @@
"""
LLM Token 流式输出处理器
支持多种 LLM 提供商的流式输出
"""
import asyncio
import logging
from typing import Any, Dict, Optional, AsyncGenerator, Callable
from dataclasses import dataclass
from datetime import datetime, timezone
logger = logging.getLogger(__name__)
@dataclass
class TokenChunk:
"""Token 块"""
content: str
token_count: int = 1
finish_reason: Optional[str] = None
model: Optional[str] = None
# 统计信息
accumulated_content: str = ""
total_tokens: int = 0
class TokenStreamer:
"""
LLM Token 流式输出处理器
最佳实践:
1. 使用 LiteLLM 的流式 API
2. 实时发送每个 Token
3. 跟踪累积内容和 Token 使用
4. 支持中断和超时
"""
def __init__(
self,
model: str,
api_key: Optional[str] = None,
base_url: Optional[str] = None,
on_token: Optional[Callable[[TokenChunk], None]] = None,
):
self.model = model
self.api_key = api_key
self.base_url = base_url
self.on_token = on_token
self._cancelled = False
self._accumulated_content = ""
self._total_tokens = 0
def cancel(self):
"""取消流式输出"""
self._cancelled = True
async def stream_completion(
self,
messages: list[Dict[str, str]],
temperature: float = 0.1,
max_tokens: int = 4096,
) -> AsyncGenerator[TokenChunk, None]:
"""
流式调用 LLM
Args:
messages: 消息列表
temperature: 温度
max_tokens: 最大 Token
Yields:
TokenChunk: Token
"""
try:
import litellm
response = await litellm.acompletion(
model=self.model,
messages=messages,
temperature=temperature,
max_tokens=max_tokens,
api_key=self.api_key,
base_url=self.base_url,
stream=True, # 启用流式输出
)
async for chunk in response:
if self._cancelled:
break
# 提取内容
content = ""
finish_reason = None
if hasattr(chunk, "choices") and chunk.choices:
choice = chunk.choices[0]
if hasattr(choice, "delta") and choice.delta:
content = getattr(choice.delta, "content", "") or ""
finish_reason = getattr(choice, "finish_reason", None)
if content:
self._accumulated_content += content
self._total_tokens += 1
token_chunk = TokenChunk(
content=content,
token_count=1,
finish_reason=finish_reason,
model=self.model,
accumulated_content=self._accumulated_content,
total_tokens=self._total_tokens,
)
# 回调
if self.on_token:
self.on_token(token_chunk)
yield token_chunk
# 检查是否完成
if finish_reason:
break
except asyncio.CancelledError:
logger.info("Token streaming cancelled")
raise
except Exception as e:
logger.error(f"Token streaming error: {e}")
raise
async def stream_with_tools(
self,
messages: list[Dict[str, str]],
tools: list[Dict[str, Any]],
temperature: float = 0.1,
max_tokens: int = 4096,
) -> AsyncGenerator[Dict[str, Any], None]:
"""
带工具调用的流式输出
Args:
messages: 消息列表
tools: 工具定义列表
temperature: 温度
max_tokens: 最大 Token
Yields:
包含 token tool_call 的字典
"""
try:
import litellm
response = await litellm.acompletion(
model=self.model,
messages=messages,
tools=tools,
temperature=temperature,
max_tokens=max_tokens,
api_key=self.api_key,
base_url=self.base_url,
stream=True,
)
# 工具调用累积器
tool_calls_accumulator: Dict[int, Dict] = {}
async for chunk in response:
if self._cancelled:
break
if not hasattr(chunk, "choices") or not chunk.choices:
continue
choice = chunk.choices[0]
delta = getattr(choice, "delta", None)
finish_reason = getattr(choice, "finish_reason", None)
if delta:
# 处理文本内容
content = getattr(delta, "content", "") or ""
if content:
self._accumulated_content += content
self._total_tokens += 1
yield {
"type": "token",
"content": content,
"accumulated": self._accumulated_content,
"total_tokens": self._total_tokens,
}
# 处理工具调用
tool_calls = getattr(delta, "tool_calls", None) or []
for tool_call in tool_calls:
idx = tool_call.index
if idx not in tool_calls_accumulator:
tool_calls_accumulator[idx] = {
"id": tool_call.id or "",
"name": "",
"arguments": "",
}
if tool_call.function:
if tool_call.function.name:
tool_calls_accumulator[idx]["name"] = tool_call.function.name
if tool_call.function.arguments:
tool_calls_accumulator[idx]["arguments"] += tool_call.function.arguments
yield {
"type": "tool_call_chunk",
"index": idx,
"tool_call": tool_calls_accumulator[idx],
}
# 完成时发送最终工具调用
if finish_reason == "tool_calls":
for idx, tool_call in tool_calls_accumulator.items():
yield {
"type": "tool_call_complete",
"index": idx,
"tool_call": tool_call,
}
if finish_reason:
yield {
"type": "finish",
"reason": finish_reason,
"accumulated": self._accumulated_content,
"total_tokens": self._total_tokens,
}
break
except asyncio.CancelledError:
logger.info("Tool streaming cancelled")
raise
except Exception as e:
logger.error(f"Tool streaming error: {e}")
yield {
"type": "error",
"error": str(e),
}
def get_accumulated_content(self) -> str:
"""获取累积内容"""
return self._accumulated_content
def get_total_tokens(self) -> int:
"""获取总 Token 数"""
return self._total_tokens
def reset(self):
"""重置状态"""
self._cancelled = False
self._accumulated_content = ""
self._total_tokens = 0

View File

@ -0,0 +1,319 @@
"""
工具调用流式处理器
展示工具调用的输入执行过程和输出
"""
import asyncio
import time
import logging
from typing import Any, Dict, Optional, AsyncGenerator, List, Callable
from dataclasses import dataclass, field
from datetime import datetime, timezone
from enum import Enum
logger = logging.getLogger(__name__)
class ToolCallState(str, Enum):
"""工具调用状态"""
PENDING = "pending" # 等待执行
RUNNING = "running" # 执行中
SUCCESS = "success" # 成功
ERROR = "error" # 错误
TIMEOUT = "timeout" # 超时
@dataclass
class ToolCallEvent:
"""工具调用事件"""
tool_name: str
state: ToolCallState
# 输入输出
input_params: Dict[str, Any] = field(default_factory=dict)
output_data: Optional[Any] = None
error_message: Optional[str] = None
# 时间
start_time: Optional[float] = None
end_time: Optional[float] = None
duration_ms: int = 0
# 元数据
call_id: Optional[str] = None
sequence: int = 0
timestamp: str = field(default_factory=lambda: datetime.now(timezone.utc).isoformat())
def to_dict(self) -> Dict[str, Any]:
"""转换为字典"""
return {
"tool_name": self.tool_name,
"state": self.state.value,
"input_params": self._truncate(self.input_params),
"output_data": self._truncate(self.output_data),
"error_message": self.error_message,
"duration_ms": self.duration_ms,
"call_id": self.call_id,
"sequence": self.sequence,
"timestamp": self.timestamp,
}
def _truncate(self, data: Any, max_length: int = 500) -> Any:
"""截断数据"""
if data is None:
return None
if isinstance(data, str):
return data[:max_length] + "..." if len(data) > max_length else data
elif isinstance(data, dict):
return {k: self._truncate(v, max_length // 2) for k, v in list(data.items())[:20]}
elif isinstance(data, list):
max_items = min(20, len(data))
return [self._truncate(item, max_length // max_items) for item in data[:max_items]]
else:
s = str(data)
return s[:max_length] + "..." if len(s) > max_length else s
class ToolStreamHandler:
"""
工具调用流式处理器
功能:
1. 跟踪工具调用状态
2. 记录输入参数
3. 流式输出执行过程
4. 记录输出和执行时间
"""
def __init__(
self,
on_event: Optional[Callable[[ToolCallEvent], None]] = None,
):
self.on_event = on_event
self._sequence = 0
self._active_calls: Dict[str, ToolCallEvent] = {}
self._history: List[ToolCallEvent] = []
def _next_sequence(self) -> int:
"""获取下一个序列号"""
self._sequence += 1
return self._sequence
def _generate_call_id(self) -> str:
"""生成调用 ID"""
import uuid
return str(uuid.uuid4())[:8]
async def emit_tool_start(
self,
tool_name: str,
input_params: Dict[str, Any],
call_id: Optional[str] = None,
) -> ToolCallEvent:
"""
发射工具开始事件
Args:
tool_name: 工具名称
input_params: 输入参数
call_id: 调用 ID
Returns:
工具调用事件
"""
call_id = call_id or self._generate_call_id()
event = ToolCallEvent(
tool_name=tool_name,
state=ToolCallState.RUNNING,
input_params=input_params,
start_time=time.time(),
call_id=call_id,
sequence=self._next_sequence(),
)
self._active_calls[call_id] = event
if self.on_event:
self.on_event(event)
return event
async def emit_tool_end(
self,
call_id: str,
output_data: Any,
is_error: bool = False,
error_message: Optional[str] = None,
) -> ToolCallEvent:
"""
发射工具结束事件
Args:
call_id: 调用 ID
output_data: 输出数据
is_error: 是否错误
error_message: 错误消息
Returns:
工具调用事件
"""
if call_id not in self._active_calls:
logger.warning(f"Unknown tool call: {call_id}")
return None
event = self._active_calls[call_id]
event.end_time = time.time()
event.duration_ms = int((event.end_time - event.start_time) * 1000) if event.start_time else 0
event.output_data = output_data
event.sequence = self._next_sequence()
if is_error:
event.state = ToolCallState.ERROR
event.error_message = error_message or str(output_data)
else:
event.state = ToolCallState.SUCCESS
# 移动到历史记录
del self._active_calls[call_id]
self._history.append(event)
if self.on_event:
self.on_event(event)
return event
async def emit_tool_timeout(self, call_id: str, timeout_seconds: int) -> ToolCallEvent:
"""发射工具超时事件"""
if call_id not in self._active_calls:
return None
event = self._active_calls[call_id]
event.end_time = time.time()
event.duration_ms = int((event.end_time - event.start_time) * 1000) if event.start_time else 0
event.state = ToolCallState.TIMEOUT
event.error_message = f"Tool execution timed out after {timeout_seconds}s"
event.sequence = self._next_sequence()
del self._active_calls[call_id]
self._history.append(event)
if self.on_event:
self.on_event(event)
return event
def wrap_tool(
self,
tool_func: Callable,
tool_name: str,
timeout: Optional[int] = None,
) -> Callable:
"""
包装工具函数以自动跟踪
Args:
tool_func: 工具函数
tool_name: 工具名称
timeout: 超时时间
Returns:
包装后的函数
"""
async def wrapped(*args, **kwargs):
call_id = self._generate_call_id()
# 发射开始事件
await self.emit_tool_start(
tool_name=tool_name,
input_params={"args": args, "kwargs": kwargs},
call_id=call_id,
)
try:
# 执行工具
if asyncio.iscoroutinefunction(tool_func):
if timeout:
result = await asyncio.wait_for(
tool_func(*args, **kwargs),
timeout=timeout,
)
else:
result = await tool_func(*args, **kwargs)
else:
if timeout:
result = await asyncio.wait_for(
asyncio.to_thread(tool_func, *args, **kwargs),
timeout=timeout,
)
else:
result = tool_func(*args, **kwargs)
# 发射结束事件
await self.emit_tool_end(call_id, result)
return result
except asyncio.TimeoutError:
await self.emit_tool_timeout(call_id, timeout or 0)
raise
except Exception as e:
await self.emit_tool_end(call_id, None, is_error=True, error_message=str(e))
raise
return wrapped
def get_active_calls(self) -> List[ToolCallEvent]:
"""获取活跃的调用"""
return list(self._active_calls.values())
def get_history(self, limit: int = 100) -> List[ToolCallEvent]:
"""获取历史记录"""
return self._history[-limit:]
def get_stats(self) -> Dict[str, Any]:
"""获取统计信息"""
total_calls = len(self._history)
success_calls = sum(1 for e in self._history if e.state == ToolCallState.SUCCESS)
error_calls = sum(1 for e in self._history if e.state == ToolCallState.ERROR)
timeout_calls = sum(1 for e in self._history if e.state == ToolCallState.TIMEOUT)
total_duration = sum(e.duration_ms for e in self._history)
avg_duration = total_duration / total_calls if total_calls > 0 else 0
# 按工具统计
tool_stats = {}
for event in self._history:
if event.tool_name not in tool_stats:
tool_stats[event.tool_name] = {
"calls": 0,
"success": 0,
"errors": 0,
"total_duration_ms": 0,
}
tool_stats[event.tool_name]["calls"] += 1
if event.state == ToolCallState.SUCCESS:
tool_stats[event.tool_name]["success"] += 1
elif event.state in [ToolCallState.ERROR, ToolCallState.TIMEOUT]:
tool_stats[event.tool_name]["errors"] += 1
tool_stats[event.tool_name]["total_duration_ms"] += event.duration_ms
return {
"total_calls": total_calls,
"success_calls": success_calls,
"error_calls": error_calls,
"timeout_calls": timeout_calls,
"success_rate": success_calls / total_calls if total_calls > 0 else 0,
"total_duration_ms": total_duration,
"avg_duration_ms": round(avg_duration, 2),
"active_calls": len(self._active_calls),
"by_tool": tool_stats,
}
def clear(self):
"""清空记录"""
self._active_calls.clear()
self._history.clear()
self._sequence = 0

View File

@ -41,6 +41,16 @@ class PatternMatchTool(AgentTool):
使用正则表达式快速扫描代码中的危险模式
"""
def __init__(self, project_root: str = None):
"""
初始化模式匹配工具
Args:
project_root: 项目根目录可选用于上下文
"""
super().__init__()
self.project_root = project_root
# 危险模式定义
PATTERNS: Dict[str, Dict[str, Any]] = {
# SQL 注入模式

View File

@ -145,3 +145,5 @@ class BaiduAdapter(BaseLLMAdapter):

View File

@ -83,3 +83,5 @@ class DoubaoAdapter(BaseLLMAdapter):

View File

@ -86,3 +86,5 @@ class MinimaxAdapter(BaseLLMAdapter):

View File

@ -134,3 +134,5 @@ class BaseLLMAdapter(ABC):

View File

@ -120,3 +120,5 @@ DEFAULT_BASE_URLS: Dict[LLMProvider, str] = {

View File

@ -0,0 +1,5 @@
"""
DeepAudit Agent 测试套件
企业级测试框架覆盖工具Agent节点和完整流程
"""

View File

@ -0,0 +1,295 @@
"""
Agent 测试配置和 Fixtures
提供测试所需的公共设施
"""
import pytest
import asyncio
import tempfile
import shutil
import os
from typing import Dict, Any, Optional
from unittest.mock import AsyncMock, MagicMock, patch
from dataclasses import dataclass
# ============ 测试配置 ============
@pytest.fixture(scope="session")
def event_loop():
"""创建事件循环"""
loop = asyncio.new_event_loop()
yield loop
loop.close()
@pytest.fixture
def temp_project_dir():
"""创建临时项目目录,包含测试代码"""
temp_dir = tempfile.mkdtemp(prefix="deepaudit_test_")
# 创建测试项目结构
os.makedirs(os.path.join(temp_dir, "src"), exist_ok=True)
os.makedirs(os.path.join(temp_dir, "config"), exist_ok=True)
# 创建有漏洞的测试代码 - SQL 注入
sql_vuln_code = '''
import sqlite3
def get_user(user_id):
"""危险SQL 注入漏洞"""
conn = sqlite3.connect("users.db")
cursor = conn.cursor()
# 直接拼接用户输入,存在 SQL 注入风险
query = f"SELECT * FROM users WHERE id = '{user_id}'"
cursor.execute(query)
return cursor.fetchone()
def search_users(name):
"""危险SQL 注入漏洞"""
conn = sqlite3.connect("users.db")
cursor = conn.cursor()
cursor.execute("SELECT * FROM users WHERE name LIKE '%" + name + "%'")
return cursor.fetchall()
'''
# 创建有漏洞的测试代码 - 命令注入
cmd_vuln_code = '''
import os
import subprocess
def run_command(user_input):
"""危险:命令注入漏洞"""
# 直接执行用户输入
os.system(f"echo {user_input}")
def execute_script(script_name):
"""危险:命令注入漏洞"""
subprocess.call(f"bash {script_name}", shell=True)
'''
# 创建有漏洞的测试代码 - XSS
xss_vuln_code = '''
from flask import Flask, request, render_template_string
app = Flask(__name__)
@app.route("/greet")
def greet():
"""危险XSS 漏洞"""
name = request.args.get("name", "")
# 直接将用户输入嵌入 HTML存在 XSS 风险
return f"<h1>Hello, {name}!</h1>"
@app.route("/search")
def search():
"""危险XSS 漏洞"""
query = request.args.get("q", "")
html = f"<p>搜索结果: {query}</p>"
return render_template_string(html)
'''
# 创建有漏洞的测试代码 - 路径遍历
path_vuln_code = '''
import os
def read_file(filename):
"""危险:路径遍历漏洞"""
# 没有验证文件路径
filepath = os.path.join("/app/data", filename)
with open(filepath, "r") as f:
return f.read()
def download_file(user_path):
"""危险:路径遍历漏洞"""
# 直接使用用户输入作为文件路径
with open(user_path, "rb") as f:
return f.read()
'''
# 创建有漏洞的测试代码 - 硬编码密钥
secret_vuln_code = '''
# 配置文件
DATABASE_URL = "postgresql://user:password123@localhost/db"
API_KEY = "sk-1234567890abcdef1234567890abcdef"
SECRET_KEY = "super_secret_key_dont_share"
AWS_SECRET_ACCESS_KEY = "wJalrXUtnFEMI/K7MDENG/bPxRfiCYEXAMPLEKEY"
def connect_database():
password = "admin123" # 硬编码密码
return f"mysql://root:{password}@localhost/mydb"
'''
# 创建安全的代码(用于对比)
safe_code = '''
import sqlite3
from typing import Optional
def get_user_safe(user_id: int) -> Optional[dict]:
"""安全:使用参数化查询"""
conn = sqlite3.connect("users.db")
cursor = conn.cursor()
cursor.execute("SELECT * FROM users WHERE id = ?", (user_id,))
return cursor.fetchone()
def validate_input(user_input: str) -> str:
"""输入验证"""
import re
if not re.match(r'^[a-zA-Z0-9_]+$', user_input):
raise ValueError("Invalid input")
return user_input
'''
# 创建配置文件
config_code = '''
import os
class Config:
"""安全配置"""
DATABASE_URL = os.environ.get("DATABASE_URL")
SECRET_KEY = os.environ.get("SECRET_KEY")
DEBUG = False
'''
# 创建 requirements.txt
requirements = '''
flask>=2.0.0
sqlalchemy>=2.0.0
requests>=2.28.0
'''
# 写入文件
with open(os.path.join(temp_dir, "src", "sql_vuln.py"), "w") as f:
f.write(sql_vuln_code)
with open(os.path.join(temp_dir, "src", "cmd_vuln.py"), "w") as f:
f.write(cmd_vuln_code)
with open(os.path.join(temp_dir, "src", "xss_vuln.py"), "w") as f:
f.write(xss_vuln_code)
with open(os.path.join(temp_dir, "src", "path_vuln.py"), "w") as f:
f.write(path_vuln_code)
with open(os.path.join(temp_dir, "src", "secrets.py"), "w") as f:
f.write(secret_vuln_code)
with open(os.path.join(temp_dir, "src", "safe_code.py"), "w") as f:
f.write(safe_code)
with open(os.path.join(temp_dir, "config", "settings.py"), "w") as f:
f.write(config_code)
with open(os.path.join(temp_dir, "requirements.txt"), "w") as f:
f.write(requirements)
yield temp_dir
# 清理
shutil.rmtree(temp_dir, ignore_errors=True)
@pytest.fixture
def mock_llm_service():
"""模拟 LLM 服务"""
service = MagicMock()
service.chat_completion_raw = AsyncMock(return_value={
"content": "测试响应",
"usage": {"total_tokens": 100},
})
return service
@pytest.fixture
def mock_event_emitter():
"""模拟事件发射器"""
emitter = MagicMock()
emitter.emit_info = AsyncMock()
emitter.emit_warning = AsyncMock()
emitter.emit_error = AsyncMock()
emitter.emit_thinking = AsyncMock()
emitter.emit_tool_call = AsyncMock()
emitter.emit_tool_result = AsyncMock()
emitter.emit_finding = AsyncMock()
emitter.emit_progress = AsyncMock()
emitter.emit_phase_start = AsyncMock()
emitter.emit_phase_complete = AsyncMock()
emitter.emit_task_complete = AsyncMock()
emitter.emit = AsyncMock()
return emitter
@pytest.fixture
def mock_db_session():
"""模拟数据库会话"""
session = AsyncMock()
session.add = MagicMock()
session.commit = AsyncMock()
session.rollback = AsyncMock()
session.get = AsyncMock(return_value=None)
session.execute = AsyncMock()
return session
@dataclass
class MockProject:
"""模拟项目"""
id: str = "test-project-id"
name: str = "Test Project"
description: str = "Test project for unit tests"
@dataclass
class MockAgentTask:
"""模拟 Agent 任务"""
id: str = "test-task-id"
project_id: str = "test-project-id"
project: MockProject = None
name: str = "Test Agent Task"
status: str = "pending"
current_phase: str = "planning"
target_vulnerabilities: list = None
verification_level: str = "sandbox"
exclude_patterns: list = None
target_files: list = None
max_iterations: int = 50
timeout_seconds: int = 1800
def __post_init__(self):
if self.project is None:
self.project = MockProject()
if self.target_vulnerabilities is None:
self.target_vulnerabilities = []
if self.exclude_patterns is None:
self.exclude_patterns = []
if self.target_files is None:
self.target_files = []
@pytest.fixture
def mock_task():
"""创建模拟任务"""
return MockAgentTask()
# ============ 测试辅助函数 ============
def assert_finding_valid(finding: Dict[str, Any]):
"""验证漏洞发现的格式"""
required_fields = ["title", "severity", "vulnerability_type"]
for field in required_fields:
assert field in finding, f"Missing required field: {field}"
valid_severities = ["critical", "high", "medium", "low", "info"]
assert finding["severity"] in valid_severities, f"Invalid severity: {finding['severity']}"
def count_findings_by_type(findings: list, vuln_type: str) -> int:
"""统计特定类型的漏洞数量"""
return sum(1 for f in findings if f.get("vulnerability_type") == vuln_type)
def count_findings_by_severity(findings: list, severity: str) -> int:
"""统计特定严重程度的漏洞数量"""
return sum(1 for f in findings if f.get("severity") == severity)

View File

@ -0,0 +1,96 @@
#!/usr/bin/env python
"""
Agent 测试运行器
运行所有 Agent 相关测试并生成报告
"""
import subprocess
import sys
import os
from datetime import datetime
from pathlib import Path
def run_tests():
"""运行测试"""
# 获取项目根目录
project_root = Path(__file__).parent.parent.parent
os.chdir(project_root)
print("=" * 60)
print("DeepAudit Agent 测试套件")
print(f"时间: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}")
print("=" * 60)
print()
# 测试命令
cmd = [
sys.executable, "-m", "pytest",
"tests/agent/",
"-v",
"--tb=short",
"-x", # 遇到第一个失败就停止
"--color=yes",
]
print(f"运行命令: {' '.join(cmd)}")
print()
# 运行测试
result = subprocess.run(cmd, cwd=project_root)
print()
print("=" * 60)
if result.returncode == 0:
print("✅ 所有测试通过!")
else:
print(f"❌ 测试失败 (退出码: {result.returncode})")
print("=" * 60)
return result.returncode
def run_tests_with_coverage():
"""运行测试并生成覆盖率报告"""
project_root = Path(__file__).parent.parent.parent
os.chdir(project_root)
cmd = [
sys.executable, "-m", "pytest",
"tests/agent/",
"-v",
"--cov=app/services/agent",
"--cov-report=term-missing",
"--cov-report=html:coverage_agent",
]
result = subprocess.run(cmd, cwd=project_root)
return result.returncode
def run_specific_test(test_name: str):
"""运行特定测试"""
project_root = Path(__file__).parent.parent.parent
os.chdir(project_root)
cmd = [
sys.executable, "-m", "pytest",
f"tests/agent/{test_name}",
"-v",
"--tb=long",
"-s", # 显示 print 输出
]
result = subprocess.run(cmd, cwd=project_root)
return result.returncode
if __name__ == "__main__":
if len(sys.argv) > 1:
if sys.argv[1] == "--coverage":
sys.exit(run_tests_with_coverage())
else:
sys.exit(run_specific_test(sys.argv[1]))
else:
sys.exit(run_tests())

View File

@ -0,0 +1,213 @@
"""
Agent 单元测试
测试各个 Agent 的功能
"""
import pytest
import asyncio
import os
from unittest.mock import MagicMock, AsyncMock, patch
from app.services.agent.agents.base import BaseAgent, AgentConfig, AgentResult, AgentType, AgentPattern
from app.services.agent.agents.recon import ReconAgent
from app.services.agent.agents.analysis import AnalysisAgent
from app.services.agent.agents.verification import VerificationAgent
class TestReconAgent:
"""Recon Agent 测试"""
@pytest.fixture
def recon_agent(self, temp_project_dir, mock_llm_service, mock_event_emitter):
"""创建 Recon Agent 实例"""
from app.services.agent.tools import (
FileReadTool, FileSearchTool, ListFilesTool,
)
tools = {
"list_files": ListFilesTool(temp_project_dir),
"read_file": FileReadTool(temp_project_dir),
"search_code": FileSearchTool(temp_project_dir),
}
return ReconAgent(
llm_service=mock_llm_service,
tools=tools,
event_emitter=mock_event_emitter,
)
@pytest.mark.asyncio
async def test_recon_agent_run(self, recon_agent, temp_project_dir):
"""测试 Recon Agent 运行"""
result = await recon_agent.run({
"project_info": {
"name": "Test Project",
"root": temp_project_dir,
},
"config": {},
})
assert result.success is True
assert result.data is not None
# 验证返回数据结构
data = result.data
assert "tech_stack" in data
assert "entry_points" in data or "high_risk_areas" in data
@pytest.mark.asyncio
async def test_recon_agent_identifies_python(self, recon_agent, temp_project_dir):
"""测试 Recon Agent 识别 Python 技术栈"""
result = await recon_agent.run({
"project_info": {"root": temp_project_dir},
"config": {},
})
assert result.success is True
tech_stack = result.data.get("tech_stack", {})
languages = tech_stack.get("languages", [])
# 应该识别出 Python
assert "Python" in languages or len(languages) > 0
@pytest.mark.asyncio
async def test_recon_agent_finds_high_risk_areas(self, recon_agent, temp_project_dir):
"""测试 Recon Agent 发现高风险区域"""
result = await recon_agent.run({
"project_info": {"root": temp_project_dir},
"config": {},
})
assert result.success is True
high_risk_areas = result.data.get("high_risk_areas", [])
# 应该发现高风险区域
assert len(high_risk_areas) > 0
class TestAnalysisAgent:
"""Analysis Agent 测试"""
@pytest.fixture
def analysis_agent(self, temp_project_dir, mock_llm_service, mock_event_emitter):
"""创建 Analysis Agent 实例"""
from app.services.agent.tools import (
FileReadTool, FileSearchTool, PatternMatchTool,
)
tools = {
"read_file": FileReadTool(temp_project_dir),
"search_code": FileSearchTool(temp_project_dir),
"pattern_match": PatternMatchTool(temp_project_dir),
}
return AnalysisAgent(
llm_service=mock_llm_service,
tools=tools,
event_emitter=mock_event_emitter,
)
@pytest.mark.asyncio
async def test_analysis_agent_run(self, analysis_agent, temp_project_dir):
"""测试 Analysis Agent 运行"""
result = await analysis_agent.run({
"tech_stack": {"languages": ["Python"]},
"entry_points": [],
"high_risk_areas": ["src/sql_vuln.py", "src/cmd_vuln.py"],
"config": {},
})
assert result.success is True
assert result.data is not None
@pytest.mark.asyncio
async def test_analysis_agent_finds_vulnerabilities(self, analysis_agent, temp_project_dir):
"""测试 Analysis Agent 发现漏洞"""
result = await analysis_agent.run({
"tech_stack": {"languages": ["Python"]},
"entry_points": [],
"high_risk_areas": [
"src/sql_vuln.py",
"src/cmd_vuln.py",
"src/xss_vuln.py",
"src/secrets.py",
],
"config": {},
})
assert result.success is True
findings = result.data.get("findings", [])
# 应该发现一些漏洞
# 注意:具体数量取决于分析逻辑
assert isinstance(findings, list)
class TestAgentResult:
"""Agent 结果测试"""
def test_agent_result_success(self):
"""测试成功的 Agent 结果"""
result = AgentResult(
success=True,
data={"findings": []},
iterations=5,
tool_calls=10,
)
assert result.success is True
assert result.iterations == 5
assert result.tool_calls == 10
def test_agent_result_failure(self):
"""测试失败的 Agent 结果"""
result = AgentResult(
success=False,
error="Test error",
)
assert result.success is False
assert result.error == "Test error"
def test_agent_result_to_dict(self):
"""测试 Agent 结果转字典"""
result = AgentResult(
success=True,
data={"key": "value"},
iterations=3,
)
d = result.to_dict()
assert d["success"] is True
assert d["iterations"] == 3
class TestAgentConfig:
"""Agent 配置测试"""
def test_agent_config_defaults(self):
"""测试 Agent 配置默认值"""
config = AgentConfig(
name="Test",
agent_type=AgentType.RECON,
)
assert config.pattern == AgentPattern.REACT
assert config.max_iterations == 20
assert config.temperature == 0.1
def test_agent_config_custom(self):
"""测试自定义 Agent 配置"""
config = AgentConfig(
name="Custom",
agent_type=AgentType.ANALYSIS,
pattern=AgentPattern.PLAN_AND_EXECUTE,
max_iterations=50,
temperature=0.5,
)
assert config.pattern == AgentPattern.PLAN_AND_EXECUTE
assert config.max_iterations == 50
assert config.temperature == 0.5

View File

@ -0,0 +1,355 @@
"""
Agent 集成测试
测试完整的审计流程
"""
import pytest
import asyncio
import os
from unittest.mock import MagicMock, AsyncMock, patch
from datetime import datetime
from app.services.agent.graph.runner import AgentRunner, LLMService
from app.services.agent.graph.audit_graph import AuditState, create_audit_graph
from app.services.agent.graph.nodes import ReconNode, AnalysisNode, VerificationNode, ReportNode
from app.services.agent.event_manager import EventManager, AgentEventEmitter
class TestLLMService:
"""LLM 服务测试"""
@pytest.mark.asyncio
async def test_llm_service_initialization(self):
"""测试 LLM 服务初始化"""
with patch("app.core.config.settings") as mock_settings:
mock_settings.LLM_MODEL = "gpt-4o-mini"
mock_settings.LLM_API_KEY = "test-key"
service = LLMService()
assert service.model == "gpt-4o-mini"
class TestEventManager:
"""事件管理器测试"""
def test_event_manager_initialization(self):
"""测试事件管理器初始化"""
manager = EventManager()
assert manager._event_queues == {}
assert manager._event_callbacks == {}
@pytest.mark.asyncio
async def test_event_emitter(self):
"""测试事件发射器"""
manager = EventManager()
emitter = AgentEventEmitter("test-task-id", manager)
await emitter.emit_info("Test message")
assert emitter._sequence == 1
@pytest.mark.asyncio
async def test_event_emitter_phase_tracking(self):
"""测试事件发射器阶段跟踪"""
manager = EventManager()
emitter = AgentEventEmitter("test-task-id", manager)
await emitter.emit_phase_start("recon", "开始信息收集")
assert emitter._current_phase == "recon"
@pytest.mark.asyncio
async def test_event_emitter_task_complete(self):
"""测试任务完成事件"""
manager = EventManager()
emitter = AgentEventEmitter("test-task-id", manager)
await emitter.emit_task_complete(findings_count=5, duration_ms=1000)
assert emitter._sequence == 1
class TestAuditGraph:
"""审计图测试"""
def test_create_audit_graph(self, mock_event_emitter):
"""测试创建审计图"""
# 创建模拟节点
recon_node = MagicMock()
analysis_node = MagicMock()
verification_node = MagicMock()
report_node = MagicMock()
graph = create_audit_graph(
recon_node=recon_node,
analysis_node=analysis_node,
verification_node=verification_node,
report_node=report_node,
)
assert graph is not None
class TestReconNode:
"""Recon 节点测试"""
@pytest.fixture
def recon_node_with_mock_agent(self, mock_event_emitter):
"""创建带模拟 Agent 的 Recon 节点"""
mock_agent = MagicMock()
mock_agent.run = AsyncMock(return_value=MagicMock(
success=True,
data={
"tech_stack": {"languages": ["Python"]},
"entry_points": [{"path": "src/app.py", "type": "api"}],
"high_risk_areas": ["src/sql_vuln.py"],
"dependencies": {},
"initial_findings": [],
}
))
return ReconNode(mock_agent, mock_event_emitter)
@pytest.mark.asyncio
async def test_recon_node_success(self, recon_node_with_mock_agent):
"""测试 Recon 节点成功执行"""
state = {
"project_info": {"name": "Test"},
"config": {},
}
result = await recon_node_with_mock_agent(state)
assert "tech_stack" in result
assert "entry_points" in result
assert result["current_phase"] == "recon_complete"
@pytest.mark.asyncio
async def test_recon_node_failure(self, mock_event_emitter):
"""测试 Recon 节点失败处理"""
mock_agent = MagicMock()
mock_agent.run = AsyncMock(return_value=MagicMock(
success=False,
error="Test error",
data=None,
))
node = ReconNode(mock_agent, mock_event_emitter)
result = await node({
"project_info": {},
"config": {},
})
assert "error" in result
assert result["current_phase"] == "error"
class TestAnalysisNode:
"""Analysis 节点测试"""
@pytest.fixture
def analysis_node_with_mock_agent(self, mock_event_emitter):
"""创建带模拟 Agent 的 Analysis 节点"""
mock_agent = MagicMock()
mock_agent.run = AsyncMock(return_value=MagicMock(
success=True,
data={
"findings": [
{
"id": "finding-1",
"title": "SQL Injection",
"severity": "high",
"vulnerability_type": "sql_injection",
"file_path": "src/sql_vuln.py",
"line_start": 10,
"description": "SQL injection vulnerability",
}
],
"should_continue": False,
}
))
return AnalysisNode(mock_agent, mock_event_emitter)
@pytest.mark.asyncio
async def test_analysis_node_success(self, analysis_node_with_mock_agent):
"""测试 Analysis 节点成功执行"""
state = {
"project_info": {"name": "Test"},
"tech_stack": {"languages": ["Python"]},
"entry_points": [],
"high_risk_areas": ["src/sql_vuln.py"],
"config": {},
"iteration": 0,
"findings": [],
}
result = await analysis_node_with_mock_agent(state)
assert "findings" in result
assert len(result["findings"]) > 0
assert result["iteration"] == 1
class TestIntegrationFlow:
"""完整流程集成测试"""
@pytest.mark.asyncio
async def test_full_audit_flow_mock(self, temp_project_dir, mock_db_session, mock_task):
"""测试完整审计流程(使用模拟)"""
# 这个测试验证整个流程的连接性
# 创建事件管理器
event_manager = EventManager()
emitter = AgentEventEmitter(mock_task.id, event_manager)
# 模拟 LLM 服务
mock_llm = MagicMock()
mock_llm.chat_completion_raw = AsyncMock(return_value={
"content": "Analysis complete",
"usage": {"total_tokens": 100},
})
# 验证事件发射
await emitter.emit_phase_start("init", "初始化")
await emitter.emit_info("测试消息")
await emitter.emit_phase_complete("init", "初始化完成")
assert emitter._sequence == 3
@pytest.mark.asyncio
async def test_audit_state_typing(self):
"""测试审计状态类型定义"""
state: AuditState = {
"project_root": "/tmp/test",
"project_info": {"name": "Test"},
"config": {},
"task_id": "test-id",
"tech_stack": {},
"entry_points": [],
"high_risk_areas": [],
"dependencies": {},
"findings": [],
"verified_findings": [],
"false_positives": [],
"current_phase": "start",
"iteration": 0,
"max_iterations": 50,
"should_continue_analysis": False,
"messages": [],
"events": [],
"summary": None,
"security_score": None,
"error": None,
}
assert state["current_phase"] == "start"
assert state["max_iterations"] == 50
class TestToolIntegration:
"""工具集成测试"""
@pytest.mark.asyncio
async def test_tools_work_together(self, temp_project_dir):
"""测试工具协同工作"""
from app.services.agent.tools import (
FileReadTool, FileSearchTool, ListFilesTool, PatternMatchTool,
)
# 1. 列出文件
list_tool = ListFilesTool(temp_project_dir)
list_result = await list_tool.execute(directory="src", recursive=False)
assert list_result.success is True
# 2. 搜索关键代码
search_tool = FileSearchTool(temp_project_dir)
search_result = await search_tool.execute(keyword="execute")
assert search_result.success is True
# 3. 读取文件内容
read_tool = FileReadTool(temp_project_dir)
read_result = await read_tool.execute(file_path="src/sql_vuln.py")
assert read_result.success is True
# 4. 模式匹配
pattern_tool = PatternMatchTool(temp_project_dir)
pattern_result = await pattern_tool.execute(
code=read_result.data,
file_path="src/sql_vuln.py",
language="python"
)
assert pattern_result.success is True
class TestErrorHandling:
"""错误处理测试"""
@pytest.mark.asyncio
async def test_tool_error_handling(self, temp_project_dir):
"""测试工具错误处理"""
from app.services.agent.tools import FileReadTool
tool = FileReadTool(temp_project_dir)
# 尝试读取不存在的文件
result = await tool.execute(file_path="nonexistent/file.py")
assert result.success is False
assert result.error is not None
@pytest.mark.asyncio
async def test_agent_graceful_degradation(self, mock_event_emitter):
"""测试 Agent 优雅降级"""
# 创建一个会失败的 Agent
mock_agent = MagicMock()
mock_agent.run = AsyncMock(side_effect=Exception("Simulated error"))
node = ReconNode(mock_agent, mock_event_emitter)
result = await node({
"project_info": {},
"config": {},
})
# 应该返回错误状态而不是崩溃
assert "error" in result
assert result["current_phase"] == "error"
class TestPerformance:
"""性能测试"""
@pytest.mark.asyncio
async def test_tool_response_time(self, temp_project_dir):
"""测试工具响应时间"""
from app.services.agent.tools import ListFilesTool
import time
tool = ListFilesTool(temp_project_dir)
start = time.time()
await tool.execute(directory=".", recursive=True)
duration = time.time() - start
# 工具应该在合理时间内响应
assert duration < 5.0 # 5 秒内
@pytest.mark.asyncio
async def test_multiple_tool_calls(self, temp_project_dir):
"""测试多次工具调用"""
from app.services.agent.tools import FileSearchTool
tool = FileSearchTool(temp_project_dir)
# 执行多次调用
for _ in range(5):
result = await tool.execute(keyword="def")
assert result.success is True
# 验证调用计数
assert tool._call_count == 5

View File

@ -0,0 +1,248 @@
"""
Agent 工具单元测试
测试各种安全分析工具的功能
"""
import pytest
import asyncio
import os
from unittest.mock import MagicMock, AsyncMock, patch
# 导入工具
from app.services.agent.tools import (
FileReadTool, FileSearchTool, ListFilesTool,
PatternMatchTool,
)
from app.services.agent.tools.base import ToolResult
class TestFileTools:
"""文件操作工具测试"""
@pytest.mark.asyncio
async def test_file_read_tool_success(self, temp_project_dir):
"""测试文件读取工具 - 成功读取"""
tool = FileReadTool(temp_project_dir)
result = await tool.execute(file_path="src/sql_vuln.py")
assert result.success is True
assert "SELECT * FROM users" in result.data
assert "sql_injection" in result.data.lower() or "cursor.execute" in result.data
@pytest.mark.asyncio
async def test_file_read_tool_not_found(self, temp_project_dir):
"""测试文件读取工具 - 文件不存在"""
tool = FileReadTool(temp_project_dir)
result = await tool.execute(file_path="nonexistent.py")
assert result.success is False
assert "不存在" in result.error or "not found" in result.error.lower()
@pytest.mark.asyncio
async def test_file_read_tool_path_traversal_blocked(self, temp_project_dir):
"""测试文件读取工具 - 路径遍历被阻止"""
tool = FileReadTool(temp_project_dir)
result = await tool.execute(file_path="../../../etc/passwd")
assert result.success is False
assert "安全" in result.error or "security" in result.error.lower()
@pytest.mark.asyncio
async def test_file_search_tool(self, temp_project_dir):
"""测试文件搜索工具"""
tool = FileSearchTool(temp_project_dir)
result = await tool.execute(keyword="cursor.execute")
assert result.success is True
assert "sql_vuln.py" in result.data
@pytest.mark.asyncio
async def test_list_files_tool(self, temp_project_dir):
"""测试文件列表工具"""
tool = ListFilesTool(temp_project_dir)
result = await tool.execute(directory=".", recursive=True)
assert result.success is True
assert "sql_vuln.py" in result.data
assert "requirements.txt" in result.data
@pytest.mark.asyncio
async def test_list_files_tool_pattern(self, temp_project_dir):
"""测试文件列表工具 - 文件模式过滤"""
tool = ListFilesTool(temp_project_dir)
result = await tool.execute(directory="src", pattern="*.py")
assert result.success is True
assert "sql_vuln.py" in result.data
class TestPatternMatchTool:
"""模式匹配工具测试"""
@pytest.mark.asyncio
async def test_pattern_match_sql_injection(self, temp_project_dir):
"""测试模式匹配 - SQL 注入检测"""
tool = PatternMatchTool(temp_project_dir)
# 读取有漏洞的代码
with open(os.path.join(temp_project_dir, "src", "sql_vuln.py")) as f:
code = f.read()
result = await tool.execute(
code=code,
file_path="src/sql_vuln.py",
pattern_types=["sql_injection"],
language="python"
)
assert result.success is True
# 应该检测到 SQL 注入模式
if result.data:
assert "sql" in str(result.data).lower() or len(result.metadata.get("matches", [])) > 0
@pytest.mark.asyncio
async def test_pattern_match_command_injection(self, temp_project_dir):
"""测试模式匹配 - 命令注入检测"""
tool = PatternMatchTool(temp_project_dir)
with open(os.path.join(temp_project_dir, "src", "cmd_vuln.py")) as f:
code = f.read()
result = await tool.execute(
code=code,
file_path="src/cmd_vuln.py",
pattern_types=["command_injection"],
language="python"
)
assert result.success is True
@pytest.mark.asyncio
async def test_pattern_match_xss(self, temp_project_dir):
"""测试模式匹配 - XSS 检测"""
tool = PatternMatchTool(temp_project_dir)
with open(os.path.join(temp_project_dir, "src", "xss_vuln.py")) as f:
code = f.read()
result = await tool.execute(
code=code,
file_path="src/xss_vuln.py",
pattern_types=["xss"],
language="python"
)
assert result.success is True
@pytest.mark.asyncio
async def test_pattern_match_hardcoded_secrets(self, temp_project_dir):
"""测试模式匹配 - 硬编码密钥检测"""
tool = PatternMatchTool(temp_project_dir)
with open(os.path.join(temp_project_dir, "src", "secrets.py")) as f:
code = f.read()
result = await tool.execute(
code=code,
file_path="src/secrets.py",
pattern_types=["hardcoded_secret"],
)
assert result.success is True
@pytest.mark.asyncio
async def test_pattern_match_safe_code(self, temp_project_dir):
"""测试模式匹配 - 安全代码应该没有问题"""
tool = PatternMatchTool(temp_project_dir)
with open(os.path.join(temp_project_dir, "src", "safe_code.py")) as f:
code = f.read()
result = await tool.execute(
code=code,
file_path="src/safe_code.py",
pattern_types=["sql_injection"],
language="python"
)
assert result.success is True
# 安全代码使用参数化查询,不应该有 SQL 注入漏洞
# 检查结果数据,如果有 matches 字段
matches = result.metadata.get("matches", [])
if isinstance(matches, list):
# 参数化查询不应该被误报为 SQL 注入
sql_injection_count = sum(
1 for m in matches
if isinstance(m, dict) and "sql" in m.get("pattern_type", "").lower()
)
# 安全代码的 SQL 注入匹配应该很少或没有
assert sql_injection_count <= 1 # 允许少量误报
class TestToolResult:
"""工具结果测试"""
def test_tool_result_success(self):
"""测试成功的工具结果"""
result = ToolResult(success=True, data="test data")
assert result.success is True
assert result.data == "test data"
assert result.error is None
def test_tool_result_failure(self):
"""测试失败的工具结果"""
result = ToolResult(success=False, error="test error")
assert result.success is False
assert result.error == "test error"
def test_tool_result_to_string(self):
"""测试工具结果转字符串"""
result = ToolResult(success=True, data={"key": "value"})
string = result.to_string()
assert "key" in string
assert "value" in string
def test_tool_result_to_string_truncate(self):
"""测试工具结果字符串截断"""
long_data = "x" * 10000
result = ToolResult(success=True, data=long_data)
string = result.to_string(max_length=100)
assert len(string) < len(long_data)
assert "truncated" in string.lower()
class TestToolMetadata:
"""工具元数据测试"""
@pytest.mark.asyncio
async def test_tool_call_count(self, temp_project_dir):
"""测试工具调用计数"""
tool = ListFilesTool(temp_project_dir)
await tool.execute(directory=".")
await tool.execute(directory="src")
assert tool._call_count == 2
@pytest.mark.asyncio
async def test_tool_duration_tracking(self, temp_project_dir):
"""测试工具执行时间跟踪"""
tool = ListFilesTool(temp_project_dir)
result = await tool.execute(directory=".")
assert result.duration_ms >= 0
assert tool._total_duration_ms >= 0

View File

@ -0,0 +1,229 @@
# DeepAudit Agent 审计功能部署清单
## 📋 生产部署前必须完成的检查
### 1. 环境依赖 ✅
```bash
# 后端依赖
cd backend
uv pip install chromadb litellm langchain langgraph
# 外部安全工具(可选但推荐)
pip install semgrep bandit safety
# 或者使用系统包管理器
brew install semgrep # macOS
apt install semgrep # Ubuntu
```
### 2. LLM 配置 ✅
`.env` 文件中配置:
```env
# LLM 配置(必须)
LLM_PROVIDER=openai # 或 azure, anthropic, ollama 等
LLM_MODEL=gpt-4o-mini # 推荐使用 gpt-4 系列
LLM_API_KEY=sk-xxx # 你的 API Key
LLM_BASE_URL= # 可选,自定义端点
# 嵌入模型配置RAG 需要)
EMBEDDING_PROVIDER=openai
EMBEDDING_MODEL=text-embedding-3-small
```
### 3. 数据库迁移 ✅
```bash
cd backend
alembic upgrade head
```
确保以下表已创建:
- `agent_tasks`
- `agent_events`
- `agent_findings`
### 4. 向量数据库 ✅
```bash
# 创建向量数据库目录
mkdir -p /var/data/deepaudit/vector_db
# 在 .env 中配置
VECTOR_DB_PATH=/var/data/deepaudit/vector_db
```
### 5. Docker 沙箱(可选)
如果需要漏洞验证功能:
```bash
# 拉取沙箱镜像
docker pull python:3.11-slim
# 配置沙箱参数
SANDBOX_IMAGE=python:3.11-slim
SANDBOX_MEMORY_LIMIT=256m
SANDBOX_CPU_LIMIT=0.5
```
---
## 🔬 功能测试检查
### 测试 1: 基础流程
```bash
cd backend
PYTHONPATH=. uv run pytest tests/agent/ -v
```
预期结果43 个测试全部通过
### 测试 2: LLM 连接
```bash
cd backend
PYTHONPATH=. uv run python -c "
import asyncio
from app.services.agent.graph.runner import LLMService
async def test():
llm = LLMService()
result = await llm.analyze_code('print(\"hello\")', 'python')
print('LLM 连接成功:', 'issues' in result)
asyncio.run(test())
"
```
### 测试 3: 外部工具
```bash
# 测试 Semgrep
semgrep --version
# 测试 Bandit
bandit --version
```
### 测试 4: 端到端测试
1. 启动后端:`cd backend && uv run uvicorn app.main:app --reload`
2. 启动前端:`cd frontend && npm run dev`
3. 创建一个项目并上传代码
4. 选择 "Agent 审计模式" 创建任务
5. 观察执行日志和发现
---
## ⚠️ 已知限制
| 限制 | 影响 | 解决方案 |
|------|------|---------|
| **LLM 成本** | 每次审计消耗 Token | 使用 gpt-4o-mini 降低成本 |
| **扫描时间** | 大项目需要较长时间 | 设置合理的超时时间 |
| **误报率** | AI 可能产生误报 | 启用验证阶段过滤 |
| **外部工具依赖** | 需要手动安装 | 提供 Docker 镜像 |
---
## 🚀 生产环境建议
### 1. 资源配置
```yaml
# Kubernetes 部署示例
resources:
limits:
memory: "2Gi"
cpu: "2"
requests:
memory: "1Gi"
cpu: "1"
```
### 2. 并发控制
```env
# 限制同时运行的任务数
MAX_CONCURRENT_AGENT_TASKS=3
AGENT_TASK_TIMEOUT=1800 # 30 分钟
```
### 3. 日志监控
```python
# 配置日志级别
LOG_LEVEL=INFO
# 启用 SQLAlchemy 日志(调试用)
SQLALCHEMY_ECHO=false
```
### 4. 安全考虑
- [ ] 限制上传文件大小
- [ ] 限制扫描目录范围
- [ ] 启用沙箱隔离
- [ ] 配置 API 速率限制
---
## ✅ 部署状态检查
运行以下命令验证部署状态:
```bash
cd backend
PYTHONPATH=. uv run python -c "
print('检查部署状态...')
# 1. 检查数据库连接
try:
from app.db.session import async_session_factory
print('✅ 数据库配置正确')
except Exception as e:
print(f'❌ 数据库错误: {e}')
# 2. 检查 LLM 配置
from app.core.config import settings
if settings.LLM_API_KEY:
print('✅ LLM API Key 已配置')
else:
print('⚠️ LLM API Key 未配置')
# 3. 检查向量数据库
import os
if os.path.exists(settings.VECTOR_DB_PATH or '/tmp'):
print('✅ 向量数据库路径存在')
else:
print('⚠️ 向量数据库路径不存在')
# 4. 检查外部工具
import shutil
tools = ['semgrep', 'bandit']
for tool in tools:
if shutil.which(tool):
print(f'✅ {tool} 已安装')
else:
print(f'⚠️ {tool} 未安装(可选)')
print()
print('部署检查完成!')
"
```
---
## 📝 结论
Agent 审计功能已经具备**基本的生产能力**,但建议:
1. **先在测试环境验证** - 用一个小项目测试完整流程
2. **监控 LLM 成本** - 观察 Token 消耗情况
3. **逐步开放** - 先给少数用户使用,收集反馈
4. **持续优化** - 根据实际效果调整 prompt 和阈值
如有问题,请查看日志或联系开发团队。

View File

@ -55,3 +55,5 @@ EXPOSE 3000
ENTRYPOINT ["/docker-entrypoint.sh"]
CMD ["serve", "-s", "dist", "-l", "3000"]

View File

@ -18,3 +18,5 @@ export const ProtectedRoute = () => {

View File

@ -0,0 +1,280 @@
/**
* Agent React Hook
*
* React Agent
*/
import { useState, useEffect, useCallback, useRef } from 'react';
import {
AgentStreamHandler,
StreamEventData,
StreamOptions,
AgentStreamState,
} from '../shared/api/agentStream';
export interface UseAgentStreamOptions extends Omit<StreamOptions, 'onEvent'> {
autoConnect?: boolean;
maxEvents?: number;
}
export interface UseAgentStreamReturn extends AgentStreamState {
connect: () => void;
disconnect: () => void;
isConnected: boolean;
clearEvents: () => void;
}
/**
* Agent Hook
*
* @example
* ```tsx
* function AgentAuditPanel({ taskId }: { taskId: string }) {
* const {
* events,
* thinking,
* isThinking,
* toolCalls,
* currentPhase,
* progress,
* findings,
* isComplete,
* error,
* connect,
* disconnect,
* isConnected,
* } = useAgentStream(taskId);
*
* useEffect(() => {
* connect();
* return () => disconnect();
* }, [taskId]);
*
* return (
* <div>
* {isThinking && <ThinkingIndicator text={thinking} />}
* {toolCalls.map(tc => <ToolCallCard key={tc.name} {...tc} />)}
* {findings.map(f => <FindingCard key={f.id} {...f} />)}
* </div>
* );
* }
* ```
*/
export function useAgentStream(
taskId: string | null,
options: UseAgentStreamOptions = {}
): UseAgentStreamReturn {
const {
autoConnect = false,
maxEvents = 500,
includeThinking = true,
includeToolCalls = true,
afterSequence = 0,
...callbackOptions
} = options;
// 状态
const [events, setEvents] = useState<StreamEventData[]>([]);
const [thinking, setThinking] = useState('');
const [isThinking, setIsThinking] = useState(false);
const [toolCalls, setToolCalls] = useState<AgentStreamState['toolCalls']>([]);
const [currentPhase, setCurrentPhase] = useState('');
const [progress, setProgress] = useState({ current: 0, total: 100, percentage: 0 });
const [findings, setFindings] = useState<Record<string, unknown>[]>([]);
const [isComplete, setIsComplete] = useState(false);
const [error, setError] = useState<string | null>(null);
const [isConnected, setIsConnected] = useState(false);
// Handler ref
const handlerRef = useRef<AgentStreamHandler | null>(null);
const thinkingBufferRef = useRef<string[]>([]);
// 连接
const connect = useCallback(() => {
if (!taskId) return;
// 断开现有连接
if (handlerRef.current) {
handlerRef.current.disconnect();
}
// 重置状态
setEvents([]);
setThinking('');
setIsThinking(false);
setToolCalls([]);
setCurrentPhase('');
setProgress({ current: 0, total: 100, percentage: 0 });
setFindings([]);
setIsComplete(false);
setError(null);
thinkingBufferRef.current = [];
// 创建新的 handler
handlerRef.current = new AgentStreamHandler(taskId, {
includeThinking,
includeToolCalls,
afterSequence,
onEvent: (event) => {
setEvents((prev) => [...prev.slice(-maxEvents + 1), event]);
},
onThinkingStart: () => {
thinkingBufferRef.current = [];
setIsThinking(true);
setThinking('');
callbackOptions.onThinkingStart?.();
},
onThinkingToken: (token, accumulated) => {
thinkingBufferRef.current.push(token);
setThinking(accumulated);
callbackOptions.onThinkingToken?.(token, accumulated);
},
onThinkingEnd: (response) => {
setIsThinking(false);
setThinking(response);
thinkingBufferRef.current = [];
callbackOptions.onThinkingEnd?.(response);
},
onToolStart: (name, input) => {
setToolCalls((prev) => [
...prev,
{ name, input, status: 'running' as const },
]);
callbackOptions.onToolStart?.(name, input);
},
onToolEnd: (name, output, durationMs) => {
setToolCalls((prev) =>
prev.map((tc) =>
tc.name === name && tc.status === 'running'
? { ...tc, output, durationMs, status: 'success' as const }
: tc
)
);
callbackOptions.onToolEnd?.(name, output, durationMs);
},
onNodeStart: (nodeName, phase) => {
setCurrentPhase(phase);
callbackOptions.onNodeStart?.(nodeName, phase);
},
onNodeEnd: (nodeName, summary) => {
callbackOptions.onNodeEnd?.(nodeName, summary);
},
onProgress: (current, total, message) => {
setProgress({
current,
total,
percentage: total > 0 ? Math.round((current / total) * 100) : 0,
});
callbackOptions.onProgress?.(current, total, message);
},
onFinding: (finding, isVerified) => {
setFindings((prev) => [...prev, finding]);
callbackOptions.onFinding?.(finding, isVerified);
},
onComplete: (data) => {
setIsComplete(true);
setIsConnected(false);
callbackOptions.onComplete?.(data);
},
onError: (err) => {
setError(err);
setIsComplete(true);
setIsConnected(false);
callbackOptions.onError?.(err);
},
onHeartbeat: () => {
callbackOptions.onHeartbeat?.();
},
});
handlerRef.current.connect();
setIsConnected(true);
}, [taskId, includeThinking, includeToolCalls, afterSequence, maxEvents, callbackOptions]);
// 断开连接
const disconnect = useCallback(() => {
if (handlerRef.current) {
handlerRef.current.disconnect();
handlerRef.current = null;
}
setIsConnected(false);
}, []);
// 清空事件
const clearEvents = useCallback(() => {
setEvents([]);
}, []);
// 自动连接
useEffect(() => {
if (autoConnect && taskId) {
connect();
}
return () => {
disconnect();
};
}, [taskId, autoConnect, connect, disconnect]);
// 清理
useEffect(() => {
return () => {
if (handlerRef.current) {
handlerRef.current.disconnect();
}
};
}, []);
return {
events,
thinking,
isThinking,
toolCalls,
currentPhase,
progress,
findings,
isComplete,
error,
connect,
disconnect,
isConnected,
clearEvents,
};
}
/**
* Hook -
*/
export function useAgentThinking(taskId: string | null) {
const { thinking, isThinking, connect, disconnect } = useAgentStream(taskId, {
includeToolCalls: false,
});
return { thinking, isThinking, connect, disconnect };
}
/**
* Hook -
*/
export function useAgentToolCalls(taskId: string | null) {
const { toolCalls, connect, disconnect } = useAgentStream(taskId, {
includeThinking: false,
});
return { toolCalls, connect, disconnect };
}
export default useAgentStream;

View File

@ -1,6 +1,7 @@
/**
* Agent
* AI Agent
* LLM
*/
import { useState, useEffect, useRef, useCallback } from "react";
@ -9,12 +10,14 @@ import {
Terminal, Bot, Cpu, Shield, AlertTriangle, CheckCircle2,
Loader2, Code, Zap, Activity, ChevronRight, XCircle,
FileCode, Search, Bug, Lock, Play, Square, RefreshCw,
ArrowLeft, Download, ExternalLink
ArrowLeft, Download, ExternalLink, Brain, Wrench,
ChevronDown, ChevronUp, Clock, Sparkles
} from "lucide-react";
import { Button } from "@/components/ui/button";
import { Badge } from "@/components/ui/badge";
import { ScrollArea } from "@/components/ui/scroll-area";
import { toast } from "sonner";
import { useAgentStream } from "@/hooks/useAgentStream";
import {
type AgentTask,
type AgentEvent,
@ -91,10 +94,36 @@ export default function AgentAuditPage() {
const [isLoading, setIsLoading] = useState(true);
const [isStreaming, setIsStreaming] = useState(false);
const [currentTime, setCurrentTime] = useState(new Date().toLocaleTimeString("zh-CN", { hour12: false }));
const [showThinking, setShowThinking] = useState(true);
const [showToolDetails, setShowToolDetails] = useState(true);
const eventsEndRef = useRef<HTMLDivElement>(null);
const thinkingEndRef = useRef<HTMLDivElement>(null);
const abortControllerRef = useRef<AbortController | null>(null);
// 使用增强版流式 Hook
const {
thinking,
isThinking,
toolCalls,
currentPhase: streamPhase,
progress: streamProgress,
connect: connectStream,
disconnect: disconnectStream,
isConnected: isStreamConnected,
} = useAgentStream(taskId || null, {
includeThinking: true,
includeToolCalls: true,
onFinding: () => loadFindings(),
onComplete: () => {
loadTask();
loadFindings();
},
onError: (err) => {
console.error("Stream error:", err);
},
});
// 是否完成
const isComplete = task?.status === "completed" || task?.status === "failed" || task?.status === "cancelled";
@ -146,12 +175,24 @@ export default function AgentAuditPage() {
init();
}, [loadTask, loadEvents, loadFindings]);
// 事件流
// 连接增强版流式 API
useEffect(() => {
if (!taskId || isComplete || isLoading) return;
connectStream();
setIsStreaming(true);
return () => {
disconnectStream();
setIsStreaming(false);
};
}, [taskId, isComplete, isLoading, connectStream, disconnectStream]);
// 旧版事件流(作为后备)
useEffect(() => {
if (!taskId || isComplete || isLoading) return;
const startStreaming = async () => {
setIsStreaming(true);
abortControllerRef.current = new AbortController();
try {
@ -179,8 +220,6 @@ export default function AgentAuditPage() {
if ((error as Error).name !== "AbortError") {
console.error("Event stream error:", error);
}
} finally {
setIsStreaming(false);
}
};
@ -205,6 +244,30 @@ export default function AgentAuditPage() {
return () => clearInterval(interval);
}, []);
// 定期轮询任务状态(作为 SSE 的后备机制)
useEffect(() => {
if (!taskId || isComplete || isLoading) return;
// 每 3 秒轮询一次任务状态
const pollInterval = setInterval(async () => {
try {
const taskData = await getAgentTask(taskId);
setTask(taskData);
// 如果任务已完成/失败/取消,刷新其他数据
if (taskData.status === "completed" || taskData.status === "failed" || taskData.status === "cancelled") {
await loadEvents();
await loadFindings();
clearInterval(pollInterval);
}
} catch (error) {
console.error("Failed to poll task status:", error);
}
}, 3000);
return () => clearInterval(pollInterval);
}, [taskId, isComplete, isLoading, loadEvents, loadFindings]);
// 取消任务
const handleCancel = async () => {
if (!taskId) return;
@ -291,14 +354,85 @@ export default function AgentAuditPage() {
</div>
</div>
{/* 错误提示 */}
{task.status === "failed" && task.error_message && (
<div className="mx-4 mt-2 p-3 bg-red-900/30 border border-red-700 rounded-lg">
<div className="flex items-start gap-2">
<XCircle className="w-5 h-5 text-red-400 flex-shrink-0 mt-0.5" />
<div>
<p className="text-red-400 font-semibold text-sm"></p>
<p className="text-red-300/80 text-xs mt-1 font-mono break-all">{task.error_message}</p>
</div>
</div>
</div>
)}
<div className="flex h-[calc(100vh-56px)]">
{/* 左侧:执行日志 */}
<div className="flex-1 p-4 flex flex-col min-w-0">
{/* 思考过程展示区域 */}
{(isThinking || thinking) && showThinking && (
<div className="mb-4 bg-purple-950/30 rounded-lg border border-purple-800/50 overflow-hidden">
<div
className="flex items-center justify-between px-3 py-2 bg-purple-900/30 border-b border-purple-800/30 cursor-pointer"
onClick={() => setShowThinking(!showThinking)}
>
<div className="flex items-center gap-2 text-xs text-purple-400">
<Brain className={`w-4 h-4 ${isThinking ? "animate-pulse" : ""}`} />
<span className="uppercase tracking-wider">AI Thinking</span>
{isThinking && (
<span className="flex items-center gap-1 text-purple-300">
<Sparkles className="w-3 h-3 animate-spin" />
<span className="text-[10px]">Processing...</span>
</span>
)}
</div>
{showThinking ? <ChevronUp className="w-4 h-4 text-purple-400" /> : <ChevronDown className="w-4 h-4 text-purple-400" />}
</div>
<div className="max-h-40 overflow-y-auto">
<div className="p-3 text-sm text-purple-200/80 font-mono whitespace-pre-wrap">
{thinking || "正在思考..."}
{isThinking && <span className="animate-pulse text-purple-400"></span>}
</div>
<div ref={thinkingEndRef} />
</div>
</div>
)}
{/* 工具调用展示区域 */}
{toolCalls.length > 0 && showToolDetails && (
<div className="mb-4 bg-yellow-950/20 rounded-lg border border-yellow-800/30 overflow-hidden">
<div
className="flex items-center justify-between px-3 py-2 bg-yellow-900/20 border-b border-yellow-800/20 cursor-pointer"
onClick={() => setShowToolDetails(!showToolDetails)}
>
<div className="flex items-center gap-2 text-xs text-yellow-500">
<Wrench className="w-4 h-4" />
<span className="uppercase tracking-wider">Tool Calls</span>
<Badge variant="outline" className="text-[10px] px-1.5 py-0 bg-yellow-900/30 border-yellow-700 text-yellow-400">
{toolCalls.length}
</Badge>
</div>
{showToolDetails ? <ChevronUp className="w-4 h-4 text-yellow-500" /> : <ChevronDown className="w-4 h-4 text-yellow-500" />}
</div>
<div className="max-h-48 overflow-y-auto">
<div className="p-2 space-y-2">
{toolCalls.slice(-5).map((tc, idx) => (
<ToolCallCard key={`${tc.name}-${idx}`} toolCall={tc} />
))}
</div>
</div>
</div>
)}
<div className="flex items-center justify-between mb-3">
<div className="flex items-center gap-2 text-xs text-cyan-400">
<Terminal className="w-4 h-4" />
<span className="uppercase tracking-wider">Execution Log</span>
{isStreaming && (
{(isStreaming || isStreamConnected) && (
<span className="flex items-center gap-1 text-green-400">
<span className="w-2 h-2 bg-green-400 rounded-full animate-pulse" />
LIVE
@ -540,6 +674,94 @@ function EventLine({ event }: { event: AgentEvent }) {
);
}
// 工具调用卡片组件
interface ToolCallProps {
toolCall: {
name: string;
input: Record<string, unknown>;
output?: unknown;
durationMs?: number;
status: 'running' | 'success' | 'error';
};
}
function ToolCallCard({ toolCall }: ToolCallProps) {
const [expanded, setExpanded] = useState(false);
const statusConfig = {
running: {
icon: <Loader2 className="w-3 h-3 animate-spin text-yellow-400" />,
badge: "bg-yellow-900/30 border-yellow-700 text-yellow-400",
text: "Running",
},
success: {
icon: <CheckCircle2 className="w-3 h-3 text-green-400" />,
badge: "bg-green-900/30 border-green-700 text-green-400",
text: "Done",
},
error: {
icon: <XCircle className="w-3 h-3 text-red-400" />,
badge: "bg-red-900/30 border-red-700 text-red-400",
text: "Error",
},
};
const config = statusConfig[toolCall.status];
return (
<div className="bg-gray-900/50 rounded border border-gray-700/50 overflow-hidden">
<div
className="flex items-center justify-between px-2 py-1.5 cursor-pointer hover:bg-gray-800/50"
onClick={() => setExpanded(!expanded)}
>
<div className="flex items-center gap-2">
{config.icon}
<span className="text-xs font-mono text-gray-300">{toolCall.name}</span>
</div>
<div className="flex items-center gap-2">
{toolCall.durationMs && (
<span className="text-[10px] text-gray-500">
<Clock className="w-2.5 h-2.5 inline mr-0.5" />
{toolCall.durationMs}ms
</span>
)}
<Badge variant="outline" className={`text-[10px] px-1 py-0 ${config.badge}`}>
{config.text}
</Badge>
{expanded ? <ChevronUp className="w-3 h-3 text-gray-500" /> : <ChevronDown className="w-3 h-3 text-gray-500" />}
</div>
</div>
{expanded && (
<div className="border-t border-gray-700/50 text-[11px] font-mono">
{/* 输入 */}
{toolCall.input && Object.keys(toolCall.input).length > 0 && (
<div className="p-2 border-b border-gray-800/50">
<span className="text-gray-500 text-[10px] uppercase">Input:</span>
<pre className="mt-1 text-cyan-300/80 whitespace-pre-wrap break-all max-h-24 overflow-y-auto">
{JSON.stringify(toolCall.input, null, 2).slice(0, 500)}
</pre>
</div>
)}
{/* 输出 */}
{toolCall.output && (
<div className="p-2">
<span className="text-gray-500 text-[10px] uppercase">Output:</span>
<pre className="mt-1 text-green-300/80 whitespace-pre-wrap break-all max-h-24 overflow-y-auto">
{typeof toolCall.output === 'string'
? toolCall.output.slice(0, 500)
: JSON.stringify(toolCall.output, null, 2).slice(0, 500)
}
</pre>
</div>
)}
</div>
)}
</div>
);
}
// 发现卡片组件
function FindingCard({ finding }: { finding: AgentFinding }) {
const colorClass = severityColors[finding.severity] || severityColors.info;

View File

@ -0,0 +1,496 @@
/**
* Agent
*
* :
* 1. 使 EventSource API fetch + ReadableStream
* 2.
* 3.
*/
// 事件类型定义
export type StreamEventType =
// LLM 相关
| 'thinking_start'
| 'thinking_token'
| 'thinking_end'
// 工具调用相关
| 'tool_call_start'
| 'tool_call_input'
| 'tool_call_output'
| 'tool_call_end'
| 'tool_call_error'
// 节点相关
| 'node_start'
| 'node_end'
// 阶段相关
| 'phase_start'
| 'phase_end'
| 'phase_complete'
// 发现相关
| 'finding_new'
| 'finding_verified'
// 状态相关
| 'progress'
| 'info'
| 'warning'
| 'error'
// 任务相关
| 'task_start'
| 'task_complete'
| 'task_error'
| 'task_cancel'
| 'task_end'
// 心跳
| 'heartbeat';
// 工具调用详情
export interface ToolCallDetail {
name: string;
input?: Record<string, unknown>;
output?: unknown;
duration_ms?: number;
}
// 流式事件数据
export interface StreamEventData {
id?: string;
type: StreamEventType;
phase?: string;
message?: string;
sequence?: number;
timestamp?: string;
tool?: ToolCallDetail;
metadata?: Record<string, unknown>;
tokens_used?: number;
// 特定类型数据
token?: string; // thinking_token
accumulated?: string; // thinking_token/thinking_end
status?: string; // task_end
error?: string; // task_error
findings_count?: number; // task_complete
security_score?: number; // task_complete
}
// 事件回调类型
export type StreamEventCallback = (event: StreamEventData) => void;
// 流式选项
export interface StreamOptions {
includeThinking?: boolean;
includeToolCalls?: boolean;
afterSequence?: number;
onThinkingStart?: () => void;
onThinkingToken?: (token: string, accumulated: string) => void;
onThinkingEnd?: (fullResponse: string) => void;
onToolStart?: (toolName: string, input: Record<string, unknown>) => void;
onToolEnd?: (toolName: string, output: unknown, durationMs: number) => void;
onNodeStart?: (nodeName: string, phase: string) => void;
onNodeEnd?: (nodeName: string, summary: Record<string, unknown>) => void;
onFinding?: (finding: Record<string, unknown>, isVerified: boolean) => void;
onProgress?: (current: number, total: number, message: string) => void;
onComplete?: (data: { findingsCount: number; securityScore: number }) => void;
onError?: (error: string) => void;
onHeartbeat?: () => void;
onEvent?: StreamEventCallback; // 通用事件回调
}
/**
* Agent
*/
export class AgentStreamHandler {
private taskId: string;
private eventSource: EventSource | null = null;
private options: StreamOptions;
private reconnectAttempts = 0;
private maxReconnectAttempts = 5;
private reconnectDelay = 1000;
private isConnected = false;
private thinkingBuffer: string[] = [];
constructor(taskId: string, options: StreamOptions = {}) {
this.taskId = taskId;
this.options = {
includeThinking: true,
includeToolCalls: true,
afterSequence: 0,
...options,
};
}
/**
*
*/
connect(): void {
const token = localStorage.getItem('access_token');
if (!token) {
this.options.onError?.('未登录');
return;
}
const params = new URLSearchParams({
include_thinking: String(this.options.includeThinking),
include_tool_calls: String(this.options.includeToolCalls),
after_sequence: String(this.options.afterSequence),
});
// 使用 EventSource (不支持自定义 headers需要通过 URL 传递 token)
// 或者使用 fetch + ReadableStream
this.connectWithFetch(token, params);
}
/**
* 使 fetch headers
*/
private async connectWithFetch(token: string, params: URLSearchParams): Promise<void> {
const url = `/api/v1/agent-tasks/${this.taskId}/stream?${params}`;
try {
const response = await fetch(url, {
headers: {
'Authorization': `Bearer ${token}`,
'Accept': 'text/event-stream',
},
});
if (!response.ok) {
throw new Error(`HTTP ${response.status}: ${response.statusText}`);
}
this.isConnected = true;
this.reconnectAttempts = 0;
const reader = response.body?.getReader();
if (!reader) {
throw new Error('无法获取响应流');
}
const decoder = new TextDecoder();
let buffer = '';
while (true) {
const { done, value } = await reader.read();
if (done) {
break;
}
buffer += decoder.decode(value, { stream: true });
// 解析 SSE 事件
const events = this.parseSSE(buffer);
buffer = events.remaining;
for (const event of events.parsed) {
this.handleEvent(event);
}
}
} catch (error) {
this.isConnected = false;
console.error('Stream connection error:', error);
// 尝试重连
if (this.reconnectAttempts < this.maxReconnectAttempts) {
this.reconnectAttempts++;
setTimeout(() => this.connect(), this.reconnectDelay * this.reconnectAttempts);
} else {
this.options.onError?.(`连接失败: ${error}`);
}
}
}
/**
* SSE
*/
private parseSSE(buffer: string): { parsed: StreamEventData[]; remaining: string } {
const parsed: StreamEventData[] = [];
const lines = buffer.split('\n');
let remaining = '';
let currentEvent: Partial<StreamEventData> = {};
for (let i = 0; i < lines.length; i++) {
const line = lines[i];
// 空行表示事件结束
if (line === '') {
if (currentEvent.type) {
parsed.push(currentEvent as StreamEventData);
currentEvent = {};
}
continue;
}
// 检查是否是最后一行(可能不完整)
if (i === lines.length - 1 && !buffer.endsWith('\n')) {
remaining = line;
break;
}
// 解析 event: 行
if (line.startsWith('event:')) {
currentEvent.type = line.slice(6).trim() as StreamEventType;
}
// 解析 data: 行
else if (line.startsWith('data:')) {
try {
const data = JSON.parse(line.slice(5).trim());
currentEvent = { ...currentEvent, ...data };
} catch {
// 忽略解析错误
}
}
}
return { parsed, remaining };
}
/**
*
*/
private handleEvent(event: StreamEventData): void {
// 通用回调
this.options.onEvent?.(event);
// 分类处理
switch (event.type) {
// LLM 思考
case 'thinking_start':
this.thinkingBuffer = [];
this.options.onThinkingStart?.();
break;
case 'thinking_token':
if (event.token) {
this.thinkingBuffer.push(event.token);
this.options.onThinkingToken?.(
event.token,
event.accumulated || this.thinkingBuffer.join('')
);
}
break;
case 'thinking_end':
const fullResponse = event.accumulated || this.thinkingBuffer.join('');
this.thinkingBuffer = [];
this.options.onThinkingEnd?.(fullResponse);
break;
// 工具调用
case 'tool_call_start':
if (event.tool) {
this.options.onToolStart?.(
event.tool.name,
event.tool.input || {}
);
}
break;
case 'tool_call_end':
if (event.tool) {
this.options.onToolEnd?.(
event.tool.name,
event.tool.output,
event.tool.duration_ms || 0
);
}
break;
// 节点
case 'node_start':
this.options.onNodeStart?.(
event.metadata?.node as string || 'unknown',
event.phase || ''
);
break;
case 'node_end':
this.options.onNodeEnd?.(
event.metadata?.node as string || 'unknown',
event.metadata?.summary as Record<string, unknown> || {}
);
break;
// 发现
case 'finding_new':
case 'finding_verified':
this.options.onFinding?.(
event.metadata || {},
event.type === 'finding_verified'
);
break;
// 进度
case 'progress':
this.options.onProgress?.(
event.metadata?.current as number || 0,
event.metadata?.total as number || 100,
event.message || ''
);
break;
// 任务完成
case 'task_complete':
case 'task_end':
if (event.status !== 'cancelled' && event.status !== 'failed') {
this.options.onComplete?.({
findingsCount: event.findings_count || event.metadata?.findings_count as number || 0,
securityScore: event.security_score || event.metadata?.security_score as number || 100,
});
}
this.disconnect();
break;
// 错误
case 'task_error':
case 'error':
this.options.onError?.(event.error || event.message || '未知错误');
this.disconnect();
break;
// 心跳
case 'heartbeat':
this.options.onHeartbeat?.();
break;
}
}
/**
*
*/
disconnect(): void {
this.isConnected = false;
if (this.eventSource) {
this.eventSource.close();
this.eventSource = null;
}
}
/**
*
*/
get connected(): boolean {
return this.isConnected;
}
}
/**
* 便
*/
export function createAgentStream(
taskId: string,
options: StreamOptions = {}
): AgentStreamHandler {
return new AgentStreamHandler(taskId, options);
}
/**
* React Hook 使
*
* ```tsx
* const { events, thinking, toolCalls, connect, disconnect } = useAgentStream(taskId);
*
* useEffect(() => {
* connect();
* return () => disconnect();
* }, [taskId]);
* ```
*/
export interface AgentStreamState {
events: StreamEventData[];
thinking: string;
isThinking: boolean;
toolCalls: Array<{
name: string;
input: Record<string, unknown>;
output?: unknown;
durationMs?: number;
status: 'running' | 'success' | 'error';
}>;
currentPhase: string;
progress: { current: number; total: number; percentage: number };
findings: Array<Record<string, unknown>>;
isComplete: boolean;
error: string | null;
}
/**
* React
*/
export function createAgentStreamWithState(
taskId: string,
onStateChange: (state: AgentStreamState) => void
): AgentStreamHandler {
const state: AgentStreamState = {
events: [],
thinking: '',
isThinking: false,
toolCalls: [],
currentPhase: '',
progress: { current: 0, total: 100, percentage: 0 },
findings: [],
isComplete: false,
error: null,
};
const updateState = (updates: Partial<AgentStreamState>) => {
Object.assign(state, updates);
onStateChange({ ...state });
};
return new AgentStreamHandler(taskId, {
onEvent: (event) => {
updateState({
events: [...state.events, event].slice(-500), // 保留最近 500 条
});
},
onThinkingStart: () => {
updateState({ isThinking: true, thinking: '' });
},
onThinkingToken: (_, accumulated) => {
updateState({ thinking: accumulated });
},
onThinkingEnd: (response) => {
updateState({ isThinking: false, thinking: response });
},
onToolStart: (name, input) => {
updateState({
toolCalls: [
...state.toolCalls,
{ name, input, status: 'running' },
],
});
},
onToolEnd: (name, output, durationMs) => {
updateState({
toolCalls: state.toolCalls.map((tc) =>
tc.name === name && tc.status === 'running'
? { ...tc, output, durationMs, status: 'success' as const }
: tc
),
});
},
onNodeStart: (_, phase) => {
updateState({ currentPhase: phase });
},
onProgress: (current, total, _) => {
updateState({
progress: {
current,
total,
percentage: total > 0 ? Math.round((current / total) * 100) : 0,
},
});
},
onFinding: (finding, _) => {
updateState({
findings: [...state.findings, finding],
});
},
onComplete: () => {
updateState({ isComplete: true });
},
onError: (error) => {
updateState({ error, isComplete: true });
},
});
}

View File

@ -43,6 +43,9 @@ export interface AgentTask {
// 进度
progress_percentage: number;
// 错误信息
error_message: string | null;
}
export interface AgentFinding {
@ -249,7 +252,7 @@ export async function* streamAgentEvents(
afterSequence = 0,
signal?: AbortSignal
): AsyncGenerator<AgentEvent, void, unknown> {
const token = localStorage.getItem("auth_token");
const token = localStorage.getItem("access_token");
const baseUrl = import.meta.env.VITE_API_URL || "";
const url = `${baseUrl}/api/v1/agent-tasks/${taskId}/events?after_sequence=${afterSequence}`;