feat(agent): 增强漏洞发现处理流程和前端兼容性

- 后端添加对旧事件类型'finding'的兼容支持
- 改进漏洞发现标准化和去重逻辑
- 新增PoC生成要求和相关字段
- 优化沙箱配置初始化流程
- 前端添加ADD_FINDING操作和状态管理
- 增强事件流处理和序列号过滤
- 改进历史事件加载和SSE连接逻辑
- 添加漏洞验证状态和PoC信息到报告
This commit is contained in:
lintsinghua 2025-12-13 18:45:05 +08:00
parent 6d98f29fa6
commit 4e4dd05ddb
13 changed files with 695 additions and 156 deletions

View File

@ -393,10 +393,20 @@ async def _execute_agent_task(task_id: str):
task.tool_calls_count = result.tool_calls
task.tokens_used = result.tokens_used
# 统计严重程度
# 🔥 统计分析的文件数量(从 findings 中提取唯一文件)
analyzed_file_set = set()
for f in findings:
if isinstance(f, dict):
sev = f.get("severity", "low")
file_path = f.get("file_path") or f.get("file") or f.get("location", "").split(":")[0]
if file_path:
analyzed_file_set.add(file_path)
task.analyzed_files = len(analyzed_file_set) if analyzed_file_set else task.total_files
# 统计严重程度和验证状态
verified_count = 0
for f in findings:
if isinstance(f, dict):
sev = str(f.get("severity", "low")).lower()
if sev == "critical":
task.critical_count += 1
elif sev == "high":
@ -405,6 +415,10 @@ async def _execute_agent_task(task_id: str):
task.medium_count += 1
elif sev == "low":
task.low_count += 1
# 🔥 统计已验证的发现
if f.get("is_verified") or f.get("verdict") == "confirmed":
verified_count += 1
task.verified_count = verified_count
# 计算安全评分
task.security_score = _calculate_security_score(findings)
@ -462,15 +476,21 @@ async def _execute_agent_task(task_id: str):
logger.error(f"Failed to update task status: {db_error}")
finally:
# 🔥 在清理之前保存 Agent 树到数据库
try:
async with async_session_factory() as save_db:
await _save_agent_tree(save_db, task_id)
except Exception as save_error:
logger.error(f"Failed to save agent tree: {save_error}")
# 清理
_running_orchestrators.pop(task_id, None)
_running_tasks.pop(task_id, None)
_running_event_managers.pop(task_id, None)
_running_asyncio_tasks.pop(task_id, None) # 🔥 清理 asyncio task
# 从 Registry 注销
if orchestrator:
agent_registry.unregister_agent(orchestrator.agent_id)
# 🔥 清理整个 Agent 注册表(包括所有子 Agent
agent_registry.clear()
logger.debug(f"Task {task_id} cleaned up")
@ -713,7 +733,11 @@ async def _collect_project_info(
async def _save_findings(db: AsyncSession, task_id: str, findings: List[Dict]) -> None:
"""保存发现到数据库"""
"""
保存发现到数据库
🔥 增强版支持多种 Agent 输出格式健壮的字段映射
"""
from app.models.agent_task import VulnerabilityType
logger.info(f"[SaveFindings] Starting to save {len(findings)} findings for task {task_id}")
@ -744,8 +768,12 @@ async def _save_findings(db: AsyncSession, task_id: str, findings: List[Dict]) -
"idor": VulnerabilityType.IDOR,
"sensitive_data_exposure": VulnerabilityType.SENSITIVE_DATA_EXPOSURE,
"hardcoded_secret": VulnerabilityType.HARDCODED_SECRET,
"deserialization": VulnerabilityType.DESERIALIZATION, # Added common type
"weak_crypto": VulnerabilityType.WEAK_CRYPTO, # Added common type
"deserialization": VulnerabilityType.DESERIALIZATION,
"weak_crypto": VulnerabilityType.WEAK_CRYPTO,
"file_inclusion": VulnerabilityType.FILE_INCLUSION,
"race_condition": VulnerabilityType.RACE_CONDITION,
"business_logic": VulnerabilityType.BUSINESS_LOGIC,
"memory_corruption": VulnerabilityType.MEMORY_CORRUPTION,
}
saved_count = 0
@ -753,49 +781,185 @@ async def _save_findings(db: AsyncSession, task_id: str, findings: List[Dict]) -
for finding in findings:
if not isinstance(finding, dict):
logger.debug(f"[SaveFindings] Skipping non-dict finding: {type(finding)}")
continue
try:
# Handle severity (case-insensitive)
raw_severity = str(finding.get("severity", "medium")).lower().strip()
# 🔥 Handle severity (case-insensitive, support multiple field names)
raw_severity = str(
finding.get("severity") or
finding.get("risk") or
"medium"
).lower().strip()
severity_enum = severity_map.get(raw_severity, VulnerabilitySeverity.MEDIUM)
# Handle vulnerability type (case-insensitive & snake_case normalization)
raw_type = str(finding.get("vulnerability_type", "other")).lower().strip().replace(" ", "_")
# 🔥 Handle vulnerability type (case-insensitive & snake_case normalization)
# Support multiple field names: vulnerability_type, type, vuln_type
raw_type = str(
finding.get("vulnerability_type") or
finding.get("type") or
finding.get("vuln_type") or
"other"
).lower().strip().replace(" ", "_").replace("-", "_")
type_enum = type_map.get(raw_type, VulnerabilityType.OTHER)
# Additional fallback for known Agent output variations
if "sqli" in raw_type: type_enum = VulnerabilityType.SQL_INJECTION
if "xss" in raw_type: type_enum = VulnerabilityType.XSS
if "rce" in raw_type or "command" in raw_type: type_enum = VulnerabilityType.COMMAND_INJECTION
# 🔥 Additional fallback for common Agent output variations
if "sqli" in raw_type or "sql" in raw_type:
type_enum = VulnerabilityType.SQL_INJECTION
if "xss" in raw_type:
type_enum = VulnerabilityType.XSS
if "rce" in raw_type or "command" in raw_type or "cmd" in raw_type:
type_enum = VulnerabilityType.COMMAND_INJECTION
if "traversal" in raw_type or "lfi" in raw_type or "rfi" in raw_type:
type_enum = VulnerabilityType.PATH_TRAVERSAL
if "ssrf" in raw_type:
type_enum = VulnerabilityType.SSRF
if "xxe" in raw_type:
type_enum = VulnerabilityType.XXE
if "auth" in raw_type:
type_enum = VulnerabilityType.AUTH_BYPASS
if "secret" in raw_type or "credential" in raw_type or "password" in raw_type:
type_enum = VulnerabilityType.HARDCODED_SECRET
if "deserial" in raw_type:
type_enum = VulnerabilityType.DESERIALIZATION
# 🔥 Handle file path (support multiple field names)
file_path = (
finding.get("file_path") or
finding.get("file") or
finding.get("location", "").split(":")[0] if ":" in finding.get("location", "") else finding.get("location")
)
# 🔥 Handle line numbers (support multiple formats)
line_start = finding.get("line_start") or finding.get("line")
if not line_start and ":" in finding.get("location", ""):
try:
line_start = int(finding.get("location", "").split(":")[1])
except (ValueError, IndexError):
line_start = None
line_end = finding.get("line_end") or line_start
# 🔥 Handle code snippet (support multiple field names)
code_snippet = (
finding.get("code_snippet") or
finding.get("code") or
finding.get("vulnerable_code")
)
# 🔥 Handle title (generate from type if not provided)
title = finding.get("title")
if not title:
# Generate title from vulnerability type and file
type_display = raw_type.replace("_", " ").title()
if file_path:
title = f"{type_display} in {os.path.basename(file_path)}"
else:
title = f"{type_display} Vulnerability"
# 🔥 Handle description (support multiple field names)
description = (
finding.get("description") or
finding.get("details") or
finding.get("explanation") or
finding.get("impact") or
""
)
# 🔥 Handle suggestion/recommendation
suggestion = (
finding.get("suggestion") or
finding.get("recommendation") or
finding.get("remediation") or
finding.get("fix")
)
# 🔥 Handle confidence (map to ai_confidence field in model)
confidence = finding.get("confidence") or finding.get("ai_confidence") or 0.5
if isinstance(confidence, str):
try:
confidence = float(confidence)
except ValueError:
confidence = 0.5
# 🔥 Handle verification status
is_verified = finding.get("is_verified", False)
if finding.get("verdict") == "confirmed":
is_verified = True
# 🔥 Handle PoC information
poc_data = finding.get("poc", {})
has_poc = bool(poc_data)
poc_code = None
poc_description = None
poc_steps = None
if isinstance(poc_data, dict):
poc_description = poc_data.get("description")
poc_steps = poc_data.get("steps")
poc_code = poc_data.get("payload") or poc_data.get("code")
elif isinstance(poc_data, str):
poc_description = poc_data
# 🔥 Handle verification details
verification_method = finding.get("verification_method")
verification_result = None
if finding.get("verification_details"):
verification_result = {"details": finding.get("verification_details")}
# 🔥 Handle CWE and CVSS
cwe_id = finding.get("cwe_id") or finding.get("cwe")
cvss_score = finding.get("cvss_score") or finding.get("cvss")
if isinstance(cvss_score, str):
try:
cvss_score = float(cvss_score)
except ValueError:
cvss_score = None
db_finding = AgentFinding(
id=str(uuid4()),
task_id=task_id,
vulnerability_type=type_enum,
severity=severity_enum,
title=finding.get("title", "Unknown"),
description=finding.get("description", ""),
file_path=finding.get("file_path"),
line_start=finding.get("line_start"),
line_end=finding.get("line_end"),
code_snippet=finding.get("code_snippet"),
suggestion=finding.get("suggestion") or finding.get("recommendation"),
is_verified=finding.get("is_verified", False),
confidence=finding.get("confidence", 0.5),
status=FindingStatus.VERIFIED if finding.get("is_verified") else FindingStatus.NEW,
title=title[:500] if title else "Unknown Vulnerability",
description=description[:5000] if description else "",
file_path=file_path[:500] if file_path else None,
line_start=line_start,
line_end=line_end,
code_snippet=code_snippet[:10000] if code_snippet else None,
suggestion=suggestion[:5000] if suggestion else None,
is_verified=is_verified,
ai_confidence=confidence, # 🔥 FIX: Use ai_confidence, not confidence
status=FindingStatus.VERIFIED if is_verified else FindingStatus.NEW,
# 🔥 Additional fields
has_poc=has_poc,
poc_code=poc_code,
poc_description=poc_description,
poc_steps=poc_steps,
verification_method=verification_method,
verification_result=verification_result,
cvss_score=cvss_score,
# References for CWE
references=[{"cwe": cwe_id}] if cwe_id else None,
)
db.add(db_finding)
saved_count += 1
logger.debug(f"[SaveFindings] Prepared finding: {title[:50]}... ({severity_enum})")
except Exception as e:
logger.warning(f"Failed to save finding: {e}, data: {finding}")
import traceback
logger.debug(f"[SaveFindings] Traceback: {traceback.format_exc()}")
logger.info(f"Successfully prepared {saved_count} findings for commit")
try:
await db.commit()
logger.info(f"[SaveFindings] Successfully committed {saved_count} findings to database")
except Exception as e:
logger.error(f"Failed to commit findings: {e}")
await db.rollback()
def _calculate_security_score(findings: List[Dict]) -> float:
@ -822,6 +986,92 @@ def _calculate_security_score(findings: List[Dict]) -> float:
return float(score)
async def _save_agent_tree(db: AsyncSession, task_id: str) -> None:
"""
保存 Agent 树到数据库
🔥 在任务完成前调用将内存中的 Agent 树持久化到数据库
"""
from app.models.agent_task import AgentTreeNode
from app.services.agent.core import agent_registry
try:
tree = agent_registry.get_agent_tree()
nodes = tree.get("nodes", {})
if not nodes:
logger.warning(f"[SaveAgentTree] No agent nodes to save for task {task_id}")
return
logger.info(f"[SaveAgentTree] Saving {len(nodes)} agent nodes for task {task_id}")
# 计算每个节点的深度
def get_depth(agent_id: str, visited: set = None) -> int:
if visited is None:
visited = set()
if agent_id in visited:
return 0
visited.add(agent_id)
node = nodes.get(agent_id)
if not node:
return 0
parent_id = node.get("parent_id")
if not parent_id:
return 0
return 1 + get_depth(parent_id, visited)
saved_count = 0
for agent_id, node_data in nodes.items():
# 获取 Agent 实例的统计数据
agent_instance = agent_registry.get_agent(agent_id)
iterations = 0
tool_calls = 0
tokens_used = 0
if agent_instance and hasattr(agent_instance, 'get_stats'):
stats = agent_instance.get_stats()
iterations = stats.get("iterations", 0)
tool_calls = stats.get("tool_calls", 0)
tokens_used = stats.get("tokens_used", 0)
# 从结果中获取发现数量
findings_count = 0
result_summary = None
if node_data.get("result"):
result = node_data.get("result", {})
if isinstance(result, dict):
findings_count = len(result.get("findings", []))
if result.get("summary"):
result_summary = str(result.get("summary"))[:2000]
tree_node = AgentTreeNode(
id=str(uuid4()),
task_id=task_id,
agent_id=agent_id,
agent_name=node_data.get("name", "Unknown"),
agent_type=node_data.get("type", "unknown"),
parent_agent_id=node_data.get("parent_id"),
depth=get_depth(agent_id),
task_description=node_data.get("task"),
knowledge_modules=node_data.get("knowledge_modules"),
status=node_data.get("status", "unknown"),
result_summary=result_summary,
findings_count=findings_count,
iterations=iterations,
tool_calls=tool_calls,
tokens_used=tokens_used,
)
db.add(tree_node)
saved_count += 1
await db.commit()
logger.info(f"[SaveAgentTree] Successfully saved {saved_count} agent nodes to database")
except Exception as e:
logger.error(f"[SaveAgentTree] Failed to save agent tree: {e}", exc_info=True)
await db.rollback()
# ============ API Endpoints ============
@router.post("/", response_model=AgentTaskResponse)
@ -1992,6 +2242,9 @@ async def generate_audit_report(
"code_snippet": f.code_snippet,
"is_verified": f.is_verified,
"has_poc": f.has_poc,
"poc_code": f.poc_code,
"poc_description": f.poc_description,
"poc_steps": f.poc_steps,
"confidence": f.ai_confidence,
"suggestion": f.suggestion,
"fix_code": f.fix_code,
@ -2169,6 +2422,30 @@ async def generate_audit_report(
md_lines.append("```")
md_lines.append("")
# 🔥 添加 PoC 详情
if f.has_poc:
md_lines.append("**Proof of Concept (PoC):**")
md_lines.append("")
if f.poc_description:
md_lines.append(f"*{f.poc_description}*")
md_lines.append("")
if f.poc_steps:
md_lines.append("**Reproduction Steps:**")
md_lines.append("")
for step_idx, step in enumerate(f.poc_steps, 1):
md_lines.append(f"{step_idx}. {step}")
md_lines.append("")
if f.poc_code:
md_lines.append("**PoC Payload:**")
md_lines.append("")
md_lines.append("```")
md_lines.append(f.poc_code.strip())
md_lines.append("```")
md_lines.append("")
md_lines.append("---")
md_lines.append("")

View File

@ -92,7 +92,7 @@ class Settings(BaseSettings):
AGENT_TIMEOUT_SECONDS: int = 1800 # Agent 超时时间30分钟
# 沙箱配置
SANDBOX_IMAGE: str = "deepaudit-sandbox:latest" # 沙箱 Docker 镜像
SANDBOX_IMAGE: str = "python:3.11-slim" # 沙箱 Docker 镜像
SANDBOX_MEMORY_LIMIT: str = "512m" # 沙箱内存限制
SANDBOX_CPU_LIMIT: float = 1.0 # 沙箱 CPU 限制
SANDBOX_TIMEOUT: int = 60 # 沙箱命令超时(秒)

View File

@ -422,6 +422,7 @@ Final Answer: {{"findings": [...], "summary": "..."}}"""
# 检查是否完成
if step.is_final:
await self.emit_llm_decision("完成安全分析", "LLM 判断分析已充分")
logger.info(f"[{self.name}] Received Final Answer: {step.final_answer}")
if step.final_answer and "findings" in step.final_answer:
all_findings = step.final_answer["findings"]
logger.info(f"[{self.name}] Final Answer contains {len(all_findings)} findings")
@ -438,7 +439,7 @@ Final Answer: {{"findings": [...], "summary": "..."}}"""
f"发现 {finding.get('severity', 'medium')} 级别漏洞: {finding.get('title', 'Unknown')}"
)
else:
logger.warning(f"[{self.name}] Final Answer has no 'findings' key: {step.final_answer}")
logger.warning(f"[{self.name}] Final Answer has no 'findings' key or is None: {step.final_answer}")
# 🔥 记录工作完成
self.record_work(f"完成安全分析,发现 {len(all_findings)} 个潜在漏洞")

View File

@ -742,8 +742,22 @@ class BaseAgent(ABC):
# ============ 发现相关事件 ============
async def emit_finding(self, title: str, severity: str, vuln_type: str, file_path: str = ""):
async def emit_finding(self, title: str, severity: str, vuln_type: str, file_path: str = "", is_verified: bool = False):
"""发射漏洞发现事件"""
import uuid
finding_id = str(uuid.uuid4())
# 🔥 使用 EventManager.emit_finding 发送正确的事件类型
if self.event_emitter and hasattr(self.event_emitter, 'emit_finding'):
await self.event_emitter.emit_finding(
finding_id=finding_id,
title=title,
severity=severity,
vulnerability_type=vuln_type,
is_verified=is_verified,
)
else:
# 回退:使用通用事件发射
severity_emoji = {
"critical": "🔴",
"high": "🟠",
@ -751,14 +765,17 @@ class BaseAgent(ABC):
"low": "🟢",
}.get(severity.lower(), "")
event_type = "finding_verified" if is_verified else "finding_new"
await self.emit_event(
"finding",
event_type,
f"{severity_emoji} [{self.name}] 发现漏洞: [{severity.upper()}] {title}\n 类型: {vuln_type}\n 位置: {file_path}",
metadata={
"id": finding_id,
"title": title,
"severity": severity,
"vulnerability_type": vuln_type,
"file_path": file_path,
"is_verified": is_verified,
}
)

View File

@ -432,6 +432,8 @@ Action Input: {{"参数": "值"}}
# 🔥 CRITICAL: Log final findings count before returning
logger.info(f"[Orchestrator] Final result: {len(self._all_findings)} findings collected")
if len(self._all_findings) == 0:
logger.warning(f"[Orchestrator] ⚠️ No findings collected! Dispatched agents: {list(self._dispatched_tasks.keys())}, Iterations: {self._iteration}")
for i, f in enumerate(self._all_findings[:5]): # Log first 5 for debugging
logger.debug(f"[Orchestrator] Finding {i+1}: {f.get('title', 'N/A')} - {f.get('vulnerability_type', 'N/A')}")
@ -654,18 +656,27 @@ Action Input: {{"参数": "值"}}
# findings 字段通常来自 Analysis/Verification Agent
# initial_findings 来自 Recon Agent
raw_findings = data.get("findings", [])
logger.info(f"[Orchestrator] {agent_name} returned data with {len(raw_findings)} findings in 'findings' field")
# 🔥 Also check for initial_findings (from Recon)
if not raw_findings and "initial_findings" in data:
# 🔥 ENHANCED: Also check for initial_findings (from Recon) - 改进逻辑
# 即使 findings 为空列表,也检查 initial_findings
if "initial_findings" in data:
initial = data.get("initial_findings", [])
# Convert string findings to dict format
raw_findings = []
logger.info(f"[Orchestrator] {agent_name} has {len(initial)} initial_findings")
for f in initial:
if isinstance(f, dict):
raw_findings.append(f)
# 🔥 Normalize finding format - 处理 Recon 返回的格式
normalized = self._normalize_finding(f)
if normalized not in raw_findings:
raw_findings.append(normalized)
elif isinstance(f, str):
# String finding from Recon - skip, it's just an observation
pass
logger.debug(f"[Orchestrator] Skipping string finding: {f[:50]}...")
# 🔥 Also check high_risk_areas from Recon for potential findings
if agent_name == "recon" and "high_risk_areas" in data:
high_risk = data.get("high_risk_areas", [])
logger.info(f"[Orchestrator] {agent_name} identified {len(high_risk)} high risk areas")
if raw_findings:
# 只添加字典格式的发现
@ -673,33 +684,58 @@ Action Input: {{"参数": "值"}}
logger.info(f"[Orchestrator] {agent_name} returned {len(valid_findings)} valid findings")
# 🔥 Merge findings to update existing ones and avoid duplicates
# 🔥 ENHANCED: Merge findings with better deduplication
for new_f in valid_findings:
# Create key for identification (file + line + type)
new_key = (
new_f.get("file_path", "") or new_f.get("file", ""),
new_f.get("line_start") or new_f.get("line", 0),
new_f.get("vulnerability_type", "") or new_f.get("type", ""),
)
# Normalize the finding first
normalized_new = self._normalize_finding(new_f)
# Check if exists
# Create fingerprint for deduplication (file + description similarity)
new_file = normalized_new.get("file_path", "").lower().strip()
new_desc = (normalized_new.get("description", "") or "").lower()[:100]
new_type = (normalized_new.get("vulnerability_type", "") or "").lower()
new_line = normalized_new.get("line_start") or normalized_new.get("line", 0)
# Check if exists (more flexible matching)
found = False
for i, existing_f in enumerate(self._all_findings):
existing_key = (
existing_f.get("file_path", "") or existing_f.get("file", ""),
existing_f.get("line_start") or existing_f.get("line", 0),
existing_f.get("vulnerability_type", "") or existing_f.get("type", ""),
existing_file = (existing_f.get("file_path", "") or existing_f.get("file", "")).lower().strip()
existing_desc = (existing_f.get("description", "") or "").lower()[:100]
existing_type = (existing_f.get("vulnerability_type", "") or existing_f.get("type", "")).lower()
existing_line = existing_f.get("line_start") or existing_f.get("line", 0)
# Match if same file AND (same line OR similar description OR same vulnerability type)
same_file = new_file and existing_file and (
new_file == existing_file or
new_file.endswith(existing_file) or
existing_file.endswith(new_file)
)
if new_key == existing_key:
same_line = new_line and existing_line and new_line == existing_line
similar_desc = new_desc and existing_desc and (
new_desc in existing_desc or existing_desc in new_desc
)
same_type = new_type and existing_type and (
new_type == existing_type or
(new_type in existing_type) or (existing_type in new_type)
)
if same_file and (same_line or similar_desc or same_type):
# Update existing with new info (e.g. verification results)
self._all_findings[i] = {**existing_f, **new_f}
# Prefer verified data over unverified
merged = {**existing_f, **normalized_new}
# Keep the better title
if normalized_new.get("title") and len(normalized_new.get("title", "")) > len(existing_f.get("title", "")):
merged["title"] = normalized_new["title"]
# Keep verified status if either is verified
if existing_f.get("is_verified") or normalized_new.get("is_verified"):
merged["is_verified"] = True
self._all_findings[i] = merged
found = True
logger.debug(f"[Orchestrator] Updated existing finding: {new_key}")
logger.info(f"[Orchestrator] Merged finding: {new_file}:{new_line} ({new_type})")
break
if not found:
self._all_findings.append(new_f)
logger.debug(f"[Orchestrator] Added new finding: {new_key}")
self._all_findings.append(normalized_new)
logger.info(f"[Orchestrator] Added new finding: {new_file}:{new_line} ({new_type})")
logger.info(f"[Orchestrator] Total findings now: {len(self._all_findings)}")
else:
@ -752,13 +788,13 @@ Action Input: {{"参数": "值"}}
observation = f"""## {agent_name} Agent 执行结果
**状态**: 成功
**发现数量**: {len(findings)}
**发现数量**: {len(valid_findings)}
**迭代次数**: {result.iterations}
**耗时**: {result.duration_ms}ms
### 发现摘要
"""
for i, f in enumerate(findings[:10]):
for i, f in enumerate(valid_findings[:10]):
if not isinstance(f, dict):
continue
observation += f"""
@ -768,8 +804,8 @@ Action Input: {{"参数": "值"}}
- 描述: {f.get('description', '')[:200]}...
"""
if len(findings) > 10:
observation += f"\n... 还有 {len(findings) - 10} 个发现"
if len(valid_findings) > 10:
observation += f"\n... 还有 {len(valid_findings) - 10} 个发现"
if data.get("summary"):
observation += f"\n\n### Agent 总结\n{data['summary']}"
@ -782,6 +818,94 @@ Action Input: {{"参数": "值"}}
logger.error(f"Sub-agent dispatch failed: {e}", exc_info=True)
return f"## 调度失败\n\n错误: {str(e)}"
def _normalize_finding(self, finding: Dict[str, Any]) -> Dict[str, Any]:
"""
标准化发现格式
不同 Agent 可能返回不同格式的发现这个方法将它们标准化为统一格式
"""
normalized = dict(finding) # 复制原始数据
# 🔥 处理 location 字段 -> file_path + line_start
if "location" in normalized and "file_path" not in normalized:
location = normalized["location"]
if isinstance(location, str) and ":" in location:
parts = location.split(":")
normalized["file_path"] = parts[0]
try:
normalized["line_start"] = int(parts[1])
except (ValueError, IndexError):
pass
elif isinstance(location, str):
normalized["file_path"] = location
# 🔥 处理 file 字段 -> file_path
if "file" in normalized and "file_path" not in normalized:
normalized["file_path"] = normalized["file"]
# 🔥 处理 line 字段 -> line_start
if "line" in normalized and "line_start" not in normalized:
normalized["line_start"] = normalized["line"]
# 🔥 处理 type 字段 -> vulnerability_type
if "type" in normalized and "vulnerability_type" not in normalized:
# 不是所有 type 都是漏洞类型,比如 "Vulnerability" 只是标记
type_val = normalized["type"]
if type_val and type_val.lower() not in ["vulnerability", "finding", "issue"]:
normalized["vulnerability_type"] = type_val
elif "description" in normalized:
# 尝试从描述中推断漏洞类型
desc = normalized["description"].lower()
if "command injection" in desc or "rce" in desc or "system(" in desc:
normalized["vulnerability_type"] = "command_injection"
elif "sql injection" in desc or "sqli" in desc:
normalized["vulnerability_type"] = "sql_injection"
elif "xss" in desc or "cross-site scripting" in desc:
normalized["vulnerability_type"] = "xss"
elif "path traversal" in desc or "directory traversal" in desc:
normalized["vulnerability_type"] = "path_traversal"
elif "ssrf" in desc:
normalized["vulnerability_type"] = "ssrf"
elif "xxe" in desc:
normalized["vulnerability_type"] = "xxe"
else:
normalized["vulnerability_type"] = "other"
# 🔥 确保 severity 字段存在且为小写
if "severity" in normalized:
normalized["severity"] = str(normalized["severity"]).lower()
else:
normalized["severity"] = "medium"
# 🔥 处理 risk 字段 -> severity
if "risk" in normalized and "severity" not in normalized:
normalized["severity"] = str(normalized["risk"]).lower()
# 🔥 生成 title 如果不存在
if "title" not in normalized:
vuln_type = normalized.get("vulnerability_type", "Unknown")
file_path = normalized.get("file_path", "")
if file_path:
import os
normalized["title"] = f"{vuln_type.replace('_', ' ').title()} in {os.path.basename(file_path)}"
else:
normalized["title"] = f"{vuln_type.replace('_', ' ').title()} Vulnerability"
# 🔥 处理 code 字段 -> code_snippet
if "code" in normalized and "code_snippet" not in normalized:
normalized["code_snippet"] = normalized["code"]
# 🔥 处理 recommendation -> suggestion
if "recommendation" in normalized and "suggestion" not in normalized:
normalized["suggestion"] = normalized["recommendation"]
# 🔥 处理 impact -> 添加到 description
if "impact" in normalized and normalized.get("description"):
if "impact" not in normalized["description"].lower():
normalized["description"] += f"\n\nImpact: {normalized['impact']}"
return normalized
def _summarize_findings(self) -> str:
"""汇总当前发现"""
if not self._all_findings:

View File

@ -120,6 +120,10 @@ Final Answer: [JSON 格式的验证报告]
2. **深入理解** - 理解代码逻辑不要表面判断
3. **证据支撑** - 判定要有依据
4. **安全第一** - 沙箱测试要谨慎
5. **🔥 PoC 生成** - 对于 confirmed likely 的漏洞**必须**生成 PoC:
- poc.description: 简要描述这个 PoC 的作用
- poc.steps: 详细的复现步骤列表
- poc.payload: 实际的攻击载荷或测试代码
现在开始验证漏洞发现"""

View File

@ -171,6 +171,7 @@ class AgentEventEmitter:
finding_id=finding_id,
message=f"{'✅ 已验证' if is_verified else '🔍 新发现'}: [{severity.upper()}] {title}",
metadata={
"id": finding_id, # 🔥 添加 id 字段供前端使用
"title": title,
"severity": severity,
"vulnerability_type": vulnerability_type,
@ -331,6 +332,31 @@ class EventManager:
"""保存事件到数据库"""
from app.models.agent_task import AgentEvent
# 🔥 清理无效的 UTF-8 字符(如二进制内容)
def sanitize_string(s):
"""清理字符串中的无效 UTF-8 字符"""
if s is None:
return None
if not isinstance(s, str):
s = str(s)
# 移除 NULL 字节和其他不可打印的控制字符(保留换行和制表符)
return ''.join(
char for char in s
if char in '\n\r\t' or (ord(char) >= 32 and ord(char) != 127)
)
def sanitize_dict(d):
"""递归清理字典中的字符串值"""
if d is None:
return None
if isinstance(d, dict):
return {k: sanitize_dict(v) for k, v in d.items()}
elif isinstance(d, list):
return [sanitize_dict(item) for item in d]
elif isinstance(d, str):
return sanitize_string(d)
return d
async with self.db_session_factory() as db:
event = AgentEvent(
id=event_data["id"],
@ -338,14 +364,14 @@ class EventManager:
event_type=event_data["event_type"],
sequence=event_data["sequence"],
phase=event_data["phase"],
message=event_data["message"],
message=sanitize_string(event_data["message"]), # 🔥 清理消息
tool_name=event_data["tool_name"],
tool_input=event_data["tool_input"],
tool_output=event_data["tool_output"],
tool_input=sanitize_dict(event_data["tool_input"]), # 🔥 清理工具输入
tool_output=sanitize_dict(event_data["tool_output"]), # 🔥 清理工具输出
tool_duration_ms=event_data["tool_duration_ms"],
finding_id=event_data["finding_id"],
tokens_used=event_data["tokens_used"],
event_metadata=event_data["metadata"],
event_metadata=sanitize_dict(event_data["metadata"]), # 🔥 清理元数据
)
db.add(event)
await db.commit()
@ -406,7 +432,10 @@ class EventManager:
🔥 重要: 此方法会先排空队列中已缓存的事件 SSE 连接前产生的
然后继续实时推送新事件
只返回序列号 > after_sequence 的事件
"""
logger.info(f"[StreamEvents] Task {task_id}: Starting stream with after_sequence={after_sequence}")
# 获取现有队列(由 AgentRunner 在初始化时创建)
queue = self._event_queues.get(task_id)
@ -417,9 +446,17 @@ class EventManager:
# 🔥 先排空队列中已缓存的事件(这些是在 SSE 连接前产生的)
buffered_count = 0
skipped_count = 0
while not queue.empty():
try:
buffered_event = queue.get_nowait()
# 🔥 过滤掉序列号 <= after_sequence 的事件
event_sequence = buffered_event.get("sequence", 0)
if event_sequence <= after_sequence:
skipped_count += 1
continue
buffered_count += 1
yield buffered_event
@ -432,19 +469,25 @@ class EventManager:
# 检查是否是结束事件
if event_type in ["task_complete", "task_error", "task_cancel"]:
logger.debug(f"Task {task_id} already completed, sent {buffered_count} buffered events")
logger.debug(f"Task {task_id} already completed, sent {buffered_count} buffered events (skipped {skipped_count})")
return
except asyncio.QueueEmpty:
break
if buffered_count > 0:
logger.debug(f"Drained {buffered_count} buffered events for task {task_id}")
if buffered_count > 0 or skipped_count > 0:
logger.debug(f"Drained queue for task {task_id}: sent {buffered_count}, skipped {skipped_count} (after_sequence={after_sequence})")
# 然后实时推送新事件
try:
while True:
try:
event = await asyncio.wait_for(queue.get(), timeout=30)
# 🔥 过滤掉序列号 <= after_sequence 的事件
event_sequence = event.get("sequence", 0)
if event_sequence <= after_sequence:
continue
yield event
# 🔥 为 thinking_token 添加微延迟确保流式效果

View File

@ -274,15 +274,20 @@ class AgentRunner:
# 沙箱工具(仅 Verification Agent 可用)
try:
self.sandbox_manager = SandboxManager(
from app.services.agent.tools.sandbox_tool import SandboxConfig
sandbox_config = SandboxConfig(
image=settings.SANDBOX_IMAGE,
memory_limit=settings.SANDBOX_MEMORY_LIMIT,
cpu_limit=settings.SANDBOX_CPU_LIMIT,
timeout=settings.SANDBOX_TIMEOUT,
network_mode=settings.SANDBOX_NETWORK_MODE,
)
self.sandbox_manager = SandboxManager(config=sandbox_config)
self.verification_tools["sandbox_exec"] = SandboxTool(self.sandbox_manager)
self.verification_tools["sandbox_http"] = SandboxHttpTool(self.sandbox_manager)
self.verification_tools["verify_vulnerability"] = VulnerabilityVerifyTool(self.sandbox_manager)
logger.info("Sandbox tools initialized successfully")
except Exception as e:
logger.warning(f"Sandbox initialization failed: {e}")

View File

@ -256,11 +256,14 @@ class LiteLLMAdapter(BaseLLMAdapter):
accumulated_content = ""
final_usage = None # 🔥 存储最终的 usage 信息
chunk_count = 0 # 🔥 跟踪 chunk 数量
try:
response = await litellm.acompletion(**kwargs)
async for chunk in response:
chunk_count += 1
# 🔥 检查是否有 usage 信息(某些 API 会在最后的 chunk 中包含)
if hasattr(chunk, "usage") and chunk.usage:
final_usage = {
@ -271,6 +274,7 @@ class LiteLLMAdapter(BaseLLMAdapter):
logger.debug(f"Got usage from chunk: {final_usage}")
if not chunk.choices:
# 🔥 某些模型可能发送没有 choices 的 chunk如心跳
continue
delta = chunk.choices[0].delta
@ -284,9 +288,8 @@ class LiteLLMAdapter(BaseLLMAdapter):
"content": content,
"accumulated": accumulated_content,
}
else:
# Log when we get a chunk without content
logger.debug(f"Chunk with no content: {chunk}")
# 🔥 ENHANCED: 处理没有 content 但也没有 finish_reason 的情况
# 某些模型(如智谱 GLM可能在某些 chunk 中不返回内容
if finish_reason:
# 流式完成
@ -300,6 +303,10 @@ class LiteLLMAdapter(BaseLLMAdapter):
}
logger.debug(f"Estimated usage: {final_usage}")
# 🔥 ENHANCED: 如果累积内容为空但有 finish_reason记录警告
if not accumulated_content:
logger.warning(f"Stream completed with no content after {chunk_count} chunks, finish_reason={finish_reason}")
yield {
"type": "done",
"content": accumulated_content,
@ -308,8 +315,26 @@ class LiteLLMAdapter(BaseLLMAdapter):
}
break
# 🔥 ENHANCED: 如果循环结束但没有收到 finish_reason也需要返回 done
if accumulated_content:
logger.warning(f"Stream ended without finish_reason, returning accumulated content ({len(accumulated_content)} chars)")
if not final_usage:
output_tokens_estimate = estimate_tokens(accumulated_content)
final_usage = {
"prompt_tokens": input_tokens_estimate,
"completion_tokens": output_tokens_estimate,
"total_tokens": input_tokens_estimate + output_tokens_estimate,
}
yield {
"type": "done",
"content": accumulated_content,
"usage": final_usage,
"finish_reason": "complete",
}
except Exception as e:
# 🔥 即使出错,也尝试返回估算的 usage
logger.error(f"Stream error: {e}")
output_tokens_estimate = estimate_tokens(accumulated_content) if accumulated_content else 0
yield {
"type": "error",

View File

@ -42,6 +42,16 @@ function agentAuditReducer(state: AgentAuditState, action: AgentAuditAction): Ag
case 'SET_FINDINGS':
return { ...state, findings: action.payload };
case 'ADD_FINDING': {
// 🔥 添加单个 finding避免重复
const newFinding = action.payload;
const existingIds = new Set(state.findings.map(f => f.id));
if (newFinding.id && existingIds.has(newFinding.id)) {
return state; // 已存在,不添加
}
return { ...state, findings: [...state.findings, newFinding] };
}
case 'SET_AGENT_TREE':
return { ...state, agentTree: action.payload };

View File

@ -63,7 +63,11 @@ function AgentAuditPageContent() {
const previousTaskIdRef = useRef<string | undefined>(undefined);
const disconnectStreamRef = useRef<(() => void) | null>(null);
const lastEventSequenceRef = useRef<number>(0);
const historicalEventsLoadedRef = useRef<boolean>(false);
const hasConnectedRef = useRef<boolean>(false); // 🔥 追踪是否已连接 SSE
const hasLoadedHistoricalEventsRef = useRef<boolean>(false); // 🔥 追踪是否已加载历史事件
// 🔥 使用 state 来标记历史事件加载状态和触发 streamOptions 重新计算
const [afterSequence, setAfterSequence] = useState<number>(0);
const [historicalEventsLoaded, setHistoricalEventsLoaded] = useState<boolean>(false);
// 🔥 当 taskId 变化时立即重置状态(新建任务时清理旧日志)
useEffect(() => {
@ -79,7 +83,10 @@ function AgentAuditPageContent() {
setShowSplash(!taskId);
// 3. 重置事件序列号和加载状态
lastEventSequenceRef.current = 0;
historicalEventsLoadedRef.current = false;
hasConnectedRef.current = false; // 🔥 重置 SSE 连接标志
hasLoadedHistoricalEventsRef.current = false; // 🔥 重置历史事件加载标志
setHistoricalEventsLoaded(false); // 🔥 重置历史事件加载状态
setAfterSequence(0); // 🔥 重置 afterSequence state
}
previousTaskIdRef.current = taskId;
}, [taskId, reset]);
@ -141,6 +148,14 @@ function AgentAuditPageContent() {
// 🔥 NEW: 加载历史事件并转换为日志项
const loadHistoricalEvents = useCallback(async () => {
if (!taskId) return 0;
// 🔥 防止重复加载历史事件
if (hasLoadedHistoricalEventsRef.current) {
console.log('[AgentAudit] Historical events already loaded, skipping');
return 0;
}
hasLoadedHistoricalEventsRef.current = true;
try {
console.log(`[AgentAudit] Fetching historical events for task ${taskId}...`);
const events = await getAgentEvents(taskId, { limit: 500 });
@ -356,20 +371,22 @@ function AgentAuditPageContent() {
});
console.log(`[AgentAudit] Processed ${processedCount} events into logs, last sequence: ${lastEventSequenceRef.current}`);
// 🔥 更新 afterSequence state触发 streamOptions 重新计算
setAfterSequence(lastEventSequenceRef.current);
return events.length;
} catch (err) {
console.error('[AgentAudit] Failed to load historical events:', err);
return 0;
}
}, [taskId, dispatch]);
}, [taskId, dispatch, setAfterSequence]);
// ============ Stream Event Handling ============
const streamOptions = useMemo(() => ({
includeThinking: true,
includeToolCalls: true,
// 🔥 使用最后的事件序列号,避免重复接收历史事件
afterSequence: lastEventSequenceRef.current,
// 🔥 使用 state 变量,确保在历史事件加载后能获取最新值
afterSequence: afterSequence,
onEvent: (event: { type: string; message?: string; metadata?: { agent_name?: string; agent?: string } }) => {
if (event.metadata?.agent_name) {
setCurrentAgentName(event.metadata.agent_name);
@ -478,7 +495,20 @@ function AgentAuditPageContent() {
agentName: getCurrentAgentName() || undefined,
}
});
loadFindings();
// 🔥 直接将 finding 添加到状态,不依赖 API因为运行时数据库还没有数据
dispatch({
type: 'ADD_FINDING',
payload: {
id: (finding.id as string) || `finding-${Date.now()}`,
title: (finding.title as string) || 'Vulnerability found',
severity: (finding.severity as string) || 'medium',
vulnerability_type: (finding.vulnerability_type as string) || 'unknown',
file_path: finding.file_path as string,
line_start: finding.line_start as number,
description: finding.description as string,
is_verified: (finding.is_verified as boolean) || false,
}
});
},
onComplete: () => {
dispatch({ type: 'ADD_LOG', payload: { type: 'info', title: 'Audit completed successfully' } });
@ -489,7 +519,7 @@ function AgentAuditPageContent() {
onError: (err: string) => {
dispatch({ type: 'ADD_LOG', payload: { type: 'error', title: `Error: ${err}` } });
},
}), [dispatch, loadTask, loadFindings, loadAgentTree, debouncedLoadAgentTree,
}), [afterSequence, dispatch, loadTask, loadFindings, loadAgentTree, debouncedLoadAgentTree,
updateLog, removeLog, getCurrentAgentName, getCurrentThinkingId,
setCurrentAgentName, setCurrentThinkingId]);
@ -523,7 +553,7 @@ function AgentAuditPageContent() {
}
setShowSplash(false);
setLoading(true);
historicalEventsLoadedRef.current = false;
setHistoricalEventsLoaded(false);
const loadAllData = async () => {
try {
@ -534,11 +564,11 @@ function AgentAuditPageContent() {
const eventsLoaded = await loadHistoricalEvents();
console.log(`[AgentAudit] Loaded ${eventsLoaded} historical events for task ${taskId}`);
// 标记历史事件已加载完成
historicalEventsLoadedRef.current = true;
// 标记历史事件已加载完成 (setAfterSequence 已在 loadHistoricalEvents 中调用)
setHistoricalEventsLoaded(true);
} catch (error) {
console.error('[AgentAudit] Failed to load data:', error);
historicalEventsLoadedRef.current = true; // 即使出错也标记为完成,避免无限等待
setHistoricalEventsLoaded(true); // 即使出错也标记为完成,避免无限等待
} finally {
setLoading(false);
}
@ -552,20 +582,21 @@ function AgentAuditPageContent() {
// 等待历史事件加载完成,且任务正在运行
if (!taskId || !task?.status || task.status !== 'running') return;
// 如果历史事件尚未加载完成,等待一下
const checkAndConnect = () => {
if (historicalEventsLoadedRef.current) {
// 🔥 使用 state 变量确保在历史事件加载完成后才连接
if (!historicalEventsLoaded) return;
// 🔥 避免重复连接 - 只连接一次
if (hasConnectedRef.current) return;
hasConnectedRef.current = true;
console.log(`[AgentAudit] Connecting to stream with afterSequence=${afterSequence}`);
connectStream();
dispatch({ type: 'ADD_LOG', payload: { type: 'info', title: 'Connected to audit stream' } });
} else {
// 延迟重试
setTimeout(checkAndConnect, 100);
}
};
checkAndConnect();
return () => disconnectStream();
}, [taskId, task?.status, connectStream, disconnectStream, dispatch]);
return () => {
disconnectStream();
};
}, [taskId, task?.status, historicalEventsLoaded, connectStream, disconnectStream, dispatch, afterSequence]);
// Polling
useEffect(() => {

View File

@ -71,6 +71,7 @@ export interface AgentTreeResponse {
export type AgentAuditAction =
| { type: 'SET_TASK'; payload: AgentTask }
| { type: 'SET_FINDINGS'; payload: AgentFinding[] }
| { type: 'ADD_FINDING'; payload: Partial<AgentFinding> & { id: string } }
| { type: 'SET_AGENT_TREE'; payload: AgentTreeResponse }
| { type: 'SET_LOGS'; payload: LogItem[] }
| { type: 'ADD_LOG'; payload: Omit<LogItem, 'id' | 'time'> & { id?: string } }

View File

@ -396,6 +396,7 @@ export class AgentStreamHandler {
break;
// 发现
case 'finding': // 🔥 向后兼容旧的事件类型
case 'finding_new':
case 'finding_verified':
this.options.onFinding?.(