436 lines
16 KiB
Python
436 lines
16 KiB
Python
"""
|
||
Recon Agent (信息收集层)
|
||
负责项目结构分析、技术栈识别、入口点识别
|
||
|
||
类型: ReAct
|
||
"""
|
||
|
||
import asyncio
|
||
import logging
|
||
import os
|
||
from typing import List, Dict, Any, Optional
|
||
from dataclasses import dataclass, field
|
||
|
||
from .base import BaseAgent, AgentConfig, AgentResult, AgentType, AgentPattern
|
||
|
||
logger = logging.getLogger(__name__)
|
||
|
||
|
||
RECON_SYSTEM_PROMPT = """你是 DeepAudit 的信息收集 Agent,负责在安全审计前收集项目信息。
|
||
|
||
## 你的职责
|
||
1. 分析项目结构和目录布局
|
||
2. 识别使用的技术栈和框架
|
||
3. 找出应用程序入口点
|
||
4. 分析依赖和第三方库
|
||
5. 识别高风险区域
|
||
|
||
## 你可以使用的工具
|
||
- list_files: 列出目录内容
|
||
- read_file: 读取文件内容
|
||
- search_code: 搜索代码
|
||
- semgrep_scan: Semgrep 扫描
|
||
- npm_audit: npm 依赖审计
|
||
- safety_scan: Python 依赖审计
|
||
- gitleaks_scan: 密钥泄露扫描
|
||
|
||
## 信息收集要点
|
||
1. **目录结构**: 了解项目布局,识别源码、配置、测试目录
|
||
2. **技术栈**: 检测语言、框架、数据库等
|
||
3. **入口点**: API 路由、控制器、处理函数
|
||
4. **配置文件**: 环境变量、数据库配置、API 密钥
|
||
5. **依赖**: package.json, requirements.txt, go.mod 等
|
||
6. **安全相关**: 认证、授权、加密相关代码
|
||
|
||
## 输出格式
|
||
完成后返回 JSON:
|
||
```json
|
||
{
|
||
"project_structure": {...},
|
||
"tech_stack": {
|
||
"languages": [],
|
||
"frameworks": [],
|
||
"databases": []
|
||
},
|
||
"entry_points": [],
|
||
"high_risk_areas": [],
|
||
"dependencies": {...},
|
||
"initial_findings": []
|
||
}
|
||
```
|
||
|
||
请系统性地收集信息,为后续分析做准备。"""
|
||
|
||
|
||
class ReconAgent(BaseAgent):
|
||
"""
|
||
信息收集 Agent
|
||
|
||
使用 ReAct 模式迭代收集项目信息
|
||
"""
|
||
|
||
def __init__(
|
||
self,
|
||
llm_service,
|
||
tools: Dict[str, Any],
|
||
event_emitter=None,
|
||
):
|
||
config = AgentConfig(
|
||
name="Recon",
|
||
agent_type=AgentType.RECON,
|
||
pattern=AgentPattern.REACT,
|
||
max_iterations=15,
|
||
system_prompt=RECON_SYSTEM_PROMPT,
|
||
tools=[
|
||
"list_files", "read_file", "search_code",
|
||
"semgrep_scan", "npm_audit", "safety_scan",
|
||
"gitleaks_scan", "osv_scan",
|
||
],
|
||
)
|
||
super().__init__(config, llm_service, tools, event_emitter)
|
||
|
||
async def run(self, input_data: Dict[str, Any]) -> AgentResult:
|
||
"""执行信息收集"""
|
||
import time
|
||
start_time = time.time()
|
||
|
||
project_info = input_data.get("project_info", {})
|
||
config = input_data.get("config", {})
|
||
|
||
try:
|
||
await self.emit_thinking("开始信息收集...")
|
||
|
||
# 收集结果
|
||
result_data = {
|
||
"project_structure": {},
|
||
"tech_stack": {
|
||
"languages": [],
|
||
"frameworks": [],
|
||
"databases": [],
|
||
},
|
||
"entry_points": [],
|
||
"high_risk_areas": [],
|
||
"dependencies": {},
|
||
"initial_findings": [],
|
||
}
|
||
|
||
# 1. 分析项目结构
|
||
await self.emit_thinking("分析项目结构...")
|
||
structure = await self._analyze_structure()
|
||
result_data["project_structure"] = structure
|
||
|
||
# 2. 识别技术栈
|
||
await self.emit_thinking("识别技术栈...")
|
||
tech_stack = await self._identify_tech_stack(structure)
|
||
result_data["tech_stack"] = tech_stack
|
||
|
||
# 3. 扫描依赖漏洞
|
||
await self.emit_thinking("扫描依赖漏洞...")
|
||
deps_result = await self._scan_dependencies(tech_stack)
|
||
result_data["dependencies"] = deps_result.get("dependencies", {})
|
||
if deps_result.get("findings"):
|
||
result_data["initial_findings"].extend(deps_result["findings"])
|
||
|
||
# 4. 快速密钥扫描
|
||
await self.emit_thinking("扫描密钥泄露...")
|
||
secrets_result = await self._scan_secrets()
|
||
if secrets_result.get("findings"):
|
||
result_data["initial_findings"].extend(secrets_result["findings"])
|
||
|
||
# 5. 识别入口点
|
||
await self.emit_thinking("识别入口点...")
|
||
entry_points = await self._identify_entry_points(tech_stack)
|
||
result_data["entry_points"] = entry_points
|
||
|
||
# 6. 识别高风险区域
|
||
result_data["high_risk_areas"] = self._identify_high_risk_areas(
|
||
structure, tech_stack, entry_points
|
||
)
|
||
|
||
duration_ms = int((time.time() - start_time) * 1000)
|
||
|
||
await self.emit_event(
|
||
"info",
|
||
f"信息收集完成: 发现 {len(result_data['entry_points'])} 个入口点, "
|
||
f"{len(result_data['high_risk_areas'])} 个高风险区域, "
|
||
f"{len(result_data['initial_findings'])} 个初步发现"
|
||
)
|
||
|
||
return AgentResult(
|
||
success=True,
|
||
data=result_data,
|
||
iterations=self._iteration,
|
||
tool_calls=self._tool_calls,
|
||
tokens_used=self._total_tokens,
|
||
duration_ms=duration_ms,
|
||
)
|
||
|
||
except Exception as e:
|
||
logger.error(f"Recon agent failed: {e}", exc_info=True)
|
||
return AgentResult(success=False, error=str(e))
|
||
|
||
async def _analyze_structure(self) -> Dict[str, Any]:
|
||
"""分析项目结构"""
|
||
structure = {
|
||
"directories": [],
|
||
"files_by_type": {},
|
||
"config_files": [],
|
||
"total_files": 0,
|
||
}
|
||
|
||
# 列出根目录
|
||
list_tool = self.tools.get("list_files")
|
||
if not list_tool:
|
||
return structure
|
||
|
||
result = await list_tool.execute(directory=".", recursive=True, max_files=300)
|
||
|
||
if result.success:
|
||
structure["total_files"] = result.metadata.get("file_count", 0)
|
||
|
||
# 识别配置文件
|
||
config_patterns = [
|
||
"package.json", "requirements.txt", "go.mod", "Cargo.toml",
|
||
"pom.xml", "build.gradle", ".env", "config.py", "settings.py",
|
||
"docker-compose.yml", "Dockerfile",
|
||
]
|
||
|
||
# 从输出中解析文件列表
|
||
if isinstance(result.data, str):
|
||
for line in result.data.split('\n'):
|
||
line = line.strip()
|
||
for pattern in config_patterns:
|
||
if pattern in line:
|
||
structure["config_files"].append(line)
|
||
|
||
return structure
|
||
|
||
async def _identify_tech_stack(self, structure: Dict) -> Dict[str, Any]:
|
||
"""识别技术栈"""
|
||
tech_stack = {
|
||
"languages": [],
|
||
"frameworks": [],
|
||
"databases": [],
|
||
"package_managers": [],
|
||
}
|
||
|
||
config_files = structure.get("config_files", [])
|
||
|
||
# 基于配置文件推断
|
||
for cfg in config_files:
|
||
if "package.json" in cfg:
|
||
tech_stack["languages"].append("JavaScript/TypeScript")
|
||
tech_stack["package_managers"].append("npm")
|
||
elif "requirements.txt" in cfg or "setup.py" in cfg:
|
||
tech_stack["languages"].append("Python")
|
||
tech_stack["package_managers"].append("pip")
|
||
elif "go.mod" in cfg:
|
||
tech_stack["languages"].append("Go")
|
||
elif "Cargo.toml" in cfg:
|
||
tech_stack["languages"].append("Rust")
|
||
elif "pom.xml" in cfg or "build.gradle" in cfg:
|
||
tech_stack["languages"].append("Java")
|
||
|
||
# 读取 package.json 识别框架
|
||
read_tool = self.tools.get("read_file")
|
||
if read_tool and "package.json" in str(config_files):
|
||
result = await read_tool.execute(file_path="package.json", max_lines=100)
|
||
if result.success:
|
||
content = result.data
|
||
if "react" in content.lower():
|
||
tech_stack["frameworks"].append("React")
|
||
if "vue" in content.lower():
|
||
tech_stack["frameworks"].append("Vue")
|
||
if "express" in content.lower():
|
||
tech_stack["frameworks"].append("Express")
|
||
if "fastify" in content.lower():
|
||
tech_stack["frameworks"].append("Fastify")
|
||
if "next" in content.lower():
|
||
tech_stack["frameworks"].append("Next.js")
|
||
|
||
# 读取 requirements.txt 识别框架
|
||
if read_tool and "requirements.txt" in str(config_files):
|
||
result = await read_tool.execute(file_path="requirements.txt", max_lines=50)
|
||
if result.success:
|
||
content = result.data.lower()
|
||
if "django" in content:
|
||
tech_stack["frameworks"].append("Django")
|
||
if "flask" in content:
|
||
tech_stack["frameworks"].append("Flask")
|
||
if "fastapi" in content:
|
||
tech_stack["frameworks"].append("FastAPI")
|
||
if "sqlalchemy" in content:
|
||
tech_stack["databases"].append("SQLAlchemy")
|
||
if "pymongo" in content:
|
||
tech_stack["databases"].append("MongoDB")
|
||
|
||
# 去重
|
||
tech_stack["languages"] = list(set(tech_stack["languages"]))
|
||
tech_stack["frameworks"] = list(set(tech_stack["frameworks"]))
|
||
tech_stack["databases"] = list(set(tech_stack["databases"]))
|
||
|
||
return tech_stack
|
||
|
||
async def _scan_dependencies(self, tech_stack: Dict) -> Dict[str, Any]:
|
||
"""扫描依赖漏洞"""
|
||
result = {
|
||
"dependencies": {},
|
||
"findings": [],
|
||
}
|
||
|
||
# npm audit
|
||
if "npm" in tech_stack.get("package_managers", []):
|
||
npm_tool = self.tools.get("npm_audit")
|
||
if npm_tool:
|
||
npm_result = await npm_tool.execute()
|
||
if npm_result.success and npm_result.metadata.get("findings_count", 0) > 0:
|
||
result["dependencies"]["npm"] = npm_result.metadata
|
||
|
||
# 转换为发现格式
|
||
for sev, count in npm_result.metadata.get("severity_counts", {}).items():
|
||
if count > 0 and sev in ["critical", "high"]:
|
||
result["findings"].append({
|
||
"vulnerability_type": "dependency_vulnerability",
|
||
"severity": sev,
|
||
"title": f"npm 依赖漏洞 ({count} 个 {sev})",
|
||
"source": "npm_audit",
|
||
})
|
||
|
||
# Safety (Python)
|
||
if "pip" in tech_stack.get("package_managers", []):
|
||
safety_tool = self.tools.get("safety_scan")
|
||
if safety_tool:
|
||
safety_result = await safety_tool.execute()
|
||
if safety_result.success and safety_result.metadata.get("findings_count", 0) > 0:
|
||
result["dependencies"]["pip"] = safety_result.metadata
|
||
result["findings"].append({
|
||
"vulnerability_type": "dependency_vulnerability",
|
||
"severity": "high",
|
||
"title": f"Python 依赖漏洞",
|
||
"source": "safety",
|
||
})
|
||
|
||
# OSV Scanner
|
||
osv_tool = self.tools.get("osv_scan")
|
||
if osv_tool:
|
||
osv_result = await osv_tool.execute()
|
||
if osv_result.success and osv_result.metadata.get("findings_count", 0) > 0:
|
||
result["dependencies"]["osv"] = osv_result.metadata
|
||
|
||
return result
|
||
|
||
async def _scan_secrets(self) -> Dict[str, Any]:
|
||
"""扫描密钥泄露"""
|
||
result = {"findings": []}
|
||
|
||
gitleaks_tool = self.tools.get("gitleaks_scan")
|
||
if gitleaks_tool:
|
||
gl_result = await gitleaks_tool.execute()
|
||
if gl_result.success and gl_result.metadata.get("findings_count", 0) > 0:
|
||
for finding in gl_result.metadata.get("findings", []):
|
||
result["findings"].append({
|
||
"vulnerability_type": "hardcoded_secret",
|
||
"severity": "high",
|
||
"title": f"密钥泄露: {finding.get('rule', 'unknown')}",
|
||
"file_path": finding.get("file"),
|
||
"line_start": finding.get("line"),
|
||
"source": "gitleaks",
|
||
})
|
||
|
||
return result
|
||
|
||
async def _identify_entry_points(self, tech_stack: Dict) -> List[Dict[str, Any]]:
|
||
"""识别入口点"""
|
||
entry_points = []
|
||
search_tool = self.tools.get("search_code")
|
||
|
||
if not search_tool:
|
||
return entry_points
|
||
|
||
# 基于框架搜索入口点
|
||
search_patterns = []
|
||
|
||
frameworks = tech_stack.get("frameworks", [])
|
||
|
||
if "Express" in frameworks:
|
||
search_patterns.extend([
|
||
("app.get(", "Express GET route"),
|
||
("app.post(", "Express POST route"),
|
||
("router.get(", "Express router GET"),
|
||
("router.post(", "Express router POST"),
|
||
])
|
||
|
||
if "FastAPI" in frameworks:
|
||
search_patterns.extend([
|
||
("@app.get(", "FastAPI GET endpoint"),
|
||
("@app.post(", "FastAPI POST endpoint"),
|
||
("@router.get(", "FastAPI router GET"),
|
||
("@router.post(", "FastAPI router POST"),
|
||
])
|
||
|
||
if "Django" in frameworks:
|
||
search_patterns.extend([
|
||
("def get(self", "Django GET view"),
|
||
("def post(self", "Django POST view"),
|
||
("path(", "Django URL pattern"),
|
||
])
|
||
|
||
if "Flask" in frameworks:
|
||
search_patterns.extend([
|
||
("@app.route(", "Flask route"),
|
||
("@blueprint.route(", "Flask blueprint route"),
|
||
])
|
||
|
||
# 通用模式
|
||
search_patterns.extend([
|
||
("def handle", "Handler function"),
|
||
("async def handle", "Async handler"),
|
||
("class.*Controller", "Controller class"),
|
||
("class.*Handler", "Handler class"),
|
||
])
|
||
|
||
for pattern, description in search_patterns[:10]: # 限制搜索数量
|
||
result = await search_tool.execute(keyword=pattern, max_results=10)
|
||
if result.success and result.metadata.get("matches", 0) > 0:
|
||
for match in result.metadata.get("results", [])[:5]:
|
||
entry_points.append({
|
||
"type": description,
|
||
"file": match.get("file"),
|
||
"line": match.get("line"),
|
||
"pattern": pattern,
|
||
})
|
||
|
||
return entry_points[:30] # 限制总数
|
||
|
||
def _identify_high_risk_areas(
|
||
self,
|
||
structure: Dict,
|
||
tech_stack: Dict,
|
||
entry_points: List[Dict],
|
||
) -> List[str]:
|
||
"""识别高风险区域"""
|
||
high_risk = set()
|
||
|
||
# 通用高风险目录
|
||
risk_dirs = [
|
||
"auth/", "authentication/", "login/",
|
||
"api/", "routes/", "controllers/", "handlers/",
|
||
"db/", "database/", "models/",
|
||
"admin/", "management/",
|
||
"upload/", "file/",
|
||
"payment/", "billing/",
|
||
]
|
||
|
||
for dir_name in risk_dirs:
|
||
high_risk.add(dir_name)
|
||
|
||
# 从入口点提取目录
|
||
for ep in entry_points:
|
||
file_path = ep.get("file", "")
|
||
if "/" in file_path:
|
||
dir_path = "/".join(file_path.split("/")[:-1]) + "/"
|
||
high_risk.add(dir_path)
|
||
|
||
return list(high_risk)[:20]
|
||
|