CodeReview/backend/tests/agent/test_agents.py

214 lines
6.5 KiB
Python
Raw Permalink Normal View History

"""
Agent 单元测试
测试各个 Agent 的功能
"""
import pytest
import asyncio
import os
from unittest.mock import MagicMock, AsyncMock, patch
from app.services.agent.agents.base import BaseAgent, AgentConfig, AgentResult, AgentType, AgentPattern
from app.services.agent.agents.recon import ReconAgent
from app.services.agent.agents.analysis import AnalysisAgent
from app.services.agent.agents.verification import VerificationAgent
class TestReconAgent:
"""Recon Agent 测试"""
@pytest.fixture
def recon_agent(self, temp_project_dir, mock_llm_service, mock_event_emitter):
"""创建 Recon Agent 实例"""
from app.services.agent.tools import (
FileReadTool, FileSearchTool, ListFilesTool,
)
tools = {
"list_files": ListFilesTool(temp_project_dir),
"read_file": FileReadTool(temp_project_dir),
"search_code": FileSearchTool(temp_project_dir),
}
return ReconAgent(
llm_service=mock_llm_service,
tools=tools,
event_emitter=mock_event_emitter,
)
@pytest.mark.asyncio
async def test_recon_agent_run(self, recon_agent, temp_project_dir):
"""测试 Recon Agent 运行"""
result = await recon_agent.run({
"project_info": {
"name": "Test Project",
"root": temp_project_dir,
},
"config": {},
})
assert result.success is True
assert result.data is not None
# 验证返回数据结构
data = result.data
assert "tech_stack" in data
assert "entry_points" in data or "high_risk_areas" in data
@pytest.mark.asyncio
async def test_recon_agent_identifies_python(self, recon_agent, temp_project_dir):
"""测试 Recon Agent 识别 Python 技术栈"""
result = await recon_agent.run({
"project_info": {"root": temp_project_dir},
"config": {},
})
assert result.success is True
tech_stack = result.data.get("tech_stack", {})
languages = tech_stack.get("languages", [])
# 应该识别出 Python
assert "Python" in languages or len(languages) > 0
@pytest.mark.asyncio
async def test_recon_agent_finds_high_risk_areas(self, recon_agent, temp_project_dir):
"""测试 Recon Agent 发现高风险区域"""
result = await recon_agent.run({
"project_info": {"root": temp_project_dir},
"config": {},
})
assert result.success is True
high_risk_areas = result.data.get("high_risk_areas", [])
# 应该发现高风险区域
assert len(high_risk_areas) > 0
class TestAnalysisAgent:
"""Analysis Agent 测试"""
@pytest.fixture
def analysis_agent(self, temp_project_dir, mock_llm_service, mock_event_emitter):
"""创建 Analysis Agent 实例"""
from app.services.agent.tools import (
FileReadTool, FileSearchTool, PatternMatchTool,
)
tools = {
"read_file": FileReadTool(temp_project_dir),
"search_code": FileSearchTool(temp_project_dir),
"pattern_match": PatternMatchTool(temp_project_dir),
}
return AnalysisAgent(
llm_service=mock_llm_service,
tools=tools,
event_emitter=mock_event_emitter,
)
@pytest.mark.asyncio
async def test_analysis_agent_run(self, analysis_agent, temp_project_dir):
"""测试 Analysis Agent 运行"""
result = await analysis_agent.run({
"tech_stack": {"languages": ["Python"]},
"entry_points": [],
"high_risk_areas": ["src/sql_vuln.py", "src/cmd_vuln.py"],
"config": {},
})
assert result.success is True
assert result.data is not None
@pytest.mark.asyncio
async def test_analysis_agent_finds_vulnerabilities(self, analysis_agent, temp_project_dir):
"""测试 Analysis Agent 发现漏洞"""
result = await analysis_agent.run({
"tech_stack": {"languages": ["Python"]},
"entry_points": [],
"high_risk_areas": [
"src/sql_vuln.py",
"src/cmd_vuln.py",
"src/xss_vuln.py",
"src/secrets.py",
],
"config": {},
})
assert result.success is True
findings = result.data.get("findings", [])
# 应该发现一些漏洞
# 注意:具体数量取决于分析逻辑
assert isinstance(findings, list)
class TestAgentResult:
"""Agent 结果测试"""
def test_agent_result_success(self):
"""测试成功的 Agent 结果"""
result = AgentResult(
success=True,
data={"findings": []},
iterations=5,
tool_calls=10,
)
assert result.success is True
assert result.iterations == 5
assert result.tool_calls == 10
def test_agent_result_failure(self):
"""测试失败的 Agent 结果"""
result = AgentResult(
success=False,
error="Test error",
)
assert result.success is False
assert result.error == "Test error"
def test_agent_result_to_dict(self):
"""测试 Agent 结果转字典"""
result = AgentResult(
success=True,
data={"key": "value"},
iterations=3,
)
d = result.to_dict()
assert d["success"] is True
assert d["iterations"] == 3
class TestAgentConfig:
"""Agent 配置测试"""
def test_agent_config_defaults(self):
"""测试 Agent 配置默认值"""
config = AgentConfig(
name="Test",
agent_type=AgentType.RECON,
)
assert config.pattern == AgentPattern.REACT
assert config.max_iterations == 20
assert config.temperature == 0.1
def test_agent_config_custom(self):
"""测试自定义 Agent 配置"""
config = AgentConfig(
name="Custom",
agent_type=AgentType.ANALYSIS,
pattern=AgentPattern.PLAN_AND_EXECUTE,
max_iterations=50,
temperature=0.5,
)
assert config.pattern == AgentPattern.PLAN_AND_EXECUTE
assert config.max_iterations == 50
assert config.temperature == 0.5