feat(agent): 增强 RAG 配置和工具集成
- 扩展嵌入模型配置选项,支持独立 API Key 和 Base URL - 重构 RAG 初始化逻辑,支持用户自定义嵌入配置 - 新增语义搜索工具并集成到 Recon 和 Analysis Agent - 完善系统提示,明确不同代码搜索工具的使用场景
This commit is contained in:
parent
3bdbbf254b
commit
5f07403850
|
|
@ -296,12 +296,13 @@ async def _execute_agent_task(task_id: str):
|
|||
|
||||
# 初始化工具集 - 传递排除模式和目标文件以及预初始化的 sandbox_manager
|
||||
tools = await _initialize_tools(
|
||||
project_root,
|
||||
llm_service,
|
||||
project_root,
|
||||
llm_service,
|
||||
user_config,
|
||||
sandbox_manager=sandbox_manager,
|
||||
exclude_patterns=task.exclude_patterns,
|
||||
target_files=task.target_files,
|
||||
project_id=str(project.id), # 🔥 传递 project_id 用于 RAG
|
||||
)
|
||||
|
||||
# 创建子 Agent
|
||||
|
|
@ -552,15 +553,16 @@ async def _get_user_config(db: AsyncSession, user_id: Optional[str]) -> Optional
|
|||
|
||||
|
||||
async def _initialize_tools(
|
||||
project_root: str,
|
||||
llm_service,
|
||||
project_root: str,
|
||||
llm_service,
|
||||
user_config: Optional[Dict[str, Any]],
|
||||
sandbox_manager: Any, # 传递预初始化的 SandboxManager
|
||||
exclude_patterns: Optional[List[str]] = None,
|
||||
target_files: Optional[List[str]] = None,
|
||||
project_id: Optional[str] = None, # 🔥 用于 RAG collection_name
|
||||
) -> Dict[str, Dict[str, Any]]:
|
||||
"""初始化工具集
|
||||
|
||||
|
||||
Args:
|
||||
project_root: 项目根目录
|
||||
llm_service: LLM 服务
|
||||
|
|
@ -568,6 +570,7 @@ async def _initialize_tools(
|
|||
sandbox_manager: 沙箱管理器
|
||||
exclude_patterns: 排除模式列表
|
||||
target_files: 目标文件列表
|
||||
project_id: 项目 ID(用于 RAG collection_name)
|
||||
"""
|
||||
from app.services.agent.tools import (
|
||||
FileReadTool, FileSearchTool, ListFilesTool,
|
||||
|
|
@ -577,12 +580,82 @@ async def _initialize_tools(
|
|||
ThinkTool, ReflectTool,
|
||||
CreateVulnerabilityReportTool,
|
||||
VulnerabilityValidationTool,
|
||||
# 🔥 RAG 工具
|
||||
RAGQueryTool, SecurityCodeSearchTool, FunctionContextTool,
|
||||
)
|
||||
from app.services.agent.knowledge import (
|
||||
SecurityKnowledgeQueryTool,
|
||||
GetVulnerabilityKnowledgeTool,
|
||||
)
|
||||
|
||||
# 🔥 RAG 相关导入
|
||||
from app.services.rag import CodeIndexer, CodeRetriever, EmbeddingService
|
||||
from app.core.config import settings
|
||||
|
||||
# ============ 🔥 初始化 RAG 系统 ============
|
||||
retriever = None
|
||||
try:
|
||||
# 从用户配置中获取 embedding 配置
|
||||
user_llm_config = (user_config or {}).get('llmConfig', {})
|
||||
user_other_config = (user_config or {}).get('otherConfig', {})
|
||||
user_embedding_config = user_other_config.get('embedding_config', {})
|
||||
|
||||
# Embedding Provider 优先级:用户嵌入配置 > 环境变量
|
||||
embedding_provider = (
|
||||
user_embedding_config.get('provider') or
|
||||
getattr(settings, 'EMBEDDING_PROVIDER', 'openai')
|
||||
)
|
||||
|
||||
# Embedding Model 优先级:用户嵌入配置 > 环境变量
|
||||
embedding_model = (
|
||||
user_embedding_config.get('model') or
|
||||
getattr(settings, 'EMBEDDING_MODEL', 'text-embedding-3-small')
|
||||
)
|
||||
|
||||
# API Key 优先级:用户嵌入配置 > 环境变量 EMBEDDING_API_KEY > 用户 LLM 配置 > 环境变量 LLM_API_KEY
|
||||
# 注意:API Key 可以共享,因为很多用户使用同一个 OpenAI Key 做 LLM 和 Embedding
|
||||
embedding_api_key = (
|
||||
user_embedding_config.get('api_key') or
|
||||
getattr(settings, 'EMBEDDING_API_KEY', None) or
|
||||
user_llm_config.get('llmApiKey') or
|
||||
getattr(settings, 'LLM_API_KEY', '') or
|
||||
''
|
||||
)
|
||||
|
||||
# Base URL 优先级:用户嵌入配置 > 环境变量 EMBEDDING_BASE_URL > None(使用提供商默认地址)
|
||||
# 🔥 重要:Base URL 不应该回退到 LLM 的 base_url,因为 Embedding 和 LLM 可能使用完全不同的服务
|
||||
# 例如:LLM 使用 SiliconFlow,但 Embedding 使用 HuggingFace
|
||||
embedding_base_url = (
|
||||
user_embedding_config.get('base_url') or
|
||||
getattr(settings, 'EMBEDDING_BASE_URL', None) or
|
||||
None
|
||||
)
|
||||
|
||||
logger.info(f"RAG 配置: provider={embedding_provider}, model={embedding_model}, base_url={embedding_base_url or '(使用默认)'}")
|
||||
|
||||
# 创建 Embedding 服务
|
||||
embedding_service = EmbeddingService(
|
||||
provider=embedding_provider,
|
||||
model=embedding_model,
|
||||
api_key=embedding_api_key,
|
||||
base_url=embedding_base_url,
|
||||
)
|
||||
|
||||
# 创建 collection_name(基于 project_id)
|
||||
collection_name = f"project_{project_id}" if project_id else "default_project"
|
||||
|
||||
# 创建 CodeRetriever(用于搜索)
|
||||
retriever = CodeRetriever(
|
||||
collection_name=collection_name,
|
||||
embedding_service=embedding_service,
|
||||
persist_directory=settings.VECTOR_DB_PATH,
|
||||
)
|
||||
|
||||
logger.info(f"✅ RAG 系统初始化成功: collection={collection_name}")
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"⚠️ RAG 系统初始化失败: {e}")
|
||||
retriever = None
|
||||
|
||||
# 基础工具 - 传递排除模式和目标文件
|
||||
base_tools = {
|
||||
"read_file": FileReadTool(project_root, exclude_patterns, target_files),
|
||||
|
|
@ -604,6 +677,11 @@ async def _initialize_tools(
|
|||
"trufflehog_scan": TruffleHogTool(project_root, sandbox_manager),
|
||||
"osv_scan": OSVScannerTool(project_root, sandbox_manager),
|
||||
}
|
||||
|
||||
# 🔥 注册 RAG 工具到 Recon Agent
|
||||
if retriever:
|
||||
recon_tools["rag_query"] = RAGQueryTool(retriever)
|
||||
logger.info("✅ RAG 工具 (rag_query) 已注册到 Recon Agent")
|
||||
|
||||
# Analysis 工具
|
||||
# 🔥 导入智能扫描工具
|
||||
|
|
@ -630,6 +708,15 @@ async def _initialize_tools(
|
|||
"query_security_knowledge": SecurityKnowledgeQueryTool(),
|
||||
"get_vulnerability_knowledge": GetVulnerabilityKnowledgeTool(),
|
||||
}
|
||||
|
||||
# 🔥 注册 RAG 工具到 Analysis Agent
|
||||
if retriever:
|
||||
analysis_tools["rag_query"] = RAGQueryTool(retriever)
|
||||
analysis_tools["security_search"] = SecurityCodeSearchTool(retriever)
|
||||
analysis_tools["function_context"] = FunctionContextTool(retriever)
|
||||
logger.info("✅ RAG 工具 (rag_query, security_search, function_context) 已注册到 Analysis Agent")
|
||||
else:
|
||||
logger.warning("⚠️ RAG 未初始化,rag_query/security_search/function_context 工具不可用")
|
||||
|
||||
# Verification 工具
|
||||
# 🔥 导入沙箱工具
|
||||
|
|
|
|||
|
|
@ -78,10 +78,12 @@ class Settings(BaseSettings):
|
|||
OUTPUT_LANGUAGE: str = "zh-CN"
|
||||
|
||||
# ============ Agent 模块配置 ============
|
||||
|
||||
# 嵌入模型配置
|
||||
EMBEDDING_PROVIDER: str = "openai" # openai, ollama, litellm
|
||||
|
||||
# 嵌入模型配置(独立于 LLM 配置)
|
||||
EMBEDDING_PROVIDER: str = "openai" # openai, azure, ollama, cohere, huggingface, jina
|
||||
EMBEDDING_MODEL: str = "text-embedding-3-small"
|
||||
EMBEDDING_API_KEY: Optional[str] = None # 嵌入模型专用 API Key(留空则使用 LLM_API_KEY)
|
||||
EMBEDDING_BASE_URL: Optional[str] = None # 嵌入模型专用 Base URL(留空使用提供商默认地址)
|
||||
|
||||
# 向量数据库配置
|
||||
VECTOR_DB_PATH: str = "./data/vector_db" # 向量数据库持久化目录
|
||||
|
|
|
|||
|
|
@ -142,30 +142,45 @@ class AgentRunner:
|
|||
async def _initialize_rag(self):
|
||||
"""初始化 RAG 系统"""
|
||||
await self.event_emitter.emit_info("📚 初始化 RAG 代码检索系统...")
|
||||
|
||||
|
||||
try:
|
||||
# 🔥 从用户配置中获取 LLM 配置(用于 Embedding API Key)
|
||||
# 优先级:用户配置 > 环境变量
|
||||
# 🔥 从用户配置中获取配置
|
||||
# 优先级:用户嵌入配置 > 用户 LLM 配置 > 环境变量
|
||||
user_llm_config = self.user_config.get('llmConfig', {})
|
||||
|
||||
# 获取 Embedding 配置(优先使用用户配置的 LLM API Key)
|
||||
embedding_provider = getattr(settings, 'EMBEDDING_PROVIDER', 'openai')
|
||||
embedding_model = getattr(settings, 'EMBEDDING_MODEL', 'text-embedding-3-small')
|
||||
|
||||
# 🔥 API Key 优先级:用户配置 > 环境变量
|
||||
user_other_config = self.user_config.get('otherConfig', {})
|
||||
user_embedding_config = user_other_config.get('embedding_config', {})
|
||||
|
||||
# 🔥 Embedding Provider 优先级:用户嵌入配置 > 环境变量
|
||||
embedding_provider = (
|
||||
user_embedding_config.get('provider') or
|
||||
getattr(settings, 'EMBEDDING_PROVIDER', 'openai')
|
||||
)
|
||||
|
||||
# 🔥 Embedding Model 优先级:用户嵌入配置 > 环境变量
|
||||
embedding_model = (
|
||||
user_embedding_config.get('model') or
|
||||
getattr(settings, 'EMBEDDING_MODEL', 'text-embedding-3-small')
|
||||
)
|
||||
|
||||
# 🔥 API Key 优先级:用户嵌入配置 > 用户 LLM 配置 > 环境变量
|
||||
embedding_api_key = (
|
||||
user_embedding_config.get('api_key') or
|
||||
user_llm_config.get('llmApiKey') or
|
||||
getattr(settings, 'LLM_API_KEY', '') or
|
||||
''
|
||||
)
|
||||
|
||||
# 🔥 Base URL 优先级:用户配置 > 环境变量
|
||||
|
||||
# 🔥 Base URL 优先级:用户嵌入配置 > 用户 LLM 配置 > 环境变量
|
||||
embedding_base_url = (
|
||||
user_embedding_config.get('base_url') or
|
||||
user_llm_config.get('llmBaseUrl') or
|
||||
getattr(settings, 'LLM_BASE_URL', None) or
|
||||
None
|
||||
)
|
||||
|
||||
|
||||
logger.info(f"RAG 配置: provider={embedding_provider}, model={embedding_model}")
|
||||
await self.event_emitter.emit_info(f"嵌入模型: {embedding_provider}/{embedding_model}")
|
||||
|
||||
embedding_service = EmbeddingService(
|
||||
provider=embedding_provider,
|
||||
model=embedding_model,
|
||||
|
|
@ -267,11 +282,14 @@ class AgentRunner:
|
|||
"safety_scan": SafetyTool(self.project_root, self.sandbox_manager),
|
||||
"npm_audit": NpmAuditTool(self.project_root, self.sandbox_manager),
|
||||
}
|
||||
|
||||
|
||||
# RAG 工具(Recon 用于语义搜索)
|
||||
if self.retriever:
|
||||
self.recon_tools["rag_query"] = RAGQueryTool(self.retriever)
|
||||
|
||||
logger.info("✅ RAG 工具已注册到 Recon Agent")
|
||||
else:
|
||||
logger.warning("⚠️ RAG 未初始化,rag_query 工具不可用")
|
||||
|
||||
# ============ Analysis Agent 专属工具 ============
|
||||
# 职责:漏洞分析、代码审计、模式匹配
|
||||
self.analysis_tools = {
|
||||
|
|
@ -300,11 +318,13 @@ class AgentRunner:
|
|||
"query_security_knowledge": SecurityKnowledgeQueryTool(),
|
||||
"get_vulnerability_knowledge": GetVulnerabilityKnowledgeTool(),
|
||||
}
|
||||
|
||||
|
||||
# RAG 工具(Analysis 用于安全相关代码搜索)
|
||||
if self.retriever:
|
||||
self.analysis_tools["security_search"] = SecurityCodeSearchTool(self.retriever)
|
||||
self.analysis_tools["function_context"] = FunctionContextTool(self.retriever)
|
||||
self.analysis_tools["rag_query"] = RAGQueryTool(self.retriever) # 通用语义搜索
|
||||
self.analysis_tools["security_search"] = SecurityCodeSearchTool(self.retriever) # 安全代码搜索
|
||||
self.analysis_tools["function_context"] = FunctionContextTool(self.retriever) # 函数上下文
|
||||
logger.info("✅ RAG 工具已注册到 Analysis Agent (rag_query, security_search, function_context)")
|
||||
|
||||
# ============ Verification Agent 专属工具 ============
|
||||
# 职责:漏洞验证、PoC 执行、误报排除
|
||||
|
|
|
|||
|
|
@ -142,11 +142,28 @@ TOOL_USAGE_GUIDE = """
|
|||
#### 辅助工具
|
||||
| 工具 | 用途 |
|
||||
|------|------|
|
||||
| `rag_query` | **语义搜索代码**(推荐!比 search_code 更智能,理解代码含义) |
|
||||
| `security_search` | **安全相关代码搜索**(专门查找安全敏感代码) |
|
||||
| `function_context` | **函数上下文搜索**(获取函数的调用关系和上下文) |
|
||||
| `list_files` | 了解项目结构 |
|
||||
| `read_file` | 读取文件内容验证发现 |
|
||||
| `search_code` | 搜索相关代码 |
|
||||
| `search_code` | 关键词搜索代码(精确匹配) |
|
||||
| `query_security_knowledge` | 查询安全知识库 |
|
||||
|
||||
### 🔍 代码搜索工具对比
|
||||
| 工具 | 特点 | 适用场景 |
|
||||
|------|------|---------|
|
||||
| `rag_query` | **语义搜索**,理解代码含义 | 查找"处理用户输入的函数"、"数据库查询逻辑" |
|
||||
| `security_search` | **安全专用搜索** | 查找"SQL注入相关代码"、"认证授权代码" |
|
||||
| `function_context` | **函数上下文** | 查找某函数的调用者和被调用者 |
|
||||
| `search_code` | **关键词搜索**,精确匹配 | 查找特定函数名、变量名、字符串 |
|
||||
|
||||
**推荐**:
|
||||
1. 查找安全相关代码时优先使用 `security_search`
|
||||
2. 理解函数关系时使用 `function_context`
|
||||
3. 通用语义搜索使用 `rag_query`
|
||||
4. 精确匹配时使用 `search_code`
|
||||
|
||||
### 📋 推荐分析流程
|
||||
|
||||
#### 第一步:快速侦察(5%时间)
|
||||
|
|
@ -156,6 +173,12 @@ Action Input: {"path": "."}
|
|||
```
|
||||
了解项目结构、技术栈、入口点
|
||||
|
||||
**语义搜索高风险代码(推荐!):**
|
||||
```
|
||||
Action: rag_query
|
||||
Action Input: {"query": "处理用户输入或执行数据库查询的函数", "top_k": 10}
|
||||
```
|
||||
|
||||
#### 第二步:外部工具全面扫描(60%时间)⚡重点!
|
||||
**根据技术栈选择对应工具,并行执行多个扫描:**
|
||||
|
||||
|
|
|
|||
Loading…
Reference in New Issue