CodeReview/backend/app/services/agent/agents/recon.py

436 lines
16 KiB
Python
Raw Normal View History

"""
Recon Agent (信息收集层)
负责项目结构分析技术栈识别入口点识别
类型: ReAct
"""
import asyncio
import logging
import os
from typing import List, Dict, Any, Optional
from dataclasses import dataclass, field
from .base import BaseAgent, AgentConfig, AgentResult, AgentType, AgentPattern
logger = logging.getLogger(__name__)
RECON_SYSTEM_PROMPT = """你是 DeepAudit 的信息收集 Agent负责在安全审计前收集项目信息。
## 你的职责
1. 分析项目结构和目录布局
2. 识别使用的技术栈和框架
3. 找出应用程序入口点
4. 分析依赖和第三方库
5. 识别高风险区域
## 你可以使用的工具
- list_files: 列出目录内容
- read_file: 读取文件内容
- search_code: 搜索代码
- semgrep_scan: Semgrep 扫描
- npm_audit: npm 依赖审计
- safety_scan: Python 依赖审计
- gitleaks_scan: 密钥泄露扫描
## 信息收集要点
1. **目录结构**: 了解项目布局识别源码配置测试目录
2. **技术栈**: 检测语言框架数据库等
3. **入口点**: API 路由控制器处理函数
4. **配置文件**: 环境变量数据库配置API 密钥
5. **依赖**: package.json, requirements.txt, go.mod
6. **安全相关**: 认证授权加密相关代码
## 输出格式
完成后返回 JSON:
```json
{
"project_structure": {...},
"tech_stack": {
"languages": [],
"frameworks": [],
"databases": []
},
"entry_points": [],
"high_risk_areas": [],
"dependencies": {...},
"initial_findings": []
}
```
请系统性地收集信息为后续分析做准备"""
class ReconAgent(BaseAgent):
"""
信息收集 Agent
使用 ReAct 模式迭代收集项目信息
"""
def __init__(
self,
llm_service,
tools: Dict[str, Any],
event_emitter=None,
):
config = AgentConfig(
name="Recon",
agent_type=AgentType.RECON,
pattern=AgentPattern.REACT,
max_iterations=15,
system_prompt=RECON_SYSTEM_PROMPT,
tools=[
"list_files", "read_file", "search_code",
"semgrep_scan", "npm_audit", "safety_scan",
"gitleaks_scan", "osv_scan",
],
)
super().__init__(config, llm_service, tools, event_emitter)
async def run(self, input_data: Dict[str, Any]) -> AgentResult:
"""执行信息收集"""
import time
start_time = time.time()
project_info = input_data.get("project_info", {})
config = input_data.get("config", {})
try:
await self.emit_thinking("开始信息收集...")
# 收集结果
result_data = {
"project_structure": {},
"tech_stack": {
"languages": [],
"frameworks": [],
"databases": [],
},
"entry_points": [],
"high_risk_areas": [],
"dependencies": {},
"initial_findings": [],
}
# 1. 分析项目结构
await self.emit_thinking("分析项目结构...")
structure = await self._analyze_structure()
result_data["project_structure"] = structure
# 2. 识别技术栈
await self.emit_thinking("识别技术栈...")
tech_stack = await self._identify_tech_stack(structure)
result_data["tech_stack"] = tech_stack
# 3. 扫描依赖漏洞
await self.emit_thinking("扫描依赖漏洞...")
deps_result = await self._scan_dependencies(tech_stack)
result_data["dependencies"] = deps_result.get("dependencies", {})
if deps_result.get("findings"):
result_data["initial_findings"].extend(deps_result["findings"])
# 4. 快速密钥扫描
await self.emit_thinking("扫描密钥泄露...")
secrets_result = await self._scan_secrets()
if secrets_result.get("findings"):
result_data["initial_findings"].extend(secrets_result["findings"])
# 5. 识别入口点
await self.emit_thinking("识别入口点...")
entry_points = await self._identify_entry_points(tech_stack)
result_data["entry_points"] = entry_points
# 6. 识别高风险区域
result_data["high_risk_areas"] = self._identify_high_risk_areas(
structure, tech_stack, entry_points
)
duration_ms = int((time.time() - start_time) * 1000)
await self.emit_event(
"info",
f"信息收集完成: 发现 {len(result_data['entry_points'])} 个入口点, "
f"{len(result_data['high_risk_areas'])} 个高风险区域, "
f"{len(result_data['initial_findings'])} 个初步发现"
)
return AgentResult(
success=True,
data=result_data,
iterations=self._iteration,
tool_calls=self._tool_calls,
tokens_used=self._total_tokens,
duration_ms=duration_ms,
)
except Exception as e:
logger.error(f"Recon agent failed: {e}", exc_info=True)
return AgentResult(success=False, error=str(e))
async def _analyze_structure(self) -> Dict[str, Any]:
"""分析项目结构"""
structure = {
"directories": [],
"files_by_type": {},
"config_files": [],
"total_files": 0,
}
# 列出根目录
list_tool = self.tools.get("list_files")
if not list_tool:
return structure
result = await list_tool.execute(directory=".", recursive=True, max_files=300)
if result.success:
structure["total_files"] = result.metadata.get("file_count", 0)
# 识别配置文件
config_patterns = [
"package.json", "requirements.txt", "go.mod", "Cargo.toml",
"pom.xml", "build.gradle", ".env", "config.py", "settings.py",
"docker-compose.yml", "Dockerfile",
]
# 从输出中解析文件列表
if isinstance(result.data, str):
for line in result.data.split('\n'):
line = line.strip()
for pattern in config_patterns:
if pattern in line:
structure["config_files"].append(line)
return structure
async def _identify_tech_stack(self, structure: Dict) -> Dict[str, Any]:
"""识别技术栈"""
tech_stack = {
"languages": [],
"frameworks": [],
"databases": [],
"package_managers": [],
}
config_files = structure.get("config_files", [])
# 基于配置文件推断
for cfg in config_files:
if "package.json" in cfg:
tech_stack["languages"].append("JavaScript/TypeScript")
tech_stack["package_managers"].append("npm")
elif "requirements.txt" in cfg or "setup.py" in cfg:
tech_stack["languages"].append("Python")
tech_stack["package_managers"].append("pip")
elif "go.mod" in cfg:
tech_stack["languages"].append("Go")
elif "Cargo.toml" in cfg:
tech_stack["languages"].append("Rust")
elif "pom.xml" in cfg or "build.gradle" in cfg:
tech_stack["languages"].append("Java")
# 读取 package.json 识别框架
read_tool = self.tools.get("read_file")
if read_tool and "package.json" in str(config_files):
result = await read_tool.execute(file_path="package.json", max_lines=100)
if result.success:
content = result.data
if "react" in content.lower():
tech_stack["frameworks"].append("React")
if "vue" in content.lower():
tech_stack["frameworks"].append("Vue")
if "express" in content.lower():
tech_stack["frameworks"].append("Express")
if "fastify" in content.lower():
tech_stack["frameworks"].append("Fastify")
if "next" in content.lower():
tech_stack["frameworks"].append("Next.js")
# 读取 requirements.txt 识别框架
if read_tool and "requirements.txt" in str(config_files):
result = await read_tool.execute(file_path="requirements.txt", max_lines=50)
if result.success:
content = result.data.lower()
if "django" in content:
tech_stack["frameworks"].append("Django")
if "flask" in content:
tech_stack["frameworks"].append("Flask")
if "fastapi" in content:
tech_stack["frameworks"].append("FastAPI")
if "sqlalchemy" in content:
tech_stack["databases"].append("SQLAlchemy")
if "pymongo" in content:
tech_stack["databases"].append("MongoDB")
# 去重
tech_stack["languages"] = list(set(tech_stack["languages"]))
tech_stack["frameworks"] = list(set(tech_stack["frameworks"]))
tech_stack["databases"] = list(set(tech_stack["databases"]))
return tech_stack
async def _scan_dependencies(self, tech_stack: Dict) -> Dict[str, Any]:
"""扫描依赖漏洞"""
result = {
"dependencies": {},
"findings": [],
}
# npm audit
if "npm" in tech_stack.get("package_managers", []):
npm_tool = self.tools.get("npm_audit")
if npm_tool:
npm_result = await npm_tool.execute()
if npm_result.success and npm_result.metadata.get("findings_count", 0) > 0:
result["dependencies"]["npm"] = npm_result.metadata
# 转换为发现格式
for sev, count in npm_result.metadata.get("severity_counts", {}).items():
if count > 0 and sev in ["critical", "high"]:
result["findings"].append({
"vulnerability_type": "dependency_vulnerability",
"severity": sev,
"title": f"npm 依赖漏洞 ({count}{sev})",
"source": "npm_audit",
})
# Safety (Python)
if "pip" in tech_stack.get("package_managers", []):
safety_tool = self.tools.get("safety_scan")
if safety_tool:
safety_result = await safety_tool.execute()
if safety_result.success and safety_result.metadata.get("findings_count", 0) > 0:
result["dependencies"]["pip"] = safety_result.metadata
result["findings"].append({
"vulnerability_type": "dependency_vulnerability",
"severity": "high",
"title": f"Python 依赖漏洞",
"source": "safety",
})
# OSV Scanner
osv_tool = self.tools.get("osv_scan")
if osv_tool:
osv_result = await osv_tool.execute()
if osv_result.success and osv_result.metadata.get("findings_count", 0) > 0:
result["dependencies"]["osv"] = osv_result.metadata
return result
async def _scan_secrets(self) -> Dict[str, Any]:
"""扫描密钥泄露"""
result = {"findings": []}
gitleaks_tool = self.tools.get("gitleaks_scan")
if gitleaks_tool:
gl_result = await gitleaks_tool.execute()
if gl_result.success and gl_result.metadata.get("findings_count", 0) > 0:
for finding in gl_result.metadata.get("findings", []):
result["findings"].append({
"vulnerability_type": "hardcoded_secret",
"severity": "high",
"title": f"密钥泄露: {finding.get('rule', 'unknown')}",
"file_path": finding.get("file"),
"line_start": finding.get("line"),
"source": "gitleaks",
})
return result
async def _identify_entry_points(self, tech_stack: Dict) -> List[Dict[str, Any]]:
"""识别入口点"""
entry_points = []
search_tool = self.tools.get("search_code")
if not search_tool:
return entry_points
# 基于框架搜索入口点
search_patterns = []
frameworks = tech_stack.get("frameworks", [])
if "Express" in frameworks:
search_patterns.extend([
("app.get(", "Express GET route"),
("app.post(", "Express POST route"),
("router.get(", "Express router GET"),
("router.post(", "Express router POST"),
])
if "FastAPI" in frameworks:
search_patterns.extend([
("@app.get(", "FastAPI GET endpoint"),
("@app.post(", "FastAPI POST endpoint"),
("@router.get(", "FastAPI router GET"),
("@router.post(", "FastAPI router POST"),
])
if "Django" in frameworks:
search_patterns.extend([
("def get(self", "Django GET view"),
("def post(self", "Django POST view"),
("path(", "Django URL pattern"),
])
if "Flask" in frameworks:
search_patterns.extend([
("@app.route(", "Flask route"),
("@blueprint.route(", "Flask blueprint route"),
])
# 通用模式
search_patterns.extend([
("def handle", "Handler function"),
("async def handle", "Async handler"),
("class.*Controller", "Controller class"),
("class.*Handler", "Handler class"),
])
for pattern, description in search_patterns[:10]: # 限制搜索数量
result = await search_tool.execute(keyword=pattern, max_results=10)
if result.success and result.metadata.get("matches", 0) > 0:
for match in result.metadata.get("results", [])[:5]:
entry_points.append({
"type": description,
"file": match.get("file"),
"line": match.get("line"),
"pattern": pattern,
})
return entry_points[:30] # 限制总数
def _identify_high_risk_areas(
self,
structure: Dict,
tech_stack: Dict,
entry_points: List[Dict],
) -> List[str]:
"""识别高风险区域"""
high_risk = set()
# 通用高风险目录
risk_dirs = [
"auth/", "authentication/", "login/",
"api/", "routes/", "controllers/", "handlers/",
"db/", "database/", "models/",
"admin/", "management/",
"upload/", "file/",
"payment/", "billing/",
]
for dir_name in risk_dirs:
high_risk.add(dir_name)
# 从入口点提取目录
for ep in entry_points:
file_path = ep.get("file", "")
if "/" in file_path:
dir_path = "/".join(file_path.split("/")[:-1]) + "/"
high_risk.add(dir_path)
return list(high_risk)[:20]