feat(agent): 增强 RAG 配置和工具集成

- 扩展嵌入模型配置选项,支持独立 API Key 和 Base URL
- 重构 RAG 初始化逻辑,支持用户自定义嵌入配置
- 新增语义搜索工具并集成到 Recon 和 Analysis Agent
- 完善系统提示,明确不同代码搜索工具的使用场景
This commit is contained in:
lintsinghua 2025-12-16 13:57:27 +08:00
parent 3bdbbf254b
commit 5f07403850
4 changed files with 159 additions and 27 deletions

View File

@ -296,12 +296,13 @@ async def _execute_agent_task(task_id: str):
# 初始化工具集 - 传递排除模式和目标文件以及预初始化的 sandbox_manager
tools = await _initialize_tools(
project_root,
llm_service,
project_root,
llm_service,
user_config,
sandbox_manager=sandbox_manager,
exclude_patterns=task.exclude_patterns,
target_files=task.target_files,
project_id=str(project.id), # 🔥 传递 project_id 用于 RAG
)
# 创建子 Agent
@ -552,15 +553,16 @@ async def _get_user_config(db: AsyncSession, user_id: Optional[str]) -> Optional
async def _initialize_tools(
project_root: str,
llm_service,
project_root: str,
llm_service,
user_config: Optional[Dict[str, Any]],
sandbox_manager: Any, # 传递预初始化的 SandboxManager
exclude_patterns: Optional[List[str]] = None,
target_files: Optional[List[str]] = None,
project_id: Optional[str] = None, # 🔥 用于 RAG collection_name
) -> Dict[str, Dict[str, Any]]:
"""初始化工具集
Args:
project_root: 项目根目录
llm_service: LLM 服务
@ -568,6 +570,7 @@ async def _initialize_tools(
sandbox_manager: 沙箱管理器
exclude_patterns: 排除模式列表
target_files: 目标文件列表
project_id: 项目 ID用于 RAG collection_name
"""
from app.services.agent.tools import (
FileReadTool, FileSearchTool, ListFilesTool,
@ -577,12 +580,82 @@ async def _initialize_tools(
ThinkTool, ReflectTool,
CreateVulnerabilityReportTool,
VulnerabilityValidationTool,
# 🔥 RAG 工具
RAGQueryTool, SecurityCodeSearchTool, FunctionContextTool,
)
from app.services.agent.knowledge import (
SecurityKnowledgeQueryTool,
GetVulnerabilityKnowledgeTool,
)
# 🔥 RAG 相关导入
from app.services.rag import CodeIndexer, CodeRetriever, EmbeddingService
from app.core.config import settings
# ============ 🔥 初始化 RAG 系统 ============
retriever = None
try:
# 从用户配置中获取 embedding 配置
user_llm_config = (user_config or {}).get('llmConfig', {})
user_other_config = (user_config or {}).get('otherConfig', {})
user_embedding_config = user_other_config.get('embedding_config', {})
# Embedding Provider 优先级:用户嵌入配置 > 环境变量
embedding_provider = (
user_embedding_config.get('provider') or
getattr(settings, 'EMBEDDING_PROVIDER', 'openai')
)
# Embedding Model 优先级:用户嵌入配置 > 环境变量
embedding_model = (
user_embedding_config.get('model') or
getattr(settings, 'EMBEDDING_MODEL', 'text-embedding-3-small')
)
# API Key 优先级:用户嵌入配置 > 环境变量 EMBEDDING_API_KEY > 用户 LLM 配置 > 环境变量 LLM_API_KEY
# 注意API Key 可以共享,因为很多用户使用同一个 OpenAI Key 做 LLM 和 Embedding
embedding_api_key = (
user_embedding_config.get('api_key') or
getattr(settings, 'EMBEDDING_API_KEY', None) or
user_llm_config.get('llmApiKey') or
getattr(settings, 'LLM_API_KEY', '') or
''
)
# Base URL 优先级:用户嵌入配置 > 环境变量 EMBEDDING_BASE_URL > None使用提供商默认地址
# 🔥 重要Base URL 不应该回退到 LLM 的 base_url因为 Embedding 和 LLM 可能使用完全不同的服务
# 例如LLM 使用 SiliconFlow但 Embedding 使用 HuggingFace
embedding_base_url = (
user_embedding_config.get('base_url') or
getattr(settings, 'EMBEDDING_BASE_URL', None) or
None
)
logger.info(f"RAG 配置: provider={embedding_provider}, model={embedding_model}, base_url={embedding_base_url or '(使用默认)'}")
# 创建 Embedding 服务
embedding_service = EmbeddingService(
provider=embedding_provider,
model=embedding_model,
api_key=embedding_api_key,
base_url=embedding_base_url,
)
# 创建 collection_name基于 project_id
collection_name = f"project_{project_id}" if project_id else "default_project"
# 创建 CodeRetriever用于搜索
retriever = CodeRetriever(
collection_name=collection_name,
embedding_service=embedding_service,
persist_directory=settings.VECTOR_DB_PATH,
)
logger.info(f"✅ RAG 系统初始化成功: collection={collection_name}")
except Exception as e:
logger.warning(f"⚠️ RAG 系统初始化失败: {e}")
retriever = None
# 基础工具 - 传递排除模式和目标文件
base_tools = {
"read_file": FileReadTool(project_root, exclude_patterns, target_files),
@ -604,6 +677,11 @@ async def _initialize_tools(
"trufflehog_scan": TruffleHogTool(project_root, sandbox_manager),
"osv_scan": OSVScannerTool(project_root, sandbox_manager),
}
# 🔥 注册 RAG 工具到 Recon Agent
if retriever:
recon_tools["rag_query"] = RAGQueryTool(retriever)
logger.info("✅ RAG 工具 (rag_query) 已注册到 Recon Agent")
# Analysis 工具
# 🔥 导入智能扫描工具
@ -630,6 +708,15 @@ async def _initialize_tools(
"query_security_knowledge": SecurityKnowledgeQueryTool(),
"get_vulnerability_knowledge": GetVulnerabilityKnowledgeTool(),
}
# 🔥 注册 RAG 工具到 Analysis Agent
if retriever:
analysis_tools["rag_query"] = RAGQueryTool(retriever)
analysis_tools["security_search"] = SecurityCodeSearchTool(retriever)
analysis_tools["function_context"] = FunctionContextTool(retriever)
logger.info("✅ RAG 工具 (rag_query, security_search, function_context) 已注册到 Analysis Agent")
else:
logger.warning("⚠️ RAG 未初始化rag_query/security_search/function_context 工具不可用")
# Verification 工具
# 🔥 导入沙箱工具

View File

@ -78,10 +78,12 @@ class Settings(BaseSettings):
OUTPUT_LANGUAGE: str = "zh-CN"
# ============ Agent 模块配置 ============
# 嵌入模型配置
EMBEDDING_PROVIDER: str = "openai" # openai, ollama, litellm
# 嵌入模型配置(独立于 LLM 配置)
EMBEDDING_PROVIDER: str = "openai" # openai, azure, ollama, cohere, huggingface, jina
EMBEDDING_MODEL: str = "text-embedding-3-small"
EMBEDDING_API_KEY: Optional[str] = None # 嵌入模型专用 API Key留空则使用 LLM_API_KEY
EMBEDDING_BASE_URL: Optional[str] = None # 嵌入模型专用 Base URL留空使用提供商默认地址
# 向量数据库配置
VECTOR_DB_PATH: str = "./data/vector_db" # 向量数据库持久化目录

View File

@ -142,30 +142,45 @@ class AgentRunner:
async def _initialize_rag(self):
"""初始化 RAG 系统"""
await self.event_emitter.emit_info("📚 初始化 RAG 代码检索系统...")
try:
# 🔥 从用户配置中获取 LLM 配置(用于 Embedding API Key
# 优先级:用户配置 > 环境变量
# 🔥 从用户配置中获取配置
# 优先级:用户嵌入配置 > 用户 LLM 配置 > 环境变量
user_llm_config = self.user_config.get('llmConfig', {})
# 获取 Embedding 配置(优先使用用户配置的 LLM API Key
embedding_provider = getattr(settings, 'EMBEDDING_PROVIDER', 'openai')
embedding_model = getattr(settings, 'EMBEDDING_MODEL', 'text-embedding-3-small')
# 🔥 API Key 优先级:用户配置 > 环境变量
user_other_config = self.user_config.get('otherConfig', {})
user_embedding_config = user_other_config.get('embedding_config', {})
# 🔥 Embedding Provider 优先级:用户嵌入配置 > 环境变量
embedding_provider = (
user_embedding_config.get('provider') or
getattr(settings, 'EMBEDDING_PROVIDER', 'openai')
)
# 🔥 Embedding Model 优先级:用户嵌入配置 > 环境变量
embedding_model = (
user_embedding_config.get('model') or
getattr(settings, 'EMBEDDING_MODEL', 'text-embedding-3-small')
)
# 🔥 API Key 优先级:用户嵌入配置 > 用户 LLM 配置 > 环境变量
embedding_api_key = (
user_embedding_config.get('api_key') or
user_llm_config.get('llmApiKey') or
getattr(settings, 'LLM_API_KEY', '') or
''
)
# 🔥 Base URL 优先级:用户配置 > 环境变量
# 🔥 Base URL 优先级:用户嵌入配置 > 用户 LLM 配置 > 环境变量
embedding_base_url = (
user_embedding_config.get('base_url') or
user_llm_config.get('llmBaseUrl') or
getattr(settings, 'LLM_BASE_URL', None) or
None
)
logger.info(f"RAG 配置: provider={embedding_provider}, model={embedding_model}")
await self.event_emitter.emit_info(f"嵌入模型: {embedding_provider}/{embedding_model}")
embedding_service = EmbeddingService(
provider=embedding_provider,
model=embedding_model,
@ -267,11 +282,14 @@ class AgentRunner:
"safety_scan": SafetyTool(self.project_root, self.sandbox_manager),
"npm_audit": NpmAuditTool(self.project_root, self.sandbox_manager),
}
# RAG 工具Recon 用于语义搜索)
if self.retriever:
self.recon_tools["rag_query"] = RAGQueryTool(self.retriever)
logger.info("✅ RAG 工具已注册到 Recon Agent")
else:
logger.warning("⚠️ RAG 未初始化rag_query 工具不可用")
# ============ Analysis Agent 专属工具 ============
# 职责:漏洞分析、代码审计、模式匹配
self.analysis_tools = {
@ -300,11 +318,13 @@ class AgentRunner:
"query_security_knowledge": SecurityKnowledgeQueryTool(),
"get_vulnerability_knowledge": GetVulnerabilityKnowledgeTool(),
}
# RAG 工具Analysis 用于安全相关代码搜索)
if self.retriever:
self.analysis_tools["security_search"] = SecurityCodeSearchTool(self.retriever)
self.analysis_tools["function_context"] = FunctionContextTool(self.retriever)
self.analysis_tools["rag_query"] = RAGQueryTool(self.retriever) # 通用语义搜索
self.analysis_tools["security_search"] = SecurityCodeSearchTool(self.retriever) # 安全代码搜索
self.analysis_tools["function_context"] = FunctionContextTool(self.retriever) # 函数上下文
logger.info("✅ RAG 工具已注册到 Analysis Agent (rag_query, security_search, function_context)")
# ============ Verification Agent 专属工具 ============
# 职责漏洞验证、PoC 执行、误报排除

View File

@ -142,11 +142,28 @@ TOOL_USAGE_GUIDE = """
#### 辅助工具
| 工具 | 用途 |
|------|------|
| `rag_query` | **语义搜索代码**推荐 search_code 更智能理解代码含义 |
| `security_search` | **安全相关代码搜索**专门查找安全敏感代码 |
| `function_context` | **函数上下文搜索**获取函数的调用关系和上下文 |
| `list_files` | 了解项目结构 |
| `read_file` | 读取文件内容验证发现 |
| `search_code` | 搜索相关代码 |
| `search_code` | 关键词搜索代码精确匹配 |
| `query_security_knowledge` | 查询安全知识库 |
### 🔍 代码搜索工具对比
| 工具 | 特点 | 适用场景 |
|------|------|---------|
| `rag_query` | **语义搜索**理解代码含义 | 查找"处理用户输入的函数""数据库查询逻辑" |
| `security_search` | **安全专用搜索** | 查找"SQL注入相关代码""认证授权代码" |
| `function_context` | **函数上下文** | 查找某函数的调用者和被调用者 |
| `search_code` | **关键词搜索**精确匹配 | 查找特定函数名变量名字符串 |
**推荐**
1. 查找安全相关代码时优先使用 `security_search`
2. 理解函数关系时使用 `function_context`
3. 通用语义搜索使用 `rag_query`
4. 精确匹配时使用 `search_code`
### 📋 推荐分析流程
#### 第一步快速侦察5%时间)
@ -156,6 +173,12 @@ Action Input: {"path": "."}
```
了解项目结构技术栈入口点
**语义搜索高风险代码推荐**
```
Action: rag_query
Action Input: {"query": "处理用户输入或执行数据库查询的函数", "top_k": 10}
```
#### 第二步外部工具全面扫描60%时间)⚡重点!
**根据技术栈选择对应工具并行执行多个扫描**