feat(agent): 增强漏洞发现处理流程和前端兼容性
- 后端添加对旧事件类型'finding'的兼容支持 - 改进漏洞发现标准化和去重逻辑 - 新增PoC生成要求和相关字段 - 优化沙箱配置初始化流程 - 前端添加ADD_FINDING操作和状态管理 - 增强事件流处理和序列号过滤 - 改进历史事件加载和SSE连接逻辑 - 添加漏洞验证状态和PoC信息到报告
This commit is contained in:
parent
6d98f29fa6
commit
4e4dd05ddb
|
|
@ -393,10 +393,20 @@ async def _execute_agent_task(task_id: str):
|
|||
task.tool_calls_count = result.tool_calls
|
||||
task.tokens_used = result.tokens_used
|
||||
|
||||
# 统计严重程度
|
||||
# 🔥 统计分析的文件数量(从 findings 中提取唯一文件)
|
||||
analyzed_file_set = set()
|
||||
for f in findings:
|
||||
if isinstance(f, dict):
|
||||
sev = f.get("severity", "low")
|
||||
file_path = f.get("file_path") or f.get("file") or f.get("location", "").split(":")[0]
|
||||
if file_path:
|
||||
analyzed_file_set.add(file_path)
|
||||
task.analyzed_files = len(analyzed_file_set) if analyzed_file_set else task.total_files
|
||||
|
||||
# 统计严重程度和验证状态
|
||||
verified_count = 0
|
||||
for f in findings:
|
||||
if isinstance(f, dict):
|
||||
sev = str(f.get("severity", "low")).lower()
|
||||
if sev == "critical":
|
||||
task.critical_count += 1
|
||||
elif sev == "high":
|
||||
|
|
@ -405,6 +415,10 @@ async def _execute_agent_task(task_id: str):
|
|||
task.medium_count += 1
|
||||
elif sev == "low":
|
||||
task.low_count += 1
|
||||
# 🔥 统计已验证的发现
|
||||
if f.get("is_verified") or f.get("verdict") == "confirmed":
|
||||
verified_count += 1
|
||||
task.verified_count = verified_count
|
||||
|
||||
# 计算安全评分
|
||||
task.security_score = _calculate_security_score(findings)
|
||||
|
|
@ -462,15 +476,21 @@ async def _execute_agent_task(task_id: str):
|
|||
logger.error(f"Failed to update task status: {db_error}")
|
||||
|
||||
finally:
|
||||
# 🔥 在清理之前保存 Agent 树到数据库
|
||||
try:
|
||||
async with async_session_factory() as save_db:
|
||||
await _save_agent_tree(save_db, task_id)
|
||||
except Exception as save_error:
|
||||
logger.error(f"Failed to save agent tree: {save_error}")
|
||||
|
||||
# 清理
|
||||
_running_orchestrators.pop(task_id, None)
|
||||
_running_tasks.pop(task_id, None)
|
||||
_running_event_managers.pop(task_id, None)
|
||||
_running_asyncio_tasks.pop(task_id, None) # 🔥 清理 asyncio task
|
||||
|
||||
# 从 Registry 注销
|
||||
if orchestrator:
|
||||
agent_registry.unregister_agent(orchestrator.agent_id)
|
||||
# 🔥 清理整个 Agent 注册表(包括所有子 Agent)
|
||||
agent_registry.clear()
|
||||
|
||||
logger.debug(f"Task {task_id} cleaned up")
|
||||
|
||||
|
|
@ -713,7 +733,11 @@ async def _collect_project_info(
|
|||
|
||||
|
||||
async def _save_findings(db: AsyncSession, task_id: str, findings: List[Dict]) -> None:
|
||||
"""保存发现到数据库"""
|
||||
"""
|
||||
保存发现到数据库
|
||||
|
||||
🔥 增强版:支持多种 Agent 输出格式,健壮的字段映射
|
||||
"""
|
||||
from app.models.agent_task import VulnerabilityType
|
||||
|
||||
logger.info(f"[SaveFindings] Starting to save {len(findings)} findings for task {task_id}")
|
||||
|
|
@ -744,8 +768,12 @@ async def _save_findings(db: AsyncSession, task_id: str, findings: List[Dict]) -
|
|||
"idor": VulnerabilityType.IDOR,
|
||||
"sensitive_data_exposure": VulnerabilityType.SENSITIVE_DATA_EXPOSURE,
|
||||
"hardcoded_secret": VulnerabilityType.HARDCODED_SECRET,
|
||||
"deserialization": VulnerabilityType.DESERIALIZATION, # Added common type
|
||||
"weak_crypto": VulnerabilityType.WEAK_CRYPTO, # Added common type
|
||||
"deserialization": VulnerabilityType.DESERIALIZATION,
|
||||
"weak_crypto": VulnerabilityType.WEAK_CRYPTO,
|
||||
"file_inclusion": VulnerabilityType.FILE_INCLUSION,
|
||||
"race_condition": VulnerabilityType.RACE_CONDITION,
|
||||
"business_logic": VulnerabilityType.BUSINESS_LOGIC,
|
||||
"memory_corruption": VulnerabilityType.MEMORY_CORRUPTION,
|
||||
}
|
||||
|
||||
saved_count = 0
|
||||
|
|
@ -753,49 +781,185 @@ async def _save_findings(db: AsyncSession, task_id: str, findings: List[Dict]) -
|
|||
|
||||
for finding in findings:
|
||||
if not isinstance(finding, dict):
|
||||
logger.debug(f"[SaveFindings] Skipping non-dict finding: {type(finding)}")
|
||||
continue
|
||||
|
||||
try:
|
||||
# Handle severity (case-insensitive)
|
||||
raw_severity = str(finding.get("severity", "medium")).lower().strip()
|
||||
# 🔥 Handle severity (case-insensitive, support multiple field names)
|
||||
raw_severity = str(
|
||||
finding.get("severity") or
|
||||
finding.get("risk") or
|
||||
"medium"
|
||||
).lower().strip()
|
||||
severity_enum = severity_map.get(raw_severity, VulnerabilitySeverity.MEDIUM)
|
||||
|
||||
# Handle vulnerability type (case-insensitive & snake_case normalization)
|
||||
raw_type = str(finding.get("vulnerability_type", "other")).lower().strip().replace(" ", "_")
|
||||
# 🔥 Handle vulnerability type (case-insensitive & snake_case normalization)
|
||||
# Support multiple field names: vulnerability_type, type, vuln_type
|
||||
raw_type = str(
|
||||
finding.get("vulnerability_type") or
|
||||
finding.get("type") or
|
||||
finding.get("vuln_type") or
|
||||
"other"
|
||||
).lower().strip().replace(" ", "_").replace("-", "_")
|
||||
|
||||
type_enum = type_map.get(raw_type, VulnerabilityType.OTHER)
|
||||
|
||||
# Additional fallback for known Agent output variations
|
||||
if "sqli" in raw_type: type_enum = VulnerabilityType.SQL_INJECTION
|
||||
if "xss" in raw_type: type_enum = VulnerabilityType.XSS
|
||||
if "rce" in raw_type or "command" in raw_type: type_enum = VulnerabilityType.COMMAND_INJECTION
|
||||
# 🔥 Additional fallback for common Agent output variations
|
||||
if "sqli" in raw_type or "sql" in raw_type:
|
||||
type_enum = VulnerabilityType.SQL_INJECTION
|
||||
if "xss" in raw_type:
|
||||
type_enum = VulnerabilityType.XSS
|
||||
if "rce" in raw_type or "command" in raw_type or "cmd" in raw_type:
|
||||
type_enum = VulnerabilityType.COMMAND_INJECTION
|
||||
if "traversal" in raw_type or "lfi" in raw_type or "rfi" in raw_type:
|
||||
type_enum = VulnerabilityType.PATH_TRAVERSAL
|
||||
if "ssrf" in raw_type:
|
||||
type_enum = VulnerabilityType.SSRF
|
||||
if "xxe" in raw_type:
|
||||
type_enum = VulnerabilityType.XXE
|
||||
if "auth" in raw_type:
|
||||
type_enum = VulnerabilityType.AUTH_BYPASS
|
||||
if "secret" in raw_type or "credential" in raw_type or "password" in raw_type:
|
||||
type_enum = VulnerabilityType.HARDCODED_SECRET
|
||||
if "deserial" in raw_type:
|
||||
type_enum = VulnerabilityType.DESERIALIZATION
|
||||
|
||||
# 🔥 Handle file path (support multiple field names)
|
||||
file_path = (
|
||||
finding.get("file_path") or
|
||||
finding.get("file") or
|
||||
finding.get("location", "").split(":")[0] if ":" in finding.get("location", "") else finding.get("location")
|
||||
)
|
||||
|
||||
# 🔥 Handle line numbers (support multiple formats)
|
||||
line_start = finding.get("line_start") or finding.get("line")
|
||||
if not line_start and ":" in finding.get("location", ""):
|
||||
try:
|
||||
line_start = int(finding.get("location", "").split(":")[1])
|
||||
except (ValueError, IndexError):
|
||||
line_start = None
|
||||
|
||||
line_end = finding.get("line_end") or line_start
|
||||
|
||||
# 🔥 Handle code snippet (support multiple field names)
|
||||
code_snippet = (
|
||||
finding.get("code_snippet") or
|
||||
finding.get("code") or
|
||||
finding.get("vulnerable_code")
|
||||
)
|
||||
|
||||
# 🔥 Handle title (generate from type if not provided)
|
||||
title = finding.get("title")
|
||||
if not title:
|
||||
# Generate title from vulnerability type and file
|
||||
type_display = raw_type.replace("_", " ").title()
|
||||
if file_path:
|
||||
title = f"{type_display} in {os.path.basename(file_path)}"
|
||||
else:
|
||||
title = f"{type_display} Vulnerability"
|
||||
|
||||
# 🔥 Handle description (support multiple field names)
|
||||
description = (
|
||||
finding.get("description") or
|
||||
finding.get("details") or
|
||||
finding.get("explanation") or
|
||||
finding.get("impact") or
|
||||
""
|
||||
)
|
||||
|
||||
# 🔥 Handle suggestion/recommendation
|
||||
suggestion = (
|
||||
finding.get("suggestion") or
|
||||
finding.get("recommendation") or
|
||||
finding.get("remediation") or
|
||||
finding.get("fix")
|
||||
)
|
||||
|
||||
# 🔥 Handle confidence (map to ai_confidence field in model)
|
||||
confidence = finding.get("confidence") or finding.get("ai_confidence") or 0.5
|
||||
if isinstance(confidence, str):
|
||||
try:
|
||||
confidence = float(confidence)
|
||||
except ValueError:
|
||||
confidence = 0.5
|
||||
|
||||
# 🔥 Handle verification status
|
||||
is_verified = finding.get("is_verified", False)
|
||||
if finding.get("verdict") == "confirmed":
|
||||
is_verified = True
|
||||
|
||||
# 🔥 Handle PoC information
|
||||
poc_data = finding.get("poc", {})
|
||||
has_poc = bool(poc_data)
|
||||
poc_code = None
|
||||
poc_description = None
|
||||
poc_steps = None
|
||||
|
||||
if isinstance(poc_data, dict):
|
||||
poc_description = poc_data.get("description")
|
||||
poc_steps = poc_data.get("steps")
|
||||
poc_code = poc_data.get("payload") or poc_data.get("code")
|
||||
elif isinstance(poc_data, str):
|
||||
poc_description = poc_data
|
||||
|
||||
# 🔥 Handle verification details
|
||||
verification_method = finding.get("verification_method")
|
||||
verification_result = None
|
||||
if finding.get("verification_details"):
|
||||
verification_result = {"details": finding.get("verification_details")}
|
||||
|
||||
# 🔥 Handle CWE and CVSS
|
||||
cwe_id = finding.get("cwe_id") or finding.get("cwe")
|
||||
cvss_score = finding.get("cvss_score") or finding.get("cvss")
|
||||
if isinstance(cvss_score, str):
|
||||
try:
|
||||
cvss_score = float(cvss_score)
|
||||
except ValueError:
|
||||
cvss_score = None
|
||||
|
||||
db_finding = AgentFinding(
|
||||
id=str(uuid4()),
|
||||
task_id=task_id,
|
||||
vulnerability_type=type_enum,
|
||||
severity=severity_enum,
|
||||
title=finding.get("title", "Unknown"),
|
||||
description=finding.get("description", ""),
|
||||
file_path=finding.get("file_path"),
|
||||
line_start=finding.get("line_start"),
|
||||
line_end=finding.get("line_end"),
|
||||
code_snippet=finding.get("code_snippet"),
|
||||
suggestion=finding.get("suggestion") or finding.get("recommendation"),
|
||||
is_verified=finding.get("is_verified", False),
|
||||
confidence=finding.get("confidence", 0.5),
|
||||
status=FindingStatus.VERIFIED if finding.get("is_verified") else FindingStatus.NEW,
|
||||
title=title[:500] if title else "Unknown Vulnerability",
|
||||
description=description[:5000] if description else "",
|
||||
file_path=file_path[:500] if file_path else None,
|
||||
line_start=line_start,
|
||||
line_end=line_end,
|
||||
code_snippet=code_snippet[:10000] if code_snippet else None,
|
||||
suggestion=suggestion[:5000] if suggestion else None,
|
||||
is_verified=is_verified,
|
||||
ai_confidence=confidence, # 🔥 FIX: Use ai_confidence, not confidence
|
||||
status=FindingStatus.VERIFIED if is_verified else FindingStatus.NEW,
|
||||
# 🔥 Additional fields
|
||||
has_poc=has_poc,
|
||||
poc_code=poc_code,
|
||||
poc_description=poc_description,
|
||||
poc_steps=poc_steps,
|
||||
verification_method=verification_method,
|
||||
verification_result=verification_result,
|
||||
cvss_score=cvss_score,
|
||||
# References for CWE
|
||||
references=[{"cwe": cwe_id}] if cwe_id else None,
|
||||
)
|
||||
db.add(db_finding)
|
||||
saved_count += 1
|
||||
logger.debug(f"[SaveFindings] Prepared finding: {title[:50]}... ({severity_enum})")
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to save finding: {e}, data: {finding}")
|
||||
import traceback
|
||||
logger.debug(f"[SaveFindings] Traceback: {traceback.format_exc()}")
|
||||
|
||||
logger.info(f"Successfully prepared {saved_count} findings for commit")
|
||||
|
||||
try:
|
||||
await db.commit()
|
||||
logger.info(f"[SaveFindings] Successfully committed {saved_count} findings to database")
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to commit findings: {e}")
|
||||
await db.rollback()
|
||||
|
||||
|
||||
def _calculate_security_score(findings: List[Dict]) -> float:
|
||||
|
|
@ -822,6 +986,92 @@ def _calculate_security_score(findings: List[Dict]) -> float:
|
|||
return float(score)
|
||||
|
||||
|
||||
async def _save_agent_tree(db: AsyncSession, task_id: str) -> None:
|
||||
"""
|
||||
保存 Agent 树到数据库
|
||||
|
||||
🔥 在任务完成前调用,将内存中的 Agent 树持久化到数据库
|
||||
"""
|
||||
from app.models.agent_task import AgentTreeNode
|
||||
from app.services.agent.core import agent_registry
|
||||
|
||||
try:
|
||||
tree = agent_registry.get_agent_tree()
|
||||
nodes = tree.get("nodes", {})
|
||||
|
||||
if not nodes:
|
||||
logger.warning(f"[SaveAgentTree] No agent nodes to save for task {task_id}")
|
||||
return
|
||||
|
||||
logger.info(f"[SaveAgentTree] Saving {len(nodes)} agent nodes for task {task_id}")
|
||||
|
||||
# 计算每个节点的深度
|
||||
def get_depth(agent_id: str, visited: set = None) -> int:
|
||||
if visited is None:
|
||||
visited = set()
|
||||
if agent_id in visited:
|
||||
return 0
|
||||
visited.add(agent_id)
|
||||
node = nodes.get(agent_id)
|
||||
if not node:
|
||||
return 0
|
||||
parent_id = node.get("parent_id")
|
||||
if not parent_id:
|
||||
return 0
|
||||
return 1 + get_depth(parent_id, visited)
|
||||
|
||||
saved_count = 0
|
||||
for agent_id, node_data in nodes.items():
|
||||
# 获取 Agent 实例的统计数据
|
||||
agent_instance = agent_registry.get_agent(agent_id)
|
||||
iterations = 0
|
||||
tool_calls = 0
|
||||
tokens_used = 0
|
||||
|
||||
if agent_instance and hasattr(agent_instance, 'get_stats'):
|
||||
stats = agent_instance.get_stats()
|
||||
iterations = stats.get("iterations", 0)
|
||||
tool_calls = stats.get("tool_calls", 0)
|
||||
tokens_used = stats.get("tokens_used", 0)
|
||||
|
||||
# 从结果中获取发现数量
|
||||
findings_count = 0
|
||||
result_summary = None
|
||||
if node_data.get("result"):
|
||||
result = node_data.get("result", {})
|
||||
if isinstance(result, dict):
|
||||
findings_count = len(result.get("findings", []))
|
||||
if result.get("summary"):
|
||||
result_summary = str(result.get("summary"))[:2000]
|
||||
|
||||
tree_node = AgentTreeNode(
|
||||
id=str(uuid4()),
|
||||
task_id=task_id,
|
||||
agent_id=agent_id,
|
||||
agent_name=node_data.get("name", "Unknown"),
|
||||
agent_type=node_data.get("type", "unknown"),
|
||||
parent_agent_id=node_data.get("parent_id"),
|
||||
depth=get_depth(agent_id),
|
||||
task_description=node_data.get("task"),
|
||||
knowledge_modules=node_data.get("knowledge_modules"),
|
||||
status=node_data.get("status", "unknown"),
|
||||
result_summary=result_summary,
|
||||
findings_count=findings_count,
|
||||
iterations=iterations,
|
||||
tool_calls=tool_calls,
|
||||
tokens_used=tokens_used,
|
||||
)
|
||||
db.add(tree_node)
|
||||
saved_count += 1
|
||||
|
||||
await db.commit()
|
||||
logger.info(f"[SaveAgentTree] Successfully saved {saved_count} agent nodes to database")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"[SaveAgentTree] Failed to save agent tree: {e}", exc_info=True)
|
||||
await db.rollback()
|
||||
|
||||
|
||||
# ============ API Endpoints ============
|
||||
|
||||
@router.post("/", response_model=AgentTaskResponse)
|
||||
|
|
@ -1992,6 +2242,9 @@ async def generate_audit_report(
|
|||
"code_snippet": f.code_snippet,
|
||||
"is_verified": f.is_verified,
|
||||
"has_poc": f.has_poc,
|
||||
"poc_code": f.poc_code,
|
||||
"poc_description": f.poc_description,
|
||||
"poc_steps": f.poc_steps,
|
||||
"confidence": f.ai_confidence,
|
||||
"suggestion": f.suggestion,
|
||||
"fix_code": f.fix_code,
|
||||
|
|
@ -2169,6 +2422,30 @@ async def generate_audit_report(
|
|||
md_lines.append("```")
|
||||
md_lines.append("")
|
||||
|
||||
# 🔥 添加 PoC 详情
|
||||
if f.has_poc:
|
||||
md_lines.append("**Proof of Concept (PoC):**")
|
||||
md_lines.append("")
|
||||
|
||||
if f.poc_description:
|
||||
md_lines.append(f"*{f.poc_description}*")
|
||||
md_lines.append("")
|
||||
|
||||
if f.poc_steps:
|
||||
md_lines.append("**Reproduction Steps:**")
|
||||
md_lines.append("")
|
||||
for step_idx, step in enumerate(f.poc_steps, 1):
|
||||
md_lines.append(f"{step_idx}. {step}")
|
||||
md_lines.append("")
|
||||
|
||||
if f.poc_code:
|
||||
md_lines.append("**PoC Payload:**")
|
||||
md_lines.append("")
|
||||
md_lines.append("```")
|
||||
md_lines.append(f.poc_code.strip())
|
||||
md_lines.append("```")
|
||||
md_lines.append("")
|
||||
|
||||
md_lines.append("---")
|
||||
md_lines.append("")
|
||||
|
||||
|
|
|
|||
|
|
@ -92,7 +92,7 @@ class Settings(BaseSettings):
|
|||
AGENT_TIMEOUT_SECONDS: int = 1800 # Agent 超时时间(30分钟)
|
||||
|
||||
# 沙箱配置
|
||||
SANDBOX_IMAGE: str = "deepaudit-sandbox:latest" # 沙箱 Docker 镜像
|
||||
SANDBOX_IMAGE: str = "python:3.11-slim" # 沙箱 Docker 镜像
|
||||
SANDBOX_MEMORY_LIMIT: str = "512m" # 沙箱内存限制
|
||||
SANDBOX_CPU_LIMIT: float = 1.0 # 沙箱 CPU 限制
|
||||
SANDBOX_TIMEOUT: int = 60 # 沙箱命令超时(秒)
|
||||
|
|
|
|||
|
|
@ -422,6 +422,7 @@ Final Answer: {{"findings": [...], "summary": "..."}}"""
|
|||
# 检查是否完成
|
||||
if step.is_final:
|
||||
await self.emit_llm_decision("完成安全分析", "LLM 判断分析已充分")
|
||||
logger.info(f"[{self.name}] Received Final Answer: {step.final_answer}")
|
||||
if step.final_answer and "findings" in step.final_answer:
|
||||
all_findings = step.final_answer["findings"]
|
||||
logger.info(f"[{self.name}] Final Answer contains {len(all_findings)} findings")
|
||||
|
|
@ -438,7 +439,7 @@ Final Answer: {{"findings": [...], "summary": "..."}}"""
|
|||
f"发现 {finding.get('severity', 'medium')} 级别漏洞: {finding.get('title', 'Unknown')}"
|
||||
)
|
||||
else:
|
||||
logger.warning(f"[{self.name}] Final Answer has no 'findings' key: {step.final_answer}")
|
||||
logger.warning(f"[{self.name}] Final Answer has no 'findings' key or is None: {step.final_answer}")
|
||||
|
||||
# 🔥 记录工作完成
|
||||
self.record_work(f"完成安全分析,发现 {len(all_findings)} 个潜在漏洞")
|
||||
|
|
|
|||
|
|
@ -742,8 +742,22 @@ class BaseAgent(ABC):
|
|||
|
||||
# ============ 发现相关事件 ============
|
||||
|
||||
async def emit_finding(self, title: str, severity: str, vuln_type: str, file_path: str = ""):
|
||||
async def emit_finding(self, title: str, severity: str, vuln_type: str, file_path: str = "", is_verified: bool = False):
|
||||
"""发射漏洞发现事件"""
|
||||
import uuid
|
||||
finding_id = str(uuid.uuid4())
|
||||
|
||||
# 🔥 使用 EventManager.emit_finding 发送正确的事件类型
|
||||
if self.event_emitter and hasattr(self.event_emitter, 'emit_finding'):
|
||||
await self.event_emitter.emit_finding(
|
||||
finding_id=finding_id,
|
||||
title=title,
|
||||
severity=severity,
|
||||
vulnerability_type=vuln_type,
|
||||
is_verified=is_verified,
|
||||
)
|
||||
else:
|
||||
# 回退:使用通用事件发射
|
||||
severity_emoji = {
|
||||
"critical": "🔴",
|
||||
"high": "🟠",
|
||||
|
|
@ -751,14 +765,17 @@ class BaseAgent(ABC):
|
|||
"low": "🟢",
|
||||
}.get(severity.lower(), "⚪")
|
||||
|
||||
event_type = "finding_verified" if is_verified else "finding_new"
|
||||
await self.emit_event(
|
||||
"finding",
|
||||
event_type,
|
||||
f"{severity_emoji} [{self.name}] 发现漏洞: [{severity.upper()}] {title}\n 类型: {vuln_type}\n 位置: {file_path}",
|
||||
metadata={
|
||||
"id": finding_id,
|
||||
"title": title,
|
||||
"severity": severity,
|
||||
"vulnerability_type": vuln_type,
|
||||
"file_path": file_path,
|
||||
"is_verified": is_verified,
|
||||
}
|
||||
)
|
||||
|
||||
|
|
|
|||
|
|
@ -432,6 +432,8 @@ Action Input: {{"参数": "值"}}
|
|||
|
||||
# 🔥 CRITICAL: Log final findings count before returning
|
||||
logger.info(f"[Orchestrator] Final result: {len(self._all_findings)} findings collected")
|
||||
if len(self._all_findings) == 0:
|
||||
logger.warning(f"[Orchestrator] ⚠️ No findings collected! Dispatched agents: {list(self._dispatched_tasks.keys())}, Iterations: {self._iteration}")
|
||||
for i, f in enumerate(self._all_findings[:5]): # Log first 5 for debugging
|
||||
logger.debug(f"[Orchestrator] Finding {i+1}: {f.get('title', 'N/A')} - {f.get('vulnerability_type', 'N/A')}")
|
||||
|
||||
|
|
@ -654,18 +656,27 @@ Action Input: {{"参数": "值"}}
|
|||
# findings 字段通常来自 Analysis/Verification Agent
|
||||
# initial_findings 来自 Recon Agent
|
||||
raw_findings = data.get("findings", [])
|
||||
logger.info(f"[Orchestrator] {agent_name} returned data with {len(raw_findings)} findings in 'findings' field")
|
||||
|
||||
# 🔥 Also check for initial_findings (from Recon)
|
||||
if not raw_findings and "initial_findings" in data:
|
||||
# 🔥 ENHANCED: Also check for initial_findings (from Recon) - 改进逻辑
|
||||
# 即使 findings 为空列表,也检查 initial_findings
|
||||
if "initial_findings" in data:
|
||||
initial = data.get("initial_findings", [])
|
||||
# Convert string findings to dict format
|
||||
raw_findings = []
|
||||
logger.info(f"[Orchestrator] {agent_name} has {len(initial)} initial_findings")
|
||||
for f in initial:
|
||||
if isinstance(f, dict):
|
||||
raw_findings.append(f)
|
||||
# 🔥 Normalize finding format - 处理 Recon 返回的格式
|
||||
normalized = self._normalize_finding(f)
|
||||
if normalized not in raw_findings:
|
||||
raw_findings.append(normalized)
|
||||
elif isinstance(f, str):
|
||||
# String finding from Recon - skip, it's just an observation
|
||||
pass
|
||||
logger.debug(f"[Orchestrator] Skipping string finding: {f[:50]}...")
|
||||
|
||||
# 🔥 Also check high_risk_areas from Recon for potential findings
|
||||
if agent_name == "recon" and "high_risk_areas" in data:
|
||||
high_risk = data.get("high_risk_areas", [])
|
||||
logger.info(f"[Orchestrator] {agent_name} identified {len(high_risk)} high risk areas")
|
||||
|
||||
if raw_findings:
|
||||
# 只添加字典格式的发现
|
||||
|
|
@ -673,33 +684,58 @@ Action Input: {{"参数": "值"}}
|
|||
|
||||
logger.info(f"[Orchestrator] {agent_name} returned {len(valid_findings)} valid findings")
|
||||
|
||||
# 🔥 Merge findings to update existing ones and avoid duplicates
|
||||
# 🔥 ENHANCED: Merge findings with better deduplication
|
||||
for new_f in valid_findings:
|
||||
# Create key for identification (file + line + type)
|
||||
new_key = (
|
||||
new_f.get("file_path", "") or new_f.get("file", ""),
|
||||
new_f.get("line_start") or new_f.get("line", 0),
|
||||
new_f.get("vulnerability_type", "") or new_f.get("type", ""),
|
||||
)
|
||||
# Normalize the finding first
|
||||
normalized_new = self._normalize_finding(new_f)
|
||||
|
||||
# Check if exists
|
||||
# Create fingerprint for deduplication (file + description similarity)
|
||||
new_file = normalized_new.get("file_path", "").lower().strip()
|
||||
new_desc = (normalized_new.get("description", "") or "").lower()[:100]
|
||||
new_type = (normalized_new.get("vulnerability_type", "") or "").lower()
|
||||
new_line = normalized_new.get("line_start") or normalized_new.get("line", 0)
|
||||
|
||||
# Check if exists (more flexible matching)
|
||||
found = False
|
||||
for i, existing_f in enumerate(self._all_findings):
|
||||
existing_key = (
|
||||
existing_f.get("file_path", "") or existing_f.get("file", ""),
|
||||
existing_f.get("line_start") or existing_f.get("line", 0),
|
||||
existing_f.get("vulnerability_type", "") or existing_f.get("type", ""),
|
||||
existing_file = (existing_f.get("file_path", "") or existing_f.get("file", "")).lower().strip()
|
||||
existing_desc = (existing_f.get("description", "") or "").lower()[:100]
|
||||
existing_type = (existing_f.get("vulnerability_type", "") or existing_f.get("type", "")).lower()
|
||||
existing_line = existing_f.get("line_start") or existing_f.get("line", 0)
|
||||
|
||||
# Match if same file AND (same line OR similar description OR same vulnerability type)
|
||||
same_file = new_file and existing_file and (
|
||||
new_file == existing_file or
|
||||
new_file.endswith(existing_file) or
|
||||
existing_file.endswith(new_file)
|
||||
)
|
||||
if new_key == existing_key:
|
||||
same_line = new_line and existing_line and new_line == existing_line
|
||||
similar_desc = new_desc and existing_desc and (
|
||||
new_desc in existing_desc or existing_desc in new_desc
|
||||
)
|
||||
same_type = new_type and existing_type and (
|
||||
new_type == existing_type or
|
||||
(new_type in existing_type) or (existing_type in new_type)
|
||||
)
|
||||
|
||||
if same_file and (same_line or similar_desc or same_type):
|
||||
# Update existing with new info (e.g. verification results)
|
||||
self._all_findings[i] = {**existing_f, **new_f}
|
||||
# Prefer verified data over unverified
|
||||
merged = {**existing_f, **normalized_new}
|
||||
# Keep the better title
|
||||
if normalized_new.get("title") and len(normalized_new.get("title", "")) > len(existing_f.get("title", "")):
|
||||
merged["title"] = normalized_new["title"]
|
||||
# Keep verified status if either is verified
|
||||
if existing_f.get("is_verified") or normalized_new.get("is_verified"):
|
||||
merged["is_verified"] = True
|
||||
self._all_findings[i] = merged
|
||||
found = True
|
||||
logger.debug(f"[Orchestrator] Updated existing finding: {new_key}")
|
||||
logger.info(f"[Orchestrator] Merged finding: {new_file}:{new_line} ({new_type})")
|
||||
break
|
||||
|
||||
if not found:
|
||||
self._all_findings.append(new_f)
|
||||
logger.debug(f"[Orchestrator] Added new finding: {new_key}")
|
||||
self._all_findings.append(normalized_new)
|
||||
logger.info(f"[Orchestrator] Added new finding: {new_file}:{new_line} ({new_type})")
|
||||
|
||||
logger.info(f"[Orchestrator] Total findings now: {len(self._all_findings)}")
|
||||
else:
|
||||
|
|
@ -752,13 +788,13 @@ Action Input: {{"参数": "值"}}
|
|||
observation = f"""## {agent_name} Agent 执行结果
|
||||
|
||||
**状态**: 成功
|
||||
**发现数量**: {len(findings)}
|
||||
**发现数量**: {len(valid_findings)}
|
||||
**迭代次数**: {result.iterations}
|
||||
**耗时**: {result.duration_ms}ms
|
||||
|
||||
### 发现摘要
|
||||
"""
|
||||
for i, f in enumerate(findings[:10]):
|
||||
for i, f in enumerate(valid_findings[:10]):
|
||||
if not isinstance(f, dict):
|
||||
continue
|
||||
observation += f"""
|
||||
|
|
@ -768,8 +804,8 @@ Action Input: {{"参数": "值"}}
|
|||
- 描述: {f.get('description', '')[:200]}...
|
||||
"""
|
||||
|
||||
if len(findings) > 10:
|
||||
observation += f"\n... 还有 {len(findings) - 10} 个发现"
|
||||
if len(valid_findings) > 10:
|
||||
observation += f"\n... 还有 {len(valid_findings) - 10} 个发现"
|
||||
|
||||
if data.get("summary"):
|
||||
observation += f"\n\n### Agent 总结\n{data['summary']}"
|
||||
|
|
@ -782,6 +818,94 @@ Action Input: {{"参数": "值"}}
|
|||
logger.error(f"Sub-agent dispatch failed: {e}", exc_info=True)
|
||||
return f"## 调度失败\n\n错误: {str(e)}"
|
||||
|
||||
def _normalize_finding(self, finding: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""
|
||||
标准化发现格式
|
||||
|
||||
不同 Agent 可能返回不同格式的发现,这个方法将它们标准化为统一格式
|
||||
"""
|
||||
normalized = dict(finding) # 复制原始数据
|
||||
|
||||
# 🔥 处理 location 字段 -> file_path + line_start
|
||||
if "location" in normalized and "file_path" not in normalized:
|
||||
location = normalized["location"]
|
||||
if isinstance(location, str) and ":" in location:
|
||||
parts = location.split(":")
|
||||
normalized["file_path"] = parts[0]
|
||||
try:
|
||||
normalized["line_start"] = int(parts[1])
|
||||
except (ValueError, IndexError):
|
||||
pass
|
||||
elif isinstance(location, str):
|
||||
normalized["file_path"] = location
|
||||
|
||||
# 🔥 处理 file 字段 -> file_path
|
||||
if "file" in normalized and "file_path" not in normalized:
|
||||
normalized["file_path"] = normalized["file"]
|
||||
|
||||
# 🔥 处理 line 字段 -> line_start
|
||||
if "line" in normalized and "line_start" not in normalized:
|
||||
normalized["line_start"] = normalized["line"]
|
||||
|
||||
# 🔥 处理 type 字段 -> vulnerability_type
|
||||
if "type" in normalized and "vulnerability_type" not in normalized:
|
||||
# 不是所有 type 都是漏洞类型,比如 "Vulnerability" 只是标记
|
||||
type_val = normalized["type"]
|
||||
if type_val and type_val.lower() not in ["vulnerability", "finding", "issue"]:
|
||||
normalized["vulnerability_type"] = type_val
|
||||
elif "description" in normalized:
|
||||
# 尝试从描述中推断漏洞类型
|
||||
desc = normalized["description"].lower()
|
||||
if "command injection" in desc or "rce" in desc or "system(" in desc:
|
||||
normalized["vulnerability_type"] = "command_injection"
|
||||
elif "sql injection" in desc or "sqli" in desc:
|
||||
normalized["vulnerability_type"] = "sql_injection"
|
||||
elif "xss" in desc or "cross-site scripting" in desc:
|
||||
normalized["vulnerability_type"] = "xss"
|
||||
elif "path traversal" in desc or "directory traversal" in desc:
|
||||
normalized["vulnerability_type"] = "path_traversal"
|
||||
elif "ssrf" in desc:
|
||||
normalized["vulnerability_type"] = "ssrf"
|
||||
elif "xxe" in desc:
|
||||
normalized["vulnerability_type"] = "xxe"
|
||||
else:
|
||||
normalized["vulnerability_type"] = "other"
|
||||
|
||||
# 🔥 确保 severity 字段存在且为小写
|
||||
if "severity" in normalized:
|
||||
normalized["severity"] = str(normalized["severity"]).lower()
|
||||
else:
|
||||
normalized["severity"] = "medium"
|
||||
|
||||
# 🔥 处理 risk 字段 -> severity
|
||||
if "risk" in normalized and "severity" not in normalized:
|
||||
normalized["severity"] = str(normalized["risk"]).lower()
|
||||
|
||||
# 🔥 生成 title 如果不存在
|
||||
if "title" not in normalized:
|
||||
vuln_type = normalized.get("vulnerability_type", "Unknown")
|
||||
file_path = normalized.get("file_path", "")
|
||||
if file_path:
|
||||
import os
|
||||
normalized["title"] = f"{vuln_type.replace('_', ' ').title()} in {os.path.basename(file_path)}"
|
||||
else:
|
||||
normalized["title"] = f"{vuln_type.replace('_', ' ').title()} Vulnerability"
|
||||
|
||||
# 🔥 处理 code 字段 -> code_snippet
|
||||
if "code" in normalized and "code_snippet" not in normalized:
|
||||
normalized["code_snippet"] = normalized["code"]
|
||||
|
||||
# 🔥 处理 recommendation -> suggestion
|
||||
if "recommendation" in normalized and "suggestion" not in normalized:
|
||||
normalized["suggestion"] = normalized["recommendation"]
|
||||
|
||||
# 🔥 处理 impact -> 添加到 description
|
||||
if "impact" in normalized and normalized.get("description"):
|
||||
if "impact" not in normalized["description"].lower():
|
||||
normalized["description"] += f"\n\nImpact: {normalized['impact']}"
|
||||
|
||||
return normalized
|
||||
|
||||
def _summarize_findings(self) -> str:
|
||||
"""汇总当前发现"""
|
||||
if not self._all_findings:
|
||||
|
|
|
|||
|
|
@ -120,6 +120,10 @@ Final Answer: [JSON 格式的验证报告]
|
|||
2. **深入理解** - 理解代码逻辑,不要表面判断
|
||||
3. **证据支撑** - 判定要有依据
|
||||
4. **安全第一** - 沙箱测试要谨慎
|
||||
5. **🔥 PoC 生成** - 对于 confirmed 和 likely 的漏洞,**必须**生成 PoC:
|
||||
- poc.description: 简要描述这个 PoC 的作用
|
||||
- poc.steps: 详细的复现步骤列表
|
||||
- poc.payload: 实际的攻击载荷或测试代码
|
||||
|
||||
现在开始验证漏洞发现!"""
|
||||
|
||||
|
|
|
|||
|
|
@ -171,6 +171,7 @@ class AgentEventEmitter:
|
|||
finding_id=finding_id,
|
||||
message=f"{'✅ 已验证' if is_verified else '🔍 新发现'}: [{severity.upper()}] {title}",
|
||||
metadata={
|
||||
"id": finding_id, # 🔥 添加 id 字段供前端使用
|
||||
"title": title,
|
||||
"severity": severity,
|
||||
"vulnerability_type": vulnerability_type,
|
||||
|
|
@ -331,6 +332,31 @@ class EventManager:
|
|||
"""保存事件到数据库"""
|
||||
from app.models.agent_task import AgentEvent
|
||||
|
||||
# 🔥 清理无效的 UTF-8 字符(如二进制内容)
|
||||
def sanitize_string(s):
|
||||
"""清理字符串中的无效 UTF-8 字符"""
|
||||
if s is None:
|
||||
return None
|
||||
if not isinstance(s, str):
|
||||
s = str(s)
|
||||
# 移除 NULL 字节和其他不可打印的控制字符(保留换行和制表符)
|
||||
return ''.join(
|
||||
char for char in s
|
||||
if char in '\n\r\t' or (ord(char) >= 32 and ord(char) != 127)
|
||||
)
|
||||
|
||||
def sanitize_dict(d):
|
||||
"""递归清理字典中的字符串值"""
|
||||
if d is None:
|
||||
return None
|
||||
if isinstance(d, dict):
|
||||
return {k: sanitize_dict(v) for k, v in d.items()}
|
||||
elif isinstance(d, list):
|
||||
return [sanitize_dict(item) for item in d]
|
||||
elif isinstance(d, str):
|
||||
return sanitize_string(d)
|
||||
return d
|
||||
|
||||
async with self.db_session_factory() as db:
|
||||
event = AgentEvent(
|
||||
id=event_data["id"],
|
||||
|
|
@ -338,14 +364,14 @@ class EventManager:
|
|||
event_type=event_data["event_type"],
|
||||
sequence=event_data["sequence"],
|
||||
phase=event_data["phase"],
|
||||
message=event_data["message"],
|
||||
message=sanitize_string(event_data["message"]), # 🔥 清理消息
|
||||
tool_name=event_data["tool_name"],
|
||||
tool_input=event_data["tool_input"],
|
||||
tool_output=event_data["tool_output"],
|
||||
tool_input=sanitize_dict(event_data["tool_input"]), # 🔥 清理工具输入
|
||||
tool_output=sanitize_dict(event_data["tool_output"]), # 🔥 清理工具输出
|
||||
tool_duration_ms=event_data["tool_duration_ms"],
|
||||
finding_id=event_data["finding_id"],
|
||||
tokens_used=event_data["tokens_used"],
|
||||
event_metadata=event_data["metadata"],
|
||||
event_metadata=sanitize_dict(event_data["metadata"]), # 🔥 清理元数据
|
||||
)
|
||||
db.add(event)
|
||||
await db.commit()
|
||||
|
|
@ -406,7 +432,10 @@ class EventManager:
|
|||
|
||||
🔥 重要: 此方法会先排空队列中已缓存的事件(在 SSE 连接前产生的),
|
||||
然后继续实时推送新事件。
|
||||
只返回序列号 > after_sequence 的事件。
|
||||
"""
|
||||
logger.info(f"[StreamEvents] Task {task_id}: Starting stream with after_sequence={after_sequence}")
|
||||
|
||||
# 获取现有队列(由 AgentRunner 在初始化时创建)
|
||||
queue = self._event_queues.get(task_id)
|
||||
|
||||
|
|
@ -417,9 +446,17 @@ class EventManager:
|
|||
|
||||
# 🔥 先排空队列中已缓存的事件(这些是在 SSE 连接前产生的)
|
||||
buffered_count = 0
|
||||
skipped_count = 0
|
||||
while not queue.empty():
|
||||
try:
|
||||
buffered_event = queue.get_nowait()
|
||||
|
||||
# 🔥 过滤掉序列号 <= after_sequence 的事件
|
||||
event_sequence = buffered_event.get("sequence", 0)
|
||||
if event_sequence <= after_sequence:
|
||||
skipped_count += 1
|
||||
continue
|
||||
|
||||
buffered_count += 1
|
||||
yield buffered_event
|
||||
|
||||
|
|
@ -432,19 +469,25 @@ class EventManager:
|
|||
|
||||
# 检查是否是结束事件
|
||||
if event_type in ["task_complete", "task_error", "task_cancel"]:
|
||||
logger.debug(f"Task {task_id} already completed, sent {buffered_count} buffered events")
|
||||
logger.debug(f"Task {task_id} already completed, sent {buffered_count} buffered events (skipped {skipped_count})")
|
||||
return
|
||||
except asyncio.QueueEmpty:
|
||||
break
|
||||
|
||||
if buffered_count > 0:
|
||||
logger.debug(f"Drained {buffered_count} buffered events for task {task_id}")
|
||||
if buffered_count > 0 or skipped_count > 0:
|
||||
logger.debug(f"Drained queue for task {task_id}: sent {buffered_count}, skipped {skipped_count} (after_sequence={after_sequence})")
|
||||
|
||||
# 然后实时推送新事件
|
||||
try:
|
||||
while True:
|
||||
try:
|
||||
event = await asyncio.wait_for(queue.get(), timeout=30)
|
||||
|
||||
# 🔥 过滤掉序列号 <= after_sequence 的事件
|
||||
event_sequence = event.get("sequence", 0)
|
||||
if event_sequence <= after_sequence:
|
||||
continue
|
||||
|
||||
yield event
|
||||
|
||||
# 🔥 为 thinking_token 添加微延迟确保流式效果
|
||||
|
|
|
|||
|
|
@ -274,15 +274,20 @@ class AgentRunner:
|
|||
|
||||
# 沙箱工具(仅 Verification Agent 可用)
|
||||
try:
|
||||
self.sandbox_manager = SandboxManager(
|
||||
from app.services.agent.tools.sandbox_tool import SandboxConfig
|
||||
sandbox_config = SandboxConfig(
|
||||
image=settings.SANDBOX_IMAGE,
|
||||
memory_limit=settings.SANDBOX_MEMORY_LIMIT,
|
||||
cpu_limit=settings.SANDBOX_CPU_LIMIT,
|
||||
timeout=settings.SANDBOX_TIMEOUT,
|
||||
network_mode=settings.SANDBOX_NETWORK_MODE,
|
||||
)
|
||||
self.sandbox_manager = SandboxManager(config=sandbox_config)
|
||||
|
||||
self.verification_tools["sandbox_exec"] = SandboxTool(self.sandbox_manager)
|
||||
self.verification_tools["sandbox_http"] = SandboxHttpTool(self.sandbox_manager)
|
||||
self.verification_tools["verify_vulnerability"] = VulnerabilityVerifyTool(self.sandbox_manager)
|
||||
logger.info("Sandbox tools initialized successfully")
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"Sandbox initialization failed: {e}")
|
||||
|
|
|
|||
|
|
@ -256,11 +256,14 @@ class LiteLLMAdapter(BaseLLMAdapter):
|
|||
|
||||
accumulated_content = ""
|
||||
final_usage = None # 🔥 存储最终的 usage 信息
|
||||
chunk_count = 0 # 🔥 跟踪 chunk 数量
|
||||
|
||||
try:
|
||||
response = await litellm.acompletion(**kwargs)
|
||||
|
||||
async for chunk in response:
|
||||
chunk_count += 1
|
||||
|
||||
# 🔥 检查是否有 usage 信息(某些 API 会在最后的 chunk 中包含)
|
||||
if hasattr(chunk, "usage") and chunk.usage:
|
||||
final_usage = {
|
||||
|
|
@ -271,6 +274,7 @@ class LiteLLMAdapter(BaseLLMAdapter):
|
|||
logger.debug(f"Got usage from chunk: {final_usage}")
|
||||
|
||||
if not chunk.choices:
|
||||
# 🔥 某些模型可能发送没有 choices 的 chunk(如心跳)
|
||||
continue
|
||||
|
||||
delta = chunk.choices[0].delta
|
||||
|
|
@ -284,9 +288,8 @@ class LiteLLMAdapter(BaseLLMAdapter):
|
|||
"content": content,
|
||||
"accumulated": accumulated_content,
|
||||
}
|
||||
else:
|
||||
# Log when we get a chunk without content
|
||||
logger.debug(f"Chunk with no content: {chunk}")
|
||||
# 🔥 ENHANCED: 处理没有 content 但也没有 finish_reason 的情况
|
||||
# 某些模型(如智谱 GLM)可能在某些 chunk 中不返回内容
|
||||
|
||||
if finish_reason:
|
||||
# 流式完成
|
||||
|
|
@ -300,6 +303,10 @@ class LiteLLMAdapter(BaseLLMAdapter):
|
|||
}
|
||||
logger.debug(f"Estimated usage: {final_usage}")
|
||||
|
||||
# 🔥 ENHANCED: 如果累积内容为空但有 finish_reason,记录警告
|
||||
if not accumulated_content:
|
||||
logger.warning(f"Stream completed with no content after {chunk_count} chunks, finish_reason={finish_reason}")
|
||||
|
||||
yield {
|
||||
"type": "done",
|
||||
"content": accumulated_content,
|
||||
|
|
@ -308,8 +315,26 @@ class LiteLLMAdapter(BaseLLMAdapter):
|
|||
}
|
||||
break
|
||||
|
||||
# 🔥 ENHANCED: 如果循环结束但没有收到 finish_reason,也需要返回 done
|
||||
if accumulated_content:
|
||||
logger.warning(f"Stream ended without finish_reason, returning accumulated content ({len(accumulated_content)} chars)")
|
||||
if not final_usage:
|
||||
output_tokens_estimate = estimate_tokens(accumulated_content)
|
||||
final_usage = {
|
||||
"prompt_tokens": input_tokens_estimate,
|
||||
"completion_tokens": output_tokens_estimate,
|
||||
"total_tokens": input_tokens_estimate + output_tokens_estimate,
|
||||
}
|
||||
yield {
|
||||
"type": "done",
|
||||
"content": accumulated_content,
|
||||
"usage": final_usage,
|
||||
"finish_reason": "complete",
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
# 🔥 即使出错,也尝试返回估算的 usage
|
||||
logger.error(f"Stream error: {e}")
|
||||
output_tokens_estimate = estimate_tokens(accumulated_content) if accumulated_content else 0
|
||||
yield {
|
||||
"type": "error",
|
||||
|
|
|
|||
|
|
@ -42,6 +42,16 @@ function agentAuditReducer(state: AgentAuditState, action: AgentAuditAction): Ag
|
|||
case 'SET_FINDINGS':
|
||||
return { ...state, findings: action.payload };
|
||||
|
||||
case 'ADD_FINDING': {
|
||||
// 🔥 添加单个 finding,避免重复
|
||||
const newFinding = action.payload;
|
||||
const existingIds = new Set(state.findings.map(f => f.id));
|
||||
if (newFinding.id && existingIds.has(newFinding.id)) {
|
||||
return state; // 已存在,不添加
|
||||
}
|
||||
return { ...state, findings: [...state.findings, newFinding] };
|
||||
}
|
||||
|
||||
case 'SET_AGENT_TREE':
|
||||
return { ...state, agentTree: action.payload };
|
||||
|
||||
|
|
|
|||
|
|
@ -63,7 +63,11 @@ function AgentAuditPageContent() {
|
|||
const previousTaskIdRef = useRef<string | undefined>(undefined);
|
||||
const disconnectStreamRef = useRef<(() => void) | null>(null);
|
||||
const lastEventSequenceRef = useRef<number>(0);
|
||||
const historicalEventsLoadedRef = useRef<boolean>(false);
|
||||
const hasConnectedRef = useRef<boolean>(false); // 🔥 追踪是否已连接 SSE
|
||||
const hasLoadedHistoricalEventsRef = useRef<boolean>(false); // 🔥 追踪是否已加载历史事件
|
||||
// 🔥 使用 state 来标记历史事件加载状态和触发 streamOptions 重新计算
|
||||
const [afterSequence, setAfterSequence] = useState<number>(0);
|
||||
const [historicalEventsLoaded, setHistoricalEventsLoaded] = useState<boolean>(false);
|
||||
|
||||
// 🔥 当 taskId 变化时立即重置状态(新建任务时清理旧日志)
|
||||
useEffect(() => {
|
||||
|
|
@ -79,7 +83,10 @@ function AgentAuditPageContent() {
|
|||
setShowSplash(!taskId);
|
||||
// 3. 重置事件序列号和加载状态
|
||||
lastEventSequenceRef.current = 0;
|
||||
historicalEventsLoadedRef.current = false;
|
||||
hasConnectedRef.current = false; // 🔥 重置 SSE 连接标志
|
||||
hasLoadedHistoricalEventsRef.current = false; // 🔥 重置历史事件加载标志
|
||||
setHistoricalEventsLoaded(false); // 🔥 重置历史事件加载状态
|
||||
setAfterSequence(0); // 🔥 重置 afterSequence state
|
||||
}
|
||||
previousTaskIdRef.current = taskId;
|
||||
}, [taskId, reset]);
|
||||
|
|
@ -141,6 +148,14 @@ function AgentAuditPageContent() {
|
|||
// 🔥 NEW: 加载历史事件并转换为日志项
|
||||
const loadHistoricalEvents = useCallback(async () => {
|
||||
if (!taskId) return 0;
|
||||
|
||||
// 🔥 防止重复加载历史事件
|
||||
if (hasLoadedHistoricalEventsRef.current) {
|
||||
console.log('[AgentAudit] Historical events already loaded, skipping');
|
||||
return 0;
|
||||
}
|
||||
hasLoadedHistoricalEventsRef.current = true;
|
||||
|
||||
try {
|
||||
console.log(`[AgentAudit] Fetching historical events for task ${taskId}...`);
|
||||
const events = await getAgentEvents(taskId, { limit: 500 });
|
||||
|
|
@ -356,20 +371,22 @@ function AgentAuditPageContent() {
|
|||
});
|
||||
|
||||
console.log(`[AgentAudit] Processed ${processedCount} events into logs, last sequence: ${lastEventSequenceRef.current}`);
|
||||
// 🔥 更新 afterSequence state,触发 streamOptions 重新计算
|
||||
setAfterSequence(lastEventSequenceRef.current);
|
||||
return events.length;
|
||||
} catch (err) {
|
||||
console.error('[AgentAudit] Failed to load historical events:', err);
|
||||
return 0;
|
||||
}
|
||||
}, [taskId, dispatch]);
|
||||
}, [taskId, dispatch, setAfterSequence]);
|
||||
|
||||
// ============ Stream Event Handling ============
|
||||
|
||||
const streamOptions = useMemo(() => ({
|
||||
includeThinking: true,
|
||||
includeToolCalls: true,
|
||||
// 🔥 使用最后的事件序列号,避免重复接收历史事件
|
||||
afterSequence: lastEventSequenceRef.current,
|
||||
// 🔥 使用 state 变量,确保在历史事件加载后能获取最新值
|
||||
afterSequence: afterSequence,
|
||||
onEvent: (event: { type: string; message?: string; metadata?: { agent_name?: string; agent?: string } }) => {
|
||||
if (event.metadata?.agent_name) {
|
||||
setCurrentAgentName(event.metadata.agent_name);
|
||||
|
|
@ -478,7 +495,20 @@ function AgentAuditPageContent() {
|
|||
agentName: getCurrentAgentName() || undefined,
|
||||
}
|
||||
});
|
||||
loadFindings();
|
||||
// 🔥 直接将 finding 添加到状态,不依赖 API(因为运行时数据库还没有数据)
|
||||
dispatch({
|
||||
type: 'ADD_FINDING',
|
||||
payload: {
|
||||
id: (finding.id as string) || `finding-${Date.now()}`,
|
||||
title: (finding.title as string) || 'Vulnerability found',
|
||||
severity: (finding.severity as string) || 'medium',
|
||||
vulnerability_type: (finding.vulnerability_type as string) || 'unknown',
|
||||
file_path: finding.file_path as string,
|
||||
line_start: finding.line_start as number,
|
||||
description: finding.description as string,
|
||||
is_verified: (finding.is_verified as boolean) || false,
|
||||
}
|
||||
});
|
||||
},
|
||||
onComplete: () => {
|
||||
dispatch({ type: 'ADD_LOG', payload: { type: 'info', title: 'Audit completed successfully' } });
|
||||
|
|
@ -489,7 +519,7 @@ function AgentAuditPageContent() {
|
|||
onError: (err: string) => {
|
||||
dispatch({ type: 'ADD_LOG', payload: { type: 'error', title: `Error: ${err}` } });
|
||||
},
|
||||
}), [dispatch, loadTask, loadFindings, loadAgentTree, debouncedLoadAgentTree,
|
||||
}), [afterSequence, dispatch, loadTask, loadFindings, loadAgentTree, debouncedLoadAgentTree,
|
||||
updateLog, removeLog, getCurrentAgentName, getCurrentThinkingId,
|
||||
setCurrentAgentName, setCurrentThinkingId]);
|
||||
|
||||
|
|
@ -523,7 +553,7 @@ function AgentAuditPageContent() {
|
|||
}
|
||||
setShowSplash(false);
|
||||
setLoading(true);
|
||||
historicalEventsLoadedRef.current = false;
|
||||
setHistoricalEventsLoaded(false);
|
||||
|
||||
const loadAllData = async () => {
|
||||
try {
|
||||
|
|
@ -534,11 +564,11 @@ function AgentAuditPageContent() {
|
|||
const eventsLoaded = await loadHistoricalEvents();
|
||||
console.log(`[AgentAudit] Loaded ${eventsLoaded} historical events for task ${taskId}`);
|
||||
|
||||
// 标记历史事件已加载完成
|
||||
historicalEventsLoadedRef.current = true;
|
||||
// 标记历史事件已加载完成 (setAfterSequence 已在 loadHistoricalEvents 中调用)
|
||||
setHistoricalEventsLoaded(true);
|
||||
} catch (error) {
|
||||
console.error('[AgentAudit] Failed to load data:', error);
|
||||
historicalEventsLoadedRef.current = true; // 即使出错也标记为完成,避免无限等待
|
||||
setHistoricalEventsLoaded(true); // 即使出错也标记为完成,避免无限等待
|
||||
} finally {
|
||||
setLoading(false);
|
||||
}
|
||||
|
|
@ -552,20 +582,21 @@ function AgentAuditPageContent() {
|
|||
// 等待历史事件加载完成,且任务正在运行
|
||||
if (!taskId || !task?.status || task.status !== 'running') return;
|
||||
|
||||
// 如果历史事件尚未加载完成,等待一下
|
||||
const checkAndConnect = () => {
|
||||
if (historicalEventsLoadedRef.current) {
|
||||
// 🔥 使用 state 变量确保在历史事件加载完成后才连接
|
||||
if (!historicalEventsLoaded) return;
|
||||
|
||||
// 🔥 避免重复连接 - 只连接一次
|
||||
if (hasConnectedRef.current) return;
|
||||
|
||||
hasConnectedRef.current = true;
|
||||
console.log(`[AgentAudit] Connecting to stream with afterSequence=${afterSequence}`);
|
||||
connectStream();
|
||||
dispatch({ type: 'ADD_LOG', payload: { type: 'info', title: 'Connected to audit stream' } });
|
||||
} else {
|
||||
// 延迟重试
|
||||
setTimeout(checkAndConnect, 100);
|
||||
}
|
||||
};
|
||||
|
||||
checkAndConnect();
|
||||
return () => disconnectStream();
|
||||
}, [taskId, task?.status, connectStream, disconnectStream, dispatch]);
|
||||
return () => {
|
||||
disconnectStream();
|
||||
};
|
||||
}, [taskId, task?.status, historicalEventsLoaded, connectStream, disconnectStream, dispatch, afterSequence]);
|
||||
|
||||
// Polling
|
||||
useEffect(() => {
|
||||
|
|
|
|||
|
|
@ -71,6 +71,7 @@ export interface AgentTreeResponse {
|
|||
export type AgentAuditAction =
|
||||
| { type: 'SET_TASK'; payload: AgentTask }
|
||||
| { type: 'SET_FINDINGS'; payload: AgentFinding[] }
|
||||
| { type: 'ADD_FINDING'; payload: Partial<AgentFinding> & { id: string } }
|
||||
| { type: 'SET_AGENT_TREE'; payload: AgentTreeResponse }
|
||||
| { type: 'SET_LOGS'; payload: LogItem[] }
|
||||
| { type: 'ADD_LOG'; payload: Omit<LogItem, 'id' | 'time'> & { id?: string } }
|
||||
|
|
|
|||
|
|
@ -396,6 +396,7 @@ export class AgentStreamHandler {
|
|||
break;
|
||||
|
||||
// 发现
|
||||
case 'finding': // 🔥 向后兼容旧的事件类型
|
||||
case 'finding_new':
|
||||
case 'finding_verified':
|
||||
this.options.onFinding?.(
|
||||
|
|
|
|||
Loading…
Reference in New Issue