feat(agent): implement streaming support for agent events and enhance UI components
- Introduce streaming capabilities for agent events, allowing real-time updates during audits. - Add new hooks for managing agent stream events in React components. - Enhance the AgentAudit page to display LLM thinking processes and tool call details in real-time. - Update API endpoints to support streaming event data and improve error handling. - Refactor UI components for better organization and user experience during audits.
This commit is contained in:
parent
a43ebf1793
commit
58c918f557
|
|
@ -59,3 +59,5 @@ CMD ["uvicorn", "app.main:app", "--host", "0.0.0.0", "--port", "8000"]
|
|||
|
||||
|
||||
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -103,3 +103,5 @@ datefmt = %H:%M:%S
|
|||
|
||||
|
||||
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -90,3 +90,5 @@ else:
|
|||
|
||||
|
||||
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -25,3 +25,5 @@ def downgrade() -> None:
|
|||
|
||||
|
||||
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -29,6 +29,7 @@ from app.models.agent_task import (
|
|||
from app.models.project import Project
|
||||
from app.models.user import User
|
||||
from app.services.agent import AgentRunner, EventManager, run_agent_task
|
||||
from app.services.agent.streaming import StreamHandler, StreamEvent, StreamEventType
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
router = APIRouter()
|
||||
|
|
@ -432,6 +433,171 @@ async def stream_agent_events(
|
|||
)
|
||||
|
||||
|
||||
@router.get("/{task_id}/stream")
|
||||
async def stream_agent_with_thinking(
|
||||
task_id: str,
|
||||
include_thinking: bool = Query(True, description="是否包含 LLM 思考过程"),
|
||||
include_tool_calls: bool = Query(True, description="是否包含工具调用详情"),
|
||||
after_sequence: int = Query(0, ge=0, description="从哪个序号之后开始"),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
current_user: User = Depends(deps.get_current_user),
|
||||
):
|
||||
"""
|
||||
增强版事件流 (SSE)
|
||||
|
||||
支持:
|
||||
- LLM 思考过程的 Token 级流式输出
|
||||
- 工具调用的详细输入/输出
|
||||
- 节点执行状态
|
||||
- 发现事件
|
||||
|
||||
事件类型:
|
||||
- thinking_start: LLM 开始思考
|
||||
- thinking_token: LLM 输出 Token
|
||||
- thinking_end: LLM 思考结束
|
||||
- tool_call_start: 工具调用开始
|
||||
- tool_call_end: 工具调用结束
|
||||
- node_start: 节点开始
|
||||
- node_end: 节点结束
|
||||
- finding_new: 新发现
|
||||
- finding_verified: 验证通过
|
||||
- progress: 进度更新
|
||||
- task_complete: 任务完成
|
||||
- task_error: 任务错误
|
||||
- heartbeat: 心跳
|
||||
"""
|
||||
task = await db.get(AgentTask, task_id)
|
||||
if not task:
|
||||
raise HTTPException(status_code=404, detail="任务不存在")
|
||||
|
||||
project = await db.get(Project, task.project_id)
|
||||
if not project or project.owner_id != current_user.id:
|
||||
raise HTTPException(status_code=403, detail="无权访问此任务")
|
||||
|
||||
async def enhanced_event_generator():
|
||||
"""生成增强版 SSE 事件流"""
|
||||
last_sequence = after_sequence
|
||||
poll_interval = 0.3 # 更短的轮询间隔以支持流式
|
||||
heartbeat_interval = 15 # 心跳间隔
|
||||
max_idle = 600 # 10 分钟无事件后关闭
|
||||
idle_time = 0
|
||||
last_heartbeat = 0
|
||||
|
||||
# 事件类型过滤
|
||||
skip_types = set()
|
||||
if not include_thinking:
|
||||
skip_types.update(["thinking_start", "thinking_token", "thinking_end"])
|
||||
if not include_tool_calls:
|
||||
skip_types.update(["tool_call_start", "tool_call_input", "tool_call_output", "tool_call_end"])
|
||||
|
||||
while True:
|
||||
try:
|
||||
async with async_session_factory() as session:
|
||||
# 查询新事件
|
||||
result = await session.execute(
|
||||
select(AgentEvent)
|
||||
.where(AgentEvent.task_id == task_id)
|
||||
.where(AgentEvent.sequence > last_sequence)
|
||||
.order_by(AgentEvent.sequence)
|
||||
.limit(100)
|
||||
)
|
||||
events = result.scalars().all()
|
||||
|
||||
# 获取任务状态
|
||||
current_task = await session.get(AgentTask, task_id)
|
||||
task_status = current_task.status if current_task else None
|
||||
|
||||
if events:
|
||||
idle_time = 0
|
||||
for event in events:
|
||||
last_sequence = event.sequence
|
||||
|
||||
# 获取事件类型字符串
|
||||
event_type = event.event_type.value if hasattr(event.event_type, 'value') else str(event.event_type)
|
||||
|
||||
# 过滤事件
|
||||
if event_type in skip_types:
|
||||
continue
|
||||
|
||||
# 构建事件数据
|
||||
data = {
|
||||
"id": event.id,
|
||||
"type": event_type,
|
||||
"phase": event.phase.value if event.phase and hasattr(event.phase, 'value') else event.phase,
|
||||
"message": event.message,
|
||||
"sequence": event.sequence,
|
||||
"timestamp": event.created_at.isoformat() if event.created_at else None,
|
||||
}
|
||||
|
||||
# 添加工具调用详情
|
||||
if include_tool_calls and event.tool_name:
|
||||
data["tool"] = {
|
||||
"name": event.tool_name,
|
||||
"input": event.tool_input,
|
||||
"output": event.tool_output,
|
||||
"duration_ms": event.tool_duration_ms,
|
||||
}
|
||||
|
||||
# 添加元数据
|
||||
if event.event_metadata:
|
||||
data["metadata"] = event.event_metadata
|
||||
|
||||
# 添加 Token 使用
|
||||
if event.tokens_used:
|
||||
data["tokens_used"] = event.tokens_used
|
||||
|
||||
# 使用标准 SSE 格式
|
||||
yield f"event: {event_type}\ndata: {json.dumps(data, ensure_ascii=False)}\n\n"
|
||||
else:
|
||||
idle_time += poll_interval
|
||||
|
||||
# 检查任务是否结束
|
||||
if task_status in [AgentTaskStatus.COMPLETED, AgentTaskStatus.FAILED, AgentTaskStatus.CANCELLED]:
|
||||
end_data = {
|
||||
"type": "task_end",
|
||||
"status": task_status.value,
|
||||
"message": f"任务{'完成' if task_status == AgentTaskStatus.COMPLETED else '结束'}",
|
||||
}
|
||||
yield f"event: task_end\ndata: {json.dumps(end_data, ensure_ascii=False)}\n\n"
|
||||
break
|
||||
|
||||
# 发送心跳
|
||||
last_heartbeat += poll_interval
|
||||
if last_heartbeat >= heartbeat_interval:
|
||||
last_heartbeat = 0
|
||||
heartbeat_data = {
|
||||
"type": "heartbeat",
|
||||
"timestamp": datetime.now(timezone.utc).isoformat(),
|
||||
"last_sequence": last_sequence,
|
||||
}
|
||||
yield f"event: heartbeat\ndata: {json.dumps(heartbeat_data)}\n\n"
|
||||
|
||||
# 检查空闲超时
|
||||
if idle_time >= max_idle:
|
||||
timeout_data = {"type": "timeout", "message": "连接超时"}
|
||||
yield f"event: timeout\ndata: {json.dumps(timeout_data)}\n\n"
|
||||
break
|
||||
|
||||
await asyncio.sleep(poll_interval)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Stream error: {e}")
|
||||
error_data = {"type": "error", "message": str(e)}
|
||||
yield f"event: error\ndata: {json.dumps(error_data)}\n\n"
|
||||
break
|
||||
|
||||
return StreamingResponse(
|
||||
enhanced_event_generator(),
|
||||
media_type="text/event-stream",
|
||||
headers={
|
||||
"Cache-Control": "no-cache",
|
||||
"Connection": "keep-alive",
|
||||
"X-Accel-Buffering": "no",
|
||||
"Content-Type": "text/event-stream; charset=utf-8",
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
@router.get("/{task_id}/events/list", response_model=List[AgentEventResponse])
|
||||
async def list_agent_events(
|
||||
task_id: str,
|
||||
|
|
|
|||
|
|
@ -210,3 +210,5 @@ async def remove_project_member(
|
|||
|
||||
|
||||
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -225,3 +225,5 @@ async def toggle_user_status(
|
|||
|
||||
|
||||
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -29,3 +29,5 @@ def get_password_hash(password: str) -> str:
|
|||
|
||||
|
||||
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -12,3 +12,5 @@ class Base:
|
|||
|
||||
|
||||
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -28,3 +28,5 @@ async def async_session_factory():
|
|||
|
||||
|
||||
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -24,3 +24,5 @@ class InstantAnalysis(Base):
|
|||
|
||||
|
||||
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -25,3 +25,5 @@ class User(Base):
|
|||
|
||||
|
||||
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -30,3 +30,5 @@ class UserConfig(Base):
|
|||
|
||||
|
||||
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -10,3 +10,5 @@ class TokenPayload(BaseModel):
|
|||
|
||||
|
||||
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -41,3 +41,5 @@ class UserListResponse(BaseModel):
|
|||
|
||||
|
||||
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -147,11 +147,11 @@ class AnalysisAgent(BaseAgent):
|
|||
deep_findings = await self._analyze_entry_points(entry_points)
|
||||
all_findings.extend(deep_findings)
|
||||
|
||||
# 分析高风险区域
|
||||
# 分析高风险区域(现在会调用 LLM)
|
||||
risk_findings = await self._analyze_high_risk_areas(high_risk_areas)
|
||||
all_findings.extend(risk_findings)
|
||||
|
||||
# 语义搜索常见漏洞
|
||||
# 语义搜索常见漏洞(现在会调用 LLM)
|
||||
vuln_types = config.get("target_vulnerabilities", [
|
||||
"sql_injection", "xss", "command_injection",
|
||||
"path_traversal", "ssrf", "hardcoded_secret",
|
||||
|
|
@ -165,6 +165,12 @@ class AnalysisAgent(BaseAgent):
|
|||
vuln_findings = await self._search_vulnerability_pattern(vuln_type)
|
||||
all_findings.extend(vuln_findings)
|
||||
|
||||
# 🔥 3. 如果还没有发现,使用 LLM 进行全面扫描
|
||||
if len(all_findings) < 3:
|
||||
await self.emit_thinking("执行 LLM 全面代码扫描...")
|
||||
llm_findings = await self._llm_comprehensive_scan(tech_stack)
|
||||
all_findings.extend(llm_findings)
|
||||
|
||||
# 去重
|
||||
all_findings = self._deduplicate_findings(all_findings)
|
||||
|
||||
|
|
@ -292,12 +298,12 @@ class AnalysisAgent(BaseAgent):
|
|||
return findings
|
||||
|
||||
async def _analyze_high_risk_areas(self, high_risk_areas: List[str]) -> List[Dict]:
|
||||
"""分析高风险区域"""
|
||||
"""分析高风险区域 - 使用 LLM 深度分析"""
|
||||
findings = []
|
||||
|
||||
pattern_tool = self.tools.get("pattern_match")
|
||||
read_tool = self.tools.get("read_file")
|
||||
search_tool = self.tools.get("search_code")
|
||||
code_analysis_tool = self.tools.get("code_analysis")
|
||||
|
||||
if not search_tool:
|
||||
return findings
|
||||
|
|
@ -305,36 +311,92 @@ class AnalysisAgent(BaseAgent):
|
|||
# 在高风险区域搜索危险模式
|
||||
dangerous_patterns = [
|
||||
("execute(", "sql_injection"),
|
||||
("query(", "sql_injection"),
|
||||
("eval(", "code_injection"),
|
||||
("system(", "command_injection"),
|
||||
("exec(", "command_injection"),
|
||||
("subprocess", "command_injection"),
|
||||
("innerHTML", "xss"),
|
||||
("document.write", "xss"),
|
||||
("open(", "path_traversal"),
|
||||
("requests.get", "ssrf"),
|
||||
]
|
||||
|
||||
for pattern, vuln_type in dangerous_patterns[:5]:
|
||||
analyzed_files = set()
|
||||
|
||||
for pattern, vuln_type in dangerous_patterns[:8]:
|
||||
if self.is_cancelled:
|
||||
break
|
||||
|
||||
result = await search_tool.execute(keyword=pattern, max_results=10)
|
||||
|
||||
if result.success and result.metadata.get("matches", 0) > 0:
|
||||
for match in result.metadata.get("results", [])[:3]:
|
||||
for match in result.metadata.get("results", [])[:5]:
|
||||
file_path = match.get("file", "")
|
||||
line = match.get("line", 0)
|
||||
|
||||
# 检查是否在高风险区域
|
||||
in_high_risk = any(
|
||||
area in file_path for area in high_risk_areas
|
||||
# 避免重复分析同一个文件的同一区域
|
||||
file_key = f"{file_path}:{line // 50}"
|
||||
if file_key in analyzed_files:
|
||||
continue
|
||||
analyzed_files.add(file_key)
|
||||
|
||||
# 🔥 使用 LLM 深度分析找到的代码
|
||||
if read_tool and code_analysis_tool:
|
||||
await self.emit_thinking(f"LLM 分析 {file_path}:{line} 的 {vuln_type} 风险...")
|
||||
|
||||
# 读取代码上下文
|
||||
read_result = await read_tool.execute(
|
||||
file_path=file_path,
|
||||
start_line=max(1, line - 15),
|
||||
end_line=line + 25,
|
||||
)
|
||||
|
||||
if in_high_risk or True: # 暂时包含所有
|
||||
if read_result.success:
|
||||
# 调用 LLM 分析
|
||||
analysis_result = await code_analysis_tool.execute(
|
||||
code=read_result.data,
|
||||
file_path=file_path,
|
||||
focus=vuln_type,
|
||||
)
|
||||
|
||||
if analysis_result.success and analysis_result.metadata.get("issues"):
|
||||
for issue in analysis_result.metadata["issues"]:
|
||||
findings.append({
|
||||
"vulnerability_type": issue.get("type", vuln_type),
|
||||
"severity": issue.get("severity", "medium"),
|
||||
"title": issue.get("title", f"LLM 发现: {vuln_type}"),
|
||||
"description": issue.get("description", ""),
|
||||
"file_path": file_path,
|
||||
"line_start": issue.get("line", line),
|
||||
"code_snippet": issue.get("code_snippet", match.get("match", "")),
|
||||
"suggestion": issue.get("suggestion", ""),
|
||||
"ai_explanation": issue.get("ai_explanation", ""),
|
||||
"source": "llm_analysis",
|
||||
"needs_verification": True,
|
||||
})
|
||||
elif analysis_result.success:
|
||||
# LLM 分析了但没发现问题,仍记录原始发现
|
||||
findings.append({
|
||||
"vulnerability_type": vuln_type,
|
||||
"severity": "high" if in_high_risk else "medium",
|
||||
"severity": "low",
|
||||
"title": f"疑似 {vuln_type}: {pattern}",
|
||||
"description": f"在 {file_path} 中发现危险模式,但 LLM 分析未确认",
|
||||
"file_path": file_path,
|
||||
"line_start": line,
|
||||
"code_snippet": match.get("match", ""),
|
||||
"source": "pattern_search",
|
||||
"needs_verification": True,
|
||||
})
|
||||
else:
|
||||
# 没有 LLM 工具,使用基础模式匹配
|
||||
findings.append({
|
||||
"vulnerability_type": vuln_type,
|
||||
"severity": "medium",
|
||||
"title": f"疑似 {vuln_type}: {pattern}",
|
||||
"description": f"在 {file_path} 中发现危险模式 {pattern}",
|
||||
"file_path": file_path,
|
||||
"line_start": match.get("line", 0),
|
||||
"line_start": line,
|
||||
"code_snippet": match.get("match", ""),
|
||||
"source": "pattern_search",
|
||||
"needs_verification": True,
|
||||
|
|
@ -343,10 +405,13 @@ class AnalysisAgent(BaseAgent):
|
|||
return findings
|
||||
|
||||
async def _search_vulnerability_pattern(self, vuln_type: str) -> List[Dict]:
|
||||
"""搜索特定漏洞模式"""
|
||||
"""搜索特定漏洞模式 - 使用 RAG + LLM"""
|
||||
findings = []
|
||||
|
||||
security_tool = self.tools.get("security_search")
|
||||
code_analysis_tool = self.tools.get("code_analysis")
|
||||
read_tool = self.tools.get("read_file")
|
||||
|
||||
if not security_tool:
|
||||
return findings
|
||||
|
||||
|
|
@ -357,20 +422,176 @@ class AnalysisAgent(BaseAgent):
|
|||
|
||||
if result.success and result.metadata.get("results_count", 0) > 0:
|
||||
for item in result.metadata.get("results", [])[:5]:
|
||||
file_path = item.get("file_path", "")
|
||||
line_start = item.get("line_start", 0)
|
||||
content = item.get("content", "")[:2000]
|
||||
|
||||
# 🔥 使用 LLM 验证 RAG 搜索结果
|
||||
if code_analysis_tool and content:
|
||||
await self.emit_thinking(f"LLM 验证 RAG 发现的 {vuln_type}...")
|
||||
|
||||
analysis_result = await code_analysis_tool.execute(
|
||||
code=content,
|
||||
file_path=file_path,
|
||||
focus=vuln_type,
|
||||
)
|
||||
|
||||
if analysis_result.success and analysis_result.metadata.get("issues"):
|
||||
for issue in analysis_result.metadata["issues"]:
|
||||
findings.append({
|
||||
"vulnerability_type": issue.get("type", vuln_type),
|
||||
"severity": issue.get("severity", "medium"),
|
||||
"title": issue.get("title", f"LLM 确认: {vuln_type}"),
|
||||
"description": issue.get("description", ""),
|
||||
"file_path": file_path,
|
||||
"line_start": issue.get("line", line_start),
|
||||
"code_snippet": issue.get("code_snippet", content[:500]),
|
||||
"suggestion": issue.get("suggestion", ""),
|
||||
"ai_explanation": issue.get("ai_explanation", ""),
|
||||
"source": "rag_llm_analysis",
|
||||
"needs_verification": True,
|
||||
})
|
||||
else:
|
||||
# RAG 找到但 LLM 未确认
|
||||
findings.append({
|
||||
"vulnerability_type": vuln_type,
|
||||
"severity": "low",
|
||||
"title": f"疑似 {vuln_type} (待确认)",
|
||||
"description": f"RAG 搜索发现可能存在 {vuln_type},但 LLM 未确认",
|
||||
"file_path": file_path,
|
||||
"line_start": line_start,
|
||||
"code_snippet": content[:500],
|
||||
"source": "rag_search",
|
||||
"needs_verification": True,
|
||||
})
|
||||
else:
|
||||
findings.append({
|
||||
"vulnerability_type": vuln_type,
|
||||
"severity": "medium",
|
||||
"title": f"疑似 {vuln_type}",
|
||||
"description": f"通过语义搜索发现可能存在 {vuln_type}",
|
||||
"file_path": item.get("file_path", ""),
|
||||
"line_start": item.get("line_start", 0),
|
||||
"code_snippet": item.get("content", "")[:500],
|
||||
"file_path": file_path,
|
||||
"line_start": line_start,
|
||||
"code_snippet": content[:500],
|
||||
"source": "rag_search",
|
||||
"needs_verification": True,
|
||||
})
|
||||
|
||||
return findings
|
||||
|
||||
async def _llm_comprehensive_scan(self, tech_stack: Dict) -> List[Dict]:
|
||||
"""
|
||||
LLM 全面代码扫描
|
||||
当其他方法没有发现足够的问题时,使用 LLM 直接分析关键文件
|
||||
"""
|
||||
findings = []
|
||||
|
||||
list_tool = self.tools.get("list_files")
|
||||
read_tool = self.tools.get("read_file")
|
||||
code_analysis_tool = self.tools.get("code_analysis")
|
||||
|
||||
if not all([list_tool, read_tool, code_analysis_tool]):
|
||||
return findings
|
||||
|
||||
await self.emit_thinking("LLM 全面扫描关键代码文件...")
|
||||
|
||||
# 确定要扫描的文件类型
|
||||
languages = tech_stack.get("languages", [])
|
||||
file_patterns = []
|
||||
|
||||
if "Python" in languages:
|
||||
file_patterns.extend(["*.py"])
|
||||
if "JavaScript" in languages or "TypeScript" in languages:
|
||||
file_patterns.extend(["*.js", "*.ts"])
|
||||
if "Go" in languages:
|
||||
file_patterns.extend(["*.go"])
|
||||
if "Java" in languages:
|
||||
file_patterns.extend(["*.java"])
|
||||
if "PHP" in languages:
|
||||
file_patterns.extend(["*.php"])
|
||||
|
||||
if not file_patterns:
|
||||
file_patterns = ["*.py", "*.js", "*.ts", "*.go", "*.java", "*.php"]
|
||||
|
||||
# 扫描关键目录
|
||||
key_dirs = ["src", "app", "api", "routes", "controllers", "handlers", "lib", "utils", "."]
|
||||
scanned_files = 0
|
||||
max_files_to_scan = 10
|
||||
|
||||
for key_dir in key_dirs:
|
||||
if scanned_files >= max_files_to_scan or self.is_cancelled:
|
||||
break
|
||||
|
||||
for pattern in file_patterns[:3]:
|
||||
if scanned_files >= max_files_to_scan or self.is_cancelled:
|
||||
break
|
||||
|
||||
# 列出文件
|
||||
list_result = await list_tool.execute(
|
||||
directory=key_dir,
|
||||
pattern=pattern,
|
||||
recursive=True,
|
||||
max_files=20,
|
||||
)
|
||||
|
||||
if not list_result.success:
|
||||
continue
|
||||
|
||||
# 从输出中提取文件路径
|
||||
output = list_result.data
|
||||
file_paths = []
|
||||
for line in output.split('\n'):
|
||||
line = line.strip()
|
||||
if line.startswith('📄 '):
|
||||
file_paths.append(line[2:].strip())
|
||||
|
||||
# 分析每个文件
|
||||
for file_path in file_paths[:5]:
|
||||
if scanned_files >= max_files_to_scan or self.is_cancelled:
|
||||
break
|
||||
|
||||
# 跳过测试文件和配置文件
|
||||
if any(skip in file_path.lower() for skip in ['test', 'spec', 'mock', '__pycache__', 'node_modules']):
|
||||
continue
|
||||
|
||||
await self.emit_thinking(f"LLM 分析文件: {file_path}")
|
||||
|
||||
# 读取文件
|
||||
read_result = await read_tool.execute(
|
||||
file_path=file_path,
|
||||
max_lines=200,
|
||||
)
|
||||
|
||||
if not read_result.success:
|
||||
continue
|
||||
|
||||
scanned_files += 1
|
||||
|
||||
# 🔥 LLM 深度分析
|
||||
analysis_result = await code_analysis_tool.execute(
|
||||
code=read_result.data,
|
||||
file_path=file_path,
|
||||
)
|
||||
|
||||
if analysis_result.success and analysis_result.metadata.get("issues"):
|
||||
for issue in analysis_result.metadata["issues"]:
|
||||
findings.append({
|
||||
"vulnerability_type": issue.get("type", "other"),
|
||||
"severity": issue.get("severity", "medium"),
|
||||
"title": issue.get("title", "LLM 发现的安全问题"),
|
||||
"description": issue.get("description", ""),
|
||||
"file_path": file_path,
|
||||
"line_start": issue.get("line", 0),
|
||||
"code_snippet": issue.get("code_snippet", ""),
|
||||
"suggestion": issue.get("suggestion", ""),
|
||||
"ai_explanation": issue.get("ai_explanation", ""),
|
||||
"source": "llm_comprehensive_scan",
|
||||
"needs_verification": True,
|
||||
})
|
||||
|
||||
await self.emit_thinking(f"LLM 全面扫描完成,分析了 {scanned_files} 个文件")
|
||||
return findings
|
||||
|
||||
def _deduplicate_findings(self, findings: List[Dict]) -> List[Dict]:
|
||||
"""去重发现"""
|
||||
seen = set()
|
||||
|
|
|
|||
|
|
@ -0,0 +1,375 @@
|
|||
"""
|
||||
真正的 ReAct Agent 实现
|
||||
LLM 是大脑,全程参与决策!
|
||||
|
||||
ReAct 循环:
|
||||
1. Thought: LLM 思考当前状态和下一步
|
||||
2. Action: LLM 决定调用哪个工具
|
||||
3. Observation: 执行工具,获取结果
|
||||
4. 重复直到 LLM 决定完成
|
||||
"""
|
||||
|
||||
import json
|
||||
import logging
|
||||
import re
|
||||
from typing import List, Dict, Any, Optional, Tuple
|
||||
from dataclasses import dataclass
|
||||
|
||||
from .base import BaseAgent, AgentConfig, AgentResult, AgentType, AgentPattern
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
REACT_SYSTEM_PROMPT = """你是 DeepAudit 安全审计 Agent,一个专业的代码安全分析专家。
|
||||
|
||||
## 你的任务
|
||||
对目标项目进行全面的安全审计,发现潜在的安全漏洞。
|
||||
|
||||
## 你的工具
|
||||
{tools_description}
|
||||
|
||||
## 工作方式
|
||||
你需要通过 **思考-行动-观察** 循环来完成任务:
|
||||
|
||||
1. **Thought**: 分析当前情况,思考下一步应该做什么
|
||||
2. **Action**: 选择一个工具并执行
|
||||
3. **Observation**: 观察工具返回的结果
|
||||
4. 重复上述过程直到你认为审计完成
|
||||
|
||||
## 输出格式
|
||||
每一步必须严格按照以下格式输出:
|
||||
|
||||
```
|
||||
Thought: [你的思考过程,分析当前状态,决定下一步]
|
||||
Action: [工具名称]
|
||||
Action Input: [工具参数,JSON 格式]
|
||||
```
|
||||
|
||||
当你完成分析后,输出:
|
||||
```
|
||||
Thought: [总结分析结果]
|
||||
Final Answer: [JSON 格式的最终发现]
|
||||
```
|
||||
|
||||
## Final Answer 格式
|
||||
```json
|
||||
{{
|
||||
"findings": [
|
||||
{{
|
||||
"vulnerability_type": "sql_injection",
|
||||
"severity": "high",
|
||||
"title": "SQL 注入漏洞",
|
||||
"description": "详细描述",
|
||||
"file_path": "path/to/file.py",
|
||||
"line_start": 42,
|
||||
"code_snippet": "危险代码片段",
|
||||
"suggestion": "修复建议"
|
||||
}}
|
||||
],
|
||||
"summary": "审计总结"
|
||||
}}
|
||||
```
|
||||
|
||||
## 审计策略建议
|
||||
1. 先用 list_files 了解项目结构
|
||||
2. 识别关键文件(路由、控制器、数据库操作)
|
||||
3. 使用 search_code 搜索危险模式(eval, exec, query, innerHTML 等)
|
||||
4. 读取可疑文件进行深度分析
|
||||
5. 如果有 semgrep,用它进行全面扫描
|
||||
|
||||
## 重点关注的漏洞类型
|
||||
- SQL 注入 (query, execute, raw SQL)
|
||||
- XSS (innerHTML, document.write, v-html)
|
||||
- 命令注入 (exec, system, subprocess, child_process)
|
||||
- 路径遍历 (open, readFile, path concatenation)
|
||||
- SSRF (requests, fetch, http client)
|
||||
- 硬编码密钥 (password, secret, api_key, token)
|
||||
- 不安全的反序列化 (pickle, yaml.load, eval)
|
||||
|
||||
现在开始审计!"""
|
||||
|
||||
|
||||
@dataclass
|
||||
class AgentStep:
|
||||
"""Agent 执行步骤"""
|
||||
thought: str
|
||||
action: Optional[str] = None
|
||||
action_input: Optional[Dict] = None
|
||||
observation: Optional[str] = None
|
||||
is_final: bool = False
|
||||
final_answer: Optional[Dict] = None
|
||||
|
||||
|
||||
class ReActAgent(BaseAgent):
|
||||
"""
|
||||
真正的 ReAct Agent
|
||||
|
||||
LLM 全程参与决策,自主选择工具和分析策略
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
llm_service,
|
||||
tools: Dict[str, Any],
|
||||
event_emitter=None,
|
||||
agent_type: AgentType = AgentType.ANALYSIS,
|
||||
max_iterations: int = 30,
|
||||
):
|
||||
config = AgentConfig(
|
||||
name="ReActAgent",
|
||||
agent_type=agent_type,
|
||||
pattern=AgentPattern.REACT,
|
||||
max_iterations=max_iterations,
|
||||
system_prompt=REACT_SYSTEM_PROMPT,
|
||||
)
|
||||
super().__init__(config, llm_service, tools, event_emitter)
|
||||
|
||||
self._conversation_history: List[Dict[str, str]] = []
|
||||
self._steps: List[AgentStep] = []
|
||||
|
||||
def _get_tools_description(self) -> str:
|
||||
"""生成工具描述"""
|
||||
descriptions = []
|
||||
|
||||
for name, tool in self.tools.items():
|
||||
if name.startswith("_"):
|
||||
continue
|
||||
|
||||
desc = f"### {name}\n"
|
||||
desc += f"{tool.description}\n"
|
||||
|
||||
# 添加参数说明
|
||||
if hasattr(tool, 'args_schema') and tool.args_schema:
|
||||
schema = tool.args_schema.schema()
|
||||
properties = schema.get("properties", {})
|
||||
if properties:
|
||||
desc += "参数:\n"
|
||||
for param_name, param_info in properties.items():
|
||||
param_desc = param_info.get("description", "")
|
||||
param_type = param_info.get("type", "string")
|
||||
desc += f" - {param_name} ({param_type}): {param_desc}\n"
|
||||
|
||||
descriptions.append(desc)
|
||||
|
||||
return "\n".join(descriptions)
|
||||
|
||||
def _build_system_prompt(self, project_info: Dict, task_context: str = "") -> str:
|
||||
"""构建系统提示词"""
|
||||
tools_desc = self._get_tools_description()
|
||||
prompt = self.config.system_prompt.format(tools_description=tools_desc)
|
||||
|
||||
if project_info:
|
||||
prompt += f"\n\n## 项目信息\n"
|
||||
prompt += f"- 名称: {project_info.get('name', 'unknown')}\n"
|
||||
prompt += f"- 语言: {', '.join(project_info.get('languages', ['unknown']))}\n"
|
||||
prompt += f"- 文件数: {project_info.get('file_count', 'unknown')}\n"
|
||||
|
||||
if task_context:
|
||||
prompt += f"\n\n## 任务上下文\n{task_context}"
|
||||
|
||||
return prompt
|
||||
|
||||
def _parse_llm_response(self, response: str) -> AgentStep:
|
||||
"""解析 LLM 响应"""
|
||||
step = AgentStep(thought="")
|
||||
|
||||
# 提取 Thought
|
||||
thought_match = re.search(r'Thought:\s*(.*?)(?=Action:|Final Answer:|$)', response, re.DOTALL)
|
||||
if thought_match:
|
||||
step.thought = thought_match.group(1).strip()
|
||||
|
||||
# 检查是否是最终答案
|
||||
final_match = re.search(r'Final Answer:\s*(.*?)$', response, re.DOTALL)
|
||||
if final_match:
|
||||
step.is_final = True
|
||||
try:
|
||||
# 尝试提取 JSON
|
||||
answer_text = final_match.group(1).strip()
|
||||
# 移除 markdown 代码块
|
||||
answer_text = re.sub(r'```json\s*', '', answer_text)
|
||||
answer_text = re.sub(r'```\s*', '', answer_text)
|
||||
step.final_answer = json.loads(answer_text)
|
||||
except json.JSONDecodeError:
|
||||
step.final_answer = {"raw_answer": final_match.group(1).strip()}
|
||||
return step
|
||||
|
||||
# 提取 Action
|
||||
action_match = re.search(r'Action:\s*(\w+)', response)
|
||||
if action_match:
|
||||
step.action = action_match.group(1).strip()
|
||||
|
||||
# 提取 Action Input
|
||||
input_match = re.search(r'Action Input:\s*(.*?)(?=Thought:|Action:|Observation:|$)', response, re.DOTALL)
|
||||
if input_match:
|
||||
input_text = input_match.group(1).strip()
|
||||
# 移除 markdown 代码块
|
||||
input_text = re.sub(r'```json\s*', '', input_text)
|
||||
input_text = re.sub(r'```\s*', '', input_text)
|
||||
try:
|
||||
step.action_input = json.loads(input_text)
|
||||
except json.JSONDecodeError:
|
||||
# 尝试简单解析
|
||||
step.action_input = {"raw_input": input_text}
|
||||
|
||||
return step
|
||||
|
||||
async def _execute_tool(self, tool_name: str, tool_input: Dict) -> str:
|
||||
"""执行工具"""
|
||||
tool = self.tools.get(tool_name)
|
||||
|
||||
if not tool:
|
||||
return f"错误: 工具 '{tool_name}' 不存在。可用工具: {list(self.tools.keys())}"
|
||||
|
||||
try:
|
||||
self._tool_calls += 1
|
||||
await self.emit_tool_call(tool_name, tool_input)
|
||||
|
||||
import time
|
||||
start = time.time()
|
||||
|
||||
result = await tool.execute(**tool_input)
|
||||
|
||||
duration_ms = int((time.time() - start) * 1000)
|
||||
await self.emit_tool_result(tool_name, str(result.data)[:200], duration_ms)
|
||||
|
||||
if result.success:
|
||||
# 截断过长的输出
|
||||
output = str(result.data)
|
||||
if len(output) > 4000:
|
||||
output = output[:4000] + "\n\n... [输出已截断,共 {} 字符]".format(len(str(result.data)))
|
||||
return output
|
||||
else:
|
||||
return f"工具执行失败: {result.error}"
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Tool execution error: {e}")
|
||||
return f"工具执行错误: {str(e)}"
|
||||
|
||||
async def run(self, input_data: Dict[str, Any]) -> AgentResult:
|
||||
"""
|
||||
执行 ReAct Agent
|
||||
|
||||
LLM 全程参与,自主决策!
|
||||
"""
|
||||
import time
|
||||
start_time = time.time()
|
||||
|
||||
project_info = input_data.get("project_info", {})
|
||||
task_context = input_data.get("task_context", "")
|
||||
config = input_data.get("config", {})
|
||||
|
||||
# 构建系统提示词
|
||||
system_prompt = self._build_system_prompt(project_info, task_context)
|
||||
|
||||
# 初始化对话历史
|
||||
self._conversation_history = [
|
||||
{"role": "system", "content": system_prompt},
|
||||
{"role": "user", "content": "请开始对项目进行安全审计。首先了解项目结构,然后系统性地搜索和分析潜在的安全漏洞。"},
|
||||
]
|
||||
|
||||
self._steps = []
|
||||
all_findings = []
|
||||
|
||||
await self.emit_thinking("🤖 ReAct Agent 启动,LLM 开始自主分析...")
|
||||
|
||||
try:
|
||||
for iteration in range(self.config.max_iterations):
|
||||
if self.is_cancelled:
|
||||
break
|
||||
|
||||
self._iteration = iteration + 1
|
||||
|
||||
await self.emit_thinking(f"💭 第 {iteration + 1} 轮思考...")
|
||||
|
||||
# 🔥 调用 LLM 进行思考和决策
|
||||
response = await self.llm_service.chat_completion_raw(
|
||||
messages=self._conversation_history,
|
||||
temperature=0.1,
|
||||
max_tokens=2048,
|
||||
)
|
||||
|
||||
llm_output = response.get("content", "")
|
||||
self._total_tokens += response.get("usage", {}).get("total_tokens", 0)
|
||||
|
||||
# 发射思考事件
|
||||
await self.emit_event("thinking", f"LLM: {llm_output[:500]}...")
|
||||
|
||||
# 解析 LLM 响应
|
||||
step = self._parse_llm_response(llm_output)
|
||||
self._steps.append(step)
|
||||
|
||||
# 添加 LLM 响应到历史
|
||||
self._conversation_history.append({
|
||||
"role": "assistant",
|
||||
"content": llm_output,
|
||||
})
|
||||
|
||||
# 检查是否完成
|
||||
if step.is_final:
|
||||
await self.emit_thinking("✅ LLM 完成分析,生成最终报告")
|
||||
|
||||
if step.final_answer and "findings" in step.final_answer:
|
||||
all_findings = step.final_answer["findings"]
|
||||
break
|
||||
|
||||
# 执行工具
|
||||
if step.action:
|
||||
await self.emit_thinking(f"🔧 LLM 决定调用工具: {step.action}")
|
||||
|
||||
observation = await self._execute_tool(
|
||||
step.action,
|
||||
step.action_input or {}
|
||||
)
|
||||
|
||||
step.observation = observation
|
||||
|
||||
# 添加观察结果到历史
|
||||
self._conversation_history.append({
|
||||
"role": "user",
|
||||
"content": f"Observation: {observation}",
|
||||
})
|
||||
else:
|
||||
# LLM 没有选择工具,提示它继续
|
||||
self._conversation_history.append({
|
||||
"role": "user",
|
||||
"content": "请继续分析,选择一个工具执行,或者如果分析完成,输出 Final Answer。",
|
||||
})
|
||||
|
||||
duration_ms = int((time.time() - start_time) * 1000)
|
||||
|
||||
await self.emit_event(
|
||||
"info",
|
||||
f"🎯 ReAct Agent 完成: {len(all_findings)} 个发现, {self._iteration} 轮迭代, {self._tool_calls} 次工具调用"
|
||||
)
|
||||
|
||||
return AgentResult(
|
||||
success=True,
|
||||
data={
|
||||
"findings": all_findings,
|
||||
"steps": [
|
||||
{
|
||||
"thought": s.thought,
|
||||
"action": s.action,
|
||||
"action_input": s.action_input,
|
||||
"observation": s.observation[:500] if s.observation else None,
|
||||
}
|
||||
for s in self._steps
|
||||
],
|
||||
},
|
||||
iterations=self._iteration,
|
||||
tool_calls=self._tool_calls,
|
||||
tokens_used=self._total_tokens,
|
||||
duration_ms=duration_ms,
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"ReAct Agent failed: {e}", exc_info=True)
|
||||
return AgentResult(success=False, error=str(e))
|
||||
|
||||
def get_conversation_history(self) -> List[Dict[str, str]]:
|
||||
"""获取对话历史"""
|
||||
return self._conversation_history
|
||||
|
||||
def get_steps(self) -> List[AgentStep]:
|
||||
"""获取执行步骤"""
|
||||
return self._steps
|
||||
|
|
@ -193,6 +193,37 @@ class AgentEventEmitter:
|
|||
},
|
||||
))
|
||||
|
||||
async def emit_task_complete(
|
||||
self,
|
||||
findings_count: int,
|
||||
duration_ms: int,
|
||||
message: Optional[str] = None,
|
||||
):
|
||||
"""发射任务完成事件"""
|
||||
await self.emit(AgentEventData(
|
||||
event_type="task_complete",
|
||||
message=message or f"✅ 审计完成!发现 {findings_count} 个漏洞,耗时 {duration_ms/1000:.1f}秒",
|
||||
metadata={
|
||||
"findings_count": findings_count,
|
||||
"duration_ms": duration_ms,
|
||||
},
|
||||
))
|
||||
|
||||
async def emit_task_error(self, error: str, message: Optional[str] = None):
|
||||
"""发射任务错误事件"""
|
||||
await self.emit(AgentEventData(
|
||||
event_type="task_error",
|
||||
message=message or f"❌ 任务失败: {error}",
|
||||
metadata={"error": error},
|
||||
))
|
||||
|
||||
async def emit_task_cancelled(self, message: Optional[str] = None):
|
||||
"""发射任务取消事件"""
|
||||
await self.emit(AgentEventData(
|
||||
event_type="task_cancel",
|
||||
message=message or "⚠️ 任务已取消",
|
||||
))
|
||||
|
||||
|
||||
class EventManager:
|
||||
"""
|
||||
|
|
@ -369,3 +400,14 @@ class EventManager:
|
|||
"""创建事件发射器"""
|
||||
return AgentEventEmitter(task_id, self)
|
||||
|
||||
async def close(self):
|
||||
"""关闭事件管理器,清理资源"""
|
||||
# 清理所有队列
|
||||
for task_id in list(self._event_queues.keys()):
|
||||
self.remove_queue(task_id)
|
||||
|
||||
# 清理所有回调
|
||||
self._event_callbacks.clear()
|
||||
|
||||
logger.debug("EventManager closed")
|
||||
|
||||
|
|
|
|||
|
|
@ -15,6 +15,7 @@ from sqlalchemy.ext.asyncio import AsyncSession
|
|||
from langgraph.graph import StateGraph, END
|
||||
from langgraph.checkpoint.memory import MemorySaver
|
||||
|
||||
from app.services.agent.streaming import StreamHandler, StreamEvent, StreamEventType
|
||||
from app.models.agent_task import (
|
||||
AgentTask, AgentEvent, AgentFinding,
|
||||
AgentTaskStatus, AgentTaskPhase, AgentEventType,
|
||||
|
|
@ -39,11 +40,15 @@ logger = logging.getLogger(__name__)
|
|||
|
||||
|
||||
class LLMService:
|
||||
"""LLM 服务封装"""
|
||||
"""
|
||||
LLM 服务封装
|
||||
提供代码分析、漏洞检测等 AI 功能
|
||||
"""
|
||||
|
||||
def __init__(self, model: Optional[str] = None, api_key: Optional[str] = None):
|
||||
self.model = model or settings.LLM_MODEL or "gpt-4o-mini"
|
||||
self.api_key = api_key or settings.LLM_API_KEY
|
||||
self.base_url = settings.LLM_BASE_URL
|
||||
|
||||
async def chat_completion_raw(
|
||||
self,
|
||||
|
|
@ -61,6 +66,7 @@ class LLMService:
|
|||
temperature=temperature,
|
||||
max_tokens=max_tokens,
|
||||
api_key=self.api_key,
|
||||
base_url=self.base_url,
|
||||
)
|
||||
|
||||
return {
|
||||
|
|
@ -76,6 +82,125 @@ class LLMService:
|
|||
logger.error(f"LLM call failed: {e}")
|
||||
raise
|
||||
|
||||
async def analyze_code(self, code: str, language: str) -> Dict[str, Any]:
|
||||
"""
|
||||
分析代码安全问题
|
||||
|
||||
Args:
|
||||
code: 代码内容
|
||||
language: 编程语言
|
||||
|
||||
Returns:
|
||||
分析结果,包含 issues 列表
|
||||
"""
|
||||
prompt = f"""请分析以下 {language} 代码的安全问题。
|
||||
|
||||
代码:
|
||||
```{language}
|
||||
{code[:8000]}
|
||||
```
|
||||
|
||||
请识别所有潜在的安全漏洞,包括但不限于:
|
||||
- SQL 注入
|
||||
- XSS (跨站脚本)
|
||||
- 命令注入
|
||||
- 路径遍历
|
||||
- 不安全的反序列化
|
||||
- 硬编码密钥/密码
|
||||
- 不安全的加密
|
||||
- SSRF
|
||||
- 认证/授权问题
|
||||
|
||||
对于每个发现的问题,请提供:
|
||||
1. 漏洞类型
|
||||
2. 严重程度 (critical/high/medium/low)
|
||||
3. 问题描述
|
||||
4. 具体行号
|
||||
5. 修复建议
|
||||
|
||||
请以 JSON 格式返回结果:
|
||||
{{
|
||||
"issues": [
|
||||
{{
|
||||
"type": "漏洞类型",
|
||||
"severity": "严重程度",
|
||||
"title": "问题标题",
|
||||
"description": "详细描述",
|
||||
"line": 行号,
|
||||
"code_snippet": "相关代码片段",
|
||||
"suggestion": "修复建议"
|
||||
}}
|
||||
],
|
||||
"quality_score": 0-100
|
||||
}}
|
||||
|
||||
如果没有发现安全问题,返回空的 issues 数组和较高的 quality_score。"""
|
||||
|
||||
try:
|
||||
result = await self.chat_completion_raw(
|
||||
messages=[
|
||||
{"role": "system", "content": "你是一位专业的代码安全审计专家,擅长发现代码中的安全漏洞。请只返回 JSON 格式的结果,不要包含其他内容。"},
|
||||
{"role": "user", "content": prompt},
|
||||
],
|
||||
temperature=0.1,
|
||||
max_tokens=4096,
|
||||
)
|
||||
|
||||
content = result.get("content", "{}")
|
||||
|
||||
# 尝试提取 JSON
|
||||
import json
|
||||
import re
|
||||
|
||||
# 尝试直接解析
|
||||
try:
|
||||
return json.loads(content)
|
||||
except json.JSONDecodeError:
|
||||
pass
|
||||
|
||||
# 尝试从 markdown 代码块提取
|
||||
json_match = re.search(r'```(?:json)?\s*([\s\S]*?)\s*```', content)
|
||||
if json_match:
|
||||
try:
|
||||
return json.loads(json_match.group(1))
|
||||
except json.JSONDecodeError:
|
||||
pass
|
||||
|
||||
# 返回空结果
|
||||
return {"issues": [], "quality_score": 80}
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Code analysis failed: {e}")
|
||||
return {"issues": [], "quality_score": 0, "error": str(e)}
|
||||
|
||||
async def analyze_code_with_custom_prompt(
|
||||
self,
|
||||
code: str,
|
||||
language: str,
|
||||
prompt: str,
|
||||
**kwargs
|
||||
) -> Dict[str, Any]:
|
||||
"""使用自定义提示词分析代码"""
|
||||
full_prompt = prompt.replace("{code}", code).replace("{language}", language)
|
||||
|
||||
try:
|
||||
result = await self.chat_completion_raw(
|
||||
messages=[
|
||||
{"role": "system", "content": "你是一位专业的代码安全审计专家。"},
|
||||
{"role": "user", "content": full_prompt},
|
||||
],
|
||||
temperature=0.1,
|
||||
)
|
||||
|
||||
return {
|
||||
"analysis": result.get("content", ""),
|
||||
"usage": result.get("usage", {}),
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Custom analysis failed: {e}")
|
||||
return {"analysis": "", "error": str(e)}
|
||||
|
||||
|
||||
class AgentRunner:
|
||||
"""
|
||||
|
|
@ -97,8 +222,9 @@ class AgentRunner:
|
|||
self.task = task
|
||||
self.project_root = project_root
|
||||
|
||||
# 事件管理
|
||||
self.event_manager = EventManager()
|
||||
# 事件管理 - 传入 db_session_factory 以持久化事件
|
||||
from app.db.session import async_session_factory
|
||||
self.event_manager = EventManager(db_session_factory=async_session_factory)
|
||||
self.event_emitter = AgentEventEmitter(task.id, self.event_manager)
|
||||
|
||||
# LLM 服务
|
||||
|
|
@ -120,6 +246,22 @@ class AgentRunner:
|
|||
|
||||
# 状态
|
||||
self._cancelled = False
|
||||
self._running_task: Optional[asyncio.Task] = None
|
||||
|
||||
# 流式处理器
|
||||
self.stream_handler = StreamHandler(task.id)
|
||||
|
||||
def cancel(self):
|
||||
"""取消任务"""
|
||||
self._cancelled = True
|
||||
if self._running_task and not self._running_task.done():
|
||||
self._running_task.cancel()
|
||||
logger.info(f"Task {self.task.id} cancellation requested")
|
||||
|
||||
@property
|
||||
def is_cancelled(self) -> bool:
|
||||
"""检查是否已取消"""
|
||||
return self._cancelled
|
||||
|
||||
async def initialize(self):
|
||||
"""初始化 Runner"""
|
||||
|
|
@ -149,15 +291,15 @@ class AgentRunner:
|
|||
)
|
||||
|
||||
self.indexer = CodeIndexer(
|
||||
embedding_service=embedding_service,
|
||||
vector_db_path=settings.VECTOR_DB_PATH,
|
||||
collection_name=f"project_{self.task.project_id}",
|
||||
embedding_service=embedding_service,
|
||||
persist_directory=settings.VECTOR_DB_PATH,
|
||||
)
|
||||
|
||||
self.retriever = CodeRetriever(
|
||||
embedding_service=embedding_service,
|
||||
vector_db_path=settings.VECTOR_DB_PATH,
|
||||
collection_name=f"project_{self.task.project_id}",
|
||||
embedding_service=embedding_service,
|
||||
persist_directory=settings.VECTOR_DB_PATH,
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
|
|
@ -261,6 +403,18 @@ class AgentRunner:
|
|||
Returns:
|
||||
最终状态
|
||||
"""
|
||||
result = {}
|
||||
async for _ in self.run_with_streaming():
|
||||
pass # 消费所有事件
|
||||
return result
|
||||
|
||||
async def run_with_streaming(self) -> AsyncGenerator[StreamEvent, None]:
|
||||
"""
|
||||
带流式输出的审计执行
|
||||
|
||||
Yields:
|
||||
StreamEvent: 流式事件(包含 LLM 思考、工具调用等)
|
||||
"""
|
||||
import time
|
||||
start_time = time.time()
|
||||
|
||||
|
|
@ -271,17 +425,28 @@ class AgentRunner:
|
|||
# 更新任务状态
|
||||
await self._update_task_status(AgentTaskStatus.RUNNING)
|
||||
|
||||
# 发射任务开始事件
|
||||
yield StreamEvent(
|
||||
event_type=StreamEventType.TASK_START,
|
||||
sequence=self.stream_handler._next_sequence(),
|
||||
data={"task_id": self.task.id, "message": "🚀 审计任务开始"},
|
||||
)
|
||||
|
||||
# 1. 索引代码
|
||||
await self._index_code()
|
||||
|
||||
if self._cancelled:
|
||||
return {"success": False, "error": "任务已取消"}
|
||||
yield StreamEvent(
|
||||
event_type=StreamEventType.TASK_CANCEL,
|
||||
sequence=self.stream_handler._next_sequence(),
|
||||
data={"message": "任务已取消"},
|
||||
)
|
||||
return
|
||||
|
||||
# 2. 收集项目信息
|
||||
project_info = await self._collect_project_info()
|
||||
|
||||
# 3. 构建初始状态
|
||||
# 从任务字段构建配置
|
||||
task_config = {
|
||||
"target_vulnerabilities": self.task.target_vulnerabilities or [],
|
||||
"verification_level": self.task.verification_level or "sandbox",
|
||||
|
|
@ -314,7 +479,7 @@ class AgentRunner:
|
|||
"error": None,
|
||||
}
|
||||
|
||||
# 4. 执行 LangGraph
|
||||
# 4. 执行 LangGraph with astream_events
|
||||
await self.event_emitter.emit_phase_start("langgraph", "🔄 启动 LangGraph 工作流")
|
||||
|
||||
run_config = {
|
||||
|
|
@ -325,16 +490,47 @@ class AgentRunner:
|
|||
|
||||
final_state = None
|
||||
|
||||
# 流式执行并发射事件
|
||||
# 使用 astream_events 获取详细事件流
|
||||
try:
|
||||
async for event in self.graph.astream_events(
|
||||
initial_state,
|
||||
config=run_config,
|
||||
version="v2",
|
||||
):
|
||||
if self._cancelled:
|
||||
break
|
||||
|
||||
# 处理 LangGraph 事件
|
||||
stream_event = await self.stream_handler.process_langgraph_event(event)
|
||||
if stream_event:
|
||||
# 同步到 event_emitter 以持久化
|
||||
await self._sync_stream_event_to_db(stream_event)
|
||||
yield stream_event
|
||||
|
||||
# 更新最终状态
|
||||
if event.get("event") == "on_chain_end":
|
||||
output = event.get("data", {}).get("output")
|
||||
if isinstance(output, dict):
|
||||
final_state = output
|
||||
|
||||
except Exception as e:
|
||||
# 如果 astream_events 不可用,回退到 astream
|
||||
logger.warning(f"astream_events not available, falling back to astream: {e}")
|
||||
async for event in self.graph.astream(initial_state, config=run_config):
|
||||
if self._cancelled:
|
||||
break
|
||||
|
||||
# 处理每个节点的输出
|
||||
for node_name, node_output in event.items():
|
||||
await self._handle_node_output(node_name, node_output)
|
||||
|
||||
# 更新阶段
|
||||
# 发射节点事件
|
||||
yield StreamEvent(
|
||||
event_type=StreamEventType.NODE_END,
|
||||
sequence=self.stream_handler._next_sequence(),
|
||||
node_name=node_name,
|
||||
data={"message": f"节点 {node_name} 完成"},
|
||||
)
|
||||
|
||||
phase_map = {
|
||||
"recon": AgentTaskPhase.RECONNAISSANCE,
|
||||
"analysis": AgentTaskPhase.ANALYSIS,
|
||||
|
|
@ -355,6 +551,13 @@ class AgentRunner:
|
|||
findings = final_state.get("findings", [])
|
||||
await self._save_findings(findings)
|
||||
|
||||
# 发射发现事件
|
||||
for finding in findings[:10]: # 限制数量
|
||||
yield self.stream_handler.create_finding_event(
|
||||
finding,
|
||||
is_verified=finding.get("is_verified", False),
|
||||
)
|
||||
|
||||
# 7. 更新任务摘要
|
||||
summary = final_state.get("summary", {})
|
||||
security_score = final_state.get("security_score", 100)
|
||||
|
|
@ -374,30 +577,59 @@ class AgentRunner:
|
|||
duration_ms=duration_ms,
|
||||
)
|
||||
|
||||
return {
|
||||
"success": True,
|
||||
"data": {
|
||||
"findings": findings,
|
||||
"verified_findings": final_state.get("verified_findings", []),
|
||||
"summary": summary,
|
||||
yield StreamEvent(
|
||||
event_type=StreamEventType.TASK_COMPLETE,
|
||||
sequence=self.stream_handler._next_sequence(),
|
||||
data={
|
||||
"findings_count": len(findings),
|
||||
"verified_count": len(final_state.get("verified_findings", [])),
|
||||
"security_score": security_score,
|
||||
},
|
||||
"duration_ms": duration_ms,
|
||||
}
|
||||
"message": f"✅ 审计完成!发现 {len(findings)} 个漏洞",
|
||||
},
|
||||
)
|
||||
|
||||
except asyncio.CancelledError:
|
||||
await self._update_task_status(AgentTaskStatus.CANCELLED)
|
||||
return {"success": False, "error": "任务已取消"}
|
||||
yield StreamEvent(
|
||||
event_type=StreamEventType.TASK_CANCEL,
|
||||
sequence=self.stream_handler._next_sequence(),
|
||||
data={"message": "任务已取消"},
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"LangGraph run failed: {e}", exc_info=True)
|
||||
await self._update_task_status(AgentTaskStatus.FAILED, str(e))
|
||||
await self.event_emitter.emit_error(str(e))
|
||||
return {"success": False, "error": str(e)}
|
||||
|
||||
yield StreamEvent(
|
||||
event_type=StreamEventType.TASK_ERROR,
|
||||
sequence=self.stream_handler._next_sequence(),
|
||||
data={"error": str(e), "message": f"❌ 审计失败: {e}"},
|
||||
)
|
||||
|
||||
finally:
|
||||
await self._cleanup()
|
||||
|
||||
async def _sync_stream_event_to_db(self, event: StreamEvent):
|
||||
"""同步流式事件到数据库"""
|
||||
try:
|
||||
# 将 StreamEvent 转换为 AgentEventData
|
||||
await self.event_manager.add_event(
|
||||
task_id=self.task.id,
|
||||
event_type=event.event_type.value,
|
||||
sequence=event.sequence,
|
||||
phase=event.phase,
|
||||
message=event.data.get("message"),
|
||||
tool_name=event.tool_name,
|
||||
tool_input=event.data.get("input") or event.data.get("input_params"),
|
||||
tool_output=event.data.get("output") or event.data.get("output_data"),
|
||||
tool_duration_ms=event.data.get("duration_ms"),
|
||||
metadata=event.data,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to sync stream event to db: {e}")
|
||||
|
||||
async def _handle_node_output(self, node_name: str, output: Dict[str, Any]):
|
||||
"""处理节点输出"""
|
||||
# 发射节点事件
|
||||
|
|
@ -445,7 +677,8 @@ class AgentRunner:
|
|||
return
|
||||
|
||||
await self.event_emitter.emit_progress(
|
||||
progress.processed / max(progress.total, 1) * 100,
|
||||
progress.processed_files,
|
||||
progress.total_files,
|
||||
f"正在索引: {progress.current_file or 'N/A'}"
|
||||
)
|
||||
|
||||
|
|
@ -502,13 +735,23 @@ class AgentRunner:
|
|||
|
||||
type_map = {
|
||||
"sql_injection": VulnerabilityType.SQL_INJECTION,
|
||||
"nosql_injection": VulnerabilityType.NOSQL_INJECTION,
|
||||
"xss": VulnerabilityType.XSS,
|
||||
"command_injection": VulnerabilityType.COMMAND_INJECTION,
|
||||
"code_injection": VulnerabilityType.CODE_INJECTION,
|
||||
"path_traversal": VulnerabilityType.PATH_TRAVERSAL,
|
||||
"file_inclusion": VulnerabilityType.FILE_INCLUSION,
|
||||
"ssrf": VulnerabilityType.SSRF,
|
||||
"xxe": VulnerabilityType.XXE,
|
||||
"deserialization": VulnerabilityType.DESERIALIZATION,
|
||||
"auth_bypass": VulnerabilityType.AUTH_BYPASS,
|
||||
"idor": VulnerabilityType.IDOR,
|
||||
"sensitive_data_exposure": VulnerabilityType.SENSITIVE_DATA_EXPOSURE,
|
||||
"hardcoded_secret": VulnerabilityType.HARDCODED_SECRET,
|
||||
"deserialization": VulnerabilityType.INSECURE_DESERIALIZATION,
|
||||
"weak_crypto": VulnerabilityType.WEAK_CRYPTO,
|
||||
"race_condition": VulnerabilityType.RACE_CONDITION,
|
||||
"business_logic": VulnerabilityType.BUSINESS_LOGIC,
|
||||
"memory_corruption": VulnerabilityType.MEMORY_CORRUPTION,
|
||||
}
|
||||
|
||||
for finding in findings:
|
||||
|
|
@ -536,7 +779,7 @@ class AgentRunner:
|
|||
is_verified=finding.get("is_verified", False),
|
||||
confidence=finding.get("confidence", 0.5),
|
||||
poc=finding.get("poc"),
|
||||
status=FindingStatus.VERIFIED if finding.get("is_verified") else FindingStatus.OPEN,
|
||||
status=FindingStatus.VERIFIED if finding.get("is_verified") else FindingStatus.NEW,
|
||||
)
|
||||
|
||||
self.db.add(db_finding)
|
||||
|
|
@ -604,10 +847,6 @@ class AgentRunner:
|
|||
except Exception as e:
|
||||
logger.warning(f"Cleanup error: {e}")
|
||||
|
||||
def cancel(self):
|
||||
"""取消任务"""
|
||||
self._cancelled = True
|
||||
|
||||
|
||||
# 便捷函数
|
||||
async def run_agent_task(
|
||||
|
|
|
|||
|
|
@ -0,0 +1,17 @@
|
|||
"""
|
||||
Agent 流式输出模块
|
||||
支持 LLM Token 流式输出、工具调用展示、思考过程展示
|
||||
"""
|
||||
|
||||
from .stream_handler import StreamHandler, StreamEvent, StreamEventType
|
||||
from .token_streamer import TokenStreamer
|
||||
from .tool_stream import ToolStreamHandler
|
||||
|
||||
__all__ = [
|
||||
"StreamHandler",
|
||||
"StreamEvent",
|
||||
"StreamEventType",
|
||||
"TokenStreamer",
|
||||
"ToolStreamHandler",
|
||||
]
|
||||
|
||||
|
|
@ -0,0 +1,453 @@
|
|||
"""
|
||||
流式事件处理器
|
||||
处理 LangGraph 的各种流式事件并转换为前端可消费的格式
|
||||
"""
|
||||
|
||||
import json
|
||||
import logging
|
||||
from enum import Enum
|
||||
from typing import Any, Dict, Optional, AsyncGenerator, List
|
||||
from dataclasses import dataclass, field
|
||||
from datetime import datetime, timezone
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class StreamEventType(str, Enum):
|
||||
"""流式事件类型"""
|
||||
# LLM 相关
|
||||
THINKING_START = "thinking_start" # 开始思考
|
||||
THINKING_TOKEN = "thinking_token" # 思考 Token
|
||||
THINKING_END = "thinking_end" # 思考结束
|
||||
|
||||
# 工具调用相关
|
||||
TOOL_CALL_START = "tool_call_start" # 工具调用开始
|
||||
TOOL_CALL_INPUT = "tool_call_input" # 工具输入参数
|
||||
TOOL_CALL_OUTPUT = "tool_call_output" # 工具输出结果
|
||||
TOOL_CALL_END = "tool_call_end" # 工具调用结束
|
||||
TOOL_CALL_ERROR = "tool_call_error" # 工具调用错误
|
||||
|
||||
# 节点相关
|
||||
NODE_START = "node_start" # 节点开始
|
||||
NODE_END = "node_end" # 节点结束
|
||||
|
||||
# 阶段相关
|
||||
PHASE_START = "phase_start"
|
||||
PHASE_END = "phase_end"
|
||||
|
||||
# 发现相关
|
||||
FINDING_NEW = "finding_new" # 新发现
|
||||
FINDING_VERIFIED = "finding_verified" # 验证通过
|
||||
|
||||
# 状态相关
|
||||
PROGRESS = "progress"
|
||||
INFO = "info"
|
||||
WARNING = "warning"
|
||||
ERROR = "error"
|
||||
|
||||
# 任务相关
|
||||
TASK_START = "task_start"
|
||||
TASK_COMPLETE = "task_complete"
|
||||
TASK_ERROR = "task_error"
|
||||
TASK_CANCEL = "task_cancel"
|
||||
|
||||
# 心跳
|
||||
HEARTBEAT = "heartbeat"
|
||||
|
||||
|
||||
@dataclass
|
||||
class StreamEvent:
|
||||
"""流式事件"""
|
||||
event_type: StreamEventType
|
||||
data: Dict[str, Any] = field(default_factory=dict)
|
||||
timestamp: str = field(default_factory=lambda: datetime.now(timezone.utc).isoformat())
|
||||
sequence: int = 0
|
||||
|
||||
# 可选字段
|
||||
node_name: Optional[str] = None
|
||||
phase: Optional[str] = None
|
||||
tool_name: Optional[str] = None
|
||||
|
||||
def to_sse(self) -> str:
|
||||
"""转换为 SSE 格式"""
|
||||
event_data = {
|
||||
"type": self.event_type.value,
|
||||
"data": self.data,
|
||||
"timestamp": self.timestamp,
|
||||
"sequence": self.sequence,
|
||||
}
|
||||
|
||||
if self.node_name:
|
||||
event_data["node"] = self.node_name
|
||||
if self.phase:
|
||||
event_data["phase"] = self.phase
|
||||
if self.tool_name:
|
||||
event_data["tool"] = self.tool_name
|
||||
|
||||
return f"event: {self.event_type.value}\ndata: {json.dumps(event_data, ensure_ascii=False)}\n\n"
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
"""转换为字典"""
|
||||
return {
|
||||
"event_type": self.event_type.value,
|
||||
"data": self.data,
|
||||
"timestamp": self.timestamp,
|
||||
"sequence": self.sequence,
|
||||
"node_name": self.node_name,
|
||||
"phase": self.phase,
|
||||
"tool_name": self.tool_name,
|
||||
}
|
||||
|
||||
|
||||
class StreamHandler:
|
||||
"""
|
||||
流式事件处理器
|
||||
|
||||
最佳实践:
|
||||
1. 使用 astream_events 捕获所有 LangGraph 事件
|
||||
2. 将内部事件转换为前端友好的格式
|
||||
3. 支持多种事件类型的分发
|
||||
"""
|
||||
|
||||
def __init__(self, task_id: str):
|
||||
self.task_id = task_id
|
||||
self._sequence = 0
|
||||
self._current_phase = None
|
||||
self._current_node = None
|
||||
self._thinking_buffer = []
|
||||
self._tool_states: Dict[str, Dict] = {}
|
||||
|
||||
def _next_sequence(self) -> int:
|
||||
"""获取下一个序列号"""
|
||||
self._sequence += 1
|
||||
return self._sequence
|
||||
|
||||
async def process_langgraph_event(self, event: Dict[str, Any]) -> Optional[StreamEvent]:
|
||||
"""
|
||||
处理 LangGraph 事件
|
||||
|
||||
支持的事件类型:
|
||||
- on_chain_start: 链/节点开始
|
||||
- on_chain_end: 链/节点结束
|
||||
- on_chain_stream: LLM Token 流
|
||||
- on_chat_model_start: 模型开始
|
||||
- on_chat_model_stream: 模型 Token 流
|
||||
- on_chat_model_end: 模型结束
|
||||
- on_tool_start: 工具开始
|
||||
- on_tool_end: 工具结束
|
||||
- on_custom_event: 自定义事件
|
||||
"""
|
||||
event_kind = event.get("event", "")
|
||||
event_name = event.get("name", "")
|
||||
event_data = event.get("data", {})
|
||||
|
||||
# LLM Token 流
|
||||
if event_kind == "on_chat_model_stream":
|
||||
return await self._handle_llm_stream(event_data, event_name)
|
||||
|
||||
# LLM 开始
|
||||
elif event_kind == "on_chat_model_start":
|
||||
return await self._handle_llm_start(event_data, event_name)
|
||||
|
||||
# LLM 结束
|
||||
elif event_kind == "on_chat_model_end":
|
||||
return await self._handle_llm_end(event_data, event_name)
|
||||
|
||||
# 工具开始
|
||||
elif event_kind == "on_tool_start":
|
||||
return await self._handle_tool_start(event_name, event_data)
|
||||
|
||||
# 工具结束
|
||||
elif event_kind == "on_tool_end":
|
||||
return await self._handle_tool_end(event_name, event_data)
|
||||
|
||||
# 节点开始
|
||||
elif event_kind == "on_chain_start" and self._is_node_event(event_name):
|
||||
return await self._handle_node_start(event_name, event_data)
|
||||
|
||||
# 节点结束
|
||||
elif event_kind == "on_chain_end" and self._is_node_event(event_name):
|
||||
return await self._handle_node_end(event_name, event_data)
|
||||
|
||||
# 自定义事件
|
||||
elif event_kind == "on_custom_event":
|
||||
return await self._handle_custom_event(event_name, event_data)
|
||||
|
||||
return None
|
||||
|
||||
def _is_node_event(self, name: str) -> bool:
|
||||
"""判断是否是节点事件"""
|
||||
node_names = ["recon", "analysis", "verification", "report", "ReconNode", "AnalysisNode", "VerificationNode", "ReportNode"]
|
||||
return any(n.lower() in name.lower() for n in node_names)
|
||||
|
||||
async def _handle_llm_start(self, data: Dict, name: str) -> StreamEvent:
|
||||
"""处理 LLM 开始事件"""
|
||||
self._thinking_buffer = []
|
||||
|
||||
return StreamEvent(
|
||||
event_type=StreamEventType.THINKING_START,
|
||||
sequence=self._next_sequence(),
|
||||
node_name=self._current_node,
|
||||
phase=self._current_phase,
|
||||
data={
|
||||
"model": name,
|
||||
"message": "🤔 正在思考...",
|
||||
},
|
||||
)
|
||||
|
||||
async def _handle_llm_stream(self, data: Dict, name: str) -> Optional[StreamEvent]:
|
||||
"""处理 LLM Token 流事件"""
|
||||
chunk = data.get("chunk")
|
||||
if not chunk:
|
||||
return None
|
||||
|
||||
# 提取 Token 内容
|
||||
content = ""
|
||||
if hasattr(chunk, "content"):
|
||||
content = chunk.content
|
||||
elif isinstance(chunk, dict):
|
||||
content = chunk.get("content", "")
|
||||
|
||||
if not content:
|
||||
return None
|
||||
|
||||
# 添加到缓冲区
|
||||
self._thinking_buffer.append(content)
|
||||
|
||||
return StreamEvent(
|
||||
event_type=StreamEventType.THINKING_TOKEN,
|
||||
sequence=self._next_sequence(),
|
||||
node_name=self._current_node,
|
||||
phase=self._current_phase,
|
||||
data={
|
||||
"token": content,
|
||||
"accumulated": "".join(self._thinking_buffer),
|
||||
},
|
||||
)
|
||||
|
||||
async def _handle_llm_end(self, data: Dict, name: str) -> StreamEvent:
|
||||
"""处理 LLM 结束事件"""
|
||||
full_response = "".join(self._thinking_buffer)
|
||||
self._thinking_buffer = []
|
||||
|
||||
# 提取使用的 Token 数
|
||||
usage = {}
|
||||
output = data.get("output")
|
||||
if output and hasattr(output, "usage_metadata"):
|
||||
usage = {
|
||||
"input_tokens": getattr(output.usage_metadata, "input_tokens", 0),
|
||||
"output_tokens": getattr(output.usage_metadata, "output_tokens", 0),
|
||||
}
|
||||
|
||||
return StreamEvent(
|
||||
event_type=StreamEventType.THINKING_END,
|
||||
sequence=self._next_sequence(),
|
||||
node_name=self._current_node,
|
||||
phase=self._current_phase,
|
||||
data={
|
||||
"response": full_response[:2000], # 截断长响应
|
||||
"usage": usage,
|
||||
"message": "💡 思考完成",
|
||||
},
|
||||
)
|
||||
|
||||
async def _handle_tool_start(self, tool_name: str, data: Dict) -> StreamEvent:
|
||||
"""处理工具开始事件"""
|
||||
import time
|
||||
|
||||
tool_input = data.get("input", {})
|
||||
|
||||
# 记录工具状态
|
||||
self._tool_states[tool_name] = {
|
||||
"start_time": time.time(),
|
||||
"input": tool_input,
|
||||
}
|
||||
|
||||
return StreamEvent(
|
||||
event_type=StreamEventType.TOOL_CALL_START,
|
||||
sequence=self._next_sequence(),
|
||||
node_name=self._current_node,
|
||||
phase=self._current_phase,
|
||||
tool_name=tool_name,
|
||||
data={
|
||||
"tool_name": tool_name,
|
||||
"input": self._truncate_data(tool_input),
|
||||
"message": f"🔧 调用工具: {tool_name}",
|
||||
},
|
||||
)
|
||||
|
||||
async def _handle_tool_end(self, tool_name: str, data: Dict) -> StreamEvent:
|
||||
"""处理工具结束事件"""
|
||||
import time
|
||||
|
||||
# 计算执行时间
|
||||
duration_ms = 0
|
||||
if tool_name in self._tool_states:
|
||||
start_time = self._tool_states[tool_name].get("start_time", time.time())
|
||||
duration_ms = int((time.time() - start_time) * 1000)
|
||||
del self._tool_states[tool_name]
|
||||
|
||||
# 提取输出
|
||||
output = data.get("output", "")
|
||||
if hasattr(output, "content"):
|
||||
output = output.content
|
||||
|
||||
return StreamEvent(
|
||||
event_type=StreamEventType.TOOL_CALL_END,
|
||||
sequence=self._next_sequence(),
|
||||
node_name=self._current_node,
|
||||
phase=self._current_phase,
|
||||
tool_name=tool_name,
|
||||
data={
|
||||
"tool_name": tool_name,
|
||||
"output": self._truncate_data(output),
|
||||
"duration_ms": duration_ms,
|
||||
"message": f"✅ 工具 {tool_name} 完成 ({duration_ms}ms)",
|
||||
},
|
||||
)
|
||||
|
||||
async def _handle_node_start(self, node_name: str, data: Dict) -> StreamEvent:
|
||||
"""处理节点开始事件"""
|
||||
self._current_node = node_name
|
||||
|
||||
# 映射节点到阶段
|
||||
phase_map = {
|
||||
"recon": "reconnaissance",
|
||||
"analysis": "analysis",
|
||||
"verification": "verification",
|
||||
"report": "reporting",
|
||||
}
|
||||
|
||||
for key, phase in phase_map.items():
|
||||
if key in node_name.lower():
|
||||
self._current_phase = phase
|
||||
break
|
||||
|
||||
return StreamEvent(
|
||||
event_type=StreamEventType.NODE_START,
|
||||
sequence=self._next_sequence(),
|
||||
node_name=node_name,
|
||||
phase=self._current_phase,
|
||||
data={
|
||||
"node": node_name,
|
||||
"phase": self._current_phase,
|
||||
"message": f"▶️ 开始节点: {node_name}",
|
||||
},
|
||||
)
|
||||
|
||||
async def _handle_node_end(self, node_name: str, data: Dict) -> StreamEvent:
|
||||
"""处理节点结束事件"""
|
||||
# 提取输出信息
|
||||
output = data.get("output", {})
|
||||
|
||||
summary = {}
|
||||
if isinstance(output, dict):
|
||||
# 提取关键信息
|
||||
if "findings" in output:
|
||||
summary["findings_count"] = len(output["findings"])
|
||||
if "entry_points" in output:
|
||||
summary["entry_points_count"] = len(output["entry_points"])
|
||||
if "high_risk_areas" in output:
|
||||
summary["high_risk_areas_count"] = len(output["high_risk_areas"])
|
||||
if "verified_findings" in output:
|
||||
summary["verified_count"] = len(output["verified_findings"])
|
||||
|
||||
return StreamEvent(
|
||||
event_type=StreamEventType.NODE_END,
|
||||
sequence=self._next_sequence(),
|
||||
node_name=node_name,
|
||||
phase=self._current_phase,
|
||||
data={
|
||||
"node": node_name,
|
||||
"phase": self._current_phase,
|
||||
"summary": summary,
|
||||
"message": f"⏹️ 节点完成: {node_name}",
|
||||
},
|
||||
)
|
||||
|
||||
async def _handle_custom_event(self, event_name: str, data: Dict) -> StreamEvent:
|
||||
"""处理自定义事件"""
|
||||
# 映射自定义事件名到事件类型
|
||||
event_type_map = {
|
||||
"finding": StreamEventType.FINDING_NEW,
|
||||
"finding_verified": StreamEventType.FINDING_VERIFIED,
|
||||
"progress": StreamEventType.PROGRESS,
|
||||
"warning": StreamEventType.WARNING,
|
||||
"error": StreamEventType.ERROR,
|
||||
}
|
||||
|
||||
event_type = event_type_map.get(event_name, StreamEventType.INFO)
|
||||
|
||||
return StreamEvent(
|
||||
event_type=event_type,
|
||||
sequence=self._next_sequence(),
|
||||
node_name=self._current_node,
|
||||
phase=self._current_phase,
|
||||
data=data,
|
||||
)
|
||||
|
||||
def _truncate_data(self, data: Any, max_length: int = 1000) -> Any:
|
||||
"""截断数据"""
|
||||
if isinstance(data, str):
|
||||
return data[:max_length] + "..." if len(data) > max_length else data
|
||||
elif isinstance(data, dict):
|
||||
return {k: self._truncate_data(v, max_length // 2) for k, v in list(data.items())[:10]}
|
||||
elif isinstance(data, list):
|
||||
return [self._truncate_data(item, max_length // len(data)) for item in data[:10]]
|
||||
else:
|
||||
return str(data)[:max_length]
|
||||
|
||||
def create_progress_event(
|
||||
self,
|
||||
current: int,
|
||||
total: int,
|
||||
message: Optional[str] = None,
|
||||
) -> StreamEvent:
|
||||
"""创建进度事件"""
|
||||
percentage = (current / total * 100) if total > 0 else 0
|
||||
|
||||
return StreamEvent(
|
||||
event_type=StreamEventType.PROGRESS,
|
||||
sequence=self._next_sequence(),
|
||||
node_name=self._current_node,
|
||||
phase=self._current_phase,
|
||||
data={
|
||||
"current": current,
|
||||
"total": total,
|
||||
"percentage": round(percentage, 1),
|
||||
"message": message or f"进度: {current}/{total}",
|
||||
},
|
||||
)
|
||||
|
||||
def create_finding_event(
|
||||
self,
|
||||
finding: Dict[str, Any],
|
||||
is_verified: bool = False,
|
||||
) -> StreamEvent:
|
||||
"""创建发现事件"""
|
||||
event_type = StreamEventType.FINDING_VERIFIED if is_verified else StreamEventType.FINDING_NEW
|
||||
|
||||
return StreamEvent(
|
||||
event_type=event_type,
|
||||
sequence=self._next_sequence(),
|
||||
node_name=self._current_node,
|
||||
phase=self._current_phase,
|
||||
data={
|
||||
"title": finding.get("title", "Unknown"),
|
||||
"severity": finding.get("severity", "medium"),
|
||||
"vulnerability_type": finding.get("vulnerability_type", "other"),
|
||||
"file_path": finding.get("file_path"),
|
||||
"line_start": finding.get("line_start"),
|
||||
"is_verified": is_verified,
|
||||
"message": f"{'✅ 已验证' if is_verified else '🔍 新发现'}: [{finding.get('severity', 'medium').upper()}] {finding.get('title', 'Unknown')}",
|
||||
},
|
||||
)
|
||||
|
||||
def create_heartbeat(self) -> StreamEvent:
|
||||
"""创建心跳事件"""
|
||||
return StreamEvent(
|
||||
event_type=StreamEventType.HEARTBEAT,
|
||||
sequence=self._sequence, # 心跳不增加序列号
|
||||
data={"message": "ping"},
|
||||
)
|
||||
|
||||
|
|
@ -0,0 +1,261 @@
|
|||
"""
|
||||
LLM Token 流式输出处理器
|
||||
支持多种 LLM 提供商的流式输出
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
from typing import Any, Dict, Optional, AsyncGenerator, Callable
|
||||
from dataclasses import dataclass
|
||||
from datetime import datetime, timezone
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class TokenChunk:
|
||||
"""Token 块"""
|
||||
content: str
|
||||
token_count: int = 1
|
||||
finish_reason: Optional[str] = None
|
||||
model: Optional[str] = None
|
||||
|
||||
# 统计信息
|
||||
accumulated_content: str = ""
|
||||
total_tokens: int = 0
|
||||
|
||||
|
||||
class TokenStreamer:
|
||||
"""
|
||||
LLM Token 流式输出处理器
|
||||
|
||||
最佳实践:
|
||||
1. 使用 LiteLLM 的流式 API
|
||||
2. 实时发送每个 Token
|
||||
3. 跟踪累积内容和 Token 使用
|
||||
4. 支持中断和超时
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model: str,
|
||||
api_key: Optional[str] = None,
|
||||
base_url: Optional[str] = None,
|
||||
on_token: Optional[Callable[[TokenChunk], None]] = None,
|
||||
):
|
||||
self.model = model
|
||||
self.api_key = api_key
|
||||
self.base_url = base_url
|
||||
self.on_token = on_token
|
||||
|
||||
self._cancelled = False
|
||||
self._accumulated_content = ""
|
||||
self._total_tokens = 0
|
||||
|
||||
def cancel(self):
|
||||
"""取消流式输出"""
|
||||
self._cancelled = True
|
||||
|
||||
async def stream_completion(
|
||||
self,
|
||||
messages: list[Dict[str, str]],
|
||||
temperature: float = 0.1,
|
||||
max_tokens: int = 4096,
|
||||
) -> AsyncGenerator[TokenChunk, None]:
|
||||
"""
|
||||
流式调用 LLM
|
||||
|
||||
Args:
|
||||
messages: 消息列表
|
||||
temperature: 温度
|
||||
max_tokens: 最大 Token 数
|
||||
|
||||
Yields:
|
||||
TokenChunk: Token 块
|
||||
"""
|
||||
try:
|
||||
import litellm
|
||||
|
||||
response = await litellm.acompletion(
|
||||
model=self.model,
|
||||
messages=messages,
|
||||
temperature=temperature,
|
||||
max_tokens=max_tokens,
|
||||
api_key=self.api_key,
|
||||
base_url=self.base_url,
|
||||
stream=True, # 启用流式输出
|
||||
)
|
||||
|
||||
async for chunk in response:
|
||||
if self._cancelled:
|
||||
break
|
||||
|
||||
# 提取内容
|
||||
content = ""
|
||||
finish_reason = None
|
||||
|
||||
if hasattr(chunk, "choices") and chunk.choices:
|
||||
choice = chunk.choices[0]
|
||||
if hasattr(choice, "delta") and choice.delta:
|
||||
content = getattr(choice.delta, "content", "") or ""
|
||||
finish_reason = getattr(choice, "finish_reason", None)
|
||||
|
||||
if content:
|
||||
self._accumulated_content += content
|
||||
self._total_tokens += 1
|
||||
|
||||
token_chunk = TokenChunk(
|
||||
content=content,
|
||||
token_count=1,
|
||||
finish_reason=finish_reason,
|
||||
model=self.model,
|
||||
accumulated_content=self._accumulated_content,
|
||||
total_tokens=self._total_tokens,
|
||||
)
|
||||
|
||||
# 回调
|
||||
if self.on_token:
|
||||
self.on_token(token_chunk)
|
||||
|
||||
yield token_chunk
|
||||
|
||||
# 检查是否完成
|
||||
if finish_reason:
|
||||
break
|
||||
|
||||
except asyncio.CancelledError:
|
||||
logger.info("Token streaming cancelled")
|
||||
raise
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Token streaming error: {e}")
|
||||
raise
|
||||
|
||||
async def stream_with_tools(
|
||||
self,
|
||||
messages: list[Dict[str, str]],
|
||||
tools: list[Dict[str, Any]],
|
||||
temperature: float = 0.1,
|
||||
max_tokens: int = 4096,
|
||||
) -> AsyncGenerator[Dict[str, Any], None]:
|
||||
"""
|
||||
带工具调用的流式输出
|
||||
|
||||
Args:
|
||||
messages: 消息列表
|
||||
tools: 工具定义列表
|
||||
temperature: 温度
|
||||
max_tokens: 最大 Token 数
|
||||
|
||||
Yields:
|
||||
包含 token 或 tool_call 的字典
|
||||
"""
|
||||
try:
|
||||
import litellm
|
||||
|
||||
response = await litellm.acompletion(
|
||||
model=self.model,
|
||||
messages=messages,
|
||||
tools=tools,
|
||||
temperature=temperature,
|
||||
max_tokens=max_tokens,
|
||||
api_key=self.api_key,
|
||||
base_url=self.base_url,
|
||||
stream=True,
|
||||
)
|
||||
|
||||
# 工具调用累积器
|
||||
tool_calls_accumulator: Dict[int, Dict] = {}
|
||||
|
||||
async for chunk in response:
|
||||
if self._cancelled:
|
||||
break
|
||||
|
||||
if not hasattr(chunk, "choices") or not chunk.choices:
|
||||
continue
|
||||
|
||||
choice = chunk.choices[0]
|
||||
delta = getattr(choice, "delta", None)
|
||||
finish_reason = getattr(choice, "finish_reason", None)
|
||||
|
||||
if delta:
|
||||
# 处理文本内容
|
||||
content = getattr(delta, "content", "") or ""
|
||||
if content:
|
||||
self._accumulated_content += content
|
||||
self._total_tokens += 1
|
||||
|
||||
yield {
|
||||
"type": "token",
|
||||
"content": content,
|
||||
"accumulated": self._accumulated_content,
|
||||
"total_tokens": self._total_tokens,
|
||||
}
|
||||
|
||||
# 处理工具调用
|
||||
tool_calls = getattr(delta, "tool_calls", None) or []
|
||||
for tool_call in tool_calls:
|
||||
idx = tool_call.index
|
||||
|
||||
if idx not in tool_calls_accumulator:
|
||||
tool_calls_accumulator[idx] = {
|
||||
"id": tool_call.id or "",
|
||||
"name": "",
|
||||
"arguments": "",
|
||||
}
|
||||
|
||||
if tool_call.function:
|
||||
if tool_call.function.name:
|
||||
tool_calls_accumulator[idx]["name"] = tool_call.function.name
|
||||
if tool_call.function.arguments:
|
||||
tool_calls_accumulator[idx]["arguments"] += tool_call.function.arguments
|
||||
|
||||
yield {
|
||||
"type": "tool_call_chunk",
|
||||
"index": idx,
|
||||
"tool_call": tool_calls_accumulator[idx],
|
||||
}
|
||||
|
||||
# 完成时发送最终工具调用
|
||||
if finish_reason == "tool_calls":
|
||||
for idx, tool_call in tool_calls_accumulator.items():
|
||||
yield {
|
||||
"type": "tool_call_complete",
|
||||
"index": idx,
|
||||
"tool_call": tool_call,
|
||||
}
|
||||
|
||||
if finish_reason:
|
||||
yield {
|
||||
"type": "finish",
|
||||
"reason": finish_reason,
|
||||
"accumulated": self._accumulated_content,
|
||||
"total_tokens": self._total_tokens,
|
||||
}
|
||||
break
|
||||
|
||||
except asyncio.CancelledError:
|
||||
logger.info("Tool streaming cancelled")
|
||||
raise
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Tool streaming error: {e}")
|
||||
yield {
|
||||
"type": "error",
|
||||
"error": str(e),
|
||||
}
|
||||
|
||||
def get_accumulated_content(self) -> str:
|
||||
"""获取累积内容"""
|
||||
return self._accumulated_content
|
||||
|
||||
def get_total_tokens(self) -> int:
|
||||
"""获取总 Token 数"""
|
||||
return self._total_tokens
|
||||
|
||||
def reset(self):
|
||||
"""重置状态"""
|
||||
self._cancelled = False
|
||||
self._accumulated_content = ""
|
||||
self._total_tokens = 0
|
||||
|
||||
|
|
@ -0,0 +1,319 @@
|
|||
"""
|
||||
工具调用流式处理器
|
||||
展示工具调用的输入、执行过程和输出
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import time
|
||||
import logging
|
||||
from typing import Any, Dict, Optional, AsyncGenerator, List, Callable
|
||||
from dataclasses import dataclass, field
|
||||
from datetime import datetime, timezone
|
||||
from enum import Enum
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class ToolCallState(str, Enum):
|
||||
"""工具调用状态"""
|
||||
PENDING = "pending" # 等待执行
|
||||
RUNNING = "running" # 执行中
|
||||
SUCCESS = "success" # 成功
|
||||
ERROR = "error" # 错误
|
||||
TIMEOUT = "timeout" # 超时
|
||||
|
||||
|
||||
@dataclass
|
||||
class ToolCallEvent:
|
||||
"""工具调用事件"""
|
||||
tool_name: str
|
||||
state: ToolCallState
|
||||
|
||||
# 输入输出
|
||||
input_params: Dict[str, Any] = field(default_factory=dict)
|
||||
output_data: Optional[Any] = None
|
||||
error_message: Optional[str] = None
|
||||
|
||||
# 时间
|
||||
start_time: Optional[float] = None
|
||||
end_time: Optional[float] = None
|
||||
duration_ms: int = 0
|
||||
|
||||
# 元数据
|
||||
call_id: Optional[str] = None
|
||||
sequence: int = 0
|
||||
timestamp: str = field(default_factory=lambda: datetime.now(timezone.utc).isoformat())
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
"""转换为字典"""
|
||||
return {
|
||||
"tool_name": self.tool_name,
|
||||
"state": self.state.value,
|
||||
"input_params": self._truncate(self.input_params),
|
||||
"output_data": self._truncate(self.output_data),
|
||||
"error_message": self.error_message,
|
||||
"duration_ms": self.duration_ms,
|
||||
"call_id": self.call_id,
|
||||
"sequence": self.sequence,
|
||||
"timestamp": self.timestamp,
|
||||
}
|
||||
|
||||
def _truncate(self, data: Any, max_length: int = 500) -> Any:
|
||||
"""截断数据"""
|
||||
if data is None:
|
||||
return None
|
||||
if isinstance(data, str):
|
||||
return data[:max_length] + "..." if len(data) > max_length else data
|
||||
elif isinstance(data, dict):
|
||||
return {k: self._truncate(v, max_length // 2) for k, v in list(data.items())[:20]}
|
||||
elif isinstance(data, list):
|
||||
max_items = min(20, len(data))
|
||||
return [self._truncate(item, max_length // max_items) for item in data[:max_items]]
|
||||
else:
|
||||
s = str(data)
|
||||
return s[:max_length] + "..." if len(s) > max_length else s
|
||||
|
||||
|
||||
class ToolStreamHandler:
|
||||
"""
|
||||
工具调用流式处理器
|
||||
|
||||
功能:
|
||||
1. 跟踪工具调用状态
|
||||
2. 记录输入参数
|
||||
3. 流式输出执行过程
|
||||
4. 记录输出和执行时间
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
on_event: Optional[Callable[[ToolCallEvent], None]] = None,
|
||||
):
|
||||
self.on_event = on_event
|
||||
self._sequence = 0
|
||||
self._active_calls: Dict[str, ToolCallEvent] = {}
|
||||
self._history: List[ToolCallEvent] = []
|
||||
|
||||
def _next_sequence(self) -> int:
|
||||
"""获取下一个序列号"""
|
||||
self._sequence += 1
|
||||
return self._sequence
|
||||
|
||||
def _generate_call_id(self) -> str:
|
||||
"""生成调用 ID"""
|
||||
import uuid
|
||||
return str(uuid.uuid4())[:8]
|
||||
|
||||
async def emit_tool_start(
|
||||
self,
|
||||
tool_name: str,
|
||||
input_params: Dict[str, Any],
|
||||
call_id: Optional[str] = None,
|
||||
) -> ToolCallEvent:
|
||||
"""
|
||||
发射工具开始事件
|
||||
|
||||
Args:
|
||||
tool_name: 工具名称
|
||||
input_params: 输入参数
|
||||
call_id: 调用 ID
|
||||
|
||||
Returns:
|
||||
工具调用事件
|
||||
"""
|
||||
call_id = call_id or self._generate_call_id()
|
||||
|
||||
event = ToolCallEvent(
|
||||
tool_name=tool_name,
|
||||
state=ToolCallState.RUNNING,
|
||||
input_params=input_params,
|
||||
start_time=time.time(),
|
||||
call_id=call_id,
|
||||
sequence=self._next_sequence(),
|
||||
)
|
||||
|
||||
self._active_calls[call_id] = event
|
||||
|
||||
if self.on_event:
|
||||
self.on_event(event)
|
||||
|
||||
return event
|
||||
|
||||
async def emit_tool_end(
|
||||
self,
|
||||
call_id: str,
|
||||
output_data: Any,
|
||||
is_error: bool = False,
|
||||
error_message: Optional[str] = None,
|
||||
) -> ToolCallEvent:
|
||||
"""
|
||||
发射工具结束事件
|
||||
|
||||
Args:
|
||||
call_id: 调用 ID
|
||||
output_data: 输出数据
|
||||
is_error: 是否错误
|
||||
error_message: 错误消息
|
||||
|
||||
Returns:
|
||||
工具调用事件
|
||||
"""
|
||||
if call_id not in self._active_calls:
|
||||
logger.warning(f"Unknown tool call: {call_id}")
|
||||
return None
|
||||
|
||||
event = self._active_calls[call_id]
|
||||
event.end_time = time.time()
|
||||
event.duration_ms = int((event.end_time - event.start_time) * 1000) if event.start_time else 0
|
||||
event.output_data = output_data
|
||||
event.sequence = self._next_sequence()
|
||||
|
||||
if is_error:
|
||||
event.state = ToolCallState.ERROR
|
||||
event.error_message = error_message or str(output_data)
|
||||
else:
|
||||
event.state = ToolCallState.SUCCESS
|
||||
|
||||
# 移动到历史记录
|
||||
del self._active_calls[call_id]
|
||||
self._history.append(event)
|
||||
|
||||
if self.on_event:
|
||||
self.on_event(event)
|
||||
|
||||
return event
|
||||
|
||||
async def emit_tool_timeout(self, call_id: str, timeout_seconds: int) -> ToolCallEvent:
|
||||
"""发射工具超时事件"""
|
||||
if call_id not in self._active_calls:
|
||||
return None
|
||||
|
||||
event = self._active_calls[call_id]
|
||||
event.end_time = time.time()
|
||||
event.duration_ms = int((event.end_time - event.start_time) * 1000) if event.start_time else 0
|
||||
event.state = ToolCallState.TIMEOUT
|
||||
event.error_message = f"Tool execution timed out after {timeout_seconds}s"
|
||||
event.sequence = self._next_sequence()
|
||||
|
||||
del self._active_calls[call_id]
|
||||
self._history.append(event)
|
||||
|
||||
if self.on_event:
|
||||
self.on_event(event)
|
||||
|
||||
return event
|
||||
|
||||
def wrap_tool(
|
||||
self,
|
||||
tool_func: Callable,
|
||||
tool_name: str,
|
||||
timeout: Optional[int] = None,
|
||||
) -> Callable:
|
||||
"""
|
||||
包装工具函数以自动跟踪
|
||||
|
||||
Args:
|
||||
tool_func: 工具函数
|
||||
tool_name: 工具名称
|
||||
timeout: 超时时间(秒)
|
||||
|
||||
Returns:
|
||||
包装后的函数
|
||||
"""
|
||||
async def wrapped(*args, **kwargs):
|
||||
call_id = self._generate_call_id()
|
||||
|
||||
# 发射开始事件
|
||||
await self.emit_tool_start(
|
||||
tool_name=tool_name,
|
||||
input_params={"args": args, "kwargs": kwargs},
|
||||
call_id=call_id,
|
||||
)
|
||||
|
||||
try:
|
||||
# 执行工具
|
||||
if asyncio.iscoroutinefunction(tool_func):
|
||||
if timeout:
|
||||
result = await asyncio.wait_for(
|
||||
tool_func(*args, **kwargs),
|
||||
timeout=timeout,
|
||||
)
|
||||
else:
|
||||
result = await tool_func(*args, **kwargs)
|
||||
else:
|
||||
if timeout:
|
||||
result = await asyncio.wait_for(
|
||||
asyncio.to_thread(tool_func, *args, **kwargs),
|
||||
timeout=timeout,
|
||||
)
|
||||
else:
|
||||
result = tool_func(*args, **kwargs)
|
||||
|
||||
# 发射结束事件
|
||||
await self.emit_tool_end(call_id, result)
|
||||
|
||||
return result
|
||||
|
||||
except asyncio.TimeoutError:
|
||||
await self.emit_tool_timeout(call_id, timeout or 0)
|
||||
raise
|
||||
|
||||
except Exception as e:
|
||||
await self.emit_tool_end(call_id, None, is_error=True, error_message=str(e))
|
||||
raise
|
||||
|
||||
return wrapped
|
||||
|
||||
def get_active_calls(self) -> List[ToolCallEvent]:
|
||||
"""获取活跃的调用"""
|
||||
return list(self._active_calls.values())
|
||||
|
||||
def get_history(self, limit: int = 100) -> List[ToolCallEvent]:
|
||||
"""获取历史记录"""
|
||||
return self._history[-limit:]
|
||||
|
||||
def get_stats(self) -> Dict[str, Any]:
|
||||
"""获取统计信息"""
|
||||
total_calls = len(self._history)
|
||||
success_calls = sum(1 for e in self._history if e.state == ToolCallState.SUCCESS)
|
||||
error_calls = sum(1 for e in self._history if e.state == ToolCallState.ERROR)
|
||||
timeout_calls = sum(1 for e in self._history if e.state == ToolCallState.TIMEOUT)
|
||||
|
||||
total_duration = sum(e.duration_ms for e in self._history)
|
||||
avg_duration = total_duration / total_calls if total_calls > 0 else 0
|
||||
|
||||
# 按工具统计
|
||||
tool_stats = {}
|
||||
for event in self._history:
|
||||
if event.tool_name not in tool_stats:
|
||||
tool_stats[event.tool_name] = {
|
||||
"calls": 0,
|
||||
"success": 0,
|
||||
"errors": 0,
|
||||
"total_duration_ms": 0,
|
||||
}
|
||||
tool_stats[event.tool_name]["calls"] += 1
|
||||
if event.state == ToolCallState.SUCCESS:
|
||||
tool_stats[event.tool_name]["success"] += 1
|
||||
elif event.state in [ToolCallState.ERROR, ToolCallState.TIMEOUT]:
|
||||
tool_stats[event.tool_name]["errors"] += 1
|
||||
tool_stats[event.tool_name]["total_duration_ms"] += event.duration_ms
|
||||
|
||||
return {
|
||||
"total_calls": total_calls,
|
||||
"success_calls": success_calls,
|
||||
"error_calls": error_calls,
|
||||
"timeout_calls": timeout_calls,
|
||||
"success_rate": success_calls / total_calls if total_calls > 0 else 0,
|
||||
"total_duration_ms": total_duration,
|
||||
"avg_duration_ms": round(avg_duration, 2),
|
||||
"active_calls": len(self._active_calls),
|
||||
"by_tool": tool_stats,
|
||||
}
|
||||
|
||||
def clear(self):
|
||||
"""清空记录"""
|
||||
self._active_calls.clear()
|
||||
self._history.clear()
|
||||
self._sequence = 0
|
||||
|
||||
|
|
@ -41,6 +41,16 @@ class PatternMatchTool(AgentTool):
|
|||
使用正则表达式快速扫描代码中的危险模式
|
||||
"""
|
||||
|
||||
def __init__(self, project_root: str = None):
|
||||
"""
|
||||
初始化模式匹配工具
|
||||
|
||||
Args:
|
||||
project_root: 项目根目录(可选,用于上下文)
|
||||
"""
|
||||
super().__init__()
|
||||
self.project_root = project_root
|
||||
|
||||
# 危险模式定义
|
||||
PATTERNS: Dict[str, Dict[str, Any]] = {
|
||||
# SQL 注入模式
|
||||
|
|
|
|||
|
|
@ -145,3 +145,5 @@ class BaiduAdapter(BaseLLMAdapter):
|
|||
|
||||
|
||||
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -83,3 +83,5 @@ class DoubaoAdapter(BaseLLMAdapter):
|
|||
|
||||
|
||||
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -86,3 +86,5 @@ class MinimaxAdapter(BaseLLMAdapter):
|
|||
|
||||
|
||||
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -134,3 +134,5 @@ class BaseLLMAdapter(ABC):
|
|||
|
||||
|
||||
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -120,3 +120,5 @@ DEFAULT_BASE_URLS: Dict[LLMProvider, str] = {
|
|||
|
||||
|
||||
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -0,0 +1,5 @@
|
|||
"""
|
||||
DeepAudit Agent 测试套件
|
||||
企业级测试框架,覆盖工具、Agent、节点和完整流程
|
||||
"""
|
||||
|
||||
|
|
@ -0,0 +1,295 @@
|
|||
"""
|
||||
Agent 测试配置和 Fixtures
|
||||
提供测试所需的公共设施
|
||||
"""
|
||||
|
||||
import pytest
|
||||
import asyncio
|
||||
import tempfile
|
||||
import shutil
|
||||
import os
|
||||
from typing import Dict, Any, Optional
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
from dataclasses import dataclass
|
||||
|
||||
|
||||
# ============ 测试配置 ============
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def event_loop():
|
||||
"""创建事件循环"""
|
||||
loop = asyncio.new_event_loop()
|
||||
yield loop
|
||||
loop.close()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def temp_project_dir():
|
||||
"""创建临时项目目录,包含测试代码"""
|
||||
temp_dir = tempfile.mkdtemp(prefix="deepaudit_test_")
|
||||
|
||||
# 创建测试项目结构
|
||||
os.makedirs(os.path.join(temp_dir, "src"), exist_ok=True)
|
||||
os.makedirs(os.path.join(temp_dir, "config"), exist_ok=True)
|
||||
|
||||
# 创建有漏洞的测试代码 - SQL 注入
|
||||
sql_vuln_code = '''
|
||||
import sqlite3
|
||||
|
||||
def get_user(user_id):
|
||||
"""危险:SQL 注入漏洞"""
|
||||
conn = sqlite3.connect("users.db")
|
||||
cursor = conn.cursor()
|
||||
# 直接拼接用户输入,存在 SQL 注入风险
|
||||
query = f"SELECT * FROM users WHERE id = '{user_id}'"
|
||||
cursor.execute(query)
|
||||
return cursor.fetchone()
|
||||
|
||||
def search_users(name):
|
||||
"""危险:SQL 注入漏洞"""
|
||||
conn = sqlite3.connect("users.db")
|
||||
cursor = conn.cursor()
|
||||
cursor.execute("SELECT * FROM users WHERE name LIKE '%" + name + "%'")
|
||||
return cursor.fetchall()
|
||||
'''
|
||||
|
||||
# 创建有漏洞的测试代码 - 命令注入
|
||||
cmd_vuln_code = '''
|
||||
import os
|
||||
import subprocess
|
||||
|
||||
def run_command(user_input):
|
||||
"""危险:命令注入漏洞"""
|
||||
# 直接执行用户输入
|
||||
os.system(f"echo {user_input}")
|
||||
|
||||
def execute_script(script_name):
|
||||
"""危险:命令注入漏洞"""
|
||||
subprocess.call(f"bash {script_name}", shell=True)
|
||||
'''
|
||||
|
||||
# 创建有漏洞的测试代码 - XSS
|
||||
xss_vuln_code = '''
|
||||
from flask import Flask, request, render_template_string
|
||||
|
||||
app = Flask(__name__)
|
||||
|
||||
@app.route("/greet")
|
||||
def greet():
|
||||
"""危险:XSS 漏洞"""
|
||||
name = request.args.get("name", "")
|
||||
# 直接将用户输入嵌入 HTML,存在 XSS 风险
|
||||
return f"<h1>Hello, {name}!</h1>"
|
||||
|
||||
@app.route("/search")
|
||||
def search():
|
||||
"""危险:XSS 漏洞"""
|
||||
query = request.args.get("q", "")
|
||||
html = f"<p>搜索结果: {query}</p>"
|
||||
return render_template_string(html)
|
||||
'''
|
||||
|
||||
# 创建有漏洞的测试代码 - 路径遍历
|
||||
path_vuln_code = '''
|
||||
import os
|
||||
|
||||
def read_file(filename):
|
||||
"""危险:路径遍历漏洞"""
|
||||
# 没有验证文件路径
|
||||
filepath = os.path.join("/app/data", filename)
|
||||
with open(filepath, "r") as f:
|
||||
return f.read()
|
||||
|
||||
def download_file(user_path):
|
||||
"""危险:路径遍历漏洞"""
|
||||
# 直接使用用户输入作为文件路径
|
||||
with open(user_path, "rb") as f:
|
||||
return f.read()
|
||||
'''
|
||||
|
||||
# 创建有漏洞的测试代码 - 硬编码密钥
|
||||
secret_vuln_code = '''
|
||||
# 配置文件
|
||||
DATABASE_URL = "postgresql://user:password123@localhost/db"
|
||||
API_KEY = "sk-1234567890abcdef1234567890abcdef"
|
||||
SECRET_KEY = "super_secret_key_dont_share"
|
||||
AWS_SECRET_ACCESS_KEY = "wJalrXUtnFEMI/K7MDENG/bPxRfiCYEXAMPLEKEY"
|
||||
|
||||
def connect_database():
|
||||
password = "admin123" # 硬编码密码
|
||||
return f"mysql://root:{password}@localhost/mydb"
|
||||
'''
|
||||
|
||||
# 创建安全的代码(用于对比)
|
||||
safe_code = '''
|
||||
import sqlite3
|
||||
from typing import Optional
|
||||
|
||||
def get_user_safe(user_id: int) -> Optional[dict]:
|
||||
"""安全:使用参数化查询"""
|
||||
conn = sqlite3.connect("users.db")
|
||||
cursor = conn.cursor()
|
||||
cursor.execute("SELECT * FROM users WHERE id = ?", (user_id,))
|
||||
return cursor.fetchone()
|
||||
|
||||
def validate_input(user_input: str) -> str:
|
||||
"""输入验证"""
|
||||
import re
|
||||
if not re.match(r'^[a-zA-Z0-9_]+$', user_input):
|
||||
raise ValueError("Invalid input")
|
||||
return user_input
|
||||
'''
|
||||
|
||||
# 创建配置文件
|
||||
config_code = '''
|
||||
import os
|
||||
|
||||
class Config:
|
||||
"""安全配置"""
|
||||
DATABASE_URL = os.environ.get("DATABASE_URL")
|
||||
SECRET_KEY = os.environ.get("SECRET_KEY")
|
||||
DEBUG = False
|
||||
'''
|
||||
|
||||
# 创建 requirements.txt
|
||||
requirements = '''
|
||||
flask>=2.0.0
|
||||
sqlalchemy>=2.0.0
|
||||
requests>=2.28.0
|
||||
'''
|
||||
|
||||
# 写入文件
|
||||
with open(os.path.join(temp_dir, "src", "sql_vuln.py"), "w") as f:
|
||||
f.write(sql_vuln_code)
|
||||
|
||||
with open(os.path.join(temp_dir, "src", "cmd_vuln.py"), "w") as f:
|
||||
f.write(cmd_vuln_code)
|
||||
|
||||
with open(os.path.join(temp_dir, "src", "xss_vuln.py"), "w") as f:
|
||||
f.write(xss_vuln_code)
|
||||
|
||||
with open(os.path.join(temp_dir, "src", "path_vuln.py"), "w") as f:
|
||||
f.write(path_vuln_code)
|
||||
|
||||
with open(os.path.join(temp_dir, "src", "secrets.py"), "w") as f:
|
||||
f.write(secret_vuln_code)
|
||||
|
||||
with open(os.path.join(temp_dir, "src", "safe_code.py"), "w") as f:
|
||||
f.write(safe_code)
|
||||
|
||||
with open(os.path.join(temp_dir, "config", "settings.py"), "w") as f:
|
||||
f.write(config_code)
|
||||
|
||||
with open(os.path.join(temp_dir, "requirements.txt"), "w") as f:
|
||||
f.write(requirements)
|
||||
|
||||
yield temp_dir
|
||||
|
||||
# 清理
|
||||
shutil.rmtree(temp_dir, ignore_errors=True)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_llm_service():
|
||||
"""模拟 LLM 服务"""
|
||||
service = MagicMock()
|
||||
service.chat_completion_raw = AsyncMock(return_value={
|
||||
"content": "测试响应",
|
||||
"usage": {"total_tokens": 100},
|
||||
})
|
||||
return service
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_event_emitter():
|
||||
"""模拟事件发射器"""
|
||||
emitter = MagicMock()
|
||||
emitter.emit_info = AsyncMock()
|
||||
emitter.emit_warning = AsyncMock()
|
||||
emitter.emit_error = AsyncMock()
|
||||
emitter.emit_thinking = AsyncMock()
|
||||
emitter.emit_tool_call = AsyncMock()
|
||||
emitter.emit_tool_result = AsyncMock()
|
||||
emitter.emit_finding = AsyncMock()
|
||||
emitter.emit_progress = AsyncMock()
|
||||
emitter.emit_phase_start = AsyncMock()
|
||||
emitter.emit_phase_complete = AsyncMock()
|
||||
emitter.emit_task_complete = AsyncMock()
|
||||
emitter.emit = AsyncMock()
|
||||
return emitter
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_db_session():
|
||||
"""模拟数据库会话"""
|
||||
session = AsyncMock()
|
||||
session.add = MagicMock()
|
||||
session.commit = AsyncMock()
|
||||
session.rollback = AsyncMock()
|
||||
session.get = AsyncMock(return_value=None)
|
||||
session.execute = AsyncMock()
|
||||
return session
|
||||
|
||||
|
||||
@dataclass
|
||||
class MockProject:
|
||||
"""模拟项目"""
|
||||
id: str = "test-project-id"
|
||||
name: str = "Test Project"
|
||||
description: str = "Test project for unit tests"
|
||||
|
||||
|
||||
@dataclass
|
||||
class MockAgentTask:
|
||||
"""模拟 Agent 任务"""
|
||||
id: str = "test-task-id"
|
||||
project_id: str = "test-project-id"
|
||||
project: MockProject = None
|
||||
name: str = "Test Agent Task"
|
||||
status: str = "pending"
|
||||
current_phase: str = "planning"
|
||||
target_vulnerabilities: list = None
|
||||
verification_level: str = "sandbox"
|
||||
exclude_patterns: list = None
|
||||
target_files: list = None
|
||||
max_iterations: int = 50
|
||||
timeout_seconds: int = 1800
|
||||
|
||||
def __post_init__(self):
|
||||
if self.project is None:
|
||||
self.project = MockProject()
|
||||
if self.target_vulnerabilities is None:
|
||||
self.target_vulnerabilities = []
|
||||
if self.exclude_patterns is None:
|
||||
self.exclude_patterns = []
|
||||
if self.target_files is None:
|
||||
self.target_files = []
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_task():
|
||||
"""创建模拟任务"""
|
||||
return MockAgentTask()
|
||||
|
||||
|
||||
# ============ 测试辅助函数 ============
|
||||
|
||||
def assert_finding_valid(finding: Dict[str, Any]):
|
||||
"""验证漏洞发现的格式"""
|
||||
required_fields = ["title", "severity", "vulnerability_type"]
|
||||
for field in required_fields:
|
||||
assert field in finding, f"Missing required field: {field}"
|
||||
|
||||
valid_severities = ["critical", "high", "medium", "low", "info"]
|
||||
assert finding["severity"] in valid_severities, f"Invalid severity: {finding['severity']}"
|
||||
|
||||
|
||||
def count_findings_by_type(findings: list, vuln_type: str) -> int:
|
||||
"""统计特定类型的漏洞数量"""
|
||||
return sum(1 for f in findings if f.get("vulnerability_type") == vuln_type)
|
||||
|
||||
|
||||
def count_findings_by_severity(findings: list, severity: str) -> int:
|
||||
"""统计特定严重程度的漏洞数量"""
|
||||
return sum(1 for f in findings if f.get("severity") == severity)
|
||||
|
||||
|
|
@ -0,0 +1,96 @@
|
|||
#!/usr/bin/env python
|
||||
"""
|
||||
Agent 测试运行器
|
||||
运行所有 Agent 相关测试并生成报告
|
||||
"""
|
||||
|
||||
import subprocess
|
||||
import sys
|
||||
import os
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
|
||||
|
||||
def run_tests():
|
||||
"""运行测试"""
|
||||
# 获取项目根目录
|
||||
project_root = Path(__file__).parent.parent.parent
|
||||
os.chdir(project_root)
|
||||
|
||||
print("=" * 60)
|
||||
print("DeepAudit Agent 测试套件")
|
||||
print(f"时间: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}")
|
||||
print("=" * 60)
|
||||
print()
|
||||
|
||||
# 测试命令
|
||||
cmd = [
|
||||
sys.executable, "-m", "pytest",
|
||||
"tests/agent/",
|
||||
"-v",
|
||||
"--tb=short",
|
||||
"-x", # 遇到第一个失败就停止
|
||||
"--color=yes",
|
||||
]
|
||||
|
||||
print(f"运行命令: {' '.join(cmd)}")
|
||||
print()
|
||||
|
||||
# 运行测试
|
||||
result = subprocess.run(cmd, cwd=project_root)
|
||||
|
||||
print()
|
||||
print("=" * 60)
|
||||
if result.returncode == 0:
|
||||
print("✅ 所有测试通过!")
|
||||
else:
|
||||
print(f"❌ 测试失败 (退出码: {result.returncode})")
|
||||
print("=" * 60)
|
||||
|
||||
return result.returncode
|
||||
|
||||
|
||||
def run_tests_with_coverage():
|
||||
"""运行测试并生成覆盖率报告"""
|
||||
project_root = Path(__file__).parent.parent.parent
|
||||
os.chdir(project_root)
|
||||
|
||||
cmd = [
|
||||
sys.executable, "-m", "pytest",
|
||||
"tests/agent/",
|
||||
"-v",
|
||||
"--cov=app/services/agent",
|
||||
"--cov-report=term-missing",
|
||||
"--cov-report=html:coverage_agent",
|
||||
]
|
||||
|
||||
result = subprocess.run(cmd, cwd=project_root)
|
||||
return result.returncode
|
||||
|
||||
|
||||
def run_specific_test(test_name: str):
|
||||
"""运行特定测试"""
|
||||
project_root = Path(__file__).parent.parent.parent
|
||||
os.chdir(project_root)
|
||||
|
||||
cmd = [
|
||||
sys.executable, "-m", "pytest",
|
||||
f"tests/agent/{test_name}",
|
||||
"-v",
|
||||
"--tb=long",
|
||||
"-s", # 显示 print 输出
|
||||
]
|
||||
|
||||
result = subprocess.run(cmd, cwd=project_root)
|
||||
return result.returncode
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
if len(sys.argv) > 1:
|
||||
if sys.argv[1] == "--coverage":
|
||||
sys.exit(run_tests_with_coverage())
|
||||
else:
|
||||
sys.exit(run_specific_test(sys.argv[1]))
|
||||
else:
|
||||
sys.exit(run_tests())
|
||||
|
||||
|
|
@ -0,0 +1,213 @@
|
|||
"""
|
||||
Agent 单元测试
|
||||
测试各个 Agent 的功能
|
||||
"""
|
||||
|
||||
import pytest
|
||||
import asyncio
|
||||
import os
|
||||
from unittest.mock import MagicMock, AsyncMock, patch
|
||||
|
||||
from app.services.agent.agents.base import BaseAgent, AgentConfig, AgentResult, AgentType, AgentPattern
|
||||
from app.services.agent.agents.recon import ReconAgent
|
||||
from app.services.agent.agents.analysis import AnalysisAgent
|
||||
from app.services.agent.agents.verification import VerificationAgent
|
||||
|
||||
|
||||
class TestReconAgent:
|
||||
"""Recon Agent 测试"""
|
||||
|
||||
@pytest.fixture
|
||||
def recon_agent(self, temp_project_dir, mock_llm_service, mock_event_emitter):
|
||||
"""创建 Recon Agent 实例"""
|
||||
from app.services.agent.tools import (
|
||||
FileReadTool, FileSearchTool, ListFilesTool,
|
||||
)
|
||||
|
||||
tools = {
|
||||
"list_files": ListFilesTool(temp_project_dir),
|
||||
"read_file": FileReadTool(temp_project_dir),
|
||||
"search_code": FileSearchTool(temp_project_dir),
|
||||
}
|
||||
|
||||
return ReconAgent(
|
||||
llm_service=mock_llm_service,
|
||||
tools=tools,
|
||||
event_emitter=mock_event_emitter,
|
||||
)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_recon_agent_run(self, recon_agent, temp_project_dir):
|
||||
"""测试 Recon Agent 运行"""
|
||||
result = await recon_agent.run({
|
||||
"project_info": {
|
||||
"name": "Test Project",
|
||||
"root": temp_project_dir,
|
||||
},
|
||||
"config": {},
|
||||
})
|
||||
|
||||
assert result.success is True
|
||||
assert result.data is not None
|
||||
|
||||
# 验证返回数据结构
|
||||
data = result.data
|
||||
assert "tech_stack" in data
|
||||
assert "entry_points" in data or "high_risk_areas" in data
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_recon_agent_identifies_python(self, recon_agent, temp_project_dir):
|
||||
"""测试 Recon Agent 识别 Python 技术栈"""
|
||||
result = await recon_agent.run({
|
||||
"project_info": {"root": temp_project_dir},
|
||||
"config": {},
|
||||
})
|
||||
|
||||
assert result.success is True
|
||||
tech_stack = result.data.get("tech_stack", {})
|
||||
languages = tech_stack.get("languages", [])
|
||||
|
||||
# 应该识别出 Python
|
||||
assert "Python" in languages or len(languages) > 0
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_recon_agent_finds_high_risk_areas(self, recon_agent, temp_project_dir):
|
||||
"""测试 Recon Agent 发现高风险区域"""
|
||||
result = await recon_agent.run({
|
||||
"project_info": {"root": temp_project_dir},
|
||||
"config": {},
|
||||
})
|
||||
|
||||
assert result.success is True
|
||||
high_risk_areas = result.data.get("high_risk_areas", [])
|
||||
|
||||
# 应该发现高风险区域
|
||||
assert len(high_risk_areas) > 0
|
||||
|
||||
|
||||
class TestAnalysisAgent:
|
||||
"""Analysis Agent 测试"""
|
||||
|
||||
@pytest.fixture
|
||||
def analysis_agent(self, temp_project_dir, mock_llm_service, mock_event_emitter):
|
||||
"""创建 Analysis Agent 实例"""
|
||||
from app.services.agent.tools import (
|
||||
FileReadTool, FileSearchTool, PatternMatchTool,
|
||||
)
|
||||
|
||||
tools = {
|
||||
"read_file": FileReadTool(temp_project_dir),
|
||||
"search_code": FileSearchTool(temp_project_dir),
|
||||
"pattern_match": PatternMatchTool(temp_project_dir),
|
||||
}
|
||||
|
||||
return AnalysisAgent(
|
||||
llm_service=mock_llm_service,
|
||||
tools=tools,
|
||||
event_emitter=mock_event_emitter,
|
||||
)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_analysis_agent_run(self, analysis_agent, temp_project_dir):
|
||||
"""测试 Analysis Agent 运行"""
|
||||
result = await analysis_agent.run({
|
||||
"tech_stack": {"languages": ["Python"]},
|
||||
"entry_points": [],
|
||||
"high_risk_areas": ["src/sql_vuln.py", "src/cmd_vuln.py"],
|
||||
"config": {},
|
||||
})
|
||||
|
||||
assert result.success is True
|
||||
assert result.data is not None
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_analysis_agent_finds_vulnerabilities(self, analysis_agent, temp_project_dir):
|
||||
"""测试 Analysis Agent 发现漏洞"""
|
||||
result = await analysis_agent.run({
|
||||
"tech_stack": {"languages": ["Python"]},
|
||||
"entry_points": [],
|
||||
"high_risk_areas": [
|
||||
"src/sql_vuln.py",
|
||||
"src/cmd_vuln.py",
|
||||
"src/xss_vuln.py",
|
||||
"src/secrets.py",
|
||||
],
|
||||
"config": {},
|
||||
})
|
||||
|
||||
assert result.success is True
|
||||
findings = result.data.get("findings", [])
|
||||
|
||||
# 应该发现一些漏洞
|
||||
# 注意:具体数量取决于分析逻辑
|
||||
assert isinstance(findings, list)
|
||||
|
||||
|
||||
class TestAgentResult:
|
||||
"""Agent 结果测试"""
|
||||
|
||||
def test_agent_result_success(self):
|
||||
"""测试成功的 Agent 结果"""
|
||||
result = AgentResult(
|
||||
success=True,
|
||||
data={"findings": []},
|
||||
iterations=5,
|
||||
tool_calls=10,
|
||||
)
|
||||
|
||||
assert result.success is True
|
||||
assert result.iterations == 5
|
||||
assert result.tool_calls == 10
|
||||
|
||||
def test_agent_result_failure(self):
|
||||
"""测试失败的 Agent 结果"""
|
||||
result = AgentResult(
|
||||
success=False,
|
||||
error="Test error",
|
||||
)
|
||||
|
||||
assert result.success is False
|
||||
assert result.error == "Test error"
|
||||
|
||||
def test_agent_result_to_dict(self):
|
||||
"""测试 Agent 结果转字典"""
|
||||
result = AgentResult(
|
||||
success=True,
|
||||
data={"key": "value"},
|
||||
iterations=3,
|
||||
)
|
||||
|
||||
d = result.to_dict()
|
||||
|
||||
assert d["success"] is True
|
||||
assert d["iterations"] == 3
|
||||
|
||||
|
||||
class TestAgentConfig:
|
||||
"""Agent 配置测试"""
|
||||
|
||||
def test_agent_config_defaults(self):
|
||||
"""测试 Agent 配置默认值"""
|
||||
config = AgentConfig(
|
||||
name="Test",
|
||||
agent_type=AgentType.RECON,
|
||||
)
|
||||
|
||||
assert config.pattern == AgentPattern.REACT
|
||||
assert config.max_iterations == 20
|
||||
assert config.temperature == 0.1
|
||||
|
||||
def test_agent_config_custom(self):
|
||||
"""测试自定义 Agent 配置"""
|
||||
config = AgentConfig(
|
||||
name="Custom",
|
||||
agent_type=AgentType.ANALYSIS,
|
||||
pattern=AgentPattern.PLAN_AND_EXECUTE,
|
||||
max_iterations=50,
|
||||
temperature=0.5,
|
||||
)
|
||||
|
||||
assert config.pattern == AgentPattern.PLAN_AND_EXECUTE
|
||||
assert config.max_iterations == 50
|
||||
assert config.temperature == 0.5
|
||||
|
||||
|
|
@ -0,0 +1,355 @@
|
|||
"""
|
||||
Agent 集成测试
|
||||
测试完整的审计流程
|
||||
"""
|
||||
|
||||
import pytest
|
||||
import asyncio
|
||||
import os
|
||||
from unittest.mock import MagicMock, AsyncMock, patch
|
||||
from datetime import datetime
|
||||
|
||||
from app.services.agent.graph.runner import AgentRunner, LLMService
|
||||
from app.services.agent.graph.audit_graph import AuditState, create_audit_graph
|
||||
from app.services.agent.graph.nodes import ReconNode, AnalysisNode, VerificationNode, ReportNode
|
||||
from app.services.agent.event_manager import EventManager, AgentEventEmitter
|
||||
|
||||
|
||||
class TestLLMService:
|
||||
"""LLM 服务测试"""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_llm_service_initialization(self):
|
||||
"""测试 LLM 服务初始化"""
|
||||
with patch("app.core.config.settings") as mock_settings:
|
||||
mock_settings.LLM_MODEL = "gpt-4o-mini"
|
||||
mock_settings.LLM_API_KEY = "test-key"
|
||||
|
||||
service = LLMService()
|
||||
|
||||
assert service.model == "gpt-4o-mini"
|
||||
|
||||
|
||||
class TestEventManager:
|
||||
"""事件管理器测试"""
|
||||
|
||||
def test_event_manager_initialization(self):
|
||||
"""测试事件管理器初始化"""
|
||||
manager = EventManager()
|
||||
|
||||
assert manager._event_queues == {}
|
||||
assert manager._event_callbacks == {}
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_event_emitter(self):
|
||||
"""测试事件发射器"""
|
||||
manager = EventManager()
|
||||
emitter = AgentEventEmitter("test-task-id", manager)
|
||||
|
||||
await emitter.emit_info("Test message")
|
||||
|
||||
assert emitter._sequence == 1
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_event_emitter_phase_tracking(self):
|
||||
"""测试事件发射器阶段跟踪"""
|
||||
manager = EventManager()
|
||||
emitter = AgentEventEmitter("test-task-id", manager)
|
||||
|
||||
await emitter.emit_phase_start("recon", "开始信息收集")
|
||||
|
||||
assert emitter._current_phase == "recon"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_event_emitter_task_complete(self):
|
||||
"""测试任务完成事件"""
|
||||
manager = EventManager()
|
||||
emitter = AgentEventEmitter("test-task-id", manager)
|
||||
|
||||
await emitter.emit_task_complete(findings_count=5, duration_ms=1000)
|
||||
|
||||
assert emitter._sequence == 1
|
||||
|
||||
|
||||
class TestAuditGraph:
|
||||
"""审计图测试"""
|
||||
|
||||
def test_create_audit_graph(self, mock_event_emitter):
|
||||
"""测试创建审计图"""
|
||||
# 创建模拟节点
|
||||
recon_node = MagicMock()
|
||||
analysis_node = MagicMock()
|
||||
verification_node = MagicMock()
|
||||
report_node = MagicMock()
|
||||
|
||||
graph = create_audit_graph(
|
||||
recon_node=recon_node,
|
||||
analysis_node=analysis_node,
|
||||
verification_node=verification_node,
|
||||
report_node=report_node,
|
||||
)
|
||||
|
||||
assert graph is not None
|
||||
|
||||
|
||||
class TestReconNode:
|
||||
"""Recon 节点测试"""
|
||||
|
||||
@pytest.fixture
|
||||
def recon_node_with_mock_agent(self, mock_event_emitter):
|
||||
"""创建带模拟 Agent 的 Recon 节点"""
|
||||
mock_agent = MagicMock()
|
||||
mock_agent.run = AsyncMock(return_value=MagicMock(
|
||||
success=True,
|
||||
data={
|
||||
"tech_stack": {"languages": ["Python"]},
|
||||
"entry_points": [{"path": "src/app.py", "type": "api"}],
|
||||
"high_risk_areas": ["src/sql_vuln.py"],
|
||||
"dependencies": {},
|
||||
"initial_findings": [],
|
||||
}
|
||||
))
|
||||
|
||||
return ReconNode(mock_agent, mock_event_emitter)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_recon_node_success(self, recon_node_with_mock_agent):
|
||||
"""测试 Recon 节点成功执行"""
|
||||
state = {
|
||||
"project_info": {"name": "Test"},
|
||||
"config": {},
|
||||
}
|
||||
|
||||
result = await recon_node_with_mock_agent(state)
|
||||
|
||||
assert "tech_stack" in result
|
||||
assert "entry_points" in result
|
||||
assert result["current_phase"] == "recon_complete"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_recon_node_failure(self, mock_event_emitter):
|
||||
"""测试 Recon 节点失败处理"""
|
||||
mock_agent = MagicMock()
|
||||
mock_agent.run = AsyncMock(return_value=MagicMock(
|
||||
success=False,
|
||||
error="Test error",
|
||||
data=None,
|
||||
))
|
||||
|
||||
node = ReconNode(mock_agent, mock_event_emitter)
|
||||
|
||||
result = await node({
|
||||
"project_info": {},
|
||||
"config": {},
|
||||
})
|
||||
|
||||
assert "error" in result
|
||||
assert result["current_phase"] == "error"
|
||||
|
||||
|
||||
class TestAnalysisNode:
|
||||
"""Analysis 节点测试"""
|
||||
|
||||
@pytest.fixture
|
||||
def analysis_node_with_mock_agent(self, mock_event_emitter):
|
||||
"""创建带模拟 Agent 的 Analysis 节点"""
|
||||
mock_agent = MagicMock()
|
||||
mock_agent.run = AsyncMock(return_value=MagicMock(
|
||||
success=True,
|
||||
data={
|
||||
"findings": [
|
||||
{
|
||||
"id": "finding-1",
|
||||
"title": "SQL Injection",
|
||||
"severity": "high",
|
||||
"vulnerability_type": "sql_injection",
|
||||
"file_path": "src/sql_vuln.py",
|
||||
"line_start": 10,
|
||||
"description": "SQL injection vulnerability",
|
||||
}
|
||||
],
|
||||
"should_continue": False,
|
||||
}
|
||||
))
|
||||
|
||||
return AnalysisNode(mock_agent, mock_event_emitter)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_analysis_node_success(self, analysis_node_with_mock_agent):
|
||||
"""测试 Analysis 节点成功执行"""
|
||||
state = {
|
||||
"project_info": {"name": "Test"},
|
||||
"tech_stack": {"languages": ["Python"]},
|
||||
"entry_points": [],
|
||||
"high_risk_areas": ["src/sql_vuln.py"],
|
||||
"config": {},
|
||||
"iteration": 0,
|
||||
"findings": [],
|
||||
}
|
||||
|
||||
result = await analysis_node_with_mock_agent(state)
|
||||
|
||||
assert "findings" in result
|
||||
assert len(result["findings"]) > 0
|
||||
assert result["iteration"] == 1
|
||||
|
||||
|
||||
class TestIntegrationFlow:
|
||||
"""完整流程集成测试"""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_full_audit_flow_mock(self, temp_project_dir, mock_db_session, mock_task):
|
||||
"""测试完整审计流程(使用模拟)"""
|
||||
# 这个测试验证整个流程的连接性
|
||||
|
||||
# 创建事件管理器
|
||||
event_manager = EventManager()
|
||||
emitter = AgentEventEmitter(mock_task.id, event_manager)
|
||||
|
||||
# 模拟 LLM 服务
|
||||
mock_llm = MagicMock()
|
||||
mock_llm.chat_completion_raw = AsyncMock(return_value={
|
||||
"content": "Analysis complete",
|
||||
"usage": {"total_tokens": 100},
|
||||
})
|
||||
|
||||
# 验证事件发射
|
||||
await emitter.emit_phase_start("init", "初始化")
|
||||
await emitter.emit_info("测试消息")
|
||||
await emitter.emit_phase_complete("init", "初始化完成")
|
||||
|
||||
assert emitter._sequence == 3
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_audit_state_typing(self):
|
||||
"""测试审计状态类型定义"""
|
||||
state: AuditState = {
|
||||
"project_root": "/tmp/test",
|
||||
"project_info": {"name": "Test"},
|
||||
"config": {},
|
||||
"task_id": "test-id",
|
||||
"tech_stack": {},
|
||||
"entry_points": [],
|
||||
"high_risk_areas": [],
|
||||
"dependencies": {},
|
||||
"findings": [],
|
||||
"verified_findings": [],
|
||||
"false_positives": [],
|
||||
"current_phase": "start",
|
||||
"iteration": 0,
|
||||
"max_iterations": 50,
|
||||
"should_continue_analysis": False,
|
||||
"messages": [],
|
||||
"events": [],
|
||||
"summary": None,
|
||||
"security_score": None,
|
||||
"error": None,
|
||||
}
|
||||
|
||||
assert state["current_phase"] == "start"
|
||||
assert state["max_iterations"] == 50
|
||||
|
||||
|
||||
class TestToolIntegration:
|
||||
"""工具集成测试"""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_tools_work_together(self, temp_project_dir):
|
||||
"""测试工具协同工作"""
|
||||
from app.services.agent.tools import (
|
||||
FileReadTool, FileSearchTool, ListFilesTool, PatternMatchTool,
|
||||
)
|
||||
|
||||
# 1. 列出文件
|
||||
list_tool = ListFilesTool(temp_project_dir)
|
||||
list_result = await list_tool.execute(directory="src", recursive=False)
|
||||
assert list_result.success is True
|
||||
|
||||
# 2. 搜索关键代码
|
||||
search_tool = FileSearchTool(temp_project_dir)
|
||||
search_result = await search_tool.execute(keyword="execute")
|
||||
assert search_result.success is True
|
||||
|
||||
# 3. 读取文件内容
|
||||
read_tool = FileReadTool(temp_project_dir)
|
||||
read_result = await read_tool.execute(file_path="src/sql_vuln.py")
|
||||
assert read_result.success is True
|
||||
|
||||
# 4. 模式匹配
|
||||
pattern_tool = PatternMatchTool(temp_project_dir)
|
||||
pattern_result = await pattern_tool.execute(
|
||||
code=read_result.data,
|
||||
file_path="src/sql_vuln.py",
|
||||
language="python"
|
||||
)
|
||||
assert pattern_result.success is True
|
||||
|
||||
|
||||
class TestErrorHandling:
|
||||
"""错误处理测试"""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_tool_error_handling(self, temp_project_dir):
|
||||
"""测试工具错误处理"""
|
||||
from app.services.agent.tools import FileReadTool
|
||||
|
||||
tool = FileReadTool(temp_project_dir)
|
||||
|
||||
# 尝试读取不存在的文件
|
||||
result = await tool.execute(file_path="nonexistent/file.py")
|
||||
|
||||
assert result.success is False
|
||||
assert result.error is not None
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_agent_graceful_degradation(self, mock_event_emitter):
|
||||
"""测试 Agent 优雅降级"""
|
||||
# 创建一个会失败的 Agent
|
||||
mock_agent = MagicMock()
|
||||
mock_agent.run = AsyncMock(side_effect=Exception("Simulated error"))
|
||||
|
||||
node = ReconNode(mock_agent, mock_event_emitter)
|
||||
|
||||
result = await node({
|
||||
"project_info": {},
|
||||
"config": {},
|
||||
})
|
||||
|
||||
# 应该返回错误状态而不是崩溃
|
||||
assert "error" in result
|
||||
assert result["current_phase"] == "error"
|
||||
|
||||
|
||||
class TestPerformance:
|
||||
"""性能测试"""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_tool_response_time(self, temp_project_dir):
|
||||
"""测试工具响应时间"""
|
||||
from app.services.agent.tools import ListFilesTool
|
||||
import time
|
||||
|
||||
tool = ListFilesTool(temp_project_dir)
|
||||
|
||||
start = time.time()
|
||||
await tool.execute(directory=".", recursive=True)
|
||||
duration = time.time() - start
|
||||
|
||||
# 工具应该在合理时间内响应
|
||||
assert duration < 5.0 # 5 秒内
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_multiple_tool_calls(self, temp_project_dir):
|
||||
"""测试多次工具调用"""
|
||||
from app.services.agent.tools import FileSearchTool
|
||||
|
||||
tool = FileSearchTool(temp_project_dir)
|
||||
|
||||
# 执行多次调用
|
||||
for _ in range(5):
|
||||
result = await tool.execute(keyword="def")
|
||||
assert result.success is True
|
||||
|
||||
# 验证调用计数
|
||||
assert tool._call_count == 5
|
||||
|
||||
|
|
@ -0,0 +1,248 @@
|
|||
"""
|
||||
Agent 工具单元测试
|
||||
测试各种安全分析工具的功能
|
||||
"""
|
||||
|
||||
import pytest
|
||||
import asyncio
|
||||
import os
|
||||
from unittest.mock import MagicMock, AsyncMock, patch
|
||||
|
||||
# 导入工具
|
||||
from app.services.agent.tools import (
|
||||
FileReadTool, FileSearchTool, ListFilesTool,
|
||||
PatternMatchTool,
|
||||
)
|
||||
from app.services.agent.tools.base import ToolResult
|
||||
|
||||
|
||||
class TestFileTools:
|
||||
"""文件操作工具测试"""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_file_read_tool_success(self, temp_project_dir):
|
||||
"""测试文件读取工具 - 成功读取"""
|
||||
tool = FileReadTool(temp_project_dir)
|
||||
|
||||
result = await tool.execute(file_path="src/sql_vuln.py")
|
||||
|
||||
assert result.success is True
|
||||
assert "SELECT * FROM users" in result.data
|
||||
assert "sql_injection" in result.data.lower() or "cursor.execute" in result.data
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_file_read_tool_not_found(self, temp_project_dir):
|
||||
"""测试文件读取工具 - 文件不存在"""
|
||||
tool = FileReadTool(temp_project_dir)
|
||||
|
||||
result = await tool.execute(file_path="nonexistent.py")
|
||||
|
||||
assert result.success is False
|
||||
assert "不存在" in result.error or "not found" in result.error.lower()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_file_read_tool_path_traversal_blocked(self, temp_project_dir):
|
||||
"""测试文件读取工具 - 路径遍历被阻止"""
|
||||
tool = FileReadTool(temp_project_dir)
|
||||
|
||||
result = await tool.execute(file_path="../../../etc/passwd")
|
||||
|
||||
assert result.success is False
|
||||
assert "安全" in result.error or "security" in result.error.lower()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_file_search_tool(self, temp_project_dir):
|
||||
"""测试文件搜索工具"""
|
||||
tool = FileSearchTool(temp_project_dir)
|
||||
|
||||
result = await tool.execute(keyword="cursor.execute")
|
||||
|
||||
assert result.success is True
|
||||
assert "sql_vuln.py" in result.data
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_list_files_tool(self, temp_project_dir):
|
||||
"""测试文件列表工具"""
|
||||
tool = ListFilesTool(temp_project_dir)
|
||||
|
||||
result = await tool.execute(directory=".", recursive=True)
|
||||
|
||||
assert result.success is True
|
||||
assert "sql_vuln.py" in result.data
|
||||
assert "requirements.txt" in result.data
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_list_files_tool_pattern(self, temp_project_dir):
|
||||
"""测试文件列表工具 - 文件模式过滤"""
|
||||
tool = ListFilesTool(temp_project_dir)
|
||||
|
||||
result = await tool.execute(directory="src", pattern="*.py")
|
||||
|
||||
assert result.success is True
|
||||
assert "sql_vuln.py" in result.data
|
||||
|
||||
|
||||
class TestPatternMatchTool:
|
||||
"""模式匹配工具测试"""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_pattern_match_sql_injection(self, temp_project_dir):
|
||||
"""测试模式匹配 - SQL 注入检测"""
|
||||
tool = PatternMatchTool(temp_project_dir)
|
||||
|
||||
# 读取有漏洞的代码
|
||||
with open(os.path.join(temp_project_dir, "src", "sql_vuln.py")) as f:
|
||||
code = f.read()
|
||||
|
||||
result = await tool.execute(
|
||||
code=code,
|
||||
file_path="src/sql_vuln.py",
|
||||
pattern_types=["sql_injection"],
|
||||
language="python"
|
||||
)
|
||||
|
||||
assert result.success is True
|
||||
# 应该检测到 SQL 注入模式
|
||||
if result.data:
|
||||
assert "sql" in str(result.data).lower() or len(result.metadata.get("matches", [])) > 0
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_pattern_match_command_injection(self, temp_project_dir):
|
||||
"""测试模式匹配 - 命令注入检测"""
|
||||
tool = PatternMatchTool(temp_project_dir)
|
||||
|
||||
with open(os.path.join(temp_project_dir, "src", "cmd_vuln.py")) as f:
|
||||
code = f.read()
|
||||
|
||||
result = await tool.execute(
|
||||
code=code,
|
||||
file_path="src/cmd_vuln.py",
|
||||
pattern_types=["command_injection"],
|
||||
language="python"
|
||||
)
|
||||
|
||||
assert result.success is True
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_pattern_match_xss(self, temp_project_dir):
|
||||
"""测试模式匹配 - XSS 检测"""
|
||||
tool = PatternMatchTool(temp_project_dir)
|
||||
|
||||
with open(os.path.join(temp_project_dir, "src", "xss_vuln.py")) as f:
|
||||
code = f.read()
|
||||
|
||||
result = await tool.execute(
|
||||
code=code,
|
||||
file_path="src/xss_vuln.py",
|
||||
pattern_types=["xss"],
|
||||
language="python"
|
||||
)
|
||||
|
||||
assert result.success is True
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_pattern_match_hardcoded_secrets(self, temp_project_dir):
|
||||
"""测试模式匹配 - 硬编码密钥检测"""
|
||||
tool = PatternMatchTool(temp_project_dir)
|
||||
|
||||
with open(os.path.join(temp_project_dir, "src", "secrets.py")) as f:
|
||||
code = f.read()
|
||||
|
||||
result = await tool.execute(
|
||||
code=code,
|
||||
file_path="src/secrets.py",
|
||||
pattern_types=["hardcoded_secret"],
|
||||
)
|
||||
|
||||
assert result.success is True
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_pattern_match_safe_code(self, temp_project_dir):
|
||||
"""测试模式匹配 - 安全代码应该没有问题"""
|
||||
tool = PatternMatchTool(temp_project_dir)
|
||||
|
||||
with open(os.path.join(temp_project_dir, "src", "safe_code.py")) as f:
|
||||
code = f.read()
|
||||
|
||||
result = await tool.execute(
|
||||
code=code,
|
||||
file_path="src/safe_code.py",
|
||||
pattern_types=["sql_injection"],
|
||||
language="python"
|
||||
)
|
||||
|
||||
assert result.success is True
|
||||
# 安全代码使用参数化查询,不应该有 SQL 注入漏洞
|
||||
# 检查结果数据,如果有 matches 字段
|
||||
matches = result.metadata.get("matches", [])
|
||||
if isinstance(matches, list):
|
||||
# 参数化查询不应该被误报为 SQL 注入
|
||||
sql_injection_count = sum(
|
||||
1 for m in matches
|
||||
if isinstance(m, dict) and "sql" in m.get("pattern_type", "").lower()
|
||||
)
|
||||
# 安全代码的 SQL 注入匹配应该很少或没有
|
||||
assert sql_injection_count <= 1 # 允许少量误报
|
||||
|
||||
|
||||
class TestToolResult:
|
||||
"""工具结果测试"""
|
||||
|
||||
def test_tool_result_success(self):
|
||||
"""测试成功的工具结果"""
|
||||
result = ToolResult(success=True, data="test data")
|
||||
|
||||
assert result.success is True
|
||||
assert result.data == "test data"
|
||||
assert result.error is None
|
||||
|
||||
def test_tool_result_failure(self):
|
||||
"""测试失败的工具结果"""
|
||||
result = ToolResult(success=False, error="test error")
|
||||
|
||||
assert result.success is False
|
||||
assert result.error == "test error"
|
||||
|
||||
def test_tool_result_to_string(self):
|
||||
"""测试工具结果转字符串"""
|
||||
result = ToolResult(success=True, data={"key": "value"})
|
||||
|
||||
string = result.to_string()
|
||||
|
||||
assert "key" in string
|
||||
assert "value" in string
|
||||
|
||||
def test_tool_result_to_string_truncate(self):
|
||||
"""测试工具结果字符串截断"""
|
||||
long_data = "x" * 10000
|
||||
result = ToolResult(success=True, data=long_data)
|
||||
|
||||
string = result.to_string(max_length=100)
|
||||
|
||||
assert len(string) < len(long_data)
|
||||
assert "truncated" in string.lower()
|
||||
|
||||
|
||||
class TestToolMetadata:
|
||||
"""工具元数据测试"""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_tool_call_count(self, temp_project_dir):
|
||||
"""测试工具调用计数"""
|
||||
tool = ListFilesTool(temp_project_dir)
|
||||
|
||||
await tool.execute(directory=".")
|
||||
await tool.execute(directory="src")
|
||||
|
||||
assert tool._call_count == 2
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_tool_duration_tracking(self, temp_project_dir):
|
||||
"""测试工具执行时间跟踪"""
|
||||
tool = ListFilesTool(temp_project_dir)
|
||||
|
||||
result = await tool.execute(directory=".")
|
||||
|
||||
assert result.duration_ms >= 0
|
||||
assert tool._total_duration_ms >= 0
|
||||
|
||||
|
|
@ -0,0 +1,229 @@
|
|||
# DeepAudit Agent 审计功能部署清单
|
||||
|
||||
## 📋 生产部署前必须完成的检查
|
||||
|
||||
### 1. 环境依赖 ✅
|
||||
|
||||
```bash
|
||||
# 后端依赖
|
||||
cd backend
|
||||
uv pip install chromadb litellm langchain langgraph
|
||||
|
||||
# 外部安全工具(可选但推荐)
|
||||
pip install semgrep bandit safety
|
||||
|
||||
# 或者使用系统包管理器
|
||||
brew install semgrep # macOS
|
||||
apt install semgrep # Ubuntu
|
||||
```
|
||||
|
||||
### 2. LLM 配置 ✅
|
||||
|
||||
在 `.env` 文件中配置:
|
||||
|
||||
```env
|
||||
# LLM 配置(必须)
|
||||
LLM_PROVIDER=openai # 或 azure, anthropic, ollama 等
|
||||
LLM_MODEL=gpt-4o-mini # 推荐使用 gpt-4 系列
|
||||
LLM_API_KEY=sk-xxx # 你的 API Key
|
||||
LLM_BASE_URL= # 可选,自定义端点
|
||||
|
||||
# 嵌入模型配置(RAG 需要)
|
||||
EMBEDDING_PROVIDER=openai
|
||||
EMBEDDING_MODEL=text-embedding-3-small
|
||||
```
|
||||
|
||||
### 3. 数据库迁移 ✅
|
||||
|
||||
```bash
|
||||
cd backend
|
||||
alembic upgrade head
|
||||
```
|
||||
|
||||
确保以下表已创建:
|
||||
- `agent_tasks`
|
||||
- `agent_events`
|
||||
- `agent_findings`
|
||||
|
||||
### 4. 向量数据库 ✅
|
||||
|
||||
```bash
|
||||
# 创建向量数据库目录
|
||||
mkdir -p /var/data/deepaudit/vector_db
|
||||
|
||||
# 在 .env 中配置
|
||||
VECTOR_DB_PATH=/var/data/deepaudit/vector_db
|
||||
```
|
||||
|
||||
### 5. Docker 沙箱(可选)
|
||||
|
||||
如果需要漏洞验证功能:
|
||||
|
||||
```bash
|
||||
# 拉取沙箱镜像
|
||||
docker pull python:3.11-slim
|
||||
|
||||
# 配置沙箱参数
|
||||
SANDBOX_IMAGE=python:3.11-slim
|
||||
SANDBOX_MEMORY_LIMIT=256m
|
||||
SANDBOX_CPU_LIMIT=0.5
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## 🔬 功能测试检查
|
||||
|
||||
### 测试 1: 基础流程
|
||||
|
||||
```bash
|
||||
cd backend
|
||||
PYTHONPATH=. uv run pytest tests/agent/ -v
|
||||
```
|
||||
|
||||
预期结果:43 个测试全部通过
|
||||
|
||||
### 测试 2: LLM 连接
|
||||
|
||||
```bash
|
||||
cd backend
|
||||
PYTHONPATH=. uv run python -c "
|
||||
import asyncio
|
||||
from app.services.agent.graph.runner import LLMService
|
||||
|
||||
async def test():
|
||||
llm = LLMService()
|
||||
result = await llm.analyze_code('print(\"hello\")', 'python')
|
||||
print('LLM 连接成功:', 'issues' in result)
|
||||
|
||||
asyncio.run(test())
|
||||
"
|
||||
```
|
||||
|
||||
### 测试 3: 外部工具
|
||||
|
||||
```bash
|
||||
# 测试 Semgrep
|
||||
semgrep --version
|
||||
|
||||
# 测试 Bandit
|
||||
bandit --version
|
||||
```
|
||||
|
||||
### 测试 4: 端到端测试
|
||||
|
||||
1. 启动后端:`cd backend && uv run uvicorn app.main:app --reload`
|
||||
2. 启动前端:`cd frontend && npm run dev`
|
||||
3. 创建一个项目并上传代码
|
||||
4. 选择 "Agent 审计模式" 创建任务
|
||||
5. 观察执行日志和发现
|
||||
|
||||
---
|
||||
|
||||
## ⚠️ 已知限制
|
||||
|
||||
| 限制 | 影响 | 解决方案 |
|
||||
|------|------|---------|
|
||||
| **LLM 成本** | 每次审计消耗 Token | 使用 gpt-4o-mini 降低成本 |
|
||||
| **扫描时间** | 大项目需要较长时间 | 设置合理的超时时间 |
|
||||
| **误报率** | AI 可能产生误报 | 启用验证阶段过滤 |
|
||||
| **外部工具依赖** | 需要手动安装 | 提供 Docker 镜像 |
|
||||
|
||||
---
|
||||
|
||||
## 🚀 生产环境建议
|
||||
|
||||
### 1. 资源配置
|
||||
|
||||
```yaml
|
||||
# Kubernetes 部署示例
|
||||
resources:
|
||||
limits:
|
||||
memory: "2Gi"
|
||||
cpu: "2"
|
||||
requests:
|
||||
memory: "1Gi"
|
||||
cpu: "1"
|
||||
```
|
||||
|
||||
### 2. 并发控制
|
||||
|
||||
```env
|
||||
# 限制同时运行的任务数
|
||||
MAX_CONCURRENT_AGENT_TASKS=3
|
||||
AGENT_TASK_TIMEOUT=1800 # 30 分钟
|
||||
```
|
||||
|
||||
### 3. 日志监控
|
||||
|
||||
```python
|
||||
# 配置日志级别
|
||||
LOG_LEVEL=INFO
|
||||
# 启用 SQLAlchemy 日志(调试用)
|
||||
SQLALCHEMY_ECHO=false
|
||||
```
|
||||
|
||||
### 4. 安全考虑
|
||||
|
||||
- [ ] 限制上传文件大小
|
||||
- [ ] 限制扫描目录范围
|
||||
- [ ] 启用沙箱隔离
|
||||
- [ ] 配置 API 速率限制
|
||||
|
||||
---
|
||||
|
||||
## ✅ 部署状态检查
|
||||
|
||||
运行以下命令验证部署状态:
|
||||
|
||||
```bash
|
||||
cd backend
|
||||
PYTHONPATH=. uv run python -c "
|
||||
print('检查部署状态...')
|
||||
|
||||
# 1. 检查数据库连接
|
||||
try:
|
||||
from app.db.session import async_session_factory
|
||||
print('✅ 数据库配置正确')
|
||||
except Exception as e:
|
||||
print(f'❌ 数据库错误: {e}')
|
||||
|
||||
# 2. 检查 LLM 配置
|
||||
from app.core.config import settings
|
||||
if settings.LLM_API_KEY:
|
||||
print('✅ LLM API Key 已配置')
|
||||
else:
|
||||
print('⚠️ LLM API Key 未配置')
|
||||
|
||||
# 3. 检查向量数据库
|
||||
import os
|
||||
if os.path.exists(settings.VECTOR_DB_PATH or '/tmp'):
|
||||
print('✅ 向量数据库路径存在')
|
||||
else:
|
||||
print('⚠️ 向量数据库路径不存在')
|
||||
|
||||
# 4. 检查外部工具
|
||||
import shutil
|
||||
tools = ['semgrep', 'bandit']
|
||||
for tool in tools:
|
||||
if shutil.which(tool):
|
||||
print(f'✅ {tool} 已安装')
|
||||
else:
|
||||
print(f'⚠️ {tool} 未安装(可选)')
|
||||
|
||||
print()
|
||||
print('部署检查完成!')
|
||||
"
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## 📝 结论
|
||||
|
||||
Agent 审计功能已经具备**基本的生产能力**,但建议:
|
||||
|
||||
1. **先在测试环境验证** - 用一个小项目测试完整流程
|
||||
2. **监控 LLM 成本** - 观察 Token 消耗情况
|
||||
3. **逐步开放** - 先给少数用户使用,收集反馈
|
||||
4. **持续优化** - 根据实际效果调整 prompt 和阈值
|
||||
|
||||
如有问题,请查看日志或联系开发团队。
|
||||
|
|
@ -55,3 +55,5 @@ EXPOSE 3000
|
|||
ENTRYPOINT ["/docker-entrypoint.sh"]
|
||||
CMD ["serve", "-s", "dist", "-l", "3000"]
|
||||
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -18,3 +18,5 @@ export const ProtectedRoute = () => {
|
|||
|
||||
|
||||
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -0,0 +1,280 @@
|
|||
/**
|
||||
* Agent 流式事件 React Hook
|
||||
*
|
||||
* 用于在 React 组件中消费 Agent 审计的实时事件流
|
||||
*/
|
||||
|
||||
import { useState, useEffect, useCallback, useRef } from 'react';
|
||||
import {
|
||||
AgentStreamHandler,
|
||||
StreamEventData,
|
||||
StreamOptions,
|
||||
AgentStreamState,
|
||||
} from '../shared/api/agentStream';
|
||||
|
||||
export interface UseAgentStreamOptions extends Omit<StreamOptions, 'onEvent'> {
|
||||
autoConnect?: boolean;
|
||||
maxEvents?: number;
|
||||
}
|
||||
|
||||
export interface UseAgentStreamReturn extends AgentStreamState {
|
||||
connect: () => void;
|
||||
disconnect: () => void;
|
||||
isConnected: boolean;
|
||||
clearEvents: () => void;
|
||||
}
|
||||
|
||||
/**
|
||||
* Agent 流式事件 Hook
|
||||
*
|
||||
* @example
|
||||
* ```tsx
|
||||
* function AgentAuditPanel({ taskId }: { taskId: string }) {
|
||||
* const {
|
||||
* events,
|
||||
* thinking,
|
||||
* isThinking,
|
||||
* toolCalls,
|
||||
* currentPhase,
|
||||
* progress,
|
||||
* findings,
|
||||
* isComplete,
|
||||
* error,
|
||||
* connect,
|
||||
* disconnect,
|
||||
* isConnected,
|
||||
* } = useAgentStream(taskId);
|
||||
*
|
||||
* useEffect(() => {
|
||||
* connect();
|
||||
* return () => disconnect();
|
||||
* }, [taskId]);
|
||||
*
|
||||
* return (
|
||||
* <div>
|
||||
* {isThinking && <ThinkingIndicator text={thinking} />}
|
||||
* {toolCalls.map(tc => <ToolCallCard key={tc.name} {...tc} />)}
|
||||
* {findings.map(f => <FindingCard key={f.id} {...f} />)}
|
||||
* </div>
|
||||
* );
|
||||
* }
|
||||
* ```
|
||||
*/
|
||||
export function useAgentStream(
|
||||
taskId: string | null,
|
||||
options: UseAgentStreamOptions = {}
|
||||
): UseAgentStreamReturn {
|
||||
const {
|
||||
autoConnect = false,
|
||||
maxEvents = 500,
|
||||
includeThinking = true,
|
||||
includeToolCalls = true,
|
||||
afterSequence = 0,
|
||||
...callbackOptions
|
||||
} = options;
|
||||
|
||||
// 状态
|
||||
const [events, setEvents] = useState<StreamEventData[]>([]);
|
||||
const [thinking, setThinking] = useState('');
|
||||
const [isThinking, setIsThinking] = useState(false);
|
||||
const [toolCalls, setToolCalls] = useState<AgentStreamState['toolCalls']>([]);
|
||||
const [currentPhase, setCurrentPhase] = useState('');
|
||||
const [progress, setProgress] = useState({ current: 0, total: 100, percentage: 0 });
|
||||
const [findings, setFindings] = useState<Record<string, unknown>[]>([]);
|
||||
const [isComplete, setIsComplete] = useState(false);
|
||||
const [error, setError] = useState<string | null>(null);
|
||||
const [isConnected, setIsConnected] = useState(false);
|
||||
|
||||
// Handler ref
|
||||
const handlerRef = useRef<AgentStreamHandler | null>(null);
|
||||
const thinkingBufferRef = useRef<string[]>([]);
|
||||
|
||||
// 连接
|
||||
const connect = useCallback(() => {
|
||||
if (!taskId) return;
|
||||
|
||||
// 断开现有连接
|
||||
if (handlerRef.current) {
|
||||
handlerRef.current.disconnect();
|
||||
}
|
||||
|
||||
// 重置状态
|
||||
setEvents([]);
|
||||
setThinking('');
|
||||
setIsThinking(false);
|
||||
setToolCalls([]);
|
||||
setCurrentPhase('');
|
||||
setProgress({ current: 0, total: 100, percentage: 0 });
|
||||
setFindings([]);
|
||||
setIsComplete(false);
|
||||
setError(null);
|
||||
thinkingBufferRef.current = [];
|
||||
|
||||
// 创建新的 handler
|
||||
handlerRef.current = new AgentStreamHandler(taskId, {
|
||||
includeThinking,
|
||||
includeToolCalls,
|
||||
afterSequence,
|
||||
|
||||
onEvent: (event) => {
|
||||
setEvents((prev) => [...prev.slice(-maxEvents + 1), event]);
|
||||
},
|
||||
|
||||
onThinkingStart: () => {
|
||||
thinkingBufferRef.current = [];
|
||||
setIsThinking(true);
|
||||
setThinking('');
|
||||
callbackOptions.onThinkingStart?.();
|
||||
},
|
||||
|
||||
onThinkingToken: (token, accumulated) => {
|
||||
thinkingBufferRef.current.push(token);
|
||||
setThinking(accumulated);
|
||||
callbackOptions.onThinkingToken?.(token, accumulated);
|
||||
},
|
||||
|
||||
onThinkingEnd: (response) => {
|
||||
setIsThinking(false);
|
||||
setThinking(response);
|
||||
thinkingBufferRef.current = [];
|
||||
callbackOptions.onThinkingEnd?.(response);
|
||||
},
|
||||
|
||||
onToolStart: (name, input) => {
|
||||
setToolCalls((prev) => [
|
||||
...prev,
|
||||
{ name, input, status: 'running' as const },
|
||||
]);
|
||||
callbackOptions.onToolStart?.(name, input);
|
||||
},
|
||||
|
||||
onToolEnd: (name, output, durationMs) => {
|
||||
setToolCalls((prev) =>
|
||||
prev.map((tc) =>
|
||||
tc.name === name && tc.status === 'running'
|
||||
? { ...tc, output, durationMs, status: 'success' as const }
|
||||
: tc
|
||||
)
|
||||
);
|
||||
callbackOptions.onToolEnd?.(name, output, durationMs);
|
||||
},
|
||||
|
||||
onNodeStart: (nodeName, phase) => {
|
||||
setCurrentPhase(phase);
|
||||
callbackOptions.onNodeStart?.(nodeName, phase);
|
||||
},
|
||||
|
||||
onNodeEnd: (nodeName, summary) => {
|
||||
callbackOptions.onNodeEnd?.(nodeName, summary);
|
||||
},
|
||||
|
||||
onProgress: (current, total, message) => {
|
||||
setProgress({
|
||||
current,
|
||||
total,
|
||||
percentage: total > 0 ? Math.round((current / total) * 100) : 0,
|
||||
});
|
||||
callbackOptions.onProgress?.(current, total, message);
|
||||
},
|
||||
|
||||
onFinding: (finding, isVerified) => {
|
||||
setFindings((prev) => [...prev, finding]);
|
||||
callbackOptions.onFinding?.(finding, isVerified);
|
||||
},
|
||||
|
||||
onComplete: (data) => {
|
||||
setIsComplete(true);
|
||||
setIsConnected(false);
|
||||
callbackOptions.onComplete?.(data);
|
||||
},
|
||||
|
||||
onError: (err) => {
|
||||
setError(err);
|
||||
setIsComplete(true);
|
||||
setIsConnected(false);
|
||||
callbackOptions.onError?.(err);
|
||||
},
|
||||
|
||||
onHeartbeat: () => {
|
||||
callbackOptions.onHeartbeat?.();
|
||||
},
|
||||
});
|
||||
|
||||
handlerRef.current.connect();
|
||||
setIsConnected(true);
|
||||
}, [taskId, includeThinking, includeToolCalls, afterSequence, maxEvents, callbackOptions]);
|
||||
|
||||
// 断开连接
|
||||
const disconnect = useCallback(() => {
|
||||
if (handlerRef.current) {
|
||||
handlerRef.current.disconnect();
|
||||
handlerRef.current = null;
|
||||
}
|
||||
setIsConnected(false);
|
||||
}, []);
|
||||
|
||||
// 清空事件
|
||||
const clearEvents = useCallback(() => {
|
||||
setEvents([]);
|
||||
}, []);
|
||||
|
||||
// 自动连接
|
||||
useEffect(() => {
|
||||
if (autoConnect && taskId) {
|
||||
connect();
|
||||
}
|
||||
return () => {
|
||||
disconnect();
|
||||
};
|
||||
}, [taskId, autoConnect, connect, disconnect]);
|
||||
|
||||
// 清理
|
||||
useEffect(() => {
|
||||
return () => {
|
||||
if (handlerRef.current) {
|
||||
handlerRef.current.disconnect();
|
||||
}
|
||||
};
|
||||
}, []);
|
||||
|
||||
return {
|
||||
events,
|
||||
thinking,
|
||||
isThinking,
|
||||
toolCalls,
|
||||
currentPhase,
|
||||
progress,
|
||||
findings,
|
||||
isComplete,
|
||||
error,
|
||||
connect,
|
||||
disconnect,
|
||||
isConnected,
|
||||
clearEvents,
|
||||
};
|
||||
}
|
||||
|
||||
/**
|
||||
* 简化版 Hook - 只获取思考过程
|
||||
*/
|
||||
export function useAgentThinking(taskId: string | null) {
|
||||
const { thinking, isThinking, connect, disconnect } = useAgentStream(taskId, {
|
||||
includeToolCalls: false,
|
||||
});
|
||||
|
||||
return { thinking, isThinking, connect, disconnect };
|
||||
}
|
||||
|
||||
/**
|
||||
* 简化版 Hook - 只获取工具调用
|
||||
*/
|
||||
export function useAgentToolCalls(taskId: string | null) {
|
||||
const { toolCalls, connect, disconnect } = useAgentStream(taskId, {
|
||||
includeThinking: false,
|
||||
});
|
||||
|
||||
return { toolCalls, connect, disconnect };
|
||||
}
|
||||
|
||||
export default useAgentStream;
|
||||
|
||||
|
|
@ -1,6 +1,7 @@
|
|||
/**
|
||||
* Agent 审计页面
|
||||
* 机械终端风格的 AI Agent 审计界面
|
||||
* 支持 LLM 思考过程和工具调用的实时流式展示
|
||||
*/
|
||||
|
||||
import { useState, useEffect, useRef, useCallback } from "react";
|
||||
|
|
@ -9,12 +10,14 @@ import {
|
|||
Terminal, Bot, Cpu, Shield, AlertTriangle, CheckCircle2,
|
||||
Loader2, Code, Zap, Activity, ChevronRight, XCircle,
|
||||
FileCode, Search, Bug, Lock, Play, Square, RefreshCw,
|
||||
ArrowLeft, Download, ExternalLink
|
||||
ArrowLeft, Download, ExternalLink, Brain, Wrench,
|
||||
ChevronDown, ChevronUp, Clock, Sparkles
|
||||
} from "lucide-react";
|
||||
import { Button } from "@/components/ui/button";
|
||||
import { Badge } from "@/components/ui/badge";
|
||||
import { ScrollArea } from "@/components/ui/scroll-area";
|
||||
import { toast } from "sonner";
|
||||
import { useAgentStream } from "@/hooks/useAgentStream";
|
||||
import {
|
||||
type AgentTask,
|
||||
type AgentEvent,
|
||||
|
|
@ -91,10 +94,36 @@ export default function AgentAuditPage() {
|
|||
const [isLoading, setIsLoading] = useState(true);
|
||||
const [isStreaming, setIsStreaming] = useState(false);
|
||||
const [currentTime, setCurrentTime] = useState(new Date().toLocaleTimeString("zh-CN", { hour12: false }));
|
||||
const [showThinking, setShowThinking] = useState(true);
|
||||
const [showToolDetails, setShowToolDetails] = useState(true);
|
||||
|
||||
const eventsEndRef = useRef<HTMLDivElement>(null);
|
||||
const thinkingEndRef = useRef<HTMLDivElement>(null);
|
||||
const abortControllerRef = useRef<AbortController | null>(null);
|
||||
|
||||
// 使用增强版流式 Hook
|
||||
const {
|
||||
thinking,
|
||||
isThinking,
|
||||
toolCalls,
|
||||
currentPhase: streamPhase,
|
||||
progress: streamProgress,
|
||||
connect: connectStream,
|
||||
disconnect: disconnectStream,
|
||||
isConnected: isStreamConnected,
|
||||
} = useAgentStream(taskId || null, {
|
||||
includeThinking: true,
|
||||
includeToolCalls: true,
|
||||
onFinding: () => loadFindings(),
|
||||
onComplete: () => {
|
||||
loadTask();
|
||||
loadFindings();
|
||||
},
|
||||
onError: (err) => {
|
||||
console.error("Stream error:", err);
|
||||
},
|
||||
});
|
||||
|
||||
// 是否完成
|
||||
const isComplete = task?.status === "completed" || task?.status === "failed" || task?.status === "cancelled";
|
||||
|
||||
|
|
@ -146,12 +175,24 @@ export default function AgentAuditPage() {
|
|||
init();
|
||||
}, [loadTask, loadEvents, loadFindings]);
|
||||
|
||||
// 事件流
|
||||
// 连接增强版流式 API
|
||||
useEffect(() => {
|
||||
if (!taskId || isComplete || isLoading) return;
|
||||
|
||||
connectStream();
|
||||
setIsStreaming(true);
|
||||
|
||||
return () => {
|
||||
disconnectStream();
|
||||
setIsStreaming(false);
|
||||
};
|
||||
}, [taskId, isComplete, isLoading, connectStream, disconnectStream]);
|
||||
|
||||
// 旧版事件流(作为后备)
|
||||
useEffect(() => {
|
||||
if (!taskId || isComplete || isLoading) return;
|
||||
|
||||
const startStreaming = async () => {
|
||||
setIsStreaming(true);
|
||||
abortControllerRef.current = new AbortController();
|
||||
|
||||
try {
|
||||
|
|
@ -179,8 +220,6 @@ export default function AgentAuditPage() {
|
|||
if ((error as Error).name !== "AbortError") {
|
||||
console.error("Event stream error:", error);
|
||||
}
|
||||
} finally {
|
||||
setIsStreaming(false);
|
||||
}
|
||||
};
|
||||
|
||||
|
|
@ -205,6 +244,30 @@ export default function AgentAuditPage() {
|
|||
return () => clearInterval(interval);
|
||||
}, []);
|
||||
|
||||
// 定期轮询任务状态(作为 SSE 的后备机制)
|
||||
useEffect(() => {
|
||||
if (!taskId || isComplete || isLoading) return;
|
||||
|
||||
// 每 3 秒轮询一次任务状态
|
||||
const pollInterval = setInterval(async () => {
|
||||
try {
|
||||
const taskData = await getAgentTask(taskId);
|
||||
setTask(taskData);
|
||||
|
||||
// 如果任务已完成/失败/取消,刷新其他数据
|
||||
if (taskData.status === "completed" || taskData.status === "failed" || taskData.status === "cancelled") {
|
||||
await loadEvents();
|
||||
await loadFindings();
|
||||
clearInterval(pollInterval);
|
||||
}
|
||||
} catch (error) {
|
||||
console.error("Failed to poll task status:", error);
|
||||
}
|
||||
}, 3000);
|
||||
|
||||
return () => clearInterval(pollInterval);
|
||||
}, [taskId, isComplete, isLoading, loadEvents, loadFindings]);
|
||||
|
||||
// 取消任务
|
||||
const handleCancel = async () => {
|
||||
if (!taskId) return;
|
||||
|
|
@ -291,14 +354,85 @@ export default function AgentAuditPage() {
|
|||
</div>
|
||||
</div>
|
||||
|
||||
{/* 错误提示 */}
|
||||
{task.status === "failed" && task.error_message && (
|
||||
<div className="mx-4 mt-2 p-3 bg-red-900/30 border border-red-700 rounded-lg">
|
||||
<div className="flex items-start gap-2">
|
||||
<XCircle className="w-5 h-5 text-red-400 flex-shrink-0 mt-0.5" />
|
||||
<div>
|
||||
<p className="text-red-400 font-semibold text-sm">任务执行失败</p>
|
||||
<p className="text-red-300/80 text-xs mt-1 font-mono break-all">{task.error_message}</p>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
)}
|
||||
|
||||
<div className="flex h-[calc(100vh-56px)]">
|
||||
{/* 左侧:执行日志 */}
|
||||
<div className="flex-1 p-4 flex flex-col min-w-0">
|
||||
|
||||
{/* 思考过程展示区域 */}
|
||||
{(isThinking || thinking) && showThinking && (
|
||||
<div className="mb-4 bg-purple-950/30 rounded-lg border border-purple-800/50 overflow-hidden">
|
||||
<div
|
||||
className="flex items-center justify-between px-3 py-2 bg-purple-900/30 border-b border-purple-800/30 cursor-pointer"
|
||||
onClick={() => setShowThinking(!showThinking)}
|
||||
>
|
||||
<div className="flex items-center gap-2 text-xs text-purple-400">
|
||||
<Brain className={`w-4 h-4 ${isThinking ? "animate-pulse" : ""}`} />
|
||||
<span className="uppercase tracking-wider">AI Thinking</span>
|
||||
{isThinking && (
|
||||
<span className="flex items-center gap-1 text-purple-300">
|
||||
<Sparkles className="w-3 h-3 animate-spin" />
|
||||
<span className="text-[10px]">Processing...</span>
|
||||
</span>
|
||||
)}
|
||||
</div>
|
||||
{showThinking ? <ChevronUp className="w-4 h-4 text-purple-400" /> : <ChevronDown className="w-4 h-4 text-purple-400" />}
|
||||
</div>
|
||||
|
||||
<div className="max-h-40 overflow-y-auto">
|
||||
<div className="p-3 text-sm text-purple-200/80 font-mono whitespace-pre-wrap">
|
||||
{thinking || "正在思考..."}
|
||||
{isThinking && <span className="animate-pulse text-purple-400">▌</span>}
|
||||
</div>
|
||||
<div ref={thinkingEndRef} />
|
||||
</div>
|
||||
</div>
|
||||
)}
|
||||
|
||||
{/* 工具调用展示区域 */}
|
||||
{toolCalls.length > 0 && showToolDetails && (
|
||||
<div className="mb-4 bg-yellow-950/20 rounded-lg border border-yellow-800/30 overflow-hidden">
|
||||
<div
|
||||
className="flex items-center justify-between px-3 py-2 bg-yellow-900/20 border-b border-yellow-800/20 cursor-pointer"
|
||||
onClick={() => setShowToolDetails(!showToolDetails)}
|
||||
>
|
||||
<div className="flex items-center gap-2 text-xs text-yellow-500">
|
||||
<Wrench className="w-4 h-4" />
|
||||
<span className="uppercase tracking-wider">Tool Calls</span>
|
||||
<Badge variant="outline" className="text-[10px] px-1.5 py-0 bg-yellow-900/30 border-yellow-700 text-yellow-400">
|
||||
{toolCalls.length}
|
||||
</Badge>
|
||||
</div>
|
||||
{showToolDetails ? <ChevronUp className="w-4 h-4 text-yellow-500" /> : <ChevronDown className="w-4 h-4 text-yellow-500" />}
|
||||
</div>
|
||||
|
||||
<div className="max-h-48 overflow-y-auto">
|
||||
<div className="p-2 space-y-2">
|
||||
{toolCalls.slice(-5).map((tc, idx) => (
|
||||
<ToolCallCard key={`${tc.name}-${idx}`} toolCall={tc} />
|
||||
))}
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
)}
|
||||
|
||||
<div className="flex items-center justify-between mb-3">
|
||||
<div className="flex items-center gap-2 text-xs text-cyan-400">
|
||||
<Terminal className="w-4 h-4" />
|
||||
<span className="uppercase tracking-wider">Execution Log</span>
|
||||
{isStreaming && (
|
||||
{(isStreaming || isStreamConnected) && (
|
||||
<span className="flex items-center gap-1 text-green-400">
|
||||
<span className="w-2 h-2 bg-green-400 rounded-full animate-pulse" />
|
||||
LIVE
|
||||
|
|
@ -540,6 +674,94 @@ function EventLine({ event }: { event: AgentEvent }) {
|
|||
);
|
||||
}
|
||||
|
||||
// 工具调用卡片组件
|
||||
interface ToolCallProps {
|
||||
toolCall: {
|
||||
name: string;
|
||||
input: Record<string, unknown>;
|
||||
output?: unknown;
|
||||
durationMs?: number;
|
||||
status: 'running' | 'success' | 'error';
|
||||
};
|
||||
}
|
||||
|
||||
function ToolCallCard({ toolCall }: ToolCallProps) {
|
||||
const [expanded, setExpanded] = useState(false);
|
||||
|
||||
const statusConfig = {
|
||||
running: {
|
||||
icon: <Loader2 className="w-3 h-3 animate-spin text-yellow-400" />,
|
||||
badge: "bg-yellow-900/30 border-yellow-700 text-yellow-400",
|
||||
text: "Running",
|
||||
},
|
||||
success: {
|
||||
icon: <CheckCircle2 className="w-3 h-3 text-green-400" />,
|
||||
badge: "bg-green-900/30 border-green-700 text-green-400",
|
||||
text: "Done",
|
||||
},
|
||||
error: {
|
||||
icon: <XCircle className="w-3 h-3 text-red-400" />,
|
||||
badge: "bg-red-900/30 border-red-700 text-red-400",
|
||||
text: "Error",
|
||||
},
|
||||
};
|
||||
|
||||
const config = statusConfig[toolCall.status];
|
||||
|
||||
return (
|
||||
<div className="bg-gray-900/50 rounded border border-gray-700/50 overflow-hidden">
|
||||
<div
|
||||
className="flex items-center justify-between px-2 py-1.5 cursor-pointer hover:bg-gray-800/50"
|
||||
onClick={() => setExpanded(!expanded)}
|
||||
>
|
||||
<div className="flex items-center gap-2">
|
||||
{config.icon}
|
||||
<span className="text-xs font-mono text-gray-300">{toolCall.name}</span>
|
||||
</div>
|
||||
<div className="flex items-center gap-2">
|
||||
{toolCall.durationMs && (
|
||||
<span className="text-[10px] text-gray-500">
|
||||
<Clock className="w-2.5 h-2.5 inline mr-0.5" />
|
||||
{toolCall.durationMs}ms
|
||||
</span>
|
||||
)}
|
||||
<Badge variant="outline" className={`text-[10px] px-1 py-0 ${config.badge}`}>
|
||||
{config.text}
|
||||
</Badge>
|
||||
{expanded ? <ChevronUp className="w-3 h-3 text-gray-500" /> : <ChevronDown className="w-3 h-3 text-gray-500" />}
|
||||
</div>
|
||||
</div>
|
||||
|
||||
{expanded && (
|
||||
<div className="border-t border-gray-700/50 text-[11px] font-mono">
|
||||
{/* 输入 */}
|
||||
{toolCall.input && Object.keys(toolCall.input).length > 0 && (
|
||||
<div className="p-2 border-b border-gray-800/50">
|
||||
<span className="text-gray-500 text-[10px] uppercase">Input:</span>
|
||||
<pre className="mt-1 text-cyan-300/80 whitespace-pre-wrap break-all max-h-24 overflow-y-auto">
|
||||
{JSON.stringify(toolCall.input, null, 2).slice(0, 500)}
|
||||
</pre>
|
||||
</div>
|
||||
)}
|
||||
|
||||
{/* 输出 */}
|
||||
{toolCall.output && (
|
||||
<div className="p-2">
|
||||
<span className="text-gray-500 text-[10px] uppercase">Output:</span>
|
||||
<pre className="mt-1 text-green-300/80 whitespace-pre-wrap break-all max-h-24 overflow-y-auto">
|
||||
{typeof toolCall.output === 'string'
|
||||
? toolCall.output.slice(0, 500)
|
||||
: JSON.stringify(toolCall.output, null, 2).slice(0, 500)
|
||||
}
|
||||
</pre>
|
||||
</div>
|
||||
)}
|
||||
</div>
|
||||
)}
|
||||
</div>
|
||||
);
|
||||
}
|
||||
|
||||
// 发现卡片组件
|
||||
function FindingCard({ finding }: { finding: AgentFinding }) {
|
||||
const colorClass = severityColors[finding.severity] || severityColors.info;
|
||||
|
|
|
|||
|
|
@ -0,0 +1,496 @@
|
|||
/**
|
||||
* Agent 流式事件处理
|
||||
*
|
||||
* 最佳实践:
|
||||
* 1. 使用 EventSource API 或 fetch + ReadableStream
|
||||
* 2. 支持重连机制
|
||||
* 3. 分类处理不同事件类型
|
||||
*/
|
||||
|
||||
// 事件类型定义
|
||||
export type StreamEventType =
|
||||
// LLM 相关
|
||||
| 'thinking_start'
|
||||
| 'thinking_token'
|
||||
| 'thinking_end'
|
||||
// 工具调用相关
|
||||
| 'tool_call_start'
|
||||
| 'tool_call_input'
|
||||
| 'tool_call_output'
|
||||
| 'tool_call_end'
|
||||
| 'tool_call_error'
|
||||
// 节点相关
|
||||
| 'node_start'
|
||||
| 'node_end'
|
||||
// 阶段相关
|
||||
| 'phase_start'
|
||||
| 'phase_end'
|
||||
| 'phase_complete'
|
||||
// 发现相关
|
||||
| 'finding_new'
|
||||
| 'finding_verified'
|
||||
// 状态相关
|
||||
| 'progress'
|
||||
| 'info'
|
||||
| 'warning'
|
||||
| 'error'
|
||||
// 任务相关
|
||||
| 'task_start'
|
||||
| 'task_complete'
|
||||
| 'task_error'
|
||||
| 'task_cancel'
|
||||
| 'task_end'
|
||||
// 心跳
|
||||
| 'heartbeat';
|
||||
|
||||
// 工具调用详情
|
||||
export interface ToolCallDetail {
|
||||
name: string;
|
||||
input?: Record<string, unknown>;
|
||||
output?: unknown;
|
||||
duration_ms?: number;
|
||||
}
|
||||
|
||||
// 流式事件数据
|
||||
export interface StreamEventData {
|
||||
id?: string;
|
||||
type: StreamEventType;
|
||||
phase?: string;
|
||||
message?: string;
|
||||
sequence?: number;
|
||||
timestamp?: string;
|
||||
tool?: ToolCallDetail;
|
||||
metadata?: Record<string, unknown>;
|
||||
tokens_used?: number;
|
||||
// 特定类型数据
|
||||
token?: string; // thinking_token
|
||||
accumulated?: string; // thinking_token/thinking_end
|
||||
status?: string; // task_end
|
||||
error?: string; // task_error
|
||||
findings_count?: number; // task_complete
|
||||
security_score?: number; // task_complete
|
||||
}
|
||||
|
||||
// 事件回调类型
|
||||
export type StreamEventCallback = (event: StreamEventData) => void;
|
||||
|
||||
// 流式选项
|
||||
export interface StreamOptions {
|
||||
includeThinking?: boolean;
|
||||
includeToolCalls?: boolean;
|
||||
afterSequence?: number;
|
||||
onThinkingStart?: () => void;
|
||||
onThinkingToken?: (token: string, accumulated: string) => void;
|
||||
onThinkingEnd?: (fullResponse: string) => void;
|
||||
onToolStart?: (toolName: string, input: Record<string, unknown>) => void;
|
||||
onToolEnd?: (toolName: string, output: unknown, durationMs: number) => void;
|
||||
onNodeStart?: (nodeName: string, phase: string) => void;
|
||||
onNodeEnd?: (nodeName: string, summary: Record<string, unknown>) => void;
|
||||
onFinding?: (finding: Record<string, unknown>, isVerified: boolean) => void;
|
||||
onProgress?: (current: number, total: number, message: string) => void;
|
||||
onComplete?: (data: { findingsCount: number; securityScore: number }) => void;
|
||||
onError?: (error: string) => void;
|
||||
onHeartbeat?: () => void;
|
||||
onEvent?: StreamEventCallback; // 通用事件回调
|
||||
}
|
||||
|
||||
/**
|
||||
* Agent 流式事件处理器
|
||||
*/
|
||||
export class AgentStreamHandler {
|
||||
private taskId: string;
|
||||
private eventSource: EventSource | null = null;
|
||||
private options: StreamOptions;
|
||||
private reconnectAttempts = 0;
|
||||
private maxReconnectAttempts = 5;
|
||||
private reconnectDelay = 1000;
|
||||
private isConnected = false;
|
||||
private thinkingBuffer: string[] = [];
|
||||
|
||||
constructor(taskId: string, options: StreamOptions = {}) {
|
||||
this.taskId = taskId;
|
||||
this.options = {
|
||||
includeThinking: true,
|
||||
includeToolCalls: true,
|
||||
afterSequence: 0,
|
||||
...options,
|
||||
};
|
||||
}
|
||||
|
||||
/**
|
||||
* 开始监听事件流
|
||||
*/
|
||||
connect(): void {
|
||||
const token = localStorage.getItem('access_token');
|
||||
if (!token) {
|
||||
this.options.onError?.('未登录');
|
||||
return;
|
||||
}
|
||||
|
||||
const params = new URLSearchParams({
|
||||
include_thinking: String(this.options.includeThinking),
|
||||
include_tool_calls: String(this.options.includeToolCalls),
|
||||
after_sequence: String(this.options.afterSequence),
|
||||
});
|
||||
|
||||
// 使用 EventSource (不支持自定义 headers,需要通过 URL 传递 token)
|
||||
// 或者使用 fetch + ReadableStream
|
||||
this.connectWithFetch(token, params);
|
||||
}
|
||||
|
||||
/**
|
||||
* 使用 fetch 连接(支持自定义 headers)
|
||||
*/
|
||||
private async connectWithFetch(token: string, params: URLSearchParams): Promise<void> {
|
||||
const url = `/api/v1/agent-tasks/${this.taskId}/stream?${params}`;
|
||||
|
||||
try {
|
||||
const response = await fetch(url, {
|
||||
headers: {
|
||||
'Authorization': `Bearer ${token}`,
|
||||
'Accept': 'text/event-stream',
|
||||
},
|
||||
});
|
||||
|
||||
if (!response.ok) {
|
||||
throw new Error(`HTTP ${response.status}: ${response.statusText}`);
|
||||
}
|
||||
|
||||
this.isConnected = true;
|
||||
this.reconnectAttempts = 0;
|
||||
|
||||
const reader = response.body?.getReader();
|
||||
if (!reader) {
|
||||
throw new Error('无法获取响应流');
|
||||
}
|
||||
|
||||
const decoder = new TextDecoder();
|
||||
let buffer = '';
|
||||
|
||||
while (true) {
|
||||
const { done, value } = await reader.read();
|
||||
|
||||
if (done) {
|
||||
break;
|
||||
}
|
||||
|
||||
buffer += decoder.decode(value, { stream: true });
|
||||
|
||||
// 解析 SSE 事件
|
||||
const events = this.parseSSE(buffer);
|
||||
buffer = events.remaining;
|
||||
|
||||
for (const event of events.parsed) {
|
||||
this.handleEvent(event);
|
||||
}
|
||||
}
|
||||
} catch (error) {
|
||||
this.isConnected = false;
|
||||
console.error('Stream connection error:', error);
|
||||
|
||||
// 尝试重连
|
||||
if (this.reconnectAttempts < this.maxReconnectAttempts) {
|
||||
this.reconnectAttempts++;
|
||||
setTimeout(() => this.connect(), this.reconnectDelay * this.reconnectAttempts);
|
||||
} else {
|
||||
this.options.onError?.(`连接失败: ${error}`);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* 解析 SSE 格式
|
||||
*/
|
||||
private parseSSE(buffer: string): { parsed: StreamEventData[]; remaining: string } {
|
||||
const parsed: StreamEventData[] = [];
|
||||
const lines = buffer.split('\n');
|
||||
let remaining = '';
|
||||
let currentEvent: Partial<StreamEventData> = {};
|
||||
|
||||
for (let i = 0; i < lines.length; i++) {
|
||||
const line = lines[i];
|
||||
|
||||
// 空行表示事件结束
|
||||
if (line === '') {
|
||||
if (currentEvent.type) {
|
||||
parsed.push(currentEvent as StreamEventData);
|
||||
currentEvent = {};
|
||||
}
|
||||
continue;
|
||||
}
|
||||
|
||||
// 检查是否是最后一行(可能不完整)
|
||||
if (i === lines.length - 1 && !buffer.endsWith('\n')) {
|
||||
remaining = line;
|
||||
break;
|
||||
}
|
||||
|
||||
// 解析 event: 行
|
||||
if (line.startsWith('event:')) {
|
||||
currentEvent.type = line.slice(6).trim() as StreamEventType;
|
||||
}
|
||||
// 解析 data: 行
|
||||
else if (line.startsWith('data:')) {
|
||||
try {
|
||||
const data = JSON.parse(line.slice(5).trim());
|
||||
currentEvent = { ...currentEvent, ...data };
|
||||
} catch {
|
||||
// 忽略解析错误
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return { parsed, remaining };
|
||||
}
|
||||
|
||||
/**
|
||||
* 处理事件
|
||||
*/
|
||||
private handleEvent(event: StreamEventData): void {
|
||||
// 通用回调
|
||||
this.options.onEvent?.(event);
|
||||
|
||||
// 分类处理
|
||||
switch (event.type) {
|
||||
// LLM 思考
|
||||
case 'thinking_start':
|
||||
this.thinkingBuffer = [];
|
||||
this.options.onThinkingStart?.();
|
||||
break;
|
||||
|
||||
case 'thinking_token':
|
||||
if (event.token) {
|
||||
this.thinkingBuffer.push(event.token);
|
||||
this.options.onThinkingToken?.(
|
||||
event.token,
|
||||
event.accumulated || this.thinkingBuffer.join('')
|
||||
);
|
||||
}
|
||||
break;
|
||||
|
||||
case 'thinking_end':
|
||||
const fullResponse = event.accumulated || this.thinkingBuffer.join('');
|
||||
this.thinkingBuffer = [];
|
||||
this.options.onThinkingEnd?.(fullResponse);
|
||||
break;
|
||||
|
||||
// 工具调用
|
||||
case 'tool_call_start':
|
||||
if (event.tool) {
|
||||
this.options.onToolStart?.(
|
||||
event.tool.name,
|
||||
event.tool.input || {}
|
||||
);
|
||||
}
|
||||
break;
|
||||
|
||||
case 'tool_call_end':
|
||||
if (event.tool) {
|
||||
this.options.onToolEnd?.(
|
||||
event.tool.name,
|
||||
event.tool.output,
|
||||
event.tool.duration_ms || 0
|
||||
);
|
||||
}
|
||||
break;
|
||||
|
||||
// 节点
|
||||
case 'node_start':
|
||||
this.options.onNodeStart?.(
|
||||
event.metadata?.node as string || 'unknown',
|
||||
event.phase || ''
|
||||
);
|
||||
break;
|
||||
|
||||
case 'node_end':
|
||||
this.options.onNodeEnd?.(
|
||||
event.metadata?.node as string || 'unknown',
|
||||
event.metadata?.summary as Record<string, unknown> || {}
|
||||
);
|
||||
break;
|
||||
|
||||
// 发现
|
||||
case 'finding_new':
|
||||
case 'finding_verified':
|
||||
this.options.onFinding?.(
|
||||
event.metadata || {},
|
||||
event.type === 'finding_verified'
|
||||
);
|
||||
break;
|
||||
|
||||
// 进度
|
||||
case 'progress':
|
||||
this.options.onProgress?.(
|
||||
event.metadata?.current as number || 0,
|
||||
event.metadata?.total as number || 100,
|
||||
event.message || ''
|
||||
);
|
||||
break;
|
||||
|
||||
// 任务完成
|
||||
case 'task_complete':
|
||||
case 'task_end':
|
||||
if (event.status !== 'cancelled' && event.status !== 'failed') {
|
||||
this.options.onComplete?.({
|
||||
findingsCount: event.findings_count || event.metadata?.findings_count as number || 0,
|
||||
securityScore: event.security_score || event.metadata?.security_score as number || 100,
|
||||
});
|
||||
}
|
||||
this.disconnect();
|
||||
break;
|
||||
|
||||
// 错误
|
||||
case 'task_error':
|
||||
case 'error':
|
||||
this.options.onError?.(event.error || event.message || '未知错误');
|
||||
this.disconnect();
|
||||
break;
|
||||
|
||||
// 心跳
|
||||
case 'heartbeat':
|
||||
this.options.onHeartbeat?.();
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* 断开连接
|
||||
*/
|
||||
disconnect(): void {
|
||||
this.isConnected = false;
|
||||
if (this.eventSource) {
|
||||
this.eventSource.close();
|
||||
this.eventSource = null;
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* 检查是否已连接
|
||||
*/
|
||||
get connected(): boolean {
|
||||
return this.isConnected;
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* 创建流式事件处理器的便捷函数
|
||||
*/
|
||||
export function createAgentStream(
|
||||
taskId: string,
|
||||
options: StreamOptions = {}
|
||||
): AgentStreamHandler {
|
||||
return new AgentStreamHandler(taskId, options);
|
||||
}
|
||||
|
||||
/**
|
||||
* React Hook 风格的使用示例
|
||||
*
|
||||
* ```tsx
|
||||
* const { events, thinking, toolCalls, connect, disconnect } = useAgentStream(taskId);
|
||||
*
|
||||
* useEffect(() => {
|
||||
* connect();
|
||||
* return () => disconnect();
|
||||
* }, [taskId]);
|
||||
* ```
|
||||
*/
|
||||
export interface AgentStreamState {
|
||||
events: StreamEventData[];
|
||||
thinking: string;
|
||||
isThinking: boolean;
|
||||
toolCalls: Array<{
|
||||
name: string;
|
||||
input: Record<string, unknown>;
|
||||
output?: unknown;
|
||||
durationMs?: number;
|
||||
status: 'running' | 'success' | 'error';
|
||||
}>;
|
||||
currentPhase: string;
|
||||
progress: { current: number; total: number; percentage: number };
|
||||
findings: Array<Record<string, unknown>>;
|
||||
isComplete: boolean;
|
||||
error: string | null;
|
||||
}
|
||||
|
||||
/**
|
||||
* 创建用于 React 状态管理的流式处理器
|
||||
*/
|
||||
export function createAgentStreamWithState(
|
||||
taskId: string,
|
||||
onStateChange: (state: AgentStreamState) => void
|
||||
): AgentStreamHandler {
|
||||
const state: AgentStreamState = {
|
||||
events: [],
|
||||
thinking: '',
|
||||
isThinking: false,
|
||||
toolCalls: [],
|
||||
currentPhase: '',
|
||||
progress: { current: 0, total: 100, percentage: 0 },
|
||||
findings: [],
|
||||
isComplete: false,
|
||||
error: null,
|
||||
};
|
||||
|
||||
const updateState = (updates: Partial<AgentStreamState>) => {
|
||||
Object.assign(state, updates);
|
||||
onStateChange({ ...state });
|
||||
};
|
||||
|
||||
return new AgentStreamHandler(taskId, {
|
||||
onEvent: (event) => {
|
||||
updateState({
|
||||
events: [...state.events, event].slice(-500), // 保留最近 500 条
|
||||
});
|
||||
},
|
||||
onThinkingStart: () => {
|
||||
updateState({ isThinking: true, thinking: '' });
|
||||
},
|
||||
onThinkingToken: (_, accumulated) => {
|
||||
updateState({ thinking: accumulated });
|
||||
},
|
||||
onThinkingEnd: (response) => {
|
||||
updateState({ isThinking: false, thinking: response });
|
||||
},
|
||||
onToolStart: (name, input) => {
|
||||
updateState({
|
||||
toolCalls: [
|
||||
...state.toolCalls,
|
||||
{ name, input, status: 'running' },
|
||||
],
|
||||
});
|
||||
},
|
||||
onToolEnd: (name, output, durationMs) => {
|
||||
updateState({
|
||||
toolCalls: state.toolCalls.map((tc) =>
|
||||
tc.name === name && tc.status === 'running'
|
||||
? { ...tc, output, durationMs, status: 'success' as const }
|
||||
: tc
|
||||
),
|
||||
});
|
||||
},
|
||||
onNodeStart: (_, phase) => {
|
||||
updateState({ currentPhase: phase });
|
||||
},
|
||||
onProgress: (current, total, _) => {
|
||||
updateState({
|
||||
progress: {
|
||||
current,
|
||||
total,
|
||||
percentage: total > 0 ? Math.round((current / total) * 100) : 0,
|
||||
},
|
||||
});
|
||||
},
|
||||
onFinding: (finding, _) => {
|
||||
updateState({
|
||||
findings: [...state.findings, finding],
|
||||
});
|
||||
},
|
||||
onComplete: () => {
|
||||
updateState({ isComplete: true });
|
||||
},
|
||||
onError: (error) => {
|
||||
updateState({ error, isComplete: true });
|
||||
},
|
||||
});
|
||||
}
|
||||
|
||||
|
|
@ -43,6 +43,9 @@ export interface AgentTask {
|
|||
|
||||
// 进度
|
||||
progress_percentage: number;
|
||||
|
||||
// 错误信息
|
||||
error_message: string | null;
|
||||
}
|
||||
|
||||
export interface AgentFinding {
|
||||
|
|
@ -249,7 +252,7 @@ export async function* streamAgentEvents(
|
|||
afterSequence = 0,
|
||||
signal?: AbortSignal
|
||||
): AsyncGenerator<AgentEvent, void, unknown> {
|
||||
const token = localStorage.getItem("auth_token");
|
||||
const token = localStorage.getItem("access_token");
|
||||
const baseUrl = import.meta.env.VITE_API_URL || "";
|
||||
const url = `${baseUrl}/api/v1/agent-tasks/${taskId}/events?after_sequence=${afterSequence}`;
|
||||
|
||||
|
|
|
|||
Loading…
Reference in New Issue