470 lines
17 KiB
Python
470 lines
17 KiB
Python
"""
|
||
Analysis Agent (漏洞分析层)
|
||
负责代码审计、RAG 查询、模式匹配、数据流分析
|
||
|
||
类型: ReAct
|
||
"""
|
||
|
||
import asyncio
|
||
import logging
|
||
from typing import List, Dict, Any, Optional
|
||
|
||
from .base import BaseAgent, AgentConfig, AgentResult, AgentType, AgentPattern
|
||
|
||
logger = logging.getLogger(__name__)
|
||
|
||
|
||
ANALYSIS_SYSTEM_PROMPT = """你是 DeepAudit 的漏洞分析 Agent,负责深度代码安全分析。
|
||
|
||
## 你的职责
|
||
1. 使用静态分析工具快速扫描
|
||
2. 使用 RAG 进行语义代码搜索
|
||
3. 追踪数据流(从用户输入到危险函数)
|
||
4. 分析业务逻辑漏洞
|
||
5. 评估漏洞严重程度
|
||
|
||
## 你可以使用的工具
|
||
### 外部扫描工具
|
||
- semgrep_scan: Semgrep 静态分析(推荐首先使用)
|
||
- bandit_scan: Python 安全扫描
|
||
|
||
### RAG 语义搜索
|
||
- rag_query: 语义代码搜索
|
||
- security_search: 安全相关代码搜索
|
||
- function_context: 函数上下文分析
|
||
|
||
### 深度分析
|
||
- pattern_match: 危险模式匹配
|
||
- code_analysis: LLM 深度代码分析
|
||
- dataflow_analysis: 数据流追踪
|
||
- vulnerability_validation: 漏洞验证
|
||
|
||
### 文件操作
|
||
- read_file: 读取文件
|
||
- search_code: 关键字搜索
|
||
|
||
## 分析策略
|
||
1. **快速扫描**: 先用 Semgrep 快速发现问题
|
||
2. **语义搜索**: 用 RAG 找到相关代码
|
||
3. **深度分析**: 对可疑代码进行 LLM 分析
|
||
4. **数据流追踪**: 追踪用户输入的流向
|
||
|
||
## 重点关注
|
||
- SQL 注入、NoSQL 注入
|
||
- XSS(反射型、存储型、DOM型)
|
||
- 命令注入、代码注入
|
||
- 路径遍历、任意文件访问
|
||
- SSRF、XXE
|
||
- 不安全的反序列化
|
||
- 认证/授权绕过
|
||
- 敏感信息泄露
|
||
|
||
## 输出格式
|
||
发现漏洞时,返回结构化信息:
|
||
```json
|
||
{
|
||
"findings": [
|
||
{
|
||
"vulnerability_type": "漏洞类型",
|
||
"severity": "critical/high/medium/low",
|
||
"title": "漏洞标题",
|
||
"description": "详细描述",
|
||
"file_path": "文件路径",
|
||
"line_start": 行号,
|
||
"code_snippet": "代码片段",
|
||
"source": "污点源",
|
||
"sink": "危险函数",
|
||
"suggestion": "修复建议",
|
||
"needs_verification": true/false
|
||
}
|
||
]
|
||
}
|
||
```
|
||
|
||
请系统性地分析代码,发现真实的安全漏洞。"""
|
||
|
||
|
||
class AnalysisAgent(BaseAgent):
|
||
"""
|
||
漏洞分析 Agent
|
||
|
||
使用 ReAct 模式进行迭代分析
|
||
"""
|
||
|
||
def __init__(
|
||
self,
|
||
llm_service,
|
||
tools: Dict[str, Any],
|
||
event_emitter=None,
|
||
):
|
||
config = AgentConfig(
|
||
name="Analysis",
|
||
agent_type=AgentType.ANALYSIS,
|
||
pattern=AgentPattern.REACT,
|
||
max_iterations=30,
|
||
system_prompt=ANALYSIS_SYSTEM_PROMPT,
|
||
tools=[
|
||
"semgrep_scan", "bandit_scan",
|
||
"rag_query", "security_search", "function_context",
|
||
"pattern_match", "code_analysis", "dataflow_analysis",
|
||
"vulnerability_validation",
|
||
"read_file", "search_code",
|
||
],
|
||
)
|
||
super().__init__(config, llm_service, tools, event_emitter)
|
||
|
||
async def run(self, input_data: Dict[str, Any]) -> AgentResult:
|
||
"""执行漏洞分析"""
|
||
import time
|
||
start_time = time.time()
|
||
|
||
phase_name = input_data.get("phase_name", "analysis")
|
||
project_info = input_data.get("project_info", {})
|
||
config = input_data.get("config", {})
|
||
plan = input_data.get("plan", {})
|
||
previous_results = input_data.get("previous_results", {})
|
||
|
||
# 从之前的 Recon 结果获取信息
|
||
recon_data = previous_results.get("recon", {}).get("data", {})
|
||
high_risk_areas = recon_data.get("high_risk_areas", plan.get("high_risk_areas", []))
|
||
tech_stack = recon_data.get("tech_stack", {})
|
||
entry_points = recon_data.get("entry_points", [])
|
||
|
||
try:
|
||
all_findings = []
|
||
|
||
# 1. 静态分析阶段
|
||
if phase_name in ["static_analysis", "analysis"]:
|
||
await self.emit_thinking("执行静态代码分析...")
|
||
static_findings = await self._run_static_analysis(tech_stack)
|
||
all_findings.extend(static_findings)
|
||
|
||
# 2. 深度分析阶段
|
||
if phase_name in ["deep_analysis", "analysis"]:
|
||
await self.emit_thinking("执行深度漏洞分析...")
|
||
|
||
# 分析入口点
|
||
deep_findings = await self._analyze_entry_points(entry_points)
|
||
all_findings.extend(deep_findings)
|
||
|
||
# 分析高风险区域
|
||
risk_findings = await self._analyze_high_risk_areas(high_risk_areas)
|
||
all_findings.extend(risk_findings)
|
||
|
||
# 语义搜索常见漏洞
|
||
vuln_types = config.get("target_vulnerabilities", [
|
||
"sql_injection", "xss", "command_injection",
|
||
"path_traversal", "ssrf", "hardcoded_secret",
|
||
])
|
||
|
||
for vuln_type in vuln_types[:5]: # 限制数量
|
||
if self.is_cancelled:
|
||
break
|
||
|
||
await self.emit_thinking(f"搜索 {vuln_type} 相关代码...")
|
||
vuln_findings = await self._search_vulnerability_pattern(vuln_type)
|
||
all_findings.extend(vuln_findings)
|
||
|
||
# 去重
|
||
all_findings = self._deduplicate_findings(all_findings)
|
||
|
||
duration_ms = int((time.time() - start_time) * 1000)
|
||
|
||
await self.emit_event(
|
||
"info",
|
||
f"分析完成: 发现 {len(all_findings)} 个潜在漏洞"
|
||
)
|
||
|
||
return AgentResult(
|
||
success=True,
|
||
data={"findings": all_findings},
|
||
iterations=self._iteration,
|
||
tool_calls=self._tool_calls,
|
||
tokens_used=self._total_tokens,
|
||
duration_ms=duration_ms,
|
||
)
|
||
|
||
except Exception as e:
|
||
logger.error(f"Analysis agent failed: {e}", exc_info=True)
|
||
return AgentResult(success=False, error=str(e))
|
||
|
||
async def _run_static_analysis(self, tech_stack: Dict) -> List[Dict]:
|
||
"""运行静态分析工具"""
|
||
findings = []
|
||
|
||
# Semgrep 扫描
|
||
semgrep_tool = self.tools.get("semgrep_scan")
|
||
if semgrep_tool:
|
||
await self.emit_tool_call("semgrep_scan", {"rules": "p/security-audit"})
|
||
|
||
result = await semgrep_tool.execute(rules="p/security-audit", max_results=30)
|
||
|
||
if result.success and result.metadata.get("findings_count", 0) > 0:
|
||
for finding in result.metadata.get("findings", []):
|
||
findings.append({
|
||
"vulnerability_type": self._map_semgrep_rule(finding.get("check_id", "")),
|
||
"severity": self._map_semgrep_severity(finding.get("extra", {}).get("severity", "")),
|
||
"title": finding.get("check_id", "Semgrep Finding"),
|
||
"description": finding.get("extra", {}).get("message", ""),
|
||
"file_path": finding.get("path", ""),
|
||
"line_start": finding.get("start", {}).get("line", 0),
|
||
"code_snippet": finding.get("extra", {}).get("lines", ""),
|
||
"source": "semgrep",
|
||
"needs_verification": True,
|
||
})
|
||
|
||
# Bandit 扫描 (Python)
|
||
languages = tech_stack.get("languages", [])
|
||
if "Python" in languages:
|
||
bandit_tool = self.tools.get("bandit_scan")
|
||
if bandit_tool:
|
||
await self.emit_tool_call("bandit_scan", {})
|
||
result = await bandit_tool.execute()
|
||
|
||
if result.success and result.metadata.get("findings_count", 0) > 0:
|
||
for finding in result.metadata.get("findings", []):
|
||
findings.append({
|
||
"vulnerability_type": self._map_bandit_test(finding.get("test_id", "")),
|
||
"severity": finding.get("issue_severity", "medium").lower(),
|
||
"title": finding.get("test_name", "Bandit Finding"),
|
||
"description": finding.get("issue_text", ""),
|
||
"file_path": finding.get("filename", ""),
|
||
"line_start": finding.get("line_number", 0),
|
||
"code_snippet": finding.get("code", ""),
|
||
"source": "bandit",
|
||
"needs_verification": True,
|
||
})
|
||
|
||
return findings
|
||
|
||
async def _analyze_entry_points(self, entry_points: List[Dict]) -> List[Dict]:
|
||
"""分析入口点"""
|
||
findings = []
|
||
|
||
code_analysis_tool = self.tools.get("code_analysis")
|
||
read_tool = self.tools.get("read_file")
|
||
|
||
if not code_analysis_tool or not read_tool:
|
||
return findings
|
||
|
||
# 分析前几个入口点
|
||
for ep in entry_points[:10]:
|
||
if self.is_cancelled:
|
||
break
|
||
|
||
file_path = ep.get("file", "")
|
||
line = ep.get("line", 1)
|
||
|
||
if not file_path:
|
||
continue
|
||
|
||
# 读取文件内容
|
||
read_result = await read_tool.execute(
|
||
file_path=file_path,
|
||
start_line=max(1, line - 20),
|
||
end_line=line + 50,
|
||
)
|
||
|
||
if not read_result.success:
|
||
continue
|
||
|
||
# 深度分析
|
||
analysis_result = await code_analysis_tool.execute(
|
||
code=read_result.data,
|
||
file_path=file_path,
|
||
)
|
||
|
||
if analysis_result.success and analysis_result.metadata.get("issues"):
|
||
for issue in analysis_result.metadata["issues"]:
|
||
findings.append({
|
||
"vulnerability_type": issue.get("type", "unknown"),
|
||
"severity": issue.get("severity", "medium"),
|
||
"title": issue.get("title", "Security Issue"),
|
||
"description": issue.get("description", ""),
|
||
"file_path": file_path,
|
||
"line_start": issue.get("line", line),
|
||
"code_snippet": issue.get("code_snippet", ""),
|
||
"suggestion": issue.get("suggestion", ""),
|
||
"source": "code_analysis",
|
||
"needs_verification": True,
|
||
})
|
||
|
||
return findings
|
||
|
||
async def _analyze_high_risk_areas(self, high_risk_areas: List[str]) -> List[Dict]:
|
||
"""分析高风险区域"""
|
||
findings = []
|
||
|
||
pattern_tool = self.tools.get("pattern_match")
|
||
read_tool = self.tools.get("read_file")
|
||
search_tool = self.tools.get("search_code")
|
||
|
||
if not search_tool:
|
||
return findings
|
||
|
||
# 在高风险区域搜索危险模式
|
||
dangerous_patterns = [
|
||
("execute(", "sql_injection"),
|
||
("eval(", "code_injection"),
|
||
("system(", "command_injection"),
|
||
("exec(", "command_injection"),
|
||
("innerHTML", "xss"),
|
||
("document.write", "xss"),
|
||
]
|
||
|
||
for pattern, vuln_type in dangerous_patterns[:5]:
|
||
if self.is_cancelled:
|
||
break
|
||
|
||
result = await search_tool.execute(keyword=pattern, max_results=10)
|
||
|
||
if result.success and result.metadata.get("matches", 0) > 0:
|
||
for match in result.metadata.get("results", [])[:3]:
|
||
file_path = match.get("file", "")
|
||
|
||
# 检查是否在高风险区域
|
||
in_high_risk = any(
|
||
area in file_path for area in high_risk_areas
|
||
)
|
||
|
||
if in_high_risk or True: # 暂时包含所有
|
||
findings.append({
|
||
"vulnerability_type": vuln_type,
|
||
"severity": "high" if in_high_risk else "medium",
|
||
"title": f"疑似 {vuln_type}: {pattern}",
|
||
"description": f"在 {file_path} 中发现危险模式 {pattern}",
|
||
"file_path": file_path,
|
||
"line_start": match.get("line", 0),
|
||
"code_snippet": match.get("match", ""),
|
||
"source": "pattern_search",
|
||
"needs_verification": True,
|
||
})
|
||
|
||
return findings
|
||
|
||
async def _search_vulnerability_pattern(self, vuln_type: str) -> List[Dict]:
|
||
"""搜索特定漏洞模式"""
|
||
findings = []
|
||
|
||
security_tool = self.tools.get("security_search")
|
||
if not security_tool:
|
||
return findings
|
||
|
||
result = await security_tool.execute(
|
||
vulnerability_type=vuln_type,
|
||
top_k=10,
|
||
)
|
||
|
||
if result.success and result.metadata.get("results_count", 0) > 0:
|
||
for item in result.metadata.get("results", [])[:5]:
|
||
findings.append({
|
||
"vulnerability_type": vuln_type,
|
||
"severity": "medium",
|
||
"title": f"疑似 {vuln_type}",
|
||
"description": f"通过语义搜索发现可能存在 {vuln_type}",
|
||
"file_path": item.get("file_path", ""),
|
||
"line_start": item.get("line_start", 0),
|
||
"code_snippet": item.get("content", "")[:500],
|
||
"source": "rag_search",
|
||
"needs_verification": True,
|
||
})
|
||
|
||
return findings
|
||
|
||
def _deduplicate_findings(self, findings: List[Dict]) -> List[Dict]:
|
||
"""去重发现"""
|
||
seen = set()
|
||
unique = []
|
||
|
||
for finding in findings:
|
||
key = (
|
||
finding.get("file_path", ""),
|
||
finding.get("line_start", 0),
|
||
finding.get("vulnerability_type", ""),
|
||
)
|
||
|
||
if key not in seen:
|
||
seen.add(key)
|
||
unique.append(finding)
|
||
|
||
return unique
|
||
|
||
def _map_semgrep_rule(self, rule_id: str) -> str:
|
||
"""映射 Semgrep 规则到漏洞类型"""
|
||
rule_lower = rule_id.lower()
|
||
|
||
if "sql" in rule_lower:
|
||
return "sql_injection"
|
||
elif "xss" in rule_lower:
|
||
return "xss"
|
||
elif "command" in rule_lower or "injection" in rule_lower:
|
||
return "command_injection"
|
||
elif "path" in rule_lower or "traversal" in rule_lower:
|
||
return "path_traversal"
|
||
elif "ssrf" in rule_lower:
|
||
return "ssrf"
|
||
elif "deserial" in rule_lower:
|
||
return "deserialization"
|
||
elif "secret" in rule_lower or "password" in rule_lower or "key" in rule_lower:
|
||
return "hardcoded_secret"
|
||
elif "crypto" in rule_lower:
|
||
return "weak_crypto"
|
||
else:
|
||
return "other"
|
||
|
||
def _map_semgrep_severity(self, severity: str) -> str:
|
||
"""映射 Semgrep 严重程度"""
|
||
mapping = {
|
||
"ERROR": "high",
|
||
"WARNING": "medium",
|
||
"INFO": "low",
|
||
}
|
||
return mapping.get(severity, "medium")
|
||
|
||
def _map_bandit_test(self, test_id: str) -> str:
|
||
"""映射 Bandit 测试到漏洞类型"""
|
||
mappings = {
|
||
"B101": "assert_used",
|
||
"B102": "exec_used",
|
||
"B103": "hardcoded_password",
|
||
"B104": "hardcoded_bind_all",
|
||
"B105": "hardcoded_password",
|
||
"B106": "hardcoded_password",
|
||
"B107": "hardcoded_password",
|
||
"B108": "hardcoded_tmp",
|
||
"B301": "deserialization",
|
||
"B302": "deserialization",
|
||
"B303": "weak_crypto",
|
||
"B304": "weak_crypto",
|
||
"B305": "weak_crypto",
|
||
"B306": "weak_crypto",
|
||
"B307": "code_injection",
|
||
"B308": "code_injection",
|
||
"B310": "ssrf",
|
||
"B311": "weak_random",
|
||
"B312": "telnet",
|
||
"B501": "ssl_verify",
|
||
"B502": "ssl_verify",
|
||
"B503": "ssl_verify",
|
||
"B504": "ssl_verify",
|
||
"B505": "weak_crypto",
|
||
"B506": "yaml_load",
|
||
"B507": "ssh_key",
|
||
"B601": "command_injection",
|
||
"B602": "command_injection",
|
||
"B603": "command_injection",
|
||
"B604": "command_injection",
|
||
"B605": "command_injection",
|
||
"B606": "command_injection",
|
||
"B607": "command_injection",
|
||
"B608": "sql_injection",
|
||
"B609": "sql_injection",
|
||
"B610": "sql_injection",
|
||
"B611": "sql_injection",
|
||
"B701": "xss",
|
||
"B702": "xss",
|
||
"B703": "xss",
|
||
}
|
||
return mappings.get(test_id, "other")
|
||
|