diff --git a/backend/app/api/v1/endpoints/agent_tasks.py b/backend/app/api/v1/endpoints/agent_tasks.py index d8c1a8e..16e8d2f 100644 --- a/backend/app/api/v1/endpoints/agent_tasks.py +++ b/backend/app/api/v1/endpoints/agent_tasks.py @@ -383,7 +383,7 @@ async def _execute_agent_task(task_id: str): logger.debug(f"[AgentTask] Finding {i+1}: {f.get('title', 'N/A')[:50]} - {f.get('severity', 'N/A')}") await _save_findings(db, task_id, findings) - + # 更新任务统计 task.status = AgentTaskStatus.COMPLETED task.completed_at = datetime.now(timezone.utc) @@ -392,11 +392,21 @@ async def _execute_agent_task(task_id: str): task.total_iterations = result.iterations task.tool_calls_count = result.tool_calls task.tokens_used = result.tokens_used - - # 统计严重程度 + + # 🔥 统计分析的文件数量(从 findings 中提取唯一文件) + analyzed_file_set = set() for f in findings: if isinstance(f, dict): - sev = f.get("severity", "low") + file_path = f.get("file_path") or f.get("file") or f.get("location", "").split(":")[0] + if file_path: + analyzed_file_set.add(file_path) + task.analyzed_files = len(analyzed_file_set) if analyzed_file_set else task.total_files + + # 统计严重程度和验证状态 + verified_count = 0 + for f in findings: + if isinstance(f, dict): + sev = str(f.get("severity", "low")).lower() if sev == "critical": task.critical_count += 1 elif sev == "high": @@ -405,6 +415,10 @@ async def _execute_agent_task(task_id: str): task.medium_count += 1 elif sev == "low": task.low_count += 1 + # 🔥 统计已验证的发现 + if f.get("is_verified") or f.get("verdict") == "confirmed": + verified_count += 1 + task.verified_count = verified_count # 计算安全评分 task.security_score = _calculate_security_score(findings) @@ -462,16 +476,22 @@ async def _execute_agent_task(task_id: str): logger.error(f"Failed to update task status: {db_error}") finally: + # 🔥 在清理之前保存 Agent 树到数据库 + try: + async with async_session_factory() as save_db: + await _save_agent_tree(save_db, task_id) + except Exception as save_error: + logger.error(f"Failed to save agent tree: {save_error}") + # 清理 _running_orchestrators.pop(task_id, None) _running_tasks.pop(task_id, None) _running_event_managers.pop(task_id, None) _running_asyncio_tasks.pop(task_id, None) # 🔥 清理 asyncio task - - # 从 Registry 注销 - if orchestrator: - agent_registry.unregister_agent(orchestrator.agent_id) - + + # 🔥 清理整个 Agent 注册表(包括所有子 Agent) + agent_registry.clear() + logger.debug(f"Task {task_id} cleaned up") @@ -713,7 +733,11 @@ async def _collect_project_info( async def _save_findings(db: AsyncSession, task_id: str, findings: List[Dict]) -> None: - """保存发现到数据库""" + """ + 保存发现到数据库 + + 🔥 增强版:支持多种 Agent 输出格式,健壮的字段映射 + """ from app.models.agent_task import VulnerabilityType logger.info(f"[SaveFindings] Starting to save {len(findings)} findings for task {task_id}") @@ -730,7 +754,7 @@ async def _save_findings(db: AsyncSession, task_id: str, findings: List[Dict]) - "low": VulnerabilitySeverity.LOW, "info": VulnerabilitySeverity.INFO, } - + type_map = { "sql_injection": VulnerabilityType.SQL_INJECTION, "nosql_injection": VulnerabilityType.NOSQL_INJECTION, @@ -744,65 +768,205 @@ async def _save_findings(db: AsyncSession, task_id: str, findings: List[Dict]) - "idor": VulnerabilityType.IDOR, "sensitive_data_exposure": VulnerabilityType.SENSITIVE_DATA_EXPOSURE, "hardcoded_secret": VulnerabilityType.HARDCODED_SECRET, - "deserialization": VulnerabilityType.DESERIALIZATION, # Added common type - "weak_crypto": VulnerabilityType.WEAK_CRYPTO, # Added common type + "deserialization": VulnerabilityType.DESERIALIZATION, + "weak_crypto": VulnerabilityType.WEAK_CRYPTO, + "file_inclusion": VulnerabilityType.FILE_INCLUSION, + "race_condition": VulnerabilityType.RACE_CONDITION, + "business_logic": VulnerabilityType.BUSINESS_LOGIC, + "memory_corruption": VulnerabilityType.MEMORY_CORRUPTION, } - + saved_count = 0 logger.info(f"Saving {len(findings)} findings for task {task_id}") for finding in findings: if not isinstance(finding, dict): + logger.debug(f"[SaveFindings] Skipping non-dict finding: {type(finding)}") continue - + try: - # Handle severity (case-insensitive) - raw_severity = str(finding.get("severity", "medium")).lower().strip() + # 🔥 Handle severity (case-insensitive, support multiple field names) + raw_severity = str( + finding.get("severity") or + finding.get("risk") or + "medium" + ).lower().strip() severity_enum = severity_map.get(raw_severity, VulnerabilitySeverity.MEDIUM) - - # Handle vulnerability type (case-insensitive & snake_case normalization) - raw_type = str(finding.get("vulnerability_type", "other")).lower().strip().replace(" ", "_") + + # 🔥 Handle vulnerability type (case-insensitive & snake_case normalization) + # Support multiple field names: vulnerability_type, type, vuln_type + raw_type = str( + finding.get("vulnerability_type") or + finding.get("type") or + finding.get("vuln_type") or + "other" + ).lower().strip().replace(" ", "_").replace("-", "_") + type_enum = type_map.get(raw_type, VulnerabilityType.OTHER) - - # Additional fallback for known Agent output variations - if "sqli" in raw_type: type_enum = VulnerabilityType.SQL_INJECTION - if "xss" in raw_type: type_enum = VulnerabilityType.XSS - if "rce" in raw_type or "command" in raw_type: type_enum = VulnerabilityType.COMMAND_INJECTION + + # 🔥 Additional fallback for common Agent output variations + if "sqli" in raw_type or "sql" in raw_type: + type_enum = VulnerabilityType.SQL_INJECTION + if "xss" in raw_type: + type_enum = VulnerabilityType.XSS + if "rce" in raw_type or "command" in raw_type or "cmd" in raw_type: + type_enum = VulnerabilityType.COMMAND_INJECTION + if "traversal" in raw_type or "lfi" in raw_type or "rfi" in raw_type: + type_enum = VulnerabilityType.PATH_TRAVERSAL + if "ssrf" in raw_type: + type_enum = VulnerabilityType.SSRF + if "xxe" in raw_type: + type_enum = VulnerabilityType.XXE + if "auth" in raw_type: + type_enum = VulnerabilityType.AUTH_BYPASS + if "secret" in raw_type or "credential" in raw_type or "password" in raw_type: + type_enum = VulnerabilityType.HARDCODED_SECRET + if "deserial" in raw_type: + type_enum = VulnerabilityType.DESERIALIZATION + + # 🔥 Handle file path (support multiple field names) + file_path = ( + finding.get("file_path") or + finding.get("file") or + finding.get("location", "").split(":")[0] if ":" in finding.get("location", "") else finding.get("location") + ) + + # 🔥 Handle line numbers (support multiple formats) + line_start = finding.get("line_start") or finding.get("line") + if not line_start and ":" in finding.get("location", ""): + try: + line_start = int(finding.get("location", "").split(":")[1]) + except (ValueError, IndexError): + line_start = None + + line_end = finding.get("line_end") or line_start + + # 🔥 Handle code snippet (support multiple field names) + code_snippet = ( + finding.get("code_snippet") or + finding.get("code") or + finding.get("vulnerable_code") + ) + + # 🔥 Handle title (generate from type if not provided) + title = finding.get("title") + if not title: + # Generate title from vulnerability type and file + type_display = raw_type.replace("_", " ").title() + if file_path: + title = f"{type_display} in {os.path.basename(file_path)}" + else: + title = f"{type_display} Vulnerability" + + # 🔥 Handle description (support multiple field names) + description = ( + finding.get("description") or + finding.get("details") or + finding.get("explanation") or + finding.get("impact") or + "" + ) + + # 🔥 Handle suggestion/recommendation + suggestion = ( + finding.get("suggestion") or + finding.get("recommendation") or + finding.get("remediation") or + finding.get("fix") + ) + + # 🔥 Handle confidence (map to ai_confidence field in model) + confidence = finding.get("confidence") or finding.get("ai_confidence") or 0.5 + if isinstance(confidence, str): + try: + confidence = float(confidence) + except ValueError: + confidence = 0.5 + + # 🔥 Handle verification status + is_verified = finding.get("is_verified", False) + if finding.get("verdict") == "confirmed": + is_verified = True + + # 🔥 Handle PoC information + poc_data = finding.get("poc", {}) + has_poc = bool(poc_data) + poc_code = None + poc_description = None + poc_steps = None + + if isinstance(poc_data, dict): + poc_description = poc_data.get("description") + poc_steps = poc_data.get("steps") + poc_code = poc_data.get("payload") or poc_data.get("code") + elif isinstance(poc_data, str): + poc_description = poc_data + + # 🔥 Handle verification details + verification_method = finding.get("verification_method") + verification_result = None + if finding.get("verification_details"): + verification_result = {"details": finding.get("verification_details")} + + # 🔥 Handle CWE and CVSS + cwe_id = finding.get("cwe_id") or finding.get("cwe") + cvss_score = finding.get("cvss_score") or finding.get("cvss") + if isinstance(cvss_score, str): + try: + cvss_score = float(cvss_score) + except ValueError: + cvss_score = None db_finding = AgentFinding( id=str(uuid4()), task_id=task_id, vulnerability_type=type_enum, severity=severity_enum, - title=finding.get("title", "Unknown"), - description=finding.get("description", ""), - file_path=finding.get("file_path"), - line_start=finding.get("line_start"), - line_end=finding.get("line_end"), - code_snippet=finding.get("code_snippet"), - suggestion=finding.get("suggestion") or finding.get("recommendation"), - is_verified=finding.get("is_verified", False), - confidence=finding.get("confidence", 0.5), - status=FindingStatus.VERIFIED if finding.get("is_verified") else FindingStatus.NEW, + title=title[:500] if title else "Unknown Vulnerability", + description=description[:5000] if description else "", + file_path=file_path[:500] if file_path else None, + line_start=line_start, + line_end=line_end, + code_snippet=code_snippet[:10000] if code_snippet else None, + suggestion=suggestion[:5000] if suggestion else None, + is_verified=is_verified, + ai_confidence=confidence, # 🔥 FIX: Use ai_confidence, not confidence + status=FindingStatus.VERIFIED if is_verified else FindingStatus.NEW, + # 🔥 Additional fields + has_poc=has_poc, + poc_code=poc_code, + poc_description=poc_description, + poc_steps=poc_steps, + verification_method=verification_method, + verification_result=verification_result, + cvss_score=cvss_score, + # References for CWE + references=[{"cwe": cwe_id}] if cwe_id else None, ) db.add(db_finding) saved_count += 1 + logger.debug(f"[SaveFindings] Prepared finding: {title[:50]}... ({severity_enum})") + except Exception as e: logger.warning(f"Failed to save finding: {e}, data: {finding}") - + import traceback + logger.debug(f"[SaveFindings] Traceback: {traceback.format_exc()}") + logger.info(f"Successfully prepared {saved_count} findings for commit") - + try: await db.commit() + logger.info(f"[SaveFindings] Successfully committed {saved_count} findings to database") except Exception as e: logger.error(f"Failed to commit findings: {e}") + await db.rollback() def _calculate_security_score(findings: List[Dict]) -> float: """计算安全评分""" if not findings: return 100.0 - + # 基于发现的严重程度计算扣分 deductions = { "critical": 25, @@ -811,17 +975,103 @@ def _calculate_security_score(findings: List[Dict]) -> float: "low": 3, "info": 1, } - + total_deduction = 0 for f in findings: if isinstance(f, dict): sev = f.get("severity", "low") total_deduction += deductions.get(sev, 3) - + score = max(0, 100 - total_deduction) return float(score) +async def _save_agent_tree(db: AsyncSession, task_id: str) -> None: + """ + 保存 Agent 树到数据库 + + 🔥 在任务完成前调用,将内存中的 Agent 树持久化到数据库 + """ + from app.models.agent_task import AgentTreeNode + from app.services.agent.core import agent_registry + + try: + tree = agent_registry.get_agent_tree() + nodes = tree.get("nodes", {}) + + if not nodes: + logger.warning(f"[SaveAgentTree] No agent nodes to save for task {task_id}") + return + + logger.info(f"[SaveAgentTree] Saving {len(nodes)} agent nodes for task {task_id}") + + # 计算每个节点的深度 + def get_depth(agent_id: str, visited: set = None) -> int: + if visited is None: + visited = set() + if agent_id in visited: + return 0 + visited.add(agent_id) + node = nodes.get(agent_id) + if not node: + return 0 + parent_id = node.get("parent_id") + if not parent_id: + return 0 + return 1 + get_depth(parent_id, visited) + + saved_count = 0 + for agent_id, node_data in nodes.items(): + # 获取 Agent 实例的统计数据 + agent_instance = agent_registry.get_agent(agent_id) + iterations = 0 + tool_calls = 0 + tokens_used = 0 + + if agent_instance and hasattr(agent_instance, 'get_stats'): + stats = agent_instance.get_stats() + iterations = stats.get("iterations", 0) + tool_calls = stats.get("tool_calls", 0) + tokens_used = stats.get("tokens_used", 0) + + # 从结果中获取发现数量 + findings_count = 0 + result_summary = None + if node_data.get("result"): + result = node_data.get("result", {}) + if isinstance(result, dict): + findings_count = len(result.get("findings", [])) + if result.get("summary"): + result_summary = str(result.get("summary"))[:2000] + + tree_node = AgentTreeNode( + id=str(uuid4()), + task_id=task_id, + agent_id=agent_id, + agent_name=node_data.get("name", "Unknown"), + agent_type=node_data.get("type", "unknown"), + parent_agent_id=node_data.get("parent_id"), + depth=get_depth(agent_id), + task_description=node_data.get("task"), + knowledge_modules=node_data.get("knowledge_modules"), + status=node_data.get("status", "unknown"), + result_summary=result_summary, + findings_count=findings_count, + iterations=iterations, + tool_calls=tool_calls, + tokens_used=tokens_used, + ) + db.add(tree_node) + saved_count += 1 + + await db.commit() + logger.info(f"[SaveAgentTree] Successfully saved {saved_count} agent nodes to database") + + except Exception as e: + logger.error(f"[SaveAgentTree] Failed to save agent tree: {e}", exc_info=True) + await db.rollback() + + # ============ API Endpoints ============ @router.post("/", response_model=AgentTaskResponse) @@ -1992,6 +2242,9 @@ async def generate_audit_report( "code_snippet": f.code_snippet, "is_verified": f.is_verified, "has_poc": f.has_poc, + "poc_code": f.poc_code, + "poc_description": f.poc_description, + "poc_steps": f.poc_steps, "confidence": f.ai_confidence, "suggestion": f.suggestion, "fix_code": f.fix_code, @@ -2169,6 +2422,30 @@ async def generate_audit_report( md_lines.append("```") md_lines.append("") + # 🔥 添加 PoC 详情 + if f.has_poc: + md_lines.append("**Proof of Concept (PoC):**") + md_lines.append("") + + if f.poc_description: + md_lines.append(f"*{f.poc_description}*") + md_lines.append("") + + if f.poc_steps: + md_lines.append("**Reproduction Steps:**") + md_lines.append("") + for step_idx, step in enumerate(f.poc_steps, 1): + md_lines.append(f"{step_idx}. {step}") + md_lines.append("") + + if f.poc_code: + md_lines.append("**PoC Payload:**") + md_lines.append("") + md_lines.append("```") + md_lines.append(f.poc_code.strip()) + md_lines.append("```") + md_lines.append("") + md_lines.append("---") md_lines.append("") diff --git a/backend/app/core/config.py b/backend/app/core/config.py index 4cd7330..b5c5162 100644 --- a/backend/app/core/config.py +++ b/backend/app/core/config.py @@ -92,7 +92,7 @@ class Settings(BaseSettings): AGENT_TIMEOUT_SECONDS: int = 1800 # Agent 超时时间(30分钟) # 沙箱配置 - SANDBOX_IMAGE: str = "deepaudit-sandbox:latest" # 沙箱 Docker 镜像 + SANDBOX_IMAGE: str = "python:3.11-slim" # 沙箱 Docker 镜像 SANDBOX_MEMORY_LIMIT: str = "512m" # 沙箱内存限制 SANDBOX_CPU_LIMIT: float = 1.0 # 沙箱 CPU 限制 SANDBOX_TIMEOUT: int = 60 # 沙箱命令超时(秒) diff --git a/backend/app/services/agent/agents/analysis.py b/backend/app/services/agent/agents/analysis.py index 9c8d8cb..1d18d71 100644 --- a/backend/app/services/agent/agents/analysis.py +++ b/backend/app/services/agent/agents/analysis.py @@ -422,6 +422,7 @@ Final Answer: {{"findings": [...], "summary": "..."}}""" # 检查是否完成 if step.is_final: await self.emit_llm_decision("完成安全分析", "LLM 判断分析已充分") + logger.info(f"[{self.name}] Received Final Answer: {step.final_answer}") if step.final_answer and "findings" in step.final_answer: all_findings = step.final_answer["findings"] logger.info(f"[{self.name}] Final Answer contains {len(all_findings)} findings") @@ -438,7 +439,7 @@ Final Answer: {{"findings": [...], "summary": "..."}}""" f"发现 {finding.get('severity', 'medium')} 级别漏洞: {finding.get('title', 'Unknown')}" ) else: - logger.warning(f"[{self.name}] Final Answer has no 'findings' key: {step.final_answer}") + logger.warning(f"[{self.name}] Final Answer has no 'findings' key or is None: {step.final_answer}") # 🔥 记录工作完成 self.record_work(f"完成安全分析,发现 {len(all_findings)} 个潜在漏洞") diff --git a/backend/app/services/agent/agents/base.py b/backend/app/services/agent/agents/base.py index bef7d94..337b063 100644 --- a/backend/app/services/agent/agents/base.py +++ b/backend/app/services/agent/agents/base.py @@ -741,26 +741,43 @@ class BaseAgent(ABC): ) # ============ 发现相关事件 ============ - - async def emit_finding(self, title: str, severity: str, vuln_type: str, file_path: str = ""): + + async def emit_finding(self, title: str, severity: str, vuln_type: str, file_path: str = "", is_verified: bool = False): """发射漏洞发现事件""" - severity_emoji = { - "critical": "🔴", - "high": "🟠", - "medium": "🟡", - "low": "🟢", - }.get(severity.lower(), "⚪") - - await self.emit_event( - "finding", - f"{severity_emoji} [{self.name}] 发现漏洞: [{severity.upper()}] {title}\n 类型: {vuln_type}\n 位置: {file_path}", - metadata={ - "title": title, - "severity": severity, - "vulnerability_type": vuln_type, - "file_path": file_path, - } - ) + import uuid + finding_id = str(uuid.uuid4()) + + # 🔥 使用 EventManager.emit_finding 发送正确的事件类型 + if self.event_emitter and hasattr(self.event_emitter, 'emit_finding'): + await self.event_emitter.emit_finding( + finding_id=finding_id, + title=title, + severity=severity, + vulnerability_type=vuln_type, + is_verified=is_verified, + ) + else: + # 回退:使用通用事件发射 + severity_emoji = { + "critical": "🔴", + "high": "🟠", + "medium": "🟡", + "low": "🟢", + }.get(severity.lower(), "⚪") + + event_type = "finding_verified" if is_verified else "finding_new" + await self.emit_event( + event_type, + f"{severity_emoji} [{self.name}] 发现漏洞: [{severity.upper()}] {title}\n 类型: {vuln_type}\n 位置: {file_path}", + metadata={ + "id": finding_id, + "title": title, + "severity": severity, + "vulnerability_type": vuln_type, + "file_path": file_path, + "is_verified": is_verified, + } + ) # ============ 通用工具方法 ============ diff --git a/backend/app/services/agent/agents/orchestrator.py b/backend/app/services/agent/agents/orchestrator.py index 41e9e74..f43f503 100644 --- a/backend/app/services/agent/agents/orchestrator.py +++ b/backend/app/services/agent/agents/orchestrator.py @@ -432,6 +432,8 @@ Action Input: {{"参数": "值"}} # 🔥 CRITICAL: Log final findings count before returning logger.info(f"[Orchestrator] Final result: {len(self._all_findings)} findings collected") + if len(self._all_findings) == 0: + logger.warning(f"[Orchestrator] ⚠️ No findings collected! Dispatched agents: {list(self._dispatched_tasks.keys())}, Iterations: {self._iteration}") for i, f in enumerate(self._all_findings[:5]): # Log first 5 for debugging logger.debug(f"[Orchestrator] Finding {i+1}: {f.get('title', 'N/A')} - {f.get('vulnerability_type', 'N/A')}") @@ -654,18 +656,27 @@ Action Input: {{"参数": "值"}} # findings 字段通常来自 Analysis/Verification Agent # initial_findings 来自 Recon Agent raw_findings = data.get("findings", []) + logger.info(f"[Orchestrator] {agent_name} returned data with {len(raw_findings)} findings in 'findings' field") - # 🔥 Also check for initial_findings (from Recon) - if not raw_findings and "initial_findings" in data: + # 🔥 ENHANCED: Also check for initial_findings (from Recon) - 改进逻辑 + # 即使 findings 为空列表,也检查 initial_findings + if "initial_findings" in data: initial = data.get("initial_findings", []) - # Convert string findings to dict format - raw_findings = [] + logger.info(f"[Orchestrator] {agent_name} has {len(initial)} initial_findings") for f in initial: if isinstance(f, dict): - raw_findings.append(f) + # 🔥 Normalize finding format - 处理 Recon 返回的格式 + normalized = self._normalize_finding(f) + if normalized not in raw_findings: + raw_findings.append(normalized) elif isinstance(f, str): # String finding from Recon - skip, it's just an observation - pass + logger.debug(f"[Orchestrator] Skipping string finding: {f[:50]}...") + + # 🔥 Also check high_risk_areas from Recon for potential findings + if agent_name == "recon" and "high_risk_areas" in data: + high_risk = data.get("high_risk_areas", []) + logger.info(f"[Orchestrator] {agent_name} identified {len(high_risk)} high risk areas") if raw_findings: # 只添加字典格式的发现 @@ -673,33 +684,58 @@ Action Input: {{"参数": "值"}} logger.info(f"[Orchestrator] {agent_name} returned {len(valid_findings)} valid findings") - # 🔥 Merge findings to update existing ones and avoid duplicates + # 🔥 ENHANCED: Merge findings with better deduplication for new_f in valid_findings: - # Create key for identification (file + line + type) - new_key = ( - new_f.get("file_path", "") or new_f.get("file", ""), - new_f.get("line_start") or new_f.get("line", 0), - new_f.get("vulnerability_type", "") or new_f.get("type", ""), - ) + # Normalize the finding first + normalized_new = self._normalize_finding(new_f) - # Check if exists + # Create fingerprint for deduplication (file + description similarity) + new_file = normalized_new.get("file_path", "").lower().strip() + new_desc = (normalized_new.get("description", "") or "").lower()[:100] + new_type = (normalized_new.get("vulnerability_type", "") or "").lower() + new_line = normalized_new.get("line_start") or normalized_new.get("line", 0) + + # Check if exists (more flexible matching) found = False for i, existing_f in enumerate(self._all_findings): - existing_key = ( - existing_f.get("file_path", "") or existing_f.get("file", ""), - existing_f.get("line_start") or existing_f.get("line", 0), - existing_f.get("vulnerability_type", "") or existing_f.get("type", ""), + existing_file = (existing_f.get("file_path", "") or existing_f.get("file", "")).lower().strip() + existing_desc = (existing_f.get("description", "") or "").lower()[:100] + existing_type = (existing_f.get("vulnerability_type", "") or existing_f.get("type", "")).lower() + existing_line = existing_f.get("line_start") or existing_f.get("line", 0) + + # Match if same file AND (same line OR similar description OR same vulnerability type) + same_file = new_file and existing_file and ( + new_file == existing_file or + new_file.endswith(existing_file) or + existing_file.endswith(new_file) ) - if new_key == existing_key: + same_line = new_line and existing_line and new_line == existing_line + similar_desc = new_desc and existing_desc and ( + new_desc in existing_desc or existing_desc in new_desc + ) + same_type = new_type and existing_type and ( + new_type == existing_type or + (new_type in existing_type) or (existing_type in new_type) + ) + + if same_file and (same_line or similar_desc or same_type): # Update existing with new info (e.g. verification results) - self._all_findings[i] = {**existing_f, **new_f} + # Prefer verified data over unverified + merged = {**existing_f, **normalized_new} + # Keep the better title + if normalized_new.get("title") and len(normalized_new.get("title", "")) > len(existing_f.get("title", "")): + merged["title"] = normalized_new["title"] + # Keep verified status if either is verified + if existing_f.get("is_verified") or normalized_new.get("is_verified"): + merged["is_verified"] = True + self._all_findings[i] = merged found = True - logger.debug(f"[Orchestrator] Updated existing finding: {new_key}") + logger.info(f"[Orchestrator] Merged finding: {new_file}:{new_line} ({new_type})") break if not found: - self._all_findings.append(new_f) - logger.debug(f"[Orchestrator] Added new finding: {new_key}") + self._all_findings.append(normalized_new) + logger.info(f"[Orchestrator] Added new finding: {new_file}:{new_line} ({new_type})") logger.info(f"[Orchestrator] Total findings now: {len(self._all_findings)}") else: @@ -752,13 +788,13 @@ Action Input: {{"参数": "值"}} observation = f"""## {agent_name} Agent 执行结果 **状态**: 成功 -**发现数量**: {len(findings)} +**发现数量**: {len(valid_findings)} **迭代次数**: {result.iterations} **耗时**: {result.duration_ms}ms ### 发现摘要 """ - for i, f in enumerate(findings[:10]): + for i, f in enumerate(valid_findings[:10]): if not isinstance(f, dict): continue observation += f""" @@ -767,9 +803,9 @@ Action Input: {{"参数": "值"}} - 文件: {f.get('file_path', 'unknown')} - 描述: {f.get('description', '')[:200]}... """ - - if len(findings) > 10: - observation += f"\n... 还有 {len(findings) - 10} 个发现" + + if len(valid_findings) > 10: + observation += f"\n... 还有 {len(valid_findings) - 10} 个发现" if data.get("summary"): observation += f"\n\n### Agent 总结\n{data['summary']}" @@ -782,6 +818,94 @@ Action Input: {{"参数": "值"}} logger.error(f"Sub-agent dispatch failed: {e}", exc_info=True) return f"## 调度失败\n\n错误: {str(e)}" + def _normalize_finding(self, finding: Dict[str, Any]) -> Dict[str, Any]: + """ + 标准化发现格式 + + 不同 Agent 可能返回不同格式的发现,这个方法将它们标准化为统一格式 + """ + normalized = dict(finding) # 复制原始数据 + + # 🔥 处理 location 字段 -> file_path + line_start + if "location" in normalized and "file_path" not in normalized: + location = normalized["location"] + if isinstance(location, str) and ":" in location: + parts = location.split(":") + normalized["file_path"] = parts[0] + try: + normalized["line_start"] = int(parts[1]) + except (ValueError, IndexError): + pass + elif isinstance(location, str): + normalized["file_path"] = location + + # 🔥 处理 file 字段 -> file_path + if "file" in normalized and "file_path" not in normalized: + normalized["file_path"] = normalized["file"] + + # 🔥 处理 line 字段 -> line_start + if "line" in normalized and "line_start" not in normalized: + normalized["line_start"] = normalized["line"] + + # 🔥 处理 type 字段 -> vulnerability_type + if "type" in normalized and "vulnerability_type" not in normalized: + # 不是所有 type 都是漏洞类型,比如 "Vulnerability" 只是标记 + type_val = normalized["type"] + if type_val and type_val.lower() not in ["vulnerability", "finding", "issue"]: + normalized["vulnerability_type"] = type_val + elif "description" in normalized: + # 尝试从描述中推断漏洞类型 + desc = normalized["description"].lower() + if "command injection" in desc or "rce" in desc or "system(" in desc: + normalized["vulnerability_type"] = "command_injection" + elif "sql injection" in desc or "sqli" in desc: + normalized["vulnerability_type"] = "sql_injection" + elif "xss" in desc or "cross-site scripting" in desc: + normalized["vulnerability_type"] = "xss" + elif "path traversal" in desc or "directory traversal" in desc: + normalized["vulnerability_type"] = "path_traversal" + elif "ssrf" in desc: + normalized["vulnerability_type"] = "ssrf" + elif "xxe" in desc: + normalized["vulnerability_type"] = "xxe" + else: + normalized["vulnerability_type"] = "other" + + # 🔥 确保 severity 字段存在且为小写 + if "severity" in normalized: + normalized["severity"] = str(normalized["severity"]).lower() + else: + normalized["severity"] = "medium" + + # 🔥 处理 risk 字段 -> severity + if "risk" in normalized and "severity" not in normalized: + normalized["severity"] = str(normalized["risk"]).lower() + + # 🔥 生成 title 如果不存在 + if "title" not in normalized: + vuln_type = normalized.get("vulnerability_type", "Unknown") + file_path = normalized.get("file_path", "") + if file_path: + import os + normalized["title"] = f"{vuln_type.replace('_', ' ').title()} in {os.path.basename(file_path)}" + else: + normalized["title"] = f"{vuln_type.replace('_', ' ').title()} Vulnerability" + + # 🔥 处理 code 字段 -> code_snippet + if "code" in normalized and "code_snippet" not in normalized: + normalized["code_snippet"] = normalized["code"] + + # 🔥 处理 recommendation -> suggestion + if "recommendation" in normalized and "suggestion" not in normalized: + normalized["suggestion"] = normalized["recommendation"] + + # 🔥 处理 impact -> 添加到 description + if "impact" in normalized and normalized.get("description"): + if "impact" not in normalized["description"].lower(): + normalized["description"] += f"\n\nImpact: {normalized['impact']}" + + return normalized + def _summarize_findings(self) -> str: """汇总当前发现""" if not self._all_findings: diff --git a/backend/app/services/agent/agents/verification.py b/backend/app/services/agent/agents/verification.py index 33fb3cc..f6651ea 100644 --- a/backend/app/services/agent/agents/verification.py +++ b/backend/app/services/agent/agents/verification.py @@ -120,6 +120,10 @@ Final Answer: [JSON 格式的验证报告] 2. **深入理解** - 理解代码逻辑,不要表面判断 3. **证据支撑** - 判定要有依据 4. **安全第一** - 沙箱测试要谨慎 +5. **🔥 PoC 生成** - 对于 confirmed 和 likely 的漏洞,**必须**生成 PoC: + - poc.description: 简要描述这个 PoC 的作用 + - poc.steps: 详细的复现步骤列表 + - poc.payload: 实际的攻击载荷或测试代码 现在开始验证漏洞发现!""" diff --git a/backend/app/services/agent/event_manager.py b/backend/app/services/agent/event_manager.py index 5c68753..1584cde 100644 --- a/backend/app/services/agent/event_manager.py +++ b/backend/app/services/agent/event_manager.py @@ -171,6 +171,7 @@ class AgentEventEmitter: finding_id=finding_id, message=f"{'✅ 已验证' if is_verified else '🔍 新发现'}: [{severity.upper()}] {title}", metadata={ + "id": finding_id, # 🔥 添加 id 字段供前端使用 "title": title, "severity": severity, "vulnerability_type": vulnerability_type, @@ -330,7 +331,32 @@ class EventManager: async def _save_event_to_db(self, event_data: Dict): """保存事件到数据库""" from app.models.agent_task import AgentEvent - + + # 🔥 清理无效的 UTF-8 字符(如二进制内容) + def sanitize_string(s): + """清理字符串中的无效 UTF-8 字符""" + if s is None: + return None + if not isinstance(s, str): + s = str(s) + # 移除 NULL 字节和其他不可打印的控制字符(保留换行和制表符) + return ''.join( + char for char in s + if char in '\n\r\t' or (ord(char) >= 32 and ord(char) != 127) + ) + + def sanitize_dict(d): + """递归清理字典中的字符串值""" + if d is None: + return None + if isinstance(d, dict): + return {k: sanitize_dict(v) for k, v in d.items()} + elif isinstance(d, list): + return [sanitize_dict(item) for item in d] + elif isinstance(d, str): + return sanitize_string(d) + return d + async with self.db_session_factory() as db: event = AgentEvent( id=event_data["id"], @@ -338,14 +364,14 @@ class EventManager: event_type=event_data["event_type"], sequence=event_data["sequence"], phase=event_data["phase"], - message=event_data["message"], + message=sanitize_string(event_data["message"]), # 🔥 清理消息 tool_name=event_data["tool_name"], - tool_input=event_data["tool_input"], - tool_output=event_data["tool_output"], + tool_input=sanitize_dict(event_data["tool_input"]), # 🔥 清理工具输入 + tool_output=sanitize_dict(event_data["tool_output"]), # 🔥 清理工具输出 tool_duration_ms=event_data["tool_duration_ms"], finding_id=event_data["finding_id"], tokens_used=event_data["tokens_used"], - event_metadata=event_data["metadata"], + event_metadata=sanitize_dict(event_data["metadata"]), # 🔥 清理元数据 ) db.add(event) await db.commit() @@ -403,62 +429,79 @@ class EventManager: after_sequence: int = 0, ) -> AsyncGenerator[Dict, None]: """流式获取事件 - + 🔥 重要: 此方法会先排空队列中已缓存的事件(在 SSE 连接前产生的), 然后继续实时推送新事件。 + 只返回序列号 > after_sequence 的事件。 """ + logger.info(f"[StreamEvents] Task {task_id}: Starting stream with after_sequence={after_sequence}") + # 获取现有队列(由 AgentRunner 在初始化时创建) queue = self._event_queues.get(task_id) - + if not queue: # 如果队列不存在,创建一个新的(回退逻辑) queue = self.create_queue(task_id) logger.warning(f"Queue not found for task {task_id}, created new one") - + # 🔥 先排空队列中已缓存的事件(这些是在 SSE 连接前产生的) buffered_count = 0 + skipped_count = 0 while not queue.empty(): try: buffered_event = queue.get_nowait() + + # 🔥 过滤掉序列号 <= after_sequence 的事件 + event_sequence = buffered_event.get("sequence", 0) + if event_sequence <= after_sequence: + skipped_count += 1 + continue + buffered_count += 1 yield buffered_event - + # 🔥 为所有缓存事件添加延迟,确保不会一起输出 event_type = buffered_event.get("event_type") if event_type == "thinking_token": await asyncio.sleep(0.015) # 15ms for tokens else: await asyncio.sleep(0.005) # 5ms for other events - + # 检查是否是结束事件 if event_type in ["task_complete", "task_error", "task_cancel"]: - logger.debug(f"Task {task_id} already completed, sent {buffered_count} buffered events") + logger.debug(f"Task {task_id} already completed, sent {buffered_count} buffered events (skipped {skipped_count})") return except asyncio.QueueEmpty: break - - if buffered_count > 0: - logger.debug(f"Drained {buffered_count} buffered events for task {task_id}") - + + if buffered_count > 0 or skipped_count > 0: + logger.debug(f"Drained queue for task {task_id}: sent {buffered_count}, skipped {skipped_count} (after_sequence={after_sequence})") + # 然后实时推送新事件 try: while True: try: event = await asyncio.wait_for(queue.get(), timeout=30) + + # 🔥 过滤掉序列号 <= after_sequence 的事件 + event_sequence = event.get("sequence", 0) + if event_sequence <= after_sequence: + continue + yield event - + # 🔥 为 thinking_token 添加微延迟确保流式效果 if event.get("event_type") == "thinking_token": await asyncio.sleep(0.01) # 10ms - + # 检查是否是结束事件 if event.get("event_type") in ["task_complete", "task_error", "task_cancel"]: break - + except asyncio.TimeoutError: # 发送心跳 yield {"event_type": "heartbeat", "timestamp": datetime.now(timezone.utc).isoformat()} - + except GeneratorExit: # SSE 连接断开 logger.debug(f"SSE stream closed for task {task_id}") diff --git a/backend/app/services/agent/graph/runner.py b/backend/app/services/agent/graph/runner.py index 5660e5a..dd976ee 100644 --- a/backend/app/services/agent/graph/runner.py +++ b/backend/app/services/agent/graph/runner.py @@ -274,16 +274,21 @@ class AgentRunner: # 沙箱工具(仅 Verification Agent 可用) try: - self.sandbox_manager = SandboxManager( + from app.services.agent.tools.sandbox_tool import SandboxConfig + sandbox_config = SandboxConfig( image=settings.SANDBOX_IMAGE, memory_limit=settings.SANDBOX_MEMORY_LIMIT, cpu_limit=settings.SANDBOX_CPU_LIMIT, + timeout=settings.SANDBOX_TIMEOUT, + network_mode=settings.SANDBOX_NETWORK_MODE, ) - + self.sandbox_manager = SandboxManager(config=sandbox_config) + self.verification_tools["sandbox_exec"] = SandboxTool(self.sandbox_manager) self.verification_tools["sandbox_http"] = SandboxHttpTool(self.sandbox_manager) self.verification_tools["verify_vulnerability"] = VulnerabilityVerifyTool(self.sandbox_manager) - + logger.info("Sandbox tools initialized successfully") + except Exception as e: logger.warning(f"Sandbox initialization failed: {e}") diff --git a/backend/app/services/llm/adapters/litellm_adapter.py b/backend/app/services/llm/adapters/litellm_adapter.py index a2913c3..af3f4fa 100644 --- a/backend/app/services/llm/adapters/litellm_adapter.py +++ b/backend/app/services/llm/adapters/litellm_adapter.py @@ -216,22 +216,22 @@ class LiteLLMAdapter(BaseLLMAdapter): async def stream_complete(self, request: LLMRequest): """ 流式调用 LLM,逐 token 返回 - + Yields: dict: {"type": "token", "content": str} 或 {"type": "done", "content": str, "usage": dict} """ import litellm - + await self.validate_config() - + litellm.cache = None litellm.drop_params = True - + messages = [{"role": msg.role, "content": msg.content} for msg in request.messages] - + # 🔥 估算输入 token 数量(用于在无法获取真实 usage 时进行估算) input_tokens_estimate = sum(estimate_tokens(msg["content"]) for msg in messages) - + kwargs = { "model": self._litellm_model, "messages": messages, @@ -240,27 +240,30 @@ class LiteLLMAdapter(BaseLLMAdapter): "top_p": request.top_p if request.top_p is not None else self.config.top_p, "stream": True, # 启用流式输出 } - + # 🔥 对于支持的模型,请求在流式输出中包含 usage 信息 # OpenAI API 支持 stream_options if self.config.provider in [LLMProvider.OPENAI, LLMProvider.DEEPSEEK]: kwargs["stream_options"] = {"include_usage": True} - + if self.config.api_key and self.config.api_key != "ollama": kwargs["api_key"] = self.config.api_key - + if self._api_base: kwargs["api_base"] = self._api_base - + kwargs["timeout"] = self.config.timeout - + accumulated_content = "" final_usage = None # 🔥 存储最终的 usage 信息 - + chunk_count = 0 # 🔥 跟踪 chunk 数量 + try: response = await litellm.acompletion(**kwargs) - + async for chunk in response: + chunk_count += 1 + # 🔥 检查是否有 usage 信息(某些 API 会在最后的 chunk 中包含) if hasattr(chunk, "usage") and chunk.usage: final_usage = { @@ -269,14 +272,15 @@ class LiteLLMAdapter(BaseLLMAdapter): "total_tokens": chunk.usage.total_tokens or 0, } logger.debug(f"Got usage from chunk: {final_usage}") - + if not chunk.choices: + # 🔥 某些模型可能发送没有 choices 的 chunk(如心跳) continue - + delta = chunk.choices[0].delta content = getattr(delta, "content", "") or "" finish_reason = chunk.choices[0].finish_reason - + if content: accumulated_content += content yield { @@ -284,9 +288,8 @@ class LiteLLMAdapter(BaseLLMAdapter): "content": content, "accumulated": accumulated_content, } - else: - # Log when we get a chunk without content - logger.debug(f"Chunk with no content: {chunk}") + # 🔥 ENHANCED: 处理没有 content 但也没有 finish_reason 的情况 + # 某些模型(如智谱 GLM)可能在某些 chunk 中不返回内容 if finish_reason: # 流式完成 @@ -299,7 +302,11 @@ class LiteLLMAdapter(BaseLLMAdapter): "total_tokens": input_tokens_estimate + output_tokens_estimate, } logger.debug(f"Estimated usage: {final_usage}") - + + # 🔥 ENHANCED: 如果累积内容为空但有 finish_reason,记录警告 + if not accumulated_content: + logger.warning(f"Stream completed with no content after {chunk_count} chunks, finish_reason={finish_reason}") + yield { "type": "done", "content": accumulated_content, @@ -307,9 +314,27 @@ class LiteLLMAdapter(BaseLLMAdapter): "finish_reason": finish_reason, } break - + + # 🔥 ENHANCED: 如果循环结束但没有收到 finish_reason,也需要返回 done + if accumulated_content: + logger.warning(f"Stream ended without finish_reason, returning accumulated content ({len(accumulated_content)} chars)") + if not final_usage: + output_tokens_estimate = estimate_tokens(accumulated_content) + final_usage = { + "prompt_tokens": input_tokens_estimate, + "completion_tokens": output_tokens_estimate, + "total_tokens": input_tokens_estimate + output_tokens_estimate, + } + yield { + "type": "done", + "content": accumulated_content, + "usage": final_usage, + "finish_reason": "complete", + } + except Exception as e: # 🔥 即使出错,也尝试返回估算的 usage + logger.error(f"Stream error: {e}") output_tokens_estimate = estimate_tokens(accumulated_content) if accumulated_content else 0 yield { "type": "error", diff --git a/frontend/src/pages/AgentAudit/hooks/useAgentAuditState.ts b/frontend/src/pages/AgentAudit/hooks/useAgentAuditState.ts index 495da5f..476d473 100644 --- a/frontend/src/pages/AgentAudit/hooks/useAgentAuditState.ts +++ b/frontend/src/pages/AgentAudit/hooks/useAgentAuditState.ts @@ -42,6 +42,16 @@ function agentAuditReducer(state: AgentAuditState, action: AgentAuditAction): Ag case 'SET_FINDINGS': return { ...state, findings: action.payload }; + case 'ADD_FINDING': { + // 🔥 添加单个 finding,避免重复 + const newFinding = action.payload; + const existingIds = new Set(state.findings.map(f => f.id)); + if (newFinding.id && existingIds.has(newFinding.id)) { + return state; // 已存在,不添加 + } + return { ...state, findings: [...state.findings, newFinding] }; + } + case 'SET_AGENT_TREE': return { ...state, agentTree: action.payload }; diff --git a/frontend/src/pages/AgentAudit/index.tsx b/frontend/src/pages/AgentAudit/index.tsx index 9167496..c036828 100644 --- a/frontend/src/pages/AgentAudit/index.tsx +++ b/frontend/src/pages/AgentAudit/index.tsx @@ -63,7 +63,11 @@ function AgentAuditPageContent() { const previousTaskIdRef = useRef(undefined); const disconnectStreamRef = useRef<(() => void) | null>(null); const lastEventSequenceRef = useRef(0); - const historicalEventsLoadedRef = useRef(false); + const hasConnectedRef = useRef(false); // 🔥 追踪是否已连接 SSE + const hasLoadedHistoricalEventsRef = useRef(false); // 🔥 追踪是否已加载历史事件 + // 🔥 使用 state 来标记历史事件加载状态和触发 streamOptions 重新计算 + const [afterSequence, setAfterSequence] = useState(0); + const [historicalEventsLoaded, setHistoricalEventsLoaded] = useState(false); // 🔥 当 taskId 变化时立即重置状态(新建任务时清理旧日志) useEffect(() => { @@ -79,7 +83,10 @@ function AgentAuditPageContent() { setShowSplash(!taskId); // 3. 重置事件序列号和加载状态 lastEventSequenceRef.current = 0; - historicalEventsLoadedRef.current = false; + hasConnectedRef.current = false; // 🔥 重置 SSE 连接标志 + hasLoadedHistoricalEventsRef.current = false; // 🔥 重置历史事件加载标志 + setHistoricalEventsLoaded(false); // 🔥 重置历史事件加载状态 + setAfterSequence(0); // 🔥 重置 afterSequence state } previousTaskIdRef.current = taskId; }, [taskId, reset]); @@ -141,6 +148,14 @@ function AgentAuditPageContent() { // 🔥 NEW: 加载历史事件并转换为日志项 const loadHistoricalEvents = useCallback(async () => { if (!taskId) return 0; + + // 🔥 防止重复加载历史事件 + if (hasLoadedHistoricalEventsRef.current) { + console.log('[AgentAudit] Historical events already loaded, skipping'); + return 0; + } + hasLoadedHistoricalEventsRef.current = true; + try { console.log(`[AgentAudit] Fetching historical events for task ${taskId}...`); const events = await getAgentEvents(taskId, { limit: 500 }); @@ -356,20 +371,22 @@ function AgentAuditPageContent() { }); console.log(`[AgentAudit] Processed ${processedCount} events into logs, last sequence: ${lastEventSequenceRef.current}`); + // 🔥 更新 afterSequence state,触发 streamOptions 重新计算 + setAfterSequence(lastEventSequenceRef.current); return events.length; } catch (err) { console.error('[AgentAudit] Failed to load historical events:', err); return 0; } - }, [taskId, dispatch]); + }, [taskId, dispatch, setAfterSequence]); // ============ Stream Event Handling ============ const streamOptions = useMemo(() => ({ includeThinking: true, includeToolCalls: true, - // 🔥 使用最后的事件序列号,避免重复接收历史事件 - afterSequence: lastEventSequenceRef.current, + // 🔥 使用 state 变量,确保在历史事件加载后能获取最新值 + afterSequence: afterSequence, onEvent: (event: { type: string; message?: string; metadata?: { agent_name?: string; agent?: string } }) => { if (event.metadata?.agent_name) { setCurrentAgentName(event.metadata.agent_name); @@ -478,7 +495,20 @@ function AgentAuditPageContent() { agentName: getCurrentAgentName() || undefined, } }); - loadFindings(); + // 🔥 直接将 finding 添加到状态,不依赖 API(因为运行时数据库还没有数据) + dispatch({ + type: 'ADD_FINDING', + payload: { + id: (finding.id as string) || `finding-${Date.now()}`, + title: (finding.title as string) || 'Vulnerability found', + severity: (finding.severity as string) || 'medium', + vulnerability_type: (finding.vulnerability_type as string) || 'unknown', + file_path: finding.file_path as string, + line_start: finding.line_start as number, + description: finding.description as string, + is_verified: (finding.is_verified as boolean) || false, + } + }); }, onComplete: () => { dispatch({ type: 'ADD_LOG', payload: { type: 'info', title: 'Audit completed successfully' } }); @@ -489,7 +519,7 @@ function AgentAuditPageContent() { onError: (err: string) => { dispatch({ type: 'ADD_LOG', payload: { type: 'error', title: `Error: ${err}` } }); }, - }), [dispatch, loadTask, loadFindings, loadAgentTree, debouncedLoadAgentTree, + }), [afterSequence, dispatch, loadTask, loadFindings, loadAgentTree, debouncedLoadAgentTree, updateLog, removeLog, getCurrentAgentName, getCurrentThinkingId, setCurrentAgentName, setCurrentThinkingId]); @@ -523,7 +553,7 @@ function AgentAuditPageContent() { } setShowSplash(false); setLoading(true); - historicalEventsLoadedRef.current = false; + setHistoricalEventsLoaded(false); const loadAllData = async () => { try { @@ -534,11 +564,11 @@ function AgentAuditPageContent() { const eventsLoaded = await loadHistoricalEvents(); console.log(`[AgentAudit] Loaded ${eventsLoaded} historical events for task ${taskId}`); - // 标记历史事件已加载完成 - historicalEventsLoadedRef.current = true; + // 标记历史事件已加载完成 (setAfterSequence 已在 loadHistoricalEvents 中调用) + setHistoricalEventsLoaded(true); } catch (error) { console.error('[AgentAudit] Failed to load data:', error); - historicalEventsLoadedRef.current = true; // 即使出错也标记为完成,避免无限等待 + setHistoricalEventsLoaded(true); // 即使出错也标记为完成,避免无限等待 } finally { setLoading(false); } @@ -552,20 +582,21 @@ function AgentAuditPageContent() { // 等待历史事件加载完成,且任务正在运行 if (!taskId || !task?.status || task.status !== 'running') return; - // 如果历史事件尚未加载完成,等待一下 - const checkAndConnect = () => { - if (historicalEventsLoadedRef.current) { - connectStream(); - dispatch({ type: 'ADD_LOG', payload: { type: 'info', title: 'Connected to audit stream' } }); - } else { - // 延迟重试 - setTimeout(checkAndConnect, 100); - } - }; + // 🔥 使用 state 变量确保在历史事件加载完成后才连接 + if (!historicalEventsLoaded) return; - checkAndConnect(); - return () => disconnectStream(); - }, [taskId, task?.status, connectStream, disconnectStream, dispatch]); + // 🔥 避免重复连接 - 只连接一次 + if (hasConnectedRef.current) return; + + hasConnectedRef.current = true; + console.log(`[AgentAudit] Connecting to stream with afterSequence=${afterSequence}`); + connectStream(); + dispatch({ type: 'ADD_LOG', payload: { type: 'info', title: 'Connected to audit stream' } }); + + return () => { + disconnectStream(); + }; + }, [taskId, task?.status, historicalEventsLoaded, connectStream, disconnectStream, dispatch, afterSequence]); // Polling useEffect(() => { diff --git a/frontend/src/pages/AgentAudit/types.ts b/frontend/src/pages/AgentAudit/types.ts index bd246e2..e099f5d 100644 --- a/frontend/src/pages/AgentAudit/types.ts +++ b/frontend/src/pages/AgentAudit/types.ts @@ -71,6 +71,7 @@ export interface AgentTreeResponse { export type AgentAuditAction = | { type: 'SET_TASK'; payload: AgentTask } | { type: 'SET_FINDINGS'; payload: AgentFinding[] } + | { type: 'ADD_FINDING'; payload: Partial & { id: string } } | { type: 'SET_AGENT_TREE'; payload: AgentTreeResponse } | { type: 'SET_LOGS'; payload: LogItem[] } | { type: 'ADD_LOG'; payload: Omit & { id?: string } } diff --git a/frontend/src/shared/api/agentStream.ts b/frontend/src/shared/api/agentStream.ts index 0ada1da..770a31c 100644 --- a/frontend/src/shared/api/agentStream.ts +++ b/frontend/src/shared/api/agentStream.ts @@ -396,6 +396,7 @@ export class AgentStreamHandler { break; // 发现 + case 'finding': // 🔥 向后兼容旧的事件类型 case 'finding_new': case 'finding_verified': this.options.onFinding?.(