From 58c918f55729d133343a54f298298723cda7b9ac Mon Sep 17 00:00:00 2001 From: lintsinghua Date: Thu, 11 Dec 2025 20:33:46 +0800 Subject: [PATCH] feat(agent): implement streaming support for agent events and enhance UI components - Introduce streaming capabilities for agent events, allowing real-time updates during audits. - Add new hooks for managing agent stream events in React components. - Enhance the AgentAudit page to display LLM thinking processes and tool call details in real-time. - Update API endpoints to support streaming event data and improve error handling. - Refactor UI components for better organization and user experience during audits. --- backend/Dockerfile | 2 + backend/alembic.ini | 2 + backend/alembic/env.py | 2 + backend/alembic/script.py.mako | 2 + backend/app/api/v1/endpoints/agent_tasks.py | 166 ++++++ backend/app/api/v1/endpoints/members.py | 2 + backend/app/api/v1/endpoints/users.py | 2 + backend/app/core/security.py | 2 + backend/app/db/base.py | 2 + backend/app/db/session.py | 2 + backend/app/models/analysis.py | 2 + backend/app/models/user.py | 2 + backend/app/models/user_config.py | 2 + backend/app/schemas/token.py | 2 + backend/app/schemas/user.py | 2 + backend/app/services/agent/agents/analysis.py | 271 +++++++++- .../app/services/agent/agents/react_agent.py | 375 +++++++++++++ backend/app/services/agent/event_manager.py | 42 ++ backend/app/services/agent/graph/runner.py | 329 ++++++++++-- .../app/services/agent/streaming/__init__.py | 17 + .../agent/streaming/stream_handler.py | 453 ++++++++++++++++ .../agent/streaming/token_streamer.py | 261 +++++++++ .../services/agent/streaming/tool_stream.py | 319 +++++++++++ .../app/services/agent/tools/pattern_tool.py | 10 + .../services/llm/adapters/baidu_adapter.py | 2 + .../services/llm/adapters/doubao_adapter.py | 2 + .../services/llm/adapters/minimax_adapter.py | 2 + backend/app/services/llm/base_adapter.py | 2 + backend/app/services/llm/types.py | 2 + backend/tests/agent/__init__.py | 5 + backend/tests/agent/conftest.py | 295 +++++++++++ backend/tests/agent/run_tests.py | 96 ++++ backend/tests/agent/test_agents.py | 213 ++++++++ backend/tests/agent/test_integration.py | 355 +++++++++++++ backend/tests/agent/test_tools.py | 248 +++++++++ docs/AGENT_DEPLOYMENT_CHECKLIST.md | 229 ++++++++ frontend/Dockerfile | 2 + frontend/src/app/ProtectedRoute.tsx | 2 + frontend/src/hooks/useAgentStream.ts | 280 ++++++++++ frontend/src/pages/AgentAudit.tsx | 234 ++++++++- frontend/src/shared/api/agentStream.ts | 496 ++++++++++++++++++ frontend/src/shared/api/agentTasks.ts | 5 +- 42 files changed, 4664 insertions(+), 77 deletions(-) create mode 100644 backend/app/services/agent/agents/react_agent.py create mode 100644 backend/app/services/agent/streaming/__init__.py create mode 100644 backend/app/services/agent/streaming/stream_handler.py create mode 100644 backend/app/services/agent/streaming/token_streamer.py create mode 100644 backend/app/services/agent/streaming/tool_stream.py create mode 100644 backend/tests/agent/__init__.py create mode 100644 backend/tests/agent/conftest.py create mode 100644 backend/tests/agent/run_tests.py create mode 100644 backend/tests/agent/test_agents.py create mode 100644 backend/tests/agent/test_integration.py create mode 100644 backend/tests/agent/test_tools.py create mode 100644 docs/AGENT_DEPLOYMENT_CHECKLIST.md create mode 100644 frontend/src/hooks/useAgentStream.ts create mode 100644 frontend/src/shared/api/agentStream.ts diff --git a/backend/Dockerfile b/backend/Dockerfile index 53294a4..3e1c745 100644 --- a/backend/Dockerfile +++ b/backend/Dockerfile @@ -59,3 +59,5 @@ CMD ["uvicorn", "app.main:app", "--host", "0.0.0.0", "--port", "8000"] + + diff --git a/backend/alembic.ini b/backend/alembic.ini index 72b7dfb..05ad917 100644 --- a/backend/alembic.ini +++ b/backend/alembic.ini @@ -103,3 +103,5 @@ datefmt = %H:%M:%S + + diff --git a/backend/alembic/env.py b/backend/alembic/env.py index 87c51e8..905f379 100644 --- a/backend/alembic/env.py +++ b/backend/alembic/env.py @@ -90,3 +90,5 @@ else: + + diff --git a/backend/alembic/script.py.mako b/backend/alembic/script.py.mako index b394332..d37ff90 100644 --- a/backend/alembic/script.py.mako +++ b/backend/alembic/script.py.mako @@ -25,3 +25,5 @@ def downgrade() -> None: + + diff --git a/backend/app/api/v1/endpoints/agent_tasks.py b/backend/app/api/v1/endpoints/agent_tasks.py index 9f48b99..a9f8880 100644 --- a/backend/app/api/v1/endpoints/agent_tasks.py +++ b/backend/app/api/v1/endpoints/agent_tasks.py @@ -29,6 +29,7 @@ from app.models.agent_task import ( from app.models.project import Project from app.models.user import User from app.services.agent import AgentRunner, EventManager, run_agent_task +from app.services.agent.streaming import StreamHandler, StreamEvent, StreamEventType logger = logging.getLogger(__name__) router = APIRouter() @@ -432,6 +433,171 @@ async def stream_agent_events( ) +@router.get("/{task_id}/stream") +async def stream_agent_with_thinking( + task_id: str, + include_thinking: bool = Query(True, description="是否包含 LLM 思考过程"), + include_tool_calls: bool = Query(True, description="是否包含工具调用详情"), + after_sequence: int = Query(0, ge=0, description="从哪个序号之后开始"), + db: AsyncSession = Depends(get_db), + current_user: User = Depends(deps.get_current_user), +): + """ + 增强版事件流 (SSE) + + 支持: + - LLM 思考过程的 Token 级流式输出 + - 工具调用的详细输入/输出 + - 节点执行状态 + - 发现事件 + + 事件类型: + - thinking_start: LLM 开始思考 + - thinking_token: LLM 输出 Token + - thinking_end: LLM 思考结束 + - tool_call_start: 工具调用开始 + - tool_call_end: 工具调用结束 + - node_start: 节点开始 + - node_end: 节点结束 + - finding_new: 新发现 + - finding_verified: 验证通过 + - progress: 进度更新 + - task_complete: 任务完成 + - task_error: 任务错误 + - heartbeat: 心跳 + """ + task = await db.get(AgentTask, task_id) + if not task: + raise HTTPException(status_code=404, detail="任务不存在") + + project = await db.get(Project, task.project_id) + if not project or project.owner_id != current_user.id: + raise HTTPException(status_code=403, detail="无权访问此任务") + + async def enhanced_event_generator(): + """生成增强版 SSE 事件流""" + last_sequence = after_sequence + poll_interval = 0.3 # 更短的轮询间隔以支持流式 + heartbeat_interval = 15 # 心跳间隔 + max_idle = 600 # 10 分钟无事件后关闭 + idle_time = 0 + last_heartbeat = 0 + + # 事件类型过滤 + skip_types = set() + if not include_thinking: + skip_types.update(["thinking_start", "thinking_token", "thinking_end"]) + if not include_tool_calls: + skip_types.update(["tool_call_start", "tool_call_input", "tool_call_output", "tool_call_end"]) + + while True: + try: + async with async_session_factory() as session: + # 查询新事件 + result = await session.execute( + select(AgentEvent) + .where(AgentEvent.task_id == task_id) + .where(AgentEvent.sequence > last_sequence) + .order_by(AgentEvent.sequence) + .limit(100) + ) + events = result.scalars().all() + + # 获取任务状态 + current_task = await session.get(AgentTask, task_id) + task_status = current_task.status if current_task else None + + if events: + idle_time = 0 + for event in events: + last_sequence = event.sequence + + # 获取事件类型字符串 + event_type = event.event_type.value if hasattr(event.event_type, 'value') else str(event.event_type) + + # 过滤事件 + if event_type in skip_types: + continue + + # 构建事件数据 + data = { + "id": event.id, + "type": event_type, + "phase": event.phase.value if event.phase and hasattr(event.phase, 'value') else event.phase, + "message": event.message, + "sequence": event.sequence, + "timestamp": event.created_at.isoformat() if event.created_at else None, + } + + # 添加工具调用详情 + if include_tool_calls and event.tool_name: + data["tool"] = { + "name": event.tool_name, + "input": event.tool_input, + "output": event.tool_output, + "duration_ms": event.tool_duration_ms, + } + + # 添加元数据 + if event.event_metadata: + data["metadata"] = event.event_metadata + + # 添加 Token 使用 + if event.tokens_used: + data["tokens_used"] = event.tokens_used + + # 使用标准 SSE 格式 + yield f"event: {event_type}\ndata: {json.dumps(data, ensure_ascii=False)}\n\n" + else: + idle_time += poll_interval + + # 检查任务是否结束 + if task_status in [AgentTaskStatus.COMPLETED, AgentTaskStatus.FAILED, AgentTaskStatus.CANCELLED]: + end_data = { + "type": "task_end", + "status": task_status.value, + "message": f"任务{'完成' if task_status == AgentTaskStatus.COMPLETED else '结束'}", + } + yield f"event: task_end\ndata: {json.dumps(end_data, ensure_ascii=False)}\n\n" + break + + # 发送心跳 + last_heartbeat += poll_interval + if last_heartbeat >= heartbeat_interval: + last_heartbeat = 0 + heartbeat_data = { + "type": "heartbeat", + "timestamp": datetime.now(timezone.utc).isoformat(), + "last_sequence": last_sequence, + } + yield f"event: heartbeat\ndata: {json.dumps(heartbeat_data)}\n\n" + + # 检查空闲超时 + if idle_time >= max_idle: + timeout_data = {"type": "timeout", "message": "连接超时"} + yield f"event: timeout\ndata: {json.dumps(timeout_data)}\n\n" + break + + await asyncio.sleep(poll_interval) + + except Exception as e: + logger.error(f"Stream error: {e}") + error_data = {"type": "error", "message": str(e)} + yield f"event: error\ndata: {json.dumps(error_data)}\n\n" + break + + return StreamingResponse( + enhanced_event_generator(), + media_type="text/event-stream", + headers={ + "Cache-Control": "no-cache", + "Connection": "keep-alive", + "X-Accel-Buffering": "no", + "Content-Type": "text/event-stream; charset=utf-8", + } + ) + + @router.get("/{task_id}/events/list", response_model=List[AgentEventResponse]) async def list_agent_events( task_id: str, diff --git a/backend/app/api/v1/endpoints/members.py b/backend/app/api/v1/endpoints/members.py index 612ccbe..0e8bcbb 100644 --- a/backend/app/api/v1/endpoints/members.py +++ b/backend/app/api/v1/endpoints/members.py @@ -210,3 +210,5 @@ async def remove_project_member( + + diff --git a/backend/app/api/v1/endpoints/users.py b/backend/app/api/v1/endpoints/users.py index 951f603..e25cae3 100644 --- a/backend/app/api/v1/endpoints/users.py +++ b/backend/app/api/v1/endpoints/users.py @@ -225,3 +225,5 @@ async def toggle_user_status( + + diff --git a/backend/app/core/security.py b/backend/app/core/security.py index d0e5295..3bf05b9 100644 --- a/backend/app/core/security.py +++ b/backend/app/core/security.py @@ -29,3 +29,5 @@ def get_password_hash(password: str) -> str: + + diff --git a/backend/app/db/base.py b/backend/app/db/base.py index 98a9a1b..15b97ae 100644 --- a/backend/app/db/base.py +++ b/backend/app/db/base.py @@ -12,3 +12,5 @@ class Base: + + diff --git a/backend/app/db/session.py b/backend/app/db/session.py index 8b4db67..3c81360 100644 --- a/backend/app/db/session.py +++ b/backend/app/db/session.py @@ -28,3 +28,5 @@ async def async_session_factory(): + + diff --git a/backend/app/models/analysis.py b/backend/app/models/analysis.py index c55d47a..7d26d2b 100644 --- a/backend/app/models/analysis.py +++ b/backend/app/models/analysis.py @@ -24,3 +24,5 @@ class InstantAnalysis(Base): + + diff --git a/backend/app/models/user.py b/backend/app/models/user.py index 401f8d3..bd3e106 100644 --- a/backend/app/models/user.py +++ b/backend/app/models/user.py @@ -25,3 +25,5 @@ class User(Base): + + diff --git a/backend/app/models/user_config.py b/backend/app/models/user_config.py index 9a30a2f..31953a1 100644 --- a/backend/app/models/user_config.py +++ b/backend/app/models/user_config.py @@ -30,3 +30,5 @@ class UserConfig(Base): + + diff --git a/backend/app/schemas/token.py b/backend/app/schemas/token.py index d3908cd..588a1e9 100644 --- a/backend/app/schemas/token.py +++ b/backend/app/schemas/token.py @@ -10,3 +10,5 @@ class TokenPayload(BaseModel): + + diff --git a/backend/app/schemas/user.py b/backend/app/schemas/user.py index ffce190..4590e9f 100644 --- a/backend/app/schemas/user.py +++ b/backend/app/schemas/user.py @@ -41,3 +41,5 @@ class UserListResponse(BaseModel): + + diff --git a/backend/app/services/agent/agents/analysis.py b/backend/app/services/agent/agents/analysis.py index ecdf8e5..e429bf8 100644 --- a/backend/app/services/agent/agents/analysis.py +++ b/backend/app/services/agent/agents/analysis.py @@ -147,11 +147,11 @@ class AnalysisAgent(BaseAgent): deep_findings = await self._analyze_entry_points(entry_points) all_findings.extend(deep_findings) - # 分析高风险区域 + # 分析高风险区域(现在会调用 LLM) risk_findings = await self._analyze_high_risk_areas(high_risk_areas) all_findings.extend(risk_findings) - # 语义搜索常见漏洞 + # 语义搜索常见漏洞(现在会调用 LLM) vuln_types = config.get("target_vulnerabilities", [ "sql_injection", "xss", "command_injection", "path_traversal", "ssrf", "hardcoded_secret", @@ -164,6 +164,12 @@ class AnalysisAgent(BaseAgent): await self.emit_thinking(f"搜索 {vuln_type} 相关代码...") vuln_findings = await self._search_vulnerability_pattern(vuln_type) all_findings.extend(vuln_findings) + + # 🔥 3. 如果还没有发现,使用 LLM 进行全面扫描 + if len(all_findings) < 3: + await self.emit_thinking("执行 LLM 全面代码扫描...") + llm_findings = await self._llm_comprehensive_scan(tech_stack) + all_findings.extend(llm_findings) # 去重 all_findings = self._deduplicate_findings(all_findings) @@ -292,12 +298,12 @@ class AnalysisAgent(BaseAgent): return findings async def _analyze_high_risk_areas(self, high_risk_areas: List[str]) -> List[Dict]: - """分析高风险区域""" + """分析高风险区域 - 使用 LLM 深度分析""" findings = [] - pattern_tool = self.tools.get("pattern_match") read_tool = self.tools.get("read_file") search_tool = self.tools.get("search_code") + code_analysis_tool = self.tools.get("code_analysis") if not search_tool: return findings @@ -305,36 +311,92 @@ class AnalysisAgent(BaseAgent): # 在高风险区域搜索危险模式 dangerous_patterns = [ ("execute(", "sql_injection"), + ("query(", "sql_injection"), ("eval(", "code_injection"), ("system(", "command_injection"), ("exec(", "command_injection"), + ("subprocess", "command_injection"), ("innerHTML", "xss"), ("document.write", "xss"), + ("open(", "path_traversal"), + ("requests.get", "ssrf"), ] - for pattern, vuln_type in dangerous_patterns[:5]: + analyzed_files = set() + + for pattern, vuln_type in dangerous_patterns[:8]: if self.is_cancelled: break result = await search_tool.execute(keyword=pattern, max_results=10) if result.success and result.metadata.get("matches", 0) > 0: - for match in result.metadata.get("results", [])[:3]: + for match in result.metadata.get("results", [])[:5]: file_path = match.get("file", "") + line = match.get("line", 0) - # 检查是否在高风险区域 - in_high_risk = any( - area in file_path for area in high_risk_areas - ) + # 避免重复分析同一个文件的同一区域 + file_key = f"{file_path}:{line // 50}" + if file_key in analyzed_files: + continue + analyzed_files.add(file_key) - if in_high_risk or True: # 暂时包含所有 + # 🔥 使用 LLM 深度分析找到的代码 + if read_tool and code_analysis_tool: + await self.emit_thinking(f"LLM 分析 {file_path}:{line} 的 {vuln_type} 风险...") + + # 读取代码上下文 + read_result = await read_tool.execute( + file_path=file_path, + start_line=max(1, line - 15), + end_line=line + 25, + ) + + if read_result.success: + # 调用 LLM 分析 + analysis_result = await code_analysis_tool.execute( + code=read_result.data, + file_path=file_path, + focus=vuln_type, + ) + + if analysis_result.success and analysis_result.metadata.get("issues"): + for issue in analysis_result.metadata["issues"]: + findings.append({ + "vulnerability_type": issue.get("type", vuln_type), + "severity": issue.get("severity", "medium"), + "title": issue.get("title", f"LLM 发现: {vuln_type}"), + "description": issue.get("description", ""), + "file_path": file_path, + "line_start": issue.get("line", line), + "code_snippet": issue.get("code_snippet", match.get("match", "")), + "suggestion": issue.get("suggestion", ""), + "ai_explanation": issue.get("ai_explanation", ""), + "source": "llm_analysis", + "needs_verification": True, + }) + elif analysis_result.success: + # LLM 分析了但没发现问题,仍记录原始发现 + findings.append({ + "vulnerability_type": vuln_type, + "severity": "low", + "title": f"疑似 {vuln_type}: {pattern}", + "description": f"在 {file_path} 中发现危险模式,但 LLM 分析未确认", + "file_path": file_path, + "line_start": line, + "code_snippet": match.get("match", ""), + "source": "pattern_search", + "needs_verification": True, + }) + else: + # 没有 LLM 工具,使用基础模式匹配 findings.append({ "vulnerability_type": vuln_type, - "severity": "high" if in_high_risk else "medium", + "severity": "medium", "title": f"疑似 {vuln_type}: {pattern}", "description": f"在 {file_path} 中发现危险模式 {pattern}", "file_path": file_path, - "line_start": match.get("line", 0), + "line_start": line, "code_snippet": match.get("match", ""), "source": "pattern_search", "needs_verification": True, @@ -343,10 +405,13 @@ class AnalysisAgent(BaseAgent): return findings async def _search_vulnerability_pattern(self, vuln_type: str) -> List[Dict]: - """搜索特定漏洞模式""" + """搜索特定漏洞模式 - 使用 RAG + LLM""" findings = [] security_tool = self.tools.get("security_search") + code_analysis_tool = self.tools.get("code_analysis") + read_tool = self.tools.get("read_file") + if not security_tool: return findings @@ -357,20 +422,176 @@ class AnalysisAgent(BaseAgent): if result.success and result.metadata.get("results_count", 0) > 0: for item in result.metadata.get("results", [])[:5]: - findings.append({ - "vulnerability_type": vuln_type, - "severity": "medium", - "title": f"疑似 {vuln_type}", - "description": f"通过语义搜索发现可能存在 {vuln_type}", - "file_path": item.get("file_path", ""), - "line_start": item.get("line_start", 0), - "code_snippet": item.get("content", "")[:500], - "source": "rag_search", - "needs_verification": True, - }) + file_path = item.get("file_path", "") + line_start = item.get("line_start", 0) + content = item.get("content", "")[:2000] + + # 🔥 使用 LLM 验证 RAG 搜索结果 + if code_analysis_tool and content: + await self.emit_thinking(f"LLM 验证 RAG 发现的 {vuln_type}...") + + analysis_result = await code_analysis_tool.execute( + code=content, + file_path=file_path, + focus=vuln_type, + ) + + if analysis_result.success and analysis_result.metadata.get("issues"): + for issue in analysis_result.metadata["issues"]: + findings.append({ + "vulnerability_type": issue.get("type", vuln_type), + "severity": issue.get("severity", "medium"), + "title": issue.get("title", f"LLM 确认: {vuln_type}"), + "description": issue.get("description", ""), + "file_path": file_path, + "line_start": issue.get("line", line_start), + "code_snippet": issue.get("code_snippet", content[:500]), + "suggestion": issue.get("suggestion", ""), + "ai_explanation": issue.get("ai_explanation", ""), + "source": "rag_llm_analysis", + "needs_verification": True, + }) + else: + # RAG 找到但 LLM 未确认 + findings.append({ + "vulnerability_type": vuln_type, + "severity": "low", + "title": f"疑似 {vuln_type} (待确认)", + "description": f"RAG 搜索发现可能存在 {vuln_type},但 LLM 未确认", + "file_path": file_path, + "line_start": line_start, + "code_snippet": content[:500], + "source": "rag_search", + "needs_verification": True, + }) + else: + findings.append({ + "vulnerability_type": vuln_type, + "severity": "medium", + "title": f"疑似 {vuln_type}", + "description": f"通过语义搜索发现可能存在 {vuln_type}", + "file_path": file_path, + "line_start": line_start, + "code_snippet": content[:500], + "source": "rag_search", + "needs_verification": True, + }) return findings + async def _llm_comprehensive_scan(self, tech_stack: Dict) -> List[Dict]: + """ + LLM 全面代码扫描 + 当其他方法没有发现足够的问题时,使用 LLM 直接分析关键文件 + """ + findings = [] + + list_tool = self.tools.get("list_files") + read_tool = self.tools.get("read_file") + code_analysis_tool = self.tools.get("code_analysis") + + if not all([list_tool, read_tool, code_analysis_tool]): + return findings + + await self.emit_thinking("LLM 全面扫描关键代码文件...") + + # 确定要扫描的文件类型 + languages = tech_stack.get("languages", []) + file_patterns = [] + + if "Python" in languages: + file_patterns.extend(["*.py"]) + if "JavaScript" in languages or "TypeScript" in languages: + file_patterns.extend(["*.js", "*.ts"]) + if "Go" in languages: + file_patterns.extend(["*.go"]) + if "Java" in languages: + file_patterns.extend(["*.java"]) + if "PHP" in languages: + file_patterns.extend(["*.php"]) + + if not file_patterns: + file_patterns = ["*.py", "*.js", "*.ts", "*.go", "*.java", "*.php"] + + # 扫描关键目录 + key_dirs = ["src", "app", "api", "routes", "controllers", "handlers", "lib", "utils", "."] + scanned_files = 0 + max_files_to_scan = 10 + + for key_dir in key_dirs: + if scanned_files >= max_files_to_scan or self.is_cancelled: + break + + for pattern in file_patterns[:3]: + if scanned_files >= max_files_to_scan or self.is_cancelled: + break + + # 列出文件 + list_result = await list_tool.execute( + directory=key_dir, + pattern=pattern, + recursive=True, + max_files=20, + ) + + if not list_result.success: + continue + + # 从输出中提取文件路径 + output = list_result.data + file_paths = [] + for line in output.split('\n'): + line = line.strip() + if line.startswith('📄 '): + file_paths.append(line[2:].strip()) + + # 分析每个文件 + for file_path in file_paths[:5]: + if scanned_files >= max_files_to_scan or self.is_cancelled: + break + + # 跳过测试文件和配置文件 + if any(skip in file_path.lower() for skip in ['test', 'spec', 'mock', '__pycache__', 'node_modules']): + continue + + await self.emit_thinking(f"LLM 分析文件: {file_path}") + + # 读取文件 + read_result = await read_tool.execute( + file_path=file_path, + max_lines=200, + ) + + if not read_result.success: + continue + + scanned_files += 1 + + # 🔥 LLM 深度分析 + analysis_result = await code_analysis_tool.execute( + code=read_result.data, + file_path=file_path, + ) + + if analysis_result.success and analysis_result.metadata.get("issues"): + for issue in analysis_result.metadata["issues"]: + findings.append({ + "vulnerability_type": issue.get("type", "other"), + "severity": issue.get("severity", "medium"), + "title": issue.get("title", "LLM 发现的安全问题"), + "description": issue.get("description", ""), + "file_path": file_path, + "line_start": issue.get("line", 0), + "code_snippet": issue.get("code_snippet", ""), + "suggestion": issue.get("suggestion", ""), + "ai_explanation": issue.get("ai_explanation", ""), + "source": "llm_comprehensive_scan", + "needs_verification": True, + }) + + await self.emit_thinking(f"LLM 全面扫描完成,分析了 {scanned_files} 个文件") + return findings + def _deduplicate_findings(self, findings: List[Dict]) -> List[Dict]: """去重发现""" seen = set() diff --git a/backend/app/services/agent/agents/react_agent.py b/backend/app/services/agent/agents/react_agent.py new file mode 100644 index 0000000..e7f5631 --- /dev/null +++ b/backend/app/services/agent/agents/react_agent.py @@ -0,0 +1,375 @@ +""" +真正的 ReAct Agent 实现 +LLM 是大脑,全程参与决策! + +ReAct 循环: +1. Thought: LLM 思考当前状态和下一步 +2. Action: LLM 决定调用哪个工具 +3. Observation: 执行工具,获取结果 +4. 重复直到 LLM 决定完成 +""" + +import json +import logging +import re +from typing import List, Dict, Any, Optional, Tuple +from dataclasses import dataclass + +from .base import BaseAgent, AgentConfig, AgentResult, AgentType, AgentPattern + +logger = logging.getLogger(__name__) + + +REACT_SYSTEM_PROMPT = """你是 DeepAudit 安全审计 Agent,一个专业的代码安全分析专家。 + +## 你的任务 +对目标项目进行全面的安全审计,发现潜在的安全漏洞。 + +## 你的工具 +{tools_description} + +## 工作方式 +你需要通过 **思考-行动-观察** 循环来完成任务: + +1. **Thought**: 分析当前情况,思考下一步应该做什么 +2. **Action**: 选择一个工具并执行 +3. **Observation**: 观察工具返回的结果 +4. 重复上述过程直到你认为审计完成 + +## 输出格式 +每一步必须严格按照以下格式输出: + +``` +Thought: [你的思考过程,分析当前状态,决定下一步] +Action: [工具名称] +Action Input: [工具参数,JSON 格式] +``` + +当你完成分析后,输出: +``` +Thought: [总结分析结果] +Final Answer: [JSON 格式的最终发现] +``` + +## Final Answer 格式 +```json +{{ + "findings": [ + {{ + "vulnerability_type": "sql_injection", + "severity": "high", + "title": "SQL 注入漏洞", + "description": "详细描述", + "file_path": "path/to/file.py", + "line_start": 42, + "code_snippet": "危险代码片段", + "suggestion": "修复建议" + }} + ], + "summary": "审计总结" +}} +``` + +## 审计策略建议 +1. 先用 list_files 了解项目结构 +2. 识别关键文件(路由、控制器、数据库操作) +3. 使用 search_code 搜索危险模式(eval, exec, query, innerHTML 等) +4. 读取可疑文件进行深度分析 +5. 如果有 semgrep,用它进行全面扫描 + +## 重点关注的漏洞类型 +- SQL 注入 (query, execute, raw SQL) +- XSS (innerHTML, document.write, v-html) +- 命令注入 (exec, system, subprocess, child_process) +- 路径遍历 (open, readFile, path concatenation) +- SSRF (requests, fetch, http client) +- 硬编码密钥 (password, secret, api_key, token) +- 不安全的反序列化 (pickle, yaml.load, eval) + +现在开始审计!""" + + +@dataclass +class AgentStep: + """Agent 执行步骤""" + thought: str + action: Optional[str] = None + action_input: Optional[Dict] = None + observation: Optional[str] = None + is_final: bool = False + final_answer: Optional[Dict] = None + + +class ReActAgent(BaseAgent): + """ + 真正的 ReAct Agent + + LLM 全程参与决策,自主选择工具和分析策略 + """ + + def __init__( + self, + llm_service, + tools: Dict[str, Any], + event_emitter=None, + agent_type: AgentType = AgentType.ANALYSIS, + max_iterations: int = 30, + ): + config = AgentConfig( + name="ReActAgent", + agent_type=agent_type, + pattern=AgentPattern.REACT, + max_iterations=max_iterations, + system_prompt=REACT_SYSTEM_PROMPT, + ) + super().__init__(config, llm_service, tools, event_emitter) + + self._conversation_history: List[Dict[str, str]] = [] + self._steps: List[AgentStep] = [] + + def _get_tools_description(self) -> str: + """生成工具描述""" + descriptions = [] + + for name, tool in self.tools.items(): + if name.startswith("_"): + continue + + desc = f"### {name}\n" + desc += f"{tool.description}\n" + + # 添加参数说明 + if hasattr(tool, 'args_schema') and tool.args_schema: + schema = tool.args_schema.schema() + properties = schema.get("properties", {}) + if properties: + desc += "参数:\n" + for param_name, param_info in properties.items(): + param_desc = param_info.get("description", "") + param_type = param_info.get("type", "string") + desc += f" - {param_name} ({param_type}): {param_desc}\n" + + descriptions.append(desc) + + return "\n".join(descriptions) + + def _build_system_prompt(self, project_info: Dict, task_context: str = "") -> str: + """构建系统提示词""" + tools_desc = self._get_tools_description() + prompt = self.config.system_prompt.format(tools_description=tools_desc) + + if project_info: + prompt += f"\n\n## 项目信息\n" + prompt += f"- 名称: {project_info.get('name', 'unknown')}\n" + prompt += f"- 语言: {', '.join(project_info.get('languages', ['unknown']))}\n" + prompt += f"- 文件数: {project_info.get('file_count', 'unknown')}\n" + + if task_context: + prompt += f"\n\n## 任务上下文\n{task_context}" + + return prompt + + def _parse_llm_response(self, response: str) -> AgentStep: + """解析 LLM 响应""" + step = AgentStep(thought="") + + # 提取 Thought + thought_match = re.search(r'Thought:\s*(.*?)(?=Action:|Final Answer:|$)', response, re.DOTALL) + if thought_match: + step.thought = thought_match.group(1).strip() + + # 检查是否是最终答案 + final_match = re.search(r'Final Answer:\s*(.*?)$', response, re.DOTALL) + if final_match: + step.is_final = True + try: + # 尝试提取 JSON + answer_text = final_match.group(1).strip() + # 移除 markdown 代码块 + answer_text = re.sub(r'```json\s*', '', answer_text) + answer_text = re.sub(r'```\s*', '', answer_text) + step.final_answer = json.loads(answer_text) + except json.JSONDecodeError: + step.final_answer = {"raw_answer": final_match.group(1).strip()} + return step + + # 提取 Action + action_match = re.search(r'Action:\s*(\w+)', response) + if action_match: + step.action = action_match.group(1).strip() + + # 提取 Action Input + input_match = re.search(r'Action Input:\s*(.*?)(?=Thought:|Action:|Observation:|$)', response, re.DOTALL) + if input_match: + input_text = input_match.group(1).strip() + # 移除 markdown 代码块 + input_text = re.sub(r'```json\s*', '', input_text) + input_text = re.sub(r'```\s*', '', input_text) + try: + step.action_input = json.loads(input_text) + except json.JSONDecodeError: + # 尝试简单解析 + step.action_input = {"raw_input": input_text} + + return step + + async def _execute_tool(self, tool_name: str, tool_input: Dict) -> str: + """执行工具""" + tool = self.tools.get(tool_name) + + if not tool: + return f"错误: 工具 '{tool_name}' 不存在。可用工具: {list(self.tools.keys())}" + + try: + self._tool_calls += 1 + await self.emit_tool_call(tool_name, tool_input) + + import time + start = time.time() + + result = await tool.execute(**tool_input) + + duration_ms = int((time.time() - start) * 1000) + await self.emit_tool_result(tool_name, str(result.data)[:200], duration_ms) + + if result.success: + # 截断过长的输出 + output = str(result.data) + if len(output) > 4000: + output = output[:4000] + "\n\n... [输出已截断,共 {} 字符]".format(len(str(result.data))) + return output + else: + return f"工具执行失败: {result.error}" + + except Exception as e: + logger.error(f"Tool execution error: {e}") + return f"工具执行错误: {str(e)}" + + async def run(self, input_data: Dict[str, Any]) -> AgentResult: + """ + 执行 ReAct Agent + + LLM 全程参与,自主决策! + """ + import time + start_time = time.time() + + project_info = input_data.get("project_info", {}) + task_context = input_data.get("task_context", "") + config = input_data.get("config", {}) + + # 构建系统提示词 + system_prompt = self._build_system_prompt(project_info, task_context) + + # 初始化对话历史 + self._conversation_history = [ + {"role": "system", "content": system_prompt}, + {"role": "user", "content": "请开始对项目进行安全审计。首先了解项目结构,然后系统性地搜索和分析潜在的安全漏洞。"}, + ] + + self._steps = [] + all_findings = [] + + await self.emit_thinking("🤖 ReAct Agent 启动,LLM 开始自主分析...") + + try: + for iteration in range(self.config.max_iterations): + if self.is_cancelled: + break + + self._iteration = iteration + 1 + + await self.emit_thinking(f"💭 第 {iteration + 1} 轮思考...") + + # 🔥 调用 LLM 进行思考和决策 + response = await self.llm_service.chat_completion_raw( + messages=self._conversation_history, + temperature=0.1, + max_tokens=2048, + ) + + llm_output = response.get("content", "") + self._total_tokens += response.get("usage", {}).get("total_tokens", 0) + + # 发射思考事件 + await self.emit_event("thinking", f"LLM: {llm_output[:500]}...") + + # 解析 LLM 响应 + step = self._parse_llm_response(llm_output) + self._steps.append(step) + + # 添加 LLM 响应到历史 + self._conversation_history.append({ + "role": "assistant", + "content": llm_output, + }) + + # 检查是否完成 + if step.is_final: + await self.emit_thinking("✅ LLM 完成分析,生成最终报告") + + if step.final_answer and "findings" in step.final_answer: + all_findings = step.final_answer["findings"] + break + + # 执行工具 + if step.action: + await self.emit_thinking(f"🔧 LLM 决定调用工具: {step.action}") + + observation = await self._execute_tool( + step.action, + step.action_input or {} + ) + + step.observation = observation + + # 添加观察结果到历史 + self._conversation_history.append({ + "role": "user", + "content": f"Observation: {observation}", + }) + else: + # LLM 没有选择工具,提示它继续 + self._conversation_history.append({ + "role": "user", + "content": "请继续分析,选择一个工具执行,或者如果分析完成,输出 Final Answer。", + }) + + duration_ms = int((time.time() - start_time) * 1000) + + await self.emit_event( + "info", + f"🎯 ReAct Agent 完成: {len(all_findings)} 个发现, {self._iteration} 轮迭代, {self._tool_calls} 次工具调用" + ) + + return AgentResult( + success=True, + data={ + "findings": all_findings, + "steps": [ + { + "thought": s.thought, + "action": s.action, + "action_input": s.action_input, + "observation": s.observation[:500] if s.observation else None, + } + for s in self._steps + ], + }, + iterations=self._iteration, + tool_calls=self._tool_calls, + tokens_used=self._total_tokens, + duration_ms=duration_ms, + ) + + except Exception as e: + logger.error(f"ReAct Agent failed: {e}", exc_info=True) + return AgentResult(success=False, error=str(e)) + + def get_conversation_history(self) -> List[Dict[str, str]]: + """获取对话历史""" + return self._conversation_history + + def get_steps(self) -> List[AgentStep]: + """获取执行步骤""" + return self._steps diff --git a/backend/app/services/agent/event_manager.py b/backend/app/services/agent/event_manager.py index b90307d..77898ef 100644 --- a/backend/app/services/agent/event_manager.py +++ b/backend/app/services/agent/event_manager.py @@ -192,6 +192,37 @@ class AgentEventEmitter: "percentage": percentage, }, )) + + async def emit_task_complete( + self, + findings_count: int, + duration_ms: int, + message: Optional[str] = None, + ): + """发射任务完成事件""" + await self.emit(AgentEventData( + event_type="task_complete", + message=message or f"✅ 审计完成!发现 {findings_count} 个漏洞,耗时 {duration_ms/1000:.1f}秒", + metadata={ + "findings_count": findings_count, + "duration_ms": duration_ms, + }, + )) + + async def emit_task_error(self, error: str, message: Optional[str] = None): + """发射任务错误事件""" + await self.emit(AgentEventData( + event_type="task_error", + message=message or f"❌ 任务失败: {error}", + metadata={"error": error}, + )) + + async def emit_task_cancelled(self, message: Optional[str] = None): + """发射任务取消事件""" + await self.emit(AgentEventData( + event_type="task_cancel", + message=message or "⚠️ 任务已取消", + )) class EventManager: @@ -368,4 +399,15 @@ class EventManager: def create_emitter(self, task_id: str) -> AgentEventEmitter: """创建事件发射器""" return AgentEventEmitter(task_id, self) + + async def close(self): + """关闭事件管理器,清理资源""" + # 清理所有队列 + for task_id in list(self._event_queues.keys()): + self.remove_queue(task_id) + + # 清理所有回调 + self._event_callbacks.clear() + + logger.debug("EventManager closed") diff --git a/backend/app/services/agent/graph/runner.py b/backend/app/services/agent/graph/runner.py index d303c8b..f8e3560 100644 --- a/backend/app/services/agent/graph/runner.py +++ b/backend/app/services/agent/graph/runner.py @@ -15,6 +15,7 @@ from sqlalchemy.ext.asyncio import AsyncSession from langgraph.graph import StateGraph, END from langgraph.checkpoint.memory import MemorySaver +from app.services.agent.streaming import StreamHandler, StreamEvent, StreamEventType from app.models.agent_task import ( AgentTask, AgentEvent, AgentFinding, AgentTaskStatus, AgentTaskPhase, AgentEventType, @@ -39,11 +40,15 @@ logger = logging.getLogger(__name__) class LLMService: - """LLM 服务封装""" + """ + LLM 服务封装 + 提供代码分析、漏洞检测等 AI 功能 + """ def __init__(self, model: Optional[str] = None, api_key: Optional[str] = None): self.model = model or settings.LLM_MODEL or "gpt-4o-mini" self.api_key = api_key or settings.LLM_API_KEY + self.base_url = settings.LLM_BASE_URL async def chat_completion_raw( self, @@ -61,6 +66,7 @@ class LLMService: temperature=temperature, max_tokens=max_tokens, api_key=self.api_key, + base_url=self.base_url, ) return { @@ -75,6 +81,125 @@ class LLMService: except Exception as e: logger.error(f"LLM call failed: {e}") raise + + async def analyze_code(self, code: str, language: str) -> Dict[str, Any]: + """ + 分析代码安全问题 + + Args: + code: 代码内容 + language: 编程语言 + + Returns: + 分析结果,包含 issues 列表 + """ + prompt = f"""请分析以下 {language} 代码的安全问题。 + +代码: +```{language} +{code[:8000]} +``` + +请识别所有潜在的安全漏洞,包括但不限于: +- SQL 注入 +- XSS (跨站脚本) +- 命令注入 +- 路径遍历 +- 不安全的反序列化 +- 硬编码密钥/密码 +- 不安全的加密 +- SSRF +- 认证/授权问题 + +对于每个发现的问题,请提供: +1. 漏洞类型 +2. 严重程度 (critical/high/medium/low) +3. 问题描述 +4. 具体行号 +5. 修复建议 + +请以 JSON 格式返回结果: +{{ + "issues": [ + {{ + "type": "漏洞类型", + "severity": "严重程度", + "title": "问题标题", + "description": "详细描述", + "line": 行号, + "code_snippet": "相关代码片段", + "suggestion": "修复建议" + }} + ], + "quality_score": 0-100 +}} + +如果没有发现安全问题,返回空的 issues 数组和较高的 quality_score。""" + + try: + result = await self.chat_completion_raw( + messages=[ + {"role": "system", "content": "你是一位专业的代码安全审计专家,擅长发现代码中的安全漏洞。请只返回 JSON 格式的结果,不要包含其他内容。"}, + {"role": "user", "content": prompt}, + ], + temperature=0.1, + max_tokens=4096, + ) + + content = result.get("content", "{}") + + # 尝试提取 JSON + import json + import re + + # 尝试直接解析 + try: + return json.loads(content) + except json.JSONDecodeError: + pass + + # 尝试从 markdown 代码块提取 + json_match = re.search(r'```(?:json)?\s*([\s\S]*?)\s*```', content) + if json_match: + try: + return json.loads(json_match.group(1)) + except json.JSONDecodeError: + pass + + # 返回空结果 + return {"issues": [], "quality_score": 80} + + except Exception as e: + logger.error(f"Code analysis failed: {e}") + return {"issues": [], "quality_score": 0, "error": str(e)} + + async def analyze_code_with_custom_prompt( + self, + code: str, + language: str, + prompt: str, + **kwargs + ) -> Dict[str, Any]: + """使用自定义提示词分析代码""" + full_prompt = prompt.replace("{code}", code).replace("{language}", language) + + try: + result = await self.chat_completion_raw( + messages=[ + {"role": "system", "content": "你是一位专业的代码安全审计专家。"}, + {"role": "user", "content": full_prompt}, + ], + temperature=0.1, + ) + + return { + "analysis": result.get("content", ""), + "usage": result.get("usage", {}), + } + + except Exception as e: + logger.error(f"Custom analysis failed: {e}") + return {"analysis": "", "error": str(e)} class AgentRunner: @@ -97,8 +222,9 @@ class AgentRunner: self.task = task self.project_root = project_root - # 事件管理 - self.event_manager = EventManager() + # 事件管理 - 传入 db_session_factory 以持久化事件 + from app.db.session import async_session_factory + self.event_manager = EventManager(db_session_factory=async_session_factory) self.event_emitter = AgentEventEmitter(task.id, self.event_manager) # LLM 服务 @@ -120,6 +246,22 @@ class AgentRunner: # 状态 self._cancelled = False + self._running_task: Optional[asyncio.Task] = None + + # 流式处理器 + self.stream_handler = StreamHandler(task.id) + + def cancel(self): + """取消任务""" + self._cancelled = True + if self._running_task and not self._running_task.done(): + self._running_task.cancel() + logger.info(f"Task {self.task.id} cancellation requested") + + @property + def is_cancelled(self) -> bool: + """检查是否已取消""" + return self._cancelled async def initialize(self): """初始化 Runner""" @@ -149,15 +291,15 @@ class AgentRunner: ) self.indexer = CodeIndexer( - embedding_service=embedding_service, - vector_db_path=settings.VECTOR_DB_PATH, collection_name=f"project_{self.task.project_id}", + embedding_service=embedding_service, + persist_directory=settings.VECTOR_DB_PATH, ) self.retriever = CodeRetriever( - embedding_service=embedding_service, - vector_db_path=settings.VECTOR_DB_PATH, collection_name=f"project_{self.task.project_id}", + embedding_service=embedding_service, + persist_directory=settings.VECTOR_DB_PATH, ) except Exception as e: @@ -261,6 +403,18 @@ class AgentRunner: Returns: 最终状态 """ + result = {} + async for _ in self.run_with_streaming(): + pass # 消费所有事件 + return result + + async def run_with_streaming(self) -> AsyncGenerator[StreamEvent, None]: + """ + 带流式输出的审计执行 + + Yields: + StreamEvent: 流式事件(包含 LLM 思考、工具调用等) + """ import time start_time = time.time() @@ -271,17 +425,28 @@ class AgentRunner: # 更新任务状态 await self._update_task_status(AgentTaskStatus.RUNNING) + # 发射任务开始事件 + yield StreamEvent( + event_type=StreamEventType.TASK_START, + sequence=self.stream_handler._next_sequence(), + data={"task_id": self.task.id, "message": "🚀 审计任务开始"}, + ) + # 1. 索引代码 await self._index_code() if self._cancelled: - return {"success": False, "error": "任务已取消"} + yield StreamEvent( + event_type=StreamEventType.TASK_CANCEL, + sequence=self.stream_handler._next_sequence(), + data={"message": "任务已取消"}, + ) + return # 2. 收集项目信息 project_info = await self._collect_project_info() # 3. 构建初始状态 - # 从任务字段构建配置 task_config = { "target_vulnerabilities": self.task.target_vulnerabilities or [], "verification_level": self.task.verification_level or "sandbox", @@ -314,7 +479,7 @@ class AgentRunner: "error": None, } - # 4. 执行 LangGraph + # 4. 执行 LangGraph with astream_events await self.event_emitter.emit_phase_start("langgraph", "🔄 启动 LangGraph 工作流") run_config = { @@ -325,26 +490,57 @@ class AgentRunner: final_state = None - # 流式执行并发射事件 - async for event in self.graph.astream(initial_state, config=run_config): - if self._cancelled: - break - - # 处理每个节点的输出 - for node_name, node_output in event.items(): - await self._handle_node_output(node_name, node_output) + # 使用 astream_events 获取详细事件流 + try: + async for event in self.graph.astream_events( + initial_state, + config=run_config, + version="v2", + ): + if self._cancelled: + break - # 更新阶段 - phase_map = { - "recon": AgentTaskPhase.RECONNAISSANCE, - "analysis": AgentTaskPhase.ANALYSIS, - "verification": AgentTaskPhase.VERIFICATION, - "report": AgentTaskPhase.REPORTING, - } - if node_name in phase_map: - await self._update_task_phase(phase_map[node_name]) + # 处理 LangGraph 事件 + stream_event = await self.stream_handler.process_langgraph_event(event) + if stream_event: + # 同步到 event_emitter 以持久化 + await self._sync_stream_event_to_db(stream_event) + yield stream_event - final_state = node_output + # 更新最终状态 + if event.get("event") == "on_chain_end": + output = event.get("data", {}).get("output") + if isinstance(output, dict): + final_state = output + + except Exception as e: + # 如果 astream_events 不可用,回退到 astream + logger.warning(f"astream_events not available, falling back to astream: {e}") + async for event in self.graph.astream(initial_state, config=run_config): + if self._cancelled: + break + + for node_name, node_output in event.items(): + await self._handle_node_output(node_name, node_output) + + # 发射节点事件 + yield StreamEvent( + event_type=StreamEventType.NODE_END, + sequence=self.stream_handler._next_sequence(), + node_name=node_name, + data={"message": f"节点 {node_name} 完成"}, + ) + + phase_map = { + "recon": AgentTaskPhase.RECONNAISSANCE, + "analysis": AgentTaskPhase.ANALYSIS, + "verification": AgentTaskPhase.VERIFICATION, + "report": AgentTaskPhase.REPORTING, + } + if node_name in phase_map: + await self._update_task_phase(phase_map[node_name]) + + final_state = node_output # 5. 获取最终状态 if not final_state: @@ -355,6 +551,13 @@ class AgentRunner: findings = final_state.get("findings", []) await self._save_findings(findings) + # 发射发现事件 + for finding in findings[:10]: # 限制数量 + yield self.stream_handler.create_finding_event( + finding, + is_verified=finding.get("is_verified", False), + ) + # 7. 更新任务摘要 summary = final_state.get("summary", {}) security_score = final_state.get("security_score", 100) @@ -374,30 +577,59 @@ class AgentRunner: duration_ms=duration_ms, ) - return { - "success": True, - "data": { - "findings": findings, - "verified_findings": final_state.get("verified_findings", []), - "summary": summary, + yield StreamEvent( + event_type=StreamEventType.TASK_COMPLETE, + sequence=self.stream_handler._next_sequence(), + data={ + "findings_count": len(findings), + "verified_count": len(final_state.get("verified_findings", [])), "security_score": security_score, + "duration_ms": duration_ms, + "message": f"✅ 审计完成!发现 {len(findings)} 个漏洞", }, - "duration_ms": duration_ms, - } + ) except asyncio.CancelledError: await self._update_task_status(AgentTaskStatus.CANCELLED) - return {"success": False, "error": "任务已取消"} + yield StreamEvent( + event_type=StreamEventType.TASK_CANCEL, + sequence=self.stream_handler._next_sequence(), + data={"message": "任务已取消"}, + ) except Exception as e: logger.error(f"LangGraph run failed: {e}", exc_info=True) await self._update_task_status(AgentTaskStatus.FAILED, str(e)) await self.event_emitter.emit_error(str(e)) - return {"success": False, "error": str(e)} + + yield StreamEvent( + event_type=StreamEventType.TASK_ERROR, + sequence=self.stream_handler._next_sequence(), + data={"error": str(e), "message": f"❌ 审计失败: {e}"}, + ) finally: await self._cleanup() + async def _sync_stream_event_to_db(self, event: StreamEvent): + """同步流式事件到数据库""" + try: + # 将 StreamEvent 转换为 AgentEventData + await self.event_manager.add_event( + task_id=self.task.id, + event_type=event.event_type.value, + sequence=event.sequence, + phase=event.phase, + message=event.data.get("message"), + tool_name=event.tool_name, + tool_input=event.data.get("input") or event.data.get("input_params"), + tool_output=event.data.get("output") or event.data.get("output_data"), + tool_duration_ms=event.data.get("duration_ms"), + metadata=event.data, + ) + except Exception as e: + logger.warning(f"Failed to sync stream event to db: {e}") + async def _handle_node_output(self, node_name: str, output: Dict[str, Any]): """处理节点输出""" # 发射节点事件 @@ -445,7 +677,8 @@ class AgentRunner: return await self.event_emitter.emit_progress( - progress.processed / max(progress.total, 1) * 100, + progress.processed_files, + progress.total_files, f"正在索引: {progress.current_file or 'N/A'}" ) @@ -502,13 +735,23 @@ class AgentRunner: type_map = { "sql_injection": VulnerabilityType.SQL_INJECTION, + "nosql_injection": VulnerabilityType.NOSQL_INJECTION, "xss": VulnerabilityType.XSS, "command_injection": VulnerabilityType.COMMAND_INJECTION, + "code_injection": VulnerabilityType.CODE_INJECTION, "path_traversal": VulnerabilityType.PATH_TRAVERSAL, + "file_inclusion": VulnerabilityType.FILE_INCLUSION, "ssrf": VulnerabilityType.SSRF, + "xxe": VulnerabilityType.XXE, + "deserialization": VulnerabilityType.DESERIALIZATION, + "auth_bypass": VulnerabilityType.AUTH_BYPASS, + "idor": VulnerabilityType.IDOR, + "sensitive_data_exposure": VulnerabilityType.SENSITIVE_DATA_EXPOSURE, "hardcoded_secret": VulnerabilityType.HARDCODED_SECRET, - "deserialization": VulnerabilityType.INSECURE_DESERIALIZATION, "weak_crypto": VulnerabilityType.WEAK_CRYPTO, + "race_condition": VulnerabilityType.RACE_CONDITION, + "business_logic": VulnerabilityType.BUSINESS_LOGIC, + "memory_corruption": VulnerabilityType.MEMORY_CORRUPTION, } for finding in findings: @@ -536,7 +779,7 @@ class AgentRunner: is_verified=finding.get("is_verified", False), confidence=finding.get("confidence", 0.5), poc=finding.get("poc"), - status=FindingStatus.VERIFIED if finding.get("is_verified") else FindingStatus.OPEN, + status=FindingStatus.VERIFIED if finding.get("is_verified") else FindingStatus.NEW, ) self.db.add(db_finding) @@ -603,10 +846,6 @@ class AgentRunner: await self.event_manager.close() except Exception as e: logger.warning(f"Cleanup error: {e}") - - def cancel(self): - """取消任务""" - self._cancelled = True # 便捷函数 diff --git a/backend/app/services/agent/streaming/__init__.py b/backend/app/services/agent/streaming/__init__.py new file mode 100644 index 0000000..12c2ab1 --- /dev/null +++ b/backend/app/services/agent/streaming/__init__.py @@ -0,0 +1,17 @@ +""" +Agent 流式输出模块 +支持 LLM Token 流式输出、工具调用展示、思考过程展示 +""" + +from .stream_handler import StreamHandler, StreamEvent, StreamEventType +from .token_streamer import TokenStreamer +from .tool_stream import ToolStreamHandler + +__all__ = [ + "StreamHandler", + "StreamEvent", + "StreamEventType", + "TokenStreamer", + "ToolStreamHandler", +] + diff --git a/backend/app/services/agent/streaming/stream_handler.py b/backend/app/services/agent/streaming/stream_handler.py new file mode 100644 index 0000000..7accd40 --- /dev/null +++ b/backend/app/services/agent/streaming/stream_handler.py @@ -0,0 +1,453 @@ +""" +流式事件处理器 +处理 LangGraph 的各种流式事件并转换为前端可消费的格式 +""" + +import json +import logging +from enum import Enum +from typing import Any, Dict, Optional, AsyncGenerator, List +from dataclasses import dataclass, field +from datetime import datetime, timezone + +logger = logging.getLogger(__name__) + + +class StreamEventType(str, Enum): + """流式事件类型""" + # LLM 相关 + THINKING_START = "thinking_start" # 开始思考 + THINKING_TOKEN = "thinking_token" # 思考 Token + THINKING_END = "thinking_end" # 思考结束 + + # 工具调用相关 + TOOL_CALL_START = "tool_call_start" # 工具调用开始 + TOOL_CALL_INPUT = "tool_call_input" # 工具输入参数 + TOOL_CALL_OUTPUT = "tool_call_output" # 工具输出结果 + TOOL_CALL_END = "tool_call_end" # 工具调用结束 + TOOL_CALL_ERROR = "tool_call_error" # 工具调用错误 + + # 节点相关 + NODE_START = "node_start" # 节点开始 + NODE_END = "node_end" # 节点结束 + + # 阶段相关 + PHASE_START = "phase_start" + PHASE_END = "phase_end" + + # 发现相关 + FINDING_NEW = "finding_new" # 新发现 + FINDING_VERIFIED = "finding_verified" # 验证通过 + + # 状态相关 + PROGRESS = "progress" + INFO = "info" + WARNING = "warning" + ERROR = "error" + + # 任务相关 + TASK_START = "task_start" + TASK_COMPLETE = "task_complete" + TASK_ERROR = "task_error" + TASK_CANCEL = "task_cancel" + + # 心跳 + HEARTBEAT = "heartbeat" + + +@dataclass +class StreamEvent: + """流式事件""" + event_type: StreamEventType + data: Dict[str, Any] = field(default_factory=dict) + timestamp: str = field(default_factory=lambda: datetime.now(timezone.utc).isoformat()) + sequence: int = 0 + + # 可选字段 + node_name: Optional[str] = None + phase: Optional[str] = None + tool_name: Optional[str] = None + + def to_sse(self) -> str: + """转换为 SSE 格式""" + event_data = { + "type": self.event_type.value, + "data": self.data, + "timestamp": self.timestamp, + "sequence": self.sequence, + } + + if self.node_name: + event_data["node"] = self.node_name + if self.phase: + event_data["phase"] = self.phase + if self.tool_name: + event_data["tool"] = self.tool_name + + return f"event: {self.event_type.value}\ndata: {json.dumps(event_data, ensure_ascii=False)}\n\n" + + def to_dict(self) -> Dict[str, Any]: + """转换为字典""" + return { + "event_type": self.event_type.value, + "data": self.data, + "timestamp": self.timestamp, + "sequence": self.sequence, + "node_name": self.node_name, + "phase": self.phase, + "tool_name": self.tool_name, + } + + +class StreamHandler: + """ + 流式事件处理器 + + 最佳实践: + 1. 使用 astream_events 捕获所有 LangGraph 事件 + 2. 将内部事件转换为前端友好的格式 + 3. 支持多种事件类型的分发 + """ + + def __init__(self, task_id: str): + self.task_id = task_id + self._sequence = 0 + self._current_phase = None + self._current_node = None + self._thinking_buffer = [] + self._tool_states: Dict[str, Dict] = {} + + def _next_sequence(self) -> int: + """获取下一个序列号""" + self._sequence += 1 + return self._sequence + + async def process_langgraph_event(self, event: Dict[str, Any]) -> Optional[StreamEvent]: + """ + 处理 LangGraph 事件 + + 支持的事件类型: + - on_chain_start: 链/节点开始 + - on_chain_end: 链/节点结束 + - on_chain_stream: LLM Token 流 + - on_chat_model_start: 模型开始 + - on_chat_model_stream: 模型 Token 流 + - on_chat_model_end: 模型结束 + - on_tool_start: 工具开始 + - on_tool_end: 工具结束 + - on_custom_event: 自定义事件 + """ + event_kind = event.get("event", "") + event_name = event.get("name", "") + event_data = event.get("data", {}) + + # LLM Token 流 + if event_kind == "on_chat_model_stream": + return await self._handle_llm_stream(event_data, event_name) + + # LLM 开始 + elif event_kind == "on_chat_model_start": + return await self._handle_llm_start(event_data, event_name) + + # LLM 结束 + elif event_kind == "on_chat_model_end": + return await self._handle_llm_end(event_data, event_name) + + # 工具开始 + elif event_kind == "on_tool_start": + return await self._handle_tool_start(event_name, event_data) + + # 工具结束 + elif event_kind == "on_tool_end": + return await self._handle_tool_end(event_name, event_data) + + # 节点开始 + elif event_kind == "on_chain_start" and self._is_node_event(event_name): + return await self._handle_node_start(event_name, event_data) + + # 节点结束 + elif event_kind == "on_chain_end" and self._is_node_event(event_name): + return await self._handle_node_end(event_name, event_data) + + # 自定义事件 + elif event_kind == "on_custom_event": + return await self._handle_custom_event(event_name, event_data) + + return None + + def _is_node_event(self, name: str) -> bool: + """判断是否是节点事件""" + node_names = ["recon", "analysis", "verification", "report", "ReconNode", "AnalysisNode", "VerificationNode", "ReportNode"] + return any(n.lower() in name.lower() for n in node_names) + + async def _handle_llm_start(self, data: Dict, name: str) -> StreamEvent: + """处理 LLM 开始事件""" + self._thinking_buffer = [] + + return StreamEvent( + event_type=StreamEventType.THINKING_START, + sequence=self._next_sequence(), + node_name=self._current_node, + phase=self._current_phase, + data={ + "model": name, + "message": "🤔 正在思考...", + }, + ) + + async def _handle_llm_stream(self, data: Dict, name: str) -> Optional[StreamEvent]: + """处理 LLM Token 流事件""" + chunk = data.get("chunk") + if not chunk: + return None + + # 提取 Token 内容 + content = "" + if hasattr(chunk, "content"): + content = chunk.content + elif isinstance(chunk, dict): + content = chunk.get("content", "") + + if not content: + return None + + # 添加到缓冲区 + self._thinking_buffer.append(content) + + return StreamEvent( + event_type=StreamEventType.THINKING_TOKEN, + sequence=self._next_sequence(), + node_name=self._current_node, + phase=self._current_phase, + data={ + "token": content, + "accumulated": "".join(self._thinking_buffer), + }, + ) + + async def _handle_llm_end(self, data: Dict, name: str) -> StreamEvent: + """处理 LLM 结束事件""" + full_response = "".join(self._thinking_buffer) + self._thinking_buffer = [] + + # 提取使用的 Token 数 + usage = {} + output = data.get("output") + if output and hasattr(output, "usage_metadata"): + usage = { + "input_tokens": getattr(output.usage_metadata, "input_tokens", 0), + "output_tokens": getattr(output.usage_metadata, "output_tokens", 0), + } + + return StreamEvent( + event_type=StreamEventType.THINKING_END, + sequence=self._next_sequence(), + node_name=self._current_node, + phase=self._current_phase, + data={ + "response": full_response[:2000], # 截断长响应 + "usage": usage, + "message": "💡 思考完成", + }, + ) + + async def _handle_tool_start(self, tool_name: str, data: Dict) -> StreamEvent: + """处理工具开始事件""" + import time + + tool_input = data.get("input", {}) + + # 记录工具状态 + self._tool_states[tool_name] = { + "start_time": time.time(), + "input": tool_input, + } + + return StreamEvent( + event_type=StreamEventType.TOOL_CALL_START, + sequence=self._next_sequence(), + node_name=self._current_node, + phase=self._current_phase, + tool_name=tool_name, + data={ + "tool_name": tool_name, + "input": self._truncate_data(tool_input), + "message": f"🔧 调用工具: {tool_name}", + }, + ) + + async def _handle_tool_end(self, tool_name: str, data: Dict) -> StreamEvent: + """处理工具结束事件""" + import time + + # 计算执行时间 + duration_ms = 0 + if tool_name in self._tool_states: + start_time = self._tool_states[tool_name].get("start_time", time.time()) + duration_ms = int((time.time() - start_time) * 1000) + del self._tool_states[tool_name] + + # 提取输出 + output = data.get("output", "") + if hasattr(output, "content"): + output = output.content + + return StreamEvent( + event_type=StreamEventType.TOOL_CALL_END, + sequence=self._next_sequence(), + node_name=self._current_node, + phase=self._current_phase, + tool_name=tool_name, + data={ + "tool_name": tool_name, + "output": self._truncate_data(output), + "duration_ms": duration_ms, + "message": f"✅ 工具 {tool_name} 完成 ({duration_ms}ms)", + }, + ) + + async def _handle_node_start(self, node_name: str, data: Dict) -> StreamEvent: + """处理节点开始事件""" + self._current_node = node_name + + # 映射节点到阶段 + phase_map = { + "recon": "reconnaissance", + "analysis": "analysis", + "verification": "verification", + "report": "reporting", + } + + for key, phase in phase_map.items(): + if key in node_name.lower(): + self._current_phase = phase + break + + return StreamEvent( + event_type=StreamEventType.NODE_START, + sequence=self._next_sequence(), + node_name=node_name, + phase=self._current_phase, + data={ + "node": node_name, + "phase": self._current_phase, + "message": f"▶️ 开始节点: {node_name}", + }, + ) + + async def _handle_node_end(self, node_name: str, data: Dict) -> StreamEvent: + """处理节点结束事件""" + # 提取输出信息 + output = data.get("output", {}) + + summary = {} + if isinstance(output, dict): + # 提取关键信息 + if "findings" in output: + summary["findings_count"] = len(output["findings"]) + if "entry_points" in output: + summary["entry_points_count"] = len(output["entry_points"]) + if "high_risk_areas" in output: + summary["high_risk_areas_count"] = len(output["high_risk_areas"]) + if "verified_findings" in output: + summary["verified_count"] = len(output["verified_findings"]) + + return StreamEvent( + event_type=StreamEventType.NODE_END, + sequence=self._next_sequence(), + node_name=node_name, + phase=self._current_phase, + data={ + "node": node_name, + "phase": self._current_phase, + "summary": summary, + "message": f"⏹️ 节点完成: {node_name}", + }, + ) + + async def _handle_custom_event(self, event_name: str, data: Dict) -> StreamEvent: + """处理自定义事件""" + # 映射自定义事件名到事件类型 + event_type_map = { + "finding": StreamEventType.FINDING_NEW, + "finding_verified": StreamEventType.FINDING_VERIFIED, + "progress": StreamEventType.PROGRESS, + "warning": StreamEventType.WARNING, + "error": StreamEventType.ERROR, + } + + event_type = event_type_map.get(event_name, StreamEventType.INFO) + + return StreamEvent( + event_type=event_type, + sequence=self._next_sequence(), + node_name=self._current_node, + phase=self._current_phase, + data=data, + ) + + def _truncate_data(self, data: Any, max_length: int = 1000) -> Any: + """截断数据""" + if isinstance(data, str): + return data[:max_length] + "..." if len(data) > max_length else data + elif isinstance(data, dict): + return {k: self._truncate_data(v, max_length // 2) for k, v in list(data.items())[:10]} + elif isinstance(data, list): + return [self._truncate_data(item, max_length // len(data)) for item in data[:10]] + else: + return str(data)[:max_length] + + def create_progress_event( + self, + current: int, + total: int, + message: Optional[str] = None, + ) -> StreamEvent: + """创建进度事件""" + percentage = (current / total * 100) if total > 0 else 0 + + return StreamEvent( + event_type=StreamEventType.PROGRESS, + sequence=self._next_sequence(), + node_name=self._current_node, + phase=self._current_phase, + data={ + "current": current, + "total": total, + "percentage": round(percentage, 1), + "message": message or f"进度: {current}/{total}", + }, + ) + + def create_finding_event( + self, + finding: Dict[str, Any], + is_verified: bool = False, + ) -> StreamEvent: + """创建发现事件""" + event_type = StreamEventType.FINDING_VERIFIED if is_verified else StreamEventType.FINDING_NEW + + return StreamEvent( + event_type=event_type, + sequence=self._next_sequence(), + node_name=self._current_node, + phase=self._current_phase, + data={ + "title": finding.get("title", "Unknown"), + "severity": finding.get("severity", "medium"), + "vulnerability_type": finding.get("vulnerability_type", "other"), + "file_path": finding.get("file_path"), + "line_start": finding.get("line_start"), + "is_verified": is_verified, + "message": f"{'✅ 已验证' if is_verified else '🔍 新发现'}: [{finding.get('severity', 'medium').upper()}] {finding.get('title', 'Unknown')}", + }, + ) + + def create_heartbeat(self) -> StreamEvent: + """创建心跳事件""" + return StreamEvent( + event_type=StreamEventType.HEARTBEAT, + sequence=self._sequence, # 心跳不增加序列号 + data={"message": "ping"}, + ) + diff --git a/backend/app/services/agent/streaming/token_streamer.py b/backend/app/services/agent/streaming/token_streamer.py new file mode 100644 index 0000000..ed16b1e --- /dev/null +++ b/backend/app/services/agent/streaming/token_streamer.py @@ -0,0 +1,261 @@ +""" +LLM Token 流式输出处理器 +支持多种 LLM 提供商的流式输出 +""" + +import asyncio +import logging +from typing import Any, Dict, Optional, AsyncGenerator, Callable +from dataclasses import dataclass +from datetime import datetime, timezone + +logger = logging.getLogger(__name__) + + +@dataclass +class TokenChunk: + """Token 块""" + content: str + token_count: int = 1 + finish_reason: Optional[str] = None + model: Optional[str] = None + + # 统计信息 + accumulated_content: str = "" + total_tokens: int = 0 + + +class TokenStreamer: + """ + LLM Token 流式输出处理器 + + 最佳实践: + 1. 使用 LiteLLM 的流式 API + 2. 实时发送每个 Token + 3. 跟踪累积内容和 Token 使用 + 4. 支持中断和超时 + """ + + def __init__( + self, + model: str, + api_key: Optional[str] = None, + base_url: Optional[str] = None, + on_token: Optional[Callable[[TokenChunk], None]] = None, + ): + self.model = model + self.api_key = api_key + self.base_url = base_url + self.on_token = on_token + + self._cancelled = False + self._accumulated_content = "" + self._total_tokens = 0 + + def cancel(self): + """取消流式输出""" + self._cancelled = True + + async def stream_completion( + self, + messages: list[Dict[str, str]], + temperature: float = 0.1, + max_tokens: int = 4096, + ) -> AsyncGenerator[TokenChunk, None]: + """ + 流式调用 LLM + + Args: + messages: 消息列表 + temperature: 温度 + max_tokens: 最大 Token 数 + + Yields: + TokenChunk: Token 块 + """ + try: + import litellm + + response = await litellm.acompletion( + model=self.model, + messages=messages, + temperature=temperature, + max_tokens=max_tokens, + api_key=self.api_key, + base_url=self.base_url, + stream=True, # 启用流式输出 + ) + + async for chunk in response: + if self._cancelled: + break + + # 提取内容 + content = "" + finish_reason = None + + if hasattr(chunk, "choices") and chunk.choices: + choice = chunk.choices[0] + if hasattr(choice, "delta") and choice.delta: + content = getattr(choice.delta, "content", "") or "" + finish_reason = getattr(choice, "finish_reason", None) + + if content: + self._accumulated_content += content + self._total_tokens += 1 + + token_chunk = TokenChunk( + content=content, + token_count=1, + finish_reason=finish_reason, + model=self.model, + accumulated_content=self._accumulated_content, + total_tokens=self._total_tokens, + ) + + # 回调 + if self.on_token: + self.on_token(token_chunk) + + yield token_chunk + + # 检查是否完成 + if finish_reason: + break + + except asyncio.CancelledError: + logger.info("Token streaming cancelled") + raise + + except Exception as e: + logger.error(f"Token streaming error: {e}") + raise + + async def stream_with_tools( + self, + messages: list[Dict[str, str]], + tools: list[Dict[str, Any]], + temperature: float = 0.1, + max_tokens: int = 4096, + ) -> AsyncGenerator[Dict[str, Any], None]: + """ + 带工具调用的流式输出 + + Args: + messages: 消息列表 + tools: 工具定义列表 + temperature: 温度 + max_tokens: 最大 Token 数 + + Yields: + 包含 token 或 tool_call 的字典 + """ + try: + import litellm + + response = await litellm.acompletion( + model=self.model, + messages=messages, + tools=tools, + temperature=temperature, + max_tokens=max_tokens, + api_key=self.api_key, + base_url=self.base_url, + stream=True, + ) + + # 工具调用累积器 + tool_calls_accumulator: Dict[int, Dict] = {} + + async for chunk in response: + if self._cancelled: + break + + if not hasattr(chunk, "choices") or not chunk.choices: + continue + + choice = chunk.choices[0] + delta = getattr(choice, "delta", None) + finish_reason = getattr(choice, "finish_reason", None) + + if delta: + # 处理文本内容 + content = getattr(delta, "content", "") or "" + if content: + self._accumulated_content += content + self._total_tokens += 1 + + yield { + "type": "token", + "content": content, + "accumulated": self._accumulated_content, + "total_tokens": self._total_tokens, + } + + # 处理工具调用 + tool_calls = getattr(delta, "tool_calls", None) or [] + for tool_call in tool_calls: + idx = tool_call.index + + if idx not in tool_calls_accumulator: + tool_calls_accumulator[idx] = { + "id": tool_call.id or "", + "name": "", + "arguments": "", + } + + if tool_call.function: + if tool_call.function.name: + tool_calls_accumulator[idx]["name"] = tool_call.function.name + if tool_call.function.arguments: + tool_calls_accumulator[idx]["arguments"] += tool_call.function.arguments + + yield { + "type": "tool_call_chunk", + "index": idx, + "tool_call": tool_calls_accumulator[idx], + } + + # 完成时发送最终工具调用 + if finish_reason == "tool_calls": + for idx, tool_call in tool_calls_accumulator.items(): + yield { + "type": "tool_call_complete", + "index": idx, + "tool_call": tool_call, + } + + if finish_reason: + yield { + "type": "finish", + "reason": finish_reason, + "accumulated": self._accumulated_content, + "total_tokens": self._total_tokens, + } + break + + except asyncio.CancelledError: + logger.info("Tool streaming cancelled") + raise + + except Exception as e: + logger.error(f"Tool streaming error: {e}") + yield { + "type": "error", + "error": str(e), + } + + def get_accumulated_content(self) -> str: + """获取累积内容""" + return self._accumulated_content + + def get_total_tokens(self) -> int: + """获取总 Token 数""" + return self._total_tokens + + def reset(self): + """重置状态""" + self._cancelled = False + self._accumulated_content = "" + self._total_tokens = 0 + diff --git a/backend/app/services/agent/streaming/tool_stream.py b/backend/app/services/agent/streaming/tool_stream.py new file mode 100644 index 0000000..5278dca --- /dev/null +++ b/backend/app/services/agent/streaming/tool_stream.py @@ -0,0 +1,319 @@ +""" +工具调用流式处理器 +展示工具调用的输入、执行过程和输出 +""" + +import asyncio +import time +import logging +from typing import Any, Dict, Optional, AsyncGenerator, List, Callable +from dataclasses import dataclass, field +from datetime import datetime, timezone +from enum import Enum + +logger = logging.getLogger(__name__) + + +class ToolCallState(str, Enum): + """工具调用状态""" + PENDING = "pending" # 等待执行 + RUNNING = "running" # 执行中 + SUCCESS = "success" # 成功 + ERROR = "error" # 错误 + TIMEOUT = "timeout" # 超时 + + +@dataclass +class ToolCallEvent: + """工具调用事件""" + tool_name: str + state: ToolCallState + + # 输入输出 + input_params: Dict[str, Any] = field(default_factory=dict) + output_data: Optional[Any] = None + error_message: Optional[str] = None + + # 时间 + start_time: Optional[float] = None + end_time: Optional[float] = None + duration_ms: int = 0 + + # 元数据 + call_id: Optional[str] = None + sequence: int = 0 + timestamp: str = field(default_factory=lambda: datetime.now(timezone.utc).isoformat()) + + def to_dict(self) -> Dict[str, Any]: + """转换为字典""" + return { + "tool_name": self.tool_name, + "state": self.state.value, + "input_params": self._truncate(self.input_params), + "output_data": self._truncate(self.output_data), + "error_message": self.error_message, + "duration_ms": self.duration_ms, + "call_id": self.call_id, + "sequence": self.sequence, + "timestamp": self.timestamp, + } + + def _truncate(self, data: Any, max_length: int = 500) -> Any: + """截断数据""" + if data is None: + return None + if isinstance(data, str): + return data[:max_length] + "..." if len(data) > max_length else data + elif isinstance(data, dict): + return {k: self._truncate(v, max_length // 2) for k, v in list(data.items())[:20]} + elif isinstance(data, list): + max_items = min(20, len(data)) + return [self._truncate(item, max_length // max_items) for item in data[:max_items]] + else: + s = str(data) + return s[:max_length] + "..." if len(s) > max_length else s + + +class ToolStreamHandler: + """ + 工具调用流式处理器 + + 功能: + 1. 跟踪工具调用状态 + 2. 记录输入参数 + 3. 流式输出执行过程 + 4. 记录输出和执行时间 + """ + + def __init__( + self, + on_event: Optional[Callable[[ToolCallEvent], None]] = None, + ): + self.on_event = on_event + self._sequence = 0 + self._active_calls: Dict[str, ToolCallEvent] = {} + self._history: List[ToolCallEvent] = [] + + def _next_sequence(self) -> int: + """获取下一个序列号""" + self._sequence += 1 + return self._sequence + + def _generate_call_id(self) -> str: + """生成调用 ID""" + import uuid + return str(uuid.uuid4())[:8] + + async def emit_tool_start( + self, + tool_name: str, + input_params: Dict[str, Any], + call_id: Optional[str] = None, + ) -> ToolCallEvent: + """ + 发射工具开始事件 + + Args: + tool_name: 工具名称 + input_params: 输入参数 + call_id: 调用 ID + + Returns: + 工具调用事件 + """ + call_id = call_id or self._generate_call_id() + + event = ToolCallEvent( + tool_name=tool_name, + state=ToolCallState.RUNNING, + input_params=input_params, + start_time=time.time(), + call_id=call_id, + sequence=self._next_sequence(), + ) + + self._active_calls[call_id] = event + + if self.on_event: + self.on_event(event) + + return event + + async def emit_tool_end( + self, + call_id: str, + output_data: Any, + is_error: bool = False, + error_message: Optional[str] = None, + ) -> ToolCallEvent: + """ + 发射工具结束事件 + + Args: + call_id: 调用 ID + output_data: 输出数据 + is_error: 是否错误 + error_message: 错误消息 + + Returns: + 工具调用事件 + """ + if call_id not in self._active_calls: + logger.warning(f"Unknown tool call: {call_id}") + return None + + event = self._active_calls[call_id] + event.end_time = time.time() + event.duration_ms = int((event.end_time - event.start_time) * 1000) if event.start_time else 0 + event.output_data = output_data + event.sequence = self._next_sequence() + + if is_error: + event.state = ToolCallState.ERROR + event.error_message = error_message or str(output_data) + else: + event.state = ToolCallState.SUCCESS + + # 移动到历史记录 + del self._active_calls[call_id] + self._history.append(event) + + if self.on_event: + self.on_event(event) + + return event + + async def emit_tool_timeout(self, call_id: str, timeout_seconds: int) -> ToolCallEvent: + """发射工具超时事件""" + if call_id not in self._active_calls: + return None + + event = self._active_calls[call_id] + event.end_time = time.time() + event.duration_ms = int((event.end_time - event.start_time) * 1000) if event.start_time else 0 + event.state = ToolCallState.TIMEOUT + event.error_message = f"Tool execution timed out after {timeout_seconds}s" + event.sequence = self._next_sequence() + + del self._active_calls[call_id] + self._history.append(event) + + if self.on_event: + self.on_event(event) + + return event + + def wrap_tool( + self, + tool_func: Callable, + tool_name: str, + timeout: Optional[int] = None, + ) -> Callable: + """ + 包装工具函数以自动跟踪 + + Args: + tool_func: 工具函数 + tool_name: 工具名称 + timeout: 超时时间(秒) + + Returns: + 包装后的函数 + """ + async def wrapped(*args, **kwargs): + call_id = self._generate_call_id() + + # 发射开始事件 + await self.emit_tool_start( + tool_name=tool_name, + input_params={"args": args, "kwargs": kwargs}, + call_id=call_id, + ) + + try: + # 执行工具 + if asyncio.iscoroutinefunction(tool_func): + if timeout: + result = await asyncio.wait_for( + tool_func(*args, **kwargs), + timeout=timeout, + ) + else: + result = await tool_func(*args, **kwargs) + else: + if timeout: + result = await asyncio.wait_for( + asyncio.to_thread(tool_func, *args, **kwargs), + timeout=timeout, + ) + else: + result = tool_func(*args, **kwargs) + + # 发射结束事件 + await self.emit_tool_end(call_id, result) + + return result + + except asyncio.TimeoutError: + await self.emit_tool_timeout(call_id, timeout or 0) + raise + + except Exception as e: + await self.emit_tool_end(call_id, None, is_error=True, error_message=str(e)) + raise + + return wrapped + + def get_active_calls(self) -> List[ToolCallEvent]: + """获取活跃的调用""" + return list(self._active_calls.values()) + + def get_history(self, limit: int = 100) -> List[ToolCallEvent]: + """获取历史记录""" + return self._history[-limit:] + + def get_stats(self) -> Dict[str, Any]: + """获取统计信息""" + total_calls = len(self._history) + success_calls = sum(1 for e in self._history if e.state == ToolCallState.SUCCESS) + error_calls = sum(1 for e in self._history if e.state == ToolCallState.ERROR) + timeout_calls = sum(1 for e in self._history if e.state == ToolCallState.TIMEOUT) + + total_duration = sum(e.duration_ms for e in self._history) + avg_duration = total_duration / total_calls if total_calls > 0 else 0 + + # 按工具统计 + tool_stats = {} + for event in self._history: + if event.tool_name not in tool_stats: + tool_stats[event.tool_name] = { + "calls": 0, + "success": 0, + "errors": 0, + "total_duration_ms": 0, + } + tool_stats[event.tool_name]["calls"] += 1 + if event.state == ToolCallState.SUCCESS: + tool_stats[event.tool_name]["success"] += 1 + elif event.state in [ToolCallState.ERROR, ToolCallState.TIMEOUT]: + tool_stats[event.tool_name]["errors"] += 1 + tool_stats[event.tool_name]["total_duration_ms"] += event.duration_ms + + return { + "total_calls": total_calls, + "success_calls": success_calls, + "error_calls": error_calls, + "timeout_calls": timeout_calls, + "success_rate": success_calls / total_calls if total_calls > 0 else 0, + "total_duration_ms": total_duration, + "avg_duration_ms": round(avg_duration, 2), + "active_calls": len(self._active_calls), + "by_tool": tool_stats, + } + + def clear(self): + """清空记录""" + self._active_calls.clear() + self._history.clear() + self._sequence = 0 + diff --git a/backend/app/services/agent/tools/pattern_tool.py b/backend/app/services/agent/tools/pattern_tool.py index c63e54c..09a7da8 100644 --- a/backend/app/services/agent/tools/pattern_tool.py +++ b/backend/app/services/agent/tools/pattern_tool.py @@ -41,6 +41,16 @@ class PatternMatchTool(AgentTool): 使用正则表达式快速扫描代码中的危险模式 """ + def __init__(self, project_root: str = None): + """ + 初始化模式匹配工具 + + Args: + project_root: 项目根目录(可选,用于上下文) + """ + super().__init__() + self.project_root = project_root + # 危险模式定义 PATTERNS: Dict[str, Dict[str, Any]] = { # SQL 注入模式 diff --git a/backend/app/services/llm/adapters/baidu_adapter.py b/backend/app/services/llm/adapters/baidu_adapter.py index cdf962f..b05375e 100644 --- a/backend/app/services/llm/adapters/baidu_adapter.py +++ b/backend/app/services/llm/adapters/baidu_adapter.py @@ -145,3 +145,5 @@ class BaiduAdapter(BaseLLMAdapter): + + diff --git a/backend/app/services/llm/adapters/doubao_adapter.py b/backend/app/services/llm/adapters/doubao_adapter.py index a95da90..346bcc5 100644 --- a/backend/app/services/llm/adapters/doubao_adapter.py +++ b/backend/app/services/llm/adapters/doubao_adapter.py @@ -83,3 +83,5 @@ class DoubaoAdapter(BaseLLMAdapter): + + diff --git a/backend/app/services/llm/adapters/minimax_adapter.py b/backend/app/services/llm/adapters/minimax_adapter.py index e57faa4..177d9ed 100644 --- a/backend/app/services/llm/adapters/minimax_adapter.py +++ b/backend/app/services/llm/adapters/minimax_adapter.py @@ -86,3 +86,5 @@ class MinimaxAdapter(BaseLLMAdapter): + + diff --git a/backend/app/services/llm/base_adapter.py b/backend/app/services/llm/base_adapter.py index 1f868ac..74cda21 100644 --- a/backend/app/services/llm/base_adapter.py +++ b/backend/app/services/llm/base_adapter.py @@ -134,3 +134,5 @@ class BaseLLMAdapter(ABC): + + diff --git a/backend/app/services/llm/types.py b/backend/app/services/llm/types.py index 5ddfd37..5c669b0 100644 --- a/backend/app/services/llm/types.py +++ b/backend/app/services/llm/types.py @@ -120,3 +120,5 @@ DEFAULT_BASE_URLS: Dict[LLMProvider, str] = { + + diff --git a/backend/tests/agent/__init__.py b/backend/tests/agent/__init__.py new file mode 100644 index 0000000..9a94db5 --- /dev/null +++ b/backend/tests/agent/__init__.py @@ -0,0 +1,5 @@ +""" +DeepAudit Agent 测试套件 +企业级测试框架,覆盖工具、Agent、节点和完整流程 +""" + diff --git a/backend/tests/agent/conftest.py b/backend/tests/agent/conftest.py new file mode 100644 index 0000000..e6693c3 --- /dev/null +++ b/backend/tests/agent/conftest.py @@ -0,0 +1,295 @@ +""" +Agent 测试配置和 Fixtures +提供测试所需的公共设施 +""" + +import pytest +import asyncio +import tempfile +import shutil +import os +from typing import Dict, Any, Optional +from unittest.mock import AsyncMock, MagicMock, patch +from dataclasses import dataclass + + +# ============ 测试配置 ============ + +@pytest.fixture(scope="session") +def event_loop(): + """创建事件循环""" + loop = asyncio.new_event_loop() + yield loop + loop.close() + + +@pytest.fixture +def temp_project_dir(): + """创建临时项目目录,包含测试代码""" + temp_dir = tempfile.mkdtemp(prefix="deepaudit_test_") + + # 创建测试项目结构 + os.makedirs(os.path.join(temp_dir, "src"), exist_ok=True) + os.makedirs(os.path.join(temp_dir, "config"), exist_ok=True) + + # 创建有漏洞的测试代码 - SQL 注入 + sql_vuln_code = ''' +import sqlite3 + +def get_user(user_id): + """危险:SQL 注入漏洞""" + conn = sqlite3.connect("users.db") + cursor = conn.cursor() + # 直接拼接用户输入,存在 SQL 注入风险 + query = f"SELECT * FROM users WHERE id = '{user_id}'" + cursor.execute(query) + return cursor.fetchone() + +def search_users(name): + """危险:SQL 注入漏洞""" + conn = sqlite3.connect("users.db") + cursor = conn.cursor() + cursor.execute("SELECT * FROM users WHERE name LIKE '%" + name + "%'") + return cursor.fetchall() +''' + + # 创建有漏洞的测试代码 - 命令注入 + cmd_vuln_code = ''' +import os +import subprocess + +def run_command(user_input): + """危险:命令注入漏洞""" + # 直接执行用户输入 + os.system(f"echo {user_input}") + +def execute_script(script_name): + """危险:命令注入漏洞""" + subprocess.call(f"bash {script_name}", shell=True) +''' + + # 创建有漏洞的测试代码 - XSS + xss_vuln_code = ''' +from flask import Flask, request, render_template_string + +app = Flask(__name__) + +@app.route("/greet") +def greet(): + """危险:XSS 漏洞""" + name = request.args.get("name", "") + # 直接将用户输入嵌入 HTML,存在 XSS 风险 + return f"

Hello, {name}!

" + +@app.route("/search") +def search(): + """危险:XSS 漏洞""" + query = request.args.get("q", "") + html = f"

搜索结果: {query}

" + return render_template_string(html) +''' + + # 创建有漏洞的测试代码 - 路径遍历 + path_vuln_code = ''' +import os + +def read_file(filename): + """危险:路径遍历漏洞""" + # 没有验证文件路径 + filepath = os.path.join("/app/data", filename) + with open(filepath, "r") as f: + return f.read() + +def download_file(user_path): + """危险:路径遍历漏洞""" + # 直接使用用户输入作为文件路径 + with open(user_path, "rb") as f: + return f.read() +''' + + # 创建有漏洞的测试代码 - 硬编码密钥 + secret_vuln_code = ''' +# 配置文件 +DATABASE_URL = "postgresql://user:password123@localhost/db" +API_KEY = "sk-1234567890abcdef1234567890abcdef" +SECRET_KEY = "super_secret_key_dont_share" +AWS_SECRET_ACCESS_KEY = "wJalrXUtnFEMI/K7MDENG/bPxRfiCYEXAMPLEKEY" + +def connect_database(): + password = "admin123" # 硬编码密码 + return f"mysql://root:{password}@localhost/mydb" +''' + + # 创建安全的代码(用于对比) + safe_code = ''' +import sqlite3 +from typing import Optional + +def get_user_safe(user_id: int) -> Optional[dict]: + """安全:使用参数化查询""" + conn = sqlite3.connect("users.db") + cursor = conn.cursor() + cursor.execute("SELECT * FROM users WHERE id = ?", (user_id,)) + return cursor.fetchone() + +def validate_input(user_input: str) -> str: + """输入验证""" + import re + if not re.match(r'^[a-zA-Z0-9_]+$', user_input): + raise ValueError("Invalid input") + return user_input +''' + + # 创建配置文件 + config_code = ''' +import os + +class Config: + """安全配置""" + DATABASE_URL = os.environ.get("DATABASE_URL") + SECRET_KEY = os.environ.get("SECRET_KEY") + DEBUG = False +''' + + # 创建 requirements.txt + requirements = ''' +flask>=2.0.0 +sqlalchemy>=2.0.0 +requests>=2.28.0 +''' + + # 写入文件 + with open(os.path.join(temp_dir, "src", "sql_vuln.py"), "w") as f: + f.write(sql_vuln_code) + + with open(os.path.join(temp_dir, "src", "cmd_vuln.py"), "w") as f: + f.write(cmd_vuln_code) + + with open(os.path.join(temp_dir, "src", "xss_vuln.py"), "w") as f: + f.write(xss_vuln_code) + + with open(os.path.join(temp_dir, "src", "path_vuln.py"), "w") as f: + f.write(path_vuln_code) + + with open(os.path.join(temp_dir, "src", "secrets.py"), "w") as f: + f.write(secret_vuln_code) + + with open(os.path.join(temp_dir, "src", "safe_code.py"), "w") as f: + f.write(safe_code) + + with open(os.path.join(temp_dir, "config", "settings.py"), "w") as f: + f.write(config_code) + + with open(os.path.join(temp_dir, "requirements.txt"), "w") as f: + f.write(requirements) + + yield temp_dir + + # 清理 + shutil.rmtree(temp_dir, ignore_errors=True) + + +@pytest.fixture +def mock_llm_service(): + """模拟 LLM 服务""" + service = MagicMock() + service.chat_completion_raw = AsyncMock(return_value={ + "content": "测试响应", + "usage": {"total_tokens": 100}, + }) + return service + + +@pytest.fixture +def mock_event_emitter(): + """模拟事件发射器""" + emitter = MagicMock() + emitter.emit_info = AsyncMock() + emitter.emit_warning = AsyncMock() + emitter.emit_error = AsyncMock() + emitter.emit_thinking = AsyncMock() + emitter.emit_tool_call = AsyncMock() + emitter.emit_tool_result = AsyncMock() + emitter.emit_finding = AsyncMock() + emitter.emit_progress = AsyncMock() + emitter.emit_phase_start = AsyncMock() + emitter.emit_phase_complete = AsyncMock() + emitter.emit_task_complete = AsyncMock() + emitter.emit = AsyncMock() + return emitter + + +@pytest.fixture +def mock_db_session(): + """模拟数据库会话""" + session = AsyncMock() + session.add = MagicMock() + session.commit = AsyncMock() + session.rollback = AsyncMock() + session.get = AsyncMock(return_value=None) + session.execute = AsyncMock() + return session + + +@dataclass +class MockProject: + """模拟项目""" + id: str = "test-project-id" + name: str = "Test Project" + description: str = "Test project for unit tests" + + +@dataclass +class MockAgentTask: + """模拟 Agent 任务""" + id: str = "test-task-id" + project_id: str = "test-project-id" + project: MockProject = None + name: str = "Test Agent Task" + status: str = "pending" + current_phase: str = "planning" + target_vulnerabilities: list = None + verification_level: str = "sandbox" + exclude_patterns: list = None + target_files: list = None + max_iterations: int = 50 + timeout_seconds: int = 1800 + + def __post_init__(self): + if self.project is None: + self.project = MockProject() + if self.target_vulnerabilities is None: + self.target_vulnerabilities = [] + if self.exclude_patterns is None: + self.exclude_patterns = [] + if self.target_files is None: + self.target_files = [] + + +@pytest.fixture +def mock_task(): + """创建模拟任务""" + return MockAgentTask() + + +# ============ 测试辅助函数 ============ + +def assert_finding_valid(finding: Dict[str, Any]): + """验证漏洞发现的格式""" + required_fields = ["title", "severity", "vulnerability_type"] + for field in required_fields: + assert field in finding, f"Missing required field: {field}" + + valid_severities = ["critical", "high", "medium", "low", "info"] + assert finding["severity"] in valid_severities, f"Invalid severity: {finding['severity']}" + + +def count_findings_by_type(findings: list, vuln_type: str) -> int: + """统计特定类型的漏洞数量""" + return sum(1 for f in findings if f.get("vulnerability_type") == vuln_type) + + +def count_findings_by_severity(findings: list, severity: str) -> int: + """统计特定严重程度的漏洞数量""" + return sum(1 for f in findings if f.get("severity") == severity) + diff --git a/backend/tests/agent/run_tests.py b/backend/tests/agent/run_tests.py new file mode 100644 index 0000000..6ea2982 --- /dev/null +++ b/backend/tests/agent/run_tests.py @@ -0,0 +1,96 @@ +#!/usr/bin/env python +""" +Agent 测试运行器 +运行所有 Agent 相关测试并生成报告 +""" + +import subprocess +import sys +import os +from datetime import datetime +from pathlib import Path + + +def run_tests(): + """运行测试""" + # 获取项目根目录 + project_root = Path(__file__).parent.parent.parent + os.chdir(project_root) + + print("=" * 60) + print("DeepAudit Agent 测试套件") + print(f"时间: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}") + print("=" * 60) + print() + + # 测试命令 + cmd = [ + sys.executable, "-m", "pytest", + "tests/agent/", + "-v", + "--tb=short", + "-x", # 遇到第一个失败就停止 + "--color=yes", + ] + + print(f"运行命令: {' '.join(cmd)}") + print() + + # 运行测试 + result = subprocess.run(cmd, cwd=project_root) + + print() + print("=" * 60) + if result.returncode == 0: + print("✅ 所有测试通过!") + else: + print(f"❌ 测试失败 (退出码: {result.returncode})") + print("=" * 60) + + return result.returncode + + +def run_tests_with_coverage(): + """运行测试并生成覆盖率报告""" + project_root = Path(__file__).parent.parent.parent + os.chdir(project_root) + + cmd = [ + sys.executable, "-m", "pytest", + "tests/agent/", + "-v", + "--cov=app/services/agent", + "--cov-report=term-missing", + "--cov-report=html:coverage_agent", + ] + + result = subprocess.run(cmd, cwd=project_root) + return result.returncode + + +def run_specific_test(test_name: str): + """运行特定测试""" + project_root = Path(__file__).parent.parent.parent + os.chdir(project_root) + + cmd = [ + sys.executable, "-m", "pytest", + f"tests/agent/{test_name}", + "-v", + "--tb=long", + "-s", # 显示 print 输出 + ] + + result = subprocess.run(cmd, cwd=project_root) + return result.returncode + + +if __name__ == "__main__": + if len(sys.argv) > 1: + if sys.argv[1] == "--coverage": + sys.exit(run_tests_with_coverage()) + else: + sys.exit(run_specific_test(sys.argv[1])) + else: + sys.exit(run_tests()) + diff --git a/backend/tests/agent/test_agents.py b/backend/tests/agent/test_agents.py new file mode 100644 index 0000000..21a6a73 --- /dev/null +++ b/backend/tests/agent/test_agents.py @@ -0,0 +1,213 @@ +""" +Agent 单元测试 +测试各个 Agent 的功能 +""" + +import pytest +import asyncio +import os +from unittest.mock import MagicMock, AsyncMock, patch + +from app.services.agent.agents.base import BaseAgent, AgentConfig, AgentResult, AgentType, AgentPattern +from app.services.agent.agents.recon import ReconAgent +from app.services.agent.agents.analysis import AnalysisAgent +from app.services.agent.agents.verification import VerificationAgent + + +class TestReconAgent: + """Recon Agent 测试""" + + @pytest.fixture + def recon_agent(self, temp_project_dir, mock_llm_service, mock_event_emitter): + """创建 Recon Agent 实例""" + from app.services.agent.tools import ( + FileReadTool, FileSearchTool, ListFilesTool, + ) + + tools = { + "list_files": ListFilesTool(temp_project_dir), + "read_file": FileReadTool(temp_project_dir), + "search_code": FileSearchTool(temp_project_dir), + } + + return ReconAgent( + llm_service=mock_llm_service, + tools=tools, + event_emitter=mock_event_emitter, + ) + + @pytest.mark.asyncio + async def test_recon_agent_run(self, recon_agent, temp_project_dir): + """测试 Recon Agent 运行""" + result = await recon_agent.run({ + "project_info": { + "name": "Test Project", + "root": temp_project_dir, + }, + "config": {}, + }) + + assert result.success is True + assert result.data is not None + + # 验证返回数据结构 + data = result.data + assert "tech_stack" in data + assert "entry_points" in data or "high_risk_areas" in data + + @pytest.mark.asyncio + async def test_recon_agent_identifies_python(self, recon_agent, temp_project_dir): + """测试 Recon Agent 识别 Python 技术栈""" + result = await recon_agent.run({ + "project_info": {"root": temp_project_dir}, + "config": {}, + }) + + assert result.success is True + tech_stack = result.data.get("tech_stack", {}) + languages = tech_stack.get("languages", []) + + # 应该识别出 Python + assert "Python" in languages or len(languages) > 0 + + @pytest.mark.asyncio + async def test_recon_agent_finds_high_risk_areas(self, recon_agent, temp_project_dir): + """测试 Recon Agent 发现高风险区域""" + result = await recon_agent.run({ + "project_info": {"root": temp_project_dir}, + "config": {}, + }) + + assert result.success is True + high_risk_areas = result.data.get("high_risk_areas", []) + + # 应该发现高风险区域 + assert len(high_risk_areas) > 0 + + +class TestAnalysisAgent: + """Analysis Agent 测试""" + + @pytest.fixture + def analysis_agent(self, temp_project_dir, mock_llm_service, mock_event_emitter): + """创建 Analysis Agent 实例""" + from app.services.agent.tools import ( + FileReadTool, FileSearchTool, PatternMatchTool, + ) + + tools = { + "read_file": FileReadTool(temp_project_dir), + "search_code": FileSearchTool(temp_project_dir), + "pattern_match": PatternMatchTool(temp_project_dir), + } + + return AnalysisAgent( + llm_service=mock_llm_service, + tools=tools, + event_emitter=mock_event_emitter, + ) + + @pytest.mark.asyncio + async def test_analysis_agent_run(self, analysis_agent, temp_project_dir): + """测试 Analysis Agent 运行""" + result = await analysis_agent.run({ + "tech_stack": {"languages": ["Python"]}, + "entry_points": [], + "high_risk_areas": ["src/sql_vuln.py", "src/cmd_vuln.py"], + "config": {}, + }) + + assert result.success is True + assert result.data is not None + + @pytest.mark.asyncio + async def test_analysis_agent_finds_vulnerabilities(self, analysis_agent, temp_project_dir): + """测试 Analysis Agent 发现漏洞""" + result = await analysis_agent.run({ + "tech_stack": {"languages": ["Python"]}, + "entry_points": [], + "high_risk_areas": [ + "src/sql_vuln.py", + "src/cmd_vuln.py", + "src/xss_vuln.py", + "src/secrets.py", + ], + "config": {}, + }) + + assert result.success is True + findings = result.data.get("findings", []) + + # 应该发现一些漏洞 + # 注意:具体数量取决于分析逻辑 + assert isinstance(findings, list) + + +class TestAgentResult: + """Agent 结果测试""" + + def test_agent_result_success(self): + """测试成功的 Agent 结果""" + result = AgentResult( + success=True, + data={"findings": []}, + iterations=5, + tool_calls=10, + ) + + assert result.success is True + assert result.iterations == 5 + assert result.tool_calls == 10 + + def test_agent_result_failure(self): + """测试失败的 Agent 结果""" + result = AgentResult( + success=False, + error="Test error", + ) + + assert result.success is False + assert result.error == "Test error" + + def test_agent_result_to_dict(self): + """测试 Agent 结果转字典""" + result = AgentResult( + success=True, + data={"key": "value"}, + iterations=3, + ) + + d = result.to_dict() + + assert d["success"] is True + assert d["iterations"] == 3 + + +class TestAgentConfig: + """Agent 配置测试""" + + def test_agent_config_defaults(self): + """测试 Agent 配置默认值""" + config = AgentConfig( + name="Test", + agent_type=AgentType.RECON, + ) + + assert config.pattern == AgentPattern.REACT + assert config.max_iterations == 20 + assert config.temperature == 0.1 + + def test_agent_config_custom(self): + """测试自定义 Agent 配置""" + config = AgentConfig( + name="Custom", + agent_type=AgentType.ANALYSIS, + pattern=AgentPattern.PLAN_AND_EXECUTE, + max_iterations=50, + temperature=0.5, + ) + + assert config.pattern == AgentPattern.PLAN_AND_EXECUTE + assert config.max_iterations == 50 + assert config.temperature == 0.5 + diff --git a/backend/tests/agent/test_integration.py b/backend/tests/agent/test_integration.py new file mode 100644 index 0000000..d314404 --- /dev/null +++ b/backend/tests/agent/test_integration.py @@ -0,0 +1,355 @@ +""" +Agent 集成测试 +测试完整的审计流程 +""" + +import pytest +import asyncio +import os +from unittest.mock import MagicMock, AsyncMock, patch +from datetime import datetime + +from app.services.agent.graph.runner import AgentRunner, LLMService +from app.services.agent.graph.audit_graph import AuditState, create_audit_graph +from app.services.agent.graph.nodes import ReconNode, AnalysisNode, VerificationNode, ReportNode +from app.services.agent.event_manager import EventManager, AgentEventEmitter + + +class TestLLMService: + """LLM 服务测试""" + + @pytest.mark.asyncio + async def test_llm_service_initialization(self): + """测试 LLM 服务初始化""" + with patch("app.core.config.settings") as mock_settings: + mock_settings.LLM_MODEL = "gpt-4o-mini" + mock_settings.LLM_API_KEY = "test-key" + + service = LLMService() + + assert service.model == "gpt-4o-mini" + + +class TestEventManager: + """事件管理器测试""" + + def test_event_manager_initialization(self): + """测试事件管理器初始化""" + manager = EventManager() + + assert manager._event_queues == {} + assert manager._event_callbacks == {} + + @pytest.mark.asyncio + async def test_event_emitter(self): + """测试事件发射器""" + manager = EventManager() + emitter = AgentEventEmitter("test-task-id", manager) + + await emitter.emit_info("Test message") + + assert emitter._sequence == 1 + + @pytest.mark.asyncio + async def test_event_emitter_phase_tracking(self): + """测试事件发射器阶段跟踪""" + manager = EventManager() + emitter = AgentEventEmitter("test-task-id", manager) + + await emitter.emit_phase_start("recon", "开始信息收集") + + assert emitter._current_phase == "recon" + + @pytest.mark.asyncio + async def test_event_emitter_task_complete(self): + """测试任务完成事件""" + manager = EventManager() + emitter = AgentEventEmitter("test-task-id", manager) + + await emitter.emit_task_complete(findings_count=5, duration_ms=1000) + + assert emitter._sequence == 1 + + +class TestAuditGraph: + """审计图测试""" + + def test_create_audit_graph(self, mock_event_emitter): + """测试创建审计图""" + # 创建模拟节点 + recon_node = MagicMock() + analysis_node = MagicMock() + verification_node = MagicMock() + report_node = MagicMock() + + graph = create_audit_graph( + recon_node=recon_node, + analysis_node=analysis_node, + verification_node=verification_node, + report_node=report_node, + ) + + assert graph is not None + + +class TestReconNode: + """Recon 节点测试""" + + @pytest.fixture + def recon_node_with_mock_agent(self, mock_event_emitter): + """创建带模拟 Agent 的 Recon 节点""" + mock_agent = MagicMock() + mock_agent.run = AsyncMock(return_value=MagicMock( + success=True, + data={ + "tech_stack": {"languages": ["Python"]}, + "entry_points": [{"path": "src/app.py", "type": "api"}], + "high_risk_areas": ["src/sql_vuln.py"], + "dependencies": {}, + "initial_findings": [], + } + )) + + return ReconNode(mock_agent, mock_event_emitter) + + @pytest.mark.asyncio + async def test_recon_node_success(self, recon_node_with_mock_agent): + """测试 Recon 节点成功执行""" + state = { + "project_info": {"name": "Test"}, + "config": {}, + } + + result = await recon_node_with_mock_agent(state) + + assert "tech_stack" in result + assert "entry_points" in result + assert result["current_phase"] == "recon_complete" + + @pytest.mark.asyncio + async def test_recon_node_failure(self, mock_event_emitter): + """测试 Recon 节点失败处理""" + mock_agent = MagicMock() + mock_agent.run = AsyncMock(return_value=MagicMock( + success=False, + error="Test error", + data=None, + )) + + node = ReconNode(mock_agent, mock_event_emitter) + + result = await node({ + "project_info": {}, + "config": {}, + }) + + assert "error" in result + assert result["current_phase"] == "error" + + +class TestAnalysisNode: + """Analysis 节点测试""" + + @pytest.fixture + def analysis_node_with_mock_agent(self, mock_event_emitter): + """创建带模拟 Agent 的 Analysis 节点""" + mock_agent = MagicMock() + mock_agent.run = AsyncMock(return_value=MagicMock( + success=True, + data={ + "findings": [ + { + "id": "finding-1", + "title": "SQL Injection", + "severity": "high", + "vulnerability_type": "sql_injection", + "file_path": "src/sql_vuln.py", + "line_start": 10, + "description": "SQL injection vulnerability", + } + ], + "should_continue": False, + } + )) + + return AnalysisNode(mock_agent, mock_event_emitter) + + @pytest.mark.asyncio + async def test_analysis_node_success(self, analysis_node_with_mock_agent): + """测试 Analysis 节点成功执行""" + state = { + "project_info": {"name": "Test"}, + "tech_stack": {"languages": ["Python"]}, + "entry_points": [], + "high_risk_areas": ["src/sql_vuln.py"], + "config": {}, + "iteration": 0, + "findings": [], + } + + result = await analysis_node_with_mock_agent(state) + + assert "findings" in result + assert len(result["findings"]) > 0 + assert result["iteration"] == 1 + + +class TestIntegrationFlow: + """完整流程集成测试""" + + @pytest.mark.asyncio + async def test_full_audit_flow_mock(self, temp_project_dir, mock_db_session, mock_task): + """测试完整审计流程(使用模拟)""" + # 这个测试验证整个流程的连接性 + + # 创建事件管理器 + event_manager = EventManager() + emitter = AgentEventEmitter(mock_task.id, event_manager) + + # 模拟 LLM 服务 + mock_llm = MagicMock() + mock_llm.chat_completion_raw = AsyncMock(return_value={ + "content": "Analysis complete", + "usage": {"total_tokens": 100}, + }) + + # 验证事件发射 + await emitter.emit_phase_start("init", "初始化") + await emitter.emit_info("测试消息") + await emitter.emit_phase_complete("init", "初始化完成") + + assert emitter._sequence == 3 + + @pytest.mark.asyncio + async def test_audit_state_typing(self): + """测试审计状态类型定义""" + state: AuditState = { + "project_root": "/tmp/test", + "project_info": {"name": "Test"}, + "config": {}, + "task_id": "test-id", + "tech_stack": {}, + "entry_points": [], + "high_risk_areas": [], + "dependencies": {}, + "findings": [], + "verified_findings": [], + "false_positives": [], + "current_phase": "start", + "iteration": 0, + "max_iterations": 50, + "should_continue_analysis": False, + "messages": [], + "events": [], + "summary": None, + "security_score": None, + "error": None, + } + + assert state["current_phase"] == "start" + assert state["max_iterations"] == 50 + + +class TestToolIntegration: + """工具集成测试""" + + @pytest.mark.asyncio + async def test_tools_work_together(self, temp_project_dir): + """测试工具协同工作""" + from app.services.agent.tools import ( + FileReadTool, FileSearchTool, ListFilesTool, PatternMatchTool, + ) + + # 1. 列出文件 + list_tool = ListFilesTool(temp_project_dir) + list_result = await list_tool.execute(directory="src", recursive=False) + assert list_result.success is True + + # 2. 搜索关键代码 + search_tool = FileSearchTool(temp_project_dir) + search_result = await search_tool.execute(keyword="execute") + assert search_result.success is True + + # 3. 读取文件内容 + read_tool = FileReadTool(temp_project_dir) + read_result = await read_tool.execute(file_path="src/sql_vuln.py") + assert read_result.success is True + + # 4. 模式匹配 + pattern_tool = PatternMatchTool(temp_project_dir) + pattern_result = await pattern_tool.execute( + code=read_result.data, + file_path="src/sql_vuln.py", + language="python" + ) + assert pattern_result.success is True + + +class TestErrorHandling: + """错误处理测试""" + + @pytest.mark.asyncio + async def test_tool_error_handling(self, temp_project_dir): + """测试工具错误处理""" + from app.services.agent.tools import FileReadTool + + tool = FileReadTool(temp_project_dir) + + # 尝试读取不存在的文件 + result = await tool.execute(file_path="nonexistent/file.py") + + assert result.success is False + assert result.error is not None + + @pytest.mark.asyncio + async def test_agent_graceful_degradation(self, mock_event_emitter): + """测试 Agent 优雅降级""" + # 创建一个会失败的 Agent + mock_agent = MagicMock() + mock_agent.run = AsyncMock(side_effect=Exception("Simulated error")) + + node = ReconNode(mock_agent, mock_event_emitter) + + result = await node({ + "project_info": {}, + "config": {}, + }) + + # 应该返回错误状态而不是崩溃 + assert "error" in result + assert result["current_phase"] == "error" + + +class TestPerformance: + """性能测试""" + + @pytest.mark.asyncio + async def test_tool_response_time(self, temp_project_dir): + """测试工具响应时间""" + from app.services.agent.tools import ListFilesTool + import time + + tool = ListFilesTool(temp_project_dir) + + start = time.time() + await tool.execute(directory=".", recursive=True) + duration = time.time() - start + + # 工具应该在合理时间内响应 + assert duration < 5.0 # 5 秒内 + + @pytest.mark.asyncio + async def test_multiple_tool_calls(self, temp_project_dir): + """测试多次工具调用""" + from app.services.agent.tools import FileSearchTool + + tool = FileSearchTool(temp_project_dir) + + # 执行多次调用 + for _ in range(5): + result = await tool.execute(keyword="def") + assert result.success is True + + # 验证调用计数 + assert tool._call_count == 5 + diff --git a/backend/tests/agent/test_tools.py b/backend/tests/agent/test_tools.py new file mode 100644 index 0000000..d8114c9 --- /dev/null +++ b/backend/tests/agent/test_tools.py @@ -0,0 +1,248 @@ +""" +Agent 工具单元测试 +测试各种安全分析工具的功能 +""" + +import pytest +import asyncio +import os +from unittest.mock import MagicMock, AsyncMock, patch + +# 导入工具 +from app.services.agent.tools import ( + FileReadTool, FileSearchTool, ListFilesTool, + PatternMatchTool, +) +from app.services.agent.tools.base import ToolResult + + +class TestFileTools: + """文件操作工具测试""" + + @pytest.mark.asyncio + async def test_file_read_tool_success(self, temp_project_dir): + """测试文件读取工具 - 成功读取""" + tool = FileReadTool(temp_project_dir) + + result = await tool.execute(file_path="src/sql_vuln.py") + + assert result.success is True + assert "SELECT * FROM users" in result.data + assert "sql_injection" in result.data.lower() or "cursor.execute" in result.data + + @pytest.mark.asyncio + async def test_file_read_tool_not_found(self, temp_project_dir): + """测试文件读取工具 - 文件不存在""" + tool = FileReadTool(temp_project_dir) + + result = await tool.execute(file_path="nonexistent.py") + + assert result.success is False + assert "不存在" in result.error or "not found" in result.error.lower() + + @pytest.mark.asyncio + async def test_file_read_tool_path_traversal_blocked(self, temp_project_dir): + """测试文件读取工具 - 路径遍历被阻止""" + tool = FileReadTool(temp_project_dir) + + result = await tool.execute(file_path="../../../etc/passwd") + + assert result.success is False + assert "安全" in result.error or "security" in result.error.lower() + + @pytest.mark.asyncio + async def test_file_search_tool(self, temp_project_dir): + """测试文件搜索工具""" + tool = FileSearchTool(temp_project_dir) + + result = await tool.execute(keyword="cursor.execute") + + assert result.success is True + assert "sql_vuln.py" in result.data + + @pytest.mark.asyncio + async def test_list_files_tool(self, temp_project_dir): + """测试文件列表工具""" + tool = ListFilesTool(temp_project_dir) + + result = await tool.execute(directory=".", recursive=True) + + assert result.success is True + assert "sql_vuln.py" in result.data + assert "requirements.txt" in result.data + + @pytest.mark.asyncio + async def test_list_files_tool_pattern(self, temp_project_dir): + """测试文件列表工具 - 文件模式过滤""" + tool = ListFilesTool(temp_project_dir) + + result = await tool.execute(directory="src", pattern="*.py") + + assert result.success is True + assert "sql_vuln.py" in result.data + + +class TestPatternMatchTool: + """模式匹配工具测试""" + + @pytest.mark.asyncio + async def test_pattern_match_sql_injection(self, temp_project_dir): + """测试模式匹配 - SQL 注入检测""" + tool = PatternMatchTool(temp_project_dir) + + # 读取有漏洞的代码 + with open(os.path.join(temp_project_dir, "src", "sql_vuln.py")) as f: + code = f.read() + + result = await tool.execute( + code=code, + file_path="src/sql_vuln.py", + pattern_types=["sql_injection"], + language="python" + ) + + assert result.success is True + # 应该检测到 SQL 注入模式 + if result.data: + assert "sql" in str(result.data).lower() or len(result.metadata.get("matches", [])) > 0 + + @pytest.mark.asyncio + async def test_pattern_match_command_injection(self, temp_project_dir): + """测试模式匹配 - 命令注入检测""" + tool = PatternMatchTool(temp_project_dir) + + with open(os.path.join(temp_project_dir, "src", "cmd_vuln.py")) as f: + code = f.read() + + result = await tool.execute( + code=code, + file_path="src/cmd_vuln.py", + pattern_types=["command_injection"], + language="python" + ) + + assert result.success is True + + @pytest.mark.asyncio + async def test_pattern_match_xss(self, temp_project_dir): + """测试模式匹配 - XSS 检测""" + tool = PatternMatchTool(temp_project_dir) + + with open(os.path.join(temp_project_dir, "src", "xss_vuln.py")) as f: + code = f.read() + + result = await tool.execute( + code=code, + file_path="src/xss_vuln.py", + pattern_types=["xss"], + language="python" + ) + + assert result.success is True + + @pytest.mark.asyncio + async def test_pattern_match_hardcoded_secrets(self, temp_project_dir): + """测试模式匹配 - 硬编码密钥检测""" + tool = PatternMatchTool(temp_project_dir) + + with open(os.path.join(temp_project_dir, "src", "secrets.py")) as f: + code = f.read() + + result = await tool.execute( + code=code, + file_path="src/secrets.py", + pattern_types=["hardcoded_secret"], + ) + + assert result.success is True + + @pytest.mark.asyncio + async def test_pattern_match_safe_code(self, temp_project_dir): + """测试模式匹配 - 安全代码应该没有问题""" + tool = PatternMatchTool(temp_project_dir) + + with open(os.path.join(temp_project_dir, "src", "safe_code.py")) as f: + code = f.read() + + result = await tool.execute( + code=code, + file_path="src/safe_code.py", + pattern_types=["sql_injection"], + language="python" + ) + + assert result.success is True + # 安全代码使用参数化查询,不应该有 SQL 注入漏洞 + # 检查结果数据,如果有 matches 字段 + matches = result.metadata.get("matches", []) + if isinstance(matches, list): + # 参数化查询不应该被误报为 SQL 注入 + sql_injection_count = sum( + 1 for m in matches + if isinstance(m, dict) and "sql" in m.get("pattern_type", "").lower() + ) + # 安全代码的 SQL 注入匹配应该很少或没有 + assert sql_injection_count <= 1 # 允许少量误报 + + +class TestToolResult: + """工具结果测试""" + + def test_tool_result_success(self): + """测试成功的工具结果""" + result = ToolResult(success=True, data="test data") + + assert result.success is True + assert result.data == "test data" + assert result.error is None + + def test_tool_result_failure(self): + """测试失败的工具结果""" + result = ToolResult(success=False, error="test error") + + assert result.success is False + assert result.error == "test error" + + def test_tool_result_to_string(self): + """测试工具结果转字符串""" + result = ToolResult(success=True, data={"key": "value"}) + + string = result.to_string() + + assert "key" in string + assert "value" in string + + def test_tool_result_to_string_truncate(self): + """测试工具结果字符串截断""" + long_data = "x" * 10000 + result = ToolResult(success=True, data=long_data) + + string = result.to_string(max_length=100) + + assert len(string) < len(long_data) + assert "truncated" in string.lower() + + +class TestToolMetadata: + """工具元数据测试""" + + @pytest.mark.asyncio + async def test_tool_call_count(self, temp_project_dir): + """测试工具调用计数""" + tool = ListFilesTool(temp_project_dir) + + await tool.execute(directory=".") + await tool.execute(directory="src") + + assert tool._call_count == 2 + + @pytest.mark.asyncio + async def test_tool_duration_tracking(self, temp_project_dir): + """测试工具执行时间跟踪""" + tool = ListFilesTool(temp_project_dir) + + result = await tool.execute(directory=".") + + assert result.duration_ms >= 0 + assert tool._total_duration_ms >= 0 + diff --git a/docs/AGENT_DEPLOYMENT_CHECKLIST.md b/docs/AGENT_DEPLOYMENT_CHECKLIST.md new file mode 100644 index 0000000..947fc08 --- /dev/null +++ b/docs/AGENT_DEPLOYMENT_CHECKLIST.md @@ -0,0 +1,229 @@ +# DeepAudit Agent 审计功能部署清单 + +## 📋 生产部署前必须完成的检查 + +### 1. 环境依赖 ✅ + +```bash +# 后端依赖 +cd backend +uv pip install chromadb litellm langchain langgraph + +# 外部安全工具(可选但推荐) +pip install semgrep bandit safety + +# 或者使用系统包管理器 +brew install semgrep # macOS +apt install semgrep # Ubuntu +``` + +### 2. LLM 配置 ✅ + +在 `.env` 文件中配置: + +```env +# LLM 配置(必须) +LLM_PROVIDER=openai # 或 azure, anthropic, ollama 等 +LLM_MODEL=gpt-4o-mini # 推荐使用 gpt-4 系列 +LLM_API_KEY=sk-xxx # 你的 API Key +LLM_BASE_URL= # 可选,自定义端点 + +# 嵌入模型配置(RAG 需要) +EMBEDDING_PROVIDER=openai +EMBEDDING_MODEL=text-embedding-3-small +``` + +### 3. 数据库迁移 ✅ + +```bash +cd backend +alembic upgrade head +``` + +确保以下表已创建: +- `agent_tasks` +- `agent_events` +- `agent_findings` + +### 4. 向量数据库 ✅ + +```bash +# 创建向量数据库目录 +mkdir -p /var/data/deepaudit/vector_db + +# 在 .env 中配置 +VECTOR_DB_PATH=/var/data/deepaudit/vector_db +``` + +### 5. Docker 沙箱(可选) + +如果需要漏洞验证功能: + +```bash +# 拉取沙箱镜像 +docker pull python:3.11-slim + +# 配置沙箱参数 +SANDBOX_IMAGE=python:3.11-slim +SANDBOX_MEMORY_LIMIT=256m +SANDBOX_CPU_LIMIT=0.5 +``` + +--- + +## 🔬 功能测试检查 + +### 测试 1: 基础流程 + +```bash +cd backend +PYTHONPATH=. uv run pytest tests/agent/ -v +``` + +预期结果:43 个测试全部通过 + +### 测试 2: LLM 连接 + +```bash +cd backend +PYTHONPATH=. uv run python -c " +import asyncio +from app.services.agent.graph.runner import LLMService + +async def test(): + llm = LLMService() + result = await llm.analyze_code('print(\"hello\")', 'python') + print('LLM 连接成功:', 'issues' in result) + +asyncio.run(test()) +" +``` + +### 测试 3: 外部工具 + +```bash +# 测试 Semgrep +semgrep --version + +# 测试 Bandit +bandit --version +``` + +### 测试 4: 端到端测试 + +1. 启动后端:`cd backend && uv run uvicorn app.main:app --reload` +2. 启动前端:`cd frontend && npm run dev` +3. 创建一个项目并上传代码 +4. 选择 "Agent 审计模式" 创建任务 +5. 观察执行日志和发现 + +--- + +## ⚠️ 已知限制 + +| 限制 | 影响 | 解决方案 | +|------|------|---------| +| **LLM 成本** | 每次审计消耗 Token | 使用 gpt-4o-mini 降低成本 | +| **扫描时间** | 大项目需要较长时间 | 设置合理的超时时间 | +| **误报率** | AI 可能产生误报 | 启用验证阶段过滤 | +| **外部工具依赖** | 需要手动安装 | 提供 Docker 镜像 | + +--- + +## 🚀 生产环境建议 + +### 1. 资源配置 + +```yaml +# Kubernetes 部署示例 +resources: + limits: + memory: "2Gi" + cpu: "2" + requests: + memory: "1Gi" + cpu: "1" +``` + +### 2. 并发控制 + +```env +# 限制同时运行的任务数 +MAX_CONCURRENT_AGENT_TASKS=3 +AGENT_TASK_TIMEOUT=1800 # 30 分钟 +``` + +### 3. 日志监控 + +```python +# 配置日志级别 +LOG_LEVEL=INFO +# 启用 SQLAlchemy 日志(调试用) +SQLALCHEMY_ECHO=false +``` + +### 4. 安全考虑 + +- [ ] 限制上传文件大小 +- [ ] 限制扫描目录范围 +- [ ] 启用沙箱隔离 +- [ ] 配置 API 速率限制 + +--- + +## ✅ 部署状态检查 + +运行以下命令验证部署状态: + +```bash +cd backend +PYTHONPATH=. uv run python -c " +print('检查部署状态...') + +# 1. 检查数据库连接 +try: + from app.db.session import async_session_factory + print('✅ 数据库配置正确') +except Exception as e: + print(f'❌ 数据库错误: {e}') + +# 2. 检查 LLM 配置 +from app.core.config import settings +if settings.LLM_API_KEY: + print('✅ LLM API Key 已配置') +else: + print('⚠️ LLM API Key 未配置') + +# 3. 检查向量数据库 +import os +if os.path.exists(settings.VECTOR_DB_PATH or '/tmp'): + print('✅ 向量数据库路径存在') +else: + print('⚠️ 向量数据库路径不存在') + +# 4. 检查外部工具 +import shutil +tools = ['semgrep', 'bandit'] +for tool in tools: + if shutil.which(tool): + print(f'✅ {tool} 已安装') + else: + print(f'⚠️ {tool} 未安装(可选)') + +print() +print('部署检查完成!') +" +``` + +--- + +## 📝 结论 + +Agent 审计功能已经具备**基本的生产能力**,但建议: + +1. **先在测试环境验证** - 用一个小项目测试完整流程 +2. **监控 LLM 成本** - 观察 Token 消耗情况 +3. **逐步开放** - 先给少数用户使用,收集反馈 +4. **持续优化** - 根据实际效果调整 prompt 和阈值 + +如有问题,请查看日志或联系开发团队。 diff --git a/frontend/Dockerfile b/frontend/Dockerfile index 3a28948..0c56cc1 100644 --- a/frontend/Dockerfile +++ b/frontend/Dockerfile @@ -55,3 +55,5 @@ EXPOSE 3000 ENTRYPOINT ["/docker-entrypoint.sh"] CMD ["serve", "-s", "dist", "-l", "3000"] + + diff --git a/frontend/src/app/ProtectedRoute.tsx b/frontend/src/app/ProtectedRoute.tsx index 6373c97..94922f5 100644 --- a/frontend/src/app/ProtectedRoute.tsx +++ b/frontend/src/app/ProtectedRoute.tsx @@ -18,3 +18,5 @@ export const ProtectedRoute = () => { + + diff --git a/frontend/src/hooks/useAgentStream.ts b/frontend/src/hooks/useAgentStream.ts new file mode 100644 index 0000000..3770ad7 --- /dev/null +++ b/frontend/src/hooks/useAgentStream.ts @@ -0,0 +1,280 @@ +/** + * Agent 流式事件 React Hook + * + * 用于在 React 组件中消费 Agent 审计的实时事件流 + */ + +import { useState, useEffect, useCallback, useRef } from 'react'; +import { + AgentStreamHandler, + StreamEventData, + StreamOptions, + AgentStreamState, +} from '../shared/api/agentStream'; + +export interface UseAgentStreamOptions extends Omit { + autoConnect?: boolean; + maxEvents?: number; +} + +export interface UseAgentStreamReturn extends AgentStreamState { + connect: () => void; + disconnect: () => void; + isConnected: boolean; + clearEvents: () => void; +} + +/** + * Agent 流式事件 Hook + * + * @example + * ```tsx + * function AgentAuditPanel({ taskId }: { taskId: string }) { + * const { + * events, + * thinking, + * isThinking, + * toolCalls, + * currentPhase, + * progress, + * findings, + * isComplete, + * error, + * connect, + * disconnect, + * isConnected, + * } = useAgentStream(taskId); + * + * useEffect(() => { + * connect(); + * return () => disconnect(); + * }, [taskId]); + * + * return ( + *
+ * {isThinking && } + * {toolCalls.map(tc => )} + * {findings.map(f => )} + *
+ * ); + * } + * ``` + */ +export function useAgentStream( + taskId: string | null, + options: UseAgentStreamOptions = {} +): UseAgentStreamReturn { + const { + autoConnect = false, + maxEvents = 500, + includeThinking = true, + includeToolCalls = true, + afterSequence = 0, + ...callbackOptions + } = options; + + // 状态 + const [events, setEvents] = useState([]); + const [thinking, setThinking] = useState(''); + const [isThinking, setIsThinking] = useState(false); + const [toolCalls, setToolCalls] = useState([]); + const [currentPhase, setCurrentPhase] = useState(''); + const [progress, setProgress] = useState({ current: 0, total: 100, percentage: 0 }); + const [findings, setFindings] = useState[]>([]); + const [isComplete, setIsComplete] = useState(false); + const [error, setError] = useState(null); + const [isConnected, setIsConnected] = useState(false); + + // Handler ref + const handlerRef = useRef(null); + const thinkingBufferRef = useRef([]); + + // 连接 + const connect = useCallback(() => { + if (!taskId) return; + + // 断开现有连接 + if (handlerRef.current) { + handlerRef.current.disconnect(); + } + + // 重置状态 + setEvents([]); + setThinking(''); + setIsThinking(false); + setToolCalls([]); + setCurrentPhase(''); + setProgress({ current: 0, total: 100, percentage: 0 }); + setFindings([]); + setIsComplete(false); + setError(null); + thinkingBufferRef.current = []; + + // 创建新的 handler + handlerRef.current = new AgentStreamHandler(taskId, { + includeThinking, + includeToolCalls, + afterSequence, + + onEvent: (event) => { + setEvents((prev) => [...prev.slice(-maxEvents + 1), event]); + }, + + onThinkingStart: () => { + thinkingBufferRef.current = []; + setIsThinking(true); + setThinking(''); + callbackOptions.onThinkingStart?.(); + }, + + onThinkingToken: (token, accumulated) => { + thinkingBufferRef.current.push(token); + setThinking(accumulated); + callbackOptions.onThinkingToken?.(token, accumulated); + }, + + onThinkingEnd: (response) => { + setIsThinking(false); + setThinking(response); + thinkingBufferRef.current = []; + callbackOptions.onThinkingEnd?.(response); + }, + + onToolStart: (name, input) => { + setToolCalls((prev) => [ + ...prev, + { name, input, status: 'running' as const }, + ]); + callbackOptions.onToolStart?.(name, input); + }, + + onToolEnd: (name, output, durationMs) => { + setToolCalls((prev) => + prev.map((tc) => + tc.name === name && tc.status === 'running' + ? { ...tc, output, durationMs, status: 'success' as const } + : tc + ) + ); + callbackOptions.onToolEnd?.(name, output, durationMs); + }, + + onNodeStart: (nodeName, phase) => { + setCurrentPhase(phase); + callbackOptions.onNodeStart?.(nodeName, phase); + }, + + onNodeEnd: (nodeName, summary) => { + callbackOptions.onNodeEnd?.(nodeName, summary); + }, + + onProgress: (current, total, message) => { + setProgress({ + current, + total, + percentage: total > 0 ? Math.round((current / total) * 100) : 0, + }); + callbackOptions.onProgress?.(current, total, message); + }, + + onFinding: (finding, isVerified) => { + setFindings((prev) => [...prev, finding]); + callbackOptions.onFinding?.(finding, isVerified); + }, + + onComplete: (data) => { + setIsComplete(true); + setIsConnected(false); + callbackOptions.onComplete?.(data); + }, + + onError: (err) => { + setError(err); + setIsComplete(true); + setIsConnected(false); + callbackOptions.onError?.(err); + }, + + onHeartbeat: () => { + callbackOptions.onHeartbeat?.(); + }, + }); + + handlerRef.current.connect(); + setIsConnected(true); + }, [taskId, includeThinking, includeToolCalls, afterSequence, maxEvents, callbackOptions]); + + // 断开连接 + const disconnect = useCallback(() => { + if (handlerRef.current) { + handlerRef.current.disconnect(); + handlerRef.current = null; + } + setIsConnected(false); + }, []); + + // 清空事件 + const clearEvents = useCallback(() => { + setEvents([]); + }, []); + + // 自动连接 + useEffect(() => { + if (autoConnect && taskId) { + connect(); + } + return () => { + disconnect(); + }; + }, [taskId, autoConnect, connect, disconnect]); + + // 清理 + useEffect(() => { + return () => { + if (handlerRef.current) { + handlerRef.current.disconnect(); + } + }; + }, []); + + return { + events, + thinking, + isThinking, + toolCalls, + currentPhase, + progress, + findings, + isComplete, + error, + connect, + disconnect, + isConnected, + clearEvents, + }; +} + +/** + * 简化版 Hook - 只获取思考过程 + */ +export function useAgentThinking(taskId: string | null) { + const { thinking, isThinking, connect, disconnect } = useAgentStream(taskId, { + includeToolCalls: false, + }); + + return { thinking, isThinking, connect, disconnect }; +} + +/** + * 简化版 Hook - 只获取工具调用 + */ +export function useAgentToolCalls(taskId: string | null) { + const { toolCalls, connect, disconnect } = useAgentStream(taskId, { + includeThinking: false, + }); + + return { toolCalls, connect, disconnect }; +} + +export default useAgentStream; + diff --git a/frontend/src/pages/AgentAudit.tsx b/frontend/src/pages/AgentAudit.tsx index 031166e..326ce39 100644 --- a/frontend/src/pages/AgentAudit.tsx +++ b/frontend/src/pages/AgentAudit.tsx @@ -1,6 +1,7 @@ /** * Agent 审计页面 * 机械终端风格的 AI Agent 审计界面 + * 支持 LLM 思考过程和工具调用的实时流式展示 */ import { useState, useEffect, useRef, useCallback } from "react"; @@ -9,12 +10,14 @@ import { Terminal, Bot, Cpu, Shield, AlertTriangle, CheckCircle2, Loader2, Code, Zap, Activity, ChevronRight, XCircle, FileCode, Search, Bug, Lock, Play, Square, RefreshCw, - ArrowLeft, Download, ExternalLink + ArrowLeft, Download, ExternalLink, Brain, Wrench, + ChevronDown, ChevronUp, Clock, Sparkles } from "lucide-react"; import { Button } from "@/components/ui/button"; import { Badge } from "@/components/ui/badge"; import { ScrollArea } from "@/components/ui/scroll-area"; import { toast } from "sonner"; +import { useAgentStream } from "@/hooks/useAgentStream"; import { type AgentTask, type AgentEvent, @@ -91,10 +94,36 @@ export default function AgentAuditPage() { const [isLoading, setIsLoading] = useState(true); const [isStreaming, setIsStreaming] = useState(false); const [currentTime, setCurrentTime] = useState(new Date().toLocaleTimeString("zh-CN", { hour12: false })); + const [showThinking, setShowThinking] = useState(true); + const [showToolDetails, setShowToolDetails] = useState(true); const eventsEndRef = useRef(null); + const thinkingEndRef = useRef(null); const abortControllerRef = useRef(null); + // 使用增强版流式 Hook + const { + thinking, + isThinking, + toolCalls, + currentPhase: streamPhase, + progress: streamProgress, + connect: connectStream, + disconnect: disconnectStream, + isConnected: isStreamConnected, + } = useAgentStream(taskId || null, { + includeThinking: true, + includeToolCalls: true, + onFinding: () => loadFindings(), + onComplete: () => { + loadTask(); + loadFindings(); + }, + onError: (err) => { + console.error("Stream error:", err); + }, + }); + // 是否完成 const isComplete = task?.status === "completed" || task?.status === "failed" || task?.status === "cancelled"; @@ -146,12 +175,24 @@ export default function AgentAuditPage() { init(); }, [loadTask, loadEvents, loadFindings]); - // 事件流 + // 连接增强版流式 API + useEffect(() => { + if (!taskId || isComplete || isLoading) return; + + connectStream(); + setIsStreaming(true); + + return () => { + disconnectStream(); + setIsStreaming(false); + }; + }, [taskId, isComplete, isLoading, connectStream, disconnectStream]); + + // 旧版事件流(作为后备) useEffect(() => { if (!taskId || isComplete || isLoading) return; const startStreaming = async () => { - setIsStreaming(true); abortControllerRef.current = new AbortController(); try { @@ -179,8 +220,6 @@ export default function AgentAuditPage() { if ((error as Error).name !== "AbortError") { console.error("Event stream error:", error); } - } finally { - setIsStreaming(false); } }; @@ -205,6 +244,30 @@ export default function AgentAuditPage() { return () => clearInterval(interval); }, []); + // 定期轮询任务状态(作为 SSE 的后备机制) + useEffect(() => { + if (!taskId || isComplete || isLoading) return; + + // 每 3 秒轮询一次任务状态 + const pollInterval = setInterval(async () => { + try { + const taskData = await getAgentTask(taskId); + setTask(taskData); + + // 如果任务已完成/失败/取消,刷新其他数据 + if (taskData.status === "completed" || taskData.status === "failed" || taskData.status === "cancelled") { + await loadEvents(); + await loadFindings(); + clearInterval(pollInterval); + } + } catch (error) { + console.error("Failed to poll task status:", error); + } + }, 3000); + + return () => clearInterval(pollInterval); + }, [taskId, isComplete, isLoading, loadEvents, loadFindings]); + // 取消任务 const handleCancel = async () => { if (!taskId) return; @@ -291,14 +354,85 @@ export default function AgentAuditPage() { + {/* 错误提示 */} + {task.status === "failed" && task.error_message && ( +
+
+ +
+

任务执行失败

+

{task.error_message}

+
+
+
+ )} +
{/* 左侧:执行日志 */}
+ + {/* 思考过程展示区域 */} + {(isThinking || thinking) && showThinking && ( +
+
setShowThinking(!showThinking)} + > +
+ + AI Thinking + {isThinking && ( + + + Processing... + + )} +
+ {showThinking ? : } +
+ +
+
+ {thinking || "正在思考..."} + {isThinking && } +
+
+
+
+ )} + + {/* 工具调用展示区域 */} + {toolCalls.length > 0 && showToolDetails && ( +
+
setShowToolDetails(!showToolDetails)} + > +
+ + Tool Calls + + {toolCalls.length} + +
+ {showToolDetails ? : } +
+ +
+
+ {toolCalls.slice(-5).map((tc, idx) => ( + + ))} +
+
+
+ )} +
Execution Log - {isStreaming && ( + {(isStreaming || isStreamConnected) && ( LIVE @@ -540,6 +674,94 @@ function EventLine({ event }: { event: AgentEvent }) { ); } +// 工具调用卡片组件 +interface ToolCallProps { + toolCall: { + name: string; + input: Record; + output?: unknown; + durationMs?: number; + status: 'running' | 'success' | 'error'; + }; +} + +function ToolCallCard({ toolCall }: ToolCallProps) { + const [expanded, setExpanded] = useState(false); + + const statusConfig = { + running: { + icon: , + badge: "bg-yellow-900/30 border-yellow-700 text-yellow-400", + text: "Running", + }, + success: { + icon: , + badge: "bg-green-900/30 border-green-700 text-green-400", + text: "Done", + }, + error: { + icon: , + badge: "bg-red-900/30 border-red-700 text-red-400", + text: "Error", + }, + }; + + const config = statusConfig[toolCall.status]; + + return ( +
+
setExpanded(!expanded)} + > +
+ {config.icon} + {toolCall.name} +
+
+ {toolCall.durationMs && ( + + + {toolCall.durationMs}ms + + )} + + {config.text} + + {expanded ? : } +
+
+ + {expanded && ( +
+ {/* 输入 */} + {toolCall.input && Object.keys(toolCall.input).length > 0 && ( +
+ Input: +
+                {JSON.stringify(toolCall.input, null, 2).slice(0, 500)}
+              
+
+ )} + + {/* 输出 */} + {toolCall.output && ( +
+ Output: +
+                {typeof toolCall.output === 'string' 
+                  ? toolCall.output.slice(0, 500)
+                  : JSON.stringify(toolCall.output, null, 2).slice(0, 500)
+                }
+              
+
+ )} +
+ )} +
+ ); +} + // 发现卡片组件 function FindingCard({ finding }: { finding: AgentFinding }) { const colorClass = severityColors[finding.severity] || severityColors.info; diff --git a/frontend/src/shared/api/agentStream.ts b/frontend/src/shared/api/agentStream.ts new file mode 100644 index 0000000..82fa28d --- /dev/null +++ b/frontend/src/shared/api/agentStream.ts @@ -0,0 +1,496 @@ +/** + * Agent 流式事件处理 + * + * 最佳实践: + * 1. 使用 EventSource API 或 fetch + ReadableStream + * 2. 支持重连机制 + * 3. 分类处理不同事件类型 + */ + +// 事件类型定义 +export type StreamEventType = + // LLM 相关 + | 'thinking_start' + | 'thinking_token' + | 'thinking_end' + // 工具调用相关 + | 'tool_call_start' + | 'tool_call_input' + | 'tool_call_output' + | 'tool_call_end' + | 'tool_call_error' + // 节点相关 + | 'node_start' + | 'node_end' + // 阶段相关 + | 'phase_start' + | 'phase_end' + | 'phase_complete' + // 发现相关 + | 'finding_new' + | 'finding_verified' + // 状态相关 + | 'progress' + | 'info' + | 'warning' + | 'error' + // 任务相关 + | 'task_start' + | 'task_complete' + | 'task_error' + | 'task_cancel' + | 'task_end' + // 心跳 + | 'heartbeat'; + +// 工具调用详情 +export interface ToolCallDetail { + name: string; + input?: Record; + output?: unknown; + duration_ms?: number; +} + +// 流式事件数据 +export interface StreamEventData { + id?: string; + type: StreamEventType; + phase?: string; + message?: string; + sequence?: number; + timestamp?: string; + tool?: ToolCallDetail; + metadata?: Record; + tokens_used?: number; + // 特定类型数据 + token?: string; // thinking_token + accumulated?: string; // thinking_token/thinking_end + status?: string; // task_end + error?: string; // task_error + findings_count?: number; // task_complete + security_score?: number; // task_complete +} + +// 事件回调类型 +export type StreamEventCallback = (event: StreamEventData) => void; + +// 流式选项 +export interface StreamOptions { + includeThinking?: boolean; + includeToolCalls?: boolean; + afterSequence?: number; + onThinkingStart?: () => void; + onThinkingToken?: (token: string, accumulated: string) => void; + onThinkingEnd?: (fullResponse: string) => void; + onToolStart?: (toolName: string, input: Record) => void; + onToolEnd?: (toolName: string, output: unknown, durationMs: number) => void; + onNodeStart?: (nodeName: string, phase: string) => void; + onNodeEnd?: (nodeName: string, summary: Record) => void; + onFinding?: (finding: Record, isVerified: boolean) => void; + onProgress?: (current: number, total: number, message: string) => void; + onComplete?: (data: { findingsCount: number; securityScore: number }) => void; + onError?: (error: string) => void; + onHeartbeat?: () => void; + onEvent?: StreamEventCallback; // 通用事件回调 +} + +/** + * Agent 流式事件处理器 + */ +export class AgentStreamHandler { + private taskId: string; + private eventSource: EventSource | null = null; + private options: StreamOptions; + private reconnectAttempts = 0; + private maxReconnectAttempts = 5; + private reconnectDelay = 1000; + private isConnected = false; + private thinkingBuffer: string[] = []; + + constructor(taskId: string, options: StreamOptions = {}) { + this.taskId = taskId; + this.options = { + includeThinking: true, + includeToolCalls: true, + afterSequence: 0, + ...options, + }; + } + + /** + * 开始监听事件流 + */ + connect(): void { + const token = localStorage.getItem('access_token'); + if (!token) { + this.options.onError?.('未登录'); + return; + } + + const params = new URLSearchParams({ + include_thinking: String(this.options.includeThinking), + include_tool_calls: String(this.options.includeToolCalls), + after_sequence: String(this.options.afterSequence), + }); + + // 使用 EventSource (不支持自定义 headers,需要通过 URL 传递 token) + // 或者使用 fetch + ReadableStream + this.connectWithFetch(token, params); + } + + /** + * 使用 fetch 连接(支持自定义 headers) + */ + private async connectWithFetch(token: string, params: URLSearchParams): Promise { + const url = `/api/v1/agent-tasks/${this.taskId}/stream?${params}`; + + try { + const response = await fetch(url, { + headers: { + 'Authorization': `Bearer ${token}`, + 'Accept': 'text/event-stream', + }, + }); + + if (!response.ok) { + throw new Error(`HTTP ${response.status}: ${response.statusText}`); + } + + this.isConnected = true; + this.reconnectAttempts = 0; + + const reader = response.body?.getReader(); + if (!reader) { + throw new Error('无法获取响应流'); + } + + const decoder = new TextDecoder(); + let buffer = ''; + + while (true) { + const { done, value } = await reader.read(); + + if (done) { + break; + } + + buffer += decoder.decode(value, { stream: true }); + + // 解析 SSE 事件 + const events = this.parseSSE(buffer); + buffer = events.remaining; + + for (const event of events.parsed) { + this.handleEvent(event); + } + } + } catch (error) { + this.isConnected = false; + console.error('Stream connection error:', error); + + // 尝试重连 + if (this.reconnectAttempts < this.maxReconnectAttempts) { + this.reconnectAttempts++; + setTimeout(() => this.connect(), this.reconnectDelay * this.reconnectAttempts); + } else { + this.options.onError?.(`连接失败: ${error}`); + } + } + } + + /** + * 解析 SSE 格式 + */ + private parseSSE(buffer: string): { parsed: StreamEventData[]; remaining: string } { + const parsed: StreamEventData[] = []; + const lines = buffer.split('\n'); + let remaining = ''; + let currentEvent: Partial = {}; + + for (let i = 0; i < lines.length; i++) { + const line = lines[i]; + + // 空行表示事件结束 + if (line === '') { + if (currentEvent.type) { + parsed.push(currentEvent as StreamEventData); + currentEvent = {}; + } + continue; + } + + // 检查是否是最后一行(可能不完整) + if (i === lines.length - 1 && !buffer.endsWith('\n')) { + remaining = line; + break; + } + + // 解析 event: 行 + if (line.startsWith('event:')) { + currentEvent.type = line.slice(6).trim() as StreamEventType; + } + // 解析 data: 行 + else if (line.startsWith('data:')) { + try { + const data = JSON.parse(line.slice(5).trim()); + currentEvent = { ...currentEvent, ...data }; + } catch { + // 忽略解析错误 + } + } + } + + return { parsed, remaining }; + } + + /** + * 处理事件 + */ + private handleEvent(event: StreamEventData): void { + // 通用回调 + this.options.onEvent?.(event); + + // 分类处理 + switch (event.type) { + // LLM 思考 + case 'thinking_start': + this.thinkingBuffer = []; + this.options.onThinkingStart?.(); + break; + + case 'thinking_token': + if (event.token) { + this.thinkingBuffer.push(event.token); + this.options.onThinkingToken?.( + event.token, + event.accumulated || this.thinkingBuffer.join('') + ); + } + break; + + case 'thinking_end': + const fullResponse = event.accumulated || this.thinkingBuffer.join(''); + this.thinkingBuffer = []; + this.options.onThinkingEnd?.(fullResponse); + break; + + // 工具调用 + case 'tool_call_start': + if (event.tool) { + this.options.onToolStart?.( + event.tool.name, + event.tool.input || {} + ); + } + break; + + case 'tool_call_end': + if (event.tool) { + this.options.onToolEnd?.( + event.tool.name, + event.tool.output, + event.tool.duration_ms || 0 + ); + } + break; + + // 节点 + case 'node_start': + this.options.onNodeStart?.( + event.metadata?.node as string || 'unknown', + event.phase || '' + ); + break; + + case 'node_end': + this.options.onNodeEnd?.( + event.metadata?.node as string || 'unknown', + event.metadata?.summary as Record || {} + ); + break; + + // 发现 + case 'finding_new': + case 'finding_verified': + this.options.onFinding?.( + event.metadata || {}, + event.type === 'finding_verified' + ); + break; + + // 进度 + case 'progress': + this.options.onProgress?.( + event.metadata?.current as number || 0, + event.metadata?.total as number || 100, + event.message || '' + ); + break; + + // 任务完成 + case 'task_complete': + case 'task_end': + if (event.status !== 'cancelled' && event.status !== 'failed') { + this.options.onComplete?.({ + findingsCount: event.findings_count || event.metadata?.findings_count as number || 0, + securityScore: event.security_score || event.metadata?.security_score as number || 100, + }); + } + this.disconnect(); + break; + + // 错误 + case 'task_error': + case 'error': + this.options.onError?.(event.error || event.message || '未知错误'); + this.disconnect(); + break; + + // 心跳 + case 'heartbeat': + this.options.onHeartbeat?.(); + break; + } + } + + /** + * 断开连接 + */ + disconnect(): void { + this.isConnected = false; + if (this.eventSource) { + this.eventSource.close(); + this.eventSource = null; + } + } + + /** + * 检查是否已连接 + */ + get connected(): boolean { + return this.isConnected; + } +} + +/** + * 创建流式事件处理器的便捷函数 + */ +export function createAgentStream( + taskId: string, + options: StreamOptions = {} +): AgentStreamHandler { + return new AgentStreamHandler(taskId, options); +} + +/** + * React Hook 风格的使用示例 + * + * ```tsx + * const { events, thinking, toolCalls, connect, disconnect } = useAgentStream(taskId); + * + * useEffect(() => { + * connect(); + * return () => disconnect(); + * }, [taskId]); + * ``` + */ +export interface AgentStreamState { + events: StreamEventData[]; + thinking: string; + isThinking: boolean; + toolCalls: Array<{ + name: string; + input: Record; + output?: unknown; + durationMs?: number; + status: 'running' | 'success' | 'error'; + }>; + currentPhase: string; + progress: { current: number; total: number; percentage: number }; + findings: Array>; + isComplete: boolean; + error: string | null; +} + +/** + * 创建用于 React 状态管理的流式处理器 + */ +export function createAgentStreamWithState( + taskId: string, + onStateChange: (state: AgentStreamState) => void +): AgentStreamHandler { + const state: AgentStreamState = { + events: [], + thinking: '', + isThinking: false, + toolCalls: [], + currentPhase: '', + progress: { current: 0, total: 100, percentage: 0 }, + findings: [], + isComplete: false, + error: null, + }; + + const updateState = (updates: Partial) => { + Object.assign(state, updates); + onStateChange({ ...state }); + }; + + return new AgentStreamHandler(taskId, { + onEvent: (event) => { + updateState({ + events: [...state.events, event].slice(-500), // 保留最近 500 条 + }); + }, + onThinkingStart: () => { + updateState({ isThinking: true, thinking: '' }); + }, + onThinkingToken: (_, accumulated) => { + updateState({ thinking: accumulated }); + }, + onThinkingEnd: (response) => { + updateState({ isThinking: false, thinking: response }); + }, + onToolStart: (name, input) => { + updateState({ + toolCalls: [ + ...state.toolCalls, + { name, input, status: 'running' }, + ], + }); + }, + onToolEnd: (name, output, durationMs) => { + updateState({ + toolCalls: state.toolCalls.map((tc) => + tc.name === name && tc.status === 'running' + ? { ...tc, output, durationMs, status: 'success' as const } + : tc + ), + }); + }, + onNodeStart: (_, phase) => { + updateState({ currentPhase: phase }); + }, + onProgress: (current, total, _) => { + updateState({ + progress: { + current, + total, + percentage: total > 0 ? Math.round((current / total) * 100) : 0, + }, + }); + }, + onFinding: (finding, _) => { + updateState({ + findings: [...state.findings, finding], + }); + }, + onComplete: () => { + updateState({ isComplete: true }); + }, + onError: (error) => { + updateState({ error, isComplete: true }); + }, + }); +} + diff --git a/frontend/src/shared/api/agentTasks.ts b/frontend/src/shared/api/agentTasks.ts index c7342fc..c6aecd4 100644 --- a/frontend/src/shared/api/agentTasks.ts +++ b/frontend/src/shared/api/agentTasks.ts @@ -43,6 +43,9 @@ export interface AgentTask { // 进度 progress_percentage: number; + + // 错误信息 + error_message: string | null; } export interface AgentFinding { @@ -249,7 +252,7 @@ export async function* streamAgentEvents( afterSequence = 0, signal?: AbortSignal ): AsyncGenerator { - const token = localStorage.getItem("auth_token"); + const token = localStorage.getItem("access_token"); const baseUrl = import.meta.env.VITE_API_URL || ""; const url = `${baseUrl}/api/v1/agent-tasks/${taskId}/events?after_sequence=${afterSequence}`;