CodeReview/backend/app/services/agent/agents/recon.py

436 lines
16 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""
Recon Agent (信息收集层)
负责项目结构分析、技术栈识别、入口点识别
类型: ReAct
"""
import asyncio
import logging
import os
from typing import List, Dict, Any, Optional
from dataclasses import dataclass, field
from .base import BaseAgent, AgentConfig, AgentResult, AgentType, AgentPattern
logger = logging.getLogger(__name__)
RECON_SYSTEM_PROMPT = """你是 DeepAudit 的信息收集 Agent负责在安全审计前收集项目信息。
## 你的职责
1. 分析项目结构和目录布局
2. 识别使用的技术栈和框架
3. 找出应用程序入口点
4. 分析依赖和第三方库
5. 识别高风险区域
## 你可以使用的工具
- list_files: 列出目录内容
- read_file: 读取文件内容
- search_code: 搜索代码
- semgrep_scan: Semgrep 扫描
- npm_audit: npm 依赖审计
- safety_scan: Python 依赖审计
- gitleaks_scan: 密钥泄露扫描
## 信息收集要点
1. **目录结构**: 了解项目布局,识别源码、配置、测试目录
2. **技术栈**: 检测语言、框架、数据库等
3. **入口点**: API 路由、控制器、处理函数
4. **配置文件**: 环境变量、数据库配置、API 密钥
5. **依赖**: package.json, requirements.txt, go.mod 等
6. **安全相关**: 认证、授权、加密相关代码
## 输出格式
完成后返回 JSON:
```json
{
"project_structure": {...},
"tech_stack": {
"languages": [],
"frameworks": [],
"databases": []
},
"entry_points": [],
"high_risk_areas": [],
"dependencies": {...},
"initial_findings": []
}
```
请系统性地收集信息,为后续分析做准备。"""
class ReconAgent(BaseAgent):
"""
信息收集 Agent
使用 ReAct 模式迭代收集项目信息
"""
def __init__(
self,
llm_service,
tools: Dict[str, Any],
event_emitter=None,
):
config = AgentConfig(
name="Recon",
agent_type=AgentType.RECON,
pattern=AgentPattern.REACT,
max_iterations=15,
system_prompt=RECON_SYSTEM_PROMPT,
tools=[
"list_files", "read_file", "search_code",
"semgrep_scan", "npm_audit", "safety_scan",
"gitleaks_scan", "osv_scan",
],
)
super().__init__(config, llm_service, tools, event_emitter)
async def run(self, input_data: Dict[str, Any]) -> AgentResult:
"""执行信息收集"""
import time
start_time = time.time()
project_info = input_data.get("project_info", {})
config = input_data.get("config", {})
try:
await self.emit_thinking("开始信息收集...")
# 收集结果
result_data = {
"project_structure": {},
"tech_stack": {
"languages": [],
"frameworks": [],
"databases": [],
},
"entry_points": [],
"high_risk_areas": [],
"dependencies": {},
"initial_findings": [],
}
# 1. 分析项目结构
await self.emit_thinking("分析项目结构...")
structure = await self._analyze_structure()
result_data["project_structure"] = structure
# 2. 识别技术栈
await self.emit_thinking("识别技术栈...")
tech_stack = await self._identify_tech_stack(structure)
result_data["tech_stack"] = tech_stack
# 3. 扫描依赖漏洞
await self.emit_thinking("扫描依赖漏洞...")
deps_result = await self._scan_dependencies(tech_stack)
result_data["dependencies"] = deps_result.get("dependencies", {})
if deps_result.get("findings"):
result_data["initial_findings"].extend(deps_result["findings"])
# 4. 快速密钥扫描
await self.emit_thinking("扫描密钥泄露...")
secrets_result = await self._scan_secrets()
if secrets_result.get("findings"):
result_data["initial_findings"].extend(secrets_result["findings"])
# 5. 识别入口点
await self.emit_thinking("识别入口点...")
entry_points = await self._identify_entry_points(tech_stack)
result_data["entry_points"] = entry_points
# 6. 识别高风险区域
result_data["high_risk_areas"] = self._identify_high_risk_areas(
structure, tech_stack, entry_points
)
duration_ms = int((time.time() - start_time) * 1000)
await self.emit_event(
"info",
f"信息收集完成: 发现 {len(result_data['entry_points'])} 个入口点, "
f"{len(result_data['high_risk_areas'])} 个高风险区域, "
f"{len(result_data['initial_findings'])} 个初步发现"
)
return AgentResult(
success=True,
data=result_data,
iterations=self._iteration,
tool_calls=self._tool_calls,
tokens_used=self._total_tokens,
duration_ms=duration_ms,
)
except Exception as e:
logger.error(f"Recon agent failed: {e}", exc_info=True)
return AgentResult(success=False, error=str(e))
async def _analyze_structure(self) -> Dict[str, Any]:
"""分析项目结构"""
structure = {
"directories": [],
"files_by_type": {},
"config_files": [],
"total_files": 0,
}
# 列出根目录
list_tool = self.tools.get("list_files")
if not list_tool:
return structure
result = await list_tool.execute(directory=".", recursive=True, max_files=300)
if result.success:
structure["total_files"] = result.metadata.get("file_count", 0)
# 识别配置文件
config_patterns = [
"package.json", "requirements.txt", "go.mod", "Cargo.toml",
"pom.xml", "build.gradle", ".env", "config.py", "settings.py",
"docker-compose.yml", "Dockerfile",
]
# 从输出中解析文件列表
if isinstance(result.data, str):
for line in result.data.split('\n'):
line = line.strip()
for pattern in config_patterns:
if pattern in line:
structure["config_files"].append(line)
return structure
async def _identify_tech_stack(self, structure: Dict) -> Dict[str, Any]:
"""识别技术栈"""
tech_stack = {
"languages": [],
"frameworks": [],
"databases": [],
"package_managers": [],
}
config_files = structure.get("config_files", [])
# 基于配置文件推断
for cfg in config_files:
if "package.json" in cfg:
tech_stack["languages"].append("JavaScript/TypeScript")
tech_stack["package_managers"].append("npm")
elif "requirements.txt" in cfg or "setup.py" in cfg:
tech_stack["languages"].append("Python")
tech_stack["package_managers"].append("pip")
elif "go.mod" in cfg:
tech_stack["languages"].append("Go")
elif "Cargo.toml" in cfg:
tech_stack["languages"].append("Rust")
elif "pom.xml" in cfg or "build.gradle" in cfg:
tech_stack["languages"].append("Java")
# 读取 package.json 识别框架
read_tool = self.tools.get("read_file")
if read_tool and "package.json" in str(config_files):
result = await read_tool.execute(file_path="package.json", max_lines=100)
if result.success:
content = result.data
if "react" in content.lower():
tech_stack["frameworks"].append("React")
if "vue" in content.lower():
tech_stack["frameworks"].append("Vue")
if "express" in content.lower():
tech_stack["frameworks"].append("Express")
if "fastify" in content.lower():
tech_stack["frameworks"].append("Fastify")
if "next" in content.lower():
tech_stack["frameworks"].append("Next.js")
# 读取 requirements.txt 识别框架
if read_tool and "requirements.txt" in str(config_files):
result = await read_tool.execute(file_path="requirements.txt", max_lines=50)
if result.success:
content = result.data.lower()
if "django" in content:
tech_stack["frameworks"].append("Django")
if "flask" in content:
tech_stack["frameworks"].append("Flask")
if "fastapi" in content:
tech_stack["frameworks"].append("FastAPI")
if "sqlalchemy" in content:
tech_stack["databases"].append("SQLAlchemy")
if "pymongo" in content:
tech_stack["databases"].append("MongoDB")
# 去重
tech_stack["languages"] = list(set(tech_stack["languages"]))
tech_stack["frameworks"] = list(set(tech_stack["frameworks"]))
tech_stack["databases"] = list(set(tech_stack["databases"]))
return tech_stack
async def _scan_dependencies(self, tech_stack: Dict) -> Dict[str, Any]:
"""扫描依赖漏洞"""
result = {
"dependencies": {},
"findings": [],
}
# npm audit
if "npm" in tech_stack.get("package_managers", []):
npm_tool = self.tools.get("npm_audit")
if npm_tool:
npm_result = await npm_tool.execute()
if npm_result.success and npm_result.metadata.get("findings_count", 0) > 0:
result["dependencies"]["npm"] = npm_result.metadata
# 转换为发现格式
for sev, count in npm_result.metadata.get("severity_counts", {}).items():
if count > 0 and sev in ["critical", "high"]:
result["findings"].append({
"vulnerability_type": "dependency_vulnerability",
"severity": sev,
"title": f"npm 依赖漏洞 ({count}{sev})",
"source": "npm_audit",
})
# Safety (Python)
if "pip" in tech_stack.get("package_managers", []):
safety_tool = self.tools.get("safety_scan")
if safety_tool:
safety_result = await safety_tool.execute()
if safety_result.success and safety_result.metadata.get("findings_count", 0) > 0:
result["dependencies"]["pip"] = safety_result.metadata
result["findings"].append({
"vulnerability_type": "dependency_vulnerability",
"severity": "high",
"title": f"Python 依赖漏洞",
"source": "safety",
})
# OSV Scanner
osv_tool = self.tools.get("osv_scan")
if osv_tool:
osv_result = await osv_tool.execute()
if osv_result.success and osv_result.metadata.get("findings_count", 0) > 0:
result["dependencies"]["osv"] = osv_result.metadata
return result
async def _scan_secrets(self) -> Dict[str, Any]:
"""扫描密钥泄露"""
result = {"findings": []}
gitleaks_tool = self.tools.get("gitleaks_scan")
if gitleaks_tool:
gl_result = await gitleaks_tool.execute()
if gl_result.success and gl_result.metadata.get("findings_count", 0) > 0:
for finding in gl_result.metadata.get("findings", []):
result["findings"].append({
"vulnerability_type": "hardcoded_secret",
"severity": "high",
"title": f"密钥泄露: {finding.get('rule', 'unknown')}",
"file_path": finding.get("file"),
"line_start": finding.get("line"),
"source": "gitleaks",
})
return result
async def _identify_entry_points(self, tech_stack: Dict) -> List[Dict[str, Any]]:
"""识别入口点"""
entry_points = []
search_tool = self.tools.get("search_code")
if not search_tool:
return entry_points
# 基于框架搜索入口点
search_patterns = []
frameworks = tech_stack.get("frameworks", [])
if "Express" in frameworks:
search_patterns.extend([
("app.get(", "Express GET route"),
("app.post(", "Express POST route"),
("router.get(", "Express router GET"),
("router.post(", "Express router POST"),
])
if "FastAPI" in frameworks:
search_patterns.extend([
("@app.get(", "FastAPI GET endpoint"),
("@app.post(", "FastAPI POST endpoint"),
("@router.get(", "FastAPI router GET"),
("@router.post(", "FastAPI router POST"),
])
if "Django" in frameworks:
search_patterns.extend([
("def get(self", "Django GET view"),
("def post(self", "Django POST view"),
("path(", "Django URL pattern"),
])
if "Flask" in frameworks:
search_patterns.extend([
("@app.route(", "Flask route"),
("@blueprint.route(", "Flask blueprint route"),
])
# 通用模式
search_patterns.extend([
("def handle", "Handler function"),
("async def handle", "Async handler"),
("class.*Controller", "Controller class"),
("class.*Handler", "Handler class"),
])
for pattern, description in search_patterns[:10]: # 限制搜索数量
result = await search_tool.execute(keyword=pattern, max_results=10)
if result.success and result.metadata.get("matches", 0) > 0:
for match in result.metadata.get("results", [])[:5]:
entry_points.append({
"type": description,
"file": match.get("file"),
"line": match.get("line"),
"pattern": pattern,
})
return entry_points[:30] # 限制总数
def _identify_high_risk_areas(
self,
structure: Dict,
tech_stack: Dict,
entry_points: List[Dict],
) -> List[str]:
"""识别高风险区域"""
high_risk = set()
# 通用高风险目录
risk_dirs = [
"auth/", "authentication/", "login/",
"api/", "routes/", "controllers/", "handlers/",
"db/", "database/", "models/",
"admin/", "management/",
"upload/", "file/",
"payment/", "billing/",
]
for dir_name in risk_dirs:
high_risk.add(dir_name)
# 从入口点提取目录
for ep in entry_points:
file_path = ep.get("file", "")
if "/" in file_path:
dir_path = "/".join(file_path.split("/")[:-1]) + "/"
high_risk.add(dir_path)
return list(high_risk)[:20]