214 lines
6.5 KiB
Python
214 lines
6.5 KiB
Python
"""
|
|
Agent 单元测试
|
|
测试各个 Agent 的功能
|
|
"""
|
|
|
|
import pytest
|
|
import asyncio
|
|
import os
|
|
from unittest.mock import MagicMock, AsyncMock, patch
|
|
|
|
from app.services.agent.agents.base import BaseAgent, AgentConfig, AgentResult, AgentType, AgentPattern
|
|
from app.services.agent.agents.recon import ReconAgent
|
|
from app.services.agent.agents.analysis import AnalysisAgent
|
|
from app.services.agent.agents.verification import VerificationAgent
|
|
|
|
|
|
class TestReconAgent:
|
|
"""Recon Agent 测试"""
|
|
|
|
@pytest.fixture
|
|
def recon_agent(self, temp_project_dir, mock_llm_service, mock_event_emitter):
|
|
"""创建 Recon Agent 实例"""
|
|
from app.services.agent.tools import (
|
|
FileReadTool, FileSearchTool, ListFilesTool,
|
|
)
|
|
|
|
tools = {
|
|
"list_files": ListFilesTool(temp_project_dir),
|
|
"read_file": FileReadTool(temp_project_dir),
|
|
"search_code": FileSearchTool(temp_project_dir),
|
|
}
|
|
|
|
return ReconAgent(
|
|
llm_service=mock_llm_service,
|
|
tools=tools,
|
|
event_emitter=mock_event_emitter,
|
|
)
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_recon_agent_run(self, recon_agent, temp_project_dir):
|
|
"""测试 Recon Agent 运行"""
|
|
result = await recon_agent.run({
|
|
"project_info": {
|
|
"name": "Test Project",
|
|
"root": temp_project_dir,
|
|
},
|
|
"config": {},
|
|
})
|
|
|
|
assert result.success is True
|
|
assert result.data is not None
|
|
|
|
# 验证返回数据结构
|
|
data = result.data
|
|
assert "tech_stack" in data
|
|
assert "entry_points" in data or "high_risk_areas" in data
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_recon_agent_identifies_python(self, recon_agent, temp_project_dir):
|
|
"""测试 Recon Agent 识别 Python 技术栈"""
|
|
result = await recon_agent.run({
|
|
"project_info": {"root": temp_project_dir},
|
|
"config": {},
|
|
})
|
|
|
|
assert result.success is True
|
|
tech_stack = result.data.get("tech_stack", {})
|
|
languages = tech_stack.get("languages", [])
|
|
|
|
# 应该识别出 Python
|
|
assert "Python" in languages or len(languages) > 0
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_recon_agent_finds_high_risk_areas(self, recon_agent, temp_project_dir):
|
|
"""测试 Recon Agent 发现高风险区域"""
|
|
result = await recon_agent.run({
|
|
"project_info": {"root": temp_project_dir},
|
|
"config": {},
|
|
})
|
|
|
|
assert result.success is True
|
|
high_risk_areas = result.data.get("high_risk_areas", [])
|
|
|
|
# 应该发现高风险区域
|
|
assert len(high_risk_areas) > 0
|
|
|
|
|
|
class TestAnalysisAgent:
|
|
"""Analysis Agent 测试"""
|
|
|
|
@pytest.fixture
|
|
def analysis_agent(self, temp_project_dir, mock_llm_service, mock_event_emitter):
|
|
"""创建 Analysis Agent 实例"""
|
|
from app.services.agent.tools import (
|
|
FileReadTool, FileSearchTool, PatternMatchTool,
|
|
)
|
|
|
|
tools = {
|
|
"read_file": FileReadTool(temp_project_dir),
|
|
"search_code": FileSearchTool(temp_project_dir),
|
|
"pattern_match": PatternMatchTool(temp_project_dir),
|
|
}
|
|
|
|
return AnalysisAgent(
|
|
llm_service=mock_llm_service,
|
|
tools=tools,
|
|
event_emitter=mock_event_emitter,
|
|
)
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_analysis_agent_run(self, analysis_agent, temp_project_dir):
|
|
"""测试 Analysis Agent 运行"""
|
|
result = await analysis_agent.run({
|
|
"tech_stack": {"languages": ["Python"]},
|
|
"entry_points": [],
|
|
"high_risk_areas": ["src/sql_vuln.py", "src/cmd_vuln.py"],
|
|
"config": {},
|
|
})
|
|
|
|
assert result.success is True
|
|
assert result.data is not None
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_analysis_agent_finds_vulnerabilities(self, analysis_agent, temp_project_dir):
|
|
"""测试 Analysis Agent 发现漏洞"""
|
|
result = await analysis_agent.run({
|
|
"tech_stack": {"languages": ["Python"]},
|
|
"entry_points": [],
|
|
"high_risk_areas": [
|
|
"src/sql_vuln.py",
|
|
"src/cmd_vuln.py",
|
|
"src/xss_vuln.py",
|
|
"src/secrets.py",
|
|
],
|
|
"config": {},
|
|
})
|
|
|
|
assert result.success is True
|
|
findings = result.data.get("findings", [])
|
|
|
|
# 应该发现一些漏洞
|
|
# 注意:具体数量取决于分析逻辑
|
|
assert isinstance(findings, list)
|
|
|
|
|
|
class TestAgentResult:
|
|
"""Agent 结果测试"""
|
|
|
|
def test_agent_result_success(self):
|
|
"""测试成功的 Agent 结果"""
|
|
result = AgentResult(
|
|
success=True,
|
|
data={"findings": []},
|
|
iterations=5,
|
|
tool_calls=10,
|
|
)
|
|
|
|
assert result.success is True
|
|
assert result.iterations == 5
|
|
assert result.tool_calls == 10
|
|
|
|
def test_agent_result_failure(self):
|
|
"""测试失败的 Agent 结果"""
|
|
result = AgentResult(
|
|
success=False,
|
|
error="Test error",
|
|
)
|
|
|
|
assert result.success is False
|
|
assert result.error == "Test error"
|
|
|
|
def test_agent_result_to_dict(self):
|
|
"""测试 Agent 结果转字典"""
|
|
result = AgentResult(
|
|
success=True,
|
|
data={"key": "value"},
|
|
iterations=3,
|
|
)
|
|
|
|
d = result.to_dict()
|
|
|
|
assert d["success"] is True
|
|
assert d["iterations"] == 3
|
|
|
|
|
|
class TestAgentConfig:
|
|
"""Agent 配置测试"""
|
|
|
|
def test_agent_config_defaults(self):
|
|
"""测试 Agent 配置默认值"""
|
|
config = AgentConfig(
|
|
name="Test",
|
|
agent_type=AgentType.RECON,
|
|
)
|
|
|
|
assert config.pattern == AgentPattern.REACT
|
|
assert config.max_iterations == 20
|
|
assert config.temperature == 0.1
|
|
|
|
def test_agent_config_custom(self):
|
|
"""测试自定义 Agent 配置"""
|
|
config = AgentConfig(
|
|
name="Custom",
|
|
agent_type=AgentType.ANALYSIS,
|
|
pattern=AgentPattern.PLAN_AND_EXECUTE,
|
|
max_iterations=50,
|
|
temperature=0.5,
|
|
)
|
|
|
|
assert config.pattern == AgentPattern.PLAN_AND_EXECUTE
|
|
assert config.max_iterations == 50
|
|
assert config.temperature == 0.5
|
|
|